Skip to main content

yscv_imgproc/ops/
surf.rs

1use yscv_tensor::Tensor;
2
3use super::super::ImgProcError;
4use super::super::shape::hwc_shape;
5
6/// SURF keypoint with position, scale, orientation, and response.
7#[derive(Debug, Clone)]
8pub struct SurfKeypoint {
9    pub x: f32,
10    pub y: f32,
11    pub scale: f32,
12    pub orientation: f32,
13    pub response: f32,
14    pub laplacian_sign: i8, // +1 or -1
15}
16
17/// SURF descriptor (64-element vector).
18pub type SurfDescriptor = Vec<f32>;
19
20/// Build integral image for fast box filter computation.
21pub fn build_integral_image(image: &[f32], width: usize, height: usize) -> Vec<f64> {
22    let mut integral = vec![0.0f64; width * height];
23    for y in 0..height {
24        let mut row_sum = 0.0f64;
25        for x in 0..width {
26            row_sum += image[y * width + x] as f64;
27            integral[y * width + x] = row_sum
28                + if y > 0 {
29                    integral[(y - 1) * width + x]
30                } else {
31                    0.0
32                };
33        }
34    }
35    integral
36}
37
38/// Query box sum from integral image.
39/// Computes the sum of pixels in the rectangle [x1, y1] to [x2, y2] (inclusive).
40fn box_sum(integral: &[f64], width: usize, x1: usize, y1: usize, x2: usize, y2: usize) -> f64 {
41    let a = if x1 > 0 && y1 > 0 {
42        integral[(y1 - 1) * width + x1 - 1]
43    } else {
44        0.0
45    };
46    let b = if y1 > 0 {
47        integral[(y1 - 1) * width + x2]
48    } else {
49        0.0
50    };
51    let c = if x1 > 0 {
52        integral[y2 * width + x1 - 1]
53    } else {
54        0.0
55    };
56    let d = integral[y2 * width + x2];
57    d - b - c + a
58}
59
60/// Safe box sum that clamps coordinates to image bounds.
61fn box_sum_safe(
62    integral: &[f64],
63    width: usize,
64    height: usize,
65    x1: i32,
66    y1: i32,
67    x2: i32,
68    y2: i32,
69) -> f64 {
70    let x1 = x1.max(0) as usize;
71    let y1 = y1.max(0) as usize;
72    let x2 = (x2.min(width as i32 - 1)).max(0) as usize;
73    let y2 = (y2.min(height as i32 - 1)).max(0) as usize;
74    if x2 < x1 || y2 < y1 {
75        return 0.0;
76    }
77    box_sum(integral, width, x1, y1, x2, y2)
78}
79
80/// Compute the approximate Hessian determinant at a given point and filter size.
81/// Uses box-filter approximation of second-order Gaussian derivatives.
82fn hessian_det(
83    integral: &[f64],
84    width: usize,
85    height: usize,
86    x: i32,
87    y: i32,
88    filter_size: usize,
89) -> (f64, f64) {
90    let fs = filter_size as i32;
91    let l = fs / 3; // lobe size
92
93    // Dxx: three horizontal lobes
94    let dxx = box_sum_safe(integral, width, height, x - l, y - l / 2, x + l, y + l / 2)
95        - 3.0
96            * box_sum_safe(
97                integral,
98                width,
99                height,
100                x - l / 2,
101                y - l / 2,
102                x + l / 2,
103                y + l / 2,
104            );
105
106    // Dyy: three vertical lobes
107    let dyy = box_sum_safe(integral, width, height, x - l / 2, y - l, x + l / 2, y + l)
108        - 3.0
109            * box_sum_safe(
110                integral,
111                width,
112                height,
113                x - l / 2,
114                y - l / 2,
115                x + l / 2,
116                y + l / 2,
117            );
118
119    // Dxy: four quadrant lobes
120    let dxy = box_sum_safe(integral, width, height, x + 1, y - l, x + l, y - 1)
121        + box_sum_safe(integral, width, height, x - l, y + 1, x - 1, y + l)
122        - box_sum_safe(integral, width, height, x - l, y - l, x - 1, y - 1)
123        - box_sum_safe(integral, width, height, x + 1, y + 1, x + l, y + l);
124
125    // Normalize by filter area
126    let area = (fs * fs) as f64;
127    let dxx = dxx / area;
128    let dyy = dyy / area;
129    let dxy = dxy / area;
130
131    // Hessian determinant approximation (weight 0.9 for Dxy per SURF paper)
132    let det = dxx * dyy - 0.81 * dxy * dxy;
133    let trace = dxx + dyy;
134    (det, trace)
135}
136
137/// Detect SURF keypoints using box-filter approximation of Hessian.
138pub fn detect_surf_keypoints(
139    image: &Tensor,
140    hessian_threshold: f32,
141    num_octaves: usize,
142    num_scales: usize,
143) -> Result<Vec<SurfKeypoint>, ImgProcError> {
144    let (h, w, c) = hwc_shape(image)?;
145    if c != 1 {
146        return Err(ImgProcError::InvalidChannelCount {
147            expected: 1,
148            got: c,
149        });
150    }
151    let data = image.data();
152
153    // 1. Build integral image
154    let integral = build_integral_image(data, w, h);
155
156    // 2. Compute Hessian response at multiple scales
157    // SURF uses filter sizes: 9, 15, 21, 27 for octave 1; 15, 27, 39, 51 for octave 2; etc.
158    let mut scale_responses: Vec<(Vec<f64>, usize)> = Vec::new(); // (response_map, filter_size)
159
160    for octave in 0..num_octaves {
161        let step = 1usize << octave; // sampling step doubles per octave
162        for scale in 0..num_scales {
163            let filter_size = 3 * ((2usize.pow(octave as u32)) * (scale + 1) + 1);
164            if filter_size / 2 >= h.min(w) {
165                continue;
166            }
167            let mut response = vec![0.0f64; h * w];
168            let margin = (filter_size / 2 + 1) as i32;
169
170            for y in (margin as usize..h.saturating_sub(margin as usize)).step_by(step) {
171                for x in (margin as usize..w.saturating_sub(margin as usize)).step_by(step) {
172                    let (det, _trace) =
173                        hessian_det(&integral, w, h, x as i32, y as i32, filter_size);
174                    response[y * w + x] = det;
175                }
176            }
177            scale_responses.push((response, filter_size));
178        }
179    }
180
181    // 3. Non-maximum suppression in 3x3x3 neighborhood (x, y, scale)
182    let mut keypoints = Vec::new();
183    let thresh = hessian_threshold as f64;
184
185    for si in 1..scale_responses.len().saturating_sub(1) {
186        let filter_size = scale_responses[si].1;
187        let margin = filter_size / 2 + 1;
188        let _step = 1usize.max(filter_size / 9);
189
190        for y in margin..h.saturating_sub(margin) {
191            for x in margin..w.saturating_sub(margin) {
192                let val = scale_responses[si].0[y * w + x];
193                if val < thresh {
194                    continue;
195                }
196
197                let mut is_max = true;
198                'nms: for ds in -1i32..=1 {
199                    let si2 = (si as i32 + ds) as usize;
200                    for dy in -1i32..=1 {
201                        for dx in -1i32..=1 {
202                            if ds == 0 && dy == 0 && dx == 0 {
203                                continue;
204                            }
205                            let ny = (y as i32 + dy) as usize;
206                            let nx = (x as i32 + dx) as usize;
207                            if ny < h && nx < w && scale_responses[si2].0[ny * w + nx] >= val {
208                                is_max = false;
209                                break 'nms;
210                            }
211                        }
212                    }
213                }
214
215                if is_max {
216                    // Compute orientation using Haar wavelets in circular neighborhood
217                    let scale = filter_size as f32 * 1.2 / 9.0;
218                    let orientation =
219                        compute_orientation(&integral, w, h, x as f32, y as f32, scale);
220
221                    let (_, trace) = hessian_det(&integral, w, h, x as i32, y as i32, filter_size);
222
223                    keypoints.push(SurfKeypoint {
224                        x: x as f32,
225                        y: y as f32,
226                        scale,
227                        orientation,
228                        response: val as f32,
229                        laplacian_sign: if trace > 0.0 { 1 } else { -1 },
230                    });
231                }
232            }
233        }
234    }
235
236    // Sort by response (descending)
237    keypoints.sort_by(|a, b| {
238        b.response
239            .partial_cmp(&a.response)
240            .unwrap_or(std::cmp::Ordering::Equal)
241    });
242
243    Ok(keypoints)
244}
245
246/// Compute dominant orientation using Haar wavelet responses in a circular region.
247fn compute_orientation(
248    integral: &[f64],
249    width: usize,
250    height: usize,
251    x: f32,
252    y: f32,
253    scale: f32,
254) -> f32 {
255    let radius = (6.0 * scale).round() as i32;
256    let haar_size = (4.0 * scale).round().max(1.0) as i32;
257    let half_haar = haar_size / 2;
258
259    let mut dx_responses = Vec::new();
260    let mut dy_responses = Vec::new();
261    let mut angles = Vec::new();
262
263    // Sample Haar wavelet responses in circular neighborhood
264    for i in -radius..=radius {
265        for j in -radius..=radius {
266            if i * i + j * j > radius * radius {
267                continue;
268            }
269            let px = x as i32 + j;
270            let py = y as i32 + i;
271
272            // Haar wavelet response in x direction
273            let dx = box_sum_safe(
274                integral,
275                width,
276                height,
277                px,
278                py - half_haar,
279                px + half_haar,
280                py + half_haar,
281            ) - box_sum_safe(
282                integral,
283                width,
284                height,
285                px - half_haar,
286                py - half_haar,
287                px,
288                py + half_haar,
289            );
290
291            // Haar wavelet response in y direction
292            let dy = box_sum_safe(
293                integral,
294                width,
295                height,
296                px - half_haar,
297                py,
298                px + half_haar,
299                py + half_haar,
300            ) - box_sum_safe(
301                integral,
302                width,
303                height,
304                px - half_haar,
305                py - half_haar,
306                px + half_haar,
307                py,
308            );
309
310            // Gaussian weight
311            let sigma = 2.5 * scale;
312            let weight = (-(i * i + j * j) as f32 / (2.0 * sigma * sigma)).exp();
313
314            dx_responses.push(dx as f32 * weight);
315            dy_responses.push(dy as f32 * weight);
316            angles.push((dy as f32).atan2(dx as f32));
317        }
318    }
319
320    if dx_responses.is_empty() {
321        return 0.0;
322    }
323
324    // Sliding window of pi/3 to find dominant orientation
325    let window = std::f32::consts::PI / 3.0;
326    let mut best_angle = 0.0f32;
327    let mut best_magnitude = 0.0f32;
328
329    let steps = 36;
330    for step in 0..steps {
331        let angle = -std::f32::consts::PI + step as f32 * 2.0 * std::f32::consts::PI / steps as f32;
332        let mut sum_dx = 0.0f32;
333        let mut sum_dy = 0.0f32;
334
335        for i in 0..angles.len() {
336            let mut diff = angles[i] - angle;
337            // Normalize to [-pi, pi]
338            while diff > std::f32::consts::PI {
339                diff -= 2.0 * std::f32::consts::PI;
340            }
341            while diff < -std::f32::consts::PI {
342                diff += 2.0 * std::f32::consts::PI;
343            }
344            if diff.abs() < window / 2.0 {
345                sum_dx += dx_responses[i];
346                sum_dy += dy_responses[i];
347            }
348        }
349
350        let mag = sum_dx * sum_dx + sum_dy * sum_dy;
351        if mag > best_magnitude {
352            best_magnitude = mag;
353            best_angle = sum_dy.atan2(sum_dx);
354        }
355    }
356
357    best_angle
358}
359
360/// Compute SURF descriptors for detected keypoints.
361pub fn compute_surf_descriptors(
362    image: &Tensor,
363    keypoints: &[SurfKeypoint],
364) -> Result<Vec<SurfDescriptor>, ImgProcError> {
365    let (h, w, c) = hwc_shape(image)?;
366    if c != 1 {
367        return Err(ImgProcError::InvalidChannelCount {
368            expected: 1,
369            got: c,
370        });
371    }
372    let data = image.data();
373    let integral = build_integral_image(data, w, h);
374
375    let mut descriptors = Vec::with_capacity(keypoints.len());
376
377    for kp in keypoints {
378        let scale = kp.scale;
379        let cos_ori = kp.orientation.cos();
380        let sin_ori = kp.orientation.sin();
381        let haar_size = (2.0 * scale).round().max(1.0) as i32;
382        let half_haar = haar_size / 2;
383
384        let mut desc = vec![0.0f32; 64];
385
386        // 4x4 sub-regions, each covering a 5s x 5s area
387        let _sub_region_size = 5.0 * scale;
388
389        for i in 0..4 {
390            for j in 0..4 {
391                let mut sum_dx = 0.0f32;
392                let mut sum_abs_dx = 0.0f32;
393                let mut sum_dy = 0.0f32;
394                let mut sum_abs_dy = 0.0f32;
395
396                // 5x5 sample points per sub-region
397                for k in 0..5 {
398                    for l in 0..5 {
399                        // Position relative to keypoint (in rotated coordinates)
400                        let sample_x = ((i as f32 - 2.0) * 5.0 + l as f32 + 0.5) * scale;
401                        let sample_y = ((j as f32 - 2.0) * 5.0 + k as f32 + 0.5) * scale;
402
403                        // Rotate to image coordinates
404                        let rx = (cos_ori * sample_x - sin_ori * sample_y + kp.x).round() as i32;
405                        let ry = (sin_ori * sample_x + cos_ori * sample_y + kp.y).round() as i32;
406
407                        // Haar wavelet responses
408                        let dx = box_sum_safe(
409                            &integral,
410                            w,
411                            h,
412                            rx,
413                            ry - half_haar,
414                            rx + half_haar,
415                            ry + half_haar,
416                        ) - box_sum_safe(
417                            &integral,
418                            w,
419                            h,
420                            rx - half_haar,
421                            ry - half_haar,
422                            rx,
423                            ry + half_haar,
424                        );
425
426                        let dy = box_sum_safe(
427                            &integral,
428                            w,
429                            h,
430                            rx - half_haar,
431                            ry,
432                            rx + half_haar,
433                            ry + half_haar,
434                        ) - box_sum_safe(
435                            &integral,
436                            w,
437                            h,
438                            rx - half_haar,
439                            ry - half_haar,
440                            rx + half_haar,
441                            ry,
442                        );
443
444                        // Gaussian weight centered on sub-region
445                        let cx = ((i as f32 - 1.5) * 5.0 + 2.5) * scale;
446                        let cy = ((j as f32 - 1.5) * 5.0 + 2.5) * scale;
447                        let dist_sq = (sample_x - cx).powi(2) + (sample_y - cy).powi(2);
448                        let sigma = 3.3 * scale;
449                        let gauss = (-dist_sq / (2.0 * sigma * sigma)).exp();
450
451                        // Rotate wavelet responses to be relative to keypoint orientation
452                        let rdx = cos_ori * dx as f32 + sin_ori * dy as f32;
453                        let rdy = -sin_ori * dx as f32 + cos_ori * dy as f32;
454
455                        sum_dx += rdx * gauss;
456                        sum_abs_dx += rdx.abs() * gauss;
457                        sum_dy += rdy * gauss;
458                        sum_abs_dy += rdy.abs() * gauss;
459                    }
460                }
461
462                let idx = (i * 4 + j) * 4;
463                desc[idx] = sum_dx;
464                desc[idx + 1] = sum_abs_dx;
465                desc[idx + 2] = sum_dy;
466                desc[idx + 3] = sum_abs_dy;
467            }
468        }
469
470        // Normalize to unit vector
471        let norm = desc.iter().map(|v| v * v).sum::<f32>().sqrt().max(1e-7);
472        for v in &mut desc {
473            *v /= norm;
474        }
475
476        descriptors.push(desc);
477    }
478
479    Ok(descriptors)
480}
481
482/// Match SURF descriptors between two sets using nearest-neighbor ratio test.
483///
484/// Returns `(idx1, idx2, distance)` for accepted matches where the ratio of
485/// best to second-best distance is below `ratio_threshold`.
486pub fn match_surf_descriptors(
487    desc1: &[SurfDescriptor],
488    desc2: &[SurfDescriptor],
489    ratio_threshold: f32,
490) -> Vec<(usize, usize, f32)> {
491    let mut matches = Vec::new();
492
493    for (i, d1) in desc1.iter().enumerate() {
494        let mut best_dist = f32::MAX;
495        let mut second_dist = f32::MAX;
496        let mut best_idx = 0;
497
498        for (j, d2) in desc2.iter().enumerate() {
499            let dist: f32 = d1
500                .iter()
501                .zip(d2.iter())
502                .map(|(a, b)| (a - b) * (a - b))
503                .sum::<f32>()
504                .sqrt();
505            if dist < best_dist {
506                second_dist = best_dist;
507                best_dist = dist;
508                best_idx = j;
509            } else if dist < second_dist {
510                second_dist = dist;
511            }
512        }
513
514        // Accept exact matches (dist ≈ 0) unconditionally; otherwise apply ratio test.
515        if best_dist < 1e-9 || (second_dist > 0.0 && best_dist / second_dist < ratio_threshold) {
516            matches.push((i, best_idx, best_dist));
517        }
518    }
519
520    matches
521}
522
523#[cfg(test)]
524mod tests {
525    use super::*;
526
527    #[test]
528    fn surf_integral_image() {
529        // 3x3 image, all ones
530        let data = vec![1.0f32; 9];
531        let integral = build_integral_image(&data, 3, 3);
532        // Expected integral image:
533        // 1  2  3
534        // 2  4  6
535        // 3  6  9
536        assert_eq!(integral[0], 1.0); // (0,0)
537        assert_eq!(integral[1], 2.0); // (1,0)
538        assert_eq!(integral[2], 3.0); // (2,0)
539        assert_eq!(integral[3], 2.0); // (0,1)
540        assert_eq!(integral[4], 4.0); // (1,1)
541        assert_eq!(integral[5], 6.0); // (2,1)
542        assert_eq!(integral[6], 3.0); // (0,2)
543        assert_eq!(integral[7], 6.0); // (1,2)
544        assert_eq!(integral[8], 9.0); // (2,2)
545
546        // Test with varying values
547        let data2 = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
548        let integral2 = build_integral_image(&data2, 3, 3);
549        // Row 0: 1, 3, 6
550        // Row 1: 5, 12, 21
551        // Row 2: 12, 27, 45
552        assert_eq!(integral2[0], 1.0);
553        assert_eq!(integral2[1], 3.0);
554        assert_eq!(integral2[2], 6.0);
555        assert_eq!(integral2[3], 5.0);
556        assert_eq!(integral2[4], 12.0);
557        assert_eq!(integral2[5], 21.0);
558        assert_eq!(integral2[6], 12.0);
559        assert_eq!(integral2[7], 27.0);
560        assert_eq!(integral2[8], 45.0);
561
562        // Verify box_sum on a region
563        let sum = box_sum(&integral2, 3, 1, 1, 2, 2);
564        // Should be 5+6+8+9 = 28
565        assert_eq!(sum, 28.0);
566    }
567
568    #[test]
569    fn surf_detect_on_gradient() {
570        // Create a 64x64 image with a strong corner (bright square on dark background)
571        let (h, w) = (64, 64);
572        let mut data = vec![0.0f32; h * w];
573        for y in 20..44 {
574            for x in 20..44 {
575                data[y * w + x] = 1.0;
576            }
577        }
578        let img = Tensor::from_vec(vec![h, w, 1], data).unwrap();
579        let keypoints = detect_surf_keypoints(&img, 0.0001, 2, 4).unwrap();
580        assert!(
581            !keypoints.is_empty(),
582            "image with strong edges should produce SURF keypoints"
583        );
584        // Keypoints should be near the edges of the bright square
585        for kp in &keypoints {
586            assert!(kp.response > 0.0, "keypoint response should be positive");
587        }
588    }
589
590    #[test]
591    fn surf_descriptor_dimension() {
592        // Create a simple image and manually-placed keypoint
593        let (h, w) = (64, 64);
594        let data: Vec<f32> = (0..h * w).map(|i| (i % w) as f32 / w as f32).collect();
595        let img = Tensor::from_vec(vec![h, w, 1], data).unwrap();
596
597        let keypoints = vec![SurfKeypoint {
598            x: 32.0,
599            y: 32.0,
600            scale: 1.2,
601            orientation: 0.0,
602            response: 1.0,
603            laplacian_sign: 1,
604        }];
605        let descriptors = compute_surf_descriptors(&img, &keypoints).unwrap();
606        assert_eq!(descriptors.len(), 1);
607        assert_eq!(
608            descriptors[0].len(),
609            64,
610            "SURF descriptor should be 64-element"
611        );
612        // Should be normalized to unit vector
613        let norm: f32 = descriptors[0].iter().map(|v| v * v).sum::<f32>().sqrt();
614        assert!(
615            (norm - 1.0).abs() < 0.01,
616            "descriptor should be L2-normalized, got norm={}",
617            norm
618        );
619    }
620
621    #[test]
622    fn surf_match_identical() {
623        // Create image, detect keypoints, compute descriptors, match against itself
624        let (h, w) = (64, 64);
625        let mut data = vec![0.1f32; h * w];
626        // Add multiple distinct features
627        for y in 10..20 {
628            for x in 10..20 {
629                data[y * w + x] = 0.9;
630            }
631        }
632        for y in 40..50 {
633            for x in 40..50 {
634                data[y * w + x] = 0.9;
635            }
636        }
637        let img = Tensor::from_vec(vec![h, w, 1], data).unwrap();
638
639        // Create keypoints at known positions
640        let keypoints = vec![
641            SurfKeypoint {
642                x: 15.0,
643                y: 15.0,
644                scale: 1.2,
645                orientation: 0.0,
646                response: 1.0,
647                laplacian_sign: 1,
648            },
649            SurfKeypoint {
650                x: 45.0,
651                y: 45.0,
652                scale: 1.2,
653                orientation: 0.0,
654                response: 1.0,
655                laplacian_sign: 1,
656            },
657        ];
658
659        let descriptors = compute_surf_descriptors(&img, &keypoints).unwrap();
660        assert_eq!(descriptors.len(), 2);
661
662        // Match against itself — every descriptor should find a near-zero distance match.
663        // Note: if two descriptors are identical (symmetric patches), the matcher may
664        // pair them in any order, so we only check distances, not index identity.
665        let matches = match_surf_descriptors(&descriptors, &descriptors, 0.99);
666        assert!(
667            !matches.is_empty(),
668            "matching descriptors against themselves should produce matches"
669        );
670        for &(_i, _j, dist) in &matches {
671            assert!(
672                dist < 1e-5,
673                "self-match distance should be ~0, got {}",
674                dist
675            );
676        }
677    }
678}