ort 2.0.0-rc.12

A safe Rust wrapper for ONNX Runtime 1.24 - Optimize and accelerate machine learning inference & training
Documentation
use alloc::{borrow::ToOwned, format, string::String, vec::Vec};
use core::{
	ffi::{CStr, c_char},
	mem::ManuallyDrop,
	ops::Deref,
	ptr
};

use crate::{Result, memory::Allocator};

#[cfg(feature = "std")]
#[path = "mutex_std.rs"]
mod mutex;
#[cfg(not(feature = "std"))]
#[path = "mutex_spin.rs"]
mod mutex;

mod map;
mod once_lock;
mod stack;
pub(crate) use self::{
	map::MiniMap,
	mutex::{Mutex, MutexGuard},
	once_lock::OnceLock,
	stack::*
};
#[cfg(feature = "ndarray")]
mod ndarray;
#[cfg(feature = "ndarray")]
#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))]
pub use self::ndarray::ArrayExt;

/// Preloads the dynamic library at the given `path`.
///
/// Internally, some ONNX Runtime execution providers load DLLs with specific names (like `onnxruntime_providers_cuda`).
/// There is no dedicated API to configure where these libraries are loaded from; the `preload_dylib` function offers a
/// way to configure library paths without resorting to other tricks like `rpath`.
///
/// This is because most operating systems will keep a list of currently loaded dylibs in memory and fall back to
/// an already loaded dylib if its file name matches what is requested. If we explicitly load `libfoo.so` from an
/// absolute path, then any subsequent requests to load a library called `foo` will use the `libfoo.so` we already
/// loaded, instead of searching the system for a `foo` library.
///
/// See also [`crate::ep::cuda::preload_dylibs`], a helper that uses `preload_dylib` to load all required dependencies
/// of the [CUDA execution provider](crate::ep::CUDA).
///
/// ```
/// use std::env;
///
/// // Use the DirectML DLL we ship alongside the application.
/// // This needs to come before we use any other `ort` API.
/// let application_dir = current_exe().unwrap();
/// let _ = ort::util::preload_dylib(application_dir.join("DirectML.dll"));
/// ```
#[cfg_attr(docsrs, doc(cfg(any(feature = "preload-dylibs", feature = "load-dynamic"))))]
#[cfg(all(feature = "preload-dylibs", not(target_arch = "wasm32")))]
pub fn preload_dylib<P: libloading::AsFilename>(path: P) -> Result<(), libloading::Error> {
	let library = unsafe { libloading::Library::new(path) }?;
	// Do not run `FreeLibrary` so the library remains in the loaded modules list.
	let _ = ManuallyDrop::new(library);
	Ok(())
}

#[cfg(target_family = "windows")]
pub(crate) type OsCharArray = Vec<u16>;
#[cfg(not(target_family = "windows"))]
pub(crate) type OsCharArray = Vec<core::ffi::c_char>;

#[cfg(feature = "std")]
pub(crate) fn path_to_os_char(path: impl AsRef<std::path::Path>) -> OsCharArray {
	#[cfg(not(target_family = "windows"))]
	use core::ffi::c_char;
	#[cfg(target_family = "windows")]
	use std::os::windows::ffi::OsStrExt as _;

	let path = std::ffi::OsString::from(path.as_ref());
	#[cfg(target_family = "windows")]
	let path: Vec<u16> = path.encode_wide().chain(core::iter::once(0)).collect();
	#[cfg(not(target_family = "windows"))]
	let path: Vec<c_char> = path
		.as_encoded_bytes()
		.iter()
		.chain(core::iter::once(&b'\0'))
		.map(|b| *b as c_char)
		.collect();
	path
}

#[allow(unused)]
pub(crate) fn str_to_os_char(string: &str) -> OsCharArray {
	#[cfg(target_family = "windows")]
	let os_char = string.encode_utf16().chain(core::iter::once(0)).collect();
	#[cfg(not(target_family = "windows"))]
	let os_char = string
		.as_bytes()
		.iter()
		.copied()
		.chain(core::iter::once(0))
		.map(|b| b as c_char)
		.collect();
	os_char
}

