mnn_rs/interpreter.rs
1//! Model interpreter for MNN inference.
2//!
3//! The interpreter holds the model and manages inference sessions.
4//! It's the primary entry point for running neural network inference with MNN.
5
6use crate::config::ScheduleConfig;
7use crate::config::SessionMode;
8use crate::error::{MnnError, MnnResult};
9use crate::session::Session;
10use crate::tensor::Tensor;
11use mnn_rs_sys::MNNInterpreter;
12use std::ffi::CString;
13use std::ffi::CStr;
14use std::path::Path;
15
16/// A model interpreter that holds a loaded neural network model.
17///
18/// The interpreter is the main entry point for MNN inference. It manages
19/// the model and can create multiple sessions for concurrent inference.
20///
21/// # Thread Safety
22/// The interpreter is thread-safe and can be shared across threads.
23/// Multiple sessions can be created from a single interpreter.
24///
25/// # Example
26/// ```no_run
27/// use mnn_rs::{Interpreter, ScheduleConfig, BackendType};
28///
29/// // Load a model
30/// let interpreter = Interpreter::from_file("model.mnn")?;
31///
32/// // Create a session with configuration
33/// let config = ScheduleConfig::new()
34/// .backend(BackendType::CPU)
35/// .num_threads(4);
36///
37/// let mut session = interpreter.create_session(config)?;
38///
39/// // Get input tensor and fill with data
40/// let input = session.get_input(None)?;
41/// // ... fill input with data ...
42///
43/// // Run inference
44/// session.run()?;
45///
46/// // Get output
47/// let output = session.get_output(None)?;
48/// # Ok::<(), mnn_rs::MnnError>(())
49/// ```
50pub struct Interpreter {
51 inner: *mut MNNInterpreter,
52 /// Model path (if loaded from file)
53 model_path: Option<String>,
54}
55
56// Safety: Interpreter operations are thread-safe through MNN's internal synchronization
57unsafe impl Send for Interpreter {}
58unsafe impl Sync for Interpreter {}
59
60impl Interpreter {
61 /// Create a new interpreter from a model file.
62 ///
63 /// # Arguments
64 /// * `path` - Path to the MNN model file
65 ///
66 /// # Returns
67 /// A new interpreter on success, or an error if the model cannot be loaded.
68 ///
69 /// # Example
70 /// ```no_run
71 /// use mnn_rs::Interpreter;
72 ///
73 /// let interpreter = Interpreter::from_file("model.mnn")?;
74 /// # Ok::<(), mnn_rs::MnnError>(())
75 /// ```
76 pub fn from_file<P: AsRef<Path>>(path: P) -> MnnResult<Self> {
77 let path_str = path.as_ref().to_string_lossy().into_owned();
78
79 // Check if file exists
80 if !path.as_ref().exists() {
81 return Err(MnnError::ModelNotFound(path.as_ref().to_path_buf()));
82 }
83
84 let c_path = CString::new(path_str.as_str())?;
85
86 let inner = unsafe { mnn_rs_sys::mnn_interpreter_create_from_file(c_path.as_ptr()) };
87
88 if inner.is_null() {
89 return Err(MnnError::invalid_model(format!(
90 "Failed to load model from: {}",
91 path_str
92 )));
93 }
94
95 Ok(Self {
96 inner,
97 model_path: Some(path_str),
98 })
99 }
100
101 /// Create a new interpreter from in-memory model data.
102 ///
103 /// This is useful for embedding models in the binary or loading
104 /// models from non-filesystem sources.
105 ///
106 /// # Arguments
107 /// * `data` - The model data as bytes
108 ///
109 /// # Returns
110 /// A new interpreter on success, or an error if the model cannot be loaded.
111 ///
112 /// # Example
113 /// ```ignore
114 /// use mnn_rs::Interpreter;
115 ///
116 /// let model_data = include_bytes!("model.mnn");
117 /// let interpreter = Interpreter::from_bytes(model_data)?;
118 /// # Ok::<(), mnn_rs::MnnError>(())
119 /// ```
120 pub fn from_bytes(data: &[u8]) -> MnnResult<Self> {
121 if data.is_empty() {
122 return Err(MnnError::invalid_model("Model data is empty"));
123 }
124
125 let inner = unsafe {
126 mnn_rs_sys::mnn_interpreter_create_from_buffer(
127 data.as_ptr() as *const std::ffi::c_void,
128 data.len(),
129 )
130 };
131
132 if inner.is_null() {
133 return Err(MnnError::invalid_model("Failed to load model from memory"));
134 }
135
136 Ok(Self {
137 inner,
138 model_path: None,
139 })
140 }
141
142 /// Create a new inference session.
143 ///
144 /// Sessions hold the runtime state for inference and can be created
145 /// with different backend configurations.
146 ///
147 /// # Arguments
148 /// * `config` - Schedule configuration specifying backend and settings
149 ///
150 /// # Returns
151 /// A new session on success, or an error if session creation fails.
152 pub fn create_session(&self, config: ScheduleConfig) -> MnnResult<Session> {
153 unsafe { Session::new(self.inner, config) }
154 }
155
156 /// Get the model path (if loaded from file).
157 pub fn model_path(&self) -> Option<&str> {
158 self.model_path.as_deref()
159 }
160
161 /// Get the business code (model identifier).
162 ///
163 /// # Returns
164 /// The business code string.
165 pub fn business_code(&self) -> String {
166 unsafe {
167 let ptr = mnn_rs_sys::mnn_interpreter_get_biz_code(self.inner);
168 if ptr.is_null() {
169 return String::new();
170 }
171 std::ffi::CStr::from_ptr(ptr)
172 .to_string_lossy()
173 .into_owned()
174 }
175 }
176
177 /// Get the model UUID.
178 ///
179 /// # Returns
180 /// The UUID string.
181 pub fn uuid(&self) -> String {
182 unsafe {
183 let ptr = mnn_rs_sys::mnn_interpreter_get_uuid(self.inner);
184 if ptr.is_null() {
185 return String::new();
186 }
187 std::ffi::CStr::from_ptr(ptr)
188 .to_string_lossy()
189 .into_owned()
190 }
191 }
192
193 /// Get the raw pointer to the underlying MNN Interpreter.
194 ///
195 /// # Safety
196 /// The returned pointer is owned by this Interpreter and must not be freed.
197 pub fn inner(&self) -> *mut mnn_rs_sys::MNNInterpreter {
198 self.inner
199 }
200
201 // ========================================================================
202 // Session Advanced Features
203 // ========================================================================
204
205 /// Set session mode.
206 ///
207 /// # Arguments
208 /// * `mode` - The session mode to set
209 pub fn set_session_mode(&self, mode: SessionMode) {
210 unsafe {
211 mnn_rs_sys::mnn_interpreter_set_session_mode(self.inner, mode.to_mnn());
212 }
213 }
214
215 /// Set cache file for optimization.
216 ///
217 /// # Arguments
218 /// * `path` - Path to the cache file
219 /// * `key_size` - Key size for cache lookup (default: 128)
220 pub fn set_cache_file<P: AsRef<Path>>(&self, path: P, key_size: usize) {
221 let path_str = path.as_ref().to_string_lossy();
222 if let Ok(c_path) = CString::new(path_str.as_ref()) {
223 unsafe {
224 mnn_rs_sys::mnn_interpreter_set_cache_file(self.inner, c_path.as_ptr(), key_size);
225 }
226 }
227 }
228
229 /// Update cache from a session.
230 ///
231 /// # Arguments
232 /// * `session` - The session to update cache from
233 ///
234 /// # Returns
235 /// Ok(()) on success, or an error on failure.
236 pub fn update_cache(&self, session: &Session) -> MnnResult<()> {
237 let result = unsafe {
238 mnn_rs_sys::mnn_interpreter_update_cache(self.inner, session.inner())
239 };
240
241 if result != mnn_rs_sys::MNN_ERROR_NONE as i32 {
242 return Err(MnnError::internal("Failed to update cache"));
243 }
244
245 Ok(())
246 }
247
248 /// Set external file for model.
249 ///
250 /// # Arguments
251 /// * `path` - Path to the external file
252 /// * `flag` - Flag for external file processing
253 pub fn set_external_file<P: AsRef<Path>>(&self, path: P, flag: usize) {
254 let path_str = path.as_ref().to_string_lossy();
255 if let Ok(c_path) = CString::new(path_str.as_ref()) {
256 unsafe {
257 mnn_rs_sys::mnn_interpreter_set_external_file(self.inner, c_path.as_ptr(), flag);
258 }
259 }
260 }
261
262 /// Get input tensor names for a session.
263 ///
264 /// # Arguments
265 /// * `session` - The session to get input names from
266 ///
267 /// # Returns
268 /// A vector of input tensor names.
269 pub fn get_input_names(&self, session: &Session) -> Vec<String> {
270 unsafe {
271 let array = mnn_rs_sys::mnn_interpreter_get_input_names(self.inner, session.inner());
272 let mut names = Vec::with_capacity(array.count as usize);
273
274 for i in 0..array.count {
275 let name_ptr = *array.names.offset(i as isize);
276 if !name_ptr.is_null() {
277 let name = CStr::from_ptr(name_ptr).to_string_lossy().into_owned();
278 names.push(name);
279 }
280 }
281
282 mnn_rs_sys::mnn_string_array_free(&mut std::ptr::read(&array) as *mut _);
283 names
284 }
285 }
286
287 /// Get output tensor names for a session.
288 ///
289 /// # Arguments
290 /// * `session` - The session to get output names from
291 ///
292 /// # Returns
293 /// A vector of output tensor names.
294 pub fn get_output_names(&self, session: &Session) -> Vec<String> {
295 unsafe {
296 let array = mnn_rs_sys::mnn_interpreter_get_output_names(self.inner, session.inner());
297 let mut names = Vec::with_capacity(array.count as usize);
298
299 for i in 0..array.count {
300 let name_ptr = *array.names.offset(i as isize);
301 if !name_ptr.is_null() {
302 let name = CStr::from_ptr(name_ptr).to_string_lossy().into_owned();
303 names.push(name);
304 }
305 }
306
307 mnn_rs_sys::mnn_string_array_free(&mut std::ptr::read(&array) as *mut _);
308 names
309 }
310 }
311
312 /// Resize a tensor with new shape.
313 ///
314 /// # Arguments
315 /// * `tensor` - The tensor to resize
316 /// * `shape` - New shape
317 pub fn resize_tensor(&self, tensor: &mut Tensor, shape: &[i32]) {
318 if shape.is_empty() {
319 return;
320 }
321
322 unsafe {
323 mnn_rs_sys::mnn_interpreter_resize_tensor(
324 self.inner,
325 tensor.inner_mut(),
326 shape.as_ptr(),
327 shape.len() as i32,
328 );
329 }
330 }
331
332 /// Resize a session after tensor resizing.
333 ///
334 /// # Arguments
335 /// * `session` - The session to resize
336 pub fn resize_session(&self, session: &mut Session) {
337 unsafe {
338 mnn_rs_sys::mnn_interpreter_resize_session(self.inner, session.inner_mut());
339 }
340 }
341
342 /// Get session FLOPS count.
343 ///
344 /// # Arguments
345 /// * `session` - The session to query
346 ///
347 /// # Returns
348 /// FLOPS count in millions.
349 pub fn get_session_flops(&self, session: &Session) -> f32 {
350 unsafe {
351 mnn_rs_sys::mnn_interpreter_get_session_flops(self.inner, session.inner())
352 }
353 }
354
355 /// Get session operator count.
356 ///
357 /// # Arguments
358 /// * `session` - The session to query
359 ///
360 /// # Returns
361 /// Operator count (approximate, based on FLOPS).
362 pub fn get_session_op_count(&self, session: &Session) -> i32 {
363 unsafe {
364 mnn_rs_sys::mnn_interpreter_get_session_op_count(self.inner, session.inner())
365 }
366 }
367}
368
369impl Drop for Interpreter {
370 fn drop(&mut self) {
371 if !self.inner.is_null() {
372 unsafe {
373 mnn_rs_sys::mnn_interpreter_destroy(self.inner);
374 }
375 }
376 }
377}
378
379impl std::fmt::Debug for Interpreter {
380 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
381 f.debug_struct("Interpreter")
382 .field("model_path", &self.model_path)
383 .field("business_code", &self.business_code())
384 .finish()
385 }
386}
387
388#[cfg(feature = "async")]
389mod async_impl {
390 use super::*;
391 use std::sync::Arc;
392
393 /// An asynchronous interpreter wrapper.
394 ///
395 /// This wraps an interpreter for use in async contexts.
396 #[derive(Clone)]
397 pub struct AsyncInterpreter {
398 inner: Arc<Interpreter>,
399 }
400
401 impl AsyncInterpreter {
402 /// Create an async interpreter from a file.
403 pub async fn from_file<P: AsRef<Path> + Send + 'static>(path: P) -> MnnResult<Self> {
404 let path = path.as_ref().to_path_buf();
405 tokio::task::spawn_blocking(move || Interpreter::from_file(path))
406 .await
407 .map_err(|e| MnnError::AsyncError(e.to_string()))?
408 .map(Self::new)
409 }
410
411 /// Create an async interpreter from bytes.
412 pub async fn from_bytes(data: Vec<u8>) -> MnnResult<Self> {
413 tokio::task::spawn_blocking(move || Interpreter::from_bytes(&data))
414 .await
415 .map_err(|e| MnnError::AsyncError(e.to_string()))?
416 .map(Self::new)
417 }
418
419 /// Create a new async interpreter from an existing interpreter.
420 pub fn new(interpreter: Interpreter) -> Self {
421 Self {
422 inner: Arc::new(interpreter),
423 }
424 }
425
426 /// Create a session.
427 pub async fn create_session(&self, config: ScheduleConfig) -> MnnResult<Session> {
428 let interpreter = Arc::clone(&self.inner);
429 tokio::task::spawn_blocking(move || interpreter.create_session(config))
430 .await
431 .map_err(|e| MnnError::AsyncError(e.to_string()))?
432 }
433
434 /// Get the inner interpreter.
435 pub fn inner(&self) -> &Interpreter {
436 &self.inner
437 }
438 }
439}
440
441#[cfg(feature = "async")]
442pub use async_impl::AsyncInterpreter;
443
444#[cfg(test)]
445mod tests {
446 use super::*;
447
448 #[test]
449 fn test_interpreter_missing_file() {
450 let result = Interpreter::from_file("nonexistent_model.mnn");
451 assert!(result.is_err());
452 assert!(matches!(result.unwrap_err(), MnnError::ModelNotFound(_)));
453 }
454
455 #[test]
456 fn test_interpreter_empty_bytes() {
457 let result = Interpreter::from_bytes(&[]);
458 assert!(result.is_err());
459 assert!(matches!(result.unwrap_err(), MnnError::InvalidModel(_)));
460 }
461}