Skip to main content

mnn_rs/
session.rs

1//! Session management for MNN inference.
2//!
3//! A session represents an inference context with allocated resources.
4//! Sessions are created from interpreters and are used to run inference.
5
6use crate::config::ScheduleConfig;
7use crate::error::{MnnError, MnnResult};
8use crate::tensor::Tensor;
9use mnn_rs_sys::MNNSession;
10use std::ffi::CString;
11
12/// An inference session.
13///
14/// Sessions hold the runtime state for model inference, including
15/// allocated memory and intermediate tensors.
16///
17/// # Thread Safety
18/// Sessions can be used from multiple threads, but `run()` must not be
19/// called concurrently on the same session.
20pub struct Session {
21    inner: *mut MNNSession,
22    /// Interpreter pointer (not owned)
23    interpreter: *mut mnn_rs_sys::MNNInterpreter,
24    /// Whether the session has been run at least once
25    has_run: bool,
26}
27
28// Safety: Session operations are thread-safe through MNN's internal synchronization
29unsafe impl Send for Session {}
30unsafe impl Sync for Session {}
31
32impl Session {
33    /// Create a new session.
34    ///
35    /// # Safety
36    /// The interpreter pointer must be valid.
37    pub(crate) unsafe fn new(
38        interpreter: *mut mnn_rs_sys::MNNInterpreter,
39        config: ScheduleConfig,
40    ) -> MnnResult<Self> {
41        // SAFETY: Caller ensures interpreter pointer is valid
42        let inner = unsafe {
43            mnn_rs_sys::mnn_interpreter_create_session(
44                interpreter,
45                config.backend_config.backend_type.to_mnn_type(),
46                config.num_threads as i32,
47            )
48        };
49
50        if inner.is_null() {
51            return Err(MnnError::session_error("Failed to create session"));
52        }
53
54        Ok(Self {
55            inner,
56            interpreter,
57            has_run: false,
58        })
59    }
60
61    /// Get an input tensor by name.
62    ///
63    /// # Arguments
64    /// * `name` - The name of the input tensor (None for the first input)
65    ///
66    /// # Returns
67    /// A mutable reference to the input tensor.
68    pub fn get_input(&self, name: Option<&str>) -> MnnResult<Tensor> {
69        unsafe {
70            let name_ptr = match name {
71                Some(n) => {
72                    let c_name = CString::new(n)?;
73                    c_name.as_ptr()
74                }
75                None => std::ptr::null(),
76            };
77
78            let tensor_ptr =
79                mnn_rs_sys::mnn_interpreter_get_session_input(self.interpreter, self.inner, name_ptr);
80
81            if tensor_ptr.is_null() {
82                return Err(MnnError::tensor_error(match name {
83                    Some(n) => format!("Input tensor '{}' not found", n),
84                    None => "No input tensor found".to_string(),
85                }));
86            }
87
88            Ok(Tensor::from_ptr_with_name(
89                tensor_ptr,
90                name.map(|s| s.to_string()),
91            ))
92        }
93    }
94
95    /// Get an output tensor by name.
96    ///
97    /// # Arguments
98    /// * `name` - The name of the output tensor (None for the first output)
99    ///
100    /// # Returns
101    /// A reference to the output tensor.
102    pub fn get_output(&self, name: Option<&str>) -> MnnResult<Tensor> {
103        unsafe {
104            let name_ptr = match name {
105                Some(n) => {
106                    let c_name = CString::new(n)?;
107                    c_name.as_ptr()
108                }
109                None => std::ptr::null(),
110            };
111
112            let tensor_ptr = mnn_rs_sys::mnn_interpreter_get_session_output(
113                self.interpreter,
114                self.inner,
115                name_ptr,
116            );
117
118            if tensor_ptr.is_null() {
119                return Err(MnnError::tensor_error(match name {
120                    Some(n) => format!("Output tensor '{}' not found", n),
121                    None => "No output tensor found".to_string(),
122                }));
123            }
124
125            Ok(Tensor::from_ptr_with_name(
126                tensor_ptr,
127                name.map(|s| s.to_string()),
128            ))
129        }
130    }
131
132    /// Run inference.
133    ///
134    /// This executes the model on the current input tensors and
135    /// populates the output tensors.
136    ///
137    /// # Returns
138    /// Ok(()) on success, or an error on failure.
139    pub fn run(&mut self) -> MnnResult<()> {
140        let result =
141            unsafe { mnn_rs_sys::mnn_interpreter_run_session(self.interpreter, self.inner) };
142
143        match result {
144            x if x == mnn_rs_sys::MNN_ERROR_NONE => {
145                self.has_run = true;
146                Ok(())
147            }
148            x if x == mnn_rs_sys::MNN_ERROR_OUT_OF_MEMORY => {
149                Err(MnnError::out_of_memory("Out of memory during inference"))
150            }
151            x if x == mnn_rs_sys::MNN_ERROR_NOT_SUPPORT => {
152                Err(MnnError::unsupported("Operation not supported"))
153            }
154            x if x == mnn_rs_sys::MNN_ERROR_EXECUTION => {
155                Err(MnnError::internal("Execution error during inference"))
156            }
157            code => Err(MnnError::internal(format!(
158                "Inference failed with error code: {}",
159                code
160            ))),
161        }
162    }
163
164    /// Check if the session has been run at least once.
165    pub fn has_run(&self) -> bool {
166        self.has_run
167    }
168
169    /// Get the memory usage of this session in bytes.
170    pub fn memory_usage(&self) -> usize {
171        let memory_mb = unsafe {
172            mnn_rs_sys::mnn_interpreter_get_session_memory(self.interpreter, self.inner)
173        };
174        (memory_mb * 1024.0 * 1024.0) as usize
175    }
176
177    /// Get the FLOPS count of this session.
178    pub fn flops(&self) -> f32 {
179        unsafe { mnn_rs_sys::mnn_interpreter_get_session_flops(self.interpreter, self.inner) }
180    }
181
182    /// Get the raw pointer to the underlying MNN session.
183    ///
184    /// # Safety
185    /// The returned pointer is owned by this Session and must not be freed.
186    pub fn inner(&self) -> *mut MNNSession {
187        self.inner
188    }
189
190    /// Get the mutable raw pointer to the underlying MNN session.
191    ///
192    /// # Safety
193    /// The returned pointer is owned by this Session and must not be freed.
194    pub fn inner_mut(&mut self) -> *mut MNNSession {
195        self.inner
196    }
197
198    /// Get the interpreter pointer (not owned).
199    ///
200    /// # Safety
201    /// The returned pointer is owned by the Interpreter.
202    pub fn interpreter(&self) -> *mut mnn_rs_sys::MNNInterpreter {
203        self.interpreter
204    }
205
206    /// Create a new session from raw pointers.
207    ///
208    /// # Safety
209    /// The pointers must be valid and the interpreter must outlive the session.
210    pub unsafe fn from_ptr(inner: *mut MNNSession, interpreter: *mut mnn_rs_sys::MNNInterpreter) -> Self {
211        Self {
212            inner,
213            interpreter,
214            has_run: false,
215        }
216    }
217}
218
219impl Drop for Session {
220    fn drop(&mut self) {
221        if !self.inner.is_null() && !self.interpreter.is_null() {
222            unsafe {
223                mnn_rs_sys::mnn_interpreter_release_session(self.interpreter, self.inner);
224            }
225        }
226    }
227}
228
229impl std::fmt::Debug for Session {
230    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
231        f.debug_struct("Session")
232            .field("has_run", &self.has_run)
233            .finish()
234    }
235}
236
237/// A guard for ensuring session resources are properly managed.
238///
239/// When dropped, this will release the session resources.
240pub struct SessionGuard<'a> {
241    session: &'a mut Session,
242}
243
244impl std::fmt::Debug for SessionGuard<'_> {
245    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
246        f.debug_struct("SessionGuard").finish_non_exhaustive()
247    }
248}
249
250impl<'a> SessionGuard<'a> {
251    /// Create a new guard for a session.
252    pub fn new(session: &'a mut Session) -> Self {
253        Self { session }
254    }
255
256    /// Run inference.
257    pub fn run(&mut self) -> MnnResult<()> {
258        self.session.run()
259    }
260}
261
262impl<'a> Drop for SessionGuard<'a> {
263    fn drop(&mut self) {
264        // Session will be cleaned up when dropped
265    }
266}
267
268#[cfg(feature = "async")]
269mod async_impl {
270    use super::*;
271
272    impl Session {
273        /// Run inference asynchronously.
274        pub async fn run_async(&mut self) -> MnnResult<()> {
275            let inner = self.inner;
276            let interpreter = self.interpreter;
277
278            let result = tokio::task::spawn_blocking(move || unsafe {
279                mnn_rs_sys::mnn_interpreter_run_session(interpreter, inner)
280            })
281            .await
282            .map_err(|e| MnnError::AsyncError(e.to_string()))?;
283
284            match result {
285                x if x == mnn_rs_sys::MNN_ERROR_NONE => {
286                    self.has_run = true;
287                    Ok(())
288                }
289                x if x == mnn_rs_sys::MNN_ERROR_OUT_OF_MEMORY => {
290                    Err(MnnError::out_of_memory("Out of memory during inference"))
291                }
292                x if x == mnn_rs_sys::MNN_ERROR_NOT_SUPPORT => {
293                    Err(MnnError::unsupported("Operation not supported"))
294                }
295                x if x == mnn_rs_sys::MNN_ERROR_EXECUTION => {
296                    Err(MnnError::internal("Execution error during inference"))
297                }
298                code => Err(MnnError::internal(format!(
299                    "Inference failed with error code: {}",
300                    code
301                ))),
302            }
303        }
304    }
305}
306
307#[cfg(test)]
308mod tests {}