1use crate::{
5 error::{Error, Result},
6 TensorMap, TensorMapTrait, TensorMemory, TensorTrait,
7};
8use log::{trace, warn};
9use num_traits::Num;
10use std::{
11 ffi::c_void,
12 fmt,
13 num::NonZero,
14 ops::{Deref, DerefMut},
15 os::fd::{AsRawFd, OwnedFd},
16 ptr::NonNull,
17 sync::{Arc, Mutex},
18};
19
20#[derive(Debug)]
26pub struct DmaTensor<T>
27where
28 T: Num + Clone + fmt::Debug + Send + Sync,
29{
30 pub name: String,
31 pub fd: OwnedFd,
32 pub shape: Vec<usize>,
33 pub _marker: std::marker::PhantomData<T>,
34 #[cfg(target_os = "linux")]
35 _drm_attachment: Option<crate::dmabuf::DrmAttachment>,
36}
37
38unsafe impl<T> Send for DmaTensor<T> where T: Num + Clone + fmt::Debug + Send + Sync {}
39unsafe impl<T> Sync for DmaTensor<T> where T: Num + Clone + fmt::Debug + Send + Sync {}
40
41impl<T> TensorTrait<T> for DmaTensor<T>
42where
43 T: Num + Clone + fmt::Debug + Send + Sync,
44{
45 #[cfg(target_os = "linux")]
46 fn new(shape: &[usize], name: Option<&str>) -> Result<Self> {
47 use log::debug;
48 use nix::sys::stat::fstat;
49
50 let size = shape.iter().product::<usize>() * std::mem::size_of::<T>();
51 let name = match name {
52 Some(name) => name.to_owned(),
53 None => {
54 let uuid = uuid::Uuid::new_v4().as_simple().to_string();
55 format!("/{}", &uuid[..16])
56 }
57 };
58
59 let heap = match dma_heap::Heap::new(dma_heap::HeapKind::Cma) {
60 Ok(heap) => heap,
61 Err(_) => dma_heap::Heap::new(dma_heap::HeapKind::System)?,
62 };
63
64 let dma_fd = heap.allocate(size)?;
65 let stat = fstat(&dma_fd)?;
66 debug!("DMA memory stat: {stat:?}");
67
68 let drm_attachment = crate::dmabuf::DrmAttachment::new(&dma_fd);
69
70 Ok(DmaTensor::<T> {
71 name: name.to_owned(),
72 fd: dma_fd,
73 shape: shape.to_vec(),
74 _marker: std::marker::PhantomData,
75 _drm_attachment: drm_attachment,
76 })
77 }
78
79 #[cfg(not(target_os = "linux"))]
80 fn new(_shape: &[usize], _name: Option<&str>) -> Result<Self> {
81 Err(Error::NotImplemented(
82 "DMA tensors are not supported on this platform".to_owned(),
83 ))
84 }
85
86 fn from_fd(fd: OwnedFd, shape: &[usize], name: Option<&str>) -> Result<Self> {
87 if shape.is_empty() {
88 return Err(Error::InvalidSize(0));
89 }
90
91 let size = shape.iter().product::<usize>() * std::mem::size_of::<T>();
92 if size == 0 {
93 return Err(Error::InvalidSize(0));
94 }
95
96 #[cfg(target_os = "linux")]
97 let drm_attachment = crate::dmabuf::DrmAttachment::new(&fd);
98
99 Ok(DmaTensor {
100 name: name.unwrap_or("").to_owned(),
101 fd,
102 shape: shape.to_vec(),
103 _marker: std::marker::PhantomData,
104 #[cfg(target_os = "linux")]
105 _drm_attachment: drm_attachment,
106 })
107 }
108
109 fn clone_fd(&self) -> Result<OwnedFd> {
110 Ok(self.fd.try_clone()?)
111 }
112
113 fn memory(&self) -> TensorMemory {
114 TensorMemory::Dma
115 }
116
117 fn name(&self) -> String {
118 self.name.clone()
119 }
120
121 fn shape(&self) -> &[usize] {
122 &self.shape
123 }
124
125 fn reshape(&mut self, shape: &[usize]) -> Result<()> {
126 if shape.is_empty() {
127 return Err(Error::InvalidSize(0));
128 }
129
130 let new_size = shape.iter().product::<usize>() * std::mem::size_of::<T>();
131 if new_size != self.size() {
132 return Err(Error::ShapeMismatch(format!(
133 "Cannot reshape incompatible shape: {:?} to {:?}",
134 self.shape, shape
135 )));
136 }
137
138 self.shape = shape.to_vec();
139 Ok(())
140 }
141
142 fn map(&self) -> Result<TensorMap<T>> {
143 Ok(TensorMap::Dma(DmaMap::new(
144 self.fd.try_clone()?,
145 &self.shape,
146 )?))
147 }
148}
149
150impl<T> AsRawFd for DmaTensor<T>
151where
152 T: Num + Clone + fmt::Debug + Send + Sync,
153{
154 fn as_raw_fd(&self) -> std::os::fd::RawFd {
155 self.fd.as_raw_fd()
156 }
157}
158
159impl<T> DmaTensor<T>
160where
161 T: Num + Clone + Send + Sync + std::fmt::Debug + Send + Sync,
162{
163 pub fn try_clone(&self) -> Result<Self> {
164 let fd = self.clone_fd()?;
165 #[cfg(target_os = "linux")]
166 let drm_attachment = crate::dmabuf::DrmAttachment::new(&fd);
167 Ok(Self {
168 name: self.name.clone(),
169 fd,
170 shape: self.shape.clone(),
171 _marker: std::marker::PhantomData,
172 #[cfg(target_os = "linux")]
173 _drm_attachment: drm_attachment,
174 })
175 }
176}
177
178#[derive(Debug)]
179pub struct DmaMap<T>
180where
181 T: Num + Clone + fmt::Debug,
182{
183 ptr: Arc<Mutex<DmaPtr>>,
184 fd: OwnedFd,
185 shape: Vec<usize>,
186 _marker: std::marker::PhantomData<T>,
187}
188
189impl<T> DmaMap<T>
190where
191 T: Num + Clone + fmt::Debug,
192{
193 pub fn new(fd: OwnedFd, shape: &[usize]) -> Result<Self> {
194 if shape.is_empty() {
195 return Err(Error::InvalidSize(0));
196 }
197
198 let size = shape.iter().product::<usize>() * std::mem::size_of::<T>();
199 if size == 0 {
200 return Err(Error::InvalidSize(0));
201 }
202
203 #[cfg(target_os = "linux")]
204 crate::dmabuf::start_readwrite(&fd)?;
205
206 let ptr = unsafe {
207 nix::sys::mman::mmap(
208 None,
209 NonZero::new(size).ok_or(Error::InvalidSize(size))?,
210 nix::sys::mman::ProtFlags::PROT_READ | nix::sys::mman::ProtFlags::PROT_WRITE,
211 nix::sys::mman::MapFlags::MAP_SHARED,
212 &fd,
213 0,
214 )?
215 };
216
217 trace!("Mapping DMA memory: {ptr:?}");
218 let dma_ptr = DmaPtr(NonNull::new(ptr.as_ptr()).ok_or(Error::InvalidSize(size))?);
219 Ok(DmaMap {
220 ptr: Arc::new(Mutex::new(dma_ptr)),
221 fd,
222 shape: shape.to_vec(),
223 _marker: std::marker::PhantomData,
224 })
225 }
226}
227
228impl<T> Deref for DmaMap<T>
229where
230 T: Num + Clone + fmt::Debug,
231{
232 type Target = [T];
233
234 fn deref(&self) -> &[T] {
235 self.as_slice()
236 }
237}
238
239impl<T> DerefMut for DmaMap<T>
240where
241 T: Num + Clone + fmt::Debug,
242{
243 fn deref_mut(&mut self) -> &mut [T] {
244 self.as_mut_slice()
245 }
246}
247
248#[derive(Debug)]
249struct DmaPtr(NonNull<c_void>);
250impl Deref for DmaPtr {
251 type Target = NonNull<c_void>;
252
253 fn deref(&self) -> &Self::Target {
254 &self.0
255 }
256}
257
258unsafe impl Send for DmaPtr {}
259
260impl<T> TensorMapTrait<T> for DmaMap<T>
261where
262 T: Num + Clone + fmt::Debug,
263{
264 fn shape(&self) -> &[usize] {
265 &self.shape
266 }
267
268 fn unmap(&mut self) {
269 let ptr = self.ptr.lock().expect("Failed to lock DmaMap pointer");
270
271 if let Err(e) = unsafe { nix::sys::mman::munmap(**ptr, self.size()) } {
272 warn!("Failed to unmap DMA memory: {e}");
273 }
274
275 #[cfg(target_os = "linux")]
276 if let Err(e) = crate::dmabuf::end_readwrite(&self.fd) {
277 warn!("Failed to end read/write on DMA memory: {e}");
278 }
279 }
280
281 fn as_slice(&self) -> &[T] {
282 let ptr = self.ptr.lock().expect("Failed to lock DmaMap pointer");
283 unsafe { std::slice::from_raw_parts(ptr.as_ptr() as *const T, self.len()) }
284 }
285
286 fn as_mut_slice(&mut self) -> &mut [T] {
287 let ptr = self.ptr.lock().expect("Failed to lock DmaMap pointer");
288 unsafe { std::slice::from_raw_parts_mut(ptr.as_ptr() as *mut T, self.len()) }
289 }
290}
291
292impl<T> Drop for DmaMap<T>
293where
294 T: Num + Clone + fmt::Debug,
295{
296 fn drop(&mut self) {
297 trace!("DmaMap dropped, unmapping memory");
298 self.unmap();
299 }
300}