1use crate::{
5 error::{Error, Result},
6 TensorMap, TensorMapTrait, TensorMemory, TensorTrait,
7};
8use log::{debug, trace, warn};
9use nix::{fcntl::OFlag, sys::stat::fstat, unistd::ftruncate};
10use num_traits::Num;
11use std::{
12 ffi::c_void,
13 fmt,
14 num::NonZero,
15 ops::{Deref, DerefMut},
16 os::fd::{AsRawFd, OwnedFd},
17 ptr::NonNull,
18 sync::{Arc, Mutex},
19};
20#[derive(Debug)]
21pub struct ShmTensor<T>
22where
23 T: Num + Clone + fmt::Debug + Send + Sync,
24{
25 pub name: String,
26 pub fd: OwnedFd,
27 pub shape: Vec<usize>,
28 pub _marker: std::marker::PhantomData<T>,
29}
30
31unsafe impl<T> Send for ShmTensor<T> where T: Num + Clone + fmt::Debug + Send + Sync {}
32unsafe impl<T> Sync for ShmTensor<T> where T: Num + Clone + fmt::Debug + Send + Sync {}
33impl<T> TensorTrait<T> for ShmTensor<T>
34where
35 T: Num + Clone + fmt::Debug + Send + Sync,
36{
37 fn new(shape: &[usize], name: Option<&str>) -> Result<Self> {
38 let size = shape.iter().product::<usize>() * std::mem::size_of::<T>();
39 let name = match name {
40 Some(name) => name.to_owned(),
41 None => {
42 let uuid = uuid::Uuid::new_v4().as_simple().to_string();
43 format!("/{}", &uuid[..16])
44 }
45 };
46
47 let shm_fd = nix::sys::mman::shm_open(
48 name.as_str(),
49 OFlag::O_CREAT | OFlag::O_EXCL | OFlag::O_RDWR,
50 nix::sys::stat::Mode::S_IRUSR | nix::sys::stat::Mode::S_IWUSR,
51 )?;
52
53 trace!("Creating shared memory: {name}");
54
55 let err = nix::sys::mman::shm_unlink(name.as_str());
59 if let Err(e) = err {
60 warn!("Failed to unlink shared memory: {e}");
61 }
62
63 ftruncate(&shm_fd, size as i64)?;
64 let stat = fstat(&shm_fd)?;
65 debug!("Shared memory stat: {stat:?}");
66
67 Ok(ShmTensor::<T> {
68 name: name.to_owned(),
69 fd: shm_fd,
70 shape: shape.to_vec(),
71 _marker: std::marker::PhantomData,
72 })
73 }
74
75 fn from_fd(fd: OwnedFd, shape: &[usize], name: Option<&str>) -> Result<Self> {
76 if shape.is_empty() {
77 return Err(Error::InvalidSize(0));
78 }
79
80 let size = shape.iter().product::<usize>() * std::mem::size_of::<T>();
81 if size == 0 {
82 return Err(Error::InvalidSize(0));
83 }
84
85 Ok(ShmTensor {
86 name: name.unwrap_or("").to_owned(),
87 fd,
88 shape: shape.to_vec(),
89 _marker: std::marker::PhantomData,
90 })
91 }
92
93 fn clone_fd(&self) -> Result<OwnedFd> {
94 Ok(self.fd.try_clone()?)
95 }
96
97 fn memory(&self) -> TensorMemory {
98 TensorMemory::Shm
99 }
100
101 fn name(&self) -> String {
102 self.name.clone()
103 }
104
105 fn shape(&self) -> &[usize] {
106 &self.shape
107 }
108
109 fn reshape(&mut self, shape: &[usize]) -> Result<()> {
110 if shape.is_empty() {
111 return Err(Error::InvalidSize(0));
112 }
113
114 let new_size = shape.iter().product::<usize>() * std::mem::size_of::<T>();
115 if new_size != self.size() {
116 return Err(Error::ShapeMismatch(format!(
117 "Cannot reshape incompatible shape: {:?} to {:?}",
118 self.shape, shape
119 )));
120 }
121
122 self.shape = shape.to_vec();
123 Ok(())
124 }
125
126 fn map(&self) -> Result<TensorMap<T>> {
127 let size = NonZero::new(self.size()).ok_or(Error::InvalidSize(self.size()))?;
128 let ptr = unsafe {
129 nix::sys::mman::mmap(
130 None,
131 size,
132 nix::sys::mman::ProtFlags::PROT_READ | nix::sys::mman::ProtFlags::PROT_WRITE,
133 nix::sys::mman::MapFlags::MAP_SHARED,
134 &self.fd,
135 0,
136 )?
137 };
138
139 trace!("Mapping shared memory: {ptr:?}");
140 let shm_ptr = ShmPtr(NonNull::new(ptr.as_ptr()).ok_or(Error::InvalidSize(self.size()))?);
141 Ok(TensorMap::Shm(ShmMap {
142 ptr: Arc::new(Mutex::new(shm_ptr)),
143 shape: self.shape.clone(),
144 _marker: std::marker::PhantomData,
145 }))
146 }
147}
148
149impl<T> AsRawFd for ShmTensor<T>
150where
151 T: Num + Clone + fmt::Debug + Send + Sync,
152{
153 fn as_raw_fd(&self) -> std::os::fd::RawFd {
154 self.fd.as_raw_fd()
155 }
156}
157
158#[derive(Debug)]
159pub struct ShmMap<T>
160where
161 T: Num + Clone + fmt::Debug,
162{
163 ptr: Arc<Mutex<ShmPtr>>,
164 shape: Vec<usize>,
165 _marker: std::marker::PhantomData<T>,
166}
167
168impl<T> Deref for ShmMap<T>
169where
170 T: Num + Clone + fmt::Debug,
171{
172 type Target = [T];
173
174 fn deref(&self) -> &[T] {
175 self.as_slice()
176 }
177}
178
179impl<T> DerefMut for ShmMap<T>
180where
181 T: Num + Clone + fmt::Debug,
182{
183 fn deref_mut(&mut self) -> &mut [T] {
184 self.as_mut_slice()
185 }
186}
187
188#[derive(Debug)]
189struct ShmPtr(NonNull<c_void>);
190impl Deref for ShmPtr {
191 type Target = NonNull<c_void>;
192
193 fn deref(&self) -> &Self::Target {
194 &self.0
195 }
196}
197
198unsafe impl Send for ShmPtr {}
199
200impl<T> TensorMapTrait<T> for ShmMap<T>
201where
202 T: Num + Clone + fmt::Debug,
203{
204 fn shape(&self) -> &[usize] {
205 &self.shape
206 }
207
208 fn unmap(&mut self) {
209 let ptr = self.ptr.lock().expect("Failed to lock ShmMap pointer");
210 let err = unsafe { nix::sys::mman::munmap(**ptr, self.size()) };
211 if let Err(e) = err {
212 warn!("Failed to unmap shared memory: {e}");
213 }
214 }
215
216 fn as_slice(&self) -> &[T] {
217 let ptr = self.ptr.lock().expect("Failed to lock ShmMap pointer");
218 unsafe { std::slice::from_raw_parts(ptr.as_ptr() as *const T, self.len()) }
219 }
220
221 fn as_mut_slice(&mut self) -> &mut [T] {
222 let ptr = self.ptr.lock().expect("Failed to lock ShmMap pointer");
223 unsafe { std::slice::from_raw_parts_mut(ptr.as_ptr() as *mut T, self.len()) }
224 }
225}
226
227impl<T> Drop for ShmMap<T>
228where
229 T: Num + Clone + fmt::Debug,
230{
231 fn drop(&mut self) {
232 trace!("ShmMap dropped, unmapping memory");
233 self.unmap();
234 }
235}