Skip to main content

edgefirst_tensor/
lib.rs

1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4/*!
5EdgeFirst HAL - Tensor Module
6
7The `edgefirst_tensor` crate provides a unified interface for managing multi-dimensional arrays (tensors)
8with support for different memory types, including Direct Memory Access (DMA), POSIX Shared Memory (Shm),
9and system memory. The crate defines traits and structures for creating, reshaping, and mapping tensors into memory.
10
11## Examples
12```rust
13use edgefirst_tensor::{Error, Tensor, TensorMemory, TensorTrait};
14# fn main() -> Result<(), Error> {
15let tensor = Tensor::<f32>::new(&[2, 3, 4], Some(TensorMemory::Mem), Some("test_tensor"))?;
16assert_eq!(tensor.memory(), TensorMemory::Mem);
17assert_eq!(tensor.name(), "test_tensor");
18#    Ok(())
19# }
20```
21## Overview
22The main structures and traits provided by the `edgefirst_tensor` crate is the `TensorTrait` and `TensorMapTrait` traits,
23which define the behavior of Tensors and their memory mappings, respectively.
24The `Tensor` enum encapsulates different tensor implementations based on the memory type, while the `TensorMap` enum
25provides access to the underlying data.
26```
27 */
28#[cfg(target_os = "linux")]
29mod dma;
30#[cfg(target_os = "linux")]
31mod dmabuf;
32mod error;
33mod mem;
34#[cfg(unix)]
35mod shm;
36
37#[cfg(target_os = "linux")]
38pub use crate::dma::{DmaMap, DmaTensor};
39pub use crate::mem::{MemMap, MemTensor};
40#[cfg(unix)]
41pub use crate::shm::{ShmMap, ShmTensor};
42pub use error::{Error, Result};
43use num_traits::Num;
44#[cfg(unix)]
45use std::os::fd::OwnedFd;
46use std::{
47    fmt,
48    ops::{Deref, DerefMut},
49};
50
51#[cfg(target_os = "linux")]
52use nix::sys::stat::{major, minor};
53
54pub trait TensorTrait<T>: Send + Sync
55where
56    T: Num + Clone + fmt::Debug,
57{
58    /// Create a new tensor with the given shape and optional name. If no name
59    /// is given, a random name will be generated.
60    fn new(shape: &[usize], name: Option<&str>) -> Result<Self>
61    where
62        Self: Sized;
63
64    #[cfg(unix)]
65    /// Create a new tensor using the given file descriptor, shape, and optional
66    /// name. If no name is given, a random name will be generated.
67    ///
68    /// On Linux: Inspects the fd to determine DMA vs SHM based on device major/minor.
69    /// On other Unix (macOS): Always creates SHM tensor.
70    fn from_fd(fd: std::os::fd::OwnedFd, shape: &[usize], name: Option<&str>) -> Result<Self>
71    where
72        Self: Sized;
73
74    #[cfg(unix)]
75    /// Clone the file descriptor associated with this tensor.
76    fn clone_fd(&self) -> Result<std::os::fd::OwnedFd>;
77
78    /// Get the memory type of this tensor.
79    fn memory(&self) -> TensorMemory;
80
81    /// Get the name of this tensor.
82    fn name(&self) -> String;
83
84    /// Get the number of elements in this tensor.
85    fn len(&self) -> usize {
86        self.shape().iter().product()
87    }
88
89    /// Check if the tensor is empty.
90    fn is_empty(&self) -> bool {
91        self.len() == 0
92    }
93
94    /// Get the size in bytes of this tensor.
95    fn size(&self) -> usize {
96        self.len() * std::mem::size_of::<T>()
97    }
98
99    /// Get the shape of this tensor.
100    fn shape(&self) -> &[usize];
101
102    /// Reshape this tensor to the given shape. The total number of elements
103    /// must remain the same.
104    fn reshape(&mut self, shape: &[usize]) -> Result<()>;
105
106    /// Map the tensor into memory and return a TensorMap for accessing the
107    /// data.
108    fn map(&self) -> Result<TensorMap<T>>;
109}
110
111pub trait TensorMapTrait<T>
112where
113    T: Num + Clone + fmt::Debug,
114{
115    /// Get the shape of this tensor map.
116    fn shape(&self) -> &[usize];
117
118    /// Unmap the tensor from memory.
119    fn unmap(&mut self);
120
121    /// Get the number of elements in this tensor map.
122    fn len(&self) -> usize {
123        self.shape().iter().product()
124    }
125
126    /// Check if the tensor map is empty.
127    fn is_empty(&self) -> bool {
128        self.len() == 0
129    }
130
131    /// Get the size in bytes of this tensor map.
132    fn size(&self) -> usize {
133        self.len() * std::mem::size_of::<T>()
134    }
135
136    /// Get a slice to the data in this tensor map.
137    fn as_slice(&self) -> &[T];
138
139    /// Get a mutable slice to the data in this tensor map.
140    fn as_mut_slice(&mut self) -> &mut [T];
141
142    #[cfg(feature = "ndarray")]
143    /// Get an ndarray ArrayView of the tensor data.
144    fn view(&'_ self) -> Result<ndarray::ArrayView<'_, T, ndarray::Dim<ndarray::IxDynImpl>>> {
145        Ok(ndarray::ArrayView::from_shape(
146            self.shape(),
147            self.as_slice(),
148        )?)
149    }
150
151    #[cfg(feature = "ndarray")]
152    /// Get an ndarray ArrayViewMut of the tensor data.
153    fn view_mut(
154        &'_ mut self,
155    ) -> Result<ndarray::ArrayViewMut<'_, T, ndarray::Dim<ndarray::IxDynImpl>>> {
156        let shape = self.shape().to_vec();
157        Ok(ndarray::ArrayViewMut::from_shape(
158            shape,
159            self.as_mut_slice(),
160        )?)
161    }
162}
163
164#[derive(Debug, Clone, Copy, PartialEq, Eq)]
165pub enum TensorMemory {
166    #[cfg(target_os = "linux")]
167    /// Direct Memory Access (DMA) allocation. Incurs additional
168    /// overhead for memory reading/writing with the CPU.  Allows for
169    /// hardware acceleration when supported.
170    Dma,
171    #[cfg(unix)]
172    /// POSIX Shared Memory allocation. Suitable for inter-process
173    /// communication, but not suitable for hardware acceleration.
174    Shm,
175
176    /// Regular system memory allocation
177    Mem,
178}
179
180impl From<TensorMemory> for String {
181    fn from(memory: TensorMemory) -> Self {
182        match memory {
183            #[cfg(target_os = "linux")]
184            TensorMemory::Dma => "dma".to_owned(),
185            #[cfg(unix)]
186            TensorMemory::Shm => "shm".to_owned(),
187            TensorMemory::Mem => "mem".to_owned(),
188        }
189    }
190}
191
192impl TryFrom<&str> for TensorMemory {
193    type Error = Error;
194
195    fn try_from(s: &str) -> Result<Self> {
196        match s {
197            #[cfg(target_os = "linux")]
198            "dma" => Ok(TensorMemory::Dma),
199            #[cfg(unix)]
200            "shm" => Ok(TensorMemory::Shm),
201            "mem" => Ok(TensorMemory::Mem),
202            _ => Err(Error::InvalidMemoryType(s.to_owned())),
203        }
204    }
205}
206
207#[derive(Debug)]
208pub enum Tensor<T>
209where
210    T: Num + Clone + fmt::Debug + Send + Sync,
211{
212    #[cfg(target_os = "linux")]
213    Dma(DmaTensor<T>),
214    #[cfg(unix)]
215    Shm(ShmTensor<T>),
216    Mem(MemTensor<T>),
217}
218
219impl<T> Tensor<T>
220where
221    T: Num + Clone + fmt::Debug + Send + Sync,
222{
223    /// Create a new tensor with the given shape, memory type, and optional
224    /// name. If no name is given, a random name will be generated. If no
225    /// memory type is given, the best available memory type will be chosen
226    /// based on the platform and environment variables.
227    ///
228    /// On Linux platforms, the order of preference is: Dma -> Shm -> Mem.
229    /// On other Unix platforms (macOS), the order is: Shm -> Mem.
230    /// On non-Unix platforms, only Mem is available.
231    ///
232    /// # Environment Variables
233    /// - `EDGEFIRST_TENSOR_FORCE_MEM`: If set to a non-zero and non-false
234    ///   value, forces the use of regular system memory allocation
235    ///   (`TensorMemory::Mem`) regardless of platform capabilities.
236    ///
237    /// # Example
238    /// ```rust
239    /// use edgefirst_tensor::{Error, Tensor, TensorMemory, TensorTrait};
240    /// # fn main() -> Result<(), Error> {
241    /// let tensor = Tensor::<f32>::new(&[2, 3, 4], Some(TensorMemory::Mem), Some("test_tensor"))?;
242    /// assert_eq!(tensor.memory(), TensorMemory::Mem);
243    /// assert_eq!(tensor.name(), "test_tensor");
244    /// #    Ok(())
245    /// # }
246    /// ```
247    pub fn new(shape: &[usize], memory: Option<TensorMemory>, name: Option<&str>) -> Result<Self> {
248        match memory {
249            #[cfg(target_os = "linux")]
250            Some(TensorMemory::Dma) => DmaTensor::<T>::new(shape, name).map(Tensor::Dma),
251            #[cfg(unix)]
252            Some(TensorMemory::Shm) => ShmTensor::<T>::new(shape, name).map(Tensor::Shm),
253            Some(TensorMemory::Mem) => MemTensor::<T>::new(shape, name).map(Tensor::Mem),
254            None => {
255                if std::env::var("EDGEFIRST_TENSOR_FORCE_MEM")
256                    .is_ok_and(|x| x != "0" && x.to_lowercase() != "false")
257                {
258                    MemTensor::<T>::new(shape, name).map(Tensor::Mem)
259                } else {
260                    #[cfg(target_os = "linux")]
261                    {
262                        // Linux: Try DMA -> SHM -> Mem
263                        match DmaTensor::<T>::new(shape, name) {
264                            Ok(tensor) => Ok(Tensor::Dma(tensor)),
265                            Err(_) => match ShmTensor::<T>::new(shape, name).map(Tensor::Shm) {
266                                Ok(tensor) => Ok(tensor),
267                                Err(_) => MemTensor::<T>::new(shape, name).map(Tensor::Mem),
268                            },
269                        }
270                    }
271                    #[cfg(all(unix, not(target_os = "linux")))]
272                    {
273                        // macOS/BSD: Try SHM -> Mem (no DMA)
274                        match ShmTensor::<T>::new(shape, name) {
275                            Ok(tensor) => Ok(Tensor::Shm(tensor)),
276                            Err(_) => MemTensor::<T>::new(shape, name).map(Tensor::Mem),
277                        }
278                    }
279                    #[cfg(not(unix))]
280                    {
281                        // Windows/other: Mem only
282                        MemTensor::<T>::new(shape, name).map(Tensor::Mem)
283                    }
284                }
285            }
286        }
287    }
288}
289
290impl<T> TensorTrait<T> for Tensor<T>
291where
292    T: Num + Clone + fmt::Debug + Send + Sync,
293{
294    fn new(shape: &[usize], name: Option<&str>) -> Result<Self> {
295        Self::new(shape, None, name)
296    }
297
298    #[cfg(unix)]
299    /// Create a new tensor using the given file descriptor, shape, and optional
300    /// name. If no name is given, a random name will be generated.
301    ///
302    /// On Linux: Inspects the file descriptor to determine the appropriate tensor type
303    /// (Dma or Shm) based on the device major and minor numbers.
304    /// On other Unix (macOS): Always creates SHM tensor.
305    fn from_fd(fd: OwnedFd, shape: &[usize], name: Option<&str>) -> Result<Self> {
306        #[cfg(target_os = "linux")]
307        {
308            use nix::sys::stat::fstat;
309
310            let stat = fstat(&fd)?;
311            let major = major(stat.st_dev);
312            let minor = minor(stat.st_dev);
313
314            log::debug!("Creating tensor from fd: major={major}, minor={minor}");
315
316            if major != 0 {
317                // Dma and Shm tensors are expected to have major number 0
318                return Err(Error::UnknownDeviceType(major, minor));
319            }
320
321            match minor {
322                9 | 10 => {
323                    // minor number 9 & 10 indicates DMA memory
324                    DmaTensor::<T>::from_fd(fd, shape, name).map(Tensor::Dma)
325                }
326                _ => {
327                    // other minor numbers are assumed to be shared memory
328                    ShmTensor::<T>::from_fd(fd, shape, name).map(Tensor::Shm)
329                }
330            }
331        }
332        #[cfg(all(unix, not(target_os = "linux")))]
333        {
334            // On macOS/BSD, always use SHM (no DMA support)
335            ShmTensor::<T>::from_fd(fd, shape, name).map(Tensor::Shm)
336        }
337    }
338
339    #[cfg(unix)]
340    fn clone_fd(&self) -> Result<OwnedFd> {
341        match self {
342            #[cfg(target_os = "linux")]
343            Tensor::Dma(t) => t.clone_fd(),
344            Tensor::Shm(t) => t.clone_fd(),
345            Tensor::Mem(t) => t.clone_fd(),
346        }
347    }
348
349    fn memory(&self) -> TensorMemory {
350        match self {
351            #[cfg(target_os = "linux")]
352            Tensor::Dma(_) => TensorMemory::Dma,
353            #[cfg(unix)]
354            Tensor::Shm(_) => TensorMemory::Shm,
355            Tensor::Mem(_) => TensorMemory::Mem,
356        }
357    }
358
359    fn name(&self) -> String {
360        match self {
361            #[cfg(target_os = "linux")]
362            Tensor::Dma(t) => t.name(),
363            #[cfg(unix)]
364            Tensor::Shm(t) => t.name(),
365            Tensor::Mem(t) => t.name(),
366        }
367    }
368
369    fn shape(&self) -> &[usize] {
370        match self {
371            #[cfg(target_os = "linux")]
372            Tensor::Dma(t) => t.shape(),
373            #[cfg(unix)]
374            Tensor::Shm(t) => t.shape(),
375            Tensor::Mem(t) => t.shape(),
376        }
377    }
378
379    fn reshape(&mut self, shape: &[usize]) -> Result<()> {
380        match self {
381            #[cfg(target_os = "linux")]
382            Tensor::Dma(t) => t.reshape(shape),
383            #[cfg(unix)]
384            Tensor::Shm(t) => t.reshape(shape),
385            Tensor::Mem(t) => t.reshape(shape),
386        }
387    }
388
389    fn map(&self) -> Result<TensorMap<T>> {
390        match self {
391            #[cfg(target_os = "linux")]
392            Tensor::Dma(t) => t.map(),
393            #[cfg(unix)]
394            Tensor::Shm(t) => t.map(),
395            Tensor::Mem(t) => t.map(),
396        }
397    }
398}
399
400pub enum TensorMap<T>
401where
402    T: Num + Clone + fmt::Debug,
403{
404    #[cfg(target_os = "linux")]
405    Dma(DmaMap<T>),
406    #[cfg(unix)]
407    Shm(ShmMap<T>),
408    Mem(MemMap<T>),
409}
410
411impl<T> TensorMapTrait<T> for TensorMap<T>
412where
413    T: Num + Clone + fmt::Debug,
414{
415    fn shape(&self) -> &[usize] {
416        match self {
417            #[cfg(target_os = "linux")]
418            TensorMap::Dma(map) => map.shape(),
419            #[cfg(unix)]
420            TensorMap::Shm(map) => map.shape(),
421            TensorMap::Mem(map) => map.shape(),
422        }
423    }
424
425    fn unmap(&mut self) {
426        match self {
427            #[cfg(target_os = "linux")]
428            TensorMap::Dma(map) => map.unmap(),
429            #[cfg(unix)]
430            TensorMap::Shm(map) => map.unmap(),
431            TensorMap::Mem(map) => map.unmap(),
432        }
433    }
434
435    fn as_slice(&self) -> &[T] {
436        match self {
437            #[cfg(target_os = "linux")]
438            TensorMap::Dma(map) => map.as_slice(),
439            #[cfg(unix)]
440            TensorMap::Shm(map) => map.as_slice(),
441            TensorMap::Mem(map) => map.as_slice(),
442        }
443    }
444
445    fn as_mut_slice(&mut self) -> &mut [T] {
446        match self {
447            #[cfg(target_os = "linux")]
448            TensorMap::Dma(map) => map.as_mut_slice(),
449            #[cfg(unix)]
450            TensorMap::Shm(map) => map.as_mut_slice(),
451            TensorMap::Mem(map) => map.as_mut_slice(),
452        }
453    }
454}
455
456impl<T> Deref for TensorMap<T>
457where
458    T: Num + Clone + fmt::Debug,
459{
460    type Target = [T];
461
462    fn deref(&self) -> &[T] {
463        match self {
464            #[cfg(target_os = "linux")]
465            TensorMap::Dma(map) => map.deref(),
466            #[cfg(unix)]
467            TensorMap::Shm(map) => map.deref(),
468            TensorMap::Mem(map) => map.deref(),
469        }
470    }
471}
472
473impl<T> DerefMut for TensorMap<T>
474where
475    T: Num + Clone + fmt::Debug,
476{
477    fn deref_mut(&mut self) -> &mut [T] {
478        match self {
479            #[cfg(target_os = "linux")]
480            TensorMap::Dma(map) => map.deref_mut(),
481            #[cfg(unix)]
482            TensorMap::Shm(map) => map.deref_mut(),
483            TensorMap::Mem(map) => map.deref_mut(),
484        }
485    }
486}
487
488// ============================================================================
489// Platform availability helpers
490// ============================================================================
491
492/// Check if DMA memory allocation is available on this system.
493///
494/// Returns `true` only on Linux systems with DMA-BUF heap access (typically
495/// requires running as root or membership in a video/render group).
496/// Always returns `false` on non-Linux platforms (macOS, Windows, etc.).
497///
498/// This function caches its result after the first call for efficiency.
499#[cfg(target_os = "linux")]
500static DMA_AVAILABLE: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
501
502/// Check if DMA memory allocation is available on this system.
503#[cfg(target_os = "linux")]
504pub fn is_dma_available() -> bool {
505    *DMA_AVAILABLE.get_or_init(|| Tensor::<u8>::new(&[64], Some(TensorMemory::Dma), None).is_ok())
506}
507
508/// Check if DMA memory allocation is available on this system.
509///
510/// Always returns `false` on non-Linux platforms since DMA-BUF is Linux-specific.
511#[cfg(not(target_os = "linux"))]
512pub fn is_dma_available() -> bool {
513    false
514}
515
516/// Check if POSIX shared memory allocation is available on this system.
517///
518/// Returns `true` on Unix systems (Linux, macOS, BSD) where POSIX shared memory
519/// is supported. Always returns `false` on non-Unix platforms (Windows).
520///
521/// This function caches its result after the first call for efficiency.
522#[cfg(unix)]
523static SHM_AVAILABLE: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
524
525/// Check if POSIX shared memory allocation is available on this system.
526#[cfg(unix)]
527pub fn is_shm_available() -> bool {
528    *SHM_AVAILABLE.get_or_init(|| Tensor::<u8>::new(&[64], Some(TensorMemory::Shm), None).is_ok())
529}
530
531/// Check if POSIX shared memory allocation is available on this system.
532///
533/// Always returns `false` on non-Unix platforms since POSIX SHM is Unix-specific.
534#[cfg(not(unix))]
535pub fn is_shm_available() -> bool {
536    false
537}
538
539#[cfg(test)]
540mod tests {
541    #[cfg(target_os = "linux")]
542    use nix::unistd::{access, AccessFlags};
543    #[cfg(target_os = "linux")]
544    use std::io::Write as _;
545    use std::sync::RwLock;
546
547    use super::*;
548
549    #[ctor::ctor]
550    fn init() {
551        env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init();
552    }
553
554    /// Macro to get the current function name for logging in tests.
555    #[cfg(target_os = "linux")]
556    macro_rules! function {
557        () => {{
558            fn f() {}
559            fn type_name_of<T>(_: T) -> &'static str {
560                std::any::type_name::<T>()
561            }
562            let name = type_name_of(f);
563
564            // Find and cut the rest of the path
565            match &name[..name.len() - 3].rfind(':') {
566                Some(pos) => &name[pos + 1..name.len() - 3],
567                None => &name[..name.len() - 3],
568            }
569        }};
570    }
571
572    #[test]
573    #[cfg(target_os = "linux")]
574    fn test_tensor() {
575        let _lock = FD_LOCK.read().unwrap();
576        let shape = vec![1];
577        let tensor = DmaTensor::<f32>::new(&shape, Some("dma_tensor"));
578        let dma_enabled = tensor.is_ok();
579
580        let tensor = Tensor::<f32>::new(&shape, None, None).expect("Failed to create tensor");
581        match dma_enabled {
582            true => assert_eq!(tensor.memory(), TensorMemory::Dma),
583            false => assert_eq!(tensor.memory(), TensorMemory::Shm),
584        }
585    }
586
587    #[test]
588    #[cfg(all(unix, not(target_os = "linux")))]
589    fn test_tensor() {
590        let shape = vec![1];
591        let tensor = Tensor::<f32>::new(&shape, None, None).expect("Failed to create tensor");
592        // On macOS/BSD, auto-detection tries SHM first, falls back to Mem
593        assert!(
594            tensor.memory() == TensorMemory::Shm || tensor.memory() == TensorMemory::Mem,
595            "Expected SHM or Mem on macOS, got {:?}",
596            tensor.memory()
597        );
598    }
599
600    #[test]
601    #[cfg(not(unix))]
602    fn test_tensor() {
603        let shape = vec![1];
604        let tensor = Tensor::<f32>::new(&shape, None, None).expect("Failed to create tensor");
605        assert_eq!(tensor.memory(), TensorMemory::Mem);
606    }
607
608    #[test]
609    #[cfg(target_os = "linux")]
610    fn test_dma_tensor() {
611        let _lock = FD_LOCK.read().unwrap();
612        match access(
613            "/dev/dma_heap/linux,cma",
614            AccessFlags::R_OK | AccessFlags::W_OK,
615        ) {
616            Ok(_) => println!("/dev/dma_heap/linux,cma is available"),
617            Err(_) => match access(
618                "/dev/dma_heap/system",
619                AccessFlags::R_OK | AccessFlags::W_OK,
620            ) {
621                Ok(_) => println!("/dev/dma_heap/system is available"),
622                Err(e) => {
623                    writeln!(
624                        &mut std::io::stdout(),
625                        "[WARNING] DMA Heap is unavailable: {e}"
626                    )
627                    .unwrap();
628                    return;
629                }
630            },
631        }
632
633        let shape = vec![2, 3, 4];
634        let tensor =
635            DmaTensor::<f32>::new(&shape, Some("test_tensor")).expect("Failed to create tensor");
636
637        const DUMMY_VALUE: f32 = 12.34;
638
639        assert_eq!(tensor.memory(), TensorMemory::Dma);
640        assert_eq!(tensor.name(), "test_tensor");
641        assert_eq!(tensor.shape(), &shape);
642        assert_eq!(tensor.size(), 2 * 3 * 4 * std::mem::size_of::<f32>());
643        assert_eq!(tensor.len(), 2 * 3 * 4);
644
645        {
646            let mut tensor_map = tensor.map().expect("Failed to map DMA memory");
647            tensor_map.fill(42.0);
648            assert!(tensor_map.iter().all(|&x| x == 42.0));
649        }
650
651        {
652            let shared = Tensor::<f32>::from_fd(
653                tensor
654                    .clone_fd()
655                    .expect("Failed to duplicate tensor file descriptor"),
656                &shape,
657                Some("test_tensor_shared"),
658            )
659            .expect("Failed to create tensor from fd");
660
661            assert_eq!(shared.memory(), TensorMemory::Dma);
662            assert_eq!(shared.name(), "test_tensor_shared");
663            assert_eq!(shared.shape(), &shape);
664
665            let mut tensor_map = shared.map().expect("Failed to map DMA memory from fd");
666            tensor_map.fill(DUMMY_VALUE);
667            assert!(tensor_map.iter().all(|&x| x == DUMMY_VALUE));
668        }
669
670        {
671            let tensor_map = tensor.map().expect("Failed to map DMA memory");
672            assert!(tensor_map.iter().all(|&x| x == DUMMY_VALUE));
673        }
674
675        let mut tensor = DmaTensor::<u8>::new(&shape, None).expect("Failed to create tensor");
676        assert_eq!(tensor.shape(), &shape);
677        let new_shape = vec![3, 4, 4];
678        assert!(
679            tensor.reshape(&new_shape).is_err(),
680            "Reshape should fail due to size mismatch"
681        );
682        assert_eq!(tensor.shape(), &shape, "Shape should remain unchanged");
683
684        let new_shape = vec![2, 3, 4];
685        tensor.reshape(&new_shape).expect("Reshape should succeed");
686        assert_eq!(
687            tensor.shape(),
688            &new_shape,
689            "Shape should be updated after successful reshape"
690        );
691
692        {
693            let mut tensor_map = tensor.map().expect("Failed to map DMA memory");
694            tensor_map.fill(1);
695            assert!(tensor_map.iter().all(|&x| x == 1));
696        }
697
698        {
699            let mut tensor_map = tensor.map().expect("Failed to map DMA memory");
700            tensor_map[2] = 42;
701            assert_eq!(tensor_map[1], 1, "Value at index 1 should be 1");
702            assert_eq!(tensor_map[2], 42, "Value at index 2 should be 42");
703        }
704    }
705
706    #[test]
707    #[cfg(unix)]
708    fn test_shm_tensor() {
709        let _lock = FD_LOCK.read().unwrap();
710        let shape = vec![2, 3, 4];
711        let tensor =
712            ShmTensor::<f32>::new(&shape, Some("test_tensor")).expect("Failed to create tensor");
713        assert_eq!(tensor.shape(), &shape);
714        assert_eq!(tensor.size(), 2 * 3 * 4 * std::mem::size_of::<f32>());
715        assert_eq!(tensor.name(), "test_tensor");
716
717        const DUMMY_VALUE: f32 = 12.34;
718        {
719            let mut tensor_map = tensor.map().expect("Failed to map shared memory");
720            tensor_map.fill(42.0);
721            assert!(tensor_map.iter().all(|&x| x == 42.0));
722        }
723
724        {
725            let shared = Tensor::<f32>::from_fd(
726                tensor
727                    .clone_fd()
728                    .expect("Failed to duplicate tensor file descriptor"),
729                &shape,
730                Some("test_tensor_shared"),
731            )
732            .expect("Failed to create tensor from fd");
733
734            assert_eq!(shared.memory(), TensorMemory::Shm);
735            assert_eq!(shared.name(), "test_tensor_shared");
736            assert_eq!(shared.shape(), &shape);
737
738            let mut tensor_map = shared.map().expect("Failed to map shared memory from fd");
739            tensor_map.fill(DUMMY_VALUE);
740            assert!(tensor_map.iter().all(|&x| x == DUMMY_VALUE));
741        }
742
743        {
744            let tensor_map = tensor.map().expect("Failed to map shared memory");
745            assert!(tensor_map.iter().all(|&x| x == DUMMY_VALUE));
746        }
747
748        let mut tensor = ShmTensor::<u8>::new(&shape, None).expect("Failed to create tensor");
749        assert_eq!(tensor.shape(), &shape);
750        let new_shape = vec![3, 4, 4];
751        assert!(
752            tensor.reshape(&new_shape).is_err(),
753            "Reshape should fail due to size mismatch"
754        );
755        assert_eq!(tensor.shape(), &shape, "Shape should remain unchanged");
756
757        let new_shape = vec![2, 3, 4];
758        tensor.reshape(&new_shape).expect("Reshape should succeed");
759        assert_eq!(
760            tensor.shape(),
761            &new_shape,
762            "Shape should be updated after successful reshape"
763        );
764
765        {
766            let mut tensor_map = tensor.map().expect("Failed to map shared memory");
767            tensor_map.fill(1);
768            assert!(tensor_map.iter().all(|&x| x == 1));
769        }
770
771        {
772            let mut tensor_map = tensor.map().expect("Failed to map shared memory");
773            tensor_map[2] = 42;
774            assert_eq!(tensor_map[1], 1, "Value at index 1 should be 1");
775            assert_eq!(tensor_map[2], 42, "Value at index 2 should be 42");
776        }
777    }
778
779    #[test]
780    fn test_mem_tensor() {
781        let shape = vec![2, 3, 4];
782        let tensor =
783            MemTensor::<f32>::new(&shape, Some("test_tensor")).expect("Failed to create tensor");
784        assert_eq!(tensor.shape(), &shape);
785        assert_eq!(tensor.size(), 2 * 3 * 4 * std::mem::size_of::<f32>());
786        assert_eq!(tensor.name(), "test_tensor");
787
788        {
789            let mut tensor_map = tensor.map().expect("Failed to map memory");
790            tensor_map.fill(42.0);
791            assert!(tensor_map.iter().all(|&x| x == 42.0));
792        }
793
794        let mut tensor = MemTensor::<u8>::new(&shape, None).expect("Failed to create tensor");
795        assert_eq!(tensor.shape(), &shape);
796        let new_shape = vec![3, 4, 4];
797        assert!(
798            tensor.reshape(&new_shape).is_err(),
799            "Reshape should fail due to size mismatch"
800        );
801        assert_eq!(tensor.shape(), &shape, "Shape should remain unchanged");
802
803        let new_shape = vec![2, 3, 4];
804        tensor.reshape(&new_shape).expect("Reshape should succeed");
805        assert_eq!(
806            tensor.shape(),
807            &new_shape,
808            "Shape should be updated after successful reshape"
809        );
810
811        {
812            let mut tensor_map = tensor.map().expect("Failed to map memory");
813            tensor_map.fill(1);
814            assert!(tensor_map.iter().all(|&x| x == 1));
815        }
816
817        {
818            let mut tensor_map = tensor.map().expect("Failed to map memory");
819            tensor_map[2] = 42;
820            assert_eq!(tensor_map[1], 1, "Value at index 1 should be 1");
821            assert_eq!(tensor_map[2], 42, "Value at index 2 should be 42");
822        }
823    }
824
825    #[test]
826    #[cfg(target_os = "linux")]
827    fn test_dma_no_fd_leaks() {
828        let _lock = FD_LOCK.write().unwrap();
829        if !is_dma_available() {
830            log::warn!(
831                "SKIPPED: {} - DMA memory allocation not available (permission denied or no DMA-BUF support)",
832                function!()
833            );
834            return;
835        }
836
837        let proc = procfs::process::Process::myself()
838            .expect("Failed to get current process using /proc/self");
839
840        let start_open_fds = proc
841            .fd_count()
842            .expect("Failed to get open file descriptor count");
843
844        for _ in 0..100 {
845            let tensor = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Dma), None)
846                .expect("Failed to create tensor");
847            let mut map = tensor.map().unwrap();
848            map.as_mut_slice().fill(233);
849        }
850
851        let end_open_fds = proc
852            .fd_count()
853            .expect("Failed to get open file descriptor count");
854
855        assert_eq!(
856            start_open_fds, end_open_fds,
857            "File descriptor leak detected: {} -> {}",
858            start_open_fds, end_open_fds
859        );
860    }
861
862    #[test]
863    #[cfg(target_os = "linux")]
864    fn test_dma_from_fd_no_fd_leaks() {
865        let _lock = FD_LOCK.write().unwrap();
866        if !is_dma_available() {
867            log::warn!(
868                "SKIPPED: {} - DMA memory allocation not available (permission denied or no DMA-BUF support)",
869                function!()
870            );
871            return;
872        }
873
874        let proc = procfs::process::Process::myself()
875            .expect("Failed to get current process using /proc/self");
876
877        let start_open_fds = proc
878            .fd_count()
879            .expect("Failed to get open file descriptor count");
880
881        let orig = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Dma), None).unwrap();
882
883        for _ in 0..100 {
884            let tensor =
885                Tensor::<u8>::from_fd(orig.clone_fd().unwrap(), orig.shape(), None).unwrap();
886            let mut map = tensor.map().unwrap();
887            map.as_mut_slice().fill(233);
888        }
889        drop(orig);
890
891        let end_open_fds = proc.fd_count().unwrap();
892
893        assert_eq!(
894            start_open_fds, end_open_fds,
895            "File descriptor leak detected: {} -> {}",
896            start_open_fds, end_open_fds
897        );
898    }
899
900    #[test]
901    #[cfg(target_os = "linux")]
902    fn test_shm_no_fd_leaks() {
903        let _lock = FD_LOCK.write().unwrap();
904        if !is_shm_available() {
905            log::warn!(
906                "SKIPPED: {} - SHM memory allocation not available (permission denied or no SHM support)",
907                function!()
908            );
909            return;
910        }
911
912        let proc = procfs::process::Process::myself()
913            .expect("Failed to get current process using /proc/self");
914
915        let start_open_fds = proc
916            .fd_count()
917            .expect("Failed to get open file descriptor count");
918
919        for _ in 0..100 {
920            let tensor = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Shm), None)
921                .expect("Failed to create tensor");
922            let mut map = tensor.map().unwrap();
923            map.as_mut_slice().fill(233);
924        }
925
926        let end_open_fds = proc
927            .fd_count()
928            .expect("Failed to get open file descriptor count");
929
930        assert_eq!(
931            start_open_fds, end_open_fds,
932            "File descriptor leak detected: {} -> {}",
933            start_open_fds, end_open_fds
934        );
935    }
936
937    #[test]
938    #[cfg(target_os = "linux")]
939    fn test_shm_from_fd_no_fd_leaks() {
940        let _lock = FD_LOCK.write().unwrap();
941        if !is_shm_available() {
942            log::warn!(
943                "SKIPPED: {} - SHM memory allocation not available (permission denied or no SHM support)",
944                function!()
945            );
946            return;
947        }
948
949        let proc = procfs::process::Process::myself()
950            .expect("Failed to get current process using /proc/self");
951
952        let start_open_fds = proc
953            .fd_count()
954            .expect("Failed to get open file descriptor count");
955
956        let orig = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Shm), None).unwrap();
957
958        for _ in 0..100 {
959            let tensor =
960                Tensor::<u8>::from_fd(orig.clone_fd().unwrap(), orig.shape(), None).unwrap();
961            let mut map = tensor.map().unwrap();
962            map.as_mut_slice().fill(233);
963        }
964        drop(orig);
965
966        let end_open_fds = proc.fd_count().unwrap();
967
968        assert_eq!(
969            start_open_fds, end_open_fds,
970            "File descriptor leak detected: {} -> {}",
971            start_open_fds, end_open_fds
972        );
973    }
974
975    #[cfg(feature = "ndarray")]
976    #[test]
977    fn test_ndarray() {
978        let _lock = FD_LOCK.read().unwrap();
979        let shape = vec![2, 3, 4];
980        let tensor = Tensor::<f32>::new(&shape, None, None).expect("Failed to create tensor");
981
982        let mut tensor_map = tensor.map().expect("Failed to map tensor memory");
983        tensor_map.fill(1.0);
984
985        let view = tensor_map.view().expect("Failed to get ndarray view");
986        assert_eq!(view.shape(), &[2, 3, 4]);
987        assert!(view.iter().all(|&x| x == 1.0));
988
989        let mut view_mut = tensor_map
990            .view_mut()
991            .expect("Failed to get mutable ndarray view");
992        view_mut[[0, 0, 0]] = 42.0;
993        assert_eq!(view_mut[[0, 0, 0]], 42.0);
994        assert_eq!(tensor_map[0], 42.0, "Value at index 0 should be 42");
995    }
996
997    // Any test that cares about the fd count must grab it exclusively.
998    // Any tests which modifies the fd count by opening or closing fds must grab it
999    // shared.
1000    pub static FD_LOCK: RwLock<()> = RwLock::new(());
1001
1002    /// Test that DMA is NOT available on non-Linux platforms.
1003    /// This verifies the cross-platform behavior of is_dma_available().
1004    #[test]
1005    #[cfg(not(target_os = "linux"))]
1006    fn test_dma_not_available_on_non_linux() {
1007        assert!(
1008            !is_dma_available(),
1009            "DMA memory allocation should NOT be available on non-Linux platforms"
1010        );
1011    }
1012
1013    /// Test that SHM memory allocation is available and usable on Unix systems.
1014    /// This is a basic functional test; Linux has additional FD leak tests using procfs.
1015    #[test]
1016    #[cfg(unix)]
1017    fn test_shm_available_and_usable() {
1018        assert!(
1019            is_shm_available(),
1020            "SHM memory allocation should be available on Unix systems"
1021        );
1022
1023        // Create a tensor with SHM backing
1024        let tensor = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Shm), None)
1025            .expect("Failed to create SHM tensor");
1026
1027        // Verify we can map and write to it
1028        let mut map = tensor.map().expect("Failed to map SHM tensor");
1029        map.as_mut_slice().fill(0xAB);
1030
1031        // Verify the data was written correctly
1032        assert!(
1033            map.as_slice().iter().all(|&b| b == 0xAB),
1034            "SHM tensor data should be writable and readable"
1035        );
1036    }
1037}