oar_ocr_core/processors/
layout_postprocess.rs

1//! Layout Detection Post-processing
2//!
3//! This module implements post-processing for layout detection models including
4//! PicoDet, RT-DETR, and PP-DocLayout series models.
5
6use crate::core::Tensor4D;
7use crate::domain::tasks::MergeBboxMode;
8use crate::processors::{BoundingBox, ImageScaleInfo, Point};
9use ndarray::{ArrayView3, Axis};
10use std::borrow::Cow;
11use std::collections::HashMap;
12
13type LayoutPostprocessOutput = (Vec<Vec<BoundingBox>>, Vec<Vec<usize>>, Vec<Vec<f32>>);
14type NmsResult = (Vec<BoundingBox>, Vec<usize>, Vec<f32>, Vec<(f32, f32)>);
15
16/// Layout detection post-processor for models like PicoDet and RT-DETR.
17///
18/// This processor converts model predictions into bounding boxes with class labels
19/// and confidence scores for document layout elements.
20#[derive(Debug, Clone)]
21pub struct LayoutPostProcess {
22    /// Number of classes the model predicts
23    num_classes: usize,
24    /// Score threshold for filtering predictions
25    score_threshold: f32,
26    /// Non-maximum suppression threshold
27    nms_threshold: f32,
28    /// Maximum number of detections to return
29    max_detections: usize,
30    /// Model type (e.g., "picodet", "rtdetr", "pp-doclayout")
31    model_type: String,
32}
33
34impl LayoutPostProcess {
35    /// Creates a new layout detection post-processor.
36    pub fn new(
37        num_classes: usize,
38        score_threshold: f32,
39        nms_threshold: f32,
40        max_detections: usize,
41        model_type: String,
42    ) -> Self {
43        Self {
44            num_classes,
45            score_threshold,
46            nms_threshold,
47            max_detections,
48            model_type,
49        }
50    }
51
52    /// Applies post-processing to layout detection model predictions.
53    ///
54    /// # Arguments
55    /// * `predictions` - Model output tensor [batch, num_boxes, 4 + num_classes]
56    /// * `img_shapes` - Original image dimensions for each image in batch
57    ///
58    /// # Returns
59    /// Tuple of (bounding_boxes, class_ids, scores) for each image in batch
60    pub fn apply(
61        &self,
62        predictions: &Tensor4D,
63        img_shapes: Vec<ImageScaleInfo>,
64    ) -> LayoutPostprocessOutput {
65        let batch_size = predictions.shape()[0];
66        let mut all_boxes = Vec::with_capacity(batch_size);
67        let mut all_classes = Vec::with_capacity(batch_size);
68        let mut all_scores = Vec::with_capacity(batch_size);
69
70        // Process each image in batch
71        for (batch_idx, img_shape) in img_shapes.into_iter().enumerate().take(batch_size) {
72            let pred = predictions.index_axis(Axis(0), batch_idx);
73
74            let (boxes, classes, scores) = match self.model_type.as_str() {
75                "picodet" => self.process_picodet(pred, &img_shape),
76                "rtdetr" => self.process_rtdetr(pred, &img_shape),
77                "pp-doclayout" => self.process_pp_doclayout(pred, &img_shape),
78                _ => self.process_standard(pred, &img_shape),
79            };
80
81            all_boxes.push(boxes);
82            all_classes.push(classes);
83            all_scores.push(scores);
84        }
85
86        (all_boxes, all_classes, all_scores)
87    }
88
89    /// Process PicoDet model output.
90    fn process_picodet(
91        &self,
92        predictions: ArrayView3<f32>,
93        img_shape: &ImageScaleInfo,
94    ) -> (Vec<BoundingBox>, Vec<usize>, Vec<f32>) {
95        let mut boxes = Vec::new();
96        let mut classes = Vec::new();
97        let mut scores = Vec::new();
98
99        let orig_width = img_shape.src_w;
100        let orig_height = img_shape.src_h;
101        let shape = predictions.shape();
102        if shape.len() != 3 || shape[2] == 0 {
103            return (boxes, classes, scores);
104        }
105
106        let total_boxes = shape[0] * shape[1];
107        if total_boxes == 0 {
108            return (boxes, classes, scores);
109        }
110
111        let feature_dim = shape[2];
112        let data: Cow<'_, [f32]> = match predictions.as_slice() {
113            Some(slice) => Cow::Borrowed(slice),
114            None => {
115                let (mut vec, offset) = predictions.to_owned().into_raw_vec_and_offset();
116                if let Some(offset) = offset
117                    && offset != 0
118                {
119                    vec.drain(0..offset);
120                }
121                Cow::Owned(vec)
122            }
123        };
124
125        for box_idx in 0..total_boxes {
126            let start = box_idx * feature_dim;
127            let end = start + feature_dim;
128
129            if end > data.len() {
130                break;
131            }
132
133            let row = &data[start..end];
134            if feature_dim == 4 + self.num_classes {
135                // Format: [x1, y1, x2, y2, scores...]
136                let (max_class, max_score) = row[4..].iter().enumerate().fold(
137                    (0usize, 0.0f32),
138                    |(best_cls, best_score), (cls_idx, &score)| {
139                        if score > best_score {
140                            (cls_idx, score)
141                        } else {
142                            (best_cls, best_score)
143                        }
144                    },
145                );
146
147                if max_score < self.score_threshold {
148                    continue;
149                }
150
151                let (sx1, sy1, sx2, sy2) = self.convert_bbox_coords(
152                    row[0],
153                    row[1],
154                    row[2],
155                    row[3],
156                    orig_width,
157                    orig_height,
158                );
159
160                if !Self::is_valid_box(sx1, sy1, sx2, sy2) {
161                    continue;
162                }
163
164                let bbox = BoundingBox::new(vec![
165                    Point::new(sx1, sy1),
166                    Point::new(sx2, sy1),
167                    Point::new(sx2, sy2),
168                    Point::new(sx1, sy2),
169                ]);
170
171                boxes.push(bbox);
172                classes.push(max_class);
173                scores.push(max_score);
174            } else if feature_dim >= 6
175                && let Some((class_id, score, x1, y1, x2, y2)) = self.parse_compact_prediction(row)
176            {
177                if score < self.score_threshold || class_id >= self.num_classes {
178                    continue;
179                }
180
181                let (sx1, sy1, sx2, sy2) =
182                    self.convert_bbox_coords(x1, y1, x2, y2, orig_width, orig_height);
183
184                if !Self::is_valid_box(sx1, sy1, sx2, sy2) {
185                    continue;
186                }
187
188                let bbox = BoundingBox::new(vec![
189                    Point::new(sx1, sy1),
190                    Point::new(sx2, sy1),
191                    Point::new(sx2, sy2),
192                    Point::new(sx1, sy2),
193                ]);
194
195                boxes.push(bbox);
196                classes.push(class_id);
197                scores.push(score);
198            }
199        }
200
201        self.apply_nms(boxes, classes, scores)
202    }
203
204    /// Process RT-DETR model output.
205    fn process_rtdetr(
206        &self,
207        predictions: ArrayView3<f32>,
208        img_shape: &ImageScaleInfo,
209    ) -> (Vec<BoundingBox>, Vec<usize>, Vec<f32>) {
210        // RT-DETR has similar output format to PicoDet
211        self.process_picodet(predictions, img_shape)
212    }
213
214    /// Process PP-DocLayout model output.
215    ///
216    /// Handles both 6-dim format (PP-DocLayout) and 8-dim format (PP-DocLayoutV2).
217    /// - 6-dim: [class_id, score, x1, y1, x2, y2]
218    /// - 8-dim: [class_id, score, x1, y1, x2, y2, col_index, row_index]
219    ///
220    /// For 8-dim format, boxes are sorted by reading order (col_index ascending, row_index ascending)
221    /// after NMS filtering.
222    fn process_pp_doclayout(
223        &self,
224        predictions: ArrayView3<f32>,
225        img_shape: &ImageScaleInfo,
226    ) -> (Vec<BoundingBox>, Vec<usize>, Vec<f32>) {
227        // PP-DocLayout outputs in [num_boxes, 1, N] format
228        // where N is 6 or 8 depending on model version
229        let shape = predictions.shape();
230        let num_boxes = shape[0];
231        let feature_dim = shape[2];
232
233        let mut boxes = Vec::new();
234        let mut classes = Vec::new();
235        let mut scores = Vec::new();
236        let mut reading_orders: Vec<(f32, f32)> = Vec::new();
237
238        let orig_width = img_shape.src_w;
239        let orig_height = img_shape.src_h;
240
241        let has_reading_order = feature_dim == 8;
242
243        // Extract predictions
244        for box_idx in 0..num_boxes {
245            // predictions is [num_boxes, 1, N], so we use 3D indexing [box_idx, 0, i]
246            let class_id = predictions[[box_idx, 0, 0]] as i32;
247            let score = predictions[[box_idx, 0, 1]];
248            let x1 = predictions[[box_idx, 0, 2]];
249            let y1 = predictions[[box_idx, 0, 3]];
250            let x2 = predictions[[box_idx, 0, 4]];
251            let y2 = predictions[[box_idx, 0, 5]];
252
253            // Extract reading order info if available (8-dim format)
254            // Default to (0, box_idx) for 6-dim format to maintain original order
255            let reading_order = if has_reading_order {
256                (predictions[[box_idx, 0, 6]], predictions[[box_idx, 0, 7]])
257            } else {
258                (0.0, box_idx as f32)
259            };
260
261            // Filter by threshold and valid class
262            if score < self.score_threshold
263                || class_id < 0
264                || (class_id as usize) >= self.num_classes
265            {
266                continue;
267            }
268
269            // PP-DocLayout-style models may emit either absolute pixel coords or normalized coords.
270            // Use the same normalization heuristic as other detectors for robustness.
271            let (sx1, sy1, sx2, sy2) =
272                self.convert_bbox_coords(x1, y1, x2, y2, orig_width, orig_height);
273            if !Self::is_valid_box(sx1, sy1, sx2, sy2) {
274                continue;
275            }
276
277            let bbox = BoundingBox::new(vec![
278                Point::new(sx1, sy1),
279                Point::new(sx2, sy1),
280                Point::new(sx2, sy2),
281                Point::new(sx1, sy2),
282            ]);
283
284            boxes.push(bbox);
285            classes.push(class_id as usize);
286            scores.push(score);
287            reading_orders.push(reading_order);
288        }
289
290        // Apply NMS with reading order preservation
291        let (filtered_boxes, filtered_classes, filtered_scores, filtered_reading_orders) =
292            self.apply_nms_with_reading_order(boxes, classes, scores, reading_orders);
293
294        // Sort by reading order if we have 8-dim format
295        if has_reading_order && !filtered_boxes.is_empty() {
296            let mut indices: Vec<usize> = (0..filtered_boxes.len()).collect();
297            indices.sort_by(|&i, &j| {
298                let (col_i, row_i) = filtered_reading_orders[i];
299                let (col_j, row_j) = filtered_reading_orders[j];
300                // Sort by col_index ascending, then row_index ascending
301                // Use total_cmp to handle NaN/infinity values gracefully
302                col_i
303                    .total_cmp(&col_j)
304                    .then_with(|| row_i.total_cmp(&row_j))
305            });
306
307            let sorted_boxes = indices.iter().map(|&i| filtered_boxes[i].clone()).collect();
308            let sorted_classes = indices.iter().map(|&i| filtered_classes[i]).collect();
309            let sorted_scores = indices.iter().map(|&i| filtered_scores[i]).collect();
310
311            (sorted_boxes, sorted_classes, sorted_scores)
312        } else {
313            (filtered_boxes, filtered_classes, filtered_scores)
314        }
315    }
316
317    /// Apply NMS with reading order preservation.
318    fn apply_nms_with_reading_order(
319        &self,
320        boxes: Vec<BoundingBox>,
321        classes: Vec<usize>,
322        scores: Vec<f32>,
323        reading_orders: Vec<(f32, f32)>,
324    ) -> NmsResult {
325        if boxes.is_empty() {
326            return (boxes, classes, scores, reading_orders);
327        }
328
329        let keep = self.compute_nms_keep_indices(&boxes, &classes, &scores);
330
331        let filtered_boxes: Vec<BoundingBox> = keep.iter().map(|&i| boxes[i].clone()).collect();
332        let filtered_classes: Vec<usize> = keep.iter().map(|&i| classes[i]).collect();
333        let filtered_scores: Vec<f32> = keep.iter().map(|&i| scores[i]).collect();
334        let filtered_reading_orders: Vec<(f32, f32)> =
335            keep.iter().map(|&i| reading_orders[i]).collect();
336
337        (
338            filtered_boxes,
339            filtered_classes,
340            filtered_scores,
341            filtered_reading_orders,
342        )
343    }
344
345    /// Process standard detection model output.
346    fn process_standard(
347        &self,
348        predictions: ArrayView3<f32>,
349        img_shape: &ImageScaleInfo,
350    ) -> (Vec<BoundingBox>, Vec<usize>, Vec<f32>) {
351        self.process_picodet(predictions, img_shape)
352    }
353
354    fn parse_compact_prediction(&self, row: &[f32]) -> Option<(usize, f32, f32, f32, f32, f32)> {
355        if row.len() < 6 {
356            return None;
357        }
358
359        // Format: [class_id, score, x1, y1, x2, y2]
360        let score_is_valid = if self.model_type == "rtdetr" {
361            row[1].is_finite()
362        } else {
363            Self::is_valid_score(row[1])
364        };
365
366        if score_is_valid && Self::is_valid_class(row[0], self.num_classes) {
367            let class_id = row[0].round() as i32;
368            if class_id >= 0 {
369                let score = self.adjust_score(row[1]);
370                return Some((class_id as usize, score, row[2], row[3], row[4], row[5]));
371            }
372        }
373
374        // Alternate format: [x1, y1, x2, y2, score, class_id]
375        let score_is_valid = if self.model_type == "rtdetr" {
376            row[4].is_finite()
377        } else {
378            Self::is_valid_score(row[4])
379        };
380        if score_is_valid && Self::is_valid_class(row[5], self.num_classes) {
381            let class_id = row[5].round() as i32;
382            if class_id >= 0 {
383                let score = self.adjust_score(row[4]);
384                return Some((class_id as usize, score, row[0], row[1], row[2], row[3]));
385            }
386        }
387
388        // Alternate format: [score, class_id, x1, y1, x2, y2]
389        let score_is_valid = if self.model_type == "rtdetr" {
390            row[0].is_finite()
391        } else {
392            Self::is_valid_score(row[0])
393        };
394        if score_is_valid && Self::is_valid_class(row[1], self.num_classes) {
395            let class_id = row[1].round() as i32;
396            if class_id >= 0 {
397                let score = self.adjust_score(row[0]);
398                return Some((class_id as usize, score, row[2], row[3], row[4], row[5]));
399            }
400        }
401
402        None
403    }
404
405    fn convert_bbox_coords(
406        &self,
407        x1: f32,
408        y1: f32,
409        x2: f32,
410        y2: f32,
411        orig_width: f32,
412        orig_height: f32,
413    ) -> (f32, f32, f32, f32) {
414        let normalized = x2 <= 1.05
415            && y2 <= 1.05
416            && x1 >= -0.05
417            && y1 >= -0.05
418            && orig_width > 0.0
419            && orig_height > 0.0;
420
421        if normalized {
422            (
423                x1.clamp(0.0, 1.0) * orig_width,
424                y1.clamp(0.0, 1.0) * orig_height,
425                x2.clamp(0.0, 1.0) * orig_width,
426                y2.clamp(0.0, 1.0) * orig_height,
427            )
428        } else {
429            (
430                x1.clamp(0.0, orig_width),
431                y1.clamp(0.0, orig_height),
432                x2.clamp(0.0, orig_width),
433                y2.clamp(0.0, orig_height),
434            )
435        }
436    }
437
438    fn is_valid_box(x1: f32, y1: f32, x2: f32, y2: f32) -> bool {
439        x2 > x1 && y2 > y1 && x1.is_finite() && y1.is_finite() && x2.is_finite() && y2.is_finite()
440    }
441
442    fn is_valid_score(score: f32) -> bool {
443        score.is_finite() && (0.0..=1.0 + f32::EPSILON).contains(&score)
444    }
445
446    fn is_valid_class(raw: f32, num_classes: usize) -> bool {
447        if !raw.is_finite() {
448            return false;
449        }
450        let class_id = raw.round() as i32;
451        class_id >= 0 && (class_id as usize) < num_classes + 5
452    }
453
454    fn adjust_score(&self, raw_score: f32) -> f32 {
455        if self.model_type == "rtdetr" {
456            raw_score.clamp(0.0, 1.0)
457        } else {
458            raw_score
459        }
460    }
461
462    /// Compute indices to keep after NMS.
463    /// Returns the indices of boxes that survive non-maximum suppression.
464    fn compute_nms_keep_indices(
465        &self,
466        boxes: &[BoundingBox],
467        classes: &[usize],
468        scores: &[f32],
469    ) -> Vec<usize> {
470        // Sort by score in descending order
471        let mut indices: Vec<usize> = (0..boxes.len()).collect();
472        indices.sort_by(|&a, &b| {
473            scores[b]
474                .partial_cmp(&scores[a])
475                .unwrap_or(std::cmp::Ordering::Equal)
476        });
477
478        let mut keep = Vec::new();
479        let mut suppressed = vec![false; boxes.len()];
480
481        for &i in &indices {
482            if suppressed[i] {
483                continue;
484            }
485
486            keep.push(i);
487            if keep.len() >= self.max_detections {
488                break;
489            }
490
491            // Suppress boxes with high IoU
492            for &j in &indices {
493                if i != j && !suppressed[j] && classes[i] == classes[j] {
494                    let iou = self.calculate_iou(&boxes[i], &boxes[j]);
495                    if iou > self.nms_threshold {
496                        suppressed[j] = true;
497                    }
498                }
499            }
500        }
501
502        keep
503    }
504
505    /// Apply Non-Maximum Suppression to filter overlapping boxes.
506    fn apply_nms(
507        &self,
508        boxes: Vec<BoundingBox>,
509        classes: Vec<usize>,
510        scores: Vec<f32>,
511    ) -> (Vec<BoundingBox>, Vec<usize>, Vec<f32>) {
512        if boxes.is_empty() {
513            return (boxes, classes, scores);
514        }
515
516        let keep = self.compute_nms_keep_indices(&boxes, &classes, &scores);
517
518        let filtered_boxes: Vec<BoundingBox> = keep.iter().map(|&i| boxes[i].clone()).collect();
519        let filtered_classes: Vec<usize> = keep.iter().map(|&i| classes[i]).collect();
520        let filtered_scores: Vec<f32> = keep.iter().map(|&i| scores[i]).collect();
521
522        (filtered_boxes, filtered_classes, filtered_scores)
523    }
524
525    /// Calculate Intersection over Union between two bounding boxes.
526    fn calculate_iou(&self, box1: &BoundingBox, box2: &BoundingBox) -> f32 {
527        // Get bounding rectangle for box1
528        let (x1_min, y1_min, x1_max, y1_max) = self.get_bbox_bounds(box1);
529
530        // Get bounding rectangle for box2
531        let (x2_min, y2_min, x2_max, y2_max) = self.get_bbox_bounds(box2);
532
533        // Calculate intersection
534        let x_min = x1_min.max(x2_min);
535        let y_min = y1_min.max(y2_min);
536        let x_max = x1_max.min(x2_max);
537        let y_max = y1_max.min(y2_max);
538
539        if x_max <= x_min || y_max <= y_min {
540            return 0.0;
541        }
542
543        let intersection = (x_max - x_min) * (y_max - y_min);
544        let area1 = (x1_max - x1_min) * (y1_max - y1_min);
545        let area2 = (x2_max - x2_min) * (y2_max - y2_min);
546        let union = area1 + area2 - intersection;
547
548        if union > 0.0 {
549            intersection / union
550        } else {
551            0.0
552        }
553    }
554
555    /// Get the minimum and maximum coordinates from a bounding box.
556    fn get_bbox_bounds(&self, bbox: &BoundingBox) -> (f32, f32, f32, f32) {
557        if bbox.points.is_empty() {
558            return (0.0, 0.0, 0.0, 0.0);
559        }
560
561        let mut x_min = f32::INFINITY;
562        let mut y_min = f32::INFINITY;
563        let mut x_max = f32::NEG_INFINITY;
564        let mut y_max = f32::NEG_INFINITY;
565
566        for point in &bbox.points {
567            x_min = x_min.min(point.x);
568            y_min = y_min.min(point.y);
569            x_max = x_max.max(point.x);
570            y_max = y_max.max(point.y);
571        }
572
573        (x_min, y_min, x_max, y_max)
574    }
575}
576
577/// Apply unclip ratio to expand/shrink bounding boxes while keeping center fixed.
578///
579/// This follows PP-StructureV3's `layout_unclip_ratio` parameter behavior.
580///
581/// # Arguments
582/// * `boxes` - Input bounding boxes
583/// * `classes` - Class IDs for each box
584/// * `width_ratio` - Ratio to apply to box width (1.0 = no change)
585/// * `height_ratio` - Ratio to apply to box height (1.0 = no change)
586/// * `per_class_ratios` - Optional per-class ratios: class_id -> (width_ratio, height_ratio)
587///
588/// # Returns
589/// Transformed bounding boxes with same center but scaled dimensions
590pub fn unclip_boxes(
591    boxes: &[BoundingBox],
592    classes: &[usize],
593    width_ratio: f32,
594    height_ratio: f32,
595    per_class_ratios: Option<&std::collections::HashMap<usize, (f32, f32)>>,
596) -> Vec<BoundingBox> {
597    boxes
598        .iter()
599        .zip(classes.iter())
600        .map(|(bbox, &class_id)| {
601            // Get ratio for this class
602            let (w_ratio, h_ratio) = per_class_ratios
603                .and_then(|ratios| ratios.get(&class_id).copied())
604                .unwrap_or((width_ratio, height_ratio));
605
606            // Skip if ratios are 1.0 (no change)
607            if (w_ratio - 1.0).abs() < 1e-6 && (h_ratio - 1.0).abs() < 1e-6 {
608                return bbox.clone();
609            }
610
611            // Get current bounds
612            let x_min = bbox.x_min();
613            let y_min = bbox.y_min();
614            let x_max = bbox.x_max();
615            let y_max = bbox.y_max();
616
617            // Calculate center and dimensions
618            let width = x_max - x_min;
619            let height = y_max - y_min;
620            let center_x = x_min + width / 2.0;
621            let center_y = y_min + height / 2.0;
622
623            // Apply ratio
624            let new_width = width * w_ratio;
625            let new_height = height * h_ratio;
626
627            // Calculate new bounds
628            let new_x_min = center_x - new_width / 2.0;
629            let new_y_min = center_y - new_height / 2.0;
630            let new_x_max = center_x + new_width / 2.0;
631            let new_y_max = center_y + new_height / 2.0;
632
633            BoundingBox::from_coords(new_x_min, new_y_min, new_x_max, new_y_max)
634        })
635        .collect()
636}
637
638/// Merge two bounding boxes according to the specified mode.
639///
640/// # Arguments
641/// * `box1` - First bounding box
642/// * `box2` - Second bounding box
643/// * `mode` - Merge mode to apply
644///
645/// # Returns
646/// Merged bounding box according to the mode
647pub fn merge_boxes(box1: &BoundingBox, box2: &BoundingBox, mode: MergeBboxMode) -> BoundingBox {
648    let (x1_min, y1_min, x1_max, y1_max) = (box1.x_min(), box1.y_min(), box1.x_max(), box1.y_max());
649    let (x2_min, y2_min, x2_max, y2_max) = (box2.x_min(), box2.y_min(), box2.x_max(), box2.y_max());
650
651    let area1 = (x1_max - x1_min) * (y1_max - y1_min);
652    let area2 = (x2_max - x2_min) * (y2_max - y2_min);
653
654    match mode {
655        MergeBboxMode::Large => {
656            // Keep the larger bounding box
657            if area1 >= area2 {
658                box1.clone()
659            } else {
660                box2.clone()
661            }
662        }
663        MergeBboxMode::Small => {
664            // Keep the smaller bounding box
665            if area1 <= area2 {
666                box1.clone()
667            } else {
668                box2.clone()
669            }
670        }
671        MergeBboxMode::Union => {
672            // Merge to union of bounding boxes
673            let union_x_min = x1_min.min(x2_min);
674            let union_y_min = y1_min.min(y2_min);
675            let union_x_max = x1_max.max(x2_max);
676            let union_y_max = y1_max.max(y2_max);
677            BoundingBox::from_coords(union_x_min, union_y_min, union_x_max, union_y_max)
678        }
679    }
680}
681
682/// Apply Non-Maximum Suppression with per-class merge modes.
683///
684/// Unlike standard NMS which simply suppresses (discards) overlapping boxes,
685/// this function can merge overlapping boxes according to the specified mode.
686///
687/// # Arguments
688/// * `boxes` - Input bounding boxes
689/// * `classes` - Class IDs for each box
690/// * `scores` - Confidence scores for each box
691/// * `class_labels` - Mapping from class ID to label string
692/// * `class_merge_modes` - Per-class merge modes (label -> mode)
693/// * `nms_threshold` - IoU threshold for overlap detection
694/// * `max_detections` - Maximum number of detections to return
695///
696/// # Returns
697/// Tuple of (filtered_boxes, filtered_classes, filtered_scores)
698pub fn apply_nms_with_merge(
699    boxes: Vec<BoundingBox>,
700    classes: Vec<usize>,
701    scores: Vec<f32>,
702    class_labels: &HashMap<usize, String>,
703    class_merge_modes: &HashMap<String, MergeBboxMode>,
704    nms_threshold: f32,
705    max_detections: usize,
706) -> (Vec<BoundingBox>, Vec<usize>, Vec<f32>) {
707    if boxes.is_empty() {
708        return (boxes, classes, scores);
709    }
710
711    // Sort by score in descending order
712    let mut indices: Vec<usize> = (0..boxes.len()).collect();
713    indices.sort_by(|&a, &b| {
714        scores[b]
715            .partial_cmp(&scores[a])
716            .unwrap_or(std::cmp::Ordering::Equal)
717    });
718
719    let mut result_boxes = Vec::new();
720    let mut result_classes = Vec::new();
721    let mut result_scores = Vec::new();
722    let mut result_order_indices = Vec::new();
723    let mut processed = vec![false; boxes.len()];
724
725    for &i in &indices {
726        if processed[i] {
727            continue;
728        }
729
730        processed[i] = true;
731
732        // Get merge mode for this class
733        let class_label = class_labels
734            .get(&classes[i])
735            .map(|s| s.as_str())
736            .unwrap_or("unknown");
737        let merge_mode = class_merge_modes
738            .get(class_label)
739            .copied()
740            .unwrap_or(MergeBboxMode::Large);
741
742        let mut merged_box = boxes[i].clone();
743        let mut best_score = scores[i];
744        let mut order_idx = i;
745
746        // Find overlapping boxes of the same class and merge them
747        for &j in &indices {
748            if i != j && !processed[j] && classes[i] == classes[j] {
749                let iou = calculate_iou_static(&merged_box, &boxes[j]);
750                if iou > nms_threshold {
751                    // Merge the boxes
752                    merged_box = merge_boxes(&merged_box, &boxes[j], merge_mode);
753                    best_score = best_score.max(scores[j]);
754                    order_idx = order_idx.min(j);
755                    processed[j] = true;
756                }
757            }
758        }
759
760        result_boxes.push(merged_box);
761        result_classes.push(classes[i]);
762        result_scores.push(best_score);
763        result_order_indices.push(order_idx);
764    }
765
766    // First, apply max_detections limit based on score (NMS already processed in score order,
767    // so result_* vectors are implicitly score-ordered). This ensures we keep the highest-scoring
768    // detections rather than earliest ones.
769    let take_count = max_detections.min(result_boxes.len());
770
771    // Preserve input ordering for downstream consumers (e.g., PP-DocLayoutV2 reading-order output).
772    // We keep the score-based selection above, but sort the top-N merged results by the earliest
773    // original index in each merged group.
774    let mut merged: Vec<(usize, BoundingBox, usize, f32)> = result_order_indices
775        .into_iter()
776        .zip(result_boxes)
777        .zip(result_classes)
778        .zip(result_scores)
779        .map(|(((order, bbox), class_id), score)| (order, bbox, class_id, score))
780        .take(take_count) // Apply max_detections limit BEFORE reordering
781        .collect();
782
783    merged.sort_by(|(a, _, _, _), (b, _, _, _)| a.cmp(b));
784
785    let mut final_boxes = Vec::new();
786    let mut final_classes = Vec::new();
787    let mut final_scores = Vec::new();
788
789    for (_, bbox, class_id, score) in merged {
790        final_boxes.push(bbox);
791        final_classes.push(class_id);
792        final_scores.push(score);
793    }
794
795    (final_boxes, final_classes, final_scores)
796}
797
798/// Calculate IoU between two bounding boxes (standalone function).
799fn calculate_iou_static(box1: &BoundingBox, box2: &BoundingBox) -> f32 {
800    let (x1_min, y1_min, x1_max, y1_max) = (box1.x_min(), box1.y_min(), box1.x_max(), box1.y_max());
801    let (x2_min, y2_min, x2_max, y2_max) = (box2.x_min(), box2.y_min(), box2.x_max(), box2.y_max());
802
803    // Calculate intersection
804    let x_min = x1_min.max(x2_min);
805    let y_min = y1_min.max(y2_min);
806    let x_max = x1_max.min(x2_max);
807    let y_max = y1_max.min(y2_max);
808
809    if x_max <= x_min || y_max <= y_min {
810        return 0.0;
811    }
812
813    let intersection = (x_max - x_min) * (y_max - y_min);
814    let area1 = (x1_max - x1_min) * (y1_max - y1_min);
815    let area2 = (x2_max - x2_min) * (y2_max - y2_min);
816    let union = area1 + area2 - intersection;
817
818    if union > 0.0 {
819        intersection / union
820    } else {
821        0.0
822    }
823}
824
825impl Default for LayoutPostProcess {
826    fn default() -> Self {
827        Self {
828            num_classes: 5, // Default for basic layout detection
829            score_threshold: 0.5,
830            nms_threshold: 0.5,
831            max_detections: 100,
832            model_type: "picodet".to_string(),
833        }
834    }
835}
836
837#[cfg(test)]
838mod tests {
839    use super::*;
840
841    #[test]
842    fn test_layout_postprocess_creation() {
843        let processor = LayoutPostProcess::default();
844        assert_eq!(processor.num_classes, 5);
845        assert_eq!(processor.score_threshold, 0.5);
846    }
847
848    #[test]
849    fn test_iou_calculation() {
850        let processor = LayoutPostProcess::default();
851
852        // Two identical boxes should have IoU = 1.0
853        let box1 = BoundingBox::new(vec![
854            Point::new(0.0, 0.0),
855            Point::new(100.0, 0.0),
856            Point::new(100.0, 100.0),
857            Point::new(0.0, 100.0),
858        ]);
859        let box2 = box1.clone();
860
861        assert_eq!(processor.calculate_iou(&box1, &box2), 1.0);
862
863        // Non-overlapping boxes should have IoU = 0.0
864        let box3 = BoundingBox::new(vec![
865            Point::new(200.0, 200.0),
866            Point::new(300.0, 200.0),
867            Point::new(300.0, 300.0),
868            Point::new(200.0, 300.0),
869        ]);
870
871        assert_eq!(processor.calculate_iou(&box1, &box3), 0.0);
872    }
873}