Skip to main content

edgefirst_tensor/
mem.rs

1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4use crate::{
5    error::{Error, Result},
6    TensorMap, TensorMapTrait, TensorMemory, TensorTrait,
7};
8use log::trace;
9use num_traits::Num;
10use std::{
11    ffi::c_void,
12    fmt,
13    ops::{Deref, DerefMut},
14    ptr::NonNull,
15    sync::{Arc, Mutex},
16};
17#[derive(Debug)]
18pub struct MemTensor<T>
19where
20    T: Num + Clone + fmt::Debug + Send + Sync,
21{
22    pub name: String,
23    pub shape: Vec<usize>,
24    pub data: Vec<T>,
25}
26
27unsafe impl<T> Send for MemTensor<T> where T: Num + Clone + fmt::Debug + Send + Sync {}
28unsafe impl<T> Sync for MemTensor<T> where T: Num + Clone + fmt::Debug + Send + Sync {}
29impl<T> TensorTrait<T> for MemTensor<T>
30where
31    T: Num + Clone + fmt::Debug + Send + Sync,
32{
33    fn new(shape: &[usize], name: Option<&str>) -> Result<Self> {
34        if shape.is_empty() {
35            return Err(Error::InvalidSize(0));
36        }
37
38        let size = shape.iter().product::<usize>() * std::mem::size_of::<T>();
39        if size == 0 {
40            return Err(Error::InvalidSize(0));
41        }
42
43        let name = name.unwrap_or("mem_tensor").to_owned();
44        let data = vec![T::zero(); size / std::mem::size_of::<T>()];
45
46        Ok(MemTensor {
47            name,
48            shape: shape.to_vec(),
49            data,
50        })
51    }
52
53    #[cfg(unix)]
54    fn from_fd(_fd: std::os::fd::OwnedFd, _shape: &[usize], _name: Option<&str>) -> Result<Self> {
55        Err(Error::NotImplemented(
56            "MemTensor does not support from_fd".to_owned(),
57        ))
58    }
59
60    #[cfg(unix)]
61    fn clone_fd(&self) -> Result<std::os::fd::OwnedFd> {
62        Err(Error::NotImplemented(
63            "MemTensor does not support clone_fd".to_owned(),
64        ))
65    }
66
67    fn memory(&self) -> TensorMemory {
68        TensorMemory::Mem
69    }
70
71    fn name(&self) -> String {
72        self.name.clone()
73    }
74
75    fn shape(&self) -> &[usize] {
76        &self.shape
77    }
78
79    fn reshape(&mut self, shape: &[usize]) -> Result<()> {
80        if shape.is_empty() {
81            return Err(Error::InvalidSize(0));
82        }
83
84        let new_size = shape.iter().product::<usize>() * std::mem::size_of::<T>();
85        if new_size != self.size() {
86            return Err(Error::ShapeMismatch(format!(
87                "Cannot reshape incompatible shape: {:?} to {:?}",
88                self.shape, shape
89            )));
90        }
91
92        self.shape = shape.to_vec();
93        Ok(())
94    }
95
96    fn map(&self) -> Result<TensorMap<T>> {
97        let mem_ptr = MemPtr(
98            NonNull::new(self.data.as_ptr() as *mut c_void)
99                .ok_or(Error::InvalidSize(self.size()))?,
100        );
101        Ok(TensorMap::Mem(MemMap {
102            ptr: Arc::new(Mutex::new(mem_ptr)),
103            shape: self.shape.clone(),
104            _marker: std::marker::PhantomData,
105        }))
106    }
107}
108
109#[derive(Debug)]
110pub struct MemMap<T>
111where
112    T: Num + Clone + fmt::Debug,
113{
114    ptr: Arc<Mutex<MemPtr>>,
115    shape: Vec<usize>,
116    _marker: std::marker::PhantomData<T>,
117}
118
119impl<T> Deref for MemMap<T>
120where
121    T: Num + Clone + fmt::Debug,
122{
123    type Target = [T];
124
125    fn deref(&self) -> &[T] {
126        self.as_slice()
127    }
128}
129
130impl<T> DerefMut for MemMap<T>
131where
132    T: Num + Clone + fmt::Debug,
133{
134    fn deref_mut(&mut self) -> &mut [T] {
135        self.as_mut_slice()
136    }
137}
138
139#[derive(Debug)]
140struct MemPtr(NonNull<c_void>);
141impl Deref for MemPtr {
142    type Target = NonNull<c_void>;
143
144    fn deref(&self) -> &Self::Target {
145        &self.0
146    }
147}
148
149unsafe impl Send for MemPtr {}
150
151impl<T> TensorMapTrait<T> for MemMap<T>
152where
153    T: Num + Clone + fmt::Debug,
154{
155    fn shape(&self) -> &[usize] {
156        &self.shape
157    }
158
159    fn unmap(&mut self) {
160        trace!("Unmapping MemMap memory");
161    }
162
163    fn as_slice(&self) -> &[T] {
164        let ptr = self.ptr.lock().expect("Failed to lock MemMap pointer");
165        unsafe { std::slice::from_raw_parts(ptr.as_ptr() as *const T, self.len()) }
166    }
167
168    fn as_mut_slice(&mut self) -> &mut [T] {
169        let ptr = self.ptr.lock().expect("Failed to lock MemMap pointer");
170        unsafe { std::slice::from_raw_parts_mut(ptr.as_ptr() as *mut T, self.len()) }
171    }
172}
173
174impl<T> Drop for MemMap<T>
175where
176    T: Num + Clone + fmt::Debug,
177{
178    fn drop(&mut self) {
179        self.unmap();
180    }
181}