1#![warn(missing_docs)]
2
3#,
63a model can be fetched directly from the [ONNX Model Zoo](https://github.com/onnx/models) using
64[`with_model_downloaded()`](session/struct.SessionBuilder.html#method.with_model_downloaded) method
65(requires the `model-fetching` feature).
66
67```no_run
68# use std::error::Error;
69# use onnxruntime::{environment::Environment, download::vision::ImageClassification, LoggingLevel, GraphOptimizationLevel};
70# fn main() -> Result<(), Box<dyn Error>> {
71# let environment = Environment::builder()
72# .with_name("test")
73# .with_log_level(LoggingLevel::Verbose)
74# .build()?;
75let mut session = environment
76 .new_session_builder()?
77 .with_optimization_level(GraphOptimizationLevel::Basic)?
78 .with_number_threads(1)?
79 .with_model_downloaded(ImageClassification::SqueezeNet)?;
80# Ok(())
81# }
82```
83
84See [`AvailableOnnxModel`](download/enum.AvailableOnnxModel.html) for the different models available
85to download.
86"##
87)]
88use std::sync::{atomic::AtomicPtr, Arc, Mutex};
119
120use lazy_static::lazy_static;
121
122use mcai_onnxruntime_sys as sys;
123
124#[cfg(all(target_os = "windows", target_arch = "x86"))]
127macro_rules! extern_system_fn {
128 ($(#[$meta:meta])* fn $($tt:tt)*) => ($(#[$meta])* extern "stdcall" fn $($tt)*);
129 ($(#[$meta:meta])* $vis:vis fn $($tt:tt)*) => ($(#[$meta])* $vis extern "stdcall" fn $($tt)*);
130 ($(#[$meta:meta])* unsafe fn $($tt:tt)*) => ($(#[$meta])* unsafe extern "stdcall" fn $($tt)*);
131 ($(#[$meta:meta])* $vis:vis unsafe fn $($tt:tt)*) => ($(#[$meta])* $vis unsafe extern "stdcall" fn $($tt)*);
132}
133
134#[cfg(not(all(target_os = "windows", target_arch = "x86")))]
137macro_rules! extern_system_fn {
138 ($(#[$meta:meta])* fn $($tt:tt)*) => ($(#[$meta])* extern "C" fn $($tt)*);
139 ($(#[$meta:meta])* $vis:vis fn $($tt:tt)*) => ($(#[$meta])* $vis extern "C" fn $($tt)*);
140 ($(#[$meta:meta])* unsafe fn $($tt:tt)*) => ($(#[$meta])* unsafe extern "C" fn $($tt)*);
141 ($(#[$meta:meta])* $vis:vis unsafe fn $($tt:tt)*) => ($(#[$meta])* $vis unsafe extern "C" fn $($tt)*);
142}
143
144pub mod download;
145pub mod environment;
146pub mod error;
147mod memory;
148pub mod session;
149pub mod tensor;
150
151pub use error::{OrtApiError, OrtError, Result};
153use sys::OnnxEnumInt;
154
155pub use ndarray;
157
158lazy_static! {
159 static ref G_ORT_API: Arc<Mutex<AtomicPtr<sys::OrtApi>>> = {
164 let base: *const sys::OrtApiBase = unsafe { sys::OrtGetApiBase() };
165 assert_ne!(base, std::ptr::null());
166 let get_api: extern_system_fn!{ unsafe fn(u32) -> *const mcai_onnxruntime_sys::OrtApi } =
167 unsafe { (*base).GetApi.unwrap() };
168 let api: *const sys::OrtApi = unsafe { get_api(sys::ORT_API_VERSION) };
169 Arc::new(Mutex::new(AtomicPtr::new(api as *mut sys::OrtApi)))
170 };
171}
172
173fn g_ort() -> sys::OrtApi {
174 let mut api_ref = G_ORT_API
175 .lock()
176 .expect("Failed to acquire lock: another thread panicked?");
177 let api_ref_mut: &mut *mut sys::OrtApi = api_ref.get_mut();
178 let api_ptr_mut: *mut sys::OrtApi = *api_ref_mut;
179
180 assert_ne!(api_ptr_mut, std::ptr::null_mut());
181
182 unsafe { *api_ptr_mut }
183}
184
185fn char_p_to_string(raw: *const i8) -> Result<String> {
186 let c_string = unsafe { std::ffi::CStr::from_ptr(raw as *mut i8).to_owned() };
187
188 match c_string.into_string() {
189 Ok(string) => Ok(string),
190 Err(e) => Err(OrtApiError::IntoStringError(e)),
191 }
192 .map_err(OrtError::StringConversion)
193}
194
195mod onnxruntime {
196 use std::ffi::CStr;
200 use tracing::{debug, error, info, span, trace, warn, Level};
201
202 use mcai_onnxruntime_sys as sys;
203
204 #[derive(Debug)]
206 struct CodeLocation<'a> {
207 file: &'a str,
208 line_number: &'a str,
209 function: &'a str,
210 }
211
212 impl<'a> From<&'a str> for CodeLocation<'a> {
213 fn from(code_location: &'a str) -> Self {
214 let mut splitter = code_location.split(' ');
215 let file_and_line_number = splitter.next().unwrap_or("<unknown file:line>");
216 let function = splitter.next().unwrap_or("<unknown module>");
217 let mut file_and_line_number_splitter = file_and_line_number.split(':');
218 let file = file_and_line_number_splitter
219 .next()
220 .unwrap_or("<unknown file>");
221 let line_number = file_and_line_number_splitter
222 .next()
223 .unwrap_or("<unknown line number>");
224
225 CodeLocation {
226 file,
227 line_number,
228 function,
229 }
230 }
231 }
232
233 extern_system_fn! {
234 pub(crate) fn custom_logger(
236 _params: *mut std::ffi::c_void,
237 severity: sys::OrtLoggingLevel,
238 category: *const i8,
239 logid: *const i8,
240 code_location: *const i8,
241 message: *const i8,
242 ) {
243 let log_level = match severity {
244 sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE => Level::TRACE,
245 sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO => Level::DEBUG,
246 sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING => Level::INFO,
247 sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_ERROR => Level::WARN,
248 sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_FATAL => Level::ERROR,
249 };
250
251 assert_ne!(category, std::ptr::null());
252 let category = unsafe { CStr::from_ptr(category) };
253 assert_ne!(code_location, std::ptr::null());
254 let code_location = unsafe { CStr::from_ptr(code_location) }
255 .to_str()
256 .unwrap_or("unknown");
257 assert_ne!(message, std::ptr::null());
258 let message = unsafe { CStr::from_ptr(message) };
259
260 assert_ne!(logid, std::ptr::null());
261 let logid = unsafe { CStr::from_ptr(logid) };
262
263 let code_location: CodeLocation = code_location.into();
265
266 let span = span!(
267 Level::TRACE,
268 "onnxruntime",
269 category = category.to_str().unwrap_or("<unknown>"),
270 file = code_location.file,
271 line_number = code_location.line_number,
272 function = code_location.function,
273 logid = logid.to_str().unwrap_or("<unknown>"),
274 );
275 let _enter = span.enter();
276
277 match log_level {
278 Level::TRACE => trace!("{:?}", message),
279 Level::DEBUG => debug!("{:?}", message),
280 Level::INFO => info!("{:?}", message),
281 Level::WARN => warn!("{:?}", message),
282 Level::ERROR => error!("{:?}", message),
283 }
284 }
285 }
286}
287
288#[derive(Debug)]
290#[cfg_attr(not(windows), repr(u32))]
291#[cfg_attr(windows, repr(i32))]
292pub enum LoggingLevel {
293 Verbose = sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE as OnnxEnumInt,
295 Info = sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO as OnnxEnumInt,
297 Warning = sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING as OnnxEnumInt,
299 Error = sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_ERROR as OnnxEnumInt,
301 Fatal = sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_FATAL as OnnxEnumInt,
303}
304
305impl From<LoggingLevel> for sys::OrtLoggingLevel {
306 fn from(val: LoggingLevel) -> Self {
307 match val {
308 LoggingLevel::Verbose => sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE,
309 LoggingLevel::Info => sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO,
310 LoggingLevel::Warning => sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING,
311 LoggingLevel::Error => sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_ERROR,
312 LoggingLevel::Fatal => sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_FATAL,
313 }
314 }
315}
316
317#[derive(Debug)]
322#[cfg_attr(not(windows), repr(u32))]
323#[cfg_attr(windows, repr(i32))]
324pub enum GraphOptimizationLevel {
325 DisableAll = sys::GraphOptimizationLevel::ORT_DISABLE_ALL as OnnxEnumInt,
327 Basic = sys::GraphOptimizationLevel::ORT_ENABLE_BASIC as OnnxEnumInt,
329 Extended = sys::GraphOptimizationLevel::ORT_ENABLE_EXTENDED as OnnxEnumInt,
331 All = sys::GraphOptimizationLevel::ORT_ENABLE_ALL as OnnxEnumInt,
333}
334
335impl From<GraphOptimizationLevel> for sys::GraphOptimizationLevel {
336 fn from(val: GraphOptimizationLevel) -> Self {
337 use GraphOptimizationLevel::*;
338 match val {
339 DisableAll => sys::GraphOptimizationLevel::ORT_DISABLE_ALL,
340 Basic => sys::GraphOptimizationLevel::ORT_ENABLE_BASIC,
341 Extended => sys::GraphOptimizationLevel::ORT_ENABLE_EXTENDED,
342 All => sys::GraphOptimizationLevel::ORT_ENABLE_ALL,
343 }
344 }
345}
346
347#[derive(Debug)]
351#[cfg_attr(not(windows), repr(u32))]
352#[cfg_attr(windows, repr(i32))]
353pub enum TensorElementDataType {
354 Float = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT as OnnxEnumInt,
356 Uint8 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8 as OnnxEnumInt,
358 Int8 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8 as OnnxEnumInt,
360 Uint16 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16 as OnnxEnumInt,
362 Int16 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16 as OnnxEnumInt,
364 Int32 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32 as OnnxEnumInt,
366 Int64 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64 as OnnxEnumInt,
368 String = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING as OnnxEnumInt,
370 Double = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE as OnnxEnumInt,
376 Uint32 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32 as OnnxEnumInt,
378 Uint64 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64 as OnnxEnumInt,
380 }
387
388impl From<TensorElementDataType> for sys::ONNXTensorElementDataType {
389 fn from(val: TensorElementDataType) -> Self {
390 use TensorElementDataType::*;
391 match val {
392 Float => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,
393 Uint8 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8,
394 Int8 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8,
395 Uint16 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16,
396 Int16 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16,
397 Int32 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32,
398 Int64 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64,
399 String => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING,
400 Double => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE,
407 Uint32 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32,
408 Uint64 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64,
409 }
419 }
420}
421
422pub trait TypeToTensorElementDataType {
424 fn tensor_element_data_type() -> TensorElementDataType;
426
427 fn try_utf8_bytes(&self) -> Option<&[u8]>;
429}
430
431macro_rules! impl_type_trait {
432 ($type_:ty, $variant:ident) => {
433 impl TypeToTensorElementDataType for $type_ {
434 fn tensor_element_data_type() -> TensorElementDataType {
435 TensorElementDataType::$variant
437 }
438
439 fn try_utf8_bytes(&self) -> Option<&[u8]> {
440 None
441 }
442 }
443 };
444}
445
446impl_type_trait!(f32, Float);
447impl_type_trait!(u8, Uint8);
448impl_type_trait!(i8, Int8);
449impl_type_trait!(u16, Uint16);
450impl_type_trait!(i16, Int16);
451impl_type_trait!(i32, Int32);
452impl_type_trait!(i64, Int64);
453impl_type_trait!(f64, Double);
456impl_type_trait!(u32, Uint32);
457impl_type_trait!(u64, Uint64);
458pub trait Utf8Data {
469 fn utf8_bytes(&self) -> &[u8];
471}
472
473impl Utf8Data for String {
474 fn utf8_bytes(&self) -> &[u8] {
475 self.as_bytes()
476 }
477}
478
479impl<'a> Utf8Data for &'a str {
480 fn utf8_bytes(&self) -> &[u8] {
481 self.as_bytes()
482 }
483}
484
485impl<T: Utf8Data> TypeToTensorElementDataType for T {
486 fn tensor_element_data_type() -> TensorElementDataType {
487 TensorElementDataType::String
488 }
489
490 fn try_utf8_bytes(&self) -> Option<&[u8]> {
491 Some(self.utf8_bytes())
492 }
493}
494
495#[derive(Debug, Clone)]
497#[repr(i32)]
498pub enum AllocatorType {
499 Device = sys::OrtAllocatorType::OrtDeviceAllocator as i32,
502 Arena = sys::OrtAllocatorType::OrtArenaAllocator as i32,
504}
505
506impl From<AllocatorType> for sys::OrtAllocatorType {
507 fn from(val: AllocatorType) -> Self {
508 use AllocatorType::*;
509 match val {
510 Device => sys::OrtAllocatorType::OrtDeviceAllocator,
512 Arena => sys::OrtAllocatorType::OrtArenaAllocator,
513 }
514 }
515}
516
517#[derive(Debug, Clone)]
521#[repr(i32)]
522pub enum MemType {
523 Default = sys::OrtMemType::OrtMemTypeDefault as i32,
529}
530
531impl From<MemType> for sys::OrtMemType {
532 fn from(val: MemType) -> Self {
533 use MemType::*;
534 match val {
535 Default => sys::OrtMemType::OrtMemTypeDefault,
539 }
540 }
541}
542
543#[cfg(test)]
544mod test {
545 use super::*;
546
547 #[test]
548 fn test_char_p_to_string() {
549 let s = std::ffi::CString::new("foo").unwrap();
550 let ptr = s.as_c_str().as_ptr();
551 assert_eq!("foo", char_p_to_string(ptr).unwrap());
552 }
553}