ort 2.0.0-rc.12

A safe Rust wrapper for ONNX Runtime 1.24 - Optimize and accelerate machine learning inference & training
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
use alloc::{string::String, sync::Arc, vec::Vec};
use core::{
	ffi::{CStr, c_char, c_int},
	marker::PhantomData,
	mem,
	ptr::{self, NonNull}
};

use smallvec::SmallVec;

#[cfg(feature = "api-20")]
use crate::session::adapter::{Adapter, AdapterInner};
use crate::{
	AsPointer,
	error::Result,
	logging::LogLevel,
	ortsys,
	session::Outlet,
	util::{MiniMap, STACK_SESSION_OUTPUTS, with_cstr},
	value::{DynValue, Value, ValueTypeMarker}
};

/// Allows selecting/deselecting/preallocating the outputs of a [`Session`] inference call.
///
/// ```
/// # use std::sync::Arc;
/// # use ort::{session::{Session, RunOptions, OutputSelector}, memory::Allocator, value::Tensor};
/// # fn main() -> ort::Result<()> {
/// let mut session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?;
/// let input = Tensor::<f32>::new(&Allocator::default(), [1_usize, 64, 64, 3])?;
///
/// let output0 = session.outputs()[0].name();
/// let options = RunOptions::new()?.with_outputs(
/// 	// Disable all outputs...
/// 	OutputSelector::no_default()
/// 		// except for the first one...
/// 		.with(output0)
/// 		// and since this is a 2x upsampler model, pre-allocate the output to be twice as large.
/// 		.preallocate(output0, Tensor::<f32>::new(&Allocator::default(), [1_usize, 128, 128, 3])?)
/// );
///
/// // `outputs[0]` will be the tensor we just pre-allocated.
/// let outputs = session.run_with_options(ort::inputs![input], &options)?;
/// # 	Ok(())
/// # }
/// ```
///
/// [`Session`]: crate::session::Session
#[derive(Debug)]
pub struct OutputSelector {
	use_defaults: bool,
	default_blocklist: Vec<String>,
	allowlist: Vec<String>,
	preallocated_outputs: MiniMap<String, Value>
}

impl Default for OutputSelector {
	/// Creates an [`OutputSelector`] that enables all outputs by default. Use [`OutputSelector::without`] to disable a
	/// specific output.
	fn default() -> Self {
		Self {
			use_defaults: true,
			allowlist: Vec::new(),
			default_blocklist: Vec::new(),
			preallocated_outputs: MiniMap::new()
		}
	}
}

impl OutputSelector {
	/// Creates an [`OutputSelector`] that does not enable any outputs. Use [`OutputSelector::with`] to enable a
	/// specific output.
	pub fn no_default() -> Self {
		Self {
			use_defaults: false,
			..Default::default()
		}
	}

	/// Mark the output specified by the `name` for inclusion.
	pub fn with(mut self, name: impl Into<String>) -> Self {
		self.allowlist.push(name.into());
		self
	}

	/// Mark the output specified by `name` to be **excluded**. ONNX Runtime may prune some of the output node's
	/// ancestor nodes.
	pub fn without(mut self, name: impl Into<String>) -> Self {
		self.default_blocklist.push(name.into());
		self
	}

	/// Pre-allocates an output. Assuming the type & shape of the value matches what is expected by the model, the
	/// output value corresponding to `name` returned by the inference call will be the exact same value as the
	/// pre-allocated value.
	///
	/// **The same value will be reused as long as this [`OutputSelector`] and its parent [`RunOptions`] is used**, so
	/// if you use the same `RunOptions` across multiple runs with a preallocated value, the preallocated value will be
	/// overwritten upon each run.
	///
	/// This can improve performance if the size and type of the output is known, and does not change between runs, i.e.
	/// for an ODE or embeddings model.
	///
	/// ```
	/// # use std::sync::Arc;
	/// # use ort::{session::{Session, RunOptions, OutputSelector}, memory::Allocator, value::Tensor};
	/// # fn main() -> ort::Result<()> {
	/// let mut session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?;
	/// let input = Tensor::<f32>::new(&Allocator::default(), [1_usize, 64, 64, 3])?;
	///
	/// let output0 = session.outputs()[0].name();
	/// let options = RunOptions::new()?.with_outputs(
	/// 	OutputSelector::default()
	/// 		.preallocate(output0, Tensor::<f32>::new(&Allocator::default(), [1_usize, 128, 128, 3])?)
	/// );
	///
	/// let outputs = session.run_with_options(ort::inputs![input], &options)?;
	/// # 	Ok(())
	/// # }
	/// ```
	pub fn preallocate<T: ValueTypeMarker>(mut self, name: impl Into<String>, value: Value<T>) -> Self {
		self.preallocated_outputs.insert(name.into(), value.into_dyn());
		self
	}

