pure_onnx_ocr/
postprocessing.rs

1use crate::detection::DetInferenceOutput;
2use geo_types::{Coord, LineString, Polygon};
3use i_overlay::float::overlay::OverlayOptions;
4use i_overlay::mesh::outline::offset::OutlineOffset;
5use i_overlay::mesh::style::{LineJoin, OutlineStyle};
6use image::{GrayImage, Luma};
7use imageproc::contours::{find_contours, Contour};
8use imageproc::point::Point;
9use ndarray::Array2;
10use std::error::Error;
11use std::fmt;
12
13/// Configuration for `DetPostProcessor`.
14#[derive(Debug, Clone, Copy)]
15pub struct DetPostProcessorConfig {
16    /// Probability threshold (0.0 - 1.0) applied before contour extraction.
17    pub threshold: f32,
18    /// Minimum contour area (in pixels) to keep.
19    pub min_area: f32,
20}
21
22impl Default for DetPostProcessorConfig {
23    fn default() -> Self {
24        Self {
25            threshold: 0.3,
26            min_area: 10.0,
27        }
28    }
29}
30
31/// Errors that can occur during detection post-processing.
32#[derive(Debug)]
33pub enum DetPostProcessorError {
34    /// Probability map contained no elements.
35    EmptyProbabilityMap,
36    /// Failed to construct an image buffer from the probability map.
37    ImageCreationFailed,
38}
39
40impl fmt::Display for DetPostProcessorError {
41    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
42        match self {
43            DetPostProcessorError::EmptyProbabilityMap => {
44                write!(f, "probability map must contain at least one element")
45            }
46            DetPostProcessorError::ImageCreationFailed => {
47                write!(f, "failed to create grayscale image from probability map")
48            }
49        }
50    }
51}
52
53impl Error for DetPostProcessorError {}
54
55/// Extracts text candidate contours from the DBNet probability map.
56#[derive(Debug, Clone)]
57pub struct DetPostProcessor {
58    config: DetPostProcessorConfig,
59}
60
61impl DetPostProcessor {
62    pub fn new(config: DetPostProcessorConfig) -> Self {
63        Self { config }
64    }
65
66    pub fn process(
67        &self,
68        output: &DetInferenceOutput,
69    ) -> Result<Vec<Contour<i32>>, DetPostProcessorError> {
70        self.process_probability_map(&output.probability_map)
71    }
72
73    pub fn process_probability_map(
74        &self,
75        probability_map: &Array2<f32>,
76    ) -> Result<Vec<Contour<i32>>, DetPostProcessorError> {
77        if probability_map.is_empty() {
78            return Err(DetPostProcessorError::EmptyProbabilityMap);
79        }
80
81        let threshold = self.config.threshold.clamp(0.0, 1.0);
82        let (height, width) = probability_map.dim();
83        let mut buffer = Vec::with_capacity(height * width);
84
85        for &value in probability_map.iter() {
86            let clamped = value.clamp(0.0, 1.0);
87            let byte = if clamped >= threshold { 255 } else { 0 };
88            buffer.push(byte);
89        }
90
91        let mut gray = GrayImage::from_vec(width as u32, height as u32, buffer)
92            .ok_or(DetPostProcessorError::ImageCreationFailed)?;
93
94        // Ensure the binary image uses full white for foreground for consistent contour detection.
95        for pixel in gray.pixels_mut() {
96            *pixel = if pixel[0] > 0 { Luma([255]) } else { Luma([0]) };
97        }
98
99        let contours = find_contours::<i32>(&gray);
100        let min_area = self.config.min_area.max(0.0);
101
102        let filtered = contours
103            .into_iter()
104            .filter(|contour| contour.points.len() >= 3)
105            .filter(|contour| contour_area(contour) >= min_area)
106            .collect();
107
108        Ok(filtered)
109    }
110}
111
112fn contour_area(contour: &Contour<i32>) -> f32 {
113    if contour.points.len() < 3 {
114        return 0.0;
115    }
116
117    let mut area = 0f64;
118    for window in contour.points.windows(2) {
119        if let [Point { x: x1, y: y1 }, Point { x: x2, y: y2 }] = window {
120            area += (*x1 as f64) * (*y2 as f64) - (*x2 as f64) * (*y1 as f64);
121        }
122    }
123
124    let first = contour.points.first().unwrap();
125    let last = contour.points.last().unwrap();
126    area += (last.x as f64) * (first.y as f64) - (first.x as f64) * (last.y as f64);
127
128    (area.abs() * 0.5) as f32
129}
130
131/// Corner join style for unclip offsetting.
132#[derive(Debug, Clone, Copy)]
133pub enum DetUnclipLineJoin {
134    Bevel,
135    Miter(f32),
136    Round(f32),
137}
138
139impl DetUnclipLineJoin {
140    fn to_line_join(self) -> LineJoin<f64> {
141        match self {
142            DetUnclipLineJoin::Bevel => LineJoin::Bevel,
143            DetUnclipLineJoin::Miter(angle) => LineJoin::Miter(angle.max(0.01) as f64),
144            DetUnclipLineJoin::Round(angle) => LineJoin::Round(angle.max(0.01) as f64),
145        }
146    }
147}
148
149/// Configuration for polygon offsetting (unclip).
150#[derive(Debug, Clone, Copy)]
151pub struct DetPolygonUnclipperConfig {
152    /// Ratio applied to the DBNet area/perimeter heuristic.
153    pub unclip_ratio: f32,
154    /// Additional minimum area after unclipping; polygons smaller than this are discarded.
155    pub min_result_area: f32,
156    /// Join style applied to buffered corners.
157    pub join_style: DetUnclipLineJoin,
158}
159
160impl Default for DetPolygonUnclipperConfig {
161    fn default() -> Self {
162        Self {
163            unclip_ratio: 1.5,
164            min_result_area: 25.0,
165            join_style: DetUnclipLineJoin::Round(0.1),
166        }
167    }
168}
169
170/// Applies DBNet-style polygon offsetting (unclip) using `i_overlay`.
171#[derive(Debug, Clone)]
172pub struct DetPolygonUnclipper {
173    config: DetPolygonUnclipperConfig,
174}
175
176impl DetPolygonUnclipper {
177    pub fn new(config: DetPolygonUnclipperConfig) -> Self {
178        Self { config }
179    }
180
181    pub fn unclip_contours(&self, contours: &[Contour<i32>]) -> Vec<Polygon<f64>> {
182        contours
183            .iter()
184            .filter_map(|contour| contour_to_polygon(contour))
185            .flat_map(|polygon| self.unclip_polygon(&polygon))
186            .filter(|polygon| polygon_area(polygon) >= self.config.min_result_area)
187            .collect()
188    }
189
190    fn unclip_polygon(&self, polygon: &Polygon<f64>) -> Vec<Polygon<f64>> {
191        let distance = unclip_distance(polygon, self.config.unclip_ratio.max(0.0));
192        if distance <= f64::EPSILON {
193            return vec![polygon.clone()];
194        }
195
196        let shape = polygon_to_shape(polygon);
197        let style = OutlineStyle::default()
198            .outer_offset(distance)
199            .inner_offset(0.0)
200            .line_join(self.config.join_style.to_line_join());
201
202        let options = OverlayOptions::default();
203        shape
204            .outline_custom(&style, options)
205            .into_iter()
206            .filter_map(shape_to_polygon)
207            .collect()
208    }
209}
210
211/// Rounding strategy used when restoring polygon coordinates.
212#[derive(Debug, Clone, Copy)]
213pub enum DetScaleRounding {
214    /// Do not apply rounding.
215    None,
216    /// Round to the specified number of fractional digits.
217    FractionalDigits(u32),
218}
219
220impl Default for DetScaleRounding {
221    fn default() -> Self {
222        Self::FractionalDigits(2)
223    }
224}
225
226/// Configuration for polygon scaling back to original image coordinates.
227#[derive(Debug, Clone, Copy)]
228pub struct DetPolygonScalerConfig {
229    /// Whether to clamp coordinates to the original image bounds.
230    pub clamp_to_image: bool,
231    /// Rounding strategy applied after scaling.
232    pub rounding: DetScaleRounding,
233}
234
235impl Default for DetPolygonScalerConfig {
236    fn default() -> Self {
237        Self {
238            clamp_to_image: true,
239            rounding: DetScaleRounding::FractionalDigits(2),
240        }
241    }
242}
243
244/// Scales polygons from resized space back to the original image coordinate system.
245#[derive(Debug, Clone)]
246pub struct DetPolygonScaler {
247    config: DetPolygonScalerConfig,
248}
249
250impl DetPolygonScaler {
251    pub fn new(config: DetPolygonScalerConfig) -> Self {
252        Self { config }
253    }
254
255    /// Converts polygons generated in resized space back to the original image space.
256    ///
257    /// * `scale_ratio` - Resize ratio used during preprocessing (resized / original).
258    /// * `original_dims` - Width and height of the original image (in pixels).
259    pub fn scale_polygons(
260        &self,
261        polygons: &[Polygon<f64>],
262        scale_ratio: f64,
263        original_dims: (u32, u32),
264    ) -> Vec<Polygon<f64>> {
265        if scale_ratio <= f64::EPSILON {
266            return polygons.to_vec();
267        }
268
269        let inverse_scale = 1.0 / scale_ratio;
270
271        polygons
272            .iter()
273            .map(|polygon| self.scale_polygon(polygon, inverse_scale, original_dims))
274            .collect()
275    }
276
277    fn scale_polygon(
278        &self,
279        polygon: &Polygon<f64>,
280        inverse_scale: f64,
281        original_dims: (u32, u32),
282    ) -> Polygon<f64> {
283        let exterior = self.scale_line_string(polygon.exterior(), inverse_scale, original_dims);
284        let interiors = polygon
285            .interiors()
286            .iter()
287            .map(|line| self.scale_line_string(line, inverse_scale, original_dims))
288            .collect();
289
290        Polygon::new(exterior, interiors)
291    }
292
293    fn scale_line_string(
294        &self,
295        line: &LineString<f64>,
296        inverse_scale: f64,
297        original_dims: (u32, u32),
298    ) -> LineString<f64> {
299        let precision = match self.config.rounding {
300            DetScaleRounding::None => None,
301            DetScaleRounding::FractionalDigits(p) => Some(p),
302        };
303
304        let mut coords: Vec<Coord<f64>> = line
305            .points()
306            .map(|p| {
307                let mut x = p.x() * inverse_scale;
308                let mut y = p.y() * inverse_scale;
309
310                if self.config.clamp_to_image {
311                    x = clamp_to_bounds(x, original_dims.0);
312                    y = clamp_to_bounds(y, original_dims.1);
313                }
314
315                if let Some(precision) = precision {
316                    x = round_fractional(x, precision);
317                    y = round_fractional(y, precision);
318                }
319
320                Coord { x, y }
321            })
322            .collect();
323
324        close_if_needed(&mut coords);
325        LineString::from(coords)
326    }
327}
328
329fn contour_to_polygon(contour: &Contour<i32>) -> Option<Polygon<f64>> {
330    if contour.points.len() < 3 {
331        return None;
332    }
333
334    let mut coords: Vec<Coord<f64>> = contour
335        .points
336        .iter()
337        .map(|point| Coord {
338            x: point.x as f64,
339            y: point.y as f64,
340        })
341        .collect();
342
343    close_if_needed(&mut coords);
344
345    // Ensure outer contour is counter-clockwise.
346    if signed_area_coords(&coords) < 0.0 {
347        coords.reverse();
348        close_if_needed(&mut coords);
349    }
350
351    let exterior = LineString::from(coords);
352    Some(Polygon::new(exterior, Vec::new()))
353}
354
355fn polygon_to_shape(polygon: &Polygon<f64>) -> Vec<Vec<[f64; 2]>> {
356    let mut shape = Vec::with_capacity(1 + polygon.interiors().len());
357    shape.push(linestring_to_points(polygon.exterior(), true));
358    for interior in polygon.interiors() {
359        shape.push(linestring_to_points(interior, false));
360    }
361    shape
362}
363
364fn shape_to_polygon(shape: Vec<Vec<[f64; 2]>>) -> Option<Polygon<f64>> {
365    if shape.is_empty() {
366        return None;
367    }
368
369    let exterior = LineString::from(points_to_coords(&shape[0]));
370    let interiors = shape
371        .iter()
372        .skip(1)
373        .map(|points| LineString::from(points_to_coords(points)))
374        .collect();
375
376    Some(Polygon::new(exterior, interiors))
377}
378
379fn linestring_to_points(line: &LineString<f64>, want_ccw: bool) -> Vec<[f64; 2]> {
380    let mut coords: Vec<Coord<f64>> = line
381        .points()
382        .map(|p| Coord { x: p.x(), y: p.y() })
383        .collect();
384    close_if_needed(&mut coords);
385
386    let area = signed_area_coords(&coords);
387    if want_ccw && area < 0.0 || !want_ccw && area > 0.0 {
388        coords.reverse();
389        close_if_needed(&mut coords);
390    }
391
392    coords.iter().map(|c| [c.x, c.y]).collect()
393}
394
395fn points_to_coords(points: &[[f64; 2]]) -> Vec<Coord<f64>> {
396    let mut coords: Vec<Coord<f64>> = points
397        .iter()
398        .map(|point| Coord {
399            x: point[0],
400            y: point[1],
401        })
402        .collect();
403    close_if_needed(&mut coords);
404    coords
405}
406
407fn close_if_needed(coords: &mut Vec<Coord<f64>>) {
408    if coords.len() < 2 {
409        return;
410    }
411    let first = coords.first().copied().unwrap();
412    let last = coords.last().copied().unwrap();
413    if first.x != last.x || first.y != last.y {
414        coords.push(first);
415    }
416}
417
418fn signed_area_coords(coords: &[Coord<f64>]) -> f64 {
419    if coords.len() < 2 {
420        return 0.0;
421    }
422
423    let mut area = 0.0;
424    for window in coords.windows(2) {
425        if let [a, b] = window {
426            area += a.x * b.y - b.x * a.y;
427        }
428    }
429    area * 0.5
430}
431
432fn perimeter_coords(coords: &[Coord<f64>]) -> f64 {
433    if coords.len() < 2 {
434        return 0.0;
435    }
436
437    let mut length = 0.0;
438    for window in coords.windows(2) {
439        if let [a, b] = window {
440            let dx = b.x - a.x;
441            let dy = b.y - a.y;
442            length += (dx * dx + dy * dy).sqrt();
443        }
444    }
445    length
446}
447
448fn polygon_area(polygon: &Polygon<f64>) -> f32 {
449    let mut area = signed_area_coords(&points_to_coords(&linestring_to_points(
450        polygon.exterior(),
451        true,
452    )))
453    .abs();
454
455    for interior in polygon.interiors() {
456        area -= signed_area_coords(&points_to_coords(&linestring_to_points(interior, false))).abs();
457    }
458
459    area as f32
460}
461
462fn unclip_distance(polygon: &Polygon<f64>, ratio: f32) -> f64 {
463    if ratio <= 0.0 {
464        return 0.0;
465    }
466
467    let exterior_coords = points_to_coords(&linestring_to_points(polygon.exterior(), true));
468    let area = signed_area_coords(&exterior_coords).abs();
469    let perimeter = perimeter_coords(&exterior_coords);
470
471    if perimeter <= f64::EPSILON {
472        0.0
473    } else {
474        (area / perimeter) * ratio as f64
475    }
476}
477
478fn clamp_to_bounds(value: f64, bound: u32) -> f64 {
479    let upper = bound as f64;
480    value.clamp(0.0, upper)
481}
482
483fn round_fractional(value: f64, digits: u32) -> f64 {
484    let factor = 10_f64.powi(digits as i32);
485    (value * factor).round() / factor
486}
487
488#[cfg(test)]
489mod tests {
490    use super::*;
491    use imageproc::contours::BorderType;
492    use ndarray::array;
493
494    #[test]
495    fn extracts_single_square_contour() {
496        let probability_map = array![
497            [0.0, 0.0, 0.0, 0.0],
498            [0.0, 1.0, 1.0, 0.0],
499            [0.0, 1.0, 1.0, 0.0],
500            [0.0, 0.0, 0.0, 0.0]
501        ];
502
503        let processor = DetPostProcessor::new(DetPostProcessorConfig {
504            threshold: 0.5,
505            min_area: 1.0,
506        });
507
508        let contours = processor.process_probability_map(&probability_map).unwrap();
509        assert_eq!(contours.len(), 1);
510
511        let area = contour_area(&contours[0]);
512        assert!(area >= 1.0, "expected positive area, got {}", area);
513    }
514
515    #[test]
516    fn filters_small_regions() {
517        let probability_map = array![
518            [0.0, 0.0, 0.0, 0.0],
519            [0.0, 1.0, 0.0, 0.0],
520            [0.0, 0.0, 0.0, 0.0],
521            [0.0, 0.0, 0.0, 0.0]
522        ];
523
524        let processor = DetPostProcessor::new(DetPostProcessorConfig {
525            threshold: 0.5,
526            min_area: 5.0,
527        });
528
529        let contours = processor.process_probability_map(&probability_map).unwrap();
530        assert!(contours.is_empty());
531    }
532
533    #[test]
534    fn empty_probability_map_is_error() {
535        let probability_map = Array2::<f32>::zeros((0, 0));
536        let processor = DetPostProcessor::new(DetPostProcessorConfig::default());
537
538        let err = processor
539            .process_probability_map(&probability_map)
540            .unwrap_err();
541        matches!(err, DetPostProcessorError::EmptyProbabilityMap);
542    }
543
544    #[test]
545    fn unclip_makes_polygon_larger() {
546        let contour = Contour::new(
547            vec![
548                Point::new(0, 0),
549                Point::new(4, 0),
550                Point::new(4, 4),
551                Point::new(0, 4),
552            ],
553            BorderType::Outer,
554            None,
555        );
556
557        let unclipper = DetPolygonUnclipper::new(DetPolygonUnclipperConfig {
558            unclip_ratio: 2.0,
559            min_result_area: 1.0,
560            join_style: DetUnclipLineJoin::Round(0.1),
561        });
562
563        let unclipped = unclipper.unclip_contours(&[contour]);
564        assert!(!unclipped.is_empty());
565
566        let original_area = 16.0;
567        let enlarged = unclipped
568            .iter()
569            .map(|poly| polygon_area(poly) as f64)
570            .fold(0.0, f64::max);
571
572        assert!(
573            enlarged > original_area,
574            "expected unclip area ({}) to exceed original ({})",
575            enlarged,
576            original_area
577        );
578    }
579
580    #[test]
581    fn scaler_restores_original_coordinates() {
582        let polygon = Polygon::new(
583            LineString::from(vec![
584                Coord { x: 50.0, y: 20.0 },
585                Coord { x: 150.0, y: 20.0 },
586                Coord { x: 150.0, y: 120.0 },
587                Coord { x: 50.0, y: 120.0 },
588                Coord { x: 50.0, y: 20.0 },
589            ]),
590            Vec::new(),
591        );
592
593        let scaler = DetPolygonScaler::new(DetPolygonScalerConfig::default());
594        let scaled = scaler.scale_polygons(&[polygon], 0.5, (400, 400));
595
596        assert_eq!(scaled.len(), 1);
597        let exterior = scaled[0].exterior();
598        let expected = vec![
599            Coord { x: 100.0, y: 40.0 },
600            Coord { x: 300.0, y: 40.0 },
601            Coord { x: 300.0, y: 240.0 },
602            Coord { x: 100.0, y: 240.0 },
603            Coord { x: 100.0, y: 40.0 },
604        ];
605
606        for (point, expected) in exterior.points().zip(expected.iter()) {
607            assert!((point.x() - expected.x).abs() < 1e-6 && (point.y() - expected.y).abs() < 1e-6);
608        }
609    }
610
611    #[test]
612    fn scaler_clamps_coordinates_when_enabled() {
613        let polygon = Polygon::new(
614            LineString::from(vec![
615                Coord { x: 500.0, y: 500.0 },
616                Coord { x: 600.0, y: 500.0 },
617                Coord { x: 600.0, y: 600.0 },
618                Coord { x: 500.0, y: 600.0 },
619                Coord { x: 500.0, y: 500.0 },
620            ]),
621            Vec::new(),
622        );
623
624        let scaler = DetPolygonScaler::new(DetPolygonScalerConfig {
625            clamp_to_image: true,
626            rounding: DetScaleRounding::None,
627        });
628
629        let scaled = scaler.scale_polygons(&[polygon], 1.0, (256, 256));
630        let exterior = scaled[0].exterior();
631
632        for point in exterior.points() {
633            assert!(
634                (0.0..=256.0).contains(&point.x()),
635                "expected x within bounds, got {}",
636                point.x()
637            );
638            assert!(
639                (0.0..=256.0).contains(&point.y()),
640                "expected y within bounds, got {}",
641                point.y()
642            );
643        }
644    }
645}