cubecl_core/post_processing/
saturating.rs1use crate as cubecl;
2use alloc::vec::Vec;
3use cubecl_ir::{
4 Allocator, Arithmetic, ElemType, Instruction, IntKind, ManagedVariable, Operation, Processor,
5 Scope, ScopeProcessing, StorageType, UIntKind, Variable,
6};
7
8use crate::prelude::*;
9
10define_scalar!(ElemA);
11define_scalar!(ElemB);
12define_size!(SizeA);
13
14#[derive(new, Debug)]
16pub struct SaturatingArithmeticProcessor {
17 replace_i32: bool,
20}
21
22impl Processor for SaturatingArithmeticProcessor {
23 fn transform(
24 &self,
25 mut processing: cubecl_ir::ScopeProcessing,
26 allocator: Allocator,
27 ) -> cubecl_ir::ScopeProcessing {
28 let mut instructions = Vec::new();
29 core::mem::swap(&mut processing.instructions, &mut instructions);
30
31 for instruction in instructions {
32 if let Operation::Arithmetic(arithmetic) = &instruction.operation {
33 match arithmetic {
34 Arithmetic::SaturatingAdd(op) if op.lhs.elem_type().is_unsigned_int() => {
35 run_polyfill(
36 &mut processing,
37 op.lhs,
38 op.rhs,
39 instruction.out(),
40 &allocator,
41 saturating_add_unsigned::expand::<ElemA, SizeA>,
42 );
43 continue;
44 }
45 Arithmetic::SaturatingAdd(op)
46 if op.lhs.elem_type().is_signed_int()
47 && self.should_replace(op.lhs.storage_type()) =>
48 {
49 run_polyfill(
50 &mut processing,
51 op.lhs,
52 op.rhs,
53 instruction.out(),
54 &allocator,
55 saturating_add_signed::expand::<ElemA, ElemB, SizeA>,
56 );
57 continue;
58 }
59 Arithmetic::SaturatingSub(op) if op.lhs.elem_type().is_unsigned_int() => {
60 run_polyfill(
61 &mut processing,
62 op.lhs,
63 op.rhs,
64 instruction.out(),
65 &allocator,
66 saturating_sub_unsigned::expand::<ElemA, SizeA>,
67 );
68 continue;
69 }
70 Arithmetic::SaturatingSub(op)
71 if op.lhs.elem_type().is_signed_int()
72 && self.should_replace(op.lhs.storage_type()) =>
73 {
74 run_polyfill(
75 &mut processing,
76 op.lhs,
77 op.rhs,
78 instruction.out(),
79 &allocator,
80 saturating_sub_signed::expand::<ElemA, ElemB, SizeA>,
81 );
82 continue;
83 }
84 _ => {}
85 }
86 }
87
88 processing.instructions.push(instruction);
90 }
91 processing
92 }
93}
94
95impl SaturatingArithmeticProcessor {
96 fn should_replace(&self, ty: StorageType) -> bool {
97 self.replace_i32 || !matches!(ty, StorageType::Scalar(ElemType::Int(IntKind::I32)))
98 }
99}
100
101fn run_polyfill<T: CubePrimitive>(
102 processing: &mut ScopeProcessing,
103 lhs: Variable,
104 rhs: Variable,
105 out: Variable,
106 allocator: &Allocator,
107 mut polyfill: impl FnMut(&mut Scope, NativeExpand<T>, NativeExpand<T>) -> NativeExpand<T>,
108) {
109 let lhs = ManagedVariable::Plain(lhs);
110 let rhs = ManagedVariable::Plain(rhs);
111 let mut scope = Scope::root(false)
112 .with_allocator(allocator.clone())
113 .with_types(processing.typemap.clone());
114 scope.register_type::<ElemA>(lhs.storage_type());
115 scope.register_size::<SizeA>(lhs.vector_size());
116 if let ElemType::Int(kind) = lhs.elem_type() {
117 let unsigned_ty = match kind {
118 IntKind::I8 => UIntKind::U8,
119 IntKind::I16 => UIntKind::U16,
120 IntKind::I32 => UIntKind::U32,
121 IntKind::I64 => UIntKind::U64,
122 };
123 scope.register_type::<ElemB>(ElemType::UInt(unsigned_ty).into())
124 }
125
126 let out_poly = polyfill(&mut scope, lhs.into(), rhs.into()).expand;
127 let tmp_processing = scope.process([]);
128
129 for inst in tmp_processing.instructions {
130 processing.instructions.push(inst);
131 }
132 for var in tmp_processing.variables {
133 processing.variables.push(var);
134 }
135
136 processing
137 .instructions
138 .push(Instruction::new(Operation::Copy(*out_poly), out));
139}
140
141#[cube]
142fn saturating_add_unsigned<U: Int, N: Size>(a: Vector<U, N>, b: Vector<U, N>) -> Vector<U, N> {
143 let c = a.min(!b);
144 c + b
145}
146
147#[cube]
148fn saturating_sub_unsigned<U: Int, N: Size>(a: Vector<U, N>, b: Vector<U, N>) -> Vector<U, N> {
149 let a = a.max(b);
150 a - b
151}
152
153#[cube]
156fn saturating_add_signed<I: Int, U: Int, N: Size>(
157 x: Vector<I, N>,
158 y: Vector<I, N>,
159) -> Vector<I, N> {
160 let bit_width = I::type_size_bits();
161 let shift = Vector::<U, N>::new(U::new(comptime![(bit_width - 1) as i64]));
162
163 let ux = Vector::<U, N>::cast_from(x);
164 let uy = Vector::<U, N>::cast_from(y);
165 let res = ux + uy;
166 let ux = (ux >> shift) + Vector::<U, N>::cast_from(I::max_value());
167 let cond =
168 Vector::<I, N>::cast_from((ux ^ uy) | !(uy ^ res)).greater_equal(Vector::new(I::new(0)));
169 select_many(cond, Vector::cast_from(ux), Vector::cast_from(res))
170}
171
172#[cube]
175fn saturating_sub_signed<I: Int, U: Int, N: Size>(
176 x: Vector<I, N>,
177 y: Vector<I, N>,
178) -> Vector<I, N> {
179 let bit_width = I::type_size_bits();
180 let shift = Vector::<U, N>::new(U::new(comptime![(bit_width - 1) as i64]));
181
182 let ux = Vector::<U, N>::cast_from(x);
183 let uy = Vector::<U, N>::cast_from(y);
184 let res = ux - uy;
185 let ux = (ux >> shift) + Vector::<U, N>::cast_from(I::max_value());
186 let cond = Vector::<I, N>::cast_from((ux ^ uy) & (ux ^ res)).less_than(Vector::new(I::new(0)));
187 select_many(cond, Vector::cast_from(ux), Vector::cast_from(res))
188}