cubecl_core/post_processing/
saturating.rs1use crate as cubecl;
2use cubecl_ir::{
3 Allocator, Arithmetic, ElemType, ExpandElement, Instruction, IntKind, Operation, Processor,
4 Scope, ScopeProcessing, StorageType, UIntKind, Variable,
5};
6
7use crate::prelude::*;
8
9#[derive(new, Debug)]
11pub struct SaturatingArithmeticProcessor {
12 replace_i32: bool,
15}
16
17impl Processor for SaturatingArithmeticProcessor {
18 fn transform(
19 &self,
20 mut processing: cubecl_ir::ScopeProcessing,
21 allocator: Allocator,
22 ) -> cubecl_ir::ScopeProcessing {
23 let mut instructions = Vec::new();
24 core::mem::swap(&mut processing.instructions, &mut instructions);
25
26 for instruction in instructions {
27 if let Operation::Arithmetic(arithmetic) = &instruction.operation {
28 match arithmetic {
29 Arithmetic::SaturatingAdd(op) if op.lhs.elem_type().is_unsigned_int() => {
30 run_polyfill(
31 &mut processing,
32 op.lhs,
33 op.rhs,
34 instruction.out(),
35 &allocator,
36 saturating_add_unsigned::expand::<IntExpand<0>>,
37 );
38 continue;
39 }
40 Arithmetic::SaturatingAdd(op)
41 if op.lhs.elem_type().is_signed_int()
42 && self.should_replace(op.lhs.storage_type()) =>
43 {
44 run_polyfill(
45 &mut processing,
46 op.lhs,
47 op.rhs,
48 instruction.out(),
49 &allocator,
50 saturating_add_signed::expand::<IntExpand<0>, IntExpand<1>>,
51 );
52 continue;
53 }
54 Arithmetic::SaturatingSub(op) if op.lhs.elem_type().is_unsigned_int() => {
55 run_polyfill(
56 &mut processing,
57 op.lhs,
58 op.rhs,
59 instruction.out(),
60 &allocator,
61 saturating_sub_unsigned::expand::<IntExpand<0>>,
62 );
63 continue;
64 }
65 Arithmetic::SaturatingSub(op)
66 if op.lhs.elem_type().is_signed_int()
67 && self.should_replace(op.lhs.storage_type()) =>
68 {
69 run_polyfill(
70 &mut processing,
71 op.lhs,
72 op.rhs,
73 instruction.out(),
74 &allocator,
75 saturating_sub_signed::expand::<IntExpand<0>, IntExpand<1>>,
76 );
77 continue;
78 }
79 _ => {}
80 }
81 }
82
83 processing.instructions.push(instruction);
85 }
86 processing
87 }
88}
89
90impl SaturatingArithmeticProcessor {
91 fn should_replace(&self, ty: StorageType) -> bool {
92 self.replace_i32 || !matches!(ty, StorageType::Scalar(ElemType::Int(IntKind::I32)))
93 }
94}
95
96fn run_polyfill<T: CubePrimitive>(
97 processing: &mut ScopeProcessing,
98 lhs: Variable,
99 rhs: Variable,
100 out: Variable,
101 allocator: &Allocator,
102 mut polyfill: impl FnMut(
103 &mut Scope,
104 ExpandElementTyped<T>,
105 ExpandElementTyped<T>,
106 ) -> ExpandElementTyped<T>,
107) {
108 let lhs = ExpandElement::Plain(lhs);
109 let rhs = ExpandElement::Plain(rhs);
110 let mut scope = Scope::root(false).with_allocator(allocator.clone());
111 scope.register_type::<IntExpand<0>>(lhs.storage_type());
112 if let ElemType::Int(kind) = lhs.elem_type() {
113 let unsigned_ty = match kind {
114 IntKind::I8 => UIntKind::U8,
115 IntKind::I16 => UIntKind::U16,
116 IntKind::I32 => UIntKind::U32,
117 IntKind::I64 => UIntKind::U64,
118 };
119 scope.register_type::<IntExpand<1>>(ElemType::UInt(unsigned_ty).into())
120 }
121
122 let out_poly = polyfill(&mut scope, lhs.into(), rhs.into()).expand;
123 let tmp_processing = scope.process([]);
124
125 for inst in tmp_processing.instructions {
126 processing.instructions.push(inst);
127 }
128 for var in tmp_processing.variables {
129 processing.variables.push(var);
130 }
131
132 processing
133 .instructions
134 .push(Instruction::new(Operation::Copy(*out_poly), out));
135}
136
137#[cube]
138fn saturating_add_unsigned<U: Int>(a: Line<U>, b: Line<U>) -> Line<U> {
139 let c = Line::<U>::min(a, Line::<U>::bitwise_not(b));
140 c + b
141}
142
143#[cube]
144fn saturating_sub_unsigned<U: Int>(a: Line<U>, b: Line<U>) -> Line<U> {
145 let a = Max::max(a, b);
146 a - b
147}
148
149#[cube]
152fn saturating_add_signed<I: Int, U: Int>(x: Line<I>, y: Line<I>) -> Line<I> {
153 let bit_width = I::elem_size_bits();
154 let shift = Line::<U>::new(U::new(comptime![(bit_width - 1) as i64]));
155
156 let ux = Line::<U>::cast_from(x);
157 let uy = Line::<U>::cast_from(y);
158 let res = ux + uy;
159 let ux = (ux >> shift) + Line::<U>::cast_from(I::max_value());
160 let cond = Line::<I>::cast_from((ux ^ uy) | BitwiseNot::bitwise_not(uy ^ res))
161 .greater_equal(Line::new(I::new(0)));
162 select_many(cond, Line::cast_from(ux), Line::cast_from(res))
163}
164
165#[cube]
168fn saturating_sub_signed<I: Int, U: Int>(x: Line<I>, y: Line<I>) -> Line<I> {
169 let bit_width = I::elem_size_bits();
170 let shift = Line::<U>::new(U::new(comptime![(bit_width - 1) as i64]));
171
172 let ux = Line::<U>::cast_from(x);
173 let uy = Line::<U>::cast_from(y);
174 let res = ux - uy;
175 let ux = (ux >> shift) + Line::<U>::cast_from(I::max_value());
176 let cond = Line::<I>::cast_from((ux ^ uy) & (ux ^ res)).less_than(Line::new(I::new(0)));
177 select_many(cond, Line::cast_from(ux), Line::cast_from(res))
178}