	pub(crate) fn resolve_outputs<'a, 's: 'a>(
		&'a self,
		outputs: &'s [Outlet]
	) -> (SmallVec<[&'a str; STACK_SESSION_OUTPUTS]>, SmallVec<[Option<DynValue>; STACK_SESSION_OUTPUTS]>) {
		if self.use_defaults { outputs.iter() } else { [].iter() }
			.map(|o| o.name())
			.filter(|n| !self.default_blocklist.iter().any(|e| e == n))
			.chain(self.allowlist.iter().map(|x| x.as_str()))
			.map(|n| (n, self.preallocated_outputs.get(n).map(DynValue::clone_of)))
			.unzip()
	}
}

/// Types that specify whether a [`RunOptions`] was configured with an [`OutputSelector`].
pub trait SelectedOutputMarker {}
/// Marks that a [`RunOptions`] was not configured with an [`OutputSelector`].
pub struct NoSelectedOutputs;
impl SelectedOutputMarker for NoSelectedOutputs {}
/// Marks that a [`RunOptions`] was configured with an [`OutputSelector`].
pub struct HasSelectedOutputs;
impl SelectedOutputMarker for HasSelectedOutputs {}

#[derive(Debug)]
pub(crate) struct UntypedRunOptions {
	pub(crate) ptr: NonNull<ort_sys::OrtRunOptions>,
	pub(crate) outputs: OutputSelector,
	#[cfg(feature = "api-20")]
	adapters: Vec<Arc<AdapterInner>>
}

impl UntypedRunOptions {
	pub fn terminate(&self) -> Result<()> {
		ortsys![unsafe RunOptionsSetTerminate(self.ptr.as_ptr())?];
		Ok(())
	}
}

// https://onnxruntime.ai/docs/api/c/struct_ort_api.html#ac2a08cac0a657604bd5899e0d1a13675
unsafe impl Send for UntypedRunOptions {}

impl Drop for UntypedRunOptions {
	fn drop(&mut self) {
		ortsys![unsafe ReleaseRunOptions(self.ptr.as_ptr())];
		crate::logging::drop!(RunOptions, self.ptr);
	}
}

/// Allows for finer control over session inference.
///
/// [`RunOptions`] provides three main features:
/// - **Run tagging**: Each individual session run can have a uniquely identifiable tag attached with
///   [`RunOptions::set_tag`], which will show up in logs. This can be especially useful for debugging
///   performance/errors in inference servers.
/// - **Termination**: Allows for terminating an inference call from another thread; when [`RunOptions::terminate`] is
///   called, any sessions currently running under that [`RunOptions`] instance will halt graph execution as soon as the
///   termination signal is received. This allows for [`Session::run_async`]'s cancel-safety.
/// - **Output specification**: Certain session outputs can be [disabled](`OutputSelector::without`) or
///   [pre-allocated](`OutputSelector::preallocate`). Disabling an output might mean ONNX Runtime will not execute parts
///   of the graph that are only used by that output. Pre-allocation can reduce expensive re-allocations by allowing you
///   to use the same memory across runs.
///
/// [`RunOptions`] can be passed to most places where a session can be inferred, e.g.
/// [`Session::run_with_options`], [`Session::run_async`],
/// [`Session::run_binding_with_options`]. Some of these patterns (notably `IoBinding`) do not accept
/// [`OutputSelector`], hence [`RunOptions`] contains an additional type parameter that marks whether or not outputs
/// have been selected.
///
/// [`Session::run_async`]: crate::session::Session::run_async
/// [`Session::run_with_options`]: crate::session::Session::run_with_options
/// [`Session::run_binding_with_options`]: crate::session::Session::run_binding_with_options
#[derive(Debug)]
pub struct RunOptions<O: SelectedOutputMarker = NoSelectedOutputs> {
	pub(crate) inner: Arc<UntypedRunOptions>,
	_marker: PhantomData<O>
}

