Skip to main content

yscv_detect/
heatmap.rs

1use yscv_tensor::Tensor;
2
3use crate::nms::validate_nms_args;
4use crate::{BoundingBox, CLASS_ID_PERSON, DetectError, Detection, non_max_suppression};
5
6/// Reusable scratch storage for connected-component heatmap detection.
7///
8/// This allows callers with stable heatmap dimensions (for example camera loops)
9/// to avoid reallocating traversal buffers on each detection call.
10#[derive(Debug, Default, Clone, PartialEq)]
11pub struct HeatmapDetectScratch {
12    active: Vec<bool>,
13    visited: Vec<bool>,
14    stack: Vec<usize>,
15    detections: Vec<Detection>,
16}
17
18/// Connected-component detector over heatmaps `[H, W, 1]`.
19pub fn detect_from_heatmap(
20    heatmap: &Tensor,
21    score_threshold: f32,
22    min_area: usize,
23    iou_threshold: f32,
24    max_detections: usize,
25) -> Result<Vec<Detection>, DetectError> {
26    let mut scratch = HeatmapDetectScratch::default();
27    detect_from_heatmap_with_scratch(
28        heatmap,
29        score_threshold,
30        min_area,
31        iou_threshold,
32        max_detections,
33        &mut scratch,
34    )
35}
36
37/// Connected-component detector over heatmaps `[H, W, 1]` with reusable scratch storage.
38pub fn detect_from_heatmap_with_scratch(
39    heatmap: &Tensor,
40    score_threshold: f32,
41    min_area: usize,
42    iou_threshold: f32,
43    max_detections: usize,
44    scratch: &mut HeatmapDetectScratch,
45) -> Result<Vec<Detection>, DetectError> {
46    let (h, w, c) = map_shape(heatmap)?;
47    if c != 1 {
48        return Err(DetectError::InvalidChannelCount {
49            expected: 1,
50            got: c,
51        });
52    }
53    detect_from_heatmap_data_with_scratch(
54        (h, w),
55        heatmap.data(),
56        score_threshold,
57        min_area,
58        iou_threshold,
59        max_detections,
60        scratch,
61    )
62}
63
64pub(crate) fn detect_from_heatmap_data_with_scratch(
65    shape: (usize, usize),
66    data: &[f32],
67    score_threshold: f32,
68    min_area: usize,
69    iou_threshold: f32,
70    max_detections: usize,
71    scratch: &mut HeatmapDetectScratch,
72) -> Result<Vec<Detection>, DetectError> {
73    let (h, w) = shape;
74    if !score_threshold.is_finite() || !(0.0..=1.0).contains(&score_threshold) {
75        return Err(DetectError::InvalidThreshold {
76            threshold: score_threshold,
77        });
78    }
79    if min_area == 0 {
80        return Err(DetectError::InvalidMinArea { min_area });
81    }
82    validate_nms_args(iou_threshold, max_detections)?;
83    let pixel_count = h.saturating_mul(w);
84    debug_assert_eq!(data.len(), pixel_count);
85
86    if scratch.active.len() != pixel_count {
87        scratch.active.resize(pixel_count, false);
88    }
89    if scratch.visited.len() != pixel_count {
90        scratch.visited.resize(pixel_count, false);
91    }
92
93    for ((active, visited), value) in scratch
94        .active
95        .iter_mut()
96        .zip(scratch.visited.iter_mut())
97        .zip(data.iter().copied())
98    {
99        *active = is_active_score(value, score_threshold);
100        *visited = false;
101    }
102
103    scratch.stack.clear();
104    scratch.detections.clear();
105    for start in 0..pixel_count {
106        if scratch.visited[start] || !scratch.active[start] {
107            continue;
108        }
109
110        scratch.visited[start] = true;
111        scratch.stack.clear();
112        scratch.stack.push(start);
113
114        let start_y = start / w;
115        let start_x = start - start_y * w;
116        let mut min_x = start_x;
117        let mut max_x = start_x;
118        let mut min_y = start_y;
119        let mut max_y = start_y;
120        let mut area = 0usize;
121        let mut score_sum = 0.0f32;
122        let mut score_max = 0.0f32;
123
124        while let Some(current) = scratch.stack.pop() {
125            let cy = current / w;
126            let cx = current - cy * w;
127            let current_score = data[current];
128
129            area += 1;
130            score_sum += current_score;
131            score_max = score_max.max(current_score);
132            min_x = min_x.min(cx);
133            max_x = max_x.max(cx);
134            min_y = min_y.min(cy);
135            max_y = max_y.max(cy);
136
137            if cx > 0 {
138                visit_neighbor(
139                    current - 1,
140                    &scratch.active,
141                    &mut scratch.visited,
142                    &mut scratch.stack,
143                );
144            }
145            if cx + 1 < w {
146                visit_neighbor(
147                    current + 1,
148                    &scratch.active,
149                    &mut scratch.visited,
150                    &mut scratch.stack,
151                );
152            }
153            if cy > 0 {
154                visit_neighbor(
155                    current - w,
156                    &scratch.active,
157                    &mut scratch.visited,
158                    &mut scratch.stack,
159                );
160            }
161            if cy + 1 < h {
162                visit_neighbor(
163                    current + w,
164                    &scratch.active,
165                    &mut scratch.visited,
166                    &mut scratch.stack,
167                );
168            }
169        }
170
171        if area >= min_area {
172            let avg_score = score_sum / area as f32;
173            scratch.detections.push(Detection {
174                bbox: BoundingBox {
175                    x1: min_x as f32,
176                    y1: min_y as f32,
177                    x2: (max_x + 1) as f32,
178                    y2: (max_y + 1) as f32,
179                },
180                score: (avg_score + score_max) * 0.5,
181                class_id: CLASS_ID_PERSON,
182            });
183        }
184    }
185
186    Ok(non_max_suppression(
187        &scratch.detections,
188        iou_threshold,
189        max_detections,
190    ))
191}
192
193pub(crate) fn map_shape(input: &Tensor) -> Result<(usize, usize, usize), DetectError> {
194    if input.rank() != 3 {
195        return Err(DetectError::InvalidMapShape {
196            expected_rank: 3,
197            got: input.shape().to_vec(),
198        });
199    }
200    Ok((input.shape()[0], input.shape()[1], input.shape()[2]))
201}
202
203fn is_active_score(value: f32, threshold: f32) -> bool {
204    value.is_finite() && value >= threshold
205}
206
207fn visit_neighbor(index: usize, active: &[bool], visited: &mut [bool], stack: &mut Vec<usize>) {
208    if visited[index] {
209        return;
210    }
211    visited[index] = true;
212    if active[index] {
213        stack.push(index);
214    }
215}