Skip to main content

edgefirst_tensor/
lib.rs

1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4/*!
5EdgeFirst HAL - Tensor Module
6
7The `edgefirst_tensor` crate provides a unified interface for managing multi-dimensional arrays (tensors)
8with support for different memory types, including Direct Memory Access (DMA), POSIX Shared Memory (Shm),
9and system memory. The crate defines traits and structures for creating, reshaping, and mapping tensors into memory.
10
11## Examples
12```rust
13use edgefirst_tensor::{Error, Tensor, TensorMemory, TensorTrait};
14# fn main() -> Result<(), Error> {
15let tensor = Tensor::<f32>::new(&[2, 3, 4], Some(TensorMemory::Mem), Some("test_tensor"))?;
16assert_eq!(tensor.memory(), TensorMemory::Mem);
17assert_eq!(tensor.name(), "test_tensor");
18#    Ok(())
19# }
20```
21
22## Overview
23The main structures and traits provided by the `edgefirst_tensor` crate are `TensorTrait` and `TensorMapTrait`,
24which define the behavior of Tensors and their memory mappings, respectively.
25The `Tensor<T>` struct wraps a backend-specific storage with optional image format metadata (`PixelFormat`),
26while the `TensorMap` enum provides access to the underlying data. The `TensorDyn` type-erased enum
27wraps `Tensor<T>` for runtime element-type dispatch.
28 */
29#[cfg(target_os = "linux")]
30mod dma;
31#[cfg(target_os = "linux")]
32mod dmabuf;
33mod error;
34mod format;
35mod mem;
36mod pbo;
37#[cfg(unix)]
38mod shm;
39mod tensor_dyn;
40
41#[cfg(target_os = "linux")]
42pub use crate::dma::{DmaMap, DmaTensor};
43pub use crate::mem::{MemMap, MemTensor};
44pub use crate::pbo::{PboMap, PboMapping, PboOps, PboTensor};
45#[cfg(unix)]
46pub use crate::shm::{ShmMap, ShmTensor};
47pub use error::{Error, Result};
48pub use format::{PixelFormat, PixelLayout};
49use num_traits::Num;
50use serde::{Deserialize, Serialize};
51#[cfg(unix)]
52use std::os::fd::OwnedFd;
53use std::{
54    fmt,
55    ops::{Deref, DerefMut},
56    sync::{
57        atomic::{AtomicU64, Ordering},
58        Arc, Weak,
59    },
60};
61pub use tensor_dyn::TensorDyn;
62
63/// Element type discriminant for runtime type identification.
64#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
65#[repr(u8)]
66#[non_exhaustive]
67pub enum DType {
68    U8,
69    I8,
70    U16,
71    I16,
72    U32,
73    I32,
74    U64,
75    I64,
76    F16,
77    F32,
78    F64,
79}
80
81impl DType {
82    /// Size of one element in bytes.
83    pub const fn size(&self) -> usize {
84        match self {
85            Self::U8 | Self::I8 => 1,
86            Self::U16 | Self::I16 | Self::F16 => 2,
87            Self::U32 | Self::I32 | Self::F32 => 4,
88            Self::U64 | Self::I64 | Self::F64 => 8,
89        }
90    }
91
92    /// Short type name (e.g., "u8", "f32", "f16").
93    pub const fn name(&self) -> &'static str {
94        match self {
95            Self::U8 => "u8",
96            Self::I8 => "i8",
97            Self::U16 => "u16",
98            Self::I16 => "i16",
99            Self::U32 => "u32",
100            Self::I32 => "i32",
101            Self::U64 => "u64",
102            Self::I64 => "i64",
103            Self::F16 => "f16",
104            Self::F32 => "f32",
105            Self::F64 => "f64",
106        }
107    }
108}
109
110impl fmt::Display for DType {
111    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
112        f.write_str(self.name())
113    }
114}
115
116/// Monotonic counter for buffer identity IDs.
117static NEXT_BUFFER_ID: AtomicU64 = AtomicU64::new(1);
118
119/// Unique identity for a tensor's underlying buffer.
120///
121/// Created fresh on every buffer allocation or import. The `id` is a monotonic
122/// u64 used as a cache key. The `guard` is an `Arc<()>` whose weak references
123/// allow downstream caches to detect when the buffer has been dropped.
124#[derive(Debug, Clone)]
125pub struct BufferIdentity {
126    id: u64,
127    guard: Arc<()>,
128}
129
130impl BufferIdentity {
131    /// Create a new unique buffer identity.
132    pub fn new() -> Self {
133        Self {
134            id: NEXT_BUFFER_ID.fetch_add(1, Ordering::Relaxed),
135            guard: Arc::new(()),
136        }
137    }
138
139    /// Unique identifier for this buffer. Changes when the buffer changes.
140    pub fn id(&self) -> u64 {
141        self.id
142    }
143
144    /// Returns a weak reference to the buffer guard. Goes dead when the
145    /// owning Tensor is dropped (and no clones remain).
146    pub fn weak(&self) -> Weak<()> {
147        Arc::downgrade(&self.guard)
148    }
149}
150
151impl Default for BufferIdentity {
152    fn default() -> Self {
153        Self::new()
154    }
155}
156
157#[cfg(target_os = "linux")]
158use nix::sys::stat::{major, minor};
159
160pub trait TensorTrait<T>: Send + Sync
161where
162    T: Num + Clone + fmt::Debug,
163{
164    /// Create a new tensor with the given shape and optional name. If no name
165    /// is given, a random name will be generated.
166    fn new(shape: &[usize], name: Option<&str>) -> Result<Self>
167    where
168        Self: Sized;
169
170    #[cfg(unix)]
171    /// Create a new tensor using the given file descriptor, shape, and optional
172    /// name. If no name is given, a random name will be generated.
173    ///
174    /// On Linux: Inspects the fd to determine DMA vs SHM based on device major/minor.
175    /// On other Unix (macOS): Always creates SHM tensor.
176    fn from_fd(fd: std::os::fd::OwnedFd, shape: &[usize], name: Option<&str>) -> Result<Self>
177    where
178        Self: Sized;
179
180    #[cfg(unix)]
181    /// Clone the file descriptor associated with this tensor.
182    fn clone_fd(&self) -> Result<std::os::fd::OwnedFd>;
183
184    /// Get the memory type of this tensor.
185    fn memory(&self) -> TensorMemory;
186
187    /// Get the name of this tensor.
188    fn name(&self) -> String;
189
190    /// Get the number of elements in this tensor.
191    fn len(&self) -> usize {
192        self.shape().iter().product()
193    }
194
195    /// Check if the tensor is empty.
196    fn is_empty(&self) -> bool {
197        self.len() == 0
198    }
199
200    /// Get the size in bytes of this tensor.
201    fn size(&self) -> usize {
202        self.len() * std::mem::size_of::<T>()
203    }
204
205    /// Get the shape of this tensor.
206    fn shape(&self) -> &[usize];
207
208    /// Reshape this tensor to the given shape. The total number of elements
209    /// must remain the same.
210    fn reshape(&mut self, shape: &[usize]) -> Result<()>;
211
212    /// Map the tensor into memory and return a TensorMap for accessing the
213    /// data.
214    fn map(&self) -> Result<TensorMap<T>>;
215
216    /// Get the buffer identity for cache keying and liveness tracking.
217    fn buffer_identity(&self) -> &BufferIdentity;
218}
219
220pub trait TensorMapTrait<T>
221where
222    T: Num + Clone + fmt::Debug,
223{
224    /// Get the shape of this tensor map.
225    fn shape(&self) -> &[usize];
226
227    /// Unmap the tensor from memory.
228    fn unmap(&mut self);
229
230    /// Get the number of elements in this tensor map.
231    fn len(&self) -> usize {
232        self.shape().iter().product()
233    }
234
235    /// Check if the tensor map is empty.
236    fn is_empty(&self) -> bool {
237        self.len() == 0
238    }
239
240    /// Get the size in bytes of this tensor map.
241    fn size(&self) -> usize {
242        self.len() * std::mem::size_of::<T>()
243    }
244
245    /// Get a slice to the data in this tensor map.
246    fn as_slice(&self) -> &[T];
247
248    /// Get a mutable slice to the data in this tensor map.
249    fn as_mut_slice(&mut self) -> &mut [T];
250
251    #[cfg(feature = "ndarray")]
252    /// Get an ndarray ArrayView of the tensor data.
253    fn view(&'_ self) -> Result<ndarray::ArrayView<'_, T, ndarray::Dim<ndarray::IxDynImpl>>> {
254        Ok(ndarray::ArrayView::from_shape(
255            self.shape(),
256            self.as_slice(),
257        )?)
258    }
259
260    #[cfg(feature = "ndarray")]
261    /// Get an ndarray ArrayViewMut of the tensor data.
262    fn view_mut(
263        &'_ mut self,
264    ) -> Result<ndarray::ArrayViewMut<'_, T, ndarray::Dim<ndarray::IxDynImpl>>> {
265        let shape = self.shape().to_vec();
266        Ok(ndarray::ArrayViewMut::from_shape(
267            shape,
268            self.as_mut_slice(),
269        )?)
270    }
271}
272
273#[derive(Debug, Clone, Copy, PartialEq, Eq)]
274pub enum TensorMemory {
275    #[cfg(target_os = "linux")]
276    /// Direct Memory Access (DMA) allocation. Incurs additional
277    /// overhead for memory reading/writing with the CPU.  Allows for
278    /// hardware acceleration when supported.
279    Dma,
280    #[cfg(unix)]
281    /// POSIX Shared Memory allocation. Suitable for inter-process
282    /// communication, but not suitable for hardware acceleration.
283    Shm,
284
285    /// Regular system memory allocation
286    Mem,
287
288    /// OpenGL Pixel Buffer Object memory. Created by ImageProcessor
289    /// when DMA-buf is unavailable but OpenGL is present.
290    Pbo,
291}
292
293impl From<TensorMemory> for String {
294    fn from(memory: TensorMemory) -> Self {
295        match memory {
296            #[cfg(target_os = "linux")]
297            TensorMemory::Dma => "dma".to_owned(),
298            #[cfg(unix)]
299            TensorMemory::Shm => "shm".to_owned(),
300            TensorMemory::Mem => "mem".to_owned(),
301            TensorMemory::Pbo => "pbo".to_owned(),
302        }
303    }
304}
305
306impl TryFrom<&str> for TensorMemory {
307    type Error = Error;
308
309    fn try_from(s: &str) -> Result<Self> {
310        match s {
311            #[cfg(target_os = "linux")]
312            "dma" => Ok(TensorMemory::Dma),
313            #[cfg(unix)]
314            "shm" => Ok(TensorMemory::Shm),
315            "mem" => Ok(TensorMemory::Mem),
316            "pbo" => Ok(TensorMemory::Pbo),
317            _ => Err(Error::InvalidMemoryType(s.to_owned())),
318        }
319    }
320}
321
322#[derive(Debug)]
323#[allow(dead_code)] // Variants are constructed by downstream crates via pub(crate) helpers
324pub(crate) enum TensorStorage<T>
325where
326    T: Num + Clone + fmt::Debug + Send + Sync,
327{
328    #[cfg(target_os = "linux")]
329    Dma(DmaTensor<T>),
330    #[cfg(unix)]
331    Shm(ShmTensor<T>),
332    Mem(MemTensor<T>),
333    Pbo(PboTensor<T>),
334}
335
336impl<T> TensorStorage<T>
337where
338    T: Num + Clone + fmt::Debug + Send + Sync,
339{
340    /// Create a new tensor storage with the given shape, memory type, and
341    /// optional name. If no name is given, a random name will be generated.
342    /// If no memory type is given, the best available memory type will be
343    /// chosen based on the platform and environment variables.
344    fn new(shape: &[usize], memory: Option<TensorMemory>, name: Option<&str>) -> Result<Self> {
345        match memory {
346            #[cfg(target_os = "linux")]
347            Some(TensorMemory::Dma) => {
348                DmaTensor::<T>::new(shape, name).map(TensorStorage::Dma)
349            }
350            #[cfg(unix)]
351            Some(TensorMemory::Shm) => {
352                ShmTensor::<T>::new(shape, name).map(TensorStorage::Shm)
353            }
354            Some(TensorMemory::Mem) => {
355                MemTensor::<T>::new(shape, name).map(TensorStorage::Mem)
356            }
357            Some(TensorMemory::Pbo) => Err(crate::error::Error::NotImplemented(
358                "PboTensor cannot be created via Tensor::new() — use ImageProcessor::create_image()".to_owned(),
359            )),
360            None => {
361                if std::env::var("EDGEFIRST_TENSOR_FORCE_MEM")
362                    .is_ok_and(|x| x != "0" && x.to_lowercase() != "false")
363                {
364                    MemTensor::<T>::new(shape, name).map(TensorStorage::Mem)
365                } else {
366                    #[cfg(target_os = "linux")]
367                    {
368                        // Linux: Try DMA -> SHM -> Mem
369                        match DmaTensor::<T>::new(shape, name) {
370                            Ok(tensor) => Ok(TensorStorage::Dma(tensor)),
371                            Err(_) => {
372                                match ShmTensor::<T>::new(shape, name)
373                                    .map(TensorStorage::Shm)
374                                {
375                                    Ok(tensor) => Ok(tensor),
376                                    Err(_) => MemTensor::<T>::new(shape, name)
377                                        .map(TensorStorage::Mem),
378                                }
379                            }
380                        }
381                    }
382                    #[cfg(all(unix, not(target_os = "linux")))]
383                    {
384                        // macOS/BSD: Try SHM -> Mem (no DMA)
385                        match ShmTensor::<T>::new(shape, name) {
386                            Ok(tensor) => Ok(TensorStorage::Shm(tensor)),
387                            Err(_) => {
388                                MemTensor::<T>::new(shape, name).map(TensorStorage::Mem)
389                            }
390                        }
391                    }
392                    #[cfg(not(unix))]
393                    {
394                        // Windows/other: Mem only
395                        MemTensor::<T>::new(shape, name).map(TensorStorage::Mem)
396                    }
397                }
398            }
399        }
400    }
401
402    /// Create a new tensor storage using the given file descriptor, shape,
403    /// and optional name.
404    #[cfg(unix)]
405    fn from_fd(fd: OwnedFd, shape: &[usize], name: Option<&str>) -> Result<Self> {
406        #[cfg(target_os = "linux")]
407        {
408            use nix::sys::stat::fstat;
409
410            let stat = fstat(&fd)?;
411            let major = major(stat.st_dev);
412            let minor = minor(stat.st_dev);
413
414            log::debug!("Creating tensor from fd: major={major}, minor={minor}");
415
416            if major != 0 {
417                // Dma and Shm tensors are expected to have major number 0
418                return Err(Error::UnknownDeviceType(major, minor));
419            }
420
421            match minor {
422                9 | 10 => {
423                    // minor number 9 & 10 indicates DMA memory
424                    DmaTensor::<T>::from_fd(fd, shape, name).map(TensorStorage::Dma)
425                }
426                _ => {
427                    // other minor numbers are assumed to be shared memory
428                    ShmTensor::<T>::from_fd(fd, shape, name).map(TensorStorage::Shm)
429                }
430            }
431        }
432        #[cfg(all(unix, not(target_os = "linux")))]
433        {
434            // On macOS/BSD, always use SHM (no DMA support)
435            ShmTensor::<T>::from_fd(fd, shape, name).map(TensorStorage::Shm)
436        }
437    }
438}
439
440impl<T> TensorTrait<T> for TensorStorage<T>
441where
442    T: Num + Clone + fmt::Debug + Send + Sync,
443{
444    fn new(shape: &[usize], name: Option<&str>) -> Result<Self> {
445        Self::new(shape, None, name)
446    }
447
448    #[cfg(unix)]
449    fn from_fd(fd: OwnedFd, shape: &[usize], name: Option<&str>) -> Result<Self> {
450        Self::from_fd(fd, shape, name)
451    }
452
453    #[cfg(unix)]
454    fn clone_fd(&self) -> Result<OwnedFd> {
455        match self {
456            #[cfg(target_os = "linux")]
457            TensorStorage::Dma(t) => t.clone_fd(),
458            TensorStorage::Shm(t) => t.clone_fd(),
459            TensorStorage::Mem(t) => t.clone_fd(),
460            TensorStorage::Pbo(t) => t.clone_fd(),
461        }
462    }
463
464    fn memory(&self) -> TensorMemory {
465        match self {
466            #[cfg(target_os = "linux")]
467            TensorStorage::Dma(_) => TensorMemory::Dma,
468            #[cfg(unix)]
469            TensorStorage::Shm(_) => TensorMemory::Shm,
470            TensorStorage::Mem(_) => TensorMemory::Mem,
471            TensorStorage::Pbo(_) => TensorMemory::Pbo,
472        }
473    }
474
475    fn name(&self) -> String {
476        match self {
477            #[cfg(target_os = "linux")]
478            TensorStorage::Dma(t) => t.name(),
479            #[cfg(unix)]
480            TensorStorage::Shm(t) => t.name(),
481            TensorStorage::Mem(t) => t.name(),
482            TensorStorage::Pbo(t) => t.name(),
483        }
484    }
485
486    fn shape(&self) -> &[usize] {
487        match self {
488            #[cfg(target_os = "linux")]
489            TensorStorage::Dma(t) => t.shape(),
490            #[cfg(unix)]
491            TensorStorage::Shm(t) => t.shape(),
492            TensorStorage::Mem(t) => t.shape(),
493            TensorStorage::Pbo(t) => t.shape(),
494        }
495    }
496
497    fn reshape(&mut self, shape: &[usize]) -> Result<()> {
498        match self {
499            #[cfg(target_os = "linux")]
500            TensorStorage::Dma(t) => t.reshape(shape),
501            #[cfg(unix)]
502            TensorStorage::Shm(t) => t.reshape(shape),
503            TensorStorage::Mem(t) => t.reshape(shape),
504            TensorStorage::Pbo(t) => t.reshape(shape),
505        }
506    }
507
508    fn map(&self) -> Result<TensorMap<T>> {
509        match self {
510            #[cfg(target_os = "linux")]
511            TensorStorage::Dma(t) => t.map(),
512            #[cfg(unix)]
513            TensorStorage::Shm(t) => t.map(),
514            TensorStorage::Mem(t) => t.map(),
515            TensorStorage::Pbo(t) => t.map(),
516        }
517    }
518
519    fn buffer_identity(&self) -> &BufferIdentity {
520        match self {
521            #[cfg(target_os = "linux")]
522            TensorStorage::Dma(t) => t.buffer_identity(),
523            #[cfg(unix)]
524            TensorStorage::Shm(t) => t.buffer_identity(),
525            TensorStorage::Mem(t) => t.buffer_identity(),
526            TensorStorage::Pbo(t) => t.buffer_identity(),
527        }
528    }
529}
530
531/// Multi-backend tensor with optional image format metadata.
532///
533/// When `format` is `Some`, this tensor represents an image. Width, height,
534/// and channels are derived from `shape` + `format`. When `format` is `None`,
535/// this is a raw tensor (identical to the pre-refactoring behavior).
536#[derive(Debug)]
537pub struct Tensor<T>
538where
539    T: Num + Clone + fmt::Debug + Send + Sync,
540{
541    pub(crate) storage: TensorStorage<T>,
542    format: Option<PixelFormat>,
543    chroma: Option<Box<Tensor<T>>>,
544}
545
546impl<T> Tensor<T>
547where
548    T: Num + Clone + fmt::Debug + Send + Sync,
549{
550    /// Wrap a TensorStorage in a Tensor with no image metadata.
551    pub(crate) fn wrap(storage: TensorStorage<T>) -> Self {
552        Self {
553            storage,
554            format: None,
555            chroma: None,
556        }
557    }
558
559    /// Create a new tensor with the given shape, memory type, and optional
560    /// name. If no name is given, a random name will be generated. If no
561    /// memory type is given, the best available memory type will be chosen
562    /// based on the platform and environment variables.
563    ///
564    /// On Linux platforms, the order of preference is: Dma -> Shm -> Mem.
565    /// On other Unix platforms (macOS), the order is: Shm -> Mem.
566    /// On non-Unix platforms, only Mem is available.
567    ///
568    /// # Environment Variables
569    /// - `EDGEFIRST_TENSOR_FORCE_MEM`: If set to a non-zero and non-false
570    ///   value, forces the use of regular system memory allocation
571    ///   (`TensorMemory::Mem`) regardless of platform capabilities.
572    ///
573    /// # Example
574    /// ```rust
575    /// use edgefirst_tensor::{Error, Tensor, TensorMemory, TensorTrait};
576    /// # fn main() -> Result<(), Error> {
577    /// let tensor = Tensor::<f32>::new(&[2, 3, 4], Some(TensorMemory::Mem), Some("test_tensor"))?;
578    /// assert_eq!(tensor.memory(), TensorMemory::Mem);
579    /// assert_eq!(tensor.name(), "test_tensor");
580    /// #    Ok(())
581    /// # }
582    /// ```
583    pub fn new(shape: &[usize], memory: Option<TensorMemory>, name: Option<&str>) -> Result<Self> {
584        TensorStorage::new(shape, memory, name).map(Self::wrap)
585    }
586
587    /// Create an image tensor with the given format.
588    pub fn image(
589        width: usize,
590        height: usize,
591        format: PixelFormat,
592        memory: Option<TensorMemory>,
593    ) -> Result<Self> {
594        let shape = match format.layout() {
595            PixelLayout::Packed => vec![height, width, format.channels()],
596            PixelLayout::Planar => vec![format.channels(), height, width],
597            PixelLayout::SemiPlanar => {
598                // Contiguous semi-planar: luma + interleaved chroma in one allocation.
599                // NV12 (4:2:0): H lines luma + H/2 lines chroma = H * 3/2 total
600                // NV16 (4:2:2): H lines luma + H lines chroma = H * 2 total
601                let total_h = match format {
602                    PixelFormat::Nv12 => {
603                        if !height.is_multiple_of(2) {
604                            return Err(Error::InvalidArgument(format!(
605                                "NV12 requires even height, got {height}"
606                            )));
607                        }
608                        height * 3 / 2
609                    }
610                    PixelFormat::Nv16 => height * 2,
611                    _ => {
612                        return Err(Error::InvalidArgument(format!(
613                            "unknown semi-planar height multiplier for {format:?}"
614                        )))
615                    }
616                };
617                vec![total_h, width]
618            }
619        };
620        let mut t = Self::new(&shape, memory, None)?;
621        t.format = Some(format);
622        Ok(t)
623    }
624
625    /// Attach format metadata to an existing tensor.
626    pub fn set_format(&mut self, format: PixelFormat) -> Result<()> {
627        let shape = self.shape();
628        match format.layout() {
629            PixelLayout::Packed => {
630                if shape.len() != 3 || shape[2] != format.channels() {
631                    return Err(Error::InvalidShape(format!(
632                        "packed format {format:?} expects [H, W, {}], got {shape:?}",
633                        format.channels()
634                    )));
635                }
636            }
637            PixelLayout::Planar => {
638                if shape.len() != 3 || shape[0] != format.channels() {
639                    return Err(Error::InvalidShape(format!(
640                        "planar format {format:?} expects [{}, H, W], got {shape:?}",
641                        format.channels()
642                    )));
643                }
644            }
645            PixelLayout::SemiPlanar => {
646                if shape.len() != 2 {
647                    return Err(Error::InvalidShape(format!(
648                        "semi-planar format {format:?} expects [H*k, W], got {shape:?}"
649                    )));
650                }
651                match format {
652                    PixelFormat::Nv12 if !shape[0].is_multiple_of(3) => {
653                        return Err(Error::InvalidShape(format!(
654                            "NV12 contiguous shape[0] must be divisible by 3, got {}",
655                            shape[0]
656                        )));
657                    }
658                    PixelFormat::Nv16 if !shape[0].is_multiple_of(2) => {
659                        return Err(Error::InvalidShape(format!(
660                            "NV16 contiguous shape[0] must be even, got {}",
661                            shape[0]
662                        )));
663                    }
664                    _ => {}
665                }
666            }
667        }
668        self.format = Some(format);
669        Ok(())
670    }
671
672    /// Pixel format (None if not an image).
673    pub fn format(&self) -> Option<PixelFormat> {
674        self.format
675    }
676
677    /// Image width (None if not an image).
678    pub fn width(&self) -> Option<usize> {
679        let fmt = self.format?;
680        let shape = self.shape();
681        match fmt.layout() {
682            PixelLayout::Packed => Some(shape[1]),
683            PixelLayout::Planar => Some(shape[2]),
684            PixelLayout::SemiPlanar => Some(shape[1]),
685        }
686    }
687
688    /// Image height (None if not an image).
689    pub fn height(&self) -> Option<usize> {
690        let fmt = self.format?;
691        let shape = self.shape();
692        match fmt.layout() {
693            PixelLayout::Packed => Some(shape[0]),
694            PixelLayout::Planar => Some(shape[1]),
695            PixelLayout::SemiPlanar => {
696                if self.is_multiplane() {
697                    Some(shape[0])
698                } else {
699                    match fmt {
700                        PixelFormat::Nv12 => Some(shape[0] * 2 / 3),
701                        PixelFormat::Nv16 => Some(shape[0] / 2),
702                        _ => None,
703                    }
704                }
705            }
706        }
707    }
708
709    /// Create from separate Y and UV planes (multiplane NV12/NV16).
710    pub fn from_planes(luma: Tensor<T>, chroma: Tensor<T>, format: PixelFormat) -> Result<Self> {
711        if format.layout() != PixelLayout::SemiPlanar {
712            return Err(Error::InvalidArgument(format!(
713                "from_planes requires a semi-planar format, got {format:?}"
714            )));
715        }
716        if chroma.format.is_some() || chroma.chroma.is_some() {
717            return Err(Error::InvalidArgument(
718                "chroma tensor must be a raw tensor (no format or chroma metadata)".into(),
719            ));
720        }
721        let luma_shape = luma.shape();
722        let chroma_shape = chroma.shape();
723        if luma_shape.len() != 2 || chroma_shape.len() != 2 {
724            return Err(Error::InvalidArgument(format!(
725                "from_planes expects 2D shapes, got luma={luma_shape:?} chroma={chroma_shape:?}"
726            )));
727        }
728        if luma_shape[1] != chroma_shape[1] {
729            return Err(Error::InvalidArgument(format!(
730                "luma width {} != chroma width {}",
731                luma_shape[1], chroma_shape[1]
732            )));
733        }
734        match format {
735            PixelFormat::Nv12 => {
736                if luma_shape[0] % 2 != 0 {
737                    return Err(Error::InvalidArgument(format!(
738                        "NV12 requires even luma height, got {}",
739                        luma_shape[0]
740                    )));
741                }
742                if chroma_shape[0] != luma_shape[0] / 2 {
743                    return Err(Error::InvalidArgument(format!(
744                        "NV12 chroma height {} != luma height / 2 ({})",
745                        chroma_shape[0],
746                        luma_shape[0] / 2
747                    )));
748                }
749            }
750            PixelFormat::Nv16 => {
751                if chroma_shape[0] != luma_shape[0] {
752                    return Err(Error::InvalidArgument(format!(
753                        "NV16 chroma height {} != luma height {}",
754                        chroma_shape[0], luma_shape[0]
755                    )));
756                }
757            }
758            _ => {
759                return Err(Error::InvalidArgument(format!(
760                    "from_planes only supports NV12 and NV16, got {format:?}"
761                )));
762            }
763        }
764
765        Ok(Tensor {
766            storage: luma.storage,
767            format: Some(format),
768            chroma: Some(Box::new(chroma)),
769        })
770    }
771
772    /// Whether this tensor uses separate plane allocations.
773    pub fn is_multiplane(&self) -> bool {
774        self.chroma.is_some()
775    }
776
777    /// Access the chroma plane for multiplane semi-planar images.
778    pub fn chroma(&self) -> Option<&Tensor<T>> {
779        self.chroma.as_deref()
780    }
781
782    /// Downcast to PBO tensor reference (for GL backends).
783    pub fn as_pbo(&self) -> Option<&PboTensor<T>> {
784        match &self.storage {
785            TensorStorage::Pbo(p) => Some(p),
786            _ => None,
787        }
788    }
789
790    /// Downcast to DMA tensor reference (for EGL import, G2D).
791    #[cfg(target_os = "linux")]
792    pub fn as_dma(&self) -> Option<&DmaTensor<T>> {
793        match &self.storage {
794            TensorStorage::Dma(d) => Some(d),
795            _ => None,
796        }
797    }
798
799    /// Construct a Tensor from a PBO tensor (for GL backends that allocate PBOs).
800    pub fn from_pbo(pbo: PboTensor<T>) -> Self {
801        Self {
802            storage: TensorStorage::Pbo(pbo),
803            format: None,
804            chroma: None,
805        }
806    }
807}
808
809impl<T> TensorTrait<T> for Tensor<T>
810where
811    T: Num + Clone + fmt::Debug + Send + Sync,
812{
813    fn new(shape: &[usize], name: Option<&str>) -> Result<Self>
814    where
815        Self: Sized,
816    {
817        Self::new(shape, None, name)
818    }
819
820    #[cfg(unix)]
821    fn from_fd(fd: std::os::fd::OwnedFd, shape: &[usize], name: Option<&str>) -> Result<Self>
822    where
823        Self: Sized,
824    {
825        Ok(Self::wrap(TensorStorage::from_fd(fd, shape, name)?))
826    }
827
828    #[cfg(unix)]
829    fn clone_fd(&self) -> Result<std::os::fd::OwnedFd> {
830        self.storage.clone_fd()
831    }
832
833    fn memory(&self) -> TensorMemory {
834        self.storage.memory()
835    }
836
837    fn name(&self) -> String {
838        self.storage.name()
839    }
840
841    fn shape(&self) -> &[usize] {
842        self.storage.shape()
843    }
844
845    fn reshape(&mut self, shape: &[usize]) -> Result<()> {
846        if self.chroma.is_some() {
847            return Err(Error::InvalidOperation(
848                "cannot reshape a multiplane tensor — decompose planes first".into(),
849            ));
850        }
851        self.storage.reshape(shape)?;
852        self.format = None;
853        Ok(())
854    }
855
856    fn map(&self) -> Result<TensorMap<T>> {
857        self.storage.map()
858    }
859
860    fn buffer_identity(&self) -> &BufferIdentity {
861        self.storage.buffer_identity()
862    }
863}
864
865pub enum TensorMap<T>
866where
867    T: Num + Clone + fmt::Debug,
868{
869    #[cfg(target_os = "linux")]
870    Dma(DmaMap<T>),
871    #[cfg(unix)]
872    Shm(ShmMap<T>),
873    Mem(MemMap<T>),
874    Pbo(PboMap<T>),
875}
876
877impl<T> TensorMapTrait<T> for TensorMap<T>
878where
879    T: Num + Clone + fmt::Debug,
880{
881    fn shape(&self) -> &[usize] {
882        match self {
883            #[cfg(target_os = "linux")]
884            TensorMap::Dma(map) => map.shape(),
885            #[cfg(unix)]
886            TensorMap::Shm(map) => map.shape(),
887            TensorMap::Mem(map) => map.shape(),
888            TensorMap::Pbo(map) => map.shape(),
889        }
890    }
891
892    fn unmap(&mut self) {
893        match self {
894            #[cfg(target_os = "linux")]
895            TensorMap::Dma(map) => map.unmap(),
896            #[cfg(unix)]
897            TensorMap::Shm(map) => map.unmap(),
898            TensorMap::Mem(map) => map.unmap(),
899            TensorMap::Pbo(map) => map.unmap(),
900        }
901    }
902
903    fn as_slice(&self) -> &[T] {
904        match self {
905            #[cfg(target_os = "linux")]
906            TensorMap::Dma(map) => map.as_slice(),
907            #[cfg(unix)]
908            TensorMap::Shm(map) => map.as_slice(),
909            TensorMap::Mem(map) => map.as_slice(),
910            TensorMap::Pbo(map) => map.as_slice(),
911        }
912    }
913
914    fn as_mut_slice(&mut self) -> &mut [T] {
915        match self {
916            #[cfg(target_os = "linux")]
917            TensorMap::Dma(map) => map.as_mut_slice(),
918            #[cfg(unix)]
919            TensorMap::Shm(map) => map.as_mut_slice(),
920            TensorMap::Mem(map) => map.as_mut_slice(),
921            TensorMap::Pbo(map) => map.as_mut_slice(),
922        }
923    }
924}
925
926impl<T> Deref for TensorMap<T>
927where
928    T: Num + Clone + fmt::Debug,
929{
930    type Target = [T];
931
932    fn deref(&self) -> &[T] {
933        match self {
934            #[cfg(target_os = "linux")]
935            TensorMap::Dma(map) => map.deref(),
936            #[cfg(unix)]
937            TensorMap::Shm(map) => map.deref(),
938            TensorMap::Mem(map) => map.deref(),
939            TensorMap::Pbo(map) => map.deref(),
940        }
941    }
942}
943
944impl<T> DerefMut for TensorMap<T>
945where
946    T: Num + Clone + fmt::Debug,
947{
948    fn deref_mut(&mut self) -> &mut [T] {
949        match self {
950            #[cfg(target_os = "linux")]
951            TensorMap::Dma(map) => map.deref_mut(),
952            #[cfg(unix)]
953            TensorMap::Shm(map) => map.deref_mut(),
954            TensorMap::Mem(map) => map.deref_mut(),
955            TensorMap::Pbo(map) => map.deref_mut(),
956        }
957    }
958}
959
960// ============================================================================
961// Platform availability helpers
962// ============================================================================
963
964/// Check if DMA memory allocation is available on this system.
965///
966/// Returns `true` only on Linux systems with DMA-BUF heap access (typically
967/// requires running as root or membership in a video/render group).
968/// Always returns `false` on non-Linux platforms (macOS, Windows, etc.).
969///
970/// This function caches its result after the first call for efficiency.
971#[cfg(target_os = "linux")]
972static DMA_AVAILABLE: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
973
974/// Check if DMA memory allocation is available on this system.
975#[cfg(target_os = "linux")]
976pub fn is_dma_available() -> bool {
977    *DMA_AVAILABLE.get_or_init(|| Tensor::<u8>::new(&[64], Some(TensorMemory::Dma), None).is_ok())
978}
979
980/// Check if DMA memory allocation is available on this system.
981///
982/// Always returns `false` on non-Linux platforms since DMA-BUF is Linux-specific.
983#[cfg(not(target_os = "linux"))]
984pub fn is_dma_available() -> bool {
985    false
986}
987
988/// Check if POSIX shared memory allocation is available on this system.
989///
990/// Returns `true` on Unix systems (Linux, macOS, BSD) where POSIX shared memory
991/// is supported. Always returns `false` on non-Unix platforms (Windows).
992///
993/// This function caches its result after the first call for efficiency.
994#[cfg(unix)]
995static SHM_AVAILABLE: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
996
997/// Check if POSIX shared memory allocation is available on this system.
998#[cfg(unix)]
999pub fn is_shm_available() -> bool {
1000    *SHM_AVAILABLE.get_or_init(|| Tensor::<u8>::new(&[64], Some(TensorMemory::Shm), None).is_ok())
1001}
1002
1003/// Check if POSIX shared memory allocation is available on this system.
1004///
1005/// Always returns `false` on non-Unix platforms since POSIX SHM is Unix-specific.
1006#[cfg(not(unix))]
1007pub fn is_shm_available() -> bool {
1008    false
1009}
1010
1011#[cfg(test)]
1012mod dtype_tests {
1013    use super::*;
1014
1015    #[test]
1016    fn dtype_size() {
1017        assert_eq!(DType::U8.size(), 1);
1018        assert_eq!(DType::I8.size(), 1);
1019        assert_eq!(DType::U16.size(), 2);
1020        assert_eq!(DType::I16.size(), 2);
1021        assert_eq!(DType::U32.size(), 4);
1022        assert_eq!(DType::I32.size(), 4);
1023        assert_eq!(DType::U64.size(), 8);
1024        assert_eq!(DType::I64.size(), 8);
1025        assert_eq!(DType::F16.size(), 2);
1026        assert_eq!(DType::F32.size(), 4);
1027        assert_eq!(DType::F64.size(), 8);
1028    }
1029
1030    #[test]
1031    fn dtype_name() {
1032        assert_eq!(DType::U8.name(), "u8");
1033        assert_eq!(DType::F16.name(), "f16");
1034        assert_eq!(DType::F32.name(), "f32");
1035    }
1036
1037    #[test]
1038    fn dtype_serde_roundtrip() {
1039        use serde_json;
1040        let dt = DType::F16;
1041        let json = serde_json::to_string(&dt).unwrap();
1042        let back: DType = serde_json::from_str(&json).unwrap();
1043        assert_eq!(dt, back);
1044    }
1045}
1046
1047#[cfg(test)]
1048mod image_tests {
1049    use super::*;
1050
1051    #[test]
1052    fn raw_tensor_has_no_format() {
1053        let t = Tensor::<u8>::new(&[480, 640, 3], None, None).unwrap();
1054        assert!(t.format().is_none());
1055        assert!(t.width().is_none());
1056        assert!(t.height().is_none());
1057        assert!(!t.is_multiplane());
1058        assert!(t.chroma().is_none());
1059    }
1060
1061    #[test]
1062    fn image_tensor_packed() {
1063        let t = Tensor::<u8>::image(640, 480, PixelFormat::Rgba, None).unwrap();
1064        assert_eq!(t.format(), Some(PixelFormat::Rgba));
1065        assert_eq!(t.width(), Some(640));
1066        assert_eq!(t.height(), Some(480));
1067        assert_eq!(t.shape(), &[480, 640, 4]);
1068        assert!(!t.is_multiplane());
1069    }
1070
1071    #[test]
1072    fn image_tensor_planar() {
1073        let t = Tensor::<u8>::image(640, 480, PixelFormat::PlanarRgb, None).unwrap();
1074        assert_eq!(t.format(), Some(PixelFormat::PlanarRgb));
1075        assert_eq!(t.width(), Some(640));
1076        assert_eq!(t.height(), Some(480));
1077        assert_eq!(t.shape(), &[3, 480, 640]);
1078    }
1079
1080    #[test]
1081    fn image_tensor_semi_planar_contiguous() {
1082        let t = Tensor::<u8>::image(640, 480, PixelFormat::Nv12, None).unwrap();
1083        assert_eq!(t.format(), Some(PixelFormat::Nv12));
1084        assert_eq!(t.width(), Some(640));
1085        assert_eq!(t.height(), Some(480));
1086        // NV12: H*3/2 = 720
1087        assert_eq!(t.shape(), &[720, 640]);
1088        assert!(!t.is_multiplane());
1089    }
1090
1091    #[test]
1092    fn set_format_valid() {
1093        let mut t = Tensor::<u8>::new(&[480, 640, 3], None, None).unwrap();
1094        assert!(t.format().is_none());
1095        t.set_format(PixelFormat::Rgb).unwrap();
1096        assert_eq!(t.format(), Some(PixelFormat::Rgb));
1097        assert_eq!(t.width(), Some(640));
1098        assert_eq!(t.height(), Some(480));
1099    }
1100
1101    #[test]
1102    fn set_format_invalid_shape() {
1103        let mut t = Tensor::<u8>::new(&[480, 640, 4], None, None).unwrap();
1104        // RGB expects 3 channels, not 4
1105        let err = t.set_format(PixelFormat::Rgb);
1106        assert!(err.is_err());
1107        // Original tensor is unmodified
1108        assert!(t.format().is_none());
1109    }
1110
1111    #[test]
1112    fn reshape_clears_format() {
1113        let mut t = Tensor::<u8>::image(640, 480, PixelFormat::Rgba, None).unwrap();
1114        assert_eq!(t.format(), Some(PixelFormat::Rgba));
1115        // Reshape to flat — format cleared
1116        t.reshape(&[480 * 640 * 4]).unwrap();
1117        assert!(t.format().is_none());
1118    }
1119
1120    #[test]
1121    fn from_planes_nv12() {
1122        let y = Tensor::<u8>::new(&[480, 640], None, None).unwrap();
1123        let uv = Tensor::<u8>::new(&[240, 640], None, None).unwrap();
1124        let img = Tensor::from_planes(y, uv, PixelFormat::Nv12).unwrap();
1125        assert_eq!(img.format(), Some(PixelFormat::Nv12));
1126        assert!(img.is_multiplane());
1127        assert!(img.chroma().is_some());
1128        assert_eq!(img.width(), Some(640));
1129        assert_eq!(img.height(), Some(480));
1130    }
1131
1132    #[test]
1133    fn from_planes_rejects_non_semiplanar() {
1134        let y = Tensor::<u8>::new(&[480, 640], None, None).unwrap();
1135        let uv = Tensor::<u8>::new(&[240, 640], None, None).unwrap();
1136        let err = Tensor::from_planes(y, uv, PixelFormat::Rgb);
1137        assert!(err.is_err());
1138    }
1139
1140    #[test]
1141    fn reshape_multiplane_errors() {
1142        let y = Tensor::<u8>::new(&[480, 640], None, None).unwrap();
1143        let uv = Tensor::<u8>::new(&[240, 640], None, None).unwrap();
1144        let mut img = Tensor::from_planes(y, uv, PixelFormat::Nv12).unwrap();
1145        let err = img.reshape(&[480 * 640 + 240 * 640]);
1146        assert!(err.is_err());
1147    }
1148}
1149
1150#[cfg(test)]
1151mod tests {
1152    #[cfg(target_os = "linux")]
1153    use nix::unistd::{access, AccessFlags};
1154    #[cfg(target_os = "linux")]
1155    use std::io::Write as _;
1156    use std::sync::RwLock;
1157
1158    use super::*;
1159
1160    #[ctor::ctor]
1161    fn init() {
1162        env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init();
1163    }
1164
1165    /// Macro to get the current function name for logging in tests.
1166    #[cfg(target_os = "linux")]
1167    macro_rules! function {
1168        () => {{
1169            fn f() {}
1170            fn type_name_of<T>(_: T) -> &'static str {
1171                std::any::type_name::<T>()
1172            }
1173            let name = type_name_of(f);
1174
1175            // Find and cut the rest of the path
1176            match &name[..name.len() - 3].rfind(':') {
1177                Some(pos) => &name[pos + 1..name.len() - 3],
1178                None => &name[..name.len() - 3],
1179            }
1180        }};
1181    }
1182
1183    #[test]
1184    #[cfg(target_os = "linux")]
1185    fn test_tensor() {
1186        let _lock = FD_LOCK.read().unwrap();
1187        let shape = vec![1];
1188        let tensor = DmaTensor::<f32>::new(&shape, Some("dma_tensor"));
1189        let dma_enabled = tensor.is_ok();
1190
1191        let tensor = Tensor::<f32>::new(&shape, None, None).expect("Failed to create tensor");
1192        match dma_enabled {
1193            true => assert_eq!(tensor.memory(), TensorMemory::Dma),
1194            false => assert_eq!(tensor.memory(), TensorMemory::Shm),
1195        }
1196    }
1197
1198    #[test]
1199    #[cfg(all(unix, not(target_os = "linux")))]
1200    fn test_tensor() {
1201        let shape = vec![1];
1202        let tensor = Tensor::<f32>::new(&shape, None, None).expect("Failed to create tensor");
1203        // On macOS/BSD, auto-detection tries SHM first, falls back to Mem
1204        assert!(
1205            tensor.memory() == TensorMemory::Shm || tensor.memory() == TensorMemory::Mem,
1206            "Expected SHM or Mem on macOS, got {:?}",
1207            tensor.memory()
1208        );
1209    }
1210
1211    #[test]
1212    #[cfg(not(unix))]
1213    fn test_tensor() {
1214        let shape = vec![1];
1215        let tensor = Tensor::<f32>::new(&shape, None, None).expect("Failed to create tensor");
1216        assert_eq!(tensor.memory(), TensorMemory::Mem);
1217    }
1218
1219    #[test]
1220    #[cfg(target_os = "linux")]
1221    fn test_dma_tensor() {
1222        let _lock = FD_LOCK.read().unwrap();
1223        match access(
1224            "/dev/dma_heap/linux,cma",
1225            AccessFlags::R_OK | AccessFlags::W_OK,
1226        ) {
1227            Ok(_) => println!("/dev/dma_heap/linux,cma is available"),
1228            Err(_) => match access(
1229                "/dev/dma_heap/system",
1230                AccessFlags::R_OK | AccessFlags::W_OK,
1231            ) {
1232                Ok(_) => println!("/dev/dma_heap/system is available"),
1233                Err(e) => {
1234                    writeln!(
1235                        &mut std::io::stdout(),
1236                        "[WARNING] DMA Heap is unavailable: {e}"
1237                    )
1238                    .unwrap();
1239                    return;
1240                }
1241            },
1242        }
1243
1244        let shape = vec![2, 3, 4];
1245        let tensor =
1246            DmaTensor::<f32>::new(&shape, Some("test_tensor")).expect("Failed to create tensor");
1247
1248        const DUMMY_VALUE: f32 = 12.34;
1249
1250        assert_eq!(tensor.memory(), TensorMemory::Dma);
1251        assert_eq!(tensor.name(), "test_tensor");
1252        assert_eq!(tensor.shape(), &shape);
1253        assert_eq!(tensor.size(), 2 * 3 * 4 * std::mem::size_of::<f32>());
1254        assert_eq!(tensor.len(), 2 * 3 * 4);
1255
1256        {
1257            let mut tensor_map = tensor.map().expect("Failed to map DMA memory");
1258            tensor_map.fill(42.0);
1259            assert!(tensor_map.iter().all(|&x| x == 42.0));
1260        }
1261
1262        {
1263            let shared = Tensor::<f32>::from_fd(
1264                tensor
1265                    .clone_fd()
1266                    .expect("Failed to duplicate tensor file descriptor"),
1267                &shape,
1268                Some("test_tensor_shared"),
1269            )
1270            .expect("Failed to create tensor from fd");
1271
1272            assert_eq!(shared.memory(), TensorMemory::Dma);
1273            assert_eq!(shared.name(), "test_tensor_shared");
1274            assert_eq!(shared.shape(), &shape);
1275
1276            let mut tensor_map = shared.map().expect("Failed to map DMA memory from fd");
1277            tensor_map.fill(DUMMY_VALUE);
1278            assert!(tensor_map.iter().all(|&x| x == DUMMY_VALUE));
1279        }
1280
1281        {
1282            let tensor_map = tensor.map().expect("Failed to map DMA memory");
1283            assert!(tensor_map.iter().all(|&x| x == DUMMY_VALUE));
1284        }
1285
1286        let mut tensor = DmaTensor::<u8>::new(&shape, None).expect("Failed to create tensor");
1287        assert_eq!(tensor.shape(), &shape);
1288        let new_shape = vec![3, 4, 4];
1289        assert!(
1290            tensor.reshape(&new_shape).is_err(),
1291            "Reshape should fail due to size mismatch"
1292        );
1293        assert_eq!(tensor.shape(), &shape, "Shape should remain unchanged");
1294
1295        let new_shape = vec![2, 3, 4];
1296        tensor.reshape(&new_shape).expect("Reshape should succeed");
1297        assert_eq!(
1298            tensor.shape(),
1299            &new_shape,
1300            "Shape should be updated after successful reshape"
1301        );
1302
1303        {
1304            let mut tensor_map = tensor.map().expect("Failed to map DMA memory");
1305            tensor_map.fill(1);
1306            assert!(tensor_map.iter().all(|&x| x == 1));
1307        }
1308
1309        {
1310            let mut tensor_map = tensor.map().expect("Failed to map DMA memory");
1311            tensor_map[2] = 42;
1312            assert_eq!(tensor_map[1], 1, "Value at index 1 should be 1");
1313            assert_eq!(tensor_map[2], 42, "Value at index 2 should be 42");
1314        }
1315    }
1316
1317    #[test]
1318    #[cfg(unix)]
1319    fn test_shm_tensor() {
1320        let _lock = FD_LOCK.read().unwrap();
1321        let shape = vec![2, 3, 4];
1322        let tensor =
1323            ShmTensor::<f32>::new(&shape, Some("test_tensor")).expect("Failed to create tensor");
1324        assert_eq!(tensor.shape(), &shape);
1325        assert_eq!(tensor.size(), 2 * 3 * 4 * std::mem::size_of::<f32>());
1326        assert_eq!(tensor.name(), "test_tensor");
1327
1328        const DUMMY_VALUE: f32 = 12.34;
1329        {
1330            let mut tensor_map = tensor.map().expect("Failed to map shared memory");
1331            tensor_map.fill(42.0);
1332            assert!(tensor_map.iter().all(|&x| x == 42.0));
1333        }
1334
1335        {
1336            let shared = Tensor::<f32>::from_fd(
1337                tensor
1338                    .clone_fd()
1339                    .expect("Failed to duplicate tensor file descriptor"),
1340                &shape,
1341                Some("test_tensor_shared"),
1342            )
1343            .expect("Failed to create tensor from fd");
1344
1345            assert_eq!(shared.memory(), TensorMemory::Shm);
1346            assert_eq!(shared.name(), "test_tensor_shared");
1347            assert_eq!(shared.shape(), &shape);
1348
1349            let mut tensor_map = shared.map().expect("Failed to map shared memory from fd");
1350            tensor_map.fill(DUMMY_VALUE);
1351            assert!(tensor_map.iter().all(|&x| x == DUMMY_VALUE));
1352        }
1353
1354        {
1355            let tensor_map = tensor.map().expect("Failed to map shared memory");
1356            assert!(tensor_map.iter().all(|&x| x == DUMMY_VALUE));
1357        }
1358
1359        let mut tensor = ShmTensor::<u8>::new(&shape, None).expect("Failed to create tensor");
1360        assert_eq!(tensor.shape(), &shape);
1361        let new_shape = vec![3, 4, 4];
1362        assert!(
1363            tensor.reshape(&new_shape).is_err(),
1364            "Reshape should fail due to size mismatch"
1365        );
1366        assert_eq!(tensor.shape(), &shape, "Shape should remain unchanged");
1367
1368        let new_shape = vec![2, 3, 4];
1369        tensor.reshape(&new_shape).expect("Reshape should succeed");
1370        assert_eq!(
1371            tensor.shape(),
1372            &new_shape,
1373            "Shape should be updated after successful reshape"
1374        );
1375
1376        {
1377            let mut tensor_map = tensor.map().expect("Failed to map shared memory");
1378            tensor_map.fill(1);
1379            assert!(tensor_map.iter().all(|&x| x == 1));
1380        }
1381
1382        {
1383            let mut tensor_map = tensor.map().expect("Failed to map shared memory");
1384            tensor_map[2] = 42;
1385            assert_eq!(tensor_map[1], 1, "Value at index 1 should be 1");
1386            assert_eq!(tensor_map[2], 42, "Value at index 2 should be 42");
1387        }
1388    }
1389
1390    #[test]
1391    fn test_mem_tensor() {
1392        let shape = vec![2, 3, 4];
1393        let tensor =
1394            MemTensor::<f32>::new(&shape, Some("test_tensor")).expect("Failed to create tensor");
1395        assert_eq!(tensor.shape(), &shape);
1396        assert_eq!(tensor.size(), 2 * 3 * 4 * std::mem::size_of::<f32>());
1397        assert_eq!(tensor.name(), "test_tensor");
1398
1399        {
1400            let mut tensor_map = tensor.map().expect("Failed to map memory");
1401            tensor_map.fill(42.0);
1402            assert!(tensor_map.iter().all(|&x| x == 42.0));
1403        }
1404
1405        let mut tensor = MemTensor::<u8>::new(&shape, None).expect("Failed to create tensor");
1406        assert_eq!(tensor.shape(), &shape);
1407        let new_shape = vec![3, 4, 4];
1408        assert!(
1409            tensor.reshape(&new_shape).is_err(),
1410            "Reshape should fail due to size mismatch"
1411        );
1412        assert_eq!(tensor.shape(), &shape, "Shape should remain unchanged");
1413
1414        let new_shape = vec![2, 3, 4];
1415        tensor.reshape(&new_shape).expect("Reshape should succeed");
1416        assert_eq!(
1417            tensor.shape(),
1418            &new_shape,
1419            "Shape should be updated after successful reshape"
1420        );
1421
1422        {
1423            let mut tensor_map = tensor.map().expect("Failed to map memory");
1424            tensor_map.fill(1);
1425            assert!(tensor_map.iter().all(|&x| x == 1));
1426        }
1427
1428        {
1429            let mut tensor_map = tensor.map().expect("Failed to map memory");
1430            tensor_map[2] = 42;
1431            assert_eq!(tensor_map[1], 1, "Value at index 1 should be 1");
1432            assert_eq!(tensor_map[2], 42, "Value at index 2 should be 42");
1433        }
1434    }
1435
1436    #[test]
1437    #[cfg(target_os = "linux")]
1438    fn test_dma_no_fd_leaks() {
1439        let _lock = FD_LOCK.write().unwrap();
1440        if !is_dma_available() {
1441            log::warn!(
1442                "SKIPPED: {} - DMA memory allocation not available (permission denied or no DMA-BUF support)",
1443                function!()
1444            );
1445            return;
1446        }
1447
1448        let proc = procfs::process::Process::myself()
1449            .expect("Failed to get current process using /proc/self");
1450
1451        let start_open_fds = proc
1452            .fd_count()
1453            .expect("Failed to get open file descriptor count");
1454
1455        for _ in 0..100 {
1456            let tensor = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Dma), None)
1457                .expect("Failed to create tensor");
1458            let mut map = tensor.map().unwrap();
1459            map.as_mut_slice().fill(233);
1460        }
1461
1462        let end_open_fds = proc
1463            .fd_count()
1464            .expect("Failed to get open file descriptor count");
1465
1466        assert_eq!(
1467            start_open_fds, end_open_fds,
1468            "File descriptor leak detected: {} -> {}",
1469            start_open_fds, end_open_fds
1470        );
1471    }
1472
1473    #[test]
1474    #[cfg(target_os = "linux")]
1475    fn test_dma_from_fd_no_fd_leaks() {
1476        let _lock = FD_LOCK.write().unwrap();
1477        if !is_dma_available() {
1478            log::warn!(
1479                "SKIPPED: {} - DMA memory allocation not available (permission denied or no DMA-BUF support)",
1480                function!()
1481            );
1482            return;
1483        }
1484
1485        let proc = procfs::process::Process::myself()
1486            .expect("Failed to get current process using /proc/self");
1487
1488        let start_open_fds = proc
1489            .fd_count()
1490            .expect("Failed to get open file descriptor count");
1491
1492        let orig = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Dma), None).unwrap();
1493
1494        for _ in 0..100 {
1495            let tensor =
1496                Tensor::<u8>::from_fd(orig.clone_fd().unwrap(), orig.shape(), None).unwrap();
1497            let mut map = tensor.map().unwrap();
1498            map.as_mut_slice().fill(233);
1499        }
1500        drop(orig);
1501
1502        let end_open_fds = proc.fd_count().unwrap();
1503
1504        assert_eq!(
1505            start_open_fds, end_open_fds,
1506            "File descriptor leak detected: {} -> {}",
1507            start_open_fds, end_open_fds
1508        );
1509    }
1510
1511    #[test]
1512    #[cfg(target_os = "linux")]
1513    fn test_shm_no_fd_leaks() {
1514        let _lock = FD_LOCK.write().unwrap();
1515        if !is_shm_available() {
1516            log::warn!(
1517                "SKIPPED: {} - SHM memory allocation not available (permission denied or no SHM support)",
1518                function!()
1519            );
1520            return;
1521        }
1522
1523        let proc = procfs::process::Process::myself()
1524            .expect("Failed to get current process using /proc/self");
1525
1526        let start_open_fds = proc
1527            .fd_count()
1528            .expect("Failed to get open file descriptor count");
1529
1530        for _ in 0..100 {
1531            let tensor = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Shm), None)
1532                .expect("Failed to create tensor");
1533            let mut map = tensor.map().unwrap();
1534            map.as_mut_slice().fill(233);
1535        }
1536
1537        let end_open_fds = proc
1538            .fd_count()
1539            .expect("Failed to get open file descriptor count");
1540
1541        assert_eq!(
1542            start_open_fds, end_open_fds,
1543            "File descriptor leak detected: {} -> {}",
1544            start_open_fds, end_open_fds
1545        );
1546    }
1547
1548    #[test]
1549    #[cfg(target_os = "linux")]
1550    fn test_shm_from_fd_no_fd_leaks() {
1551        let _lock = FD_LOCK.write().unwrap();
1552        if !is_shm_available() {
1553            log::warn!(
1554                "SKIPPED: {} - SHM memory allocation not available (permission denied or no SHM support)",
1555                function!()
1556            );
1557            return;
1558        }
1559
1560        let proc = procfs::process::Process::myself()
1561            .expect("Failed to get current process using /proc/self");
1562
1563        let start_open_fds = proc
1564            .fd_count()
1565            .expect("Failed to get open file descriptor count");
1566
1567        let orig = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Shm), None).unwrap();
1568
1569        for _ in 0..100 {
1570            let tensor =
1571                Tensor::<u8>::from_fd(orig.clone_fd().unwrap(), orig.shape(), None).unwrap();
1572            let mut map = tensor.map().unwrap();
1573            map.as_mut_slice().fill(233);
1574        }
1575        drop(orig);
1576
1577        let end_open_fds = proc.fd_count().unwrap();
1578
1579        assert_eq!(
1580            start_open_fds, end_open_fds,
1581            "File descriptor leak detected: {} -> {}",
1582            start_open_fds, end_open_fds
1583        );
1584    }
1585
1586    #[cfg(feature = "ndarray")]
1587    #[test]
1588    fn test_ndarray() {
1589        let _lock = FD_LOCK.read().unwrap();
1590        let shape = vec![2, 3, 4];
1591        let tensor = Tensor::<f32>::new(&shape, None, None).expect("Failed to create tensor");
1592
1593        let mut tensor_map = tensor.map().expect("Failed to map tensor memory");
1594        tensor_map.fill(1.0);
1595
1596        let view = tensor_map.view().expect("Failed to get ndarray view");
1597        assert_eq!(view.shape(), &[2, 3, 4]);
1598        assert!(view.iter().all(|&x| x == 1.0));
1599
1600        let mut view_mut = tensor_map
1601            .view_mut()
1602            .expect("Failed to get mutable ndarray view");
1603        view_mut[[0, 0, 0]] = 42.0;
1604        assert_eq!(view_mut[[0, 0, 0]], 42.0);
1605        assert_eq!(tensor_map[0], 42.0, "Value at index 0 should be 42");
1606    }
1607
1608    #[test]
1609    fn test_buffer_identity_unique() {
1610        let id1 = BufferIdentity::new();
1611        let id2 = BufferIdentity::new();
1612        assert_ne!(
1613            id1.id(),
1614            id2.id(),
1615            "Two identities should have different ids"
1616        );
1617    }
1618
1619    #[test]
1620    fn test_buffer_identity_clone_shares_guard() {
1621        let id1 = BufferIdentity::new();
1622        let weak = id1.weak();
1623        assert!(
1624            weak.upgrade().is_some(),
1625            "Weak should be alive while original exists"
1626        );
1627
1628        let id2 = id1.clone();
1629        assert_eq!(id1.id(), id2.id(), "Cloned identity should have same id");
1630
1631        drop(id1);
1632        assert!(
1633            weak.upgrade().is_some(),
1634            "Weak should still be alive (clone holds Arc)"
1635        );
1636
1637        drop(id2);
1638        assert!(
1639            weak.upgrade().is_none(),
1640            "Weak should be dead after all clones dropped"
1641        );
1642    }
1643
1644    #[test]
1645    fn test_tensor_buffer_identity() {
1646        let t1 = Tensor::<u8>::new(&[100], Some(TensorMemory::Mem), Some("t1")).unwrap();
1647        let t2 = Tensor::<u8>::new(&[100], Some(TensorMemory::Mem), Some("t2")).unwrap();
1648        assert_ne!(
1649            t1.buffer_identity().id(),
1650            t2.buffer_identity().id(),
1651            "Different tensors should have different buffer ids"
1652        );
1653    }
1654
1655    // Any test that cares about the fd count must grab it exclusively.
1656    // Any tests which modifies the fd count by opening or closing fds must grab it
1657    // shared.
1658    pub static FD_LOCK: RwLock<()> = RwLock::new(());
1659
1660    /// Test that DMA is NOT available on non-Linux platforms.
1661    /// This verifies the cross-platform behavior of is_dma_available().
1662    #[test]
1663    #[cfg(not(target_os = "linux"))]
1664    fn test_dma_not_available_on_non_linux() {
1665        assert!(
1666            !is_dma_available(),
1667            "DMA memory allocation should NOT be available on non-Linux platforms"
1668        );
1669    }
1670
1671    /// Test that SHM memory allocation is available and usable on Unix systems.
1672    /// This is a basic functional test; Linux has additional FD leak tests using procfs.
1673    #[test]
1674    #[cfg(unix)]
1675    fn test_shm_available_and_usable() {
1676        assert!(
1677            is_shm_available(),
1678            "SHM memory allocation should be available on Unix systems"
1679        );
1680
1681        // Create a tensor with SHM backing
1682        let tensor = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Shm), None)
1683            .expect("Failed to create SHM tensor");
1684
1685        // Verify we can map and write to it
1686        let mut map = tensor.map().expect("Failed to map SHM tensor");
1687        map.as_mut_slice().fill(0xAB);
1688
1689        // Verify the data was written correctly
1690        assert!(
1691            map.as_slice().iter().all(|&b| b == 0xAB),
1692            "SHM tensor data should be writable and readable"
1693        );
1694    }
1695}