#[cold]
#[inline]
#[doc(hidden)] // Used in the `ortsys!` macro, which is publicly exposed.
pub fn cold() {}

#[allow(unused)]
pub(crate) struct RunOnDrop<F: FnOnce()> {
	runner: ManuallyDrop<F>
}

impl<F: FnOnce()> Drop for RunOnDrop<F> {
	#[inline]
	fn drop(&mut self) {
		let runner = unsafe { ptr::read(&*self.runner) };
		runner()
	}
}

/// Runs the given closure at the end of the scope.
#[allow(unused)]
pub(crate) fn run_on_drop<F: FnOnce()>(f: F) -> RunOnDrop<F> {
	RunOnDrop { runner: ManuallyDrop::new(f) }
}

pub(crate) fn char_p_to_string(raw: *const c_char) -> Result<String> {
	if raw.is_null() {
		return Ok(String::new());
	}
	let c_string = unsafe { CStr::from_ptr(raw.cast_mut()).to_owned() };
	Ok(c_string.to_string_lossy().into())
}

/// A string allocated by an ONNX Runtime [`Allocator`].
pub(crate) struct AllocatedString<'a, 'p> {
	data: &'p str,
	allocator: &'a Allocator
}

static EMPTY_DANGLING: &str = "";

impl<'a, 'p> AllocatedString<'a, 'p> {
	pub unsafe fn from_ptr(ptr: *const c_char, allocator: &'a Allocator) -> Result<Self> {
		if ptr.is_null() {
			// string slices cannot have null pointers, so we use an empty string instead
			// (as a static so we can differentiate between allocated but zero-length and null pointer/non-allocated)
			return Ok(Self { data: EMPTY_DANGLING, allocator });
		}

		let c_string = unsafe { CStr::from_ptr(ptr.cast_mut()) };
		match c_string.to_str() {
			Ok(data) => Ok(Self { data, allocator }),
			Err(e) => {
				unsafe { allocator.free(ptr.cast_mut()) };
				Err(crate::Error::new(format!("string could not be converted to UTF-8: {e}")))
			}
		}
	}
}

impl<'p> Deref for AllocatedString<'_, 'p> {
	type Target = str;
	fn deref(&self) -> &Self::Target {
		self.data
	}
}

impl<'a, 'p> Drop for AllocatedString<'a, 'p> {
	fn drop(&mut self) {
		let ptr = self.data.as_ptr();
		// strings originally created from null pointers don't need freeing since nothing was allocated
		if !ptr::eq(ptr, EMPTY_DANGLING.as_ptr()) {
			unsafe { self.allocator.free(ptr.cast_mut()) };
		}
	}
}

