use alloc::vec::Vec;
use core::{f32, f64};
use crate as cubecl;
use cubecl_ir::{
Allocator, Comparison, ElemType, FloatKind, Instruction, ManagedVariable, Operation, Processor,
Scope, ScopeProcessing, UIntKind, Variable,
};
use half::{bf16, f16};
use crate::prelude::*;
define_scalar!(ElemA);
define_scalar!(IntB);
define_size!(SizeA);
#[derive(Debug, Default)]
pub struct PredicateProcessor;
impl Processor for PredicateProcessor {
fn transform(
&self,
mut processing: cubecl_ir::ScopeProcessing,
allocator: Allocator,
) -> cubecl_ir::ScopeProcessing {
let mut instructions = Vec::new();
core::mem::swap(&mut processing.instructions, &mut instructions);
for instruction in instructions {
if let Operation::Comparison(comparison) = &instruction.operation {
match comparison {
Comparison::IsNan(op) => {
run_polyfill(
&mut processing,
op.input,
instruction.out(),
&allocator,
is_nan::expand::<ElemA, IntB, SizeA>,
);
continue;
}
Comparison::IsInf(op) => {
run_polyfill(
&mut processing,
op.input,
instruction.out(),
&allocator,
is_inf::expand::<ElemA, IntB, SizeA>,
);
continue;
}
_ => {}
}
}
processing.instructions.push(instruction);
}
processing
}
}
fn run_polyfill<T: CubePrimitive, O: CubePrimitive>(
processing: &mut ScopeProcessing,
input: Variable,
out: Variable,
allocator: &Allocator,
mut polyfill: impl FnMut(&mut Scope, NativeExpand<T>, u32, u32) -> NativeExpand<O>,
) {
let input = ManagedVariable::Plain(input);
let mut scope = Scope::root(false)
.with_allocator(allocator.clone())
.with_types(processing.typemap.clone());
scope.register_type::<ElemA>(input.storage_type());
scope.register_size::<SizeA>(input.vector_size());
let out_poly = if let ElemType::Float(kind) = input.elem_type() {
let (unsigned_ty, bit_width, mantissa_bits) = match kind {
FloatKind::F64 => (
UIntKind::U64,
f64::size_bits().unwrap(),
f64::MANTISSA_DIGITS - 1,
),
FloatKind::F32 => (
UIntKind::U32,
f32::size_bits().unwrap(),
f32::MANTISSA_DIGITS - 1,
),
FloatKind::F16 => (
UIntKind::U16,
f16::size_bits().unwrap(),
f16::MANTISSA_DIGITS - 1,
),
FloatKind::BF16 => (
UIntKind::U16,
bf16::size_bits().unwrap(),
bf16::MANTISSA_DIGITS - 1,
),
_ => unreachable!(),
};
scope.register_type::<IntB>(ElemType::UInt(unsigned_ty).into());
let exp_bits = bit_width as u32 - mantissa_bits - 1;
polyfill(&mut scope, input.into(), mantissa_bits, exp_bits).expand
} else {
panic!("Should be float")
};
let tmp_processing = scope.process([]);
processing.instructions.extend(tmp_processing.instructions);
processing.variables.extend(tmp_processing.variables);
processing
.instructions
.push(Instruction::new(Operation::Copy(*out_poly), out));
}
#[cube]
fn is_nan<F: Float, U: Int, N: Size>(
x: Vector<F, N>,
#[comptime] mantissa_bits: u32,
#[comptime] exp_bits: u32,
) -> Vector<bool, N> {
let inf_bits = comptime![((1u64 << exp_bits as u64) - 1u64) << mantissa_bits as u64];
let abs_mask = comptime![(1u64 << (exp_bits as u64 + mantissa_bits as u64)) - 1u64];
let bits: Vector<U, N> = Vector::<U, N>::reinterpret(x);
let abs_bits = bits & Vector::new(U::cast_from(abs_mask));
abs_bits.greater_than(Vector::new(U::cast_from(inf_bits)))
}
#[cube]
fn is_inf<F: Float, U: Int, N: Size>(
x: Vector<F, N>,
#[comptime] mantissa_bits: u32,
#[comptime] exp_bits: u32,
) -> Vector<bool, N> {
let inf_bits = comptime![((1u64 << exp_bits as u64) - 1u64) << mantissa_bits as u64];
let abs_mask = comptime![(1u64 << (exp_bits as u64 + mantissa_bits as u64)) - 1u64];
let bits: Vector<U, N> = Vector::<U, N>::reinterpret(x);
let abs_bits = bits & Vector::new(U::cast_from(abs_mask));
abs_bits.equal(Vector::new(U::cast_from(inf_bits)))
}