1use crate::{
5 error::{Error, Result},
6 TensorMap, TensorMapTrait, TensorMemory, TensorTrait,
7};
8use log::{trace, warn};
9use num_traits::Num;
10use std::{
11 ffi::c_void,
12 fmt,
13 num::NonZero,
14 ops::{Deref, DerefMut},
15 os::fd::{AsRawFd, OwnedFd},
16 ptr::NonNull,
17 sync::{Arc, Mutex},
18};
19
20#[derive(Debug)]
26pub struct DmaTensor<T>
27where
28 T: Num + Clone + fmt::Debug + Send + Sync,
29{
30 pub name: String,
31 pub fd: OwnedFd,
32 pub shape: Vec<usize>,
33 pub _marker: std::marker::PhantomData<T>,
34 #[cfg(target_os = "linux")]
35 _drm_attachment: Option<crate::dmabuf::DrmAttachment>,
36 identity: crate::BufferIdentity,
37 buf_size: usize,
41 pub(crate) mmap_offset: usize,
44 #[cfg(target_os = "linux")]
48 is_imported: bool,
49}
50
51unsafe impl<T> Send for DmaTensor<T> where T: Num + Clone + fmt::Debug + Send + Sync {}
52unsafe impl<T> Sync for DmaTensor<T> where T: Num + Clone + fmt::Debug + Send + Sync {}
53
54impl<T> TensorTrait<T> for DmaTensor<T>
55where
56 T: Num + Clone + fmt::Debug + Send + Sync,
57{
58 #[cfg(target_os = "linux")]
59 fn new(shape: &[usize], name: Option<&str>) -> Result<Self> {
60 use log::debug;
61 use nix::sys::stat::fstat;
62
63 let logical_size = shape.iter().product::<usize>() * std::mem::size_of::<T>();
64 let name = match name {
65 Some(name) => name.to_owned(),
66 None => {
67 let uuid = uuid::Uuid::new_v4().as_simple().to_string();
68 format!("/{}", &uuid[..16])
69 }
70 };
71
72 let heap = match dma_heap::Heap::new(dma_heap::HeapKind::Cma) {
73 Ok(heap) => heap,
74 Err(_) => dma_heap::Heap::new(dma_heap::HeapKind::System)?,
75 };
76
77 let dma_fd = heap.allocate(logical_size)?;
78 let stat = fstat(&dma_fd)?;
79 debug!("DMA memory stat: {stat:?}");
80 let buf_size = if stat.st_size > 0 {
81 std::cmp::max(stat.st_size as usize, logical_size)
82 } else {
83 logical_size
84 };
85
86 let drm_attachment = crate::dmabuf::DrmAttachment::new(&dma_fd, false);
87
88 Ok(DmaTensor::<T> {
89 name: name.to_owned(),
90 fd: dma_fd,
91 shape: shape.to_vec(),
92 _marker: std::marker::PhantomData,
93 _drm_attachment: drm_attachment,
94 identity: crate::BufferIdentity::new(),
95 buf_size,
96 mmap_offset: 0,
97 is_imported: false,
98 })
99 }
100
101 #[cfg(not(target_os = "linux"))]
102 fn new(_shape: &[usize], _name: Option<&str>) -> Result<Self> {
103 Err(Error::NotImplemented(
104 "DMA tensors are not supported on this platform".to_owned(),
105 ))
106 }
107
108 fn from_fd(fd: OwnedFd, shape: &[usize], name: Option<&str>) -> Result<Self> {
109 if shape.is_empty() {
110 return Err(Error::InvalidSize(0));
111 }
112
113 let logical_size = shape.iter().product::<usize>() * std::mem::size_of::<T>();
114 if logical_size == 0 {
115 return Err(Error::InvalidSize(0));
116 }
117
118 let buf_size = {
121 #[cfg(target_os = "linux")]
122 {
123 use nix::sys::stat::fstat;
124 match fstat(&fd) {
125 Ok(stat) if stat.st_size > 0 && stat.st_size as usize >= logical_size => {
126 stat.st_size as usize
127 }
128 _ => logical_size,
129 }
130 }
131 #[cfg(not(target_os = "linux"))]
132 {
133 logical_size
134 }
135 };
136
137 #[cfg(target_os = "linux")]
138 let drm_attachment = crate::dmabuf::DrmAttachment::new(&fd, true);
139
140 Ok(DmaTensor {
141 name: name.unwrap_or("").to_owned(),
142 fd,
143 shape: shape.to_vec(),
144 _marker: std::marker::PhantomData,
145 #[cfg(target_os = "linux")]
146 _drm_attachment: drm_attachment,
147 identity: crate::BufferIdentity::new(),
148 buf_size,
149 mmap_offset: 0,
150 #[cfg(target_os = "linux")]
151 is_imported: true,
152 })
153 }
154
155 fn clone_fd(&self) -> Result<OwnedFd> {
156 Ok(self.fd.try_clone()?)
157 }
158
159 fn memory(&self) -> TensorMemory {
160 TensorMemory::Dma
161 }
162
163 fn name(&self) -> String {
164 self.name.clone()
165 }
166
167 fn shape(&self) -> &[usize] {
168 &self.shape
169 }
170
171 fn reshape(&mut self, shape: &[usize]) -> Result<()> {
172 if shape.is_empty() {
173 return Err(Error::InvalidSize(0));
174 }
175
176 let new_size = shape.iter().product::<usize>() * std::mem::size_of::<T>();
177 if new_size != self.size() {
178 return Err(Error::ShapeMismatch(format!(
179 "Cannot reshape incompatible shape: {:?} to {:?}",
180 self.shape, shape
181 )));
182 }
183
184 self.shape = shape.to_vec();
185 Ok(())
186 }
187
188 fn map(&self) -> Result<TensorMap<T>> {
189 Ok(TensorMap::Dma(DmaMap::new(
190 self.fd.try_clone()?,
191 &self.shape,
192 self.buf_size,
193 self.mmap_offset,
194 )?))
195 }
196
197 fn buffer_identity(&self) -> &crate::BufferIdentity {
198 &self.identity
199 }
200}
201
202impl<T> AsRawFd for DmaTensor<T>
203where
204 T: Num + Clone + fmt::Debug + Send + Sync,
205{
206 fn as_raw_fd(&self) -> std::os::fd::RawFd {
207 self.fd.as_raw_fd()
208 }
209}
210
211impl<T> DmaTensor<T>
212where
213 T: Num + Clone + Send + Sync + std::fmt::Debug + Send + Sync,
214{
215 pub fn try_clone(&self) -> Result<Self> {
216 let fd = self.clone_fd()?;
217 #[cfg(target_os = "linux")]
218 let drm_attachment = crate::dmabuf::DrmAttachment::new(&fd, self.is_imported);
219 Ok(Self {
220 name: self.name.clone(),
221 fd,
222 shape: self.shape.clone(),
223 _marker: std::marker::PhantomData,
224 #[cfg(target_os = "linux")]
225 _drm_attachment: drm_attachment,
226 identity: self.identity.clone(),
227 buf_size: self.buf_size,
228 mmap_offset: self.mmap_offset,
229 #[cfg(target_os = "linux")]
230 is_imported: self.is_imported,
231 })
232 }
233}
234
235#[derive(Debug)]
236pub struct DmaMap<T>
237where
238 T: Num + Clone + fmt::Debug,
239{
240 ptr: Arc<Mutex<DmaPtr>>,
241 fd: OwnedFd,
242 shape: Vec<usize>,
243 mmap_size: usize,
245 offset: usize,
247 _marker: std::marker::PhantomData<T>,
248}
249
250impl<T> DmaMap<T>
251where
252 T: Num + Clone + fmt::Debug,
253{
254 pub fn new(fd: OwnedFd, shape: &[usize], buf_size: usize, offset: usize) -> Result<Self> {
255 if shape.is_empty() {
256 return Err(Error::InvalidSize(0));
257 }
258
259 let logical_size = shape.iter().product::<usize>() * std::mem::size_of::<T>();
260 if logical_size == 0 {
261 return Err(Error::InvalidSize(0));
262 }
263
264 let total_needed = offset
270 .checked_add(logical_size)
271 .ok_or(Error::InvalidSize(0))?;
272 if total_needed > buf_size {
273 warn!(
274 "DmaMap: offset={} + logical_size={} = {} exceeds buf_size={} (fd={})",
275 offset,
276 logical_size,
277 total_needed,
278 buf_size,
279 fd.as_raw_fd()
280 );
281 return Err(Error::InvalidSize(total_needed));
282 }
283 if std::mem::size_of::<T>() > 1 && !offset.is_multiple_of(std::mem::align_of::<T>()) {
284 return Err(Error::InvalidOperation(format!(
285 "DmaMap: offset {} is not aligned to align_of::<T>()={}",
286 offset,
287 std::mem::align_of::<T>()
288 )));
289 }
290 let mmap_size = buf_size;
291
292 #[cfg(target_os = "linux")]
293 {
294 trace!("DmaMap: sync start fd={} size={mmap_size}", fd.as_raw_fd());
295 if let Err(e) = crate::dmabuf::start_readwrite(&fd) {
296 warn!(
297 "DmaMap: DMA_BUF_IOCTL_SYNC(START) failed fd={}: {e}",
298 fd.as_raw_fd()
299 );
300 return Err(Error::NixError(e));
301 }
302 }
303
304 let ptr = unsafe {
305 nix::sys::mman::mmap(
306 None,
307 NonZero::new(mmap_size).ok_or(Error::InvalidSize(mmap_size))?,
308 nix::sys::mman::ProtFlags::PROT_READ | nix::sys::mman::ProtFlags::PROT_WRITE,
309 nix::sys::mman::MapFlags::MAP_SHARED,
310 &fd,
311 0,
312 )?
313 };
314
315 trace!("Mapping DMA memory: {ptr:?}");
316 let dma_ptr = DmaPtr(NonNull::new(ptr.as_ptr()).ok_or(Error::InvalidSize(mmap_size))?);
317 Ok(DmaMap {
318 ptr: Arc::new(Mutex::new(dma_ptr)),
319 fd,
320 shape: shape.to_vec(),
321 mmap_size,
322 offset,
323 _marker: std::marker::PhantomData,
324 })
325 }
326}
327
328impl<T> Deref for DmaMap<T>
329where
330 T: Num + Clone + fmt::Debug,
331{
332 type Target = [T];
333
334 fn deref(&self) -> &[T] {
335 self.as_slice()
336 }
337}
338
339impl<T> DerefMut for DmaMap<T>
340where
341 T: Num + Clone + fmt::Debug,
342{
343 fn deref_mut(&mut self) -> &mut [T] {
344 self.as_mut_slice()
345 }
346}
347
348#[derive(Debug)]
349struct DmaPtr(NonNull<c_void>);
350impl Deref for DmaPtr {
351 type Target = NonNull<c_void>;
352
353 fn deref(&self) -> &Self::Target {
354 &self.0
355 }
356}
357
358unsafe impl Send for DmaPtr {}
359
360impl<T> TensorMapTrait<T> for DmaMap<T>
361where
362 T: Num + Clone + fmt::Debug,
363{
364 fn shape(&self) -> &[usize] {
365 &self.shape
366 }
367
368 fn unmap(&mut self) {
369 let ptr = self.ptr.lock().expect("Failed to lock DmaMap pointer");
370
371 if let Err(e) = unsafe { nix::sys::mman::munmap(**ptr, self.mmap_size) } {
372 warn!("Failed to unmap DMA memory: {e}");
373 }
374
375 #[cfg(target_os = "linux")]
376 if let Err(e) = crate::dmabuf::end_readwrite(&self.fd) {
377 warn!("Failed to end read/write on DMA memory: {e}");
378 }
379 }
380
381 fn as_slice(&self) -> &[T] {
382 let ptr = self.ptr.lock().expect("Failed to lock DmaMap pointer");
383 let base = unsafe { (ptr.as_ptr() as *const u8).add(self.offset) as *const T };
384 unsafe { std::slice::from_raw_parts(base, self.len()) }
385 }
386
387 fn as_mut_slice(&mut self) -> &mut [T] {
388 let ptr = self.ptr.lock().expect("Failed to lock DmaMap pointer");
389 let base = unsafe { (ptr.as_ptr() as *mut u8).add(self.offset) as *mut T };
390 unsafe { std::slice::from_raw_parts_mut(base, self.len()) }
391 }
392}
393
394impl<T> Drop for DmaMap<T>
395where
396 T: Num + Clone + fmt::Debug,
397{
398 fn drop(&mut self) {
399 trace!("DmaMap dropped, unmapping memory");
400 self.unmap();
401 }
402}
403
404#[cfg(test)]
405mod tests {
406 use super::*;
407
408 #[cfg(target_os = "linux")]
412 fn dummy_fd() -> std::os::fd::OwnedFd {
413 use std::os::fd::FromRawFd;
414 use std::os::unix::io::IntoRawFd;
415 let f = std::fs::File::open("/dev/null").expect("open /dev/null");
416 unsafe { std::os::fd::OwnedFd::from_raw_fd(f.into_raw_fd()) }
417 }
418
419 #[test]
421 #[cfg(target_os = "linux")]
422 fn test_dma_map_offset_exceeds_buf_size() {
423 let fd = dummy_fd();
424 let result = DmaMap::<u8>::new(fd, &[4096], 4096, 4096);
427 match result {
428 Err(Error::InvalidSize(n)) => assert_eq!(n, 8192),
429 other => panic!("expected InvalidSize(8192), got {:?}", other),
430 }
431 }
432
433 #[test]
435 #[cfg(target_os = "linux")]
436 fn test_dma_map_misaligned_offset() {
437 let fd = dummy_fd();
438 let result = DmaMap::<u32>::new(fd, &[1024], 8192, 3);
441 assert!(
442 matches!(result, Err(Error::InvalidOperation(_))),
443 "expected InvalidOperation for misaligned offset, got {:?}",
444 result
445 );
446 }
447
448 #[test]
450 #[cfg(target_os = "linux")]
451 fn test_dma_map_offset_overflow() {
452 let fd = dummy_fd();
453 let result = DmaMap::<u8>::new(fd, &[1], usize::MAX, usize::MAX);
455 assert!(
456 matches!(result, Err(Error::InvalidSize(0))),
457 "expected InvalidSize(0) on overflow, got {:?}",
458 result
459 );
460 }
461
462 #[test]
463 #[cfg(target_os = "linux")]
464 fn test_dma_map_with_offset() {
465 use crate::{Tensor, TensorMapTrait, TensorMemory, TensorTrait};
466
467 let total_size: usize = 4096 * 4; let offset: usize = 4096; let data_size: usize = 4096; let large_buf = match Tensor::<u8>::new(&[total_size], Some(TensorMemory::Dma), None) {
473 Ok(buf) => buf,
474 Err(_) => {
475 eprintln!("SKIPPED: DMA not available");
476 return;
477 }
478 };
479
480 {
482 let mut map = large_buf.map().unwrap();
483 map.as_mut_slice().fill(0xAA);
484 }
485
486 let fd = large_buf.clone_fd().unwrap();
488 let mut offset_tensor = Tensor::<u8>::from_fd(fd, &[data_size], None).unwrap();
489 offset_tensor.set_plane_offset(offset);
490
491 let mut map = offset_tensor.map().unwrap();
493 let slice = map.as_mut_slice();
494
495 assert_eq!(slice.len(), data_size);
497 assert!(
498 slice.iter().all(|&b| b == 0xAA),
499 "Offset tensor map should see sentinel data at offset"
500 );
501
502 slice.fill(0xBB);
504 drop(map);
505
506 {
509 let map = large_buf.map().unwrap();
510 let buf = map.as_slice();
511 assert!(
512 buf[..offset].iter().all(|&b| b == 0xAA),
513 "Data before offset should be unchanged"
514 );
515 assert!(
516 buf[offset..offset + data_size].iter().all(|&b| b == 0xBB),
517 "Data at offset should be 0xBB"
518 );
519 }
520 }
521}