use crate::{
frontend::{CubeContext, CubeType},
unexpanded,
};
use super::{CubePrimitive, ExpandElement, ExpandElementTyped, Init, UInt, Vectorized};
#[derive(Clone, Copy)]
pub struct Comptime<T> {
pub(crate) inner: T,
}
pub trait ComptimeType: CubeType {
fn into_expand(self) -> Self::ExpandType;
}
impl ComptimeType for UInt {
fn into_expand(self) -> Self::ExpandType {
ExpandElementTyped::new(self.into())
}
}
impl<T> Comptime<T> {
pub fn new(inner: T) -> Self {
Self { inner }
}
pub fn get(_comptime: Self) -> T {
unexpanded!()
}
pub fn map<R, F: Fn(T) -> R>(_comptime: Self, _closure: F) -> Comptime<R> {
unexpanded!()
}
pub fn __expand_map<R, F: Fn(T) -> R>(inner: T, closure: F) -> R {
closure(inner)
}
}
impl<T: ComptimeType> Comptime<Option<T>> {
pub fn is_some(comptime: Self) -> Comptime<bool> {
Comptime::new(comptime.inner.is_some())
}
pub fn unwrap_or_else<F>(_comptime: Self, mut _alt: F) -> T
where
F: FnOnce() -> T,
{
unexpanded!()
}
pub fn __expand_unwrap_or_else<F>(
context: &mut CubeContext,
t: Option<T>,
alt: F,
) -> <T as CubeType>::ExpandType
where
F: FnOnce(&mut CubeContext) -> T::ExpandType,
{
match t {
Some(t) => t.into_expand(),
None => alt(context),
}
}
}
impl<T: Clone + Init> CubeType for Comptime<T> {
type ExpandType = T;
}
impl<T: Vectorized> Comptime<T> {
pub fn vectorization(_state: &T) -> Comptime<UInt> {
unexpanded!()
}
pub fn __expand_vectorization(_context: &mut CubeContext, state: T) -> UInt {
state.vectorization_factor()
}
}
impl<T: CubePrimitive + Into<ExpandElement>> Comptime<T> {
pub fn runtime(_comptime: Self) -> T {
unexpanded!()
}
pub fn __expand_runtime(_context: &mut CubeContext, inner: T) -> ExpandElementTyped<T> {
let elem: ExpandElement = inner.into();
elem.into()
}
}
impl<T: core::ops::Add<T, Output = T>> core::ops::Add for Comptime<T> {
type Output = Comptime<T>;
fn add(self, rhs: Self) -> Self::Output {
Comptime::new(self.inner.add(rhs.inner))
}
}
impl<T: core::ops::Sub<T, Output = T>> core::ops::Sub for Comptime<T> {
type Output = Comptime<T>;
fn sub(self, rhs: Self) -> Self::Output {
Comptime::new(self.inner.sub(rhs.inner))
}
}
impl<T: core::ops::Div<T, Output = T>> core::ops::Div for Comptime<T> {
type Output = Comptime<T>;
fn div(self, rhs: Self) -> Self::Output {
Comptime::new(self.inner.div(rhs.inner))
}
}
impl<T: core::ops::Mul<T, Output = T>> core::ops::Mul for Comptime<T> {
type Output = Comptime<T>;
fn mul(self, rhs: Self) -> Self::Output {
Comptime::new(self.inner.mul(rhs.inner))
}
}
impl<T: core::ops::Rem<T, Output = T>> core::ops::Rem for Comptime<T> {
type Output = Comptime<T>;
fn rem(self, rhs: Self) -> Self::Output {
Comptime::new(self.inner.rem(rhs.inner))
}
}
impl<T: core::cmp::PartialOrd + core::cmp::PartialEq> core::cmp::PartialEq for Comptime<T> {
fn eq(&self, other: &Self) -> bool {
core::cmp::PartialEq::eq(&self.inner, &other.inner)
}
}
impl<T: core::cmp::PartialOrd + core::cmp::PartialEq> core::cmp::PartialOrd for Comptime<T> {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
core::cmp::PartialOrd::partial_cmp(&self.inner, &other.inner)
}
}