cubecl_runtime/memory_management/
memory_lock.rs1use crate::storage::StorageId;
2use alloc::collections::BTreeSet;
3
4#[derive(Debug)]
7pub struct MemoryLock {
8 locked: BTreeSet<StorageId>,
9 flush_threshold: usize,
10}
11
12impl MemoryLock {
13 pub fn new(flush_threshold: usize) -> Self {
15 Self {
16 locked: Default::default(),
17 flush_threshold,
18 }
19 }
20 pub fn is_locked(&self, storage: &StorageId) -> bool {
22 self.locked.contains(storage)
23 }
24
25 pub fn has_reached_threshold(&self) -> bool {
27 self.locked.len() >= self.flush_threshold
30 }
31
32 pub fn add_locked(&mut self, storage: StorageId) {
34 self.locked.insert(storage);
35 }
36
37 pub fn clear_locked(&mut self) {
39 self.locked.clear();
40 }
41}