1use crate::{
5 error::{Error, Result},
6 TensorMap, TensorMapTrait, TensorMemory, TensorTrait,
7};
8use log::{debug, trace, warn};
9use nix::{fcntl::OFlag, sys::stat::fstat, unistd::ftruncate};
10use num_traits::Num;
11use std::{
12 ffi::c_void,
13 fmt,
14 num::NonZero,
15 ops::{Deref, DerefMut},
16 os::fd::{AsRawFd, OwnedFd},
17 ptr::NonNull,
18 sync::{Arc, Mutex},
19};
20#[derive(Debug)]
21pub struct ShmTensor<T>
22where
23 T: Num + Clone + fmt::Debug + Send + Sync,
24{
25 pub name: String,
26 pub fd: OwnedFd,
27 pub shape: Vec<usize>,
28 pub _marker: std::marker::PhantomData<T>,
29 identity: crate::BufferIdentity,
30}
31
32unsafe impl<T> Send for ShmTensor<T> where T: Num + Clone + fmt::Debug + Send + Sync {}
33unsafe impl<T> Sync for ShmTensor<T> where T: Num + Clone + fmt::Debug + Send + Sync {}
34impl<T> TensorTrait<T> for ShmTensor<T>
35where
36 T: Num + Clone + fmt::Debug + Send + Sync,
37{
38 fn new(shape: &[usize], name: Option<&str>) -> Result<Self> {
39 let size = shape.iter().product::<usize>() * std::mem::size_of::<T>();
40 let name = match name {
41 Some(name) => name.to_owned(),
42 None => {
43 let uuid = uuid::Uuid::new_v4().as_simple().to_string();
44 format!("/{}", &uuid[..16])
45 }
46 };
47
48 let shm_fd = nix::sys::mman::shm_open(
49 name.as_str(),
50 OFlag::O_CREAT | OFlag::O_EXCL | OFlag::O_RDWR,
51 nix::sys::stat::Mode::S_IRUSR | nix::sys::stat::Mode::S_IWUSR,
52 )?;
53
54 trace!("Creating shared memory: {name}");
55
56 let err = nix::sys::mman::shm_unlink(name.as_str());
60 if let Err(e) = err {
61 warn!("Failed to unlink shared memory: {e}");
62 }
63
64 ftruncate(&shm_fd, size as i64)?;
65 let stat = fstat(&shm_fd)?;
66 debug!("Shared memory stat: {stat:?}");
67
68 Ok(ShmTensor::<T> {
69 name: name.to_owned(),
70 fd: shm_fd,
71 shape: shape.to_vec(),
72 _marker: std::marker::PhantomData,
73 identity: crate::BufferIdentity::new(),
74 })
75 }
76
77 fn from_fd(fd: OwnedFd, shape: &[usize], name: Option<&str>) -> Result<Self> {
78 if shape.is_empty() {
79 return Err(Error::InvalidSize(0));
80 }
81
82 let size = shape.iter().product::<usize>() * std::mem::size_of::<T>();
83 if size == 0 {
84 return Err(Error::InvalidSize(0));
85 }
86
87 Ok(ShmTensor {
88 name: name.unwrap_or("").to_owned(),
89 fd,
90 shape: shape.to_vec(),
91 _marker: std::marker::PhantomData,
92 identity: crate::BufferIdentity::new(),
93 })
94 }
95
96 fn clone_fd(&self) -> Result<OwnedFd> {
97 Ok(self.fd.try_clone()?)
98 }
99
100 fn memory(&self) -> TensorMemory {
101 TensorMemory::Shm
102 }
103
104 fn name(&self) -> String {
105 self.name.clone()
106 }
107
108 fn shape(&self) -> &[usize] {
109 &self.shape
110 }
111
112 fn reshape(&mut self, shape: &[usize]) -> Result<()> {
113 if shape.is_empty() {
114 return Err(Error::InvalidSize(0));
115 }
116
117 let new_size = shape.iter().product::<usize>() * std::mem::size_of::<T>();
118 if new_size != self.size() {
119 return Err(Error::ShapeMismatch(format!(
120 "Cannot reshape incompatible shape: {:?} to {:?}",
121 self.shape, shape
122 )));
123 }
124
125 self.shape = shape.to_vec();
126 Ok(())
127 }
128
129 fn map(&self) -> Result<TensorMap<T>> {
130 let size = NonZero::new(self.size()).ok_or(Error::InvalidSize(self.size()))?;
131 let ptr = unsafe {
132 nix::sys::mman::mmap(
133 None,
134 size,
135 nix::sys::mman::ProtFlags::PROT_READ | nix::sys::mman::ProtFlags::PROT_WRITE,
136 nix::sys::mman::MapFlags::MAP_SHARED,
137 &self.fd,
138 0,
139 )?
140 };
141
142 trace!("Mapping shared memory: {ptr:?}");
143 let shm_ptr = ShmPtr(NonNull::new(ptr.as_ptr()).ok_or(Error::InvalidSize(self.size()))?);
144 Ok(TensorMap::Shm(ShmMap {
145 ptr: Arc::new(Mutex::new(shm_ptr)),
146 shape: self.shape.clone(),
147 _marker: std::marker::PhantomData,
148 }))
149 }
150
151 fn buffer_identity(&self) -> &crate::BufferIdentity {
152 &self.identity
153 }
154}
155
156impl<T> AsRawFd for ShmTensor<T>
157where
158 T: Num + Clone + fmt::Debug + Send + Sync,
159{
160 fn as_raw_fd(&self) -> std::os::fd::RawFd {
161 self.fd.as_raw_fd()
162 }
163}
164
165#[derive(Debug)]
166pub struct ShmMap<T>
167where
168 T: Num + Clone + fmt::Debug,
169{
170 ptr: Arc<Mutex<ShmPtr>>,
171 shape: Vec<usize>,
172 _marker: std::marker::PhantomData<T>,
173}
174
175impl<T> Deref for ShmMap<T>
176where
177 T: Num + Clone + fmt::Debug,
178{
179 type Target = [T];
180
181 fn deref(&self) -> &[T] {
182 self.as_slice()
183 }
184}
185
186impl<T> DerefMut for ShmMap<T>
187where
188 T: Num + Clone + fmt::Debug,
189{
190 fn deref_mut(&mut self) -> &mut [T] {
191 self.as_mut_slice()
192 }
193}
194
195#[derive(Debug)]
196struct ShmPtr(NonNull<c_void>);
197impl Deref for ShmPtr {
198 type Target = NonNull<c_void>;
199
200 fn deref(&self) -> &Self::Target {
201 &self.0
202 }
203}
204
205unsafe impl Send for ShmPtr {}
206
207impl<T> TensorMapTrait<T> for ShmMap<T>
208where
209 T: Num + Clone + fmt::Debug,
210{
211 fn shape(&self) -> &[usize] {
212 &self.shape
213 }
214
215 fn unmap(&mut self) {
216 let ptr = self.ptr.lock().expect("Failed to lock ShmMap pointer");
217 let err = unsafe { nix::sys::mman::munmap(**ptr, self.size()) };
218 if let Err(e) = err {
219 warn!("Failed to unmap shared memory: {e}");
220 }
221 }
222
223 fn as_slice(&self) -> &[T] {
224 let ptr = self.ptr.lock().expect("Failed to lock ShmMap pointer");
225 unsafe { std::slice::from_raw_parts(ptr.as_ptr() as *const T, self.len()) }
226 }
227
228 fn as_mut_slice(&mut self) -> &mut [T] {
229 let ptr = self.ptr.lock().expect("Failed to lock ShmMap pointer");
230 unsafe { std::slice::from_raw_parts_mut(ptr.as_ptr() as *mut T, self.len()) }
231 }
232}
233
234impl<T> Drop for ShmMap<T>
235where
236 T: Num + Clone + fmt::Debug,
237{
238 fn drop(&mut self) {
239 trace!("ShmMap dropped, unmapping memory");
240 self.unmap();
241 }
242}