ort 2.0.0-rc.12

A safe Rust wrapper for ONNX Runtime 1.24 - Optimize and accelerate machine learning inference & training
Documentation
use super::{ExecutionProvider, RegisterError};
use crate::{AsPointer, error::Result, ortsys, session::builder::SessionBuilder};

/// The default CPU execution provider, powered by MLAS.
#[derive(Debug, Default, Clone)]
pub struct CPU {
	use_arena: bool
}

super::impl_ep!(CPU);

impl CPU {
	/// Enable/disable the usage of the arena allocator.
	///
	/// ```
	/// # use ort::{ep, session::Session};
	/// # fn main() -> ort::Result<()> {
	/// let ep = ep::CPU::default().with_arena_allocator(true).build();
	/// # Ok(())
	/// # }
	/// ```
	#[must_use]
	pub fn with_arena_allocator(mut self, enable: bool) -> Self {
		self.use_arena = enable;
		self
	}
}

impl ExecutionProvider for CPU {
	fn name(&self) -> &'static str {
		"CPUExecutionProvider"
	}

	// The CPU execution provider is always available.
	fn is_available(&self) -> Result<bool> {
		Ok(true)
	}

	fn supported_by_platform(&self) -> bool {
		true
	}

	fn register(&self, session_builder: &mut SessionBuilder) -> Result<(), RegisterError> {
		if self.use_arena {
			ortsys![unsafe EnableCpuMemArena(session_builder.ptr_mut())?];
		} else {
			ortsys![unsafe DisableCpuMemArena(session_builder.ptr_mut())?];
		}
		Ok(())
	}
}