cubecl_core/frontend/container/line/
base.rs1use std::num::NonZero;
2
3use crate as cubecl;
4use crate::{
5 frontend::{CubePrimitive, CubeType, ExpandElementIntoMut, ExpandElementTyped},
6 prelude::MulHi,
7};
8use crate::{
9 ir::{Arithmetic, BinaryOperator, Elem, Instruction, Item, Scope},
10 prelude::{Dot, Numeric, binary_expand_fixed_output},
11 unexpanded,
12};
13use cubecl_ir::{Comparison, ExpandElement};
14use cubecl_macros::{cube, intrinsic};
15use derive_more::derive::Neg;
16#[derive(Neg)]
19pub struct Line<P> {
20 pub(crate) val: P,
22}
23
24type LineExpand<P> = ExpandElementTyped<Line<P>>;
25
26impl<P: CubePrimitive> Clone for Line<P> {
27 fn clone(&self) -> Self {
28 *self
29 }
30}
31impl<P: CubePrimitive> Eq for Line<P> {}
32impl<P: CubePrimitive> Copy for Line<P> {}
33
34mod new {
36 use cubecl_macros::comptime_type;
37
38 use super::*;
39
40 #[cube]
41 impl<P: CubePrimitive> Line<P> {
42 #[allow(unused_variables)]
44 pub fn new(val: P) -> Self {
45 intrinsic!(|_| {
46 let elem: ExpandElementTyped<P> = val;
47 elem.expand.into()
48 })
49 }
50
51 #[allow(unused_variables)]
52 pub fn line_size(&self) -> comptime_type!(u32) {
54 intrinsic!(|_| {
55 let elem: ExpandElementTyped<Line<P>> = self;
56 elem.expand
57 .item
58 .vectorization
59 .map(|a| a.get() as u32)
60 .unwrap_or(1)
61 })
62 }
63 }
64}
65
66mod fill {
68 use crate::prelude::cast;
69
70 use super::*;
71
72 #[cube]
73 impl<P: CubePrimitive + Into<ExpandElementTyped<P>>> Line<P> {
74 #[allow(unused_variables)]
85 pub fn fill(self, value: P) -> Self {
86 intrinsic!(|scope| {
87 let length = self.expand.item.vectorization;
88 let output = scope.create_local(Item::vectorized(P::as_elem(scope), length));
89
90 cast::expand::<P>(scope, value, output.clone().into());
91
92 output.into()
93 })
94 }
95 }
96}
97
98mod empty {
100 use crate::prelude::Cast;
101
102 use super::*;
103
104 #[cube]
105 impl<P: CubePrimitive> Line<P> {
106 #[allow(unused_variables)]
110 pub fn empty(#[comptime] size: u32) -> Self {
111 let zero = Line::<P>::cast_from(0);
112 intrinsic!(|scope| {
113 let length = NonZero::new(size as u8);
114 let var: ExpandElementTyped<Line<P>> = scope
117 .create_local_mut(Item::vectorized(Self::as_elem(scope), length))
118 .into();
119 cubecl::frontend::assign::expand(scope, zero, var.clone());
120 var
121 })
122 }
123 }
124}
125
126mod size {
128 use super::*;
129
130 impl<P: CubePrimitive> Line<P> {
131 pub fn size(&self) -> u32 {
142 unexpanded!()
143 }
144
145 pub fn __expand_size(scope: &mut Scope, element: ExpandElementTyped<P>) -> u32 {
147 element.__expand_vectorization_factor_method(scope)
148 }
149 }
150
151 impl<P: CubePrimitive> ExpandElementTyped<Line<P>> {
152 pub fn size(&self) -> u32 {
154 self.expand
155 .item
156 .vectorization
157 .unwrap_or(NonZero::new(1).unwrap())
158 .get() as u32
159 }
160
161 pub fn __expand_size_method(&self, _scope: &mut Scope) -> u32 {
163 self.size()
164 }
165 }
166}
167
168macro_rules! impl_line_comparison {
170 ($name:ident, $operator:ident, $comment:literal) => {
171 ::paste::paste! {
172 mod $name {
174
175 use super::*;
176
177 #[cube]
178 impl<P: CubePrimitive> Line<P> {
179 #[doc = concat!(
180 "Return a new line with the element-wise comparison of the first line being ",
181 $comment,
182 " the second line."
183 )]
184 #[allow(unused_variables)]
185 pub fn $name(self, other: Self) -> Line<bool> {
186 intrinsic!(|scope| {
187 let size = self.expand.item.vectorization;
188 let lhs = self.expand.into();
189 let rhs = other.expand.into();
190
191 let output = scope.create_local_mut(Item::vectorized(bool::as_elem(scope), size));
192
193 scope.register(Instruction::new(
194 Comparison::$operator(BinaryOperator { lhs, rhs }),
195 output.clone().into(),
196 ));
197
198 output.into()
199 })
200 }
201 }
202 }
203 }
204
205 };
206}
207
208impl_line_comparison!(equal, Equal, "equal to");
209impl_line_comparison!(not_equal, NotEqual, "not equal to");
210impl_line_comparison!(less_than, Lower, "less than");
211impl_line_comparison!(greater_than, Greater, "greater than");
212impl_line_comparison!(less_equal, LowerEqual, "less than or equal to");
213impl_line_comparison!(greater_equal, GreaterEqual, "greater than or equal to");
214
215mod bool_and {
216 use cubecl_ir::Operator;
217
218 use crate::prelude::binary_expand;
219
220 use super::*;
221
222 #[cube]
223 impl Line<bool> {
224 #[allow(unused_variables)]
226 pub fn and(self, other: Self) -> Line<bool> {
227 intrinsic!(
228 |scope| binary_expand(scope, self.expand, other.expand, Operator::And).into()
229 )
230 }
231 }
232}
233
234mod bool_or {
235 use cubecl_ir::Operator;
236
237 use crate::prelude::binary_expand;
238
239 use super::*;
240
241 #[cube]
242 impl Line<bool> {
243 #[allow(unused_variables)]
245 pub fn or(self, other: Self) -> Line<bool> {
246 intrinsic!(|scope| binary_expand(scope, self.expand, other.expand, Operator::Or).into())
247 }
248 }
249}
250
251impl<P: CubePrimitive> CubeType for Line<P> {
252 type ExpandType = ExpandElementTyped<Self>;
253}
254
255impl<P: CubePrimitive> ExpandElementIntoMut for Line<P> {
256 fn elem_into_mut(scope: &mut crate::ir::Scope, elem: ExpandElement) -> ExpandElement {
257 P::elem_into_mut(scope, elem)
258 }
259}
260
261impl<P: CubePrimitive> CubePrimitive for Line<P> {
262 fn as_elem(scope: &Scope) -> Elem {
263 P::as_elem(scope)
264 }
265
266 fn as_elem_native() -> Option<Elem> {
267 P::as_elem_native()
268 }
269
270 fn size() -> Option<usize> {
271 P::size()
272 }
273}
274
275impl<N: Numeric> Dot for Line<N> {
276 fn dot(self, _rhs: Self) -> Self {
277 unexpanded!()
278 }
279
280 fn __expand_dot(
281 scope: &mut Scope,
282 lhs: ExpandElementTyped<Self>,
283 rhs: ExpandElementTyped<Self>,
284 ) -> ExpandElementTyped<Self> {
285 let lhs: ExpandElement = lhs.into();
286 let mut item = lhs.item;
287 item.vectorization = None;
288 binary_expand_fixed_output(scope, lhs, rhs.into(), item, Arithmetic::Dot).into()
289 }
290}
291
292impl<N: MulHi + CubePrimitive> MulHi for Line<N> {}