1use crate::sys;
5use crate::Error;
6use std::ffi::{CStr, CString};
7use std::ptr;
8use std::sync::Arc;
9
10pub fn num_physical_devices() -> i32 {
12 unsafe { sys::oidnGetNumPhysicalDevices() }
13}
14
15pub fn get_physical_device_bool(physical_device_id: i32, name: &str) -> bool {
17 let c_name = CString::new(name).unwrap();
18 unsafe { sys::oidnGetPhysicalDeviceBool(physical_device_id, c_name.as_ptr()) }
19}
20
21pub fn get_physical_device_int(physical_device_id: i32, name: &str) -> i32 {
23 let c_name = CString::new(name).unwrap();
24 unsafe { sys::oidnGetPhysicalDeviceInt(physical_device_id, c_name.as_ptr()) }
25}
26
27pub fn get_physical_device_string(physical_device_id: i32, name: &str) -> Option<String> {
29 let c_name = CString::new(name).unwrap();
30 let p = unsafe { sys::oidnGetPhysicalDeviceString(physical_device_id, c_name.as_ptr()) };
31 if p.is_null() {
32 return None;
33 }
34 Some(unsafe { CStr::from_ptr(p).to_string_lossy().into_owned() })
35}
36
37pub fn get_physical_device_data(physical_device_id: i32, name: &str) -> Option<(*const std::ffi::c_void, usize)> {
39 let c_name = CString::new(name).unwrap();
40 let mut size = 0usize;
41 let p = unsafe { sys::oidnGetPhysicalDeviceData(physical_device_id, c_name.as_ptr(), &mut size) };
42 if p.is_null() {
43 None
44 } else {
45 Some((p, size))
46 }
47}
48
49pub fn is_cpu_device_supported() -> bool {
51 unsafe { sys::oidnIsCPUDeviceSupported() }
52}
53
54pub fn is_cuda_device_supported(device_id: i32) -> bool {
56 unsafe { sys::oidnIsCUDADeviceSupported(device_id) }
57}
58
59pub fn is_hip_device_supported(device_id: i32) -> bool {
61 unsafe { sys::oidnIsHIPDeviceSupported(device_id) }
62}
63
64pub unsafe fn is_metal_device_supported(device: *mut std::ffi::c_void) -> bool {
66 sys::oidnIsMetalDeviceSupported(device)
67}
68
69pub fn take_global_error() -> Option<Error> {
72 let mut msg_ptr: *const std::ffi::c_char = ptr::null();
73 let code = unsafe { sys::oidnGetDeviceError(ptr::null_mut(), &mut msg_ptr) };
74 if code == sys::OIDNError::None {
75 return None;
76 }
77 let message = if msg_ptr.is_null() {
78 String::new()
79 } else {
80 unsafe { CStr::from_ptr(msg_ptr).to_string_lossy().into_owned() }
81 };
82 Some(Error::OidnError { code: code as u32, message })
83}
84
85#[derive(Clone)]
90pub struct OidnDevice {
91 pub(crate) raw: sys::OIDNDevice,
92 _refcount: Arc<()>,
93}
94
95impl std::fmt::Debug for OidnDevice {
96 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
97 f.debug_struct("OidnDevice").finish_non_exhaustive()
98 }
99}
100
101impl OidnDevice {
102 pub fn new() -> Result<Self, Error> {
109 Self::with_type(OidnDeviceType::Default)
110 }
111
112 pub fn cpu() -> Result<Self, Error> {
114 Self::with_type(OidnDeviceType::Cpu)
115 }
116
117 pub fn cuda() -> Result<Self, Error> {
120 Self::with_type(OidnDeviceType::Cuda)
121 }
122
123 pub fn sycl() -> Result<Self, Error> {
125 Self::with_type(OidnDeviceType::Sycl)
126 }
127
128 pub fn hip() -> Result<Self, Error> {
130 Self::with_type(OidnDeviceType::Hip)
131 }
132
133 pub fn metal() -> Result<Self, Error> {
135 Self::with_type(OidnDeviceType::Metal)
136 }
137
138 pub fn with_type(device_type: OidnDeviceType) -> Result<Self, Error> {
140 let raw = unsafe { sys::oidnNewDevice(device_type.to_raw()) };
141 if raw.is_null() {
142 return Err(Error::DeviceCreationFailed);
143 }
144 unsafe { sys::oidnCommitDevice(raw) };
145 Ok(Self {
146 raw,
147 _refcount: Arc::new(()),
148 })
149 }
150
151 pub fn new_by_id(physical_device_id: i32) -> Result<Self, Error> {
153 let raw = unsafe { sys::oidnNewDeviceByID(physical_device_id) };
154 if raw.is_null() {
155 return Err(Error::DeviceCreationFailed);
156 }
157 unsafe { sys::oidnCommitDevice(raw) };
158 Ok(Self { raw, _refcount: Arc::new(()) })
159 }
160
161 pub fn new_by_uuid(uuid: &[u8; sys::OIDN_UUID_SIZE]) -> Result<Self, Error> {
163 let raw = unsafe { sys::oidnNewDeviceByUUID(uuid.as_ptr() as *const std::ffi::c_void) };
164 if raw.is_null() {
165 return Err(Error::DeviceCreationFailed);
166 }
167 unsafe { sys::oidnCommitDevice(raw) };
168 Ok(Self { raw, _refcount: Arc::new(()) })
169 }
170
171 pub fn new_by_luid(luid: &[u8; sys::OIDN_LUID_SIZE]) -> Result<Self, Error> {
173 let raw = unsafe { sys::oidnNewDeviceByLUID(luid.as_ptr() as *const std::ffi::c_void) };
174 if raw.is_null() {
175 return Err(Error::DeviceCreationFailed);
176 }
177 unsafe { sys::oidnCommitDevice(raw) };
178 Ok(Self { raw, _refcount: Arc::new(()) })
179 }
180
181 pub fn new_by_pci_address(
183 pci_domain: i32,
184 pci_bus: i32,
185 pci_device: i32,
186 pci_function: i32,
187 ) -> Result<Self, Error> {
188 let raw = unsafe {
189 sys::oidnNewDeviceByPCIAddress(pci_domain, pci_bus, pci_device, pci_function)
190 };
191 if raw.is_null() {
192 return Err(Error::DeviceCreationFailed);
193 }
194 unsafe { sys::oidnCommitDevice(raw) };
195 Ok(Self { raw, _refcount: Arc::new(()) })
196 }
197
198 pub unsafe fn new_cuda_device(
202 device_id: i32,
203 stream: Option<*mut std::ffi::c_void>,
204 ) -> Result<Self, Error> {
205 let stream_ptr = stream.unwrap_or(ptr::null_mut());
206 let raw = sys::oidnNewCUDADevice(&device_id, &stream_ptr, 1);
207 if raw.is_null() {
208 return Err(Error::DeviceCreationFailed);
209 }
210 sys::oidnCommitDevice(raw);
211 Ok(Self { raw, _refcount: Arc::new(()) })
212 }
213
214 pub unsafe fn new_hip_device(
217 device_id: i32,
218 stream: Option<*mut std::ffi::c_void>,
219 ) -> Result<Self, Error> {
220 let stream_ptr = stream.unwrap_or(ptr::null_mut());
221 let raw = sys::oidnNewHIPDevice(&device_id, &stream_ptr, 1);
222 if raw.is_null() {
223 return Err(Error::DeviceCreationFailed);
224 }
225 sys::oidnCommitDevice(raw);
226 Ok(Self { raw, _refcount: Arc::new(()) })
227 }
228
229 pub unsafe fn new_metal_device(command_queues: &[*mut std::ffi::c_void]) -> Result<Self, Error> {
232 let raw = sys::oidnNewMetalDevice(command_queues.as_ptr(), command_queues.len() as i32);
233 if raw.is_null() {
234 return Err(Error::DeviceCreationFailed);
235 }
236 sys::oidnCommitDevice(raw);
237 Ok(Self { raw, _refcount: Arc::new(()) })
238 }
239
240 pub fn set_bool(&self, name: &str, value: bool) {
242 let c_name = CString::new(name).unwrap();
243 unsafe { sys::oidnSetDeviceBool(self.raw, c_name.as_ptr(), value) };
244 }
245
246 pub fn set_int(&self, name: &str, value: i32) {
248 let c_name = CString::new(name).unwrap();
249 unsafe { sys::oidnSetDeviceInt(self.raw, c_name.as_ptr(), value) };
250 }
251
252 pub fn get_bool(&self, name: &str) -> bool {
254 let c_name = CString::new(name).unwrap();
255 unsafe { sys::oidnGetDeviceBool(self.raw, c_name.as_ptr()) }
256 }
257
258 pub fn get_int(&self, name: &str) -> i32 {
260 let c_name = CString::new(name).unwrap();
261 unsafe { sys::oidnGetDeviceInt(self.raw, c_name.as_ptr()) }
262 }
263
264 pub fn get_uint(&self, name: &str) -> u32 {
266 self.get_int(name) as u32
267 }
268
269 pub fn commit(&self) {
271 unsafe { sys::oidnCommitDevice(self.raw) };
272 }
273
274 pub unsafe fn set_error_function_raw(
277 &self,
278 func: sys::OIDNErrorFunction,
279 user_ptr: *mut std::ffi::c_void,
280 ) {
281 sys::oidnSetDeviceErrorFunction(self.raw, func, user_ptr);
282 }
283
284 pub fn take_error(&self) -> Option<Error> {
286 let mut msg_ptr: *const std::ffi::c_char = ptr::null();
287 let code = unsafe { sys::oidnGetDeviceError(self.raw, &mut msg_ptr) };
288 if code == sys::OIDNError::None {
289 return None;
290 }
291 let message = if msg_ptr.is_null() {
292 String::new()
293 } else {
294 unsafe { CStr::from_ptr(msg_ptr).to_string_lossy().into_owned() }
295 };
296 Some(Error::OidnError { code: code as u32, message })
297 }
298
299 pub fn sync(&self) {
301 unsafe { sys::oidnSyncDevice(self.raw) };
302 }
303
304 pub fn retain(&self) {
306 unsafe { sys::oidnRetainDevice(self.raw) };
307 }
308
309 pub(crate) fn raw(&self) -> sys::OIDNDevice {
310 self.raw
311 }
312}
313
314impl Drop for OidnDevice {
315 fn drop(&mut self) {
316 unsafe { sys::oidnReleaseDevice(self.raw) }
317 }
318}
319
320unsafe impl Send for OidnDevice {}
321unsafe impl Sync for OidnDevice {}
322
323#[derive(Clone, Copy, Debug, Default)]
325pub enum OidnDeviceType {
326 #[default]
328 Default,
329 Cpu,
331 Sycl,
333 Cuda,
335 Hip,
337 Metal,
339}
340
341impl OidnDeviceType {
342 fn to_raw(self) -> sys::OIDNDeviceType {
343 match self {
344 OidnDeviceType::Default => sys::OIDNDeviceType::Default,
345 OidnDeviceType::Cpu => sys::OIDNDeviceType::CPU,
346 OidnDeviceType::Sycl => sys::OIDNDeviceType::SYCL,
347 OidnDeviceType::Cuda => sys::OIDNDeviceType::CUDA,
348 OidnDeviceType::Hip => sys::OIDNDeviceType::HIP,
349 OidnDeviceType::Metal => sys::OIDNDeviceType::Metal,
350 }
351 }
352}