1use crate::{
5 error::{Error, Result},
6 TensorMap, TensorMapTrait, TensorMemory, TensorTrait,
7};
8use log::trace;
9use num_traits::Num;
10use std::{
11 ffi::c_void,
12 fmt,
13 ops::{Deref, DerefMut},
14 ptr::NonNull,
15 sync::{Arc, Mutex},
16};
17#[derive(Debug)]
18pub struct MemTensor<T>
19where
20 T: Num + Clone + fmt::Debug + Send + Sync,
21{
22 pub name: String,
23 pub shape: Vec<usize>,
24 pub data: Vec<T>,
25 identity: crate::BufferIdentity,
26}
27
28unsafe impl<T> Send for MemTensor<T> where T: Num + Clone + fmt::Debug + Send + Sync {}
29unsafe impl<T> Sync for MemTensor<T> where T: Num + Clone + fmt::Debug + Send + Sync {}
30impl<T> TensorTrait<T> for MemTensor<T>
31where
32 T: Num + Clone + fmt::Debug + Send + Sync,
33{
34 fn new(shape: &[usize], name: Option<&str>) -> Result<Self> {
35 if shape.is_empty() {
36 return Err(Error::InvalidSize(0));
37 }
38
39 let element_count: usize = shape.iter().product();
44 let data = vec![T::zero(); element_count];
45
46 let name = name.unwrap_or("mem_tensor").to_owned();
47 Ok(MemTensor {
48 name,
49 shape: shape.to_vec(),
50 data,
51 identity: crate::BufferIdentity::new(),
52 })
53 }
54
55 #[cfg(unix)]
56 fn from_fd(_fd: std::os::fd::OwnedFd, _shape: &[usize], _name: Option<&str>) -> Result<Self> {
57 Err(Error::NotImplemented(
58 "MemTensor does not support from_fd".to_owned(),
59 ))
60 }
61
62 #[cfg(unix)]
63 fn clone_fd(&self) -> Result<std::os::fd::OwnedFd> {
64 Err(Error::NotImplemented(
65 "MemTensor does not support clone_fd".to_owned(),
66 ))
67 }
68
69 fn memory(&self) -> TensorMemory {
70 TensorMemory::Mem
71 }
72
73 fn name(&self) -> String {
74 self.name.clone()
75 }
76
77 fn shape(&self) -> &[usize] {
78 &self.shape
79 }
80
81 fn reshape(&mut self, shape: &[usize]) -> Result<()> {
82 if shape.is_empty() {
83 return Err(Error::InvalidSize(0));
84 }
85
86 let new_size = shape.iter().product::<usize>() * std::mem::size_of::<T>();
87 if new_size != self.size() {
88 return Err(Error::ShapeMismatch(format!(
89 "Cannot reshape incompatible shape: {:?} to {:?}",
90 self.shape, shape
91 )));
92 }
93
94 self.shape = shape.to_vec();
95 Ok(())
96 }
97
98 fn map(&self) -> Result<TensorMap<T>> {
99 let mem_ptr = MemPtr(
100 NonNull::new(self.data.as_ptr() as *mut c_void)
101 .ok_or(Error::InvalidSize(self.size()))?,
102 );
103 Ok(TensorMap::Mem(MemMap {
104 ptr: Arc::new(Mutex::new(mem_ptr)),
105 shape: self.shape.clone(),
106 _marker: std::marker::PhantomData,
107 }))
108 }
109
110 fn buffer_identity(&self) -> &crate::BufferIdentity {
111 &self.identity
112 }
113}
114
115#[derive(Debug)]
116pub struct MemMap<T>
117where
118 T: Num + Clone + fmt::Debug,
119{
120 ptr: Arc<Mutex<MemPtr>>,
121 shape: Vec<usize>,
122 _marker: std::marker::PhantomData<T>,
123}
124
125impl<T> Deref for MemMap<T>
126where
127 T: Num + Clone + fmt::Debug,
128{
129 type Target = [T];
130
131 fn deref(&self) -> &[T] {
132 self.as_slice()
133 }
134}
135
136impl<T> DerefMut for MemMap<T>
137where
138 T: Num + Clone + fmt::Debug,
139{
140 fn deref_mut(&mut self) -> &mut [T] {
141 self.as_mut_slice()
142 }
143}
144
145#[derive(Debug)]
146struct MemPtr(NonNull<c_void>);
147impl Deref for MemPtr {
148 type Target = NonNull<c_void>;
149
150 fn deref(&self) -> &Self::Target {
151 &self.0
152 }
153}
154
155unsafe impl Send for MemPtr {}
156
157impl<T> TensorMapTrait<T> for MemMap<T>
158where
159 T: Num + Clone + fmt::Debug,
160{
161 fn shape(&self) -> &[usize] {
162 &self.shape
163 }
164
165 fn unmap(&mut self) {
166 trace!("Unmapping MemMap memory");
167 }
168
169 fn as_slice(&self) -> &[T] {
170 let ptr = self.ptr.lock().expect("Failed to lock MemMap pointer");
171 unsafe { std::slice::from_raw_parts(ptr.as_ptr() as *const T, self.len()) }
172 }
173
174 fn as_mut_slice(&mut self) -> &mut [T] {
175 let ptr = self.ptr.lock().expect("Failed to lock MemMap pointer");
176 unsafe { std::slice::from_raw_parts_mut(ptr.as_ptr() as *mut T, self.len()) }
177 }
178}
179
180impl<T> Drop for MemMap<T>
181where
182 T: Num + Clone + fmt::Debug,
183{
184 fn drop(&mut self) {
185 self.unmap();
186 }
187}
188
189#[cfg(test)]
190mod tests {
191 use super::*;
192 use crate::{TensorMapTrait, TensorMemory, TensorTrait};
193
194 #[test]
195 fn test_new_valid_shape() {
196 let tensor = MemTensor::<u8>::new(&[2, 3, 4], Some("test")).unwrap();
197 assert_eq!(tensor.shape(), &[2, 3, 4]);
198 assert_eq!(tensor.memory(), TensorMemory::Mem);
199 assert_eq!(tensor.name(), "test");
200 assert_eq!(tensor.size(), 24);
201 assert_eq!(tensor.len(), 24);
202 }
203
204 #[test]
205 fn test_new_empty_shape_error() {
206 let result = MemTensor::<u8>::new(&[], Some("test"));
207 assert!(result.is_err());
208 assert!(matches!(result.unwrap_err(), Error::InvalidSize(_)));
209 }
210
211 #[test]
212 fn test_new_zero_dim_is_accepted() {
213 let result = MemTensor::<u8>::new(&[2, 0, 4], Some("test")).unwrap();
218 assert_eq!(result.shape(), &[2, 0, 4]);
219 assert_eq!(result.size(), 0);
220 assert_eq!(result.len(), 0);
221 }
222
223 #[test]
224 fn test_map_read_write() {
225 let tensor = MemTensor::<u8>::new(&[2, 3], Some("rw")).unwrap();
226 let mut map = tensor.map().unwrap();
227 map.as_mut_slice()[0] = 42;
228 map.as_mut_slice()[1] = 99;
229 assert_eq!(map.as_slice()[0], 42);
230 assert_eq!(map.as_slice()[1], 99);
231 assert_eq!(map.as_slice()[2], 0);
233 }
234
235 #[test]
236 fn test_reshape_compatible() {
237 let mut tensor = MemTensor::<u8>::new(&[2, 3], None).unwrap();
238 tensor.reshape(&[6]).unwrap();
239 assert_eq!(tensor.shape(), &[6]);
240 assert_eq!(tensor.len(), 6);
241 }
242
243 #[test]
244 fn test_reshape_incompatible() {
245 let mut tensor = MemTensor::<u8>::new(&[2, 3], None).unwrap();
246 let result = tensor.reshape(&[7]);
247 assert!(result.is_err());
248 assert!(matches!(result.unwrap_err(), Error::ShapeMismatch(_)));
249 }
250}