cubecl_core/frontend/
polyfills.rs1use cubecl_ir::{Elem, ExpandElement, Variable};
2
3use crate::prelude::*;
4use crate::{self as cubecl, unexpanded};
5
6pub fn set_polyfill<E: CubePrimitive>(_elem: Elem) {
12 unexpanded!()
13}
14
15pub mod set_polyfill {
17 use super::*;
18
19 pub fn expand<E: CubePrimitive>(scope: &mut Scope, elem: Elem) {
21 scope.register_elem::<E>(elem);
22 }
23}
24
25#[cube]
26fn checked_index_assign<E: CubePrimitive>(
27 index: u32,
28 value: Line<E>,
29 out: &mut Array<Line<E>>,
30 #[comptime] has_buffer_len: bool,
31) {
32 let array_len = if comptime![has_buffer_len] {
33 out.buffer_len()
34 } else {
35 out.len()
36 };
37
38 if index < array_len {
39 unsafe { out.index_assign_unchecked(index, value) };
40 }
41}
42
43#[allow(missing_docs)]
44pub fn expand_checked_index_assign(scope: &mut Scope, lhs: Variable, rhs: Variable, out: Variable) {
45 scope.register_elem::<FloatExpand<0>>(rhs.item.elem);
46 checked_index_assign::expand::<FloatExpand<0>>(
47 scope,
48 ExpandElement::Plain(lhs).into(),
49 ExpandElement::Plain(rhs).into(),
50 ExpandElement::Plain(out).into(),
51 out.has_buffer_length(),
52 );
53}
54
55#[cube]
56pub fn erf<F: Float>(x: Line<F>) -> Line<F> {
57 let erf = erf_positive(Abs::abs(x));
58 select_many(x.less_than(Line::new(F::new(0.0))), -erf, erf)
59}
60
61#[cube]
66fn erf_positive<F: Float>(x: Line<F>) -> Line<F> {
67 let p = Line::new(F::new(0.3275911));
68 let a1 = Line::new(F::new(0.2548296));
69 let a2 = Line::new(F::new(-0.28449674));
70 let a3 = Line::new(F::new(1.4214137));
71 let a4 = Line::new(F::new(-1.453152));
72 let a5 = Line::new(F::new(1.0614054));
73 let one = Line::new(F::new(1.0));
74
75 let t = one / (one + p * x);
76 let tmp = ((((a5 * t + a4) * t) + a3) * t + a2) * t + a1;
77
78 one - (tmp * t * Exp::exp(-x * x))
79}
80
81#[allow(missing_docs)]
82pub fn expand_erf(scope: &mut Scope, input: Variable, out: Variable) {
83 scope.register_elem::<FloatExpand<0>>(input.item.elem);
84 let res = erf::expand::<FloatExpand<0>>(scope, ExpandElement::Plain(input).into());
85 assign::expand_no_check(scope, res, ExpandElement::Plain(out).into());
86}
87
88#[cube]
89fn himul_i64(lhs: Line<i32>, rhs: Line<i32>) -> Line<i32> {
90 let shift = Line::empty(lhs.size()).fill(32);
91 let mul = (Line::<i64>::cast_from(lhs) * Line::<i64>::cast_from(rhs)) >> shift;
92 Line::cast_from(mul)
93}
94
95#[cube]
96fn himul_u64(lhs: Line<u32>, rhs: Line<u32>) -> Line<u32> {
97 let shift = Line::empty(lhs.size()).fill(32);
98 let mul = (Line::<u64>::cast_from(lhs) * Line::<u64>::cast_from(rhs)) >> shift;
99 Line::cast_from(mul)
100}
101
102#[allow(missing_docs)]
103pub fn expand_himul_64(scope: &mut Scope, lhs: Variable, rhs: Variable, out: Variable) {
104 match lhs.item.elem {
105 Elem::Int(_) => {
106 let res = himul_i64::expand(
107 scope,
108 ExpandElement::Plain(lhs).into(),
109 ExpandElement::Plain(rhs).into(),
110 );
111 assign::expand_no_check(scope, res, ExpandElement::Plain(out).into());
112 }
113 Elem::UInt(_) => {
114 let res = himul_u64::expand(
115 scope,
116 ExpandElement::Plain(lhs).into(),
117 ExpandElement::Plain(rhs).into(),
118 );
119 assign::expand_no_check(scope, res, ExpandElement::Plain(out).into());
120 }
121 _ => unreachable!(),
122 };
123}
124
125#[cube]
126fn himul_sim(lhs: Line<u32>, rhs: Line<u32>) -> Line<u32> {
127 let low_mask = Line::empty(lhs.size()).fill(0xffff);
128 let shift = Line::empty(lhs.size()).fill(16);
129
130 let lhs_low = lhs & low_mask;
131 let lhs_hi = (lhs >> shift) & low_mask;
132 let rhs_low = rhs & low_mask;
133 let rhs_hi = (rhs >> shift) & low_mask;
134
135 let low_low = lhs_low * rhs_low;
136 let high_low = lhs_hi * rhs_low;
137 let low_high = lhs_low * rhs_hi;
138 let high_high = lhs_hi * rhs_hi;
139
140 let mid = ((low_low >> shift) & low_mask) + (high_low & low_mask) + (low_high & low_mask);
141 high_high
142 + ((high_low >> shift) & low_mask)
143 + ((low_high >> shift) & low_mask)
144 + ((mid >> shift) & low_mask)
145}
146
147#[allow(missing_docs)]
148pub fn expand_himul_sim(scope: &mut Scope, lhs: Variable, rhs: Variable, out: Variable) {
149 let res = himul_sim::expand(
150 scope,
151 ExpandElement::Plain(lhs).into(),
152 ExpandElement::Plain(rhs).into(),
153 );
154 assign::expand_no_check(scope, res, ExpandElement::Plain(out).into());
155}