Skip to main content

edgefirst_tensor/
dma.rs

1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4use crate::{
5    error::{Error, Result},
6    TensorMap, TensorMapTrait, TensorMemory, TensorTrait,
7};
8use log::{trace, warn};
9use num_traits::Num;
10use std::{
11    ffi::c_void,
12    fmt,
13    num::NonZero,
14    ops::{Deref, DerefMut},
15    os::fd::{AsRawFd, OwnedFd},
16    ptr::NonNull,
17    sync::{Arc, Mutex},
18};
19
20/// A tensor backed by DMA (Direct Memory Access) memory.
21#[derive(Debug)]
22pub struct DmaTensor<T>
23where
24    T: Num + Clone + fmt::Debug + Send + Sync,
25{
26    pub name: String,
27    pub fd: OwnedFd,
28    pub shape: Vec<usize>,
29    pub _marker: std::marker::PhantomData<T>,
30}
31
32unsafe impl<T> Send for DmaTensor<T> where T: Num + Clone + fmt::Debug + Send + Sync {}
33unsafe impl<T> Sync for DmaTensor<T> where T: Num + Clone + fmt::Debug + Send + Sync {}
34
35impl<T> TensorTrait<T> for DmaTensor<T>
36where
37    T: Num + Clone + fmt::Debug + Send + Sync,
38{
39    #[cfg(target_os = "linux")]
40    fn new(shape: &[usize], name: Option<&str>) -> Result<Self> {
41        use log::debug;
42        use nix::sys::stat::fstat;
43
44        let size = shape.iter().product::<usize>() * std::mem::size_of::<T>();
45        let name = match name {
46            Some(name) => name.to_owned(),
47            None => {
48                let uuid = uuid::Uuid::new_v4().as_simple().to_string();
49                format!("/{}", &uuid[..16])
50            }
51        };
52
53        let heap = match dma_heap::Heap::new(dma_heap::HeapKind::Cma) {
54            Ok(heap) => heap,
55            Err(_) => dma_heap::Heap::new(dma_heap::HeapKind::System)?,
56        };
57
58        let dma_fd = heap.allocate(size)?;
59        let stat = fstat(&dma_fd)?;
60        debug!("DMA memory stat: {stat:?}");
61
62        Ok(DmaTensor::<T> {
63            name: name.to_owned(),
64            fd: dma_fd,
65            shape: shape.to_vec(),
66            _marker: std::marker::PhantomData,
67        })
68    }
69
70    #[cfg(not(target_os = "linux"))]
71    fn new(_shape: &[usize], _name: Option<&str>) -> Result<Self> {
72        Err(Error::NotImplemented(
73            "DMA tensors are not supported on this platform".to_owned(),
74        ))
75    }
76
77    fn from_fd(fd: OwnedFd, shape: &[usize], name: Option<&str>) -> Result<Self> {
78        if shape.is_empty() {
79            return Err(Error::InvalidSize(0));
80        }
81
82        let size = shape.iter().product::<usize>() * std::mem::size_of::<T>();
83        if size == 0 {
84            return Err(Error::InvalidSize(0));
85        }
86
87        Ok(DmaTensor {
88            name: name.unwrap_or("").to_owned(),
89            fd,
90            shape: shape.to_vec(),
91            _marker: std::marker::PhantomData,
92        })
93    }
94
95    fn clone_fd(&self) -> Result<OwnedFd> {
96        Ok(self.fd.try_clone()?)
97    }
98
99    fn memory(&self) -> TensorMemory {
100        TensorMemory::Dma
101    }
102
103    fn name(&self) -> String {
104        self.name.clone()
105    }
106
107    fn shape(&self) -> &[usize] {
108        &self.shape
109    }
110
111    fn reshape(&mut self, shape: &[usize]) -> Result<()> {
112        if shape.is_empty() {
113            return Err(Error::InvalidSize(0));
114        }
115
116        let new_size = shape.iter().product::<usize>() * std::mem::size_of::<T>();
117        if new_size != self.size() {
118            return Err(Error::ShapeMismatch(format!(
119                "Cannot reshape incompatible shape: {:?} to {:?}",
120                self.shape, shape
121            )));
122        }
123
124        self.shape = shape.to_vec();
125        Ok(())
126    }
127
128    fn map(&self) -> Result<TensorMap<T>> {
129        Ok(TensorMap::Dma(DmaMap::new(
130            self.fd.try_clone()?,
131            &self.shape,
132        )?))
133    }
134}
135
136impl<T> AsRawFd for DmaTensor<T>
137where
138    T: Num + Clone + fmt::Debug + Send + Sync,
139{
140    fn as_raw_fd(&self) -> std::os::fd::RawFd {
141        self.fd.as_raw_fd()
142    }
143}
144
145impl<T> DmaTensor<T>
146where
147    T: Num + Clone + Send + Sync + std::fmt::Debug + Send + Sync,
148{
149    pub fn try_clone(&self) -> Result<Self> {
150        let fd = self.clone_fd()?;
151        Ok(Self {
152            name: self.name.clone(),
153            fd,
154            shape: self.shape.clone(),
155            _marker: std::marker::PhantomData,
156        })
157    }
158}
159
160#[derive(Debug)]
161pub struct DmaMap<T>
162where
163    T: Num + Clone + fmt::Debug,
164{
165    ptr: Arc<Mutex<DmaPtr>>,
166    fd: OwnedFd,
167    shape: Vec<usize>,
168    _marker: std::marker::PhantomData<T>,
169}
170
171impl<T> DmaMap<T>
172where
173    T: Num + Clone + fmt::Debug,
174{
175    pub fn new(fd: OwnedFd, shape: &[usize]) -> Result<Self> {
176        if shape.is_empty() {
177            return Err(Error::InvalidSize(0));
178        }
179
180        let size = shape.iter().product::<usize>() * std::mem::size_of::<T>();
181        if size == 0 {
182            return Err(Error::InvalidSize(0));
183        }
184
185        #[cfg(target_os = "linux")]
186        crate::dmabuf::start_readwrite(&fd)?;
187
188        let ptr = unsafe {
189            nix::sys::mman::mmap(
190                None,
191                NonZero::new(size).ok_or(Error::InvalidSize(size))?,
192                nix::sys::mman::ProtFlags::PROT_READ | nix::sys::mman::ProtFlags::PROT_WRITE,
193                nix::sys::mman::MapFlags::MAP_SHARED,
194                &fd,
195                0,
196            )?
197        };
198
199        trace!("Mapping DMA memory: {ptr:?}");
200        let dma_ptr = DmaPtr(NonNull::new(ptr.as_ptr()).ok_or(Error::InvalidSize(size))?);
201        Ok(DmaMap {
202            ptr: Arc::new(Mutex::new(dma_ptr)),
203            fd,
204            shape: shape.to_vec(),
205            _marker: std::marker::PhantomData,
206        })
207    }
208}
209
210impl<T> Deref for DmaMap<T>
211where
212    T: Num + Clone + fmt::Debug,
213{
214    type Target = [T];
215
216    fn deref(&self) -> &[T] {
217        self.as_slice()
218    }
219}
220
221impl<T> DerefMut for DmaMap<T>
222where
223    T: Num + Clone + fmt::Debug,
224{
225    fn deref_mut(&mut self) -> &mut [T] {
226        self.as_mut_slice()
227    }
228}
229
230#[derive(Debug)]
231struct DmaPtr(NonNull<c_void>);
232impl Deref for DmaPtr {
233    type Target = NonNull<c_void>;
234
235    fn deref(&self) -> &Self::Target {
236        &self.0
237    }
238}
239
240unsafe impl Send for DmaPtr {}
241
242impl<T> TensorMapTrait<T> for DmaMap<T>
243where
244    T: Num + Clone + fmt::Debug,
245{
246    fn shape(&self) -> &[usize] {
247        &self.shape
248    }
249
250    fn unmap(&mut self) {
251        let ptr = self.ptr.lock().expect("Failed to lock DmaMap pointer");
252
253        if let Err(e) = unsafe { nix::sys::mman::munmap(**ptr, self.size()) } {
254            warn!("Failed to unmap DMA memory: {e}");
255        }
256
257        #[cfg(target_os = "linux")]
258        if let Err(e) = crate::dmabuf::end_readwrite(&self.fd) {
259            warn!("Failed to end read/write on DMA memory: {e}");
260        }
261    }
262
263    fn as_slice(&self) -> &[T] {
264        let ptr = self.ptr.lock().expect("Failed to lock DmaMap pointer");
265        unsafe { std::slice::from_raw_parts(ptr.as_ptr() as *const T, self.len()) }
266    }
267
268    fn as_mut_slice(&mut self) -> &mut [T] {
269        let ptr = self.ptr.lock().expect("Failed to lock DmaMap pointer");
270        unsafe { std::slice::from_raw_parts_mut(ptr.as_ptr() as *mut T, self.len()) }
271    }
272}
273
274impl<T> Drop for DmaMap<T>
275where
276    T: Num + Clone + fmt::Debug,
277{
278    fn drop(&mut self) {
279        trace!("DmaMap dropped, unmapping memory");
280        self.unmap();
281    }
282}