ort 1.16.3

A Rust wrapper for ONNX Runtime 1.16 - Optimize and Accelerate Machine Learning Inferencing
Documentation
use std::ffi::{c_char, c_int, CStr, CString};

use super::{error::OrtResult, ortsys, sys, AllocatorType, MemType};
use crate::OrtError;

#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum AllocationDevice {
	CPU,
	CUDA,
	CUDAPinned,
	CANN,
	CANNPinned,
	DirectML,
	OpenVINOCPU,
	OpenVINOGPU
}

impl AllocationDevice {
	pub fn as_str(&self) -> &'static str {
		match self {
			Self::CPU => "Cpu",
			Self::CUDA => "Cuda",
			Self::CUDAPinned => "CudaPinned",
			Self::CANN => "Cann",
			Self::CANNPinned => "CannPinned",
			Self::DirectML => "Dml",
			Self::OpenVINOCPU => "OpenVINO_CPU",
			Self::OpenVINOGPU => "OpenVINO_GPU"
		}
	}
}

impl From<&str> for AllocationDevice {
	fn from(value: &str) -> Self {
		match value {
			"Cpu" => AllocationDevice::CPU,
			"CUDA_CPU" => AllocationDevice::CPU,
			"Cuda" => AllocationDevice::CUDA,
			"CudaPinned" => AllocationDevice::CUDAPinned,
			"Cann" => AllocationDevice::CANN,
			"CannPinned" => AllocationDevice::CANNPinned,
			"Dml" => AllocationDevice::DirectML,
			"OpenVINO_CPU" => AllocationDevice::OpenVINOCPU,
			"OpenVINO_GPU" => AllocationDevice::OpenVINOGPU,
			other => unimplemented!("not implemented `{other}`")
		}
	}
}

#[derive(Debug)]
pub struct MemoryInfo {
	pub(crate) ptr: *mut sys::OrtMemoryInfo,
	pub(crate) should_release: bool
}

impl MemoryInfo {
	#[tracing::instrument]
	pub fn new_cpu(allocator: AllocatorType, memory_type: MemType) -> OrtResult<Self> {
		let mut memory_info_ptr: *mut sys::OrtMemoryInfo = std::ptr::null_mut();
		ortsys![
			unsafe CreateCpuMemoryInfo(allocator.into(), memory_type.into(), &mut memory_info_ptr);
			nonNull(memory_info_ptr)
		];
		Ok(Self {
			ptr: memory_info_ptr,
			should_release: true
		})
	}

	#[tracing::instrument]
	pub fn new(allocation_device: AllocationDevice, device_id: c_int, allocator_type: AllocatorType, memory_type: MemType) -> OrtResult<Self> {
		let mut memory_info_ptr: *mut sys::OrtMemoryInfo = std::ptr::null_mut();
		let allocator_name = CString::new(allocation_device.as_str()).unwrap();
		ortsys![
			unsafe CreateMemoryInfo(allocator_name.as_ptr(), allocator_type.into(), device_id, memory_type.into(), &mut memory_info_ptr);
			nonNull(memory_info_ptr)
		];
		Ok(Self {
			ptr: memory_info_ptr,
			should_release: true
		})
	}

	#[allow(clippy::not_unsafe_ptr_arg_deref)]
	pub fn allocation_device(memory_info_ptr: *const sys::OrtMemoryInfo) -> OrtResult<AllocationDevice> {
		let mut name_ptr: *const c_char = std::ptr::null_mut();
		ortsys![unsafe MemoryInfoGetName(memory_info_ptr, &mut name_ptr) -> OrtError::CreateCpuMemoryInfo; nonNull(name_ptr)];
		let name: String = unsafe { CStr::from_ptr(name_ptr) }.to_string_lossy().to_string();
		Ok(AllocationDevice::from(name.as_str()))
	}
}

impl Drop for MemoryInfo {
	#[tracing::instrument]
	fn drop(&mut self) {
		if !self.ptr.is_null() && self.should_release {
			ortsys![unsafe ReleaseMemoryInfo(self.ptr)];
		}

		self.ptr = std::ptr::null_mut();
	}
}

#[cfg(test)]
mod tests {
	use test_log::test;

	use super::*;

	#[test]
	fn create_memory_info() {
		let memory_info = MemoryInfo::new_cpu(AllocatorType::Device, MemType::Default).unwrap();
		std::mem::drop(memory_info);
	}
}