burn_compute/memory_management/
simple.rs1use super::{MemoryHandle, MemoryManagement};
2use crate::{
3 memory_id_type,
4 storage::{ComputeStorage, StorageHandle, StorageUtilization},
5};
6use alloc::{sync::Arc, vec::Vec};
7use hashbrown::HashMap;
8
9#[cfg(all(not(target_family = "wasm"), feature = "std"))]
10use std::time;
11#[cfg(all(target_family = "wasm", feature = "std"))]
12use web_time as time;
13
14memory_id_type!(ChunkId);
16memory_id_type!(SliceId);
18
19impl ChunkId {
20 fn is_free(&self) -> bool {
22 Arc::strong_count(&self.id) <= 1
23 }
24}
25
26impl SliceId {
27 fn is_free(&self) -> bool {
29 Arc::strong_count(&self.id) <= 2
30 }
31}
32
33#[derive(Debug, Clone)]
35pub enum SimpleHandle {
36 Chunk(ChunkId),
38 Slice(SliceId),
40}
41
42#[derive(Debug)]
44pub enum DeallocStrategy {
45 PeriodTick {
47 period: usize,
49 state: usize,
51 },
52 #[cfg(feature = "std")]
53 PeriodTime {
55 period: time::Duration,
57 state: time::Instant,
59 },
60 Never,
62}
63
64#[derive(Debug)]
66pub enum SliceStrategy {
67 Never,
69 Ratio(f32),
71 MinimumSize(usize),
73 MaximumSize(usize),
75}
76
77impl SliceStrategy {
78 pub fn can_use_chunk(&self, chunk_size: usize, reserved_size: usize) -> bool {
80 if chunk_size < reserved_size {
81 return false;
82 }
83
84 match self {
85 SliceStrategy::Never => false,
86 SliceStrategy::Ratio(ratio) => (reserved_size as f32 / chunk_size as f32) >= *ratio,
87 SliceStrategy::MinimumSize(bytes) => reserved_size >= *bytes,
88 SliceStrategy::MaximumSize(bytes) => reserved_size <= *bytes,
89 }
90 }
91}
92
93impl DeallocStrategy {
94 pub fn new_period_tick(period: usize) -> Self {
96 DeallocStrategy::PeriodTick { period, state: 0 }
97 }
98
99 fn should_dealloc(&mut self) -> bool {
100 match self {
101 DeallocStrategy::PeriodTick { period, state } => {
102 *state = (*state + 1) % *period;
103 *state == 0
104 }
105 #[cfg(feature = "std")]
106 DeallocStrategy::PeriodTime { period, state } => {
107 if &state.elapsed() > period {
108 *state = time::Instant::now();
109 true
110 } else {
111 false
112 }
113 }
114 DeallocStrategy::Never => false,
115 }
116 }
117}
118
119pub struct SimpleMemoryManagement<Storage> {
121 chunks: HashMap<ChunkId, (StorageHandle, Vec<SliceId>)>,
122 slices: HashMap<SliceId, (StorageHandle, ChunkId)>,
123 dealloc_strategy: DeallocStrategy,
124 slice_strategy: SliceStrategy,
125 storage: Storage,
126}
127
128impl<Storage> core::fmt::Debug for SimpleMemoryManagement<Storage> {
129 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
130 f.write_str(
131 alloc::format!(
132 "SimpleMemoryManagement {:?} - {:?}",
133 self.dealloc_strategy,
134 core::any::type_name::<Storage>(),
135 )
136 .as_str(),
137 )
138 }
139}
140
141impl MemoryHandle for SimpleHandle {
142 fn can_mut(&self) -> bool {
145 const REFERENCE_LIMIT_CHUNK: usize = 2;
147 const REFERENCE_LIMIT_SLICE: usize = 3;
150
151 match &self {
152 SimpleHandle::Chunk(id) => Arc::strong_count(&id.id) <= REFERENCE_LIMIT_CHUNK,
153 SimpleHandle::Slice(id) => Arc::strong_count(&id.id) <= REFERENCE_LIMIT_SLICE,
154 }
155 }
156}
157
158impl<Storage: ComputeStorage> MemoryManagement<Storage> for SimpleMemoryManagement<Storage> {
159 type Handle = SimpleHandle;
160
161 fn get(&mut self, handle: &Self::Handle) -> Storage::Resource {
163 let resource = match &handle {
164 SimpleHandle::Chunk(id) => &self.chunks.get(id).unwrap().0,
165 SimpleHandle::Slice(id) => &self.slices.get(id).unwrap().0,
166 };
167
168 self.storage.get(resource)
169 }
170
171 fn reserve(&mut self, size: usize) -> Self::Handle {
176 self.cleanup_slices();
177
178 let handle = self.reserve_algorithm(size);
179
180 if self.dealloc_strategy.should_dealloc() {
181 self.cleanup_chunks();
182 }
183
184 handle
185 }
186
187 fn alloc(&mut self, size: usize) -> Self::Handle {
188 self.create_chunk(size)
189 }
190
191 fn dealloc(&mut self, handle: &Self::Handle) {
192 match handle {
193 SimpleHandle::Chunk(id) => {
194 if let Some((handle, _slices)) = self.chunks.remove(id) {
195 self.storage.dealloc(handle.id);
196 }
197 }
198 SimpleHandle::Slice(_) => panic!("Can't dealloc slice manually"),
199 }
200 }
201
202 fn storage(&mut self) -> &mut Storage {
203 &mut self.storage
204 }
205}
206
207impl<Storage: ComputeStorage> SimpleMemoryManagement<Storage> {
208 pub fn new(
210 storage: Storage,
211 dealloc_strategy: DeallocStrategy,
212 slice_strategy: SliceStrategy,
213 ) -> Self {
214 Self {
215 chunks: HashMap::new(),
216 slices: HashMap::new(),
217 dealloc_strategy,
218 slice_strategy,
219 storage,
220 }
221 }
222
223 fn reserve_algorithm(&mut self, size: usize) -> SimpleHandle {
224 let chunk = self.find_free_chunk(size);
226
227 match chunk {
228 Some((chunk_id, chunk_size)) => {
229 if size == chunk_size {
230 SimpleHandle::Chunk(chunk_id.clone())
232 } else {
233 self.create_slice(size, chunk_id)
235 }
236 }
237 None => self.create_chunk(size),
239 }
240 }
241
242 fn find_free_chunk(&self, size: usize) -> Option<(ChunkId, usize)> {
245 let mut size_diff_current = usize::MAX;
246 let mut current = None;
247
248 for (chunk_id, (resource, slices)) in self.chunks.iter() {
249 if !slices.is_empty() || !chunk_id.is_free() {
251 continue;
252 }
253
254 let resource_size = resource.size();
255
256 if size == resource_size {
258 current = Some((chunk_id, resource));
259 break;
260 }
261
262 if self.slice_strategy.can_use_chunk(resource_size, size) {
265 let size_diff = resource_size - size;
266
267 if size_diff < size_diff_current {
268 current = Some((chunk_id, resource));
269 size_diff_current = size_diff;
270 }
271 }
272 }
273
274 current.map(|(id, handle)| (id.clone(), handle.size()))
275 }
276
277 fn create_slice(&mut self, size: usize, chunk_id: ChunkId) -> SimpleHandle {
281 let (handle, slices) = self.chunks.get_mut(&chunk_id).unwrap();
282 let slice_id = SliceId::new();
283
284 let storage = StorageHandle {
285 id: handle.id.clone(),
286 utilization: StorageUtilization::Slice(0, size),
287 };
288
289 if slices.is_empty() {
290 self.slices.insert(slice_id.clone(), (storage, chunk_id));
291 } else {
292 panic!("Can't have more than 1 slice yet.");
293 }
294
295 slices.push(slice_id.clone());
296
297 SimpleHandle::Slice(slice_id)
298 }
299
300 fn create_chunk(&mut self, size: usize) -> SimpleHandle {
302 let resource = self.storage.alloc(size);
303 let chunk_id = ChunkId::new();
304
305 self.chunks.insert(chunk_id.clone(), (resource, Vec::new()));
306
307 SimpleHandle::Chunk(chunk_id)
308 }
309
310 fn cleanup_chunks(&mut self) {
312 let mut ids_to_remove = Vec::new();
313
314 self.chunks.iter().for_each(|(chunk_id, _resource)| {
315 if chunk_id.is_free() {
316 ids_to_remove.push(chunk_id.clone());
317 }
318 });
319
320 ids_to_remove
321 .iter()
322 .map(|chunk_id| self.chunks.remove(chunk_id).unwrap())
323 .for_each(|(resource, _slices)| {
324 self.storage.dealloc(resource.id);
325 });
326 }
327
328 fn cleanup_slices(&mut self) {
330 let mut ids_to_remove = Vec::new();
331
332 self.slices.iter().for_each(|(slice_id, _resource)| {
333 if slice_id.is_free() {
334 ids_to_remove.push(slice_id.clone());
335 }
336 });
337
338 ids_to_remove
339 .iter()
340 .map(|slice_id| {
341 let value = self.slices.remove(slice_id).unwrap();
342 (slice_id, value.1)
343 })
344 .for_each(|(slice_id, chunk_id)| {
345 let (_chunk, slices) = self.chunks.get_mut(&chunk_id).unwrap();
346 slices.retain(|id| id != slice_id);
347 });
348 }
349}
350
351#[cfg(test)]
352mod tests {
353 use crate::{
354 memory_management::{MemoryHandle, MemoryManagement, SliceStrategy},
355 storage::BytesStorage,
356 };
357
358 use super::{DeallocStrategy, SimpleMemoryManagement};
359
360 #[test]
361 fn can_mut_with_single_tensor_reference() {
362 let mut memory_management = SimpleMemoryManagement::new(
363 BytesStorage::default(),
364 DeallocStrategy::Never,
365 SliceStrategy::Never,
366 );
367
368 let chunk_size = 4;
369 let simple_handle = memory_management.create_chunk(chunk_size);
370
371 let x = simple_handle.clone();
372 core::mem::drop(simple_handle);
373
374 assert!(x.can_mut());
375 }
376
377 #[test]
378 fn two_tensor_references_remove_mutability() {
379 let mut memory_management = SimpleMemoryManagement::new(
380 BytesStorage::default(),
381 DeallocStrategy::Never,
382 SliceStrategy::Never,
383 );
384
385 let chunk_size = 4;
386 let simple_handle = memory_management.create_chunk(chunk_size);
387
388 let x = simple_handle.clone();
389
390 assert!(!simple_handle.can_mut());
391 assert!(!x.can_mut())
392 }
393
394 #[test]
395 fn when_non_empty_chunk_exists_and_other_one_created_there_should_be_two() {
396 let mut memory_management = SimpleMemoryManagement::new(
397 BytesStorage::default(),
398 DeallocStrategy::Never,
399 SliceStrategy::Never,
400 );
401 let chunk_size = 4;
402 let _chunk_handle = memory_management.reserve(chunk_size);
403 let _new_handle = memory_management.reserve(chunk_size);
404
405 assert_eq!(memory_management.chunks.len(), 2);
406 }
407
408 #[test]
409 fn when_empty_chunk_is_cleaned_upexists_it_disappears() {
410 let mut memory_management = SimpleMemoryManagement::new(
411 BytesStorage::default(),
412 DeallocStrategy::Never,
413 SliceStrategy::Never,
414 );
415 let chunk_size = 4;
416 let chunk_handle = memory_management.reserve(chunk_size);
417 drop(chunk_handle);
418 memory_management.cleanup_chunks();
419
420 assert_eq!(memory_management.chunks.len(), 0);
421 }
422
423 #[test]
424 fn never_dealloc_strategy_never_deallocs() {
425 let mut never_dealloc = DeallocStrategy::Never;
426 for _ in 0..20 {
427 assert!(!never_dealloc.should_dealloc())
428 }
429 }
430
431 #[test]
432 fn period_tick_dealloc_strategy_should_dealloc_after_period() {
433 let period = 3;
434 let mut period_tick_dealloc = DeallocStrategy::new_period_tick(period);
435
436 for _ in 0..3 {
437 for _ in 0..period - 1 {
438 assert!(!period_tick_dealloc.should_dealloc());
439 }
440 assert!(period_tick_dealloc.should_dealloc());
441 }
442 }
443
444 #[test]
445 fn slice_strategy_minimum_bytes() {
446 let strategy = SliceStrategy::MinimumSize(100);
447
448 assert!(strategy.can_use_chunk(200, 101));
449 assert!(!strategy.can_use_chunk(200, 99));
450 }
451
452 #[test]
453 fn slice_strategy_maximum_bytes() {
454 let strategy = SliceStrategy::MaximumSize(100);
455
456 assert!(strategy.can_use_chunk(200, 99));
457 assert!(!strategy.can_use_chunk(200, 101));
458 }
459
460 #[test]
461 fn slice_strategy_ratio() {
462 let strategy = SliceStrategy::Ratio(0.9);
463
464 assert!(strategy.can_use_chunk(200, 180));
465 assert!(!strategy.can_use_chunk(200, 179));
466 }
467}