Skip to main content

edgefirst_tensor/
shm.rs

1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4use crate::{
5    error::{Error, Result},
6    TensorMap, TensorMapTrait, TensorMemory, TensorTrait,
7};
8use log::{debug, trace, warn};
9use nix::{fcntl::OFlag, sys::stat::fstat, unistd::ftruncate};
10use num_traits::Num;
11use std::{
12    ffi::c_void,
13    fmt,
14    num::NonZero,
15    ops::{Deref, DerefMut},
16    os::fd::{AsRawFd, OwnedFd},
17    ptr::NonNull,
18    sync::{Arc, Mutex},
19};
20#[derive(Debug)]
21pub struct ShmTensor<T>
22where
23    T: Num + Clone + fmt::Debug + Send + Sync,
24{
25    pub name: String,
26    pub fd: OwnedFd,
27    pub shape: Vec<usize>,
28    pub _marker: std::marker::PhantomData<T>,
29    identity: crate::BufferIdentity,
30}
31
32unsafe impl<T> Send for ShmTensor<T> where T: Num + Clone + fmt::Debug + Send + Sync {}
33unsafe impl<T> Sync for ShmTensor<T> where T: Num + Clone + fmt::Debug + Send + Sync {}
34impl<T> TensorTrait<T> for ShmTensor<T>
35where
36    T: Num + Clone + fmt::Debug + Send + Sync,
37{
38    fn new(shape: &[usize], name: Option<&str>) -> Result<Self> {
39        let size = shape.iter().product::<usize>() * std::mem::size_of::<T>();
40        let name = match name {
41            Some(name) => name.to_owned(),
42            None => {
43                let uuid = uuid::Uuid::new_v4().as_simple().to_string();
44                format!("/{}", &uuid[..16])
45            }
46        };
47
48        let shm_fd = nix::sys::mman::shm_open(
49            name.as_str(),
50            OFlag::O_CREAT | OFlag::O_EXCL | OFlag::O_RDWR,
51            nix::sys::stat::Mode::S_IRUSR | nix::sys::stat::Mode::S_IWUSR,
52        )?;
53
54        trace!("Creating shared memory: {name}");
55
56        // We drop the shared memory object name after creating it to avoid
57        // leaving it in the system after the program exits.  The sharing model
58        // for the library is through file descriptors, not names.
59        let err = nix::sys::mman::shm_unlink(name.as_str());
60        if let Err(e) = err {
61            warn!("Failed to unlink shared memory: {e}");
62        }
63
64        ftruncate(&shm_fd, size as i64)?;
65        let stat = fstat(&shm_fd)?;
66        debug!("Shared memory stat: {stat:?}");
67
68        Ok(ShmTensor::<T> {
69            name: name.to_owned(),
70            fd: shm_fd,
71            shape: shape.to_vec(),
72            _marker: std::marker::PhantomData,
73            identity: crate::BufferIdentity::new(),
74        })
75    }
76
77    fn from_fd(fd: OwnedFd, shape: &[usize], name: Option<&str>) -> Result<Self> {
78        if shape.is_empty() {
79            return Err(Error::InvalidSize(0));
80        }
81
82        let size = shape.iter().product::<usize>() * std::mem::size_of::<T>();
83        if size == 0 {
84            return Err(Error::InvalidSize(0));
85        }
86
87        Ok(ShmTensor {
88            name: name.unwrap_or("").to_owned(),
89            fd,
90            shape: shape.to_vec(),
91            _marker: std::marker::PhantomData,
92            identity: crate::BufferIdentity::new(),
93        })
94    }
95
96    fn clone_fd(&self) -> Result<OwnedFd> {
97        Ok(self.fd.try_clone()?)
98    }
99
100    fn memory(&self) -> TensorMemory {
101        TensorMemory::Shm
102    }
103
104    fn name(&self) -> String {
105        self.name.clone()
106    }
107
108    fn shape(&self) -> &[usize] {
109        &self.shape
110    }
111
112    fn reshape(&mut self, shape: &[usize]) -> Result<()> {
113        if shape.is_empty() {
114            return Err(Error::InvalidSize(0));
115        }
116
117        let new_size = shape.iter().product::<usize>() * std::mem::size_of::<T>();
118        if new_size != self.size() {
119            return Err(Error::ShapeMismatch(format!(
120                "Cannot reshape incompatible shape: {:?} to {:?}",
121                self.shape, shape
122            )));
123        }
124
125        self.shape = shape.to_vec();
126        Ok(())
127    }
128
129    fn map(&self) -> Result<TensorMap<T>> {
130        let size = NonZero::new(self.size()).ok_or(Error::InvalidSize(self.size()))?;
131        let ptr = unsafe {
132            nix::sys::mman::mmap(
133                None,
134                size,
135                nix::sys::mman::ProtFlags::PROT_READ | nix::sys::mman::ProtFlags::PROT_WRITE,
136                nix::sys::mman::MapFlags::MAP_SHARED,
137                &self.fd,
138                0,
139            )?
140        };
141
142        trace!("Mapping shared memory: {ptr:?}");
143        let shm_ptr = ShmPtr(NonNull::new(ptr.as_ptr()).ok_or(Error::InvalidSize(self.size()))?);
144        Ok(TensorMap::Shm(ShmMap {
145            ptr: Arc::new(Mutex::new(shm_ptr)),
146            shape: self.shape.clone(),
147            _marker: std::marker::PhantomData,
148        }))
149    }
150
151    fn buffer_identity(&self) -> &crate::BufferIdentity {
152        &self.identity
153    }
154}
155
156impl<T> AsRawFd for ShmTensor<T>
157where
158    T: Num + Clone + fmt::Debug + Send + Sync,
159{
160    fn as_raw_fd(&self) -> std::os::fd::RawFd {
161        self.fd.as_raw_fd()
162    }
163}
164
165#[derive(Debug)]
166pub struct ShmMap<T>
167where
168    T: Num + Clone + fmt::Debug,
169{
170    ptr: Arc<Mutex<ShmPtr>>,
171    shape: Vec<usize>,
172    _marker: std::marker::PhantomData<T>,
173}
174
175impl<T> Deref for ShmMap<T>
176where
177    T: Num + Clone + fmt::Debug,
178{
179    type Target = [T];
180
181    fn deref(&self) -> &[T] {
182        self.as_slice()
183    }
184}
185
186impl<T> DerefMut for ShmMap<T>
187where
188    T: Num + Clone + fmt::Debug,
189{
190    fn deref_mut(&mut self) -> &mut [T] {
191        self.as_mut_slice()
192    }
193}
194
195#[derive(Debug)]
196struct ShmPtr(NonNull<c_void>);
197impl Deref for ShmPtr {
198    type Target = NonNull<c_void>;
199
200    fn deref(&self) -> &Self::Target {
201        &self.0
202    }
203}
204
205unsafe impl Send for ShmPtr {}
206
207impl<T> TensorMapTrait<T> for ShmMap<T>
208where
209    T: Num + Clone + fmt::Debug,
210{
211    fn shape(&self) -> &[usize] {
212        &self.shape
213    }
214
215    fn unmap(&mut self) {
216        let ptr = self.ptr.lock().expect("Failed to lock ShmMap pointer");
217        let err = unsafe { nix::sys::mman::munmap(**ptr, self.size()) };
218        if let Err(e) = err {
219            warn!("Failed to unmap shared memory: {e}");
220        }
221    }
222
223    fn as_slice(&self) -> &[T] {
224        let ptr = self.ptr.lock().expect("Failed to lock ShmMap pointer");
225        unsafe { std::slice::from_raw_parts(ptr.as_ptr() as *const T, self.len()) }
226    }
227
228    fn as_mut_slice(&mut self) -> &mut [T] {
229        let ptr = self.ptr.lock().expect("Failed to lock ShmMap pointer");
230        unsafe { std::slice::from_raw_parts_mut(ptr.as_ptr() as *mut T, self.len()) }
231    }
232}
233
234impl<T> Drop for ShmMap<T>
235where
236    T: Num + Clone + fmt::Debug,
237{
238    fn drop(&mut self) {
239        trace!("ShmMap dropped, unmapping memory");
240        self.unmap();
241    }
242}