Skip to main content

edgefirst_tensor/
lib.rs

1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4/*!
5EdgeFirst HAL - Tensor Module
6
7The `edgefirst_tensor` crate provides a unified interface for managing multi-dimensional arrays (tensors)
8with support for different memory types, including Direct Memory Access (DMA), POSIX Shared Memory (Shm),
9and system memory. The crate defines traits and structures for creating, reshaping, and mapping tensors into memory.
10
11## Examples
12```rust
13use edgefirst_tensor::{Error, Tensor, TensorMemory, TensorTrait};
14# fn main() -> Result<(), Error> {
15let tensor = Tensor::<f32>::new(&[2, 3, 4], Some(TensorMemory::Mem), Some("test_tensor"))?;
16assert_eq!(tensor.memory(), TensorMemory::Mem);
17assert_eq!(tensor.name(), "test_tensor");
18#    Ok(())
19# }
20```
21
22## Overview
23The main structures and traits provided by the `edgefirst_tensor` crate are `TensorTrait` and `TensorMapTrait`,
24which define the behavior of Tensors and their memory mappings, respectively.
25The `Tensor<T>` struct wraps a backend-specific storage with optional image format metadata (`PixelFormat`),
26while the `TensorMap` enum provides access to the underlying data. The `TensorDyn` type-erased enum
27wraps `Tensor<T>` for runtime element-type dispatch.
28 */
29#[cfg(target_os = "linux")]
30mod dma;
31#[cfg(target_os = "linux")]
32mod dmabuf;
33mod error;
34mod format;
35mod mem;
36mod pbo;
37#[cfg(unix)]
38mod shm;
39mod tensor_dyn;
40
41#[cfg(target_os = "linux")]
42pub use crate::dma::{DmaMap, DmaTensor};
43pub use crate::mem::{MemMap, MemTensor};
44pub use crate::pbo::{PboMap, PboMapping, PboOps, PboTensor};
45#[cfg(unix)]
46pub use crate::shm::{ShmMap, ShmTensor};
47pub use error::{Error, Result};
48pub use format::{PixelFormat, PixelLayout};
49use num_traits::Num;
50use serde::{Deserialize, Serialize};
51#[cfg(unix)]
52use std::os::fd::OwnedFd;
53use std::{
54    fmt,
55    ops::{Deref, DerefMut},
56    sync::{
57        atomic::{AtomicU64, Ordering},
58        Arc, Weak,
59    },
60};
61pub use tensor_dyn::TensorDyn;
62
63/// Per-plane DMA-BUF descriptor for external buffer import.
64///
65/// Owns a duplicated file descriptor plus optional stride and offset metadata.
66/// The fd is duplicated eagerly in [`new()`](Self::new) so that a bad fd is
67/// caught immediately. `import_image` consumes the descriptor and takes
68/// ownership of the duped fd — no further cleanup is needed by the caller.
69///
70/// # Examples
71///
72/// ```rust,no_run
73/// use edgefirst_tensor::PlaneDescriptor;
74/// use std::os::fd::BorrowedFd;
75///
76/// // SAFETY: fd 42 is hypothetical; real code must pass a valid fd.
77/// let pd = unsafe { PlaneDescriptor::new(BorrowedFd::borrow_raw(42)) }
78///     .unwrap()
79///     .with_stride(2048)
80///     .with_offset(0);
81/// ```
82#[cfg(unix)]
83pub struct PlaneDescriptor {
84    fd: OwnedFd,
85    stride: Option<usize>,
86    offset: Option<usize>,
87}
88
89#[cfg(unix)]
90impl PlaneDescriptor {
91    /// Create a new plane descriptor by duplicating the given file descriptor.
92    ///
93    /// The fd is duped immediately — a bad fd fails here rather than inside
94    /// `import_image`. The caller retains ownership of the original fd.
95    ///
96    /// # Errors
97    ///
98    /// Returns an error if the `dup()` syscall fails (e.g. invalid fd or
99    /// fd limit reached).
100    pub fn new(fd: std::os::fd::BorrowedFd<'_>) -> Result<Self> {
101        let owned = fd.try_clone_to_owned()?;
102        Ok(Self {
103            fd: owned,
104            stride: None,
105            offset: None,
106        })
107    }
108
109    /// Set the row stride in bytes (consuming builder).
110    pub fn with_stride(mut self, stride: usize) -> Self {
111        self.stride = Some(stride);
112        self
113    }
114
115    /// Set the plane offset in bytes (consuming builder).
116    pub fn with_offset(mut self, offset: usize) -> Self {
117        self.offset = Some(offset);
118        self
119    }
120
121    /// Consume the descriptor and return the owned file descriptor.
122    pub fn into_fd(self) -> OwnedFd {
123        self.fd
124    }
125
126    /// Row stride in bytes, if set.
127    pub fn stride(&self) -> Option<usize> {
128        self.stride
129    }
130
131    /// Plane offset in bytes, if set.
132    pub fn offset(&self) -> Option<usize> {
133        self.offset
134    }
135}
136
137/// Element type discriminant for runtime type identification.
138#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
139#[repr(u8)]
140#[non_exhaustive]
141pub enum DType {
142    U8,
143    I8,
144    U16,
145    I16,
146    U32,
147    I32,
148    U64,
149    I64,
150    F16,
151    F32,
152    F64,
153}
154
155impl DType {
156    /// Size of one element in bytes.
157    pub const fn size(&self) -> usize {
158        match self {
159            Self::U8 | Self::I8 => 1,
160            Self::U16 | Self::I16 | Self::F16 => 2,
161            Self::U32 | Self::I32 | Self::F32 => 4,
162            Self::U64 | Self::I64 | Self::F64 => 8,
163        }
164    }
165
166    /// Short type name (e.g., "u8", "f32", "f16").
167    pub const fn name(&self) -> &'static str {
168        match self {
169            Self::U8 => "u8",
170            Self::I8 => "i8",
171            Self::U16 => "u16",
172            Self::I16 => "i16",
173            Self::U32 => "u32",
174            Self::I32 => "i32",
175            Self::U64 => "u64",
176            Self::I64 => "i64",
177            Self::F16 => "f16",
178            Self::F32 => "f32",
179            Self::F64 => "f64",
180        }
181    }
182}
183
184impl fmt::Display for DType {
185    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
186        f.write_str(self.name())
187    }
188}
189
190// =============================================================================
191// Quantization metadata — type-gated to integer element types via sealed
192// `IntegerType` trait. Accessors on `Tensor<T>` only compile when `T` is
193// an integer type; calling them on `Tensor<f32>` / `Tensor<f16>` etc. is a
194// compile error, not a runtime one.
195// =============================================================================
196
197mod sealed {
198    pub trait Sealed {}
199    impl Sealed for u8 {}
200    impl Sealed for i8 {}
201    impl Sealed for u16 {}
202    impl Sealed for i16 {}
203    impl Sealed for u32 {}
204    impl Sealed for i32 {}
205    impl Sealed for u64 {}
206    impl Sealed for i64 {}
207    // Deliberately NOT implemented for f16 / f32 / f64.
208}
209
210/// Integer element types that may carry quantization metadata.
211///
212/// Sealed trait: implemented for `u8`, `i8`, `u16`, `i16`, `u32`, `i32`,
213/// `u64`, `i64`. Cannot be implemented downstream. Float element types
214/// (`half::f16`, `f32`, `f64`) are explicitly excluded — quantization
215/// metadata does not apply to float tensors per the edgefirst.json spec.
216pub trait IntegerType: sealed::Sealed {}
217impl IntegerType for u8 {}
218impl IntegerType for i8 {}
219impl IntegerType for u16 {}
220impl IntegerType for i16 {}
221impl IntegerType for u32 {}
222impl IntegerType for i32 {}
223impl IntegerType for u64 {}
224impl IntegerType for i64 {}
225
226/// Quantization parameters for an integer tensor.
227///
228/// Covers all four modes the edgefirst.json spec defines:
229///
230/// | Mode | `scale.len()` | `zero_point` | `axis` |
231/// |---|---|---|---|
232/// | Per-tensor symmetric | 1 | `None` | `None` |
233/// | Per-tensor asymmetric | 1 | `Some(len == 1)` | `None` |
234/// | Per-channel symmetric | >1 | `None` | `Some(c)` |
235/// | Per-channel asymmetric | >1 | `Some(len == scale.len())` | `Some(c)` |
236///
237/// The quantized storage type is carried on the parent [`Tensor<T>`]; this
238/// struct does not duplicate it. Construct via the four named constructors
239/// (the only public entry points); direct field mutation is not allowed so
240/// invalid combinations cannot be represented.
241///
242/// Dequantization formula:
243///
244/// ```text
245///   real_value = scale[c] × (quantized_value[c] - zero_point[c])
246/// ```
247///
248/// where `c` is the channel index (always `0` for per-tensor).
249#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
250pub struct Quantization {
251    /// Per-tensor: `vec![scale]`. Per-channel: `vec![scale_0, scale_1, ...]`.
252    #[serde(deserialize_with = "deserialize_scalar_or_vec_f32")]
253    scale: Vec<f32>,
254
255    /// `None` means symmetric (zero-point is 0). `Some(vec)` must have the
256    /// same length as `scale`.
257    #[serde(
258        default,
259        deserialize_with = "deserialize_opt_scalar_or_vec_i32",
260        skip_serializing_if = "Option::is_none"
261    )]
262    zero_point: Option<Vec<i32>>,
263
264    /// Channel axis for per-channel quantization. `Some(_)` iff
265    /// `scale.len() > 1`. Validated against the parent tensor's shape at
266    /// `set_quantization()` time.
267    #[serde(default, skip_serializing_if = "Option::is_none")]
268    axis: Option<usize>,
269}
270
271/// Semantic mode discriminant for hot-path kernel dispatch.
272///
273/// Obtain via [`Quantization::mode`] once at kernel entry; never inside a
274/// pixel-level loop. The enum is borrow-based so the hot kernel receives
275/// the scales / zero-points as slices without reallocation.
276#[derive(Debug, Clone, Copy)]
277pub enum QuantMode<'a> {
278    PerTensorSymmetric {
279        scale: f32,
280    },
281    PerTensor {
282        scale: f32,
283        zero_point: i32,
284    },
285    PerChannelSymmetric {
286        scales: &'a [f32],
287        axis: usize,
288    },
289    PerChannel {
290        scales: &'a [f32],
291        zero_points: &'a [i32],
292        axis: usize,
293    },
294}
295
296impl Quantization {
297    /// Per-tensor symmetric (zero_point = 0).
298    pub fn per_tensor_symmetric(scale: f32) -> Self {
299        Self {
300            scale: vec![scale],
301            zero_point: None,
302            axis: None,
303        }
304    }
305
306    /// Per-tensor asymmetric — the most common runtime shape.
307    pub fn per_tensor(scale: f32, zero_point: i32) -> Self {
308        Self {
309            scale: vec![scale],
310            zero_point: Some(vec![zero_point]),
311            axis: None,
312        }
313    }
314
315    /// Per-channel symmetric. Errors on empty `scales`.
316    pub fn per_channel_symmetric(scales: Vec<f32>, axis: usize) -> Result<Self> {
317        if scales.is_empty() {
318            return Err(Error::QuantizationInvalid {
319                field: "scale.len",
320                expected: "non-empty per-channel scales".to_string(),
321                got: "length 0".to_string(),
322            });
323        }
324        Ok(Self {
325            scale: scales,
326            zero_point: None,
327            axis: Some(axis),
328        })
329    }
330
331    /// Per-channel asymmetric. Errors on length mismatch between `scales`
332    /// and `zero_points`, or empty arrays.
333    pub fn per_channel(scales: Vec<f32>, zero_points: Vec<i32>, axis: usize) -> Result<Self> {
334        if scales.is_empty() {
335            return Err(Error::QuantizationInvalid {
336                field: "scale.len",
337                expected: "non-empty per-channel scales".to_string(),
338                got: "length 0".to_string(),
339            });
340        }
341        if scales.len() != zero_points.len() {
342            return Err(Error::QuantizationInvalid {
343                field: "zero_point.len",
344                expected: format!("length matches scale ({})", scales.len()),
345                got: format!("length {}", zero_points.len()),
346            });
347        }
348        Ok(Self {
349            scale: scales,
350            zero_point: Some(zero_points),
351            axis: Some(axis),
352        })
353    }
354
355    /// Borrow-based dispatch view. Match once at kernel entry.
356    pub fn mode(&self) -> QuantMode<'_> {
357        match (self.scale.len(), self.zero_point.as_deref(), self.axis) {
358            (1, None, _) => QuantMode::PerTensorSymmetric {
359                scale: self.scale[0],
360            },
361            (1, Some(zps), _) => QuantMode::PerTensor {
362                scale: self.scale[0],
363                zero_point: zps.first().copied().unwrap_or(0),
364            },
365            (_, None, Some(axis)) => QuantMode::PerChannelSymmetric {
366                scales: &self.scale,
367                axis,
368            },
369            (_, Some(zps), Some(axis)) => QuantMode::PerChannel {
370                scales: &self.scale,
371                zero_points: zps,
372                axis,
373            },
374            // The `validate()` path prevents constructing a
375            // per-channel Quantization without an axis, so the remaining
376            // pattern is unreachable in practice. Fall back to
377            // per-tensor symmetric using scale[0] to avoid panicking in
378            // release; debug builds assert.
379            _ => {
380                debug_assert!(
381                    false,
382                    "Quantization::mode: per-channel without axis is unreachable"
383                );
384                QuantMode::PerTensorSymmetric {
385                    scale: self.scale.first().copied().unwrap_or(1.0),
386                }
387            }
388        }
389    }
390
391    /// Returns `true` for per-tensor quantization (`scale.len() == 1`).
392    pub fn is_per_tensor(&self) -> bool {
393        self.scale.len() == 1
394    }
395
396    /// Returns `true` for per-channel quantization (`scale.len() > 1`).
397    pub fn is_per_channel(&self) -> bool {
398        self.scale.len() > 1
399    }
400
401    /// Returns `true` for symmetric quantization (no zero-point, or
402    /// zero-point vector of all zeros).
403    pub fn is_symmetric(&self) -> bool {
404        match &self.zero_point {
405            None => true,
406            Some(zps) => zps.iter().all(|&z| z == 0),
407        }
408    }
409
410    /// Borrow the scale array. Length 1 for per-tensor; `num_channels` for
411    /// per-channel.
412    pub fn scale(&self) -> &[f32] {
413        &self.scale
414    }
415
416    /// Borrow the zero-point array. `None` for symmetric.
417    pub fn zero_point(&self) -> Option<&[i32]> {
418        self.zero_point.as_deref()
419    }
420
421    /// Channel axis for per-channel quantization. `None` for per-tensor.
422    pub fn axis(&self) -> Option<usize> {
423        self.axis
424    }
425
426    /// Validate against a target tensor shape. Runs in
427    /// `Tensor::set_quantization()`. Catches:
428    ///   - empty `scale` (reject — must declare at least one factor)
429    ///   - `zero_point` length inconsistent with `scale` (reject —
430    ///     per-tensor must have len 1, per-channel must match `scale.len`)
431    ///   - `axis >= shape.len()` (axis out of range)
432    ///   - `scale.len() != shape[axis]` for per-channel
433    ///   - per-channel without axis (reject)
434    ///   - per-tensor with redundant axis (reject)
435    pub(crate) fn validate(&self, shape: &[usize]) -> Result<()> {
436        // `Quantization` is `Deserialize`, so malformed JSON like
437        // `{"scale": [], "zero_point": []}` could otherwise produce an
438        // ill-defined value that confuses `mode()` selection and the
439        // per-channel kernels' indexing.
440        if self.scale.is_empty() {
441            return Err(Error::QuantizationInvalid {
442                field: "scale.len",
443                expected: ">= 1".to_string(),
444                got: "0".to_string(),
445            });
446        }
447        if let Some(zps) = self.zero_point.as_ref() {
448            // Per-tensor: scale.len() == 1 and zero_point.len() must == 1.
449            // Per-channel: zero_point.len() must == scale.len().
450            let expected = if self.scale.len() == 1 {
451                1
452            } else {
453                self.scale.len()
454            };
455            if zps.len() != expected {
456                return Err(Error::QuantizationInvalid {
457                    field: "zero_point.len",
458                    expected: format!(
459                        "{expected} (matching {})",
460                        if self.scale.len() == 1 {
461                            "per-tensor scale"
462                        } else {
463                            "per-channel scale.len"
464                        }
465                    ),
466                    got: format!("length {}", zps.len()),
467                });
468            }
469        }
470
471        match (self.scale.len(), self.axis) {
472            (1, None) => Ok(()),
473            (1, Some(_)) => Err(Error::QuantizationInvalid {
474                field: "per_tensor_redundant_axis",
475                expected: "axis=None for per-tensor quantization".to_string(),
476                got: format!("axis={:?}", self.axis),
477            }),
478            (_, None) => Err(Error::QuantizationInvalid {
479                field: "per_channel_requires_axis",
480                expected: format!(
481                    "axis=Some(_) for per-channel quantization (scale.len={})",
482                    self.scale.len()
483                ),
484                got: "axis=None".to_string(),
485            }),
486            (n, Some(axis)) => {
487                if axis >= shape.len() {
488                    return Err(Error::QuantizationInvalid {
489                        field: "axis",
490                        expected: format!("axis < tensor rank ({})", shape.len()),
491                        got: format!("axis={axis}"),
492                    });
493                }
494                if shape[axis] != n {
495                    return Err(Error::QuantizationInvalid {
496                        field: "scale.len",
497                        expected: format!("length matches shape[{axis}] ({})", shape[axis]),
498                        got: format!("length {n}"),
499                    });
500                }
501                Ok(())
502            }
503        }
504    }
505}
506
507impl From<(f32, i32)> for Quantization {
508    /// Convenience construction from a `(scale, zero_point)` tuple. Matches
509    /// the legacy `QuantTuple` / `Quantization::new` calling convention so
510    /// existing `(0.1, -128).into()` sites keep working.
511    fn from((scale, zero_point): (f32, i32)) -> Self {
512        Self::per_tensor(scale, zero_point)
513    }
514}
515
516fn deserialize_scalar_or_vec_f32<'de, D: serde::Deserializer<'de>>(
517    de: D,
518) -> std::result::Result<Vec<f32>, D::Error> {
519    use serde::de::{self, Visitor};
520    struct V;
521    impl<'de> Visitor<'de> for V {
522        type Value = Vec<f32>;
523        fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
524            f.write_str("f32 or array of f32")
525        }
526        fn visit_f64<E: de::Error>(self, v: f64) -> std::result::Result<Self::Value, E> {
527            Ok(vec![v as f32])
528        }
529        #[allow(clippy::cast_possible_truncation)]
530        fn visit_i64<E: de::Error>(self, v: i64) -> std::result::Result<Self::Value, E> {
531            Ok(vec![v as f32])
532        }
533        #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
534        fn visit_u64<E: de::Error>(self, v: u64) -> std::result::Result<Self::Value, E> {
535            Ok(vec![v as f32])
536        }
537        fn visit_seq<A: de::SeqAccess<'de>>(
538            self,
539            mut seq: A,
540        ) -> std::result::Result<Self::Value, A::Error> {
541            let mut out = Vec::with_capacity(seq.size_hint().unwrap_or(1));
542            while let Some(x) = seq.next_element::<f32>()? {
543                out.push(x);
544            }
545            Ok(out)
546        }
547    }
548    de.deserialize_any(V)
549}
550
551fn deserialize_opt_scalar_or_vec_i32<'de, D: serde::Deserializer<'de>>(
552    de: D,
553) -> std::result::Result<Option<Vec<i32>>, D::Error> {
554    use serde::de::{self, Visitor};
555    struct V;
556    impl<'de> Visitor<'de> for V {
557        type Value = Option<Vec<i32>>;
558        fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
559            f.write_str("null, i32, or array of i32")
560        }
561        fn visit_none<E: de::Error>(self) -> std::result::Result<Self::Value, E> {
562            Ok(None)
563        }
564        fn visit_unit<E: de::Error>(self) -> std::result::Result<Self::Value, E> {
565            Ok(None)
566        }
567        fn visit_some<D2: serde::Deserializer<'de>>(
568            self,
569            de: D2,
570        ) -> std::result::Result<Self::Value, D2::Error> {
571            struct Inner;
572            impl<'de> Visitor<'de> for Inner {
573                type Value = Vec<i32>;
574                fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
575                    f.write_str("i32 or array of i32")
576                }
577                #[allow(clippy::cast_possible_truncation)]
578                fn visit_i64<E: de::Error>(self, v: i64) -> std::result::Result<Self::Value, E> {
579                    Ok(vec![v as i32])
580                }
581                #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
582                fn visit_u64<E: de::Error>(self, v: u64) -> std::result::Result<Self::Value, E> {
583                    Ok(vec![v as i32])
584                }
585                fn visit_seq<A: de::SeqAccess<'de>>(
586                    self,
587                    mut seq: A,
588                ) -> std::result::Result<Self::Value, A::Error> {
589                    let mut out = Vec::with_capacity(seq.size_hint().unwrap_or(1));
590                    while let Some(x) = seq.next_element::<i32>()? {
591                        out.push(x);
592                    }
593                    Ok(out)
594                }
595            }
596            de.deserialize_any(Inner).map(Some)
597        }
598        #[allow(clippy::cast_possible_truncation)]
599        fn visit_i64<E: de::Error>(self, v: i64) -> std::result::Result<Self::Value, E> {
600            Ok(Some(vec![v as i32]))
601        }
602        #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
603        fn visit_u64<E: de::Error>(self, v: u64) -> std::result::Result<Self::Value, E> {
604            Ok(Some(vec![v as i32]))
605        }
606        fn visit_seq<A: de::SeqAccess<'de>>(
607            self,
608            mut seq: A,
609        ) -> std::result::Result<Self::Value, A::Error> {
610            let mut out = Vec::with_capacity(seq.size_hint().unwrap_or(1));
611            while let Some(x) = seq.next_element::<i32>()? {
612                out.push(x);
613            }
614            Ok(Some(out))
615        }
616    }
617    de.deserialize_option(V)
618}
619
620/// Monotonic counter for buffer identity IDs.
621static NEXT_BUFFER_ID: AtomicU64 = AtomicU64::new(1);
622
623/// Unique identity for a tensor's underlying buffer.
624///
625/// Created fresh on every buffer allocation or import. The `id` is a monotonic
626/// u64 used as a cache key. The `guard` is an `Arc<()>` whose weak references
627/// allow downstream caches to detect when the buffer has been dropped.
628#[derive(Debug, Clone)]
629pub struct BufferIdentity {
630    id: u64,
631    guard: Arc<()>,
632}
633
634impl BufferIdentity {
635    /// Create a new unique buffer identity.
636    pub fn new() -> Self {
637        Self {
638            id: NEXT_BUFFER_ID.fetch_add(1, Ordering::Relaxed),
639            guard: Arc::new(()),
640        }
641    }
642
643    /// Unique identifier for this buffer. Changes when the buffer changes.
644    pub fn id(&self) -> u64 {
645        self.id
646    }
647
648    /// Returns a weak reference to the buffer guard. Goes dead when the
649    /// owning Tensor is dropped (and no clones remain).
650    pub fn weak(&self) -> Weak<()> {
651        Arc::downgrade(&self.guard)
652    }
653}
654
655impl Default for BufferIdentity {
656    fn default() -> Self {
657        Self::new()
658    }
659}
660
661#[cfg(target_os = "linux")]
662use nix::sys::stat::{major, minor};
663
664pub trait TensorTrait<T>: Send + Sync
665where
666    T: Num + Clone + fmt::Debug,
667{
668    /// Create a new tensor with the given shape and optional name. If no name
669    /// is given, a random name will be generated.
670    fn new(shape: &[usize], name: Option<&str>) -> Result<Self>
671    where
672        Self: Sized;
673
674    #[cfg(unix)]
675    /// Create a new tensor using the given file descriptor, shape, and optional
676    /// name. If no name is given, a random name will be generated.
677    ///
678    /// On Linux: Inspects the fd to determine DMA vs SHM based on device major/minor.
679    /// On other Unix (macOS): Always creates SHM tensor.
680    fn from_fd(fd: std::os::fd::OwnedFd, shape: &[usize], name: Option<&str>) -> Result<Self>
681    where
682        Self: Sized;
683
684    #[cfg(unix)]
685    /// Clone the file descriptor associated with this tensor.
686    fn clone_fd(&self) -> Result<std::os::fd::OwnedFd>;
687
688    /// Get the memory type of this tensor.
689    fn memory(&self) -> TensorMemory;
690
691    /// Get the name of this tensor.
692    fn name(&self) -> String;
693
694    /// Get the number of elements in this tensor.
695    fn len(&self) -> usize {
696        self.shape().iter().product()
697    }
698
699    /// Check if the tensor is empty.
700    fn is_empty(&self) -> bool {
701        self.len() == 0
702    }
703
704    /// Get the size in bytes of this tensor.
705    fn size(&self) -> usize {
706        self.len() * std::mem::size_of::<T>()
707    }
708
709    /// Get the shape of this tensor.
710    fn shape(&self) -> &[usize];
711
712    /// Reshape this tensor to the given shape. The total number of elements
713    /// must remain the same.
714    fn reshape(&mut self, shape: &[usize]) -> Result<()>;
715
716    /// Map the tensor into memory and return a TensorMap for accessing the
717    /// data.
718    fn map(&self) -> Result<TensorMap<T>>;
719
720    /// Get the buffer identity for cache keying and liveness tracking.
721    fn buffer_identity(&self) -> &BufferIdentity;
722}
723
724pub trait TensorMapTrait<T>
725where
726    T: Num + Clone + fmt::Debug,
727{
728    /// Get the shape of this tensor map.
729    fn shape(&self) -> &[usize];
730
731    /// Unmap the tensor from memory.
732    fn unmap(&mut self);
733
734    /// Get the number of elements in this tensor map.
735    fn len(&self) -> usize {
736        self.shape().iter().product()
737    }
738
739    /// Check if the tensor map is empty.
740    fn is_empty(&self) -> bool {
741        self.len() == 0
742    }
743
744    /// Get the size in bytes of this tensor map.
745    fn size(&self) -> usize {
746        self.len() * std::mem::size_of::<T>()
747    }
748
749    /// Get a slice to the data in this tensor map.
750    fn as_slice(&self) -> &[T];
751
752    /// Get a mutable slice to the data in this tensor map.
753    fn as_mut_slice(&mut self) -> &mut [T];
754
755    #[cfg(feature = "ndarray")]
756    /// Get an ndarray ArrayView of the tensor data.
757    fn view(&'_ self) -> Result<ndarray::ArrayView<'_, T, ndarray::Dim<ndarray::IxDynImpl>>> {
758        Ok(ndarray::ArrayView::from_shape(
759            self.shape(),
760            self.as_slice(),
761        )?)
762    }
763
764    #[cfg(feature = "ndarray")]
765    /// Get an ndarray ArrayViewMut of the tensor data.
766    fn view_mut(
767        &'_ mut self,
768    ) -> Result<ndarray::ArrayViewMut<'_, T, ndarray::Dim<ndarray::IxDynImpl>>> {
769        let shape = self.shape().to_vec();
770        Ok(ndarray::ArrayViewMut::from_shape(
771            shape,
772            self.as_mut_slice(),
773        )?)
774    }
775}
776
777#[derive(Debug, Clone, Copy, PartialEq, Eq)]
778pub enum TensorMemory {
779    #[cfg(target_os = "linux")]
780    /// Direct Memory Access (DMA) allocation. Incurs additional
781    /// overhead for memory reading/writing with the CPU.  Allows for
782    /// hardware acceleration when supported.
783    Dma,
784    #[cfg(unix)]
785    /// POSIX Shared Memory allocation. Suitable for inter-process
786    /// communication, but not suitable for hardware acceleration.
787    Shm,
788
789    /// Regular system memory allocation
790    Mem,
791
792    /// OpenGL Pixel Buffer Object memory. Created by ImageProcessor
793    /// when DMA-buf is unavailable but OpenGL is present.
794    Pbo,
795}
796
797impl From<TensorMemory> for String {
798    fn from(memory: TensorMemory) -> Self {
799        match memory {
800            #[cfg(target_os = "linux")]
801            TensorMemory::Dma => "dma".to_owned(),
802            #[cfg(unix)]
803            TensorMemory::Shm => "shm".to_owned(),
804            TensorMemory::Mem => "mem".to_owned(),
805            TensorMemory::Pbo => "pbo".to_owned(),
806        }
807    }
808}
809
810impl TryFrom<&str> for TensorMemory {
811    type Error = Error;
812
813    fn try_from(s: &str) -> Result<Self> {
814        match s {
815            #[cfg(target_os = "linux")]
816            "dma" => Ok(TensorMemory::Dma),
817            #[cfg(unix)]
818            "shm" => Ok(TensorMemory::Shm),
819            "mem" => Ok(TensorMemory::Mem),
820            "pbo" => Ok(TensorMemory::Pbo),
821            _ => Err(Error::InvalidMemoryType(s.to_owned())),
822        }
823    }
824}
825
826#[derive(Debug)]
827#[allow(dead_code)] // Variants are constructed by downstream crates via pub(crate) helpers
828pub(crate) enum TensorStorage<T>
829where
830    T: Num + Clone + fmt::Debug + Send + Sync,
831{
832    #[cfg(target_os = "linux")]
833    Dma(DmaTensor<T>),
834    #[cfg(unix)]
835    Shm(ShmTensor<T>),
836    Mem(MemTensor<T>),
837    Pbo(PboTensor<T>),
838}
839
840impl<T> TensorStorage<T>
841where
842    T: Num + Clone + fmt::Debug + Send + Sync,
843{
844    /// Create a new tensor storage with the given shape, memory type, and
845    /// optional name. If no name is given, a random name will be generated.
846    /// If no memory type is given, the best available memory type will be
847    /// chosen based on the platform and environment variables.
848    fn new(shape: &[usize], memory: Option<TensorMemory>, name: Option<&str>) -> Result<Self> {
849        match memory {
850            #[cfg(target_os = "linux")]
851            Some(TensorMemory::Dma) => {
852                DmaTensor::<T>::new(shape, name).map(TensorStorage::Dma)
853            }
854            #[cfg(unix)]
855            Some(TensorMemory::Shm) => {
856                ShmTensor::<T>::new(shape, name).map(TensorStorage::Shm)
857            }
858            Some(TensorMemory::Mem) => {
859                MemTensor::<T>::new(shape, name).map(TensorStorage::Mem)
860            }
861            Some(TensorMemory::Pbo) => Err(crate::error::Error::NotImplemented(
862                "PboTensor cannot be created via Tensor::new() — use ImageProcessor::create_image()".to_owned(),
863            )),
864            None => {
865                if std::env::var("EDGEFIRST_TENSOR_FORCE_MEM")
866                    .is_ok_and(|x| x != "0" && x.to_lowercase() != "false")
867                {
868                    MemTensor::<T>::new(shape, name).map(TensorStorage::Mem)
869                } else {
870                    #[cfg(target_os = "linux")]
871                    {
872                        // Linux: Try DMA -> SHM -> Mem
873                        match DmaTensor::<T>::new(shape, name) {
874                            Ok(tensor) => Ok(TensorStorage::Dma(tensor)),
875                            Err(_) => {
876                                match ShmTensor::<T>::new(shape, name)
877                                    .map(TensorStorage::Shm)
878                                {
879                                    Ok(tensor) => Ok(tensor),
880                                    Err(_) => MemTensor::<T>::new(shape, name)
881                                        .map(TensorStorage::Mem),
882                                }
883                            }
884                        }
885                    }
886                    #[cfg(all(unix, not(target_os = "linux")))]
887                    {
888                        // macOS/BSD: Try SHM -> Mem (no DMA)
889                        match ShmTensor::<T>::new(shape, name) {
890                            Ok(tensor) => Ok(TensorStorage::Shm(tensor)),
891                            Err(_) => {
892                                MemTensor::<T>::new(shape, name).map(TensorStorage::Mem)
893                            }
894                        }
895                    }
896                    #[cfg(not(unix))]
897                    {
898                        // Windows/other: Mem only
899                        MemTensor::<T>::new(shape, name).map(TensorStorage::Mem)
900                    }
901                }
902            }
903        }
904    }
905
906    /// Create a DMA-backed tensor storage with an explicit byte size that
907    /// may exceed `shape.product() * sizeof(T)`. Used for image tensors
908    /// with row-padded layouts (see `DmaTensor::new_with_byte_size`).
909    ///
910    /// This is intentionally DMA-only: padding is only meaningful for
911    /// buffers that will be imported as GPU textures via EGLImage. PBO,
912    /// Shm, and Mem storage doesn't benefit from pitch alignment and
913    /// shouldn't pay the memory cost.
914    #[cfg(target_os = "linux")]
915    pub(crate) fn new_dma_with_byte_size(
916        shape: &[usize],
917        byte_size: usize,
918        name: Option<&str>,
919    ) -> Result<Self> {
920        DmaTensor::<T>::new_with_byte_size(shape, byte_size, name).map(TensorStorage::Dma)
921    }
922
923    // No non-Linux stub: the only caller (`Tensor::image_with_stride`)
924    // returns `NotImplemented` directly on non-Linux without ever
925    // reaching the storage layer, so defining a stub here would be
926    // dead code and fail the `-D warnings` clippy gate on macOS CI.
927
928    /// Create a new tensor storage using the given file descriptor, shape,
929    /// and optional name.
930    #[cfg(unix)]
931    fn from_fd(fd: OwnedFd, shape: &[usize], name: Option<&str>) -> Result<Self> {
932        #[cfg(target_os = "linux")]
933        {
934            use nix::sys::stat::fstat;
935
936            let stat = fstat(&fd)?;
937            let major = major(stat.st_dev);
938            let minor = minor(stat.st_dev);
939
940            log::debug!("Creating tensor from fd: major={major}, minor={minor}");
941
942            if major != 0 {
943                // Dma and Shm tensors are expected to have major number 0
944                return Err(Error::UnknownDeviceType(major, minor));
945            }
946
947            match minor {
948                9 | 10 => {
949                    // minor number 9 & 10 indicates DMA memory
950                    DmaTensor::<T>::from_fd(fd, shape, name).map(TensorStorage::Dma)
951                }
952                _ => {
953                    // other minor numbers are assumed to be shared memory
954                    ShmTensor::<T>::from_fd(fd, shape, name).map(TensorStorage::Shm)
955                }
956            }
957        }
958        #[cfg(all(unix, not(target_os = "linux")))]
959        {
960            // On macOS/BSD, always use SHM (no DMA support)
961            ShmTensor::<T>::from_fd(fd, shape, name).map(TensorStorage::Shm)
962        }
963    }
964}
965
966impl<T> TensorTrait<T> for TensorStorage<T>
967where
968    T: Num + Clone + fmt::Debug + Send + Sync,
969{
970    fn new(shape: &[usize], name: Option<&str>) -> Result<Self> {
971        Self::new(shape, None, name)
972    }
973
974    #[cfg(unix)]
975    fn from_fd(fd: OwnedFd, shape: &[usize], name: Option<&str>) -> Result<Self> {
976        Self::from_fd(fd, shape, name)
977    }
978
979    #[cfg(unix)]
980    fn clone_fd(&self) -> Result<OwnedFd> {
981        match self {
982            #[cfg(target_os = "linux")]
983            TensorStorage::Dma(t) => t.clone_fd(),
984            TensorStorage::Shm(t) => t.clone_fd(),
985            TensorStorage::Mem(t) => t.clone_fd(),
986            TensorStorage::Pbo(t) => t.clone_fd(),
987        }
988    }
989
990    fn memory(&self) -> TensorMemory {
991        match self {
992            #[cfg(target_os = "linux")]
993            TensorStorage::Dma(_) => TensorMemory::Dma,
994            #[cfg(unix)]
995            TensorStorage::Shm(_) => TensorMemory::Shm,
996            TensorStorage::Mem(_) => TensorMemory::Mem,
997            TensorStorage::Pbo(_) => TensorMemory::Pbo,
998        }
999    }
1000
1001    fn name(&self) -> String {
1002        match self {
1003            #[cfg(target_os = "linux")]
1004            TensorStorage::Dma(t) => t.name(),
1005            #[cfg(unix)]
1006            TensorStorage::Shm(t) => t.name(),
1007            TensorStorage::Mem(t) => t.name(),
1008            TensorStorage::Pbo(t) => t.name(),
1009        }
1010    }
1011
1012    fn shape(&self) -> &[usize] {
1013        match self {
1014            #[cfg(target_os = "linux")]
1015            TensorStorage::Dma(t) => t.shape(),
1016            #[cfg(unix)]
1017            TensorStorage::Shm(t) => t.shape(),
1018            TensorStorage::Mem(t) => t.shape(),
1019            TensorStorage::Pbo(t) => t.shape(),
1020        }
1021    }
1022
1023    fn reshape(&mut self, shape: &[usize]) -> Result<()> {
1024        match self {
1025            #[cfg(target_os = "linux")]
1026            TensorStorage::Dma(t) => t.reshape(shape),
1027            #[cfg(unix)]
1028            TensorStorage::Shm(t) => t.reshape(shape),
1029            TensorStorage::Mem(t) => t.reshape(shape),
1030            TensorStorage::Pbo(t) => t.reshape(shape),
1031        }
1032    }
1033
1034    fn map(&self) -> Result<TensorMap<T>> {
1035        match self {
1036            #[cfg(target_os = "linux")]
1037            TensorStorage::Dma(t) => t.map(),
1038            #[cfg(unix)]
1039            TensorStorage::Shm(t) => t.map(),
1040            TensorStorage::Mem(t) => t.map(),
1041            TensorStorage::Pbo(t) => t.map(),
1042        }
1043    }
1044
1045    fn buffer_identity(&self) -> &BufferIdentity {
1046        match self {
1047            #[cfg(target_os = "linux")]
1048            TensorStorage::Dma(t) => t.buffer_identity(),
1049            #[cfg(unix)]
1050            TensorStorage::Shm(t) => t.buffer_identity(),
1051            TensorStorage::Mem(t) => t.buffer_identity(),
1052            TensorStorage::Pbo(t) => t.buffer_identity(),
1053        }
1054    }
1055}
1056
1057/// Multi-backend tensor with optional image format metadata.
1058///
1059/// When `format` is `Some`, this tensor represents an image. Width, height,
1060/// and channels are derived from `shape` + `format`. When `format` is `None`,
1061/// this is a raw tensor (identical to the pre-refactoring behavior).
1062#[derive(Debug)]
1063pub struct Tensor<T>
1064where
1065    T: Num + Clone + fmt::Debug + Send + Sync,
1066{
1067    pub(crate) storage: TensorStorage<T>,
1068    format: Option<PixelFormat>,
1069    chroma: Option<Box<Tensor<T>>>,
1070    /// Row stride in bytes for externally allocated buffers with row padding.
1071    /// `None` means tightly packed (stride == width * bytes_per_pixel).
1072    row_stride: Option<usize>,
1073    /// Byte offset within the DMA-BUF where image data starts.
1074    /// `None` means offset 0 (data starts at the beginning of the buffer).
1075    plane_offset: Option<usize>,
1076    /// Quantization metadata for integer-typed tensors. Public access is
1077    /// gated by the `IntegerType` trait — `Tensor<f32>` etc. carry the
1078    /// field for layout uniformity but have no way to read or write it.
1079    pub(crate) quantization: Option<Quantization>,
1080}
1081
1082impl<T> Tensor<T>
1083where
1084    T: Num + Clone + fmt::Debug + Send + Sync,
1085{
1086    /// Wrap a TensorStorage in a Tensor with no image metadata.
1087    pub(crate) fn wrap(storage: TensorStorage<T>) -> Self {
1088        Self {
1089            storage,
1090            format: None,
1091            chroma: None,
1092            row_stride: None,
1093            plane_offset: None,
1094            quantization: None,
1095        }
1096    }
1097
1098    /// Construct a tensor from a row-major element slice + shape. Allocates a
1099    /// new buffer (`TensorMemory::Mem`) and memcpys the contents; caller
1100    /// retains ownership of the input slice.
1101    ///
1102    /// # Errors
1103    ///
1104    /// - [`Error::InvalidShape`] if `values.len() != shape.iter().product()`.
1105    /// - Propagates any allocation error from [`Self::new`].
1106    pub fn from_slice(values: &[T], shape: &[usize]) -> Result<Self>
1107    where
1108        T: Copy,
1109    {
1110        let expected: usize = shape.iter().product();
1111        if values.len() != expected {
1112            return Err(Error::InvalidShape(format!(
1113                "from_slice: values.len()={} but shape product={expected} (shape={shape:?})",
1114                values.len()
1115            )));
1116        }
1117        let t = Self::new(shape, Some(TensorMemory::Mem), None)?;
1118        {
1119            let mut m = t.map()?;
1120            m.as_mut_slice().copy_from_slice(values);
1121        }
1122        Ok(t)
1123    }
1124
1125    /// Construct a tensor from a 3-D ndarray view. Respects strides — one
1126    /// copy in all cases; contiguous views take a memcpy fast path.
1127    ///
1128    /// Only available when the `ndarray` feature is enabled.
1129    #[cfg(feature = "ndarray")]
1130    pub fn from_arrayview3(view: ndarray::ArrayView3<'_, T>) -> Result<Self>
1131    where
1132        T: Copy,
1133    {
1134        let (h, w, c) = view.dim();
1135        let t = Self::new(&[h, w, c], Some(TensorMemory::Mem), None)?;
1136        {
1137            let mut m = t.map()?;
1138            let dst = m.as_mut_slice();
1139            if let Some(src) = view.as_slice() {
1140                dst.copy_from_slice(src);
1141            } else {
1142                for (d, &s) in dst.iter_mut().zip(view.iter()) {
1143                    *d = s;
1144                }
1145            }
1146        }
1147        Ok(t)
1148    }
1149
1150    /// Create a new tensor with the given shape, memory type, and optional
1151    /// name. If no name is given, a random name will be generated. If no
1152    /// memory type is given, the best available memory type will be chosen
1153    /// based on the platform and environment variables.
1154    ///
1155    /// On Linux platforms, the order of preference is: Dma -> Shm -> Mem.
1156    /// On other Unix platforms (macOS), the order is: Shm -> Mem.
1157    /// On non-Unix platforms, only Mem is available.
1158    ///
1159    /// # Environment Variables
1160    /// - `EDGEFIRST_TENSOR_FORCE_MEM`: If set to a non-zero and non-false
1161    ///   value, forces the use of regular system memory allocation
1162    ///   (`TensorMemory::Mem`) regardless of platform capabilities.
1163    ///
1164    /// # Example
1165    /// ```rust
1166    /// use edgefirst_tensor::{Error, Tensor, TensorMemory, TensorTrait};
1167    /// # fn main() -> Result<(), Error> {
1168    /// let tensor = Tensor::<f32>::new(&[2, 3, 4], Some(TensorMemory::Mem), Some("test_tensor"))?;
1169    /// assert_eq!(tensor.memory(), TensorMemory::Mem);
1170    /// assert_eq!(tensor.name(), "test_tensor");
1171    /// #    Ok(())
1172    /// # }
1173    /// ```
1174    pub fn new(shape: &[usize], memory: Option<TensorMemory>, name: Option<&str>) -> Result<Self> {
1175        TensorStorage::new(shape, memory, name).map(Self::wrap)
1176    }
1177
1178    /// Create an image tensor with the given format.
1179    pub fn image(
1180        width: usize,
1181        height: usize,
1182        format: PixelFormat,
1183        memory: Option<TensorMemory>,
1184    ) -> Result<Self> {
1185        let shape = match format.layout() {
1186            PixelLayout::Packed => vec![height, width, format.channels()],
1187            PixelLayout::Planar => vec![format.channels(), height, width],
1188            PixelLayout::SemiPlanar => {
1189                // Contiguous semi-planar: luma + interleaved chroma in one allocation.
1190                // NV12 (4:2:0): H lines luma + H/2 lines chroma = H * 3/2 total
1191                // NV16 (4:2:2): H lines luma + H lines chroma = H * 2 total
1192                let total_h = match format {
1193                    PixelFormat::Nv12 => {
1194                        if !height.is_multiple_of(2) {
1195                            return Err(Error::InvalidArgument(format!(
1196                                "NV12 requires even height, got {height}"
1197                            )));
1198                        }
1199                        height * 3 / 2
1200                    }
1201                    PixelFormat::Nv16 => height * 2,
1202                    _ => {
1203                        return Err(Error::InvalidArgument(format!(
1204                            "unknown semi-planar height multiplier for {format:?}"
1205                        )))
1206                    }
1207                };
1208                vec![total_h, width]
1209            }
1210        };
1211        let mut t = Self::new(&shape, memory, None)?;
1212        t.format = Some(format);
1213        Ok(t)
1214    }
1215
1216    /// Create a DMA-backed image tensor with an explicit row stride that
1217    /// may exceed the natural `width * channels * sizeof(T)` pitch.
1218    ///
1219    /// Used for image tensors that need GPU pitch alignment padding: the
1220    /// underlying DMA-BUF is sized to `row_stride * height` bytes, but
1221    /// the tensor's logical shape stays at `[height, width, channels]`.
1222    /// `width()` / `height()` / `shape()` continue to report the
1223    /// user-requested values; the padding is visible only via
1224    /// `row_stride()` / `effective_row_stride()` and is automatically
1225    /// propagated to the GL backend's EGLImage import so Mali Valhall
1226    /// accepts the buffer.
1227    ///
1228    /// # Supported formats
1229    ///
1230    /// Currently only **packed** pixel layouts (RGBA8, BGRA8, RGB888,
1231    /// Grey, etc.) are supported — the formats the GL backend uses as
1232    /// render destinations. Semi-planar formats (NV12, NV16) come from
1233    /// external allocators (camera capture, video decoders) and are
1234    /// imported via `TensorDyn::from_fd` + `set_row_stride`, which
1235    /// already supports padded strides.
1236    ///
1237    /// # Supported memory
1238    ///
1239    /// Currently only `TensorMemory::Dma` is supported. PBO and Mem
1240    /// storage don't go through EGLImage import so they don't need
1241    /// pitch alignment; if you pass any other memory type this returns
1242    /// `NotImplemented`. `None` (auto-select) is treated as `Dma`.
1243    ///
1244    /// # Errors
1245    ///
1246    /// - `InvalidArgument` if `row_stride_bytes < width * channels * sizeof(T)`
1247    ///   (the requested stride would not fit a single row)
1248    /// - `NotImplemented` for non-packed formats or non-DMA memory
1249    /// - `IoError` if the DMA-heap allocation fails (propagated from
1250    ///   `DmaTensor::new_with_byte_size`)
1251    pub fn image_with_stride(
1252        width: usize,
1253        height: usize,
1254        format: PixelFormat,
1255        row_stride_bytes: usize,
1256        memory: Option<TensorMemory>,
1257    ) -> Result<Self> {
1258        // DMA backing (the only thing this constructor produces) is
1259        // Linux-only. On macOS/BSD/Windows the non-Linux block below is
1260        // the only compiled body and returns `NotImplemented` directly;
1261        // on Linux the non-Linux block is cfg-removed and the function
1262        // falls through to the real validation + allocation path. Each
1263        // target compiles exactly one of the two blocks, and the block
1264        // serves as the function's tail expression in both cases — so
1265        // neither needs an explicit `return` (avoids
1266        // `clippy::needless_return` on the macOS CI gate).
1267        #[cfg(not(target_os = "linux"))]
1268        {
1269            let _ = (width, height, format, row_stride_bytes, memory);
1270            Err(Error::NotImplemented(
1271                "image_with_stride requires DMA support (Linux only)".to_owned(),
1272            ))
1273        }
1274
1275        #[cfg(target_os = "linux")]
1276        {
1277            if format.layout() != PixelLayout::Packed {
1278                return Err(Error::NotImplemented(format!(
1279                    "Tensor::image_with_stride only supports packed pixel layouts, got {format:?}"
1280                )));
1281            }
1282            let elem = std::mem::size_of::<T>();
1283            let min_stride = width
1284                .checked_mul(format.channels())
1285                .and_then(|p| p.checked_mul(elem))
1286                .ok_or_else(|| {
1287                    Error::InvalidArgument(format!(
1288                        "image_with_stride: width {width} × channels {} × sizeof::<T>={elem} \
1289                         overflows usize",
1290                        format.channels()
1291                    ))
1292                })?;
1293            if row_stride_bytes < min_stride {
1294                return Err(Error::InvalidArgument(format!(
1295                    "image_with_stride: row_stride {row_stride_bytes} < minimum {min_stride} \
1296                     ({width} px × {} ch × {elem} B)",
1297                    format.channels()
1298                )));
1299            }
1300            let total_byte_size = row_stride_bytes.checked_mul(height).ok_or_else(|| {
1301                Error::InvalidArgument(format!(
1302                    "image_with_stride: row_stride {row_stride_bytes} × height {height} overflows usize"
1303                ))
1304            })?;
1305
1306            let shape = vec![height, width, format.channels()];
1307
1308            let storage = match memory {
1309                Some(TensorMemory::Dma) | None => {
1310                    TensorStorage::<T>::new_dma_with_byte_size(&shape, total_byte_size, None)?
1311                }
1312                Some(other) => {
1313                    return Err(Error::NotImplemented(format!(
1314                        "image_with_stride: only TensorMemory::Dma is supported, got {other:?}"
1315                    )));
1316                }
1317            };
1318
1319            let mut t = Self::wrap(storage);
1320            t.format = Some(format);
1321            t.row_stride = Some(row_stride_bytes);
1322            Ok(t)
1323        }
1324    }
1325
1326    /// Attach format metadata to an existing tensor.
1327    ///
1328    /// # Arguments
1329    ///
1330    /// * `format` - The pixel format to attach
1331    ///
1332    /// # Returns
1333    ///
1334    /// `Ok(())` on success, with the format stored as metadata on the tensor.
1335    ///
1336    /// # Errors
1337    ///
1338    /// Returns `Error::InvalidShape` if the tensor shape is incompatible with
1339    /// the format's layout (packed expects `[H, W, C]`, planar expects
1340    /// `[C, H, W]`, semi-planar expects `[H*k, W]` with format-specific
1341    /// height constraints).
1342    pub fn set_format(&mut self, format: PixelFormat) -> Result<()> {
1343        let shape = self.shape();
1344        match format.layout() {
1345            PixelLayout::Packed => {
1346                if shape.len() != 3 || shape[2] != format.channels() {
1347                    return Err(Error::InvalidShape(format!(
1348                        "packed format {format:?} expects [H, W, {}], got {shape:?}",
1349                        format.channels()
1350                    )));
1351                }
1352            }
1353            PixelLayout::Planar => {
1354                if shape.len() != 3 || shape[0] != format.channels() {
1355                    return Err(Error::InvalidShape(format!(
1356                        "planar format {format:?} expects [{}, H, W], got {shape:?}",
1357                        format.channels()
1358                    )));
1359                }
1360            }
1361            PixelLayout::SemiPlanar => {
1362                if shape.len() != 2 {
1363                    return Err(Error::InvalidShape(format!(
1364                        "semi-planar format {format:?} expects [H*k, W], got {shape:?}"
1365                    )));
1366                }
1367                match format {
1368                    PixelFormat::Nv12 if !shape[0].is_multiple_of(3) => {
1369                        return Err(Error::InvalidShape(format!(
1370                            "NV12 contiguous shape[0] must be divisible by 3, got {}",
1371                            shape[0]
1372                        )));
1373                    }
1374                    PixelFormat::Nv16 if !shape[0].is_multiple_of(2) => {
1375                        return Err(Error::InvalidShape(format!(
1376                            "NV16 contiguous shape[0] must be even, got {}",
1377                            shape[0]
1378                        )));
1379                    }
1380                    _ => {}
1381                }
1382            }
1383        }
1384        // Clear stored stride/offset when format changes — they may be invalid
1385        // for the new format. Caller must re-set after changing format.
1386        if self.format != Some(format) {
1387            self.row_stride = None;
1388            self.plane_offset = None;
1389            #[cfg(target_os = "linux")]
1390            if let TensorStorage::Dma(ref mut dma) = self.storage {
1391                dma.mmap_offset = 0;
1392            }
1393        }
1394        self.format = Some(format);
1395        Ok(())
1396    }
1397
1398    /// Pixel format (None if not an image).
1399    pub fn format(&self) -> Option<PixelFormat> {
1400        self.format
1401    }
1402
1403    /// Image width (None if not an image).
1404    pub fn width(&self) -> Option<usize> {
1405        let fmt = self.format?;
1406        let shape = self.shape();
1407        match fmt.layout() {
1408            PixelLayout::Packed => Some(shape[1]),
1409            PixelLayout::Planar => Some(shape[2]),
1410            PixelLayout::SemiPlanar => Some(shape[1]),
1411        }
1412    }
1413
1414    /// Image height (None if not an image).
1415    pub fn height(&self) -> Option<usize> {
1416        let fmt = self.format?;
1417        let shape = self.shape();
1418        match fmt.layout() {
1419            PixelLayout::Packed => Some(shape[0]),
1420            PixelLayout::Planar => Some(shape[1]),
1421            PixelLayout::SemiPlanar => {
1422                if self.is_multiplane() {
1423                    Some(shape[0])
1424                } else {
1425                    match fmt {
1426                        PixelFormat::Nv12 => Some(shape[0] * 2 / 3),
1427                        PixelFormat::Nv16 => Some(shape[0] / 2),
1428                        _ => None,
1429                    }
1430                }
1431            }
1432        }
1433    }
1434
1435    /// Create from separate Y and UV planes (multiplane NV12/NV16).
1436    pub fn from_planes(luma: Tensor<T>, chroma: Tensor<T>, format: PixelFormat) -> Result<Self> {
1437        if format.layout() != PixelLayout::SemiPlanar {
1438            return Err(Error::InvalidArgument(format!(
1439                "from_planes requires a semi-planar format, got {format:?}"
1440            )));
1441        }
1442        if chroma.format.is_some() || chroma.chroma.is_some() {
1443            return Err(Error::InvalidArgument(
1444                "chroma tensor must be a raw tensor (no format or chroma metadata)".into(),
1445            ));
1446        }
1447        let luma_shape = luma.shape();
1448        let chroma_shape = chroma.shape();
1449        if luma_shape.len() != 2 || chroma_shape.len() != 2 {
1450            return Err(Error::InvalidArgument(format!(
1451                "from_planes expects 2D shapes, got luma={luma_shape:?} chroma={chroma_shape:?}"
1452            )));
1453        }
1454        if luma_shape[1] != chroma_shape[1] {
1455            return Err(Error::InvalidArgument(format!(
1456                "luma width {} != chroma width {}",
1457                luma_shape[1], chroma_shape[1]
1458            )));
1459        }
1460        match format {
1461            PixelFormat::Nv12 => {
1462                if luma_shape[0] % 2 != 0 {
1463                    return Err(Error::InvalidArgument(format!(
1464                        "NV12 requires even luma height, got {}",
1465                        luma_shape[0]
1466                    )));
1467                }
1468                if chroma_shape[0] != luma_shape[0] / 2 {
1469                    return Err(Error::InvalidArgument(format!(
1470                        "NV12 chroma height {} != luma height / 2 ({})",
1471                        chroma_shape[0],
1472                        luma_shape[0] / 2
1473                    )));
1474                }
1475            }
1476            PixelFormat::Nv16 => {
1477                if chroma_shape[0] != luma_shape[0] {
1478                    return Err(Error::InvalidArgument(format!(
1479                        "NV16 chroma height {} != luma height {}",
1480                        chroma_shape[0], luma_shape[0]
1481                    )));
1482                }
1483            }
1484            _ => {
1485                return Err(Error::InvalidArgument(format!(
1486                    "from_planes only supports NV12 and NV16, got {format:?}"
1487                )));
1488            }
1489        }
1490
1491        Ok(Tensor {
1492            storage: luma.storage,
1493            format: Some(format),
1494            chroma: Some(Box::new(chroma)),
1495            row_stride: luma.row_stride,
1496            plane_offset: luma.plane_offset,
1497            quantization: luma.quantization,
1498        })
1499    }
1500
1501    /// Whether this tensor uses separate plane allocations.
1502    pub fn is_multiplane(&self) -> bool {
1503        self.chroma.is_some()
1504    }
1505
1506    /// Access the chroma plane for multiplane semi-planar images.
1507    pub fn chroma(&self) -> Option<&Tensor<T>> {
1508        self.chroma.as_deref()
1509    }
1510
1511    /// Mutable access to the chroma plane for multiplane semi-planar images.
1512    pub fn chroma_mut(&mut self) -> Option<&mut Tensor<T>> {
1513        self.chroma.as_deref_mut()
1514    }
1515
1516    /// Row stride in bytes (`None` = tightly packed).
1517    pub fn row_stride(&self) -> Option<usize> {
1518        self.row_stride
1519    }
1520
1521    /// Effective row stride in bytes: the stored stride if set, otherwise the
1522    /// minimum stride computed from the format, width, and element size.
1523    /// Returns `None` only when no format is set and no explicit stride was
1524    /// stored via [`set_row_stride`](Self::set_row_stride).
1525    pub fn effective_row_stride(&self) -> Option<usize> {
1526        if let Some(s) = self.row_stride {
1527            return Some(s);
1528        }
1529        let fmt = self.format?;
1530        let w = self.width()?;
1531        let elem = std::mem::size_of::<T>();
1532        Some(match fmt.layout() {
1533            PixelLayout::Packed => w * fmt.channels() * elem,
1534            PixelLayout::Planar | PixelLayout::SemiPlanar => w * elem,
1535        })
1536    }
1537
1538    /// Set the row stride in bytes for externally allocated buffers with
1539    /// row padding (e.g. V4L2 or GStreamer allocators).
1540    ///
1541    /// The stride is propagated to the EGL DMA-BUF import attributes so
1542    /// the GPU interprets the padded buffer layout correctly. Must be
1543    /// called after [`set_format`](Self::set_format) and before the tensor
1544    /// is first passed to [`ImageProcessor::convert`]. The stored stride
1545    /// is cleared automatically if the pixel format is later changed.
1546    ///
1547    /// No stride-vs-buffer-size validation is performed because the
1548    /// backing allocation size is not reliably known: external DMA-BUFs
1549    /// may be over-allocated by the allocator, and internal tensors store
1550    /// a logical (unpadded) shape. An incorrect stride will be caught by
1551    /// the EGL driver at import time.
1552    ///
1553    /// # Arguments
1554    ///
1555    /// * `stride` - Row stride in bytes. Must be >= the minimum stride for
1556    ///   the format (width * channels * sizeof(T) for packed,
1557    ///   width * sizeof(T) for planar/semi-planar).
1558    ///
1559    /// # Errors
1560    ///
1561    /// * `InvalidArgument` if no pixel format is set on this tensor
1562    /// * `InvalidArgument` if `stride` is less than the minimum for the
1563    ///   format and width
1564    pub fn set_row_stride(&mut self, stride: usize) -> Result<()> {
1565        let fmt = self.format.ok_or_else(|| {
1566            Error::InvalidArgument("cannot set row_stride without a pixel format".into())
1567        })?;
1568        let w = self.width().ok_or_else(|| {
1569            Error::InvalidArgument("cannot determine width for row_stride validation".into())
1570        })?;
1571        let elem = std::mem::size_of::<T>();
1572        let min_stride = match fmt.layout() {
1573            PixelLayout::Packed => w * fmt.channels() * elem,
1574            PixelLayout::Planar | PixelLayout::SemiPlanar => w * elem,
1575        };
1576        if stride < min_stride {
1577            return Err(Error::InvalidArgument(format!(
1578                "row_stride {stride} < minimum {min_stride} for {fmt:?} at width {w}"
1579            )));
1580        }
1581        self.row_stride = Some(stride);
1582        Ok(())
1583    }
1584
1585    /// Set the row stride without format validation.
1586    ///
1587    /// Use this for raw sub-tensors (e.g. chroma planes) that don't carry
1588    /// format metadata. The caller is responsible for ensuring the stride
1589    /// is valid.
1590    pub fn set_row_stride_unchecked(&mut self, stride: usize) {
1591        self.row_stride = Some(stride);
1592    }
1593
1594    /// Builder-style variant of [`set_row_stride`](Self::set_row_stride),
1595    /// consuming and returning `self`.
1596    ///
1597    /// # Errors
1598    ///
1599    /// Same conditions as [`set_row_stride`](Self::set_row_stride).
1600    pub fn with_row_stride(mut self, stride: usize) -> Result<Self> {
1601        self.set_row_stride(stride)?;
1602        Ok(self)
1603    }
1604
1605    /// Byte offset within the DMA-BUF where image data starts (`None` = 0).
1606    pub fn plane_offset(&self) -> Option<usize> {
1607        self.plane_offset
1608    }
1609
1610    /// Set the byte offset within the DMA-BUF where image data starts.
1611    ///
1612    /// Propagated to `EGL_DMA_BUF_PLANE0_OFFSET_EXT` on GPU import.
1613    /// Unlike [`set_row_stride`](Self::set_row_stride), no format is required
1614    /// since the offset is format-independent.
1615    pub fn set_plane_offset(&mut self, offset: usize) {
1616        self.plane_offset = Some(offset);
1617        #[cfg(target_os = "linux")]
1618        if let TensorStorage::Dma(ref mut dma) = self.storage {
1619            dma.mmap_offset = offset;
1620        }
1621    }
1622
1623    /// Builder-style variant of [`set_plane_offset`](Self::set_plane_offset),
1624    /// consuming and returning `self`.
1625    pub fn with_plane_offset(mut self, offset: usize) -> Self {
1626        self.set_plane_offset(offset);
1627        self
1628    }
1629
1630    /// Downcast to PBO tensor reference (for GL backends).
1631    pub fn as_pbo(&self) -> Option<&PboTensor<T>> {
1632        match &self.storage {
1633            TensorStorage::Pbo(p) => Some(p),
1634            _ => None,
1635        }
1636    }
1637
1638    /// Downcast to DMA tensor reference (for EGL import, G2D).
1639    #[cfg(target_os = "linux")]
1640    pub fn as_dma(&self) -> Option<&DmaTensor<T>> {
1641        match &self.storage {
1642            TensorStorage::Dma(d) => Some(d),
1643            _ => None,
1644        }
1645    }
1646
1647    /// Borrow the DMA-BUF file descriptor backing this tensor.
1648    ///
1649    /// # Returns
1650    ///
1651    /// A borrowed reference to the DMA-BUF file descriptor, tied to `self`'s
1652    /// lifetime.
1653    ///
1654    /// # Errors
1655    ///
1656    /// Returns `Error::NotImplemented` if the tensor is not DMA-backed.
1657    #[cfg(target_os = "linux")]
1658    pub fn dmabuf(&self) -> Result<std::os::fd::BorrowedFd<'_>> {
1659        use std::os::fd::AsFd;
1660        match &self.storage {
1661            TensorStorage::Dma(dma) => Ok(dma.fd.as_fd()),
1662            _ => Err(Error::NotImplemented(format!(
1663                "dmabuf requires DMA-backed tensor, got {:?}",
1664                self.storage.memory()
1665            ))),
1666        }
1667    }
1668
1669    /// Construct a Tensor from a PBO tensor (for GL backends that allocate PBOs).
1670    pub fn from_pbo(pbo: PboTensor<T>) -> Self {
1671        Self {
1672            storage: TensorStorage::Pbo(pbo),
1673            format: None,
1674            chroma: None,
1675            row_stride: None,
1676            plane_offset: None,
1677            quantization: None,
1678        }
1679    }
1680}
1681
1682// Quantization accessors — type-gated to integer element types via the
1683// sealed `IntegerType` trait. Calling `.quantization()` on a `Tensor<f32>`
1684// produces a compile error, not a runtime one.
1685impl<T> Tensor<T>
1686where
1687    T: IntegerType + Num + Clone + fmt::Debug + Send + Sync,
1688{
1689    /// Quantization metadata for this tensor, if set.
1690    pub fn quantization(&self) -> Option<&Quantization> {
1691        self.quantization.as_ref()
1692    }
1693
1694    /// Attach quantization metadata to this tensor. Validates against the
1695    /// tensor's shape — returns [`Error::QuantizationInvalid`] on any
1696    /// inconsistency (mismatched scale/zp lengths, out-of-range axis, etc.).
1697    pub fn set_quantization(&mut self, q: Quantization) -> Result<()> {
1698        q.validate(self.shape())?;
1699        self.quantization = Some(q);
1700        Ok(())
1701    }
1702
1703    /// Builder-style variant of [`Self::set_quantization`]. Consumes `self`
1704    /// and returns `Result<Self>` — on success yields the tensor with the
1705    /// attached quantization; on validation failure returns
1706    /// [`Error::QuantizationInvalid`] and drops `self` (the tensor is not
1707    /// returned in the error arm).
1708    pub fn with_quantization(mut self, q: Quantization) -> Result<Self> {
1709        self.set_quantization(q)?;
1710        Ok(self)
1711    }
1712
1713    /// Clear any quantization metadata on this tensor.
1714    pub fn clear_quantization(&mut self) {
1715        self.quantization = None;
1716    }
1717}
1718
1719impl<T> TensorTrait<T> for Tensor<T>
1720where
1721    T: Num + Clone + fmt::Debug + Send + Sync,
1722{
1723    fn new(shape: &[usize], name: Option<&str>) -> Result<Self>
1724    where
1725        Self: Sized,
1726    {
1727        Self::new(shape, None, name)
1728    }
1729
1730    #[cfg(unix)]
1731    fn from_fd(fd: std::os::fd::OwnedFd, shape: &[usize], name: Option<&str>) -> Result<Self>
1732    where
1733        Self: Sized,
1734    {
1735        Ok(Self::wrap(TensorStorage::from_fd(fd, shape, name)?))
1736    }
1737
1738    #[cfg(unix)]
1739    fn clone_fd(&self) -> Result<std::os::fd::OwnedFd> {
1740        self.storage.clone_fd()
1741    }
1742
1743    fn memory(&self) -> TensorMemory {
1744        self.storage.memory()
1745    }
1746
1747    fn name(&self) -> String {
1748        self.storage.name()
1749    }
1750
1751    fn shape(&self) -> &[usize] {
1752        self.storage.shape()
1753    }
1754
1755    fn reshape(&mut self, shape: &[usize]) -> Result<()> {
1756        if self.chroma.is_some() {
1757            return Err(Error::InvalidOperation(
1758                "cannot reshape a multiplane tensor — decompose planes first".into(),
1759            ));
1760        }
1761        self.storage.reshape(shape)?;
1762        self.format = None;
1763        self.row_stride = None;
1764        self.plane_offset = None;
1765        #[cfg(target_os = "linux")]
1766        if let TensorStorage::Dma(ref mut dma) = self.storage {
1767            dma.mmap_offset = 0;
1768        }
1769        Ok(())
1770    }
1771
1772    fn map(&self) -> Result<TensorMap<T>> {
1773        // CPU mapping of strided tensors is allowed only when the HAL
1774        // owns the underlying allocation — i.e. self-allocated DMA
1775        // tensors with pitch padding added by `image_with_stride()`
1776        // for GPU import alignment. In that case we know the buffer
1777        // is exactly `row_stride × height` bytes (for packed formats)
1778        // and callers that respect the stride can iterate rows
1779        // correctly via `effective_row_stride()`.
1780        //
1781        // Foreign DMA-BUFs imported via `from_fd()` + `set_row_stride()`
1782        // (the V4L2 / GStreamer case) are rejected: their layout comes
1783        // from an external allocator and the HAL cannot validate what
1784        // the caller expects the mapping to look like. Those tensors
1785        // are intended for the GPU path only.
1786        //
1787        // The cfg split keeps `stride` from being an unused binding on
1788        // non-Linux builds (the Linux branch is the only consumer).
1789        #[cfg(target_os = "linux")]
1790        if let Some(stride) = self.row_stride {
1791            if let TensorStorage::Dma(dma) = &self.storage {
1792                if !dma.is_imported {
1793                    // Self-allocated strided DMA tensor — expose the
1794                    // full stride×height padded mmap via the override
1795                    // constructor so callers can iterate rows with
1796                    // `effective_row_stride()` without going past
1797                    // the end of the returned slice.
1798                    //
1799                    // Validate the requested mapping fits inside the
1800                    // actual DMA-BUF. `set_row_stride()` is a public
1801                    // API and only validates `stride >= min_stride`,
1802                    // not `stride × height <= buf_size`, so a caller
1803                    // that tampers with the stride after allocation
1804                    // could otherwise request a slice larger than the
1805                    // underlying mmap — which would be undefined
1806                    // behaviour in `DmaMap::as_slice`.
1807                    //
1808                    // Refuse to map if `height()` can't be derived
1809                    // (e.g. raw 2D tensors without a PixelFormat that
1810                    // got a `row_stride` set via `set_row_stride_unchecked`).
1811                    // Returning a 0-byte view would silently truncate
1812                    // rather than surface the misuse.
1813                    let height = self.height().ok_or_else(|| {
1814                        Error::InvalidOperation(
1815                            "Tensor::map: strided DMA mapping requires a PixelFormat \
1816                             so height() can be derived; set a format before mapping \
1817                             or clear row_stride for raw tensor access"
1818                                .into(),
1819                        )
1820                    })?;
1821                    let total_bytes = stride.checked_mul(height).ok_or_else(|| {
1822                        Error::InvalidOperation(format!(
1823                            "Tensor::map: row_stride {stride} × height {height} overflows usize"
1824                        ))
1825                    })?;
1826                    let available_bytes = dma.buf_size.saturating_sub(dma.mmap_offset);
1827                    if total_bytes > available_bytes {
1828                        return Err(Error::InvalidOperation(format!(
1829                            "Tensor::map: strided mapping needs {total_bytes} bytes \
1830                             but DMA buffer only has {available_bytes} available \
1831                             (buf_size={}, mmap_offset={}, stride={stride}, height={height}); \
1832                             the row_stride was likely set larger than the original allocation",
1833                            dma.buf_size, dma.mmap_offset
1834                        )));
1835                    }
1836                    return dma.map_with_byte_size(total_bytes).map(TensorMap::Dma);
1837                }
1838            }
1839            return Err(Error::InvalidOperation(
1840                "CPU mapping of strided foreign tensors is not supported; \
1841                 use GPU path only"
1842                    .into(),
1843            ));
1844        }
1845        #[cfg(not(target_os = "linux"))]
1846        if self.row_stride.is_some() {
1847            return Err(Error::InvalidOperation(
1848                "CPU mapping of strided tensors is not supported on this \
1849                 platform (DMA backing is Linux-only)"
1850                    .into(),
1851            ));
1852        }
1853        // Offset tensors are supported for DMA storage — DmaMap adjusts the
1854        // mmap range and slice start position.  Non-DMA offset tensors are
1855        // not meaningful (offset only applies to DMA-BUF sub-regions).
1856        if self.plane_offset.is_some_and(|o| o > 0) {
1857            #[cfg(target_os = "linux")]
1858            if !matches!(self.storage, TensorStorage::Dma(_)) {
1859                return Err(Error::InvalidOperation(
1860                    "plane offset only supported for DMA tensors".into(),
1861                ));
1862            }
1863            #[cfg(not(target_os = "linux"))]
1864            return Err(Error::InvalidOperation(
1865                "plane offset only supported for DMA tensors".into(),
1866            ));
1867        }
1868        self.storage.map()
1869    }
1870
1871    fn buffer_identity(&self) -> &BufferIdentity {
1872        self.storage.buffer_identity()
1873    }
1874}
1875
1876pub enum TensorMap<T>
1877where
1878    T: Num + Clone + fmt::Debug,
1879{
1880    #[cfg(target_os = "linux")]
1881    Dma(DmaMap<T>),
1882    #[cfg(unix)]
1883    Shm(ShmMap<T>),
1884    Mem(MemMap<T>),
1885    Pbo(PboMap<T>),
1886}
1887
1888impl<T> TensorMapTrait<T> for TensorMap<T>
1889where
1890    T: Num + Clone + fmt::Debug,
1891{
1892    fn shape(&self) -> &[usize] {
1893        match self {
1894            #[cfg(target_os = "linux")]
1895            TensorMap::Dma(map) => map.shape(),
1896            #[cfg(unix)]
1897            TensorMap::Shm(map) => map.shape(),
1898            TensorMap::Mem(map) => map.shape(),
1899            TensorMap::Pbo(map) => map.shape(),
1900        }
1901    }
1902
1903    fn unmap(&mut self) {
1904        match self {
1905            #[cfg(target_os = "linux")]
1906            TensorMap::Dma(map) => map.unmap(),
1907            #[cfg(unix)]
1908            TensorMap::Shm(map) => map.unmap(),
1909            TensorMap::Mem(map) => map.unmap(),
1910            TensorMap::Pbo(map) => map.unmap(),
1911        }
1912    }
1913
1914    fn as_slice(&self) -> &[T] {
1915        match self {
1916            #[cfg(target_os = "linux")]
1917            TensorMap::Dma(map) => map.as_slice(),
1918            #[cfg(unix)]
1919            TensorMap::Shm(map) => map.as_slice(),
1920            TensorMap::Mem(map) => map.as_slice(),
1921            TensorMap::Pbo(map) => map.as_slice(),
1922        }
1923    }
1924
1925    fn as_mut_slice(&mut self) -> &mut [T] {
1926        match self {
1927            #[cfg(target_os = "linux")]
1928            TensorMap::Dma(map) => map.as_mut_slice(),
1929            #[cfg(unix)]
1930            TensorMap::Shm(map) => map.as_mut_slice(),
1931            TensorMap::Mem(map) => map.as_mut_slice(),
1932            TensorMap::Pbo(map) => map.as_mut_slice(),
1933        }
1934    }
1935}
1936
1937impl<T> Deref for TensorMap<T>
1938where
1939    T: Num + Clone + fmt::Debug,
1940{
1941    type Target = [T];
1942
1943    fn deref(&self) -> &[T] {
1944        match self {
1945            #[cfg(target_os = "linux")]
1946            TensorMap::Dma(map) => map.deref(),
1947            #[cfg(unix)]
1948            TensorMap::Shm(map) => map.deref(),
1949            TensorMap::Mem(map) => map.deref(),
1950            TensorMap::Pbo(map) => map.deref(),
1951        }
1952    }
1953}
1954
1955impl<T> DerefMut for TensorMap<T>
1956where
1957    T: Num + Clone + fmt::Debug,
1958{
1959    fn deref_mut(&mut self) -> &mut [T] {
1960        match self {
1961            #[cfg(target_os = "linux")]
1962            TensorMap::Dma(map) => map.deref_mut(),
1963            #[cfg(unix)]
1964            TensorMap::Shm(map) => map.deref_mut(),
1965            TensorMap::Mem(map) => map.deref_mut(),
1966            TensorMap::Pbo(map) => map.deref_mut(),
1967        }
1968    }
1969}
1970
1971// ============================================================================
1972// Platform availability helpers
1973// ============================================================================
1974
1975/// Check if DMA memory allocation is available on this system.
1976///
1977/// Returns `true` only on Linux systems with DMA-BUF heap access (typically
1978/// requires running as root or membership in a video/render group).
1979/// Always returns `false` on non-Linux platforms (macOS, Windows, etc.).
1980///
1981/// This function caches its result after the first call for efficiency.
1982#[cfg(target_os = "linux")]
1983static DMA_AVAILABLE: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
1984
1985/// Check if DMA memory allocation is available on this system.
1986#[cfg(target_os = "linux")]
1987pub fn is_dma_available() -> bool {
1988    *DMA_AVAILABLE.get_or_init(|| Tensor::<u8>::new(&[64], Some(TensorMemory::Dma), None).is_ok())
1989}
1990
1991/// Check if DMA memory allocation is available on this system.
1992///
1993/// Always returns `false` on non-Linux platforms since DMA-BUF is Linux-specific.
1994#[cfg(not(target_os = "linux"))]
1995pub fn is_dma_available() -> bool {
1996    false
1997}
1998
1999/// Check if POSIX shared memory allocation is available on this system.
2000///
2001/// Returns `true` on Unix systems (Linux, macOS, BSD) where POSIX shared memory
2002/// is supported. Always returns `false` on non-Unix platforms (Windows).
2003///
2004/// This function caches its result after the first call for efficiency.
2005#[cfg(unix)]
2006static SHM_AVAILABLE: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
2007
2008/// Check if POSIX shared memory allocation is available on this system.
2009#[cfg(unix)]
2010pub fn is_shm_available() -> bool {
2011    *SHM_AVAILABLE.get_or_init(|| Tensor::<u8>::new(&[64], Some(TensorMemory::Shm), None).is_ok())
2012}
2013
2014/// Check if POSIX shared memory allocation is available on this system.
2015///
2016/// Always returns `false` on non-Unix platforms since POSIX SHM is Unix-specific.
2017#[cfg(not(unix))]
2018pub fn is_shm_available() -> bool {
2019    false
2020}
2021
2022#[cfg(test)]
2023mod dtype_tests {
2024    use super::*;
2025
2026    #[test]
2027    fn dtype_size() {
2028        assert_eq!(DType::U8.size(), 1);
2029        assert_eq!(DType::I8.size(), 1);
2030        assert_eq!(DType::U16.size(), 2);
2031        assert_eq!(DType::I16.size(), 2);
2032        assert_eq!(DType::U32.size(), 4);
2033        assert_eq!(DType::I32.size(), 4);
2034        assert_eq!(DType::U64.size(), 8);
2035        assert_eq!(DType::I64.size(), 8);
2036        assert_eq!(DType::F16.size(), 2);
2037        assert_eq!(DType::F32.size(), 4);
2038        assert_eq!(DType::F64.size(), 8);
2039    }
2040
2041    #[test]
2042    fn dtype_name() {
2043        assert_eq!(DType::U8.name(), "u8");
2044        assert_eq!(DType::F16.name(), "f16");
2045        assert_eq!(DType::F32.name(), "f32");
2046    }
2047
2048    #[test]
2049    fn dtype_serde_roundtrip() {
2050        use serde_json;
2051        let dt = DType::F16;
2052        let json = serde_json::to_string(&dt).unwrap();
2053        let back: DType = serde_json::from_str(&json).unwrap();
2054        assert_eq!(dt, back);
2055    }
2056}
2057
2058#[cfg(test)]
2059mod image_tests {
2060    use super::*;
2061
2062    #[test]
2063    fn raw_tensor_has_no_format() {
2064        let t = Tensor::<u8>::new(&[480, 640, 3], None, None).unwrap();
2065        assert!(t.format().is_none());
2066        assert!(t.width().is_none());
2067        assert!(t.height().is_none());
2068        assert!(!t.is_multiplane());
2069        assert!(t.chroma().is_none());
2070    }
2071
2072    #[test]
2073    fn image_tensor_packed() {
2074        let t = Tensor::<u8>::image(640, 480, PixelFormat::Rgba, None).unwrap();
2075        assert_eq!(t.format(), Some(PixelFormat::Rgba));
2076        assert_eq!(t.width(), Some(640));
2077        assert_eq!(t.height(), Some(480));
2078        assert_eq!(t.shape(), &[480, 640, 4]);
2079        assert!(!t.is_multiplane());
2080    }
2081
2082    #[test]
2083    fn image_tensor_planar() {
2084        let t = Tensor::<u8>::image(640, 480, PixelFormat::PlanarRgb, None).unwrap();
2085        assert_eq!(t.format(), Some(PixelFormat::PlanarRgb));
2086        assert_eq!(t.width(), Some(640));
2087        assert_eq!(t.height(), Some(480));
2088        assert_eq!(t.shape(), &[3, 480, 640]);
2089    }
2090
2091    #[test]
2092    fn image_tensor_semi_planar_contiguous() {
2093        let t = Tensor::<u8>::image(640, 480, PixelFormat::Nv12, None).unwrap();
2094        assert_eq!(t.format(), Some(PixelFormat::Nv12));
2095        assert_eq!(t.width(), Some(640));
2096        assert_eq!(t.height(), Some(480));
2097        // NV12: H*3/2 = 720
2098        assert_eq!(t.shape(), &[720, 640]);
2099        assert!(!t.is_multiplane());
2100    }
2101
2102    #[test]
2103    #[cfg(target_os = "linux")]
2104    fn image_tensor_with_stride_preserves_logical_width() {
2105        // Skip if DMA not available (e.g. sandboxed CI lacking dma_heap access).
2106        if !is_dma_available() {
2107            eprintln!("SKIPPED: DMA heap not available");
2108            return;
2109        }
2110        // 3004×1688 RGBA8: natural pitch 12016, padded to 12032 (64-aligned).
2111        let stride = 12032;
2112        let t = Tensor::<u8>::image_with_stride(
2113            3004,
2114            1688,
2115            PixelFormat::Rgba,
2116            stride,
2117            Some(TensorMemory::Dma),
2118        )
2119        .unwrap();
2120        // Logical dimensions unchanged by padding — this is the contract.
2121        assert_eq!(t.width(), Some(3004));
2122        assert_eq!(t.height(), Some(1688));
2123        assert_eq!(t.shape(), &[1688, 3004, 4]);
2124        // Stride is carried separately and reports the padded pitch.
2125        assert_eq!(t.effective_row_stride(), Some(stride));
2126        // Buffer is sized to stride × height so the full padded layout fits,
2127        // and CPU map() works for self-allocated strided DMA tensors.
2128        use crate::TensorMapTrait;
2129        {
2130            let map = t.map().unwrap();
2131            assert!(
2132                map.as_slice().len() >= stride * 1688,
2133                "mapped buffer {} bytes < expected {}",
2134                map.as_slice().len(),
2135                stride * 1688
2136            );
2137        }
2138        // CPU write access works too — iterate rows using the padded stride,
2139        // touch only the active `width × bpp` region, verify it round-trips.
2140        {
2141            let mut map = t.map().unwrap();
2142            let slice = map.as_mut_slice();
2143            for y in 0..1688 {
2144                let row_start = y * stride;
2145                for x in 0..3004 {
2146                    let p = row_start + x * 4;
2147                    slice[p] = (y & 0xFF) as u8;
2148                    slice[p + 1] = (x & 0xFF) as u8;
2149                    slice[p + 2] = 0x42;
2150                    slice[p + 3] = 0xFF;
2151                }
2152            }
2153        }
2154        {
2155            let map = t.map().unwrap();
2156            let slice = map.as_slice();
2157            // Sample a few pixels to confirm the round-trip.
2158            assert_eq!(slice[0], 0x00);
2159            assert_eq!(slice[1], 0x00);
2160            assert_eq!(slice[2], 0x42);
2161            assert_eq!(slice[3], 0xFF);
2162            let mid = 100 * stride + 50 * 4;
2163            assert_eq!(slice[mid], 100);
2164            assert_eq!(slice[mid + 1], 50);
2165            assert_eq!(slice[mid + 2], 0x42);
2166        }
2167    }
2168
2169    #[test]
2170    #[cfg(target_os = "linux")]
2171    fn image_tensor_with_stride_rejects_foreign_strided_map() {
2172        // A FOREIGN (imported via from_fd) DMA tensor with row_stride set
2173        // should still refuse CPU mapping — external allocator owns the
2174        // layout. This protects the V4L2 / GStreamer use case.
2175        //
2176        // We simulate a foreign import by wrapping our own allocation's
2177        // fd via `from_fd` and calling set_row_stride manually. The
2178        // `is_imported` flag on from_fd is true by construction.
2179        if !is_dma_available() {
2180            eprintln!("SKIPPED: DMA heap not available");
2181            return;
2182        }
2183        // Allocate a backing buffer large enough for a 320×240 BGRA8 image.
2184        let backing = Tensor::<u8>::new(&[240 * 320 * 4], Some(TensorMemory::Dma), None).unwrap();
2185        let fd = backing.clone_fd().unwrap();
2186        // Import it via from_fd — this marks is_imported=true.
2187        let shape = [240usize, 320, 4];
2188        let storage = TensorStorage::<u8>::from_fd(fd, &shape, None).unwrap();
2189        let mut t = Tensor::<u8>::wrap(storage);
2190        t.set_format(PixelFormat::Bgra).unwrap();
2191        t.set_row_stride(320 * 4).unwrap(); // natural, but still marks it as strided
2192        let err = t.map();
2193        assert!(
2194            matches!(err, Err(Error::InvalidOperation(_))),
2195            "foreign strided map should error"
2196        );
2197    }
2198
2199    #[test]
2200    #[cfg(target_os = "linux")]
2201    fn image_tensor_with_stride_map_rejects_tampered_stride() {
2202        // Round-3 PR feedback (C1): `set_row_stride` is public and only
2203        // validates `stride >= min_stride`, not that the new stride × height
2204        // fits the underlying buffer. A caller that tampers with the stride
2205        // after allocation must not be able to coerce `Tensor::map()` into
2206        // returning a slice larger than the backing mmap (that would be UB
2207        // in `DmaMap::as_slice`).
2208        if !is_dma_available() {
2209            eprintln!("SKIPPED: DMA heap not available");
2210            return;
2211        }
2212        // Allocate a 640×480 RGBA8 padded canvas (stride = 3072 = 768 px).
2213        // Backing buffer is 3072 × 480 = 1,474,560 bytes.
2214        let mut t = Tensor::<u8>::image_with_stride(
2215            640,
2216            480,
2217            PixelFormat::Rgba,
2218            3072,
2219            Some(TensorMemory::Dma),
2220        )
2221        .unwrap();
2222        // Tamper: push the stride up to 4 × the original. This is >=
2223        // min_stride (2560), so `set_row_stride` accepts it.
2224        t.set_row_stride(12288).unwrap();
2225        // Map must now refuse — 12288 × 480 = 5,898,240 > 1,474,560.
2226        let err = t.map();
2227        assert!(
2228            matches!(err, Err(Error::InvalidOperation(_))),
2229            "map() with oversized stride must return InvalidOperation"
2230        );
2231    }
2232
2233    #[test]
2234    fn dma_tensor_new_with_byte_size_rejects_shape_overflow() {
2235        // Round-3 PR feedback (C3): shape.product() * sizeof(T) must use
2236        // checked arithmetic so a pathological shape can't wrap usize and
2237        // make the byte_size-vs-logical-size comparison incorrect.
2238        //
2239        // This test only exercises the overflow rejection path, which is
2240        // pure-Rust and doesn't touch dma_heap — safe to run on any target.
2241        #[cfg(target_os = "linux")]
2242        {
2243            let err = crate::dma::DmaTensor::<u64>::new_with_byte_size(
2244                &[usize::MAX, 2, 2],
2245                usize::MAX,
2246                None,
2247            );
2248            assert!(
2249                matches!(err, Err(Error::InvalidArgument(_))),
2250                "new_with_byte_size must detect shape.product() overflow"
2251            );
2252        }
2253    }
2254
2255    #[test]
2256    #[cfg(target_os = "linux")]
2257    fn image_tensor_with_stride_rejects_too_small_stride() {
2258        // 640×480 RGBA8 natural pitch = 2560, request 2400 → should error.
2259        let err = Tensor::<u8>::image_with_stride(
2260            640,
2261            480,
2262            PixelFormat::Rgba,
2263            2400,
2264            Some(TensorMemory::Dma),
2265        );
2266        assert!(matches!(err, Err(Error::InvalidArgument(_))));
2267    }
2268
2269    #[test]
2270    #[cfg(target_os = "linux")]
2271    fn image_tensor_with_stride_rejects_non_packed() {
2272        // NV12 is SemiPlanar → not supported. (Linux-only because
2273        // `TensorMemory::Dma` itself is a Linux-only enum variant.)
2274        let err = Tensor::<u8>::image_with_stride(
2275            640,
2276            480,
2277            PixelFormat::Nv12,
2278            640,
2279            Some(TensorMemory::Dma),
2280        );
2281        assert!(matches!(err, Err(Error::NotImplemented(_))));
2282    }
2283
2284    #[test]
2285    fn set_format_valid() {
2286        let mut t = Tensor::<u8>::new(&[480, 640, 3], None, None).unwrap();
2287        assert!(t.format().is_none());
2288        t.set_format(PixelFormat::Rgb).unwrap();
2289        assert_eq!(t.format(), Some(PixelFormat::Rgb));
2290        assert_eq!(t.width(), Some(640));
2291        assert_eq!(t.height(), Some(480));
2292    }
2293
2294    #[test]
2295    fn set_format_invalid_shape() {
2296        let mut t = Tensor::<u8>::new(&[480, 640, 4], None, None).unwrap();
2297        // RGB expects 3 channels, not 4
2298        let err = t.set_format(PixelFormat::Rgb);
2299        assert!(err.is_err());
2300        // Original tensor is unmodified
2301        assert!(t.format().is_none());
2302    }
2303
2304    #[test]
2305    fn reshape_clears_format() {
2306        let mut t = Tensor::<u8>::image(640, 480, PixelFormat::Rgba, None).unwrap();
2307        assert_eq!(t.format(), Some(PixelFormat::Rgba));
2308        // Reshape to flat — format cleared
2309        t.reshape(&[480 * 640 * 4]).unwrap();
2310        assert!(t.format().is_none());
2311    }
2312
2313    #[test]
2314    fn from_planes_nv12() {
2315        let y = Tensor::<u8>::new(&[480, 640], None, None).unwrap();
2316        let uv = Tensor::<u8>::new(&[240, 640], None, None).unwrap();
2317        let img = Tensor::from_planes(y, uv, PixelFormat::Nv12).unwrap();
2318        assert_eq!(img.format(), Some(PixelFormat::Nv12));
2319        assert!(img.is_multiplane());
2320        assert!(img.chroma().is_some());
2321        assert_eq!(img.width(), Some(640));
2322        assert_eq!(img.height(), Some(480));
2323    }
2324
2325    #[test]
2326    fn from_planes_rejects_non_semiplanar() {
2327        let y = Tensor::<u8>::new(&[480, 640], None, None).unwrap();
2328        let uv = Tensor::<u8>::new(&[240, 640], None, None).unwrap();
2329        let err = Tensor::from_planes(y, uv, PixelFormat::Rgb);
2330        assert!(err.is_err());
2331    }
2332
2333    #[test]
2334    fn reshape_multiplane_errors() {
2335        let y = Tensor::<u8>::new(&[480, 640], None, None).unwrap();
2336        let uv = Tensor::<u8>::new(&[240, 640], None, None).unwrap();
2337        let mut img = Tensor::from_planes(y, uv, PixelFormat::Nv12).unwrap();
2338        let err = img.reshape(&[480 * 640 + 240 * 640]);
2339        assert!(err.is_err());
2340    }
2341}
2342
2343#[cfg(test)]
2344mod tests {
2345    #[cfg(target_os = "linux")]
2346    use nix::unistd::{access, AccessFlags};
2347    #[cfg(target_os = "linux")]
2348    use std::io::Write as _;
2349    use std::sync::RwLock;
2350
2351    use super::*;
2352
2353    #[ctor::ctor]
2354    fn init() {
2355        env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init();
2356    }
2357
2358    /// Macro to get the current function name for logging in tests.
2359    #[cfg(target_os = "linux")]
2360    macro_rules! function {
2361        () => {{
2362            fn f() {}
2363            fn type_name_of<T>(_: T) -> &'static str {
2364                std::any::type_name::<T>()
2365            }
2366            let name = type_name_of(f);
2367
2368            // Find and cut the rest of the path
2369            match &name[..name.len() - 3].rfind(':') {
2370                Some(pos) => &name[pos + 1..name.len() - 3],
2371                None => &name[..name.len() - 3],
2372            }
2373        }};
2374    }
2375
2376    #[test]
2377    #[cfg(target_os = "linux")]
2378    fn test_tensor() {
2379        let _lock = FD_LOCK.read().unwrap();
2380        let shape = vec![1];
2381        let tensor = DmaTensor::<f32>::new(&shape, Some("dma_tensor"));
2382        let dma_enabled = tensor.is_ok();
2383
2384        let tensor = Tensor::<f32>::new(&shape, None, None).expect("Failed to create tensor");
2385        match dma_enabled {
2386            true => assert_eq!(tensor.memory(), TensorMemory::Dma),
2387            false => assert_eq!(tensor.memory(), TensorMemory::Shm),
2388        }
2389    }
2390
2391    #[test]
2392    #[cfg(all(unix, not(target_os = "linux")))]
2393    fn test_tensor() {
2394        let shape = vec![1];
2395        let tensor = Tensor::<f32>::new(&shape, None, None).expect("Failed to create tensor");
2396        // On macOS/BSD, auto-detection tries SHM first, falls back to Mem
2397        assert!(
2398            tensor.memory() == TensorMemory::Shm || tensor.memory() == TensorMemory::Mem,
2399            "Expected SHM or Mem on macOS, got {:?}",
2400            tensor.memory()
2401        );
2402    }
2403
2404    #[test]
2405    #[cfg(not(unix))]
2406    fn test_tensor() {
2407        let shape = vec![1];
2408        let tensor = Tensor::<f32>::new(&shape, None, None).expect("Failed to create tensor");
2409        assert_eq!(tensor.memory(), TensorMemory::Mem);
2410    }
2411
2412    #[test]
2413    #[cfg(target_os = "linux")]
2414    fn test_dma_tensor() {
2415        let _lock = FD_LOCK.read().unwrap();
2416        match access(
2417            "/dev/dma_heap/linux,cma",
2418            AccessFlags::R_OK | AccessFlags::W_OK,
2419        ) {
2420            Ok(_) => println!("/dev/dma_heap/linux,cma is available"),
2421            Err(_) => match access(
2422                "/dev/dma_heap/system",
2423                AccessFlags::R_OK | AccessFlags::W_OK,
2424            ) {
2425                Ok(_) => println!("/dev/dma_heap/system is available"),
2426                Err(e) => {
2427                    writeln!(
2428                        &mut std::io::stdout(),
2429                        "[WARNING] DMA Heap is unavailable: {e}"
2430                    )
2431                    .unwrap();
2432                    return;
2433                }
2434            },
2435        }
2436
2437        let shape = vec![2, 3, 4];
2438        let tensor =
2439            DmaTensor::<f32>::new(&shape, Some("test_tensor")).expect("Failed to create tensor");
2440
2441        const DUMMY_VALUE: f32 = 12.34;
2442
2443        assert_eq!(tensor.memory(), TensorMemory::Dma);
2444        assert_eq!(tensor.name(), "test_tensor");
2445        assert_eq!(tensor.shape(), &shape);
2446        assert_eq!(tensor.size(), 2 * 3 * 4 * std::mem::size_of::<f32>());
2447        assert_eq!(tensor.len(), 2 * 3 * 4);
2448
2449        {
2450            let mut tensor_map = tensor.map().expect("Failed to map DMA memory");
2451            tensor_map.fill(42.0);
2452            assert!(tensor_map.iter().all(|&x| x == 42.0));
2453        }
2454
2455        {
2456            let shared = Tensor::<f32>::from_fd(
2457                tensor
2458                    .clone_fd()
2459                    .expect("Failed to duplicate tensor file descriptor"),
2460                &shape,
2461                Some("test_tensor_shared"),
2462            )
2463            .expect("Failed to create tensor from fd");
2464
2465            assert_eq!(shared.memory(), TensorMemory::Dma);
2466            assert_eq!(shared.name(), "test_tensor_shared");
2467            assert_eq!(shared.shape(), &shape);
2468
2469            let mut tensor_map = shared.map().expect("Failed to map DMA memory from fd");
2470            tensor_map.fill(DUMMY_VALUE);
2471            assert!(tensor_map.iter().all(|&x| x == DUMMY_VALUE));
2472        }
2473
2474        {
2475            let tensor_map = tensor.map().expect("Failed to map DMA memory");
2476            assert!(tensor_map.iter().all(|&x| x == DUMMY_VALUE));
2477        }
2478
2479        let mut tensor = DmaTensor::<u8>::new(&shape, None).expect("Failed to create tensor");
2480        assert_eq!(tensor.shape(), &shape);
2481        let new_shape = vec![3, 4, 4];
2482        assert!(
2483            tensor.reshape(&new_shape).is_err(),
2484            "Reshape should fail due to size mismatch"
2485        );
2486        assert_eq!(tensor.shape(), &shape, "Shape should remain unchanged");
2487
2488        let new_shape = vec![2, 3, 4];
2489        tensor.reshape(&new_shape).expect("Reshape should succeed");
2490        assert_eq!(
2491            tensor.shape(),
2492            &new_shape,
2493            "Shape should be updated after successful reshape"
2494        );
2495
2496        {
2497            let mut tensor_map = tensor.map().expect("Failed to map DMA memory");
2498            tensor_map.fill(1);
2499            assert!(tensor_map.iter().all(|&x| x == 1));
2500        }
2501
2502        {
2503            let mut tensor_map = tensor.map().expect("Failed to map DMA memory");
2504            tensor_map[2] = 42;
2505            assert_eq!(tensor_map[1], 1, "Value at index 1 should be 1");
2506            assert_eq!(tensor_map[2], 42, "Value at index 2 should be 42");
2507        }
2508    }
2509
2510    #[test]
2511    #[cfg(unix)]
2512    fn test_shm_tensor() {
2513        let _lock = FD_LOCK.read().unwrap();
2514        let shape = vec![2, 3, 4];
2515        let tensor =
2516            ShmTensor::<f32>::new(&shape, Some("test_tensor")).expect("Failed to create tensor");
2517        assert_eq!(tensor.shape(), &shape);
2518        assert_eq!(tensor.size(), 2 * 3 * 4 * std::mem::size_of::<f32>());
2519        assert_eq!(tensor.name(), "test_tensor");
2520
2521        const DUMMY_VALUE: f32 = 12.34;
2522        {
2523            let mut tensor_map = tensor.map().expect("Failed to map shared memory");
2524            tensor_map.fill(42.0);
2525            assert!(tensor_map.iter().all(|&x| x == 42.0));
2526        }
2527
2528        {
2529            let shared = Tensor::<f32>::from_fd(
2530                tensor
2531                    .clone_fd()
2532                    .expect("Failed to duplicate tensor file descriptor"),
2533                &shape,
2534                Some("test_tensor_shared"),
2535            )
2536            .expect("Failed to create tensor from fd");
2537
2538            assert_eq!(shared.memory(), TensorMemory::Shm);
2539            assert_eq!(shared.name(), "test_tensor_shared");
2540            assert_eq!(shared.shape(), &shape);
2541
2542            let mut tensor_map = shared.map().expect("Failed to map shared memory from fd");
2543            tensor_map.fill(DUMMY_VALUE);
2544            assert!(tensor_map.iter().all(|&x| x == DUMMY_VALUE));
2545        }
2546
2547        {
2548            let tensor_map = tensor.map().expect("Failed to map shared memory");
2549            assert!(tensor_map.iter().all(|&x| x == DUMMY_VALUE));
2550        }
2551
2552        let mut tensor = ShmTensor::<u8>::new(&shape, None).expect("Failed to create tensor");
2553        assert_eq!(tensor.shape(), &shape);
2554        let new_shape = vec![3, 4, 4];
2555        assert!(
2556            tensor.reshape(&new_shape).is_err(),
2557            "Reshape should fail due to size mismatch"
2558        );
2559        assert_eq!(tensor.shape(), &shape, "Shape should remain unchanged");
2560
2561        let new_shape = vec![2, 3, 4];
2562        tensor.reshape(&new_shape).expect("Reshape should succeed");
2563        assert_eq!(
2564            tensor.shape(),
2565            &new_shape,
2566            "Shape should be updated after successful reshape"
2567        );
2568
2569        {
2570            let mut tensor_map = tensor.map().expect("Failed to map shared memory");
2571            tensor_map.fill(1);
2572            assert!(tensor_map.iter().all(|&x| x == 1));
2573        }
2574
2575        {
2576            let mut tensor_map = tensor.map().expect("Failed to map shared memory");
2577            tensor_map[2] = 42;
2578            assert_eq!(tensor_map[1], 1, "Value at index 1 should be 1");
2579            assert_eq!(tensor_map[2], 42, "Value at index 2 should be 42");
2580        }
2581    }
2582
2583    #[test]
2584    fn test_mem_tensor() {
2585        let shape = vec![2, 3, 4];
2586        let tensor =
2587            MemTensor::<f32>::new(&shape, Some("test_tensor")).expect("Failed to create tensor");
2588        assert_eq!(tensor.shape(), &shape);
2589        assert_eq!(tensor.size(), 2 * 3 * 4 * std::mem::size_of::<f32>());
2590        assert_eq!(tensor.name(), "test_tensor");
2591
2592        {
2593            let mut tensor_map = tensor.map().expect("Failed to map memory");
2594            tensor_map.fill(42.0);
2595            assert!(tensor_map.iter().all(|&x| x == 42.0));
2596        }
2597
2598        let mut tensor = MemTensor::<u8>::new(&shape, None).expect("Failed to create tensor");
2599        assert_eq!(tensor.shape(), &shape);
2600        let new_shape = vec![3, 4, 4];
2601        assert!(
2602            tensor.reshape(&new_shape).is_err(),
2603            "Reshape should fail due to size mismatch"
2604        );
2605        assert_eq!(tensor.shape(), &shape, "Shape should remain unchanged");
2606
2607        let new_shape = vec![2, 3, 4];
2608        tensor.reshape(&new_shape).expect("Reshape should succeed");
2609        assert_eq!(
2610            tensor.shape(),
2611            &new_shape,
2612            "Shape should be updated after successful reshape"
2613        );
2614
2615        {
2616            let mut tensor_map = tensor.map().expect("Failed to map memory");
2617            tensor_map.fill(1);
2618            assert!(tensor_map.iter().all(|&x| x == 1));
2619        }
2620
2621        {
2622            let mut tensor_map = tensor.map().expect("Failed to map memory");
2623            tensor_map[2] = 42;
2624            assert_eq!(tensor_map[1], 1, "Value at index 1 should be 1");
2625            assert_eq!(tensor_map[2], 42, "Value at index 2 should be 42");
2626        }
2627    }
2628
2629    #[test]
2630    #[cfg(target_os = "linux")]
2631    fn test_dma_no_fd_leaks() {
2632        let _lock = FD_LOCK.write().unwrap();
2633        if !is_dma_available() {
2634            log::warn!(
2635                "SKIPPED: {} - DMA memory allocation not available (permission denied or no DMA-BUF support)",
2636                function!()
2637            );
2638            return;
2639        }
2640
2641        let proc = procfs::process::Process::myself()
2642            .expect("Failed to get current process using /proc/self");
2643
2644        let start_open_fds = proc
2645            .fd_count()
2646            .expect("Failed to get open file descriptor count");
2647
2648        for _ in 0..100 {
2649            let tensor = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Dma), None)
2650                .expect("Failed to create tensor");
2651            let mut map = tensor.map().unwrap();
2652            map.as_mut_slice().fill(233);
2653        }
2654
2655        let end_open_fds = proc
2656            .fd_count()
2657            .expect("Failed to get open file descriptor count");
2658
2659        assert_eq!(
2660            start_open_fds, end_open_fds,
2661            "File descriptor leak detected: {} -> {}",
2662            start_open_fds, end_open_fds
2663        );
2664    }
2665
2666    #[test]
2667    #[cfg(target_os = "linux")]
2668    fn test_dma_from_fd_no_fd_leaks() {
2669        let _lock = FD_LOCK.write().unwrap();
2670        if !is_dma_available() {
2671            log::warn!(
2672                "SKIPPED: {} - DMA memory allocation not available (permission denied or no DMA-BUF support)",
2673                function!()
2674            );
2675            return;
2676        }
2677
2678        let proc = procfs::process::Process::myself()
2679            .expect("Failed to get current process using /proc/self");
2680
2681        let start_open_fds = proc
2682            .fd_count()
2683            .expect("Failed to get open file descriptor count");
2684
2685        let orig = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Dma), None).unwrap();
2686
2687        for _ in 0..100 {
2688            let tensor =
2689                Tensor::<u8>::from_fd(orig.clone_fd().unwrap(), orig.shape(), None).unwrap();
2690            let mut map = tensor.map().unwrap();
2691            map.as_mut_slice().fill(233);
2692        }
2693        drop(orig);
2694
2695        let end_open_fds = proc.fd_count().unwrap();
2696
2697        assert_eq!(
2698            start_open_fds, end_open_fds,
2699            "File descriptor leak detected: {} -> {}",
2700            start_open_fds, end_open_fds
2701        );
2702    }
2703
2704    #[test]
2705    #[cfg(target_os = "linux")]
2706    fn test_shm_no_fd_leaks() {
2707        let _lock = FD_LOCK.write().unwrap();
2708        if !is_shm_available() {
2709            log::warn!(
2710                "SKIPPED: {} - SHM memory allocation not available (permission denied or no SHM support)",
2711                function!()
2712            );
2713            return;
2714        }
2715
2716        let proc = procfs::process::Process::myself()
2717            .expect("Failed to get current process using /proc/self");
2718
2719        let start_open_fds = proc
2720            .fd_count()
2721            .expect("Failed to get open file descriptor count");
2722
2723        for _ in 0..100 {
2724            let tensor = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Shm), None)
2725                .expect("Failed to create tensor");
2726            let mut map = tensor.map().unwrap();
2727            map.as_mut_slice().fill(233);
2728        }
2729
2730        let end_open_fds = proc
2731            .fd_count()
2732            .expect("Failed to get open file descriptor count");
2733
2734        assert_eq!(
2735            start_open_fds, end_open_fds,
2736            "File descriptor leak detected: {} -> {}",
2737            start_open_fds, end_open_fds
2738        );
2739    }
2740
2741    #[test]
2742    #[cfg(target_os = "linux")]
2743    fn test_shm_from_fd_no_fd_leaks() {
2744        let _lock = FD_LOCK.write().unwrap();
2745        if !is_shm_available() {
2746            log::warn!(
2747                "SKIPPED: {} - SHM memory allocation not available (permission denied or no SHM support)",
2748                function!()
2749            );
2750            return;
2751        }
2752
2753        let proc = procfs::process::Process::myself()
2754            .expect("Failed to get current process using /proc/self");
2755
2756        let start_open_fds = proc
2757            .fd_count()
2758            .expect("Failed to get open file descriptor count");
2759
2760        let orig = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Shm), None).unwrap();
2761
2762        for _ in 0..100 {
2763            let tensor =
2764                Tensor::<u8>::from_fd(orig.clone_fd().unwrap(), orig.shape(), None).unwrap();
2765            let mut map = tensor.map().unwrap();
2766            map.as_mut_slice().fill(233);
2767        }
2768        drop(orig);
2769
2770        let end_open_fds = proc.fd_count().unwrap();
2771
2772        assert_eq!(
2773            start_open_fds, end_open_fds,
2774            "File descriptor leak detected: {} -> {}",
2775            start_open_fds, end_open_fds
2776        );
2777    }
2778
2779    #[cfg(feature = "ndarray")]
2780    #[test]
2781    fn test_ndarray() {
2782        let _lock = FD_LOCK.read().unwrap();
2783        let shape = vec![2, 3, 4];
2784        let tensor = Tensor::<f32>::new(&shape, None, None).expect("Failed to create tensor");
2785
2786        let mut tensor_map = tensor.map().expect("Failed to map tensor memory");
2787        tensor_map.fill(1.0);
2788
2789        let view = tensor_map.view().expect("Failed to get ndarray view");
2790        assert_eq!(view.shape(), &[2, 3, 4]);
2791        assert!(view.iter().all(|&x| x == 1.0));
2792
2793        let mut view_mut = tensor_map
2794            .view_mut()
2795            .expect("Failed to get mutable ndarray view");
2796        view_mut[[0, 0, 0]] = 42.0;
2797        assert_eq!(view_mut[[0, 0, 0]], 42.0);
2798        assert_eq!(tensor_map[0], 42.0, "Value at index 0 should be 42");
2799    }
2800
2801    #[test]
2802    fn test_buffer_identity_unique() {
2803        let id1 = BufferIdentity::new();
2804        let id2 = BufferIdentity::new();
2805        assert_ne!(
2806            id1.id(),
2807            id2.id(),
2808            "Two identities should have different ids"
2809        );
2810    }
2811
2812    #[test]
2813    fn test_buffer_identity_clone_shares_guard() {
2814        let id1 = BufferIdentity::new();
2815        let weak = id1.weak();
2816        assert!(
2817            weak.upgrade().is_some(),
2818            "Weak should be alive while original exists"
2819        );
2820
2821        let id2 = id1.clone();
2822        assert_eq!(id1.id(), id2.id(), "Cloned identity should have same id");
2823
2824        drop(id1);
2825        assert!(
2826            weak.upgrade().is_some(),
2827            "Weak should still be alive (clone holds Arc)"
2828        );
2829
2830        drop(id2);
2831        assert!(
2832            weak.upgrade().is_none(),
2833            "Weak should be dead after all clones dropped"
2834        );
2835    }
2836
2837    #[test]
2838    fn test_tensor_buffer_identity() {
2839        let t1 = Tensor::<u8>::new(&[100], Some(TensorMemory::Mem), Some("t1")).unwrap();
2840        let t2 = Tensor::<u8>::new(&[100], Some(TensorMemory::Mem), Some("t2")).unwrap();
2841        assert_ne!(
2842            t1.buffer_identity().id(),
2843            t2.buffer_identity().id(),
2844            "Different tensors should have different buffer ids"
2845        );
2846    }
2847
2848    // ------------------------------------------------------------------------
2849    // Quantization — constructor validation + accessor correctness.
2850    // ------------------------------------------------------------------------
2851
2852    #[test]
2853    fn test_quantization_per_tensor_constructors() {
2854        let q = Quantization::per_tensor(0.1, -5);
2855        assert!(q.is_per_tensor());
2856        assert!(!q.is_per_channel());
2857        assert!(!q.is_symmetric());
2858        assert_eq!(q.scale(), &[0.1]);
2859        assert_eq!(q.zero_point(), Some(&[-5][..]));
2860
2861        let qs = Quantization::per_tensor_symmetric(0.05);
2862        assert!(qs.is_per_tensor());
2863        assert!(qs.is_symmetric());
2864        assert_eq!(qs.zero_point(), None);
2865    }
2866
2867    #[test]
2868    fn test_quantization_per_channel_constructors() {
2869        let q = Quantization::per_channel(vec![0.1, 0.2, 0.3], vec![0, -1, 1], 2).unwrap();
2870        assert!(q.is_per_channel());
2871        assert!(!q.is_symmetric());
2872        assert_eq!(q.axis(), Some(2));
2873        assert_eq!(q.scale().len(), 3);
2874
2875        let qs = Quantization::per_channel_symmetric(vec![0.054, 0.089, 0.195], 0).unwrap();
2876        assert!(qs.is_per_channel());
2877        assert!(qs.is_symmetric());
2878        assert_eq!(qs.axis(), Some(0));
2879    }
2880
2881    #[test]
2882    fn test_quantization_per_channel_length_mismatch_rejected() {
2883        // len(scales) != len(zero_points) → rejected at construction.
2884        let err = Quantization::per_channel(vec![0.1, 0.2], vec![0, 0, 0], 0).unwrap_err();
2885        assert!(matches!(err, Error::QuantizationInvalid { .. }));
2886    }
2887
2888    #[test]
2889    fn test_quantization_per_channel_empty_rejected() {
2890        let err = Quantization::per_channel_symmetric(vec![], 0).unwrap_err();
2891        assert!(matches!(err, Error::QuantizationInvalid { .. }));
2892    }
2893
2894    /// Constructors guard scale/zero_point length invariants, but
2895    /// `Quantization` is `Deserialize`, so malformed JSON (e.g. an
2896    /// empty `scale` array, or `zero_point` length that disagrees with
2897    /// `scale`) bypasses the constructor checks. `set_quantization`
2898    /// must reject these via `validate()` so they don't poison
2899    /// downstream `mode()` selection or per-channel kernel indexing.
2900    #[test]
2901    fn test_quantization_validate_rejects_malformed_deserialize() {
2902        let mut t = Tensor::<i8>::new(&[1, 1, 4], Some(TensorMemory::Mem), None).unwrap();
2903
2904        // Empty scale array: must be rejected.
2905        let q: Quantization = serde_json::from_str(r#"{"scale": []}"#).unwrap();
2906        assert!(matches!(
2907            t.set_quantization(q).unwrap_err(),
2908            Error::QuantizationInvalid { .. }
2909        ));
2910
2911        // Per-tensor with multi-element zero_point: must be rejected.
2912        let q: Quantization =
2913            serde_json::from_str(r#"{"scale": 0.1, "zero_point": [0, 0, 0]}"#).unwrap();
2914        assert!(matches!(
2915            t.set_quantization(q).unwrap_err(),
2916            Error::QuantizationInvalid { .. }
2917        ));
2918
2919        // Per-channel zero_point length != scale length: must be rejected.
2920        let q: Quantization = serde_json::from_str(
2921            r#"{"scale": [0.1, 0.2, 0.3, 0.4], "zero_point": [0, 0], "axis": 2}"#,
2922        )
2923        .unwrap();
2924        assert!(matches!(
2925            t.set_quantization(q).unwrap_err(),
2926            Error::QuantizationInvalid { .. }
2927        ));
2928    }
2929
2930    #[test]
2931    fn test_quantization_mode_dispatch() {
2932        let pt = Quantization::per_tensor(0.1, -5);
2933        assert!(matches!(
2934            pt.mode(),
2935            QuantMode::PerTensor { scale, zero_point } if scale == 0.1 && zero_point == -5
2936        ));
2937
2938        let pts = Quantization::per_tensor_symmetric(0.05);
2939        assert!(matches!(
2940            pts.mode(),
2941            QuantMode::PerTensorSymmetric { scale } if scale == 0.05
2942        ));
2943
2944        let pc = Quantization::per_channel(vec![0.1, 0.2], vec![0, -1], 2).unwrap();
2945        assert!(matches!(pc.mode(), QuantMode::PerChannel { axis: 2, .. }));
2946
2947        let pcs = Quantization::per_channel_symmetric(vec![0.1, 0.2], 0).unwrap();
2948        assert!(matches!(
2949            pcs.mode(),
2950            QuantMode::PerChannelSymmetric { axis: 0, .. }
2951        ));
2952    }
2953
2954    #[test]
2955    fn test_tensor_quantization_roundtrip_integer() {
2956        let mut t = Tensor::<i8>::new(&[2, 3, 4], Some(TensorMemory::Mem), None).unwrap();
2957        assert!(t.quantization().is_none());
2958        t.set_quantization(Quantization::per_tensor(0.1, -5))
2959            .unwrap();
2960        let q = t.quantization().unwrap();
2961        assert_eq!(q.scale(), &[0.1]);
2962        t.clear_quantization();
2963        assert!(t.quantization().is_none());
2964    }
2965
2966    #[test]
2967    fn test_tensor_with_quantization_builder() {
2968        let t = Tensor::<i8>::new(&[4, 4], Some(TensorMemory::Mem), None)
2969            .unwrap()
2970            .with_quantization(Quantization::per_tensor_symmetric(0.05))
2971            .unwrap();
2972        assert!(t.quantization().is_some());
2973    }
2974
2975    #[test]
2976    fn test_tensor_dyn_quantization_float_arm_returns_none() {
2977        let t = Tensor::<f32>::new(&[2, 2], Some(TensorMemory::Mem), None).unwrap();
2978        let td = TensorDyn::F32(t);
2979        assert!(td.quantization().is_none());
2980    }
2981
2982    #[test]
2983    fn test_tensor_dyn_set_quantization_float_arm_errors() {
2984        let t = Tensor::<f32>::new(&[2, 2], Some(TensorMemory::Mem), None).unwrap();
2985        let mut td = TensorDyn::F32(t);
2986        let err = td
2987            .set_quantization(Quantization::per_tensor(0.1, 0))
2988            .unwrap_err();
2989        // float path returns a QuantizationInvalid error.
2990        assert!(matches!(err, Error::QuantizationInvalid { .. }));
2991    }
2992
2993    /// Compile-time type gate — calling `Tensor::<f32>::quantization()` must
2994    /// fail to compile (the `IntegerType` trait bound is not satisfied by
2995    /// `f32`). This doctest anchors the invariant.
2996    ///
2997    /// ```compile_fail
2998    /// use edgefirst_tensor::{Tensor, TensorMemory};
2999    /// let t = Tensor::<f32>::new(&[2, 2], Some(TensorMemory::Mem), None).unwrap();
3000    /// let _ = t.quantization(); // compile error: f32 not IntegerType
3001    /// ```
3002    fn _compile_fail_doctest_anchor() {}
3003
3004    // Any test that cares about the fd count must grab it exclusively.
3005    // Any tests which modifies the fd count by opening or closing fds must grab it
3006    // shared.
3007    pub static FD_LOCK: RwLock<()> = RwLock::new(());
3008
3009    /// Test that DMA is NOT available on non-Linux platforms.
3010    /// This verifies the cross-platform behavior of is_dma_available().
3011    #[test]
3012    #[cfg(not(target_os = "linux"))]
3013    fn test_dma_not_available_on_non_linux() {
3014        assert!(
3015            !is_dma_available(),
3016            "DMA memory allocation should NOT be available on non-Linux platforms"
3017        );
3018    }
3019
3020    /// Test that SHM memory allocation is available and usable on Unix systems.
3021    /// This is a basic functional test; Linux has additional FD leak tests using procfs.
3022    #[test]
3023    #[cfg(unix)]
3024    fn test_shm_available_and_usable() {
3025        assert!(
3026            is_shm_available(),
3027            "SHM memory allocation should be available on Unix systems"
3028        );
3029
3030        // Create a tensor with SHM backing
3031        let tensor = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Shm), None)
3032            .expect("Failed to create SHM tensor");
3033
3034        // Verify we can map and write to it
3035        let mut map = tensor.map().expect("Failed to map SHM tensor");
3036        map.as_mut_slice().fill(0xAB);
3037
3038        // Verify the data was written correctly
3039        assert!(
3040            map.as_slice().iter().all(|&b| b == 0xAB),
3041            "SHM tensor data should be writable and readable"
3042        );
3043    }
3044}