Skip to main content

mnn_rs/
interpreter.rs

1//! Model interpreter for MNN inference.
2//!
3//! The interpreter holds the model and manages inference sessions.
4//! It's the primary entry point for running neural network inference with MNN.
5
6use crate::config::ScheduleConfig;
7use crate::config::SessionMode;
8use crate::error::{MnnError, MnnResult};
9use crate::session::Session;
10use crate::tensor::Tensor;
11use mnn_rs_sys::MNNInterpreter;
12use std::ffi::CString;
13use std::ffi::CStr;
14use std::path::Path;
15
16/// A model interpreter that holds a loaded neural network model.
17///
18/// The interpreter is the main entry point for MNN inference. It manages
19/// the model and can create multiple sessions for concurrent inference.
20///
21/// # Thread Safety
22/// The interpreter is thread-safe and can be shared across threads.
23/// Multiple sessions can be created from a single interpreter.
24///
25/// # Example
26/// ```no_run
27/// use mnn_rs::{Interpreter, ScheduleConfig, BackendType};
28///
29/// // Load a model
30/// let interpreter = Interpreter::from_file("model.mnn")?;
31///
32/// // Create a session with configuration
33/// let config = ScheduleConfig::new()
34///     .backend(BackendType::CPU)
35///     .num_threads(4);
36///
37/// let mut session = interpreter.create_session(config)?;
38///
39/// // Get input tensor and fill with data
40/// let input = session.get_input(None)?;
41/// // ... fill input with data ...
42///
43/// // Run inference
44/// session.run()?;
45///
46/// // Get output
47/// let output = session.get_output(None)?;
48/// # Ok::<(), mnn_rs::MnnError>(())
49/// ```
50pub struct Interpreter {
51    inner: *mut MNNInterpreter,
52    /// Model path (if loaded from file)
53    model_path: Option<String>,
54}
55
56// Safety: Interpreter operations are thread-safe through MNN's internal synchronization
57unsafe impl Send for Interpreter {}
58unsafe impl Sync for Interpreter {}
59
60impl Interpreter {
61    /// Create a new interpreter from a model file.
62    ///
63    /// # Arguments
64    /// * `path` - Path to the MNN model file
65    ///
66    /// # Returns
67    /// A new interpreter on success, or an error if the model cannot be loaded.
68    ///
69    /// # Example
70    /// ```no_run
71    /// use mnn_rs::Interpreter;
72    ///
73    /// let interpreter = Interpreter::from_file("model.mnn")?;
74    /// # Ok::<(), mnn_rs::MnnError>(())
75    /// ```
76    pub fn from_file<P: AsRef<Path>>(path: P) -> MnnResult<Self> {
77        let path_str = path.as_ref().to_string_lossy().into_owned();
78
79        // Check if file exists
80        if !path.as_ref().exists() {
81            return Err(MnnError::ModelNotFound(path.as_ref().to_path_buf()));
82        }
83
84        let c_path = CString::new(path_str.as_str())?;
85
86        let inner = unsafe { mnn_rs_sys::mnn_interpreter_create_from_file(c_path.as_ptr()) };
87
88        if inner.is_null() {
89            return Err(MnnError::invalid_model(format!(
90                "Failed to load model from: {}",
91                path_str
92            )));
93        }
94
95        Ok(Self {
96            inner,
97            model_path: Some(path_str),
98        })
99    }
100
101    /// Create a new interpreter from in-memory model data.
102    ///
103    /// This is useful for embedding models in the binary or loading
104    /// models from non-filesystem sources.
105    ///
106    /// # Arguments
107    /// * `data` - The model data as bytes
108    ///
109    /// # Returns
110    /// A new interpreter on success, or an error if the model cannot be loaded.
111    ///
112    /// # Example
113    /// ```ignore
114    /// use mnn_rs::Interpreter;
115    ///
116    /// let model_data = include_bytes!("model.mnn");
117    /// let interpreter = Interpreter::from_bytes(model_data)?;
118    /// # Ok::<(), mnn_rs::MnnError>(())
119    /// ```
120    pub fn from_bytes(data: &[u8]) -> MnnResult<Self> {
121        if data.is_empty() {
122            return Err(MnnError::invalid_model("Model data is empty"));
123        }
124
125        let inner = unsafe {
126            mnn_rs_sys::mnn_interpreter_create_from_buffer(
127                data.as_ptr() as *const std::ffi::c_void,
128                data.len(),
129            )
130        };
131
132        if inner.is_null() {
133            return Err(MnnError::invalid_model("Failed to load model from memory"));
134        }
135
136        Ok(Self {
137            inner,
138            model_path: None,
139        })
140    }
141
142    /// Create a new inference session.
143    ///
144    /// Sessions hold the runtime state for inference and can be created
145    /// with different backend configurations.
146    ///
147    /// # Arguments
148    /// * `config` - Schedule configuration specifying backend and settings
149    ///
150    /// # Returns
151    /// A new session on success, or an error if session creation fails.
152    pub fn create_session(&self, config: ScheduleConfig) -> MnnResult<Session> {
153        unsafe { Session::new(self.inner, config) }
154    }
155
156    /// Get the model path (if loaded from file).
157    pub fn model_path(&self) -> Option<&str> {
158        self.model_path.as_deref()
159    }
160
161    /// Get the business code (model identifier).
162    ///
163    /// # Returns
164    /// The business code string.
165    pub fn business_code(&self) -> String {
166        unsafe {
167            let ptr = mnn_rs_sys::mnn_interpreter_get_biz_code(self.inner);
168            if ptr.is_null() {
169                return String::new();
170            }
171            std::ffi::CStr::from_ptr(ptr)
172                .to_string_lossy()
173                .into_owned()
174        }
175    }
176
177    /// Get the model UUID.
178    ///
179    /// # Returns
180    /// The UUID string.
181    pub fn uuid(&self) -> String {
182        unsafe {
183            let ptr = mnn_rs_sys::mnn_interpreter_get_uuid(self.inner);
184            if ptr.is_null() {
185                return String::new();
186            }
187            std::ffi::CStr::from_ptr(ptr)
188                .to_string_lossy()
189                .into_owned()
190        }
191    }
192
193    /// Get the raw pointer to the underlying MNN Interpreter.
194    ///
195    /// # Safety
196    /// The returned pointer is owned by this Interpreter and must not be freed.
197    pub fn inner(&self) -> *mut mnn_rs_sys::MNNInterpreter {
198        self.inner
199    }
200
201    // ========================================================================
202    // Session Advanced Features
203    // ========================================================================
204
205    /// Set session mode.
206    ///
207    /// # Arguments
208    /// * `mode` - The session mode to set
209    pub fn set_session_mode(&self, mode: SessionMode) {
210        unsafe {
211            mnn_rs_sys::mnn_interpreter_set_session_mode(self.inner, mode.to_mnn());
212        }
213    }
214
215    /// Set cache file for optimization.
216    ///
217    /// # Arguments
218    /// * `path` - Path to the cache file
219    /// * `key_size` - Key size for cache lookup (default: 128)
220    pub fn set_cache_file<P: AsRef<Path>>(&self, path: P, key_size: usize) {
221        let path_str = path.as_ref().to_string_lossy();
222        if let Ok(c_path) = CString::new(path_str.as_ref()) {
223            unsafe {
224                mnn_rs_sys::mnn_interpreter_set_cache_file(self.inner, c_path.as_ptr(), key_size);
225            }
226        }
227    }
228
229    /// Update cache from a session.
230    ///
231    /// # Arguments
232    /// * `session` - The session to update cache from
233    ///
234    /// # Returns
235    /// Ok(()) on success, or an error on failure.
236    pub fn update_cache(&self, session: &Session) -> MnnResult<()> {
237        let result = unsafe {
238            mnn_rs_sys::mnn_interpreter_update_cache(self.inner, session.inner())
239        };
240
241        if result != mnn_rs_sys::MNN_ERROR_NONE as i32 {
242            return Err(MnnError::internal("Failed to update cache"));
243        }
244
245        Ok(())
246    }
247
248    /// Set external file for model.
249    ///
250    /// # Arguments
251    /// * `path` - Path to the external file
252    /// * `flag` - Flag for external file processing
253    pub fn set_external_file<P: AsRef<Path>>(&self, path: P, flag: usize) {
254        let path_str = path.as_ref().to_string_lossy();
255        if let Ok(c_path) = CString::new(path_str.as_ref()) {
256            unsafe {
257                mnn_rs_sys::mnn_interpreter_set_external_file(self.inner, c_path.as_ptr(), flag);
258            }
259        }
260    }
261
262    /// Get input tensor names for a session.
263    ///
264    /// # Arguments
265    /// * `session` - The session to get input names from
266    ///
267    /// # Returns
268    /// A vector of input tensor names.
269    pub fn get_input_names(&self, session: &Session) -> Vec<String> {
270        unsafe {
271            let array = mnn_rs_sys::mnn_interpreter_get_input_names(self.inner, session.inner());
272            let mut names = Vec::with_capacity(array.count as usize);
273
274            for i in 0..array.count {
275                let name_ptr = *array.names.offset(i as isize);
276                if !name_ptr.is_null() {
277                    let name = CStr::from_ptr(name_ptr).to_string_lossy().into_owned();
278                    names.push(name);
279                }
280            }
281
282            mnn_rs_sys::mnn_string_array_free(&mut std::ptr::read(&array) as *mut _);
283            names
284        }
285    }
286
287    /// Get output tensor names for a session.
288    ///
289    /// # Arguments
290    /// * `session` - The session to get output names from
291    ///
292    /// # Returns
293    /// A vector of output tensor names.
294    pub fn get_output_names(&self, session: &Session) -> Vec<String> {
295        unsafe {
296            let array = mnn_rs_sys::mnn_interpreter_get_output_names(self.inner, session.inner());
297            let mut names = Vec::with_capacity(array.count as usize);
298
299            for i in 0..array.count {
300                let name_ptr = *array.names.offset(i as isize);
301                if !name_ptr.is_null() {
302                    let name = CStr::from_ptr(name_ptr).to_string_lossy().into_owned();
303                    names.push(name);
304                }
305            }
306
307            mnn_rs_sys::mnn_string_array_free(&mut std::ptr::read(&array) as *mut _);
308            names
309        }
310    }
311
312    /// Resize a tensor with new shape.
313    ///
314    /// # Arguments
315    /// * `tensor` - The tensor to resize
316    /// * `shape` - New shape
317    pub fn resize_tensor(&self, tensor: &mut Tensor, shape: &[i32]) {
318        if shape.is_empty() {
319            return;
320        }
321
322        unsafe {
323            mnn_rs_sys::mnn_interpreter_resize_tensor(
324                self.inner,
325                tensor.inner_mut(),
326                shape.as_ptr(),
327                shape.len() as i32,
328            );
329        }
330    }
331
332    /// Resize a session after tensor resizing.
333    ///
334    /// # Arguments
335    /// * `session` - The session to resize
336    pub fn resize_session(&self, session: &mut Session) {
337        unsafe {
338            mnn_rs_sys::mnn_interpreter_resize_session(self.inner, session.inner_mut());
339        }
340    }
341
342    /// Get session FLOPS count.
343    ///
344    /// # Arguments
345    /// * `session` - The session to query
346    ///
347    /// # Returns
348    /// FLOPS count in millions.
349    pub fn get_session_flops(&self, session: &Session) -> f32 {
350        unsafe {
351            mnn_rs_sys::mnn_interpreter_get_session_flops(self.inner, session.inner())
352        }
353    }
354
355    /// Get session operator count.
356    ///
357    /// # Arguments
358    /// * `session` - The session to query
359    ///
360    /// # Returns
361    /// Operator count (approximate, based on FLOPS).
362    pub fn get_session_op_count(&self, session: &Session) -> i32 {
363        unsafe {
364            mnn_rs_sys::mnn_interpreter_get_session_op_count(self.inner, session.inner())
365        }
366    }
367}
368
369impl Drop for Interpreter {
370    fn drop(&mut self) {
371        if !self.inner.is_null() {
372            unsafe {
373                mnn_rs_sys::mnn_interpreter_destroy(self.inner);
374            }
375        }
376    }
377}
378
379impl std::fmt::Debug for Interpreter {
380    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
381        f.debug_struct("Interpreter")
382            .field("model_path", &self.model_path)
383            .field("business_code", &self.business_code())
384            .finish()
385    }
386}
387
388#[cfg(feature = "async")]
389mod async_impl {
390    use super::*;
391    use std::sync::Arc;
392
393    /// An asynchronous interpreter wrapper.
394    ///
395    /// This wraps an interpreter for use in async contexts.
396    #[derive(Clone)]
397    pub struct AsyncInterpreter {
398        inner: Arc<Interpreter>,
399    }
400
401    impl AsyncInterpreter {
402        /// Create an async interpreter from a file.
403        pub async fn from_file<P: AsRef<Path> + Send + 'static>(path: P) -> MnnResult<Self> {
404            let path = path.as_ref().to_path_buf();
405            tokio::task::spawn_blocking(move || Interpreter::from_file(path))
406                .await
407                .map_err(|e| MnnError::AsyncError(e.to_string()))?
408                .map(Self::new)
409        }
410
411        /// Create an async interpreter from bytes.
412        pub async fn from_bytes(data: Vec<u8>) -> MnnResult<Self> {
413            tokio::task::spawn_blocking(move || Interpreter::from_bytes(&data))
414                .await
415                .map_err(|e| MnnError::AsyncError(e.to_string()))?
416                .map(Self::new)
417        }
418
419        /// Create a new async interpreter from an existing interpreter.
420        pub fn new(interpreter: Interpreter) -> Self {
421            Self {
422                inner: Arc::new(interpreter),
423            }
424        }
425
426        /// Create a session.
427        pub async fn create_session(&self, config: ScheduleConfig) -> MnnResult<Session> {
428            let interpreter = Arc::clone(&self.inner);
429            tokio::task::spawn_blocking(move || interpreter.create_session(config))
430                .await
431                .map_err(|e| MnnError::AsyncError(e.to_string()))?
432        }
433
434        /// Get the inner interpreter.
435        pub fn inner(&self) -> &Interpreter {
436            &self.inner
437        }
438    }
439}
440
441#[cfg(feature = "async")]
442pub use async_impl::AsyncInterpreter;
443
444#[cfg(test)]
445mod tests {
446    use super::*;
447
448    #[test]
449    fn test_interpreter_missing_file() {
450        let result = Interpreter::from_file("nonexistent_model.mnn");
451        assert!(result.is_err());
452        assert!(matches!(result.unwrap_err(), MnnError::ModelNotFound(_)));
453    }
454
455    #[test]
456    fn test_interpreter_empty_bytes() {
457        let result = Interpreter::from_bytes(&[]);
458        assert!(result.is_err());
459        assert!(matches!(result.unwrap_err(), MnnError::InvalidModel(_)));
460    }
461}