1use yscv_tensor::Tensor;
2
3use super::super::ImgProcError;
4use super::super::shape::hwc_shape;
5
6#[derive(Debug, Clone)]
8pub struct SurfKeypoint {
9 pub x: f32,
10 pub y: f32,
11 pub scale: f32,
12 pub orientation: f32,
13 pub response: f32,
14 pub laplacian_sign: i8, }
16
17pub type SurfDescriptor = Vec<f32>;
19
20pub fn build_integral_image(image: &[f32], width: usize, height: usize) -> Vec<f64> {
22 let mut integral = vec![0.0f64; width * height];
23 for y in 0..height {
24 let mut row_sum = 0.0f64;
25 for x in 0..width {
26 row_sum += image[y * width + x] as f64;
27 integral[y * width + x] = row_sum
28 + if y > 0 {
29 integral[(y - 1) * width + x]
30 } else {
31 0.0
32 };
33 }
34 }
35 integral
36}
37
38fn box_sum(integral: &[f64], width: usize, x1: usize, y1: usize, x2: usize, y2: usize) -> f64 {
41 let a = if x1 > 0 && y1 > 0 {
42 integral[(y1 - 1) * width + x1 - 1]
43 } else {
44 0.0
45 };
46 let b = if y1 > 0 {
47 integral[(y1 - 1) * width + x2]
48 } else {
49 0.0
50 };
51 let c = if x1 > 0 {
52 integral[y2 * width + x1 - 1]
53 } else {
54 0.0
55 };
56 let d = integral[y2 * width + x2];
57 d - b - c + a
58}
59
60fn box_sum_safe(
62 integral: &[f64],
63 width: usize,
64 height: usize,
65 x1: i32,
66 y1: i32,
67 x2: i32,
68 y2: i32,
69) -> f64 {
70 let x1 = x1.max(0) as usize;
71 let y1 = y1.max(0) as usize;
72 let x2 = (x2.min(width as i32 - 1)).max(0) as usize;
73 let y2 = (y2.min(height as i32 - 1)).max(0) as usize;
74 if x2 < x1 || y2 < y1 {
75 return 0.0;
76 }
77 box_sum(integral, width, x1, y1, x2, y2)
78}
79
80fn hessian_det(
83 integral: &[f64],
84 width: usize,
85 height: usize,
86 x: i32,
87 y: i32,
88 filter_size: usize,
89) -> (f64, f64) {
90 let fs = filter_size as i32;
91 let l = fs / 3; let dxx = box_sum_safe(integral, width, height, x - l, y - l / 2, x + l, y + l / 2)
95 - 3.0
96 * box_sum_safe(
97 integral,
98 width,
99 height,
100 x - l / 2,
101 y - l / 2,
102 x + l / 2,
103 y + l / 2,
104 );
105
106 let dyy = box_sum_safe(integral, width, height, x - l / 2, y - l, x + l / 2, y + l)
108 - 3.0
109 * box_sum_safe(
110 integral,
111 width,
112 height,
113 x - l / 2,
114 y - l / 2,
115 x + l / 2,
116 y + l / 2,
117 );
118
119 let dxy = box_sum_safe(integral, width, height, x + 1, y - l, x + l, y - 1)
121 + box_sum_safe(integral, width, height, x - l, y + 1, x - 1, y + l)
122 - box_sum_safe(integral, width, height, x - l, y - l, x - 1, y - 1)
123 - box_sum_safe(integral, width, height, x + 1, y + 1, x + l, y + l);
124
125 let area = (fs * fs) as f64;
127 let dxx = dxx / area;
128 let dyy = dyy / area;
129 let dxy = dxy / area;
130
131 let det = dxx * dyy - 0.81 * dxy * dxy;
133 let trace = dxx + dyy;
134 (det, trace)
135}
136
137pub fn detect_surf_keypoints(
139 image: &Tensor,
140 hessian_threshold: f32,
141 num_octaves: usize,
142 num_scales: usize,
143) -> Result<Vec<SurfKeypoint>, ImgProcError> {
144 let (h, w, c) = hwc_shape(image)?;
145 if c != 1 {
146 return Err(ImgProcError::InvalidChannelCount {
147 expected: 1,
148 got: c,
149 });
150 }
151 let data = image.data();
152
153 let integral = build_integral_image(data, w, h);
155
156 let mut scale_responses: Vec<(Vec<f64>, usize)> = Vec::new(); for octave in 0..num_octaves {
161 let step = 1usize << octave; for scale in 0..num_scales {
163 let filter_size = 3 * ((2usize.pow(octave as u32)) * (scale + 1) + 1);
164 if filter_size / 2 >= h.min(w) {
165 continue;
166 }
167 let mut response = vec![0.0f64; h * w];
168 let margin = (filter_size / 2 + 1) as i32;
169
170 for y in (margin as usize..h.saturating_sub(margin as usize)).step_by(step) {
171 for x in (margin as usize..w.saturating_sub(margin as usize)).step_by(step) {
172 let (det, _trace) =
173 hessian_det(&integral, w, h, x as i32, y as i32, filter_size);
174 response[y * w + x] = det;
175 }
176 }
177 scale_responses.push((response, filter_size));
178 }
179 }
180
181 let mut keypoints = Vec::new();
183 let thresh = hessian_threshold as f64;
184
185 for si in 1..scale_responses.len().saturating_sub(1) {
186 let filter_size = scale_responses[si].1;
187 let margin = filter_size / 2 + 1;
188 let _step = 1usize.max(filter_size / 9);
189
190 for y in margin..h.saturating_sub(margin) {
191 for x in margin..w.saturating_sub(margin) {
192 let val = scale_responses[si].0[y * w + x];
193 if val < thresh {
194 continue;
195 }
196
197 let mut is_max = true;
198 'nms: for ds in -1i32..=1 {
199 let si2 = (si as i32 + ds) as usize;
200 for dy in -1i32..=1 {
201 for dx in -1i32..=1 {
202 if ds == 0 && dy == 0 && dx == 0 {
203 continue;
204 }
205 let ny = (y as i32 + dy) as usize;
206 let nx = (x as i32 + dx) as usize;
207 if ny < h && nx < w && scale_responses[si2].0[ny * w + nx] >= val {
208 is_max = false;
209 break 'nms;
210 }
211 }
212 }
213 }
214
215 if is_max {
216 let scale = filter_size as f32 * 1.2 / 9.0;
218 let orientation =
219 compute_orientation(&integral, w, h, x as f32, y as f32, scale);
220
221 let (_, trace) = hessian_det(&integral, w, h, x as i32, y as i32, filter_size);
222
223 keypoints.push(SurfKeypoint {
224 x: x as f32,
225 y: y as f32,
226 scale,
227 orientation,
228 response: val as f32,
229 laplacian_sign: if trace > 0.0 { 1 } else { -1 },
230 });
231 }
232 }
233 }
234 }
235
236 keypoints.sort_by(|a, b| {
238 b.response
239 .partial_cmp(&a.response)
240 .unwrap_or(std::cmp::Ordering::Equal)
241 });
242
243 Ok(keypoints)
244}
245
246fn compute_orientation(
248 integral: &[f64],
249 width: usize,
250 height: usize,
251 x: f32,
252 y: f32,
253 scale: f32,
254) -> f32 {
255 let radius = (6.0 * scale).round() as i32;
256 let haar_size = (4.0 * scale).round().max(1.0) as i32;
257 let half_haar = haar_size / 2;
258
259 let mut dx_responses = Vec::new();
260 let mut dy_responses = Vec::new();
261 let mut angles = Vec::new();
262
263 for i in -radius..=radius {
265 for j in -radius..=radius {
266 if i * i + j * j > radius * radius {
267 continue;
268 }
269 let px = x as i32 + j;
270 let py = y as i32 + i;
271
272 let dx = box_sum_safe(
274 integral,
275 width,
276 height,
277 px,
278 py - half_haar,
279 px + half_haar,
280 py + half_haar,
281 ) - box_sum_safe(
282 integral,
283 width,
284 height,
285 px - half_haar,
286 py - half_haar,
287 px,
288 py + half_haar,
289 );
290
291 let dy = box_sum_safe(
293 integral,
294 width,
295 height,
296 px - half_haar,
297 py,
298 px + half_haar,
299 py + half_haar,
300 ) - box_sum_safe(
301 integral,
302 width,
303 height,
304 px - half_haar,
305 py - half_haar,
306 px + half_haar,
307 py,
308 );
309
310 let sigma = 2.5 * scale;
312 let weight = (-(i * i + j * j) as f32 / (2.0 * sigma * sigma)).exp();
313
314 dx_responses.push(dx as f32 * weight);
315 dy_responses.push(dy as f32 * weight);
316 angles.push((dy as f32).atan2(dx as f32));
317 }
318 }
319
320 if dx_responses.is_empty() {
321 return 0.0;
322 }
323
324 let window = std::f32::consts::PI / 3.0;
326 let mut best_angle = 0.0f32;
327 let mut best_magnitude = 0.0f32;
328
329 let steps = 36;
330 for step in 0..steps {
331 let angle = -std::f32::consts::PI + step as f32 * 2.0 * std::f32::consts::PI / steps as f32;
332 let mut sum_dx = 0.0f32;
333 let mut sum_dy = 0.0f32;
334
335 for i in 0..angles.len() {
336 let mut diff = angles[i] - angle;
337 while diff > std::f32::consts::PI {
339 diff -= 2.0 * std::f32::consts::PI;
340 }
341 while diff < -std::f32::consts::PI {
342 diff += 2.0 * std::f32::consts::PI;
343 }
344 if diff.abs() < window / 2.0 {
345 sum_dx += dx_responses[i];
346 sum_dy += dy_responses[i];
347 }
348 }
349
350 let mag = sum_dx * sum_dx + sum_dy * sum_dy;
351 if mag > best_magnitude {
352 best_magnitude = mag;
353 best_angle = sum_dy.atan2(sum_dx);
354 }
355 }
356
357 best_angle
358}
359
360pub fn compute_surf_descriptors(
362 image: &Tensor,
363 keypoints: &[SurfKeypoint],
364) -> Result<Vec<SurfDescriptor>, ImgProcError> {
365 let (h, w, c) = hwc_shape(image)?;
366 if c != 1 {
367 return Err(ImgProcError::InvalidChannelCount {
368 expected: 1,
369 got: c,
370 });
371 }
372 let data = image.data();
373 let integral = build_integral_image(data, w, h);
374
375 let mut descriptors = Vec::with_capacity(keypoints.len());
376
377 for kp in keypoints {
378 let scale = kp.scale;
379 let cos_ori = kp.orientation.cos();
380 let sin_ori = kp.orientation.sin();
381 let haar_size = (2.0 * scale).round().max(1.0) as i32;
382 let half_haar = haar_size / 2;
383
384 let mut desc = vec![0.0f32; 64];
385
386 let _sub_region_size = 5.0 * scale;
388
389 for i in 0..4 {
390 for j in 0..4 {
391 let mut sum_dx = 0.0f32;
392 let mut sum_abs_dx = 0.0f32;
393 let mut sum_dy = 0.0f32;
394 let mut sum_abs_dy = 0.0f32;
395
396 for k in 0..5 {
398 for l in 0..5 {
399 let sample_x = ((i as f32 - 2.0) * 5.0 + l as f32 + 0.5) * scale;
401 let sample_y = ((j as f32 - 2.0) * 5.0 + k as f32 + 0.5) * scale;
402
403 let rx = (cos_ori * sample_x - sin_ori * sample_y + kp.x).round() as i32;
405 let ry = (sin_ori * sample_x + cos_ori * sample_y + kp.y).round() as i32;
406
407 let dx = box_sum_safe(
409 &integral,
410 w,
411 h,
412 rx,
413 ry - half_haar,
414 rx + half_haar,
415 ry + half_haar,
416 ) - box_sum_safe(
417 &integral,
418 w,
419 h,
420 rx - half_haar,
421 ry - half_haar,
422 rx,
423 ry + half_haar,
424 );
425
426 let dy = box_sum_safe(
427 &integral,
428 w,
429 h,
430 rx - half_haar,
431 ry,
432 rx + half_haar,
433 ry + half_haar,
434 ) - box_sum_safe(
435 &integral,
436 w,
437 h,
438 rx - half_haar,
439 ry - half_haar,
440 rx + half_haar,
441 ry,
442 );
443
444 let cx = ((i as f32 - 1.5) * 5.0 + 2.5) * scale;
446 let cy = ((j as f32 - 1.5) * 5.0 + 2.5) * scale;
447 let dist_sq = (sample_x - cx).powi(2) + (sample_y - cy).powi(2);
448 let sigma = 3.3 * scale;
449 let gauss = (-dist_sq / (2.0 * sigma * sigma)).exp();
450
451 let rdx = cos_ori * dx as f32 + sin_ori * dy as f32;
453 let rdy = -sin_ori * dx as f32 + cos_ori * dy as f32;
454
455 sum_dx += rdx * gauss;
456 sum_abs_dx += rdx.abs() * gauss;
457 sum_dy += rdy * gauss;
458 sum_abs_dy += rdy.abs() * gauss;
459 }
460 }
461
462 let idx = (i * 4 + j) * 4;
463 desc[idx] = sum_dx;
464 desc[idx + 1] = sum_abs_dx;
465 desc[idx + 2] = sum_dy;
466 desc[idx + 3] = sum_abs_dy;
467 }
468 }
469
470 let norm = desc.iter().map(|v| v * v).sum::<f32>().sqrt().max(1e-7);
472 for v in &mut desc {
473 *v /= norm;
474 }
475
476 descriptors.push(desc);
477 }
478
479 Ok(descriptors)
480}
481
482pub fn match_surf_descriptors(
487 desc1: &[SurfDescriptor],
488 desc2: &[SurfDescriptor],
489 ratio_threshold: f32,
490) -> Vec<(usize, usize, f32)> {
491 let mut matches = Vec::new();
492
493 for (i, d1) in desc1.iter().enumerate() {
494 let mut best_dist = f32::MAX;
495 let mut second_dist = f32::MAX;
496 let mut best_idx = 0;
497
498 for (j, d2) in desc2.iter().enumerate() {
499 let dist: f32 = d1
500 .iter()
501 .zip(d2.iter())
502 .map(|(a, b)| (a - b) * (a - b))
503 .sum::<f32>()
504 .sqrt();
505 if dist < best_dist {
506 second_dist = best_dist;
507 best_dist = dist;
508 best_idx = j;
509 } else if dist < second_dist {
510 second_dist = dist;
511 }
512 }
513
514 if best_dist < 1e-9 || (second_dist > 0.0 && best_dist / second_dist < ratio_threshold) {
516 matches.push((i, best_idx, best_dist));
517 }
518 }
519
520 matches
521}
522
523#[cfg(test)]
524mod tests {
525 use super::*;
526
527 #[test]
528 fn surf_integral_image() {
529 let data = vec![1.0f32; 9];
531 let integral = build_integral_image(&data, 3, 3);
532 assert_eq!(integral[0], 1.0); assert_eq!(integral[1], 2.0); assert_eq!(integral[2], 3.0); assert_eq!(integral[3], 2.0); assert_eq!(integral[4], 4.0); assert_eq!(integral[5], 6.0); assert_eq!(integral[6], 3.0); assert_eq!(integral[7], 6.0); assert_eq!(integral[8], 9.0); let data2 = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
548 let integral2 = build_integral_image(&data2, 3, 3);
549 assert_eq!(integral2[0], 1.0);
553 assert_eq!(integral2[1], 3.0);
554 assert_eq!(integral2[2], 6.0);
555 assert_eq!(integral2[3], 5.0);
556 assert_eq!(integral2[4], 12.0);
557 assert_eq!(integral2[5], 21.0);
558 assert_eq!(integral2[6], 12.0);
559 assert_eq!(integral2[7], 27.0);
560 assert_eq!(integral2[8], 45.0);
561
562 let sum = box_sum(&integral2, 3, 1, 1, 2, 2);
564 assert_eq!(sum, 28.0);
566 }
567
568 #[test]
569 fn surf_detect_on_gradient() {
570 let (h, w) = (64, 64);
572 let mut data = vec![0.0f32; h * w];
573 for y in 20..44 {
574 for x in 20..44 {
575 data[y * w + x] = 1.0;
576 }
577 }
578 let img = Tensor::from_vec(vec![h, w, 1], data).unwrap();
579 let keypoints = detect_surf_keypoints(&img, 0.0001, 2, 4).unwrap();
580 assert!(
581 !keypoints.is_empty(),
582 "image with strong edges should produce SURF keypoints"
583 );
584 for kp in &keypoints {
586 assert!(kp.response > 0.0, "keypoint response should be positive");
587 }
588 }
589
590 #[test]
591 fn surf_descriptor_dimension() {
592 let (h, w) = (64, 64);
594 let data: Vec<f32> = (0..h * w).map(|i| (i % w) as f32 / w as f32).collect();
595 let img = Tensor::from_vec(vec![h, w, 1], data).unwrap();
596
597 let keypoints = vec![SurfKeypoint {
598 x: 32.0,
599 y: 32.0,
600 scale: 1.2,
601 orientation: 0.0,
602 response: 1.0,
603 laplacian_sign: 1,
604 }];
605 let descriptors = compute_surf_descriptors(&img, &keypoints).unwrap();
606 assert_eq!(descriptors.len(), 1);
607 assert_eq!(
608 descriptors[0].len(),
609 64,
610 "SURF descriptor should be 64-element"
611 );
612 let norm: f32 = descriptors[0].iter().map(|v| v * v).sum::<f32>().sqrt();
614 assert!(
615 (norm - 1.0).abs() < 0.01,
616 "descriptor should be L2-normalized, got norm={}",
617 norm
618 );
619 }
620
621 #[test]
622 fn surf_match_identical() {
623 let (h, w) = (64, 64);
625 let mut data = vec![0.1f32; h * w];
626 for y in 10..20 {
628 for x in 10..20 {
629 data[y * w + x] = 0.9;
630 }
631 }
632 for y in 40..50 {
633 for x in 40..50 {
634 data[y * w + x] = 0.9;
635 }
636 }
637 let img = Tensor::from_vec(vec![h, w, 1], data).unwrap();
638
639 let keypoints = vec![
641 SurfKeypoint {
642 x: 15.0,
643 y: 15.0,
644 scale: 1.2,
645 orientation: 0.0,
646 response: 1.0,
647 laplacian_sign: 1,
648 },
649 SurfKeypoint {
650 x: 45.0,
651 y: 45.0,
652 scale: 1.2,
653 orientation: 0.0,
654 response: 1.0,
655 laplacian_sign: 1,
656 },
657 ];
658
659 let descriptors = compute_surf_descriptors(&img, &keypoints).unwrap();
660 assert_eq!(descriptors.len(), 2);
661
662 let matches = match_surf_descriptors(&descriptors, &descriptors, 0.99);
666 assert!(
667 !matches.is_empty(),
668 "matching descriptors against themselves should produce matches"
669 );
670 for &(_i, _j, dist) in &matches {
671 assert!(
672 dist < 1e-5,
673 "self-match distance should be ~0, got {}",
674 dist
675 );
676 }
677 }
678}