rustautogui/core/template_match/
opencl_kernel.rs

1#[cfg(feature = "opencl")]
2pub const OCL_KERNEL: &str = r#"
3inline ulong sum_region(
4    __global const ulong* integral,
5    int x,
6    int y,
7    int width,
8    int height,
9    int image_width
10) {
11    int x2 = x + width - 1;
12    int y2 = y + height - 1;
13
14    ulong br = integral[y2 * image_width + x2];
15    ulong bl = (x == 0) ? 0 : integral[y2 * image_width + (x - 1)];
16    ulong tr = (y == 0) ? 0 : integral[(y - 1) * image_width + x2];
17    ulong tl = (x == 0 || y == 0) ? 0 : integral[(y - 1) * image_width + (x - 1)];
18    long sum = (long)br + (long)tl - (long)bl - (long)tr;
19
20
21    return (ulong)sum;
22}
23
24
25inline ulong sum_region_squared(
26    __global const ulong* integral_sq,
27    int x,
28    int y,
29    int width,
30    int height,
31    int image_width
32) {
33    int x2 = x + width - 1;
34    int y2 = y + height - 1;
35
36    ulong br = integral_sq[y2 * image_width + x2];
37    ulong bl = (x == 0) ? 0 : integral_sq[y2 * image_width + (x - 1)];
38    ulong tr = (y == 0) ? 0 : integral_sq[(y - 1) * image_width + x2];
39    ulong tl = (x == 0 || y == 0) ? 0 : integral_sq[(y - 1) * image_width + (x - 1)];
40    long sum = (long)br + (long)tl - (long)bl - (long)tr;
41    return (ulong)sum;
42}
43
44
45__kernel void segmented_match_integral(
46    __global const ulong* integral,
47    __global const ulong* integral_sq,
48    __global const int4* segments,
49    __global const int4* segments_slow,
50    __global const float* segment_values,
51    __global const float* segment_values_slow,
52    const int num_segments,
53    const int num_segments_slow,
54    const float template_mean,
55    const float template_mean_slow,
56    const float template_sq_dev,
57    const float template_sq_dev_slow,
58    __global float* results,
59    const int image_width,
60    const int image_height,
61    const int template_width,
62    const int template_height,
63    const float min_expected_corr,
64    __global float* precision_buff
65) {
66    float precision = precision_buff[0];
67    int idx = get_global_id(0);
68    int result_width = image_width - template_width + 1;
69    int result_height = image_height - template_height + 1;
70
71    if (idx >= result_width * result_height) return;
72    // results[idx] = 0.0;
73    int x = idx % result_width;
74    int y = idx / result_width;
75
76
77    ulong patch_sum = sum_region(integral, x, y, template_width, template_height, image_width);
78    ulong patch_sq_sum = sum_region_squared(integral_sq, x, y, template_width, template_height, image_width);
79    
80
81
82    float area = (float)(template_width * template_height);
83    float mean_img = (float)(patch_sum) / area;
84    float var_img = (float)(patch_sq_sum) - ((float)(patch_sum) * (float)(patch_sum)) / area;
85    
86    float nominator = 0.0f;
87    for (int i = 0; i < num_segments; i++) {
88        int4 seg = segments[i];
89        float seg_val = segment_values[i];
90        int seg_area = seg.z * seg.w;
91
92        ulong region_sum = sum_region(integral, x + seg.x, y + seg.y, seg.z, seg.w, image_width);
93
94        nominator += ((float)(region_sum) - mean_img * seg_area) * (seg_val - template_mean);
95    }
96
97    float denominator = sqrt(var_img * template_sq_dev);
98    
99    float corr = (denominator != 0.0f) ? (nominator / denominator) : -1.0f;
100
101
102
103    if (corr < (min_expected_corr - 0.001)* precision) {
104        results[idx] = corr;
105        return;
106    } else {
107        float denominator_slow = sqrt(var_img * template_sq_dev_slow);
108        float nominator_slow = 0.0f;
109        for (int i = 0; i < num_segments_slow; i++) {
110            int4 seg_slow = segments_slow[i];
111            float seg_val_slow = segment_values_slow[i];
112            int seg_area = seg_slow.z * seg_slow.w;
113
114            ulong region_sum = sum_region(integral, x + seg_slow.x, y + seg_slow.y, seg_slow.z, seg_slow.w, image_width);
115
116            nominator_slow += ((float)(region_sum) - mean_img * seg_area) * (seg_val_slow - template_mean);
117        }
118        float corr_slow = (denominator_slow != 0.0f) ? (nominator_slow / denominator_slow) : -1.0f;
119        results[idx] = corr_slow;
120    }    
121}
122
123
124__kernel void v2_segmented_match_integral_fast_pass(
125    __global const ulong* integral,
126    __global const ulong* integral_sq,
127    __global const int4* segments,
128    __global const float* segment_values,
129    const int num_segments,
130    const float template_mean,
131    const float template_sq_dev,
132    __global int2* results,
133    const int image_width,
134    const int image_height,
135    const int template_width,
136    const int template_height,
137    const float min_expected_corr,
138    const int remainder_segments_fast,
139    const int segments_per_thread_fast,
140    const int pixels_per_workgroup,
141    const int workgroup_size,
142    __local ulong* sum_template_region_buff,
143    __local ulong* sum_sq_template_region_buff,
144    __local float* thread_segment_sum_buff,
145    __global int* valid_corr_count,
146    __global float* precision_buff
147) {
148    int global_id = get_global_id(0);
149    int local_id = get_local_id(0);
150    int workgroup_id = get_group_id(0);
151    int result_w = image_width - template_width;
152    if (local_id == 3 && global_id == 2) {
153        valid_corr_count[0] == 0;
154    }
155
156
157    // num_segments is also count of threads per pixel for fast img
158    if (local_id * segments_per_thread_fast +  remainder_segments_fast >= num_segments * pixels_per_workgroup) return ; // this solves more segments per thread
159
160    int pixel_pos = (workgroup_id * pixels_per_workgroup) + (local_id / num_segments) ;
161    int image_x = pixel_pos % result_w;
162    int image_y = pixel_pos / result_w;
163
164    // first sum the region of template area for numerator calculations
165    // we do it with first threads for each x,y position which workgroup processes
166    // if there are 5 pixels processed, local_id 0-4 should process sum regions for each position, 5-9 for squared
167    ulong patch_sum = 0;
168    if (local_id < pixels_per_workgroup) {
169        patch_sum = sum_region(integral, image_x, image_y, template_width, template_height, image_width);
170        sum_template_region_buff[local_id] = patch_sum;
171        
172    }
173    
174    // there will never be less than 2 segments 
175    // meaning pixels per workgroup is never greater than workgroup_size / 2 
176    if (local_id >= pixels_per_workgroup && local_id < pixels_per_workgroup * 2) {
177        ulong patch_sq_sum = sum_region_squared(integral_sq, image_x, image_y, template_width, template_height, image_width);
178        sum_sq_template_region_buff[local_id % pixels_per_workgroup] = patch_sq_sum;
179    }
180    
181    int result_width = image_width - template_width + 1;
182    int result_height = image_height - template_height + 1;
183    float area = (float)(template_width * template_height);
184
185    // wait  for threads to complete writing sum_area
186    barrier(CLK_LOCAL_MEM_FENCE);
187
188    
189    float mean_img = (float)(sum_template_region_buff[local_id / num_segments]) / area;
190
191
192    // this is to cover if we have more than 1 segment per thread. This method 
193    // with remainder allows us to keep all threads working
194    int remainder_offset = 0;
195    int remainder_addition = 0;
196    if (remainder_segments_fast > 0) {
197        if (local_id >= remainder_segments_fast) {
198            remainder_offset = remainder_segments_fast;
199        } else {
200            remainder_offset = local_id;
201            remainder_addition = 1; 
202        }
203    
204    }
205
206    
207    
208    // AUDIT - DOUBLE CHECK THIS LOGIC
209    int thread_segment_start = (local_id * segments_per_thread_fast + remainder_offset ) % num_segments;
210    int thread_segment_end = thread_segment_start +  segments_per_thread_fast + remainder_addition;
211
212    float nominator = 0.0f;
213    for (int i = thread_segment_start; i< thread_segment_end; i++) {
214        
215        int4 seg = segments[i];
216        float seg_val = segment_values[i];
217        int seg_area = seg.z* seg.w;
218        ulong region_sum = sum_region(integral, image_x + seg.x, image_y + seg.y, seg.z, seg.w, image_width);
219        
220
221        nominator += ((float)(region_sum) - mean_img * seg_area) * (seg_val - template_mean);
222
223    }
224    
225    thread_segment_sum_buff[local_id] = nominator;
226
227    barrier(CLK_LOCAL_MEM_FENCE);
228
229
230    
231    if (local_id < pixels_per_workgroup) {
232        float nominator_sum = 0.0f;
233        int sum_start = local_id * num_segments;
234        int sum_end = sum_start + (num_segments / segments_per_thread_fast ) - (remainder_segments_fast/segments_per_thread_fast);
235        for (int i = sum_start; i< sum_end; i++) {
236            nominator_sum = nominator_sum + thread_segment_sum_buff[i] ;
237        }
238
239        int pixel_pos_final = (workgroup_id * pixels_per_workgroup) + (local_id) ;
240        int image_x = pixel_pos_final % result_w;
241        int image_y = pixel_pos_final / result_w;
242
243        float precision = precision_buff[0];
244        ulong patch_sq_sum_extracted = sum_sq_template_region_buff[local_id];
245        float var_img = (float)patch_sq_sum_extracted - ((float)patch_sum * (float)patch_sum)/ (float)area;
246        float denominator = sqrt(var_img * (float)template_sq_dev);
247        float corr = (denominator != 0.0f) ? (nominator_sum / denominator) : -1.0f;        
248        
249        if (corr >= (min_expected_corr - 0.01) * precision && corr < 2) {
250        
251            int index = atomic_add(valid_corr_count, 1);
252            results[index] = (int2)(image_x, image_y);
253            
254        }
255    } 
256}
257
258
259
260__kernel void v2_segmented_match_integral_slow_pass (
261    __global const ulong* integral,
262    __global const ulong* integral_sq,
263    __global const int4* segments,
264    __global const float* segment_values,
265    const int num_segments,
266    const float template_mean,
267    const float template_sq_dev,
268    __global int2* position_results,
269    __global float* corr_results,
270    const int image_width,
271    const int image_height,
272    const int template_width,
273    const int template_height,
274    const float min_expected_corr,
275    const int remainder_segments_slow,
276    const int segments_per_thread_slow,
277    const int workgroup_size,
278    __local ulong* sum_template_region_buff,
279    __local ulong* sum_sq_template_region_buff,
280    __local float* thread_segment_sum_buff,
281    __global int* valid_corr_count_slow,
282    __global int* valid_corr_count_fast,
283    __global int2* fast_pass_results,
284    __global float* precision_buff
285) {
286    
287    int global_id = get_global_id(0);
288    int local_id = get_local_id(0);
289    int workgroup_id = get_group_id(0);
290
291    
292    
293    
294
295    int image_x = fast_pass_results[workgroup_id].x;
296    int image_y = fast_pass_results[workgroup_id].y;
297
298    int result_w = image_width - template_width;
299    // num_segments is also count of threads per pixel for fast img
300    if (local_id * segments_per_thread_slow +  remainder_segments_slow >= num_segments) return ; // this solves more segments per thread
301
302    
303    // first sum the region of template area for numerator calculations
304    // we do it with first threads for each x,y position which workgroup processes
305    // if there are 5 pixels processed, local_id 0-4 should process sum regions for each position, 5-9 for squared
306    ulong patch_sum = 0;
307    if (local_id == 0) {
308        patch_sum = sum_region(integral, image_x, image_y, template_width, template_height, image_width);
309        sum_template_region_buff[0] = patch_sum;
310        
311    }
312    
313    // there will never be less than 2 segments 
314    // meaning pixels per workgroup is never greater than workgroup_size / 2 
315    if (local_id == 1) {
316        ulong patch_sq_sum = sum_region_squared(integral_sq, image_x, image_y, template_width, template_height, image_width);
317        sum_sq_template_region_buff[0] = patch_sq_sum;
318    }
319    int result_width = image_width - template_width + 1;
320    int result_height = image_height - template_height + 1;
321    float area = (float)(template_width * template_height);
322    // wait  for threads to complete writing sum_area
323    barrier(CLK_LOCAL_MEM_FENCE);
324    float mean_img = (float)(sum_template_region_buff[0]) / area;
325    // this is to cover if we have more than 1 segment per thread. This method 
326    
327
328    // with remainder allows us to keep all threads working
329    int remainder_offset = 0;
330    int remainder_addition = 0;
331    if (remainder_segments_slow > 0) {
332        if (local_id >= remainder_segments_slow) {
333            remainder_offset = remainder_segments_slow;
334        } else {
335            remainder_offset = local_id;
336            remainder_addition = 1; 
337        }
338    
339    }
340
341    int thread_segment_start = (local_id * segments_per_thread_slow + remainder_offset ) % num_segments;
342    int thread_segment_end = thread_segment_start +  segments_per_thread_slow + remainder_addition;
343
344
345    float nominator = 0.0f;
346    for (int i = thread_segment_start; i< thread_segment_end; i++) {
347        
348        int4 seg = segments[i];
349        float seg_val = segment_values[i];
350        int seg_area = seg.z* seg.w;
351        ulong region_sum = sum_region(integral, image_x + seg.x, image_y + seg.y, seg.z, seg.w, image_width);
352        
353
354        nominator += ((float)(region_sum) - mean_img * seg_area) * (seg_val - template_mean);
355
356    }
357    
358    thread_segment_sum_buff[local_id] = nominator;
359    barrier(CLK_LOCAL_MEM_FENCE);
360    if (local_id == 0) {
361        float nominator_sum = 0.0f;
362        int sum_start = 0;
363        int sum_end = sum_start + (num_segments / segments_per_thread_slow ) - (remainder_segments_slow/segments_per_thread_slow);
364        for (int i = sum_start; i< sum_end; i++) {
365            nominator_sum = nominator_sum + thread_segment_sum_buff[i] ;
366        }
367
368        
369
370
371        ulong patch_sq_sum_extracted = sum_sq_template_region_buff[0];
372        float var_img = (float)patch_sq_sum_extracted - ((float)patch_sum * (float)patch_sum)/ (float)area;
373        float denominator = sqrt(var_img * (float)template_sq_dev);
374        float corr = (denominator != 0.0f) ? (nominator_sum / denominator) : -1.0f;        
375        float precision = precision_buff[0];
376        
377        if (corr >= (min_expected_corr - 0.001) * precision  && corr < 2) {
378            int index = atomic_add(valid_corr_count_slow, 1);
379            position_results[index] = (int2)(image_x, image_y);
380            corr_results[index] = corr;
381        }
382    } 
383}
384
385
386"#;