Skip to main content

edgefirst_tensor/
lib.rs

1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4/*!
5EdgeFirst HAL - Tensor Module
6
7The `edgefirst_tensor` crate provides a unified interface for managing multi-dimensional arrays (tensors)
8with support for different memory types, including Direct Memory Access (DMA), POSIX Shared Memory (Shm),
9and system memory. The crate defines traits and structures for creating, reshaping, and mapping tensors into memory.
10
11## Examples
12```rust
13use edgefirst_tensor::{Error, Tensor, TensorMemory, TensorTrait};
14# fn main() -> Result<(), Error> {
15let tensor = Tensor::<f32>::new(&[2, 3, 4], Some(TensorMemory::Mem), Some("test_tensor"))?;
16assert_eq!(tensor.memory(), TensorMemory::Mem);
17assert_eq!(tensor.name(), "test_tensor");
18#    Ok(())
19# }
20```
21
22## Overview
23The main structures and traits provided by the `edgefirst_tensor` crate are `TensorTrait` and `TensorMapTrait`,
24which define the behavior of Tensors and their memory mappings, respectively.
25The `Tensor<T>` struct wraps a backend-specific storage with optional image format metadata (`PixelFormat`),
26while the `TensorMap` enum provides access to the underlying data. The `TensorDyn` type-erased enum
27wraps `Tensor<T>` for runtime element-type dispatch.
28 */
29#[cfg(target_os = "linux")]
30mod dma;
31#[cfg(target_os = "linux")]
32mod dmabuf;
33mod error;
34mod format;
35mod mem;
36mod pbo;
37#[cfg(unix)]
38mod shm;
39mod tensor_dyn;
40
41#[cfg(target_os = "linux")]
42pub use crate::dma::{DmaMap, DmaTensor};
43pub use crate::mem::{MemMap, MemTensor};
44pub use crate::pbo::{PboMap, PboMapping, PboOps, PboTensor};
45#[cfg(unix)]
46pub use crate::shm::{ShmMap, ShmTensor};
47pub use error::{Error, Result};
48pub use format::{PixelFormat, PixelLayout};
49use num_traits::Num;
50use serde::{Deserialize, Serialize};
51#[cfg(unix)]
52use std::os::fd::OwnedFd;
53use std::{
54    fmt,
55    ops::{Deref, DerefMut},
56    sync::{
57        atomic::{AtomicU64, Ordering},
58        Arc, Weak,
59    },
60};
61pub use tensor_dyn::TensorDyn;
62
63/// Element type discriminant for runtime type identification.
64#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
65#[repr(u8)]
66#[non_exhaustive]
67pub enum DType {
68    U8,
69    I8,
70    U16,
71    I16,
72    U32,
73    I32,
74    U64,
75    I64,
76    F16,
77    F32,
78    F64,
79}
80
81impl DType {
82    /// Size of one element in bytes.
83    pub const fn size(&self) -> usize {
84        match self {
85            Self::U8 | Self::I8 => 1,
86            Self::U16 | Self::I16 | Self::F16 => 2,
87            Self::U32 | Self::I32 | Self::F32 => 4,
88            Self::U64 | Self::I64 | Self::F64 => 8,
89        }
90    }
91
92    /// Short type name (e.g., "u8", "f32", "f16").
93    pub const fn name(&self) -> &'static str {
94        match self {
95            Self::U8 => "u8",
96            Self::I8 => "i8",
97            Self::U16 => "u16",
98            Self::I16 => "i16",
99            Self::U32 => "u32",
100            Self::I32 => "i32",
101            Self::U64 => "u64",
102            Self::I64 => "i64",
103            Self::F16 => "f16",
104            Self::F32 => "f32",
105            Self::F64 => "f64",
106        }
107    }
108}
109
110impl fmt::Display for DType {
111    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
112        f.write_str(self.name())
113    }
114}
115
116/// Monotonic counter for buffer identity IDs.
117static NEXT_BUFFER_ID: AtomicU64 = AtomicU64::new(1);
118
119/// Unique identity for a tensor's underlying buffer.
120///
121/// Created fresh on every buffer allocation or import. The `id` is a monotonic
122/// u64 used as a cache key. The `guard` is an `Arc<()>` whose weak references
123/// allow downstream caches to detect when the buffer has been dropped.
124#[derive(Debug, Clone)]
125pub struct BufferIdentity {
126    id: u64,
127    guard: Arc<()>,
128}
129
130impl BufferIdentity {
131    /// Create a new unique buffer identity.
132    pub fn new() -> Self {
133        Self {
134            id: NEXT_BUFFER_ID.fetch_add(1, Ordering::Relaxed),
135            guard: Arc::new(()),
136        }
137    }
138
139    /// Unique identifier for this buffer. Changes when the buffer changes.
140    pub fn id(&self) -> u64 {
141        self.id
142    }
143
144    /// Returns a weak reference to the buffer guard. Goes dead when the
145    /// owning Tensor is dropped (and no clones remain).
146    pub fn weak(&self) -> Weak<()> {
147        Arc::downgrade(&self.guard)
148    }
149}
150
151impl Default for BufferIdentity {
152    fn default() -> Self {
153        Self::new()
154    }
155}
156
157#[cfg(target_os = "linux")]
158use nix::sys::stat::{major, minor};
159
160pub trait TensorTrait<T>: Send + Sync
161where
162    T: Num + Clone + fmt::Debug,
163{
164    /// Create a new tensor with the given shape and optional name. If no name
165    /// is given, a random name will be generated.
166    fn new(shape: &[usize], name: Option<&str>) -> Result<Self>
167    where
168        Self: Sized;
169
170    #[cfg(unix)]
171    /// Create a new tensor using the given file descriptor, shape, and optional
172    /// name. If no name is given, a random name will be generated.
173    ///
174    /// On Linux: Inspects the fd to determine DMA vs SHM based on device major/minor.
175    /// On other Unix (macOS): Always creates SHM tensor.
176    fn from_fd(fd: std::os::fd::OwnedFd, shape: &[usize], name: Option<&str>) -> Result<Self>
177    where
178        Self: Sized;
179
180    #[cfg(unix)]
181    /// Clone the file descriptor associated with this tensor.
182    fn clone_fd(&self) -> Result<std::os::fd::OwnedFd>;
183
184    /// Get the memory type of this tensor.
185    fn memory(&self) -> TensorMemory;
186
187    /// Get the name of this tensor.
188    fn name(&self) -> String;
189
190    /// Get the number of elements in this tensor.
191    fn len(&self) -> usize {
192        self.shape().iter().product()
193    }
194
195    /// Check if the tensor is empty.
196    fn is_empty(&self) -> bool {
197        self.len() == 0
198    }
199
200    /// Get the size in bytes of this tensor.
201    fn size(&self) -> usize {
202        self.len() * std::mem::size_of::<T>()
203    }
204
205    /// Get the shape of this tensor.
206    fn shape(&self) -> &[usize];
207
208    /// Reshape this tensor to the given shape. The total number of elements
209    /// must remain the same.
210    fn reshape(&mut self, shape: &[usize]) -> Result<()>;
211
212    /// Map the tensor into memory and return a TensorMap for accessing the
213    /// data.
214    fn map(&self) -> Result<TensorMap<T>>;
215
216    /// Get the buffer identity for cache keying and liveness tracking.
217    fn buffer_identity(&self) -> &BufferIdentity;
218}
219
220pub trait TensorMapTrait<T>
221where
222    T: Num + Clone + fmt::Debug,
223{
224    /// Get the shape of this tensor map.
225    fn shape(&self) -> &[usize];
226
227    /// Unmap the tensor from memory.
228    fn unmap(&mut self);
229
230    /// Get the number of elements in this tensor map.
231    fn len(&self) -> usize {
232        self.shape().iter().product()
233    }
234
235    /// Check if the tensor map is empty.
236    fn is_empty(&self) -> bool {
237        self.len() == 0
238    }
239
240    /// Get the size in bytes of this tensor map.
241    fn size(&self) -> usize {
242        self.len() * std::mem::size_of::<T>()
243    }
244
245    /// Get a slice to the data in this tensor map.
246    fn as_slice(&self) -> &[T];
247
248    /// Get a mutable slice to the data in this tensor map.
249    fn as_mut_slice(&mut self) -> &mut [T];
250
251    #[cfg(feature = "ndarray")]
252    /// Get an ndarray ArrayView of the tensor data.
253    fn view(&'_ self) -> Result<ndarray::ArrayView<'_, T, ndarray::Dim<ndarray::IxDynImpl>>> {
254        Ok(ndarray::ArrayView::from_shape(
255            self.shape(),
256            self.as_slice(),
257        )?)
258    }
259
260    #[cfg(feature = "ndarray")]
261    /// Get an ndarray ArrayViewMut of the tensor data.
262    fn view_mut(
263        &'_ mut self,
264    ) -> Result<ndarray::ArrayViewMut<'_, T, ndarray::Dim<ndarray::IxDynImpl>>> {
265        let shape = self.shape().to_vec();
266        Ok(ndarray::ArrayViewMut::from_shape(
267            shape,
268            self.as_mut_slice(),
269        )?)
270    }
271}
272
273#[derive(Debug, Clone, Copy, PartialEq, Eq)]
274pub enum TensorMemory {
275    #[cfg(target_os = "linux")]
276    /// Direct Memory Access (DMA) allocation. Incurs additional
277    /// overhead for memory reading/writing with the CPU.  Allows for
278    /// hardware acceleration when supported.
279    Dma,
280    #[cfg(unix)]
281    /// POSIX Shared Memory allocation. Suitable for inter-process
282    /// communication, but not suitable for hardware acceleration.
283    Shm,
284
285    /// Regular system memory allocation
286    Mem,
287
288    /// OpenGL Pixel Buffer Object memory. Created by ImageProcessor
289    /// when DMA-buf is unavailable but OpenGL is present.
290    Pbo,
291}
292
293impl From<TensorMemory> for String {
294    fn from(memory: TensorMemory) -> Self {
295        match memory {
296            #[cfg(target_os = "linux")]
297            TensorMemory::Dma => "dma".to_owned(),
298            #[cfg(unix)]
299            TensorMemory::Shm => "shm".to_owned(),
300            TensorMemory::Mem => "mem".to_owned(),
301            TensorMemory::Pbo => "pbo".to_owned(),
302        }
303    }
304}
305
306impl TryFrom<&str> for TensorMemory {
307    type Error = Error;
308
309    fn try_from(s: &str) -> Result<Self> {
310        match s {
311            #[cfg(target_os = "linux")]
312            "dma" => Ok(TensorMemory::Dma),
313            #[cfg(unix)]
314            "shm" => Ok(TensorMemory::Shm),
315            "mem" => Ok(TensorMemory::Mem),
316            "pbo" => Ok(TensorMemory::Pbo),
317            _ => Err(Error::InvalidMemoryType(s.to_owned())),
318        }
319    }
320}
321
322#[derive(Debug)]
323#[allow(dead_code)] // Variants are constructed by downstream crates via pub(crate) helpers
324pub(crate) enum TensorStorage<T>
325where
326    T: Num + Clone + fmt::Debug + Send + Sync,
327{
328    #[cfg(target_os = "linux")]
329    Dma(DmaTensor<T>),
330    #[cfg(unix)]
331    Shm(ShmTensor<T>),
332    Mem(MemTensor<T>),
333    Pbo(PboTensor<T>),
334}
335
336impl<T> TensorStorage<T>
337where
338    T: Num + Clone + fmt::Debug + Send + Sync,
339{
340    /// Create a new tensor storage with the given shape, memory type, and
341    /// optional name. If no name is given, a random name will be generated.
342    /// If no memory type is given, the best available memory type will be
343    /// chosen based on the platform and environment variables.
344    fn new(shape: &[usize], memory: Option<TensorMemory>, name: Option<&str>) -> Result<Self> {
345        match memory {
346            #[cfg(target_os = "linux")]
347            Some(TensorMemory::Dma) => {
348                DmaTensor::<T>::new(shape, name).map(TensorStorage::Dma)
349            }
350            #[cfg(unix)]
351            Some(TensorMemory::Shm) => {
352                ShmTensor::<T>::new(shape, name).map(TensorStorage::Shm)
353            }
354            Some(TensorMemory::Mem) => {
355                MemTensor::<T>::new(shape, name).map(TensorStorage::Mem)
356            }
357            Some(TensorMemory::Pbo) => Err(crate::error::Error::NotImplemented(
358                "PboTensor cannot be created via Tensor::new() — use ImageProcessor::create_image()".to_owned(),
359            )),
360            None => {
361                if std::env::var("EDGEFIRST_TENSOR_FORCE_MEM")
362                    .is_ok_and(|x| x != "0" && x.to_lowercase() != "false")
363                {
364                    MemTensor::<T>::new(shape, name).map(TensorStorage::Mem)
365                } else {
366                    #[cfg(target_os = "linux")]
367                    {
368                        // Linux: Try DMA -> SHM -> Mem
369                        match DmaTensor::<T>::new(shape, name) {
370                            Ok(tensor) => Ok(TensorStorage::Dma(tensor)),
371                            Err(_) => {
372                                match ShmTensor::<T>::new(shape, name)
373                                    .map(TensorStorage::Shm)
374                                {
375                                    Ok(tensor) => Ok(tensor),
376                                    Err(_) => MemTensor::<T>::new(shape, name)
377                                        .map(TensorStorage::Mem),
378                                }
379                            }
380                        }
381                    }
382                    #[cfg(all(unix, not(target_os = "linux")))]
383                    {
384                        // macOS/BSD: Try SHM -> Mem (no DMA)
385                        match ShmTensor::<T>::new(shape, name) {
386                            Ok(tensor) => Ok(TensorStorage::Shm(tensor)),
387                            Err(_) => {
388                                MemTensor::<T>::new(shape, name).map(TensorStorage::Mem)
389                            }
390                        }
391                    }
392                    #[cfg(not(unix))]
393                    {
394                        // Windows/other: Mem only
395                        MemTensor::<T>::new(shape, name).map(TensorStorage::Mem)
396                    }
397                }
398            }
399        }
400    }
401
402    /// Create a new tensor storage using the given file descriptor, shape,
403    /// and optional name.
404    #[cfg(unix)]
405    fn from_fd(fd: OwnedFd, shape: &[usize], name: Option<&str>) -> Result<Self> {
406        #[cfg(target_os = "linux")]
407        {
408            use nix::sys::stat::fstat;
409
410            let stat = fstat(&fd)?;
411            let major = major(stat.st_dev);
412            let minor = minor(stat.st_dev);
413
414            log::debug!("Creating tensor from fd: major={major}, minor={minor}");
415
416            if major != 0 {
417                // Dma and Shm tensors are expected to have major number 0
418                return Err(Error::UnknownDeviceType(major, minor));
419            }
420
421            match minor {
422                9 | 10 => {
423                    // minor number 9 & 10 indicates DMA memory
424                    DmaTensor::<T>::from_fd(fd, shape, name).map(TensorStorage::Dma)
425                }
426                _ => {
427                    // other minor numbers are assumed to be shared memory
428                    ShmTensor::<T>::from_fd(fd, shape, name).map(TensorStorage::Shm)
429                }
430            }
431        }
432        #[cfg(all(unix, not(target_os = "linux")))]
433        {
434            // On macOS/BSD, always use SHM (no DMA support)
435            ShmTensor::<T>::from_fd(fd, shape, name).map(TensorStorage::Shm)
436        }
437    }
438}
439
440impl<T> TensorTrait<T> for TensorStorage<T>
441where
442    T: Num + Clone + fmt::Debug + Send + Sync,
443{
444    fn new(shape: &[usize], name: Option<&str>) -> Result<Self> {
445        Self::new(shape, None, name)
446    }
447
448    #[cfg(unix)]
449    fn from_fd(fd: OwnedFd, shape: &[usize], name: Option<&str>) -> Result<Self> {
450        Self::from_fd(fd, shape, name)
451    }
452
453    #[cfg(unix)]
454    fn clone_fd(&self) -> Result<OwnedFd> {
455        match self {
456            #[cfg(target_os = "linux")]
457            TensorStorage::Dma(t) => t.clone_fd(),
458            TensorStorage::Shm(t) => t.clone_fd(),
459            TensorStorage::Mem(t) => t.clone_fd(),
460            TensorStorage::Pbo(t) => t.clone_fd(),
461        }
462    }
463
464    fn memory(&self) -> TensorMemory {
465        match self {
466            #[cfg(target_os = "linux")]
467            TensorStorage::Dma(_) => TensorMemory::Dma,
468            #[cfg(unix)]
469            TensorStorage::Shm(_) => TensorMemory::Shm,
470            TensorStorage::Mem(_) => TensorMemory::Mem,
471            TensorStorage::Pbo(_) => TensorMemory::Pbo,
472        }
473    }
474
475    fn name(&self) -> String {
476        match self {
477            #[cfg(target_os = "linux")]
478            TensorStorage::Dma(t) => t.name(),
479            #[cfg(unix)]
480            TensorStorage::Shm(t) => t.name(),
481            TensorStorage::Mem(t) => t.name(),
482            TensorStorage::Pbo(t) => t.name(),
483        }
484    }
485
486    fn shape(&self) -> &[usize] {
487        match self {
488            #[cfg(target_os = "linux")]
489            TensorStorage::Dma(t) => t.shape(),
490            #[cfg(unix)]
491            TensorStorage::Shm(t) => t.shape(),
492            TensorStorage::Mem(t) => t.shape(),
493            TensorStorage::Pbo(t) => t.shape(),
494        }
495    }
496
497    fn reshape(&mut self, shape: &[usize]) -> Result<()> {
498        match self {
499            #[cfg(target_os = "linux")]
500            TensorStorage::Dma(t) => t.reshape(shape),
501            #[cfg(unix)]
502            TensorStorage::Shm(t) => t.reshape(shape),
503            TensorStorage::Mem(t) => t.reshape(shape),
504            TensorStorage::Pbo(t) => t.reshape(shape),
505        }
506    }
507
508    fn map(&self) -> Result<TensorMap<T>> {
509        match self {
510            #[cfg(target_os = "linux")]
511            TensorStorage::Dma(t) => t.map(),
512            #[cfg(unix)]
513            TensorStorage::Shm(t) => t.map(),
514            TensorStorage::Mem(t) => t.map(),
515            TensorStorage::Pbo(t) => t.map(),
516        }
517    }
518
519    fn buffer_identity(&self) -> &BufferIdentity {
520        match self {
521            #[cfg(target_os = "linux")]
522            TensorStorage::Dma(t) => t.buffer_identity(),
523            #[cfg(unix)]
524            TensorStorage::Shm(t) => t.buffer_identity(),
525            TensorStorage::Mem(t) => t.buffer_identity(),
526            TensorStorage::Pbo(t) => t.buffer_identity(),
527        }
528    }
529}
530
531/// Multi-backend tensor with optional image format metadata.
532///
533/// When `format` is `Some`, this tensor represents an image. Width, height,
534/// and channels are derived from `shape` + `format`. When `format` is `None`,
535/// this is a raw tensor (identical to the pre-refactoring behavior).
536#[derive(Debug)]
537pub struct Tensor<T>
538where
539    T: Num + Clone + fmt::Debug + Send + Sync,
540{
541    pub(crate) storage: TensorStorage<T>,
542    format: Option<PixelFormat>,
543    chroma: Option<Box<Tensor<T>>>,
544}
545
546impl<T> Tensor<T>
547where
548    T: Num + Clone + fmt::Debug + Send + Sync,
549{
550    /// Wrap a TensorStorage in a Tensor with no image metadata.
551    pub(crate) fn wrap(storage: TensorStorage<T>) -> Self {
552        Self {
553            storage,
554            format: None,
555            chroma: None,
556        }
557    }
558
559    /// Create a new tensor with the given shape, memory type, and optional
560    /// name. If no name is given, a random name will be generated. If no
561    /// memory type is given, the best available memory type will be chosen
562    /// based on the platform and environment variables.
563    ///
564    /// On Linux platforms, the order of preference is: Dma -> Shm -> Mem.
565    /// On other Unix platforms (macOS), the order is: Shm -> Mem.
566    /// On non-Unix platforms, only Mem is available.
567    ///
568    /// # Environment Variables
569    /// - `EDGEFIRST_TENSOR_FORCE_MEM`: If set to a non-zero and non-false
570    ///   value, forces the use of regular system memory allocation
571    ///   (`TensorMemory::Mem`) regardless of platform capabilities.
572    ///
573    /// # Example
574    /// ```rust
575    /// use edgefirst_tensor::{Error, Tensor, TensorMemory, TensorTrait};
576    /// # fn main() -> Result<(), Error> {
577    /// let tensor = Tensor::<f32>::new(&[2, 3, 4], Some(TensorMemory::Mem), Some("test_tensor"))?;
578    /// assert_eq!(tensor.memory(), TensorMemory::Mem);
579    /// assert_eq!(tensor.name(), "test_tensor");
580    /// #    Ok(())
581    /// # }
582    /// ```
583    pub fn new(shape: &[usize], memory: Option<TensorMemory>, name: Option<&str>) -> Result<Self> {
584        TensorStorage::new(shape, memory, name).map(Self::wrap)
585    }
586
587    /// Create an image tensor with the given format.
588    pub fn image(
589        width: usize,
590        height: usize,
591        format: PixelFormat,
592        memory: Option<TensorMemory>,
593    ) -> Result<Self> {
594        let shape = match format.layout() {
595            PixelLayout::Packed => vec![height, width, format.channels()],
596            PixelLayout::Planar => vec![format.channels(), height, width],
597            PixelLayout::SemiPlanar => {
598                // Contiguous semi-planar: luma + interleaved chroma in one allocation.
599                // NV12 (4:2:0): H lines luma + H/2 lines chroma = H * 3/2 total
600                // NV16 (4:2:2): H lines luma + H lines chroma = H * 2 total
601                let total_h = match format {
602                    PixelFormat::Nv12 => {
603                        if !height.is_multiple_of(2) {
604                            return Err(Error::InvalidArgument(format!(
605                                "NV12 requires even height, got {height}"
606                            )));
607                        }
608                        height * 3 / 2
609                    }
610                    PixelFormat::Nv16 => height * 2,
611                    _ => {
612                        return Err(Error::InvalidArgument(format!(
613                            "unknown semi-planar height multiplier for {format:?}"
614                        )))
615                    }
616                };
617                vec![total_h, width]
618            }
619        };
620        let mut t = Self::new(&shape, memory, None)?;
621        t.format = Some(format);
622        Ok(t)
623    }
624
625    /// Attach format metadata to an existing tensor.
626    ///
627    /// # Arguments
628    ///
629    /// * `format` - The pixel format to attach
630    ///
631    /// # Returns
632    ///
633    /// `Ok(())` on success, with the format stored as metadata on the tensor.
634    ///
635    /// # Errors
636    ///
637    /// Returns `Error::InvalidShape` if the tensor shape is incompatible with
638    /// the format's layout (packed expects `[H, W, C]`, planar expects
639    /// `[C, H, W]`, semi-planar expects `[H*k, W]` with format-specific
640    /// height constraints).
641    pub fn set_format(&mut self, format: PixelFormat) -> Result<()> {
642        let shape = self.shape();
643        match format.layout() {
644            PixelLayout::Packed => {
645                if shape.len() != 3 || shape[2] != format.channels() {
646                    return Err(Error::InvalidShape(format!(
647                        "packed format {format:?} expects [H, W, {}], got {shape:?}",
648                        format.channels()
649                    )));
650                }
651            }
652            PixelLayout::Planar => {
653                if shape.len() != 3 || shape[0] != format.channels() {
654                    return Err(Error::InvalidShape(format!(
655                        "planar format {format:?} expects [{}, H, W], got {shape:?}",
656                        format.channels()
657                    )));
658                }
659            }
660            PixelLayout::SemiPlanar => {
661                if shape.len() != 2 {
662                    return Err(Error::InvalidShape(format!(
663                        "semi-planar format {format:?} expects [H*k, W], got {shape:?}"
664                    )));
665                }
666                match format {
667                    PixelFormat::Nv12 if !shape[0].is_multiple_of(3) => {
668                        return Err(Error::InvalidShape(format!(
669                            "NV12 contiguous shape[0] must be divisible by 3, got {}",
670                            shape[0]
671                        )));
672                    }
673                    PixelFormat::Nv16 if !shape[0].is_multiple_of(2) => {
674                        return Err(Error::InvalidShape(format!(
675                            "NV16 contiguous shape[0] must be even, got {}",
676                            shape[0]
677                        )));
678                    }
679                    _ => {}
680                }
681            }
682        }
683        self.format = Some(format);
684        Ok(())
685    }
686
687    /// Pixel format (None if not an image).
688    pub fn format(&self) -> Option<PixelFormat> {
689        self.format
690    }
691
692    /// Image width (None if not an image).
693    pub fn width(&self) -> Option<usize> {
694        let fmt = self.format?;
695        let shape = self.shape();
696        match fmt.layout() {
697            PixelLayout::Packed => Some(shape[1]),
698            PixelLayout::Planar => Some(shape[2]),
699            PixelLayout::SemiPlanar => Some(shape[1]),
700        }
701    }
702
703    /// Image height (None if not an image).
704    pub fn height(&self) -> Option<usize> {
705        let fmt = self.format?;
706        let shape = self.shape();
707        match fmt.layout() {
708            PixelLayout::Packed => Some(shape[0]),
709            PixelLayout::Planar => Some(shape[1]),
710            PixelLayout::SemiPlanar => {
711                if self.is_multiplane() {
712                    Some(shape[0])
713                } else {
714                    match fmt {
715                        PixelFormat::Nv12 => Some(shape[0] * 2 / 3),
716                        PixelFormat::Nv16 => Some(shape[0] / 2),
717                        _ => None,
718                    }
719                }
720            }
721        }
722    }
723
724    /// Create from separate Y and UV planes (multiplane NV12/NV16).
725    pub fn from_planes(luma: Tensor<T>, chroma: Tensor<T>, format: PixelFormat) -> Result<Self> {
726        if format.layout() != PixelLayout::SemiPlanar {
727            return Err(Error::InvalidArgument(format!(
728                "from_planes requires a semi-planar format, got {format:?}"
729            )));
730        }
731        if chroma.format.is_some() || chroma.chroma.is_some() {
732            return Err(Error::InvalidArgument(
733                "chroma tensor must be a raw tensor (no format or chroma metadata)".into(),
734            ));
735        }
736        let luma_shape = luma.shape();
737        let chroma_shape = chroma.shape();
738        if luma_shape.len() != 2 || chroma_shape.len() != 2 {
739            return Err(Error::InvalidArgument(format!(
740                "from_planes expects 2D shapes, got luma={luma_shape:?} chroma={chroma_shape:?}"
741            )));
742        }
743        if luma_shape[1] != chroma_shape[1] {
744            return Err(Error::InvalidArgument(format!(
745                "luma width {} != chroma width {}",
746                luma_shape[1], chroma_shape[1]
747            )));
748        }
749        match format {
750            PixelFormat::Nv12 => {
751                if luma_shape[0] % 2 != 0 {
752                    return Err(Error::InvalidArgument(format!(
753                        "NV12 requires even luma height, got {}",
754                        luma_shape[0]
755                    )));
756                }
757                if chroma_shape[0] != luma_shape[0] / 2 {
758                    return Err(Error::InvalidArgument(format!(
759                        "NV12 chroma height {} != luma height / 2 ({})",
760                        chroma_shape[0],
761                        luma_shape[0] / 2
762                    )));
763                }
764            }
765            PixelFormat::Nv16 => {
766                if chroma_shape[0] != luma_shape[0] {
767                    return Err(Error::InvalidArgument(format!(
768                        "NV16 chroma height {} != luma height {}",
769                        chroma_shape[0], luma_shape[0]
770                    )));
771                }
772            }
773            _ => {
774                return Err(Error::InvalidArgument(format!(
775                    "from_planes only supports NV12 and NV16, got {format:?}"
776                )));
777            }
778        }
779
780        Ok(Tensor {
781            storage: luma.storage,
782            format: Some(format),
783            chroma: Some(Box::new(chroma)),
784        })
785    }
786
787    /// Whether this tensor uses separate plane allocations.
788    pub fn is_multiplane(&self) -> bool {
789        self.chroma.is_some()
790    }
791
792    /// Access the chroma plane for multiplane semi-planar images.
793    pub fn chroma(&self) -> Option<&Tensor<T>> {
794        self.chroma.as_deref()
795    }
796
797    /// Downcast to PBO tensor reference (for GL backends).
798    pub fn as_pbo(&self) -> Option<&PboTensor<T>> {
799        match &self.storage {
800            TensorStorage::Pbo(p) => Some(p),
801            _ => None,
802        }
803    }
804
805    /// Downcast to DMA tensor reference (for EGL import, G2D).
806    #[cfg(target_os = "linux")]
807    pub fn as_dma(&self) -> Option<&DmaTensor<T>> {
808        match &self.storage {
809            TensorStorage::Dma(d) => Some(d),
810            _ => None,
811        }
812    }
813
814    /// Borrow the DMA-BUF file descriptor backing this tensor.
815    ///
816    /// # Returns
817    ///
818    /// A borrowed reference to the DMA-BUF file descriptor, tied to `self`'s
819    /// lifetime.
820    ///
821    /// # Errors
822    ///
823    /// Returns `Error::NotImplemented` if the tensor is not DMA-backed.
824    #[cfg(target_os = "linux")]
825    pub fn dmabuf(&self) -> Result<std::os::fd::BorrowedFd<'_>> {
826        use std::os::fd::AsFd;
827        match &self.storage {
828            TensorStorage::Dma(dma) => Ok(dma.fd.as_fd()),
829            _ => Err(Error::NotImplemented(format!(
830                "dmabuf requires DMA-backed tensor, got {:?}",
831                self.storage.memory()
832            ))),
833        }
834    }
835
836    /// Construct a Tensor from a PBO tensor (for GL backends that allocate PBOs).
837    pub fn from_pbo(pbo: PboTensor<T>) -> Self {
838        Self {
839            storage: TensorStorage::Pbo(pbo),
840            format: None,
841            chroma: None,
842        }
843    }
844}
845
846impl<T> TensorTrait<T> for Tensor<T>
847where
848    T: Num + Clone + fmt::Debug + Send + Sync,
849{
850    fn new(shape: &[usize], name: Option<&str>) -> Result<Self>
851    where
852        Self: Sized,
853    {
854        Self::new(shape, None, name)
855    }
856
857    #[cfg(unix)]
858    fn from_fd(fd: std::os::fd::OwnedFd, shape: &[usize], name: Option<&str>) -> Result<Self>
859    where
860        Self: Sized,
861    {
862        Ok(Self::wrap(TensorStorage::from_fd(fd, shape, name)?))
863    }
864
865    #[cfg(unix)]
866    fn clone_fd(&self) -> Result<std::os::fd::OwnedFd> {
867        self.storage.clone_fd()
868    }
869
870    fn memory(&self) -> TensorMemory {
871        self.storage.memory()
872    }
873
874    fn name(&self) -> String {
875        self.storage.name()
876    }
877
878    fn shape(&self) -> &[usize] {
879        self.storage.shape()
880    }
881
882    fn reshape(&mut self, shape: &[usize]) -> Result<()> {
883        if self.chroma.is_some() {
884            return Err(Error::InvalidOperation(
885                "cannot reshape a multiplane tensor — decompose planes first".into(),
886            ));
887        }
888        self.storage.reshape(shape)?;
889        self.format = None;
890        Ok(())
891    }
892
893    fn map(&self) -> Result<TensorMap<T>> {
894        self.storage.map()
895    }
896
897    fn buffer_identity(&self) -> &BufferIdentity {
898        self.storage.buffer_identity()
899    }
900}
901
902pub enum TensorMap<T>
903where
904    T: Num + Clone + fmt::Debug,
905{
906    #[cfg(target_os = "linux")]
907    Dma(DmaMap<T>),
908    #[cfg(unix)]
909    Shm(ShmMap<T>),
910    Mem(MemMap<T>),
911    Pbo(PboMap<T>),
912}
913
914impl<T> TensorMapTrait<T> for TensorMap<T>
915where
916    T: Num + Clone + fmt::Debug,
917{
918    fn shape(&self) -> &[usize] {
919        match self {
920            #[cfg(target_os = "linux")]
921            TensorMap::Dma(map) => map.shape(),
922            #[cfg(unix)]
923            TensorMap::Shm(map) => map.shape(),
924            TensorMap::Mem(map) => map.shape(),
925            TensorMap::Pbo(map) => map.shape(),
926        }
927    }
928
929    fn unmap(&mut self) {
930        match self {
931            #[cfg(target_os = "linux")]
932            TensorMap::Dma(map) => map.unmap(),
933            #[cfg(unix)]
934            TensorMap::Shm(map) => map.unmap(),
935            TensorMap::Mem(map) => map.unmap(),
936            TensorMap::Pbo(map) => map.unmap(),
937        }
938    }
939
940    fn as_slice(&self) -> &[T] {
941        match self {
942            #[cfg(target_os = "linux")]
943            TensorMap::Dma(map) => map.as_slice(),
944            #[cfg(unix)]
945            TensorMap::Shm(map) => map.as_slice(),
946            TensorMap::Mem(map) => map.as_slice(),
947            TensorMap::Pbo(map) => map.as_slice(),
948        }
949    }
950
951    fn as_mut_slice(&mut self) -> &mut [T] {
952        match self {
953            #[cfg(target_os = "linux")]
954            TensorMap::Dma(map) => map.as_mut_slice(),
955            #[cfg(unix)]
956            TensorMap::Shm(map) => map.as_mut_slice(),
957            TensorMap::Mem(map) => map.as_mut_slice(),
958            TensorMap::Pbo(map) => map.as_mut_slice(),
959        }
960    }
961}
962
963impl<T> Deref for TensorMap<T>
964where
965    T: Num + Clone + fmt::Debug,
966{
967    type Target = [T];
968
969    fn deref(&self) -> &[T] {
970        match self {
971            #[cfg(target_os = "linux")]
972            TensorMap::Dma(map) => map.deref(),
973            #[cfg(unix)]
974            TensorMap::Shm(map) => map.deref(),
975            TensorMap::Mem(map) => map.deref(),
976            TensorMap::Pbo(map) => map.deref(),
977        }
978    }
979}
980
981impl<T> DerefMut for TensorMap<T>
982where
983    T: Num + Clone + fmt::Debug,
984{
985    fn deref_mut(&mut self) -> &mut [T] {
986        match self {
987            #[cfg(target_os = "linux")]
988            TensorMap::Dma(map) => map.deref_mut(),
989            #[cfg(unix)]
990            TensorMap::Shm(map) => map.deref_mut(),
991            TensorMap::Mem(map) => map.deref_mut(),
992            TensorMap::Pbo(map) => map.deref_mut(),
993        }
994    }
995}
996
997// ============================================================================
998// Platform availability helpers
999// ============================================================================
1000
1001/// Check if DMA memory allocation is available on this system.
1002///
1003/// Returns `true` only on Linux systems with DMA-BUF heap access (typically
1004/// requires running as root or membership in a video/render group).
1005/// Always returns `false` on non-Linux platforms (macOS, Windows, etc.).
1006///
1007/// This function caches its result after the first call for efficiency.
1008#[cfg(target_os = "linux")]
1009static DMA_AVAILABLE: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
1010
1011/// Check if DMA memory allocation is available on this system.
1012#[cfg(target_os = "linux")]
1013pub fn is_dma_available() -> bool {
1014    *DMA_AVAILABLE.get_or_init(|| Tensor::<u8>::new(&[64], Some(TensorMemory::Dma), None).is_ok())
1015}
1016
1017/// Check if DMA memory allocation is available on this system.
1018///
1019/// Always returns `false` on non-Linux platforms since DMA-BUF is Linux-specific.
1020#[cfg(not(target_os = "linux"))]
1021pub fn is_dma_available() -> bool {
1022    false
1023}
1024
1025/// Check if POSIX shared memory allocation is available on this system.
1026///
1027/// Returns `true` on Unix systems (Linux, macOS, BSD) where POSIX shared memory
1028/// is supported. Always returns `false` on non-Unix platforms (Windows).
1029///
1030/// This function caches its result after the first call for efficiency.
1031#[cfg(unix)]
1032static SHM_AVAILABLE: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
1033
1034/// Check if POSIX shared memory allocation is available on this system.
1035#[cfg(unix)]
1036pub fn is_shm_available() -> bool {
1037    *SHM_AVAILABLE.get_or_init(|| Tensor::<u8>::new(&[64], Some(TensorMemory::Shm), None).is_ok())
1038}
1039
1040/// Check if POSIX shared memory allocation is available on this system.
1041///
1042/// Always returns `false` on non-Unix platforms since POSIX SHM is Unix-specific.
1043#[cfg(not(unix))]
1044pub fn is_shm_available() -> bool {
1045    false
1046}
1047
1048#[cfg(test)]
1049mod dtype_tests {
1050    use super::*;
1051
1052    #[test]
1053    fn dtype_size() {
1054        assert_eq!(DType::U8.size(), 1);
1055        assert_eq!(DType::I8.size(), 1);
1056        assert_eq!(DType::U16.size(), 2);
1057        assert_eq!(DType::I16.size(), 2);
1058        assert_eq!(DType::U32.size(), 4);
1059        assert_eq!(DType::I32.size(), 4);
1060        assert_eq!(DType::U64.size(), 8);
1061        assert_eq!(DType::I64.size(), 8);
1062        assert_eq!(DType::F16.size(), 2);
1063        assert_eq!(DType::F32.size(), 4);
1064        assert_eq!(DType::F64.size(), 8);
1065    }
1066
1067    #[test]
1068    fn dtype_name() {
1069        assert_eq!(DType::U8.name(), "u8");
1070        assert_eq!(DType::F16.name(), "f16");
1071        assert_eq!(DType::F32.name(), "f32");
1072    }
1073
1074    #[test]
1075    fn dtype_serde_roundtrip() {
1076        use serde_json;
1077        let dt = DType::F16;
1078        let json = serde_json::to_string(&dt).unwrap();
1079        let back: DType = serde_json::from_str(&json).unwrap();
1080        assert_eq!(dt, back);
1081    }
1082}
1083
1084#[cfg(test)]
1085mod image_tests {
1086    use super::*;
1087
1088    #[test]
1089    fn raw_tensor_has_no_format() {
1090        let t = Tensor::<u8>::new(&[480, 640, 3], None, None).unwrap();
1091        assert!(t.format().is_none());
1092        assert!(t.width().is_none());
1093        assert!(t.height().is_none());
1094        assert!(!t.is_multiplane());
1095        assert!(t.chroma().is_none());
1096    }
1097
1098    #[test]
1099    fn image_tensor_packed() {
1100        let t = Tensor::<u8>::image(640, 480, PixelFormat::Rgba, None).unwrap();
1101        assert_eq!(t.format(), Some(PixelFormat::Rgba));
1102        assert_eq!(t.width(), Some(640));
1103        assert_eq!(t.height(), Some(480));
1104        assert_eq!(t.shape(), &[480, 640, 4]);
1105        assert!(!t.is_multiplane());
1106    }
1107
1108    #[test]
1109    fn image_tensor_planar() {
1110        let t = Tensor::<u8>::image(640, 480, PixelFormat::PlanarRgb, None).unwrap();
1111        assert_eq!(t.format(), Some(PixelFormat::PlanarRgb));
1112        assert_eq!(t.width(), Some(640));
1113        assert_eq!(t.height(), Some(480));
1114        assert_eq!(t.shape(), &[3, 480, 640]);
1115    }
1116
1117    #[test]
1118    fn image_tensor_semi_planar_contiguous() {
1119        let t = Tensor::<u8>::image(640, 480, PixelFormat::Nv12, None).unwrap();
1120        assert_eq!(t.format(), Some(PixelFormat::Nv12));
1121        assert_eq!(t.width(), Some(640));
1122        assert_eq!(t.height(), Some(480));
1123        // NV12: H*3/2 = 720
1124        assert_eq!(t.shape(), &[720, 640]);
1125        assert!(!t.is_multiplane());
1126    }
1127
1128    #[test]
1129    fn set_format_valid() {
1130        let mut t = Tensor::<u8>::new(&[480, 640, 3], None, None).unwrap();
1131        assert!(t.format().is_none());
1132        t.set_format(PixelFormat::Rgb).unwrap();
1133        assert_eq!(t.format(), Some(PixelFormat::Rgb));
1134        assert_eq!(t.width(), Some(640));
1135        assert_eq!(t.height(), Some(480));
1136    }
1137
1138    #[test]
1139    fn set_format_invalid_shape() {
1140        let mut t = Tensor::<u8>::new(&[480, 640, 4], None, None).unwrap();
1141        // RGB expects 3 channels, not 4
1142        let err = t.set_format(PixelFormat::Rgb);
1143        assert!(err.is_err());
1144        // Original tensor is unmodified
1145        assert!(t.format().is_none());
1146    }
1147
1148    #[test]
1149    fn reshape_clears_format() {
1150        let mut t = Tensor::<u8>::image(640, 480, PixelFormat::Rgba, None).unwrap();
1151        assert_eq!(t.format(), Some(PixelFormat::Rgba));
1152        // Reshape to flat — format cleared
1153        t.reshape(&[480 * 640 * 4]).unwrap();
1154        assert!(t.format().is_none());
1155    }
1156
1157    #[test]
1158    fn from_planes_nv12() {
1159        let y = Tensor::<u8>::new(&[480, 640], None, None).unwrap();
1160        let uv = Tensor::<u8>::new(&[240, 640], None, None).unwrap();
1161        let img = Tensor::from_planes(y, uv, PixelFormat::Nv12).unwrap();
1162        assert_eq!(img.format(), Some(PixelFormat::Nv12));
1163        assert!(img.is_multiplane());
1164        assert!(img.chroma().is_some());
1165        assert_eq!(img.width(), Some(640));
1166        assert_eq!(img.height(), Some(480));
1167    }
1168
1169    #[test]
1170    fn from_planes_rejects_non_semiplanar() {
1171        let y = Tensor::<u8>::new(&[480, 640], None, None).unwrap();
1172        let uv = Tensor::<u8>::new(&[240, 640], None, None).unwrap();
1173        let err = Tensor::from_planes(y, uv, PixelFormat::Rgb);
1174        assert!(err.is_err());
1175    }
1176
1177    #[test]
1178    fn reshape_multiplane_errors() {
1179        let y = Tensor::<u8>::new(&[480, 640], None, None).unwrap();
1180        let uv = Tensor::<u8>::new(&[240, 640], None, None).unwrap();
1181        let mut img = Tensor::from_planes(y, uv, PixelFormat::Nv12).unwrap();
1182        let err = img.reshape(&[480 * 640 + 240 * 640]);
1183        assert!(err.is_err());
1184    }
1185}
1186
1187#[cfg(test)]
1188mod tests {
1189    #[cfg(target_os = "linux")]
1190    use nix::unistd::{access, AccessFlags};
1191    #[cfg(target_os = "linux")]
1192    use std::io::Write as _;
1193    use std::sync::RwLock;
1194
1195    use super::*;
1196
1197    #[ctor::ctor]
1198    fn init() {
1199        env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init();
1200    }
1201
1202    /// Macro to get the current function name for logging in tests.
1203    #[cfg(target_os = "linux")]
1204    macro_rules! function {
1205        () => {{
1206            fn f() {}
1207            fn type_name_of<T>(_: T) -> &'static str {
1208                std::any::type_name::<T>()
1209            }
1210            let name = type_name_of(f);
1211
1212            // Find and cut the rest of the path
1213            match &name[..name.len() - 3].rfind(':') {
1214                Some(pos) => &name[pos + 1..name.len() - 3],
1215                None => &name[..name.len() - 3],
1216            }
1217        }};
1218    }
1219
1220    #[test]
1221    #[cfg(target_os = "linux")]
1222    fn test_tensor() {
1223        let _lock = FD_LOCK.read().unwrap();
1224        let shape = vec![1];
1225        let tensor = DmaTensor::<f32>::new(&shape, Some("dma_tensor"));
1226        let dma_enabled = tensor.is_ok();
1227
1228        let tensor = Tensor::<f32>::new(&shape, None, None).expect("Failed to create tensor");
1229        match dma_enabled {
1230            true => assert_eq!(tensor.memory(), TensorMemory::Dma),
1231            false => assert_eq!(tensor.memory(), TensorMemory::Shm),
1232        }
1233    }
1234
1235    #[test]
1236    #[cfg(all(unix, not(target_os = "linux")))]
1237    fn test_tensor() {
1238        let shape = vec![1];
1239        let tensor = Tensor::<f32>::new(&shape, None, None).expect("Failed to create tensor");
1240        // On macOS/BSD, auto-detection tries SHM first, falls back to Mem
1241        assert!(
1242            tensor.memory() == TensorMemory::Shm || tensor.memory() == TensorMemory::Mem,
1243            "Expected SHM or Mem on macOS, got {:?}",
1244            tensor.memory()
1245        );
1246    }
1247
1248    #[test]
1249    #[cfg(not(unix))]
1250    fn test_tensor() {
1251        let shape = vec![1];
1252        let tensor = Tensor::<f32>::new(&shape, None, None).expect("Failed to create tensor");
1253        assert_eq!(tensor.memory(), TensorMemory::Mem);
1254    }
1255
1256    #[test]
1257    #[cfg(target_os = "linux")]
1258    fn test_dma_tensor() {
1259        let _lock = FD_LOCK.read().unwrap();
1260        match access(
1261            "/dev/dma_heap/linux,cma",
1262            AccessFlags::R_OK | AccessFlags::W_OK,
1263        ) {
1264            Ok(_) => println!("/dev/dma_heap/linux,cma is available"),
1265            Err(_) => match access(
1266                "/dev/dma_heap/system",
1267                AccessFlags::R_OK | AccessFlags::W_OK,
1268            ) {
1269                Ok(_) => println!("/dev/dma_heap/system is available"),
1270                Err(e) => {
1271                    writeln!(
1272                        &mut std::io::stdout(),
1273                        "[WARNING] DMA Heap is unavailable: {e}"
1274                    )
1275                    .unwrap();
1276                    return;
1277                }
1278            },
1279        }
1280
1281        let shape = vec![2, 3, 4];
1282        let tensor =
1283            DmaTensor::<f32>::new(&shape, Some("test_tensor")).expect("Failed to create tensor");
1284
1285        const DUMMY_VALUE: f32 = 12.34;
1286
1287        assert_eq!(tensor.memory(), TensorMemory::Dma);
1288        assert_eq!(tensor.name(), "test_tensor");
1289        assert_eq!(tensor.shape(), &shape);
1290        assert_eq!(tensor.size(), 2 * 3 * 4 * std::mem::size_of::<f32>());
1291        assert_eq!(tensor.len(), 2 * 3 * 4);
1292
1293        {
1294            let mut tensor_map = tensor.map().expect("Failed to map DMA memory");
1295            tensor_map.fill(42.0);
1296            assert!(tensor_map.iter().all(|&x| x == 42.0));
1297        }
1298
1299        {
1300            let shared = Tensor::<f32>::from_fd(
1301                tensor
1302                    .clone_fd()
1303                    .expect("Failed to duplicate tensor file descriptor"),
1304                &shape,
1305                Some("test_tensor_shared"),
1306            )
1307            .expect("Failed to create tensor from fd");
1308
1309            assert_eq!(shared.memory(), TensorMemory::Dma);
1310            assert_eq!(shared.name(), "test_tensor_shared");
1311            assert_eq!(shared.shape(), &shape);
1312
1313            let mut tensor_map = shared.map().expect("Failed to map DMA memory from fd");
1314            tensor_map.fill(DUMMY_VALUE);
1315            assert!(tensor_map.iter().all(|&x| x == DUMMY_VALUE));
1316        }
1317
1318        {
1319            let tensor_map = tensor.map().expect("Failed to map DMA memory");
1320            assert!(tensor_map.iter().all(|&x| x == DUMMY_VALUE));
1321        }
1322
1323        let mut tensor = DmaTensor::<u8>::new(&shape, None).expect("Failed to create tensor");
1324        assert_eq!(tensor.shape(), &shape);
1325        let new_shape = vec![3, 4, 4];
1326        assert!(
1327            tensor.reshape(&new_shape).is_err(),
1328            "Reshape should fail due to size mismatch"
1329        );
1330        assert_eq!(tensor.shape(), &shape, "Shape should remain unchanged");
1331
1332        let new_shape = vec![2, 3, 4];
1333        tensor.reshape(&new_shape).expect("Reshape should succeed");
1334        assert_eq!(
1335            tensor.shape(),
1336            &new_shape,
1337            "Shape should be updated after successful reshape"
1338        );
1339
1340        {
1341            let mut tensor_map = tensor.map().expect("Failed to map DMA memory");
1342            tensor_map.fill(1);
1343            assert!(tensor_map.iter().all(|&x| x == 1));
1344        }
1345
1346        {
1347            let mut tensor_map = tensor.map().expect("Failed to map DMA memory");
1348            tensor_map[2] = 42;
1349            assert_eq!(tensor_map[1], 1, "Value at index 1 should be 1");
1350            assert_eq!(tensor_map[2], 42, "Value at index 2 should be 42");
1351        }
1352    }
1353
1354    #[test]
1355    #[cfg(unix)]
1356    fn test_shm_tensor() {
1357        let _lock = FD_LOCK.read().unwrap();
1358        let shape = vec![2, 3, 4];
1359        let tensor =
1360            ShmTensor::<f32>::new(&shape, Some("test_tensor")).expect("Failed to create tensor");
1361        assert_eq!(tensor.shape(), &shape);
1362        assert_eq!(tensor.size(), 2 * 3 * 4 * std::mem::size_of::<f32>());
1363        assert_eq!(tensor.name(), "test_tensor");
1364
1365        const DUMMY_VALUE: f32 = 12.34;
1366        {
1367            let mut tensor_map = tensor.map().expect("Failed to map shared memory");
1368            tensor_map.fill(42.0);
1369            assert!(tensor_map.iter().all(|&x| x == 42.0));
1370        }
1371
1372        {
1373            let shared = Tensor::<f32>::from_fd(
1374                tensor
1375                    .clone_fd()
1376                    .expect("Failed to duplicate tensor file descriptor"),
1377                &shape,
1378                Some("test_tensor_shared"),
1379            )
1380            .expect("Failed to create tensor from fd");
1381
1382            assert_eq!(shared.memory(), TensorMemory::Shm);
1383            assert_eq!(shared.name(), "test_tensor_shared");
1384            assert_eq!(shared.shape(), &shape);
1385
1386            let mut tensor_map = shared.map().expect("Failed to map shared memory from fd");
1387            tensor_map.fill(DUMMY_VALUE);
1388            assert!(tensor_map.iter().all(|&x| x == DUMMY_VALUE));
1389        }
1390
1391        {
1392            let tensor_map = tensor.map().expect("Failed to map shared memory");
1393            assert!(tensor_map.iter().all(|&x| x == DUMMY_VALUE));
1394        }
1395
1396        let mut tensor = ShmTensor::<u8>::new(&shape, None).expect("Failed to create tensor");
1397        assert_eq!(tensor.shape(), &shape);
1398        let new_shape = vec![3, 4, 4];
1399        assert!(
1400            tensor.reshape(&new_shape).is_err(),
1401            "Reshape should fail due to size mismatch"
1402        );
1403        assert_eq!(tensor.shape(), &shape, "Shape should remain unchanged");
1404
1405        let new_shape = vec![2, 3, 4];
1406        tensor.reshape(&new_shape).expect("Reshape should succeed");
1407        assert_eq!(
1408            tensor.shape(),
1409            &new_shape,
1410            "Shape should be updated after successful reshape"
1411        );
1412
1413        {
1414            let mut tensor_map = tensor.map().expect("Failed to map shared memory");
1415            tensor_map.fill(1);
1416            assert!(tensor_map.iter().all(|&x| x == 1));
1417        }
1418
1419        {
1420            let mut tensor_map = tensor.map().expect("Failed to map shared memory");
1421            tensor_map[2] = 42;
1422            assert_eq!(tensor_map[1], 1, "Value at index 1 should be 1");
1423            assert_eq!(tensor_map[2], 42, "Value at index 2 should be 42");
1424        }
1425    }
1426
1427    #[test]
1428    fn test_mem_tensor() {
1429        let shape = vec![2, 3, 4];
1430        let tensor =
1431            MemTensor::<f32>::new(&shape, Some("test_tensor")).expect("Failed to create tensor");
1432        assert_eq!(tensor.shape(), &shape);
1433        assert_eq!(tensor.size(), 2 * 3 * 4 * std::mem::size_of::<f32>());
1434        assert_eq!(tensor.name(), "test_tensor");
1435
1436        {
1437            let mut tensor_map = tensor.map().expect("Failed to map memory");
1438            tensor_map.fill(42.0);
1439            assert!(tensor_map.iter().all(|&x| x == 42.0));
1440        }
1441
1442        let mut tensor = MemTensor::<u8>::new(&shape, None).expect("Failed to create tensor");
1443        assert_eq!(tensor.shape(), &shape);
1444        let new_shape = vec![3, 4, 4];
1445        assert!(
1446            tensor.reshape(&new_shape).is_err(),
1447            "Reshape should fail due to size mismatch"
1448        );
1449        assert_eq!(tensor.shape(), &shape, "Shape should remain unchanged");
1450
1451        let new_shape = vec![2, 3, 4];
1452        tensor.reshape(&new_shape).expect("Reshape should succeed");
1453        assert_eq!(
1454            tensor.shape(),
1455            &new_shape,
1456            "Shape should be updated after successful reshape"
1457        );
1458
1459        {
1460            let mut tensor_map = tensor.map().expect("Failed to map memory");
1461            tensor_map.fill(1);
1462            assert!(tensor_map.iter().all(|&x| x == 1));
1463        }
1464
1465        {
1466            let mut tensor_map = tensor.map().expect("Failed to map memory");
1467            tensor_map[2] = 42;
1468            assert_eq!(tensor_map[1], 1, "Value at index 1 should be 1");
1469            assert_eq!(tensor_map[2], 42, "Value at index 2 should be 42");
1470        }
1471    }
1472
1473    #[test]
1474    #[cfg(target_os = "linux")]
1475    fn test_dma_no_fd_leaks() {
1476        let _lock = FD_LOCK.write().unwrap();
1477        if !is_dma_available() {
1478            log::warn!(
1479                "SKIPPED: {} - DMA memory allocation not available (permission denied or no DMA-BUF support)",
1480                function!()
1481            );
1482            return;
1483        }
1484
1485        let proc = procfs::process::Process::myself()
1486            .expect("Failed to get current process using /proc/self");
1487
1488        let start_open_fds = proc
1489            .fd_count()
1490            .expect("Failed to get open file descriptor count");
1491
1492        for _ in 0..100 {
1493            let tensor = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Dma), None)
1494                .expect("Failed to create tensor");
1495            let mut map = tensor.map().unwrap();
1496            map.as_mut_slice().fill(233);
1497        }
1498
1499        let end_open_fds = proc
1500            .fd_count()
1501            .expect("Failed to get open file descriptor count");
1502
1503        assert_eq!(
1504            start_open_fds, end_open_fds,
1505            "File descriptor leak detected: {} -> {}",
1506            start_open_fds, end_open_fds
1507        );
1508    }
1509
1510    #[test]
1511    #[cfg(target_os = "linux")]
1512    fn test_dma_from_fd_no_fd_leaks() {
1513        let _lock = FD_LOCK.write().unwrap();
1514        if !is_dma_available() {
1515            log::warn!(
1516                "SKIPPED: {} - DMA memory allocation not available (permission denied or no DMA-BUF support)",
1517                function!()
1518            );
1519            return;
1520        }
1521
1522        let proc = procfs::process::Process::myself()
1523            .expect("Failed to get current process using /proc/self");
1524
1525        let start_open_fds = proc
1526            .fd_count()
1527            .expect("Failed to get open file descriptor count");
1528
1529        let orig = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Dma), None).unwrap();
1530
1531        for _ in 0..100 {
1532            let tensor =
1533                Tensor::<u8>::from_fd(orig.clone_fd().unwrap(), orig.shape(), None).unwrap();
1534            let mut map = tensor.map().unwrap();
1535            map.as_mut_slice().fill(233);
1536        }
1537        drop(orig);
1538
1539        let end_open_fds = proc.fd_count().unwrap();
1540
1541        assert_eq!(
1542            start_open_fds, end_open_fds,
1543            "File descriptor leak detected: {} -> {}",
1544            start_open_fds, end_open_fds
1545        );
1546    }
1547
1548    #[test]
1549    #[cfg(target_os = "linux")]
1550    fn test_shm_no_fd_leaks() {
1551        let _lock = FD_LOCK.write().unwrap();
1552        if !is_shm_available() {
1553            log::warn!(
1554                "SKIPPED: {} - SHM memory allocation not available (permission denied or no SHM support)",
1555                function!()
1556            );
1557            return;
1558        }
1559
1560        let proc = procfs::process::Process::myself()
1561            .expect("Failed to get current process using /proc/self");
1562
1563        let start_open_fds = proc
1564            .fd_count()
1565            .expect("Failed to get open file descriptor count");
1566
1567        for _ in 0..100 {
1568            let tensor = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Shm), None)
1569                .expect("Failed to create tensor");
1570            let mut map = tensor.map().unwrap();
1571            map.as_mut_slice().fill(233);
1572        }
1573
1574        let end_open_fds = proc
1575            .fd_count()
1576            .expect("Failed to get open file descriptor count");
1577
1578        assert_eq!(
1579            start_open_fds, end_open_fds,
1580            "File descriptor leak detected: {} -> {}",
1581            start_open_fds, end_open_fds
1582        );
1583    }
1584
1585    #[test]
1586    #[cfg(target_os = "linux")]
1587    fn test_shm_from_fd_no_fd_leaks() {
1588        let _lock = FD_LOCK.write().unwrap();
1589        if !is_shm_available() {
1590            log::warn!(
1591                "SKIPPED: {} - SHM memory allocation not available (permission denied or no SHM support)",
1592                function!()
1593            );
1594            return;
1595        }
1596
1597        let proc = procfs::process::Process::myself()
1598            .expect("Failed to get current process using /proc/self");
1599
1600        let start_open_fds = proc
1601            .fd_count()
1602            .expect("Failed to get open file descriptor count");
1603
1604        let orig = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Shm), None).unwrap();
1605
1606        for _ in 0..100 {
1607            let tensor =
1608                Tensor::<u8>::from_fd(orig.clone_fd().unwrap(), orig.shape(), None).unwrap();
1609            let mut map = tensor.map().unwrap();
1610            map.as_mut_slice().fill(233);
1611        }
1612        drop(orig);
1613
1614        let end_open_fds = proc.fd_count().unwrap();
1615
1616        assert_eq!(
1617            start_open_fds, end_open_fds,
1618            "File descriptor leak detected: {} -> {}",
1619            start_open_fds, end_open_fds
1620        );
1621    }
1622
1623    #[cfg(feature = "ndarray")]
1624    #[test]
1625    fn test_ndarray() {
1626        let _lock = FD_LOCK.read().unwrap();
1627        let shape = vec![2, 3, 4];
1628        let tensor = Tensor::<f32>::new(&shape, None, None).expect("Failed to create tensor");
1629
1630        let mut tensor_map = tensor.map().expect("Failed to map tensor memory");
1631        tensor_map.fill(1.0);
1632
1633        let view = tensor_map.view().expect("Failed to get ndarray view");
1634        assert_eq!(view.shape(), &[2, 3, 4]);
1635        assert!(view.iter().all(|&x| x == 1.0));
1636
1637        let mut view_mut = tensor_map
1638            .view_mut()
1639            .expect("Failed to get mutable ndarray view");
1640        view_mut[[0, 0, 0]] = 42.0;
1641        assert_eq!(view_mut[[0, 0, 0]], 42.0);
1642        assert_eq!(tensor_map[0], 42.0, "Value at index 0 should be 42");
1643    }
1644
1645    #[test]
1646    fn test_buffer_identity_unique() {
1647        let id1 = BufferIdentity::new();
1648        let id2 = BufferIdentity::new();
1649        assert_ne!(
1650            id1.id(),
1651            id2.id(),
1652            "Two identities should have different ids"
1653        );
1654    }
1655
1656    #[test]
1657    fn test_buffer_identity_clone_shares_guard() {
1658        let id1 = BufferIdentity::new();
1659        let weak = id1.weak();
1660        assert!(
1661            weak.upgrade().is_some(),
1662            "Weak should be alive while original exists"
1663        );
1664
1665        let id2 = id1.clone();
1666        assert_eq!(id1.id(), id2.id(), "Cloned identity should have same id");
1667
1668        drop(id1);
1669        assert!(
1670            weak.upgrade().is_some(),
1671            "Weak should still be alive (clone holds Arc)"
1672        );
1673
1674        drop(id2);
1675        assert!(
1676            weak.upgrade().is_none(),
1677            "Weak should be dead after all clones dropped"
1678        );
1679    }
1680
1681    #[test]
1682    fn test_tensor_buffer_identity() {
1683        let t1 = Tensor::<u8>::new(&[100], Some(TensorMemory::Mem), Some("t1")).unwrap();
1684        let t2 = Tensor::<u8>::new(&[100], Some(TensorMemory::Mem), Some("t2")).unwrap();
1685        assert_ne!(
1686            t1.buffer_identity().id(),
1687            t2.buffer_identity().id(),
1688            "Different tensors should have different buffer ids"
1689        );
1690    }
1691
1692    // Any test that cares about the fd count must grab it exclusively.
1693    // Any tests which modifies the fd count by opening or closing fds must grab it
1694    // shared.
1695    pub static FD_LOCK: RwLock<()> = RwLock::new(());
1696
1697    /// Test that DMA is NOT available on non-Linux platforms.
1698    /// This verifies the cross-platform behavior of is_dma_available().
1699    #[test]
1700    #[cfg(not(target_os = "linux"))]
1701    fn test_dma_not_available_on_non_linux() {
1702        assert!(
1703            !is_dma_available(),
1704            "DMA memory allocation should NOT be available on non-Linux platforms"
1705        );
1706    }
1707
1708    /// Test that SHM memory allocation is available and usable on Unix systems.
1709    /// This is a basic functional test; Linux has additional FD leak tests using procfs.
1710    #[test]
1711    #[cfg(unix)]
1712    fn test_shm_available_and_usable() {
1713        assert!(
1714            is_shm_available(),
1715            "SHM memory allocation should be available on Unix systems"
1716        );
1717
1718        // Create a tensor with SHM backing
1719        let tensor = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Shm), None)
1720            .expect("Failed to create SHM tensor");
1721
1722        // Verify we can map and write to it
1723        let mut map = tensor.map().expect("Failed to map SHM tensor");
1724        map.as_mut_slice().fill(0xAB);
1725
1726        // Verify the data was written correctly
1727        assert!(
1728            map.as_slice().iter().all(|&b| b == 0xAB),
1729            "SHM tensor data should be writable and readable"
1730        );
1731    }
1732}