1use crate::{
5 error::{Error, Result},
6 TensorMap, TensorMapTrait, TensorMemory, TensorTrait,
7};
8use log::{trace, warn};
9use num_traits::Num;
10use std::{
11 ffi::c_void,
12 fmt,
13 num::NonZero,
14 ops::{Deref, DerefMut},
15 os::fd::{AsRawFd, OwnedFd},
16 ptr::NonNull,
17 sync::{Arc, Mutex},
18};
19
20#[derive(Debug)]
26pub struct DmaTensor<T>
27where
28 T: Num + Clone + fmt::Debug + Send + Sync,
29{
30 pub name: String,
31 pub fd: OwnedFd,
32 pub shape: Vec<usize>,
33 pub _marker: std::marker::PhantomData<T>,
34 #[cfg(target_os = "linux")]
35 _drm_attachment: Option<crate::dmabuf::DrmAttachment>,
36 identity: crate::BufferIdentity,
37}
38
39unsafe impl<T> Send for DmaTensor<T> where T: Num + Clone + fmt::Debug + Send + Sync {}
40unsafe impl<T> Sync for DmaTensor<T> where T: Num + Clone + fmt::Debug + Send + Sync {}
41
42impl<T> TensorTrait<T> for DmaTensor<T>
43where
44 T: Num + Clone + fmt::Debug + Send + Sync,
45{
46 #[cfg(target_os = "linux")]
47 fn new(shape: &[usize], name: Option<&str>) -> Result<Self> {
48 use log::debug;
49 use nix::sys::stat::fstat;
50
51 let size = shape.iter().product::<usize>() * std::mem::size_of::<T>();
52 let name = match name {
53 Some(name) => name.to_owned(),
54 None => {
55 let uuid = uuid::Uuid::new_v4().as_simple().to_string();
56 format!("/{}", &uuid[..16])
57 }
58 };
59
60 let heap = match dma_heap::Heap::new(dma_heap::HeapKind::Cma) {
61 Ok(heap) => heap,
62 Err(_) => dma_heap::Heap::new(dma_heap::HeapKind::System)?,
63 };
64
65 let dma_fd = heap.allocate(size)?;
66 let stat = fstat(&dma_fd)?;
67 debug!("DMA memory stat: {stat:?}");
68
69 let drm_attachment = crate::dmabuf::DrmAttachment::new(&dma_fd);
70
71 Ok(DmaTensor::<T> {
72 name: name.to_owned(),
73 fd: dma_fd,
74 shape: shape.to_vec(),
75 _marker: std::marker::PhantomData,
76 _drm_attachment: drm_attachment,
77 identity: crate::BufferIdentity::new(),
78 })
79 }
80
81 #[cfg(not(target_os = "linux"))]
82 fn new(_shape: &[usize], _name: Option<&str>) -> Result<Self> {
83 Err(Error::NotImplemented(
84 "DMA tensors are not supported on this platform".to_owned(),
85 ))
86 }
87
88 fn from_fd(fd: OwnedFd, shape: &[usize], name: Option<&str>) -> Result<Self> {
89 if shape.is_empty() {
90 return Err(Error::InvalidSize(0));
91 }
92
93 let size = shape.iter().product::<usize>() * std::mem::size_of::<T>();
94 if size == 0 {
95 return Err(Error::InvalidSize(0));
96 }
97
98 #[cfg(target_os = "linux")]
99 let drm_attachment = crate::dmabuf::DrmAttachment::new(&fd);
100
101 Ok(DmaTensor {
102 name: name.unwrap_or("").to_owned(),
103 fd,
104 shape: shape.to_vec(),
105 _marker: std::marker::PhantomData,
106 #[cfg(target_os = "linux")]
107 _drm_attachment: drm_attachment,
108 identity: crate::BufferIdentity::new(),
109 })
110 }
111
112 fn clone_fd(&self) -> Result<OwnedFd> {
113 Ok(self.fd.try_clone()?)
114 }
115
116 fn memory(&self) -> TensorMemory {
117 TensorMemory::Dma
118 }
119
120 fn name(&self) -> String {
121 self.name.clone()
122 }
123
124 fn shape(&self) -> &[usize] {
125 &self.shape
126 }
127
128 fn reshape(&mut self, shape: &[usize]) -> Result<()> {
129 if shape.is_empty() {
130 return Err(Error::InvalidSize(0));
131 }
132
133 let new_size = shape.iter().product::<usize>() * std::mem::size_of::<T>();
134 if new_size != self.size() {
135 return Err(Error::ShapeMismatch(format!(
136 "Cannot reshape incompatible shape: {:?} to {:?}",
137 self.shape, shape
138 )));
139 }
140
141 self.shape = shape.to_vec();
142 Ok(())
143 }
144
145 fn map(&self) -> Result<TensorMap<T>> {
146 Ok(TensorMap::Dma(DmaMap::new(
147 self.fd.try_clone()?,
148 &self.shape,
149 )?))
150 }
151
152 fn buffer_identity(&self) -> &crate::BufferIdentity {
153 &self.identity
154 }
155}
156
157impl<T> AsRawFd for DmaTensor<T>
158where
159 T: Num + Clone + fmt::Debug + Send + Sync,
160{
161 fn as_raw_fd(&self) -> std::os::fd::RawFd {
162 self.fd.as_raw_fd()
163 }
164}
165
166impl<T> DmaTensor<T>
167where
168 T: Num + Clone + Send + Sync + std::fmt::Debug + Send + Sync,
169{
170 pub fn try_clone(&self) -> Result<Self> {
171 let fd = self.clone_fd()?;
172 #[cfg(target_os = "linux")]
173 let drm_attachment = crate::dmabuf::DrmAttachment::new(&fd);
174 Ok(Self {
175 name: self.name.clone(),
176 fd,
177 shape: self.shape.clone(),
178 _marker: std::marker::PhantomData,
179 #[cfg(target_os = "linux")]
180 _drm_attachment: drm_attachment,
181 identity: self.identity.clone(),
182 })
183 }
184}
185
186#[derive(Debug)]
187pub struct DmaMap<T>
188where
189 T: Num + Clone + fmt::Debug,
190{
191 ptr: Arc<Mutex<DmaPtr>>,
192 fd: OwnedFd,
193 shape: Vec<usize>,
194 _marker: std::marker::PhantomData<T>,
195}
196
197impl<T> DmaMap<T>
198where
199 T: Num + Clone + fmt::Debug,
200{
201 pub fn new(fd: OwnedFd, shape: &[usize]) -> Result<Self> {
202 if shape.is_empty() {
203 return Err(Error::InvalidSize(0));
204 }
205
206 let size = shape.iter().product::<usize>() * std::mem::size_of::<T>();
207 if size == 0 {
208 return Err(Error::InvalidSize(0));
209 }
210
211 #[cfg(target_os = "linux")]
212 {
213 trace!("DmaMap: sync start fd={} size={size}", fd.as_raw_fd());
214 if let Err(e) = crate::dmabuf::start_readwrite(&fd) {
215 warn!(
216 "DmaMap: DMA_BUF_IOCTL_SYNC(START) failed fd={}: {e}",
217 fd.as_raw_fd()
218 );
219 return Err(Error::NixError(e));
220 }
221 }
222
223 let ptr = unsafe {
224 nix::sys::mman::mmap(
225 None,
226 NonZero::new(size).ok_or(Error::InvalidSize(size))?,
227 nix::sys::mman::ProtFlags::PROT_READ | nix::sys::mman::ProtFlags::PROT_WRITE,
228 nix::sys::mman::MapFlags::MAP_SHARED,
229 &fd,
230 0,
231 )?
232 };
233
234 trace!("Mapping DMA memory: {ptr:?}");
235 let dma_ptr = DmaPtr(NonNull::new(ptr.as_ptr()).ok_or(Error::InvalidSize(size))?);
236 Ok(DmaMap {
237 ptr: Arc::new(Mutex::new(dma_ptr)),
238 fd,
239 shape: shape.to_vec(),
240 _marker: std::marker::PhantomData,
241 })
242 }
243}
244
245impl<T> Deref for DmaMap<T>
246where
247 T: Num + Clone + fmt::Debug,
248{
249 type Target = [T];
250
251 fn deref(&self) -> &[T] {
252 self.as_slice()
253 }
254}
255
256impl<T> DerefMut for DmaMap<T>
257where
258 T: Num + Clone + fmt::Debug,
259{
260 fn deref_mut(&mut self) -> &mut [T] {
261 self.as_mut_slice()
262 }
263}
264
265#[derive(Debug)]
266struct DmaPtr(NonNull<c_void>);
267impl Deref for DmaPtr {
268 type Target = NonNull<c_void>;
269
270 fn deref(&self) -> &Self::Target {
271 &self.0
272 }
273}
274
275unsafe impl Send for DmaPtr {}
276
277impl<T> TensorMapTrait<T> for DmaMap<T>
278where
279 T: Num + Clone + fmt::Debug,
280{
281 fn shape(&self) -> &[usize] {
282 &self.shape
283 }
284
285 fn unmap(&mut self) {
286 let ptr = self.ptr.lock().expect("Failed to lock DmaMap pointer");
287
288 if let Err(e) = unsafe { nix::sys::mman::munmap(**ptr, self.size()) } {
289 warn!("Failed to unmap DMA memory: {e}");
290 }
291
292 #[cfg(target_os = "linux")]
293 if let Err(e) = crate::dmabuf::end_readwrite(&self.fd) {
294 warn!("Failed to end read/write on DMA memory: {e}");
295 }
296 }
297
298 fn as_slice(&self) -> &[T] {
299 let ptr = self.ptr.lock().expect("Failed to lock DmaMap pointer");
300 unsafe { std::slice::from_raw_parts(ptr.as_ptr() as *const T, self.len()) }
301 }
302
303 fn as_mut_slice(&mut self) -> &mut [T] {
304 let ptr = self.ptr.lock().expect("Failed to lock DmaMap pointer");
305 unsafe { std::slice::from_raw_parts_mut(ptr.as_ptr() as *mut T, self.len()) }
306 }
307}
308
309impl<T> Drop for DmaMap<T>
310where
311 T: Num + Clone + fmt::Debug,
312{
313 fn drop(&mut self) {
314 trace!("DmaMap dropped, unmapping memory");
315 self.unmap();
316 }
317}