Skip to main content

edgefirst_tensor/
dma.rs

1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4use crate::{
5    error::{Error, Result},
6    TensorMap, TensorMapTrait, TensorMemory, TensorTrait,
7};
8use log::{trace, warn};
9use num_traits::Num;
10use std::{
11    ffi::c_void,
12    fmt,
13    num::NonZero,
14    ops::{Deref, DerefMut},
15    os::fd::{AsRawFd, OwnedFd},
16    ptr::NonNull,
17    sync::{Arc, Mutex},
18};
19
20/// A tensor backed by DMA (Direct Memory Access) memory.
21///
22/// On Linux, a DRM PRIME attachment is created to enable proper CPU cache
23/// coherency via `DMA_BUF_IOCTL_SYNC`. Without this attachment, sync ioctls
24/// are no-ops on cached CMA heaps.
25#[derive(Debug)]
26pub struct DmaTensor<T>
27where
28    T: Num + Clone + fmt::Debug + Send + Sync,
29{
30    pub name: String,
31    pub fd: OwnedFd,
32    pub shape: Vec<usize>,
33    pub _marker: std::marker::PhantomData<T>,
34    #[cfg(target_os = "linux")]
35    _drm_attachment: Option<crate::dmabuf::DrmAttachment>,
36    identity: crate::BufferIdentity,
37}
38
39unsafe impl<T> Send for DmaTensor<T> where T: Num + Clone + fmt::Debug + Send + Sync {}
40unsafe impl<T> Sync for DmaTensor<T> where T: Num + Clone + fmt::Debug + Send + Sync {}
41
42impl<T> TensorTrait<T> for DmaTensor<T>
43where
44    T: Num + Clone + fmt::Debug + Send + Sync,
45{
46    #[cfg(target_os = "linux")]
47    fn new(shape: &[usize], name: Option<&str>) -> Result<Self> {
48        use log::debug;
49        use nix::sys::stat::fstat;
50
51        let size = shape.iter().product::<usize>() * std::mem::size_of::<T>();
52        let name = match name {
53            Some(name) => name.to_owned(),
54            None => {
55                let uuid = uuid::Uuid::new_v4().as_simple().to_string();
56                format!("/{}", &uuid[..16])
57            }
58        };
59
60        let heap = match dma_heap::Heap::new(dma_heap::HeapKind::Cma) {
61            Ok(heap) => heap,
62            Err(_) => dma_heap::Heap::new(dma_heap::HeapKind::System)?,
63        };
64
65        let dma_fd = heap.allocate(size)?;
66        let stat = fstat(&dma_fd)?;
67        debug!("DMA memory stat: {stat:?}");
68
69        let drm_attachment = crate::dmabuf::DrmAttachment::new(&dma_fd);
70
71        Ok(DmaTensor::<T> {
72            name: name.to_owned(),
73            fd: dma_fd,
74            shape: shape.to_vec(),
75            _marker: std::marker::PhantomData,
76            _drm_attachment: drm_attachment,
77            identity: crate::BufferIdentity::new(),
78        })
79    }
80
81    #[cfg(not(target_os = "linux"))]
82    fn new(_shape: &[usize], _name: Option<&str>) -> Result<Self> {
83        Err(Error::NotImplemented(
84            "DMA tensors are not supported on this platform".to_owned(),
85        ))
86    }
87
88    fn from_fd(fd: OwnedFd, shape: &[usize], name: Option<&str>) -> Result<Self> {
89        if shape.is_empty() {
90            return Err(Error::InvalidSize(0));
91        }
92
93        let size = shape.iter().product::<usize>() * std::mem::size_of::<T>();
94        if size == 0 {
95            return Err(Error::InvalidSize(0));
96        }
97
98        #[cfg(target_os = "linux")]
99        let drm_attachment = crate::dmabuf::DrmAttachment::new(&fd);
100
101        Ok(DmaTensor {
102            name: name.unwrap_or("").to_owned(),
103            fd,
104            shape: shape.to_vec(),
105            _marker: std::marker::PhantomData,
106            #[cfg(target_os = "linux")]
107            _drm_attachment: drm_attachment,
108            identity: crate::BufferIdentity::new(),
109        })
110    }
111
112    fn clone_fd(&self) -> Result<OwnedFd> {
113        Ok(self.fd.try_clone()?)
114    }
115
116    fn memory(&self) -> TensorMemory {
117        TensorMemory::Dma
118    }
119
120    fn name(&self) -> String {
121        self.name.clone()
122    }
123
124    fn shape(&self) -> &[usize] {
125        &self.shape
126    }
127
128    fn reshape(&mut self, shape: &[usize]) -> Result<()> {
129        if shape.is_empty() {
130            return Err(Error::InvalidSize(0));
131        }
132
133        let new_size = shape.iter().product::<usize>() * std::mem::size_of::<T>();
134        if new_size != self.size() {
135            return Err(Error::ShapeMismatch(format!(
136                "Cannot reshape incompatible shape: {:?} to {:?}",
137                self.shape, shape
138            )));
139        }
140
141        self.shape = shape.to_vec();
142        Ok(())
143    }
144
145    fn map(&self) -> Result<TensorMap<T>> {
146        Ok(TensorMap::Dma(DmaMap::new(
147            self.fd.try_clone()?,
148            &self.shape,
149        )?))
150    }
151
152    fn buffer_identity(&self) -> &crate::BufferIdentity {
153        &self.identity
154    }
155}
156
157impl<T> AsRawFd for DmaTensor<T>
158where
159    T: Num + Clone + fmt::Debug + Send + Sync,
160{
161    fn as_raw_fd(&self) -> std::os::fd::RawFd {
162        self.fd.as_raw_fd()
163    }
164}
165
166impl<T> DmaTensor<T>
167where
168    T: Num + Clone + Send + Sync + std::fmt::Debug + Send + Sync,
169{
170    pub fn try_clone(&self) -> Result<Self> {
171        let fd = self.clone_fd()?;
172        #[cfg(target_os = "linux")]
173        let drm_attachment = crate::dmabuf::DrmAttachment::new(&fd);
174        Ok(Self {
175            name: self.name.clone(),
176            fd,
177            shape: self.shape.clone(),
178            _marker: std::marker::PhantomData,
179            #[cfg(target_os = "linux")]
180            _drm_attachment: drm_attachment,
181            identity: self.identity.clone(),
182        })
183    }
184}
185
186#[derive(Debug)]
187pub struct DmaMap<T>
188where
189    T: Num + Clone + fmt::Debug,
190{
191    ptr: Arc<Mutex<DmaPtr>>,
192    fd: OwnedFd,
193    shape: Vec<usize>,
194    _marker: std::marker::PhantomData<T>,
195}
196
197impl<T> DmaMap<T>
198where
199    T: Num + Clone + fmt::Debug,
200{
201    pub fn new(fd: OwnedFd, shape: &[usize]) -> Result<Self> {
202        if shape.is_empty() {
203            return Err(Error::InvalidSize(0));
204        }
205
206        let size = shape.iter().product::<usize>() * std::mem::size_of::<T>();
207        if size == 0 {
208            return Err(Error::InvalidSize(0));
209        }
210
211        #[cfg(target_os = "linux")]
212        {
213            trace!("DmaMap: sync start fd={} size={size}", fd.as_raw_fd());
214            if let Err(e) = crate::dmabuf::start_readwrite(&fd) {
215                warn!(
216                    "DmaMap: DMA_BUF_IOCTL_SYNC(START) failed fd={}: {e}",
217                    fd.as_raw_fd()
218                );
219                return Err(Error::NixError(e));
220            }
221        }
222
223        let ptr = unsafe {
224            nix::sys::mman::mmap(
225                None,
226                NonZero::new(size).ok_or(Error::InvalidSize(size))?,
227                nix::sys::mman::ProtFlags::PROT_READ | nix::sys::mman::ProtFlags::PROT_WRITE,
228                nix::sys::mman::MapFlags::MAP_SHARED,
229                &fd,
230                0,
231            )?
232        };
233
234        trace!("Mapping DMA memory: {ptr:?}");
235        let dma_ptr = DmaPtr(NonNull::new(ptr.as_ptr()).ok_or(Error::InvalidSize(size))?);
236        Ok(DmaMap {
237            ptr: Arc::new(Mutex::new(dma_ptr)),
238            fd,
239            shape: shape.to_vec(),
240            _marker: std::marker::PhantomData,
241        })
242    }
243}
244
245impl<T> Deref for DmaMap<T>
246where
247    T: Num + Clone + fmt::Debug,
248{
249    type Target = [T];
250
251    fn deref(&self) -> &[T] {
252        self.as_slice()
253    }
254}
255
256impl<T> DerefMut for DmaMap<T>
257where
258    T: Num + Clone + fmt::Debug,
259{
260    fn deref_mut(&mut self) -> &mut [T] {
261        self.as_mut_slice()
262    }
263}
264
265#[derive(Debug)]
266struct DmaPtr(NonNull<c_void>);
267impl Deref for DmaPtr {
268    type Target = NonNull<c_void>;
269
270    fn deref(&self) -> &Self::Target {
271        &self.0
272    }
273}
274
275unsafe impl Send for DmaPtr {}
276
277impl<T> TensorMapTrait<T> for DmaMap<T>
278where
279    T: Num + Clone + fmt::Debug,
280{
281    fn shape(&self) -> &[usize] {
282        &self.shape
283    }
284
285    fn unmap(&mut self) {
286        let ptr = self.ptr.lock().expect("Failed to lock DmaMap pointer");
287
288        if let Err(e) = unsafe { nix::sys::mman::munmap(**ptr, self.size()) } {
289            warn!("Failed to unmap DMA memory: {e}");
290        }
291
292        #[cfg(target_os = "linux")]
293        if let Err(e) = crate::dmabuf::end_readwrite(&self.fd) {
294            warn!("Failed to end read/write on DMA memory: {e}");
295        }
296    }
297
298    fn as_slice(&self) -> &[T] {
299        let ptr = self.ptr.lock().expect("Failed to lock DmaMap pointer");
300        unsafe { std::slice::from_raw_parts(ptr.as_ptr() as *const T, self.len()) }
301    }
302
303    fn as_mut_slice(&mut self) -> &mut [T] {
304        let ptr = self.ptr.lock().expect("Failed to lock DmaMap pointer");
305        unsafe { std::slice::from_raw_parts_mut(ptr.as_ptr() as *mut T, self.len()) }
306    }
307}
308
309impl<T> Drop for DmaMap<T>
310where
311    T: Num + Clone + fmt::Debug,
312{
313    fn drop(&mut self) {
314        trace!("DmaMap dropped, unmapping memory");
315        self.unmap();
316    }
317}