1use crate::{
5 error::{Error, Result},
6 TensorMap, TensorMapTrait, TensorMemory, TensorTrait,
7};
8use log::trace;
9use num_traits::Num;
10use std::{
11 ffi::c_void,
12 fmt,
13 ops::{Deref, DerefMut},
14 ptr::NonNull,
15 sync::{Arc, Mutex},
16};
17#[derive(Debug)]
18pub struct MemTensor<T>
19where
20 T: Num + Clone + fmt::Debug + Send + Sync,
21{
22 pub name: String,
23 pub shape: Vec<usize>,
24 pub data: Vec<T>,
25}
26
27unsafe impl<T> Send for MemTensor<T> where T: Num + Clone + fmt::Debug + Send + Sync {}
28unsafe impl<T> Sync for MemTensor<T> where T: Num + Clone + fmt::Debug + Send + Sync {}
29impl<T> TensorTrait<T> for MemTensor<T>
30where
31 T: Num + Clone + fmt::Debug + Send + Sync,
32{
33 fn new(shape: &[usize], name: Option<&str>) -> Result<Self> {
34 if shape.is_empty() {
35 return Err(Error::InvalidSize(0));
36 }
37
38 let size = shape.iter().product::<usize>() * std::mem::size_of::<T>();
39 if size == 0 {
40 return Err(Error::InvalidSize(0));
41 }
42
43 let name = name.unwrap_or("mem_tensor").to_owned();
44 let data = vec![T::zero(); size / std::mem::size_of::<T>()];
45
46 Ok(MemTensor {
47 name,
48 shape: shape.to_vec(),
49 data,
50 })
51 }
52
53 #[cfg(unix)]
54 fn from_fd(_fd: std::os::fd::OwnedFd, _shape: &[usize], _name: Option<&str>) -> Result<Self> {
55 Err(Error::NotImplemented(
56 "MemTensor does not support from_fd".to_owned(),
57 ))
58 }
59
60 #[cfg(unix)]
61 fn clone_fd(&self) -> Result<std::os::fd::OwnedFd> {
62 Err(Error::NotImplemented(
63 "MemTensor does not support clone_fd".to_owned(),
64 ))
65 }
66
67 fn memory(&self) -> TensorMemory {
68 TensorMemory::Mem
69 }
70
71 fn name(&self) -> String {
72 self.name.clone()
73 }
74
75 fn shape(&self) -> &[usize] {
76 &self.shape
77 }
78
79 fn reshape(&mut self, shape: &[usize]) -> Result<()> {
80 if shape.is_empty() {
81 return Err(Error::InvalidSize(0));
82 }
83
84 let new_size = shape.iter().product::<usize>() * std::mem::size_of::<T>();
85 if new_size != self.size() {
86 return Err(Error::ShapeMismatch(format!(
87 "Cannot reshape incompatible shape: {:?} to {:?}",
88 self.shape, shape
89 )));
90 }
91
92 self.shape = shape.to_vec();
93 Ok(())
94 }
95
96 fn map(&self) -> Result<TensorMap<T>> {
97 let mem_ptr = MemPtr(
98 NonNull::new(self.data.as_ptr() as *mut c_void)
99 .ok_or(Error::InvalidSize(self.size()))?,
100 );
101 Ok(TensorMap::Mem(MemMap {
102 ptr: Arc::new(Mutex::new(mem_ptr)),
103 shape: self.shape.clone(),
104 _marker: std::marker::PhantomData,
105 }))
106 }
107}
108
109#[derive(Debug)]
110pub struct MemMap<T>
111where
112 T: Num + Clone + fmt::Debug,
113{
114 ptr: Arc<Mutex<MemPtr>>,
115 shape: Vec<usize>,
116 _marker: std::marker::PhantomData<T>,
117}
118
119impl<T> Deref for MemMap<T>
120where
121 T: Num + Clone + fmt::Debug,
122{
123 type Target = [T];
124
125 fn deref(&self) -> &[T] {
126 self.as_slice()
127 }
128}
129
130impl<T> DerefMut for MemMap<T>
131where
132 T: Num + Clone + fmt::Debug,
133{
134 fn deref_mut(&mut self) -> &mut [T] {
135 self.as_mut_slice()
136 }
137}
138
139#[derive(Debug)]
140struct MemPtr(NonNull<c_void>);
141impl Deref for MemPtr {
142 type Target = NonNull<c_void>;
143
144 fn deref(&self) -> &Self::Target {
145 &self.0
146 }
147}
148
149unsafe impl Send for MemPtr {}
150
151impl<T> TensorMapTrait<T> for MemMap<T>
152where
153 T: Num + Clone + fmt::Debug,
154{
155 fn shape(&self) -> &[usize] {
156 &self.shape
157 }
158
159 fn unmap(&mut self) {
160 trace!("Unmapping MemMap memory");
161 }
162
163 fn as_slice(&self) -> &[T] {
164 let ptr = self.ptr.lock().expect("Failed to lock MemMap pointer");
165 unsafe { std::slice::from_raw_parts(ptr.as_ptr() as *const T, self.len()) }
166 }
167
168 fn as_mut_slice(&mut self) -> &mut [T] {
169 let ptr = self.ptr.lock().expect("Failed to lock MemMap pointer");
170 unsafe { std::slice::from_raw_parts_mut(ptr.as_ptr() as *mut T, self.len()) }
171 }
172}
173
174impl<T> Drop for MemMap<T>
175where
176 T: Num + Clone + fmt::Debug,
177{
178 fn drop(&mut self) {
179 self.unmap();
180 }
181}