#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
	use alloc::{ffi::CString, sync::Arc};
	use core::ffi::CStr;
	use std::thread;

	use super::{MiniMap, Mutex, char_p_to_string, run_on_drop, with_cstr, with_cstr_ptr_array};

	#[test]
	fn test_mutex_sanity() {
		let mutex = Mutex::new(());
		for _ in 0..4 {
			drop(mutex.lock());
		}
	}

	#[test]
	fn test_mutex_threaded() {
		let mutex = Arc::new(Mutex::new(0usize));
		let threads = (0..4)
			.map(|_| {
				let mutex = Arc::clone(&mutex);
				thread::spawn(move || {
					for _ in 0..1000 {
						*mutex.lock() += 1;
					}
				})
			})
			.collect::<Vec<_>>();

		for t in threads {
			t.join().unwrap();
		}

		assert_eq!(*mutex.lock(), 4000);
	}

	#[test]
	fn test_run_on_drop() {
		let mut x = 0;
		{
			let _d = run_on_drop(|| {
				x = 42;
			});
		}
		assert_eq!(x, 42);
	}

	#[test]
	fn test_with_cstr() {
		let normal_str = "hello world";
		with_cstr(normal_str.as_bytes(), &|s| {
			assert_eq!(s.to_string_lossy(), normal_str);
			Ok(())
		})
		.unwrap();

		let longer_str = "HyperSuperUltraLongInputNameLikeAReallyLongNameSoLongInFactThatItDoesntFitOnTheStackAsSpecifiedInTheSTACK_CSTR_ARRAY_MAX_TOTAL_ConstantDefinedInUtilDotRsThisStringIsSoLongThatImStartingToRunOutOfThingsToSaySoIllJustPutZeros000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000hi0000000000000000000000000000000000000000000000000000000000000000000000000000000";
		with_cstr(longer_str.as_bytes(), &|s| {
			assert_eq!(s.to_string_lossy(), longer_str);
			Ok(())
		})
		.unwrap();
	}

	#[test]
	fn test_with_cstr_arr() {
		let normal_str = "hello world";
		let few_normal_strs = vec![normal_str; 4];
		with_cstr_ptr_array(&few_normal_strs, &|arr| {
			for &ptr in arr {
				let cstr = unsafe { CStr::from_ptr(ptr) };
				assert_eq!(cstr.to_string_lossy(), normal_str);
			}
			Ok(())
		})
		.unwrap();

		let many_normal_strs = vec![normal_str; 1000];
		with_cstr_ptr_array(&many_normal_strs, &|arr| {
			for &ptr in arr {
				let cstr = unsafe { CStr::from_ptr(ptr) };
				assert_eq!(cstr.to_string_lossy(), normal_str);
			}
			Ok(())
		})
		.unwrap();

		let longer_str = "HyperSuperUltraLongInputNameLikeAReallyLongNameSoLongInFactThatItDoesntFitOnTheStackAsSpecifiedInTheSTACK_CSTR_ARRAY_MAX_TOTAL_ConstantDefinedInUtilDotRsThisStringIsSoLongThatImStartingToRunOutOfThingsToSaySoIllJustPutZeros000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000hi0000000000000000000000000000000000000000000000000000000000000000000000000000000";
		let mut mixed_strs = vec![normal_str; 3];
		mixed_strs[1] = longer_str;
		with_cstr_ptr_array(&mixed_strs, &|arr| {
			for (i, &ptr) in arr.iter().enumerate() {
				let cstr = unsafe { CStr::from_ptr(ptr) };
				if i != 1 {
					assert_eq!(cstr.to_string_lossy(), normal_str);
				} else {
					assert_eq!(cstr.to_string_lossy(), longer_str);
				}
			}
			Ok(())
		})
		.unwrap();

		let many_longer_strs = vec![longer_str; 200];
		with_cstr_ptr_array(&many_longer_strs, &|arr| {
			for &ptr in arr {
				let cstr = unsafe { CStr::from_ptr(ptr) };
				assert_eq!(cstr.to_string_lossy(), longer_str);
			}
			Ok(())
		})
		.unwrap();
	}

	#[test]
	fn test_mini_map() {
		let mut map = MiniMap::<&'static str, u32>::new();
		map.insert("meaning", 42);
		assert_eq!(map.get("meaning"), Some(&42));
		assert_eq!(map.get("nothing"), None);

		for (k, v) in map.iter() {
			assert_eq!(*k, "meaning");
			assert_eq!(*v, 42);
		}

		map.insert("meaning", 24);
		map.insert("other", 21);
		assert_eq!(map.get("meaning"), Some(&24));
		assert_eq!(map.get("other"), Some(&21));
		assert_eq!(map.len(), 2);

		assert_eq!(map.drain().next(), Some(("meaning", 24)));
		assert_eq!(map.len(), 0);
	}

	#[test]
	fn test_char_p_to_string() {
		let s = CString::new("foo").unwrap_or_else(|_| unreachable!());
		let ptr = s.as_c_str().as_ptr();
		assert_eq!("foo", char_p_to_string(ptr).expect("failed to convert string"));
	}
}