Skip to main content

mnn_rs/
tensor.rs

1//! Tensor types for MNN inference.
2//!
3//! This module provides safe wrappers around MNN tensor operations,
4//! including creating, reading, and writing tensor data.
5
6use crate::backend::{DataType, BackendType};
7use crate::config::DataFormat;
8use crate::error::{MnnError, MnnResult};
9use mnn_rs_sys::MNNTensor;
10use std::ffi::c_void;
11use std::marker::PhantomData;
12
13/// Information about a tensor's shape and type.
14#[derive(Debug, Clone, PartialEq, Eq)]
15pub struct TensorInfo {
16    /// Name of the tensor (may be empty)
17    pub name: String,
18
19    /// Shape of the tensor (dimensions)
20    pub shape: Vec<i32>,
21
22    /// Data type of tensor elements
23    pub dtype: DataType,
24
25    /// Data format (layout)
26    pub format: DataFormat,
27}
28
29impl TensorInfo {
30    /// Get the total number of elements in the tensor.
31    pub fn element_count(&self) -> i32 {
32        self.shape.iter().product()
33    }
34
35    /// Get the size in bytes of the tensor data.
36    pub fn byte_size(&self) -> usize {
37        self.element_count() as usize * self.dtype.size()
38    }
39}
40
41/// A multi-dimensional array for neural network operations.
42///
43/// Tensors are the primary data structure for MNN inference,
44/// holding input and output data for models.
45pub struct Tensor {
46    inner: *mut MNNTensor,
47    /// Name of the tensor (if any)
48    name: Option<String>,
49}
50
51// Safety: Tensor operations are thread-safe through MNN's internal synchronization
52unsafe impl Send for Tensor {}
53unsafe impl Sync for Tensor {}
54
55impl Tensor {
56    /// Create a tensor wrapper around an existing MNN tensor pointer.
57    ///
58    /// # Safety
59    /// The pointer must be valid and remain valid for the lifetime of this tensor.
60    pub(crate) unsafe fn from_ptr_with_name(ptr: *mut MNNTensor, name: Option<String>) -> Self {
61        Self { inner: ptr, name }
62    }
63
64    /// Create a tensor wrapper around an existing MNN tensor pointer (public version).
65    ///
66    /// # Safety
67    /// The pointer must be valid and remain valid for the lifetime of this tensor.
68    pub unsafe fn from_ptr(ptr: *mut MNNTensor, name: Option<String>) -> Self {
69        Self { inner: ptr, name }
70    }
71
72    /// Get the mutable raw pointer to the underlying MNN tensor.
73    pub fn inner_mut(&mut self) -> *mut MNNTensor {
74        self.inner
75    }
76
77    /// Get the raw pointer to the underlying MNN tensor.
78    pub fn as_ptr(&self) -> *const MNNTensor {
79        self.inner
80    }
81
82    /// Get the shape of the tensor.
83    pub fn shape(&self) -> Vec<i32> {
84        unsafe {
85            let dim_count = mnn_rs_sys::mnn_tensor_get_dimensions(self.inner);
86            if dim_count <= 0 {
87                return Vec::new();
88            }
89
90            let mut shape = Vec::with_capacity(dim_count as usize);
91            for i in 0..dim_count {
92                let dim = mnn_rs_sys::mnn_tensor_get_dim(self.inner, i);
93                shape.push(dim);
94            }
95            shape
96        }
97    }
98
99    /// Get the number of dimensions.
100    pub fn ndim(&self) -> usize {
101        unsafe { mnn_rs_sys::mnn_tensor_get_dimensions(self.inner) as usize }
102    }
103
104    /// Get the size of a specific dimension.
105    ///
106    /// # Arguments
107    /// * `axis` - The dimension index (0-based)
108    ///
109    /// # Returns
110    /// The size of the dimension, or an error if the axis is out of bounds.
111    pub fn dim(&self, axis: usize) -> MnnResult<i32> {
112        let shape = self.shape();
113        if axis >= shape.len() {
114            return Err(MnnError::index_out_of_bounds(axis, 0, shape.len() as i32));
115        }
116        Ok(shape[axis])
117    }
118
119    /// Get the data type of the tensor.
120    pub fn dtype(&self) -> DataType {
121        unsafe {
122            let type_code = mnn_rs_sys::mnn_tensor_get_type_code(self.inner);
123            DataType::from_type_code(type_code)
124        }
125    }
126
127    /// Get the data format of the tensor.
128    pub fn format(&self) -> DataFormat {
129        unsafe {
130            let dim_type = mnn_rs_sys::mnn_tensor_get_dimension_type(self.inner);
131            match dim_type {
132                0 => DataFormat::Nhwc,
133                1 => DataFormat::Nc4hw4,
134                2 => DataFormat::Nchw,
135                _ => DataFormat::Nhwc,
136            }
137        }
138    }
139
140    /// Get the total number of elements in the tensor.
141    pub fn element_count(&self) -> i32 {
142        unsafe { mnn_rs_sys::mnn_tensor_get_element_count(self.inner) }
143    }
144
145    /// Get the size of the tensor data in bytes.
146    pub fn byte_size(&self) -> usize {
147        unsafe { mnn_rs_sys::mnn_tensor_get_size(self.inner) as usize }
148    }
149
150    /// Get the name of the tensor.
151    pub fn name(&self) -> Option<&str> {
152        self.name.as_deref()
153    }
154
155    /// Write data to the tensor.
156    ///
157    /// # Arguments
158    /// * `data` - The data to write
159    ///
160    /// # Errors
161    /// Returns an error if the data size doesn't match.
162    pub fn write<T: TensorData>(&self, data: &[T]) -> MnnResult<()> {
163        if data.is_empty() {
164            return Err(MnnError::EmptyData);
165        }
166
167        let expected_count = self.element_count() as usize;
168        if data.len() != expected_count {
169            return Err(MnnError::shape_mismatch(
170                &[expected_count as i32],
171                &[data.len() as i32],
172            ));
173        }
174
175        let host_data = unsafe { mnn_rs_sys::mnn_tensor_get_host_data(self.inner) };
176        if host_data.is_null() {
177            return Err(MnnError::tensor_error("Tensor has no host data"));
178        }
179
180        unsafe {
181            std::ptr::copy_nonoverlapping(
182                data.as_ptr() as *const c_void,
183                host_data,
184                data.len() * std::mem::size_of::<T>(),
185            );
186        }
187
188        Ok(())
189    }
190
191    /// Read data from the tensor.
192    ///
193    /// # Returns
194    /// A vector containing the tensor data.
195    pub fn read<T: TensorData>(&self) -> MnnResult<Vec<T>> {
196        let count = self.element_count() as usize;
197        let mut data = vec![T::default(); count];
198
199        let host_data = unsafe { mnn_rs_sys::mnn_tensor_get_host_data(self.inner) };
200        if host_data.is_null() {
201            return Err(MnnError::tensor_error("Tensor has no host data"));
202        }
203
204        unsafe {
205            std::ptr::copy_nonoverlapping(
206                host_data,
207                data.as_mut_ptr() as *mut c_void,
208                count * std::mem::size_of::<T>(),
209            );
210        }
211
212        Ok(data)
213    }
214
215    /// Get a mutable reference to the tensor's host data.
216    ///
217    /// # Safety
218    /// The returned slice is valid only as long as no other operations
219    /// are performed on the tensor.
220    pub unsafe fn as_slice_mut<T: TensorData>(&mut self) -> MnnResult<&mut [T]> {
221        let count = self.element_count() as usize;
222        let ptr = unsafe { mnn_rs_sys::mnn_tensor_get_host_data(self.inner) };
223
224        if ptr.is_null() {
225            return Err(MnnError::tensor_error("Tensor has no host data"));
226        }
227
228        Ok(unsafe { std::slice::from_raw_parts_mut(ptr as *mut T, count) })
229    }
230
231    /// Get a reference to the tensor's host data.
232    ///
233    /// # Safety
234    /// The returned slice is valid only as long as no other operations
235    /// are performed on the tensor.
236    pub unsafe fn as_slice<T: TensorData>(&self) -> MnnResult<&[T]> {
237        let count = self.element_count() as usize;
238        let ptr = unsafe { mnn_rs_sys::mnn_tensor_get_host_data(self.inner) };
239
240        if ptr.is_null() {
241            return Err(MnnError::tensor_error("Tensor has no host data"));
242        }
243
244        Ok(unsafe { std::slice::from_raw_parts(ptr as *const T, count) })
245    }
246
247    /// Get the tensor info.
248    pub fn info(&self) -> TensorInfo {
249        TensorInfo {
250            name: self.name.clone().unwrap_or_default(),
251            shape: self.shape(),
252            dtype: self.dtype(),
253            format: self.format(),
254        }
255    }
256
257    // ========================================================================
258    // GPU Memory Operations
259    // ========================================================================
260
261    /// Copy data from a host tensor to this tensor (potentially on device).
262    ///
263    /// # Arguments
264    /// * `host_tensor` - The host tensor to copy from
265    ///
266    /// # Returns
267    /// Ok(()) on success, or an error on failure.
268    pub fn copy_from_host(&mut self, host_tensor: &Tensor) -> MnnResult<()> {
269        let result = unsafe {
270            mnn_rs_sys::mnn_tensor_copy_from_host(self.inner, host_tensor.inner)
271        };
272
273        if result != mnn_rs_sys::MNN_ERROR_NONE as i32 {
274            return Err(MnnError::internal("Failed to copy from host tensor"));
275        }
276
277        Ok(())
278    }
279
280    /// Copy data from this tensor (potentially on device) to a host tensor.
281    ///
282    /// # Arguments
283    /// * `host_tensor` - The host tensor to copy to
284    ///
285    /// # Returns
286    /// Ok(()) on success, or an error on failure.
287    pub fn copy_to_host(&self, host_tensor: &mut Tensor) -> MnnResult<()> {
288        let result = unsafe {
289            mnn_rs_sys::mnn_tensor_copy_to_host(host_tensor.inner, self.inner)
290        };
291
292        if result != mnn_rs_sys::MNN_ERROR_NONE as i32 {
293            return Err(MnnError::internal("Failed to copy to host tensor"));
294        }
295
296        Ok(())
297    }
298
299    /// Create a device tensor with the given shape and format.
300    ///
301    /// # Arguments
302    /// * `shape` - The tensor shape
303    /// * `format` - The data format (NHWC, NCHW, etc.)
304    /// * `dtype` - The data type
305    ///
306    /// # Returns
307    /// A new device tensor on success, or an error on failure.
308    pub fn create_device(
309        shape: &[i32],
310        format: DataFormat,
311        dtype: DataType,
312    ) -> MnnResult<Tensor> {
313        if shape.is_empty() {
314            return Err(MnnError::internal("Shape cannot be empty"));
315        }
316
317        let type_code = dtype.to_type_code();
318        let format_code = format.to_mnn();
319
320        let inner = unsafe {
321            mnn_rs_sys::mnn_tensor_create_device(
322                shape.as_ptr(),
323                shape.len() as i32,
324                type_code,
325                format_code,
326            )
327        };
328
329        if inner.is_null() {
330            return Err(MnnError::internal("Failed to create device tensor"));
331        }
332
333        Ok(unsafe { Tensor::from_ptr(inner, None) })
334    }
335
336    /// Clone this tensor.
337    ///
338    /// # Arguments
339    /// * `deep_copy` - If true, copy data; if false, only copy metadata
340    ///
341    /// # Returns
342    /// A cloned tensor on success, or an error on failure.
343    pub fn clone(&self, deep_copy: bool) -> MnnResult<Tensor> {
344        let inner = unsafe {
345            mnn_rs_sys::mnn_tensor_clone(self.inner, if deep_copy { 1 } else { 0 })
346        };
347
348        if inner.is_null() {
349            return Err(MnnError::internal("Failed to clone tensor"));
350        }
351
352        Ok(unsafe { Tensor::from_ptr(inner, None) })
353    }
354
355    /// Get the device ID for this tensor (for GPU tensors).
356    ///
357    /// # Returns
358    /// The device ID, or 0 if not a GPU tensor or unknown.
359    pub fn device_id(&self) -> u64 {
360        unsafe { mnn_rs_sys::mnn_tensor_device_id(self.inner) }
361    }
362
363    /// Get the backend type for this tensor.
364    ///
365    /// # Returns
366    /// The backend type (CPU, CUDA, OpenCL, etc.).
367    pub fn backend(&self) -> BackendType {
368        let backend_code = unsafe { mnn_rs_sys::mnn_tensor_get_backend(self.inner) };
369        BackendType::from_mnn_type(backend_code)
370    }
371}
372
373impl std::fmt::Debug for Tensor {
374    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
375        f.debug_struct("Tensor")
376            .field("shape", &self.shape())
377            .field("dtype", &self.dtype())
378            .field("format", &self.format())
379            .field("name", &self.name)
380            .finish()
381    }
382}
383
384/// Trait for types that can be stored in a tensor.
385///
386/// This trait is implemented for primitive numeric types that MNN supports.
387pub trait TensorData: Default + Clone + Copy + 'static {
388    /// Get the MNN data type for this Rust type.
389    fn dtype() -> DataType;
390}
391
392impl TensorData for f32 {
393    fn dtype() -> DataType {
394        DataType::Float32
395    }
396}
397
398impl TensorData for f64 {
399    fn dtype() -> DataType {
400        DataType::Float64
401    }
402}
403
404impl TensorData for i32 {
405    fn dtype() -> DataType {
406        DataType::Int32
407    }
408}
409
410impl TensorData for i16 {
411    fn dtype() -> DataType {
412        DataType::Int16
413    }
414}
415
416impl TensorData for u8 {
417    fn dtype() -> DataType {
418        DataType::UInt8
419    }
420}
421
422#[cfg(feature = "fp16")]
423impl TensorData for half::f16 {
424    fn dtype() -> DataType {
425        DataType::Float16
426    }
427}
428
429#[cfg(feature = "int8")]
430impl TensorData for i8 {
431    fn dtype() -> DataType {
432        DataType::Int8
433    }
434}
435
436/// A view into a tensor's data without ownership.
437///
438/// This is useful for zero-copy access to tensor data.
439pub struct TensorView<'a> {
440    inner: *mut MNNTensor,
441    _marker: PhantomData<&'a Tensor>,
442}
443
444impl std::fmt::Debug for TensorView<'_> {
445    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
446        f.debug_struct("TensorView").finish_non_exhaustive()
447    }
448}
449
450impl<'a> TensorView<'a> {
451    /// Create a view from a tensor reference.
452    pub fn from_tensor(tensor: &'a Tensor) -> Self {
453        Self {
454            inner: tensor.inner,
455            _marker: PhantomData,
456        }
457    }
458
459    /// Get the shape of the tensor.
460    pub fn shape(&self) -> Vec<i32> {
461        unsafe {
462            let dim_count = mnn_rs_sys::mnn_tensor_get_dimensions(self.inner);
463            if dim_count <= 0 {
464                return Vec::new();
465            }
466
467            let mut shape = Vec::with_capacity(dim_count as usize);
468            for i in 0..dim_count {
469                let dim = mnn_rs_sys::mnn_tensor_get_dim(self.inner, i);
470                shape.push(dim);
471            }
472            shape
473        }
474    }
475
476    /// Get the data type.
477    pub fn dtype(&self) -> DataType {
478        DataType::Float32
479    }
480}
481
482impl TensorInfo {
483    /// Get the data type.
484    pub fn dtype(&self) -> DataType {
485        self.dtype
486    }
487}
488
489#[cfg(test)]
490mod tests {
491    use super::*;
492
493    #[test]
494    fn test_tensor_data_types() {
495        assert_eq!(f32::dtype(), DataType::Float32);
496        assert_eq!(i32::dtype(), DataType::Int32);
497        assert_eq!(u8::dtype(), DataType::UInt8);
498    }
499
500    #[test]
501    fn test_tensor_info() {
502        let info = TensorInfo {
503            name: "test".to_string(),
504            shape: vec![1, 3, 224, 224],
505            dtype: DataType::Float32,
506            format: DataFormat::Nchw,
507        };
508
509        assert_eq!(info.element_count(), 1 * 3 * 224 * 224);
510        assert_eq!(info.byte_size(), 1 * 3 * 224 * 224 * 4);
511    }
512}