Skip to main content

edgefirst_tensor/
dma.rs

1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4use crate::{
5    error::{Error, Result},
6    TensorMap, TensorMapTrait, TensorMemory, TensorTrait,
7};
8use log::{trace, warn};
9use num_traits::Num;
10use std::{
11    ffi::c_void,
12    fmt,
13    num::NonZero,
14    ops::{Deref, DerefMut},
15    os::fd::{AsRawFd, OwnedFd},
16    ptr::NonNull,
17    sync::{Arc, Mutex},
18};
19
20/// A tensor backed by DMA (Direct Memory Access) memory.
21///
22/// On Linux, for self-allocated (dma_heap) buffers a DRM PRIME attachment is
23/// created to enable CPU cache coherency via `DMA_BUF_IOCTL_SYNC`. Without an
24/// active attachment, sync ioctls are no-ops on cached CMA heaps.
25///
26/// For imported (foreign) DMA-BUF fds — e.g. those exported by the Neutron
27/// NPU driver — no DRM attachment is created. Cache coherency for foreign
28/// buffers is the responsibility of the buffer owner (the kernel driver).
29#[derive(Debug)]
30pub struct DmaTensor<T>
31where
32    T: Num + Clone + fmt::Debug + Send + Sync,
33{
34    pub name: String,
35    pub fd: OwnedFd,
36    pub shape: Vec<usize>,
37    pub _marker: std::marker::PhantomData<T>,
38    #[cfg(target_os = "linux")]
39    _drm_attachment: Option<crate::dmabuf::DrmAttachment>,
40    identity: crate::BufferIdentity,
41    /// Actual buffer size in bytes (from fstat at creation time).
42    /// May be larger than shape.product() * sizeof(T) for externally
43    /// allocated buffers with row padding.
44    buf_size: usize,
45    /// Byte offset into the DMA buffer where the tensor data begins.
46    /// Set via `Tensor::set_plane_offset` for sub-region imports.
47    pub(crate) mmap_offset: usize,
48    /// Whether this tensor was created via `from_fd()` (imported from an
49    /// external allocator).  Propagated through `try_clone()` so that DRM
50    /// PRIME import failures are logged at DEBUG rather than WARN.
51    #[cfg(target_os = "linux")]
52    is_imported: bool,
53}
54
55unsafe impl<T> Send for DmaTensor<T> where T: Num + Clone + fmt::Debug + Send + Sync {}
56unsafe impl<T> Sync for DmaTensor<T> where T: Num + Clone + fmt::Debug + Send + Sync {}
57
58impl<T> TensorTrait<T> for DmaTensor<T>
59where
60    T: Num + Clone + fmt::Debug + Send + Sync,
61{
62    #[cfg(target_os = "linux")]
63    fn new(shape: &[usize], name: Option<&str>) -> Result<Self> {
64        use log::debug;
65        use nix::sys::stat::fstat;
66
67        let logical_size = shape.iter().product::<usize>() * std::mem::size_of::<T>();
68        let name = match name {
69            Some(name) => name.to_owned(),
70            None => {
71                let uuid = uuid::Uuid::new_v4().as_simple().to_string();
72                format!("/{}", &uuid[..16])
73            }
74        };
75
76        let heap = match dma_heap::Heap::new(dma_heap::HeapKind::Cma) {
77            Ok(heap) => heap,
78            Err(_) => dma_heap::Heap::new(dma_heap::HeapKind::System)?,
79        };
80
81        let dma_fd = heap.allocate(logical_size)?;
82        let stat = fstat(&dma_fd)?;
83        debug!("DMA memory stat: {stat:?}");
84        let buf_size = if stat.st_size > 0 {
85            std::cmp::max(stat.st_size as usize, logical_size)
86        } else {
87            logical_size
88        };
89
90        let drm_attachment = crate::dmabuf::DrmAttachment::new(&dma_fd, false);
91
92        Ok(DmaTensor::<T> {
93            name: name.to_owned(),
94            fd: dma_fd,
95            shape: shape.to_vec(),
96            _marker: std::marker::PhantomData,
97            _drm_attachment: drm_attachment,
98            identity: crate::BufferIdentity::new(),
99            buf_size,
100            mmap_offset: 0,
101            is_imported: false,
102        })
103    }
104
105    #[cfg(not(target_os = "linux"))]
106    fn new(_shape: &[usize], _name: Option<&str>) -> Result<Self> {
107        Err(Error::NotImplemented(
108            "DMA tensors are not supported on this platform".to_owned(),
109        ))
110    }
111
112    fn from_fd(fd: OwnedFd, shape: &[usize], name: Option<&str>) -> Result<Self> {
113        if shape.is_empty() {
114            return Err(Error::InvalidSize(0));
115        }
116
117        let logical_size = shape.iter().product::<usize>() * std::mem::size_of::<T>();
118        if logical_size == 0 {
119            return Err(Error::InvalidSize(0));
120        }
121
122        // fstat may return st_size=0 for DMA-BUF fds on some kernels;
123        // fall back to logical_size in that case.
124        let buf_size = {
125            #[cfg(target_os = "linux")]
126            {
127                use nix::sys::stat::fstat;
128                match fstat(&fd) {
129                    Ok(stat) if stat.st_size > 0 && stat.st_size as usize >= logical_size => {
130                        stat.st_size as usize
131                    }
132                    _ => logical_size,
133                }
134            }
135            #[cfg(not(target_os = "linux"))]
136            {
137                logical_size
138            }
139        };
140
141        // Do NOT attempt a DRM attachment for foreign (imported) DMA-BUF fds.
142        // DRM PRIME import is only meaningful for DMA-BUF fds that were
143        // allocated by the same DRM device (e.g. via the CMA/system heap).
144        // For fds owned by other kernel drivers (e.g. Neutron NPU), the
145        // PRIME_FD_TO_HANDLE ioctl will fail and the resulting no-op
146        // attachment attempt adds unnecessary ioctl overhead on every import.
147        // DMA_BUF_IOCTL_SYNC coherency for foreign buffers is the
148        // responsibility of the buffer owner (the NPU driver in this case).
149        #[cfg(target_os = "linux")]
150        let drm_attachment = None;
151
152        Ok(DmaTensor {
153            name: name.unwrap_or("").to_owned(),
154            fd,
155            shape: shape.to_vec(),
156            _marker: std::marker::PhantomData,
157            #[cfg(target_os = "linux")]
158            _drm_attachment: drm_attachment,
159            identity: crate::BufferIdentity::new(),
160            buf_size,
161            mmap_offset: 0,
162            #[cfg(target_os = "linux")]
163            is_imported: true,
164        })
165    }
166
167    fn clone_fd(&self) -> Result<OwnedFd> {
168        Ok(self.fd.try_clone()?)
169    }
170
171    fn memory(&self) -> TensorMemory {
172        TensorMemory::Dma
173    }
174
175    fn name(&self) -> String {
176        self.name.clone()
177    }
178
179    fn shape(&self) -> &[usize] {
180        &self.shape
181    }
182
183    fn reshape(&mut self, shape: &[usize]) -> Result<()> {
184        if shape.is_empty() {
185            return Err(Error::InvalidSize(0));
186        }
187
188        let new_size = shape.iter().product::<usize>() * std::mem::size_of::<T>();
189        if new_size != self.size() {
190            return Err(Error::ShapeMismatch(format!(
191                "Cannot reshape incompatible shape: {:?} to {:?}",
192                self.shape, shape
193            )));
194        }
195
196        self.shape = shape.to_vec();
197        Ok(())
198    }
199
200    fn map(&self) -> Result<TensorMap<T>> {
201        Ok(TensorMap::Dma(DmaMap::new(
202            self.fd.try_clone()?,
203            &self.shape,
204            self.buf_size,
205            self.mmap_offset,
206        )?))
207    }
208
209    fn buffer_identity(&self) -> &crate::BufferIdentity {
210        &self.identity
211    }
212}
213
214impl<T> AsRawFd for DmaTensor<T>
215where
216    T: Num + Clone + fmt::Debug + Send + Sync,
217{
218    fn as_raw_fd(&self) -> std::os::fd::RawFd {
219        self.fd.as_raw_fd()
220    }
221}
222
223impl<T> DmaTensor<T>
224where
225    T: Num + Clone + Send + Sync + std::fmt::Debug + Send + Sync,
226{
227    pub fn try_clone(&self) -> Result<Self> {
228        let fd = self.clone_fd()?;
229        // Preserve the imported/owned distinction: imported fds never get a
230        // DRM attachment (consistent with from_fd()).
231        #[cfg(target_os = "linux")]
232        let drm_attachment = if self.is_imported {
233            None
234        } else {
235            crate::dmabuf::DrmAttachment::new(&fd, false)
236        };
237        Ok(Self {
238            name: self.name.clone(),
239            fd,
240            shape: self.shape.clone(),
241            _marker: std::marker::PhantomData,
242            #[cfg(target_os = "linux")]
243            _drm_attachment: drm_attachment,
244            identity: self.identity.clone(),
245            buf_size: self.buf_size,
246            mmap_offset: self.mmap_offset,
247            #[cfg(target_os = "linux")]
248            is_imported: self.is_imported,
249        })
250    }
251}
252
253#[derive(Debug)]
254pub struct DmaMap<T>
255where
256    T: Num + Clone + fmt::Debug,
257{
258    ptr: Arc<Mutex<DmaPtr>>,
259    fd: OwnedFd,
260    shape: Vec<usize>,
261    /// Actual mmap'd size (may be > shape.product() * sizeof(T) for padded buffers).
262    mmap_size: usize,
263    /// Byte offset into the mmap'd region where tensor data begins.
264    offset: usize,
265    _marker: std::marker::PhantomData<T>,
266}
267
268impl<T> DmaMap<T>
269where
270    T: Num + Clone + fmt::Debug,
271{
272    pub fn new(fd: OwnedFd, shape: &[usize], buf_size: usize, offset: usize) -> Result<Self> {
273        if shape.is_empty() {
274            return Err(Error::InvalidSize(0));
275        }
276
277        let logical_size = shape.iter().product::<usize>() * std::mem::size_of::<T>();
278        if logical_size == 0 {
279            return Err(Error::InvalidSize(0));
280        }
281
282        // Use the buffer's actual size (from fstat at DmaTensor creation).
283        // as_slice() uses the logical element count from shape.
284        // When an offset is present (sub-region of a larger DMA-BUF), verify
285        // that offset + logical_size fits within the allocated buffer — mapping
286        // beyond buf_size would cause SIGBUS on access.
287        let total_needed = offset
288            .checked_add(logical_size)
289            .ok_or(Error::InvalidSize(0))?;
290        if total_needed > buf_size {
291            warn!(
292                "DmaMap: offset={} + logical_size={} = {} exceeds buf_size={} (fd={})",
293                offset,
294                logical_size,
295                total_needed,
296                buf_size,
297                fd.as_raw_fd()
298            );
299            return Err(Error::InvalidSize(total_needed));
300        }
301        if std::mem::size_of::<T>() > 1 && !offset.is_multiple_of(std::mem::align_of::<T>()) {
302            return Err(Error::InvalidOperation(format!(
303                "DmaMap: offset {} is not aligned to align_of::<T>()={}",
304                offset,
305                std::mem::align_of::<T>()
306            )));
307        }
308        let mmap_size = buf_size;
309
310        #[cfg(target_os = "linux")]
311        {
312            trace!("DmaMap: sync start fd={} size={mmap_size}", fd.as_raw_fd());
313            if let Err(e) = crate::dmabuf::start_readwrite(&fd) {
314                warn!(
315                    "DmaMap: DMA_BUF_IOCTL_SYNC(START) failed fd={}: {e}",
316                    fd.as_raw_fd()
317                );
318                return Err(Error::NixError(e));
319            }
320        }
321
322        let ptr = unsafe {
323            nix::sys::mman::mmap(
324                None,
325                NonZero::new(mmap_size).ok_or(Error::InvalidSize(mmap_size))?,
326                nix::sys::mman::ProtFlags::PROT_READ | nix::sys::mman::ProtFlags::PROT_WRITE,
327                nix::sys::mman::MapFlags::MAP_SHARED,
328                &fd,
329                0,
330            )?
331        };
332
333        trace!("Mapping DMA memory: {ptr:?}");
334        let dma_ptr = DmaPtr(NonNull::new(ptr.as_ptr()).ok_or(Error::InvalidSize(mmap_size))?);
335        Ok(DmaMap {
336            ptr: Arc::new(Mutex::new(dma_ptr)),
337            fd,
338            shape: shape.to_vec(),
339            mmap_size,
340            offset,
341            _marker: std::marker::PhantomData,
342        })
343    }
344}
345
346impl<T> Deref for DmaMap<T>
347where
348    T: Num + Clone + fmt::Debug,
349{
350    type Target = [T];
351
352    fn deref(&self) -> &[T] {
353        self.as_slice()
354    }
355}
356
357impl<T> DerefMut for DmaMap<T>
358where
359    T: Num + Clone + fmt::Debug,
360{
361    fn deref_mut(&mut self) -> &mut [T] {
362        self.as_mut_slice()
363    }
364}
365
366#[derive(Debug)]
367struct DmaPtr(NonNull<c_void>);
368impl Deref for DmaPtr {
369    type Target = NonNull<c_void>;
370
371    fn deref(&self) -> &Self::Target {
372        &self.0
373    }
374}
375
376unsafe impl Send for DmaPtr {}
377
378impl<T> TensorMapTrait<T> for DmaMap<T>
379where
380    T: Num + Clone + fmt::Debug,
381{
382    fn shape(&self) -> &[usize] {
383        &self.shape
384    }
385
386    fn unmap(&mut self) {
387        let ptr = self.ptr.lock().expect("Failed to lock DmaMap pointer");
388
389        if let Err(e) = unsafe { nix::sys::mman::munmap(**ptr, self.mmap_size) } {
390            warn!("Failed to unmap DMA memory: {e}");
391        }
392
393        #[cfg(target_os = "linux")]
394        if let Err(e) = crate::dmabuf::end_readwrite(&self.fd) {
395            warn!("Failed to end read/write on DMA memory: {e}");
396        }
397    }
398
399    fn as_slice(&self) -> &[T] {
400        let ptr = self.ptr.lock().expect("Failed to lock DmaMap pointer");
401        let base = unsafe { (ptr.as_ptr() as *const u8).add(self.offset) as *const T };
402        unsafe { std::slice::from_raw_parts(base, self.len()) }
403    }
404
405    fn as_mut_slice(&mut self) -> &mut [T] {
406        let ptr = self.ptr.lock().expect("Failed to lock DmaMap pointer");
407        let base = unsafe { (ptr.as_ptr() as *mut u8).add(self.offset) as *mut T };
408        unsafe { std::slice::from_raw_parts_mut(base, self.len()) }
409    }
410}
411
412impl<T> Drop for DmaMap<T>
413where
414    T: Num + Clone + fmt::Debug,
415{
416    fn drop(&mut self) {
417        trace!("DmaMap dropped, unmapping memory");
418        self.unmap();
419    }
420}
421
422#[cfg(test)]
423mod tests {
424    use super::*;
425
426    /// Returns a valid fd backed by /dev/null.  The new error paths in
427    /// DmaMap::new() all fire before any fd-specific syscall (mmap,
428    /// DMA_BUF_IOCTL_SYNC), so any readable fd is sufficient.
429    #[cfg(target_os = "linux")]
430    fn dummy_fd() -> std::os::fd::OwnedFd {
431        use std::os::fd::FromRawFd;
432        use std::os::unix::io::IntoRawFd;
433        let f = std::fs::File::open("/dev/null").expect("open /dev/null");
434        unsafe { std::os::fd::OwnedFd::from_raw_fd(f.into_raw_fd()) }
435    }
436
437    /// offset + logical_size exceeds buf_size — must return InvalidSize.
438    #[test]
439    #[cfg(target_os = "linux")]
440    fn test_dma_map_offset_exceeds_buf_size() {
441        let fd = dummy_fd();
442        // shape=[4096] u8 → logical_size=4096; offset=4096 → total_needed=8192
443        // buf_size=4096 < 8192 → error
444        let result = DmaMap::<u8>::new(fd, &[4096], 4096, 4096);
445        match result {
446            Err(Error::InvalidSize(n)) => assert_eq!(n, 8192),
447            other => panic!("expected InvalidSize(8192), got {:?}", other),
448        }
449    }
450
451    /// Offset not aligned to align_of::<T>() — must return InvalidOperation.
452    #[test]
453    #[cfg(target_os = "linux")]
454    fn test_dma_map_misaligned_offset() {
455        let fd = dummy_fd();
456        // shape=[1024] u32 → logical_size=4096; offset=3 (not aligned to 4)
457        // buf_size=8192 so total_needed check passes; alignment check fires
458        let result = DmaMap::<u32>::new(fd, &[1024], 8192, 3);
459        assert!(
460            matches!(result, Err(Error::InvalidOperation(_))),
461            "expected InvalidOperation for misaligned offset, got {:?}",
462            result
463        );
464    }
465
466    /// offset + logical_size overflows usize — must return InvalidSize(0).
467    #[test]
468    #[cfg(target_os = "linux")]
469    fn test_dma_map_offset_overflow() {
470        let fd = dummy_fd();
471        // offset=usize::MAX, shape=[1] u8 → checked_add overflows
472        let result = DmaMap::<u8>::new(fd, &[1], usize::MAX, usize::MAX);
473        assert!(
474            matches!(result, Err(Error::InvalidSize(0))),
475            "expected InvalidSize(0) on overflow, got {:?}",
476            result
477        );
478    }
479
480    #[test]
481    #[cfg(target_os = "linux")]
482    fn test_dma_map_with_offset() {
483        use crate::{Tensor, TensorMapTrait, TensorMemory, TensorTrait};
484
485        // Skip if DMA heap not available
486        let total_size: usize = 4096 * 4; // 16KB
487        let offset: usize = 4096; // 4KB offset
488        let data_size: usize = 4096; // 4KB of data after offset
489
490        let large_buf = match Tensor::<u8>::new(&[total_size], Some(TensorMemory::Dma), None) {
491            Ok(buf) => buf,
492            Err(_) => {
493                eprintln!("SKIPPED: DMA not available");
494                return;
495            }
496        };
497
498        // Fill entire buffer with sentinel
499        {
500            let mut map = large_buf.map().unwrap();
501            map.as_mut_slice().fill(0xAA);
502        }
503
504        // Import at offset as a smaller tensor using clone_fd + set_plane_offset
505        let fd = large_buf.clone_fd().unwrap();
506        let mut offset_tensor = Tensor::<u8>::from_fd(fd, &[data_size], None).unwrap();
507        offset_tensor.set_plane_offset(offset);
508
509        // Map the offset tensor — should succeed (not rejected)
510        let mut map = offset_tensor.map().unwrap();
511        let slice = map.as_mut_slice();
512
513        // Should see the sentinel at the offset position
514        assert_eq!(slice.len(), data_size);
515        assert!(
516            slice.iter().all(|&b| b == 0xAA),
517            "Offset tensor map should see sentinel data at offset"
518        );
519
520        // Write different data at offset
521        slice.fill(0xBB);
522        drop(map);
523
524        // Verify via the original buffer: bytes before offset unchanged,
525        // bytes at offset are 0xBB
526        {
527            let map = large_buf.map().unwrap();
528            let buf = map.as_slice();
529            assert!(
530                buf[..offset].iter().all(|&b| b == 0xAA),
531                "Data before offset should be unchanged"
532            );
533            assert!(
534                buf[offset..offset + data_size].iter().all(|&b| b == 0xBB),
535                "Data at offset should be 0xBB"
536            );
537        }
538    }
539}