burn_compute/memory_management/
simple.rs

1use super::{MemoryHandle, MemoryManagement};
2use crate::{
3    memory_id_type,
4    storage::{ComputeStorage, StorageHandle, StorageUtilization},
5};
6use alloc::{sync::Arc, vec::Vec};
7use hashbrown::HashMap;
8
9#[cfg(all(not(target_family = "wasm"), feature = "std"))]
10use std::time;
11#[cfg(all(target_family = "wasm", feature = "std"))]
12use web_time as time;
13
14// The ChunkId allows to keep track of how many references there are to a specific chunk.
15memory_id_type!(ChunkId);
16// The SliceId allows to keep track of how many references there are to a specific slice.
17memory_id_type!(SliceId);
18
19impl ChunkId {
20    /// A chunk is free if it is only referred by the chunk hashmap.
21    fn is_free(&self) -> bool {
22        Arc::strong_count(&self.id) <= 1
23    }
24}
25
26impl SliceId {
27    /// A slice is free if it is only referred by the slice hashmap and the chunk it is in.
28    fn is_free(&self) -> bool {
29        Arc::strong_count(&self.id) <= 2
30    }
31}
32
33/// The SimpleHandle is a memory handle, referring to either a chunk or a slice.
34#[derive(Debug, Clone)]
35pub enum SimpleHandle {
36    /// A whole chunk of memory.
37    Chunk(ChunkId),
38    /// A slice of a chunk of memory.
39    Slice(SliceId),
40}
41
42/// The strategy defines the frequency at which deallocation of unused memory chunks should occur.
43#[derive(Debug)]
44pub enum DeallocStrategy {
45    /// Once every n calls to reserve.
46    PeriodTick {
47        /// Number of calls to be executed before triggering the deallocation.
48        period: usize,
49        /// Current state. Should start at zero.
50        state: usize,
51    },
52    #[cfg(feature = "std")]
53    /// Once every period of time
54    PeriodTime {
55        /// Number of time before triggering the deallocation.
56        period: time::Duration,
57        /// Current state. Should start at now.
58        state: time::Instant,
59    },
60    /// Never deallocate.
61    Never,
62}
63
64/// The strategy defines when to reuse chunk with slices.
65#[derive(Debug)]
66pub enum SliceStrategy {
67    /// Never use slices.
68    Never,
69    /// Ratio needed before the chunk can be used as a slice. Between 0 and 1.
70    Ratio(f32),
71    /// When the reserved memory is at least {} bytes.
72    MinimumSize(usize),
73    /// When the reserved memory less than {} bytes.
74    MaximumSize(usize),
75}
76
77impl SliceStrategy {
78    /// If the chunk can be used with a slice.
79    pub fn can_use_chunk(&self, chunk_size: usize, reserved_size: usize) -> bool {
80        if chunk_size < reserved_size {
81            return false;
82        }
83
84        match self {
85            SliceStrategy::Never => false,
86            SliceStrategy::Ratio(ratio) => (reserved_size as f32 / chunk_size as f32) >= *ratio,
87            SliceStrategy::MinimumSize(bytes) => reserved_size >= *bytes,
88            SliceStrategy::MaximumSize(bytes) => reserved_size <= *bytes,
89        }
90    }
91}
92
93impl DeallocStrategy {
94    /// Create a new strategy with the given period.
95    pub fn new_period_tick(period: usize) -> Self {
96        DeallocStrategy::PeriodTick { period, state: 0 }
97    }
98
99    fn should_dealloc(&mut self) -> bool {
100        match self {
101            DeallocStrategy::PeriodTick { period, state } => {
102                *state = (*state + 1) % *period;
103                *state == 0
104            }
105            #[cfg(feature = "std")]
106            DeallocStrategy::PeriodTime { period, state } => {
107                if &state.elapsed() > period {
108                    *state = time::Instant::now();
109                    true
110                } else {
111                    false
112                }
113            }
114            DeallocStrategy::Never => false,
115        }
116    }
117}
118
119/// Reserves and keeps track of chunks of memory in the storage, and slices upon these chunks.
120pub struct SimpleMemoryManagement<Storage> {
121    chunks: HashMap<ChunkId, (StorageHandle, Vec<SliceId>)>,
122    slices: HashMap<SliceId, (StorageHandle, ChunkId)>,
123    dealloc_strategy: DeallocStrategy,
124    slice_strategy: SliceStrategy,
125    storage: Storage,
126}
127
128impl<Storage> core::fmt::Debug for SimpleMemoryManagement<Storage> {
129    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
130        f.write_str(
131            alloc::format!(
132                "SimpleMemoryManagement {:?} - {:?}",
133                self.dealloc_strategy,
134                core::any::type_name::<Storage>(),
135            )
136            .as_str(),
137        )
138    }
139}
140
141impl MemoryHandle for SimpleHandle {
142    /// Returns true if referenced by only one tensor, and only once by the
143    /// memory management hashmaps
144    fn can_mut(&self) -> bool {
145        // One reference in the chunk hashmap, another owned by one tensor.
146        const REFERENCE_LIMIT_CHUNK: usize = 2;
147        // One reference in the chunk hashmap (for the chunk on which this slice is built),
148        // another in the slice hashmap for this slice, and another owned by one tensor.
149        const REFERENCE_LIMIT_SLICE: usize = 3;
150
151        match &self {
152            SimpleHandle::Chunk(id) => Arc::strong_count(&id.id) <= REFERENCE_LIMIT_CHUNK,
153            SimpleHandle::Slice(id) => Arc::strong_count(&id.id) <= REFERENCE_LIMIT_SLICE,
154        }
155    }
156}
157
158impl<Storage: ComputeStorage> MemoryManagement<Storage> for SimpleMemoryManagement<Storage> {
159    type Handle = SimpleHandle;
160
161    /// Returns the resource from the storage, for the specified handle.
162    fn get(&mut self, handle: &Self::Handle) -> Storage::Resource {
163        let resource = match &handle {
164            SimpleHandle::Chunk(id) => &self.chunks.get(id).unwrap().0,
165            SimpleHandle::Slice(id) => &self.slices.get(id).unwrap().0,
166        };
167
168        self.storage.get(resource)
169    }
170
171    /// Reserves memory of specified size using the reserve algorithm, and return
172    /// a handle to the reserved memory.
173    ///
174    /// Also clean ups, removing unused slices, and chunks if permitted by deallocation strategy.
175    fn reserve(&mut self, size: usize) -> Self::Handle {
176        self.cleanup_slices();
177
178        let handle = self.reserve_algorithm(size);
179
180        if self.dealloc_strategy.should_dealloc() {
181            self.cleanup_chunks();
182        }
183
184        handle
185    }
186
187    fn alloc(&mut self, size: usize) -> Self::Handle {
188        self.create_chunk(size)
189    }
190
191    fn dealloc(&mut self, handle: &Self::Handle) {
192        match handle {
193            SimpleHandle::Chunk(id) => {
194                if let Some((handle, _slices)) = self.chunks.remove(id) {
195                    self.storage.dealloc(handle.id);
196                }
197            }
198            SimpleHandle::Slice(_) => panic!("Can't dealloc slice manually"),
199        }
200    }
201
202    fn storage(&mut self) -> &mut Storage {
203        &mut self.storage
204    }
205}
206
207impl<Storage: ComputeStorage> SimpleMemoryManagement<Storage> {
208    /// Creates a new instance using the given storage, deallocation strategy and slice strategy.
209    pub fn new(
210        storage: Storage,
211        dealloc_strategy: DeallocStrategy,
212        slice_strategy: SliceStrategy,
213    ) -> Self {
214        Self {
215            chunks: HashMap::new(),
216            slices: HashMap::new(),
217            dealloc_strategy,
218            slice_strategy,
219            storage,
220        }
221    }
222
223    fn reserve_algorithm(&mut self, size: usize) -> SimpleHandle {
224        // Looks for a large enough, existing but unused chunk of memory.
225        let chunk = self.find_free_chunk(size);
226
227        match chunk {
228            Some((chunk_id, chunk_size)) => {
229                if size == chunk_size {
230                    // If there is one of exactly the same size, it reuses it.
231                    SimpleHandle::Chunk(chunk_id.clone())
232                } else {
233                    // Otherwise creates a slice of the right size upon it, always starting at zero.
234                    self.create_slice(size, chunk_id)
235                }
236            }
237            // If no chunk available, creates one of exactly the right size.
238            None => self.create_chunk(size),
239        }
240    }
241
242    /// Finds the smallest of the free and large enough chunks to fit `size`
243    /// Returns the chunk's id and size.
244    fn find_free_chunk(&self, size: usize) -> Option<(ChunkId, usize)> {
245        let mut size_diff_current = usize::MAX;
246        let mut current = None;
247
248        for (chunk_id, (resource, slices)) in self.chunks.iter() {
249            // If chunk is already used, we do not choose it
250            if !slices.is_empty() || !chunk_id.is_free() {
251                continue;
252            }
253
254            let resource_size = resource.size();
255
256            // If we find a chunk of exactly the right size, we stop searching altogether
257            if size == resource_size {
258                current = Some((chunk_id, resource));
259                break;
260            }
261
262            // Finds the smallest of the large enough chunks that can accept a slice
263            // of the given size
264            if self.slice_strategy.can_use_chunk(resource_size, size) {
265                let size_diff = resource_size - size;
266
267                if size_diff < size_diff_current {
268                    current = Some((chunk_id, resource));
269                    size_diff_current = size_diff;
270                }
271            }
272        }
273
274        current.map(|(id, handle)| (id.clone(), handle.size()))
275    }
276
277    /// Creates a slice of size `size` upon the given chunk.
278    ///
279    /// For now slices must start at zero, therefore there can be only one per chunk
280    fn create_slice(&mut self, size: usize, chunk_id: ChunkId) -> SimpleHandle {
281        let (handle, slices) = self.chunks.get_mut(&chunk_id).unwrap();
282        let slice_id = SliceId::new();
283
284        let storage = StorageHandle {
285            id: handle.id.clone(),
286            utilization: StorageUtilization::Slice(0, size),
287        };
288
289        if slices.is_empty() {
290            self.slices.insert(slice_id.clone(), (storage, chunk_id));
291        } else {
292            panic!("Can't have more than 1 slice yet.");
293        }
294
295        slices.push(slice_id.clone());
296
297        SimpleHandle::Slice(slice_id)
298    }
299
300    /// Creates a chunk of given size by allocating on the storage.
301    fn create_chunk(&mut self, size: usize) -> SimpleHandle {
302        let resource = self.storage.alloc(size);
303        let chunk_id = ChunkId::new();
304
305        self.chunks.insert(chunk_id.clone(), (resource, Vec::new()));
306
307        SimpleHandle::Chunk(chunk_id)
308    }
309
310    /// Deallocates free chunks and remove them from chunks map.
311    fn cleanup_chunks(&mut self) {
312        let mut ids_to_remove = Vec::new();
313
314        self.chunks.iter().for_each(|(chunk_id, _resource)| {
315            if chunk_id.is_free() {
316                ids_to_remove.push(chunk_id.clone());
317            }
318        });
319
320        ids_to_remove
321            .iter()
322            .map(|chunk_id| self.chunks.remove(chunk_id).unwrap())
323            .for_each(|(resource, _slices)| {
324                self.storage.dealloc(resource.id);
325            });
326    }
327
328    /// Removes free slices from slice map and corresponding chunks.
329    fn cleanup_slices(&mut self) {
330        let mut ids_to_remove = Vec::new();
331
332        self.slices.iter().for_each(|(slice_id, _resource)| {
333            if slice_id.is_free() {
334                ids_to_remove.push(slice_id.clone());
335            }
336        });
337
338        ids_to_remove
339            .iter()
340            .map(|slice_id| {
341                let value = self.slices.remove(slice_id).unwrap();
342                (slice_id, value.1)
343            })
344            .for_each(|(slice_id, chunk_id)| {
345                let (_chunk, slices) = self.chunks.get_mut(&chunk_id).unwrap();
346                slices.retain(|id| id != slice_id);
347            });
348    }
349}
350
351#[cfg(test)]
352mod tests {
353    use crate::{
354        memory_management::{MemoryHandle, MemoryManagement, SliceStrategy},
355        storage::BytesStorage,
356    };
357
358    use super::{DeallocStrategy, SimpleMemoryManagement};
359
360    #[test]
361    fn can_mut_with_single_tensor_reference() {
362        let mut memory_management = SimpleMemoryManagement::new(
363            BytesStorage::default(),
364            DeallocStrategy::Never,
365            SliceStrategy::Never,
366        );
367
368        let chunk_size = 4;
369        let simple_handle = memory_management.create_chunk(chunk_size);
370
371        let x = simple_handle.clone();
372        core::mem::drop(simple_handle);
373
374        assert!(x.can_mut());
375    }
376
377    #[test]
378    fn two_tensor_references_remove_mutability() {
379        let mut memory_management = SimpleMemoryManagement::new(
380            BytesStorage::default(),
381            DeallocStrategy::Never,
382            SliceStrategy::Never,
383        );
384
385        let chunk_size = 4;
386        let simple_handle = memory_management.create_chunk(chunk_size);
387
388        let x = simple_handle.clone();
389
390        assert!(!simple_handle.can_mut());
391        assert!(!x.can_mut())
392    }
393
394    #[test]
395    fn when_non_empty_chunk_exists_and_other_one_created_there_should_be_two() {
396        let mut memory_management = SimpleMemoryManagement::new(
397            BytesStorage::default(),
398            DeallocStrategy::Never,
399            SliceStrategy::Never,
400        );
401        let chunk_size = 4;
402        let _chunk_handle = memory_management.reserve(chunk_size);
403        let _new_handle = memory_management.reserve(chunk_size);
404
405        assert_eq!(memory_management.chunks.len(), 2);
406    }
407
408    #[test]
409    fn when_empty_chunk_is_cleaned_upexists_it_disappears() {
410        let mut memory_management = SimpleMemoryManagement::new(
411            BytesStorage::default(),
412            DeallocStrategy::Never,
413            SliceStrategy::Never,
414        );
415        let chunk_size = 4;
416        let chunk_handle = memory_management.reserve(chunk_size);
417        drop(chunk_handle);
418        memory_management.cleanup_chunks();
419
420        assert_eq!(memory_management.chunks.len(), 0);
421    }
422
423    #[test]
424    fn never_dealloc_strategy_never_deallocs() {
425        let mut never_dealloc = DeallocStrategy::Never;
426        for _ in 0..20 {
427            assert!(!never_dealloc.should_dealloc())
428        }
429    }
430
431    #[test]
432    fn period_tick_dealloc_strategy_should_dealloc_after_period() {
433        let period = 3;
434        let mut period_tick_dealloc = DeallocStrategy::new_period_tick(period);
435
436        for _ in 0..3 {
437            for _ in 0..period - 1 {
438                assert!(!period_tick_dealloc.should_dealloc());
439            }
440            assert!(period_tick_dealloc.should_dealloc());
441        }
442    }
443
444    #[test]
445    fn slice_strategy_minimum_bytes() {
446        let strategy = SliceStrategy::MinimumSize(100);
447
448        assert!(strategy.can_use_chunk(200, 101));
449        assert!(!strategy.can_use_chunk(200, 99));
450    }
451
452    #[test]
453    fn slice_strategy_maximum_bytes() {
454        let strategy = SliceStrategy::MaximumSize(100);
455
456        assert!(strategy.can_use_chunk(200, 99));
457        assert!(!strategy.can_use_chunk(200, 101));
458    }
459
460    #[test]
461    fn slice_strategy_ratio() {
462        let strategy = SliceStrategy::Ratio(0.9);
463
464        assert!(strategy.can_use_chunk(200, 180));
465        assert!(!strategy.can_use_chunk(200, 179));
466    }
467}