cubecl_core/frontend/container/line/
base.rs1use std::num::NonZero;
2
3use crate::{
4 frontend::{CubePrimitive, CubeType, ExpandElementBaseInit, ExpandElementTyped},
5 prelude::MulHi,
6};
7use crate::{
8 ir::{Arithmetic, BinaryOperator, Elem, Instruction, Item, Scope},
9 prelude::{Dot, Numeric, binary_expand_fixed_output},
10 unexpanded,
11};
12use cubecl_ir::{Comparison, ExpandElement};
13use derive_more::derive::Neg;
14#[derive(Neg)]
17pub struct Line<P> {
18 pub(crate) val: P,
20}
21
22impl<P: CubePrimitive> Clone for Line<P> {
23 fn clone(&self) -> Self {
24 *self
25 }
26}
27impl<P: CubePrimitive> Eq for Line<P> {}
28impl<P: CubePrimitive> Copy for Line<P> {}
29
30mod new {
32 use super::*;
33
34 impl<P: CubePrimitive> Line<P> {
35 pub fn new(val: P) -> Self {
37 Self { val }
38 }
39
40 pub fn __expand_new(_scope: &mut Scope, val: P::ExpandType) -> ExpandElementTyped<Self> {
42 let elem: ExpandElementTyped<P> = val;
43 elem.expand.into()
44 }
45 }
46}
47
48mod fill {
50 use crate::prelude::cast;
51
52 use super::*;
53
54 impl<P: CubePrimitive + Into<ExpandElementTyped<P>>> Line<P> {
55 #[allow(unused_variables)]
66 pub fn fill(mut self, value: P) -> Self {
67 self.val = value;
68 self
69 }
70
71 pub fn __expand_fill(
73 scope: &mut Scope,
74 line: ExpandElementTyped<Self>,
75 value: ExpandElementTyped<P>,
76 ) -> ExpandElementTyped<Self> {
77 line.__expand_fill_method(scope, value)
78 }
79 }
80
81 impl<P: CubePrimitive> ExpandElementTyped<Line<P>> {
82 pub fn __expand_fill_method(self, scope: &mut Scope, value: ExpandElementTyped<P>) -> Self {
84 let length = self.expand.item.vectorization;
85 let output = scope.create_local(Item::vectorized(P::as_elem(scope), length));
86
87 cast::expand::<P>(scope, value, output.clone().into());
88
89 output.into()
90 }
91 }
92}
93
94mod empty {
96 use super::*;
97
98 impl<P: CubePrimitive> Line<P> {
99 #[allow(unused_variables)]
103 pub fn empty(size: u32) -> Self {
104 unexpanded!()
105 }
106
107 pub fn __expand_empty(scope: &mut Scope, length: u32) -> ExpandElementTyped<Self> {
109 let length = NonZero::new(length as u8);
110 scope
111 .create_local_mut(Item::vectorized(Self::as_elem(scope), length))
112 .into()
113 }
114 }
115}
116
117mod size {
119 use super::*;
120
121 impl<P: CubePrimitive> Line<P> {
122 pub fn size(&self) -> u32 {
133 unexpanded!()
134 }
135
136 pub fn __expand_size(scope: &mut Scope, element: ExpandElementTyped<P>) -> u32 {
138 element.__expand_vectorization_factor_method(scope)
139 }
140 }
141
142 impl<P: CubePrimitive> ExpandElementTyped<Line<P>> {
143 pub fn size(&self) -> u32 {
145 self.expand
146 .item
147 .vectorization
148 .unwrap_or(NonZero::new(1).unwrap())
149 .get() as u32
150 }
151
152 pub fn __expand_size_method(&self, _scope: &mut Scope) -> u32 {
154 self.size()
155 }
156 }
157}
158
159macro_rules! impl_line_comparison {
161 ($name:ident, $operator:ident, $comment:literal) => {
162 ::paste::paste! {
163 mod $name {
165
166 use super::*;
167
168 impl<P: CubePrimitive> Line<P> {
169 #[doc = concat!(
170 "Return a new line with the element-wise comparison of the first line being ",
171 $comment,
172 " the second line."
173 )]
174 pub fn $name(self, _other: Self) -> Line<bool> {
175 unexpanded!()
176 }
177
178 pub fn [< __expand_ $name >](
180 scope: &mut Scope,
181 lhs: ExpandElementTyped<Self>,
182 rhs: ExpandElementTyped<Self>,
183 ) -> ExpandElementTyped<Line<bool>> {
184 lhs.[< __expand_ $name _method >](scope, rhs)
185 }
186 }
187
188 impl<P: CubePrimitive> ExpandElementTyped<Line<P>> {
189 pub fn [< __expand_ $name _method >](
191 self,
192 scope: &mut Scope,
193 rhs: Self,
194 ) -> ExpandElementTyped<Line<bool>> {
195 let size = self.expand.item.vectorization;
196 let lhs = self.expand.into();
197 let rhs = rhs.expand.into();
198
199 let output = scope.create_local_mut(Item::vectorized(bool::as_elem(scope), size));
200
201 scope.register(Instruction::new(
202 Comparison::$operator(BinaryOperator { lhs, rhs }),
203 output.clone().into(),
204 ));
205
206 output.into()
207 }
208 }
209 }
210 }
211
212 };
213}
214
215impl_line_comparison!(equal, Equal, "equal to");
216impl_line_comparison!(not_equal, NotEqual, "not equal to");
217impl_line_comparison!(less_than, Lower, "less than");
218impl_line_comparison!(greater_than, Greater, "greater than");
219impl_line_comparison!(less_equal, LowerEqual, "less than or equal to");
220impl_line_comparison!(greater_equal, GreaterEqual, "greater than or equal to");
221
222mod bool_and {
223 use cubecl_ir::Operator;
224
225 use crate::prelude::binary_expand;
226
227 use super::*;
228
229 impl Line<bool> {
230 pub fn and(self, _other: Self) -> Line<bool> {
232 unexpanded!()
233 }
234
235 pub fn __expand_and(
237 scope: &mut Scope,
238 lhs: ExpandElementTyped<Self>,
239 rhs: ExpandElementTyped<Self>,
240 ) -> ExpandElementTyped<Line<bool>> {
241 lhs.__expand_and_method(scope, rhs)
242 }
243 }
244
245 impl ExpandElementTyped<Line<bool>> {
246 pub fn __expand_and_method(
248 self,
249 scope: &mut Scope,
250 rhs: Self,
251 ) -> ExpandElementTyped<Line<bool>> {
252 binary_expand(scope, self.expand, rhs.expand, Operator::And).into()
253 }
254 }
255}
256
257mod bool_or {
258 use cubecl_ir::Operator;
259
260 use crate::prelude::binary_expand;
261
262 use super::*;
263
264 impl Line<bool> {
265 pub fn or(self, _other: Self) -> Line<bool> {
267 unexpanded!()
268 }
269
270 pub fn __expand_or(
272 scope: &mut Scope,
273 lhs: ExpandElementTyped<Self>,
274 rhs: ExpandElementTyped<Self>,
275 ) -> ExpandElementTyped<Line<bool>> {
276 lhs.__expand_and_method(scope, rhs)
277 }
278 }
279
280 impl ExpandElementTyped<Line<bool>> {
281 pub fn __expand_or_method(
283 self,
284 scope: &mut Scope,
285 rhs: Self,
286 ) -> ExpandElementTyped<Line<bool>> {
287 binary_expand(scope, self.expand, rhs.expand, Operator::Or).into()
288 }
289 }
290}
291
292impl<P: CubePrimitive> CubeType for Line<P> {
293 type ExpandType = ExpandElementTyped<Self>;
294}
295
296impl<P: CubePrimitive> ExpandElementBaseInit for Line<P> {
297 fn init_elem(scope: &mut crate::ir::Scope, elem: ExpandElement) -> ExpandElement {
298 P::init_elem(scope, elem)
299 }
300}
301
302impl<P: CubePrimitive> CubePrimitive for Line<P> {
303 fn as_elem(scope: &Scope) -> Elem {
304 P::as_elem(scope)
305 }
306
307 fn as_elem_native() -> Option<Elem> {
308 P::as_elem_native()
309 }
310
311 fn size() -> Option<usize> {
312 P::size()
313 }
314}
315
316impl<N: Numeric> Dot for Line<N> {
317 fn dot(self, _rhs: Self) -> Self {
318 unexpanded!()
319 }
320
321 fn __expand_dot(
322 scope: &mut Scope,
323 lhs: ExpandElementTyped<Self>,
324 rhs: ExpandElementTyped<Self>,
325 ) -> ExpandElementTyped<Self> {
326 let lhs: ExpandElement = lhs.into();
327 let mut item = lhs.item;
328 item.vectorization = None;
329 binary_expand_fixed_output(scope, lhs, rhs.into(), item, Arithmetic::Dot).into()
330 }
331}
332
333impl<N: MulHi + CubePrimitive> MulHi for Line<N> {}