Skip to main content

edgefirst_tensor/
lib.rs

1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4/*!
5EdgeFirst HAL - Tensor Module
6
7The `edgefirst_tensor` crate provides a unified interface for managing multi-dimensional arrays (tensors)
8with support for different memory types, including Direct Memory Access (DMA), POSIX Shared Memory (Shm),
9and system memory. The crate defines traits and structures for creating, reshaping, and mapping tensors into memory.
10
11## Examples
12```rust
13use edgefirst_tensor::{Error, Tensor, TensorMemory, TensorTrait};
14# fn main() -> Result<(), Error> {
15let tensor = Tensor::<f32>::new(&[2, 3, 4], Some(TensorMemory::Mem), Some("test_tensor"))?;
16assert_eq!(tensor.memory(), TensorMemory::Mem);
17assert_eq!(tensor.name(), "test_tensor");
18#    Ok(())
19# }
20```
21## Overview
22The main structures and traits provided by the `edgefirst_tensor` crate is the `TensorTrait` and `TensorMapTrait` traits,
23which define the behavior of Tensors and their memory mappings, respectively.
24The `Tensor` enum encapsulates different tensor implementations based on the memory type, while the `TensorMap` enum
25provides access to the underlying data.
26```
27 */
28#[cfg(target_os = "linux")]
29mod dma;
30#[cfg(target_os = "linux")]
31mod dmabuf;
32mod error;
33mod mem;
34mod pbo;
35#[cfg(unix)]
36mod shm;
37
38#[cfg(target_os = "linux")]
39pub use crate::dma::{DmaMap, DmaTensor};
40pub use crate::mem::{MemMap, MemTensor};
41pub use crate::pbo::{PboMap, PboMapping, PboOps, PboTensor};
42#[cfg(unix)]
43pub use crate::shm::{ShmMap, ShmTensor};
44pub use error::{Error, Result};
45use num_traits::Num;
46#[cfg(unix)]
47use std::os::fd::OwnedFd;
48use std::{
49    fmt,
50    ops::{Deref, DerefMut},
51    sync::{
52        atomic::{AtomicU64, Ordering},
53        Arc, Weak,
54    },
55};
56
57/// Monotonic counter for buffer identity IDs.
58static NEXT_BUFFER_ID: AtomicU64 = AtomicU64::new(1);
59
60/// Unique identity for a tensor's underlying buffer.
61///
62/// Created fresh on every buffer allocation or import. The `id` is a monotonic
63/// u64 used as a cache key. The `guard` is an `Arc<()>` whose weak references
64/// allow downstream caches to detect when the buffer has been dropped.
65#[derive(Debug, Clone)]
66pub struct BufferIdentity {
67    id: u64,
68    guard: Arc<()>,
69}
70
71impl BufferIdentity {
72    /// Create a new unique buffer identity.
73    pub fn new() -> Self {
74        Self {
75            id: NEXT_BUFFER_ID.fetch_add(1, Ordering::Relaxed),
76            guard: Arc::new(()),
77        }
78    }
79
80    /// Unique identifier for this buffer. Changes when the buffer changes.
81    pub fn id(&self) -> u64 {
82        self.id
83    }
84
85    /// Returns a weak reference to the buffer guard. Goes dead when the
86    /// owning Tensor is dropped (and no clones remain).
87    pub fn weak(&self) -> Weak<()> {
88        Arc::downgrade(&self.guard)
89    }
90}
91
92impl Default for BufferIdentity {
93    fn default() -> Self {
94        Self::new()
95    }
96}
97
98#[cfg(target_os = "linux")]
99use nix::sys::stat::{major, minor};
100
101pub trait TensorTrait<T>: Send + Sync
102where
103    T: Num + Clone + fmt::Debug,
104{
105    /// Create a new tensor with the given shape and optional name. If no name
106    /// is given, a random name will be generated.
107    fn new(shape: &[usize], name: Option<&str>) -> Result<Self>
108    where
109        Self: Sized;
110
111    #[cfg(unix)]
112    /// Create a new tensor using the given file descriptor, shape, and optional
113    /// name. If no name is given, a random name will be generated.
114    ///
115    /// On Linux: Inspects the fd to determine DMA vs SHM based on device major/minor.
116    /// On other Unix (macOS): Always creates SHM tensor.
117    fn from_fd(fd: std::os::fd::OwnedFd, shape: &[usize], name: Option<&str>) -> Result<Self>
118    where
119        Self: Sized;
120
121    #[cfg(unix)]
122    /// Clone the file descriptor associated with this tensor.
123    fn clone_fd(&self) -> Result<std::os::fd::OwnedFd>;
124
125    /// Get the memory type of this tensor.
126    fn memory(&self) -> TensorMemory;
127
128    /// Get the name of this tensor.
129    fn name(&self) -> String;
130
131    /// Get the number of elements in this tensor.
132    fn len(&self) -> usize {
133        self.shape().iter().product()
134    }
135
136    /// Check if the tensor is empty.
137    fn is_empty(&self) -> bool {
138        self.len() == 0
139    }
140
141    /// Get the size in bytes of this tensor.
142    fn size(&self) -> usize {
143        self.len() * std::mem::size_of::<T>()
144    }
145
146    /// Get the shape of this tensor.
147    fn shape(&self) -> &[usize];
148
149    /// Reshape this tensor to the given shape. The total number of elements
150    /// must remain the same.
151    fn reshape(&mut self, shape: &[usize]) -> Result<()>;
152
153    /// Map the tensor into memory and return a TensorMap for accessing the
154    /// data.
155    fn map(&self) -> Result<TensorMap<T>>;
156
157    /// Get the buffer identity for cache keying and liveness tracking.
158    fn buffer_identity(&self) -> &BufferIdentity;
159}
160
161pub trait TensorMapTrait<T>
162where
163    T: Num + Clone + fmt::Debug,
164{
165    /// Get the shape of this tensor map.
166    fn shape(&self) -> &[usize];
167
168    /// Unmap the tensor from memory.
169    fn unmap(&mut self);
170
171    /// Get the number of elements in this tensor map.
172    fn len(&self) -> usize {
173        self.shape().iter().product()
174    }
175
176    /// Check if the tensor map is empty.
177    fn is_empty(&self) -> bool {
178        self.len() == 0
179    }
180
181    /// Get the size in bytes of this tensor map.
182    fn size(&self) -> usize {
183        self.len() * std::mem::size_of::<T>()
184    }
185
186    /// Get a slice to the data in this tensor map.
187    fn as_slice(&self) -> &[T];
188
189    /// Get a mutable slice to the data in this tensor map.
190    fn as_mut_slice(&mut self) -> &mut [T];
191
192    #[cfg(feature = "ndarray")]
193    /// Get an ndarray ArrayView of the tensor data.
194    fn view(&'_ self) -> Result<ndarray::ArrayView<'_, T, ndarray::Dim<ndarray::IxDynImpl>>> {
195        Ok(ndarray::ArrayView::from_shape(
196            self.shape(),
197            self.as_slice(),
198        )?)
199    }
200
201    #[cfg(feature = "ndarray")]
202    /// Get an ndarray ArrayViewMut of the tensor data.
203    fn view_mut(
204        &'_ mut self,
205    ) -> Result<ndarray::ArrayViewMut<'_, T, ndarray::Dim<ndarray::IxDynImpl>>> {
206        let shape = self.shape().to_vec();
207        Ok(ndarray::ArrayViewMut::from_shape(
208            shape,
209            self.as_mut_slice(),
210        )?)
211    }
212}
213
214#[derive(Debug, Clone, Copy, PartialEq, Eq)]
215pub enum TensorMemory {
216    #[cfg(target_os = "linux")]
217    /// Direct Memory Access (DMA) allocation. Incurs additional
218    /// overhead for memory reading/writing with the CPU.  Allows for
219    /// hardware acceleration when supported.
220    Dma,
221    #[cfg(unix)]
222    /// POSIX Shared Memory allocation. Suitable for inter-process
223    /// communication, but not suitable for hardware acceleration.
224    Shm,
225
226    /// Regular system memory allocation
227    Mem,
228
229    /// OpenGL Pixel Buffer Object memory. Created by ImageProcessor
230    /// when DMA-buf is unavailable but OpenGL is present.
231    Pbo,
232}
233
234impl From<TensorMemory> for String {
235    fn from(memory: TensorMemory) -> Self {
236        match memory {
237            #[cfg(target_os = "linux")]
238            TensorMemory::Dma => "dma".to_owned(),
239            #[cfg(unix)]
240            TensorMemory::Shm => "shm".to_owned(),
241            TensorMemory::Mem => "mem".to_owned(),
242            TensorMemory::Pbo => "pbo".to_owned(),
243        }
244    }
245}
246
247impl TryFrom<&str> for TensorMemory {
248    type Error = Error;
249
250    fn try_from(s: &str) -> Result<Self> {
251        match s {
252            #[cfg(target_os = "linux")]
253            "dma" => Ok(TensorMemory::Dma),
254            #[cfg(unix)]
255            "shm" => Ok(TensorMemory::Shm),
256            "mem" => Ok(TensorMemory::Mem),
257            "pbo" => Ok(TensorMemory::Pbo),
258            _ => Err(Error::InvalidMemoryType(s.to_owned())),
259        }
260    }
261}
262
263#[derive(Debug)]
264pub enum Tensor<T>
265where
266    T: Num + Clone + fmt::Debug + Send + Sync,
267{
268    #[cfg(target_os = "linux")]
269    Dma(DmaTensor<T>),
270    #[cfg(unix)]
271    Shm(ShmTensor<T>),
272    Mem(MemTensor<T>),
273    Pbo(PboTensor<T>),
274}
275
276impl<T> Tensor<T>
277where
278    T: Num + Clone + fmt::Debug + Send + Sync,
279{
280    /// Create a new tensor with the given shape, memory type, and optional
281    /// name. If no name is given, a random name will be generated. If no
282    /// memory type is given, the best available memory type will be chosen
283    /// based on the platform and environment variables.
284    ///
285    /// On Linux platforms, the order of preference is: Dma -> Shm -> Mem.
286    /// On other Unix platforms (macOS), the order is: Shm -> Mem.
287    /// On non-Unix platforms, only Mem is available.
288    ///
289    /// # Environment Variables
290    /// - `EDGEFIRST_TENSOR_FORCE_MEM`: If set to a non-zero and non-false
291    ///   value, forces the use of regular system memory allocation
292    ///   (`TensorMemory::Mem`) regardless of platform capabilities.
293    ///
294    /// # Example
295    /// ```rust
296    /// use edgefirst_tensor::{Error, Tensor, TensorMemory, TensorTrait};
297    /// # fn main() -> Result<(), Error> {
298    /// let tensor = Tensor::<f32>::new(&[2, 3, 4], Some(TensorMemory::Mem), Some("test_tensor"))?;
299    /// assert_eq!(tensor.memory(), TensorMemory::Mem);
300    /// assert_eq!(tensor.name(), "test_tensor");
301    /// #    Ok(())
302    /// # }
303    /// ```
304    pub fn new(shape: &[usize], memory: Option<TensorMemory>, name: Option<&str>) -> Result<Self> {
305        match memory {
306            #[cfg(target_os = "linux")]
307            Some(TensorMemory::Dma) => DmaTensor::<T>::new(shape, name).map(Tensor::Dma),
308            #[cfg(unix)]
309            Some(TensorMemory::Shm) => ShmTensor::<T>::new(shape, name).map(Tensor::Shm),
310            Some(TensorMemory::Mem) => MemTensor::<T>::new(shape, name).map(Tensor::Mem),
311            Some(TensorMemory::Pbo) => Err(crate::error::Error::NotImplemented(
312                "PboTensor cannot be created via Tensor::new() — use ImageProcessor::create_image()".to_owned(),
313            )),
314            None => {
315                if std::env::var("EDGEFIRST_TENSOR_FORCE_MEM")
316                    .is_ok_and(|x| x != "0" && x.to_lowercase() != "false")
317                {
318                    MemTensor::<T>::new(shape, name).map(Tensor::Mem)
319                } else {
320                    #[cfg(target_os = "linux")]
321                    {
322                        // Linux: Try DMA -> SHM -> Mem
323                        match DmaTensor::<T>::new(shape, name) {
324                            Ok(tensor) => Ok(Tensor::Dma(tensor)),
325                            Err(_) => match ShmTensor::<T>::new(shape, name).map(Tensor::Shm) {
326                                Ok(tensor) => Ok(tensor),
327                                Err(_) => MemTensor::<T>::new(shape, name).map(Tensor::Mem),
328                            },
329                        }
330                    }
331                    #[cfg(all(unix, not(target_os = "linux")))]
332                    {
333                        // macOS/BSD: Try SHM -> Mem (no DMA)
334                        match ShmTensor::<T>::new(shape, name) {
335                            Ok(tensor) => Ok(Tensor::Shm(tensor)),
336                            Err(_) => MemTensor::<T>::new(shape, name).map(Tensor::Mem),
337                        }
338                    }
339                    #[cfg(not(unix))]
340                    {
341                        // Windows/other: Mem only
342                        MemTensor::<T>::new(shape, name).map(Tensor::Mem)
343                    }
344                }
345            }
346        }
347    }
348}
349
350impl<T> TensorTrait<T> for Tensor<T>
351where
352    T: Num + Clone + fmt::Debug + Send + Sync,
353{
354    fn new(shape: &[usize], name: Option<&str>) -> Result<Self> {
355        Self::new(shape, None, name)
356    }
357
358    #[cfg(unix)]
359    /// Create a new tensor using the given file descriptor, shape, and optional
360    /// name. If no name is given, a random name will be generated.
361    ///
362    /// On Linux: Inspects the file descriptor to determine the appropriate tensor type
363    /// (Dma or Shm) based on the device major and minor numbers.
364    /// On other Unix (macOS): Always creates SHM tensor.
365    fn from_fd(fd: OwnedFd, shape: &[usize], name: Option<&str>) -> Result<Self> {
366        #[cfg(target_os = "linux")]
367        {
368            use nix::sys::stat::fstat;
369
370            let stat = fstat(&fd)?;
371            let major = major(stat.st_dev);
372            let minor = minor(stat.st_dev);
373
374            log::debug!("Creating tensor from fd: major={major}, minor={minor}");
375
376            if major != 0 {
377                // Dma and Shm tensors are expected to have major number 0
378                return Err(Error::UnknownDeviceType(major, minor));
379            }
380
381            match minor {
382                9 | 10 => {
383                    // minor number 9 & 10 indicates DMA memory
384                    DmaTensor::<T>::from_fd(fd, shape, name).map(Tensor::Dma)
385                }
386                _ => {
387                    // other minor numbers are assumed to be shared memory
388                    ShmTensor::<T>::from_fd(fd, shape, name).map(Tensor::Shm)
389                }
390            }
391        }
392        #[cfg(all(unix, not(target_os = "linux")))]
393        {
394            // On macOS/BSD, always use SHM (no DMA support)
395            ShmTensor::<T>::from_fd(fd, shape, name).map(Tensor::Shm)
396        }
397    }
398
399    #[cfg(unix)]
400    fn clone_fd(&self) -> Result<OwnedFd> {
401        match self {
402            #[cfg(target_os = "linux")]
403            Tensor::Dma(t) => t.clone_fd(),
404            Tensor::Shm(t) => t.clone_fd(),
405            Tensor::Mem(t) => t.clone_fd(),
406            Tensor::Pbo(t) => t.clone_fd(),
407        }
408    }
409
410    fn memory(&self) -> TensorMemory {
411        match self {
412            #[cfg(target_os = "linux")]
413            Tensor::Dma(_) => TensorMemory::Dma,
414            #[cfg(unix)]
415            Tensor::Shm(_) => TensorMemory::Shm,
416            Tensor::Mem(_) => TensorMemory::Mem,
417            Tensor::Pbo(_) => TensorMemory::Pbo,
418        }
419    }
420
421    fn name(&self) -> String {
422        match self {
423            #[cfg(target_os = "linux")]
424            Tensor::Dma(t) => t.name(),
425            #[cfg(unix)]
426            Tensor::Shm(t) => t.name(),
427            Tensor::Mem(t) => t.name(),
428            Tensor::Pbo(t) => t.name(),
429        }
430    }
431
432    fn shape(&self) -> &[usize] {
433        match self {
434            #[cfg(target_os = "linux")]
435            Tensor::Dma(t) => t.shape(),
436            #[cfg(unix)]
437            Tensor::Shm(t) => t.shape(),
438            Tensor::Mem(t) => t.shape(),
439            Tensor::Pbo(t) => t.shape(),
440        }
441    }
442
443    fn reshape(&mut self, shape: &[usize]) -> Result<()> {
444        match self {
445            #[cfg(target_os = "linux")]
446            Tensor::Dma(t) => t.reshape(shape),
447            #[cfg(unix)]
448            Tensor::Shm(t) => t.reshape(shape),
449            Tensor::Mem(t) => t.reshape(shape),
450            Tensor::Pbo(t) => t.reshape(shape),
451        }
452    }
453
454    fn map(&self) -> Result<TensorMap<T>> {
455        match self {
456            #[cfg(target_os = "linux")]
457            Tensor::Dma(t) => t.map(),
458            #[cfg(unix)]
459            Tensor::Shm(t) => t.map(),
460            Tensor::Mem(t) => t.map(),
461            Tensor::Pbo(t) => t.map(),
462        }
463    }
464
465    fn buffer_identity(&self) -> &BufferIdentity {
466        match self {
467            #[cfg(target_os = "linux")]
468            Tensor::Dma(t) => t.buffer_identity(),
469            #[cfg(unix)]
470            Tensor::Shm(t) => t.buffer_identity(),
471            Tensor::Mem(t) => t.buffer_identity(),
472            Tensor::Pbo(t) => t.buffer_identity(),
473        }
474    }
475}
476
477pub enum TensorMap<T>
478where
479    T: Num + Clone + fmt::Debug,
480{
481    #[cfg(target_os = "linux")]
482    Dma(DmaMap<T>),
483    #[cfg(unix)]
484    Shm(ShmMap<T>),
485    Mem(MemMap<T>),
486    Pbo(PboMap<T>),
487}
488
489impl<T> TensorMapTrait<T> for TensorMap<T>
490where
491    T: Num + Clone + fmt::Debug,
492{
493    fn shape(&self) -> &[usize] {
494        match self {
495            #[cfg(target_os = "linux")]
496            TensorMap::Dma(map) => map.shape(),
497            #[cfg(unix)]
498            TensorMap::Shm(map) => map.shape(),
499            TensorMap::Mem(map) => map.shape(),
500            TensorMap::Pbo(map) => map.shape(),
501        }
502    }
503
504    fn unmap(&mut self) {
505        match self {
506            #[cfg(target_os = "linux")]
507            TensorMap::Dma(map) => map.unmap(),
508            #[cfg(unix)]
509            TensorMap::Shm(map) => map.unmap(),
510            TensorMap::Mem(map) => map.unmap(),
511            TensorMap::Pbo(map) => map.unmap(),
512        }
513    }
514
515    fn as_slice(&self) -> &[T] {
516        match self {
517            #[cfg(target_os = "linux")]
518            TensorMap::Dma(map) => map.as_slice(),
519            #[cfg(unix)]
520            TensorMap::Shm(map) => map.as_slice(),
521            TensorMap::Mem(map) => map.as_slice(),
522            TensorMap::Pbo(map) => map.as_slice(),
523        }
524    }
525
526    fn as_mut_slice(&mut self) -> &mut [T] {
527        match self {
528            #[cfg(target_os = "linux")]
529            TensorMap::Dma(map) => map.as_mut_slice(),
530            #[cfg(unix)]
531            TensorMap::Shm(map) => map.as_mut_slice(),
532            TensorMap::Mem(map) => map.as_mut_slice(),
533            TensorMap::Pbo(map) => map.as_mut_slice(),
534        }
535    }
536}
537
538impl<T> Deref for TensorMap<T>
539where
540    T: Num + Clone + fmt::Debug,
541{
542    type Target = [T];
543
544    fn deref(&self) -> &[T] {
545        match self {
546            #[cfg(target_os = "linux")]
547            TensorMap::Dma(map) => map.deref(),
548            #[cfg(unix)]
549            TensorMap::Shm(map) => map.deref(),
550            TensorMap::Mem(map) => map.deref(),
551            TensorMap::Pbo(map) => map.deref(),
552        }
553    }
554}
555
556impl<T> DerefMut for TensorMap<T>
557where
558    T: Num + Clone + fmt::Debug,
559{
560    fn deref_mut(&mut self) -> &mut [T] {
561        match self {
562            #[cfg(target_os = "linux")]
563            TensorMap::Dma(map) => map.deref_mut(),
564            #[cfg(unix)]
565            TensorMap::Shm(map) => map.deref_mut(),
566            TensorMap::Mem(map) => map.deref_mut(),
567            TensorMap::Pbo(map) => map.deref_mut(),
568        }
569    }
570}
571
572// ============================================================================
573// Platform availability helpers
574// ============================================================================
575
576/// Check if DMA memory allocation is available on this system.
577///
578/// Returns `true` only on Linux systems with DMA-BUF heap access (typically
579/// requires running as root or membership in a video/render group).
580/// Always returns `false` on non-Linux platforms (macOS, Windows, etc.).
581///
582/// This function caches its result after the first call for efficiency.
583#[cfg(target_os = "linux")]
584static DMA_AVAILABLE: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
585
586/// Check if DMA memory allocation is available on this system.
587#[cfg(target_os = "linux")]
588pub fn is_dma_available() -> bool {
589    *DMA_AVAILABLE.get_or_init(|| Tensor::<u8>::new(&[64], Some(TensorMemory::Dma), None).is_ok())
590}
591
592/// Check if DMA memory allocation is available on this system.
593///
594/// Always returns `false` on non-Linux platforms since DMA-BUF is Linux-specific.
595#[cfg(not(target_os = "linux"))]
596pub fn is_dma_available() -> bool {
597    false
598}
599
600/// Check if POSIX shared memory allocation is available on this system.
601///
602/// Returns `true` on Unix systems (Linux, macOS, BSD) where POSIX shared memory
603/// is supported. Always returns `false` on non-Unix platforms (Windows).
604///
605/// This function caches its result after the first call for efficiency.
606#[cfg(unix)]
607static SHM_AVAILABLE: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
608
609/// Check if POSIX shared memory allocation is available on this system.
610#[cfg(unix)]
611pub fn is_shm_available() -> bool {
612    *SHM_AVAILABLE.get_or_init(|| Tensor::<u8>::new(&[64], Some(TensorMemory::Shm), None).is_ok())
613}
614
615/// Check if POSIX shared memory allocation is available on this system.
616///
617/// Always returns `false` on non-Unix platforms since POSIX SHM is Unix-specific.
618#[cfg(not(unix))]
619pub fn is_shm_available() -> bool {
620    false
621}
622
623#[cfg(test)]
624mod tests {
625    #[cfg(target_os = "linux")]
626    use nix::unistd::{access, AccessFlags};
627    #[cfg(target_os = "linux")]
628    use std::io::Write as _;
629    use std::sync::RwLock;
630
631    use super::*;
632
633    #[ctor::ctor]
634    fn init() {
635        env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init();
636    }
637
638    /// Macro to get the current function name for logging in tests.
639    #[cfg(target_os = "linux")]
640    macro_rules! function {
641        () => {{
642            fn f() {}
643            fn type_name_of<T>(_: T) -> &'static str {
644                std::any::type_name::<T>()
645            }
646            let name = type_name_of(f);
647
648            // Find and cut the rest of the path
649            match &name[..name.len() - 3].rfind(':') {
650                Some(pos) => &name[pos + 1..name.len() - 3],
651                None => &name[..name.len() - 3],
652            }
653        }};
654    }
655
656    #[test]
657    #[cfg(target_os = "linux")]
658    fn test_tensor() {
659        let _lock = FD_LOCK.read().unwrap();
660        let shape = vec![1];
661        let tensor = DmaTensor::<f32>::new(&shape, Some("dma_tensor"));
662        let dma_enabled = tensor.is_ok();
663
664        let tensor = Tensor::<f32>::new(&shape, None, None).expect("Failed to create tensor");
665        match dma_enabled {
666            true => assert_eq!(tensor.memory(), TensorMemory::Dma),
667            false => assert_eq!(tensor.memory(), TensorMemory::Shm),
668        }
669    }
670
671    #[test]
672    #[cfg(all(unix, not(target_os = "linux")))]
673    fn test_tensor() {
674        let shape = vec![1];
675        let tensor = Tensor::<f32>::new(&shape, None, None).expect("Failed to create tensor");
676        // On macOS/BSD, auto-detection tries SHM first, falls back to Mem
677        assert!(
678            tensor.memory() == TensorMemory::Shm || tensor.memory() == TensorMemory::Mem,
679            "Expected SHM or Mem on macOS, got {:?}",
680            tensor.memory()
681        );
682    }
683
684    #[test]
685    #[cfg(not(unix))]
686    fn test_tensor() {
687        let shape = vec![1];
688        let tensor = Tensor::<f32>::new(&shape, None, None).expect("Failed to create tensor");
689        assert_eq!(tensor.memory(), TensorMemory::Mem);
690    }
691
692    #[test]
693    #[cfg(target_os = "linux")]
694    fn test_dma_tensor() {
695        let _lock = FD_LOCK.read().unwrap();
696        match access(
697            "/dev/dma_heap/linux,cma",
698            AccessFlags::R_OK | AccessFlags::W_OK,
699        ) {
700            Ok(_) => println!("/dev/dma_heap/linux,cma is available"),
701            Err(_) => match access(
702                "/dev/dma_heap/system",
703                AccessFlags::R_OK | AccessFlags::W_OK,
704            ) {
705                Ok(_) => println!("/dev/dma_heap/system is available"),
706                Err(e) => {
707                    writeln!(
708                        &mut std::io::stdout(),
709                        "[WARNING] DMA Heap is unavailable: {e}"
710                    )
711                    .unwrap();
712                    return;
713                }
714            },
715        }
716
717        let shape = vec![2, 3, 4];
718        let tensor =
719            DmaTensor::<f32>::new(&shape, Some("test_tensor")).expect("Failed to create tensor");
720
721        const DUMMY_VALUE: f32 = 12.34;
722
723        assert_eq!(tensor.memory(), TensorMemory::Dma);
724        assert_eq!(tensor.name(), "test_tensor");
725        assert_eq!(tensor.shape(), &shape);
726        assert_eq!(tensor.size(), 2 * 3 * 4 * std::mem::size_of::<f32>());
727        assert_eq!(tensor.len(), 2 * 3 * 4);
728
729        {
730            let mut tensor_map = tensor.map().expect("Failed to map DMA memory");
731            tensor_map.fill(42.0);
732            assert!(tensor_map.iter().all(|&x| x == 42.0));
733        }
734
735        {
736            let shared = Tensor::<f32>::from_fd(
737                tensor
738                    .clone_fd()
739                    .expect("Failed to duplicate tensor file descriptor"),
740                &shape,
741                Some("test_tensor_shared"),
742            )
743            .expect("Failed to create tensor from fd");
744
745            assert_eq!(shared.memory(), TensorMemory::Dma);
746            assert_eq!(shared.name(), "test_tensor_shared");
747            assert_eq!(shared.shape(), &shape);
748
749            let mut tensor_map = shared.map().expect("Failed to map DMA memory from fd");
750            tensor_map.fill(DUMMY_VALUE);
751            assert!(tensor_map.iter().all(|&x| x == DUMMY_VALUE));
752        }
753
754        {
755            let tensor_map = tensor.map().expect("Failed to map DMA memory");
756            assert!(tensor_map.iter().all(|&x| x == DUMMY_VALUE));
757        }
758
759        let mut tensor = DmaTensor::<u8>::new(&shape, None).expect("Failed to create tensor");
760        assert_eq!(tensor.shape(), &shape);
761        let new_shape = vec![3, 4, 4];
762        assert!(
763            tensor.reshape(&new_shape).is_err(),
764            "Reshape should fail due to size mismatch"
765        );
766        assert_eq!(tensor.shape(), &shape, "Shape should remain unchanged");
767
768        let new_shape = vec![2, 3, 4];
769        tensor.reshape(&new_shape).expect("Reshape should succeed");
770        assert_eq!(
771            tensor.shape(),
772            &new_shape,
773            "Shape should be updated after successful reshape"
774        );
775
776        {
777            let mut tensor_map = tensor.map().expect("Failed to map DMA memory");
778            tensor_map.fill(1);
779            assert!(tensor_map.iter().all(|&x| x == 1));
780        }
781
782        {
783            let mut tensor_map = tensor.map().expect("Failed to map DMA memory");
784            tensor_map[2] = 42;
785            assert_eq!(tensor_map[1], 1, "Value at index 1 should be 1");
786            assert_eq!(tensor_map[2], 42, "Value at index 2 should be 42");
787        }
788    }
789
790    #[test]
791    #[cfg(unix)]
792    fn test_shm_tensor() {
793        let _lock = FD_LOCK.read().unwrap();
794        let shape = vec![2, 3, 4];
795        let tensor =
796            ShmTensor::<f32>::new(&shape, Some("test_tensor")).expect("Failed to create tensor");
797        assert_eq!(tensor.shape(), &shape);
798        assert_eq!(tensor.size(), 2 * 3 * 4 * std::mem::size_of::<f32>());
799        assert_eq!(tensor.name(), "test_tensor");
800
801        const DUMMY_VALUE: f32 = 12.34;
802        {
803            let mut tensor_map = tensor.map().expect("Failed to map shared memory");
804            tensor_map.fill(42.0);
805            assert!(tensor_map.iter().all(|&x| x == 42.0));
806        }
807
808        {
809            let shared = Tensor::<f32>::from_fd(
810                tensor
811                    .clone_fd()
812                    .expect("Failed to duplicate tensor file descriptor"),
813                &shape,
814                Some("test_tensor_shared"),
815            )
816            .expect("Failed to create tensor from fd");
817
818            assert_eq!(shared.memory(), TensorMemory::Shm);
819            assert_eq!(shared.name(), "test_tensor_shared");
820            assert_eq!(shared.shape(), &shape);
821
822            let mut tensor_map = shared.map().expect("Failed to map shared memory from fd");
823            tensor_map.fill(DUMMY_VALUE);
824            assert!(tensor_map.iter().all(|&x| x == DUMMY_VALUE));
825        }
826
827        {
828            let tensor_map = tensor.map().expect("Failed to map shared memory");
829            assert!(tensor_map.iter().all(|&x| x == DUMMY_VALUE));
830        }
831
832        let mut tensor = ShmTensor::<u8>::new(&shape, None).expect("Failed to create tensor");
833        assert_eq!(tensor.shape(), &shape);
834        let new_shape = vec![3, 4, 4];
835        assert!(
836            tensor.reshape(&new_shape).is_err(),
837            "Reshape should fail due to size mismatch"
838        );
839        assert_eq!(tensor.shape(), &shape, "Shape should remain unchanged");
840
841        let new_shape = vec![2, 3, 4];
842        tensor.reshape(&new_shape).expect("Reshape should succeed");
843        assert_eq!(
844            tensor.shape(),
845            &new_shape,
846            "Shape should be updated after successful reshape"
847        );
848
849        {
850            let mut tensor_map = tensor.map().expect("Failed to map shared memory");
851            tensor_map.fill(1);
852            assert!(tensor_map.iter().all(|&x| x == 1));
853        }
854
855        {
856            let mut tensor_map = tensor.map().expect("Failed to map shared memory");
857            tensor_map[2] = 42;
858            assert_eq!(tensor_map[1], 1, "Value at index 1 should be 1");
859            assert_eq!(tensor_map[2], 42, "Value at index 2 should be 42");
860        }
861    }
862
863    #[test]
864    fn test_mem_tensor() {
865        let shape = vec![2, 3, 4];
866        let tensor =
867            MemTensor::<f32>::new(&shape, Some("test_tensor")).expect("Failed to create tensor");
868        assert_eq!(tensor.shape(), &shape);
869        assert_eq!(tensor.size(), 2 * 3 * 4 * std::mem::size_of::<f32>());
870        assert_eq!(tensor.name(), "test_tensor");
871
872        {
873            let mut tensor_map = tensor.map().expect("Failed to map memory");
874            tensor_map.fill(42.0);
875            assert!(tensor_map.iter().all(|&x| x == 42.0));
876        }
877
878        let mut tensor = MemTensor::<u8>::new(&shape, None).expect("Failed to create tensor");
879        assert_eq!(tensor.shape(), &shape);
880        let new_shape = vec![3, 4, 4];
881        assert!(
882            tensor.reshape(&new_shape).is_err(),
883            "Reshape should fail due to size mismatch"
884        );
885        assert_eq!(tensor.shape(), &shape, "Shape should remain unchanged");
886
887        let new_shape = vec![2, 3, 4];
888        tensor.reshape(&new_shape).expect("Reshape should succeed");
889        assert_eq!(
890            tensor.shape(),
891            &new_shape,
892            "Shape should be updated after successful reshape"
893        );
894
895        {
896            let mut tensor_map = tensor.map().expect("Failed to map memory");
897            tensor_map.fill(1);
898            assert!(tensor_map.iter().all(|&x| x == 1));
899        }
900
901        {
902            let mut tensor_map = tensor.map().expect("Failed to map memory");
903            tensor_map[2] = 42;
904            assert_eq!(tensor_map[1], 1, "Value at index 1 should be 1");
905            assert_eq!(tensor_map[2], 42, "Value at index 2 should be 42");
906        }
907    }
908
909    #[test]
910    #[cfg(target_os = "linux")]
911    fn test_dma_no_fd_leaks() {
912        let _lock = FD_LOCK.write().unwrap();
913        if !is_dma_available() {
914            log::warn!(
915                "SKIPPED: {} - DMA memory allocation not available (permission denied or no DMA-BUF support)",
916                function!()
917            );
918            return;
919        }
920
921        let proc = procfs::process::Process::myself()
922            .expect("Failed to get current process using /proc/self");
923
924        let start_open_fds = proc
925            .fd_count()
926            .expect("Failed to get open file descriptor count");
927
928        for _ in 0..100 {
929            let tensor = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Dma), None)
930                .expect("Failed to create tensor");
931            let mut map = tensor.map().unwrap();
932            map.as_mut_slice().fill(233);
933        }
934
935        let end_open_fds = proc
936            .fd_count()
937            .expect("Failed to get open file descriptor count");
938
939        assert_eq!(
940            start_open_fds, end_open_fds,
941            "File descriptor leak detected: {} -> {}",
942            start_open_fds, end_open_fds
943        );
944    }
945
946    #[test]
947    #[cfg(target_os = "linux")]
948    fn test_dma_from_fd_no_fd_leaks() {
949        let _lock = FD_LOCK.write().unwrap();
950        if !is_dma_available() {
951            log::warn!(
952                "SKIPPED: {} - DMA memory allocation not available (permission denied or no DMA-BUF support)",
953                function!()
954            );
955            return;
956        }
957
958        let proc = procfs::process::Process::myself()
959            .expect("Failed to get current process using /proc/self");
960
961        let start_open_fds = proc
962            .fd_count()
963            .expect("Failed to get open file descriptor count");
964
965        let orig = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Dma), None).unwrap();
966
967        for _ in 0..100 {
968            let tensor =
969                Tensor::<u8>::from_fd(orig.clone_fd().unwrap(), orig.shape(), None).unwrap();
970            let mut map = tensor.map().unwrap();
971            map.as_mut_slice().fill(233);
972        }
973        drop(orig);
974
975        let end_open_fds = proc.fd_count().unwrap();
976
977        assert_eq!(
978            start_open_fds, end_open_fds,
979            "File descriptor leak detected: {} -> {}",
980            start_open_fds, end_open_fds
981        );
982    }
983
984    #[test]
985    #[cfg(target_os = "linux")]
986    fn test_shm_no_fd_leaks() {
987        let _lock = FD_LOCK.write().unwrap();
988        if !is_shm_available() {
989            log::warn!(
990                "SKIPPED: {} - SHM memory allocation not available (permission denied or no SHM support)",
991                function!()
992            );
993            return;
994        }
995
996        let proc = procfs::process::Process::myself()
997            .expect("Failed to get current process using /proc/self");
998
999        let start_open_fds = proc
1000            .fd_count()
1001            .expect("Failed to get open file descriptor count");
1002
1003        for _ in 0..100 {
1004            let tensor = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Shm), None)
1005                .expect("Failed to create tensor");
1006            let mut map = tensor.map().unwrap();
1007            map.as_mut_slice().fill(233);
1008        }
1009
1010        let end_open_fds = proc
1011            .fd_count()
1012            .expect("Failed to get open file descriptor count");
1013
1014        assert_eq!(
1015            start_open_fds, end_open_fds,
1016            "File descriptor leak detected: {} -> {}",
1017            start_open_fds, end_open_fds
1018        );
1019    }
1020
1021    #[test]
1022    #[cfg(target_os = "linux")]
1023    fn test_shm_from_fd_no_fd_leaks() {
1024        let _lock = FD_LOCK.write().unwrap();
1025        if !is_shm_available() {
1026            log::warn!(
1027                "SKIPPED: {} - SHM memory allocation not available (permission denied or no SHM support)",
1028                function!()
1029            );
1030            return;
1031        }
1032
1033        let proc = procfs::process::Process::myself()
1034            .expect("Failed to get current process using /proc/self");
1035
1036        let start_open_fds = proc
1037            .fd_count()
1038            .expect("Failed to get open file descriptor count");
1039
1040        let orig = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Shm), None).unwrap();
1041
1042        for _ in 0..100 {
1043            let tensor =
1044                Tensor::<u8>::from_fd(orig.clone_fd().unwrap(), orig.shape(), None).unwrap();
1045            let mut map = tensor.map().unwrap();
1046            map.as_mut_slice().fill(233);
1047        }
1048        drop(orig);
1049
1050        let end_open_fds = proc.fd_count().unwrap();
1051
1052        assert_eq!(
1053            start_open_fds, end_open_fds,
1054            "File descriptor leak detected: {} -> {}",
1055            start_open_fds, end_open_fds
1056        );
1057    }
1058
1059    #[cfg(feature = "ndarray")]
1060    #[test]
1061    fn test_ndarray() {
1062        let _lock = FD_LOCK.read().unwrap();
1063        let shape = vec![2, 3, 4];
1064        let tensor = Tensor::<f32>::new(&shape, None, None).expect("Failed to create tensor");
1065
1066        let mut tensor_map = tensor.map().expect("Failed to map tensor memory");
1067        tensor_map.fill(1.0);
1068
1069        let view = tensor_map.view().expect("Failed to get ndarray view");
1070        assert_eq!(view.shape(), &[2, 3, 4]);
1071        assert!(view.iter().all(|&x| x == 1.0));
1072
1073        let mut view_mut = tensor_map
1074            .view_mut()
1075            .expect("Failed to get mutable ndarray view");
1076        view_mut[[0, 0, 0]] = 42.0;
1077        assert_eq!(view_mut[[0, 0, 0]], 42.0);
1078        assert_eq!(tensor_map[0], 42.0, "Value at index 0 should be 42");
1079    }
1080
1081    #[test]
1082    fn test_buffer_identity_unique() {
1083        let id1 = BufferIdentity::new();
1084        let id2 = BufferIdentity::new();
1085        assert_ne!(
1086            id1.id(),
1087            id2.id(),
1088            "Two identities should have different ids"
1089        );
1090    }
1091
1092    #[test]
1093    fn test_buffer_identity_clone_shares_guard() {
1094        let id1 = BufferIdentity::new();
1095        let weak = id1.weak();
1096        assert!(
1097            weak.upgrade().is_some(),
1098            "Weak should be alive while original exists"
1099        );
1100
1101        let id2 = id1.clone();
1102        assert_eq!(id1.id(), id2.id(), "Cloned identity should have same id");
1103
1104        drop(id1);
1105        assert!(
1106            weak.upgrade().is_some(),
1107            "Weak should still be alive (clone holds Arc)"
1108        );
1109
1110        drop(id2);
1111        assert!(
1112            weak.upgrade().is_none(),
1113            "Weak should be dead after all clones dropped"
1114        );
1115    }
1116
1117    #[test]
1118    fn test_tensor_buffer_identity() {
1119        let t1 = Tensor::<u8>::new(&[100], Some(TensorMemory::Mem), Some("t1")).unwrap();
1120        let t2 = Tensor::<u8>::new(&[100], Some(TensorMemory::Mem), Some("t2")).unwrap();
1121        assert_ne!(
1122            t1.buffer_identity().id(),
1123            t2.buffer_identity().id(),
1124            "Different tensors should have different buffer ids"
1125        );
1126    }
1127
1128    // Any test that cares about the fd count must grab it exclusively.
1129    // Any tests which modifies the fd count by opening or closing fds must grab it
1130    // shared.
1131    pub static FD_LOCK: RwLock<()> = RwLock::new(());
1132
1133    /// Test that DMA is NOT available on non-Linux platforms.
1134    /// This verifies the cross-platform behavior of is_dma_available().
1135    #[test]
1136    #[cfg(not(target_os = "linux"))]
1137    fn test_dma_not_available_on_non_linux() {
1138        assert!(
1139            !is_dma_available(),
1140            "DMA memory allocation should NOT be available on non-Linux platforms"
1141        );
1142    }
1143
1144    /// Test that SHM memory allocation is available and usable on Unix systems.
1145    /// This is a basic functional test; Linux has additional FD leak tests using procfs.
1146    #[test]
1147    #[cfg(unix)]
1148    fn test_shm_available_and_usable() {
1149        assert!(
1150            is_shm_available(),
1151            "SHM memory allocation should be available on Unix systems"
1152        );
1153
1154        // Create a tensor with SHM backing
1155        let tensor = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Shm), None)
1156            .expect("Failed to create SHM tensor");
1157
1158        // Verify we can map and write to it
1159        let mut map = tensor.map().expect("Failed to map SHM tensor");
1160        map.as_mut_slice().fill(0xAB);
1161
1162        // Verify the data was written correctly
1163        assert!(
1164            map.as_slice().iter().all(|&b| b == 0xAB),
1165            "SHM tensor data should be writable and readable"
1166        );
1167    }
1168}