Skip to main content

edgefirst_tensor/
lib.rs

1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4/*!
5EdgeFirst HAL - Tensor Module
6
7The `edgefirst_tensor` crate provides a unified interface for managing multi-dimensional arrays (tensors)
8with support for different memory types, including Direct Memory Access (DMA), POSIX Shared Memory (Shm),
9and system memory. The crate defines traits and structures for creating, reshaping, and mapping tensors into memory.
10
11## Examples
12```rust
13use edgefirst_tensor::{Error, Tensor, TensorMemory, TensorTrait};
14# fn main() -> Result<(), Error> {
15let tensor = Tensor::<f32>::new(&[2, 3, 4], Some(TensorMemory::Mem), Some("test_tensor"))?;
16assert_eq!(tensor.memory(), TensorMemory::Mem);
17assert_eq!(tensor.name(), "test_tensor");
18#    Ok(())
19# }
20```
21
22## Overview
23The main structures and traits provided by the `edgefirst_tensor` crate are `TensorTrait` and `TensorMapTrait`,
24which define the behavior of Tensors and their memory mappings, respectively.
25The `Tensor<T>` struct wraps a backend-specific storage with optional image format metadata (`PixelFormat`),
26while the `TensorMap` enum provides access to the underlying data. The `TensorDyn` type-erased enum
27wraps `Tensor<T>` for runtime element-type dispatch.
28 */
29#[cfg(target_os = "linux")]
30mod dma;
31#[cfg(target_os = "linux")]
32mod dmabuf;
33mod error;
34mod format;
35mod mem;
36mod pbo;
37#[cfg(unix)]
38mod shm;
39mod tensor_dyn;
40
41#[cfg(target_os = "linux")]
42pub use crate::dma::{DmaMap, DmaTensor};
43pub use crate::mem::{MemMap, MemTensor};
44pub use crate::pbo::{PboMap, PboMapping, PboOps, PboTensor};
45#[cfg(unix)]
46pub use crate::shm::{ShmMap, ShmTensor};
47pub use error::{Error, Result};
48pub use format::{PixelFormat, PixelLayout};
49use num_traits::Num;
50use serde::{Deserialize, Serialize};
51#[cfg(unix)]
52use std::os::fd::OwnedFd;
53use std::{
54    fmt,
55    ops::{Deref, DerefMut},
56    sync::{
57        atomic::{AtomicU64, Ordering},
58        Arc, Weak,
59    },
60};
61pub use tensor_dyn::TensorDyn;
62
63/// Per-plane DMA-BUF descriptor for external buffer import.
64///
65/// Owns a duplicated file descriptor plus optional stride and offset metadata.
66/// The fd is duplicated eagerly in [`new()`](Self::new) so that a bad fd is
67/// caught immediately. `import_image` consumes the descriptor and takes
68/// ownership of the duped fd — no further cleanup is needed by the caller.
69///
70/// # Examples
71///
72/// ```rust,no_run
73/// use edgefirst_tensor::PlaneDescriptor;
74/// use std::os::fd::BorrowedFd;
75///
76/// // SAFETY: fd 42 is hypothetical; real code must pass a valid fd.
77/// let pd = unsafe { PlaneDescriptor::new(BorrowedFd::borrow_raw(42)) }
78///     .unwrap()
79///     .with_stride(2048)
80///     .with_offset(0);
81/// ```
82#[cfg(unix)]
83pub struct PlaneDescriptor {
84    fd: OwnedFd,
85    stride: Option<usize>,
86    offset: Option<usize>,
87}
88
89#[cfg(unix)]
90impl PlaneDescriptor {
91    /// Create a new plane descriptor by duplicating the given file descriptor.
92    ///
93    /// The fd is duped immediately — a bad fd fails here rather than inside
94    /// `import_image`. The caller retains ownership of the original fd.
95    ///
96    /// # Errors
97    ///
98    /// Returns an error if the `dup()` syscall fails (e.g. invalid fd or
99    /// fd limit reached).
100    pub fn new(fd: std::os::fd::BorrowedFd<'_>) -> Result<Self> {
101        let owned = fd.try_clone_to_owned()?;
102        Ok(Self {
103            fd: owned,
104            stride: None,
105            offset: None,
106        })
107    }
108
109    /// Set the row stride in bytes (consuming builder).
110    pub fn with_stride(mut self, stride: usize) -> Self {
111        self.stride = Some(stride);
112        self
113    }
114
115    /// Set the plane offset in bytes (consuming builder).
116    pub fn with_offset(mut self, offset: usize) -> Self {
117        self.offset = Some(offset);
118        self
119    }
120
121    /// Consume the descriptor and return the owned file descriptor.
122    pub fn into_fd(self) -> OwnedFd {
123        self.fd
124    }
125
126    /// Row stride in bytes, if set.
127    pub fn stride(&self) -> Option<usize> {
128        self.stride
129    }
130
131    /// Plane offset in bytes, if set.
132    pub fn offset(&self) -> Option<usize> {
133        self.offset
134    }
135}
136
137/// Element type discriminant for runtime type identification.
138#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
139#[repr(u8)]
140#[non_exhaustive]
141pub enum DType {
142    U8,
143    I8,
144    U16,
145    I16,
146    U32,
147    I32,
148    U64,
149    I64,
150    F16,
151    F32,
152    F64,
153}
154
155impl DType {
156    /// Size of one element in bytes.
157    pub const fn size(&self) -> usize {
158        match self {
159            Self::U8 | Self::I8 => 1,
160            Self::U16 | Self::I16 | Self::F16 => 2,
161            Self::U32 | Self::I32 | Self::F32 => 4,
162            Self::U64 | Self::I64 | Self::F64 => 8,
163        }
164    }
165
166    /// Short type name (e.g., "u8", "f32", "f16").
167    pub const fn name(&self) -> &'static str {
168        match self {
169            Self::U8 => "u8",
170            Self::I8 => "i8",
171            Self::U16 => "u16",
172            Self::I16 => "i16",
173            Self::U32 => "u32",
174            Self::I32 => "i32",
175            Self::U64 => "u64",
176            Self::I64 => "i64",
177            Self::F16 => "f16",
178            Self::F32 => "f32",
179            Self::F64 => "f64",
180        }
181    }
182}
183
184impl fmt::Display for DType {
185    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
186        f.write_str(self.name())
187    }
188}
189
190// =============================================================================
191// Quantization metadata — type-gated to integer element types via sealed
192// `IntegerType` trait. Accessors on `Tensor<T>` only compile when `T` is
193// an integer type; calling them on `Tensor<f32>` / `Tensor<f16>` etc. is a
194// compile error, not a runtime one.
195// =============================================================================
196
197mod sealed {
198    pub trait Sealed {}
199    impl Sealed for u8 {}
200    impl Sealed for i8 {}
201    impl Sealed for u16 {}
202    impl Sealed for i16 {}
203    impl Sealed for u32 {}
204    impl Sealed for i32 {}
205    impl Sealed for u64 {}
206    impl Sealed for i64 {}
207    // Deliberately NOT implemented for f16 / f32 / f64.
208}
209
210/// Integer element types that may carry quantization metadata.
211///
212/// Sealed trait: implemented for `u8`, `i8`, `u16`, `i16`, `u32`, `i32`,
213/// `u64`, `i64`. Cannot be implemented downstream. Float element types
214/// (`half::f16`, `f32`, `f64`) are explicitly excluded — quantization
215/// metadata does not apply to float tensors per the edgefirst.json spec.
216pub trait IntegerType: sealed::Sealed {}
217impl IntegerType for u8 {}
218impl IntegerType for i8 {}
219impl IntegerType for u16 {}
220impl IntegerType for i16 {}
221impl IntegerType for u32 {}
222impl IntegerType for i32 {}
223impl IntegerType for u64 {}
224impl IntegerType for i64 {}
225
226/// Quantization parameters for an integer tensor.
227///
228/// Covers all four modes the edgefirst.json spec defines:
229///
230/// | Mode | `scale.len()` | `zero_point` | `axis` |
231/// |---|---|---|---|
232/// | Per-tensor symmetric | 1 | `None` | `None` |
233/// | Per-tensor asymmetric | 1 | `Some(len == 1)` | `None` |
234/// | Per-channel symmetric | >1 | `None` | `Some(c)` |
235/// | Per-channel asymmetric | >1 | `Some(len == scale.len())` | `Some(c)` |
236///
237/// The quantized storage type is carried on the parent [`Tensor<T>`]; this
238/// struct does not duplicate it. Construct via the four named constructors
239/// (the only public entry points); direct field mutation is not allowed so
240/// invalid combinations cannot be represented.
241///
242/// Dequantization formula:
243///
244/// ```text
245///   real_value = scale[c] × (quantized_value[c] - zero_point[c])
246/// ```
247///
248/// where `c` is the channel index (always `0` for per-tensor).
249#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
250pub struct Quantization {
251    /// Per-tensor: `vec![scale]`. Per-channel: `vec![scale_0, scale_1, ...]`.
252    #[serde(deserialize_with = "deserialize_scalar_or_vec_f32")]
253    scale: Vec<f32>,
254
255    /// `None` means symmetric (zero-point is 0). `Some(vec)` must have the
256    /// same length as `scale`.
257    #[serde(
258        default,
259        deserialize_with = "deserialize_opt_scalar_or_vec_i32",
260        skip_serializing_if = "Option::is_none"
261    )]
262    zero_point: Option<Vec<i32>>,
263
264    /// Channel axis for per-channel quantization. `Some(_)` iff
265    /// `scale.len() > 1`. Validated against the parent tensor's shape at
266    /// `set_quantization()` time.
267    #[serde(default, skip_serializing_if = "Option::is_none")]
268    axis: Option<usize>,
269}
270
271/// Semantic mode discriminant for hot-path kernel dispatch.
272///
273/// Obtain via [`Quantization::mode`] once at kernel entry; never inside a
274/// pixel-level loop. The enum is borrow-based so the hot kernel receives
275/// the scales / zero-points as slices without reallocation.
276#[derive(Debug, Clone, Copy)]
277pub enum QuantMode<'a> {
278    PerTensorSymmetric {
279        scale: f32,
280    },
281    PerTensor {
282        scale: f32,
283        zero_point: i32,
284    },
285    PerChannelSymmetric {
286        scales: &'a [f32],
287        axis: usize,
288    },
289    PerChannel {
290        scales: &'a [f32],
291        zero_points: &'a [i32],
292        axis: usize,
293    },
294}
295
296impl Quantization {
297    /// Per-tensor symmetric (zero_point = 0).
298    pub fn per_tensor_symmetric(scale: f32) -> Self {
299        Self {
300            scale: vec![scale],
301            zero_point: None,
302            axis: None,
303        }
304    }
305
306    /// Per-tensor asymmetric — the most common runtime shape.
307    pub fn per_tensor(scale: f32, zero_point: i32) -> Self {
308        Self {
309            scale: vec![scale],
310            zero_point: Some(vec![zero_point]),
311            axis: None,
312        }
313    }
314
315    /// Per-channel symmetric. Errors on empty `scales`.
316    pub fn per_channel_symmetric(scales: Vec<f32>, axis: usize) -> Result<Self> {
317        if scales.is_empty() {
318            return Err(Error::QuantizationInvalid {
319                field: "scale.len",
320                expected: "non-empty per-channel scales".to_string(),
321                got: "length 0".to_string(),
322            });
323        }
324        Ok(Self {
325            scale: scales,
326            zero_point: None,
327            axis: Some(axis),
328        })
329    }
330
331    /// Per-channel asymmetric. Errors on length mismatch between `scales`
332    /// and `zero_points`, or empty arrays.
333    pub fn per_channel(scales: Vec<f32>, zero_points: Vec<i32>, axis: usize) -> Result<Self> {
334        if scales.is_empty() {
335            return Err(Error::QuantizationInvalid {
336                field: "scale.len",
337                expected: "non-empty per-channel scales".to_string(),
338                got: "length 0".to_string(),
339            });
340        }
341        if scales.len() != zero_points.len() {
342            return Err(Error::QuantizationInvalid {
343                field: "zero_point.len",
344                expected: format!("length matches scale ({})", scales.len()),
345                got: format!("length {}", zero_points.len()),
346            });
347        }
348        Ok(Self {
349            scale: scales,
350            zero_point: Some(zero_points),
351            axis: Some(axis),
352        })
353    }
354
355    /// Borrow-based dispatch view. Match once at kernel entry.
356    pub fn mode(&self) -> QuantMode<'_> {
357        match (self.scale.len(), self.zero_point.as_deref(), self.axis) {
358            (1, None, _) => QuantMode::PerTensorSymmetric {
359                scale: self.scale[0],
360            },
361            (1, Some(zps), _) => QuantMode::PerTensor {
362                scale: self.scale[0],
363                zero_point: zps.first().copied().unwrap_or(0),
364            },
365            (_, None, Some(axis)) => QuantMode::PerChannelSymmetric {
366                scales: &self.scale,
367                axis,
368            },
369            (_, Some(zps), Some(axis)) => QuantMode::PerChannel {
370                scales: &self.scale,
371                zero_points: zps,
372                axis,
373            },
374            // The `validate()` path prevents constructing a
375            // per-channel Quantization without an axis, so the remaining
376            // pattern is unreachable in practice. Fall back to
377            // per-tensor symmetric using scale[0] to avoid panicking in
378            // release; debug builds assert.
379            _ => {
380                debug_assert!(
381                    false,
382                    "Quantization::mode: per-channel without axis is unreachable"
383                );
384                QuantMode::PerTensorSymmetric {
385                    scale: self.scale.first().copied().unwrap_or(1.0),
386                }
387            }
388        }
389    }
390
391    /// Returns `true` for per-tensor quantization (`scale.len() == 1`).
392    pub fn is_per_tensor(&self) -> bool {
393        self.scale.len() == 1
394    }
395
396    /// Returns `true` for per-channel quantization (`scale.len() > 1`).
397    pub fn is_per_channel(&self) -> bool {
398        self.scale.len() > 1
399    }
400
401    /// Returns `true` for symmetric quantization (no zero-point, or
402    /// zero-point vector of all zeros).
403    pub fn is_symmetric(&self) -> bool {
404        match &self.zero_point {
405            None => true,
406            Some(zps) => zps.iter().all(|&z| z == 0),
407        }
408    }
409
410    /// Borrow the scale array. Length 1 for per-tensor; `num_channels` for
411    /// per-channel.
412    pub fn scale(&self) -> &[f32] {
413        &self.scale
414    }
415
416    /// Borrow the zero-point array. `None` for symmetric.
417    pub fn zero_point(&self) -> Option<&[i32]> {
418        self.zero_point.as_deref()
419    }
420
421    /// Channel axis for per-channel quantization. `None` for per-tensor.
422    pub fn axis(&self) -> Option<usize> {
423        self.axis
424    }
425
426    /// Validate against a target tensor shape. Runs in
427    /// `Tensor::set_quantization()`. Catches:
428    ///   - empty `scale` (reject — must declare at least one factor)
429    ///   - `zero_point` length inconsistent with `scale` (reject —
430    ///     per-tensor must have len 1, per-channel must match `scale.len`)
431    ///   - `axis >= shape.len()` (axis out of range)
432    ///   - `scale.len() != shape[axis]` for per-channel
433    ///   - per-channel without axis (reject)
434    ///   - per-tensor with redundant axis (reject)
435    pub(crate) fn validate(&self, shape: &[usize]) -> Result<()> {
436        // `Quantization` is `Deserialize`, so malformed JSON like
437        // `{"scale": [], "zero_point": []}` could otherwise produce an
438        // ill-defined value that confuses `mode()` selection and the
439        // per-channel kernels' indexing.
440        if self.scale.is_empty() {
441            return Err(Error::QuantizationInvalid {
442                field: "scale.len",
443                expected: ">= 1".to_string(),
444                got: "0".to_string(),
445            });
446        }
447        if let Some(zps) = self.zero_point.as_ref() {
448            // Per-tensor: scale.len() == 1 and zero_point.len() must == 1.
449            // Per-channel: zero_point.len() must == scale.len().
450            let expected = if self.scale.len() == 1 {
451                1
452            } else {
453                self.scale.len()
454            };
455            if zps.len() != expected {
456                return Err(Error::QuantizationInvalid {
457                    field: "zero_point.len",
458                    expected: format!(
459                        "{expected} (matching {})",
460                        if self.scale.len() == 1 {
461                            "per-tensor scale"
462                        } else {
463                            "per-channel scale.len"
464                        }
465                    ),
466                    got: format!("length {}", zps.len()),
467                });
468            }
469        }
470
471        match (self.scale.len(), self.axis) {
472            (1, None) => Ok(()),
473            (1, Some(_)) => Err(Error::QuantizationInvalid {
474                field: "per_tensor_redundant_axis",
475                expected: "axis=None for per-tensor quantization".to_string(),
476                got: format!("axis={:?}", self.axis),
477            }),
478            (_, None) => Err(Error::QuantizationInvalid {
479                field: "per_channel_requires_axis",
480                expected: format!(
481                    "axis=Some(_) for per-channel quantization (scale.len={})",
482                    self.scale.len()
483                ),
484                got: "axis=None".to_string(),
485            }),
486            (n, Some(axis)) => {
487                if axis >= shape.len() {
488                    return Err(Error::QuantizationInvalid {
489                        field: "axis",
490                        expected: format!("axis < tensor rank ({})", shape.len()),
491                        got: format!("axis={axis}"),
492                    });
493                }
494                if shape[axis] != n {
495                    return Err(Error::QuantizationInvalid {
496                        field: "scale.len",
497                        expected: format!("length matches shape[{axis}] ({})", shape[axis]),
498                        got: format!("length {n}"),
499                    });
500                }
501                Ok(())
502            }
503        }
504    }
505}
506
507impl From<(f32, i32)> for Quantization {
508    /// Convenience construction from a `(scale, zero_point)` tuple. Matches
509    /// the legacy `QuantTuple` / `Quantization::new` calling convention so
510    /// existing `(0.1, -128).into()` sites keep working.
511    fn from((scale, zero_point): (f32, i32)) -> Self {
512        Self::per_tensor(scale, zero_point)
513    }
514}
515
516fn deserialize_scalar_or_vec_f32<'de, D: serde::Deserializer<'de>>(
517    de: D,
518) -> std::result::Result<Vec<f32>, D::Error> {
519    use serde::de::{self, Visitor};
520    struct V;
521    impl<'de> Visitor<'de> for V {
522        type Value = Vec<f32>;
523        fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
524            f.write_str("f32 or array of f32")
525        }
526        fn visit_f64<E: de::Error>(self, v: f64) -> std::result::Result<Self::Value, E> {
527            Ok(vec![v as f32])
528        }
529        #[allow(clippy::cast_possible_truncation)]
530        fn visit_i64<E: de::Error>(self, v: i64) -> std::result::Result<Self::Value, E> {
531            Ok(vec![v as f32])
532        }
533        #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
534        fn visit_u64<E: de::Error>(self, v: u64) -> std::result::Result<Self::Value, E> {
535            Ok(vec![v as f32])
536        }
537        fn visit_seq<A: de::SeqAccess<'de>>(
538            self,
539            mut seq: A,
540        ) -> std::result::Result<Self::Value, A::Error> {
541            let mut out = Vec::with_capacity(seq.size_hint().unwrap_or(1));
542            while let Some(x) = seq.next_element::<f32>()? {
543                out.push(x);
544            }
545            Ok(out)
546        }
547    }
548    de.deserialize_any(V)
549}
550
551fn deserialize_opt_scalar_or_vec_i32<'de, D: serde::Deserializer<'de>>(
552    de: D,
553) -> std::result::Result<Option<Vec<i32>>, D::Error> {
554    use serde::de::{self, Visitor};
555    struct V;
556    impl<'de> Visitor<'de> for V {
557        type Value = Option<Vec<i32>>;
558        fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
559            f.write_str("null, i32, or array of i32")
560        }
561        fn visit_none<E: de::Error>(self) -> std::result::Result<Self::Value, E> {
562            Ok(None)
563        }
564        fn visit_unit<E: de::Error>(self) -> std::result::Result<Self::Value, E> {
565            Ok(None)
566        }
567        fn visit_some<D2: serde::Deserializer<'de>>(
568            self,
569            de: D2,
570        ) -> std::result::Result<Self::Value, D2::Error> {
571            struct Inner;
572            impl<'de> Visitor<'de> for Inner {
573                type Value = Vec<i32>;
574                fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
575                    f.write_str("i32 or array of i32")
576                }
577                #[allow(clippy::cast_possible_truncation)]
578                fn visit_i64<E: de::Error>(self, v: i64) -> std::result::Result<Self::Value, E> {
579                    Ok(vec![v as i32])
580                }
581                #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
582                fn visit_u64<E: de::Error>(self, v: u64) -> std::result::Result<Self::Value, E> {
583                    Ok(vec![v as i32])
584                }
585                fn visit_seq<A: de::SeqAccess<'de>>(
586                    self,
587                    mut seq: A,
588                ) -> std::result::Result<Self::Value, A::Error> {
589                    let mut out = Vec::with_capacity(seq.size_hint().unwrap_or(1));
590                    while let Some(x) = seq.next_element::<i32>()? {
591                        out.push(x);
592                    }
593                    Ok(out)
594                }
595            }
596            de.deserialize_any(Inner).map(Some)
597        }
598        #[allow(clippy::cast_possible_truncation)]
599        fn visit_i64<E: de::Error>(self, v: i64) -> std::result::Result<Self::Value, E> {
600            Ok(Some(vec![v as i32]))
601        }
602        #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
603        fn visit_u64<E: de::Error>(self, v: u64) -> std::result::Result<Self::Value, E> {
604            Ok(Some(vec![v as i32]))
605        }
606        fn visit_seq<A: de::SeqAccess<'de>>(
607            self,
608            mut seq: A,
609        ) -> std::result::Result<Self::Value, A::Error> {
610            let mut out = Vec::with_capacity(seq.size_hint().unwrap_or(1));
611            while let Some(x) = seq.next_element::<i32>()? {
612                out.push(x);
613            }
614            Ok(Some(out))
615        }
616    }
617    de.deserialize_option(V)
618}
619
620/// Monotonic counter for buffer identity IDs.
621static NEXT_BUFFER_ID: AtomicU64 = AtomicU64::new(1);
622
623/// Unique identity for a tensor's underlying buffer.
624///
625/// Created fresh on every buffer allocation or import. The `id` is a monotonic
626/// u64 used as a cache key. The `guard` is an `Arc<()>` whose weak references
627/// allow downstream caches to detect when the buffer has been dropped.
628#[derive(Debug, Clone)]
629pub struct BufferIdentity {
630    id: u64,
631    guard: Arc<()>,
632}
633
634impl BufferIdentity {
635    /// Create a new unique buffer identity.
636    pub fn new() -> Self {
637        Self {
638            id: NEXT_BUFFER_ID.fetch_add(1, Ordering::Relaxed),
639            guard: Arc::new(()),
640        }
641    }
642
643    /// Unique identifier for this buffer. Changes when the buffer changes.
644    pub fn id(&self) -> u64 {
645        self.id
646    }
647
648    /// Returns a weak reference to the buffer guard. Goes dead when the
649    /// owning Tensor is dropped (and no clones remain).
650    pub fn weak(&self) -> Weak<()> {
651        Arc::downgrade(&self.guard)
652    }
653}
654
655impl Default for BufferIdentity {
656    fn default() -> Self {
657        Self::new()
658    }
659}
660
661#[cfg(target_os = "linux")]
662use nix::sys::stat::{major, minor};
663
664pub trait TensorTrait<T>: Send + Sync
665where
666    T: Num + Clone + fmt::Debug,
667{
668    /// Create a new tensor with the given shape and optional name. If no name
669    /// is given, a random name will be generated.
670    fn new(shape: &[usize], name: Option<&str>) -> Result<Self>
671    where
672        Self: Sized;
673
674    #[cfg(unix)]
675    /// Create a new tensor using the given file descriptor, shape, and optional
676    /// name. If no name is given, a random name will be generated.
677    ///
678    /// On Linux: Inspects the fd to determine DMA vs SHM based on device major/minor.
679    /// On other Unix (macOS): Always creates SHM tensor.
680    fn from_fd(fd: std::os::fd::OwnedFd, shape: &[usize], name: Option<&str>) -> Result<Self>
681    where
682        Self: Sized;
683
684    #[cfg(unix)]
685    /// Clone the file descriptor associated with this tensor.
686    fn clone_fd(&self) -> Result<std::os::fd::OwnedFd>;
687
688    /// Get the memory type of this tensor.
689    fn memory(&self) -> TensorMemory;
690
691    /// Get the name of this tensor.
692    fn name(&self) -> String;
693
694    /// Get the number of elements in this tensor.
695    fn len(&self) -> usize {
696        self.shape().iter().product()
697    }
698
699    /// Check if the tensor is empty.
700    fn is_empty(&self) -> bool {
701        self.len() == 0
702    }
703
704    /// Get the size in bytes of this tensor.
705    fn size(&self) -> usize {
706        self.len() * std::mem::size_of::<T>()
707    }
708
709    /// Get the shape of this tensor.
710    fn shape(&self) -> &[usize];
711
712    /// Reshape this tensor to the given shape. The total number of elements
713    /// must remain the same.
714    fn reshape(&mut self, shape: &[usize]) -> Result<()>;
715
716    /// Map the tensor into memory and return a TensorMap for accessing the
717    /// data.
718    fn map(&self) -> Result<TensorMap<T>>;
719
720    /// Get the buffer identity for cache keying and liveness tracking.
721    fn buffer_identity(&self) -> &BufferIdentity;
722}
723
724pub trait TensorMapTrait<T>
725where
726    T: Num + Clone + fmt::Debug,
727{
728    /// Get the shape of this tensor map.
729    fn shape(&self) -> &[usize];
730
731    /// Unmap the tensor from memory.
732    fn unmap(&mut self);
733
734    /// Get the number of elements in this tensor map.
735    fn len(&self) -> usize {
736        self.shape().iter().product()
737    }
738
739    /// Check if the tensor map is empty.
740    fn is_empty(&self) -> bool {
741        self.len() == 0
742    }
743
744    /// Get the size in bytes of this tensor map.
745    fn size(&self) -> usize {
746        self.len() * std::mem::size_of::<T>()
747    }
748
749    /// Get a slice to the data in this tensor map.
750    fn as_slice(&self) -> &[T];
751
752    /// Get a mutable slice to the data in this tensor map.
753    fn as_mut_slice(&mut self) -> &mut [T];
754
755    #[cfg(feature = "ndarray")]
756    /// Get an ndarray ArrayView of the tensor data.
757    fn view(&'_ self) -> Result<ndarray::ArrayView<'_, T, ndarray::Dim<ndarray::IxDynImpl>>> {
758        Ok(ndarray::ArrayView::from_shape(
759            self.shape(),
760            self.as_slice(),
761        )?)
762    }
763
764    #[cfg(feature = "ndarray")]
765    /// Get an ndarray ArrayViewMut of the tensor data.
766    fn view_mut(
767        &'_ mut self,
768    ) -> Result<ndarray::ArrayViewMut<'_, T, ndarray::Dim<ndarray::IxDynImpl>>> {
769        let shape = self.shape().to_vec();
770        Ok(ndarray::ArrayViewMut::from_shape(
771            shape,
772            self.as_mut_slice(),
773        )?)
774    }
775}
776
777#[derive(Debug, Clone, Copy, PartialEq, Eq)]
778pub enum TensorMemory {
779    #[cfg(target_os = "linux")]
780    /// Direct Memory Access (DMA) allocation. Incurs additional
781    /// overhead for memory reading/writing with the CPU.  Allows for
782    /// hardware acceleration when supported.
783    Dma,
784    #[cfg(unix)]
785    /// POSIX Shared Memory allocation. Suitable for inter-process
786    /// communication, but not suitable for hardware acceleration.
787    Shm,
788
789    /// Regular system memory allocation
790    Mem,
791
792    /// OpenGL Pixel Buffer Object memory. Created by ImageProcessor
793    /// when DMA-buf is unavailable but OpenGL is present.
794    Pbo,
795}
796
797impl From<TensorMemory> for String {
798    fn from(memory: TensorMemory) -> Self {
799        match memory {
800            #[cfg(target_os = "linux")]
801            TensorMemory::Dma => "dma".to_owned(),
802            #[cfg(unix)]
803            TensorMemory::Shm => "shm".to_owned(),
804            TensorMemory::Mem => "mem".to_owned(),
805            TensorMemory::Pbo => "pbo".to_owned(),
806        }
807    }
808}
809
810impl TryFrom<&str> for TensorMemory {
811    type Error = Error;
812
813    fn try_from(s: &str) -> Result<Self> {
814        match s {
815            #[cfg(target_os = "linux")]
816            "dma" => Ok(TensorMemory::Dma),
817            #[cfg(unix)]
818            "shm" => Ok(TensorMemory::Shm),
819            "mem" => Ok(TensorMemory::Mem),
820            "pbo" => Ok(TensorMemory::Pbo),
821            _ => Err(Error::InvalidMemoryType(s.to_owned())),
822        }
823    }
824}
825
826#[derive(Debug)]
827#[allow(dead_code)] // Variants are constructed by downstream crates via pub(crate) helpers
828pub(crate) enum TensorStorage<T>
829where
830    T: Num + Clone + fmt::Debug + Send + Sync,
831{
832    #[cfg(target_os = "linux")]
833    Dma(DmaTensor<T>),
834    #[cfg(unix)]
835    Shm(ShmTensor<T>),
836    Mem(MemTensor<T>),
837    Pbo(PboTensor<T>),
838}
839
840impl<T> TensorStorage<T>
841where
842    T: Num + Clone + fmt::Debug + Send + Sync,
843{
844    /// Create a new tensor storage with the given shape, memory type, and
845    /// optional name. If no name is given, a random name will be generated.
846    /// If no memory type is given, the best available memory type will be
847    /// chosen based on the platform and environment variables.
848    fn new(shape: &[usize], memory: Option<TensorMemory>, name: Option<&str>) -> Result<Self> {
849        match memory {
850            #[cfg(target_os = "linux")]
851            Some(TensorMemory::Dma) => {
852                DmaTensor::<T>::new(shape, name).map(TensorStorage::Dma)
853            }
854            #[cfg(unix)]
855            Some(TensorMemory::Shm) => {
856                ShmTensor::<T>::new(shape, name).map(TensorStorage::Shm)
857            }
858            Some(TensorMemory::Mem) => {
859                MemTensor::<T>::new(shape, name).map(TensorStorage::Mem)
860            }
861            Some(TensorMemory::Pbo) => Err(crate::error::Error::NotImplemented(
862                "PboTensor cannot be created via Tensor::new() — use ImageProcessor::create_image()".to_owned(),
863            )),
864            None => {
865                if std::env::var("EDGEFIRST_TENSOR_FORCE_MEM")
866                    .is_ok_and(|x| x != "0" && x.to_lowercase() != "false")
867                {
868                    MemTensor::<T>::new(shape, name).map(TensorStorage::Mem)
869                } else {
870                    #[cfg(target_os = "linux")]
871                    {
872                        // Linux: Try DMA -> SHM -> Mem
873                        match DmaTensor::<T>::new(shape, name) {
874                            Ok(tensor) => Ok(TensorStorage::Dma(tensor)),
875                            Err(_) => {
876                                match ShmTensor::<T>::new(shape, name)
877                                    .map(TensorStorage::Shm)
878                                {
879                                    Ok(tensor) => Ok(tensor),
880                                    Err(_) => MemTensor::<T>::new(shape, name)
881                                        .map(TensorStorage::Mem),
882                                }
883                            }
884                        }
885                    }
886                    #[cfg(all(unix, not(target_os = "linux")))]
887                    {
888                        // macOS/BSD: Try SHM -> Mem (no DMA)
889                        match ShmTensor::<T>::new(shape, name) {
890                            Ok(tensor) => Ok(TensorStorage::Shm(tensor)),
891                            Err(_) => {
892                                MemTensor::<T>::new(shape, name).map(TensorStorage::Mem)
893                            }
894                        }
895                    }
896                    #[cfg(not(unix))]
897                    {
898                        // Windows/other: Mem only
899                        MemTensor::<T>::new(shape, name).map(TensorStorage::Mem)
900                    }
901                }
902            }
903        }
904    }
905
906    /// Create a DMA-backed tensor storage with an explicit byte size that
907    /// may exceed `shape.product() * sizeof(T)`. Used for image tensors
908    /// with row-padded layouts (see `DmaTensor::new_with_byte_size`).
909    ///
910    /// This is intentionally DMA-only: padding is only meaningful for
911    /// buffers that will be imported as GPU textures via EGLImage. PBO,
912    /// Shm, and Mem storage doesn't benefit from pitch alignment and
913    /// shouldn't pay the memory cost.
914    #[cfg(target_os = "linux")]
915    pub(crate) fn new_dma_with_byte_size(
916        shape: &[usize],
917        byte_size: usize,
918        name: Option<&str>,
919    ) -> Result<Self> {
920        DmaTensor::<T>::new_with_byte_size(shape, byte_size, name).map(TensorStorage::Dma)
921    }
922
923    // No non-Linux stub: the only caller (`Tensor::image_with_stride`)
924    // returns `NotImplemented` directly on non-Linux without ever
925    // reaching the storage layer, so defining a stub here would be
926    // dead code and fail the `-D warnings` clippy gate on macOS CI.
927
928    /// Create a new tensor storage using the given file descriptor, shape,
929    /// and optional name.
930    #[cfg(unix)]
931    fn from_fd(fd: OwnedFd, shape: &[usize], name: Option<&str>) -> Result<Self> {
932        #[cfg(target_os = "linux")]
933        {
934            use nix::sys::stat::fstat;
935
936            let stat = fstat(&fd)?;
937            let major = major(stat.st_dev);
938            let minor = minor(stat.st_dev);
939
940            log::debug!("Creating tensor from fd: major={major}, minor={minor}");
941
942            if major != 0 {
943                // Dma and Shm tensors are expected to have major number 0
944                return Err(Error::UnknownDeviceType(major, minor));
945            }
946
947            match minor {
948                9 | 10 => {
949                    // minor number 9 & 10 indicates DMA memory
950                    DmaTensor::<T>::from_fd(fd, shape, name).map(TensorStorage::Dma)
951                }
952                _ => {
953                    // other minor numbers are assumed to be shared memory
954                    ShmTensor::<T>::from_fd(fd, shape, name).map(TensorStorage::Shm)
955                }
956            }
957        }
958        #[cfg(all(unix, not(target_os = "linux")))]
959        {
960            // On macOS/BSD, always use SHM (no DMA support)
961            ShmTensor::<T>::from_fd(fd, shape, name).map(TensorStorage::Shm)
962        }
963    }
964}
965
966impl<T> TensorTrait<T> for TensorStorage<T>
967where
968    T: Num + Clone + fmt::Debug + Send + Sync,
969{
970    fn new(shape: &[usize], name: Option<&str>) -> Result<Self> {
971        Self::new(shape, None, name)
972    }
973
974    #[cfg(unix)]
975    fn from_fd(fd: OwnedFd, shape: &[usize], name: Option<&str>) -> Result<Self> {
976        Self::from_fd(fd, shape, name)
977    }
978
979    #[cfg(unix)]
980    fn clone_fd(&self) -> Result<OwnedFd> {
981        match self {
982            #[cfg(target_os = "linux")]
983            TensorStorage::Dma(t) => t.clone_fd(),
984            TensorStorage::Shm(t) => t.clone_fd(),
985            TensorStorage::Mem(t) => t.clone_fd(),
986            TensorStorage::Pbo(t) => t.clone_fd(),
987        }
988    }
989
990    fn memory(&self) -> TensorMemory {
991        match self {
992            #[cfg(target_os = "linux")]
993            TensorStorage::Dma(_) => TensorMemory::Dma,
994            #[cfg(unix)]
995            TensorStorage::Shm(_) => TensorMemory::Shm,
996            TensorStorage::Mem(_) => TensorMemory::Mem,
997            TensorStorage::Pbo(_) => TensorMemory::Pbo,
998        }
999    }
1000
1001    fn name(&self) -> String {
1002        match self {
1003            #[cfg(target_os = "linux")]
1004            TensorStorage::Dma(t) => t.name(),
1005            #[cfg(unix)]
1006            TensorStorage::Shm(t) => t.name(),
1007            TensorStorage::Mem(t) => t.name(),
1008            TensorStorage::Pbo(t) => t.name(),
1009        }
1010    }
1011
1012    fn shape(&self) -> &[usize] {
1013        match self {
1014            #[cfg(target_os = "linux")]
1015            TensorStorage::Dma(t) => t.shape(),
1016            #[cfg(unix)]
1017            TensorStorage::Shm(t) => t.shape(),
1018            TensorStorage::Mem(t) => t.shape(),
1019            TensorStorage::Pbo(t) => t.shape(),
1020        }
1021    }
1022
1023    fn reshape(&mut self, shape: &[usize]) -> Result<()> {
1024        match self {
1025            #[cfg(target_os = "linux")]
1026            TensorStorage::Dma(t) => t.reshape(shape),
1027            #[cfg(unix)]
1028            TensorStorage::Shm(t) => t.reshape(shape),
1029            TensorStorage::Mem(t) => t.reshape(shape),
1030            TensorStorage::Pbo(t) => t.reshape(shape),
1031        }
1032    }
1033
1034    fn map(&self) -> Result<TensorMap<T>> {
1035        match self {
1036            #[cfg(target_os = "linux")]
1037            TensorStorage::Dma(t) => t.map(),
1038            #[cfg(unix)]
1039            TensorStorage::Shm(t) => t.map(),
1040            TensorStorage::Mem(t) => t.map(),
1041            TensorStorage::Pbo(t) => t.map(),
1042        }
1043    }
1044
1045    fn buffer_identity(&self) -> &BufferIdentity {
1046        match self {
1047            #[cfg(target_os = "linux")]
1048            TensorStorage::Dma(t) => t.buffer_identity(),
1049            #[cfg(unix)]
1050            TensorStorage::Shm(t) => t.buffer_identity(),
1051            TensorStorage::Mem(t) => t.buffer_identity(),
1052            TensorStorage::Pbo(t) => t.buffer_identity(),
1053        }
1054    }
1055}
1056
1057/// Multi-backend tensor with optional image format metadata.
1058///
1059/// When `format` is `Some`, this tensor represents an image. Width, height,
1060/// and channels are derived from `shape` + `format`. When `format` is `None`,
1061/// this is a raw tensor (identical to the pre-refactoring behavior).
1062#[derive(Debug)]
1063pub struct Tensor<T>
1064where
1065    T: Num + Clone + fmt::Debug + Send + Sync,
1066{
1067    pub(crate) storage: TensorStorage<T>,
1068    format: Option<PixelFormat>,
1069    chroma: Option<Box<Tensor<T>>>,
1070    /// Row stride in bytes for externally allocated buffers with row padding.
1071    /// `None` means tightly packed (stride == width * bytes_per_pixel).
1072    row_stride: Option<usize>,
1073    /// Byte offset within the DMA-BUF where image data starts.
1074    /// `None` means offset 0 (data starts at the beginning of the buffer).
1075    plane_offset: Option<usize>,
1076    /// Quantization metadata for integer-typed tensors. Public access is
1077    /// gated by the `IntegerType` trait — `Tensor<f32>` etc. carry the
1078    /// field for layout uniformity but have no way to read or write it.
1079    pub(crate) quantization: Option<Quantization>,
1080}
1081
1082impl<T> Tensor<T>
1083where
1084    T: Num + Clone + fmt::Debug + Send + Sync,
1085{
1086    /// Wrap a TensorStorage in a Tensor with no image metadata.
1087    pub(crate) fn wrap(storage: TensorStorage<T>) -> Self {
1088        Self {
1089            storage,
1090            format: None,
1091            chroma: None,
1092            row_stride: None,
1093            plane_offset: None,
1094            quantization: None,
1095        }
1096    }
1097
1098    /// Construct a tensor from a row-major element slice + shape. Allocates a
1099    /// new buffer (`TensorMemory::Mem`) and memcpys the contents; caller
1100    /// retains ownership of the input slice.
1101    ///
1102    /// # Errors
1103    ///
1104    /// - [`Error::InvalidShape`] if `values.len() != shape.iter().product()`.
1105    /// - Propagates any allocation error from [`Self::new`].
1106    pub fn from_slice(values: &[T], shape: &[usize]) -> Result<Self>
1107    where
1108        T: Copy,
1109    {
1110        let expected: usize = shape.iter().product();
1111        if values.len() != expected {
1112            return Err(Error::InvalidShape(format!(
1113                "from_slice: values.len()={} but shape product={expected} (shape={shape:?})",
1114                values.len()
1115            )));
1116        }
1117        let t = Self::new(shape, Some(TensorMemory::Mem), None)?;
1118        {
1119            let mut m = t.map()?;
1120            m.as_mut_slice().copy_from_slice(values);
1121        }
1122        Ok(t)
1123    }
1124
1125    /// Construct a tensor from a 3-D ndarray view. Respects strides — one
1126    /// copy in all cases; contiguous views take a memcpy fast path.
1127    ///
1128    /// Only available when the `ndarray` feature is enabled.
1129    #[cfg(feature = "ndarray")]
1130    pub fn from_arrayview3(view: ndarray::ArrayView3<'_, T>) -> Result<Self>
1131    where
1132        T: Copy,
1133    {
1134        let (h, w, c) = view.dim();
1135        let t = Self::new(&[h, w, c], Some(TensorMemory::Mem), None)?;
1136        {
1137            let mut m = t.map()?;
1138            let dst = m.as_mut_slice();
1139            if let Some(src) = view.as_slice() {
1140                dst.copy_from_slice(src);
1141            } else {
1142                for (d, &s) in dst.iter_mut().zip(view.iter()) {
1143                    *d = s;
1144                }
1145            }
1146        }
1147        Ok(t)
1148    }
1149
1150    /// Create a new tensor with the given shape, memory type, and optional
1151    /// name. If no name is given, a random name will be generated. If no
1152    /// memory type is given, the best available memory type will be chosen
1153    /// based on the platform and environment variables.
1154    ///
1155    /// On Linux platforms, the order of preference is: Dma -> Shm -> Mem.
1156    /// On other Unix platforms (macOS), the order is: Shm -> Mem.
1157    /// On non-Unix platforms, only Mem is available.
1158    ///
1159    /// # Environment Variables
1160    /// - `EDGEFIRST_TENSOR_FORCE_MEM`: If set to a non-zero and non-false
1161    ///   value, forces the use of regular system memory allocation
1162    ///   (`TensorMemory::Mem`) regardless of platform capabilities.
1163    ///
1164    /// # Example
1165    /// ```rust
1166    /// use edgefirst_tensor::{Error, Tensor, TensorMemory, TensorTrait};
1167    /// # fn main() -> Result<(), Error> {
1168    /// let tensor = Tensor::<f32>::new(&[2, 3, 4], Some(TensorMemory::Mem), Some("test_tensor"))?;
1169    /// assert_eq!(tensor.memory(), TensorMemory::Mem);
1170    /// assert_eq!(tensor.name(), "test_tensor");
1171    /// #    Ok(())
1172    /// # }
1173    /// ```
1174    pub fn new(shape: &[usize], memory: Option<TensorMemory>, name: Option<&str>) -> Result<Self> {
1175        let _span = tracing::trace_span!(
1176            "tensor_alloc",
1177            ?shape,
1178            memory = ?memory,
1179            dtype = std::any::type_name::<T>(),
1180        )
1181        .entered();
1182        TensorStorage::new(shape, memory, name).map(Self::wrap)
1183    }
1184
1185    /// Create an image tensor with the given format.
1186    pub fn image(
1187        width: usize,
1188        height: usize,
1189        format: PixelFormat,
1190        memory: Option<TensorMemory>,
1191    ) -> Result<Self> {
1192        let shape = match format.layout() {
1193            PixelLayout::Packed => vec![height, width, format.channels()],
1194            PixelLayout::Planar => vec![format.channels(), height, width],
1195            PixelLayout::SemiPlanar => {
1196                // Contiguous semi-planar: luma + interleaved chroma in one allocation.
1197                // NV12 (4:2:0): H lines luma + H/2 lines chroma = H * 3/2 total
1198                // NV16 (4:2:2): H lines luma + H lines chroma = H * 2 total
1199                let total_h = match format {
1200                    PixelFormat::Nv12 => {
1201                        if !height.is_multiple_of(2) {
1202                            return Err(Error::InvalidArgument(format!(
1203                                "NV12 requires even height, got {height}"
1204                            )));
1205                        }
1206                        height * 3 / 2
1207                    }
1208                    PixelFormat::Nv16 => height * 2,
1209                    _ => {
1210                        return Err(Error::InvalidArgument(format!(
1211                            "unknown semi-planar height multiplier for {format:?}"
1212                        )))
1213                    }
1214                };
1215                vec![total_h, width]
1216            }
1217        };
1218        let mut t = Self::new(&shape, memory, None)?;
1219        t.format = Some(format);
1220        Ok(t)
1221    }
1222
1223    /// Create a DMA-backed image tensor with an explicit row stride that
1224    /// may exceed the natural `width * channels * sizeof(T)` pitch.
1225    ///
1226    /// Used for image tensors that need GPU pitch alignment padding: the
1227    /// underlying DMA-BUF is sized to `row_stride * height` bytes, but
1228    /// the tensor's logical shape stays at `[height, width, channels]`.
1229    /// `width()` / `height()` / `shape()` continue to report the
1230    /// user-requested values; the padding is visible only via
1231    /// `row_stride()` / `effective_row_stride()` and is automatically
1232    /// propagated to the GL backend's EGLImage import so Mali Valhall
1233    /// accepts the buffer.
1234    ///
1235    /// # Supported formats
1236    ///
1237    /// Currently only **packed** pixel layouts (RGBA8, BGRA8, RGB888,
1238    /// Grey, etc.) are supported — the formats the GL backend uses as
1239    /// render destinations. Semi-planar formats (NV12, NV16) come from
1240    /// external allocators (camera capture, video decoders) and are
1241    /// imported via `TensorDyn::from_fd` + `set_row_stride`, which
1242    /// already supports padded strides.
1243    ///
1244    /// # Supported memory
1245    ///
1246    /// Currently only `TensorMemory::Dma` is supported. PBO and Mem
1247    /// storage don't go through EGLImage import so they don't need
1248    /// pitch alignment; if you pass any other memory type this returns
1249    /// `NotImplemented`. `None` (auto-select) is treated as `Dma`.
1250    ///
1251    /// # Errors
1252    ///
1253    /// - `InvalidArgument` if `row_stride_bytes < width * channels * sizeof(T)`
1254    ///   (the requested stride would not fit a single row)
1255    /// - `NotImplemented` for non-packed formats or non-DMA memory
1256    /// - `IoError` if the DMA-heap allocation fails (propagated from
1257    ///   `DmaTensor::new_with_byte_size`)
1258    pub fn image_with_stride(
1259        width: usize,
1260        height: usize,
1261        format: PixelFormat,
1262        row_stride_bytes: usize,
1263        memory: Option<TensorMemory>,
1264    ) -> Result<Self> {
1265        // DMA backing (the only thing this constructor produces) is
1266        // Linux-only. On macOS/BSD/Windows the non-Linux block below is
1267        // the only compiled body and returns `NotImplemented` directly;
1268        // on Linux the non-Linux block is cfg-removed and the function
1269        // falls through to the real validation + allocation path. Each
1270        // target compiles exactly one of the two blocks, and the block
1271        // serves as the function's tail expression in both cases — so
1272        // neither needs an explicit `return` (avoids
1273        // `clippy::needless_return` on the macOS CI gate).
1274        #[cfg(not(target_os = "linux"))]
1275        {
1276            let _ = (width, height, format, row_stride_bytes, memory);
1277            Err(Error::NotImplemented(
1278                "image_with_stride requires DMA support (Linux only)".to_owned(),
1279            ))
1280        }
1281
1282        #[cfg(target_os = "linux")]
1283        {
1284            if format.layout() != PixelLayout::Packed {
1285                return Err(Error::NotImplemented(format!(
1286                    "Tensor::image_with_stride only supports packed pixel layouts, got {format:?}"
1287                )));
1288            }
1289            let elem = std::mem::size_of::<T>();
1290            let min_stride = width
1291                .checked_mul(format.channels())
1292                .and_then(|p| p.checked_mul(elem))
1293                .ok_or_else(|| {
1294                    Error::InvalidArgument(format!(
1295                        "image_with_stride: width {width} × channels {} × sizeof::<T>={elem} \
1296                         overflows usize",
1297                        format.channels()
1298                    ))
1299                })?;
1300            if row_stride_bytes < min_stride {
1301                return Err(Error::InvalidArgument(format!(
1302                    "image_with_stride: row_stride {row_stride_bytes} < minimum {min_stride} \
1303                     ({width} px × {} ch × {elem} B)",
1304                    format.channels()
1305                )));
1306            }
1307            let total_byte_size = row_stride_bytes.checked_mul(height).ok_or_else(|| {
1308                Error::InvalidArgument(format!(
1309                    "image_with_stride: row_stride {row_stride_bytes} × height {height} overflows usize"
1310                ))
1311            })?;
1312
1313            let shape = vec![height, width, format.channels()];
1314
1315            let storage = match memory {
1316                Some(TensorMemory::Dma) | None => {
1317                    TensorStorage::<T>::new_dma_with_byte_size(&shape, total_byte_size, None)?
1318                }
1319                Some(other) => {
1320                    return Err(Error::NotImplemented(format!(
1321                        "image_with_stride: only TensorMemory::Dma is supported, got {other:?}"
1322                    )));
1323                }
1324            };
1325
1326            let mut t = Self::wrap(storage);
1327            t.format = Some(format);
1328            t.row_stride = Some(row_stride_bytes);
1329            Ok(t)
1330        }
1331    }
1332
1333    /// Attach format metadata to an existing tensor.
1334    ///
1335    /// # Arguments
1336    ///
1337    /// * `format` - The pixel format to attach
1338    ///
1339    /// # Returns
1340    ///
1341    /// `Ok(())` on success, with the format stored as metadata on the tensor.
1342    ///
1343    /// # Errors
1344    ///
1345    /// Returns `Error::InvalidShape` if the tensor shape is incompatible with
1346    /// the format's layout (packed expects `[H, W, C]`, planar expects
1347    /// `[C, H, W]`, semi-planar expects `[H*k, W]` with format-specific
1348    /// height constraints).
1349    pub fn set_format(&mut self, format: PixelFormat) -> Result<()> {
1350        let shape = self.shape();
1351        match format.layout() {
1352            PixelLayout::Packed => {
1353                if shape.len() != 3 || shape[2] != format.channels() {
1354                    return Err(Error::InvalidShape(format!(
1355                        "packed format {format:?} expects [H, W, {}], got {shape:?}",
1356                        format.channels()
1357                    )));
1358                }
1359            }
1360            PixelLayout::Planar => {
1361                if shape.len() != 3 || shape[0] != format.channels() {
1362                    return Err(Error::InvalidShape(format!(
1363                        "planar format {format:?} expects [{}, H, W], got {shape:?}",
1364                        format.channels()
1365                    )));
1366                }
1367            }
1368            PixelLayout::SemiPlanar => {
1369                if shape.len() != 2 {
1370                    return Err(Error::InvalidShape(format!(
1371                        "semi-planar format {format:?} expects [H*k, W], got {shape:?}"
1372                    )));
1373                }
1374                match format {
1375                    PixelFormat::Nv12 if !shape[0].is_multiple_of(3) => {
1376                        return Err(Error::InvalidShape(format!(
1377                            "NV12 contiguous shape[0] must be divisible by 3, got {}",
1378                            shape[0]
1379                        )));
1380                    }
1381                    PixelFormat::Nv16 if !shape[0].is_multiple_of(2) => {
1382                        return Err(Error::InvalidShape(format!(
1383                            "NV16 contiguous shape[0] must be even, got {}",
1384                            shape[0]
1385                        )));
1386                    }
1387                    _ => {}
1388                }
1389            }
1390        }
1391        // Clear stored stride/offset when format changes — they may be invalid
1392        // for the new format. Caller must re-set after changing format.
1393        if self.format != Some(format) {
1394            self.row_stride = None;
1395            self.plane_offset = None;
1396            #[cfg(target_os = "linux")]
1397            if let TensorStorage::Dma(ref mut dma) = self.storage {
1398                dma.mmap_offset = 0;
1399            }
1400        }
1401        self.format = Some(format);
1402        Ok(())
1403    }
1404
1405    /// Pixel format (None if not an image).
1406    pub fn format(&self) -> Option<PixelFormat> {
1407        self.format
1408    }
1409
1410    /// Image width (None if not an image).
1411    pub fn width(&self) -> Option<usize> {
1412        let fmt = self.format?;
1413        let shape = self.shape();
1414        match fmt.layout() {
1415            PixelLayout::Packed => Some(shape[1]),
1416            PixelLayout::Planar => Some(shape[2]),
1417            PixelLayout::SemiPlanar => Some(shape[1]),
1418        }
1419    }
1420
1421    /// Image height (None if not an image).
1422    pub fn height(&self) -> Option<usize> {
1423        let fmt = self.format?;
1424        let shape = self.shape();
1425        match fmt.layout() {
1426            PixelLayout::Packed => Some(shape[0]),
1427            PixelLayout::Planar => Some(shape[1]),
1428            PixelLayout::SemiPlanar => {
1429                if self.is_multiplane() {
1430                    Some(shape[0])
1431                } else {
1432                    match fmt {
1433                        PixelFormat::Nv12 => Some(shape[0] * 2 / 3),
1434                        PixelFormat::Nv16 => Some(shape[0] / 2),
1435                        _ => None,
1436                    }
1437                }
1438            }
1439        }
1440    }
1441
1442    /// Create from separate Y and UV planes (multiplane NV12/NV16).
1443    pub fn from_planes(luma: Tensor<T>, chroma: Tensor<T>, format: PixelFormat) -> Result<Self> {
1444        if format.layout() != PixelLayout::SemiPlanar {
1445            return Err(Error::InvalidArgument(format!(
1446                "from_planes requires a semi-planar format, got {format:?}"
1447            )));
1448        }
1449        if chroma.format.is_some() || chroma.chroma.is_some() {
1450            return Err(Error::InvalidArgument(
1451                "chroma tensor must be a raw tensor (no format or chroma metadata)".into(),
1452            ));
1453        }
1454        let luma_shape = luma.shape();
1455        let chroma_shape = chroma.shape();
1456        if luma_shape.len() != 2 || chroma_shape.len() != 2 {
1457            return Err(Error::InvalidArgument(format!(
1458                "from_planes expects 2D shapes, got luma={luma_shape:?} chroma={chroma_shape:?}"
1459            )));
1460        }
1461        if luma_shape[1] != chroma_shape[1] {
1462            return Err(Error::InvalidArgument(format!(
1463                "luma width {} != chroma width {}",
1464                luma_shape[1], chroma_shape[1]
1465            )));
1466        }
1467        match format {
1468            PixelFormat::Nv12 => {
1469                if luma_shape[0] % 2 != 0 {
1470                    return Err(Error::InvalidArgument(format!(
1471                        "NV12 requires even luma height, got {}",
1472                        luma_shape[0]
1473                    )));
1474                }
1475                if chroma_shape[0] != luma_shape[0] / 2 {
1476                    return Err(Error::InvalidArgument(format!(
1477                        "NV12 chroma height {} != luma height / 2 ({})",
1478                        chroma_shape[0],
1479                        luma_shape[0] / 2
1480                    )));
1481                }
1482            }
1483            PixelFormat::Nv16 => {
1484                if chroma_shape[0] != luma_shape[0] {
1485                    return Err(Error::InvalidArgument(format!(
1486                        "NV16 chroma height {} != luma height {}",
1487                        chroma_shape[0], luma_shape[0]
1488                    )));
1489                }
1490            }
1491            _ => {
1492                return Err(Error::InvalidArgument(format!(
1493                    "from_planes only supports NV12 and NV16, got {format:?}"
1494                )));
1495            }
1496        }
1497
1498        Ok(Tensor {
1499            storage: luma.storage,
1500            format: Some(format),
1501            chroma: Some(Box::new(chroma)),
1502            row_stride: luma.row_stride,
1503            plane_offset: luma.plane_offset,
1504            quantization: luma.quantization,
1505        })
1506    }
1507
1508    /// Whether this tensor uses separate plane allocations.
1509    pub fn is_multiplane(&self) -> bool {
1510        self.chroma.is_some()
1511    }
1512
1513    /// Access the chroma plane for multiplane semi-planar images.
1514    pub fn chroma(&self) -> Option<&Tensor<T>> {
1515        self.chroma.as_deref()
1516    }
1517
1518    /// Mutable access to the chroma plane for multiplane semi-planar images.
1519    pub fn chroma_mut(&mut self) -> Option<&mut Tensor<T>> {
1520        self.chroma.as_deref_mut()
1521    }
1522
1523    /// Row stride in bytes (`None` = tightly packed).
1524    pub fn row_stride(&self) -> Option<usize> {
1525        self.row_stride
1526    }
1527
1528    /// Effective row stride in bytes: the stored stride if set, otherwise the
1529    /// minimum stride computed from the format, width, and element size.
1530    /// Returns `None` only when no format is set and no explicit stride was
1531    /// stored via [`set_row_stride`](Self::set_row_stride).
1532    pub fn effective_row_stride(&self) -> Option<usize> {
1533        if let Some(s) = self.row_stride {
1534            return Some(s);
1535        }
1536        let fmt = self.format?;
1537        let w = self.width()?;
1538        let elem = std::mem::size_of::<T>();
1539        Some(match fmt.layout() {
1540            PixelLayout::Packed => w * fmt.channels() * elem,
1541            PixelLayout::Planar | PixelLayout::SemiPlanar => w * elem,
1542        })
1543    }
1544
1545    /// Set the row stride in bytes for externally allocated buffers with
1546    /// row padding (e.g. V4L2 or GStreamer allocators).
1547    ///
1548    /// The stride is propagated to the EGL DMA-BUF import attributes so
1549    /// the GPU interprets the padded buffer layout correctly. Must be
1550    /// called after [`set_format`](Self::set_format) and before the tensor
1551    /// is first passed to [`ImageProcessor::convert`]. The stored stride
1552    /// is cleared automatically if the pixel format is later changed.
1553    ///
1554    /// No stride-vs-buffer-size validation is performed because the
1555    /// backing allocation size is not reliably known: external DMA-BUFs
1556    /// may be over-allocated by the allocator, and internal tensors store
1557    /// a logical (unpadded) shape. An incorrect stride will be caught by
1558    /// the EGL driver at import time.
1559    ///
1560    /// # Arguments
1561    ///
1562    /// * `stride` - Row stride in bytes. Must be >= the minimum stride for
1563    ///   the format (width * channels * sizeof(T) for packed,
1564    ///   width * sizeof(T) for planar/semi-planar).
1565    ///
1566    /// # Errors
1567    ///
1568    /// * `InvalidArgument` if no pixel format is set on this tensor
1569    /// * `InvalidArgument` if `stride` is less than the minimum for the
1570    ///   format and width
1571    pub fn set_row_stride(&mut self, stride: usize) -> Result<()> {
1572        let fmt = self.format.ok_or_else(|| {
1573            Error::InvalidArgument("cannot set row_stride without a pixel format".into())
1574        })?;
1575        let w = self.width().ok_or_else(|| {
1576            Error::InvalidArgument("cannot determine width for row_stride validation".into())
1577        })?;
1578        let elem = std::mem::size_of::<T>();
1579        let min_stride = match fmt.layout() {
1580            PixelLayout::Packed => w * fmt.channels() * elem,
1581            PixelLayout::Planar | PixelLayout::SemiPlanar => w * elem,
1582        };
1583        if stride < min_stride {
1584            return Err(Error::InvalidArgument(format!(
1585                "row_stride {stride} < minimum {min_stride} for {fmt:?} at width {w}"
1586            )));
1587        }
1588        self.row_stride = Some(stride);
1589        Ok(())
1590    }
1591
1592    /// Set the row stride without format validation.
1593    ///
1594    /// Use this for raw sub-tensors (e.g. chroma planes) that don't carry
1595    /// format metadata. The caller is responsible for ensuring the stride
1596    /// is valid.
1597    pub fn set_row_stride_unchecked(&mut self, stride: usize) {
1598        self.row_stride = Some(stride);
1599    }
1600
1601    /// Builder-style variant of [`set_row_stride`](Self::set_row_stride),
1602    /// consuming and returning `self`.
1603    ///
1604    /// # Errors
1605    ///
1606    /// Same conditions as [`set_row_stride`](Self::set_row_stride).
1607    pub fn with_row_stride(mut self, stride: usize) -> Result<Self> {
1608        self.set_row_stride(stride)?;
1609        Ok(self)
1610    }
1611
1612    /// Byte offset within the DMA-BUF where image data starts (`None` = 0).
1613    pub fn plane_offset(&self) -> Option<usize> {
1614        self.plane_offset
1615    }
1616
1617    /// Set the byte offset within the DMA-BUF where image data starts.
1618    ///
1619    /// Propagated to `EGL_DMA_BUF_PLANE0_OFFSET_EXT` on GPU import.
1620    /// Unlike [`set_row_stride`](Self::set_row_stride), no format is required
1621    /// since the offset is format-independent.
1622    pub fn set_plane_offset(&mut self, offset: usize) {
1623        self.plane_offset = Some(offset);
1624        #[cfg(target_os = "linux")]
1625        if let TensorStorage::Dma(ref mut dma) = self.storage {
1626            dma.mmap_offset = offset;
1627        }
1628    }
1629
1630    /// Builder-style variant of [`set_plane_offset`](Self::set_plane_offset),
1631    /// consuming and returning `self`.
1632    pub fn with_plane_offset(mut self, offset: usize) -> Self {
1633        self.set_plane_offset(offset);
1634        self
1635    }
1636
1637    /// Downcast to PBO tensor reference (for GL backends).
1638    pub fn as_pbo(&self) -> Option<&PboTensor<T>> {
1639        match &self.storage {
1640            TensorStorage::Pbo(p) => Some(p),
1641            _ => None,
1642        }
1643    }
1644
1645    /// Downcast to DMA tensor reference (for EGL import, G2D).
1646    #[cfg(target_os = "linux")]
1647    pub fn as_dma(&self) -> Option<&DmaTensor<T>> {
1648        match &self.storage {
1649            TensorStorage::Dma(d) => Some(d),
1650            _ => None,
1651        }
1652    }
1653
1654    /// Borrow the DMA-BUF file descriptor backing this tensor.
1655    ///
1656    /// # Returns
1657    ///
1658    /// A borrowed reference to the DMA-BUF file descriptor, tied to `self`'s
1659    /// lifetime.
1660    ///
1661    /// # Errors
1662    ///
1663    /// Returns `Error::NotImplemented` if the tensor is not DMA-backed.
1664    #[cfg(target_os = "linux")]
1665    pub fn dmabuf(&self) -> Result<std::os::fd::BorrowedFd<'_>> {
1666        use std::os::fd::AsFd;
1667        match &self.storage {
1668            TensorStorage::Dma(dma) => Ok(dma.fd.as_fd()),
1669            _ => Err(Error::NotImplemented(format!(
1670                "dmabuf requires DMA-backed tensor, got {:?}",
1671                self.storage.memory()
1672            ))),
1673        }
1674    }
1675
1676    /// Construct a Tensor from a PBO tensor (for GL backends that allocate PBOs).
1677    pub fn from_pbo(pbo: PboTensor<T>) -> Self {
1678        Self {
1679            storage: TensorStorage::Pbo(pbo),
1680            format: None,
1681            chroma: None,
1682            row_stride: None,
1683            plane_offset: None,
1684            quantization: None,
1685        }
1686    }
1687}
1688
1689// Quantization accessors — type-gated to integer element types via the
1690// sealed `IntegerType` trait. Calling `.quantization()` on a `Tensor<f32>`
1691// produces a compile error, not a runtime one.
1692impl<T> Tensor<T>
1693where
1694    T: IntegerType + Num + Clone + fmt::Debug + Send + Sync,
1695{
1696    /// Quantization metadata for this tensor, if set.
1697    pub fn quantization(&self) -> Option<&Quantization> {
1698        self.quantization.as_ref()
1699    }
1700
1701    /// Attach quantization metadata to this tensor. Validates against the
1702    /// tensor's shape — returns [`Error::QuantizationInvalid`] on any
1703    /// inconsistency (mismatched scale/zp lengths, out-of-range axis, etc.).
1704    pub fn set_quantization(&mut self, q: Quantization) -> Result<()> {
1705        q.validate(self.shape())?;
1706        self.quantization = Some(q);
1707        Ok(())
1708    }
1709
1710    /// Builder-style variant of [`Self::set_quantization`]. Consumes `self`
1711    /// and returns `Result<Self>` — on success yields the tensor with the
1712    /// attached quantization; on validation failure returns
1713    /// [`Error::QuantizationInvalid`] and drops `self` (the tensor is not
1714    /// returned in the error arm).
1715    pub fn with_quantization(mut self, q: Quantization) -> Result<Self> {
1716        self.set_quantization(q)?;
1717        Ok(self)
1718    }
1719
1720    /// Clear any quantization metadata on this tensor.
1721    pub fn clear_quantization(&mut self) {
1722        self.quantization = None;
1723    }
1724}
1725
1726impl<T> TensorTrait<T> for Tensor<T>
1727where
1728    T: Num + Clone + fmt::Debug + Send + Sync,
1729{
1730    fn new(shape: &[usize], name: Option<&str>) -> Result<Self>
1731    where
1732        Self: Sized,
1733    {
1734        Self::new(shape, None, name)
1735    }
1736
1737    #[cfg(unix)]
1738    fn from_fd(fd: std::os::fd::OwnedFd, shape: &[usize], name: Option<&str>) -> Result<Self>
1739    where
1740        Self: Sized,
1741    {
1742        Ok(Self::wrap(TensorStorage::from_fd(fd, shape, name)?))
1743    }
1744
1745    #[cfg(unix)]
1746    fn clone_fd(&self) -> Result<std::os::fd::OwnedFd> {
1747        self.storage.clone_fd()
1748    }
1749
1750    fn memory(&self) -> TensorMemory {
1751        self.storage.memory()
1752    }
1753
1754    fn name(&self) -> String {
1755        self.storage.name()
1756    }
1757
1758    fn shape(&self) -> &[usize] {
1759        self.storage.shape()
1760    }
1761
1762    fn reshape(&mut self, shape: &[usize]) -> Result<()> {
1763        if self.chroma.is_some() {
1764            return Err(Error::InvalidOperation(
1765                "cannot reshape a multiplane tensor — decompose planes first".into(),
1766            ));
1767        }
1768        self.storage.reshape(shape)?;
1769        self.format = None;
1770        self.row_stride = None;
1771        self.plane_offset = None;
1772        #[cfg(target_os = "linux")]
1773        if let TensorStorage::Dma(ref mut dma) = self.storage {
1774            dma.mmap_offset = 0;
1775        }
1776        Ok(())
1777    }
1778
1779    fn map(&self) -> Result<TensorMap<T>> {
1780        let _span = tracing::trace_span!(
1781            "tensor_map",
1782            memory = ?self.storage.memory(),
1783        )
1784        .entered();
1785        // CPU mapping of strided tensors is allowed only when the HAL
1786        // owns the underlying allocation — i.e. self-allocated DMA
1787        // tensors with pitch padding added by `image_with_stride()`
1788        // for GPU import alignment. In that case we know the buffer
1789        // is exactly `row_stride × height` bytes (for packed formats)
1790        // and callers that respect the stride can iterate rows
1791        // correctly via `effective_row_stride()`.
1792        //
1793        // Foreign DMA-BUFs imported via `from_fd()` + `set_row_stride()`
1794        // (the V4L2 / GStreamer case) are rejected: their layout comes
1795        // from an external allocator and the HAL cannot validate what
1796        // the caller expects the mapping to look like. Those tensors
1797        // are intended for the GPU path only.
1798        //
1799        // The cfg split keeps `stride` from being an unused binding on
1800        // non-Linux builds (the Linux branch is the only consumer).
1801        #[cfg(target_os = "linux")]
1802        if let Some(stride) = self.row_stride {
1803            if let TensorStorage::Dma(dma) = &self.storage {
1804                if !dma.is_imported {
1805                    // Self-allocated strided DMA tensor — expose the
1806                    // full stride×height padded mmap via the override
1807                    // constructor so callers can iterate rows with
1808                    // `effective_row_stride()` without going past
1809                    // the end of the returned slice.
1810                    //
1811                    // Validate the requested mapping fits inside the
1812                    // actual DMA-BUF. `set_row_stride()` is a public
1813                    // API and only validates `stride >= min_stride`,
1814                    // not `stride × height <= buf_size`, so a caller
1815                    // that tampers with the stride after allocation
1816                    // could otherwise request a slice larger than the
1817                    // underlying mmap — which would be undefined
1818                    // behaviour in `DmaMap::as_slice`.
1819                    //
1820                    // Refuse to map if `height()` can't be derived
1821                    // (e.g. raw 2D tensors without a PixelFormat that
1822                    // got a `row_stride` set via `set_row_stride_unchecked`).
1823                    // Returning a 0-byte view would silently truncate
1824                    // rather than surface the misuse.
1825                    let height = self.height().ok_or_else(|| {
1826                        Error::InvalidOperation(
1827                            "Tensor::map: strided DMA mapping requires a PixelFormat \
1828                             so height() can be derived; set a format before mapping \
1829                             or clear row_stride for raw tensor access"
1830                                .into(),
1831                        )
1832                    })?;
1833                    let total_bytes = stride.checked_mul(height).ok_or_else(|| {
1834                        Error::InvalidOperation(format!(
1835                            "Tensor::map: row_stride {stride} × height {height} overflows usize"
1836                        ))
1837                    })?;
1838                    let available_bytes = dma.buf_size.saturating_sub(dma.mmap_offset);
1839                    if total_bytes > available_bytes {
1840                        return Err(Error::InvalidOperation(format!(
1841                            "Tensor::map: strided mapping needs {total_bytes} bytes \
1842                             but DMA buffer only has {available_bytes} available \
1843                             (buf_size={}, mmap_offset={}, stride={stride}, height={height}); \
1844                             the row_stride was likely set larger than the original allocation",
1845                            dma.buf_size, dma.mmap_offset
1846                        )));
1847                    }
1848                    return dma.map_with_byte_size(total_bytes).map(TensorMap::Dma);
1849                }
1850            }
1851            return Err(Error::InvalidOperation(
1852                "CPU mapping of strided foreign tensors is not supported; \
1853                 use GPU path only"
1854                    .into(),
1855            ));
1856        }
1857        #[cfg(not(target_os = "linux"))]
1858        if self.row_stride.is_some() {
1859            return Err(Error::InvalidOperation(
1860                "CPU mapping of strided tensors is not supported on this \
1861                 platform (DMA backing is Linux-only)"
1862                    .into(),
1863            ));
1864        }
1865        // Offset tensors are supported for DMA storage — DmaMap adjusts the
1866        // mmap range and slice start position.  Non-DMA offset tensors are
1867        // not meaningful (offset only applies to DMA-BUF sub-regions).
1868        if self.plane_offset.is_some_and(|o| o > 0) {
1869            #[cfg(target_os = "linux")]
1870            if !matches!(self.storage, TensorStorage::Dma(_)) {
1871                return Err(Error::InvalidOperation(
1872                    "plane offset only supported for DMA tensors".into(),
1873                ));
1874            }
1875            #[cfg(not(target_os = "linux"))]
1876            return Err(Error::InvalidOperation(
1877                "plane offset only supported for DMA tensors".into(),
1878            ));
1879        }
1880        self.storage.map()
1881    }
1882
1883    fn buffer_identity(&self) -> &BufferIdentity {
1884        self.storage.buffer_identity()
1885    }
1886}
1887
1888pub enum TensorMap<T>
1889where
1890    T: Num + Clone + fmt::Debug,
1891{
1892    #[cfg(target_os = "linux")]
1893    Dma(DmaMap<T>),
1894    #[cfg(unix)]
1895    Shm(ShmMap<T>),
1896    Mem(MemMap<T>),
1897    Pbo(PboMap<T>),
1898}
1899
1900impl<T> TensorMapTrait<T> for TensorMap<T>
1901where
1902    T: Num + Clone + fmt::Debug,
1903{
1904    fn shape(&self) -> &[usize] {
1905        match self {
1906            #[cfg(target_os = "linux")]
1907            TensorMap::Dma(map) => map.shape(),
1908            #[cfg(unix)]
1909            TensorMap::Shm(map) => map.shape(),
1910            TensorMap::Mem(map) => map.shape(),
1911            TensorMap::Pbo(map) => map.shape(),
1912        }
1913    }
1914
1915    fn unmap(&mut self) {
1916        match self {
1917            #[cfg(target_os = "linux")]
1918            TensorMap::Dma(map) => map.unmap(),
1919            #[cfg(unix)]
1920            TensorMap::Shm(map) => map.unmap(),
1921            TensorMap::Mem(map) => map.unmap(),
1922            TensorMap::Pbo(map) => map.unmap(),
1923        }
1924    }
1925
1926    fn as_slice(&self) -> &[T] {
1927        match self {
1928            #[cfg(target_os = "linux")]
1929            TensorMap::Dma(map) => map.as_slice(),
1930            #[cfg(unix)]
1931            TensorMap::Shm(map) => map.as_slice(),
1932            TensorMap::Mem(map) => map.as_slice(),
1933            TensorMap::Pbo(map) => map.as_slice(),
1934        }
1935    }
1936
1937    fn as_mut_slice(&mut self) -> &mut [T] {
1938        match self {
1939            #[cfg(target_os = "linux")]
1940            TensorMap::Dma(map) => map.as_mut_slice(),
1941            #[cfg(unix)]
1942            TensorMap::Shm(map) => map.as_mut_slice(),
1943            TensorMap::Mem(map) => map.as_mut_slice(),
1944            TensorMap::Pbo(map) => map.as_mut_slice(),
1945        }
1946    }
1947}
1948
1949impl<T> Deref for TensorMap<T>
1950where
1951    T: Num + Clone + fmt::Debug,
1952{
1953    type Target = [T];
1954
1955    fn deref(&self) -> &[T] {
1956        match self {
1957            #[cfg(target_os = "linux")]
1958            TensorMap::Dma(map) => map.deref(),
1959            #[cfg(unix)]
1960            TensorMap::Shm(map) => map.deref(),
1961            TensorMap::Mem(map) => map.deref(),
1962            TensorMap::Pbo(map) => map.deref(),
1963        }
1964    }
1965}
1966
1967impl<T> DerefMut for TensorMap<T>
1968where
1969    T: Num + Clone + fmt::Debug,
1970{
1971    fn deref_mut(&mut self) -> &mut [T] {
1972        match self {
1973            #[cfg(target_os = "linux")]
1974            TensorMap::Dma(map) => map.deref_mut(),
1975            #[cfg(unix)]
1976            TensorMap::Shm(map) => map.deref_mut(),
1977            TensorMap::Mem(map) => map.deref_mut(),
1978            TensorMap::Pbo(map) => map.deref_mut(),
1979        }
1980    }
1981}
1982
1983// ============================================================================
1984// Platform availability helpers
1985// ============================================================================
1986
1987/// Check if DMA memory allocation is available on this system.
1988///
1989/// Returns `true` only on Linux systems with DMA-BUF heap access (typically
1990/// requires running as root or membership in a video/render group).
1991/// Always returns `false` on non-Linux platforms (macOS, Windows, etc.).
1992///
1993/// This function caches its result after the first call for efficiency.
1994#[cfg(target_os = "linux")]
1995static DMA_AVAILABLE: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
1996
1997/// Check if DMA memory allocation is available on this system.
1998#[cfg(target_os = "linux")]
1999pub fn is_dma_available() -> bool {
2000    *DMA_AVAILABLE.get_or_init(|| Tensor::<u8>::new(&[64], Some(TensorMemory::Dma), None).is_ok())
2001}
2002
2003/// Check if DMA memory allocation is available on this system.
2004///
2005/// Always returns `false` on non-Linux platforms since DMA-BUF is Linux-specific.
2006#[cfg(not(target_os = "linux"))]
2007pub fn is_dma_available() -> bool {
2008    false
2009}
2010
2011/// Check if POSIX shared memory allocation is available on this system.
2012///
2013/// Returns `true` on Unix systems (Linux, macOS, BSD) where POSIX shared memory
2014/// is supported. Always returns `false` on non-Unix platforms (Windows).
2015///
2016/// This function caches its result after the first call for efficiency.
2017#[cfg(unix)]
2018static SHM_AVAILABLE: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
2019
2020/// Check if POSIX shared memory allocation is available on this system.
2021#[cfg(unix)]
2022pub fn is_shm_available() -> bool {
2023    *SHM_AVAILABLE.get_or_init(|| Tensor::<u8>::new(&[64], Some(TensorMemory::Shm), None).is_ok())
2024}
2025
2026/// Check if POSIX shared memory allocation is available on this system.
2027///
2028/// Always returns `false` on non-Unix platforms since POSIX SHM is Unix-specific.
2029#[cfg(not(unix))]
2030pub fn is_shm_available() -> bool {
2031    false
2032}
2033
2034#[cfg(test)]
2035mod dtype_tests {
2036    use super::*;
2037
2038    #[test]
2039    fn dtype_size() {
2040        assert_eq!(DType::U8.size(), 1);
2041        assert_eq!(DType::I8.size(), 1);
2042        assert_eq!(DType::U16.size(), 2);
2043        assert_eq!(DType::I16.size(), 2);
2044        assert_eq!(DType::U32.size(), 4);
2045        assert_eq!(DType::I32.size(), 4);
2046        assert_eq!(DType::U64.size(), 8);
2047        assert_eq!(DType::I64.size(), 8);
2048        assert_eq!(DType::F16.size(), 2);
2049        assert_eq!(DType::F32.size(), 4);
2050        assert_eq!(DType::F64.size(), 8);
2051    }
2052
2053    #[test]
2054    fn dtype_name() {
2055        assert_eq!(DType::U8.name(), "u8");
2056        assert_eq!(DType::F16.name(), "f16");
2057        assert_eq!(DType::F32.name(), "f32");
2058    }
2059
2060    #[test]
2061    fn dtype_serde_roundtrip() {
2062        use serde_json;
2063        let dt = DType::F16;
2064        let json = serde_json::to_string(&dt).unwrap();
2065        let back: DType = serde_json::from_str(&json).unwrap();
2066        assert_eq!(dt, back);
2067    }
2068}
2069
2070#[cfg(test)]
2071mod image_tests {
2072    use super::*;
2073
2074    #[test]
2075    fn raw_tensor_has_no_format() {
2076        let t = Tensor::<u8>::new(&[480, 640, 3], None, None).unwrap();
2077        assert!(t.format().is_none());
2078        assert!(t.width().is_none());
2079        assert!(t.height().is_none());
2080        assert!(!t.is_multiplane());
2081        assert!(t.chroma().is_none());
2082    }
2083
2084    #[test]
2085    fn image_tensor_packed() {
2086        let t = Tensor::<u8>::image(640, 480, PixelFormat::Rgba, None).unwrap();
2087        assert_eq!(t.format(), Some(PixelFormat::Rgba));
2088        assert_eq!(t.width(), Some(640));
2089        assert_eq!(t.height(), Some(480));
2090        assert_eq!(t.shape(), &[480, 640, 4]);
2091        assert!(!t.is_multiplane());
2092    }
2093
2094    #[test]
2095    fn image_tensor_planar() {
2096        let t = Tensor::<u8>::image(640, 480, PixelFormat::PlanarRgb, None).unwrap();
2097        assert_eq!(t.format(), Some(PixelFormat::PlanarRgb));
2098        assert_eq!(t.width(), Some(640));
2099        assert_eq!(t.height(), Some(480));
2100        assert_eq!(t.shape(), &[3, 480, 640]);
2101    }
2102
2103    #[test]
2104    fn image_tensor_semi_planar_contiguous() {
2105        let t = Tensor::<u8>::image(640, 480, PixelFormat::Nv12, None).unwrap();
2106        assert_eq!(t.format(), Some(PixelFormat::Nv12));
2107        assert_eq!(t.width(), Some(640));
2108        assert_eq!(t.height(), Some(480));
2109        // NV12: H*3/2 = 720
2110        assert_eq!(t.shape(), &[720, 640]);
2111        assert!(!t.is_multiplane());
2112    }
2113
2114    #[test]
2115    #[cfg(target_os = "linux")]
2116    fn image_tensor_with_stride_preserves_logical_width() {
2117        // Skip if DMA not available (e.g. sandboxed CI lacking dma_heap access).
2118        if !is_dma_available() {
2119            eprintln!("SKIPPED: DMA heap not available");
2120            return;
2121        }
2122        // 3004×1688 RGBA8: natural pitch 12016, padded to 12032 (64-aligned).
2123        let stride = 12032;
2124        let t = Tensor::<u8>::image_with_stride(
2125            3004,
2126            1688,
2127            PixelFormat::Rgba,
2128            stride,
2129            Some(TensorMemory::Dma),
2130        )
2131        .unwrap();
2132        // Logical dimensions unchanged by padding — this is the contract.
2133        assert_eq!(t.width(), Some(3004));
2134        assert_eq!(t.height(), Some(1688));
2135        assert_eq!(t.shape(), &[1688, 3004, 4]);
2136        // Stride is carried separately and reports the padded pitch.
2137        assert_eq!(t.effective_row_stride(), Some(stride));
2138        // Buffer is sized to stride × height so the full padded layout fits,
2139        // and CPU map() works for self-allocated strided DMA tensors.
2140        use crate::TensorMapTrait;
2141        {
2142            let map = t.map().unwrap();
2143            assert!(
2144                map.as_slice().len() >= stride * 1688,
2145                "mapped buffer {} bytes < expected {}",
2146                map.as_slice().len(),
2147                stride * 1688
2148            );
2149        }
2150        // CPU write access works too — iterate rows using the padded stride,
2151        // touch only the active `width × bpp` region, verify it round-trips.
2152        {
2153            let mut map = t.map().unwrap();
2154            let slice = map.as_mut_slice();
2155            for y in 0..1688 {
2156                let row_start = y * stride;
2157                for x in 0..3004 {
2158                    let p = row_start + x * 4;
2159                    slice[p] = (y & 0xFF) as u8;
2160                    slice[p + 1] = (x & 0xFF) as u8;
2161                    slice[p + 2] = 0x42;
2162                    slice[p + 3] = 0xFF;
2163                }
2164            }
2165        }
2166        {
2167            let map = t.map().unwrap();
2168            let slice = map.as_slice();
2169            // Sample a few pixels to confirm the round-trip.
2170            assert_eq!(slice[0], 0x00);
2171            assert_eq!(slice[1], 0x00);
2172            assert_eq!(slice[2], 0x42);
2173            assert_eq!(slice[3], 0xFF);
2174            let mid = 100 * stride + 50 * 4;
2175            assert_eq!(slice[mid], 100);
2176            assert_eq!(slice[mid + 1], 50);
2177            assert_eq!(slice[mid + 2], 0x42);
2178        }
2179    }
2180
2181    #[test]
2182    #[cfg(target_os = "linux")]
2183    fn image_tensor_with_stride_rejects_foreign_strided_map() {
2184        // A FOREIGN (imported via from_fd) DMA tensor with row_stride set
2185        // should still refuse CPU mapping — external allocator owns the
2186        // layout. This protects the V4L2 / GStreamer use case.
2187        //
2188        // We simulate a foreign import by wrapping our own allocation's
2189        // fd via `from_fd` and calling set_row_stride manually. The
2190        // `is_imported` flag on from_fd is true by construction.
2191        if !is_dma_available() {
2192            eprintln!("SKIPPED: DMA heap not available");
2193            return;
2194        }
2195        // Allocate a backing buffer large enough for a 320×240 BGRA8 image.
2196        let backing = Tensor::<u8>::new(&[240 * 320 * 4], Some(TensorMemory::Dma), None).unwrap();
2197        let fd = backing.clone_fd().unwrap();
2198        // Import it via from_fd — this marks is_imported=true.
2199        let shape = [240usize, 320, 4];
2200        let storage = TensorStorage::<u8>::from_fd(fd, &shape, None).unwrap();
2201        let mut t = Tensor::<u8>::wrap(storage);
2202        t.set_format(PixelFormat::Bgra).unwrap();
2203        t.set_row_stride(320 * 4).unwrap(); // natural, but still marks it as strided
2204        let err = t.map();
2205        assert!(
2206            matches!(err, Err(Error::InvalidOperation(_))),
2207            "foreign strided map should error"
2208        );
2209    }
2210
2211    #[test]
2212    #[cfg(target_os = "linux")]
2213    fn image_tensor_with_stride_map_rejects_tampered_stride() {
2214        // Round-3 PR feedback (C1): `set_row_stride` is public and only
2215        // validates `stride >= min_stride`, not that the new stride × height
2216        // fits the underlying buffer. A caller that tampers with the stride
2217        // after allocation must not be able to coerce `Tensor::map()` into
2218        // returning a slice larger than the backing mmap (that would be UB
2219        // in `DmaMap::as_slice`).
2220        if !is_dma_available() {
2221            eprintln!("SKIPPED: DMA heap not available");
2222            return;
2223        }
2224        // Allocate a 640×480 RGBA8 padded canvas (stride = 3072 = 768 px).
2225        // Backing buffer is 3072 × 480 = 1,474,560 bytes.
2226        let mut t = Tensor::<u8>::image_with_stride(
2227            640,
2228            480,
2229            PixelFormat::Rgba,
2230            3072,
2231            Some(TensorMemory::Dma),
2232        )
2233        .unwrap();
2234        // Tamper: push the stride up to 4 × the original. This is >=
2235        // min_stride (2560), so `set_row_stride` accepts it.
2236        t.set_row_stride(12288).unwrap();
2237        // Map must now refuse — 12288 × 480 = 5,898,240 > 1,474,560.
2238        let err = t.map();
2239        assert!(
2240            matches!(err, Err(Error::InvalidOperation(_))),
2241            "map() with oversized stride must return InvalidOperation"
2242        );
2243    }
2244
2245    #[test]
2246    fn dma_tensor_new_with_byte_size_rejects_shape_overflow() {
2247        // Round-3 PR feedback (C3): shape.product() * sizeof(T) must use
2248        // checked arithmetic so a pathological shape can't wrap usize and
2249        // make the byte_size-vs-logical-size comparison incorrect.
2250        //
2251        // This test only exercises the overflow rejection path, which is
2252        // pure-Rust and doesn't touch dma_heap — safe to run on any target.
2253        #[cfg(target_os = "linux")]
2254        {
2255            let err = crate::dma::DmaTensor::<u64>::new_with_byte_size(
2256                &[usize::MAX, 2, 2],
2257                usize::MAX,
2258                None,
2259            );
2260            assert!(
2261                matches!(err, Err(Error::InvalidArgument(_))),
2262                "new_with_byte_size must detect shape.product() overflow"
2263            );
2264        }
2265    }
2266
2267    #[test]
2268    #[cfg(target_os = "linux")]
2269    fn image_tensor_with_stride_rejects_too_small_stride() {
2270        // 640×480 RGBA8 natural pitch = 2560, request 2400 → should error.
2271        let err = Tensor::<u8>::image_with_stride(
2272            640,
2273            480,
2274            PixelFormat::Rgba,
2275            2400,
2276            Some(TensorMemory::Dma),
2277        );
2278        assert!(matches!(err, Err(Error::InvalidArgument(_))));
2279    }
2280
2281    #[test]
2282    #[cfg(target_os = "linux")]
2283    fn image_tensor_with_stride_rejects_non_packed() {
2284        // NV12 is SemiPlanar → not supported. (Linux-only because
2285        // `TensorMemory::Dma` itself is a Linux-only enum variant.)
2286        let err = Tensor::<u8>::image_with_stride(
2287            640,
2288            480,
2289            PixelFormat::Nv12,
2290            640,
2291            Some(TensorMemory::Dma),
2292        );
2293        assert!(matches!(err, Err(Error::NotImplemented(_))));
2294    }
2295
2296    #[test]
2297    fn set_format_valid() {
2298        let mut t = Tensor::<u8>::new(&[480, 640, 3], None, None).unwrap();
2299        assert!(t.format().is_none());
2300        t.set_format(PixelFormat::Rgb).unwrap();
2301        assert_eq!(t.format(), Some(PixelFormat::Rgb));
2302        assert_eq!(t.width(), Some(640));
2303        assert_eq!(t.height(), Some(480));
2304    }
2305
2306    #[test]
2307    fn set_format_invalid_shape() {
2308        let mut t = Tensor::<u8>::new(&[480, 640, 4], None, None).unwrap();
2309        // RGB expects 3 channels, not 4
2310        let err = t.set_format(PixelFormat::Rgb);
2311        assert!(err.is_err());
2312        // Original tensor is unmodified
2313        assert!(t.format().is_none());
2314    }
2315
2316    #[test]
2317    fn reshape_clears_format() {
2318        let mut t = Tensor::<u8>::image(640, 480, PixelFormat::Rgba, None).unwrap();
2319        assert_eq!(t.format(), Some(PixelFormat::Rgba));
2320        // Reshape to flat — format cleared
2321        t.reshape(&[480 * 640 * 4]).unwrap();
2322        assert!(t.format().is_none());
2323    }
2324
2325    #[test]
2326    fn from_planes_nv12() {
2327        let y = Tensor::<u8>::new(&[480, 640], None, None).unwrap();
2328        let uv = Tensor::<u8>::new(&[240, 640], None, None).unwrap();
2329        let img = Tensor::from_planes(y, uv, PixelFormat::Nv12).unwrap();
2330        assert_eq!(img.format(), Some(PixelFormat::Nv12));
2331        assert!(img.is_multiplane());
2332        assert!(img.chroma().is_some());
2333        assert_eq!(img.width(), Some(640));
2334        assert_eq!(img.height(), Some(480));
2335    }
2336
2337    #[test]
2338    fn from_planes_rejects_non_semiplanar() {
2339        let y = Tensor::<u8>::new(&[480, 640], None, None).unwrap();
2340        let uv = Tensor::<u8>::new(&[240, 640], None, None).unwrap();
2341        let err = Tensor::from_planes(y, uv, PixelFormat::Rgb);
2342        assert!(err.is_err());
2343    }
2344
2345    #[test]
2346    fn reshape_multiplane_errors() {
2347        let y = Tensor::<u8>::new(&[480, 640], None, None).unwrap();
2348        let uv = Tensor::<u8>::new(&[240, 640], None, None).unwrap();
2349        let mut img = Tensor::from_planes(y, uv, PixelFormat::Nv12).unwrap();
2350        let err = img.reshape(&[480 * 640 + 240 * 640]);
2351        assert!(err.is_err());
2352    }
2353}
2354
2355#[cfg(test)]
2356mod tests {
2357    #[cfg(target_os = "linux")]
2358    use nix::unistd::{access, AccessFlags};
2359    #[cfg(target_os = "linux")]
2360    use std::io::Write as _;
2361    use std::sync::RwLock;
2362
2363    use super::*;
2364
2365    #[ctor::ctor]
2366    fn init() {
2367        env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init();
2368    }
2369
2370    /// Macro to get the current function name for logging in tests.
2371    #[cfg(target_os = "linux")]
2372    macro_rules! function {
2373        () => {{
2374            fn f() {}
2375            fn type_name_of<T>(_: T) -> &'static str {
2376                std::any::type_name::<T>()
2377            }
2378            let name = type_name_of(f);
2379
2380            // Find and cut the rest of the path
2381            match &name[..name.len() - 3].rfind(':') {
2382                Some(pos) => &name[pos + 1..name.len() - 3],
2383                None => &name[..name.len() - 3],
2384            }
2385        }};
2386    }
2387
2388    #[test]
2389    #[cfg(target_os = "linux")]
2390    fn test_tensor() {
2391        let _lock = FD_LOCK.read().unwrap();
2392        let shape = vec![1];
2393        let tensor = DmaTensor::<f32>::new(&shape, Some("dma_tensor"));
2394        let dma_enabled = tensor.is_ok();
2395
2396        let tensor = Tensor::<f32>::new(&shape, None, None).expect("Failed to create tensor");
2397        match dma_enabled {
2398            true => assert_eq!(tensor.memory(), TensorMemory::Dma),
2399            false => assert_eq!(tensor.memory(), TensorMemory::Shm),
2400        }
2401    }
2402
2403    #[test]
2404    #[cfg(all(unix, not(target_os = "linux")))]
2405    fn test_tensor() {
2406        let shape = vec![1];
2407        let tensor = Tensor::<f32>::new(&shape, None, None).expect("Failed to create tensor");
2408        // On macOS/BSD, auto-detection tries SHM first, falls back to Mem
2409        assert!(
2410            tensor.memory() == TensorMemory::Shm || tensor.memory() == TensorMemory::Mem,
2411            "Expected SHM or Mem on macOS, got {:?}",
2412            tensor.memory()
2413        );
2414    }
2415
2416    #[test]
2417    #[cfg(not(unix))]
2418    fn test_tensor() {
2419        let shape = vec![1];
2420        let tensor = Tensor::<f32>::new(&shape, None, None).expect("Failed to create tensor");
2421        assert_eq!(tensor.memory(), TensorMemory::Mem);
2422    }
2423
2424    #[test]
2425    #[cfg(target_os = "linux")]
2426    fn test_dma_tensor() {
2427        let _lock = FD_LOCK.read().unwrap();
2428        match access(
2429            "/dev/dma_heap/linux,cma",
2430            AccessFlags::R_OK | AccessFlags::W_OK,
2431        ) {
2432            Ok(_) => println!("/dev/dma_heap/linux,cma is available"),
2433            Err(_) => match access(
2434                "/dev/dma_heap/system",
2435                AccessFlags::R_OK | AccessFlags::W_OK,
2436            ) {
2437                Ok(_) => println!("/dev/dma_heap/system is available"),
2438                Err(e) => {
2439                    writeln!(
2440                        &mut std::io::stdout(),
2441                        "[WARNING] DMA Heap is unavailable: {e}"
2442                    )
2443                    .unwrap();
2444                    return;
2445                }
2446            },
2447        }
2448
2449        let shape = vec![2, 3, 4];
2450        let tensor =
2451            DmaTensor::<f32>::new(&shape, Some("test_tensor")).expect("Failed to create tensor");
2452
2453        const DUMMY_VALUE: f32 = 12.34;
2454
2455        assert_eq!(tensor.memory(), TensorMemory::Dma);
2456        assert_eq!(tensor.name(), "test_tensor");
2457        assert_eq!(tensor.shape(), &shape);
2458        assert_eq!(tensor.size(), 2 * 3 * 4 * std::mem::size_of::<f32>());
2459        assert_eq!(tensor.len(), 2 * 3 * 4);
2460
2461        {
2462            let mut tensor_map = tensor.map().expect("Failed to map DMA memory");
2463            tensor_map.fill(42.0);
2464            assert!(tensor_map.iter().all(|&x| x == 42.0));
2465        }
2466
2467        {
2468            let shared = Tensor::<f32>::from_fd(
2469                tensor
2470                    .clone_fd()
2471                    .expect("Failed to duplicate tensor file descriptor"),
2472                &shape,
2473                Some("test_tensor_shared"),
2474            )
2475            .expect("Failed to create tensor from fd");
2476
2477            assert_eq!(shared.memory(), TensorMemory::Dma);
2478            assert_eq!(shared.name(), "test_tensor_shared");
2479            assert_eq!(shared.shape(), &shape);
2480
2481            let mut tensor_map = shared.map().expect("Failed to map DMA memory from fd");
2482            tensor_map.fill(DUMMY_VALUE);
2483            assert!(tensor_map.iter().all(|&x| x == DUMMY_VALUE));
2484        }
2485
2486        {
2487            let tensor_map = tensor.map().expect("Failed to map DMA memory");
2488            assert!(tensor_map.iter().all(|&x| x == DUMMY_VALUE));
2489        }
2490
2491        let mut tensor = DmaTensor::<u8>::new(&shape, None).expect("Failed to create tensor");
2492        assert_eq!(tensor.shape(), &shape);
2493        let new_shape = vec![3, 4, 4];
2494        assert!(
2495            tensor.reshape(&new_shape).is_err(),
2496            "Reshape should fail due to size mismatch"
2497        );
2498        assert_eq!(tensor.shape(), &shape, "Shape should remain unchanged");
2499
2500        let new_shape = vec![2, 3, 4];
2501        tensor.reshape(&new_shape).expect("Reshape should succeed");
2502        assert_eq!(
2503            tensor.shape(),
2504            &new_shape,
2505            "Shape should be updated after successful reshape"
2506        );
2507
2508        {
2509            let mut tensor_map = tensor.map().expect("Failed to map DMA memory");
2510            tensor_map.fill(1);
2511            assert!(tensor_map.iter().all(|&x| x == 1));
2512        }
2513
2514        {
2515            let mut tensor_map = tensor.map().expect("Failed to map DMA memory");
2516            tensor_map[2] = 42;
2517            assert_eq!(tensor_map[1], 1, "Value at index 1 should be 1");
2518            assert_eq!(tensor_map[2], 42, "Value at index 2 should be 42");
2519        }
2520    }
2521
2522    #[test]
2523    #[cfg(unix)]
2524    fn test_shm_tensor() {
2525        let _lock = FD_LOCK.read().unwrap();
2526        let shape = vec![2, 3, 4];
2527        let tensor =
2528            ShmTensor::<f32>::new(&shape, Some("test_tensor")).expect("Failed to create tensor");
2529        assert_eq!(tensor.shape(), &shape);
2530        assert_eq!(tensor.size(), 2 * 3 * 4 * std::mem::size_of::<f32>());
2531        assert_eq!(tensor.name(), "test_tensor");
2532
2533        const DUMMY_VALUE: f32 = 12.34;
2534        {
2535            let mut tensor_map = tensor.map().expect("Failed to map shared memory");
2536            tensor_map.fill(42.0);
2537            assert!(tensor_map.iter().all(|&x| x == 42.0));
2538        }
2539
2540        {
2541            let shared = Tensor::<f32>::from_fd(
2542                tensor
2543                    .clone_fd()
2544                    .expect("Failed to duplicate tensor file descriptor"),
2545                &shape,
2546                Some("test_tensor_shared"),
2547            )
2548            .expect("Failed to create tensor from fd");
2549
2550            assert_eq!(shared.memory(), TensorMemory::Shm);
2551            assert_eq!(shared.name(), "test_tensor_shared");
2552            assert_eq!(shared.shape(), &shape);
2553
2554            let mut tensor_map = shared.map().expect("Failed to map shared memory from fd");
2555            tensor_map.fill(DUMMY_VALUE);
2556            assert!(tensor_map.iter().all(|&x| x == DUMMY_VALUE));
2557        }
2558
2559        {
2560            let tensor_map = tensor.map().expect("Failed to map shared memory");
2561            assert!(tensor_map.iter().all(|&x| x == DUMMY_VALUE));
2562        }
2563
2564        let mut tensor = ShmTensor::<u8>::new(&shape, None).expect("Failed to create tensor");
2565        assert_eq!(tensor.shape(), &shape);
2566        let new_shape = vec![3, 4, 4];
2567        assert!(
2568            tensor.reshape(&new_shape).is_err(),
2569            "Reshape should fail due to size mismatch"
2570        );
2571        assert_eq!(tensor.shape(), &shape, "Shape should remain unchanged");
2572
2573        let new_shape = vec![2, 3, 4];
2574        tensor.reshape(&new_shape).expect("Reshape should succeed");
2575        assert_eq!(
2576            tensor.shape(),
2577            &new_shape,
2578            "Shape should be updated after successful reshape"
2579        );
2580
2581        {
2582            let mut tensor_map = tensor.map().expect("Failed to map shared memory");
2583            tensor_map.fill(1);
2584            assert!(tensor_map.iter().all(|&x| x == 1));
2585        }
2586
2587        {
2588            let mut tensor_map = tensor.map().expect("Failed to map shared memory");
2589            tensor_map[2] = 42;
2590            assert_eq!(tensor_map[1], 1, "Value at index 1 should be 1");
2591            assert_eq!(tensor_map[2], 42, "Value at index 2 should be 42");
2592        }
2593    }
2594
2595    #[test]
2596    fn test_mem_tensor() {
2597        let shape = vec![2, 3, 4];
2598        let tensor =
2599            MemTensor::<f32>::new(&shape, Some("test_tensor")).expect("Failed to create tensor");
2600        assert_eq!(tensor.shape(), &shape);
2601        assert_eq!(tensor.size(), 2 * 3 * 4 * std::mem::size_of::<f32>());
2602        assert_eq!(tensor.name(), "test_tensor");
2603
2604        {
2605            let mut tensor_map = tensor.map().expect("Failed to map memory");
2606            tensor_map.fill(42.0);
2607            assert!(tensor_map.iter().all(|&x| x == 42.0));
2608        }
2609
2610        let mut tensor = MemTensor::<u8>::new(&shape, None).expect("Failed to create tensor");
2611        assert_eq!(tensor.shape(), &shape);
2612        let new_shape = vec![3, 4, 4];
2613        assert!(
2614            tensor.reshape(&new_shape).is_err(),
2615            "Reshape should fail due to size mismatch"
2616        );
2617        assert_eq!(tensor.shape(), &shape, "Shape should remain unchanged");
2618
2619        let new_shape = vec![2, 3, 4];
2620        tensor.reshape(&new_shape).expect("Reshape should succeed");
2621        assert_eq!(
2622            tensor.shape(),
2623            &new_shape,
2624            "Shape should be updated after successful reshape"
2625        );
2626
2627        {
2628            let mut tensor_map = tensor.map().expect("Failed to map memory");
2629            tensor_map.fill(1);
2630            assert!(tensor_map.iter().all(|&x| x == 1));
2631        }
2632
2633        {
2634            let mut tensor_map = tensor.map().expect("Failed to map memory");
2635            tensor_map[2] = 42;
2636            assert_eq!(tensor_map[1], 1, "Value at index 1 should be 1");
2637            assert_eq!(tensor_map[2], 42, "Value at index 2 should be 42");
2638        }
2639    }
2640
2641    #[test]
2642    #[cfg(target_os = "linux")]
2643    fn test_dma_no_fd_leaks() {
2644        let _lock = FD_LOCK.write().unwrap();
2645        if !is_dma_available() {
2646            log::warn!(
2647                "SKIPPED: {} - DMA memory allocation not available (permission denied or no DMA-BUF support)",
2648                function!()
2649            );
2650            return;
2651        }
2652
2653        let proc = procfs::process::Process::myself()
2654            .expect("Failed to get current process using /proc/self");
2655
2656        let start_open_fds = proc
2657            .fd_count()
2658            .expect("Failed to get open file descriptor count");
2659
2660        for _ in 0..100 {
2661            let tensor = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Dma), None)
2662                .expect("Failed to create tensor");
2663            let mut map = tensor.map().unwrap();
2664            map.as_mut_slice().fill(233);
2665        }
2666
2667        let end_open_fds = proc
2668            .fd_count()
2669            .expect("Failed to get open file descriptor count");
2670
2671        assert_eq!(
2672            start_open_fds, end_open_fds,
2673            "File descriptor leak detected: {} -> {}",
2674            start_open_fds, end_open_fds
2675        );
2676    }
2677
2678    #[test]
2679    #[cfg(target_os = "linux")]
2680    fn test_dma_from_fd_no_fd_leaks() {
2681        let _lock = FD_LOCK.write().unwrap();
2682        if !is_dma_available() {
2683            log::warn!(
2684                "SKIPPED: {} - DMA memory allocation not available (permission denied or no DMA-BUF support)",
2685                function!()
2686            );
2687            return;
2688        }
2689
2690        let proc = procfs::process::Process::myself()
2691            .expect("Failed to get current process using /proc/self");
2692
2693        let start_open_fds = proc
2694            .fd_count()
2695            .expect("Failed to get open file descriptor count");
2696
2697        let orig = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Dma), None).unwrap();
2698
2699        for _ in 0..100 {
2700            let tensor =
2701                Tensor::<u8>::from_fd(orig.clone_fd().unwrap(), orig.shape(), None).unwrap();
2702            let mut map = tensor.map().unwrap();
2703            map.as_mut_slice().fill(233);
2704        }
2705        drop(orig);
2706
2707        let end_open_fds = proc.fd_count().unwrap();
2708
2709        assert_eq!(
2710            start_open_fds, end_open_fds,
2711            "File descriptor leak detected: {} -> {}",
2712            start_open_fds, end_open_fds
2713        );
2714    }
2715
2716    #[test]
2717    #[cfg(target_os = "linux")]
2718    fn test_shm_no_fd_leaks() {
2719        let _lock = FD_LOCK.write().unwrap();
2720        if !is_shm_available() {
2721            log::warn!(
2722                "SKIPPED: {} - SHM memory allocation not available (permission denied or no SHM support)",
2723                function!()
2724            );
2725            return;
2726        }
2727
2728        let proc = procfs::process::Process::myself()
2729            .expect("Failed to get current process using /proc/self");
2730
2731        let start_open_fds = proc
2732            .fd_count()
2733            .expect("Failed to get open file descriptor count");
2734
2735        for _ in 0..100 {
2736            let tensor = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Shm), None)
2737                .expect("Failed to create tensor");
2738            let mut map = tensor.map().unwrap();
2739            map.as_mut_slice().fill(233);
2740        }
2741
2742        let end_open_fds = proc
2743            .fd_count()
2744            .expect("Failed to get open file descriptor count");
2745
2746        assert_eq!(
2747            start_open_fds, end_open_fds,
2748            "File descriptor leak detected: {} -> {}",
2749            start_open_fds, end_open_fds
2750        );
2751    }
2752
2753    #[test]
2754    #[cfg(target_os = "linux")]
2755    fn test_shm_from_fd_no_fd_leaks() {
2756        let _lock = FD_LOCK.write().unwrap();
2757        if !is_shm_available() {
2758            log::warn!(
2759                "SKIPPED: {} - SHM memory allocation not available (permission denied or no SHM support)",
2760                function!()
2761            );
2762            return;
2763        }
2764
2765        let proc = procfs::process::Process::myself()
2766            .expect("Failed to get current process using /proc/self");
2767
2768        let start_open_fds = proc
2769            .fd_count()
2770            .expect("Failed to get open file descriptor count");
2771
2772        let orig = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Shm), None).unwrap();
2773
2774        for _ in 0..100 {
2775            let tensor =
2776                Tensor::<u8>::from_fd(orig.clone_fd().unwrap(), orig.shape(), None).unwrap();
2777            let mut map = tensor.map().unwrap();
2778            map.as_mut_slice().fill(233);
2779        }
2780        drop(orig);
2781
2782        let end_open_fds = proc.fd_count().unwrap();
2783
2784        assert_eq!(
2785            start_open_fds, end_open_fds,
2786            "File descriptor leak detected: {} -> {}",
2787            start_open_fds, end_open_fds
2788        );
2789    }
2790
2791    #[cfg(feature = "ndarray")]
2792    #[test]
2793    fn test_ndarray() {
2794        let _lock = FD_LOCK.read().unwrap();
2795        let shape = vec![2, 3, 4];
2796        let tensor = Tensor::<f32>::new(&shape, None, None).expect("Failed to create tensor");
2797
2798        let mut tensor_map = tensor.map().expect("Failed to map tensor memory");
2799        tensor_map.fill(1.0);
2800
2801        let view = tensor_map.view().expect("Failed to get ndarray view");
2802        assert_eq!(view.shape(), &[2, 3, 4]);
2803        assert!(view.iter().all(|&x| x == 1.0));
2804
2805        let mut view_mut = tensor_map
2806            .view_mut()
2807            .expect("Failed to get mutable ndarray view");
2808        view_mut[[0, 0, 0]] = 42.0;
2809        assert_eq!(view_mut[[0, 0, 0]], 42.0);
2810        assert_eq!(tensor_map[0], 42.0, "Value at index 0 should be 42");
2811    }
2812
2813    #[test]
2814    fn test_buffer_identity_unique() {
2815        let id1 = BufferIdentity::new();
2816        let id2 = BufferIdentity::new();
2817        assert_ne!(
2818            id1.id(),
2819            id2.id(),
2820            "Two identities should have different ids"
2821        );
2822    }
2823
2824    #[test]
2825    fn test_buffer_identity_clone_shares_guard() {
2826        let id1 = BufferIdentity::new();
2827        let weak = id1.weak();
2828        assert!(
2829            weak.upgrade().is_some(),
2830            "Weak should be alive while original exists"
2831        );
2832
2833        let id2 = id1.clone();
2834        assert_eq!(id1.id(), id2.id(), "Cloned identity should have same id");
2835
2836        drop(id1);
2837        assert!(
2838            weak.upgrade().is_some(),
2839            "Weak should still be alive (clone holds Arc)"
2840        );
2841
2842        drop(id2);
2843        assert!(
2844            weak.upgrade().is_none(),
2845            "Weak should be dead after all clones dropped"
2846        );
2847    }
2848
2849    #[test]
2850    fn test_tensor_buffer_identity() {
2851        let t1 = Tensor::<u8>::new(&[100], Some(TensorMemory::Mem), Some("t1")).unwrap();
2852        let t2 = Tensor::<u8>::new(&[100], Some(TensorMemory::Mem), Some("t2")).unwrap();
2853        assert_ne!(
2854            t1.buffer_identity().id(),
2855            t2.buffer_identity().id(),
2856            "Different tensors should have different buffer ids"
2857        );
2858    }
2859
2860    // ------------------------------------------------------------------------
2861    // Quantization — constructor validation + accessor correctness.
2862    // ------------------------------------------------------------------------
2863
2864    #[test]
2865    fn test_quantization_per_tensor_constructors() {
2866        let q = Quantization::per_tensor(0.1, -5);
2867        assert!(q.is_per_tensor());
2868        assert!(!q.is_per_channel());
2869        assert!(!q.is_symmetric());
2870        assert_eq!(q.scale(), &[0.1]);
2871        assert_eq!(q.zero_point(), Some(&[-5][..]));
2872
2873        let qs = Quantization::per_tensor_symmetric(0.05);
2874        assert!(qs.is_per_tensor());
2875        assert!(qs.is_symmetric());
2876        assert_eq!(qs.zero_point(), None);
2877    }
2878
2879    #[test]
2880    fn test_quantization_per_channel_constructors() {
2881        let q = Quantization::per_channel(vec![0.1, 0.2, 0.3], vec![0, -1, 1], 2).unwrap();
2882        assert!(q.is_per_channel());
2883        assert!(!q.is_symmetric());
2884        assert_eq!(q.axis(), Some(2));
2885        assert_eq!(q.scale().len(), 3);
2886
2887        let qs = Quantization::per_channel_symmetric(vec![0.054, 0.089, 0.195], 0).unwrap();
2888        assert!(qs.is_per_channel());
2889        assert!(qs.is_symmetric());
2890        assert_eq!(qs.axis(), Some(0));
2891    }
2892
2893    #[test]
2894    fn test_quantization_per_channel_length_mismatch_rejected() {
2895        // len(scales) != len(zero_points) → rejected at construction.
2896        let err = Quantization::per_channel(vec![0.1, 0.2], vec![0, 0, 0], 0).unwrap_err();
2897        assert!(matches!(err, Error::QuantizationInvalid { .. }));
2898    }
2899
2900    #[test]
2901    fn test_quantization_per_channel_empty_rejected() {
2902        let err = Quantization::per_channel_symmetric(vec![], 0).unwrap_err();
2903        assert!(matches!(err, Error::QuantizationInvalid { .. }));
2904    }
2905
2906    /// Constructors guard scale/zero_point length invariants, but
2907    /// `Quantization` is `Deserialize`, so malformed JSON (e.g. an
2908    /// empty `scale` array, or `zero_point` length that disagrees with
2909    /// `scale`) bypasses the constructor checks. `set_quantization`
2910    /// must reject these via `validate()` so they don't poison
2911    /// downstream `mode()` selection or per-channel kernel indexing.
2912    #[test]
2913    fn test_quantization_validate_rejects_malformed_deserialize() {
2914        let mut t = Tensor::<i8>::new(&[1, 1, 4], Some(TensorMemory::Mem), None).unwrap();
2915
2916        // Empty scale array: must be rejected.
2917        let q: Quantization = serde_json::from_str(r#"{"scale": []}"#).unwrap();
2918        assert!(matches!(
2919            t.set_quantization(q).unwrap_err(),
2920            Error::QuantizationInvalid { .. }
2921        ));
2922
2923        // Per-tensor with multi-element zero_point: must be rejected.
2924        let q: Quantization =
2925            serde_json::from_str(r#"{"scale": 0.1, "zero_point": [0, 0, 0]}"#).unwrap();
2926        assert!(matches!(
2927            t.set_quantization(q).unwrap_err(),
2928            Error::QuantizationInvalid { .. }
2929        ));
2930
2931        // Per-channel zero_point length != scale length: must be rejected.
2932        let q: Quantization = serde_json::from_str(
2933            r#"{"scale": [0.1, 0.2, 0.3, 0.4], "zero_point": [0, 0], "axis": 2}"#,
2934        )
2935        .unwrap();
2936        assert!(matches!(
2937            t.set_quantization(q).unwrap_err(),
2938            Error::QuantizationInvalid { .. }
2939        ));
2940    }
2941
2942    #[test]
2943    fn test_quantization_mode_dispatch() {
2944        let pt = Quantization::per_tensor(0.1, -5);
2945        assert!(matches!(
2946            pt.mode(),
2947            QuantMode::PerTensor { scale, zero_point } if scale == 0.1 && zero_point == -5
2948        ));
2949
2950        let pts = Quantization::per_tensor_symmetric(0.05);
2951        assert!(matches!(
2952            pts.mode(),
2953            QuantMode::PerTensorSymmetric { scale } if scale == 0.05
2954        ));
2955
2956        let pc = Quantization::per_channel(vec![0.1, 0.2], vec![0, -1], 2).unwrap();
2957        assert!(matches!(pc.mode(), QuantMode::PerChannel { axis: 2, .. }));
2958
2959        let pcs = Quantization::per_channel_symmetric(vec![0.1, 0.2], 0).unwrap();
2960        assert!(matches!(
2961            pcs.mode(),
2962            QuantMode::PerChannelSymmetric { axis: 0, .. }
2963        ));
2964    }
2965
2966    #[test]
2967    fn test_tensor_quantization_roundtrip_integer() {
2968        let mut t = Tensor::<i8>::new(&[2, 3, 4], Some(TensorMemory::Mem), None).unwrap();
2969        assert!(t.quantization().is_none());
2970        t.set_quantization(Quantization::per_tensor(0.1, -5))
2971            .unwrap();
2972        let q = t.quantization().unwrap();
2973        assert_eq!(q.scale(), &[0.1]);
2974        t.clear_quantization();
2975        assert!(t.quantization().is_none());
2976    }
2977
2978    #[test]
2979    fn test_tensor_with_quantization_builder() {
2980        let t = Tensor::<i8>::new(&[4, 4], Some(TensorMemory::Mem), None)
2981            .unwrap()
2982            .with_quantization(Quantization::per_tensor_symmetric(0.05))
2983            .unwrap();
2984        assert!(t.quantization().is_some());
2985    }
2986
2987    #[test]
2988    fn test_tensor_dyn_quantization_float_arm_returns_none() {
2989        let t = Tensor::<f32>::new(&[2, 2], Some(TensorMemory::Mem), None).unwrap();
2990        let td = TensorDyn::F32(t);
2991        assert!(td.quantization().is_none());
2992    }
2993
2994    #[test]
2995    fn test_tensor_dyn_set_quantization_float_arm_errors() {
2996        let t = Tensor::<f32>::new(&[2, 2], Some(TensorMemory::Mem), None).unwrap();
2997        let mut td = TensorDyn::F32(t);
2998        let err = td
2999            .set_quantization(Quantization::per_tensor(0.1, 0))
3000            .unwrap_err();
3001        // float path returns a QuantizationInvalid error.
3002        assert!(matches!(err, Error::QuantizationInvalid { .. }));
3003    }
3004
3005    /// Compile-time type gate — calling `Tensor::<f32>::quantization()` must
3006    /// fail to compile (the `IntegerType` trait bound is not satisfied by
3007    /// `f32`). This doctest anchors the invariant.
3008    ///
3009    /// ```compile_fail
3010    /// use edgefirst_tensor::{Tensor, TensorMemory};
3011    /// let t = Tensor::<f32>::new(&[2, 2], Some(TensorMemory::Mem), None).unwrap();
3012    /// let _ = t.quantization(); // compile error: f32 not IntegerType
3013    /// ```
3014    fn _compile_fail_doctest_anchor() {}
3015
3016    // Any test that cares about the fd count must grab it exclusively.
3017    // Any tests which modifies the fd count by opening or closing fds must grab it
3018    // shared.
3019    pub static FD_LOCK: RwLock<()> = RwLock::new(());
3020
3021    /// Test that DMA is NOT available on non-Linux platforms.
3022    /// This verifies the cross-platform behavior of is_dma_available().
3023    #[test]
3024    #[cfg(not(target_os = "linux"))]
3025    fn test_dma_not_available_on_non_linux() {
3026        assert!(
3027            !is_dma_available(),
3028            "DMA memory allocation should NOT be available on non-Linux platforms"
3029        );
3030    }
3031
3032    /// Test that SHM memory allocation is available and usable on Unix systems.
3033    /// This is a basic functional test; Linux has additional FD leak tests using procfs.
3034    #[test]
3035    #[cfg(unix)]
3036    fn test_shm_available_and_usable() {
3037        assert!(
3038            is_shm_available(),
3039            "SHM memory allocation should be available on Unix systems"
3040        );
3041
3042        // Create a tensor with SHM backing
3043        let tensor = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Shm), None)
3044            .expect("Failed to create SHM tensor");
3045
3046        // Verify we can map and write to it
3047        let mut map = tensor.map().expect("Failed to map SHM tensor");
3048        map.as_mut_slice().fill(0xAB);
3049
3050        // Verify the data was written correctly
3051        assert!(
3052            map.as_slice().iter().all(|&b| b == 0xAB),
3053            "SHM tensor data should be writable and readable"
3054        );
3055    }
3056}