cubecl_core/frontend/container/line/
base.rs1use crate as cubecl;
2use crate::{
3 frontend::{CubePrimitive, CubeType, ExpandElementIntoMut, ExpandElementTyped},
4 prelude::MulHi,
5};
6use crate::{
7 ir::{Arithmetic, BinaryOperator, Instruction, Scope, Type},
8 prelude::{Dot, Numeric, binary_expand_fixed_output},
9 unexpanded,
10};
11use cubecl_ir::{Comparison, ConstantScalarValue, ExpandElement, StorageType};
12use cubecl_macros::{cube, intrinsic};
13use derive_more::derive::Neg;
14#[derive(Neg)]
17pub struct Line<P> {
18 pub(crate) val: P,
20}
21
22type LineExpand<P> = ExpandElementTyped<Line<P>>;
23
24impl<P: CubePrimitive> Clone for Line<P> {
25 fn clone(&self) -> Self {
26 *self
27 }
28}
29impl<P: CubePrimitive> Eq for Line<P> {}
30impl<P: CubePrimitive> Copy for Line<P> {}
31
32mod new {
34 use cubecl_macros::comptime_type;
35
36 use super::*;
37
38 #[cube]
39 impl<P: CubePrimitive> Line<P> {
40 #[allow(unused_variables)]
42 pub fn new(val: P) -> Self {
43 intrinsic!(|_| {
44 let elem: ExpandElementTyped<P> = val;
45 elem.expand.into()
46 })
47 }
48 }
49
50 impl<P: CubePrimitive> Line<P> {
51 pub fn line_size(&self) -> comptime_type!(u32) {
53 unexpanded!()
54 }
55 }
56}
57
58mod fill {
60 use crate::prelude::cast;
61
62 use super::*;
63
64 #[cube]
65 impl<P: CubePrimitive + Into<ExpandElementTyped<P>>> Line<P> {
66 #[allow(unused_variables)]
77 pub fn fill(self, value: P) -> Self {
78 intrinsic!(|scope| {
79 let length = self.expand.ty.line_size();
80 let output = scope.create_local(Type::new(P::as_type(scope)).line(length));
81
82 cast::expand::<P>(scope, value, output.clone().into());
83
84 output.into()
85 })
86 }
87 }
88}
89
90mod empty {
92 use crate::prelude::Cast;
93
94 use super::*;
95
96 #[cube]
97 impl<P: CubePrimitive> Line<P> {
98 #[allow(unused_variables)]
102 pub fn empty(#[comptime] size: u32) -> Self {
103 let zero = Line::<P>::cast_from(0);
104 intrinsic!(|scope| {
105 let var: ExpandElementTyped<Line<P>> = scope
108 .create_local_mut(Type::new(Self::as_type(scope)).line(size))
109 .into();
110 cubecl::frontend::assign::expand(scope, zero, var.clone());
111 var
112 })
113 }
114 }
115}
116
117mod size {
119 use super::*;
120
121 impl<P: CubePrimitive> Line<P> {
122 pub fn size(&self) -> u32 {
133 unexpanded!()
134 }
135
136 pub fn __expand_size(scope: &mut Scope, element: ExpandElementTyped<P>) -> u32 {
138 element.__expand_line_size_method(scope)
139 }
140 }
141
142 impl<P: CubePrimitive> ExpandElementTyped<Line<P>> {
143 pub fn size(&self) -> u32 {
145 self.expand.ty.line_size()
146 }
147
148 pub fn __expand_size_method(&self, _scope: &mut Scope) -> u32 {
150 self.size()
151 }
152 }
153}
154
155macro_rules! impl_line_comparison {
157 ($name:ident, $operator:ident, $comment:literal) => {
158 ::paste::paste! {
159 mod $name {
161
162 use super::*;
163
164 #[cube]
165 impl<P: CubePrimitive> Line<P> {
166 #[doc = concat!(
167 "Return a new line with the element-wise comparison of the first line being ",
168 $comment,
169 " the second line."
170 )]
171 #[allow(unused_variables)]
172 pub fn $name(self, other: Self) -> Line<bool> {
173 intrinsic!(|scope| {
174 let size = self.expand.ty.line_size();
175 let lhs = self.expand.into();
176 let rhs = other.expand.into();
177
178 let output = scope.create_local_mut(Type::new(bool::as_type(scope)).line(size));
179
180 scope.register(Instruction::new(
181 Comparison::$operator(BinaryOperator { lhs, rhs }),
182 output.clone().into(),
183 ));
184
185 output.into()
186 })
187 }
188 }
189 }
190 }
191
192 };
193}
194
195impl_line_comparison!(equal, Equal, "equal to");
196impl_line_comparison!(not_equal, NotEqual, "not equal to");
197impl_line_comparison!(less_than, Lower, "less than");
198impl_line_comparison!(greater_than, Greater, "greater than");
199impl_line_comparison!(less_equal, LowerEqual, "less than or equal to");
200impl_line_comparison!(greater_equal, GreaterEqual, "greater than or equal to");
201
202mod bool_and {
203 use cubecl_ir::Operator;
204
205 use crate::prelude::binary_expand;
206
207 use super::*;
208
209 #[cube]
210 impl Line<bool> {
211 #[allow(unused_variables)]
213 pub fn and(self, other: Self) -> Line<bool> {
214 intrinsic!(
215 |scope| binary_expand(scope, self.expand, other.expand, Operator::And).into()
216 )
217 }
218 }
219}
220
221mod bool_or {
222 use cubecl_ir::Operator;
223
224 use crate::prelude::binary_expand;
225
226 use super::*;
227
228 #[cube]
229 impl Line<bool> {
230 #[allow(unused_variables)]
232 pub fn or(self, other: Self) -> Line<bool> {
233 intrinsic!(|scope| binary_expand(scope, self.expand, other.expand, Operator::Or).into())
234 }
235 }
236}
237
238impl<P: CubePrimitive> CubeType for Line<P> {
239 type ExpandType = ExpandElementTyped<Self>;
240}
241
242impl<P: CubePrimitive> CubeType for &Line<P> {
243 type ExpandType = ExpandElementTyped<Line<P>>;
244}
245
246impl<P: CubePrimitive> CubeType for &mut Line<P> {
247 type ExpandType = ExpandElementTyped<Line<P>>;
248}
249
250impl<P: CubePrimitive> ExpandElementIntoMut for Line<P> {
251 fn elem_into_mut(scope: &mut crate::ir::Scope, elem: ExpandElement) -> ExpandElement {
252 P::elem_into_mut(scope, elem)
253 }
254}
255
256impl<P: CubePrimitive> CubePrimitive for Line<P> {
257 fn as_type(scope: &Scope) -> StorageType {
258 P::as_type(scope)
259 }
260
261 fn as_type_native() -> Option<StorageType> {
262 P::as_type_native()
263 }
264
265 fn size() -> Option<usize> {
266 P::size()
267 }
268
269 fn from_const_value(value: ConstantScalarValue) -> Self {
270 Self::new(P::from_const_value(value))
271 }
272}
273
274impl<N: Numeric> Dot for Line<N> {
275 fn dot(self, _rhs: Self) -> Self {
276 unexpanded!()
277 }
278
279 fn __expand_dot(
280 scope: &mut Scope,
281 lhs: ExpandElementTyped<Self>,
282 rhs: ExpandElementTyped<Self>,
283 ) -> ExpandElementTyped<Self> {
284 let lhs: ExpandElement = lhs.into();
285 let item = lhs.ty.storage_type().into();
286 binary_expand_fixed_output(scope, lhs, rhs.into(), item, Arithmetic::Dot).into()
287 }
288}
289
290impl<N: MulHi + CubePrimitive> MulHi for Line<N> {}