1use crate::errors::MemoryError;
9use crate::memory::header_store::HeaderStore;
10use crate::memory::manager::MemoryManager;
11use crate::memory::MemoryClass;
12use crate::message::payload::Payload;
13use crate::message::{Message, MessageHeader};
14use crate::types::MessageToken;
15
16#[cfg(feature = "checked-memory-manager-refs")]
17use core::cell::Cell;
18
19#[cfg(feature = "checked-memory-manager-refs")]
20use core::ops::{Deref, DerefMut};
21
22pub struct StaticMemoryManager<P: Payload, const DEPTH: usize> {
34 slots: [Option<Message<P>>; DEPTH],
35 #[cfg(feature = "checked-memory-manager-refs")]
36 borrow_states: [Cell<u16>; DEPTH], mem_class: MemoryClass,
38}
39
40impl<P: Payload, const DEPTH: usize> StaticMemoryManager<P, DEPTH> {
41 pub fn new() -> Self {
45 Self {
46 slots: core::array::from_fn(|_| None),
47 #[cfg(feature = "checked-memory-manager-refs")]
48 borrow_states: core::array::from_fn(|_| Cell::new(0)),
49 mem_class: MemoryClass::Host,
50 }
51 }
52
53 pub fn with_memory_class(mem_class: MemoryClass) -> Self {
55 Self {
56 slots: core::array::from_fn(|_| None),
57 #[cfg(feature = "checked-memory-manager-refs")]
58 borrow_states: core::array::from_fn(|_| Cell::new(0)),
59 mem_class,
60 }
61 }
62}
63
64impl<P: Payload, const DEPTH: usize> Default for StaticMemoryManager<P, DEPTH> {
65 fn default() -> Self {
66 Self::new()
67 }
68}
69
70#[cfg(not(feature = "checked-memory-manager-refs"))]
75mod unchecked {
76 use super::*;
77
78 impl<P: Payload, const DEPTH: usize> HeaderStore for StaticMemoryManager<P, DEPTH> {
79 type HeaderGuard<'a>
80 = &'a MessageHeader
81 where
82 Self: 'a;
83
84 fn peek_header(&self, token: MessageToken) -> Result<Self::HeaderGuard<'_>, MemoryError> {
85 let idx = token.index();
86 if idx >= DEPTH {
87 return Err(MemoryError::BadToken);
88 }
89 self.slots[idx]
91 .as_ref()
92 .map(|msg| msg.header())
93 .ok_or(MemoryError::BadToken)
94 }
95 }
96
97 impl<P: Payload, const DEPTH: usize> MemoryManager<P> for StaticMemoryManager<P, DEPTH> {
98 type ReadGuard<'a>
99 = &'a Message<P>
100 where
101 Self: 'a;
102
103 type WriteGuard<'a>
104 = &'a mut Message<P>
105 where
106 Self: 'a;
107
108 fn store(&mut self, value: Message<P>) -> Result<MessageToken, MemoryError> {
109 for (i, slot) in self.slots.iter_mut().enumerate() {
110 if slot.is_none() {
111 *slot = Some(value);
112 return Ok(MessageToken::new(i as u32));
113 }
114 }
115 Err(MemoryError::NoFreeSlots)
116 }
117
118 fn read(&self, token: MessageToken) -> Result<Self::ReadGuard<'_>, MemoryError> {
119 let idx = token.index();
120 if idx >= DEPTH {
121 return Err(MemoryError::BadToken);
122 }
123 self.slots[idx].as_ref().ok_or(MemoryError::BadToken)
125 }
126
127 fn read_mut(&mut self, token: MessageToken) -> Result<Self::WriteGuard<'_>, MemoryError> {
128 let idx = token.index();
129 if idx >= DEPTH {
130 return Err(MemoryError::BadToken);
131 }
132 self.slots[idx].as_mut().ok_or(MemoryError::BadToken)
134 }
135
136 fn free(&mut self, token: MessageToken) -> Result<(), MemoryError> {
137 let idx = token.index();
138 if idx >= DEPTH {
139 return Err(MemoryError::BadToken);
140 }
141 if self.slots[idx].is_none() {
142 return Err(MemoryError::NotAllocated);
144 }
145 self.slots[idx] = None;
146 Ok(())
147 }
148
149 fn available(&self) -> usize {
150 self.slots.iter().filter(|s| s.is_none()).count()
151 }
152
153 fn capacity(&self) -> usize {
154 DEPTH
155 }
156
157 fn memory_class(&self) -> MemoryClass {
158 self.mem_class
159 }
160 }
161}
162
163#[cfg(feature = "checked-memory-manager-refs")]
168mod checked {
169 use super::*;
170
171 const WRITE_BORROW_MARK: u16 = u16::MAX;
178
179 fn try_increment_read(cell: &Cell<u16>) -> Result<(), MemoryError> {
180 let value = cell.get();
181 if value == WRITE_BORROW_MARK {
182 return Err(MemoryError::AlreadyBorrowed);
183 }
184 if value == WRITE_BORROW_MARK - 1 {
185 return Err(MemoryError::AlreadyBorrowed);
186 }
187 cell.set(value + 1);
188 Ok(())
189 }
190
191 fn decrement_read(cell: &Cell<u16>) {
192 let v = cell.get();
193 if v == 0 || v == WRITE_BORROW_MARK {
194 cell.set(0);
196 } else {
197 cell.set(v - 1);
198 }
199 }
200
201 fn try_set_write(cell: &Cell<u16>) -> Result<(), MemoryError> {
202 if cell.get() != 0 {
203 return Err(MemoryError::AlreadyBorrowed);
204 }
205 cell.set(WRITE_BORROW_MARK);
206 Ok(())
207 }
208
209 fn clear_write(cell: &Cell<u16>) {
210 cell.set(0);
211 }
212
213 pub struct StaticHeaderGuard<'a> {
215 header: &'a MessageHeader,
216 borrow_state: &'a Cell<u16>,
217 }
218
219 impl<'a> Deref for StaticHeaderGuard<'a> {
220 type Target = MessageHeader;
221 fn deref(&self) -> &Self::Target {
222 self.header
223 }
224 }
225
226 impl<'a> Drop for StaticHeaderGuard<'a> {
227 fn drop(&mut self) {
228 decrement_read(self.borrow_state);
229 }
230 }
231
232 pub struct StaticReadGuard<'a, P: Payload> {
234 msg: &'a Message<P>,
235 borrow_state: &'a Cell<u16>,
236 }
237
238 impl<'a, P: Payload> Deref for StaticReadGuard<'a, P> {
239 type Target = Message<P>;
240 fn deref(&self) -> &Self::Target {
241 self.msg
242 }
243 }
244
245 impl<'a, P: Payload> Drop for StaticReadGuard<'a, P> {
246 fn drop(&mut self) {
247 decrement_read(self.borrow_state);
248 }
249 }
250
251 pub struct StaticWriteGuard<'a, P: Payload> {
253 msg: &'a mut Message<P>,
254 borrow_state: &'a Cell<u16>,
255 }
256
257 impl<'a, P: Payload> Deref for StaticWriteGuard<'a, P> {
258 type Target = Message<P>;
259 fn deref(&self) -> &Self::Target {
260 self.msg
261 }
262 }
263
264 impl<'a, P: Payload> DerefMut for StaticWriteGuard<'a, P> {
265 fn deref_mut(&mut self) -> &mut Self::Target {
266 self.msg
267 }
268 }
269
270 impl<'a, P: Payload> Drop for StaticWriteGuard<'a, P> {
271 fn drop(&mut self) {
272 clear_write(self.borrow_state);
273 }
274 }
275
276 impl<P: Payload, const DEPTH: usize> HeaderStore for StaticMemoryManager<P, DEPTH> {
277 type HeaderGuard<'a>
278 = StaticHeaderGuard<'a>
279 where
280 Self: 'a;
281
282 fn peek_header(&self, token: MessageToken) -> Result<Self::HeaderGuard<'_>, MemoryError> {
283 let idx = token.index();
284 if idx >= DEPTH {
285 return Err(MemoryError::BadToken);
286 }
287
288 match self.slots[idx].as_ref() {
289 Some(msg) => {
290 try_increment_read(&self.borrow_states[idx])?;
291 Ok(StaticHeaderGuard {
292 header: msg.header(),
293 borrow_state: &self.borrow_states[idx],
294 })
295 }
296 None => Err(MemoryError::BadToken),
297 }
298 }
299 }
300
301 impl<P: Payload, const DEPTH: usize> MemoryManager<P> for StaticMemoryManager<P, DEPTH> {
302 type ReadGuard<'a>
303 = StaticReadGuard<'a, P>
304 where
305 Self: 'a;
306
307 type WriteGuard<'a>
308 = StaticWriteGuard<'a, P>
309 where
310 Self: 'a;
311
312 fn store(&mut self, value: Message<P>) -> Result<MessageToken, MemoryError> {
313 for (i, slot) in self.slots.iter_mut().enumerate() {
314 if slot.is_none() {
315 slot.replace(value);
316 self.borrow_states[i].set(0);
318 return Ok(MessageToken::new(i as u32));
319 }
320 }
321 Err(MemoryError::NoFreeSlots)
322 }
323
324 fn read(&self, token: MessageToken) -> Result<Self::ReadGuard<'_>, MemoryError> {
325 let idx = token.index();
326 if idx >= DEPTH {
327 return Err(MemoryError::BadToken);
328 }
329
330 match self.slots[idx].as_ref() {
331 Some(msg) => {
332 try_increment_read(&self.borrow_states[idx])?;
333 Ok(StaticReadGuard {
334 msg,
335 borrow_state: &self.borrow_states[idx],
336 })
337 }
338 None => Err(MemoryError::BadToken),
339 }
340 }
341
342 fn read_mut(&mut self, token: MessageToken) -> Result<Self::WriteGuard<'_>, MemoryError> {
343 let idx = token.index();
344 if idx >= DEPTH {
345 return Err(MemoryError::BadToken);
346 }
347
348 match self.slots[idx].as_mut() {
349 Some(msg) => {
350 try_set_write(&self.borrow_states[idx])?;
351 Ok(StaticWriteGuard {
352 msg,
353 borrow_state: &self.borrow_states[idx],
354 })
355 }
356 None => Err(MemoryError::BadToken),
357 }
358 }
359
360 fn free(&mut self, token: MessageToken) -> Result<(), MemoryError> {
361 let idx = token.index();
362 if idx >= DEPTH {
363 return Err(MemoryError::BadToken);
364 }
365
366 if self.slots[idx].is_none() {
367 return Err(MemoryError::NotAllocated);
368 }
369
370 let state = self.borrow_states[idx].get();
371 if state != 0 {
372 return Err(MemoryError::BorrowActive);
373 }
374
375 self.slots[idx] = None;
376 self.borrow_states[idx].set(0);
377 Ok(())
378 }
379
380 fn available(&self) -> usize {
381 self.slots.iter().filter(|s| s.is_none()).count()
382 }
383
384 fn capacity(&self) -> usize {
385 DEPTH
386 }
387
388 fn memory_class(&self) -> MemoryClass {
389 self.mem_class
390 }
391 }
392}
393
394#[cfg(test)]
395mod tests {
396 use super::*;
418 use crate::{
419 message::MessageHeader,
420 prelude::{create_test_tensor_filled_with, TestTensor, TEST_TENSOR_BYTE_COUNT},
421 };
422
423 fn make_msg(val: u32) -> Message<TestTensor> {
425 Message::new(MessageHeader::empty(), create_test_tensor_filled_with(val))
426 }
427
428 #[test]
429 fn store_read_free_cycle() {
430 let mut mgr: StaticMemoryManager<TestTensor, 4> = StaticMemoryManager::new();
431 assert_eq!(mgr.available(), 4);
432 assert_eq!(mgr.capacity(), 4);
433
434 let token = mgr.store(make_msg(42)).unwrap();
435 assert_eq!(mgr.available(), 3);
436
437 {
438 let msg = mgr.read(token).unwrap();
439 assert_eq!(*msg.payload(), create_test_tensor_filled_with(42));
440 }
441
442 mgr.free(token).unwrap();
443 assert_eq!(mgr.available(), 4);
444 }
445
446 #[test]
447 fn read_mut_works() {
448 let mut mgr: StaticMemoryManager<TestTensor, 4> = StaticMemoryManager::new();
449 let token = mgr.store(make_msg(10)).unwrap();
450
451 {
452 let mut write_guard = mgr.read_mut(token).unwrap();
453 let msg = core::ops::DerefMut::deref_mut(&mut write_guard);
454 *msg.payload_mut() = create_test_tensor_filled_with(99);
455 }
456
457 {
458 let msg = mgr.read(token).unwrap();
459 assert_eq!(*msg.payload(), create_test_tensor_filled_with(99));
460 }
461
462 mgr.free(token).unwrap();
463 }
464
465 #[test]
466 fn peek_header_works() {
467 let mut mgr: StaticMemoryManager<TestTensor, 4> = StaticMemoryManager::new();
468 let token = mgr.store(make_msg(7)).unwrap();
469
470 {
471 let header = mgr.peek_header(token).unwrap();
472 assert_eq!(*header.payload_size_bytes(), TEST_TENSOR_BYTE_COUNT);
473 }
474
475 mgr.free(token).unwrap();
476 }
477
478 #[test]
479 fn capacity_exhaustion() {
480 let mut mgr: StaticMemoryManager<TestTensor, 2> = StaticMemoryManager::new();
481 let _t0 = mgr.store(make_msg(1)).unwrap();
482 let _t1 = mgr.store(make_msg(2)).unwrap();
483 assert_eq!(mgr.available(), 0);
484
485 let err = mgr.store(make_msg(3));
486 assert_eq!(err, Err(MemoryError::NoFreeSlots));
487 }
488
489 #[test]
490 fn double_free_detected() {
491 let mut mgr: StaticMemoryManager<TestTensor, 4> = StaticMemoryManager::new();
492 let token = mgr.store(make_msg(1)).unwrap();
493 mgr.free(token).unwrap();
494
495 let err = mgr.free(token);
496 assert_eq!(err, Err(MemoryError::NotAllocated));
497 }
498
499 #[test]
500 fn bad_token_detected() {
501 let mgr: StaticMemoryManager<TestTensor, 4> = StaticMemoryManager::new();
502 let bad = MessageToken::new(99);
503
504 assert!(matches!(mgr.read(bad), Err(MemoryError::BadToken)));
505 assert!(matches!(mgr.peek_header(bad), Err(MemoryError::BadToken)));
506 }
507
508 #[test]
509 fn read_freed_slot_is_bad_token() {
510 let mut mgr: StaticMemoryManager<TestTensor, 4> = StaticMemoryManager::new();
511 let token = mgr.store(make_msg(1)).unwrap();
512 mgr.free(token).unwrap();
513
514 assert!(matches!(mgr.read(token), Err(MemoryError::BadToken)));
515 assert!(matches!(mgr.peek_header(token), Err(MemoryError::BadToken)));
516 }
517
518 #[test]
519 fn slot_reuse_after_free() {
520 let mut mgr: StaticMemoryManager<TestTensor, 1> = StaticMemoryManager::new();
521 let t0 = mgr.store(make_msg(10)).unwrap();
522 mgr.free(t0).unwrap();
523
524 let t1 = mgr.store(make_msg(20)).unwrap();
526 assert_eq!(t1.index(), 0);
527 assert_eq!(
528 *mgr.read(t1).unwrap().payload(),
529 create_test_tensor_filled_with(20)
530 );
531 }
532
533 #[test]
534 fn memory_class_configurable() {
535 let mgr: StaticMemoryManager<TestTensor, 4> =
536 StaticMemoryManager::with_memory_class(MemoryClass::Device(0));
537 assert_eq!(mgr.memory_class(), MemoryClass::Device(0));
538 }
539
540 #[test]
541 fn default_memory_class_is_host() {
542 let mgr: StaticMemoryManager<TestTensor, 4> = StaticMemoryManager::new();
543 assert_eq!(mgr.memory_class(), MemoryClass::Host);
544 }
545}