cubecl_runtime/memory_management/
memory_lock.rs

1use crate::storage::StorageId;
2use alloc::collections::BTreeSet;
3
4/// A set of storage buffers that are 'locked' and cannot be
5/// used for allocations currently.
6#[derive(Debug)]
7pub struct MemoryLock {
8    locked: BTreeSet<StorageId>,
9    flush_threshold: usize,
10}
11
12impl MemoryLock {
13    /// Create a new memory lock with the given flushing threshold.
14    pub fn new(flush_threshold: usize) -> Self {
15        Self {
16            locked: Default::default(),
17            flush_threshold,
18        }
19    }
20    /// Check whether a particular storage ID is locked currently.
21    pub fn is_locked(&self, storage: &StorageId) -> bool {
22        self.locked.contains(storage)
23    }
24
25    /// Whether the flushing threshold has been reached.
26    pub fn has_reached_threshold(&self) -> bool {
27        // For now we only consider the number of handles locked, but we may consider the amount in
28        // bytes at some point.
29        self.locked.len() >= self.flush_threshold
30    }
31
32    /// Add a storage ID to be locked.
33    pub fn add_locked(&mut self, storage: StorageId) {
34        self.locked.insert(storage);
35    }
36
37    /// Remove all locks at once.
38    pub fn clear_locked(&mut self) {
39        self.locked.clear();
40    }
41}