1use crate::{
5 error::{Error, Result},
6 TensorMap, TensorMapTrait, TensorMemory, TensorTrait,
7};
8use log::{trace, warn};
9use num_traits::Num;
10use std::{
11 ffi::c_void,
12 fmt,
13 num::NonZero,
14 ops::{Deref, DerefMut},
15 os::fd::{AsRawFd, OwnedFd},
16 ptr::NonNull,
17 sync::{Arc, Mutex},
18};
19
20#[derive(Debug)]
30pub struct DmaTensor<T>
31where
32 T: Num + Clone + fmt::Debug + Send + Sync,
33{
34 pub name: String,
35 pub fd: OwnedFd,
36 pub shape: Vec<usize>,
37 pub _marker: std::marker::PhantomData<T>,
38 #[cfg(target_os = "linux")]
39 _drm_attachment: Option<crate::dmabuf::DrmAttachment>,
40 identity: crate::BufferIdentity,
41 buf_size: usize,
45 pub(crate) mmap_offset: usize,
48 #[cfg(target_os = "linux")]
52 is_imported: bool,
53}
54
55unsafe impl<T> Send for DmaTensor<T> where T: Num + Clone + fmt::Debug + Send + Sync {}
56unsafe impl<T> Sync for DmaTensor<T> where T: Num + Clone + fmt::Debug + Send + Sync {}
57
58impl<T> TensorTrait<T> for DmaTensor<T>
59where
60 T: Num + Clone + fmt::Debug + Send + Sync,
61{
62 #[cfg(target_os = "linux")]
63 fn new(shape: &[usize], name: Option<&str>) -> Result<Self> {
64 use log::debug;
65 use nix::sys::stat::fstat;
66
67 let logical_size = shape.iter().product::<usize>() * std::mem::size_of::<T>();
68 let name = match name {
69 Some(name) => name.to_owned(),
70 None => {
71 let uuid = uuid::Uuid::new_v4().as_simple().to_string();
72 format!("/{}", &uuid[..16])
73 }
74 };
75
76 let heap = match dma_heap::Heap::new(dma_heap::HeapKind::Cma) {
77 Ok(heap) => heap,
78 Err(_) => dma_heap::Heap::new(dma_heap::HeapKind::System)?,
79 };
80
81 let dma_fd = heap.allocate(logical_size)?;
82 let stat = fstat(&dma_fd)?;
83 debug!("DMA memory stat: {stat:?}");
84 let buf_size = if stat.st_size > 0 {
85 std::cmp::max(stat.st_size as usize, logical_size)
86 } else {
87 logical_size
88 };
89
90 let drm_attachment = crate::dmabuf::DrmAttachment::new(&dma_fd, false);
91
92 Ok(DmaTensor::<T> {
93 name: name.to_owned(),
94 fd: dma_fd,
95 shape: shape.to_vec(),
96 _marker: std::marker::PhantomData,
97 _drm_attachment: drm_attachment,
98 identity: crate::BufferIdentity::new(),
99 buf_size,
100 mmap_offset: 0,
101 is_imported: false,
102 })
103 }
104
105 #[cfg(not(target_os = "linux"))]
106 fn new(_shape: &[usize], _name: Option<&str>) -> Result<Self> {
107 Err(Error::NotImplemented(
108 "DMA tensors are not supported on this platform".to_owned(),
109 ))
110 }
111
112 fn from_fd(fd: OwnedFd, shape: &[usize], name: Option<&str>) -> Result<Self> {
113 if shape.is_empty() {
114 return Err(Error::InvalidSize(0));
115 }
116
117 let logical_size = shape.iter().product::<usize>() * std::mem::size_of::<T>();
118 if logical_size == 0 {
119 return Err(Error::InvalidSize(0));
120 }
121
122 let buf_size = {
125 #[cfg(target_os = "linux")]
126 {
127 use nix::sys::stat::fstat;
128 match fstat(&fd) {
129 Ok(stat) if stat.st_size > 0 && stat.st_size as usize >= logical_size => {
130 stat.st_size as usize
131 }
132 _ => logical_size,
133 }
134 }
135 #[cfg(not(target_os = "linux"))]
136 {
137 logical_size
138 }
139 };
140
141 #[cfg(target_os = "linux")]
150 let drm_attachment = None;
151
152 Ok(DmaTensor {
153 name: name.unwrap_or("").to_owned(),
154 fd,
155 shape: shape.to_vec(),
156 _marker: std::marker::PhantomData,
157 #[cfg(target_os = "linux")]
158 _drm_attachment: drm_attachment,
159 identity: crate::BufferIdentity::new(),
160 buf_size,
161 mmap_offset: 0,
162 #[cfg(target_os = "linux")]
163 is_imported: true,
164 })
165 }
166
167 fn clone_fd(&self) -> Result<OwnedFd> {
168 Ok(self.fd.try_clone()?)
169 }
170
171 fn memory(&self) -> TensorMemory {
172 TensorMemory::Dma
173 }
174
175 fn name(&self) -> String {
176 self.name.clone()
177 }
178
179 fn shape(&self) -> &[usize] {
180 &self.shape
181 }
182
183 fn reshape(&mut self, shape: &[usize]) -> Result<()> {
184 if shape.is_empty() {
185 return Err(Error::InvalidSize(0));
186 }
187
188 let new_size = shape.iter().product::<usize>() * std::mem::size_of::<T>();
189 if new_size != self.size() {
190 return Err(Error::ShapeMismatch(format!(
191 "Cannot reshape incompatible shape: {:?} to {:?}",
192 self.shape, shape
193 )));
194 }
195
196 self.shape = shape.to_vec();
197 Ok(())
198 }
199
200 fn map(&self) -> Result<TensorMap<T>> {
201 Ok(TensorMap::Dma(DmaMap::new(
202 self.fd.try_clone()?,
203 &self.shape,
204 self.buf_size,
205 self.mmap_offset,
206 )?))
207 }
208
209 fn buffer_identity(&self) -> &crate::BufferIdentity {
210 &self.identity
211 }
212}
213
214impl<T> AsRawFd for DmaTensor<T>
215where
216 T: Num + Clone + fmt::Debug + Send + Sync,
217{
218 fn as_raw_fd(&self) -> std::os::fd::RawFd {
219 self.fd.as_raw_fd()
220 }
221}
222
223impl<T> DmaTensor<T>
224where
225 T: Num + Clone + Send + Sync + std::fmt::Debug + Send + Sync,
226{
227 pub fn try_clone(&self) -> Result<Self> {
228 let fd = self.clone_fd()?;
229 #[cfg(target_os = "linux")]
232 let drm_attachment = if self.is_imported {
233 None
234 } else {
235 crate::dmabuf::DrmAttachment::new(&fd, false)
236 };
237 Ok(Self {
238 name: self.name.clone(),
239 fd,
240 shape: self.shape.clone(),
241 _marker: std::marker::PhantomData,
242 #[cfg(target_os = "linux")]
243 _drm_attachment: drm_attachment,
244 identity: self.identity.clone(),
245 buf_size: self.buf_size,
246 mmap_offset: self.mmap_offset,
247 #[cfg(target_os = "linux")]
248 is_imported: self.is_imported,
249 })
250 }
251}
252
253#[derive(Debug)]
254pub struct DmaMap<T>
255where
256 T: Num + Clone + fmt::Debug,
257{
258 ptr: Arc<Mutex<DmaPtr>>,
259 fd: OwnedFd,
260 shape: Vec<usize>,
261 mmap_size: usize,
263 offset: usize,
265 _marker: std::marker::PhantomData<T>,
266}
267
268impl<T> DmaMap<T>
269where
270 T: Num + Clone + fmt::Debug,
271{
272 pub fn new(fd: OwnedFd, shape: &[usize], buf_size: usize, offset: usize) -> Result<Self> {
273 if shape.is_empty() {
274 return Err(Error::InvalidSize(0));
275 }
276
277 let logical_size = shape.iter().product::<usize>() * std::mem::size_of::<T>();
278 if logical_size == 0 {
279 return Err(Error::InvalidSize(0));
280 }
281
282 let total_needed = offset
288 .checked_add(logical_size)
289 .ok_or(Error::InvalidSize(0))?;
290 if total_needed > buf_size {
291 warn!(
292 "DmaMap: offset={} + logical_size={} = {} exceeds buf_size={} (fd={})",
293 offset,
294 logical_size,
295 total_needed,
296 buf_size,
297 fd.as_raw_fd()
298 );
299 return Err(Error::InvalidSize(total_needed));
300 }
301 if std::mem::size_of::<T>() > 1 && !offset.is_multiple_of(std::mem::align_of::<T>()) {
302 return Err(Error::InvalidOperation(format!(
303 "DmaMap: offset {} is not aligned to align_of::<T>()={}",
304 offset,
305 std::mem::align_of::<T>()
306 )));
307 }
308 let mmap_size = buf_size;
309
310 #[cfg(target_os = "linux")]
311 {
312 trace!("DmaMap: sync start fd={} size={mmap_size}", fd.as_raw_fd());
313 if let Err(e) = crate::dmabuf::start_readwrite(&fd) {
314 warn!(
315 "DmaMap: DMA_BUF_IOCTL_SYNC(START) failed fd={}: {e}",
316 fd.as_raw_fd()
317 );
318 return Err(Error::NixError(e));
319 }
320 }
321
322 let ptr = unsafe {
323 nix::sys::mman::mmap(
324 None,
325 NonZero::new(mmap_size).ok_or(Error::InvalidSize(mmap_size))?,
326 nix::sys::mman::ProtFlags::PROT_READ | nix::sys::mman::ProtFlags::PROT_WRITE,
327 nix::sys::mman::MapFlags::MAP_SHARED,
328 &fd,
329 0,
330 )?
331 };
332
333 trace!("Mapping DMA memory: {ptr:?}");
334 let dma_ptr = DmaPtr(NonNull::new(ptr.as_ptr()).ok_or(Error::InvalidSize(mmap_size))?);
335 Ok(DmaMap {
336 ptr: Arc::new(Mutex::new(dma_ptr)),
337 fd,
338 shape: shape.to_vec(),
339 mmap_size,
340 offset,
341 _marker: std::marker::PhantomData,
342 })
343 }
344}
345
346impl<T> Deref for DmaMap<T>
347where
348 T: Num + Clone + fmt::Debug,
349{
350 type Target = [T];
351
352 fn deref(&self) -> &[T] {
353 self.as_slice()
354 }
355}
356
357impl<T> DerefMut for DmaMap<T>
358where
359 T: Num + Clone + fmt::Debug,
360{
361 fn deref_mut(&mut self) -> &mut [T] {
362 self.as_mut_slice()
363 }
364}
365
366#[derive(Debug)]
367struct DmaPtr(NonNull<c_void>);
368impl Deref for DmaPtr {
369 type Target = NonNull<c_void>;
370
371 fn deref(&self) -> &Self::Target {
372 &self.0
373 }
374}
375
376unsafe impl Send for DmaPtr {}
377
378impl<T> TensorMapTrait<T> for DmaMap<T>
379where
380 T: Num + Clone + fmt::Debug,
381{
382 fn shape(&self) -> &[usize] {
383 &self.shape
384 }
385
386 fn unmap(&mut self) {
387 let ptr = self.ptr.lock().expect("Failed to lock DmaMap pointer");
388
389 if let Err(e) = unsafe { nix::sys::mman::munmap(**ptr, self.mmap_size) } {
390 warn!("Failed to unmap DMA memory: {e}");
391 }
392
393 #[cfg(target_os = "linux")]
394 if let Err(e) = crate::dmabuf::end_readwrite(&self.fd) {
395 warn!("Failed to end read/write on DMA memory: {e}");
396 }
397 }
398
399 fn as_slice(&self) -> &[T] {
400 let ptr = self.ptr.lock().expect("Failed to lock DmaMap pointer");
401 let base = unsafe { (ptr.as_ptr() as *const u8).add(self.offset) as *const T };
402 unsafe { std::slice::from_raw_parts(base, self.len()) }
403 }
404
405 fn as_mut_slice(&mut self) -> &mut [T] {
406 let ptr = self.ptr.lock().expect("Failed to lock DmaMap pointer");
407 let base = unsafe { (ptr.as_ptr() as *mut u8).add(self.offset) as *mut T };
408 unsafe { std::slice::from_raw_parts_mut(base, self.len()) }
409 }
410}
411
412impl<T> Drop for DmaMap<T>
413where
414 T: Num + Clone + fmt::Debug,
415{
416 fn drop(&mut self) {
417 trace!("DmaMap dropped, unmapping memory");
418 self.unmap();
419 }
420}
421
422#[cfg(test)]
423mod tests {
424 use super::*;
425
426 #[cfg(target_os = "linux")]
430 fn dummy_fd() -> std::os::fd::OwnedFd {
431 use std::os::fd::FromRawFd;
432 use std::os::unix::io::IntoRawFd;
433 let f = std::fs::File::open("/dev/null").expect("open /dev/null");
434 unsafe { std::os::fd::OwnedFd::from_raw_fd(f.into_raw_fd()) }
435 }
436
437 #[test]
439 #[cfg(target_os = "linux")]
440 fn test_dma_map_offset_exceeds_buf_size() {
441 let fd = dummy_fd();
442 let result = DmaMap::<u8>::new(fd, &[4096], 4096, 4096);
445 match result {
446 Err(Error::InvalidSize(n)) => assert_eq!(n, 8192),
447 other => panic!("expected InvalidSize(8192), got {:?}", other),
448 }
449 }
450
451 #[test]
453 #[cfg(target_os = "linux")]
454 fn test_dma_map_misaligned_offset() {
455 let fd = dummy_fd();
456 let result = DmaMap::<u32>::new(fd, &[1024], 8192, 3);
459 assert!(
460 matches!(result, Err(Error::InvalidOperation(_))),
461 "expected InvalidOperation for misaligned offset, got {:?}",
462 result
463 );
464 }
465
466 #[test]
468 #[cfg(target_os = "linux")]
469 fn test_dma_map_offset_overflow() {
470 let fd = dummy_fd();
471 let result = DmaMap::<u8>::new(fd, &[1], usize::MAX, usize::MAX);
473 assert!(
474 matches!(result, Err(Error::InvalidSize(0))),
475 "expected InvalidSize(0) on overflow, got {:?}",
476 result
477 );
478 }
479
480 #[test]
481 #[cfg(target_os = "linux")]
482 fn test_dma_map_with_offset() {
483 use crate::{Tensor, TensorMapTrait, TensorMemory, TensorTrait};
484
485 let total_size: usize = 4096 * 4; let offset: usize = 4096; let data_size: usize = 4096; let large_buf = match Tensor::<u8>::new(&[total_size], Some(TensorMemory::Dma), None) {
491 Ok(buf) => buf,
492 Err(_) => {
493 eprintln!("SKIPPED: DMA not available");
494 return;
495 }
496 };
497
498 {
500 let mut map = large_buf.map().unwrap();
501 map.as_mut_slice().fill(0xAA);
502 }
503
504 let fd = large_buf.clone_fd().unwrap();
506 let mut offset_tensor = Tensor::<u8>::from_fd(fd, &[data_size], None).unwrap();
507 offset_tensor.set_plane_offset(offset);
508
509 let mut map = offset_tensor.map().unwrap();
511 let slice = map.as_mut_slice();
512
513 assert_eq!(slice.len(), data_size);
515 assert!(
516 slice.iter().all(|&b| b == 0xAA),
517 "Offset tensor map should see sentinel data at offset"
518 );
519
520 slice.fill(0xBB);
522 drop(map);
523
524 {
527 let map = large_buf.map().unwrap();
528 let buf = map.as_slice();
529 assert!(
530 buf[..offset].iter().all(|&b| b == 0xAA),
531 "Data before offset should be unchanged"
532 );
533 assert!(
534 buf[offset..offset + data_size].iter().all(|&b| b == 0xBB),
535 "Data at offset should be 0xBB"
536 );
537 }
538 }
539}