cubecl_core/post_processing/
predicate.rs

1use core::{f32, f64};
2
3use crate as cubecl;
4use cubecl_ir::{
5    Allocator, Comparison, ElemType, ExpandElement, FloatKind, Instruction, Operation, Processor,
6    Scope, ScopeProcessing, UIntKind, Variable,
7};
8use half::{bf16, f16};
9
10use crate::prelude::*;
11
12#[derive(Debug, Default)]
13pub struct PredicateProcessor;
14
15impl Processor for PredicateProcessor {
16    fn transform(
17        &self,
18        mut processing: cubecl_ir::ScopeProcessing,
19        allocator: Allocator,
20    ) -> cubecl_ir::ScopeProcessing {
21        let mut instructions = Vec::new();
22        core::mem::swap(&mut processing.instructions, &mut instructions);
23
24        for instruction in instructions {
25            if let Operation::Comparison(comparison) = &instruction.operation {
26                match comparison {
27                    Comparison::IsNan(op) => {
28                        run_polyfill(
29                            &mut processing,
30                            op.input,
31                            instruction.out(),
32                            &allocator,
33                            is_nan::expand::<FloatExpand<0>, IntExpand<1>>,
34                        );
35                        continue;
36                    }
37                    Comparison::IsInf(op) => {
38                        run_polyfill(
39                            &mut processing,
40                            op.input,
41                            instruction.out(),
42                            &allocator,
43                            is_inf::expand::<FloatExpand<0>, IntExpand<1>>,
44                        );
45                        continue;
46                    }
47                    _ => {}
48                }
49            }
50            processing.instructions.push(instruction);
51        }
52        processing
53    }
54}
55
56fn run_polyfill<T: CubePrimitive, O: CubePrimitive>(
57    processing: &mut ScopeProcessing,
58    input: Variable,
59    out: Variable,
60    allocator: &Allocator,
61    mut polyfill: impl FnMut(&mut Scope, ExpandElementTyped<T>, u32, u32) -> ExpandElementTyped<O>,
62) {
63    let input = ExpandElement::Plain(input);
64    let mut scope = Scope::root(false)
65        .with_allocator(allocator.clone())
66        .with_types(processing.typemap.clone());
67    scope.register_type::<FloatExpand<0>>(input.storage_type());
68
69    let out_poly = if let ElemType::Float(kind) = input.elem_type() {
70        let (unsigned_ty, bit_width, mantissa_bits) = match kind {
71            FloatKind::F64 => (
72                UIntKind::U64,
73                f64::size_bits().unwrap(),
74                f64::MANTISSA_DIGITS - 1,
75            ),
76            FloatKind::F32 => (
77                UIntKind::U32,
78                f32::size_bits().unwrap(),
79                f32::MANTISSA_DIGITS - 1,
80            ),
81            FloatKind::F16 => (
82                UIntKind::U16,
83                f16::size_bits().unwrap(),
84                f16::MANTISSA_DIGITS - 1,
85            ),
86            FloatKind::BF16 => (
87                UIntKind::U16,
88                bf16::size_bits().unwrap(),
89                bf16::MANTISSA_DIGITS - 1,
90            ),
91            _ => unreachable!(),
92        };
93        scope.register_type::<IntExpand<1>>(ElemType::UInt(unsigned_ty).into());
94
95        let exp_bits = bit_width as u32 - mantissa_bits - 1;
96
97        polyfill(&mut scope, input.into(), mantissa_bits, exp_bits).expand
98    } else {
99        panic!("Should be float")
100    };
101
102    let tmp_processing = scope.process([]);
103
104    processing.instructions.extend(tmp_processing.instructions);
105    processing.variables.extend(tmp_processing.variables);
106
107    processing
108        .instructions
109        .push(Instruction::new(Operation::Copy(*out_poly), out));
110}
111
112#[cube]
113fn is_nan<F: Float, U: Int>(
114    x: Line<F>,
115    #[comptime] mantissa_bits: u32,
116    #[comptime] exp_bits: u32,
117) -> Line<bool> {
118    // Need to mark as u64 otherwise it is coerced into i32 which does not fit the values for f64
119    let inf_bits = comptime![((1u64 << exp_bits as u64) - 1u64) << mantissa_bits as u64];
120    let abs_mask = comptime![(1u64 << (exp_bits as u64 + mantissa_bits as u64)) - 1u64];
121
122    let bits: Line<U> = Line::<U>::reinterpret(x);
123
124    let abs_bits = bits & Line::new(U::cast_from(abs_mask));
125
126    abs_bits.greater_than(Line::new(U::cast_from(inf_bits)))
127}
128
129// Same trick as NaN detection following IEEE 754, but check for all 0 bits equality
130#[cube]
131fn is_inf<F: Float, U: Int>(
132    x: Line<F>,
133    #[comptime] mantissa_bits: u32,
134    #[comptime] exp_bits: u32,
135) -> Line<bool> {
136    // Need to mark as u64 otherwise it is coerced into i32 which does not fit the values for f64
137    let inf_bits = comptime![((1u64 << exp_bits as u64) - 1u64) << mantissa_bits as u64];
138    let abs_mask = comptime![(1u64 << (exp_bits as u64 + mantissa_bits as u64)) - 1u64];
139
140    let bits: Line<U> = Line::<U>::reinterpret(x);
141
142    let abs_bits = bits & Line::new(U::cast_from(abs_mask));
143
144    abs_bits.equal(Line::new(U::cast_from(inf_bits)))
145}