onnxruntime_ng/
session.rs

1//! Module containing session types
2
3use std::{ffi::CString, fmt::Debug, path::Path};
4
5#[cfg(not(target_family = "windows"))]
6use std::os::unix::ffi::OsStrExt;
7#[cfg(target_family = "windows")]
8use std::os::windows::ffi::OsStrExt;
9
10#[cfg(feature = "model-fetching")]
11use std::env;
12
13use ndarray::Array;
14use tracing::{debug, error};
15
16use onnxruntime_sys_ng as sys;
17
18use crate::{
19    char_p_to_string,
20    environment::Environment,
21    error::{
22        assert_not_null_pointer, assert_null_pointer, status_to_result, NonMatchingDimensionsError,
23        OrtApiError, OrtError, Result,
24    },
25    g_ort,
26    memory::MemoryInfo,
27    tensor::{
28        ort_owned_tensor::{OrtOwnedTensor, OrtOwnedTensorExtractor},
29        OrtTensor,
30    },
31    AllocatorType, GraphOptimizationLevel, MemType, TensorElementDataType,
32    TypeToTensorElementDataType,
33};
34
35#[cfg(feature = "model-fetching")]
36use crate::{download::AvailableOnnxModel, error::OrtDownloadError};
37
38/// Type used to create a session using the _builder pattern_
39///
40/// A `SessionBuilder` is created by calling the
41/// [`Environment::new_session_builder()`](../env/struct.Environment.html#method.new_session_builder)
42/// method on the environment.
43///
44/// Once created, use the different methods to configure the session.
45///
46/// Once configured, use the [`SessionBuilder::with_model_from_file()`](../session/struct.SessionBuilder.html#method.with_model_from_file)
47/// method to "commit" the builder configuration into a [`Session`](../session/struct.Session.html).
48///
49/// # Example
50///
51/// ```no_run
52/// # use std::error::Error;
53/// # use onnxruntime::{environment::Environment, LoggingLevel, GraphOptimizationLevel};
54/// # fn main() -> Result<(), Box<dyn Error>> {
55/// let environment = Environment::builder()
56///     .with_name("test")
57///     .with_log_level(LoggingLevel::Verbose)
58///     .build()?;
59/// let mut session = environment
60///     .new_session_builder()?
61///     .with_optimization_level(GraphOptimizationLevel::Basic)?
62///     .with_number_threads(1)?
63///     .with_model_from_file("squeezenet.onnx")?;
64/// # Ok(())
65/// # }
66/// ```
67#[derive(Debug)]
68pub struct SessionBuilder<'a> {
69    env: &'a Environment,
70    session_options_ptr: *mut sys::OrtSessionOptions,
71
72    allocator: AllocatorType,
73    memory_type: MemType,
74}
75
76impl<'a> Drop for SessionBuilder<'a> {
77    #[tracing::instrument]
78    fn drop(&mut self) {
79        if self.session_options_ptr.is_null() {
80            error!("Session options pointer is null, not dropping");
81        } else {
82            debug!("Dropping the session options.");
83            unsafe { g_ort().ReleaseSessionOptions.unwrap()(self.session_options_ptr) };
84        }
85    }
86}
87
88impl<'a> SessionBuilder<'a> {
89    pub(crate) fn new(env: &'a Environment) -> Result<SessionBuilder<'a>> {
90        let mut session_options_ptr: *mut sys::OrtSessionOptions = std::ptr::null_mut();
91        let status = unsafe { g_ort().CreateSessionOptions.unwrap()(&mut session_options_ptr) };
92
93        status_to_result(status).map_err(OrtError::SessionOptions)?;
94        assert_null_pointer(status, "SessionStatus")?;
95        assert_not_null_pointer(session_options_ptr, "SessionOptions")?;
96
97        Ok(SessionBuilder {
98            env,
99            session_options_ptr,
100            allocator: AllocatorType::Arena,
101            memory_type: MemType::Default,
102        })
103    }
104
105    /// Configure the session to use a number of threads
106    pub fn with_number_threads(self, num_threads: i16) -> Result<SessionBuilder<'a>> {
107        // FIXME: Pre-built binaries use OpenMP, set env variable instead
108
109        // We use a u16 in the builder to cover the 16-bits positive values of a i32.
110        let num_threads = num_threads as i32;
111        let status =
112            unsafe { g_ort().SetIntraOpNumThreads.unwrap()(self.session_options_ptr, num_threads) };
113        status_to_result(status).map_err(OrtError::SessionOptions)?;
114        assert_null_pointer(status, "SessionStatus")?;
115        Ok(self)
116    }
117
118    /// Set the session's optimization level
119    pub fn with_optimization_level(
120        self,
121        opt_level: GraphOptimizationLevel,
122    ) -> Result<SessionBuilder<'a>> {
123        // Sets graph optimization level
124        unsafe {
125            g_ort().SetSessionGraphOptimizationLevel.unwrap()(
126                self.session_options_ptr,
127                opt_level.into(),
128            )
129        };
130        Ok(self)
131    }
132
133    /// Set the session to use cpu
134    #[cfg(feature = "cuda")]
135    pub fn use_cpu(self, use_arena: i32) -> Result<SessionBuilder<'a>> {
136        unsafe {
137            sys::OrtSessionOptionsAppendExecutionProvider_CPU(self.session_options_ptr, use_arena);
138        }
139        Ok(self)
140    }
141
142    /// Set the session to use cuda
143    #[cfg(feature = "cuda")]
144    pub fn use_cuda(self, device_id: i32) -> Result<SessionBuilder<'a>> {
145        unsafe {
146            sys::OrtSessionOptionsAppendExecutionProvider_CUDA(self.session_options_ptr, device_id);
147        }
148        Ok(self)
149    }
150
151    /// Set the session's allocator
152    ///
153    /// Defaults to [`AllocatorType::Arena`](../enum.AllocatorType.html#variant.Arena)
154    pub fn with_allocator(mut self, allocator: AllocatorType) -> Result<SessionBuilder<'a>> {
155        self.allocator = allocator;
156        Ok(self)
157    }
158
159    /// Set the session's memory type
160    ///
161    /// Defaults to [`MemType::Default`](../enum.MemType.html#variant.Default)
162    pub fn with_memory_type(mut self, memory_type: MemType) -> Result<SessionBuilder<'a>> {
163        self.memory_type = memory_type;
164        Ok(self)
165    }
166
167    /// Download an ONNX pre-trained model from the [ONNX Model Zoo](https://github.com/onnx/models) and commit the session
168    #[cfg(feature = "model-fetching")]
169    pub fn with_model_downloaded<M>(self, model: M) -> Result<Session<'a>>
170    where
171        M: Into<AvailableOnnxModel>,
172    {
173        self.with_model_downloaded_monomorphized(model.into())
174    }
175
176    #[cfg(feature = "model-fetching")]
177    fn with_model_downloaded_monomorphized(self, model: AvailableOnnxModel) -> Result<Session<'a>> {
178        let download_dir = env::current_dir().map_err(OrtDownloadError::IoError)?;
179        let downloaded_path = model.download_to(download_dir)?;
180        self.with_model_from_file(downloaded_path)
181    }
182
183    // TODO: Add all functions changing the options.
184    //       See all OrtApi methods taking a `options: *mut OrtSessionOptions`.
185
186    /// Load an ONNX graph from a file and commit the session
187    pub fn with_model_from_file<P>(self, model_filepath_ref: P) -> Result<Session<'a>>
188    where
189        P: AsRef<Path> + 'a,
190    {
191        let model_filepath = model_filepath_ref.as_ref();
192        let mut session_ptr: *mut sys::OrtSession = std::ptr::null_mut();
193
194        if !model_filepath.exists() {
195            return Err(OrtError::FileDoesNotExists {
196                filename: model_filepath.to_path_buf(),
197            });
198        }
199
200        // Build an OsString than a vector of bytes to pass to C
201        let model_path = std::ffi::OsString::from(model_filepath);
202        #[cfg(target_family = "windows")]
203        let model_path: Vec<u16> = model_path
204            .encode_wide()
205            .chain(std::iter::once(0)) // Make sure we have a null terminated string
206            .collect();
207        #[cfg(not(target_family = "windows"))]
208        let model_path: Vec<std::os::raw::c_char> = model_path
209            .as_bytes()
210            .iter()
211            .chain(std::iter::once(&b'\0')) // Make sure we have a null terminated string
212            .map(|b| *b as std::os::raw::c_char)
213            .collect();
214
215        let env_ptr: *const sys::OrtEnv = self.env.env_ptr();
216
217        let status = unsafe {
218            g_ort().CreateSession.unwrap()(
219                env_ptr,
220                model_path.as_ptr(),
221                self.session_options_ptr,
222                &mut session_ptr,
223            )
224        };
225        status_to_result(status).map_err(OrtError::Session)?;
226        assert_null_pointer(status, "SessionStatus")?;
227        assert_not_null_pointer(session_ptr, "Session")?;
228
229        let mut allocator_ptr: *mut sys::OrtAllocator = std::ptr::null_mut();
230        let status = unsafe { g_ort().GetAllocatorWithDefaultOptions.unwrap()(&mut allocator_ptr) };
231        status_to_result(status).map_err(OrtError::Allocator)?;
232        assert_null_pointer(status, "SessionStatus")?;
233        assert_not_null_pointer(allocator_ptr, "Allocator")?;
234
235        let memory_info = MemoryInfo::new(AllocatorType::Arena, MemType::Default)?;
236
237        // Extract input and output properties
238        let num_input_nodes = dangerous::extract_inputs_count(session_ptr)?;
239        let num_output_nodes = dangerous::extract_outputs_count(session_ptr)?;
240        let inputs = (0..num_input_nodes)
241            .map(|i| dangerous::extract_input(session_ptr, allocator_ptr, i))
242            .collect::<Result<Vec<Input>>>()?;
243        let outputs = (0..num_output_nodes)
244            .map(|i| dangerous::extract_output(session_ptr, allocator_ptr, i))
245            .collect::<Result<Vec<Output>>>()?;
246
247        Ok(Session {
248            env: self.env,
249            session_ptr,
250            allocator_ptr,
251            memory_info,
252            inputs,
253            outputs,
254        })
255    }
256
257    /// Load an ONNX graph from memory and commit the session
258    pub fn with_model_from_memory<B>(self, model_bytes: B) -> Result<Session<'a>>
259    where
260        B: AsRef<[u8]>,
261    {
262        self.with_model_from_memory_monomorphized(model_bytes.as_ref())
263    }
264
265    fn with_model_from_memory_monomorphized(self, model_bytes: &[u8]) -> Result<Session<'a>> {
266        let mut session_ptr: *mut sys::OrtSession = std::ptr::null_mut();
267
268        let env_ptr: *const sys::OrtEnv = self.env.env_ptr();
269
270        let status = unsafe {
271            let model_data = model_bytes.as_ptr() as *const std::ffi::c_void;
272            let model_data_length = model_bytes.len();
273            g_ort().CreateSessionFromArray.unwrap()(
274                env_ptr,
275                model_data,
276                model_data_length,
277                self.session_options_ptr,
278                &mut session_ptr,
279            )
280        };
281        status_to_result(status).map_err(OrtError::Session)?;
282        assert_null_pointer(status, "SessionStatus")?;
283        assert_not_null_pointer(session_ptr, "Session")?;
284
285        let mut allocator_ptr: *mut sys::OrtAllocator = std::ptr::null_mut();
286        let status = unsafe { g_ort().GetAllocatorWithDefaultOptions.unwrap()(&mut allocator_ptr) };
287        status_to_result(status).map_err(OrtError::Allocator)?;
288        assert_null_pointer(status, "SessionStatus")?;
289        assert_not_null_pointer(allocator_ptr, "Allocator")?;
290
291        let memory_info = MemoryInfo::new(AllocatorType::Arena, MemType::Default)?;
292
293        // Extract input and output properties
294        let num_input_nodes = dangerous::extract_inputs_count(session_ptr)?;
295        let num_output_nodes = dangerous::extract_outputs_count(session_ptr)?;
296        let inputs = (0..num_input_nodes)
297            .map(|i| dangerous::extract_input(session_ptr, allocator_ptr, i))
298            .collect::<Result<Vec<Input>>>()?;
299        let outputs = (0..num_output_nodes)
300            .map(|i| dangerous::extract_output(session_ptr, allocator_ptr, i))
301            .collect::<Result<Vec<Output>>>()?;
302
303        Ok(Session {
304            env: self.env,
305            session_ptr,
306            allocator_ptr,
307            memory_info,
308            inputs,
309            outputs,
310        })
311    }
312}
313
314/// Type storing the session information, built from an [`Environment`](environment/struct.Environment.html)
315#[derive(Debug)]
316pub struct Session<'a> {
317    env: &'a Environment,
318    session_ptr: *mut sys::OrtSession,
319    allocator_ptr: *mut sys::OrtAllocator,
320    memory_info: MemoryInfo,
321    /// Information about the ONNX's inputs as stored in loaded file
322    pub inputs: Vec<Input>,
323    /// Information about the ONNX's outputs as stored in loaded file
324    pub outputs: Vec<Output>,
325}
326
327/// Information about an ONNX's input as stored in loaded file
328#[derive(Debug)]
329pub struct Input {
330    /// Name of the input layer
331    pub name: String,
332    /// Type of the input layer's elements
333    pub input_type: TensorElementDataType,
334    /// Shape of the input layer
335    ///
336    /// C API uses a i64 for the dimensions. We use an unsigned of the same range of the positive values.
337    pub dimensions: Vec<Option<u32>>,
338}
339
340/// Information about an ONNX's output as stored in loaded file
341#[derive(Debug)]
342pub struct Output {
343    /// Name of the output layer
344    pub name: String,
345    /// Type of the output layer's elements
346    pub output_type: TensorElementDataType,
347    /// Shape of the output layer
348    ///
349    /// C API uses a i64 for the dimensions. We use an unsigned of the same range of the positive values.
350    pub dimensions: Vec<Option<u32>>,
351}
352
353impl Input {
354    /// Return an iterator over the shape elements of the input layer
355    ///
356    /// Note: The member [`Input::dimensions`](struct.Input.html#structfield.dimensions)
357    /// stores `u32` (since ONNX uses `i64` but which cannot be negative) so the
358    /// iterator converts to `usize`.
359    pub fn dimensions(&self) -> impl Iterator<Item = Option<usize>> + '_ {
360        self.dimensions.iter().map(|d| d.map(|d2| d2 as usize))
361    }
362}
363
364impl Output {
365    /// Return an iterator over the shape elements of the output layer
366    ///
367    /// Note: The member [`Output::dimensions`](struct.Output.html#structfield.dimensions)
368    /// stores `u32` (since ONNX uses `i64` but which cannot be negative) so the
369    /// iterator converts to `usize`.
370    pub fn dimensions(&self) -> impl Iterator<Item = Option<usize>> + '_ {
371        self.dimensions.iter().map(|d| d.map(|d2| d2 as usize))
372    }
373}
374
375impl<'a> Drop for Session<'a> {
376    #[tracing::instrument]
377    fn drop(&mut self) {
378        debug!("Dropping the session.");
379        if self.session_ptr.is_null() {
380            error!("Session pointer is null, not dropping.");
381        } else {
382            unsafe { g_ort().ReleaseSession.unwrap()(self.session_ptr) };
383        }
384        // FIXME: There is no C function to release the allocator?
385
386        self.session_ptr = std::ptr::null_mut();
387        self.allocator_ptr = std::ptr::null_mut();
388    }
389}
390
391impl<'a> Session<'a> {
392    /// Run the input data through the ONNX graph, performing inference.
393    ///
394    /// Note that ONNX models can have multiple inputs; a `Vec<_>` is thus
395    /// used for the input data here.
396    pub fn run<'s, 't, 'm, TIn, TOut, D>(
397        &'s mut self,
398        input_arrays: Vec<Array<TIn, D>>,
399    ) -> Result<Vec<OrtOwnedTensor<'t, 'm, TOut, ndarray::IxDyn>>>
400    where
401        TIn: TypeToTensorElementDataType + Debug + Clone,
402        TOut: TypeToTensorElementDataType + Debug + Clone,
403        D: ndarray::Dimension,
404        'm: 't, // 'm outlives 't (memory info outlives tensor)
405        's: 'm, // 's outlives 'm (session outlives memory info)
406    {
407        self.validate_input_shapes(&input_arrays)?;
408
409        // Build arguments to Run()
410
411        let input_names_ptr: Vec<*const i8> = self
412            .inputs
413            .iter()
414            .map(|input| input.name.clone())
415            .map(|n| CString::new(n).unwrap())
416            .map(|n| n.into_raw() as *const i8)
417            .collect();
418
419        let output_names_cstring: Vec<CString> = self
420            .outputs
421            .iter()
422            .map(|output| output.name.clone())
423            .map(|n| CString::new(n).unwrap())
424            .collect();
425        let output_names_ptr: Vec<*const i8> = output_names_cstring
426            .iter()
427            .map(|n| n.as_ptr() as *const i8)
428            .collect();
429
430        let mut output_tensor_extractors_ptrs: Vec<*mut sys::OrtValue> =
431            vec![std::ptr::null_mut(); self.outputs.len()];
432
433        // The C API expects pointers for the arrays (pointers to C-arrays)
434        let input_ort_tensors: Vec<OrtTensor<TIn, D>> = input_arrays
435            .into_iter()
436            .map(|input_array| {
437                OrtTensor::from_array(&self.memory_info, self.allocator_ptr, input_array)
438            })
439            .collect::<Result<Vec<OrtTensor<TIn, D>>>>()?;
440        let input_ort_values: Vec<*const sys::OrtValue> = input_ort_tensors
441            .iter()
442            .map(|input_array_ort| input_array_ort.c_ptr as *const sys::OrtValue)
443            .collect();
444
445        let run_options_ptr: *const sys::OrtRunOptions = std::ptr::null();
446
447        let status = unsafe {
448            g_ort().Run.unwrap()(
449                self.session_ptr,
450                run_options_ptr,
451                input_names_ptr.as_ptr(),
452                input_ort_values.as_ptr(),
453                input_ort_values.len(),
454                output_names_ptr.as_ptr(),
455                output_names_ptr.len(),
456                output_tensor_extractors_ptrs.as_mut_ptr(),
457            )
458        };
459        status_to_result(status).map_err(OrtError::Run)?;
460
461        let memory_info_ref = &self.memory_info;
462        let outputs: Result<Vec<OrtOwnedTensor<TOut, ndarray::Dim<ndarray::IxDynImpl>>>> =
463            output_tensor_extractors_ptrs
464                .into_iter()
465                .map(|ptr| {
466                    let mut tensor_info_ptr: *mut sys::OrtTensorTypeAndShapeInfo =
467                        std::ptr::null_mut();
468                    let status = unsafe {
469                        g_ort().GetTensorTypeAndShape.unwrap()(ptr, &mut tensor_info_ptr as _)
470                    };
471                    status_to_result(status).map_err(OrtError::GetTensorTypeAndShape)?;
472                    let dims = unsafe { get_tensor_dimensions(tensor_info_ptr) };
473                    unsafe { g_ort().ReleaseTensorTypeAndShapeInfo.unwrap()(tensor_info_ptr) };
474                    let dims: Vec<_> = dims?.iter().map(|&n| n as usize).collect();
475
476                    let mut output_tensor_extractor =
477                        OrtOwnedTensorExtractor::new(memory_info_ref, ndarray::IxDyn(&dims));
478                    output_tensor_extractor.tensor_ptr = ptr;
479                    output_tensor_extractor.extract::<TOut>()
480                })
481                .collect();
482
483        // Reconvert to CString so drop impl is called and memory is freed
484        let cstrings: Result<Vec<CString>> = input_names_ptr
485            .into_iter()
486            .map(|p| {
487                assert_not_null_pointer(p, "i8 for CString")?;
488                unsafe { Ok(CString::from_raw(p as *mut i8)) }
489            })
490            .collect();
491        cstrings?;
492
493        outputs
494    }
495
496    // pub fn tensor_from_array<'a, 'b, T, D>(&'a self, array: Array<T, D>) -> Tensor<'b, T, D>
497    // where
498    //     'a: 'b, // 'a outlives 'b
499    // {
500    //     Tensor::from_array(self, array)
501    // }
502
503    fn validate_input_shapes<TIn, D>(&mut self, input_arrays: &[Array<TIn, D>]) -> Result<()>
504    where
505        TIn: TypeToTensorElementDataType + Debug + Clone,
506        D: ndarray::Dimension,
507    {
508        // ******************************************************************
509        // FIXME: Properly handle errors here
510        // Make sure all dimensions match (except dynamic ones)
511
512        // Verify length of inputs
513        if input_arrays.len() != self.inputs.len() {
514            error!(
515                "Non-matching number of inputs: {} (inference) vs {} (model)",
516                input_arrays.len(),
517                self.inputs.len()
518            );
519            return Err(OrtError::NonMatchingDimensions(
520                NonMatchingDimensionsError::InputsCount {
521                    inference_input_count: 0,
522                    model_input_count: 0,
523                    inference_input: input_arrays
524                        .iter()
525                        .map(|input_array| input_array.shape().to_vec())
526                        .collect(),
527                    model_input: self
528                        .inputs
529                        .iter()
530                        .map(|input| input.dimensions.clone())
531                        .collect(),
532                },
533            ));
534        }
535
536        // Verify length of each individual inputs
537        let inputs_different_length = input_arrays
538            .iter()
539            .zip(self.inputs.iter())
540            .any(|(l, r)| l.shape().len() != r.dimensions.len());
541        if inputs_different_length {
542            error!(
543                "Different input lengths: {:?} vs {:?}",
544                self.inputs, input_arrays
545            );
546            return Err(OrtError::NonMatchingDimensions(
547                NonMatchingDimensionsError::InputsLength {
548                    inference_input: input_arrays
549                        .iter()
550                        .map(|input_array| input_array.shape().to_vec())
551                        .collect(),
552                    model_input: self
553                        .inputs
554                        .iter()
555                        .map(|input| input.dimensions.clone())
556                        .collect(),
557                },
558            ));
559        }
560
561        // Verify shape of each individual inputs
562        let inputs_different_shape = input_arrays.iter().zip(self.inputs.iter()).any(|(l, r)| {
563            let l_shape = l.shape();
564            let r_shape = r.dimensions.as_slice();
565            l_shape.iter().zip(r_shape.iter()).any(|(l2, r2)| match r2 {
566                Some(r3) => *r3 as usize != *l2,
567                None => false, // None means dynamic size; in that case shape always match
568            })
569        });
570        if inputs_different_shape {
571            error!(
572                "Different input lengths: {:?} vs {:?}",
573                self.inputs, input_arrays
574            );
575            return Err(OrtError::NonMatchingDimensions(
576                NonMatchingDimensionsError::InputsLength {
577                    inference_input: input_arrays
578                        .iter()
579                        .map(|input_array| input_array.shape().to_vec())
580                        .collect(),
581                    model_input: self
582                        .inputs
583                        .iter()
584                        .map(|input| input.dimensions.clone())
585                        .collect(),
586                },
587            ));
588        }
589
590        Ok(())
591    }
592}
593
594unsafe fn get_tensor_dimensions(
595    tensor_info_ptr: *const sys::OrtTensorTypeAndShapeInfo,
596) -> Result<Vec<i64>> {
597    let mut num_dims = 0;
598    let status = g_ort().GetDimensionsCount.unwrap()(tensor_info_ptr, &mut num_dims);
599    status_to_result(status).map_err(OrtError::GetDimensionsCount)?;
600    (num_dims != 0)
601        .then(|| ())
602        .ok_or(OrtError::InvalidDimensions)?;
603
604    let mut node_dims: Vec<i64> = vec![0; num_dims];
605    let status = g_ort().GetDimensions.unwrap()(
606        tensor_info_ptr,
607        node_dims.as_mut_ptr(), // FIXME: UB?
608        num_dims,
609    );
610    status_to_result(status).map_err(OrtError::GetDimensions)?;
611    Ok(node_dims)
612}
613
614/// This module contains dangerous functions working on raw pointers.
615/// Those functions are only to be used from inside the
616/// `SessionBuilder::with_model_from_file()` method.
617mod dangerous {
618    use super::*;
619
620    pub(super) fn extract_inputs_count(session_ptr: *mut sys::OrtSession) -> Result<usize> {
621        let f = g_ort().SessionGetInputCount.unwrap();
622        extract_io_count(f, session_ptr)
623    }
624
625    pub(super) fn extract_outputs_count(session_ptr: *mut sys::OrtSession) -> Result<usize> {
626        let f = g_ort().SessionGetOutputCount.unwrap();
627        extract_io_count(f, session_ptr)
628    }
629
630    fn extract_io_count(
631        f: extern_system_fn! { unsafe fn(*const sys::OrtSession, *mut usize) -> *mut sys::OrtStatus },
632        session_ptr: *mut sys::OrtSession,
633    ) -> Result<usize> {
634        let mut num_nodes: usize = 0;
635        let status = unsafe { f(session_ptr, &mut num_nodes) };
636        status_to_result(status).map_err(OrtError::InOutCount)?;
637        assert_null_pointer(status, "SessionStatus")?;
638        (num_nodes != 0).then(|| ()).ok_or_else(|| {
639            OrtError::InOutCount(OrtApiError::Msg("No nodes in model".to_owned()))
640        })?;
641        Ok(num_nodes)
642    }
643
644    fn extract_input_name(
645        session_ptr: *mut sys::OrtSession,
646        allocator_ptr: *mut sys::OrtAllocator,
647        i: usize,
648    ) -> Result<String> {
649        let f = g_ort().SessionGetInputName.unwrap();
650        extract_io_name(f, session_ptr, allocator_ptr, i)
651    }
652
653    fn extract_output_name(
654        session_ptr: *mut sys::OrtSession,
655        allocator_ptr: *mut sys::OrtAllocator,
656        i: usize,
657    ) -> Result<String> {
658        let f = g_ort().SessionGetOutputName.unwrap();
659        extract_io_name(f, session_ptr, allocator_ptr, i)
660    }
661
662    fn extract_io_name(
663        f: extern_system_fn! { unsafe fn(
664            *const sys::OrtSession,
665            usize,
666            *mut sys::OrtAllocator,
667            *mut *mut i8,
668        ) -> *mut sys::OrtStatus },
669        session_ptr: *mut sys::OrtSession,
670        allocator_ptr: *mut sys::OrtAllocator,
671        i: usize,
672    ) -> Result<String> {
673        let mut name_bytes: *mut i8 = std::ptr::null_mut();
674
675        let status = unsafe { f(session_ptr, i, allocator_ptr, &mut name_bytes) };
676        status_to_result(status).map_err(OrtError::InputName)?;
677        assert_not_null_pointer(name_bytes, "InputName")?;
678
679        // FIXME: Is it safe to keep ownership of the memory?
680        let name = char_p_to_string(name_bytes)?;
681
682        Ok(name)
683    }
684
685    pub(super) fn extract_input(
686        session_ptr: *mut sys::OrtSession,
687        allocator_ptr: *mut sys::OrtAllocator,
688        i: usize,
689    ) -> Result<Input> {
690        let input_name = extract_input_name(session_ptr, allocator_ptr, i)?;
691        let f = g_ort().SessionGetInputTypeInfo.unwrap();
692        let (input_type, dimensions) = extract_io(f, session_ptr, i)?;
693        Ok(Input {
694            name: input_name,
695            input_type,
696            dimensions,
697        })
698    }
699
700    pub(super) fn extract_output(
701        session_ptr: *mut sys::OrtSession,
702        allocator_ptr: *mut sys::OrtAllocator,
703        i: usize,
704    ) -> Result<Output> {
705        let output_name = extract_output_name(session_ptr, allocator_ptr, i)?;
706        let f = g_ort().SessionGetOutputTypeInfo.unwrap();
707        let (output_type, dimensions) = extract_io(f, session_ptr, i)?;
708        Ok(Output {
709            name: output_name,
710            output_type,
711            dimensions,
712        })
713    }
714
715    fn extract_io(
716        f: extern_system_fn! { unsafe fn(
717            *const sys::OrtSession,
718            usize,
719            *mut *mut sys::OrtTypeInfo,
720        ) -> *mut sys::OrtStatus },
721        session_ptr: *mut sys::OrtSession,
722        i: usize,
723    ) -> Result<(TensorElementDataType, Vec<Option<u32>>)> {
724        let mut typeinfo_ptr: *mut sys::OrtTypeInfo = std::ptr::null_mut();
725
726        let status = unsafe { f(session_ptr, i, &mut typeinfo_ptr) };
727        status_to_result(status).map_err(OrtError::GetTypeInfo)?;
728        assert_not_null_pointer(typeinfo_ptr, "TypeInfo")?;
729
730        let mut tensor_info_ptr: *const sys::OrtTensorTypeAndShapeInfo = std::ptr::null_mut();
731        let status = unsafe {
732            g_ort().CastTypeInfoToTensorInfo.unwrap()(typeinfo_ptr, &mut tensor_info_ptr)
733        };
734        status_to_result(status).map_err(OrtError::CastTypeInfoToTensorInfo)?;
735        assert_not_null_pointer(tensor_info_ptr, "TensorInfo")?;
736
737        let mut type_sys = sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
738        let status =
739            unsafe { g_ort().GetTensorElementType.unwrap()(tensor_info_ptr, &mut type_sys) };
740        status_to_result(status).map_err(OrtError::TensorElementType)?;
741        (type_sys != sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED)
742            .then(|| ())
743            .ok_or(OrtError::UndefinedTensorElementType)?;
744        // This transmute should be safe since its value is read from GetTensorElementType which we must trust.
745        let io_type: TensorElementDataType = unsafe { std::mem::transmute(type_sys) };
746
747        // info!("{} : type={}", i, type_);
748
749        let node_dims = unsafe { get_tensor_dimensions(tensor_info_ptr)? };
750
751        // for j in 0..num_dims {
752        //     info!("{} : dim {}={}", i, j, node_dims[j as usize]);
753        // }
754
755        unsafe { g_ort().ReleaseTypeInfo.unwrap()(typeinfo_ptr) };
756
757        Ok((
758            io_type,
759            node_dims
760                .into_iter()
761                .map(|d| if d == -1 { None } else { Some(d as u32) })
762                .collect(),
763        ))
764    }
765}