Skip to main content

cubecl_core/post_processing/
saturating.rs

1use crate as cubecl;
2use alloc::vec::Vec;
3use cubecl_ir::{
4    Allocator, Arithmetic, ElemType, Instruction, IntKind, ManagedVariable, Operation, Processor,
5    Scope, ScopeProcessing, StorageType, UIntKind, Variable,
6};
7
8use crate::prelude::*;
9
10define_scalar!(ElemA);
11define_scalar!(ElemB);
12define_size!(SizeA);
13
14/// Replaces saturating arithmetic with a performant polyfill
15#[derive(new, Debug)]
16pub struct SaturatingArithmeticProcessor {
17    /// Whether to replace i32 saturating sub. Used for CUDA, because there's a more performant
18    /// PTX intrinsic for that specific type.
19    replace_i32: bool,
20}
21
22impl Processor for SaturatingArithmeticProcessor {
23    fn transform(
24        &self,
25        mut processing: cubecl_ir::ScopeProcessing,
26        allocator: Allocator,
27    ) -> cubecl_ir::ScopeProcessing {
28        let mut instructions = Vec::new();
29        core::mem::swap(&mut processing.instructions, &mut instructions);
30
31        for instruction in instructions {
32            if let Operation::Arithmetic(arithmetic) = &instruction.operation {
33                match arithmetic {
34                    Arithmetic::SaturatingAdd(op) if op.lhs.elem_type().is_unsigned_int() => {
35                        run_polyfill(
36                            &mut processing,
37                            op.lhs,
38                            op.rhs,
39                            instruction.out(),
40                            &allocator,
41                            saturating_add_unsigned::expand::<ElemA, SizeA>,
42                        );
43                        continue;
44                    }
45                    Arithmetic::SaturatingAdd(op)
46                        if op.lhs.elem_type().is_signed_int()
47                            && self.should_replace(op.lhs.storage_type()) =>
48                    {
49                        run_polyfill(
50                            &mut processing,
51                            op.lhs,
52                            op.rhs,
53                            instruction.out(),
54                            &allocator,
55                            saturating_add_signed::expand::<ElemA, ElemB, SizeA>,
56                        );
57                        continue;
58                    }
59                    Arithmetic::SaturatingSub(op) if op.lhs.elem_type().is_unsigned_int() => {
60                        run_polyfill(
61                            &mut processing,
62                            op.lhs,
63                            op.rhs,
64                            instruction.out(),
65                            &allocator,
66                            saturating_sub_unsigned::expand::<ElemA, SizeA>,
67                        );
68                        continue;
69                    }
70                    Arithmetic::SaturatingSub(op)
71                        if op.lhs.elem_type().is_signed_int()
72                            && self.should_replace(op.lhs.storage_type()) =>
73                    {
74                        run_polyfill(
75                            &mut processing,
76                            op.lhs,
77                            op.rhs,
78                            instruction.out(),
79                            &allocator,
80                            saturating_sub_signed::expand::<ElemA, ElemB, SizeA>,
81                        );
82                        continue;
83                    }
84                    _ => {}
85                }
86            }
87
88            // When we have nothing to do.
89            processing.instructions.push(instruction);
90        }
91        processing
92    }
93}
94
95impl SaturatingArithmeticProcessor {
96    fn should_replace(&self, ty: StorageType) -> bool {
97        self.replace_i32 || !matches!(ty, StorageType::Scalar(ElemType::Int(IntKind::I32)))
98    }
99}
100
101fn run_polyfill<T: CubePrimitive>(
102    processing: &mut ScopeProcessing,
103    lhs: Variable,
104    rhs: Variable,
105    out: Variable,
106    allocator: &Allocator,
107    mut polyfill: impl FnMut(&mut Scope, NativeExpand<T>, NativeExpand<T>) -> NativeExpand<T>,
108) {
109    let lhs = ManagedVariable::Plain(lhs);
110    let rhs = ManagedVariable::Plain(rhs);
111    let mut scope = Scope::root(false)
112        .with_allocator(allocator.clone())
113        .with_types(processing.typemap.clone());
114    scope.register_type::<ElemA>(lhs.storage_type());
115    scope.register_size::<SizeA>(lhs.vector_size());
116    if let ElemType::Int(kind) = lhs.elem_type() {
117        let unsigned_ty = match kind {
118            IntKind::I8 => UIntKind::U8,
119            IntKind::I16 => UIntKind::U16,
120            IntKind::I32 => UIntKind::U32,
121            IntKind::I64 => UIntKind::U64,
122        };
123        scope.register_type::<ElemB>(ElemType::UInt(unsigned_ty).into())
124    }
125
126    let out_poly = polyfill(&mut scope, lhs.into(), rhs.into()).expand;
127    let tmp_processing = scope.process([]);
128
129    for inst in tmp_processing.instructions {
130        processing.instructions.push(inst);
131    }
132    for var in tmp_processing.variables {
133        processing.variables.push(var);
134    }
135
136    processing
137        .instructions
138        .push(Instruction::new(Operation::Copy(*out_poly), out));
139}
140
141#[cube]
142fn saturating_add_unsigned<U: Int, N: Size>(a: Vector<U, N>, b: Vector<U, N>) -> Vector<U, N> {
143    let c = a.min(!b);
144    c + b
145}
146
147#[cube]
148fn saturating_sub_unsigned<U: Int, N: Size>(a: Vector<U, N>, b: Vector<U, N>) -> Vector<U, N> {
149    let a = a.max(b);
150    a - b
151}
152
153/// Don't ask me how this works
154/// <https://locklessinc.com/articles/sat_arithmetic/>
155#[cube]
156fn saturating_add_signed<I: Int, U: Int, N: Size>(
157    x: Vector<I, N>,
158    y: Vector<I, N>,
159) -> Vector<I, N> {
160    let bit_width = I::type_size_bits();
161    let shift = Vector::<U, N>::new(U::new(comptime![(bit_width - 1) as i64]));
162
163    let ux = Vector::<U, N>::cast_from(x);
164    let uy = Vector::<U, N>::cast_from(y);
165    let res = ux + uy;
166    let ux = (ux >> shift) + Vector::<U, N>::cast_from(I::max_value());
167    let cond =
168        Vector::<I, N>::cast_from((ux ^ uy) | !(uy ^ res)).greater_equal(Vector::new(I::new(0)));
169    select_many(cond, Vector::cast_from(ux), Vector::cast_from(res))
170}
171
172/// Don't ask me how this works
173/// <https://locklessinc.com/articles/sat_arithmetic/>
174#[cube]
175fn saturating_sub_signed<I: Int, U: Int, N: Size>(
176    x: Vector<I, N>,
177    y: Vector<I, N>,
178) -> Vector<I, N> {
179    let bit_width = I::type_size_bits();
180    let shift = Vector::<U, N>::new(U::new(comptime![(bit_width - 1) as i64]));
181
182    let ux = Vector::<U, N>::cast_from(x);
183    let uy = Vector::<U, N>::cast_from(y);
184    let res = ux - uy;
185    let ux = (ux >> shift) + Vector::<U, N>::cast_from(I::max_value());
186    let cond = Vector::<I, N>::cast_from((ux ^ uy) & (ux ^ res)).less_than(Vector::new(I::new(0)));
187    select_many(cond, Vector::cast_from(ux), Vector::cast_from(res))
188}