Skip to main content

limen_core/memory/
heap_manager.rs

1//! Heap-backed fixed-capacity memory manager.
2//!
3//! Provides a fixed-capacity heap-backed manager that stores `Message<P>`
4//! values in a `Vec<Option<Message<P>>>`. The manager supports the same
5//! `checked-memory-manager-refs` feature as `StaticMemoryManager`:
6//!
7//! - Feature **disabled**: returns plain references (zero-cost).
8//! - Feature **enabled**: returns guard types and keeps per-slot borrow counters.
9
10extern crate alloc;
11
12use alloc::vec::Vec;
13
14use crate::errors::MemoryError;
15use crate::memory::header_store::HeaderStore;
16use crate::memory::manager::MemoryManager;
17use crate::memory::MemoryClass;
18use crate::message::payload::Payload;
19use crate::message::{Message, MessageHeader};
20use crate::types::MessageToken;
21
22#[cfg(feature = "checked-memory-manager-refs")]
23use core::cell::Cell;
24
25#[cfg(feature = "checked-memory-manager-refs")]
26use core::ops::{Deref, DerefMut};
27
28/// A heap-backed fixed-capacity memory manager.
29pub struct HeapMemoryManager<P: Payload> {
30    slots: Vec<Option<Message<P>>>,
31    #[cfg(feature = "checked-memory-manager-refs")]
32    borrow_states: Vec<Cell<u16>>,
33    mem_class: MemoryClass,
34}
35
36impl<P: Payload> HeapMemoryManager<P> {
37    /// Create a new manager with given fixed `capacity`.
38    ///
39    /// The manager pre-allocates `capacity` slots, all initially `None`.
40    /// `memory_class` defaults to `MemoryClass::Host`.
41    pub fn new(capacity: usize) -> Self {
42        let mut slots = Vec::with_capacity(capacity);
43        for _ in 0..capacity {
44            slots.push(None);
45        }
46
47        #[cfg(feature = "checked-memory-manager-refs")]
48        let borrow_states = {
49            let mut v = Vec::with_capacity(capacity);
50            for _ in 0..capacity {
51                v.push(Cell::new(0));
52            }
53            v
54        };
55
56        Self {
57            slots,
58            #[cfg(feature = "checked-memory-manager-refs")]
59            borrow_states,
60            mem_class: MemoryClass::Host,
61        }
62    }
63
64    /// Create a new manager with given `capacity` and explicit `memory_class`.
65    pub fn with_memory_class(capacity: usize, mem_class: MemoryClass) -> Self {
66        let mut slots = Vec::with_capacity(capacity);
67        for _ in 0..capacity {
68            slots.push(None);
69        }
70
71        #[cfg(feature = "checked-memory-manager-refs")]
72        let borrow_states = {
73            let mut v = Vec::with_capacity(capacity);
74            for _ in 0..capacity {
75                v.push(Cell::new(0));
76            }
77            v
78        };
79
80        Self {
81            slots,
82            #[cfg(feature = "checked-memory-manager-refs")]
83            borrow_states,
84            mem_class,
85        }
86    }
87
88    /// Return the configured capacity.
89    pub fn configured_capacity(&self) -> usize {
90        self.slots.len()
91    }
92}
93
94//
95// ===== Implementation when checked refs are disabled (zero-cost path) =====
96//
97
98#[cfg(not(feature = "checked-memory-manager-refs"))]
99mod unchecked {
100    use super::*;
101
102    impl<P: Payload> HeaderStore for HeapMemoryManager<P> {
103        type HeaderGuard<'a>
104            = &'a MessageHeader
105        where
106            Self: 'a;
107
108        fn peek_header(&self, token: MessageToken) -> Result<Self::HeaderGuard<'_>, MemoryError> {
109            let idx = token.index();
110            if idx >= self.slots.len() {
111                return Err(MemoryError::BadToken);
112            }
113            self.slots[idx]
114                .as_ref()
115                .map(|m| m.header())
116                .ok_or(MemoryError::BadToken)
117        }
118    }
119
120    impl<P: Payload> MemoryManager<P> for HeapMemoryManager<P> {
121        type ReadGuard<'a>
122            = &'a Message<P>
123        where
124            Self: 'a;
125
126        type WriteGuard<'a>
127            = &'a mut Message<P>
128        where
129            Self: 'a;
130
131        fn store(&mut self, value: Message<P>) -> Result<MessageToken, MemoryError> {
132            // Find first free slot
133            for (i, slot) in self.slots.iter_mut().enumerate() {
134                if slot.is_none() {
135                    *slot = Some(value);
136                    return Ok(MessageToken::new(i as u32));
137                }
138            }
139            Err(MemoryError::NoFreeSlots)
140        }
141
142        fn read(&self, token: MessageToken) -> Result<Self::ReadGuard<'_>, MemoryError> {
143            let idx = token.index();
144            if idx >= self.slots.len() {
145                return Err(MemoryError::BadToken);
146            }
147            self.slots[idx].as_ref().ok_or(MemoryError::BadToken)
148        }
149
150        fn read_mut(&mut self, token: MessageToken) -> Result<Self::WriteGuard<'_>, MemoryError> {
151            let idx = token.index();
152            if idx >= self.slots.len() {
153                return Err(MemoryError::BadToken);
154            }
155            self.slots[idx].as_mut().ok_or(MemoryError::BadToken)
156        }
157
158        fn free(&mut self, token: MessageToken) -> Result<(), MemoryError> {
159            let idx = token.index();
160            if idx >= self.slots.len() {
161                return Err(MemoryError::BadToken);
162            }
163            if self.slots[idx].is_none() {
164                return Err(MemoryError::NotAllocated);
165            }
166            self.slots[idx] = None;
167            Ok(())
168        }
169
170        fn available(&self) -> usize {
171            self.slots.iter().filter(|s| s.is_none()).count()
172        }
173
174        fn capacity(&self) -> usize {
175            self.slots.len()
176        }
177
178        fn memory_class(&self) -> MemoryClass {
179            self.mem_class
180        }
181    }
182}
183
184//
185// ===== Implementation when checked refs are enabled =====
186//
187
188#[cfg(feature = "checked-memory-manager-refs")]
189mod checked {
190    use super::*;
191    use core::marker::PhantomData;
192
193    const WRITE_BORROW_MARK: u16 = u16::MAX;
194
195    fn try_increment_read(cell: &Cell<u16>) -> Result<(), MemoryError> {
196        let value = cell.get();
197        if value == WRITE_BORROW_MARK {
198            return Err(MemoryError::AlreadyBorrowed);
199        }
200        if value == WRITE_BORROW_MARK - 1 {
201            return Err(MemoryError::AlreadyBorrowed);
202        }
203        cell.set(value + 1);
204        Ok(())
205    }
206
207    fn decrement_read(cell: &Cell<u16>) {
208        let v = cell.get();
209        if v == 0 || v == WRITE_BORROW_MARK {
210            cell.set(0);
211        } else {
212            cell.set(v - 1);
213        }
214    }
215
216    fn try_set_write(cell: &Cell<u16>) -> Result<(), MemoryError> {
217        if cell.get() != 0 {
218            return Err(MemoryError::AlreadyBorrowed);
219        }
220        cell.set(WRITE_BORROW_MARK);
221        Ok(())
222    }
223
224    fn clear_write(cell: &Cell<u16>) {
225        cell.set(0);
226    }
227
228    // Header guard
229    pub struct HeapHeaderGuard<'a> {
230        header: &'a MessageHeader,
231        borrow_state: &'a Cell<u16>,
232    }
233
234    impl<'a> Deref for HeapHeaderGuard<'a> {
235        type Target = MessageHeader;
236        fn deref(&self) -> &Self::Target {
237            self.header
238        }
239    }
240
241    impl<'a> Drop for HeapHeaderGuard<'a> {
242        fn drop(&mut self) {
243            decrement_read(self.borrow_state);
244        }
245    }
246
247    // Read guard
248    pub struct HeapReadGuard<'a, P: Payload> {
249        msg: &'a Message<P>,
250        borrow_state: &'a Cell<u16>,
251    }
252
253    impl<'a, P: Payload> Deref for HeapReadGuard<'a, P> {
254        type Target = Message<P>;
255        fn deref(&self) -> &Self::Target {
256            self.msg
257        }
258    }
259
260    impl<'a, P: Payload> Drop for HeapReadGuard<'a, P> {
261        fn drop(&mut self) {
262            decrement_read(self.borrow_state);
263        }
264    }
265
266    // Write guard
267    pub struct HeapWriteGuard<'a, P: Payload> {
268        msg: &'a mut Message<P>,
269        borrow_state: &'a Cell<u16>,
270        _phantom: PhantomData<&'a mut Message<P>>,
271    }
272
273    impl<'a, P: Payload> Deref for HeapWriteGuard<'a, P> {
274        type Target = Message<P>;
275        fn deref(&self) -> &Self::Target {
276            self.msg
277        }
278    }
279
280    impl<'a, P: Payload> DerefMut for HeapWriteGuard<'a, P> {
281        fn deref_mut(&mut self) -> &mut Self::Target {
282            self.msg
283        }
284    }
285
286    impl<'a, P: Payload> Drop for HeapWriteGuard<'a, P> {
287        fn drop(&mut self) {
288            clear_write(self.borrow_state);
289        }
290    }
291
292    impl<P: Payload> HeaderStore for HeapMemoryManager<P> {
293        type HeaderGuard<'a>
294            = HeapHeaderGuard<'a>
295        where
296            Self: 'a;
297
298        fn peek_header(&self, token: MessageToken) -> Result<Self::HeaderGuard<'_>, MemoryError> {
299            let idx = token.index();
300            if idx >= self.slots.len() {
301                return Err(MemoryError::BadToken);
302            }
303
304            match self.slots[idx].as_ref() {
305                Some(msg) => {
306                    try_increment_read(&self.borrow_states[idx])?;
307                    Ok(HeapHeaderGuard {
308                        header: msg.header(),
309                        borrow_state: &self.borrow_states[idx],
310                    })
311                }
312                None => Err(MemoryError::BadToken),
313            }
314        }
315    }
316
317    impl<P: Payload> MemoryManager<P> for HeapMemoryManager<P> {
318        type ReadGuard<'a>
319            = HeapReadGuard<'a, P>
320        where
321            Self: 'a;
322
323        type WriteGuard<'a>
324            = HeapWriteGuard<'a, P>
325        where
326            Self: 'a;
327
328        fn store(&mut self, value: Message<P>) -> Result<MessageToken, MemoryError> {
329            for (i, slot) in self.slots.iter_mut().enumerate() {
330                if slot.is_none() {
331                    slot.replace(value);
332                    // ensure borrow state cleared
333                    self.borrow_states[i].set(0);
334                    return Ok(MessageToken::new(i as u32));
335                }
336            }
337            Err(MemoryError::NoFreeSlots)
338        }
339
340        fn read(&self, token: MessageToken) -> Result<Self::ReadGuard<'_>, MemoryError> {
341            let idx = token.index();
342            if idx >= self.slots.len() {
343                return Err(MemoryError::BadToken);
344            }
345
346            match self.slots[idx].as_ref() {
347                Some(msg) => {
348                    try_increment_read(&self.borrow_states[idx])?;
349                    Ok(HeapReadGuard {
350                        msg,
351                        borrow_state: &self.borrow_states[idx],
352                    })
353                }
354                None => Err(MemoryError::BadToken),
355            }
356        }
357
358        fn read_mut(&mut self, token: MessageToken) -> Result<Self::WriteGuard<'_>, MemoryError> {
359            let idx = token.index();
360            if idx >= self.slots.len() {
361                return Err(MemoryError::BadToken);
362            }
363
364            match self.slots[idx].as_mut() {
365                Some(msg) => {
366                    try_set_write(&self.borrow_states[idx])?;
367                    Ok(HeapWriteGuard {
368                        msg,
369                        borrow_state: &self.borrow_states[idx],
370                        _phantom: PhantomData,
371                    })
372                }
373                None => Err(MemoryError::BadToken),
374            }
375        }
376
377        fn free(&mut self, token: MessageToken) -> Result<(), MemoryError> {
378            let idx = token.index();
379            if idx >= self.slots.len() {
380                return Err(MemoryError::BadToken);
381            }
382
383            if self.slots[idx].is_none() {
384                return Err(MemoryError::NotAllocated);
385            }
386
387            let state = self.borrow_states[idx].get();
388            if state != 0 {
389                return Err(MemoryError::BorrowActive);
390            }
391
392            self.slots[idx] = None;
393            self.borrow_states[idx].set(0);
394            Ok(())
395        }
396
397        fn available(&self) -> usize {
398            self.slots.iter().filter(|s| s.is_none()).count()
399        }
400
401        fn capacity(&self) -> usize {
402            self.slots.len()
403        }
404
405        fn memory_class(&self) -> MemoryClass {
406            self.mem_class
407        }
408    }
409}
410
411#[cfg(test)]
412mod tests {
413    use super::*;
414    use crate::{
415        message::MessageHeader,
416        prelude::{create_test_tensor_filled_with, TestTensor, TEST_TENSOR_BYTE_COUNT},
417    };
418
419    // Helper: build a simple Message<TestTensor>.
420    fn make_msg(val: u32) -> Message<TestTensor> {
421        Message::new(MessageHeader::empty(), create_test_tensor_filled_with(val))
422    }
423
424    #[test]
425    fn store_read_free_cycle() {
426        let mut mgr: HeapMemoryManager<TestTensor> = HeapMemoryManager::new(4);
427        assert_eq!(mgr.available(), 4);
428        assert_eq!(mgr.capacity(), 4);
429
430        let token = mgr.store(make_msg(42)).unwrap();
431        assert_eq!(mgr.available(), 3);
432
433        {
434            let msg = mgr.read(token).unwrap();
435            assert_eq!(*msg.payload(), create_test_tensor_filled_with(42));
436        }
437
438        mgr.free(token).unwrap();
439        assert_eq!(mgr.available(), 4);
440    }
441
442    #[test]
443    fn read_mut_works() {
444        let mut mgr: HeapMemoryManager<TestTensor> = HeapMemoryManager::new(4);
445        let token = mgr.store(make_msg(10)).unwrap();
446
447        {
448            let mut write_guard = mgr.read_mut(token).unwrap();
449            let msg = core::ops::DerefMut::deref_mut(&mut write_guard);
450            *msg.payload_mut() = create_test_tensor_filled_with(99);
451        }
452
453        {
454            let msg = mgr.read(token).unwrap();
455            assert_eq!(*msg.payload(), create_test_tensor_filled_with(99));
456        }
457
458        mgr.free(token).unwrap();
459    }
460
461    #[test]
462    fn peek_header_works() {
463        let mut mgr: HeapMemoryManager<TestTensor> = HeapMemoryManager::new(4);
464        let token = mgr.store(make_msg(7)).unwrap();
465
466        {
467            let header = mgr.peek_header(token).unwrap();
468            assert_eq!(*header.payload_size_bytes(), TEST_TENSOR_BYTE_COUNT);
469        }
470
471        mgr.free(token).unwrap();
472    }
473
474    #[test]
475    fn capacity_exhaustion() {
476        let mut mgr: HeapMemoryManager<TestTensor> = HeapMemoryManager::new(2);
477        let _t0 = mgr.store(make_msg(1)).unwrap();
478        let _t1 = mgr.store(make_msg(2)).unwrap();
479        assert_eq!(mgr.available(), 0);
480
481        let err = mgr.store(make_msg(3));
482        assert_eq!(err, Err(MemoryError::NoFreeSlots));
483    }
484
485    #[test]
486    fn double_free_detected() {
487        let mut mgr: HeapMemoryManager<TestTensor> = HeapMemoryManager::new(4);
488        let token = mgr.store(make_msg(1)).unwrap();
489        mgr.free(token).unwrap();
490
491        let err = mgr.free(token);
492        assert_eq!(err, Err(MemoryError::NotAllocated));
493    }
494
495    #[test]
496    fn bad_token_detected() {
497        let mgr: HeapMemoryManager<TestTensor> = HeapMemoryManager::new(4);
498        let bad = MessageToken::new(99);
499
500        assert!(matches!(mgr.read(bad), Err(MemoryError::BadToken)));
501        assert!(matches!(mgr.peek_header(bad), Err(MemoryError::BadToken)));
502    }
503
504    #[test]
505    fn read_freed_slot_is_bad_token() {
506        let mut mgr: HeapMemoryManager<TestTensor> = HeapMemoryManager::new(4);
507        let token = mgr.store(make_msg(1)).unwrap();
508        mgr.free(token).unwrap();
509
510        assert!(matches!(mgr.read(token), Err(MemoryError::BadToken)));
511        assert!(matches!(mgr.peek_header(token), Err(MemoryError::BadToken)));
512    }
513
514    #[test]
515    fn slot_reuse_after_free() {
516        let mut mgr: HeapMemoryManager<TestTensor> = HeapMemoryManager::new(4);
517        let t0 = mgr.store(make_msg(10)).unwrap();
518        mgr.free(t0).unwrap();
519
520        // Slot 0 should be reused.
521        let t1 = mgr.store(make_msg(20)).unwrap();
522        assert_eq!(t1.index(), 0);
523        assert_eq!(
524            *mgr.read(t1).unwrap().payload(),
525            create_test_tensor_filled_with(20)
526        );
527    }
528
529    #[test]
530    fn memory_class_configurable() {
531        let mgr: HeapMemoryManager<TestTensor> =
532            HeapMemoryManager::with_memory_class(4, MemoryClass::Device(0));
533        assert_eq!(mgr.memory_class(), MemoryClass::Device(0));
534    }
535
536    #[test]
537    fn default_memory_class_is_host() {
538        let mgr: HeapMemoryManager<TestTensor> = HeapMemoryManager::new(4);
539        assert_eq!(mgr.memory_class(), MemoryClass::Host);
540    }
541}