rustautogui/normalized_x_corr/
fft_ncc.rs

1/*
2 * Fast Normalized Cross correlation algorithm
3 * Author of the algorithm: J.P.Lewis
4 * http://scribblethink.org/Work/nvisionInterface/vi95_lewis.pdf
5 */
6
7use crate::imgtools;
8use core::cmp::max;
9use image::{ImageBuffer, Luma};
10use rayon::prelude::*;
11use rustfft::{num_complex::Complex, Fft, FftPlanner};
12
13use super::{compute_integral_images, sum_region};
14
15pub fn fft_ncc(
16    image: &ImageBuffer<Luma<u8>, Vec<u8>>,
17    precision: f32,
18    prepared_data: &(Vec<Complex<f32>>, f32, u32, u32, u32),
19) -> Vec<(u32, u32, f64)> {
20    // retreive all precalculated template data, most importantly template with already fft and conjugation calculated
21    // sum squared deviations will be needed for denominator
22    let (
23        template_conj_freq,
24        template_sum_squared_deviations,
25        template_width,
26        template_height,
27        padded_size,
28    ) = prepared_data;
29
30    let mut planner = FftPlanner::<f32>::new();
31    let fft: std::sync::Arc<dyn Fft<f32>> =
32        planner.plan_fft_forward((padded_size * padded_size) as usize);
33    let (image_width, image_height) = image.dimensions();
34    let image_vec: Vec<Vec<u8>> = imgtools::imagebuffer_to_vec(&image);
35
36    if (image_width < *template_width) || (image_height < *template_height) {
37        return Vec::new();
38    }
39
40    // compute needed integral images for denominator calculation
41    let (image_integral, squared_image_integral) = compute_integral_images(&image_vec);
42
43    //// calculating zero mean image
44    let sum_image: u64 = sum_region(&image_integral, 0, 0, image_width, image_height);
45
46    // calculating zero mean image , meaning image pixel values - image zero value
47    let image_average_total = sum_image as f32 / (image_height * image_width) as f32; //@audit check image_height*image_width != 0
48    let mut zero_mean_image: Vec<Vec<f32>> =
49        vec![vec![0.0; image_width as usize]; image_height as usize];
50    for y in 0..image_height {
51        for x in 0..image_width {
52            let image_pixel_value = image.get_pixel(x, y)[0] as f32;
53            zero_mean_image[y as usize][x as usize] = image_pixel_value - image_average_total;
54        }
55    }
56
57    // padding to least squares and placing image in top left corner, same as template
58    let mut image_padded: Vec<Complex<f32>> =
59        vec![Complex::new(0.0, 0.0); (padded_size * padded_size) as usize];
60    for dy in 0..image_height {
61        for dx in 0..image_width {
62            let image_pixel_value = zero_mean_image[dy as usize][dx as usize];
63            image_padded[dy as usize * *padded_size as usize + dx as usize] =
64                Complex::new(image_pixel_value, 0.0);
65        }
66    }
67
68    // conver image into frequency domain
69    let ifft: std::sync::Arc<dyn Fft<f32>> =
70        planner.plan_fft_inverse((padded_size * padded_size) as usize);
71    fft.process(&mut image_padded);
72
73    // calculate F(image) * F(template).conjugate
74    let product_freq: Vec<Complex<f32>> = image_padded
75        .iter()
76        .zip(template_conj_freq.iter())
77        .map(|(&img_val, &tmpl_val)| img_val * tmpl_val)
78        .collect();
79    // do inverse fft
80    let mut fft_result: Vec<Complex<f32>> = product_freq.clone();
81    ifft.process(&mut fft_result);
82
83    // flatten for multithreading
84    let coords: Vec<(u32, u32)> = (0..=(image_height - template_height)) //@audit could underflow if image_height = 0
85        .flat_map(|y| (0..=(image_width - template_width)).map(move |x| (x, y))) //@audit could underflow if image_width = 0
86        .collect();
87    // multithreading pixel by pixel template sliding, where correlations are filtered by precision
88    // sending all needed data to calculate nominator and denominator at each of pixel positions
89    let mut found_points: Vec<(u32, u32, f64)> = coords
90        .par_iter()
91        .map(|&(x, y)| {
92            let corr = fft_correlation_calculation(
93                &image_integral,
94                &squared_image_integral,
95                *template_width,
96                *template_height,
97                *template_sum_squared_deviations,
98                x,
99                y,
100                *padded_size,
101                &fft_result,
102            );
103
104            (x, y, corr)
105        })
106        .filter(|&(_, _, corr)| corr > precision as f64)
107        .collect();
108    found_points.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap());
109
110    found_points
111}
112
113#[allow(dead_code)]
114fn fft_correlation_calculation(
115    image_integral: &[Vec<u64>],
116    squared_image_integral: &[Vec<u64>],
117    template_width: u32,
118    template_height: u32,
119    template_sum_squared_deviations: f32,
120    x: u32, // big image x value
121    y: u32, // big image y value,
122    padded_size: u32,
123    fft_result: &[Complex<f32>],
124) -> f64 {
125    /// Function for calculation of correlation at each pixel position
126    ////////// denominator calculation
127    let sum_image: u64 = sum_region(image_integral, x, y, template_width, template_height);
128
129    let sum_squared_image: u64 = sum_region(
130        squared_image_integral,
131        x,
132        y,
133        template_width,
134        template_height,
135    );
136    let image_sum_squared_deviations = sum_squared_image as f64
137        - (sum_image as f64).powi(2) / (template_height * template_width) as f64; //@audit check template_height*template_width!=0
138    let denominator =
139        (image_sum_squared_deviations * template_sum_squared_deviations as f64).sqrt();
140
141    /////////////// NOMINATOR CALCULATION
142
143    // fft result is calculated invert of whole image and template that were padded and zero valued
144    // each pixel position shows value for that template position
145    let numerator_value =
146        fft_result[(y * padded_size) as usize + x as usize].re / (padded_size * padded_size) as f32; //@audit guess the padded_size is always non zero but could be checked
147    let mut corr = numerator_value as f64 / denominator;
148
149    if corr > 2.0 {
150        corr = -100.0;
151    }
152    corr
153}
154
155pub fn prepare_template_picture(
156    template: &ImageBuffer<Luma<u8>, Vec<u8>>,
157    image_width: u32,
158    image_height: u32,
159) -> (Vec<Complex<f32>>, f32, u32, u32, u32) {
160    /// precalculate all the neccessary data so its not slowing down main process
161    /// returning template in frequency domain, with calculated conjugate
162    let (template_width, template_height) = template.dimensions();
163    let padded_width = image_width.next_power_of_two();
164    let padded_height = image_height.next_power_of_two();
165    let padded_size = max(padded_width, padded_height);
166
167    let mut sum_template = 0.0;
168    // calculate needed sums
169    for y in 0..template_height {
170        for x in 0..template_width {
171            let template_value = template.get_pixel(x, y)[0] as f32;
172            sum_template += template_value;
173        }
174    }
175    let mean_template_value = sum_template / (template_height * template_width) as f32;
176    // create zero mean template
177    let mut zero_mean_template: Vec<Vec<f32>> =
178        vec![vec![0.0; template_width as usize]; template_height as usize];
179    let mut template_sum_squared_deviations: f32 = 0.0;
180    for y in 0..template_height {
181        for x in 0..template_width {
182            let template_value = template.get_pixel(x, y)[0] as f32;
183            let squared_deviation = (template_value - mean_template_value as f32).powf(2.0);
184            template_sum_squared_deviations += squared_deviation;
185
186            // set zero mean value on new template
187            zero_mean_template[y as usize][x as usize] = template_value - mean_template_value;
188        }
189    }
190    // pad the zero mean template
191    let mut template_padded: Vec<Complex<f32>> =
192        vec![Complex::new(0.0, 0.0); (padded_size * padded_size) as usize];
193    for dy in 0..template_height {
194        for dx in 0..template_width {
195            let template_pixel_value = zero_mean_template[dy as usize][dx as usize];
196            template_padded[dy as usize * padded_size as usize + dx as usize] =
197                Complex::new(template_pixel_value, 0.0);
198        }
199    }
200    // convert template to frequency domain
201    let mut planner = FftPlanner::<f32>::new();
202    let fft: std::sync::Arc<dyn Fft<f32>> =
203        planner.plan_fft_forward((padded_size * padded_size) as usize);
204    fft.process(&mut template_padded);
205    // calculate template conjugate
206    let template_conj_freq: Vec<Complex<f32>> =
207        template_padded.iter().map(|&val| val.conj()).collect();
208    (
209        template_conj_freq,
210        template_sum_squared_deviations,
211        template_width,
212        template_height,
213        padded_size,
214    )
215}