Skip to main content

cubecl_core/post_processing/
predicate.rs

1use alloc::vec::Vec;
2use core::{f32, f64};
3
4use crate as cubecl;
5use cubecl_ir::{
6    Allocator, Comparison, ElemType, FloatKind, Instruction, ManagedVariable, Operation, Processor,
7    Scope, ScopeProcessing, UIntKind, Variable,
8};
9use half::{bf16, f16};
10
11use crate::prelude::*;
12
13define_scalar!(ElemA);
14define_scalar!(IntB);
15define_size!(SizeA);
16
17#[derive(Debug, Default)]
18pub struct PredicateProcessor;
19
20impl Processor for PredicateProcessor {
21    fn transform(
22        &self,
23        mut processing: cubecl_ir::ScopeProcessing,
24        allocator: Allocator,
25    ) -> cubecl_ir::ScopeProcessing {
26        let mut instructions = Vec::new();
27        core::mem::swap(&mut processing.instructions, &mut instructions);
28
29        for instruction in instructions {
30            if let Operation::Comparison(comparison) = &instruction.operation {
31                match comparison {
32                    Comparison::IsNan(op) => {
33                        run_polyfill(
34                            &mut processing,
35                            op.input,
36                            instruction.out(),
37                            &allocator,
38                            is_nan::expand::<ElemA, IntB, SizeA>,
39                        );
40                        continue;
41                    }
42                    Comparison::IsInf(op) => {
43                        run_polyfill(
44                            &mut processing,
45                            op.input,
46                            instruction.out(),
47                            &allocator,
48                            is_inf::expand::<ElemA, IntB, SizeA>,
49                        );
50                        continue;
51                    }
52                    _ => {}
53                }
54            }
55            processing.instructions.push(instruction);
56        }
57        processing
58    }
59}
60
61fn run_polyfill<T: CubePrimitive, O: CubePrimitive>(
62    processing: &mut ScopeProcessing,
63    input: Variable,
64    out: Variable,
65    allocator: &Allocator,
66    mut polyfill: impl FnMut(&mut Scope, NativeExpand<T>, u32, u32) -> NativeExpand<O>,
67) {
68    let input = ManagedVariable::Plain(input);
69    let mut scope = Scope::root(false)
70        .with_allocator(allocator.clone())
71        .with_types(processing.typemap.clone());
72    scope.register_type::<ElemA>(input.storage_type());
73    scope.register_size::<SizeA>(input.vector_size());
74
75    let out_poly = if let ElemType::Float(kind) = input.elem_type() {
76        let (unsigned_ty, bit_width, mantissa_bits) = match kind {
77            FloatKind::F64 => (
78                UIntKind::U64,
79                f64::size_bits().unwrap(),
80                f64::MANTISSA_DIGITS - 1,
81            ),
82            FloatKind::F32 => (
83                UIntKind::U32,
84                f32::size_bits().unwrap(),
85                f32::MANTISSA_DIGITS - 1,
86            ),
87            FloatKind::F16 => (
88                UIntKind::U16,
89                f16::size_bits().unwrap(),
90                f16::MANTISSA_DIGITS - 1,
91            ),
92            FloatKind::BF16 => (
93                UIntKind::U16,
94                bf16::size_bits().unwrap(),
95                bf16::MANTISSA_DIGITS - 1,
96            ),
97            _ => unreachable!(),
98        };
99        scope.register_type::<IntB>(ElemType::UInt(unsigned_ty).into());
100
101        let exp_bits = bit_width as u32 - mantissa_bits - 1;
102
103        polyfill(&mut scope, input.into(), mantissa_bits, exp_bits).expand
104    } else {
105        panic!("Should be float")
106    };
107
108    let tmp_processing = scope.process([]);
109
110    processing.instructions.extend(tmp_processing.instructions);
111    processing.variables.extend(tmp_processing.variables);
112
113    processing
114        .instructions
115        .push(Instruction::new(Operation::Copy(*out_poly), out));
116}
117
118#[cube]
119fn is_nan<F: Float, U: Int, N: Size>(
120    x: Vector<F, N>,
121    #[comptime] mantissa_bits: u32,
122    #[comptime] exp_bits: u32,
123) -> Vector<bool, N> {
124    // Need to mark as u64 otherwise it is coerced into i32 which does not fit the values for f64
125    let inf_bits = comptime![((1u64 << exp_bits as u64) - 1u64) << mantissa_bits as u64];
126    let abs_mask = comptime![(1u64 << (exp_bits as u64 + mantissa_bits as u64)) - 1u64];
127
128    let bits: Vector<U, N> = Vector::<U, N>::reinterpret(x);
129
130    let abs_bits = bits & Vector::new(U::cast_from(abs_mask));
131
132    abs_bits.greater_than(Vector::new(U::cast_from(inf_bits)))
133}
134
135// Same trick as NaN detection following IEEE 754, but check for all 0 bits equality
136#[cube]
137fn is_inf<F: Float, U: Int, N: Size>(
138    x: Vector<F, N>,
139    #[comptime] mantissa_bits: u32,
140    #[comptime] exp_bits: u32,
141) -> Vector<bool, N> {
142    // Need to mark as u64 otherwise it is coerced into i32 which does not fit the values for f64
143    let inf_bits = comptime![((1u64 << exp_bits as u64) - 1u64) << mantissa_bits as u64];
144    let abs_mask = comptime![(1u64 << (exp_bits as u64 + mantissa_bits as u64)) - 1u64];
145
146    let bits: Vector<U, N> = Vector::<U, N>::reinterpret(x);
147
148    let abs_bits = bits & Vector::new(U::cast_from(abs_mask));
149
150    abs_bits.equal(Vector::new(U::cast_from(inf_bits)))
151}