use alloc::sync::Arc;
use core::{
ffi::{CStr, c_char, c_int, c_void},
mem,
ptr::{self, NonNull},
str
};
use crate::{
AsPointer,
error::Result,
ortsys,
session::{Session, SharedSessionInner}
};
#[derive(Debug)]
pub struct Allocator {
ptr: NonNull<ort_sys::OrtAllocator>,
is_default: bool,
info: MemoryInfo,
_session_inner: Option<Arc<SharedSessionInner>>
}
unsafe impl Send for Allocator {}
impl Allocator {
pub(crate) unsafe fn from_raw(ptr: NonNull<ort_sys::OrtAllocator>) -> Allocator {
let mut memory_info_ptr = ptr::null();
ortsys![unsafe AllocatorGetInfo(ptr.as_ptr(), &mut memory_info_ptr).expect("Failed to get memory info"); nonNull(memory_info_ptr)];
Allocator {
ptr,
is_default: false,
info: MemoryInfo::from_raw(memory_info_ptr, false),
_session_inner: None
}
}
pub fn alloc<T>(&self, len: usize) -> Result<AllocatedBlock<'_>> {
let mut ptr = ptr::null_mut();
ortsys![unsafe AllocatorAlloc(self.ptr.as_ptr(), (len * mem::size_of::<T>()) as _, &mut ptr)?; nonNull(ptr)];
crate::logging::create!(AllocatedBlock, ptr);
Ok(AllocatedBlock { ptr, allocator: self })
}
pub unsafe fn free<T>(&self, ptr: *mut T) {
ortsys![unsafe AllocatorFree(self.ptr.as_ptr(), ptr.cast()).expect("Failed to free")];
}
pub fn memory_info(&self) -> &MemoryInfo {
&self.info
}
pub fn new(session: &Session, memory_info: MemoryInfo) -> Result<Self> {
let mut ptr: *mut ort_sys::OrtAllocator = ptr::null_mut();
ortsys![unsafe CreateAllocator(session.ptr(), memory_info.ptr.as_ptr(), &mut ptr)?; nonNull(ptr)];
crate::logging::create!(Allocator, ptr);
Ok(Self {
ptr,
is_default: false,
info: memory_info,
_session_inner: Some(session.inner())
})
}
}
impl Default for Allocator {
fn default() -> Self {
let mut allocator_ptr: *mut ort_sys::OrtAllocator = ptr::null_mut();
ortsys![
unsafe GetAllocatorWithDefaultOptions(&mut allocator_ptr)
.expect("Failed to get default allocator");
nonNull(allocator_ptr)
];
let mut memory_info_ptr = ptr::null();
ortsys![unsafe AllocatorGetInfo(allocator_ptr.as_ptr(), &mut memory_info_ptr).expect("Failed to get memory info"); nonNull(memory_info_ptr)];
Self {
ptr: allocator_ptr,
is_default: true,
info: MemoryInfo::from_raw(memory_info_ptr, false),
_session_inner: None
}
}
}
impl AsPointer for Allocator {
type Sys = ort_sys::OrtAllocator;
fn ptr(&self) -> *const Self::Sys {
self.ptr.as_ptr()
}
}
impl Drop for Allocator {
fn drop(&mut self) {
if !self.is_default {
ortsys![unsafe ReleaseAllocator(self.ptr.as_ptr())];
crate::logging::drop!(Allocator, self.ptr);
}
}
}
pub struct AllocatedBlock<'a> {
ptr: NonNull<c_void>,
allocator: &'a Allocator
}
impl<'a> AllocatedBlock<'a> {
pub fn as_ptr(&self) -> *const c_void {
self.ptr.as_ptr()
}
pub fn as_mut_ptr(&mut self) -> *mut c_void {
self.ptr.as_ptr()
}
pub fn allocator(&self) -> &'a Allocator {
self.allocator
}
#[must_use = "the returned pointer must be freed with the allocator that created it"]
pub fn into_raw(self) -> *mut c_void {
let ptr = self.ptr;
mem::forget(self);
ptr.as_ptr()
}
}
impl Drop for AllocatedBlock<'_> {
fn drop(&mut self) {
unsafe { self.allocator.free(self.ptr.as_ptr()) };
crate::logging::drop!(AllocatedBlock, self.ptr);
}
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
pub struct AllocationDevice(&'static str);
impl AllocationDevice {
pub const CPU: AllocationDevice = AllocationDevice("Cpu\0");
pub const CUDA: AllocationDevice = AllocationDevice("Cuda\0");
pub const CUDA_PINNED: AllocationDevice = AllocationDevice("CudaPinned\0");
pub const CANN: AllocationDevice = AllocationDevice("Cann\0");
pub const CANN_PINNED: AllocationDevice = AllocationDevice("CannPinned\0");
pub const DIRECTML: AllocationDevice = AllocationDevice("DML\0");
pub const HIP: AllocationDevice = AllocationDevice("Hip\0");
pub const HIP_PINNED: AllocationDevice = AllocationDevice("HipPinned\0");
pub const OPENVINO_CPU: AllocationDevice = AllocationDevice("OpenVINO_CPU\0");
pub const OPENVINO_GPU: AllocationDevice = AllocationDevice("OpenVINO_GPU\0");
pub const QNN_HTP_SHARED: AllocationDevice = AllocationDevice("QnnHtpShared\0");
pub const WEBGPU_BUFFER: AllocationDevice = AllocationDevice("WebGPU_Buffer\0");
pub fn as_str(&self) -> &'static str {
&self.0[..self.0.len() - 1]
}
}
impl PartialEq<str> for AllocationDevice {
fn eq(&self, other: &str) -> bool {
self.0 == other
}
}
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum AllocatorType {
Device,
Arena
}
impl From<AllocatorType> for ort_sys::OrtAllocatorType {
fn from(val: AllocatorType) -> Self {
match val {
AllocatorType::Device => ort_sys::OrtAllocatorType::OrtDeviceAllocator,
AllocatorType::Arena => ort_sys::OrtAllocatorType::OrtArenaAllocator
}
}
}
#[derive(Default, Debug, Copy, Clone, PartialEq, Eq)]
pub enum MemoryType {
CPUInput,
CPUOutput,
#[default]
Default
}
impl MemoryType {
pub const CPU: MemoryType = MemoryType::CPUOutput;
}
impl From<MemoryType> for ort_sys::OrtMemType {
fn from(val: MemoryType) -> Self {
match val {
MemoryType::CPUInput => ort_sys::OrtMemType::OrtMemTypeCPUInput,
MemoryType::CPUOutput => ort_sys::OrtMemType::OrtMemTypeCPUOutput,
MemoryType::Default => ort_sys::OrtMemType::OrtMemTypeDefault
}
}
}
impl From<ort_sys::OrtMemType> for MemoryType {
fn from(value: ort_sys::OrtMemType) -> Self {
match value {
ort_sys::OrtMemType::OrtMemTypeCPUInput => MemoryType::CPUInput,
ort_sys::OrtMemType::OrtMemTypeCPUOutput => MemoryType::CPUOutput,
ort_sys::OrtMemType::OrtMemTypeDefault => MemoryType::Default
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[allow(clippy::upper_case_acronyms)]
pub enum DeviceType {
CPU,
GPU,
NPU
}
impl From<DeviceType> for ort_sys::OrtMemoryInfoDeviceType {
fn from(value: DeviceType) -> Self {
match value {
DeviceType::CPU => ort_sys::OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_CPU,
DeviceType::GPU => ort_sys::OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_GPU,
DeviceType::NPU => ort_sys::OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_FPGA
}
}
}
impl From<ort_sys::OrtMemoryInfoDeviceType> for DeviceType {
fn from(value: ort_sys::OrtMemoryInfoDeviceType) -> Self {
match value {
ort_sys::OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_CPU => DeviceType::CPU,
ort_sys::OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_GPU => DeviceType::GPU,
ort_sys::OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_FPGA => DeviceType::NPU
}
}
}
#[derive(Debug)]
pub struct MemoryInfo {
ptr: NonNull<ort_sys::OrtMemoryInfo>,
should_release: bool
}
impl MemoryInfo {
pub fn new(allocation_device: AllocationDevice, device_id: c_int, allocator_type: AllocatorType, memory_type: MemoryType) -> Result<Self> {
let mut ptr: *mut ort_sys::OrtMemoryInfo = ptr::null_mut();
ortsys![
unsafe CreateMemoryInfo(allocation_device.as_str().as_ptr().cast(), allocator_type.into(), device_id, memory_type.into(), &mut ptr)?;
nonNull(ptr)
];
crate::logging::create!(MemoryInfo, ptr);
Ok(Self { ptr, should_release: true })
}
pub(crate) unsafe fn from_value(value_ptr: NonNull<ort_sys::OrtValue>) -> Option<Self> {
let mut is_tensor = 0;
ortsys![unsafe IsTensor(value_ptr.as_ptr(), &mut is_tensor).expect("infallible")];
if is_tensor != 0 {
let mut memory_info_ptr: *const ort_sys::OrtMemoryInfo = ptr::null_mut();
ortsys![unsafe GetTensorMemoryInfo(value_ptr.as_ptr(), &mut memory_info_ptr).expect("infallible"); nonNull(memory_info_ptr)];
Some(Self::from_raw(memory_info_ptr, false))
} else {
None
}
}
pub(crate) fn from_raw(ptr: NonNull<ort_sys::OrtMemoryInfo>, should_release: bool) -> Self {
MemoryInfo { ptr, should_release }
}
pub fn memory_type(&self) -> MemoryType {
let mut raw_type: ort_sys::OrtMemType = ort_sys::OrtMemType::OrtMemTypeDefault;
ortsys![unsafe MemoryInfoGetMemType(self.ptr.as_ptr(), &mut raw_type).expect("infallible")];
MemoryType::from(raw_type)
}
pub fn allocator_type(&self) -> AllocatorType {
let mut raw_type: ort_sys::OrtAllocatorType = ort_sys::OrtAllocatorType::OrtInvalidAllocator;
ortsys![unsafe MemoryInfoGetType(self.ptr.as_ptr(), &mut raw_type).expect("infallible")];
match raw_type {
ort_sys::OrtAllocatorType::OrtArenaAllocator => AllocatorType::Arena,
ort_sys::OrtAllocatorType::OrtDeviceAllocator => AllocatorType::Device,
_ => unreachable!()
}
}
pub fn allocation_device(&self) -> AllocationDevice {
let mut name_ptr: *const c_char = ptr::null_mut();
ortsys![unsafe MemoryInfoGetName(self.ptr.as_ptr(), &mut name_ptr).expect("infallible"); nonNull(name_ptr)];
let name = unsafe { CStr::from_ptr(name_ptr.as_ptr()) };
AllocationDevice(core::str::from_utf8(name.to_bytes_with_nul()).expect("invalid allocation device name"))
}
pub fn device_id(&self) -> i32 {
let mut raw: ort_sys::c_int = 0;
ortsys![unsafe MemoryInfoGetId(self.ptr.as_ptr(), &mut raw).expect("infallible")];
raw as _
}
pub fn device_type(&self) -> DeviceType {
let mut raw: ort_sys::OrtMemoryInfoDeviceType = ort_sys::OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_CPU;
ortsys![unsafe MemoryInfoGetDeviceType(self.ptr.as_ptr(), &mut raw)];
raw.into()
}
pub fn is_cpu_accessible(&self) -> bool {
#[cfg(feature = "api-23")]
return self.device_type() == DeviceType::CPU
|| ortsys![unsafe MemoryInfoGetDeviceMemType(self.ptr.as_ptr())] == ort_sys::OrtDeviceMemoryType::OrtDeviceMemoryType_HOST_ACCESSIBLE;
#[cfg(not(feature = "api-23"))]
return self.device_type() == DeviceType::CPU;
}
}
impl Default for MemoryInfo {
fn default() -> Self {
MemoryInfo::new(AllocationDevice::CPU, 0, AllocatorType::Device, MemoryType::Default).expect("failed to create default memory info")
}
}
impl Clone for MemoryInfo {
fn clone(&self) -> Self {
MemoryInfo::new(self.allocation_device(), self.device_id(), self.allocator_type(), self.memory_type()).expect("failed to clone memory info")
}
}
impl PartialEq<MemoryInfo> for MemoryInfo {
fn eq(&self, other: &MemoryInfo) -> bool {
let mut out = 0;
ortsys![unsafe CompareMemoryInfo(self.ptr.as_ptr(), other.ptr.as_ptr(), &mut out).expect("infallible")]; out == 0
}
}
impl AsPointer for MemoryInfo {
type Sys = ort_sys::OrtMemoryInfo;
fn ptr(&self) -> *const Self::Sys {
self.ptr.as_ptr()
}
}
impl Drop for MemoryInfo {
fn drop(&mut self) {
if self.should_release {
ortsys![unsafe ReleaseMemoryInfo(self.ptr.as_ptr())];
crate::logging::drop!(MemoryInfo, self.ptr);
}
}
}
#[cfg(test)]
mod tests {
use super::{AllocationDevice, AllocatorType, MemoryInfo, MemoryType};
#[test]
fn test_memory_info_eq() -> crate::Result<()> {
let a = MemoryInfo::new(AllocationDevice::CUDA, 1, AllocatorType::Device, MemoryType::Default)?;
let b = MemoryInfo::new(AllocationDevice::CUDA, 1, AllocatorType::Device, MemoryType::Default)?;
assert_eq!(a, b);
let c = MemoryInfo::new(AllocationDevice::CPU, 0, AllocatorType::Device, MemoryType::Default)?;
assert_ne!(a, c);
Ok(())
}
#[test]
#[cfg(feature = "cuda")]
fn test_cpu_accessible() -> crate::Result<()> {
let mem = MemoryInfo::new(AllocationDevice::CUDA_PINNED, 0, AllocatorType::Device, MemoryType::Default)?;
assert!(mem.is_cpu_accessible());
let mem = MemoryInfo::new(AllocationDevice::CUDA, 0, AllocatorType::Device, MemoryType::Default)?;
assert!(!mem.is_cpu_accessible());
let mem = MemoryInfo::new(AllocationDevice::CPU, 0, AllocatorType::Device, MemoryType::Default)?;
assert!(mem.is_cpu_accessible());
Ok(())
}
}