1use crate::{
5 error::{Error, Result},
6 TensorMap, TensorMapTrait, TensorMemory, TensorTrait,
7};
8use log::trace;
9use num_traits::Num;
10use std::{
11 ffi::c_void,
12 fmt,
13 ops::{Deref, DerefMut},
14 ptr::NonNull,
15 sync::{Arc, Mutex},
16};
17#[derive(Debug)]
18pub struct MemTensor<T>
19where
20 T: Num + Clone + fmt::Debug + Send + Sync,
21{
22 pub name: String,
23 pub shape: Vec<usize>,
24 pub data: Vec<T>,
25 identity: crate::BufferIdentity,
26}
27
28unsafe impl<T> Send for MemTensor<T> where T: Num + Clone + fmt::Debug + Send + Sync {}
29unsafe impl<T> Sync for MemTensor<T> where T: Num + Clone + fmt::Debug + Send + Sync {}
30impl<T> TensorTrait<T> for MemTensor<T>
31where
32 T: Num + Clone + fmt::Debug + Send + Sync,
33{
34 fn new(shape: &[usize], name: Option<&str>) -> Result<Self> {
35 if shape.is_empty() {
36 return Err(Error::InvalidSize(0));
37 }
38
39 let size = shape.iter().product::<usize>() * std::mem::size_of::<T>();
40 if size == 0 {
41 return Err(Error::InvalidSize(0));
42 }
43
44 let name = name.unwrap_or("mem_tensor").to_owned();
45 let data = vec![T::zero(); size / std::mem::size_of::<T>()];
46
47 Ok(MemTensor {
48 name,
49 shape: shape.to_vec(),
50 data,
51 identity: crate::BufferIdentity::new(),
52 })
53 }
54
55 #[cfg(unix)]
56 fn from_fd(_fd: std::os::fd::OwnedFd, _shape: &[usize], _name: Option<&str>) -> Result<Self> {
57 Err(Error::NotImplemented(
58 "MemTensor does not support from_fd".to_owned(),
59 ))
60 }
61
62 #[cfg(unix)]
63 fn clone_fd(&self) -> Result<std::os::fd::OwnedFd> {
64 Err(Error::NotImplemented(
65 "MemTensor does not support clone_fd".to_owned(),
66 ))
67 }
68
69 fn memory(&self) -> TensorMemory {
70 TensorMemory::Mem
71 }
72
73 fn name(&self) -> String {
74 self.name.clone()
75 }
76
77 fn shape(&self) -> &[usize] {
78 &self.shape
79 }
80
81 fn reshape(&mut self, shape: &[usize]) -> Result<()> {
82 if shape.is_empty() {
83 return Err(Error::InvalidSize(0));
84 }
85
86 let new_size = shape.iter().product::<usize>() * std::mem::size_of::<T>();
87 if new_size != self.size() {
88 return Err(Error::ShapeMismatch(format!(
89 "Cannot reshape incompatible shape: {:?} to {:?}",
90 self.shape, shape
91 )));
92 }
93
94 self.shape = shape.to_vec();
95 Ok(())
96 }
97
98 fn map(&self) -> Result<TensorMap<T>> {
99 let mem_ptr = MemPtr(
100 NonNull::new(self.data.as_ptr() as *mut c_void)
101 .ok_or(Error::InvalidSize(self.size()))?,
102 );
103 Ok(TensorMap::Mem(MemMap {
104 ptr: Arc::new(Mutex::new(mem_ptr)),
105 shape: self.shape.clone(),
106 _marker: std::marker::PhantomData,
107 }))
108 }
109
110 fn buffer_identity(&self) -> &crate::BufferIdentity {
111 &self.identity
112 }
113}
114
115#[derive(Debug)]
116pub struct MemMap<T>
117where
118 T: Num + Clone + fmt::Debug,
119{
120 ptr: Arc<Mutex<MemPtr>>,
121 shape: Vec<usize>,
122 _marker: std::marker::PhantomData<T>,
123}
124
125impl<T> Deref for MemMap<T>
126where
127 T: Num + Clone + fmt::Debug,
128{
129 type Target = [T];
130
131 fn deref(&self) -> &[T] {
132 self.as_slice()
133 }
134}
135
136impl<T> DerefMut for MemMap<T>
137where
138 T: Num + Clone + fmt::Debug,
139{
140 fn deref_mut(&mut self) -> &mut [T] {
141 self.as_mut_slice()
142 }
143}
144
145#[derive(Debug)]
146struct MemPtr(NonNull<c_void>);
147impl Deref for MemPtr {
148 type Target = NonNull<c_void>;
149
150 fn deref(&self) -> &Self::Target {
151 &self.0
152 }
153}
154
155unsafe impl Send for MemPtr {}
156
157impl<T> TensorMapTrait<T> for MemMap<T>
158where
159 T: Num + Clone + fmt::Debug,
160{
161 fn shape(&self) -> &[usize] {
162 &self.shape
163 }
164
165 fn unmap(&mut self) {
166 trace!("Unmapping MemMap memory");
167 }
168
169 fn as_slice(&self) -> &[T] {
170 let ptr = self.ptr.lock().expect("Failed to lock MemMap pointer");
171 unsafe { std::slice::from_raw_parts(ptr.as_ptr() as *const T, self.len()) }
172 }
173
174 fn as_mut_slice(&mut self) -> &mut [T] {
175 let ptr = self.ptr.lock().expect("Failed to lock MemMap pointer");
176 unsafe { std::slice::from_raw_parts_mut(ptr.as_ptr() as *mut T, self.len()) }
177 }
178}
179
180impl<T> Drop for MemMap<T>
181where
182 T: Num + Clone + fmt::Debug,
183{
184 fn drop(&mut self) {
185 self.unmap();
186 }
187}