cubecl_runtime/memory_management/drop_queue/
policy.rs1use cubecl_common::bytes::Bytes;
2
3#[derive(Debug)]
9pub struct FlushingPolicy {
10 pub max_bytes_count: u32,
12 pub max_bytes_size: u32,
14}
15
16impl Default for FlushingPolicy {
17 fn default() -> Self {
18 Self {
19 max_bytes_count: 64,
20 max_bytes_size: 64 * 1024 * 1024, }
22 }
23}
24
25#[derive(Default, Debug)]
27pub(crate) struct FlushingPolicyState {
28 bytes_count: u32,
29 bytes_size: u32,
30}
31
32impl FlushingPolicyState {
33 pub(crate) fn register(&mut self, bytes: &Bytes) {
35 self.bytes_count += 1;
36 self.bytes_size += bytes.len() as u32;
37 }
38
39 pub(crate) fn reset(&mut self) {
41 self.bytes_count = 0;
42 self.bytes_size = 0;
43 }
44
45 pub(crate) fn should_flush(&self, policy: &FlushingPolicy) -> bool {
47 self.bytes_count >= policy.max_bytes_count || self.bytes_size >= policy.max_bytes_size
48 }
49}
50
51#[cfg(test)]
52mod policy_tests {
53 use std::vec;
54
55 use super::*;
56
57 fn policy() -> FlushingPolicy {
58 FlushingPolicy {
59 max_bytes_count: 4,
60 max_bytes_size: 100,
61 }
62 }
63
64 fn state() -> FlushingPolicyState {
65 FlushingPolicyState {
66 bytes_count: 0,
67 bytes_size: 0,
68 }
69 }
70
71 #[test]
72 fn no_flush_when_below_both_thresholds() {
73 let s = state();
74 assert!(!s.should_flush(&policy()));
75 }
76
77 #[test]
78 fn flush_when_count_threshold_reached() {
79 let mut s = state();
80 for _ in 0..4 {
81 s.register(&Bytes::from_elems(vec![0u8]));
82 }
83 assert!(s.should_flush(&policy()));
84 }
85
86 #[test]
87 fn flush_when_size_threshold_reached() {
88 let mut s = state();
89 s.register(&Bytes::from_elems(vec![0u8; 101]));
90 assert!(s.should_flush(&policy()));
91 }
92
93 #[test]
94 fn flush_triggered_by_whichever_limit_comes_first() {
95 let mut s = state();
96 s.register(&Bytes::from_elems(vec![0u8; 60]));
98 s.register(&Bytes::from_elems(vec![0u8; 60]));
99 assert!(s.should_flush(&policy()));
100 }
101
102 #[test]
103 fn reset_clears_state() {
104 let mut s = state();
105 for _ in 0..4 {
106 s.register(&Bytes::from_elems(vec![0u8]));
107 }
108 assert!(s.should_flush(&policy()));
109 s.reset();
110 assert!(!s.should_flush(&policy()));
111 }
112}