cubecl_core/post_processing/
saturating.rs

1use crate as cubecl;
2use cubecl_ir::{
3    Allocator, Arithmetic, ElemType, ExpandElement, Instruction, IntKind, Operation, Processor,
4    Scope, ScopeProcessing, StorageType, UIntKind, Variable,
5};
6
7use crate::prelude::*;
8
9/// Replaces saturating arithmetic with a performant polyfill
10#[derive(new, Debug)]
11pub struct SaturatingArithmeticProcessor {
12    /// Whether to replace i32 saturating sub. Used for CUDA, because there's a more performant
13    /// PTX intrinsic for that specific type.
14    replace_i32: bool,
15}
16
17impl Processor for SaturatingArithmeticProcessor {
18    fn transform(
19        &self,
20        mut processing: cubecl_ir::ScopeProcessing,
21        allocator: Allocator,
22    ) -> cubecl_ir::ScopeProcessing {
23        let mut instructions = Vec::new();
24        core::mem::swap(&mut processing.instructions, &mut instructions);
25
26        for instruction in instructions {
27            if let Operation::Arithmetic(arithmetic) = &instruction.operation {
28                match arithmetic {
29                    Arithmetic::SaturatingAdd(op) if op.lhs.elem_type().is_unsigned_int() => {
30                        run_polyfill(
31                            &mut processing,
32                            op.lhs,
33                            op.rhs,
34                            instruction.out(),
35                            &allocator,
36                            saturating_add_unsigned::expand::<IntExpand<0>>,
37                        );
38                        continue;
39                    }
40                    Arithmetic::SaturatingAdd(op)
41                        if op.lhs.elem_type().is_signed_int()
42                            && self.should_replace(op.lhs.storage_type()) =>
43                    {
44                        run_polyfill(
45                            &mut processing,
46                            op.lhs,
47                            op.rhs,
48                            instruction.out(),
49                            &allocator,
50                            saturating_add_signed::expand::<IntExpand<0>, IntExpand<1>>,
51                        );
52                        continue;
53                    }
54                    Arithmetic::SaturatingSub(op) if op.lhs.elem_type().is_unsigned_int() => {
55                        run_polyfill(
56                            &mut processing,
57                            op.lhs,
58                            op.rhs,
59                            instruction.out(),
60                            &allocator,
61                            saturating_sub_unsigned::expand::<IntExpand<0>>,
62                        );
63                        continue;
64                    }
65                    Arithmetic::SaturatingSub(op)
66                        if op.lhs.elem_type().is_signed_int()
67                            && self.should_replace(op.lhs.storage_type()) =>
68                    {
69                        run_polyfill(
70                            &mut processing,
71                            op.lhs,
72                            op.rhs,
73                            instruction.out(),
74                            &allocator,
75                            saturating_sub_signed::expand::<IntExpand<0>, IntExpand<1>>,
76                        );
77                        continue;
78                    }
79                    _ => {}
80                }
81            }
82
83            // When we have nothing to do.
84            processing.instructions.push(instruction);
85        }
86        processing
87    }
88}
89
90impl SaturatingArithmeticProcessor {
91    fn should_replace(&self, ty: StorageType) -> bool {
92        self.replace_i32 || !matches!(ty, StorageType::Scalar(ElemType::Int(IntKind::I32)))
93    }
94}
95
96fn run_polyfill<T: CubePrimitive>(
97    processing: &mut ScopeProcessing,
98    lhs: Variable,
99    rhs: Variable,
100    out: Variable,
101    allocator: &Allocator,
102    mut polyfill: impl FnMut(
103        &mut Scope,
104        ExpandElementTyped<T>,
105        ExpandElementTyped<T>,
106    ) -> ExpandElementTyped<T>,
107) {
108    let lhs = ExpandElement::Plain(lhs);
109    let rhs = ExpandElement::Plain(rhs);
110    let mut scope = Scope::root(false)
111        .with_allocator(allocator.clone())
112        .with_types(processing.typemap.clone());
113    scope.register_type::<IntExpand<0>>(lhs.storage_type());
114    if let ElemType::Int(kind) = lhs.elem_type() {
115        let unsigned_ty = match kind {
116            IntKind::I8 => UIntKind::U8,
117            IntKind::I16 => UIntKind::U16,
118            IntKind::I32 => UIntKind::U32,
119            IntKind::I64 => UIntKind::U64,
120        };
121        scope.register_type::<IntExpand<1>>(ElemType::UInt(unsigned_ty).into())
122    }
123
124    let out_poly = polyfill(&mut scope, lhs.into(), rhs.into()).expand;
125    let tmp_processing = scope.process([]);
126
127    for inst in tmp_processing.instructions {
128        processing.instructions.push(inst);
129    }
130    for var in tmp_processing.variables {
131        processing.variables.push(var);
132    }
133
134    processing
135        .instructions
136        .push(Instruction::new(Operation::Copy(*out_poly), out));
137}
138
139#[cube]
140fn saturating_add_unsigned<U: Int>(a: Line<U>, b: Line<U>) -> Line<U> {
141    let c = a.min(!b);
142    c + b
143}
144
145#[cube]
146fn saturating_sub_unsigned<U: Int>(a: Line<U>, b: Line<U>) -> Line<U> {
147    let a = a.max(b);
148    a - b
149}
150
151/// Don't ask me how this works
152/// <https://locklessinc.com/articles/sat_arithmetic/>
153#[cube]
154fn saturating_add_signed<I: Int, U: Int>(x: Line<I>, y: Line<I>) -> Line<I> {
155    let bit_width = I::type_size_bits();
156    let shift = Line::<U>::new(U::new(comptime![(bit_width - 1) as i64]));
157
158    let ux = Line::<U>::cast_from(x);
159    let uy = Line::<U>::cast_from(y);
160    let res = ux + uy;
161    let ux = (ux >> shift) + Line::<U>::cast_from(I::max_value());
162    let cond = Line::<I>::cast_from((ux ^ uy) | !(uy ^ res)).greater_equal(Line::new(I::new(0)));
163    select_many(cond, Line::cast_from(ux), Line::cast_from(res))
164}
165
166/// Don't ask me how this works
167/// <https://locklessinc.com/articles/sat_arithmetic/>
168#[cube]
169fn saturating_sub_signed<I: Int, U: Int>(x: Line<I>, y: Line<I>) -> Line<I> {
170    let bit_width = I::type_size_bits();
171    let shift = Line::<U>::new(U::new(comptime![(bit_width - 1) as i64]));
172
173    let ux = Line::<U>::cast_from(x);
174    let uy = Line::<U>::cast_from(y);
175    let res = ux - uy;
176    let ux = (ux >> shift) + Line::<U>::cast_from(I::max_value());
177    let cond = Line::<I>::cast_from((ux ^ uy) & (ux ^ res)).less_than(Line::new(I::new(0)));
178    select_many(cond, Line::cast_from(ux), Line::cast_from(res))
179}