layoutparser_ort/models/
detectron2.rs

1use image::imageops;
2use ndarray::{Array, ArrayBase, Dim, OwnedRepr};
3use ort::{Session, SessionBuilder, SessionOutputs};
4
5pub use crate::error::Result;
6use crate::{utils::vec_to_bbox, LayoutElement};
7
8/// A [`Detectron2`](https://github.com/facebookresearch/detectron2)-based model.
9pub struct Detectron2Model {
10    model_name: String,
11    model: ort::Session,
12    confidence_threshold: f32,
13    label_map: Vec<(i64, String)>,
14    confidence_score_index: usize,
15}
16
17#[allow(non_camel_case_types)]
18/// Pretrained Detectron2-based models from Hugging Face.
19pub enum Detectron2PretrainedModels {
20    FASTER_RCNN_R_50_FPN_3X,
21    MASK_RCNN_X_101_32X8D_FPN_3x,
22}
23
24impl Detectron2PretrainedModels {
25    /// Model name.
26    pub fn name(&self) -> &str {
27        match self {
28            _ => self.hf_repo(),
29        }
30    }
31
32    /// Hugging Face repository for this model.
33    pub fn hf_repo(&self) -> &str {
34        match self {
35            Self::FASTER_RCNN_R_50_FPN_3X => "unstructuredio/detectron2_faster_rcnn_R_50_FPN_3x",
36            Self::MASK_RCNN_X_101_32X8D_FPN_3x => {
37                "unstructuredio/detectron2_mask_rcnn_X_101_32x8d_FPN_3x"
38            }
39        }
40    }
41
42    /// Path for this model file in Hugging Face repository.
43    pub fn hf_filename(&self) -> &str {
44        match self {
45            Self::FASTER_RCNN_R_50_FPN_3X => "model.onnx",
46            Self::MASK_RCNN_X_101_32X8D_FPN_3x => "model.onnx",
47        }
48    }
49
50    /// The label map for this model.
51    pub fn label_map(&self) -> Vec<(i64, String)> {
52        match self {
53            Detectron2PretrainedModels::FASTER_RCNN_R_50_FPN_3X => {
54                ["Text", "Title", "List", "Table", "Figure"]
55                    .iter()
56                    .enumerate()
57                    .map(|(i, l)| (i as i64, l.to_string()))
58                    .collect()
59            }
60            Detectron2PretrainedModels::MASK_RCNN_X_101_32X8D_FPN_3x => {
61                ["Text", "Title", "List", "Table", "Figure"]
62                    .iter()
63                    .enumerate()
64                    .map(|(i, l)| (i as i64, l.to_string()))
65                    .collect()
66            }
67        }
68    }
69
70    /// Index for the confidence score in this model's outputs.
71    pub fn confidence_score_index(&self) -> usize {
72        match self {
73            Detectron2PretrainedModels::FASTER_RCNN_R_50_FPN_3X => 2,
74            Detectron2PretrainedModels::MASK_RCNN_X_101_32X8D_FPN_3x => 3,
75        }
76    }
77}
78
79impl Detectron2Model {
80    /// Required input image width.
81    pub const REQUIRED_WIDTH: u32 = 800;
82    /// Required input image height.
83    pub const REQUIRED_HEIGHT: u32 = 1035;
84    /// Default confidence threshold for detections.
85    pub const DEFAULT_CONFIDENCE_THRESHOLD: f32 = 0.8;
86
87    /// Construct a [`Detectron2Model`] with a pretrained model downloaded from Hugging Face.
88    pub fn pretrained(p_model: Detectron2PretrainedModels) -> Result<Self> {
89        let session_builder = Session::builder()?;
90        let api = hf_hub::api::sync::Api::new()?;
91        let filename = api
92            .model(p_model.hf_repo().to_string())
93            .get(p_model.hf_filename())?;
94
95        let model = session_builder.commit_from_file(filename)?;
96
97        Ok(Self {
98            model_name: p_model.name().to_string(),
99            model,
100            label_map: p_model.label_map(),
101            confidence_threshold: Self::DEFAULT_CONFIDENCE_THRESHOLD,
102            confidence_score_index: p_model.confidence_score_index(),
103        })
104    }
105
106    /// Construct a configured [`Detectron2Model`] with a pretrained model downloaded from Hugging Face.
107    pub fn configure_pretrained(
108        p_model: Detectron2PretrainedModels,
109        confidence_threshold: f32,
110        session_builder: SessionBuilder,
111    ) -> Result<Self> {
112        let api = hf_hub::api::sync::Api::new()?;
113        let filename = api
114            .model(p_model.hf_repo().to_string())
115            .get(p_model.hf_filename())?;
116
117        let model = session_builder.commit_from_file(filename)?;
118
119        Ok(Self {
120            model_name: p_model.name().to_string(),
121            model,
122            label_map: p_model.label_map(),
123            confidence_threshold,
124            confidence_score_index: p_model.confidence_score_index(),
125        })
126    }
127
128    /// Construct a [`Detectron2Model`] from a model file.
129    pub fn new_from_file(
130        file_path: &str,
131        model_name: &str,
132        label_map: &[(i64, &str)],
133        confidence_threshold: f32,
134        confidence_score_index: usize,
135        session_builder: SessionBuilder,
136    ) -> Result<Self> {
137        let model = session_builder.commit_from_file(file_path)?;
138
139        Ok(Self {
140            model_name: model_name.to_string(),
141            model,
142            label_map: label_map.iter().map(|(i, l)| (*i, l.to_string())).collect(),
143            confidence_threshold,
144            confidence_score_index,
145        })
146    }
147
148    /// Predict [`LayoutElement`]s from the image provided.
149    pub fn predict(&self, img: &image::DynamicImage) -> Result<Vec<LayoutElement>> {
150        let (img_width, img_height, input) = self.preprocess(img);
151
152        let run_result = self.model.run(ort::inputs!["x.1" => input]?);
153        match run_result {
154            Ok(outputs) => {
155                let elements = self.postprocess(&outputs, img_width, img_height)?;
156                return Ok(elements);
157            }
158            Err(_err) => {
159                tracing::warn!(
160                    "Ignoring runtime error from onnx (likely due to encountering blank page)."
161                );
162                return Ok(vec![]);
163            }
164        }
165    }
166
167    fn preprocess(
168        &self,
169        img: &image::DynamicImage,
170    ) -> (u32, u32, ArrayBase<OwnedRepr<f32>, Dim<[usize; 3]>>) {
171        let (img_width, img_height) = (img.width(), img.height());
172        let img = img.resize_exact(
173            Self::REQUIRED_WIDTH,
174            Self::REQUIRED_HEIGHT,
175            imageops::FilterType::Triangle,
176        );
177        let img_rgb8 = img.into_rgba8();
178
179        let mut input = Array::zeros((3, 1035, 800));
180
181        for pixel in img_rgb8.enumerate_pixels() {
182            let x = pixel.0 as _;
183            let y = pixel.1 as _;
184            let [r, g, b, _] = pixel.2 .0;
185            input[[0, y, x]] = r as f32;
186            input[[1, y, x]] = g as f32;
187            input[[2, y, x]] = b as f32;
188        }
189
190        return (img_width, img_height, input);
191    }
192
193    fn postprocess<'s>(
194        &self,
195        outputs: &SessionOutputs<'s>,
196        img_width: u32,
197        img_height: u32,
198    ) -> Result<Vec<LayoutElement>> {
199        let bboxes = &outputs[0].try_extract_tensor::<f32>()?;
200        let labels = &outputs[1].try_extract_tensor::<i64>()?;
201        let confidence_scores =
202            &outputs[self.confidence_score_index].try_extract_tensor::<f32>()?;
203
204        let width_conversion = img_width as f32 / Self::REQUIRED_WIDTH as f32;
205        let height_conversion = img_height as f32 / Self::REQUIRED_HEIGHT as f32;
206
207        let mut elements = vec![];
208
209        for (bbox, (label, confidence_score)) in bboxes
210            .rows()
211            .into_iter()
212            .zip(labels.iter().zip(confidence_scores))
213        {
214            let [x1, y1, x2, y2] = vec_to_bbox(bbox.iter().copied().collect());
215
216            let detected_label = &self
217                .label_map
218                .iter()
219                .find(|(l_i, _)| l_i == label)
220                .unwrap() // SAFETY: the model always yields one of these labels
221                .1;
222
223            if *confidence_score > self.confidence_threshold as f32 {
224                elements.push(LayoutElement::new(
225                    x1 * width_conversion,
226                    y1 * height_conversion,
227                    x2 * width_conversion,
228                    y2 * height_conversion,
229                    &detected_label,
230                    *confidence_score,
231                    &self.model_name,
232                ))
233            }
234        }
235
236        elements.sort_by(|a, b| a.bbox.max().y.total_cmp(&b.bbox.max().y));
237
238        return Ok(elements);
239    }
240}