1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
use crate::{
    frontend::{CubeContext, CubeType},
    unexpanded,
};

use super::{CubePrimitive, ExpandElement, ExpandElementTyped, Init, UInt, Vectorized};

#[derive(Clone, Copy)]
/// Encapsulates a value to signify it must be used at compilation time rather than in the kernel
///
/// Use `Comptime<Option<T>>` to have an alternate runtime behaviour if the compilation time value is not present
pub struct Comptime<T> {
    pub(crate) inner: T,
}

/// Type that can be used within [Comptime].
pub trait ComptimeType: CubeType {
    /// Create the expand type from the normal type.
    fn into_expand(self) -> Self::ExpandType;
}

impl ComptimeType for UInt {
    fn into_expand(self) -> Self::ExpandType {
        ExpandElementTyped::new(self.into())
    }
}

impl<T> Comptime<T> {
    /// Create a new Comptime. Useful when hardcoding values in
    /// Cube kernels. For instance:
    /// if Comptime::new(false) {...} never generates the inner code block
    pub fn new(inner: T) -> Self {
        Self { inner }
    }

    /// Get the inner value of a Comptime. For instance:
    /// let c = Comptime::new(false);
    /// if Comptime::get(c) {...}
    pub fn get(_comptime: Self) -> T {
        unexpanded!()
    }

    /// Executes a closure on the comptime and returns a new comptime containing the value.
    pub fn map<R, F: Fn(T) -> R>(_comptime: Self, _closure: F) -> Comptime<R> {
        unexpanded!()
    }

    pub fn __expand_map<R, F: Fn(T) -> R>(inner: T, closure: F) -> R {
        closure(inner)
    }
}

impl<T: ComptimeType> Comptime<Option<T>> {
    /// Map a Comptime optional to a Comptime boolean that tell
    /// whether the optional contained a value
    pub fn is_some(comptime: Self) -> Comptime<bool> {
        Comptime::new(comptime.inner.is_some())
    }

    /// Return the inner value of the Comptime if it exists,
    /// otherwise tell how to compute it at runtime
    pub fn unwrap_or_else<F>(_comptime: Self, mut _alt: F) -> T
    where
        F: FnOnce() -> T,
    {
        unexpanded!()
    }

    /// Expanded version of unwrap_or_else
    pub fn __expand_unwrap_or_else<F>(
        context: &mut CubeContext,
        t: Option<T>,
        alt: F,
    ) -> <T as CubeType>::ExpandType
    where
        F: FnOnce(&mut CubeContext) -> T::ExpandType,
    {
        match t {
            Some(t) => t.into_expand(),
            None => alt(context),
        }
    }
}

impl<T: Clone + Init> CubeType for Comptime<T> {
    type ExpandType = T;
}

impl<T: Vectorized> Comptime<T> {
    pub fn vectorization(_state: &T) -> Comptime<UInt> {
        unexpanded!()
    }

    pub fn __expand_vectorization(_context: &mut CubeContext, state: T) -> UInt {
        state.vectorization_factor()
    }
}

impl<T: CubePrimitive + Into<ExpandElement>> Comptime<T> {
    pub fn runtime(_comptime: Self) -> T {
        unexpanded!()
    }

    pub fn __expand_runtime(_context: &mut CubeContext, inner: T) -> ExpandElementTyped<T> {
        let elem: ExpandElement = inner.into();
        elem.into()
    }
}

impl<T: core::ops::Add<T, Output = T>> core::ops::Add for Comptime<T> {
    type Output = Comptime<T>;

    fn add(self, rhs: Self) -> Self::Output {
        Comptime::new(self.inner.add(rhs.inner))
    }
}

impl<T: core::ops::Sub<T, Output = T>> core::ops::Sub for Comptime<T> {
    type Output = Comptime<T>;

    fn sub(self, rhs: Self) -> Self::Output {
        Comptime::new(self.inner.sub(rhs.inner))
    }
}

impl<T: core::ops::Div<T, Output = T>> core::ops::Div for Comptime<T> {
    type Output = Comptime<T>;

    fn div(self, rhs: Self) -> Self::Output {
        Comptime::new(self.inner.div(rhs.inner))
    }
}

impl<T: core::ops::Mul<T, Output = T>> core::ops::Mul for Comptime<T> {
    type Output = Comptime<T>;

    fn mul(self, rhs: Self) -> Self::Output {
        Comptime::new(self.inner.mul(rhs.inner))
    }
}

impl<T: core::ops::Rem<T, Output = T>> core::ops::Rem for Comptime<T> {
    type Output = Comptime<T>;

    fn rem(self, rhs: Self) -> Self::Output {
        Comptime::new(self.inner.rem(rhs.inner))
    }
}

impl<T: core::cmp::PartialOrd + core::cmp::PartialEq> core::cmp::PartialEq for Comptime<T> {
    fn eq(&self, other: &Self) -> bool {
        core::cmp::PartialEq::eq(&self.inner, &other.inner)
    }
}

impl<T: core::cmp::PartialOrd + core::cmp::PartialEq> core::cmp::PartialOrd for Comptime<T> {
    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
        core::cmp::PartialOrd::partial_cmp(&self.inner, &other.inner)
    }
}