Skip to main content

cubecl_runtime/memory_management/drop_queue/
policy.rs

1use cubecl_common::bytes::Bytes;
2
3/// Defines the thresholds that determine when a [`PendingDropQueue`] should be
4/// flushed.
5///
6/// A flush is triggered when **either** limit is exceeded — whichever comes
7/// first. Set a field to `u32::MAX` / `usize::MAX` to effectively disable it.
8#[derive(Debug)]
9pub struct FlushingPolicy {
10    /// Flush when this many allocations have been staged.
11    pub max_bytes_count: u32,
12    /// Flush when the total staged size reaches this many bytes.
13    pub max_bytes_size: u32,
14}
15
16impl Default for FlushingPolicy {
17    fn default() -> Self {
18        Self {
19            max_bytes_count: 64,
20            max_bytes_size: 64 * 1024 * 1024, // 64 MiB
21        }
22    }
23}
24
25/// Tracks staged allocations and evaluates them against a [`FlushingPolicy`].
26#[derive(Default, Debug)]
27pub(crate) struct FlushingPolicyState {
28    bytes_count: u32,
29    bytes_size: u32,
30}
31
32impl FlushingPolicyState {
33    /// Record a newly staged [`Bytes`] allocation.
34    pub(crate) fn register(&mut self, bytes: &Bytes) {
35        self.bytes_count += 1;
36        self.bytes_size += bytes.len() as u32;
37    }
38
39    /// Reset all counters, typically called after a flush.
40    pub(crate) fn reset(&mut self) {
41        self.bytes_count = 0;
42        self.bytes_size = 0;
43    }
44
45    /// Returns `true` if either threshold in `policy` has been reached.
46    pub(crate) fn should_flush(&self, policy: &FlushingPolicy) -> bool {
47        self.bytes_count >= policy.max_bytes_count || self.bytes_size >= policy.max_bytes_size
48    }
49}
50
51#[cfg(test)]
52mod policy_tests {
53    use std::vec;
54
55    use super::*;
56
57    fn policy() -> FlushingPolicy {
58        FlushingPolicy {
59            max_bytes_count: 4,
60            max_bytes_size: 100,
61        }
62    }
63
64    fn state() -> FlushingPolicyState {
65        FlushingPolicyState {
66            bytes_count: 0,
67            bytes_size: 0,
68        }
69    }
70
71    #[test]
72    fn no_flush_when_below_both_thresholds() {
73        let s = state();
74        assert!(!s.should_flush(&policy()));
75    }
76
77    #[test]
78    fn flush_when_count_threshold_reached() {
79        let mut s = state();
80        for _ in 0..4 {
81            s.register(&Bytes::from_elems(vec![0u8]));
82        }
83        assert!(s.should_flush(&policy()));
84    }
85
86    #[test]
87    fn flush_when_size_threshold_reached() {
88        let mut s = state();
89        s.register(&Bytes::from_elems(vec![0u8; 101]));
90        assert!(s.should_flush(&policy()));
91    }
92
93    #[test]
94    fn flush_triggered_by_whichever_limit_comes_first() {
95        let mut s = state();
96        // Only 2 allocations but already over the size limit.
97        s.register(&Bytes::from_elems(vec![0u8; 60]));
98        s.register(&Bytes::from_elems(vec![0u8; 60]));
99        assert!(s.should_flush(&policy()));
100    }
101
102    #[test]
103    fn reset_clears_state() {
104        let mut s = state();
105        for _ in 0..4 {
106            s.register(&Bytes::from_elems(vec![0u8]));
107        }
108        assert!(s.should_flush(&policy()));
109        s.reset();
110        assert!(!s.should_flush(&policy()));
111    }
112}