1#[cfg(target_os = "linux")]
30mod dma;
31#[cfg(target_os = "linux")]
32mod dmabuf;
33mod error;
34mod format;
35#[cfg(target_os = "macos")]
36mod iosurface;
37mod mem;
38mod pbo;
39#[cfg(unix)]
40mod shm;
41mod tensor_dyn;
42
43#[cfg(target_os = "linux")]
44pub use crate::dma::{DmaMap, DmaTensor};
45#[cfg(target_os = "macos")]
46pub use crate::iosurface::{image_iosurface_layout, IoSurfaceMap, IoSurfaceTensor};
47pub use crate::mem::{MemMap, MemTensor};
48pub use crate::pbo::{PboMap, PboMapping, PboOps, PboTensor};
49#[cfg(unix)]
50pub use crate::shm::{ShmMap, ShmTensor};
51pub use error::{Error, Result};
52pub use format::{PixelFormat, PixelLayout};
53use num_traits::Num;
54use serde::{Deserialize, Serialize};
55#[cfg(unix)]
56use std::os::fd::OwnedFd;
57use std::{
58 fmt,
59 ops::{Deref, DerefMut},
60 sync::{
61 atomic::{AtomicU64, Ordering},
62 Arc, Weak,
63 },
64};
65pub use tensor_dyn::TensorDyn;
66
67#[cfg(unix)]
87pub struct PlaneDescriptor {
88 fd: OwnedFd,
89 stride: Option<usize>,
90 offset: Option<usize>,
91}
92
93#[cfg(unix)]
94impl PlaneDescriptor {
95 pub fn new(fd: std::os::fd::BorrowedFd<'_>) -> Result<Self> {
105 let owned = fd.try_clone_to_owned()?;
106 Ok(Self {
107 fd: owned,
108 stride: None,
109 offset: None,
110 })
111 }
112
113 pub fn with_stride(mut self, stride: usize) -> Self {
115 self.stride = Some(stride);
116 self
117 }
118
119 pub fn with_offset(mut self, offset: usize) -> Self {
121 self.offset = Some(offset);
122 self
123 }
124
125 pub fn into_fd(self) -> OwnedFd {
127 self.fd
128 }
129
130 pub fn stride(&self) -> Option<usize> {
132 self.stride
133 }
134
135 pub fn offset(&self) -> Option<usize> {
137 self.offset
138 }
139}
140
141#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
143#[repr(u8)]
144#[non_exhaustive]
145pub enum DType {
146 U8,
147 I8,
148 U16,
149 I16,
150 U32,
151 I32,
152 U64,
153 I64,
154 F16,
155 F32,
156 F64,
157}
158
159impl DType {
160 pub const fn size(&self) -> usize {
162 match self {
163 Self::U8 | Self::I8 => 1,
164 Self::U16 | Self::I16 | Self::F16 => 2,
165 Self::U32 | Self::I32 | Self::F32 => 4,
166 Self::U64 | Self::I64 | Self::F64 => 8,
167 }
168 }
169
170 pub const fn name(&self) -> &'static str {
172 match self {
173 Self::U8 => "u8",
174 Self::I8 => "i8",
175 Self::U16 => "u16",
176 Self::I16 => "i16",
177 Self::U32 => "u32",
178 Self::I32 => "i32",
179 Self::U64 => "u64",
180 Self::I64 => "i64",
181 Self::F16 => "f16",
182 Self::F32 => "f32",
183 Self::F64 => "f64",
184 }
185 }
186}
187
188impl fmt::Display for DType {
189 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
190 f.write_str(self.name())
191 }
192}
193
194mod sealed {
202 pub trait Sealed {}
203 impl Sealed for u8 {}
204 impl Sealed for i8 {}
205 impl Sealed for u16 {}
206 impl Sealed for i16 {}
207 impl Sealed for u32 {}
208 impl Sealed for i32 {}
209 impl Sealed for u64 {}
210 impl Sealed for i64 {}
211 }
213
214pub trait IntegerType: sealed::Sealed {}
221impl IntegerType for u8 {}
222impl IntegerType for i8 {}
223impl IntegerType for u16 {}
224impl IntegerType for i16 {}
225impl IntegerType for u32 {}
226impl IntegerType for i32 {}
227impl IntegerType for u64 {}
228impl IntegerType for i64 {}
229
230#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
254pub struct Quantization {
255 #[serde(deserialize_with = "deserialize_scalar_or_vec_f32")]
257 scale: Vec<f32>,
258
259 #[serde(
262 default,
263 deserialize_with = "deserialize_opt_scalar_or_vec_i32",
264 skip_serializing_if = "Option::is_none"
265 )]
266 zero_point: Option<Vec<i32>>,
267
268 #[serde(default, skip_serializing_if = "Option::is_none")]
272 axis: Option<usize>,
273}
274
275#[derive(Debug, Clone, Copy)]
281pub enum QuantMode<'a> {
282 PerTensorSymmetric {
283 scale: f32,
284 },
285 PerTensor {
286 scale: f32,
287 zero_point: i32,
288 },
289 PerChannelSymmetric {
290 scales: &'a [f32],
291 axis: usize,
292 },
293 PerChannel {
294 scales: &'a [f32],
295 zero_points: &'a [i32],
296 axis: usize,
297 },
298}
299
300impl Quantization {
301 pub fn per_tensor_symmetric(scale: f32) -> Self {
303 Self {
304 scale: vec![scale],
305 zero_point: None,
306 axis: None,
307 }
308 }
309
310 pub fn per_tensor(scale: f32, zero_point: i32) -> Self {
312 Self {
313 scale: vec![scale],
314 zero_point: Some(vec![zero_point]),
315 axis: None,
316 }
317 }
318
319 pub fn per_channel_symmetric(scales: Vec<f32>, axis: usize) -> Result<Self> {
321 if scales.is_empty() {
322 return Err(Error::QuantizationInvalid {
323 field: "scale.len",
324 expected: "non-empty per-channel scales".to_string(),
325 got: "length 0".to_string(),
326 });
327 }
328 Ok(Self {
329 scale: scales,
330 zero_point: None,
331 axis: Some(axis),
332 })
333 }
334
335 pub fn per_channel(scales: Vec<f32>, zero_points: Vec<i32>, axis: usize) -> Result<Self> {
338 if scales.is_empty() {
339 return Err(Error::QuantizationInvalid {
340 field: "scale.len",
341 expected: "non-empty per-channel scales".to_string(),
342 got: "length 0".to_string(),
343 });
344 }
345 if scales.len() != zero_points.len() {
346 return Err(Error::QuantizationInvalid {
347 field: "zero_point.len",
348 expected: format!("length matches scale ({})", scales.len()),
349 got: format!("length {}", zero_points.len()),
350 });
351 }
352 Ok(Self {
353 scale: scales,
354 zero_point: Some(zero_points),
355 axis: Some(axis),
356 })
357 }
358
359 pub fn mode(&self) -> QuantMode<'_> {
361 match (self.scale.len(), self.zero_point.as_deref(), self.axis) {
362 (1, None, _) => QuantMode::PerTensorSymmetric {
363 scale: self.scale[0],
364 },
365 (1, Some(zps), _) => QuantMode::PerTensor {
366 scale: self.scale[0],
367 zero_point: zps.first().copied().unwrap_or(0),
368 },
369 (_, None, Some(axis)) => QuantMode::PerChannelSymmetric {
370 scales: &self.scale,
371 axis,
372 },
373 (_, Some(zps), Some(axis)) => QuantMode::PerChannel {
374 scales: &self.scale,
375 zero_points: zps,
376 axis,
377 },
378 _ => {
384 debug_assert!(
385 false,
386 "Quantization::mode: per-channel without axis is unreachable"
387 );
388 QuantMode::PerTensorSymmetric {
389 scale: self.scale.first().copied().unwrap_or(1.0),
390 }
391 }
392 }
393 }
394
395 pub fn is_per_tensor(&self) -> bool {
397 self.scale.len() == 1
398 }
399
400 pub fn is_per_channel(&self) -> bool {
402 self.scale.len() > 1
403 }
404
405 pub fn is_symmetric(&self) -> bool {
408 match &self.zero_point {
409 None => true,
410 Some(zps) => zps.iter().all(|&z| z == 0),
411 }
412 }
413
414 pub fn scale(&self) -> &[f32] {
417 &self.scale
418 }
419
420 pub fn zero_point(&self) -> Option<&[i32]> {
422 self.zero_point.as_deref()
423 }
424
425 pub fn axis(&self) -> Option<usize> {
427 self.axis
428 }
429
430 pub(crate) fn validate(&self, shape: &[usize]) -> Result<()> {
440 if self.scale.is_empty() {
445 return Err(Error::QuantizationInvalid {
446 field: "scale.len",
447 expected: ">= 1".to_string(),
448 got: "0".to_string(),
449 });
450 }
451 if let Some(zps) = self.zero_point.as_ref() {
452 let expected = if self.scale.len() == 1 {
455 1
456 } else {
457 self.scale.len()
458 };
459 if zps.len() != expected {
460 return Err(Error::QuantizationInvalid {
461 field: "zero_point.len",
462 expected: format!(
463 "{expected} (matching {})",
464 if self.scale.len() == 1 {
465 "per-tensor scale"
466 } else {
467 "per-channel scale.len"
468 }
469 ),
470 got: format!("length {}", zps.len()),
471 });
472 }
473 }
474
475 match (self.scale.len(), self.axis) {
476 (1, None) => Ok(()),
477 (1, Some(_)) => Err(Error::QuantizationInvalid {
478 field: "per_tensor_redundant_axis",
479 expected: "axis=None for per-tensor quantization".to_string(),
480 got: format!("axis={:?}", self.axis),
481 }),
482 (_, None) => Err(Error::QuantizationInvalid {
483 field: "per_channel_requires_axis",
484 expected: format!(
485 "axis=Some(_) for per-channel quantization (scale.len={})",
486 self.scale.len()
487 ),
488 got: "axis=None".to_string(),
489 }),
490 (n, Some(axis)) => {
491 if axis >= shape.len() {
492 return Err(Error::QuantizationInvalid {
493 field: "axis",
494 expected: format!("axis < tensor rank ({})", shape.len()),
495 got: format!("axis={axis}"),
496 });
497 }
498 if shape[axis] != n {
499 return Err(Error::QuantizationInvalid {
500 field: "scale.len",
501 expected: format!("length matches shape[{axis}] ({})", shape[axis]),
502 got: format!("length {n}"),
503 });
504 }
505 Ok(())
506 }
507 }
508 }
509}
510
511impl From<(f32, i32)> for Quantization {
512 fn from((scale, zero_point): (f32, i32)) -> Self {
516 Self::per_tensor(scale, zero_point)
517 }
518}
519
520fn deserialize_scalar_or_vec_f32<'de, D: serde::Deserializer<'de>>(
521 de: D,
522) -> std::result::Result<Vec<f32>, D::Error> {
523 use serde::de::{self, Visitor};
524 struct V;
525 impl<'de> Visitor<'de> for V {
526 type Value = Vec<f32>;
527 fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
528 f.write_str("f32 or array of f32")
529 }
530 fn visit_f64<E: de::Error>(self, v: f64) -> std::result::Result<Self::Value, E> {
531 Ok(vec![v as f32])
532 }
533 #[allow(clippy::cast_possible_truncation)]
534 fn visit_i64<E: de::Error>(self, v: i64) -> std::result::Result<Self::Value, E> {
535 Ok(vec![v as f32])
536 }
537 #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
538 fn visit_u64<E: de::Error>(self, v: u64) -> std::result::Result<Self::Value, E> {
539 Ok(vec![v as f32])
540 }
541 fn visit_seq<A: de::SeqAccess<'de>>(
542 self,
543 mut seq: A,
544 ) -> std::result::Result<Self::Value, A::Error> {
545 let mut out = Vec::with_capacity(seq.size_hint().unwrap_or(1));
546 while let Some(x) = seq.next_element::<f32>()? {
547 out.push(x);
548 }
549 Ok(out)
550 }
551 }
552 de.deserialize_any(V)
553}
554
555fn deserialize_opt_scalar_or_vec_i32<'de, D: serde::Deserializer<'de>>(
556 de: D,
557) -> std::result::Result<Option<Vec<i32>>, D::Error> {
558 use serde::de::{self, Visitor};
559 struct V;
560 impl<'de> Visitor<'de> for V {
561 type Value = Option<Vec<i32>>;
562 fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
563 f.write_str("null, i32, or array of i32")
564 }
565 fn visit_none<E: de::Error>(self) -> std::result::Result<Self::Value, E> {
566 Ok(None)
567 }
568 fn visit_unit<E: de::Error>(self) -> std::result::Result<Self::Value, E> {
569 Ok(None)
570 }
571 fn visit_some<D2: serde::Deserializer<'de>>(
572 self,
573 de: D2,
574 ) -> std::result::Result<Self::Value, D2::Error> {
575 struct Inner;
576 impl<'de> Visitor<'de> for Inner {
577 type Value = Vec<i32>;
578 fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
579 f.write_str("i32 or array of i32")
580 }
581 #[allow(clippy::cast_possible_truncation)]
582 fn visit_i64<E: de::Error>(self, v: i64) -> std::result::Result<Self::Value, E> {
583 Ok(vec![v as i32])
584 }
585 #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
586 fn visit_u64<E: de::Error>(self, v: u64) -> std::result::Result<Self::Value, E> {
587 Ok(vec![v as i32])
588 }
589 fn visit_seq<A: de::SeqAccess<'de>>(
590 self,
591 mut seq: A,
592 ) -> std::result::Result<Self::Value, A::Error> {
593 let mut out = Vec::with_capacity(seq.size_hint().unwrap_or(1));
594 while let Some(x) = seq.next_element::<i32>()? {
595 out.push(x);
596 }
597 Ok(out)
598 }
599 }
600 de.deserialize_any(Inner).map(Some)
601 }
602 #[allow(clippy::cast_possible_truncation)]
603 fn visit_i64<E: de::Error>(self, v: i64) -> std::result::Result<Self::Value, E> {
604 Ok(Some(vec![v as i32]))
605 }
606 #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
607 fn visit_u64<E: de::Error>(self, v: u64) -> std::result::Result<Self::Value, E> {
608 Ok(Some(vec![v as i32]))
609 }
610 fn visit_seq<A: de::SeqAccess<'de>>(
611 self,
612 mut seq: A,
613 ) -> std::result::Result<Self::Value, A::Error> {
614 let mut out = Vec::with_capacity(seq.size_hint().unwrap_or(1));
615 while let Some(x) = seq.next_element::<i32>()? {
616 out.push(x);
617 }
618 Ok(Some(out))
619 }
620 }
621 de.deserialize_option(V)
622}
623
624static NEXT_BUFFER_ID: AtomicU64 = AtomicU64::new(1);
626
627#[derive(Debug, Clone)]
633pub struct BufferIdentity {
634 id: u64,
635 guard: Arc<()>,
636}
637
638impl BufferIdentity {
639 pub fn new() -> Self {
641 Self {
642 id: NEXT_BUFFER_ID.fetch_add(1, Ordering::Relaxed),
643 guard: Arc::new(()),
644 }
645 }
646
647 pub fn id(&self) -> u64 {
649 self.id
650 }
651
652 pub fn weak(&self) -> Weak<()> {
655 Arc::downgrade(&self.guard)
656 }
657}
658
659impl Default for BufferIdentity {
660 fn default() -> Self {
661 Self::new()
662 }
663}
664
665#[cfg(target_os = "linux")]
666use nix::sys::stat::{major, minor};
667
668pub trait TensorTrait<T>: Send + Sync
669where
670 T: Num + Clone + fmt::Debug,
671{
672 fn new(shape: &[usize], name: Option<&str>) -> Result<Self>
675 where
676 Self: Sized;
677
678 #[cfg(unix)]
679 fn from_fd(fd: std::os::fd::OwnedFd, shape: &[usize], name: Option<&str>) -> Result<Self>
685 where
686 Self: Sized;
687
688 #[cfg(unix)]
689 fn clone_fd(&self) -> Result<std::os::fd::OwnedFd>;
691
692 fn memory(&self) -> TensorMemory;
694
695 fn name(&self) -> String;
697
698 fn len(&self) -> usize {
700 self.shape().iter().product()
701 }
702
703 fn is_empty(&self) -> bool {
705 self.len() == 0
706 }
707
708 fn size(&self) -> usize {
710 self.len() * std::mem::size_of::<T>()
711 }
712
713 fn shape(&self) -> &[usize];
715
716 fn reshape(&mut self, shape: &[usize]) -> Result<()>;
719
720 fn map(&self) -> Result<TensorMap<T>>;
723
724 fn buffer_identity(&self) -> &BufferIdentity;
726}
727
728pub trait TensorMapTrait<T>
729where
730 T: Num + Clone + fmt::Debug,
731{
732 fn shape(&self) -> &[usize];
734
735 fn unmap(&mut self);
737
738 fn len(&self) -> usize {
740 self.shape().iter().product()
741 }
742
743 fn is_empty(&self) -> bool {
745 self.len() == 0
746 }
747
748 fn size(&self) -> usize {
750 self.len() * std::mem::size_of::<T>()
751 }
752
753 fn as_slice(&self) -> &[T];
755
756 fn as_mut_slice(&mut self) -> &mut [T];
758
759 #[cfg(feature = "ndarray")]
760 fn view(&'_ self) -> Result<ndarray::ArrayView<'_, T, ndarray::Dim<ndarray::IxDynImpl>>> {
762 Ok(ndarray::ArrayView::from_shape(
763 self.shape(),
764 self.as_slice(),
765 )?)
766 }
767
768 #[cfg(feature = "ndarray")]
769 fn view_mut(
771 &'_ mut self,
772 ) -> Result<ndarray::ArrayViewMut<'_, T, ndarray::Dim<ndarray::IxDynImpl>>> {
773 let shape = self.shape().to_vec();
774 Ok(ndarray::ArrayViewMut::from_shape(
775 shape,
776 self.as_mut_slice(),
777 )?)
778 }
779}
780
781#[derive(Debug, Clone, Copy, PartialEq, Eq)]
782pub enum TensorMemory {
783 Dma,
798 #[cfg(unix)]
799 Shm,
802
803 Mem,
805
806 Pbo,
809}
810
811impl From<TensorMemory> for String {
812 fn from(memory: TensorMemory) -> Self {
813 match memory {
814 TensorMemory::Dma => "dma".to_owned(),
815 #[cfg(unix)]
816 TensorMemory::Shm => "shm".to_owned(),
817 TensorMemory::Mem => "mem".to_owned(),
818 TensorMemory::Pbo => "pbo".to_owned(),
819 }
820 }
821}
822
823impl TryFrom<&str> for TensorMemory {
824 type Error = Error;
825
826 fn try_from(s: &str) -> Result<Self> {
827 match s {
828 "dma" => Ok(TensorMemory::Dma),
829 #[cfg(unix)]
830 "shm" => Ok(TensorMemory::Shm),
831 "mem" => Ok(TensorMemory::Mem),
832 "pbo" => Ok(TensorMemory::Pbo),
833 _ => Err(Error::InvalidMemoryType(s.to_owned())),
834 }
835 }
836}
837
838#[derive(Debug)]
839#[allow(dead_code)] pub(crate) enum TensorStorage<T>
841where
842 T: Num + Clone + fmt::Debug + Send + Sync,
843{
844 #[cfg(target_os = "linux")]
849 Dma(DmaTensor<T>),
850 #[cfg(target_os = "macos")]
851 Dma(IoSurfaceTensor<T>),
852 #[cfg(unix)]
853 Shm(ShmTensor<T>),
854 Mem(MemTensor<T>),
855 Pbo(PboTensor<T>),
856}
857
858impl<T> TensorStorage<T>
859where
860 T: Num + Clone + fmt::Debug + Send + Sync,
861{
862 fn new(shape: &[usize], memory: Option<TensorMemory>, name: Option<&str>) -> Result<Self> {
867 match memory {
868 #[cfg(target_os = "linux")]
869 Some(TensorMemory::Dma) => {
870 DmaTensor::<T>::new(shape, name).map(TensorStorage::Dma)
871 }
872 #[cfg(target_os = "macos")]
873 Some(TensorMemory::Dma) => {
874 IoSurfaceTensor::<T>::new(shape, name).map(TensorStorage::Dma)
875 }
876 #[cfg(not(any(target_os = "linux", target_os = "macos")))]
877 Some(TensorMemory::Dma) => Err(crate::error::Error::NotImplemented(
878 "TensorMemory::Dma is only available on Linux (DMA-BUF) and macOS (IOSurface)"
879 .to_owned(),
880 )),
881 #[cfg(unix)]
882 Some(TensorMemory::Shm) => {
883 ShmTensor::<T>::new(shape, name).map(TensorStorage::Shm)
884 }
885 Some(TensorMemory::Mem) => {
886 MemTensor::<T>::new(shape, name).map(TensorStorage::Mem)
887 }
888 Some(TensorMemory::Pbo) => Err(crate::error::Error::NotImplemented(
889 "PboTensor cannot be created via Tensor::new() — use ImageProcessor::create_image()".to_owned(),
890 )),
891 None => {
892 if std::env::var("EDGEFIRST_TENSOR_FORCE_MEM")
893 .is_ok_and(|x| x != "0" && x.to_lowercase() != "false")
894 {
895 MemTensor::<T>::new(shape, name).map(TensorStorage::Mem)
896 } else {
897 #[cfg(target_os = "linux")]
898 {
899 match DmaTensor::<T>::new(shape, name) {
901 Ok(tensor) => Ok(TensorStorage::Dma(tensor)),
902 Err(_) => {
903 match ShmTensor::<T>::new(shape, name)
904 .map(TensorStorage::Shm)
905 {
906 Ok(tensor) => Ok(tensor),
907 Err(_) => MemTensor::<T>::new(shape, name)
908 .map(TensorStorage::Mem),
909 }
910 }
911 }
912 }
913 #[cfg(target_os = "macos")]
914 {
915 match IoSurfaceTensor::<T>::new(shape, name) {
921 Ok(tensor) => Ok(TensorStorage::Dma(tensor)),
922 Err(_) => match ShmTensor::<T>::new(shape, name)
923 .map(TensorStorage::Shm)
924 {
925 Ok(tensor) => Ok(tensor),
926 Err(_) => MemTensor::<T>::new(shape, name)
927 .map(TensorStorage::Mem),
928 },
929 }
930 }
931 #[cfg(all(unix, not(any(target_os = "linux", target_os = "macos"))))]
932 {
933 match ShmTensor::<T>::new(shape, name) {
935 Ok(tensor) => Ok(TensorStorage::Shm(tensor)),
936 Err(_) => {
937 MemTensor::<T>::new(shape, name).map(TensorStorage::Mem)
938 }
939 }
940 }
941 #[cfg(not(unix))]
942 {
943 MemTensor::<T>::new(shape, name).map(TensorStorage::Mem)
945 }
946 }
947 }
948 }
949 }
950
951 #[cfg(target_os = "linux")]
960 pub(crate) fn new_dma_with_byte_size(
961 shape: &[usize],
962 byte_size: usize,
963 name: Option<&str>,
964 ) -> Result<Self> {
965 DmaTensor::<T>::new_with_byte_size(shape, byte_size, name).map(TensorStorage::Dma)
966 }
967
968 #[cfg(target_os = "macos")]
980 pub(crate) fn new_image_iosurface(
981 width: usize,
982 height: usize,
983 format: PixelFormat,
984 shape: &[usize],
985 name: Option<&str>,
986 ) -> Result<Self> {
987 IoSurfaceTensor::<T>::new_image(width, height, format, shape, name).map(TensorStorage::Dma)
988 }
989
990 #[cfg(unix)]
993 fn from_fd(fd: OwnedFd, shape: &[usize], name: Option<&str>) -> Result<Self> {
994 #[cfg(target_os = "linux")]
995 {
996 use nix::sys::stat::fstat;
997
998 let stat = fstat(&fd)?;
999 let major = major(stat.st_dev);
1000 let minor = minor(stat.st_dev);
1001
1002 log::debug!("Creating tensor from fd: major={major}, minor={minor}");
1003
1004 if major != 0 {
1005 return Err(Error::UnknownDeviceType(major, minor));
1007 }
1008
1009 match minor {
1010 9 | 10 => {
1011 DmaTensor::<T>::from_fd(fd, shape, name).map(TensorStorage::Dma)
1013 }
1014 _ => {
1015 ShmTensor::<T>::from_fd(fd, shape, name).map(TensorStorage::Shm)
1017 }
1018 }
1019 }
1020 #[cfg(all(unix, not(target_os = "linux")))]
1021 {
1022 ShmTensor::<T>::from_fd(fd, shape, name).map(TensorStorage::Shm)
1024 }
1025 }
1026}
1027
1028impl<T> TensorTrait<T> for TensorStorage<T>
1029where
1030 T: Num + Clone + fmt::Debug + Send + Sync,
1031{
1032 fn new(shape: &[usize], name: Option<&str>) -> Result<Self> {
1033 Self::new(shape, None, name)
1034 }
1035
1036 #[cfg(unix)]
1037 fn from_fd(fd: OwnedFd, shape: &[usize], name: Option<&str>) -> Result<Self> {
1038 Self::from_fd(fd, shape, name)
1039 }
1040
1041 #[cfg(unix)]
1042 fn clone_fd(&self) -> Result<OwnedFd> {
1043 match self {
1044 TensorStorage::Dma(t) => t.clone_fd(),
1045 TensorStorage::Shm(t) => t.clone_fd(),
1046 TensorStorage::Mem(t) => t.clone_fd(),
1047 TensorStorage::Pbo(t) => t.clone_fd(),
1048 }
1049 }
1050
1051 fn memory(&self) -> TensorMemory {
1052 match self {
1053 #[cfg(any(target_os = "linux", target_os = "macos"))]
1054 TensorStorage::Dma(_) => TensorMemory::Dma,
1055 #[cfg(unix)]
1056 TensorStorage::Shm(_) => TensorMemory::Shm,
1057 TensorStorage::Mem(_) => TensorMemory::Mem,
1058 TensorStorage::Pbo(_) => TensorMemory::Pbo,
1059 }
1060 }
1061
1062 fn name(&self) -> String {
1063 match self {
1064 #[cfg(any(target_os = "linux", target_os = "macos"))]
1065 TensorStorage::Dma(t) => t.name(),
1066 #[cfg(unix)]
1067 TensorStorage::Shm(t) => t.name(),
1068 TensorStorage::Mem(t) => t.name(),
1069 TensorStorage::Pbo(t) => t.name(),
1070 }
1071 }
1072
1073 fn shape(&self) -> &[usize] {
1074 match self {
1075 #[cfg(any(target_os = "linux", target_os = "macos"))]
1076 TensorStorage::Dma(t) => t.shape(),
1077 #[cfg(unix)]
1078 TensorStorage::Shm(t) => t.shape(),
1079 TensorStorage::Mem(t) => t.shape(),
1080 TensorStorage::Pbo(t) => t.shape(),
1081 }
1082 }
1083
1084 fn reshape(&mut self, shape: &[usize]) -> Result<()> {
1085 match self {
1086 #[cfg(any(target_os = "linux", target_os = "macos"))]
1087 TensorStorage::Dma(t) => t.reshape(shape),
1088 #[cfg(unix)]
1089 TensorStorage::Shm(t) => t.reshape(shape),
1090 TensorStorage::Mem(t) => t.reshape(shape),
1091 TensorStorage::Pbo(t) => t.reshape(shape),
1092 }
1093 }
1094
1095 fn map(&self) -> Result<TensorMap<T>> {
1096 match self {
1097 #[cfg(any(target_os = "linux", target_os = "macos"))]
1098 TensorStorage::Dma(t) => t.map(),
1099 #[cfg(unix)]
1100 TensorStorage::Shm(t) => t.map(),
1101 TensorStorage::Mem(t) => t.map(),
1102 TensorStorage::Pbo(t) => t.map(),
1103 }
1104 }
1105
1106 fn buffer_identity(&self) -> &BufferIdentity {
1107 match self {
1108 #[cfg(any(target_os = "linux", target_os = "macos"))]
1109 TensorStorage::Dma(t) => t.buffer_identity(),
1110 #[cfg(unix)]
1111 TensorStorage::Shm(t) => t.buffer_identity(),
1112 TensorStorage::Mem(t) => t.buffer_identity(),
1113 TensorStorage::Pbo(t) => t.buffer_identity(),
1114 }
1115 }
1116}
1117
1118#[derive(Debug)]
1124pub struct Tensor<T>
1125where
1126 T: Num + Clone + fmt::Debug + Send + Sync,
1127{
1128 pub(crate) storage: TensorStorage<T>,
1129 format: Option<PixelFormat>,
1130 chroma: Option<Box<Tensor<T>>>,
1131 row_stride: Option<usize>,
1134 plane_offset: Option<usize>,
1137 pub(crate) quantization: Option<Quantization>,
1141}
1142
1143impl<T> Tensor<T>
1144where
1145 T: Num + Clone + fmt::Debug + Send + Sync,
1146{
1147 pub(crate) fn wrap(storage: TensorStorage<T>) -> Self {
1149 Self {
1150 storage,
1151 format: None,
1152 chroma: None,
1153 row_stride: None,
1154 plane_offset: None,
1155 quantization: None,
1156 }
1157 }
1158
1159 pub fn from_slice(values: &[T], shape: &[usize]) -> Result<Self>
1168 where
1169 T: Copy,
1170 {
1171 let expected: usize = shape.iter().product();
1172 if values.len() != expected {
1173 return Err(Error::InvalidShape(format!(
1174 "from_slice: values.len()={} but shape product={expected} (shape={shape:?})",
1175 values.len()
1176 )));
1177 }
1178 let t = Self::new(shape, Some(TensorMemory::Mem), None)?;
1179 {
1180 let mut m = t.map()?;
1181 m.as_mut_slice().copy_from_slice(values);
1182 }
1183 Ok(t)
1184 }
1185
1186 #[cfg(feature = "ndarray")]
1191 pub fn from_arrayview3(view: ndarray::ArrayView3<'_, T>) -> Result<Self>
1192 where
1193 T: Copy,
1194 {
1195 let (h, w, c) = view.dim();
1196 let t = Self::new(&[h, w, c], Some(TensorMemory::Mem), None)?;
1197 {
1198 let mut m = t.map()?;
1199 let dst = m.as_mut_slice();
1200 if let Some(src) = view.as_slice() {
1201 dst.copy_from_slice(src);
1202 } else {
1203 for (d, &s) in dst.iter_mut().zip(view.iter()) {
1204 *d = s;
1205 }
1206 }
1207 }
1208 Ok(t)
1209 }
1210
1211 pub fn new(shape: &[usize], memory: Option<TensorMemory>, name: Option<&str>) -> Result<Self> {
1236 let _span = tracing::trace_span!(
1237 "tensor.alloc",
1238 ?shape,
1239 memory = ?memory,
1240 dtype = std::any::type_name::<T>(),
1241 )
1242 .entered();
1243 TensorStorage::new(shape, memory, name).map(Self::wrap)
1244 }
1245
1246 pub fn image(
1248 width: usize,
1249 height: usize,
1250 format: PixelFormat,
1251 memory: Option<TensorMemory>,
1252 ) -> Result<Self> {
1253 let shape = match format.layout() {
1254 PixelLayout::Packed => vec![height, width, format.channels()],
1255 PixelLayout::Planar => vec![format.channels(), height, width],
1256 PixelLayout::SemiPlanar => {
1257 let total_h = match format {
1261 PixelFormat::Nv12 => {
1262 if !height.is_multiple_of(2) {
1263 return Err(Error::InvalidArgument(format!(
1264 "NV12 requires even height, got {height}"
1265 )));
1266 }
1267 height * 3 / 2
1268 }
1269 PixelFormat::Nv16 => height * 2,
1270 _ => {
1271 return Err(Error::InvalidArgument(format!(
1272 "unknown semi-planar height multiplier for {format:?}"
1273 )))
1274 }
1275 };
1276 vec![total_h, width]
1277 }
1278 };
1279
1280 #[cfg(target_os = "macos")]
1293 if matches!(memory, Some(TensorMemory::Dma)) {
1294 let natural_row_bytes = width * format.channels() * std::mem::size_of::<T>();
1295 if natural_row_bytes.is_multiple_of(64) {
1296 if let Ok(storage) =
1297 TensorStorage::<T>::new_image_iosurface(width, height, format, &shape, None)
1298 {
1299 let mut t = Self::wrap(storage);
1300 t.format = Some(format);
1301 return Ok(t);
1302 }
1303 }
1304 }
1308
1309 let mut t = Self::new(&shape, memory, None)?;
1310 t.format = Some(format);
1311 Ok(t)
1312 }
1313
1314 pub fn image_with_stride(
1350 width: usize,
1351 height: usize,
1352 format: PixelFormat,
1353 row_stride_bytes: usize,
1354 memory: Option<TensorMemory>,
1355 ) -> Result<Self> {
1356 #[cfg(not(target_os = "linux"))]
1366 {
1367 let _ = (width, height, format, row_stride_bytes, memory);
1368 Err(Error::NotImplemented(
1369 "image_with_stride requires DMA support (Linux only)".to_owned(),
1370 ))
1371 }
1372
1373 #[cfg(target_os = "linux")]
1374 {
1375 if format.layout() != PixelLayout::Packed {
1376 return Err(Error::NotImplemented(format!(
1377 "Tensor::image_with_stride only supports packed pixel layouts, got {format:?}"
1378 )));
1379 }
1380 let elem = std::mem::size_of::<T>();
1381 let min_stride = width
1382 .checked_mul(format.channels())
1383 .and_then(|p| p.checked_mul(elem))
1384 .ok_or_else(|| {
1385 Error::InvalidArgument(format!(
1386 "image_with_stride: width {width} × channels {} × sizeof::<T>={elem} \
1387 overflows usize",
1388 format.channels()
1389 ))
1390 })?;
1391 if row_stride_bytes < min_stride {
1392 return Err(Error::InvalidArgument(format!(
1393 "image_with_stride: row_stride {row_stride_bytes} < minimum {min_stride} \
1394 ({width} px × {} ch × {elem} B)",
1395 format.channels()
1396 )));
1397 }
1398 let total_byte_size = row_stride_bytes.checked_mul(height).ok_or_else(|| {
1399 Error::InvalidArgument(format!(
1400 "image_with_stride: row_stride {row_stride_bytes} × height {height} overflows usize"
1401 ))
1402 })?;
1403
1404 let shape = vec![height, width, format.channels()];
1405
1406 let storage = match memory {
1407 Some(TensorMemory::Dma) | None => {
1408 TensorStorage::<T>::new_dma_with_byte_size(&shape, total_byte_size, None)?
1409 }
1410 Some(other) => {
1411 return Err(Error::NotImplemented(format!(
1412 "image_with_stride: only TensorMemory::Dma is supported, got {other:?}"
1413 )));
1414 }
1415 };
1416
1417 let mut t = Self::wrap(storage);
1418 t.format = Some(format);
1419 t.row_stride = Some(row_stride_bytes);
1420 Ok(t)
1421 }
1422 }
1423
1424 pub fn set_format(&mut self, format: PixelFormat) -> Result<()> {
1441 let shape = self.shape();
1442 match format.layout() {
1443 PixelLayout::Packed => {
1444 if shape.len() != 3 || shape[2] != format.channels() {
1445 return Err(Error::InvalidShape(format!(
1446 "packed format {format:?} expects [H, W, {}], got {shape:?}",
1447 format.channels()
1448 )));
1449 }
1450 }
1451 PixelLayout::Planar => {
1452 if shape.len() != 3 || shape[0] != format.channels() {
1453 return Err(Error::InvalidShape(format!(
1454 "planar format {format:?} expects [{}, H, W], got {shape:?}",
1455 format.channels()
1456 )));
1457 }
1458 }
1459 PixelLayout::SemiPlanar => {
1460 if shape.len() != 2 {
1461 return Err(Error::InvalidShape(format!(
1462 "semi-planar format {format:?} expects [H*k, W], got {shape:?}"
1463 )));
1464 }
1465 match format {
1466 PixelFormat::Nv12 if !shape[0].is_multiple_of(3) => {
1467 return Err(Error::InvalidShape(format!(
1468 "NV12 contiguous shape[0] must be divisible by 3, got {}",
1469 shape[0]
1470 )));
1471 }
1472 PixelFormat::Nv16 if !shape[0].is_multiple_of(2) => {
1473 return Err(Error::InvalidShape(format!(
1474 "NV16 contiguous shape[0] must be even, got {}",
1475 shape[0]
1476 )));
1477 }
1478 _ => {}
1479 }
1480 }
1481 }
1482 if self.format != Some(format) {
1485 self.row_stride = None;
1486 self.plane_offset = None;
1487 #[cfg(target_os = "linux")]
1488 if let TensorStorage::Dma(ref mut dma) = self.storage {
1489 dma.mmap_offset = 0;
1490 }
1491 }
1492 self.format = Some(format);
1493 Ok(())
1494 }
1495
1496 pub fn format(&self) -> Option<PixelFormat> {
1498 self.format
1499 }
1500
1501 pub fn width(&self) -> Option<usize> {
1503 let fmt = self.format?;
1504 let shape = self.shape();
1505 match fmt.layout() {
1506 PixelLayout::Packed => Some(shape[1]),
1507 PixelLayout::Planar => Some(shape[2]),
1508 PixelLayout::SemiPlanar => Some(shape[1]),
1509 }
1510 }
1511
1512 pub fn height(&self) -> Option<usize> {
1514 let fmt = self.format?;
1515 let shape = self.shape();
1516 match fmt.layout() {
1517 PixelLayout::Packed => Some(shape[0]),
1518 PixelLayout::Planar => Some(shape[1]),
1519 PixelLayout::SemiPlanar => {
1520 if self.is_multiplane() {
1521 Some(shape[0])
1522 } else {
1523 match fmt {
1524 PixelFormat::Nv12 => Some(shape[0] * 2 / 3),
1525 PixelFormat::Nv16 => Some(shape[0] / 2),
1526 _ => None,
1527 }
1528 }
1529 }
1530 }
1531 }
1532
1533 pub fn from_planes(luma: Tensor<T>, chroma: Tensor<T>, format: PixelFormat) -> Result<Self> {
1535 if format.layout() != PixelLayout::SemiPlanar {
1536 return Err(Error::InvalidArgument(format!(
1537 "from_planes requires a semi-planar format, got {format:?}"
1538 )));
1539 }
1540 if chroma.format.is_some() || chroma.chroma.is_some() {
1541 return Err(Error::InvalidArgument(
1542 "chroma tensor must be a raw tensor (no format or chroma metadata)".into(),
1543 ));
1544 }
1545 let luma_shape = luma.shape();
1546 let chroma_shape = chroma.shape();
1547 if luma_shape.len() != 2 || chroma_shape.len() != 2 {
1548 return Err(Error::InvalidArgument(format!(
1549 "from_planes expects 2D shapes, got luma={luma_shape:?} chroma={chroma_shape:?}"
1550 )));
1551 }
1552 if luma_shape[1] != chroma_shape[1] {
1553 return Err(Error::InvalidArgument(format!(
1554 "luma width {} != chroma width {}",
1555 luma_shape[1], chroma_shape[1]
1556 )));
1557 }
1558 match format {
1559 PixelFormat::Nv12 => {
1560 if luma_shape[0] % 2 != 0 {
1561 return Err(Error::InvalidArgument(format!(
1562 "NV12 requires even luma height, got {}",
1563 luma_shape[0]
1564 )));
1565 }
1566 if chroma_shape[0] != luma_shape[0] / 2 {
1567 return Err(Error::InvalidArgument(format!(
1568 "NV12 chroma height {} != luma height / 2 ({})",
1569 chroma_shape[0],
1570 luma_shape[0] / 2
1571 )));
1572 }
1573 }
1574 PixelFormat::Nv16 => {
1575 if chroma_shape[0] != luma_shape[0] {
1576 return Err(Error::InvalidArgument(format!(
1577 "NV16 chroma height {} != luma height {}",
1578 chroma_shape[0], luma_shape[0]
1579 )));
1580 }
1581 }
1582 _ => {
1583 return Err(Error::InvalidArgument(format!(
1584 "from_planes only supports NV12 and NV16, got {format:?}"
1585 )));
1586 }
1587 }
1588
1589 Ok(Tensor {
1590 storage: luma.storage,
1591 format: Some(format),
1592 chroma: Some(Box::new(chroma)),
1593 row_stride: luma.row_stride,
1594 plane_offset: luma.plane_offset,
1595 quantization: luma.quantization,
1596 })
1597 }
1598
1599 pub fn is_multiplane(&self) -> bool {
1601 self.chroma.is_some()
1602 }
1603
1604 pub fn chroma(&self) -> Option<&Tensor<T>> {
1606 self.chroma.as_deref()
1607 }
1608
1609 pub fn chroma_mut(&mut self) -> Option<&mut Tensor<T>> {
1611 self.chroma.as_deref_mut()
1612 }
1613
1614 pub fn row_stride(&self) -> Option<usize> {
1616 self.row_stride
1617 }
1618
1619 pub fn effective_row_stride(&self) -> Option<usize> {
1624 if let Some(s) = self.row_stride {
1625 return Some(s);
1626 }
1627 let fmt = self.format?;
1628 let w = self.width()?;
1629 let elem = std::mem::size_of::<T>();
1630 Some(match fmt.layout() {
1631 PixelLayout::Packed => w * fmt.channels() * elem,
1632 PixelLayout::Planar | PixelLayout::SemiPlanar => w * elem,
1633 })
1634 }
1635
1636 pub fn set_row_stride(&mut self, stride: usize) -> Result<()> {
1663 let fmt = self.format.ok_or_else(|| {
1664 Error::InvalidArgument("cannot set row_stride without a pixel format".into())
1665 })?;
1666 let w = self.width().ok_or_else(|| {
1667 Error::InvalidArgument("cannot determine width for row_stride validation".into())
1668 })?;
1669 let elem = std::mem::size_of::<T>();
1670 let min_stride = match fmt.layout() {
1671 PixelLayout::Packed => w * fmt.channels() * elem,
1672 PixelLayout::Planar | PixelLayout::SemiPlanar => w * elem,
1673 };
1674 if stride < min_stride {
1675 return Err(Error::InvalidArgument(format!(
1676 "row_stride {stride} < minimum {min_stride} for {fmt:?} at width {w}"
1677 )));
1678 }
1679 self.row_stride = Some(stride);
1680 Ok(())
1681 }
1682
1683 pub fn set_row_stride_unchecked(&mut self, stride: usize) {
1689 self.row_stride = Some(stride);
1690 }
1691
1692 pub fn with_row_stride(mut self, stride: usize) -> Result<Self> {
1699 self.set_row_stride(stride)?;
1700 Ok(self)
1701 }
1702
1703 pub fn plane_offset(&self) -> Option<usize> {
1705 self.plane_offset
1706 }
1707
1708 pub fn set_plane_offset(&mut self, offset: usize) {
1714 self.plane_offset = Some(offset);
1715 #[cfg(target_os = "linux")]
1716 if let TensorStorage::Dma(ref mut dma) = self.storage {
1717 dma.mmap_offset = offset;
1718 }
1719 }
1720
1721 pub fn with_plane_offset(mut self, offset: usize) -> Self {
1724 self.set_plane_offset(offset);
1725 self
1726 }
1727
1728 pub fn as_pbo(&self) -> Option<&PboTensor<T>> {
1730 match &self.storage {
1731 TensorStorage::Pbo(p) => Some(p),
1732 _ => None,
1733 }
1734 }
1735
1736 #[cfg(target_os = "linux")]
1738 pub fn as_dma(&self) -> Option<&DmaTensor<T>> {
1739 match &self.storage {
1740 TensorStorage::Dma(d) => Some(d),
1741 _ => None,
1742 }
1743 }
1744
1745 #[cfg(target_os = "linux")]
1756 pub fn dmabuf(&self) -> Result<std::os::fd::BorrowedFd<'_>> {
1757 use std::os::fd::AsFd;
1758 match &self.storage {
1759 TensorStorage::Dma(dma) => Ok(dma.fd.as_fd()),
1760 _ => Err(Error::NotImplemented(format!(
1761 "dmabuf requires DMA-backed tensor, got {:?}",
1762 self.storage.memory()
1763 ))),
1764 }
1765 }
1766
1767 pub fn from_pbo(pbo: PboTensor<T>) -> Self {
1769 Self {
1770 storage: TensorStorage::Pbo(pbo),
1771 format: None,
1772 chroma: None,
1773 row_stride: None,
1774 plane_offset: None,
1775 quantization: None,
1776 }
1777 }
1778}
1779
1780impl<T> Tensor<T>
1784where
1785 T: IntegerType + Num + Clone + fmt::Debug + Send + Sync,
1786{
1787 pub fn quantization(&self) -> Option<&Quantization> {
1789 self.quantization.as_ref()
1790 }
1791
1792 pub fn set_quantization(&mut self, q: Quantization) -> Result<()> {
1796 q.validate(self.shape())?;
1797 self.quantization = Some(q);
1798 Ok(())
1799 }
1800
1801 pub fn with_quantization(mut self, q: Quantization) -> Result<Self> {
1807 self.set_quantization(q)?;
1808 Ok(self)
1809 }
1810
1811 pub fn clear_quantization(&mut self) {
1813 self.quantization = None;
1814 }
1815}
1816
1817impl<T> TensorTrait<T> for Tensor<T>
1818where
1819 T: Num + Clone + fmt::Debug + Send + Sync,
1820{
1821 fn new(shape: &[usize], name: Option<&str>) -> Result<Self>
1822 where
1823 Self: Sized,
1824 {
1825 Self::new(shape, None, name)
1826 }
1827
1828 #[cfg(unix)]
1829 fn from_fd(fd: std::os::fd::OwnedFd, shape: &[usize], name: Option<&str>) -> Result<Self>
1830 where
1831 Self: Sized,
1832 {
1833 Ok(Self::wrap(TensorStorage::from_fd(fd, shape, name)?))
1834 }
1835
1836 #[cfg(unix)]
1837 fn clone_fd(&self) -> Result<std::os::fd::OwnedFd> {
1838 self.storage.clone_fd()
1839 }
1840
1841 fn memory(&self) -> TensorMemory {
1842 self.storage.memory()
1843 }
1844
1845 fn name(&self) -> String {
1846 self.storage.name()
1847 }
1848
1849 fn shape(&self) -> &[usize] {
1850 self.storage.shape()
1851 }
1852
1853 fn reshape(&mut self, shape: &[usize]) -> Result<()> {
1854 if self.chroma.is_some() {
1855 return Err(Error::InvalidOperation(
1856 "cannot reshape a multiplane tensor — decompose planes first".into(),
1857 ));
1858 }
1859 self.storage.reshape(shape)?;
1860 self.format = None;
1861 self.row_stride = None;
1862 self.plane_offset = None;
1863 #[cfg(target_os = "linux")]
1864 if let TensorStorage::Dma(ref mut dma) = self.storage {
1865 dma.mmap_offset = 0;
1866 }
1867 Ok(())
1868 }
1869
1870 fn map(&self) -> Result<TensorMap<T>> {
1871 let _span = tracing::trace_span!(
1872 "tensor.map",
1873 memory = ?self.storage.memory(),
1874 )
1875 .entered();
1876 #[cfg(target_os = "linux")]
1893 if let Some(stride) = self.row_stride {
1894 if let TensorStorage::Dma(dma) = &self.storage {
1895 if !dma.is_imported {
1896 let height = self.height().ok_or_else(|| {
1917 Error::InvalidOperation(
1918 "Tensor::map: strided DMA mapping requires a PixelFormat \
1919 so height() can be derived; set a format before mapping \
1920 or clear row_stride for raw tensor access"
1921 .into(),
1922 )
1923 })?;
1924 let total_bytes = stride.checked_mul(height).ok_or_else(|| {
1925 Error::InvalidOperation(format!(
1926 "Tensor::map: row_stride {stride} × height {height} overflows usize"
1927 ))
1928 })?;
1929 let available_bytes = dma.buf_size.saturating_sub(dma.mmap_offset);
1930 if total_bytes > available_bytes {
1931 return Err(Error::InvalidOperation(format!(
1932 "Tensor::map: strided mapping needs {total_bytes} bytes \
1933 but DMA buffer only has {available_bytes} available \
1934 (buf_size={}, mmap_offset={}, stride={stride}, height={height}); \
1935 the row_stride was likely set larger than the original allocation",
1936 dma.buf_size, dma.mmap_offset
1937 )));
1938 }
1939 return dma.map_with_byte_size(total_bytes).map(TensorMap::Dma);
1940 }
1941 }
1942 return Err(Error::InvalidOperation(
1943 "CPU mapping of strided foreign tensors is not supported; \
1944 use GPU path only"
1945 .into(),
1946 ));
1947 }
1948 #[cfg(not(target_os = "linux"))]
1949 if self.row_stride.is_some() {
1950 return Err(Error::InvalidOperation(
1951 "CPU mapping of strided tensors is not supported on this \
1952 platform (DMA backing is Linux-only)"
1953 .into(),
1954 ));
1955 }
1956 if self.plane_offset.is_some_and(|o| o > 0) {
1960 #[cfg(target_os = "linux")]
1961 if !matches!(self.storage, TensorStorage::Dma(_)) {
1962 return Err(Error::InvalidOperation(
1963 "plane offset only supported for DMA tensors".into(),
1964 ));
1965 }
1966 #[cfg(not(target_os = "linux"))]
1967 return Err(Error::InvalidOperation(
1968 "plane offset only supported for DMA tensors".into(),
1969 ));
1970 }
1971 self.storage.map()
1972 }
1973
1974 fn buffer_identity(&self) -> &BufferIdentity {
1975 self.storage.buffer_identity()
1976 }
1977}
1978
1979pub enum TensorMap<T>
1980where
1981 T: Num + Clone + fmt::Debug,
1982{
1983 #[cfg(target_os = "linux")]
1984 Dma(DmaMap<T>),
1985 #[cfg(target_os = "macos")]
1986 IoSurface(IoSurfaceMap<T>),
1987 #[cfg(unix)]
1988 Shm(ShmMap<T>),
1989 Mem(MemMap<T>),
1990 Pbo(PboMap<T>),
1991}
1992
1993impl<T> TensorMapTrait<T> for TensorMap<T>
1994where
1995 T: Num + Clone + fmt::Debug,
1996{
1997 fn shape(&self) -> &[usize] {
1998 match self {
1999 #[cfg(target_os = "linux")]
2000 TensorMap::Dma(map) => map.shape(),
2001 #[cfg(target_os = "macos")]
2002 TensorMap::IoSurface(map) => map.shape(),
2003 #[cfg(unix)]
2004 TensorMap::Shm(map) => map.shape(),
2005 TensorMap::Mem(map) => map.shape(),
2006 TensorMap::Pbo(map) => map.shape(),
2007 }
2008 }
2009
2010 fn unmap(&mut self) {
2011 match self {
2012 #[cfg(target_os = "linux")]
2013 TensorMap::Dma(map) => map.unmap(),
2014 #[cfg(target_os = "macos")]
2015 TensorMap::IoSurface(map) => map.unmap(),
2016 #[cfg(unix)]
2017 TensorMap::Shm(map) => map.unmap(),
2018 TensorMap::Mem(map) => map.unmap(),
2019 TensorMap::Pbo(map) => map.unmap(),
2020 }
2021 }
2022
2023 fn as_slice(&self) -> &[T] {
2024 match self {
2025 #[cfg(target_os = "linux")]
2026 TensorMap::Dma(map) => map.as_slice(),
2027 #[cfg(target_os = "macos")]
2028 TensorMap::IoSurface(map) => map.deref(),
2029 #[cfg(unix)]
2030 TensorMap::Shm(map) => map.as_slice(),
2031 TensorMap::Mem(map) => map.as_slice(),
2032 TensorMap::Pbo(map) => map.as_slice(),
2033 }
2034 }
2035
2036 fn as_mut_slice(&mut self) -> &mut [T] {
2037 match self {
2038 #[cfg(target_os = "linux")]
2039 TensorMap::Dma(map) => map.as_mut_slice(),
2040 #[cfg(target_os = "macos")]
2041 TensorMap::IoSurface(map) => map.deref_mut(),
2042 #[cfg(unix)]
2043 TensorMap::Shm(map) => map.as_mut_slice(),
2044 TensorMap::Mem(map) => map.as_mut_slice(),
2045 TensorMap::Pbo(map) => map.as_mut_slice(),
2046 }
2047 }
2048}
2049
2050impl<T> Deref for TensorMap<T>
2051where
2052 T: Num + Clone + fmt::Debug,
2053{
2054 type Target = [T];
2055
2056 fn deref(&self) -> &[T] {
2057 match self {
2058 #[cfg(target_os = "linux")]
2059 TensorMap::Dma(map) => map.deref(),
2060 #[cfg(target_os = "macos")]
2061 TensorMap::IoSurface(map) => map.deref(),
2062 #[cfg(unix)]
2063 TensorMap::Shm(map) => map.deref(),
2064 TensorMap::Mem(map) => map.deref(),
2065 TensorMap::Pbo(map) => map.deref(),
2066 }
2067 }
2068}
2069
2070impl<T> DerefMut for TensorMap<T>
2071where
2072 T: Num + Clone + fmt::Debug,
2073{
2074 fn deref_mut(&mut self) -> &mut [T] {
2075 match self {
2076 #[cfg(target_os = "linux")]
2077 TensorMap::Dma(map) => map.deref_mut(),
2078 #[cfg(target_os = "macos")]
2079 TensorMap::IoSurface(map) => map.deref_mut(),
2080 #[cfg(unix)]
2081 TensorMap::Shm(map) => map.deref_mut(),
2082 TensorMap::Mem(map) => map.deref_mut(),
2083 TensorMap::Pbo(map) => map.deref_mut(),
2084 }
2085 }
2086}
2087
2088#[cfg(target_os = "linux")]
2094static DMA_AVAILABLE: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
2095#[cfg(target_os = "macos")]
2097static IOSURFACE_AVAILABLE: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
2098
2099#[cfg(target_os = "linux")]
2108pub fn is_dma_available() -> bool {
2109 *DMA_AVAILABLE.get_or_init(|| Tensor::<u8>::new(&[64], Some(TensorMemory::Dma), None).is_ok())
2110}
2111
2112#[cfg(not(target_os = "linux"))]
2114pub fn is_dma_available() -> bool {
2115 false
2116}
2117
2118#[cfg(target_os = "macos")]
2127pub fn is_iosurface_available() -> bool {
2128 *IOSURFACE_AVAILABLE.get_or_init(|| {
2129 Tensor::<u8>::new(&[64], Some(TensorMemory::Dma), None).is_ok()
2132 })
2133}
2134
2135#[cfg(not(target_os = "macos"))]
2136pub fn is_iosurface_available() -> bool {
2137 false
2138}
2139
2140pub fn is_gpu_buffer_available() -> bool {
2146 #[cfg(target_os = "linux")]
2147 {
2148 is_dma_available()
2149 }
2150 #[cfg(target_os = "macos")]
2151 {
2152 is_iosurface_available()
2153 }
2154 #[cfg(not(any(target_os = "linux", target_os = "macos")))]
2155 {
2156 false
2157 }
2158}
2159
2160#[cfg(unix)]
2167static SHM_AVAILABLE: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
2168
2169#[cfg(unix)]
2171pub fn is_shm_available() -> bool {
2172 *SHM_AVAILABLE.get_or_init(|| Tensor::<u8>::new(&[64], Some(TensorMemory::Shm), None).is_ok())
2173}
2174
2175#[cfg(not(unix))]
2179pub fn is_shm_available() -> bool {
2180 false
2181}
2182
2183#[cfg(test)]
2184mod dtype_tests {
2185 use super::*;
2186
2187 #[test]
2188 fn dtype_size() {
2189 assert_eq!(DType::U8.size(), 1);
2190 assert_eq!(DType::I8.size(), 1);
2191 assert_eq!(DType::U16.size(), 2);
2192 assert_eq!(DType::I16.size(), 2);
2193 assert_eq!(DType::U32.size(), 4);
2194 assert_eq!(DType::I32.size(), 4);
2195 assert_eq!(DType::U64.size(), 8);
2196 assert_eq!(DType::I64.size(), 8);
2197 assert_eq!(DType::F16.size(), 2);
2198 assert_eq!(DType::F32.size(), 4);
2199 assert_eq!(DType::F64.size(), 8);
2200 }
2201
2202 #[test]
2203 fn dtype_name() {
2204 assert_eq!(DType::U8.name(), "u8");
2205 assert_eq!(DType::F16.name(), "f16");
2206 assert_eq!(DType::F32.name(), "f32");
2207 }
2208
2209 #[test]
2210 fn dtype_serde_roundtrip() {
2211 use serde_json;
2212 let dt = DType::F16;
2213 let json = serde_json::to_string(&dt).unwrap();
2214 let back: DType = serde_json::from_str(&json).unwrap();
2215 assert_eq!(dt, back);
2216 }
2217}
2218
2219#[cfg(test)]
2220mod image_tests {
2221 use super::*;
2222
2223 #[test]
2224 fn raw_tensor_has_no_format() {
2225 let t = Tensor::<u8>::new(&[480, 640, 3], None, None).unwrap();
2226 assert!(t.format().is_none());
2227 assert!(t.width().is_none());
2228 assert!(t.height().is_none());
2229 assert!(!t.is_multiplane());
2230 assert!(t.chroma().is_none());
2231 }
2232
2233 #[test]
2234 fn image_tensor_packed() {
2235 let t = Tensor::<u8>::image(640, 480, PixelFormat::Rgba, None).unwrap();
2236 assert_eq!(t.format(), Some(PixelFormat::Rgba));
2237 assert_eq!(t.width(), Some(640));
2238 assert_eq!(t.height(), Some(480));
2239 assert_eq!(t.shape(), &[480, 640, 4]);
2240 assert!(!t.is_multiplane());
2241 }
2242
2243 #[test]
2244 fn image_tensor_planar() {
2245 let t = Tensor::<u8>::image(640, 480, PixelFormat::PlanarRgb, None).unwrap();
2246 assert_eq!(t.format(), Some(PixelFormat::PlanarRgb));
2247 assert_eq!(t.width(), Some(640));
2248 assert_eq!(t.height(), Some(480));
2249 assert_eq!(t.shape(), &[3, 480, 640]);
2250 }
2251
2252 #[test]
2253 fn image_tensor_semi_planar_contiguous() {
2254 let t = Tensor::<u8>::image(640, 480, PixelFormat::Nv12, None).unwrap();
2255 assert_eq!(t.format(), Some(PixelFormat::Nv12));
2256 assert_eq!(t.width(), Some(640));
2257 assert_eq!(t.height(), Some(480));
2258 assert_eq!(t.shape(), &[720, 640]);
2260 assert!(!t.is_multiplane());
2261 }
2262
2263 #[test]
2264 #[cfg(target_os = "linux")]
2265 fn image_tensor_with_stride_preserves_logical_width() {
2266 if !is_dma_available() {
2268 eprintln!("SKIPPED: DMA heap not available");
2269 return;
2270 }
2271 let stride = 12032;
2273 let t = Tensor::<u8>::image_with_stride(
2274 3004,
2275 1688,
2276 PixelFormat::Rgba,
2277 stride,
2278 Some(TensorMemory::Dma),
2279 )
2280 .unwrap();
2281 assert_eq!(t.width(), Some(3004));
2283 assert_eq!(t.height(), Some(1688));
2284 assert_eq!(t.shape(), &[1688, 3004, 4]);
2285 assert_eq!(t.effective_row_stride(), Some(stride));
2287 use crate::TensorMapTrait;
2290 {
2291 let map = t.map().unwrap();
2292 assert!(
2293 map.as_slice().len() >= stride * 1688,
2294 "mapped buffer {} bytes < expected {}",
2295 map.as_slice().len(),
2296 stride * 1688
2297 );
2298 }
2299 {
2302 let mut map = t.map().unwrap();
2303 let slice = map.as_mut_slice();
2304 for y in 0..1688 {
2305 let row_start = y * stride;
2306 for x in 0..3004 {
2307 let p = row_start + x * 4;
2308 slice[p] = (y & 0xFF) as u8;
2309 slice[p + 1] = (x & 0xFF) as u8;
2310 slice[p + 2] = 0x42;
2311 slice[p + 3] = 0xFF;
2312 }
2313 }
2314 }
2315 {
2316 let map = t.map().unwrap();
2317 let slice = map.as_slice();
2318 assert_eq!(slice[0], 0x00);
2320 assert_eq!(slice[1], 0x00);
2321 assert_eq!(slice[2], 0x42);
2322 assert_eq!(slice[3], 0xFF);
2323 let mid = 100 * stride + 50 * 4;
2324 assert_eq!(slice[mid], 100);
2325 assert_eq!(slice[mid + 1], 50);
2326 assert_eq!(slice[mid + 2], 0x42);
2327 }
2328 }
2329
2330 #[test]
2331 #[cfg(target_os = "linux")]
2332 fn image_tensor_with_stride_rejects_foreign_strided_map() {
2333 if !is_dma_available() {
2341 eprintln!("SKIPPED: DMA heap not available");
2342 return;
2343 }
2344 let backing = Tensor::<u8>::new(&[240 * 320 * 4], Some(TensorMemory::Dma), None).unwrap();
2346 let fd = backing.clone_fd().unwrap();
2347 let shape = [240usize, 320, 4];
2349 let storage = TensorStorage::<u8>::from_fd(fd, &shape, None).unwrap();
2350 let mut t = Tensor::<u8>::wrap(storage);
2351 t.set_format(PixelFormat::Bgra).unwrap();
2352 t.set_row_stride(320 * 4).unwrap(); let err = t.map();
2354 assert!(
2355 matches!(err, Err(Error::InvalidOperation(_))),
2356 "foreign strided map should error"
2357 );
2358 }
2359
2360 #[test]
2361 #[cfg(target_os = "linux")]
2362 fn image_tensor_with_stride_map_rejects_tampered_stride() {
2363 if !is_dma_available() {
2370 eprintln!("SKIPPED: DMA heap not available");
2371 return;
2372 }
2373 let mut t = Tensor::<u8>::image_with_stride(
2376 640,
2377 480,
2378 PixelFormat::Rgba,
2379 3072,
2380 Some(TensorMemory::Dma),
2381 )
2382 .unwrap();
2383 t.set_row_stride(12288).unwrap();
2386 let err = t.map();
2388 assert!(
2389 matches!(err, Err(Error::InvalidOperation(_))),
2390 "map() with oversized stride must return InvalidOperation"
2391 );
2392 }
2393
2394 #[test]
2395 fn dma_tensor_new_with_byte_size_rejects_shape_overflow() {
2396 #[cfg(target_os = "linux")]
2403 {
2404 let err = crate::dma::DmaTensor::<u64>::new_with_byte_size(
2405 &[usize::MAX, 2, 2],
2406 usize::MAX,
2407 None,
2408 );
2409 assert!(
2410 matches!(err, Err(Error::InvalidArgument(_))),
2411 "new_with_byte_size must detect shape.product() overflow"
2412 );
2413 }
2414 }
2415
2416 #[test]
2417 #[cfg(target_os = "linux")]
2418 fn image_tensor_with_stride_rejects_too_small_stride() {
2419 let err = Tensor::<u8>::image_with_stride(
2421 640,
2422 480,
2423 PixelFormat::Rgba,
2424 2400,
2425 Some(TensorMemory::Dma),
2426 );
2427 assert!(matches!(err, Err(Error::InvalidArgument(_))));
2428 }
2429
2430 #[test]
2431 #[cfg(target_os = "linux")]
2432 fn image_tensor_with_stride_rejects_non_packed() {
2433 let err = Tensor::<u8>::image_with_stride(
2436 640,
2437 480,
2438 PixelFormat::Nv12,
2439 640,
2440 Some(TensorMemory::Dma),
2441 );
2442 assert!(matches!(err, Err(Error::NotImplemented(_))));
2443 }
2444
2445 #[test]
2446 fn set_format_valid() {
2447 let mut t = Tensor::<u8>::new(&[480, 640, 3], None, None).unwrap();
2448 assert!(t.format().is_none());
2449 t.set_format(PixelFormat::Rgb).unwrap();
2450 assert_eq!(t.format(), Some(PixelFormat::Rgb));
2451 assert_eq!(t.width(), Some(640));
2452 assert_eq!(t.height(), Some(480));
2453 }
2454
2455 #[test]
2456 fn set_format_invalid_shape() {
2457 let mut t = Tensor::<u8>::new(&[480, 640, 4], None, None).unwrap();
2458 let err = t.set_format(PixelFormat::Rgb);
2460 assert!(err.is_err());
2461 assert!(t.format().is_none());
2463 }
2464
2465 #[test]
2466 fn reshape_clears_format() {
2467 let mut t = Tensor::<u8>::image(640, 480, PixelFormat::Rgba, None).unwrap();
2468 assert_eq!(t.format(), Some(PixelFormat::Rgba));
2469 t.reshape(&[480 * 640 * 4]).unwrap();
2471 assert!(t.format().is_none());
2472 }
2473
2474 #[test]
2475 fn from_planes_nv12() {
2476 let y = Tensor::<u8>::new(&[480, 640], None, None).unwrap();
2477 let uv = Tensor::<u8>::new(&[240, 640], None, None).unwrap();
2478 let img = Tensor::from_planes(y, uv, PixelFormat::Nv12).unwrap();
2479 assert_eq!(img.format(), Some(PixelFormat::Nv12));
2480 assert!(img.is_multiplane());
2481 assert!(img.chroma().is_some());
2482 assert_eq!(img.width(), Some(640));
2483 assert_eq!(img.height(), Some(480));
2484 }
2485
2486 #[test]
2487 fn from_planes_rejects_non_semiplanar() {
2488 let y = Tensor::<u8>::new(&[480, 640], None, None).unwrap();
2489 let uv = Tensor::<u8>::new(&[240, 640], None, None).unwrap();
2490 let err = Tensor::from_planes(y, uv, PixelFormat::Rgb);
2491 assert!(err.is_err());
2492 }
2493
2494 #[test]
2495 fn reshape_multiplane_errors() {
2496 let y = Tensor::<u8>::new(&[480, 640], None, None).unwrap();
2497 let uv = Tensor::<u8>::new(&[240, 640], None, None).unwrap();
2498 let mut img = Tensor::from_planes(y, uv, PixelFormat::Nv12).unwrap();
2499 let err = img.reshape(&[480 * 640 + 240 * 640]);
2500 assert!(err.is_err());
2501 }
2502}
2503
2504#[cfg(test)]
2505mod tests {
2506 #[cfg(target_os = "linux")]
2507 use nix::unistd::{access, AccessFlags};
2508 #[cfg(target_os = "linux")]
2509 use std::io::Write as _;
2510 use std::sync::RwLock;
2511
2512 use super::*;
2513
2514 #[ctor::ctor]
2515 fn init() {
2516 env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init();
2517 }
2518
2519 #[cfg(target_os = "linux")]
2521 macro_rules! function {
2522 () => {{
2523 fn f() {}
2524 fn type_name_of<T>(_: T) -> &'static str {
2525 std::any::type_name::<T>()
2526 }
2527 let name = type_name_of(f);
2528
2529 match &name[..name.len() - 3].rfind(':') {
2531 Some(pos) => &name[pos + 1..name.len() - 3],
2532 None => &name[..name.len() - 3],
2533 }
2534 }};
2535 }
2536
2537 #[test]
2538 #[cfg(target_os = "linux")]
2539 fn test_tensor() {
2540 let _lock = FD_LOCK.read().unwrap();
2541 let shape = vec![1];
2542 let tensor = DmaTensor::<f32>::new(&shape, Some("dma_tensor"));
2543 let dma_enabled = tensor.is_ok();
2544
2545 let tensor = Tensor::<f32>::new(&shape, None, None).expect("Failed to create tensor");
2546 match dma_enabled {
2547 true => assert_eq!(tensor.memory(), TensorMemory::Dma),
2548 false => assert_eq!(tensor.memory(), TensorMemory::Shm),
2549 }
2550 }
2551
2552 #[test]
2553 #[cfg(target_os = "macos")]
2554 fn test_tensor() {
2555 let shape = vec![1];
2556 let tensor = Tensor::<f32>::new(&shape, None, None).expect("Failed to create tensor");
2557 let m = tensor.memory();
2562 assert!(
2563 matches!(m, TensorMemory::Dma | TensorMemory::Shm | TensorMemory::Mem),
2564 "Unexpected auto-fallback result on macOS: {m:?}"
2565 );
2566 }
2567
2568 #[test]
2569 #[cfg(all(unix, not(any(target_os = "linux", target_os = "macos"))))]
2570 fn test_tensor() {
2571 let shape = vec![1];
2572 let tensor = Tensor::<f32>::new(&shape, None, None).expect("Failed to create tensor");
2573 assert!(
2575 tensor.memory() == TensorMemory::Shm || tensor.memory() == TensorMemory::Mem,
2576 "Expected SHM or Mem, got {:?}",
2577 tensor.memory()
2578 );
2579 }
2580
2581 #[test]
2582 #[cfg(not(unix))]
2583 fn test_tensor() {
2584 let shape = vec![1];
2585 let tensor = Tensor::<f32>::new(&shape, None, None).expect("Failed to create tensor");
2586 assert_eq!(tensor.memory(), TensorMemory::Mem);
2587 }
2588
2589 #[test]
2590 #[cfg(target_os = "linux")]
2591 fn test_dma_tensor() {
2592 let _lock = FD_LOCK.read().unwrap();
2593 match access(
2594 "/dev/dma_heap/linux,cma",
2595 AccessFlags::R_OK | AccessFlags::W_OK,
2596 ) {
2597 Ok(_) => println!("/dev/dma_heap/linux,cma is available"),
2598 Err(_) => match access(
2599 "/dev/dma_heap/system",
2600 AccessFlags::R_OK | AccessFlags::W_OK,
2601 ) {
2602 Ok(_) => println!("/dev/dma_heap/system is available"),
2603 Err(e) => {
2604 writeln!(
2605 &mut std::io::stdout(),
2606 "[WARNING] DMA Heap is unavailable: {e}"
2607 )
2608 .unwrap();
2609 return;
2610 }
2611 },
2612 }
2613
2614 let shape = vec![2, 3, 4];
2615 let tensor =
2616 DmaTensor::<f32>::new(&shape, Some("test_tensor")).expect("Failed to create tensor");
2617
2618 const DUMMY_VALUE: f32 = 12.34;
2619
2620 assert_eq!(tensor.memory(), TensorMemory::Dma);
2621 assert_eq!(tensor.name(), "test_tensor");
2622 assert_eq!(tensor.shape(), &shape);
2623 assert_eq!(tensor.size(), 2 * 3 * 4 * std::mem::size_of::<f32>());
2624 assert_eq!(tensor.len(), 2 * 3 * 4);
2625
2626 {
2627 let mut tensor_map = tensor.map().expect("Failed to map DMA memory");
2628 tensor_map.fill(42.0);
2629 assert!(tensor_map.iter().all(|&x| x == 42.0));
2630 }
2631
2632 {
2633 let shared = Tensor::<f32>::from_fd(
2634 tensor
2635 .clone_fd()
2636 .expect("Failed to duplicate tensor file descriptor"),
2637 &shape,
2638 Some("test_tensor_shared"),
2639 )
2640 .expect("Failed to create tensor from fd");
2641
2642 assert_eq!(shared.memory(), TensorMemory::Dma);
2643 assert_eq!(shared.name(), "test_tensor_shared");
2644 assert_eq!(shared.shape(), &shape);
2645
2646 let mut tensor_map = shared.map().expect("Failed to map DMA memory from fd");
2647 tensor_map.fill(DUMMY_VALUE);
2648 assert!(tensor_map.iter().all(|&x| x == DUMMY_VALUE));
2649 }
2650
2651 {
2652 let tensor_map = tensor.map().expect("Failed to map DMA memory");
2653 assert!(tensor_map.iter().all(|&x| x == DUMMY_VALUE));
2654 }
2655
2656 let mut tensor = DmaTensor::<u8>::new(&shape, None).expect("Failed to create tensor");
2657 assert_eq!(tensor.shape(), &shape);
2658 let new_shape = vec![3, 4, 4];
2659 assert!(
2660 tensor.reshape(&new_shape).is_err(),
2661 "Reshape should fail due to size mismatch"
2662 );
2663 assert_eq!(tensor.shape(), &shape, "Shape should remain unchanged");
2664
2665 let new_shape = vec![2, 3, 4];
2666 tensor.reshape(&new_shape).expect("Reshape should succeed");
2667 assert_eq!(
2668 tensor.shape(),
2669 &new_shape,
2670 "Shape should be updated after successful reshape"
2671 );
2672
2673 {
2674 let mut tensor_map = tensor.map().expect("Failed to map DMA memory");
2675 tensor_map.fill(1);
2676 assert!(tensor_map.iter().all(|&x| x == 1));
2677 }
2678
2679 {
2680 let mut tensor_map = tensor.map().expect("Failed to map DMA memory");
2681 tensor_map[2] = 42;
2682 assert_eq!(tensor_map[1], 1, "Value at index 1 should be 1");
2683 assert_eq!(tensor_map[2], 42, "Value at index 2 should be 42");
2684 }
2685 }
2686
2687 #[test]
2688 #[cfg(unix)]
2689 fn test_shm_tensor() {
2690 let _lock = FD_LOCK.read().unwrap();
2691 let shape = vec![2, 3, 4];
2692 let tensor =
2693 ShmTensor::<f32>::new(&shape, Some("test_tensor")).expect("Failed to create tensor");
2694 assert_eq!(tensor.shape(), &shape);
2695 assert_eq!(tensor.size(), 2 * 3 * 4 * std::mem::size_of::<f32>());
2696 assert_eq!(tensor.name(), "test_tensor");
2697
2698 const DUMMY_VALUE: f32 = 12.34;
2699 {
2700 let mut tensor_map = tensor.map().expect("Failed to map shared memory");
2701 tensor_map.fill(42.0);
2702 assert!(tensor_map.iter().all(|&x| x == 42.0));
2703 }
2704
2705 {
2706 let shared = Tensor::<f32>::from_fd(
2707 tensor
2708 .clone_fd()
2709 .expect("Failed to duplicate tensor file descriptor"),
2710 &shape,
2711 Some("test_tensor_shared"),
2712 )
2713 .expect("Failed to create tensor from fd");
2714
2715 assert_eq!(shared.memory(), TensorMemory::Shm);
2716 assert_eq!(shared.name(), "test_tensor_shared");
2717 assert_eq!(shared.shape(), &shape);
2718
2719 let mut tensor_map = shared.map().expect("Failed to map shared memory from fd");
2720 tensor_map.fill(DUMMY_VALUE);
2721 assert!(tensor_map.iter().all(|&x| x == DUMMY_VALUE));
2722 }
2723
2724 {
2725 let tensor_map = tensor.map().expect("Failed to map shared memory");
2726 assert!(tensor_map.iter().all(|&x| x == DUMMY_VALUE));
2727 }
2728
2729 let mut tensor = ShmTensor::<u8>::new(&shape, None).expect("Failed to create tensor");
2730 assert_eq!(tensor.shape(), &shape);
2731 let new_shape = vec![3, 4, 4];
2732 assert!(
2733 tensor.reshape(&new_shape).is_err(),
2734 "Reshape should fail due to size mismatch"
2735 );
2736 assert_eq!(tensor.shape(), &shape, "Shape should remain unchanged");
2737
2738 let new_shape = vec![2, 3, 4];
2739 tensor.reshape(&new_shape).expect("Reshape should succeed");
2740 assert_eq!(
2741 tensor.shape(),
2742 &new_shape,
2743 "Shape should be updated after successful reshape"
2744 );
2745
2746 {
2747 let mut tensor_map = tensor.map().expect("Failed to map shared memory");
2748 tensor_map.fill(1);
2749 assert!(tensor_map.iter().all(|&x| x == 1));
2750 }
2751
2752 {
2753 let mut tensor_map = tensor.map().expect("Failed to map shared memory");
2754 tensor_map[2] = 42;
2755 assert_eq!(tensor_map[1], 1, "Value at index 1 should be 1");
2756 assert_eq!(tensor_map[2], 42, "Value at index 2 should be 42");
2757 }
2758 }
2759
2760 #[test]
2761 fn test_mem_tensor() {
2762 let shape = vec![2, 3, 4];
2763 let tensor =
2764 MemTensor::<f32>::new(&shape, Some("test_tensor")).expect("Failed to create tensor");
2765 assert_eq!(tensor.shape(), &shape);
2766 assert_eq!(tensor.size(), 2 * 3 * 4 * std::mem::size_of::<f32>());
2767 assert_eq!(tensor.name(), "test_tensor");
2768
2769 {
2770 let mut tensor_map = tensor.map().expect("Failed to map memory");
2771 tensor_map.fill(42.0);
2772 assert!(tensor_map.iter().all(|&x| x == 42.0));
2773 }
2774
2775 let mut tensor = MemTensor::<u8>::new(&shape, None).expect("Failed to create tensor");
2776 assert_eq!(tensor.shape(), &shape);
2777 let new_shape = vec![3, 4, 4];
2778 assert!(
2779 tensor.reshape(&new_shape).is_err(),
2780 "Reshape should fail due to size mismatch"
2781 );
2782 assert_eq!(tensor.shape(), &shape, "Shape should remain unchanged");
2783
2784 let new_shape = vec![2, 3, 4];
2785 tensor.reshape(&new_shape).expect("Reshape should succeed");
2786 assert_eq!(
2787 tensor.shape(),
2788 &new_shape,
2789 "Shape should be updated after successful reshape"
2790 );
2791
2792 {
2793 let mut tensor_map = tensor.map().expect("Failed to map memory");
2794 tensor_map.fill(1);
2795 assert!(tensor_map.iter().all(|&x| x == 1));
2796 }
2797
2798 {
2799 let mut tensor_map = tensor.map().expect("Failed to map memory");
2800 tensor_map[2] = 42;
2801 assert_eq!(tensor_map[1], 1, "Value at index 1 should be 1");
2802 assert_eq!(tensor_map[2], 42, "Value at index 2 should be 42");
2803 }
2804 }
2805
2806 #[test]
2807 #[cfg(target_os = "linux")]
2808 fn test_dma_no_fd_leaks() {
2809 let _lock = FD_LOCK.write().unwrap();
2810 if !is_dma_available() {
2811 log::warn!(
2812 "SKIPPED: {} - DMA memory allocation not available (permission denied or no DMA-BUF support)",
2813 function!()
2814 );
2815 return;
2816 }
2817
2818 let proc = procfs::process::Process::myself()
2819 .expect("Failed to get current process using /proc/self");
2820
2821 let start_open_fds = proc
2822 .fd_count()
2823 .expect("Failed to get open file descriptor count");
2824
2825 for _ in 0..100 {
2826 let tensor = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Dma), None)
2827 .expect("Failed to create tensor");
2828 let mut map = tensor.map().unwrap();
2829 map.as_mut_slice().fill(233);
2830 }
2831
2832 let end_open_fds = proc
2833 .fd_count()
2834 .expect("Failed to get open file descriptor count");
2835
2836 assert_eq!(
2837 start_open_fds, end_open_fds,
2838 "File descriptor leak detected: {} -> {}",
2839 start_open_fds, end_open_fds
2840 );
2841 }
2842
2843 #[test]
2844 #[cfg(target_os = "linux")]
2845 fn test_dma_from_fd_no_fd_leaks() {
2846 let _lock = FD_LOCK.write().unwrap();
2847 if !is_dma_available() {
2848 log::warn!(
2849 "SKIPPED: {} - DMA memory allocation not available (permission denied or no DMA-BUF support)",
2850 function!()
2851 );
2852 return;
2853 }
2854
2855 let proc = procfs::process::Process::myself()
2856 .expect("Failed to get current process using /proc/self");
2857
2858 let start_open_fds = proc
2859 .fd_count()
2860 .expect("Failed to get open file descriptor count");
2861
2862 let orig = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Dma), None).unwrap();
2863
2864 for _ in 0..100 {
2865 let tensor =
2866 Tensor::<u8>::from_fd(orig.clone_fd().unwrap(), orig.shape(), None).unwrap();
2867 let mut map = tensor.map().unwrap();
2868 map.as_mut_slice().fill(233);
2869 }
2870 drop(orig);
2871
2872 let end_open_fds = proc.fd_count().unwrap();
2873
2874 assert_eq!(
2875 start_open_fds, end_open_fds,
2876 "File descriptor leak detected: {} -> {}",
2877 start_open_fds, end_open_fds
2878 );
2879 }
2880
2881 #[test]
2882 #[cfg(target_os = "linux")]
2883 fn test_shm_no_fd_leaks() {
2884 let _lock = FD_LOCK.write().unwrap();
2885 if !is_shm_available() {
2886 log::warn!(
2887 "SKIPPED: {} - SHM memory allocation not available (permission denied or no SHM support)",
2888 function!()
2889 );
2890 return;
2891 }
2892
2893 let proc = procfs::process::Process::myself()
2894 .expect("Failed to get current process using /proc/self");
2895
2896 let start_open_fds = proc
2897 .fd_count()
2898 .expect("Failed to get open file descriptor count");
2899
2900 for _ in 0..100 {
2901 let tensor = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Shm), None)
2902 .expect("Failed to create tensor");
2903 let mut map = tensor.map().unwrap();
2904 map.as_mut_slice().fill(233);
2905 }
2906
2907 let end_open_fds = proc
2908 .fd_count()
2909 .expect("Failed to get open file descriptor count");
2910
2911 assert_eq!(
2912 start_open_fds, end_open_fds,
2913 "File descriptor leak detected: {} -> {}",
2914 start_open_fds, end_open_fds
2915 );
2916 }
2917
2918 #[test]
2919 #[cfg(target_os = "linux")]
2920 fn test_shm_from_fd_no_fd_leaks() {
2921 let _lock = FD_LOCK.write().unwrap();
2922 if !is_shm_available() {
2923 log::warn!(
2924 "SKIPPED: {} - SHM memory allocation not available (permission denied or no SHM support)",
2925 function!()
2926 );
2927 return;
2928 }
2929
2930 let proc = procfs::process::Process::myself()
2931 .expect("Failed to get current process using /proc/self");
2932
2933 let start_open_fds = proc
2934 .fd_count()
2935 .expect("Failed to get open file descriptor count");
2936
2937 let orig = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Shm), None).unwrap();
2938
2939 for _ in 0..100 {
2940 let tensor =
2941 Tensor::<u8>::from_fd(orig.clone_fd().unwrap(), orig.shape(), None).unwrap();
2942 let mut map = tensor.map().unwrap();
2943 map.as_mut_slice().fill(233);
2944 }
2945 drop(orig);
2946
2947 let end_open_fds = proc.fd_count().unwrap();
2948
2949 assert_eq!(
2950 start_open_fds, end_open_fds,
2951 "File descriptor leak detected: {} -> {}",
2952 start_open_fds, end_open_fds
2953 );
2954 }
2955
2956 #[cfg(feature = "ndarray")]
2957 #[test]
2958 fn test_ndarray() {
2959 let _lock = FD_LOCK.read().unwrap();
2960 let shape = vec![2, 3, 4];
2961 let tensor = Tensor::<f32>::new(&shape, None, None).expect("Failed to create tensor");
2962
2963 let mut tensor_map = tensor.map().expect("Failed to map tensor memory");
2964 tensor_map.fill(1.0);
2965
2966 let view = tensor_map.view().expect("Failed to get ndarray view");
2967 assert_eq!(view.shape(), &[2, 3, 4]);
2968 assert!(view.iter().all(|&x| x == 1.0));
2969
2970 let mut view_mut = tensor_map
2971 .view_mut()
2972 .expect("Failed to get mutable ndarray view");
2973 view_mut[[0, 0, 0]] = 42.0;
2974 assert_eq!(view_mut[[0, 0, 0]], 42.0);
2975 assert_eq!(tensor_map[0], 42.0, "Value at index 0 should be 42");
2976 }
2977
2978 #[test]
2979 fn test_buffer_identity_unique() {
2980 let id1 = BufferIdentity::new();
2981 let id2 = BufferIdentity::new();
2982 assert_ne!(
2983 id1.id(),
2984 id2.id(),
2985 "Two identities should have different ids"
2986 );
2987 }
2988
2989 #[test]
2990 fn test_buffer_identity_clone_shares_guard() {
2991 let id1 = BufferIdentity::new();
2992 let weak = id1.weak();
2993 assert!(
2994 weak.upgrade().is_some(),
2995 "Weak should be alive while original exists"
2996 );
2997
2998 let id2 = id1.clone();
2999 assert_eq!(id1.id(), id2.id(), "Cloned identity should have same id");
3000
3001 drop(id1);
3002 assert!(
3003 weak.upgrade().is_some(),
3004 "Weak should still be alive (clone holds Arc)"
3005 );
3006
3007 drop(id2);
3008 assert!(
3009 weak.upgrade().is_none(),
3010 "Weak should be dead after all clones dropped"
3011 );
3012 }
3013
3014 #[test]
3015 fn test_tensor_buffer_identity() {
3016 let t1 = Tensor::<u8>::new(&[100], Some(TensorMemory::Mem), Some("t1")).unwrap();
3017 let t2 = Tensor::<u8>::new(&[100], Some(TensorMemory::Mem), Some("t2")).unwrap();
3018 assert_ne!(
3019 t1.buffer_identity().id(),
3020 t2.buffer_identity().id(),
3021 "Different tensors should have different buffer ids"
3022 );
3023 }
3024
3025 #[test]
3030 fn test_quantization_per_tensor_constructors() {
3031 let q = Quantization::per_tensor(0.1, -5);
3032 assert!(q.is_per_tensor());
3033 assert!(!q.is_per_channel());
3034 assert!(!q.is_symmetric());
3035 assert_eq!(q.scale(), &[0.1]);
3036 assert_eq!(q.zero_point(), Some(&[-5][..]));
3037
3038 let qs = Quantization::per_tensor_symmetric(0.05);
3039 assert!(qs.is_per_tensor());
3040 assert!(qs.is_symmetric());
3041 assert_eq!(qs.zero_point(), None);
3042 }
3043
3044 #[test]
3045 fn test_quantization_per_channel_constructors() {
3046 let q = Quantization::per_channel(vec![0.1, 0.2, 0.3], vec![0, -1, 1], 2).unwrap();
3047 assert!(q.is_per_channel());
3048 assert!(!q.is_symmetric());
3049 assert_eq!(q.axis(), Some(2));
3050 assert_eq!(q.scale().len(), 3);
3051
3052 let qs = Quantization::per_channel_symmetric(vec![0.054, 0.089, 0.195], 0).unwrap();
3053 assert!(qs.is_per_channel());
3054 assert!(qs.is_symmetric());
3055 assert_eq!(qs.axis(), Some(0));
3056 }
3057
3058 #[test]
3059 fn test_quantization_per_channel_length_mismatch_rejected() {
3060 let err = Quantization::per_channel(vec![0.1, 0.2], vec![0, 0, 0], 0).unwrap_err();
3062 assert!(matches!(err, Error::QuantizationInvalid { .. }));
3063 }
3064
3065 #[test]
3066 fn test_quantization_per_channel_empty_rejected() {
3067 let err = Quantization::per_channel_symmetric(vec![], 0).unwrap_err();
3068 assert!(matches!(err, Error::QuantizationInvalid { .. }));
3069 }
3070
3071 #[test]
3078 fn test_quantization_validate_rejects_malformed_deserialize() {
3079 let mut t = Tensor::<i8>::new(&[1, 1, 4], Some(TensorMemory::Mem), None).unwrap();
3080
3081 let q: Quantization = serde_json::from_str(r#"{"scale": []}"#).unwrap();
3083 assert!(matches!(
3084 t.set_quantization(q).unwrap_err(),
3085 Error::QuantizationInvalid { .. }
3086 ));
3087
3088 let q: Quantization =
3090 serde_json::from_str(r#"{"scale": 0.1, "zero_point": [0, 0, 0]}"#).unwrap();
3091 assert!(matches!(
3092 t.set_quantization(q).unwrap_err(),
3093 Error::QuantizationInvalid { .. }
3094 ));
3095
3096 let q: Quantization = serde_json::from_str(
3098 r#"{"scale": [0.1, 0.2, 0.3, 0.4], "zero_point": [0, 0], "axis": 2}"#,
3099 )
3100 .unwrap();
3101 assert!(matches!(
3102 t.set_quantization(q).unwrap_err(),
3103 Error::QuantizationInvalid { .. }
3104 ));
3105 }
3106
3107 #[test]
3108 fn test_quantization_mode_dispatch() {
3109 let pt = Quantization::per_tensor(0.1, -5);
3110 assert!(matches!(
3111 pt.mode(),
3112 QuantMode::PerTensor { scale, zero_point } if scale == 0.1 && zero_point == -5
3113 ));
3114
3115 let pts = Quantization::per_tensor_symmetric(0.05);
3116 assert!(matches!(
3117 pts.mode(),
3118 QuantMode::PerTensorSymmetric { scale } if scale == 0.05
3119 ));
3120
3121 let pc = Quantization::per_channel(vec![0.1, 0.2], vec![0, -1], 2).unwrap();
3122 assert!(matches!(pc.mode(), QuantMode::PerChannel { axis: 2, .. }));
3123
3124 let pcs = Quantization::per_channel_symmetric(vec![0.1, 0.2], 0).unwrap();
3125 assert!(matches!(
3126 pcs.mode(),
3127 QuantMode::PerChannelSymmetric { axis: 0, .. }
3128 ));
3129 }
3130
3131 #[test]
3132 fn test_tensor_quantization_roundtrip_integer() {
3133 let mut t = Tensor::<i8>::new(&[2, 3, 4], Some(TensorMemory::Mem), None).unwrap();
3134 assert!(t.quantization().is_none());
3135 t.set_quantization(Quantization::per_tensor(0.1, -5))
3136 .unwrap();
3137 let q = t.quantization().unwrap();
3138 assert_eq!(q.scale(), &[0.1]);
3139 t.clear_quantization();
3140 assert!(t.quantization().is_none());
3141 }
3142
3143 #[test]
3144 fn test_tensor_with_quantization_builder() {
3145 let t = Tensor::<i8>::new(&[4, 4], Some(TensorMemory::Mem), None)
3146 .unwrap()
3147 .with_quantization(Quantization::per_tensor_symmetric(0.05))
3148 .unwrap();
3149 assert!(t.quantization().is_some());
3150 }
3151
3152 #[test]
3153 fn test_tensor_dyn_quantization_float_arm_returns_none() {
3154 let t = Tensor::<f32>::new(&[2, 2], Some(TensorMemory::Mem), None).unwrap();
3155 let td = TensorDyn::F32(t);
3156 assert!(td.quantization().is_none());
3157 }
3158
3159 #[test]
3160 fn test_tensor_dyn_set_quantization_float_arm_errors() {
3161 let t = Tensor::<f32>::new(&[2, 2], Some(TensorMemory::Mem), None).unwrap();
3162 let mut td = TensorDyn::F32(t);
3163 let err = td
3164 .set_quantization(Quantization::per_tensor(0.1, 0))
3165 .unwrap_err();
3166 assert!(matches!(err, Error::QuantizationInvalid { .. }));
3168 }
3169
3170 fn _compile_fail_doctest_anchor() {}
3180
3181 pub static FD_LOCK: RwLock<()> = RwLock::new(());
3185
3186 #[test]
3189 #[cfg(not(target_os = "linux"))]
3190 fn test_dma_not_available_on_non_linux() {
3191 assert!(
3192 !is_dma_available(),
3193 "DMA memory allocation should NOT be available on non-Linux platforms"
3194 );
3195 }
3196
3197 #[test]
3200 #[cfg(unix)]
3201 fn test_shm_available_and_usable() {
3202 assert!(
3203 is_shm_available(),
3204 "SHM memory allocation should be available on Unix systems"
3205 );
3206
3207 let tensor = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Shm), None)
3209 .expect("Failed to create SHM tensor");
3210
3211 let mut map = tensor.map().expect("Failed to map SHM tensor");
3213 map.as_mut_slice().fill(0xAB);
3214
3215 assert!(
3217 map.as_slice().iter().all(|&b| b == 0xAB),
3218 "SHM tensor data should be writable and readable"
3219 );
3220 }
3221}