Skip to main content

edgefirst_tensor/
shm.rs

1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4use crate::{
5    error::{Error, Result},
6    TensorMap, TensorMapTrait, TensorMemory, TensorTrait,
7};
8use log::{debug, trace, warn};
9use nix::{fcntl::OFlag, sys::stat::fstat, unistd::ftruncate};
10use num_traits::Num;
11use std::{
12    ffi::c_void,
13    fmt,
14    num::NonZero,
15    ops::{Deref, DerefMut},
16    os::fd::{AsRawFd, OwnedFd},
17    ptr::NonNull,
18    sync::{Arc, Mutex},
19};
20#[derive(Debug)]
21pub struct ShmTensor<T>
22where
23    T: Num + Clone + fmt::Debug + Send + Sync,
24{
25    pub name: String,
26    pub fd: OwnedFd,
27    pub shape: Vec<usize>,
28    pub _marker: std::marker::PhantomData<T>,
29}
30
31unsafe impl<T> Send for ShmTensor<T> where T: Num + Clone + fmt::Debug + Send + Sync {}
32unsafe impl<T> Sync for ShmTensor<T> where T: Num + Clone + fmt::Debug + Send + Sync {}
33impl<T> TensorTrait<T> for ShmTensor<T>
34where
35    T: Num + Clone + fmt::Debug + Send + Sync,
36{
37    fn new(shape: &[usize], name: Option<&str>) -> Result<Self> {
38        let size = shape.iter().product::<usize>() * std::mem::size_of::<T>();
39        let name = match name {
40            Some(name) => name.to_owned(),
41            None => {
42                let uuid = uuid::Uuid::new_v4().as_simple().to_string();
43                format!("/{}", &uuid[..16])
44            }
45        };
46
47        let shm_fd = nix::sys::mman::shm_open(
48            name.as_str(),
49            OFlag::O_CREAT | OFlag::O_EXCL | OFlag::O_RDWR,
50            nix::sys::stat::Mode::S_IRUSR | nix::sys::stat::Mode::S_IWUSR,
51        )?;
52
53        trace!("Creating shared memory: {name}");
54
55        // We drop the shared memory object name after creating it to avoid
56        // leaving it in the system after the program exits.  The sharing model
57        // for the library is through file descriptors, not names.
58        let err = nix::sys::mman::shm_unlink(name.as_str());
59        if let Err(e) = err {
60            warn!("Failed to unlink shared memory: {e}");
61        }
62
63        ftruncate(&shm_fd, size as i64)?;
64        let stat = fstat(&shm_fd)?;
65        debug!("Shared memory stat: {stat:?}");
66
67        Ok(ShmTensor::<T> {
68            name: name.to_owned(),
69            fd: shm_fd,
70            shape: shape.to_vec(),
71            _marker: std::marker::PhantomData,
72        })
73    }
74
75    fn from_fd(fd: OwnedFd, shape: &[usize], name: Option<&str>) -> Result<Self> {
76        if shape.is_empty() {
77            return Err(Error::InvalidSize(0));
78        }
79
80        let size = shape.iter().product::<usize>() * std::mem::size_of::<T>();
81        if size == 0 {
82            return Err(Error::InvalidSize(0));
83        }
84
85        Ok(ShmTensor {
86            name: name.unwrap_or("").to_owned(),
87            fd,
88            shape: shape.to_vec(),
89            _marker: std::marker::PhantomData,
90        })
91    }
92
93    fn clone_fd(&self) -> Result<OwnedFd> {
94        Ok(self.fd.try_clone()?)
95    }
96
97    fn memory(&self) -> TensorMemory {
98        TensorMemory::Shm
99    }
100
101    fn name(&self) -> String {
102        self.name.clone()
103    }
104
105    fn shape(&self) -> &[usize] {
106        &self.shape
107    }
108
109    fn reshape(&mut self, shape: &[usize]) -> Result<()> {
110        if shape.is_empty() {
111            return Err(Error::InvalidSize(0));
112        }
113
114        let new_size = shape.iter().product::<usize>() * std::mem::size_of::<T>();
115        if new_size != self.size() {
116            return Err(Error::ShapeMismatch(format!(
117                "Cannot reshape incompatible shape: {:?} to {:?}",
118                self.shape, shape
119            )));
120        }
121
122        self.shape = shape.to_vec();
123        Ok(())
124    }
125
126    fn map(&self) -> Result<TensorMap<T>> {
127        let size = NonZero::new(self.size()).ok_or(Error::InvalidSize(self.size()))?;
128        let ptr = unsafe {
129            nix::sys::mman::mmap(
130                None,
131                size,
132                nix::sys::mman::ProtFlags::PROT_READ | nix::sys::mman::ProtFlags::PROT_WRITE,
133                nix::sys::mman::MapFlags::MAP_SHARED,
134                &self.fd,
135                0,
136            )?
137        };
138
139        trace!("Mapping shared memory: {ptr:?}");
140        let shm_ptr = ShmPtr(NonNull::new(ptr.as_ptr()).ok_or(Error::InvalidSize(self.size()))?);
141        Ok(TensorMap::Shm(ShmMap {
142            ptr: Arc::new(Mutex::new(shm_ptr)),
143            shape: self.shape.clone(),
144            _marker: std::marker::PhantomData,
145        }))
146    }
147}
148
149impl<T> AsRawFd for ShmTensor<T>
150where
151    T: Num + Clone + fmt::Debug + Send + Sync,
152{
153    fn as_raw_fd(&self) -> std::os::fd::RawFd {
154        self.fd.as_raw_fd()
155    }
156}
157
158#[derive(Debug)]
159pub struct ShmMap<T>
160where
161    T: Num + Clone + fmt::Debug,
162{
163    ptr: Arc<Mutex<ShmPtr>>,
164    shape: Vec<usize>,
165    _marker: std::marker::PhantomData<T>,
166}
167
168impl<T> Deref for ShmMap<T>
169where
170    T: Num + Clone + fmt::Debug,
171{
172    type Target = [T];
173
174    fn deref(&self) -> &[T] {
175        self.as_slice()
176    }
177}
178
179impl<T> DerefMut for ShmMap<T>
180where
181    T: Num + Clone + fmt::Debug,
182{
183    fn deref_mut(&mut self) -> &mut [T] {
184        self.as_mut_slice()
185    }
186}
187
188#[derive(Debug)]
189struct ShmPtr(NonNull<c_void>);
190impl Deref for ShmPtr {
191    type Target = NonNull<c_void>;
192
193    fn deref(&self) -> &Self::Target {
194        &self.0
195    }
196}
197
198unsafe impl Send for ShmPtr {}
199
200impl<T> TensorMapTrait<T> for ShmMap<T>
201where
202    T: Num + Clone + fmt::Debug,
203{
204    fn shape(&self) -> &[usize] {
205        &self.shape
206    }
207
208    fn unmap(&mut self) {
209        let ptr = self.ptr.lock().expect("Failed to lock ShmMap pointer");
210        let err = unsafe { nix::sys::mman::munmap(**ptr, self.size()) };
211        if let Err(e) = err {
212            warn!("Failed to unmap shared memory: {e}");
213        }
214    }
215
216    fn as_slice(&self) -> &[T] {
217        let ptr = self.ptr.lock().expect("Failed to lock ShmMap pointer");
218        unsafe { std::slice::from_raw_parts(ptr.as_ptr() as *const T, self.len()) }
219    }
220
221    fn as_mut_slice(&mut self) -> &mut [T] {
222        let ptr = self.ptr.lock().expect("Failed to lock ShmMap pointer");
223        unsafe { std::slice::from_raw_parts_mut(ptr.as_ptr() as *mut T, self.len()) }
224    }
225}
226
227impl<T> Drop for ShmMap<T>
228where
229    T: Num + Clone + fmt::Debug,
230{
231    fn drop(&mut self) {
232        trace!("ShmMap dropped, unmapping memory");
233        self.unmap();
234    }
235}