1use crate::{
5 error::{Error, Result},
6 TensorMap, TensorMapTrait, TensorMemory, TensorTrait,
7};
8use log::{trace, warn};
9use num_traits::Num;
10use std::{
11 ffi::c_void,
12 fmt,
13 num::NonZero,
14 ops::{Deref, DerefMut},
15 os::fd::{AsRawFd, OwnedFd},
16 ptr::NonNull,
17 sync::{Arc, Mutex},
18};
19
20#[derive(Debug)]
30pub struct DmaTensor<T>
31where
32 T: Num + Clone + fmt::Debug + Send + Sync,
33{
34 pub name: String,
35 pub fd: OwnedFd,
36 pub shape: Vec<usize>,
37 pub _marker: std::marker::PhantomData<T>,
38 #[cfg(target_os = "linux")]
39 _drm_attachment: Option<crate::dmabuf::DrmAttachment>,
40 identity: crate::BufferIdentity,
41 pub(crate) buf_size: usize,
45 pub(crate) mmap_offset: usize,
48 #[cfg(target_os = "linux")]
57 pub(crate) is_imported: bool,
58}
59
60unsafe impl<T> Send for DmaTensor<T> where T: Num + Clone + fmt::Debug + Send + Sync {}
61unsafe impl<T> Sync for DmaTensor<T> where T: Num + Clone + fmt::Debug + Send + Sync {}
62
63impl<T> TensorTrait<T> for DmaTensor<T>
64where
65 T: Num + Clone + fmt::Debug + Send + Sync,
66{
67 #[cfg(target_os = "linux")]
68 fn new(shape: &[usize], name: Option<&str>) -> Result<Self> {
69 use log::debug;
70 use nix::sys::stat::fstat;
71
72 let logical_size = shape.iter().product::<usize>() * std::mem::size_of::<T>();
73 let name = match name {
74 Some(name) => name.to_owned(),
75 None => {
76 let uuid = uuid::Uuid::new_v4().as_simple().to_string();
77 format!("/{}", &uuid[..16])
78 }
79 };
80
81 let heap = match dma_heap::Heap::new(dma_heap::HeapKind::Cma) {
82 Ok(heap) => heap,
83 Err(_) => dma_heap::Heap::new(dma_heap::HeapKind::System)?,
84 };
85
86 let dma_fd = heap.allocate(logical_size)?;
87 let stat = fstat(&dma_fd)?;
88 debug!("DMA memory stat: {stat:?}");
89 let buf_size = if stat.st_size > 0 {
90 std::cmp::max(stat.st_size as usize, logical_size)
91 } else {
92 logical_size
93 };
94
95 let drm_attachment = crate::dmabuf::DrmAttachment::new(&dma_fd, false);
96
97 Ok(DmaTensor::<T> {
98 name: name.to_owned(),
99 fd: dma_fd,
100 shape: shape.to_vec(),
101 _marker: std::marker::PhantomData,
102 _drm_attachment: drm_attachment,
103 identity: crate::BufferIdentity::new(),
104 buf_size,
105 mmap_offset: 0,
106 is_imported: false,
107 })
108 }
109
110 #[cfg(not(target_os = "linux"))]
111 fn new(_shape: &[usize], _name: Option<&str>) -> Result<Self> {
112 Err(Error::NotImplemented(
113 "DMA tensors are not supported on this platform".to_owned(),
114 ))
115 }
116
117 fn from_fd(fd: OwnedFd, shape: &[usize], name: Option<&str>) -> Result<Self> {
118 if shape.is_empty() {
119 return Err(Error::InvalidSize(0));
120 }
121
122 let logical_size = shape.iter().product::<usize>() * std::mem::size_of::<T>();
123 if logical_size == 0 {
124 return Err(Error::InvalidSize(0));
125 }
126
127 let buf_size = {
130 #[cfg(target_os = "linux")]
131 {
132 use nix::sys::stat::fstat;
133 match fstat(&fd) {
134 Ok(stat) if stat.st_size > 0 && stat.st_size as usize >= logical_size => {
135 stat.st_size as usize
136 }
137 _ => logical_size,
138 }
139 }
140 #[cfg(not(target_os = "linux"))]
141 {
142 logical_size
143 }
144 };
145
146 #[cfg(target_os = "linux")]
155 let drm_attachment = None;
156
157 Ok(DmaTensor {
158 name: name.unwrap_or("").to_owned(),
159 fd,
160 shape: shape.to_vec(),
161 _marker: std::marker::PhantomData,
162 #[cfg(target_os = "linux")]
163 _drm_attachment: drm_attachment,
164 identity: crate::BufferIdentity::new(),
165 buf_size,
166 mmap_offset: 0,
167 #[cfg(target_os = "linux")]
168 is_imported: true,
169 })
170 }
171
172 fn clone_fd(&self) -> Result<OwnedFd> {
173 Ok(self.fd.try_clone()?)
174 }
175
176 fn memory(&self) -> TensorMemory {
177 TensorMemory::Dma
178 }
179
180 fn name(&self) -> String {
181 self.name.clone()
182 }
183
184 fn shape(&self) -> &[usize] {
185 &self.shape
186 }
187
188 fn reshape(&mut self, shape: &[usize]) -> Result<()> {
189 if shape.is_empty() {
190 return Err(Error::InvalidSize(0));
191 }
192
193 let new_size = shape.iter().product::<usize>() * std::mem::size_of::<T>();
194 if new_size != self.size() {
195 return Err(Error::ShapeMismatch(format!(
196 "Cannot reshape incompatible shape: {:?} to {:?}",
197 self.shape, shape
198 )));
199 }
200
201 self.shape = shape.to_vec();
202 Ok(())
203 }
204
205 fn map(&self) -> Result<TensorMap<T>> {
206 Ok(TensorMap::Dma(DmaMap::new(
207 self.fd.try_clone()?,
208 &self.shape,
209 self.buf_size,
210 self.mmap_offset,
211 )?))
212 }
213
214 fn buffer_identity(&self) -> &crate::BufferIdentity {
215 &self.identity
216 }
217}
218
219impl<T> AsRawFd for DmaTensor<T>
220where
221 T: Num + Clone + fmt::Debug + Send + Sync,
222{
223 fn as_raw_fd(&self) -> std::os::fd::RawFd {
224 self.fd.as_raw_fd()
225 }
226}
227
228impl<T> DmaTensor<T>
229where
230 T: Num + Clone + Send + Sync + std::fmt::Debug + Send + Sync,
231{
232 #[cfg(target_os = "linux")]
250 pub(crate) fn new_with_byte_size(
251 shape: &[usize],
252 byte_size: usize,
253 name: Option<&str>,
254 ) -> Result<Self> {
255 use log::debug;
256 use nix::sys::stat::fstat;
257
258 let logical_elems = shape
263 .iter()
264 .copied()
265 .try_fold(1usize, |acc, dim| acc.checked_mul(dim))
266 .ok_or_else(|| {
267 Error::InvalidArgument(format!(
268 "DmaTensor::new_with_byte_size: shape.product() overflows usize \
269 (shape={shape:?})"
270 ))
271 })?;
272 let logical_size = logical_elems
273 .checked_mul(std::mem::size_of::<T>())
274 .ok_or_else(|| {
275 Error::InvalidArgument(format!(
276 "DmaTensor::new_with_byte_size: logical_elems {logical_elems} × \
277 sizeof::<T>={} overflows usize (shape={shape:?})",
278 std::mem::size_of::<T>()
279 ))
280 })?;
281 if byte_size < logical_size {
282 return Err(Error::InvalidArgument(format!(
283 "DmaTensor::new_with_byte_size: byte_size {byte_size} < logical {logical_size} \
284 (shape={shape:?}, sizeof::<T>={})",
285 std::mem::size_of::<T>()
286 )));
287 }
288 let name = match name {
289 Some(name) => name.to_owned(),
290 None => {
291 let uuid = uuid::Uuid::new_v4().as_simple().to_string();
292 format!("/{}", &uuid[..16])
293 }
294 };
295
296 let heap = match dma_heap::Heap::new(dma_heap::HeapKind::Cma) {
297 Ok(heap) => heap,
298 Err(_) => dma_heap::Heap::new(dma_heap::HeapKind::System)?,
299 };
300
301 let dma_fd = heap.allocate(byte_size)?;
302 let stat = fstat(&dma_fd)?;
303 debug!("DMA padded memory stat: {stat:?}");
304 let buf_size = if stat.st_size > 0 {
305 std::cmp::max(stat.st_size as usize, byte_size)
306 } else {
307 byte_size
308 };
309
310 let drm_attachment = crate::dmabuf::DrmAttachment::new(&dma_fd, false);
311
312 Ok(DmaTensor::<T> {
313 name,
314 fd: dma_fd,
315 shape: shape.to_vec(),
316 _marker: std::marker::PhantomData,
317 _drm_attachment: drm_attachment,
318 identity: crate::BufferIdentity::new(),
319 buf_size,
320 mmap_offset: 0,
321 is_imported: false,
322 })
323 }
324
325 #[cfg(not(target_os = "linux"))]
326 pub(crate) fn new_with_byte_size(
327 _shape: &[usize],
328 _byte_size: usize,
329 _name: Option<&str>,
330 ) -> Result<Self> {
331 Err(Error::NotImplemented(
332 "DMA tensors are not supported on this platform".to_owned(),
333 ))
334 }
335
336 pub(crate) fn map_with_byte_size(&self, byte_size: usize) -> Result<DmaMap<T>> {
344 DmaMap::new_with_byte_size(
345 self.fd.try_clone()?,
346 &self.shape,
347 self.buf_size,
348 self.mmap_offset,
349 byte_size,
350 )
351 }
352
353 pub fn try_clone(&self) -> Result<Self> {
354 let fd = self.clone_fd()?;
355 #[cfg(target_os = "linux")]
358 let drm_attachment = if self.is_imported {
359 None
360 } else {
361 crate::dmabuf::DrmAttachment::new(&fd, false)
362 };
363 Ok(Self {
364 name: self.name.clone(),
365 fd,
366 shape: self.shape.clone(),
367 _marker: std::marker::PhantomData,
368 #[cfg(target_os = "linux")]
369 _drm_attachment: drm_attachment,
370 identity: self.identity.clone(),
371 buf_size: self.buf_size,
372 mmap_offset: self.mmap_offset,
373 #[cfg(target_os = "linux")]
374 is_imported: self.is_imported,
375 })
376 }
377}
378
379#[derive(Debug)]
380pub struct DmaMap<T>
381where
382 T: Num + Clone + fmt::Debug,
383{
384 ptr: Arc<Mutex<DmaPtr>>,
385 fd: OwnedFd,
386 shape: Vec<usize>,
387 mmap_size: usize,
389 offset: usize,
391 byte_size_override: Option<usize>,
399 _marker: std::marker::PhantomData<T>,
400}
401
402impl<T> DmaMap<T>
403where
404 T: Num + Clone + fmt::Debug,
405{
406 pub fn new(fd: OwnedFd, shape: &[usize], buf_size: usize, offset: usize) -> Result<Self> {
407 Self::new_internal(fd, shape, buf_size, offset, None)
408 }
409
410 pub(crate) fn new_with_byte_size(
421 fd: OwnedFd,
422 shape: &[usize],
423 buf_size: usize,
424 offset: usize,
425 byte_size: usize,
426 ) -> Result<Self> {
427 Self::new_internal(fd, shape, buf_size, offset, Some(byte_size))
428 }
429
430 fn new_internal(
431 fd: OwnedFd,
432 shape: &[usize],
433 buf_size: usize,
434 offset: usize,
435 byte_size_override: Option<usize>,
436 ) -> Result<Self> {
437 if shape.is_empty() {
438 return Err(Error::InvalidSize(0));
439 }
440
441 let logical_size = shape.iter().product::<usize>() * std::mem::size_of::<T>();
442 if logical_size == 0 {
443 return Err(Error::InvalidSize(0));
444 }
445
446 let total_needed = offset
452 .checked_add(logical_size)
453 .ok_or(Error::InvalidSize(0))?;
454 if total_needed > buf_size {
455 warn!(
456 "DmaMap: offset={} + logical_size={} = {} exceeds buf_size={} (fd={})",
457 offset,
458 logical_size,
459 total_needed,
460 buf_size,
461 fd.as_raw_fd()
462 );
463 return Err(Error::InvalidSize(total_needed));
464 }
465 if std::mem::size_of::<T>() > 1 && !offset.is_multiple_of(std::mem::align_of::<T>()) {
466 return Err(Error::InvalidOperation(format!(
467 "DmaMap: offset {} is not aligned to align_of::<T>()={}",
468 offset,
469 std::mem::align_of::<T>()
470 )));
471 }
472
473 if let Some(byte_size) = byte_size_override {
478 if byte_size == 0 {
479 return Err(Error::InvalidSize(0));
480 }
481 let t_size = std::mem::size_of::<T>();
482 if t_size > 1 && !byte_size.is_multiple_of(t_size) {
483 return Err(Error::InvalidOperation(format!(
484 "DmaMap: byte_size_override {byte_size} is not a multiple of sizeof::<T>()={t_size}"
485 )));
486 }
487 let available = buf_size.saturating_sub(offset);
488 if byte_size > available {
489 return Err(Error::InvalidSize(byte_size));
490 }
491 }
492
493 let mmap_size = buf_size;
494
495 #[cfg(target_os = "linux")]
496 {
497 trace!("DmaMap: sync start fd={} size={mmap_size}", fd.as_raw_fd());
498 if let Err(e) = crate::dmabuf::start_readwrite(&fd) {
499 warn!(
500 "DmaMap: DMA_BUF_IOCTL_SYNC(START) failed fd={}: {e}",
501 fd.as_raw_fd()
502 );
503 return Err(Error::NixError(e));
504 }
505 }
506
507 let ptr = unsafe {
508 nix::sys::mman::mmap(
509 None,
510 NonZero::new(mmap_size).ok_or(Error::InvalidSize(mmap_size))?,
511 nix::sys::mman::ProtFlags::PROT_READ | nix::sys::mman::ProtFlags::PROT_WRITE,
512 nix::sys::mman::MapFlags::MAP_SHARED,
513 &fd,
514 0,
515 )?
516 };
517
518 trace!("Mapping DMA memory: {ptr:?}");
519 let dma_ptr = DmaPtr(NonNull::new(ptr.as_ptr()).ok_or(Error::InvalidSize(mmap_size))?);
520 Ok(DmaMap {
521 ptr: Arc::new(Mutex::new(dma_ptr)),
522 fd,
523 shape: shape.to_vec(),
524 mmap_size,
525 offset,
526 byte_size_override,
527 _marker: std::marker::PhantomData,
528 })
529 }
530}
531
532impl<T> Deref for DmaMap<T>
533where
534 T: Num + Clone + fmt::Debug,
535{
536 type Target = [T];
537
538 fn deref(&self) -> &[T] {
539 self.as_slice()
540 }
541}
542
543impl<T> DerefMut for DmaMap<T>
544where
545 T: Num + Clone + fmt::Debug,
546{
547 fn deref_mut(&mut self) -> &mut [T] {
548 self.as_mut_slice()
549 }
550}
551
552#[derive(Debug)]
553struct DmaPtr(NonNull<c_void>);
554impl Deref for DmaPtr {
555 type Target = NonNull<c_void>;
556
557 fn deref(&self) -> &Self::Target {
558 &self.0
559 }
560}
561
562unsafe impl Send for DmaPtr {}
563
564impl<T> TensorMapTrait<T> for DmaMap<T>
565where
566 T: Num + Clone + fmt::Debug,
567{
568 fn shape(&self) -> &[usize] {
569 &self.shape
570 }
571
572 fn unmap(&mut self) {
573 let ptr = self.ptr.lock().expect("Failed to lock DmaMap pointer");
574
575 if let Err(e) = unsafe { nix::sys::mman::munmap(**ptr, self.mmap_size) } {
576 warn!("Failed to unmap DMA memory: {e}");
577 }
578
579 #[cfg(target_os = "linux")]
580 if let Err(e) = crate::dmabuf::end_readwrite(&self.fd) {
581 warn!("Failed to end read/write on DMA memory: {e}");
582 }
583 }
584
585 fn as_slice(&self) -> &[T] {
586 let ptr = self.ptr.lock().expect("Failed to lock DmaMap pointer");
587 let base = unsafe { (ptr.as_ptr() as *const u8).add(self.offset) as *const T };
588 unsafe { std::slice::from_raw_parts(base, self.slice_len_elems()) }
589 }
590
591 fn as_mut_slice(&mut self) -> &mut [T] {
592 let ptr = self.ptr.lock().expect("Failed to lock DmaMap pointer");
593 let base = unsafe { (ptr.as_ptr() as *mut u8).add(self.offset) as *mut T };
594 unsafe { std::slice::from_raw_parts_mut(base, self.slice_len_elems()) }
595 }
596}
597
598impl<T> DmaMap<T>
599where
600 T: Num + Clone + fmt::Debug,
601{
602 fn slice_len_elems(&self) -> usize {
607 match self.byte_size_override {
608 Some(bytes) => bytes / std::mem::size_of::<T>(),
609 None => self.shape.iter().product(),
610 }
611 }
612}
613
614impl<T> Drop for DmaMap<T>
615where
616 T: Num + Clone + fmt::Debug,
617{
618 fn drop(&mut self) {
619 trace!("DmaMap dropped, unmapping memory");
620 self.unmap();
621 }
622}
623
624#[cfg(test)]
625mod tests {
626 use super::*;
627
628 #[cfg(target_os = "linux")]
632 fn dummy_fd() -> std::os::fd::OwnedFd {
633 use std::os::fd::FromRawFd;
634 use std::os::unix::io::IntoRawFd;
635 let f = std::fs::File::open("/dev/null").expect("open /dev/null");
636 unsafe { std::os::fd::OwnedFd::from_raw_fd(f.into_raw_fd()) }
637 }
638
639 #[test]
641 #[cfg(target_os = "linux")]
642 fn test_dma_map_offset_exceeds_buf_size() {
643 let fd = dummy_fd();
644 let result = DmaMap::<u8>::new(fd, &[4096], 4096, 4096);
647 match result {
648 Err(Error::InvalidSize(n)) => assert_eq!(n, 8192),
649 other => panic!("expected InvalidSize(8192), got {:?}", other),
650 }
651 }
652
653 #[test]
655 #[cfg(target_os = "linux")]
656 fn test_dma_map_misaligned_offset() {
657 let fd = dummy_fd();
658 let result = DmaMap::<u32>::new(fd, &[1024], 8192, 3);
661 assert!(
662 matches!(result, Err(Error::InvalidOperation(_))),
663 "expected InvalidOperation for misaligned offset, got {:?}",
664 result
665 );
666 }
667
668 #[test]
670 #[cfg(target_os = "linux")]
671 fn test_dma_map_offset_overflow() {
672 let fd = dummy_fd();
673 let result = DmaMap::<u8>::new(fd, &[1], usize::MAX, usize::MAX);
675 assert!(
676 matches!(result, Err(Error::InvalidSize(0))),
677 "expected InvalidSize(0) on overflow, got {:?}",
678 result
679 );
680 }
681
682 #[test]
683 #[cfg(target_os = "linux")]
684 fn test_dma_map_with_offset() {
685 use crate::{Tensor, TensorMapTrait, TensorMemory, TensorTrait};
686
687 let total_size: usize = 4096 * 4; let offset: usize = 4096; let data_size: usize = 4096; let large_buf = match Tensor::<u8>::new(&[total_size], Some(TensorMemory::Dma), None) {
693 Ok(buf) => buf,
694 Err(_) => {
695 eprintln!("SKIPPED: DMA not available");
696 return;
697 }
698 };
699
700 {
702 let mut map = large_buf.map().unwrap();
703 map.as_mut_slice().fill(0xAA);
704 }
705
706 let fd = large_buf.clone_fd().unwrap();
708 let mut offset_tensor = Tensor::<u8>::from_fd(fd, &[data_size], None).unwrap();
709 offset_tensor.set_plane_offset(offset);
710
711 let mut map = offset_tensor.map().unwrap();
713 let slice = map.as_mut_slice();
714
715 assert_eq!(slice.len(), data_size);
717 assert!(
718 slice.iter().all(|&b| b == 0xAA),
719 "Offset tensor map should see sentinel data at offset"
720 );
721
722 slice.fill(0xBB);
724 drop(map);
725
726 {
729 let map = large_buf.map().unwrap();
730 let buf = map.as_slice();
731 assert!(
732 buf[..offset].iter().all(|&b| b == 0xAA),
733 "Data before offset should be unchanged"
734 );
735 assert!(
736 buf[offset..offset + data_size].iter().all(|&b| b == 0xBB),
737 "Data at offset should be 0xBB"
738 );
739 }
740 }
741}