Skip to main content

edgefirst_tensor/
lib.rs

1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4/*!
5EdgeFirst HAL - Tensor Module
6
7The `edgefirst_tensor` crate provides a unified interface for managing multi-dimensional arrays (tensors)
8with support for different memory types, including Direct Memory Access (DMA), POSIX Shared Memory (Shm),
9and system memory. The crate defines traits and structures for creating, reshaping, and mapping tensors into memory.
10
11## Examples
12```rust
13use edgefirst_tensor::{Error, Tensor, TensorMemory, TensorTrait};
14# fn main() -> Result<(), Error> {
15let tensor = Tensor::<f32>::new(&[2, 3, 4], Some(TensorMemory::Mem), Some("test_tensor"))?;
16assert_eq!(tensor.memory(), TensorMemory::Mem);
17assert_eq!(tensor.name(), "test_tensor");
18#    Ok(())
19# }
20```
21
22## Overview
23The main structures and traits provided by the `edgefirst_tensor` crate are `TensorTrait` and `TensorMapTrait`,
24which define the behavior of Tensors and their memory mappings, respectively.
25The `Tensor<T>` struct wraps a backend-specific storage with optional image format metadata (`PixelFormat`),
26while the `TensorMap` enum provides access to the underlying data. The `TensorDyn` type-erased enum
27wraps `Tensor<T>` for runtime element-type dispatch.
28 */
29#[cfg(target_os = "linux")]
30mod dma;
31#[cfg(target_os = "linux")]
32mod dmabuf;
33mod error;
34mod format;
35mod mem;
36mod pbo;
37#[cfg(unix)]
38mod shm;
39mod tensor_dyn;
40
41#[cfg(target_os = "linux")]
42pub use crate::dma::{DmaMap, DmaTensor};
43pub use crate::mem::{MemMap, MemTensor};
44pub use crate::pbo::{PboMap, PboMapping, PboOps, PboTensor};
45#[cfg(unix)]
46pub use crate::shm::{ShmMap, ShmTensor};
47pub use error::{Error, Result};
48pub use format::{PixelFormat, PixelLayout};
49use num_traits::Num;
50use serde::{Deserialize, Serialize};
51#[cfg(unix)]
52use std::os::fd::OwnedFd;
53use std::{
54    fmt,
55    ops::{Deref, DerefMut},
56    sync::{
57        atomic::{AtomicU64, Ordering},
58        Arc, Weak,
59    },
60};
61pub use tensor_dyn::TensorDyn;
62
63/// Per-plane DMA-BUF descriptor for external buffer import.
64///
65/// Owns a duplicated file descriptor plus optional stride and offset metadata.
66/// The fd is duplicated eagerly in [`new()`](Self::new) so that a bad fd is
67/// caught immediately. `import_image` consumes the descriptor and takes
68/// ownership of the duped fd — no further cleanup is needed by the caller.
69///
70/// # Examples
71///
72/// ```rust,no_run
73/// use edgefirst_tensor::PlaneDescriptor;
74/// use std::os::fd::BorrowedFd;
75///
76/// // SAFETY: fd 42 is hypothetical; real code must pass a valid fd.
77/// let pd = unsafe { PlaneDescriptor::new(BorrowedFd::borrow_raw(42)) }
78///     .unwrap()
79///     .with_stride(2048)
80///     .with_offset(0);
81/// ```
82#[cfg(unix)]
83pub struct PlaneDescriptor {
84    fd: OwnedFd,
85    stride: Option<usize>,
86    offset: Option<usize>,
87}
88
89#[cfg(unix)]
90impl PlaneDescriptor {
91    /// Create a new plane descriptor by duplicating the given file descriptor.
92    ///
93    /// The fd is duped immediately — a bad fd fails here rather than inside
94    /// `import_image`. The caller retains ownership of the original fd.
95    ///
96    /// # Errors
97    ///
98    /// Returns an error if the `dup()` syscall fails (e.g. invalid fd or
99    /// fd limit reached).
100    pub fn new(fd: std::os::fd::BorrowedFd<'_>) -> Result<Self> {
101        let owned = fd.try_clone_to_owned()?;
102        Ok(Self {
103            fd: owned,
104            stride: None,
105            offset: None,
106        })
107    }
108
109    /// Set the row stride in bytes (consuming builder).
110    pub fn with_stride(mut self, stride: usize) -> Self {
111        self.stride = Some(stride);
112        self
113    }
114
115    /// Set the plane offset in bytes (consuming builder).
116    pub fn with_offset(mut self, offset: usize) -> Self {
117        self.offset = Some(offset);
118        self
119    }
120
121    /// Consume the descriptor and return the owned file descriptor.
122    pub fn into_fd(self) -> OwnedFd {
123        self.fd
124    }
125
126    /// Row stride in bytes, if set.
127    pub fn stride(&self) -> Option<usize> {
128        self.stride
129    }
130
131    /// Plane offset in bytes, if set.
132    pub fn offset(&self) -> Option<usize> {
133        self.offset
134    }
135}
136
137/// Element type discriminant for runtime type identification.
138#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
139#[repr(u8)]
140#[non_exhaustive]
141pub enum DType {
142    U8,
143    I8,
144    U16,
145    I16,
146    U32,
147    I32,
148    U64,
149    I64,
150    F16,
151    F32,
152    F64,
153}
154
155impl DType {
156    /// Size of one element in bytes.
157    pub const fn size(&self) -> usize {
158        match self {
159            Self::U8 | Self::I8 => 1,
160            Self::U16 | Self::I16 | Self::F16 => 2,
161            Self::U32 | Self::I32 | Self::F32 => 4,
162            Self::U64 | Self::I64 | Self::F64 => 8,
163        }
164    }
165
166    /// Short type name (e.g., "u8", "f32", "f16").
167    pub const fn name(&self) -> &'static str {
168        match self {
169            Self::U8 => "u8",
170            Self::I8 => "i8",
171            Self::U16 => "u16",
172            Self::I16 => "i16",
173            Self::U32 => "u32",
174            Self::I32 => "i32",
175            Self::U64 => "u64",
176            Self::I64 => "i64",
177            Self::F16 => "f16",
178            Self::F32 => "f32",
179            Self::F64 => "f64",
180        }
181    }
182}
183
184impl fmt::Display for DType {
185    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
186        f.write_str(self.name())
187    }
188}
189
190/// Monotonic counter for buffer identity IDs.
191static NEXT_BUFFER_ID: AtomicU64 = AtomicU64::new(1);
192
193/// Unique identity for a tensor's underlying buffer.
194///
195/// Created fresh on every buffer allocation or import. The `id` is a monotonic
196/// u64 used as a cache key. The `guard` is an `Arc<()>` whose weak references
197/// allow downstream caches to detect when the buffer has been dropped.
198#[derive(Debug, Clone)]
199pub struct BufferIdentity {
200    id: u64,
201    guard: Arc<()>,
202}
203
204impl BufferIdentity {
205    /// Create a new unique buffer identity.
206    pub fn new() -> Self {
207        Self {
208            id: NEXT_BUFFER_ID.fetch_add(1, Ordering::Relaxed),
209            guard: Arc::new(()),
210        }
211    }
212
213    /// Unique identifier for this buffer. Changes when the buffer changes.
214    pub fn id(&self) -> u64 {
215        self.id
216    }
217
218    /// Returns a weak reference to the buffer guard. Goes dead when the
219    /// owning Tensor is dropped (and no clones remain).
220    pub fn weak(&self) -> Weak<()> {
221        Arc::downgrade(&self.guard)
222    }
223}
224
225impl Default for BufferIdentity {
226    fn default() -> Self {
227        Self::new()
228    }
229}
230
231#[cfg(target_os = "linux")]
232use nix::sys::stat::{major, minor};
233
234pub trait TensorTrait<T>: Send + Sync
235where
236    T: Num + Clone + fmt::Debug,
237{
238    /// Create a new tensor with the given shape and optional name. If no name
239    /// is given, a random name will be generated.
240    fn new(shape: &[usize], name: Option<&str>) -> Result<Self>
241    where
242        Self: Sized;
243
244    #[cfg(unix)]
245    /// Create a new tensor using the given file descriptor, shape, and optional
246    /// name. If no name is given, a random name will be generated.
247    ///
248    /// On Linux: Inspects the fd to determine DMA vs SHM based on device major/minor.
249    /// On other Unix (macOS): Always creates SHM tensor.
250    fn from_fd(fd: std::os::fd::OwnedFd, shape: &[usize], name: Option<&str>) -> Result<Self>
251    where
252        Self: Sized;
253
254    #[cfg(unix)]
255    /// Clone the file descriptor associated with this tensor.
256    fn clone_fd(&self) -> Result<std::os::fd::OwnedFd>;
257
258    /// Get the memory type of this tensor.
259    fn memory(&self) -> TensorMemory;
260
261    /// Get the name of this tensor.
262    fn name(&self) -> String;
263
264    /// Get the number of elements in this tensor.
265    fn len(&self) -> usize {
266        self.shape().iter().product()
267    }
268
269    /// Check if the tensor is empty.
270    fn is_empty(&self) -> bool {
271        self.len() == 0
272    }
273
274    /// Get the size in bytes of this tensor.
275    fn size(&self) -> usize {
276        self.len() * std::mem::size_of::<T>()
277    }
278
279    /// Get the shape of this tensor.
280    fn shape(&self) -> &[usize];
281
282    /// Reshape this tensor to the given shape. The total number of elements
283    /// must remain the same.
284    fn reshape(&mut self, shape: &[usize]) -> Result<()>;
285
286    /// Map the tensor into memory and return a TensorMap for accessing the
287    /// data.
288    fn map(&self) -> Result<TensorMap<T>>;
289
290    /// Get the buffer identity for cache keying and liveness tracking.
291    fn buffer_identity(&self) -> &BufferIdentity;
292}
293
294pub trait TensorMapTrait<T>
295where
296    T: Num + Clone + fmt::Debug,
297{
298    /// Get the shape of this tensor map.
299    fn shape(&self) -> &[usize];
300
301    /// Unmap the tensor from memory.
302    fn unmap(&mut self);
303
304    /// Get the number of elements in this tensor map.
305    fn len(&self) -> usize {
306        self.shape().iter().product()
307    }
308
309    /// Check if the tensor map is empty.
310    fn is_empty(&self) -> bool {
311        self.len() == 0
312    }
313
314    /// Get the size in bytes of this tensor map.
315    fn size(&self) -> usize {
316        self.len() * std::mem::size_of::<T>()
317    }
318
319    /// Get a slice to the data in this tensor map.
320    fn as_slice(&self) -> &[T];
321
322    /// Get a mutable slice to the data in this tensor map.
323    fn as_mut_slice(&mut self) -> &mut [T];
324
325    #[cfg(feature = "ndarray")]
326    /// Get an ndarray ArrayView of the tensor data.
327    fn view(&'_ self) -> Result<ndarray::ArrayView<'_, T, ndarray::Dim<ndarray::IxDynImpl>>> {
328        Ok(ndarray::ArrayView::from_shape(
329            self.shape(),
330            self.as_slice(),
331        )?)
332    }
333
334    #[cfg(feature = "ndarray")]
335    /// Get an ndarray ArrayViewMut of the tensor data.
336    fn view_mut(
337        &'_ mut self,
338    ) -> Result<ndarray::ArrayViewMut<'_, T, ndarray::Dim<ndarray::IxDynImpl>>> {
339        let shape = self.shape().to_vec();
340        Ok(ndarray::ArrayViewMut::from_shape(
341            shape,
342            self.as_mut_slice(),
343        )?)
344    }
345}
346
347#[derive(Debug, Clone, Copy, PartialEq, Eq)]
348pub enum TensorMemory {
349    #[cfg(target_os = "linux")]
350    /// Direct Memory Access (DMA) allocation. Incurs additional
351    /// overhead for memory reading/writing with the CPU.  Allows for
352    /// hardware acceleration when supported.
353    Dma,
354    #[cfg(unix)]
355    /// POSIX Shared Memory allocation. Suitable for inter-process
356    /// communication, but not suitable for hardware acceleration.
357    Shm,
358
359    /// Regular system memory allocation
360    Mem,
361
362    /// OpenGL Pixel Buffer Object memory. Created by ImageProcessor
363    /// when DMA-buf is unavailable but OpenGL is present.
364    Pbo,
365}
366
367impl From<TensorMemory> for String {
368    fn from(memory: TensorMemory) -> Self {
369        match memory {
370            #[cfg(target_os = "linux")]
371            TensorMemory::Dma => "dma".to_owned(),
372            #[cfg(unix)]
373            TensorMemory::Shm => "shm".to_owned(),
374            TensorMemory::Mem => "mem".to_owned(),
375            TensorMemory::Pbo => "pbo".to_owned(),
376        }
377    }
378}
379
380impl TryFrom<&str> for TensorMemory {
381    type Error = Error;
382
383    fn try_from(s: &str) -> Result<Self> {
384        match s {
385            #[cfg(target_os = "linux")]
386            "dma" => Ok(TensorMemory::Dma),
387            #[cfg(unix)]
388            "shm" => Ok(TensorMemory::Shm),
389            "mem" => Ok(TensorMemory::Mem),
390            "pbo" => Ok(TensorMemory::Pbo),
391            _ => Err(Error::InvalidMemoryType(s.to_owned())),
392        }
393    }
394}
395
396#[derive(Debug)]
397#[allow(dead_code)] // Variants are constructed by downstream crates via pub(crate) helpers
398pub(crate) enum TensorStorage<T>
399where
400    T: Num + Clone + fmt::Debug + Send + Sync,
401{
402    #[cfg(target_os = "linux")]
403    Dma(DmaTensor<T>),
404    #[cfg(unix)]
405    Shm(ShmTensor<T>),
406    Mem(MemTensor<T>),
407    Pbo(PboTensor<T>),
408}
409
410impl<T> TensorStorage<T>
411where
412    T: Num + Clone + fmt::Debug + Send + Sync,
413{
414    /// Create a new tensor storage with the given shape, memory type, and
415    /// optional name. If no name is given, a random name will be generated.
416    /// If no memory type is given, the best available memory type will be
417    /// chosen based on the platform and environment variables.
418    fn new(shape: &[usize], memory: Option<TensorMemory>, name: Option<&str>) -> Result<Self> {
419        match memory {
420            #[cfg(target_os = "linux")]
421            Some(TensorMemory::Dma) => {
422                DmaTensor::<T>::new(shape, name).map(TensorStorage::Dma)
423            }
424            #[cfg(unix)]
425            Some(TensorMemory::Shm) => {
426                ShmTensor::<T>::new(shape, name).map(TensorStorage::Shm)
427            }
428            Some(TensorMemory::Mem) => {
429                MemTensor::<T>::new(shape, name).map(TensorStorage::Mem)
430            }
431            Some(TensorMemory::Pbo) => Err(crate::error::Error::NotImplemented(
432                "PboTensor cannot be created via Tensor::new() — use ImageProcessor::create_image()".to_owned(),
433            )),
434            None => {
435                if std::env::var("EDGEFIRST_TENSOR_FORCE_MEM")
436                    .is_ok_and(|x| x != "0" && x.to_lowercase() != "false")
437                {
438                    MemTensor::<T>::new(shape, name).map(TensorStorage::Mem)
439                } else {
440                    #[cfg(target_os = "linux")]
441                    {
442                        // Linux: Try DMA -> SHM -> Mem
443                        match DmaTensor::<T>::new(shape, name) {
444                            Ok(tensor) => Ok(TensorStorage::Dma(tensor)),
445                            Err(_) => {
446                                match ShmTensor::<T>::new(shape, name)
447                                    .map(TensorStorage::Shm)
448                                {
449                                    Ok(tensor) => Ok(tensor),
450                                    Err(_) => MemTensor::<T>::new(shape, name)
451                                        .map(TensorStorage::Mem),
452                                }
453                            }
454                        }
455                    }
456                    #[cfg(all(unix, not(target_os = "linux")))]
457                    {
458                        // macOS/BSD: Try SHM -> Mem (no DMA)
459                        match ShmTensor::<T>::new(shape, name) {
460                            Ok(tensor) => Ok(TensorStorage::Shm(tensor)),
461                            Err(_) => {
462                                MemTensor::<T>::new(shape, name).map(TensorStorage::Mem)
463                            }
464                        }
465                    }
466                    #[cfg(not(unix))]
467                    {
468                        // Windows/other: Mem only
469                        MemTensor::<T>::new(shape, name).map(TensorStorage::Mem)
470                    }
471                }
472            }
473        }
474    }
475
476    /// Create a DMA-backed tensor storage with an explicit byte size that
477    /// may exceed `shape.product() * sizeof(T)`. Used for image tensors
478    /// with row-padded layouts (see `DmaTensor::new_with_byte_size`).
479    ///
480    /// This is intentionally DMA-only: padding is only meaningful for
481    /// buffers that will be imported as GPU textures via EGLImage. PBO,
482    /// Shm, and Mem storage doesn't benefit from pitch alignment and
483    /// shouldn't pay the memory cost.
484    #[cfg(target_os = "linux")]
485    pub(crate) fn new_dma_with_byte_size(
486        shape: &[usize],
487        byte_size: usize,
488        name: Option<&str>,
489    ) -> Result<Self> {
490        DmaTensor::<T>::new_with_byte_size(shape, byte_size, name).map(TensorStorage::Dma)
491    }
492
493    // No non-Linux stub: the only caller (`Tensor::image_with_stride`)
494    // returns `NotImplemented` directly on non-Linux without ever
495    // reaching the storage layer, so defining a stub here would be
496    // dead code and fail the `-D warnings` clippy gate on macOS CI.
497
498    /// Create a new tensor storage using the given file descriptor, shape,
499    /// and optional name.
500    #[cfg(unix)]
501    fn from_fd(fd: OwnedFd, shape: &[usize], name: Option<&str>) -> Result<Self> {
502        #[cfg(target_os = "linux")]
503        {
504            use nix::sys::stat::fstat;
505
506            let stat = fstat(&fd)?;
507            let major = major(stat.st_dev);
508            let minor = minor(stat.st_dev);
509
510            log::debug!("Creating tensor from fd: major={major}, minor={minor}");
511
512            if major != 0 {
513                // Dma and Shm tensors are expected to have major number 0
514                return Err(Error::UnknownDeviceType(major, minor));
515            }
516
517            match minor {
518                9 | 10 => {
519                    // minor number 9 & 10 indicates DMA memory
520                    DmaTensor::<T>::from_fd(fd, shape, name).map(TensorStorage::Dma)
521                }
522                _ => {
523                    // other minor numbers are assumed to be shared memory
524                    ShmTensor::<T>::from_fd(fd, shape, name).map(TensorStorage::Shm)
525                }
526            }
527        }
528        #[cfg(all(unix, not(target_os = "linux")))]
529        {
530            // On macOS/BSD, always use SHM (no DMA support)
531            ShmTensor::<T>::from_fd(fd, shape, name).map(TensorStorage::Shm)
532        }
533    }
534}
535
536impl<T> TensorTrait<T> for TensorStorage<T>
537where
538    T: Num + Clone + fmt::Debug + Send + Sync,
539{
540    fn new(shape: &[usize], name: Option<&str>) -> Result<Self> {
541        Self::new(shape, None, name)
542    }
543
544    #[cfg(unix)]
545    fn from_fd(fd: OwnedFd, shape: &[usize], name: Option<&str>) -> Result<Self> {
546        Self::from_fd(fd, shape, name)
547    }
548
549    #[cfg(unix)]
550    fn clone_fd(&self) -> Result<OwnedFd> {
551        match self {
552            #[cfg(target_os = "linux")]
553            TensorStorage::Dma(t) => t.clone_fd(),
554            TensorStorage::Shm(t) => t.clone_fd(),
555            TensorStorage::Mem(t) => t.clone_fd(),
556            TensorStorage::Pbo(t) => t.clone_fd(),
557        }
558    }
559
560    fn memory(&self) -> TensorMemory {
561        match self {
562            #[cfg(target_os = "linux")]
563            TensorStorage::Dma(_) => TensorMemory::Dma,
564            #[cfg(unix)]
565            TensorStorage::Shm(_) => TensorMemory::Shm,
566            TensorStorage::Mem(_) => TensorMemory::Mem,
567            TensorStorage::Pbo(_) => TensorMemory::Pbo,
568        }
569    }
570
571    fn name(&self) -> String {
572        match self {
573            #[cfg(target_os = "linux")]
574            TensorStorage::Dma(t) => t.name(),
575            #[cfg(unix)]
576            TensorStorage::Shm(t) => t.name(),
577            TensorStorage::Mem(t) => t.name(),
578            TensorStorage::Pbo(t) => t.name(),
579        }
580    }
581
582    fn shape(&self) -> &[usize] {
583        match self {
584            #[cfg(target_os = "linux")]
585            TensorStorage::Dma(t) => t.shape(),
586            #[cfg(unix)]
587            TensorStorage::Shm(t) => t.shape(),
588            TensorStorage::Mem(t) => t.shape(),
589            TensorStorage::Pbo(t) => t.shape(),
590        }
591    }
592
593    fn reshape(&mut self, shape: &[usize]) -> Result<()> {
594        match self {
595            #[cfg(target_os = "linux")]
596            TensorStorage::Dma(t) => t.reshape(shape),
597            #[cfg(unix)]
598            TensorStorage::Shm(t) => t.reshape(shape),
599            TensorStorage::Mem(t) => t.reshape(shape),
600            TensorStorage::Pbo(t) => t.reshape(shape),
601        }
602    }
603
604    fn map(&self) -> Result<TensorMap<T>> {
605        match self {
606            #[cfg(target_os = "linux")]
607            TensorStorage::Dma(t) => t.map(),
608            #[cfg(unix)]
609            TensorStorage::Shm(t) => t.map(),
610            TensorStorage::Mem(t) => t.map(),
611            TensorStorage::Pbo(t) => t.map(),
612        }
613    }
614
615    fn buffer_identity(&self) -> &BufferIdentity {
616        match self {
617            #[cfg(target_os = "linux")]
618            TensorStorage::Dma(t) => t.buffer_identity(),
619            #[cfg(unix)]
620            TensorStorage::Shm(t) => t.buffer_identity(),
621            TensorStorage::Mem(t) => t.buffer_identity(),
622            TensorStorage::Pbo(t) => t.buffer_identity(),
623        }
624    }
625}
626
627/// Multi-backend tensor with optional image format metadata.
628///
629/// When `format` is `Some`, this tensor represents an image. Width, height,
630/// and channels are derived from `shape` + `format`. When `format` is `None`,
631/// this is a raw tensor (identical to the pre-refactoring behavior).
632#[derive(Debug)]
633pub struct Tensor<T>
634where
635    T: Num + Clone + fmt::Debug + Send + Sync,
636{
637    pub(crate) storage: TensorStorage<T>,
638    format: Option<PixelFormat>,
639    chroma: Option<Box<Tensor<T>>>,
640    /// Row stride in bytes for externally allocated buffers with row padding.
641    /// `None` means tightly packed (stride == width * bytes_per_pixel).
642    row_stride: Option<usize>,
643    /// Byte offset within the DMA-BUF where image data starts.
644    /// `None` means offset 0 (data starts at the beginning of the buffer).
645    plane_offset: Option<usize>,
646}
647
648impl<T> Tensor<T>
649where
650    T: Num + Clone + fmt::Debug + Send + Sync,
651{
652    /// Wrap a TensorStorage in a Tensor with no image metadata.
653    pub(crate) fn wrap(storage: TensorStorage<T>) -> Self {
654        Self {
655            storage,
656            format: None,
657            chroma: None,
658            row_stride: None,
659            plane_offset: None,
660        }
661    }
662
663    /// Create a new tensor with the given shape, memory type, and optional
664    /// name. If no name is given, a random name will be generated. If no
665    /// memory type is given, the best available memory type will be chosen
666    /// based on the platform and environment variables.
667    ///
668    /// On Linux platforms, the order of preference is: Dma -> Shm -> Mem.
669    /// On other Unix platforms (macOS), the order is: Shm -> Mem.
670    /// On non-Unix platforms, only Mem is available.
671    ///
672    /// # Environment Variables
673    /// - `EDGEFIRST_TENSOR_FORCE_MEM`: If set to a non-zero and non-false
674    ///   value, forces the use of regular system memory allocation
675    ///   (`TensorMemory::Mem`) regardless of platform capabilities.
676    ///
677    /// # Example
678    /// ```rust
679    /// use edgefirst_tensor::{Error, Tensor, TensorMemory, TensorTrait};
680    /// # fn main() -> Result<(), Error> {
681    /// let tensor = Tensor::<f32>::new(&[2, 3, 4], Some(TensorMemory::Mem), Some("test_tensor"))?;
682    /// assert_eq!(tensor.memory(), TensorMemory::Mem);
683    /// assert_eq!(tensor.name(), "test_tensor");
684    /// #    Ok(())
685    /// # }
686    /// ```
687    pub fn new(shape: &[usize], memory: Option<TensorMemory>, name: Option<&str>) -> Result<Self> {
688        TensorStorage::new(shape, memory, name).map(Self::wrap)
689    }
690
691    /// Create an image tensor with the given format.
692    pub fn image(
693        width: usize,
694        height: usize,
695        format: PixelFormat,
696        memory: Option<TensorMemory>,
697    ) -> Result<Self> {
698        let shape = match format.layout() {
699            PixelLayout::Packed => vec![height, width, format.channels()],
700            PixelLayout::Planar => vec![format.channels(), height, width],
701            PixelLayout::SemiPlanar => {
702                // Contiguous semi-planar: luma + interleaved chroma in one allocation.
703                // NV12 (4:2:0): H lines luma + H/2 lines chroma = H * 3/2 total
704                // NV16 (4:2:2): H lines luma + H lines chroma = H * 2 total
705                let total_h = match format {
706                    PixelFormat::Nv12 => {
707                        if !height.is_multiple_of(2) {
708                            return Err(Error::InvalidArgument(format!(
709                                "NV12 requires even height, got {height}"
710                            )));
711                        }
712                        height * 3 / 2
713                    }
714                    PixelFormat::Nv16 => height * 2,
715                    _ => {
716                        return Err(Error::InvalidArgument(format!(
717                            "unknown semi-planar height multiplier for {format:?}"
718                        )))
719                    }
720                };
721                vec![total_h, width]
722            }
723        };
724        let mut t = Self::new(&shape, memory, None)?;
725        t.format = Some(format);
726        Ok(t)
727    }
728
729    /// Create a DMA-backed image tensor with an explicit row stride that
730    /// may exceed the natural `width * channels * sizeof(T)` pitch.
731    ///
732    /// Used for image tensors that need GPU pitch alignment padding: the
733    /// underlying DMA-BUF is sized to `row_stride * height` bytes, but
734    /// the tensor's logical shape stays at `[height, width, channels]`.
735    /// `width()` / `height()` / `shape()` continue to report the
736    /// user-requested values; the padding is visible only via
737    /// `row_stride()` / `effective_row_stride()` and is automatically
738    /// propagated to the GL backend's EGLImage import so Mali Valhall
739    /// accepts the buffer.
740    ///
741    /// # Supported formats
742    ///
743    /// Currently only **packed** pixel layouts (RGBA8, BGRA8, RGB888,
744    /// Grey, etc.) are supported — the formats the GL backend uses as
745    /// render destinations. Semi-planar formats (NV12, NV16) come from
746    /// external allocators (camera capture, video decoders) and are
747    /// imported via `TensorDyn::from_fd` + `set_row_stride`, which
748    /// already supports padded strides.
749    ///
750    /// # Supported memory
751    ///
752    /// Currently only `TensorMemory::Dma` is supported. PBO and Mem
753    /// storage don't go through EGLImage import so they don't need
754    /// pitch alignment; if you pass any other memory type this returns
755    /// `NotImplemented`. `None` (auto-select) is treated as `Dma`.
756    ///
757    /// # Errors
758    ///
759    /// - `InvalidArgument` if `row_stride_bytes < width * channels * sizeof(T)`
760    ///   (the requested stride would not fit a single row)
761    /// - `NotImplemented` for non-packed formats or non-DMA memory
762    /// - `IoError` if the DMA-heap allocation fails (propagated from
763    ///   `DmaTensor::new_with_byte_size`)
764    pub fn image_with_stride(
765        width: usize,
766        height: usize,
767        format: PixelFormat,
768        row_stride_bytes: usize,
769        memory: Option<TensorMemory>,
770    ) -> Result<Self> {
771        // DMA backing (the only thing this constructor produces) is
772        // Linux-only. On macOS/BSD/Windows the non-Linux block below is
773        // the only compiled body and returns `NotImplemented` directly;
774        // on Linux the non-Linux block is cfg-removed and the function
775        // falls through to the real validation + allocation path. Each
776        // target compiles exactly one of the two blocks, and the block
777        // serves as the function's tail expression in both cases — so
778        // neither needs an explicit `return` (avoids
779        // `clippy::needless_return` on the macOS CI gate).
780        #[cfg(not(target_os = "linux"))]
781        {
782            let _ = (width, height, format, row_stride_bytes, memory);
783            Err(Error::NotImplemented(
784                "image_with_stride requires DMA support (Linux only)".to_owned(),
785            ))
786        }
787
788        #[cfg(target_os = "linux")]
789        {
790            if format.layout() != PixelLayout::Packed {
791                return Err(Error::NotImplemented(format!(
792                    "Tensor::image_with_stride only supports packed pixel layouts, got {format:?}"
793                )));
794            }
795            let elem = std::mem::size_of::<T>();
796            let min_stride = width
797                .checked_mul(format.channels())
798                .and_then(|p| p.checked_mul(elem))
799                .ok_or_else(|| {
800                    Error::InvalidArgument(format!(
801                        "image_with_stride: width {width} × channels {} × sizeof::<T>={elem} \
802                         overflows usize",
803                        format.channels()
804                    ))
805                })?;
806            if row_stride_bytes < min_stride {
807                return Err(Error::InvalidArgument(format!(
808                    "image_with_stride: row_stride {row_stride_bytes} < minimum {min_stride} \
809                     ({width} px × {} ch × {elem} B)",
810                    format.channels()
811                )));
812            }
813            let total_byte_size = row_stride_bytes.checked_mul(height).ok_or_else(|| {
814                Error::InvalidArgument(format!(
815                    "image_with_stride: row_stride {row_stride_bytes} × height {height} overflows usize"
816                ))
817            })?;
818
819            let shape = vec![height, width, format.channels()];
820
821            let storage = match memory {
822                Some(TensorMemory::Dma) | None => {
823                    TensorStorage::<T>::new_dma_with_byte_size(&shape, total_byte_size, None)?
824                }
825                Some(other) => {
826                    return Err(Error::NotImplemented(format!(
827                        "image_with_stride: only TensorMemory::Dma is supported, got {other:?}"
828                    )));
829                }
830            };
831
832            let mut t = Self::wrap(storage);
833            t.format = Some(format);
834            t.row_stride = Some(row_stride_bytes);
835            Ok(t)
836        }
837    }
838
839    /// Attach format metadata to an existing tensor.
840    ///
841    /// # Arguments
842    ///
843    /// * `format` - The pixel format to attach
844    ///
845    /// # Returns
846    ///
847    /// `Ok(())` on success, with the format stored as metadata on the tensor.
848    ///
849    /// # Errors
850    ///
851    /// Returns `Error::InvalidShape` if the tensor shape is incompatible with
852    /// the format's layout (packed expects `[H, W, C]`, planar expects
853    /// `[C, H, W]`, semi-planar expects `[H*k, W]` with format-specific
854    /// height constraints).
855    pub fn set_format(&mut self, format: PixelFormat) -> Result<()> {
856        let shape = self.shape();
857        match format.layout() {
858            PixelLayout::Packed => {
859                if shape.len() != 3 || shape[2] != format.channels() {
860                    return Err(Error::InvalidShape(format!(
861                        "packed format {format:?} expects [H, W, {}], got {shape:?}",
862                        format.channels()
863                    )));
864                }
865            }
866            PixelLayout::Planar => {
867                if shape.len() != 3 || shape[0] != format.channels() {
868                    return Err(Error::InvalidShape(format!(
869                        "planar format {format:?} expects [{}, H, W], got {shape:?}",
870                        format.channels()
871                    )));
872                }
873            }
874            PixelLayout::SemiPlanar => {
875                if shape.len() != 2 {
876                    return Err(Error::InvalidShape(format!(
877                        "semi-planar format {format:?} expects [H*k, W], got {shape:?}"
878                    )));
879                }
880                match format {
881                    PixelFormat::Nv12 if !shape[0].is_multiple_of(3) => {
882                        return Err(Error::InvalidShape(format!(
883                            "NV12 contiguous shape[0] must be divisible by 3, got {}",
884                            shape[0]
885                        )));
886                    }
887                    PixelFormat::Nv16 if !shape[0].is_multiple_of(2) => {
888                        return Err(Error::InvalidShape(format!(
889                            "NV16 contiguous shape[0] must be even, got {}",
890                            shape[0]
891                        )));
892                    }
893                    _ => {}
894                }
895            }
896        }
897        // Clear stored stride/offset when format changes — they may be invalid
898        // for the new format. Caller must re-set after changing format.
899        if self.format != Some(format) {
900            self.row_stride = None;
901            self.plane_offset = None;
902            #[cfg(target_os = "linux")]
903            if let TensorStorage::Dma(ref mut dma) = self.storage {
904                dma.mmap_offset = 0;
905            }
906        }
907        self.format = Some(format);
908        Ok(())
909    }
910
911    /// Pixel format (None if not an image).
912    pub fn format(&self) -> Option<PixelFormat> {
913        self.format
914    }
915
916    /// Image width (None if not an image).
917    pub fn width(&self) -> Option<usize> {
918        let fmt = self.format?;
919        let shape = self.shape();
920        match fmt.layout() {
921            PixelLayout::Packed => Some(shape[1]),
922            PixelLayout::Planar => Some(shape[2]),
923            PixelLayout::SemiPlanar => Some(shape[1]),
924        }
925    }
926
927    /// Image height (None if not an image).
928    pub fn height(&self) -> Option<usize> {
929        let fmt = self.format?;
930        let shape = self.shape();
931        match fmt.layout() {
932            PixelLayout::Packed => Some(shape[0]),
933            PixelLayout::Planar => Some(shape[1]),
934            PixelLayout::SemiPlanar => {
935                if self.is_multiplane() {
936                    Some(shape[0])
937                } else {
938                    match fmt {
939                        PixelFormat::Nv12 => Some(shape[0] * 2 / 3),
940                        PixelFormat::Nv16 => Some(shape[0] / 2),
941                        _ => None,
942                    }
943                }
944            }
945        }
946    }
947
948    /// Create from separate Y and UV planes (multiplane NV12/NV16).
949    pub fn from_planes(luma: Tensor<T>, chroma: Tensor<T>, format: PixelFormat) -> Result<Self> {
950        if format.layout() != PixelLayout::SemiPlanar {
951            return Err(Error::InvalidArgument(format!(
952                "from_planes requires a semi-planar format, got {format:?}"
953            )));
954        }
955        if chroma.format.is_some() || chroma.chroma.is_some() {
956            return Err(Error::InvalidArgument(
957                "chroma tensor must be a raw tensor (no format or chroma metadata)".into(),
958            ));
959        }
960        let luma_shape = luma.shape();
961        let chroma_shape = chroma.shape();
962        if luma_shape.len() != 2 || chroma_shape.len() != 2 {
963            return Err(Error::InvalidArgument(format!(
964                "from_planes expects 2D shapes, got luma={luma_shape:?} chroma={chroma_shape:?}"
965            )));
966        }
967        if luma_shape[1] != chroma_shape[1] {
968            return Err(Error::InvalidArgument(format!(
969                "luma width {} != chroma width {}",
970                luma_shape[1], chroma_shape[1]
971            )));
972        }
973        match format {
974            PixelFormat::Nv12 => {
975                if luma_shape[0] % 2 != 0 {
976                    return Err(Error::InvalidArgument(format!(
977                        "NV12 requires even luma height, got {}",
978                        luma_shape[0]
979                    )));
980                }
981                if chroma_shape[0] != luma_shape[0] / 2 {
982                    return Err(Error::InvalidArgument(format!(
983                        "NV12 chroma height {} != luma height / 2 ({})",
984                        chroma_shape[0],
985                        luma_shape[0] / 2
986                    )));
987                }
988            }
989            PixelFormat::Nv16 => {
990                if chroma_shape[0] != luma_shape[0] {
991                    return Err(Error::InvalidArgument(format!(
992                        "NV16 chroma height {} != luma height {}",
993                        chroma_shape[0], luma_shape[0]
994                    )));
995                }
996            }
997            _ => {
998                return Err(Error::InvalidArgument(format!(
999                    "from_planes only supports NV12 and NV16, got {format:?}"
1000                )));
1001            }
1002        }
1003
1004        Ok(Tensor {
1005            storage: luma.storage,
1006            format: Some(format),
1007            chroma: Some(Box::new(chroma)),
1008            row_stride: luma.row_stride,
1009            plane_offset: luma.plane_offset,
1010        })
1011    }
1012
1013    /// Whether this tensor uses separate plane allocations.
1014    pub fn is_multiplane(&self) -> bool {
1015        self.chroma.is_some()
1016    }
1017
1018    /// Access the chroma plane for multiplane semi-planar images.
1019    pub fn chroma(&self) -> Option<&Tensor<T>> {
1020        self.chroma.as_deref()
1021    }
1022
1023    /// Mutable access to the chroma plane for multiplane semi-planar images.
1024    pub fn chroma_mut(&mut self) -> Option<&mut Tensor<T>> {
1025        self.chroma.as_deref_mut()
1026    }
1027
1028    /// Row stride in bytes (`None` = tightly packed).
1029    pub fn row_stride(&self) -> Option<usize> {
1030        self.row_stride
1031    }
1032
1033    /// Effective row stride in bytes: the stored stride if set, otherwise the
1034    /// minimum stride computed from the format, width, and element size.
1035    /// Returns `None` only when no format is set and no explicit stride was
1036    /// stored via [`set_row_stride`](Self::set_row_stride).
1037    pub fn effective_row_stride(&self) -> Option<usize> {
1038        if let Some(s) = self.row_stride {
1039            return Some(s);
1040        }
1041        let fmt = self.format?;
1042        let w = self.width()?;
1043        let elem = std::mem::size_of::<T>();
1044        Some(match fmt.layout() {
1045            PixelLayout::Packed => w * fmt.channels() * elem,
1046            PixelLayout::Planar | PixelLayout::SemiPlanar => w * elem,
1047        })
1048    }
1049
1050    /// Set the row stride in bytes for externally allocated buffers with
1051    /// row padding (e.g. V4L2 or GStreamer allocators).
1052    ///
1053    /// The stride is propagated to the EGL DMA-BUF import attributes so
1054    /// the GPU interprets the padded buffer layout correctly. Must be
1055    /// called after [`set_format`](Self::set_format) and before the tensor
1056    /// is first passed to [`ImageProcessor::convert`]. The stored stride
1057    /// is cleared automatically if the pixel format is later changed.
1058    ///
1059    /// No stride-vs-buffer-size validation is performed because the
1060    /// backing allocation size is not reliably known: external DMA-BUFs
1061    /// may be over-allocated by the allocator, and internal tensors store
1062    /// a logical (unpadded) shape. An incorrect stride will be caught by
1063    /// the EGL driver at import time.
1064    ///
1065    /// # Arguments
1066    ///
1067    /// * `stride` - Row stride in bytes. Must be >= the minimum stride for
1068    ///   the format (width * channels * sizeof(T) for packed,
1069    ///   width * sizeof(T) for planar/semi-planar).
1070    ///
1071    /// # Errors
1072    ///
1073    /// * `InvalidArgument` if no pixel format is set on this tensor
1074    /// * `InvalidArgument` if `stride` is less than the minimum for the
1075    ///   format and width
1076    pub fn set_row_stride(&mut self, stride: usize) -> Result<()> {
1077        let fmt = self.format.ok_or_else(|| {
1078            Error::InvalidArgument("cannot set row_stride without a pixel format".into())
1079        })?;
1080        let w = self.width().ok_or_else(|| {
1081            Error::InvalidArgument("cannot determine width for row_stride validation".into())
1082        })?;
1083        let elem = std::mem::size_of::<T>();
1084        let min_stride = match fmt.layout() {
1085            PixelLayout::Packed => w * fmt.channels() * elem,
1086            PixelLayout::Planar | PixelLayout::SemiPlanar => w * elem,
1087        };
1088        if stride < min_stride {
1089            return Err(Error::InvalidArgument(format!(
1090                "row_stride {stride} < minimum {min_stride} for {fmt:?} at width {w}"
1091            )));
1092        }
1093        self.row_stride = Some(stride);
1094        Ok(())
1095    }
1096
1097    /// Set the row stride without format validation.
1098    ///
1099    /// Use this for raw sub-tensors (e.g. chroma planes) that don't carry
1100    /// format metadata. The caller is responsible for ensuring the stride
1101    /// is valid.
1102    pub fn set_row_stride_unchecked(&mut self, stride: usize) {
1103        self.row_stride = Some(stride);
1104    }
1105
1106    /// Builder-style variant of [`set_row_stride`](Self::set_row_stride),
1107    /// consuming and returning `self`.
1108    ///
1109    /// # Errors
1110    ///
1111    /// Same conditions as [`set_row_stride`](Self::set_row_stride).
1112    pub fn with_row_stride(mut self, stride: usize) -> Result<Self> {
1113        self.set_row_stride(stride)?;
1114        Ok(self)
1115    }
1116
1117    /// Byte offset within the DMA-BUF where image data starts (`None` = 0).
1118    pub fn plane_offset(&self) -> Option<usize> {
1119        self.plane_offset
1120    }
1121
1122    /// Set the byte offset within the DMA-BUF where image data starts.
1123    ///
1124    /// Propagated to `EGL_DMA_BUF_PLANE0_OFFSET_EXT` on GPU import.
1125    /// Unlike [`set_row_stride`](Self::set_row_stride), no format is required
1126    /// since the offset is format-independent.
1127    pub fn set_plane_offset(&mut self, offset: usize) {
1128        self.plane_offset = Some(offset);
1129        #[cfg(target_os = "linux")]
1130        if let TensorStorage::Dma(ref mut dma) = self.storage {
1131            dma.mmap_offset = offset;
1132        }
1133    }
1134
1135    /// Builder-style variant of [`set_plane_offset`](Self::set_plane_offset),
1136    /// consuming and returning `self`.
1137    pub fn with_plane_offset(mut self, offset: usize) -> Self {
1138        self.set_plane_offset(offset);
1139        self
1140    }
1141
1142    /// Downcast to PBO tensor reference (for GL backends).
1143    pub fn as_pbo(&self) -> Option<&PboTensor<T>> {
1144        match &self.storage {
1145            TensorStorage::Pbo(p) => Some(p),
1146            _ => None,
1147        }
1148    }
1149
1150    /// Downcast to DMA tensor reference (for EGL import, G2D).
1151    #[cfg(target_os = "linux")]
1152    pub fn as_dma(&self) -> Option<&DmaTensor<T>> {
1153        match &self.storage {
1154            TensorStorage::Dma(d) => Some(d),
1155            _ => None,
1156        }
1157    }
1158
1159    /// Borrow the DMA-BUF file descriptor backing this tensor.
1160    ///
1161    /// # Returns
1162    ///
1163    /// A borrowed reference to the DMA-BUF file descriptor, tied to `self`'s
1164    /// lifetime.
1165    ///
1166    /// # Errors
1167    ///
1168    /// Returns `Error::NotImplemented` if the tensor is not DMA-backed.
1169    #[cfg(target_os = "linux")]
1170    pub fn dmabuf(&self) -> Result<std::os::fd::BorrowedFd<'_>> {
1171        use std::os::fd::AsFd;
1172        match &self.storage {
1173            TensorStorage::Dma(dma) => Ok(dma.fd.as_fd()),
1174            _ => Err(Error::NotImplemented(format!(
1175                "dmabuf requires DMA-backed tensor, got {:?}",
1176                self.storage.memory()
1177            ))),
1178        }
1179    }
1180
1181    /// Construct a Tensor from a PBO tensor (for GL backends that allocate PBOs).
1182    pub fn from_pbo(pbo: PboTensor<T>) -> Self {
1183        Self {
1184            storage: TensorStorage::Pbo(pbo),
1185            format: None,
1186            chroma: None,
1187            row_stride: None,
1188            plane_offset: None,
1189        }
1190    }
1191}
1192
1193impl<T> TensorTrait<T> for Tensor<T>
1194where
1195    T: Num + Clone + fmt::Debug + Send + Sync,
1196{
1197    fn new(shape: &[usize], name: Option<&str>) -> Result<Self>
1198    where
1199        Self: Sized,
1200    {
1201        Self::new(shape, None, name)
1202    }
1203
1204    #[cfg(unix)]
1205    fn from_fd(fd: std::os::fd::OwnedFd, shape: &[usize], name: Option<&str>) -> Result<Self>
1206    where
1207        Self: Sized,
1208    {
1209        Ok(Self::wrap(TensorStorage::from_fd(fd, shape, name)?))
1210    }
1211
1212    #[cfg(unix)]
1213    fn clone_fd(&self) -> Result<std::os::fd::OwnedFd> {
1214        self.storage.clone_fd()
1215    }
1216
1217    fn memory(&self) -> TensorMemory {
1218        self.storage.memory()
1219    }
1220
1221    fn name(&self) -> String {
1222        self.storage.name()
1223    }
1224
1225    fn shape(&self) -> &[usize] {
1226        self.storage.shape()
1227    }
1228
1229    fn reshape(&mut self, shape: &[usize]) -> Result<()> {
1230        if self.chroma.is_some() {
1231            return Err(Error::InvalidOperation(
1232                "cannot reshape a multiplane tensor — decompose planes first".into(),
1233            ));
1234        }
1235        self.storage.reshape(shape)?;
1236        self.format = None;
1237        self.row_stride = None;
1238        self.plane_offset = None;
1239        #[cfg(target_os = "linux")]
1240        if let TensorStorage::Dma(ref mut dma) = self.storage {
1241            dma.mmap_offset = 0;
1242        }
1243        Ok(())
1244    }
1245
1246    fn map(&self) -> Result<TensorMap<T>> {
1247        // CPU mapping of strided tensors is allowed only when the HAL
1248        // owns the underlying allocation — i.e. self-allocated DMA
1249        // tensors with pitch padding added by `image_with_stride()`
1250        // for GPU import alignment. In that case we know the buffer
1251        // is exactly `row_stride × height` bytes (for packed formats)
1252        // and callers that respect the stride can iterate rows
1253        // correctly via `effective_row_stride()`.
1254        //
1255        // Foreign DMA-BUFs imported via `from_fd()` + `set_row_stride()`
1256        // (the V4L2 / GStreamer case) are rejected: their layout comes
1257        // from an external allocator and the HAL cannot validate what
1258        // the caller expects the mapping to look like. Those tensors
1259        // are intended for the GPU path only.
1260        //
1261        // The cfg split keeps `stride` from being an unused binding on
1262        // non-Linux builds (the Linux branch is the only consumer).
1263        #[cfg(target_os = "linux")]
1264        if let Some(stride) = self.row_stride {
1265            if let TensorStorage::Dma(dma) = &self.storage {
1266                if !dma.is_imported {
1267                    // Self-allocated strided DMA tensor — expose the
1268                    // full stride×height padded mmap via the override
1269                    // constructor so callers can iterate rows with
1270                    // `effective_row_stride()` without going past
1271                    // the end of the returned slice.
1272                    //
1273                    // Validate the requested mapping fits inside the
1274                    // actual DMA-BUF. `set_row_stride()` is a public
1275                    // API and only validates `stride >= min_stride`,
1276                    // not `stride × height <= buf_size`, so a caller
1277                    // that tampers with the stride after allocation
1278                    // could otherwise request a slice larger than the
1279                    // underlying mmap — which would be undefined
1280                    // behaviour in `DmaMap::as_slice`.
1281                    //
1282                    // Refuse to map if `height()` can't be derived
1283                    // (e.g. raw 2D tensors without a PixelFormat that
1284                    // got a `row_stride` set via `set_row_stride_unchecked`).
1285                    // Returning a 0-byte view would silently truncate
1286                    // rather than surface the misuse.
1287                    let height = self.height().ok_or_else(|| {
1288                        Error::InvalidOperation(
1289                            "Tensor::map: strided DMA mapping requires a PixelFormat \
1290                             so height() can be derived; set a format before mapping \
1291                             or clear row_stride for raw tensor access"
1292                                .into(),
1293                        )
1294                    })?;
1295                    let total_bytes = stride.checked_mul(height).ok_or_else(|| {
1296                        Error::InvalidOperation(format!(
1297                            "Tensor::map: row_stride {stride} × height {height} overflows usize"
1298                        ))
1299                    })?;
1300                    let available_bytes = dma.buf_size.saturating_sub(dma.mmap_offset);
1301                    if total_bytes > available_bytes {
1302                        return Err(Error::InvalidOperation(format!(
1303                            "Tensor::map: strided mapping needs {total_bytes} bytes \
1304                             but DMA buffer only has {available_bytes} available \
1305                             (buf_size={}, mmap_offset={}, stride={stride}, height={height}); \
1306                             the row_stride was likely set larger than the original allocation",
1307                            dma.buf_size, dma.mmap_offset
1308                        )));
1309                    }
1310                    return dma.map_with_byte_size(total_bytes).map(TensorMap::Dma);
1311                }
1312            }
1313            return Err(Error::InvalidOperation(
1314                "CPU mapping of strided foreign tensors is not supported; \
1315                 use GPU path only"
1316                    .into(),
1317            ));
1318        }
1319        #[cfg(not(target_os = "linux"))]
1320        if self.row_stride.is_some() {
1321            return Err(Error::InvalidOperation(
1322                "CPU mapping of strided tensors is not supported on this \
1323                 platform (DMA backing is Linux-only)"
1324                    .into(),
1325            ));
1326        }
1327        // Offset tensors are supported for DMA storage — DmaMap adjusts the
1328        // mmap range and slice start position.  Non-DMA offset tensors are
1329        // not meaningful (offset only applies to DMA-BUF sub-regions).
1330        if self.plane_offset.is_some_and(|o| o > 0) {
1331            #[cfg(target_os = "linux")]
1332            if !matches!(self.storage, TensorStorage::Dma(_)) {
1333                return Err(Error::InvalidOperation(
1334                    "plane offset only supported for DMA tensors".into(),
1335                ));
1336            }
1337            #[cfg(not(target_os = "linux"))]
1338            return Err(Error::InvalidOperation(
1339                "plane offset only supported for DMA tensors".into(),
1340            ));
1341        }
1342        self.storage.map()
1343    }
1344
1345    fn buffer_identity(&self) -> &BufferIdentity {
1346        self.storage.buffer_identity()
1347    }
1348}
1349
1350pub enum TensorMap<T>
1351where
1352    T: Num + Clone + fmt::Debug,
1353{
1354    #[cfg(target_os = "linux")]
1355    Dma(DmaMap<T>),
1356    #[cfg(unix)]
1357    Shm(ShmMap<T>),
1358    Mem(MemMap<T>),
1359    Pbo(PboMap<T>),
1360}
1361
1362impl<T> TensorMapTrait<T> for TensorMap<T>
1363where
1364    T: Num + Clone + fmt::Debug,
1365{
1366    fn shape(&self) -> &[usize] {
1367        match self {
1368            #[cfg(target_os = "linux")]
1369            TensorMap::Dma(map) => map.shape(),
1370            #[cfg(unix)]
1371            TensorMap::Shm(map) => map.shape(),
1372            TensorMap::Mem(map) => map.shape(),
1373            TensorMap::Pbo(map) => map.shape(),
1374        }
1375    }
1376
1377    fn unmap(&mut self) {
1378        match self {
1379            #[cfg(target_os = "linux")]
1380            TensorMap::Dma(map) => map.unmap(),
1381            #[cfg(unix)]
1382            TensorMap::Shm(map) => map.unmap(),
1383            TensorMap::Mem(map) => map.unmap(),
1384            TensorMap::Pbo(map) => map.unmap(),
1385        }
1386    }
1387
1388    fn as_slice(&self) -> &[T] {
1389        match self {
1390            #[cfg(target_os = "linux")]
1391            TensorMap::Dma(map) => map.as_slice(),
1392            #[cfg(unix)]
1393            TensorMap::Shm(map) => map.as_slice(),
1394            TensorMap::Mem(map) => map.as_slice(),
1395            TensorMap::Pbo(map) => map.as_slice(),
1396        }
1397    }
1398
1399    fn as_mut_slice(&mut self) -> &mut [T] {
1400        match self {
1401            #[cfg(target_os = "linux")]
1402            TensorMap::Dma(map) => map.as_mut_slice(),
1403            #[cfg(unix)]
1404            TensorMap::Shm(map) => map.as_mut_slice(),
1405            TensorMap::Mem(map) => map.as_mut_slice(),
1406            TensorMap::Pbo(map) => map.as_mut_slice(),
1407        }
1408    }
1409}
1410
1411impl<T> Deref for TensorMap<T>
1412where
1413    T: Num + Clone + fmt::Debug,
1414{
1415    type Target = [T];
1416
1417    fn deref(&self) -> &[T] {
1418        match self {
1419            #[cfg(target_os = "linux")]
1420            TensorMap::Dma(map) => map.deref(),
1421            #[cfg(unix)]
1422            TensorMap::Shm(map) => map.deref(),
1423            TensorMap::Mem(map) => map.deref(),
1424            TensorMap::Pbo(map) => map.deref(),
1425        }
1426    }
1427}
1428
1429impl<T> DerefMut for TensorMap<T>
1430where
1431    T: Num + Clone + fmt::Debug,
1432{
1433    fn deref_mut(&mut self) -> &mut [T] {
1434        match self {
1435            #[cfg(target_os = "linux")]
1436            TensorMap::Dma(map) => map.deref_mut(),
1437            #[cfg(unix)]
1438            TensorMap::Shm(map) => map.deref_mut(),
1439            TensorMap::Mem(map) => map.deref_mut(),
1440            TensorMap::Pbo(map) => map.deref_mut(),
1441        }
1442    }
1443}
1444
1445// ============================================================================
1446// Platform availability helpers
1447// ============================================================================
1448
1449/// Check if DMA memory allocation is available on this system.
1450///
1451/// Returns `true` only on Linux systems with DMA-BUF heap access (typically
1452/// requires running as root or membership in a video/render group).
1453/// Always returns `false` on non-Linux platforms (macOS, Windows, etc.).
1454///
1455/// This function caches its result after the first call for efficiency.
1456#[cfg(target_os = "linux")]
1457static DMA_AVAILABLE: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
1458
1459/// Check if DMA memory allocation is available on this system.
1460#[cfg(target_os = "linux")]
1461pub fn is_dma_available() -> bool {
1462    *DMA_AVAILABLE.get_or_init(|| Tensor::<u8>::new(&[64], Some(TensorMemory::Dma), None).is_ok())
1463}
1464
1465/// Check if DMA memory allocation is available on this system.
1466///
1467/// Always returns `false` on non-Linux platforms since DMA-BUF is Linux-specific.
1468#[cfg(not(target_os = "linux"))]
1469pub fn is_dma_available() -> bool {
1470    false
1471}
1472
1473/// Check if POSIX shared memory allocation is available on this system.
1474///
1475/// Returns `true` on Unix systems (Linux, macOS, BSD) where POSIX shared memory
1476/// is supported. Always returns `false` on non-Unix platforms (Windows).
1477///
1478/// This function caches its result after the first call for efficiency.
1479#[cfg(unix)]
1480static SHM_AVAILABLE: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
1481
1482/// Check if POSIX shared memory allocation is available on this system.
1483#[cfg(unix)]
1484pub fn is_shm_available() -> bool {
1485    *SHM_AVAILABLE.get_or_init(|| Tensor::<u8>::new(&[64], Some(TensorMemory::Shm), None).is_ok())
1486}
1487
1488/// Check if POSIX shared memory allocation is available on this system.
1489///
1490/// Always returns `false` on non-Unix platforms since POSIX SHM is Unix-specific.
1491#[cfg(not(unix))]
1492pub fn is_shm_available() -> bool {
1493    false
1494}
1495
1496#[cfg(test)]
1497mod dtype_tests {
1498    use super::*;
1499
1500    #[test]
1501    fn dtype_size() {
1502        assert_eq!(DType::U8.size(), 1);
1503        assert_eq!(DType::I8.size(), 1);
1504        assert_eq!(DType::U16.size(), 2);
1505        assert_eq!(DType::I16.size(), 2);
1506        assert_eq!(DType::U32.size(), 4);
1507        assert_eq!(DType::I32.size(), 4);
1508        assert_eq!(DType::U64.size(), 8);
1509        assert_eq!(DType::I64.size(), 8);
1510        assert_eq!(DType::F16.size(), 2);
1511        assert_eq!(DType::F32.size(), 4);
1512        assert_eq!(DType::F64.size(), 8);
1513    }
1514
1515    #[test]
1516    fn dtype_name() {
1517        assert_eq!(DType::U8.name(), "u8");
1518        assert_eq!(DType::F16.name(), "f16");
1519        assert_eq!(DType::F32.name(), "f32");
1520    }
1521
1522    #[test]
1523    fn dtype_serde_roundtrip() {
1524        use serde_json;
1525        let dt = DType::F16;
1526        let json = serde_json::to_string(&dt).unwrap();
1527        let back: DType = serde_json::from_str(&json).unwrap();
1528        assert_eq!(dt, back);
1529    }
1530}
1531
1532#[cfg(test)]
1533mod image_tests {
1534    use super::*;
1535
1536    #[test]
1537    fn raw_tensor_has_no_format() {
1538        let t = Tensor::<u8>::new(&[480, 640, 3], None, None).unwrap();
1539        assert!(t.format().is_none());
1540        assert!(t.width().is_none());
1541        assert!(t.height().is_none());
1542        assert!(!t.is_multiplane());
1543        assert!(t.chroma().is_none());
1544    }
1545
1546    #[test]
1547    fn image_tensor_packed() {
1548        let t = Tensor::<u8>::image(640, 480, PixelFormat::Rgba, None).unwrap();
1549        assert_eq!(t.format(), Some(PixelFormat::Rgba));
1550        assert_eq!(t.width(), Some(640));
1551        assert_eq!(t.height(), Some(480));
1552        assert_eq!(t.shape(), &[480, 640, 4]);
1553        assert!(!t.is_multiplane());
1554    }
1555
1556    #[test]
1557    fn image_tensor_planar() {
1558        let t = Tensor::<u8>::image(640, 480, PixelFormat::PlanarRgb, None).unwrap();
1559        assert_eq!(t.format(), Some(PixelFormat::PlanarRgb));
1560        assert_eq!(t.width(), Some(640));
1561        assert_eq!(t.height(), Some(480));
1562        assert_eq!(t.shape(), &[3, 480, 640]);
1563    }
1564
1565    #[test]
1566    fn image_tensor_semi_planar_contiguous() {
1567        let t = Tensor::<u8>::image(640, 480, PixelFormat::Nv12, None).unwrap();
1568        assert_eq!(t.format(), Some(PixelFormat::Nv12));
1569        assert_eq!(t.width(), Some(640));
1570        assert_eq!(t.height(), Some(480));
1571        // NV12: H*3/2 = 720
1572        assert_eq!(t.shape(), &[720, 640]);
1573        assert!(!t.is_multiplane());
1574    }
1575
1576    #[test]
1577    #[cfg(target_os = "linux")]
1578    fn image_tensor_with_stride_preserves_logical_width() {
1579        // Skip if DMA not available (e.g. sandboxed CI lacking dma_heap access).
1580        if !is_dma_available() {
1581            eprintln!("SKIPPED: DMA heap not available");
1582            return;
1583        }
1584        // 3004×1688 RGBA8: natural pitch 12016, padded to 12032 (64-aligned).
1585        let stride = 12032;
1586        let t = Tensor::<u8>::image_with_stride(
1587            3004,
1588            1688,
1589            PixelFormat::Rgba,
1590            stride,
1591            Some(TensorMemory::Dma),
1592        )
1593        .unwrap();
1594        // Logical dimensions unchanged by padding — this is the contract.
1595        assert_eq!(t.width(), Some(3004));
1596        assert_eq!(t.height(), Some(1688));
1597        assert_eq!(t.shape(), &[1688, 3004, 4]);
1598        // Stride is carried separately and reports the padded pitch.
1599        assert_eq!(t.effective_row_stride(), Some(stride));
1600        // Buffer is sized to stride × height so the full padded layout fits,
1601        // and CPU map() works for self-allocated strided DMA tensors.
1602        use crate::TensorMapTrait;
1603        {
1604            let map = t.map().unwrap();
1605            assert!(
1606                map.as_slice().len() >= stride * 1688,
1607                "mapped buffer {} bytes < expected {}",
1608                map.as_slice().len(),
1609                stride * 1688
1610            );
1611        }
1612        // CPU write access works too — iterate rows using the padded stride,
1613        // touch only the active `width × bpp` region, verify it round-trips.
1614        {
1615            let mut map = t.map().unwrap();
1616            let slice = map.as_mut_slice();
1617            for y in 0..1688 {
1618                let row_start = y * stride;
1619                for x in 0..3004 {
1620                    let p = row_start + x * 4;
1621                    slice[p] = (y & 0xFF) as u8;
1622                    slice[p + 1] = (x & 0xFF) as u8;
1623                    slice[p + 2] = 0x42;
1624                    slice[p + 3] = 0xFF;
1625                }
1626            }
1627        }
1628        {
1629            let map = t.map().unwrap();
1630            let slice = map.as_slice();
1631            // Sample a few pixels to confirm the round-trip.
1632            assert_eq!(slice[0], 0x00);
1633            assert_eq!(slice[1], 0x00);
1634            assert_eq!(slice[2], 0x42);
1635            assert_eq!(slice[3], 0xFF);
1636            let mid = 100 * stride + 50 * 4;
1637            assert_eq!(slice[mid], 100);
1638            assert_eq!(slice[mid + 1], 50);
1639            assert_eq!(slice[mid + 2], 0x42);
1640        }
1641    }
1642
1643    #[test]
1644    #[cfg(target_os = "linux")]
1645    fn image_tensor_with_stride_rejects_foreign_strided_map() {
1646        // A FOREIGN (imported via from_fd) DMA tensor with row_stride set
1647        // should still refuse CPU mapping — external allocator owns the
1648        // layout. This protects the V4L2 / GStreamer use case.
1649        //
1650        // We simulate a foreign import by wrapping our own allocation's
1651        // fd via `from_fd` and calling set_row_stride manually. The
1652        // `is_imported` flag on from_fd is true by construction.
1653        if !is_dma_available() {
1654            eprintln!("SKIPPED: DMA heap not available");
1655            return;
1656        }
1657        // Allocate a backing buffer large enough for a 320×240 BGRA8 image.
1658        let backing = Tensor::<u8>::new(&[240 * 320 * 4], Some(TensorMemory::Dma), None).unwrap();
1659        let fd = backing.clone_fd().unwrap();
1660        // Import it via from_fd — this marks is_imported=true.
1661        let shape = [240usize, 320, 4];
1662        let storage = TensorStorage::<u8>::from_fd(fd, &shape, None).unwrap();
1663        let mut t = Tensor::<u8>::wrap(storage);
1664        t.set_format(PixelFormat::Bgra).unwrap();
1665        t.set_row_stride(320 * 4).unwrap(); // natural, but still marks it as strided
1666        let err = t.map();
1667        assert!(
1668            matches!(err, Err(Error::InvalidOperation(_))),
1669            "foreign strided map should error"
1670        );
1671    }
1672
1673    #[test]
1674    #[cfg(target_os = "linux")]
1675    fn image_tensor_with_stride_map_rejects_tampered_stride() {
1676        // Round-3 PR feedback (C1): `set_row_stride` is public and only
1677        // validates `stride >= min_stride`, not that the new stride × height
1678        // fits the underlying buffer. A caller that tampers with the stride
1679        // after allocation must not be able to coerce `Tensor::map()` into
1680        // returning a slice larger than the backing mmap (that would be UB
1681        // in `DmaMap::as_slice`).
1682        if !is_dma_available() {
1683            eprintln!("SKIPPED: DMA heap not available");
1684            return;
1685        }
1686        // Allocate a 640×480 RGBA8 padded canvas (stride = 3072 = 768 px).
1687        // Backing buffer is 3072 × 480 = 1,474,560 bytes.
1688        let mut t = Tensor::<u8>::image_with_stride(
1689            640,
1690            480,
1691            PixelFormat::Rgba,
1692            3072,
1693            Some(TensorMemory::Dma),
1694        )
1695        .unwrap();
1696        // Tamper: push the stride up to 4 × the original. This is >=
1697        // min_stride (2560), so `set_row_stride` accepts it.
1698        t.set_row_stride(12288).unwrap();
1699        // Map must now refuse — 12288 × 480 = 5,898,240 > 1,474,560.
1700        let err = t.map();
1701        assert!(
1702            matches!(err, Err(Error::InvalidOperation(_))),
1703            "map() with oversized stride must return InvalidOperation"
1704        );
1705    }
1706
1707    #[test]
1708    fn dma_tensor_new_with_byte_size_rejects_shape_overflow() {
1709        // Round-3 PR feedback (C3): shape.product() * sizeof(T) must use
1710        // checked arithmetic so a pathological shape can't wrap usize and
1711        // make the byte_size-vs-logical-size comparison incorrect.
1712        //
1713        // This test only exercises the overflow rejection path, which is
1714        // pure-Rust and doesn't touch dma_heap — safe to run on any target.
1715        #[cfg(target_os = "linux")]
1716        {
1717            let err = crate::dma::DmaTensor::<u64>::new_with_byte_size(
1718                &[usize::MAX, 2, 2],
1719                usize::MAX,
1720                None,
1721            );
1722            assert!(
1723                matches!(err, Err(Error::InvalidArgument(_))),
1724                "new_with_byte_size must detect shape.product() overflow"
1725            );
1726        }
1727    }
1728
1729    #[test]
1730    #[cfg(target_os = "linux")]
1731    fn image_tensor_with_stride_rejects_too_small_stride() {
1732        // 640×480 RGBA8 natural pitch = 2560, request 2400 → should error.
1733        let err = Tensor::<u8>::image_with_stride(
1734            640,
1735            480,
1736            PixelFormat::Rgba,
1737            2400,
1738            Some(TensorMemory::Dma),
1739        );
1740        assert!(matches!(err, Err(Error::InvalidArgument(_))));
1741    }
1742
1743    #[test]
1744    #[cfg(target_os = "linux")]
1745    fn image_tensor_with_stride_rejects_non_packed() {
1746        // NV12 is SemiPlanar → not supported. (Linux-only because
1747        // `TensorMemory::Dma` itself is a Linux-only enum variant.)
1748        let err = Tensor::<u8>::image_with_stride(
1749            640,
1750            480,
1751            PixelFormat::Nv12,
1752            640,
1753            Some(TensorMemory::Dma),
1754        );
1755        assert!(matches!(err, Err(Error::NotImplemented(_))));
1756    }
1757
1758    #[test]
1759    fn set_format_valid() {
1760        let mut t = Tensor::<u8>::new(&[480, 640, 3], None, None).unwrap();
1761        assert!(t.format().is_none());
1762        t.set_format(PixelFormat::Rgb).unwrap();
1763        assert_eq!(t.format(), Some(PixelFormat::Rgb));
1764        assert_eq!(t.width(), Some(640));
1765        assert_eq!(t.height(), Some(480));
1766    }
1767
1768    #[test]
1769    fn set_format_invalid_shape() {
1770        let mut t = Tensor::<u8>::new(&[480, 640, 4], None, None).unwrap();
1771        // RGB expects 3 channels, not 4
1772        let err = t.set_format(PixelFormat::Rgb);
1773        assert!(err.is_err());
1774        // Original tensor is unmodified
1775        assert!(t.format().is_none());
1776    }
1777
1778    #[test]
1779    fn reshape_clears_format() {
1780        let mut t = Tensor::<u8>::image(640, 480, PixelFormat::Rgba, None).unwrap();
1781        assert_eq!(t.format(), Some(PixelFormat::Rgba));
1782        // Reshape to flat — format cleared
1783        t.reshape(&[480 * 640 * 4]).unwrap();
1784        assert!(t.format().is_none());
1785    }
1786
1787    #[test]
1788    fn from_planes_nv12() {
1789        let y = Tensor::<u8>::new(&[480, 640], None, None).unwrap();
1790        let uv = Tensor::<u8>::new(&[240, 640], None, None).unwrap();
1791        let img = Tensor::from_planes(y, uv, PixelFormat::Nv12).unwrap();
1792        assert_eq!(img.format(), Some(PixelFormat::Nv12));
1793        assert!(img.is_multiplane());
1794        assert!(img.chroma().is_some());
1795        assert_eq!(img.width(), Some(640));
1796        assert_eq!(img.height(), Some(480));
1797    }
1798
1799    #[test]
1800    fn from_planes_rejects_non_semiplanar() {
1801        let y = Tensor::<u8>::new(&[480, 640], None, None).unwrap();
1802        let uv = Tensor::<u8>::new(&[240, 640], None, None).unwrap();
1803        let err = Tensor::from_planes(y, uv, PixelFormat::Rgb);
1804        assert!(err.is_err());
1805    }
1806
1807    #[test]
1808    fn reshape_multiplane_errors() {
1809        let y = Tensor::<u8>::new(&[480, 640], None, None).unwrap();
1810        let uv = Tensor::<u8>::new(&[240, 640], None, None).unwrap();
1811        let mut img = Tensor::from_planes(y, uv, PixelFormat::Nv12).unwrap();
1812        let err = img.reshape(&[480 * 640 + 240 * 640]);
1813        assert!(err.is_err());
1814    }
1815}
1816
1817#[cfg(test)]
1818mod tests {
1819    #[cfg(target_os = "linux")]
1820    use nix::unistd::{access, AccessFlags};
1821    #[cfg(target_os = "linux")]
1822    use std::io::Write as _;
1823    use std::sync::RwLock;
1824
1825    use super::*;
1826
1827    #[ctor::ctor]
1828    fn init() {
1829        env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init();
1830    }
1831
1832    /// Macro to get the current function name for logging in tests.
1833    #[cfg(target_os = "linux")]
1834    macro_rules! function {
1835        () => {{
1836            fn f() {}
1837            fn type_name_of<T>(_: T) -> &'static str {
1838                std::any::type_name::<T>()
1839            }
1840            let name = type_name_of(f);
1841
1842            // Find and cut the rest of the path
1843            match &name[..name.len() - 3].rfind(':') {
1844                Some(pos) => &name[pos + 1..name.len() - 3],
1845                None => &name[..name.len() - 3],
1846            }
1847        }};
1848    }
1849
1850    #[test]
1851    #[cfg(target_os = "linux")]
1852    fn test_tensor() {
1853        let _lock = FD_LOCK.read().unwrap();
1854        let shape = vec![1];
1855        let tensor = DmaTensor::<f32>::new(&shape, Some("dma_tensor"));
1856        let dma_enabled = tensor.is_ok();
1857
1858        let tensor = Tensor::<f32>::new(&shape, None, None).expect("Failed to create tensor");
1859        match dma_enabled {
1860            true => assert_eq!(tensor.memory(), TensorMemory::Dma),
1861            false => assert_eq!(tensor.memory(), TensorMemory::Shm),
1862        }
1863    }
1864
1865    #[test]
1866    #[cfg(all(unix, not(target_os = "linux")))]
1867    fn test_tensor() {
1868        let shape = vec![1];
1869        let tensor = Tensor::<f32>::new(&shape, None, None).expect("Failed to create tensor");
1870        // On macOS/BSD, auto-detection tries SHM first, falls back to Mem
1871        assert!(
1872            tensor.memory() == TensorMemory::Shm || tensor.memory() == TensorMemory::Mem,
1873            "Expected SHM or Mem on macOS, got {:?}",
1874            tensor.memory()
1875        );
1876    }
1877
1878    #[test]
1879    #[cfg(not(unix))]
1880    fn test_tensor() {
1881        let shape = vec![1];
1882        let tensor = Tensor::<f32>::new(&shape, None, None).expect("Failed to create tensor");
1883        assert_eq!(tensor.memory(), TensorMemory::Mem);
1884    }
1885
1886    #[test]
1887    #[cfg(target_os = "linux")]
1888    fn test_dma_tensor() {
1889        let _lock = FD_LOCK.read().unwrap();
1890        match access(
1891            "/dev/dma_heap/linux,cma",
1892            AccessFlags::R_OK | AccessFlags::W_OK,
1893        ) {
1894            Ok(_) => println!("/dev/dma_heap/linux,cma is available"),
1895            Err(_) => match access(
1896                "/dev/dma_heap/system",
1897                AccessFlags::R_OK | AccessFlags::W_OK,
1898            ) {
1899                Ok(_) => println!("/dev/dma_heap/system is available"),
1900                Err(e) => {
1901                    writeln!(
1902                        &mut std::io::stdout(),
1903                        "[WARNING] DMA Heap is unavailable: {e}"
1904                    )
1905                    .unwrap();
1906                    return;
1907                }
1908            },
1909        }
1910
1911        let shape = vec![2, 3, 4];
1912        let tensor =
1913            DmaTensor::<f32>::new(&shape, Some("test_tensor")).expect("Failed to create tensor");
1914
1915        const DUMMY_VALUE: f32 = 12.34;
1916
1917        assert_eq!(tensor.memory(), TensorMemory::Dma);
1918        assert_eq!(tensor.name(), "test_tensor");
1919        assert_eq!(tensor.shape(), &shape);
1920        assert_eq!(tensor.size(), 2 * 3 * 4 * std::mem::size_of::<f32>());
1921        assert_eq!(tensor.len(), 2 * 3 * 4);
1922
1923        {
1924            let mut tensor_map = tensor.map().expect("Failed to map DMA memory");
1925            tensor_map.fill(42.0);
1926            assert!(tensor_map.iter().all(|&x| x == 42.0));
1927        }
1928
1929        {
1930            let shared = Tensor::<f32>::from_fd(
1931                tensor
1932                    .clone_fd()
1933                    .expect("Failed to duplicate tensor file descriptor"),
1934                &shape,
1935                Some("test_tensor_shared"),
1936            )
1937            .expect("Failed to create tensor from fd");
1938
1939            assert_eq!(shared.memory(), TensorMemory::Dma);
1940            assert_eq!(shared.name(), "test_tensor_shared");
1941            assert_eq!(shared.shape(), &shape);
1942
1943            let mut tensor_map = shared.map().expect("Failed to map DMA memory from fd");
1944            tensor_map.fill(DUMMY_VALUE);
1945            assert!(tensor_map.iter().all(|&x| x == DUMMY_VALUE));
1946        }
1947
1948        {
1949            let tensor_map = tensor.map().expect("Failed to map DMA memory");
1950            assert!(tensor_map.iter().all(|&x| x == DUMMY_VALUE));
1951        }
1952
1953        let mut tensor = DmaTensor::<u8>::new(&shape, None).expect("Failed to create tensor");
1954        assert_eq!(tensor.shape(), &shape);
1955        let new_shape = vec![3, 4, 4];
1956        assert!(
1957            tensor.reshape(&new_shape).is_err(),
1958            "Reshape should fail due to size mismatch"
1959        );
1960        assert_eq!(tensor.shape(), &shape, "Shape should remain unchanged");
1961
1962        let new_shape = vec![2, 3, 4];
1963        tensor.reshape(&new_shape).expect("Reshape should succeed");
1964        assert_eq!(
1965            tensor.shape(),
1966            &new_shape,
1967            "Shape should be updated after successful reshape"
1968        );
1969
1970        {
1971            let mut tensor_map = tensor.map().expect("Failed to map DMA memory");
1972            tensor_map.fill(1);
1973            assert!(tensor_map.iter().all(|&x| x == 1));
1974        }
1975
1976        {
1977            let mut tensor_map = tensor.map().expect("Failed to map DMA memory");
1978            tensor_map[2] = 42;
1979            assert_eq!(tensor_map[1], 1, "Value at index 1 should be 1");
1980            assert_eq!(tensor_map[2], 42, "Value at index 2 should be 42");
1981        }
1982    }
1983
1984    #[test]
1985    #[cfg(unix)]
1986    fn test_shm_tensor() {
1987        let _lock = FD_LOCK.read().unwrap();
1988        let shape = vec![2, 3, 4];
1989        let tensor =
1990            ShmTensor::<f32>::new(&shape, Some("test_tensor")).expect("Failed to create tensor");
1991        assert_eq!(tensor.shape(), &shape);
1992        assert_eq!(tensor.size(), 2 * 3 * 4 * std::mem::size_of::<f32>());
1993        assert_eq!(tensor.name(), "test_tensor");
1994
1995        const DUMMY_VALUE: f32 = 12.34;
1996        {
1997            let mut tensor_map = tensor.map().expect("Failed to map shared memory");
1998            tensor_map.fill(42.0);
1999            assert!(tensor_map.iter().all(|&x| x == 42.0));
2000        }
2001
2002        {
2003            let shared = Tensor::<f32>::from_fd(
2004                tensor
2005                    .clone_fd()
2006                    .expect("Failed to duplicate tensor file descriptor"),
2007                &shape,
2008                Some("test_tensor_shared"),
2009            )
2010            .expect("Failed to create tensor from fd");
2011
2012            assert_eq!(shared.memory(), TensorMemory::Shm);
2013            assert_eq!(shared.name(), "test_tensor_shared");
2014            assert_eq!(shared.shape(), &shape);
2015
2016            let mut tensor_map = shared.map().expect("Failed to map shared memory from fd");
2017            tensor_map.fill(DUMMY_VALUE);
2018            assert!(tensor_map.iter().all(|&x| x == DUMMY_VALUE));
2019        }
2020
2021        {
2022            let tensor_map = tensor.map().expect("Failed to map shared memory");
2023            assert!(tensor_map.iter().all(|&x| x == DUMMY_VALUE));
2024        }
2025
2026        let mut tensor = ShmTensor::<u8>::new(&shape, None).expect("Failed to create tensor");
2027        assert_eq!(tensor.shape(), &shape);
2028        let new_shape = vec![3, 4, 4];
2029        assert!(
2030            tensor.reshape(&new_shape).is_err(),
2031            "Reshape should fail due to size mismatch"
2032        );
2033        assert_eq!(tensor.shape(), &shape, "Shape should remain unchanged");
2034
2035        let new_shape = vec![2, 3, 4];
2036        tensor.reshape(&new_shape).expect("Reshape should succeed");
2037        assert_eq!(
2038            tensor.shape(),
2039            &new_shape,
2040            "Shape should be updated after successful reshape"
2041        );
2042
2043        {
2044            let mut tensor_map = tensor.map().expect("Failed to map shared memory");
2045            tensor_map.fill(1);
2046            assert!(tensor_map.iter().all(|&x| x == 1));
2047        }
2048
2049        {
2050            let mut tensor_map = tensor.map().expect("Failed to map shared memory");
2051            tensor_map[2] = 42;
2052            assert_eq!(tensor_map[1], 1, "Value at index 1 should be 1");
2053            assert_eq!(tensor_map[2], 42, "Value at index 2 should be 42");
2054        }
2055    }
2056
2057    #[test]
2058    fn test_mem_tensor() {
2059        let shape = vec![2, 3, 4];
2060        let tensor =
2061            MemTensor::<f32>::new(&shape, Some("test_tensor")).expect("Failed to create tensor");
2062        assert_eq!(tensor.shape(), &shape);
2063        assert_eq!(tensor.size(), 2 * 3 * 4 * std::mem::size_of::<f32>());
2064        assert_eq!(tensor.name(), "test_tensor");
2065
2066        {
2067            let mut tensor_map = tensor.map().expect("Failed to map memory");
2068            tensor_map.fill(42.0);
2069            assert!(tensor_map.iter().all(|&x| x == 42.0));
2070        }
2071
2072        let mut tensor = MemTensor::<u8>::new(&shape, None).expect("Failed to create tensor");
2073        assert_eq!(tensor.shape(), &shape);
2074        let new_shape = vec![3, 4, 4];
2075        assert!(
2076            tensor.reshape(&new_shape).is_err(),
2077            "Reshape should fail due to size mismatch"
2078        );
2079        assert_eq!(tensor.shape(), &shape, "Shape should remain unchanged");
2080
2081        let new_shape = vec![2, 3, 4];
2082        tensor.reshape(&new_shape).expect("Reshape should succeed");
2083        assert_eq!(
2084            tensor.shape(),
2085            &new_shape,
2086            "Shape should be updated after successful reshape"
2087        );
2088
2089        {
2090            let mut tensor_map = tensor.map().expect("Failed to map memory");
2091            tensor_map.fill(1);
2092            assert!(tensor_map.iter().all(|&x| x == 1));
2093        }
2094
2095        {
2096            let mut tensor_map = tensor.map().expect("Failed to map memory");
2097            tensor_map[2] = 42;
2098            assert_eq!(tensor_map[1], 1, "Value at index 1 should be 1");
2099            assert_eq!(tensor_map[2], 42, "Value at index 2 should be 42");
2100        }
2101    }
2102
2103    #[test]
2104    #[cfg(target_os = "linux")]
2105    fn test_dma_no_fd_leaks() {
2106        let _lock = FD_LOCK.write().unwrap();
2107        if !is_dma_available() {
2108            log::warn!(
2109                "SKIPPED: {} - DMA memory allocation not available (permission denied or no DMA-BUF support)",
2110                function!()
2111            );
2112            return;
2113        }
2114
2115        let proc = procfs::process::Process::myself()
2116            .expect("Failed to get current process using /proc/self");
2117
2118        let start_open_fds = proc
2119            .fd_count()
2120            .expect("Failed to get open file descriptor count");
2121
2122        for _ in 0..100 {
2123            let tensor = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Dma), None)
2124                .expect("Failed to create tensor");
2125            let mut map = tensor.map().unwrap();
2126            map.as_mut_slice().fill(233);
2127        }
2128
2129        let end_open_fds = proc
2130            .fd_count()
2131            .expect("Failed to get open file descriptor count");
2132
2133        assert_eq!(
2134            start_open_fds, end_open_fds,
2135            "File descriptor leak detected: {} -> {}",
2136            start_open_fds, end_open_fds
2137        );
2138    }
2139
2140    #[test]
2141    #[cfg(target_os = "linux")]
2142    fn test_dma_from_fd_no_fd_leaks() {
2143        let _lock = FD_LOCK.write().unwrap();
2144        if !is_dma_available() {
2145            log::warn!(
2146                "SKIPPED: {} - DMA memory allocation not available (permission denied or no DMA-BUF support)",
2147                function!()
2148            );
2149            return;
2150        }
2151
2152        let proc = procfs::process::Process::myself()
2153            .expect("Failed to get current process using /proc/self");
2154
2155        let start_open_fds = proc
2156            .fd_count()
2157            .expect("Failed to get open file descriptor count");
2158
2159        let orig = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Dma), None).unwrap();
2160
2161        for _ in 0..100 {
2162            let tensor =
2163                Tensor::<u8>::from_fd(orig.clone_fd().unwrap(), orig.shape(), None).unwrap();
2164            let mut map = tensor.map().unwrap();
2165            map.as_mut_slice().fill(233);
2166        }
2167        drop(orig);
2168
2169        let end_open_fds = proc.fd_count().unwrap();
2170
2171        assert_eq!(
2172            start_open_fds, end_open_fds,
2173            "File descriptor leak detected: {} -> {}",
2174            start_open_fds, end_open_fds
2175        );
2176    }
2177
2178    #[test]
2179    #[cfg(target_os = "linux")]
2180    fn test_shm_no_fd_leaks() {
2181        let _lock = FD_LOCK.write().unwrap();
2182        if !is_shm_available() {
2183            log::warn!(
2184                "SKIPPED: {} - SHM memory allocation not available (permission denied or no SHM support)",
2185                function!()
2186            );
2187            return;
2188        }
2189
2190        let proc = procfs::process::Process::myself()
2191            .expect("Failed to get current process using /proc/self");
2192
2193        let start_open_fds = proc
2194            .fd_count()
2195            .expect("Failed to get open file descriptor count");
2196
2197        for _ in 0..100 {
2198            let tensor = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Shm), None)
2199                .expect("Failed to create tensor");
2200            let mut map = tensor.map().unwrap();
2201            map.as_mut_slice().fill(233);
2202        }
2203
2204        let end_open_fds = proc
2205            .fd_count()
2206            .expect("Failed to get open file descriptor count");
2207
2208        assert_eq!(
2209            start_open_fds, end_open_fds,
2210            "File descriptor leak detected: {} -> {}",
2211            start_open_fds, end_open_fds
2212        );
2213    }
2214
2215    #[test]
2216    #[cfg(target_os = "linux")]
2217    fn test_shm_from_fd_no_fd_leaks() {
2218        let _lock = FD_LOCK.write().unwrap();
2219        if !is_shm_available() {
2220            log::warn!(
2221                "SKIPPED: {} - SHM memory allocation not available (permission denied or no SHM support)",
2222                function!()
2223            );
2224            return;
2225        }
2226
2227        let proc = procfs::process::Process::myself()
2228            .expect("Failed to get current process using /proc/self");
2229
2230        let start_open_fds = proc
2231            .fd_count()
2232            .expect("Failed to get open file descriptor count");
2233
2234        let orig = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Shm), None).unwrap();
2235
2236        for _ in 0..100 {
2237            let tensor =
2238                Tensor::<u8>::from_fd(orig.clone_fd().unwrap(), orig.shape(), None).unwrap();
2239            let mut map = tensor.map().unwrap();
2240            map.as_mut_slice().fill(233);
2241        }
2242        drop(orig);
2243
2244        let end_open_fds = proc.fd_count().unwrap();
2245
2246        assert_eq!(
2247            start_open_fds, end_open_fds,
2248            "File descriptor leak detected: {} -> {}",
2249            start_open_fds, end_open_fds
2250        );
2251    }
2252
2253    #[cfg(feature = "ndarray")]
2254    #[test]
2255    fn test_ndarray() {
2256        let _lock = FD_LOCK.read().unwrap();
2257        let shape = vec![2, 3, 4];
2258        let tensor = Tensor::<f32>::new(&shape, None, None).expect("Failed to create tensor");
2259
2260        let mut tensor_map = tensor.map().expect("Failed to map tensor memory");
2261        tensor_map.fill(1.0);
2262
2263        let view = tensor_map.view().expect("Failed to get ndarray view");
2264        assert_eq!(view.shape(), &[2, 3, 4]);
2265        assert!(view.iter().all(|&x| x == 1.0));
2266
2267        let mut view_mut = tensor_map
2268            .view_mut()
2269            .expect("Failed to get mutable ndarray view");
2270        view_mut[[0, 0, 0]] = 42.0;
2271        assert_eq!(view_mut[[0, 0, 0]], 42.0);
2272        assert_eq!(tensor_map[0], 42.0, "Value at index 0 should be 42");
2273    }
2274
2275    #[test]
2276    fn test_buffer_identity_unique() {
2277        let id1 = BufferIdentity::new();
2278        let id2 = BufferIdentity::new();
2279        assert_ne!(
2280            id1.id(),
2281            id2.id(),
2282            "Two identities should have different ids"
2283        );
2284    }
2285
2286    #[test]
2287    fn test_buffer_identity_clone_shares_guard() {
2288        let id1 = BufferIdentity::new();
2289        let weak = id1.weak();
2290        assert!(
2291            weak.upgrade().is_some(),
2292            "Weak should be alive while original exists"
2293        );
2294
2295        let id2 = id1.clone();
2296        assert_eq!(id1.id(), id2.id(), "Cloned identity should have same id");
2297
2298        drop(id1);
2299        assert!(
2300            weak.upgrade().is_some(),
2301            "Weak should still be alive (clone holds Arc)"
2302        );
2303
2304        drop(id2);
2305        assert!(
2306            weak.upgrade().is_none(),
2307            "Weak should be dead after all clones dropped"
2308        );
2309    }
2310
2311    #[test]
2312    fn test_tensor_buffer_identity() {
2313        let t1 = Tensor::<u8>::new(&[100], Some(TensorMemory::Mem), Some("t1")).unwrap();
2314        let t2 = Tensor::<u8>::new(&[100], Some(TensorMemory::Mem), Some("t2")).unwrap();
2315        assert_ne!(
2316            t1.buffer_identity().id(),
2317            t2.buffer_identity().id(),
2318            "Different tensors should have different buffer ids"
2319        );
2320    }
2321
2322    // Any test that cares about the fd count must grab it exclusively.
2323    // Any tests which modifies the fd count by opening or closing fds must grab it
2324    // shared.
2325    pub static FD_LOCK: RwLock<()> = RwLock::new(());
2326
2327    /// Test that DMA is NOT available on non-Linux platforms.
2328    /// This verifies the cross-platform behavior of is_dma_available().
2329    #[test]
2330    #[cfg(not(target_os = "linux"))]
2331    fn test_dma_not_available_on_non_linux() {
2332        assert!(
2333            !is_dma_available(),
2334            "DMA memory allocation should NOT be available on non-Linux platforms"
2335        );
2336    }
2337
2338    /// Test that SHM memory allocation is available and usable on Unix systems.
2339    /// This is a basic functional test; Linux has additional FD leak tests using procfs.
2340    #[test]
2341    #[cfg(unix)]
2342    fn test_shm_available_and_usable() {
2343        assert!(
2344            is_shm_available(),
2345            "SHM memory allocation should be available on Unix systems"
2346        );
2347
2348        // Create a tensor with SHM backing
2349        let tensor = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Shm), None)
2350            .expect("Failed to create SHM tensor");
2351
2352        // Verify we can map and write to it
2353        let mut map = tensor.map().expect("Failed to map SHM tensor");
2354        map.as_mut_slice().fill(0xAB);
2355
2356        // Verify the data was written correctly
2357        assert!(
2358            map.as_slice().iter().all(|&b| b == 0xAB),
2359            "SHM tensor data should be writable and readable"
2360        );
2361    }
2362}