cubecl_core/post_processing/
predicate.rs1use core::{f32, f64};
2
3use crate as cubecl;
4use cubecl_ir::{
5 Allocator, Comparison, ElemType, ExpandElement, FloatKind, Instruction, Operation, Processor,
6 Scope, ScopeProcessing, UIntKind, Variable,
7};
8use half::{bf16, f16};
9
10use crate::prelude::*;
11
12#[derive(Debug, Default)]
13pub struct PredicateProcessor;
14
15impl Processor for PredicateProcessor {
16 fn transform(
17 &self,
18 mut processing: cubecl_ir::ScopeProcessing,
19 allocator: Allocator,
20 ) -> cubecl_ir::ScopeProcessing {
21 let mut instructions = Vec::new();
22 core::mem::swap(&mut processing.instructions, &mut instructions);
23
24 for instruction in instructions {
25 if let Operation::Comparison(comparison) = &instruction.operation {
26 match comparison {
27 Comparison::IsNan(op) => {
28 run_polyfill(
29 &mut processing,
30 op.input,
31 instruction.out(),
32 &allocator,
33 is_nan::expand::<FloatExpand<0>, IntExpand<1>>,
34 );
35 continue;
36 }
37 Comparison::IsInf(op) => {
38 run_polyfill(
39 &mut processing,
40 op.input,
41 instruction.out(),
42 &allocator,
43 is_inf::expand::<FloatExpand<0>, IntExpand<1>>,
44 );
45 continue;
46 }
47 _ => {}
48 }
49 }
50 processing.instructions.push(instruction);
51 }
52 processing
53 }
54}
55
56fn run_polyfill<T: CubePrimitive, O: CubePrimitive>(
57 processing: &mut ScopeProcessing,
58 input: Variable,
59 out: Variable,
60 allocator: &Allocator,
61 mut polyfill: impl FnMut(&mut Scope, ExpandElementTyped<T>, u32, u32) -> ExpandElementTyped<O>,
62) {
63 let input = ExpandElement::Plain(input);
64 let mut scope = Scope::root(false).with_allocator(allocator.clone());
65 scope.register_type::<FloatExpand<0>>(input.storage_type());
66
67 let out_poly = if let ElemType::Float(kind) = input.elem_type() {
68 let (unsigned_ty, bit_width, mantissa_bits) = match kind {
69 FloatKind::F64 => (
70 UIntKind::U64,
71 f64::size_bits().unwrap(),
72 f64::MANTISSA_DIGITS - 1,
73 ),
74 FloatKind::F32 => (
75 UIntKind::U32,
76 f32::size_bits().unwrap(),
77 f32::MANTISSA_DIGITS - 1,
78 ),
79 FloatKind::F16 => (
80 UIntKind::U16,
81 f16::size_bits().unwrap(),
82 f16::MANTISSA_DIGITS - 1,
83 ),
84 FloatKind::BF16 => (
85 UIntKind::U16,
86 bf16::size_bits().unwrap(),
87 bf16::MANTISSA_DIGITS - 1,
88 ),
89 _ => unreachable!(),
90 };
91 scope.register_type::<IntExpand<1>>(ElemType::UInt(unsigned_ty).into());
92
93 let exp_bits = bit_width as u32 - mantissa_bits - 1;
94
95 polyfill(&mut scope, input.into(), mantissa_bits, exp_bits).expand
96 } else {
97 panic!("Should be float")
98 };
99
100 let tmp_processing = scope.process([]);
101
102 processing.instructions.extend(tmp_processing.instructions);
103 processing.variables.extend(tmp_processing.variables);
104
105 processing
106 .instructions
107 .push(Instruction::new(Operation::Copy(*out_poly), out));
108}
109
110#[cube]
111fn is_nan<F: Float, U: Int>(
112 x: Line<F>,
113 #[comptime] mantissa_bits: u32,
114 #[comptime] exp_bits: u32,
115) -> Line<bool> {
116 let inf_bits = comptime![((1u64 << exp_bits as u64) - 1u64) << mantissa_bits as u64];
118 let abs_mask = comptime![(1u64 << (exp_bits as u64 + mantissa_bits as u64)) - 1u64];
119
120 let bits: Line<U> = Line::<U>::reinterpret(x);
121
122 let abs_bits = bits & Line::new(U::cast_from(abs_mask));
123
124 abs_bits.greater_than(Line::new(U::cast_from(inf_bits)))
125}
126
127#[cube]
129fn is_inf<F: Float, U: Int>(
130 x: Line<F>,
131 #[comptime] mantissa_bits: u32,
132 #[comptime] exp_bits: u32,
133) -> Line<bool> {
134 let inf_bits = comptime![((1u64 << exp_bits as u64) - 1u64) << mantissa_bits as u64];
136 let abs_mask = comptime![(1u64 << (exp_bits as u64 + mantissa_bits as u64)) - 1u64];
137
138 let bits: Line<U> = Line::<U>::reinterpret(x);
139
140 let abs_bits = bits & Line::new(U::cast_from(abs_mask));
141
142 abs_bits.equal(Line::new(U::cast_from(inf_bits)))
143}