1use crate::{
5 error::{Error, Result},
6 TensorMap, TensorMapTrait, TensorMemory, TensorTrait,
7};
8use log::{trace, warn};
9use num_traits::Num;
10use std::{
11 ffi::c_void,
12 fmt,
13 num::NonZero,
14 ops::{Deref, DerefMut},
15 os::fd::{AsRawFd, OwnedFd},
16 ptr::NonNull,
17 sync::{Arc, Mutex},
18};
19
20#[derive(Debug)]
26pub struct DmaTensor<T>
27where
28 T: Num + Clone + fmt::Debug + Send + Sync,
29{
30 pub name: String,
31 pub fd: OwnedFd,
32 pub shape: Vec<usize>,
33 pub _marker: std::marker::PhantomData<T>,
34 #[cfg(target_os = "linux")]
35 _drm_attachment: Option<crate::dmabuf::DrmAttachment>,
36 identity: crate::BufferIdentity,
37 buf_size: usize,
41}
42
43unsafe impl<T> Send for DmaTensor<T> where T: Num + Clone + fmt::Debug + Send + Sync {}
44unsafe impl<T> Sync for DmaTensor<T> where T: Num + Clone + fmt::Debug + Send + Sync {}
45
46impl<T> TensorTrait<T> for DmaTensor<T>
47where
48 T: Num + Clone + fmt::Debug + Send + Sync,
49{
50 #[cfg(target_os = "linux")]
51 fn new(shape: &[usize], name: Option<&str>) -> Result<Self> {
52 use log::debug;
53 use nix::sys::stat::fstat;
54
55 let logical_size = shape.iter().product::<usize>() * std::mem::size_of::<T>();
56 let name = match name {
57 Some(name) => name.to_owned(),
58 None => {
59 let uuid = uuid::Uuid::new_v4().as_simple().to_string();
60 format!("/{}", &uuid[..16])
61 }
62 };
63
64 let heap = match dma_heap::Heap::new(dma_heap::HeapKind::Cma) {
65 Ok(heap) => heap,
66 Err(_) => dma_heap::Heap::new(dma_heap::HeapKind::System)?,
67 };
68
69 let dma_fd = heap.allocate(logical_size)?;
70 let stat = fstat(&dma_fd)?;
71 debug!("DMA memory stat: {stat:?}");
72 let buf_size = if stat.st_size > 0 {
73 std::cmp::max(stat.st_size as usize, logical_size)
74 } else {
75 logical_size
76 };
77
78 let drm_attachment = crate::dmabuf::DrmAttachment::new(&dma_fd);
79
80 Ok(DmaTensor::<T> {
81 name: name.to_owned(),
82 fd: dma_fd,
83 shape: shape.to_vec(),
84 _marker: std::marker::PhantomData,
85 _drm_attachment: drm_attachment,
86 identity: crate::BufferIdentity::new(),
87 buf_size,
88 })
89 }
90
91 #[cfg(not(target_os = "linux"))]
92 fn new(_shape: &[usize], _name: Option<&str>) -> Result<Self> {
93 Err(Error::NotImplemented(
94 "DMA tensors are not supported on this platform".to_owned(),
95 ))
96 }
97
98 fn from_fd(fd: OwnedFd, shape: &[usize], name: Option<&str>) -> Result<Self> {
99 if shape.is_empty() {
100 return Err(Error::InvalidSize(0));
101 }
102
103 let logical_size = shape.iter().product::<usize>() * std::mem::size_of::<T>();
104 if logical_size == 0 {
105 return Err(Error::InvalidSize(0));
106 }
107
108 let buf_size = {
111 #[cfg(target_os = "linux")]
112 {
113 use nix::sys::stat::fstat;
114 match fstat(&fd) {
115 Ok(stat) if stat.st_size > 0 && stat.st_size as usize >= logical_size => {
116 stat.st_size as usize
117 }
118 _ => logical_size,
119 }
120 }
121 #[cfg(not(target_os = "linux"))]
122 {
123 logical_size
124 }
125 };
126
127 #[cfg(target_os = "linux")]
128 let drm_attachment = crate::dmabuf::DrmAttachment::new(&fd);
129
130 Ok(DmaTensor {
131 name: name.unwrap_or("").to_owned(),
132 fd,
133 shape: shape.to_vec(),
134 _marker: std::marker::PhantomData,
135 #[cfg(target_os = "linux")]
136 _drm_attachment: drm_attachment,
137 identity: crate::BufferIdentity::new(),
138 buf_size,
139 })
140 }
141
142 fn clone_fd(&self) -> Result<OwnedFd> {
143 Ok(self.fd.try_clone()?)
144 }
145
146 fn memory(&self) -> TensorMemory {
147 TensorMemory::Dma
148 }
149
150 fn name(&self) -> String {
151 self.name.clone()
152 }
153
154 fn shape(&self) -> &[usize] {
155 &self.shape
156 }
157
158 fn reshape(&mut self, shape: &[usize]) -> Result<()> {
159 if shape.is_empty() {
160 return Err(Error::InvalidSize(0));
161 }
162
163 let new_size = shape.iter().product::<usize>() * std::mem::size_of::<T>();
164 if new_size != self.size() {
165 return Err(Error::ShapeMismatch(format!(
166 "Cannot reshape incompatible shape: {:?} to {:?}",
167 self.shape, shape
168 )));
169 }
170
171 self.shape = shape.to_vec();
172 Ok(())
173 }
174
175 fn map(&self) -> Result<TensorMap<T>> {
176 Ok(TensorMap::Dma(DmaMap::new(
177 self.fd.try_clone()?,
178 &self.shape,
179 self.buf_size,
180 )?))
181 }
182
183 fn buffer_identity(&self) -> &crate::BufferIdentity {
184 &self.identity
185 }
186}
187
188impl<T> AsRawFd for DmaTensor<T>
189where
190 T: Num + Clone + fmt::Debug + Send + Sync,
191{
192 fn as_raw_fd(&self) -> std::os::fd::RawFd {
193 self.fd.as_raw_fd()
194 }
195}
196
197impl<T> DmaTensor<T>
198where
199 T: Num + Clone + Send + Sync + std::fmt::Debug + Send + Sync,
200{
201 pub fn try_clone(&self) -> Result<Self> {
202 let fd = self.clone_fd()?;
203 #[cfg(target_os = "linux")]
204 let drm_attachment = crate::dmabuf::DrmAttachment::new(&fd);
205 Ok(Self {
206 name: self.name.clone(),
207 fd,
208 shape: self.shape.clone(),
209 _marker: std::marker::PhantomData,
210 #[cfg(target_os = "linux")]
211 _drm_attachment: drm_attachment,
212 identity: self.identity.clone(),
213 buf_size: self.buf_size,
214 })
215 }
216}
217
218#[derive(Debug)]
219pub struct DmaMap<T>
220where
221 T: Num + Clone + fmt::Debug,
222{
223 ptr: Arc<Mutex<DmaPtr>>,
224 fd: OwnedFd,
225 shape: Vec<usize>,
226 mmap_size: usize,
228 _marker: std::marker::PhantomData<T>,
229}
230
231impl<T> DmaMap<T>
232where
233 T: Num + Clone + fmt::Debug,
234{
235 pub fn new(fd: OwnedFd, shape: &[usize], buf_size: usize) -> Result<Self> {
236 if shape.is_empty() {
237 return Err(Error::InvalidSize(0));
238 }
239
240 let logical_size = shape.iter().product::<usize>() * std::mem::size_of::<T>();
241 if logical_size == 0 {
242 return Err(Error::InvalidSize(0));
243 }
244
245 let mmap_size = std::cmp::max(buf_size, logical_size);
249
250 #[cfg(target_os = "linux")]
251 {
252 trace!("DmaMap: sync start fd={} size={mmap_size}", fd.as_raw_fd());
253 if let Err(e) = crate::dmabuf::start_readwrite(&fd) {
254 warn!(
255 "DmaMap: DMA_BUF_IOCTL_SYNC(START) failed fd={}: {e}",
256 fd.as_raw_fd()
257 );
258 return Err(Error::NixError(e));
259 }
260 }
261
262 let ptr = unsafe {
263 nix::sys::mman::mmap(
264 None,
265 NonZero::new(mmap_size).ok_or(Error::InvalidSize(mmap_size))?,
266 nix::sys::mman::ProtFlags::PROT_READ | nix::sys::mman::ProtFlags::PROT_WRITE,
267 nix::sys::mman::MapFlags::MAP_SHARED,
268 &fd,
269 0,
270 )?
271 };
272
273 trace!("Mapping DMA memory: {ptr:?}");
274 let dma_ptr = DmaPtr(NonNull::new(ptr.as_ptr()).ok_or(Error::InvalidSize(mmap_size))?);
275 Ok(DmaMap {
276 ptr: Arc::new(Mutex::new(dma_ptr)),
277 fd,
278 shape: shape.to_vec(),
279 mmap_size,
280 _marker: std::marker::PhantomData,
281 })
282 }
283}
284
285impl<T> Deref for DmaMap<T>
286where
287 T: Num + Clone + fmt::Debug,
288{
289 type Target = [T];
290
291 fn deref(&self) -> &[T] {
292 self.as_slice()
293 }
294}
295
296impl<T> DerefMut for DmaMap<T>
297where
298 T: Num + Clone + fmt::Debug,
299{
300 fn deref_mut(&mut self) -> &mut [T] {
301 self.as_mut_slice()
302 }
303}
304
305#[derive(Debug)]
306struct DmaPtr(NonNull<c_void>);
307impl Deref for DmaPtr {
308 type Target = NonNull<c_void>;
309
310 fn deref(&self) -> &Self::Target {
311 &self.0
312 }
313}
314
315unsafe impl Send for DmaPtr {}
316
317impl<T> TensorMapTrait<T> for DmaMap<T>
318where
319 T: Num + Clone + fmt::Debug,
320{
321 fn shape(&self) -> &[usize] {
322 &self.shape
323 }
324
325 fn unmap(&mut self) {
326 let ptr = self.ptr.lock().expect("Failed to lock DmaMap pointer");
327
328 if let Err(e) = unsafe { nix::sys::mman::munmap(**ptr, self.mmap_size) } {
329 warn!("Failed to unmap DMA memory: {e}");
330 }
331
332 #[cfg(target_os = "linux")]
333 if let Err(e) = crate::dmabuf::end_readwrite(&self.fd) {
334 warn!("Failed to end read/write on DMA memory: {e}");
335 }
336 }
337
338 fn as_slice(&self) -> &[T] {
339 let ptr = self.ptr.lock().expect("Failed to lock DmaMap pointer");
340 unsafe { std::slice::from_raw_parts(ptr.as_ptr() as *const T, self.len()) }
341 }
342
343 fn as_mut_slice(&mut self) -> &mut [T] {
344 let ptr = self.ptr.lock().expect("Failed to lock DmaMap pointer");
345 unsafe { std::slice::from_raw_parts_mut(ptr.as_ptr() as *mut T, self.len()) }
346 }
347}
348
349impl<T> Drop for DmaMap<T>
350where
351 T: Num + Clone + fmt::Debug,
352{
353 fn drop(&mut self) {
354 trace!("DmaMap dropped, unmapping memory");
355 self.unmap();
356 }
357}