1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
#![doc = include_str!("../README.md")]

pub mod download;
pub mod environment;
pub mod error;
pub mod execution_providers;
pub mod io_binding;
pub mod memory;
pub mod metadata;
pub mod session;
pub mod sys;
pub mod tensor;
pub mod value;

use std::{
	ffi::{self, CStr},
	os::raw::c_char,
	ptr,
	sync::{atomic::AtomicPtr, Arc, Mutex}
};

use lazy_static::lazy_static;
use tracing::warn;

pub use self::environment::Environment;
pub use self::error::{OrtApiError, OrtError, OrtResult};
pub use self::execution_providers::ExecutionProvider;
pub use self::io_binding::IoBinding;
pub use self::memory::{AllocationDevice, MemoryInfo};
pub use self::session::{InMemorySession, Session, SessionBuilder};
pub use self::tensor::NdArrayExtensions;
pub use self::value::Value;

#[cfg(not(all(target_arch = "x86", target_os = "windows")))]
macro_rules! extern_system_fn {
	($(#[$meta:meta])* fn $($tt:tt)*) => ($(#[$meta])* extern "C" fn $($tt)*);
	($(#[$meta:meta])* $vis:vis fn $($tt:tt)*) => ($(#[$meta])* $vis extern "C" fn $($tt)*);
	($(#[$meta:meta])* unsafe fn $($tt:tt)*) => ($(#[$meta])* unsafe extern "C" fn $($tt)*);
	($(#[$meta:meta])* $vis:vis unsafe fn $($tt:tt)*) => ($(#[$meta])* $vis unsafe extern "C" fn $($tt)*);
}
#[cfg(all(target_arch = "x86", target_os = "windows"))]
macro_rules! extern_system_fn {
	($(#[$meta:meta])* fn $($tt:tt)*) => ($(#[$meta])* extern "stdcall" fn $($tt)*);
	($(#[$meta:meta])* $vis:vis fn $($tt:tt)*) => ($(#[$meta])* $vis extern "stdcall" fn $($tt)*);
	($(#[$meta:meta])* unsafe fn $($tt:tt)*) => ($(#[$meta])* unsafe extern "stdcall" fn $($tt)*);
	($(#[$meta:meta])* $vis:vis unsafe fn $($tt:tt)*) => ($(#[$meta])* $vis unsafe extern "stdcall" fn $($tt)*);
}

pub(crate) use extern_system_fn;

#[cfg(feature = "load-dynamic")]
lazy_static! {
	pub(crate) static ref G_ORT_DYLIB_PATH: Arc<String> = {
		let path = match std::env::var("ORT_DYLIB_PATH") {
			Ok(s) if !s.is_empty() => s,
			#[cfg(target_os = "windows")]
			_ => "onnxruntime.dll".to_owned(),
			#[cfg(any(target_os = "linux", target_os = "android"))]
			_ => "libonnxruntime.so".to_owned(),
			#[cfg(target_os = "macos")]
			_ => "libonnxruntime.dylib".to_owned()
		};
		Arc::new(path)
	};
	pub(crate) static ref G_ORT_LIB: Arc<Mutex<AtomicPtr<libloading::Library>>> = {
		unsafe {
			// resolve path relative to executable
			let path: std::path::PathBuf = (&**G_ORT_DYLIB_PATH).into();
			let absolute_path = if path.is_absolute() {
				path
			} else {
				let relative = std::env::current_exe().expect("could not get current executable path").parent().unwrap().join(&path);
				if relative.exists() {
					relative
				} else {
					path
				}
			};
			let lib = libloading::Library::new(&absolute_path).unwrap_or_else(|e| panic!("could not load the library at `{}`: {e:?}", absolute_path.display()));
			Arc::new(Mutex::new(AtomicPtr::new(Box::leak(Box::new(lib)) as *mut _)))
		}
	};
}

lazy_static! {
	pub(crate) static ref G_ORT_API: Arc<Mutex<AtomicPtr<sys::OrtApi>>> = {
		#[cfg(feature = "load-dynamic")]
		unsafe {
			let dylib = *G_ORT_LIB
				.lock()
				.expect("failed to acquire ONNX Runtime dylib lock; another thread panicked?")
				.get_mut();
			let base_getter: libloading::Symbol<unsafe extern "C" fn() -> *const sys::OrtApiBase> = (*dylib).get(b"OrtGetApiBase").expect("");
			let base: *const sys::OrtApiBase = base_getter();
			assert_ne!(base, ptr::null());

			let get_version_string: extern_system_fn! { unsafe fn () -> *const ffi::c_char } = (*base).GetVersionString.unwrap();
			let version_string = get_version_string();
			let version_string = CStr::from_ptr(version_string).to_string_lossy();
			let lib_minor_version = version_string.split('.').nth(1).map(|x| x.parse::<u32>().unwrap_or(0)).unwrap_or(0);
			match lib_minor_version.cmp(&16) {
				std::cmp::Ordering::Less => panic!(
					"ort 1.16 is not compatible with the ONNX Runtime binary found at `{}`; expected GetVersionString to return '1.16.x', but got '{version_string}'",
					**G_ORT_DYLIB_PATH
				),
				std::cmp::Ordering::Greater => warn!(
					"ort 1.16 may have compatibility issues with the ONNX Runtime binary found at `{}`; expected GetVersionString to return '1.16.x', but got '{version_string}'",
					**G_ORT_DYLIB_PATH
				),
				std::cmp::Ordering::Equal => {}
			};
			let get_api: extern_system_fn! { unsafe fn(u32) -> *const sys::OrtApi } = (*base).GetApi.unwrap();
			let api: *const sys::OrtApi = get_api(sys::ORT_API_VERSION);
			Arc::new(Mutex::new(AtomicPtr::new(api as *mut sys::OrtApi)))
		}
		#[cfg(not(feature = "load-dynamic"))]
		{
			let base: *const sys::OrtApiBase = unsafe { sys::OrtGetApiBase() };
			assert_ne!(base, ptr::null());
			let get_api: extern_system_fn! { unsafe fn(u32) -> *const sys::OrtApi } = unsafe { (*base).GetApi.unwrap() };
			let api: *const sys::OrtApi = unsafe { get_api(sys::ORT_API_VERSION) };
			Arc::new(Mutex::new(AtomicPtr::new(api as *mut sys::OrtApi)))
		}
	};
}

/// Attempts to acquire the global [`sys::OrtApi`] object.
///
/// # Panics
///
/// Panics if another thread panicked while holding the API lock, or if the ONNX Runtime API could not be initialized.
pub fn ort() -> sys::OrtApi {
	let mut api_ref = G_ORT_API.lock().expect("failed to acquire OrtApi lock; another thread panicked?");
	let api_ref_mut: &mut *mut sys::OrtApi = api_ref.get_mut();
	let api_ptr_mut: *mut sys::OrtApi = *api_ref_mut;

	assert_ne!(api_ptr_mut, ptr::null_mut());

	unsafe { *api_ptr_mut }
}

macro_rules! ortsys {
	($method:tt) => {
		$crate::ort().$method.unwrap()
	};
	(unsafe $method:tt) => {
		unsafe { $crate::ort().$method.unwrap() }
	};
	($method:tt($($n:expr),+ $(,)?)) => {
		$crate::ort().$method.unwrap()($($n),+)
	};
	(unsafe $method:tt($($n:expr),+ $(,)?)) => {
		unsafe { $crate::ort().$method.unwrap()($($n),+) }
	};
	($method:tt($($n:expr),+ $(,)?); nonNull($($check:expr),+ $(,)?)$(;)?) => {
		$crate::ort().$method.unwrap()($($n),+);
		$($crate::error::assert_non_null_pointer($check, stringify!($method))?;)+
	};
	(unsafe $method:tt($($n:expr),+ $(,)?); nonNull($($check:expr),+ $(,)?)$(;)?) => {
		unsafe { $crate::ort().$method.unwrap()($($n),+) };
		$($crate::error::assert_non_null_pointer($check, stringify!($method))?;)+
	};
	($method:tt($($n:expr),+ $(,)?) -> $err:expr$(;)?) => {
		$crate::error::status_to_result($crate::ort().$method.unwrap()($($n),+)).map_err($err)?;
	};
	(unsafe $method:tt($($n:expr),+ $(,)?) -> $err:expr$(;)?) => {
		$crate::error::status_to_result(unsafe { $crate::ort().$method.unwrap()($($n),+) }).map_err($err)?;
	};
	($method:tt($($n:expr),+ $(,)?) -> $err:expr; nonNull($($check:expr),+ $(,)?)$(;)?) => {
		$crate::error::status_to_result($crate::ort().$method.unwrap()($($n),+)).map_err($err)?;
		$($crate::error::assert_non_null_pointer($check, stringify!($method))?;)+
	};
	(unsafe $method:tt($($n:expr),+ $(,)?) -> $err:expr; nonNull($($check:expr),+ $(,)?)$(;)?) => {
		$crate::error::status_to_result(unsafe { $crate::ort().$method.unwrap()($($n),+) }).map_err($err)?;
		$($crate::error::assert_non_null_pointer($check, stringify!($method))?;)+
	};
}

macro_rules! ortfree {
	(unsafe $allocator_ptr:expr, $ptr:tt) => {
		unsafe { (*$allocator_ptr).Free.unwrap()($allocator_ptr, $ptr as *mut std::ffi::c_void) }
	};
	($allocator_ptr:expr, $ptr:tt) => {
		(*$allocator_ptr).Free.unwrap()($allocator_ptr, $ptr as *mut std::ffi::c_void)
	};
}

pub(crate) use ortfree;
pub(crate) use ortsys;

pub(crate) fn char_p_to_string(raw: *const c_char) -> OrtResult<String> {
	let c_string = unsafe { CStr::from_ptr(raw as *mut c_char).to_owned() };
	match c_string.into_string() {
		Ok(string) => Ok(string),
		Err(e) => Err(OrtApiError::IntoStringError(e))
	}
	.map_err(OrtError::FfiStringConversion)
}

/// ONNX's logger sends the code location where the log occurred, which will be parsed into this struct.
#[derive(Debug)]
struct CodeLocation<'a> {
	file: &'a str,
	line: &'a str,
	function: &'a str
}

impl<'a> From<&'a str> for CodeLocation<'a> {
	fn from(code_location: &'a str) -> Self {
		let mut splitter = code_location.split(' ');
		let file_and_line = splitter.next().unwrap_or("<unknown file>:<unknown line>");
		let function = splitter.next().unwrap_or("<unknown function>");
		let mut file_and_line_splitter = file_and_line.split(':');
		let file = file_and_line_splitter.next().unwrap_or("<unknown file>");
		let line = file_and_line_splitter.next().unwrap_or("<unknown line>");

		CodeLocation { file, line, function }
	}
}

extern_system_fn! {
	/// Callback from C that will handle ONNX logging, forwarding ONNX's logs to the `tracing` crate.
	pub(crate) fn custom_logger(_params: *mut ffi::c_void, severity: sys::OrtLoggingLevel, category: *const c_char, log_id: *const c_char, code_location: *const c_char, message: *const c_char) {
		use tracing::{span, Level, trace, debug, warn, info, error};

		let log_level = match severity {
			sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE => Level::TRACE,
			sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO => Level::DEBUG,
			sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING => Level::INFO,
			sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_ERROR => Level::WARN,
			sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_FATAL => Level::ERROR
		};

		assert_ne!(category, ptr::null());
		let category = unsafe { CStr::from_ptr(category) };
		assert_ne!(code_location, ptr::null());
		let code_location = unsafe { CStr::from_ptr(code_location) }.to_str().unwrap_or("unknown");
		assert_ne!(message, ptr::null());
		let message = unsafe { CStr::from_ptr(message) }.to_str().unwrap_or("<invalid>");
		assert_ne!(log_id, ptr::null());
		let log_id = unsafe { CStr::from_ptr(log_id) };

		let code_location = CodeLocation::from(code_location);
		let span = span!(
			Level::TRACE,
			"ort",
			category = category.to_str().unwrap_or("<unknown>"),
			file = code_location.file,
			line = code_location.line,
			function = code_location.function,
			log_id = log_id.to_str().unwrap_or("<unknown>")
		);
		let _enter = span.enter();

		match log_level {
			Level::TRACE => trace!("{}", message),
			Level::DEBUG => debug!("{}", message),
			Level::INFO => info!("{}", message),
			Level::WARN => warn!("{}", message),
			Level::ERROR => error!("{}", message)
		}
	}
}

/// The minimum logging level. Logs will be handled by the `tracing` crate.
#[derive(Debug)]
pub enum LoggingLevel {
	/// Verbose logging level. This will log *a lot* of messages!
	Verbose,
	/// Info logging level.
	Info,
	/// Warning logging level. Recommended to receive potentially important warnings.
	Warning,
	/// Error logging level.
	Error,
	/// Fatal logging level.
	Fatal
}

impl From<LoggingLevel> for sys::OrtLoggingLevel {
	fn from(logging_level: LoggingLevel) -> Self {
		match logging_level {
			LoggingLevel::Verbose => sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE,
			LoggingLevel::Info => sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO,
			LoggingLevel::Warning => sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING,
			LoggingLevel::Error => sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_ERROR,
			LoggingLevel::Fatal => sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_FATAL
		}
	}
}

/// ONNX Runtime provides various graph optimizations to improve performance. Graph optimizations are essentially
/// graph-level transformations, ranging from small graph simplifications and node eliminations to more complex node
/// fusions and layout optimizations.
///
/// Graph optimizations are divided in several categories (or levels) based on their complexity and functionality. They
/// can be performed either online or offline. In online mode, the optimizations are done before performing the
/// inference, while in offline mode, the runtime saves the optimized graph to disk (most commonly used when converting
/// an ONNX model to an ONNX Runtime model).
///
/// The optimizations belonging to one level are performed after the optimizations of the previous level have been
/// applied (e.g., extended optimizations are applied after basic optimizations have been applied).
///
/// **All optimizations (i.e. [`GraphOptimizationLevel::Level3`]) are enabled by default.**
///
/// # Online/offline mode
/// All optimizations can be performed either online or offline. In online mode, when initializing an inference session,
/// we also apply all enabled graph optimizations before performing model inference. Applying all optimizations each
/// time we initiate a session can add overhead to the model startup time (especially for complex models), which can be
/// critical in production scenarios. This is where the offline mode can bring a lot of benefit. In offline mode, after
/// performing graph optimizations, ONNX Runtime serializes the resulting model to disk. Subsequently, we can reduce
/// startup time by using the already optimized model and disabling all optimizations.
///
/// ## Notes:
/// - When running in offline mode, make sure to use the exact same options (e.g., execution providers, optimization
///   level) and hardware as the target machine that the model inference will run on (e.g., you cannot run a model
///   pre-optimized for a GPU execution provider on a machine that is equipped only with CPU).
/// - When layout optimizations are enabled, the offline mode can only be used on compatible hardware to the environment
///   when the offline model is saved. For example, if model has layout optimized for AVX2, the offline model would
///   require CPUs that support AVX2.
#[derive(Debug)]
pub enum GraphOptimizationLevel {
	/// Disables all graph optimizations.
	Disable,
	/// Level 1 includes semantics-preserving graph rewrites which remove redundant nodes and redundant computation.
	/// They run before graph partitioning and thus apply to all the execution providers. Available basic/level 1 graph
	/// optimizations are as follows:
	///
	/// - Constant Folding: Statically computes parts of the graph that rely only on constant initializers. This
	///   eliminates the need to compute them during runtime.
	/// - Redundant node eliminations: Remove all redundant nodes without changing the graph structure. The following
	///   such optimizations are currently supported:
	///   * Identity Elimination
	///   * Slice Elimination
	///   * Unsqueeze Elimination
	///   * Dropout Elimination
	/// - Semantics-preserving node fusions : Fuse/fold multiple nodes into a single node. For example, Conv Add fusion
	///   folds the Add operator as the bias of the Conv operator. The following such optimizations are currently
	///   supported:
	///   * Conv Add Fusion
	///   * Conv Mul Fusion
	///   * Conv BatchNorm Fusion
	///   * Relu Clip Fusion
	///   * Reshape Fusion
	Level1,
	#[rustfmt::skip]
	/// Level 2 optimizations include complex node fusions. They are run after graph partitioning and are only applied to
	/// the nodes assigned to the CPU or CUDA execution provider. Available extended/level 2 graph optimizations are as follows:
	///
	/// | Optimization                    | EPs       | Comments                                                                       |
	/// |:------------------------------- |:--------- |:------------------------------------------------------------------------------ |
	/// | GEMM Activation Fusion          | CPU       |                                                                                |
	/// | Matmul Add Fusion               | CPU       |                                                                                |
	/// | Conv Activation Fusion          | CPU       |                                                                                |
	/// | GELU Fusion                     | CPU, CUDA |                                                                                |
	/// | Layer Normalization Fusion      | CPU, CUDA |                                                                                |
	/// | BERT Embedding Layer Fusion     | CPU, CUDA | Fuses BERT embedding layers, layer normalization, & attention mask length      |
	/// | Attention Fusion*               | CPU, CUDA |                                                                                |
	/// | Skip Layer Normalization Fusion | CPU, CUDA | Fuse bias of fully connected layers, skip connections, and layer normalization |
	/// | Bias GELU Fusion                | CPU, CUDA | Fuse bias of fully connected layers & GELU activation                          |
	/// | GELU Approximation*             | CUDA      | Disabled by default; enable with `OrtSessionOptions::EnableGeluApproximation`  |
	///
	/// > **NOTE**: To optimize performance of the BERT model, approximation is used in GELU Approximation and Attention
	/// Fusion for the CUDA execution provider. The impact on accuracy is negligible based on our evaluation; F1 score
	/// for a BERT model on SQuAD v1.1 is almost the same (87.05 vs 87.03).
	Level2,
	/// Level 3 optimizations include memory layout optimizations, which may optimize the graph to use the NCHWc memory
	/// layout rather than NCHW to improve spatial locality for some targets.
	Level3
}

impl From<GraphOptimizationLevel> for sys::GraphOptimizationLevel {
	fn from(val: GraphOptimizationLevel) -> Self {
		match val {
			GraphOptimizationLevel::Disable => sys::GraphOptimizationLevel::ORT_DISABLE_ALL,
			GraphOptimizationLevel::Level1 => sys::GraphOptimizationLevel::ORT_ENABLE_BASIC,
			GraphOptimizationLevel::Level2 => sys::GraphOptimizationLevel::ORT_ENABLE_EXTENDED,
			GraphOptimizationLevel::Level3 => sys::GraphOptimizationLevel::ORT_ENABLE_ALL
		}
	}
}

/// Execution provider allocator type.
#[derive(Debug, Copy, Clone)]
pub enum AllocatorType {
	/// Default device-specific allocator.
	Device,
	/// Arena allocator.
	Arena
}

impl From<AllocatorType> for sys::OrtAllocatorType {
	fn from(val: AllocatorType) -> Self {
		match val {
			AllocatorType::Device => sys::OrtAllocatorType::OrtDeviceAllocator,
			AllocatorType::Arena => sys::OrtAllocatorType::OrtArenaAllocator
		}
	}
}

/// Memory types for allocated memory.
#[derive(Debug, Copy, Clone)]
pub enum MemType {
	/// Any CPU memory used by non-CPU execution provider.
	CPUInput,
	/// CPU accessible memory outputted by non-CPU execution provider, i.e. CUDA_PINNED.
	CPUOutput,
	/// The default allocator for an execution provider.
	Default
}

impl MemType {
	/// Temporary CPU accessible memory allocated by non-CPU execution provider, i.e. `CUDA_PINNED`.
	pub const CPU: MemType = MemType::CPUOutput;
}

impl From<MemType> for sys::OrtMemType {
	fn from(val: MemType) -> Self {
		match val {
			MemType::CPUInput => sys::OrtMemType::OrtMemTypeCPUInput,
			MemType::CPUOutput => sys::OrtMemType::OrtMemTypeCPUOutput,
			MemType::Default => sys::OrtMemType::OrtMemTypeDefault
		}
	}
}

#[cfg(test)]
mod test {
	use super::*;

	#[test]
	fn test_char_p_to_string() {
		let s = ffi::CString::new("foo").unwrap();
		let ptr = s.as_c_str().as_ptr();
		assert_eq!("foo", char_p_to_string(ptr).unwrap());
	}
}