Skip to main content

edgefirst_tensor/
dma.rs

1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4use crate::{
5    error::{Error, Result},
6    TensorMap, TensorMapTrait, TensorMemory, TensorTrait,
7};
8use log::{trace, warn};
9use num_traits::Num;
10use std::{
11    ffi::c_void,
12    fmt,
13    num::NonZero,
14    ops::{Deref, DerefMut},
15    os::fd::{AsRawFd, OwnedFd},
16    ptr::NonNull,
17    sync::{Arc, Mutex},
18};
19
20/// A tensor backed by DMA (Direct Memory Access) memory.
21///
22/// On Linux, a DRM PRIME attachment is created to enable proper CPU cache
23/// coherency via `DMA_BUF_IOCTL_SYNC`. Without this attachment, sync ioctls
24/// are no-ops on cached CMA heaps.
25#[derive(Debug)]
26pub struct DmaTensor<T>
27where
28    T: Num + Clone + fmt::Debug + Send + Sync,
29{
30    pub name: String,
31    pub fd: OwnedFd,
32    pub shape: Vec<usize>,
33    pub _marker: std::marker::PhantomData<T>,
34    #[cfg(target_os = "linux")]
35    _drm_attachment: Option<crate::dmabuf::DrmAttachment>,
36    identity: crate::BufferIdentity,
37    /// Actual buffer size in bytes (from fstat at creation time).
38    /// May be larger than shape.product() * sizeof(T) for externally
39    /// allocated buffers with row padding.
40    buf_size: usize,
41    /// Byte offset into the DMA buffer where the tensor data begins.
42    /// Set via `Tensor::set_plane_offset` for sub-region imports.
43    pub(crate) mmap_offset: usize,
44    /// Whether this tensor was created via `from_fd()` (imported from an
45    /// external allocator).  Propagated through `try_clone()` so that DRM
46    /// PRIME import failures are logged at DEBUG rather than WARN.
47    #[cfg(target_os = "linux")]
48    is_imported: bool,
49}
50
51unsafe impl<T> Send for DmaTensor<T> where T: Num + Clone + fmt::Debug + Send + Sync {}
52unsafe impl<T> Sync for DmaTensor<T> where T: Num + Clone + fmt::Debug + Send + Sync {}
53
54impl<T> TensorTrait<T> for DmaTensor<T>
55where
56    T: Num + Clone + fmt::Debug + Send + Sync,
57{
58    #[cfg(target_os = "linux")]
59    fn new(shape: &[usize], name: Option<&str>) -> Result<Self> {
60        use log::debug;
61        use nix::sys::stat::fstat;
62
63        let logical_size = shape.iter().product::<usize>() * std::mem::size_of::<T>();
64        let name = match name {
65            Some(name) => name.to_owned(),
66            None => {
67                let uuid = uuid::Uuid::new_v4().as_simple().to_string();
68                format!("/{}", &uuid[..16])
69            }
70        };
71
72        let heap = match dma_heap::Heap::new(dma_heap::HeapKind::Cma) {
73            Ok(heap) => heap,
74            Err(_) => dma_heap::Heap::new(dma_heap::HeapKind::System)?,
75        };
76
77        let dma_fd = heap.allocate(logical_size)?;
78        let stat = fstat(&dma_fd)?;
79        debug!("DMA memory stat: {stat:?}");
80        let buf_size = if stat.st_size > 0 {
81            std::cmp::max(stat.st_size as usize, logical_size)
82        } else {
83            logical_size
84        };
85
86        let drm_attachment = crate::dmabuf::DrmAttachment::new(&dma_fd, false);
87
88        Ok(DmaTensor::<T> {
89            name: name.to_owned(),
90            fd: dma_fd,
91            shape: shape.to_vec(),
92            _marker: std::marker::PhantomData,
93            _drm_attachment: drm_attachment,
94            identity: crate::BufferIdentity::new(),
95            buf_size,
96            mmap_offset: 0,
97            is_imported: false,
98        })
99    }
100
101    #[cfg(not(target_os = "linux"))]
102    fn new(_shape: &[usize], _name: Option<&str>) -> Result<Self> {
103        Err(Error::NotImplemented(
104            "DMA tensors are not supported on this platform".to_owned(),
105        ))
106    }
107
108    fn from_fd(fd: OwnedFd, shape: &[usize], name: Option<&str>) -> Result<Self> {
109        if shape.is_empty() {
110            return Err(Error::InvalidSize(0));
111        }
112
113        let logical_size = shape.iter().product::<usize>() * std::mem::size_of::<T>();
114        if logical_size == 0 {
115            return Err(Error::InvalidSize(0));
116        }
117
118        // fstat may return st_size=0 for DMA-BUF fds on some kernels;
119        // fall back to logical_size in that case.
120        let buf_size = {
121            #[cfg(target_os = "linux")]
122            {
123                use nix::sys::stat::fstat;
124                match fstat(&fd) {
125                    Ok(stat) if stat.st_size > 0 && stat.st_size as usize >= logical_size => {
126                        stat.st_size as usize
127                    }
128                    _ => logical_size,
129                }
130            }
131            #[cfg(not(target_os = "linux"))]
132            {
133                logical_size
134            }
135        };
136
137        #[cfg(target_os = "linux")]
138        let drm_attachment = crate::dmabuf::DrmAttachment::new(&fd, true);
139
140        Ok(DmaTensor {
141            name: name.unwrap_or("").to_owned(),
142            fd,
143            shape: shape.to_vec(),
144            _marker: std::marker::PhantomData,
145            #[cfg(target_os = "linux")]
146            _drm_attachment: drm_attachment,
147            identity: crate::BufferIdentity::new(),
148            buf_size,
149            mmap_offset: 0,
150            #[cfg(target_os = "linux")]
151            is_imported: true,
152        })
153    }
154
155    fn clone_fd(&self) -> Result<OwnedFd> {
156        Ok(self.fd.try_clone()?)
157    }
158
159    fn memory(&self) -> TensorMemory {
160        TensorMemory::Dma
161    }
162
163    fn name(&self) -> String {
164        self.name.clone()
165    }
166
167    fn shape(&self) -> &[usize] {
168        &self.shape
169    }
170
171    fn reshape(&mut self, shape: &[usize]) -> Result<()> {
172        if shape.is_empty() {
173            return Err(Error::InvalidSize(0));
174        }
175
176        let new_size = shape.iter().product::<usize>() * std::mem::size_of::<T>();
177        if new_size != self.size() {
178            return Err(Error::ShapeMismatch(format!(
179                "Cannot reshape incompatible shape: {:?} to {:?}",
180                self.shape, shape
181            )));
182        }
183
184        self.shape = shape.to_vec();
185        Ok(())
186    }
187
188    fn map(&self) -> Result<TensorMap<T>> {
189        Ok(TensorMap::Dma(DmaMap::new(
190            self.fd.try_clone()?,
191            &self.shape,
192            self.buf_size,
193            self.mmap_offset,
194        )?))
195    }
196
197    fn buffer_identity(&self) -> &crate::BufferIdentity {
198        &self.identity
199    }
200}
201
202impl<T> AsRawFd for DmaTensor<T>
203where
204    T: Num + Clone + fmt::Debug + Send + Sync,
205{
206    fn as_raw_fd(&self) -> std::os::fd::RawFd {
207        self.fd.as_raw_fd()
208    }
209}
210
211impl<T> DmaTensor<T>
212where
213    T: Num + Clone + Send + Sync + std::fmt::Debug + Send + Sync,
214{
215    pub fn try_clone(&self) -> Result<Self> {
216        let fd = self.clone_fd()?;
217        #[cfg(target_os = "linux")]
218        let drm_attachment = crate::dmabuf::DrmAttachment::new(&fd, self.is_imported);
219        Ok(Self {
220            name: self.name.clone(),
221            fd,
222            shape: self.shape.clone(),
223            _marker: std::marker::PhantomData,
224            #[cfg(target_os = "linux")]
225            _drm_attachment: drm_attachment,
226            identity: self.identity.clone(),
227            buf_size: self.buf_size,
228            mmap_offset: self.mmap_offset,
229            #[cfg(target_os = "linux")]
230            is_imported: self.is_imported,
231        })
232    }
233}
234
235#[derive(Debug)]
236pub struct DmaMap<T>
237where
238    T: Num + Clone + fmt::Debug,
239{
240    ptr: Arc<Mutex<DmaPtr>>,
241    fd: OwnedFd,
242    shape: Vec<usize>,
243    /// Actual mmap'd size (may be > shape.product() * sizeof(T) for padded buffers).
244    mmap_size: usize,
245    /// Byte offset into the mmap'd region where tensor data begins.
246    offset: usize,
247    _marker: std::marker::PhantomData<T>,
248}
249
250impl<T> DmaMap<T>
251where
252    T: Num + Clone + fmt::Debug,
253{
254    pub fn new(fd: OwnedFd, shape: &[usize], buf_size: usize, offset: usize) -> Result<Self> {
255        if shape.is_empty() {
256            return Err(Error::InvalidSize(0));
257        }
258
259        let logical_size = shape.iter().product::<usize>() * std::mem::size_of::<T>();
260        if logical_size == 0 {
261            return Err(Error::InvalidSize(0));
262        }
263
264        // Use the buffer's actual size (from fstat at DmaTensor creation).
265        // as_slice() uses the logical element count from shape.
266        // When an offset is present (sub-region of a larger DMA-BUF), verify
267        // that offset + logical_size fits within the allocated buffer — mapping
268        // beyond buf_size would cause SIGBUS on access.
269        let total_needed = offset
270            .checked_add(logical_size)
271            .ok_or(Error::InvalidSize(0))?;
272        if total_needed > buf_size {
273            warn!(
274                "DmaMap: offset={} + logical_size={} = {} exceeds buf_size={} (fd={})",
275                offset,
276                logical_size,
277                total_needed,
278                buf_size,
279                fd.as_raw_fd()
280            );
281            return Err(Error::InvalidSize(total_needed));
282        }
283        if std::mem::size_of::<T>() > 1 && !offset.is_multiple_of(std::mem::align_of::<T>()) {
284            return Err(Error::InvalidOperation(format!(
285                "DmaMap: offset {} is not aligned to align_of::<T>()={}",
286                offset,
287                std::mem::align_of::<T>()
288            )));
289        }
290        let mmap_size = buf_size;
291
292        #[cfg(target_os = "linux")]
293        {
294            trace!("DmaMap: sync start fd={} size={mmap_size}", fd.as_raw_fd());
295            if let Err(e) = crate::dmabuf::start_readwrite(&fd) {
296                warn!(
297                    "DmaMap: DMA_BUF_IOCTL_SYNC(START) failed fd={}: {e}",
298                    fd.as_raw_fd()
299                );
300                return Err(Error::NixError(e));
301            }
302        }
303
304        let ptr = unsafe {
305            nix::sys::mman::mmap(
306                None,
307                NonZero::new(mmap_size).ok_or(Error::InvalidSize(mmap_size))?,
308                nix::sys::mman::ProtFlags::PROT_READ | nix::sys::mman::ProtFlags::PROT_WRITE,
309                nix::sys::mman::MapFlags::MAP_SHARED,
310                &fd,
311                0,
312            )?
313        };
314
315        trace!("Mapping DMA memory: {ptr:?}");
316        let dma_ptr = DmaPtr(NonNull::new(ptr.as_ptr()).ok_or(Error::InvalidSize(mmap_size))?);
317        Ok(DmaMap {
318            ptr: Arc::new(Mutex::new(dma_ptr)),
319            fd,
320            shape: shape.to_vec(),
321            mmap_size,
322            offset,
323            _marker: std::marker::PhantomData,
324        })
325    }
326}
327
328impl<T> Deref for DmaMap<T>
329where
330    T: Num + Clone + fmt::Debug,
331{
332    type Target = [T];
333
334    fn deref(&self) -> &[T] {
335        self.as_slice()
336    }
337}
338
339impl<T> DerefMut for DmaMap<T>
340where
341    T: Num + Clone + fmt::Debug,
342{
343    fn deref_mut(&mut self) -> &mut [T] {
344        self.as_mut_slice()
345    }
346}
347
348#[derive(Debug)]
349struct DmaPtr(NonNull<c_void>);
350impl Deref for DmaPtr {
351    type Target = NonNull<c_void>;
352
353    fn deref(&self) -> &Self::Target {
354        &self.0
355    }
356}
357
358unsafe impl Send for DmaPtr {}
359
360impl<T> TensorMapTrait<T> for DmaMap<T>
361where
362    T: Num + Clone + fmt::Debug,
363{
364    fn shape(&self) -> &[usize] {
365        &self.shape
366    }
367
368    fn unmap(&mut self) {
369        let ptr = self.ptr.lock().expect("Failed to lock DmaMap pointer");
370
371        if let Err(e) = unsafe { nix::sys::mman::munmap(**ptr, self.mmap_size) } {
372            warn!("Failed to unmap DMA memory: {e}");
373        }
374
375        #[cfg(target_os = "linux")]
376        if let Err(e) = crate::dmabuf::end_readwrite(&self.fd) {
377            warn!("Failed to end read/write on DMA memory: {e}");
378        }
379    }
380
381    fn as_slice(&self) -> &[T] {
382        let ptr = self.ptr.lock().expect("Failed to lock DmaMap pointer");
383        let base = unsafe { (ptr.as_ptr() as *const u8).add(self.offset) as *const T };
384        unsafe { std::slice::from_raw_parts(base, self.len()) }
385    }
386
387    fn as_mut_slice(&mut self) -> &mut [T] {
388        let ptr = self.ptr.lock().expect("Failed to lock DmaMap pointer");
389        let base = unsafe { (ptr.as_ptr() as *mut u8).add(self.offset) as *mut T };
390        unsafe { std::slice::from_raw_parts_mut(base, self.len()) }
391    }
392}
393
394impl<T> Drop for DmaMap<T>
395where
396    T: Num + Clone + fmt::Debug,
397{
398    fn drop(&mut self) {
399        trace!("DmaMap dropped, unmapping memory");
400        self.unmap();
401    }
402}
403
404#[cfg(test)]
405mod tests {
406    use super::*;
407
408    /// Returns a valid fd backed by /dev/null.  The new error paths in
409    /// DmaMap::new() all fire before any fd-specific syscall (mmap,
410    /// DMA_BUF_IOCTL_SYNC), so any readable fd is sufficient.
411    #[cfg(target_os = "linux")]
412    fn dummy_fd() -> std::os::fd::OwnedFd {
413        use std::os::fd::FromRawFd;
414        use std::os::unix::io::IntoRawFd;
415        let f = std::fs::File::open("/dev/null").expect("open /dev/null");
416        unsafe { std::os::fd::OwnedFd::from_raw_fd(f.into_raw_fd()) }
417    }
418
419    /// offset + logical_size exceeds buf_size — must return InvalidSize.
420    #[test]
421    #[cfg(target_os = "linux")]
422    fn test_dma_map_offset_exceeds_buf_size() {
423        let fd = dummy_fd();
424        // shape=[4096] u8 → logical_size=4096; offset=4096 → total_needed=8192
425        // buf_size=4096 < 8192 → error
426        let result = DmaMap::<u8>::new(fd, &[4096], 4096, 4096);
427        match result {
428            Err(Error::InvalidSize(n)) => assert_eq!(n, 8192),
429            other => panic!("expected InvalidSize(8192), got {:?}", other),
430        }
431    }
432
433    /// Offset not aligned to align_of::<T>() — must return InvalidOperation.
434    #[test]
435    #[cfg(target_os = "linux")]
436    fn test_dma_map_misaligned_offset() {
437        let fd = dummy_fd();
438        // shape=[1024] u32 → logical_size=4096; offset=3 (not aligned to 4)
439        // buf_size=8192 so total_needed check passes; alignment check fires
440        let result = DmaMap::<u32>::new(fd, &[1024], 8192, 3);
441        assert!(
442            matches!(result, Err(Error::InvalidOperation(_))),
443            "expected InvalidOperation for misaligned offset, got {:?}",
444            result
445        );
446    }
447
448    /// offset + logical_size overflows usize — must return InvalidSize(0).
449    #[test]
450    #[cfg(target_os = "linux")]
451    fn test_dma_map_offset_overflow() {
452        let fd = dummy_fd();
453        // offset=usize::MAX, shape=[1] u8 → checked_add overflows
454        let result = DmaMap::<u8>::new(fd, &[1], usize::MAX, usize::MAX);
455        assert!(
456            matches!(result, Err(Error::InvalidSize(0))),
457            "expected InvalidSize(0) on overflow, got {:?}",
458            result
459        );
460    }
461
462    #[test]
463    #[cfg(target_os = "linux")]
464    fn test_dma_map_with_offset() {
465        use crate::{Tensor, TensorMapTrait, TensorMemory, TensorTrait};
466
467        // Skip if DMA heap not available
468        let total_size: usize = 4096 * 4; // 16KB
469        let offset: usize = 4096; // 4KB offset
470        let data_size: usize = 4096; // 4KB of data after offset
471
472        let large_buf = match Tensor::<u8>::new(&[total_size], Some(TensorMemory::Dma), None) {
473            Ok(buf) => buf,
474            Err(_) => {
475                eprintln!("SKIPPED: DMA not available");
476                return;
477            }
478        };
479
480        // Fill entire buffer with sentinel
481        {
482            let mut map = large_buf.map().unwrap();
483            map.as_mut_slice().fill(0xAA);
484        }
485
486        // Import at offset as a smaller tensor using clone_fd + set_plane_offset
487        let fd = large_buf.clone_fd().unwrap();
488        let mut offset_tensor = Tensor::<u8>::from_fd(fd, &[data_size], None).unwrap();
489        offset_tensor.set_plane_offset(offset);
490
491        // Map the offset tensor — should succeed (not rejected)
492        let mut map = offset_tensor.map().unwrap();
493        let slice = map.as_mut_slice();
494
495        // Should see the sentinel at the offset position
496        assert_eq!(slice.len(), data_size);
497        assert!(
498            slice.iter().all(|&b| b == 0xAA),
499            "Offset tensor map should see sentinel data at offset"
500        );
501
502        // Write different data at offset
503        slice.fill(0xBB);
504        drop(map);
505
506        // Verify via the original buffer: bytes before offset unchanged,
507        // bytes at offset are 0xBB
508        {
509            let map = large_buf.map().unwrap();
510            let buf = map.as_slice();
511            assert!(
512                buf[..offset].iter().all(|&b| b == 0xAA),
513                "Data before offset should be unchanged"
514            );
515            assert!(
516                buf[offset..offset + data_size].iter().all(|&b| b == 0xBB),
517                "Data at offset should be 0xBB"
518            );
519        }
520    }
521}