cubecl_core/post_processing/
saturating.rs1use crate as cubecl;
2use alloc::vec::Vec;
3use cubecl_ir::{
4 Allocator, Arithmetic, ElemType, ExpandElement, Instruction, IntKind, Operation, Processor,
5 Scope, ScopeProcessing, StorageType, UIntKind, Variable,
6};
7
8use crate::prelude::*;
9
10#[derive(new, Debug)]
12pub struct SaturatingArithmeticProcessor {
13 replace_i32: bool,
16}
17
18impl Processor for SaturatingArithmeticProcessor {
19 fn transform(
20 &self,
21 mut processing: cubecl_ir::ScopeProcessing,
22 allocator: Allocator,
23 ) -> cubecl_ir::ScopeProcessing {
24 let mut instructions = Vec::new();
25 core::mem::swap(&mut processing.instructions, &mut instructions);
26
27 for instruction in instructions {
28 if let Operation::Arithmetic(arithmetic) = &instruction.operation {
29 match arithmetic {
30 Arithmetic::SaturatingAdd(op) if op.lhs.elem_type().is_unsigned_int() => {
31 run_polyfill(
32 &mut processing,
33 op.lhs,
34 op.rhs,
35 instruction.out(),
36 &allocator,
37 saturating_add_unsigned::expand::<IntExpand<0>>,
38 );
39 continue;
40 }
41 Arithmetic::SaturatingAdd(op)
42 if op.lhs.elem_type().is_signed_int()
43 && self.should_replace(op.lhs.storage_type()) =>
44 {
45 run_polyfill(
46 &mut processing,
47 op.lhs,
48 op.rhs,
49 instruction.out(),
50 &allocator,
51 saturating_add_signed::expand::<IntExpand<0>, IntExpand<1>>,
52 );
53 continue;
54 }
55 Arithmetic::SaturatingSub(op) if op.lhs.elem_type().is_unsigned_int() => {
56 run_polyfill(
57 &mut processing,
58 op.lhs,
59 op.rhs,
60 instruction.out(),
61 &allocator,
62 saturating_sub_unsigned::expand::<IntExpand<0>>,
63 );
64 continue;
65 }
66 Arithmetic::SaturatingSub(op)
67 if op.lhs.elem_type().is_signed_int()
68 && self.should_replace(op.lhs.storage_type()) =>
69 {
70 run_polyfill(
71 &mut processing,
72 op.lhs,
73 op.rhs,
74 instruction.out(),
75 &allocator,
76 saturating_sub_signed::expand::<IntExpand<0>, IntExpand<1>>,
77 );
78 continue;
79 }
80 _ => {}
81 }
82 }
83
84 processing.instructions.push(instruction);
86 }
87 processing
88 }
89}
90
91impl SaturatingArithmeticProcessor {
92 fn should_replace(&self, ty: StorageType) -> bool {
93 self.replace_i32 || !matches!(ty, StorageType::Scalar(ElemType::Int(IntKind::I32)))
94 }
95}
96
97fn run_polyfill<T: CubePrimitive>(
98 processing: &mut ScopeProcessing,
99 lhs: Variable,
100 rhs: Variable,
101 out: Variable,
102 allocator: &Allocator,
103 mut polyfill: impl FnMut(
104 &mut Scope,
105 ExpandElementTyped<T>,
106 ExpandElementTyped<T>,
107 ) -> ExpandElementTyped<T>,
108) {
109 let lhs = ExpandElement::Plain(lhs);
110 let rhs = ExpandElement::Plain(rhs);
111 let mut scope = Scope::root(false)
112 .with_allocator(allocator.clone())
113 .with_types(processing.typemap.clone());
114 scope.register_type::<IntExpand<0>>(lhs.storage_type());
115 if let ElemType::Int(kind) = lhs.elem_type() {
116 let unsigned_ty = match kind {
117 IntKind::I8 => UIntKind::U8,
118 IntKind::I16 => UIntKind::U16,
119 IntKind::I32 => UIntKind::U32,
120 IntKind::I64 => UIntKind::U64,
121 };
122 scope.register_type::<IntExpand<1>>(ElemType::UInt(unsigned_ty).into())
123 }
124
125 let out_poly = polyfill(&mut scope, lhs.into(), rhs.into()).expand;
126 let tmp_processing = scope.process([]);
127
128 for inst in tmp_processing.instructions {
129 processing.instructions.push(inst);
130 }
131 for var in tmp_processing.variables {
132 processing.variables.push(var);
133 }
134
135 processing
136 .instructions
137 .push(Instruction::new(Operation::Copy(*out_poly), out));
138}
139
140#[cube]
141fn saturating_add_unsigned<U: Int>(a: Line<U>, b: Line<U>) -> Line<U> {
142 let c = a.min(!b);
143 c + b
144}
145
146#[cube]
147fn saturating_sub_unsigned<U: Int>(a: Line<U>, b: Line<U>) -> Line<U> {
148 let a = a.max(b);
149 a - b
150}
151
152#[cube]
155fn saturating_add_signed<I: Int, U: Int>(x: Line<I>, y: Line<I>) -> Line<I> {
156 let bit_width = I::type_size_bits();
157 let shift = Line::<U>::new(U::new(comptime![(bit_width - 1) as i64]));
158
159 let ux = Line::<U>::cast_from(x);
160 let uy = Line::<U>::cast_from(y);
161 let res = ux + uy;
162 let ux = (ux >> shift) + Line::<U>::cast_from(I::max_value());
163 let cond = Line::<I>::cast_from((ux ^ uy) | !(uy ^ res)).greater_equal(Line::new(I::new(0)));
164 select_many(cond, Line::cast_from(ux), Line::cast_from(res))
165}
166
167#[cube]
170fn saturating_sub_signed<I: Int, U: Int>(x: Line<I>, y: Line<I>) -> Line<I> {
171 let bit_width = I::type_size_bits();
172 let shift = Line::<U>::new(U::new(comptime![(bit_width - 1) as i64]));
173
174 let ux = Line::<U>::cast_from(x);
175 let uy = Line::<U>::cast_from(y);
176 let res = ux - uy;
177 let ux = (ux >> shift) + Line::<U>::cast_from(I::max_value());
178 let cond = Line::<I>::cast_from((ux ^ uy) & (ux ^ res)).less_than(Line::new(I::new(0)));
179 select_many(cond, Line::cast_from(ux), Line::cast_from(res))
180}