Skip to main content

oidn_wgpu_interop/
lib.rs

1use std::fmt::Debug;
2
3#[cfg(dx12)]
4mod dx12;
5#[cfg(vulkan)]
6mod vulkan;
7
8pub enum DeviceCreateError {
9    RequestDeviceError(wgpu::RequestDeviceError),
10    OidnUnsupported,
11    OidnImportUnsupported,
12    MissingFeature,
13    UnsupportedBackend(wgpu::Backend),
14}
15
16impl Debug for DeviceCreateError {
17    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
18        match self {
19            DeviceCreateError::RequestDeviceError(err) => err.fmt(f),
20            DeviceCreateError::OidnUnsupported => f.write_str(
21                "OIDN could not create a device for this Adapter (does this adapter support OIDN?)",
22            ),
23            DeviceCreateError::OidnImportUnsupported => {
24                f.write_str("OIDN does not support the required import method")
25            }
26            DeviceCreateError::MissingFeature => f.write_str("A required feature is missing"),
27            DeviceCreateError::UnsupportedBackend(backend) => {
28                f.write_str("The backend ")?;
29                backend.fmt(f)?;
30                f.write_str(" is not supported.")
31            }
32        }
33    }
34}
35
36pub enum SharedBufferCreateError {
37    InvalidSize(wgpu::BufferAddress),
38    Oidn((oidn::Error, String)),
39    OutOfMemory,
40}
41
42impl Debug for SharedBufferCreateError {
43    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
44        match self {
45            SharedBufferCreateError::InvalidSize(size) => {
46                f.write_str("Size ")?;
47                size.fmt(f)?;
48                f.write_str(" is not allowed")
49            }
50            SharedBufferCreateError::Oidn((error, desc)) => {
51                f.write_str("OIDN shared buffer creation failed with error ")?;
52                error.fmt(f)?;
53                f.write_str(": ")?;
54                desc.fmt(f)
55            }
56            SharedBufferCreateError::OutOfMemory => f.write_str("Out of memory"),
57        }
58    }
59}
60
61#[derive(Eq, PartialEq, Clone, Copy, Debug)]
62enum Backend {
63    #[cfg(dx12)]
64    Dx12,
65    #[cfg(vulkan)]
66    Vulkan,
67}
68
69enum BackendData {
70    #[cfg(dx12)]
71    Dx12,
72    #[cfg(vulkan)]
73    Vulkan(vulkan::VulkanSharingMode),
74}
75
76impl BackendData {
77    fn as_backend(&self) -> Backend {
78        match self {
79            #[cfg(dx12)]
80            BackendData::Dx12 => Backend::Dx12,
81            #[cfg(vulkan)]
82            BackendData::Vulkan(_) => Backend::Vulkan,
83        }
84    }
85}
86
87pub struct Device {
88    wgpu_device: wgpu::Device,
89    oidn_device: oidn::Device,
90    queue: wgpu::Queue,
91    backend_data: BackendData,
92}
93
94impl Device {
95    pub async fn new(
96        adapter: &wgpu::Adapter,
97        desc: &wgpu::DeviceDescriptor<'_>,
98    ) -> Result<(Self, wgpu::Queue), DeviceCreateError> {
99        match adapter.get_info().backend {
100            #[cfg(vulkan)]
101            wgpu::Backend::Vulkan => Self::new_vulkan(adapter, desc).await,
102            #[cfg(dx12)]
103            wgpu::Backend::Dx12 => Self::new_dx12(adapter, desc).await,
104            _ => Err(DeviceCreateError::UnsupportedBackend(
105                adapter.get_info().backend,
106            )),
107        }
108    }
109    pub fn allocate_shared_buffers(
110        &self,
111        size: wgpu::BufferAddress,
112    ) -> Result<SharedBuffer, SharedBufferCreateError> {
113        if size == 0 {
114            return Err(SharedBufferCreateError::InvalidSize(size));
115        }
116        match self.backend_data.as_backend() {
117            #[cfg(dx12)]
118            Backend::Dx12 => self.allocate_shared_buffers_dx12(size),
119            #[cfg(vulkan)]
120            Backend::Vulkan => self.allocate_shared_buffers_vulkan(size),
121        }
122    }
123    pub fn oidn_device(&self) -> &oidn::Device {
124        &self.oidn_device
125    }
126
127    pub fn wgpu_device(&self) -> &wgpu::Device {
128        &self.wgpu_device
129    }
130
131    async fn new_from_raw_oidn_adapter<
132        F: FnOnce(oidn::sys::OIDNExternalMemoryTypeFlag) -> Option<BackendData>,
133    >(
134        device: oidn::sys::OIDNDevice,
135        adapter: &wgpu::Adapter,
136        desc: &wgpu::DeviceDescriptor<'_>,
137        backend_data_callback: F,
138    ) -> Result<(Self, wgpu::Queue), DeviceCreateError> {
139        if device.is_null() {
140            return Err(crate::DeviceCreateError::OidnUnsupported);
141        }
142
143        let supported_memory_types = unsafe {
144            oidn::sys::oidnCommitDevice(device);
145            oidn::sys::oidnGetDeviceInt(device, b"externalMemoryTypes\0" as *const _ as _)
146        } as oidn::sys::OIDNExternalMemoryTypeFlag;
147        let Some(backend_data) = backend_data_callback(supported_memory_types) else {
148            unsafe {
149                oidn::sys::oidnReleaseDevice(device);
150            }
151            return Err(DeviceCreateError::OidnImportUnsupported);
152        };
153        let oidn_device = unsafe { oidn::Device::from_raw(device) };
154        let (wgpu_device, queue) = adapter
155            .request_device(desc)
156            .await
157            .map_err(crate::DeviceCreateError::RequestDeviceError)?;
158        Ok((
159            Self {
160                wgpu_device,
161                oidn_device,
162                queue: queue.clone(),
163                backend_data,
164            },
165            queue,
166        ))
167    }
168}
169
170enum Allocation {
171    // we keep these around to keep the allocations alive
172    #[cfg(dx12)]
173    Dx12 { _dx12: dx12::Dx12Allocation },
174    #[cfg(vulkan)]
175    Vulkan,
176}
177
178pub struct SharedBuffer {
179    _allocation: Allocation,
180    oidn_buffer: oidn::Buffer,
181    wgpu_buffer: wgpu::Buffer,
182}
183
184impl SharedBuffer {
185    pub fn oidn_buffer(&self) -> &oidn::Buffer {
186        &self.oidn_buffer
187    }
188    pub fn oidn_buffer_mut(&mut self) -> &mut oidn::Buffer {
189        &mut self.oidn_buffer
190    }
191    pub fn wgpu_buffer(&self) -> &wgpu::Buffer {
192        &self.wgpu_buffer
193    }
194}
195
196#[cfg(test)]
197#[async_std::test]
198async fn test() {
199    let instance = wgpu::Instance::new(wgpu::InstanceDescriptor {
200        backends: wgpu::Backends::DX12 | wgpu::Backends::VULKAN,
201        ..wgpu::InstanceDescriptor::new_without_display_handle()
202    });
203    let adapters = instance.enumerate_adapters(wgpu::Backends::all()).await;
204    for adapter in adapters {
205        match adapter.get_info().backend {
206            wgpu::Backend::Vulkan => {
207                eprintln!("Testing vulkan device {}", adapter.get_info().name);
208            }
209            wgpu::Backend::Dx12 => {
210                eprintln!("Testing dx12 device {}", adapter.get_info().name);
211            }
212            _ => continue,
213        }
214        let (device, queue) = match Device::new(&adapter, &wgpu::DeviceDescriptor::default()).await
215        {
216            Ok((device, queue)) => (device, queue),
217            Err(err) => {
218                eprintln!("Device creation failed");
219                eprintln!("    {err:?}");
220                continue;
221            }
222        };
223        let mut bufs = device
224            .allocate_shared_buffers(size_of::<[f32; 3]>() as wgpu::BufferAddress)
225            .unwrap();
226        queue.write_buffer(bufs.wgpu_buffer(), 0, &1.0_f32.to_ne_bytes());
227        queue.submit([]);
228        device
229            .wgpu_device()
230            .poll(wgpu::PollType::wait_indefinitely())
231            .unwrap();
232        assert_eq!(bufs.oidn_buffer_mut().read()[0], 1.0);
233        let mut filter = oidn::RayTracing::new(device.oidn_device());
234        filter.image_dimensions(1, 1);
235        filter
236            .filter_in_place_buffer(&mut bufs.oidn_buffer_mut())
237            .unwrap();
238        match device.oidn_device().get_error() {
239            Ok(_) | Err((oidn::Error::OutOfMemory, _)) => {}
240            Err(err) => panic!("{err:?}"),
241        }
242    }
243}
244
245// Ensure that dropping one or the other shared buffers does not break anything.
246#[cfg(test)]
247#[async_std::test]
248async fn test_validity() {
249    let instance = wgpu::Instance::new(wgpu::InstanceDescriptor {
250        backends: wgpu::Backends::DX12 | wgpu::Backends::VULKAN,
251        ..wgpu::InstanceDescriptor::new_without_display_handle()
252    });
253    let adapters = instance.enumerate_adapters(wgpu::Backends::all()).await;
254    for adapter in adapters {
255        match adapter.get_info().backend {
256            wgpu::Backend::Vulkan => {
257                eprintln!("Testing vulkan device {}", adapter.get_info().name);
258            }
259            wgpu::Backend::Dx12 => {
260                eprintln!("Testing dx12 device {}", adapter.get_info().name);
261            }
262            _ => continue,
263        }
264        let (device, queue) = match Device::new(&adapter, &wgpu::DeviceDescriptor::default()).await
265        {
266            Ok((device, queue)) => (device, queue),
267            Err(err) => {
268                eprintln!("Device creation failed");
269                eprintln!("    {err:?}");
270                continue;
271            }
272        };
273        {
274            let mut bufs = device
275                .allocate_shared_buffers(size_of::<[f32; 3]>() as wgpu::BufferAddress)
276                .unwrap();
277            queue.write_buffer(bufs.wgpu_buffer(), 0, &1.0_f32.to_ne_bytes());
278            queue.submit([]);
279            bufs.wgpu_buffer().destroy();
280            device
281                .wgpu_device()
282                .poll(wgpu::PollType::wait_indefinitely())
283                .unwrap();
284            assert_eq!(bufs.oidn_buffer_mut().read()[0], 1.0);
285            eprintln!("    Tested wgpu destroy");
286        }
287        {
288            use std::sync::mpsc;
289
290            use wgpu::{BufferAddress, BufferUsages, PollType, wgt::BufferDescriptor};
291
292            let bufs = device
293                .allocate_shared_buffers(size_of::<[f32; 3]>() as wgpu::BufferAddress)
294                .unwrap();
295            let buffer = bufs.wgpu_buffer().clone();
296            drop(bufs);
297            queue.write_buffer(&buffer, 0, &1.0_f32.to_ne_bytes());
298            let readback_buffer = device.wgpu_device().create_buffer(&BufferDescriptor {
299                label: Some("readback"),
300                size: size_of::<f32>() as _,
301                usage: BufferUsages::MAP_READ | BufferUsages::COPY_DST,
302                mapped_at_creation: false,
303            });
304            let mut encoder = device
305                .wgpu_device()
306                .create_command_encoder(&Default::default());
307            encoder.copy_buffer_to_buffer(
308                &buffer,
309                0,
310                &readback_buffer,
311                0,
312                size_of::<f32>() as BufferAddress,
313            );
314            let (send, recv) = mpsc::channel();
315            encoder.map_buffer_on_submit(&readback_buffer, wgpu::MapMode::Read, .., move |_| {
316                send.send(()).unwrap()
317            });
318            queue.submit([encoder.finish()]);
319            device
320                .wgpu_device()
321                .poll(PollType::wait_indefinitely())
322                .unwrap();
323            recv.recv().unwrap();
324
325            let view = readback_buffer.get_mapped_range(..);
326            assert_eq!(
327                f32::from_ne_bytes([view[0], view[1], view[2], view[3]]),
328                1.0_f32
329            );
330
331            eprintln!("    Tested oidn drop");
332        }
333    }
334}