1use crate::{
5 error::{Error, Result},
6 TensorMap, TensorMapTrait, TensorMemory, TensorTrait,
7};
8use log::{trace, warn};
9use num_traits::Num;
10use std::{
11 ffi::c_void,
12 fmt,
13 num::NonZero,
14 ops::{Deref, DerefMut},
15 os::fd::{AsRawFd, OwnedFd},
16 ptr::NonNull,
17 sync::{Arc, Mutex},
18};
19
20#[derive(Debug)]
22pub struct DmaTensor<T>
23where
24 T: Num + Clone + fmt::Debug + Send + Sync,
25{
26 pub name: String,
27 pub fd: OwnedFd,
28 pub shape: Vec<usize>,
29 pub _marker: std::marker::PhantomData<T>,
30}
31
32unsafe impl<T> Send for DmaTensor<T> where T: Num + Clone + fmt::Debug + Send + Sync {}
33unsafe impl<T> Sync for DmaTensor<T> where T: Num + Clone + fmt::Debug + Send + Sync {}
34
35impl<T> TensorTrait<T> for DmaTensor<T>
36where
37 T: Num + Clone + fmt::Debug + Send + Sync,
38{
39 #[cfg(target_os = "linux")]
40 fn new(shape: &[usize], name: Option<&str>) -> Result<Self> {
41 use log::debug;
42 use nix::sys::stat::fstat;
43
44 let size = shape.iter().product::<usize>() * std::mem::size_of::<T>();
45 let name = match name {
46 Some(name) => name.to_owned(),
47 None => {
48 let uuid = uuid::Uuid::new_v4().as_simple().to_string();
49 format!("/{}", &uuid[..16])
50 }
51 };
52
53 let heap = match dma_heap::Heap::new(dma_heap::HeapKind::Cma) {
54 Ok(heap) => heap,
55 Err(_) => dma_heap::Heap::new(dma_heap::HeapKind::System)?,
56 };
57
58 let dma_fd = heap.allocate(size)?;
59 let stat = fstat(&dma_fd)?;
60 debug!("DMA memory stat: {stat:?}");
61
62 Ok(DmaTensor::<T> {
63 name: name.to_owned(),
64 fd: dma_fd,
65 shape: shape.to_vec(),
66 _marker: std::marker::PhantomData,
67 })
68 }
69
70 #[cfg(not(target_os = "linux"))]
71 fn new(_shape: &[usize], _name: Option<&str>) -> Result<Self> {
72 Err(Error::NotImplemented(
73 "DMA tensors are not supported on this platform".to_owned(),
74 ))
75 }
76
77 fn from_fd(fd: OwnedFd, shape: &[usize], name: Option<&str>) -> Result<Self> {
78 if shape.is_empty() {
79 return Err(Error::InvalidSize(0));
80 }
81
82 let size = shape.iter().product::<usize>() * std::mem::size_of::<T>();
83 if size == 0 {
84 return Err(Error::InvalidSize(0));
85 }
86
87 Ok(DmaTensor {
88 name: name.unwrap_or("").to_owned(),
89 fd,
90 shape: shape.to_vec(),
91 _marker: std::marker::PhantomData,
92 })
93 }
94
95 fn clone_fd(&self) -> Result<OwnedFd> {
96 Ok(self.fd.try_clone()?)
97 }
98
99 fn memory(&self) -> TensorMemory {
100 TensorMemory::Dma
101 }
102
103 fn name(&self) -> String {
104 self.name.clone()
105 }
106
107 fn shape(&self) -> &[usize] {
108 &self.shape
109 }
110
111 fn reshape(&mut self, shape: &[usize]) -> Result<()> {
112 if shape.is_empty() {
113 return Err(Error::InvalidSize(0));
114 }
115
116 let new_size = shape.iter().product::<usize>() * std::mem::size_of::<T>();
117 if new_size != self.size() {
118 return Err(Error::ShapeMismatch(format!(
119 "Cannot reshape incompatible shape: {:?} to {:?}",
120 self.shape, shape
121 )));
122 }
123
124 self.shape = shape.to_vec();
125 Ok(())
126 }
127
128 fn map(&self) -> Result<TensorMap<T>> {
129 Ok(TensorMap::Dma(DmaMap::new(
130 self.fd.try_clone()?,
131 &self.shape,
132 )?))
133 }
134}
135
136impl<T> AsRawFd for DmaTensor<T>
137where
138 T: Num + Clone + fmt::Debug + Send + Sync,
139{
140 fn as_raw_fd(&self) -> std::os::fd::RawFd {
141 self.fd.as_raw_fd()
142 }
143}
144
145impl<T> DmaTensor<T>
146where
147 T: Num + Clone + Send + Sync + std::fmt::Debug + Send + Sync,
148{
149 pub fn try_clone(&self) -> Result<Self> {
150 let fd = self.clone_fd()?;
151 Ok(Self {
152 name: self.name.clone(),
153 fd,
154 shape: self.shape.clone(),
155 _marker: std::marker::PhantomData,
156 })
157 }
158}
159
160#[derive(Debug)]
161pub struct DmaMap<T>
162where
163 T: Num + Clone + fmt::Debug,
164{
165 ptr: Arc<Mutex<DmaPtr>>,
166 fd: OwnedFd,
167 shape: Vec<usize>,
168 _marker: std::marker::PhantomData<T>,
169}
170
171impl<T> DmaMap<T>
172where
173 T: Num + Clone + fmt::Debug,
174{
175 pub fn new(fd: OwnedFd, shape: &[usize]) -> Result<Self> {
176 if shape.is_empty() {
177 return Err(Error::InvalidSize(0));
178 }
179
180 let size = shape.iter().product::<usize>() * std::mem::size_of::<T>();
181 if size == 0 {
182 return Err(Error::InvalidSize(0));
183 }
184
185 #[cfg(target_os = "linux")]
186 crate::dmabuf::start_readwrite(&fd)?;
187
188 let ptr = unsafe {
189 nix::sys::mman::mmap(
190 None,
191 NonZero::new(size).ok_or(Error::InvalidSize(size))?,
192 nix::sys::mman::ProtFlags::PROT_READ | nix::sys::mman::ProtFlags::PROT_WRITE,
193 nix::sys::mman::MapFlags::MAP_SHARED,
194 &fd,
195 0,
196 )?
197 };
198
199 trace!("Mapping DMA memory: {ptr:?}");
200 let dma_ptr = DmaPtr(NonNull::new(ptr.as_ptr()).ok_or(Error::InvalidSize(size))?);
201 Ok(DmaMap {
202 ptr: Arc::new(Mutex::new(dma_ptr)),
203 fd,
204 shape: shape.to_vec(),
205 _marker: std::marker::PhantomData,
206 })
207 }
208}
209
210impl<T> Deref for DmaMap<T>
211where
212 T: Num + Clone + fmt::Debug,
213{
214 type Target = [T];
215
216 fn deref(&self) -> &[T] {
217 self.as_slice()
218 }
219}
220
221impl<T> DerefMut for DmaMap<T>
222where
223 T: Num + Clone + fmt::Debug,
224{
225 fn deref_mut(&mut self) -> &mut [T] {
226 self.as_mut_slice()
227 }
228}
229
230#[derive(Debug)]
231struct DmaPtr(NonNull<c_void>);
232impl Deref for DmaPtr {
233 type Target = NonNull<c_void>;
234
235 fn deref(&self) -> &Self::Target {
236 &self.0
237 }
238}
239
240unsafe impl Send for DmaPtr {}
241
242impl<T> TensorMapTrait<T> for DmaMap<T>
243where
244 T: Num + Clone + fmt::Debug,
245{
246 fn shape(&self) -> &[usize] {
247 &self.shape
248 }
249
250 fn unmap(&mut self) {
251 let ptr = self.ptr.lock().expect("Failed to lock DmaMap pointer");
252
253 if let Err(e) = unsafe { nix::sys::mman::munmap(**ptr, self.size()) } {
254 warn!("Failed to unmap DMA memory: {e}");
255 }
256
257 #[cfg(target_os = "linux")]
258 if let Err(e) = crate::dmabuf::end_readwrite(&self.fd) {
259 warn!("Failed to end read/write on DMA memory: {e}");
260 }
261 }
262
263 fn as_slice(&self) -> &[T] {
264 let ptr = self.ptr.lock().expect("Failed to lock DmaMap pointer");
265 unsafe { std::slice::from_raw_parts(ptr.as_ptr() as *const T, self.len()) }
266 }
267
268 fn as_mut_slice(&mut self) -> &mut [T] {
269 let ptr = self.ptr.lock().expect("Failed to lock DmaMap pointer");
270 unsafe { std::slice::from_raw_parts_mut(ptr.as_ptr() as *mut T, self.len()) }
271 }
272}
273
274impl<T> Drop for DmaMap<T>
275where
276 T: Num + Clone + fmt::Debug,
277{
278 fn drop(&mut self) {
279 trace!("DmaMap dropped, unmapping memory");
280 self.unmap();
281 }
282}