mcai_onnxruntime/
lib.rs

1#![warn(missing_docs)]
2
3//! ONNX Runtime
4//!
5//! This crate is a (safe) wrapper around Microsoft's [ONNX Runtime](https://github.com/microsoft/onnxruntime/)
6//! through its C API.
7//!
8//! From its [GitHub page](https://github.com/microsoft/onnxruntime/):
9//!
10//! > ONNX Runtime is a cross-platform, high performance ML inferencing and training accelerator.
11//!
12//! The (highly) unsafe [C API](https://github.com/microsoft/onnxruntime/blob/master/include/onnxruntime/core/session/onnxruntime_c_api.h)
13//! is wrapped using bindgen as [`onnxruntime-sys`](https://crates.io/crates/onnxruntime-sys).
14//!
15//! The unsafe bindings are wrapped in this crate to expose a safe API.
16//!
17//! For now, efforts are concentrated on the inference API. Training is _not_ supported.
18//!
19//! # Example
20//!
21//! The C++ example that uses the C API
22//! ([`C_Api_Sample.cpp`](https://github.com/microsoft/onnxruntime/blob/v1.3.1/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests.Capi/C_Api_Sample.cpp))
23//! was ported to
24//! [`onnxruntime`](https://github.com/nbigaouette/onnxruntime-rs/blob/master/onnxruntime/examples/sample.rs).
25//!
26//! First, an environment must be created using and [`EnvBuilder`](environment/struct.EnvBuilder.html):
27//!
28//! ```no_run
29//! # use std::error::Error;
30//! # use onnxruntime::{environment::Environment, LoggingLevel};
31//! # fn main() -> Result<(), Box<dyn Error>> {
32//! let environment = Environment::builder()
33//!     .with_name("test")
34//!     .with_log_level(LoggingLevel::Verbose)
35//!     .build()?;
36//! # Ok(())
37//! # }
38//! ```
39//!
40//! Then a [`Session`](session/struct.Session.html) is created from the environment, some options and an ONNX archive:
41//!
42//! ```no_run
43//! # use std::error::Error;
44//! # use onnxruntime::{environment::Environment, LoggingLevel, GraphOptimizationLevel};
45//! # fn main() -> Result<(), Box<dyn Error>> {
46//! # let environment = Environment::builder()
47//! #     .with_name("test")
48//! #     .with_log_level(LoggingLevel::Verbose)
49//! #     .build()?;
50//! let mut session = environment
51//!     .new_session_builder()?
52//!     .with_optimization_level(GraphOptimizationLevel::Basic)?
53//!     .with_number_threads(1)?
54//!     .with_model_from_file("squeezenet.onnx")?;
55//! # Ok(())
56//! # }
57//! ```
58//!
59#![cfg_attr(
60    feature = "model-fetching",
61    doc = r##"
62Instead of loading a model from file using [`with_model_from_file()`](session/struct.SessionBuilder.html#method.with_model_from_file),
63a model can be fetched directly from the [ONNX Model Zoo](https://github.com/onnx/models) using
64[`with_model_downloaded()`](session/struct.SessionBuilder.html#method.with_model_downloaded) method
65(requires the `model-fetching` feature).
66
67```no_run
68# use std::error::Error;
69# use onnxruntime::{environment::Environment, download::vision::ImageClassification, LoggingLevel, GraphOptimizationLevel};
70# fn main() -> Result<(), Box<dyn Error>> {
71# let environment = Environment::builder()
72#     .with_name("test")
73#     .with_log_level(LoggingLevel::Verbose)
74#     .build()?;
75let mut session = environment
76    .new_session_builder()?
77    .with_optimization_level(GraphOptimizationLevel::Basic)?
78    .with_number_threads(1)?
79    .with_model_downloaded(ImageClassification::SqueezeNet)?;
80# Ok(())
81# }
82```
83
84See [`AvailableOnnxModel`](download/enum.AvailableOnnxModel.html) for the different models available
85to download.
86"##
87)]
88//!
89//! Inference will be run on data passed as an [`ndarray::Array`](https://docs.rs/ndarray/latest/ndarray/type.Array.html).
90//!
91//! ```no_run
92//! # use std::error::Error;
93//! # use onnxruntime::{environment::Environment, LoggingLevel, GraphOptimizationLevel, tensor::{OrtOwnedTensor, InputTensor, FromArray}};
94//! # fn main() -> Result<(), Box<dyn Error>> {
95//! # let environment = Environment::builder()
96//! #     .with_name("test")
97//! #     .with_log_level(LoggingLevel::Verbose)
98//! #     .build()?;
99//! # let mut session = environment
100//! #     .new_session_builder()?
101//! #     .with_optimization_level(GraphOptimizationLevel::Basic)?
102//! #     .with_number_threads(1)?
103//! #     .with_model_from_file("squeezenet.onnx")?;
104//! let array = ndarray::Array::linspace(0.0_f32, 1.0, 100);
105//! // Multiple inputs and outputs are possible
106//! let input_tensor = vec![InputTensor::from_array(array)];
107//! let outputs: Vec<OrtOwnedTensor<f32,_>> = session.run(input_tensor)?;
108//! # Ok(())
109//! # }
110//! ```
111//!
112//! The outputs are of type [`OrtOwnedTensor`](tensor/ort_owned_tensor/struct.OrtOwnedTensor.html)s inside a vector,
113//! with the same length as the inputs.
114//!
115//! See the [`sample.rs`](https://github.com/nbigaouette/onnxruntime-rs/blob/master/onnxruntime/examples/sample.rs)
116//! example for more details.
117
118use std::sync::{atomic::AtomicPtr, Arc, Mutex};
119
120use lazy_static::lazy_static;
121
122use mcai_onnxruntime_sys as sys;
123
124// Make functions `extern "stdcall"` for Windows 32bit.
125// This behaviors like `extern "system"`.
126#[cfg(all(target_os = "windows", target_arch = "x86"))]
127macro_rules! extern_system_fn {
128    ($(#[$meta:meta])* fn $($tt:tt)*) => ($(#[$meta])* extern "stdcall" fn $($tt)*);
129    ($(#[$meta:meta])* $vis:vis fn $($tt:tt)*) => ($(#[$meta])* $vis extern "stdcall" fn $($tt)*);
130    ($(#[$meta:meta])* unsafe fn $($tt:tt)*) => ($(#[$meta])* unsafe extern "stdcall" fn $($tt)*);
131    ($(#[$meta:meta])* $vis:vis unsafe fn $($tt:tt)*) => ($(#[$meta])* $vis unsafe extern "stdcall" fn $($tt)*);
132}
133
134// Make functions `extern "C"` for normal targets.
135// This behaviors like `extern "system"`.
136#[cfg(not(all(target_os = "windows", target_arch = "x86")))]
137macro_rules! extern_system_fn {
138    ($(#[$meta:meta])* fn $($tt:tt)*) => ($(#[$meta])* extern "C" fn $($tt)*);
139    ($(#[$meta:meta])* $vis:vis fn $($tt:tt)*) => ($(#[$meta])* $vis extern "C" fn $($tt)*);
140    ($(#[$meta:meta])* unsafe fn $($tt:tt)*) => ($(#[$meta])* unsafe extern "C" fn $($tt)*);
141    ($(#[$meta:meta])* $vis:vis unsafe fn $($tt:tt)*) => ($(#[$meta])* $vis unsafe extern "C" fn $($tt)*);
142}
143
144pub mod download;
145pub mod environment;
146pub mod error;
147mod memory;
148pub mod session;
149pub mod tensor;
150
151// Re-export
152pub use error::{OrtApiError, OrtError, Result};
153use sys::OnnxEnumInt;
154
155// Re-export ndarray as it's part of the public API anyway
156pub use ndarray;
157
158lazy_static! {
159    // static ref G_ORT: Arc<Mutex<AtomicPtr<sys::OrtApi>>> =
160    //     Arc::new(Mutex::new(AtomicPtr::new(unsafe {
161    //         sys::OrtGetApiBase().as_ref().unwrap().GetApi.unwrap()(sys::ORT_API_VERSION)
162    //     } as *mut sys::OrtApi)));
163    static ref G_ORT_API: Arc<Mutex<AtomicPtr<sys::OrtApi>>> = {
164        let base: *const sys::OrtApiBase = unsafe { sys::OrtGetApiBase() };
165        assert_ne!(base, std::ptr::null());
166        let get_api: extern_system_fn!{ unsafe fn(u32) -> *const mcai_onnxruntime_sys::OrtApi } =
167            unsafe { (*base).GetApi.unwrap() };
168        let api: *const sys::OrtApi = unsafe { get_api(sys::ORT_API_VERSION) };
169        Arc::new(Mutex::new(AtomicPtr::new(api as *mut sys::OrtApi)))
170    };
171}
172
173fn g_ort() -> sys::OrtApi {
174    let mut api_ref = G_ORT_API
175        .lock()
176        .expect("Failed to acquire lock: another thread panicked?");
177    let api_ref_mut: &mut *mut sys::OrtApi = api_ref.get_mut();
178    let api_ptr_mut: *mut sys::OrtApi = *api_ref_mut;
179
180    assert_ne!(api_ptr_mut, std::ptr::null_mut());
181
182    unsafe { *api_ptr_mut }
183}
184
185fn char_p_to_string(raw: *const i8) -> Result<String> {
186    let c_string = unsafe { std::ffi::CStr::from_ptr(raw as *mut i8).to_owned() };
187
188    match c_string.into_string() {
189        Ok(string) => Ok(string),
190        Err(e) => Err(OrtApiError::IntoStringError(e)),
191    }
192    .map_err(OrtError::StringConversion)
193}
194
195mod onnxruntime {
196    //! Module containing a custom logger, used to catch the runtime's own logging and send it
197    //! to Rust's tracing logging instead.
198
199    use std::ffi::CStr;
200    use tracing::{debug, error, info, span, trace, warn, Level};
201
202    use mcai_onnxruntime_sys as sys;
203
204    /// Runtime's logging sends the code location where the log happened, will be parsed to this struct.
205    #[derive(Debug)]
206    struct CodeLocation<'a> {
207        file: &'a str,
208        line_number: &'a str,
209        function: &'a str,
210    }
211
212    impl<'a> From<&'a str> for CodeLocation<'a> {
213        fn from(code_location: &'a str) -> Self {
214            let mut splitter = code_location.split(' ');
215            let file_and_line_number = splitter.next().unwrap_or("<unknown file:line>");
216            let function = splitter.next().unwrap_or("<unknown module>");
217            let mut file_and_line_number_splitter = file_and_line_number.split(':');
218            let file = file_and_line_number_splitter
219                .next()
220                .unwrap_or("<unknown file>");
221            let line_number = file_and_line_number_splitter
222                .next()
223                .unwrap_or("<unknown line number>");
224
225            CodeLocation {
226                file,
227                line_number,
228                function,
229            }
230        }
231    }
232
233    extern_system_fn! {
234        /// Callback from C that will handle the logging, forwarding the runtime's logs to the tracing crate.
235        pub(crate) fn custom_logger(
236            _params: *mut std::ffi::c_void,
237            severity: sys::OrtLoggingLevel,
238            category: *const i8,
239            logid: *const i8,
240            code_location: *const i8,
241            message: *const i8,
242        ) {
243            let log_level = match severity {
244                sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE => Level::TRACE,
245                sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO => Level::DEBUG,
246                sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING => Level::INFO,
247                sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_ERROR => Level::WARN,
248                sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_FATAL => Level::ERROR,
249            };
250
251            assert_ne!(category, std::ptr::null());
252            let category = unsafe { CStr::from_ptr(category) };
253            assert_ne!(code_location, std::ptr::null());
254            let code_location = unsafe { CStr::from_ptr(code_location) }
255                .to_str()
256                .unwrap_or("unknown");
257            assert_ne!(message, std::ptr::null());
258            let message = unsafe { CStr::from_ptr(message) };
259
260            assert_ne!(logid, std::ptr::null());
261            let logid = unsafe { CStr::from_ptr(logid) };
262
263            // Parse the code location
264            let code_location: CodeLocation = code_location.into();
265
266            let span = span!(
267                Level::TRACE,
268                "onnxruntime",
269                category = category.to_str().unwrap_or("<unknown>"),
270                file = code_location.file,
271                line_number = code_location.line_number,
272                function = code_location.function,
273                logid = logid.to_str().unwrap_or("<unknown>"),
274            );
275            let _enter = span.enter();
276
277            match log_level {
278                Level::TRACE => trace!("{:?}", message),
279                Level::DEBUG => debug!("{:?}", message),
280                Level::INFO => info!("{:?}", message),
281                Level::WARN => warn!("{:?}", message),
282                Level::ERROR => error!("{:?}", message),
283            }
284        }
285    }
286}
287
288/// Logging level of the ONNX Runtime C API
289#[derive(Debug)]
290#[cfg_attr(not(windows), repr(u32))]
291#[cfg_attr(windows, repr(i32))]
292pub enum LoggingLevel {
293    /// Verbose log level
294    Verbose = sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE as OnnxEnumInt,
295    /// Info log level
296    Info = sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO as OnnxEnumInt,
297    /// Warning log level
298    Warning = sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING as OnnxEnumInt,
299    /// Error log level
300    Error = sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_ERROR as OnnxEnumInt,
301    /// Fatal log level
302    Fatal = sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_FATAL as OnnxEnumInt,
303}
304
305impl From<LoggingLevel> for sys::OrtLoggingLevel {
306    fn from(val: LoggingLevel) -> Self {
307        match val {
308            LoggingLevel::Verbose => sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE,
309            LoggingLevel::Info => sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO,
310            LoggingLevel::Warning => sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING,
311            LoggingLevel::Error => sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_ERROR,
312            LoggingLevel::Fatal => sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_FATAL,
313        }
314    }
315}
316
317/// Optimization level performed by ONNX Runtime of the loaded graph
318///
319/// See the [official documentation](https://github.com/microsoft/onnxruntime/blob/master/docs/ONNX_Runtime_Graph_Optimizations.md)
320/// for more information on the different optimization levels.
321#[derive(Debug)]
322#[cfg_attr(not(windows), repr(u32))]
323#[cfg_attr(windows, repr(i32))]
324pub enum GraphOptimizationLevel {
325    /// Disable optimization
326    DisableAll = sys::GraphOptimizationLevel::ORT_DISABLE_ALL as OnnxEnumInt,
327    /// Basic optimization
328    Basic = sys::GraphOptimizationLevel::ORT_ENABLE_BASIC as OnnxEnumInt,
329    /// Extended optimization
330    Extended = sys::GraphOptimizationLevel::ORT_ENABLE_EXTENDED as OnnxEnumInt,
331    /// Add optimization
332    All = sys::GraphOptimizationLevel::ORT_ENABLE_ALL as OnnxEnumInt,
333}
334
335impl From<GraphOptimizationLevel> for sys::GraphOptimizationLevel {
336    fn from(val: GraphOptimizationLevel) -> Self {
337        use GraphOptimizationLevel::*;
338        match val {
339            DisableAll => sys::GraphOptimizationLevel::ORT_DISABLE_ALL,
340            Basic => sys::GraphOptimizationLevel::ORT_ENABLE_BASIC,
341            Extended => sys::GraphOptimizationLevel::ORT_ENABLE_EXTENDED,
342            All => sys::GraphOptimizationLevel::ORT_ENABLE_ALL,
343        }
344    }
345}
346
347// FIXME: Use https://docs.rs/bindgen/0.54.1/bindgen/struct.Builder.html#method.rustified_enum
348// FIXME: Add tests to cover the commented out types
349/// Enum mapping ONNX Runtime's supported tensor types
350#[derive(Debug)]
351#[cfg_attr(not(windows), repr(u32))]
352#[cfg_attr(windows, repr(i32))]
353pub enum TensorElementDataType {
354    /// 32-bit floating point, equivalent to Rust's `f32`
355    Float = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT as OnnxEnumInt,
356    /// Unsigned 8-bit int, equivalent to Rust's `u8`
357    Uint8 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8 as OnnxEnumInt,
358    /// Signed 8-bit int, equivalent to Rust's `i8`
359    Int8 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8 as OnnxEnumInt,
360    /// Unsigned 16-bit int, equivalent to Rust's `u16`
361    Uint16 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16 as OnnxEnumInt,
362    /// Signed 16-bit int, equivalent to Rust's `i16`
363    Int16 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16 as OnnxEnumInt,
364    /// Signed 32-bit int, equivalent to Rust's `i32`
365    Int32 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32 as OnnxEnumInt,
366    /// Signed 64-bit int, equivalent to Rust's `i64`
367    Int64 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64 as OnnxEnumInt,
368    /// String, equivalent to Rust's `String`
369    String = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING as OnnxEnumInt,
370    // /// Boolean, equivalent to Rust's `bool`
371    // Bool = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL as OnnxEnumInt,
372    // /// 16-bit floating point, equivalent to Rust's `f16`
373    // Float16 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 as OnnxEnumInt,
374    /// 64-bit floating point, equivalent to Rust's `f64`
375    Double = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE as OnnxEnumInt,
376    /// Unsigned 32-bit int, equivalent to Rust's `u32`
377    Uint32 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32 as OnnxEnumInt,
378    /// Unsigned 64-bit int, equivalent to Rust's `u64`
379    Uint64 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64 as OnnxEnumInt,
380    // /// Complex 64-bit floating point, equivalent to Rust's `???`
381    // Complex64 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64 as OnnxEnumInt,
382    // /// Complex 128-bit floating point, equivalent to Rust's `???`
383    // Complex128 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128 as OnnxEnumInt,
384    // /// Brain 16-bit floating point
385    // Bfloat16 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16 as OnnxEnumInt,
386}
387
388impl From<TensorElementDataType> for sys::ONNXTensorElementDataType {
389    fn from(val: TensorElementDataType) -> Self {
390        use TensorElementDataType::*;
391        match val {
392            Float => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,
393            Uint8 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8,
394            Int8 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8,
395            Uint16 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16,
396            Int16 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16,
397            Int32 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32,
398            Int64 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64,
399            String => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING,
400            // Bool => {
401            //     sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL
402            // }
403            // Float16 => {
404            //     sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16
405            // }
406            Double => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE,
407            Uint32 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32,
408            Uint64 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64,
409            // Complex64 => {
410            //     sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64
411            // }
412            // Complex128 => {
413            //     sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128
414            // }
415            // Bfloat16 => {
416            //     sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16
417            // }
418        }
419    }
420}
421
422/// Trait used to map Rust types (for example `f32`) to ONNX types (for example `Float`)
423pub trait TypeToTensorElementDataType {
424    /// Return the ONNX type for a Rust type
425    fn tensor_element_data_type() -> TensorElementDataType;
426
427    /// If the type is `String`, returns `Some` with utf8 contents, else `None`.
428    fn try_utf8_bytes(&self) -> Option<&[u8]>;
429}
430
431macro_rules! impl_type_trait {
432    ($type_:ty, $variant:ident) => {
433        impl TypeToTensorElementDataType for $type_ {
434            fn tensor_element_data_type() -> TensorElementDataType {
435                // unsafe { std::mem::transmute(TensorElementDataType::$variant) }
436                TensorElementDataType::$variant
437            }
438
439            fn try_utf8_bytes(&self) -> Option<&[u8]> {
440                None
441            }
442        }
443    };
444}
445
446impl_type_trait!(f32, Float);
447impl_type_trait!(u8, Uint8);
448impl_type_trait!(i8, Int8);
449impl_type_trait!(u16, Uint16);
450impl_type_trait!(i16, Int16);
451impl_type_trait!(i32, Int32);
452impl_type_trait!(i64, Int64);
453// impl_type_trait!(bool, Bool);
454// impl_type_trait!(f16, Float16);
455impl_type_trait!(f64, Double);
456impl_type_trait!(u32, Uint32);
457impl_type_trait!(u64, Uint64);
458// impl_type_trait!(, Complex64);
459// impl_type_trait!(, Complex128);
460// impl_type_trait!(, Bfloat16);
461
462/// Adapter for common Rust string types to Onnx strings.
463///
464/// It should be easy to use both `String` and `&str` as [TensorElementDataType::String] data, but
465/// we can't define an automatic implementation for anything that implements `AsRef<str>` as it
466/// would conflict with the implementations of [TypeToTensorElementDataType] for primitive numeric
467/// types (which might implement `AsRef<str>` at some point in the future).
468pub trait Utf8Data {
469    /// Returns the utf8 contents.
470    fn utf8_bytes(&self) -> &[u8];
471}
472
473impl Utf8Data for String {
474    fn utf8_bytes(&self) -> &[u8] {
475        self.as_bytes()
476    }
477}
478
479impl<'a> Utf8Data for &'a str {
480    fn utf8_bytes(&self) -> &[u8] {
481        self.as_bytes()
482    }
483}
484
485impl<T: Utf8Data> TypeToTensorElementDataType for T {
486    fn tensor_element_data_type() -> TensorElementDataType {
487        TensorElementDataType::String
488    }
489
490    fn try_utf8_bytes(&self) -> Option<&[u8]> {
491        Some(self.utf8_bytes())
492    }
493}
494
495/// Allocator type
496#[derive(Debug, Clone)]
497#[repr(i32)]
498pub enum AllocatorType {
499    // Invalid = sys::OrtAllocatorType::Invalid as i32,
500    /// Device allocator
501    Device = sys::OrtAllocatorType::OrtDeviceAllocator as i32,
502    /// Arena allocator
503    Arena = sys::OrtAllocatorType::OrtArenaAllocator as i32,
504}
505
506impl From<AllocatorType> for sys::OrtAllocatorType {
507    fn from(val: AllocatorType) -> Self {
508        use AllocatorType::*;
509        match val {
510            // Invalid => sys::OrtAllocatorType::Invalid,
511            Device => sys::OrtAllocatorType::OrtDeviceAllocator,
512            Arena => sys::OrtAllocatorType::OrtArenaAllocator,
513        }
514    }
515}
516
517/// Memory type
518///
519/// Only support ONNX's default type for now.
520#[derive(Debug, Clone)]
521#[repr(i32)]
522pub enum MemType {
523    // FIXME: C API's `OrtMemType_OrtMemTypeCPU` defines it equal to `OrtMemType_OrtMemTypeCPUOutput`. How to handle this??
524    // CPUInput = sys::OrtMemType::OrtMemTypeCPUInput as i32,
525    // CPUOutput = sys::OrtMemType::OrtMemTypeCPUOutput as i32,
526    // CPU = sys::OrtMemType::OrtMemTypeCPU as i32,
527    /// Default memory type
528    Default = sys::OrtMemType::OrtMemTypeDefault as i32,
529}
530
531impl From<MemType> for sys::OrtMemType {
532    fn from(val: MemType) -> Self {
533        use MemType::*;
534        match val {
535            // CPUInput => sys::OrtMemType::OrtMemTypeCPUInput,
536            // CPUOutput => sys::OrtMemType::OrtMemTypeCPUOutput,
537            // CPU => sys::OrtMemType::OrtMemTypeCPU,
538            Default => sys::OrtMemType::OrtMemTypeDefault,
539        }
540    }
541}
542
543#[cfg(test)]
544mod test {
545    use super::*;
546
547    #[test]
548    fn test_char_p_to_string() {
549        let s = std::ffi::CString::new("foo").unwrap();
550        let ptr = s.as_c_str().as_ptr();
551        assert_eq!("foo", char_p_to_string(ptr).unwrap());
552    }
553}