1#[cfg(target_os = "linux")]
30mod dma;
31#[cfg(target_os = "linux")]
32mod dmabuf;
33mod error;
34mod format;
35mod mem;
36mod pbo;
37#[cfg(unix)]
38mod shm;
39mod tensor_dyn;
40
41#[cfg(target_os = "linux")]
42pub use crate::dma::{DmaMap, DmaTensor};
43pub use crate::mem::{MemMap, MemTensor};
44pub use crate::pbo::{PboMap, PboMapping, PboOps, PboTensor};
45#[cfg(unix)]
46pub use crate::shm::{ShmMap, ShmTensor};
47pub use error::{Error, Result};
48pub use format::{PixelFormat, PixelLayout};
49use num_traits::Num;
50use serde::{Deserialize, Serialize};
51#[cfg(unix)]
52use std::os::fd::OwnedFd;
53use std::{
54 fmt,
55 ops::{Deref, DerefMut},
56 sync::{
57 atomic::{AtomicU64, Ordering},
58 Arc, Weak,
59 },
60};
61pub use tensor_dyn::TensorDyn;
62
63#[cfg(unix)]
83pub struct PlaneDescriptor {
84 fd: OwnedFd,
85 stride: Option<usize>,
86 offset: Option<usize>,
87}
88
89#[cfg(unix)]
90impl PlaneDescriptor {
91 pub fn new(fd: std::os::fd::BorrowedFd<'_>) -> Result<Self> {
101 let owned = fd.try_clone_to_owned()?;
102 Ok(Self {
103 fd: owned,
104 stride: None,
105 offset: None,
106 })
107 }
108
109 pub fn with_stride(mut self, stride: usize) -> Self {
111 self.stride = Some(stride);
112 self
113 }
114
115 pub fn with_offset(mut self, offset: usize) -> Self {
117 self.offset = Some(offset);
118 self
119 }
120
121 pub fn into_fd(self) -> OwnedFd {
123 self.fd
124 }
125
126 pub fn stride(&self) -> Option<usize> {
128 self.stride
129 }
130
131 pub fn offset(&self) -> Option<usize> {
133 self.offset
134 }
135}
136
137#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
139#[repr(u8)]
140#[non_exhaustive]
141pub enum DType {
142 U8,
143 I8,
144 U16,
145 I16,
146 U32,
147 I32,
148 U64,
149 I64,
150 F16,
151 F32,
152 F64,
153}
154
155impl DType {
156 pub const fn size(&self) -> usize {
158 match self {
159 Self::U8 | Self::I8 => 1,
160 Self::U16 | Self::I16 | Self::F16 => 2,
161 Self::U32 | Self::I32 | Self::F32 => 4,
162 Self::U64 | Self::I64 | Self::F64 => 8,
163 }
164 }
165
166 pub const fn name(&self) -> &'static str {
168 match self {
169 Self::U8 => "u8",
170 Self::I8 => "i8",
171 Self::U16 => "u16",
172 Self::I16 => "i16",
173 Self::U32 => "u32",
174 Self::I32 => "i32",
175 Self::U64 => "u64",
176 Self::I64 => "i64",
177 Self::F16 => "f16",
178 Self::F32 => "f32",
179 Self::F64 => "f64",
180 }
181 }
182}
183
184impl fmt::Display for DType {
185 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
186 f.write_str(self.name())
187 }
188}
189
190static NEXT_BUFFER_ID: AtomicU64 = AtomicU64::new(1);
192
193#[derive(Debug, Clone)]
199pub struct BufferIdentity {
200 id: u64,
201 guard: Arc<()>,
202}
203
204impl BufferIdentity {
205 pub fn new() -> Self {
207 Self {
208 id: NEXT_BUFFER_ID.fetch_add(1, Ordering::Relaxed),
209 guard: Arc::new(()),
210 }
211 }
212
213 pub fn id(&self) -> u64 {
215 self.id
216 }
217
218 pub fn weak(&self) -> Weak<()> {
221 Arc::downgrade(&self.guard)
222 }
223}
224
225impl Default for BufferIdentity {
226 fn default() -> Self {
227 Self::new()
228 }
229}
230
231#[cfg(target_os = "linux")]
232use nix::sys::stat::{major, minor};
233
234pub trait TensorTrait<T>: Send + Sync
235where
236 T: Num + Clone + fmt::Debug,
237{
238 fn new(shape: &[usize], name: Option<&str>) -> Result<Self>
241 where
242 Self: Sized;
243
244 #[cfg(unix)]
245 fn from_fd(fd: std::os::fd::OwnedFd, shape: &[usize], name: Option<&str>) -> Result<Self>
251 where
252 Self: Sized;
253
254 #[cfg(unix)]
255 fn clone_fd(&self) -> Result<std::os::fd::OwnedFd>;
257
258 fn memory(&self) -> TensorMemory;
260
261 fn name(&self) -> String;
263
264 fn len(&self) -> usize {
266 self.shape().iter().product()
267 }
268
269 fn is_empty(&self) -> bool {
271 self.len() == 0
272 }
273
274 fn size(&self) -> usize {
276 self.len() * std::mem::size_of::<T>()
277 }
278
279 fn shape(&self) -> &[usize];
281
282 fn reshape(&mut self, shape: &[usize]) -> Result<()>;
285
286 fn map(&self) -> Result<TensorMap<T>>;
289
290 fn buffer_identity(&self) -> &BufferIdentity;
292}
293
294pub trait TensorMapTrait<T>
295where
296 T: Num + Clone + fmt::Debug,
297{
298 fn shape(&self) -> &[usize];
300
301 fn unmap(&mut self);
303
304 fn len(&self) -> usize {
306 self.shape().iter().product()
307 }
308
309 fn is_empty(&self) -> bool {
311 self.len() == 0
312 }
313
314 fn size(&self) -> usize {
316 self.len() * std::mem::size_of::<T>()
317 }
318
319 fn as_slice(&self) -> &[T];
321
322 fn as_mut_slice(&mut self) -> &mut [T];
324
325 #[cfg(feature = "ndarray")]
326 fn view(&'_ self) -> Result<ndarray::ArrayView<'_, T, ndarray::Dim<ndarray::IxDynImpl>>> {
328 Ok(ndarray::ArrayView::from_shape(
329 self.shape(),
330 self.as_slice(),
331 )?)
332 }
333
334 #[cfg(feature = "ndarray")]
335 fn view_mut(
337 &'_ mut self,
338 ) -> Result<ndarray::ArrayViewMut<'_, T, ndarray::Dim<ndarray::IxDynImpl>>> {
339 let shape = self.shape().to_vec();
340 Ok(ndarray::ArrayViewMut::from_shape(
341 shape,
342 self.as_mut_slice(),
343 )?)
344 }
345}
346
347#[derive(Debug, Clone, Copy, PartialEq, Eq)]
348pub enum TensorMemory {
349 #[cfg(target_os = "linux")]
350 Dma,
354 #[cfg(unix)]
355 Shm,
358
359 Mem,
361
362 Pbo,
365}
366
367impl From<TensorMemory> for String {
368 fn from(memory: TensorMemory) -> Self {
369 match memory {
370 #[cfg(target_os = "linux")]
371 TensorMemory::Dma => "dma".to_owned(),
372 #[cfg(unix)]
373 TensorMemory::Shm => "shm".to_owned(),
374 TensorMemory::Mem => "mem".to_owned(),
375 TensorMemory::Pbo => "pbo".to_owned(),
376 }
377 }
378}
379
380impl TryFrom<&str> for TensorMemory {
381 type Error = Error;
382
383 fn try_from(s: &str) -> Result<Self> {
384 match s {
385 #[cfg(target_os = "linux")]
386 "dma" => Ok(TensorMemory::Dma),
387 #[cfg(unix)]
388 "shm" => Ok(TensorMemory::Shm),
389 "mem" => Ok(TensorMemory::Mem),
390 "pbo" => Ok(TensorMemory::Pbo),
391 _ => Err(Error::InvalidMemoryType(s.to_owned())),
392 }
393 }
394}
395
396#[derive(Debug)]
397#[allow(dead_code)] pub(crate) enum TensorStorage<T>
399where
400 T: Num + Clone + fmt::Debug + Send + Sync,
401{
402 #[cfg(target_os = "linux")]
403 Dma(DmaTensor<T>),
404 #[cfg(unix)]
405 Shm(ShmTensor<T>),
406 Mem(MemTensor<T>),
407 Pbo(PboTensor<T>),
408}
409
410impl<T> TensorStorage<T>
411where
412 T: Num + Clone + fmt::Debug + Send + Sync,
413{
414 fn new(shape: &[usize], memory: Option<TensorMemory>, name: Option<&str>) -> Result<Self> {
419 match memory {
420 #[cfg(target_os = "linux")]
421 Some(TensorMemory::Dma) => {
422 DmaTensor::<T>::new(shape, name).map(TensorStorage::Dma)
423 }
424 #[cfg(unix)]
425 Some(TensorMemory::Shm) => {
426 ShmTensor::<T>::new(shape, name).map(TensorStorage::Shm)
427 }
428 Some(TensorMemory::Mem) => {
429 MemTensor::<T>::new(shape, name).map(TensorStorage::Mem)
430 }
431 Some(TensorMemory::Pbo) => Err(crate::error::Error::NotImplemented(
432 "PboTensor cannot be created via Tensor::new() — use ImageProcessor::create_image()".to_owned(),
433 )),
434 None => {
435 if std::env::var("EDGEFIRST_TENSOR_FORCE_MEM")
436 .is_ok_and(|x| x != "0" && x.to_lowercase() != "false")
437 {
438 MemTensor::<T>::new(shape, name).map(TensorStorage::Mem)
439 } else {
440 #[cfg(target_os = "linux")]
441 {
442 match DmaTensor::<T>::new(shape, name) {
444 Ok(tensor) => Ok(TensorStorage::Dma(tensor)),
445 Err(_) => {
446 match ShmTensor::<T>::new(shape, name)
447 .map(TensorStorage::Shm)
448 {
449 Ok(tensor) => Ok(tensor),
450 Err(_) => MemTensor::<T>::new(shape, name)
451 .map(TensorStorage::Mem),
452 }
453 }
454 }
455 }
456 #[cfg(all(unix, not(target_os = "linux")))]
457 {
458 match ShmTensor::<T>::new(shape, name) {
460 Ok(tensor) => Ok(TensorStorage::Shm(tensor)),
461 Err(_) => {
462 MemTensor::<T>::new(shape, name).map(TensorStorage::Mem)
463 }
464 }
465 }
466 #[cfg(not(unix))]
467 {
468 MemTensor::<T>::new(shape, name).map(TensorStorage::Mem)
470 }
471 }
472 }
473 }
474 }
475
476 #[cfg(unix)]
479 fn from_fd(fd: OwnedFd, shape: &[usize], name: Option<&str>) -> Result<Self> {
480 #[cfg(target_os = "linux")]
481 {
482 use nix::sys::stat::fstat;
483
484 let stat = fstat(&fd)?;
485 let major = major(stat.st_dev);
486 let minor = minor(stat.st_dev);
487
488 log::debug!("Creating tensor from fd: major={major}, minor={minor}");
489
490 if major != 0 {
491 return Err(Error::UnknownDeviceType(major, minor));
493 }
494
495 match minor {
496 9 | 10 => {
497 DmaTensor::<T>::from_fd(fd, shape, name).map(TensorStorage::Dma)
499 }
500 _ => {
501 ShmTensor::<T>::from_fd(fd, shape, name).map(TensorStorage::Shm)
503 }
504 }
505 }
506 #[cfg(all(unix, not(target_os = "linux")))]
507 {
508 ShmTensor::<T>::from_fd(fd, shape, name).map(TensorStorage::Shm)
510 }
511 }
512}
513
514impl<T> TensorTrait<T> for TensorStorage<T>
515where
516 T: Num + Clone + fmt::Debug + Send + Sync,
517{
518 fn new(shape: &[usize], name: Option<&str>) -> Result<Self> {
519 Self::new(shape, None, name)
520 }
521
522 #[cfg(unix)]
523 fn from_fd(fd: OwnedFd, shape: &[usize], name: Option<&str>) -> Result<Self> {
524 Self::from_fd(fd, shape, name)
525 }
526
527 #[cfg(unix)]
528 fn clone_fd(&self) -> Result<OwnedFd> {
529 match self {
530 #[cfg(target_os = "linux")]
531 TensorStorage::Dma(t) => t.clone_fd(),
532 TensorStorage::Shm(t) => t.clone_fd(),
533 TensorStorage::Mem(t) => t.clone_fd(),
534 TensorStorage::Pbo(t) => t.clone_fd(),
535 }
536 }
537
538 fn memory(&self) -> TensorMemory {
539 match self {
540 #[cfg(target_os = "linux")]
541 TensorStorage::Dma(_) => TensorMemory::Dma,
542 #[cfg(unix)]
543 TensorStorage::Shm(_) => TensorMemory::Shm,
544 TensorStorage::Mem(_) => TensorMemory::Mem,
545 TensorStorage::Pbo(_) => TensorMemory::Pbo,
546 }
547 }
548
549 fn name(&self) -> String {
550 match self {
551 #[cfg(target_os = "linux")]
552 TensorStorage::Dma(t) => t.name(),
553 #[cfg(unix)]
554 TensorStorage::Shm(t) => t.name(),
555 TensorStorage::Mem(t) => t.name(),
556 TensorStorage::Pbo(t) => t.name(),
557 }
558 }
559
560 fn shape(&self) -> &[usize] {
561 match self {
562 #[cfg(target_os = "linux")]
563 TensorStorage::Dma(t) => t.shape(),
564 #[cfg(unix)]
565 TensorStorage::Shm(t) => t.shape(),
566 TensorStorage::Mem(t) => t.shape(),
567 TensorStorage::Pbo(t) => t.shape(),
568 }
569 }
570
571 fn reshape(&mut self, shape: &[usize]) -> Result<()> {
572 match self {
573 #[cfg(target_os = "linux")]
574 TensorStorage::Dma(t) => t.reshape(shape),
575 #[cfg(unix)]
576 TensorStorage::Shm(t) => t.reshape(shape),
577 TensorStorage::Mem(t) => t.reshape(shape),
578 TensorStorage::Pbo(t) => t.reshape(shape),
579 }
580 }
581
582 fn map(&self) -> Result<TensorMap<T>> {
583 match self {
584 #[cfg(target_os = "linux")]
585 TensorStorage::Dma(t) => t.map(),
586 #[cfg(unix)]
587 TensorStorage::Shm(t) => t.map(),
588 TensorStorage::Mem(t) => t.map(),
589 TensorStorage::Pbo(t) => t.map(),
590 }
591 }
592
593 fn buffer_identity(&self) -> &BufferIdentity {
594 match self {
595 #[cfg(target_os = "linux")]
596 TensorStorage::Dma(t) => t.buffer_identity(),
597 #[cfg(unix)]
598 TensorStorage::Shm(t) => t.buffer_identity(),
599 TensorStorage::Mem(t) => t.buffer_identity(),
600 TensorStorage::Pbo(t) => t.buffer_identity(),
601 }
602 }
603}
604
605#[derive(Debug)]
611pub struct Tensor<T>
612where
613 T: Num + Clone + fmt::Debug + Send + Sync,
614{
615 pub(crate) storage: TensorStorage<T>,
616 format: Option<PixelFormat>,
617 chroma: Option<Box<Tensor<T>>>,
618 row_stride: Option<usize>,
621 plane_offset: Option<usize>,
624}
625
626impl<T> Tensor<T>
627where
628 T: Num + Clone + fmt::Debug + Send + Sync,
629{
630 pub(crate) fn wrap(storage: TensorStorage<T>) -> Self {
632 Self {
633 storage,
634 format: None,
635 chroma: None,
636 row_stride: None,
637 plane_offset: None,
638 }
639 }
640
641 pub fn new(shape: &[usize], memory: Option<TensorMemory>, name: Option<&str>) -> Result<Self> {
666 TensorStorage::new(shape, memory, name).map(Self::wrap)
667 }
668
669 pub fn image(
671 width: usize,
672 height: usize,
673 format: PixelFormat,
674 memory: Option<TensorMemory>,
675 ) -> Result<Self> {
676 let shape = match format.layout() {
677 PixelLayout::Packed => vec![height, width, format.channels()],
678 PixelLayout::Planar => vec![format.channels(), height, width],
679 PixelLayout::SemiPlanar => {
680 let total_h = match format {
684 PixelFormat::Nv12 => {
685 if !height.is_multiple_of(2) {
686 return Err(Error::InvalidArgument(format!(
687 "NV12 requires even height, got {height}"
688 )));
689 }
690 height * 3 / 2
691 }
692 PixelFormat::Nv16 => height * 2,
693 _ => {
694 return Err(Error::InvalidArgument(format!(
695 "unknown semi-planar height multiplier for {format:?}"
696 )))
697 }
698 };
699 vec![total_h, width]
700 }
701 };
702 let mut t = Self::new(&shape, memory, None)?;
703 t.format = Some(format);
704 Ok(t)
705 }
706
707 pub fn set_format(&mut self, format: PixelFormat) -> Result<()> {
724 let shape = self.shape();
725 match format.layout() {
726 PixelLayout::Packed => {
727 if shape.len() != 3 || shape[2] != format.channels() {
728 return Err(Error::InvalidShape(format!(
729 "packed format {format:?} expects [H, W, {}], got {shape:?}",
730 format.channels()
731 )));
732 }
733 }
734 PixelLayout::Planar => {
735 if shape.len() != 3 || shape[0] != format.channels() {
736 return Err(Error::InvalidShape(format!(
737 "planar format {format:?} expects [{}, H, W], got {shape:?}",
738 format.channels()
739 )));
740 }
741 }
742 PixelLayout::SemiPlanar => {
743 if shape.len() != 2 {
744 return Err(Error::InvalidShape(format!(
745 "semi-planar format {format:?} expects [H*k, W], got {shape:?}"
746 )));
747 }
748 match format {
749 PixelFormat::Nv12 if !shape[0].is_multiple_of(3) => {
750 return Err(Error::InvalidShape(format!(
751 "NV12 contiguous shape[0] must be divisible by 3, got {}",
752 shape[0]
753 )));
754 }
755 PixelFormat::Nv16 if !shape[0].is_multiple_of(2) => {
756 return Err(Error::InvalidShape(format!(
757 "NV16 contiguous shape[0] must be even, got {}",
758 shape[0]
759 )));
760 }
761 _ => {}
762 }
763 }
764 }
765 if self.format != Some(format) {
768 self.row_stride = None;
769 self.plane_offset = None;
770 }
771 self.format = Some(format);
772 Ok(())
773 }
774
775 pub fn format(&self) -> Option<PixelFormat> {
777 self.format
778 }
779
780 pub fn width(&self) -> Option<usize> {
782 let fmt = self.format?;
783 let shape = self.shape();
784 match fmt.layout() {
785 PixelLayout::Packed => Some(shape[1]),
786 PixelLayout::Planar => Some(shape[2]),
787 PixelLayout::SemiPlanar => Some(shape[1]),
788 }
789 }
790
791 pub fn height(&self) -> Option<usize> {
793 let fmt = self.format?;
794 let shape = self.shape();
795 match fmt.layout() {
796 PixelLayout::Packed => Some(shape[0]),
797 PixelLayout::Planar => Some(shape[1]),
798 PixelLayout::SemiPlanar => {
799 if self.is_multiplane() {
800 Some(shape[0])
801 } else {
802 match fmt {
803 PixelFormat::Nv12 => Some(shape[0] * 2 / 3),
804 PixelFormat::Nv16 => Some(shape[0] / 2),
805 _ => None,
806 }
807 }
808 }
809 }
810 }
811
812 pub fn from_planes(luma: Tensor<T>, chroma: Tensor<T>, format: PixelFormat) -> Result<Self> {
814 if format.layout() != PixelLayout::SemiPlanar {
815 return Err(Error::InvalidArgument(format!(
816 "from_planes requires a semi-planar format, got {format:?}"
817 )));
818 }
819 if chroma.format.is_some() || chroma.chroma.is_some() {
820 return Err(Error::InvalidArgument(
821 "chroma tensor must be a raw tensor (no format or chroma metadata)".into(),
822 ));
823 }
824 let luma_shape = luma.shape();
825 let chroma_shape = chroma.shape();
826 if luma_shape.len() != 2 || chroma_shape.len() != 2 {
827 return Err(Error::InvalidArgument(format!(
828 "from_planes expects 2D shapes, got luma={luma_shape:?} chroma={chroma_shape:?}"
829 )));
830 }
831 if luma_shape[1] != chroma_shape[1] {
832 return Err(Error::InvalidArgument(format!(
833 "luma width {} != chroma width {}",
834 luma_shape[1], chroma_shape[1]
835 )));
836 }
837 match format {
838 PixelFormat::Nv12 => {
839 if luma_shape[0] % 2 != 0 {
840 return Err(Error::InvalidArgument(format!(
841 "NV12 requires even luma height, got {}",
842 luma_shape[0]
843 )));
844 }
845 if chroma_shape[0] != luma_shape[0] / 2 {
846 return Err(Error::InvalidArgument(format!(
847 "NV12 chroma height {} != luma height / 2 ({})",
848 chroma_shape[0],
849 luma_shape[0] / 2
850 )));
851 }
852 }
853 PixelFormat::Nv16 => {
854 if chroma_shape[0] != luma_shape[0] {
855 return Err(Error::InvalidArgument(format!(
856 "NV16 chroma height {} != luma height {}",
857 chroma_shape[0], luma_shape[0]
858 )));
859 }
860 }
861 _ => {
862 return Err(Error::InvalidArgument(format!(
863 "from_planes only supports NV12 and NV16, got {format:?}"
864 )));
865 }
866 }
867
868 Ok(Tensor {
869 storage: luma.storage,
870 format: Some(format),
871 chroma: Some(Box::new(chroma)),
872 row_stride: luma.row_stride,
873 plane_offset: luma.plane_offset,
874 })
875 }
876
877 pub fn is_multiplane(&self) -> bool {
879 self.chroma.is_some()
880 }
881
882 pub fn chroma(&self) -> Option<&Tensor<T>> {
884 self.chroma.as_deref()
885 }
886
887 pub fn chroma_mut(&mut self) -> Option<&mut Tensor<T>> {
889 self.chroma.as_deref_mut()
890 }
891
892 pub fn row_stride(&self) -> Option<usize> {
894 self.row_stride
895 }
896
897 pub fn effective_row_stride(&self) -> Option<usize> {
902 if let Some(s) = self.row_stride {
903 return Some(s);
904 }
905 let fmt = self.format?;
906 let w = self.width()?;
907 let elem = std::mem::size_of::<T>();
908 Some(match fmt.layout() {
909 PixelLayout::Packed => w * fmt.channels() * elem,
910 PixelLayout::Planar | PixelLayout::SemiPlanar => w * elem,
911 })
912 }
913
914 pub fn set_row_stride(&mut self, stride: usize) -> Result<()> {
941 let fmt = self.format.ok_or_else(|| {
942 Error::InvalidArgument("cannot set row_stride without a pixel format".into())
943 })?;
944 let w = self.width().ok_or_else(|| {
945 Error::InvalidArgument("cannot determine width for row_stride validation".into())
946 })?;
947 let elem = std::mem::size_of::<T>();
948 let min_stride = match fmt.layout() {
949 PixelLayout::Packed => w * fmt.channels() * elem,
950 PixelLayout::Planar | PixelLayout::SemiPlanar => w * elem,
951 };
952 if stride < min_stride {
953 return Err(Error::InvalidArgument(format!(
954 "row_stride {stride} < minimum {min_stride} for {fmt:?} at width {w}"
955 )));
956 }
957 self.row_stride = Some(stride);
958 Ok(())
959 }
960
961 pub fn set_row_stride_unchecked(&mut self, stride: usize) {
967 self.row_stride = Some(stride);
968 }
969
970 pub fn with_row_stride(mut self, stride: usize) -> Result<Self> {
977 self.set_row_stride(stride)?;
978 Ok(self)
979 }
980
981 pub fn plane_offset(&self) -> Option<usize> {
983 self.plane_offset
984 }
985
986 pub fn set_plane_offset(&mut self, offset: usize) {
992 self.plane_offset = Some(offset);
993 }
994
995 pub fn with_plane_offset(mut self, offset: usize) -> Self {
998 self.set_plane_offset(offset);
999 self
1000 }
1001
1002 pub fn as_pbo(&self) -> Option<&PboTensor<T>> {
1004 match &self.storage {
1005 TensorStorage::Pbo(p) => Some(p),
1006 _ => None,
1007 }
1008 }
1009
1010 #[cfg(target_os = "linux")]
1012 pub fn as_dma(&self) -> Option<&DmaTensor<T>> {
1013 match &self.storage {
1014 TensorStorage::Dma(d) => Some(d),
1015 _ => None,
1016 }
1017 }
1018
1019 #[cfg(target_os = "linux")]
1030 pub fn dmabuf(&self) -> Result<std::os::fd::BorrowedFd<'_>> {
1031 use std::os::fd::AsFd;
1032 match &self.storage {
1033 TensorStorage::Dma(dma) => Ok(dma.fd.as_fd()),
1034 _ => Err(Error::NotImplemented(format!(
1035 "dmabuf requires DMA-backed tensor, got {:?}",
1036 self.storage.memory()
1037 ))),
1038 }
1039 }
1040
1041 pub fn from_pbo(pbo: PboTensor<T>) -> Self {
1043 Self {
1044 storage: TensorStorage::Pbo(pbo),
1045 format: None,
1046 chroma: None,
1047 row_stride: None,
1048 plane_offset: None,
1049 }
1050 }
1051}
1052
1053impl<T> TensorTrait<T> for Tensor<T>
1054where
1055 T: Num + Clone + fmt::Debug + Send + Sync,
1056{
1057 fn new(shape: &[usize], name: Option<&str>) -> Result<Self>
1058 where
1059 Self: Sized,
1060 {
1061 Self::new(shape, None, name)
1062 }
1063
1064 #[cfg(unix)]
1065 fn from_fd(fd: std::os::fd::OwnedFd, shape: &[usize], name: Option<&str>) -> Result<Self>
1066 where
1067 Self: Sized,
1068 {
1069 Ok(Self::wrap(TensorStorage::from_fd(fd, shape, name)?))
1070 }
1071
1072 #[cfg(unix)]
1073 fn clone_fd(&self) -> Result<std::os::fd::OwnedFd> {
1074 self.storage.clone_fd()
1075 }
1076
1077 fn memory(&self) -> TensorMemory {
1078 self.storage.memory()
1079 }
1080
1081 fn name(&self) -> String {
1082 self.storage.name()
1083 }
1084
1085 fn shape(&self) -> &[usize] {
1086 self.storage.shape()
1087 }
1088
1089 fn reshape(&mut self, shape: &[usize]) -> Result<()> {
1090 if self.chroma.is_some() {
1091 return Err(Error::InvalidOperation(
1092 "cannot reshape a multiplane tensor — decompose planes first".into(),
1093 ));
1094 }
1095 self.storage.reshape(shape)?;
1096 self.format = None;
1097 self.row_stride = None;
1098 self.plane_offset = None;
1099 Ok(())
1100 }
1101
1102 fn map(&self) -> Result<TensorMap<T>> {
1103 if self.row_stride.is_some() {
1104 return Err(Error::InvalidOperation(
1105 "CPU mapping of strided tensors is not supported; use GPU path only".into(),
1106 ));
1107 }
1108 if self.plane_offset.is_some_and(|o| o > 0) {
1109 return Err(Error::InvalidOperation(
1110 "CPU mapping of offset tensors is not supported; use GPU path only".into(),
1111 ));
1112 }
1113 self.storage.map()
1114 }
1115
1116 fn buffer_identity(&self) -> &BufferIdentity {
1117 self.storage.buffer_identity()
1118 }
1119}
1120
1121pub enum TensorMap<T>
1122where
1123 T: Num + Clone + fmt::Debug,
1124{
1125 #[cfg(target_os = "linux")]
1126 Dma(DmaMap<T>),
1127 #[cfg(unix)]
1128 Shm(ShmMap<T>),
1129 Mem(MemMap<T>),
1130 Pbo(PboMap<T>),
1131}
1132
1133impl<T> TensorMapTrait<T> for TensorMap<T>
1134where
1135 T: Num + Clone + fmt::Debug,
1136{
1137 fn shape(&self) -> &[usize] {
1138 match self {
1139 #[cfg(target_os = "linux")]
1140 TensorMap::Dma(map) => map.shape(),
1141 #[cfg(unix)]
1142 TensorMap::Shm(map) => map.shape(),
1143 TensorMap::Mem(map) => map.shape(),
1144 TensorMap::Pbo(map) => map.shape(),
1145 }
1146 }
1147
1148 fn unmap(&mut self) {
1149 match self {
1150 #[cfg(target_os = "linux")]
1151 TensorMap::Dma(map) => map.unmap(),
1152 #[cfg(unix)]
1153 TensorMap::Shm(map) => map.unmap(),
1154 TensorMap::Mem(map) => map.unmap(),
1155 TensorMap::Pbo(map) => map.unmap(),
1156 }
1157 }
1158
1159 fn as_slice(&self) -> &[T] {
1160 match self {
1161 #[cfg(target_os = "linux")]
1162 TensorMap::Dma(map) => map.as_slice(),
1163 #[cfg(unix)]
1164 TensorMap::Shm(map) => map.as_slice(),
1165 TensorMap::Mem(map) => map.as_slice(),
1166 TensorMap::Pbo(map) => map.as_slice(),
1167 }
1168 }
1169
1170 fn as_mut_slice(&mut self) -> &mut [T] {
1171 match self {
1172 #[cfg(target_os = "linux")]
1173 TensorMap::Dma(map) => map.as_mut_slice(),
1174 #[cfg(unix)]
1175 TensorMap::Shm(map) => map.as_mut_slice(),
1176 TensorMap::Mem(map) => map.as_mut_slice(),
1177 TensorMap::Pbo(map) => map.as_mut_slice(),
1178 }
1179 }
1180}
1181
1182impl<T> Deref for TensorMap<T>
1183where
1184 T: Num + Clone + fmt::Debug,
1185{
1186 type Target = [T];
1187
1188 fn deref(&self) -> &[T] {
1189 match self {
1190 #[cfg(target_os = "linux")]
1191 TensorMap::Dma(map) => map.deref(),
1192 #[cfg(unix)]
1193 TensorMap::Shm(map) => map.deref(),
1194 TensorMap::Mem(map) => map.deref(),
1195 TensorMap::Pbo(map) => map.deref(),
1196 }
1197 }
1198}
1199
1200impl<T> DerefMut for TensorMap<T>
1201where
1202 T: Num + Clone + fmt::Debug,
1203{
1204 fn deref_mut(&mut self) -> &mut [T] {
1205 match self {
1206 #[cfg(target_os = "linux")]
1207 TensorMap::Dma(map) => map.deref_mut(),
1208 #[cfg(unix)]
1209 TensorMap::Shm(map) => map.deref_mut(),
1210 TensorMap::Mem(map) => map.deref_mut(),
1211 TensorMap::Pbo(map) => map.deref_mut(),
1212 }
1213 }
1214}
1215
1216#[cfg(target_os = "linux")]
1228static DMA_AVAILABLE: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
1229
1230#[cfg(target_os = "linux")]
1232pub fn is_dma_available() -> bool {
1233 *DMA_AVAILABLE.get_or_init(|| Tensor::<u8>::new(&[64], Some(TensorMemory::Dma), None).is_ok())
1234}
1235
1236#[cfg(not(target_os = "linux"))]
1240pub fn is_dma_available() -> bool {
1241 false
1242}
1243
1244#[cfg(unix)]
1251static SHM_AVAILABLE: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
1252
1253#[cfg(unix)]
1255pub fn is_shm_available() -> bool {
1256 *SHM_AVAILABLE.get_or_init(|| Tensor::<u8>::new(&[64], Some(TensorMemory::Shm), None).is_ok())
1257}
1258
1259#[cfg(not(unix))]
1263pub fn is_shm_available() -> bool {
1264 false
1265}
1266
1267#[cfg(test)]
1268mod dtype_tests {
1269 use super::*;
1270
1271 #[test]
1272 fn dtype_size() {
1273 assert_eq!(DType::U8.size(), 1);
1274 assert_eq!(DType::I8.size(), 1);
1275 assert_eq!(DType::U16.size(), 2);
1276 assert_eq!(DType::I16.size(), 2);
1277 assert_eq!(DType::U32.size(), 4);
1278 assert_eq!(DType::I32.size(), 4);
1279 assert_eq!(DType::U64.size(), 8);
1280 assert_eq!(DType::I64.size(), 8);
1281 assert_eq!(DType::F16.size(), 2);
1282 assert_eq!(DType::F32.size(), 4);
1283 assert_eq!(DType::F64.size(), 8);
1284 }
1285
1286 #[test]
1287 fn dtype_name() {
1288 assert_eq!(DType::U8.name(), "u8");
1289 assert_eq!(DType::F16.name(), "f16");
1290 assert_eq!(DType::F32.name(), "f32");
1291 }
1292
1293 #[test]
1294 fn dtype_serde_roundtrip() {
1295 use serde_json;
1296 let dt = DType::F16;
1297 let json = serde_json::to_string(&dt).unwrap();
1298 let back: DType = serde_json::from_str(&json).unwrap();
1299 assert_eq!(dt, back);
1300 }
1301}
1302
1303#[cfg(test)]
1304mod image_tests {
1305 use super::*;
1306
1307 #[test]
1308 fn raw_tensor_has_no_format() {
1309 let t = Tensor::<u8>::new(&[480, 640, 3], None, None).unwrap();
1310 assert!(t.format().is_none());
1311 assert!(t.width().is_none());
1312 assert!(t.height().is_none());
1313 assert!(!t.is_multiplane());
1314 assert!(t.chroma().is_none());
1315 }
1316
1317 #[test]
1318 fn image_tensor_packed() {
1319 let t = Tensor::<u8>::image(640, 480, PixelFormat::Rgba, None).unwrap();
1320 assert_eq!(t.format(), Some(PixelFormat::Rgba));
1321 assert_eq!(t.width(), Some(640));
1322 assert_eq!(t.height(), Some(480));
1323 assert_eq!(t.shape(), &[480, 640, 4]);
1324 assert!(!t.is_multiplane());
1325 }
1326
1327 #[test]
1328 fn image_tensor_planar() {
1329 let t = Tensor::<u8>::image(640, 480, PixelFormat::PlanarRgb, None).unwrap();
1330 assert_eq!(t.format(), Some(PixelFormat::PlanarRgb));
1331 assert_eq!(t.width(), Some(640));
1332 assert_eq!(t.height(), Some(480));
1333 assert_eq!(t.shape(), &[3, 480, 640]);
1334 }
1335
1336 #[test]
1337 fn image_tensor_semi_planar_contiguous() {
1338 let t = Tensor::<u8>::image(640, 480, PixelFormat::Nv12, None).unwrap();
1339 assert_eq!(t.format(), Some(PixelFormat::Nv12));
1340 assert_eq!(t.width(), Some(640));
1341 assert_eq!(t.height(), Some(480));
1342 assert_eq!(t.shape(), &[720, 640]);
1344 assert!(!t.is_multiplane());
1345 }
1346
1347 #[test]
1348 fn set_format_valid() {
1349 let mut t = Tensor::<u8>::new(&[480, 640, 3], None, None).unwrap();
1350 assert!(t.format().is_none());
1351 t.set_format(PixelFormat::Rgb).unwrap();
1352 assert_eq!(t.format(), Some(PixelFormat::Rgb));
1353 assert_eq!(t.width(), Some(640));
1354 assert_eq!(t.height(), Some(480));
1355 }
1356
1357 #[test]
1358 fn set_format_invalid_shape() {
1359 let mut t = Tensor::<u8>::new(&[480, 640, 4], None, None).unwrap();
1360 let err = t.set_format(PixelFormat::Rgb);
1362 assert!(err.is_err());
1363 assert!(t.format().is_none());
1365 }
1366
1367 #[test]
1368 fn reshape_clears_format() {
1369 let mut t = Tensor::<u8>::image(640, 480, PixelFormat::Rgba, None).unwrap();
1370 assert_eq!(t.format(), Some(PixelFormat::Rgba));
1371 t.reshape(&[480 * 640 * 4]).unwrap();
1373 assert!(t.format().is_none());
1374 }
1375
1376 #[test]
1377 fn from_planes_nv12() {
1378 let y = Tensor::<u8>::new(&[480, 640], None, None).unwrap();
1379 let uv = Tensor::<u8>::new(&[240, 640], None, None).unwrap();
1380 let img = Tensor::from_planes(y, uv, PixelFormat::Nv12).unwrap();
1381 assert_eq!(img.format(), Some(PixelFormat::Nv12));
1382 assert!(img.is_multiplane());
1383 assert!(img.chroma().is_some());
1384 assert_eq!(img.width(), Some(640));
1385 assert_eq!(img.height(), Some(480));
1386 }
1387
1388 #[test]
1389 fn from_planes_rejects_non_semiplanar() {
1390 let y = Tensor::<u8>::new(&[480, 640], None, None).unwrap();
1391 let uv = Tensor::<u8>::new(&[240, 640], None, None).unwrap();
1392 let err = Tensor::from_planes(y, uv, PixelFormat::Rgb);
1393 assert!(err.is_err());
1394 }
1395
1396 #[test]
1397 fn reshape_multiplane_errors() {
1398 let y = Tensor::<u8>::new(&[480, 640], None, None).unwrap();
1399 let uv = Tensor::<u8>::new(&[240, 640], None, None).unwrap();
1400 let mut img = Tensor::from_planes(y, uv, PixelFormat::Nv12).unwrap();
1401 let err = img.reshape(&[480 * 640 + 240 * 640]);
1402 assert!(err.is_err());
1403 }
1404}
1405
1406#[cfg(test)]
1407mod tests {
1408 #[cfg(target_os = "linux")]
1409 use nix::unistd::{access, AccessFlags};
1410 #[cfg(target_os = "linux")]
1411 use std::io::Write as _;
1412 use std::sync::RwLock;
1413
1414 use super::*;
1415
1416 #[ctor::ctor]
1417 fn init() {
1418 env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init();
1419 }
1420
1421 #[cfg(target_os = "linux")]
1423 macro_rules! function {
1424 () => {{
1425 fn f() {}
1426 fn type_name_of<T>(_: T) -> &'static str {
1427 std::any::type_name::<T>()
1428 }
1429 let name = type_name_of(f);
1430
1431 match &name[..name.len() - 3].rfind(':') {
1433 Some(pos) => &name[pos + 1..name.len() - 3],
1434 None => &name[..name.len() - 3],
1435 }
1436 }};
1437 }
1438
1439 #[test]
1440 #[cfg(target_os = "linux")]
1441 fn test_tensor() {
1442 let _lock = FD_LOCK.read().unwrap();
1443 let shape = vec![1];
1444 let tensor = DmaTensor::<f32>::new(&shape, Some("dma_tensor"));
1445 let dma_enabled = tensor.is_ok();
1446
1447 let tensor = Tensor::<f32>::new(&shape, None, None).expect("Failed to create tensor");
1448 match dma_enabled {
1449 true => assert_eq!(tensor.memory(), TensorMemory::Dma),
1450 false => assert_eq!(tensor.memory(), TensorMemory::Shm),
1451 }
1452 }
1453
1454 #[test]
1455 #[cfg(all(unix, not(target_os = "linux")))]
1456 fn test_tensor() {
1457 let shape = vec![1];
1458 let tensor = Tensor::<f32>::new(&shape, None, None).expect("Failed to create tensor");
1459 assert!(
1461 tensor.memory() == TensorMemory::Shm || tensor.memory() == TensorMemory::Mem,
1462 "Expected SHM or Mem on macOS, got {:?}",
1463 tensor.memory()
1464 );
1465 }
1466
1467 #[test]
1468 #[cfg(not(unix))]
1469 fn test_tensor() {
1470 let shape = vec![1];
1471 let tensor = Tensor::<f32>::new(&shape, None, None).expect("Failed to create tensor");
1472 assert_eq!(tensor.memory(), TensorMemory::Mem);
1473 }
1474
1475 #[test]
1476 #[cfg(target_os = "linux")]
1477 fn test_dma_tensor() {
1478 let _lock = FD_LOCK.read().unwrap();
1479 match access(
1480 "/dev/dma_heap/linux,cma",
1481 AccessFlags::R_OK | AccessFlags::W_OK,
1482 ) {
1483 Ok(_) => println!("/dev/dma_heap/linux,cma is available"),
1484 Err(_) => match access(
1485 "/dev/dma_heap/system",
1486 AccessFlags::R_OK | AccessFlags::W_OK,
1487 ) {
1488 Ok(_) => println!("/dev/dma_heap/system is available"),
1489 Err(e) => {
1490 writeln!(
1491 &mut std::io::stdout(),
1492 "[WARNING] DMA Heap is unavailable: {e}"
1493 )
1494 .unwrap();
1495 return;
1496 }
1497 },
1498 }
1499
1500 let shape = vec![2, 3, 4];
1501 let tensor =
1502 DmaTensor::<f32>::new(&shape, Some("test_tensor")).expect("Failed to create tensor");
1503
1504 const DUMMY_VALUE: f32 = 12.34;
1505
1506 assert_eq!(tensor.memory(), TensorMemory::Dma);
1507 assert_eq!(tensor.name(), "test_tensor");
1508 assert_eq!(tensor.shape(), &shape);
1509 assert_eq!(tensor.size(), 2 * 3 * 4 * std::mem::size_of::<f32>());
1510 assert_eq!(tensor.len(), 2 * 3 * 4);
1511
1512 {
1513 let mut tensor_map = tensor.map().expect("Failed to map DMA memory");
1514 tensor_map.fill(42.0);
1515 assert!(tensor_map.iter().all(|&x| x == 42.0));
1516 }
1517
1518 {
1519 let shared = Tensor::<f32>::from_fd(
1520 tensor
1521 .clone_fd()
1522 .expect("Failed to duplicate tensor file descriptor"),
1523 &shape,
1524 Some("test_tensor_shared"),
1525 )
1526 .expect("Failed to create tensor from fd");
1527
1528 assert_eq!(shared.memory(), TensorMemory::Dma);
1529 assert_eq!(shared.name(), "test_tensor_shared");
1530 assert_eq!(shared.shape(), &shape);
1531
1532 let mut tensor_map = shared.map().expect("Failed to map DMA memory from fd");
1533 tensor_map.fill(DUMMY_VALUE);
1534 assert!(tensor_map.iter().all(|&x| x == DUMMY_VALUE));
1535 }
1536
1537 {
1538 let tensor_map = tensor.map().expect("Failed to map DMA memory");
1539 assert!(tensor_map.iter().all(|&x| x == DUMMY_VALUE));
1540 }
1541
1542 let mut tensor = DmaTensor::<u8>::new(&shape, None).expect("Failed to create tensor");
1543 assert_eq!(tensor.shape(), &shape);
1544 let new_shape = vec![3, 4, 4];
1545 assert!(
1546 tensor.reshape(&new_shape).is_err(),
1547 "Reshape should fail due to size mismatch"
1548 );
1549 assert_eq!(tensor.shape(), &shape, "Shape should remain unchanged");
1550
1551 let new_shape = vec![2, 3, 4];
1552 tensor.reshape(&new_shape).expect("Reshape should succeed");
1553 assert_eq!(
1554 tensor.shape(),
1555 &new_shape,
1556 "Shape should be updated after successful reshape"
1557 );
1558
1559 {
1560 let mut tensor_map = tensor.map().expect("Failed to map DMA memory");
1561 tensor_map.fill(1);
1562 assert!(tensor_map.iter().all(|&x| x == 1));
1563 }
1564
1565 {
1566 let mut tensor_map = tensor.map().expect("Failed to map DMA memory");
1567 tensor_map[2] = 42;
1568 assert_eq!(tensor_map[1], 1, "Value at index 1 should be 1");
1569 assert_eq!(tensor_map[2], 42, "Value at index 2 should be 42");
1570 }
1571 }
1572
1573 #[test]
1574 #[cfg(unix)]
1575 fn test_shm_tensor() {
1576 let _lock = FD_LOCK.read().unwrap();
1577 let shape = vec![2, 3, 4];
1578 let tensor =
1579 ShmTensor::<f32>::new(&shape, Some("test_tensor")).expect("Failed to create tensor");
1580 assert_eq!(tensor.shape(), &shape);
1581 assert_eq!(tensor.size(), 2 * 3 * 4 * std::mem::size_of::<f32>());
1582 assert_eq!(tensor.name(), "test_tensor");
1583
1584 const DUMMY_VALUE: f32 = 12.34;
1585 {
1586 let mut tensor_map = tensor.map().expect("Failed to map shared memory");
1587 tensor_map.fill(42.0);
1588 assert!(tensor_map.iter().all(|&x| x == 42.0));
1589 }
1590
1591 {
1592 let shared = Tensor::<f32>::from_fd(
1593 tensor
1594 .clone_fd()
1595 .expect("Failed to duplicate tensor file descriptor"),
1596 &shape,
1597 Some("test_tensor_shared"),
1598 )
1599 .expect("Failed to create tensor from fd");
1600
1601 assert_eq!(shared.memory(), TensorMemory::Shm);
1602 assert_eq!(shared.name(), "test_tensor_shared");
1603 assert_eq!(shared.shape(), &shape);
1604
1605 let mut tensor_map = shared.map().expect("Failed to map shared memory from fd");
1606 tensor_map.fill(DUMMY_VALUE);
1607 assert!(tensor_map.iter().all(|&x| x == DUMMY_VALUE));
1608 }
1609
1610 {
1611 let tensor_map = tensor.map().expect("Failed to map shared memory");
1612 assert!(tensor_map.iter().all(|&x| x == DUMMY_VALUE));
1613 }
1614
1615 let mut tensor = ShmTensor::<u8>::new(&shape, None).expect("Failed to create tensor");
1616 assert_eq!(tensor.shape(), &shape);
1617 let new_shape = vec![3, 4, 4];
1618 assert!(
1619 tensor.reshape(&new_shape).is_err(),
1620 "Reshape should fail due to size mismatch"
1621 );
1622 assert_eq!(tensor.shape(), &shape, "Shape should remain unchanged");
1623
1624 let new_shape = vec![2, 3, 4];
1625 tensor.reshape(&new_shape).expect("Reshape should succeed");
1626 assert_eq!(
1627 tensor.shape(),
1628 &new_shape,
1629 "Shape should be updated after successful reshape"
1630 );
1631
1632 {
1633 let mut tensor_map = tensor.map().expect("Failed to map shared memory");
1634 tensor_map.fill(1);
1635 assert!(tensor_map.iter().all(|&x| x == 1));
1636 }
1637
1638 {
1639 let mut tensor_map = tensor.map().expect("Failed to map shared memory");
1640 tensor_map[2] = 42;
1641 assert_eq!(tensor_map[1], 1, "Value at index 1 should be 1");
1642 assert_eq!(tensor_map[2], 42, "Value at index 2 should be 42");
1643 }
1644 }
1645
1646 #[test]
1647 fn test_mem_tensor() {
1648 let shape = vec![2, 3, 4];
1649 let tensor =
1650 MemTensor::<f32>::new(&shape, Some("test_tensor")).expect("Failed to create tensor");
1651 assert_eq!(tensor.shape(), &shape);
1652 assert_eq!(tensor.size(), 2 * 3 * 4 * std::mem::size_of::<f32>());
1653 assert_eq!(tensor.name(), "test_tensor");
1654
1655 {
1656 let mut tensor_map = tensor.map().expect("Failed to map memory");
1657 tensor_map.fill(42.0);
1658 assert!(tensor_map.iter().all(|&x| x == 42.0));
1659 }
1660
1661 let mut tensor = MemTensor::<u8>::new(&shape, None).expect("Failed to create tensor");
1662 assert_eq!(tensor.shape(), &shape);
1663 let new_shape = vec![3, 4, 4];
1664 assert!(
1665 tensor.reshape(&new_shape).is_err(),
1666 "Reshape should fail due to size mismatch"
1667 );
1668 assert_eq!(tensor.shape(), &shape, "Shape should remain unchanged");
1669
1670 let new_shape = vec![2, 3, 4];
1671 tensor.reshape(&new_shape).expect("Reshape should succeed");
1672 assert_eq!(
1673 tensor.shape(),
1674 &new_shape,
1675 "Shape should be updated after successful reshape"
1676 );
1677
1678 {
1679 let mut tensor_map = tensor.map().expect("Failed to map memory");
1680 tensor_map.fill(1);
1681 assert!(tensor_map.iter().all(|&x| x == 1));
1682 }
1683
1684 {
1685 let mut tensor_map = tensor.map().expect("Failed to map memory");
1686 tensor_map[2] = 42;
1687 assert_eq!(tensor_map[1], 1, "Value at index 1 should be 1");
1688 assert_eq!(tensor_map[2], 42, "Value at index 2 should be 42");
1689 }
1690 }
1691
1692 #[test]
1693 #[cfg(target_os = "linux")]
1694 fn test_dma_no_fd_leaks() {
1695 let _lock = FD_LOCK.write().unwrap();
1696 if !is_dma_available() {
1697 log::warn!(
1698 "SKIPPED: {} - DMA memory allocation not available (permission denied or no DMA-BUF support)",
1699 function!()
1700 );
1701 return;
1702 }
1703
1704 let proc = procfs::process::Process::myself()
1705 .expect("Failed to get current process using /proc/self");
1706
1707 let start_open_fds = proc
1708 .fd_count()
1709 .expect("Failed to get open file descriptor count");
1710
1711 for _ in 0..100 {
1712 let tensor = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Dma), None)
1713 .expect("Failed to create tensor");
1714 let mut map = tensor.map().unwrap();
1715 map.as_mut_slice().fill(233);
1716 }
1717
1718 let end_open_fds = proc
1719 .fd_count()
1720 .expect("Failed to get open file descriptor count");
1721
1722 assert_eq!(
1723 start_open_fds, end_open_fds,
1724 "File descriptor leak detected: {} -> {}",
1725 start_open_fds, end_open_fds
1726 );
1727 }
1728
1729 #[test]
1730 #[cfg(target_os = "linux")]
1731 fn test_dma_from_fd_no_fd_leaks() {
1732 let _lock = FD_LOCK.write().unwrap();
1733 if !is_dma_available() {
1734 log::warn!(
1735 "SKIPPED: {} - DMA memory allocation not available (permission denied or no DMA-BUF support)",
1736 function!()
1737 );
1738 return;
1739 }
1740
1741 let proc = procfs::process::Process::myself()
1742 .expect("Failed to get current process using /proc/self");
1743
1744 let start_open_fds = proc
1745 .fd_count()
1746 .expect("Failed to get open file descriptor count");
1747
1748 let orig = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Dma), None).unwrap();
1749
1750 for _ in 0..100 {
1751 let tensor =
1752 Tensor::<u8>::from_fd(orig.clone_fd().unwrap(), orig.shape(), None).unwrap();
1753 let mut map = tensor.map().unwrap();
1754 map.as_mut_slice().fill(233);
1755 }
1756 drop(orig);
1757
1758 let end_open_fds = proc.fd_count().unwrap();
1759
1760 assert_eq!(
1761 start_open_fds, end_open_fds,
1762 "File descriptor leak detected: {} -> {}",
1763 start_open_fds, end_open_fds
1764 );
1765 }
1766
1767 #[test]
1768 #[cfg(target_os = "linux")]
1769 fn test_shm_no_fd_leaks() {
1770 let _lock = FD_LOCK.write().unwrap();
1771 if !is_shm_available() {
1772 log::warn!(
1773 "SKIPPED: {} - SHM memory allocation not available (permission denied or no SHM support)",
1774 function!()
1775 );
1776 return;
1777 }
1778
1779 let proc = procfs::process::Process::myself()
1780 .expect("Failed to get current process using /proc/self");
1781
1782 let start_open_fds = proc
1783 .fd_count()
1784 .expect("Failed to get open file descriptor count");
1785
1786 for _ in 0..100 {
1787 let tensor = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Shm), None)
1788 .expect("Failed to create tensor");
1789 let mut map = tensor.map().unwrap();
1790 map.as_mut_slice().fill(233);
1791 }
1792
1793 let end_open_fds = proc
1794 .fd_count()
1795 .expect("Failed to get open file descriptor count");
1796
1797 assert_eq!(
1798 start_open_fds, end_open_fds,
1799 "File descriptor leak detected: {} -> {}",
1800 start_open_fds, end_open_fds
1801 );
1802 }
1803
1804 #[test]
1805 #[cfg(target_os = "linux")]
1806 fn test_shm_from_fd_no_fd_leaks() {
1807 let _lock = FD_LOCK.write().unwrap();
1808 if !is_shm_available() {
1809 log::warn!(
1810 "SKIPPED: {} - SHM memory allocation not available (permission denied or no SHM support)",
1811 function!()
1812 );
1813 return;
1814 }
1815
1816 let proc = procfs::process::Process::myself()
1817 .expect("Failed to get current process using /proc/self");
1818
1819 let start_open_fds = proc
1820 .fd_count()
1821 .expect("Failed to get open file descriptor count");
1822
1823 let orig = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Shm), None).unwrap();
1824
1825 for _ in 0..100 {
1826 let tensor =
1827 Tensor::<u8>::from_fd(orig.clone_fd().unwrap(), orig.shape(), None).unwrap();
1828 let mut map = tensor.map().unwrap();
1829 map.as_mut_slice().fill(233);
1830 }
1831 drop(orig);
1832
1833 let end_open_fds = proc.fd_count().unwrap();
1834
1835 assert_eq!(
1836 start_open_fds, end_open_fds,
1837 "File descriptor leak detected: {} -> {}",
1838 start_open_fds, end_open_fds
1839 );
1840 }
1841
1842 #[cfg(feature = "ndarray")]
1843 #[test]
1844 fn test_ndarray() {
1845 let _lock = FD_LOCK.read().unwrap();
1846 let shape = vec![2, 3, 4];
1847 let tensor = Tensor::<f32>::new(&shape, None, None).expect("Failed to create tensor");
1848
1849 let mut tensor_map = tensor.map().expect("Failed to map tensor memory");
1850 tensor_map.fill(1.0);
1851
1852 let view = tensor_map.view().expect("Failed to get ndarray view");
1853 assert_eq!(view.shape(), &[2, 3, 4]);
1854 assert!(view.iter().all(|&x| x == 1.0));
1855
1856 let mut view_mut = tensor_map
1857 .view_mut()
1858 .expect("Failed to get mutable ndarray view");
1859 view_mut[[0, 0, 0]] = 42.0;
1860 assert_eq!(view_mut[[0, 0, 0]], 42.0);
1861 assert_eq!(tensor_map[0], 42.0, "Value at index 0 should be 42");
1862 }
1863
1864 #[test]
1865 fn test_buffer_identity_unique() {
1866 let id1 = BufferIdentity::new();
1867 let id2 = BufferIdentity::new();
1868 assert_ne!(
1869 id1.id(),
1870 id2.id(),
1871 "Two identities should have different ids"
1872 );
1873 }
1874
1875 #[test]
1876 fn test_buffer_identity_clone_shares_guard() {
1877 let id1 = BufferIdentity::new();
1878 let weak = id1.weak();
1879 assert!(
1880 weak.upgrade().is_some(),
1881 "Weak should be alive while original exists"
1882 );
1883
1884 let id2 = id1.clone();
1885 assert_eq!(id1.id(), id2.id(), "Cloned identity should have same id");
1886
1887 drop(id1);
1888 assert!(
1889 weak.upgrade().is_some(),
1890 "Weak should still be alive (clone holds Arc)"
1891 );
1892
1893 drop(id2);
1894 assert!(
1895 weak.upgrade().is_none(),
1896 "Weak should be dead after all clones dropped"
1897 );
1898 }
1899
1900 #[test]
1901 fn test_tensor_buffer_identity() {
1902 let t1 = Tensor::<u8>::new(&[100], Some(TensorMemory::Mem), Some("t1")).unwrap();
1903 let t2 = Tensor::<u8>::new(&[100], Some(TensorMemory::Mem), Some("t2")).unwrap();
1904 assert_ne!(
1905 t1.buffer_identity().id(),
1906 t2.buffer_identity().id(),
1907 "Different tensors should have different buffer ids"
1908 );
1909 }
1910
1911 pub static FD_LOCK: RwLock<()> = RwLock::new(());
1915
1916 #[test]
1919 #[cfg(not(target_os = "linux"))]
1920 fn test_dma_not_available_on_non_linux() {
1921 assert!(
1922 !is_dma_available(),
1923 "DMA memory allocation should NOT be available on non-Linux platforms"
1924 );
1925 }
1926
1927 #[test]
1930 #[cfg(unix)]
1931 fn test_shm_available_and_usable() {
1932 assert!(
1933 is_shm_available(),
1934 "SHM memory allocation should be available on Unix systems"
1935 );
1936
1937 let tensor = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Shm), None)
1939 .expect("Failed to create SHM tensor");
1940
1941 let mut map = tensor.map().expect("Failed to map SHM tensor");
1943 map.as_mut_slice().fill(0xAB);
1944
1945 assert!(
1947 map.as_slice().iter().all(|&b| b == 0xAB),
1948 "SHM tensor data should be writable and readable"
1949 );
1950 }
1951}