1use crate::{
5 error::{Error, Result},
6 TensorMap, TensorMapTrait, TensorMemory, TensorTrait,
7};
8use log::{trace, warn};
9use num_traits::Num;
10use std::{
11 ffi::c_void,
12 fmt,
13 num::NonZero,
14 ops::{Deref, DerefMut},
15 os::fd::{AsRawFd, OwnedFd},
16 ptr::NonNull,
17 sync::{Arc, Mutex},
18};
19
20#[derive(Debug)]
26pub struct DmaTensor<T>
27where
28 T: Num + Clone + fmt::Debug + Send + Sync,
29{
30 pub name: String,
31 pub fd: OwnedFd,
32 pub shape: Vec<usize>,
33 pub _marker: std::marker::PhantomData<T>,
34 #[cfg(target_os = "linux")]
35 _drm_attachment: Option<crate::dmabuf::DrmAttachment>,
36 identity: crate::BufferIdentity,
37}
38
39unsafe impl<T> Send for DmaTensor<T> where T: Num + Clone + fmt::Debug + Send + Sync {}
40unsafe impl<T> Sync for DmaTensor<T> where T: Num + Clone + fmt::Debug + Send + Sync {}
41
42impl<T> TensorTrait<T> for DmaTensor<T>
43where
44 T: Num + Clone + fmt::Debug + Send + Sync,
45{
46 #[cfg(target_os = "linux")]
47 fn new(shape: &[usize], name: Option<&str>) -> Result<Self> {
48 use log::debug;
49 use nix::sys::stat::fstat;
50
51 let size = shape.iter().product::<usize>() * std::mem::size_of::<T>();
52 let name = match name {
53 Some(name) => name.to_owned(),
54 None => {
55 let uuid = uuid::Uuid::new_v4().as_simple().to_string();
56 format!("/{}", &uuid[..16])
57 }
58 };
59
60 let heap = match dma_heap::Heap::new(dma_heap::HeapKind::Cma) {
61 Ok(heap) => heap,
62 Err(_) => dma_heap::Heap::new(dma_heap::HeapKind::System)?,
63 };
64
65 let dma_fd = heap.allocate(size)?;
66 let stat = fstat(&dma_fd)?;
67 debug!("DMA memory stat: {stat:?}");
68
69 let drm_attachment = crate::dmabuf::DrmAttachment::new(&dma_fd);
70
71 Ok(DmaTensor::<T> {
72 name: name.to_owned(),
73 fd: dma_fd,
74 shape: shape.to_vec(),
75 _marker: std::marker::PhantomData,
76 _drm_attachment: drm_attachment,
77 identity: crate::BufferIdentity::new(),
78 })
79 }
80
81 #[cfg(not(target_os = "linux"))]
82 fn new(_shape: &[usize], _name: Option<&str>) -> Result<Self> {
83 Err(Error::NotImplemented(
84 "DMA tensors are not supported on this platform".to_owned(),
85 ))
86 }
87
88 fn from_fd(fd: OwnedFd, shape: &[usize], name: Option<&str>) -> Result<Self> {
89 if shape.is_empty() {
90 return Err(Error::InvalidSize(0));
91 }
92
93 let size = shape.iter().product::<usize>() * std::mem::size_of::<T>();
94 if size == 0 {
95 return Err(Error::InvalidSize(0));
96 }
97
98 #[cfg(target_os = "linux")]
99 let drm_attachment = crate::dmabuf::DrmAttachment::new(&fd);
100
101 Ok(DmaTensor {
102 name: name.unwrap_or("").to_owned(),
103 fd,
104 shape: shape.to_vec(),
105 _marker: std::marker::PhantomData,
106 #[cfg(target_os = "linux")]
107 _drm_attachment: drm_attachment,
108 identity: crate::BufferIdentity::new(),
109 })
110 }
111
112 fn clone_fd(&self) -> Result<OwnedFd> {
113 Ok(self.fd.try_clone()?)
114 }
115
116 fn memory(&self) -> TensorMemory {
117 TensorMemory::Dma
118 }
119
120 fn name(&self) -> String {
121 self.name.clone()
122 }
123
124 fn shape(&self) -> &[usize] {
125 &self.shape
126 }
127
128 fn reshape(&mut self, shape: &[usize]) -> Result<()> {
129 if shape.is_empty() {
130 return Err(Error::InvalidSize(0));
131 }
132
133 let new_size = shape.iter().product::<usize>() * std::mem::size_of::<T>();
134 if new_size != self.size() {
135 return Err(Error::ShapeMismatch(format!(
136 "Cannot reshape incompatible shape: {:?} to {:?}",
137 self.shape, shape
138 )));
139 }
140
141 self.shape = shape.to_vec();
142 Ok(())
143 }
144
145 fn map(&self) -> Result<TensorMap<T>> {
146 Ok(TensorMap::Dma(DmaMap::new(
147 self.fd.try_clone()?,
148 &self.shape,
149 )?))
150 }
151
152 fn buffer_identity(&self) -> &crate::BufferIdentity {
153 &self.identity
154 }
155}
156
157impl<T> AsRawFd for DmaTensor<T>
158where
159 T: Num + Clone + fmt::Debug + Send + Sync,
160{
161 fn as_raw_fd(&self) -> std::os::fd::RawFd {
162 self.fd.as_raw_fd()
163 }
164}
165
166impl<T> DmaTensor<T>
167where
168 T: Num + Clone + Send + Sync + std::fmt::Debug + Send + Sync,
169{
170 pub fn try_clone(&self) -> Result<Self> {
171 let fd = self.clone_fd()?;
172 #[cfg(target_os = "linux")]
173 let drm_attachment = crate::dmabuf::DrmAttachment::new(&fd);
174 Ok(Self {
175 name: self.name.clone(),
176 fd,
177 shape: self.shape.clone(),
178 _marker: std::marker::PhantomData,
179 #[cfg(target_os = "linux")]
180 _drm_attachment: drm_attachment,
181 identity: self.identity.clone(),
182 })
183 }
184}
185
186#[derive(Debug)]
187pub struct DmaMap<T>
188where
189 T: Num + Clone + fmt::Debug,
190{
191 ptr: Arc<Mutex<DmaPtr>>,
192 fd: OwnedFd,
193 shape: Vec<usize>,
194 _marker: std::marker::PhantomData<T>,
195}
196
197impl<T> DmaMap<T>
198where
199 T: Num + Clone + fmt::Debug,
200{
201 pub fn new(fd: OwnedFd, shape: &[usize]) -> Result<Self> {
202 if shape.is_empty() {
203 return Err(Error::InvalidSize(0));
204 }
205
206 let size = shape.iter().product::<usize>() * std::mem::size_of::<T>();
207 if size == 0 {
208 return Err(Error::InvalidSize(0));
209 }
210
211 #[cfg(target_os = "linux")]
212 crate::dmabuf::start_readwrite(&fd)?;
213
214 let ptr = unsafe {
215 nix::sys::mman::mmap(
216 None,
217 NonZero::new(size).ok_or(Error::InvalidSize(size))?,
218 nix::sys::mman::ProtFlags::PROT_READ | nix::sys::mman::ProtFlags::PROT_WRITE,
219 nix::sys::mman::MapFlags::MAP_SHARED,
220 &fd,
221 0,
222 )?
223 };
224
225 trace!("Mapping DMA memory: {ptr:?}");
226 let dma_ptr = DmaPtr(NonNull::new(ptr.as_ptr()).ok_or(Error::InvalidSize(size))?);
227 Ok(DmaMap {
228 ptr: Arc::new(Mutex::new(dma_ptr)),
229 fd,
230 shape: shape.to_vec(),
231 _marker: std::marker::PhantomData,
232 })
233 }
234}
235
236impl<T> Deref for DmaMap<T>
237where
238 T: Num + Clone + fmt::Debug,
239{
240 type Target = [T];
241
242 fn deref(&self) -> &[T] {
243 self.as_slice()
244 }
245}
246
247impl<T> DerefMut for DmaMap<T>
248where
249 T: Num + Clone + fmt::Debug,
250{
251 fn deref_mut(&mut self) -> &mut [T] {
252 self.as_mut_slice()
253 }
254}
255
256#[derive(Debug)]
257struct DmaPtr(NonNull<c_void>);
258impl Deref for DmaPtr {
259 type Target = NonNull<c_void>;
260
261 fn deref(&self) -> &Self::Target {
262 &self.0
263 }
264}
265
266unsafe impl Send for DmaPtr {}
267
268impl<T> TensorMapTrait<T> for DmaMap<T>
269where
270 T: Num + Clone + fmt::Debug,
271{
272 fn shape(&self) -> &[usize] {
273 &self.shape
274 }
275
276 fn unmap(&mut self) {
277 let ptr = self.ptr.lock().expect("Failed to lock DmaMap pointer");
278
279 if let Err(e) = unsafe { nix::sys::mman::munmap(**ptr, self.size()) } {
280 warn!("Failed to unmap DMA memory: {e}");
281 }
282
283 #[cfg(target_os = "linux")]
284 if let Err(e) = crate::dmabuf::end_readwrite(&self.fd) {
285 warn!("Failed to end read/write on DMA memory: {e}");
286 }
287 }
288
289 fn as_slice(&self) -> &[T] {
290 let ptr = self.ptr.lock().expect("Failed to lock DmaMap pointer");
291 unsafe { std::slice::from_raw_parts(ptr.as_ptr() as *const T, self.len()) }
292 }
293
294 fn as_mut_slice(&mut self) -> &mut [T] {
295 let ptr = self.ptr.lock().expect("Failed to lock DmaMap pointer");
296 unsafe { std::slice::from_raw_parts_mut(ptr.as_ptr() as *mut T, self.len()) }
297 }
298}
299
300impl<T> Drop for DmaMap<T>
301where
302 T: Num + Clone + fmt::Debug,
303{
304 fn drop(&mut self) {
305 trace!("DmaMap dropped, unmapping memory");
306 self.unmap();
307 }
308}