cubecl_runtime/memory_management/
base.rs

1#[cfg(not(feature = "std"))]
2use alloc::string::{String, ToString};
3
4/// Amount of memory in use by this allocator
5/// and statistics on how much memory is reserved and
6/// wasted in total.
7#[derive(Debug, Clone, PartialEq, Eq)]
8pub struct MemoryUsage {
9    /// The number of allocations currently active.
10    ///
11    /// This is not the number of times an actual allocation happens to create a new memory page,
12    /// but really the number of active slices.
13    pub number_allocs: u64,
14    /// The number of bytes that are currently actually in use.
15    ///
16    /// This doesn't include any padding or other memory that needs to be
17    /// reserved, and is the minimum amount of memory that could possible
18    /// be allocated.
19    pub bytes_in_use: u64,
20    /// The amount of bytes used for padding memory in currently active allocations.
21    pub bytes_padding: u64,
22    /// The total amount of memory reserved on the device.
23    ///
24    /// This will be at least as much as bytes_in_use but in practice will
25    /// be higher, as allocations reserve memory for future allocations
26    /// and for padding.
27    pub bytes_reserved: u64,
28}
29
30impl MemoryUsage {
31    /// Calculate the combined memory usage of two reports (summing them).
32    pub fn combine(&self, other: MemoryUsage) -> MemoryUsage {
33        MemoryUsage {
34            number_allocs: self.number_allocs + other.number_allocs,
35            bytes_in_use: self.bytes_in_use + other.bytes_in_use,
36            bytes_padding: self.bytes_padding + other.bytes_padding,
37            bytes_reserved: self.bytes_reserved + other.bytes_reserved,
38        }
39    }
40}
41
42#[derive(new)]
43pub(crate) struct BytesFormat {
44    bytes: u64,
45}
46
47impl core::fmt::Display for BytesFormat {
48    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
49        let unit = 1000;
50
51        if self.bytes < unit {
52            f.write_fmt(format_args!("{} B", self.bytes))
53        } else {
54            let size = self.bytes as f64;
55            let exp = match size.log(1000.0).floor() as usize {
56                0 => 1,
57                e => e,
58            };
59            let unit_prefix = "KMGTPEZY".as_bytes();
60            f.write_fmt(format_args!(
61                "{:.2} {}B",
62                (size / unit.pow(exp as u32) as f64),
63                unit_prefix[exp - 1] as char,
64            ))
65        }
66    }
67}
68
69fn bytes_format(bytes: u64) -> String {
70    BytesFormat::new(bytes).to_string()
71}
72
73impl core::fmt::Display for MemoryUsage {
74    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
75        // In the future it'd be nice if MemoryUsage also held some stats about say,
76        // the 5 biggest allocations, to show when you an OOM.
77        let usage_percentage = (self.bytes_in_use as f32 / self.bytes_reserved as f32) * 100.0;
78        let padding_percentage = (self.bytes_padding as f32 / self.bytes_in_use as f32) * 100.0;
79        writeln!(f, "Memory Usage Report:")?;
80        writeln!(f, "  Number of allocations: {}", self.number_allocs)?;
81        writeln!(f, "  Bytes in use: {}", bytes_format(self.bytes_in_use))?;
82        writeln!(
83            f,
84            "  Bytes used for padding: {}",
85            bytes_format(self.bytes_padding)
86        )?;
87        writeln!(
88            f,
89            "  Total bytes reserved: {}",
90            bytes_format(self.bytes_reserved)
91        )?;
92        writeln!(f, "  Usage efficiency: {usage_percentage:.2}%")?;
93        writeln!(f, "  Padding overhead: {padding_percentage:.2}%")
94    }
95}
96
97/// The managed tensor buffer handle that points to some memory segment.
98/// It should not contain actual data.
99pub trait MemoryHandle<Binding>: Clone + Send + Sync + core::fmt::Debug {
100    /// Checks if the underlying memory can be safely mutated.
101    fn can_mut(&self) -> bool;
102    /// Get the binding associated to the current handle.
103    fn binding(self) -> Binding;
104}