Skip to main content

edgefirst_tensor/
lib.rs

1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4/*!
5EdgeFirst HAL - Tensor Module
6
7The `edgefirst_tensor` crate provides a unified interface for managing multi-dimensional arrays (tensors)
8with support for different memory types, including Direct Memory Access (DMA), POSIX Shared Memory (Shm),
9and system memory. The crate defines traits and structures for creating, reshaping, and mapping tensors into memory.
10
11## Examples
12```rust
13use edgefirst_tensor::{Error, Tensor, TensorMemory, TensorTrait};
14# fn main() -> Result<(), Error> {
15let tensor = Tensor::<f32>::new(&[2, 3, 4], Some(TensorMemory::Mem), Some("test_tensor"))?;
16assert_eq!(tensor.memory(), TensorMemory::Mem);
17assert_eq!(tensor.name(), "test_tensor");
18#    Ok(())
19# }
20```
21
22## Overview
23The main structures and traits provided by the `edgefirst_tensor` crate are `TensorTrait` and `TensorMapTrait`,
24which define the behavior of Tensors and their memory mappings, respectively.
25The `Tensor<T>` struct wraps a backend-specific storage with optional image format metadata (`PixelFormat`),
26while the `TensorMap` enum provides access to the underlying data. The `TensorDyn` type-erased enum
27wraps `Tensor<T>` for runtime element-type dispatch.
28 */
29#[cfg(target_os = "linux")]
30mod dma;
31#[cfg(target_os = "linux")]
32mod dmabuf;
33mod error;
34mod format;
35#[cfg(target_os = "macos")]
36mod iosurface;
37mod mem;
38mod pbo;
39#[cfg(unix)]
40mod shm;
41mod tensor_dyn;
42
43#[cfg(target_os = "linux")]
44pub use crate::dma::{DmaMap, DmaTensor};
45#[cfg(target_os = "macos")]
46pub use crate::iosurface::{image_iosurface_layout, IoSurfaceMap, IoSurfaceTensor};
47pub use crate::mem::{MemMap, MemTensor};
48pub use crate::pbo::{PboMap, PboMapping, PboOps, PboTensor};
49#[cfg(unix)]
50pub use crate::shm::{ShmMap, ShmTensor};
51pub use error::{Error, Result};
52pub use format::{PixelFormat, PixelLayout};
53use num_traits::Num;
54use serde::{Deserialize, Serialize};
55#[cfg(unix)]
56use std::os::fd::OwnedFd;
57use std::{
58    fmt,
59    ops::{Deref, DerefMut},
60    sync::{
61        atomic::{AtomicU64, Ordering},
62        Arc, Weak,
63    },
64};
65pub use tensor_dyn::TensorDyn;
66
67/// Per-plane DMA-BUF descriptor for external buffer import.
68///
69/// Owns a duplicated file descriptor plus optional stride and offset metadata.
70/// The fd is duplicated eagerly in [`new()`](Self::new) so that a bad fd is
71/// caught immediately. `import_image` consumes the descriptor and takes
72/// ownership of the duped fd — no further cleanup is needed by the caller.
73///
74/// # Examples
75///
76/// ```rust,no_run
77/// use edgefirst_tensor::PlaneDescriptor;
78/// use std::os::fd::BorrowedFd;
79///
80/// // SAFETY: fd 42 is hypothetical; real code must pass a valid fd.
81/// let pd = unsafe { PlaneDescriptor::new(BorrowedFd::borrow_raw(42)) }
82///     .unwrap()
83///     .with_stride(2048)
84///     .with_offset(0);
85/// ```
86#[cfg(unix)]
87pub struct PlaneDescriptor {
88    fd: OwnedFd,
89    stride: Option<usize>,
90    offset: Option<usize>,
91}
92
93#[cfg(unix)]
94impl PlaneDescriptor {
95    /// Create a new plane descriptor by duplicating the given file descriptor.
96    ///
97    /// The fd is duped immediately — a bad fd fails here rather than inside
98    /// `import_image`. The caller retains ownership of the original fd.
99    ///
100    /// # Errors
101    ///
102    /// Returns an error if the `dup()` syscall fails (e.g. invalid fd or
103    /// fd limit reached).
104    pub fn new(fd: std::os::fd::BorrowedFd<'_>) -> Result<Self> {
105        let owned = fd.try_clone_to_owned()?;
106        Ok(Self {
107            fd: owned,
108            stride: None,
109            offset: None,
110        })
111    }
112
113    /// Set the row stride in bytes (consuming builder).
114    pub fn with_stride(mut self, stride: usize) -> Self {
115        self.stride = Some(stride);
116        self
117    }
118
119    /// Set the plane offset in bytes (consuming builder).
120    pub fn with_offset(mut self, offset: usize) -> Self {
121        self.offset = Some(offset);
122        self
123    }
124
125    /// Consume the descriptor and return the owned file descriptor.
126    pub fn into_fd(self) -> OwnedFd {
127        self.fd
128    }
129
130    /// Row stride in bytes, if set.
131    pub fn stride(&self) -> Option<usize> {
132        self.stride
133    }
134
135    /// Plane offset in bytes, if set.
136    pub fn offset(&self) -> Option<usize> {
137        self.offset
138    }
139}
140
141/// Element type discriminant for runtime type identification.
142#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
143#[repr(u8)]
144#[non_exhaustive]
145pub enum DType {
146    U8,
147    I8,
148    U16,
149    I16,
150    U32,
151    I32,
152    U64,
153    I64,
154    F16,
155    F32,
156    F64,
157}
158
159impl DType {
160    /// Size of one element in bytes.
161    pub const fn size(&self) -> usize {
162        match self {
163            Self::U8 | Self::I8 => 1,
164            Self::U16 | Self::I16 | Self::F16 => 2,
165            Self::U32 | Self::I32 | Self::F32 => 4,
166            Self::U64 | Self::I64 | Self::F64 => 8,
167        }
168    }
169
170    /// Short type name (e.g., "u8", "f32", "f16").
171    pub const fn name(&self) -> &'static str {
172        match self {
173            Self::U8 => "u8",
174            Self::I8 => "i8",
175            Self::U16 => "u16",
176            Self::I16 => "i16",
177            Self::U32 => "u32",
178            Self::I32 => "i32",
179            Self::U64 => "u64",
180            Self::I64 => "i64",
181            Self::F16 => "f16",
182            Self::F32 => "f32",
183            Self::F64 => "f64",
184        }
185    }
186}
187
188impl fmt::Display for DType {
189    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
190        f.write_str(self.name())
191    }
192}
193
194// =============================================================================
195// Quantization metadata — type-gated to integer element types via sealed
196// `IntegerType` trait. Accessors on `Tensor<T>` only compile when `T` is
197// an integer type; calling them on `Tensor<f32>` / `Tensor<f16>` etc. is a
198// compile error, not a runtime one.
199// =============================================================================
200
201mod sealed {
202    pub trait Sealed {}
203    impl Sealed for u8 {}
204    impl Sealed for i8 {}
205    impl Sealed for u16 {}
206    impl Sealed for i16 {}
207    impl Sealed for u32 {}
208    impl Sealed for i32 {}
209    impl Sealed for u64 {}
210    impl Sealed for i64 {}
211    // Deliberately NOT implemented for f16 / f32 / f64.
212}
213
214/// Integer element types that may carry quantization metadata.
215///
216/// Sealed trait: implemented for `u8`, `i8`, `u16`, `i16`, `u32`, `i32`,
217/// `u64`, `i64`. Cannot be implemented downstream. Float element types
218/// (`half::f16`, `f32`, `f64`) are explicitly excluded — quantization
219/// metadata does not apply to float tensors per the edgefirst.json spec.
220pub trait IntegerType: sealed::Sealed {}
221impl IntegerType for u8 {}
222impl IntegerType for i8 {}
223impl IntegerType for u16 {}
224impl IntegerType for i16 {}
225impl IntegerType for u32 {}
226impl IntegerType for i32 {}
227impl IntegerType for u64 {}
228impl IntegerType for i64 {}
229
230/// Quantization parameters for an integer tensor.
231///
232/// Covers all four modes the edgefirst.json spec defines:
233///
234/// | Mode | `scale.len()` | `zero_point` | `axis` |
235/// |---|---|---|---|
236/// | Per-tensor symmetric | 1 | `None` | `None` |
237/// | Per-tensor asymmetric | 1 | `Some(len == 1)` | `None` |
238/// | Per-channel symmetric | >1 | `None` | `Some(c)` |
239/// | Per-channel asymmetric | >1 | `Some(len == scale.len())` | `Some(c)` |
240///
241/// The quantized storage type is carried on the parent [`Tensor<T>`]; this
242/// struct does not duplicate it. Construct via the four named constructors
243/// (the only public entry points); direct field mutation is not allowed so
244/// invalid combinations cannot be represented.
245///
246/// Dequantization formula:
247///
248/// ```text
249///   real_value = scale[c] × (quantized_value[c] - zero_point[c])
250/// ```
251///
252/// where `c` is the channel index (always `0` for per-tensor).
253#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
254pub struct Quantization {
255    /// Per-tensor: `vec![scale]`. Per-channel: `vec![scale_0, scale_1, ...]`.
256    #[serde(deserialize_with = "deserialize_scalar_or_vec_f32")]
257    scale: Vec<f32>,
258
259    /// `None` means symmetric (zero-point is 0). `Some(vec)` must have the
260    /// same length as `scale`.
261    #[serde(
262        default,
263        deserialize_with = "deserialize_opt_scalar_or_vec_i32",
264        skip_serializing_if = "Option::is_none"
265    )]
266    zero_point: Option<Vec<i32>>,
267
268    /// Channel axis for per-channel quantization. `Some(_)` iff
269    /// `scale.len() > 1`. Validated against the parent tensor's shape at
270    /// `set_quantization()` time.
271    #[serde(default, skip_serializing_if = "Option::is_none")]
272    axis: Option<usize>,
273}
274
275/// Semantic mode discriminant for hot-path kernel dispatch.
276///
277/// Obtain via [`Quantization::mode`] once at kernel entry; never inside a
278/// pixel-level loop. The enum is borrow-based so the hot kernel receives
279/// the scales / zero-points as slices without reallocation.
280#[derive(Debug, Clone, Copy)]
281pub enum QuantMode<'a> {
282    PerTensorSymmetric {
283        scale: f32,
284    },
285    PerTensor {
286        scale: f32,
287        zero_point: i32,
288    },
289    PerChannelSymmetric {
290        scales: &'a [f32],
291        axis: usize,
292    },
293    PerChannel {
294        scales: &'a [f32],
295        zero_points: &'a [i32],
296        axis: usize,
297    },
298}
299
300impl Quantization {
301    /// Per-tensor symmetric (zero_point = 0).
302    pub fn per_tensor_symmetric(scale: f32) -> Self {
303        Self {
304            scale: vec![scale],
305            zero_point: None,
306            axis: None,
307        }
308    }
309
310    /// Per-tensor asymmetric — the most common runtime shape.
311    pub fn per_tensor(scale: f32, zero_point: i32) -> Self {
312        Self {
313            scale: vec![scale],
314            zero_point: Some(vec![zero_point]),
315            axis: None,
316        }
317    }
318
319    /// Per-channel symmetric. Errors on empty `scales`.
320    pub fn per_channel_symmetric(scales: Vec<f32>, axis: usize) -> Result<Self> {
321        if scales.is_empty() {
322            return Err(Error::QuantizationInvalid {
323                field: "scale.len",
324                expected: "non-empty per-channel scales".to_string(),
325                got: "length 0".to_string(),
326            });
327        }
328        Ok(Self {
329            scale: scales,
330            zero_point: None,
331            axis: Some(axis),
332        })
333    }
334
335    /// Per-channel asymmetric. Errors on length mismatch between `scales`
336    /// and `zero_points`, or empty arrays.
337    pub fn per_channel(scales: Vec<f32>, zero_points: Vec<i32>, axis: usize) -> Result<Self> {
338        if scales.is_empty() {
339            return Err(Error::QuantizationInvalid {
340                field: "scale.len",
341                expected: "non-empty per-channel scales".to_string(),
342                got: "length 0".to_string(),
343            });
344        }
345        if scales.len() != zero_points.len() {
346            return Err(Error::QuantizationInvalid {
347                field: "zero_point.len",
348                expected: format!("length matches scale ({})", scales.len()),
349                got: format!("length {}", zero_points.len()),
350            });
351        }
352        Ok(Self {
353            scale: scales,
354            zero_point: Some(zero_points),
355            axis: Some(axis),
356        })
357    }
358
359    /// Borrow-based dispatch view. Match once at kernel entry.
360    pub fn mode(&self) -> QuantMode<'_> {
361        match (self.scale.len(), self.zero_point.as_deref(), self.axis) {
362            (1, None, _) => QuantMode::PerTensorSymmetric {
363                scale: self.scale[0],
364            },
365            (1, Some(zps), _) => QuantMode::PerTensor {
366                scale: self.scale[0],
367                zero_point: zps.first().copied().unwrap_or(0),
368            },
369            (_, None, Some(axis)) => QuantMode::PerChannelSymmetric {
370                scales: &self.scale,
371                axis,
372            },
373            (_, Some(zps), Some(axis)) => QuantMode::PerChannel {
374                scales: &self.scale,
375                zero_points: zps,
376                axis,
377            },
378            // The `validate()` path prevents constructing a
379            // per-channel Quantization without an axis, so the remaining
380            // pattern is unreachable in practice. Fall back to
381            // per-tensor symmetric using scale[0] to avoid panicking in
382            // release; debug builds assert.
383            _ => {
384                debug_assert!(
385                    false,
386                    "Quantization::mode: per-channel without axis is unreachable"
387                );
388                QuantMode::PerTensorSymmetric {
389                    scale: self.scale.first().copied().unwrap_or(1.0),
390                }
391            }
392        }
393    }
394
395    /// Returns `true` for per-tensor quantization (`scale.len() == 1`).
396    pub fn is_per_tensor(&self) -> bool {
397        self.scale.len() == 1
398    }
399
400    /// Returns `true` for per-channel quantization (`scale.len() > 1`).
401    pub fn is_per_channel(&self) -> bool {
402        self.scale.len() > 1
403    }
404
405    /// Returns `true` for symmetric quantization (no zero-point, or
406    /// zero-point vector of all zeros).
407    pub fn is_symmetric(&self) -> bool {
408        match &self.zero_point {
409            None => true,
410            Some(zps) => zps.iter().all(|&z| z == 0),
411        }
412    }
413
414    /// Borrow the scale array. Length 1 for per-tensor; `num_channels` for
415    /// per-channel.
416    pub fn scale(&self) -> &[f32] {
417        &self.scale
418    }
419
420    /// Borrow the zero-point array. `None` for symmetric.
421    pub fn zero_point(&self) -> Option<&[i32]> {
422        self.zero_point.as_deref()
423    }
424
425    /// Channel axis for per-channel quantization. `None` for per-tensor.
426    pub fn axis(&self) -> Option<usize> {
427        self.axis
428    }
429
430    /// Validate against a target tensor shape. Runs in
431    /// `Tensor::set_quantization()`. Catches:
432    ///   - empty `scale` (reject — must declare at least one factor)
433    ///   - `zero_point` length inconsistent with `scale` (reject —
434    ///     per-tensor must have len 1, per-channel must match `scale.len`)
435    ///   - `axis >= shape.len()` (axis out of range)
436    ///   - `scale.len() != shape[axis]` for per-channel
437    ///   - per-channel without axis (reject)
438    ///   - per-tensor with redundant axis (reject)
439    pub(crate) fn validate(&self, shape: &[usize]) -> Result<()> {
440        // `Quantization` is `Deserialize`, so malformed JSON like
441        // `{"scale": [], "zero_point": []}` could otherwise produce an
442        // ill-defined value that confuses `mode()` selection and the
443        // per-channel kernels' indexing.
444        if self.scale.is_empty() {
445            return Err(Error::QuantizationInvalid {
446                field: "scale.len",
447                expected: ">= 1".to_string(),
448                got: "0".to_string(),
449            });
450        }
451        if let Some(zps) = self.zero_point.as_ref() {
452            // Per-tensor: scale.len() == 1 and zero_point.len() must == 1.
453            // Per-channel: zero_point.len() must == scale.len().
454            let expected = if self.scale.len() == 1 {
455                1
456            } else {
457                self.scale.len()
458            };
459            if zps.len() != expected {
460                return Err(Error::QuantizationInvalid {
461                    field: "zero_point.len",
462                    expected: format!(
463                        "{expected} (matching {})",
464                        if self.scale.len() == 1 {
465                            "per-tensor scale"
466                        } else {
467                            "per-channel scale.len"
468                        }
469                    ),
470                    got: format!("length {}", zps.len()),
471                });
472            }
473        }
474
475        match (self.scale.len(), self.axis) {
476            (1, None) => Ok(()),
477            (1, Some(_)) => Err(Error::QuantizationInvalid {
478                field: "per_tensor_redundant_axis",
479                expected: "axis=None for per-tensor quantization".to_string(),
480                got: format!("axis={:?}", self.axis),
481            }),
482            (_, None) => Err(Error::QuantizationInvalid {
483                field: "per_channel_requires_axis",
484                expected: format!(
485                    "axis=Some(_) for per-channel quantization (scale.len={})",
486                    self.scale.len()
487                ),
488                got: "axis=None".to_string(),
489            }),
490            (n, Some(axis)) => {
491                if axis >= shape.len() {
492                    return Err(Error::QuantizationInvalid {
493                        field: "axis",
494                        expected: format!("axis < tensor rank ({})", shape.len()),
495                        got: format!("axis={axis}"),
496                    });
497                }
498                if shape[axis] != n {
499                    return Err(Error::QuantizationInvalid {
500                        field: "scale.len",
501                        expected: format!("length matches shape[{axis}] ({})", shape[axis]),
502                        got: format!("length {n}"),
503                    });
504                }
505                Ok(())
506            }
507        }
508    }
509}
510
511impl From<(f32, i32)> for Quantization {
512    /// Convenience construction from a `(scale, zero_point)` tuple. Matches
513    /// the legacy `QuantTuple` / `Quantization::new` calling convention so
514    /// existing `(0.1, -128).into()` sites keep working.
515    fn from((scale, zero_point): (f32, i32)) -> Self {
516        Self::per_tensor(scale, zero_point)
517    }
518}
519
520fn deserialize_scalar_or_vec_f32<'de, D: serde::Deserializer<'de>>(
521    de: D,
522) -> std::result::Result<Vec<f32>, D::Error> {
523    use serde::de::{self, Visitor};
524    struct V;
525    impl<'de> Visitor<'de> for V {
526        type Value = Vec<f32>;
527        fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
528            f.write_str("f32 or array of f32")
529        }
530        fn visit_f64<E: de::Error>(self, v: f64) -> std::result::Result<Self::Value, E> {
531            Ok(vec![v as f32])
532        }
533        #[allow(clippy::cast_possible_truncation)]
534        fn visit_i64<E: de::Error>(self, v: i64) -> std::result::Result<Self::Value, E> {
535            Ok(vec![v as f32])
536        }
537        #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
538        fn visit_u64<E: de::Error>(self, v: u64) -> std::result::Result<Self::Value, E> {
539            Ok(vec![v as f32])
540        }
541        fn visit_seq<A: de::SeqAccess<'de>>(
542            self,
543            mut seq: A,
544        ) -> std::result::Result<Self::Value, A::Error> {
545            let mut out = Vec::with_capacity(seq.size_hint().unwrap_or(1));
546            while let Some(x) = seq.next_element::<f32>()? {
547                out.push(x);
548            }
549            Ok(out)
550        }
551    }
552    de.deserialize_any(V)
553}
554
555fn deserialize_opt_scalar_or_vec_i32<'de, D: serde::Deserializer<'de>>(
556    de: D,
557) -> std::result::Result<Option<Vec<i32>>, D::Error> {
558    use serde::de::{self, Visitor};
559    struct V;
560    impl<'de> Visitor<'de> for V {
561        type Value = Option<Vec<i32>>;
562        fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
563            f.write_str("null, i32, or array of i32")
564        }
565        fn visit_none<E: de::Error>(self) -> std::result::Result<Self::Value, E> {
566            Ok(None)
567        }
568        fn visit_unit<E: de::Error>(self) -> std::result::Result<Self::Value, E> {
569            Ok(None)
570        }
571        fn visit_some<D2: serde::Deserializer<'de>>(
572            self,
573            de: D2,
574        ) -> std::result::Result<Self::Value, D2::Error> {
575            struct Inner;
576            impl<'de> Visitor<'de> for Inner {
577                type Value = Vec<i32>;
578                fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
579                    f.write_str("i32 or array of i32")
580                }
581                #[allow(clippy::cast_possible_truncation)]
582                fn visit_i64<E: de::Error>(self, v: i64) -> std::result::Result<Self::Value, E> {
583                    Ok(vec![v as i32])
584                }
585                #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
586                fn visit_u64<E: de::Error>(self, v: u64) -> std::result::Result<Self::Value, E> {
587                    Ok(vec![v as i32])
588                }
589                fn visit_seq<A: de::SeqAccess<'de>>(
590                    self,
591                    mut seq: A,
592                ) -> std::result::Result<Self::Value, A::Error> {
593                    let mut out = Vec::with_capacity(seq.size_hint().unwrap_or(1));
594                    while let Some(x) = seq.next_element::<i32>()? {
595                        out.push(x);
596                    }
597                    Ok(out)
598                }
599            }
600            de.deserialize_any(Inner).map(Some)
601        }
602        #[allow(clippy::cast_possible_truncation)]
603        fn visit_i64<E: de::Error>(self, v: i64) -> std::result::Result<Self::Value, E> {
604            Ok(Some(vec![v as i32]))
605        }
606        #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
607        fn visit_u64<E: de::Error>(self, v: u64) -> std::result::Result<Self::Value, E> {
608            Ok(Some(vec![v as i32]))
609        }
610        fn visit_seq<A: de::SeqAccess<'de>>(
611            self,
612            mut seq: A,
613        ) -> std::result::Result<Self::Value, A::Error> {
614            let mut out = Vec::with_capacity(seq.size_hint().unwrap_or(1));
615            while let Some(x) = seq.next_element::<i32>()? {
616                out.push(x);
617            }
618            Ok(Some(out))
619        }
620    }
621    de.deserialize_option(V)
622}
623
624/// Monotonic counter for buffer identity IDs.
625static NEXT_BUFFER_ID: AtomicU64 = AtomicU64::new(1);
626
627/// Unique identity for a tensor's underlying buffer.
628///
629/// Created fresh on every buffer allocation or import. The `id` is a monotonic
630/// u64 used as a cache key. The `guard` is an `Arc<()>` whose weak references
631/// allow downstream caches to detect when the buffer has been dropped.
632#[derive(Debug, Clone)]
633pub struct BufferIdentity {
634    id: u64,
635    guard: Arc<()>,
636}
637
638impl BufferIdentity {
639    /// Create a new unique buffer identity.
640    pub fn new() -> Self {
641        Self {
642            id: NEXT_BUFFER_ID.fetch_add(1, Ordering::Relaxed),
643            guard: Arc::new(()),
644        }
645    }
646
647    /// Unique identifier for this buffer. Changes when the buffer changes.
648    pub fn id(&self) -> u64 {
649        self.id
650    }
651
652    /// Returns a weak reference to the buffer guard. Goes dead when the
653    /// owning Tensor is dropped (and no clones remain).
654    pub fn weak(&self) -> Weak<()> {
655        Arc::downgrade(&self.guard)
656    }
657}
658
659impl Default for BufferIdentity {
660    fn default() -> Self {
661        Self::new()
662    }
663}
664
665#[cfg(target_os = "linux")]
666use nix::sys::stat::{major, minor};
667
668pub trait TensorTrait<T>: Send + Sync
669where
670    T: Num + Clone + fmt::Debug,
671{
672    /// Create a new tensor with the given shape and optional name. If no name
673    /// is given, a random name will be generated.
674    fn new(shape: &[usize], name: Option<&str>) -> Result<Self>
675    where
676        Self: Sized;
677
678    #[cfg(unix)]
679    /// Create a new tensor using the given file descriptor, shape, and optional
680    /// name. If no name is given, a random name will be generated.
681    ///
682    /// On Linux: Inspects the fd to determine DMA vs SHM based on device major/minor.
683    /// On other Unix (macOS): Always creates SHM tensor.
684    fn from_fd(fd: std::os::fd::OwnedFd, shape: &[usize], name: Option<&str>) -> Result<Self>
685    where
686        Self: Sized;
687
688    #[cfg(unix)]
689    /// Clone the file descriptor associated with this tensor.
690    fn clone_fd(&self) -> Result<std::os::fd::OwnedFd>;
691
692    /// Get the memory type of this tensor.
693    fn memory(&self) -> TensorMemory;
694
695    /// Get the name of this tensor.
696    fn name(&self) -> String;
697
698    /// Get the number of elements in this tensor.
699    fn len(&self) -> usize {
700        self.shape().iter().product()
701    }
702
703    /// Check if the tensor is empty.
704    fn is_empty(&self) -> bool {
705        self.len() == 0
706    }
707
708    /// Get the size in bytes of this tensor.
709    fn size(&self) -> usize {
710        self.len() * std::mem::size_of::<T>()
711    }
712
713    /// Get the shape of this tensor.
714    fn shape(&self) -> &[usize];
715
716    /// Reshape this tensor to the given shape. The total number of elements
717    /// must remain the same.
718    fn reshape(&mut self, shape: &[usize]) -> Result<()>;
719
720    /// Map the tensor into memory and return a TensorMap for accessing the
721    /// data.
722    fn map(&self) -> Result<TensorMap<T>>;
723
724    /// Get the buffer identity for cache keying and liveness tracking.
725    fn buffer_identity(&self) -> &BufferIdentity;
726}
727
728pub trait TensorMapTrait<T>
729where
730    T: Num + Clone + fmt::Debug,
731{
732    /// Get the shape of this tensor map.
733    fn shape(&self) -> &[usize];
734
735    /// Unmap the tensor from memory.
736    fn unmap(&mut self);
737
738    /// Get the number of elements in this tensor map.
739    fn len(&self) -> usize {
740        self.shape().iter().product()
741    }
742
743    /// Check if the tensor map is empty.
744    fn is_empty(&self) -> bool {
745        self.len() == 0
746    }
747
748    /// Get the size in bytes of this tensor map.
749    fn size(&self) -> usize {
750        self.len() * std::mem::size_of::<T>()
751    }
752
753    /// Get a slice to the data in this tensor map.
754    fn as_slice(&self) -> &[T];
755
756    /// Get a mutable slice to the data in this tensor map.
757    fn as_mut_slice(&mut self) -> &mut [T];
758
759    #[cfg(feature = "ndarray")]
760    /// Get an ndarray ArrayView of the tensor data.
761    fn view(&'_ self) -> Result<ndarray::ArrayView<'_, T, ndarray::Dim<ndarray::IxDynImpl>>> {
762        Ok(ndarray::ArrayView::from_shape(
763            self.shape(),
764            self.as_slice(),
765        )?)
766    }
767
768    #[cfg(feature = "ndarray")]
769    /// Get an ndarray ArrayViewMut of the tensor data.
770    fn view_mut(
771        &'_ mut self,
772    ) -> Result<ndarray::ArrayViewMut<'_, T, ndarray::Dim<ndarray::IxDynImpl>>> {
773        let shape = self.shape().to_vec();
774        Ok(ndarray::ArrayViewMut::from_shape(
775            shape,
776            self.as_mut_slice(),
777        )?)
778    }
779}
780
781#[derive(Debug, Clone, Copy, PartialEq, Eq)]
782pub enum TensorMemory {
783    /// Platform-native zero-copy GPU buffer.
784    ///
785    /// On Linux this is a DMA-BUF (`DmaTensor` in `crates/tensor/src/dma.rs`)
786    /// allocated via the DRM/dma-heap subsystem. On macOS this is an
787    /// IOSurface (`IoSurfaceTensor` in `crates/tensor/src/iosurface.rs`).
788    /// Both fit into the same `TensorStorage::Dma` slot at the trait
789    /// level — the public C API discriminant (`HalTensorMemory::Dma=1`)
790    /// works on both platforms with no ABI break.
791    ///
792    /// Allows hardware-accelerated paths (OpenGL backend on Linux via
793    /// `EGL_EXT_image_dma_buf_import`; macOS via
794    /// `EGL_ANGLE_iosurface_client_buffer`). CPU access via `map()`
795    /// incurs cache-coherency overhead on Linux DMA-BUF and is similar
796    /// in cost on IOSurface; SHM/Mem are cheaper for CPU-only workloads.
797    Dma,
798    #[cfg(unix)]
799    /// POSIX Shared Memory allocation. Suitable for inter-process
800    /// communication, but not suitable for hardware acceleration.
801    Shm,
802
803    /// Regular system memory allocation
804    Mem,
805
806    /// OpenGL Pixel Buffer Object memory. Created by ImageProcessor
807    /// when DMA-buf is unavailable but OpenGL is present.
808    Pbo,
809}
810
811impl From<TensorMemory> for String {
812    fn from(memory: TensorMemory) -> Self {
813        match memory {
814            TensorMemory::Dma => "dma".to_owned(),
815            #[cfg(unix)]
816            TensorMemory::Shm => "shm".to_owned(),
817            TensorMemory::Mem => "mem".to_owned(),
818            TensorMemory::Pbo => "pbo".to_owned(),
819        }
820    }
821}
822
823impl TryFrom<&str> for TensorMemory {
824    type Error = Error;
825
826    fn try_from(s: &str) -> Result<Self> {
827        match s {
828            "dma" => Ok(TensorMemory::Dma),
829            #[cfg(unix)]
830            "shm" => Ok(TensorMemory::Shm),
831            "mem" => Ok(TensorMemory::Mem),
832            "pbo" => Ok(TensorMemory::Pbo),
833            _ => Err(Error::InvalidMemoryType(s.to_owned())),
834        }
835    }
836}
837
838#[derive(Debug)]
839#[allow(dead_code)] // Variants are constructed by downstream crates via pub(crate) helpers
840pub(crate) enum TensorStorage<T>
841where
842    T: Num + Clone + fmt::Debug + Send + Sync,
843{
844    /// Platform-native zero-copy GPU buffer. Inner type differs per
845    /// target: `DmaTensor` on Linux (DMA-BUF fd), `IoSurfaceTensor` on
846    /// macOS (CFRetained IOSurface). The shared variant name keeps the
847    /// public `TensorMemory::Dma` discriminant stable across platforms.
848    #[cfg(target_os = "linux")]
849    Dma(DmaTensor<T>),
850    #[cfg(target_os = "macos")]
851    Dma(IoSurfaceTensor<T>),
852    #[cfg(unix)]
853    Shm(ShmTensor<T>),
854    Mem(MemTensor<T>),
855    Pbo(PboTensor<T>),
856}
857
858impl<T> TensorStorage<T>
859where
860    T: Num + Clone + fmt::Debug + Send + Sync,
861{
862    /// Create a new tensor storage with the given shape, memory type, and
863    /// optional name. If no name is given, a random name will be generated.
864    /// If no memory type is given, the best available memory type will be
865    /// chosen based on the platform and environment variables.
866    fn new(shape: &[usize], memory: Option<TensorMemory>, name: Option<&str>) -> Result<Self> {
867        match memory {
868            #[cfg(target_os = "linux")]
869            Some(TensorMemory::Dma) => {
870                DmaTensor::<T>::new(shape, name).map(TensorStorage::Dma)
871            }
872            #[cfg(target_os = "macos")]
873            Some(TensorMemory::Dma) => {
874                IoSurfaceTensor::<T>::new(shape, name).map(TensorStorage::Dma)
875            }
876            #[cfg(not(any(target_os = "linux", target_os = "macos")))]
877            Some(TensorMemory::Dma) => Err(crate::error::Error::NotImplemented(
878                "TensorMemory::Dma is only available on Linux (DMA-BUF) and macOS (IOSurface)"
879                    .to_owned(),
880            )),
881            #[cfg(unix)]
882            Some(TensorMemory::Shm) => {
883                ShmTensor::<T>::new(shape, name).map(TensorStorage::Shm)
884            }
885            Some(TensorMemory::Mem) => {
886                MemTensor::<T>::new(shape, name).map(TensorStorage::Mem)
887            }
888            Some(TensorMemory::Pbo) => Err(crate::error::Error::NotImplemented(
889                "PboTensor cannot be created via Tensor::new() — use ImageProcessor::create_image()".to_owned(),
890            )),
891            None => {
892                if std::env::var("EDGEFIRST_TENSOR_FORCE_MEM")
893                    .is_ok_and(|x| x != "0" && x.to_lowercase() != "false")
894                {
895                    MemTensor::<T>::new(shape, name).map(TensorStorage::Mem)
896                } else {
897                    #[cfg(target_os = "linux")]
898                    {
899                        // Linux: Try DMA -> SHM -> Mem
900                        match DmaTensor::<T>::new(shape, name) {
901                            Ok(tensor) => Ok(TensorStorage::Dma(tensor)),
902                            Err(_) => {
903                                match ShmTensor::<T>::new(shape, name)
904                                    .map(TensorStorage::Shm)
905                                {
906                                    Ok(tensor) => Ok(tensor),
907                                    Err(_) => MemTensor::<T>::new(shape, name)
908                                        .map(TensorStorage::Mem),
909                                }
910                            }
911                        }
912                    }
913                    #[cfg(target_os = "macos")]
914                    {
915                        // macOS: Try IOSurface -> SHM -> Mem. IOSurface
916                        // is the GPU-shareable backend (zero-copy via
917                        // ANGLE), filling the same role as DMA-BUF on
918                        // Linux. Falls back to SHM if IOSurface alloc
919                        // fails (memory pressure or sandboxed contexts).
920                        match IoSurfaceTensor::<T>::new(shape, name) {
921                            Ok(tensor) => Ok(TensorStorage::Dma(tensor)),
922                            Err(_) => match ShmTensor::<T>::new(shape, name)
923                                .map(TensorStorage::Shm)
924                            {
925                                Ok(tensor) => Ok(tensor),
926                                Err(_) => MemTensor::<T>::new(shape, name)
927                                    .map(TensorStorage::Mem),
928                            },
929                        }
930                    }
931                    #[cfg(all(unix, not(any(target_os = "linux", target_os = "macos"))))]
932                    {
933                        // Other Unix (BSD): Try SHM -> Mem (no DMA)
934                        match ShmTensor::<T>::new(shape, name) {
935                            Ok(tensor) => Ok(TensorStorage::Shm(tensor)),
936                            Err(_) => {
937                                MemTensor::<T>::new(shape, name).map(TensorStorage::Mem)
938                            }
939                        }
940                    }
941                    #[cfg(not(unix))]
942                    {
943                        // Windows/other: Mem only
944                        MemTensor::<T>::new(shape, name).map(TensorStorage::Mem)
945                    }
946                }
947            }
948        }
949    }
950
951    /// Create a DMA-backed tensor storage with an explicit byte size that
952    /// may exceed `shape.product() * sizeof(T)`. Used for image tensors
953    /// with row-padded layouts (see `DmaTensor::new_with_byte_size`).
954    ///
955    /// This is intentionally DMA-only: padding is only meaningful for
956    /// buffers that will be imported as GPU textures via EGLImage. PBO,
957    /// Shm, and Mem storage doesn't benefit from pitch alignment and
958    /// shouldn't pay the memory cost.
959    #[cfg(target_os = "linux")]
960    pub(crate) fn new_dma_with_byte_size(
961        shape: &[usize],
962        byte_size: usize,
963        name: Option<&str>,
964    ) -> Result<Self> {
965        DmaTensor::<T>::new_with_byte_size(shape, byte_size, name).map(TensorStorage::Dma)
966    }
967
968    // No non-Linux stub: the only caller (`Tensor::image_with_stride`)
969    // returns `NotImplemented` directly on non-Linux without ever
970    // reaching the storage layer, so defining a stub here would be
971    // dead code and fail the `-D warnings` clippy gate on macOS CI.
972
973    /// Allocate an image-formatted IOSurface-backed storage (macOS).
974    ///
975    /// Used by `Tensor::image()` when the caller requests
976    /// `TensorMemory::Dma` and the format has an IOSurface FourCC
977    /// mapping (YUYV, RGBA, BGRA today). Falls back to `new_with_byte_size`
978    /// otherwise.
979    #[cfg(target_os = "macos")]
980    pub(crate) fn new_image_iosurface(
981        width: usize,
982        height: usize,
983        format: PixelFormat,
984        shape: &[usize],
985        name: Option<&str>,
986    ) -> Result<Self> {
987        IoSurfaceTensor::<T>::new_image(width, height, format, shape, name).map(TensorStorage::Dma)
988    }
989
990    /// Create a new tensor storage using the given file descriptor, shape,
991    /// and optional name.
992    #[cfg(unix)]
993    fn from_fd(fd: OwnedFd, shape: &[usize], name: Option<&str>) -> Result<Self> {
994        #[cfg(target_os = "linux")]
995        {
996            use nix::sys::stat::fstat;
997
998            let stat = fstat(&fd)?;
999            let major = major(stat.st_dev);
1000            let minor = minor(stat.st_dev);
1001
1002            log::debug!("Creating tensor from fd: major={major}, minor={minor}");
1003
1004            if major != 0 {
1005                // Dma and Shm tensors are expected to have major number 0
1006                return Err(Error::UnknownDeviceType(major, minor));
1007            }
1008
1009            match minor {
1010                9 | 10 => {
1011                    // minor number 9 & 10 indicates DMA memory
1012                    DmaTensor::<T>::from_fd(fd, shape, name).map(TensorStorage::Dma)
1013                }
1014                _ => {
1015                    // other minor numbers are assumed to be shared memory
1016                    ShmTensor::<T>::from_fd(fd, shape, name).map(TensorStorage::Shm)
1017                }
1018            }
1019        }
1020        #[cfg(all(unix, not(target_os = "linux")))]
1021        {
1022            // On macOS/BSD, always use SHM (no DMA support)
1023            ShmTensor::<T>::from_fd(fd, shape, name).map(TensorStorage::Shm)
1024        }
1025    }
1026}
1027
1028impl<T> TensorTrait<T> for TensorStorage<T>
1029where
1030    T: Num + Clone + fmt::Debug + Send + Sync,
1031{
1032    fn new(shape: &[usize], name: Option<&str>) -> Result<Self> {
1033        Self::new(shape, None, name)
1034    }
1035
1036    #[cfg(unix)]
1037    fn from_fd(fd: OwnedFd, shape: &[usize], name: Option<&str>) -> Result<Self> {
1038        Self::from_fd(fd, shape, name)
1039    }
1040
1041    #[cfg(unix)]
1042    fn clone_fd(&self) -> Result<OwnedFd> {
1043        match self {
1044            TensorStorage::Dma(t) => t.clone_fd(),
1045            TensorStorage::Shm(t) => t.clone_fd(),
1046            TensorStorage::Mem(t) => t.clone_fd(),
1047            TensorStorage::Pbo(t) => t.clone_fd(),
1048        }
1049    }
1050
1051    fn memory(&self) -> TensorMemory {
1052        match self {
1053            #[cfg(any(target_os = "linux", target_os = "macos"))]
1054            TensorStorage::Dma(_) => TensorMemory::Dma,
1055            #[cfg(unix)]
1056            TensorStorage::Shm(_) => TensorMemory::Shm,
1057            TensorStorage::Mem(_) => TensorMemory::Mem,
1058            TensorStorage::Pbo(_) => TensorMemory::Pbo,
1059        }
1060    }
1061
1062    fn name(&self) -> String {
1063        match self {
1064            #[cfg(any(target_os = "linux", target_os = "macos"))]
1065            TensorStorage::Dma(t) => t.name(),
1066            #[cfg(unix)]
1067            TensorStorage::Shm(t) => t.name(),
1068            TensorStorage::Mem(t) => t.name(),
1069            TensorStorage::Pbo(t) => t.name(),
1070        }
1071    }
1072
1073    fn shape(&self) -> &[usize] {
1074        match self {
1075            #[cfg(any(target_os = "linux", target_os = "macos"))]
1076            TensorStorage::Dma(t) => t.shape(),
1077            #[cfg(unix)]
1078            TensorStorage::Shm(t) => t.shape(),
1079            TensorStorage::Mem(t) => t.shape(),
1080            TensorStorage::Pbo(t) => t.shape(),
1081        }
1082    }
1083
1084    fn reshape(&mut self, shape: &[usize]) -> Result<()> {
1085        match self {
1086            #[cfg(any(target_os = "linux", target_os = "macos"))]
1087            TensorStorage::Dma(t) => t.reshape(shape),
1088            #[cfg(unix)]
1089            TensorStorage::Shm(t) => t.reshape(shape),
1090            TensorStorage::Mem(t) => t.reshape(shape),
1091            TensorStorage::Pbo(t) => t.reshape(shape),
1092        }
1093    }
1094
1095    fn map(&self) -> Result<TensorMap<T>> {
1096        match self {
1097            #[cfg(any(target_os = "linux", target_os = "macos"))]
1098            TensorStorage::Dma(t) => t.map(),
1099            #[cfg(unix)]
1100            TensorStorage::Shm(t) => t.map(),
1101            TensorStorage::Mem(t) => t.map(),
1102            TensorStorage::Pbo(t) => t.map(),
1103        }
1104    }
1105
1106    fn buffer_identity(&self) -> &BufferIdentity {
1107        match self {
1108            #[cfg(any(target_os = "linux", target_os = "macos"))]
1109            TensorStorage::Dma(t) => t.buffer_identity(),
1110            #[cfg(unix)]
1111            TensorStorage::Shm(t) => t.buffer_identity(),
1112            TensorStorage::Mem(t) => t.buffer_identity(),
1113            TensorStorage::Pbo(t) => t.buffer_identity(),
1114        }
1115    }
1116}
1117
1118/// Multi-backend tensor with optional image format metadata.
1119///
1120/// When `format` is `Some`, this tensor represents an image. Width, height,
1121/// and channels are derived from `shape` + `format`. When `format` is `None`,
1122/// this is a raw tensor (identical to the pre-refactoring behavior).
1123#[derive(Debug)]
1124pub struct Tensor<T>
1125where
1126    T: Num + Clone + fmt::Debug + Send + Sync,
1127{
1128    pub(crate) storage: TensorStorage<T>,
1129    format: Option<PixelFormat>,
1130    chroma: Option<Box<Tensor<T>>>,
1131    /// Row stride in bytes for externally allocated buffers with row padding.
1132    /// `None` means tightly packed (stride == width * bytes_per_pixel).
1133    row_stride: Option<usize>,
1134    /// Byte offset within the DMA-BUF where image data starts.
1135    /// `None` means offset 0 (data starts at the beginning of the buffer).
1136    plane_offset: Option<usize>,
1137    /// Quantization metadata for integer-typed tensors. Public access is
1138    /// gated by the `IntegerType` trait — `Tensor<f32>` etc. carry the
1139    /// field for layout uniformity but have no way to read or write it.
1140    pub(crate) quantization: Option<Quantization>,
1141}
1142
1143impl<T> Tensor<T>
1144where
1145    T: Num + Clone + fmt::Debug + Send + Sync,
1146{
1147    /// Wrap a TensorStorage in a Tensor with no image metadata.
1148    pub(crate) fn wrap(storage: TensorStorage<T>) -> Self {
1149        Self {
1150            storage,
1151            format: None,
1152            chroma: None,
1153            row_stride: None,
1154            plane_offset: None,
1155            quantization: None,
1156        }
1157    }
1158
1159    /// Construct a tensor from a row-major element slice + shape. Allocates a
1160    /// new buffer (`TensorMemory::Mem`) and memcpys the contents; caller
1161    /// retains ownership of the input slice.
1162    ///
1163    /// # Errors
1164    ///
1165    /// - [`Error::InvalidShape`] if `values.len() != shape.iter().product()`.
1166    /// - Propagates any allocation error from [`Self::new`].
1167    pub fn from_slice(values: &[T], shape: &[usize]) -> Result<Self>
1168    where
1169        T: Copy,
1170    {
1171        let expected: usize = shape.iter().product();
1172        if values.len() != expected {
1173            return Err(Error::InvalidShape(format!(
1174                "from_slice: values.len()={} but shape product={expected} (shape={shape:?})",
1175                values.len()
1176            )));
1177        }
1178        let t = Self::new(shape, Some(TensorMemory::Mem), None)?;
1179        {
1180            let mut m = t.map()?;
1181            m.as_mut_slice().copy_from_slice(values);
1182        }
1183        Ok(t)
1184    }
1185
1186    /// Construct a tensor from a 3-D ndarray view. Respects strides — one
1187    /// copy in all cases; contiguous views take a memcpy fast path.
1188    ///
1189    /// Only available when the `ndarray` feature is enabled.
1190    #[cfg(feature = "ndarray")]
1191    pub fn from_arrayview3(view: ndarray::ArrayView3<'_, T>) -> Result<Self>
1192    where
1193        T: Copy,
1194    {
1195        let (h, w, c) = view.dim();
1196        let t = Self::new(&[h, w, c], Some(TensorMemory::Mem), None)?;
1197        {
1198            let mut m = t.map()?;
1199            let dst = m.as_mut_slice();
1200            if let Some(src) = view.as_slice() {
1201                dst.copy_from_slice(src);
1202            } else {
1203                for (d, &s) in dst.iter_mut().zip(view.iter()) {
1204                    *d = s;
1205                }
1206            }
1207        }
1208        Ok(t)
1209    }
1210
1211    /// Create a new tensor with the given shape, memory type, and optional
1212    /// name. If no name is given, a random name will be generated. If no
1213    /// memory type is given, the best available memory type will be chosen
1214    /// based on the platform and environment variables.
1215    ///
1216    /// On Linux platforms, the order of preference is: Dma -> Shm -> Mem.
1217    /// On other Unix platforms (macOS), the order is: Shm -> Mem.
1218    /// On non-Unix platforms, only Mem is available.
1219    ///
1220    /// # Environment Variables
1221    /// - `EDGEFIRST_TENSOR_FORCE_MEM`: If set to a non-zero and non-false
1222    ///   value, forces the use of regular system memory allocation
1223    ///   (`TensorMemory::Mem`) regardless of platform capabilities.
1224    ///
1225    /// # Example
1226    /// ```rust
1227    /// use edgefirst_tensor::{Error, Tensor, TensorMemory, TensorTrait};
1228    /// # fn main() -> Result<(), Error> {
1229    /// let tensor = Tensor::<f32>::new(&[2, 3, 4], Some(TensorMemory::Mem), Some("test_tensor"))?;
1230    /// assert_eq!(tensor.memory(), TensorMemory::Mem);
1231    /// assert_eq!(tensor.name(), "test_tensor");
1232    /// #    Ok(())
1233    /// # }
1234    /// ```
1235    pub fn new(shape: &[usize], memory: Option<TensorMemory>, name: Option<&str>) -> Result<Self> {
1236        let _span = tracing::trace_span!(
1237            "tensor.alloc",
1238            ?shape,
1239            memory = ?memory,
1240            dtype = std::any::type_name::<T>(),
1241        )
1242        .entered();
1243        TensorStorage::new(shape, memory, name).map(Self::wrap)
1244    }
1245
1246    /// Create an image tensor with the given format.
1247    pub fn image(
1248        width: usize,
1249        height: usize,
1250        format: PixelFormat,
1251        memory: Option<TensorMemory>,
1252    ) -> Result<Self> {
1253        let shape = match format.layout() {
1254            PixelLayout::Packed => vec![height, width, format.channels()],
1255            PixelLayout::Planar => vec![format.channels(), height, width],
1256            PixelLayout::SemiPlanar => {
1257                // Contiguous semi-planar: luma + interleaved chroma in one allocation.
1258                // NV12 (4:2:0): H lines luma + H/2 lines chroma = H * 3/2 total
1259                // NV16 (4:2:2): H lines luma + H lines chroma = H * 2 total
1260                let total_h = match format {
1261                    PixelFormat::Nv12 => {
1262                        if !height.is_multiple_of(2) {
1263                            return Err(Error::InvalidArgument(format!(
1264                                "NV12 requires even height, got {height}"
1265                            )));
1266                        }
1267                        height * 3 / 2
1268                    }
1269                    PixelFormat::Nv16 => height * 2,
1270                    _ => {
1271                        return Err(Error::InvalidArgument(format!(
1272                            "unknown semi-planar height multiplier for {format:?}"
1273                        )))
1274                    }
1275                };
1276                vec![total_h, width]
1277            }
1278        };
1279
1280        // macOS Dma path: allocate a format-aware IOSurface (FourCC +
1281        // 2D dimensions) so the GL backend can bind it via
1282        // `EGL_ANGLE_iosurface_client_buffer`. Without this, the IOSurface
1283        // would default to a generic byte buffer (FourCC 'L008') and
1284        // ANGLE would reject the import with `EGL_BAD_ATTRIBUTE`.
1285        //
1286        // Guard: IOSurface rounds `bytes_per_row` up to 64-byte alignment.
1287        // If the natural row pitch (`width * channels * sizeof(T)`) is not
1288        // already 64-byte aligned, the padded allocation cannot be mapped
1289        // as a contiguous packed tensor — CPU reads/writes would use the
1290        // wrong stride. Only proceed when alignment is natural; otherwise
1291        // fall through to SHM/Mem where no per-row padding exists.
1292        #[cfg(target_os = "macos")]
1293        if matches!(memory, Some(TensorMemory::Dma)) {
1294            let natural_row_bytes = width * format.channels() * std::mem::size_of::<T>();
1295            if natural_row_bytes.is_multiple_of(64) {
1296                if let Ok(storage) =
1297                    TensorStorage::<T>::new_image_iosurface(width, height, format, &shape, None)
1298                {
1299                    let mut t = Self::wrap(storage);
1300                    t.format = Some(format);
1301                    return Ok(t);
1302                }
1303            }
1304            // If row pitch is not 64-byte aligned or new_image_iosurface
1305            // fails (unsupported format), fall through to the generic
1306            // Tensor::new path which picks up SHM/Mem instead.
1307        }
1308
1309        let mut t = Self::new(&shape, memory, None)?;
1310        t.format = Some(format);
1311        Ok(t)
1312    }
1313
1314    /// Create a DMA-backed image tensor with an explicit row stride that
1315    /// may exceed the natural `width * channels * sizeof(T)` pitch.
1316    ///
1317    /// Used for image tensors that need GPU pitch alignment padding: the
1318    /// underlying DMA-BUF is sized to `row_stride * height` bytes, but
1319    /// the tensor's logical shape stays at `[height, width, channels]`.
1320    /// `width()` / `height()` / `shape()` continue to report the
1321    /// user-requested values; the padding is visible only via
1322    /// `row_stride()` / `effective_row_stride()` and is automatically
1323    /// propagated to the GL backend's EGLImage import so Mali Valhall
1324    /// accepts the buffer.
1325    ///
1326    /// # Supported formats
1327    ///
1328    /// Currently only **packed** pixel layouts (RGBA8, BGRA8, RGB888,
1329    /// Grey, etc.) are supported — the formats the GL backend uses as
1330    /// render destinations. Semi-planar formats (NV12, NV16) come from
1331    /// external allocators (camera capture, video decoders) and are
1332    /// imported via `TensorDyn::from_fd` + `set_row_stride`, which
1333    /// already supports padded strides.
1334    ///
1335    /// # Supported memory
1336    ///
1337    /// Currently only `TensorMemory::Dma` is supported. PBO and Mem
1338    /// storage don't go through EGLImage import so they don't need
1339    /// pitch alignment; if you pass any other memory type this returns
1340    /// `NotImplemented`. `None` (auto-select) is treated as `Dma`.
1341    ///
1342    /// # Errors
1343    ///
1344    /// - `InvalidArgument` if `row_stride_bytes < width * channels * sizeof(T)`
1345    ///   (the requested stride would not fit a single row)
1346    /// - `NotImplemented` for non-packed formats or non-DMA memory
1347    /// - `IoError` if the DMA-heap allocation fails (propagated from
1348    ///   `DmaTensor::new_with_byte_size`)
1349    pub fn image_with_stride(
1350        width: usize,
1351        height: usize,
1352        format: PixelFormat,
1353        row_stride_bytes: usize,
1354        memory: Option<TensorMemory>,
1355    ) -> Result<Self> {
1356        // DMA backing (the only thing this constructor produces) is
1357        // Linux-only. On macOS/BSD/Windows the non-Linux block below is
1358        // the only compiled body and returns `NotImplemented` directly;
1359        // on Linux the non-Linux block is cfg-removed and the function
1360        // falls through to the real validation + allocation path. Each
1361        // target compiles exactly one of the two blocks, and the block
1362        // serves as the function's tail expression in both cases — so
1363        // neither needs an explicit `return` (avoids
1364        // `clippy::needless_return` on the macOS CI gate).
1365        #[cfg(not(target_os = "linux"))]
1366        {
1367            let _ = (width, height, format, row_stride_bytes, memory);
1368            Err(Error::NotImplemented(
1369                "image_with_stride requires DMA support (Linux only)".to_owned(),
1370            ))
1371        }
1372
1373        #[cfg(target_os = "linux")]
1374        {
1375            if format.layout() != PixelLayout::Packed {
1376                return Err(Error::NotImplemented(format!(
1377                    "Tensor::image_with_stride only supports packed pixel layouts, got {format:?}"
1378                )));
1379            }
1380            let elem = std::mem::size_of::<T>();
1381            let min_stride = width
1382                .checked_mul(format.channels())
1383                .and_then(|p| p.checked_mul(elem))
1384                .ok_or_else(|| {
1385                    Error::InvalidArgument(format!(
1386                        "image_with_stride: width {width} × channels {} × sizeof::<T>={elem} \
1387                         overflows usize",
1388                        format.channels()
1389                    ))
1390                })?;
1391            if row_stride_bytes < min_stride {
1392                return Err(Error::InvalidArgument(format!(
1393                    "image_with_stride: row_stride {row_stride_bytes} < minimum {min_stride} \
1394                     ({width} px × {} ch × {elem} B)",
1395                    format.channels()
1396                )));
1397            }
1398            let total_byte_size = row_stride_bytes.checked_mul(height).ok_or_else(|| {
1399                Error::InvalidArgument(format!(
1400                    "image_with_stride: row_stride {row_stride_bytes} × height {height} overflows usize"
1401                ))
1402            })?;
1403
1404            let shape = vec![height, width, format.channels()];
1405
1406            let storage = match memory {
1407                Some(TensorMemory::Dma) | None => {
1408                    TensorStorage::<T>::new_dma_with_byte_size(&shape, total_byte_size, None)?
1409                }
1410                Some(other) => {
1411                    return Err(Error::NotImplemented(format!(
1412                        "image_with_stride: only TensorMemory::Dma is supported, got {other:?}"
1413                    )));
1414                }
1415            };
1416
1417            let mut t = Self::wrap(storage);
1418            t.format = Some(format);
1419            t.row_stride = Some(row_stride_bytes);
1420            Ok(t)
1421        }
1422    }
1423
1424    /// Attach format metadata to an existing tensor.
1425    ///
1426    /// # Arguments
1427    ///
1428    /// * `format` - The pixel format to attach
1429    ///
1430    /// # Returns
1431    ///
1432    /// `Ok(())` on success, with the format stored as metadata on the tensor.
1433    ///
1434    /// # Errors
1435    ///
1436    /// Returns `Error::InvalidShape` if the tensor shape is incompatible with
1437    /// the format's layout (packed expects `[H, W, C]`, planar expects
1438    /// `[C, H, W]`, semi-planar expects `[H*k, W]` with format-specific
1439    /// height constraints).
1440    pub fn set_format(&mut self, format: PixelFormat) -> Result<()> {
1441        let shape = self.shape();
1442        match format.layout() {
1443            PixelLayout::Packed => {
1444                if shape.len() != 3 || shape[2] != format.channels() {
1445                    return Err(Error::InvalidShape(format!(
1446                        "packed format {format:?} expects [H, W, {}], got {shape:?}",
1447                        format.channels()
1448                    )));
1449                }
1450            }
1451            PixelLayout::Planar => {
1452                if shape.len() != 3 || shape[0] != format.channels() {
1453                    return Err(Error::InvalidShape(format!(
1454                        "planar format {format:?} expects [{}, H, W], got {shape:?}",
1455                        format.channels()
1456                    )));
1457                }
1458            }
1459            PixelLayout::SemiPlanar => {
1460                if shape.len() != 2 {
1461                    return Err(Error::InvalidShape(format!(
1462                        "semi-planar format {format:?} expects [H*k, W], got {shape:?}"
1463                    )));
1464                }
1465                match format {
1466                    PixelFormat::Nv12 if !shape[0].is_multiple_of(3) => {
1467                        return Err(Error::InvalidShape(format!(
1468                            "NV12 contiguous shape[0] must be divisible by 3, got {}",
1469                            shape[0]
1470                        )));
1471                    }
1472                    PixelFormat::Nv16 if !shape[0].is_multiple_of(2) => {
1473                        return Err(Error::InvalidShape(format!(
1474                            "NV16 contiguous shape[0] must be even, got {}",
1475                            shape[0]
1476                        )));
1477                    }
1478                    _ => {}
1479                }
1480            }
1481        }
1482        // Clear stored stride/offset when format changes — they may be invalid
1483        // for the new format. Caller must re-set after changing format.
1484        if self.format != Some(format) {
1485            self.row_stride = None;
1486            self.plane_offset = None;
1487            #[cfg(target_os = "linux")]
1488            if let TensorStorage::Dma(ref mut dma) = self.storage {
1489                dma.mmap_offset = 0;
1490            }
1491        }
1492        self.format = Some(format);
1493        Ok(())
1494    }
1495
1496    /// Pixel format (None if not an image).
1497    pub fn format(&self) -> Option<PixelFormat> {
1498        self.format
1499    }
1500
1501    /// Image width (None if not an image).
1502    pub fn width(&self) -> Option<usize> {
1503        let fmt = self.format?;
1504        let shape = self.shape();
1505        match fmt.layout() {
1506            PixelLayout::Packed => Some(shape[1]),
1507            PixelLayout::Planar => Some(shape[2]),
1508            PixelLayout::SemiPlanar => Some(shape[1]),
1509        }
1510    }
1511
1512    /// Image height (None if not an image).
1513    pub fn height(&self) -> Option<usize> {
1514        let fmt = self.format?;
1515        let shape = self.shape();
1516        match fmt.layout() {
1517            PixelLayout::Packed => Some(shape[0]),
1518            PixelLayout::Planar => Some(shape[1]),
1519            PixelLayout::SemiPlanar => {
1520                if self.is_multiplane() {
1521                    Some(shape[0])
1522                } else {
1523                    match fmt {
1524                        PixelFormat::Nv12 => Some(shape[0] * 2 / 3),
1525                        PixelFormat::Nv16 => Some(shape[0] / 2),
1526                        _ => None,
1527                    }
1528                }
1529            }
1530        }
1531    }
1532
1533    /// Create from separate Y and UV planes (multiplane NV12/NV16).
1534    pub fn from_planes(luma: Tensor<T>, chroma: Tensor<T>, format: PixelFormat) -> Result<Self> {
1535        if format.layout() != PixelLayout::SemiPlanar {
1536            return Err(Error::InvalidArgument(format!(
1537                "from_planes requires a semi-planar format, got {format:?}"
1538            )));
1539        }
1540        if chroma.format.is_some() || chroma.chroma.is_some() {
1541            return Err(Error::InvalidArgument(
1542                "chroma tensor must be a raw tensor (no format or chroma metadata)".into(),
1543            ));
1544        }
1545        let luma_shape = luma.shape();
1546        let chroma_shape = chroma.shape();
1547        if luma_shape.len() != 2 || chroma_shape.len() != 2 {
1548            return Err(Error::InvalidArgument(format!(
1549                "from_planes expects 2D shapes, got luma={luma_shape:?} chroma={chroma_shape:?}"
1550            )));
1551        }
1552        if luma_shape[1] != chroma_shape[1] {
1553            return Err(Error::InvalidArgument(format!(
1554                "luma width {} != chroma width {}",
1555                luma_shape[1], chroma_shape[1]
1556            )));
1557        }
1558        match format {
1559            PixelFormat::Nv12 => {
1560                if luma_shape[0] % 2 != 0 {
1561                    return Err(Error::InvalidArgument(format!(
1562                        "NV12 requires even luma height, got {}",
1563                        luma_shape[0]
1564                    )));
1565                }
1566                if chroma_shape[0] != luma_shape[0] / 2 {
1567                    return Err(Error::InvalidArgument(format!(
1568                        "NV12 chroma height {} != luma height / 2 ({})",
1569                        chroma_shape[0],
1570                        luma_shape[0] / 2
1571                    )));
1572                }
1573            }
1574            PixelFormat::Nv16 => {
1575                if chroma_shape[0] != luma_shape[0] {
1576                    return Err(Error::InvalidArgument(format!(
1577                        "NV16 chroma height {} != luma height {}",
1578                        chroma_shape[0], luma_shape[0]
1579                    )));
1580                }
1581            }
1582            _ => {
1583                return Err(Error::InvalidArgument(format!(
1584                    "from_planes only supports NV12 and NV16, got {format:?}"
1585                )));
1586            }
1587        }
1588
1589        Ok(Tensor {
1590            storage: luma.storage,
1591            format: Some(format),
1592            chroma: Some(Box::new(chroma)),
1593            row_stride: luma.row_stride,
1594            plane_offset: luma.plane_offset,
1595            quantization: luma.quantization,
1596        })
1597    }
1598
1599    /// Whether this tensor uses separate plane allocations.
1600    pub fn is_multiplane(&self) -> bool {
1601        self.chroma.is_some()
1602    }
1603
1604    /// Access the chroma plane for multiplane semi-planar images.
1605    pub fn chroma(&self) -> Option<&Tensor<T>> {
1606        self.chroma.as_deref()
1607    }
1608
1609    /// Mutable access to the chroma plane for multiplane semi-planar images.
1610    pub fn chroma_mut(&mut self) -> Option<&mut Tensor<T>> {
1611        self.chroma.as_deref_mut()
1612    }
1613
1614    /// Row stride in bytes (`None` = tightly packed).
1615    pub fn row_stride(&self) -> Option<usize> {
1616        self.row_stride
1617    }
1618
1619    /// Effective row stride in bytes: the stored stride if set, otherwise the
1620    /// minimum stride computed from the format, width, and element size.
1621    /// Returns `None` only when no format is set and no explicit stride was
1622    /// stored via [`set_row_stride`](Self::set_row_stride).
1623    pub fn effective_row_stride(&self) -> Option<usize> {
1624        if let Some(s) = self.row_stride {
1625            return Some(s);
1626        }
1627        let fmt = self.format?;
1628        let w = self.width()?;
1629        let elem = std::mem::size_of::<T>();
1630        Some(match fmt.layout() {
1631            PixelLayout::Packed => w * fmt.channels() * elem,
1632            PixelLayout::Planar | PixelLayout::SemiPlanar => w * elem,
1633        })
1634    }
1635
1636    /// Set the row stride in bytes for externally allocated buffers with
1637    /// row padding (e.g. V4L2 or GStreamer allocators).
1638    ///
1639    /// The stride is propagated to the EGL DMA-BUF import attributes so
1640    /// the GPU interprets the padded buffer layout correctly. Must be
1641    /// called after [`set_format`](Self::set_format) and before the tensor
1642    /// is first passed to [`ImageProcessor::convert`]. The stored stride
1643    /// is cleared automatically if the pixel format is later changed.
1644    ///
1645    /// No stride-vs-buffer-size validation is performed because the
1646    /// backing allocation size is not reliably known: external DMA-BUFs
1647    /// may be over-allocated by the allocator, and internal tensors store
1648    /// a logical (unpadded) shape. An incorrect stride will be caught by
1649    /// the EGL driver at import time.
1650    ///
1651    /// # Arguments
1652    ///
1653    /// * `stride` - Row stride in bytes. Must be >= the minimum stride for
1654    ///   the format (width * channels * sizeof(T) for packed,
1655    ///   width * sizeof(T) for planar/semi-planar).
1656    ///
1657    /// # Errors
1658    ///
1659    /// * `InvalidArgument` if no pixel format is set on this tensor
1660    /// * `InvalidArgument` if `stride` is less than the minimum for the
1661    ///   format and width
1662    pub fn set_row_stride(&mut self, stride: usize) -> Result<()> {
1663        let fmt = self.format.ok_or_else(|| {
1664            Error::InvalidArgument("cannot set row_stride without a pixel format".into())
1665        })?;
1666        let w = self.width().ok_or_else(|| {
1667            Error::InvalidArgument("cannot determine width for row_stride validation".into())
1668        })?;
1669        let elem = std::mem::size_of::<T>();
1670        let min_stride = match fmt.layout() {
1671            PixelLayout::Packed => w * fmt.channels() * elem,
1672            PixelLayout::Planar | PixelLayout::SemiPlanar => w * elem,
1673        };
1674        if stride < min_stride {
1675            return Err(Error::InvalidArgument(format!(
1676                "row_stride {stride} < minimum {min_stride} for {fmt:?} at width {w}"
1677            )));
1678        }
1679        self.row_stride = Some(stride);
1680        Ok(())
1681    }
1682
1683    /// Set the row stride without format validation.
1684    ///
1685    /// Use this for raw sub-tensors (e.g. chroma planes) that don't carry
1686    /// format metadata. The caller is responsible for ensuring the stride
1687    /// is valid.
1688    pub fn set_row_stride_unchecked(&mut self, stride: usize) {
1689        self.row_stride = Some(stride);
1690    }
1691
1692    /// Builder-style variant of [`set_row_stride`](Self::set_row_stride),
1693    /// consuming and returning `self`.
1694    ///
1695    /// # Errors
1696    ///
1697    /// Same conditions as [`set_row_stride`](Self::set_row_stride).
1698    pub fn with_row_stride(mut self, stride: usize) -> Result<Self> {
1699        self.set_row_stride(stride)?;
1700        Ok(self)
1701    }
1702
1703    /// Byte offset within the DMA-BUF where image data starts (`None` = 0).
1704    pub fn plane_offset(&self) -> Option<usize> {
1705        self.plane_offset
1706    }
1707
1708    /// Set the byte offset within the DMA-BUF where image data starts.
1709    ///
1710    /// Propagated to `EGL_DMA_BUF_PLANE0_OFFSET_EXT` on GPU import.
1711    /// Unlike [`set_row_stride`](Self::set_row_stride), no format is required
1712    /// since the offset is format-independent.
1713    pub fn set_plane_offset(&mut self, offset: usize) {
1714        self.plane_offset = Some(offset);
1715        #[cfg(target_os = "linux")]
1716        if let TensorStorage::Dma(ref mut dma) = self.storage {
1717            dma.mmap_offset = offset;
1718        }
1719    }
1720
1721    /// Builder-style variant of [`set_plane_offset`](Self::set_plane_offset),
1722    /// consuming and returning `self`.
1723    pub fn with_plane_offset(mut self, offset: usize) -> Self {
1724        self.set_plane_offset(offset);
1725        self
1726    }
1727
1728    /// Downcast to PBO tensor reference (for GL backends).
1729    pub fn as_pbo(&self) -> Option<&PboTensor<T>> {
1730        match &self.storage {
1731            TensorStorage::Pbo(p) => Some(p),
1732            _ => None,
1733        }
1734    }
1735
1736    /// Downcast to DMA tensor reference (for EGL import, G2D).
1737    #[cfg(target_os = "linux")]
1738    pub fn as_dma(&self) -> Option<&DmaTensor<T>> {
1739        match &self.storage {
1740            TensorStorage::Dma(d) => Some(d),
1741            _ => None,
1742        }
1743    }
1744
1745    /// Borrow the DMA-BUF file descriptor backing this tensor.
1746    ///
1747    /// # Returns
1748    ///
1749    /// A borrowed reference to the DMA-BUF file descriptor, tied to `self`'s
1750    /// lifetime.
1751    ///
1752    /// # Errors
1753    ///
1754    /// Returns `Error::NotImplemented` if the tensor is not DMA-backed.
1755    #[cfg(target_os = "linux")]
1756    pub fn dmabuf(&self) -> Result<std::os::fd::BorrowedFd<'_>> {
1757        use std::os::fd::AsFd;
1758        match &self.storage {
1759            TensorStorage::Dma(dma) => Ok(dma.fd.as_fd()),
1760            _ => Err(Error::NotImplemented(format!(
1761                "dmabuf requires DMA-backed tensor, got {:?}",
1762                self.storage.memory()
1763            ))),
1764        }
1765    }
1766
1767    /// Construct a Tensor from a PBO tensor (for GL backends that allocate PBOs).
1768    pub fn from_pbo(pbo: PboTensor<T>) -> Self {
1769        Self {
1770            storage: TensorStorage::Pbo(pbo),
1771            format: None,
1772            chroma: None,
1773            row_stride: None,
1774            plane_offset: None,
1775            quantization: None,
1776        }
1777    }
1778}
1779
1780// Quantization accessors — type-gated to integer element types via the
1781// sealed `IntegerType` trait. Calling `.quantization()` on a `Tensor<f32>`
1782// produces a compile error, not a runtime one.
1783impl<T> Tensor<T>
1784where
1785    T: IntegerType + Num + Clone + fmt::Debug + Send + Sync,
1786{
1787    /// Quantization metadata for this tensor, if set.
1788    pub fn quantization(&self) -> Option<&Quantization> {
1789        self.quantization.as_ref()
1790    }
1791
1792    /// Attach quantization metadata to this tensor. Validates against the
1793    /// tensor's shape — returns [`Error::QuantizationInvalid`] on any
1794    /// inconsistency (mismatched scale/zp lengths, out-of-range axis, etc.).
1795    pub fn set_quantization(&mut self, q: Quantization) -> Result<()> {
1796        q.validate(self.shape())?;
1797        self.quantization = Some(q);
1798        Ok(())
1799    }
1800
1801    /// Builder-style variant of [`Self::set_quantization`]. Consumes `self`
1802    /// and returns `Result<Self>` — on success yields the tensor with the
1803    /// attached quantization; on validation failure returns
1804    /// [`Error::QuantizationInvalid`] and drops `self` (the tensor is not
1805    /// returned in the error arm).
1806    pub fn with_quantization(mut self, q: Quantization) -> Result<Self> {
1807        self.set_quantization(q)?;
1808        Ok(self)
1809    }
1810
1811    /// Clear any quantization metadata on this tensor.
1812    pub fn clear_quantization(&mut self) {
1813        self.quantization = None;
1814    }
1815}
1816
1817impl<T> TensorTrait<T> for Tensor<T>
1818where
1819    T: Num + Clone + fmt::Debug + Send + Sync,
1820{
1821    fn new(shape: &[usize], name: Option<&str>) -> Result<Self>
1822    where
1823        Self: Sized,
1824    {
1825        Self::new(shape, None, name)
1826    }
1827
1828    #[cfg(unix)]
1829    fn from_fd(fd: std::os::fd::OwnedFd, shape: &[usize], name: Option<&str>) -> Result<Self>
1830    where
1831        Self: Sized,
1832    {
1833        Ok(Self::wrap(TensorStorage::from_fd(fd, shape, name)?))
1834    }
1835
1836    #[cfg(unix)]
1837    fn clone_fd(&self) -> Result<std::os::fd::OwnedFd> {
1838        self.storage.clone_fd()
1839    }
1840
1841    fn memory(&self) -> TensorMemory {
1842        self.storage.memory()
1843    }
1844
1845    fn name(&self) -> String {
1846        self.storage.name()
1847    }
1848
1849    fn shape(&self) -> &[usize] {
1850        self.storage.shape()
1851    }
1852
1853    fn reshape(&mut self, shape: &[usize]) -> Result<()> {
1854        if self.chroma.is_some() {
1855            return Err(Error::InvalidOperation(
1856                "cannot reshape a multiplane tensor — decompose planes first".into(),
1857            ));
1858        }
1859        self.storage.reshape(shape)?;
1860        self.format = None;
1861        self.row_stride = None;
1862        self.plane_offset = None;
1863        #[cfg(target_os = "linux")]
1864        if let TensorStorage::Dma(ref mut dma) = self.storage {
1865            dma.mmap_offset = 0;
1866        }
1867        Ok(())
1868    }
1869
1870    fn map(&self) -> Result<TensorMap<T>> {
1871        let _span = tracing::trace_span!(
1872            "tensor.map",
1873            memory = ?self.storage.memory(),
1874        )
1875        .entered();
1876        // CPU mapping of strided tensors is allowed only when the HAL
1877        // owns the underlying allocation — i.e. self-allocated DMA
1878        // tensors with pitch padding added by `image_with_stride()`
1879        // for GPU import alignment. In that case we know the buffer
1880        // is exactly `row_stride × height` bytes (for packed formats)
1881        // and callers that respect the stride can iterate rows
1882        // correctly via `effective_row_stride()`.
1883        //
1884        // Foreign DMA-BUFs imported via `from_fd()` + `set_row_stride()`
1885        // (the V4L2 / GStreamer case) are rejected: their layout comes
1886        // from an external allocator and the HAL cannot validate what
1887        // the caller expects the mapping to look like. Those tensors
1888        // are intended for the GPU path only.
1889        //
1890        // The cfg split keeps `stride` from being an unused binding on
1891        // non-Linux builds (the Linux branch is the only consumer).
1892        #[cfg(target_os = "linux")]
1893        if let Some(stride) = self.row_stride {
1894            if let TensorStorage::Dma(dma) = &self.storage {
1895                if !dma.is_imported {
1896                    // Self-allocated strided DMA tensor — expose the
1897                    // full stride×height padded mmap via the override
1898                    // constructor so callers can iterate rows with
1899                    // `effective_row_stride()` without going past
1900                    // the end of the returned slice.
1901                    //
1902                    // Validate the requested mapping fits inside the
1903                    // actual DMA-BUF. `set_row_stride()` is a public
1904                    // API and only validates `stride >= min_stride`,
1905                    // not `stride × height <= buf_size`, so a caller
1906                    // that tampers with the stride after allocation
1907                    // could otherwise request a slice larger than the
1908                    // underlying mmap — which would be undefined
1909                    // behaviour in `DmaMap::as_slice`.
1910                    //
1911                    // Refuse to map if `height()` can't be derived
1912                    // (e.g. raw 2D tensors without a PixelFormat that
1913                    // got a `row_stride` set via `set_row_stride_unchecked`).
1914                    // Returning a 0-byte view would silently truncate
1915                    // rather than surface the misuse.
1916                    let height = self.height().ok_or_else(|| {
1917                        Error::InvalidOperation(
1918                            "Tensor::map: strided DMA mapping requires a PixelFormat \
1919                             so height() can be derived; set a format before mapping \
1920                             or clear row_stride for raw tensor access"
1921                                .into(),
1922                        )
1923                    })?;
1924                    let total_bytes = stride.checked_mul(height).ok_or_else(|| {
1925                        Error::InvalidOperation(format!(
1926                            "Tensor::map: row_stride {stride} × height {height} overflows usize"
1927                        ))
1928                    })?;
1929                    let available_bytes = dma.buf_size.saturating_sub(dma.mmap_offset);
1930                    if total_bytes > available_bytes {
1931                        return Err(Error::InvalidOperation(format!(
1932                            "Tensor::map: strided mapping needs {total_bytes} bytes \
1933                             but DMA buffer only has {available_bytes} available \
1934                             (buf_size={}, mmap_offset={}, stride={stride}, height={height}); \
1935                             the row_stride was likely set larger than the original allocation",
1936                            dma.buf_size, dma.mmap_offset
1937                        )));
1938                    }
1939                    return dma.map_with_byte_size(total_bytes).map(TensorMap::Dma);
1940                }
1941            }
1942            return Err(Error::InvalidOperation(
1943                "CPU mapping of strided foreign tensors is not supported; \
1944                 use GPU path only"
1945                    .into(),
1946            ));
1947        }
1948        #[cfg(not(target_os = "linux"))]
1949        if self.row_stride.is_some() {
1950            return Err(Error::InvalidOperation(
1951                "CPU mapping of strided tensors is not supported on this \
1952                 platform (DMA backing is Linux-only)"
1953                    .into(),
1954            ));
1955        }
1956        // Offset tensors are supported for DMA storage — DmaMap adjusts the
1957        // mmap range and slice start position.  Non-DMA offset tensors are
1958        // not meaningful (offset only applies to DMA-BUF sub-regions).
1959        if self.plane_offset.is_some_and(|o| o > 0) {
1960            #[cfg(target_os = "linux")]
1961            if !matches!(self.storage, TensorStorage::Dma(_)) {
1962                return Err(Error::InvalidOperation(
1963                    "plane offset only supported for DMA tensors".into(),
1964                ));
1965            }
1966            #[cfg(not(target_os = "linux"))]
1967            return Err(Error::InvalidOperation(
1968                "plane offset only supported for DMA tensors".into(),
1969            ));
1970        }
1971        self.storage.map()
1972    }
1973
1974    fn buffer_identity(&self) -> &BufferIdentity {
1975        self.storage.buffer_identity()
1976    }
1977}
1978
1979pub enum TensorMap<T>
1980where
1981    T: Num + Clone + fmt::Debug,
1982{
1983    #[cfg(target_os = "linux")]
1984    Dma(DmaMap<T>),
1985    #[cfg(target_os = "macos")]
1986    IoSurface(IoSurfaceMap<T>),
1987    #[cfg(unix)]
1988    Shm(ShmMap<T>),
1989    Mem(MemMap<T>),
1990    Pbo(PboMap<T>),
1991}
1992
1993impl<T> TensorMapTrait<T> for TensorMap<T>
1994where
1995    T: Num + Clone + fmt::Debug,
1996{
1997    fn shape(&self) -> &[usize] {
1998        match self {
1999            #[cfg(target_os = "linux")]
2000            TensorMap::Dma(map) => map.shape(),
2001            #[cfg(target_os = "macos")]
2002            TensorMap::IoSurface(map) => map.shape(),
2003            #[cfg(unix)]
2004            TensorMap::Shm(map) => map.shape(),
2005            TensorMap::Mem(map) => map.shape(),
2006            TensorMap::Pbo(map) => map.shape(),
2007        }
2008    }
2009
2010    fn unmap(&mut self) {
2011        match self {
2012            #[cfg(target_os = "linux")]
2013            TensorMap::Dma(map) => map.unmap(),
2014            #[cfg(target_os = "macos")]
2015            TensorMap::IoSurface(map) => map.unmap(),
2016            #[cfg(unix)]
2017            TensorMap::Shm(map) => map.unmap(),
2018            TensorMap::Mem(map) => map.unmap(),
2019            TensorMap::Pbo(map) => map.unmap(),
2020        }
2021    }
2022
2023    fn as_slice(&self) -> &[T] {
2024        match self {
2025            #[cfg(target_os = "linux")]
2026            TensorMap::Dma(map) => map.as_slice(),
2027            #[cfg(target_os = "macos")]
2028            TensorMap::IoSurface(map) => map.deref(),
2029            #[cfg(unix)]
2030            TensorMap::Shm(map) => map.as_slice(),
2031            TensorMap::Mem(map) => map.as_slice(),
2032            TensorMap::Pbo(map) => map.as_slice(),
2033        }
2034    }
2035
2036    fn as_mut_slice(&mut self) -> &mut [T] {
2037        match self {
2038            #[cfg(target_os = "linux")]
2039            TensorMap::Dma(map) => map.as_mut_slice(),
2040            #[cfg(target_os = "macos")]
2041            TensorMap::IoSurface(map) => map.deref_mut(),
2042            #[cfg(unix)]
2043            TensorMap::Shm(map) => map.as_mut_slice(),
2044            TensorMap::Mem(map) => map.as_mut_slice(),
2045            TensorMap::Pbo(map) => map.as_mut_slice(),
2046        }
2047    }
2048}
2049
2050impl<T> Deref for TensorMap<T>
2051where
2052    T: Num + Clone + fmt::Debug,
2053{
2054    type Target = [T];
2055
2056    fn deref(&self) -> &[T] {
2057        match self {
2058            #[cfg(target_os = "linux")]
2059            TensorMap::Dma(map) => map.deref(),
2060            #[cfg(target_os = "macos")]
2061            TensorMap::IoSurface(map) => map.deref(),
2062            #[cfg(unix)]
2063            TensorMap::Shm(map) => map.deref(),
2064            TensorMap::Mem(map) => map.deref(),
2065            TensorMap::Pbo(map) => map.deref(),
2066        }
2067    }
2068}
2069
2070impl<T> DerefMut for TensorMap<T>
2071where
2072    T: Num + Clone + fmt::Debug,
2073{
2074    fn deref_mut(&mut self) -> &mut [T] {
2075        match self {
2076            #[cfg(target_os = "linux")]
2077            TensorMap::Dma(map) => map.deref_mut(),
2078            #[cfg(target_os = "macos")]
2079            TensorMap::IoSurface(map) => map.deref_mut(),
2080            #[cfg(unix)]
2081            TensorMap::Shm(map) => map.deref_mut(),
2082            TensorMap::Mem(map) => map.deref_mut(),
2083            TensorMap::Pbo(map) => map.deref_mut(),
2084        }
2085    }
2086}
2087
2088// ============================================================================
2089// Platform availability helpers
2090// ============================================================================
2091
2092/// Cached result of the Linux DMA-BUF availability probe.
2093#[cfg(target_os = "linux")]
2094static DMA_AVAILABLE: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
2095/// Cached result of the macOS IOSurface availability probe.
2096#[cfg(target_os = "macos")]
2097static IOSURFACE_AVAILABLE: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
2098
2099/// Check if Linux DMA-BUF allocation is available on this system.
2100///
2101/// Linux-specific availability check (typically requires `/dev/dma_heap`
2102/// access — running as root or membership in a video/render group). For
2103/// portable code that wants "any zero-copy GPU buffer", use
2104/// [`is_gpu_buffer_available`] which also covers IOSurface on macOS.
2105///
2106/// This function caches its result after the first call.
2107#[cfg(target_os = "linux")]
2108pub fn is_dma_available() -> bool {
2109    *DMA_AVAILABLE.get_or_init(|| Tensor::<u8>::new(&[64], Some(TensorMemory::Dma), None).is_ok())
2110}
2111
2112/// Always returns `false` on non-Linux platforms.
2113#[cfg(not(target_os = "linux"))]
2114pub fn is_dma_available() -> bool {
2115    false
2116}
2117
2118/// Check if macOS IOSurface allocation is available on this system.
2119///
2120/// IOSurface is part of the macOS OS and is essentially always present;
2121/// this probe catches degraded scenarios such as memory pressure or
2122/// sandboxed contexts where `IOSurfaceCreate` fails. The result is
2123/// cached after the first call.
2124///
2125/// Always returns `false` on non-macOS platforms.
2126#[cfg(target_os = "macos")]
2127pub fn is_iosurface_available() -> bool {
2128    *IOSURFACE_AVAILABLE.get_or_init(|| {
2129        // Probe via the same Dma path — on macOS this routes through
2130        // IoSurfaceTensor::new.
2131        Tensor::<u8>::new(&[64], Some(TensorMemory::Dma), None).is_ok()
2132    })
2133}
2134
2135#[cfg(not(target_os = "macos"))]
2136pub fn is_iosurface_available() -> bool {
2137    false
2138}
2139
2140/// Portable probe for the platform's native zero-copy GPU buffer
2141/// allocator (DMA-BUF on Linux, IOSurface on macOS). Returns `false` on
2142/// Windows and other platforms with no equivalent. Use this when writing
2143/// cross-platform code that cares whether the `Dma` tensor variant will
2144/// work, not which underlying mechanism is used.
2145pub fn is_gpu_buffer_available() -> bool {
2146    #[cfg(target_os = "linux")]
2147    {
2148        is_dma_available()
2149    }
2150    #[cfg(target_os = "macos")]
2151    {
2152        is_iosurface_available()
2153    }
2154    #[cfg(not(any(target_os = "linux", target_os = "macos")))]
2155    {
2156        false
2157    }
2158}
2159
2160/// Check if POSIX shared memory allocation is available on this system.
2161///
2162/// Returns `true` on Unix systems (Linux, macOS, BSD) where POSIX shared memory
2163/// is supported. Always returns `false` on non-Unix platforms (Windows).
2164///
2165/// This function caches its result after the first call for efficiency.
2166#[cfg(unix)]
2167static SHM_AVAILABLE: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
2168
2169/// Check if POSIX shared memory allocation is available on this system.
2170#[cfg(unix)]
2171pub fn is_shm_available() -> bool {
2172    *SHM_AVAILABLE.get_or_init(|| Tensor::<u8>::new(&[64], Some(TensorMemory::Shm), None).is_ok())
2173}
2174
2175/// Check if POSIX shared memory allocation is available on this system.
2176///
2177/// Always returns `false` on non-Unix platforms since POSIX SHM is Unix-specific.
2178#[cfg(not(unix))]
2179pub fn is_shm_available() -> bool {
2180    false
2181}
2182
2183#[cfg(test)]
2184mod dtype_tests {
2185    use super::*;
2186
2187    #[test]
2188    fn dtype_size() {
2189        assert_eq!(DType::U8.size(), 1);
2190        assert_eq!(DType::I8.size(), 1);
2191        assert_eq!(DType::U16.size(), 2);
2192        assert_eq!(DType::I16.size(), 2);
2193        assert_eq!(DType::U32.size(), 4);
2194        assert_eq!(DType::I32.size(), 4);
2195        assert_eq!(DType::U64.size(), 8);
2196        assert_eq!(DType::I64.size(), 8);
2197        assert_eq!(DType::F16.size(), 2);
2198        assert_eq!(DType::F32.size(), 4);
2199        assert_eq!(DType::F64.size(), 8);
2200    }
2201
2202    #[test]
2203    fn dtype_name() {
2204        assert_eq!(DType::U8.name(), "u8");
2205        assert_eq!(DType::F16.name(), "f16");
2206        assert_eq!(DType::F32.name(), "f32");
2207    }
2208
2209    #[test]
2210    fn dtype_serde_roundtrip() {
2211        use serde_json;
2212        let dt = DType::F16;
2213        let json = serde_json::to_string(&dt).unwrap();
2214        let back: DType = serde_json::from_str(&json).unwrap();
2215        assert_eq!(dt, back);
2216    }
2217}
2218
2219#[cfg(test)]
2220mod image_tests {
2221    use super::*;
2222
2223    #[test]
2224    fn raw_tensor_has_no_format() {
2225        let t = Tensor::<u8>::new(&[480, 640, 3], None, None).unwrap();
2226        assert!(t.format().is_none());
2227        assert!(t.width().is_none());
2228        assert!(t.height().is_none());
2229        assert!(!t.is_multiplane());
2230        assert!(t.chroma().is_none());
2231    }
2232
2233    #[test]
2234    fn image_tensor_packed() {
2235        let t = Tensor::<u8>::image(640, 480, PixelFormat::Rgba, None).unwrap();
2236        assert_eq!(t.format(), Some(PixelFormat::Rgba));
2237        assert_eq!(t.width(), Some(640));
2238        assert_eq!(t.height(), Some(480));
2239        assert_eq!(t.shape(), &[480, 640, 4]);
2240        assert!(!t.is_multiplane());
2241    }
2242
2243    #[test]
2244    fn image_tensor_planar() {
2245        let t = Tensor::<u8>::image(640, 480, PixelFormat::PlanarRgb, None).unwrap();
2246        assert_eq!(t.format(), Some(PixelFormat::PlanarRgb));
2247        assert_eq!(t.width(), Some(640));
2248        assert_eq!(t.height(), Some(480));
2249        assert_eq!(t.shape(), &[3, 480, 640]);
2250    }
2251
2252    #[test]
2253    fn image_tensor_semi_planar_contiguous() {
2254        let t = Tensor::<u8>::image(640, 480, PixelFormat::Nv12, None).unwrap();
2255        assert_eq!(t.format(), Some(PixelFormat::Nv12));
2256        assert_eq!(t.width(), Some(640));
2257        assert_eq!(t.height(), Some(480));
2258        // NV12: H*3/2 = 720
2259        assert_eq!(t.shape(), &[720, 640]);
2260        assert!(!t.is_multiplane());
2261    }
2262
2263    #[test]
2264    #[cfg(target_os = "linux")]
2265    fn image_tensor_with_stride_preserves_logical_width() {
2266        // Skip if DMA not available (e.g. sandboxed CI lacking dma_heap access).
2267        if !is_dma_available() {
2268            eprintln!("SKIPPED: DMA heap not available");
2269            return;
2270        }
2271        // 3004×1688 RGBA8: natural pitch 12016, padded to 12032 (64-aligned).
2272        let stride = 12032;
2273        let t = Tensor::<u8>::image_with_stride(
2274            3004,
2275            1688,
2276            PixelFormat::Rgba,
2277            stride,
2278            Some(TensorMemory::Dma),
2279        )
2280        .unwrap();
2281        // Logical dimensions unchanged by padding — this is the contract.
2282        assert_eq!(t.width(), Some(3004));
2283        assert_eq!(t.height(), Some(1688));
2284        assert_eq!(t.shape(), &[1688, 3004, 4]);
2285        // Stride is carried separately and reports the padded pitch.
2286        assert_eq!(t.effective_row_stride(), Some(stride));
2287        // Buffer is sized to stride × height so the full padded layout fits,
2288        // and CPU map() works for self-allocated strided DMA tensors.
2289        use crate::TensorMapTrait;
2290        {
2291            let map = t.map().unwrap();
2292            assert!(
2293                map.as_slice().len() >= stride * 1688,
2294                "mapped buffer {} bytes < expected {}",
2295                map.as_slice().len(),
2296                stride * 1688
2297            );
2298        }
2299        // CPU write access works too — iterate rows using the padded stride,
2300        // touch only the active `width × bpp` region, verify it round-trips.
2301        {
2302            let mut map = t.map().unwrap();
2303            let slice = map.as_mut_slice();
2304            for y in 0..1688 {
2305                let row_start = y * stride;
2306                for x in 0..3004 {
2307                    let p = row_start + x * 4;
2308                    slice[p] = (y & 0xFF) as u8;
2309                    slice[p + 1] = (x & 0xFF) as u8;
2310                    slice[p + 2] = 0x42;
2311                    slice[p + 3] = 0xFF;
2312                }
2313            }
2314        }
2315        {
2316            let map = t.map().unwrap();
2317            let slice = map.as_slice();
2318            // Sample a few pixels to confirm the round-trip.
2319            assert_eq!(slice[0], 0x00);
2320            assert_eq!(slice[1], 0x00);
2321            assert_eq!(slice[2], 0x42);
2322            assert_eq!(slice[3], 0xFF);
2323            let mid = 100 * stride + 50 * 4;
2324            assert_eq!(slice[mid], 100);
2325            assert_eq!(slice[mid + 1], 50);
2326            assert_eq!(slice[mid + 2], 0x42);
2327        }
2328    }
2329
2330    #[test]
2331    #[cfg(target_os = "linux")]
2332    fn image_tensor_with_stride_rejects_foreign_strided_map() {
2333        // A FOREIGN (imported via from_fd) DMA tensor with row_stride set
2334        // should still refuse CPU mapping — external allocator owns the
2335        // layout. This protects the V4L2 / GStreamer use case.
2336        //
2337        // We simulate a foreign import by wrapping our own allocation's
2338        // fd via `from_fd` and calling set_row_stride manually. The
2339        // `is_imported` flag on from_fd is true by construction.
2340        if !is_dma_available() {
2341            eprintln!("SKIPPED: DMA heap not available");
2342            return;
2343        }
2344        // Allocate a backing buffer large enough for a 320×240 BGRA8 image.
2345        let backing = Tensor::<u8>::new(&[240 * 320 * 4], Some(TensorMemory::Dma), None).unwrap();
2346        let fd = backing.clone_fd().unwrap();
2347        // Import it via from_fd — this marks is_imported=true.
2348        let shape = [240usize, 320, 4];
2349        let storage = TensorStorage::<u8>::from_fd(fd, &shape, None).unwrap();
2350        let mut t = Tensor::<u8>::wrap(storage);
2351        t.set_format(PixelFormat::Bgra).unwrap();
2352        t.set_row_stride(320 * 4).unwrap(); // natural, but still marks it as strided
2353        let err = t.map();
2354        assert!(
2355            matches!(err, Err(Error::InvalidOperation(_))),
2356            "foreign strided map should error"
2357        );
2358    }
2359
2360    #[test]
2361    #[cfg(target_os = "linux")]
2362    fn image_tensor_with_stride_map_rejects_tampered_stride() {
2363        // Round-3 PR feedback (C1): `set_row_stride` is public and only
2364        // validates `stride >= min_stride`, not that the new stride × height
2365        // fits the underlying buffer. A caller that tampers with the stride
2366        // after allocation must not be able to coerce `Tensor::map()` into
2367        // returning a slice larger than the backing mmap (that would be UB
2368        // in `DmaMap::as_slice`).
2369        if !is_dma_available() {
2370            eprintln!("SKIPPED: DMA heap not available");
2371            return;
2372        }
2373        // Allocate a 640×480 RGBA8 padded canvas (stride = 3072 = 768 px).
2374        // Backing buffer is 3072 × 480 = 1,474,560 bytes.
2375        let mut t = Tensor::<u8>::image_with_stride(
2376            640,
2377            480,
2378            PixelFormat::Rgba,
2379            3072,
2380            Some(TensorMemory::Dma),
2381        )
2382        .unwrap();
2383        // Tamper: push the stride up to 4 × the original. This is >=
2384        // min_stride (2560), so `set_row_stride` accepts it.
2385        t.set_row_stride(12288).unwrap();
2386        // Map must now refuse — 12288 × 480 = 5,898,240 > 1,474,560.
2387        let err = t.map();
2388        assert!(
2389            matches!(err, Err(Error::InvalidOperation(_))),
2390            "map() with oversized stride must return InvalidOperation"
2391        );
2392    }
2393
2394    #[test]
2395    fn dma_tensor_new_with_byte_size_rejects_shape_overflow() {
2396        // Round-3 PR feedback (C3): shape.product() * sizeof(T) must use
2397        // checked arithmetic so a pathological shape can't wrap usize and
2398        // make the byte_size-vs-logical-size comparison incorrect.
2399        //
2400        // This test only exercises the overflow rejection path, which is
2401        // pure-Rust and doesn't touch dma_heap — safe to run on any target.
2402        #[cfg(target_os = "linux")]
2403        {
2404            let err = crate::dma::DmaTensor::<u64>::new_with_byte_size(
2405                &[usize::MAX, 2, 2],
2406                usize::MAX,
2407                None,
2408            );
2409            assert!(
2410                matches!(err, Err(Error::InvalidArgument(_))),
2411                "new_with_byte_size must detect shape.product() overflow"
2412            );
2413        }
2414    }
2415
2416    #[test]
2417    #[cfg(target_os = "linux")]
2418    fn image_tensor_with_stride_rejects_too_small_stride() {
2419        // 640×480 RGBA8 natural pitch = 2560, request 2400 → should error.
2420        let err = Tensor::<u8>::image_with_stride(
2421            640,
2422            480,
2423            PixelFormat::Rgba,
2424            2400,
2425            Some(TensorMemory::Dma),
2426        );
2427        assert!(matches!(err, Err(Error::InvalidArgument(_))));
2428    }
2429
2430    #[test]
2431    #[cfg(target_os = "linux")]
2432    fn image_tensor_with_stride_rejects_non_packed() {
2433        // NV12 is SemiPlanar → not supported. (Linux-only because
2434        // `TensorMemory::Dma` itself is a Linux-only enum variant.)
2435        let err = Tensor::<u8>::image_with_stride(
2436            640,
2437            480,
2438            PixelFormat::Nv12,
2439            640,
2440            Some(TensorMemory::Dma),
2441        );
2442        assert!(matches!(err, Err(Error::NotImplemented(_))));
2443    }
2444
2445    #[test]
2446    fn set_format_valid() {
2447        let mut t = Tensor::<u8>::new(&[480, 640, 3], None, None).unwrap();
2448        assert!(t.format().is_none());
2449        t.set_format(PixelFormat::Rgb).unwrap();
2450        assert_eq!(t.format(), Some(PixelFormat::Rgb));
2451        assert_eq!(t.width(), Some(640));
2452        assert_eq!(t.height(), Some(480));
2453    }
2454
2455    #[test]
2456    fn set_format_invalid_shape() {
2457        let mut t = Tensor::<u8>::new(&[480, 640, 4], None, None).unwrap();
2458        // RGB expects 3 channels, not 4
2459        let err = t.set_format(PixelFormat::Rgb);
2460        assert!(err.is_err());
2461        // Original tensor is unmodified
2462        assert!(t.format().is_none());
2463    }
2464
2465    #[test]
2466    fn reshape_clears_format() {
2467        let mut t = Tensor::<u8>::image(640, 480, PixelFormat::Rgba, None).unwrap();
2468        assert_eq!(t.format(), Some(PixelFormat::Rgba));
2469        // Reshape to flat — format cleared
2470        t.reshape(&[480 * 640 * 4]).unwrap();
2471        assert!(t.format().is_none());
2472    }
2473
2474    #[test]
2475    fn from_planes_nv12() {
2476        let y = Tensor::<u8>::new(&[480, 640], None, None).unwrap();
2477        let uv = Tensor::<u8>::new(&[240, 640], None, None).unwrap();
2478        let img = Tensor::from_planes(y, uv, PixelFormat::Nv12).unwrap();
2479        assert_eq!(img.format(), Some(PixelFormat::Nv12));
2480        assert!(img.is_multiplane());
2481        assert!(img.chroma().is_some());
2482        assert_eq!(img.width(), Some(640));
2483        assert_eq!(img.height(), Some(480));
2484    }
2485
2486    #[test]
2487    fn from_planes_rejects_non_semiplanar() {
2488        let y = Tensor::<u8>::new(&[480, 640], None, None).unwrap();
2489        let uv = Tensor::<u8>::new(&[240, 640], None, None).unwrap();
2490        let err = Tensor::from_planes(y, uv, PixelFormat::Rgb);
2491        assert!(err.is_err());
2492    }
2493
2494    #[test]
2495    fn reshape_multiplane_errors() {
2496        let y = Tensor::<u8>::new(&[480, 640], None, None).unwrap();
2497        let uv = Tensor::<u8>::new(&[240, 640], None, None).unwrap();
2498        let mut img = Tensor::from_planes(y, uv, PixelFormat::Nv12).unwrap();
2499        let err = img.reshape(&[480 * 640 + 240 * 640]);
2500        assert!(err.is_err());
2501    }
2502}
2503
2504#[cfg(test)]
2505mod tests {
2506    #[cfg(target_os = "linux")]
2507    use nix::unistd::{access, AccessFlags};
2508    #[cfg(target_os = "linux")]
2509    use std::io::Write as _;
2510    use std::sync::RwLock;
2511
2512    use super::*;
2513
2514    #[ctor::ctor]
2515    fn init() {
2516        env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init();
2517    }
2518
2519    /// Macro to get the current function name for logging in tests.
2520    #[cfg(target_os = "linux")]
2521    macro_rules! function {
2522        () => {{
2523            fn f() {}
2524            fn type_name_of<T>(_: T) -> &'static str {
2525                std::any::type_name::<T>()
2526            }
2527            let name = type_name_of(f);
2528
2529            // Find and cut the rest of the path
2530            match &name[..name.len() - 3].rfind(':') {
2531                Some(pos) => &name[pos + 1..name.len() - 3],
2532                None => &name[..name.len() - 3],
2533            }
2534        }};
2535    }
2536
2537    #[test]
2538    #[cfg(target_os = "linux")]
2539    fn test_tensor() {
2540        let _lock = FD_LOCK.read().unwrap();
2541        let shape = vec![1];
2542        let tensor = DmaTensor::<f32>::new(&shape, Some("dma_tensor"));
2543        let dma_enabled = tensor.is_ok();
2544
2545        let tensor = Tensor::<f32>::new(&shape, None, None).expect("Failed to create tensor");
2546        match dma_enabled {
2547            true => assert_eq!(tensor.memory(), TensorMemory::Dma),
2548            false => assert_eq!(tensor.memory(), TensorMemory::Shm),
2549        }
2550    }
2551
2552    #[test]
2553    #[cfg(target_os = "macos")]
2554    fn test_tensor() {
2555        let shape = vec![1];
2556        let tensor = Tensor::<f32>::new(&shape, None, None).expect("Failed to create tensor");
2557        // macOS auto-fallback chain: IOSurface (Dma) → SHM → Mem.
2558        // Healthy systems always return Dma; SHM/Mem only appear under
2559        // memory pressure or sandboxed contexts where IOSurfaceCreate
2560        // fails.
2561        let m = tensor.memory();
2562        assert!(
2563            matches!(m, TensorMemory::Dma | TensorMemory::Shm | TensorMemory::Mem),
2564            "Unexpected auto-fallback result on macOS: {m:?}"
2565        );
2566    }
2567
2568    #[test]
2569    #[cfg(all(unix, not(any(target_os = "linux", target_os = "macos"))))]
2570    fn test_tensor() {
2571        let shape = vec![1];
2572        let tensor = Tensor::<f32>::new(&shape, None, None).expect("Failed to create tensor");
2573        // Other Unix (BSD): auto-detection tries SHM first, falls back to Mem.
2574        assert!(
2575            tensor.memory() == TensorMemory::Shm || tensor.memory() == TensorMemory::Mem,
2576            "Expected SHM or Mem, got {:?}",
2577            tensor.memory()
2578        );
2579    }
2580
2581    #[test]
2582    #[cfg(not(unix))]
2583    fn test_tensor() {
2584        let shape = vec![1];
2585        let tensor = Tensor::<f32>::new(&shape, None, None).expect("Failed to create tensor");
2586        assert_eq!(tensor.memory(), TensorMemory::Mem);
2587    }
2588
2589    #[test]
2590    #[cfg(target_os = "linux")]
2591    fn test_dma_tensor() {
2592        let _lock = FD_LOCK.read().unwrap();
2593        match access(
2594            "/dev/dma_heap/linux,cma",
2595            AccessFlags::R_OK | AccessFlags::W_OK,
2596        ) {
2597            Ok(_) => println!("/dev/dma_heap/linux,cma is available"),
2598            Err(_) => match access(
2599                "/dev/dma_heap/system",
2600                AccessFlags::R_OK | AccessFlags::W_OK,
2601            ) {
2602                Ok(_) => println!("/dev/dma_heap/system is available"),
2603                Err(e) => {
2604                    writeln!(
2605                        &mut std::io::stdout(),
2606                        "[WARNING] DMA Heap is unavailable: {e}"
2607                    )
2608                    .unwrap();
2609                    return;
2610                }
2611            },
2612        }
2613
2614        let shape = vec![2, 3, 4];
2615        let tensor =
2616            DmaTensor::<f32>::new(&shape, Some("test_tensor")).expect("Failed to create tensor");
2617
2618        const DUMMY_VALUE: f32 = 12.34;
2619
2620        assert_eq!(tensor.memory(), TensorMemory::Dma);
2621        assert_eq!(tensor.name(), "test_tensor");
2622        assert_eq!(tensor.shape(), &shape);
2623        assert_eq!(tensor.size(), 2 * 3 * 4 * std::mem::size_of::<f32>());
2624        assert_eq!(tensor.len(), 2 * 3 * 4);
2625
2626        {
2627            let mut tensor_map = tensor.map().expect("Failed to map DMA memory");
2628            tensor_map.fill(42.0);
2629            assert!(tensor_map.iter().all(|&x| x == 42.0));
2630        }
2631
2632        {
2633            let shared = Tensor::<f32>::from_fd(
2634                tensor
2635                    .clone_fd()
2636                    .expect("Failed to duplicate tensor file descriptor"),
2637                &shape,
2638                Some("test_tensor_shared"),
2639            )
2640            .expect("Failed to create tensor from fd");
2641
2642            assert_eq!(shared.memory(), TensorMemory::Dma);
2643            assert_eq!(shared.name(), "test_tensor_shared");
2644            assert_eq!(shared.shape(), &shape);
2645
2646            let mut tensor_map = shared.map().expect("Failed to map DMA memory from fd");
2647            tensor_map.fill(DUMMY_VALUE);
2648            assert!(tensor_map.iter().all(|&x| x == DUMMY_VALUE));
2649        }
2650
2651        {
2652            let tensor_map = tensor.map().expect("Failed to map DMA memory");
2653            assert!(tensor_map.iter().all(|&x| x == DUMMY_VALUE));
2654        }
2655
2656        let mut tensor = DmaTensor::<u8>::new(&shape, None).expect("Failed to create tensor");
2657        assert_eq!(tensor.shape(), &shape);
2658        let new_shape = vec![3, 4, 4];
2659        assert!(
2660            tensor.reshape(&new_shape).is_err(),
2661            "Reshape should fail due to size mismatch"
2662        );
2663        assert_eq!(tensor.shape(), &shape, "Shape should remain unchanged");
2664
2665        let new_shape = vec![2, 3, 4];
2666        tensor.reshape(&new_shape).expect("Reshape should succeed");
2667        assert_eq!(
2668            tensor.shape(),
2669            &new_shape,
2670            "Shape should be updated after successful reshape"
2671        );
2672
2673        {
2674            let mut tensor_map = tensor.map().expect("Failed to map DMA memory");
2675            tensor_map.fill(1);
2676            assert!(tensor_map.iter().all(|&x| x == 1));
2677        }
2678
2679        {
2680            let mut tensor_map = tensor.map().expect("Failed to map DMA memory");
2681            tensor_map[2] = 42;
2682            assert_eq!(tensor_map[1], 1, "Value at index 1 should be 1");
2683            assert_eq!(tensor_map[2], 42, "Value at index 2 should be 42");
2684        }
2685    }
2686
2687    #[test]
2688    #[cfg(unix)]
2689    fn test_shm_tensor() {
2690        let _lock = FD_LOCK.read().unwrap();
2691        let shape = vec![2, 3, 4];
2692        let tensor =
2693            ShmTensor::<f32>::new(&shape, Some("test_tensor")).expect("Failed to create tensor");
2694        assert_eq!(tensor.shape(), &shape);
2695        assert_eq!(tensor.size(), 2 * 3 * 4 * std::mem::size_of::<f32>());
2696        assert_eq!(tensor.name(), "test_tensor");
2697
2698        const DUMMY_VALUE: f32 = 12.34;
2699        {
2700            let mut tensor_map = tensor.map().expect("Failed to map shared memory");
2701            tensor_map.fill(42.0);
2702            assert!(tensor_map.iter().all(|&x| x == 42.0));
2703        }
2704
2705        {
2706            let shared = Tensor::<f32>::from_fd(
2707                tensor
2708                    .clone_fd()
2709                    .expect("Failed to duplicate tensor file descriptor"),
2710                &shape,
2711                Some("test_tensor_shared"),
2712            )
2713            .expect("Failed to create tensor from fd");
2714
2715            assert_eq!(shared.memory(), TensorMemory::Shm);
2716            assert_eq!(shared.name(), "test_tensor_shared");
2717            assert_eq!(shared.shape(), &shape);
2718
2719            let mut tensor_map = shared.map().expect("Failed to map shared memory from fd");
2720            tensor_map.fill(DUMMY_VALUE);
2721            assert!(tensor_map.iter().all(|&x| x == DUMMY_VALUE));
2722        }
2723
2724        {
2725            let tensor_map = tensor.map().expect("Failed to map shared memory");
2726            assert!(tensor_map.iter().all(|&x| x == DUMMY_VALUE));
2727        }
2728
2729        let mut tensor = ShmTensor::<u8>::new(&shape, None).expect("Failed to create tensor");
2730        assert_eq!(tensor.shape(), &shape);
2731        let new_shape = vec![3, 4, 4];
2732        assert!(
2733            tensor.reshape(&new_shape).is_err(),
2734            "Reshape should fail due to size mismatch"
2735        );
2736        assert_eq!(tensor.shape(), &shape, "Shape should remain unchanged");
2737
2738        let new_shape = vec![2, 3, 4];
2739        tensor.reshape(&new_shape).expect("Reshape should succeed");
2740        assert_eq!(
2741            tensor.shape(),
2742            &new_shape,
2743            "Shape should be updated after successful reshape"
2744        );
2745
2746        {
2747            let mut tensor_map = tensor.map().expect("Failed to map shared memory");
2748            tensor_map.fill(1);
2749            assert!(tensor_map.iter().all(|&x| x == 1));
2750        }
2751
2752        {
2753            let mut tensor_map = tensor.map().expect("Failed to map shared memory");
2754            tensor_map[2] = 42;
2755            assert_eq!(tensor_map[1], 1, "Value at index 1 should be 1");
2756            assert_eq!(tensor_map[2], 42, "Value at index 2 should be 42");
2757        }
2758    }
2759
2760    #[test]
2761    fn test_mem_tensor() {
2762        let shape = vec![2, 3, 4];
2763        let tensor =
2764            MemTensor::<f32>::new(&shape, Some("test_tensor")).expect("Failed to create tensor");
2765        assert_eq!(tensor.shape(), &shape);
2766        assert_eq!(tensor.size(), 2 * 3 * 4 * std::mem::size_of::<f32>());
2767        assert_eq!(tensor.name(), "test_tensor");
2768
2769        {
2770            let mut tensor_map = tensor.map().expect("Failed to map memory");
2771            tensor_map.fill(42.0);
2772            assert!(tensor_map.iter().all(|&x| x == 42.0));
2773        }
2774
2775        let mut tensor = MemTensor::<u8>::new(&shape, None).expect("Failed to create tensor");
2776        assert_eq!(tensor.shape(), &shape);
2777        let new_shape = vec![3, 4, 4];
2778        assert!(
2779            tensor.reshape(&new_shape).is_err(),
2780            "Reshape should fail due to size mismatch"
2781        );
2782        assert_eq!(tensor.shape(), &shape, "Shape should remain unchanged");
2783
2784        let new_shape = vec![2, 3, 4];
2785        tensor.reshape(&new_shape).expect("Reshape should succeed");
2786        assert_eq!(
2787            tensor.shape(),
2788            &new_shape,
2789            "Shape should be updated after successful reshape"
2790        );
2791
2792        {
2793            let mut tensor_map = tensor.map().expect("Failed to map memory");
2794            tensor_map.fill(1);
2795            assert!(tensor_map.iter().all(|&x| x == 1));
2796        }
2797
2798        {
2799            let mut tensor_map = tensor.map().expect("Failed to map memory");
2800            tensor_map[2] = 42;
2801            assert_eq!(tensor_map[1], 1, "Value at index 1 should be 1");
2802            assert_eq!(tensor_map[2], 42, "Value at index 2 should be 42");
2803        }
2804    }
2805
2806    #[test]
2807    #[cfg(target_os = "linux")]
2808    fn test_dma_no_fd_leaks() {
2809        let _lock = FD_LOCK.write().unwrap();
2810        if !is_dma_available() {
2811            log::warn!(
2812                "SKIPPED: {} - DMA memory allocation not available (permission denied or no DMA-BUF support)",
2813                function!()
2814            );
2815            return;
2816        }
2817
2818        let proc = procfs::process::Process::myself()
2819            .expect("Failed to get current process using /proc/self");
2820
2821        let start_open_fds = proc
2822            .fd_count()
2823            .expect("Failed to get open file descriptor count");
2824
2825        for _ in 0..100 {
2826            let tensor = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Dma), None)
2827                .expect("Failed to create tensor");
2828            let mut map = tensor.map().unwrap();
2829            map.as_mut_slice().fill(233);
2830        }
2831
2832        let end_open_fds = proc
2833            .fd_count()
2834            .expect("Failed to get open file descriptor count");
2835
2836        assert_eq!(
2837            start_open_fds, end_open_fds,
2838            "File descriptor leak detected: {} -> {}",
2839            start_open_fds, end_open_fds
2840        );
2841    }
2842
2843    #[test]
2844    #[cfg(target_os = "linux")]
2845    fn test_dma_from_fd_no_fd_leaks() {
2846        let _lock = FD_LOCK.write().unwrap();
2847        if !is_dma_available() {
2848            log::warn!(
2849                "SKIPPED: {} - DMA memory allocation not available (permission denied or no DMA-BUF support)",
2850                function!()
2851            );
2852            return;
2853        }
2854
2855        let proc = procfs::process::Process::myself()
2856            .expect("Failed to get current process using /proc/self");
2857
2858        let start_open_fds = proc
2859            .fd_count()
2860            .expect("Failed to get open file descriptor count");
2861
2862        let orig = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Dma), None).unwrap();
2863
2864        for _ in 0..100 {
2865            let tensor =
2866                Tensor::<u8>::from_fd(orig.clone_fd().unwrap(), orig.shape(), None).unwrap();
2867            let mut map = tensor.map().unwrap();
2868            map.as_mut_slice().fill(233);
2869        }
2870        drop(orig);
2871
2872        let end_open_fds = proc.fd_count().unwrap();
2873
2874        assert_eq!(
2875            start_open_fds, end_open_fds,
2876            "File descriptor leak detected: {} -> {}",
2877            start_open_fds, end_open_fds
2878        );
2879    }
2880
2881    #[test]
2882    #[cfg(target_os = "linux")]
2883    fn test_shm_no_fd_leaks() {
2884        let _lock = FD_LOCK.write().unwrap();
2885        if !is_shm_available() {
2886            log::warn!(
2887                "SKIPPED: {} - SHM memory allocation not available (permission denied or no SHM support)",
2888                function!()
2889            );
2890            return;
2891        }
2892
2893        let proc = procfs::process::Process::myself()
2894            .expect("Failed to get current process using /proc/self");
2895
2896        let start_open_fds = proc
2897            .fd_count()
2898            .expect("Failed to get open file descriptor count");
2899
2900        for _ in 0..100 {
2901            let tensor = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Shm), None)
2902                .expect("Failed to create tensor");
2903            let mut map = tensor.map().unwrap();
2904            map.as_mut_slice().fill(233);
2905        }
2906
2907        let end_open_fds = proc
2908            .fd_count()
2909            .expect("Failed to get open file descriptor count");
2910
2911        assert_eq!(
2912            start_open_fds, end_open_fds,
2913            "File descriptor leak detected: {} -> {}",
2914            start_open_fds, end_open_fds
2915        );
2916    }
2917
2918    #[test]
2919    #[cfg(target_os = "linux")]
2920    fn test_shm_from_fd_no_fd_leaks() {
2921        let _lock = FD_LOCK.write().unwrap();
2922        if !is_shm_available() {
2923            log::warn!(
2924                "SKIPPED: {} - SHM memory allocation not available (permission denied or no SHM support)",
2925                function!()
2926            );
2927            return;
2928        }
2929
2930        let proc = procfs::process::Process::myself()
2931            .expect("Failed to get current process using /proc/self");
2932
2933        let start_open_fds = proc
2934            .fd_count()
2935            .expect("Failed to get open file descriptor count");
2936
2937        let orig = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Shm), None).unwrap();
2938
2939        for _ in 0..100 {
2940            let tensor =
2941                Tensor::<u8>::from_fd(orig.clone_fd().unwrap(), orig.shape(), None).unwrap();
2942            let mut map = tensor.map().unwrap();
2943            map.as_mut_slice().fill(233);
2944        }
2945        drop(orig);
2946
2947        let end_open_fds = proc.fd_count().unwrap();
2948
2949        assert_eq!(
2950            start_open_fds, end_open_fds,
2951            "File descriptor leak detected: {} -> {}",
2952            start_open_fds, end_open_fds
2953        );
2954    }
2955
2956    #[cfg(feature = "ndarray")]
2957    #[test]
2958    fn test_ndarray() {
2959        let _lock = FD_LOCK.read().unwrap();
2960        let shape = vec![2, 3, 4];
2961        let tensor = Tensor::<f32>::new(&shape, None, None).expect("Failed to create tensor");
2962
2963        let mut tensor_map = tensor.map().expect("Failed to map tensor memory");
2964        tensor_map.fill(1.0);
2965
2966        let view = tensor_map.view().expect("Failed to get ndarray view");
2967        assert_eq!(view.shape(), &[2, 3, 4]);
2968        assert!(view.iter().all(|&x| x == 1.0));
2969
2970        let mut view_mut = tensor_map
2971            .view_mut()
2972            .expect("Failed to get mutable ndarray view");
2973        view_mut[[0, 0, 0]] = 42.0;
2974        assert_eq!(view_mut[[0, 0, 0]], 42.0);
2975        assert_eq!(tensor_map[0], 42.0, "Value at index 0 should be 42");
2976    }
2977
2978    #[test]
2979    fn test_buffer_identity_unique() {
2980        let id1 = BufferIdentity::new();
2981        let id2 = BufferIdentity::new();
2982        assert_ne!(
2983            id1.id(),
2984            id2.id(),
2985            "Two identities should have different ids"
2986        );
2987    }
2988
2989    #[test]
2990    fn test_buffer_identity_clone_shares_guard() {
2991        let id1 = BufferIdentity::new();
2992        let weak = id1.weak();
2993        assert!(
2994            weak.upgrade().is_some(),
2995            "Weak should be alive while original exists"
2996        );
2997
2998        let id2 = id1.clone();
2999        assert_eq!(id1.id(), id2.id(), "Cloned identity should have same id");
3000
3001        drop(id1);
3002        assert!(
3003            weak.upgrade().is_some(),
3004            "Weak should still be alive (clone holds Arc)"
3005        );
3006
3007        drop(id2);
3008        assert!(
3009            weak.upgrade().is_none(),
3010            "Weak should be dead after all clones dropped"
3011        );
3012    }
3013
3014    #[test]
3015    fn test_tensor_buffer_identity() {
3016        let t1 = Tensor::<u8>::new(&[100], Some(TensorMemory::Mem), Some("t1")).unwrap();
3017        let t2 = Tensor::<u8>::new(&[100], Some(TensorMemory::Mem), Some("t2")).unwrap();
3018        assert_ne!(
3019            t1.buffer_identity().id(),
3020            t2.buffer_identity().id(),
3021            "Different tensors should have different buffer ids"
3022        );
3023    }
3024
3025    // ------------------------------------------------------------------------
3026    // Quantization — constructor validation + accessor correctness.
3027    // ------------------------------------------------------------------------
3028
3029    #[test]
3030    fn test_quantization_per_tensor_constructors() {
3031        let q = Quantization::per_tensor(0.1, -5);
3032        assert!(q.is_per_tensor());
3033        assert!(!q.is_per_channel());
3034        assert!(!q.is_symmetric());
3035        assert_eq!(q.scale(), &[0.1]);
3036        assert_eq!(q.zero_point(), Some(&[-5][..]));
3037
3038        let qs = Quantization::per_tensor_symmetric(0.05);
3039        assert!(qs.is_per_tensor());
3040        assert!(qs.is_symmetric());
3041        assert_eq!(qs.zero_point(), None);
3042    }
3043
3044    #[test]
3045    fn test_quantization_per_channel_constructors() {
3046        let q = Quantization::per_channel(vec![0.1, 0.2, 0.3], vec![0, -1, 1], 2).unwrap();
3047        assert!(q.is_per_channel());
3048        assert!(!q.is_symmetric());
3049        assert_eq!(q.axis(), Some(2));
3050        assert_eq!(q.scale().len(), 3);
3051
3052        let qs = Quantization::per_channel_symmetric(vec![0.054, 0.089, 0.195], 0).unwrap();
3053        assert!(qs.is_per_channel());
3054        assert!(qs.is_symmetric());
3055        assert_eq!(qs.axis(), Some(0));
3056    }
3057
3058    #[test]
3059    fn test_quantization_per_channel_length_mismatch_rejected() {
3060        // len(scales) != len(zero_points) → rejected at construction.
3061        let err = Quantization::per_channel(vec![0.1, 0.2], vec![0, 0, 0], 0).unwrap_err();
3062        assert!(matches!(err, Error::QuantizationInvalid { .. }));
3063    }
3064
3065    #[test]
3066    fn test_quantization_per_channel_empty_rejected() {
3067        let err = Quantization::per_channel_symmetric(vec![], 0).unwrap_err();
3068        assert!(matches!(err, Error::QuantizationInvalid { .. }));
3069    }
3070
3071    /// Constructors guard scale/zero_point length invariants, but
3072    /// `Quantization` is `Deserialize`, so malformed JSON (e.g. an
3073    /// empty `scale` array, or `zero_point` length that disagrees with
3074    /// `scale`) bypasses the constructor checks. `set_quantization`
3075    /// must reject these via `validate()` so they don't poison
3076    /// downstream `mode()` selection or per-channel kernel indexing.
3077    #[test]
3078    fn test_quantization_validate_rejects_malformed_deserialize() {
3079        let mut t = Tensor::<i8>::new(&[1, 1, 4], Some(TensorMemory::Mem), None).unwrap();
3080
3081        // Empty scale array: must be rejected.
3082        let q: Quantization = serde_json::from_str(r#"{"scale": []}"#).unwrap();
3083        assert!(matches!(
3084            t.set_quantization(q).unwrap_err(),
3085            Error::QuantizationInvalid { .. }
3086        ));
3087
3088        // Per-tensor with multi-element zero_point: must be rejected.
3089        let q: Quantization =
3090            serde_json::from_str(r#"{"scale": 0.1, "zero_point": [0, 0, 0]}"#).unwrap();
3091        assert!(matches!(
3092            t.set_quantization(q).unwrap_err(),
3093            Error::QuantizationInvalid { .. }
3094        ));
3095
3096        // Per-channel zero_point length != scale length: must be rejected.
3097        let q: Quantization = serde_json::from_str(
3098            r#"{"scale": [0.1, 0.2, 0.3, 0.4], "zero_point": [0, 0], "axis": 2}"#,
3099        )
3100        .unwrap();
3101        assert!(matches!(
3102            t.set_quantization(q).unwrap_err(),
3103            Error::QuantizationInvalid { .. }
3104        ));
3105    }
3106
3107    #[test]
3108    fn test_quantization_mode_dispatch() {
3109        let pt = Quantization::per_tensor(0.1, -5);
3110        assert!(matches!(
3111            pt.mode(),
3112            QuantMode::PerTensor { scale, zero_point } if scale == 0.1 && zero_point == -5
3113        ));
3114
3115        let pts = Quantization::per_tensor_symmetric(0.05);
3116        assert!(matches!(
3117            pts.mode(),
3118            QuantMode::PerTensorSymmetric { scale } if scale == 0.05
3119        ));
3120
3121        let pc = Quantization::per_channel(vec![0.1, 0.2], vec![0, -1], 2).unwrap();
3122        assert!(matches!(pc.mode(), QuantMode::PerChannel { axis: 2, .. }));
3123
3124        let pcs = Quantization::per_channel_symmetric(vec![0.1, 0.2], 0).unwrap();
3125        assert!(matches!(
3126            pcs.mode(),
3127            QuantMode::PerChannelSymmetric { axis: 0, .. }
3128        ));
3129    }
3130
3131    #[test]
3132    fn test_tensor_quantization_roundtrip_integer() {
3133        let mut t = Tensor::<i8>::new(&[2, 3, 4], Some(TensorMemory::Mem), None).unwrap();
3134        assert!(t.quantization().is_none());
3135        t.set_quantization(Quantization::per_tensor(0.1, -5))
3136            .unwrap();
3137        let q = t.quantization().unwrap();
3138        assert_eq!(q.scale(), &[0.1]);
3139        t.clear_quantization();
3140        assert!(t.quantization().is_none());
3141    }
3142
3143    #[test]
3144    fn test_tensor_with_quantization_builder() {
3145        let t = Tensor::<i8>::new(&[4, 4], Some(TensorMemory::Mem), None)
3146            .unwrap()
3147            .with_quantization(Quantization::per_tensor_symmetric(0.05))
3148            .unwrap();
3149        assert!(t.quantization().is_some());
3150    }
3151
3152    #[test]
3153    fn test_tensor_dyn_quantization_float_arm_returns_none() {
3154        let t = Tensor::<f32>::new(&[2, 2], Some(TensorMemory::Mem), None).unwrap();
3155        let td = TensorDyn::F32(t);
3156        assert!(td.quantization().is_none());
3157    }
3158
3159    #[test]
3160    fn test_tensor_dyn_set_quantization_float_arm_errors() {
3161        let t = Tensor::<f32>::new(&[2, 2], Some(TensorMemory::Mem), None).unwrap();
3162        let mut td = TensorDyn::F32(t);
3163        let err = td
3164            .set_quantization(Quantization::per_tensor(0.1, 0))
3165            .unwrap_err();
3166        // float path returns a QuantizationInvalid error.
3167        assert!(matches!(err, Error::QuantizationInvalid { .. }));
3168    }
3169
3170    /// Compile-time type gate — calling `Tensor::<f32>::quantization()` must
3171    /// fail to compile (the `IntegerType` trait bound is not satisfied by
3172    /// `f32`). This doctest anchors the invariant.
3173    ///
3174    /// ```compile_fail
3175    /// use edgefirst_tensor::{Tensor, TensorMemory};
3176    /// let t = Tensor::<f32>::new(&[2, 2], Some(TensorMemory::Mem), None).unwrap();
3177    /// let _ = t.quantization(); // compile error: f32 not IntegerType
3178    /// ```
3179    fn _compile_fail_doctest_anchor() {}
3180
3181    // Any test that cares about the fd count must grab it exclusively.
3182    // Any tests which modifies the fd count by opening or closing fds must grab it
3183    // shared.
3184    pub static FD_LOCK: RwLock<()> = RwLock::new(());
3185
3186    /// Test that DMA is NOT available on non-Linux platforms.
3187    /// This verifies the cross-platform behavior of is_dma_available().
3188    #[test]
3189    #[cfg(not(target_os = "linux"))]
3190    fn test_dma_not_available_on_non_linux() {
3191        assert!(
3192            !is_dma_available(),
3193            "DMA memory allocation should NOT be available on non-Linux platforms"
3194        );
3195    }
3196
3197    /// Test that SHM memory allocation is available and usable on Unix systems.
3198    /// This is a basic functional test; Linux has additional FD leak tests using procfs.
3199    #[test]
3200    #[cfg(unix)]
3201    fn test_shm_available_and_usable() {
3202        assert!(
3203            is_shm_available(),
3204            "SHM memory allocation should be available on Unix systems"
3205        );
3206
3207        // Create a tensor with SHM backing
3208        let tensor = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Shm), None)
3209            .expect("Failed to create SHM tensor");
3210
3211        // Verify we can map and write to it
3212        let mut map = tensor.map().expect("Failed to map SHM tensor");
3213        map.as_mut_slice().fill(0xAB);
3214
3215        // Verify the data was written correctly
3216        assert!(
3217            map.as_slice().iter().all(|&b| b == 0xAB),
3218            "SHM tensor data should be writable and readable"
3219        );
3220    }
3221}