Skip to main content

oidn_wgpu/
wgpu_integration.rs

1//! Denoise wgpu textures by copying to CPU, running OIDN, and copying back.
2
3use crate::device::OidnDevice;
4use crate::filter::{Quality, RtFilter};
5use crate::Error;
6use bytemuck::{cast_slice, cast_slice_mut};
7use std::sync::mpsc;
8use wgpu::util::DeviceExt;
9
10/// Supported texture format for denoising input/output.
11#[derive(Clone, Copy, Debug, PartialEq, Eq)]
12pub enum DenoiseTextureFormat {
13    /// RGBA 32-bit float (4 components). Alpha is preserved.
14    Rgba32Float,
15    /// RGBA 16-bit float (4 components). Alpha is preserved.
16    Rgba16Float,
17}
18
19impl DenoiseTextureFormat {
20    /// Converts from a wgpu texture format if it is supported for denoising.
21    pub fn from_wgpu(format: wgpu::TextureFormat) -> Option<Self> {
22        format.try_into().ok()
23    }
24
25    fn bytes_per_pixel(self) -> u32 {
26        match self {
27            DenoiseTextureFormat::Rgba32Float => 16,
28            DenoiseTextureFormat::Rgba16Float => 8,
29        }
30    }
31}
32
33impl TryFrom<wgpu::TextureFormat> for DenoiseTextureFormat {
34    type Error = ();
35
36    fn try_from(format: wgpu::TextureFormat) -> Result<Self, Self::Error> {
37        match format {
38            wgpu::TextureFormat::Rgba32Float => Ok(Self::Rgba32Float),
39            wgpu::TextureFormat::Rgba16Float => Ok(Self::Rgba16Float),
40            _ => Err(()),
41        }
42    }
43}
44
45/// Options for denoising a wgpu texture.
46#[derive(Clone, Debug)]
47pub struct DenoiseOptions {
48    /// Quality vs performance: `Fast`, `Balanced`, or `High`.
49    pub quality: Quality,
50    /// `true` if the image is HDR (linear, possibly > 1.0).
51    pub hdr: bool,
52    /// `true` if the image is sRGB-encoded LDR.
53    pub srgb: bool,
54    /// Input scale for HDR (e.g. exposure). `None` = auto.
55    pub input_scale: Option<f32>,
56}
57
58impl Default for DenoiseOptions {
59    fn default() -> Self {
60        Self {
61            quality: Quality::Default,
62            hdr: true,
63            srgb: false,
64            input_scale: None,
65        }
66    }
67}
68
69/// Denoises a wgpu texture by readback → OIDN (CPU) → upload.
70///
71/// Input and output can be the same texture for in-place denoising, or different.
72/// Supported formats: [`DenoiseTextureFormat::Rgba32Float`], [`DenoiseTextureFormat::Rgba16Float`].
73/// Only RGB is denoised; alpha is preserved.
74///
75/// **Texture usage:** `input` must have [`TextureUsages::COPY_SRC`](wgpu::TextureUsages::COPY_SRC);
76/// `output` must have [`TextureUsages::COPY_DST`](wgpu::TextureUsages::COPY_DST).
77///
78/// This is a blocking call: it submits copy commands, waits for readback, runs OIDN, then uploads.
79///
80/// # Errors
81///
82/// Returns [`Error::InvalidDimensions`] if texture sizes or array layers are incompatible, or
83/// [`Error::BufferMapFailed`] if wgpu buffer mapping fails. OIDN execution errors are returned
84/// as [`Error::OidnError`] or other [`Error`] variants.
85pub fn denoise_texture(
86    device: &OidnDevice,
87    wgpu_device: &wgpu::Device,
88    wgpu_queue: &wgpu::Queue,
89    input: &wgpu::Texture,
90    output: &wgpu::Texture,
91    format: DenoiseTextureFormat,
92    options: &DenoiseOptions,
93) -> Result<(), Error> {
94    let size = input.size();
95    if size.depth_or_array_layers != 1 {
96        return Err(Error::InvalidDimensions);
97    }
98    let out_size = output.size();
99    if out_size.width != size.width
100        || out_size.height != size.height
101        || out_size.depth_or_array_layers != 1
102    {
103        return Err(Error::InvalidDimensions);
104    }
105    let w = size.width;
106    let h = size.height;
107    let bpp = format.bytes_per_pixel();
108    let bytes_per_row = w * bpp;
109    let alignment = wgpu::COPY_BYTES_PER_ROW_ALIGNMENT;
110    let padded_bytes_per_row = (bytes_per_row + alignment - 1) / alignment * alignment;
111    let buffer_size = padded_bytes_per_row as u64 * h as u64;
112
113    // Create staging buffer for readback
114    let read_buffer = wgpu_device.create_buffer(&wgpu::BufferDescriptor {
115        label: Some("oidn_wgpu readback"),
116        size: buffer_size,
117        usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
118        mapped_at_creation: false,
119    });
120
121    let mut encoder = wgpu_device.create_command_encoder(&wgpu::CommandEncoderDescriptor::default());
122    encoder.copy_texture_to_buffer(
123        wgpu::TexelCopyTextureInfo {
124            texture: input,
125            mip_level: 0,
126            origin: wgpu::Origin3d::ZERO,
127            aspect: wgpu::TextureAspect::All,
128        },
129        wgpu::TexelCopyBufferInfo {
130            buffer: &read_buffer,
131            layout: wgpu::TexelCopyBufferLayout {
132                offset: 0,
133                bytes_per_row: Some(padded_bytes_per_row),
134                rows_per_image: Some(h),
135            },
136        },
137        size,
138    );
139    wgpu_queue.submit(Some(encoder.finish()));
140
141    // Map and read
142    let slice = read_buffer.slice(..);
143    let (tx, rx) = mpsc::channel();
144    slice.map_async(wgpu::MapMode::Read, move |r| {
145        let _ = tx.send(r);
146    });
147    loop {
148        let _ = wgpu_device.poll(wgpu::PollType::wait_indefinitely());
149        match rx.try_recv() {
150            Ok(Ok(())) => break,
151            Ok(Err(_)) => return Err(Error::BufferMapFailed(wgpu::BufferAsyncError)),
152            Err(mpsc::TryRecvError::Disconnected) => return Err(Error::BufferMapFailed(wgpu::BufferAsyncError)),
153            Err(mpsc::TryRecvError::Empty) => std::thread::sleep(std::time::Duration::from_micros(100)),
154        }
155    }
156
157    let mapped = slice.get_mapped_range();
158    let raw: &[u8] = &mapped;
159
160    let n_pixels = (w * h) as usize;
161    let mut color_f32 = vec![0.0f32; n_pixels * 3];
162    let alpha_f32: Vec<f32> = match format {
163        DenoiseTextureFormat::Rgba32Float => {
164            let floats: &[f32] = cast_slice(raw);
165            for i in 0..n_pixels {
166                let j = i * 4;
167                color_f32[i * 3] = floats[j];
168                color_f32[i * 3 + 1] = floats[j + 1];
169                color_f32[i * 3 + 2] = floats[j + 2];
170            }
171            floats.iter().skip(3).step_by(4).copied().collect()
172        }
173        DenoiseTextureFormat::Rgba16Float => {
174            let u16s: &[u16] = cast_slice(raw);
175            for i in 0..n_pixels {
176                let j = i * 4;
177                color_f32[i * 3] = half::f16::from_bits(u16s[j]).to_f32();
178                color_f32[i * 3 + 1] = half::f16::from_bits(u16s[j + 1]).to_f32();
179                color_f32[i * 3 + 2] = half::f16::from_bits(u16s[j + 2]).to_f32();
180            }
181            u16s.iter()
182                .skip(3)
183                .step_by(4)
184                .map(|&b| half::f16::from_bits(b).to_f32())
185                .collect()
186        }
187    };
188
189    drop(mapped);
190
191    // Run OIDN
192    let mut filter = RtFilter::new(device)?;
193    filter
194        .set_dimensions(w, h)
195        .set_hdr(options.hdr)
196        .set_srgb(options.srgb)
197        .set_quality(options.quality);
198    if let Some(scale) = options.input_scale {
199        filter.set_input_scale(scale);
200    }
201    filter.execute_in_place(&mut color_f32)?;
202
203    // Convert back to RGBA (preserve alpha from original)
204    let mut output_bytes: Vec<u8> = vec![0; n_pixels * bpp as usize];
205    match format {
206        DenoiseTextureFormat::Rgba32Float => {
207            let out_f: &mut [f32] = cast_slice_mut(&mut output_bytes);
208            for i in 0..n_pixels {
209                let j = i * 4;
210                out_f[j] = color_f32[i * 3];
211                out_f[j + 1] = color_f32[i * 3 + 1];
212                out_f[j + 2] = color_f32[i * 3 + 2];
213                out_f[j + 3] = alpha_f32.get(i).copied().unwrap_or(1.0);
214            }
215        }
216        DenoiseTextureFormat::Rgba16Float => {
217            let out_u16: &mut [u16] = cast_slice_mut(&mut output_bytes);
218            for i in 0..n_pixels {
219                out_u16[i * 4] = half::f16::from_f32(color_f32[i * 3]).to_bits();
220                out_u16[i * 4 + 1] = half::f16::from_f32(color_f32[i * 3 + 1]).to_bits();
221                out_u16[i * 4 + 2] = half::f16::from_f32(color_f32[i * 3 + 2]).to_bits();
222                out_u16[i * 4 + 3] = half::f16::from_f32(alpha_f32.get(i).copied().unwrap_or(1.0)).to_bits();
223            }
224        }
225    }
226
227    // Upload: buffer must use padded row stride for WebGPU alignment.
228    let mut upload_data = vec![0u8; (padded_bytes_per_row * h) as usize];
229    for row in 0..h {
230        let src_off = (row * bytes_per_row) as usize;
231        let dst_off = (row * padded_bytes_per_row) as usize;
232        upload_data[dst_off..dst_off + bytes_per_row as usize]
233            .copy_from_slice(&output_bytes[src_off..src_off + bytes_per_row as usize]);
234    }
235    let write_buffer = wgpu_device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
236        label: Some("oidn_wgpu upload"),
237        contents: &upload_data,
238        usage: wgpu::BufferUsages::COPY_SRC,
239    });
240
241    let mut enc2 = wgpu_device.create_command_encoder(&wgpu::CommandEncoderDescriptor::default());
242    enc2.copy_buffer_to_texture(
243        wgpu::TexelCopyBufferInfo {
244            buffer: &write_buffer,
245            layout: wgpu::TexelCopyBufferLayout {
246                offset: 0,
247                bytes_per_row: Some(padded_bytes_per_row),
248                rows_per_image: Some(h),
249            },
250        },
251        wgpu::TexelCopyTextureInfo {
252            texture: output,
253            mip_level: 0,
254            origin: wgpu::Origin3d::ZERO,
255            aspect: wgpu::TextureAspect::All,
256        },
257        size,
258    );
259    wgpu_queue.submit(Some(enc2.finish()));
260
261    Ok(())
262}
263
264/// Denoises a wgpu color texture with optional albedo and normal AOV textures (same size/format as color).
265///
266/// Higher quality when albedo and normal are provided. Otherwise identical to [`denoise_texture`].
267///
268/// # Errors
269///
270/// Same as [`denoise_texture`]; also [`Error::InvalidDimensions`] if any aux texture size does not match.
271pub fn denoise_texture_with_aux(
272    device: &OidnDevice,
273    wgpu_device: &wgpu::Device,
274    wgpu_queue: &wgpu::Queue,
275    input: &wgpu::Texture,
276    output: &wgpu::Texture,
277    format: DenoiseTextureFormat,
278    options: &DenoiseOptions,
279    albedo: Option<&wgpu::Texture>,
280    normal: Option<&wgpu::Texture>,
281) -> Result<(), Error> {
282    let size = input.size();
283    if size.depth_or_array_layers != 1 {
284        return Err(Error::InvalidDimensions);
285    }
286    let out_size = output.size();
287    if out_size.width != size.width
288        || out_size.height != size.height
289        || out_size.depth_or_array_layers != 1
290    {
291        return Err(Error::InvalidDimensions);
292    }
293    if let Some(tex) = albedo {
294        let s = tex.size();
295        if s.width != size.width || s.height != size.height || s.depth_or_array_layers != 1 {
296            return Err(Error::InvalidDimensions);
297        }
298    }
299    if let Some(tex) = normal {
300        let s = tex.size();
301        if s.width != size.width || s.height != size.height || s.depth_or_array_layers != 1 {
302            return Err(Error::InvalidDimensions);
303        }
304    }
305    let w = size.width;
306    let h = size.height;
307
308    let (mut color_rgb, alpha) = read_texture_to_rgba_f32(wgpu_device, wgpu_queue, input, format)?;
309    let albedo_rgb = albedo
310        .map(|t| read_texture_to_rgba_f32(wgpu_device, wgpu_queue, t, format).map(|(rgb, _)| rgb))
311        .transpose()?;
312    let normal_rgb = normal
313        .map(|t| read_texture_to_rgba_f32(wgpu_device, wgpu_queue, t, format).map(|(rgb, _)| rgb))
314        .transpose()?;
315
316    let mut filter = RtFilter::new(device)?;
317    filter
318        .set_dimensions(w, h)
319        .set_hdr(options.hdr)
320        .set_srgb(options.srgb)
321        .set_quality(options.quality);
322    if let Some(scale) = options.input_scale {
323        filter.set_input_scale(scale);
324    }
325    filter.execute_in_place_with_aux(
326        &mut color_rgb,
327        albedo_rgb.as_deref(),
328        normal_rgb.as_deref(),
329    )?;
330
331    upload_rgba_to_texture(
332        wgpu_device,
333        wgpu_queue,
334        output,
335        format,
336        w,
337        h,
338        &color_rgb,
339        &alpha,
340    )
341}
342
343/// Reads a wgpu texture to CPU as (RGB f32, alpha f32). Blocking.
344fn read_texture_to_rgba_f32(
345    wgpu_device: &wgpu::Device,
346    wgpu_queue: &wgpu::Queue,
347    texture: &wgpu::Texture,
348    format: DenoiseTextureFormat,
349) -> Result<(Vec<f32>, Vec<f32>), Error> {
350    let size = texture.size();
351    let w = size.width;
352    let h = size.height;
353    let bpp = format.bytes_per_pixel();
354    let bytes_per_row = w * bpp;
355    let alignment = wgpu::COPY_BYTES_PER_ROW_ALIGNMENT;
356    let padded_bytes_per_row = (bytes_per_row + alignment - 1) / alignment * alignment;
357    let buffer_size = padded_bytes_per_row as u64 * h as u64;
358    let n_pixels = (w * h) as usize;
359
360    let read_buffer = wgpu_device.create_buffer(&wgpu::BufferDescriptor {
361        label: Some("oidn_wgpu readback"),
362        size: buffer_size,
363        usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
364        mapped_at_creation: false,
365    });
366    let mut encoder = wgpu_device.create_command_encoder(&wgpu::CommandEncoderDescriptor::default());
367    encoder.copy_texture_to_buffer(
368        wgpu::TexelCopyTextureInfo {
369            texture,
370            mip_level: 0,
371            origin: wgpu::Origin3d::ZERO,
372            aspect: wgpu::TextureAspect::All,
373        },
374        wgpu::TexelCopyBufferInfo {
375            buffer: &read_buffer,
376            layout: wgpu::TexelCopyBufferLayout {
377                offset: 0,
378                bytes_per_row: Some(padded_bytes_per_row),
379                rows_per_image: Some(h),
380            },
381        },
382        size,
383    );
384    wgpu_queue.submit(Some(encoder.finish()));
385
386    let slice = read_buffer.slice(..);
387    let (tx, rx) = mpsc::channel();
388    slice.map_async(wgpu::MapMode::Read, move |r| {
389        let _ = tx.send(r);
390    });
391    loop {
392        let _ = wgpu_device.poll(wgpu::PollType::wait_indefinitely());
393        match rx.try_recv() {
394            Ok(Ok(())) => break,
395            Ok(Err(_)) => return Err(Error::BufferMapFailed(wgpu::BufferAsyncError)),
396            Err(mpsc::TryRecvError::Disconnected) => return Err(Error::BufferMapFailed(wgpu::BufferAsyncError)),
397            Err(mpsc::TryRecvError::Empty) => std::thread::sleep(std::time::Duration::from_micros(100)),
398        }
399    }
400    let mapped = slice.get_mapped_range();
401    let raw: &[u8] = &mapped;
402    let mut rgb = vec![0.0f32; n_pixels * 3];
403    let alpha: Vec<f32> = match format {
404        DenoiseTextureFormat::Rgba32Float => {
405            let floats: &[f32] = cast_slice(raw);
406            for i in 0..n_pixels {
407                let j = i * 4;
408                rgb[i * 3] = floats[j];
409                rgb[i * 3 + 1] = floats[j + 1];
410                rgb[i * 3 + 2] = floats[j + 2];
411            }
412            floats.iter().skip(3).step_by(4).copied().collect()
413        }
414        DenoiseTextureFormat::Rgba16Float => {
415            let u16s: &[u16] = cast_slice(raw);
416            for i in 0..n_pixels {
417                let j = i * 4;
418                rgb[i * 3] = half::f16::from_bits(u16s[j]).to_f32();
419                rgb[i * 3 + 1] = half::f16::from_bits(u16s[j + 1]).to_f32();
420                rgb[i * 3 + 2] = half::f16::from_bits(u16s[j + 2]).to_f32();
421            }
422            u16s.iter()
423                .skip(3)
424                .step_by(4)
425                .map(|&b| half::f16::from_bits(b).to_f32())
426                .collect()
427        }
428    };
429    drop(mapped);
430    Ok((rgb, alpha))
431}
432
433/// Uploads denoised RGB + preserved alpha to a wgpu texture (padded row alignment).
434fn upload_rgba_to_texture(
435    wgpu_device: &wgpu::Device,
436    wgpu_queue: &wgpu::Queue,
437    output: &wgpu::Texture,
438    format: DenoiseTextureFormat,
439    w: u32,
440    h: u32,
441    color_f32: &[f32],
442    alpha_f32: &[f32],
443) -> Result<(), Error> {
444    let n_pixels = (w * h) as usize;
445    let bpp = format.bytes_per_pixel();
446    let bytes_per_row = w * bpp;
447    let alignment = wgpu::COPY_BYTES_PER_ROW_ALIGNMENT;
448    let padded_bytes_per_row = (bytes_per_row + alignment - 1) / alignment * alignment;
449    let mut output_bytes = vec![0u8; n_pixels * bpp as usize];
450    match format {
451        DenoiseTextureFormat::Rgba32Float => {
452            let out_f: &mut [f32] = cast_slice_mut(&mut output_bytes);
453            for i in 0..n_pixels {
454                let j = i * 4;
455                out_f[j] = color_f32[i * 3];
456                out_f[j + 1] = color_f32[i * 3 + 1];
457                out_f[j + 2] = color_f32[i * 3 + 2];
458                out_f[j + 3] = alpha_f32.get(i).copied().unwrap_or(1.0);
459            }
460        }
461        DenoiseTextureFormat::Rgba16Float => {
462            let out_u16: &mut [u16] = cast_slice_mut(&mut output_bytes);
463            for i in 0..n_pixels {
464                out_u16[i * 4] = half::f16::from_f32(color_f32[i * 3]).to_bits();
465                out_u16[i * 4 + 1] = half::f16::from_f32(color_f32[i * 3 + 1]).to_bits();
466                out_u16[i * 4 + 2] = half::f16::from_f32(color_f32[i * 3 + 2]).to_bits();
467                out_u16[i * 4 + 3] = half::f16::from_f32(alpha_f32.get(i).copied().unwrap_or(1.0)).to_bits();
468            }
469        }
470    }
471    let mut upload_data = vec![0u8; (padded_bytes_per_row * h) as usize];
472    for row in 0..h {
473        let src_off = (row * bytes_per_row) as usize;
474        let dst_off = (row * padded_bytes_per_row) as usize;
475        upload_data[dst_off..dst_off + bytes_per_row as usize]
476            .copy_from_slice(&output_bytes[src_off..src_off + bytes_per_row as usize]);
477    }
478    let size = output.size();
479    let write_buffer = wgpu_device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
480        label: Some("oidn_wgpu upload"),
481        contents: &upload_data,
482        usage: wgpu::BufferUsages::COPY_SRC,
483    });
484    let mut enc = wgpu_device.create_command_encoder(&wgpu::CommandEncoderDescriptor::default());
485    enc.copy_buffer_to_texture(
486        wgpu::TexelCopyBufferInfo {
487            buffer: &write_buffer,
488            layout: wgpu::TexelCopyBufferLayout {
489                offset: 0,
490                bytes_per_row: Some(padded_bytes_per_row),
491                rows_per_image: Some(h),
492            },
493        },
494        wgpu::TexelCopyTextureInfo {
495            texture: output,
496            mip_level: 0,
497            origin: wgpu::Origin3d::ZERO,
498            aspect: wgpu::TextureAspect::All,
499        },
500        size,
501    );
502    wgpu_queue.submit(Some(enc.finish()));
503    Ok(())
504}