unsafe impl<O: SelectedOutputMarker> Send for RunOptions<O> {}
// Only allow `Sync` if we don't have (potentially pre-allocated) outputs selected.
// Allowing `Sync` here would mean a single pre-allocated `Value` could be mutated simultaneously in different threads -
// a brazen crime against crabkind.
unsafe impl Sync for RunOptions<NoSelectedOutputs> {}

impl RunOptions {
	/// Creates a new [`RunOptions`] struct.
	pub fn new() -> Result<RunOptions<NoSelectedOutputs>> {
		let mut ptr: *mut ort_sys::OrtRunOptions = ptr::null_mut();
		ortsys![unsafe CreateRunOptions(&mut ptr)?; nonNull(ptr)];
		crate::logging::create!(RunOptions, ptr);
		Ok(RunOptions {
			inner: Arc::new(UntypedRunOptions {
				ptr,
				outputs: OutputSelector::default(),
				#[cfg(feature = "api-20")]
				adapters: Vec::new()
			}),
			_marker: PhantomData
		})
	}
}

impl<O: SelectedOutputMarker> RunOptions<O> {
	/// Select/deselect/preallocate outputs for this run.
	///
	/// See [`OutputSelector`] for more details.
	///
	/// ```
	/// # use std::sync::Arc;
	/// # use ort::{session::{Session, RunOptions, OutputSelector}, memory::Allocator, value::Tensor};
	/// # fn main() -> ort::Result<()> {
	/// let mut session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?;
	/// let input = Tensor::<f32>::new(&Allocator::default(), [1_usize, 64, 64, 3])?;
	///
	/// let output0 = session.outputs()[0].name();
	/// let options = RunOptions::new()?.with_outputs(
	/// 	// Disable all outputs...
	/// 	OutputSelector::no_default()
	/// 		// except for the first one...
	/// 		.with(output0)
	/// 		// and since this is a 2x upsampler model, pre-allocate the output to be twice as large.
	/// 		.preallocate(output0, Tensor::<f32>::new(&Allocator::default(), [1_usize, 128, 128, 3])?)
	/// );
	///
	/// // `outputs[0]` will be the tensor we just pre-allocated.
	/// let outputs = session.run_with_options(ort::inputs![input], &options)?;
	/// # 	Ok(())
	/// # }
	/// ```
	pub fn with_outputs(mut self, outputs: OutputSelector) -> RunOptions<HasSelectedOutputs> {
		let Some(inner) = Arc::get_mut(&mut self.inner) else {
			panic!("Expected RunOptions to have exclusive access");
		};
		inner.outputs = outputs;
		unsafe { mem::transmute(self) }
	}

	/// Sets a tag to identify this run in logs.
	pub fn with_tag(mut self, tag: impl AsRef<str>) -> Result<Self> {
		self.set_tag(tag).map(|_| self)
	}

	/// Sets a tag to identify this run in logs.
	pub fn set_tag(&mut self, tag: impl AsRef<str>) -> Result<()> {
		with_cstr(tag.as_ref().as_bytes(), &|tag| {
			ortsys![unsafe RunOptionsSetRunTag(self.inner.ptr.as_ptr(), tag.as_ptr())?];
			Ok(())
		})
	}

	pub fn tag(&self) -> Result<&str> {
		let mut tag_ptr: *const c_char = ptr::null();
		ortsys![unsafe RunOptionsGetRunTag(self.inner.ptr.as_ptr(), &mut tag_ptr)?];
		Ok(unsafe { CStr::from_ptr(tag_ptr) }.to_str()?)
	}

	/// Sets the termination flag for the runs associated with this [`RunOptions`].
	///
	/// This function returns immediately (it does not wait for the session run to terminate). The run will terminate as
	/// soon as it is able to.
	///
	/// ```no_run
	/// # // no_run because upsample.onnx is too simple of a model for the termination signal to be reliable enough
	/// # use std::sync::Arc;
	/// # use ort::{session::{Session, RunOptions, OutputSelector}, value::Value};
	/// # fn main() -> ort::Result<()> {
	/// # 	let mut session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?;
	/// # 	let input = Value::from_array(ndarray::Array4::<f32>::zeros((1, 64, 64, 3)))?;
	/// let run_options = Arc::new(RunOptions::new()?);
	///
	/// let run_options_ = Arc::clone(&run_options);
	/// std::thread::spawn(move || {
	/// 	let _ = run_options_.terminate();
	/// });
	///
	/// let res = session.run_with_options(ort::inputs![input], &*run_options);
	/// // upon termination, the session will return an `Error::SessionRun` error.`
	/// assert_eq!(
	/// 	&res.unwrap_err().to_string(),
	/// 	"Failed to run inference on model: Exiting due to terminate flag being set to true."
	/// );
	/// # 	Ok(())
	/// # }
	/// ```
	pub fn terminate(&self) -> Result<()> {
		self.inner.terminate()
	}

