1#[cfg(target_os = "linux")]
30mod dma;
31#[cfg(target_os = "linux")]
32mod dmabuf;
33mod error;
34mod format;
35mod mem;
36mod pbo;
37#[cfg(unix)]
38mod shm;
39mod tensor_dyn;
40
41#[cfg(target_os = "linux")]
42pub use crate::dma::{DmaMap, DmaTensor};
43pub use crate::mem::{MemMap, MemTensor};
44pub use crate::pbo::{PboMap, PboMapping, PboOps, PboTensor};
45#[cfg(unix)]
46pub use crate::shm::{ShmMap, ShmTensor};
47pub use error::{Error, Result};
48pub use format::{PixelFormat, PixelLayout};
49use num_traits::Num;
50use serde::{Deserialize, Serialize};
51#[cfg(unix)]
52use std::os::fd::OwnedFd;
53use std::{
54 fmt,
55 ops::{Deref, DerefMut},
56 sync::{
57 atomic::{AtomicU64, Ordering},
58 Arc, Weak,
59 },
60};
61pub use tensor_dyn::TensorDyn;
62
63#[cfg(unix)]
83pub struct PlaneDescriptor {
84 fd: OwnedFd,
85 stride: Option<usize>,
86 offset: Option<usize>,
87}
88
89#[cfg(unix)]
90impl PlaneDescriptor {
91 pub fn new(fd: std::os::fd::BorrowedFd<'_>) -> Result<Self> {
101 let owned = fd.try_clone_to_owned()?;
102 Ok(Self {
103 fd: owned,
104 stride: None,
105 offset: None,
106 })
107 }
108
109 pub fn with_stride(mut self, stride: usize) -> Self {
111 self.stride = Some(stride);
112 self
113 }
114
115 pub fn with_offset(mut self, offset: usize) -> Self {
117 self.offset = Some(offset);
118 self
119 }
120
121 pub fn into_fd(self) -> OwnedFd {
123 self.fd
124 }
125
126 pub fn stride(&self) -> Option<usize> {
128 self.stride
129 }
130
131 pub fn offset(&self) -> Option<usize> {
133 self.offset
134 }
135}
136
137#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
139#[repr(u8)]
140#[non_exhaustive]
141pub enum DType {
142 U8,
143 I8,
144 U16,
145 I16,
146 U32,
147 I32,
148 U64,
149 I64,
150 F16,
151 F32,
152 F64,
153}
154
155impl DType {
156 pub const fn size(&self) -> usize {
158 match self {
159 Self::U8 | Self::I8 => 1,
160 Self::U16 | Self::I16 | Self::F16 => 2,
161 Self::U32 | Self::I32 | Self::F32 => 4,
162 Self::U64 | Self::I64 | Self::F64 => 8,
163 }
164 }
165
166 pub const fn name(&self) -> &'static str {
168 match self {
169 Self::U8 => "u8",
170 Self::I8 => "i8",
171 Self::U16 => "u16",
172 Self::I16 => "i16",
173 Self::U32 => "u32",
174 Self::I32 => "i32",
175 Self::U64 => "u64",
176 Self::I64 => "i64",
177 Self::F16 => "f16",
178 Self::F32 => "f32",
179 Self::F64 => "f64",
180 }
181 }
182}
183
184impl fmt::Display for DType {
185 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
186 f.write_str(self.name())
187 }
188}
189
190static NEXT_BUFFER_ID: AtomicU64 = AtomicU64::new(1);
192
193#[derive(Debug, Clone)]
199pub struct BufferIdentity {
200 id: u64,
201 guard: Arc<()>,
202}
203
204impl BufferIdentity {
205 pub fn new() -> Self {
207 Self {
208 id: NEXT_BUFFER_ID.fetch_add(1, Ordering::Relaxed),
209 guard: Arc::new(()),
210 }
211 }
212
213 pub fn id(&self) -> u64 {
215 self.id
216 }
217
218 pub fn weak(&self) -> Weak<()> {
221 Arc::downgrade(&self.guard)
222 }
223}
224
225impl Default for BufferIdentity {
226 fn default() -> Self {
227 Self::new()
228 }
229}
230
231#[cfg(target_os = "linux")]
232use nix::sys::stat::{major, minor};
233
234pub trait TensorTrait<T>: Send + Sync
235where
236 T: Num + Clone + fmt::Debug,
237{
238 fn new(shape: &[usize], name: Option<&str>) -> Result<Self>
241 where
242 Self: Sized;
243
244 #[cfg(unix)]
245 fn from_fd(fd: std::os::fd::OwnedFd, shape: &[usize], name: Option<&str>) -> Result<Self>
251 where
252 Self: Sized;
253
254 #[cfg(unix)]
255 fn clone_fd(&self) -> Result<std::os::fd::OwnedFd>;
257
258 fn memory(&self) -> TensorMemory;
260
261 fn name(&self) -> String;
263
264 fn len(&self) -> usize {
266 self.shape().iter().product()
267 }
268
269 fn is_empty(&self) -> bool {
271 self.len() == 0
272 }
273
274 fn size(&self) -> usize {
276 self.len() * std::mem::size_of::<T>()
277 }
278
279 fn shape(&self) -> &[usize];
281
282 fn reshape(&mut self, shape: &[usize]) -> Result<()>;
285
286 fn map(&self) -> Result<TensorMap<T>>;
289
290 fn buffer_identity(&self) -> &BufferIdentity;
292}
293
294pub trait TensorMapTrait<T>
295where
296 T: Num + Clone + fmt::Debug,
297{
298 fn shape(&self) -> &[usize];
300
301 fn unmap(&mut self);
303
304 fn len(&self) -> usize {
306 self.shape().iter().product()
307 }
308
309 fn is_empty(&self) -> bool {
311 self.len() == 0
312 }
313
314 fn size(&self) -> usize {
316 self.len() * std::mem::size_of::<T>()
317 }
318
319 fn as_slice(&self) -> &[T];
321
322 fn as_mut_slice(&mut self) -> &mut [T];
324
325 #[cfg(feature = "ndarray")]
326 fn view(&'_ self) -> Result<ndarray::ArrayView<'_, T, ndarray::Dim<ndarray::IxDynImpl>>> {
328 Ok(ndarray::ArrayView::from_shape(
329 self.shape(),
330 self.as_slice(),
331 )?)
332 }
333
334 #[cfg(feature = "ndarray")]
335 fn view_mut(
337 &'_ mut self,
338 ) -> Result<ndarray::ArrayViewMut<'_, T, ndarray::Dim<ndarray::IxDynImpl>>> {
339 let shape = self.shape().to_vec();
340 Ok(ndarray::ArrayViewMut::from_shape(
341 shape,
342 self.as_mut_slice(),
343 )?)
344 }
345}
346
347#[derive(Debug, Clone, Copy, PartialEq, Eq)]
348pub enum TensorMemory {
349 #[cfg(target_os = "linux")]
350 Dma,
354 #[cfg(unix)]
355 Shm,
358
359 Mem,
361
362 Pbo,
365}
366
367impl From<TensorMemory> for String {
368 fn from(memory: TensorMemory) -> Self {
369 match memory {
370 #[cfg(target_os = "linux")]
371 TensorMemory::Dma => "dma".to_owned(),
372 #[cfg(unix)]
373 TensorMemory::Shm => "shm".to_owned(),
374 TensorMemory::Mem => "mem".to_owned(),
375 TensorMemory::Pbo => "pbo".to_owned(),
376 }
377 }
378}
379
380impl TryFrom<&str> for TensorMemory {
381 type Error = Error;
382
383 fn try_from(s: &str) -> Result<Self> {
384 match s {
385 #[cfg(target_os = "linux")]
386 "dma" => Ok(TensorMemory::Dma),
387 #[cfg(unix)]
388 "shm" => Ok(TensorMemory::Shm),
389 "mem" => Ok(TensorMemory::Mem),
390 "pbo" => Ok(TensorMemory::Pbo),
391 _ => Err(Error::InvalidMemoryType(s.to_owned())),
392 }
393 }
394}
395
396#[derive(Debug)]
397#[allow(dead_code)] pub(crate) enum TensorStorage<T>
399where
400 T: Num + Clone + fmt::Debug + Send + Sync,
401{
402 #[cfg(target_os = "linux")]
403 Dma(DmaTensor<T>),
404 #[cfg(unix)]
405 Shm(ShmTensor<T>),
406 Mem(MemTensor<T>),
407 Pbo(PboTensor<T>),
408}
409
410impl<T> TensorStorage<T>
411where
412 T: Num + Clone + fmt::Debug + Send + Sync,
413{
414 fn new(shape: &[usize], memory: Option<TensorMemory>, name: Option<&str>) -> Result<Self> {
419 match memory {
420 #[cfg(target_os = "linux")]
421 Some(TensorMemory::Dma) => {
422 DmaTensor::<T>::new(shape, name).map(TensorStorage::Dma)
423 }
424 #[cfg(unix)]
425 Some(TensorMemory::Shm) => {
426 ShmTensor::<T>::new(shape, name).map(TensorStorage::Shm)
427 }
428 Some(TensorMemory::Mem) => {
429 MemTensor::<T>::new(shape, name).map(TensorStorage::Mem)
430 }
431 Some(TensorMemory::Pbo) => Err(crate::error::Error::NotImplemented(
432 "PboTensor cannot be created via Tensor::new() — use ImageProcessor::create_image()".to_owned(),
433 )),
434 None => {
435 if std::env::var("EDGEFIRST_TENSOR_FORCE_MEM")
436 .is_ok_and(|x| x != "0" && x.to_lowercase() != "false")
437 {
438 MemTensor::<T>::new(shape, name).map(TensorStorage::Mem)
439 } else {
440 #[cfg(target_os = "linux")]
441 {
442 match DmaTensor::<T>::new(shape, name) {
444 Ok(tensor) => Ok(TensorStorage::Dma(tensor)),
445 Err(_) => {
446 match ShmTensor::<T>::new(shape, name)
447 .map(TensorStorage::Shm)
448 {
449 Ok(tensor) => Ok(tensor),
450 Err(_) => MemTensor::<T>::new(shape, name)
451 .map(TensorStorage::Mem),
452 }
453 }
454 }
455 }
456 #[cfg(all(unix, not(target_os = "linux")))]
457 {
458 match ShmTensor::<T>::new(shape, name) {
460 Ok(tensor) => Ok(TensorStorage::Shm(tensor)),
461 Err(_) => {
462 MemTensor::<T>::new(shape, name).map(TensorStorage::Mem)
463 }
464 }
465 }
466 #[cfg(not(unix))]
467 {
468 MemTensor::<T>::new(shape, name).map(TensorStorage::Mem)
470 }
471 }
472 }
473 }
474 }
475
476 #[cfg(target_os = "linux")]
485 pub(crate) fn new_dma_with_byte_size(
486 shape: &[usize],
487 byte_size: usize,
488 name: Option<&str>,
489 ) -> Result<Self> {
490 DmaTensor::<T>::new_with_byte_size(shape, byte_size, name).map(TensorStorage::Dma)
491 }
492
493 #[cfg(unix)]
501 fn from_fd(fd: OwnedFd, shape: &[usize], name: Option<&str>) -> Result<Self> {
502 #[cfg(target_os = "linux")]
503 {
504 use nix::sys::stat::fstat;
505
506 let stat = fstat(&fd)?;
507 let major = major(stat.st_dev);
508 let minor = minor(stat.st_dev);
509
510 log::debug!("Creating tensor from fd: major={major}, minor={minor}");
511
512 if major != 0 {
513 return Err(Error::UnknownDeviceType(major, minor));
515 }
516
517 match minor {
518 9 | 10 => {
519 DmaTensor::<T>::from_fd(fd, shape, name).map(TensorStorage::Dma)
521 }
522 _ => {
523 ShmTensor::<T>::from_fd(fd, shape, name).map(TensorStorage::Shm)
525 }
526 }
527 }
528 #[cfg(all(unix, not(target_os = "linux")))]
529 {
530 ShmTensor::<T>::from_fd(fd, shape, name).map(TensorStorage::Shm)
532 }
533 }
534}
535
536impl<T> TensorTrait<T> for TensorStorage<T>
537where
538 T: Num + Clone + fmt::Debug + Send + Sync,
539{
540 fn new(shape: &[usize], name: Option<&str>) -> Result<Self> {
541 Self::new(shape, None, name)
542 }
543
544 #[cfg(unix)]
545 fn from_fd(fd: OwnedFd, shape: &[usize], name: Option<&str>) -> Result<Self> {
546 Self::from_fd(fd, shape, name)
547 }
548
549 #[cfg(unix)]
550 fn clone_fd(&self) -> Result<OwnedFd> {
551 match self {
552 #[cfg(target_os = "linux")]
553 TensorStorage::Dma(t) => t.clone_fd(),
554 TensorStorage::Shm(t) => t.clone_fd(),
555 TensorStorage::Mem(t) => t.clone_fd(),
556 TensorStorage::Pbo(t) => t.clone_fd(),
557 }
558 }
559
560 fn memory(&self) -> TensorMemory {
561 match self {
562 #[cfg(target_os = "linux")]
563 TensorStorage::Dma(_) => TensorMemory::Dma,
564 #[cfg(unix)]
565 TensorStorage::Shm(_) => TensorMemory::Shm,
566 TensorStorage::Mem(_) => TensorMemory::Mem,
567 TensorStorage::Pbo(_) => TensorMemory::Pbo,
568 }
569 }
570
571 fn name(&self) -> String {
572 match self {
573 #[cfg(target_os = "linux")]
574 TensorStorage::Dma(t) => t.name(),
575 #[cfg(unix)]
576 TensorStorage::Shm(t) => t.name(),
577 TensorStorage::Mem(t) => t.name(),
578 TensorStorage::Pbo(t) => t.name(),
579 }
580 }
581
582 fn shape(&self) -> &[usize] {
583 match self {
584 #[cfg(target_os = "linux")]
585 TensorStorage::Dma(t) => t.shape(),
586 #[cfg(unix)]
587 TensorStorage::Shm(t) => t.shape(),
588 TensorStorage::Mem(t) => t.shape(),
589 TensorStorage::Pbo(t) => t.shape(),
590 }
591 }
592
593 fn reshape(&mut self, shape: &[usize]) -> Result<()> {
594 match self {
595 #[cfg(target_os = "linux")]
596 TensorStorage::Dma(t) => t.reshape(shape),
597 #[cfg(unix)]
598 TensorStorage::Shm(t) => t.reshape(shape),
599 TensorStorage::Mem(t) => t.reshape(shape),
600 TensorStorage::Pbo(t) => t.reshape(shape),
601 }
602 }
603
604 fn map(&self) -> Result<TensorMap<T>> {
605 match self {
606 #[cfg(target_os = "linux")]
607 TensorStorage::Dma(t) => t.map(),
608 #[cfg(unix)]
609 TensorStorage::Shm(t) => t.map(),
610 TensorStorage::Mem(t) => t.map(),
611 TensorStorage::Pbo(t) => t.map(),
612 }
613 }
614
615 fn buffer_identity(&self) -> &BufferIdentity {
616 match self {
617 #[cfg(target_os = "linux")]
618 TensorStorage::Dma(t) => t.buffer_identity(),
619 #[cfg(unix)]
620 TensorStorage::Shm(t) => t.buffer_identity(),
621 TensorStorage::Mem(t) => t.buffer_identity(),
622 TensorStorage::Pbo(t) => t.buffer_identity(),
623 }
624 }
625}
626
627#[derive(Debug)]
633pub struct Tensor<T>
634where
635 T: Num + Clone + fmt::Debug + Send + Sync,
636{
637 pub(crate) storage: TensorStorage<T>,
638 format: Option<PixelFormat>,
639 chroma: Option<Box<Tensor<T>>>,
640 row_stride: Option<usize>,
643 plane_offset: Option<usize>,
646}
647
648impl<T> Tensor<T>
649where
650 T: Num + Clone + fmt::Debug + Send + Sync,
651{
652 pub(crate) fn wrap(storage: TensorStorage<T>) -> Self {
654 Self {
655 storage,
656 format: None,
657 chroma: None,
658 row_stride: None,
659 plane_offset: None,
660 }
661 }
662
663 pub fn new(shape: &[usize], memory: Option<TensorMemory>, name: Option<&str>) -> Result<Self> {
688 TensorStorage::new(shape, memory, name).map(Self::wrap)
689 }
690
691 pub fn image(
693 width: usize,
694 height: usize,
695 format: PixelFormat,
696 memory: Option<TensorMemory>,
697 ) -> Result<Self> {
698 let shape = match format.layout() {
699 PixelLayout::Packed => vec![height, width, format.channels()],
700 PixelLayout::Planar => vec![format.channels(), height, width],
701 PixelLayout::SemiPlanar => {
702 let total_h = match format {
706 PixelFormat::Nv12 => {
707 if !height.is_multiple_of(2) {
708 return Err(Error::InvalidArgument(format!(
709 "NV12 requires even height, got {height}"
710 )));
711 }
712 height * 3 / 2
713 }
714 PixelFormat::Nv16 => height * 2,
715 _ => {
716 return Err(Error::InvalidArgument(format!(
717 "unknown semi-planar height multiplier for {format:?}"
718 )))
719 }
720 };
721 vec![total_h, width]
722 }
723 };
724 let mut t = Self::new(&shape, memory, None)?;
725 t.format = Some(format);
726 Ok(t)
727 }
728
729 pub fn image_with_stride(
765 width: usize,
766 height: usize,
767 format: PixelFormat,
768 row_stride_bytes: usize,
769 memory: Option<TensorMemory>,
770 ) -> Result<Self> {
771 #[cfg(not(target_os = "linux"))]
781 {
782 let _ = (width, height, format, row_stride_bytes, memory);
783 Err(Error::NotImplemented(
784 "image_with_stride requires DMA support (Linux only)".to_owned(),
785 ))
786 }
787
788 #[cfg(target_os = "linux")]
789 {
790 if format.layout() != PixelLayout::Packed {
791 return Err(Error::NotImplemented(format!(
792 "Tensor::image_with_stride only supports packed pixel layouts, got {format:?}"
793 )));
794 }
795 let elem = std::mem::size_of::<T>();
796 let min_stride = width
797 .checked_mul(format.channels())
798 .and_then(|p| p.checked_mul(elem))
799 .ok_or_else(|| {
800 Error::InvalidArgument(format!(
801 "image_with_stride: width {width} × channels {} × sizeof::<T>={elem} \
802 overflows usize",
803 format.channels()
804 ))
805 })?;
806 if row_stride_bytes < min_stride {
807 return Err(Error::InvalidArgument(format!(
808 "image_with_stride: row_stride {row_stride_bytes} < minimum {min_stride} \
809 ({width} px × {} ch × {elem} B)",
810 format.channels()
811 )));
812 }
813 let total_byte_size = row_stride_bytes.checked_mul(height).ok_or_else(|| {
814 Error::InvalidArgument(format!(
815 "image_with_stride: row_stride {row_stride_bytes} × height {height} overflows usize"
816 ))
817 })?;
818
819 let shape = vec![height, width, format.channels()];
820
821 let storage = match memory {
822 Some(TensorMemory::Dma) | None => {
823 TensorStorage::<T>::new_dma_with_byte_size(&shape, total_byte_size, None)?
824 }
825 Some(other) => {
826 return Err(Error::NotImplemented(format!(
827 "image_with_stride: only TensorMemory::Dma is supported, got {other:?}"
828 )));
829 }
830 };
831
832 let mut t = Self::wrap(storage);
833 t.format = Some(format);
834 t.row_stride = Some(row_stride_bytes);
835 Ok(t)
836 }
837 }
838
839 pub fn set_format(&mut self, format: PixelFormat) -> Result<()> {
856 let shape = self.shape();
857 match format.layout() {
858 PixelLayout::Packed => {
859 if shape.len() != 3 || shape[2] != format.channels() {
860 return Err(Error::InvalidShape(format!(
861 "packed format {format:?} expects [H, W, {}], got {shape:?}",
862 format.channels()
863 )));
864 }
865 }
866 PixelLayout::Planar => {
867 if shape.len() != 3 || shape[0] != format.channels() {
868 return Err(Error::InvalidShape(format!(
869 "planar format {format:?} expects [{}, H, W], got {shape:?}",
870 format.channels()
871 )));
872 }
873 }
874 PixelLayout::SemiPlanar => {
875 if shape.len() != 2 {
876 return Err(Error::InvalidShape(format!(
877 "semi-planar format {format:?} expects [H*k, W], got {shape:?}"
878 )));
879 }
880 match format {
881 PixelFormat::Nv12 if !shape[0].is_multiple_of(3) => {
882 return Err(Error::InvalidShape(format!(
883 "NV12 contiguous shape[0] must be divisible by 3, got {}",
884 shape[0]
885 )));
886 }
887 PixelFormat::Nv16 if !shape[0].is_multiple_of(2) => {
888 return Err(Error::InvalidShape(format!(
889 "NV16 contiguous shape[0] must be even, got {}",
890 shape[0]
891 )));
892 }
893 _ => {}
894 }
895 }
896 }
897 if self.format != Some(format) {
900 self.row_stride = None;
901 self.plane_offset = None;
902 #[cfg(target_os = "linux")]
903 if let TensorStorage::Dma(ref mut dma) = self.storage {
904 dma.mmap_offset = 0;
905 }
906 }
907 self.format = Some(format);
908 Ok(())
909 }
910
911 pub fn format(&self) -> Option<PixelFormat> {
913 self.format
914 }
915
916 pub fn width(&self) -> Option<usize> {
918 let fmt = self.format?;
919 let shape = self.shape();
920 match fmt.layout() {
921 PixelLayout::Packed => Some(shape[1]),
922 PixelLayout::Planar => Some(shape[2]),
923 PixelLayout::SemiPlanar => Some(shape[1]),
924 }
925 }
926
927 pub fn height(&self) -> Option<usize> {
929 let fmt = self.format?;
930 let shape = self.shape();
931 match fmt.layout() {
932 PixelLayout::Packed => Some(shape[0]),
933 PixelLayout::Planar => Some(shape[1]),
934 PixelLayout::SemiPlanar => {
935 if self.is_multiplane() {
936 Some(shape[0])
937 } else {
938 match fmt {
939 PixelFormat::Nv12 => Some(shape[0] * 2 / 3),
940 PixelFormat::Nv16 => Some(shape[0] / 2),
941 _ => None,
942 }
943 }
944 }
945 }
946 }
947
948 pub fn from_planes(luma: Tensor<T>, chroma: Tensor<T>, format: PixelFormat) -> Result<Self> {
950 if format.layout() != PixelLayout::SemiPlanar {
951 return Err(Error::InvalidArgument(format!(
952 "from_planes requires a semi-planar format, got {format:?}"
953 )));
954 }
955 if chroma.format.is_some() || chroma.chroma.is_some() {
956 return Err(Error::InvalidArgument(
957 "chroma tensor must be a raw tensor (no format or chroma metadata)".into(),
958 ));
959 }
960 let luma_shape = luma.shape();
961 let chroma_shape = chroma.shape();
962 if luma_shape.len() != 2 || chroma_shape.len() != 2 {
963 return Err(Error::InvalidArgument(format!(
964 "from_planes expects 2D shapes, got luma={luma_shape:?} chroma={chroma_shape:?}"
965 )));
966 }
967 if luma_shape[1] != chroma_shape[1] {
968 return Err(Error::InvalidArgument(format!(
969 "luma width {} != chroma width {}",
970 luma_shape[1], chroma_shape[1]
971 )));
972 }
973 match format {
974 PixelFormat::Nv12 => {
975 if luma_shape[0] % 2 != 0 {
976 return Err(Error::InvalidArgument(format!(
977 "NV12 requires even luma height, got {}",
978 luma_shape[0]
979 )));
980 }
981 if chroma_shape[0] != luma_shape[0] / 2 {
982 return Err(Error::InvalidArgument(format!(
983 "NV12 chroma height {} != luma height / 2 ({})",
984 chroma_shape[0],
985 luma_shape[0] / 2
986 )));
987 }
988 }
989 PixelFormat::Nv16 => {
990 if chroma_shape[0] != luma_shape[0] {
991 return Err(Error::InvalidArgument(format!(
992 "NV16 chroma height {} != luma height {}",
993 chroma_shape[0], luma_shape[0]
994 )));
995 }
996 }
997 _ => {
998 return Err(Error::InvalidArgument(format!(
999 "from_planes only supports NV12 and NV16, got {format:?}"
1000 )));
1001 }
1002 }
1003
1004 Ok(Tensor {
1005 storage: luma.storage,
1006 format: Some(format),
1007 chroma: Some(Box::new(chroma)),
1008 row_stride: luma.row_stride,
1009 plane_offset: luma.plane_offset,
1010 })
1011 }
1012
1013 pub fn is_multiplane(&self) -> bool {
1015 self.chroma.is_some()
1016 }
1017
1018 pub fn chroma(&self) -> Option<&Tensor<T>> {
1020 self.chroma.as_deref()
1021 }
1022
1023 pub fn chroma_mut(&mut self) -> Option<&mut Tensor<T>> {
1025 self.chroma.as_deref_mut()
1026 }
1027
1028 pub fn row_stride(&self) -> Option<usize> {
1030 self.row_stride
1031 }
1032
1033 pub fn effective_row_stride(&self) -> Option<usize> {
1038 if let Some(s) = self.row_stride {
1039 return Some(s);
1040 }
1041 let fmt = self.format?;
1042 let w = self.width()?;
1043 let elem = std::mem::size_of::<T>();
1044 Some(match fmt.layout() {
1045 PixelLayout::Packed => w * fmt.channels() * elem,
1046 PixelLayout::Planar | PixelLayout::SemiPlanar => w * elem,
1047 })
1048 }
1049
1050 pub fn set_row_stride(&mut self, stride: usize) -> Result<()> {
1077 let fmt = self.format.ok_or_else(|| {
1078 Error::InvalidArgument("cannot set row_stride without a pixel format".into())
1079 })?;
1080 let w = self.width().ok_or_else(|| {
1081 Error::InvalidArgument("cannot determine width for row_stride validation".into())
1082 })?;
1083 let elem = std::mem::size_of::<T>();
1084 let min_stride = match fmt.layout() {
1085 PixelLayout::Packed => w * fmt.channels() * elem,
1086 PixelLayout::Planar | PixelLayout::SemiPlanar => w * elem,
1087 };
1088 if stride < min_stride {
1089 return Err(Error::InvalidArgument(format!(
1090 "row_stride {stride} < minimum {min_stride} for {fmt:?} at width {w}"
1091 )));
1092 }
1093 self.row_stride = Some(stride);
1094 Ok(())
1095 }
1096
1097 pub fn set_row_stride_unchecked(&mut self, stride: usize) {
1103 self.row_stride = Some(stride);
1104 }
1105
1106 pub fn with_row_stride(mut self, stride: usize) -> Result<Self> {
1113 self.set_row_stride(stride)?;
1114 Ok(self)
1115 }
1116
1117 pub fn plane_offset(&self) -> Option<usize> {
1119 self.plane_offset
1120 }
1121
1122 pub fn set_plane_offset(&mut self, offset: usize) {
1128 self.plane_offset = Some(offset);
1129 #[cfg(target_os = "linux")]
1130 if let TensorStorage::Dma(ref mut dma) = self.storage {
1131 dma.mmap_offset = offset;
1132 }
1133 }
1134
1135 pub fn with_plane_offset(mut self, offset: usize) -> Self {
1138 self.set_plane_offset(offset);
1139 self
1140 }
1141
1142 pub fn as_pbo(&self) -> Option<&PboTensor<T>> {
1144 match &self.storage {
1145 TensorStorage::Pbo(p) => Some(p),
1146 _ => None,
1147 }
1148 }
1149
1150 #[cfg(target_os = "linux")]
1152 pub fn as_dma(&self) -> Option<&DmaTensor<T>> {
1153 match &self.storage {
1154 TensorStorage::Dma(d) => Some(d),
1155 _ => None,
1156 }
1157 }
1158
1159 #[cfg(target_os = "linux")]
1170 pub fn dmabuf(&self) -> Result<std::os::fd::BorrowedFd<'_>> {
1171 use std::os::fd::AsFd;
1172 match &self.storage {
1173 TensorStorage::Dma(dma) => Ok(dma.fd.as_fd()),
1174 _ => Err(Error::NotImplemented(format!(
1175 "dmabuf requires DMA-backed tensor, got {:?}",
1176 self.storage.memory()
1177 ))),
1178 }
1179 }
1180
1181 pub fn from_pbo(pbo: PboTensor<T>) -> Self {
1183 Self {
1184 storage: TensorStorage::Pbo(pbo),
1185 format: None,
1186 chroma: None,
1187 row_stride: None,
1188 plane_offset: None,
1189 }
1190 }
1191}
1192
1193impl<T> TensorTrait<T> for Tensor<T>
1194where
1195 T: Num + Clone + fmt::Debug + Send + Sync,
1196{
1197 fn new(shape: &[usize], name: Option<&str>) -> Result<Self>
1198 where
1199 Self: Sized,
1200 {
1201 Self::new(shape, None, name)
1202 }
1203
1204 #[cfg(unix)]
1205 fn from_fd(fd: std::os::fd::OwnedFd, shape: &[usize], name: Option<&str>) -> Result<Self>
1206 where
1207 Self: Sized,
1208 {
1209 Ok(Self::wrap(TensorStorage::from_fd(fd, shape, name)?))
1210 }
1211
1212 #[cfg(unix)]
1213 fn clone_fd(&self) -> Result<std::os::fd::OwnedFd> {
1214 self.storage.clone_fd()
1215 }
1216
1217 fn memory(&self) -> TensorMemory {
1218 self.storage.memory()
1219 }
1220
1221 fn name(&self) -> String {
1222 self.storage.name()
1223 }
1224
1225 fn shape(&self) -> &[usize] {
1226 self.storage.shape()
1227 }
1228
1229 fn reshape(&mut self, shape: &[usize]) -> Result<()> {
1230 if self.chroma.is_some() {
1231 return Err(Error::InvalidOperation(
1232 "cannot reshape a multiplane tensor — decompose planes first".into(),
1233 ));
1234 }
1235 self.storage.reshape(shape)?;
1236 self.format = None;
1237 self.row_stride = None;
1238 self.plane_offset = None;
1239 #[cfg(target_os = "linux")]
1240 if let TensorStorage::Dma(ref mut dma) = self.storage {
1241 dma.mmap_offset = 0;
1242 }
1243 Ok(())
1244 }
1245
1246 fn map(&self) -> Result<TensorMap<T>> {
1247 #[cfg(target_os = "linux")]
1264 if let Some(stride) = self.row_stride {
1265 if let TensorStorage::Dma(dma) = &self.storage {
1266 if !dma.is_imported {
1267 let height = self.height().ok_or_else(|| {
1288 Error::InvalidOperation(
1289 "Tensor::map: strided DMA mapping requires a PixelFormat \
1290 so height() can be derived; set a format before mapping \
1291 or clear row_stride for raw tensor access"
1292 .into(),
1293 )
1294 })?;
1295 let total_bytes = stride.checked_mul(height).ok_or_else(|| {
1296 Error::InvalidOperation(format!(
1297 "Tensor::map: row_stride {stride} × height {height} overflows usize"
1298 ))
1299 })?;
1300 let available_bytes = dma.buf_size.saturating_sub(dma.mmap_offset);
1301 if total_bytes > available_bytes {
1302 return Err(Error::InvalidOperation(format!(
1303 "Tensor::map: strided mapping needs {total_bytes} bytes \
1304 but DMA buffer only has {available_bytes} available \
1305 (buf_size={}, mmap_offset={}, stride={stride}, height={height}); \
1306 the row_stride was likely set larger than the original allocation",
1307 dma.buf_size, dma.mmap_offset
1308 )));
1309 }
1310 return dma.map_with_byte_size(total_bytes).map(TensorMap::Dma);
1311 }
1312 }
1313 return Err(Error::InvalidOperation(
1314 "CPU mapping of strided foreign tensors is not supported; \
1315 use GPU path only"
1316 .into(),
1317 ));
1318 }
1319 #[cfg(not(target_os = "linux"))]
1320 if self.row_stride.is_some() {
1321 return Err(Error::InvalidOperation(
1322 "CPU mapping of strided tensors is not supported on this \
1323 platform (DMA backing is Linux-only)"
1324 .into(),
1325 ));
1326 }
1327 if self.plane_offset.is_some_and(|o| o > 0) {
1331 #[cfg(target_os = "linux")]
1332 if !matches!(self.storage, TensorStorage::Dma(_)) {
1333 return Err(Error::InvalidOperation(
1334 "plane offset only supported for DMA tensors".into(),
1335 ));
1336 }
1337 #[cfg(not(target_os = "linux"))]
1338 return Err(Error::InvalidOperation(
1339 "plane offset only supported for DMA tensors".into(),
1340 ));
1341 }
1342 self.storage.map()
1343 }
1344
1345 fn buffer_identity(&self) -> &BufferIdentity {
1346 self.storage.buffer_identity()
1347 }
1348}
1349
1350pub enum TensorMap<T>
1351where
1352 T: Num + Clone + fmt::Debug,
1353{
1354 #[cfg(target_os = "linux")]
1355 Dma(DmaMap<T>),
1356 #[cfg(unix)]
1357 Shm(ShmMap<T>),
1358 Mem(MemMap<T>),
1359 Pbo(PboMap<T>),
1360}
1361
1362impl<T> TensorMapTrait<T> for TensorMap<T>
1363where
1364 T: Num + Clone + fmt::Debug,
1365{
1366 fn shape(&self) -> &[usize] {
1367 match self {
1368 #[cfg(target_os = "linux")]
1369 TensorMap::Dma(map) => map.shape(),
1370 #[cfg(unix)]
1371 TensorMap::Shm(map) => map.shape(),
1372 TensorMap::Mem(map) => map.shape(),
1373 TensorMap::Pbo(map) => map.shape(),
1374 }
1375 }
1376
1377 fn unmap(&mut self) {
1378 match self {
1379 #[cfg(target_os = "linux")]
1380 TensorMap::Dma(map) => map.unmap(),
1381 #[cfg(unix)]
1382 TensorMap::Shm(map) => map.unmap(),
1383 TensorMap::Mem(map) => map.unmap(),
1384 TensorMap::Pbo(map) => map.unmap(),
1385 }
1386 }
1387
1388 fn as_slice(&self) -> &[T] {
1389 match self {
1390 #[cfg(target_os = "linux")]
1391 TensorMap::Dma(map) => map.as_slice(),
1392 #[cfg(unix)]
1393 TensorMap::Shm(map) => map.as_slice(),
1394 TensorMap::Mem(map) => map.as_slice(),
1395 TensorMap::Pbo(map) => map.as_slice(),
1396 }
1397 }
1398
1399 fn as_mut_slice(&mut self) -> &mut [T] {
1400 match self {
1401 #[cfg(target_os = "linux")]
1402 TensorMap::Dma(map) => map.as_mut_slice(),
1403 #[cfg(unix)]
1404 TensorMap::Shm(map) => map.as_mut_slice(),
1405 TensorMap::Mem(map) => map.as_mut_slice(),
1406 TensorMap::Pbo(map) => map.as_mut_slice(),
1407 }
1408 }
1409}
1410
1411impl<T> Deref for TensorMap<T>
1412where
1413 T: Num + Clone + fmt::Debug,
1414{
1415 type Target = [T];
1416
1417 fn deref(&self) -> &[T] {
1418 match self {
1419 #[cfg(target_os = "linux")]
1420 TensorMap::Dma(map) => map.deref(),
1421 #[cfg(unix)]
1422 TensorMap::Shm(map) => map.deref(),
1423 TensorMap::Mem(map) => map.deref(),
1424 TensorMap::Pbo(map) => map.deref(),
1425 }
1426 }
1427}
1428
1429impl<T> DerefMut for TensorMap<T>
1430where
1431 T: Num + Clone + fmt::Debug,
1432{
1433 fn deref_mut(&mut self) -> &mut [T] {
1434 match self {
1435 #[cfg(target_os = "linux")]
1436 TensorMap::Dma(map) => map.deref_mut(),
1437 #[cfg(unix)]
1438 TensorMap::Shm(map) => map.deref_mut(),
1439 TensorMap::Mem(map) => map.deref_mut(),
1440 TensorMap::Pbo(map) => map.deref_mut(),
1441 }
1442 }
1443}
1444
1445#[cfg(target_os = "linux")]
1457static DMA_AVAILABLE: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
1458
1459#[cfg(target_os = "linux")]
1461pub fn is_dma_available() -> bool {
1462 *DMA_AVAILABLE.get_or_init(|| Tensor::<u8>::new(&[64], Some(TensorMemory::Dma), None).is_ok())
1463}
1464
1465#[cfg(not(target_os = "linux"))]
1469pub fn is_dma_available() -> bool {
1470 false
1471}
1472
1473#[cfg(unix)]
1480static SHM_AVAILABLE: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
1481
1482#[cfg(unix)]
1484pub fn is_shm_available() -> bool {
1485 *SHM_AVAILABLE.get_or_init(|| Tensor::<u8>::new(&[64], Some(TensorMemory::Shm), None).is_ok())
1486}
1487
1488#[cfg(not(unix))]
1492pub fn is_shm_available() -> bool {
1493 false
1494}
1495
1496#[cfg(test)]
1497mod dtype_tests {
1498 use super::*;
1499
1500 #[test]
1501 fn dtype_size() {
1502 assert_eq!(DType::U8.size(), 1);
1503 assert_eq!(DType::I8.size(), 1);
1504 assert_eq!(DType::U16.size(), 2);
1505 assert_eq!(DType::I16.size(), 2);
1506 assert_eq!(DType::U32.size(), 4);
1507 assert_eq!(DType::I32.size(), 4);
1508 assert_eq!(DType::U64.size(), 8);
1509 assert_eq!(DType::I64.size(), 8);
1510 assert_eq!(DType::F16.size(), 2);
1511 assert_eq!(DType::F32.size(), 4);
1512 assert_eq!(DType::F64.size(), 8);
1513 }
1514
1515 #[test]
1516 fn dtype_name() {
1517 assert_eq!(DType::U8.name(), "u8");
1518 assert_eq!(DType::F16.name(), "f16");
1519 assert_eq!(DType::F32.name(), "f32");
1520 }
1521
1522 #[test]
1523 fn dtype_serde_roundtrip() {
1524 use serde_json;
1525 let dt = DType::F16;
1526 let json = serde_json::to_string(&dt).unwrap();
1527 let back: DType = serde_json::from_str(&json).unwrap();
1528 assert_eq!(dt, back);
1529 }
1530}
1531
1532#[cfg(test)]
1533mod image_tests {
1534 use super::*;
1535
1536 #[test]
1537 fn raw_tensor_has_no_format() {
1538 let t = Tensor::<u8>::new(&[480, 640, 3], None, None).unwrap();
1539 assert!(t.format().is_none());
1540 assert!(t.width().is_none());
1541 assert!(t.height().is_none());
1542 assert!(!t.is_multiplane());
1543 assert!(t.chroma().is_none());
1544 }
1545
1546 #[test]
1547 fn image_tensor_packed() {
1548 let t = Tensor::<u8>::image(640, 480, PixelFormat::Rgba, None).unwrap();
1549 assert_eq!(t.format(), Some(PixelFormat::Rgba));
1550 assert_eq!(t.width(), Some(640));
1551 assert_eq!(t.height(), Some(480));
1552 assert_eq!(t.shape(), &[480, 640, 4]);
1553 assert!(!t.is_multiplane());
1554 }
1555
1556 #[test]
1557 fn image_tensor_planar() {
1558 let t = Tensor::<u8>::image(640, 480, PixelFormat::PlanarRgb, None).unwrap();
1559 assert_eq!(t.format(), Some(PixelFormat::PlanarRgb));
1560 assert_eq!(t.width(), Some(640));
1561 assert_eq!(t.height(), Some(480));
1562 assert_eq!(t.shape(), &[3, 480, 640]);
1563 }
1564
1565 #[test]
1566 fn image_tensor_semi_planar_contiguous() {
1567 let t = Tensor::<u8>::image(640, 480, PixelFormat::Nv12, None).unwrap();
1568 assert_eq!(t.format(), Some(PixelFormat::Nv12));
1569 assert_eq!(t.width(), Some(640));
1570 assert_eq!(t.height(), Some(480));
1571 assert_eq!(t.shape(), &[720, 640]);
1573 assert!(!t.is_multiplane());
1574 }
1575
1576 #[test]
1577 #[cfg(target_os = "linux")]
1578 fn image_tensor_with_stride_preserves_logical_width() {
1579 if !is_dma_available() {
1581 eprintln!("SKIPPED: DMA heap not available");
1582 return;
1583 }
1584 let stride = 12032;
1586 let t = Tensor::<u8>::image_with_stride(
1587 3004,
1588 1688,
1589 PixelFormat::Rgba,
1590 stride,
1591 Some(TensorMemory::Dma),
1592 )
1593 .unwrap();
1594 assert_eq!(t.width(), Some(3004));
1596 assert_eq!(t.height(), Some(1688));
1597 assert_eq!(t.shape(), &[1688, 3004, 4]);
1598 assert_eq!(t.effective_row_stride(), Some(stride));
1600 use crate::TensorMapTrait;
1603 {
1604 let map = t.map().unwrap();
1605 assert!(
1606 map.as_slice().len() >= stride * 1688,
1607 "mapped buffer {} bytes < expected {}",
1608 map.as_slice().len(),
1609 stride * 1688
1610 );
1611 }
1612 {
1615 let mut map = t.map().unwrap();
1616 let slice = map.as_mut_slice();
1617 for y in 0..1688 {
1618 let row_start = y * stride;
1619 for x in 0..3004 {
1620 let p = row_start + x * 4;
1621 slice[p] = (y & 0xFF) as u8;
1622 slice[p + 1] = (x & 0xFF) as u8;
1623 slice[p + 2] = 0x42;
1624 slice[p + 3] = 0xFF;
1625 }
1626 }
1627 }
1628 {
1629 let map = t.map().unwrap();
1630 let slice = map.as_slice();
1631 assert_eq!(slice[0], 0x00);
1633 assert_eq!(slice[1], 0x00);
1634 assert_eq!(slice[2], 0x42);
1635 assert_eq!(slice[3], 0xFF);
1636 let mid = 100 * stride + 50 * 4;
1637 assert_eq!(slice[mid], 100);
1638 assert_eq!(slice[mid + 1], 50);
1639 assert_eq!(slice[mid + 2], 0x42);
1640 }
1641 }
1642
1643 #[test]
1644 #[cfg(target_os = "linux")]
1645 fn image_tensor_with_stride_rejects_foreign_strided_map() {
1646 if !is_dma_available() {
1654 eprintln!("SKIPPED: DMA heap not available");
1655 return;
1656 }
1657 let backing = Tensor::<u8>::new(&[240 * 320 * 4], Some(TensorMemory::Dma), None).unwrap();
1659 let fd = backing.clone_fd().unwrap();
1660 let shape = [240usize, 320, 4];
1662 let storage = TensorStorage::<u8>::from_fd(fd, &shape, None).unwrap();
1663 let mut t = Tensor::<u8>::wrap(storage);
1664 t.set_format(PixelFormat::Bgra).unwrap();
1665 t.set_row_stride(320 * 4).unwrap(); let err = t.map();
1667 assert!(
1668 matches!(err, Err(Error::InvalidOperation(_))),
1669 "foreign strided map should error"
1670 );
1671 }
1672
1673 #[test]
1674 #[cfg(target_os = "linux")]
1675 fn image_tensor_with_stride_map_rejects_tampered_stride() {
1676 if !is_dma_available() {
1683 eprintln!("SKIPPED: DMA heap not available");
1684 return;
1685 }
1686 let mut t = Tensor::<u8>::image_with_stride(
1689 640,
1690 480,
1691 PixelFormat::Rgba,
1692 3072,
1693 Some(TensorMemory::Dma),
1694 )
1695 .unwrap();
1696 t.set_row_stride(12288).unwrap();
1699 let err = t.map();
1701 assert!(
1702 matches!(err, Err(Error::InvalidOperation(_))),
1703 "map() with oversized stride must return InvalidOperation"
1704 );
1705 }
1706
1707 #[test]
1708 fn dma_tensor_new_with_byte_size_rejects_shape_overflow() {
1709 #[cfg(target_os = "linux")]
1716 {
1717 let err = crate::dma::DmaTensor::<u64>::new_with_byte_size(
1718 &[usize::MAX, 2, 2],
1719 usize::MAX,
1720 None,
1721 );
1722 assert!(
1723 matches!(err, Err(Error::InvalidArgument(_))),
1724 "new_with_byte_size must detect shape.product() overflow"
1725 );
1726 }
1727 }
1728
1729 #[test]
1730 #[cfg(target_os = "linux")]
1731 fn image_tensor_with_stride_rejects_too_small_stride() {
1732 let err = Tensor::<u8>::image_with_stride(
1734 640,
1735 480,
1736 PixelFormat::Rgba,
1737 2400,
1738 Some(TensorMemory::Dma),
1739 );
1740 assert!(matches!(err, Err(Error::InvalidArgument(_))));
1741 }
1742
1743 #[test]
1744 #[cfg(target_os = "linux")]
1745 fn image_tensor_with_stride_rejects_non_packed() {
1746 let err = Tensor::<u8>::image_with_stride(
1749 640,
1750 480,
1751 PixelFormat::Nv12,
1752 640,
1753 Some(TensorMemory::Dma),
1754 );
1755 assert!(matches!(err, Err(Error::NotImplemented(_))));
1756 }
1757
1758 #[test]
1759 fn set_format_valid() {
1760 let mut t = Tensor::<u8>::new(&[480, 640, 3], None, None).unwrap();
1761 assert!(t.format().is_none());
1762 t.set_format(PixelFormat::Rgb).unwrap();
1763 assert_eq!(t.format(), Some(PixelFormat::Rgb));
1764 assert_eq!(t.width(), Some(640));
1765 assert_eq!(t.height(), Some(480));
1766 }
1767
1768 #[test]
1769 fn set_format_invalid_shape() {
1770 let mut t = Tensor::<u8>::new(&[480, 640, 4], None, None).unwrap();
1771 let err = t.set_format(PixelFormat::Rgb);
1773 assert!(err.is_err());
1774 assert!(t.format().is_none());
1776 }
1777
1778 #[test]
1779 fn reshape_clears_format() {
1780 let mut t = Tensor::<u8>::image(640, 480, PixelFormat::Rgba, None).unwrap();
1781 assert_eq!(t.format(), Some(PixelFormat::Rgba));
1782 t.reshape(&[480 * 640 * 4]).unwrap();
1784 assert!(t.format().is_none());
1785 }
1786
1787 #[test]
1788 fn from_planes_nv12() {
1789 let y = Tensor::<u8>::new(&[480, 640], None, None).unwrap();
1790 let uv = Tensor::<u8>::new(&[240, 640], None, None).unwrap();
1791 let img = Tensor::from_planes(y, uv, PixelFormat::Nv12).unwrap();
1792 assert_eq!(img.format(), Some(PixelFormat::Nv12));
1793 assert!(img.is_multiplane());
1794 assert!(img.chroma().is_some());
1795 assert_eq!(img.width(), Some(640));
1796 assert_eq!(img.height(), Some(480));
1797 }
1798
1799 #[test]
1800 fn from_planes_rejects_non_semiplanar() {
1801 let y = Tensor::<u8>::new(&[480, 640], None, None).unwrap();
1802 let uv = Tensor::<u8>::new(&[240, 640], None, None).unwrap();
1803 let err = Tensor::from_planes(y, uv, PixelFormat::Rgb);
1804 assert!(err.is_err());
1805 }
1806
1807 #[test]
1808 fn reshape_multiplane_errors() {
1809 let y = Tensor::<u8>::new(&[480, 640], None, None).unwrap();
1810 let uv = Tensor::<u8>::new(&[240, 640], None, None).unwrap();
1811 let mut img = Tensor::from_planes(y, uv, PixelFormat::Nv12).unwrap();
1812 let err = img.reshape(&[480 * 640 + 240 * 640]);
1813 assert!(err.is_err());
1814 }
1815}
1816
1817#[cfg(test)]
1818mod tests {
1819 #[cfg(target_os = "linux")]
1820 use nix::unistd::{access, AccessFlags};
1821 #[cfg(target_os = "linux")]
1822 use std::io::Write as _;
1823 use std::sync::RwLock;
1824
1825 use super::*;
1826
1827 #[ctor::ctor]
1828 fn init() {
1829 env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init();
1830 }
1831
1832 #[cfg(target_os = "linux")]
1834 macro_rules! function {
1835 () => {{
1836 fn f() {}
1837 fn type_name_of<T>(_: T) -> &'static str {
1838 std::any::type_name::<T>()
1839 }
1840 let name = type_name_of(f);
1841
1842 match &name[..name.len() - 3].rfind(':') {
1844 Some(pos) => &name[pos + 1..name.len() - 3],
1845 None => &name[..name.len() - 3],
1846 }
1847 }};
1848 }
1849
1850 #[test]
1851 #[cfg(target_os = "linux")]
1852 fn test_tensor() {
1853 let _lock = FD_LOCK.read().unwrap();
1854 let shape = vec![1];
1855 let tensor = DmaTensor::<f32>::new(&shape, Some("dma_tensor"));
1856 let dma_enabled = tensor.is_ok();
1857
1858 let tensor = Tensor::<f32>::new(&shape, None, None).expect("Failed to create tensor");
1859 match dma_enabled {
1860 true => assert_eq!(tensor.memory(), TensorMemory::Dma),
1861 false => assert_eq!(tensor.memory(), TensorMemory::Shm),
1862 }
1863 }
1864
1865 #[test]
1866 #[cfg(all(unix, not(target_os = "linux")))]
1867 fn test_tensor() {
1868 let shape = vec![1];
1869 let tensor = Tensor::<f32>::new(&shape, None, None).expect("Failed to create tensor");
1870 assert!(
1872 tensor.memory() == TensorMemory::Shm || tensor.memory() == TensorMemory::Mem,
1873 "Expected SHM or Mem on macOS, got {:?}",
1874 tensor.memory()
1875 );
1876 }
1877
1878 #[test]
1879 #[cfg(not(unix))]
1880 fn test_tensor() {
1881 let shape = vec![1];
1882 let tensor = Tensor::<f32>::new(&shape, None, None).expect("Failed to create tensor");
1883 assert_eq!(tensor.memory(), TensorMemory::Mem);
1884 }
1885
1886 #[test]
1887 #[cfg(target_os = "linux")]
1888 fn test_dma_tensor() {
1889 let _lock = FD_LOCK.read().unwrap();
1890 match access(
1891 "/dev/dma_heap/linux,cma",
1892 AccessFlags::R_OK | AccessFlags::W_OK,
1893 ) {
1894 Ok(_) => println!("/dev/dma_heap/linux,cma is available"),
1895 Err(_) => match access(
1896 "/dev/dma_heap/system",
1897 AccessFlags::R_OK | AccessFlags::W_OK,
1898 ) {
1899 Ok(_) => println!("/dev/dma_heap/system is available"),
1900 Err(e) => {
1901 writeln!(
1902 &mut std::io::stdout(),
1903 "[WARNING] DMA Heap is unavailable: {e}"
1904 )
1905 .unwrap();
1906 return;
1907 }
1908 },
1909 }
1910
1911 let shape = vec![2, 3, 4];
1912 let tensor =
1913 DmaTensor::<f32>::new(&shape, Some("test_tensor")).expect("Failed to create tensor");
1914
1915 const DUMMY_VALUE: f32 = 12.34;
1916
1917 assert_eq!(tensor.memory(), TensorMemory::Dma);
1918 assert_eq!(tensor.name(), "test_tensor");
1919 assert_eq!(tensor.shape(), &shape);
1920 assert_eq!(tensor.size(), 2 * 3 * 4 * std::mem::size_of::<f32>());
1921 assert_eq!(tensor.len(), 2 * 3 * 4);
1922
1923 {
1924 let mut tensor_map = tensor.map().expect("Failed to map DMA memory");
1925 tensor_map.fill(42.0);
1926 assert!(tensor_map.iter().all(|&x| x == 42.0));
1927 }
1928
1929 {
1930 let shared = Tensor::<f32>::from_fd(
1931 tensor
1932 .clone_fd()
1933 .expect("Failed to duplicate tensor file descriptor"),
1934 &shape,
1935 Some("test_tensor_shared"),
1936 )
1937 .expect("Failed to create tensor from fd");
1938
1939 assert_eq!(shared.memory(), TensorMemory::Dma);
1940 assert_eq!(shared.name(), "test_tensor_shared");
1941 assert_eq!(shared.shape(), &shape);
1942
1943 let mut tensor_map = shared.map().expect("Failed to map DMA memory from fd");
1944 tensor_map.fill(DUMMY_VALUE);
1945 assert!(tensor_map.iter().all(|&x| x == DUMMY_VALUE));
1946 }
1947
1948 {
1949 let tensor_map = tensor.map().expect("Failed to map DMA memory");
1950 assert!(tensor_map.iter().all(|&x| x == DUMMY_VALUE));
1951 }
1952
1953 let mut tensor = DmaTensor::<u8>::new(&shape, None).expect("Failed to create tensor");
1954 assert_eq!(tensor.shape(), &shape);
1955 let new_shape = vec![3, 4, 4];
1956 assert!(
1957 tensor.reshape(&new_shape).is_err(),
1958 "Reshape should fail due to size mismatch"
1959 );
1960 assert_eq!(tensor.shape(), &shape, "Shape should remain unchanged");
1961
1962 let new_shape = vec![2, 3, 4];
1963 tensor.reshape(&new_shape).expect("Reshape should succeed");
1964 assert_eq!(
1965 tensor.shape(),
1966 &new_shape,
1967 "Shape should be updated after successful reshape"
1968 );
1969
1970 {
1971 let mut tensor_map = tensor.map().expect("Failed to map DMA memory");
1972 tensor_map.fill(1);
1973 assert!(tensor_map.iter().all(|&x| x == 1));
1974 }
1975
1976 {
1977 let mut tensor_map = tensor.map().expect("Failed to map DMA memory");
1978 tensor_map[2] = 42;
1979 assert_eq!(tensor_map[1], 1, "Value at index 1 should be 1");
1980 assert_eq!(tensor_map[2], 42, "Value at index 2 should be 42");
1981 }
1982 }
1983
1984 #[test]
1985 #[cfg(unix)]
1986 fn test_shm_tensor() {
1987 let _lock = FD_LOCK.read().unwrap();
1988 let shape = vec![2, 3, 4];
1989 let tensor =
1990 ShmTensor::<f32>::new(&shape, Some("test_tensor")).expect("Failed to create tensor");
1991 assert_eq!(tensor.shape(), &shape);
1992 assert_eq!(tensor.size(), 2 * 3 * 4 * std::mem::size_of::<f32>());
1993 assert_eq!(tensor.name(), "test_tensor");
1994
1995 const DUMMY_VALUE: f32 = 12.34;
1996 {
1997 let mut tensor_map = tensor.map().expect("Failed to map shared memory");
1998 tensor_map.fill(42.0);
1999 assert!(tensor_map.iter().all(|&x| x == 42.0));
2000 }
2001
2002 {
2003 let shared = Tensor::<f32>::from_fd(
2004 tensor
2005 .clone_fd()
2006 .expect("Failed to duplicate tensor file descriptor"),
2007 &shape,
2008 Some("test_tensor_shared"),
2009 )
2010 .expect("Failed to create tensor from fd");
2011
2012 assert_eq!(shared.memory(), TensorMemory::Shm);
2013 assert_eq!(shared.name(), "test_tensor_shared");
2014 assert_eq!(shared.shape(), &shape);
2015
2016 let mut tensor_map = shared.map().expect("Failed to map shared memory from fd");
2017 tensor_map.fill(DUMMY_VALUE);
2018 assert!(tensor_map.iter().all(|&x| x == DUMMY_VALUE));
2019 }
2020
2021 {
2022 let tensor_map = tensor.map().expect("Failed to map shared memory");
2023 assert!(tensor_map.iter().all(|&x| x == DUMMY_VALUE));
2024 }
2025
2026 let mut tensor = ShmTensor::<u8>::new(&shape, None).expect("Failed to create tensor");
2027 assert_eq!(tensor.shape(), &shape);
2028 let new_shape = vec![3, 4, 4];
2029 assert!(
2030 tensor.reshape(&new_shape).is_err(),
2031 "Reshape should fail due to size mismatch"
2032 );
2033 assert_eq!(tensor.shape(), &shape, "Shape should remain unchanged");
2034
2035 let new_shape = vec![2, 3, 4];
2036 tensor.reshape(&new_shape).expect("Reshape should succeed");
2037 assert_eq!(
2038 tensor.shape(),
2039 &new_shape,
2040 "Shape should be updated after successful reshape"
2041 );
2042
2043 {
2044 let mut tensor_map = tensor.map().expect("Failed to map shared memory");
2045 tensor_map.fill(1);
2046 assert!(tensor_map.iter().all(|&x| x == 1));
2047 }
2048
2049 {
2050 let mut tensor_map = tensor.map().expect("Failed to map shared memory");
2051 tensor_map[2] = 42;
2052 assert_eq!(tensor_map[1], 1, "Value at index 1 should be 1");
2053 assert_eq!(tensor_map[2], 42, "Value at index 2 should be 42");
2054 }
2055 }
2056
2057 #[test]
2058 fn test_mem_tensor() {
2059 let shape = vec![2, 3, 4];
2060 let tensor =
2061 MemTensor::<f32>::new(&shape, Some("test_tensor")).expect("Failed to create tensor");
2062 assert_eq!(tensor.shape(), &shape);
2063 assert_eq!(tensor.size(), 2 * 3 * 4 * std::mem::size_of::<f32>());
2064 assert_eq!(tensor.name(), "test_tensor");
2065
2066 {
2067 let mut tensor_map = tensor.map().expect("Failed to map memory");
2068 tensor_map.fill(42.0);
2069 assert!(tensor_map.iter().all(|&x| x == 42.0));
2070 }
2071
2072 let mut tensor = MemTensor::<u8>::new(&shape, None).expect("Failed to create tensor");
2073 assert_eq!(tensor.shape(), &shape);
2074 let new_shape = vec![3, 4, 4];
2075 assert!(
2076 tensor.reshape(&new_shape).is_err(),
2077 "Reshape should fail due to size mismatch"
2078 );
2079 assert_eq!(tensor.shape(), &shape, "Shape should remain unchanged");
2080
2081 let new_shape = vec![2, 3, 4];
2082 tensor.reshape(&new_shape).expect("Reshape should succeed");
2083 assert_eq!(
2084 tensor.shape(),
2085 &new_shape,
2086 "Shape should be updated after successful reshape"
2087 );
2088
2089 {
2090 let mut tensor_map = tensor.map().expect("Failed to map memory");
2091 tensor_map.fill(1);
2092 assert!(tensor_map.iter().all(|&x| x == 1));
2093 }
2094
2095 {
2096 let mut tensor_map = tensor.map().expect("Failed to map memory");
2097 tensor_map[2] = 42;
2098 assert_eq!(tensor_map[1], 1, "Value at index 1 should be 1");
2099 assert_eq!(tensor_map[2], 42, "Value at index 2 should be 42");
2100 }
2101 }
2102
2103 #[test]
2104 #[cfg(target_os = "linux")]
2105 fn test_dma_no_fd_leaks() {
2106 let _lock = FD_LOCK.write().unwrap();
2107 if !is_dma_available() {
2108 log::warn!(
2109 "SKIPPED: {} - DMA memory allocation not available (permission denied or no DMA-BUF support)",
2110 function!()
2111 );
2112 return;
2113 }
2114
2115 let proc = procfs::process::Process::myself()
2116 .expect("Failed to get current process using /proc/self");
2117
2118 let start_open_fds = proc
2119 .fd_count()
2120 .expect("Failed to get open file descriptor count");
2121
2122 for _ in 0..100 {
2123 let tensor = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Dma), None)
2124 .expect("Failed to create tensor");
2125 let mut map = tensor.map().unwrap();
2126 map.as_mut_slice().fill(233);
2127 }
2128
2129 let end_open_fds = proc
2130 .fd_count()
2131 .expect("Failed to get open file descriptor count");
2132
2133 assert_eq!(
2134 start_open_fds, end_open_fds,
2135 "File descriptor leak detected: {} -> {}",
2136 start_open_fds, end_open_fds
2137 );
2138 }
2139
2140 #[test]
2141 #[cfg(target_os = "linux")]
2142 fn test_dma_from_fd_no_fd_leaks() {
2143 let _lock = FD_LOCK.write().unwrap();
2144 if !is_dma_available() {
2145 log::warn!(
2146 "SKIPPED: {} - DMA memory allocation not available (permission denied or no DMA-BUF support)",
2147 function!()
2148 );
2149 return;
2150 }
2151
2152 let proc = procfs::process::Process::myself()
2153 .expect("Failed to get current process using /proc/self");
2154
2155 let start_open_fds = proc
2156 .fd_count()
2157 .expect("Failed to get open file descriptor count");
2158
2159 let orig = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Dma), None).unwrap();
2160
2161 for _ in 0..100 {
2162 let tensor =
2163 Tensor::<u8>::from_fd(orig.clone_fd().unwrap(), orig.shape(), None).unwrap();
2164 let mut map = tensor.map().unwrap();
2165 map.as_mut_slice().fill(233);
2166 }
2167 drop(orig);
2168
2169 let end_open_fds = proc.fd_count().unwrap();
2170
2171 assert_eq!(
2172 start_open_fds, end_open_fds,
2173 "File descriptor leak detected: {} -> {}",
2174 start_open_fds, end_open_fds
2175 );
2176 }
2177
2178 #[test]
2179 #[cfg(target_os = "linux")]
2180 fn test_shm_no_fd_leaks() {
2181 let _lock = FD_LOCK.write().unwrap();
2182 if !is_shm_available() {
2183 log::warn!(
2184 "SKIPPED: {} - SHM memory allocation not available (permission denied or no SHM support)",
2185 function!()
2186 );
2187 return;
2188 }
2189
2190 let proc = procfs::process::Process::myself()
2191 .expect("Failed to get current process using /proc/self");
2192
2193 let start_open_fds = proc
2194 .fd_count()
2195 .expect("Failed to get open file descriptor count");
2196
2197 for _ in 0..100 {
2198 let tensor = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Shm), None)
2199 .expect("Failed to create tensor");
2200 let mut map = tensor.map().unwrap();
2201 map.as_mut_slice().fill(233);
2202 }
2203
2204 let end_open_fds = proc
2205 .fd_count()
2206 .expect("Failed to get open file descriptor count");
2207
2208 assert_eq!(
2209 start_open_fds, end_open_fds,
2210 "File descriptor leak detected: {} -> {}",
2211 start_open_fds, end_open_fds
2212 );
2213 }
2214
2215 #[test]
2216 #[cfg(target_os = "linux")]
2217 fn test_shm_from_fd_no_fd_leaks() {
2218 let _lock = FD_LOCK.write().unwrap();
2219 if !is_shm_available() {
2220 log::warn!(
2221 "SKIPPED: {} - SHM memory allocation not available (permission denied or no SHM support)",
2222 function!()
2223 );
2224 return;
2225 }
2226
2227 let proc = procfs::process::Process::myself()
2228 .expect("Failed to get current process using /proc/self");
2229
2230 let start_open_fds = proc
2231 .fd_count()
2232 .expect("Failed to get open file descriptor count");
2233
2234 let orig = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Shm), None).unwrap();
2235
2236 for _ in 0..100 {
2237 let tensor =
2238 Tensor::<u8>::from_fd(orig.clone_fd().unwrap(), orig.shape(), None).unwrap();
2239 let mut map = tensor.map().unwrap();
2240 map.as_mut_slice().fill(233);
2241 }
2242 drop(orig);
2243
2244 let end_open_fds = proc.fd_count().unwrap();
2245
2246 assert_eq!(
2247 start_open_fds, end_open_fds,
2248 "File descriptor leak detected: {} -> {}",
2249 start_open_fds, end_open_fds
2250 );
2251 }
2252
2253 #[cfg(feature = "ndarray")]
2254 #[test]
2255 fn test_ndarray() {
2256 let _lock = FD_LOCK.read().unwrap();
2257 let shape = vec![2, 3, 4];
2258 let tensor = Tensor::<f32>::new(&shape, None, None).expect("Failed to create tensor");
2259
2260 let mut tensor_map = tensor.map().expect("Failed to map tensor memory");
2261 tensor_map.fill(1.0);
2262
2263 let view = tensor_map.view().expect("Failed to get ndarray view");
2264 assert_eq!(view.shape(), &[2, 3, 4]);
2265 assert!(view.iter().all(|&x| x == 1.0));
2266
2267 let mut view_mut = tensor_map
2268 .view_mut()
2269 .expect("Failed to get mutable ndarray view");
2270 view_mut[[0, 0, 0]] = 42.0;
2271 assert_eq!(view_mut[[0, 0, 0]], 42.0);
2272 assert_eq!(tensor_map[0], 42.0, "Value at index 0 should be 42");
2273 }
2274
2275 #[test]
2276 fn test_buffer_identity_unique() {
2277 let id1 = BufferIdentity::new();
2278 let id2 = BufferIdentity::new();
2279 assert_ne!(
2280 id1.id(),
2281 id2.id(),
2282 "Two identities should have different ids"
2283 );
2284 }
2285
2286 #[test]
2287 fn test_buffer_identity_clone_shares_guard() {
2288 let id1 = BufferIdentity::new();
2289 let weak = id1.weak();
2290 assert!(
2291 weak.upgrade().is_some(),
2292 "Weak should be alive while original exists"
2293 );
2294
2295 let id2 = id1.clone();
2296 assert_eq!(id1.id(), id2.id(), "Cloned identity should have same id");
2297
2298 drop(id1);
2299 assert!(
2300 weak.upgrade().is_some(),
2301 "Weak should still be alive (clone holds Arc)"
2302 );
2303
2304 drop(id2);
2305 assert!(
2306 weak.upgrade().is_none(),
2307 "Weak should be dead after all clones dropped"
2308 );
2309 }
2310
2311 #[test]
2312 fn test_tensor_buffer_identity() {
2313 let t1 = Tensor::<u8>::new(&[100], Some(TensorMemory::Mem), Some("t1")).unwrap();
2314 let t2 = Tensor::<u8>::new(&[100], Some(TensorMemory::Mem), Some("t2")).unwrap();
2315 assert_ne!(
2316 t1.buffer_identity().id(),
2317 t2.buffer_identity().id(),
2318 "Different tensors should have different buffer ids"
2319 );
2320 }
2321
2322 pub static FD_LOCK: RwLock<()> = RwLock::new(());
2326
2327 #[test]
2330 #[cfg(not(target_os = "linux"))]
2331 fn test_dma_not_available_on_non_linux() {
2332 assert!(
2333 !is_dma_available(),
2334 "DMA memory allocation should NOT be available on non-Linux platforms"
2335 );
2336 }
2337
2338 #[test]
2341 #[cfg(unix)]
2342 fn test_shm_available_and_usable() {
2343 assert!(
2344 is_shm_available(),
2345 "SHM memory allocation should be available on Unix systems"
2346 );
2347
2348 let tensor = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Shm), None)
2350 .expect("Failed to create SHM tensor");
2351
2352 let mut map = tensor.map().expect("Failed to map SHM tensor");
2354 map.as_mut_slice().fill(0xAB);
2355
2356 assert!(
2358 map.as_slice().iter().all(|&b| b == 0xAB),
2359 "SHM tensor data should be writable and readable"
2360 );
2361 }
2362}