use crate::{self as cubecl, prelude::FloatOps};
use crate::{
frontend::{CubePrimitive, CubeType, ExpandElementIntoMut, ExpandElementTyped},
prelude::MulHi,
};
use crate::{
ir::{BinaryOperator, Instruction, Scope, Type},
prelude::Dot,
unexpanded,
};
use cubecl_ir::{Comparison, ConstantValue, ExpandElement, StorageType};
use cubecl_macros::{cube, intrinsic};
use derive_more::derive::Neg;
#[derive(Neg)]
pub struct Line<P> {
pub(crate) val: P,
}
type LineExpand<P> = ExpandElementTyped<Line<P>>;
impl<P: CubePrimitive> Clone for Line<P> {
fn clone(&self) -> Self {
*self
}
}
impl<P: CubePrimitive> Eq for Line<P> {}
impl<P: CubePrimitive> Copy for Line<P> {}
mod new {
use cubecl_ir::LineSize;
use cubecl_macros::comptime_type;
use super::*;
#[cube]
impl<P: CubePrimitive> Line<P> {
#[allow(unused_variables)]
pub fn new(val: P) -> Self {
intrinsic!(|_| {
let elem: ExpandElementTyped<P> = val;
elem.expand.into()
})
}
}
impl<P: CubePrimitive> Line<P> {
pub fn line_size(&self) -> comptime_type!(LineSize) {
unexpanded!()
}
}
}
mod fill {
use crate::prelude::cast;
use super::*;
#[cube]
impl<P: CubePrimitive + Into<ExpandElementTyped<P>>> Line<P> {
#[allow(unused_variables)]
pub fn fill(self, value: P) -> Self {
intrinsic!(|scope| {
let length = self.expand.ty.line_size();
let output = scope.create_local(Type::new(P::as_type(scope)).line(length));
cast::expand::<P, Line<P>>(scope, value, output.clone().into());
output.into()
})
}
}
}
mod empty {
use crate::prelude::Cast;
use super::*;
#[cube]
impl<P: CubePrimitive> Line<P> {
#[allow(unused_variables)]
pub fn empty(#[comptime] size: usize) -> Self {
let zero = Line::<P>::cast_from(0);
intrinsic!(|scope| {
let var: ExpandElementTyped<Line<P>> = scope
.create_local_mut(Type::new(Self::as_type(scope)).line(size))
.into();
cubecl::frontend::assign::expand(scope, zero, var.clone());
var
})
}
}
}
mod size {
use cubecl_ir::LineSize;
use super::*;
impl<P: CubePrimitive> Line<P> {
pub fn size(&self) -> LineSize {
unexpanded!()
}
pub fn __expand_size(scope: &mut Scope, element: ExpandElementTyped<P>) -> LineSize {
element.__expand_line_size_method(scope)
}
}
impl<P: CubePrimitive> ExpandElementTyped<Line<P>> {
pub fn size(&self) -> LineSize {
self.expand.ty.line_size()
}
pub fn __expand_size_method(&self, _scope: &mut Scope) -> LineSize {
self.size()
}
}
}
macro_rules! impl_line_comparison {
($name:ident, $operator:ident, $comment:literal) => {
::paste::paste! {
mod $name {
use super::*;
#[cube]
impl<P: CubePrimitive> Line<P> {
#[doc = concat!(
"Return a new line with the element-wise comparison of the first line being ",
$comment,
" the second line."
)]
#[allow(unused_variables)]
pub fn $name(self, other: Self) -> Line<bool> {
intrinsic!(|scope| {
let size = self.expand.ty.line_size();
let lhs = self.expand.into();
let rhs = other.expand.into();
let output = scope.create_local_mut(Type::new(bool::as_type(scope)).line(size));
scope.register(Instruction::new(
Comparison::$operator(BinaryOperator { lhs, rhs }),
output.clone().into(),
));
output.into()
})
}
}
}
}
};
}
impl_line_comparison!(equal, Equal, "equal to");
impl_line_comparison!(not_equal, NotEqual, "not equal to");
impl_line_comparison!(less_than, Lower, "less than");
impl_line_comparison!(greater_than, Greater, "greater than");
impl_line_comparison!(less_equal, LowerEqual, "less than or equal to");
impl_line_comparison!(greater_equal, GreaterEqual, "greater than or equal to");
mod bool_and {
use cubecl_ir::Operator;
use crate::prelude::binary_expand;
use super::*;
#[cube]
impl Line<bool> {
#[allow(unused_variables)]
pub fn and(self, other: Self) -> Line<bool> {
intrinsic!(
|scope| binary_expand(scope, self.expand, other.expand, Operator::And).into()
)
}
}
}
mod bool_or {
use cubecl_ir::Operator;
use crate::prelude::binary_expand;
use super::*;
#[cube]
impl Line<bool> {
#[allow(unused_variables)]
pub fn or(self, other: Self) -> Line<bool> {
intrinsic!(|scope| binary_expand(scope, self.expand, other.expand, Operator::Or).into())
}
}
}
impl<P: CubePrimitive> CubeType for Line<P> {
type ExpandType = ExpandElementTyped<Self>;
}
impl<P: CubePrimitive> CubeType for &Line<P> {
type ExpandType = ExpandElementTyped<Line<P>>;
}
impl<P: CubePrimitive> CubeType for &mut Line<P> {
type ExpandType = ExpandElementTyped<Line<P>>;
}
impl<P: CubePrimitive> ExpandElementIntoMut for Line<P> {
fn elem_into_mut(scope: &mut crate::ir::Scope, elem: ExpandElement) -> ExpandElement {
P::elem_into_mut(scope, elem)
}
}
impl<P: CubePrimitive> CubePrimitive for Line<P> {
fn as_type(scope: &Scope) -> StorageType {
P::as_type(scope)
}
fn as_type_native() -> Option<StorageType> {
P::as_type_native()
}
fn size() -> Option<usize> {
P::size()
}
fn from_const_value(value: ConstantValue) -> Self {
Self::new(P::from_const_value(value))
}
}
impl<N: Dot + CubePrimitive> Dot for Line<N> {}
impl<N: MulHi + CubePrimitive> MulHi for Line<N> {}
impl<N: FloatOps + CubePrimitive> FloatOps for Line<N> {}