1#[cfg(feature = "docsrs")]
7mod docsrs_stub;
8
9#[cfg(feature = "docsrs")]
10pub use docsrs_stub::*;
11
12#[cfg(not(feature = "docsrs"))]
14mod normal_impl {
15
16 use ndarray::{ArrayD, ArrayViewD, IxDyn};
17 use std::ffi::CStr;
18 use std::ptr::NonNull;
19
20 #[allow(non_camel_case_types)]
21 #[allow(non_upper_case_globals)]
22 #[allow(non_snake_case)]
23 #[allow(dead_code)]
24 mod ffi {
25 include!(concat!(env!("OUT_DIR"), "/mnn_bindings.rs"));
26 }
27
28 #[derive(Debug, Clone, PartialEq, Eq)]
32 pub enum MnnError {
33 InvalidParameter(String),
35 OutOfMemory,
37 RuntimeError(String),
39 Unsupported,
41 ModelLoadFailed(String),
43 NullPointer,
45 ShapeMismatch {
47 expected: Vec<usize>,
48 got: Vec<usize>,
49 },
50 }
51
52 impl std::fmt::Display for MnnError {
53 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54 match self {
55 MnnError::InvalidParameter(msg) => write!(f, "Invalid parameter: {}", msg),
56 MnnError::OutOfMemory => write!(f, "Out of memory"),
57 MnnError::RuntimeError(msg) => write!(f, "Runtime error: {}", msg),
58 MnnError::Unsupported => write!(f, "Unsupported operation"),
59 MnnError::ModelLoadFailed(msg) => write!(f, "Model loading failed: {}", msg),
60 MnnError::NullPointer => write!(f, "Null pointer"),
61 MnnError::ShapeMismatch { expected, got } => {
62 write!(f, "Shape mismatch: expected {:?}, got {:?}", expected, got)
63 }
64 }
65 }
66 }
67
68 impl std::error::Error for MnnError {}
69
70 pub type Result<T> = std::result::Result<T, MnnError>;
71
72 #[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
76 #[repr(i32)]
77 pub enum PrecisionMode {
78 #[default]
80 Normal = 0,
81 Low = 1,
83 High = 2,
85 }
86
87 #[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
89 #[repr(i32)]
90 pub enum DataFormat {
91 #[default]
93 NCHW = 0,
94 NHWC = 1,
96 Auto = 2,
98 }
99
100 #[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
102 pub enum Backend {
103 #[default]
105 CPU,
106 Metal,
108 OpenCL,
110 OpenGL,
112 Vulkan,
114 CUDA,
116 CoreML,
118 }
119
120 #[derive(Debug, Clone)]
122 pub struct InferenceConfig {
123 pub thread_count: i32,
125 pub precision_mode: PrecisionMode,
127 pub use_cache: bool,
129 pub data_format: DataFormat,
131 pub backend: Backend,
133 }
134
135 impl Default for InferenceConfig {
136 fn default() -> Self {
137 InferenceConfig {
138 thread_count: 4,
139 precision_mode: PrecisionMode::Normal,
140 use_cache: false,
141 data_format: DataFormat::NCHW,
142 backend: Backend::CPU,
143 }
144 }
145 }
146
147 impl InferenceConfig {
148 pub fn new() -> Self {
150 Self::default()
151 }
152
153 pub fn with_threads(mut self, threads: i32) -> Self {
155 self.thread_count = threads;
156 self
157 }
158
159 pub fn with_precision(mut self, precision: PrecisionMode) -> Self {
161 self.precision_mode = precision;
162 self
163 }
164
165 pub fn with_backend(mut self, backend: Backend) -> Self {
167 self.backend = backend;
168 self
169 }
170
171 pub fn with_data_format(mut self, format: DataFormat) -> Self {
173 self.data_format = format;
174 self
175 }
176
177 fn to_ffi(&self) -> ffi::MNNR_Config {
178 ffi::MNNR_Config {
179 thread_count: self.thread_count,
180 precision_mode: self.precision_mode as i32,
181 use_cache: self.use_cache,
182 data_format: self.data_format as i32,
183 }
184 }
185 }
186
187 pub struct SharedRuntime {
191 ptr: NonNull<ffi::MNN_SharedRuntime>,
192 }
193
194 impl SharedRuntime {
195 pub fn new(config: &InferenceConfig) -> Result<Self> {
197 let c_config = config.to_ffi();
198 let runtime_ptr = unsafe { ffi::mnnr_create_runtime(&c_config) };
199
200 let ptr = NonNull::new(runtime_ptr).ok_or_else(|| {
201 MnnError::RuntimeError("Create shared runtime failed".to_string())
202 })?;
203
204 Ok(SharedRuntime { ptr })
205 }
206
207 pub(crate) fn as_ptr(&self) -> *mut ffi::MNN_SharedRuntime {
208 self.ptr.as_ptr()
209 }
210 }
211
212 impl Drop for SharedRuntime {
213 fn drop(&mut self) {
214 unsafe {
215 ffi::mnnr_destroy_runtime(self.ptr.as_ptr());
216 }
217 }
218 }
219
220 unsafe impl Send for SharedRuntime {}
221 unsafe impl Sync for SharedRuntime {}
222
223 fn get_last_error_message(engine: Option<*const ffi::MNN_InferenceEngine>) -> String {
226 match engine {
227 Some(ptr) => unsafe {
228 let c_str = ffi::mnnr_get_last_error(ptr);
229 if c_str.is_null() {
230 "Unknown error".to_string()
231 } else {
232 CStr::from_ptr(c_str).to_string_lossy().into_owned()
233 }
234 },
235 None => "Engine creation failed".to_string(),
236 }
237 }
238
239 pub struct InferenceEngine {
245 ptr: NonNull<ffi::MNN_InferenceEngine>,
246 input_shape: Vec<usize>,
247 output_shape: Vec<usize>,
248 }
249
250 impl InferenceEngine {
251 pub fn from_buffer(model_buffer: &[u8], config: Option<InferenceConfig>) -> Result<Self> {
263 if model_buffer.is_empty() {
264 return Err(MnnError::InvalidParameter(
265 "Model data is empty".to_string(),
266 ));
267 }
268
269 let cfg = config.unwrap_or_default();
270 let c_config = cfg.to_ffi();
271
272 let engine_ptr = unsafe {
273 ffi::mnnr_create_engine(
274 model_buffer.as_ptr() as *const _,
275 model_buffer.len(),
276 &c_config,
277 )
278 };
279
280 let ptr = NonNull::new(engine_ptr)
281 .ok_or_else(|| MnnError::ModelLoadFailed(get_last_error_message(None)))?;
282
283 let (input_shape, output_shape) = unsafe { Self::get_shapes(ptr.as_ptr())? };
284
285 Ok(InferenceEngine {
286 ptr,
287 input_shape,
288 output_shape,
289 })
290 }
291
292 pub fn from_file(
294 model_path: impl AsRef<std::path::Path>,
295 config: Option<InferenceConfig>,
296 ) -> Result<Self> {
297 let model_buffer = std::fs::read(model_path.as_ref()).map_err(|e| {
298 MnnError::ModelLoadFailed(format!("Failed to read model file: {}", e))
299 })?;
300 Self::from_buffer(&model_buffer, config)
301 }
302
303 pub fn from_buffer_with_runtime(
305 model_buffer: &[u8],
306 runtime: &SharedRuntime,
307 ) -> Result<Self> {
308 if model_buffer.is_empty() {
309 return Err(MnnError::InvalidParameter(
310 "Model data is empty".to_string(),
311 ));
312 }
313
314 let engine_ptr = unsafe {
315 ffi::mnnr_create_engine_with_runtime(
316 model_buffer.as_ptr() as *const _,
317 model_buffer.len(),
318 runtime.as_ptr(),
319 )
320 };
321
322 let ptr = NonNull::new(engine_ptr)
323 .ok_or_else(|| MnnError::ModelLoadFailed(get_last_error_message(None)))?;
324
325 let (input_shape, output_shape) = unsafe { Self::get_shapes(ptr.as_ptr())? };
326
327 Ok(InferenceEngine {
328 ptr,
329 input_shape,
330 output_shape,
331 })
332 }
333
334 unsafe fn get_shapes(
335 ptr: *mut ffi::MNN_InferenceEngine,
336 ) -> Result<(Vec<usize>, Vec<usize>)> {
337 let mut input_shape_vec = vec![0usize; 8];
338 let mut input_ndims = 0;
339 let mut output_shape_vec = vec![0usize; 8];
340 let mut output_ndims = 0;
341
342 if ffi::mnnr_get_input_shape(ptr, input_shape_vec.as_mut_ptr(), &mut input_ndims)
343 != ffi::MNNR_ErrorCode_MNNR_SUCCESS
344 {
345 return Err(MnnError::RuntimeError(
346 "Failed to get input shape".to_string(),
347 ));
348 }
349 input_shape_vec.truncate(input_ndims);
350
351 if ffi::mnnr_get_output_shape(ptr, output_shape_vec.as_mut_ptr(), &mut output_ndims)
352 != ffi::MNNR_ErrorCode_MNNR_SUCCESS
353 {
354 return Err(MnnError::RuntimeError(
355 "Failed to get output shape".to_string(),
356 ));
357 }
358 output_shape_vec.truncate(output_ndims);
359
360 Ok((input_shape_vec, output_shape_vec))
361 }
362
363 pub fn input_shape(&self) -> &[usize] {
365 &self.input_shape
366 }
367
368 pub fn output_shape(&self) -> &[usize] {
370 &self.output_shape
371 }
372
373 pub fn run(&self, input_data: ArrayViewD<f32>) -> Result<ArrayD<f32>> {
381 if input_data.shape() != self.input_shape.as_slice() {
382 return Err(MnnError::ShapeMismatch {
383 expected: self.input_shape.clone(),
384 got: input_data.shape().to_vec(),
385 });
386 }
387
388 let input_slice = input_data.as_slice().ok_or_else(|| {
389 MnnError::InvalidParameter("Input data must be contiguous".to_string())
390 })?;
391
392 let output_size: usize = self.output_shape.iter().product();
393 let mut output_buffer = vec![0.0f32; output_size];
394
395 let error_code = unsafe {
396 ffi::mnnr_run_inference(
397 self.ptr.as_ptr(),
398 input_slice.as_ptr(),
399 input_slice.len(),
400 output_buffer.as_mut_ptr(),
401 output_buffer.len(),
402 )
403 };
404
405 match error_code {
406 ffi::MNNR_ErrorCode_MNNR_SUCCESS => {
407 ArrayD::from_shape_vec(IxDyn(&self.output_shape), output_buffer).map_err(|e| {
408 MnnError::RuntimeError(format!("Failed to create output array: {}", e))
409 })
410 }
411 ffi::MNNR_ErrorCode_MNNR_ERROR_INVALID_PARAMETER => Err(
412 MnnError::InvalidParameter(get_last_error_message(Some(self.ptr.as_ptr()))),
413 ),
414 ffi::MNNR_ErrorCode_MNNR_ERROR_OUT_OF_MEMORY => Err(MnnError::OutOfMemory),
415 ffi::MNNR_ErrorCode_MNNR_ERROR_UNSUPPORTED => Err(MnnError::Unsupported),
416 _ => Err(MnnError::RuntimeError(get_last_error_message(Some(
417 self.ptr.as_ptr(),
418 )))),
419 }
420 }
421
422 pub fn run_raw(&self, input: &[f32], output: &mut [f32]) -> Result<()> {
426 let expected_input: usize = self.input_shape.iter().product();
427 let expected_output: usize = self.output_shape.iter().product();
428
429 if input.len() != expected_input {
430 return Err(MnnError::ShapeMismatch {
431 expected: vec![expected_input],
432 got: vec![input.len()],
433 });
434 }
435
436 if output.len() != expected_output {
437 return Err(MnnError::ShapeMismatch {
438 expected: vec![expected_output],
439 got: vec![output.len()],
440 });
441 }
442
443 let error_code = unsafe {
444 ffi::mnnr_run_inference(
445 self.ptr.as_ptr(),
446 input.as_ptr(),
447 input.len(),
448 output.as_mut_ptr(),
449 output.len(),
450 )
451 };
452
453 match error_code {
454 ffi::MNNR_ErrorCode_MNNR_SUCCESS => Ok(()),
455 ffi::MNNR_ErrorCode_MNNR_ERROR_INVALID_PARAMETER => Err(
456 MnnError::InvalidParameter(get_last_error_message(Some(self.ptr.as_ptr()))),
457 ),
458 ffi::MNNR_ErrorCode_MNNR_ERROR_OUT_OF_MEMORY => Err(MnnError::OutOfMemory),
459 _ => Err(MnnError::RuntimeError(get_last_error_message(Some(
460 self.ptr.as_ptr(),
461 )))),
462 }
463 }
464
465 pub(crate) fn as_ptr(&self) -> NonNull<ffi::MNN_InferenceEngine> {
466 self.ptr
467 }
468
469 pub fn has_dynamic_shape(&self) -> bool {
471 self.input_shape.iter().any(|&d| d > 100000)
473 || self.output_shape.iter().any(|&d| d > 100000)
474 }
475
476 pub fn run_dynamic(&self, input_data: ArrayViewD<f32>) -> Result<ArrayD<f32>> {
487 let input_shape: Vec<usize> = input_data.shape().to_vec();
488 let input_slice = input_data.as_slice().ok_or_else(|| {
489 MnnError::InvalidParameter("Input data must be contiguous".to_string())
490 })?;
491
492 let mut output_data: *mut f32 = std::ptr::null_mut();
493 let mut output_size: usize = 0;
494 let mut output_dims = [0usize; 8];
495 let mut output_ndims: usize = 0;
496
497 let error_code = unsafe {
498 ffi::mnnr_run_inference_dynamic(
499 self.ptr.as_ptr(),
500 input_slice.as_ptr(),
501 input_shape.as_ptr(),
502 input_shape.len(),
503 &mut output_data,
504 &mut output_size,
505 output_dims.as_mut_ptr(),
506 &mut output_ndims,
507 )
508 };
509
510 if error_code != ffi::MNNR_ErrorCode_MNNR_SUCCESS {
511 return match error_code {
512 ffi::MNNR_ErrorCode_MNNR_ERROR_INVALID_PARAMETER => Err(
513 MnnError::InvalidParameter(get_last_error_message(Some(self.ptr.as_ptr()))),
514 ),
515 ffi::MNNR_ErrorCode_MNNR_ERROR_OUT_OF_MEMORY => Err(MnnError::OutOfMemory),
516 ffi::MNNR_ErrorCode_MNNR_ERROR_UNSUPPORTED => Err(MnnError::Unsupported),
517 _ => Err(MnnError::RuntimeError(get_last_error_message(Some(
518 self.ptr.as_ptr(),
519 )))),
520 };
521 }
522
523 let output_shape: Vec<usize> = output_dims[..output_ndims].to_vec();
525 let output_buffer = unsafe {
526 let slice = std::slice::from_raw_parts(output_data, output_size);
527 let buffer = slice.to_vec();
528 ffi::mnnr_free_output(output_data);
529 buffer
530 };
531
532 ArrayD::from_shape_vec(IxDyn(&output_shape), output_buffer).map_err(|e| {
533 MnnError::RuntimeError(format!("Failed to create output array: {}", e))
534 })
535 }
536
537 pub fn run_dynamic_raw(
541 &self,
542 input: &[f32],
543 input_shape: &[usize],
544 ) -> Result<(Vec<f32>, Vec<usize>)> {
545 let mut output_data: *mut f32 = std::ptr::null_mut();
546 let mut output_size: usize = 0;
547 let mut output_dims = [0usize; 8];
548 let mut output_ndims: usize = 0;
549
550 let error_code = unsafe {
551 ffi::mnnr_run_inference_dynamic(
552 self.ptr.as_ptr(),
553 input.as_ptr(),
554 input_shape.as_ptr(),
555 input_shape.len(),
556 &mut output_data,
557 &mut output_size,
558 output_dims.as_mut_ptr(),
559 &mut output_ndims,
560 )
561 };
562
563 if error_code != ffi::MNNR_ErrorCode_MNNR_SUCCESS {
564 return match error_code {
565 ffi::MNNR_ErrorCode_MNNR_ERROR_INVALID_PARAMETER => Err(
566 MnnError::InvalidParameter(get_last_error_message(Some(self.ptr.as_ptr()))),
567 ),
568 ffi::MNNR_ErrorCode_MNNR_ERROR_OUT_OF_MEMORY => Err(MnnError::OutOfMemory),
569 _ => Err(MnnError::RuntimeError(get_last_error_message(Some(
570 self.ptr.as_ptr(),
571 )))),
572 };
573 }
574
575 let output_shape = output_dims[..output_ndims].to_vec();
577 let output_buffer = unsafe {
578 let slice = std::slice::from_raw_parts(output_data, output_size);
579 let buffer = slice.to_vec();
580 ffi::mnnr_free_output(output_data);
581 buffer
582 };
583
584 Ok((output_buffer, output_shape))
585 }
586 }
587
588 impl Drop for InferenceEngine {
589 fn drop(&mut self) {
590 unsafe {
591 ffi::mnnr_destroy_engine(self.ptr.as_ptr());
592 }
593 }
594 }
595
596 unsafe impl Send for InferenceEngine {}
597 unsafe impl Sync for InferenceEngine {}
598
599 pub struct SessionPool {
603 ptr: NonNull<ffi::MNN_SessionPool>,
604 input_shape: Vec<usize>,
605 output_shape: Vec<usize>,
606 }
607
608 impl SessionPool {
609 pub fn new(
616 engine: &InferenceEngine,
617 pool_size: usize,
618 config: Option<InferenceConfig>,
619 ) -> Result<Self> {
620 if pool_size == 0 {
621 return Err(MnnError::InvalidParameter(
622 "Pool size cannot be 0".to_string(),
623 ));
624 }
625
626 let cfg = config.unwrap_or_default();
627 let c_config = cfg.to_ffi();
628
629 let pool_ptr = unsafe {
630 ffi::mnnr_create_session_pool(engine.as_ptr().as_ptr(), pool_size, &c_config)
631 };
632
633 let ptr = NonNull::new(pool_ptr)
634 .ok_or_else(|| MnnError::RuntimeError("Create session pool failed".to_string()))?;
635
636 Ok(SessionPool {
637 ptr,
638 input_shape: engine.input_shape.clone(),
639 output_shape: engine.output_shape.clone(),
640 })
641 }
642
643 pub fn run(&self, input_data: ArrayViewD<f32>) -> Result<ArrayD<f32>> {
645 if input_data.shape() != self.input_shape.as_slice() {
646 return Err(MnnError::ShapeMismatch {
647 expected: self.input_shape.clone(),
648 got: input_data.shape().to_vec(),
649 });
650 }
651
652 let input_slice = input_data.as_slice().ok_or_else(|| {
653 MnnError::InvalidParameter("Input data must be contiguous".to_string())
654 })?;
655
656 let output_size: usize = self.output_shape.iter().product();
657 let mut output_buffer = vec![0.0f32; output_size];
658
659 let error_code = unsafe {
660 ffi::mnnr_session_pool_run(
661 self.ptr.as_ptr(),
662 input_slice.as_ptr(),
663 input_slice.len(),
664 output_buffer.as_mut_ptr(),
665 output_buffer.len(),
666 )
667 };
668
669 match error_code {
670 ffi::MNNR_ErrorCode_MNNR_SUCCESS => {
671 ArrayD::from_shape_vec(IxDyn(&self.output_shape), output_buffer).map_err(|e| {
672 MnnError::RuntimeError(format!("Failed to create output array: {}", e))
673 })
674 }
675 _ => Err(MnnError::RuntimeError(
676 "Session pool inference failed".to_string(),
677 )),
678 }
679 }
680
681 pub fn available(&self) -> usize {
683 unsafe { ffi::mnnr_session_pool_available(self.ptr.as_ptr()) }
684 }
685 }
686
687 impl Drop for SessionPool {
688 fn drop(&mut self) {
689 unsafe {
690 ffi::mnnr_destroy_session_pool(self.ptr.as_ptr());
691 }
692 }
693 }
694
695 unsafe impl Send for SessionPool {}
696 unsafe impl Sync for SessionPool {}
697
698 pub fn get_version() -> String {
702 unsafe {
703 let c_str = ffi::mnnr_get_version();
704 if c_str.is_null() {
705 "unknown".to_string()
706 } else {
707 CStr::from_ptr(c_str).to_string_lossy().into_owned()
708 }
709 }
710 }
711
712 #[cfg(test)]
713 mod tests {
714 use super::*;
715
716 #[test]
717 fn test_config_default() {
718 let config = InferenceConfig::default();
719 assert_eq!(config.thread_count, 4);
720 assert_eq!(config.precision_mode, PrecisionMode::Normal);
721 }
722
723 #[test]
724 fn test_config_builder() {
725 let config = InferenceConfig::new()
726 .with_threads(8)
727 .with_precision(PrecisionMode::High)
728 .with_backend(Backend::Metal);
729
730 assert_eq!(config.thread_count, 8);
731 assert_eq!(config.precision_mode, PrecisionMode::High);
732 assert_eq!(config.backend, Backend::Metal);
733 }
734 }
735} #[cfg(not(feature = "docsrs"))]
739pub use normal_impl::*;