	/// Resets the termination flag for the runs associated with [`RunOptions`].
	///
	/// ```no_run
	/// # use std::sync::Arc;
	/// # use ort::{session::{Session, RunOptions, OutputSelector}, value::Value};
	/// # fn main() -> ort::Result<()> {
	/// # 	let mut session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?;
	/// # 	let input = Value::from_array(ndarray::Array4::<f32>::zeros((1, 64, 64, 3)))?;
	/// let run_options = Arc::new(RunOptions::new()?);
	///
	/// let run_options_ = Arc::clone(&run_options);
	/// std::thread::spawn(move || {
	/// 	let _ = run_options_.terminate();
	/// 	// ...oops, didn't mean to do that
	/// 	let _ = run_options_.unterminate();
	/// });
	///
	/// let res = session.run_with_options(ort::inputs![input], &*run_options);
	/// assert!(res.is_ok());
	/// # 	Ok(())
	/// # }
	/// ```
	pub fn unterminate(&self) -> Result<()> {
		ortsys![unsafe RunOptionsUnsetTerminate(self.inner.ptr.as_ptr())?];
		Ok(())
	}

	/// Adds a custom configuration option to the `RunOptions`.
	///
	/// This can be used to, for example, configure the graph ID when using compute graphs with an execution provider
	/// like CUDA:
	/// ```no_run
	/// # use std::sync::Arc;
	/// # use ort::session::RunOptions;
	/// # fn main() -> ort::Result<()> {
	/// let mut run_options = RunOptions::new()?;
	/// run_options.add_config_entry("gpu_graph_id", "1")?;
	/// # 	Ok(())
	/// # }
	/// ```
	pub fn add_config_entry(&mut self, key: impl AsRef<str>, value: impl AsRef<str>) -> Result<()> {
		with_cstr(key.as_ref().as_bytes(), &|key| {
			with_cstr(value.as_ref().as_bytes(), &|value| {
				ortsys![unsafe AddRunConfigEntry(self.inner.ptr.as_ptr(), key.as_ptr(), value.as_ptr())?];
				Ok(())
			})
		})
	}

	#[cfg(feature = "api-20")]
	#[cfg_attr(docsrs, doc(cfg(feature = "api-20")))]
	pub fn add_adapter(&mut self, adapter: &Adapter) -> Result<()> {
		let Some(inner) = Arc::get_mut(&mut self.inner) else {
			panic!("Expected RunOptions to have exclusive access");
		};
		ortsys![unsafe RunOptionsAddActiveLoraAdapter(inner.ptr.as_ptr(), adapter.ptr())?];
		inner.adapters.push(Arc::clone(&adapter.inner));
		Ok(())
	}

	pub fn set_log_level(&mut self, level: LogLevel) -> Result<()> {
		ortsys![unsafe RunOptionsSetRunLogSeverityLevel(self.ptr_mut(), ort_sys::OrtLoggingLevel::from(level) as _)?];
		Ok(())
	}

	pub fn log_level(&self) -> Result<LogLevel> {
		let mut log_level = ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE;
		ortsys![unsafe RunOptionsGetRunLogSeverityLevel(self.ptr(), &mut log_level as *mut ort_sys::OrtLoggingLevel as *mut _)?];
		Ok(LogLevel::from(log_level))
	}

	pub fn set_log_verbosity(&mut self, verbosity: c_int) -> Result<()> {
		ortsys![unsafe RunOptionsSetRunLogVerbosityLevel(self.ptr_mut(), verbosity)?];
		Ok(())
	}

	pub fn log_verbosity(&self) -> Result<i32> {
		let mut verbosity = 0;
		ortsys![unsafe RunOptionsGetRunLogVerbosityLevel(self.ptr(), &mut verbosity)?];
		Ok(verbosity)
	}

	pub fn disable_device_sync(&mut self) -> Result<()> {
		self.add_config_entry("disable_synchronize_execution_providers", "1")
	}
}

impl<O: SelectedOutputMarker> AsPointer for RunOptions<O> {
	type Sys = ort_sys::OrtRunOptions;

	fn ptr(&self) -> *const Self::Sys {
		self.inner.ptr.as_ptr()
	}
}