Skip to main content

oar_ocr_core/processors/
layout_postprocess.rs

1//! Layout Detection Post-processing
2//!
3//! This module implements post-processing for layout detection models including
4//! PicoDet, RT-DETR, and PP-DocLayout series models.
5
6use crate::core::Tensor4D;
7use crate::domain::tasks::MergeBboxMode;
8use crate::processors::{BoundingBox, ImageScaleInfo, Point};
9use ndarray::{ArrayView3, Axis};
10use std::borrow::Cow;
11use std::collections::HashMap;
12
13type LayoutPostprocessOutput = (Vec<Vec<BoundingBox>>, Vec<Vec<usize>>, Vec<Vec<f32>>);
14type NmsResult = (Vec<BoundingBox>, Vec<usize>, Vec<f32>, Vec<(f32, f32)>);
15
16/// Layout detection post-processor for models like PicoDet and RT-DETR.
17///
18/// This processor converts model predictions into bounding boxes with class labels
19/// and confidence scores for document layout elements.
20#[derive(Debug, Clone)]
21pub struct LayoutPostProcess {
22    /// Number of classes the model predicts
23    num_classes: usize,
24    /// Score threshold for filtering predictions
25    score_threshold: f32,
26    /// Non-maximum suppression threshold
27    nms_threshold: f32,
28    /// Maximum number of detections to return
29    max_detections: usize,
30    /// Model type (e.g., "picodet", "rtdetr", "pp-doclayout")
31    model_type: String,
32}
33
34impl LayoutPostProcess {
35    /// Creates a new layout detection post-processor.
36    pub fn new(
37        num_classes: usize,
38        score_threshold: f32,
39        nms_threshold: f32,
40        max_detections: usize,
41        model_type: String,
42    ) -> Self {
43        Self {
44            num_classes,
45            score_threshold,
46            nms_threshold,
47            max_detections,
48            model_type,
49        }
50    }
51
52    /// Applies post-processing to layout detection model predictions.
53    ///
54    /// # Arguments
55    /// * `predictions` - Model output tensor [batch, num_boxes, 4 + num_classes]
56    /// * `img_shapes` - Original image dimensions for each image in batch
57    ///
58    /// # Returns
59    /// Tuple of (bounding_boxes, class_ids, scores) for each image in batch
60    pub fn apply(
61        &self,
62        predictions: &Tensor4D,
63        img_shapes: Vec<ImageScaleInfo>,
64    ) -> LayoutPostprocessOutput {
65        let batch_size = predictions.shape()[0];
66        let mut all_boxes = Vec::with_capacity(batch_size);
67        let mut all_classes = Vec::with_capacity(batch_size);
68        let mut all_scores = Vec::with_capacity(batch_size);
69
70        // Process each image in batch
71        for (batch_idx, img_shape) in img_shapes.into_iter().enumerate().take(batch_size) {
72            let pred = predictions.index_axis(Axis(0), batch_idx);
73
74            let (boxes, classes, scores) = match self.model_type.as_str() {
75                "picodet" => self.process_picodet(pred, &img_shape),
76                "rtdetr" => self.process_rtdetr(pred, &img_shape),
77                "pp-doclayout" => self.process_pp_doclayout(pred, &img_shape),
78                _ => self.process_standard(pred, &img_shape),
79            };
80
81            all_boxes.push(boxes);
82            all_classes.push(classes);
83            all_scores.push(scores);
84        }
85
86        (all_boxes, all_classes, all_scores)
87    }
88
89    /// Process PicoDet model output.
90    fn process_picodet(
91        &self,
92        predictions: ArrayView3<f32>,
93        img_shape: &ImageScaleInfo,
94    ) -> (Vec<BoundingBox>, Vec<usize>, Vec<f32>) {
95        let mut boxes = Vec::new();
96        let mut classes = Vec::new();
97        let mut scores = Vec::new();
98
99        let orig_width = img_shape.src_w;
100        let orig_height = img_shape.src_h;
101        let shape = predictions.shape();
102        if shape.len() != 3 || shape[2] == 0 {
103            return (boxes, classes, scores);
104        }
105
106        let total_boxes = shape[0] * shape[1];
107        if total_boxes == 0 {
108            return (boxes, classes, scores);
109        }
110
111        let feature_dim = shape[2];
112        let data: Cow<'_, [f32]> = match predictions.as_slice() {
113            Some(slice) => Cow::Borrowed(slice),
114            None => {
115                let (mut vec, offset) = predictions.to_owned().into_raw_vec_and_offset();
116                if let Some(offset) = offset
117                    && offset != 0
118                {
119                    vec.drain(0..offset);
120                }
121                Cow::Owned(vec)
122            }
123        };
124
125        for box_idx in 0..total_boxes {
126            let start = box_idx * feature_dim;
127            let end = start + feature_dim;
128
129            if end > data.len() {
130                break;
131            }
132
133            let row = &data[start..end];
134            if feature_dim == 4 + self.num_classes {
135                // Format: [x1, y1, x2, y2, scores...]
136                let (max_class, max_score) = row[4..].iter().enumerate().fold(
137                    (0usize, 0.0f32),
138                    |(best_cls, best_score), (cls_idx, &score)| {
139                        if score > best_score {
140                            (cls_idx, score)
141                        } else {
142                            (best_cls, best_score)
143                        }
144                    },
145                );
146
147                if max_score < self.score_threshold {
148                    continue;
149                }
150
151                let (sx1, sy1, sx2, sy2) = self.convert_bbox_coords(
152                    row[0],
153                    row[1],
154                    row[2],
155                    row[3],
156                    orig_width,
157                    orig_height,
158                );
159
160                if !Self::is_valid_box(sx1, sy1, sx2, sy2) {
161                    continue;
162                }
163
164                let bbox = BoundingBox::new(vec![
165                    Point::new(sx1, sy1),
166                    Point::new(sx2, sy1),
167                    Point::new(sx2, sy2),
168                    Point::new(sx1, sy2),
169                ]);
170
171                boxes.push(bbox);
172                classes.push(max_class);
173                scores.push(max_score);
174            } else if feature_dim >= 6
175                && let Some((class_id, score, x1, y1, x2, y2)) = self.parse_compact_prediction(row)
176            {
177                if score < self.score_threshold || class_id >= self.num_classes {
178                    continue;
179                }
180
181                let (sx1, sy1, sx2, sy2) =
182                    self.convert_bbox_coords(x1, y1, x2, y2, orig_width, orig_height);
183
184                if !Self::is_valid_box(sx1, sy1, sx2, sy2) {
185                    continue;
186                }
187
188                let bbox = BoundingBox::new(vec![
189                    Point::new(sx1, sy1),
190                    Point::new(sx2, sy1),
191                    Point::new(sx2, sy2),
192                    Point::new(sx1, sy2),
193                ]);
194
195                boxes.push(bbox);
196                classes.push(class_id);
197                scores.push(score);
198            }
199        }
200
201        self.apply_nms(boxes, classes, scores)
202    }
203
204    /// Process RT-DETR model output.
205    fn process_rtdetr(
206        &self,
207        predictions: ArrayView3<f32>,
208        img_shape: &ImageScaleInfo,
209    ) -> (Vec<BoundingBox>, Vec<usize>, Vec<f32>) {
210        // RT-DETR has similar output format to PicoDet
211        self.process_picodet(predictions, img_shape)
212    }
213
214    /// Process PP-DocLayout model output.
215    ///
216    /// Handles 6-dim format (PP-DocLayout), 7-dim format (PP-DocLayoutV3), and 8-dim format (PP-DocLayoutV2).
217    /// - 6-dim: [class_id, score, x1, y1, x2, y2]
218    /// - 7-dim: [class_id, score, x1, y1, x2, y2, extra]
219    /// - 8-dim: [class_id, score, x1, y1, x2, y2, col_index, row_index]
220    ///
221    /// For 8-dim format, boxes are sorted by reading order (col_index ascending, row_index ascending)
222    /// after NMS filtering.
223    fn process_pp_doclayout(
224        &self,
225        predictions: ArrayView3<f32>,
226        img_shape: &ImageScaleInfo,
227    ) -> (Vec<BoundingBox>, Vec<usize>, Vec<f32>) {
228        // PP-DocLayout outputs in [num_boxes, 1, N] format
229        // where N is 6 or 8 depending on model version
230        let shape = predictions.shape();
231        let num_boxes = shape[0];
232        let feature_dim = shape[2];
233
234        let mut boxes = Vec::new();
235        let mut classes = Vec::new();
236        let mut scores = Vec::new();
237        let mut reading_orders: Vec<(f32, f32)> = Vec::new();
238
239        let orig_width = img_shape.src_w;
240        let orig_height = img_shape.src_h;
241
242        let has_reading_order = feature_dim == 8;
243
244        // Extract predictions
245        for box_idx in 0..num_boxes {
246            // predictions is [num_boxes, 1, N], so we use 3D indexing [box_idx, 0, i]
247            let class_id = predictions[[box_idx, 0, 0]] as i32;
248            let score = predictions[[box_idx, 0, 1]];
249            let x1 = predictions[[box_idx, 0, 2]];
250            let y1 = predictions[[box_idx, 0, 3]];
251            let x2 = predictions[[box_idx, 0, 4]];
252            let y2 = predictions[[box_idx, 0, 5]];
253
254            // Extract reading order info if available (8-dim format)
255            // Default to (0, box_idx) for 6-dim format to maintain original order
256            let reading_order = if has_reading_order {
257                (predictions[[box_idx, 0, 6]], predictions[[box_idx, 0, 7]])
258            } else {
259                (0.0, box_idx as f32)
260            };
261
262            // Filter by threshold and valid class
263            if score < self.score_threshold
264                || class_id < 0
265                || (class_id as usize) >= self.num_classes
266            {
267                continue;
268            }
269
270            // PP-DocLayout-style models may emit either absolute pixel coords or normalized coords.
271            // Use the same normalization heuristic as other detectors for robustness.
272            let (sx1, sy1, sx2, sy2) =
273                self.convert_bbox_coords(x1, y1, x2, y2, orig_width, orig_height);
274            if !Self::is_valid_box(sx1, sy1, sx2, sy2) {
275                continue;
276            }
277
278            let bbox = BoundingBox::new(vec![
279                Point::new(sx1, sy1),
280                Point::new(sx2, sy1),
281                Point::new(sx2, sy2),
282                Point::new(sx1, sy2),
283            ]);
284
285            boxes.push(bbox);
286            classes.push(class_id as usize);
287            scores.push(score);
288            reading_orders.push(reading_order);
289        }
290
291        // Apply NMS with reading order preservation
292        let (filtered_boxes, filtered_classes, filtered_scores, filtered_reading_orders) =
293            self.apply_nms_with_reading_order(boxes, classes, scores, reading_orders);
294
295        // Sort by reading order if we have 8-dim format
296        if has_reading_order && !filtered_boxes.is_empty() {
297            let mut indices: Vec<usize> = (0..filtered_boxes.len()).collect();
298            indices.sort_by(|&i, &j| {
299                let (col_i, row_i) = filtered_reading_orders[i];
300                let (col_j, row_j) = filtered_reading_orders[j];
301                // Sort by col_index ascending, then row_index ascending
302                // Use total_cmp to handle NaN/infinity values gracefully
303                col_i
304                    .total_cmp(&col_j)
305                    .then_with(|| row_i.total_cmp(&row_j))
306            });
307
308            let sorted_boxes = indices.iter().map(|&i| filtered_boxes[i].clone()).collect();
309            let sorted_classes = indices.iter().map(|&i| filtered_classes[i]).collect();
310            let sorted_scores = indices.iter().map(|&i| filtered_scores[i]).collect();
311
312            (sorted_boxes, sorted_classes, sorted_scores)
313        } else {
314            (filtered_boxes, filtered_classes, filtered_scores)
315        }
316    }
317
318    /// Apply NMS with reading order preservation.
319    fn apply_nms_with_reading_order(
320        &self,
321        boxes: Vec<BoundingBox>,
322        classes: Vec<usize>,
323        scores: Vec<f32>,
324        reading_orders: Vec<(f32, f32)>,
325    ) -> NmsResult {
326        if boxes.is_empty() {
327            return (boxes, classes, scores, reading_orders);
328        }
329
330        let keep = self.compute_nms_keep_indices(&boxes, &classes, &scores);
331
332        let filtered_boxes: Vec<BoundingBox> = keep.iter().map(|&i| boxes[i].clone()).collect();
333        let filtered_classes: Vec<usize> = keep.iter().map(|&i| classes[i]).collect();
334        let filtered_scores: Vec<f32> = keep.iter().map(|&i| scores[i]).collect();
335        let filtered_reading_orders: Vec<(f32, f32)> =
336            keep.iter().map(|&i| reading_orders[i]).collect();
337
338        (
339            filtered_boxes,
340            filtered_classes,
341            filtered_scores,
342            filtered_reading_orders,
343        )
344    }
345
346    /// Process standard detection model output.
347    fn process_standard(
348        &self,
349        predictions: ArrayView3<f32>,
350        img_shape: &ImageScaleInfo,
351    ) -> (Vec<BoundingBox>, Vec<usize>, Vec<f32>) {
352        self.process_picodet(predictions, img_shape)
353    }
354
355    fn parse_compact_prediction(&self, row: &[f32]) -> Option<(usize, f32, f32, f32, f32, f32)> {
356        if row.len() < 6 {
357            return None;
358        }
359
360        // Format: [class_id, score, x1, y1, x2, y2]
361        let score_is_valid = if self.model_type == "rtdetr" {
362            row[1].is_finite()
363        } else {
364            Self::is_valid_score(row[1])
365        };
366
367        if score_is_valid && Self::is_valid_class(row[0], self.num_classes) {
368            let class_id = row[0].round() as i32;
369            if class_id >= 0 {
370                let score = self.adjust_score(row[1]);
371                return Some((class_id as usize, score, row[2], row[3], row[4], row[5]));
372            }
373        }
374
375        // Alternate format: [x1, y1, x2, y2, score, class_id]
376        let score_is_valid = if self.model_type == "rtdetr" {
377            row[4].is_finite()
378        } else {
379            Self::is_valid_score(row[4])
380        };
381        if score_is_valid && Self::is_valid_class(row[5], self.num_classes) {
382            let class_id = row[5].round() as i32;
383            if class_id >= 0 {
384                let score = self.adjust_score(row[4]);
385                return Some((class_id as usize, score, row[0], row[1], row[2], row[3]));
386            }
387        }
388
389        // Alternate format: [score, class_id, x1, y1, x2, y2]
390        let score_is_valid = if self.model_type == "rtdetr" {
391            row[0].is_finite()
392        } else {
393            Self::is_valid_score(row[0])
394        };
395        if score_is_valid && Self::is_valid_class(row[1], self.num_classes) {
396            let class_id = row[1].round() as i32;
397            if class_id >= 0 {
398                let score = self.adjust_score(row[0]);
399                return Some((class_id as usize, score, row[2], row[3], row[4], row[5]));
400            }
401        }
402
403        None
404    }
405
406    fn convert_bbox_coords(
407        &self,
408        x1: f32,
409        y1: f32,
410        x2: f32,
411        y2: f32,
412        orig_width: f32,
413        orig_height: f32,
414    ) -> (f32, f32, f32, f32) {
415        let normalized = x2 <= 1.05
416            && y2 <= 1.05
417            && x1 >= -0.05
418            && y1 >= -0.05
419            && orig_width > 0.0
420            && orig_height > 0.0;
421
422        if normalized {
423            (
424                x1.clamp(0.0, 1.0) * orig_width,
425                y1.clamp(0.0, 1.0) * orig_height,
426                x2.clamp(0.0, 1.0) * orig_width,
427                y2.clamp(0.0, 1.0) * orig_height,
428            )
429        } else {
430            (
431                x1.clamp(0.0, orig_width),
432                y1.clamp(0.0, orig_height),
433                x2.clamp(0.0, orig_width),
434                y2.clamp(0.0, orig_height),
435            )
436        }
437    }
438
439    fn is_valid_box(x1: f32, y1: f32, x2: f32, y2: f32) -> bool {
440        x2 > x1 && y2 > y1 && x1.is_finite() && y1.is_finite() && x2.is_finite() && y2.is_finite()
441    }
442
443    fn is_valid_score(score: f32) -> bool {
444        score.is_finite() && (0.0..=1.0 + f32::EPSILON).contains(&score)
445    }
446
447    fn is_valid_class(raw: f32, num_classes: usize) -> bool {
448        if !raw.is_finite() {
449            return false;
450        }
451        let class_id = raw.round() as i32;
452        class_id >= 0 && (class_id as usize) < num_classes + 5
453    }
454
455    fn adjust_score(&self, raw_score: f32) -> f32 {
456        if self.model_type == "rtdetr" {
457            raw_score.clamp(0.0, 1.0)
458        } else {
459            raw_score
460        }
461    }
462
463    /// Compute indices to keep after NMS.
464    /// Returns the indices of boxes that survive non-maximum suppression.
465    fn compute_nms_keep_indices(
466        &self,
467        boxes: &[BoundingBox],
468        classes: &[usize],
469        scores: &[f32],
470    ) -> Vec<usize> {
471        // Sort by score in descending order
472        let mut indices: Vec<usize> = (0..boxes.len()).collect();
473        indices.sort_by(|&a, &b| {
474            scores[b]
475                .partial_cmp(&scores[a])
476                .unwrap_or(std::cmp::Ordering::Equal)
477        });
478
479        let mut keep = Vec::new();
480        let mut suppressed = vec![false; boxes.len()];
481
482        for &i in &indices {
483            if suppressed[i] {
484                continue;
485            }
486
487            keep.push(i);
488            if keep.len() >= self.max_detections {
489                break;
490            }
491
492            // Suppress boxes with high IoU
493            for &j in &indices {
494                if i != j && !suppressed[j] && classes[i] == classes[j] {
495                    let iou = self.calculate_iou(&boxes[i], &boxes[j]);
496                    if iou > self.nms_threshold {
497                        suppressed[j] = true;
498                    }
499                }
500            }
501        }
502
503        keep
504    }
505
506    /// Apply Non-Maximum Suppression to filter overlapping boxes.
507    fn apply_nms(
508        &self,
509        boxes: Vec<BoundingBox>,
510        classes: Vec<usize>,
511        scores: Vec<f32>,
512    ) -> (Vec<BoundingBox>, Vec<usize>, Vec<f32>) {
513        if boxes.is_empty() {
514            return (boxes, classes, scores);
515        }
516
517        let keep = self.compute_nms_keep_indices(&boxes, &classes, &scores);
518
519        let filtered_boxes: Vec<BoundingBox> = keep.iter().map(|&i| boxes[i].clone()).collect();
520        let filtered_classes: Vec<usize> = keep.iter().map(|&i| classes[i]).collect();
521        let filtered_scores: Vec<f32> = keep.iter().map(|&i| scores[i]).collect();
522
523        (filtered_boxes, filtered_classes, filtered_scores)
524    }
525
526    /// Calculate Intersection over Union between two bounding boxes.
527    fn calculate_iou(&self, box1: &BoundingBox, box2: &BoundingBox) -> f32 {
528        // Get bounding rectangle for box1
529        let (x1_min, y1_min, x1_max, y1_max) = self.get_bbox_bounds(box1);
530
531        // Get bounding rectangle for box2
532        let (x2_min, y2_min, x2_max, y2_max) = self.get_bbox_bounds(box2);
533
534        // Calculate intersection
535        let x_min = x1_min.max(x2_min);
536        let y_min = y1_min.max(y2_min);
537        let x_max = x1_max.min(x2_max);
538        let y_max = y1_max.min(y2_max);
539
540        if x_max <= x_min || y_max <= y_min {
541            return 0.0;
542        }
543
544        let intersection = (x_max - x_min) * (y_max - y_min);
545        let area1 = (x1_max - x1_min) * (y1_max - y1_min);
546        let area2 = (x2_max - x2_min) * (y2_max - y2_min);
547        let union = area1 + area2 - intersection;
548
549        if union > 0.0 {
550            intersection / union
551        } else {
552            0.0
553        }
554    }
555
556    /// Get the minimum and maximum coordinates from a bounding box.
557    fn get_bbox_bounds(&self, bbox: &BoundingBox) -> (f32, f32, f32, f32) {
558        if bbox.points.is_empty() {
559            return (0.0, 0.0, 0.0, 0.0);
560        }
561
562        let mut x_min = f32::INFINITY;
563        let mut y_min = f32::INFINITY;
564        let mut x_max = f32::NEG_INFINITY;
565        let mut y_max = f32::NEG_INFINITY;
566
567        for point in &bbox.points {
568            x_min = x_min.min(point.x);
569            y_min = y_min.min(point.y);
570            x_max = x_max.max(point.x);
571            y_max = y_max.max(point.y);
572        }
573
574        (x_min, y_min, x_max, y_max)
575    }
576}
577
578/// Apply unclip ratio to expand/shrink bounding boxes while keeping center fixed.
579///
580/// This follows PP-StructureV3's `layout_unclip_ratio` parameter behavior.
581///
582/// # Arguments
583/// * `boxes` - Input bounding boxes
584/// * `classes` - Class IDs for each box
585/// * `width_ratio` - Ratio to apply to box width (1.0 = no change)
586/// * `height_ratio` - Ratio to apply to box height (1.0 = no change)
587/// * `per_class_ratios` - Optional per-class ratios: class_id -> (width_ratio, height_ratio)
588///
589/// # Returns
590/// Transformed bounding boxes with same center but scaled dimensions
591pub fn unclip_boxes(
592    boxes: &[BoundingBox],
593    classes: &[usize],
594    width_ratio: f32,
595    height_ratio: f32,
596    per_class_ratios: Option<&std::collections::HashMap<usize, (f32, f32)>>,
597) -> Vec<BoundingBox> {
598    boxes
599        .iter()
600        .zip(classes.iter())
601        .map(|(bbox, &class_id)| {
602            // Get ratio for this class
603            let (w_ratio, h_ratio) = per_class_ratios
604                .and_then(|ratios| ratios.get(&class_id).copied())
605                .unwrap_or((width_ratio, height_ratio));
606
607            // Skip if ratios are 1.0 (no change)
608            if (w_ratio - 1.0).abs() < 1e-6 && (h_ratio - 1.0).abs() < 1e-6 {
609                return bbox.clone();
610            }
611
612            // Get current bounds
613            let x_min = bbox.x_min();
614            let y_min = bbox.y_min();
615            let x_max = bbox.x_max();
616            let y_max = bbox.y_max();
617
618            // Calculate center and dimensions
619            let width = x_max - x_min;
620            let height = y_max - y_min;
621            let center_x = x_min + width / 2.0;
622            let center_y = y_min + height / 2.0;
623
624            // Apply ratio
625            let new_width = width * w_ratio;
626            let new_height = height * h_ratio;
627
628            // Calculate new bounds
629            let new_x_min = center_x - new_width / 2.0;
630            let new_y_min = center_y - new_height / 2.0;
631            let new_x_max = center_x + new_width / 2.0;
632            let new_y_max = center_y + new_height / 2.0;
633
634            BoundingBox::from_coords(new_x_min, new_y_min, new_x_max, new_y_max)
635        })
636        .collect()
637}
638
639/// Merge two bounding boxes according to the specified mode.
640///
641/// # Arguments
642/// * `box1` - First bounding box
643/// * `box2` - Second bounding box
644/// * `mode` - Merge mode to apply
645///
646/// # Returns
647/// Merged bounding box according to the mode
648pub fn merge_boxes(box1: &BoundingBox, box2: &BoundingBox, mode: MergeBboxMode) -> BoundingBox {
649    let (x1_min, y1_min, x1_max, y1_max) = (box1.x_min(), box1.y_min(), box1.x_max(), box1.y_max());
650    let (x2_min, y2_min, x2_max, y2_max) = (box2.x_min(), box2.y_min(), box2.x_max(), box2.y_max());
651
652    let area1 = (x1_max - x1_min) * (y1_max - y1_min);
653    let area2 = (x2_max - x2_min) * (y2_max - y2_min);
654
655    match mode {
656        MergeBboxMode::Large => {
657            // Keep the larger bounding box
658            if area1 >= area2 {
659                box1.clone()
660            } else {
661                box2.clone()
662            }
663        }
664        MergeBboxMode::Small => {
665            // Keep the smaller bounding box
666            if area1 <= area2 {
667                box1.clone()
668            } else {
669                box2.clone()
670            }
671        }
672        MergeBboxMode::Union => {
673            // Merge to union of bounding boxes
674            let union_x_min = x1_min.min(x2_min);
675            let union_y_min = y1_min.min(y2_min);
676            let union_x_max = x1_max.max(x2_max);
677            let union_y_max = y1_max.max(y2_max);
678            BoundingBox::from_coords(union_x_min, union_y_min, union_x_max, union_y_max)
679        }
680    }
681}
682
683/// Apply Non-Maximum Suppression with per-class merge modes.
684///
685/// Unlike standard NMS which simply suppresses (discards) overlapping boxes,
686/// this function can merge overlapping boxes according to the specified mode.
687///
688/// # Arguments
689/// * `boxes` - Input bounding boxes
690/// * `classes` - Class IDs for each box
691/// * `scores` - Confidence scores for each box
692/// * `class_labels` - Mapping from class ID to label string
693/// * `class_merge_modes` - Per-class merge modes (label -> mode)
694/// * `nms_threshold` - IoU threshold for overlap detection
695/// * `max_detections` - Maximum number of detections to return
696///
697/// # Returns
698/// Tuple of (filtered_boxes, filtered_classes, filtered_scores)
699pub fn apply_nms_with_merge(
700    boxes: Vec<BoundingBox>,
701    classes: Vec<usize>,
702    scores: Vec<f32>,
703    class_labels: &HashMap<usize, String>,
704    class_merge_modes: &HashMap<String, MergeBboxMode>,
705    nms_threshold: f32,
706    max_detections: usize,
707) -> (Vec<BoundingBox>, Vec<usize>, Vec<f32>) {
708    if boxes.is_empty() {
709        return (boxes, classes, scores);
710    }
711
712    // Sort by score in descending order
713    let mut indices: Vec<usize> = (0..boxes.len()).collect();
714    indices.sort_by(|&a, &b| {
715        scores[b]
716            .partial_cmp(&scores[a])
717            .unwrap_or(std::cmp::Ordering::Equal)
718    });
719
720    let mut result_boxes = Vec::new();
721    let mut result_classes = Vec::new();
722    let mut result_scores = Vec::new();
723    let mut result_order_indices = Vec::new();
724    let mut processed = vec![false; boxes.len()];
725
726    for &i in &indices {
727        if processed[i] {
728            continue;
729        }
730
731        processed[i] = true;
732
733        // Get merge mode for this class
734        let class_label = class_labels
735            .get(&classes[i])
736            .map(|s| s.as_str())
737            .unwrap_or("unknown");
738        let merge_mode = class_merge_modes
739            .get(class_label)
740            .copied()
741            .unwrap_or(MergeBboxMode::Large);
742
743        let mut merged_box = boxes[i].clone();
744        let mut best_score = scores[i];
745        let mut order_idx = i;
746
747        // Find overlapping boxes of the same class and merge them
748        for &j in &indices {
749            if i != j && !processed[j] && classes[i] == classes[j] {
750                let iou = calculate_iou_static(&merged_box, &boxes[j]);
751                if iou > nms_threshold {
752                    // Merge the boxes
753                    merged_box = merge_boxes(&merged_box, &boxes[j], merge_mode);
754                    best_score = best_score.max(scores[j]);
755                    order_idx = order_idx.min(j);
756                    processed[j] = true;
757                }
758            }
759        }
760
761        result_boxes.push(merged_box);
762        result_classes.push(classes[i]);
763        result_scores.push(best_score);
764        result_order_indices.push(order_idx);
765    }
766
767    // First, apply max_detections limit based on score (NMS already processed in score order,
768    // so result_* vectors are implicitly score-ordered). This ensures we keep the highest-scoring
769    // detections rather than earliest ones.
770    let take_count = max_detections.min(result_boxes.len());
771
772    // Preserve input ordering for downstream consumers (e.g., PP-DocLayoutV2 reading-order output).
773    // We keep the score-based selection above, but sort the top-N merged results by the earliest
774    // original index in each merged group.
775    let mut merged: Vec<(usize, BoundingBox, usize, f32)> = result_order_indices
776        .into_iter()
777        .zip(result_boxes)
778        .zip(result_classes)
779        .zip(result_scores)
780        .map(|(((order, bbox), class_id), score)| (order, bbox, class_id, score))
781        .take(take_count) // Apply max_detections limit BEFORE reordering
782        .collect();
783
784    merged.sort_by(|(a, _, _, _), (b, _, _, _)| a.cmp(b));
785
786    let mut final_boxes = Vec::new();
787    let mut final_classes = Vec::new();
788    let mut final_scores = Vec::new();
789
790    for (_, bbox, class_id, score) in merged {
791        final_boxes.push(bbox);
792        final_classes.push(class_id);
793        final_scores.push(score);
794    }
795
796    (final_boxes, final_classes, final_scores)
797}
798
799/// Calculate IoU between two bounding boxes (standalone function).
800fn calculate_iou_static(box1: &BoundingBox, box2: &BoundingBox) -> f32 {
801    let (x1_min, y1_min, x1_max, y1_max) = (box1.x_min(), box1.y_min(), box1.x_max(), box1.y_max());
802    let (x2_min, y2_min, x2_max, y2_max) = (box2.x_min(), box2.y_min(), box2.x_max(), box2.y_max());
803
804    // Calculate intersection
805    let x_min = x1_min.max(x2_min);
806    let y_min = y1_min.max(y2_min);
807    let x_max = x1_max.min(x2_max);
808    let y_max = y1_max.min(y2_max);
809
810    if x_max <= x_min || y_max <= y_min {
811        return 0.0;
812    }
813
814    let intersection = (x_max - x_min) * (y_max - y_min);
815    let area1 = (x1_max - x1_min) * (y1_max - y1_min);
816    let area2 = (x2_max - x2_min) * (y2_max - y2_min);
817    let union = area1 + area2 - intersection;
818
819    if union > 0.0 {
820        intersection / union
821    } else {
822        0.0
823    }
824}
825
826impl Default for LayoutPostProcess {
827    fn default() -> Self {
828        Self {
829            num_classes: 5, // Default for basic layout detection
830            score_threshold: 0.5,
831            nms_threshold: 0.5,
832            max_detections: 100,
833            model_type: "picodet".to_string(),
834        }
835    }
836}
837
838#[cfg(test)]
839mod tests {
840    use super::*;
841
842    #[test]
843    fn test_layout_postprocess_creation() {
844        let processor = LayoutPostProcess::default();
845        assert_eq!(processor.num_classes, 5);
846        assert_eq!(processor.score_threshold, 0.5);
847    }
848
849    #[test]
850    fn test_iou_calculation() {
851        let processor = LayoutPostProcess::default();
852
853        // Two identical boxes should have IoU = 1.0
854        let box1 = BoundingBox::new(vec![
855            Point::new(0.0, 0.0),
856            Point::new(100.0, 0.0),
857            Point::new(100.0, 100.0),
858            Point::new(0.0, 100.0),
859        ]);
860        let box2 = box1.clone();
861
862        assert_eq!(processor.calculate_iou(&box1, &box2), 1.0);
863
864        // Non-overlapping boxes should have IoU = 0.0
865        let box3 = BoundingBox::new(vec![
866            Point::new(200.0, 200.0),
867            Point::new(300.0, 200.0),
868            Point::new(300.0, 300.0),
869            Point::new(200.0, 300.0),
870        ]);
871
872        assert_eq!(processor.calculate_iou(&box1, &box3), 0.0);
873    }
874}