1use crate::config::ScheduleConfig;
7use crate::error::{MnnError, MnnResult};
8use crate::tensor::Tensor;
9use mnn_rs_sys::MNNSession;
10use std::ffi::CString;
11
12pub struct Session {
21 inner: *mut MNNSession,
22 interpreter: *mut mnn_rs_sys::MNNInterpreter,
24 has_run: bool,
26}
27
28unsafe impl Send for Session {}
30unsafe impl Sync for Session {}
31
32impl Session {
33 pub(crate) unsafe fn new(
38 interpreter: *mut mnn_rs_sys::MNNInterpreter,
39 config: ScheduleConfig,
40 ) -> MnnResult<Self> {
41 let inner = unsafe {
43 mnn_rs_sys::mnn_interpreter_create_session(
44 interpreter,
45 config.backend_config.backend_type.to_mnn_type(),
46 config.num_threads as i32,
47 )
48 };
49
50 if inner.is_null() {
51 return Err(MnnError::session_error("Failed to create session"));
52 }
53
54 Ok(Self {
55 inner,
56 interpreter,
57 has_run: false,
58 })
59 }
60
61 pub fn get_input(&self, name: Option<&str>) -> MnnResult<Tensor> {
69 unsafe {
70 let name_ptr = match name {
71 Some(n) => {
72 let c_name = CString::new(n)?;
73 c_name.as_ptr()
74 }
75 None => std::ptr::null(),
76 };
77
78 let tensor_ptr =
79 mnn_rs_sys::mnn_interpreter_get_session_input(self.interpreter, self.inner, name_ptr);
80
81 if tensor_ptr.is_null() {
82 return Err(MnnError::tensor_error(match name {
83 Some(n) => format!("Input tensor '{}' not found", n),
84 None => "No input tensor found".to_string(),
85 }));
86 }
87
88 Ok(Tensor::from_ptr_with_name(
89 tensor_ptr,
90 name.map(|s| s.to_string()),
91 ))
92 }
93 }
94
95 pub fn get_output(&self, name: Option<&str>) -> MnnResult<Tensor> {
103 unsafe {
104 let name_ptr = match name {
105 Some(n) => {
106 let c_name = CString::new(n)?;
107 c_name.as_ptr()
108 }
109 None => std::ptr::null(),
110 };
111
112 let tensor_ptr = mnn_rs_sys::mnn_interpreter_get_session_output(
113 self.interpreter,
114 self.inner,
115 name_ptr,
116 );
117
118 if tensor_ptr.is_null() {
119 return Err(MnnError::tensor_error(match name {
120 Some(n) => format!("Output tensor '{}' not found", n),
121 None => "No output tensor found".to_string(),
122 }));
123 }
124
125 Ok(Tensor::from_ptr_with_name(
126 tensor_ptr,
127 name.map(|s| s.to_string()),
128 ))
129 }
130 }
131
132 pub fn run(&mut self) -> MnnResult<()> {
140 let result =
141 unsafe { mnn_rs_sys::mnn_interpreter_run_session(self.interpreter, self.inner) };
142
143 match result {
144 x if x == mnn_rs_sys::MNN_ERROR_NONE => {
145 self.has_run = true;
146 Ok(())
147 }
148 x if x == mnn_rs_sys::MNN_ERROR_OUT_OF_MEMORY => {
149 Err(MnnError::out_of_memory("Out of memory during inference"))
150 }
151 x if x == mnn_rs_sys::MNN_ERROR_NOT_SUPPORT => {
152 Err(MnnError::unsupported("Operation not supported"))
153 }
154 x if x == mnn_rs_sys::MNN_ERROR_EXECUTION => {
155 Err(MnnError::internal("Execution error during inference"))
156 }
157 code => Err(MnnError::internal(format!(
158 "Inference failed with error code: {}",
159 code
160 ))),
161 }
162 }
163
164 pub fn has_run(&self) -> bool {
166 self.has_run
167 }
168
169 pub fn memory_usage(&self) -> usize {
171 let memory_mb = unsafe {
172 mnn_rs_sys::mnn_interpreter_get_session_memory(self.interpreter, self.inner)
173 };
174 (memory_mb * 1024.0 * 1024.0) as usize
175 }
176
177 pub fn flops(&self) -> f32 {
179 unsafe { mnn_rs_sys::mnn_interpreter_get_session_flops(self.interpreter, self.inner) }
180 }
181
182 pub fn inner(&self) -> *mut MNNSession {
187 self.inner
188 }
189
190 pub fn inner_mut(&mut self) -> *mut MNNSession {
195 self.inner
196 }
197
198 pub fn interpreter(&self) -> *mut mnn_rs_sys::MNNInterpreter {
203 self.interpreter
204 }
205
206 pub unsafe fn from_ptr(inner: *mut MNNSession, interpreter: *mut mnn_rs_sys::MNNInterpreter) -> Self {
211 Self {
212 inner,
213 interpreter,
214 has_run: false,
215 }
216 }
217}
218
219impl Drop for Session {
220 fn drop(&mut self) {
221 if !self.inner.is_null() && !self.interpreter.is_null() {
222 unsafe {
223 mnn_rs_sys::mnn_interpreter_release_session(self.interpreter, self.inner);
224 }
225 }
226 }
227}
228
229impl std::fmt::Debug for Session {
230 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
231 f.debug_struct("Session")
232 .field("has_run", &self.has_run)
233 .finish()
234 }
235}
236
237pub struct SessionGuard<'a> {
241 session: &'a mut Session,
242}
243
244impl std::fmt::Debug for SessionGuard<'_> {
245 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
246 f.debug_struct("SessionGuard").finish_non_exhaustive()
247 }
248}
249
250impl<'a> SessionGuard<'a> {
251 pub fn new(session: &'a mut Session) -> Self {
253 Self { session }
254 }
255
256 pub fn run(&mut self) -> MnnResult<()> {
258 self.session.run()
259 }
260}
261
262impl<'a> Drop for SessionGuard<'a> {
263 fn drop(&mut self) {
264 }
266}
267
268#[cfg(feature = "async")]
269mod async_impl {
270 use super::*;
271
272 impl Session {
273 pub async fn run_async(&mut self) -> MnnResult<()> {
275 let inner = self.inner;
276 let interpreter = self.interpreter;
277
278 let result = tokio::task::spawn_blocking(move || unsafe {
279 mnn_rs_sys::mnn_interpreter_run_session(interpreter, inner)
280 })
281 .await
282 .map_err(|e| MnnError::AsyncError(e.to_string()))?;
283
284 match result {
285 x if x == mnn_rs_sys::MNN_ERROR_NONE => {
286 self.has_run = true;
287 Ok(())
288 }
289 x if x == mnn_rs_sys::MNN_ERROR_OUT_OF_MEMORY => {
290 Err(MnnError::out_of_memory("Out of memory during inference"))
291 }
292 x if x == mnn_rs_sys::MNN_ERROR_NOT_SUPPORT => {
293 Err(MnnError::unsupported("Operation not supported"))
294 }
295 x if x == mnn_rs_sys::MNN_ERROR_EXECUTION => {
296 Err(MnnError::internal("Execution error during inference"))
297 }
298 code => Err(MnnError::internal(format!(
299 "Inference failed with error code: {}",
300 code
301 ))),
302 }
303 }
304 }
305}
306
307#[cfg(test)]
308mod tests {}