Skip to main content

memlink_protocol/
shm.rs

1//! Shared memory view for zero-copy access.
2//!
3//! Defines ShmView struct for safe, bounds-checked access to shared
4//! memory regions with methods for reading headers and payloads.
5
6use alloc::string::ToString;
7use core::marker::PhantomData;
8use core::ptr::NonNull;
9
10use crate::error::{ProtocolError, Result};
11use crate::header::MessageHeader;
12use crate::magic::HEADER_SIZE;
13
14#[cfg(feature = "shm")]
15pub use memlink_shm::{
16    RingBuffer, Priority as ShmPriority, Platform, MmapSegment, ControlRegion,
17    Futex, FutexError, FutexResult, PriorityRingBuffer,
18    ShmTransport, ShmError, ShmResult,
19    RecoveryManager, Heartbeat, SlotMetadata, SlotState,
20    BoundsChecker, PoisonGuard, BoundsError,
21};
22
23pub const SHM_ALIGNMENT: usize = 64;
24
25#[derive(Debug)]
26pub struct ShmView<'a> {
27    ptr: NonNull<u8>,
28    len: usize,
29    _phantom: PhantomData<&'a ()>,
30}
31
32impl<'a> ShmView<'a> {
33    pub unsafe fn new(ptr: *const u8, len: usize) -> Self {
34        let non_null_ptr = NonNull::new(ptr as *mut u8).unwrap_or_else(|| {
35            panic!("ShmView::new called with null pointer");
36        });
37
38        Self {
39            ptr: non_null_ptr,
40            len,
41            _phantom: PhantomData,
42        }
43    }
44
45    pub fn from_slice(slice: &'a [u8]) -> Self {
46        Self {
47            ptr: NonNull::from(slice).cast(),
48            len: slice.len(),
49            _phantom: PhantomData,
50        }
51    }
52
53    pub fn len(&self) -> usize {
54        self.len
55    }
56
57    pub fn is_empty(&self) -> bool {
58        self.len == 0
59    }
60
61    pub fn as_slice(&self) -> &'a [u8] {
62        unsafe { core::slice::from_raw_parts(self.ptr.as_ptr(), self.len) }
63    }
64
65    pub fn read_header(&self) -> Result<MessageHeader> {
66        if self.len < HEADER_SIZE {
67            return Err(ProtocolError::BufferOverflow {
68                required: HEADER_SIZE,
69                available: self.len,
70            });
71        }
72
73        let header_bytes: &[u8; HEADER_SIZE] = self.as_slice()[..HEADER_SIZE]
74            .try_into()
75            .map_err(|_| ProtocolError::InvalidHeader("failed to convert to header array".to_string()))?;
76
77        MessageHeader::from_bytes(header_bytes)
78    }
79
80    pub fn read_payload_at(&self, offset: usize, payload_len: usize) -> Result<&'a [u8]> {
81        let end = offset
82            .checked_add(payload_len)
83            .ok_or(ProtocolError::BufferOverflow {
84                required: offset + payload_len,
85                available: self.len,
86            })?;
87
88        if end > self.len {
89            return Err(ProtocolError::BufferOverflow {
90                required: end,
91                available: self.len,
92            });
93        }
94
95        Ok(&self.as_slice()[offset..end])
96    }
97
98    pub fn read_payload(&self, payload_len: usize) -> Result<&'a [u8]> {
99        self.read_payload_at(HEADER_SIZE, payload_len)
100    }
101
102    pub fn sub_view(&self, offset: usize, len: usize) -> Result<ShmView<'a>> {
103        let end = offset
104            .checked_add(len)
105            .ok_or(ProtocolError::BufferOverflow {
106                required: offset + len,
107                available: self.len,
108            })?;
109
110        if end > self.len {
111            return Err(ProtocolError::BufferOverflow {
112                required: end,
113                available: self.len,
114            });
115        }
116
117        unsafe {
118            Ok(ShmView::new(
119                self.ptr.as_ptr().add(offset),
120                len,
121            ))
122        }
123    }
124
125    pub fn has_minimum(&self, min_bytes: usize) -> bool {
126        self.len >= min_bytes
127    }
128
129    pub unsafe fn as_ptr(&self) -> *const u8 {
130        self.ptr.as_ptr()
131    }
132}
133
134pub fn is_aligned(ptr: *const u8) -> bool {
135    (ptr as usize) % SHM_ALIGNMENT == 0
136}