Skip to main content

edgefirst_tensor/
mem.rs

1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4use crate::{
5    error::{Error, Result},
6    TensorMap, TensorMapTrait, TensorMemory, TensorTrait,
7};
8use log::trace;
9use num_traits::Num;
10use std::{
11    ffi::c_void,
12    fmt,
13    ops::{Deref, DerefMut},
14    ptr::NonNull,
15    sync::{Arc, Mutex},
16};
17#[derive(Debug)]
18pub struct MemTensor<T>
19where
20    T: Num + Clone + fmt::Debug + Send + Sync,
21{
22    pub name: String,
23    pub shape: Vec<usize>,
24    pub data: Vec<T>,
25    identity: crate::BufferIdentity,
26}
27
28unsafe impl<T> Send for MemTensor<T> where T: Num + Clone + fmt::Debug + Send + Sync {}
29unsafe impl<T> Sync for MemTensor<T> where T: Num + Clone + fmt::Debug + Send + Sync {}
30impl<T> TensorTrait<T> for MemTensor<T>
31where
32    T: Num + Clone + fmt::Debug + Send + Sync,
33{
34    fn new(shape: &[usize], name: Option<&str>) -> Result<Self> {
35        if shape.is_empty() {
36            return Err(Error::InvalidSize(0));
37        }
38
39        let size = shape.iter().product::<usize>() * std::mem::size_of::<T>();
40        if size == 0 {
41            return Err(Error::InvalidSize(0));
42        }
43
44        let name = name.unwrap_or("mem_tensor").to_owned();
45        let data = vec![T::zero(); size / std::mem::size_of::<T>()];
46
47        Ok(MemTensor {
48            name,
49            shape: shape.to_vec(),
50            data,
51            identity: crate::BufferIdentity::new(),
52        })
53    }
54
55    #[cfg(unix)]
56    fn from_fd(_fd: std::os::fd::OwnedFd, _shape: &[usize], _name: Option<&str>) -> Result<Self> {
57        Err(Error::NotImplemented(
58            "MemTensor does not support from_fd".to_owned(),
59        ))
60    }
61
62    #[cfg(unix)]
63    fn clone_fd(&self) -> Result<std::os::fd::OwnedFd> {
64        Err(Error::NotImplemented(
65            "MemTensor does not support clone_fd".to_owned(),
66        ))
67    }
68
69    fn memory(&self) -> TensorMemory {
70        TensorMemory::Mem
71    }
72
73    fn name(&self) -> String {
74        self.name.clone()
75    }
76
77    fn shape(&self) -> &[usize] {
78        &self.shape
79    }
80
81    fn reshape(&mut self, shape: &[usize]) -> Result<()> {
82        if shape.is_empty() {
83            return Err(Error::InvalidSize(0));
84        }
85
86        let new_size = shape.iter().product::<usize>() * std::mem::size_of::<T>();
87        if new_size != self.size() {
88            return Err(Error::ShapeMismatch(format!(
89                "Cannot reshape incompatible shape: {:?} to {:?}",
90                self.shape, shape
91            )));
92        }
93
94        self.shape = shape.to_vec();
95        Ok(())
96    }
97
98    fn map(&self) -> Result<TensorMap<T>> {
99        let mem_ptr = MemPtr(
100            NonNull::new(self.data.as_ptr() as *mut c_void)
101                .ok_or(Error::InvalidSize(self.size()))?,
102        );
103        Ok(TensorMap::Mem(MemMap {
104            ptr: Arc::new(Mutex::new(mem_ptr)),
105            shape: self.shape.clone(),
106            _marker: std::marker::PhantomData,
107        }))
108    }
109
110    fn buffer_identity(&self) -> &crate::BufferIdentity {
111        &self.identity
112    }
113}
114
115#[derive(Debug)]
116pub struct MemMap<T>
117where
118    T: Num + Clone + fmt::Debug,
119{
120    ptr: Arc<Mutex<MemPtr>>,
121    shape: Vec<usize>,
122    _marker: std::marker::PhantomData<T>,
123}
124
125impl<T> Deref for MemMap<T>
126where
127    T: Num + Clone + fmt::Debug,
128{
129    type Target = [T];
130
131    fn deref(&self) -> &[T] {
132        self.as_slice()
133    }
134}
135
136impl<T> DerefMut for MemMap<T>
137where
138    T: Num + Clone + fmt::Debug,
139{
140    fn deref_mut(&mut self) -> &mut [T] {
141        self.as_mut_slice()
142    }
143}
144
145#[derive(Debug)]
146struct MemPtr(NonNull<c_void>);
147impl Deref for MemPtr {
148    type Target = NonNull<c_void>;
149
150    fn deref(&self) -> &Self::Target {
151        &self.0
152    }
153}
154
155unsafe impl Send for MemPtr {}
156
157impl<T> TensorMapTrait<T> for MemMap<T>
158where
159    T: Num + Clone + fmt::Debug,
160{
161    fn shape(&self) -> &[usize] {
162        &self.shape
163    }
164
165    fn unmap(&mut self) {
166        trace!("Unmapping MemMap memory");
167    }
168
169    fn as_slice(&self) -> &[T] {
170        let ptr = self.ptr.lock().expect("Failed to lock MemMap pointer");
171        unsafe { std::slice::from_raw_parts(ptr.as_ptr() as *const T, self.len()) }
172    }
173
174    fn as_mut_slice(&mut self) -> &mut [T] {
175        let ptr = self.ptr.lock().expect("Failed to lock MemMap pointer");
176        unsafe { std::slice::from_raw_parts_mut(ptr.as_ptr() as *mut T, self.len()) }
177    }
178}
179
180impl<T> Drop for MemMap<T>
181where
182    T: Num + Clone + fmt::Debug,
183{
184    fn drop(&mut self) {
185        self.unmap();
186    }
187}