1#[cfg(target_os = "linux")]
30mod dma;
31#[cfg(target_os = "linux")]
32mod dmabuf;
33mod error;
34mod format;
35mod mem;
36mod pbo;
37#[cfg(unix)]
38mod shm;
39mod tensor_dyn;
40
41#[cfg(target_os = "linux")]
42pub use crate::dma::{DmaMap, DmaTensor};
43pub use crate::mem::{MemMap, MemTensor};
44pub use crate::pbo::{PboMap, PboMapping, PboOps, PboTensor};
45#[cfg(unix)]
46pub use crate::shm::{ShmMap, ShmTensor};
47pub use error::{Error, Result};
48pub use format::{PixelFormat, PixelLayout};
49use num_traits::Num;
50use serde::{Deserialize, Serialize};
51#[cfg(unix)]
52use std::os::fd::OwnedFd;
53use std::{
54 fmt,
55 ops::{Deref, DerefMut},
56 sync::{
57 atomic::{AtomicU64, Ordering},
58 Arc, Weak,
59 },
60};
61pub use tensor_dyn::TensorDyn;
62
63#[cfg(unix)]
83pub struct PlaneDescriptor {
84 fd: OwnedFd,
85 stride: Option<usize>,
86 offset: Option<usize>,
87}
88
89#[cfg(unix)]
90impl PlaneDescriptor {
91 pub fn new(fd: std::os::fd::BorrowedFd<'_>) -> Result<Self> {
101 let owned = fd.try_clone_to_owned()?;
102 Ok(Self {
103 fd: owned,
104 stride: None,
105 offset: None,
106 })
107 }
108
109 pub fn with_stride(mut self, stride: usize) -> Self {
111 self.stride = Some(stride);
112 self
113 }
114
115 pub fn with_offset(mut self, offset: usize) -> Self {
117 self.offset = Some(offset);
118 self
119 }
120
121 pub fn into_fd(self) -> OwnedFd {
123 self.fd
124 }
125
126 pub fn stride(&self) -> Option<usize> {
128 self.stride
129 }
130
131 pub fn offset(&self) -> Option<usize> {
133 self.offset
134 }
135}
136
137#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
139#[repr(u8)]
140#[non_exhaustive]
141pub enum DType {
142 U8,
143 I8,
144 U16,
145 I16,
146 U32,
147 I32,
148 U64,
149 I64,
150 F16,
151 F32,
152 F64,
153}
154
155impl DType {
156 pub const fn size(&self) -> usize {
158 match self {
159 Self::U8 | Self::I8 => 1,
160 Self::U16 | Self::I16 | Self::F16 => 2,
161 Self::U32 | Self::I32 | Self::F32 => 4,
162 Self::U64 | Self::I64 | Self::F64 => 8,
163 }
164 }
165
166 pub const fn name(&self) -> &'static str {
168 match self {
169 Self::U8 => "u8",
170 Self::I8 => "i8",
171 Self::U16 => "u16",
172 Self::I16 => "i16",
173 Self::U32 => "u32",
174 Self::I32 => "i32",
175 Self::U64 => "u64",
176 Self::I64 => "i64",
177 Self::F16 => "f16",
178 Self::F32 => "f32",
179 Self::F64 => "f64",
180 }
181 }
182}
183
184impl fmt::Display for DType {
185 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
186 f.write_str(self.name())
187 }
188}
189
190static NEXT_BUFFER_ID: AtomicU64 = AtomicU64::new(1);
192
193#[derive(Debug, Clone)]
199pub struct BufferIdentity {
200 id: u64,
201 guard: Arc<()>,
202}
203
204impl BufferIdentity {
205 pub fn new() -> Self {
207 Self {
208 id: NEXT_BUFFER_ID.fetch_add(1, Ordering::Relaxed),
209 guard: Arc::new(()),
210 }
211 }
212
213 pub fn id(&self) -> u64 {
215 self.id
216 }
217
218 pub fn weak(&self) -> Weak<()> {
221 Arc::downgrade(&self.guard)
222 }
223}
224
225impl Default for BufferIdentity {
226 fn default() -> Self {
227 Self::new()
228 }
229}
230
231#[cfg(target_os = "linux")]
232use nix::sys::stat::{major, minor};
233
234pub trait TensorTrait<T>: Send + Sync
235where
236 T: Num + Clone + fmt::Debug,
237{
238 fn new(shape: &[usize], name: Option<&str>) -> Result<Self>
241 where
242 Self: Sized;
243
244 #[cfg(unix)]
245 fn from_fd(fd: std::os::fd::OwnedFd, shape: &[usize], name: Option<&str>) -> Result<Self>
251 where
252 Self: Sized;
253
254 #[cfg(unix)]
255 fn clone_fd(&self) -> Result<std::os::fd::OwnedFd>;
257
258 fn memory(&self) -> TensorMemory;
260
261 fn name(&self) -> String;
263
264 fn len(&self) -> usize {
266 self.shape().iter().product()
267 }
268
269 fn is_empty(&self) -> bool {
271 self.len() == 0
272 }
273
274 fn size(&self) -> usize {
276 self.len() * std::mem::size_of::<T>()
277 }
278
279 fn shape(&self) -> &[usize];
281
282 fn reshape(&mut self, shape: &[usize]) -> Result<()>;
285
286 fn map(&self) -> Result<TensorMap<T>>;
289
290 fn buffer_identity(&self) -> &BufferIdentity;
292}
293
294pub trait TensorMapTrait<T>
295where
296 T: Num + Clone + fmt::Debug,
297{
298 fn shape(&self) -> &[usize];
300
301 fn unmap(&mut self);
303
304 fn len(&self) -> usize {
306 self.shape().iter().product()
307 }
308
309 fn is_empty(&self) -> bool {
311 self.len() == 0
312 }
313
314 fn size(&self) -> usize {
316 self.len() * std::mem::size_of::<T>()
317 }
318
319 fn as_slice(&self) -> &[T];
321
322 fn as_mut_slice(&mut self) -> &mut [T];
324
325 #[cfg(feature = "ndarray")]
326 fn view(&'_ self) -> Result<ndarray::ArrayView<'_, T, ndarray::Dim<ndarray::IxDynImpl>>> {
328 Ok(ndarray::ArrayView::from_shape(
329 self.shape(),
330 self.as_slice(),
331 )?)
332 }
333
334 #[cfg(feature = "ndarray")]
335 fn view_mut(
337 &'_ mut self,
338 ) -> Result<ndarray::ArrayViewMut<'_, T, ndarray::Dim<ndarray::IxDynImpl>>> {
339 let shape = self.shape().to_vec();
340 Ok(ndarray::ArrayViewMut::from_shape(
341 shape,
342 self.as_mut_slice(),
343 )?)
344 }
345}
346
347#[derive(Debug, Clone, Copy, PartialEq, Eq)]
348pub enum TensorMemory {
349 #[cfg(target_os = "linux")]
350 Dma,
354 #[cfg(unix)]
355 Shm,
358
359 Mem,
361
362 Pbo,
365}
366
367impl From<TensorMemory> for String {
368 fn from(memory: TensorMemory) -> Self {
369 match memory {
370 #[cfg(target_os = "linux")]
371 TensorMemory::Dma => "dma".to_owned(),
372 #[cfg(unix)]
373 TensorMemory::Shm => "shm".to_owned(),
374 TensorMemory::Mem => "mem".to_owned(),
375 TensorMemory::Pbo => "pbo".to_owned(),
376 }
377 }
378}
379
380impl TryFrom<&str> for TensorMemory {
381 type Error = Error;
382
383 fn try_from(s: &str) -> Result<Self> {
384 match s {
385 #[cfg(target_os = "linux")]
386 "dma" => Ok(TensorMemory::Dma),
387 #[cfg(unix)]
388 "shm" => Ok(TensorMemory::Shm),
389 "mem" => Ok(TensorMemory::Mem),
390 "pbo" => Ok(TensorMemory::Pbo),
391 _ => Err(Error::InvalidMemoryType(s.to_owned())),
392 }
393 }
394}
395
396#[derive(Debug)]
397#[allow(dead_code)] pub(crate) enum TensorStorage<T>
399where
400 T: Num + Clone + fmt::Debug + Send + Sync,
401{
402 #[cfg(target_os = "linux")]
403 Dma(DmaTensor<T>),
404 #[cfg(unix)]
405 Shm(ShmTensor<T>),
406 Mem(MemTensor<T>),
407 Pbo(PboTensor<T>),
408}
409
410impl<T> TensorStorage<T>
411where
412 T: Num + Clone + fmt::Debug + Send + Sync,
413{
414 fn new(shape: &[usize], memory: Option<TensorMemory>, name: Option<&str>) -> Result<Self> {
419 match memory {
420 #[cfg(target_os = "linux")]
421 Some(TensorMemory::Dma) => {
422 DmaTensor::<T>::new(shape, name).map(TensorStorage::Dma)
423 }
424 #[cfg(unix)]
425 Some(TensorMemory::Shm) => {
426 ShmTensor::<T>::new(shape, name).map(TensorStorage::Shm)
427 }
428 Some(TensorMemory::Mem) => {
429 MemTensor::<T>::new(shape, name).map(TensorStorage::Mem)
430 }
431 Some(TensorMemory::Pbo) => Err(crate::error::Error::NotImplemented(
432 "PboTensor cannot be created via Tensor::new() — use ImageProcessor::create_image()".to_owned(),
433 )),
434 None => {
435 if std::env::var("EDGEFIRST_TENSOR_FORCE_MEM")
436 .is_ok_and(|x| x != "0" && x.to_lowercase() != "false")
437 {
438 MemTensor::<T>::new(shape, name).map(TensorStorage::Mem)
439 } else {
440 #[cfg(target_os = "linux")]
441 {
442 match DmaTensor::<T>::new(shape, name) {
444 Ok(tensor) => Ok(TensorStorage::Dma(tensor)),
445 Err(_) => {
446 match ShmTensor::<T>::new(shape, name)
447 .map(TensorStorage::Shm)
448 {
449 Ok(tensor) => Ok(tensor),
450 Err(_) => MemTensor::<T>::new(shape, name)
451 .map(TensorStorage::Mem),
452 }
453 }
454 }
455 }
456 #[cfg(all(unix, not(target_os = "linux")))]
457 {
458 match ShmTensor::<T>::new(shape, name) {
460 Ok(tensor) => Ok(TensorStorage::Shm(tensor)),
461 Err(_) => {
462 MemTensor::<T>::new(shape, name).map(TensorStorage::Mem)
463 }
464 }
465 }
466 #[cfg(not(unix))]
467 {
468 MemTensor::<T>::new(shape, name).map(TensorStorage::Mem)
470 }
471 }
472 }
473 }
474 }
475
476 #[cfg(unix)]
479 fn from_fd(fd: OwnedFd, shape: &[usize], name: Option<&str>) -> Result<Self> {
480 #[cfg(target_os = "linux")]
481 {
482 use nix::sys::stat::fstat;
483
484 let stat = fstat(&fd)?;
485 let major = major(stat.st_dev);
486 let minor = minor(stat.st_dev);
487
488 log::debug!("Creating tensor from fd: major={major}, minor={minor}");
489
490 if major != 0 {
491 return Err(Error::UnknownDeviceType(major, minor));
493 }
494
495 match minor {
496 9 | 10 => {
497 DmaTensor::<T>::from_fd(fd, shape, name).map(TensorStorage::Dma)
499 }
500 _ => {
501 ShmTensor::<T>::from_fd(fd, shape, name).map(TensorStorage::Shm)
503 }
504 }
505 }
506 #[cfg(all(unix, not(target_os = "linux")))]
507 {
508 ShmTensor::<T>::from_fd(fd, shape, name).map(TensorStorage::Shm)
510 }
511 }
512}
513
514impl<T> TensorTrait<T> for TensorStorage<T>
515where
516 T: Num + Clone + fmt::Debug + Send + Sync,
517{
518 fn new(shape: &[usize], name: Option<&str>) -> Result<Self> {
519 Self::new(shape, None, name)
520 }
521
522 #[cfg(unix)]
523 fn from_fd(fd: OwnedFd, shape: &[usize], name: Option<&str>) -> Result<Self> {
524 Self::from_fd(fd, shape, name)
525 }
526
527 #[cfg(unix)]
528 fn clone_fd(&self) -> Result<OwnedFd> {
529 match self {
530 #[cfg(target_os = "linux")]
531 TensorStorage::Dma(t) => t.clone_fd(),
532 TensorStorage::Shm(t) => t.clone_fd(),
533 TensorStorage::Mem(t) => t.clone_fd(),
534 TensorStorage::Pbo(t) => t.clone_fd(),
535 }
536 }
537
538 fn memory(&self) -> TensorMemory {
539 match self {
540 #[cfg(target_os = "linux")]
541 TensorStorage::Dma(_) => TensorMemory::Dma,
542 #[cfg(unix)]
543 TensorStorage::Shm(_) => TensorMemory::Shm,
544 TensorStorage::Mem(_) => TensorMemory::Mem,
545 TensorStorage::Pbo(_) => TensorMemory::Pbo,
546 }
547 }
548
549 fn name(&self) -> String {
550 match self {
551 #[cfg(target_os = "linux")]
552 TensorStorage::Dma(t) => t.name(),
553 #[cfg(unix)]
554 TensorStorage::Shm(t) => t.name(),
555 TensorStorage::Mem(t) => t.name(),
556 TensorStorage::Pbo(t) => t.name(),
557 }
558 }
559
560 fn shape(&self) -> &[usize] {
561 match self {
562 #[cfg(target_os = "linux")]
563 TensorStorage::Dma(t) => t.shape(),
564 #[cfg(unix)]
565 TensorStorage::Shm(t) => t.shape(),
566 TensorStorage::Mem(t) => t.shape(),
567 TensorStorage::Pbo(t) => t.shape(),
568 }
569 }
570
571 fn reshape(&mut self, shape: &[usize]) -> Result<()> {
572 match self {
573 #[cfg(target_os = "linux")]
574 TensorStorage::Dma(t) => t.reshape(shape),
575 #[cfg(unix)]
576 TensorStorage::Shm(t) => t.reshape(shape),
577 TensorStorage::Mem(t) => t.reshape(shape),
578 TensorStorage::Pbo(t) => t.reshape(shape),
579 }
580 }
581
582 fn map(&self) -> Result<TensorMap<T>> {
583 match self {
584 #[cfg(target_os = "linux")]
585 TensorStorage::Dma(t) => t.map(),
586 #[cfg(unix)]
587 TensorStorage::Shm(t) => t.map(),
588 TensorStorage::Mem(t) => t.map(),
589 TensorStorage::Pbo(t) => t.map(),
590 }
591 }
592
593 fn buffer_identity(&self) -> &BufferIdentity {
594 match self {
595 #[cfg(target_os = "linux")]
596 TensorStorage::Dma(t) => t.buffer_identity(),
597 #[cfg(unix)]
598 TensorStorage::Shm(t) => t.buffer_identity(),
599 TensorStorage::Mem(t) => t.buffer_identity(),
600 TensorStorage::Pbo(t) => t.buffer_identity(),
601 }
602 }
603}
604
605#[derive(Debug)]
611pub struct Tensor<T>
612where
613 T: Num + Clone + fmt::Debug + Send + Sync,
614{
615 pub(crate) storage: TensorStorage<T>,
616 format: Option<PixelFormat>,
617 chroma: Option<Box<Tensor<T>>>,
618 row_stride: Option<usize>,
621 plane_offset: Option<usize>,
624}
625
626impl<T> Tensor<T>
627where
628 T: Num + Clone + fmt::Debug + Send + Sync,
629{
630 pub(crate) fn wrap(storage: TensorStorage<T>) -> Self {
632 Self {
633 storage,
634 format: None,
635 chroma: None,
636 row_stride: None,
637 plane_offset: None,
638 }
639 }
640
641 pub fn new(shape: &[usize], memory: Option<TensorMemory>, name: Option<&str>) -> Result<Self> {
666 TensorStorage::new(shape, memory, name).map(Self::wrap)
667 }
668
669 pub fn image(
671 width: usize,
672 height: usize,
673 format: PixelFormat,
674 memory: Option<TensorMemory>,
675 ) -> Result<Self> {
676 let shape = match format.layout() {
677 PixelLayout::Packed => vec![height, width, format.channels()],
678 PixelLayout::Planar => vec![format.channels(), height, width],
679 PixelLayout::SemiPlanar => {
680 let total_h = match format {
684 PixelFormat::Nv12 => {
685 if !height.is_multiple_of(2) {
686 return Err(Error::InvalidArgument(format!(
687 "NV12 requires even height, got {height}"
688 )));
689 }
690 height * 3 / 2
691 }
692 PixelFormat::Nv16 => height * 2,
693 _ => {
694 return Err(Error::InvalidArgument(format!(
695 "unknown semi-planar height multiplier for {format:?}"
696 )))
697 }
698 };
699 vec![total_h, width]
700 }
701 };
702 let mut t = Self::new(&shape, memory, None)?;
703 t.format = Some(format);
704 Ok(t)
705 }
706
707 pub fn set_format(&mut self, format: PixelFormat) -> Result<()> {
724 let shape = self.shape();
725 match format.layout() {
726 PixelLayout::Packed => {
727 if shape.len() != 3 || shape[2] != format.channels() {
728 return Err(Error::InvalidShape(format!(
729 "packed format {format:?} expects [H, W, {}], got {shape:?}",
730 format.channels()
731 )));
732 }
733 }
734 PixelLayout::Planar => {
735 if shape.len() != 3 || shape[0] != format.channels() {
736 return Err(Error::InvalidShape(format!(
737 "planar format {format:?} expects [{}, H, W], got {shape:?}",
738 format.channels()
739 )));
740 }
741 }
742 PixelLayout::SemiPlanar => {
743 if shape.len() != 2 {
744 return Err(Error::InvalidShape(format!(
745 "semi-planar format {format:?} expects [H*k, W], got {shape:?}"
746 )));
747 }
748 match format {
749 PixelFormat::Nv12 if !shape[0].is_multiple_of(3) => {
750 return Err(Error::InvalidShape(format!(
751 "NV12 contiguous shape[0] must be divisible by 3, got {}",
752 shape[0]
753 )));
754 }
755 PixelFormat::Nv16 if !shape[0].is_multiple_of(2) => {
756 return Err(Error::InvalidShape(format!(
757 "NV16 contiguous shape[0] must be even, got {}",
758 shape[0]
759 )));
760 }
761 _ => {}
762 }
763 }
764 }
765 if self.format != Some(format) {
768 self.row_stride = None;
769 self.plane_offset = None;
770 #[cfg(target_os = "linux")]
771 if let TensorStorage::Dma(ref mut dma) = self.storage {
772 dma.mmap_offset = 0;
773 }
774 }
775 self.format = Some(format);
776 Ok(())
777 }
778
779 pub fn format(&self) -> Option<PixelFormat> {
781 self.format
782 }
783
784 pub fn width(&self) -> Option<usize> {
786 let fmt = self.format?;
787 let shape = self.shape();
788 match fmt.layout() {
789 PixelLayout::Packed => Some(shape[1]),
790 PixelLayout::Planar => Some(shape[2]),
791 PixelLayout::SemiPlanar => Some(shape[1]),
792 }
793 }
794
795 pub fn height(&self) -> Option<usize> {
797 let fmt = self.format?;
798 let shape = self.shape();
799 match fmt.layout() {
800 PixelLayout::Packed => Some(shape[0]),
801 PixelLayout::Planar => Some(shape[1]),
802 PixelLayout::SemiPlanar => {
803 if self.is_multiplane() {
804 Some(shape[0])
805 } else {
806 match fmt {
807 PixelFormat::Nv12 => Some(shape[0] * 2 / 3),
808 PixelFormat::Nv16 => Some(shape[0] / 2),
809 _ => None,
810 }
811 }
812 }
813 }
814 }
815
816 pub fn from_planes(luma: Tensor<T>, chroma: Tensor<T>, format: PixelFormat) -> Result<Self> {
818 if format.layout() != PixelLayout::SemiPlanar {
819 return Err(Error::InvalidArgument(format!(
820 "from_planes requires a semi-planar format, got {format:?}"
821 )));
822 }
823 if chroma.format.is_some() || chroma.chroma.is_some() {
824 return Err(Error::InvalidArgument(
825 "chroma tensor must be a raw tensor (no format or chroma metadata)".into(),
826 ));
827 }
828 let luma_shape = luma.shape();
829 let chroma_shape = chroma.shape();
830 if luma_shape.len() != 2 || chroma_shape.len() != 2 {
831 return Err(Error::InvalidArgument(format!(
832 "from_planes expects 2D shapes, got luma={luma_shape:?} chroma={chroma_shape:?}"
833 )));
834 }
835 if luma_shape[1] != chroma_shape[1] {
836 return Err(Error::InvalidArgument(format!(
837 "luma width {} != chroma width {}",
838 luma_shape[1], chroma_shape[1]
839 )));
840 }
841 match format {
842 PixelFormat::Nv12 => {
843 if luma_shape[0] % 2 != 0 {
844 return Err(Error::InvalidArgument(format!(
845 "NV12 requires even luma height, got {}",
846 luma_shape[0]
847 )));
848 }
849 if chroma_shape[0] != luma_shape[0] / 2 {
850 return Err(Error::InvalidArgument(format!(
851 "NV12 chroma height {} != luma height / 2 ({})",
852 chroma_shape[0],
853 luma_shape[0] / 2
854 )));
855 }
856 }
857 PixelFormat::Nv16 => {
858 if chroma_shape[0] != luma_shape[0] {
859 return Err(Error::InvalidArgument(format!(
860 "NV16 chroma height {} != luma height {}",
861 chroma_shape[0], luma_shape[0]
862 )));
863 }
864 }
865 _ => {
866 return Err(Error::InvalidArgument(format!(
867 "from_planes only supports NV12 and NV16, got {format:?}"
868 )));
869 }
870 }
871
872 Ok(Tensor {
873 storage: luma.storage,
874 format: Some(format),
875 chroma: Some(Box::new(chroma)),
876 row_stride: luma.row_stride,
877 plane_offset: luma.plane_offset,
878 })
879 }
880
881 pub fn is_multiplane(&self) -> bool {
883 self.chroma.is_some()
884 }
885
886 pub fn chroma(&self) -> Option<&Tensor<T>> {
888 self.chroma.as_deref()
889 }
890
891 pub fn chroma_mut(&mut self) -> Option<&mut Tensor<T>> {
893 self.chroma.as_deref_mut()
894 }
895
896 pub fn row_stride(&self) -> Option<usize> {
898 self.row_stride
899 }
900
901 pub fn effective_row_stride(&self) -> Option<usize> {
906 if let Some(s) = self.row_stride {
907 return Some(s);
908 }
909 let fmt = self.format?;
910 let w = self.width()?;
911 let elem = std::mem::size_of::<T>();
912 Some(match fmt.layout() {
913 PixelLayout::Packed => w * fmt.channels() * elem,
914 PixelLayout::Planar | PixelLayout::SemiPlanar => w * elem,
915 })
916 }
917
918 pub fn set_row_stride(&mut self, stride: usize) -> Result<()> {
945 let fmt = self.format.ok_or_else(|| {
946 Error::InvalidArgument("cannot set row_stride without a pixel format".into())
947 })?;
948 let w = self.width().ok_or_else(|| {
949 Error::InvalidArgument("cannot determine width for row_stride validation".into())
950 })?;
951 let elem = std::mem::size_of::<T>();
952 let min_stride = match fmt.layout() {
953 PixelLayout::Packed => w * fmt.channels() * elem,
954 PixelLayout::Planar | PixelLayout::SemiPlanar => w * elem,
955 };
956 if stride < min_stride {
957 return Err(Error::InvalidArgument(format!(
958 "row_stride {stride} < minimum {min_stride} for {fmt:?} at width {w}"
959 )));
960 }
961 self.row_stride = Some(stride);
962 Ok(())
963 }
964
965 pub fn set_row_stride_unchecked(&mut self, stride: usize) {
971 self.row_stride = Some(stride);
972 }
973
974 pub fn with_row_stride(mut self, stride: usize) -> Result<Self> {
981 self.set_row_stride(stride)?;
982 Ok(self)
983 }
984
985 pub fn plane_offset(&self) -> Option<usize> {
987 self.plane_offset
988 }
989
990 pub fn set_plane_offset(&mut self, offset: usize) {
996 self.plane_offset = Some(offset);
997 #[cfg(target_os = "linux")]
998 if let TensorStorage::Dma(ref mut dma) = self.storage {
999 dma.mmap_offset = offset;
1000 }
1001 }
1002
1003 pub fn with_plane_offset(mut self, offset: usize) -> Self {
1006 self.set_plane_offset(offset);
1007 self
1008 }
1009
1010 pub fn as_pbo(&self) -> Option<&PboTensor<T>> {
1012 match &self.storage {
1013 TensorStorage::Pbo(p) => Some(p),
1014 _ => None,
1015 }
1016 }
1017
1018 #[cfg(target_os = "linux")]
1020 pub fn as_dma(&self) -> Option<&DmaTensor<T>> {
1021 match &self.storage {
1022 TensorStorage::Dma(d) => Some(d),
1023 _ => None,
1024 }
1025 }
1026
1027 #[cfg(target_os = "linux")]
1038 pub fn dmabuf(&self) -> Result<std::os::fd::BorrowedFd<'_>> {
1039 use std::os::fd::AsFd;
1040 match &self.storage {
1041 TensorStorage::Dma(dma) => Ok(dma.fd.as_fd()),
1042 _ => Err(Error::NotImplemented(format!(
1043 "dmabuf requires DMA-backed tensor, got {:?}",
1044 self.storage.memory()
1045 ))),
1046 }
1047 }
1048
1049 pub fn from_pbo(pbo: PboTensor<T>) -> Self {
1051 Self {
1052 storage: TensorStorage::Pbo(pbo),
1053 format: None,
1054 chroma: None,
1055 row_stride: None,
1056 plane_offset: None,
1057 }
1058 }
1059}
1060
1061impl<T> TensorTrait<T> for Tensor<T>
1062where
1063 T: Num + Clone + fmt::Debug + Send + Sync,
1064{
1065 fn new(shape: &[usize], name: Option<&str>) -> Result<Self>
1066 where
1067 Self: Sized,
1068 {
1069 Self::new(shape, None, name)
1070 }
1071
1072 #[cfg(unix)]
1073 fn from_fd(fd: std::os::fd::OwnedFd, shape: &[usize], name: Option<&str>) -> Result<Self>
1074 where
1075 Self: Sized,
1076 {
1077 Ok(Self::wrap(TensorStorage::from_fd(fd, shape, name)?))
1078 }
1079
1080 #[cfg(unix)]
1081 fn clone_fd(&self) -> Result<std::os::fd::OwnedFd> {
1082 self.storage.clone_fd()
1083 }
1084
1085 fn memory(&self) -> TensorMemory {
1086 self.storage.memory()
1087 }
1088
1089 fn name(&self) -> String {
1090 self.storage.name()
1091 }
1092
1093 fn shape(&self) -> &[usize] {
1094 self.storage.shape()
1095 }
1096
1097 fn reshape(&mut self, shape: &[usize]) -> Result<()> {
1098 if self.chroma.is_some() {
1099 return Err(Error::InvalidOperation(
1100 "cannot reshape a multiplane tensor — decompose planes first".into(),
1101 ));
1102 }
1103 self.storage.reshape(shape)?;
1104 self.format = None;
1105 self.row_stride = None;
1106 self.plane_offset = None;
1107 #[cfg(target_os = "linux")]
1108 if let TensorStorage::Dma(ref mut dma) = self.storage {
1109 dma.mmap_offset = 0;
1110 }
1111 Ok(())
1112 }
1113
1114 fn map(&self) -> Result<TensorMap<T>> {
1115 if self.row_stride.is_some() {
1116 return Err(Error::InvalidOperation(
1117 "CPU mapping of strided tensors is not supported; use GPU path only".into(),
1118 ));
1119 }
1120 if self.plane_offset.is_some_and(|o| o > 0) {
1124 #[cfg(target_os = "linux")]
1125 if !matches!(self.storage, TensorStorage::Dma(_)) {
1126 return Err(Error::InvalidOperation(
1127 "plane offset only supported for DMA tensors".into(),
1128 ));
1129 }
1130 #[cfg(not(target_os = "linux"))]
1131 return Err(Error::InvalidOperation(
1132 "plane offset only supported for DMA tensors".into(),
1133 ));
1134 }
1135 self.storage.map()
1136 }
1137
1138 fn buffer_identity(&self) -> &BufferIdentity {
1139 self.storage.buffer_identity()
1140 }
1141}
1142
1143pub enum TensorMap<T>
1144where
1145 T: Num + Clone + fmt::Debug,
1146{
1147 #[cfg(target_os = "linux")]
1148 Dma(DmaMap<T>),
1149 #[cfg(unix)]
1150 Shm(ShmMap<T>),
1151 Mem(MemMap<T>),
1152 Pbo(PboMap<T>),
1153}
1154
1155impl<T> TensorMapTrait<T> for TensorMap<T>
1156where
1157 T: Num + Clone + fmt::Debug,
1158{
1159 fn shape(&self) -> &[usize] {
1160 match self {
1161 #[cfg(target_os = "linux")]
1162 TensorMap::Dma(map) => map.shape(),
1163 #[cfg(unix)]
1164 TensorMap::Shm(map) => map.shape(),
1165 TensorMap::Mem(map) => map.shape(),
1166 TensorMap::Pbo(map) => map.shape(),
1167 }
1168 }
1169
1170 fn unmap(&mut self) {
1171 match self {
1172 #[cfg(target_os = "linux")]
1173 TensorMap::Dma(map) => map.unmap(),
1174 #[cfg(unix)]
1175 TensorMap::Shm(map) => map.unmap(),
1176 TensorMap::Mem(map) => map.unmap(),
1177 TensorMap::Pbo(map) => map.unmap(),
1178 }
1179 }
1180
1181 fn as_slice(&self) -> &[T] {
1182 match self {
1183 #[cfg(target_os = "linux")]
1184 TensorMap::Dma(map) => map.as_slice(),
1185 #[cfg(unix)]
1186 TensorMap::Shm(map) => map.as_slice(),
1187 TensorMap::Mem(map) => map.as_slice(),
1188 TensorMap::Pbo(map) => map.as_slice(),
1189 }
1190 }
1191
1192 fn as_mut_slice(&mut self) -> &mut [T] {
1193 match self {
1194 #[cfg(target_os = "linux")]
1195 TensorMap::Dma(map) => map.as_mut_slice(),
1196 #[cfg(unix)]
1197 TensorMap::Shm(map) => map.as_mut_slice(),
1198 TensorMap::Mem(map) => map.as_mut_slice(),
1199 TensorMap::Pbo(map) => map.as_mut_slice(),
1200 }
1201 }
1202}
1203
1204impl<T> Deref for TensorMap<T>
1205where
1206 T: Num + Clone + fmt::Debug,
1207{
1208 type Target = [T];
1209
1210 fn deref(&self) -> &[T] {
1211 match self {
1212 #[cfg(target_os = "linux")]
1213 TensorMap::Dma(map) => map.deref(),
1214 #[cfg(unix)]
1215 TensorMap::Shm(map) => map.deref(),
1216 TensorMap::Mem(map) => map.deref(),
1217 TensorMap::Pbo(map) => map.deref(),
1218 }
1219 }
1220}
1221
1222impl<T> DerefMut for TensorMap<T>
1223where
1224 T: Num + Clone + fmt::Debug,
1225{
1226 fn deref_mut(&mut self) -> &mut [T] {
1227 match self {
1228 #[cfg(target_os = "linux")]
1229 TensorMap::Dma(map) => map.deref_mut(),
1230 #[cfg(unix)]
1231 TensorMap::Shm(map) => map.deref_mut(),
1232 TensorMap::Mem(map) => map.deref_mut(),
1233 TensorMap::Pbo(map) => map.deref_mut(),
1234 }
1235 }
1236}
1237
1238#[cfg(target_os = "linux")]
1250static DMA_AVAILABLE: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
1251
1252#[cfg(target_os = "linux")]
1254pub fn is_dma_available() -> bool {
1255 *DMA_AVAILABLE.get_or_init(|| Tensor::<u8>::new(&[64], Some(TensorMemory::Dma), None).is_ok())
1256}
1257
1258#[cfg(not(target_os = "linux"))]
1262pub fn is_dma_available() -> bool {
1263 false
1264}
1265
1266#[cfg(unix)]
1273static SHM_AVAILABLE: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
1274
1275#[cfg(unix)]
1277pub fn is_shm_available() -> bool {
1278 *SHM_AVAILABLE.get_or_init(|| Tensor::<u8>::new(&[64], Some(TensorMemory::Shm), None).is_ok())
1279}
1280
1281#[cfg(not(unix))]
1285pub fn is_shm_available() -> bool {
1286 false
1287}
1288
1289#[cfg(test)]
1290mod dtype_tests {
1291 use super::*;
1292
1293 #[test]
1294 fn dtype_size() {
1295 assert_eq!(DType::U8.size(), 1);
1296 assert_eq!(DType::I8.size(), 1);
1297 assert_eq!(DType::U16.size(), 2);
1298 assert_eq!(DType::I16.size(), 2);
1299 assert_eq!(DType::U32.size(), 4);
1300 assert_eq!(DType::I32.size(), 4);
1301 assert_eq!(DType::U64.size(), 8);
1302 assert_eq!(DType::I64.size(), 8);
1303 assert_eq!(DType::F16.size(), 2);
1304 assert_eq!(DType::F32.size(), 4);
1305 assert_eq!(DType::F64.size(), 8);
1306 }
1307
1308 #[test]
1309 fn dtype_name() {
1310 assert_eq!(DType::U8.name(), "u8");
1311 assert_eq!(DType::F16.name(), "f16");
1312 assert_eq!(DType::F32.name(), "f32");
1313 }
1314
1315 #[test]
1316 fn dtype_serde_roundtrip() {
1317 use serde_json;
1318 let dt = DType::F16;
1319 let json = serde_json::to_string(&dt).unwrap();
1320 let back: DType = serde_json::from_str(&json).unwrap();
1321 assert_eq!(dt, back);
1322 }
1323}
1324
1325#[cfg(test)]
1326mod image_tests {
1327 use super::*;
1328
1329 #[test]
1330 fn raw_tensor_has_no_format() {
1331 let t = Tensor::<u8>::new(&[480, 640, 3], None, None).unwrap();
1332 assert!(t.format().is_none());
1333 assert!(t.width().is_none());
1334 assert!(t.height().is_none());
1335 assert!(!t.is_multiplane());
1336 assert!(t.chroma().is_none());
1337 }
1338
1339 #[test]
1340 fn image_tensor_packed() {
1341 let t = Tensor::<u8>::image(640, 480, PixelFormat::Rgba, None).unwrap();
1342 assert_eq!(t.format(), Some(PixelFormat::Rgba));
1343 assert_eq!(t.width(), Some(640));
1344 assert_eq!(t.height(), Some(480));
1345 assert_eq!(t.shape(), &[480, 640, 4]);
1346 assert!(!t.is_multiplane());
1347 }
1348
1349 #[test]
1350 fn image_tensor_planar() {
1351 let t = Tensor::<u8>::image(640, 480, PixelFormat::PlanarRgb, None).unwrap();
1352 assert_eq!(t.format(), Some(PixelFormat::PlanarRgb));
1353 assert_eq!(t.width(), Some(640));
1354 assert_eq!(t.height(), Some(480));
1355 assert_eq!(t.shape(), &[3, 480, 640]);
1356 }
1357
1358 #[test]
1359 fn image_tensor_semi_planar_contiguous() {
1360 let t = Tensor::<u8>::image(640, 480, PixelFormat::Nv12, None).unwrap();
1361 assert_eq!(t.format(), Some(PixelFormat::Nv12));
1362 assert_eq!(t.width(), Some(640));
1363 assert_eq!(t.height(), Some(480));
1364 assert_eq!(t.shape(), &[720, 640]);
1366 assert!(!t.is_multiplane());
1367 }
1368
1369 #[test]
1370 fn set_format_valid() {
1371 let mut t = Tensor::<u8>::new(&[480, 640, 3], None, None).unwrap();
1372 assert!(t.format().is_none());
1373 t.set_format(PixelFormat::Rgb).unwrap();
1374 assert_eq!(t.format(), Some(PixelFormat::Rgb));
1375 assert_eq!(t.width(), Some(640));
1376 assert_eq!(t.height(), Some(480));
1377 }
1378
1379 #[test]
1380 fn set_format_invalid_shape() {
1381 let mut t = Tensor::<u8>::new(&[480, 640, 4], None, None).unwrap();
1382 let err = t.set_format(PixelFormat::Rgb);
1384 assert!(err.is_err());
1385 assert!(t.format().is_none());
1387 }
1388
1389 #[test]
1390 fn reshape_clears_format() {
1391 let mut t = Tensor::<u8>::image(640, 480, PixelFormat::Rgba, None).unwrap();
1392 assert_eq!(t.format(), Some(PixelFormat::Rgba));
1393 t.reshape(&[480 * 640 * 4]).unwrap();
1395 assert!(t.format().is_none());
1396 }
1397
1398 #[test]
1399 fn from_planes_nv12() {
1400 let y = Tensor::<u8>::new(&[480, 640], None, None).unwrap();
1401 let uv = Tensor::<u8>::new(&[240, 640], None, None).unwrap();
1402 let img = Tensor::from_planes(y, uv, PixelFormat::Nv12).unwrap();
1403 assert_eq!(img.format(), Some(PixelFormat::Nv12));
1404 assert!(img.is_multiplane());
1405 assert!(img.chroma().is_some());
1406 assert_eq!(img.width(), Some(640));
1407 assert_eq!(img.height(), Some(480));
1408 }
1409
1410 #[test]
1411 fn from_planes_rejects_non_semiplanar() {
1412 let y = Tensor::<u8>::new(&[480, 640], None, None).unwrap();
1413 let uv = Tensor::<u8>::new(&[240, 640], None, None).unwrap();
1414 let err = Tensor::from_planes(y, uv, PixelFormat::Rgb);
1415 assert!(err.is_err());
1416 }
1417
1418 #[test]
1419 fn reshape_multiplane_errors() {
1420 let y = Tensor::<u8>::new(&[480, 640], None, None).unwrap();
1421 let uv = Tensor::<u8>::new(&[240, 640], None, None).unwrap();
1422 let mut img = Tensor::from_planes(y, uv, PixelFormat::Nv12).unwrap();
1423 let err = img.reshape(&[480 * 640 + 240 * 640]);
1424 assert!(err.is_err());
1425 }
1426}
1427
1428#[cfg(test)]
1429mod tests {
1430 #[cfg(target_os = "linux")]
1431 use nix::unistd::{access, AccessFlags};
1432 #[cfg(target_os = "linux")]
1433 use std::io::Write as _;
1434 use std::sync::RwLock;
1435
1436 use super::*;
1437
1438 #[ctor::ctor]
1439 fn init() {
1440 env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init();
1441 }
1442
1443 #[cfg(target_os = "linux")]
1445 macro_rules! function {
1446 () => {{
1447 fn f() {}
1448 fn type_name_of<T>(_: T) -> &'static str {
1449 std::any::type_name::<T>()
1450 }
1451 let name = type_name_of(f);
1452
1453 match &name[..name.len() - 3].rfind(':') {
1455 Some(pos) => &name[pos + 1..name.len() - 3],
1456 None => &name[..name.len() - 3],
1457 }
1458 }};
1459 }
1460
1461 #[test]
1462 #[cfg(target_os = "linux")]
1463 fn test_tensor() {
1464 let _lock = FD_LOCK.read().unwrap();
1465 let shape = vec![1];
1466 let tensor = DmaTensor::<f32>::new(&shape, Some("dma_tensor"));
1467 let dma_enabled = tensor.is_ok();
1468
1469 let tensor = Tensor::<f32>::new(&shape, None, None).expect("Failed to create tensor");
1470 match dma_enabled {
1471 true => assert_eq!(tensor.memory(), TensorMemory::Dma),
1472 false => assert_eq!(tensor.memory(), TensorMemory::Shm),
1473 }
1474 }
1475
1476 #[test]
1477 #[cfg(all(unix, not(target_os = "linux")))]
1478 fn test_tensor() {
1479 let shape = vec![1];
1480 let tensor = Tensor::<f32>::new(&shape, None, None).expect("Failed to create tensor");
1481 assert!(
1483 tensor.memory() == TensorMemory::Shm || tensor.memory() == TensorMemory::Mem,
1484 "Expected SHM or Mem on macOS, got {:?}",
1485 tensor.memory()
1486 );
1487 }
1488
1489 #[test]
1490 #[cfg(not(unix))]
1491 fn test_tensor() {
1492 let shape = vec![1];
1493 let tensor = Tensor::<f32>::new(&shape, None, None).expect("Failed to create tensor");
1494 assert_eq!(tensor.memory(), TensorMemory::Mem);
1495 }
1496
1497 #[test]
1498 #[cfg(target_os = "linux")]
1499 fn test_dma_tensor() {
1500 let _lock = FD_LOCK.read().unwrap();
1501 match access(
1502 "/dev/dma_heap/linux,cma",
1503 AccessFlags::R_OK | AccessFlags::W_OK,
1504 ) {
1505 Ok(_) => println!("/dev/dma_heap/linux,cma is available"),
1506 Err(_) => match access(
1507 "/dev/dma_heap/system",
1508 AccessFlags::R_OK | AccessFlags::W_OK,
1509 ) {
1510 Ok(_) => println!("/dev/dma_heap/system is available"),
1511 Err(e) => {
1512 writeln!(
1513 &mut std::io::stdout(),
1514 "[WARNING] DMA Heap is unavailable: {e}"
1515 )
1516 .unwrap();
1517 return;
1518 }
1519 },
1520 }
1521
1522 let shape = vec![2, 3, 4];
1523 let tensor =
1524 DmaTensor::<f32>::new(&shape, Some("test_tensor")).expect("Failed to create tensor");
1525
1526 const DUMMY_VALUE: f32 = 12.34;
1527
1528 assert_eq!(tensor.memory(), TensorMemory::Dma);
1529 assert_eq!(tensor.name(), "test_tensor");
1530 assert_eq!(tensor.shape(), &shape);
1531 assert_eq!(tensor.size(), 2 * 3 * 4 * std::mem::size_of::<f32>());
1532 assert_eq!(tensor.len(), 2 * 3 * 4);
1533
1534 {
1535 let mut tensor_map = tensor.map().expect("Failed to map DMA memory");
1536 tensor_map.fill(42.0);
1537 assert!(tensor_map.iter().all(|&x| x == 42.0));
1538 }
1539
1540 {
1541 let shared = Tensor::<f32>::from_fd(
1542 tensor
1543 .clone_fd()
1544 .expect("Failed to duplicate tensor file descriptor"),
1545 &shape,
1546 Some("test_tensor_shared"),
1547 )
1548 .expect("Failed to create tensor from fd");
1549
1550 assert_eq!(shared.memory(), TensorMemory::Dma);
1551 assert_eq!(shared.name(), "test_tensor_shared");
1552 assert_eq!(shared.shape(), &shape);
1553
1554 let mut tensor_map = shared.map().expect("Failed to map DMA memory from fd");
1555 tensor_map.fill(DUMMY_VALUE);
1556 assert!(tensor_map.iter().all(|&x| x == DUMMY_VALUE));
1557 }
1558
1559 {
1560 let tensor_map = tensor.map().expect("Failed to map DMA memory");
1561 assert!(tensor_map.iter().all(|&x| x == DUMMY_VALUE));
1562 }
1563
1564 let mut tensor = DmaTensor::<u8>::new(&shape, None).expect("Failed to create tensor");
1565 assert_eq!(tensor.shape(), &shape);
1566 let new_shape = vec![3, 4, 4];
1567 assert!(
1568 tensor.reshape(&new_shape).is_err(),
1569 "Reshape should fail due to size mismatch"
1570 );
1571 assert_eq!(tensor.shape(), &shape, "Shape should remain unchanged");
1572
1573 let new_shape = vec![2, 3, 4];
1574 tensor.reshape(&new_shape).expect("Reshape should succeed");
1575 assert_eq!(
1576 tensor.shape(),
1577 &new_shape,
1578 "Shape should be updated after successful reshape"
1579 );
1580
1581 {
1582 let mut tensor_map = tensor.map().expect("Failed to map DMA memory");
1583 tensor_map.fill(1);
1584 assert!(tensor_map.iter().all(|&x| x == 1));
1585 }
1586
1587 {
1588 let mut tensor_map = tensor.map().expect("Failed to map DMA memory");
1589 tensor_map[2] = 42;
1590 assert_eq!(tensor_map[1], 1, "Value at index 1 should be 1");
1591 assert_eq!(tensor_map[2], 42, "Value at index 2 should be 42");
1592 }
1593 }
1594
1595 #[test]
1596 #[cfg(unix)]
1597 fn test_shm_tensor() {
1598 let _lock = FD_LOCK.read().unwrap();
1599 let shape = vec![2, 3, 4];
1600 let tensor =
1601 ShmTensor::<f32>::new(&shape, Some("test_tensor")).expect("Failed to create tensor");
1602 assert_eq!(tensor.shape(), &shape);
1603 assert_eq!(tensor.size(), 2 * 3 * 4 * std::mem::size_of::<f32>());
1604 assert_eq!(tensor.name(), "test_tensor");
1605
1606 const DUMMY_VALUE: f32 = 12.34;
1607 {
1608 let mut tensor_map = tensor.map().expect("Failed to map shared memory");
1609 tensor_map.fill(42.0);
1610 assert!(tensor_map.iter().all(|&x| x == 42.0));
1611 }
1612
1613 {
1614 let shared = Tensor::<f32>::from_fd(
1615 tensor
1616 .clone_fd()
1617 .expect("Failed to duplicate tensor file descriptor"),
1618 &shape,
1619 Some("test_tensor_shared"),
1620 )
1621 .expect("Failed to create tensor from fd");
1622
1623 assert_eq!(shared.memory(), TensorMemory::Shm);
1624 assert_eq!(shared.name(), "test_tensor_shared");
1625 assert_eq!(shared.shape(), &shape);
1626
1627 let mut tensor_map = shared.map().expect("Failed to map shared memory from fd");
1628 tensor_map.fill(DUMMY_VALUE);
1629 assert!(tensor_map.iter().all(|&x| x == DUMMY_VALUE));
1630 }
1631
1632 {
1633 let tensor_map = tensor.map().expect("Failed to map shared memory");
1634 assert!(tensor_map.iter().all(|&x| x == DUMMY_VALUE));
1635 }
1636
1637 let mut tensor = ShmTensor::<u8>::new(&shape, None).expect("Failed to create tensor");
1638 assert_eq!(tensor.shape(), &shape);
1639 let new_shape = vec![3, 4, 4];
1640 assert!(
1641 tensor.reshape(&new_shape).is_err(),
1642 "Reshape should fail due to size mismatch"
1643 );
1644 assert_eq!(tensor.shape(), &shape, "Shape should remain unchanged");
1645
1646 let new_shape = vec![2, 3, 4];
1647 tensor.reshape(&new_shape).expect("Reshape should succeed");
1648 assert_eq!(
1649 tensor.shape(),
1650 &new_shape,
1651 "Shape should be updated after successful reshape"
1652 );
1653
1654 {
1655 let mut tensor_map = tensor.map().expect("Failed to map shared memory");
1656 tensor_map.fill(1);
1657 assert!(tensor_map.iter().all(|&x| x == 1));
1658 }
1659
1660 {
1661 let mut tensor_map = tensor.map().expect("Failed to map shared memory");
1662 tensor_map[2] = 42;
1663 assert_eq!(tensor_map[1], 1, "Value at index 1 should be 1");
1664 assert_eq!(tensor_map[2], 42, "Value at index 2 should be 42");
1665 }
1666 }
1667
1668 #[test]
1669 fn test_mem_tensor() {
1670 let shape = vec![2, 3, 4];
1671 let tensor =
1672 MemTensor::<f32>::new(&shape, Some("test_tensor")).expect("Failed to create tensor");
1673 assert_eq!(tensor.shape(), &shape);
1674 assert_eq!(tensor.size(), 2 * 3 * 4 * std::mem::size_of::<f32>());
1675 assert_eq!(tensor.name(), "test_tensor");
1676
1677 {
1678 let mut tensor_map = tensor.map().expect("Failed to map memory");
1679 tensor_map.fill(42.0);
1680 assert!(tensor_map.iter().all(|&x| x == 42.0));
1681 }
1682
1683 let mut tensor = MemTensor::<u8>::new(&shape, None).expect("Failed to create tensor");
1684 assert_eq!(tensor.shape(), &shape);
1685 let new_shape = vec![3, 4, 4];
1686 assert!(
1687 tensor.reshape(&new_shape).is_err(),
1688 "Reshape should fail due to size mismatch"
1689 );
1690 assert_eq!(tensor.shape(), &shape, "Shape should remain unchanged");
1691
1692 let new_shape = vec![2, 3, 4];
1693 tensor.reshape(&new_shape).expect("Reshape should succeed");
1694 assert_eq!(
1695 tensor.shape(),
1696 &new_shape,
1697 "Shape should be updated after successful reshape"
1698 );
1699
1700 {
1701 let mut tensor_map = tensor.map().expect("Failed to map memory");
1702 tensor_map.fill(1);
1703 assert!(tensor_map.iter().all(|&x| x == 1));
1704 }
1705
1706 {
1707 let mut tensor_map = tensor.map().expect("Failed to map memory");
1708 tensor_map[2] = 42;
1709 assert_eq!(tensor_map[1], 1, "Value at index 1 should be 1");
1710 assert_eq!(tensor_map[2], 42, "Value at index 2 should be 42");
1711 }
1712 }
1713
1714 #[test]
1715 #[cfg(target_os = "linux")]
1716 fn test_dma_no_fd_leaks() {
1717 let _lock = FD_LOCK.write().unwrap();
1718 if !is_dma_available() {
1719 log::warn!(
1720 "SKIPPED: {} - DMA memory allocation not available (permission denied or no DMA-BUF support)",
1721 function!()
1722 );
1723 return;
1724 }
1725
1726 let proc = procfs::process::Process::myself()
1727 .expect("Failed to get current process using /proc/self");
1728
1729 let start_open_fds = proc
1730 .fd_count()
1731 .expect("Failed to get open file descriptor count");
1732
1733 for _ in 0..100 {
1734 let tensor = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Dma), None)
1735 .expect("Failed to create tensor");
1736 let mut map = tensor.map().unwrap();
1737 map.as_mut_slice().fill(233);
1738 }
1739
1740 let end_open_fds = proc
1741 .fd_count()
1742 .expect("Failed to get open file descriptor count");
1743
1744 assert_eq!(
1745 start_open_fds, end_open_fds,
1746 "File descriptor leak detected: {} -> {}",
1747 start_open_fds, end_open_fds
1748 );
1749 }
1750
1751 #[test]
1752 #[cfg(target_os = "linux")]
1753 fn test_dma_from_fd_no_fd_leaks() {
1754 let _lock = FD_LOCK.write().unwrap();
1755 if !is_dma_available() {
1756 log::warn!(
1757 "SKIPPED: {} - DMA memory allocation not available (permission denied or no DMA-BUF support)",
1758 function!()
1759 );
1760 return;
1761 }
1762
1763 let proc = procfs::process::Process::myself()
1764 .expect("Failed to get current process using /proc/self");
1765
1766 let start_open_fds = proc
1767 .fd_count()
1768 .expect("Failed to get open file descriptor count");
1769
1770 let orig = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Dma), None).unwrap();
1771
1772 for _ in 0..100 {
1773 let tensor =
1774 Tensor::<u8>::from_fd(orig.clone_fd().unwrap(), orig.shape(), None).unwrap();
1775 let mut map = tensor.map().unwrap();
1776 map.as_mut_slice().fill(233);
1777 }
1778 drop(orig);
1779
1780 let end_open_fds = proc.fd_count().unwrap();
1781
1782 assert_eq!(
1783 start_open_fds, end_open_fds,
1784 "File descriptor leak detected: {} -> {}",
1785 start_open_fds, end_open_fds
1786 );
1787 }
1788
1789 #[test]
1790 #[cfg(target_os = "linux")]
1791 fn test_shm_no_fd_leaks() {
1792 let _lock = FD_LOCK.write().unwrap();
1793 if !is_shm_available() {
1794 log::warn!(
1795 "SKIPPED: {} - SHM memory allocation not available (permission denied or no SHM support)",
1796 function!()
1797 );
1798 return;
1799 }
1800
1801 let proc = procfs::process::Process::myself()
1802 .expect("Failed to get current process using /proc/self");
1803
1804 let start_open_fds = proc
1805 .fd_count()
1806 .expect("Failed to get open file descriptor count");
1807
1808 for _ in 0..100 {
1809 let tensor = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Shm), None)
1810 .expect("Failed to create tensor");
1811 let mut map = tensor.map().unwrap();
1812 map.as_mut_slice().fill(233);
1813 }
1814
1815 let end_open_fds = proc
1816 .fd_count()
1817 .expect("Failed to get open file descriptor count");
1818
1819 assert_eq!(
1820 start_open_fds, end_open_fds,
1821 "File descriptor leak detected: {} -> {}",
1822 start_open_fds, end_open_fds
1823 );
1824 }
1825
1826 #[test]
1827 #[cfg(target_os = "linux")]
1828 fn test_shm_from_fd_no_fd_leaks() {
1829 let _lock = FD_LOCK.write().unwrap();
1830 if !is_shm_available() {
1831 log::warn!(
1832 "SKIPPED: {} - SHM memory allocation not available (permission denied or no SHM support)",
1833 function!()
1834 );
1835 return;
1836 }
1837
1838 let proc = procfs::process::Process::myself()
1839 .expect("Failed to get current process using /proc/self");
1840
1841 let start_open_fds = proc
1842 .fd_count()
1843 .expect("Failed to get open file descriptor count");
1844
1845 let orig = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Shm), None).unwrap();
1846
1847 for _ in 0..100 {
1848 let tensor =
1849 Tensor::<u8>::from_fd(orig.clone_fd().unwrap(), orig.shape(), None).unwrap();
1850 let mut map = tensor.map().unwrap();
1851 map.as_mut_slice().fill(233);
1852 }
1853 drop(orig);
1854
1855 let end_open_fds = proc.fd_count().unwrap();
1856
1857 assert_eq!(
1858 start_open_fds, end_open_fds,
1859 "File descriptor leak detected: {} -> {}",
1860 start_open_fds, end_open_fds
1861 );
1862 }
1863
1864 #[cfg(feature = "ndarray")]
1865 #[test]
1866 fn test_ndarray() {
1867 let _lock = FD_LOCK.read().unwrap();
1868 let shape = vec![2, 3, 4];
1869 let tensor = Tensor::<f32>::new(&shape, None, None).expect("Failed to create tensor");
1870
1871 let mut tensor_map = tensor.map().expect("Failed to map tensor memory");
1872 tensor_map.fill(1.0);
1873
1874 let view = tensor_map.view().expect("Failed to get ndarray view");
1875 assert_eq!(view.shape(), &[2, 3, 4]);
1876 assert!(view.iter().all(|&x| x == 1.0));
1877
1878 let mut view_mut = tensor_map
1879 .view_mut()
1880 .expect("Failed to get mutable ndarray view");
1881 view_mut[[0, 0, 0]] = 42.0;
1882 assert_eq!(view_mut[[0, 0, 0]], 42.0);
1883 assert_eq!(tensor_map[0], 42.0, "Value at index 0 should be 42");
1884 }
1885
1886 #[test]
1887 fn test_buffer_identity_unique() {
1888 let id1 = BufferIdentity::new();
1889 let id2 = BufferIdentity::new();
1890 assert_ne!(
1891 id1.id(),
1892 id2.id(),
1893 "Two identities should have different ids"
1894 );
1895 }
1896
1897 #[test]
1898 fn test_buffer_identity_clone_shares_guard() {
1899 let id1 = BufferIdentity::new();
1900 let weak = id1.weak();
1901 assert!(
1902 weak.upgrade().is_some(),
1903 "Weak should be alive while original exists"
1904 );
1905
1906 let id2 = id1.clone();
1907 assert_eq!(id1.id(), id2.id(), "Cloned identity should have same id");
1908
1909 drop(id1);
1910 assert!(
1911 weak.upgrade().is_some(),
1912 "Weak should still be alive (clone holds Arc)"
1913 );
1914
1915 drop(id2);
1916 assert!(
1917 weak.upgrade().is_none(),
1918 "Weak should be dead after all clones dropped"
1919 );
1920 }
1921
1922 #[test]
1923 fn test_tensor_buffer_identity() {
1924 let t1 = Tensor::<u8>::new(&[100], Some(TensorMemory::Mem), Some("t1")).unwrap();
1925 let t2 = Tensor::<u8>::new(&[100], Some(TensorMemory::Mem), Some("t2")).unwrap();
1926 assert_ne!(
1927 t1.buffer_identity().id(),
1928 t2.buffer_identity().id(),
1929 "Different tensors should have different buffer ids"
1930 );
1931 }
1932
1933 pub static FD_LOCK: RwLock<()> = RwLock::new(());
1937
1938 #[test]
1941 #[cfg(not(target_os = "linux"))]
1942 fn test_dma_not_available_on_non_linux() {
1943 assert!(
1944 !is_dma_available(),
1945 "DMA memory allocation should NOT be available on non-Linux platforms"
1946 );
1947 }
1948
1949 #[test]
1952 #[cfg(unix)]
1953 fn test_shm_available_and_usable() {
1954 assert!(
1955 is_shm_available(),
1956 "SHM memory allocation should be available on Unix systems"
1957 );
1958
1959 let tensor = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Shm), None)
1961 .expect("Failed to create SHM tensor");
1962
1963 let mut map = tensor.map().expect("Failed to map SHM tensor");
1965 map.as_mut_slice().fill(0xAB);
1966
1967 assert!(
1969 map.as_slice().iter().all(|&b| b == 0xAB),
1970 "SHM tensor data should be writable and readable"
1971 );
1972 }
1973}