1use crate::backend::{DataType, BackendType};
7use crate::config::DataFormat;
8use crate::error::{MnnError, MnnResult};
9use mnn_rs_sys::MNNTensor;
10use std::ffi::c_void;
11use std::marker::PhantomData;
12
13#[derive(Debug, Clone, PartialEq, Eq)]
15pub struct TensorInfo {
16 pub name: String,
18
19 pub shape: Vec<i32>,
21
22 pub dtype: DataType,
24
25 pub format: DataFormat,
27}
28
29impl TensorInfo {
30 pub fn element_count(&self) -> i32 {
32 self.shape.iter().product()
33 }
34
35 pub fn byte_size(&self) -> usize {
37 self.element_count() as usize * self.dtype.size()
38 }
39}
40
41pub struct Tensor {
46 inner: *mut MNNTensor,
47 name: Option<String>,
49}
50
51unsafe impl Send for Tensor {}
53unsafe impl Sync for Tensor {}
54
55impl Tensor {
56 pub(crate) unsafe fn from_ptr_with_name(ptr: *mut MNNTensor, name: Option<String>) -> Self {
61 Self { inner: ptr, name }
62 }
63
64 pub unsafe fn from_ptr(ptr: *mut MNNTensor, name: Option<String>) -> Self {
69 Self { inner: ptr, name }
70 }
71
72 pub fn inner_mut(&mut self) -> *mut MNNTensor {
74 self.inner
75 }
76
77 pub fn as_ptr(&self) -> *const MNNTensor {
79 self.inner
80 }
81
82 pub fn shape(&self) -> Vec<i32> {
84 unsafe {
85 let dim_count = mnn_rs_sys::mnn_tensor_get_dimensions(self.inner);
86 if dim_count <= 0 {
87 return Vec::new();
88 }
89
90 let mut shape = Vec::with_capacity(dim_count as usize);
91 for i in 0..dim_count {
92 let dim = mnn_rs_sys::mnn_tensor_get_dim(self.inner, i);
93 shape.push(dim);
94 }
95 shape
96 }
97 }
98
99 pub fn ndim(&self) -> usize {
101 unsafe { mnn_rs_sys::mnn_tensor_get_dimensions(self.inner) as usize }
102 }
103
104 pub fn dim(&self, axis: usize) -> MnnResult<i32> {
112 let shape = self.shape();
113 if axis >= shape.len() {
114 return Err(MnnError::index_out_of_bounds(axis, 0, shape.len() as i32));
115 }
116 Ok(shape[axis])
117 }
118
119 pub fn dtype(&self) -> DataType {
121 unsafe {
122 let type_code = mnn_rs_sys::mnn_tensor_get_type_code(self.inner);
123 DataType::from_type_code(type_code)
124 }
125 }
126
127 pub fn format(&self) -> DataFormat {
129 unsafe {
130 let dim_type = mnn_rs_sys::mnn_tensor_get_dimension_type(self.inner);
131 match dim_type {
132 0 => DataFormat::Nhwc,
133 1 => DataFormat::Nc4hw4,
134 2 => DataFormat::Nchw,
135 _ => DataFormat::Nhwc,
136 }
137 }
138 }
139
140 pub fn element_count(&self) -> i32 {
142 unsafe { mnn_rs_sys::mnn_tensor_get_element_count(self.inner) }
143 }
144
145 pub fn byte_size(&self) -> usize {
147 unsafe { mnn_rs_sys::mnn_tensor_get_size(self.inner) as usize }
148 }
149
150 pub fn name(&self) -> Option<&str> {
152 self.name.as_deref()
153 }
154
155 pub fn write<T: TensorData>(&self, data: &[T]) -> MnnResult<()> {
163 if data.is_empty() {
164 return Err(MnnError::EmptyData);
165 }
166
167 let expected_count = self.element_count() as usize;
168 if data.len() != expected_count {
169 return Err(MnnError::shape_mismatch(
170 &[expected_count as i32],
171 &[data.len() as i32],
172 ));
173 }
174
175 let host_data = unsafe { mnn_rs_sys::mnn_tensor_get_host_data(self.inner) };
176 if host_data.is_null() {
177 return Err(MnnError::tensor_error("Tensor has no host data"));
178 }
179
180 unsafe {
181 std::ptr::copy_nonoverlapping(
182 data.as_ptr() as *const c_void,
183 host_data,
184 data.len() * std::mem::size_of::<T>(),
185 );
186 }
187
188 Ok(())
189 }
190
191 pub fn read<T: TensorData>(&self) -> MnnResult<Vec<T>> {
196 let count = self.element_count() as usize;
197 let mut data = vec![T::default(); count];
198
199 let host_data = unsafe { mnn_rs_sys::mnn_tensor_get_host_data(self.inner) };
200 if host_data.is_null() {
201 return Err(MnnError::tensor_error("Tensor has no host data"));
202 }
203
204 unsafe {
205 std::ptr::copy_nonoverlapping(
206 host_data,
207 data.as_mut_ptr() as *mut c_void,
208 count * std::mem::size_of::<T>(),
209 );
210 }
211
212 Ok(data)
213 }
214
215 pub unsafe fn as_slice_mut<T: TensorData>(&mut self) -> MnnResult<&mut [T]> {
221 let count = self.element_count() as usize;
222 let ptr = unsafe { mnn_rs_sys::mnn_tensor_get_host_data(self.inner) };
223
224 if ptr.is_null() {
225 return Err(MnnError::tensor_error("Tensor has no host data"));
226 }
227
228 Ok(unsafe { std::slice::from_raw_parts_mut(ptr as *mut T, count) })
229 }
230
231 pub unsafe fn as_slice<T: TensorData>(&self) -> MnnResult<&[T]> {
237 let count = self.element_count() as usize;
238 let ptr = unsafe { mnn_rs_sys::mnn_tensor_get_host_data(self.inner) };
239
240 if ptr.is_null() {
241 return Err(MnnError::tensor_error("Tensor has no host data"));
242 }
243
244 Ok(unsafe { std::slice::from_raw_parts(ptr as *const T, count) })
245 }
246
247 pub fn info(&self) -> TensorInfo {
249 TensorInfo {
250 name: self.name.clone().unwrap_or_default(),
251 shape: self.shape(),
252 dtype: self.dtype(),
253 format: self.format(),
254 }
255 }
256
257 pub fn copy_from_host(&mut self, host_tensor: &Tensor) -> MnnResult<()> {
269 let result = unsafe {
270 mnn_rs_sys::mnn_tensor_copy_from_host(self.inner, host_tensor.inner)
271 };
272
273 if result != mnn_rs_sys::MNN_ERROR_NONE as i32 {
274 return Err(MnnError::internal("Failed to copy from host tensor"));
275 }
276
277 Ok(())
278 }
279
280 pub fn copy_to_host(&self, host_tensor: &mut Tensor) -> MnnResult<()> {
288 let result = unsafe {
289 mnn_rs_sys::mnn_tensor_copy_to_host(host_tensor.inner, self.inner)
290 };
291
292 if result != mnn_rs_sys::MNN_ERROR_NONE as i32 {
293 return Err(MnnError::internal("Failed to copy to host tensor"));
294 }
295
296 Ok(())
297 }
298
299 pub fn create_device(
309 shape: &[i32],
310 format: DataFormat,
311 dtype: DataType,
312 ) -> MnnResult<Tensor> {
313 if shape.is_empty() {
314 return Err(MnnError::internal("Shape cannot be empty"));
315 }
316
317 let type_code = dtype.to_type_code();
318 let format_code = format.to_mnn();
319
320 let inner = unsafe {
321 mnn_rs_sys::mnn_tensor_create_device(
322 shape.as_ptr(),
323 shape.len() as i32,
324 type_code,
325 format_code,
326 )
327 };
328
329 if inner.is_null() {
330 return Err(MnnError::internal("Failed to create device tensor"));
331 }
332
333 Ok(unsafe { Tensor::from_ptr(inner, None) })
334 }
335
336 pub fn clone(&self, deep_copy: bool) -> MnnResult<Tensor> {
344 let inner = unsafe {
345 mnn_rs_sys::mnn_tensor_clone(self.inner, if deep_copy { 1 } else { 0 })
346 };
347
348 if inner.is_null() {
349 return Err(MnnError::internal("Failed to clone tensor"));
350 }
351
352 Ok(unsafe { Tensor::from_ptr(inner, None) })
353 }
354
355 pub fn device_id(&self) -> u64 {
360 unsafe { mnn_rs_sys::mnn_tensor_device_id(self.inner) }
361 }
362
363 pub fn backend(&self) -> BackendType {
368 let backend_code = unsafe { mnn_rs_sys::mnn_tensor_get_backend(self.inner) };
369 BackendType::from_mnn_type(backend_code)
370 }
371}
372
373impl std::fmt::Debug for Tensor {
374 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
375 f.debug_struct("Tensor")
376 .field("shape", &self.shape())
377 .field("dtype", &self.dtype())
378 .field("format", &self.format())
379 .field("name", &self.name)
380 .finish()
381 }
382}
383
384pub trait TensorData: Default + Clone + Copy + 'static {
388 fn dtype() -> DataType;
390}
391
392impl TensorData for f32 {
393 fn dtype() -> DataType {
394 DataType::Float32
395 }
396}
397
398impl TensorData for f64 {
399 fn dtype() -> DataType {
400 DataType::Float64
401 }
402}
403
404impl TensorData for i32 {
405 fn dtype() -> DataType {
406 DataType::Int32
407 }
408}
409
410impl TensorData for i16 {
411 fn dtype() -> DataType {
412 DataType::Int16
413 }
414}
415
416impl TensorData for u8 {
417 fn dtype() -> DataType {
418 DataType::UInt8
419 }
420}
421
422#[cfg(feature = "fp16")]
423impl TensorData for half::f16 {
424 fn dtype() -> DataType {
425 DataType::Float16
426 }
427}
428
429#[cfg(feature = "int8")]
430impl TensorData for i8 {
431 fn dtype() -> DataType {
432 DataType::Int8
433 }
434}
435
436pub struct TensorView<'a> {
440 inner: *mut MNNTensor,
441 _marker: PhantomData<&'a Tensor>,
442}
443
444impl std::fmt::Debug for TensorView<'_> {
445 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
446 f.debug_struct("TensorView").finish_non_exhaustive()
447 }
448}
449
450impl<'a> TensorView<'a> {
451 pub fn from_tensor(tensor: &'a Tensor) -> Self {
453 Self {
454 inner: tensor.inner,
455 _marker: PhantomData,
456 }
457 }
458
459 pub fn shape(&self) -> Vec<i32> {
461 unsafe {
462 let dim_count = mnn_rs_sys::mnn_tensor_get_dimensions(self.inner);
463 if dim_count <= 0 {
464 return Vec::new();
465 }
466
467 let mut shape = Vec::with_capacity(dim_count as usize);
468 for i in 0..dim_count {
469 let dim = mnn_rs_sys::mnn_tensor_get_dim(self.inner, i);
470 shape.push(dim);
471 }
472 shape
473 }
474 }
475
476 pub fn dtype(&self) -> DataType {
478 DataType::Float32
479 }
480}
481
482impl TensorInfo {
483 pub fn dtype(&self) -> DataType {
485 self.dtype
486 }
487}
488
489#[cfg(test)]
490mod tests {
491 use super::*;
492
493 #[test]
494 fn test_tensor_data_types() {
495 assert_eq!(f32::dtype(), DataType::Float32);
496 assert_eq!(i32::dtype(), DataType::Int32);
497 assert_eq!(u8::dtype(), DataType::UInt8);
498 }
499
500 #[test]
501 fn test_tensor_info() {
502 let info = TensorInfo {
503 name: "test".to_string(),
504 shape: vec![1, 3, 224, 224],
505 dtype: DataType::Float32,
506 format: DataFormat::Nchw,
507 };
508
509 assert_eq!(info.element_count(), 1 * 3 * 224 * 224);
510 assert_eq!(info.byte_size(), 1 * 3 * 224 * 224 * 4);
511 }
512}