cubecl_core/frontend/element/
cube_elem.rs

1use cubecl_ir::ExpandElement;
2use cubecl_runtime::{channel::ComputeChannel, client::ComputeClient, server::ComputeServer};
3
4use crate::ir::{Elem, Scope};
5use crate::{Feature, frontend::CubeType};
6
7use super::{ExpandElementIntoMut, ExpandElementTyped};
8
9/// Form of CubeType that encapsulates all primitive types:
10/// Numeric, UInt, Bool
11pub trait CubePrimitive:
12    CubeType<ExpandType = ExpandElementTyped<Self>>
13    + ExpandElementIntoMut
14    // + IntoRuntime
15    + core::cmp::PartialEq
16    + Send
17    + Sync
18    + 'static
19    + Clone
20    + Copy
21{
22    /// Return the element type to use on GPU.
23    fn as_elem(_scope: &Scope) -> Elem {
24        Self::as_elem_native().expect("To be overridden if not native")
25    }
26
27    /// Native or static element type.
28    fn as_elem_native() -> Option<Elem> {
29        None
30    }
31
32    /// Native or static element type.
33    fn as_elem_native_unchecked() -> Elem {
34        Self::as_elem_native().expect("To be a native type")
35    }
36
37    /// Only native element types have a size.
38    fn size() -> Option<usize> {
39        Self::as_elem_native().map(|t| t.size())
40    }
41
42    /// Only native element types have a size.
43    fn size_bits() -> Option<usize> {
44        Self::as_elem_native().map(|t| t.size_bits())
45    }
46
47    /// Minimum line size of this type
48    fn min_line_size(&self) -> Option<u8> {
49        Self::as_elem_native().map(|t| t.min_line_size())
50    }
51
52
53    fn from_expand_elem(elem: ExpandElement) -> Self::ExpandType {
54        ExpandElementTyped::new(elem)
55    }
56
57    fn is_supported<S: ComputeServer<Feature = Feature>, C: ComputeChannel<S>>(
58        client: &ComputeClient<S, C>,
59    ) -> bool {
60        let elem = Self::as_elem_native_unchecked();
61        client.properties().feature_enabled(Feature::Type(elem))
62    }
63
64    fn elem_size() -> u32 {
65        Self::as_elem_native_unchecked().size() as u32
66    }
67
68    fn elem_size_bits() -> u32 {
69        Self::as_elem_native_unchecked().size_bits() as u32
70    }
71
72    fn __expand_elem_size(scope: &Scope) -> u32 {
73        Self::as_elem(scope).size() as u32
74    }
75
76    fn __expand_elem_size_bits(scope: &Scope) -> u32 {
77        Self::as_elem(scope).size_bits() as u32
78    }
79}