Skip to main content

edgefirst_tensor/
dma.rs

1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4use crate::{
5    error::{Error, Result},
6    TensorMap, TensorMapTrait, TensorMemory, TensorTrait,
7};
8use log::{trace, warn};
9use num_traits::Num;
10use std::{
11    ffi::c_void,
12    fmt,
13    num::NonZero,
14    ops::{Deref, DerefMut},
15    os::fd::{AsRawFd, OwnedFd},
16    ptr::NonNull,
17    sync::{Arc, Mutex},
18};
19
20/// A tensor backed by DMA (Direct Memory Access) memory.
21///
22/// On Linux, a DRM PRIME attachment is created to enable proper CPU cache
23/// coherency via `DMA_BUF_IOCTL_SYNC`. Without this attachment, sync ioctls
24/// are no-ops on cached CMA heaps.
25#[derive(Debug)]
26pub struct DmaTensor<T>
27where
28    T: Num + Clone + fmt::Debug + Send + Sync,
29{
30    pub name: String,
31    pub fd: OwnedFd,
32    pub shape: Vec<usize>,
33    pub _marker: std::marker::PhantomData<T>,
34    #[cfg(target_os = "linux")]
35    _drm_attachment: Option<crate::dmabuf::DrmAttachment>,
36}
37
38unsafe impl<T> Send for DmaTensor<T> where T: Num + Clone + fmt::Debug + Send + Sync {}
39unsafe impl<T> Sync for DmaTensor<T> where T: Num + Clone + fmt::Debug + Send + Sync {}
40
41impl<T> TensorTrait<T> for DmaTensor<T>
42where
43    T: Num + Clone + fmt::Debug + Send + Sync,
44{
45    #[cfg(target_os = "linux")]
46    fn new(shape: &[usize], name: Option<&str>) -> Result<Self> {
47        use log::debug;
48        use nix::sys::stat::fstat;
49
50        let size = shape.iter().product::<usize>() * std::mem::size_of::<T>();
51        let name = match name {
52            Some(name) => name.to_owned(),
53            None => {
54                let uuid = uuid::Uuid::new_v4().as_simple().to_string();
55                format!("/{}", &uuid[..16])
56            }
57        };
58
59        let heap = match dma_heap::Heap::new(dma_heap::HeapKind::Cma) {
60            Ok(heap) => heap,
61            Err(_) => dma_heap::Heap::new(dma_heap::HeapKind::System)?,
62        };
63
64        let dma_fd = heap.allocate(size)?;
65        let stat = fstat(&dma_fd)?;
66        debug!("DMA memory stat: {stat:?}");
67
68        let drm_attachment = crate::dmabuf::DrmAttachment::new(&dma_fd);
69
70        Ok(DmaTensor::<T> {
71            name: name.to_owned(),
72            fd: dma_fd,
73            shape: shape.to_vec(),
74            _marker: std::marker::PhantomData,
75            _drm_attachment: drm_attachment,
76        })
77    }
78
79    #[cfg(not(target_os = "linux"))]
80    fn new(_shape: &[usize], _name: Option<&str>) -> Result<Self> {
81        Err(Error::NotImplemented(
82            "DMA tensors are not supported on this platform".to_owned(),
83        ))
84    }
85
86    fn from_fd(fd: OwnedFd, shape: &[usize], name: Option<&str>) -> Result<Self> {
87        if shape.is_empty() {
88            return Err(Error::InvalidSize(0));
89        }
90
91        let size = shape.iter().product::<usize>() * std::mem::size_of::<T>();
92        if size == 0 {
93            return Err(Error::InvalidSize(0));
94        }
95
96        #[cfg(target_os = "linux")]
97        let drm_attachment = crate::dmabuf::DrmAttachment::new(&fd);
98
99        Ok(DmaTensor {
100            name: name.unwrap_or("").to_owned(),
101            fd,
102            shape: shape.to_vec(),
103            _marker: std::marker::PhantomData,
104            #[cfg(target_os = "linux")]
105            _drm_attachment: drm_attachment,
106        })
107    }
108
109    fn clone_fd(&self) -> Result<OwnedFd> {
110        Ok(self.fd.try_clone()?)
111    }
112
113    fn memory(&self) -> TensorMemory {
114        TensorMemory::Dma
115    }
116
117    fn name(&self) -> String {
118        self.name.clone()
119    }
120
121    fn shape(&self) -> &[usize] {
122        &self.shape
123    }
124
125    fn reshape(&mut self, shape: &[usize]) -> Result<()> {
126        if shape.is_empty() {
127            return Err(Error::InvalidSize(0));
128        }
129
130        let new_size = shape.iter().product::<usize>() * std::mem::size_of::<T>();
131        if new_size != self.size() {
132            return Err(Error::ShapeMismatch(format!(
133                "Cannot reshape incompatible shape: {:?} to {:?}",
134                self.shape, shape
135            )));
136        }
137
138        self.shape = shape.to_vec();
139        Ok(())
140    }
141
142    fn map(&self) -> Result<TensorMap<T>> {
143        Ok(TensorMap::Dma(DmaMap::new(
144            self.fd.try_clone()?,
145            &self.shape,
146        )?))
147    }
148}
149
150impl<T> AsRawFd for DmaTensor<T>
151where
152    T: Num + Clone + fmt::Debug + Send + Sync,
153{
154    fn as_raw_fd(&self) -> std::os::fd::RawFd {
155        self.fd.as_raw_fd()
156    }
157}
158
159impl<T> DmaTensor<T>
160where
161    T: Num + Clone + Send + Sync + std::fmt::Debug + Send + Sync,
162{
163    pub fn try_clone(&self) -> Result<Self> {
164        let fd = self.clone_fd()?;
165        #[cfg(target_os = "linux")]
166        let drm_attachment = crate::dmabuf::DrmAttachment::new(&fd);
167        Ok(Self {
168            name: self.name.clone(),
169            fd,
170            shape: self.shape.clone(),
171            _marker: std::marker::PhantomData,
172            #[cfg(target_os = "linux")]
173            _drm_attachment: drm_attachment,
174        })
175    }
176}
177
178#[derive(Debug)]
179pub struct DmaMap<T>
180where
181    T: Num + Clone + fmt::Debug,
182{
183    ptr: Arc<Mutex<DmaPtr>>,
184    fd: OwnedFd,
185    shape: Vec<usize>,
186    _marker: std::marker::PhantomData<T>,
187}
188
189impl<T> DmaMap<T>
190where
191    T: Num + Clone + fmt::Debug,
192{
193    pub fn new(fd: OwnedFd, shape: &[usize]) -> Result<Self> {
194        if shape.is_empty() {
195            return Err(Error::InvalidSize(0));
196        }
197
198        let size = shape.iter().product::<usize>() * std::mem::size_of::<T>();
199        if size == 0 {
200            return Err(Error::InvalidSize(0));
201        }
202
203        #[cfg(target_os = "linux")]
204        crate::dmabuf::start_readwrite(&fd)?;
205
206        let ptr = unsafe {
207            nix::sys::mman::mmap(
208                None,
209                NonZero::new(size).ok_or(Error::InvalidSize(size))?,
210                nix::sys::mman::ProtFlags::PROT_READ | nix::sys::mman::ProtFlags::PROT_WRITE,
211                nix::sys::mman::MapFlags::MAP_SHARED,
212                &fd,
213                0,
214            )?
215        };
216
217        trace!("Mapping DMA memory: {ptr:?}");
218        let dma_ptr = DmaPtr(NonNull::new(ptr.as_ptr()).ok_or(Error::InvalidSize(size))?);
219        Ok(DmaMap {
220            ptr: Arc::new(Mutex::new(dma_ptr)),
221            fd,
222            shape: shape.to_vec(),
223            _marker: std::marker::PhantomData,
224        })
225    }
226}
227
228impl<T> Deref for DmaMap<T>
229where
230    T: Num + Clone + fmt::Debug,
231{
232    type Target = [T];
233
234    fn deref(&self) -> &[T] {
235        self.as_slice()
236    }
237}
238
239impl<T> DerefMut for DmaMap<T>
240where
241    T: Num + Clone + fmt::Debug,
242{
243    fn deref_mut(&mut self) -> &mut [T] {
244        self.as_mut_slice()
245    }
246}
247
248#[derive(Debug)]
249struct DmaPtr(NonNull<c_void>);
250impl Deref for DmaPtr {
251    type Target = NonNull<c_void>;
252
253    fn deref(&self) -> &Self::Target {
254        &self.0
255    }
256}
257
258unsafe impl Send for DmaPtr {}
259
260impl<T> TensorMapTrait<T> for DmaMap<T>
261where
262    T: Num + Clone + fmt::Debug,
263{
264    fn shape(&self) -> &[usize] {
265        &self.shape
266    }
267
268    fn unmap(&mut self) {
269        let ptr = self.ptr.lock().expect("Failed to lock DmaMap pointer");
270
271        if let Err(e) = unsafe { nix::sys::mman::munmap(**ptr, self.size()) } {
272            warn!("Failed to unmap DMA memory: {e}");
273        }
274
275        #[cfg(target_os = "linux")]
276        if let Err(e) = crate::dmabuf::end_readwrite(&self.fd) {
277            warn!("Failed to end read/write on DMA memory: {e}");
278        }
279    }
280
281    fn as_slice(&self) -> &[T] {
282        let ptr = self.ptr.lock().expect("Failed to lock DmaMap pointer");
283        unsafe { std::slice::from_raw_parts(ptr.as_ptr() as *const T, self.len()) }
284    }
285
286    fn as_mut_slice(&mut self) -> &mut [T] {
287        let ptr = self.ptr.lock().expect("Failed to lock DmaMap pointer");
288        unsafe { std::slice::from_raw_parts_mut(ptr.as_ptr() as *mut T, self.len()) }
289    }
290}
291
292impl<T> Drop for DmaMap<T>
293where
294    T: Num + Clone + fmt::Debug,
295{
296    fn drop(&mut self) {
297        trace!("DmaMap dropped, unmapping memory");
298        self.unmap();
299    }
300}