1#[cfg(target_os = "linux")]
29mod dma;
30#[cfg(target_os = "linux")]
31mod dmabuf;
32mod error;
33mod mem;
34#[cfg(unix)]
35mod shm;
36
37#[cfg(target_os = "linux")]
38pub use crate::dma::{DmaMap, DmaTensor};
39pub use crate::mem::{MemMap, MemTensor};
40#[cfg(unix)]
41pub use crate::shm::{ShmMap, ShmTensor};
42pub use error::{Error, Result};
43use num_traits::Num;
44#[cfg(unix)]
45use std::os::fd::OwnedFd;
46use std::{
47 fmt,
48 ops::{Deref, DerefMut},
49};
50
51#[cfg(target_os = "linux")]
52use nix::sys::stat::{major, minor};
53
54pub trait TensorTrait<T>: Send + Sync
55where
56 T: Num + Clone + fmt::Debug,
57{
58 fn new(shape: &[usize], name: Option<&str>) -> Result<Self>
61 where
62 Self: Sized;
63
64 #[cfg(unix)]
65 fn from_fd(fd: std::os::fd::OwnedFd, shape: &[usize], name: Option<&str>) -> Result<Self>
71 where
72 Self: Sized;
73
74 #[cfg(unix)]
75 fn clone_fd(&self) -> Result<std::os::fd::OwnedFd>;
77
78 fn memory(&self) -> TensorMemory;
80
81 fn name(&self) -> String;
83
84 fn len(&self) -> usize {
86 self.shape().iter().product()
87 }
88
89 fn is_empty(&self) -> bool {
91 self.len() == 0
92 }
93
94 fn size(&self) -> usize {
96 self.len() * std::mem::size_of::<T>()
97 }
98
99 fn shape(&self) -> &[usize];
101
102 fn reshape(&mut self, shape: &[usize]) -> Result<()>;
105
106 fn map(&self) -> Result<TensorMap<T>>;
109}
110
111pub trait TensorMapTrait<T>
112where
113 T: Num + Clone + fmt::Debug,
114{
115 fn shape(&self) -> &[usize];
117
118 fn unmap(&mut self);
120
121 fn len(&self) -> usize {
123 self.shape().iter().product()
124 }
125
126 fn is_empty(&self) -> bool {
128 self.len() == 0
129 }
130
131 fn size(&self) -> usize {
133 self.len() * std::mem::size_of::<T>()
134 }
135
136 fn as_slice(&self) -> &[T];
138
139 fn as_mut_slice(&mut self) -> &mut [T];
141
142 #[cfg(feature = "ndarray")]
143 fn view(&'_ self) -> Result<ndarray::ArrayView<'_, T, ndarray::Dim<ndarray::IxDynImpl>>> {
145 Ok(ndarray::ArrayView::from_shape(
146 self.shape(),
147 self.as_slice(),
148 )?)
149 }
150
151 #[cfg(feature = "ndarray")]
152 fn view_mut(
154 &'_ mut self,
155 ) -> Result<ndarray::ArrayViewMut<'_, T, ndarray::Dim<ndarray::IxDynImpl>>> {
156 let shape = self.shape().to_vec();
157 Ok(ndarray::ArrayViewMut::from_shape(
158 shape,
159 self.as_mut_slice(),
160 )?)
161 }
162}
163
164#[derive(Debug, Clone, Copy, PartialEq, Eq)]
165pub enum TensorMemory {
166 #[cfg(target_os = "linux")]
167 Dma,
171 #[cfg(unix)]
172 Shm,
175
176 Mem,
178}
179
180impl From<TensorMemory> for String {
181 fn from(memory: TensorMemory) -> Self {
182 match memory {
183 #[cfg(target_os = "linux")]
184 TensorMemory::Dma => "dma".to_owned(),
185 #[cfg(unix)]
186 TensorMemory::Shm => "shm".to_owned(),
187 TensorMemory::Mem => "mem".to_owned(),
188 }
189 }
190}
191
192impl TryFrom<&str> for TensorMemory {
193 type Error = Error;
194
195 fn try_from(s: &str) -> Result<Self> {
196 match s {
197 #[cfg(target_os = "linux")]
198 "dma" => Ok(TensorMemory::Dma),
199 #[cfg(unix)]
200 "shm" => Ok(TensorMemory::Shm),
201 "mem" => Ok(TensorMemory::Mem),
202 _ => Err(Error::InvalidMemoryType(s.to_owned())),
203 }
204 }
205}
206
207#[derive(Debug)]
208pub enum Tensor<T>
209where
210 T: Num + Clone + fmt::Debug + Send + Sync,
211{
212 #[cfg(target_os = "linux")]
213 Dma(DmaTensor<T>),
214 #[cfg(unix)]
215 Shm(ShmTensor<T>),
216 Mem(MemTensor<T>),
217}
218
219impl<T> Tensor<T>
220where
221 T: Num + Clone + fmt::Debug + Send + Sync,
222{
223 pub fn new(shape: &[usize], memory: Option<TensorMemory>, name: Option<&str>) -> Result<Self> {
248 match memory {
249 #[cfg(target_os = "linux")]
250 Some(TensorMemory::Dma) => DmaTensor::<T>::new(shape, name).map(Tensor::Dma),
251 #[cfg(unix)]
252 Some(TensorMemory::Shm) => ShmTensor::<T>::new(shape, name).map(Tensor::Shm),
253 Some(TensorMemory::Mem) => MemTensor::<T>::new(shape, name).map(Tensor::Mem),
254 None => {
255 if std::env::var("EDGEFIRST_TENSOR_FORCE_MEM")
256 .is_ok_and(|x| x != "0" && x.to_lowercase() != "false")
257 {
258 MemTensor::<T>::new(shape, name).map(Tensor::Mem)
259 } else {
260 #[cfg(target_os = "linux")]
261 {
262 match DmaTensor::<T>::new(shape, name) {
264 Ok(tensor) => Ok(Tensor::Dma(tensor)),
265 Err(_) => match ShmTensor::<T>::new(shape, name).map(Tensor::Shm) {
266 Ok(tensor) => Ok(tensor),
267 Err(_) => MemTensor::<T>::new(shape, name).map(Tensor::Mem),
268 },
269 }
270 }
271 #[cfg(all(unix, not(target_os = "linux")))]
272 {
273 match ShmTensor::<T>::new(shape, name) {
275 Ok(tensor) => Ok(Tensor::Shm(tensor)),
276 Err(_) => MemTensor::<T>::new(shape, name).map(Tensor::Mem),
277 }
278 }
279 #[cfg(not(unix))]
280 {
281 MemTensor::<T>::new(shape, name).map(Tensor::Mem)
283 }
284 }
285 }
286 }
287 }
288}
289
290impl<T> TensorTrait<T> for Tensor<T>
291where
292 T: Num + Clone + fmt::Debug + Send + Sync,
293{
294 fn new(shape: &[usize], name: Option<&str>) -> Result<Self> {
295 Self::new(shape, None, name)
296 }
297
298 #[cfg(unix)]
299 fn from_fd(fd: OwnedFd, shape: &[usize], name: Option<&str>) -> Result<Self> {
306 #[cfg(target_os = "linux")]
307 {
308 use nix::sys::stat::fstat;
309
310 let stat = fstat(&fd)?;
311 let major = major(stat.st_dev);
312 let minor = minor(stat.st_dev);
313
314 log::debug!("Creating tensor from fd: major={major}, minor={minor}");
315
316 if major != 0 {
317 return Err(Error::UnknownDeviceType(major, minor));
319 }
320
321 match minor {
322 9 | 10 => {
323 DmaTensor::<T>::from_fd(fd, shape, name).map(Tensor::Dma)
325 }
326 _ => {
327 ShmTensor::<T>::from_fd(fd, shape, name).map(Tensor::Shm)
329 }
330 }
331 }
332 #[cfg(all(unix, not(target_os = "linux")))]
333 {
334 ShmTensor::<T>::from_fd(fd, shape, name).map(Tensor::Shm)
336 }
337 }
338
339 #[cfg(unix)]
340 fn clone_fd(&self) -> Result<OwnedFd> {
341 match self {
342 #[cfg(target_os = "linux")]
343 Tensor::Dma(t) => t.clone_fd(),
344 Tensor::Shm(t) => t.clone_fd(),
345 Tensor::Mem(t) => t.clone_fd(),
346 }
347 }
348
349 fn memory(&self) -> TensorMemory {
350 match self {
351 #[cfg(target_os = "linux")]
352 Tensor::Dma(_) => TensorMemory::Dma,
353 #[cfg(unix)]
354 Tensor::Shm(_) => TensorMemory::Shm,
355 Tensor::Mem(_) => TensorMemory::Mem,
356 }
357 }
358
359 fn name(&self) -> String {
360 match self {
361 #[cfg(target_os = "linux")]
362 Tensor::Dma(t) => t.name(),
363 #[cfg(unix)]
364 Tensor::Shm(t) => t.name(),
365 Tensor::Mem(t) => t.name(),
366 }
367 }
368
369 fn shape(&self) -> &[usize] {
370 match self {
371 #[cfg(target_os = "linux")]
372 Tensor::Dma(t) => t.shape(),
373 #[cfg(unix)]
374 Tensor::Shm(t) => t.shape(),
375 Tensor::Mem(t) => t.shape(),
376 }
377 }
378
379 fn reshape(&mut self, shape: &[usize]) -> Result<()> {
380 match self {
381 #[cfg(target_os = "linux")]
382 Tensor::Dma(t) => t.reshape(shape),
383 #[cfg(unix)]
384 Tensor::Shm(t) => t.reshape(shape),
385 Tensor::Mem(t) => t.reshape(shape),
386 }
387 }
388
389 fn map(&self) -> Result<TensorMap<T>> {
390 match self {
391 #[cfg(target_os = "linux")]
392 Tensor::Dma(t) => t.map(),
393 #[cfg(unix)]
394 Tensor::Shm(t) => t.map(),
395 Tensor::Mem(t) => t.map(),
396 }
397 }
398}
399
400pub enum TensorMap<T>
401where
402 T: Num + Clone + fmt::Debug,
403{
404 #[cfg(target_os = "linux")]
405 Dma(DmaMap<T>),
406 #[cfg(unix)]
407 Shm(ShmMap<T>),
408 Mem(MemMap<T>),
409}
410
411impl<T> TensorMapTrait<T> for TensorMap<T>
412where
413 T: Num + Clone + fmt::Debug,
414{
415 fn shape(&self) -> &[usize] {
416 match self {
417 #[cfg(target_os = "linux")]
418 TensorMap::Dma(map) => map.shape(),
419 #[cfg(unix)]
420 TensorMap::Shm(map) => map.shape(),
421 TensorMap::Mem(map) => map.shape(),
422 }
423 }
424
425 fn unmap(&mut self) {
426 match self {
427 #[cfg(target_os = "linux")]
428 TensorMap::Dma(map) => map.unmap(),
429 #[cfg(unix)]
430 TensorMap::Shm(map) => map.unmap(),
431 TensorMap::Mem(map) => map.unmap(),
432 }
433 }
434
435 fn as_slice(&self) -> &[T] {
436 match self {
437 #[cfg(target_os = "linux")]
438 TensorMap::Dma(map) => map.as_slice(),
439 #[cfg(unix)]
440 TensorMap::Shm(map) => map.as_slice(),
441 TensorMap::Mem(map) => map.as_slice(),
442 }
443 }
444
445 fn as_mut_slice(&mut self) -> &mut [T] {
446 match self {
447 #[cfg(target_os = "linux")]
448 TensorMap::Dma(map) => map.as_mut_slice(),
449 #[cfg(unix)]
450 TensorMap::Shm(map) => map.as_mut_slice(),
451 TensorMap::Mem(map) => map.as_mut_slice(),
452 }
453 }
454}
455
456impl<T> Deref for TensorMap<T>
457where
458 T: Num + Clone + fmt::Debug,
459{
460 type Target = [T];
461
462 fn deref(&self) -> &[T] {
463 match self {
464 #[cfg(target_os = "linux")]
465 TensorMap::Dma(map) => map.deref(),
466 #[cfg(unix)]
467 TensorMap::Shm(map) => map.deref(),
468 TensorMap::Mem(map) => map.deref(),
469 }
470 }
471}
472
473impl<T> DerefMut for TensorMap<T>
474where
475 T: Num + Clone + fmt::Debug,
476{
477 fn deref_mut(&mut self) -> &mut [T] {
478 match self {
479 #[cfg(target_os = "linux")]
480 TensorMap::Dma(map) => map.deref_mut(),
481 #[cfg(unix)]
482 TensorMap::Shm(map) => map.deref_mut(),
483 TensorMap::Mem(map) => map.deref_mut(),
484 }
485 }
486}
487
488#[cfg(target_os = "linux")]
500static DMA_AVAILABLE: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
501
502#[cfg(target_os = "linux")]
504pub fn is_dma_available() -> bool {
505 *DMA_AVAILABLE.get_or_init(|| Tensor::<u8>::new(&[64], Some(TensorMemory::Dma), None).is_ok())
506}
507
508#[cfg(not(target_os = "linux"))]
512pub fn is_dma_available() -> bool {
513 false
514}
515
516#[cfg(unix)]
523static SHM_AVAILABLE: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
524
525#[cfg(unix)]
527pub fn is_shm_available() -> bool {
528 *SHM_AVAILABLE.get_or_init(|| Tensor::<u8>::new(&[64], Some(TensorMemory::Shm), None).is_ok())
529}
530
531#[cfg(not(unix))]
535pub fn is_shm_available() -> bool {
536 false
537}
538
539#[cfg(test)]
540mod tests {
541 #[cfg(target_os = "linux")]
542 use nix::unistd::{access, AccessFlags};
543 #[cfg(target_os = "linux")]
544 use std::io::Write as _;
545 use std::sync::RwLock;
546
547 use super::*;
548
549 #[ctor::ctor]
550 fn init() {
551 env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init();
552 }
553
554 #[cfg(target_os = "linux")]
556 macro_rules! function {
557 () => {{
558 fn f() {}
559 fn type_name_of<T>(_: T) -> &'static str {
560 std::any::type_name::<T>()
561 }
562 let name = type_name_of(f);
563
564 match &name[..name.len() - 3].rfind(':') {
566 Some(pos) => &name[pos + 1..name.len() - 3],
567 None => &name[..name.len() - 3],
568 }
569 }};
570 }
571
572 #[test]
573 #[cfg(target_os = "linux")]
574 fn test_tensor() {
575 let _lock = FD_LOCK.read().unwrap();
576 let shape = vec![1];
577 let tensor = DmaTensor::<f32>::new(&shape, Some("dma_tensor"));
578 let dma_enabled = tensor.is_ok();
579
580 let tensor = Tensor::<f32>::new(&shape, None, None).expect("Failed to create tensor");
581 match dma_enabled {
582 true => assert_eq!(tensor.memory(), TensorMemory::Dma),
583 false => assert_eq!(tensor.memory(), TensorMemory::Shm),
584 }
585 }
586
587 #[test]
588 #[cfg(all(unix, not(target_os = "linux")))]
589 fn test_tensor() {
590 let shape = vec![1];
591 let tensor = Tensor::<f32>::new(&shape, None, None).expect("Failed to create tensor");
592 assert!(
594 tensor.memory() == TensorMemory::Shm || tensor.memory() == TensorMemory::Mem,
595 "Expected SHM or Mem on macOS, got {:?}",
596 tensor.memory()
597 );
598 }
599
600 #[test]
601 #[cfg(not(unix))]
602 fn test_tensor() {
603 let shape = vec![1];
604 let tensor = Tensor::<f32>::new(&shape, None, None).expect("Failed to create tensor");
605 assert_eq!(tensor.memory(), TensorMemory::Mem);
606 }
607
608 #[test]
609 #[cfg(target_os = "linux")]
610 fn test_dma_tensor() {
611 let _lock = FD_LOCK.read().unwrap();
612 match access(
613 "/dev/dma_heap/linux,cma",
614 AccessFlags::R_OK | AccessFlags::W_OK,
615 ) {
616 Ok(_) => println!("/dev/dma_heap/linux,cma is available"),
617 Err(_) => match access(
618 "/dev/dma_heap/system",
619 AccessFlags::R_OK | AccessFlags::W_OK,
620 ) {
621 Ok(_) => println!("/dev/dma_heap/system is available"),
622 Err(e) => {
623 writeln!(
624 &mut std::io::stdout(),
625 "[WARNING] DMA Heap is unavailable: {e}"
626 )
627 .unwrap();
628 return;
629 }
630 },
631 }
632
633 let shape = vec![2, 3, 4];
634 let tensor =
635 DmaTensor::<f32>::new(&shape, Some("test_tensor")).expect("Failed to create tensor");
636
637 const DUMMY_VALUE: f32 = 12.34;
638
639 assert_eq!(tensor.memory(), TensorMemory::Dma);
640 assert_eq!(tensor.name(), "test_tensor");
641 assert_eq!(tensor.shape(), &shape);
642 assert_eq!(tensor.size(), 2 * 3 * 4 * std::mem::size_of::<f32>());
643 assert_eq!(tensor.len(), 2 * 3 * 4);
644
645 {
646 let mut tensor_map = tensor.map().expect("Failed to map DMA memory");
647 tensor_map.fill(42.0);
648 assert!(tensor_map.iter().all(|&x| x == 42.0));
649 }
650
651 {
652 let shared = Tensor::<f32>::from_fd(
653 tensor
654 .clone_fd()
655 .expect("Failed to duplicate tensor file descriptor"),
656 &shape,
657 Some("test_tensor_shared"),
658 )
659 .expect("Failed to create tensor from fd");
660
661 assert_eq!(shared.memory(), TensorMemory::Dma);
662 assert_eq!(shared.name(), "test_tensor_shared");
663 assert_eq!(shared.shape(), &shape);
664
665 let mut tensor_map = shared.map().expect("Failed to map DMA memory from fd");
666 tensor_map.fill(DUMMY_VALUE);
667 assert!(tensor_map.iter().all(|&x| x == DUMMY_VALUE));
668 }
669
670 {
671 let tensor_map = tensor.map().expect("Failed to map DMA memory");
672 assert!(tensor_map.iter().all(|&x| x == DUMMY_VALUE));
673 }
674
675 let mut tensor = DmaTensor::<u8>::new(&shape, None).expect("Failed to create tensor");
676 assert_eq!(tensor.shape(), &shape);
677 let new_shape = vec![3, 4, 4];
678 assert!(
679 tensor.reshape(&new_shape).is_err(),
680 "Reshape should fail due to size mismatch"
681 );
682 assert_eq!(tensor.shape(), &shape, "Shape should remain unchanged");
683
684 let new_shape = vec![2, 3, 4];
685 tensor.reshape(&new_shape).expect("Reshape should succeed");
686 assert_eq!(
687 tensor.shape(),
688 &new_shape,
689 "Shape should be updated after successful reshape"
690 );
691
692 {
693 let mut tensor_map = tensor.map().expect("Failed to map DMA memory");
694 tensor_map.fill(1);
695 assert!(tensor_map.iter().all(|&x| x == 1));
696 }
697
698 {
699 let mut tensor_map = tensor.map().expect("Failed to map DMA memory");
700 tensor_map[2] = 42;
701 assert_eq!(tensor_map[1], 1, "Value at index 1 should be 1");
702 assert_eq!(tensor_map[2], 42, "Value at index 2 should be 42");
703 }
704 }
705
706 #[test]
707 #[cfg(unix)]
708 fn test_shm_tensor() {
709 let _lock = FD_LOCK.read().unwrap();
710 let shape = vec![2, 3, 4];
711 let tensor =
712 ShmTensor::<f32>::new(&shape, Some("test_tensor")).expect("Failed to create tensor");
713 assert_eq!(tensor.shape(), &shape);
714 assert_eq!(tensor.size(), 2 * 3 * 4 * std::mem::size_of::<f32>());
715 assert_eq!(tensor.name(), "test_tensor");
716
717 const DUMMY_VALUE: f32 = 12.34;
718 {
719 let mut tensor_map = tensor.map().expect("Failed to map shared memory");
720 tensor_map.fill(42.0);
721 assert!(tensor_map.iter().all(|&x| x == 42.0));
722 }
723
724 {
725 let shared = Tensor::<f32>::from_fd(
726 tensor
727 .clone_fd()
728 .expect("Failed to duplicate tensor file descriptor"),
729 &shape,
730 Some("test_tensor_shared"),
731 )
732 .expect("Failed to create tensor from fd");
733
734 assert_eq!(shared.memory(), TensorMemory::Shm);
735 assert_eq!(shared.name(), "test_tensor_shared");
736 assert_eq!(shared.shape(), &shape);
737
738 let mut tensor_map = shared.map().expect("Failed to map shared memory from fd");
739 tensor_map.fill(DUMMY_VALUE);
740 assert!(tensor_map.iter().all(|&x| x == DUMMY_VALUE));
741 }
742
743 {
744 let tensor_map = tensor.map().expect("Failed to map shared memory");
745 assert!(tensor_map.iter().all(|&x| x == DUMMY_VALUE));
746 }
747
748 let mut tensor = ShmTensor::<u8>::new(&shape, None).expect("Failed to create tensor");
749 assert_eq!(tensor.shape(), &shape);
750 let new_shape = vec![3, 4, 4];
751 assert!(
752 tensor.reshape(&new_shape).is_err(),
753 "Reshape should fail due to size mismatch"
754 );
755 assert_eq!(tensor.shape(), &shape, "Shape should remain unchanged");
756
757 let new_shape = vec![2, 3, 4];
758 tensor.reshape(&new_shape).expect("Reshape should succeed");
759 assert_eq!(
760 tensor.shape(),
761 &new_shape,
762 "Shape should be updated after successful reshape"
763 );
764
765 {
766 let mut tensor_map = tensor.map().expect("Failed to map shared memory");
767 tensor_map.fill(1);
768 assert!(tensor_map.iter().all(|&x| x == 1));
769 }
770
771 {
772 let mut tensor_map = tensor.map().expect("Failed to map shared memory");
773 tensor_map[2] = 42;
774 assert_eq!(tensor_map[1], 1, "Value at index 1 should be 1");
775 assert_eq!(tensor_map[2], 42, "Value at index 2 should be 42");
776 }
777 }
778
779 #[test]
780 fn test_mem_tensor() {
781 let shape = vec![2, 3, 4];
782 let tensor =
783 MemTensor::<f32>::new(&shape, Some("test_tensor")).expect("Failed to create tensor");
784 assert_eq!(tensor.shape(), &shape);
785 assert_eq!(tensor.size(), 2 * 3 * 4 * std::mem::size_of::<f32>());
786 assert_eq!(tensor.name(), "test_tensor");
787
788 {
789 let mut tensor_map = tensor.map().expect("Failed to map memory");
790 tensor_map.fill(42.0);
791 assert!(tensor_map.iter().all(|&x| x == 42.0));
792 }
793
794 let mut tensor = MemTensor::<u8>::new(&shape, None).expect("Failed to create tensor");
795 assert_eq!(tensor.shape(), &shape);
796 let new_shape = vec![3, 4, 4];
797 assert!(
798 tensor.reshape(&new_shape).is_err(),
799 "Reshape should fail due to size mismatch"
800 );
801 assert_eq!(tensor.shape(), &shape, "Shape should remain unchanged");
802
803 let new_shape = vec![2, 3, 4];
804 tensor.reshape(&new_shape).expect("Reshape should succeed");
805 assert_eq!(
806 tensor.shape(),
807 &new_shape,
808 "Shape should be updated after successful reshape"
809 );
810
811 {
812 let mut tensor_map = tensor.map().expect("Failed to map memory");
813 tensor_map.fill(1);
814 assert!(tensor_map.iter().all(|&x| x == 1));
815 }
816
817 {
818 let mut tensor_map = tensor.map().expect("Failed to map memory");
819 tensor_map[2] = 42;
820 assert_eq!(tensor_map[1], 1, "Value at index 1 should be 1");
821 assert_eq!(tensor_map[2], 42, "Value at index 2 should be 42");
822 }
823 }
824
825 #[test]
826 #[cfg(target_os = "linux")]
827 fn test_dma_no_fd_leaks() {
828 let _lock = FD_LOCK.write().unwrap();
829 if !is_dma_available() {
830 log::warn!(
831 "SKIPPED: {} - DMA memory allocation not available (permission denied or no DMA-BUF support)",
832 function!()
833 );
834 return;
835 }
836
837 let proc = procfs::process::Process::myself()
838 .expect("Failed to get current process using /proc/self");
839
840 let start_open_fds = proc
841 .fd_count()
842 .expect("Failed to get open file descriptor count");
843
844 for _ in 0..100 {
845 let tensor = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Dma), None)
846 .expect("Failed to create tensor");
847 let mut map = tensor.map().unwrap();
848 map.as_mut_slice().fill(233);
849 }
850
851 let end_open_fds = proc
852 .fd_count()
853 .expect("Failed to get open file descriptor count");
854
855 assert_eq!(
856 start_open_fds, end_open_fds,
857 "File descriptor leak detected: {} -> {}",
858 start_open_fds, end_open_fds
859 );
860 }
861
862 #[test]
863 #[cfg(target_os = "linux")]
864 fn test_dma_from_fd_no_fd_leaks() {
865 let _lock = FD_LOCK.write().unwrap();
866 if !is_dma_available() {
867 log::warn!(
868 "SKIPPED: {} - DMA memory allocation not available (permission denied or no DMA-BUF support)",
869 function!()
870 );
871 return;
872 }
873
874 let proc = procfs::process::Process::myself()
875 .expect("Failed to get current process using /proc/self");
876
877 let start_open_fds = proc
878 .fd_count()
879 .expect("Failed to get open file descriptor count");
880
881 let orig = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Dma), None).unwrap();
882
883 for _ in 0..100 {
884 let tensor =
885 Tensor::<u8>::from_fd(orig.clone_fd().unwrap(), orig.shape(), None).unwrap();
886 let mut map = tensor.map().unwrap();
887 map.as_mut_slice().fill(233);
888 }
889 drop(orig);
890
891 let end_open_fds = proc.fd_count().unwrap();
892
893 assert_eq!(
894 start_open_fds, end_open_fds,
895 "File descriptor leak detected: {} -> {}",
896 start_open_fds, end_open_fds
897 );
898 }
899
900 #[test]
901 #[cfg(target_os = "linux")]
902 fn test_shm_no_fd_leaks() {
903 let _lock = FD_LOCK.write().unwrap();
904 if !is_shm_available() {
905 log::warn!(
906 "SKIPPED: {} - SHM memory allocation not available (permission denied or no SHM support)",
907 function!()
908 );
909 return;
910 }
911
912 let proc = procfs::process::Process::myself()
913 .expect("Failed to get current process using /proc/self");
914
915 let start_open_fds = proc
916 .fd_count()
917 .expect("Failed to get open file descriptor count");
918
919 for _ in 0..100 {
920 let tensor = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Shm), None)
921 .expect("Failed to create tensor");
922 let mut map = tensor.map().unwrap();
923 map.as_mut_slice().fill(233);
924 }
925
926 let end_open_fds = proc
927 .fd_count()
928 .expect("Failed to get open file descriptor count");
929
930 assert_eq!(
931 start_open_fds, end_open_fds,
932 "File descriptor leak detected: {} -> {}",
933 start_open_fds, end_open_fds
934 );
935 }
936
937 #[test]
938 #[cfg(target_os = "linux")]
939 fn test_shm_from_fd_no_fd_leaks() {
940 let _lock = FD_LOCK.write().unwrap();
941 if !is_shm_available() {
942 log::warn!(
943 "SKIPPED: {} - SHM memory allocation not available (permission denied or no SHM support)",
944 function!()
945 );
946 return;
947 }
948
949 let proc = procfs::process::Process::myself()
950 .expect("Failed to get current process using /proc/self");
951
952 let start_open_fds = proc
953 .fd_count()
954 .expect("Failed to get open file descriptor count");
955
956 let orig = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Shm), None).unwrap();
957
958 for _ in 0..100 {
959 let tensor =
960 Tensor::<u8>::from_fd(orig.clone_fd().unwrap(), orig.shape(), None).unwrap();
961 let mut map = tensor.map().unwrap();
962 map.as_mut_slice().fill(233);
963 }
964 drop(orig);
965
966 let end_open_fds = proc.fd_count().unwrap();
967
968 assert_eq!(
969 start_open_fds, end_open_fds,
970 "File descriptor leak detected: {} -> {}",
971 start_open_fds, end_open_fds
972 );
973 }
974
975 #[cfg(feature = "ndarray")]
976 #[test]
977 fn test_ndarray() {
978 let _lock = FD_LOCK.read().unwrap();
979 let shape = vec![2, 3, 4];
980 let tensor = Tensor::<f32>::new(&shape, None, None).expect("Failed to create tensor");
981
982 let mut tensor_map = tensor.map().expect("Failed to map tensor memory");
983 tensor_map.fill(1.0);
984
985 let view = tensor_map.view().expect("Failed to get ndarray view");
986 assert_eq!(view.shape(), &[2, 3, 4]);
987 assert!(view.iter().all(|&x| x == 1.0));
988
989 let mut view_mut = tensor_map
990 .view_mut()
991 .expect("Failed to get mutable ndarray view");
992 view_mut[[0, 0, 0]] = 42.0;
993 assert_eq!(view_mut[[0, 0, 0]], 42.0);
994 assert_eq!(tensor_map[0], 42.0, "Value at index 0 should be 42");
995 }
996
997 pub static FD_LOCK: RwLock<()> = RwLock::new(());
1001
1002 #[test]
1005 #[cfg(not(target_os = "linux"))]
1006 fn test_dma_not_available_on_non_linux() {
1007 assert!(
1008 !is_dma_available(),
1009 "DMA memory allocation should NOT be available on non-Linux platforms"
1010 );
1011 }
1012
1013 #[test]
1016 #[cfg(unix)]
1017 fn test_shm_available_and_usable() {
1018 assert!(
1019 is_shm_available(),
1020 "SHM memory allocation should be available on Unix systems"
1021 );
1022
1023 let tensor = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Shm), None)
1025 .expect("Failed to create SHM tensor");
1026
1027 let mut map = tensor.map().expect("Failed to map SHM tensor");
1029 map.as_mut_slice().fill(0xAB);
1030
1031 assert!(
1033 map.as_slice().iter().all(|&b| b == 0xAB),
1034 "SHM tensor data should be writable and readable"
1035 );
1036 }
1037}