Skip to main content

oar_ocr_core/processors/
layout_postprocess.rs

1//! Layout Detection Post-processing
2//!
3//! This module implements post-processing for layout detection models including
4//! PicoDet, RT-DETR, and PP-DocLayout series models.
5
6use crate::domain::tasks::MergeBboxMode;
7use crate::processors::{BoundingBox, ImageScaleInfo, Point};
8use ndarray::{ArrayView3, Axis};
9use std::borrow::Cow;
10use std::collections::HashMap;
11
12type LayoutPostprocessOutput = (Vec<Vec<BoundingBox>>, Vec<Vec<usize>>, Vec<Vec<f32>>);
13type NmsResult = (Vec<BoundingBox>, Vec<usize>, Vec<f32>, Vec<(f32, f32)>);
14
15/// Layout detection post-processor for models like PicoDet and RT-DETR.
16///
17/// This processor converts model predictions into bounding boxes with class labels
18/// and confidence scores for document layout elements.
19#[derive(Debug, Clone)]
20pub struct LayoutPostProcess {
21    /// Number of classes the model predicts
22    num_classes: usize,
23    /// Score threshold for filtering predictions
24    score_threshold: f32,
25    /// Non-maximum suppression threshold
26    nms_threshold: f32,
27    /// Maximum number of detections to return
28    max_detections: usize,
29    /// Model type (e.g., "picodet", "rtdetr", "pp-doclayout")
30    model_type: String,
31}
32
33impl LayoutPostProcess {
34    /// Creates a new layout detection post-processor.
35    pub fn new(
36        num_classes: usize,
37        score_threshold: f32,
38        nms_threshold: f32,
39        max_detections: usize,
40        model_type: String,
41    ) -> Self {
42        Self {
43            num_classes,
44            score_threshold,
45            nms_threshold,
46            max_detections,
47            model_type,
48        }
49    }
50
51    /// Applies post-processing to layout detection model predictions.
52    ///
53    /// # Arguments
54    /// * `predictions` - Model output tensor [batch, num_boxes, 4 + num_classes]
55    /// * `img_shapes` - Original image dimensions for each image in batch
56    ///
57    /// # Returns
58    /// Tuple of (bounding_boxes, class_ids, scores) for each image in batch
59    pub fn apply(
60        &self,
61        predictions: &ndarray::Array4<f32>,
62        img_shapes: Vec<ImageScaleInfo>,
63    ) -> LayoutPostprocessOutput {
64        let batch_size = predictions.shape()[0];
65        let mut all_boxes = Vec::with_capacity(batch_size);
66        let mut all_classes = Vec::with_capacity(batch_size);
67        let mut all_scores = Vec::with_capacity(batch_size);
68
69        // Process each image in batch
70        for (batch_idx, img_shape) in img_shapes.into_iter().enumerate().take(batch_size) {
71            let pred = predictions.index_axis(Axis(0), batch_idx);
72
73            let (boxes, classes, scores) = match self.model_type.as_str() {
74                "picodet" => self.process_picodet(pred, &img_shape),
75                "rtdetr" => self.process_rtdetr(pred, &img_shape),
76                "pp-doclayout" => self.process_pp_doclayout(pred, &img_shape),
77                _ => self.process_standard(pred, &img_shape),
78            };
79
80            all_boxes.push(boxes);
81            all_classes.push(classes);
82            all_scores.push(scores);
83        }
84
85        (all_boxes, all_classes, all_scores)
86    }
87
88    /// Process PicoDet model output.
89    fn process_picodet(
90        &self,
91        predictions: ArrayView3<f32>,
92        img_shape: &ImageScaleInfo,
93    ) -> (Vec<BoundingBox>, Vec<usize>, Vec<f32>) {
94        let mut boxes = Vec::new();
95        let mut classes = Vec::new();
96        let mut scores = Vec::new();
97
98        let orig_width = img_shape.src_w;
99        let orig_height = img_shape.src_h;
100        let shape = predictions.shape();
101        if shape.len() != 3 || shape[2] == 0 {
102            return (boxes, classes, scores);
103        }
104
105        let total_boxes = shape[0] * shape[1];
106        if total_boxes == 0 {
107            return (boxes, classes, scores);
108        }
109
110        let feature_dim = shape[2];
111        let data: Cow<'_, [f32]> = match predictions.as_slice() {
112            Some(slice) => Cow::Borrowed(slice),
113            None => {
114                let (mut vec, offset) = predictions.to_owned().into_raw_vec_and_offset();
115                if let Some(offset) = offset
116                    && offset != 0
117                {
118                    vec.drain(0..offset);
119                }
120                Cow::Owned(vec)
121            }
122        };
123
124        for box_idx in 0..total_boxes {
125            let start = box_idx * feature_dim;
126            let end = start + feature_dim;
127
128            if end > data.len() {
129                break;
130            }
131
132            let row = &data[start..end];
133            if feature_dim == 4 + self.num_classes {
134                // Format: [x1, y1, x2, y2, scores...]
135                let (max_class, max_score) = row[4..].iter().enumerate().fold(
136                    (0usize, 0.0f32),
137                    |(best_cls, best_score), (cls_idx, &score)| {
138                        if score > best_score {
139                            (cls_idx, score)
140                        } else {
141                            (best_cls, best_score)
142                        }
143                    },
144                );
145
146                if max_score < self.score_threshold {
147                    continue;
148                }
149
150                let (sx1, sy1, sx2, sy2) = self.convert_bbox_coords(
151                    row[0],
152                    row[1],
153                    row[2],
154                    row[3],
155                    orig_width,
156                    orig_height,
157                );
158
159                if !Self::is_valid_box(sx1, sy1, sx2, sy2) {
160                    continue;
161                }
162
163                let bbox = BoundingBox::new(vec![
164                    Point::new(sx1, sy1),
165                    Point::new(sx2, sy1),
166                    Point::new(sx2, sy2),
167                    Point::new(sx1, sy2),
168                ]);
169
170                boxes.push(bbox);
171                classes.push(max_class);
172                scores.push(max_score);
173            } else if feature_dim >= 6
174                && let Some((class_id, score, x1, y1, x2, y2)) = self.parse_compact_prediction(row)
175            {
176                if score < self.score_threshold || class_id >= self.num_classes {
177                    continue;
178                }
179
180                let (sx1, sy1, sx2, sy2) =
181                    self.convert_bbox_coords(x1, y1, x2, y2, orig_width, orig_height);
182
183                if !Self::is_valid_box(sx1, sy1, sx2, sy2) {
184                    continue;
185                }
186
187                let bbox = BoundingBox::new(vec![
188                    Point::new(sx1, sy1),
189                    Point::new(sx2, sy1),
190                    Point::new(sx2, sy2),
191                    Point::new(sx1, sy2),
192                ]);
193
194                boxes.push(bbox);
195                classes.push(class_id);
196                scores.push(score);
197            }
198        }
199
200        self.apply_nms(boxes, classes, scores)
201    }
202
203    /// Process RT-DETR model output.
204    fn process_rtdetr(
205        &self,
206        predictions: ArrayView3<f32>,
207        img_shape: &ImageScaleInfo,
208    ) -> (Vec<BoundingBox>, Vec<usize>, Vec<f32>) {
209        // RT-DETR has similar output format to PicoDet
210        self.process_picodet(predictions, img_shape)
211    }
212
213    /// Process PP-DocLayout model output.
214    ///
215    /// Handles 6-dim format (PP-DocLayout), 7-dim format (PP-DocLayoutV3), and 8-dim format (PP-DocLayoutV2).
216    /// - 6-dim: [class_id, score, x1, y1, x2, y2]
217    /// - 7-dim: [class_id, score, x1, y1, x2, y2, extra]
218    /// - 8-dim: [class_id, score, x1, y1, x2, y2, col_index, row_index]
219    ///
220    /// For 8-dim format, boxes are sorted by reading order (col_index ascending, row_index ascending)
221    /// after NMS filtering.
222    fn process_pp_doclayout(
223        &self,
224        predictions: ArrayView3<f32>,
225        img_shape: &ImageScaleInfo,
226    ) -> (Vec<BoundingBox>, Vec<usize>, Vec<f32>) {
227        // PP-DocLayout outputs in [num_boxes, 1, N] format
228        // where N is 6 or 8 depending on model version
229        let shape = predictions.shape();
230        let num_boxes = shape[0];
231        let feature_dim = shape[2];
232
233        let mut boxes = Vec::new();
234        let mut classes = Vec::new();
235        let mut scores = Vec::new();
236        let mut reading_orders: Vec<(f32, f32)> = Vec::new();
237
238        let orig_width = img_shape.src_w;
239        let orig_height = img_shape.src_h;
240
241        let has_reading_order = feature_dim == 8;
242
243        // Extract predictions
244        for box_idx in 0..num_boxes {
245            // predictions is [num_boxes, 1, N], so we use 3D indexing [box_idx, 0, i]
246            let class_id = predictions[[box_idx, 0, 0]] as i32;
247            let score = predictions[[box_idx, 0, 1]];
248            let x1 = predictions[[box_idx, 0, 2]];
249            let y1 = predictions[[box_idx, 0, 3]];
250            let x2 = predictions[[box_idx, 0, 4]];
251            let y2 = predictions[[box_idx, 0, 5]];
252
253            // Extract reading order info if available (8-dim format)
254            // Default to (0, box_idx) for 6-dim format to maintain original order
255            let reading_order = if has_reading_order {
256                (predictions[[box_idx, 0, 6]], predictions[[box_idx, 0, 7]])
257            } else {
258                (0.0, box_idx as f32)
259            };
260
261            // Filter by threshold and valid class
262            if score < self.score_threshold
263                || class_id < 0
264                || (class_id as usize) >= self.num_classes
265            {
266                continue;
267            }
268
269            // PP-DocLayout-style models may emit either absolute pixel coords or normalized coords.
270            // Use the same normalization heuristic as other detectors for robustness.
271            let (sx1, sy1, sx2, sy2) =
272                self.convert_bbox_coords(x1, y1, x2, y2, orig_width, orig_height);
273            if !Self::is_valid_box(sx1, sy1, sx2, sy2) {
274                continue;
275            }
276
277            let bbox = BoundingBox::new(vec![
278                Point::new(sx1, sy1),
279                Point::new(sx2, sy1),
280                Point::new(sx2, sy2),
281                Point::new(sx1, sy2),
282            ]);
283
284            boxes.push(bbox);
285            classes.push(class_id as usize);
286            scores.push(score);
287            reading_orders.push(reading_order);
288        }
289
290        // Apply NMS with reading order preservation
291        let (filtered_boxes, filtered_classes, filtered_scores, filtered_reading_orders) =
292            self.apply_nms_with_reading_order(boxes, classes, scores, reading_orders);
293
294        // Sort by reading order if we have 8-dim format
295        if has_reading_order && !filtered_boxes.is_empty() {
296            let mut indices: Vec<usize> = (0..filtered_boxes.len()).collect();
297            indices.sort_by(|&i, &j| {
298                let (col_i, row_i) = filtered_reading_orders[i];
299                let (col_j, row_j) = filtered_reading_orders[j];
300                // Sort by col_index ascending, then row_index ascending
301                // Use total_cmp to handle NaN/infinity values gracefully
302                col_i
303                    .total_cmp(&col_j)
304                    .then_with(|| row_i.total_cmp(&row_j))
305            });
306
307            let sorted_boxes = indices.iter().map(|&i| filtered_boxes[i].clone()).collect();
308            let sorted_classes = indices.iter().map(|&i| filtered_classes[i]).collect();
309            let sorted_scores = indices.iter().map(|&i| filtered_scores[i]).collect();
310
311            (sorted_boxes, sorted_classes, sorted_scores)
312        } else {
313            (filtered_boxes, filtered_classes, filtered_scores)
314        }
315    }
316
317    /// Apply NMS with reading order preservation.
318    fn apply_nms_with_reading_order(
319        &self,
320        boxes: Vec<BoundingBox>,
321        classes: Vec<usize>,
322        scores: Vec<f32>,
323        reading_orders: Vec<(f32, f32)>,
324    ) -> NmsResult {
325        if boxes.is_empty() {
326            return (boxes, classes, scores, reading_orders);
327        }
328
329        let keep = self.compute_nms_keep_indices(&boxes, &classes, &scores);
330
331        let filtered_boxes: Vec<BoundingBox> = keep.iter().map(|&i| boxes[i].clone()).collect();
332        let filtered_classes: Vec<usize> = keep.iter().map(|&i| classes[i]).collect();
333        let filtered_scores: Vec<f32> = keep.iter().map(|&i| scores[i]).collect();
334        let filtered_reading_orders: Vec<(f32, f32)> =
335            keep.iter().map(|&i| reading_orders[i]).collect();
336
337        (
338            filtered_boxes,
339            filtered_classes,
340            filtered_scores,
341            filtered_reading_orders,
342        )
343    }
344
345    /// Process standard detection model output.
346    fn process_standard(
347        &self,
348        predictions: ArrayView3<f32>,
349        img_shape: &ImageScaleInfo,
350    ) -> (Vec<BoundingBox>, Vec<usize>, Vec<f32>) {
351        self.process_picodet(predictions, img_shape)
352    }
353
354    fn parse_compact_prediction(&self, row: &[f32]) -> Option<(usize, f32, f32, f32, f32, f32)> {
355        if row.len() < 6 {
356            return None;
357        }
358
359        // Format: [class_id, score, x1, y1, x2, y2]
360        let score_is_valid = if self.model_type == "rtdetr" {
361            row[1].is_finite()
362        } else {
363            Self::is_valid_score(row[1])
364        };
365
366        if score_is_valid && Self::is_valid_class(row[0], self.num_classes) {
367            let class_id = row[0].round() as i32;
368            if class_id >= 0 {
369                let score = self.adjust_score(row[1]);
370                return Some((class_id as usize, score, row[2], row[3], row[4], row[5]));
371            }
372        }
373
374        // Alternate format: [x1, y1, x2, y2, score, class_id]
375        let score_is_valid = if self.model_type == "rtdetr" {
376            row[4].is_finite()
377        } else {
378            Self::is_valid_score(row[4])
379        };
380        if score_is_valid && Self::is_valid_class(row[5], self.num_classes) {
381            let class_id = row[5].round() as i32;
382            if class_id >= 0 {
383                let score = self.adjust_score(row[4]);
384                return Some((class_id as usize, score, row[0], row[1], row[2], row[3]));
385            }
386        }
387
388        // Alternate format: [score, class_id, x1, y1, x2, y2]
389        let score_is_valid = if self.model_type == "rtdetr" {
390            row[0].is_finite()
391        } else {
392            Self::is_valid_score(row[0])
393        };
394        if score_is_valid && Self::is_valid_class(row[1], self.num_classes) {
395            let class_id = row[1].round() as i32;
396            if class_id >= 0 {
397                let score = self.adjust_score(row[0]);
398                return Some((class_id as usize, score, row[2], row[3], row[4], row[5]));
399            }
400        }
401
402        None
403    }
404
405    fn convert_bbox_coords(
406        &self,
407        x1: f32,
408        y1: f32,
409        x2: f32,
410        y2: f32,
411        orig_width: f32,
412        orig_height: f32,
413    ) -> (f32, f32, f32, f32) {
414        let normalized = x2 <= 1.05
415            && y2 <= 1.05
416            && x1 >= -0.05
417            && y1 >= -0.05
418            && orig_width > 0.0
419            && orig_height > 0.0;
420
421        if normalized {
422            (
423                x1.clamp(0.0, 1.0) * orig_width,
424                y1.clamp(0.0, 1.0) * orig_height,
425                x2.clamp(0.0, 1.0) * orig_width,
426                y2.clamp(0.0, 1.0) * orig_height,
427            )
428        } else {
429            (
430                x1.clamp(0.0, orig_width),
431                y1.clamp(0.0, orig_height),
432                x2.clamp(0.0, orig_width),
433                y2.clamp(0.0, orig_height),
434            )
435        }
436    }
437
438    fn is_valid_box(x1: f32, y1: f32, x2: f32, y2: f32) -> bool {
439        x2 > x1 && y2 > y1 && x1.is_finite() && y1.is_finite() && x2.is_finite() && y2.is_finite()
440    }
441
442    fn is_valid_score(score: f32) -> bool {
443        score.is_finite() && (0.0..=1.0 + f32::EPSILON).contains(&score)
444    }
445
446    fn is_valid_class(raw: f32, num_classes: usize) -> bool {
447        if !raw.is_finite() {
448            return false;
449        }
450        let class_id = raw.round() as i32;
451        class_id >= 0 && (class_id as usize) < num_classes + 5
452    }
453
454    fn adjust_score(&self, raw_score: f32) -> f32 {
455        if self.model_type == "rtdetr" {
456            raw_score.clamp(0.0, 1.0)
457        } else {
458            raw_score
459        }
460    }
461
462    /// Compute indices to keep after NMS.
463    /// Returns the indices of boxes that survive non-maximum suppression.
464    fn compute_nms_keep_indices(
465        &self,
466        boxes: &[BoundingBox],
467        classes: &[usize],
468        scores: &[f32],
469    ) -> Vec<usize> {
470        // Sort by score in descending order
471        let mut indices: Vec<usize> = (0..boxes.len()).collect();
472        indices.sort_by(|&a, &b| {
473            scores[b]
474                .partial_cmp(&scores[a])
475                .unwrap_or(std::cmp::Ordering::Equal)
476        });
477
478        let mut keep = Vec::new();
479        let mut suppressed = vec![false; boxes.len()];
480
481        for &i in &indices {
482            if suppressed[i] {
483                continue;
484            }
485
486            keep.push(i);
487            if keep.len() >= self.max_detections {
488                break;
489            }
490
491            // Suppress boxes with high IoU
492            for &j in &indices {
493                if i != j && !suppressed[j] && classes[i] == classes[j] {
494                    let iou = self.calculate_iou(&boxes[i], &boxes[j]);
495                    if iou > self.nms_threshold {
496                        suppressed[j] = true;
497                    }
498                }
499            }
500        }
501
502        keep
503    }
504
505    /// Apply Non-Maximum Suppression to filter overlapping boxes.
506    fn apply_nms(
507        &self,
508        boxes: Vec<BoundingBox>,
509        classes: Vec<usize>,
510        scores: Vec<f32>,
511    ) -> (Vec<BoundingBox>, Vec<usize>, Vec<f32>) {
512        if boxes.is_empty() {
513            return (boxes, classes, scores);
514        }
515
516        let keep = self.compute_nms_keep_indices(&boxes, &classes, &scores);
517
518        let filtered_boxes: Vec<BoundingBox> = keep.iter().map(|&i| boxes[i].clone()).collect();
519        let filtered_classes: Vec<usize> = keep.iter().map(|&i| classes[i]).collect();
520        let filtered_scores: Vec<f32> = keep.iter().map(|&i| scores[i]).collect();
521
522        (filtered_boxes, filtered_classes, filtered_scores)
523    }
524
525    /// Calculate Intersection over Union between two bounding boxes.
526    fn calculate_iou(&self, box1: &BoundingBox, box2: &BoundingBox) -> f32 {
527        // Get bounding rectangle for box1
528        let (x1_min, y1_min, x1_max, y1_max) = self.get_bbox_bounds(box1);
529
530        // Get bounding rectangle for box2
531        let (x2_min, y2_min, x2_max, y2_max) = self.get_bbox_bounds(box2);
532
533        // Calculate intersection
534        let x_min = x1_min.max(x2_min);
535        let y_min = y1_min.max(y2_min);
536        let x_max = x1_max.min(x2_max);
537        let y_max = y1_max.min(y2_max);
538
539        if x_max <= x_min || y_max <= y_min {
540            return 0.0;
541        }
542
543        let intersection = (x_max - x_min) * (y_max - y_min);
544        let area1 = (x1_max - x1_min) * (y1_max - y1_min);
545        let area2 = (x2_max - x2_min) * (y2_max - y2_min);
546        let union = area1 + area2 - intersection;
547
548        if union > 0.0 {
549            intersection / union
550        } else {
551            0.0
552        }
553    }
554
555    /// Get the minimum and maximum coordinates from a bounding box.
556    fn get_bbox_bounds(&self, bbox: &BoundingBox) -> (f32, f32, f32, f32) {
557        if bbox.points.is_empty() {
558            return (0.0, 0.0, 0.0, 0.0);
559        }
560
561        let mut x_min = f32::INFINITY;
562        let mut y_min = f32::INFINITY;
563        let mut x_max = f32::NEG_INFINITY;
564        let mut y_max = f32::NEG_INFINITY;
565
566        for point in &bbox.points {
567            x_min = x_min.min(point.x);
568            y_min = y_min.min(point.y);
569            x_max = x_max.max(point.x);
570            y_max = y_max.max(point.y);
571        }
572
573        (x_min, y_min, x_max, y_max)
574    }
575}
576
577/// Apply unclip ratio to expand/shrink bounding boxes while keeping center fixed.
578///
579/// This follows PP-StructureV3's `layout_unclip_ratio` parameter behavior.
580///
581/// # Arguments
582/// * `boxes` - Input bounding boxes
583/// * `classes` - Class IDs for each box
584/// * `width_ratio` - Ratio to apply to box width (1.0 = no change)
585/// * `height_ratio` - Ratio to apply to box height (1.0 = no change)
586/// * `per_class_ratios` - Optional per-class ratios: class_id -> (width_ratio, height_ratio)
587///
588/// # Returns
589/// Transformed bounding boxes with same center but scaled dimensions
590pub fn unclip_boxes(
591    boxes: &[BoundingBox],
592    classes: &[usize],
593    width_ratio: f32,
594    height_ratio: f32,
595    per_class_ratios: Option<&std::collections::HashMap<usize, (f32, f32)>>,
596) -> Vec<BoundingBox> {
597    boxes
598        .iter()
599        .zip(classes.iter())
600        .map(|(bbox, &class_id)| {
601            // Get ratio for this class
602            let (w_ratio, h_ratio) = per_class_ratios
603                .and_then(|ratios| ratios.get(&class_id).copied())
604                .unwrap_or((width_ratio, height_ratio));
605
606            // Skip if ratios are 1.0 (no change)
607            if (w_ratio - 1.0).abs() < 1e-6 && (h_ratio - 1.0).abs() < 1e-6 {
608                return bbox.clone();
609            }
610
611            // Get current bounds
612            let x_min = bbox.x_min();
613            let y_min = bbox.y_min();
614            let x_max = bbox.x_max();
615            let y_max = bbox.y_max();
616
617            // Calculate center and dimensions
618            let width = x_max - x_min;
619            let height = y_max - y_min;
620            let center_x = x_min + width / 2.0;
621            let center_y = y_min + height / 2.0;
622
623            // Apply ratio
624            let new_width = width * w_ratio;
625            let new_height = height * h_ratio;
626
627            // Calculate new bounds
628            let new_x_min = center_x - new_width / 2.0;
629            let new_y_min = center_y - new_height / 2.0;
630            let new_x_max = center_x + new_width / 2.0;
631            let new_y_max = center_y + new_height / 2.0;
632
633            BoundingBox::from_coords(new_x_min, new_y_min, new_x_max, new_y_max)
634        })
635        .collect()
636}
637
638/// Merge two bounding boxes according to the specified mode.
639///
640/// # Arguments
641/// * `box1` - First bounding box
642/// * `box2` - Second bounding box
643/// * `mode` - Merge mode to apply
644///
645/// # Returns
646/// Merged bounding box according to the mode
647pub fn merge_boxes(box1: &BoundingBox, box2: &BoundingBox, mode: MergeBboxMode) -> BoundingBox {
648    let (x1_min, y1_min, x1_max, y1_max) = (box1.x_min(), box1.y_min(), box1.x_max(), box1.y_max());
649    let (x2_min, y2_min, x2_max, y2_max) = (box2.x_min(), box2.y_min(), box2.x_max(), box2.y_max());
650
651    let area1 = (x1_max - x1_min) * (y1_max - y1_min);
652    let area2 = (x2_max - x2_min) * (y2_max - y2_min);
653
654    match mode {
655        MergeBboxMode::Large => {
656            // Keep the larger bounding box
657            if area1 >= area2 {
658                box1.clone()
659            } else {
660                box2.clone()
661            }
662        }
663        MergeBboxMode::Small => {
664            // Keep the smaller bounding box
665            if area1 <= area2 {
666                box1.clone()
667            } else {
668                box2.clone()
669            }
670        }
671        MergeBboxMode::Union => {
672            // Merge to union of bounding boxes
673            let union_x_min = x1_min.min(x2_min);
674            let union_y_min = y1_min.min(y2_min);
675            let union_x_max = x1_max.max(x2_max);
676            let union_y_max = y1_max.max(y2_max);
677            BoundingBox::from_coords(union_x_min, union_y_min, union_x_max, union_y_max)
678        }
679    }
680}
681
682/// Apply Non-Maximum Suppression with per-class merge modes.
683///
684/// Unlike standard NMS which simply suppresses (discards) overlapping boxes,
685/// this function can merge overlapping boxes according to the specified mode.
686///
687/// # Arguments
688/// * `boxes` - Input bounding boxes
689/// * `classes` - Class IDs for each box
690/// * `scores` - Confidence scores for each box
691/// * `class_labels` - Mapping from class ID to label string
692/// * `class_merge_modes` - Per-class merge modes (label -> mode)
693/// * `nms_threshold` - IoU threshold for overlap detection
694/// * `max_detections` - Maximum number of detections to return
695///
696/// # Returns
697/// Tuple of (filtered_boxes, filtered_classes, filtered_scores)
698pub fn apply_nms_with_merge(
699    boxes: Vec<BoundingBox>,
700    classes: Vec<usize>,
701    scores: Vec<f32>,
702    class_labels: &HashMap<usize, String>,
703    class_merge_modes: &HashMap<String, MergeBboxMode>,
704    nms_threshold: f32,
705    max_detections: usize,
706) -> (Vec<BoundingBox>, Vec<usize>, Vec<f32>) {
707    if boxes.is_empty() {
708        return (boxes, classes, scores);
709    }
710
711    // Sort by score in descending order
712    let mut indices: Vec<usize> = (0..boxes.len()).collect();
713    indices.sort_by(|&a, &b| {
714        scores[b]
715            .partial_cmp(&scores[a])
716            .unwrap_or(std::cmp::Ordering::Equal)
717    });
718
719    let mut result_boxes = Vec::new();
720    let mut result_classes = Vec::new();
721    let mut result_scores = Vec::new();
722    let mut result_order_indices = Vec::new();
723    let mut processed = vec![false; boxes.len()];
724
725    for &i in &indices {
726        if processed[i] {
727            continue;
728        }
729
730        processed[i] = true;
731
732        // Get merge mode for this class
733        let class_label = class_labels
734            .get(&classes[i])
735            .map(|s| s.as_str())
736            .unwrap_or("unknown");
737        let merge_mode = class_merge_modes
738            .get(class_label)
739            .copied()
740            .unwrap_or(MergeBboxMode::Large);
741
742        let mut merged_box = boxes[i].clone();
743        let mut best_score = scores[i];
744        let mut order_idx = i;
745
746        // Find overlapping boxes of the same class and merge them
747        for &j in &indices {
748            if i != j && !processed[j] && classes[i] == classes[j] {
749                let iou = calculate_iou_static(&merged_box, &boxes[j]);
750                if iou > nms_threshold {
751                    // Merge the boxes
752                    merged_box = merge_boxes(&merged_box, &boxes[j], merge_mode);
753                    best_score = best_score.max(scores[j]);
754                    order_idx = order_idx.min(j);
755                    processed[j] = true;
756                }
757            }
758        }
759
760        result_boxes.push(merged_box);
761        result_classes.push(classes[i]);
762        result_scores.push(best_score);
763        result_order_indices.push(order_idx);
764    }
765
766    // First, apply max_detections limit based on score (NMS already processed in score order,
767    // so result_* vectors are implicitly score-ordered). This ensures we keep the highest-scoring
768    // detections rather than earliest ones.
769    let take_count = max_detections.min(result_boxes.len());
770
771    // Preserve input ordering for downstream consumers (e.g., PP-DocLayoutV2 reading-order output).
772    // We keep the score-based selection above, but sort the top-N merged results by the earliest
773    // original index in each merged group.
774    let mut merged: Vec<(usize, BoundingBox, usize, f32)> = result_order_indices
775        .into_iter()
776        .zip(result_boxes)
777        .zip(result_classes)
778        .zip(result_scores)
779        .map(|(((order, bbox), class_id), score)| (order, bbox, class_id, score))
780        .take(take_count) // Apply max_detections limit BEFORE reordering
781        .collect();
782
783    merged.sort_by(|(a, _, _, _), (b, _, _, _)| a.cmp(b));
784
785    let mut final_boxes = Vec::new();
786    let mut final_classes = Vec::new();
787    let mut final_scores = Vec::new();
788
789    for (_, bbox, class_id, score) in merged {
790        final_boxes.push(bbox);
791        final_classes.push(class_id);
792        final_scores.push(score);
793    }
794
795    (final_boxes, final_classes, final_scores)
796}
797
798/// Calculate IoU between two bounding boxes (standalone function).
799fn calculate_iou_static(box1: &BoundingBox, box2: &BoundingBox) -> f32 {
800    let (x1_min, y1_min, x1_max, y1_max) = (box1.x_min(), box1.y_min(), box1.x_max(), box1.y_max());
801    let (x2_min, y2_min, x2_max, y2_max) = (box2.x_min(), box2.y_min(), box2.x_max(), box2.y_max());
802
803    // Calculate intersection
804    let x_min = x1_min.max(x2_min);
805    let y_min = y1_min.max(y2_min);
806    let x_max = x1_max.min(x2_max);
807    let y_max = y1_max.min(y2_max);
808
809    if x_max <= x_min || y_max <= y_min {
810        return 0.0;
811    }
812
813    let intersection = (x_max - x_min) * (y_max - y_min);
814    let area1 = (x1_max - x1_min) * (y1_max - y1_min);
815    let area2 = (x2_max - x2_min) * (y2_max - y2_min);
816    let union = area1 + area2 - intersection;
817
818    if union > 0.0 {
819        intersection / union
820    } else {
821        0.0
822    }
823}
824
825impl Default for LayoutPostProcess {
826    fn default() -> Self {
827        Self {
828            num_classes: 5, // Default for basic layout detection
829            score_threshold: 0.5,
830            nms_threshold: 0.5,
831            max_detections: 100,
832            model_type: "picodet".to_string(),
833        }
834    }
835}
836
837#[cfg(test)]
838mod tests {
839    use super::*;
840
841    #[test]
842    fn test_layout_postprocess_creation() {
843        let processor = LayoutPostProcess::default();
844        assert_eq!(processor.num_classes, 5);
845        assert_eq!(processor.score_threshold, 0.5);
846    }
847
848    #[test]
849    fn test_iou_calculation() {
850        let processor = LayoutPostProcess::default();
851
852        // Two identical boxes should have IoU = 1.0
853        let box1 = BoundingBox::new(vec![
854            Point::new(0.0, 0.0),
855            Point::new(100.0, 0.0),
856            Point::new(100.0, 100.0),
857            Point::new(0.0, 100.0),
858        ]);
859        let box2 = box1.clone();
860
861        assert_eq!(processor.calculate_iou(&box1, &box2), 1.0);
862
863        // Non-overlapping boxes should have IoU = 0.0
864        let box3 = BoundingBox::new(vec![
865            Point::new(200.0, 200.0),
866            Point::new(300.0, 200.0),
867            Point::new(300.0, 300.0),
868            Point::new(200.0, 300.0),
869        ]);
870
871        assert_eq!(processor.calculate_iou(&box1, &box3), 0.0);
872    }
873}