cubecl_core/frontend/container/line/
base.rs1use crate::{self as cubecl, prelude::FloatOps};
2use crate::{
3 frontend::{CubePrimitive, CubeType, ExpandElementIntoMut, ExpandElementTyped},
4 prelude::MulHi,
5};
6use crate::{
7 ir::{BinaryOperator, Instruction, Scope, Type},
8 prelude::Dot,
9 unexpanded,
10};
11use cubecl_ir::{Comparison, ConstantValue, ExpandElement, StorageType};
12use cubecl_macros::{cube, intrinsic};
13use derive_more::derive::Neg;
14#[derive(Neg)]
17pub struct Line<P> {
18 pub(crate) val: P,
20}
21
22type LineExpand<P> = ExpandElementTyped<Line<P>>;
23
24impl<P: CubePrimitive> Clone for Line<P> {
25 fn clone(&self) -> Self {
26 *self
27 }
28}
29impl<P: CubePrimitive> Eq for Line<P> {}
30impl<P: CubePrimitive> Copy for Line<P> {}
31
32mod new {
34 use cubecl_ir::LineSize;
35 use cubecl_macros::comptime_type;
36
37 use super::*;
38
39 #[cube]
40 impl<P: CubePrimitive> Line<P> {
41 #[allow(unused_variables)]
43 pub fn new(val: P) -> Self {
44 intrinsic!(|_| {
45 let elem: ExpandElementTyped<P> = val;
46 elem.expand.into()
47 })
48 }
49 }
50
51 impl<P: CubePrimitive> Line<P> {
52 pub fn line_size(&self) -> comptime_type!(LineSize) {
54 unexpanded!()
55 }
56 }
57}
58
59mod fill {
61 use crate::prelude::cast;
62
63 use super::*;
64
65 #[cube]
66 impl<P: CubePrimitive + Into<ExpandElementTyped<P>>> Line<P> {
67 #[allow(unused_variables)]
78 pub fn fill(self, value: P) -> Self {
79 intrinsic!(|scope| {
80 let length = self.expand.ty.line_size();
81 let output = scope.create_local(Type::new(P::as_type(scope)).line(length));
82
83 cast::expand::<P, Line<P>>(scope, value, output.clone().into());
84
85 output.into()
86 })
87 }
88 }
89}
90
91mod empty {
93 use crate::prelude::Cast;
94
95 use super::*;
96
97 #[cube]
98 impl<P: CubePrimitive> Line<P> {
99 #[allow(unused_variables)]
103 pub fn empty(#[comptime] size: usize) -> Self {
104 let zero = Line::<P>::cast_from(0);
105 intrinsic!(|scope| {
106 let var: ExpandElementTyped<Line<P>> = scope
109 .create_local_mut(Type::new(Self::as_type(scope)).line(size))
110 .into();
111 cubecl::frontend::assign::expand(scope, zero, var.clone());
112 var
113 })
114 }
115 }
116}
117
118mod size {
120 use cubecl_ir::LineSize;
121
122 use super::*;
123
124 impl<P: CubePrimitive> Line<P> {
125 pub fn size(&self) -> LineSize {
136 unexpanded!()
137 }
138
139 pub fn __expand_size(scope: &mut Scope, element: ExpandElementTyped<P>) -> LineSize {
141 element.__expand_line_size_method(scope)
142 }
143 }
144
145 impl<P: CubePrimitive> ExpandElementTyped<Line<P>> {
146 pub fn size(&self) -> LineSize {
148 self.expand.ty.line_size()
149 }
150
151 pub fn __expand_size_method(&self, _scope: &mut Scope) -> LineSize {
153 self.size()
154 }
155 }
156}
157
158macro_rules! impl_line_comparison {
160 ($name:ident, $operator:ident, $comment:literal) => {
161 ::paste::paste! {
162 mod $name {
164
165 use super::*;
166
167 #[cube]
168 impl<P: CubePrimitive> Line<P> {
169 #[doc = concat!(
170 "Return a new line with the element-wise comparison of the first line being ",
171 $comment,
172 " the second line."
173 )]
174 #[allow(unused_variables)]
175 pub fn $name(self, other: Self) -> Line<bool> {
176 intrinsic!(|scope| {
177 let size = self.expand.ty.line_size();
178 let lhs = self.expand.into();
179 let rhs = other.expand.into();
180
181 let output = scope.create_local_mut(Type::new(bool::as_type(scope)).line(size));
182
183 scope.register(Instruction::new(
184 Comparison::$operator(BinaryOperator { lhs, rhs }),
185 output.clone().into(),
186 ));
187
188 output.into()
189 })
190 }
191 }
192 }
193 }
194
195 };
196}
197
198impl_line_comparison!(equal, Equal, "equal to");
199impl_line_comparison!(not_equal, NotEqual, "not equal to");
200impl_line_comparison!(less_than, Lower, "less than");
201impl_line_comparison!(greater_than, Greater, "greater than");
202impl_line_comparison!(less_equal, LowerEqual, "less than or equal to");
203impl_line_comparison!(greater_equal, GreaterEqual, "greater than or equal to");
204
205mod bool_and {
206 use cubecl_ir::Operator;
207
208 use crate::prelude::binary_expand;
209
210 use super::*;
211
212 #[cube]
213 impl Line<bool> {
214 #[allow(unused_variables)]
216 pub fn and(self, other: Self) -> Line<bool> {
217 intrinsic!(
218 |scope| binary_expand(scope, self.expand, other.expand, Operator::And).into()
219 )
220 }
221 }
222}
223
224mod bool_or {
225 use cubecl_ir::Operator;
226
227 use crate::prelude::binary_expand;
228
229 use super::*;
230
231 #[cube]
232 impl Line<bool> {
233 #[allow(unused_variables)]
235 pub fn or(self, other: Self) -> Line<bool> {
236 intrinsic!(|scope| binary_expand(scope, self.expand, other.expand, Operator::Or).into())
237 }
238 }
239}
240
241impl<P: CubePrimitive> CubeType for Line<P> {
242 type ExpandType = ExpandElementTyped<Self>;
243}
244
245impl<P: CubePrimitive> CubeType for &Line<P> {
246 type ExpandType = ExpandElementTyped<Line<P>>;
247}
248
249impl<P: CubePrimitive> CubeType for &mut Line<P> {
250 type ExpandType = ExpandElementTyped<Line<P>>;
251}
252
253impl<P: CubePrimitive> ExpandElementIntoMut for Line<P> {
254 fn elem_into_mut(scope: &mut crate::ir::Scope, elem: ExpandElement) -> ExpandElement {
255 P::elem_into_mut(scope, elem)
256 }
257}
258
259impl<P: CubePrimitive> CubePrimitive for Line<P> {
260 fn as_type(scope: &Scope) -> StorageType {
261 P::as_type(scope)
262 }
263
264 fn as_type_native() -> Option<StorageType> {
265 P::as_type_native()
266 }
267
268 fn size() -> Option<usize> {
269 P::size()
270 }
271
272 fn from_const_value(value: ConstantValue) -> Self {
273 Self::new(P::from_const_value(value))
274 }
275}
276
277impl<N: Dot + CubePrimitive> Dot for Line<N> {}
278impl<N: MulHi + CubePrimitive> MulHi for Line<N> {}
279impl<N: FloatOps + CubePrimitive> FloatOps for Line<N> {}