pineapple_core/mp/
zernike.rs

1// Copyright (c) 2025, Tom Ouellette
2// Licensed under the BSD 3-Clause License
3
4use std::ops::Deref;
5
6use num::{FromPrimitive, ToPrimitive, complex::Complex};
7
8use crate::{constant::FACTORIAL, im::PineappleViewBuffer};
9
10#[inline]
11fn radial_polynomial(n: usize, m: usize, r: &mut [Complex<f32>]) {
12    let nf = n as i32;
13    let nsm = (n - m) / 2;
14    let nam = (n + m) / 2;
15
16    for ri in r.iter_mut() {
17        let mut r_nm_i = Complex::new(0.0, 0.0);
18        for si in 0..=nsm {
19            let sf = si as f32;
20            let exp = nf - 2 * si as i32;
21
22            let v = ((-1.0f32).powf(sf) * FACTORIAL[n - si])
23                / (FACTORIAL[si] * FACTORIAL[nam - si] * FACTORIAL[nsm - si]);
24
25            let pow_term = if exp >= 0 {
26                ri.powi(exp)
27            } else {
28                ri.powf(exp as f32)
29            };
30            r_nm_i += Complex::new(v, 0.0) * pow_term;
31        }
32        *ri = r_nm_i;
33    }
34}
35
36#[inline]
37fn zernike_polynomial(n: usize, m: usize, r: &mut [Complex<f32>], theta: &[f32]) {
38    let m_theta = Complex::new(0.0, m as f32);
39    radial_polynomial(n, m, r);
40    for (ri, &theta_i) in r.iter_mut().zip(theta.iter()) {
41        *ri *= (m_theta * theta_i).exp();
42    }
43}
44
45#[inline]
46pub fn zernike_moments<T>(pixels: &[T], width: usize, height: usize, n: usize, m: usize) -> f32
47where
48    T: ToPrimitive,
49{
50    let half_width = width as f32 / 2.0;
51    let half_height = height as f32 / 2.0;
52
53    let mut total_mass = 0.0;
54
55    let capacity = pixels.len();
56
57    let mut circle: Vec<f32> = Vec::with_capacity(capacity);
58    let mut theta = Vec::with_capacity(capacity);
59    let mut r: Vec<Complex<f32>> = Vec::with_capacity(capacity);
60
61    for (i, pixel) in pixels.iter().enumerate() {
62        let x_norm = ((i % width) as f32 - half_width) / half_width;
63        let y_norm = ((i / width) as f32 - half_height) / half_height;
64        let r_i = (x_norm * x_norm + y_norm * y_norm).sqrt();
65
66        if r_i <= 1.0 {
67            let pixel = pixel.to_f32().unwrap();
68            total_mass += pixel;
69            theta.push(y_norm.atan2(x_norm));
70            r.push(Complex::new(r_i, 0.0));
71            circle.push(pixel);
72        }
73    }
74
75    if total_mass == 0.0 {
76        return 0.0;
77    }
78
79    zernike_polynomial(n, m, &mut r, &theta);
80
81    let inv_mass = 1.0 / total_mass;
82    let mut a_nm = Complex::new(0.0, 0.0);
83
84    for (i, z_nm_i) in r.iter().enumerate() {
85        a_nm += z_nm_i.conj() * Complex::new(circle[i] * inv_mass, 0.0);
86    }
87
88    a_nm *= Complex::new((n as f32 + 1.0) / std::f32::consts::PI, 0.0);
89
90    (a_nm.re.powi(2) + a_nm.im.powi(2)).sqrt()
91}
92
93#[inline]
94pub fn descriptors<T>(pixels: &[T], width: usize, height: usize) -> [f32; 30]
95where
96    T: ToPrimitive,
97{
98    let mut descriptors: [f32; 30] = [0.0; 30];
99    let mut i = 0;
100    for n in 0..=9 {
101        for m in 0..=n {
102            if (n - m) % 2 == 0 {
103                descriptors[i] = zernike_moments(pixels, width, height, n, m);
104                i += 1;
105            }
106        }
107    }
108
109    descriptors
110}
111
112#[inline]
113pub fn zernike_moments_object<T, Container>(
114    object: &PineappleViewBuffer<T, Container>,
115    n: usize,
116    m: usize,
117) -> f32
118where
119    T: ToPrimitive + FromPrimitive,
120    Container: Deref<Target = [T]>,
121{
122    let width = object.width();
123    let half_width = width as f32 / 2.0;
124    let half_height = object.height() as f32 / 2.0;
125
126    let mut total_mass = 0.0;
127
128    let capacity = object.len();
129
130    let mut circle: Vec<f32> = Vec::with_capacity(capacity);
131    let mut theta = Vec::with_capacity(capacity);
132    let mut r: Vec<Complex<f32>> = Vec::with_capacity(capacity);
133
134    for (i, pixel) in object.iter().enumerate() {
135        let x_norm = ((i % width) as f32 - half_width) / half_width;
136        let y_norm = ((i / width) as f32 - half_height) / half_height;
137        let r_i = (x_norm * x_norm + y_norm * y_norm).sqrt();
138
139        if r_i <= 1.0 {
140            let pixel = pixel.to_f32().unwrap();
141            total_mass += pixel;
142            theta.push(y_norm.atan2(x_norm));
143            r.push(Complex::new(r_i, 0.0));
144            circle.push(pixel);
145        }
146    }
147
148    if total_mass == 0.0 {
149        return 0.0;
150    }
151
152    zernike_polynomial(n, m, &mut r, &theta);
153
154    let inv_mass = 1.0 / total_mass;
155    let mut a_nm = Complex::new(0.0, 0.0);
156
157    for (i, z_nm_i) in r.iter().enumerate() {
158        a_nm += z_nm_i.conj() * Complex::new(circle[i] * inv_mass, 0.0);
159    }
160
161    a_nm *= Complex::new((n as f32 + 1.0) / std::f32::consts::PI, 0.0);
162
163    (a_nm.re.powi(2) + a_nm.im.powi(2)).sqrt()
164}
165
166#[inline]
167pub fn objects<T, Container>(object: &PineappleViewBuffer<T, Container>) -> [f32; 30]
168where
169    T: ToPrimitive + FromPrimitive,
170    Container: Deref<Target = [T]>,
171{
172    let mut descriptors: [f32; 30] = [0.0; 30];
173    let mut i = 0;
174    for n in 0..=9 {
175        for m in 0..=n {
176            if (n - m) % 2 == 0 {
177                descriptors[i] = zernike_moments_object(object, n, m);
178                i += 1;
179            }
180        }
181    }
182
183    descriptors
184}
185
186#[cfg(test)]
187mod test {
188
189    use super::*;
190
191    #[test]
192    fn test_zernike_moment() {
193        let zm = descriptors(&[1], 1, 1);
194        for i in zm.iter() {
195            assert_eq!(*i, 0.0);
196        }
197    }
198
199    #[test]
200    fn test_radial_polynomial() {
201        let mut r = vec![
202            Complex::new(0.0, 0.0),
203            Complex::new(0.5, 0.0),
204            Complex::new(1.0, 0.0),
205        ];
206
207        radial_polynomial(2, 0, &mut r);
208
209        let expected = [
210            Complex::new(-1.0, 0.0),
211            Complex::new(-0.5, 0.0),
212            Complex::new(1.0, 0.0),
213        ];
214
215        for (res, exp) in r.iter().zip(expected.iter()) {
216            assert!((res.re - exp.re).abs() < 1e-6);
217        }
218    }
219
220    #[test]
221    fn test_zernike_polynomial() {
222        let mut r = vec![
223            Complex::new(1.0, 0.0),
224            Complex::new(1.0, 0.0),
225            Complex::new(1.0, 0.0),
226        ];
227
228        let theta = vec![
229            0.0,
230            std::f32::consts::FRAC_PI_4,
231            std::f32::consts::FRAC_PI_2,
232        ];
233
234        zernike_polynomial(2, 2, &mut r, &theta);
235
236        let expected = [
237            Complex::new(1.0, 0.0),
238            Complex::new(0.0, 1.0),
239            Complex::new(-1.0, 0.0),
240        ];
241
242        for (res, exp) in r.iter().zip(expected.iter()) {
243            assert!((res.re - exp.re).abs() < 1e-6);
244            assert!((res.im - exp.im).abs() < 1e-6);
245        }
246    }
247}