Skip to main content

cubecl_core/post_processing/
predicate.rs

1use alloc::vec::Vec;
2use core::{f32, f64};
3
4use crate as cubecl;
5use cubecl_ir::{
6    Allocator, Comparison, ElemType, ExpandElement, FloatKind, Instruction, Operation, Processor,
7    Scope, ScopeProcessing, UIntKind, Variable,
8};
9use half::{bf16, f16};
10
11use crate::prelude::*;
12
13#[derive(Debug, Default)]
14pub struct PredicateProcessor;
15
16impl Processor for PredicateProcessor {
17    fn transform(
18        &self,
19        mut processing: cubecl_ir::ScopeProcessing,
20        allocator: Allocator,
21    ) -> cubecl_ir::ScopeProcessing {
22        let mut instructions = Vec::new();
23        core::mem::swap(&mut processing.instructions, &mut instructions);
24
25        for instruction in instructions {
26            if let Operation::Comparison(comparison) = &instruction.operation {
27                match comparison {
28                    Comparison::IsNan(op) => {
29                        run_polyfill(
30                            &mut processing,
31                            op.input,
32                            instruction.out(),
33                            &allocator,
34                            is_nan::expand::<FloatExpand<0>, IntExpand<1>>,
35                        );
36                        continue;
37                    }
38                    Comparison::IsInf(op) => {
39                        run_polyfill(
40                            &mut processing,
41                            op.input,
42                            instruction.out(),
43                            &allocator,
44                            is_inf::expand::<FloatExpand<0>, IntExpand<1>>,
45                        );
46                        continue;
47                    }
48                    _ => {}
49                }
50            }
51            processing.instructions.push(instruction);
52        }
53        processing
54    }
55}
56
57fn run_polyfill<T: CubePrimitive, O: CubePrimitive>(
58    processing: &mut ScopeProcessing,
59    input: Variable,
60    out: Variable,
61    allocator: &Allocator,
62    mut polyfill: impl FnMut(&mut Scope, ExpandElementTyped<T>, u32, u32) -> ExpandElementTyped<O>,
63) {
64    let input = ExpandElement::Plain(input);
65    let mut scope = Scope::root(false)
66        .with_allocator(allocator.clone())
67        .with_types(processing.typemap.clone());
68    scope.register_type::<FloatExpand<0>>(input.storage_type());
69
70    let out_poly = if let ElemType::Float(kind) = input.elem_type() {
71        let (unsigned_ty, bit_width, mantissa_bits) = match kind {
72            FloatKind::F64 => (
73                UIntKind::U64,
74                f64::size_bits().unwrap(),
75                f64::MANTISSA_DIGITS - 1,
76            ),
77            FloatKind::F32 => (
78                UIntKind::U32,
79                f32::size_bits().unwrap(),
80                f32::MANTISSA_DIGITS - 1,
81            ),
82            FloatKind::F16 => (
83                UIntKind::U16,
84                f16::size_bits().unwrap(),
85                f16::MANTISSA_DIGITS - 1,
86            ),
87            FloatKind::BF16 => (
88                UIntKind::U16,
89                bf16::size_bits().unwrap(),
90                bf16::MANTISSA_DIGITS - 1,
91            ),
92            _ => unreachable!(),
93        };
94        scope.register_type::<IntExpand<1>>(ElemType::UInt(unsigned_ty).into());
95
96        let exp_bits = bit_width as u32 - mantissa_bits - 1;
97
98        polyfill(&mut scope, input.into(), mantissa_bits, exp_bits).expand
99    } else {
100        panic!("Should be float")
101    };
102
103    let tmp_processing = scope.process([]);
104
105    processing.instructions.extend(tmp_processing.instructions);
106    processing.variables.extend(tmp_processing.variables);
107
108    processing
109        .instructions
110        .push(Instruction::new(Operation::Copy(*out_poly), out));
111}
112
113#[cube]
114fn is_nan<F: Float, U: Int>(
115    x: Line<F>,
116    #[comptime] mantissa_bits: u32,
117    #[comptime] exp_bits: u32,
118) -> Line<bool> {
119    // Need to mark as u64 otherwise it is coerced into i32 which does not fit the values for f64
120    let inf_bits = comptime![((1u64 << exp_bits as u64) - 1u64) << mantissa_bits as u64];
121    let abs_mask = comptime![(1u64 << (exp_bits as u64 + mantissa_bits as u64)) - 1u64];
122
123    let bits: Line<U> = Line::<U>::reinterpret(x);
124
125    let abs_bits = bits & Line::new(U::cast_from(abs_mask));
126
127    abs_bits.greater_than(Line::new(U::cast_from(inf_bits)))
128}
129
130// Same trick as NaN detection following IEEE 754, but check for all 0 bits equality
131#[cube]
132fn is_inf<F: Float, U: Int>(
133    x: Line<F>,
134    #[comptime] mantissa_bits: u32,
135    #[comptime] exp_bits: u32,
136) -> Line<bool> {
137    // Need to mark as u64 otherwise it is coerced into i32 which does not fit the values for f64
138    let inf_bits = comptime![((1u64 << exp_bits as u64) - 1u64) << mantissa_bits as u64];
139    let abs_mask = comptime![(1u64 << (exp_bits as u64 + mantissa_bits as u64)) - 1u64];
140
141    let bits: Line<U> = Line::<U>::reinterpret(x);
142
143    let abs_bits = bits & Line::new(U::cast_from(abs_mask));
144
145    abs_bits.equal(Line::new(U::cast_from(inf_bits)))
146}