1use image::imageops;
2use ndarray::{Array, ArrayBase, Dim, OwnedRepr};
3use ort::{Session, SessionBuilder, SessionOutputs};
4
5pub use crate::error::Result;
6use crate::{utils::vec_to_bbox, LayoutElement};
7
8pub struct Detectron2Model {
10 model_name: String,
11 model: ort::Session,
12 confidence_threshold: f32,
13 label_map: Vec<(i64, String)>,
14 confidence_score_index: usize,
15}
16
17#[allow(non_camel_case_types)]
18pub enum Detectron2PretrainedModels {
20 FASTER_RCNN_R_50_FPN_3X,
21 MASK_RCNN_X_101_32X8D_FPN_3x,
22}
23
24impl Detectron2PretrainedModels {
25 pub fn name(&self) -> &str {
27 match self {
28 _ => self.hf_repo(),
29 }
30 }
31
32 pub fn hf_repo(&self) -> &str {
34 match self {
35 Self::FASTER_RCNN_R_50_FPN_3X => "unstructuredio/detectron2_faster_rcnn_R_50_FPN_3x",
36 Self::MASK_RCNN_X_101_32X8D_FPN_3x => {
37 "unstructuredio/detectron2_mask_rcnn_X_101_32x8d_FPN_3x"
38 }
39 }
40 }
41
42 pub fn hf_filename(&self) -> &str {
44 match self {
45 Self::FASTER_RCNN_R_50_FPN_3X => "model.onnx",
46 Self::MASK_RCNN_X_101_32X8D_FPN_3x => "model.onnx",
47 }
48 }
49
50 pub fn label_map(&self) -> Vec<(i64, String)> {
52 match self {
53 Detectron2PretrainedModels::FASTER_RCNN_R_50_FPN_3X => {
54 ["Text", "Title", "List", "Table", "Figure"]
55 .iter()
56 .enumerate()
57 .map(|(i, l)| (i as i64, l.to_string()))
58 .collect()
59 }
60 Detectron2PretrainedModels::MASK_RCNN_X_101_32X8D_FPN_3x => {
61 ["Text", "Title", "List", "Table", "Figure"]
62 .iter()
63 .enumerate()
64 .map(|(i, l)| (i as i64, l.to_string()))
65 .collect()
66 }
67 }
68 }
69
70 pub fn confidence_score_index(&self) -> usize {
72 match self {
73 Detectron2PretrainedModels::FASTER_RCNN_R_50_FPN_3X => 2,
74 Detectron2PretrainedModels::MASK_RCNN_X_101_32X8D_FPN_3x => 3,
75 }
76 }
77}
78
79impl Detectron2Model {
80 pub const REQUIRED_WIDTH: u32 = 800;
82 pub const REQUIRED_HEIGHT: u32 = 1035;
84 pub const DEFAULT_CONFIDENCE_THRESHOLD: f32 = 0.8;
86
87 pub fn pretrained(p_model: Detectron2PretrainedModels) -> Result<Self> {
89 let session_builder = Session::builder()?;
90 let api = hf_hub::api::sync::Api::new()?;
91 let filename = api
92 .model(p_model.hf_repo().to_string())
93 .get(p_model.hf_filename())?;
94
95 let model = session_builder.commit_from_file(filename)?;
96
97 Ok(Self {
98 model_name: p_model.name().to_string(),
99 model,
100 label_map: p_model.label_map(),
101 confidence_threshold: Self::DEFAULT_CONFIDENCE_THRESHOLD,
102 confidence_score_index: p_model.confidence_score_index(),
103 })
104 }
105
106 pub fn configure_pretrained(
108 p_model: Detectron2PretrainedModels,
109 confidence_threshold: f32,
110 session_builder: SessionBuilder,
111 ) -> Result<Self> {
112 let api = hf_hub::api::sync::Api::new()?;
113 let filename = api
114 .model(p_model.hf_repo().to_string())
115 .get(p_model.hf_filename())?;
116
117 let model = session_builder.commit_from_file(filename)?;
118
119 Ok(Self {
120 model_name: p_model.name().to_string(),
121 model,
122 label_map: p_model.label_map(),
123 confidence_threshold,
124 confidence_score_index: p_model.confidence_score_index(),
125 })
126 }
127
128 pub fn new_from_file(
130 file_path: &str,
131 model_name: &str,
132 label_map: &[(i64, &str)],
133 confidence_threshold: f32,
134 confidence_score_index: usize,
135 session_builder: SessionBuilder,
136 ) -> Result<Self> {
137 let model = session_builder.commit_from_file(file_path)?;
138
139 Ok(Self {
140 model_name: model_name.to_string(),
141 model,
142 label_map: label_map.iter().map(|(i, l)| (*i, l.to_string())).collect(),
143 confidence_threshold,
144 confidence_score_index,
145 })
146 }
147
148 pub fn predict(&self, img: &image::DynamicImage) -> Result<Vec<LayoutElement>> {
150 let (img_width, img_height, input) = self.preprocess(img);
151
152 let run_result = self.model.run(ort::inputs!["x.1" => input]?);
153 match run_result {
154 Ok(outputs) => {
155 let elements = self.postprocess(&outputs, img_width, img_height)?;
156 return Ok(elements);
157 }
158 Err(_err) => {
159 tracing::warn!(
160 "Ignoring runtime error from onnx (likely due to encountering blank page)."
161 );
162 return Ok(vec![]);
163 }
164 }
165 }
166
167 fn preprocess(
168 &self,
169 img: &image::DynamicImage,
170 ) -> (u32, u32, ArrayBase<OwnedRepr<f32>, Dim<[usize; 3]>>) {
171 let (img_width, img_height) = (img.width(), img.height());
172 let img = img.resize_exact(
173 Self::REQUIRED_WIDTH,
174 Self::REQUIRED_HEIGHT,
175 imageops::FilterType::Triangle,
176 );
177 let img_rgb8 = img.into_rgba8();
178
179 let mut input = Array::zeros((3, 1035, 800));
180
181 for pixel in img_rgb8.enumerate_pixels() {
182 let x = pixel.0 as _;
183 let y = pixel.1 as _;
184 let [r, g, b, _] = pixel.2 .0;
185 input[[0, y, x]] = r as f32;
186 input[[1, y, x]] = g as f32;
187 input[[2, y, x]] = b as f32;
188 }
189
190 return (img_width, img_height, input);
191 }
192
193 fn postprocess<'s>(
194 &self,
195 outputs: &SessionOutputs<'s>,
196 img_width: u32,
197 img_height: u32,
198 ) -> Result<Vec<LayoutElement>> {
199 let bboxes = &outputs[0].try_extract_tensor::<f32>()?;
200 let labels = &outputs[1].try_extract_tensor::<i64>()?;
201 let confidence_scores =
202 &outputs[self.confidence_score_index].try_extract_tensor::<f32>()?;
203
204 let width_conversion = img_width as f32 / Self::REQUIRED_WIDTH as f32;
205 let height_conversion = img_height as f32 / Self::REQUIRED_HEIGHT as f32;
206
207 let mut elements = vec![];
208
209 for (bbox, (label, confidence_score)) in bboxes
210 .rows()
211 .into_iter()
212 .zip(labels.iter().zip(confidence_scores))
213 {
214 let [x1, y1, x2, y2] = vec_to_bbox(bbox.iter().copied().collect());
215
216 let detected_label = &self
217 .label_map
218 .iter()
219 .find(|(l_i, _)| l_i == label)
220 .unwrap() .1;
222
223 if *confidence_score > self.confidence_threshold as f32 {
224 elements.push(LayoutElement::new(
225 x1 * width_conversion,
226 y1 * height_conversion,
227 x2 * width_conversion,
228 y2 * height_conversion,
229 &detected_label,
230 *confidence_score,
231 &self.model_name,
232 ))
233 }
234 }
235
236 elements.sort_by(|a, b| a.bbox.max().y.total_cmp(&b.bbox.max().y));
237
238 return Ok(elements);
239 }
240}