1#[cfg(target_os = "linux")]
30mod dma;
31#[cfg(target_os = "linux")]
32mod dmabuf;
33mod error;
34mod format;
35mod mem;
36mod pbo;
37#[cfg(unix)]
38mod shm;
39mod tensor_dyn;
40
41#[cfg(target_os = "linux")]
42pub use crate::dma::{DmaMap, DmaTensor};
43pub use crate::mem::{MemMap, MemTensor};
44pub use crate::pbo::{PboMap, PboMapping, PboOps, PboTensor};
45#[cfg(unix)]
46pub use crate::shm::{ShmMap, ShmTensor};
47pub use error::{Error, Result};
48pub use format::{PixelFormat, PixelLayout};
49use num_traits::Num;
50use serde::{Deserialize, Serialize};
51#[cfg(unix)]
52use std::os::fd::OwnedFd;
53use std::{
54 fmt,
55 ops::{Deref, DerefMut},
56 sync::{
57 atomic::{AtomicU64, Ordering},
58 Arc, Weak,
59 },
60};
61pub use tensor_dyn::TensorDyn;
62
63#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
65#[repr(u8)]
66#[non_exhaustive]
67pub enum DType {
68 U8,
69 I8,
70 U16,
71 I16,
72 U32,
73 I32,
74 U64,
75 I64,
76 F16,
77 F32,
78 F64,
79}
80
81impl DType {
82 pub const fn size(&self) -> usize {
84 match self {
85 Self::U8 | Self::I8 => 1,
86 Self::U16 | Self::I16 | Self::F16 => 2,
87 Self::U32 | Self::I32 | Self::F32 => 4,
88 Self::U64 | Self::I64 | Self::F64 => 8,
89 }
90 }
91
92 pub const fn name(&self) -> &'static str {
94 match self {
95 Self::U8 => "u8",
96 Self::I8 => "i8",
97 Self::U16 => "u16",
98 Self::I16 => "i16",
99 Self::U32 => "u32",
100 Self::I32 => "i32",
101 Self::U64 => "u64",
102 Self::I64 => "i64",
103 Self::F16 => "f16",
104 Self::F32 => "f32",
105 Self::F64 => "f64",
106 }
107 }
108}
109
110impl fmt::Display for DType {
111 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
112 f.write_str(self.name())
113 }
114}
115
116static NEXT_BUFFER_ID: AtomicU64 = AtomicU64::new(1);
118
119#[derive(Debug, Clone)]
125pub struct BufferIdentity {
126 id: u64,
127 guard: Arc<()>,
128}
129
130impl BufferIdentity {
131 pub fn new() -> Self {
133 Self {
134 id: NEXT_BUFFER_ID.fetch_add(1, Ordering::Relaxed),
135 guard: Arc::new(()),
136 }
137 }
138
139 pub fn id(&self) -> u64 {
141 self.id
142 }
143
144 pub fn weak(&self) -> Weak<()> {
147 Arc::downgrade(&self.guard)
148 }
149}
150
151impl Default for BufferIdentity {
152 fn default() -> Self {
153 Self::new()
154 }
155}
156
157#[cfg(target_os = "linux")]
158use nix::sys::stat::{major, minor};
159
160pub trait TensorTrait<T>: Send + Sync
161where
162 T: Num + Clone + fmt::Debug,
163{
164 fn new(shape: &[usize], name: Option<&str>) -> Result<Self>
167 where
168 Self: Sized;
169
170 #[cfg(unix)]
171 fn from_fd(fd: std::os::fd::OwnedFd, shape: &[usize], name: Option<&str>) -> Result<Self>
177 where
178 Self: Sized;
179
180 #[cfg(unix)]
181 fn clone_fd(&self) -> Result<std::os::fd::OwnedFd>;
183
184 fn memory(&self) -> TensorMemory;
186
187 fn name(&self) -> String;
189
190 fn len(&self) -> usize {
192 self.shape().iter().product()
193 }
194
195 fn is_empty(&self) -> bool {
197 self.len() == 0
198 }
199
200 fn size(&self) -> usize {
202 self.len() * std::mem::size_of::<T>()
203 }
204
205 fn shape(&self) -> &[usize];
207
208 fn reshape(&mut self, shape: &[usize]) -> Result<()>;
211
212 fn map(&self) -> Result<TensorMap<T>>;
215
216 fn buffer_identity(&self) -> &BufferIdentity;
218}
219
220pub trait TensorMapTrait<T>
221where
222 T: Num + Clone + fmt::Debug,
223{
224 fn shape(&self) -> &[usize];
226
227 fn unmap(&mut self);
229
230 fn len(&self) -> usize {
232 self.shape().iter().product()
233 }
234
235 fn is_empty(&self) -> bool {
237 self.len() == 0
238 }
239
240 fn size(&self) -> usize {
242 self.len() * std::mem::size_of::<T>()
243 }
244
245 fn as_slice(&self) -> &[T];
247
248 fn as_mut_slice(&mut self) -> &mut [T];
250
251 #[cfg(feature = "ndarray")]
252 fn view(&'_ self) -> Result<ndarray::ArrayView<'_, T, ndarray::Dim<ndarray::IxDynImpl>>> {
254 Ok(ndarray::ArrayView::from_shape(
255 self.shape(),
256 self.as_slice(),
257 )?)
258 }
259
260 #[cfg(feature = "ndarray")]
261 fn view_mut(
263 &'_ mut self,
264 ) -> Result<ndarray::ArrayViewMut<'_, T, ndarray::Dim<ndarray::IxDynImpl>>> {
265 let shape = self.shape().to_vec();
266 Ok(ndarray::ArrayViewMut::from_shape(
267 shape,
268 self.as_mut_slice(),
269 )?)
270 }
271}
272
273#[derive(Debug, Clone, Copy, PartialEq, Eq)]
274pub enum TensorMemory {
275 #[cfg(target_os = "linux")]
276 Dma,
280 #[cfg(unix)]
281 Shm,
284
285 Mem,
287
288 Pbo,
291}
292
293impl From<TensorMemory> for String {
294 fn from(memory: TensorMemory) -> Self {
295 match memory {
296 #[cfg(target_os = "linux")]
297 TensorMemory::Dma => "dma".to_owned(),
298 #[cfg(unix)]
299 TensorMemory::Shm => "shm".to_owned(),
300 TensorMemory::Mem => "mem".to_owned(),
301 TensorMemory::Pbo => "pbo".to_owned(),
302 }
303 }
304}
305
306impl TryFrom<&str> for TensorMemory {
307 type Error = Error;
308
309 fn try_from(s: &str) -> Result<Self> {
310 match s {
311 #[cfg(target_os = "linux")]
312 "dma" => Ok(TensorMemory::Dma),
313 #[cfg(unix)]
314 "shm" => Ok(TensorMemory::Shm),
315 "mem" => Ok(TensorMemory::Mem),
316 "pbo" => Ok(TensorMemory::Pbo),
317 _ => Err(Error::InvalidMemoryType(s.to_owned())),
318 }
319 }
320}
321
322#[derive(Debug)]
323#[allow(dead_code)] pub(crate) enum TensorStorage<T>
325where
326 T: Num + Clone + fmt::Debug + Send + Sync,
327{
328 #[cfg(target_os = "linux")]
329 Dma(DmaTensor<T>),
330 #[cfg(unix)]
331 Shm(ShmTensor<T>),
332 Mem(MemTensor<T>),
333 Pbo(PboTensor<T>),
334}
335
336impl<T> TensorStorage<T>
337where
338 T: Num + Clone + fmt::Debug + Send + Sync,
339{
340 fn new(shape: &[usize], memory: Option<TensorMemory>, name: Option<&str>) -> Result<Self> {
345 match memory {
346 #[cfg(target_os = "linux")]
347 Some(TensorMemory::Dma) => {
348 DmaTensor::<T>::new(shape, name).map(TensorStorage::Dma)
349 }
350 #[cfg(unix)]
351 Some(TensorMemory::Shm) => {
352 ShmTensor::<T>::new(shape, name).map(TensorStorage::Shm)
353 }
354 Some(TensorMemory::Mem) => {
355 MemTensor::<T>::new(shape, name).map(TensorStorage::Mem)
356 }
357 Some(TensorMemory::Pbo) => Err(crate::error::Error::NotImplemented(
358 "PboTensor cannot be created via Tensor::new() — use ImageProcessor::create_image()".to_owned(),
359 )),
360 None => {
361 if std::env::var("EDGEFIRST_TENSOR_FORCE_MEM")
362 .is_ok_and(|x| x != "0" && x.to_lowercase() != "false")
363 {
364 MemTensor::<T>::new(shape, name).map(TensorStorage::Mem)
365 } else {
366 #[cfg(target_os = "linux")]
367 {
368 match DmaTensor::<T>::new(shape, name) {
370 Ok(tensor) => Ok(TensorStorage::Dma(tensor)),
371 Err(_) => {
372 match ShmTensor::<T>::new(shape, name)
373 .map(TensorStorage::Shm)
374 {
375 Ok(tensor) => Ok(tensor),
376 Err(_) => MemTensor::<T>::new(shape, name)
377 .map(TensorStorage::Mem),
378 }
379 }
380 }
381 }
382 #[cfg(all(unix, not(target_os = "linux")))]
383 {
384 match ShmTensor::<T>::new(shape, name) {
386 Ok(tensor) => Ok(TensorStorage::Shm(tensor)),
387 Err(_) => {
388 MemTensor::<T>::new(shape, name).map(TensorStorage::Mem)
389 }
390 }
391 }
392 #[cfg(not(unix))]
393 {
394 MemTensor::<T>::new(shape, name).map(TensorStorage::Mem)
396 }
397 }
398 }
399 }
400 }
401
402 #[cfg(unix)]
405 fn from_fd(fd: OwnedFd, shape: &[usize], name: Option<&str>) -> Result<Self> {
406 #[cfg(target_os = "linux")]
407 {
408 use nix::sys::stat::fstat;
409
410 let stat = fstat(&fd)?;
411 let major = major(stat.st_dev);
412 let minor = minor(stat.st_dev);
413
414 log::debug!("Creating tensor from fd: major={major}, minor={minor}");
415
416 if major != 0 {
417 return Err(Error::UnknownDeviceType(major, minor));
419 }
420
421 match minor {
422 9 | 10 => {
423 DmaTensor::<T>::from_fd(fd, shape, name).map(TensorStorage::Dma)
425 }
426 _ => {
427 ShmTensor::<T>::from_fd(fd, shape, name).map(TensorStorage::Shm)
429 }
430 }
431 }
432 #[cfg(all(unix, not(target_os = "linux")))]
433 {
434 ShmTensor::<T>::from_fd(fd, shape, name).map(TensorStorage::Shm)
436 }
437 }
438}
439
440impl<T> TensorTrait<T> for TensorStorage<T>
441where
442 T: Num + Clone + fmt::Debug + Send + Sync,
443{
444 fn new(shape: &[usize], name: Option<&str>) -> Result<Self> {
445 Self::new(shape, None, name)
446 }
447
448 #[cfg(unix)]
449 fn from_fd(fd: OwnedFd, shape: &[usize], name: Option<&str>) -> Result<Self> {
450 Self::from_fd(fd, shape, name)
451 }
452
453 #[cfg(unix)]
454 fn clone_fd(&self) -> Result<OwnedFd> {
455 match self {
456 #[cfg(target_os = "linux")]
457 TensorStorage::Dma(t) => t.clone_fd(),
458 TensorStorage::Shm(t) => t.clone_fd(),
459 TensorStorage::Mem(t) => t.clone_fd(),
460 TensorStorage::Pbo(t) => t.clone_fd(),
461 }
462 }
463
464 fn memory(&self) -> TensorMemory {
465 match self {
466 #[cfg(target_os = "linux")]
467 TensorStorage::Dma(_) => TensorMemory::Dma,
468 #[cfg(unix)]
469 TensorStorage::Shm(_) => TensorMemory::Shm,
470 TensorStorage::Mem(_) => TensorMemory::Mem,
471 TensorStorage::Pbo(_) => TensorMemory::Pbo,
472 }
473 }
474
475 fn name(&self) -> String {
476 match self {
477 #[cfg(target_os = "linux")]
478 TensorStorage::Dma(t) => t.name(),
479 #[cfg(unix)]
480 TensorStorage::Shm(t) => t.name(),
481 TensorStorage::Mem(t) => t.name(),
482 TensorStorage::Pbo(t) => t.name(),
483 }
484 }
485
486 fn shape(&self) -> &[usize] {
487 match self {
488 #[cfg(target_os = "linux")]
489 TensorStorage::Dma(t) => t.shape(),
490 #[cfg(unix)]
491 TensorStorage::Shm(t) => t.shape(),
492 TensorStorage::Mem(t) => t.shape(),
493 TensorStorage::Pbo(t) => t.shape(),
494 }
495 }
496
497 fn reshape(&mut self, shape: &[usize]) -> Result<()> {
498 match self {
499 #[cfg(target_os = "linux")]
500 TensorStorage::Dma(t) => t.reshape(shape),
501 #[cfg(unix)]
502 TensorStorage::Shm(t) => t.reshape(shape),
503 TensorStorage::Mem(t) => t.reshape(shape),
504 TensorStorage::Pbo(t) => t.reshape(shape),
505 }
506 }
507
508 fn map(&self) -> Result<TensorMap<T>> {
509 match self {
510 #[cfg(target_os = "linux")]
511 TensorStorage::Dma(t) => t.map(),
512 #[cfg(unix)]
513 TensorStorage::Shm(t) => t.map(),
514 TensorStorage::Mem(t) => t.map(),
515 TensorStorage::Pbo(t) => t.map(),
516 }
517 }
518
519 fn buffer_identity(&self) -> &BufferIdentity {
520 match self {
521 #[cfg(target_os = "linux")]
522 TensorStorage::Dma(t) => t.buffer_identity(),
523 #[cfg(unix)]
524 TensorStorage::Shm(t) => t.buffer_identity(),
525 TensorStorage::Mem(t) => t.buffer_identity(),
526 TensorStorage::Pbo(t) => t.buffer_identity(),
527 }
528 }
529}
530
531#[derive(Debug)]
537pub struct Tensor<T>
538where
539 T: Num + Clone + fmt::Debug + Send + Sync,
540{
541 pub(crate) storage: TensorStorage<T>,
542 format: Option<PixelFormat>,
543 chroma: Option<Box<Tensor<T>>>,
544}
545
546impl<T> Tensor<T>
547where
548 T: Num + Clone + fmt::Debug + Send + Sync,
549{
550 pub(crate) fn wrap(storage: TensorStorage<T>) -> Self {
552 Self {
553 storage,
554 format: None,
555 chroma: None,
556 }
557 }
558
559 pub fn new(shape: &[usize], memory: Option<TensorMemory>, name: Option<&str>) -> Result<Self> {
584 TensorStorage::new(shape, memory, name).map(Self::wrap)
585 }
586
587 pub fn image(
589 width: usize,
590 height: usize,
591 format: PixelFormat,
592 memory: Option<TensorMemory>,
593 ) -> Result<Self> {
594 let shape = match format.layout() {
595 PixelLayout::Packed => vec![height, width, format.channels()],
596 PixelLayout::Planar => vec![format.channels(), height, width],
597 PixelLayout::SemiPlanar => {
598 let total_h = match format {
602 PixelFormat::Nv12 => {
603 if !height.is_multiple_of(2) {
604 return Err(Error::InvalidArgument(format!(
605 "NV12 requires even height, got {height}"
606 )));
607 }
608 height * 3 / 2
609 }
610 PixelFormat::Nv16 => height * 2,
611 _ => {
612 return Err(Error::InvalidArgument(format!(
613 "unknown semi-planar height multiplier for {format:?}"
614 )))
615 }
616 };
617 vec![total_h, width]
618 }
619 };
620 let mut t = Self::new(&shape, memory, None)?;
621 t.format = Some(format);
622 Ok(t)
623 }
624
625 pub fn set_format(&mut self, format: PixelFormat) -> Result<()> {
627 let shape = self.shape();
628 match format.layout() {
629 PixelLayout::Packed => {
630 if shape.len() != 3 || shape[2] != format.channels() {
631 return Err(Error::InvalidShape(format!(
632 "packed format {format:?} expects [H, W, {}], got {shape:?}",
633 format.channels()
634 )));
635 }
636 }
637 PixelLayout::Planar => {
638 if shape.len() != 3 || shape[0] != format.channels() {
639 return Err(Error::InvalidShape(format!(
640 "planar format {format:?} expects [{}, H, W], got {shape:?}",
641 format.channels()
642 )));
643 }
644 }
645 PixelLayout::SemiPlanar => {
646 if shape.len() != 2 {
647 return Err(Error::InvalidShape(format!(
648 "semi-planar format {format:?} expects [H*k, W], got {shape:?}"
649 )));
650 }
651 match format {
652 PixelFormat::Nv12 if !shape[0].is_multiple_of(3) => {
653 return Err(Error::InvalidShape(format!(
654 "NV12 contiguous shape[0] must be divisible by 3, got {}",
655 shape[0]
656 )));
657 }
658 PixelFormat::Nv16 if !shape[0].is_multiple_of(2) => {
659 return Err(Error::InvalidShape(format!(
660 "NV16 contiguous shape[0] must be even, got {}",
661 shape[0]
662 )));
663 }
664 _ => {}
665 }
666 }
667 }
668 self.format = Some(format);
669 Ok(())
670 }
671
672 pub fn format(&self) -> Option<PixelFormat> {
674 self.format
675 }
676
677 pub fn width(&self) -> Option<usize> {
679 let fmt = self.format?;
680 let shape = self.shape();
681 match fmt.layout() {
682 PixelLayout::Packed => Some(shape[1]),
683 PixelLayout::Planar => Some(shape[2]),
684 PixelLayout::SemiPlanar => Some(shape[1]),
685 }
686 }
687
688 pub fn height(&self) -> Option<usize> {
690 let fmt = self.format?;
691 let shape = self.shape();
692 match fmt.layout() {
693 PixelLayout::Packed => Some(shape[0]),
694 PixelLayout::Planar => Some(shape[1]),
695 PixelLayout::SemiPlanar => {
696 if self.is_multiplane() {
697 Some(shape[0])
698 } else {
699 match fmt {
700 PixelFormat::Nv12 => Some(shape[0] * 2 / 3),
701 PixelFormat::Nv16 => Some(shape[0] / 2),
702 _ => None,
703 }
704 }
705 }
706 }
707 }
708
709 pub fn from_planes(luma: Tensor<T>, chroma: Tensor<T>, format: PixelFormat) -> Result<Self> {
711 if format.layout() != PixelLayout::SemiPlanar {
712 return Err(Error::InvalidArgument(format!(
713 "from_planes requires a semi-planar format, got {format:?}"
714 )));
715 }
716 if chroma.format.is_some() || chroma.chroma.is_some() {
717 return Err(Error::InvalidArgument(
718 "chroma tensor must be a raw tensor (no format or chroma metadata)".into(),
719 ));
720 }
721 let luma_shape = luma.shape();
722 let chroma_shape = chroma.shape();
723 if luma_shape.len() != 2 || chroma_shape.len() != 2 {
724 return Err(Error::InvalidArgument(format!(
725 "from_planes expects 2D shapes, got luma={luma_shape:?} chroma={chroma_shape:?}"
726 )));
727 }
728 if luma_shape[1] != chroma_shape[1] {
729 return Err(Error::InvalidArgument(format!(
730 "luma width {} != chroma width {}",
731 luma_shape[1], chroma_shape[1]
732 )));
733 }
734 match format {
735 PixelFormat::Nv12 => {
736 if luma_shape[0] % 2 != 0 {
737 return Err(Error::InvalidArgument(format!(
738 "NV12 requires even luma height, got {}",
739 luma_shape[0]
740 )));
741 }
742 if chroma_shape[0] != luma_shape[0] / 2 {
743 return Err(Error::InvalidArgument(format!(
744 "NV12 chroma height {} != luma height / 2 ({})",
745 chroma_shape[0],
746 luma_shape[0] / 2
747 )));
748 }
749 }
750 PixelFormat::Nv16 => {
751 if chroma_shape[0] != luma_shape[0] {
752 return Err(Error::InvalidArgument(format!(
753 "NV16 chroma height {} != luma height {}",
754 chroma_shape[0], luma_shape[0]
755 )));
756 }
757 }
758 _ => {
759 return Err(Error::InvalidArgument(format!(
760 "from_planes only supports NV12 and NV16, got {format:?}"
761 )));
762 }
763 }
764
765 Ok(Tensor {
766 storage: luma.storage,
767 format: Some(format),
768 chroma: Some(Box::new(chroma)),
769 })
770 }
771
772 pub fn is_multiplane(&self) -> bool {
774 self.chroma.is_some()
775 }
776
777 pub fn chroma(&self) -> Option<&Tensor<T>> {
779 self.chroma.as_deref()
780 }
781
782 pub fn as_pbo(&self) -> Option<&PboTensor<T>> {
784 match &self.storage {
785 TensorStorage::Pbo(p) => Some(p),
786 _ => None,
787 }
788 }
789
790 #[cfg(target_os = "linux")]
792 pub fn as_dma(&self) -> Option<&DmaTensor<T>> {
793 match &self.storage {
794 TensorStorage::Dma(d) => Some(d),
795 _ => None,
796 }
797 }
798
799 pub fn from_pbo(pbo: PboTensor<T>) -> Self {
801 Self {
802 storage: TensorStorage::Pbo(pbo),
803 format: None,
804 chroma: None,
805 }
806 }
807}
808
809impl<T> TensorTrait<T> for Tensor<T>
810where
811 T: Num + Clone + fmt::Debug + Send + Sync,
812{
813 fn new(shape: &[usize], name: Option<&str>) -> Result<Self>
814 where
815 Self: Sized,
816 {
817 Self::new(shape, None, name)
818 }
819
820 #[cfg(unix)]
821 fn from_fd(fd: std::os::fd::OwnedFd, shape: &[usize], name: Option<&str>) -> Result<Self>
822 where
823 Self: Sized,
824 {
825 Ok(Self::wrap(TensorStorage::from_fd(fd, shape, name)?))
826 }
827
828 #[cfg(unix)]
829 fn clone_fd(&self) -> Result<std::os::fd::OwnedFd> {
830 self.storage.clone_fd()
831 }
832
833 fn memory(&self) -> TensorMemory {
834 self.storage.memory()
835 }
836
837 fn name(&self) -> String {
838 self.storage.name()
839 }
840
841 fn shape(&self) -> &[usize] {
842 self.storage.shape()
843 }
844
845 fn reshape(&mut self, shape: &[usize]) -> Result<()> {
846 if self.chroma.is_some() {
847 return Err(Error::InvalidOperation(
848 "cannot reshape a multiplane tensor — decompose planes first".into(),
849 ));
850 }
851 self.storage.reshape(shape)?;
852 self.format = None;
853 Ok(())
854 }
855
856 fn map(&self) -> Result<TensorMap<T>> {
857 self.storage.map()
858 }
859
860 fn buffer_identity(&self) -> &BufferIdentity {
861 self.storage.buffer_identity()
862 }
863}
864
865pub enum TensorMap<T>
866where
867 T: Num + Clone + fmt::Debug,
868{
869 #[cfg(target_os = "linux")]
870 Dma(DmaMap<T>),
871 #[cfg(unix)]
872 Shm(ShmMap<T>),
873 Mem(MemMap<T>),
874 Pbo(PboMap<T>),
875}
876
877impl<T> TensorMapTrait<T> for TensorMap<T>
878where
879 T: Num + Clone + fmt::Debug,
880{
881 fn shape(&self) -> &[usize] {
882 match self {
883 #[cfg(target_os = "linux")]
884 TensorMap::Dma(map) => map.shape(),
885 #[cfg(unix)]
886 TensorMap::Shm(map) => map.shape(),
887 TensorMap::Mem(map) => map.shape(),
888 TensorMap::Pbo(map) => map.shape(),
889 }
890 }
891
892 fn unmap(&mut self) {
893 match self {
894 #[cfg(target_os = "linux")]
895 TensorMap::Dma(map) => map.unmap(),
896 #[cfg(unix)]
897 TensorMap::Shm(map) => map.unmap(),
898 TensorMap::Mem(map) => map.unmap(),
899 TensorMap::Pbo(map) => map.unmap(),
900 }
901 }
902
903 fn as_slice(&self) -> &[T] {
904 match self {
905 #[cfg(target_os = "linux")]
906 TensorMap::Dma(map) => map.as_slice(),
907 #[cfg(unix)]
908 TensorMap::Shm(map) => map.as_slice(),
909 TensorMap::Mem(map) => map.as_slice(),
910 TensorMap::Pbo(map) => map.as_slice(),
911 }
912 }
913
914 fn as_mut_slice(&mut self) -> &mut [T] {
915 match self {
916 #[cfg(target_os = "linux")]
917 TensorMap::Dma(map) => map.as_mut_slice(),
918 #[cfg(unix)]
919 TensorMap::Shm(map) => map.as_mut_slice(),
920 TensorMap::Mem(map) => map.as_mut_slice(),
921 TensorMap::Pbo(map) => map.as_mut_slice(),
922 }
923 }
924}
925
926impl<T> Deref for TensorMap<T>
927where
928 T: Num + Clone + fmt::Debug,
929{
930 type Target = [T];
931
932 fn deref(&self) -> &[T] {
933 match self {
934 #[cfg(target_os = "linux")]
935 TensorMap::Dma(map) => map.deref(),
936 #[cfg(unix)]
937 TensorMap::Shm(map) => map.deref(),
938 TensorMap::Mem(map) => map.deref(),
939 TensorMap::Pbo(map) => map.deref(),
940 }
941 }
942}
943
944impl<T> DerefMut for TensorMap<T>
945where
946 T: Num + Clone + fmt::Debug,
947{
948 fn deref_mut(&mut self) -> &mut [T] {
949 match self {
950 #[cfg(target_os = "linux")]
951 TensorMap::Dma(map) => map.deref_mut(),
952 #[cfg(unix)]
953 TensorMap::Shm(map) => map.deref_mut(),
954 TensorMap::Mem(map) => map.deref_mut(),
955 TensorMap::Pbo(map) => map.deref_mut(),
956 }
957 }
958}
959
960#[cfg(target_os = "linux")]
972static DMA_AVAILABLE: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
973
974#[cfg(target_os = "linux")]
976pub fn is_dma_available() -> bool {
977 *DMA_AVAILABLE.get_or_init(|| Tensor::<u8>::new(&[64], Some(TensorMemory::Dma), None).is_ok())
978}
979
980#[cfg(not(target_os = "linux"))]
984pub fn is_dma_available() -> bool {
985 false
986}
987
988#[cfg(unix)]
995static SHM_AVAILABLE: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
996
997#[cfg(unix)]
999pub fn is_shm_available() -> bool {
1000 *SHM_AVAILABLE.get_or_init(|| Tensor::<u8>::new(&[64], Some(TensorMemory::Shm), None).is_ok())
1001}
1002
1003#[cfg(not(unix))]
1007pub fn is_shm_available() -> bool {
1008 false
1009}
1010
1011#[cfg(test)]
1012mod dtype_tests {
1013 use super::*;
1014
1015 #[test]
1016 fn dtype_size() {
1017 assert_eq!(DType::U8.size(), 1);
1018 assert_eq!(DType::I8.size(), 1);
1019 assert_eq!(DType::U16.size(), 2);
1020 assert_eq!(DType::I16.size(), 2);
1021 assert_eq!(DType::U32.size(), 4);
1022 assert_eq!(DType::I32.size(), 4);
1023 assert_eq!(DType::U64.size(), 8);
1024 assert_eq!(DType::I64.size(), 8);
1025 assert_eq!(DType::F16.size(), 2);
1026 assert_eq!(DType::F32.size(), 4);
1027 assert_eq!(DType::F64.size(), 8);
1028 }
1029
1030 #[test]
1031 fn dtype_name() {
1032 assert_eq!(DType::U8.name(), "u8");
1033 assert_eq!(DType::F16.name(), "f16");
1034 assert_eq!(DType::F32.name(), "f32");
1035 }
1036
1037 #[test]
1038 fn dtype_serde_roundtrip() {
1039 use serde_json;
1040 let dt = DType::F16;
1041 let json = serde_json::to_string(&dt).unwrap();
1042 let back: DType = serde_json::from_str(&json).unwrap();
1043 assert_eq!(dt, back);
1044 }
1045}
1046
1047#[cfg(test)]
1048mod image_tests {
1049 use super::*;
1050
1051 #[test]
1052 fn raw_tensor_has_no_format() {
1053 let t = Tensor::<u8>::new(&[480, 640, 3], None, None).unwrap();
1054 assert!(t.format().is_none());
1055 assert!(t.width().is_none());
1056 assert!(t.height().is_none());
1057 assert!(!t.is_multiplane());
1058 assert!(t.chroma().is_none());
1059 }
1060
1061 #[test]
1062 fn image_tensor_packed() {
1063 let t = Tensor::<u8>::image(640, 480, PixelFormat::Rgba, None).unwrap();
1064 assert_eq!(t.format(), Some(PixelFormat::Rgba));
1065 assert_eq!(t.width(), Some(640));
1066 assert_eq!(t.height(), Some(480));
1067 assert_eq!(t.shape(), &[480, 640, 4]);
1068 assert!(!t.is_multiplane());
1069 }
1070
1071 #[test]
1072 fn image_tensor_planar() {
1073 let t = Tensor::<u8>::image(640, 480, PixelFormat::PlanarRgb, None).unwrap();
1074 assert_eq!(t.format(), Some(PixelFormat::PlanarRgb));
1075 assert_eq!(t.width(), Some(640));
1076 assert_eq!(t.height(), Some(480));
1077 assert_eq!(t.shape(), &[3, 480, 640]);
1078 }
1079
1080 #[test]
1081 fn image_tensor_semi_planar_contiguous() {
1082 let t = Tensor::<u8>::image(640, 480, PixelFormat::Nv12, None).unwrap();
1083 assert_eq!(t.format(), Some(PixelFormat::Nv12));
1084 assert_eq!(t.width(), Some(640));
1085 assert_eq!(t.height(), Some(480));
1086 assert_eq!(t.shape(), &[720, 640]);
1088 assert!(!t.is_multiplane());
1089 }
1090
1091 #[test]
1092 fn set_format_valid() {
1093 let mut t = Tensor::<u8>::new(&[480, 640, 3], None, None).unwrap();
1094 assert!(t.format().is_none());
1095 t.set_format(PixelFormat::Rgb).unwrap();
1096 assert_eq!(t.format(), Some(PixelFormat::Rgb));
1097 assert_eq!(t.width(), Some(640));
1098 assert_eq!(t.height(), Some(480));
1099 }
1100
1101 #[test]
1102 fn set_format_invalid_shape() {
1103 let mut t = Tensor::<u8>::new(&[480, 640, 4], None, None).unwrap();
1104 let err = t.set_format(PixelFormat::Rgb);
1106 assert!(err.is_err());
1107 assert!(t.format().is_none());
1109 }
1110
1111 #[test]
1112 fn reshape_clears_format() {
1113 let mut t = Tensor::<u8>::image(640, 480, PixelFormat::Rgba, None).unwrap();
1114 assert_eq!(t.format(), Some(PixelFormat::Rgba));
1115 t.reshape(&[480 * 640 * 4]).unwrap();
1117 assert!(t.format().is_none());
1118 }
1119
1120 #[test]
1121 fn from_planes_nv12() {
1122 let y = Tensor::<u8>::new(&[480, 640], None, None).unwrap();
1123 let uv = Tensor::<u8>::new(&[240, 640], None, None).unwrap();
1124 let img = Tensor::from_planes(y, uv, PixelFormat::Nv12).unwrap();
1125 assert_eq!(img.format(), Some(PixelFormat::Nv12));
1126 assert!(img.is_multiplane());
1127 assert!(img.chroma().is_some());
1128 assert_eq!(img.width(), Some(640));
1129 assert_eq!(img.height(), Some(480));
1130 }
1131
1132 #[test]
1133 fn from_planes_rejects_non_semiplanar() {
1134 let y = Tensor::<u8>::new(&[480, 640], None, None).unwrap();
1135 let uv = Tensor::<u8>::new(&[240, 640], None, None).unwrap();
1136 let err = Tensor::from_planes(y, uv, PixelFormat::Rgb);
1137 assert!(err.is_err());
1138 }
1139
1140 #[test]
1141 fn reshape_multiplane_errors() {
1142 let y = Tensor::<u8>::new(&[480, 640], None, None).unwrap();
1143 let uv = Tensor::<u8>::new(&[240, 640], None, None).unwrap();
1144 let mut img = Tensor::from_planes(y, uv, PixelFormat::Nv12).unwrap();
1145 let err = img.reshape(&[480 * 640 + 240 * 640]);
1146 assert!(err.is_err());
1147 }
1148}
1149
1150#[cfg(test)]
1151mod tests {
1152 #[cfg(target_os = "linux")]
1153 use nix::unistd::{access, AccessFlags};
1154 #[cfg(target_os = "linux")]
1155 use std::io::Write as _;
1156 use std::sync::RwLock;
1157
1158 use super::*;
1159
1160 #[ctor::ctor]
1161 fn init() {
1162 env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init();
1163 }
1164
1165 #[cfg(target_os = "linux")]
1167 macro_rules! function {
1168 () => {{
1169 fn f() {}
1170 fn type_name_of<T>(_: T) -> &'static str {
1171 std::any::type_name::<T>()
1172 }
1173 let name = type_name_of(f);
1174
1175 match &name[..name.len() - 3].rfind(':') {
1177 Some(pos) => &name[pos + 1..name.len() - 3],
1178 None => &name[..name.len() - 3],
1179 }
1180 }};
1181 }
1182
1183 #[test]
1184 #[cfg(target_os = "linux")]
1185 fn test_tensor() {
1186 let _lock = FD_LOCK.read().unwrap();
1187 let shape = vec![1];
1188 let tensor = DmaTensor::<f32>::new(&shape, Some("dma_tensor"));
1189 let dma_enabled = tensor.is_ok();
1190
1191 let tensor = Tensor::<f32>::new(&shape, None, None).expect("Failed to create tensor");
1192 match dma_enabled {
1193 true => assert_eq!(tensor.memory(), TensorMemory::Dma),
1194 false => assert_eq!(tensor.memory(), TensorMemory::Shm),
1195 }
1196 }
1197
1198 #[test]
1199 #[cfg(all(unix, not(target_os = "linux")))]
1200 fn test_tensor() {
1201 let shape = vec![1];
1202 let tensor = Tensor::<f32>::new(&shape, None, None).expect("Failed to create tensor");
1203 assert!(
1205 tensor.memory() == TensorMemory::Shm || tensor.memory() == TensorMemory::Mem,
1206 "Expected SHM or Mem on macOS, got {:?}",
1207 tensor.memory()
1208 );
1209 }
1210
1211 #[test]
1212 #[cfg(not(unix))]
1213 fn test_tensor() {
1214 let shape = vec![1];
1215 let tensor = Tensor::<f32>::new(&shape, None, None).expect("Failed to create tensor");
1216 assert_eq!(tensor.memory(), TensorMemory::Mem);
1217 }
1218
1219 #[test]
1220 #[cfg(target_os = "linux")]
1221 fn test_dma_tensor() {
1222 let _lock = FD_LOCK.read().unwrap();
1223 match access(
1224 "/dev/dma_heap/linux,cma",
1225 AccessFlags::R_OK | AccessFlags::W_OK,
1226 ) {
1227 Ok(_) => println!("/dev/dma_heap/linux,cma is available"),
1228 Err(_) => match access(
1229 "/dev/dma_heap/system",
1230 AccessFlags::R_OK | AccessFlags::W_OK,
1231 ) {
1232 Ok(_) => println!("/dev/dma_heap/system is available"),
1233 Err(e) => {
1234 writeln!(
1235 &mut std::io::stdout(),
1236 "[WARNING] DMA Heap is unavailable: {e}"
1237 )
1238 .unwrap();
1239 return;
1240 }
1241 },
1242 }
1243
1244 let shape = vec![2, 3, 4];
1245 let tensor =
1246 DmaTensor::<f32>::new(&shape, Some("test_tensor")).expect("Failed to create tensor");
1247
1248 const DUMMY_VALUE: f32 = 12.34;
1249
1250 assert_eq!(tensor.memory(), TensorMemory::Dma);
1251 assert_eq!(tensor.name(), "test_tensor");
1252 assert_eq!(tensor.shape(), &shape);
1253 assert_eq!(tensor.size(), 2 * 3 * 4 * std::mem::size_of::<f32>());
1254 assert_eq!(tensor.len(), 2 * 3 * 4);
1255
1256 {
1257 let mut tensor_map = tensor.map().expect("Failed to map DMA memory");
1258 tensor_map.fill(42.0);
1259 assert!(tensor_map.iter().all(|&x| x == 42.0));
1260 }
1261
1262 {
1263 let shared = Tensor::<f32>::from_fd(
1264 tensor
1265 .clone_fd()
1266 .expect("Failed to duplicate tensor file descriptor"),
1267 &shape,
1268 Some("test_tensor_shared"),
1269 )
1270 .expect("Failed to create tensor from fd");
1271
1272 assert_eq!(shared.memory(), TensorMemory::Dma);
1273 assert_eq!(shared.name(), "test_tensor_shared");
1274 assert_eq!(shared.shape(), &shape);
1275
1276 let mut tensor_map = shared.map().expect("Failed to map DMA memory from fd");
1277 tensor_map.fill(DUMMY_VALUE);
1278 assert!(tensor_map.iter().all(|&x| x == DUMMY_VALUE));
1279 }
1280
1281 {
1282 let tensor_map = tensor.map().expect("Failed to map DMA memory");
1283 assert!(tensor_map.iter().all(|&x| x == DUMMY_VALUE));
1284 }
1285
1286 let mut tensor = DmaTensor::<u8>::new(&shape, None).expect("Failed to create tensor");
1287 assert_eq!(tensor.shape(), &shape);
1288 let new_shape = vec![3, 4, 4];
1289 assert!(
1290 tensor.reshape(&new_shape).is_err(),
1291 "Reshape should fail due to size mismatch"
1292 );
1293 assert_eq!(tensor.shape(), &shape, "Shape should remain unchanged");
1294
1295 let new_shape = vec![2, 3, 4];
1296 tensor.reshape(&new_shape).expect("Reshape should succeed");
1297 assert_eq!(
1298 tensor.shape(),
1299 &new_shape,
1300 "Shape should be updated after successful reshape"
1301 );
1302
1303 {
1304 let mut tensor_map = tensor.map().expect("Failed to map DMA memory");
1305 tensor_map.fill(1);
1306 assert!(tensor_map.iter().all(|&x| x == 1));
1307 }
1308
1309 {
1310 let mut tensor_map = tensor.map().expect("Failed to map DMA memory");
1311 tensor_map[2] = 42;
1312 assert_eq!(tensor_map[1], 1, "Value at index 1 should be 1");
1313 assert_eq!(tensor_map[2], 42, "Value at index 2 should be 42");
1314 }
1315 }
1316
1317 #[test]
1318 #[cfg(unix)]
1319 fn test_shm_tensor() {
1320 let _lock = FD_LOCK.read().unwrap();
1321 let shape = vec![2, 3, 4];
1322 let tensor =
1323 ShmTensor::<f32>::new(&shape, Some("test_tensor")).expect("Failed to create tensor");
1324 assert_eq!(tensor.shape(), &shape);
1325 assert_eq!(tensor.size(), 2 * 3 * 4 * std::mem::size_of::<f32>());
1326 assert_eq!(tensor.name(), "test_tensor");
1327
1328 const DUMMY_VALUE: f32 = 12.34;
1329 {
1330 let mut tensor_map = tensor.map().expect("Failed to map shared memory");
1331 tensor_map.fill(42.0);
1332 assert!(tensor_map.iter().all(|&x| x == 42.0));
1333 }
1334
1335 {
1336 let shared = Tensor::<f32>::from_fd(
1337 tensor
1338 .clone_fd()
1339 .expect("Failed to duplicate tensor file descriptor"),
1340 &shape,
1341 Some("test_tensor_shared"),
1342 )
1343 .expect("Failed to create tensor from fd");
1344
1345 assert_eq!(shared.memory(), TensorMemory::Shm);
1346 assert_eq!(shared.name(), "test_tensor_shared");
1347 assert_eq!(shared.shape(), &shape);
1348
1349 let mut tensor_map = shared.map().expect("Failed to map shared memory from fd");
1350 tensor_map.fill(DUMMY_VALUE);
1351 assert!(tensor_map.iter().all(|&x| x == DUMMY_VALUE));
1352 }
1353
1354 {
1355 let tensor_map = tensor.map().expect("Failed to map shared memory");
1356 assert!(tensor_map.iter().all(|&x| x == DUMMY_VALUE));
1357 }
1358
1359 let mut tensor = ShmTensor::<u8>::new(&shape, None).expect("Failed to create tensor");
1360 assert_eq!(tensor.shape(), &shape);
1361 let new_shape = vec![3, 4, 4];
1362 assert!(
1363 tensor.reshape(&new_shape).is_err(),
1364 "Reshape should fail due to size mismatch"
1365 );
1366 assert_eq!(tensor.shape(), &shape, "Shape should remain unchanged");
1367
1368 let new_shape = vec![2, 3, 4];
1369 tensor.reshape(&new_shape).expect("Reshape should succeed");
1370 assert_eq!(
1371 tensor.shape(),
1372 &new_shape,
1373 "Shape should be updated after successful reshape"
1374 );
1375
1376 {
1377 let mut tensor_map = tensor.map().expect("Failed to map shared memory");
1378 tensor_map.fill(1);
1379 assert!(tensor_map.iter().all(|&x| x == 1));
1380 }
1381
1382 {
1383 let mut tensor_map = tensor.map().expect("Failed to map shared memory");
1384 tensor_map[2] = 42;
1385 assert_eq!(tensor_map[1], 1, "Value at index 1 should be 1");
1386 assert_eq!(tensor_map[2], 42, "Value at index 2 should be 42");
1387 }
1388 }
1389
1390 #[test]
1391 fn test_mem_tensor() {
1392 let shape = vec![2, 3, 4];
1393 let tensor =
1394 MemTensor::<f32>::new(&shape, Some("test_tensor")).expect("Failed to create tensor");
1395 assert_eq!(tensor.shape(), &shape);
1396 assert_eq!(tensor.size(), 2 * 3 * 4 * std::mem::size_of::<f32>());
1397 assert_eq!(tensor.name(), "test_tensor");
1398
1399 {
1400 let mut tensor_map = tensor.map().expect("Failed to map memory");
1401 tensor_map.fill(42.0);
1402 assert!(tensor_map.iter().all(|&x| x == 42.0));
1403 }
1404
1405 let mut tensor = MemTensor::<u8>::new(&shape, None).expect("Failed to create tensor");
1406 assert_eq!(tensor.shape(), &shape);
1407 let new_shape = vec![3, 4, 4];
1408 assert!(
1409 tensor.reshape(&new_shape).is_err(),
1410 "Reshape should fail due to size mismatch"
1411 );
1412 assert_eq!(tensor.shape(), &shape, "Shape should remain unchanged");
1413
1414 let new_shape = vec![2, 3, 4];
1415 tensor.reshape(&new_shape).expect("Reshape should succeed");
1416 assert_eq!(
1417 tensor.shape(),
1418 &new_shape,
1419 "Shape should be updated after successful reshape"
1420 );
1421
1422 {
1423 let mut tensor_map = tensor.map().expect("Failed to map memory");
1424 tensor_map.fill(1);
1425 assert!(tensor_map.iter().all(|&x| x == 1));
1426 }
1427
1428 {
1429 let mut tensor_map = tensor.map().expect("Failed to map memory");
1430 tensor_map[2] = 42;
1431 assert_eq!(tensor_map[1], 1, "Value at index 1 should be 1");
1432 assert_eq!(tensor_map[2], 42, "Value at index 2 should be 42");
1433 }
1434 }
1435
1436 #[test]
1437 #[cfg(target_os = "linux")]
1438 fn test_dma_no_fd_leaks() {
1439 let _lock = FD_LOCK.write().unwrap();
1440 if !is_dma_available() {
1441 log::warn!(
1442 "SKIPPED: {} - DMA memory allocation not available (permission denied or no DMA-BUF support)",
1443 function!()
1444 );
1445 return;
1446 }
1447
1448 let proc = procfs::process::Process::myself()
1449 .expect("Failed to get current process using /proc/self");
1450
1451 let start_open_fds = proc
1452 .fd_count()
1453 .expect("Failed to get open file descriptor count");
1454
1455 for _ in 0..100 {
1456 let tensor = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Dma), None)
1457 .expect("Failed to create tensor");
1458 let mut map = tensor.map().unwrap();
1459 map.as_mut_slice().fill(233);
1460 }
1461
1462 let end_open_fds = proc
1463 .fd_count()
1464 .expect("Failed to get open file descriptor count");
1465
1466 assert_eq!(
1467 start_open_fds, end_open_fds,
1468 "File descriptor leak detected: {} -> {}",
1469 start_open_fds, end_open_fds
1470 );
1471 }
1472
1473 #[test]
1474 #[cfg(target_os = "linux")]
1475 fn test_dma_from_fd_no_fd_leaks() {
1476 let _lock = FD_LOCK.write().unwrap();
1477 if !is_dma_available() {
1478 log::warn!(
1479 "SKIPPED: {} - DMA memory allocation not available (permission denied or no DMA-BUF support)",
1480 function!()
1481 );
1482 return;
1483 }
1484
1485 let proc = procfs::process::Process::myself()
1486 .expect("Failed to get current process using /proc/self");
1487
1488 let start_open_fds = proc
1489 .fd_count()
1490 .expect("Failed to get open file descriptor count");
1491
1492 let orig = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Dma), None).unwrap();
1493
1494 for _ in 0..100 {
1495 let tensor =
1496 Tensor::<u8>::from_fd(orig.clone_fd().unwrap(), orig.shape(), None).unwrap();
1497 let mut map = tensor.map().unwrap();
1498 map.as_mut_slice().fill(233);
1499 }
1500 drop(orig);
1501
1502 let end_open_fds = proc.fd_count().unwrap();
1503
1504 assert_eq!(
1505 start_open_fds, end_open_fds,
1506 "File descriptor leak detected: {} -> {}",
1507 start_open_fds, end_open_fds
1508 );
1509 }
1510
1511 #[test]
1512 #[cfg(target_os = "linux")]
1513 fn test_shm_no_fd_leaks() {
1514 let _lock = FD_LOCK.write().unwrap();
1515 if !is_shm_available() {
1516 log::warn!(
1517 "SKIPPED: {} - SHM memory allocation not available (permission denied or no SHM support)",
1518 function!()
1519 );
1520 return;
1521 }
1522
1523 let proc = procfs::process::Process::myself()
1524 .expect("Failed to get current process using /proc/self");
1525
1526 let start_open_fds = proc
1527 .fd_count()
1528 .expect("Failed to get open file descriptor count");
1529
1530 for _ in 0..100 {
1531 let tensor = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Shm), None)
1532 .expect("Failed to create tensor");
1533 let mut map = tensor.map().unwrap();
1534 map.as_mut_slice().fill(233);
1535 }
1536
1537 let end_open_fds = proc
1538 .fd_count()
1539 .expect("Failed to get open file descriptor count");
1540
1541 assert_eq!(
1542 start_open_fds, end_open_fds,
1543 "File descriptor leak detected: {} -> {}",
1544 start_open_fds, end_open_fds
1545 );
1546 }
1547
1548 #[test]
1549 #[cfg(target_os = "linux")]
1550 fn test_shm_from_fd_no_fd_leaks() {
1551 let _lock = FD_LOCK.write().unwrap();
1552 if !is_shm_available() {
1553 log::warn!(
1554 "SKIPPED: {} - SHM memory allocation not available (permission denied or no SHM support)",
1555 function!()
1556 );
1557 return;
1558 }
1559
1560 let proc = procfs::process::Process::myself()
1561 .expect("Failed to get current process using /proc/self");
1562
1563 let start_open_fds = proc
1564 .fd_count()
1565 .expect("Failed to get open file descriptor count");
1566
1567 let orig = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Shm), None).unwrap();
1568
1569 for _ in 0..100 {
1570 let tensor =
1571 Tensor::<u8>::from_fd(orig.clone_fd().unwrap(), orig.shape(), None).unwrap();
1572 let mut map = tensor.map().unwrap();
1573 map.as_mut_slice().fill(233);
1574 }
1575 drop(orig);
1576
1577 let end_open_fds = proc.fd_count().unwrap();
1578
1579 assert_eq!(
1580 start_open_fds, end_open_fds,
1581 "File descriptor leak detected: {} -> {}",
1582 start_open_fds, end_open_fds
1583 );
1584 }
1585
1586 #[cfg(feature = "ndarray")]
1587 #[test]
1588 fn test_ndarray() {
1589 let _lock = FD_LOCK.read().unwrap();
1590 let shape = vec![2, 3, 4];
1591 let tensor = Tensor::<f32>::new(&shape, None, None).expect("Failed to create tensor");
1592
1593 let mut tensor_map = tensor.map().expect("Failed to map tensor memory");
1594 tensor_map.fill(1.0);
1595
1596 let view = tensor_map.view().expect("Failed to get ndarray view");
1597 assert_eq!(view.shape(), &[2, 3, 4]);
1598 assert!(view.iter().all(|&x| x == 1.0));
1599
1600 let mut view_mut = tensor_map
1601 .view_mut()
1602 .expect("Failed to get mutable ndarray view");
1603 view_mut[[0, 0, 0]] = 42.0;
1604 assert_eq!(view_mut[[0, 0, 0]], 42.0);
1605 assert_eq!(tensor_map[0], 42.0, "Value at index 0 should be 42");
1606 }
1607
1608 #[test]
1609 fn test_buffer_identity_unique() {
1610 let id1 = BufferIdentity::new();
1611 let id2 = BufferIdentity::new();
1612 assert_ne!(
1613 id1.id(),
1614 id2.id(),
1615 "Two identities should have different ids"
1616 );
1617 }
1618
1619 #[test]
1620 fn test_buffer_identity_clone_shares_guard() {
1621 let id1 = BufferIdentity::new();
1622 let weak = id1.weak();
1623 assert!(
1624 weak.upgrade().is_some(),
1625 "Weak should be alive while original exists"
1626 );
1627
1628 let id2 = id1.clone();
1629 assert_eq!(id1.id(), id2.id(), "Cloned identity should have same id");
1630
1631 drop(id1);
1632 assert!(
1633 weak.upgrade().is_some(),
1634 "Weak should still be alive (clone holds Arc)"
1635 );
1636
1637 drop(id2);
1638 assert!(
1639 weak.upgrade().is_none(),
1640 "Weak should be dead after all clones dropped"
1641 );
1642 }
1643
1644 #[test]
1645 fn test_tensor_buffer_identity() {
1646 let t1 = Tensor::<u8>::new(&[100], Some(TensorMemory::Mem), Some("t1")).unwrap();
1647 let t2 = Tensor::<u8>::new(&[100], Some(TensorMemory::Mem), Some("t2")).unwrap();
1648 assert_ne!(
1649 t1.buffer_identity().id(),
1650 t2.buffer_identity().id(),
1651 "Different tensors should have different buffer ids"
1652 );
1653 }
1654
1655 pub static FD_LOCK: RwLock<()> = RwLock::new(());
1659
1660 #[test]
1663 #[cfg(not(target_os = "linux"))]
1664 fn test_dma_not_available_on_non_linux() {
1665 assert!(
1666 !is_dma_available(),
1667 "DMA memory allocation should NOT be available on non-Linux platforms"
1668 );
1669 }
1670
1671 #[test]
1674 #[cfg(unix)]
1675 fn test_shm_available_and_usable() {
1676 assert!(
1677 is_shm_available(),
1678 "SHM memory allocation should be available on Unix systems"
1679 );
1680
1681 let tensor = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Shm), None)
1683 .expect("Failed to create SHM tensor");
1684
1685 let mut map = tensor.map().expect("Failed to map SHM tensor");
1687 map.as_mut_slice().fill(0xAB);
1688
1689 assert!(
1691 map.as_slice().iter().all(|&b| b == 0xAB),
1692 "SHM tensor data should be writable and readable"
1693 );
1694 }
1695}