Skip to main content

cubecl_core/frontend/
polyfills.rs

1use cubecl_ir::{ElemType, ManagedVariable, Type, Variable};
2
3use crate::prelude::*;
4use crate::{self as cubecl, unexpanded};
5
6define_scalar!(ElemA);
7define_size!(SizeA);
8
9/// Change the meaning of the given cube primitive type during compilation.
10///
11/// # Warning
12///
13/// To be used for very custom kernels, it would likely lead to a JIT compiler error otherwise.
14pub fn set_polyfill<E: Scalar, N: Size>(_elem: Type) {
15    unexpanded!()
16}
17
18/// Expand module of [`set_polyfill()`].
19pub mod set_polyfill {
20    use super::*;
21
22    /// Expand function of [`set_polyfill()`].
23    pub fn expand<E: Scalar, N: Size>(scope: &mut Scope, ty: Type) {
24        scope.register_type::<E>(ty.storage_type());
25        scope.register_size::<N>(ty.vector_size());
26    }
27}
28
29#[cube]
30pub fn erf<F: Float, N: Size>(x: Vector<F, N>) -> Vector<F, N> {
31    let erf = erf_positive(x.abs());
32    select_many(x.less_than(Vector::new(F::new(0.0))), -erf, erf)
33}
34
35/// An approximation of the error function: <https://en.wikipedia.org/wiki/Error_function#Numerical_approximations>
36///
37/// > (maximum error: 1.5×10−7)
38/// > All of these approximations are valid for x ≥ 0. To use these approximations for negative x, use the fact that erf x is an odd function, so erf x = −erf(−x).
39#[cube]
40fn erf_positive<F: Float, N: Size>(x: Vector<F, N>) -> Vector<F, N> {
41    let p = Vector::new(F::new(0.3275911));
42    let a1 = Vector::new(F::new(0.2548296));
43    let a2 = Vector::new(F::new(-0.28449674));
44    let a3 = Vector::new(F::new(1.4214137));
45    let a4 = Vector::new(F::new(-1.453152));
46    let a5 = Vector::new(F::new(1.0614054));
47    let one = Vector::new(F::new(1.0));
48
49    let t = one / (one + p * x);
50    let tmp = ((((a5 * t + a4) * t) + a3) * t + a2) * t + a1;
51
52    one - (tmp * t * (-x * x).exp())
53}
54
55#[allow(missing_docs)]
56pub fn expand_erf(scope: &mut Scope, input: Variable, out: Variable) {
57    scope.register_type::<ElemA>(input.ty.storage_type());
58    scope.register_size::<SizeA>(input.vector_size());
59    let res = erf::expand::<ElemA, SizeA>(scope, ManagedVariable::Plain(input).into());
60    assign::expand_no_check(scope, res, ManagedVariable::Plain(out).into());
61}
62
63#[cube]
64fn himul_i64<N: Size>(lhs: Vector<i32, N>, rhs: Vector<i32, N>) -> Vector<i32, N> {
65    let shift = Vector::new(32);
66    let mul = (Vector::<i64, N>::cast_from(lhs) * Vector::<i64, N>::cast_from(rhs)) >> shift;
67    Vector::cast_from(mul)
68}
69
70#[cube]
71fn himul_u64<N: Size>(lhs: Vector<u32, N>, rhs: Vector<u32, N>) -> Vector<u32, N> {
72    let shift = Vector::new(32);
73    let mul = (Vector::<u64, N>::cast_from(lhs) * Vector::<u64, N>::cast_from(rhs)) >> shift;
74    Vector::cast_from(mul)
75}
76
77#[allow(missing_docs)]
78pub fn expand_himul_64(scope: &mut Scope, lhs: Variable, rhs: Variable, out: Variable) {
79    scope.register_size::<SizeA>(lhs.vector_size());
80    match lhs.ty.elem_type() {
81        ElemType::Int(_) => {
82            let res = himul_i64::expand::<SizeA>(
83                scope,
84                ManagedVariable::Plain(lhs).into(),
85                ManagedVariable::Plain(rhs).into(),
86            );
87            assign::expand_no_check(scope, res, ManagedVariable::Plain(out).into());
88        }
89        ElemType::UInt(_) => {
90            let res = himul_u64::expand::<SizeA>(
91                scope,
92                ManagedVariable::Plain(lhs).into(),
93                ManagedVariable::Plain(rhs).into(),
94            );
95            assign::expand_no_check(scope, res, ManagedVariable::Plain(out).into());
96        }
97        _ => unreachable!(),
98    };
99}
100
101#[cube]
102fn himul_sim<N: Size>(lhs: Vector<u32, N>, rhs: Vector<u32, N>) -> Vector<u32, N> {
103    let low_mask = Vector::new(0xffff);
104    let shift = Vector::new(16);
105
106    let lhs_low = lhs & low_mask;
107    let lhs_hi = (lhs >> shift) & low_mask;
108    let rhs_low = rhs & low_mask;
109    let rhs_hi = (rhs >> shift) & low_mask;
110
111    let low_low = lhs_low * rhs_low;
112    let high_low = lhs_hi * rhs_low;
113    let low_high = lhs_low * rhs_hi;
114    let high_high = lhs_hi * rhs_hi;
115
116    let mid = ((low_low >> shift) & low_mask) + (high_low & low_mask) + (low_high & low_mask);
117    high_high
118        + ((high_low >> shift) & low_mask)
119        + ((low_high >> shift) & low_mask)
120        + ((mid >> shift) & low_mask)
121}
122
123#[allow(missing_docs)]
124pub fn expand_himul_sim(scope: &mut Scope, lhs: Variable, rhs: Variable, out: Variable) {
125    scope.register_size::<SizeA>(lhs.vector_size());
126    let res = himul_sim::expand::<SizeA>(
127        scope,
128        ManagedVariable::Plain(lhs).into(),
129        ManagedVariable::Plain(rhs).into(),
130    );
131    assign::expand_no_check(scope, res, ManagedVariable::Plain(out).into());
132}