1#[cfg(target_os = "linux")]
30mod dma;
31#[cfg(target_os = "linux")]
32mod dmabuf;
33mod error;
34mod format;
35mod mem;
36mod pbo;
37#[cfg(unix)]
38mod shm;
39mod tensor_dyn;
40
41#[cfg(target_os = "linux")]
42pub use crate::dma::{DmaMap, DmaTensor};
43pub use crate::mem::{MemMap, MemTensor};
44pub use crate::pbo::{PboMap, PboMapping, PboOps, PboTensor};
45#[cfg(unix)]
46pub use crate::shm::{ShmMap, ShmTensor};
47pub use error::{Error, Result};
48pub use format::{PixelFormat, PixelLayout};
49use num_traits::Num;
50use serde::{Deserialize, Serialize};
51#[cfg(unix)]
52use std::os::fd::OwnedFd;
53use std::{
54 fmt,
55 ops::{Deref, DerefMut},
56 sync::{
57 atomic::{AtomicU64, Ordering},
58 Arc, Weak,
59 },
60};
61pub use tensor_dyn::TensorDyn;
62
63#[cfg(unix)]
83pub struct PlaneDescriptor {
84 fd: OwnedFd,
85 stride: Option<usize>,
86 offset: Option<usize>,
87}
88
89#[cfg(unix)]
90impl PlaneDescriptor {
91 pub fn new(fd: std::os::fd::BorrowedFd<'_>) -> Result<Self> {
101 let owned = fd.try_clone_to_owned()?;
102 Ok(Self {
103 fd: owned,
104 stride: None,
105 offset: None,
106 })
107 }
108
109 pub fn with_stride(mut self, stride: usize) -> Self {
111 self.stride = Some(stride);
112 self
113 }
114
115 pub fn with_offset(mut self, offset: usize) -> Self {
117 self.offset = Some(offset);
118 self
119 }
120
121 pub fn into_fd(self) -> OwnedFd {
123 self.fd
124 }
125
126 pub fn stride(&self) -> Option<usize> {
128 self.stride
129 }
130
131 pub fn offset(&self) -> Option<usize> {
133 self.offset
134 }
135}
136
137#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
139#[repr(u8)]
140#[non_exhaustive]
141pub enum DType {
142 U8,
143 I8,
144 U16,
145 I16,
146 U32,
147 I32,
148 U64,
149 I64,
150 F16,
151 F32,
152 F64,
153}
154
155impl DType {
156 pub const fn size(&self) -> usize {
158 match self {
159 Self::U8 | Self::I8 => 1,
160 Self::U16 | Self::I16 | Self::F16 => 2,
161 Self::U32 | Self::I32 | Self::F32 => 4,
162 Self::U64 | Self::I64 | Self::F64 => 8,
163 }
164 }
165
166 pub const fn name(&self) -> &'static str {
168 match self {
169 Self::U8 => "u8",
170 Self::I8 => "i8",
171 Self::U16 => "u16",
172 Self::I16 => "i16",
173 Self::U32 => "u32",
174 Self::I32 => "i32",
175 Self::U64 => "u64",
176 Self::I64 => "i64",
177 Self::F16 => "f16",
178 Self::F32 => "f32",
179 Self::F64 => "f64",
180 }
181 }
182}
183
184impl fmt::Display for DType {
185 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
186 f.write_str(self.name())
187 }
188}
189
190mod sealed {
198 pub trait Sealed {}
199 impl Sealed for u8 {}
200 impl Sealed for i8 {}
201 impl Sealed for u16 {}
202 impl Sealed for i16 {}
203 impl Sealed for u32 {}
204 impl Sealed for i32 {}
205 impl Sealed for u64 {}
206 impl Sealed for i64 {}
207 }
209
210pub trait IntegerType: sealed::Sealed {}
217impl IntegerType for u8 {}
218impl IntegerType for i8 {}
219impl IntegerType for u16 {}
220impl IntegerType for i16 {}
221impl IntegerType for u32 {}
222impl IntegerType for i32 {}
223impl IntegerType for u64 {}
224impl IntegerType for i64 {}
225
226#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
250pub struct Quantization {
251 #[serde(deserialize_with = "deserialize_scalar_or_vec_f32")]
253 scale: Vec<f32>,
254
255 #[serde(
258 default,
259 deserialize_with = "deserialize_opt_scalar_or_vec_i32",
260 skip_serializing_if = "Option::is_none"
261 )]
262 zero_point: Option<Vec<i32>>,
263
264 #[serde(default, skip_serializing_if = "Option::is_none")]
268 axis: Option<usize>,
269}
270
271#[derive(Debug, Clone, Copy)]
277pub enum QuantMode<'a> {
278 PerTensorSymmetric {
279 scale: f32,
280 },
281 PerTensor {
282 scale: f32,
283 zero_point: i32,
284 },
285 PerChannelSymmetric {
286 scales: &'a [f32],
287 axis: usize,
288 },
289 PerChannel {
290 scales: &'a [f32],
291 zero_points: &'a [i32],
292 axis: usize,
293 },
294}
295
296impl Quantization {
297 pub fn per_tensor_symmetric(scale: f32) -> Self {
299 Self {
300 scale: vec![scale],
301 zero_point: None,
302 axis: None,
303 }
304 }
305
306 pub fn per_tensor(scale: f32, zero_point: i32) -> Self {
308 Self {
309 scale: vec![scale],
310 zero_point: Some(vec![zero_point]),
311 axis: None,
312 }
313 }
314
315 pub fn per_channel_symmetric(scales: Vec<f32>, axis: usize) -> Result<Self> {
317 if scales.is_empty() {
318 return Err(Error::QuantizationInvalid {
319 field: "scale.len",
320 expected: "non-empty per-channel scales".to_string(),
321 got: "length 0".to_string(),
322 });
323 }
324 Ok(Self {
325 scale: scales,
326 zero_point: None,
327 axis: Some(axis),
328 })
329 }
330
331 pub fn per_channel(scales: Vec<f32>, zero_points: Vec<i32>, axis: usize) -> Result<Self> {
334 if scales.is_empty() {
335 return Err(Error::QuantizationInvalid {
336 field: "scale.len",
337 expected: "non-empty per-channel scales".to_string(),
338 got: "length 0".to_string(),
339 });
340 }
341 if scales.len() != zero_points.len() {
342 return Err(Error::QuantizationInvalid {
343 field: "zero_point.len",
344 expected: format!("length matches scale ({})", scales.len()),
345 got: format!("length {}", zero_points.len()),
346 });
347 }
348 Ok(Self {
349 scale: scales,
350 zero_point: Some(zero_points),
351 axis: Some(axis),
352 })
353 }
354
355 pub fn mode(&self) -> QuantMode<'_> {
357 match (self.scale.len(), self.zero_point.as_deref(), self.axis) {
358 (1, None, _) => QuantMode::PerTensorSymmetric {
359 scale: self.scale[0],
360 },
361 (1, Some(zps), _) => QuantMode::PerTensor {
362 scale: self.scale[0],
363 zero_point: zps.first().copied().unwrap_or(0),
364 },
365 (_, None, Some(axis)) => QuantMode::PerChannelSymmetric {
366 scales: &self.scale,
367 axis,
368 },
369 (_, Some(zps), Some(axis)) => QuantMode::PerChannel {
370 scales: &self.scale,
371 zero_points: zps,
372 axis,
373 },
374 _ => {
380 debug_assert!(
381 false,
382 "Quantization::mode: per-channel without axis is unreachable"
383 );
384 QuantMode::PerTensorSymmetric {
385 scale: self.scale.first().copied().unwrap_or(1.0),
386 }
387 }
388 }
389 }
390
391 pub fn is_per_tensor(&self) -> bool {
393 self.scale.len() == 1
394 }
395
396 pub fn is_per_channel(&self) -> bool {
398 self.scale.len() > 1
399 }
400
401 pub fn is_symmetric(&self) -> bool {
404 match &self.zero_point {
405 None => true,
406 Some(zps) => zps.iter().all(|&z| z == 0),
407 }
408 }
409
410 pub fn scale(&self) -> &[f32] {
413 &self.scale
414 }
415
416 pub fn zero_point(&self) -> Option<&[i32]> {
418 self.zero_point.as_deref()
419 }
420
421 pub fn axis(&self) -> Option<usize> {
423 self.axis
424 }
425
426 pub(crate) fn validate(&self, shape: &[usize]) -> Result<()> {
436 if self.scale.is_empty() {
441 return Err(Error::QuantizationInvalid {
442 field: "scale.len",
443 expected: ">= 1".to_string(),
444 got: "0".to_string(),
445 });
446 }
447 if let Some(zps) = self.zero_point.as_ref() {
448 let expected = if self.scale.len() == 1 {
451 1
452 } else {
453 self.scale.len()
454 };
455 if zps.len() != expected {
456 return Err(Error::QuantizationInvalid {
457 field: "zero_point.len",
458 expected: format!(
459 "{expected} (matching {})",
460 if self.scale.len() == 1 {
461 "per-tensor scale"
462 } else {
463 "per-channel scale.len"
464 }
465 ),
466 got: format!("length {}", zps.len()),
467 });
468 }
469 }
470
471 match (self.scale.len(), self.axis) {
472 (1, None) => Ok(()),
473 (1, Some(_)) => Err(Error::QuantizationInvalid {
474 field: "per_tensor_redundant_axis",
475 expected: "axis=None for per-tensor quantization".to_string(),
476 got: format!("axis={:?}", self.axis),
477 }),
478 (_, None) => Err(Error::QuantizationInvalid {
479 field: "per_channel_requires_axis",
480 expected: format!(
481 "axis=Some(_) for per-channel quantization (scale.len={})",
482 self.scale.len()
483 ),
484 got: "axis=None".to_string(),
485 }),
486 (n, Some(axis)) => {
487 if axis >= shape.len() {
488 return Err(Error::QuantizationInvalid {
489 field: "axis",
490 expected: format!("axis < tensor rank ({})", shape.len()),
491 got: format!("axis={axis}"),
492 });
493 }
494 if shape[axis] != n {
495 return Err(Error::QuantizationInvalid {
496 field: "scale.len",
497 expected: format!("length matches shape[{axis}] ({})", shape[axis]),
498 got: format!("length {n}"),
499 });
500 }
501 Ok(())
502 }
503 }
504 }
505}
506
507impl From<(f32, i32)> for Quantization {
508 fn from((scale, zero_point): (f32, i32)) -> Self {
512 Self::per_tensor(scale, zero_point)
513 }
514}
515
516fn deserialize_scalar_or_vec_f32<'de, D: serde::Deserializer<'de>>(
517 de: D,
518) -> std::result::Result<Vec<f32>, D::Error> {
519 use serde::de::{self, Visitor};
520 struct V;
521 impl<'de> Visitor<'de> for V {
522 type Value = Vec<f32>;
523 fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
524 f.write_str("f32 or array of f32")
525 }
526 fn visit_f64<E: de::Error>(self, v: f64) -> std::result::Result<Self::Value, E> {
527 Ok(vec![v as f32])
528 }
529 #[allow(clippy::cast_possible_truncation)]
530 fn visit_i64<E: de::Error>(self, v: i64) -> std::result::Result<Self::Value, E> {
531 Ok(vec![v as f32])
532 }
533 #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
534 fn visit_u64<E: de::Error>(self, v: u64) -> std::result::Result<Self::Value, E> {
535 Ok(vec![v as f32])
536 }
537 fn visit_seq<A: de::SeqAccess<'de>>(
538 self,
539 mut seq: A,
540 ) -> std::result::Result<Self::Value, A::Error> {
541 let mut out = Vec::with_capacity(seq.size_hint().unwrap_or(1));
542 while let Some(x) = seq.next_element::<f32>()? {
543 out.push(x);
544 }
545 Ok(out)
546 }
547 }
548 de.deserialize_any(V)
549}
550
551fn deserialize_opt_scalar_or_vec_i32<'de, D: serde::Deserializer<'de>>(
552 de: D,
553) -> std::result::Result<Option<Vec<i32>>, D::Error> {
554 use serde::de::{self, Visitor};
555 struct V;
556 impl<'de> Visitor<'de> for V {
557 type Value = Option<Vec<i32>>;
558 fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
559 f.write_str("null, i32, or array of i32")
560 }
561 fn visit_none<E: de::Error>(self) -> std::result::Result<Self::Value, E> {
562 Ok(None)
563 }
564 fn visit_unit<E: de::Error>(self) -> std::result::Result<Self::Value, E> {
565 Ok(None)
566 }
567 fn visit_some<D2: serde::Deserializer<'de>>(
568 self,
569 de: D2,
570 ) -> std::result::Result<Self::Value, D2::Error> {
571 struct Inner;
572 impl<'de> Visitor<'de> for Inner {
573 type Value = Vec<i32>;
574 fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
575 f.write_str("i32 or array of i32")
576 }
577 #[allow(clippy::cast_possible_truncation)]
578 fn visit_i64<E: de::Error>(self, v: i64) -> std::result::Result<Self::Value, E> {
579 Ok(vec![v as i32])
580 }
581 #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
582 fn visit_u64<E: de::Error>(self, v: u64) -> std::result::Result<Self::Value, E> {
583 Ok(vec![v as i32])
584 }
585 fn visit_seq<A: de::SeqAccess<'de>>(
586 self,
587 mut seq: A,
588 ) -> std::result::Result<Self::Value, A::Error> {
589 let mut out = Vec::with_capacity(seq.size_hint().unwrap_or(1));
590 while let Some(x) = seq.next_element::<i32>()? {
591 out.push(x);
592 }
593 Ok(out)
594 }
595 }
596 de.deserialize_any(Inner).map(Some)
597 }
598 #[allow(clippy::cast_possible_truncation)]
599 fn visit_i64<E: de::Error>(self, v: i64) -> std::result::Result<Self::Value, E> {
600 Ok(Some(vec![v as i32]))
601 }
602 #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
603 fn visit_u64<E: de::Error>(self, v: u64) -> std::result::Result<Self::Value, E> {
604 Ok(Some(vec![v as i32]))
605 }
606 fn visit_seq<A: de::SeqAccess<'de>>(
607 self,
608 mut seq: A,
609 ) -> std::result::Result<Self::Value, A::Error> {
610 let mut out = Vec::with_capacity(seq.size_hint().unwrap_or(1));
611 while let Some(x) = seq.next_element::<i32>()? {
612 out.push(x);
613 }
614 Ok(Some(out))
615 }
616 }
617 de.deserialize_option(V)
618}
619
620static NEXT_BUFFER_ID: AtomicU64 = AtomicU64::new(1);
622
623#[derive(Debug, Clone)]
629pub struct BufferIdentity {
630 id: u64,
631 guard: Arc<()>,
632}
633
634impl BufferIdentity {
635 pub fn new() -> Self {
637 Self {
638 id: NEXT_BUFFER_ID.fetch_add(1, Ordering::Relaxed),
639 guard: Arc::new(()),
640 }
641 }
642
643 pub fn id(&self) -> u64 {
645 self.id
646 }
647
648 pub fn weak(&self) -> Weak<()> {
651 Arc::downgrade(&self.guard)
652 }
653}
654
655impl Default for BufferIdentity {
656 fn default() -> Self {
657 Self::new()
658 }
659}
660
661#[cfg(target_os = "linux")]
662use nix::sys::stat::{major, minor};
663
664pub trait TensorTrait<T>: Send + Sync
665where
666 T: Num + Clone + fmt::Debug,
667{
668 fn new(shape: &[usize], name: Option<&str>) -> Result<Self>
671 where
672 Self: Sized;
673
674 #[cfg(unix)]
675 fn from_fd(fd: std::os::fd::OwnedFd, shape: &[usize], name: Option<&str>) -> Result<Self>
681 where
682 Self: Sized;
683
684 #[cfg(unix)]
685 fn clone_fd(&self) -> Result<std::os::fd::OwnedFd>;
687
688 fn memory(&self) -> TensorMemory;
690
691 fn name(&self) -> String;
693
694 fn len(&self) -> usize {
696 self.shape().iter().product()
697 }
698
699 fn is_empty(&self) -> bool {
701 self.len() == 0
702 }
703
704 fn size(&self) -> usize {
706 self.len() * std::mem::size_of::<T>()
707 }
708
709 fn shape(&self) -> &[usize];
711
712 fn reshape(&mut self, shape: &[usize]) -> Result<()>;
715
716 fn map(&self) -> Result<TensorMap<T>>;
719
720 fn buffer_identity(&self) -> &BufferIdentity;
722}
723
724pub trait TensorMapTrait<T>
725where
726 T: Num + Clone + fmt::Debug,
727{
728 fn shape(&self) -> &[usize];
730
731 fn unmap(&mut self);
733
734 fn len(&self) -> usize {
736 self.shape().iter().product()
737 }
738
739 fn is_empty(&self) -> bool {
741 self.len() == 0
742 }
743
744 fn size(&self) -> usize {
746 self.len() * std::mem::size_of::<T>()
747 }
748
749 fn as_slice(&self) -> &[T];
751
752 fn as_mut_slice(&mut self) -> &mut [T];
754
755 #[cfg(feature = "ndarray")]
756 fn view(&'_ self) -> Result<ndarray::ArrayView<'_, T, ndarray::Dim<ndarray::IxDynImpl>>> {
758 Ok(ndarray::ArrayView::from_shape(
759 self.shape(),
760 self.as_slice(),
761 )?)
762 }
763
764 #[cfg(feature = "ndarray")]
765 fn view_mut(
767 &'_ mut self,
768 ) -> Result<ndarray::ArrayViewMut<'_, T, ndarray::Dim<ndarray::IxDynImpl>>> {
769 let shape = self.shape().to_vec();
770 Ok(ndarray::ArrayViewMut::from_shape(
771 shape,
772 self.as_mut_slice(),
773 )?)
774 }
775}
776
777#[derive(Debug, Clone, Copy, PartialEq, Eq)]
778pub enum TensorMemory {
779 #[cfg(target_os = "linux")]
780 Dma,
784 #[cfg(unix)]
785 Shm,
788
789 Mem,
791
792 Pbo,
795}
796
797impl From<TensorMemory> for String {
798 fn from(memory: TensorMemory) -> Self {
799 match memory {
800 #[cfg(target_os = "linux")]
801 TensorMemory::Dma => "dma".to_owned(),
802 #[cfg(unix)]
803 TensorMemory::Shm => "shm".to_owned(),
804 TensorMemory::Mem => "mem".to_owned(),
805 TensorMemory::Pbo => "pbo".to_owned(),
806 }
807 }
808}
809
810impl TryFrom<&str> for TensorMemory {
811 type Error = Error;
812
813 fn try_from(s: &str) -> Result<Self> {
814 match s {
815 #[cfg(target_os = "linux")]
816 "dma" => Ok(TensorMemory::Dma),
817 #[cfg(unix)]
818 "shm" => Ok(TensorMemory::Shm),
819 "mem" => Ok(TensorMemory::Mem),
820 "pbo" => Ok(TensorMemory::Pbo),
821 _ => Err(Error::InvalidMemoryType(s.to_owned())),
822 }
823 }
824}
825
826#[derive(Debug)]
827#[allow(dead_code)] pub(crate) enum TensorStorage<T>
829where
830 T: Num + Clone + fmt::Debug + Send + Sync,
831{
832 #[cfg(target_os = "linux")]
833 Dma(DmaTensor<T>),
834 #[cfg(unix)]
835 Shm(ShmTensor<T>),
836 Mem(MemTensor<T>),
837 Pbo(PboTensor<T>),
838}
839
840impl<T> TensorStorage<T>
841where
842 T: Num + Clone + fmt::Debug + Send + Sync,
843{
844 fn new(shape: &[usize], memory: Option<TensorMemory>, name: Option<&str>) -> Result<Self> {
849 match memory {
850 #[cfg(target_os = "linux")]
851 Some(TensorMemory::Dma) => {
852 DmaTensor::<T>::new(shape, name).map(TensorStorage::Dma)
853 }
854 #[cfg(unix)]
855 Some(TensorMemory::Shm) => {
856 ShmTensor::<T>::new(shape, name).map(TensorStorage::Shm)
857 }
858 Some(TensorMemory::Mem) => {
859 MemTensor::<T>::new(shape, name).map(TensorStorage::Mem)
860 }
861 Some(TensorMemory::Pbo) => Err(crate::error::Error::NotImplemented(
862 "PboTensor cannot be created via Tensor::new() — use ImageProcessor::create_image()".to_owned(),
863 )),
864 None => {
865 if std::env::var("EDGEFIRST_TENSOR_FORCE_MEM")
866 .is_ok_and(|x| x != "0" && x.to_lowercase() != "false")
867 {
868 MemTensor::<T>::new(shape, name).map(TensorStorage::Mem)
869 } else {
870 #[cfg(target_os = "linux")]
871 {
872 match DmaTensor::<T>::new(shape, name) {
874 Ok(tensor) => Ok(TensorStorage::Dma(tensor)),
875 Err(_) => {
876 match ShmTensor::<T>::new(shape, name)
877 .map(TensorStorage::Shm)
878 {
879 Ok(tensor) => Ok(tensor),
880 Err(_) => MemTensor::<T>::new(shape, name)
881 .map(TensorStorage::Mem),
882 }
883 }
884 }
885 }
886 #[cfg(all(unix, not(target_os = "linux")))]
887 {
888 match ShmTensor::<T>::new(shape, name) {
890 Ok(tensor) => Ok(TensorStorage::Shm(tensor)),
891 Err(_) => {
892 MemTensor::<T>::new(shape, name).map(TensorStorage::Mem)
893 }
894 }
895 }
896 #[cfg(not(unix))]
897 {
898 MemTensor::<T>::new(shape, name).map(TensorStorage::Mem)
900 }
901 }
902 }
903 }
904 }
905
906 #[cfg(target_os = "linux")]
915 pub(crate) fn new_dma_with_byte_size(
916 shape: &[usize],
917 byte_size: usize,
918 name: Option<&str>,
919 ) -> Result<Self> {
920 DmaTensor::<T>::new_with_byte_size(shape, byte_size, name).map(TensorStorage::Dma)
921 }
922
923 #[cfg(unix)]
931 fn from_fd(fd: OwnedFd, shape: &[usize], name: Option<&str>) -> Result<Self> {
932 #[cfg(target_os = "linux")]
933 {
934 use nix::sys::stat::fstat;
935
936 let stat = fstat(&fd)?;
937 let major = major(stat.st_dev);
938 let minor = minor(stat.st_dev);
939
940 log::debug!("Creating tensor from fd: major={major}, minor={minor}");
941
942 if major != 0 {
943 return Err(Error::UnknownDeviceType(major, minor));
945 }
946
947 match minor {
948 9 | 10 => {
949 DmaTensor::<T>::from_fd(fd, shape, name).map(TensorStorage::Dma)
951 }
952 _ => {
953 ShmTensor::<T>::from_fd(fd, shape, name).map(TensorStorage::Shm)
955 }
956 }
957 }
958 #[cfg(all(unix, not(target_os = "linux")))]
959 {
960 ShmTensor::<T>::from_fd(fd, shape, name).map(TensorStorage::Shm)
962 }
963 }
964}
965
966impl<T> TensorTrait<T> for TensorStorage<T>
967where
968 T: Num + Clone + fmt::Debug + Send + Sync,
969{
970 fn new(shape: &[usize], name: Option<&str>) -> Result<Self> {
971 Self::new(shape, None, name)
972 }
973
974 #[cfg(unix)]
975 fn from_fd(fd: OwnedFd, shape: &[usize], name: Option<&str>) -> Result<Self> {
976 Self::from_fd(fd, shape, name)
977 }
978
979 #[cfg(unix)]
980 fn clone_fd(&self) -> Result<OwnedFd> {
981 match self {
982 #[cfg(target_os = "linux")]
983 TensorStorage::Dma(t) => t.clone_fd(),
984 TensorStorage::Shm(t) => t.clone_fd(),
985 TensorStorage::Mem(t) => t.clone_fd(),
986 TensorStorage::Pbo(t) => t.clone_fd(),
987 }
988 }
989
990 fn memory(&self) -> TensorMemory {
991 match self {
992 #[cfg(target_os = "linux")]
993 TensorStorage::Dma(_) => TensorMemory::Dma,
994 #[cfg(unix)]
995 TensorStorage::Shm(_) => TensorMemory::Shm,
996 TensorStorage::Mem(_) => TensorMemory::Mem,
997 TensorStorage::Pbo(_) => TensorMemory::Pbo,
998 }
999 }
1000
1001 fn name(&self) -> String {
1002 match self {
1003 #[cfg(target_os = "linux")]
1004 TensorStorage::Dma(t) => t.name(),
1005 #[cfg(unix)]
1006 TensorStorage::Shm(t) => t.name(),
1007 TensorStorage::Mem(t) => t.name(),
1008 TensorStorage::Pbo(t) => t.name(),
1009 }
1010 }
1011
1012 fn shape(&self) -> &[usize] {
1013 match self {
1014 #[cfg(target_os = "linux")]
1015 TensorStorage::Dma(t) => t.shape(),
1016 #[cfg(unix)]
1017 TensorStorage::Shm(t) => t.shape(),
1018 TensorStorage::Mem(t) => t.shape(),
1019 TensorStorage::Pbo(t) => t.shape(),
1020 }
1021 }
1022
1023 fn reshape(&mut self, shape: &[usize]) -> Result<()> {
1024 match self {
1025 #[cfg(target_os = "linux")]
1026 TensorStorage::Dma(t) => t.reshape(shape),
1027 #[cfg(unix)]
1028 TensorStorage::Shm(t) => t.reshape(shape),
1029 TensorStorage::Mem(t) => t.reshape(shape),
1030 TensorStorage::Pbo(t) => t.reshape(shape),
1031 }
1032 }
1033
1034 fn map(&self) -> Result<TensorMap<T>> {
1035 match self {
1036 #[cfg(target_os = "linux")]
1037 TensorStorage::Dma(t) => t.map(),
1038 #[cfg(unix)]
1039 TensorStorage::Shm(t) => t.map(),
1040 TensorStorage::Mem(t) => t.map(),
1041 TensorStorage::Pbo(t) => t.map(),
1042 }
1043 }
1044
1045 fn buffer_identity(&self) -> &BufferIdentity {
1046 match self {
1047 #[cfg(target_os = "linux")]
1048 TensorStorage::Dma(t) => t.buffer_identity(),
1049 #[cfg(unix)]
1050 TensorStorage::Shm(t) => t.buffer_identity(),
1051 TensorStorage::Mem(t) => t.buffer_identity(),
1052 TensorStorage::Pbo(t) => t.buffer_identity(),
1053 }
1054 }
1055}
1056
1057#[derive(Debug)]
1063pub struct Tensor<T>
1064where
1065 T: Num + Clone + fmt::Debug + Send + Sync,
1066{
1067 pub(crate) storage: TensorStorage<T>,
1068 format: Option<PixelFormat>,
1069 chroma: Option<Box<Tensor<T>>>,
1070 row_stride: Option<usize>,
1073 plane_offset: Option<usize>,
1076 pub(crate) quantization: Option<Quantization>,
1080}
1081
1082impl<T> Tensor<T>
1083where
1084 T: Num + Clone + fmt::Debug + Send + Sync,
1085{
1086 pub(crate) fn wrap(storage: TensorStorage<T>) -> Self {
1088 Self {
1089 storage,
1090 format: None,
1091 chroma: None,
1092 row_stride: None,
1093 plane_offset: None,
1094 quantization: None,
1095 }
1096 }
1097
1098 pub fn from_slice(values: &[T], shape: &[usize]) -> Result<Self>
1107 where
1108 T: Copy,
1109 {
1110 let expected: usize = shape.iter().product();
1111 if values.len() != expected {
1112 return Err(Error::InvalidShape(format!(
1113 "from_slice: values.len()={} but shape product={expected} (shape={shape:?})",
1114 values.len()
1115 )));
1116 }
1117 let t = Self::new(shape, Some(TensorMemory::Mem), None)?;
1118 {
1119 let mut m = t.map()?;
1120 m.as_mut_slice().copy_from_slice(values);
1121 }
1122 Ok(t)
1123 }
1124
1125 #[cfg(feature = "ndarray")]
1130 pub fn from_arrayview3(view: ndarray::ArrayView3<'_, T>) -> Result<Self>
1131 where
1132 T: Copy,
1133 {
1134 let (h, w, c) = view.dim();
1135 let t = Self::new(&[h, w, c], Some(TensorMemory::Mem), None)?;
1136 {
1137 let mut m = t.map()?;
1138 let dst = m.as_mut_slice();
1139 if let Some(src) = view.as_slice() {
1140 dst.copy_from_slice(src);
1141 } else {
1142 for (d, &s) in dst.iter_mut().zip(view.iter()) {
1143 *d = s;
1144 }
1145 }
1146 }
1147 Ok(t)
1148 }
1149
1150 pub fn new(shape: &[usize], memory: Option<TensorMemory>, name: Option<&str>) -> Result<Self> {
1175 let _span = tracing::trace_span!(
1176 "tensor_alloc",
1177 ?shape,
1178 memory = ?memory,
1179 dtype = std::any::type_name::<T>(),
1180 )
1181 .entered();
1182 TensorStorage::new(shape, memory, name).map(Self::wrap)
1183 }
1184
1185 pub fn image(
1187 width: usize,
1188 height: usize,
1189 format: PixelFormat,
1190 memory: Option<TensorMemory>,
1191 ) -> Result<Self> {
1192 let shape = match format.layout() {
1193 PixelLayout::Packed => vec![height, width, format.channels()],
1194 PixelLayout::Planar => vec![format.channels(), height, width],
1195 PixelLayout::SemiPlanar => {
1196 let total_h = match format {
1200 PixelFormat::Nv12 => {
1201 if !height.is_multiple_of(2) {
1202 return Err(Error::InvalidArgument(format!(
1203 "NV12 requires even height, got {height}"
1204 )));
1205 }
1206 height * 3 / 2
1207 }
1208 PixelFormat::Nv16 => height * 2,
1209 _ => {
1210 return Err(Error::InvalidArgument(format!(
1211 "unknown semi-planar height multiplier for {format:?}"
1212 )))
1213 }
1214 };
1215 vec![total_h, width]
1216 }
1217 };
1218 let mut t = Self::new(&shape, memory, None)?;
1219 t.format = Some(format);
1220 Ok(t)
1221 }
1222
1223 pub fn image_with_stride(
1259 width: usize,
1260 height: usize,
1261 format: PixelFormat,
1262 row_stride_bytes: usize,
1263 memory: Option<TensorMemory>,
1264 ) -> Result<Self> {
1265 #[cfg(not(target_os = "linux"))]
1275 {
1276 let _ = (width, height, format, row_stride_bytes, memory);
1277 Err(Error::NotImplemented(
1278 "image_with_stride requires DMA support (Linux only)".to_owned(),
1279 ))
1280 }
1281
1282 #[cfg(target_os = "linux")]
1283 {
1284 if format.layout() != PixelLayout::Packed {
1285 return Err(Error::NotImplemented(format!(
1286 "Tensor::image_with_stride only supports packed pixel layouts, got {format:?}"
1287 )));
1288 }
1289 let elem = std::mem::size_of::<T>();
1290 let min_stride = width
1291 .checked_mul(format.channels())
1292 .and_then(|p| p.checked_mul(elem))
1293 .ok_or_else(|| {
1294 Error::InvalidArgument(format!(
1295 "image_with_stride: width {width} × channels {} × sizeof::<T>={elem} \
1296 overflows usize",
1297 format.channels()
1298 ))
1299 })?;
1300 if row_stride_bytes < min_stride {
1301 return Err(Error::InvalidArgument(format!(
1302 "image_with_stride: row_stride {row_stride_bytes} < minimum {min_stride} \
1303 ({width} px × {} ch × {elem} B)",
1304 format.channels()
1305 )));
1306 }
1307 let total_byte_size = row_stride_bytes.checked_mul(height).ok_or_else(|| {
1308 Error::InvalidArgument(format!(
1309 "image_with_stride: row_stride {row_stride_bytes} × height {height} overflows usize"
1310 ))
1311 })?;
1312
1313 let shape = vec![height, width, format.channels()];
1314
1315 let storage = match memory {
1316 Some(TensorMemory::Dma) | None => {
1317 TensorStorage::<T>::new_dma_with_byte_size(&shape, total_byte_size, None)?
1318 }
1319 Some(other) => {
1320 return Err(Error::NotImplemented(format!(
1321 "image_with_stride: only TensorMemory::Dma is supported, got {other:?}"
1322 )));
1323 }
1324 };
1325
1326 let mut t = Self::wrap(storage);
1327 t.format = Some(format);
1328 t.row_stride = Some(row_stride_bytes);
1329 Ok(t)
1330 }
1331 }
1332
1333 pub fn set_format(&mut self, format: PixelFormat) -> Result<()> {
1350 let shape = self.shape();
1351 match format.layout() {
1352 PixelLayout::Packed => {
1353 if shape.len() != 3 || shape[2] != format.channels() {
1354 return Err(Error::InvalidShape(format!(
1355 "packed format {format:?} expects [H, W, {}], got {shape:?}",
1356 format.channels()
1357 )));
1358 }
1359 }
1360 PixelLayout::Planar => {
1361 if shape.len() != 3 || shape[0] != format.channels() {
1362 return Err(Error::InvalidShape(format!(
1363 "planar format {format:?} expects [{}, H, W], got {shape:?}",
1364 format.channels()
1365 )));
1366 }
1367 }
1368 PixelLayout::SemiPlanar => {
1369 if shape.len() != 2 {
1370 return Err(Error::InvalidShape(format!(
1371 "semi-planar format {format:?} expects [H*k, W], got {shape:?}"
1372 )));
1373 }
1374 match format {
1375 PixelFormat::Nv12 if !shape[0].is_multiple_of(3) => {
1376 return Err(Error::InvalidShape(format!(
1377 "NV12 contiguous shape[0] must be divisible by 3, got {}",
1378 shape[0]
1379 )));
1380 }
1381 PixelFormat::Nv16 if !shape[0].is_multiple_of(2) => {
1382 return Err(Error::InvalidShape(format!(
1383 "NV16 contiguous shape[0] must be even, got {}",
1384 shape[0]
1385 )));
1386 }
1387 _ => {}
1388 }
1389 }
1390 }
1391 if self.format != Some(format) {
1394 self.row_stride = None;
1395 self.plane_offset = None;
1396 #[cfg(target_os = "linux")]
1397 if let TensorStorage::Dma(ref mut dma) = self.storage {
1398 dma.mmap_offset = 0;
1399 }
1400 }
1401 self.format = Some(format);
1402 Ok(())
1403 }
1404
1405 pub fn format(&self) -> Option<PixelFormat> {
1407 self.format
1408 }
1409
1410 pub fn width(&self) -> Option<usize> {
1412 let fmt = self.format?;
1413 let shape = self.shape();
1414 match fmt.layout() {
1415 PixelLayout::Packed => Some(shape[1]),
1416 PixelLayout::Planar => Some(shape[2]),
1417 PixelLayout::SemiPlanar => Some(shape[1]),
1418 }
1419 }
1420
1421 pub fn height(&self) -> Option<usize> {
1423 let fmt = self.format?;
1424 let shape = self.shape();
1425 match fmt.layout() {
1426 PixelLayout::Packed => Some(shape[0]),
1427 PixelLayout::Planar => Some(shape[1]),
1428 PixelLayout::SemiPlanar => {
1429 if self.is_multiplane() {
1430 Some(shape[0])
1431 } else {
1432 match fmt {
1433 PixelFormat::Nv12 => Some(shape[0] * 2 / 3),
1434 PixelFormat::Nv16 => Some(shape[0] / 2),
1435 _ => None,
1436 }
1437 }
1438 }
1439 }
1440 }
1441
1442 pub fn from_planes(luma: Tensor<T>, chroma: Tensor<T>, format: PixelFormat) -> Result<Self> {
1444 if format.layout() != PixelLayout::SemiPlanar {
1445 return Err(Error::InvalidArgument(format!(
1446 "from_planes requires a semi-planar format, got {format:?}"
1447 )));
1448 }
1449 if chroma.format.is_some() || chroma.chroma.is_some() {
1450 return Err(Error::InvalidArgument(
1451 "chroma tensor must be a raw tensor (no format or chroma metadata)".into(),
1452 ));
1453 }
1454 let luma_shape = luma.shape();
1455 let chroma_shape = chroma.shape();
1456 if luma_shape.len() != 2 || chroma_shape.len() != 2 {
1457 return Err(Error::InvalidArgument(format!(
1458 "from_planes expects 2D shapes, got luma={luma_shape:?} chroma={chroma_shape:?}"
1459 )));
1460 }
1461 if luma_shape[1] != chroma_shape[1] {
1462 return Err(Error::InvalidArgument(format!(
1463 "luma width {} != chroma width {}",
1464 luma_shape[1], chroma_shape[1]
1465 )));
1466 }
1467 match format {
1468 PixelFormat::Nv12 => {
1469 if luma_shape[0] % 2 != 0 {
1470 return Err(Error::InvalidArgument(format!(
1471 "NV12 requires even luma height, got {}",
1472 luma_shape[0]
1473 )));
1474 }
1475 if chroma_shape[0] != luma_shape[0] / 2 {
1476 return Err(Error::InvalidArgument(format!(
1477 "NV12 chroma height {} != luma height / 2 ({})",
1478 chroma_shape[0],
1479 luma_shape[0] / 2
1480 )));
1481 }
1482 }
1483 PixelFormat::Nv16 => {
1484 if chroma_shape[0] != luma_shape[0] {
1485 return Err(Error::InvalidArgument(format!(
1486 "NV16 chroma height {} != luma height {}",
1487 chroma_shape[0], luma_shape[0]
1488 )));
1489 }
1490 }
1491 _ => {
1492 return Err(Error::InvalidArgument(format!(
1493 "from_planes only supports NV12 and NV16, got {format:?}"
1494 )));
1495 }
1496 }
1497
1498 Ok(Tensor {
1499 storage: luma.storage,
1500 format: Some(format),
1501 chroma: Some(Box::new(chroma)),
1502 row_stride: luma.row_stride,
1503 plane_offset: luma.plane_offset,
1504 quantization: luma.quantization,
1505 })
1506 }
1507
1508 pub fn is_multiplane(&self) -> bool {
1510 self.chroma.is_some()
1511 }
1512
1513 pub fn chroma(&self) -> Option<&Tensor<T>> {
1515 self.chroma.as_deref()
1516 }
1517
1518 pub fn chroma_mut(&mut self) -> Option<&mut Tensor<T>> {
1520 self.chroma.as_deref_mut()
1521 }
1522
1523 pub fn row_stride(&self) -> Option<usize> {
1525 self.row_stride
1526 }
1527
1528 pub fn effective_row_stride(&self) -> Option<usize> {
1533 if let Some(s) = self.row_stride {
1534 return Some(s);
1535 }
1536 let fmt = self.format?;
1537 let w = self.width()?;
1538 let elem = std::mem::size_of::<T>();
1539 Some(match fmt.layout() {
1540 PixelLayout::Packed => w * fmt.channels() * elem,
1541 PixelLayout::Planar | PixelLayout::SemiPlanar => w * elem,
1542 })
1543 }
1544
1545 pub fn set_row_stride(&mut self, stride: usize) -> Result<()> {
1572 let fmt = self.format.ok_or_else(|| {
1573 Error::InvalidArgument("cannot set row_stride without a pixel format".into())
1574 })?;
1575 let w = self.width().ok_or_else(|| {
1576 Error::InvalidArgument("cannot determine width for row_stride validation".into())
1577 })?;
1578 let elem = std::mem::size_of::<T>();
1579 let min_stride = match fmt.layout() {
1580 PixelLayout::Packed => w * fmt.channels() * elem,
1581 PixelLayout::Planar | PixelLayout::SemiPlanar => w * elem,
1582 };
1583 if stride < min_stride {
1584 return Err(Error::InvalidArgument(format!(
1585 "row_stride {stride} < minimum {min_stride} for {fmt:?} at width {w}"
1586 )));
1587 }
1588 self.row_stride = Some(stride);
1589 Ok(())
1590 }
1591
1592 pub fn set_row_stride_unchecked(&mut self, stride: usize) {
1598 self.row_stride = Some(stride);
1599 }
1600
1601 pub fn with_row_stride(mut self, stride: usize) -> Result<Self> {
1608 self.set_row_stride(stride)?;
1609 Ok(self)
1610 }
1611
1612 pub fn plane_offset(&self) -> Option<usize> {
1614 self.plane_offset
1615 }
1616
1617 pub fn set_plane_offset(&mut self, offset: usize) {
1623 self.plane_offset = Some(offset);
1624 #[cfg(target_os = "linux")]
1625 if let TensorStorage::Dma(ref mut dma) = self.storage {
1626 dma.mmap_offset = offset;
1627 }
1628 }
1629
1630 pub fn with_plane_offset(mut self, offset: usize) -> Self {
1633 self.set_plane_offset(offset);
1634 self
1635 }
1636
1637 pub fn as_pbo(&self) -> Option<&PboTensor<T>> {
1639 match &self.storage {
1640 TensorStorage::Pbo(p) => Some(p),
1641 _ => None,
1642 }
1643 }
1644
1645 #[cfg(target_os = "linux")]
1647 pub fn as_dma(&self) -> Option<&DmaTensor<T>> {
1648 match &self.storage {
1649 TensorStorage::Dma(d) => Some(d),
1650 _ => None,
1651 }
1652 }
1653
1654 #[cfg(target_os = "linux")]
1665 pub fn dmabuf(&self) -> Result<std::os::fd::BorrowedFd<'_>> {
1666 use std::os::fd::AsFd;
1667 match &self.storage {
1668 TensorStorage::Dma(dma) => Ok(dma.fd.as_fd()),
1669 _ => Err(Error::NotImplemented(format!(
1670 "dmabuf requires DMA-backed tensor, got {:?}",
1671 self.storage.memory()
1672 ))),
1673 }
1674 }
1675
1676 pub fn from_pbo(pbo: PboTensor<T>) -> Self {
1678 Self {
1679 storage: TensorStorage::Pbo(pbo),
1680 format: None,
1681 chroma: None,
1682 row_stride: None,
1683 plane_offset: None,
1684 quantization: None,
1685 }
1686 }
1687}
1688
1689impl<T> Tensor<T>
1693where
1694 T: IntegerType + Num + Clone + fmt::Debug + Send + Sync,
1695{
1696 pub fn quantization(&self) -> Option<&Quantization> {
1698 self.quantization.as_ref()
1699 }
1700
1701 pub fn set_quantization(&mut self, q: Quantization) -> Result<()> {
1705 q.validate(self.shape())?;
1706 self.quantization = Some(q);
1707 Ok(())
1708 }
1709
1710 pub fn with_quantization(mut self, q: Quantization) -> Result<Self> {
1716 self.set_quantization(q)?;
1717 Ok(self)
1718 }
1719
1720 pub fn clear_quantization(&mut self) {
1722 self.quantization = None;
1723 }
1724}
1725
1726impl<T> TensorTrait<T> for Tensor<T>
1727where
1728 T: Num + Clone + fmt::Debug + Send + Sync,
1729{
1730 fn new(shape: &[usize], name: Option<&str>) -> Result<Self>
1731 where
1732 Self: Sized,
1733 {
1734 Self::new(shape, None, name)
1735 }
1736
1737 #[cfg(unix)]
1738 fn from_fd(fd: std::os::fd::OwnedFd, shape: &[usize], name: Option<&str>) -> Result<Self>
1739 where
1740 Self: Sized,
1741 {
1742 Ok(Self::wrap(TensorStorage::from_fd(fd, shape, name)?))
1743 }
1744
1745 #[cfg(unix)]
1746 fn clone_fd(&self) -> Result<std::os::fd::OwnedFd> {
1747 self.storage.clone_fd()
1748 }
1749
1750 fn memory(&self) -> TensorMemory {
1751 self.storage.memory()
1752 }
1753
1754 fn name(&self) -> String {
1755 self.storage.name()
1756 }
1757
1758 fn shape(&self) -> &[usize] {
1759 self.storage.shape()
1760 }
1761
1762 fn reshape(&mut self, shape: &[usize]) -> Result<()> {
1763 if self.chroma.is_some() {
1764 return Err(Error::InvalidOperation(
1765 "cannot reshape a multiplane tensor — decompose planes first".into(),
1766 ));
1767 }
1768 self.storage.reshape(shape)?;
1769 self.format = None;
1770 self.row_stride = None;
1771 self.plane_offset = None;
1772 #[cfg(target_os = "linux")]
1773 if let TensorStorage::Dma(ref mut dma) = self.storage {
1774 dma.mmap_offset = 0;
1775 }
1776 Ok(())
1777 }
1778
1779 fn map(&self) -> Result<TensorMap<T>> {
1780 let _span = tracing::trace_span!(
1781 "tensor_map",
1782 memory = ?self.storage.memory(),
1783 )
1784 .entered();
1785 #[cfg(target_os = "linux")]
1802 if let Some(stride) = self.row_stride {
1803 if let TensorStorage::Dma(dma) = &self.storage {
1804 if !dma.is_imported {
1805 let height = self.height().ok_or_else(|| {
1826 Error::InvalidOperation(
1827 "Tensor::map: strided DMA mapping requires a PixelFormat \
1828 so height() can be derived; set a format before mapping \
1829 or clear row_stride for raw tensor access"
1830 .into(),
1831 )
1832 })?;
1833 let total_bytes = stride.checked_mul(height).ok_or_else(|| {
1834 Error::InvalidOperation(format!(
1835 "Tensor::map: row_stride {stride} × height {height} overflows usize"
1836 ))
1837 })?;
1838 let available_bytes = dma.buf_size.saturating_sub(dma.mmap_offset);
1839 if total_bytes > available_bytes {
1840 return Err(Error::InvalidOperation(format!(
1841 "Tensor::map: strided mapping needs {total_bytes} bytes \
1842 but DMA buffer only has {available_bytes} available \
1843 (buf_size={}, mmap_offset={}, stride={stride}, height={height}); \
1844 the row_stride was likely set larger than the original allocation",
1845 dma.buf_size, dma.mmap_offset
1846 )));
1847 }
1848 return dma.map_with_byte_size(total_bytes).map(TensorMap::Dma);
1849 }
1850 }
1851 return Err(Error::InvalidOperation(
1852 "CPU mapping of strided foreign tensors is not supported; \
1853 use GPU path only"
1854 .into(),
1855 ));
1856 }
1857 #[cfg(not(target_os = "linux"))]
1858 if self.row_stride.is_some() {
1859 return Err(Error::InvalidOperation(
1860 "CPU mapping of strided tensors is not supported on this \
1861 platform (DMA backing is Linux-only)"
1862 .into(),
1863 ));
1864 }
1865 if self.plane_offset.is_some_and(|o| o > 0) {
1869 #[cfg(target_os = "linux")]
1870 if !matches!(self.storage, TensorStorage::Dma(_)) {
1871 return Err(Error::InvalidOperation(
1872 "plane offset only supported for DMA tensors".into(),
1873 ));
1874 }
1875 #[cfg(not(target_os = "linux"))]
1876 return Err(Error::InvalidOperation(
1877 "plane offset only supported for DMA tensors".into(),
1878 ));
1879 }
1880 self.storage.map()
1881 }
1882
1883 fn buffer_identity(&self) -> &BufferIdentity {
1884 self.storage.buffer_identity()
1885 }
1886}
1887
1888pub enum TensorMap<T>
1889where
1890 T: Num + Clone + fmt::Debug,
1891{
1892 #[cfg(target_os = "linux")]
1893 Dma(DmaMap<T>),
1894 #[cfg(unix)]
1895 Shm(ShmMap<T>),
1896 Mem(MemMap<T>),
1897 Pbo(PboMap<T>),
1898}
1899
1900impl<T> TensorMapTrait<T> for TensorMap<T>
1901where
1902 T: Num + Clone + fmt::Debug,
1903{
1904 fn shape(&self) -> &[usize] {
1905 match self {
1906 #[cfg(target_os = "linux")]
1907 TensorMap::Dma(map) => map.shape(),
1908 #[cfg(unix)]
1909 TensorMap::Shm(map) => map.shape(),
1910 TensorMap::Mem(map) => map.shape(),
1911 TensorMap::Pbo(map) => map.shape(),
1912 }
1913 }
1914
1915 fn unmap(&mut self) {
1916 match self {
1917 #[cfg(target_os = "linux")]
1918 TensorMap::Dma(map) => map.unmap(),
1919 #[cfg(unix)]
1920 TensorMap::Shm(map) => map.unmap(),
1921 TensorMap::Mem(map) => map.unmap(),
1922 TensorMap::Pbo(map) => map.unmap(),
1923 }
1924 }
1925
1926 fn as_slice(&self) -> &[T] {
1927 match self {
1928 #[cfg(target_os = "linux")]
1929 TensorMap::Dma(map) => map.as_slice(),
1930 #[cfg(unix)]
1931 TensorMap::Shm(map) => map.as_slice(),
1932 TensorMap::Mem(map) => map.as_slice(),
1933 TensorMap::Pbo(map) => map.as_slice(),
1934 }
1935 }
1936
1937 fn as_mut_slice(&mut self) -> &mut [T] {
1938 match self {
1939 #[cfg(target_os = "linux")]
1940 TensorMap::Dma(map) => map.as_mut_slice(),
1941 #[cfg(unix)]
1942 TensorMap::Shm(map) => map.as_mut_slice(),
1943 TensorMap::Mem(map) => map.as_mut_slice(),
1944 TensorMap::Pbo(map) => map.as_mut_slice(),
1945 }
1946 }
1947}
1948
1949impl<T> Deref for TensorMap<T>
1950where
1951 T: Num + Clone + fmt::Debug,
1952{
1953 type Target = [T];
1954
1955 fn deref(&self) -> &[T] {
1956 match self {
1957 #[cfg(target_os = "linux")]
1958 TensorMap::Dma(map) => map.deref(),
1959 #[cfg(unix)]
1960 TensorMap::Shm(map) => map.deref(),
1961 TensorMap::Mem(map) => map.deref(),
1962 TensorMap::Pbo(map) => map.deref(),
1963 }
1964 }
1965}
1966
1967impl<T> DerefMut for TensorMap<T>
1968where
1969 T: Num + Clone + fmt::Debug,
1970{
1971 fn deref_mut(&mut self) -> &mut [T] {
1972 match self {
1973 #[cfg(target_os = "linux")]
1974 TensorMap::Dma(map) => map.deref_mut(),
1975 #[cfg(unix)]
1976 TensorMap::Shm(map) => map.deref_mut(),
1977 TensorMap::Mem(map) => map.deref_mut(),
1978 TensorMap::Pbo(map) => map.deref_mut(),
1979 }
1980 }
1981}
1982
1983#[cfg(target_os = "linux")]
1995static DMA_AVAILABLE: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
1996
1997#[cfg(target_os = "linux")]
1999pub fn is_dma_available() -> bool {
2000 *DMA_AVAILABLE.get_or_init(|| Tensor::<u8>::new(&[64], Some(TensorMemory::Dma), None).is_ok())
2001}
2002
2003#[cfg(not(target_os = "linux"))]
2007pub fn is_dma_available() -> bool {
2008 false
2009}
2010
2011#[cfg(unix)]
2018static SHM_AVAILABLE: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
2019
2020#[cfg(unix)]
2022pub fn is_shm_available() -> bool {
2023 *SHM_AVAILABLE.get_or_init(|| Tensor::<u8>::new(&[64], Some(TensorMemory::Shm), None).is_ok())
2024}
2025
2026#[cfg(not(unix))]
2030pub fn is_shm_available() -> bool {
2031 false
2032}
2033
2034#[cfg(test)]
2035mod dtype_tests {
2036 use super::*;
2037
2038 #[test]
2039 fn dtype_size() {
2040 assert_eq!(DType::U8.size(), 1);
2041 assert_eq!(DType::I8.size(), 1);
2042 assert_eq!(DType::U16.size(), 2);
2043 assert_eq!(DType::I16.size(), 2);
2044 assert_eq!(DType::U32.size(), 4);
2045 assert_eq!(DType::I32.size(), 4);
2046 assert_eq!(DType::U64.size(), 8);
2047 assert_eq!(DType::I64.size(), 8);
2048 assert_eq!(DType::F16.size(), 2);
2049 assert_eq!(DType::F32.size(), 4);
2050 assert_eq!(DType::F64.size(), 8);
2051 }
2052
2053 #[test]
2054 fn dtype_name() {
2055 assert_eq!(DType::U8.name(), "u8");
2056 assert_eq!(DType::F16.name(), "f16");
2057 assert_eq!(DType::F32.name(), "f32");
2058 }
2059
2060 #[test]
2061 fn dtype_serde_roundtrip() {
2062 use serde_json;
2063 let dt = DType::F16;
2064 let json = serde_json::to_string(&dt).unwrap();
2065 let back: DType = serde_json::from_str(&json).unwrap();
2066 assert_eq!(dt, back);
2067 }
2068}
2069
2070#[cfg(test)]
2071mod image_tests {
2072 use super::*;
2073
2074 #[test]
2075 fn raw_tensor_has_no_format() {
2076 let t = Tensor::<u8>::new(&[480, 640, 3], None, None).unwrap();
2077 assert!(t.format().is_none());
2078 assert!(t.width().is_none());
2079 assert!(t.height().is_none());
2080 assert!(!t.is_multiplane());
2081 assert!(t.chroma().is_none());
2082 }
2083
2084 #[test]
2085 fn image_tensor_packed() {
2086 let t = Tensor::<u8>::image(640, 480, PixelFormat::Rgba, None).unwrap();
2087 assert_eq!(t.format(), Some(PixelFormat::Rgba));
2088 assert_eq!(t.width(), Some(640));
2089 assert_eq!(t.height(), Some(480));
2090 assert_eq!(t.shape(), &[480, 640, 4]);
2091 assert!(!t.is_multiplane());
2092 }
2093
2094 #[test]
2095 fn image_tensor_planar() {
2096 let t = Tensor::<u8>::image(640, 480, PixelFormat::PlanarRgb, None).unwrap();
2097 assert_eq!(t.format(), Some(PixelFormat::PlanarRgb));
2098 assert_eq!(t.width(), Some(640));
2099 assert_eq!(t.height(), Some(480));
2100 assert_eq!(t.shape(), &[3, 480, 640]);
2101 }
2102
2103 #[test]
2104 fn image_tensor_semi_planar_contiguous() {
2105 let t = Tensor::<u8>::image(640, 480, PixelFormat::Nv12, None).unwrap();
2106 assert_eq!(t.format(), Some(PixelFormat::Nv12));
2107 assert_eq!(t.width(), Some(640));
2108 assert_eq!(t.height(), Some(480));
2109 assert_eq!(t.shape(), &[720, 640]);
2111 assert!(!t.is_multiplane());
2112 }
2113
2114 #[test]
2115 #[cfg(target_os = "linux")]
2116 fn image_tensor_with_stride_preserves_logical_width() {
2117 if !is_dma_available() {
2119 eprintln!("SKIPPED: DMA heap not available");
2120 return;
2121 }
2122 let stride = 12032;
2124 let t = Tensor::<u8>::image_with_stride(
2125 3004,
2126 1688,
2127 PixelFormat::Rgba,
2128 stride,
2129 Some(TensorMemory::Dma),
2130 )
2131 .unwrap();
2132 assert_eq!(t.width(), Some(3004));
2134 assert_eq!(t.height(), Some(1688));
2135 assert_eq!(t.shape(), &[1688, 3004, 4]);
2136 assert_eq!(t.effective_row_stride(), Some(stride));
2138 use crate::TensorMapTrait;
2141 {
2142 let map = t.map().unwrap();
2143 assert!(
2144 map.as_slice().len() >= stride * 1688,
2145 "mapped buffer {} bytes < expected {}",
2146 map.as_slice().len(),
2147 stride * 1688
2148 );
2149 }
2150 {
2153 let mut map = t.map().unwrap();
2154 let slice = map.as_mut_slice();
2155 for y in 0..1688 {
2156 let row_start = y * stride;
2157 for x in 0..3004 {
2158 let p = row_start + x * 4;
2159 slice[p] = (y & 0xFF) as u8;
2160 slice[p + 1] = (x & 0xFF) as u8;
2161 slice[p + 2] = 0x42;
2162 slice[p + 3] = 0xFF;
2163 }
2164 }
2165 }
2166 {
2167 let map = t.map().unwrap();
2168 let slice = map.as_slice();
2169 assert_eq!(slice[0], 0x00);
2171 assert_eq!(slice[1], 0x00);
2172 assert_eq!(slice[2], 0x42);
2173 assert_eq!(slice[3], 0xFF);
2174 let mid = 100 * stride + 50 * 4;
2175 assert_eq!(slice[mid], 100);
2176 assert_eq!(slice[mid + 1], 50);
2177 assert_eq!(slice[mid + 2], 0x42);
2178 }
2179 }
2180
2181 #[test]
2182 #[cfg(target_os = "linux")]
2183 fn image_tensor_with_stride_rejects_foreign_strided_map() {
2184 if !is_dma_available() {
2192 eprintln!("SKIPPED: DMA heap not available");
2193 return;
2194 }
2195 let backing = Tensor::<u8>::new(&[240 * 320 * 4], Some(TensorMemory::Dma), None).unwrap();
2197 let fd = backing.clone_fd().unwrap();
2198 let shape = [240usize, 320, 4];
2200 let storage = TensorStorage::<u8>::from_fd(fd, &shape, None).unwrap();
2201 let mut t = Tensor::<u8>::wrap(storage);
2202 t.set_format(PixelFormat::Bgra).unwrap();
2203 t.set_row_stride(320 * 4).unwrap(); let err = t.map();
2205 assert!(
2206 matches!(err, Err(Error::InvalidOperation(_))),
2207 "foreign strided map should error"
2208 );
2209 }
2210
2211 #[test]
2212 #[cfg(target_os = "linux")]
2213 fn image_tensor_with_stride_map_rejects_tampered_stride() {
2214 if !is_dma_available() {
2221 eprintln!("SKIPPED: DMA heap not available");
2222 return;
2223 }
2224 let mut t = Tensor::<u8>::image_with_stride(
2227 640,
2228 480,
2229 PixelFormat::Rgba,
2230 3072,
2231 Some(TensorMemory::Dma),
2232 )
2233 .unwrap();
2234 t.set_row_stride(12288).unwrap();
2237 let err = t.map();
2239 assert!(
2240 matches!(err, Err(Error::InvalidOperation(_))),
2241 "map() with oversized stride must return InvalidOperation"
2242 );
2243 }
2244
2245 #[test]
2246 fn dma_tensor_new_with_byte_size_rejects_shape_overflow() {
2247 #[cfg(target_os = "linux")]
2254 {
2255 let err = crate::dma::DmaTensor::<u64>::new_with_byte_size(
2256 &[usize::MAX, 2, 2],
2257 usize::MAX,
2258 None,
2259 );
2260 assert!(
2261 matches!(err, Err(Error::InvalidArgument(_))),
2262 "new_with_byte_size must detect shape.product() overflow"
2263 );
2264 }
2265 }
2266
2267 #[test]
2268 #[cfg(target_os = "linux")]
2269 fn image_tensor_with_stride_rejects_too_small_stride() {
2270 let err = Tensor::<u8>::image_with_stride(
2272 640,
2273 480,
2274 PixelFormat::Rgba,
2275 2400,
2276 Some(TensorMemory::Dma),
2277 );
2278 assert!(matches!(err, Err(Error::InvalidArgument(_))));
2279 }
2280
2281 #[test]
2282 #[cfg(target_os = "linux")]
2283 fn image_tensor_with_stride_rejects_non_packed() {
2284 let err = Tensor::<u8>::image_with_stride(
2287 640,
2288 480,
2289 PixelFormat::Nv12,
2290 640,
2291 Some(TensorMemory::Dma),
2292 );
2293 assert!(matches!(err, Err(Error::NotImplemented(_))));
2294 }
2295
2296 #[test]
2297 fn set_format_valid() {
2298 let mut t = Tensor::<u8>::new(&[480, 640, 3], None, None).unwrap();
2299 assert!(t.format().is_none());
2300 t.set_format(PixelFormat::Rgb).unwrap();
2301 assert_eq!(t.format(), Some(PixelFormat::Rgb));
2302 assert_eq!(t.width(), Some(640));
2303 assert_eq!(t.height(), Some(480));
2304 }
2305
2306 #[test]
2307 fn set_format_invalid_shape() {
2308 let mut t = Tensor::<u8>::new(&[480, 640, 4], None, None).unwrap();
2309 let err = t.set_format(PixelFormat::Rgb);
2311 assert!(err.is_err());
2312 assert!(t.format().is_none());
2314 }
2315
2316 #[test]
2317 fn reshape_clears_format() {
2318 let mut t = Tensor::<u8>::image(640, 480, PixelFormat::Rgba, None).unwrap();
2319 assert_eq!(t.format(), Some(PixelFormat::Rgba));
2320 t.reshape(&[480 * 640 * 4]).unwrap();
2322 assert!(t.format().is_none());
2323 }
2324
2325 #[test]
2326 fn from_planes_nv12() {
2327 let y = Tensor::<u8>::new(&[480, 640], None, None).unwrap();
2328 let uv = Tensor::<u8>::new(&[240, 640], None, None).unwrap();
2329 let img = Tensor::from_planes(y, uv, PixelFormat::Nv12).unwrap();
2330 assert_eq!(img.format(), Some(PixelFormat::Nv12));
2331 assert!(img.is_multiplane());
2332 assert!(img.chroma().is_some());
2333 assert_eq!(img.width(), Some(640));
2334 assert_eq!(img.height(), Some(480));
2335 }
2336
2337 #[test]
2338 fn from_planes_rejects_non_semiplanar() {
2339 let y = Tensor::<u8>::new(&[480, 640], None, None).unwrap();
2340 let uv = Tensor::<u8>::new(&[240, 640], None, None).unwrap();
2341 let err = Tensor::from_planes(y, uv, PixelFormat::Rgb);
2342 assert!(err.is_err());
2343 }
2344
2345 #[test]
2346 fn reshape_multiplane_errors() {
2347 let y = Tensor::<u8>::new(&[480, 640], None, None).unwrap();
2348 let uv = Tensor::<u8>::new(&[240, 640], None, None).unwrap();
2349 let mut img = Tensor::from_planes(y, uv, PixelFormat::Nv12).unwrap();
2350 let err = img.reshape(&[480 * 640 + 240 * 640]);
2351 assert!(err.is_err());
2352 }
2353}
2354
2355#[cfg(test)]
2356mod tests {
2357 #[cfg(target_os = "linux")]
2358 use nix::unistd::{access, AccessFlags};
2359 #[cfg(target_os = "linux")]
2360 use std::io::Write as _;
2361 use std::sync::RwLock;
2362
2363 use super::*;
2364
2365 #[ctor::ctor]
2366 fn init() {
2367 env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init();
2368 }
2369
2370 #[cfg(target_os = "linux")]
2372 macro_rules! function {
2373 () => {{
2374 fn f() {}
2375 fn type_name_of<T>(_: T) -> &'static str {
2376 std::any::type_name::<T>()
2377 }
2378 let name = type_name_of(f);
2379
2380 match &name[..name.len() - 3].rfind(':') {
2382 Some(pos) => &name[pos + 1..name.len() - 3],
2383 None => &name[..name.len() - 3],
2384 }
2385 }};
2386 }
2387
2388 #[test]
2389 #[cfg(target_os = "linux")]
2390 fn test_tensor() {
2391 let _lock = FD_LOCK.read().unwrap();
2392 let shape = vec![1];
2393 let tensor = DmaTensor::<f32>::new(&shape, Some("dma_tensor"));
2394 let dma_enabled = tensor.is_ok();
2395
2396 let tensor = Tensor::<f32>::new(&shape, None, None).expect("Failed to create tensor");
2397 match dma_enabled {
2398 true => assert_eq!(tensor.memory(), TensorMemory::Dma),
2399 false => assert_eq!(tensor.memory(), TensorMemory::Shm),
2400 }
2401 }
2402
2403 #[test]
2404 #[cfg(all(unix, not(target_os = "linux")))]
2405 fn test_tensor() {
2406 let shape = vec![1];
2407 let tensor = Tensor::<f32>::new(&shape, None, None).expect("Failed to create tensor");
2408 assert!(
2410 tensor.memory() == TensorMemory::Shm || tensor.memory() == TensorMemory::Mem,
2411 "Expected SHM or Mem on macOS, got {:?}",
2412 tensor.memory()
2413 );
2414 }
2415
2416 #[test]
2417 #[cfg(not(unix))]
2418 fn test_tensor() {
2419 let shape = vec![1];
2420 let tensor = Tensor::<f32>::new(&shape, None, None).expect("Failed to create tensor");
2421 assert_eq!(tensor.memory(), TensorMemory::Mem);
2422 }
2423
2424 #[test]
2425 #[cfg(target_os = "linux")]
2426 fn test_dma_tensor() {
2427 let _lock = FD_LOCK.read().unwrap();
2428 match access(
2429 "/dev/dma_heap/linux,cma",
2430 AccessFlags::R_OK | AccessFlags::W_OK,
2431 ) {
2432 Ok(_) => println!("/dev/dma_heap/linux,cma is available"),
2433 Err(_) => match access(
2434 "/dev/dma_heap/system",
2435 AccessFlags::R_OK | AccessFlags::W_OK,
2436 ) {
2437 Ok(_) => println!("/dev/dma_heap/system is available"),
2438 Err(e) => {
2439 writeln!(
2440 &mut std::io::stdout(),
2441 "[WARNING] DMA Heap is unavailable: {e}"
2442 )
2443 .unwrap();
2444 return;
2445 }
2446 },
2447 }
2448
2449 let shape = vec![2, 3, 4];
2450 let tensor =
2451 DmaTensor::<f32>::new(&shape, Some("test_tensor")).expect("Failed to create tensor");
2452
2453 const DUMMY_VALUE: f32 = 12.34;
2454
2455 assert_eq!(tensor.memory(), TensorMemory::Dma);
2456 assert_eq!(tensor.name(), "test_tensor");
2457 assert_eq!(tensor.shape(), &shape);
2458 assert_eq!(tensor.size(), 2 * 3 * 4 * std::mem::size_of::<f32>());
2459 assert_eq!(tensor.len(), 2 * 3 * 4);
2460
2461 {
2462 let mut tensor_map = tensor.map().expect("Failed to map DMA memory");
2463 tensor_map.fill(42.0);
2464 assert!(tensor_map.iter().all(|&x| x == 42.0));
2465 }
2466
2467 {
2468 let shared = Tensor::<f32>::from_fd(
2469 tensor
2470 .clone_fd()
2471 .expect("Failed to duplicate tensor file descriptor"),
2472 &shape,
2473 Some("test_tensor_shared"),
2474 )
2475 .expect("Failed to create tensor from fd");
2476
2477 assert_eq!(shared.memory(), TensorMemory::Dma);
2478 assert_eq!(shared.name(), "test_tensor_shared");
2479 assert_eq!(shared.shape(), &shape);
2480
2481 let mut tensor_map = shared.map().expect("Failed to map DMA memory from fd");
2482 tensor_map.fill(DUMMY_VALUE);
2483 assert!(tensor_map.iter().all(|&x| x == DUMMY_VALUE));
2484 }
2485
2486 {
2487 let tensor_map = tensor.map().expect("Failed to map DMA memory");
2488 assert!(tensor_map.iter().all(|&x| x == DUMMY_VALUE));
2489 }
2490
2491 let mut tensor = DmaTensor::<u8>::new(&shape, None).expect("Failed to create tensor");
2492 assert_eq!(tensor.shape(), &shape);
2493 let new_shape = vec![3, 4, 4];
2494 assert!(
2495 tensor.reshape(&new_shape).is_err(),
2496 "Reshape should fail due to size mismatch"
2497 );
2498 assert_eq!(tensor.shape(), &shape, "Shape should remain unchanged");
2499
2500 let new_shape = vec![2, 3, 4];
2501 tensor.reshape(&new_shape).expect("Reshape should succeed");
2502 assert_eq!(
2503 tensor.shape(),
2504 &new_shape,
2505 "Shape should be updated after successful reshape"
2506 );
2507
2508 {
2509 let mut tensor_map = tensor.map().expect("Failed to map DMA memory");
2510 tensor_map.fill(1);
2511 assert!(tensor_map.iter().all(|&x| x == 1));
2512 }
2513
2514 {
2515 let mut tensor_map = tensor.map().expect("Failed to map DMA memory");
2516 tensor_map[2] = 42;
2517 assert_eq!(tensor_map[1], 1, "Value at index 1 should be 1");
2518 assert_eq!(tensor_map[2], 42, "Value at index 2 should be 42");
2519 }
2520 }
2521
2522 #[test]
2523 #[cfg(unix)]
2524 fn test_shm_tensor() {
2525 let _lock = FD_LOCK.read().unwrap();
2526 let shape = vec![2, 3, 4];
2527 let tensor =
2528 ShmTensor::<f32>::new(&shape, Some("test_tensor")).expect("Failed to create tensor");
2529 assert_eq!(tensor.shape(), &shape);
2530 assert_eq!(tensor.size(), 2 * 3 * 4 * std::mem::size_of::<f32>());
2531 assert_eq!(tensor.name(), "test_tensor");
2532
2533 const DUMMY_VALUE: f32 = 12.34;
2534 {
2535 let mut tensor_map = tensor.map().expect("Failed to map shared memory");
2536 tensor_map.fill(42.0);
2537 assert!(tensor_map.iter().all(|&x| x == 42.0));
2538 }
2539
2540 {
2541 let shared = Tensor::<f32>::from_fd(
2542 tensor
2543 .clone_fd()
2544 .expect("Failed to duplicate tensor file descriptor"),
2545 &shape,
2546 Some("test_tensor_shared"),
2547 )
2548 .expect("Failed to create tensor from fd");
2549
2550 assert_eq!(shared.memory(), TensorMemory::Shm);
2551 assert_eq!(shared.name(), "test_tensor_shared");
2552 assert_eq!(shared.shape(), &shape);
2553
2554 let mut tensor_map = shared.map().expect("Failed to map shared memory from fd");
2555 tensor_map.fill(DUMMY_VALUE);
2556 assert!(tensor_map.iter().all(|&x| x == DUMMY_VALUE));
2557 }
2558
2559 {
2560 let tensor_map = tensor.map().expect("Failed to map shared memory");
2561 assert!(tensor_map.iter().all(|&x| x == DUMMY_VALUE));
2562 }
2563
2564 let mut tensor = ShmTensor::<u8>::new(&shape, None).expect("Failed to create tensor");
2565 assert_eq!(tensor.shape(), &shape);
2566 let new_shape = vec![3, 4, 4];
2567 assert!(
2568 tensor.reshape(&new_shape).is_err(),
2569 "Reshape should fail due to size mismatch"
2570 );
2571 assert_eq!(tensor.shape(), &shape, "Shape should remain unchanged");
2572
2573 let new_shape = vec![2, 3, 4];
2574 tensor.reshape(&new_shape).expect("Reshape should succeed");
2575 assert_eq!(
2576 tensor.shape(),
2577 &new_shape,
2578 "Shape should be updated after successful reshape"
2579 );
2580
2581 {
2582 let mut tensor_map = tensor.map().expect("Failed to map shared memory");
2583 tensor_map.fill(1);
2584 assert!(tensor_map.iter().all(|&x| x == 1));
2585 }
2586
2587 {
2588 let mut tensor_map = tensor.map().expect("Failed to map shared memory");
2589 tensor_map[2] = 42;
2590 assert_eq!(tensor_map[1], 1, "Value at index 1 should be 1");
2591 assert_eq!(tensor_map[2], 42, "Value at index 2 should be 42");
2592 }
2593 }
2594
2595 #[test]
2596 fn test_mem_tensor() {
2597 let shape = vec![2, 3, 4];
2598 let tensor =
2599 MemTensor::<f32>::new(&shape, Some("test_tensor")).expect("Failed to create tensor");
2600 assert_eq!(tensor.shape(), &shape);
2601 assert_eq!(tensor.size(), 2 * 3 * 4 * std::mem::size_of::<f32>());
2602 assert_eq!(tensor.name(), "test_tensor");
2603
2604 {
2605 let mut tensor_map = tensor.map().expect("Failed to map memory");
2606 tensor_map.fill(42.0);
2607 assert!(tensor_map.iter().all(|&x| x == 42.0));
2608 }
2609
2610 let mut tensor = MemTensor::<u8>::new(&shape, None).expect("Failed to create tensor");
2611 assert_eq!(tensor.shape(), &shape);
2612 let new_shape = vec![3, 4, 4];
2613 assert!(
2614 tensor.reshape(&new_shape).is_err(),
2615 "Reshape should fail due to size mismatch"
2616 );
2617 assert_eq!(tensor.shape(), &shape, "Shape should remain unchanged");
2618
2619 let new_shape = vec![2, 3, 4];
2620 tensor.reshape(&new_shape).expect("Reshape should succeed");
2621 assert_eq!(
2622 tensor.shape(),
2623 &new_shape,
2624 "Shape should be updated after successful reshape"
2625 );
2626
2627 {
2628 let mut tensor_map = tensor.map().expect("Failed to map memory");
2629 tensor_map.fill(1);
2630 assert!(tensor_map.iter().all(|&x| x == 1));
2631 }
2632
2633 {
2634 let mut tensor_map = tensor.map().expect("Failed to map memory");
2635 tensor_map[2] = 42;
2636 assert_eq!(tensor_map[1], 1, "Value at index 1 should be 1");
2637 assert_eq!(tensor_map[2], 42, "Value at index 2 should be 42");
2638 }
2639 }
2640
2641 #[test]
2642 #[cfg(target_os = "linux")]
2643 fn test_dma_no_fd_leaks() {
2644 let _lock = FD_LOCK.write().unwrap();
2645 if !is_dma_available() {
2646 log::warn!(
2647 "SKIPPED: {} - DMA memory allocation not available (permission denied or no DMA-BUF support)",
2648 function!()
2649 );
2650 return;
2651 }
2652
2653 let proc = procfs::process::Process::myself()
2654 .expect("Failed to get current process using /proc/self");
2655
2656 let start_open_fds = proc
2657 .fd_count()
2658 .expect("Failed to get open file descriptor count");
2659
2660 for _ in 0..100 {
2661 let tensor = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Dma), None)
2662 .expect("Failed to create tensor");
2663 let mut map = tensor.map().unwrap();
2664 map.as_mut_slice().fill(233);
2665 }
2666
2667 let end_open_fds = proc
2668 .fd_count()
2669 .expect("Failed to get open file descriptor count");
2670
2671 assert_eq!(
2672 start_open_fds, end_open_fds,
2673 "File descriptor leak detected: {} -> {}",
2674 start_open_fds, end_open_fds
2675 );
2676 }
2677
2678 #[test]
2679 #[cfg(target_os = "linux")]
2680 fn test_dma_from_fd_no_fd_leaks() {
2681 let _lock = FD_LOCK.write().unwrap();
2682 if !is_dma_available() {
2683 log::warn!(
2684 "SKIPPED: {} - DMA memory allocation not available (permission denied or no DMA-BUF support)",
2685 function!()
2686 );
2687 return;
2688 }
2689
2690 let proc = procfs::process::Process::myself()
2691 .expect("Failed to get current process using /proc/self");
2692
2693 let start_open_fds = proc
2694 .fd_count()
2695 .expect("Failed to get open file descriptor count");
2696
2697 let orig = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Dma), None).unwrap();
2698
2699 for _ in 0..100 {
2700 let tensor =
2701 Tensor::<u8>::from_fd(orig.clone_fd().unwrap(), orig.shape(), None).unwrap();
2702 let mut map = tensor.map().unwrap();
2703 map.as_mut_slice().fill(233);
2704 }
2705 drop(orig);
2706
2707 let end_open_fds = proc.fd_count().unwrap();
2708
2709 assert_eq!(
2710 start_open_fds, end_open_fds,
2711 "File descriptor leak detected: {} -> {}",
2712 start_open_fds, end_open_fds
2713 );
2714 }
2715
2716 #[test]
2717 #[cfg(target_os = "linux")]
2718 fn test_shm_no_fd_leaks() {
2719 let _lock = FD_LOCK.write().unwrap();
2720 if !is_shm_available() {
2721 log::warn!(
2722 "SKIPPED: {} - SHM memory allocation not available (permission denied or no SHM support)",
2723 function!()
2724 );
2725 return;
2726 }
2727
2728 let proc = procfs::process::Process::myself()
2729 .expect("Failed to get current process using /proc/self");
2730
2731 let start_open_fds = proc
2732 .fd_count()
2733 .expect("Failed to get open file descriptor count");
2734
2735 for _ in 0..100 {
2736 let tensor = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Shm), None)
2737 .expect("Failed to create tensor");
2738 let mut map = tensor.map().unwrap();
2739 map.as_mut_slice().fill(233);
2740 }
2741
2742 let end_open_fds = proc
2743 .fd_count()
2744 .expect("Failed to get open file descriptor count");
2745
2746 assert_eq!(
2747 start_open_fds, end_open_fds,
2748 "File descriptor leak detected: {} -> {}",
2749 start_open_fds, end_open_fds
2750 );
2751 }
2752
2753 #[test]
2754 #[cfg(target_os = "linux")]
2755 fn test_shm_from_fd_no_fd_leaks() {
2756 let _lock = FD_LOCK.write().unwrap();
2757 if !is_shm_available() {
2758 log::warn!(
2759 "SKIPPED: {} - SHM memory allocation not available (permission denied or no SHM support)",
2760 function!()
2761 );
2762 return;
2763 }
2764
2765 let proc = procfs::process::Process::myself()
2766 .expect("Failed to get current process using /proc/self");
2767
2768 let start_open_fds = proc
2769 .fd_count()
2770 .expect("Failed to get open file descriptor count");
2771
2772 let orig = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Shm), None).unwrap();
2773
2774 for _ in 0..100 {
2775 let tensor =
2776 Tensor::<u8>::from_fd(orig.clone_fd().unwrap(), orig.shape(), None).unwrap();
2777 let mut map = tensor.map().unwrap();
2778 map.as_mut_slice().fill(233);
2779 }
2780 drop(orig);
2781
2782 let end_open_fds = proc.fd_count().unwrap();
2783
2784 assert_eq!(
2785 start_open_fds, end_open_fds,
2786 "File descriptor leak detected: {} -> {}",
2787 start_open_fds, end_open_fds
2788 );
2789 }
2790
2791 #[cfg(feature = "ndarray")]
2792 #[test]
2793 fn test_ndarray() {
2794 let _lock = FD_LOCK.read().unwrap();
2795 let shape = vec![2, 3, 4];
2796 let tensor = Tensor::<f32>::new(&shape, None, None).expect("Failed to create tensor");
2797
2798 let mut tensor_map = tensor.map().expect("Failed to map tensor memory");
2799 tensor_map.fill(1.0);
2800
2801 let view = tensor_map.view().expect("Failed to get ndarray view");
2802 assert_eq!(view.shape(), &[2, 3, 4]);
2803 assert!(view.iter().all(|&x| x == 1.0));
2804
2805 let mut view_mut = tensor_map
2806 .view_mut()
2807 .expect("Failed to get mutable ndarray view");
2808 view_mut[[0, 0, 0]] = 42.0;
2809 assert_eq!(view_mut[[0, 0, 0]], 42.0);
2810 assert_eq!(tensor_map[0], 42.0, "Value at index 0 should be 42");
2811 }
2812
2813 #[test]
2814 fn test_buffer_identity_unique() {
2815 let id1 = BufferIdentity::new();
2816 let id2 = BufferIdentity::new();
2817 assert_ne!(
2818 id1.id(),
2819 id2.id(),
2820 "Two identities should have different ids"
2821 );
2822 }
2823
2824 #[test]
2825 fn test_buffer_identity_clone_shares_guard() {
2826 let id1 = BufferIdentity::new();
2827 let weak = id1.weak();
2828 assert!(
2829 weak.upgrade().is_some(),
2830 "Weak should be alive while original exists"
2831 );
2832
2833 let id2 = id1.clone();
2834 assert_eq!(id1.id(), id2.id(), "Cloned identity should have same id");
2835
2836 drop(id1);
2837 assert!(
2838 weak.upgrade().is_some(),
2839 "Weak should still be alive (clone holds Arc)"
2840 );
2841
2842 drop(id2);
2843 assert!(
2844 weak.upgrade().is_none(),
2845 "Weak should be dead after all clones dropped"
2846 );
2847 }
2848
2849 #[test]
2850 fn test_tensor_buffer_identity() {
2851 let t1 = Tensor::<u8>::new(&[100], Some(TensorMemory::Mem), Some("t1")).unwrap();
2852 let t2 = Tensor::<u8>::new(&[100], Some(TensorMemory::Mem), Some("t2")).unwrap();
2853 assert_ne!(
2854 t1.buffer_identity().id(),
2855 t2.buffer_identity().id(),
2856 "Different tensors should have different buffer ids"
2857 );
2858 }
2859
2860 #[test]
2865 fn test_quantization_per_tensor_constructors() {
2866 let q = Quantization::per_tensor(0.1, -5);
2867 assert!(q.is_per_tensor());
2868 assert!(!q.is_per_channel());
2869 assert!(!q.is_symmetric());
2870 assert_eq!(q.scale(), &[0.1]);
2871 assert_eq!(q.zero_point(), Some(&[-5][..]));
2872
2873 let qs = Quantization::per_tensor_symmetric(0.05);
2874 assert!(qs.is_per_tensor());
2875 assert!(qs.is_symmetric());
2876 assert_eq!(qs.zero_point(), None);
2877 }
2878
2879 #[test]
2880 fn test_quantization_per_channel_constructors() {
2881 let q = Quantization::per_channel(vec![0.1, 0.2, 0.3], vec![0, -1, 1], 2).unwrap();
2882 assert!(q.is_per_channel());
2883 assert!(!q.is_symmetric());
2884 assert_eq!(q.axis(), Some(2));
2885 assert_eq!(q.scale().len(), 3);
2886
2887 let qs = Quantization::per_channel_symmetric(vec![0.054, 0.089, 0.195], 0).unwrap();
2888 assert!(qs.is_per_channel());
2889 assert!(qs.is_symmetric());
2890 assert_eq!(qs.axis(), Some(0));
2891 }
2892
2893 #[test]
2894 fn test_quantization_per_channel_length_mismatch_rejected() {
2895 let err = Quantization::per_channel(vec![0.1, 0.2], vec![0, 0, 0], 0).unwrap_err();
2897 assert!(matches!(err, Error::QuantizationInvalid { .. }));
2898 }
2899
2900 #[test]
2901 fn test_quantization_per_channel_empty_rejected() {
2902 let err = Quantization::per_channel_symmetric(vec![], 0).unwrap_err();
2903 assert!(matches!(err, Error::QuantizationInvalid { .. }));
2904 }
2905
2906 #[test]
2913 fn test_quantization_validate_rejects_malformed_deserialize() {
2914 let mut t = Tensor::<i8>::new(&[1, 1, 4], Some(TensorMemory::Mem), None).unwrap();
2915
2916 let q: Quantization = serde_json::from_str(r#"{"scale": []}"#).unwrap();
2918 assert!(matches!(
2919 t.set_quantization(q).unwrap_err(),
2920 Error::QuantizationInvalid { .. }
2921 ));
2922
2923 let q: Quantization =
2925 serde_json::from_str(r#"{"scale": 0.1, "zero_point": [0, 0, 0]}"#).unwrap();
2926 assert!(matches!(
2927 t.set_quantization(q).unwrap_err(),
2928 Error::QuantizationInvalid { .. }
2929 ));
2930
2931 let q: Quantization = serde_json::from_str(
2933 r#"{"scale": [0.1, 0.2, 0.3, 0.4], "zero_point": [0, 0], "axis": 2}"#,
2934 )
2935 .unwrap();
2936 assert!(matches!(
2937 t.set_quantization(q).unwrap_err(),
2938 Error::QuantizationInvalid { .. }
2939 ));
2940 }
2941
2942 #[test]
2943 fn test_quantization_mode_dispatch() {
2944 let pt = Quantization::per_tensor(0.1, -5);
2945 assert!(matches!(
2946 pt.mode(),
2947 QuantMode::PerTensor { scale, zero_point } if scale == 0.1 && zero_point == -5
2948 ));
2949
2950 let pts = Quantization::per_tensor_symmetric(0.05);
2951 assert!(matches!(
2952 pts.mode(),
2953 QuantMode::PerTensorSymmetric { scale } if scale == 0.05
2954 ));
2955
2956 let pc = Quantization::per_channel(vec![0.1, 0.2], vec![0, -1], 2).unwrap();
2957 assert!(matches!(pc.mode(), QuantMode::PerChannel { axis: 2, .. }));
2958
2959 let pcs = Quantization::per_channel_symmetric(vec![0.1, 0.2], 0).unwrap();
2960 assert!(matches!(
2961 pcs.mode(),
2962 QuantMode::PerChannelSymmetric { axis: 0, .. }
2963 ));
2964 }
2965
2966 #[test]
2967 fn test_tensor_quantization_roundtrip_integer() {
2968 let mut t = Tensor::<i8>::new(&[2, 3, 4], Some(TensorMemory::Mem), None).unwrap();
2969 assert!(t.quantization().is_none());
2970 t.set_quantization(Quantization::per_tensor(0.1, -5))
2971 .unwrap();
2972 let q = t.quantization().unwrap();
2973 assert_eq!(q.scale(), &[0.1]);
2974 t.clear_quantization();
2975 assert!(t.quantization().is_none());
2976 }
2977
2978 #[test]
2979 fn test_tensor_with_quantization_builder() {
2980 let t = Tensor::<i8>::new(&[4, 4], Some(TensorMemory::Mem), None)
2981 .unwrap()
2982 .with_quantization(Quantization::per_tensor_symmetric(0.05))
2983 .unwrap();
2984 assert!(t.quantization().is_some());
2985 }
2986
2987 #[test]
2988 fn test_tensor_dyn_quantization_float_arm_returns_none() {
2989 let t = Tensor::<f32>::new(&[2, 2], Some(TensorMemory::Mem), None).unwrap();
2990 let td = TensorDyn::F32(t);
2991 assert!(td.quantization().is_none());
2992 }
2993
2994 #[test]
2995 fn test_tensor_dyn_set_quantization_float_arm_errors() {
2996 let t = Tensor::<f32>::new(&[2, 2], Some(TensorMemory::Mem), None).unwrap();
2997 let mut td = TensorDyn::F32(t);
2998 let err = td
2999 .set_quantization(Quantization::per_tensor(0.1, 0))
3000 .unwrap_err();
3001 assert!(matches!(err, Error::QuantizationInvalid { .. }));
3003 }
3004
3005 fn _compile_fail_doctest_anchor() {}
3015
3016 pub static FD_LOCK: RwLock<()> = RwLock::new(());
3020
3021 #[test]
3024 #[cfg(not(target_os = "linux"))]
3025 fn test_dma_not_available_on_non_linux() {
3026 assert!(
3027 !is_dma_available(),
3028 "DMA memory allocation should NOT be available on non-Linux platforms"
3029 );
3030 }
3031
3032 #[test]
3035 #[cfg(unix)]
3036 fn test_shm_available_and_usable() {
3037 assert!(
3038 is_shm_available(),
3039 "SHM memory allocation should be available on Unix systems"
3040 );
3041
3042 let tensor = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Shm), None)
3044 .expect("Failed to create SHM tensor");
3045
3046 let mut map = tensor.map().expect("Failed to map SHM tensor");
3048 map.as_mut_slice().fill(0xAB);
3049
3050 assert!(
3052 map.as_slice().iter().all(|&b| b == 0xAB),
3053 "SHM tensor data should be writable and readable"
3054 );
3055 }
3056}