Skip to main content

agx/engine/gpu/
params.rs

1//! GPU-friendly parameter struct for uploading to uniform buffers.
2
3use crate::adjust::{ColorWheel, VignetteShape};
4use crate::engine::Parameters;
5
6/// Flat, repr(C) parameter struct for GPU uniform buffers.
7/// All fields are f32 or fixed-size f32 arrays — no enums, Options, or pointers.
8/// Field names mirror [`Parameters`] 1:1.
9#[repr(C)]
10#[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
11#[allow(missing_docs)]
12pub struct GpuParameters {
13    // Linear adjustments
14    pub exposure: f32,
15    pub temperature: f32,
16    pub tint: f32,
17    pub _pad0: f32,
18
19    // Gamma adjustments — tone
20    pub contrast: f32,
21    pub highlights: f32,
22    pub shadows: f32,
23    pub whites: f32,
24
25    pub blacks: f32,
26    pub _pad1: [f32; 3],
27
28    // HSL — 8 channels x 3 values = 24 floats
29    pub hue_shifts: [f32; 8],
30    pub sat_shifts: [f32; 8],
31    pub lum_shifts: [f32; 8],
32
33    // Color grading — 4 wheels x [r_mult, g_mult, b_mult, luminance] + balance
34    pub cg_shadow_tint: [f32; 4],
35    pub cg_midtone_tint: [f32; 4],
36    pub cg_highlight_tint: [f32; 4],
37    pub cg_global_tint: [f32; 4],
38    pub cg_balance_factor: f32,
39    pub cg_balance_active: f32,
40    pub cg_active: f32,
41    pub _pad2: f32,
42
43    // Vignette
44    pub vignette_amount: f32,
45    pub vignette_shape: f32, // 0.0 = elliptical, 1.0 = circular
46    pub hsl_active: f32,
47    pub _pad3: f32,
48
49    // Dehaze
50    pub dehaze_amount: f32,
51    pub _pad4: [f32; 3],
52
53    // Grain
54    pub grain_amount: f32,
55    pub grain_size: f32,
56    pub grain_type: f32, // 0.0 = Fine, 1.0 = Silver, 2.0 = Harsh
57    pub grain_seed: f32,
58
59    // Tone curve active flags (1.0 = active, 0.0 = inactive)
60    pub tc_rgb_active: f32,
61    pub tc_luma_active: f32,
62    pub tc_red_active: f32,
63    pub tc_green_active: f32,
64    pub tc_blue_active: f32,
65    pub lut_active: f32,
66    pub _pad_tc: [f32; 2],
67
68    // Image dimensions (needed by vignette, grain, etc.)
69    pub width: f32,
70    pub height: f32,
71    pub _pad5: [f32; 2],
72
73    // Detail / unsharp mask (set per-dispatch)
74    pub detail_strength: f32,
75    pub detail_threshold: f32,
76    pub detail_masking: f32,
77    pub kernel_size: f32,
78
79    // Noise reduction (set per-dispatch for channel/level)
80    pub nr_luminance: f32,
81    pub nr_color: f32,
82    pub nr_detail: f32,
83    pub nr_channel: f32,
84
85    pub nr_gap: f32,
86    pub nr_threshold: f32,
87    pub nr_is_luma: f32,
88    pub _pad_nr: f32,
89
90    // Dehaze (set per-dispatch)
91    pub dehaze_airlight_r: f32,
92    pub dehaze_airlight_g: f32,
93    pub dehaze_airlight_b: f32,
94    pub dehaze_omega: f32,
95
96    pub dehaze_filter_radius: f32,
97    pub dehaze_mode: f32, // multi-purpose: pixel_min mode, filter direction, etc.
98    pub _pad_dh: [f32; 2],
99}
100
101impl From<&Parameters> for GpuParameters {
102    fn from(p: &Parameters) -> Self {
103        let shadow_tint = wheel_to_tint_and_lum(&p.color_grading.shadows);
104        let midtone_tint = wheel_to_tint_and_lum(&p.color_grading.midtones);
105        let highlight_tint = wheel_to_tint_and_lum(&p.color_grading.highlights);
106        let global_tint = wheel_to_tint_and_lum(&p.color_grading.global);
107
108        Self {
109            exposure: p.exposure,
110            temperature: p.temperature,
111            tint: p.tint,
112            _pad0: 0.0,
113            contrast: p.contrast,
114            highlights: p.highlights,
115            shadows: p.shadows,
116            whites: p.whites,
117            blacks: p.blacks,
118            _pad1: [0.0; 3],
119            hue_shifts: p.hsl.hue_shifts(),
120            sat_shifts: p.hsl.saturation_shifts(),
121            lum_shifts: p.hsl.luminance_shifts(),
122            cg_shadow_tint: shadow_tint,
123            cg_midtone_tint: midtone_tint,
124            cg_highlight_tint: highlight_tint,
125            cg_global_tint: global_tint,
126            cg_balance_factor: 2.0_f32.powf(-p.color_grading.balance / 100.0),
127            cg_balance_active: if p.color_grading.balance != 0.0 {
128                1.0
129            } else {
130                0.0
131            },
132            cg_active: if p.color_grading.is_default() {
133                0.0
134            } else {
135                1.0
136            },
137            _pad2: 0.0,
138            vignette_amount: p.vignette.amount,
139            vignette_shape: match p.vignette.shape {
140                VignetteShape::Elliptical => 0.0,
141                VignetteShape::Circular => 1.0,
142            },
143            hsl_active: if p.hsl.is_default() { 0.0 } else { 1.0 },
144            _pad3: 0.0,
145            dehaze_amount: p.dehaze.amount,
146            _pad4: [0.0; 3],
147            grain_amount: p.grain.amount,
148            grain_size: p.grain.size,
149            grain_type: match p.grain.grain_type {
150                crate::adjust::grain::GrainType::Fine => 0.0,
151                crate::adjust::grain::GrainType::Silver => 1.0,
152                crate::adjust::grain::GrainType::Harsh => 2.0,
153            },
154            grain_seed: 0.0,
155            tc_rgb_active: if p.tone_curve.rgb.is_identity() {
156                0.0
157            } else {
158                1.0
159            },
160            tc_luma_active: if p.tone_curve.luma.is_identity() {
161                0.0
162            } else {
163                1.0
164            },
165            tc_red_active: if p.tone_curve.red.is_identity() {
166                0.0
167            } else {
168                1.0
169            },
170            tc_green_active: if p.tone_curve.green.is_identity() {
171                0.0
172            } else {
173                1.0
174            },
175            tc_blue_active: if p.tone_curve.blue.is_identity() {
176                0.0
177            } else {
178                1.0
179            },
180            lut_active: 0.0, // set by GpuPipeline when LUT is present
181            _pad_tc: [0.0; 2],
182            width: 0.0,
183            height: 0.0,
184            _pad5: [0.0; 2],
185            detail_strength: 0.0,
186            detail_threshold: 0.0,
187            detail_masking: 0.0,
188            kernel_size: 0.0,
189            nr_luminance: p.noise_reduction.luminance,
190            nr_color: p.noise_reduction.color,
191            nr_detail: p.noise_reduction.detail,
192            nr_channel: 0.0,
193            nr_gap: 1.0,
194            nr_threshold: 0.0,
195            nr_is_luma: 0.0,
196            _pad_nr: 0.0,
197            dehaze_airlight_r: 0.0,
198            dehaze_airlight_g: 0.0,
199            dehaze_airlight_b: 0.0,
200            dehaze_omega: 0.0,
201            dehaze_filter_radius: 0.0,
202            dehaze_mode: 0.0,
203            _pad_dh: [0.0; 2],
204        }
205    }
206}
207
208/// Build the 5x256 tone curve data for GPU upload.
209/// Layout: [rgb_256, luma_256, red_256, green_256, blue_256] contiguous.
210/// Inactive curves are identity (value\[i\] = i / 255.0).
211pub fn build_tone_curve_data(params: &crate::engine::Parameters) -> [f32; 1280] {
212    let mut data = [0.0f32; 1280];
213    let identity: [f32; 256] = std::array::from_fn(|i| i as f32 / 255.0);
214    let curves = [
215        &params.tone_curve.rgb,
216        &params.tone_curve.luma,
217        &params.tone_curve.red,
218        &params.tone_curve.green,
219        &params.tone_curve.blue,
220    ];
221    for (ci, curve) in curves.iter().enumerate() {
222        let lut = if curve.is_identity() {
223            identity
224        } else {
225            crate::adjust::build_tone_curve_lut(curve)
226        };
227        data[ci * 256..(ci + 1) * 256].copy_from_slice(&lut);
228    }
229    data
230}
231
232fn wheel_to_tint_and_lum(wheel: &ColorWheel) -> [f32; 4] {
233    let hue_rad = wheel.hue * std::f32::consts::PI / 180.0;
234    let sat = wheel.saturation / 100.0;
235    [
236        1.0 + sat * hue_rad.cos(),
237        1.0 + sat * (hue_rad - 2.0 * std::f32::consts::PI / 3.0).cos(),
238        1.0 + sat * (hue_rad - 4.0 * std::f32::consts::PI / 3.0).cos(),
239        wheel.luminance / 100.0,
240    ]
241}
242
243#[cfg(test)]
244mod tests {
245    use super::*;
246
247    #[test]
248    fn gpu_params_is_pod() {
249        let p = Parameters::default();
250        let gpu: GpuParameters = (&p).into();
251        let _bytes: &[u8] = bytemuck::bytes_of(&gpu);
252    }
253
254    #[test]
255    fn gpu_params_default_values() {
256        let p = Parameters::default();
257        let gpu: GpuParameters = (&p).into();
258        assert_eq!(gpu.exposure, 0.0);
259        assert_eq!(gpu.contrast, 0.0);
260        assert_eq!(gpu.temperature, 0.0);
261        assert_eq!(gpu.vignette_amount, 0.0);
262        assert_eq!(gpu.dehaze_amount, 0.0);
263        assert_eq!(gpu.grain_amount, 0.0);
264    }
265
266    #[test]
267    fn gpu_params_size_is_16_aligned() {
268        // WGSL uniform buffers require 16-byte alignment
269        assert_eq!(std::mem::size_of::<GpuParameters>() % 16, 0);
270    }
271}