Skip to main content

edgefirst_tensor/
dma.rs

1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4use crate::{
5    error::{Error, Result},
6    TensorMap, TensorMapTrait, TensorMemory, TensorTrait,
7};
8use log::{trace, warn};
9use num_traits::Num;
10use std::{
11    ffi::c_void,
12    fmt,
13    num::NonZero,
14    ops::{Deref, DerefMut},
15    os::fd::{AsRawFd, OwnedFd},
16    ptr::NonNull,
17    sync::{Arc, Mutex},
18};
19
20/// A tensor backed by DMA (Direct Memory Access) memory.
21///
22/// On Linux, a DRM PRIME attachment is created to enable proper CPU cache
23/// coherency via `DMA_BUF_IOCTL_SYNC`. Without this attachment, sync ioctls
24/// are no-ops on cached CMA heaps.
25#[derive(Debug)]
26pub struct DmaTensor<T>
27where
28    T: Num + Clone + fmt::Debug + Send + Sync,
29{
30    pub name: String,
31    pub fd: OwnedFd,
32    pub shape: Vec<usize>,
33    pub _marker: std::marker::PhantomData<T>,
34    #[cfg(target_os = "linux")]
35    _drm_attachment: Option<crate::dmabuf::DrmAttachment>,
36    identity: crate::BufferIdentity,
37    /// Actual buffer size in bytes (from fstat at creation time).
38    /// May be larger than shape.product() * sizeof(T) for externally
39    /// allocated buffers with row padding.
40    buf_size: usize,
41}
42
43unsafe impl<T> Send for DmaTensor<T> where T: Num + Clone + fmt::Debug + Send + Sync {}
44unsafe impl<T> Sync for DmaTensor<T> where T: Num + Clone + fmt::Debug + Send + Sync {}
45
46impl<T> TensorTrait<T> for DmaTensor<T>
47where
48    T: Num + Clone + fmt::Debug + Send + Sync,
49{
50    #[cfg(target_os = "linux")]
51    fn new(shape: &[usize], name: Option<&str>) -> Result<Self> {
52        use log::debug;
53        use nix::sys::stat::fstat;
54
55        let logical_size = shape.iter().product::<usize>() * std::mem::size_of::<T>();
56        let name = match name {
57            Some(name) => name.to_owned(),
58            None => {
59                let uuid = uuid::Uuid::new_v4().as_simple().to_string();
60                format!("/{}", &uuid[..16])
61            }
62        };
63
64        let heap = match dma_heap::Heap::new(dma_heap::HeapKind::Cma) {
65            Ok(heap) => heap,
66            Err(_) => dma_heap::Heap::new(dma_heap::HeapKind::System)?,
67        };
68
69        let dma_fd = heap.allocate(logical_size)?;
70        let stat = fstat(&dma_fd)?;
71        debug!("DMA memory stat: {stat:?}");
72        let buf_size = if stat.st_size > 0 {
73            std::cmp::max(stat.st_size as usize, logical_size)
74        } else {
75            logical_size
76        };
77
78        let drm_attachment = crate::dmabuf::DrmAttachment::new(&dma_fd);
79
80        Ok(DmaTensor::<T> {
81            name: name.to_owned(),
82            fd: dma_fd,
83            shape: shape.to_vec(),
84            _marker: std::marker::PhantomData,
85            _drm_attachment: drm_attachment,
86            identity: crate::BufferIdentity::new(),
87            buf_size,
88        })
89    }
90
91    #[cfg(not(target_os = "linux"))]
92    fn new(_shape: &[usize], _name: Option<&str>) -> Result<Self> {
93        Err(Error::NotImplemented(
94            "DMA tensors are not supported on this platform".to_owned(),
95        ))
96    }
97
98    fn from_fd(fd: OwnedFd, shape: &[usize], name: Option<&str>) -> Result<Self> {
99        if shape.is_empty() {
100            return Err(Error::InvalidSize(0));
101        }
102
103        let logical_size = shape.iter().product::<usize>() * std::mem::size_of::<T>();
104        if logical_size == 0 {
105            return Err(Error::InvalidSize(0));
106        }
107
108        // fstat may return st_size=0 for DMA-BUF fds on some kernels;
109        // fall back to logical_size in that case.
110        let buf_size = {
111            #[cfg(target_os = "linux")]
112            {
113                use nix::sys::stat::fstat;
114                match fstat(&fd) {
115                    Ok(stat) if stat.st_size > 0 && stat.st_size as usize >= logical_size => {
116                        stat.st_size as usize
117                    }
118                    _ => logical_size,
119                }
120            }
121            #[cfg(not(target_os = "linux"))]
122            {
123                logical_size
124            }
125        };
126
127        #[cfg(target_os = "linux")]
128        let drm_attachment = crate::dmabuf::DrmAttachment::new(&fd);
129
130        Ok(DmaTensor {
131            name: name.unwrap_or("").to_owned(),
132            fd,
133            shape: shape.to_vec(),
134            _marker: std::marker::PhantomData,
135            #[cfg(target_os = "linux")]
136            _drm_attachment: drm_attachment,
137            identity: crate::BufferIdentity::new(),
138            buf_size,
139        })
140    }
141
142    fn clone_fd(&self) -> Result<OwnedFd> {
143        Ok(self.fd.try_clone()?)
144    }
145
146    fn memory(&self) -> TensorMemory {
147        TensorMemory::Dma
148    }
149
150    fn name(&self) -> String {
151        self.name.clone()
152    }
153
154    fn shape(&self) -> &[usize] {
155        &self.shape
156    }
157
158    fn reshape(&mut self, shape: &[usize]) -> Result<()> {
159        if shape.is_empty() {
160            return Err(Error::InvalidSize(0));
161        }
162
163        let new_size = shape.iter().product::<usize>() * std::mem::size_of::<T>();
164        if new_size != self.size() {
165            return Err(Error::ShapeMismatch(format!(
166                "Cannot reshape incompatible shape: {:?} to {:?}",
167                self.shape, shape
168            )));
169        }
170
171        self.shape = shape.to_vec();
172        Ok(())
173    }
174
175    fn map(&self) -> Result<TensorMap<T>> {
176        Ok(TensorMap::Dma(DmaMap::new(
177            self.fd.try_clone()?,
178            &self.shape,
179            self.buf_size,
180        )?))
181    }
182
183    fn buffer_identity(&self) -> &crate::BufferIdentity {
184        &self.identity
185    }
186}
187
188impl<T> AsRawFd for DmaTensor<T>
189where
190    T: Num + Clone + fmt::Debug + Send + Sync,
191{
192    fn as_raw_fd(&self) -> std::os::fd::RawFd {
193        self.fd.as_raw_fd()
194    }
195}
196
197impl<T> DmaTensor<T>
198where
199    T: Num + Clone + Send + Sync + std::fmt::Debug + Send + Sync,
200{
201    pub fn try_clone(&self) -> Result<Self> {
202        let fd = self.clone_fd()?;
203        #[cfg(target_os = "linux")]
204        let drm_attachment = crate::dmabuf::DrmAttachment::new(&fd);
205        Ok(Self {
206            name: self.name.clone(),
207            fd,
208            shape: self.shape.clone(),
209            _marker: std::marker::PhantomData,
210            #[cfg(target_os = "linux")]
211            _drm_attachment: drm_attachment,
212            identity: self.identity.clone(),
213            buf_size: self.buf_size,
214        })
215    }
216}
217
218#[derive(Debug)]
219pub struct DmaMap<T>
220where
221    T: Num + Clone + fmt::Debug,
222{
223    ptr: Arc<Mutex<DmaPtr>>,
224    fd: OwnedFd,
225    shape: Vec<usize>,
226    /// Actual mmap'd size (may be > shape.product() * sizeof(T) for padded buffers).
227    mmap_size: usize,
228    _marker: std::marker::PhantomData<T>,
229}
230
231impl<T> DmaMap<T>
232where
233    T: Num + Clone + fmt::Debug,
234{
235    pub fn new(fd: OwnedFd, shape: &[usize], buf_size: usize) -> Result<Self> {
236        if shape.is_empty() {
237            return Err(Error::InvalidSize(0));
238        }
239
240        let logical_size = shape.iter().product::<usize>() * std::mem::size_of::<T>();
241        if logical_size == 0 {
242            return Err(Error::InvalidSize(0));
243        }
244
245        // Use the buffer's actual size (from fstat at DmaTensor creation) to ensure
246        // mmap covers the full allocation, including any row padding from external
247        // allocators. as_slice() still uses the logical element count from shape.
248        let mmap_size = std::cmp::max(buf_size, logical_size);
249
250        #[cfg(target_os = "linux")]
251        {
252            trace!("DmaMap: sync start fd={} size={mmap_size}", fd.as_raw_fd());
253            if let Err(e) = crate::dmabuf::start_readwrite(&fd) {
254                warn!(
255                    "DmaMap: DMA_BUF_IOCTL_SYNC(START) failed fd={}: {e}",
256                    fd.as_raw_fd()
257                );
258                return Err(Error::NixError(e));
259            }
260        }
261
262        let ptr = unsafe {
263            nix::sys::mman::mmap(
264                None,
265                NonZero::new(mmap_size).ok_or(Error::InvalidSize(mmap_size))?,
266                nix::sys::mman::ProtFlags::PROT_READ | nix::sys::mman::ProtFlags::PROT_WRITE,
267                nix::sys::mman::MapFlags::MAP_SHARED,
268                &fd,
269                0,
270            )?
271        };
272
273        trace!("Mapping DMA memory: {ptr:?}");
274        let dma_ptr = DmaPtr(NonNull::new(ptr.as_ptr()).ok_or(Error::InvalidSize(mmap_size))?);
275        Ok(DmaMap {
276            ptr: Arc::new(Mutex::new(dma_ptr)),
277            fd,
278            shape: shape.to_vec(),
279            mmap_size,
280            _marker: std::marker::PhantomData,
281        })
282    }
283}
284
285impl<T> Deref for DmaMap<T>
286where
287    T: Num + Clone + fmt::Debug,
288{
289    type Target = [T];
290
291    fn deref(&self) -> &[T] {
292        self.as_slice()
293    }
294}
295
296impl<T> DerefMut for DmaMap<T>
297where
298    T: Num + Clone + fmt::Debug,
299{
300    fn deref_mut(&mut self) -> &mut [T] {
301        self.as_mut_slice()
302    }
303}
304
305#[derive(Debug)]
306struct DmaPtr(NonNull<c_void>);
307impl Deref for DmaPtr {
308    type Target = NonNull<c_void>;
309
310    fn deref(&self) -> &Self::Target {
311        &self.0
312    }
313}
314
315unsafe impl Send for DmaPtr {}
316
317impl<T> TensorMapTrait<T> for DmaMap<T>
318where
319    T: Num + Clone + fmt::Debug,
320{
321    fn shape(&self) -> &[usize] {
322        &self.shape
323    }
324
325    fn unmap(&mut self) {
326        let ptr = self.ptr.lock().expect("Failed to lock DmaMap pointer");
327
328        if let Err(e) = unsafe { nix::sys::mman::munmap(**ptr, self.mmap_size) } {
329            warn!("Failed to unmap DMA memory: {e}");
330        }
331
332        #[cfg(target_os = "linux")]
333        if let Err(e) = crate::dmabuf::end_readwrite(&self.fd) {
334            warn!("Failed to end read/write on DMA memory: {e}");
335        }
336    }
337
338    fn as_slice(&self) -> &[T] {
339        let ptr = self.ptr.lock().expect("Failed to lock DmaMap pointer");
340        unsafe { std::slice::from_raw_parts(ptr.as_ptr() as *const T, self.len()) }
341    }
342
343    fn as_mut_slice(&mut self) -> &mut [T] {
344        let ptr = self.ptr.lock().expect("Failed to lock DmaMap pointer");
345        unsafe { std::slice::from_raw_parts_mut(ptr.as_ptr() as *mut T, self.len()) }
346    }
347}
348
349impl<T> Drop for DmaMap<T>
350where
351    T: Num + Clone + fmt::Debug,
352{
353    fn drop(&mut self) {
354        trace!("DmaMap dropped, unmapping memory");
355        self.unmap();
356    }
357}