1#![warn(missing_docs)]
2
3#,
63a model can be fetched directly from the [ONNX Model Zoo](https://github.com/onnx/models) using
64[`with_model_downloaded()`](session/struct.SessionBuilder.html#method.with_model_downloaded) method
65(requires the `model-fetching` feature).
66
67```no_run
68# use std::error::Error;
69# use onnxruntime::{environment::Environment, download::vision::ImageClassification, LoggingLevel, GraphOptimizationLevel};
70# fn main() -> Result<(), Box<dyn Error>> {
71# let environment = Environment::builder()
72# .with_name("test")
73# .with_log_level(LoggingLevel::Verbose)
74# .build()?;
75let mut session = environment
76 .new_session_builder()?
77 .with_optimization_level(GraphOptimizationLevel::Basic)?
78 .with_number_threads(1)?
79 .with_model_downloaded(ImageClassification::SqueezeNet)?;
80# Ok(())
81# }
82```
83
84See [`AvailableOnnxModel`](download/enum.AvailableOnnxModel.html) for the different models available
85to download.
86"##
87)]
88use std::sync::{atomic::AtomicPtr, Arc, Mutex};
119
120use lazy_static::lazy_static;
121
122use onnxruntime_sys_ng as sys;
123
124#[cfg(all(target_os = "windows", target_arch = "x86"))]
127macro_rules! extern_system_fn {
128 ($(#[$meta:meta])* fn $($tt:tt)*) => ($(#[$meta])* extern "stdcall" fn $($tt)*);
129 ($(#[$meta:meta])* $vis:vis fn $($tt:tt)*) => ($(#[$meta])* $vis extern "stdcall" fn $($tt)*);
130 ($(#[$meta:meta])* unsafe fn $($tt:tt)*) => ($(#[$meta])* unsafe extern "stdcall" fn $($tt)*);
131 ($(#[$meta:meta])* $vis:vis unsafe fn $($tt:tt)*) => ($(#[$meta])* $vis unsafe extern "stdcall" fn $($tt)*);
132}
133
134#[cfg(not(all(target_os = "windows", target_arch = "x86")))]
137macro_rules! extern_system_fn {
138 ($(#[$meta:meta])* fn $($tt:tt)*) => ($(#[$meta])* extern "C" fn $($tt)*);
139 ($(#[$meta:meta])* $vis:vis fn $($tt:tt)*) => ($(#[$meta])* $vis extern "C" fn $($tt)*);
140 ($(#[$meta:meta])* unsafe fn $($tt:tt)*) => ($(#[$meta])* unsafe extern "C" fn $($tt)*);
141 ($(#[$meta:meta])* $vis:vis unsafe fn $($tt:tt)*) => ($(#[$meta])* $vis unsafe extern "C" fn $($tt)*);
142}
143
144pub mod download;
145pub mod environment;
146pub mod error;
147mod memory;
148pub mod session;
149pub mod tensor;
150
151pub use error::{OrtApiError, OrtError, Result};
153use sys::OnnxEnumInt;
154
155pub use ndarray;
157
158lazy_static! {
159 static ref G_ORT_API: Arc<Mutex<AtomicPtr<sys::OrtApi>>> = {
164 let base: *const sys::OrtApiBase = unsafe { sys::OrtGetApiBase() };
165 assert_ne!(base, std::ptr::null());
166 let get_api: extern_system_fn!{ unsafe fn(u32) -> *const onnxruntime_sys_ng::OrtApi } =
167 unsafe { (*base).GetApi.unwrap() };
168 let api: *const sys::OrtApi = unsafe { get_api(sys::ORT_API_VERSION) };
169 Arc::new(Mutex::new(AtomicPtr::new(api as *mut sys::OrtApi)))
170 };
171}
172
173fn g_ort() -> sys::OrtApi {
174 let mut api_ref = G_ORT_API
175 .lock()
176 .expect("Failed to acquire lock: another thread panicked?");
177 let api_ref_mut: &mut *mut sys::OrtApi = api_ref.get_mut();
178 let api_ptr_mut: *mut sys::OrtApi = *api_ref_mut;
179
180 assert_ne!(api_ptr_mut, std::ptr::null_mut());
181
182 unsafe { *api_ptr_mut }
183}
184
185fn char_p_to_string(raw: *const i8) -> Result<String> {
186 let c_string = unsafe { std::ffi::CStr::from_ptr(raw as *mut i8).to_owned() };
187
188 match c_string.into_string() {
189 Ok(string) => Ok(string),
190 Err(e) => Err(OrtApiError::IntoStringError(e)),
191 }
192 .map_err(OrtError::StringConversion)
193}
194
195mod onnxruntime {
196 use std::ffi::CStr;
200 use tracing::{debug, error, info, span, trace, warn, Level};
201
202 use onnxruntime_sys_ng as sys;
203
204 #[derive(Debug)]
206 struct CodeLocation<'a> {
207 file: &'a str,
208 line_number: &'a str,
209 function: &'a str,
210 }
211
212 impl<'a> From<&'a str> for CodeLocation<'a> {
213 fn from(code_location: &'a str) -> Self {
214 let mut splitter = code_location.split(' ');
215 let file_and_line_number = splitter.next().unwrap_or("<unknown file:line>");
216 let function = splitter.next().unwrap_or("<unknown module>");
217 let mut file_and_line_number_splitter = file_and_line_number.split(':');
218 let file = file_and_line_number_splitter
219 .next()
220 .unwrap_or("<unknown file>");
221 let line_number = file_and_line_number_splitter
222 .next()
223 .unwrap_or("<unknown line number>");
224
225 CodeLocation {
226 file,
227 line_number,
228 function,
229 }
230 }
231 }
232
233 extern_system_fn! {
234 pub(crate) fn custom_logger(
236 _params: *mut std::ffi::c_void,
237 severity: sys::OrtLoggingLevel,
238 category: *const i8,
239 logid: *const i8,
240 code_location: *const i8,
241 message: *const i8,
242 ) {
243 let log_level = match severity {
244 sys::OrtLoggingLevel_ORT_LOGGING_LEVEL_VERBOSE => Level::TRACE,
245 sys::OrtLoggingLevel_ORT_LOGGING_LEVEL_INFO => Level::DEBUG,
246 sys::OrtLoggingLevel_ORT_LOGGING_LEVEL_WARNING => Level::INFO,
247 sys::OrtLoggingLevel_ORT_LOGGING_LEVEL_ERROR => Level::WARN,
248 sys::OrtLoggingLevel_ORT_LOGGING_LEVEL_FATAL => Level::ERROR,
249 _ => Level::ERROR,
250 };
251
252 assert_ne!(category, std::ptr::null());
253 let category = unsafe { CStr::from_ptr(category) };
254 assert_ne!(code_location, std::ptr::null());
255 let code_location = unsafe { CStr::from_ptr(code_location) }
256 .to_str()
257 .unwrap_or("unknown");
258 assert_ne!(message, std::ptr::null());
259 let message = unsafe { CStr::from_ptr(message) };
260
261 assert_ne!(logid, std::ptr::null());
262 let logid = unsafe { CStr::from_ptr(logid) };
263
264 let code_location: CodeLocation = code_location.into();
266
267 let span = span!(
268 Level::TRACE,
269 "onnxruntime",
270 category = category.to_str().unwrap_or("<unknown>"),
271 file = code_location.file,
272 line_number = code_location.line_number,
273 function = code_location.function,
274 logid = logid.to_str().unwrap_or("<unknown>"),
275 );
276 let _enter = span.enter();
277
278 match log_level {
279 Level::TRACE => trace!("{:?}", message),
280 Level::DEBUG => debug!("{:?}", message),
281 Level::INFO => info!("{:?}", message),
282 Level::WARN => warn!("{:?}", message),
283 Level::ERROR => error!("{:?}", message),
284 }
285 }
286 }
287}
288
289#[derive(Debug)]
291#[cfg_attr(not(windows), repr(u32))]
292#[cfg_attr(windows, repr(i32))]
293pub enum LoggingLevel {
294 Verbose = sys::OrtLoggingLevel_ORT_LOGGING_LEVEL_VERBOSE as OnnxEnumInt,
296 Info = sys::OrtLoggingLevel_ORT_LOGGING_LEVEL_INFO as OnnxEnumInt,
298 Warning = sys::OrtLoggingLevel_ORT_LOGGING_LEVEL_WARNING as OnnxEnumInt,
300 Error = sys::OrtLoggingLevel_ORT_LOGGING_LEVEL_ERROR as OnnxEnumInt,
302 Fatal = sys::OrtLoggingLevel_ORT_LOGGING_LEVEL_FATAL as OnnxEnumInt,
304}
305
306impl From<LoggingLevel> for sys::OrtLoggingLevel {
307 fn from(val: LoggingLevel) -> Self {
308 match val {
309 LoggingLevel::Verbose => sys::OrtLoggingLevel_ORT_LOGGING_LEVEL_VERBOSE,
310 LoggingLevel::Info => sys::OrtLoggingLevel_ORT_LOGGING_LEVEL_INFO,
311 LoggingLevel::Warning => sys::OrtLoggingLevel_ORT_LOGGING_LEVEL_WARNING,
312 LoggingLevel::Error => sys::OrtLoggingLevel_ORT_LOGGING_LEVEL_ERROR,
313 LoggingLevel::Fatal => sys::OrtLoggingLevel_ORT_LOGGING_LEVEL_FATAL,
314 }
315 }
316}
317
318#[derive(Debug)]
323#[cfg_attr(not(windows), repr(u32))]
324#[cfg_attr(windows, repr(i32))]
325pub enum GraphOptimizationLevel {
326 DisableAll = sys::GraphOptimizationLevel_ORT_DISABLE_ALL as OnnxEnumInt,
328 Basic = sys::GraphOptimizationLevel_ORT_ENABLE_BASIC as OnnxEnumInt,
330 Extended = sys::GraphOptimizationLevel_ORT_ENABLE_EXTENDED as OnnxEnumInt,
332 All = sys::GraphOptimizationLevel_ORT_ENABLE_ALL as OnnxEnumInt,
334}
335
336impl From<GraphOptimizationLevel> for sys::GraphOptimizationLevel {
337 fn from(val: GraphOptimizationLevel) -> Self {
338 use GraphOptimizationLevel::*;
339 match val {
340 DisableAll => sys::GraphOptimizationLevel_ORT_DISABLE_ALL,
341 Basic => sys::GraphOptimizationLevel_ORT_ENABLE_BASIC,
342 Extended => sys::GraphOptimizationLevel_ORT_ENABLE_EXTENDED,
343 All => sys::GraphOptimizationLevel_ORT_ENABLE_ALL,
344 }
345 }
346}
347
348#[derive(Debug)]
352#[cfg_attr(not(windows), repr(u32))]
353#[cfg_attr(windows, repr(i32))]
354pub enum TensorElementDataType {
355 Float = sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT as OnnxEnumInt,
357 Uint8 = sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8 as OnnxEnumInt,
359 Int8 = sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8 as OnnxEnumInt,
361 Uint16 = sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16 as OnnxEnumInt,
363 Int16 = sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16 as OnnxEnumInt,
365 Int32 = sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32 as OnnxEnumInt,
367 Int64 = sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64 as OnnxEnumInt,
369 String = sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING as OnnxEnumInt,
371 Double = sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE as OnnxEnumInt,
377 Uint32 = sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32 as OnnxEnumInt,
379 Uint64 = sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64 as OnnxEnumInt,
381 }
388
389impl From<TensorElementDataType> for sys::ONNXTensorElementDataType {
390 fn from(val: TensorElementDataType) -> Self {
391 use TensorElementDataType::*;
392 match val {
393 Float => sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,
394 Uint8 => sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8,
395 Int8 => sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8,
396 Uint16 => sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16,
397 Int16 => sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16,
398 Int32 => sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32,
399 Int64 => sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64,
400 String => sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING,
401 Double => sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE,
408 Uint32 => sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32,
409 Uint64 => sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64,
410 }
420 }
421}
422
423pub trait TypeToTensorElementDataType {
425 fn tensor_element_data_type() -> TensorElementDataType;
427
428 fn try_utf8_bytes(&self) -> Option<&[u8]>;
430}
431
432macro_rules! impl_type_trait {
433 ($type_:ty, $variant:ident) => {
434 impl TypeToTensorElementDataType for $type_ {
435 fn tensor_element_data_type() -> TensorElementDataType {
436 TensorElementDataType::$variant
438 }
439
440 fn try_utf8_bytes(&self) -> Option<&[u8]> {
441 None
442 }
443 }
444 };
445}
446
447impl_type_trait!(f32, Float);
448impl_type_trait!(u8, Uint8);
449impl_type_trait!(i8, Int8);
450impl_type_trait!(u16, Uint16);
451impl_type_trait!(i16, Int16);
452impl_type_trait!(i32, Int32);
453impl_type_trait!(i64, Int64);
454impl_type_trait!(f64, Double);
457impl_type_trait!(u32, Uint32);
458impl_type_trait!(u64, Uint64);
459pub trait Utf8Data {
470 fn utf8_bytes(&self) -> &[u8];
472}
473
474impl Utf8Data for String {
475 fn utf8_bytes(&self) -> &[u8] {
476 self.as_bytes()
477 }
478}
479
480impl<'a> Utf8Data for &'a str {
481 fn utf8_bytes(&self) -> &[u8] {
482 self.as_bytes()
483 }
484}
485
486impl<T: Utf8Data> TypeToTensorElementDataType for T {
487 fn tensor_element_data_type() -> TensorElementDataType {
488 TensorElementDataType::String
489 }
490
491 fn try_utf8_bytes(&self) -> Option<&[u8]> {
492 Some(self.utf8_bytes())
493 }
494}
495
496#[derive(Debug, Clone)]
498#[repr(i32)]
499pub enum AllocatorType {
500 Device = sys::OrtAllocatorType_OrtDeviceAllocator as i32,
503 Arena = sys::OrtAllocatorType_OrtArenaAllocator as i32,
505}
506
507impl From<AllocatorType> for sys::OrtAllocatorType {
508 fn from(val: AllocatorType) -> Self {
509 use AllocatorType::*;
510 match val {
511 Device => sys::OrtAllocatorType_OrtDeviceAllocator,
513 Arena => sys::OrtAllocatorType_OrtArenaAllocator,
514 }
515 }
516}
517
518#[derive(Debug, Clone)]
522#[repr(i32)]
523pub enum MemType {
524 Default = sys::OrtMemType_OrtMemTypeDefault as i32,
530}
531
532impl From<MemType> for sys::OrtMemType {
533 fn from(val: MemType) -> Self {
534 use MemType::*;
535 match val {
536 Default => sys::OrtMemType_OrtMemTypeDefault,
540 }
541 }
542}
543
544#[cfg(test)]
545mod test {
546 use super::*;
547
548 #[test]
549 fn test_char_p_to_string() {
550 let s = std::ffi::CString::new("foo").unwrap();
551 let ptr = s.as_c_str().as_ptr();
552 assert_eq!("foo", char_p_to_string(ptr).unwrap());
553 }
554}