Skip to main content

opendefocus_kernel/
lib.rs

1#![warn(unused_extern_crates)]
2#![cfg_attr(target_arch = "spirv", no_std)]
3
4mod datamodel;
5mod stages;
6mod util;
7use glam::UVec2;
8#[cfg(not(any(target_arch = "spirv")))]
9use image::{LumaA, Rgba};
10#[cfg(not(any(target_arch = "spirv")))]
11use opendefocus_shared::cpu_image::{CPUImage, Sampler};
12use opendefocus_shared::{ConvolveSettings, GlobalFlags, ThreadId, math::get_real_coordinates};
13
14#[cfg(target_arch = "spirv")]
15use glam::UVec3;
16#[cfg(target_arch = "spirv")]
17use spirv_std::{
18    Sampler,
19    image::{Image, Image2d},
20    spirv,
21};
22
23use crate::{
24    stages::{
25        ring::{Rings, calculate_ring},
26        sample::get_alpha_map,
27    },
28    util::image::{load_texture, write_texture},
29};
30
31fn skip_overlap(position: UVec2, settings: &ConvolveSettings) -> bool {
32    let coordinates = get_real_coordinates(settings, position);
33    !((coordinates.x as i32) >= settings.process_region.x
34        && (coordinates.x as i32) < settings.process_region.z
35        && (coordinates.y as i32) >= settings.process_region.y - 1
36        && (coordinates.y as i32) <= settings.process_region.w)
37}
38
39#[inline(always)]
40pub fn global_entrypoint(
41    thread_id: ThreadId,
42    output_image: &mut [f32],
43    settings: &ConvolveSettings,
44    cached_samples: &[f32],
45    #[cfg(not(any(target_arch = "spirv")))] input_image: &CPUImage<Rgba<f32>>,
46    #[cfg(not(any(target_arch = "spirv")))] inpaint: &CPUImage<Rgba<f32>>,
47    #[cfg(not(any(target_arch = "spirv")))] filter: &CPUImage<Rgba<f32>>,
48    #[cfg(not(any(target_arch = "spirv")))] depth: &CPUImage<LumaA<f32>>,
49    #[cfg(not(any(target_arch = "spirv")))] bilinear_sampler: &Sampler,
50    #[cfg(not(any(target_arch = "spirv")))] nearest_sampler: &Sampler,
51    #[cfg(target_arch = "spirv")] input_image: &Image2d,
52    #[cfg(target_arch = "spirv")] inpaint: &Image2d,
53    #[cfg(target_arch = "spirv")] filter: &Image2d,
54    #[cfg(target_arch = "spirv")] depth: &Image2d,
55    #[cfg(target_arch = "spirv")] bilinear_sampler: &Sampler,
56    #[cfg(target_arch = "spirv")] nearest_sampler: &Sampler,
57) {
58    let coords = thread_id.get_coordinates();
59    if coords.x >= settings.full_region.z as u32 || coords.y >= settings.full_region.w as u32 {
60        return;
61    }
62    let resolution = settings.get_image_resolution();
63    let coordinates_scale = coords.as_vec2() / settings.get_image_resolution().as_vec2();
64    if skip_overlap(coords, settings) {
65        return;
66    }
67
68    let center_size = if GlobalFlags::from_bits_retain(settings.flags).contains(GlobalFlags::IS_2D)
69    {
70        settings.max_size
71    } else {
72        load_texture(depth, coordinates_scale, nearest_sampler, 0.0).x
73    };
74    let mut main_rings = Rings::new(0, 1.0);
75
76    for ring_id in (0..settings.samples).rev() {
77        let current_rings = calculate_ring(
78            settings,
79            input_image,
80            filter,
81            inpaint,
82            depth,
83            cached_samples,
84            ring_id,
85            center_size,
86            coords,
87            bilinear_sampler,
88            nearest_sampler,
89        );
90        main_rings.merge(cached_samples, current_rings);
91    }
92    main_rings.background.normalize();
93
94    if GlobalFlags::from_bits_retain(settings.flags).contains(GlobalFlags::IS_2D) {
95        let alpha = get_alpha_map(settings, center_size);
96        write_texture(
97            output_image,
98            &[
99                main_rings.background.color.x * alpha,
100                main_rings.background.color.y * alpha,
101                main_rings.background.color.z * alpha,
102                main_rings.background.color.w * alpha,
103                alpha,
104            ],
105            coords,
106            resolution.x,
107        );
108        return;
109    }
110    main_rings.foreground.normalize();
111
112    let mut output = main_rings.background.color;
113    output = main_rings.foreground.color + output * (1.0 - main_rings.foreground.alpha);
114    let alpha = main_rings.foreground.alpha_masked
115        + main_rings.background.alpha_masked * (1.0 - main_rings.foreground.alpha);
116
117    write_texture(
118        output_image,
119        &[
120            output.x,
121            output.y,
122            output.z,
123            output.w,
124            alpha.clamp(0.0, 1.0),
125        ],
126        coords,
127        resolution.x,
128    );
129}
130
131/// GPU entry point for Vulkan/SPIR-V
132#[cfg(target_arch = "spirv")]
133#[spirv(compute(threads(16, 16)))]
134pub fn convolve_kernel_f32(
135    #[spirv(global_invocation_id)] gid: UVec3,
136    #[spirv(uniform, descriptor_set = 0, binding = 0)] settings: &ConvolveSettings,
137    #[spirv(descriptor_set = 0, binding = 1)] bilinear_sampler: &Sampler,
138    #[spirv(descriptor_set = 0, binding = 2)] nearest_sampler: &Sampler,
139    #[spirv(storage_buffer, descriptor_set = 0, binding = 3)] output_image: &mut [f32],
140    #[spirv(descriptor_set = 0, binding = 4)] input_image: &Image!(2D, type=f32, sampled),
141    #[spirv(descriptor_set = 0, binding = 5)] filter: &Image!(2D, type=f32, sampled),
142    #[spirv(descriptor_set = 0, binding = 6)] inpaint_image: &Image!(2D, type=f32, sampled),
143    #[spirv(descriptor_set = 0, binding = 7)] depth: &Image!(2D, type=f32, sampled),
144    #[spirv(storage_buffer, descriptor_set = 0, binding = 8)] cached_sample_weights: &[f32],
145) {
146    let thread_id = ThreadId::new(gid.x, gid.y);
147    global_entrypoint(
148        thread_id,
149        output_image,
150        settings,
151        cached_sample_weights,
152        input_image,
153        inpaint_image,
154        filter,
155        depth,
156        bilinear_sampler,
157        nearest_sampler,
158    );
159}