Skip to main content

edgefirst_tensor/
lib.rs

1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4/*!
5EdgeFirst HAL - Tensor Module
6
7The `edgefirst_tensor` crate provides a unified interface for managing multi-dimensional arrays (tensors)
8with support for different memory types, including Direct Memory Access (DMA), POSIX Shared Memory (Shm),
9and system memory. The crate defines traits and structures for creating, reshaping, and mapping tensors into memory.
10
11## Examples
12```rust
13use edgefirst_tensor::{Error, Tensor, TensorMemory, TensorTrait};
14# fn main() -> Result<(), Error> {
15let tensor = Tensor::<f32>::new(&[2, 3, 4], Some(TensorMemory::Mem), Some("test_tensor"))?;
16assert_eq!(tensor.memory(), TensorMemory::Mem);
17assert_eq!(tensor.name(), "test_tensor");
18#    Ok(())
19# }
20```
21
22## Overview
23The main structures and traits provided by the `edgefirst_tensor` crate are `TensorTrait` and `TensorMapTrait`,
24which define the behavior of Tensors and their memory mappings, respectively.
25The `Tensor<T>` struct wraps a backend-specific storage with optional image format metadata (`PixelFormat`),
26while the `TensorMap` enum provides access to the underlying data. The `TensorDyn` type-erased enum
27wraps `Tensor<T>` for runtime element-type dispatch.
28 */
29#[cfg(target_os = "linux")]
30mod dma;
31#[cfg(target_os = "linux")]
32mod dmabuf;
33mod error;
34mod format;
35mod mem;
36mod pbo;
37#[cfg(unix)]
38mod shm;
39mod tensor_dyn;
40
41#[cfg(target_os = "linux")]
42pub use crate::dma::{DmaMap, DmaTensor};
43pub use crate::mem::{MemMap, MemTensor};
44pub use crate::pbo::{PboMap, PboMapping, PboOps, PboTensor};
45#[cfg(unix)]
46pub use crate::shm::{ShmMap, ShmTensor};
47pub use error::{Error, Result};
48pub use format::{PixelFormat, PixelLayout};
49use num_traits::Num;
50use serde::{Deserialize, Serialize};
51#[cfg(unix)]
52use std::os::fd::OwnedFd;
53use std::{
54    fmt,
55    ops::{Deref, DerefMut},
56    sync::{
57        atomic::{AtomicU64, Ordering},
58        Arc, Weak,
59    },
60};
61pub use tensor_dyn::TensorDyn;
62
63/// Per-plane DMA-BUF descriptor for external buffer import.
64///
65/// Owns a duplicated file descriptor plus optional stride and offset metadata.
66/// The fd is duplicated eagerly in [`new()`](Self::new) so that a bad fd is
67/// caught immediately. `import_image` consumes the descriptor and takes
68/// ownership of the duped fd — no further cleanup is needed by the caller.
69///
70/// # Examples
71///
72/// ```rust,no_run
73/// use edgefirst_tensor::PlaneDescriptor;
74/// use std::os::fd::BorrowedFd;
75///
76/// // SAFETY: fd 42 is hypothetical; real code must pass a valid fd.
77/// let pd = unsafe { PlaneDescriptor::new(BorrowedFd::borrow_raw(42)) }
78///     .unwrap()
79///     .with_stride(2048)
80///     .with_offset(0);
81/// ```
82#[cfg(unix)]
83pub struct PlaneDescriptor {
84    fd: OwnedFd,
85    stride: Option<usize>,
86    offset: Option<usize>,
87}
88
89#[cfg(unix)]
90impl PlaneDescriptor {
91    /// Create a new plane descriptor by duplicating the given file descriptor.
92    ///
93    /// The fd is duped immediately — a bad fd fails here rather than inside
94    /// `import_image`. The caller retains ownership of the original fd.
95    ///
96    /// # Errors
97    ///
98    /// Returns an error if the `dup()` syscall fails (e.g. invalid fd or
99    /// fd limit reached).
100    pub fn new(fd: std::os::fd::BorrowedFd<'_>) -> Result<Self> {
101        let owned = fd.try_clone_to_owned()?;
102        Ok(Self {
103            fd: owned,
104            stride: None,
105            offset: None,
106        })
107    }
108
109    /// Set the row stride in bytes (consuming builder).
110    pub fn with_stride(mut self, stride: usize) -> Self {
111        self.stride = Some(stride);
112        self
113    }
114
115    /// Set the plane offset in bytes (consuming builder).
116    pub fn with_offset(mut self, offset: usize) -> Self {
117        self.offset = Some(offset);
118        self
119    }
120
121    /// Consume the descriptor and return the owned file descriptor.
122    pub fn into_fd(self) -> OwnedFd {
123        self.fd
124    }
125
126    /// Row stride in bytes, if set.
127    pub fn stride(&self) -> Option<usize> {
128        self.stride
129    }
130
131    /// Plane offset in bytes, if set.
132    pub fn offset(&self) -> Option<usize> {
133        self.offset
134    }
135}
136
137/// Element type discriminant for runtime type identification.
138#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
139#[repr(u8)]
140#[non_exhaustive]
141pub enum DType {
142    U8,
143    I8,
144    U16,
145    I16,
146    U32,
147    I32,
148    U64,
149    I64,
150    F16,
151    F32,
152    F64,
153}
154
155impl DType {
156    /// Size of one element in bytes.
157    pub const fn size(&self) -> usize {
158        match self {
159            Self::U8 | Self::I8 => 1,
160            Self::U16 | Self::I16 | Self::F16 => 2,
161            Self::U32 | Self::I32 | Self::F32 => 4,
162            Self::U64 | Self::I64 | Self::F64 => 8,
163        }
164    }
165
166    /// Short type name (e.g., "u8", "f32", "f16").
167    pub const fn name(&self) -> &'static str {
168        match self {
169            Self::U8 => "u8",
170            Self::I8 => "i8",
171            Self::U16 => "u16",
172            Self::I16 => "i16",
173            Self::U32 => "u32",
174            Self::I32 => "i32",
175            Self::U64 => "u64",
176            Self::I64 => "i64",
177            Self::F16 => "f16",
178            Self::F32 => "f32",
179            Self::F64 => "f64",
180        }
181    }
182}
183
184impl fmt::Display for DType {
185    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
186        f.write_str(self.name())
187    }
188}
189
190/// Monotonic counter for buffer identity IDs.
191static NEXT_BUFFER_ID: AtomicU64 = AtomicU64::new(1);
192
193/// Unique identity for a tensor's underlying buffer.
194///
195/// Created fresh on every buffer allocation or import. The `id` is a monotonic
196/// u64 used as a cache key. The `guard` is an `Arc<()>` whose weak references
197/// allow downstream caches to detect when the buffer has been dropped.
198#[derive(Debug, Clone)]
199pub struct BufferIdentity {
200    id: u64,
201    guard: Arc<()>,
202}
203
204impl BufferIdentity {
205    /// Create a new unique buffer identity.
206    pub fn new() -> Self {
207        Self {
208            id: NEXT_BUFFER_ID.fetch_add(1, Ordering::Relaxed),
209            guard: Arc::new(()),
210        }
211    }
212
213    /// Unique identifier for this buffer. Changes when the buffer changes.
214    pub fn id(&self) -> u64 {
215        self.id
216    }
217
218    /// Returns a weak reference to the buffer guard. Goes dead when the
219    /// owning Tensor is dropped (and no clones remain).
220    pub fn weak(&self) -> Weak<()> {
221        Arc::downgrade(&self.guard)
222    }
223}
224
225impl Default for BufferIdentity {
226    fn default() -> Self {
227        Self::new()
228    }
229}
230
231#[cfg(target_os = "linux")]
232use nix::sys::stat::{major, minor};
233
234pub trait TensorTrait<T>: Send + Sync
235where
236    T: Num + Clone + fmt::Debug,
237{
238    /// Create a new tensor with the given shape and optional name. If no name
239    /// is given, a random name will be generated.
240    fn new(shape: &[usize], name: Option<&str>) -> Result<Self>
241    where
242        Self: Sized;
243
244    #[cfg(unix)]
245    /// Create a new tensor using the given file descriptor, shape, and optional
246    /// name. If no name is given, a random name will be generated.
247    ///
248    /// On Linux: Inspects the fd to determine DMA vs SHM based on device major/minor.
249    /// On other Unix (macOS): Always creates SHM tensor.
250    fn from_fd(fd: std::os::fd::OwnedFd, shape: &[usize], name: Option<&str>) -> Result<Self>
251    where
252        Self: Sized;
253
254    #[cfg(unix)]
255    /// Clone the file descriptor associated with this tensor.
256    fn clone_fd(&self) -> Result<std::os::fd::OwnedFd>;
257
258    /// Get the memory type of this tensor.
259    fn memory(&self) -> TensorMemory;
260
261    /// Get the name of this tensor.
262    fn name(&self) -> String;
263
264    /// Get the number of elements in this tensor.
265    fn len(&self) -> usize {
266        self.shape().iter().product()
267    }
268
269    /// Check if the tensor is empty.
270    fn is_empty(&self) -> bool {
271        self.len() == 0
272    }
273
274    /// Get the size in bytes of this tensor.
275    fn size(&self) -> usize {
276        self.len() * std::mem::size_of::<T>()
277    }
278
279    /// Get the shape of this tensor.
280    fn shape(&self) -> &[usize];
281
282    /// Reshape this tensor to the given shape. The total number of elements
283    /// must remain the same.
284    fn reshape(&mut self, shape: &[usize]) -> Result<()>;
285
286    /// Map the tensor into memory and return a TensorMap for accessing the
287    /// data.
288    fn map(&self) -> Result<TensorMap<T>>;
289
290    /// Get the buffer identity for cache keying and liveness tracking.
291    fn buffer_identity(&self) -> &BufferIdentity;
292}
293
294pub trait TensorMapTrait<T>
295where
296    T: Num + Clone + fmt::Debug,
297{
298    /// Get the shape of this tensor map.
299    fn shape(&self) -> &[usize];
300
301    /// Unmap the tensor from memory.
302    fn unmap(&mut self);
303
304    /// Get the number of elements in this tensor map.
305    fn len(&self) -> usize {
306        self.shape().iter().product()
307    }
308
309    /// Check if the tensor map is empty.
310    fn is_empty(&self) -> bool {
311        self.len() == 0
312    }
313
314    /// Get the size in bytes of this tensor map.
315    fn size(&self) -> usize {
316        self.len() * std::mem::size_of::<T>()
317    }
318
319    /// Get a slice to the data in this tensor map.
320    fn as_slice(&self) -> &[T];
321
322    /// Get a mutable slice to the data in this tensor map.
323    fn as_mut_slice(&mut self) -> &mut [T];
324
325    #[cfg(feature = "ndarray")]
326    /// Get an ndarray ArrayView of the tensor data.
327    fn view(&'_ self) -> Result<ndarray::ArrayView<'_, T, ndarray::Dim<ndarray::IxDynImpl>>> {
328        Ok(ndarray::ArrayView::from_shape(
329            self.shape(),
330            self.as_slice(),
331        )?)
332    }
333
334    #[cfg(feature = "ndarray")]
335    /// Get an ndarray ArrayViewMut of the tensor data.
336    fn view_mut(
337        &'_ mut self,
338    ) -> Result<ndarray::ArrayViewMut<'_, T, ndarray::Dim<ndarray::IxDynImpl>>> {
339        let shape = self.shape().to_vec();
340        Ok(ndarray::ArrayViewMut::from_shape(
341            shape,
342            self.as_mut_slice(),
343        )?)
344    }
345}
346
347#[derive(Debug, Clone, Copy, PartialEq, Eq)]
348pub enum TensorMemory {
349    #[cfg(target_os = "linux")]
350    /// Direct Memory Access (DMA) allocation. Incurs additional
351    /// overhead for memory reading/writing with the CPU.  Allows for
352    /// hardware acceleration when supported.
353    Dma,
354    #[cfg(unix)]
355    /// POSIX Shared Memory allocation. Suitable for inter-process
356    /// communication, but not suitable for hardware acceleration.
357    Shm,
358
359    /// Regular system memory allocation
360    Mem,
361
362    /// OpenGL Pixel Buffer Object memory. Created by ImageProcessor
363    /// when DMA-buf is unavailable but OpenGL is present.
364    Pbo,
365}
366
367impl From<TensorMemory> for String {
368    fn from(memory: TensorMemory) -> Self {
369        match memory {
370            #[cfg(target_os = "linux")]
371            TensorMemory::Dma => "dma".to_owned(),
372            #[cfg(unix)]
373            TensorMemory::Shm => "shm".to_owned(),
374            TensorMemory::Mem => "mem".to_owned(),
375            TensorMemory::Pbo => "pbo".to_owned(),
376        }
377    }
378}
379
380impl TryFrom<&str> for TensorMemory {
381    type Error = Error;
382
383    fn try_from(s: &str) -> Result<Self> {
384        match s {
385            #[cfg(target_os = "linux")]
386            "dma" => Ok(TensorMemory::Dma),
387            #[cfg(unix)]
388            "shm" => Ok(TensorMemory::Shm),
389            "mem" => Ok(TensorMemory::Mem),
390            "pbo" => Ok(TensorMemory::Pbo),
391            _ => Err(Error::InvalidMemoryType(s.to_owned())),
392        }
393    }
394}
395
396#[derive(Debug)]
397#[allow(dead_code)] // Variants are constructed by downstream crates via pub(crate) helpers
398pub(crate) enum TensorStorage<T>
399where
400    T: Num + Clone + fmt::Debug + Send + Sync,
401{
402    #[cfg(target_os = "linux")]
403    Dma(DmaTensor<T>),
404    #[cfg(unix)]
405    Shm(ShmTensor<T>),
406    Mem(MemTensor<T>),
407    Pbo(PboTensor<T>),
408}
409
410impl<T> TensorStorage<T>
411where
412    T: Num + Clone + fmt::Debug + Send + Sync,
413{
414    /// Create a new tensor storage with the given shape, memory type, and
415    /// optional name. If no name is given, a random name will be generated.
416    /// If no memory type is given, the best available memory type will be
417    /// chosen based on the platform and environment variables.
418    fn new(shape: &[usize], memory: Option<TensorMemory>, name: Option<&str>) -> Result<Self> {
419        match memory {
420            #[cfg(target_os = "linux")]
421            Some(TensorMemory::Dma) => {
422                DmaTensor::<T>::new(shape, name).map(TensorStorage::Dma)
423            }
424            #[cfg(unix)]
425            Some(TensorMemory::Shm) => {
426                ShmTensor::<T>::new(shape, name).map(TensorStorage::Shm)
427            }
428            Some(TensorMemory::Mem) => {
429                MemTensor::<T>::new(shape, name).map(TensorStorage::Mem)
430            }
431            Some(TensorMemory::Pbo) => Err(crate::error::Error::NotImplemented(
432                "PboTensor cannot be created via Tensor::new() — use ImageProcessor::create_image()".to_owned(),
433            )),
434            None => {
435                if std::env::var("EDGEFIRST_TENSOR_FORCE_MEM")
436                    .is_ok_and(|x| x != "0" && x.to_lowercase() != "false")
437                {
438                    MemTensor::<T>::new(shape, name).map(TensorStorage::Mem)
439                } else {
440                    #[cfg(target_os = "linux")]
441                    {
442                        // Linux: Try DMA -> SHM -> Mem
443                        match DmaTensor::<T>::new(shape, name) {
444                            Ok(tensor) => Ok(TensorStorage::Dma(tensor)),
445                            Err(_) => {
446                                match ShmTensor::<T>::new(shape, name)
447                                    .map(TensorStorage::Shm)
448                                {
449                                    Ok(tensor) => Ok(tensor),
450                                    Err(_) => MemTensor::<T>::new(shape, name)
451                                        .map(TensorStorage::Mem),
452                                }
453                            }
454                        }
455                    }
456                    #[cfg(all(unix, not(target_os = "linux")))]
457                    {
458                        // macOS/BSD: Try SHM -> Mem (no DMA)
459                        match ShmTensor::<T>::new(shape, name) {
460                            Ok(tensor) => Ok(TensorStorage::Shm(tensor)),
461                            Err(_) => {
462                                MemTensor::<T>::new(shape, name).map(TensorStorage::Mem)
463                            }
464                        }
465                    }
466                    #[cfg(not(unix))]
467                    {
468                        // Windows/other: Mem only
469                        MemTensor::<T>::new(shape, name).map(TensorStorage::Mem)
470                    }
471                }
472            }
473        }
474    }
475
476    /// Create a new tensor storage using the given file descriptor, shape,
477    /// and optional name.
478    #[cfg(unix)]
479    fn from_fd(fd: OwnedFd, shape: &[usize], name: Option<&str>) -> Result<Self> {
480        #[cfg(target_os = "linux")]
481        {
482            use nix::sys::stat::fstat;
483
484            let stat = fstat(&fd)?;
485            let major = major(stat.st_dev);
486            let minor = minor(stat.st_dev);
487
488            log::debug!("Creating tensor from fd: major={major}, minor={minor}");
489
490            if major != 0 {
491                // Dma and Shm tensors are expected to have major number 0
492                return Err(Error::UnknownDeviceType(major, minor));
493            }
494
495            match minor {
496                9 | 10 => {
497                    // minor number 9 & 10 indicates DMA memory
498                    DmaTensor::<T>::from_fd(fd, shape, name).map(TensorStorage::Dma)
499                }
500                _ => {
501                    // other minor numbers are assumed to be shared memory
502                    ShmTensor::<T>::from_fd(fd, shape, name).map(TensorStorage::Shm)
503                }
504            }
505        }
506        #[cfg(all(unix, not(target_os = "linux")))]
507        {
508            // On macOS/BSD, always use SHM (no DMA support)
509            ShmTensor::<T>::from_fd(fd, shape, name).map(TensorStorage::Shm)
510        }
511    }
512}
513
514impl<T> TensorTrait<T> for TensorStorage<T>
515where
516    T: Num + Clone + fmt::Debug + Send + Sync,
517{
518    fn new(shape: &[usize], name: Option<&str>) -> Result<Self> {
519        Self::new(shape, None, name)
520    }
521
522    #[cfg(unix)]
523    fn from_fd(fd: OwnedFd, shape: &[usize], name: Option<&str>) -> Result<Self> {
524        Self::from_fd(fd, shape, name)
525    }
526
527    #[cfg(unix)]
528    fn clone_fd(&self) -> Result<OwnedFd> {
529        match self {
530            #[cfg(target_os = "linux")]
531            TensorStorage::Dma(t) => t.clone_fd(),
532            TensorStorage::Shm(t) => t.clone_fd(),
533            TensorStorage::Mem(t) => t.clone_fd(),
534            TensorStorage::Pbo(t) => t.clone_fd(),
535        }
536    }
537
538    fn memory(&self) -> TensorMemory {
539        match self {
540            #[cfg(target_os = "linux")]
541            TensorStorage::Dma(_) => TensorMemory::Dma,
542            #[cfg(unix)]
543            TensorStorage::Shm(_) => TensorMemory::Shm,
544            TensorStorage::Mem(_) => TensorMemory::Mem,
545            TensorStorage::Pbo(_) => TensorMemory::Pbo,
546        }
547    }
548
549    fn name(&self) -> String {
550        match self {
551            #[cfg(target_os = "linux")]
552            TensorStorage::Dma(t) => t.name(),
553            #[cfg(unix)]
554            TensorStorage::Shm(t) => t.name(),
555            TensorStorage::Mem(t) => t.name(),
556            TensorStorage::Pbo(t) => t.name(),
557        }
558    }
559
560    fn shape(&self) -> &[usize] {
561        match self {
562            #[cfg(target_os = "linux")]
563            TensorStorage::Dma(t) => t.shape(),
564            #[cfg(unix)]
565            TensorStorage::Shm(t) => t.shape(),
566            TensorStorage::Mem(t) => t.shape(),
567            TensorStorage::Pbo(t) => t.shape(),
568        }
569    }
570
571    fn reshape(&mut self, shape: &[usize]) -> Result<()> {
572        match self {
573            #[cfg(target_os = "linux")]
574            TensorStorage::Dma(t) => t.reshape(shape),
575            #[cfg(unix)]
576            TensorStorage::Shm(t) => t.reshape(shape),
577            TensorStorage::Mem(t) => t.reshape(shape),
578            TensorStorage::Pbo(t) => t.reshape(shape),
579        }
580    }
581
582    fn map(&self) -> Result<TensorMap<T>> {
583        match self {
584            #[cfg(target_os = "linux")]
585            TensorStorage::Dma(t) => t.map(),
586            #[cfg(unix)]
587            TensorStorage::Shm(t) => t.map(),
588            TensorStorage::Mem(t) => t.map(),
589            TensorStorage::Pbo(t) => t.map(),
590        }
591    }
592
593    fn buffer_identity(&self) -> &BufferIdentity {
594        match self {
595            #[cfg(target_os = "linux")]
596            TensorStorage::Dma(t) => t.buffer_identity(),
597            #[cfg(unix)]
598            TensorStorage::Shm(t) => t.buffer_identity(),
599            TensorStorage::Mem(t) => t.buffer_identity(),
600            TensorStorage::Pbo(t) => t.buffer_identity(),
601        }
602    }
603}
604
605/// Multi-backend tensor with optional image format metadata.
606///
607/// When `format` is `Some`, this tensor represents an image. Width, height,
608/// and channels are derived from `shape` + `format`. When `format` is `None`,
609/// this is a raw tensor (identical to the pre-refactoring behavior).
610#[derive(Debug)]
611pub struct Tensor<T>
612where
613    T: Num + Clone + fmt::Debug + Send + Sync,
614{
615    pub(crate) storage: TensorStorage<T>,
616    format: Option<PixelFormat>,
617    chroma: Option<Box<Tensor<T>>>,
618    /// Row stride in bytes for externally allocated buffers with row padding.
619    /// `None` means tightly packed (stride == width * bytes_per_pixel).
620    row_stride: Option<usize>,
621    /// Byte offset within the DMA-BUF where image data starts.
622    /// `None` means offset 0 (data starts at the beginning of the buffer).
623    plane_offset: Option<usize>,
624}
625
626impl<T> Tensor<T>
627where
628    T: Num + Clone + fmt::Debug + Send + Sync,
629{
630    /// Wrap a TensorStorage in a Tensor with no image metadata.
631    pub(crate) fn wrap(storage: TensorStorage<T>) -> Self {
632        Self {
633            storage,
634            format: None,
635            chroma: None,
636            row_stride: None,
637            plane_offset: None,
638        }
639    }
640
641    /// Create a new tensor with the given shape, memory type, and optional
642    /// name. If no name is given, a random name will be generated. If no
643    /// memory type is given, the best available memory type will be chosen
644    /// based on the platform and environment variables.
645    ///
646    /// On Linux platforms, the order of preference is: Dma -> Shm -> Mem.
647    /// On other Unix platforms (macOS), the order is: Shm -> Mem.
648    /// On non-Unix platforms, only Mem is available.
649    ///
650    /// # Environment Variables
651    /// - `EDGEFIRST_TENSOR_FORCE_MEM`: If set to a non-zero and non-false
652    ///   value, forces the use of regular system memory allocation
653    ///   (`TensorMemory::Mem`) regardless of platform capabilities.
654    ///
655    /// # Example
656    /// ```rust
657    /// use edgefirst_tensor::{Error, Tensor, TensorMemory, TensorTrait};
658    /// # fn main() -> Result<(), Error> {
659    /// let tensor = Tensor::<f32>::new(&[2, 3, 4], Some(TensorMemory::Mem), Some("test_tensor"))?;
660    /// assert_eq!(tensor.memory(), TensorMemory::Mem);
661    /// assert_eq!(tensor.name(), "test_tensor");
662    /// #    Ok(())
663    /// # }
664    /// ```
665    pub fn new(shape: &[usize], memory: Option<TensorMemory>, name: Option<&str>) -> Result<Self> {
666        TensorStorage::new(shape, memory, name).map(Self::wrap)
667    }
668
669    /// Create an image tensor with the given format.
670    pub fn image(
671        width: usize,
672        height: usize,
673        format: PixelFormat,
674        memory: Option<TensorMemory>,
675    ) -> Result<Self> {
676        let shape = match format.layout() {
677            PixelLayout::Packed => vec![height, width, format.channels()],
678            PixelLayout::Planar => vec![format.channels(), height, width],
679            PixelLayout::SemiPlanar => {
680                // Contiguous semi-planar: luma + interleaved chroma in one allocation.
681                // NV12 (4:2:0): H lines luma + H/2 lines chroma = H * 3/2 total
682                // NV16 (4:2:2): H lines luma + H lines chroma = H * 2 total
683                let total_h = match format {
684                    PixelFormat::Nv12 => {
685                        if !height.is_multiple_of(2) {
686                            return Err(Error::InvalidArgument(format!(
687                                "NV12 requires even height, got {height}"
688                            )));
689                        }
690                        height * 3 / 2
691                    }
692                    PixelFormat::Nv16 => height * 2,
693                    _ => {
694                        return Err(Error::InvalidArgument(format!(
695                            "unknown semi-planar height multiplier for {format:?}"
696                        )))
697                    }
698                };
699                vec![total_h, width]
700            }
701        };
702        let mut t = Self::new(&shape, memory, None)?;
703        t.format = Some(format);
704        Ok(t)
705    }
706
707    /// Attach format metadata to an existing tensor.
708    ///
709    /// # Arguments
710    ///
711    /// * `format` - The pixel format to attach
712    ///
713    /// # Returns
714    ///
715    /// `Ok(())` on success, with the format stored as metadata on the tensor.
716    ///
717    /// # Errors
718    ///
719    /// Returns `Error::InvalidShape` if the tensor shape is incompatible with
720    /// the format's layout (packed expects `[H, W, C]`, planar expects
721    /// `[C, H, W]`, semi-planar expects `[H*k, W]` with format-specific
722    /// height constraints).
723    pub fn set_format(&mut self, format: PixelFormat) -> Result<()> {
724        let shape = self.shape();
725        match format.layout() {
726            PixelLayout::Packed => {
727                if shape.len() != 3 || shape[2] != format.channels() {
728                    return Err(Error::InvalidShape(format!(
729                        "packed format {format:?} expects [H, W, {}], got {shape:?}",
730                        format.channels()
731                    )));
732                }
733            }
734            PixelLayout::Planar => {
735                if shape.len() != 3 || shape[0] != format.channels() {
736                    return Err(Error::InvalidShape(format!(
737                        "planar format {format:?} expects [{}, H, W], got {shape:?}",
738                        format.channels()
739                    )));
740                }
741            }
742            PixelLayout::SemiPlanar => {
743                if shape.len() != 2 {
744                    return Err(Error::InvalidShape(format!(
745                        "semi-planar format {format:?} expects [H*k, W], got {shape:?}"
746                    )));
747                }
748                match format {
749                    PixelFormat::Nv12 if !shape[0].is_multiple_of(3) => {
750                        return Err(Error::InvalidShape(format!(
751                            "NV12 contiguous shape[0] must be divisible by 3, got {}",
752                            shape[0]
753                        )));
754                    }
755                    PixelFormat::Nv16 if !shape[0].is_multiple_of(2) => {
756                        return Err(Error::InvalidShape(format!(
757                            "NV16 contiguous shape[0] must be even, got {}",
758                            shape[0]
759                        )));
760                    }
761                    _ => {}
762                }
763            }
764        }
765        // Clear stored stride/offset when format changes — they may be invalid
766        // for the new format. Caller must re-set after changing format.
767        if self.format != Some(format) {
768            self.row_stride = None;
769            self.plane_offset = None;
770        }
771        self.format = Some(format);
772        Ok(())
773    }
774
775    /// Pixel format (None if not an image).
776    pub fn format(&self) -> Option<PixelFormat> {
777        self.format
778    }
779
780    /// Image width (None if not an image).
781    pub fn width(&self) -> Option<usize> {
782        let fmt = self.format?;
783        let shape = self.shape();
784        match fmt.layout() {
785            PixelLayout::Packed => Some(shape[1]),
786            PixelLayout::Planar => Some(shape[2]),
787            PixelLayout::SemiPlanar => Some(shape[1]),
788        }
789    }
790
791    /// Image height (None if not an image).
792    pub fn height(&self) -> Option<usize> {
793        let fmt = self.format?;
794        let shape = self.shape();
795        match fmt.layout() {
796            PixelLayout::Packed => Some(shape[0]),
797            PixelLayout::Planar => Some(shape[1]),
798            PixelLayout::SemiPlanar => {
799                if self.is_multiplane() {
800                    Some(shape[0])
801                } else {
802                    match fmt {
803                        PixelFormat::Nv12 => Some(shape[0] * 2 / 3),
804                        PixelFormat::Nv16 => Some(shape[0] / 2),
805                        _ => None,
806                    }
807                }
808            }
809        }
810    }
811
812    /// Create from separate Y and UV planes (multiplane NV12/NV16).
813    pub fn from_planes(luma: Tensor<T>, chroma: Tensor<T>, format: PixelFormat) -> Result<Self> {
814        if format.layout() != PixelLayout::SemiPlanar {
815            return Err(Error::InvalidArgument(format!(
816                "from_planes requires a semi-planar format, got {format:?}"
817            )));
818        }
819        if chroma.format.is_some() || chroma.chroma.is_some() {
820            return Err(Error::InvalidArgument(
821                "chroma tensor must be a raw tensor (no format or chroma metadata)".into(),
822            ));
823        }
824        let luma_shape = luma.shape();
825        let chroma_shape = chroma.shape();
826        if luma_shape.len() != 2 || chroma_shape.len() != 2 {
827            return Err(Error::InvalidArgument(format!(
828                "from_planes expects 2D shapes, got luma={luma_shape:?} chroma={chroma_shape:?}"
829            )));
830        }
831        if luma_shape[1] != chroma_shape[1] {
832            return Err(Error::InvalidArgument(format!(
833                "luma width {} != chroma width {}",
834                luma_shape[1], chroma_shape[1]
835            )));
836        }
837        match format {
838            PixelFormat::Nv12 => {
839                if luma_shape[0] % 2 != 0 {
840                    return Err(Error::InvalidArgument(format!(
841                        "NV12 requires even luma height, got {}",
842                        luma_shape[0]
843                    )));
844                }
845                if chroma_shape[0] != luma_shape[0] / 2 {
846                    return Err(Error::InvalidArgument(format!(
847                        "NV12 chroma height {} != luma height / 2 ({})",
848                        chroma_shape[0],
849                        luma_shape[0] / 2
850                    )));
851                }
852            }
853            PixelFormat::Nv16 => {
854                if chroma_shape[0] != luma_shape[0] {
855                    return Err(Error::InvalidArgument(format!(
856                        "NV16 chroma height {} != luma height {}",
857                        chroma_shape[0], luma_shape[0]
858                    )));
859                }
860            }
861            _ => {
862                return Err(Error::InvalidArgument(format!(
863                    "from_planes only supports NV12 and NV16, got {format:?}"
864                )));
865            }
866        }
867
868        Ok(Tensor {
869            storage: luma.storage,
870            format: Some(format),
871            chroma: Some(Box::new(chroma)),
872            row_stride: luma.row_stride,
873            plane_offset: luma.plane_offset,
874        })
875    }
876
877    /// Whether this tensor uses separate plane allocations.
878    pub fn is_multiplane(&self) -> bool {
879        self.chroma.is_some()
880    }
881
882    /// Access the chroma plane for multiplane semi-planar images.
883    pub fn chroma(&self) -> Option<&Tensor<T>> {
884        self.chroma.as_deref()
885    }
886
887    /// Mutable access to the chroma plane for multiplane semi-planar images.
888    pub fn chroma_mut(&mut self) -> Option<&mut Tensor<T>> {
889        self.chroma.as_deref_mut()
890    }
891
892    /// Row stride in bytes (`None` = tightly packed).
893    pub fn row_stride(&self) -> Option<usize> {
894        self.row_stride
895    }
896
897    /// Effective row stride in bytes: the stored stride if set, otherwise the
898    /// minimum stride computed from the format, width, and element size.
899    /// Returns `None` only when no format is set and no explicit stride was
900    /// stored via [`set_row_stride`](Self::set_row_stride).
901    pub fn effective_row_stride(&self) -> Option<usize> {
902        if let Some(s) = self.row_stride {
903            return Some(s);
904        }
905        let fmt = self.format?;
906        let w = self.width()?;
907        let elem = std::mem::size_of::<T>();
908        Some(match fmt.layout() {
909            PixelLayout::Packed => w * fmt.channels() * elem,
910            PixelLayout::Planar | PixelLayout::SemiPlanar => w * elem,
911        })
912    }
913
914    /// Set the row stride in bytes for externally allocated buffers with
915    /// row padding (e.g. V4L2 or GStreamer allocators).
916    ///
917    /// The stride is propagated to the EGL DMA-BUF import attributes so
918    /// the GPU interprets the padded buffer layout correctly. Must be
919    /// called after [`set_format`](Self::set_format) and before the tensor
920    /// is first passed to [`ImageProcessor::convert`]. The stored stride
921    /// is cleared automatically if the pixel format is later changed.
922    ///
923    /// No stride-vs-buffer-size validation is performed because the
924    /// backing allocation size is not reliably known: external DMA-BUFs
925    /// may be over-allocated by the allocator, and internal tensors store
926    /// a logical (unpadded) shape. An incorrect stride will be caught by
927    /// the EGL driver at import time.
928    ///
929    /// # Arguments
930    ///
931    /// * `stride` - Row stride in bytes. Must be >= the minimum stride for
932    ///   the format (width * channels * sizeof(T) for packed,
933    ///   width * sizeof(T) for planar/semi-planar).
934    ///
935    /// # Errors
936    ///
937    /// * `InvalidArgument` if no pixel format is set on this tensor
938    /// * `InvalidArgument` if `stride` is less than the minimum for the
939    ///   format and width
940    pub fn set_row_stride(&mut self, stride: usize) -> Result<()> {
941        let fmt = self.format.ok_or_else(|| {
942            Error::InvalidArgument("cannot set row_stride without a pixel format".into())
943        })?;
944        let w = self.width().ok_or_else(|| {
945            Error::InvalidArgument("cannot determine width for row_stride validation".into())
946        })?;
947        let elem = std::mem::size_of::<T>();
948        let min_stride = match fmt.layout() {
949            PixelLayout::Packed => w * fmt.channels() * elem,
950            PixelLayout::Planar | PixelLayout::SemiPlanar => w * elem,
951        };
952        if stride < min_stride {
953            return Err(Error::InvalidArgument(format!(
954                "row_stride {stride} < minimum {min_stride} for {fmt:?} at width {w}"
955            )));
956        }
957        self.row_stride = Some(stride);
958        Ok(())
959    }
960
961    /// Set the row stride without format validation.
962    ///
963    /// Use this for raw sub-tensors (e.g. chroma planes) that don't carry
964    /// format metadata. The caller is responsible for ensuring the stride
965    /// is valid.
966    pub fn set_row_stride_unchecked(&mut self, stride: usize) {
967        self.row_stride = Some(stride);
968    }
969
970    /// Builder-style variant of [`set_row_stride`](Self::set_row_stride),
971    /// consuming and returning `self`.
972    ///
973    /// # Errors
974    ///
975    /// Same conditions as [`set_row_stride`](Self::set_row_stride).
976    pub fn with_row_stride(mut self, stride: usize) -> Result<Self> {
977        self.set_row_stride(stride)?;
978        Ok(self)
979    }
980
981    /// Byte offset within the DMA-BUF where image data starts (`None` = 0).
982    pub fn plane_offset(&self) -> Option<usize> {
983        self.plane_offset
984    }
985
986    /// Set the byte offset within the DMA-BUF where image data starts.
987    ///
988    /// Propagated to `EGL_DMA_BUF_PLANE0_OFFSET_EXT` on GPU import.
989    /// Unlike [`set_row_stride`](Self::set_row_stride), no format is required
990    /// since the offset is format-independent.
991    pub fn set_plane_offset(&mut self, offset: usize) {
992        self.plane_offset = Some(offset);
993    }
994
995    /// Builder-style variant of [`set_plane_offset`](Self::set_plane_offset),
996    /// consuming and returning `self`.
997    pub fn with_plane_offset(mut self, offset: usize) -> Self {
998        self.set_plane_offset(offset);
999        self
1000    }
1001
1002    /// Downcast to PBO tensor reference (for GL backends).
1003    pub fn as_pbo(&self) -> Option<&PboTensor<T>> {
1004        match &self.storage {
1005            TensorStorage::Pbo(p) => Some(p),
1006            _ => None,
1007        }
1008    }
1009
1010    /// Downcast to DMA tensor reference (for EGL import, G2D).
1011    #[cfg(target_os = "linux")]
1012    pub fn as_dma(&self) -> Option<&DmaTensor<T>> {
1013        match &self.storage {
1014            TensorStorage::Dma(d) => Some(d),
1015            _ => None,
1016        }
1017    }
1018
1019    /// Borrow the DMA-BUF file descriptor backing this tensor.
1020    ///
1021    /// # Returns
1022    ///
1023    /// A borrowed reference to the DMA-BUF file descriptor, tied to `self`'s
1024    /// lifetime.
1025    ///
1026    /// # Errors
1027    ///
1028    /// Returns `Error::NotImplemented` if the tensor is not DMA-backed.
1029    #[cfg(target_os = "linux")]
1030    pub fn dmabuf(&self) -> Result<std::os::fd::BorrowedFd<'_>> {
1031        use std::os::fd::AsFd;
1032        match &self.storage {
1033            TensorStorage::Dma(dma) => Ok(dma.fd.as_fd()),
1034            _ => Err(Error::NotImplemented(format!(
1035                "dmabuf requires DMA-backed tensor, got {:?}",
1036                self.storage.memory()
1037            ))),
1038        }
1039    }
1040
1041    /// Construct a Tensor from a PBO tensor (for GL backends that allocate PBOs).
1042    pub fn from_pbo(pbo: PboTensor<T>) -> Self {
1043        Self {
1044            storage: TensorStorage::Pbo(pbo),
1045            format: None,
1046            chroma: None,
1047            row_stride: None,
1048            plane_offset: None,
1049        }
1050    }
1051}
1052
1053impl<T> TensorTrait<T> for Tensor<T>
1054where
1055    T: Num + Clone + fmt::Debug + Send + Sync,
1056{
1057    fn new(shape: &[usize], name: Option<&str>) -> Result<Self>
1058    where
1059        Self: Sized,
1060    {
1061        Self::new(shape, None, name)
1062    }
1063
1064    #[cfg(unix)]
1065    fn from_fd(fd: std::os::fd::OwnedFd, shape: &[usize], name: Option<&str>) -> Result<Self>
1066    where
1067        Self: Sized,
1068    {
1069        Ok(Self::wrap(TensorStorage::from_fd(fd, shape, name)?))
1070    }
1071
1072    #[cfg(unix)]
1073    fn clone_fd(&self) -> Result<std::os::fd::OwnedFd> {
1074        self.storage.clone_fd()
1075    }
1076
1077    fn memory(&self) -> TensorMemory {
1078        self.storage.memory()
1079    }
1080
1081    fn name(&self) -> String {
1082        self.storage.name()
1083    }
1084
1085    fn shape(&self) -> &[usize] {
1086        self.storage.shape()
1087    }
1088
1089    fn reshape(&mut self, shape: &[usize]) -> Result<()> {
1090        if self.chroma.is_some() {
1091            return Err(Error::InvalidOperation(
1092                "cannot reshape a multiplane tensor — decompose planes first".into(),
1093            ));
1094        }
1095        self.storage.reshape(shape)?;
1096        self.format = None;
1097        self.row_stride = None;
1098        self.plane_offset = None;
1099        Ok(())
1100    }
1101
1102    fn map(&self) -> Result<TensorMap<T>> {
1103        if self.row_stride.is_some() {
1104            return Err(Error::InvalidOperation(
1105                "CPU mapping of strided tensors is not supported; use GPU path only".into(),
1106            ));
1107        }
1108        if self.plane_offset.is_some_and(|o| o > 0) {
1109            return Err(Error::InvalidOperation(
1110                "CPU mapping of offset tensors is not supported; use GPU path only".into(),
1111            ));
1112        }
1113        self.storage.map()
1114    }
1115
1116    fn buffer_identity(&self) -> &BufferIdentity {
1117        self.storage.buffer_identity()
1118    }
1119}
1120
1121pub enum TensorMap<T>
1122where
1123    T: Num + Clone + fmt::Debug,
1124{
1125    #[cfg(target_os = "linux")]
1126    Dma(DmaMap<T>),
1127    #[cfg(unix)]
1128    Shm(ShmMap<T>),
1129    Mem(MemMap<T>),
1130    Pbo(PboMap<T>),
1131}
1132
1133impl<T> TensorMapTrait<T> for TensorMap<T>
1134where
1135    T: Num + Clone + fmt::Debug,
1136{
1137    fn shape(&self) -> &[usize] {
1138        match self {
1139            #[cfg(target_os = "linux")]
1140            TensorMap::Dma(map) => map.shape(),
1141            #[cfg(unix)]
1142            TensorMap::Shm(map) => map.shape(),
1143            TensorMap::Mem(map) => map.shape(),
1144            TensorMap::Pbo(map) => map.shape(),
1145        }
1146    }
1147
1148    fn unmap(&mut self) {
1149        match self {
1150            #[cfg(target_os = "linux")]
1151            TensorMap::Dma(map) => map.unmap(),
1152            #[cfg(unix)]
1153            TensorMap::Shm(map) => map.unmap(),
1154            TensorMap::Mem(map) => map.unmap(),
1155            TensorMap::Pbo(map) => map.unmap(),
1156        }
1157    }
1158
1159    fn as_slice(&self) -> &[T] {
1160        match self {
1161            #[cfg(target_os = "linux")]
1162            TensorMap::Dma(map) => map.as_slice(),
1163            #[cfg(unix)]
1164            TensorMap::Shm(map) => map.as_slice(),
1165            TensorMap::Mem(map) => map.as_slice(),
1166            TensorMap::Pbo(map) => map.as_slice(),
1167        }
1168    }
1169
1170    fn as_mut_slice(&mut self) -> &mut [T] {
1171        match self {
1172            #[cfg(target_os = "linux")]
1173            TensorMap::Dma(map) => map.as_mut_slice(),
1174            #[cfg(unix)]
1175            TensorMap::Shm(map) => map.as_mut_slice(),
1176            TensorMap::Mem(map) => map.as_mut_slice(),
1177            TensorMap::Pbo(map) => map.as_mut_slice(),
1178        }
1179    }
1180}
1181
1182impl<T> Deref for TensorMap<T>
1183where
1184    T: Num + Clone + fmt::Debug,
1185{
1186    type Target = [T];
1187
1188    fn deref(&self) -> &[T] {
1189        match self {
1190            #[cfg(target_os = "linux")]
1191            TensorMap::Dma(map) => map.deref(),
1192            #[cfg(unix)]
1193            TensorMap::Shm(map) => map.deref(),
1194            TensorMap::Mem(map) => map.deref(),
1195            TensorMap::Pbo(map) => map.deref(),
1196        }
1197    }
1198}
1199
1200impl<T> DerefMut for TensorMap<T>
1201where
1202    T: Num + Clone + fmt::Debug,
1203{
1204    fn deref_mut(&mut self) -> &mut [T] {
1205        match self {
1206            #[cfg(target_os = "linux")]
1207            TensorMap::Dma(map) => map.deref_mut(),
1208            #[cfg(unix)]
1209            TensorMap::Shm(map) => map.deref_mut(),
1210            TensorMap::Mem(map) => map.deref_mut(),
1211            TensorMap::Pbo(map) => map.deref_mut(),
1212        }
1213    }
1214}
1215
1216// ============================================================================
1217// Platform availability helpers
1218// ============================================================================
1219
1220/// Check if DMA memory allocation is available on this system.
1221///
1222/// Returns `true` only on Linux systems with DMA-BUF heap access (typically
1223/// requires running as root or membership in a video/render group).
1224/// Always returns `false` on non-Linux platforms (macOS, Windows, etc.).
1225///
1226/// This function caches its result after the first call for efficiency.
1227#[cfg(target_os = "linux")]
1228static DMA_AVAILABLE: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
1229
1230/// Check if DMA memory allocation is available on this system.
1231#[cfg(target_os = "linux")]
1232pub fn is_dma_available() -> bool {
1233    *DMA_AVAILABLE.get_or_init(|| Tensor::<u8>::new(&[64], Some(TensorMemory::Dma), None).is_ok())
1234}
1235
1236/// Check if DMA memory allocation is available on this system.
1237///
1238/// Always returns `false` on non-Linux platforms since DMA-BUF is Linux-specific.
1239#[cfg(not(target_os = "linux"))]
1240pub fn is_dma_available() -> bool {
1241    false
1242}
1243
1244/// Check if POSIX shared memory allocation is available on this system.
1245///
1246/// Returns `true` on Unix systems (Linux, macOS, BSD) where POSIX shared memory
1247/// is supported. Always returns `false` on non-Unix platforms (Windows).
1248///
1249/// This function caches its result after the first call for efficiency.
1250#[cfg(unix)]
1251static SHM_AVAILABLE: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
1252
1253/// Check if POSIX shared memory allocation is available on this system.
1254#[cfg(unix)]
1255pub fn is_shm_available() -> bool {
1256    *SHM_AVAILABLE.get_or_init(|| Tensor::<u8>::new(&[64], Some(TensorMemory::Shm), None).is_ok())
1257}
1258
1259/// Check if POSIX shared memory allocation is available on this system.
1260///
1261/// Always returns `false` on non-Unix platforms since POSIX SHM is Unix-specific.
1262#[cfg(not(unix))]
1263pub fn is_shm_available() -> bool {
1264    false
1265}
1266
1267#[cfg(test)]
1268mod dtype_tests {
1269    use super::*;
1270
1271    #[test]
1272    fn dtype_size() {
1273        assert_eq!(DType::U8.size(), 1);
1274        assert_eq!(DType::I8.size(), 1);
1275        assert_eq!(DType::U16.size(), 2);
1276        assert_eq!(DType::I16.size(), 2);
1277        assert_eq!(DType::U32.size(), 4);
1278        assert_eq!(DType::I32.size(), 4);
1279        assert_eq!(DType::U64.size(), 8);
1280        assert_eq!(DType::I64.size(), 8);
1281        assert_eq!(DType::F16.size(), 2);
1282        assert_eq!(DType::F32.size(), 4);
1283        assert_eq!(DType::F64.size(), 8);
1284    }
1285
1286    #[test]
1287    fn dtype_name() {
1288        assert_eq!(DType::U8.name(), "u8");
1289        assert_eq!(DType::F16.name(), "f16");
1290        assert_eq!(DType::F32.name(), "f32");
1291    }
1292
1293    #[test]
1294    fn dtype_serde_roundtrip() {
1295        use serde_json;
1296        let dt = DType::F16;
1297        let json = serde_json::to_string(&dt).unwrap();
1298        let back: DType = serde_json::from_str(&json).unwrap();
1299        assert_eq!(dt, back);
1300    }
1301}
1302
1303#[cfg(test)]
1304mod image_tests {
1305    use super::*;
1306
1307    #[test]
1308    fn raw_tensor_has_no_format() {
1309        let t = Tensor::<u8>::new(&[480, 640, 3], None, None).unwrap();
1310        assert!(t.format().is_none());
1311        assert!(t.width().is_none());
1312        assert!(t.height().is_none());
1313        assert!(!t.is_multiplane());
1314        assert!(t.chroma().is_none());
1315    }
1316
1317    #[test]
1318    fn image_tensor_packed() {
1319        let t = Tensor::<u8>::image(640, 480, PixelFormat::Rgba, None).unwrap();
1320        assert_eq!(t.format(), Some(PixelFormat::Rgba));
1321        assert_eq!(t.width(), Some(640));
1322        assert_eq!(t.height(), Some(480));
1323        assert_eq!(t.shape(), &[480, 640, 4]);
1324        assert!(!t.is_multiplane());
1325    }
1326
1327    #[test]
1328    fn image_tensor_planar() {
1329        let t = Tensor::<u8>::image(640, 480, PixelFormat::PlanarRgb, None).unwrap();
1330        assert_eq!(t.format(), Some(PixelFormat::PlanarRgb));
1331        assert_eq!(t.width(), Some(640));
1332        assert_eq!(t.height(), Some(480));
1333        assert_eq!(t.shape(), &[3, 480, 640]);
1334    }
1335
1336    #[test]
1337    fn image_tensor_semi_planar_contiguous() {
1338        let t = Tensor::<u8>::image(640, 480, PixelFormat::Nv12, None).unwrap();
1339        assert_eq!(t.format(), Some(PixelFormat::Nv12));
1340        assert_eq!(t.width(), Some(640));
1341        assert_eq!(t.height(), Some(480));
1342        // NV12: H*3/2 = 720
1343        assert_eq!(t.shape(), &[720, 640]);
1344        assert!(!t.is_multiplane());
1345    }
1346
1347    #[test]
1348    fn set_format_valid() {
1349        let mut t = Tensor::<u8>::new(&[480, 640, 3], None, None).unwrap();
1350        assert!(t.format().is_none());
1351        t.set_format(PixelFormat::Rgb).unwrap();
1352        assert_eq!(t.format(), Some(PixelFormat::Rgb));
1353        assert_eq!(t.width(), Some(640));
1354        assert_eq!(t.height(), Some(480));
1355    }
1356
1357    #[test]
1358    fn set_format_invalid_shape() {
1359        let mut t = Tensor::<u8>::new(&[480, 640, 4], None, None).unwrap();
1360        // RGB expects 3 channels, not 4
1361        let err = t.set_format(PixelFormat::Rgb);
1362        assert!(err.is_err());
1363        // Original tensor is unmodified
1364        assert!(t.format().is_none());
1365    }
1366
1367    #[test]
1368    fn reshape_clears_format() {
1369        let mut t = Tensor::<u8>::image(640, 480, PixelFormat::Rgba, None).unwrap();
1370        assert_eq!(t.format(), Some(PixelFormat::Rgba));
1371        // Reshape to flat — format cleared
1372        t.reshape(&[480 * 640 * 4]).unwrap();
1373        assert!(t.format().is_none());
1374    }
1375
1376    #[test]
1377    fn from_planes_nv12() {
1378        let y = Tensor::<u8>::new(&[480, 640], None, None).unwrap();
1379        let uv = Tensor::<u8>::new(&[240, 640], None, None).unwrap();
1380        let img = Tensor::from_planes(y, uv, PixelFormat::Nv12).unwrap();
1381        assert_eq!(img.format(), Some(PixelFormat::Nv12));
1382        assert!(img.is_multiplane());
1383        assert!(img.chroma().is_some());
1384        assert_eq!(img.width(), Some(640));
1385        assert_eq!(img.height(), Some(480));
1386    }
1387
1388    #[test]
1389    fn from_planes_rejects_non_semiplanar() {
1390        let y = Tensor::<u8>::new(&[480, 640], None, None).unwrap();
1391        let uv = Tensor::<u8>::new(&[240, 640], None, None).unwrap();
1392        let err = Tensor::from_planes(y, uv, PixelFormat::Rgb);
1393        assert!(err.is_err());
1394    }
1395
1396    #[test]
1397    fn reshape_multiplane_errors() {
1398        let y = Tensor::<u8>::new(&[480, 640], None, None).unwrap();
1399        let uv = Tensor::<u8>::new(&[240, 640], None, None).unwrap();
1400        let mut img = Tensor::from_planes(y, uv, PixelFormat::Nv12).unwrap();
1401        let err = img.reshape(&[480 * 640 + 240 * 640]);
1402        assert!(err.is_err());
1403    }
1404}
1405
1406#[cfg(test)]
1407mod tests {
1408    #[cfg(target_os = "linux")]
1409    use nix::unistd::{access, AccessFlags};
1410    #[cfg(target_os = "linux")]
1411    use std::io::Write as _;
1412    use std::sync::RwLock;
1413
1414    use super::*;
1415
1416    #[ctor::ctor]
1417    fn init() {
1418        env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init();
1419    }
1420
1421    /// Macro to get the current function name for logging in tests.
1422    #[cfg(target_os = "linux")]
1423    macro_rules! function {
1424        () => {{
1425            fn f() {}
1426            fn type_name_of<T>(_: T) -> &'static str {
1427                std::any::type_name::<T>()
1428            }
1429            let name = type_name_of(f);
1430
1431            // Find and cut the rest of the path
1432            match &name[..name.len() - 3].rfind(':') {
1433                Some(pos) => &name[pos + 1..name.len() - 3],
1434                None => &name[..name.len() - 3],
1435            }
1436        }};
1437    }
1438
1439    #[test]
1440    #[cfg(target_os = "linux")]
1441    fn test_tensor() {
1442        let _lock = FD_LOCK.read().unwrap();
1443        let shape = vec![1];
1444        let tensor = DmaTensor::<f32>::new(&shape, Some("dma_tensor"));
1445        let dma_enabled = tensor.is_ok();
1446
1447        let tensor = Tensor::<f32>::new(&shape, None, None).expect("Failed to create tensor");
1448        match dma_enabled {
1449            true => assert_eq!(tensor.memory(), TensorMemory::Dma),
1450            false => assert_eq!(tensor.memory(), TensorMemory::Shm),
1451        }
1452    }
1453
1454    #[test]
1455    #[cfg(all(unix, not(target_os = "linux")))]
1456    fn test_tensor() {
1457        let shape = vec![1];
1458        let tensor = Tensor::<f32>::new(&shape, None, None).expect("Failed to create tensor");
1459        // On macOS/BSD, auto-detection tries SHM first, falls back to Mem
1460        assert!(
1461            tensor.memory() == TensorMemory::Shm || tensor.memory() == TensorMemory::Mem,
1462            "Expected SHM or Mem on macOS, got {:?}",
1463            tensor.memory()
1464        );
1465    }
1466
1467    #[test]
1468    #[cfg(not(unix))]
1469    fn test_tensor() {
1470        let shape = vec![1];
1471        let tensor = Tensor::<f32>::new(&shape, None, None).expect("Failed to create tensor");
1472        assert_eq!(tensor.memory(), TensorMemory::Mem);
1473    }
1474
1475    #[test]
1476    #[cfg(target_os = "linux")]
1477    fn test_dma_tensor() {
1478        let _lock = FD_LOCK.read().unwrap();
1479        match access(
1480            "/dev/dma_heap/linux,cma",
1481            AccessFlags::R_OK | AccessFlags::W_OK,
1482        ) {
1483            Ok(_) => println!("/dev/dma_heap/linux,cma is available"),
1484            Err(_) => match access(
1485                "/dev/dma_heap/system",
1486                AccessFlags::R_OK | AccessFlags::W_OK,
1487            ) {
1488                Ok(_) => println!("/dev/dma_heap/system is available"),
1489                Err(e) => {
1490                    writeln!(
1491                        &mut std::io::stdout(),
1492                        "[WARNING] DMA Heap is unavailable: {e}"
1493                    )
1494                    .unwrap();
1495                    return;
1496                }
1497            },
1498        }
1499
1500        let shape = vec![2, 3, 4];
1501        let tensor =
1502            DmaTensor::<f32>::new(&shape, Some("test_tensor")).expect("Failed to create tensor");
1503
1504        const DUMMY_VALUE: f32 = 12.34;
1505
1506        assert_eq!(tensor.memory(), TensorMemory::Dma);
1507        assert_eq!(tensor.name(), "test_tensor");
1508        assert_eq!(tensor.shape(), &shape);
1509        assert_eq!(tensor.size(), 2 * 3 * 4 * std::mem::size_of::<f32>());
1510        assert_eq!(tensor.len(), 2 * 3 * 4);
1511
1512        {
1513            let mut tensor_map = tensor.map().expect("Failed to map DMA memory");
1514            tensor_map.fill(42.0);
1515            assert!(tensor_map.iter().all(|&x| x == 42.0));
1516        }
1517
1518        {
1519            let shared = Tensor::<f32>::from_fd(
1520                tensor
1521                    .clone_fd()
1522                    .expect("Failed to duplicate tensor file descriptor"),
1523                &shape,
1524                Some("test_tensor_shared"),
1525            )
1526            .expect("Failed to create tensor from fd");
1527
1528            assert_eq!(shared.memory(), TensorMemory::Dma);
1529            assert_eq!(shared.name(), "test_tensor_shared");
1530            assert_eq!(shared.shape(), &shape);
1531
1532            let mut tensor_map = shared.map().expect("Failed to map DMA memory from fd");
1533            tensor_map.fill(DUMMY_VALUE);
1534            assert!(tensor_map.iter().all(|&x| x == DUMMY_VALUE));
1535        }
1536
1537        {
1538            let tensor_map = tensor.map().expect("Failed to map DMA memory");
1539            assert!(tensor_map.iter().all(|&x| x == DUMMY_VALUE));
1540        }
1541
1542        let mut tensor = DmaTensor::<u8>::new(&shape, None).expect("Failed to create tensor");
1543        assert_eq!(tensor.shape(), &shape);
1544        let new_shape = vec![3, 4, 4];
1545        assert!(
1546            tensor.reshape(&new_shape).is_err(),
1547            "Reshape should fail due to size mismatch"
1548        );
1549        assert_eq!(tensor.shape(), &shape, "Shape should remain unchanged");
1550
1551        let new_shape = vec![2, 3, 4];
1552        tensor.reshape(&new_shape).expect("Reshape should succeed");
1553        assert_eq!(
1554            tensor.shape(),
1555            &new_shape,
1556            "Shape should be updated after successful reshape"
1557        );
1558
1559        {
1560            let mut tensor_map = tensor.map().expect("Failed to map DMA memory");
1561            tensor_map.fill(1);
1562            assert!(tensor_map.iter().all(|&x| x == 1));
1563        }
1564
1565        {
1566            let mut tensor_map = tensor.map().expect("Failed to map DMA memory");
1567            tensor_map[2] = 42;
1568            assert_eq!(tensor_map[1], 1, "Value at index 1 should be 1");
1569            assert_eq!(tensor_map[2], 42, "Value at index 2 should be 42");
1570        }
1571    }
1572
1573    #[test]
1574    #[cfg(unix)]
1575    fn test_shm_tensor() {
1576        let _lock = FD_LOCK.read().unwrap();
1577        let shape = vec![2, 3, 4];
1578        let tensor =
1579            ShmTensor::<f32>::new(&shape, Some("test_tensor")).expect("Failed to create tensor");
1580        assert_eq!(tensor.shape(), &shape);
1581        assert_eq!(tensor.size(), 2 * 3 * 4 * std::mem::size_of::<f32>());
1582        assert_eq!(tensor.name(), "test_tensor");
1583
1584        const DUMMY_VALUE: f32 = 12.34;
1585        {
1586            let mut tensor_map = tensor.map().expect("Failed to map shared memory");
1587            tensor_map.fill(42.0);
1588            assert!(tensor_map.iter().all(|&x| x == 42.0));
1589        }
1590
1591        {
1592            let shared = Tensor::<f32>::from_fd(
1593                tensor
1594                    .clone_fd()
1595                    .expect("Failed to duplicate tensor file descriptor"),
1596                &shape,
1597                Some("test_tensor_shared"),
1598            )
1599            .expect("Failed to create tensor from fd");
1600
1601            assert_eq!(shared.memory(), TensorMemory::Shm);
1602            assert_eq!(shared.name(), "test_tensor_shared");
1603            assert_eq!(shared.shape(), &shape);
1604
1605            let mut tensor_map = shared.map().expect("Failed to map shared memory from fd");
1606            tensor_map.fill(DUMMY_VALUE);
1607            assert!(tensor_map.iter().all(|&x| x == DUMMY_VALUE));
1608        }
1609
1610        {
1611            let tensor_map = tensor.map().expect("Failed to map shared memory");
1612            assert!(tensor_map.iter().all(|&x| x == DUMMY_VALUE));
1613        }
1614
1615        let mut tensor = ShmTensor::<u8>::new(&shape, None).expect("Failed to create tensor");
1616        assert_eq!(tensor.shape(), &shape);
1617        let new_shape = vec![3, 4, 4];
1618        assert!(
1619            tensor.reshape(&new_shape).is_err(),
1620            "Reshape should fail due to size mismatch"
1621        );
1622        assert_eq!(tensor.shape(), &shape, "Shape should remain unchanged");
1623
1624        let new_shape = vec![2, 3, 4];
1625        tensor.reshape(&new_shape).expect("Reshape should succeed");
1626        assert_eq!(
1627            tensor.shape(),
1628            &new_shape,
1629            "Shape should be updated after successful reshape"
1630        );
1631
1632        {
1633            let mut tensor_map = tensor.map().expect("Failed to map shared memory");
1634            tensor_map.fill(1);
1635            assert!(tensor_map.iter().all(|&x| x == 1));
1636        }
1637
1638        {
1639            let mut tensor_map = tensor.map().expect("Failed to map shared memory");
1640            tensor_map[2] = 42;
1641            assert_eq!(tensor_map[1], 1, "Value at index 1 should be 1");
1642            assert_eq!(tensor_map[2], 42, "Value at index 2 should be 42");
1643        }
1644    }
1645
1646    #[test]
1647    fn test_mem_tensor() {
1648        let shape = vec![2, 3, 4];
1649        let tensor =
1650            MemTensor::<f32>::new(&shape, Some("test_tensor")).expect("Failed to create tensor");
1651        assert_eq!(tensor.shape(), &shape);
1652        assert_eq!(tensor.size(), 2 * 3 * 4 * std::mem::size_of::<f32>());
1653        assert_eq!(tensor.name(), "test_tensor");
1654
1655        {
1656            let mut tensor_map = tensor.map().expect("Failed to map memory");
1657            tensor_map.fill(42.0);
1658            assert!(tensor_map.iter().all(|&x| x == 42.0));
1659        }
1660
1661        let mut tensor = MemTensor::<u8>::new(&shape, None).expect("Failed to create tensor");
1662        assert_eq!(tensor.shape(), &shape);
1663        let new_shape = vec![3, 4, 4];
1664        assert!(
1665            tensor.reshape(&new_shape).is_err(),
1666            "Reshape should fail due to size mismatch"
1667        );
1668        assert_eq!(tensor.shape(), &shape, "Shape should remain unchanged");
1669
1670        let new_shape = vec![2, 3, 4];
1671        tensor.reshape(&new_shape).expect("Reshape should succeed");
1672        assert_eq!(
1673            tensor.shape(),
1674            &new_shape,
1675            "Shape should be updated after successful reshape"
1676        );
1677
1678        {
1679            let mut tensor_map = tensor.map().expect("Failed to map memory");
1680            tensor_map.fill(1);
1681            assert!(tensor_map.iter().all(|&x| x == 1));
1682        }
1683
1684        {
1685            let mut tensor_map = tensor.map().expect("Failed to map memory");
1686            tensor_map[2] = 42;
1687            assert_eq!(tensor_map[1], 1, "Value at index 1 should be 1");
1688            assert_eq!(tensor_map[2], 42, "Value at index 2 should be 42");
1689        }
1690    }
1691
1692    #[test]
1693    #[cfg(target_os = "linux")]
1694    fn test_dma_no_fd_leaks() {
1695        let _lock = FD_LOCK.write().unwrap();
1696        if !is_dma_available() {
1697            log::warn!(
1698                "SKIPPED: {} - DMA memory allocation not available (permission denied or no DMA-BUF support)",
1699                function!()
1700            );
1701            return;
1702        }
1703
1704        let proc = procfs::process::Process::myself()
1705            .expect("Failed to get current process using /proc/self");
1706
1707        let start_open_fds = proc
1708            .fd_count()
1709            .expect("Failed to get open file descriptor count");
1710
1711        for _ in 0..100 {
1712            let tensor = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Dma), None)
1713                .expect("Failed to create tensor");
1714            let mut map = tensor.map().unwrap();
1715            map.as_mut_slice().fill(233);
1716        }
1717
1718        let end_open_fds = proc
1719            .fd_count()
1720            .expect("Failed to get open file descriptor count");
1721
1722        assert_eq!(
1723            start_open_fds, end_open_fds,
1724            "File descriptor leak detected: {} -> {}",
1725            start_open_fds, end_open_fds
1726        );
1727    }
1728
1729    #[test]
1730    #[cfg(target_os = "linux")]
1731    fn test_dma_from_fd_no_fd_leaks() {
1732        let _lock = FD_LOCK.write().unwrap();
1733        if !is_dma_available() {
1734            log::warn!(
1735                "SKIPPED: {} - DMA memory allocation not available (permission denied or no DMA-BUF support)",
1736                function!()
1737            );
1738            return;
1739        }
1740
1741        let proc = procfs::process::Process::myself()
1742            .expect("Failed to get current process using /proc/self");
1743
1744        let start_open_fds = proc
1745            .fd_count()
1746            .expect("Failed to get open file descriptor count");
1747
1748        let orig = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Dma), None).unwrap();
1749
1750        for _ in 0..100 {
1751            let tensor =
1752                Tensor::<u8>::from_fd(orig.clone_fd().unwrap(), orig.shape(), None).unwrap();
1753            let mut map = tensor.map().unwrap();
1754            map.as_mut_slice().fill(233);
1755        }
1756        drop(orig);
1757
1758        let end_open_fds = proc.fd_count().unwrap();
1759
1760        assert_eq!(
1761            start_open_fds, end_open_fds,
1762            "File descriptor leak detected: {} -> {}",
1763            start_open_fds, end_open_fds
1764        );
1765    }
1766
1767    #[test]
1768    #[cfg(target_os = "linux")]
1769    fn test_shm_no_fd_leaks() {
1770        let _lock = FD_LOCK.write().unwrap();
1771        if !is_shm_available() {
1772            log::warn!(
1773                "SKIPPED: {} - SHM memory allocation not available (permission denied or no SHM support)",
1774                function!()
1775            );
1776            return;
1777        }
1778
1779        let proc = procfs::process::Process::myself()
1780            .expect("Failed to get current process using /proc/self");
1781
1782        let start_open_fds = proc
1783            .fd_count()
1784            .expect("Failed to get open file descriptor count");
1785
1786        for _ in 0..100 {
1787            let tensor = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Shm), None)
1788                .expect("Failed to create tensor");
1789            let mut map = tensor.map().unwrap();
1790            map.as_mut_slice().fill(233);
1791        }
1792
1793        let end_open_fds = proc
1794            .fd_count()
1795            .expect("Failed to get open file descriptor count");
1796
1797        assert_eq!(
1798            start_open_fds, end_open_fds,
1799            "File descriptor leak detected: {} -> {}",
1800            start_open_fds, end_open_fds
1801        );
1802    }
1803
1804    #[test]
1805    #[cfg(target_os = "linux")]
1806    fn test_shm_from_fd_no_fd_leaks() {
1807        let _lock = FD_LOCK.write().unwrap();
1808        if !is_shm_available() {
1809            log::warn!(
1810                "SKIPPED: {} - SHM memory allocation not available (permission denied or no SHM support)",
1811                function!()
1812            );
1813            return;
1814        }
1815
1816        let proc = procfs::process::Process::myself()
1817            .expect("Failed to get current process using /proc/self");
1818
1819        let start_open_fds = proc
1820            .fd_count()
1821            .expect("Failed to get open file descriptor count");
1822
1823        let orig = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Shm), None).unwrap();
1824
1825        for _ in 0..100 {
1826            let tensor =
1827                Tensor::<u8>::from_fd(orig.clone_fd().unwrap(), orig.shape(), None).unwrap();
1828            let mut map = tensor.map().unwrap();
1829            map.as_mut_slice().fill(233);
1830        }
1831        drop(orig);
1832
1833        let end_open_fds = proc.fd_count().unwrap();
1834
1835        assert_eq!(
1836            start_open_fds, end_open_fds,
1837            "File descriptor leak detected: {} -> {}",
1838            start_open_fds, end_open_fds
1839        );
1840    }
1841
1842    #[cfg(feature = "ndarray")]
1843    #[test]
1844    fn test_ndarray() {
1845        let _lock = FD_LOCK.read().unwrap();
1846        let shape = vec![2, 3, 4];
1847        let tensor = Tensor::<f32>::new(&shape, None, None).expect("Failed to create tensor");
1848
1849        let mut tensor_map = tensor.map().expect("Failed to map tensor memory");
1850        tensor_map.fill(1.0);
1851
1852        let view = tensor_map.view().expect("Failed to get ndarray view");
1853        assert_eq!(view.shape(), &[2, 3, 4]);
1854        assert!(view.iter().all(|&x| x == 1.0));
1855
1856        let mut view_mut = tensor_map
1857            .view_mut()
1858            .expect("Failed to get mutable ndarray view");
1859        view_mut[[0, 0, 0]] = 42.0;
1860        assert_eq!(view_mut[[0, 0, 0]], 42.0);
1861        assert_eq!(tensor_map[0], 42.0, "Value at index 0 should be 42");
1862    }
1863
1864    #[test]
1865    fn test_buffer_identity_unique() {
1866        let id1 = BufferIdentity::new();
1867        let id2 = BufferIdentity::new();
1868        assert_ne!(
1869            id1.id(),
1870            id2.id(),
1871            "Two identities should have different ids"
1872        );
1873    }
1874
1875    #[test]
1876    fn test_buffer_identity_clone_shares_guard() {
1877        let id1 = BufferIdentity::new();
1878        let weak = id1.weak();
1879        assert!(
1880            weak.upgrade().is_some(),
1881            "Weak should be alive while original exists"
1882        );
1883
1884        let id2 = id1.clone();
1885        assert_eq!(id1.id(), id2.id(), "Cloned identity should have same id");
1886
1887        drop(id1);
1888        assert!(
1889            weak.upgrade().is_some(),
1890            "Weak should still be alive (clone holds Arc)"
1891        );
1892
1893        drop(id2);
1894        assert!(
1895            weak.upgrade().is_none(),
1896            "Weak should be dead after all clones dropped"
1897        );
1898    }
1899
1900    #[test]
1901    fn test_tensor_buffer_identity() {
1902        let t1 = Tensor::<u8>::new(&[100], Some(TensorMemory::Mem), Some("t1")).unwrap();
1903        let t2 = Tensor::<u8>::new(&[100], Some(TensorMemory::Mem), Some("t2")).unwrap();
1904        assert_ne!(
1905            t1.buffer_identity().id(),
1906            t2.buffer_identity().id(),
1907            "Different tensors should have different buffer ids"
1908        );
1909    }
1910
1911    // Any test that cares about the fd count must grab it exclusively.
1912    // Any tests which modifies the fd count by opening or closing fds must grab it
1913    // shared.
1914    pub static FD_LOCK: RwLock<()> = RwLock::new(());
1915
1916    /// Test that DMA is NOT available on non-Linux platforms.
1917    /// This verifies the cross-platform behavior of is_dma_available().
1918    #[test]
1919    #[cfg(not(target_os = "linux"))]
1920    fn test_dma_not_available_on_non_linux() {
1921        assert!(
1922            !is_dma_available(),
1923            "DMA memory allocation should NOT be available on non-Linux platforms"
1924        );
1925    }
1926
1927    /// Test that SHM memory allocation is available and usable on Unix systems.
1928    /// This is a basic functional test; Linux has additional FD leak tests using procfs.
1929    #[test]
1930    #[cfg(unix)]
1931    fn test_shm_available_and_usable() {
1932        assert!(
1933            is_shm_available(),
1934            "SHM memory allocation should be available on Unix systems"
1935        );
1936
1937        // Create a tensor with SHM backing
1938        let tensor = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Shm), None)
1939            .expect("Failed to create SHM tensor");
1940
1941        // Verify we can map and write to it
1942        let mut map = tensor.map().expect("Failed to map SHM tensor");
1943        map.as_mut_slice().fill(0xAB);
1944
1945        // Verify the data was written correctly
1946        assert!(
1947            map.as_slice().iter().all(|&b| b == 0xAB),
1948            "SHM tensor data should be writable and readable"
1949        );
1950    }
1951}