cubecl_core/post_processing/
saturating.rs

1use crate as cubecl;
2use cubecl_ir::{
3    Allocator, Arithmetic, ElemType, ExpandElement, Instruction, IntKind, Operation, Processor,
4    Scope, ScopeProcessing, StorageType, UIntKind, Variable,
5};
6
7use crate::prelude::*;
8
9/// Replaces saturating arithmetic with a performant polyfill
10#[derive(new, Debug)]
11pub struct SaturatingArithmeticProcessor {
12    /// Whether to replace i32 saturating sub. Used for CUDA, because there's a more performant
13    /// PTX intrinsic for that specific type.
14    replace_i32: bool,
15}
16
17impl Processor for SaturatingArithmeticProcessor {
18    fn transform(
19        &self,
20        mut processing: cubecl_ir::ScopeProcessing,
21        allocator: Allocator,
22    ) -> cubecl_ir::ScopeProcessing {
23        let mut instructions = Vec::new();
24        core::mem::swap(&mut processing.instructions, &mut instructions);
25
26        for instruction in instructions {
27            if let Operation::Arithmetic(arithmetic) = &instruction.operation {
28                match arithmetic {
29                    Arithmetic::SaturatingAdd(op) if op.lhs.elem_type().is_unsigned_int() => {
30                        run_polyfill(
31                            &mut processing,
32                            op.lhs,
33                            op.rhs,
34                            instruction.out(),
35                            &allocator,
36                            saturating_add_unsigned::expand::<IntExpand<0>>,
37                        );
38                        continue;
39                    }
40                    Arithmetic::SaturatingAdd(op)
41                        if op.lhs.elem_type().is_signed_int()
42                            && self.should_replace(op.lhs.storage_type()) =>
43                    {
44                        run_polyfill(
45                            &mut processing,
46                            op.lhs,
47                            op.rhs,
48                            instruction.out(),
49                            &allocator,
50                            saturating_add_signed::expand::<IntExpand<0>, IntExpand<1>>,
51                        );
52                        continue;
53                    }
54                    Arithmetic::SaturatingSub(op) if op.lhs.elem_type().is_unsigned_int() => {
55                        run_polyfill(
56                            &mut processing,
57                            op.lhs,
58                            op.rhs,
59                            instruction.out(),
60                            &allocator,
61                            saturating_sub_unsigned::expand::<IntExpand<0>>,
62                        );
63                        continue;
64                    }
65                    Arithmetic::SaturatingSub(op)
66                        if op.lhs.elem_type().is_signed_int()
67                            && self.should_replace(op.lhs.storage_type()) =>
68                    {
69                        run_polyfill(
70                            &mut processing,
71                            op.lhs,
72                            op.rhs,
73                            instruction.out(),
74                            &allocator,
75                            saturating_sub_signed::expand::<IntExpand<0>, IntExpand<1>>,
76                        );
77                        continue;
78                    }
79                    _ => {}
80                }
81            }
82
83            // When we have nothing to do.
84            processing.instructions.push(instruction);
85        }
86        processing
87    }
88}
89
90impl SaturatingArithmeticProcessor {
91    fn should_replace(&self, ty: StorageType) -> bool {
92        self.replace_i32 || !matches!(ty, StorageType::Scalar(ElemType::Int(IntKind::I32)))
93    }
94}
95
96fn run_polyfill<T: CubePrimitive>(
97    processing: &mut ScopeProcessing,
98    lhs: Variable,
99    rhs: Variable,
100    out: Variable,
101    allocator: &Allocator,
102    mut polyfill: impl FnMut(
103        &mut Scope,
104        ExpandElementTyped<T>,
105        ExpandElementTyped<T>,
106    ) -> ExpandElementTyped<T>,
107) {
108    let lhs = ExpandElement::Plain(lhs);
109    let rhs = ExpandElement::Plain(rhs);
110    let mut scope = Scope::root(false).with_allocator(allocator.clone());
111    scope.register_type::<IntExpand<0>>(lhs.storage_type());
112    if let ElemType::Int(kind) = lhs.elem_type() {
113        let unsigned_ty = match kind {
114            IntKind::I8 => UIntKind::U8,
115            IntKind::I16 => UIntKind::U16,
116            IntKind::I32 => UIntKind::U32,
117            IntKind::I64 => UIntKind::U64,
118        };
119        scope.register_type::<IntExpand<1>>(ElemType::UInt(unsigned_ty).into())
120    }
121
122    let out_poly = polyfill(&mut scope, lhs.into(), rhs.into()).expand;
123    let tmp_processing = scope.process([]);
124
125    for inst in tmp_processing.instructions {
126        processing.instructions.push(inst);
127    }
128    for var in tmp_processing.variables {
129        processing.variables.push(var);
130    }
131
132    processing
133        .instructions
134        .push(Instruction::new(Operation::Copy(*out_poly), out));
135}
136
137#[cube]
138fn saturating_add_unsigned<U: Int>(a: Line<U>, b: Line<U>) -> Line<U> {
139    let c = Line::<U>::min(a, Line::<U>::bitwise_not(b));
140    c + b
141}
142
143#[cube]
144fn saturating_sub_unsigned<U: Int>(a: Line<U>, b: Line<U>) -> Line<U> {
145    let a = Max::max(a, b);
146    a - b
147}
148
149/// Don't ask me how this works
150/// <https://locklessinc.com/articles/sat_arithmetic/>
151#[cube]
152fn saturating_add_signed<I: Int, U: Int>(x: Line<I>, y: Line<I>) -> Line<I> {
153    let bit_width = I::elem_size_bits();
154    let shift = Line::<U>::new(U::new(comptime![(bit_width - 1) as i64]));
155
156    let ux = Line::<U>::cast_from(x);
157    let uy = Line::<U>::cast_from(y);
158    let res = ux + uy;
159    let ux = (ux >> shift) + Line::<U>::cast_from(I::max_value());
160    let cond = Line::<I>::cast_from((ux ^ uy) | BitwiseNot::bitwise_not(uy ^ res))
161        .greater_equal(Line::new(I::new(0)));
162    select_many(cond, Line::cast_from(ux), Line::cast_from(res))
163}
164
165/// Don't ask me how this works
166/// <https://locklessinc.com/articles/sat_arithmetic/>
167#[cube]
168fn saturating_sub_signed<I: Int, U: Int>(x: Line<I>, y: Line<I>) -> Line<I> {
169    let bit_width = I::elem_size_bits();
170    let shift = Line::<U>::new(U::new(comptime![(bit_width - 1) as i64]));
171
172    let ux = Line::<U>::cast_from(x);
173    let uy = Line::<U>::cast_from(y);
174    let res = ux - uy;
175    let ux = (ux >> shift) + Line::<U>::cast_from(I::max_value());
176    let cond = Line::<I>::cast_from((ux ^ uy) & (ux ^ res)).less_than(Line::new(I::new(0)));
177    select_many(cond, Line::cast_from(ux), Line::cast_from(res))
178}