cubecl_core/frontend/
polyfills.rs1use cubecl_ir::{ElemType, ExpandElement, StorageType, Variable};
2
3use crate::prelude::*;
4use crate::{self as cubecl, unexpanded};
5
6pub fn set_polyfill<E: CubePrimitive>(_elem: StorageType) {
12 unexpanded!()
13}
14
15pub mod set_polyfill {
17 use super::*;
18
19 pub fn expand<E: CubePrimitive>(scope: &mut Scope, ty: StorageType) {
21 scope.register_type::<E>(ty);
22 }
23}
24
25#[cube]
26fn checked_index_assign<E: CubePrimitive>(
27 index: u32,
28 value: Line<E>,
29 out: &mut Array<Line<E>>,
30 #[comptime] has_buffer_len: bool,
31 #[comptime] unroll_factor: u32,
32) {
33 let array_len = if comptime![has_buffer_len] {
34 out.buffer_len()
35 } else {
36 out.len()
37 };
38
39 if index < array_len * unroll_factor {
40 unsafe { out.index_assign_unchecked(index, value) };
41 }
42}
43
44#[allow(missing_docs)]
45pub fn expand_checked_index_assign(
46 scope: &mut Scope,
47 lhs: Variable,
48 rhs: Variable,
49 out: Variable,
50 unroll_factor: u32,
51) {
52 scope.register_type::<FloatExpand<0>>(rhs.ty.storage_type());
53 checked_index_assign::expand::<FloatExpand<0>>(
54 scope,
55 ExpandElement::Plain(lhs).into(),
56 ExpandElement::Plain(rhs).into(),
57 ExpandElement::Plain(out).into(),
58 out.has_buffer_length(),
59 unroll_factor,
60 );
61}
62
63#[cube]
64pub fn erf<F: Float>(x: Line<F>) -> Line<F> {
65 let erf = erf_positive(Abs::abs(x));
66 select_many(x.less_than(Line::new(F::new(0.0))), -erf, erf)
67}
68
69#[cube]
74fn erf_positive<F: Float>(x: Line<F>) -> Line<F> {
75 let p = Line::new(F::new(0.3275911));
76 let a1 = Line::new(F::new(0.2548296));
77 let a2 = Line::new(F::new(-0.28449674));
78 let a3 = Line::new(F::new(1.4214137));
79 let a4 = Line::new(F::new(-1.453152));
80 let a5 = Line::new(F::new(1.0614054));
81 let one = Line::new(F::new(1.0));
82
83 let t = one / (one + p * x);
84 let tmp = ((((a5 * t + a4) * t) + a3) * t + a2) * t + a1;
85
86 one - (tmp * t * Exp::exp(-x * x))
87}
88
89#[allow(missing_docs)]
90pub fn expand_erf(scope: &mut Scope, input: Variable, out: Variable) {
91 scope.register_type::<FloatExpand<0>>(input.ty.storage_type());
92 let res = erf::expand::<FloatExpand<0>>(scope, ExpandElement::Plain(input).into());
93 assign::expand_no_check(scope, res, ExpandElement::Plain(out).into());
94}
95
96#[cube]
97fn himul_i64(lhs: Line<i32>, rhs: Line<i32>) -> Line<i32> {
98 let shift = Line::empty(lhs.size()).fill(32);
99 let mul = (Line::<i64>::cast_from(lhs) * Line::<i64>::cast_from(rhs)) >> shift;
100 Line::cast_from(mul)
101}
102
103#[cube]
104fn himul_u64(lhs: Line<u32>, rhs: Line<u32>) -> Line<u32> {
105 let shift = Line::empty(lhs.size()).fill(32);
106 let mul = (Line::<u64>::cast_from(lhs) * Line::<u64>::cast_from(rhs)) >> shift;
107 Line::cast_from(mul)
108}
109
110#[allow(missing_docs)]
111pub fn expand_himul_64(scope: &mut Scope, lhs: Variable, rhs: Variable, out: Variable) {
112 match lhs.ty.elem_type() {
113 ElemType::Int(_) => {
114 let res = himul_i64::expand(
115 scope,
116 ExpandElement::Plain(lhs).into(),
117 ExpandElement::Plain(rhs).into(),
118 );
119 assign::expand_no_check(scope, res, ExpandElement::Plain(out).into());
120 }
121 ElemType::UInt(_) => {
122 let res = himul_u64::expand(
123 scope,
124 ExpandElement::Plain(lhs).into(),
125 ExpandElement::Plain(rhs).into(),
126 );
127 assign::expand_no_check(scope, res, ExpandElement::Plain(out).into());
128 }
129 _ => unreachable!(),
130 };
131}
132
133#[cube]
134fn himul_sim(lhs: Line<u32>, rhs: Line<u32>) -> Line<u32> {
135 let low_mask = Line::empty(lhs.size()).fill(0xffff);
136 let shift = Line::empty(lhs.size()).fill(16);
137
138 let lhs_low = lhs & low_mask;
139 let lhs_hi = (lhs >> shift) & low_mask;
140 let rhs_low = rhs & low_mask;
141 let rhs_hi = (rhs >> shift) & low_mask;
142
143 let low_low = lhs_low * rhs_low;
144 let high_low = lhs_hi * rhs_low;
145 let low_high = lhs_low * rhs_hi;
146 let high_high = lhs_hi * rhs_hi;
147
148 let mid = ((low_low >> shift) & low_mask) + (high_low & low_mask) + (low_high & low_mask);
149 high_high
150 + ((high_low >> shift) & low_mask)
151 + ((low_high >> shift) & low_mask)
152 + ((mid >> shift) & low_mask)
153}
154
155#[allow(missing_docs)]
156pub fn expand_himul_sim(scope: &mut Scope, lhs: Variable, rhs: Variable, out: Variable) {
157 let res = himul_sim::expand(
158 scope,
159 ExpandElement::Plain(lhs).into(),
160 ExpandElement::Plain(rhs).into(),
161 );
162 assign::expand_no_check(scope, res, ExpandElement::Plain(out).into());
163}