1#[cfg(target_os = "linux")]
29mod dma;
30#[cfg(target_os = "linux")]
31mod dmabuf;
32mod error;
33mod mem;
34mod pbo;
35#[cfg(unix)]
36mod shm;
37
38#[cfg(target_os = "linux")]
39pub use crate::dma::{DmaMap, DmaTensor};
40pub use crate::mem::{MemMap, MemTensor};
41pub use crate::pbo::{PboMap, PboMapping, PboOps, PboTensor};
42#[cfg(unix)]
43pub use crate::shm::{ShmMap, ShmTensor};
44pub use error::{Error, Result};
45use num_traits::Num;
46#[cfg(unix)]
47use std::os::fd::OwnedFd;
48use std::{
49 fmt,
50 ops::{Deref, DerefMut},
51 sync::{
52 atomic::{AtomicU64, Ordering},
53 Arc, Weak,
54 },
55};
56
57static NEXT_BUFFER_ID: AtomicU64 = AtomicU64::new(1);
59
60#[derive(Debug, Clone)]
66pub struct BufferIdentity {
67 id: u64,
68 guard: Arc<()>,
69}
70
71impl BufferIdentity {
72 pub fn new() -> Self {
74 Self {
75 id: NEXT_BUFFER_ID.fetch_add(1, Ordering::Relaxed),
76 guard: Arc::new(()),
77 }
78 }
79
80 pub fn id(&self) -> u64 {
82 self.id
83 }
84
85 pub fn weak(&self) -> Weak<()> {
88 Arc::downgrade(&self.guard)
89 }
90}
91
92impl Default for BufferIdentity {
93 fn default() -> Self {
94 Self::new()
95 }
96}
97
98#[cfg(target_os = "linux")]
99use nix::sys::stat::{major, minor};
100
101pub trait TensorTrait<T>: Send + Sync
102where
103 T: Num + Clone + fmt::Debug,
104{
105 fn new(shape: &[usize], name: Option<&str>) -> Result<Self>
108 where
109 Self: Sized;
110
111 #[cfg(unix)]
112 fn from_fd(fd: std::os::fd::OwnedFd, shape: &[usize], name: Option<&str>) -> Result<Self>
118 where
119 Self: Sized;
120
121 #[cfg(unix)]
122 fn clone_fd(&self) -> Result<std::os::fd::OwnedFd>;
124
125 fn memory(&self) -> TensorMemory;
127
128 fn name(&self) -> String;
130
131 fn len(&self) -> usize {
133 self.shape().iter().product()
134 }
135
136 fn is_empty(&self) -> bool {
138 self.len() == 0
139 }
140
141 fn size(&self) -> usize {
143 self.len() * std::mem::size_of::<T>()
144 }
145
146 fn shape(&self) -> &[usize];
148
149 fn reshape(&mut self, shape: &[usize]) -> Result<()>;
152
153 fn map(&self) -> Result<TensorMap<T>>;
156
157 fn buffer_identity(&self) -> &BufferIdentity;
159}
160
161pub trait TensorMapTrait<T>
162where
163 T: Num + Clone + fmt::Debug,
164{
165 fn shape(&self) -> &[usize];
167
168 fn unmap(&mut self);
170
171 fn len(&self) -> usize {
173 self.shape().iter().product()
174 }
175
176 fn is_empty(&self) -> bool {
178 self.len() == 0
179 }
180
181 fn size(&self) -> usize {
183 self.len() * std::mem::size_of::<T>()
184 }
185
186 fn as_slice(&self) -> &[T];
188
189 fn as_mut_slice(&mut self) -> &mut [T];
191
192 #[cfg(feature = "ndarray")]
193 fn view(&'_ self) -> Result<ndarray::ArrayView<'_, T, ndarray::Dim<ndarray::IxDynImpl>>> {
195 Ok(ndarray::ArrayView::from_shape(
196 self.shape(),
197 self.as_slice(),
198 )?)
199 }
200
201 #[cfg(feature = "ndarray")]
202 fn view_mut(
204 &'_ mut self,
205 ) -> Result<ndarray::ArrayViewMut<'_, T, ndarray::Dim<ndarray::IxDynImpl>>> {
206 let shape = self.shape().to_vec();
207 Ok(ndarray::ArrayViewMut::from_shape(
208 shape,
209 self.as_mut_slice(),
210 )?)
211 }
212}
213
214#[derive(Debug, Clone, Copy, PartialEq, Eq)]
215pub enum TensorMemory {
216 #[cfg(target_os = "linux")]
217 Dma,
221 #[cfg(unix)]
222 Shm,
225
226 Mem,
228
229 Pbo,
232}
233
234impl From<TensorMemory> for String {
235 fn from(memory: TensorMemory) -> Self {
236 match memory {
237 #[cfg(target_os = "linux")]
238 TensorMemory::Dma => "dma".to_owned(),
239 #[cfg(unix)]
240 TensorMemory::Shm => "shm".to_owned(),
241 TensorMemory::Mem => "mem".to_owned(),
242 TensorMemory::Pbo => "pbo".to_owned(),
243 }
244 }
245}
246
247impl TryFrom<&str> for TensorMemory {
248 type Error = Error;
249
250 fn try_from(s: &str) -> Result<Self> {
251 match s {
252 #[cfg(target_os = "linux")]
253 "dma" => Ok(TensorMemory::Dma),
254 #[cfg(unix)]
255 "shm" => Ok(TensorMemory::Shm),
256 "mem" => Ok(TensorMemory::Mem),
257 "pbo" => Ok(TensorMemory::Pbo),
258 _ => Err(Error::InvalidMemoryType(s.to_owned())),
259 }
260 }
261}
262
263#[derive(Debug)]
264pub enum Tensor<T>
265where
266 T: Num + Clone + fmt::Debug + Send + Sync,
267{
268 #[cfg(target_os = "linux")]
269 Dma(DmaTensor<T>),
270 #[cfg(unix)]
271 Shm(ShmTensor<T>),
272 Mem(MemTensor<T>),
273 Pbo(PboTensor<T>),
274}
275
276impl<T> Tensor<T>
277where
278 T: Num + Clone + fmt::Debug + Send + Sync,
279{
280 pub fn new(shape: &[usize], memory: Option<TensorMemory>, name: Option<&str>) -> Result<Self> {
305 match memory {
306 #[cfg(target_os = "linux")]
307 Some(TensorMemory::Dma) => DmaTensor::<T>::new(shape, name).map(Tensor::Dma),
308 #[cfg(unix)]
309 Some(TensorMemory::Shm) => ShmTensor::<T>::new(shape, name).map(Tensor::Shm),
310 Some(TensorMemory::Mem) => MemTensor::<T>::new(shape, name).map(Tensor::Mem),
311 Some(TensorMemory::Pbo) => Err(crate::error::Error::NotImplemented(
312 "PboTensor cannot be created via Tensor::new() — use ImageProcessor::create_image()".to_owned(),
313 )),
314 None => {
315 if std::env::var("EDGEFIRST_TENSOR_FORCE_MEM")
316 .is_ok_and(|x| x != "0" && x.to_lowercase() != "false")
317 {
318 MemTensor::<T>::new(shape, name).map(Tensor::Mem)
319 } else {
320 #[cfg(target_os = "linux")]
321 {
322 match DmaTensor::<T>::new(shape, name) {
324 Ok(tensor) => Ok(Tensor::Dma(tensor)),
325 Err(_) => match ShmTensor::<T>::new(shape, name).map(Tensor::Shm) {
326 Ok(tensor) => Ok(tensor),
327 Err(_) => MemTensor::<T>::new(shape, name).map(Tensor::Mem),
328 },
329 }
330 }
331 #[cfg(all(unix, not(target_os = "linux")))]
332 {
333 match ShmTensor::<T>::new(shape, name) {
335 Ok(tensor) => Ok(Tensor::Shm(tensor)),
336 Err(_) => MemTensor::<T>::new(shape, name).map(Tensor::Mem),
337 }
338 }
339 #[cfg(not(unix))]
340 {
341 MemTensor::<T>::new(shape, name).map(Tensor::Mem)
343 }
344 }
345 }
346 }
347 }
348}
349
350impl<T> TensorTrait<T> for Tensor<T>
351where
352 T: Num + Clone + fmt::Debug + Send + Sync,
353{
354 fn new(shape: &[usize], name: Option<&str>) -> Result<Self> {
355 Self::new(shape, None, name)
356 }
357
358 #[cfg(unix)]
359 fn from_fd(fd: OwnedFd, shape: &[usize], name: Option<&str>) -> Result<Self> {
366 #[cfg(target_os = "linux")]
367 {
368 use nix::sys::stat::fstat;
369
370 let stat = fstat(&fd)?;
371 let major = major(stat.st_dev);
372 let minor = minor(stat.st_dev);
373
374 log::debug!("Creating tensor from fd: major={major}, minor={minor}");
375
376 if major != 0 {
377 return Err(Error::UnknownDeviceType(major, minor));
379 }
380
381 match minor {
382 9 | 10 => {
383 DmaTensor::<T>::from_fd(fd, shape, name).map(Tensor::Dma)
385 }
386 _ => {
387 ShmTensor::<T>::from_fd(fd, shape, name).map(Tensor::Shm)
389 }
390 }
391 }
392 #[cfg(all(unix, not(target_os = "linux")))]
393 {
394 ShmTensor::<T>::from_fd(fd, shape, name).map(Tensor::Shm)
396 }
397 }
398
399 #[cfg(unix)]
400 fn clone_fd(&self) -> Result<OwnedFd> {
401 match self {
402 #[cfg(target_os = "linux")]
403 Tensor::Dma(t) => t.clone_fd(),
404 Tensor::Shm(t) => t.clone_fd(),
405 Tensor::Mem(t) => t.clone_fd(),
406 Tensor::Pbo(t) => t.clone_fd(),
407 }
408 }
409
410 fn memory(&self) -> TensorMemory {
411 match self {
412 #[cfg(target_os = "linux")]
413 Tensor::Dma(_) => TensorMemory::Dma,
414 #[cfg(unix)]
415 Tensor::Shm(_) => TensorMemory::Shm,
416 Tensor::Mem(_) => TensorMemory::Mem,
417 Tensor::Pbo(_) => TensorMemory::Pbo,
418 }
419 }
420
421 fn name(&self) -> String {
422 match self {
423 #[cfg(target_os = "linux")]
424 Tensor::Dma(t) => t.name(),
425 #[cfg(unix)]
426 Tensor::Shm(t) => t.name(),
427 Tensor::Mem(t) => t.name(),
428 Tensor::Pbo(t) => t.name(),
429 }
430 }
431
432 fn shape(&self) -> &[usize] {
433 match self {
434 #[cfg(target_os = "linux")]
435 Tensor::Dma(t) => t.shape(),
436 #[cfg(unix)]
437 Tensor::Shm(t) => t.shape(),
438 Tensor::Mem(t) => t.shape(),
439 Tensor::Pbo(t) => t.shape(),
440 }
441 }
442
443 fn reshape(&mut self, shape: &[usize]) -> Result<()> {
444 match self {
445 #[cfg(target_os = "linux")]
446 Tensor::Dma(t) => t.reshape(shape),
447 #[cfg(unix)]
448 Tensor::Shm(t) => t.reshape(shape),
449 Tensor::Mem(t) => t.reshape(shape),
450 Tensor::Pbo(t) => t.reshape(shape),
451 }
452 }
453
454 fn map(&self) -> Result<TensorMap<T>> {
455 match self {
456 #[cfg(target_os = "linux")]
457 Tensor::Dma(t) => t.map(),
458 #[cfg(unix)]
459 Tensor::Shm(t) => t.map(),
460 Tensor::Mem(t) => t.map(),
461 Tensor::Pbo(t) => t.map(),
462 }
463 }
464
465 fn buffer_identity(&self) -> &BufferIdentity {
466 match self {
467 #[cfg(target_os = "linux")]
468 Tensor::Dma(t) => t.buffer_identity(),
469 #[cfg(unix)]
470 Tensor::Shm(t) => t.buffer_identity(),
471 Tensor::Mem(t) => t.buffer_identity(),
472 Tensor::Pbo(t) => t.buffer_identity(),
473 }
474 }
475}
476
477pub enum TensorMap<T>
478where
479 T: Num + Clone + fmt::Debug,
480{
481 #[cfg(target_os = "linux")]
482 Dma(DmaMap<T>),
483 #[cfg(unix)]
484 Shm(ShmMap<T>),
485 Mem(MemMap<T>),
486 Pbo(PboMap<T>),
487}
488
489impl<T> TensorMapTrait<T> for TensorMap<T>
490where
491 T: Num + Clone + fmt::Debug,
492{
493 fn shape(&self) -> &[usize] {
494 match self {
495 #[cfg(target_os = "linux")]
496 TensorMap::Dma(map) => map.shape(),
497 #[cfg(unix)]
498 TensorMap::Shm(map) => map.shape(),
499 TensorMap::Mem(map) => map.shape(),
500 TensorMap::Pbo(map) => map.shape(),
501 }
502 }
503
504 fn unmap(&mut self) {
505 match self {
506 #[cfg(target_os = "linux")]
507 TensorMap::Dma(map) => map.unmap(),
508 #[cfg(unix)]
509 TensorMap::Shm(map) => map.unmap(),
510 TensorMap::Mem(map) => map.unmap(),
511 TensorMap::Pbo(map) => map.unmap(),
512 }
513 }
514
515 fn as_slice(&self) -> &[T] {
516 match self {
517 #[cfg(target_os = "linux")]
518 TensorMap::Dma(map) => map.as_slice(),
519 #[cfg(unix)]
520 TensorMap::Shm(map) => map.as_slice(),
521 TensorMap::Mem(map) => map.as_slice(),
522 TensorMap::Pbo(map) => map.as_slice(),
523 }
524 }
525
526 fn as_mut_slice(&mut self) -> &mut [T] {
527 match self {
528 #[cfg(target_os = "linux")]
529 TensorMap::Dma(map) => map.as_mut_slice(),
530 #[cfg(unix)]
531 TensorMap::Shm(map) => map.as_mut_slice(),
532 TensorMap::Mem(map) => map.as_mut_slice(),
533 TensorMap::Pbo(map) => map.as_mut_slice(),
534 }
535 }
536}
537
538impl<T> Deref for TensorMap<T>
539where
540 T: Num + Clone + fmt::Debug,
541{
542 type Target = [T];
543
544 fn deref(&self) -> &[T] {
545 match self {
546 #[cfg(target_os = "linux")]
547 TensorMap::Dma(map) => map.deref(),
548 #[cfg(unix)]
549 TensorMap::Shm(map) => map.deref(),
550 TensorMap::Mem(map) => map.deref(),
551 TensorMap::Pbo(map) => map.deref(),
552 }
553 }
554}
555
556impl<T> DerefMut for TensorMap<T>
557where
558 T: Num + Clone + fmt::Debug,
559{
560 fn deref_mut(&mut self) -> &mut [T] {
561 match self {
562 #[cfg(target_os = "linux")]
563 TensorMap::Dma(map) => map.deref_mut(),
564 #[cfg(unix)]
565 TensorMap::Shm(map) => map.deref_mut(),
566 TensorMap::Mem(map) => map.deref_mut(),
567 TensorMap::Pbo(map) => map.deref_mut(),
568 }
569 }
570}
571
572#[cfg(target_os = "linux")]
584static DMA_AVAILABLE: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
585
586#[cfg(target_os = "linux")]
588pub fn is_dma_available() -> bool {
589 *DMA_AVAILABLE.get_or_init(|| Tensor::<u8>::new(&[64], Some(TensorMemory::Dma), None).is_ok())
590}
591
592#[cfg(not(target_os = "linux"))]
596pub fn is_dma_available() -> bool {
597 false
598}
599
600#[cfg(unix)]
607static SHM_AVAILABLE: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
608
609#[cfg(unix)]
611pub fn is_shm_available() -> bool {
612 *SHM_AVAILABLE.get_or_init(|| Tensor::<u8>::new(&[64], Some(TensorMemory::Shm), None).is_ok())
613}
614
615#[cfg(not(unix))]
619pub fn is_shm_available() -> bool {
620 false
621}
622
623#[cfg(test)]
624mod tests {
625 #[cfg(target_os = "linux")]
626 use nix::unistd::{access, AccessFlags};
627 #[cfg(target_os = "linux")]
628 use std::io::Write as _;
629 use std::sync::RwLock;
630
631 use super::*;
632
633 #[ctor::ctor]
634 fn init() {
635 env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init();
636 }
637
638 #[cfg(target_os = "linux")]
640 macro_rules! function {
641 () => {{
642 fn f() {}
643 fn type_name_of<T>(_: T) -> &'static str {
644 std::any::type_name::<T>()
645 }
646 let name = type_name_of(f);
647
648 match &name[..name.len() - 3].rfind(':') {
650 Some(pos) => &name[pos + 1..name.len() - 3],
651 None => &name[..name.len() - 3],
652 }
653 }};
654 }
655
656 #[test]
657 #[cfg(target_os = "linux")]
658 fn test_tensor() {
659 let _lock = FD_LOCK.read().unwrap();
660 let shape = vec![1];
661 let tensor = DmaTensor::<f32>::new(&shape, Some("dma_tensor"));
662 let dma_enabled = tensor.is_ok();
663
664 let tensor = Tensor::<f32>::new(&shape, None, None).expect("Failed to create tensor");
665 match dma_enabled {
666 true => assert_eq!(tensor.memory(), TensorMemory::Dma),
667 false => assert_eq!(tensor.memory(), TensorMemory::Shm),
668 }
669 }
670
671 #[test]
672 #[cfg(all(unix, not(target_os = "linux")))]
673 fn test_tensor() {
674 let shape = vec![1];
675 let tensor = Tensor::<f32>::new(&shape, None, None).expect("Failed to create tensor");
676 assert!(
678 tensor.memory() == TensorMemory::Shm || tensor.memory() == TensorMemory::Mem,
679 "Expected SHM or Mem on macOS, got {:?}",
680 tensor.memory()
681 );
682 }
683
684 #[test]
685 #[cfg(not(unix))]
686 fn test_tensor() {
687 let shape = vec![1];
688 let tensor = Tensor::<f32>::new(&shape, None, None).expect("Failed to create tensor");
689 assert_eq!(tensor.memory(), TensorMemory::Mem);
690 }
691
692 #[test]
693 #[cfg(target_os = "linux")]
694 fn test_dma_tensor() {
695 let _lock = FD_LOCK.read().unwrap();
696 match access(
697 "/dev/dma_heap/linux,cma",
698 AccessFlags::R_OK | AccessFlags::W_OK,
699 ) {
700 Ok(_) => println!("/dev/dma_heap/linux,cma is available"),
701 Err(_) => match access(
702 "/dev/dma_heap/system",
703 AccessFlags::R_OK | AccessFlags::W_OK,
704 ) {
705 Ok(_) => println!("/dev/dma_heap/system is available"),
706 Err(e) => {
707 writeln!(
708 &mut std::io::stdout(),
709 "[WARNING] DMA Heap is unavailable: {e}"
710 )
711 .unwrap();
712 return;
713 }
714 },
715 }
716
717 let shape = vec![2, 3, 4];
718 let tensor =
719 DmaTensor::<f32>::new(&shape, Some("test_tensor")).expect("Failed to create tensor");
720
721 const DUMMY_VALUE: f32 = 12.34;
722
723 assert_eq!(tensor.memory(), TensorMemory::Dma);
724 assert_eq!(tensor.name(), "test_tensor");
725 assert_eq!(tensor.shape(), &shape);
726 assert_eq!(tensor.size(), 2 * 3 * 4 * std::mem::size_of::<f32>());
727 assert_eq!(tensor.len(), 2 * 3 * 4);
728
729 {
730 let mut tensor_map = tensor.map().expect("Failed to map DMA memory");
731 tensor_map.fill(42.0);
732 assert!(tensor_map.iter().all(|&x| x == 42.0));
733 }
734
735 {
736 let shared = Tensor::<f32>::from_fd(
737 tensor
738 .clone_fd()
739 .expect("Failed to duplicate tensor file descriptor"),
740 &shape,
741 Some("test_tensor_shared"),
742 )
743 .expect("Failed to create tensor from fd");
744
745 assert_eq!(shared.memory(), TensorMemory::Dma);
746 assert_eq!(shared.name(), "test_tensor_shared");
747 assert_eq!(shared.shape(), &shape);
748
749 let mut tensor_map = shared.map().expect("Failed to map DMA memory from fd");
750 tensor_map.fill(DUMMY_VALUE);
751 assert!(tensor_map.iter().all(|&x| x == DUMMY_VALUE));
752 }
753
754 {
755 let tensor_map = tensor.map().expect("Failed to map DMA memory");
756 assert!(tensor_map.iter().all(|&x| x == DUMMY_VALUE));
757 }
758
759 let mut tensor = DmaTensor::<u8>::new(&shape, None).expect("Failed to create tensor");
760 assert_eq!(tensor.shape(), &shape);
761 let new_shape = vec![3, 4, 4];
762 assert!(
763 tensor.reshape(&new_shape).is_err(),
764 "Reshape should fail due to size mismatch"
765 );
766 assert_eq!(tensor.shape(), &shape, "Shape should remain unchanged");
767
768 let new_shape = vec![2, 3, 4];
769 tensor.reshape(&new_shape).expect("Reshape should succeed");
770 assert_eq!(
771 tensor.shape(),
772 &new_shape,
773 "Shape should be updated after successful reshape"
774 );
775
776 {
777 let mut tensor_map = tensor.map().expect("Failed to map DMA memory");
778 tensor_map.fill(1);
779 assert!(tensor_map.iter().all(|&x| x == 1));
780 }
781
782 {
783 let mut tensor_map = tensor.map().expect("Failed to map DMA memory");
784 tensor_map[2] = 42;
785 assert_eq!(tensor_map[1], 1, "Value at index 1 should be 1");
786 assert_eq!(tensor_map[2], 42, "Value at index 2 should be 42");
787 }
788 }
789
790 #[test]
791 #[cfg(unix)]
792 fn test_shm_tensor() {
793 let _lock = FD_LOCK.read().unwrap();
794 let shape = vec![2, 3, 4];
795 let tensor =
796 ShmTensor::<f32>::new(&shape, Some("test_tensor")).expect("Failed to create tensor");
797 assert_eq!(tensor.shape(), &shape);
798 assert_eq!(tensor.size(), 2 * 3 * 4 * std::mem::size_of::<f32>());
799 assert_eq!(tensor.name(), "test_tensor");
800
801 const DUMMY_VALUE: f32 = 12.34;
802 {
803 let mut tensor_map = tensor.map().expect("Failed to map shared memory");
804 tensor_map.fill(42.0);
805 assert!(tensor_map.iter().all(|&x| x == 42.0));
806 }
807
808 {
809 let shared = Tensor::<f32>::from_fd(
810 tensor
811 .clone_fd()
812 .expect("Failed to duplicate tensor file descriptor"),
813 &shape,
814 Some("test_tensor_shared"),
815 )
816 .expect("Failed to create tensor from fd");
817
818 assert_eq!(shared.memory(), TensorMemory::Shm);
819 assert_eq!(shared.name(), "test_tensor_shared");
820 assert_eq!(shared.shape(), &shape);
821
822 let mut tensor_map = shared.map().expect("Failed to map shared memory from fd");
823 tensor_map.fill(DUMMY_VALUE);
824 assert!(tensor_map.iter().all(|&x| x == DUMMY_VALUE));
825 }
826
827 {
828 let tensor_map = tensor.map().expect("Failed to map shared memory");
829 assert!(tensor_map.iter().all(|&x| x == DUMMY_VALUE));
830 }
831
832 let mut tensor = ShmTensor::<u8>::new(&shape, None).expect("Failed to create tensor");
833 assert_eq!(tensor.shape(), &shape);
834 let new_shape = vec![3, 4, 4];
835 assert!(
836 tensor.reshape(&new_shape).is_err(),
837 "Reshape should fail due to size mismatch"
838 );
839 assert_eq!(tensor.shape(), &shape, "Shape should remain unchanged");
840
841 let new_shape = vec![2, 3, 4];
842 tensor.reshape(&new_shape).expect("Reshape should succeed");
843 assert_eq!(
844 tensor.shape(),
845 &new_shape,
846 "Shape should be updated after successful reshape"
847 );
848
849 {
850 let mut tensor_map = tensor.map().expect("Failed to map shared memory");
851 tensor_map.fill(1);
852 assert!(tensor_map.iter().all(|&x| x == 1));
853 }
854
855 {
856 let mut tensor_map = tensor.map().expect("Failed to map shared memory");
857 tensor_map[2] = 42;
858 assert_eq!(tensor_map[1], 1, "Value at index 1 should be 1");
859 assert_eq!(tensor_map[2], 42, "Value at index 2 should be 42");
860 }
861 }
862
863 #[test]
864 fn test_mem_tensor() {
865 let shape = vec![2, 3, 4];
866 let tensor =
867 MemTensor::<f32>::new(&shape, Some("test_tensor")).expect("Failed to create tensor");
868 assert_eq!(tensor.shape(), &shape);
869 assert_eq!(tensor.size(), 2 * 3 * 4 * std::mem::size_of::<f32>());
870 assert_eq!(tensor.name(), "test_tensor");
871
872 {
873 let mut tensor_map = tensor.map().expect("Failed to map memory");
874 tensor_map.fill(42.0);
875 assert!(tensor_map.iter().all(|&x| x == 42.0));
876 }
877
878 let mut tensor = MemTensor::<u8>::new(&shape, None).expect("Failed to create tensor");
879 assert_eq!(tensor.shape(), &shape);
880 let new_shape = vec![3, 4, 4];
881 assert!(
882 tensor.reshape(&new_shape).is_err(),
883 "Reshape should fail due to size mismatch"
884 );
885 assert_eq!(tensor.shape(), &shape, "Shape should remain unchanged");
886
887 let new_shape = vec![2, 3, 4];
888 tensor.reshape(&new_shape).expect("Reshape should succeed");
889 assert_eq!(
890 tensor.shape(),
891 &new_shape,
892 "Shape should be updated after successful reshape"
893 );
894
895 {
896 let mut tensor_map = tensor.map().expect("Failed to map memory");
897 tensor_map.fill(1);
898 assert!(tensor_map.iter().all(|&x| x == 1));
899 }
900
901 {
902 let mut tensor_map = tensor.map().expect("Failed to map memory");
903 tensor_map[2] = 42;
904 assert_eq!(tensor_map[1], 1, "Value at index 1 should be 1");
905 assert_eq!(tensor_map[2], 42, "Value at index 2 should be 42");
906 }
907 }
908
909 #[test]
910 #[cfg(target_os = "linux")]
911 fn test_dma_no_fd_leaks() {
912 let _lock = FD_LOCK.write().unwrap();
913 if !is_dma_available() {
914 log::warn!(
915 "SKIPPED: {} - DMA memory allocation not available (permission denied or no DMA-BUF support)",
916 function!()
917 );
918 return;
919 }
920
921 let proc = procfs::process::Process::myself()
922 .expect("Failed to get current process using /proc/self");
923
924 let start_open_fds = proc
925 .fd_count()
926 .expect("Failed to get open file descriptor count");
927
928 for _ in 0..100 {
929 let tensor = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Dma), None)
930 .expect("Failed to create tensor");
931 let mut map = tensor.map().unwrap();
932 map.as_mut_slice().fill(233);
933 }
934
935 let end_open_fds = proc
936 .fd_count()
937 .expect("Failed to get open file descriptor count");
938
939 assert_eq!(
940 start_open_fds, end_open_fds,
941 "File descriptor leak detected: {} -> {}",
942 start_open_fds, end_open_fds
943 );
944 }
945
946 #[test]
947 #[cfg(target_os = "linux")]
948 fn test_dma_from_fd_no_fd_leaks() {
949 let _lock = FD_LOCK.write().unwrap();
950 if !is_dma_available() {
951 log::warn!(
952 "SKIPPED: {} - DMA memory allocation not available (permission denied or no DMA-BUF support)",
953 function!()
954 );
955 return;
956 }
957
958 let proc = procfs::process::Process::myself()
959 .expect("Failed to get current process using /proc/self");
960
961 let start_open_fds = proc
962 .fd_count()
963 .expect("Failed to get open file descriptor count");
964
965 let orig = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Dma), None).unwrap();
966
967 for _ in 0..100 {
968 let tensor =
969 Tensor::<u8>::from_fd(orig.clone_fd().unwrap(), orig.shape(), None).unwrap();
970 let mut map = tensor.map().unwrap();
971 map.as_mut_slice().fill(233);
972 }
973 drop(orig);
974
975 let end_open_fds = proc.fd_count().unwrap();
976
977 assert_eq!(
978 start_open_fds, end_open_fds,
979 "File descriptor leak detected: {} -> {}",
980 start_open_fds, end_open_fds
981 );
982 }
983
984 #[test]
985 #[cfg(target_os = "linux")]
986 fn test_shm_no_fd_leaks() {
987 let _lock = FD_LOCK.write().unwrap();
988 if !is_shm_available() {
989 log::warn!(
990 "SKIPPED: {} - SHM memory allocation not available (permission denied or no SHM support)",
991 function!()
992 );
993 return;
994 }
995
996 let proc = procfs::process::Process::myself()
997 .expect("Failed to get current process using /proc/self");
998
999 let start_open_fds = proc
1000 .fd_count()
1001 .expect("Failed to get open file descriptor count");
1002
1003 for _ in 0..100 {
1004 let tensor = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Shm), None)
1005 .expect("Failed to create tensor");
1006 let mut map = tensor.map().unwrap();
1007 map.as_mut_slice().fill(233);
1008 }
1009
1010 let end_open_fds = proc
1011 .fd_count()
1012 .expect("Failed to get open file descriptor count");
1013
1014 assert_eq!(
1015 start_open_fds, end_open_fds,
1016 "File descriptor leak detected: {} -> {}",
1017 start_open_fds, end_open_fds
1018 );
1019 }
1020
1021 #[test]
1022 #[cfg(target_os = "linux")]
1023 fn test_shm_from_fd_no_fd_leaks() {
1024 let _lock = FD_LOCK.write().unwrap();
1025 if !is_shm_available() {
1026 log::warn!(
1027 "SKIPPED: {} - SHM memory allocation not available (permission denied or no SHM support)",
1028 function!()
1029 );
1030 return;
1031 }
1032
1033 let proc = procfs::process::Process::myself()
1034 .expect("Failed to get current process using /proc/self");
1035
1036 let start_open_fds = proc
1037 .fd_count()
1038 .expect("Failed to get open file descriptor count");
1039
1040 let orig = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Shm), None).unwrap();
1041
1042 for _ in 0..100 {
1043 let tensor =
1044 Tensor::<u8>::from_fd(orig.clone_fd().unwrap(), orig.shape(), None).unwrap();
1045 let mut map = tensor.map().unwrap();
1046 map.as_mut_slice().fill(233);
1047 }
1048 drop(orig);
1049
1050 let end_open_fds = proc.fd_count().unwrap();
1051
1052 assert_eq!(
1053 start_open_fds, end_open_fds,
1054 "File descriptor leak detected: {} -> {}",
1055 start_open_fds, end_open_fds
1056 );
1057 }
1058
1059 #[cfg(feature = "ndarray")]
1060 #[test]
1061 fn test_ndarray() {
1062 let _lock = FD_LOCK.read().unwrap();
1063 let shape = vec![2, 3, 4];
1064 let tensor = Tensor::<f32>::new(&shape, None, None).expect("Failed to create tensor");
1065
1066 let mut tensor_map = tensor.map().expect("Failed to map tensor memory");
1067 tensor_map.fill(1.0);
1068
1069 let view = tensor_map.view().expect("Failed to get ndarray view");
1070 assert_eq!(view.shape(), &[2, 3, 4]);
1071 assert!(view.iter().all(|&x| x == 1.0));
1072
1073 let mut view_mut = tensor_map
1074 .view_mut()
1075 .expect("Failed to get mutable ndarray view");
1076 view_mut[[0, 0, 0]] = 42.0;
1077 assert_eq!(view_mut[[0, 0, 0]], 42.0);
1078 assert_eq!(tensor_map[0], 42.0, "Value at index 0 should be 42");
1079 }
1080
1081 #[test]
1082 fn test_buffer_identity_unique() {
1083 let id1 = BufferIdentity::new();
1084 let id2 = BufferIdentity::new();
1085 assert_ne!(
1086 id1.id(),
1087 id2.id(),
1088 "Two identities should have different ids"
1089 );
1090 }
1091
1092 #[test]
1093 fn test_buffer_identity_clone_shares_guard() {
1094 let id1 = BufferIdentity::new();
1095 let weak = id1.weak();
1096 assert!(
1097 weak.upgrade().is_some(),
1098 "Weak should be alive while original exists"
1099 );
1100
1101 let id2 = id1.clone();
1102 assert_eq!(id1.id(), id2.id(), "Cloned identity should have same id");
1103
1104 drop(id1);
1105 assert!(
1106 weak.upgrade().is_some(),
1107 "Weak should still be alive (clone holds Arc)"
1108 );
1109
1110 drop(id2);
1111 assert!(
1112 weak.upgrade().is_none(),
1113 "Weak should be dead after all clones dropped"
1114 );
1115 }
1116
1117 #[test]
1118 fn test_tensor_buffer_identity() {
1119 let t1 = Tensor::<u8>::new(&[100], Some(TensorMemory::Mem), Some("t1")).unwrap();
1120 let t2 = Tensor::<u8>::new(&[100], Some(TensorMemory::Mem), Some("t2")).unwrap();
1121 assert_ne!(
1122 t1.buffer_identity().id(),
1123 t2.buffer_identity().id(),
1124 "Different tensors should have different buffer ids"
1125 );
1126 }
1127
1128 pub static FD_LOCK: RwLock<()> = RwLock::new(());
1132
1133 #[test]
1136 #[cfg(not(target_os = "linux"))]
1137 fn test_dma_not_available_on_non_linux() {
1138 assert!(
1139 !is_dma_available(),
1140 "DMA memory allocation should NOT be available on non-Linux platforms"
1141 );
1142 }
1143
1144 #[test]
1147 #[cfg(unix)]
1148 fn test_shm_available_and_usable() {
1149 assert!(
1150 is_shm_available(),
1151 "SHM memory allocation should be available on Unix systems"
1152 );
1153
1154 let tensor = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Shm), None)
1156 .expect("Failed to create SHM tensor");
1157
1158 let mut map = tensor.map().expect("Failed to map SHM tensor");
1160 map.as_mut_slice().fill(0xAB);
1161
1162 assert!(
1164 map.as_slice().iter().all(|&b| b == 0xAB),
1165 "SHM tensor data should be writable and readable"
1166 );
1167 }
1168}