onnxruntime_ng/
lib.rs

1#![warn(missing_docs)]
2
3//! ONNX Runtime
4//!
5//! This crate is a (safe) wrapper around Microsoft's [ONNX Runtime](https://github.com/microsoft/onnxruntime/)
6//! through its C API.
7//!
8//! From its [GitHub page](https://github.com/microsoft/onnxruntime/):
9//!
10//! > ONNX Runtime is a cross-platform, high performance ML inferencing and training accelerator.
11//!
12//! The (highly) unsafe [C API](https://github.com/microsoft/onnxruntime/blob/master/include/onnxruntime/core/session/onnxruntime_c_api.h)
13//! is wrapped using bindgen as [`onnxruntime-sys-ng`](https://crates.io/crates/onnxruntime-sys).
14//!
15//! The unsafe bindings are wrapped in this crate to expose a safe API.
16//!
17//! For now, efforts are concentrated on the inference API. Training is _not_ supported.
18//!
19//! # Example
20//!
21//! The C++ example that uses the C API
22//! ([`C_Api_Sample.cpp`](https://github.com/microsoft/onnxruntime/blob/v1.3.1/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests.Capi/C_Api_Sample.cpp))
23//! was ported to
24//! [`onnxruntime`](https://github.com/nbigaouette/onnxruntime-rs/blob/master/onnxruntime/examples/sample.rs).
25//!
26//! First, an environment must be created using and [`EnvBuilder`](environment/struct.EnvBuilder.html):
27//!
28//! ```no_run
29//! # use std::error::Error;
30//! # use onnxruntime::{environment::Environment, LoggingLevel};
31//! # fn main() -> Result<(), Box<dyn Error>> {
32//! let environment = Environment::builder()
33//!     .with_name("test")
34//!     .with_log_level(LoggingLevel::Verbose)
35//!     .build()?;
36//! # Ok(())
37//! # }
38//! ```
39//!
40//! Then a [`Session`](session/struct.Session.html) is created from the environment, some options and an ONNX archive:
41//!
42//! ```no_run
43//! # use std::error::Error;
44//! # use onnxruntime::{environment::Environment, LoggingLevel, GraphOptimizationLevel};
45//! # fn main() -> Result<(), Box<dyn Error>> {
46//! # let environment = Environment::builder()
47//! #     .with_name("test")
48//! #     .with_log_level(LoggingLevel::Verbose)
49//! #     .build()?;
50//! let mut session = environment
51//!     .new_session_builder()?
52//!     .with_optimization_level(GraphOptimizationLevel::Basic)?
53//!     .with_number_threads(1)?
54//!     .with_model_from_file("squeezenet.onnx")?;
55//! # Ok(())
56//! # }
57//! ```
58//!
59#![cfg_attr(
60    feature = "model-fetching",
61    doc = r##"
62Instead of loading a model from file using [`with_model_from_file()`](session/struct.SessionBuilder.html#method.with_model_from_file),
63a model can be fetched directly from the [ONNX Model Zoo](https://github.com/onnx/models) using
64[`with_model_downloaded()`](session/struct.SessionBuilder.html#method.with_model_downloaded) method
65(requires the `model-fetching` feature).
66
67```no_run
68# use std::error::Error;
69# use onnxruntime::{environment::Environment, download::vision::ImageClassification, LoggingLevel, GraphOptimizationLevel};
70# fn main() -> Result<(), Box<dyn Error>> {
71# let environment = Environment::builder()
72#     .with_name("test")
73#     .with_log_level(LoggingLevel::Verbose)
74#     .build()?;
75let mut session = environment
76    .new_session_builder()?
77    .with_optimization_level(GraphOptimizationLevel::Basic)?
78    .with_number_threads(1)?
79    .with_model_downloaded(ImageClassification::SqueezeNet)?;
80# Ok(())
81# }
82```
83
84See [`AvailableOnnxModel`](download/enum.AvailableOnnxModel.html) for the different models available
85to download.
86"##
87)]
88//!
89//! Inference will be run on data passed as an [`ndarray::Array`](https://docs.rs/ndarray/latest/ndarray/type.Array.html).
90//!
91//! ```no_run
92//! # use std::error::Error;
93//! # use onnxruntime::{environment::Environment, LoggingLevel, GraphOptimizationLevel, tensor::OrtOwnedTensor};
94//! # fn main() -> Result<(), Box<dyn Error>> {
95//! # let environment = Environment::builder()
96//! #     .with_name("test")
97//! #     .with_log_level(LoggingLevel::Verbose)
98//! #     .build()?;
99//! # let mut session = environment
100//! #     .new_session_builder()?
101//! #     .with_optimization_level(GraphOptimizationLevel::Basic)?
102//! #     .with_number_threads(1)?
103//! #     .with_model_from_file("squeezenet.onnx")?;
104//! let array = ndarray::Array::linspace(0.0_f32, 1.0, 100);
105//! // Multiple inputs and outputs are possible
106//! let input_tensor = vec![array];
107//! let outputs: Vec<OrtOwnedTensor<f32,_>> = session.run(input_tensor)?;
108//! # Ok(())
109//! # }
110//! ```
111//!
112//! The outputs are of type [`OrtOwnedTensor`](tensor/ort_owned_tensor/struct.OrtOwnedTensor.html)s inside a vector,
113//! with the same length as the inputs.
114//!
115//! See the [`sample.rs`](https://github.com/nbigaouette/onnxruntime-rs/blob/master/onnxruntime/examples/sample.rs)
116//! example for more details.
117
118use std::sync::{atomic::AtomicPtr, Arc, Mutex};
119
120use lazy_static::lazy_static;
121
122use onnxruntime_sys_ng as sys;
123
124// Make functions `extern "stdcall"` for Windows 32bit.
125// This behaviors like `extern "system"`.
126#[cfg(all(target_os = "windows", target_arch = "x86"))]
127macro_rules! extern_system_fn {
128    ($(#[$meta:meta])* fn $($tt:tt)*) => ($(#[$meta])* extern "stdcall" fn $($tt)*);
129    ($(#[$meta:meta])* $vis:vis fn $($tt:tt)*) => ($(#[$meta])* $vis extern "stdcall" fn $($tt)*);
130    ($(#[$meta:meta])* unsafe fn $($tt:tt)*) => ($(#[$meta])* unsafe extern "stdcall" fn $($tt)*);
131    ($(#[$meta:meta])* $vis:vis unsafe fn $($tt:tt)*) => ($(#[$meta])* $vis unsafe extern "stdcall" fn $($tt)*);
132}
133
134// Make functions `extern "C"` for normal targets.
135// This behaviors like `extern "system"`.
136#[cfg(not(all(target_os = "windows", target_arch = "x86")))]
137macro_rules! extern_system_fn {
138    ($(#[$meta:meta])* fn $($tt:tt)*) => ($(#[$meta])* extern "C" fn $($tt)*);
139    ($(#[$meta:meta])* $vis:vis fn $($tt:tt)*) => ($(#[$meta])* $vis extern "C" fn $($tt)*);
140    ($(#[$meta:meta])* unsafe fn $($tt:tt)*) => ($(#[$meta])* unsafe extern "C" fn $($tt)*);
141    ($(#[$meta:meta])* $vis:vis unsafe fn $($tt:tt)*) => ($(#[$meta])* $vis unsafe extern "C" fn $($tt)*);
142}
143
144pub mod download;
145pub mod environment;
146pub mod error;
147mod memory;
148pub mod session;
149pub mod tensor;
150
151// Re-export
152pub use error::{OrtApiError, OrtError, Result};
153use sys::OnnxEnumInt;
154
155// Re-export ndarray as it's part of the public API anyway
156pub use ndarray;
157
158lazy_static! {
159    // static ref G_ORT: Arc<Mutex<AtomicPtr<sys::OrtApi>>> =
160    //     Arc::new(Mutex::new(AtomicPtr::new(unsafe {
161    //         sys::OrtGetApiBase().as_ref().unwrap().GetApi.unwrap()(sys::ORT_API_VERSION)
162    //     } as *mut sys::OrtApi)));
163    static ref G_ORT_API: Arc<Mutex<AtomicPtr<sys::OrtApi>>> = {
164        let base: *const sys::OrtApiBase = unsafe { sys::OrtGetApiBase() };
165        assert_ne!(base, std::ptr::null());
166        let get_api: extern_system_fn!{ unsafe fn(u32) -> *const onnxruntime_sys_ng::OrtApi } =
167            unsafe { (*base).GetApi.unwrap() };
168        let api: *const sys::OrtApi = unsafe { get_api(sys::ORT_API_VERSION) };
169        Arc::new(Mutex::new(AtomicPtr::new(api as *mut sys::OrtApi)))
170    };
171}
172
173fn g_ort() -> sys::OrtApi {
174    let mut api_ref = G_ORT_API
175        .lock()
176        .expect("Failed to acquire lock: another thread panicked?");
177    let api_ref_mut: &mut *mut sys::OrtApi = api_ref.get_mut();
178    let api_ptr_mut: *mut sys::OrtApi = *api_ref_mut;
179
180    assert_ne!(api_ptr_mut, std::ptr::null_mut());
181
182    unsafe { *api_ptr_mut }
183}
184
185fn char_p_to_string(raw: *const i8) -> Result<String> {
186    let c_string = unsafe { std::ffi::CStr::from_ptr(raw as *mut i8).to_owned() };
187
188    match c_string.into_string() {
189        Ok(string) => Ok(string),
190        Err(e) => Err(OrtApiError::IntoStringError(e)),
191    }
192    .map_err(OrtError::StringConversion)
193}
194
195mod onnxruntime {
196    //! Module containing a custom logger, used to catch the runtime's own logging and send it
197    //! to Rust's tracing logging instead.
198
199    use std::ffi::CStr;
200    use tracing::{debug, error, info, span, trace, warn, Level};
201
202    use onnxruntime_sys_ng as sys;
203
204    /// Runtime's logging sends the code location where the log happened, will be parsed to this struct.
205    #[derive(Debug)]
206    struct CodeLocation<'a> {
207        file: &'a str,
208        line_number: &'a str,
209        function: &'a str,
210    }
211
212    impl<'a> From<&'a str> for CodeLocation<'a> {
213        fn from(code_location: &'a str) -> Self {
214            let mut splitter = code_location.split(' ');
215            let file_and_line_number = splitter.next().unwrap_or("<unknown file:line>");
216            let function = splitter.next().unwrap_or("<unknown module>");
217            let mut file_and_line_number_splitter = file_and_line_number.split(':');
218            let file = file_and_line_number_splitter
219                .next()
220                .unwrap_or("<unknown file>");
221            let line_number = file_and_line_number_splitter
222                .next()
223                .unwrap_or("<unknown line number>");
224
225            CodeLocation {
226                file,
227                line_number,
228                function,
229            }
230        }
231    }
232
233    extern_system_fn! {
234        /// Callback from C that will handle the logging, forwarding the runtime's logs to the tracing crate.
235        pub(crate) fn custom_logger(
236            _params: *mut std::ffi::c_void,
237            severity: sys::OrtLoggingLevel,
238            category: *const i8,
239            logid: *const i8,
240            code_location: *const i8,
241            message: *const i8,
242        ) {
243            let log_level = match severity {
244                sys::OrtLoggingLevel_ORT_LOGGING_LEVEL_VERBOSE => Level::TRACE,
245                sys::OrtLoggingLevel_ORT_LOGGING_LEVEL_INFO => Level::DEBUG,
246                sys::OrtLoggingLevel_ORT_LOGGING_LEVEL_WARNING => Level::INFO,
247                sys::OrtLoggingLevel_ORT_LOGGING_LEVEL_ERROR => Level::WARN,
248                sys::OrtLoggingLevel_ORT_LOGGING_LEVEL_FATAL => Level::ERROR,
249                _ => Level::ERROR,
250            };
251
252            assert_ne!(category, std::ptr::null());
253            let category = unsafe { CStr::from_ptr(category) };
254            assert_ne!(code_location, std::ptr::null());
255            let code_location = unsafe { CStr::from_ptr(code_location) }
256                .to_str()
257                .unwrap_or("unknown");
258            assert_ne!(message, std::ptr::null());
259            let message = unsafe { CStr::from_ptr(message) };
260
261            assert_ne!(logid, std::ptr::null());
262            let logid = unsafe { CStr::from_ptr(logid) };
263
264            // Parse the code location
265            let code_location: CodeLocation = code_location.into();
266
267            let span = span!(
268                Level::TRACE,
269                "onnxruntime",
270                category = category.to_str().unwrap_or("<unknown>"),
271                file = code_location.file,
272                line_number = code_location.line_number,
273                function = code_location.function,
274                logid = logid.to_str().unwrap_or("<unknown>"),
275            );
276            let _enter = span.enter();
277
278            match log_level {
279                Level::TRACE => trace!("{:?}", message),
280                Level::DEBUG => debug!("{:?}", message),
281                Level::INFO => info!("{:?}", message),
282                Level::WARN => warn!("{:?}", message),
283                Level::ERROR => error!("{:?}", message),
284            }
285        }
286    }
287}
288
289/// Logging level of the ONNX Runtime C API
290#[derive(Debug)]
291#[cfg_attr(not(windows), repr(u32))]
292#[cfg_attr(windows, repr(i32))]
293pub enum LoggingLevel {
294    /// Verbose log level
295    Verbose = sys::OrtLoggingLevel_ORT_LOGGING_LEVEL_VERBOSE as OnnxEnumInt,
296    /// Info log level
297    Info = sys::OrtLoggingLevel_ORT_LOGGING_LEVEL_INFO as OnnxEnumInt,
298    /// Warning log level
299    Warning = sys::OrtLoggingLevel_ORT_LOGGING_LEVEL_WARNING as OnnxEnumInt,
300    /// Error log level
301    Error = sys::OrtLoggingLevel_ORT_LOGGING_LEVEL_ERROR as OnnxEnumInt,
302    /// Fatal log level
303    Fatal = sys::OrtLoggingLevel_ORT_LOGGING_LEVEL_FATAL as OnnxEnumInt,
304}
305
306impl From<LoggingLevel> for sys::OrtLoggingLevel {
307    fn from(val: LoggingLevel) -> Self {
308        match val {
309            LoggingLevel::Verbose => sys::OrtLoggingLevel_ORT_LOGGING_LEVEL_VERBOSE,
310            LoggingLevel::Info => sys::OrtLoggingLevel_ORT_LOGGING_LEVEL_INFO,
311            LoggingLevel::Warning => sys::OrtLoggingLevel_ORT_LOGGING_LEVEL_WARNING,
312            LoggingLevel::Error => sys::OrtLoggingLevel_ORT_LOGGING_LEVEL_ERROR,
313            LoggingLevel::Fatal => sys::OrtLoggingLevel_ORT_LOGGING_LEVEL_FATAL,
314        }
315    }
316}
317
318/// Optimization level performed by ONNX Runtime of the loaded graph
319///
320/// See the [official documentation](https://github.com/microsoft/onnxruntime/blob/master/docs/ONNX_Runtime_Graph_Optimizations.md)
321/// for more information on the different optimization levels.
322#[derive(Debug)]
323#[cfg_attr(not(windows), repr(u32))]
324#[cfg_attr(windows, repr(i32))]
325pub enum GraphOptimizationLevel {
326    /// Disable optimization
327    DisableAll = sys::GraphOptimizationLevel_ORT_DISABLE_ALL as OnnxEnumInt,
328    /// Basic optimization
329    Basic = sys::GraphOptimizationLevel_ORT_ENABLE_BASIC as OnnxEnumInt,
330    /// Extended optimization
331    Extended = sys::GraphOptimizationLevel_ORT_ENABLE_EXTENDED as OnnxEnumInt,
332    /// Add optimization
333    All = sys::GraphOptimizationLevel_ORT_ENABLE_ALL as OnnxEnumInt,
334}
335
336impl From<GraphOptimizationLevel> for sys::GraphOptimizationLevel {
337    fn from(val: GraphOptimizationLevel) -> Self {
338        use GraphOptimizationLevel::*;
339        match val {
340            DisableAll => sys::GraphOptimizationLevel_ORT_DISABLE_ALL,
341            Basic => sys::GraphOptimizationLevel_ORT_ENABLE_BASIC,
342            Extended => sys::GraphOptimizationLevel_ORT_ENABLE_EXTENDED,
343            All => sys::GraphOptimizationLevel_ORT_ENABLE_ALL,
344        }
345    }
346}
347
348// FIXME: Use https://docs.rs/bindgen/0.54.1/bindgen/struct.Builder.html#method.rustified_enum
349// FIXME: Add tests to cover the commented out types
350/// Enum mapping ONNX Runtime's supported tensor types
351#[derive(Debug)]
352#[cfg_attr(not(windows), repr(u32))]
353#[cfg_attr(windows, repr(i32))]
354pub enum TensorElementDataType {
355    /// 32-bit floating point, equivalent to Rust's `f32`
356    Float = sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT as OnnxEnumInt,
357    /// Unsigned 8-bit int, equivalent to Rust's `u8`
358    Uint8 = sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8 as OnnxEnumInt,
359    /// Signed 8-bit int, equivalent to Rust's `i8`
360    Int8 = sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8 as OnnxEnumInt,
361    /// Unsigned 16-bit int, equivalent to Rust's `u16`
362    Uint16 = sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16 as OnnxEnumInt,
363    /// Signed 16-bit int, equivalent to Rust's `i16`
364    Int16 = sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16 as OnnxEnumInt,
365    /// Signed 32-bit int, equivalent to Rust's `i32`
366    Int32 = sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32 as OnnxEnumInt,
367    /// Signed 64-bit int, equivalent to Rust's `i64`
368    Int64 = sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64 as OnnxEnumInt,
369    /// String, equivalent to Rust's `String`
370    String = sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING as OnnxEnumInt,
371    // /// Boolean, equivalent to Rust's `bool`
372    // Bool = sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL as OnnxEnumInt,
373    // /// 16-bit floating point, equivalent to Rust's `f16`
374    // Float16 = sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 as OnnxEnumInt,
375    /// 64-bit floating point, equivalent to Rust's `f64`
376    Double = sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE as OnnxEnumInt,
377    /// Unsigned 32-bit int, equivalent to Rust's `u32`
378    Uint32 = sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32 as OnnxEnumInt,
379    /// Unsigned 64-bit int, equivalent to Rust's `u64`
380    Uint64 = sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64 as OnnxEnumInt,
381    // /// Complex 64-bit floating point, equivalent to Rust's `???`
382    // Complex64 = sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64 as OnnxEnumInt,
383    // /// Complex 128-bit floating point, equivalent to Rust's `???`
384    // Complex128 = sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128 as OnnxEnumInt,
385    // /// Brain 16-bit floating point
386    // Bfloat16 = sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16 as OnnxEnumInt,
387}
388
389impl From<TensorElementDataType> for sys::ONNXTensorElementDataType {
390    fn from(val: TensorElementDataType) -> Self {
391        use TensorElementDataType::*;
392        match val {
393            Float => sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,
394            Uint8 => sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8,
395            Int8 => sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8,
396            Uint16 => sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16,
397            Int16 => sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16,
398            Int32 => sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32,
399            Int64 => sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64,
400            String => sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING,
401            // Bool => {
402            //     sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL
403            // }
404            // Float16 => {
405            //     sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16
406            // }
407            Double => sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE,
408            Uint32 => sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32,
409            Uint64 => sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64,
410            // Complex64 => {
411            //     sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64
412            // }
413            // Complex128 => {
414            //     sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128
415            // }
416            // Bfloat16 => {
417            //     sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16
418            // }
419        }
420    }
421}
422
423/// Trait used to map Rust types (for example `f32`) to ONNX types (for example `Float`)
424pub trait TypeToTensorElementDataType {
425    /// Return the ONNX type for a Rust type
426    fn tensor_element_data_type() -> TensorElementDataType;
427
428    /// If the type is `String`, returns `Some` with utf8 contents, else `None`.
429    fn try_utf8_bytes(&self) -> Option<&[u8]>;
430}
431
432macro_rules! impl_type_trait {
433    ($type_:ty, $variant:ident) => {
434        impl TypeToTensorElementDataType for $type_ {
435            fn tensor_element_data_type() -> TensorElementDataType {
436                // unsafe { std::mem::transmute(TensorElementDataType::$variant) }
437                TensorElementDataType::$variant
438            }
439
440            fn try_utf8_bytes(&self) -> Option<&[u8]> {
441                None
442            }
443        }
444    };
445}
446
447impl_type_trait!(f32, Float);
448impl_type_trait!(u8, Uint8);
449impl_type_trait!(i8, Int8);
450impl_type_trait!(u16, Uint16);
451impl_type_trait!(i16, Int16);
452impl_type_trait!(i32, Int32);
453impl_type_trait!(i64, Int64);
454// impl_type_trait!(bool, Bool);
455// impl_type_trait!(f16, Float16);
456impl_type_trait!(f64, Double);
457impl_type_trait!(u32, Uint32);
458impl_type_trait!(u64, Uint64);
459// impl_type_trait!(, Complex64);
460// impl_type_trait!(, Complex128);
461// impl_type_trait!(, Bfloat16);
462
463/// Adapter for common Rust string types to Onnx strings.
464///
465/// It should be easy to use both `String` and `&str` as [TensorElementDataType::String] data, but
466/// we can't define an automatic implementation for anything that implements `AsRef<str>` as it
467/// would conflict with the implementations of [TypeToTensorElementDataType] for primitive numeric
468/// types (which might implement `AsRef<str>` at some point in the future).
469pub trait Utf8Data {
470    /// Returns the utf8 contents.
471    fn utf8_bytes(&self) -> &[u8];
472}
473
474impl Utf8Data for String {
475    fn utf8_bytes(&self) -> &[u8] {
476        self.as_bytes()
477    }
478}
479
480impl<'a> Utf8Data for &'a str {
481    fn utf8_bytes(&self) -> &[u8] {
482        self.as_bytes()
483    }
484}
485
486impl<T: Utf8Data> TypeToTensorElementDataType for T {
487    fn tensor_element_data_type() -> TensorElementDataType {
488        TensorElementDataType::String
489    }
490
491    fn try_utf8_bytes(&self) -> Option<&[u8]> {
492        Some(self.utf8_bytes())
493    }
494}
495
496/// Allocator type
497#[derive(Debug, Clone)]
498#[repr(i32)]
499pub enum AllocatorType {
500    // Invalid = sys::OrtAllocatorType::Invalid as i32,
501    /// Device allocator
502    Device = sys::OrtAllocatorType_OrtDeviceAllocator as i32,
503    /// Arena allocator
504    Arena = sys::OrtAllocatorType_OrtArenaAllocator as i32,
505}
506
507impl From<AllocatorType> for sys::OrtAllocatorType {
508    fn from(val: AllocatorType) -> Self {
509        use AllocatorType::*;
510        match val {
511            // Invalid => sys::OrtAllocatorType::Invalid,
512            Device => sys::OrtAllocatorType_OrtDeviceAllocator,
513            Arena => sys::OrtAllocatorType_OrtArenaAllocator,
514        }
515    }
516}
517
518/// Memory type
519///
520/// Only support ONNX's default type for now.
521#[derive(Debug, Clone)]
522#[repr(i32)]
523pub enum MemType {
524    // FIXME: C API's `OrtMemType_OrtMemTypeCPU` defines it equal to `OrtMemType_OrtMemTypeCPUOutput`. How to handle this??
525    // CPUInput = sys::OrtMemType::OrtMemTypeCPUInput as i32,
526    // CPUOutput = sys::OrtMemType::OrtMemTypeCPUOutput as i32,
527    // CPU = sys::OrtMemType::OrtMemTypeCPU as i32,
528    /// Default memory type
529    Default = sys::OrtMemType_OrtMemTypeDefault as i32,
530}
531
532impl From<MemType> for sys::OrtMemType {
533    fn from(val: MemType) -> Self {
534        use MemType::*;
535        match val {
536            // CPUInput => sys::OrtMemType::OrtMemTypeCPUInput,
537            // CPUOutput => sys::OrtMemType::OrtMemTypeCPUOutput,
538            // CPU => sys::OrtMemType::OrtMemTypeCPU,
539            Default => sys::OrtMemType_OrtMemTypeDefault,
540        }
541    }
542}
543
544#[cfg(test)]
545mod test {
546    use super::*;
547
548    #[test]
549    fn test_char_p_to_string() {
550        let s = std::ffi::CString::new("foo").unwrap();
551        let ptr = s.as_c_str().as_ptr();
552        assert_eq!("foo", char_p_to_string(ptr).unwrap());
553    }
554}