ocr_rs/
postprocess.rs

1//! Postprocessing Utilities
2//!
3//! Provides post-processing functions for text detection results, including bounding box extraction, NMS, box merging, etc.
4
5use image::GrayImage;
6use imageproc::contours::{find_contours, Contour};
7use imageproc::point::Point;
8use imageproc::rect::Rect;
9
10/// Text bounding box
11#[derive(Debug, Clone)]
12pub struct TextBox {
13    /// Bounding box rectangle
14    pub rect: Rect,
15    /// Confidence score
16    pub score: f32,
17    /// Four corner points (optional, for rotated boxes)
18    pub points: Option<[Point<f32>; 4]>,
19}
20
21impl TextBox {
22    /// Create new text bounding box
23    pub fn new(rect: Rect, score: f32) -> Self {
24        Self {
25            rect,
26            score,
27            points: None,
28        }
29    }
30
31    /// Create with corner points
32    pub fn with_points(rect: Rect, score: f32, points: [Point<f32>; 4]) -> Self {
33        Self {
34            rect,
35            score,
36            points: Some(points),
37        }
38    }
39
40    /// Calculate area
41    pub fn area(&self) -> u32 {
42        self.rect.width() * self.rect.height()
43    }
44
45    /// Expand bounding box
46    pub fn expand(&self, border: u32, max_width: u32, max_height: u32) -> Self {
47        let x = (self.rect.left() - border as i32).max(0) as u32;
48        let y = (self.rect.top() - border as i32).max(0) as u32;
49        let right = ((self.rect.left() as u32 + self.rect.width()) + border).min(max_width);
50        let bottom = ((self.rect.top() as u32 + self.rect.height()) + border).min(max_height);
51
52        // 确保 right >= x 和 bottom >= y,避免减法溢出
53        let width = if right > x { right - x } else { 1 };
54        let height = if bottom > y { bottom - y } else { 1 };
55
56        Self {
57            rect: Rect::at(x as i32, y as i32).of_size(width, height),
58            score: self.score,
59            points: self.points,
60        }
61    }
62}
63
64/// Extract text bounding boxes from segmentation mask
65///
66/// # Parameters
67/// - `mask`: Binarized mask (0 or 255)
68/// - `width`: Mask width
69/// - `height`: Mask height
70/// - `original_width`: Original image width
71/// - `original_height`: Original image height
72/// - `min_area`: Minimum bounding box area
73/// - `box_threshold`: Bounding box score threshold
74pub fn extract_boxes_from_mask(
75    mask: &[u8],
76    width: u32,
77    height: u32,
78    original_width: u32,
79    original_height: u32,
80    min_area: u32,
81    _box_threshold: f32,
82) -> Vec<TextBox> {
83    extract_boxes_from_mask_with_padding(
84        mask,
85        width,
86        height,
87        width,
88        height,
89        original_width,
90        original_height,
91        min_area,
92        _box_threshold,
93    )
94}
95
96/// Extract text bounding boxes from segmentation mask with padding
97///
98/// # Parameters
99/// - `mask`: Binarized mask (0 or 255)
100/// - `mask_width`: Mask width (including padding)
101/// - `mask_height`: Mask height (including padding)
102/// - `valid_width`: Valid region width (excluding padding)
103/// - `valid_height`: Valid region height (excluding padding)
104/// - `original_width`: Original image width
105/// - `original_height`: Original image height
106/// - `min_area`: Minimum bounding box area
107/// - `box_threshold`: Bounding box score threshold
108pub fn extract_boxes_from_mask_with_padding(
109    mask: &[u8],
110    mask_width: u32,
111    mask_height: u32,
112    valid_width: u32,
113    valid_height: u32,
114    original_width: u32,
115    original_height: u32,
116    min_area: u32,
117    _box_threshold: f32,
118) -> Vec<TextBox> {
119    extract_boxes_with_unclip(
120        mask,
121        mask_width,
122        mask_height,
123        valid_width,
124        valid_height,
125        original_width,
126        original_height,
127        min_area,
128        1.5, // 默认 unclip_ratio
129    )
130}
131
132/// Extract text bounding boxes from segmentation mask (with unclip expansion)
133///
134/// Core of DB algorithm is to perform unclip expansion on detected contours,
135/// because model output segmentation mask is usually smaller than actual text region.
136pub fn extract_boxes_with_unclip(
137    mask: &[u8],
138    mask_width: u32,
139    mask_height: u32,
140    valid_width: u32,
141    valid_height: u32,
142    original_width: u32,
143    original_height: u32,
144    min_area: u32,
145    unclip_ratio: f32,
146) -> Vec<TextBox> {
147    // Create grayscale image
148    let gray_image = GrayImage::from_raw(mask_width, mask_height, mask.to_vec())
149        .unwrap_or_else(|| GrayImage::new(mask_width, mask_height));
150
151    // Find contours
152    let contours = find_contours::<i32>(&gray_image);
153
154    // Calculate scale ratio (from valid region to original image)
155    let scale_x = original_width as f32 / valid_width as f32;
156    let scale_y = original_height as f32 / valid_height as f32;
157
158    let mut boxes = Vec::new();
159
160    for contour in contours {
161        // Only keep outer contours (without parent), filter out inner/nested contours
162        // This avoids producing overlapping detection boxes
163        if contour.parent.is_some() {
164            continue;
165        }
166
167        if contour.points.len() < 4 {
168            continue;
169        }
170
171        // Calculate bounding box
172        let (min_x, min_y, max_x, max_y) = get_contour_bounds(&contour);
173
174        // Filter out contours in padding area
175        if min_x >= valid_width as i32 || min_y >= valid_height as i32 {
176            continue;
177        }
178
179        // Clip to valid region
180        let min_x = min_x.max(0);
181        let min_y = min_y.max(0);
182        let max_x = max_x.min(valid_width as i32);
183        let max_y = max_y.min(valid_height as i32);
184
185        let box_width = (max_x - min_x) as u32;
186        let box_height = (max_y - min_y) as u32;
187
188        // Filter boxes that are too small
189        if box_width * box_height < min_area {
190            continue;
191        }
192
193        // Calculate unclip expansion amount
194        // DB algorithm uses area and perimeter to calculate expansion distance: distance = Area * unclip_ratio / Perimeter
195        let area = box_width as f32 * box_height as f32;
196        let perimeter = 2.0 * (box_width + box_height) as f32;
197        let expand_dist = (area * unclip_ratio / perimeter).max(1.0);
198
199        // Apply unclip expansion (on coordinates before scaling)
200        let expanded_min_x = (min_x as f32 - expand_dist).max(0.0) as i32;
201        let expanded_min_y = (min_y as f32 - expand_dist).max(0.0) as i32;
202        let expanded_max_x = (max_x as f32 + expand_dist).min(valid_width as f32) as i32;
203        let expanded_max_y = (max_y as f32 + expand_dist).min(valid_height as f32) as i32;
204
205        let expanded_w = (expanded_max_x - expanded_min_x) as u32;
206        let expanded_h = (expanded_max_y - expanded_min_y) as u32;
207
208        // Scale to original image size
209        let scaled_x = (expanded_min_x as f32 * scale_x) as i32;
210        let scaled_y = (expanded_min_y as f32 * scale_y) as i32;
211        let scaled_w = (expanded_w as f32 * scale_x) as u32;
212        let scaled_h = (expanded_h as f32 * scale_y) as u32;
213
214        // Ensure boundaries are within valid range
215        let final_x = scaled_x.max(0) as u32;
216        let final_y = scaled_y.max(0) as u32;
217        let final_w = scaled_w.min(original_width.saturating_sub(final_x));
218        let final_h = scaled_h.min(original_height.saturating_sub(final_y));
219
220        if final_w > 0 && final_h > 0 {
221            let rect = Rect::at(final_x as i32, final_y as i32).of_size(final_w, final_h);
222            boxes.push(TextBox::new(rect, 1.0));
223        }
224    }
225
226    boxes
227}
228
229/// Get contour bounds
230fn get_contour_bounds(contour: &Contour<i32>) -> (i32, i32, i32, i32) {
231    let mut min_x = i32::MAX;
232    let mut min_y = i32::MAX;
233    let mut max_x = i32::MIN;
234    let mut max_y = i32::MIN;
235
236    for point in &contour.points {
237        min_x = min_x.min(point.x);
238        min_y = min_y.min(point.y);
239        max_x = max_x.max(point.x);
240        max_y = max_y.max(point.y);
241    }
242
243    (min_x, min_y, max_x, max_y)
244}
245
246/// Calculate containment ratio of one box inside another
247fn compute_containment_ratio(inner: &Rect, outer: &Rect) -> f32 {
248    let x1 = inner.left().max(outer.left());
249    let y1 = inner.top().max(outer.top());
250    let x2 = (inner.left() + inner.width() as i32).min(outer.left() + outer.width() as i32);
251    let y2 = (inner.top() + inner.height() as i32).min(outer.top() + outer.height() as i32);
252
253    if x2 <= x1 || y2 <= y1 {
254        return 0.0;
255    }
256
257    let intersection = (x2 - x1) as f32 * (y2 - y1) as f32;
258    let inner_area = inner.width() as f32 * inner.height() as f32;
259
260    if inner_area <= 0.0 {
261        0.0
262    } else {
263        intersection / inner_area
264    }
265}
266
267/// Non-Maximum Suppression (NMS)
268///
269/// Filter overlapping bounding boxes, keep ones with highest scores
270/// Also filters small boxes that are largely contained within other boxes
271///
272/// # Parameters
273/// - `boxes`: List of bounding boxes
274/// - `iou_threshold`: IoU threshold, boxes exceeding this value are considered overlapping
275pub fn nms(boxes: &[TextBox], iou_threshold: f32) -> Vec<TextBox> {
276    if boxes.is_empty() {
277        return Vec::new();
278    }
279
280    // Sort by score descending, area descending (keep boxes with higher score and larger area first)
281    let mut indices: Vec<usize> = (0..boxes.len()).collect();
282    indices.sort_by(|&a, &b| {
283        // First sort by score descending
284        let score_cmp = boxes[b]
285            .score
286            .partial_cmp(&boxes[a].score)
287            .unwrap_or(std::cmp::Ordering::Equal);
288        if score_cmp != std::cmp::Ordering::Equal {
289            return score_cmp;
290        }
291        // When scores are equal, sort by area descending (prefer larger boxes)
292        boxes[b].area().cmp(&boxes[a].area())
293    });
294
295    let mut keep = Vec::new();
296    let mut suppressed = vec![false; boxes.len()];
297
298    for (pos, &i) in indices.iter().enumerate() {
299        if suppressed[i] {
300            continue;
301        }
302
303        keep.push(boxes[i].clone());
304
305        // Check all subsequent boxes (lower score or smaller area)
306        for &j in indices.iter().skip(pos + 1) {
307            if suppressed[j] {
308                continue;
309            }
310
311            // Check IoU
312            let iou = compute_iou(&boxes[i].rect, &boxes[j].rect);
313            if iou > iou_threshold {
314                suppressed[j] = true;
315                continue;
316            }
317
318            // Check containment relationship: if j is largely contained (>50%) by i, suppress j
319            let containment_j_in_i = compute_containment_ratio(&boxes[j].rect, &boxes[i].rect);
320            if containment_j_in_i > 0.5 {
321                suppressed[j] = true;
322                continue;
323            }
324
325            // Check reverse containment: if i is largely contained (>70%) by j,
326            // since i was selected first (higher score or larger area), suppress j
327            let containment_i_in_j = compute_containment_ratio(&boxes[i].rect, &boxes[j].rect);
328            if containment_i_in_j > 0.7 {
329                suppressed[j] = true;
330                continue;
331            }
332        }
333    }
334
335    keep
336}
337
338/// Calculate IoU (Intersection over Union) of two rectangles
339pub fn compute_iou(a: &Rect, b: &Rect) -> f32 {
340    let x1 = a.left().max(b.left());
341    let y1 = a.top().max(b.top());
342    let x2 = (a.left() + a.width() as i32).min(b.left() + b.width() as i32);
343    let y2 = (a.top() + a.height() as i32).min(b.top() + b.height() as i32);
344
345    if x2 <= x1 || y2 <= y1 {
346        return 0.0;
347    }
348
349    let intersection = (x2 - x1) as f32 * (y2 - y1) as f32;
350    let area_a = a.width() as f32 * a.height() as f32;
351    let area_b = b.width() as f32 * b.height() as f32;
352    let union = area_a + area_b - intersection;
353
354    if union <= 0.0 {
355        0.0
356    } else {
357        intersection / union
358    }
359}
360
361/// Merge adjacent bounding boxes
362///
363/// Merge bounding boxes that are close to each other into one
364///
365/// # Parameters
366/// - `boxes`: List of bounding boxes
367/// - `distance_threshold`: Distance threshold, boxes below this value will be merged
368pub fn merge_adjacent_boxes(boxes: &[TextBox], distance_threshold: i32) -> Vec<TextBox> {
369    if boxes.is_empty() {
370        return Vec::new();
371    }
372
373    let mut merged = Vec::new();
374    let mut used = vec![false; boxes.len()];
375
376    for i in 0..boxes.len() {
377        if used[i] {
378            continue;
379        }
380
381        let mut current = boxes[i].rect;
382        let mut group_score = boxes[i].score;
383        let mut count = 1;
384        used[i] = true;
385
386        // Find boxes that can be merged
387        loop {
388            let mut found = false;
389
390            for j in 0..boxes.len() {
391                if used[j] {
392                    continue;
393                }
394
395                if can_merge(&current, &boxes[j].rect, distance_threshold) {
396                    current = merge_rects(&current, &boxes[j].rect);
397                    group_score += boxes[j].score;
398                    count += 1;
399                    used[j] = true;
400                    found = true;
401                }
402            }
403
404            if !found {
405                break;
406            }
407        }
408
409        merged.push(TextBox::new(current, group_score / count as f32));
410    }
411
412    merged
413}
414
415/// Check if two boxes can be merged
416fn can_merge(a: &Rect, b: &Rect, threshold: i32) -> bool {
417    // Calculate vertical distance
418    let a_bottom = a.top() + a.height() as i32;
419    let b_bottom = b.top() + b.height() as i32;
420
421    let _vertical_dist = if a.top() > b_bottom {
422        a.top() - b_bottom
423    } else if b.top() > a_bottom {
424        b.top() - a_bottom
425    } else {
426        0 // Vertical overlap
427    };
428
429    // Calculate horizontal distance
430    let a_right = a.left() + a.width() as i32;
431    let b_right = b.left() + b.width() as i32;
432
433    let horizontal_dist = if a.left() > b_right {
434        a.left() - b_right
435    } else if b.left() > a_right {
436        b.left() - a_right
437    } else {
438        0 // Horizontal overlap
439    };
440
441    // Check if on same line (vertical overlap) and horizontal distance is less than threshold
442    let vertical_overlap = !(a.top() > b_bottom || b.top() > a_bottom);
443
444    vertical_overlap && horizontal_dist <= threshold
445}
446
447/// Merge two rectangles
448fn merge_rects(a: &Rect, b: &Rect) -> Rect {
449    let x1 = a.left().min(b.left());
450    let y1 = a.top().min(b.top());
451    let x2 = (a.left() + a.width() as i32).max(b.left() + b.width() as i32);
452    let y2 = (a.top() + a.height() as i32).max(b.top() + b.height() as i32);
453
454    Rect::at(x1, y1).of_size((x2 - x1) as u32, (y2 - y1) as u32)
455}
456
457/// Sort bounding boxes by reading order (top to bottom, left to right)
458pub fn sort_boxes_by_reading_order(boxes: &mut [TextBox]) {
459    boxes.sort_by(|a, b| {
460        // First sort by y coordinate (row)
461        let y_cmp = a.rect.top().cmp(&b.rect.top());
462        if y_cmp != std::cmp::Ordering::Equal {
463            return y_cmp;
464        }
465        // Same row, sort by x coordinate
466        a.rect.left().cmp(&b.rect.left())
467    });
468}
469
470/// Group bounding boxes by line
471///
472/// Group boxes with close y coordinates into the same line
473pub fn group_boxes_by_line(boxes: &[TextBox], line_threshold: i32) -> Vec<Vec<TextBox>> {
474    if boxes.is_empty() {
475        return Vec::new();
476    }
477
478    let mut sorted_boxes = boxes.to_vec();
479    sorted_boxes.sort_by_key(|b| b.rect.top());
480
481    let mut lines: Vec<Vec<TextBox>> = Vec::new();
482    let mut current_line: Vec<TextBox> = vec![sorted_boxes[0].clone()];
483    let mut current_y = sorted_boxes[0].rect.top();
484
485    for box_item in sorted_boxes.iter().skip(1) {
486        if (box_item.rect.top() - current_y).abs() <= line_threshold {
487            current_line.push(box_item.clone());
488        } else {
489            // Sort current line by x
490            current_line.sort_by_key(|b| b.rect.left());
491            lines.push(current_line);
492            current_line = vec![box_item.clone()];
493            current_y = box_item.rect.top();
494        }
495    }
496
497    // Add last line
498    if !current_line.is_empty() {
499        current_line.sort_by_key(|b| b.rect.left());
500        lines.push(current_line);
501    }
502
503    lines
504}
505
506/// Merge bounding boxes from multiple detection results (for high precision mode)
507///
508/// # Parameters
509/// - `results`: Multiple detection results, each element is (boxes, offset_x, offset_y, scale)
510/// - `iou_threshold`: NMS IoU threshold
511pub fn merge_multi_scale_results(
512    results: &[(Vec<TextBox>, u32, u32, f32)],
513    iou_threshold: f32,
514) -> Vec<TextBox> {
515    let mut all_boxes = Vec::new();
516
517    for (boxes, offset_x, offset_y, scale) in results {
518        for box_item in boxes {
519            // Convert box coordinates to original image coordinate system
520            let scaled_x = (box_item.rect.left() as f32 / scale) as i32 + *offset_x as i32;
521            let scaled_y = (box_item.rect.top() as f32 / scale) as i32 + *offset_y as i32;
522            let scaled_w = (box_item.rect.width() as f32 / scale) as u32;
523            let scaled_h = (box_item.rect.height() as f32 / scale) as u32;
524
525            let rect = Rect::at(scaled_x, scaled_y).of_size(scaled_w, scaled_h);
526            all_boxes.push(TextBox::new(rect, box_item.score));
527        }
528    }
529
530    // Apply NMS to remove duplicates
531    nms(&all_boxes, iou_threshold)
532}
533
534// ============== Traditional Algorithm Detection ==============
535
536/// Detect text regions using traditional algorithm (suitable for solid background)
537///
538/// Based on OTSU binarization + connected component analysis, suitable for:
539/// - Document images with solid background
540/// - High contrast text
541/// - As supplement to deep learning detection
542///
543/// # Parameters
544/// - `gray_image`: Grayscale image
545/// - `min_area`: Minimum text region area
546/// - `expand_ratio`: Bounding box expansion ratio
547pub fn detect_text_traditional(
548    gray_image: &GrayImage,
549    min_area: u32,
550    expand_ratio: f32,
551) -> Vec<TextBox> {
552    let (width, height) = gray_image.dimensions();
553
554    // 1. Calculate OTSU threshold
555    let threshold = otsu_threshold(gray_image);
556
557    // 2. Binarization
558    let binary: Vec<u8> = gray_image
559        .pixels()
560        .map(|p| if p.0[0] < threshold { 255 } else { 0 })
561        .collect();
562
563    // 3. Create binary image and find contours
564    let binary_image =
565        GrayImage::from_raw(width, height, binary).unwrap_or_else(|| GrayImage::new(width, height));
566    let contours = find_contours::<i32>(&binary_image);
567
568    // 4. Extract bounding boxes
569    let mut boxes = Vec::new();
570    for contour in contours {
571        if contour.points.len() < 4 {
572            continue;
573        }
574
575        let (min_x, min_y, max_x, max_y) = get_contour_bounds(&contour);
576        let box_width = (max_x - min_x) as u32;
577        let box_height = (max_y - min_y) as u32;
578
579        if box_width * box_height < min_area {
580            continue;
581        }
582
583        // Expand bounding box
584        let expand_w = (box_width as f32 * expand_ratio * 0.5) as i32;
585        let expand_h = (box_height as f32 * expand_ratio * 0.5) as i32;
586
587        let final_x = (min_x - expand_w).max(0) as u32;
588        let final_y = (min_y - expand_h).max(0) as u32;
589        let final_w = ((max_x + expand_w) as u32)
590            .min(width)
591            .saturating_sub(final_x);
592        let final_h = ((max_y + expand_h) as u32)
593            .min(height)
594            .saturating_sub(final_y);
595
596        if final_w > 0 && final_h > 0 {
597            let rect = Rect::at(final_x as i32, final_y as i32).of_size(final_w, final_h);
598            boxes.push(TextBox::new(rect, 1.0));
599        }
600    }
601
602    // 5. Merge adjacent boxes to form text lines
603    merge_into_text_lines(&boxes, 10)
604}
605
606/// OTSU adaptive threshold calculation
607fn otsu_threshold(image: &GrayImage) -> u8 {
608    // Calculate histogram
609    let mut histogram = [0u32; 256];
610    for pixel in image.pixels() {
611        histogram[pixel.0[0] as usize] += 1;
612    }
613
614    let total = image.pixels().count() as f64;
615    let mut sum = 0.0;
616    for (i, &count) in histogram.iter().enumerate() {
617        sum += i as f64 * count as f64;
618    }
619
620    let mut sum_b = 0.0;
621    let mut w_b = 0.0;
622    let mut max_variance = 0.0;
623    let mut threshold = 0u8;
624
625    for (t, &count) in histogram.iter().enumerate() {
626        w_b += count as f64;
627        if w_b == 0.0 {
628            continue;
629        }
630
631        let w_f = total - w_b;
632        if w_f == 0.0 {
633            break;
634        }
635
636        sum_b += t as f64 * count as f64;
637        let m_b = sum_b / w_b;
638        let m_f = (sum - sum_b) / w_f;
639
640        let variance = w_b * w_f * (m_b - m_f).powi(2);
641        if variance > max_variance {
642            max_variance = variance;
643            threshold = t as u8;
644        }
645    }
646
647    threshold
648}
649
650/// Merge independent character boxes into text lines
651fn merge_into_text_lines(boxes: &[TextBox], gap_threshold: i32) -> Vec<TextBox> {
652    if boxes.is_empty() {
653        return Vec::new();
654    }
655
656    // Group by y coordinate
657    let mut sorted_boxes: Vec<_> = boxes.iter().collect();
658    sorted_boxes.sort_by_key(|b| b.rect.top());
659
660    let mut lines: Vec<TextBox> = Vec::new();
661
662    for bbox in sorted_boxes {
663        let mut merged = false;
664
665        // Try to merge into existing lines
666        for line in &mut lines {
667            let line_center_y = line.rect.top() + line.rect.height() as i32 / 2;
668            let box_center_y = bbox.rect.top() + bbox.rect.height() as i32 / 2;
669
670            // If vertical overlap and horizontal proximity
671            if (line_center_y - box_center_y).abs() < line.rect.height() as i32 / 2 {
672                let line_right = line.rect.left() + line.rect.width() as i32;
673                let box_left = bbox.rect.left();
674
675                if (box_left - line_right).abs() < gap_threshold * 3 {
676                    // Merge
677                    let new_left = line.rect.left().min(bbox.rect.left());
678                    let new_top = line.rect.top().min(bbox.rect.top());
679                    let new_right = (line.rect.left() + line.rect.width() as i32)
680                        .max(bbox.rect.left() + bbox.rect.width() as i32);
681                    let new_bottom = (line.rect.top() + line.rect.height() as i32)
682                        .max(bbox.rect.top() + bbox.rect.height() as i32);
683
684                    line.rect = Rect::at(new_left, new_top)
685                        .of_size((new_right - new_left) as u32, (new_bottom - new_top) as u32);
686                    merged = true;
687                    break;
688                }
689            }
690        }
691
692        if !merged {
693            lines.push(bbox.clone());
694        }
695    }
696
697    lines
698}
699
700#[cfg(test)]
701mod tests {
702    use super::*;
703
704    #[test]
705    fn test_textbox_new() {
706        let rect = Rect::at(10, 20).of_size(100, 50);
707        let tb = TextBox::new(rect, 0.95);
708
709        assert_eq!(tb.rect.left(), 10);
710        assert_eq!(tb.rect.top(), 20);
711        assert_eq!(tb.rect.width(), 100);
712        assert_eq!(tb.rect.height(), 50);
713        assert_eq!(tb.score, 0.95);
714        assert!(tb.points.is_none());
715    }
716
717    #[test]
718    fn test_textbox_with_points() {
719        let rect = Rect::at(0, 0).of_size(100, 50);
720        let points = [
721            Point::new(0.0, 0.0),
722            Point::new(100.0, 0.0),
723            Point::new(100.0, 50.0),
724            Point::new(0.0, 50.0),
725        ];
726        let tb = TextBox::with_points(rect, 0.9, points);
727
728        assert!(tb.points.is_some());
729        let pts = tb.points.unwrap();
730        assert_eq!(pts[0].x, 0.0);
731        assert_eq!(pts[1].x, 100.0);
732    }
733
734    #[test]
735    fn test_textbox_area() {
736        let tb = TextBox::new(Rect::at(0, 0).of_size(100, 50), 0.9);
737        assert_eq!(tb.area(), 5000);
738    }
739
740    #[test]
741    fn test_textbox_expand() {
742        let tb = TextBox::new(Rect::at(50, 50).of_size(100, 100), 0.9);
743        let expanded = tb.expand(10, 500, 500);
744
745        assert_eq!(expanded.rect.left(), 40);
746        assert_eq!(expanded.rect.top(), 40);
747        assert_eq!(expanded.rect.width(), 120);
748        assert_eq!(expanded.rect.height(), 120);
749    }
750
751    #[test]
752    fn test_textbox_expand_clamp() {
753        // 测试边界裁剪
754        let tb = TextBox::new(Rect::at(5, 5).of_size(100, 100), 0.9);
755        let expanded = tb.expand(10, 200, 200);
756
757        // 左上角应该被限制在 (0, 0)
758        assert_eq!(expanded.rect.left(), 0);
759        assert_eq!(expanded.rect.top(), 0);
760    }
761
762    #[test]
763    fn test_compute_iou() {
764        let a = Rect::at(0, 0).of_size(10, 10);
765        let b = Rect::at(5, 5).of_size(10, 10);
766
767        let iou = compute_iou(&a, &b);
768        assert!(iou > 0.0 && iou < 1.0);
769
770        // 不相交
771        let c = Rect::at(100, 100).of_size(10, 10);
772        assert_eq!(compute_iou(&a, &c), 0.0);
773
774        // 完全重叠
775        assert_eq!(compute_iou(&a, &a), 1.0);
776    }
777
778    #[test]
779    fn test_compute_iou_partial_overlap() {
780        // 50% 重叠的情况
781        let a = Rect::at(0, 0).of_size(10, 10);
782        let b = Rect::at(5, 0).of_size(10, 10);
783
784        let iou = compute_iou(&a, &b);
785        // 交集面积 = 5 * 10 = 50
786        // 并集面积 = 100 + 100 - 50 = 150
787        // IoU = 50 / 150 ≈ 0.333
788        assert!((iou - 0.333).abs() < 0.01);
789    }
790
791    #[test]
792    fn test_nms() {
793        // 第一个和第二个框有很大重叠,第三个框独立
794        let boxes = vec![
795            TextBox::new(Rect::at(0, 0).of_size(10, 10), 0.9),
796            TextBox::new(Rect::at(1, 1).of_size(10, 10), 0.8), // 与第一个框高度重叠
797            TextBox::new(Rect::at(100, 100).of_size(10, 10), 0.7),
798        ];
799
800        let result = nms(&boxes, 0.3); // 使用较低的阈值确保重叠框被过滤
801                                       // 第一个框(最高分数)和第三个框(无重叠)应该保留
802        assert!(
803            result.len() >= 2,
804            "至少应该保留2个框,实际: {}",
805            result.len()
806        );
807    }
808
809    #[test]
810    fn test_nms_empty() {
811        let boxes: Vec<TextBox> = vec![];
812        let result = nms(&boxes, 0.5);
813        assert!(result.is_empty());
814    }
815
816    #[test]
817    fn test_nms_single() {
818        let boxes = vec![TextBox::new(Rect::at(0, 0).of_size(10, 10), 0.9)];
819        let result = nms(&boxes, 0.5);
820        assert_eq!(result.len(), 1);
821    }
822
823    #[test]
824    fn test_nms_no_overlap() {
825        let boxes = vec![
826            TextBox::new(Rect::at(0, 0).of_size(10, 10), 0.9),
827            TextBox::new(Rect::at(50, 50).of_size(10, 10), 0.8),
828            TextBox::new(Rect::at(100, 100).of_size(10, 10), 0.7),
829        ];
830
831        let result = nms(&boxes, 0.5);
832        assert_eq!(result.len(), 3); // 所有框都保留
833    }
834
835    #[test]
836    fn test_merge_adjacent() {
837        let boxes = vec![
838            TextBox::new(Rect::at(0, 0).of_size(10, 10), 1.0),
839            TextBox::new(Rect::at(12, 0).of_size(10, 10), 1.0), // 水平距离 2
840            TextBox::new(Rect::at(100, 100).of_size(10, 10), 1.0),
841        ];
842
843        let result = merge_adjacent_boxes(&boxes, 5);
844        assert_eq!(result.len(), 2); // 前两个应该合并
845    }
846
847    #[test]
848    fn test_merge_adjacent_empty() {
849        let boxes: Vec<TextBox> = vec![];
850        let result = merge_adjacent_boxes(&boxes, 5);
851        assert!(result.is_empty());
852    }
853
854    #[test]
855    fn test_sort_boxes_by_reading_order() {
856        let mut boxes = vec![
857            TextBox::new(Rect::at(100, 0).of_size(10, 10), 0.9), // 第一行右边
858            TextBox::new(Rect::at(0, 0).of_size(10, 10), 0.9),   // 第一行左边
859            TextBox::new(Rect::at(0, 50).of_size(10, 10), 0.9),  // 第二行
860        ];
861
862        sort_boxes_by_reading_order(&mut boxes);
863
864        // 应该先按行排序,然后行内按x坐标排序
865        assert_eq!(boxes[0].rect.left(), 0);
866        assert_eq!(boxes[0].rect.top(), 0);
867    }
868
869    #[test]
870    fn test_group_boxes_by_line() {
871        let boxes = vec![
872            TextBox::new(Rect::at(0, 0).of_size(50, 20), 0.9),
873            TextBox::new(Rect::at(60, 0).of_size(50, 20), 0.9),
874            TextBox::new(Rect::at(0, 50).of_size(50, 20), 0.9),
875        ];
876
877        let lines = group_boxes_by_line(&boxes, 10);
878
879        // 应该分成两行
880        assert_eq!(lines.len(), 2);
881    }
882}