Skip to main content

edgefirst_tensor/
lib.rs

1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4/*!
5EdgeFirst HAL - Tensor Module
6
7The `edgefirst_tensor` crate provides a unified interface for managing multi-dimensional arrays (tensors)
8with support for different memory types, including Direct Memory Access (DMA), POSIX Shared Memory (Shm),
9and system memory. The crate defines traits and structures for creating, reshaping, and mapping tensors into memory.
10
11## Examples
12```rust
13use edgefirst_tensor::{Error, Tensor, TensorMemory, TensorTrait};
14# fn main() -> Result<(), Error> {
15let tensor = Tensor::<f32>::new(&[2, 3, 4], Some(TensorMemory::Mem), Some("test_tensor"))?;
16assert_eq!(tensor.memory(), TensorMemory::Mem);
17assert_eq!(tensor.name(), "test_tensor");
18#    Ok(())
19# }
20```
21
22## Overview
23The main structures and traits provided by the `edgefirst_tensor` crate are `TensorTrait` and `TensorMapTrait`,
24which define the behavior of Tensors and their memory mappings, respectively.
25The `Tensor<T>` struct wraps a backend-specific storage with optional image format metadata (`PixelFormat`),
26while the `TensorMap` enum provides access to the underlying data. The `TensorDyn` type-erased enum
27wraps `Tensor<T>` for runtime element-type dispatch.
28 */
29#[cfg(target_os = "linux")]
30mod dma;
31#[cfg(target_os = "linux")]
32mod dmabuf;
33mod error;
34mod format;
35mod mem;
36mod pbo;
37#[cfg(unix)]
38mod shm;
39mod tensor_dyn;
40
41#[cfg(target_os = "linux")]
42pub use crate::dma::{DmaMap, DmaTensor};
43pub use crate::mem::{MemMap, MemTensor};
44pub use crate::pbo::{PboMap, PboMapping, PboOps, PboTensor};
45#[cfg(unix)]
46pub use crate::shm::{ShmMap, ShmTensor};
47pub use error::{Error, Result};
48pub use format::{PixelFormat, PixelLayout};
49use num_traits::Num;
50use serde::{Deserialize, Serialize};
51#[cfg(unix)]
52use std::os::fd::OwnedFd;
53use std::{
54    fmt,
55    ops::{Deref, DerefMut},
56    sync::{
57        atomic::{AtomicU64, Ordering},
58        Arc, Weak,
59    },
60};
61pub use tensor_dyn::TensorDyn;
62
63/// Per-plane DMA-BUF descriptor for external buffer import.
64///
65/// Owns a duplicated file descriptor plus optional stride and offset metadata.
66/// The fd is duplicated eagerly in [`new()`](Self::new) so that a bad fd is
67/// caught immediately. `import_image` consumes the descriptor and takes
68/// ownership of the duped fd — no further cleanup is needed by the caller.
69///
70/// # Examples
71///
72/// ```rust,no_run
73/// use edgefirst_tensor::PlaneDescriptor;
74/// use std::os::fd::BorrowedFd;
75///
76/// // SAFETY: fd 42 is hypothetical; real code must pass a valid fd.
77/// let pd = unsafe { PlaneDescriptor::new(BorrowedFd::borrow_raw(42)) }
78///     .unwrap()
79///     .with_stride(2048)
80///     .with_offset(0);
81/// ```
82#[cfg(unix)]
83pub struct PlaneDescriptor {
84    fd: OwnedFd,
85    stride: Option<usize>,
86    offset: Option<usize>,
87}
88
89#[cfg(unix)]
90impl PlaneDescriptor {
91    /// Create a new plane descriptor by duplicating the given file descriptor.
92    ///
93    /// The fd is duped immediately — a bad fd fails here rather than inside
94    /// `import_image`. The caller retains ownership of the original fd.
95    ///
96    /// # Errors
97    ///
98    /// Returns an error if the `dup()` syscall fails (e.g. invalid fd or
99    /// fd limit reached).
100    pub fn new(fd: std::os::fd::BorrowedFd<'_>) -> Result<Self> {
101        let owned = fd.try_clone_to_owned()?;
102        Ok(Self {
103            fd: owned,
104            stride: None,
105            offset: None,
106        })
107    }
108
109    /// Set the row stride in bytes (consuming builder).
110    pub fn with_stride(mut self, stride: usize) -> Self {
111        self.stride = Some(stride);
112        self
113    }
114
115    /// Set the plane offset in bytes (consuming builder).
116    pub fn with_offset(mut self, offset: usize) -> Self {
117        self.offset = Some(offset);
118        self
119    }
120
121    /// Consume the descriptor and return the owned file descriptor.
122    pub fn into_fd(self) -> OwnedFd {
123        self.fd
124    }
125
126    /// Row stride in bytes, if set.
127    pub fn stride(&self) -> Option<usize> {
128        self.stride
129    }
130
131    /// Plane offset in bytes, if set.
132    pub fn offset(&self) -> Option<usize> {
133        self.offset
134    }
135}
136
137/// Element type discriminant for runtime type identification.
138#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
139#[repr(u8)]
140#[non_exhaustive]
141pub enum DType {
142    U8,
143    I8,
144    U16,
145    I16,
146    U32,
147    I32,
148    U64,
149    I64,
150    F16,
151    F32,
152    F64,
153}
154
155impl DType {
156    /// Size of one element in bytes.
157    pub const fn size(&self) -> usize {
158        match self {
159            Self::U8 | Self::I8 => 1,
160            Self::U16 | Self::I16 | Self::F16 => 2,
161            Self::U32 | Self::I32 | Self::F32 => 4,
162            Self::U64 | Self::I64 | Self::F64 => 8,
163        }
164    }
165
166    /// Short type name (e.g., "u8", "f32", "f16").
167    pub const fn name(&self) -> &'static str {
168        match self {
169            Self::U8 => "u8",
170            Self::I8 => "i8",
171            Self::U16 => "u16",
172            Self::I16 => "i16",
173            Self::U32 => "u32",
174            Self::I32 => "i32",
175            Self::U64 => "u64",
176            Self::I64 => "i64",
177            Self::F16 => "f16",
178            Self::F32 => "f32",
179            Self::F64 => "f64",
180        }
181    }
182}
183
184impl fmt::Display for DType {
185    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
186        f.write_str(self.name())
187    }
188}
189
190/// Monotonic counter for buffer identity IDs.
191static NEXT_BUFFER_ID: AtomicU64 = AtomicU64::new(1);
192
193/// Unique identity for a tensor's underlying buffer.
194///
195/// Created fresh on every buffer allocation or import. The `id` is a monotonic
196/// u64 used as a cache key. The `guard` is an `Arc<()>` whose weak references
197/// allow downstream caches to detect when the buffer has been dropped.
198#[derive(Debug, Clone)]
199pub struct BufferIdentity {
200    id: u64,
201    guard: Arc<()>,
202}
203
204impl BufferIdentity {
205    /// Create a new unique buffer identity.
206    pub fn new() -> Self {
207        Self {
208            id: NEXT_BUFFER_ID.fetch_add(1, Ordering::Relaxed),
209            guard: Arc::new(()),
210        }
211    }
212
213    /// Unique identifier for this buffer. Changes when the buffer changes.
214    pub fn id(&self) -> u64 {
215        self.id
216    }
217
218    /// Returns a weak reference to the buffer guard. Goes dead when the
219    /// owning Tensor is dropped (and no clones remain).
220    pub fn weak(&self) -> Weak<()> {
221        Arc::downgrade(&self.guard)
222    }
223}
224
225impl Default for BufferIdentity {
226    fn default() -> Self {
227        Self::new()
228    }
229}
230
231#[cfg(target_os = "linux")]
232use nix::sys::stat::{major, minor};
233
234pub trait TensorTrait<T>: Send + Sync
235where
236    T: Num + Clone + fmt::Debug,
237{
238    /// Create a new tensor with the given shape and optional name. If no name
239    /// is given, a random name will be generated.
240    fn new(shape: &[usize], name: Option<&str>) -> Result<Self>
241    where
242        Self: Sized;
243
244    #[cfg(unix)]
245    /// Create a new tensor using the given file descriptor, shape, and optional
246    /// name. If no name is given, a random name will be generated.
247    ///
248    /// On Linux: Inspects the fd to determine DMA vs SHM based on device major/minor.
249    /// On other Unix (macOS): Always creates SHM tensor.
250    fn from_fd(fd: std::os::fd::OwnedFd, shape: &[usize], name: Option<&str>) -> Result<Self>
251    where
252        Self: Sized;
253
254    #[cfg(unix)]
255    /// Clone the file descriptor associated with this tensor.
256    fn clone_fd(&self) -> Result<std::os::fd::OwnedFd>;
257
258    /// Get the memory type of this tensor.
259    fn memory(&self) -> TensorMemory;
260
261    /// Get the name of this tensor.
262    fn name(&self) -> String;
263
264    /// Get the number of elements in this tensor.
265    fn len(&self) -> usize {
266        self.shape().iter().product()
267    }
268
269    /// Check if the tensor is empty.
270    fn is_empty(&self) -> bool {
271        self.len() == 0
272    }
273
274    /// Get the size in bytes of this tensor.
275    fn size(&self) -> usize {
276        self.len() * std::mem::size_of::<T>()
277    }
278
279    /// Get the shape of this tensor.
280    fn shape(&self) -> &[usize];
281
282    /// Reshape this tensor to the given shape. The total number of elements
283    /// must remain the same.
284    fn reshape(&mut self, shape: &[usize]) -> Result<()>;
285
286    /// Map the tensor into memory and return a TensorMap for accessing the
287    /// data.
288    fn map(&self) -> Result<TensorMap<T>>;
289
290    /// Get the buffer identity for cache keying and liveness tracking.
291    fn buffer_identity(&self) -> &BufferIdentity;
292}
293
294pub trait TensorMapTrait<T>
295where
296    T: Num + Clone + fmt::Debug,
297{
298    /// Get the shape of this tensor map.
299    fn shape(&self) -> &[usize];
300
301    /// Unmap the tensor from memory.
302    fn unmap(&mut self);
303
304    /// Get the number of elements in this tensor map.
305    fn len(&self) -> usize {
306        self.shape().iter().product()
307    }
308
309    /// Check if the tensor map is empty.
310    fn is_empty(&self) -> bool {
311        self.len() == 0
312    }
313
314    /// Get the size in bytes of this tensor map.
315    fn size(&self) -> usize {
316        self.len() * std::mem::size_of::<T>()
317    }
318
319    /// Get a slice to the data in this tensor map.
320    fn as_slice(&self) -> &[T];
321
322    /// Get a mutable slice to the data in this tensor map.
323    fn as_mut_slice(&mut self) -> &mut [T];
324
325    #[cfg(feature = "ndarray")]
326    /// Get an ndarray ArrayView of the tensor data.
327    fn view(&'_ self) -> Result<ndarray::ArrayView<'_, T, ndarray::Dim<ndarray::IxDynImpl>>> {
328        Ok(ndarray::ArrayView::from_shape(
329            self.shape(),
330            self.as_slice(),
331        )?)
332    }
333
334    #[cfg(feature = "ndarray")]
335    /// Get an ndarray ArrayViewMut of the tensor data.
336    fn view_mut(
337        &'_ mut self,
338    ) -> Result<ndarray::ArrayViewMut<'_, T, ndarray::Dim<ndarray::IxDynImpl>>> {
339        let shape = self.shape().to_vec();
340        Ok(ndarray::ArrayViewMut::from_shape(
341            shape,
342            self.as_mut_slice(),
343        )?)
344    }
345}
346
347#[derive(Debug, Clone, Copy, PartialEq, Eq)]
348pub enum TensorMemory {
349    #[cfg(target_os = "linux")]
350    /// Direct Memory Access (DMA) allocation. Incurs additional
351    /// overhead for memory reading/writing with the CPU.  Allows for
352    /// hardware acceleration when supported.
353    Dma,
354    #[cfg(unix)]
355    /// POSIX Shared Memory allocation. Suitable for inter-process
356    /// communication, but not suitable for hardware acceleration.
357    Shm,
358
359    /// Regular system memory allocation
360    Mem,
361
362    /// OpenGL Pixel Buffer Object memory. Created by ImageProcessor
363    /// when DMA-buf is unavailable but OpenGL is present.
364    Pbo,
365}
366
367impl From<TensorMemory> for String {
368    fn from(memory: TensorMemory) -> Self {
369        match memory {
370            #[cfg(target_os = "linux")]
371            TensorMemory::Dma => "dma".to_owned(),
372            #[cfg(unix)]
373            TensorMemory::Shm => "shm".to_owned(),
374            TensorMemory::Mem => "mem".to_owned(),
375            TensorMemory::Pbo => "pbo".to_owned(),
376        }
377    }
378}
379
380impl TryFrom<&str> for TensorMemory {
381    type Error = Error;
382
383    fn try_from(s: &str) -> Result<Self> {
384        match s {
385            #[cfg(target_os = "linux")]
386            "dma" => Ok(TensorMemory::Dma),
387            #[cfg(unix)]
388            "shm" => Ok(TensorMemory::Shm),
389            "mem" => Ok(TensorMemory::Mem),
390            "pbo" => Ok(TensorMemory::Pbo),
391            _ => Err(Error::InvalidMemoryType(s.to_owned())),
392        }
393    }
394}
395
396#[derive(Debug)]
397#[allow(dead_code)] // Variants are constructed by downstream crates via pub(crate) helpers
398pub(crate) enum TensorStorage<T>
399where
400    T: Num + Clone + fmt::Debug + Send + Sync,
401{
402    #[cfg(target_os = "linux")]
403    Dma(DmaTensor<T>),
404    #[cfg(unix)]
405    Shm(ShmTensor<T>),
406    Mem(MemTensor<T>),
407    Pbo(PboTensor<T>),
408}
409
410impl<T> TensorStorage<T>
411where
412    T: Num + Clone + fmt::Debug + Send + Sync,
413{
414    /// Create a new tensor storage with the given shape, memory type, and
415    /// optional name. If no name is given, a random name will be generated.
416    /// If no memory type is given, the best available memory type will be
417    /// chosen based on the platform and environment variables.
418    fn new(shape: &[usize], memory: Option<TensorMemory>, name: Option<&str>) -> Result<Self> {
419        match memory {
420            #[cfg(target_os = "linux")]
421            Some(TensorMemory::Dma) => {
422                DmaTensor::<T>::new(shape, name).map(TensorStorage::Dma)
423            }
424            #[cfg(unix)]
425            Some(TensorMemory::Shm) => {
426                ShmTensor::<T>::new(shape, name).map(TensorStorage::Shm)
427            }
428            Some(TensorMemory::Mem) => {
429                MemTensor::<T>::new(shape, name).map(TensorStorage::Mem)
430            }
431            Some(TensorMemory::Pbo) => Err(crate::error::Error::NotImplemented(
432                "PboTensor cannot be created via Tensor::new() — use ImageProcessor::create_image()".to_owned(),
433            )),
434            None => {
435                if std::env::var("EDGEFIRST_TENSOR_FORCE_MEM")
436                    .is_ok_and(|x| x != "0" && x.to_lowercase() != "false")
437                {
438                    MemTensor::<T>::new(shape, name).map(TensorStorage::Mem)
439                } else {
440                    #[cfg(target_os = "linux")]
441                    {
442                        // Linux: Try DMA -> SHM -> Mem
443                        match DmaTensor::<T>::new(shape, name) {
444                            Ok(tensor) => Ok(TensorStorage::Dma(tensor)),
445                            Err(_) => {
446                                match ShmTensor::<T>::new(shape, name)
447                                    .map(TensorStorage::Shm)
448                                {
449                                    Ok(tensor) => Ok(tensor),
450                                    Err(_) => MemTensor::<T>::new(shape, name)
451                                        .map(TensorStorage::Mem),
452                                }
453                            }
454                        }
455                    }
456                    #[cfg(all(unix, not(target_os = "linux")))]
457                    {
458                        // macOS/BSD: Try SHM -> Mem (no DMA)
459                        match ShmTensor::<T>::new(shape, name) {
460                            Ok(tensor) => Ok(TensorStorage::Shm(tensor)),
461                            Err(_) => {
462                                MemTensor::<T>::new(shape, name).map(TensorStorage::Mem)
463                            }
464                        }
465                    }
466                    #[cfg(not(unix))]
467                    {
468                        // Windows/other: Mem only
469                        MemTensor::<T>::new(shape, name).map(TensorStorage::Mem)
470                    }
471                }
472            }
473        }
474    }
475
476    /// Create a new tensor storage using the given file descriptor, shape,
477    /// and optional name.
478    #[cfg(unix)]
479    fn from_fd(fd: OwnedFd, shape: &[usize], name: Option<&str>) -> Result<Self> {
480        #[cfg(target_os = "linux")]
481        {
482            use nix::sys::stat::fstat;
483
484            let stat = fstat(&fd)?;
485            let major = major(stat.st_dev);
486            let minor = minor(stat.st_dev);
487
488            log::debug!("Creating tensor from fd: major={major}, minor={minor}");
489
490            if major != 0 {
491                // Dma and Shm tensors are expected to have major number 0
492                return Err(Error::UnknownDeviceType(major, minor));
493            }
494
495            match minor {
496                9 | 10 => {
497                    // minor number 9 & 10 indicates DMA memory
498                    DmaTensor::<T>::from_fd(fd, shape, name).map(TensorStorage::Dma)
499                }
500                _ => {
501                    // other minor numbers are assumed to be shared memory
502                    ShmTensor::<T>::from_fd(fd, shape, name).map(TensorStorage::Shm)
503                }
504            }
505        }
506        #[cfg(all(unix, not(target_os = "linux")))]
507        {
508            // On macOS/BSD, always use SHM (no DMA support)
509            ShmTensor::<T>::from_fd(fd, shape, name).map(TensorStorage::Shm)
510        }
511    }
512}
513
514impl<T> TensorTrait<T> for TensorStorage<T>
515where
516    T: Num + Clone + fmt::Debug + Send + Sync,
517{
518    fn new(shape: &[usize], name: Option<&str>) -> Result<Self> {
519        Self::new(shape, None, name)
520    }
521
522    #[cfg(unix)]
523    fn from_fd(fd: OwnedFd, shape: &[usize], name: Option<&str>) -> Result<Self> {
524        Self::from_fd(fd, shape, name)
525    }
526
527    #[cfg(unix)]
528    fn clone_fd(&self) -> Result<OwnedFd> {
529        match self {
530            #[cfg(target_os = "linux")]
531            TensorStorage::Dma(t) => t.clone_fd(),
532            TensorStorage::Shm(t) => t.clone_fd(),
533            TensorStorage::Mem(t) => t.clone_fd(),
534            TensorStorage::Pbo(t) => t.clone_fd(),
535        }
536    }
537
538    fn memory(&self) -> TensorMemory {
539        match self {
540            #[cfg(target_os = "linux")]
541            TensorStorage::Dma(_) => TensorMemory::Dma,
542            #[cfg(unix)]
543            TensorStorage::Shm(_) => TensorMemory::Shm,
544            TensorStorage::Mem(_) => TensorMemory::Mem,
545            TensorStorage::Pbo(_) => TensorMemory::Pbo,
546        }
547    }
548
549    fn name(&self) -> String {
550        match self {
551            #[cfg(target_os = "linux")]
552            TensorStorage::Dma(t) => t.name(),
553            #[cfg(unix)]
554            TensorStorage::Shm(t) => t.name(),
555            TensorStorage::Mem(t) => t.name(),
556            TensorStorage::Pbo(t) => t.name(),
557        }
558    }
559
560    fn shape(&self) -> &[usize] {
561        match self {
562            #[cfg(target_os = "linux")]
563            TensorStorage::Dma(t) => t.shape(),
564            #[cfg(unix)]
565            TensorStorage::Shm(t) => t.shape(),
566            TensorStorage::Mem(t) => t.shape(),
567            TensorStorage::Pbo(t) => t.shape(),
568        }
569    }
570
571    fn reshape(&mut self, shape: &[usize]) -> Result<()> {
572        match self {
573            #[cfg(target_os = "linux")]
574            TensorStorage::Dma(t) => t.reshape(shape),
575            #[cfg(unix)]
576            TensorStorage::Shm(t) => t.reshape(shape),
577            TensorStorage::Mem(t) => t.reshape(shape),
578            TensorStorage::Pbo(t) => t.reshape(shape),
579        }
580    }
581
582    fn map(&self) -> Result<TensorMap<T>> {
583        match self {
584            #[cfg(target_os = "linux")]
585            TensorStorage::Dma(t) => t.map(),
586            #[cfg(unix)]
587            TensorStorage::Shm(t) => t.map(),
588            TensorStorage::Mem(t) => t.map(),
589            TensorStorage::Pbo(t) => t.map(),
590        }
591    }
592
593    fn buffer_identity(&self) -> &BufferIdentity {
594        match self {
595            #[cfg(target_os = "linux")]
596            TensorStorage::Dma(t) => t.buffer_identity(),
597            #[cfg(unix)]
598            TensorStorage::Shm(t) => t.buffer_identity(),
599            TensorStorage::Mem(t) => t.buffer_identity(),
600            TensorStorage::Pbo(t) => t.buffer_identity(),
601        }
602    }
603}
604
605/// Multi-backend tensor with optional image format metadata.
606///
607/// When `format` is `Some`, this tensor represents an image. Width, height,
608/// and channels are derived from `shape` + `format`. When `format` is `None`,
609/// this is a raw tensor (identical to the pre-refactoring behavior).
610#[derive(Debug)]
611pub struct Tensor<T>
612where
613    T: Num + Clone + fmt::Debug + Send + Sync,
614{
615    pub(crate) storage: TensorStorage<T>,
616    format: Option<PixelFormat>,
617    chroma: Option<Box<Tensor<T>>>,
618    /// Row stride in bytes for externally allocated buffers with row padding.
619    /// `None` means tightly packed (stride == width * bytes_per_pixel).
620    row_stride: Option<usize>,
621    /// Byte offset within the DMA-BUF where image data starts.
622    /// `None` means offset 0 (data starts at the beginning of the buffer).
623    plane_offset: Option<usize>,
624}
625
626impl<T> Tensor<T>
627where
628    T: Num + Clone + fmt::Debug + Send + Sync,
629{
630    /// Wrap a TensorStorage in a Tensor with no image metadata.
631    pub(crate) fn wrap(storage: TensorStorage<T>) -> Self {
632        Self {
633            storage,
634            format: None,
635            chroma: None,
636            row_stride: None,
637            plane_offset: None,
638        }
639    }
640
641    /// Create a new tensor with the given shape, memory type, and optional
642    /// name. If no name is given, a random name will be generated. If no
643    /// memory type is given, the best available memory type will be chosen
644    /// based on the platform and environment variables.
645    ///
646    /// On Linux platforms, the order of preference is: Dma -> Shm -> Mem.
647    /// On other Unix platforms (macOS), the order is: Shm -> Mem.
648    /// On non-Unix platforms, only Mem is available.
649    ///
650    /// # Environment Variables
651    /// - `EDGEFIRST_TENSOR_FORCE_MEM`: If set to a non-zero and non-false
652    ///   value, forces the use of regular system memory allocation
653    ///   (`TensorMemory::Mem`) regardless of platform capabilities.
654    ///
655    /// # Example
656    /// ```rust
657    /// use edgefirst_tensor::{Error, Tensor, TensorMemory, TensorTrait};
658    /// # fn main() -> Result<(), Error> {
659    /// let tensor = Tensor::<f32>::new(&[2, 3, 4], Some(TensorMemory::Mem), Some("test_tensor"))?;
660    /// assert_eq!(tensor.memory(), TensorMemory::Mem);
661    /// assert_eq!(tensor.name(), "test_tensor");
662    /// #    Ok(())
663    /// # }
664    /// ```
665    pub fn new(shape: &[usize], memory: Option<TensorMemory>, name: Option<&str>) -> Result<Self> {
666        TensorStorage::new(shape, memory, name).map(Self::wrap)
667    }
668
669    /// Create an image tensor with the given format.
670    pub fn image(
671        width: usize,
672        height: usize,
673        format: PixelFormat,
674        memory: Option<TensorMemory>,
675    ) -> Result<Self> {
676        let shape = match format.layout() {
677            PixelLayout::Packed => vec![height, width, format.channels()],
678            PixelLayout::Planar => vec![format.channels(), height, width],
679            PixelLayout::SemiPlanar => {
680                // Contiguous semi-planar: luma + interleaved chroma in one allocation.
681                // NV12 (4:2:0): H lines luma + H/2 lines chroma = H * 3/2 total
682                // NV16 (4:2:2): H lines luma + H lines chroma = H * 2 total
683                let total_h = match format {
684                    PixelFormat::Nv12 => {
685                        if !height.is_multiple_of(2) {
686                            return Err(Error::InvalidArgument(format!(
687                                "NV12 requires even height, got {height}"
688                            )));
689                        }
690                        height * 3 / 2
691                    }
692                    PixelFormat::Nv16 => height * 2,
693                    _ => {
694                        return Err(Error::InvalidArgument(format!(
695                            "unknown semi-planar height multiplier for {format:?}"
696                        )))
697                    }
698                };
699                vec![total_h, width]
700            }
701        };
702        let mut t = Self::new(&shape, memory, None)?;
703        t.format = Some(format);
704        Ok(t)
705    }
706
707    /// Attach format metadata to an existing tensor.
708    ///
709    /// # Arguments
710    ///
711    /// * `format` - The pixel format to attach
712    ///
713    /// # Returns
714    ///
715    /// `Ok(())` on success, with the format stored as metadata on the tensor.
716    ///
717    /// # Errors
718    ///
719    /// Returns `Error::InvalidShape` if the tensor shape is incompatible with
720    /// the format's layout (packed expects `[H, W, C]`, planar expects
721    /// `[C, H, W]`, semi-planar expects `[H*k, W]` with format-specific
722    /// height constraints).
723    pub fn set_format(&mut self, format: PixelFormat) -> Result<()> {
724        let shape = self.shape();
725        match format.layout() {
726            PixelLayout::Packed => {
727                if shape.len() != 3 || shape[2] != format.channels() {
728                    return Err(Error::InvalidShape(format!(
729                        "packed format {format:?} expects [H, W, {}], got {shape:?}",
730                        format.channels()
731                    )));
732                }
733            }
734            PixelLayout::Planar => {
735                if shape.len() != 3 || shape[0] != format.channels() {
736                    return Err(Error::InvalidShape(format!(
737                        "planar format {format:?} expects [{}, H, W], got {shape:?}",
738                        format.channels()
739                    )));
740                }
741            }
742            PixelLayout::SemiPlanar => {
743                if shape.len() != 2 {
744                    return Err(Error::InvalidShape(format!(
745                        "semi-planar format {format:?} expects [H*k, W], got {shape:?}"
746                    )));
747                }
748                match format {
749                    PixelFormat::Nv12 if !shape[0].is_multiple_of(3) => {
750                        return Err(Error::InvalidShape(format!(
751                            "NV12 contiguous shape[0] must be divisible by 3, got {}",
752                            shape[0]
753                        )));
754                    }
755                    PixelFormat::Nv16 if !shape[0].is_multiple_of(2) => {
756                        return Err(Error::InvalidShape(format!(
757                            "NV16 contiguous shape[0] must be even, got {}",
758                            shape[0]
759                        )));
760                    }
761                    _ => {}
762                }
763            }
764        }
765        // Clear stored stride/offset when format changes — they may be invalid
766        // for the new format. Caller must re-set after changing format.
767        if self.format != Some(format) {
768            self.row_stride = None;
769            self.plane_offset = None;
770            #[cfg(target_os = "linux")]
771            if let TensorStorage::Dma(ref mut dma) = self.storage {
772                dma.mmap_offset = 0;
773            }
774        }
775        self.format = Some(format);
776        Ok(())
777    }
778
779    /// Pixel format (None if not an image).
780    pub fn format(&self) -> Option<PixelFormat> {
781        self.format
782    }
783
784    /// Image width (None if not an image).
785    pub fn width(&self) -> Option<usize> {
786        let fmt = self.format?;
787        let shape = self.shape();
788        match fmt.layout() {
789            PixelLayout::Packed => Some(shape[1]),
790            PixelLayout::Planar => Some(shape[2]),
791            PixelLayout::SemiPlanar => Some(shape[1]),
792        }
793    }
794
795    /// Image height (None if not an image).
796    pub fn height(&self) -> Option<usize> {
797        let fmt = self.format?;
798        let shape = self.shape();
799        match fmt.layout() {
800            PixelLayout::Packed => Some(shape[0]),
801            PixelLayout::Planar => Some(shape[1]),
802            PixelLayout::SemiPlanar => {
803                if self.is_multiplane() {
804                    Some(shape[0])
805                } else {
806                    match fmt {
807                        PixelFormat::Nv12 => Some(shape[0] * 2 / 3),
808                        PixelFormat::Nv16 => Some(shape[0] / 2),
809                        _ => None,
810                    }
811                }
812            }
813        }
814    }
815
816    /// Create from separate Y and UV planes (multiplane NV12/NV16).
817    pub fn from_planes(luma: Tensor<T>, chroma: Tensor<T>, format: PixelFormat) -> Result<Self> {
818        if format.layout() != PixelLayout::SemiPlanar {
819            return Err(Error::InvalidArgument(format!(
820                "from_planes requires a semi-planar format, got {format:?}"
821            )));
822        }
823        if chroma.format.is_some() || chroma.chroma.is_some() {
824            return Err(Error::InvalidArgument(
825                "chroma tensor must be a raw tensor (no format or chroma metadata)".into(),
826            ));
827        }
828        let luma_shape = luma.shape();
829        let chroma_shape = chroma.shape();
830        if luma_shape.len() != 2 || chroma_shape.len() != 2 {
831            return Err(Error::InvalidArgument(format!(
832                "from_planes expects 2D shapes, got luma={luma_shape:?} chroma={chroma_shape:?}"
833            )));
834        }
835        if luma_shape[1] != chroma_shape[1] {
836            return Err(Error::InvalidArgument(format!(
837                "luma width {} != chroma width {}",
838                luma_shape[1], chroma_shape[1]
839            )));
840        }
841        match format {
842            PixelFormat::Nv12 => {
843                if luma_shape[0] % 2 != 0 {
844                    return Err(Error::InvalidArgument(format!(
845                        "NV12 requires even luma height, got {}",
846                        luma_shape[0]
847                    )));
848                }
849                if chroma_shape[0] != luma_shape[0] / 2 {
850                    return Err(Error::InvalidArgument(format!(
851                        "NV12 chroma height {} != luma height / 2 ({})",
852                        chroma_shape[0],
853                        luma_shape[0] / 2
854                    )));
855                }
856            }
857            PixelFormat::Nv16 => {
858                if chroma_shape[0] != luma_shape[0] {
859                    return Err(Error::InvalidArgument(format!(
860                        "NV16 chroma height {} != luma height {}",
861                        chroma_shape[0], luma_shape[0]
862                    )));
863                }
864            }
865            _ => {
866                return Err(Error::InvalidArgument(format!(
867                    "from_planes only supports NV12 and NV16, got {format:?}"
868                )));
869            }
870        }
871
872        Ok(Tensor {
873            storage: luma.storage,
874            format: Some(format),
875            chroma: Some(Box::new(chroma)),
876            row_stride: luma.row_stride,
877            plane_offset: luma.plane_offset,
878        })
879    }
880
881    /// Whether this tensor uses separate plane allocations.
882    pub fn is_multiplane(&self) -> bool {
883        self.chroma.is_some()
884    }
885
886    /// Access the chroma plane for multiplane semi-planar images.
887    pub fn chroma(&self) -> Option<&Tensor<T>> {
888        self.chroma.as_deref()
889    }
890
891    /// Mutable access to the chroma plane for multiplane semi-planar images.
892    pub fn chroma_mut(&mut self) -> Option<&mut Tensor<T>> {
893        self.chroma.as_deref_mut()
894    }
895
896    /// Row stride in bytes (`None` = tightly packed).
897    pub fn row_stride(&self) -> Option<usize> {
898        self.row_stride
899    }
900
901    /// Effective row stride in bytes: the stored stride if set, otherwise the
902    /// minimum stride computed from the format, width, and element size.
903    /// Returns `None` only when no format is set and no explicit stride was
904    /// stored via [`set_row_stride`](Self::set_row_stride).
905    pub fn effective_row_stride(&self) -> Option<usize> {
906        if let Some(s) = self.row_stride {
907            return Some(s);
908        }
909        let fmt = self.format?;
910        let w = self.width()?;
911        let elem = std::mem::size_of::<T>();
912        Some(match fmt.layout() {
913            PixelLayout::Packed => w * fmt.channels() * elem,
914            PixelLayout::Planar | PixelLayout::SemiPlanar => w * elem,
915        })
916    }
917
918    /// Set the row stride in bytes for externally allocated buffers with
919    /// row padding (e.g. V4L2 or GStreamer allocators).
920    ///
921    /// The stride is propagated to the EGL DMA-BUF import attributes so
922    /// the GPU interprets the padded buffer layout correctly. Must be
923    /// called after [`set_format`](Self::set_format) and before the tensor
924    /// is first passed to [`ImageProcessor::convert`]. The stored stride
925    /// is cleared automatically if the pixel format is later changed.
926    ///
927    /// No stride-vs-buffer-size validation is performed because the
928    /// backing allocation size is not reliably known: external DMA-BUFs
929    /// may be over-allocated by the allocator, and internal tensors store
930    /// a logical (unpadded) shape. An incorrect stride will be caught by
931    /// the EGL driver at import time.
932    ///
933    /// # Arguments
934    ///
935    /// * `stride` - Row stride in bytes. Must be >= the minimum stride for
936    ///   the format (width * channels * sizeof(T) for packed,
937    ///   width * sizeof(T) for planar/semi-planar).
938    ///
939    /// # Errors
940    ///
941    /// * `InvalidArgument` if no pixel format is set on this tensor
942    /// * `InvalidArgument` if `stride` is less than the minimum for the
943    ///   format and width
944    pub fn set_row_stride(&mut self, stride: usize) -> Result<()> {
945        let fmt = self.format.ok_or_else(|| {
946            Error::InvalidArgument("cannot set row_stride without a pixel format".into())
947        })?;
948        let w = self.width().ok_or_else(|| {
949            Error::InvalidArgument("cannot determine width for row_stride validation".into())
950        })?;
951        let elem = std::mem::size_of::<T>();
952        let min_stride = match fmt.layout() {
953            PixelLayout::Packed => w * fmt.channels() * elem,
954            PixelLayout::Planar | PixelLayout::SemiPlanar => w * elem,
955        };
956        if stride < min_stride {
957            return Err(Error::InvalidArgument(format!(
958                "row_stride {stride} < minimum {min_stride} for {fmt:?} at width {w}"
959            )));
960        }
961        self.row_stride = Some(stride);
962        Ok(())
963    }
964
965    /// Set the row stride without format validation.
966    ///
967    /// Use this for raw sub-tensors (e.g. chroma planes) that don't carry
968    /// format metadata. The caller is responsible for ensuring the stride
969    /// is valid.
970    pub fn set_row_stride_unchecked(&mut self, stride: usize) {
971        self.row_stride = Some(stride);
972    }
973
974    /// Builder-style variant of [`set_row_stride`](Self::set_row_stride),
975    /// consuming and returning `self`.
976    ///
977    /// # Errors
978    ///
979    /// Same conditions as [`set_row_stride`](Self::set_row_stride).
980    pub fn with_row_stride(mut self, stride: usize) -> Result<Self> {
981        self.set_row_stride(stride)?;
982        Ok(self)
983    }
984
985    /// Byte offset within the DMA-BUF where image data starts (`None` = 0).
986    pub fn plane_offset(&self) -> Option<usize> {
987        self.plane_offset
988    }
989
990    /// Set the byte offset within the DMA-BUF where image data starts.
991    ///
992    /// Propagated to `EGL_DMA_BUF_PLANE0_OFFSET_EXT` on GPU import.
993    /// Unlike [`set_row_stride`](Self::set_row_stride), no format is required
994    /// since the offset is format-independent.
995    pub fn set_plane_offset(&mut self, offset: usize) {
996        self.plane_offset = Some(offset);
997        #[cfg(target_os = "linux")]
998        if let TensorStorage::Dma(ref mut dma) = self.storage {
999            dma.mmap_offset = offset;
1000        }
1001    }
1002
1003    /// Builder-style variant of [`set_plane_offset`](Self::set_plane_offset),
1004    /// consuming and returning `self`.
1005    pub fn with_plane_offset(mut self, offset: usize) -> Self {
1006        self.set_plane_offset(offset);
1007        self
1008    }
1009
1010    /// Downcast to PBO tensor reference (for GL backends).
1011    pub fn as_pbo(&self) -> Option<&PboTensor<T>> {
1012        match &self.storage {
1013            TensorStorage::Pbo(p) => Some(p),
1014            _ => None,
1015        }
1016    }
1017
1018    /// Downcast to DMA tensor reference (for EGL import, G2D).
1019    #[cfg(target_os = "linux")]
1020    pub fn as_dma(&self) -> Option<&DmaTensor<T>> {
1021        match &self.storage {
1022            TensorStorage::Dma(d) => Some(d),
1023            _ => None,
1024        }
1025    }
1026
1027    /// Borrow the DMA-BUF file descriptor backing this tensor.
1028    ///
1029    /// # Returns
1030    ///
1031    /// A borrowed reference to the DMA-BUF file descriptor, tied to `self`'s
1032    /// lifetime.
1033    ///
1034    /// # Errors
1035    ///
1036    /// Returns `Error::NotImplemented` if the tensor is not DMA-backed.
1037    #[cfg(target_os = "linux")]
1038    pub fn dmabuf(&self) -> Result<std::os::fd::BorrowedFd<'_>> {
1039        use std::os::fd::AsFd;
1040        match &self.storage {
1041            TensorStorage::Dma(dma) => Ok(dma.fd.as_fd()),
1042            _ => Err(Error::NotImplemented(format!(
1043                "dmabuf requires DMA-backed tensor, got {:?}",
1044                self.storage.memory()
1045            ))),
1046        }
1047    }
1048
1049    /// Construct a Tensor from a PBO tensor (for GL backends that allocate PBOs).
1050    pub fn from_pbo(pbo: PboTensor<T>) -> Self {
1051        Self {
1052            storage: TensorStorage::Pbo(pbo),
1053            format: None,
1054            chroma: None,
1055            row_stride: None,
1056            plane_offset: None,
1057        }
1058    }
1059}
1060
1061impl<T> TensorTrait<T> for Tensor<T>
1062where
1063    T: Num + Clone + fmt::Debug + Send + Sync,
1064{
1065    fn new(shape: &[usize], name: Option<&str>) -> Result<Self>
1066    where
1067        Self: Sized,
1068    {
1069        Self::new(shape, None, name)
1070    }
1071
1072    #[cfg(unix)]
1073    fn from_fd(fd: std::os::fd::OwnedFd, shape: &[usize], name: Option<&str>) -> Result<Self>
1074    where
1075        Self: Sized,
1076    {
1077        Ok(Self::wrap(TensorStorage::from_fd(fd, shape, name)?))
1078    }
1079
1080    #[cfg(unix)]
1081    fn clone_fd(&self) -> Result<std::os::fd::OwnedFd> {
1082        self.storage.clone_fd()
1083    }
1084
1085    fn memory(&self) -> TensorMemory {
1086        self.storage.memory()
1087    }
1088
1089    fn name(&self) -> String {
1090        self.storage.name()
1091    }
1092
1093    fn shape(&self) -> &[usize] {
1094        self.storage.shape()
1095    }
1096
1097    fn reshape(&mut self, shape: &[usize]) -> Result<()> {
1098        if self.chroma.is_some() {
1099            return Err(Error::InvalidOperation(
1100                "cannot reshape a multiplane tensor — decompose planes first".into(),
1101            ));
1102        }
1103        self.storage.reshape(shape)?;
1104        self.format = None;
1105        self.row_stride = None;
1106        self.plane_offset = None;
1107        #[cfg(target_os = "linux")]
1108        if let TensorStorage::Dma(ref mut dma) = self.storage {
1109            dma.mmap_offset = 0;
1110        }
1111        Ok(())
1112    }
1113
1114    fn map(&self) -> Result<TensorMap<T>> {
1115        if self.row_stride.is_some() {
1116            return Err(Error::InvalidOperation(
1117                "CPU mapping of strided tensors is not supported; use GPU path only".into(),
1118            ));
1119        }
1120        // Offset tensors are supported for DMA storage — DmaMap adjusts the
1121        // mmap range and slice start position.  Non-DMA offset tensors are
1122        // not meaningful (offset only applies to DMA-BUF sub-regions).
1123        if self.plane_offset.is_some_and(|o| o > 0) {
1124            #[cfg(target_os = "linux")]
1125            if !matches!(self.storage, TensorStorage::Dma(_)) {
1126                return Err(Error::InvalidOperation(
1127                    "plane offset only supported for DMA tensors".into(),
1128                ));
1129            }
1130            #[cfg(not(target_os = "linux"))]
1131            return Err(Error::InvalidOperation(
1132                "plane offset only supported for DMA tensors".into(),
1133            ));
1134        }
1135        self.storage.map()
1136    }
1137
1138    fn buffer_identity(&self) -> &BufferIdentity {
1139        self.storage.buffer_identity()
1140    }
1141}
1142
1143pub enum TensorMap<T>
1144where
1145    T: Num + Clone + fmt::Debug,
1146{
1147    #[cfg(target_os = "linux")]
1148    Dma(DmaMap<T>),
1149    #[cfg(unix)]
1150    Shm(ShmMap<T>),
1151    Mem(MemMap<T>),
1152    Pbo(PboMap<T>),
1153}
1154
1155impl<T> TensorMapTrait<T> for TensorMap<T>
1156where
1157    T: Num + Clone + fmt::Debug,
1158{
1159    fn shape(&self) -> &[usize] {
1160        match self {
1161            #[cfg(target_os = "linux")]
1162            TensorMap::Dma(map) => map.shape(),
1163            #[cfg(unix)]
1164            TensorMap::Shm(map) => map.shape(),
1165            TensorMap::Mem(map) => map.shape(),
1166            TensorMap::Pbo(map) => map.shape(),
1167        }
1168    }
1169
1170    fn unmap(&mut self) {
1171        match self {
1172            #[cfg(target_os = "linux")]
1173            TensorMap::Dma(map) => map.unmap(),
1174            #[cfg(unix)]
1175            TensorMap::Shm(map) => map.unmap(),
1176            TensorMap::Mem(map) => map.unmap(),
1177            TensorMap::Pbo(map) => map.unmap(),
1178        }
1179    }
1180
1181    fn as_slice(&self) -> &[T] {
1182        match self {
1183            #[cfg(target_os = "linux")]
1184            TensorMap::Dma(map) => map.as_slice(),
1185            #[cfg(unix)]
1186            TensorMap::Shm(map) => map.as_slice(),
1187            TensorMap::Mem(map) => map.as_slice(),
1188            TensorMap::Pbo(map) => map.as_slice(),
1189        }
1190    }
1191
1192    fn as_mut_slice(&mut self) -> &mut [T] {
1193        match self {
1194            #[cfg(target_os = "linux")]
1195            TensorMap::Dma(map) => map.as_mut_slice(),
1196            #[cfg(unix)]
1197            TensorMap::Shm(map) => map.as_mut_slice(),
1198            TensorMap::Mem(map) => map.as_mut_slice(),
1199            TensorMap::Pbo(map) => map.as_mut_slice(),
1200        }
1201    }
1202}
1203
1204impl<T> Deref for TensorMap<T>
1205where
1206    T: Num + Clone + fmt::Debug,
1207{
1208    type Target = [T];
1209
1210    fn deref(&self) -> &[T] {
1211        match self {
1212            #[cfg(target_os = "linux")]
1213            TensorMap::Dma(map) => map.deref(),
1214            #[cfg(unix)]
1215            TensorMap::Shm(map) => map.deref(),
1216            TensorMap::Mem(map) => map.deref(),
1217            TensorMap::Pbo(map) => map.deref(),
1218        }
1219    }
1220}
1221
1222impl<T> DerefMut for TensorMap<T>
1223where
1224    T: Num + Clone + fmt::Debug,
1225{
1226    fn deref_mut(&mut self) -> &mut [T] {
1227        match self {
1228            #[cfg(target_os = "linux")]
1229            TensorMap::Dma(map) => map.deref_mut(),
1230            #[cfg(unix)]
1231            TensorMap::Shm(map) => map.deref_mut(),
1232            TensorMap::Mem(map) => map.deref_mut(),
1233            TensorMap::Pbo(map) => map.deref_mut(),
1234        }
1235    }
1236}
1237
1238// ============================================================================
1239// Platform availability helpers
1240// ============================================================================
1241
1242/// Check if DMA memory allocation is available on this system.
1243///
1244/// Returns `true` only on Linux systems with DMA-BUF heap access (typically
1245/// requires running as root or membership in a video/render group).
1246/// Always returns `false` on non-Linux platforms (macOS, Windows, etc.).
1247///
1248/// This function caches its result after the first call for efficiency.
1249#[cfg(target_os = "linux")]
1250static DMA_AVAILABLE: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
1251
1252/// Check if DMA memory allocation is available on this system.
1253#[cfg(target_os = "linux")]
1254pub fn is_dma_available() -> bool {
1255    *DMA_AVAILABLE.get_or_init(|| Tensor::<u8>::new(&[64], Some(TensorMemory::Dma), None).is_ok())
1256}
1257
1258/// Check if DMA memory allocation is available on this system.
1259///
1260/// Always returns `false` on non-Linux platforms since DMA-BUF is Linux-specific.
1261#[cfg(not(target_os = "linux"))]
1262pub fn is_dma_available() -> bool {
1263    false
1264}
1265
1266/// Check if POSIX shared memory allocation is available on this system.
1267///
1268/// Returns `true` on Unix systems (Linux, macOS, BSD) where POSIX shared memory
1269/// is supported. Always returns `false` on non-Unix platforms (Windows).
1270///
1271/// This function caches its result after the first call for efficiency.
1272#[cfg(unix)]
1273static SHM_AVAILABLE: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
1274
1275/// Check if POSIX shared memory allocation is available on this system.
1276#[cfg(unix)]
1277pub fn is_shm_available() -> bool {
1278    *SHM_AVAILABLE.get_or_init(|| Tensor::<u8>::new(&[64], Some(TensorMemory::Shm), None).is_ok())
1279}
1280
1281/// Check if POSIX shared memory allocation is available on this system.
1282///
1283/// Always returns `false` on non-Unix platforms since POSIX SHM is Unix-specific.
1284#[cfg(not(unix))]
1285pub fn is_shm_available() -> bool {
1286    false
1287}
1288
1289#[cfg(test)]
1290mod dtype_tests {
1291    use super::*;
1292
1293    #[test]
1294    fn dtype_size() {
1295        assert_eq!(DType::U8.size(), 1);
1296        assert_eq!(DType::I8.size(), 1);
1297        assert_eq!(DType::U16.size(), 2);
1298        assert_eq!(DType::I16.size(), 2);
1299        assert_eq!(DType::U32.size(), 4);
1300        assert_eq!(DType::I32.size(), 4);
1301        assert_eq!(DType::U64.size(), 8);
1302        assert_eq!(DType::I64.size(), 8);
1303        assert_eq!(DType::F16.size(), 2);
1304        assert_eq!(DType::F32.size(), 4);
1305        assert_eq!(DType::F64.size(), 8);
1306    }
1307
1308    #[test]
1309    fn dtype_name() {
1310        assert_eq!(DType::U8.name(), "u8");
1311        assert_eq!(DType::F16.name(), "f16");
1312        assert_eq!(DType::F32.name(), "f32");
1313    }
1314
1315    #[test]
1316    fn dtype_serde_roundtrip() {
1317        use serde_json;
1318        let dt = DType::F16;
1319        let json = serde_json::to_string(&dt).unwrap();
1320        let back: DType = serde_json::from_str(&json).unwrap();
1321        assert_eq!(dt, back);
1322    }
1323}
1324
1325#[cfg(test)]
1326mod image_tests {
1327    use super::*;
1328
1329    #[test]
1330    fn raw_tensor_has_no_format() {
1331        let t = Tensor::<u8>::new(&[480, 640, 3], None, None).unwrap();
1332        assert!(t.format().is_none());
1333        assert!(t.width().is_none());
1334        assert!(t.height().is_none());
1335        assert!(!t.is_multiplane());
1336        assert!(t.chroma().is_none());
1337    }
1338
1339    #[test]
1340    fn image_tensor_packed() {
1341        let t = Tensor::<u8>::image(640, 480, PixelFormat::Rgba, None).unwrap();
1342        assert_eq!(t.format(), Some(PixelFormat::Rgba));
1343        assert_eq!(t.width(), Some(640));
1344        assert_eq!(t.height(), Some(480));
1345        assert_eq!(t.shape(), &[480, 640, 4]);
1346        assert!(!t.is_multiplane());
1347    }
1348
1349    #[test]
1350    fn image_tensor_planar() {
1351        let t = Tensor::<u8>::image(640, 480, PixelFormat::PlanarRgb, None).unwrap();
1352        assert_eq!(t.format(), Some(PixelFormat::PlanarRgb));
1353        assert_eq!(t.width(), Some(640));
1354        assert_eq!(t.height(), Some(480));
1355        assert_eq!(t.shape(), &[3, 480, 640]);
1356    }
1357
1358    #[test]
1359    fn image_tensor_semi_planar_contiguous() {
1360        let t = Tensor::<u8>::image(640, 480, PixelFormat::Nv12, None).unwrap();
1361        assert_eq!(t.format(), Some(PixelFormat::Nv12));
1362        assert_eq!(t.width(), Some(640));
1363        assert_eq!(t.height(), Some(480));
1364        // NV12: H*3/2 = 720
1365        assert_eq!(t.shape(), &[720, 640]);
1366        assert!(!t.is_multiplane());
1367    }
1368
1369    #[test]
1370    fn set_format_valid() {
1371        let mut t = Tensor::<u8>::new(&[480, 640, 3], None, None).unwrap();
1372        assert!(t.format().is_none());
1373        t.set_format(PixelFormat::Rgb).unwrap();
1374        assert_eq!(t.format(), Some(PixelFormat::Rgb));
1375        assert_eq!(t.width(), Some(640));
1376        assert_eq!(t.height(), Some(480));
1377    }
1378
1379    #[test]
1380    fn set_format_invalid_shape() {
1381        let mut t = Tensor::<u8>::new(&[480, 640, 4], None, None).unwrap();
1382        // RGB expects 3 channels, not 4
1383        let err = t.set_format(PixelFormat::Rgb);
1384        assert!(err.is_err());
1385        // Original tensor is unmodified
1386        assert!(t.format().is_none());
1387    }
1388
1389    #[test]
1390    fn reshape_clears_format() {
1391        let mut t = Tensor::<u8>::image(640, 480, PixelFormat::Rgba, None).unwrap();
1392        assert_eq!(t.format(), Some(PixelFormat::Rgba));
1393        // Reshape to flat — format cleared
1394        t.reshape(&[480 * 640 * 4]).unwrap();
1395        assert!(t.format().is_none());
1396    }
1397
1398    #[test]
1399    fn from_planes_nv12() {
1400        let y = Tensor::<u8>::new(&[480, 640], None, None).unwrap();
1401        let uv = Tensor::<u8>::new(&[240, 640], None, None).unwrap();
1402        let img = Tensor::from_planes(y, uv, PixelFormat::Nv12).unwrap();
1403        assert_eq!(img.format(), Some(PixelFormat::Nv12));
1404        assert!(img.is_multiplane());
1405        assert!(img.chroma().is_some());
1406        assert_eq!(img.width(), Some(640));
1407        assert_eq!(img.height(), Some(480));
1408    }
1409
1410    #[test]
1411    fn from_planes_rejects_non_semiplanar() {
1412        let y = Tensor::<u8>::new(&[480, 640], None, None).unwrap();
1413        let uv = Tensor::<u8>::new(&[240, 640], None, None).unwrap();
1414        let err = Tensor::from_planes(y, uv, PixelFormat::Rgb);
1415        assert!(err.is_err());
1416    }
1417
1418    #[test]
1419    fn reshape_multiplane_errors() {
1420        let y = Tensor::<u8>::new(&[480, 640], None, None).unwrap();
1421        let uv = Tensor::<u8>::new(&[240, 640], None, None).unwrap();
1422        let mut img = Tensor::from_planes(y, uv, PixelFormat::Nv12).unwrap();
1423        let err = img.reshape(&[480 * 640 + 240 * 640]);
1424        assert!(err.is_err());
1425    }
1426}
1427
1428#[cfg(test)]
1429mod tests {
1430    #[cfg(target_os = "linux")]
1431    use nix::unistd::{access, AccessFlags};
1432    #[cfg(target_os = "linux")]
1433    use std::io::Write as _;
1434    use std::sync::RwLock;
1435
1436    use super::*;
1437
1438    #[ctor::ctor]
1439    fn init() {
1440        env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init();
1441    }
1442
1443    /// Macro to get the current function name for logging in tests.
1444    #[cfg(target_os = "linux")]
1445    macro_rules! function {
1446        () => {{
1447            fn f() {}
1448            fn type_name_of<T>(_: T) -> &'static str {
1449                std::any::type_name::<T>()
1450            }
1451            let name = type_name_of(f);
1452
1453            // Find and cut the rest of the path
1454            match &name[..name.len() - 3].rfind(':') {
1455                Some(pos) => &name[pos + 1..name.len() - 3],
1456                None => &name[..name.len() - 3],
1457            }
1458        }};
1459    }
1460
1461    #[test]
1462    #[cfg(target_os = "linux")]
1463    fn test_tensor() {
1464        let _lock = FD_LOCK.read().unwrap();
1465        let shape = vec![1];
1466        let tensor = DmaTensor::<f32>::new(&shape, Some("dma_tensor"));
1467        let dma_enabled = tensor.is_ok();
1468
1469        let tensor = Tensor::<f32>::new(&shape, None, None).expect("Failed to create tensor");
1470        match dma_enabled {
1471            true => assert_eq!(tensor.memory(), TensorMemory::Dma),
1472            false => assert_eq!(tensor.memory(), TensorMemory::Shm),
1473        }
1474    }
1475
1476    #[test]
1477    #[cfg(all(unix, not(target_os = "linux")))]
1478    fn test_tensor() {
1479        let shape = vec![1];
1480        let tensor = Tensor::<f32>::new(&shape, None, None).expect("Failed to create tensor");
1481        // On macOS/BSD, auto-detection tries SHM first, falls back to Mem
1482        assert!(
1483            tensor.memory() == TensorMemory::Shm || tensor.memory() == TensorMemory::Mem,
1484            "Expected SHM or Mem on macOS, got {:?}",
1485            tensor.memory()
1486        );
1487    }
1488
1489    #[test]
1490    #[cfg(not(unix))]
1491    fn test_tensor() {
1492        let shape = vec![1];
1493        let tensor = Tensor::<f32>::new(&shape, None, None).expect("Failed to create tensor");
1494        assert_eq!(tensor.memory(), TensorMemory::Mem);
1495    }
1496
1497    #[test]
1498    #[cfg(target_os = "linux")]
1499    fn test_dma_tensor() {
1500        let _lock = FD_LOCK.read().unwrap();
1501        match access(
1502            "/dev/dma_heap/linux,cma",
1503            AccessFlags::R_OK | AccessFlags::W_OK,
1504        ) {
1505            Ok(_) => println!("/dev/dma_heap/linux,cma is available"),
1506            Err(_) => match access(
1507                "/dev/dma_heap/system",
1508                AccessFlags::R_OK | AccessFlags::W_OK,
1509            ) {
1510                Ok(_) => println!("/dev/dma_heap/system is available"),
1511                Err(e) => {
1512                    writeln!(
1513                        &mut std::io::stdout(),
1514                        "[WARNING] DMA Heap is unavailable: {e}"
1515                    )
1516                    .unwrap();
1517                    return;
1518                }
1519            },
1520        }
1521
1522        let shape = vec![2, 3, 4];
1523        let tensor =
1524            DmaTensor::<f32>::new(&shape, Some("test_tensor")).expect("Failed to create tensor");
1525
1526        const DUMMY_VALUE: f32 = 12.34;
1527
1528        assert_eq!(tensor.memory(), TensorMemory::Dma);
1529        assert_eq!(tensor.name(), "test_tensor");
1530        assert_eq!(tensor.shape(), &shape);
1531        assert_eq!(tensor.size(), 2 * 3 * 4 * std::mem::size_of::<f32>());
1532        assert_eq!(tensor.len(), 2 * 3 * 4);
1533
1534        {
1535            let mut tensor_map = tensor.map().expect("Failed to map DMA memory");
1536            tensor_map.fill(42.0);
1537            assert!(tensor_map.iter().all(|&x| x == 42.0));
1538        }
1539
1540        {
1541            let shared = Tensor::<f32>::from_fd(
1542                tensor
1543                    .clone_fd()
1544                    .expect("Failed to duplicate tensor file descriptor"),
1545                &shape,
1546                Some("test_tensor_shared"),
1547            )
1548            .expect("Failed to create tensor from fd");
1549
1550            assert_eq!(shared.memory(), TensorMemory::Dma);
1551            assert_eq!(shared.name(), "test_tensor_shared");
1552            assert_eq!(shared.shape(), &shape);
1553
1554            let mut tensor_map = shared.map().expect("Failed to map DMA memory from fd");
1555            tensor_map.fill(DUMMY_VALUE);
1556            assert!(tensor_map.iter().all(|&x| x == DUMMY_VALUE));
1557        }
1558
1559        {
1560            let tensor_map = tensor.map().expect("Failed to map DMA memory");
1561            assert!(tensor_map.iter().all(|&x| x == DUMMY_VALUE));
1562        }
1563
1564        let mut tensor = DmaTensor::<u8>::new(&shape, None).expect("Failed to create tensor");
1565        assert_eq!(tensor.shape(), &shape);
1566        let new_shape = vec![3, 4, 4];
1567        assert!(
1568            tensor.reshape(&new_shape).is_err(),
1569            "Reshape should fail due to size mismatch"
1570        );
1571        assert_eq!(tensor.shape(), &shape, "Shape should remain unchanged");
1572
1573        let new_shape = vec![2, 3, 4];
1574        tensor.reshape(&new_shape).expect("Reshape should succeed");
1575        assert_eq!(
1576            tensor.shape(),
1577            &new_shape,
1578            "Shape should be updated after successful reshape"
1579        );
1580
1581        {
1582            let mut tensor_map = tensor.map().expect("Failed to map DMA memory");
1583            tensor_map.fill(1);
1584            assert!(tensor_map.iter().all(|&x| x == 1));
1585        }
1586
1587        {
1588            let mut tensor_map = tensor.map().expect("Failed to map DMA memory");
1589            tensor_map[2] = 42;
1590            assert_eq!(tensor_map[1], 1, "Value at index 1 should be 1");
1591            assert_eq!(tensor_map[2], 42, "Value at index 2 should be 42");
1592        }
1593    }
1594
1595    #[test]
1596    #[cfg(unix)]
1597    fn test_shm_tensor() {
1598        let _lock = FD_LOCK.read().unwrap();
1599        let shape = vec![2, 3, 4];
1600        let tensor =
1601            ShmTensor::<f32>::new(&shape, Some("test_tensor")).expect("Failed to create tensor");
1602        assert_eq!(tensor.shape(), &shape);
1603        assert_eq!(tensor.size(), 2 * 3 * 4 * std::mem::size_of::<f32>());
1604        assert_eq!(tensor.name(), "test_tensor");
1605
1606        const DUMMY_VALUE: f32 = 12.34;
1607        {
1608            let mut tensor_map = tensor.map().expect("Failed to map shared memory");
1609            tensor_map.fill(42.0);
1610            assert!(tensor_map.iter().all(|&x| x == 42.0));
1611        }
1612
1613        {
1614            let shared = Tensor::<f32>::from_fd(
1615                tensor
1616                    .clone_fd()
1617                    .expect("Failed to duplicate tensor file descriptor"),
1618                &shape,
1619                Some("test_tensor_shared"),
1620            )
1621            .expect("Failed to create tensor from fd");
1622
1623            assert_eq!(shared.memory(), TensorMemory::Shm);
1624            assert_eq!(shared.name(), "test_tensor_shared");
1625            assert_eq!(shared.shape(), &shape);
1626
1627            let mut tensor_map = shared.map().expect("Failed to map shared memory from fd");
1628            tensor_map.fill(DUMMY_VALUE);
1629            assert!(tensor_map.iter().all(|&x| x == DUMMY_VALUE));
1630        }
1631
1632        {
1633            let tensor_map = tensor.map().expect("Failed to map shared memory");
1634            assert!(tensor_map.iter().all(|&x| x == DUMMY_VALUE));
1635        }
1636
1637        let mut tensor = ShmTensor::<u8>::new(&shape, None).expect("Failed to create tensor");
1638        assert_eq!(tensor.shape(), &shape);
1639        let new_shape = vec![3, 4, 4];
1640        assert!(
1641            tensor.reshape(&new_shape).is_err(),
1642            "Reshape should fail due to size mismatch"
1643        );
1644        assert_eq!(tensor.shape(), &shape, "Shape should remain unchanged");
1645
1646        let new_shape = vec![2, 3, 4];
1647        tensor.reshape(&new_shape).expect("Reshape should succeed");
1648        assert_eq!(
1649            tensor.shape(),
1650            &new_shape,
1651            "Shape should be updated after successful reshape"
1652        );
1653
1654        {
1655            let mut tensor_map = tensor.map().expect("Failed to map shared memory");
1656            tensor_map.fill(1);
1657            assert!(tensor_map.iter().all(|&x| x == 1));
1658        }
1659
1660        {
1661            let mut tensor_map = tensor.map().expect("Failed to map shared memory");
1662            tensor_map[2] = 42;
1663            assert_eq!(tensor_map[1], 1, "Value at index 1 should be 1");
1664            assert_eq!(tensor_map[2], 42, "Value at index 2 should be 42");
1665        }
1666    }
1667
1668    #[test]
1669    fn test_mem_tensor() {
1670        let shape = vec![2, 3, 4];
1671        let tensor =
1672            MemTensor::<f32>::new(&shape, Some("test_tensor")).expect("Failed to create tensor");
1673        assert_eq!(tensor.shape(), &shape);
1674        assert_eq!(tensor.size(), 2 * 3 * 4 * std::mem::size_of::<f32>());
1675        assert_eq!(tensor.name(), "test_tensor");
1676
1677        {
1678            let mut tensor_map = tensor.map().expect("Failed to map memory");
1679            tensor_map.fill(42.0);
1680            assert!(tensor_map.iter().all(|&x| x == 42.0));
1681        }
1682
1683        let mut tensor = MemTensor::<u8>::new(&shape, None).expect("Failed to create tensor");
1684        assert_eq!(tensor.shape(), &shape);
1685        let new_shape = vec![3, 4, 4];
1686        assert!(
1687            tensor.reshape(&new_shape).is_err(),
1688            "Reshape should fail due to size mismatch"
1689        );
1690        assert_eq!(tensor.shape(), &shape, "Shape should remain unchanged");
1691
1692        let new_shape = vec![2, 3, 4];
1693        tensor.reshape(&new_shape).expect("Reshape should succeed");
1694        assert_eq!(
1695            tensor.shape(),
1696            &new_shape,
1697            "Shape should be updated after successful reshape"
1698        );
1699
1700        {
1701            let mut tensor_map = tensor.map().expect("Failed to map memory");
1702            tensor_map.fill(1);
1703            assert!(tensor_map.iter().all(|&x| x == 1));
1704        }
1705
1706        {
1707            let mut tensor_map = tensor.map().expect("Failed to map memory");
1708            tensor_map[2] = 42;
1709            assert_eq!(tensor_map[1], 1, "Value at index 1 should be 1");
1710            assert_eq!(tensor_map[2], 42, "Value at index 2 should be 42");
1711        }
1712    }
1713
1714    #[test]
1715    #[cfg(target_os = "linux")]
1716    fn test_dma_no_fd_leaks() {
1717        let _lock = FD_LOCK.write().unwrap();
1718        if !is_dma_available() {
1719            log::warn!(
1720                "SKIPPED: {} - DMA memory allocation not available (permission denied or no DMA-BUF support)",
1721                function!()
1722            );
1723            return;
1724        }
1725
1726        let proc = procfs::process::Process::myself()
1727            .expect("Failed to get current process using /proc/self");
1728
1729        let start_open_fds = proc
1730            .fd_count()
1731            .expect("Failed to get open file descriptor count");
1732
1733        for _ in 0..100 {
1734            let tensor = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Dma), None)
1735                .expect("Failed to create tensor");
1736            let mut map = tensor.map().unwrap();
1737            map.as_mut_slice().fill(233);
1738        }
1739
1740        let end_open_fds = proc
1741            .fd_count()
1742            .expect("Failed to get open file descriptor count");
1743
1744        assert_eq!(
1745            start_open_fds, end_open_fds,
1746            "File descriptor leak detected: {} -> {}",
1747            start_open_fds, end_open_fds
1748        );
1749    }
1750
1751    #[test]
1752    #[cfg(target_os = "linux")]
1753    fn test_dma_from_fd_no_fd_leaks() {
1754        let _lock = FD_LOCK.write().unwrap();
1755        if !is_dma_available() {
1756            log::warn!(
1757                "SKIPPED: {} - DMA memory allocation not available (permission denied or no DMA-BUF support)",
1758                function!()
1759            );
1760            return;
1761        }
1762
1763        let proc = procfs::process::Process::myself()
1764            .expect("Failed to get current process using /proc/self");
1765
1766        let start_open_fds = proc
1767            .fd_count()
1768            .expect("Failed to get open file descriptor count");
1769
1770        let orig = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Dma), None).unwrap();
1771
1772        for _ in 0..100 {
1773            let tensor =
1774                Tensor::<u8>::from_fd(orig.clone_fd().unwrap(), orig.shape(), None).unwrap();
1775            let mut map = tensor.map().unwrap();
1776            map.as_mut_slice().fill(233);
1777        }
1778        drop(orig);
1779
1780        let end_open_fds = proc.fd_count().unwrap();
1781
1782        assert_eq!(
1783            start_open_fds, end_open_fds,
1784            "File descriptor leak detected: {} -> {}",
1785            start_open_fds, end_open_fds
1786        );
1787    }
1788
1789    #[test]
1790    #[cfg(target_os = "linux")]
1791    fn test_shm_no_fd_leaks() {
1792        let _lock = FD_LOCK.write().unwrap();
1793        if !is_shm_available() {
1794            log::warn!(
1795                "SKIPPED: {} - SHM memory allocation not available (permission denied or no SHM support)",
1796                function!()
1797            );
1798            return;
1799        }
1800
1801        let proc = procfs::process::Process::myself()
1802            .expect("Failed to get current process using /proc/self");
1803
1804        let start_open_fds = proc
1805            .fd_count()
1806            .expect("Failed to get open file descriptor count");
1807
1808        for _ in 0..100 {
1809            let tensor = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Shm), None)
1810                .expect("Failed to create tensor");
1811            let mut map = tensor.map().unwrap();
1812            map.as_mut_slice().fill(233);
1813        }
1814
1815        let end_open_fds = proc
1816            .fd_count()
1817            .expect("Failed to get open file descriptor count");
1818
1819        assert_eq!(
1820            start_open_fds, end_open_fds,
1821            "File descriptor leak detected: {} -> {}",
1822            start_open_fds, end_open_fds
1823        );
1824    }
1825
1826    #[test]
1827    #[cfg(target_os = "linux")]
1828    fn test_shm_from_fd_no_fd_leaks() {
1829        let _lock = FD_LOCK.write().unwrap();
1830        if !is_shm_available() {
1831            log::warn!(
1832                "SKIPPED: {} - SHM memory allocation not available (permission denied or no SHM support)",
1833                function!()
1834            );
1835            return;
1836        }
1837
1838        let proc = procfs::process::Process::myself()
1839            .expect("Failed to get current process using /proc/self");
1840
1841        let start_open_fds = proc
1842            .fd_count()
1843            .expect("Failed to get open file descriptor count");
1844
1845        let orig = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Shm), None).unwrap();
1846
1847        for _ in 0..100 {
1848            let tensor =
1849                Tensor::<u8>::from_fd(orig.clone_fd().unwrap(), orig.shape(), None).unwrap();
1850            let mut map = tensor.map().unwrap();
1851            map.as_mut_slice().fill(233);
1852        }
1853        drop(orig);
1854
1855        let end_open_fds = proc.fd_count().unwrap();
1856
1857        assert_eq!(
1858            start_open_fds, end_open_fds,
1859            "File descriptor leak detected: {} -> {}",
1860            start_open_fds, end_open_fds
1861        );
1862    }
1863
1864    #[cfg(feature = "ndarray")]
1865    #[test]
1866    fn test_ndarray() {
1867        let _lock = FD_LOCK.read().unwrap();
1868        let shape = vec![2, 3, 4];
1869        let tensor = Tensor::<f32>::new(&shape, None, None).expect("Failed to create tensor");
1870
1871        let mut tensor_map = tensor.map().expect("Failed to map tensor memory");
1872        tensor_map.fill(1.0);
1873
1874        let view = tensor_map.view().expect("Failed to get ndarray view");
1875        assert_eq!(view.shape(), &[2, 3, 4]);
1876        assert!(view.iter().all(|&x| x == 1.0));
1877
1878        let mut view_mut = tensor_map
1879            .view_mut()
1880            .expect("Failed to get mutable ndarray view");
1881        view_mut[[0, 0, 0]] = 42.0;
1882        assert_eq!(view_mut[[0, 0, 0]], 42.0);
1883        assert_eq!(tensor_map[0], 42.0, "Value at index 0 should be 42");
1884    }
1885
1886    #[test]
1887    fn test_buffer_identity_unique() {
1888        let id1 = BufferIdentity::new();
1889        let id2 = BufferIdentity::new();
1890        assert_ne!(
1891            id1.id(),
1892            id2.id(),
1893            "Two identities should have different ids"
1894        );
1895    }
1896
1897    #[test]
1898    fn test_buffer_identity_clone_shares_guard() {
1899        let id1 = BufferIdentity::new();
1900        let weak = id1.weak();
1901        assert!(
1902            weak.upgrade().is_some(),
1903            "Weak should be alive while original exists"
1904        );
1905
1906        let id2 = id1.clone();
1907        assert_eq!(id1.id(), id2.id(), "Cloned identity should have same id");
1908
1909        drop(id1);
1910        assert!(
1911            weak.upgrade().is_some(),
1912            "Weak should still be alive (clone holds Arc)"
1913        );
1914
1915        drop(id2);
1916        assert!(
1917            weak.upgrade().is_none(),
1918            "Weak should be dead after all clones dropped"
1919        );
1920    }
1921
1922    #[test]
1923    fn test_tensor_buffer_identity() {
1924        let t1 = Tensor::<u8>::new(&[100], Some(TensorMemory::Mem), Some("t1")).unwrap();
1925        let t2 = Tensor::<u8>::new(&[100], Some(TensorMemory::Mem), Some("t2")).unwrap();
1926        assert_ne!(
1927            t1.buffer_identity().id(),
1928            t2.buffer_identity().id(),
1929            "Different tensors should have different buffer ids"
1930        );
1931    }
1932
1933    // Any test that cares about the fd count must grab it exclusively.
1934    // Any tests which modifies the fd count by opening or closing fds must grab it
1935    // shared.
1936    pub static FD_LOCK: RwLock<()> = RwLock::new(());
1937
1938    /// Test that DMA is NOT available on non-Linux platforms.
1939    /// This verifies the cross-platform behavior of is_dma_available().
1940    #[test]
1941    #[cfg(not(target_os = "linux"))]
1942    fn test_dma_not_available_on_non_linux() {
1943        assert!(
1944            !is_dma_available(),
1945            "DMA memory allocation should NOT be available on non-Linux platforms"
1946        );
1947    }
1948
1949    /// Test that SHM memory allocation is available and usable on Unix systems.
1950    /// This is a basic functional test; Linux has additional FD leak tests using procfs.
1951    #[test]
1952    #[cfg(unix)]
1953    fn test_shm_available_and_usable() {
1954        assert!(
1955            is_shm_available(),
1956            "SHM memory allocation should be available on Unix systems"
1957        );
1958
1959        // Create a tensor with SHM backing
1960        let tensor = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Shm), None)
1961            .expect("Failed to create SHM tensor");
1962
1963        // Verify we can map and write to it
1964        let mut map = tensor.map().expect("Failed to map SHM tensor");
1965        map.as_mut_slice().fill(0xAB);
1966
1967        // Verify the data was written correctly
1968        assert!(
1969            map.as_slice().iter().all(|&b| b == 0xAB),
1970            "SHM tensor data should be writable and readable"
1971        );
1972    }
1973}