fleischwolf_pdf/
layout.rs1use image::imageops::FilterType;
9use image::RgbImage;
10use ort::session::Session;
11use ort::value::Tensor;
12
13pub const LABELS: [&str; 17] = [
16 "caption",
17 "footnote",
18 "formula",
19 "list_item",
20 "page_footer",
21 "page_header",
22 "picture",
23 "section_header",
24 "table",
25 "text",
26 "title",
27 "document_index",
28 "code",
29 "checkbox_selected",
30 "checkbox_unselected",
31 "form",
32 "key_value_region",
33];
34
35#[derive(Debug, Clone)]
37pub struct Region {
38 pub label: &'static str,
39 pub score: f32,
40 pub l: f32,
41 pub t: f32,
42 pub r: f32,
43 pub b: f32,
44}
45
46const THRESHOLD: f32 = 0.3;
48const SIDE: u32 = 640;
49
50pub struct LayoutModel {
51 session: Session,
52}
53
54impl LayoutModel {
55 pub fn load() -> Result<Self, String> {
59 Self::load_with(crate::intra_threads())
60 }
61
62 pub fn load_with(intra: usize) -> Result<Self, String> {
66 let path = crate::model_path(
67 "DOCLING_LAYOUT_ONNX",
68 "models/layout_heron.onnx",
69 "models/layout_heron_int8.onnx",
70 );
71 let session = Session::builder()
72 .map_err(|e| format!("layout: builder: {e}"))?
73 .with_intra_threads(intra)
76 .map_err(|e| format!("layout: intra_threads: {e}"))?
77 .commit_from_file(&path)
78 .map_err(|e| format!("layout: load {path}: {e}"))?;
79 Ok(Self { session })
80 }
81
82 pub fn predict(
85 &mut self,
86 img: &RgbImage,
87 page_w: f32,
88 page_h: f32,
89 ) -> Result<Vec<Region>, String> {
90 let resized = image::imageops::resize(img, SIDE, SIDE, FilterType::Triangle);
93 let n = (SIDE * SIDE) as usize;
94 let mut data = vec![0f32; 3 * n];
95 for (i, px) in resized.pixels().enumerate() {
96 data[i] = px[0] as f32 / 255.0;
97 data[n + i] = px[1] as f32 / 255.0;
98 data[2 * n + i] = px[2] as f32 / 255.0;
99 }
100 let input = Tensor::from_array(([1usize, 3, SIDE as usize, SIDE as usize], data))
101 .map_err(|e| format!("layout: input tensor: {e}"))?;
102 let outputs = self
103 .session
104 .run(ort::inputs!["pixel_values" => input])
105 .map_err(|e| format!("layout: inference: {e}"))?;
106 let (lshape, logits) = outputs["logits"]
107 .try_extract_tensor::<f32>()
108 .map_err(|e| format!("layout: extract logits: {e}"))?;
109 let (_, boxes) = outputs["pred_boxes"]
110 .try_extract_tensor::<f32>()
111 .map_err(|e| format!("layout: extract boxes: {e}"))?;
112
113 let num_queries = lshape[1] as usize;
114 let num_classes = lshape[2] as usize;
115
116 let mut scored: Vec<(f32, usize)> = (0..num_queries * num_classes)
118 .map(|idx| (sigmoid(logits[idx]), idx))
119 .collect();
120 scored.sort_unstable_by(|a, b| b.0.total_cmp(&a.0));
121 scored.truncate(num_queries);
122
123 let mut regions = Vec::new();
124 for (score, idx) in scored {
125 if score <= THRESHOLD {
126 continue;
127 }
128 let label_id = idx % num_classes;
129 let q = idx / num_classes;
130 let cx = boxes[q * 4];
131 let cy = boxes[q * 4 + 1];
132 let w = boxes[q * 4 + 2];
133 let h = boxes[q * 4 + 3];
134 let l = (cx - w / 2.0) * page_w;
136 let t = (cy - h / 2.0) * page_h;
137 let r = (cx + w / 2.0) * page_w;
138 let b = (cy + h / 2.0) * page_h;
139 regions.push(Region {
140 label: LABELS.get(label_id).copied().unwrap_or("text"),
141 score,
142 l,
143 t,
144 r,
145 b,
146 });
147 }
148 Ok(regions)
149 }
150}
151
152fn sigmoid(x: f32) -> f32 {
153 1.0 / (1.0 + (-x).exp())
154}