cubecl_core/post_processing/
saturating.rs1use crate as cubecl;
2use cubecl_ir::{
3 Allocator, Arithmetic, ElemType, ExpandElement, Instruction, IntKind, Operation, Processor,
4 Scope, ScopeProcessing, StorageType, UIntKind, Variable,
5};
6
7use crate::prelude::*;
8
9#[derive(new, Debug)]
11pub struct SaturatingArithmeticProcessor {
12 replace_i32: bool,
15}
16
17impl Processor for SaturatingArithmeticProcessor {
18 fn transform(
19 &self,
20 mut processing: cubecl_ir::ScopeProcessing,
21 allocator: Allocator,
22 ) -> cubecl_ir::ScopeProcessing {
23 let mut instructions = Vec::new();
24 core::mem::swap(&mut processing.instructions, &mut instructions);
25
26 for instruction in instructions {
27 if let Operation::Arithmetic(arithmetic) = &instruction.operation {
28 match arithmetic {
29 Arithmetic::SaturatingAdd(op) if op.lhs.elem_type().is_unsigned_int() => {
30 run_polyfill(
31 &mut processing,
32 op.lhs,
33 op.rhs,
34 instruction.out(),
35 &allocator,
36 saturating_add_unsigned::expand::<IntExpand<0>>,
37 );
38 continue;
39 }
40 Arithmetic::SaturatingAdd(op)
41 if op.lhs.elem_type().is_signed_int()
42 && self.should_replace(op.lhs.storage_type()) =>
43 {
44 run_polyfill(
45 &mut processing,
46 op.lhs,
47 op.rhs,
48 instruction.out(),
49 &allocator,
50 saturating_add_signed::expand::<IntExpand<0>, IntExpand<1>>,
51 );
52 continue;
53 }
54 Arithmetic::SaturatingSub(op) if op.lhs.elem_type().is_unsigned_int() => {
55 run_polyfill(
56 &mut processing,
57 op.lhs,
58 op.rhs,
59 instruction.out(),
60 &allocator,
61 saturating_sub_unsigned::expand::<IntExpand<0>>,
62 );
63 continue;
64 }
65 Arithmetic::SaturatingSub(op)
66 if op.lhs.elem_type().is_signed_int()
67 && self.should_replace(op.lhs.storage_type()) =>
68 {
69 run_polyfill(
70 &mut processing,
71 op.lhs,
72 op.rhs,
73 instruction.out(),
74 &allocator,
75 saturating_sub_signed::expand::<IntExpand<0>, IntExpand<1>>,
76 );
77 continue;
78 }
79 _ => {}
80 }
81 }
82
83 processing.instructions.push(instruction);
85 }
86 processing
87 }
88}
89
90impl SaturatingArithmeticProcessor {
91 fn should_replace(&self, ty: StorageType) -> bool {
92 self.replace_i32 || !matches!(ty, StorageType::Scalar(ElemType::Int(IntKind::I32)))
93 }
94}
95
96fn run_polyfill<T: CubePrimitive>(
97 processing: &mut ScopeProcessing,
98 lhs: Variable,
99 rhs: Variable,
100 out: Variable,
101 allocator: &Allocator,
102 mut polyfill: impl FnMut(
103 &mut Scope,
104 ExpandElementTyped<T>,
105 ExpandElementTyped<T>,
106 ) -> ExpandElementTyped<T>,
107) {
108 let lhs = ExpandElement::Plain(lhs);
109 let rhs = ExpandElement::Plain(rhs);
110 let mut scope = Scope::root(false)
111 .with_allocator(allocator.clone())
112 .with_types(processing.typemap.clone());
113 scope.register_type::<IntExpand<0>>(lhs.storage_type());
114 if let ElemType::Int(kind) = lhs.elem_type() {
115 let unsigned_ty = match kind {
116 IntKind::I8 => UIntKind::U8,
117 IntKind::I16 => UIntKind::U16,
118 IntKind::I32 => UIntKind::U32,
119 IntKind::I64 => UIntKind::U64,
120 };
121 scope.register_type::<IntExpand<1>>(ElemType::UInt(unsigned_ty).into())
122 }
123
124 let out_poly = polyfill(&mut scope, lhs.into(), rhs.into()).expand;
125 let tmp_processing = scope.process([]);
126
127 for inst in tmp_processing.instructions {
128 processing.instructions.push(inst);
129 }
130 for var in tmp_processing.variables {
131 processing.variables.push(var);
132 }
133
134 processing
135 .instructions
136 .push(Instruction::new(Operation::Copy(*out_poly), out));
137}
138
139#[cube]
140fn saturating_add_unsigned<U: Int>(a: Line<U>, b: Line<U>) -> Line<U> {
141 let c = a.min(!b);
142 c + b
143}
144
145#[cube]
146fn saturating_sub_unsigned<U: Int>(a: Line<U>, b: Line<U>) -> Line<U> {
147 let a = a.max(b);
148 a - b
149}
150
151#[cube]
154fn saturating_add_signed<I: Int, U: Int>(x: Line<I>, y: Line<I>) -> Line<I> {
155 let bit_width = I::type_size_bits();
156 let shift = Line::<U>::new(U::new(comptime![(bit_width - 1) as i64]));
157
158 let ux = Line::<U>::cast_from(x);
159 let uy = Line::<U>::cast_from(y);
160 let res = ux + uy;
161 let ux = (ux >> shift) + Line::<U>::cast_from(I::max_value());
162 let cond = Line::<I>::cast_from((ux ^ uy) | !(uy ^ res)).greater_equal(Line::new(I::new(0)));
163 select_many(cond, Line::cast_from(ux), Line::cast_from(res))
164}
165
166#[cube]
169fn saturating_sub_signed<I: Int, U: Int>(x: Line<I>, y: Line<I>) -> Line<I> {
170 let bit_width = I::type_size_bits();
171 let shift = Line::<U>::new(U::new(comptime![(bit_width - 1) as i64]));
172
173 let ux = Line::<U>::cast_from(x);
174 let uy = Line::<U>::cast_from(y);
175 let res = ux - uy;
176 let ux = (ux >> shift) + Line::<U>::cast_from(I::max_value());
177 let cond = Line::<I>::cast_from((ux ^ uy) & (ux ^ res)).less_than(Line::new(I::new(0)));
178 select_many(cond, Line::cast_from(ux), Line::cast_from(res))
179}