1#[cfg(target_os = "linux")]
30mod dma;
31#[cfg(target_os = "linux")]
32mod dmabuf;
33mod error;
34mod format;
35mod mem;
36mod pbo;
37#[cfg(unix)]
38mod shm;
39mod tensor_dyn;
40
41#[cfg(target_os = "linux")]
42pub use crate::dma::{DmaMap, DmaTensor};
43pub use crate::mem::{MemMap, MemTensor};
44pub use crate::pbo::{PboMap, PboMapping, PboOps, PboTensor};
45#[cfg(unix)]
46pub use crate::shm::{ShmMap, ShmTensor};
47pub use error::{Error, Result};
48pub use format::{PixelFormat, PixelLayout};
49use num_traits::Num;
50use serde::{Deserialize, Serialize};
51#[cfg(unix)]
52use std::os::fd::OwnedFd;
53use std::{
54 fmt,
55 ops::{Deref, DerefMut},
56 sync::{
57 atomic::{AtomicU64, Ordering},
58 Arc, Weak,
59 },
60};
61pub use tensor_dyn::TensorDyn;
62
63#[cfg(unix)]
83pub struct PlaneDescriptor {
84 fd: OwnedFd,
85 stride: Option<usize>,
86 offset: Option<usize>,
87}
88
89#[cfg(unix)]
90impl PlaneDescriptor {
91 pub fn new(fd: std::os::fd::BorrowedFd<'_>) -> Result<Self> {
101 let owned = fd.try_clone_to_owned()?;
102 Ok(Self {
103 fd: owned,
104 stride: None,
105 offset: None,
106 })
107 }
108
109 pub fn with_stride(mut self, stride: usize) -> Self {
111 self.stride = Some(stride);
112 self
113 }
114
115 pub fn with_offset(mut self, offset: usize) -> Self {
117 self.offset = Some(offset);
118 self
119 }
120
121 pub fn into_fd(self) -> OwnedFd {
123 self.fd
124 }
125
126 pub fn stride(&self) -> Option<usize> {
128 self.stride
129 }
130
131 pub fn offset(&self) -> Option<usize> {
133 self.offset
134 }
135}
136
137#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
139#[repr(u8)]
140#[non_exhaustive]
141pub enum DType {
142 U8,
143 I8,
144 U16,
145 I16,
146 U32,
147 I32,
148 U64,
149 I64,
150 F16,
151 F32,
152 F64,
153}
154
155impl DType {
156 pub const fn size(&self) -> usize {
158 match self {
159 Self::U8 | Self::I8 => 1,
160 Self::U16 | Self::I16 | Self::F16 => 2,
161 Self::U32 | Self::I32 | Self::F32 => 4,
162 Self::U64 | Self::I64 | Self::F64 => 8,
163 }
164 }
165
166 pub const fn name(&self) -> &'static str {
168 match self {
169 Self::U8 => "u8",
170 Self::I8 => "i8",
171 Self::U16 => "u16",
172 Self::I16 => "i16",
173 Self::U32 => "u32",
174 Self::I32 => "i32",
175 Self::U64 => "u64",
176 Self::I64 => "i64",
177 Self::F16 => "f16",
178 Self::F32 => "f32",
179 Self::F64 => "f64",
180 }
181 }
182}
183
184impl fmt::Display for DType {
185 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
186 f.write_str(self.name())
187 }
188}
189
190mod sealed {
198 pub trait Sealed {}
199 impl Sealed for u8 {}
200 impl Sealed for i8 {}
201 impl Sealed for u16 {}
202 impl Sealed for i16 {}
203 impl Sealed for u32 {}
204 impl Sealed for i32 {}
205 impl Sealed for u64 {}
206 impl Sealed for i64 {}
207 }
209
210pub trait IntegerType: sealed::Sealed {}
217impl IntegerType for u8 {}
218impl IntegerType for i8 {}
219impl IntegerType for u16 {}
220impl IntegerType for i16 {}
221impl IntegerType for u32 {}
222impl IntegerType for i32 {}
223impl IntegerType for u64 {}
224impl IntegerType for i64 {}
225
226#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
250pub struct Quantization {
251 #[serde(deserialize_with = "deserialize_scalar_or_vec_f32")]
253 scale: Vec<f32>,
254
255 #[serde(
258 default,
259 deserialize_with = "deserialize_opt_scalar_or_vec_i32",
260 skip_serializing_if = "Option::is_none"
261 )]
262 zero_point: Option<Vec<i32>>,
263
264 #[serde(default, skip_serializing_if = "Option::is_none")]
268 axis: Option<usize>,
269}
270
271#[derive(Debug, Clone, Copy)]
277pub enum QuantMode<'a> {
278 PerTensorSymmetric {
279 scale: f32,
280 },
281 PerTensor {
282 scale: f32,
283 zero_point: i32,
284 },
285 PerChannelSymmetric {
286 scales: &'a [f32],
287 axis: usize,
288 },
289 PerChannel {
290 scales: &'a [f32],
291 zero_points: &'a [i32],
292 axis: usize,
293 },
294}
295
296impl Quantization {
297 pub fn per_tensor_symmetric(scale: f32) -> Self {
299 Self {
300 scale: vec![scale],
301 zero_point: None,
302 axis: None,
303 }
304 }
305
306 pub fn per_tensor(scale: f32, zero_point: i32) -> Self {
308 Self {
309 scale: vec![scale],
310 zero_point: Some(vec![zero_point]),
311 axis: None,
312 }
313 }
314
315 pub fn per_channel_symmetric(scales: Vec<f32>, axis: usize) -> Result<Self> {
317 if scales.is_empty() {
318 return Err(Error::QuantizationInvalid {
319 field: "scale.len",
320 expected: "non-empty per-channel scales".to_string(),
321 got: "length 0".to_string(),
322 });
323 }
324 Ok(Self {
325 scale: scales,
326 zero_point: None,
327 axis: Some(axis),
328 })
329 }
330
331 pub fn per_channel(scales: Vec<f32>, zero_points: Vec<i32>, axis: usize) -> Result<Self> {
334 if scales.is_empty() {
335 return Err(Error::QuantizationInvalid {
336 field: "scale.len",
337 expected: "non-empty per-channel scales".to_string(),
338 got: "length 0".to_string(),
339 });
340 }
341 if scales.len() != zero_points.len() {
342 return Err(Error::QuantizationInvalid {
343 field: "zero_point.len",
344 expected: format!("length matches scale ({})", scales.len()),
345 got: format!("length {}", zero_points.len()),
346 });
347 }
348 Ok(Self {
349 scale: scales,
350 zero_point: Some(zero_points),
351 axis: Some(axis),
352 })
353 }
354
355 pub fn mode(&self) -> QuantMode<'_> {
357 match (self.scale.len(), self.zero_point.as_deref(), self.axis) {
358 (1, None, _) => QuantMode::PerTensorSymmetric {
359 scale: self.scale[0],
360 },
361 (1, Some(zps), _) => QuantMode::PerTensor {
362 scale: self.scale[0],
363 zero_point: zps.first().copied().unwrap_or(0),
364 },
365 (_, None, Some(axis)) => QuantMode::PerChannelSymmetric {
366 scales: &self.scale,
367 axis,
368 },
369 (_, Some(zps), Some(axis)) => QuantMode::PerChannel {
370 scales: &self.scale,
371 zero_points: zps,
372 axis,
373 },
374 _ => {
380 debug_assert!(
381 false,
382 "Quantization::mode: per-channel without axis is unreachable"
383 );
384 QuantMode::PerTensorSymmetric {
385 scale: self.scale.first().copied().unwrap_or(1.0),
386 }
387 }
388 }
389 }
390
391 pub fn is_per_tensor(&self) -> bool {
393 self.scale.len() == 1
394 }
395
396 pub fn is_per_channel(&self) -> bool {
398 self.scale.len() > 1
399 }
400
401 pub fn is_symmetric(&self) -> bool {
404 match &self.zero_point {
405 None => true,
406 Some(zps) => zps.iter().all(|&z| z == 0),
407 }
408 }
409
410 pub fn scale(&self) -> &[f32] {
413 &self.scale
414 }
415
416 pub fn zero_point(&self) -> Option<&[i32]> {
418 self.zero_point.as_deref()
419 }
420
421 pub fn axis(&self) -> Option<usize> {
423 self.axis
424 }
425
426 pub(crate) fn validate(&self, shape: &[usize]) -> Result<()> {
436 if self.scale.is_empty() {
441 return Err(Error::QuantizationInvalid {
442 field: "scale.len",
443 expected: ">= 1".to_string(),
444 got: "0".to_string(),
445 });
446 }
447 if let Some(zps) = self.zero_point.as_ref() {
448 let expected = if self.scale.len() == 1 {
451 1
452 } else {
453 self.scale.len()
454 };
455 if zps.len() != expected {
456 return Err(Error::QuantizationInvalid {
457 field: "zero_point.len",
458 expected: format!(
459 "{expected} (matching {})",
460 if self.scale.len() == 1 {
461 "per-tensor scale"
462 } else {
463 "per-channel scale.len"
464 }
465 ),
466 got: format!("length {}", zps.len()),
467 });
468 }
469 }
470
471 match (self.scale.len(), self.axis) {
472 (1, None) => Ok(()),
473 (1, Some(_)) => Err(Error::QuantizationInvalid {
474 field: "per_tensor_redundant_axis",
475 expected: "axis=None for per-tensor quantization".to_string(),
476 got: format!("axis={:?}", self.axis),
477 }),
478 (_, None) => Err(Error::QuantizationInvalid {
479 field: "per_channel_requires_axis",
480 expected: format!(
481 "axis=Some(_) for per-channel quantization (scale.len={})",
482 self.scale.len()
483 ),
484 got: "axis=None".to_string(),
485 }),
486 (n, Some(axis)) => {
487 if axis >= shape.len() {
488 return Err(Error::QuantizationInvalid {
489 field: "axis",
490 expected: format!("axis < tensor rank ({})", shape.len()),
491 got: format!("axis={axis}"),
492 });
493 }
494 if shape[axis] != n {
495 return Err(Error::QuantizationInvalid {
496 field: "scale.len",
497 expected: format!("length matches shape[{axis}] ({})", shape[axis]),
498 got: format!("length {n}"),
499 });
500 }
501 Ok(())
502 }
503 }
504 }
505}
506
507impl From<(f32, i32)> for Quantization {
508 fn from((scale, zero_point): (f32, i32)) -> Self {
512 Self::per_tensor(scale, zero_point)
513 }
514}
515
516fn deserialize_scalar_or_vec_f32<'de, D: serde::Deserializer<'de>>(
517 de: D,
518) -> std::result::Result<Vec<f32>, D::Error> {
519 use serde::de::{self, Visitor};
520 struct V;
521 impl<'de> Visitor<'de> for V {
522 type Value = Vec<f32>;
523 fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
524 f.write_str("f32 or array of f32")
525 }
526 fn visit_f64<E: de::Error>(self, v: f64) -> std::result::Result<Self::Value, E> {
527 Ok(vec![v as f32])
528 }
529 #[allow(clippy::cast_possible_truncation)]
530 fn visit_i64<E: de::Error>(self, v: i64) -> std::result::Result<Self::Value, E> {
531 Ok(vec![v as f32])
532 }
533 #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
534 fn visit_u64<E: de::Error>(self, v: u64) -> std::result::Result<Self::Value, E> {
535 Ok(vec![v as f32])
536 }
537 fn visit_seq<A: de::SeqAccess<'de>>(
538 self,
539 mut seq: A,
540 ) -> std::result::Result<Self::Value, A::Error> {
541 let mut out = Vec::with_capacity(seq.size_hint().unwrap_or(1));
542 while let Some(x) = seq.next_element::<f32>()? {
543 out.push(x);
544 }
545 Ok(out)
546 }
547 }
548 de.deserialize_any(V)
549}
550
551fn deserialize_opt_scalar_or_vec_i32<'de, D: serde::Deserializer<'de>>(
552 de: D,
553) -> std::result::Result<Option<Vec<i32>>, D::Error> {
554 use serde::de::{self, Visitor};
555 struct V;
556 impl<'de> Visitor<'de> for V {
557 type Value = Option<Vec<i32>>;
558 fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
559 f.write_str("null, i32, or array of i32")
560 }
561 fn visit_none<E: de::Error>(self) -> std::result::Result<Self::Value, E> {
562 Ok(None)
563 }
564 fn visit_unit<E: de::Error>(self) -> std::result::Result<Self::Value, E> {
565 Ok(None)
566 }
567 fn visit_some<D2: serde::Deserializer<'de>>(
568 self,
569 de: D2,
570 ) -> std::result::Result<Self::Value, D2::Error> {
571 struct Inner;
572 impl<'de> Visitor<'de> for Inner {
573 type Value = Vec<i32>;
574 fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
575 f.write_str("i32 or array of i32")
576 }
577 #[allow(clippy::cast_possible_truncation)]
578 fn visit_i64<E: de::Error>(self, v: i64) -> std::result::Result<Self::Value, E> {
579 Ok(vec![v as i32])
580 }
581 #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
582 fn visit_u64<E: de::Error>(self, v: u64) -> std::result::Result<Self::Value, E> {
583 Ok(vec![v as i32])
584 }
585 fn visit_seq<A: de::SeqAccess<'de>>(
586 self,
587 mut seq: A,
588 ) -> std::result::Result<Self::Value, A::Error> {
589 let mut out = Vec::with_capacity(seq.size_hint().unwrap_or(1));
590 while let Some(x) = seq.next_element::<i32>()? {
591 out.push(x);
592 }
593 Ok(out)
594 }
595 }
596 de.deserialize_any(Inner).map(Some)
597 }
598 #[allow(clippy::cast_possible_truncation)]
599 fn visit_i64<E: de::Error>(self, v: i64) -> std::result::Result<Self::Value, E> {
600 Ok(Some(vec![v as i32]))
601 }
602 #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
603 fn visit_u64<E: de::Error>(self, v: u64) -> std::result::Result<Self::Value, E> {
604 Ok(Some(vec![v as i32]))
605 }
606 fn visit_seq<A: de::SeqAccess<'de>>(
607 self,
608 mut seq: A,
609 ) -> std::result::Result<Self::Value, A::Error> {
610 let mut out = Vec::with_capacity(seq.size_hint().unwrap_or(1));
611 while let Some(x) = seq.next_element::<i32>()? {
612 out.push(x);
613 }
614 Ok(Some(out))
615 }
616 }
617 de.deserialize_option(V)
618}
619
620static NEXT_BUFFER_ID: AtomicU64 = AtomicU64::new(1);
622
623#[derive(Debug, Clone)]
629pub struct BufferIdentity {
630 id: u64,
631 guard: Arc<()>,
632}
633
634impl BufferIdentity {
635 pub fn new() -> Self {
637 Self {
638 id: NEXT_BUFFER_ID.fetch_add(1, Ordering::Relaxed),
639 guard: Arc::new(()),
640 }
641 }
642
643 pub fn id(&self) -> u64 {
645 self.id
646 }
647
648 pub fn weak(&self) -> Weak<()> {
651 Arc::downgrade(&self.guard)
652 }
653}
654
655impl Default for BufferIdentity {
656 fn default() -> Self {
657 Self::new()
658 }
659}
660
661#[cfg(target_os = "linux")]
662use nix::sys::stat::{major, minor};
663
664pub trait TensorTrait<T>: Send + Sync
665where
666 T: Num + Clone + fmt::Debug,
667{
668 fn new(shape: &[usize], name: Option<&str>) -> Result<Self>
671 where
672 Self: Sized;
673
674 #[cfg(unix)]
675 fn from_fd(fd: std::os::fd::OwnedFd, shape: &[usize], name: Option<&str>) -> Result<Self>
681 where
682 Self: Sized;
683
684 #[cfg(unix)]
685 fn clone_fd(&self) -> Result<std::os::fd::OwnedFd>;
687
688 fn memory(&self) -> TensorMemory;
690
691 fn name(&self) -> String;
693
694 fn len(&self) -> usize {
696 self.shape().iter().product()
697 }
698
699 fn is_empty(&self) -> bool {
701 self.len() == 0
702 }
703
704 fn size(&self) -> usize {
706 self.len() * std::mem::size_of::<T>()
707 }
708
709 fn shape(&self) -> &[usize];
711
712 fn reshape(&mut self, shape: &[usize]) -> Result<()>;
715
716 fn map(&self) -> Result<TensorMap<T>>;
719
720 fn buffer_identity(&self) -> &BufferIdentity;
722}
723
724pub trait TensorMapTrait<T>
725where
726 T: Num + Clone + fmt::Debug,
727{
728 fn shape(&self) -> &[usize];
730
731 fn unmap(&mut self);
733
734 fn len(&self) -> usize {
736 self.shape().iter().product()
737 }
738
739 fn is_empty(&self) -> bool {
741 self.len() == 0
742 }
743
744 fn size(&self) -> usize {
746 self.len() * std::mem::size_of::<T>()
747 }
748
749 fn as_slice(&self) -> &[T];
751
752 fn as_mut_slice(&mut self) -> &mut [T];
754
755 #[cfg(feature = "ndarray")]
756 fn view(&'_ self) -> Result<ndarray::ArrayView<'_, T, ndarray::Dim<ndarray::IxDynImpl>>> {
758 Ok(ndarray::ArrayView::from_shape(
759 self.shape(),
760 self.as_slice(),
761 )?)
762 }
763
764 #[cfg(feature = "ndarray")]
765 fn view_mut(
767 &'_ mut self,
768 ) -> Result<ndarray::ArrayViewMut<'_, T, ndarray::Dim<ndarray::IxDynImpl>>> {
769 let shape = self.shape().to_vec();
770 Ok(ndarray::ArrayViewMut::from_shape(
771 shape,
772 self.as_mut_slice(),
773 )?)
774 }
775}
776
777#[derive(Debug, Clone, Copy, PartialEq, Eq)]
778pub enum TensorMemory {
779 #[cfg(target_os = "linux")]
780 Dma,
784 #[cfg(unix)]
785 Shm,
788
789 Mem,
791
792 Pbo,
795}
796
797impl From<TensorMemory> for String {
798 fn from(memory: TensorMemory) -> Self {
799 match memory {
800 #[cfg(target_os = "linux")]
801 TensorMemory::Dma => "dma".to_owned(),
802 #[cfg(unix)]
803 TensorMemory::Shm => "shm".to_owned(),
804 TensorMemory::Mem => "mem".to_owned(),
805 TensorMemory::Pbo => "pbo".to_owned(),
806 }
807 }
808}
809
810impl TryFrom<&str> for TensorMemory {
811 type Error = Error;
812
813 fn try_from(s: &str) -> Result<Self> {
814 match s {
815 #[cfg(target_os = "linux")]
816 "dma" => Ok(TensorMemory::Dma),
817 #[cfg(unix)]
818 "shm" => Ok(TensorMemory::Shm),
819 "mem" => Ok(TensorMemory::Mem),
820 "pbo" => Ok(TensorMemory::Pbo),
821 _ => Err(Error::InvalidMemoryType(s.to_owned())),
822 }
823 }
824}
825
826#[derive(Debug)]
827#[allow(dead_code)] pub(crate) enum TensorStorage<T>
829where
830 T: Num + Clone + fmt::Debug + Send + Sync,
831{
832 #[cfg(target_os = "linux")]
833 Dma(DmaTensor<T>),
834 #[cfg(unix)]
835 Shm(ShmTensor<T>),
836 Mem(MemTensor<T>),
837 Pbo(PboTensor<T>),
838}
839
840impl<T> TensorStorage<T>
841where
842 T: Num + Clone + fmt::Debug + Send + Sync,
843{
844 fn new(shape: &[usize], memory: Option<TensorMemory>, name: Option<&str>) -> Result<Self> {
849 match memory {
850 #[cfg(target_os = "linux")]
851 Some(TensorMemory::Dma) => {
852 DmaTensor::<T>::new(shape, name).map(TensorStorage::Dma)
853 }
854 #[cfg(unix)]
855 Some(TensorMemory::Shm) => {
856 ShmTensor::<T>::new(shape, name).map(TensorStorage::Shm)
857 }
858 Some(TensorMemory::Mem) => {
859 MemTensor::<T>::new(shape, name).map(TensorStorage::Mem)
860 }
861 Some(TensorMemory::Pbo) => Err(crate::error::Error::NotImplemented(
862 "PboTensor cannot be created via Tensor::new() — use ImageProcessor::create_image()".to_owned(),
863 )),
864 None => {
865 if std::env::var("EDGEFIRST_TENSOR_FORCE_MEM")
866 .is_ok_and(|x| x != "0" && x.to_lowercase() != "false")
867 {
868 MemTensor::<T>::new(shape, name).map(TensorStorage::Mem)
869 } else {
870 #[cfg(target_os = "linux")]
871 {
872 match DmaTensor::<T>::new(shape, name) {
874 Ok(tensor) => Ok(TensorStorage::Dma(tensor)),
875 Err(_) => {
876 match ShmTensor::<T>::new(shape, name)
877 .map(TensorStorage::Shm)
878 {
879 Ok(tensor) => Ok(tensor),
880 Err(_) => MemTensor::<T>::new(shape, name)
881 .map(TensorStorage::Mem),
882 }
883 }
884 }
885 }
886 #[cfg(all(unix, not(target_os = "linux")))]
887 {
888 match ShmTensor::<T>::new(shape, name) {
890 Ok(tensor) => Ok(TensorStorage::Shm(tensor)),
891 Err(_) => {
892 MemTensor::<T>::new(shape, name).map(TensorStorage::Mem)
893 }
894 }
895 }
896 #[cfg(not(unix))]
897 {
898 MemTensor::<T>::new(shape, name).map(TensorStorage::Mem)
900 }
901 }
902 }
903 }
904 }
905
906 #[cfg(target_os = "linux")]
915 pub(crate) fn new_dma_with_byte_size(
916 shape: &[usize],
917 byte_size: usize,
918 name: Option<&str>,
919 ) -> Result<Self> {
920 DmaTensor::<T>::new_with_byte_size(shape, byte_size, name).map(TensorStorage::Dma)
921 }
922
923 #[cfg(unix)]
931 fn from_fd(fd: OwnedFd, shape: &[usize], name: Option<&str>) -> Result<Self> {
932 #[cfg(target_os = "linux")]
933 {
934 use nix::sys::stat::fstat;
935
936 let stat = fstat(&fd)?;
937 let major = major(stat.st_dev);
938 let minor = minor(stat.st_dev);
939
940 log::debug!("Creating tensor from fd: major={major}, minor={minor}");
941
942 if major != 0 {
943 return Err(Error::UnknownDeviceType(major, minor));
945 }
946
947 match minor {
948 9 | 10 => {
949 DmaTensor::<T>::from_fd(fd, shape, name).map(TensorStorage::Dma)
951 }
952 _ => {
953 ShmTensor::<T>::from_fd(fd, shape, name).map(TensorStorage::Shm)
955 }
956 }
957 }
958 #[cfg(all(unix, not(target_os = "linux")))]
959 {
960 ShmTensor::<T>::from_fd(fd, shape, name).map(TensorStorage::Shm)
962 }
963 }
964}
965
966impl<T> TensorTrait<T> for TensorStorage<T>
967where
968 T: Num + Clone + fmt::Debug + Send + Sync,
969{
970 fn new(shape: &[usize], name: Option<&str>) -> Result<Self> {
971 Self::new(shape, None, name)
972 }
973
974 #[cfg(unix)]
975 fn from_fd(fd: OwnedFd, shape: &[usize], name: Option<&str>) -> Result<Self> {
976 Self::from_fd(fd, shape, name)
977 }
978
979 #[cfg(unix)]
980 fn clone_fd(&self) -> Result<OwnedFd> {
981 match self {
982 #[cfg(target_os = "linux")]
983 TensorStorage::Dma(t) => t.clone_fd(),
984 TensorStorage::Shm(t) => t.clone_fd(),
985 TensorStorage::Mem(t) => t.clone_fd(),
986 TensorStorage::Pbo(t) => t.clone_fd(),
987 }
988 }
989
990 fn memory(&self) -> TensorMemory {
991 match self {
992 #[cfg(target_os = "linux")]
993 TensorStorage::Dma(_) => TensorMemory::Dma,
994 #[cfg(unix)]
995 TensorStorage::Shm(_) => TensorMemory::Shm,
996 TensorStorage::Mem(_) => TensorMemory::Mem,
997 TensorStorage::Pbo(_) => TensorMemory::Pbo,
998 }
999 }
1000
1001 fn name(&self) -> String {
1002 match self {
1003 #[cfg(target_os = "linux")]
1004 TensorStorage::Dma(t) => t.name(),
1005 #[cfg(unix)]
1006 TensorStorage::Shm(t) => t.name(),
1007 TensorStorage::Mem(t) => t.name(),
1008 TensorStorage::Pbo(t) => t.name(),
1009 }
1010 }
1011
1012 fn shape(&self) -> &[usize] {
1013 match self {
1014 #[cfg(target_os = "linux")]
1015 TensorStorage::Dma(t) => t.shape(),
1016 #[cfg(unix)]
1017 TensorStorage::Shm(t) => t.shape(),
1018 TensorStorage::Mem(t) => t.shape(),
1019 TensorStorage::Pbo(t) => t.shape(),
1020 }
1021 }
1022
1023 fn reshape(&mut self, shape: &[usize]) -> Result<()> {
1024 match self {
1025 #[cfg(target_os = "linux")]
1026 TensorStorage::Dma(t) => t.reshape(shape),
1027 #[cfg(unix)]
1028 TensorStorage::Shm(t) => t.reshape(shape),
1029 TensorStorage::Mem(t) => t.reshape(shape),
1030 TensorStorage::Pbo(t) => t.reshape(shape),
1031 }
1032 }
1033
1034 fn map(&self) -> Result<TensorMap<T>> {
1035 match self {
1036 #[cfg(target_os = "linux")]
1037 TensorStorage::Dma(t) => t.map(),
1038 #[cfg(unix)]
1039 TensorStorage::Shm(t) => t.map(),
1040 TensorStorage::Mem(t) => t.map(),
1041 TensorStorage::Pbo(t) => t.map(),
1042 }
1043 }
1044
1045 fn buffer_identity(&self) -> &BufferIdentity {
1046 match self {
1047 #[cfg(target_os = "linux")]
1048 TensorStorage::Dma(t) => t.buffer_identity(),
1049 #[cfg(unix)]
1050 TensorStorage::Shm(t) => t.buffer_identity(),
1051 TensorStorage::Mem(t) => t.buffer_identity(),
1052 TensorStorage::Pbo(t) => t.buffer_identity(),
1053 }
1054 }
1055}
1056
1057#[derive(Debug)]
1063pub struct Tensor<T>
1064where
1065 T: Num + Clone + fmt::Debug + Send + Sync,
1066{
1067 pub(crate) storage: TensorStorage<T>,
1068 format: Option<PixelFormat>,
1069 chroma: Option<Box<Tensor<T>>>,
1070 row_stride: Option<usize>,
1073 plane_offset: Option<usize>,
1076 pub(crate) quantization: Option<Quantization>,
1080}
1081
1082impl<T> Tensor<T>
1083where
1084 T: Num + Clone + fmt::Debug + Send + Sync,
1085{
1086 pub(crate) fn wrap(storage: TensorStorage<T>) -> Self {
1088 Self {
1089 storage,
1090 format: None,
1091 chroma: None,
1092 row_stride: None,
1093 plane_offset: None,
1094 quantization: None,
1095 }
1096 }
1097
1098 pub fn from_slice(values: &[T], shape: &[usize]) -> Result<Self>
1107 where
1108 T: Copy,
1109 {
1110 let expected: usize = shape.iter().product();
1111 if values.len() != expected {
1112 return Err(Error::InvalidShape(format!(
1113 "from_slice: values.len()={} but shape product={expected} (shape={shape:?})",
1114 values.len()
1115 )));
1116 }
1117 let t = Self::new(shape, Some(TensorMemory::Mem), None)?;
1118 {
1119 let mut m = t.map()?;
1120 m.as_mut_slice().copy_from_slice(values);
1121 }
1122 Ok(t)
1123 }
1124
1125 #[cfg(feature = "ndarray")]
1130 pub fn from_arrayview3(view: ndarray::ArrayView3<'_, T>) -> Result<Self>
1131 where
1132 T: Copy,
1133 {
1134 let (h, w, c) = view.dim();
1135 let t = Self::new(&[h, w, c], Some(TensorMemory::Mem), None)?;
1136 {
1137 let mut m = t.map()?;
1138 let dst = m.as_mut_slice();
1139 if let Some(src) = view.as_slice() {
1140 dst.copy_from_slice(src);
1141 } else {
1142 for (d, &s) in dst.iter_mut().zip(view.iter()) {
1143 *d = s;
1144 }
1145 }
1146 }
1147 Ok(t)
1148 }
1149
1150 pub fn new(shape: &[usize], memory: Option<TensorMemory>, name: Option<&str>) -> Result<Self> {
1175 TensorStorage::new(shape, memory, name).map(Self::wrap)
1176 }
1177
1178 pub fn image(
1180 width: usize,
1181 height: usize,
1182 format: PixelFormat,
1183 memory: Option<TensorMemory>,
1184 ) -> Result<Self> {
1185 let shape = match format.layout() {
1186 PixelLayout::Packed => vec![height, width, format.channels()],
1187 PixelLayout::Planar => vec![format.channels(), height, width],
1188 PixelLayout::SemiPlanar => {
1189 let total_h = match format {
1193 PixelFormat::Nv12 => {
1194 if !height.is_multiple_of(2) {
1195 return Err(Error::InvalidArgument(format!(
1196 "NV12 requires even height, got {height}"
1197 )));
1198 }
1199 height * 3 / 2
1200 }
1201 PixelFormat::Nv16 => height * 2,
1202 _ => {
1203 return Err(Error::InvalidArgument(format!(
1204 "unknown semi-planar height multiplier for {format:?}"
1205 )))
1206 }
1207 };
1208 vec![total_h, width]
1209 }
1210 };
1211 let mut t = Self::new(&shape, memory, None)?;
1212 t.format = Some(format);
1213 Ok(t)
1214 }
1215
1216 pub fn image_with_stride(
1252 width: usize,
1253 height: usize,
1254 format: PixelFormat,
1255 row_stride_bytes: usize,
1256 memory: Option<TensorMemory>,
1257 ) -> Result<Self> {
1258 #[cfg(not(target_os = "linux"))]
1268 {
1269 let _ = (width, height, format, row_stride_bytes, memory);
1270 Err(Error::NotImplemented(
1271 "image_with_stride requires DMA support (Linux only)".to_owned(),
1272 ))
1273 }
1274
1275 #[cfg(target_os = "linux")]
1276 {
1277 if format.layout() != PixelLayout::Packed {
1278 return Err(Error::NotImplemented(format!(
1279 "Tensor::image_with_stride only supports packed pixel layouts, got {format:?}"
1280 )));
1281 }
1282 let elem = std::mem::size_of::<T>();
1283 let min_stride = width
1284 .checked_mul(format.channels())
1285 .and_then(|p| p.checked_mul(elem))
1286 .ok_or_else(|| {
1287 Error::InvalidArgument(format!(
1288 "image_with_stride: width {width} × channels {} × sizeof::<T>={elem} \
1289 overflows usize",
1290 format.channels()
1291 ))
1292 })?;
1293 if row_stride_bytes < min_stride {
1294 return Err(Error::InvalidArgument(format!(
1295 "image_with_stride: row_stride {row_stride_bytes} < minimum {min_stride} \
1296 ({width} px × {} ch × {elem} B)",
1297 format.channels()
1298 )));
1299 }
1300 let total_byte_size = row_stride_bytes.checked_mul(height).ok_or_else(|| {
1301 Error::InvalidArgument(format!(
1302 "image_with_stride: row_stride {row_stride_bytes} × height {height} overflows usize"
1303 ))
1304 })?;
1305
1306 let shape = vec![height, width, format.channels()];
1307
1308 let storage = match memory {
1309 Some(TensorMemory::Dma) | None => {
1310 TensorStorage::<T>::new_dma_with_byte_size(&shape, total_byte_size, None)?
1311 }
1312 Some(other) => {
1313 return Err(Error::NotImplemented(format!(
1314 "image_with_stride: only TensorMemory::Dma is supported, got {other:?}"
1315 )));
1316 }
1317 };
1318
1319 let mut t = Self::wrap(storage);
1320 t.format = Some(format);
1321 t.row_stride = Some(row_stride_bytes);
1322 Ok(t)
1323 }
1324 }
1325
1326 pub fn set_format(&mut self, format: PixelFormat) -> Result<()> {
1343 let shape = self.shape();
1344 match format.layout() {
1345 PixelLayout::Packed => {
1346 if shape.len() != 3 || shape[2] != format.channels() {
1347 return Err(Error::InvalidShape(format!(
1348 "packed format {format:?} expects [H, W, {}], got {shape:?}",
1349 format.channels()
1350 )));
1351 }
1352 }
1353 PixelLayout::Planar => {
1354 if shape.len() != 3 || shape[0] != format.channels() {
1355 return Err(Error::InvalidShape(format!(
1356 "planar format {format:?} expects [{}, H, W], got {shape:?}",
1357 format.channels()
1358 )));
1359 }
1360 }
1361 PixelLayout::SemiPlanar => {
1362 if shape.len() != 2 {
1363 return Err(Error::InvalidShape(format!(
1364 "semi-planar format {format:?} expects [H*k, W], got {shape:?}"
1365 )));
1366 }
1367 match format {
1368 PixelFormat::Nv12 if !shape[0].is_multiple_of(3) => {
1369 return Err(Error::InvalidShape(format!(
1370 "NV12 contiguous shape[0] must be divisible by 3, got {}",
1371 shape[0]
1372 )));
1373 }
1374 PixelFormat::Nv16 if !shape[0].is_multiple_of(2) => {
1375 return Err(Error::InvalidShape(format!(
1376 "NV16 contiguous shape[0] must be even, got {}",
1377 shape[0]
1378 )));
1379 }
1380 _ => {}
1381 }
1382 }
1383 }
1384 if self.format != Some(format) {
1387 self.row_stride = None;
1388 self.plane_offset = None;
1389 #[cfg(target_os = "linux")]
1390 if let TensorStorage::Dma(ref mut dma) = self.storage {
1391 dma.mmap_offset = 0;
1392 }
1393 }
1394 self.format = Some(format);
1395 Ok(())
1396 }
1397
1398 pub fn format(&self) -> Option<PixelFormat> {
1400 self.format
1401 }
1402
1403 pub fn width(&self) -> Option<usize> {
1405 let fmt = self.format?;
1406 let shape = self.shape();
1407 match fmt.layout() {
1408 PixelLayout::Packed => Some(shape[1]),
1409 PixelLayout::Planar => Some(shape[2]),
1410 PixelLayout::SemiPlanar => Some(shape[1]),
1411 }
1412 }
1413
1414 pub fn height(&self) -> Option<usize> {
1416 let fmt = self.format?;
1417 let shape = self.shape();
1418 match fmt.layout() {
1419 PixelLayout::Packed => Some(shape[0]),
1420 PixelLayout::Planar => Some(shape[1]),
1421 PixelLayout::SemiPlanar => {
1422 if self.is_multiplane() {
1423 Some(shape[0])
1424 } else {
1425 match fmt {
1426 PixelFormat::Nv12 => Some(shape[0] * 2 / 3),
1427 PixelFormat::Nv16 => Some(shape[0] / 2),
1428 _ => None,
1429 }
1430 }
1431 }
1432 }
1433 }
1434
1435 pub fn from_planes(luma: Tensor<T>, chroma: Tensor<T>, format: PixelFormat) -> Result<Self> {
1437 if format.layout() != PixelLayout::SemiPlanar {
1438 return Err(Error::InvalidArgument(format!(
1439 "from_planes requires a semi-planar format, got {format:?}"
1440 )));
1441 }
1442 if chroma.format.is_some() || chroma.chroma.is_some() {
1443 return Err(Error::InvalidArgument(
1444 "chroma tensor must be a raw tensor (no format or chroma metadata)".into(),
1445 ));
1446 }
1447 let luma_shape = luma.shape();
1448 let chroma_shape = chroma.shape();
1449 if luma_shape.len() != 2 || chroma_shape.len() != 2 {
1450 return Err(Error::InvalidArgument(format!(
1451 "from_planes expects 2D shapes, got luma={luma_shape:?} chroma={chroma_shape:?}"
1452 )));
1453 }
1454 if luma_shape[1] != chroma_shape[1] {
1455 return Err(Error::InvalidArgument(format!(
1456 "luma width {} != chroma width {}",
1457 luma_shape[1], chroma_shape[1]
1458 )));
1459 }
1460 match format {
1461 PixelFormat::Nv12 => {
1462 if luma_shape[0] % 2 != 0 {
1463 return Err(Error::InvalidArgument(format!(
1464 "NV12 requires even luma height, got {}",
1465 luma_shape[0]
1466 )));
1467 }
1468 if chroma_shape[0] != luma_shape[0] / 2 {
1469 return Err(Error::InvalidArgument(format!(
1470 "NV12 chroma height {} != luma height / 2 ({})",
1471 chroma_shape[0],
1472 luma_shape[0] / 2
1473 )));
1474 }
1475 }
1476 PixelFormat::Nv16 => {
1477 if chroma_shape[0] != luma_shape[0] {
1478 return Err(Error::InvalidArgument(format!(
1479 "NV16 chroma height {} != luma height {}",
1480 chroma_shape[0], luma_shape[0]
1481 )));
1482 }
1483 }
1484 _ => {
1485 return Err(Error::InvalidArgument(format!(
1486 "from_planes only supports NV12 and NV16, got {format:?}"
1487 )));
1488 }
1489 }
1490
1491 Ok(Tensor {
1492 storage: luma.storage,
1493 format: Some(format),
1494 chroma: Some(Box::new(chroma)),
1495 row_stride: luma.row_stride,
1496 plane_offset: luma.plane_offset,
1497 quantization: luma.quantization,
1498 })
1499 }
1500
1501 pub fn is_multiplane(&self) -> bool {
1503 self.chroma.is_some()
1504 }
1505
1506 pub fn chroma(&self) -> Option<&Tensor<T>> {
1508 self.chroma.as_deref()
1509 }
1510
1511 pub fn chroma_mut(&mut self) -> Option<&mut Tensor<T>> {
1513 self.chroma.as_deref_mut()
1514 }
1515
1516 pub fn row_stride(&self) -> Option<usize> {
1518 self.row_stride
1519 }
1520
1521 pub fn effective_row_stride(&self) -> Option<usize> {
1526 if let Some(s) = self.row_stride {
1527 return Some(s);
1528 }
1529 let fmt = self.format?;
1530 let w = self.width()?;
1531 let elem = std::mem::size_of::<T>();
1532 Some(match fmt.layout() {
1533 PixelLayout::Packed => w * fmt.channels() * elem,
1534 PixelLayout::Planar | PixelLayout::SemiPlanar => w * elem,
1535 })
1536 }
1537
1538 pub fn set_row_stride(&mut self, stride: usize) -> Result<()> {
1565 let fmt = self.format.ok_or_else(|| {
1566 Error::InvalidArgument("cannot set row_stride without a pixel format".into())
1567 })?;
1568 let w = self.width().ok_or_else(|| {
1569 Error::InvalidArgument("cannot determine width for row_stride validation".into())
1570 })?;
1571 let elem = std::mem::size_of::<T>();
1572 let min_stride = match fmt.layout() {
1573 PixelLayout::Packed => w * fmt.channels() * elem,
1574 PixelLayout::Planar | PixelLayout::SemiPlanar => w * elem,
1575 };
1576 if stride < min_stride {
1577 return Err(Error::InvalidArgument(format!(
1578 "row_stride {stride} < minimum {min_stride} for {fmt:?} at width {w}"
1579 )));
1580 }
1581 self.row_stride = Some(stride);
1582 Ok(())
1583 }
1584
1585 pub fn set_row_stride_unchecked(&mut self, stride: usize) {
1591 self.row_stride = Some(stride);
1592 }
1593
1594 pub fn with_row_stride(mut self, stride: usize) -> Result<Self> {
1601 self.set_row_stride(stride)?;
1602 Ok(self)
1603 }
1604
1605 pub fn plane_offset(&self) -> Option<usize> {
1607 self.plane_offset
1608 }
1609
1610 pub fn set_plane_offset(&mut self, offset: usize) {
1616 self.plane_offset = Some(offset);
1617 #[cfg(target_os = "linux")]
1618 if let TensorStorage::Dma(ref mut dma) = self.storage {
1619 dma.mmap_offset = offset;
1620 }
1621 }
1622
1623 pub fn with_plane_offset(mut self, offset: usize) -> Self {
1626 self.set_plane_offset(offset);
1627 self
1628 }
1629
1630 pub fn as_pbo(&self) -> Option<&PboTensor<T>> {
1632 match &self.storage {
1633 TensorStorage::Pbo(p) => Some(p),
1634 _ => None,
1635 }
1636 }
1637
1638 #[cfg(target_os = "linux")]
1640 pub fn as_dma(&self) -> Option<&DmaTensor<T>> {
1641 match &self.storage {
1642 TensorStorage::Dma(d) => Some(d),
1643 _ => None,
1644 }
1645 }
1646
1647 #[cfg(target_os = "linux")]
1658 pub fn dmabuf(&self) -> Result<std::os::fd::BorrowedFd<'_>> {
1659 use std::os::fd::AsFd;
1660 match &self.storage {
1661 TensorStorage::Dma(dma) => Ok(dma.fd.as_fd()),
1662 _ => Err(Error::NotImplemented(format!(
1663 "dmabuf requires DMA-backed tensor, got {:?}",
1664 self.storage.memory()
1665 ))),
1666 }
1667 }
1668
1669 pub fn from_pbo(pbo: PboTensor<T>) -> Self {
1671 Self {
1672 storage: TensorStorage::Pbo(pbo),
1673 format: None,
1674 chroma: None,
1675 row_stride: None,
1676 plane_offset: None,
1677 quantization: None,
1678 }
1679 }
1680}
1681
1682impl<T> Tensor<T>
1686where
1687 T: IntegerType + Num + Clone + fmt::Debug + Send + Sync,
1688{
1689 pub fn quantization(&self) -> Option<&Quantization> {
1691 self.quantization.as_ref()
1692 }
1693
1694 pub fn set_quantization(&mut self, q: Quantization) -> Result<()> {
1698 q.validate(self.shape())?;
1699 self.quantization = Some(q);
1700 Ok(())
1701 }
1702
1703 pub fn with_quantization(mut self, q: Quantization) -> Result<Self> {
1709 self.set_quantization(q)?;
1710 Ok(self)
1711 }
1712
1713 pub fn clear_quantization(&mut self) {
1715 self.quantization = None;
1716 }
1717}
1718
1719impl<T> TensorTrait<T> for Tensor<T>
1720where
1721 T: Num + Clone + fmt::Debug + Send + Sync,
1722{
1723 fn new(shape: &[usize], name: Option<&str>) -> Result<Self>
1724 where
1725 Self: Sized,
1726 {
1727 Self::new(shape, None, name)
1728 }
1729
1730 #[cfg(unix)]
1731 fn from_fd(fd: std::os::fd::OwnedFd, shape: &[usize], name: Option<&str>) -> Result<Self>
1732 where
1733 Self: Sized,
1734 {
1735 Ok(Self::wrap(TensorStorage::from_fd(fd, shape, name)?))
1736 }
1737
1738 #[cfg(unix)]
1739 fn clone_fd(&self) -> Result<std::os::fd::OwnedFd> {
1740 self.storage.clone_fd()
1741 }
1742
1743 fn memory(&self) -> TensorMemory {
1744 self.storage.memory()
1745 }
1746
1747 fn name(&self) -> String {
1748 self.storage.name()
1749 }
1750
1751 fn shape(&self) -> &[usize] {
1752 self.storage.shape()
1753 }
1754
1755 fn reshape(&mut self, shape: &[usize]) -> Result<()> {
1756 if self.chroma.is_some() {
1757 return Err(Error::InvalidOperation(
1758 "cannot reshape a multiplane tensor — decompose planes first".into(),
1759 ));
1760 }
1761 self.storage.reshape(shape)?;
1762 self.format = None;
1763 self.row_stride = None;
1764 self.plane_offset = None;
1765 #[cfg(target_os = "linux")]
1766 if let TensorStorage::Dma(ref mut dma) = self.storage {
1767 dma.mmap_offset = 0;
1768 }
1769 Ok(())
1770 }
1771
1772 fn map(&self) -> Result<TensorMap<T>> {
1773 #[cfg(target_os = "linux")]
1790 if let Some(stride) = self.row_stride {
1791 if let TensorStorage::Dma(dma) = &self.storage {
1792 if !dma.is_imported {
1793 let height = self.height().ok_or_else(|| {
1814 Error::InvalidOperation(
1815 "Tensor::map: strided DMA mapping requires a PixelFormat \
1816 so height() can be derived; set a format before mapping \
1817 or clear row_stride for raw tensor access"
1818 .into(),
1819 )
1820 })?;
1821 let total_bytes = stride.checked_mul(height).ok_or_else(|| {
1822 Error::InvalidOperation(format!(
1823 "Tensor::map: row_stride {stride} × height {height} overflows usize"
1824 ))
1825 })?;
1826 let available_bytes = dma.buf_size.saturating_sub(dma.mmap_offset);
1827 if total_bytes > available_bytes {
1828 return Err(Error::InvalidOperation(format!(
1829 "Tensor::map: strided mapping needs {total_bytes} bytes \
1830 but DMA buffer only has {available_bytes} available \
1831 (buf_size={}, mmap_offset={}, stride={stride}, height={height}); \
1832 the row_stride was likely set larger than the original allocation",
1833 dma.buf_size, dma.mmap_offset
1834 )));
1835 }
1836 return dma.map_with_byte_size(total_bytes).map(TensorMap::Dma);
1837 }
1838 }
1839 return Err(Error::InvalidOperation(
1840 "CPU mapping of strided foreign tensors is not supported; \
1841 use GPU path only"
1842 .into(),
1843 ));
1844 }
1845 #[cfg(not(target_os = "linux"))]
1846 if self.row_stride.is_some() {
1847 return Err(Error::InvalidOperation(
1848 "CPU mapping of strided tensors is not supported on this \
1849 platform (DMA backing is Linux-only)"
1850 .into(),
1851 ));
1852 }
1853 if self.plane_offset.is_some_and(|o| o > 0) {
1857 #[cfg(target_os = "linux")]
1858 if !matches!(self.storage, TensorStorage::Dma(_)) {
1859 return Err(Error::InvalidOperation(
1860 "plane offset only supported for DMA tensors".into(),
1861 ));
1862 }
1863 #[cfg(not(target_os = "linux"))]
1864 return Err(Error::InvalidOperation(
1865 "plane offset only supported for DMA tensors".into(),
1866 ));
1867 }
1868 self.storage.map()
1869 }
1870
1871 fn buffer_identity(&self) -> &BufferIdentity {
1872 self.storage.buffer_identity()
1873 }
1874}
1875
1876pub enum TensorMap<T>
1877where
1878 T: Num + Clone + fmt::Debug,
1879{
1880 #[cfg(target_os = "linux")]
1881 Dma(DmaMap<T>),
1882 #[cfg(unix)]
1883 Shm(ShmMap<T>),
1884 Mem(MemMap<T>),
1885 Pbo(PboMap<T>),
1886}
1887
1888impl<T> TensorMapTrait<T> for TensorMap<T>
1889where
1890 T: Num + Clone + fmt::Debug,
1891{
1892 fn shape(&self) -> &[usize] {
1893 match self {
1894 #[cfg(target_os = "linux")]
1895 TensorMap::Dma(map) => map.shape(),
1896 #[cfg(unix)]
1897 TensorMap::Shm(map) => map.shape(),
1898 TensorMap::Mem(map) => map.shape(),
1899 TensorMap::Pbo(map) => map.shape(),
1900 }
1901 }
1902
1903 fn unmap(&mut self) {
1904 match self {
1905 #[cfg(target_os = "linux")]
1906 TensorMap::Dma(map) => map.unmap(),
1907 #[cfg(unix)]
1908 TensorMap::Shm(map) => map.unmap(),
1909 TensorMap::Mem(map) => map.unmap(),
1910 TensorMap::Pbo(map) => map.unmap(),
1911 }
1912 }
1913
1914 fn as_slice(&self) -> &[T] {
1915 match self {
1916 #[cfg(target_os = "linux")]
1917 TensorMap::Dma(map) => map.as_slice(),
1918 #[cfg(unix)]
1919 TensorMap::Shm(map) => map.as_slice(),
1920 TensorMap::Mem(map) => map.as_slice(),
1921 TensorMap::Pbo(map) => map.as_slice(),
1922 }
1923 }
1924
1925 fn as_mut_slice(&mut self) -> &mut [T] {
1926 match self {
1927 #[cfg(target_os = "linux")]
1928 TensorMap::Dma(map) => map.as_mut_slice(),
1929 #[cfg(unix)]
1930 TensorMap::Shm(map) => map.as_mut_slice(),
1931 TensorMap::Mem(map) => map.as_mut_slice(),
1932 TensorMap::Pbo(map) => map.as_mut_slice(),
1933 }
1934 }
1935}
1936
1937impl<T> Deref for TensorMap<T>
1938where
1939 T: Num + Clone + fmt::Debug,
1940{
1941 type Target = [T];
1942
1943 fn deref(&self) -> &[T] {
1944 match self {
1945 #[cfg(target_os = "linux")]
1946 TensorMap::Dma(map) => map.deref(),
1947 #[cfg(unix)]
1948 TensorMap::Shm(map) => map.deref(),
1949 TensorMap::Mem(map) => map.deref(),
1950 TensorMap::Pbo(map) => map.deref(),
1951 }
1952 }
1953}
1954
1955impl<T> DerefMut for TensorMap<T>
1956where
1957 T: Num + Clone + fmt::Debug,
1958{
1959 fn deref_mut(&mut self) -> &mut [T] {
1960 match self {
1961 #[cfg(target_os = "linux")]
1962 TensorMap::Dma(map) => map.deref_mut(),
1963 #[cfg(unix)]
1964 TensorMap::Shm(map) => map.deref_mut(),
1965 TensorMap::Mem(map) => map.deref_mut(),
1966 TensorMap::Pbo(map) => map.deref_mut(),
1967 }
1968 }
1969}
1970
1971#[cfg(target_os = "linux")]
1983static DMA_AVAILABLE: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
1984
1985#[cfg(target_os = "linux")]
1987pub fn is_dma_available() -> bool {
1988 *DMA_AVAILABLE.get_or_init(|| Tensor::<u8>::new(&[64], Some(TensorMemory::Dma), None).is_ok())
1989}
1990
1991#[cfg(not(target_os = "linux"))]
1995pub fn is_dma_available() -> bool {
1996 false
1997}
1998
1999#[cfg(unix)]
2006static SHM_AVAILABLE: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
2007
2008#[cfg(unix)]
2010pub fn is_shm_available() -> bool {
2011 *SHM_AVAILABLE.get_or_init(|| Tensor::<u8>::new(&[64], Some(TensorMemory::Shm), None).is_ok())
2012}
2013
2014#[cfg(not(unix))]
2018pub fn is_shm_available() -> bool {
2019 false
2020}
2021
2022#[cfg(test)]
2023mod dtype_tests {
2024 use super::*;
2025
2026 #[test]
2027 fn dtype_size() {
2028 assert_eq!(DType::U8.size(), 1);
2029 assert_eq!(DType::I8.size(), 1);
2030 assert_eq!(DType::U16.size(), 2);
2031 assert_eq!(DType::I16.size(), 2);
2032 assert_eq!(DType::U32.size(), 4);
2033 assert_eq!(DType::I32.size(), 4);
2034 assert_eq!(DType::U64.size(), 8);
2035 assert_eq!(DType::I64.size(), 8);
2036 assert_eq!(DType::F16.size(), 2);
2037 assert_eq!(DType::F32.size(), 4);
2038 assert_eq!(DType::F64.size(), 8);
2039 }
2040
2041 #[test]
2042 fn dtype_name() {
2043 assert_eq!(DType::U8.name(), "u8");
2044 assert_eq!(DType::F16.name(), "f16");
2045 assert_eq!(DType::F32.name(), "f32");
2046 }
2047
2048 #[test]
2049 fn dtype_serde_roundtrip() {
2050 use serde_json;
2051 let dt = DType::F16;
2052 let json = serde_json::to_string(&dt).unwrap();
2053 let back: DType = serde_json::from_str(&json).unwrap();
2054 assert_eq!(dt, back);
2055 }
2056}
2057
2058#[cfg(test)]
2059mod image_tests {
2060 use super::*;
2061
2062 #[test]
2063 fn raw_tensor_has_no_format() {
2064 let t = Tensor::<u8>::new(&[480, 640, 3], None, None).unwrap();
2065 assert!(t.format().is_none());
2066 assert!(t.width().is_none());
2067 assert!(t.height().is_none());
2068 assert!(!t.is_multiplane());
2069 assert!(t.chroma().is_none());
2070 }
2071
2072 #[test]
2073 fn image_tensor_packed() {
2074 let t = Tensor::<u8>::image(640, 480, PixelFormat::Rgba, None).unwrap();
2075 assert_eq!(t.format(), Some(PixelFormat::Rgba));
2076 assert_eq!(t.width(), Some(640));
2077 assert_eq!(t.height(), Some(480));
2078 assert_eq!(t.shape(), &[480, 640, 4]);
2079 assert!(!t.is_multiplane());
2080 }
2081
2082 #[test]
2083 fn image_tensor_planar() {
2084 let t = Tensor::<u8>::image(640, 480, PixelFormat::PlanarRgb, None).unwrap();
2085 assert_eq!(t.format(), Some(PixelFormat::PlanarRgb));
2086 assert_eq!(t.width(), Some(640));
2087 assert_eq!(t.height(), Some(480));
2088 assert_eq!(t.shape(), &[3, 480, 640]);
2089 }
2090
2091 #[test]
2092 fn image_tensor_semi_planar_contiguous() {
2093 let t = Tensor::<u8>::image(640, 480, PixelFormat::Nv12, None).unwrap();
2094 assert_eq!(t.format(), Some(PixelFormat::Nv12));
2095 assert_eq!(t.width(), Some(640));
2096 assert_eq!(t.height(), Some(480));
2097 assert_eq!(t.shape(), &[720, 640]);
2099 assert!(!t.is_multiplane());
2100 }
2101
2102 #[test]
2103 #[cfg(target_os = "linux")]
2104 fn image_tensor_with_stride_preserves_logical_width() {
2105 if !is_dma_available() {
2107 eprintln!("SKIPPED: DMA heap not available");
2108 return;
2109 }
2110 let stride = 12032;
2112 let t = Tensor::<u8>::image_with_stride(
2113 3004,
2114 1688,
2115 PixelFormat::Rgba,
2116 stride,
2117 Some(TensorMemory::Dma),
2118 )
2119 .unwrap();
2120 assert_eq!(t.width(), Some(3004));
2122 assert_eq!(t.height(), Some(1688));
2123 assert_eq!(t.shape(), &[1688, 3004, 4]);
2124 assert_eq!(t.effective_row_stride(), Some(stride));
2126 use crate::TensorMapTrait;
2129 {
2130 let map = t.map().unwrap();
2131 assert!(
2132 map.as_slice().len() >= stride * 1688,
2133 "mapped buffer {} bytes < expected {}",
2134 map.as_slice().len(),
2135 stride * 1688
2136 );
2137 }
2138 {
2141 let mut map = t.map().unwrap();
2142 let slice = map.as_mut_slice();
2143 for y in 0..1688 {
2144 let row_start = y * stride;
2145 for x in 0..3004 {
2146 let p = row_start + x * 4;
2147 slice[p] = (y & 0xFF) as u8;
2148 slice[p + 1] = (x & 0xFF) as u8;
2149 slice[p + 2] = 0x42;
2150 slice[p + 3] = 0xFF;
2151 }
2152 }
2153 }
2154 {
2155 let map = t.map().unwrap();
2156 let slice = map.as_slice();
2157 assert_eq!(slice[0], 0x00);
2159 assert_eq!(slice[1], 0x00);
2160 assert_eq!(slice[2], 0x42);
2161 assert_eq!(slice[3], 0xFF);
2162 let mid = 100 * stride + 50 * 4;
2163 assert_eq!(slice[mid], 100);
2164 assert_eq!(slice[mid + 1], 50);
2165 assert_eq!(slice[mid + 2], 0x42);
2166 }
2167 }
2168
2169 #[test]
2170 #[cfg(target_os = "linux")]
2171 fn image_tensor_with_stride_rejects_foreign_strided_map() {
2172 if !is_dma_available() {
2180 eprintln!("SKIPPED: DMA heap not available");
2181 return;
2182 }
2183 let backing = Tensor::<u8>::new(&[240 * 320 * 4], Some(TensorMemory::Dma), None).unwrap();
2185 let fd = backing.clone_fd().unwrap();
2186 let shape = [240usize, 320, 4];
2188 let storage = TensorStorage::<u8>::from_fd(fd, &shape, None).unwrap();
2189 let mut t = Tensor::<u8>::wrap(storage);
2190 t.set_format(PixelFormat::Bgra).unwrap();
2191 t.set_row_stride(320 * 4).unwrap(); let err = t.map();
2193 assert!(
2194 matches!(err, Err(Error::InvalidOperation(_))),
2195 "foreign strided map should error"
2196 );
2197 }
2198
2199 #[test]
2200 #[cfg(target_os = "linux")]
2201 fn image_tensor_with_stride_map_rejects_tampered_stride() {
2202 if !is_dma_available() {
2209 eprintln!("SKIPPED: DMA heap not available");
2210 return;
2211 }
2212 let mut t = Tensor::<u8>::image_with_stride(
2215 640,
2216 480,
2217 PixelFormat::Rgba,
2218 3072,
2219 Some(TensorMemory::Dma),
2220 )
2221 .unwrap();
2222 t.set_row_stride(12288).unwrap();
2225 let err = t.map();
2227 assert!(
2228 matches!(err, Err(Error::InvalidOperation(_))),
2229 "map() with oversized stride must return InvalidOperation"
2230 );
2231 }
2232
2233 #[test]
2234 fn dma_tensor_new_with_byte_size_rejects_shape_overflow() {
2235 #[cfg(target_os = "linux")]
2242 {
2243 let err = crate::dma::DmaTensor::<u64>::new_with_byte_size(
2244 &[usize::MAX, 2, 2],
2245 usize::MAX,
2246 None,
2247 );
2248 assert!(
2249 matches!(err, Err(Error::InvalidArgument(_))),
2250 "new_with_byte_size must detect shape.product() overflow"
2251 );
2252 }
2253 }
2254
2255 #[test]
2256 #[cfg(target_os = "linux")]
2257 fn image_tensor_with_stride_rejects_too_small_stride() {
2258 let err = Tensor::<u8>::image_with_stride(
2260 640,
2261 480,
2262 PixelFormat::Rgba,
2263 2400,
2264 Some(TensorMemory::Dma),
2265 );
2266 assert!(matches!(err, Err(Error::InvalidArgument(_))));
2267 }
2268
2269 #[test]
2270 #[cfg(target_os = "linux")]
2271 fn image_tensor_with_stride_rejects_non_packed() {
2272 let err = Tensor::<u8>::image_with_stride(
2275 640,
2276 480,
2277 PixelFormat::Nv12,
2278 640,
2279 Some(TensorMemory::Dma),
2280 );
2281 assert!(matches!(err, Err(Error::NotImplemented(_))));
2282 }
2283
2284 #[test]
2285 fn set_format_valid() {
2286 let mut t = Tensor::<u8>::new(&[480, 640, 3], None, None).unwrap();
2287 assert!(t.format().is_none());
2288 t.set_format(PixelFormat::Rgb).unwrap();
2289 assert_eq!(t.format(), Some(PixelFormat::Rgb));
2290 assert_eq!(t.width(), Some(640));
2291 assert_eq!(t.height(), Some(480));
2292 }
2293
2294 #[test]
2295 fn set_format_invalid_shape() {
2296 let mut t = Tensor::<u8>::new(&[480, 640, 4], None, None).unwrap();
2297 let err = t.set_format(PixelFormat::Rgb);
2299 assert!(err.is_err());
2300 assert!(t.format().is_none());
2302 }
2303
2304 #[test]
2305 fn reshape_clears_format() {
2306 let mut t = Tensor::<u8>::image(640, 480, PixelFormat::Rgba, None).unwrap();
2307 assert_eq!(t.format(), Some(PixelFormat::Rgba));
2308 t.reshape(&[480 * 640 * 4]).unwrap();
2310 assert!(t.format().is_none());
2311 }
2312
2313 #[test]
2314 fn from_planes_nv12() {
2315 let y = Tensor::<u8>::new(&[480, 640], None, None).unwrap();
2316 let uv = Tensor::<u8>::new(&[240, 640], None, None).unwrap();
2317 let img = Tensor::from_planes(y, uv, PixelFormat::Nv12).unwrap();
2318 assert_eq!(img.format(), Some(PixelFormat::Nv12));
2319 assert!(img.is_multiplane());
2320 assert!(img.chroma().is_some());
2321 assert_eq!(img.width(), Some(640));
2322 assert_eq!(img.height(), Some(480));
2323 }
2324
2325 #[test]
2326 fn from_planes_rejects_non_semiplanar() {
2327 let y = Tensor::<u8>::new(&[480, 640], None, None).unwrap();
2328 let uv = Tensor::<u8>::new(&[240, 640], None, None).unwrap();
2329 let err = Tensor::from_planes(y, uv, PixelFormat::Rgb);
2330 assert!(err.is_err());
2331 }
2332
2333 #[test]
2334 fn reshape_multiplane_errors() {
2335 let y = Tensor::<u8>::new(&[480, 640], None, None).unwrap();
2336 let uv = Tensor::<u8>::new(&[240, 640], None, None).unwrap();
2337 let mut img = Tensor::from_planes(y, uv, PixelFormat::Nv12).unwrap();
2338 let err = img.reshape(&[480 * 640 + 240 * 640]);
2339 assert!(err.is_err());
2340 }
2341}
2342
2343#[cfg(test)]
2344mod tests {
2345 #[cfg(target_os = "linux")]
2346 use nix::unistd::{access, AccessFlags};
2347 #[cfg(target_os = "linux")]
2348 use std::io::Write as _;
2349 use std::sync::RwLock;
2350
2351 use super::*;
2352
2353 #[ctor::ctor]
2354 fn init() {
2355 env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init();
2356 }
2357
2358 #[cfg(target_os = "linux")]
2360 macro_rules! function {
2361 () => {{
2362 fn f() {}
2363 fn type_name_of<T>(_: T) -> &'static str {
2364 std::any::type_name::<T>()
2365 }
2366 let name = type_name_of(f);
2367
2368 match &name[..name.len() - 3].rfind(':') {
2370 Some(pos) => &name[pos + 1..name.len() - 3],
2371 None => &name[..name.len() - 3],
2372 }
2373 }};
2374 }
2375
2376 #[test]
2377 #[cfg(target_os = "linux")]
2378 fn test_tensor() {
2379 let _lock = FD_LOCK.read().unwrap();
2380 let shape = vec![1];
2381 let tensor = DmaTensor::<f32>::new(&shape, Some("dma_tensor"));
2382 let dma_enabled = tensor.is_ok();
2383
2384 let tensor = Tensor::<f32>::new(&shape, None, None).expect("Failed to create tensor");
2385 match dma_enabled {
2386 true => assert_eq!(tensor.memory(), TensorMemory::Dma),
2387 false => assert_eq!(tensor.memory(), TensorMemory::Shm),
2388 }
2389 }
2390
2391 #[test]
2392 #[cfg(all(unix, not(target_os = "linux")))]
2393 fn test_tensor() {
2394 let shape = vec![1];
2395 let tensor = Tensor::<f32>::new(&shape, None, None).expect("Failed to create tensor");
2396 assert!(
2398 tensor.memory() == TensorMemory::Shm || tensor.memory() == TensorMemory::Mem,
2399 "Expected SHM or Mem on macOS, got {:?}",
2400 tensor.memory()
2401 );
2402 }
2403
2404 #[test]
2405 #[cfg(not(unix))]
2406 fn test_tensor() {
2407 let shape = vec![1];
2408 let tensor = Tensor::<f32>::new(&shape, None, None).expect("Failed to create tensor");
2409 assert_eq!(tensor.memory(), TensorMemory::Mem);
2410 }
2411
2412 #[test]
2413 #[cfg(target_os = "linux")]
2414 fn test_dma_tensor() {
2415 let _lock = FD_LOCK.read().unwrap();
2416 match access(
2417 "/dev/dma_heap/linux,cma",
2418 AccessFlags::R_OK | AccessFlags::W_OK,
2419 ) {
2420 Ok(_) => println!("/dev/dma_heap/linux,cma is available"),
2421 Err(_) => match access(
2422 "/dev/dma_heap/system",
2423 AccessFlags::R_OK | AccessFlags::W_OK,
2424 ) {
2425 Ok(_) => println!("/dev/dma_heap/system is available"),
2426 Err(e) => {
2427 writeln!(
2428 &mut std::io::stdout(),
2429 "[WARNING] DMA Heap is unavailable: {e}"
2430 )
2431 .unwrap();
2432 return;
2433 }
2434 },
2435 }
2436
2437 let shape = vec![2, 3, 4];
2438 let tensor =
2439 DmaTensor::<f32>::new(&shape, Some("test_tensor")).expect("Failed to create tensor");
2440
2441 const DUMMY_VALUE: f32 = 12.34;
2442
2443 assert_eq!(tensor.memory(), TensorMemory::Dma);
2444 assert_eq!(tensor.name(), "test_tensor");
2445 assert_eq!(tensor.shape(), &shape);
2446 assert_eq!(tensor.size(), 2 * 3 * 4 * std::mem::size_of::<f32>());
2447 assert_eq!(tensor.len(), 2 * 3 * 4);
2448
2449 {
2450 let mut tensor_map = tensor.map().expect("Failed to map DMA memory");
2451 tensor_map.fill(42.0);
2452 assert!(tensor_map.iter().all(|&x| x == 42.0));
2453 }
2454
2455 {
2456 let shared = Tensor::<f32>::from_fd(
2457 tensor
2458 .clone_fd()
2459 .expect("Failed to duplicate tensor file descriptor"),
2460 &shape,
2461 Some("test_tensor_shared"),
2462 )
2463 .expect("Failed to create tensor from fd");
2464
2465 assert_eq!(shared.memory(), TensorMemory::Dma);
2466 assert_eq!(shared.name(), "test_tensor_shared");
2467 assert_eq!(shared.shape(), &shape);
2468
2469 let mut tensor_map = shared.map().expect("Failed to map DMA memory from fd");
2470 tensor_map.fill(DUMMY_VALUE);
2471 assert!(tensor_map.iter().all(|&x| x == DUMMY_VALUE));
2472 }
2473
2474 {
2475 let tensor_map = tensor.map().expect("Failed to map DMA memory");
2476 assert!(tensor_map.iter().all(|&x| x == DUMMY_VALUE));
2477 }
2478
2479 let mut tensor = DmaTensor::<u8>::new(&shape, None).expect("Failed to create tensor");
2480 assert_eq!(tensor.shape(), &shape);
2481 let new_shape = vec![3, 4, 4];
2482 assert!(
2483 tensor.reshape(&new_shape).is_err(),
2484 "Reshape should fail due to size mismatch"
2485 );
2486 assert_eq!(tensor.shape(), &shape, "Shape should remain unchanged");
2487
2488 let new_shape = vec![2, 3, 4];
2489 tensor.reshape(&new_shape).expect("Reshape should succeed");
2490 assert_eq!(
2491 tensor.shape(),
2492 &new_shape,
2493 "Shape should be updated after successful reshape"
2494 );
2495
2496 {
2497 let mut tensor_map = tensor.map().expect("Failed to map DMA memory");
2498 tensor_map.fill(1);
2499 assert!(tensor_map.iter().all(|&x| x == 1));
2500 }
2501
2502 {
2503 let mut tensor_map = tensor.map().expect("Failed to map DMA memory");
2504 tensor_map[2] = 42;
2505 assert_eq!(tensor_map[1], 1, "Value at index 1 should be 1");
2506 assert_eq!(tensor_map[2], 42, "Value at index 2 should be 42");
2507 }
2508 }
2509
2510 #[test]
2511 #[cfg(unix)]
2512 fn test_shm_tensor() {
2513 let _lock = FD_LOCK.read().unwrap();
2514 let shape = vec![2, 3, 4];
2515 let tensor =
2516 ShmTensor::<f32>::new(&shape, Some("test_tensor")).expect("Failed to create tensor");
2517 assert_eq!(tensor.shape(), &shape);
2518 assert_eq!(tensor.size(), 2 * 3 * 4 * std::mem::size_of::<f32>());
2519 assert_eq!(tensor.name(), "test_tensor");
2520
2521 const DUMMY_VALUE: f32 = 12.34;
2522 {
2523 let mut tensor_map = tensor.map().expect("Failed to map shared memory");
2524 tensor_map.fill(42.0);
2525 assert!(tensor_map.iter().all(|&x| x == 42.0));
2526 }
2527
2528 {
2529 let shared = Tensor::<f32>::from_fd(
2530 tensor
2531 .clone_fd()
2532 .expect("Failed to duplicate tensor file descriptor"),
2533 &shape,
2534 Some("test_tensor_shared"),
2535 )
2536 .expect("Failed to create tensor from fd");
2537
2538 assert_eq!(shared.memory(), TensorMemory::Shm);
2539 assert_eq!(shared.name(), "test_tensor_shared");
2540 assert_eq!(shared.shape(), &shape);
2541
2542 let mut tensor_map = shared.map().expect("Failed to map shared memory from fd");
2543 tensor_map.fill(DUMMY_VALUE);
2544 assert!(tensor_map.iter().all(|&x| x == DUMMY_VALUE));
2545 }
2546
2547 {
2548 let tensor_map = tensor.map().expect("Failed to map shared memory");
2549 assert!(tensor_map.iter().all(|&x| x == DUMMY_VALUE));
2550 }
2551
2552 let mut tensor = ShmTensor::<u8>::new(&shape, None).expect("Failed to create tensor");
2553 assert_eq!(tensor.shape(), &shape);
2554 let new_shape = vec![3, 4, 4];
2555 assert!(
2556 tensor.reshape(&new_shape).is_err(),
2557 "Reshape should fail due to size mismatch"
2558 );
2559 assert_eq!(tensor.shape(), &shape, "Shape should remain unchanged");
2560
2561 let new_shape = vec![2, 3, 4];
2562 tensor.reshape(&new_shape).expect("Reshape should succeed");
2563 assert_eq!(
2564 tensor.shape(),
2565 &new_shape,
2566 "Shape should be updated after successful reshape"
2567 );
2568
2569 {
2570 let mut tensor_map = tensor.map().expect("Failed to map shared memory");
2571 tensor_map.fill(1);
2572 assert!(tensor_map.iter().all(|&x| x == 1));
2573 }
2574
2575 {
2576 let mut tensor_map = tensor.map().expect("Failed to map shared memory");
2577 tensor_map[2] = 42;
2578 assert_eq!(tensor_map[1], 1, "Value at index 1 should be 1");
2579 assert_eq!(tensor_map[2], 42, "Value at index 2 should be 42");
2580 }
2581 }
2582
2583 #[test]
2584 fn test_mem_tensor() {
2585 let shape = vec![2, 3, 4];
2586 let tensor =
2587 MemTensor::<f32>::new(&shape, Some("test_tensor")).expect("Failed to create tensor");
2588 assert_eq!(tensor.shape(), &shape);
2589 assert_eq!(tensor.size(), 2 * 3 * 4 * std::mem::size_of::<f32>());
2590 assert_eq!(tensor.name(), "test_tensor");
2591
2592 {
2593 let mut tensor_map = tensor.map().expect("Failed to map memory");
2594 tensor_map.fill(42.0);
2595 assert!(tensor_map.iter().all(|&x| x == 42.0));
2596 }
2597
2598 let mut tensor = MemTensor::<u8>::new(&shape, None).expect("Failed to create tensor");
2599 assert_eq!(tensor.shape(), &shape);
2600 let new_shape = vec![3, 4, 4];
2601 assert!(
2602 tensor.reshape(&new_shape).is_err(),
2603 "Reshape should fail due to size mismatch"
2604 );
2605 assert_eq!(tensor.shape(), &shape, "Shape should remain unchanged");
2606
2607 let new_shape = vec![2, 3, 4];
2608 tensor.reshape(&new_shape).expect("Reshape should succeed");
2609 assert_eq!(
2610 tensor.shape(),
2611 &new_shape,
2612 "Shape should be updated after successful reshape"
2613 );
2614
2615 {
2616 let mut tensor_map = tensor.map().expect("Failed to map memory");
2617 tensor_map.fill(1);
2618 assert!(tensor_map.iter().all(|&x| x == 1));
2619 }
2620
2621 {
2622 let mut tensor_map = tensor.map().expect("Failed to map memory");
2623 tensor_map[2] = 42;
2624 assert_eq!(tensor_map[1], 1, "Value at index 1 should be 1");
2625 assert_eq!(tensor_map[2], 42, "Value at index 2 should be 42");
2626 }
2627 }
2628
2629 #[test]
2630 #[cfg(target_os = "linux")]
2631 fn test_dma_no_fd_leaks() {
2632 let _lock = FD_LOCK.write().unwrap();
2633 if !is_dma_available() {
2634 log::warn!(
2635 "SKIPPED: {} - DMA memory allocation not available (permission denied or no DMA-BUF support)",
2636 function!()
2637 );
2638 return;
2639 }
2640
2641 let proc = procfs::process::Process::myself()
2642 .expect("Failed to get current process using /proc/self");
2643
2644 let start_open_fds = proc
2645 .fd_count()
2646 .expect("Failed to get open file descriptor count");
2647
2648 for _ in 0..100 {
2649 let tensor = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Dma), None)
2650 .expect("Failed to create tensor");
2651 let mut map = tensor.map().unwrap();
2652 map.as_mut_slice().fill(233);
2653 }
2654
2655 let end_open_fds = proc
2656 .fd_count()
2657 .expect("Failed to get open file descriptor count");
2658
2659 assert_eq!(
2660 start_open_fds, end_open_fds,
2661 "File descriptor leak detected: {} -> {}",
2662 start_open_fds, end_open_fds
2663 );
2664 }
2665
2666 #[test]
2667 #[cfg(target_os = "linux")]
2668 fn test_dma_from_fd_no_fd_leaks() {
2669 let _lock = FD_LOCK.write().unwrap();
2670 if !is_dma_available() {
2671 log::warn!(
2672 "SKIPPED: {} - DMA memory allocation not available (permission denied or no DMA-BUF support)",
2673 function!()
2674 );
2675 return;
2676 }
2677
2678 let proc = procfs::process::Process::myself()
2679 .expect("Failed to get current process using /proc/self");
2680
2681 let start_open_fds = proc
2682 .fd_count()
2683 .expect("Failed to get open file descriptor count");
2684
2685 let orig = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Dma), None).unwrap();
2686
2687 for _ in 0..100 {
2688 let tensor =
2689 Tensor::<u8>::from_fd(orig.clone_fd().unwrap(), orig.shape(), None).unwrap();
2690 let mut map = tensor.map().unwrap();
2691 map.as_mut_slice().fill(233);
2692 }
2693 drop(orig);
2694
2695 let end_open_fds = proc.fd_count().unwrap();
2696
2697 assert_eq!(
2698 start_open_fds, end_open_fds,
2699 "File descriptor leak detected: {} -> {}",
2700 start_open_fds, end_open_fds
2701 );
2702 }
2703
2704 #[test]
2705 #[cfg(target_os = "linux")]
2706 fn test_shm_no_fd_leaks() {
2707 let _lock = FD_LOCK.write().unwrap();
2708 if !is_shm_available() {
2709 log::warn!(
2710 "SKIPPED: {} - SHM memory allocation not available (permission denied or no SHM support)",
2711 function!()
2712 );
2713 return;
2714 }
2715
2716 let proc = procfs::process::Process::myself()
2717 .expect("Failed to get current process using /proc/self");
2718
2719 let start_open_fds = proc
2720 .fd_count()
2721 .expect("Failed to get open file descriptor count");
2722
2723 for _ in 0..100 {
2724 let tensor = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Shm), None)
2725 .expect("Failed to create tensor");
2726 let mut map = tensor.map().unwrap();
2727 map.as_mut_slice().fill(233);
2728 }
2729
2730 let end_open_fds = proc
2731 .fd_count()
2732 .expect("Failed to get open file descriptor count");
2733
2734 assert_eq!(
2735 start_open_fds, end_open_fds,
2736 "File descriptor leak detected: {} -> {}",
2737 start_open_fds, end_open_fds
2738 );
2739 }
2740
2741 #[test]
2742 #[cfg(target_os = "linux")]
2743 fn test_shm_from_fd_no_fd_leaks() {
2744 let _lock = FD_LOCK.write().unwrap();
2745 if !is_shm_available() {
2746 log::warn!(
2747 "SKIPPED: {} - SHM memory allocation not available (permission denied or no SHM support)",
2748 function!()
2749 );
2750 return;
2751 }
2752
2753 let proc = procfs::process::Process::myself()
2754 .expect("Failed to get current process using /proc/self");
2755
2756 let start_open_fds = proc
2757 .fd_count()
2758 .expect("Failed to get open file descriptor count");
2759
2760 let orig = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Shm), None).unwrap();
2761
2762 for _ in 0..100 {
2763 let tensor =
2764 Tensor::<u8>::from_fd(orig.clone_fd().unwrap(), orig.shape(), None).unwrap();
2765 let mut map = tensor.map().unwrap();
2766 map.as_mut_slice().fill(233);
2767 }
2768 drop(orig);
2769
2770 let end_open_fds = proc.fd_count().unwrap();
2771
2772 assert_eq!(
2773 start_open_fds, end_open_fds,
2774 "File descriptor leak detected: {} -> {}",
2775 start_open_fds, end_open_fds
2776 );
2777 }
2778
2779 #[cfg(feature = "ndarray")]
2780 #[test]
2781 fn test_ndarray() {
2782 let _lock = FD_LOCK.read().unwrap();
2783 let shape = vec![2, 3, 4];
2784 let tensor = Tensor::<f32>::new(&shape, None, None).expect("Failed to create tensor");
2785
2786 let mut tensor_map = tensor.map().expect("Failed to map tensor memory");
2787 tensor_map.fill(1.0);
2788
2789 let view = tensor_map.view().expect("Failed to get ndarray view");
2790 assert_eq!(view.shape(), &[2, 3, 4]);
2791 assert!(view.iter().all(|&x| x == 1.0));
2792
2793 let mut view_mut = tensor_map
2794 .view_mut()
2795 .expect("Failed to get mutable ndarray view");
2796 view_mut[[0, 0, 0]] = 42.0;
2797 assert_eq!(view_mut[[0, 0, 0]], 42.0);
2798 assert_eq!(tensor_map[0], 42.0, "Value at index 0 should be 42");
2799 }
2800
2801 #[test]
2802 fn test_buffer_identity_unique() {
2803 let id1 = BufferIdentity::new();
2804 let id2 = BufferIdentity::new();
2805 assert_ne!(
2806 id1.id(),
2807 id2.id(),
2808 "Two identities should have different ids"
2809 );
2810 }
2811
2812 #[test]
2813 fn test_buffer_identity_clone_shares_guard() {
2814 let id1 = BufferIdentity::new();
2815 let weak = id1.weak();
2816 assert!(
2817 weak.upgrade().is_some(),
2818 "Weak should be alive while original exists"
2819 );
2820
2821 let id2 = id1.clone();
2822 assert_eq!(id1.id(), id2.id(), "Cloned identity should have same id");
2823
2824 drop(id1);
2825 assert!(
2826 weak.upgrade().is_some(),
2827 "Weak should still be alive (clone holds Arc)"
2828 );
2829
2830 drop(id2);
2831 assert!(
2832 weak.upgrade().is_none(),
2833 "Weak should be dead after all clones dropped"
2834 );
2835 }
2836
2837 #[test]
2838 fn test_tensor_buffer_identity() {
2839 let t1 = Tensor::<u8>::new(&[100], Some(TensorMemory::Mem), Some("t1")).unwrap();
2840 let t2 = Tensor::<u8>::new(&[100], Some(TensorMemory::Mem), Some("t2")).unwrap();
2841 assert_ne!(
2842 t1.buffer_identity().id(),
2843 t2.buffer_identity().id(),
2844 "Different tensors should have different buffer ids"
2845 );
2846 }
2847
2848 #[test]
2853 fn test_quantization_per_tensor_constructors() {
2854 let q = Quantization::per_tensor(0.1, -5);
2855 assert!(q.is_per_tensor());
2856 assert!(!q.is_per_channel());
2857 assert!(!q.is_symmetric());
2858 assert_eq!(q.scale(), &[0.1]);
2859 assert_eq!(q.zero_point(), Some(&[-5][..]));
2860
2861 let qs = Quantization::per_tensor_symmetric(0.05);
2862 assert!(qs.is_per_tensor());
2863 assert!(qs.is_symmetric());
2864 assert_eq!(qs.zero_point(), None);
2865 }
2866
2867 #[test]
2868 fn test_quantization_per_channel_constructors() {
2869 let q = Quantization::per_channel(vec![0.1, 0.2, 0.3], vec![0, -1, 1], 2).unwrap();
2870 assert!(q.is_per_channel());
2871 assert!(!q.is_symmetric());
2872 assert_eq!(q.axis(), Some(2));
2873 assert_eq!(q.scale().len(), 3);
2874
2875 let qs = Quantization::per_channel_symmetric(vec![0.054, 0.089, 0.195], 0).unwrap();
2876 assert!(qs.is_per_channel());
2877 assert!(qs.is_symmetric());
2878 assert_eq!(qs.axis(), Some(0));
2879 }
2880
2881 #[test]
2882 fn test_quantization_per_channel_length_mismatch_rejected() {
2883 let err = Quantization::per_channel(vec![0.1, 0.2], vec![0, 0, 0], 0).unwrap_err();
2885 assert!(matches!(err, Error::QuantizationInvalid { .. }));
2886 }
2887
2888 #[test]
2889 fn test_quantization_per_channel_empty_rejected() {
2890 let err = Quantization::per_channel_symmetric(vec![], 0).unwrap_err();
2891 assert!(matches!(err, Error::QuantizationInvalid { .. }));
2892 }
2893
2894 #[test]
2901 fn test_quantization_validate_rejects_malformed_deserialize() {
2902 let mut t = Tensor::<i8>::new(&[1, 1, 4], Some(TensorMemory::Mem), None).unwrap();
2903
2904 let q: Quantization = serde_json::from_str(r#"{"scale": []}"#).unwrap();
2906 assert!(matches!(
2907 t.set_quantization(q).unwrap_err(),
2908 Error::QuantizationInvalid { .. }
2909 ));
2910
2911 let q: Quantization =
2913 serde_json::from_str(r#"{"scale": 0.1, "zero_point": [0, 0, 0]}"#).unwrap();
2914 assert!(matches!(
2915 t.set_quantization(q).unwrap_err(),
2916 Error::QuantizationInvalid { .. }
2917 ));
2918
2919 let q: Quantization = serde_json::from_str(
2921 r#"{"scale": [0.1, 0.2, 0.3, 0.4], "zero_point": [0, 0], "axis": 2}"#,
2922 )
2923 .unwrap();
2924 assert!(matches!(
2925 t.set_quantization(q).unwrap_err(),
2926 Error::QuantizationInvalid { .. }
2927 ));
2928 }
2929
2930 #[test]
2931 fn test_quantization_mode_dispatch() {
2932 let pt = Quantization::per_tensor(0.1, -5);
2933 assert!(matches!(
2934 pt.mode(),
2935 QuantMode::PerTensor { scale, zero_point } if scale == 0.1 && zero_point == -5
2936 ));
2937
2938 let pts = Quantization::per_tensor_symmetric(0.05);
2939 assert!(matches!(
2940 pts.mode(),
2941 QuantMode::PerTensorSymmetric { scale } if scale == 0.05
2942 ));
2943
2944 let pc = Quantization::per_channel(vec![0.1, 0.2], vec![0, -1], 2).unwrap();
2945 assert!(matches!(pc.mode(), QuantMode::PerChannel { axis: 2, .. }));
2946
2947 let pcs = Quantization::per_channel_symmetric(vec![0.1, 0.2], 0).unwrap();
2948 assert!(matches!(
2949 pcs.mode(),
2950 QuantMode::PerChannelSymmetric { axis: 0, .. }
2951 ));
2952 }
2953
2954 #[test]
2955 fn test_tensor_quantization_roundtrip_integer() {
2956 let mut t = Tensor::<i8>::new(&[2, 3, 4], Some(TensorMemory::Mem), None).unwrap();
2957 assert!(t.quantization().is_none());
2958 t.set_quantization(Quantization::per_tensor(0.1, -5))
2959 .unwrap();
2960 let q = t.quantization().unwrap();
2961 assert_eq!(q.scale(), &[0.1]);
2962 t.clear_quantization();
2963 assert!(t.quantization().is_none());
2964 }
2965
2966 #[test]
2967 fn test_tensor_with_quantization_builder() {
2968 let t = Tensor::<i8>::new(&[4, 4], Some(TensorMemory::Mem), None)
2969 .unwrap()
2970 .with_quantization(Quantization::per_tensor_symmetric(0.05))
2971 .unwrap();
2972 assert!(t.quantization().is_some());
2973 }
2974
2975 #[test]
2976 fn test_tensor_dyn_quantization_float_arm_returns_none() {
2977 let t = Tensor::<f32>::new(&[2, 2], Some(TensorMemory::Mem), None).unwrap();
2978 let td = TensorDyn::F32(t);
2979 assert!(td.quantization().is_none());
2980 }
2981
2982 #[test]
2983 fn test_tensor_dyn_set_quantization_float_arm_errors() {
2984 let t = Tensor::<f32>::new(&[2, 2], Some(TensorMemory::Mem), None).unwrap();
2985 let mut td = TensorDyn::F32(t);
2986 let err = td
2987 .set_quantization(Quantization::per_tensor(0.1, 0))
2988 .unwrap_err();
2989 assert!(matches!(err, Error::QuantizationInvalid { .. }));
2991 }
2992
2993 fn _compile_fail_doctest_anchor() {}
3003
3004 pub static FD_LOCK: RwLock<()> = RwLock::new(());
3008
3009 #[test]
3012 #[cfg(not(target_os = "linux"))]
3013 fn test_dma_not_available_on_non_linux() {
3014 assert!(
3015 !is_dma_available(),
3016 "DMA memory allocation should NOT be available on non-Linux platforms"
3017 );
3018 }
3019
3020 #[test]
3023 #[cfg(unix)]
3024 fn test_shm_available_and_usable() {
3025 assert!(
3026 is_shm_available(),
3027 "SHM memory allocation should be available on Unix systems"
3028 );
3029
3030 let tensor = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Shm), None)
3032 .expect("Failed to create SHM tensor");
3033
3034 let mut map = tensor.map().expect("Failed to map SHM tensor");
3036 map.as_mut_slice().fill(0xAB);
3037
3038 assert!(
3040 map.as_slice().iter().all(|&b| b == 0xAB),
3041 "SHM tensor data should be writable and readable"
3042 );
3043 }
3044}