use crate::sys;
use crate::Error;
use std::ffi::{CStr, CString};
use std::ptr;
use std::sync::Arc;
pub fn num_physical_devices() -> i32 {
unsafe { sys::oidnGetNumPhysicalDevices() }
}
pub fn get_physical_device_bool(physical_device_id: i32, name: &str) -> bool {
let c_name = CString::new(name).unwrap();
unsafe { sys::oidnGetPhysicalDeviceBool(physical_device_id, c_name.as_ptr()) }
}
pub fn get_physical_device_int(physical_device_id: i32, name: &str) -> i32 {
let c_name = CString::new(name).unwrap();
unsafe { sys::oidnGetPhysicalDeviceInt(physical_device_id, c_name.as_ptr()) }
}
pub fn get_physical_device_string(physical_device_id: i32, name: &str) -> Option<String> {
let c_name = CString::new(name).unwrap();
let p = unsafe { sys::oidnGetPhysicalDeviceString(physical_device_id, c_name.as_ptr()) };
if p.is_null() {
return None;
}
Some(unsafe { CStr::from_ptr(p).to_string_lossy().into_owned() })
}
pub fn get_physical_device_data(physical_device_id: i32, name: &str) -> Option<(*const std::ffi::c_void, usize)> {
let c_name = CString::new(name).unwrap();
let mut size = 0usize;
let p = unsafe { sys::oidnGetPhysicalDeviceData(physical_device_id, c_name.as_ptr(), &mut size) };
if p.is_null() {
None
} else {
Some((p, size))
}
}
pub fn is_cpu_device_supported() -> bool {
unsafe { sys::oidnIsCPUDeviceSupported() }
}
pub fn is_cuda_device_supported(device_id: i32) -> bool {
unsafe { sys::oidnIsCUDADeviceSupported(device_id) }
}
pub fn is_hip_device_supported(device_id: i32) -> bool {
unsafe { sys::oidnIsHIPDeviceSupported(device_id) }
}
pub unsafe fn is_metal_device_supported(device: *mut std::ffi::c_void) -> bool {
sys::oidnIsMetalDeviceSupported(device)
}
pub fn take_global_error() -> Option<Error> {
let mut msg_ptr: *const std::ffi::c_char = ptr::null();
let code = unsafe { sys::oidnGetDeviceError(ptr::null_mut(), &mut msg_ptr) };
if code == sys::OIDNError::None {
return None;
}
let message = if msg_ptr.is_null() {
String::new()
} else {
unsafe { CStr::from_ptr(msg_ptr).to_string_lossy().into_owned() }
};
Some(Error::OidnError { code: code as u32, message })
}
#[derive(Clone)]
pub struct OidnDevice {
pub(crate) raw: sys::OIDNDevice,
_refcount: Arc<()>,
}
impl std::fmt::Debug for OidnDevice {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("OidnDevice").finish_non_exhaustive()
}
}
impl OidnDevice {
pub fn new() -> Result<Self, Error> {
Self::with_type(OidnDeviceType::Default)
}
pub fn cpu() -> Result<Self, Error> {
Self::with_type(OidnDeviceType::Cpu)
}
pub fn cuda() -> Result<Self, Error> {
Self::with_type(OidnDeviceType::Cuda)
}
pub fn sycl() -> Result<Self, Error> {
Self::with_type(OidnDeviceType::Sycl)
}
pub fn hip() -> Result<Self, Error> {
Self::with_type(OidnDeviceType::Hip)
}
pub fn metal() -> Result<Self, Error> {
Self::with_type(OidnDeviceType::Metal)
}
pub fn with_type(device_type: OidnDeviceType) -> Result<Self, Error> {
let raw = unsafe { sys::oidnNewDevice(device_type.to_raw()) };
if raw.is_null() {
return Err(Error::DeviceCreationFailed);
}
unsafe { sys::oidnCommitDevice(raw) };
Ok(Self {
raw,
_refcount: Arc::new(()),
})
}
pub fn new_by_id(physical_device_id: i32) -> Result<Self, Error> {
let raw = unsafe { sys::oidnNewDeviceByID(physical_device_id) };
if raw.is_null() {
return Err(Error::DeviceCreationFailed);
}
unsafe { sys::oidnCommitDevice(raw) };
Ok(Self { raw, _refcount: Arc::new(()) })
}
pub fn new_by_uuid(uuid: &[u8; sys::OIDN_UUID_SIZE]) -> Result<Self, Error> {
let raw = unsafe { sys::oidnNewDeviceByUUID(uuid.as_ptr() as *const std::ffi::c_void) };
if raw.is_null() {
return Err(Error::DeviceCreationFailed);
}
unsafe { sys::oidnCommitDevice(raw) };
Ok(Self { raw, _refcount: Arc::new(()) })
}
pub fn new_by_luid(luid: &[u8; sys::OIDN_LUID_SIZE]) -> Result<Self, Error> {
let raw = unsafe { sys::oidnNewDeviceByLUID(luid.as_ptr() as *const std::ffi::c_void) };
if raw.is_null() {
return Err(Error::DeviceCreationFailed);
}
unsafe { sys::oidnCommitDevice(raw) };
Ok(Self { raw, _refcount: Arc::new(()) })
}
pub fn new_by_pci_address(
pci_domain: i32,
pci_bus: i32,
pci_device: i32,
pci_function: i32,
) -> Result<Self, Error> {
let raw = unsafe {
sys::oidnNewDeviceByPCIAddress(pci_domain, pci_bus, pci_device, pci_function)
};
if raw.is_null() {
return Err(Error::DeviceCreationFailed);
}
unsafe { sys::oidnCommitDevice(raw) };
Ok(Self { raw, _refcount: Arc::new(()) })
}
pub unsafe fn new_cuda_device(
device_id: i32,
stream: Option<*mut std::ffi::c_void>,
) -> Result<Self, Error> {
let stream_ptr = stream.unwrap_or(ptr::null_mut());
let raw = sys::oidnNewCUDADevice(&device_id, &stream_ptr, 1);
if raw.is_null() {
return Err(Error::DeviceCreationFailed);
}
sys::oidnCommitDevice(raw);
Ok(Self { raw, _refcount: Arc::new(()) })
}
pub unsafe fn new_hip_device(
device_id: i32,
stream: Option<*mut std::ffi::c_void>,
) -> Result<Self, Error> {
let stream_ptr = stream.unwrap_or(ptr::null_mut());
let raw = sys::oidnNewHIPDevice(&device_id, &stream_ptr, 1);
if raw.is_null() {
return Err(Error::DeviceCreationFailed);
}
sys::oidnCommitDevice(raw);
Ok(Self { raw, _refcount: Arc::new(()) })
}
pub unsafe fn new_metal_device(command_queues: &[*mut std::ffi::c_void]) -> Result<Self, Error> {
let raw = sys::oidnNewMetalDevice(command_queues.as_ptr(), command_queues.len() as i32);
if raw.is_null() {
return Err(Error::DeviceCreationFailed);
}
sys::oidnCommitDevice(raw);
Ok(Self { raw, _refcount: Arc::new(()) })
}
pub fn set_bool(&self, name: &str, value: bool) {
let c_name = CString::new(name).unwrap();
unsafe { sys::oidnSetDeviceBool(self.raw, c_name.as_ptr(), value) };
}
pub fn set_int(&self, name: &str, value: i32) {
let c_name = CString::new(name).unwrap();
unsafe { sys::oidnSetDeviceInt(self.raw, c_name.as_ptr(), value) };
}
pub fn get_bool(&self, name: &str) -> bool {
let c_name = CString::new(name).unwrap();
unsafe { sys::oidnGetDeviceBool(self.raw, c_name.as_ptr()) }
}
pub fn get_int(&self, name: &str) -> i32 {
let c_name = CString::new(name).unwrap();
unsafe { sys::oidnGetDeviceInt(self.raw, c_name.as_ptr()) }
}
pub fn get_uint(&self, name: &str) -> u32 {
self.get_int(name) as u32
}
pub fn commit(&self) {
unsafe { sys::oidnCommitDevice(self.raw) };
}
pub unsafe fn set_error_function_raw(
&self,
func: sys::OIDNErrorFunction,
user_ptr: *mut std::ffi::c_void,
) {
sys::oidnSetDeviceErrorFunction(self.raw, func, user_ptr);
}
pub fn take_error(&self) -> Option<Error> {
let mut msg_ptr: *const std::ffi::c_char = ptr::null();
let code = unsafe { sys::oidnGetDeviceError(self.raw, &mut msg_ptr) };
if code == sys::OIDNError::None {
return None;
}
let message = if msg_ptr.is_null() {
String::new()
} else {
unsafe { CStr::from_ptr(msg_ptr).to_string_lossy().into_owned() }
};
Some(Error::OidnError { code: code as u32, message })
}
pub fn sync(&self) {
unsafe { sys::oidnSyncDevice(self.raw) };
}
pub fn retain(&self) {
unsafe { sys::oidnRetainDevice(self.raw) };
}
pub(crate) fn raw(&self) -> sys::OIDNDevice {
self.raw
}
}
impl Drop for OidnDevice {
fn drop(&mut self) {
unsafe { sys::oidnReleaseDevice(self.raw) }
}
}
unsafe impl Send for OidnDevice {}
unsafe impl Sync for OidnDevice {}
#[derive(Clone, Copy, Debug, Default)]
pub enum OidnDeviceType {
#[default]
Default,
Cpu,
Sycl,
Cuda,
Hip,
Metal,
}
impl OidnDeviceType {
fn to_raw(self) -> sys::OIDNDeviceType {
match self {
OidnDeviceType::Default => sys::OIDNDeviceType::Default,
OidnDeviceType::Cpu => sys::OIDNDeviceType::CPU,
OidnDeviceType::Sycl => sys::OIDNDeviceType::SYCL,
OidnDeviceType::Cuda => sys::OIDNDeviceType::CUDA,
OidnDeviceType::Hip => sys::OIDNDeviceType::HIP,
OidnDeviceType::Metal => sys::OIDNDeviceType::Metal,
}
}
}