cubecl_core/frontend/container/line/
base.rs1use std::num::NonZero;
2
3use crate::{
4 ir::{BinaryOperator, ConstantScalarValue, Elem, Instruction, Item, Operator},
5 prelude::{binary_expand_fixed_output, CubeContext, Dot, ExpandElement, Numeric},
6 unexpanded,
7};
8
9use crate::frontend::{
10 CubePrimitive, CubeType, ExpandElementBaseInit, ExpandElementTyped, IntoRuntime,
11};
12
13pub struct Line<P> {
15 pub(crate) val: P,
17}
18
19impl<P: CubePrimitive> Clone for Line<P> {
20 fn clone(&self) -> Self {
21 *self
22 }
23}
24impl<P: CubePrimitive> Eq for Line<P> {}
25impl<P: CubePrimitive> Copy for Line<P> {}
26
27mod new {
29 use super::*;
30
31 impl<P: CubePrimitive> Line<P> {
32 pub fn new(val: P) -> Self {
34 Self { val }
35 }
36
37 pub fn __expand_new(
39 _context: &mut CubeContext,
40 val: P::ExpandType,
41 ) -> ExpandElementTyped<Self> {
42 let elem: ExpandElementTyped<P> = val;
43 elem.expand.into()
44 }
45 }
46}
47
48mod fill {
50 use crate::prelude::cast;
51
52 use super::*;
53
54 impl<P: CubePrimitive + Into<ExpandElementTyped<P>>> Line<P> {
55 #[allow(unused_variables)]
66 pub fn fill(mut self, value: P) -> Self {
67 self.val = value;
68 self
69 }
70
71 pub fn __expand_fill(
73 context: &mut CubeContext,
74 line: ExpandElementTyped<Self>,
75 value: ExpandElementTyped<P>,
76 ) -> ExpandElementTyped<Self> {
77 line.__expand_fill_method(context, value)
78 }
79 }
80
81 impl<P: CubePrimitive> ExpandElementTyped<Line<P>> {
82 pub fn __expand_fill_method(
84 self,
85 context: &mut CubeContext,
86 value: ExpandElementTyped<P>,
87 ) -> Self {
88 let length = self.expand.item.vectorization;
89 let output = context.create_local(Item::vectorized(P::as_elem(context), length));
90
91 cast::expand::<P>(context, value, output.clone().into());
92
93 output.into()
94 }
95 }
96}
97
98mod empty {
100 use super::*;
101
102 impl<P: CubePrimitive + Into<ExpandElementTyped<P>>> Line<P> {
103 #[allow(unused_variables)]
107 pub fn empty(size: u32) -> Self {
108 unexpanded!()
109 }
110
111 pub fn __expand_empty(
113 context: &mut CubeContext,
114 length: ExpandElementTyped<u32>,
115 ) -> ExpandElementTyped<Self> {
116 let length = match length.expand.as_const() {
117 Some(val) => match val {
118 ConstantScalarValue::Int(val, _) => NonZero::new(val)
119 .map(|val| val.get() as u8)
120 .map(|val| NonZero::new(val).unwrap()),
121 ConstantScalarValue::Float(val, _) => NonZero::new(val as i64)
122 .map(|val| val.get() as u8)
123 .map(|val| NonZero::new(val).unwrap()),
124 ConstantScalarValue::UInt(val, _) => NonZero::new(val as u8),
125 ConstantScalarValue::Bool(_) => None,
126 },
127 None => None,
128 };
129 context
130 .create_local_mut(Item::vectorized(Self::as_elem(context), length))
131 .into()
132 }
133 }
134}
135
136mod size {
138 use super::*;
139
140 impl<P: CubePrimitive> Line<P> {
141 pub fn size(&self) -> u32 {
152 unexpanded!()
153 }
154
155 pub fn __expand_size(context: &mut CubeContext, element: ExpandElementTyped<P>) -> u32 {
157 element.__expand_vectorization_factor_method(context)
158 }
159 }
160
161 impl<P: CubePrimitive> ExpandElementTyped<Line<P>> {
162 pub fn size(&self) -> u32 {
164 self.expand
165 .item
166 .vectorization
167 .unwrap_or(NonZero::new(1).unwrap())
168 .get() as u32
169 }
170
171 pub fn __expand_size_method(&self, _context: &mut CubeContext) -> u32 {
173 self.size()
174 }
175 }
176}
177
178macro_rules! impl_line_comparison {
180 ($name:ident, $operator:ident, $comment:literal) => {
181 ::paste::paste! {
182 mod $name {
184
185 use super::*;
186
187 impl<P: CubePrimitive> Line<P> {
188 #[doc = concat!(
189 "Return a new line with the element-wise comparison of the first line being ",
190 $comment,
191 " the second line."
192 )]
193 pub fn $name(self, _other: Self) -> Line<bool> {
194 unexpanded!()
195 }
196
197 pub fn [< __expand_ $name >](
199 context: &mut CubeContext,
200 lhs: ExpandElementTyped<Self>,
201 rhs: ExpandElementTyped<Self>,
202 ) -> ExpandElementTyped<Line<bool>> {
203 lhs.[< __expand_ $name _method >](context, rhs)
204 }
205 }
206
207 impl<P: CubePrimitive> ExpandElementTyped<Line<P>> {
208 pub fn [< __expand_ $name _method >](
210 self,
211 context: &mut CubeContext,
212 rhs: Self,
213 ) -> ExpandElementTyped<Line<bool>> {
214 let size = self.expand.item.vectorization;
215 let lhs = self.expand.into();
216 let rhs = rhs.expand.into();
217
218 let output = context.create_local_mut(Item::vectorized(bool::as_elem(context), size));
219
220 context.register(Instruction::new(
221 Operator::$operator(BinaryOperator { lhs, rhs }),
222 output.clone().into(),
223 ));
224
225 output.into()
226 }
227 }
228 }
229 }
230
231 };
232}
233
234impl_line_comparison!(equal, Equal, "equal to");
235impl_line_comparison!(not_equal, NotEqual, "not equal to");
236impl_line_comparison!(less_than, Lower, "less than");
237impl_line_comparison!(greater_than, Greater, "greater than");
238impl_line_comparison!(less_equal, LowerEqual, "less than or equal to");
239impl_line_comparison!(greater_equal, GreaterEqual, "greater than or equal to");
240
241impl<P: CubePrimitive> CubeType for Line<P> {
242 type ExpandType = ExpandElementTyped<Self>;
243}
244
245impl<P: CubePrimitive> ExpandElementBaseInit for Line<P> {
246 fn init_elem(context: &mut crate::prelude::CubeContext, elem: ExpandElement) -> ExpandElement {
247 P::init_elem(context, elem)
248 }
249}
250
251impl<P: CubePrimitive> IntoRuntime for Line<P> {
252 fn __expand_runtime_method(
253 self,
254 context: &mut crate::prelude::CubeContext,
255 ) -> Self::ExpandType {
256 self.val.__expand_runtime_method(context).expand.into()
257 }
258}
259
260impl<P: CubePrimitive> CubePrimitive for Line<P> {
261 fn as_elem(context: &CubeContext) -> Elem {
262 P::as_elem(context)
263 }
264
265 fn as_elem_native() -> Option<Elem> {
266 P::as_elem_native()
267 }
268
269 fn size() -> Option<usize> {
270 P::size()
271 }
272}
273
274impl<N: Numeric> Dot for Line<N> {
275 fn dot(self, _rhs: Self) -> Self {
276 unexpanded!()
277 }
278
279 fn __expand_dot(
280 context: &mut CubeContext,
281 lhs: ExpandElementTyped<Self>,
282 rhs: ExpandElementTyped<Self>,
283 ) -> ExpandElementTyped<Self> {
284 let lhs: ExpandElement = lhs.into();
285 let mut item = lhs.item;
286 item.vectorization = None;
287 binary_expand_fixed_output(context, lhs, rhs.into(), item, Operator::Dot).into()
288 }
289}