Skip to main content

opendefocus_shared/
math.rs

1use crate::internal_settings::ConvolveSettings;
2#[cfg(target_arch = "spirv")]
3use core::arch::asm;
4use core::f32::consts::PI;
5use glam::{UVec2, Vec2, Vec4};
6
7pub const BASE_SAMPLES: u32 = 20;
8
9#[must_use]
10#[inline]
11/// Calculate the base number of points to distribute along a ring
12pub fn get_points_for_ring(ring_id: u32, samples: u32, use_base_points: bool) -> u32 {
13    if ring_id == 0 {
14        return 1;
15    }
16    let mut base_points = 0;
17    if use_base_points {
18        base_points += BASE_SAMPLES;
19    }
20
21    base_points += ring_id * 2;
22    ((samples as f32) * 0.1 * (base_points as f32)) as u32
23}
24
25#[must_use]
26#[inline]
27#[cfg(not(feature = "libm"))]
28pub fn powf(value: f32, pow: f32) -> f32 {
29    value.powf(pow)
30}
31
32#[must_use]
33#[inline]
34#[cfg(feature = "libm")]
35pub fn powf(value: f32, pow: f32) -> f32 {
36    libm::powf(value, pow as f32)
37}
38
39#[must_use]
40#[inline]
41#[cfg(not(feature = "libm"))]
42pub fn sqrt(value: f32) -> f32 {
43    value.sqrt()
44}
45
46#[inline]
47#[cfg(not(feature = "libm"))]
48/// Just a simple port from <https://mazzo.li/posts/vectorized-atan2.html>
49/// This is a fast port mostly without branching
50pub fn atan2f(a: f32, b: f32) -> f32 {
51    let swap = a.abs() > b.abs();
52    let atan_input = if swap { b / a } else { a / b };
53
54    const A1: f32 = 0.99997726;
55    const A3: f32 = -0.33262347;
56    const A5: f32 = 0.19354346;
57    const A7: f32 = -0.11643287;
58    const A9: f32 = 0.05265332;
59    const A11: f32 = -0.01172120;
60
61    let z_sq = atan_input * atan_input;
62    let res =
63        atan_input * (A1 + z_sq * (A3 + z_sq * (A5 + z_sq * (A7 + z_sq * (A9 + z_sq * A11)))));
64
65    let mut res = if_else!(swap, core::f32::consts::FRAC_PI_2 - res, res);
66    res = if_else!(b < 0.0, core::f32::consts::PI - res, res);
67
68    let temp_res_bits = res.to_bits();
69    let temp_y_bits = a.to_bits();
70    let result_bits = (temp_res_bits & 0x7FFFFFFF) | (temp_y_bits & 0x80000000);
71
72    f32::from_bits(result_bits)
73}
74
75#[inline]
76#[cfg(feature = "libm")]
77pub fn atan2f(a: f32, b: f32) -> f32 {
78    libm::atan2f(a, b)
79}
80
81#[must_use]
82#[inline]
83#[cfg(feature = "libm")]
84pub fn sqrt(value: f32) -> f32 {
85    libm::sqrtf(value)
86}
87
88#[must_use]
89#[inline]
90#[cfg(not(feature = "libm"))]
91pub fn cosf(value: f32) -> f32 {
92    value.cos()
93}
94
95#[must_use]
96#[inline]
97#[cfg(feature = "libm")]
98pub fn cosf(value: f32) -> f32 {
99    libm::cosf(value)
100}
101
102#[must_use]
103#[inline]
104#[cfg(not(feature = "libm"))]
105pub fn sinf(value: f32) -> f32 {
106    value.sin()
107}
108
109#[must_use]
110#[inline]
111#[cfg(feature = "libm")]
112pub fn sinf(value: f32) -> f32 {
113    libm::sinf(value)
114}
115
116#[must_use]
117#[inline]
118#[cfg(not(feature = "libm"))]
119pub fn floorf(value: f32) -> f32 {
120    value.floor()
121}
122
123#[must_use]
124#[inline]
125#[cfg(feature = "libm")]
126pub fn floorf(value: f32) -> f32 {
127    libm::floorf(value)
128}
129
130#[must_use]
131#[inline]
132#[cfg(not(feature = "libm"))]
133pub fn ceilf(value: f32) -> f32 {
134    value.ceil()
135}
136
137#[must_use]
138#[inline]
139#[cfg(feature = "libm")]
140pub fn ceilf(value: f32) -> f32 {
141    libm::ceilf(value)
142}
143
144#[must_use]
145#[inline]
146#[cfg(not(feature = "libm"))]
147pub fn log2f(value: f32) -> f32 {
148    value.log2()
149}
150
151#[must_use]
152#[inline]
153#[cfg(feature = "libm")]
154pub fn log2f(value: f32) -> f32 {
155    libm::log2f(value)
156}
157
158#[must_use]
159#[inline]
160/// Function that allows pow to be in the negative.
161pub fn neg_pow(number: f32, exponent: f32) -> f32 {
162    f32::copysign(powf(number.abs(), exponent), number)
163}
164
165#[must_use]
166#[inline]
167/// Linearly interpolates between two arrays of the same size.
168pub fn mix_vec(a: Vec4, b: Vec4, t: f32) -> Vec4 {
169    a * (1.0 - t) + b * t
170}
171
172#[must_use]
173#[inline]
174#[cfg(not(target_arch = "spirv"))]
175/// Linearly interpolates between two values.
176pub fn mix(x: f32, y: f32, a: f32) -> f32 {
177    x * (1.0 - a) + y * a
178}
179
180#[must_use]
181#[inline]
182#[cfg(target_arch = "spirv")]
183/// Linearly interpolates between two values.
184pub fn mix(x: f32, y: f32, a: f32) -> f32 {
185    let result;
186
187    unsafe {
188        asm!(
189            "%glsl = OpExtInstImport \"GLSL.std.450\"",
190            "%float = OpTypeFloat 32",
191            "%x = OpLoad _ {x}",
192            "%y = OpLoad _ {y}",
193            "%a = OpLoad _ {a}",
194
195            // 55 = PackUnorm4x8
196            "{result} = OpExtInst %float %glsl 46 %x %y %a",
197            x = in(reg) &x,
198            y = in(reg) &y,
199            a = in(reg) &a,
200            result = out(reg) result,
201        );
202    }
203
204    result
205}
206
207#[must_use]
208#[inline]
209pub fn saturate(x: f32) -> f32 {
210    x.clamp(0.0, 1.0)
211}
212
213#[must_use]
214#[inline]
215#[cfg(target_arch = "spirv")]
216pub fn smoothstep(edge0: f32, edge1: f32, x: f32) -> f32 {
217    let result;
218
219    unsafe {
220        asm!(
221            "%glsl = OpExtInstImport \"GLSL.std.450\"",
222            "%float = OpTypeFloat 32",
223            "%edge0 = OpLoad _ {edge0}",
224            "%edge1 = OpLoad _ {edge1}",
225            "%x = OpLoad _ {x}",
226
227            // 55 = PackUnorm4x8
228            "{result} = OpExtInst %float %glsl 49 %edge0 %edge1 %x",
229            edge0 = in(reg) &edge0,
230            edge1 = in(reg) &edge1,
231            x = in(reg) &x,
232            result = out(reg) result,
233        );
234    }
235
236    result
237}
238
239#[must_use]
240#[inline]
241#[cfg(not(target_arch = "spirv"))]
242pub fn smoothstep(edge0: f32, edge1: f32, x: f32) -> f32 {
243    if edge1 == edge0 {
244        return 0.0;
245    }
246    let x = saturate((x - edge0) / (edge1 - edge0));
247    x * x * (3.0 - 2.0 * x)
248}
249
250#[must_use]
251#[inline]
252/// Calculate the coordinates in X and Y on a circle by the radius.
253pub fn get_coordinates_on_circle(angle: f32, radius: f32) -> Vec2 {
254    let theta = angle.to_radians();
255
256    Vec2::new(radius * cosf(theta), radius * sinf(theta))
257}
258
259#[must_use]
260#[inline]
261/// Get the interpolated sample weight based on Circle of Confusion (CoC)
262pub fn get_sample_weight(cached_samples: &[f32], coc: f32) -> f32 {
263    let bottom = floorf(coc.abs()) as usize;
264    let top = ceilf(coc.abs()) as usize;
265    let bottom_value = cached_samples[bottom];
266    let top_value = cached_samples[top];
267
268    let interpolation_weight = coc.abs() - floorf(coc.abs());
269    mix(bottom_value, top_value, interpolation_weight)
270}
271
272#[must_use]
273#[inline]
274/// Calculate the coverage weight for a pixel position in a ring
275///
276/// This function calculates how many times a pixel gets covered by the sampling pattern
277/// based on the radius and angular distance between samples.
278pub fn calculate_coverage_weight(radius: f32, total_points: u32) -> f32 {
279    if total_points == 0 {
280        return 1.0;
281    }
282
283    let angle_per_sample = 360.0 / total_points as f32;
284    let arc_length = radius * angle_per_sample * (PI / 180.0);
285    1.0 / arc_length.max(0.0001)
286}
287
288#[must_use]
289#[inline]
290/// Calculate the normalized coverage weight for a pixel position in a ring
291pub fn calculate_normalized_coverage_weight(
292    radius: f32,
293    base_points: u32,
294    total_points: u32,
295) -> f32 {
296    if base_points == 0 || total_points == 0 {
297        return 1.0;
298    }
299
300    let base_weight = calculate_coverage_weight(radius, base_points);
301    let total_weight = calculate_coverage_weight(radius, total_points);
302
303    total_weight / base_weight
304}
305
306#[must_use]
307#[inline]
308/// Get the coordinates in screen-space based on the full region
309pub fn get_real_coordinates(settings: &ConvolveSettings, coordinates: UVec2) -> UVec2 {
310    UVec2::new(
311        settings.full_region.x as u32 + coordinates.x,
312        settings.full_region.y as u32 + coordinates.y,
313    )
314}
315
316#[cfg(test)]
317mod tests {
318    use super::*;
319    use rstest::rstest;
320
321    #[rstest]
322    #[case(1.0, 0.0, 1.5707964)]
323    #[case(0.0, 1.0, 0.0)]
324    #[case(1.0, 1.0, core::f32::consts::FRAC_PI_4)]
325    fn test_atan2(#[case] a: f32, #[case] b: f32, #[case] expected: f32) {
326        let result = atan2f(a, b);
327        println!("result: '{result}', expected: '{expected}'");
328        assert!((result - expected).abs() < 1e-3, "Difference is too large");
329    }
330
331    #[rstest]
332    #[case(10.0, 1.0, 10.0)] // Positive number with positive exponent
333    #[case(-10.0, 1.0, -10.0)] // Negative number with positive exponent
334    #[case(10.0, -1.0, 0.1)] // Positive number with negative exponent
335    #[case(-10.0, -1.0, -0.1)] // Negative number with negative exponent
336    #[case(0.0, 2.0, 0.0)] // Zero with positive exponent
337    #[case(2.0, 0.5, core::f32::consts::SQRT_2)] // Positive number with fractional exponent
338    #[case(1e-6, 2.0, 1e-12)] // Very small number with positive exponent
339    #[case(1e6, -2.0, 1e-12)] // Very large number with negative exponent
340
341    fn test_neg_pow(#[case] input: f32, #[case] exponent: f32, #[case] expected: f32) {
342        assert_eq!(neg_pow(input, exponent), expected)
343    }
344    #[rstest]
345    #[case(0.0, 1.0, 0.0, 0.0)] // x below edge0
346    #[case(0.0, 1.0, 0.5, 0.5)] // x between edge0 and edge1
347    #[case(2.0, 1.0, 0.5, 1.0)] // x above edge1
348    #[case(0.0, 0.0, 0.0, 0.0)] // edge0 == edge1
349    fn test_smoothstep(
350        #[case] edge0: f32,
351        #[case] edge1: f32,
352        #[case] x: f32,
353        #[case] expected: f32,
354    ) {
355        assert_eq!(smoothstep(edge0, edge1, x), expected);
356    }
357}