Skip to main content

cubecl_ir/
address.rs

1use crate::{IntKind, Scope, StorageType, UIntKind};
2
3/// The type used for addressing storage types in a kernel.
4/// This is the type `usize` maps to when used in a kernel, with `isize` being mapped to the signed
5/// equivalent.
6#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
7#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, Default, PartialOrd, Ord)]
8pub enum AddressType {
9    // Discriminants are explicit to ensure correct ordering
10    #[default]
11    U32 = 0,
12    U64 = 1,
13}
14
15impl core::fmt::Display for AddressType {
16    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
17        match self {
18            AddressType::U32 => f.write_str("u32"),
19            AddressType::U64 => f.write_str("u64"),
20        }
21    }
22}
23
24impl AddressType {
25    /// Pick an address type based on the number of elements in a buffer.
26    pub fn from_len(num_elems: usize) -> Self {
27        if num_elems > u32::MAX as usize {
28            AddressType::U64
29        } else {
30            AddressType::U32
31        }
32    }
33
34    /// Pick an address type based on the number of elements in a buffer, for a kernel that requires
35    /// signed indices.
36    pub fn from_len_signed(num_elems: usize) -> Self {
37        if num_elems > i32::MAX as usize {
38            AddressType::U64
39        } else {
40            AddressType::U32
41        }
42    }
43
44    pub fn register(&self, scope: &mut Scope) {
45        scope.register_type::<usize>(self.unsigned_type());
46        scope.register_type::<isize>(self.signed_type());
47    }
48
49    pub fn unsigned_type(&self) -> StorageType {
50        match self {
51            AddressType::U32 => UIntKind::U32.into(),
52            AddressType::U64 => UIntKind::U64.into(),
53        }
54    }
55
56    pub fn signed_type(&self) -> StorageType {
57        match self {
58            AddressType::U32 => IntKind::I32.into(),
59            AddressType::U64 => IntKind::I64.into(),
60        }
61    }
62}