Skip to main content

fleischwolf_pdf/
layout.rs

1//! Layout detection via the RT-DETR (`docling-layout-heron`) model exported to
2//! ONNX, run with `ort`. A port of docling-ibm-models' `LayoutPredictor`:
3//! resize the page image to 640×640 and rescale to `[0,1]` (the heron processor
4//! has `do_normalize=false`), run the model, then RT-DETR
5//! `post_process_object_detection` (sigmoid → top-k over query×class →
6//! center-to-corners boxes scaled to the page).
7
8use image::imageops::FilterType;
9use image::RgbImage;
10use ort::session::Session;
11use ort::value::Tensor;
12
13/// The 17 canonical layout classes, indexed by the model's class id
14/// (`config.json` `id2label`).
15pub const LABELS: [&str; 17] = [
16    "caption",
17    "footnote",
18    "formula",
19    "list_item",
20    "page_footer",
21    "page_header",
22    "picture",
23    "section_header",
24    "table",
25    "text",
26    "title",
27    "document_index",
28    "code",
29    "checkbox_selected",
30    "checkbox_unselected",
31    "form",
32    "key_value_region",
33];
34
35/// One detected region, in page points (top-left origin).
36#[derive(Debug, Clone)]
37pub struct Region {
38    pub label: &'static str,
39    pub score: f32,
40    pub l: f32,
41    pub t: f32,
42    pub r: f32,
43    pub b: f32,
44}
45
46/// Confidence threshold (docling-ibm-models `base_threshold`).
47const THRESHOLD: f32 = 0.3;
48const SIDE: u32 = 640;
49
50pub struct LayoutModel {
51    session: Session,
52}
53
54impl LayoutModel {
55    /// Load the ONNX model from `DOCLING_LAYOUT_ONNX`. Without the override,
56    /// prefers `models/layout_heron_int8.onnx` when present (the quantized
57    /// default; `FLEISCHWOLF_FP32=1` opts out), else `models/layout_heron.onnx`.
58    pub fn load() -> Result<Self, String> {
59        Self::load_with(crate::intra_threads())
60    }
61
62    /// Like [`load`](Self::load) but with an explicit intra-op thread count. A
63    /// parallel page-worker pool loads its helper models on a single thread each
64    /// and gets its speed-up from running pages concurrently instead.
65    pub fn load_with(intra: usize) -> Result<Self, String> {
66        let path = crate::model_path(
67            "DOCLING_LAYOUT_ONNX",
68            "models/layout_heron.onnx",
69            "models/layout_heron_int8.onnx",
70        );
71        let session = Session::builder()
72            .map_err(|e| format!("layout: builder: {e}"))?
73            // Let inference use the available cores (ort otherwise defaults low);
74            // a large PDF runs this model once per page.
75            .with_intra_threads(intra)
76            .map_err(|e| format!("layout: intra_threads: {e}"))?
77            .commit_from_file(&path)
78            .map_err(|e| format!("layout: load {path}: {e}"))?;
79        Ok(Self { session })
80    }
81
82    /// Detect layout regions on a page image. `page_w`/`page_h` are the page size
83    /// in points; returned boxes are in those coordinates.
84    pub fn predict(
85        &mut self,
86        img: &RgbImage,
87        page_w: f32,
88        page_h: f32,
89    ) -> Result<Vec<Region>, String> {
90        // Resize to 640×640 (RT-DETR ignores aspect ratio), rescale to [0,1],
91        // lay out as CHW.
92        let resized = image::imageops::resize(img, SIDE, SIDE, FilterType::Triangle);
93        let n = (SIDE * SIDE) as usize;
94        let mut data = vec![0f32; 3 * n];
95        for (i, px) in resized.pixels().enumerate() {
96            data[i] = px[0] as f32 / 255.0;
97            data[n + i] = px[1] as f32 / 255.0;
98            data[2 * n + i] = px[2] as f32 / 255.0;
99        }
100        let input = Tensor::from_array(([1usize, 3, SIDE as usize, SIDE as usize], data))
101            .map_err(|e| format!("layout: input tensor: {e}"))?;
102        let outputs = self
103            .session
104            .run(ort::inputs!["pixel_values" => input])
105            .map_err(|e| format!("layout: inference: {e}"))?;
106        let (lshape, logits) = outputs["logits"]
107            .try_extract_tensor::<f32>()
108            .map_err(|e| format!("layout: extract logits: {e}"))?;
109        let (_, boxes) = outputs["pred_boxes"]
110            .try_extract_tensor::<f32>()
111            .map_err(|e| format!("layout: extract boxes: {e}"))?;
112
113        let num_queries = lshape[1] as usize;
114        let num_classes = lshape[2] as usize;
115
116        // sigmoid over every (query, class); take the top `num_queries` scores.
117        let mut scored: Vec<(f32, usize)> = (0..num_queries * num_classes)
118            .map(|idx| (sigmoid(logits[idx]), idx))
119            .collect();
120        scored.sort_unstable_by(|a, b| b.0.total_cmp(&a.0));
121        scored.truncate(num_queries);
122
123        let mut regions = Vec::new();
124        for (score, idx) in scored {
125            if score <= THRESHOLD {
126                continue;
127            }
128            let label_id = idx % num_classes;
129            let q = idx / num_classes;
130            let cx = boxes[q * 4];
131            let cy = boxes[q * 4 + 1];
132            let w = boxes[q * 4 + 2];
133            let h = boxes[q * 4 + 3];
134            // center_to_corners, then scale normalized coords to page points.
135            let l = (cx - w / 2.0) * page_w;
136            let t = (cy - h / 2.0) * page_h;
137            let r = (cx + w / 2.0) * page_w;
138            let b = (cy + h / 2.0) * page_h;
139            regions.push(Region {
140                label: LABELS.get(label_id).copied().unwrap_or("text"),
141                score,
142                l,
143                t,
144                r,
145                b,
146            });
147        }
148        Ok(regions)
149    }
150}
151
152fn sigmoid(x: f32) -> f32 {
153    1.0 / (1.0 + (-x).exp())
154}