Skip to main content

cubecl_core/post_processing/
saturating.rs

1use crate as cubecl;
2use alloc::vec::Vec;
3use cubecl_ir::{
4    Allocator, Arithmetic, ElemType, ExpandElement, Instruction, IntKind, Operation, Processor,
5    Scope, ScopeProcessing, StorageType, UIntKind, Variable,
6};
7
8use crate::prelude::*;
9
10/// Replaces saturating arithmetic with a performant polyfill
11#[derive(new, Debug)]
12pub struct SaturatingArithmeticProcessor {
13    /// Whether to replace i32 saturating sub. Used for CUDA, because there's a more performant
14    /// PTX intrinsic for that specific type.
15    replace_i32: bool,
16}
17
18impl Processor for SaturatingArithmeticProcessor {
19    fn transform(
20        &self,
21        mut processing: cubecl_ir::ScopeProcessing,
22        allocator: Allocator,
23    ) -> cubecl_ir::ScopeProcessing {
24        let mut instructions = Vec::new();
25        core::mem::swap(&mut processing.instructions, &mut instructions);
26
27        for instruction in instructions {
28            if let Operation::Arithmetic(arithmetic) = &instruction.operation {
29                match arithmetic {
30                    Arithmetic::SaturatingAdd(op) if op.lhs.elem_type().is_unsigned_int() => {
31                        run_polyfill(
32                            &mut processing,
33                            op.lhs,
34                            op.rhs,
35                            instruction.out(),
36                            &allocator,
37                            saturating_add_unsigned::expand::<IntExpand<0>>,
38                        );
39                        continue;
40                    }
41                    Arithmetic::SaturatingAdd(op)
42                        if op.lhs.elem_type().is_signed_int()
43                            && self.should_replace(op.lhs.storage_type()) =>
44                    {
45                        run_polyfill(
46                            &mut processing,
47                            op.lhs,
48                            op.rhs,
49                            instruction.out(),
50                            &allocator,
51                            saturating_add_signed::expand::<IntExpand<0>, IntExpand<1>>,
52                        );
53                        continue;
54                    }
55                    Arithmetic::SaturatingSub(op) if op.lhs.elem_type().is_unsigned_int() => {
56                        run_polyfill(
57                            &mut processing,
58                            op.lhs,
59                            op.rhs,
60                            instruction.out(),
61                            &allocator,
62                            saturating_sub_unsigned::expand::<IntExpand<0>>,
63                        );
64                        continue;
65                    }
66                    Arithmetic::SaturatingSub(op)
67                        if op.lhs.elem_type().is_signed_int()
68                            && self.should_replace(op.lhs.storage_type()) =>
69                    {
70                        run_polyfill(
71                            &mut processing,
72                            op.lhs,
73                            op.rhs,
74                            instruction.out(),
75                            &allocator,
76                            saturating_sub_signed::expand::<IntExpand<0>, IntExpand<1>>,
77                        );
78                        continue;
79                    }
80                    _ => {}
81                }
82            }
83
84            // When we have nothing to do.
85            processing.instructions.push(instruction);
86        }
87        processing
88    }
89}
90
91impl SaturatingArithmeticProcessor {
92    fn should_replace(&self, ty: StorageType) -> bool {
93        self.replace_i32 || !matches!(ty, StorageType::Scalar(ElemType::Int(IntKind::I32)))
94    }
95}
96
97fn run_polyfill<T: CubePrimitive>(
98    processing: &mut ScopeProcessing,
99    lhs: Variable,
100    rhs: Variable,
101    out: Variable,
102    allocator: &Allocator,
103    mut polyfill: impl FnMut(
104        &mut Scope,
105        ExpandElementTyped<T>,
106        ExpandElementTyped<T>,
107    ) -> ExpandElementTyped<T>,
108) {
109    let lhs = ExpandElement::Plain(lhs);
110    let rhs = ExpandElement::Plain(rhs);
111    let mut scope = Scope::root(false)
112        .with_allocator(allocator.clone())
113        .with_types(processing.typemap.clone());
114    scope.register_type::<IntExpand<0>>(lhs.storage_type());
115    if let ElemType::Int(kind) = lhs.elem_type() {
116        let unsigned_ty = match kind {
117            IntKind::I8 => UIntKind::U8,
118            IntKind::I16 => UIntKind::U16,
119            IntKind::I32 => UIntKind::U32,
120            IntKind::I64 => UIntKind::U64,
121        };
122        scope.register_type::<IntExpand<1>>(ElemType::UInt(unsigned_ty).into())
123    }
124
125    let out_poly = polyfill(&mut scope, lhs.into(), rhs.into()).expand;
126    let tmp_processing = scope.process([]);
127
128    for inst in tmp_processing.instructions {
129        processing.instructions.push(inst);
130    }
131    for var in tmp_processing.variables {
132        processing.variables.push(var);
133    }
134
135    processing
136        .instructions
137        .push(Instruction::new(Operation::Copy(*out_poly), out));
138}
139
140#[cube]
141fn saturating_add_unsigned<U: Int>(a: Line<U>, b: Line<U>) -> Line<U> {
142    let c = a.min(!b);
143    c + b
144}
145
146#[cube]
147fn saturating_sub_unsigned<U: Int>(a: Line<U>, b: Line<U>) -> Line<U> {
148    let a = a.max(b);
149    a - b
150}
151
152/// Don't ask me how this works
153/// <https://locklessinc.com/articles/sat_arithmetic/>
154#[cube]
155fn saturating_add_signed<I: Int, U: Int>(x: Line<I>, y: Line<I>) -> Line<I> {
156    let bit_width = I::type_size_bits();
157    let shift = Line::<U>::new(U::new(comptime![(bit_width - 1) as i64]));
158
159    let ux = Line::<U>::cast_from(x);
160    let uy = Line::<U>::cast_from(y);
161    let res = ux + uy;
162    let ux = (ux >> shift) + Line::<U>::cast_from(I::max_value());
163    let cond = Line::<I>::cast_from((ux ^ uy) | !(uy ^ res)).greater_equal(Line::new(I::new(0)));
164    select_many(cond, Line::cast_from(ux), Line::cast_from(res))
165}
166
167/// Don't ask me how this works
168/// <https://locklessinc.com/articles/sat_arithmetic/>
169#[cube]
170fn saturating_sub_signed<I: Int, U: Int>(x: Line<I>, y: Line<I>) -> Line<I> {
171    let bit_width = I::type_size_bits();
172    let shift = Line::<U>::new(U::new(comptime![(bit_width - 1) as i64]));
173
174    let ux = Line::<U>::cast_from(x);
175    let uy = Line::<U>::cast_from(y);
176    let res = ux - uy;
177    let ux = (ux >> shift) + Line::<U>::cast_from(I::max_value());
178    let cond = Line::<I>::cast_from((ux ^ uy) & (ux ^ res)).less_than(Line::new(I::new(0)));
179    select_many(cond, Line::cast_from(ux), Line::cast_from(res))
180}