1use std::fmt::Debug;
2
3#[cfg(dx12)]
4mod dx12;
5#[cfg(vulkan)]
6mod vulkan;
7
8pub enum DeviceCreateError {
9 RequestDeviceError(wgpu::RequestDeviceError),
10 OidnUnsupported,
11 OidnImportUnsupported,
12 MissingFeature,
13 UnsupportedBackend(wgpu::Backend),
14}
15
16impl Debug for DeviceCreateError {
17 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
18 match self {
19 DeviceCreateError::RequestDeviceError(err) => err.fmt(f),
20 DeviceCreateError::OidnUnsupported => f.write_str(
21 "OIDN could not create a device for this Adapter (does this adapter support OIDN?)",
22 ),
23 DeviceCreateError::OidnImportUnsupported => {
24 f.write_str("OIDN does not support the required import method")
25 }
26 DeviceCreateError::MissingFeature => f.write_str("A required feature is missing"),
27 DeviceCreateError::UnsupportedBackend(backend) => {
28 f.write_str("The backend ")?;
29 backend.fmt(f)?;
30 f.write_str(" is not supported.")
31 }
32 }
33 }
34}
35
36pub enum SharedBufferCreateError {
37 InvalidSize(wgpu::BufferAddress),
38 Oidn((oidn::Error, String)),
39 OutOfMemory,
40}
41
42impl Debug for SharedBufferCreateError {
43 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
44 match self {
45 SharedBufferCreateError::InvalidSize(size) => {
46 f.write_str("Size ")?;
47 size.fmt(f)?;
48 f.write_str(" is not allowed")
49 }
50 SharedBufferCreateError::Oidn((error, desc)) => {
51 f.write_str("OIDN shared buffer creation failed with error ")?;
52 error.fmt(f)?;
53 f.write_str(": ")?;
54 desc.fmt(f)
55 }
56 SharedBufferCreateError::OutOfMemory => f.write_str("Out of memory"),
57 }
58 }
59}
60
61#[derive(Eq, PartialEq, Clone, Copy, Debug)]
62enum Backend {
63 #[cfg(dx12)]
64 Dx12,
65 #[cfg(vulkan)]
66 Vulkan,
67}
68
69enum BackendData {
70 #[cfg(dx12)]
71 Dx12,
72 #[cfg(vulkan)]
73 Vulkan(vulkan::VulkanSharingMode),
74}
75
76impl BackendData {
77 fn as_backend(&self) -> Backend {
78 match self {
79 #[cfg(dx12)]
80 BackendData::Dx12 => Backend::Dx12,
81 #[cfg(vulkan)]
82 BackendData::Vulkan(_) => Backend::Vulkan,
83 }
84 }
85}
86
87pub struct Device {
88 wgpu_device: wgpu::Device,
89 oidn_device: oidn::Device,
90 queue: wgpu::Queue,
91 backend_data: BackendData,
92}
93
94impl Device {
95 pub async fn new(
96 adapter: &wgpu::Adapter,
97 desc: &wgpu::DeviceDescriptor<'_>,
98 ) -> Result<(Self, wgpu::Queue), DeviceCreateError> {
99 match adapter.get_info().backend {
100 #[cfg(vulkan)]
101 wgpu::Backend::Vulkan => Self::new_vulkan(adapter, desc).await,
102 #[cfg(dx12)]
103 wgpu::Backend::Dx12 => Self::new_dx12(adapter, desc).await,
104 _ => Err(DeviceCreateError::UnsupportedBackend(
105 adapter.get_info().backend,
106 )),
107 }
108 }
109 pub fn allocate_shared_buffers(
110 &self,
111 size: wgpu::BufferAddress,
112 ) -> Result<SharedBuffer, SharedBufferCreateError> {
113 if size == 0 {
114 return Err(SharedBufferCreateError::InvalidSize(size));
115 }
116 match self.backend_data.as_backend() {
117 #[cfg(dx12)]
118 Backend::Dx12 => self.allocate_shared_buffers_dx12(size),
119 #[cfg(vulkan)]
120 Backend::Vulkan => self.allocate_shared_buffers_vulkan(size),
121 }
122 }
123 pub fn oidn_device(&self) -> &oidn::Device {
124 &self.oidn_device
125 }
126
127 pub fn wgpu_device(&self) -> &wgpu::Device {
128 &self.wgpu_device
129 }
130
131 async fn new_from_raw_oidn_adapter<
132 F: FnOnce(oidn::sys::OIDNExternalMemoryTypeFlag) -> Option<BackendData>,
133 >(
134 device: oidn::sys::OIDNDevice,
135 adapter: &wgpu::Adapter,
136 desc: &wgpu::DeviceDescriptor<'_>,
137 backend_data_callback: F,
138 ) -> Result<(Self, wgpu::Queue), DeviceCreateError> {
139 if device.is_null() {
140 return Err(crate::DeviceCreateError::OidnUnsupported);
141 }
142
143 let supported_memory_types = unsafe {
144 oidn::sys::oidnCommitDevice(device);
145 oidn::sys::oidnGetDeviceInt(device, b"externalMemoryTypes\0" as *const _ as _)
146 } as oidn::sys::OIDNExternalMemoryTypeFlag;
147 let Some(backend_data) = backend_data_callback(supported_memory_types) else {
148 unsafe {
149 oidn::sys::oidnReleaseDevice(device);
150 }
151 return Err(DeviceCreateError::OidnImportUnsupported);
152 };
153 let oidn_device = unsafe { oidn::Device::from_raw(device) };
154 let (wgpu_device, queue) = adapter
155 .request_device(desc)
156 .await
157 .map_err(crate::DeviceCreateError::RequestDeviceError)?;
158 Ok((
159 Self {
160 wgpu_device,
161 oidn_device,
162 queue: queue.clone(),
163 backend_data,
164 },
165 queue,
166 ))
167 }
168}
169
170enum Allocation {
171 #[cfg(dx12)]
173 Dx12 { _dx12: dx12::Dx12Allocation },
174 #[cfg(vulkan)]
175 Vulkan,
176}
177
178pub struct SharedBuffer {
179 _allocation: Allocation,
180 oidn_buffer: oidn::Buffer,
181 wgpu_buffer: wgpu::Buffer,
182}
183
184impl SharedBuffer {
185 pub fn oidn_buffer(&self) -> &oidn::Buffer {
186 &self.oidn_buffer
187 }
188 pub fn oidn_buffer_mut(&mut self) -> &mut oidn::Buffer {
189 &mut self.oidn_buffer
190 }
191 pub fn wgpu_buffer(&self) -> &wgpu::Buffer {
192 &self.wgpu_buffer
193 }
194}
195
196#[cfg(test)]
197#[async_std::test]
198async fn test() {
199 let instance = wgpu::Instance::new(&wgpu::InstanceDescriptor {
200 backends: wgpu::Backends::DX12 | wgpu::Backends::VULKAN,
201 ..Default::default()
202 });
203 let adapters = instance.enumerate_adapters(wgpu::Backends::all());
204 for adapter in adapters {
205 match adapter.get_info().backend {
206 wgpu::Backend::Vulkan => {
207 eprintln!("Testing vulkan device {}", adapter.get_info().name);
208 }
209 wgpu::Backend::Dx12 => {
210 eprintln!("Testing dx12 device {}", adapter.get_info().name);
211 }
212 _ => continue,
213 }
214 let (device, queue) = match Device::new(&adapter, &wgpu::DeviceDescriptor::default()).await
215 {
216 Ok((device, queue)) => (device, queue),
217 Err(err) => {
218 eprintln!("Device creation failed");
219 eprintln!(" {err:?}");
220 continue;
221 }
222 };
223 let mut bufs = device
224 .allocate_shared_buffers(size_of::<[f32; 3]>() as wgpu::BufferAddress)
225 .unwrap();
226 queue.write_buffer(bufs.wgpu_buffer(), 0, &1.0_f32.to_ne_bytes());
227 queue.submit([]);
228 device.wgpu_device().poll(wgpu::PollType::Wait).unwrap();
229 assert_eq!(bufs.oidn_buffer_mut().read()[0], 1.0);
230 let mut filter = oidn::RayTracing::new(device.oidn_device());
231 filter.image_dimensions(1, 1);
232 filter
233 .filter_in_place_buffer(&mut bufs.oidn_buffer_mut())
234 .unwrap();
235 match device.oidn_device().get_error() {
236 Ok(_) | Err((oidn::Error::OutOfMemory, _)) => {}
237 Err(err) => panic!("{err:?}"),
238 }
239 }
240}