1#[cfg(target_os = "linux")]
30mod dma;
31#[cfg(target_os = "linux")]
32mod dmabuf;
33mod error;
34mod format;
35mod mem;
36mod pbo;
37#[cfg(unix)]
38mod shm;
39mod tensor_dyn;
40
41#[cfg(target_os = "linux")]
42pub use crate::dma::{DmaMap, DmaTensor};
43pub use crate::mem::{MemMap, MemTensor};
44pub use crate::pbo::{PboMap, PboMapping, PboOps, PboTensor};
45#[cfg(unix)]
46pub use crate::shm::{ShmMap, ShmTensor};
47pub use error::{Error, Result};
48pub use format::{PixelFormat, PixelLayout};
49use num_traits::Num;
50use serde::{Deserialize, Serialize};
51#[cfg(unix)]
52use std::os::fd::OwnedFd;
53use std::{
54 fmt,
55 ops::{Deref, DerefMut},
56 sync::{
57 atomic::{AtomicU64, Ordering},
58 Arc, Weak,
59 },
60};
61pub use tensor_dyn::TensorDyn;
62
63#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
65#[repr(u8)]
66#[non_exhaustive]
67pub enum DType {
68 U8,
69 I8,
70 U16,
71 I16,
72 U32,
73 I32,
74 U64,
75 I64,
76 F16,
77 F32,
78 F64,
79}
80
81impl DType {
82 pub const fn size(&self) -> usize {
84 match self {
85 Self::U8 | Self::I8 => 1,
86 Self::U16 | Self::I16 | Self::F16 => 2,
87 Self::U32 | Self::I32 | Self::F32 => 4,
88 Self::U64 | Self::I64 | Self::F64 => 8,
89 }
90 }
91
92 pub const fn name(&self) -> &'static str {
94 match self {
95 Self::U8 => "u8",
96 Self::I8 => "i8",
97 Self::U16 => "u16",
98 Self::I16 => "i16",
99 Self::U32 => "u32",
100 Self::I32 => "i32",
101 Self::U64 => "u64",
102 Self::I64 => "i64",
103 Self::F16 => "f16",
104 Self::F32 => "f32",
105 Self::F64 => "f64",
106 }
107 }
108}
109
110impl fmt::Display for DType {
111 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
112 f.write_str(self.name())
113 }
114}
115
116static NEXT_BUFFER_ID: AtomicU64 = AtomicU64::new(1);
118
119#[derive(Debug, Clone)]
125pub struct BufferIdentity {
126 id: u64,
127 guard: Arc<()>,
128}
129
130impl BufferIdentity {
131 pub fn new() -> Self {
133 Self {
134 id: NEXT_BUFFER_ID.fetch_add(1, Ordering::Relaxed),
135 guard: Arc::new(()),
136 }
137 }
138
139 pub fn id(&self) -> u64 {
141 self.id
142 }
143
144 pub fn weak(&self) -> Weak<()> {
147 Arc::downgrade(&self.guard)
148 }
149}
150
151impl Default for BufferIdentity {
152 fn default() -> Self {
153 Self::new()
154 }
155}
156
157#[cfg(target_os = "linux")]
158use nix::sys::stat::{major, minor};
159
160pub trait TensorTrait<T>: Send + Sync
161where
162 T: Num + Clone + fmt::Debug,
163{
164 fn new(shape: &[usize], name: Option<&str>) -> Result<Self>
167 where
168 Self: Sized;
169
170 #[cfg(unix)]
171 fn from_fd(fd: std::os::fd::OwnedFd, shape: &[usize], name: Option<&str>) -> Result<Self>
177 where
178 Self: Sized;
179
180 #[cfg(unix)]
181 fn clone_fd(&self) -> Result<std::os::fd::OwnedFd>;
183
184 fn memory(&self) -> TensorMemory;
186
187 fn name(&self) -> String;
189
190 fn len(&self) -> usize {
192 self.shape().iter().product()
193 }
194
195 fn is_empty(&self) -> bool {
197 self.len() == 0
198 }
199
200 fn size(&self) -> usize {
202 self.len() * std::mem::size_of::<T>()
203 }
204
205 fn shape(&self) -> &[usize];
207
208 fn reshape(&mut self, shape: &[usize]) -> Result<()>;
211
212 fn map(&self) -> Result<TensorMap<T>>;
215
216 fn buffer_identity(&self) -> &BufferIdentity;
218}
219
220pub trait TensorMapTrait<T>
221where
222 T: Num + Clone + fmt::Debug,
223{
224 fn shape(&self) -> &[usize];
226
227 fn unmap(&mut self);
229
230 fn len(&self) -> usize {
232 self.shape().iter().product()
233 }
234
235 fn is_empty(&self) -> bool {
237 self.len() == 0
238 }
239
240 fn size(&self) -> usize {
242 self.len() * std::mem::size_of::<T>()
243 }
244
245 fn as_slice(&self) -> &[T];
247
248 fn as_mut_slice(&mut self) -> &mut [T];
250
251 #[cfg(feature = "ndarray")]
252 fn view(&'_ self) -> Result<ndarray::ArrayView<'_, T, ndarray::Dim<ndarray::IxDynImpl>>> {
254 Ok(ndarray::ArrayView::from_shape(
255 self.shape(),
256 self.as_slice(),
257 )?)
258 }
259
260 #[cfg(feature = "ndarray")]
261 fn view_mut(
263 &'_ mut self,
264 ) -> Result<ndarray::ArrayViewMut<'_, T, ndarray::Dim<ndarray::IxDynImpl>>> {
265 let shape = self.shape().to_vec();
266 Ok(ndarray::ArrayViewMut::from_shape(
267 shape,
268 self.as_mut_slice(),
269 )?)
270 }
271}
272
273#[derive(Debug, Clone, Copy, PartialEq, Eq)]
274pub enum TensorMemory {
275 #[cfg(target_os = "linux")]
276 Dma,
280 #[cfg(unix)]
281 Shm,
284
285 Mem,
287
288 Pbo,
291}
292
293impl From<TensorMemory> for String {
294 fn from(memory: TensorMemory) -> Self {
295 match memory {
296 #[cfg(target_os = "linux")]
297 TensorMemory::Dma => "dma".to_owned(),
298 #[cfg(unix)]
299 TensorMemory::Shm => "shm".to_owned(),
300 TensorMemory::Mem => "mem".to_owned(),
301 TensorMemory::Pbo => "pbo".to_owned(),
302 }
303 }
304}
305
306impl TryFrom<&str> for TensorMemory {
307 type Error = Error;
308
309 fn try_from(s: &str) -> Result<Self> {
310 match s {
311 #[cfg(target_os = "linux")]
312 "dma" => Ok(TensorMemory::Dma),
313 #[cfg(unix)]
314 "shm" => Ok(TensorMemory::Shm),
315 "mem" => Ok(TensorMemory::Mem),
316 "pbo" => Ok(TensorMemory::Pbo),
317 _ => Err(Error::InvalidMemoryType(s.to_owned())),
318 }
319 }
320}
321
322#[derive(Debug)]
323#[allow(dead_code)] pub(crate) enum TensorStorage<T>
325where
326 T: Num + Clone + fmt::Debug + Send + Sync,
327{
328 #[cfg(target_os = "linux")]
329 Dma(DmaTensor<T>),
330 #[cfg(unix)]
331 Shm(ShmTensor<T>),
332 Mem(MemTensor<T>),
333 Pbo(PboTensor<T>),
334}
335
336impl<T> TensorStorage<T>
337where
338 T: Num + Clone + fmt::Debug + Send + Sync,
339{
340 fn new(shape: &[usize], memory: Option<TensorMemory>, name: Option<&str>) -> Result<Self> {
345 match memory {
346 #[cfg(target_os = "linux")]
347 Some(TensorMemory::Dma) => {
348 DmaTensor::<T>::new(shape, name).map(TensorStorage::Dma)
349 }
350 #[cfg(unix)]
351 Some(TensorMemory::Shm) => {
352 ShmTensor::<T>::new(shape, name).map(TensorStorage::Shm)
353 }
354 Some(TensorMemory::Mem) => {
355 MemTensor::<T>::new(shape, name).map(TensorStorage::Mem)
356 }
357 Some(TensorMemory::Pbo) => Err(crate::error::Error::NotImplemented(
358 "PboTensor cannot be created via Tensor::new() — use ImageProcessor::create_image()".to_owned(),
359 )),
360 None => {
361 if std::env::var("EDGEFIRST_TENSOR_FORCE_MEM")
362 .is_ok_and(|x| x != "0" && x.to_lowercase() != "false")
363 {
364 MemTensor::<T>::new(shape, name).map(TensorStorage::Mem)
365 } else {
366 #[cfg(target_os = "linux")]
367 {
368 match DmaTensor::<T>::new(shape, name) {
370 Ok(tensor) => Ok(TensorStorage::Dma(tensor)),
371 Err(_) => {
372 match ShmTensor::<T>::new(shape, name)
373 .map(TensorStorage::Shm)
374 {
375 Ok(tensor) => Ok(tensor),
376 Err(_) => MemTensor::<T>::new(shape, name)
377 .map(TensorStorage::Mem),
378 }
379 }
380 }
381 }
382 #[cfg(all(unix, not(target_os = "linux")))]
383 {
384 match ShmTensor::<T>::new(shape, name) {
386 Ok(tensor) => Ok(TensorStorage::Shm(tensor)),
387 Err(_) => {
388 MemTensor::<T>::new(shape, name).map(TensorStorage::Mem)
389 }
390 }
391 }
392 #[cfg(not(unix))]
393 {
394 MemTensor::<T>::new(shape, name).map(TensorStorage::Mem)
396 }
397 }
398 }
399 }
400 }
401
402 #[cfg(unix)]
405 fn from_fd(fd: OwnedFd, shape: &[usize], name: Option<&str>) -> Result<Self> {
406 #[cfg(target_os = "linux")]
407 {
408 use nix::sys::stat::fstat;
409
410 let stat = fstat(&fd)?;
411 let major = major(stat.st_dev);
412 let minor = minor(stat.st_dev);
413
414 log::debug!("Creating tensor from fd: major={major}, minor={minor}");
415
416 if major != 0 {
417 return Err(Error::UnknownDeviceType(major, minor));
419 }
420
421 match minor {
422 9 | 10 => {
423 DmaTensor::<T>::from_fd(fd, shape, name).map(TensorStorage::Dma)
425 }
426 _ => {
427 ShmTensor::<T>::from_fd(fd, shape, name).map(TensorStorage::Shm)
429 }
430 }
431 }
432 #[cfg(all(unix, not(target_os = "linux")))]
433 {
434 ShmTensor::<T>::from_fd(fd, shape, name).map(TensorStorage::Shm)
436 }
437 }
438}
439
440impl<T> TensorTrait<T> for TensorStorage<T>
441where
442 T: Num + Clone + fmt::Debug + Send + Sync,
443{
444 fn new(shape: &[usize], name: Option<&str>) -> Result<Self> {
445 Self::new(shape, None, name)
446 }
447
448 #[cfg(unix)]
449 fn from_fd(fd: OwnedFd, shape: &[usize], name: Option<&str>) -> Result<Self> {
450 Self::from_fd(fd, shape, name)
451 }
452
453 #[cfg(unix)]
454 fn clone_fd(&self) -> Result<OwnedFd> {
455 match self {
456 #[cfg(target_os = "linux")]
457 TensorStorage::Dma(t) => t.clone_fd(),
458 TensorStorage::Shm(t) => t.clone_fd(),
459 TensorStorage::Mem(t) => t.clone_fd(),
460 TensorStorage::Pbo(t) => t.clone_fd(),
461 }
462 }
463
464 fn memory(&self) -> TensorMemory {
465 match self {
466 #[cfg(target_os = "linux")]
467 TensorStorage::Dma(_) => TensorMemory::Dma,
468 #[cfg(unix)]
469 TensorStorage::Shm(_) => TensorMemory::Shm,
470 TensorStorage::Mem(_) => TensorMemory::Mem,
471 TensorStorage::Pbo(_) => TensorMemory::Pbo,
472 }
473 }
474
475 fn name(&self) -> String {
476 match self {
477 #[cfg(target_os = "linux")]
478 TensorStorage::Dma(t) => t.name(),
479 #[cfg(unix)]
480 TensorStorage::Shm(t) => t.name(),
481 TensorStorage::Mem(t) => t.name(),
482 TensorStorage::Pbo(t) => t.name(),
483 }
484 }
485
486 fn shape(&self) -> &[usize] {
487 match self {
488 #[cfg(target_os = "linux")]
489 TensorStorage::Dma(t) => t.shape(),
490 #[cfg(unix)]
491 TensorStorage::Shm(t) => t.shape(),
492 TensorStorage::Mem(t) => t.shape(),
493 TensorStorage::Pbo(t) => t.shape(),
494 }
495 }
496
497 fn reshape(&mut self, shape: &[usize]) -> Result<()> {
498 match self {
499 #[cfg(target_os = "linux")]
500 TensorStorage::Dma(t) => t.reshape(shape),
501 #[cfg(unix)]
502 TensorStorage::Shm(t) => t.reshape(shape),
503 TensorStorage::Mem(t) => t.reshape(shape),
504 TensorStorage::Pbo(t) => t.reshape(shape),
505 }
506 }
507
508 fn map(&self) -> Result<TensorMap<T>> {
509 match self {
510 #[cfg(target_os = "linux")]
511 TensorStorage::Dma(t) => t.map(),
512 #[cfg(unix)]
513 TensorStorage::Shm(t) => t.map(),
514 TensorStorage::Mem(t) => t.map(),
515 TensorStorage::Pbo(t) => t.map(),
516 }
517 }
518
519 fn buffer_identity(&self) -> &BufferIdentity {
520 match self {
521 #[cfg(target_os = "linux")]
522 TensorStorage::Dma(t) => t.buffer_identity(),
523 #[cfg(unix)]
524 TensorStorage::Shm(t) => t.buffer_identity(),
525 TensorStorage::Mem(t) => t.buffer_identity(),
526 TensorStorage::Pbo(t) => t.buffer_identity(),
527 }
528 }
529}
530
531#[derive(Debug)]
537pub struct Tensor<T>
538where
539 T: Num + Clone + fmt::Debug + Send + Sync,
540{
541 pub(crate) storage: TensorStorage<T>,
542 format: Option<PixelFormat>,
543 chroma: Option<Box<Tensor<T>>>,
544}
545
546impl<T> Tensor<T>
547where
548 T: Num + Clone + fmt::Debug + Send + Sync,
549{
550 pub(crate) fn wrap(storage: TensorStorage<T>) -> Self {
552 Self {
553 storage,
554 format: None,
555 chroma: None,
556 }
557 }
558
559 pub fn new(shape: &[usize], memory: Option<TensorMemory>, name: Option<&str>) -> Result<Self> {
584 TensorStorage::new(shape, memory, name).map(Self::wrap)
585 }
586
587 pub fn image(
589 width: usize,
590 height: usize,
591 format: PixelFormat,
592 memory: Option<TensorMemory>,
593 ) -> Result<Self> {
594 let shape = match format.layout() {
595 PixelLayout::Packed => vec![height, width, format.channels()],
596 PixelLayout::Planar => vec![format.channels(), height, width],
597 PixelLayout::SemiPlanar => {
598 let total_h = match format {
602 PixelFormat::Nv12 => {
603 if !height.is_multiple_of(2) {
604 return Err(Error::InvalidArgument(format!(
605 "NV12 requires even height, got {height}"
606 )));
607 }
608 height * 3 / 2
609 }
610 PixelFormat::Nv16 => height * 2,
611 _ => {
612 return Err(Error::InvalidArgument(format!(
613 "unknown semi-planar height multiplier for {format:?}"
614 )))
615 }
616 };
617 vec![total_h, width]
618 }
619 };
620 let mut t = Self::new(&shape, memory, None)?;
621 t.format = Some(format);
622 Ok(t)
623 }
624
625 pub fn set_format(&mut self, format: PixelFormat) -> Result<()> {
642 let shape = self.shape();
643 match format.layout() {
644 PixelLayout::Packed => {
645 if shape.len() != 3 || shape[2] != format.channels() {
646 return Err(Error::InvalidShape(format!(
647 "packed format {format:?} expects [H, W, {}], got {shape:?}",
648 format.channels()
649 )));
650 }
651 }
652 PixelLayout::Planar => {
653 if shape.len() != 3 || shape[0] != format.channels() {
654 return Err(Error::InvalidShape(format!(
655 "planar format {format:?} expects [{}, H, W], got {shape:?}",
656 format.channels()
657 )));
658 }
659 }
660 PixelLayout::SemiPlanar => {
661 if shape.len() != 2 {
662 return Err(Error::InvalidShape(format!(
663 "semi-planar format {format:?} expects [H*k, W], got {shape:?}"
664 )));
665 }
666 match format {
667 PixelFormat::Nv12 if !shape[0].is_multiple_of(3) => {
668 return Err(Error::InvalidShape(format!(
669 "NV12 contiguous shape[0] must be divisible by 3, got {}",
670 shape[0]
671 )));
672 }
673 PixelFormat::Nv16 if !shape[0].is_multiple_of(2) => {
674 return Err(Error::InvalidShape(format!(
675 "NV16 contiguous shape[0] must be even, got {}",
676 shape[0]
677 )));
678 }
679 _ => {}
680 }
681 }
682 }
683 self.format = Some(format);
684 Ok(())
685 }
686
687 pub fn format(&self) -> Option<PixelFormat> {
689 self.format
690 }
691
692 pub fn width(&self) -> Option<usize> {
694 let fmt = self.format?;
695 let shape = self.shape();
696 match fmt.layout() {
697 PixelLayout::Packed => Some(shape[1]),
698 PixelLayout::Planar => Some(shape[2]),
699 PixelLayout::SemiPlanar => Some(shape[1]),
700 }
701 }
702
703 pub fn height(&self) -> Option<usize> {
705 let fmt = self.format?;
706 let shape = self.shape();
707 match fmt.layout() {
708 PixelLayout::Packed => Some(shape[0]),
709 PixelLayout::Planar => Some(shape[1]),
710 PixelLayout::SemiPlanar => {
711 if self.is_multiplane() {
712 Some(shape[0])
713 } else {
714 match fmt {
715 PixelFormat::Nv12 => Some(shape[0] * 2 / 3),
716 PixelFormat::Nv16 => Some(shape[0] / 2),
717 _ => None,
718 }
719 }
720 }
721 }
722 }
723
724 pub fn from_planes(luma: Tensor<T>, chroma: Tensor<T>, format: PixelFormat) -> Result<Self> {
726 if format.layout() != PixelLayout::SemiPlanar {
727 return Err(Error::InvalidArgument(format!(
728 "from_planes requires a semi-planar format, got {format:?}"
729 )));
730 }
731 if chroma.format.is_some() || chroma.chroma.is_some() {
732 return Err(Error::InvalidArgument(
733 "chroma tensor must be a raw tensor (no format or chroma metadata)".into(),
734 ));
735 }
736 let luma_shape = luma.shape();
737 let chroma_shape = chroma.shape();
738 if luma_shape.len() != 2 || chroma_shape.len() != 2 {
739 return Err(Error::InvalidArgument(format!(
740 "from_planes expects 2D shapes, got luma={luma_shape:?} chroma={chroma_shape:?}"
741 )));
742 }
743 if luma_shape[1] != chroma_shape[1] {
744 return Err(Error::InvalidArgument(format!(
745 "luma width {} != chroma width {}",
746 luma_shape[1], chroma_shape[1]
747 )));
748 }
749 match format {
750 PixelFormat::Nv12 => {
751 if luma_shape[0] % 2 != 0 {
752 return Err(Error::InvalidArgument(format!(
753 "NV12 requires even luma height, got {}",
754 luma_shape[0]
755 )));
756 }
757 if chroma_shape[0] != luma_shape[0] / 2 {
758 return Err(Error::InvalidArgument(format!(
759 "NV12 chroma height {} != luma height / 2 ({})",
760 chroma_shape[0],
761 luma_shape[0] / 2
762 )));
763 }
764 }
765 PixelFormat::Nv16 => {
766 if chroma_shape[0] != luma_shape[0] {
767 return Err(Error::InvalidArgument(format!(
768 "NV16 chroma height {} != luma height {}",
769 chroma_shape[0], luma_shape[0]
770 )));
771 }
772 }
773 _ => {
774 return Err(Error::InvalidArgument(format!(
775 "from_planes only supports NV12 and NV16, got {format:?}"
776 )));
777 }
778 }
779
780 Ok(Tensor {
781 storage: luma.storage,
782 format: Some(format),
783 chroma: Some(Box::new(chroma)),
784 })
785 }
786
787 pub fn is_multiplane(&self) -> bool {
789 self.chroma.is_some()
790 }
791
792 pub fn chroma(&self) -> Option<&Tensor<T>> {
794 self.chroma.as_deref()
795 }
796
797 pub fn as_pbo(&self) -> Option<&PboTensor<T>> {
799 match &self.storage {
800 TensorStorage::Pbo(p) => Some(p),
801 _ => None,
802 }
803 }
804
805 #[cfg(target_os = "linux")]
807 pub fn as_dma(&self) -> Option<&DmaTensor<T>> {
808 match &self.storage {
809 TensorStorage::Dma(d) => Some(d),
810 _ => None,
811 }
812 }
813
814 #[cfg(target_os = "linux")]
825 pub fn dmabuf(&self) -> Result<std::os::fd::BorrowedFd<'_>> {
826 use std::os::fd::AsFd;
827 match &self.storage {
828 TensorStorage::Dma(dma) => Ok(dma.fd.as_fd()),
829 _ => Err(Error::NotImplemented(format!(
830 "dmabuf requires DMA-backed tensor, got {:?}",
831 self.storage.memory()
832 ))),
833 }
834 }
835
836 pub fn from_pbo(pbo: PboTensor<T>) -> Self {
838 Self {
839 storage: TensorStorage::Pbo(pbo),
840 format: None,
841 chroma: None,
842 }
843 }
844}
845
846impl<T> TensorTrait<T> for Tensor<T>
847where
848 T: Num + Clone + fmt::Debug + Send + Sync,
849{
850 fn new(shape: &[usize], name: Option<&str>) -> Result<Self>
851 where
852 Self: Sized,
853 {
854 Self::new(shape, None, name)
855 }
856
857 #[cfg(unix)]
858 fn from_fd(fd: std::os::fd::OwnedFd, shape: &[usize], name: Option<&str>) -> Result<Self>
859 where
860 Self: Sized,
861 {
862 Ok(Self::wrap(TensorStorage::from_fd(fd, shape, name)?))
863 }
864
865 #[cfg(unix)]
866 fn clone_fd(&self) -> Result<std::os::fd::OwnedFd> {
867 self.storage.clone_fd()
868 }
869
870 fn memory(&self) -> TensorMemory {
871 self.storage.memory()
872 }
873
874 fn name(&self) -> String {
875 self.storage.name()
876 }
877
878 fn shape(&self) -> &[usize] {
879 self.storage.shape()
880 }
881
882 fn reshape(&mut self, shape: &[usize]) -> Result<()> {
883 if self.chroma.is_some() {
884 return Err(Error::InvalidOperation(
885 "cannot reshape a multiplane tensor — decompose planes first".into(),
886 ));
887 }
888 self.storage.reshape(shape)?;
889 self.format = None;
890 Ok(())
891 }
892
893 fn map(&self) -> Result<TensorMap<T>> {
894 self.storage.map()
895 }
896
897 fn buffer_identity(&self) -> &BufferIdentity {
898 self.storage.buffer_identity()
899 }
900}
901
902pub enum TensorMap<T>
903where
904 T: Num + Clone + fmt::Debug,
905{
906 #[cfg(target_os = "linux")]
907 Dma(DmaMap<T>),
908 #[cfg(unix)]
909 Shm(ShmMap<T>),
910 Mem(MemMap<T>),
911 Pbo(PboMap<T>),
912}
913
914impl<T> TensorMapTrait<T> for TensorMap<T>
915where
916 T: Num + Clone + fmt::Debug,
917{
918 fn shape(&self) -> &[usize] {
919 match self {
920 #[cfg(target_os = "linux")]
921 TensorMap::Dma(map) => map.shape(),
922 #[cfg(unix)]
923 TensorMap::Shm(map) => map.shape(),
924 TensorMap::Mem(map) => map.shape(),
925 TensorMap::Pbo(map) => map.shape(),
926 }
927 }
928
929 fn unmap(&mut self) {
930 match self {
931 #[cfg(target_os = "linux")]
932 TensorMap::Dma(map) => map.unmap(),
933 #[cfg(unix)]
934 TensorMap::Shm(map) => map.unmap(),
935 TensorMap::Mem(map) => map.unmap(),
936 TensorMap::Pbo(map) => map.unmap(),
937 }
938 }
939
940 fn as_slice(&self) -> &[T] {
941 match self {
942 #[cfg(target_os = "linux")]
943 TensorMap::Dma(map) => map.as_slice(),
944 #[cfg(unix)]
945 TensorMap::Shm(map) => map.as_slice(),
946 TensorMap::Mem(map) => map.as_slice(),
947 TensorMap::Pbo(map) => map.as_slice(),
948 }
949 }
950
951 fn as_mut_slice(&mut self) -> &mut [T] {
952 match self {
953 #[cfg(target_os = "linux")]
954 TensorMap::Dma(map) => map.as_mut_slice(),
955 #[cfg(unix)]
956 TensorMap::Shm(map) => map.as_mut_slice(),
957 TensorMap::Mem(map) => map.as_mut_slice(),
958 TensorMap::Pbo(map) => map.as_mut_slice(),
959 }
960 }
961}
962
963impl<T> Deref for TensorMap<T>
964where
965 T: Num + Clone + fmt::Debug,
966{
967 type Target = [T];
968
969 fn deref(&self) -> &[T] {
970 match self {
971 #[cfg(target_os = "linux")]
972 TensorMap::Dma(map) => map.deref(),
973 #[cfg(unix)]
974 TensorMap::Shm(map) => map.deref(),
975 TensorMap::Mem(map) => map.deref(),
976 TensorMap::Pbo(map) => map.deref(),
977 }
978 }
979}
980
981impl<T> DerefMut for TensorMap<T>
982where
983 T: Num + Clone + fmt::Debug,
984{
985 fn deref_mut(&mut self) -> &mut [T] {
986 match self {
987 #[cfg(target_os = "linux")]
988 TensorMap::Dma(map) => map.deref_mut(),
989 #[cfg(unix)]
990 TensorMap::Shm(map) => map.deref_mut(),
991 TensorMap::Mem(map) => map.deref_mut(),
992 TensorMap::Pbo(map) => map.deref_mut(),
993 }
994 }
995}
996
997#[cfg(target_os = "linux")]
1009static DMA_AVAILABLE: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
1010
1011#[cfg(target_os = "linux")]
1013pub fn is_dma_available() -> bool {
1014 *DMA_AVAILABLE.get_or_init(|| Tensor::<u8>::new(&[64], Some(TensorMemory::Dma), None).is_ok())
1015}
1016
1017#[cfg(not(target_os = "linux"))]
1021pub fn is_dma_available() -> bool {
1022 false
1023}
1024
1025#[cfg(unix)]
1032static SHM_AVAILABLE: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
1033
1034#[cfg(unix)]
1036pub fn is_shm_available() -> bool {
1037 *SHM_AVAILABLE.get_or_init(|| Tensor::<u8>::new(&[64], Some(TensorMemory::Shm), None).is_ok())
1038}
1039
1040#[cfg(not(unix))]
1044pub fn is_shm_available() -> bool {
1045 false
1046}
1047
1048#[cfg(test)]
1049mod dtype_tests {
1050 use super::*;
1051
1052 #[test]
1053 fn dtype_size() {
1054 assert_eq!(DType::U8.size(), 1);
1055 assert_eq!(DType::I8.size(), 1);
1056 assert_eq!(DType::U16.size(), 2);
1057 assert_eq!(DType::I16.size(), 2);
1058 assert_eq!(DType::U32.size(), 4);
1059 assert_eq!(DType::I32.size(), 4);
1060 assert_eq!(DType::U64.size(), 8);
1061 assert_eq!(DType::I64.size(), 8);
1062 assert_eq!(DType::F16.size(), 2);
1063 assert_eq!(DType::F32.size(), 4);
1064 assert_eq!(DType::F64.size(), 8);
1065 }
1066
1067 #[test]
1068 fn dtype_name() {
1069 assert_eq!(DType::U8.name(), "u8");
1070 assert_eq!(DType::F16.name(), "f16");
1071 assert_eq!(DType::F32.name(), "f32");
1072 }
1073
1074 #[test]
1075 fn dtype_serde_roundtrip() {
1076 use serde_json;
1077 let dt = DType::F16;
1078 let json = serde_json::to_string(&dt).unwrap();
1079 let back: DType = serde_json::from_str(&json).unwrap();
1080 assert_eq!(dt, back);
1081 }
1082}
1083
1084#[cfg(test)]
1085mod image_tests {
1086 use super::*;
1087
1088 #[test]
1089 fn raw_tensor_has_no_format() {
1090 let t = Tensor::<u8>::new(&[480, 640, 3], None, None).unwrap();
1091 assert!(t.format().is_none());
1092 assert!(t.width().is_none());
1093 assert!(t.height().is_none());
1094 assert!(!t.is_multiplane());
1095 assert!(t.chroma().is_none());
1096 }
1097
1098 #[test]
1099 fn image_tensor_packed() {
1100 let t = Tensor::<u8>::image(640, 480, PixelFormat::Rgba, None).unwrap();
1101 assert_eq!(t.format(), Some(PixelFormat::Rgba));
1102 assert_eq!(t.width(), Some(640));
1103 assert_eq!(t.height(), Some(480));
1104 assert_eq!(t.shape(), &[480, 640, 4]);
1105 assert!(!t.is_multiplane());
1106 }
1107
1108 #[test]
1109 fn image_tensor_planar() {
1110 let t = Tensor::<u8>::image(640, 480, PixelFormat::PlanarRgb, None).unwrap();
1111 assert_eq!(t.format(), Some(PixelFormat::PlanarRgb));
1112 assert_eq!(t.width(), Some(640));
1113 assert_eq!(t.height(), Some(480));
1114 assert_eq!(t.shape(), &[3, 480, 640]);
1115 }
1116
1117 #[test]
1118 fn image_tensor_semi_planar_contiguous() {
1119 let t = Tensor::<u8>::image(640, 480, PixelFormat::Nv12, None).unwrap();
1120 assert_eq!(t.format(), Some(PixelFormat::Nv12));
1121 assert_eq!(t.width(), Some(640));
1122 assert_eq!(t.height(), Some(480));
1123 assert_eq!(t.shape(), &[720, 640]);
1125 assert!(!t.is_multiplane());
1126 }
1127
1128 #[test]
1129 fn set_format_valid() {
1130 let mut t = Tensor::<u8>::new(&[480, 640, 3], None, None).unwrap();
1131 assert!(t.format().is_none());
1132 t.set_format(PixelFormat::Rgb).unwrap();
1133 assert_eq!(t.format(), Some(PixelFormat::Rgb));
1134 assert_eq!(t.width(), Some(640));
1135 assert_eq!(t.height(), Some(480));
1136 }
1137
1138 #[test]
1139 fn set_format_invalid_shape() {
1140 let mut t = Tensor::<u8>::new(&[480, 640, 4], None, None).unwrap();
1141 let err = t.set_format(PixelFormat::Rgb);
1143 assert!(err.is_err());
1144 assert!(t.format().is_none());
1146 }
1147
1148 #[test]
1149 fn reshape_clears_format() {
1150 let mut t = Tensor::<u8>::image(640, 480, PixelFormat::Rgba, None).unwrap();
1151 assert_eq!(t.format(), Some(PixelFormat::Rgba));
1152 t.reshape(&[480 * 640 * 4]).unwrap();
1154 assert!(t.format().is_none());
1155 }
1156
1157 #[test]
1158 fn from_planes_nv12() {
1159 let y = Tensor::<u8>::new(&[480, 640], None, None).unwrap();
1160 let uv = Tensor::<u8>::new(&[240, 640], None, None).unwrap();
1161 let img = Tensor::from_planes(y, uv, PixelFormat::Nv12).unwrap();
1162 assert_eq!(img.format(), Some(PixelFormat::Nv12));
1163 assert!(img.is_multiplane());
1164 assert!(img.chroma().is_some());
1165 assert_eq!(img.width(), Some(640));
1166 assert_eq!(img.height(), Some(480));
1167 }
1168
1169 #[test]
1170 fn from_planes_rejects_non_semiplanar() {
1171 let y = Tensor::<u8>::new(&[480, 640], None, None).unwrap();
1172 let uv = Tensor::<u8>::new(&[240, 640], None, None).unwrap();
1173 let err = Tensor::from_planes(y, uv, PixelFormat::Rgb);
1174 assert!(err.is_err());
1175 }
1176
1177 #[test]
1178 fn reshape_multiplane_errors() {
1179 let y = Tensor::<u8>::new(&[480, 640], None, None).unwrap();
1180 let uv = Tensor::<u8>::new(&[240, 640], None, None).unwrap();
1181 let mut img = Tensor::from_planes(y, uv, PixelFormat::Nv12).unwrap();
1182 let err = img.reshape(&[480 * 640 + 240 * 640]);
1183 assert!(err.is_err());
1184 }
1185}
1186
1187#[cfg(test)]
1188mod tests {
1189 #[cfg(target_os = "linux")]
1190 use nix::unistd::{access, AccessFlags};
1191 #[cfg(target_os = "linux")]
1192 use std::io::Write as _;
1193 use std::sync::RwLock;
1194
1195 use super::*;
1196
1197 #[ctor::ctor]
1198 fn init() {
1199 env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init();
1200 }
1201
1202 #[cfg(target_os = "linux")]
1204 macro_rules! function {
1205 () => {{
1206 fn f() {}
1207 fn type_name_of<T>(_: T) -> &'static str {
1208 std::any::type_name::<T>()
1209 }
1210 let name = type_name_of(f);
1211
1212 match &name[..name.len() - 3].rfind(':') {
1214 Some(pos) => &name[pos + 1..name.len() - 3],
1215 None => &name[..name.len() - 3],
1216 }
1217 }};
1218 }
1219
1220 #[test]
1221 #[cfg(target_os = "linux")]
1222 fn test_tensor() {
1223 let _lock = FD_LOCK.read().unwrap();
1224 let shape = vec![1];
1225 let tensor = DmaTensor::<f32>::new(&shape, Some("dma_tensor"));
1226 let dma_enabled = tensor.is_ok();
1227
1228 let tensor = Tensor::<f32>::new(&shape, None, None).expect("Failed to create tensor");
1229 match dma_enabled {
1230 true => assert_eq!(tensor.memory(), TensorMemory::Dma),
1231 false => assert_eq!(tensor.memory(), TensorMemory::Shm),
1232 }
1233 }
1234
1235 #[test]
1236 #[cfg(all(unix, not(target_os = "linux")))]
1237 fn test_tensor() {
1238 let shape = vec![1];
1239 let tensor = Tensor::<f32>::new(&shape, None, None).expect("Failed to create tensor");
1240 assert!(
1242 tensor.memory() == TensorMemory::Shm || tensor.memory() == TensorMemory::Mem,
1243 "Expected SHM or Mem on macOS, got {:?}",
1244 tensor.memory()
1245 );
1246 }
1247
1248 #[test]
1249 #[cfg(not(unix))]
1250 fn test_tensor() {
1251 let shape = vec![1];
1252 let tensor = Tensor::<f32>::new(&shape, None, None).expect("Failed to create tensor");
1253 assert_eq!(tensor.memory(), TensorMemory::Mem);
1254 }
1255
1256 #[test]
1257 #[cfg(target_os = "linux")]
1258 fn test_dma_tensor() {
1259 let _lock = FD_LOCK.read().unwrap();
1260 match access(
1261 "/dev/dma_heap/linux,cma",
1262 AccessFlags::R_OK | AccessFlags::W_OK,
1263 ) {
1264 Ok(_) => println!("/dev/dma_heap/linux,cma is available"),
1265 Err(_) => match access(
1266 "/dev/dma_heap/system",
1267 AccessFlags::R_OK | AccessFlags::W_OK,
1268 ) {
1269 Ok(_) => println!("/dev/dma_heap/system is available"),
1270 Err(e) => {
1271 writeln!(
1272 &mut std::io::stdout(),
1273 "[WARNING] DMA Heap is unavailable: {e}"
1274 )
1275 .unwrap();
1276 return;
1277 }
1278 },
1279 }
1280
1281 let shape = vec![2, 3, 4];
1282 let tensor =
1283 DmaTensor::<f32>::new(&shape, Some("test_tensor")).expect("Failed to create tensor");
1284
1285 const DUMMY_VALUE: f32 = 12.34;
1286
1287 assert_eq!(tensor.memory(), TensorMemory::Dma);
1288 assert_eq!(tensor.name(), "test_tensor");
1289 assert_eq!(tensor.shape(), &shape);
1290 assert_eq!(tensor.size(), 2 * 3 * 4 * std::mem::size_of::<f32>());
1291 assert_eq!(tensor.len(), 2 * 3 * 4);
1292
1293 {
1294 let mut tensor_map = tensor.map().expect("Failed to map DMA memory");
1295 tensor_map.fill(42.0);
1296 assert!(tensor_map.iter().all(|&x| x == 42.0));
1297 }
1298
1299 {
1300 let shared = Tensor::<f32>::from_fd(
1301 tensor
1302 .clone_fd()
1303 .expect("Failed to duplicate tensor file descriptor"),
1304 &shape,
1305 Some("test_tensor_shared"),
1306 )
1307 .expect("Failed to create tensor from fd");
1308
1309 assert_eq!(shared.memory(), TensorMemory::Dma);
1310 assert_eq!(shared.name(), "test_tensor_shared");
1311 assert_eq!(shared.shape(), &shape);
1312
1313 let mut tensor_map = shared.map().expect("Failed to map DMA memory from fd");
1314 tensor_map.fill(DUMMY_VALUE);
1315 assert!(tensor_map.iter().all(|&x| x == DUMMY_VALUE));
1316 }
1317
1318 {
1319 let tensor_map = tensor.map().expect("Failed to map DMA memory");
1320 assert!(tensor_map.iter().all(|&x| x == DUMMY_VALUE));
1321 }
1322
1323 let mut tensor = DmaTensor::<u8>::new(&shape, None).expect("Failed to create tensor");
1324 assert_eq!(tensor.shape(), &shape);
1325 let new_shape = vec![3, 4, 4];
1326 assert!(
1327 tensor.reshape(&new_shape).is_err(),
1328 "Reshape should fail due to size mismatch"
1329 );
1330 assert_eq!(tensor.shape(), &shape, "Shape should remain unchanged");
1331
1332 let new_shape = vec![2, 3, 4];
1333 tensor.reshape(&new_shape).expect("Reshape should succeed");
1334 assert_eq!(
1335 tensor.shape(),
1336 &new_shape,
1337 "Shape should be updated after successful reshape"
1338 );
1339
1340 {
1341 let mut tensor_map = tensor.map().expect("Failed to map DMA memory");
1342 tensor_map.fill(1);
1343 assert!(tensor_map.iter().all(|&x| x == 1));
1344 }
1345
1346 {
1347 let mut tensor_map = tensor.map().expect("Failed to map DMA memory");
1348 tensor_map[2] = 42;
1349 assert_eq!(tensor_map[1], 1, "Value at index 1 should be 1");
1350 assert_eq!(tensor_map[2], 42, "Value at index 2 should be 42");
1351 }
1352 }
1353
1354 #[test]
1355 #[cfg(unix)]
1356 fn test_shm_tensor() {
1357 let _lock = FD_LOCK.read().unwrap();
1358 let shape = vec![2, 3, 4];
1359 let tensor =
1360 ShmTensor::<f32>::new(&shape, Some("test_tensor")).expect("Failed to create tensor");
1361 assert_eq!(tensor.shape(), &shape);
1362 assert_eq!(tensor.size(), 2 * 3 * 4 * std::mem::size_of::<f32>());
1363 assert_eq!(tensor.name(), "test_tensor");
1364
1365 const DUMMY_VALUE: f32 = 12.34;
1366 {
1367 let mut tensor_map = tensor.map().expect("Failed to map shared memory");
1368 tensor_map.fill(42.0);
1369 assert!(tensor_map.iter().all(|&x| x == 42.0));
1370 }
1371
1372 {
1373 let shared = Tensor::<f32>::from_fd(
1374 tensor
1375 .clone_fd()
1376 .expect("Failed to duplicate tensor file descriptor"),
1377 &shape,
1378 Some("test_tensor_shared"),
1379 )
1380 .expect("Failed to create tensor from fd");
1381
1382 assert_eq!(shared.memory(), TensorMemory::Shm);
1383 assert_eq!(shared.name(), "test_tensor_shared");
1384 assert_eq!(shared.shape(), &shape);
1385
1386 let mut tensor_map = shared.map().expect("Failed to map shared memory from fd");
1387 tensor_map.fill(DUMMY_VALUE);
1388 assert!(tensor_map.iter().all(|&x| x == DUMMY_VALUE));
1389 }
1390
1391 {
1392 let tensor_map = tensor.map().expect("Failed to map shared memory");
1393 assert!(tensor_map.iter().all(|&x| x == DUMMY_VALUE));
1394 }
1395
1396 let mut tensor = ShmTensor::<u8>::new(&shape, None).expect("Failed to create tensor");
1397 assert_eq!(tensor.shape(), &shape);
1398 let new_shape = vec![3, 4, 4];
1399 assert!(
1400 tensor.reshape(&new_shape).is_err(),
1401 "Reshape should fail due to size mismatch"
1402 );
1403 assert_eq!(tensor.shape(), &shape, "Shape should remain unchanged");
1404
1405 let new_shape = vec![2, 3, 4];
1406 tensor.reshape(&new_shape).expect("Reshape should succeed");
1407 assert_eq!(
1408 tensor.shape(),
1409 &new_shape,
1410 "Shape should be updated after successful reshape"
1411 );
1412
1413 {
1414 let mut tensor_map = tensor.map().expect("Failed to map shared memory");
1415 tensor_map.fill(1);
1416 assert!(tensor_map.iter().all(|&x| x == 1));
1417 }
1418
1419 {
1420 let mut tensor_map = tensor.map().expect("Failed to map shared memory");
1421 tensor_map[2] = 42;
1422 assert_eq!(tensor_map[1], 1, "Value at index 1 should be 1");
1423 assert_eq!(tensor_map[2], 42, "Value at index 2 should be 42");
1424 }
1425 }
1426
1427 #[test]
1428 fn test_mem_tensor() {
1429 let shape = vec![2, 3, 4];
1430 let tensor =
1431 MemTensor::<f32>::new(&shape, Some("test_tensor")).expect("Failed to create tensor");
1432 assert_eq!(tensor.shape(), &shape);
1433 assert_eq!(tensor.size(), 2 * 3 * 4 * std::mem::size_of::<f32>());
1434 assert_eq!(tensor.name(), "test_tensor");
1435
1436 {
1437 let mut tensor_map = tensor.map().expect("Failed to map memory");
1438 tensor_map.fill(42.0);
1439 assert!(tensor_map.iter().all(|&x| x == 42.0));
1440 }
1441
1442 let mut tensor = MemTensor::<u8>::new(&shape, None).expect("Failed to create tensor");
1443 assert_eq!(tensor.shape(), &shape);
1444 let new_shape = vec![3, 4, 4];
1445 assert!(
1446 tensor.reshape(&new_shape).is_err(),
1447 "Reshape should fail due to size mismatch"
1448 );
1449 assert_eq!(tensor.shape(), &shape, "Shape should remain unchanged");
1450
1451 let new_shape = vec![2, 3, 4];
1452 tensor.reshape(&new_shape).expect("Reshape should succeed");
1453 assert_eq!(
1454 tensor.shape(),
1455 &new_shape,
1456 "Shape should be updated after successful reshape"
1457 );
1458
1459 {
1460 let mut tensor_map = tensor.map().expect("Failed to map memory");
1461 tensor_map.fill(1);
1462 assert!(tensor_map.iter().all(|&x| x == 1));
1463 }
1464
1465 {
1466 let mut tensor_map = tensor.map().expect("Failed to map memory");
1467 tensor_map[2] = 42;
1468 assert_eq!(tensor_map[1], 1, "Value at index 1 should be 1");
1469 assert_eq!(tensor_map[2], 42, "Value at index 2 should be 42");
1470 }
1471 }
1472
1473 #[test]
1474 #[cfg(target_os = "linux")]
1475 fn test_dma_no_fd_leaks() {
1476 let _lock = FD_LOCK.write().unwrap();
1477 if !is_dma_available() {
1478 log::warn!(
1479 "SKIPPED: {} - DMA memory allocation not available (permission denied or no DMA-BUF support)",
1480 function!()
1481 );
1482 return;
1483 }
1484
1485 let proc = procfs::process::Process::myself()
1486 .expect("Failed to get current process using /proc/self");
1487
1488 let start_open_fds = proc
1489 .fd_count()
1490 .expect("Failed to get open file descriptor count");
1491
1492 for _ in 0..100 {
1493 let tensor = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Dma), None)
1494 .expect("Failed to create tensor");
1495 let mut map = tensor.map().unwrap();
1496 map.as_mut_slice().fill(233);
1497 }
1498
1499 let end_open_fds = proc
1500 .fd_count()
1501 .expect("Failed to get open file descriptor count");
1502
1503 assert_eq!(
1504 start_open_fds, end_open_fds,
1505 "File descriptor leak detected: {} -> {}",
1506 start_open_fds, end_open_fds
1507 );
1508 }
1509
1510 #[test]
1511 #[cfg(target_os = "linux")]
1512 fn test_dma_from_fd_no_fd_leaks() {
1513 let _lock = FD_LOCK.write().unwrap();
1514 if !is_dma_available() {
1515 log::warn!(
1516 "SKIPPED: {} - DMA memory allocation not available (permission denied or no DMA-BUF support)",
1517 function!()
1518 );
1519 return;
1520 }
1521
1522 let proc = procfs::process::Process::myself()
1523 .expect("Failed to get current process using /proc/self");
1524
1525 let start_open_fds = proc
1526 .fd_count()
1527 .expect("Failed to get open file descriptor count");
1528
1529 let orig = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Dma), None).unwrap();
1530
1531 for _ in 0..100 {
1532 let tensor =
1533 Tensor::<u8>::from_fd(orig.clone_fd().unwrap(), orig.shape(), None).unwrap();
1534 let mut map = tensor.map().unwrap();
1535 map.as_mut_slice().fill(233);
1536 }
1537 drop(orig);
1538
1539 let end_open_fds = proc.fd_count().unwrap();
1540
1541 assert_eq!(
1542 start_open_fds, end_open_fds,
1543 "File descriptor leak detected: {} -> {}",
1544 start_open_fds, end_open_fds
1545 );
1546 }
1547
1548 #[test]
1549 #[cfg(target_os = "linux")]
1550 fn test_shm_no_fd_leaks() {
1551 let _lock = FD_LOCK.write().unwrap();
1552 if !is_shm_available() {
1553 log::warn!(
1554 "SKIPPED: {} - SHM memory allocation not available (permission denied or no SHM support)",
1555 function!()
1556 );
1557 return;
1558 }
1559
1560 let proc = procfs::process::Process::myself()
1561 .expect("Failed to get current process using /proc/self");
1562
1563 let start_open_fds = proc
1564 .fd_count()
1565 .expect("Failed to get open file descriptor count");
1566
1567 for _ in 0..100 {
1568 let tensor = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Shm), None)
1569 .expect("Failed to create tensor");
1570 let mut map = tensor.map().unwrap();
1571 map.as_mut_slice().fill(233);
1572 }
1573
1574 let end_open_fds = proc
1575 .fd_count()
1576 .expect("Failed to get open file descriptor count");
1577
1578 assert_eq!(
1579 start_open_fds, end_open_fds,
1580 "File descriptor leak detected: {} -> {}",
1581 start_open_fds, end_open_fds
1582 );
1583 }
1584
1585 #[test]
1586 #[cfg(target_os = "linux")]
1587 fn test_shm_from_fd_no_fd_leaks() {
1588 let _lock = FD_LOCK.write().unwrap();
1589 if !is_shm_available() {
1590 log::warn!(
1591 "SKIPPED: {} - SHM memory allocation not available (permission denied or no SHM support)",
1592 function!()
1593 );
1594 return;
1595 }
1596
1597 let proc = procfs::process::Process::myself()
1598 .expect("Failed to get current process using /proc/self");
1599
1600 let start_open_fds = proc
1601 .fd_count()
1602 .expect("Failed to get open file descriptor count");
1603
1604 let orig = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Shm), None).unwrap();
1605
1606 for _ in 0..100 {
1607 let tensor =
1608 Tensor::<u8>::from_fd(orig.clone_fd().unwrap(), orig.shape(), None).unwrap();
1609 let mut map = tensor.map().unwrap();
1610 map.as_mut_slice().fill(233);
1611 }
1612 drop(orig);
1613
1614 let end_open_fds = proc.fd_count().unwrap();
1615
1616 assert_eq!(
1617 start_open_fds, end_open_fds,
1618 "File descriptor leak detected: {} -> {}",
1619 start_open_fds, end_open_fds
1620 );
1621 }
1622
1623 #[cfg(feature = "ndarray")]
1624 #[test]
1625 fn test_ndarray() {
1626 let _lock = FD_LOCK.read().unwrap();
1627 let shape = vec![2, 3, 4];
1628 let tensor = Tensor::<f32>::new(&shape, None, None).expect("Failed to create tensor");
1629
1630 let mut tensor_map = tensor.map().expect("Failed to map tensor memory");
1631 tensor_map.fill(1.0);
1632
1633 let view = tensor_map.view().expect("Failed to get ndarray view");
1634 assert_eq!(view.shape(), &[2, 3, 4]);
1635 assert!(view.iter().all(|&x| x == 1.0));
1636
1637 let mut view_mut = tensor_map
1638 .view_mut()
1639 .expect("Failed to get mutable ndarray view");
1640 view_mut[[0, 0, 0]] = 42.0;
1641 assert_eq!(view_mut[[0, 0, 0]], 42.0);
1642 assert_eq!(tensor_map[0], 42.0, "Value at index 0 should be 42");
1643 }
1644
1645 #[test]
1646 fn test_buffer_identity_unique() {
1647 let id1 = BufferIdentity::new();
1648 let id2 = BufferIdentity::new();
1649 assert_ne!(
1650 id1.id(),
1651 id2.id(),
1652 "Two identities should have different ids"
1653 );
1654 }
1655
1656 #[test]
1657 fn test_buffer_identity_clone_shares_guard() {
1658 let id1 = BufferIdentity::new();
1659 let weak = id1.weak();
1660 assert!(
1661 weak.upgrade().is_some(),
1662 "Weak should be alive while original exists"
1663 );
1664
1665 let id2 = id1.clone();
1666 assert_eq!(id1.id(), id2.id(), "Cloned identity should have same id");
1667
1668 drop(id1);
1669 assert!(
1670 weak.upgrade().is_some(),
1671 "Weak should still be alive (clone holds Arc)"
1672 );
1673
1674 drop(id2);
1675 assert!(
1676 weak.upgrade().is_none(),
1677 "Weak should be dead after all clones dropped"
1678 );
1679 }
1680
1681 #[test]
1682 fn test_tensor_buffer_identity() {
1683 let t1 = Tensor::<u8>::new(&[100], Some(TensorMemory::Mem), Some("t1")).unwrap();
1684 let t2 = Tensor::<u8>::new(&[100], Some(TensorMemory::Mem), Some("t2")).unwrap();
1685 assert_ne!(
1686 t1.buffer_identity().id(),
1687 t2.buffer_identity().id(),
1688 "Different tensors should have different buffer ids"
1689 );
1690 }
1691
1692 pub static FD_LOCK: RwLock<()> = RwLock::new(());
1696
1697 #[test]
1700 #[cfg(not(target_os = "linux"))]
1701 fn test_dma_not_available_on_non_linux() {
1702 assert!(
1703 !is_dma_available(),
1704 "DMA memory allocation should NOT be available on non-Linux platforms"
1705 );
1706 }
1707
1708 #[test]
1711 #[cfg(unix)]
1712 fn test_shm_available_and_usable() {
1713 assert!(
1714 is_shm_available(),
1715 "SHM memory allocation should be available on Unix systems"
1716 );
1717
1718 let tensor = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Shm), None)
1720 .expect("Failed to create SHM tensor");
1721
1722 let mut map = tensor.map().expect("Failed to map SHM tensor");
1724 map.as_mut_slice().fill(0xAB);
1725
1726 assert!(
1728 map.as_slice().iter().all(|&b| b == 0xAB),
1729 "SHM tensor data should be writable and readable"
1730 );
1731 }
1732}