Skip to main content

ocr_rs/mnn/
mod.rs

1//! MNN Inference Engine FFI Binding Layer
2//!
3//! This module encapsulates the low-level interfaces of the MNN C++ inference framework, providing safe Rust APIs.
4
5// Use stub implementation when building on docs.rs
6#[cfg(feature = "docsrs")]
7mod docsrs_stub;
8
9#[cfg(feature = "docsrs")]
10pub use docsrs_stub::*;
11
12// Use complete implementation for normal builds
13#[cfg(not(feature = "docsrs"))]
14mod normal_impl {
15
16    use ndarray::{ArrayD, ArrayViewD, IxDyn};
17    use std::ffi::CStr;
18    use std::ptr::NonNull;
19
20    #[allow(non_camel_case_types)]
21    #[allow(non_upper_case_globals)]
22    #[allow(non_snake_case)]
23    #[allow(dead_code)]
24    mod ffi {
25        include!(concat!(env!("OUT_DIR"), "/mnn_bindings.rs"));
26    }
27
28    // ============== Error Types ==============
29
30    /// MNN related errors
31    #[derive(Debug, Clone, PartialEq, Eq)]
32    pub enum MnnError {
33        /// Invalid parameter
34        InvalidParameter(String),
35        /// Out of memory
36        OutOfMemory,
37        /// Runtime error
38        RuntimeError(String),
39        /// Unsupported operation
40        Unsupported,
41        /// Model loading failed
42        ModelLoadFailed(String),
43        /// Null pointer error
44        NullPointer,
45        /// Shape mismatch
46        ShapeMismatch {
47            expected: Vec<usize>,
48            got: Vec<usize>,
49        },
50    }
51
52    impl std::fmt::Display for MnnError {
53        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54            match self {
55                MnnError::InvalidParameter(msg) => write!(f, "Invalid parameter: {}", msg),
56                MnnError::OutOfMemory => write!(f, "Out of memory"),
57                MnnError::RuntimeError(msg) => write!(f, "Runtime error: {}", msg),
58                MnnError::Unsupported => write!(f, "Unsupported operation"),
59                MnnError::ModelLoadFailed(msg) => write!(f, "Model loading failed: {}", msg),
60                MnnError::NullPointer => write!(f, "Null pointer"),
61                MnnError::ShapeMismatch { expected, got } => {
62                    write!(f, "Shape mismatch: expected {:?}, got {:?}", expected, got)
63                }
64            }
65        }
66    }
67
68    impl std::error::Error for MnnError {}
69
70    pub type Result<T> = std::result::Result<T, MnnError>;
71
72    // ============== Configuration Types ==============
73
74    /// Precision mode
75    #[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
76    #[repr(i32)]
77    pub enum PrecisionMode {
78        /// Normal precision
79        #[default]
80        Normal = 0,
81        /// Low precision (faster)
82        Low = 1,
83        /// High precision (more accurate)
84        High = 2,
85    }
86
87    /// Data format
88    #[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
89    #[repr(i32)]
90    pub enum DataFormat {
91        /// NCHW format (Caffe/PyTorch/ONNX)
92        #[default]
93        NCHW = 0,
94        /// NHWC format (TensorFlow)
95        NHWC = 1,
96        /// Auto detect
97        Auto = 2,
98    }
99
100    /// Inference backend type
101    #[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
102    pub enum Backend {
103        /// CPU backend
104        #[default]
105        CPU,
106        /// Metal GPU (macOS/iOS)
107        Metal,
108        /// OpenCL GPU
109        OpenCL,
110        /// OpenGL GPU
111        OpenGL,
112        /// Vulkan GPU
113        Vulkan,
114        /// CUDA GPU (NVIDIA)
115        CUDA,
116        /// CoreML (macOS/iOS)
117        CoreML,
118    }
119
120    /// Inference configuration
121    #[derive(Debug, Clone)]
122    pub struct InferenceConfig {
123        /// Thread count (0 means auto, default is 4)
124        pub thread_count: i32,
125        /// Precision mode
126        pub precision_mode: PrecisionMode,
127        /// Whether to use cache
128        pub use_cache: bool,
129        /// Data format
130        pub data_format: DataFormat,
131        /// Inference backend
132        pub backend: Backend,
133    }
134
135    impl Default for InferenceConfig {
136        fn default() -> Self {
137            InferenceConfig {
138                thread_count: 4,
139                precision_mode: PrecisionMode::Normal,
140                use_cache: false,
141                data_format: DataFormat::NCHW,
142                backend: Backend::CPU,
143            }
144        }
145    }
146
147    impl InferenceConfig {
148        /// Create new inference configuration
149        pub fn new() -> Self {
150            Self::default()
151        }
152
153        /// Set thread count
154        pub fn with_threads(mut self, threads: i32) -> Self {
155            self.thread_count = threads;
156            self
157        }
158
159        /// Set precision mode
160        pub fn with_precision(mut self, precision: PrecisionMode) -> Self {
161            self.precision_mode = precision;
162            self
163        }
164
165        /// Set backend
166        pub fn with_backend(mut self, backend: Backend) -> Self {
167            self.backend = backend;
168            self
169        }
170
171        /// Set data format
172        pub fn with_data_format(mut self, format: DataFormat) -> Self {
173            self.data_format = format;
174            self
175        }
176
177        fn to_ffi(&self) -> ffi::MNNR_Config {
178            ffi::MNNR_Config {
179                thread_count: self.thread_count,
180                precision_mode: self.precision_mode as i32,
181                use_cache: self.use_cache,
182                data_format: self.data_format as i32,
183            }
184        }
185    }
186
187    // ============== Shared Runtime ==============
188
189    /// Shared runtime for sharing resources among multiple engines
190    pub struct SharedRuntime {
191        ptr: NonNull<ffi::MNN_SharedRuntime>,
192    }
193
194    impl SharedRuntime {
195        /// Create new shared runtime
196        pub fn new(config: &InferenceConfig) -> Result<Self> {
197            let c_config = config.to_ffi();
198            let runtime_ptr = unsafe { ffi::mnnr_create_runtime(&c_config) };
199
200            let ptr = NonNull::new(runtime_ptr).ok_or_else(|| {
201                MnnError::RuntimeError("Create shared runtime failed".to_string())
202            })?;
203
204            Ok(SharedRuntime { ptr })
205        }
206
207        pub(crate) fn as_ptr(&self) -> *mut ffi::MNN_SharedRuntime {
208            self.ptr.as_ptr()
209        }
210    }
211
212    impl Drop for SharedRuntime {
213        fn drop(&mut self) {
214            unsafe {
215                ffi::mnnr_destroy_runtime(self.ptr.as_ptr());
216            }
217        }
218    }
219
220    unsafe impl Send for SharedRuntime {}
221    unsafe impl Sync for SharedRuntime {}
222
223    // ============== Helper Functions ==============
224
225    fn get_last_error_message(engine: Option<*const ffi::MNN_InferenceEngine>) -> String {
226        match engine {
227            Some(ptr) => unsafe {
228                let c_str = ffi::mnnr_get_last_error(ptr);
229                if c_str.is_null() {
230                    "Unknown error".to_string()
231                } else {
232                    CStr::from_ptr(c_str).to_string_lossy().into_owned()
233                }
234            },
235            None => "Engine creation failed".to_string(),
236        }
237    }
238
239    // ============== Inference Engine ==============
240
241    /// MNN inference engine
242    ///
243    /// Encapsulates MNN model loading and inference functionality
244    pub struct InferenceEngine {
245        ptr: NonNull<ffi::MNN_InferenceEngine>,
246        input_shape: Vec<usize>,
247        output_shape: Vec<usize>,
248    }
249
250    impl InferenceEngine {
251        /// Create inference engine from model byte data
252        ///
253        /// # Parameters
254        /// - `model_buffer`: Model file byte data
255        /// - `config`: Optional inference configuration
256        ///
257        /// # Example
258        /// ```ignore
259        /// let model_data = std::fs::read("model.mnn")?;
260        /// let engine = InferenceEngine::from_buffer(&model_data, None)?;
261        /// ```
262        pub fn from_buffer(model_buffer: &[u8], config: Option<InferenceConfig>) -> Result<Self> {
263            if model_buffer.is_empty() {
264                return Err(MnnError::InvalidParameter(
265                    "Model data is empty".to_string(),
266                ));
267            }
268
269            let cfg = config.unwrap_or_default();
270            let c_config = cfg.to_ffi();
271
272            let engine_ptr = unsafe {
273                ffi::mnnr_create_engine(
274                    model_buffer.as_ptr() as *const _,
275                    model_buffer.len(),
276                    &c_config,
277                )
278            };
279
280            let ptr = NonNull::new(engine_ptr)
281                .ok_or_else(|| MnnError::ModelLoadFailed(get_last_error_message(None)))?;
282
283            let (input_shape, output_shape) = unsafe { Self::get_shapes(ptr.as_ptr())? };
284
285            Ok(InferenceEngine {
286                ptr,
287                input_shape,
288                output_shape,
289            })
290        }
291
292        /// Create inference engine from model file
293        pub fn from_file(
294            model_path: impl AsRef<std::path::Path>,
295            config: Option<InferenceConfig>,
296        ) -> Result<Self> {
297            let model_buffer = std::fs::read(model_path.as_ref()).map_err(|e| {
298                MnnError::ModelLoadFailed(format!("Failed to read model file: {}", e))
299            })?;
300            Self::from_buffer(&model_buffer, config)
301        }
302
303        /// Create inference engine from model byte data using shared runtime
304        pub fn from_buffer_with_runtime(
305            model_buffer: &[u8],
306            runtime: &SharedRuntime,
307        ) -> Result<Self> {
308            if model_buffer.is_empty() {
309                return Err(MnnError::InvalidParameter(
310                    "Model data is empty".to_string(),
311                ));
312            }
313
314            let engine_ptr = unsafe {
315                ffi::mnnr_create_engine_with_runtime(
316                    model_buffer.as_ptr() as *const _,
317                    model_buffer.len(),
318                    runtime.as_ptr(),
319                )
320            };
321
322            let ptr = NonNull::new(engine_ptr)
323                .ok_or_else(|| MnnError::ModelLoadFailed(get_last_error_message(None)))?;
324
325            let (input_shape, output_shape) = unsafe { Self::get_shapes(ptr.as_ptr())? };
326
327            Ok(InferenceEngine {
328                ptr,
329                input_shape,
330                output_shape,
331            })
332        }
333
334        unsafe fn get_shapes(
335            ptr: *mut ffi::MNN_InferenceEngine,
336        ) -> Result<(Vec<usize>, Vec<usize>)> {
337            let mut input_shape_vec = vec![0usize; 8];
338            let mut input_ndims = 0;
339            let mut output_shape_vec = vec![0usize; 8];
340            let mut output_ndims = 0;
341
342            if ffi::mnnr_get_input_shape(ptr, input_shape_vec.as_mut_ptr(), &mut input_ndims)
343                != ffi::MNNR_ErrorCode_MNNR_SUCCESS
344            {
345                return Err(MnnError::RuntimeError(
346                    "Failed to get input shape".to_string(),
347                ));
348            }
349            input_shape_vec.truncate(input_ndims);
350
351            if ffi::mnnr_get_output_shape(ptr, output_shape_vec.as_mut_ptr(), &mut output_ndims)
352                != ffi::MNNR_ErrorCode_MNNR_SUCCESS
353            {
354                return Err(MnnError::RuntimeError(
355                    "Failed to get output shape".to_string(),
356                ));
357            }
358            output_shape_vec.truncate(output_ndims);
359
360            Ok((input_shape_vec, output_shape_vec))
361        }
362
363        /// Get input tensor shape
364        pub fn input_shape(&self) -> &[usize] {
365            &self.input_shape
366        }
367
368        /// Get output tensor shape
369        pub fn output_shape(&self) -> &[usize] {
370            &self.output_shape
371        }
372
373        /// Execute inference
374        ///
375        /// # Parameters
376        /// - `input_data`: Input data, shape must match model input shape
377        ///
378        /// # Returns
379        /// Inference result array
380        pub fn run(&self, input_data: ArrayViewD<f32>) -> Result<ArrayD<f32>> {
381            if input_data.shape() != self.input_shape.as_slice() {
382                return Err(MnnError::ShapeMismatch {
383                    expected: self.input_shape.clone(),
384                    got: input_data.shape().to_vec(),
385                });
386            }
387
388            let input_slice = input_data.as_slice().ok_or_else(|| {
389                MnnError::InvalidParameter("Input data must be contiguous".to_string())
390            })?;
391
392            let output_size: usize = self.output_shape.iter().product();
393            let mut output_buffer = vec![0.0f32; output_size];
394
395            let error_code = unsafe {
396                ffi::mnnr_run_inference(
397                    self.ptr.as_ptr(),
398                    input_slice.as_ptr(),
399                    input_slice.len(),
400                    output_buffer.as_mut_ptr(),
401                    output_buffer.len(),
402                )
403            };
404
405            match error_code {
406                ffi::MNNR_ErrorCode_MNNR_SUCCESS => {
407                    ArrayD::from_shape_vec(IxDyn(&self.output_shape), output_buffer).map_err(|e| {
408                        MnnError::RuntimeError(format!("Failed to create output array: {}", e))
409                    })
410                }
411                ffi::MNNR_ErrorCode_MNNR_ERROR_INVALID_PARAMETER => Err(
412                    MnnError::InvalidParameter(get_last_error_message(Some(self.ptr.as_ptr()))),
413                ),
414                ffi::MNNR_ErrorCode_MNNR_ERROR_OUT_OF_MEMORY => Err(MnnError::OutOfMemory),
415                ffi::MNNR_ErrorCode_MNNR_ERROR_UNSUPPORTED => Err(MnnError::Unsupported),
416                _ => Err(MnnError::RuntimeError(get_last_error_message(Some(
417                    self.ptr.as_ptr(),
418                )))),
419            }
420        }
421
422        /// Execute inference (using raw slices)
423        ///
424        /// This is a low-level API, suitable for scenarios requiring maximum performance
425        pub fn run_raw(&self, input: &[f32], output: &mut [f32]) -> Result<()> {
426            let expected_input: usize = self.input_shape.iter().product();
427            let expected_output: usize = self.output_shape.iter().product();
428
429            if input.len() != expected_input {
430                return Err(MnnError::ShapeMismatch {
431                    expected: vec![expected_input],
432                    got: vec![input.len()],
433                });
434            }
435
436            if output.len() != expected_output {
437                return Err(MnnError::ShapeMismatch {
438                    expected: vec![expected_output],
439                    got: vec![output.len()],
440                });
441            }
442
443            let error_code = unsafe {
444                ffi::mnnr_run_inference(
445                    self.ptr.as_ptr(),
446                    input.as_ptr(),
447                    input.len(),
448                    output.as_mut_ptr(),
449                    output.len(),
450                )
451            };
452
453            match error_code {
454                ffi::MNNR_ErrorCode_MNNR_SUCCESS => Ok(()),
455                ffi::MNNR_ErrorCode_MNNR_ERROR_INVALID_PARAMETER => Err(
456                    MnnError::InvalidParameter(get_last_error_message(Some(self.ptr.as_ptr()))),
457                ),
458                ffi::MNNR_ErrorCode_MNNR_ERROR_OUT_OF_MEMORY => Err(MnnError::OutOfMemory),
459                _ => Err(MnnError::RuntimeError(get_last_error_message(Some(
460                    self.ptr.as_ptr(),
461                )))),
462            }
463        }
464
465        pub(crate) fn as_ptr(&self) -> NonNull<ffi::MNN_InferenceEngine> {
466            self.ptr
467        }
468
469        /// Check if model has dynamic shape (contains -1 dimension)
470        pub fn has_dynamic_shape(&self) -> bool {
471            // When shape contains very large values, it indicates dynamic shape (-1 converted to usize becomes very large)
472            self.input_shape.iter().any(|&d| d > 100000)
473                || self.output_shape.iter().any(|&d| d > 100000)
474        }
475
476        /// Execute dynamic shape inference
477        ///
478        /// Suitable for models where input shape changes at runtime (such as detection models).
479        /// This function adjusts model input tensor shape before running.
480        ///
481        /// # Parameters
482        /// - `input_data`: Input data array
483        ///
484        /// # Returns
485        /// Inference result array, shape dynamically determined by model
486        pub fn run_dynamic(&self, input_data: ArrayViewD<f32>) -> Result<ArrayD<f32>> {
487            let input_shape: Vec<usize> = input_data.shape().to_vec();
488            let input_slice = input_data.as_slice().ok_or_else(|| {
489                MnnError::InvalidParameter("Input data must be contiguous".to_string())
490            })?;
491
492            let mut output_data: *mut f32 = std::ptr::null_mut();
493            let mut output_size: usize = 0;
494            let mut output_dims = [0usize; 8];
495            let mut output_ndims: usize = 0;
496
497            let error_code = unsafe {
498                ffi::mnnr_run_inference_dynamic(
499                    self.ptr.as_ptr(),
500                    input_slice.as_ptr(),
501                    input_shape.as_ptr(),
502                    input_shape.len(),
503                    &mut output_data,
504                    &mut output_size,
505                    output_dims.as_mut_ptr(),
506                    &mut output_ndims,
507                )
508            };
509
510            if error_code != ffi::MNNR_ErrorCode_MNNR_SUCCESS {
511                return match error_code {
512                    ffi::MNNR_ErrorCode_MNNR_ERROR_INVALID_PARAMETER => Err(
513                        MnnError::InvalidParameter(get_last_error_message(Some(self.ptr.as_ptr()))),
514                    ),
515                    ffi::MNNR_ErrorCode_MNNR_ERROR_OUT_OF_MEMORY => Err(MnnError::OutOfMemory),
516                    ffi::MNNR_ErrorCode_MNNR_ERROR_UNSUPPORTED => Err(MnnError::Unsupported),
517                    _ => Err(MnnError::RuntimeError(get_last_error_message(Some(
518                        self.ptr.as_ptr(),
519                    )))),
520                };
521            }
522
523            // Copy output data and free C buffer
524            let output_shape: Vec<usize> = output_dims[..output_ndims].to_vec();
525            let output_buffer = unsafe {
526                let slice = std::slice::from_raw_parts(output_data, output_size);
527                let buffer = slice.to_vec();
528                ffi::mnnr_free_output(output_data);
529                buffer
530            };
531
532            ArrayD::from_shape_vec(IxDyn(&output_shape), output_buffer).map_err(|e| {
533                MnnError::RuntimeError(format!("Failed to create output array: {}", e))
534            })
535        }
536
537        /// Execute dynamic shape inference (using raw slices)
538        ///
539        /// Low-level API, caller is responsible for managing output buffer
540        pub fn run_dynamic_raw(
541            &self,
542            input: &[f32],
543            input_shape: &[usize],
544        ) -> Result<(Vec<f32>, Vec<usize>)> {
545            let mut output_data: *mut f32 = std::ptr::null_mut();
546            let mut output_size: usize = 0;
547            let mut output_dims = [0usize; 8];
548            let mut output_ndims: usize = 0;
549
550            let error_code = unsafe {
551                ffi::mnnr_run_inference_dynamic(
552                    self.ptr.as_ptr(),
553                    input.as_ptr(),
554                    input_shape.as_ptr(),
555                    input_shape.len(),
556                    &mut output_data,
557                    &mut output_size,
558                    output_dims.as_mut_ptr(),
559                    &mut output_ndims,
560                )
561            };
562
563            if error_code != ffi::MNNR_ErrorCode_MNNR_SUCCESS {
564                return match error_code {
565                    ffi::MNNR_ErrorCode_MNNR_ERROR_INVALID_PARAMETER => Err(
566                        MnnError::InvalidParameter(get_last_error_message(Some(self.ptr.as_ptr()))),
567                    ),
568                    ffi::MNNR_ErrorCode_MNNR_ERROR_OUT_OF_MEMORY => Err(MnnError::OutOfMemory),
569                    _ => Err(MnnError::RuntimeError(get_last_error_message(Some(
570                        self.ptr.as_ptr(),
571                    )))),
572                };
573            }
574
575            // Copy output and free C buffer
576            let output_shape = output_dims[..output_ndims].to_vec();
577            let output_buffer = unsafe {
578                let slice = std::slice::from_raw_parts(output_data, output_size);
579                let buffer = slice.to_vec();
580                ffi::mnnr_free_output(output_data);
581                buffer
582            };
583
584            Ok((output_buffer, output_shape))
585        }
586    }
587
588    impl Drop for InferenceEngine {
589        fn drop(&mut self) {
590            unsafe {
591                ffi::mnnr_destroy_engine(self.ptr.as_ptr());
592            }
593        }
594    }
595
596    unsafe impl Send for InferenceEngine {}
597    unsafe impl Sync for InferenceEngine {}
598
599    // ============== Session Pool ==============
600
601    /// Session pool for high-concurrency inference scenarios
602    pub struct SessionPool {
603        ptr: NonNull<ffi::MNN_SessionPool>,
604        input_shape: Vec<usize>,
605        output_shape: Vec<usize>,
606    }
607
608    impl SessionPool {
609        /// Create session pool
610        ///
611        /// # Parameters
612        /// - `engine`: Inference engine
613        /// - `pool_size`: Number of sessions in pool
614        /// - `config`: Optional inference configuration
615        pub fn new(
616            engine: &InferenceEngine,
617            pool_size: usize,
618            config: Option<InferenceConfig>,
619        ) -> Result<Self> {
620            if pool_size == 0 {
621                return Err(MnnError::InvalidParameter(
622                    "Pool size cannot be 0".to_string(),
623                ));
624            }
625
626            let cfg = config.unwrap_or_default();
627            let c_config = cfg.to_ffi();
628
629            let pool_ptr = unsafe {
630                ffi::mnnr_create_session_pool(engine.as_ptr().as_ptr(), pool_size, &c_config)
631            };
632
633            let ptr = NonNull::new(pool_ptr)
634                .ok_or_else(|| MnnError::RuntimeError("Create session pool failed".to_string()))?;
635
636            Ok(SessionPool {
637                ptr,
638                input_shape: engine.input_shape.clone(),
639                output_shape: engine.output_shape.clone(),
640            })
641        }
642
643        /// Execute inference (thread-safe)
644        pub fn run(&self, input_data: ArrayViewD<f32>) -> Result<ArrayD<f32>> {
645            if input_data.shape() != self.input_shape.as_slice() {
646                return Err(MnnError::ShapeMismatch {
647                    expected: self.input_shape.clone(),
648                    got: input_data.shape().to_vec(),
649                });
650            }
651
652            let input_slice = input_data.as_slice().ok_or_else(|| {
653                MnnError::InvalidParameter("Input data must be contiguous".to_string())
654            })?;
655
656            let output_size: usize = self.output_shape.iter().product();
657            let mut output_buffer = vec![0.0f32; output_size];
658
659            let error_code = unsafe {
660                ffi::mnnr_session_pool_run(
661                    self.ptr.as_ptr(),
662                    input_slice.as_ptr(),
663                    input_slice.len(),
664                    output_buffer.as_mut_ptr(),
665                    output_buffer.len(),
666                )
667            };
668
669            match error_code {
670                ffi::MNNR_ErrorCode_MNNR_SUCCESS => {
671                    ArrayD::from_shape_vec(IxDyn(&self.output_shape), output_buffer).map_err(|e| {
672                        MnnError::RuntimeError(format!("Failed to create output array: {}", e))
673                    })
674                }
675                _ => Err(MnnError::RuntimeError(
676                    "Session pool inference failed".to_string(),
677                )),
678            }
679        }
680
681        /// Get available session count
682        pub fn available(&self) -> usize {
683            unsafe { ffi::mnnr_session_pool_available(self.ptr.as_ptr()) }
684        }
685    }
686
687    impl Drop for SessionPool {
688        fn drop(&mut self) {
689            unsafe {
690                ffi::mnnr_destroy_session_pool(self.ptr.as_ptr());
691            }
692        }
693    }
694
695    unsafe impl Send for SessionPool {}
696    unsafe impl Sync for SessionPool {}
697
698    // ============== Utility Functions ==============
699
700    /// Get MNN version number
701    pub fn get_version() -> String {
702        unsafe {
703            let c_str = ffi::mnnr_get_version();
704            if c_str.is_null() {
705                "unknown".to_string()
706            } else {
707                CStr::from_ptr(c_str).to_string_lossy().into_owned()
708            }
709        }
710    }
711
712    #[cfg(test)]
713    mod tests {
714        use super::*;
715
716        #[test]
717        fn test_config_default() {
718            let config = InferenceConfig::default();
719            assert_eq!(config.thread_count, 4);
720            assert_eq!(config.precision_mode, PrecisionMode::Normal);
721        }
722
723        #[test]
724        fn test_config_builder() {
725            let config = InferenceConfig::new()
726                .with_threads(8)
727                .with_precision(PrecisionMode::High)
728                .with_backend(Backend::Metal);
729
730            assert_eq!(config.thread_count, 8);
731            assert_eq!(config.precision_mode, PrecisionMode::High);
732            assert_eq!(config.backend, Backend::Metal);
733        }
734    }
735} // end of normal_impl module
736
737// Re-export types from normal implementation
738#[cfg(not(feature = "docsrs"))]
739pub use normal_impl::*;