1extern crate alloc;
11
12use alloc::vec::Vec;
13
14use crate::errors::MemoryError;
15use crate::memory::header_store::HeaderStore;
16use crate::memory::manager::MemoryManager;
17use crate::memory::MemoryClass;
18use crate::message::payload::Payload;
19use crate::message::{Message, MessageHeader};
20use crate::types::MessageToken;
21
22#[cfg(feature = "checked-memory-manager-refs")]
23use core::cell::Cell;
24
25#[cfg(feature = "checked-memory-manager-refs")]
26use core::ops::{Deref, DerefMut};
27
28pub struct HeapMemoryManager<P: Payload> {
30 slots: Vec<Option<Message<P>>>,
31 #[cfg(feature = "checked-memory-manager-refs")]
32 borrow_states: Vec<Cell<u16>>,
33 mem_class: MemoryClass,
34}
35
36impl<P: Payload> HeapMemoryManager<P> {
37 pub fn new(capacity: usize) -> Self {
42 let mut slots = Vec::with_capacity(capacity);
43 for _ in 0..capacity {
44 slots.push(None);
45 }
46
47 #[cfg(feature = "checked-memory-manager-refs")]
48 let borrow_states = {
49 let mut v = Vec::with_capacity(capacity);
50 for _ in 0..capacity {
51 v.push(Cell::new(0));
52 }
53 v
54 };
55
56 Self {
57 slots,
58 #[cfg(feature = "checked-memory-manager-refs")]
59 borrow_states,
60 mem_class: MemoryClass::Host,
61 }
62 }
63
64 pub fn with_memory_class(capacity: usize, mem_class: MemoryClass) -> Self {
66 let mut slots = Vec::with_capacity(capacity);
67 for _ in 0..capacity {
68 slots.push(None);
69 }
70
71 #[cfg(feature = "checked-memory-manager-refs")]
72 let borrow_states = {
73 let mut v = Vec::with_capacity(capacity);
74 for _ in 0..capacity {
75 v.push(Cell::new(0));
76 }
77 v
78 };
79
80 Self {
81 slots,
82 #[cfg(feature = "checked-memory-manager-refs")]
83 borrow_states,
84 mem_class,
85 }
86 }
87
88 pub fn configured_capacity(&self) -> usize {
90 self.slots.len()
91 }
92}
93
94#[cfg(not(feature = "checked-memory-manager-refs"))]
99mod unchecked {
100 use super::*;
101
102 impl<P: Payload> HeaderStore for HeapMemoryManager<P> {
103 type HeaderGuard<'a>
104 = &'a MessageHeader
105 where
106 Self: 'a;
107
108 fn peek_header(&self, token: MessageToken) -> Result<Self::HeaderGuard<'_>, MemoryError> {
109 let idx = token.index();
110 if idx >= self.slots.len() {
111 return Err(MemoryError::BadToken);
112 }
113 self.slots[idx]
114 .as_ref()
115 .map(|m| m.header())
116 .ok_or(MemoryError::BadToken)
117 }
118 }
119
120 impl<P: Payload> MemoryManager<P> for HeapMemoryManager<P> {
121 type ReadGuard<'a>
122 = &'a Message<P>
123 where
124 Self: 'a;
125
126 type WriteGuard<'a>
127 = &'a mut Message<P>
128 where
129 Self: 'a;
130
131 fn store(&mut self, value: Message<P>) -> Result<MessageToken, MemoryError> {
132 for (i, slot) in self.slots.iter_mut().enumerate() {
134 if slot.is_none() {
135 *slot = Some(value);
136 return Ok(MessageToken::new(i as u32));
137 }
138 }
139 Err(MemoryError::NoFreeSlots)
140 }
141
142 fn read(&self, token: MessageToken) -> Result<Self::ReadGuard<'_>, MemoryError> {
143 let idx = token.index();
144 if idx >= self.slots.len() {
145 return Err(MemoryError::BadToken);
146 }
147 self.slots[idx].as_ref().ok_or(MemoryError::BadToken)
148 }
149
150 fn read_mut(&mut self, token: MessageToken) -> Result<Self::WriteGuard<'_>, MemoryError> {
151 let idx = token.index();
152 if idx >= self.slots.len() {
153 return Err(MemoryError::BadToken);
154 }
155 self.slots[idx].as_mut().ok_or(MemoryError::BadToken)
156 }
157
158 fn free(&mut self, token: MessageToken) -> Result<(), MemoryError> {
159 let idx = token.index();
160 if idx >= self.slots.len() {
161 return Err(MemoryError::BadToken);
162 }
163 if self.slots[idx].is_none() {
164 return Err(MemoryError::NotAllocated);
165 }
166 self.slots[idx] = None;
167 Ok(())
168 }
169
170 fn available(&self) -> usize {
171 self.slots.iter().filter(|s| s.is_none()).count()
172 }
173
174 fn capacity(&self) -> usize {
175 self.slots.len()
176 }
177
178 fn memory_class(&self) -> MemoryClass {
179 self.mem_class
180 }
181 }
182}
183
184#[cfg(feature = "checked-memory-manager-refs")]
189mod checked {
190 use super::*;
191 use core::marker::PhantomData;
192
193 const WRITE_BORROW_MARK: u16 = u16::MAX;
194
195 fn try_increment_read(cell: &Cell<u16>) -> Result<(), MemoryError> {
196 let value = cell.get();
197 if value == WRITE_BORROW_MARK {
198 return Err(MemoryError::AlreadyBorrowed);
199 }
200 if value == WRITE_BORROW_MARK - 1 {
201 return Err(MemoryError::AlreadyBorrowed);
202 }
203 cell.set(value + 1);
204 Ok(())
205 }
206
207 fn decrement_read(cell: &Cell<u16>) {
208 let v = cell.get();
209 if v == 0 || v == WRITE_BORROW_MARK {
210 cell.set(0);
211 } else {
212 cell.set(v - 1);
213 }
214 }
215
216 fn try_set_write(cell: &Cell<u16>) -> Result<(), MemoryError> {
217 if cell.get() != 0 {
218 return Err(MemoryError::AlreadyBorrowed);
219 }
220 cell.set(WRITE_BORROW_MARK);
221 Ok(())
222 }
223
224 fn clear_write(cell: &Cell<u16>) {
225 cell.set(0);
226 }
227
228 pub struct HeapHeaderGuard<'a> {
230 header: &'a MessageHeader,
231 borrow_state: &'a Cell<u16>,
232 }
233
234 impl<'a> Deref for HeapHeaderGuard<'a> {
235 type Target = MessageHeader;
236 fn deref(&self) -> &Self::Target {
237 self.header
238 }
239 }
240
241 impl<'a> Drop for HeapHeaderGuard<'a> {
242 fn drop(&mut self) {
243 decrement_read(self.borrow_state);
244 }
245 }
246
247 pub struct HeapReadGuard<'a, P: Payload> {
249 msg: &'a Message<P>,
250 borrow_state: &'a Cell<u16>,
251 }
252
253 impl<'a, P: Payload> Deref for HeapReadGuard<'a, P> {
254 type Target = Message<P>;
255 fn deref(&self) -> &Self::Target {
256 self.msg
257 }
258 }
259
260 impl<'a, P: Payload> Drop for HeapReadGuard<'a, P> {
261 fn drop(&mut self) {
262 decrement_read(self.borrow_state);
263 }
264 }
265
266 pub struct HeapWriteGuard<'a, P: Payload> {
268 msg: &'a mut Message<P>,
269 borrow_state: &'a Cell<u16>,
270 _phantom: PhantomData<&'a mut Message<P>>,
271 }
272
273 impl<'a, P: Payload> Deref for HeapWriteGuard<'a, P> {
274 type Target = Message<P>;
275 fn deref(&self) -> &Self::Target {
276 self.msg
277 }
278 }
279
280 impl<'a, P: Payload> DerefMut for HeapWriteGuard<'a, P> {
281 fn deref_mut(&mut self) -> &mut Self::Target {
282 self.msg
283 }
284 }
285
286 impl<'a, P: Payload> Drop for HeapWriteGuard<'a, P> {
287 fn drop(&mut self) {
288 clear_write(self.borrow_state);
289 }
290 }
291
292 impl<P: Payload> HeaderStore for HeapMemoryManager<P> {
293 type HeaderGuard<'a>
294 = HeapHeaderGuard<'a>
295 where
296 Self: 'a;
297
298 fn peek_header(&self, token: MessageToken) -> Result<Self::HeaderGuard<'_>, MemoryError> {
299 let idx = token.index();
300 if idx >= self.slots.len() {
301 return Err(MemoryError::BadToken);
302 }
303
304 match self.slots[idx].as_ref() {
305 Some(msg) => {
306 try_increment_read(&self.borrow_states[idx])?;
307 Ok(HeapHeaderGuard {
308 header: msg.header(),
309 borrow_state: &self.borrow_states[idx],
310 })
311 }
312 None => Err(MemoryError::BadToken),
313 }
314 }
315 }
316
317 impl<P: Payload> MemoryManager<P> for HeapMemoryManager<P> {
318 type ReadGuard<'a>
319 = HeapReadGuard<'a, P>
320 where
321 Self: 'a;
322
323 type WriteGuard<'a>
324 = HeapWriteGuard<'a, P>
325 where
326 Self: 'a;
327
328 fn store(&mut self, value: Message<P>) -> Result<MessageToken, MemoryError> {
329 for (i, slot) in self.slots.iter_mut().enumerate() {
330 if slot.is_none() {
331 slot.replace(value);
332 self.borrow_states[i].set(0);
334 return Ok(MessageToken::new(i as u32));
335 }
336 }
337 Err(MemoryError::NoFreeSlots)
338 }
339
340 fn read(&self, token: MessageToken) -> Result<Self::ReadGuard<'_>, MemoryError> {
341 let idx = token.index();
342 if idx >= self.slots.len() {
343 return Err(MemoryError::BadToken);
344 }
345
346 match self.slots[idx].as_ref() {
347 Some(msg) => {
348 try_increment_read(&self.borrow_states[idx])?;
349 Ok(HeapReadGuard {
350 msg,
351 borrow_state: &self.borrow_states[idx],
352 })
353 }
354 None => Err(MemoryError::BadToken),
355 }
356 }
357
358 fn read_mut(&mut self, token: MessageToken) -> Result<Self::WriteGuard<'_>, MemoryError> {
359 let idx = token.index();
360 if idx >= self.slots.len() {
361 return Err(MemoryError::BadToken);
362 }
363
364 match self.slots[idx].as_mut() {
365 Some(msg) => {
366 try_set_write(&self.borrow_states[idx])?;
367 Ok(HeapWriteGuard {
368 msg,
369 borrow_state: &self.borrow_states[idx],
370 _phantom: PhantomData,
371 })
372 }
373 None => Err(MemoryError::BadToken),
374 }
375 }
376
377 fn free(&mut self, token: MessageToken) -> Result<(), MemoryError> {
378 let idx = token.index();
379 if idx >= self.slots.len() {
380 return Err(MemoryError::BadToken);
381 }
382
383 if self.slots[idx].is_none() {
384 return Err(MemoryError::NotAllocated);
385 }
386
387 let state = self.borrow_states[idx].get();
388 if state != 0 {
389 return Err(MemoryError::BorrowActive);
390 }
391
392 self.slots[idx] = None;
393 self.borrow_states[idx].set(0);
394 Ok(())
395 }
396
397 fn available(&self) -> usize {
398 self.slots.iter().filter(|s| s.is_none()).count()
399 }
400
401 fn capacity(&self) -> usize {
402 self.slots.len()
403 }
404
405 fn memory_class(&self) -> MemoryClass {
406 self.mem_class
407 }
408 }
409}
410
411#[cfg(test)]
412mod tests {
413 use super::*;
414 use crate::{
415 message::MessageHeader,
416 prelude::{create_test_tensor_filled_with, TestTensor, TEST_TENSOR_BYTE_COUNT},
417 };
418
419 fn make_msg(val: u32) -> Message<TestTensor> {
421 Message::new(MessageHeader::empty(), create_test_tensor_filled_with(val))
422 }
423
424 #[test]
425 fn store_read_free_cycle() {
426 let mut mgr: HeapMemoryManager<TestTensor> = HeapMemoryManager::new(4);
427 assert_eq!(mgr.available(), 4);
428 assert_eq!(mgr.capacity(), 4);
429
430 let token = mgr.store(make_msg(42)).unwrap();
431 assert_eq!(mgr.available(), 3);
432
433 {
434 let msg = mgr.read(token).unwrap();
435 assert_eq!(*msg.payload(), create_test_tensor_filled_with(42));
436 }
437
438 mgr.free(token).unwrap();
439 assert_eq!(mgr.available(), 4);
440 }
441
442 #[test]
443 fn read_mut_works() {
444 let mut mgr: HeapMemoryManager<TestTensor> = HeapMemoryManager::new(4);
445 let token = mgr.store(make_msg(10)).unwrap();
446
447 {
448 let mut write_guard = mgr.read_mut(token).unwrap();
449 let msg = core::ops::DerefMut::deref_mut(&mut write_guard);
450 *msg.payload_mut() = create_test_tensor_filled_with(99);
451 }
452
453 {
454 let msg = mgr.read(token).unwrap();
455 assert_eq!(*msg.payload(), create_test_tensor_filled_with(99));
456 }
457
458 mgr.free(token).unwrap();
459 }
460
461 #[test]
462 fn peek_header_works() {
463 let mut mgr: HeapMemoryManager<TestTensor> = HeapMemoryManager::new(4);
464 let token = mgr.store(make_msg(7)).unwrap();
465
466 {
467 let header = mgr.peek_header(token).unwrap();
468 assert_eq!(*header.payload_size_bytes(), TEST_TENSOR_BYTE_COUNT);
469 }
470
471 mgr.free(token).unwrap();
472 }
473
474 #[test]
475 fn capacity_exhaustion() {
476 let mut mgr: HeapMemoryManager<TestTensor> = HeapMemoryManager::new(2);
477 let _t0 = mgr.store(make_msg(1)).unwrap();
478 let _t1 = mgr.store(make_msg(2)).unwrap();
479 assert_eq!(mgr.available(), 0);
480
481 let err = mgr.store(make_msg(3));
482 assert_eq!(err, Err(MemoryError::NoFreeSlots));
483 }
484
485 #[test]
486 fn double_free_detected() {
487 let mut mgr: HeapMemoryManager<TestTensor> = HeapMemoryManager::new(4);
488 let token = mgr.store(make_msg(1)).unwrap();
489 mgr.free(token).unwrap();
490
491 let err = mgr.free(token);
492 assert_eq!(err, Err(MemoryError::NotAllocated));
493 }
494
495 #[test]
496 fn bad_token_detected() {
497 let mgr: HeapMemoryManager<TestTensor> = HeapMemoryManager::new(4);
498 let bad = MessageToken::new(99);
499
500 assert!(matches!(mgr.read(bad), Err(MemoryError::BadToken)));
501 assert!(matches!(mgr.peek_header(bad), Err(MemoryError::BadToken)));
502 }
503
504 #[test]
505 fn read_freed_slot_is_bad_token() {
506 let mut mgr: HeapMemoryManager<TestTensor> = HeapMemoryManager::new(4);
507 let token = mgr.store(make_msg(1)).unwrap();
508 mgr.free(token).unwrap();
509
510 assert!(matches!(mgr.read(token), Err(MemoryError::BadToken)));
511 assert!(matches!(mgr.peek_header(token), Err(MemoryError::BadToken)));
512 }
513
514 #[test]
515 fn slot_reuse_after_free() {
516 let mut mgr: HeapMemoryManager<TestTensor> = HeapMemoryManager::new(4);
517 let t0 = mgr.store(make_msg(10)).unwrap();
518 mgr.free(t0).unwrap();
519
520 let t1 = mgr.store(make_msg(20)).unwrap();
522 assert_eq!(t1.index(), 0);
523 assert_eq!(
524 *mgr.read(t1).unwrap().payload(),
525 create_test_tensor_filled_with(20)
526 );
527 }
528
529 #[test]
530 fn memory_class_configurable() {
531 let mgr: HeapMemoryManager<TestTensor> =
532 HeapMemoryManager::with_memory_class(4, MemoryClass::Device(0));
533 assert_eq!(mgr.memory_class(), MemoryClass::Device(0));
534 }
535
536 #[test]
537 fn default_memory_class_is_host() {
538 let mgr: HeapMemoryManager<TestTensor> = HeapMemoryManager::new(4);
539 assert_eq!(mgr.memory_class(), MemoryClass::Host);
540 }
541}