cubecl_ir/
address.rs

1use crate::{IntKind, Scope, StorageType, UIntKind};
2
3/// The type used for addressing storage types in a kernel.
4/// This is the type `usize` maps to when used in a kernel, with `isize` being mapped to the signed
5/// equivalent.
6#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, Default, PartialOrd, Ord)]
7pub enum AddressType {
8    // Discriminants are explicit to ensure correct ordering
9    #[default]
10    U32 = 0,
11    U64 = 1,
12}
13
14impl core::fmt::Display for AddressType {
15    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
16        match self {
17            AddressType::U32 => f.write_str("u32"),
18            AddressType::U64 => f.write_str("u64"),
19        }
20    }
21}
22
23impl AddressType {
24    /// Pick an address type based on the number of elements in a buffer.
25    pub fn from_len(num_elems: usize) -> Self {
26        if num_elems > u32::MAX as usize {
27            AddressType::U64
28        } else {
29            AddressType::U32
30        }
31    }
32
33    /// Pick an address type based on the number of elements in a buffer, for a kernel that requires
34    /// signed indices.
35    pub fn from_len_signed(num_elems: usize) -> Self {
36        if num_elems > i32::MAX as usize {
37            AddressType::U64
38        } else {
39            AddressType::U32
40        }
41    }
42
43    pub fn register(&self, scope: &mut Scope) {
44        scope.register_type::<usize>(self.unsigned_type());
45        scope.register_type::<isize>(self.signed_type());
46    }
47
48    pub fn unsigned_type(&self) -> StorageType {
49        match self {
50            AddressType::U32 => UIntKind::U32.into(),
51            AddressType::U64 => UIntKind::U64.into(),
52        }
53    }
54
55    pub fn signed_type(&self) -> StorageType {
56        match self {
57            AddressType::U32 => IntKind::I32.into(),
58            AddressType::U64 => IntKind::I64.into(),
59        }
60    }
61}