layoutparser_ort/models/
yolox.rs

1use image::imageops;
2use itertools::Itertools;
3use ndarray::{
4    concatenate, s, stack, Array, Array1, Array2, ArrayBase, ArrayD, Axis, Dim, IxDyn, OwnedRepr,
5};
6use ort::{Session, SessionBuilder, SessionOutputs};
7
8pub use crate::error::Result;
9use crate::{utils, LayoutElement};
10
11/// A [`YOLOX`](https://github.com/Megvii-BaseDetection/YOLOX)-based model.
12pub struct YOLOXModel {
13    model_name: String,
14    model: ort::Session,
15    is_quantized: bool,
16    label_map: Vec<(i64, String)>,
17}
18
19#[derive(PartialEq)]
20/// Pretrained YOLOX-based models from Hugging Face.
21pub enum YOLOXPretrainedModels {
22    Large,
23    LargeQuantized,
24    Tiny,
25}
26
27impl YOLOXPretrainedModels {
28    /// Model name.
29    pub fn name(&self) -> &str {
30        match self {
31            _ => self.hf_repo(),
32        }
33    }
34
35    /// Hugging Face repository for this model.
36    pub fn hf_repo(&self) -> &str {
37        match self {
38            _ => "unstructuredio/yolo_x_layout",
39        }
40    }
41
42    /// Path for this model file in Hugging Face repository.
43    pub fn hf_filename(&self) -> &str {
44        match self {
45            YOLOXPretrainedModels::Large => "yolox_l0.05.onnx",
46            YOLOXPretrainedModels::LargeQuantized => "yolox_l0.05_quantized.onnx",
47            YOLOXPretrainedModels::Tiny => "yolox_tiny.onnx",
48        }
49    }
50
51    /// The label map for this model.
52    pub fn label_map(&self) -> Vec<(i64, String)> {
53        match self {
54            _ => Vec::from_iter(
55                [
56                    (0, "Caption"),
57                    (1, "Footnote"),
58                    (2, "Formula"),
59                    (3, "List-item"),
60                    (4, "Page-footer"),
61                    (5, "Page-header"),
62                    (6, "Picture"),
63                    (7, "Section-header"),
64                    (8, "Table"),
65                    (9, "Text"),
66                    (10, "Title"),
67                ]
68                .iter()
69                .map(|(i, l)| (*i as i64, l.to_string())),
70            ),
71        }
72    }
73}
74
75impl YOLOXModel {
76    /// Required input image width.
77    pub const REQUIRED_WIDTH: u32 = 768;
78    /// Required input image height.
79    pub const REQUIRED_HEIGHT: u32 = 1024;
80
81    /// Construct a [`YOLOXModel`] with a pretrained model downloaded from Hugging Face.
82    pub fn pretrained(p_model: YOLOXPretrainedModels) -> Result<Self> {
83        let session_builder = Session::builder()?;
84        let api = hf_hub::api::sync::Api::new()?;
85        let filename = api
86            .model(p_model.hf_repo().to_string())
87            .get(p_model.hf_filename())?;
88
89        let model = session_builder.commit_from_file(filename)?;
90
91        Ok(Self {
92            model_name: p_model.name().to_string(),
93            model,
94            label_map: p_model.label_map(),
95            is_quantized: p_model == YOLOXPretrainedModels::LargeQuantized,
96        })
97    }
98
99    /// Construct a configured [`YOLOXModel`] with a pretrained model downloaded from Hugging Face.
100    pub fn configure_pretrained(
101        p_model: YOLOXPretrainedModels,
102        session_builder: SessionBuilder,
103    ) -> Result<Self> {
104        let api = hf_hub::api::sync::Api::new()?;
105        let filename = api
106            .model(p_model.hf_repo().to_string())
107            .get(p_model.hf_filename())?;
108
109        let model = session_builder.commit_from_file(filename)?;
110
111        Ok(Self {
112            model_name: p_model.name().to_string(),
113            model,
114            label_map: p_model.label_map(),
115            is_quantized: p_model == YOLOXPretrainedModels::LargeQuantized,
116        })
117    }
118
119    /// Construct a [`YOLOXModel`] from a model file.
120    pub fn new_from_file(
121        file_path: &str,
122        model_name: &str,
123        label_map: &[(i64, &str)],
124        is_quantized: bool,
125        session_builder: SessionBuilder,
126    ) -> Result<Self> {
127        let model = session_builder.commit_from_file(file_path)?;
128
129        Ok(Self {
130            model_name: model_name.to_string(),
131            model,
132            label_map: label_map.iter().map(|(i, l)| (*i, l.to_string())).collect(),
133            is_quantized,
134        })
135    }
136
137    /// Predict [`LayoutElement`]s from the image provided.
138    pub fn predict(&self, img: &image::DynamicImage) -> Result<Vec<LayoutElement>> {
139        // UNWRAP SAFETY: shape unwraps are never a problem because we know the size of the output tensor
140        let (input, r) = self.preprocess(img);
141
142        let input_name = &self.model.inputs[0].name;
143
144        let run_result = self.model.run(ort::inputs![input_name => input]?);
145        match run_result {
146            Ok(outputs) => {
147                let predictions = self
148                    .postprocess(&outputs, false)?
149                    .slice(s![0, .., ..])
150                    .to_owned();
151
152                let boxes = predictions
153                    .slice(s![.., 0..4])
154                    .to_shape([16128, 4])
155                    .unwrap()
156                    .to_owned();
157                let scores = predictions
158                    .slice(s![.., 4..5])
159                    .to_shape([16128, 1])
160                    .unwrap()
161                    .to_owned()
162                    * predictions.slice(s![.., 5..]);
163
164                let mut boxes_xyxy: Array<f32, _> = ndarray::Array::ones([16128, 4]);
165
166                let s0 =
167                    boxes.slice(s![.., 0]).to_owned() - (boxes.slice(s![.., 2]).to_owned() / 2.0);
168                let s1 =
169                    boxes.slice(s![.., 1]).to_owned() - (boxes.slice(s![.., 3]).to_owned() / 2.0);
170                let s2 =
171                    boxes.slice(s![.., 0]).to_owned() + (boxes.slice(s![.., 2]).to_owned() / 2.0);
172                let s3 =
173                    boxes.slice(s![.., 1]).to_owned() + (boxes.slice(s![.., 3]).to_owned() / 2.0);
174
175                boxes_xyxy
176                    .slice_mut(s![.., 0])
177                    .iter_mut()
178                    .zip_eq(s0.iter())
179                    .for_each(|(old, new)| *old = *new);
180                boxes_xyxy
181                    .slice_mut(s![.., 1])
182                    .iter_mut()
183                    .zip_eq(s1.iter())
184                    .for_each(|(old, new)| *old = *new);
185                boxes_xyxy
186                    .slice_mut(s![.., 2])
187                    .iter_mut()
188                    .zip_eq(s2.iter())
189                    .for_each(|(old, new)| *old = *new);
190                boxes_xyxy
191                    .slice_mut(s![.., 3])
192                    .iter_mut()
193                    .zip_eq(s3.iter())
194                    .for_each(|(old, new)| *old = *new);
195
196                boxes_xyxy /= r;
197
198                let mut regions = vec![];
199
200                let (nms_thr, score_thr) = if self.is_quantized {
201                    (0.0, 0.07)
202                } else {
203                    (0.1, 0.25)
204                };
205
206                let dets = multiclass_nms_class_agnostic(&boxes_xyxy, &scores, nms_thr, score_thr);
207
208                for det in dets.outer_iter() {
209                    let [x1, y1, x2, y2, prob, class_id] =
210                        extract_bbox_etc(&det.into_iter().copied().collect());
211                    let detected_class = self.get_label(class_id as i64);
212                    regions.push(LayoutElement::new(
213                        x1,
214                        y1,
215                        x2,
216                        y2,
217                        &detected_class,
218                        prob,
219                        &self.model_name,
220                    ));
221                }
222
223                regions.sort_by(|a, b| a.bbox.max().y.total_cmp(&b.bbox.max().y));
224
225                return Ok(regions);
226            }
227            Err(_err) => {
228                eprintln!("{_err:?}");
229                tracing::warn!(
230                    "Ignoring runtime error from onnx (likely due to encountering blank page)."
231                );
232                return Ok(vec![]);
233            }
234        }
235    }
236
237    fn postprocess<'s>(
238        &self,
239        outputs: &SessionOutputs<'s>,
240        p6: bool,
241    ) -> Result<Array<f32, Dim<[usize; 3]>>> {
242        let output_m = &outputs[0].try_extract_tensor::<f32>()?;
243        let mut shaped_output = output_m.to_shape([1, 16128, 16]).unwrap().to_owned();
244
245        let strides = if !p6 {
246            vec![8, 16, 32]
247        } else {
248            vec![8, 16, 32, 64]
249        };
250
251        let hsizes: Vec<u32> = strides.iter().map(|s| Self::REQUIRED_HEIGHT / s).collect();
252        let wsizes: Vec<u32> = strides.iter().map(|s| Self::REQUIRED_WIDTH / s).collect();
253
254        let mut grids = vec![];
255        let mut expanded_strides = vec![];
256
257        for (stride, (hsize, wsize)) in strides.iter().zip(hsizes.iter().zip(wsizes.iter())) {
258            let meshgrid_res = meshgrid(
259                &[Array1::from_iter(0..*wsize), Array1::from_iter(0..*hsize)],
260                Indexing::Xy,
261            );
262            let xv = meshgrid_res[0].to_owned();
263            let yv = meshgrid_res[1].to_owned();
264
265            let grid = stack![Axis(2), xv, yv]
266                .to_shape((1, (hsize * wsize) as usize, 2))
267                .unwrap()
268                .to_owned();
269
270            let shape_1 = &grid.shape()[0..2];
271            expanded_strides.push(Array::from_elem((shape_1[0], shape_1[1], 1), stride));
272
273            grids.push(grid);
274        }
275
276        let grids =
277            ndarray::concatenate(Axis(1), &grids.iter().map(|g| g.view()).collect::<Vec<_>>())
278                .unwrap();
279        let expanded_strides = ndarray::concatenate(
280            Axis(1),
281            &expanded_strides
282                .iter()
283                .map(|g| g.view())
284                .collect::<Vec<_>>(),
285        )
286        .unwrap();
287
288        let s1 = (shaped_output.slice(s![.., .., 0..2]).to_owned() + grids.mapv(|e| e as f32))
289            * expanded_strides.mapv(|e| *e as f32);
290        let s2 = (shaped_output
291            .slice(s![.., .., 2..4])
292            .mapv(|e| e.exp())
293            .to_owned())
294            * expanded_strides.mapv(|e| *e as f32);
295
296        shaped_output
297            .slice_mut(s![.., .., 0..2])
298            .into_iter()
299            .zip_eq(s1.into_iter())
300            .for_each(|(old, new)| {
301                *old = new;
302            });
303
304        shaped_output
305            .slice_mut(s![.., .., 2..4])
306            .into_iter()
307            .zip_eq(s2.into_iter())
308            .for_each(|(old, new)| {
309                *old = new;
310            });
311
312        Ok(shaped_output)
313    }
314
315    fn preprocess(
316        &self,
317        img: &image::DynamicImage,
318    ) -> (ArrayBase<OwnedRepr<f32>, Dim<[usize; 4]>>, f32) {
319        let (img_width, img_height) = (img.width(), img.height());
320
321        let mut padded_img: ArrayBase<OwnedRepr<f32>, Dim<[usize; 4]>> = Array::ones((
322            1,
323            3,
324            Self::REQUIRED_HEIGHT as usize,
325            Self::REQUIRED_WIDTH as usize,
326        )) * 114_f32;
327
328        let r: f64 = f64::min(
329            Self::REQUIRED_HEIGHT as f64 / img_height as f64,
330            Self::REQUIRED_WIDTH as f64 / img_width as f64,
331        );
332
333        let resized_img = img.resize_exact(
334            (img_width as f64 * r) as u32,
335            (img_height as f64 * r) as u32,
336            imageops::FilterType::Triangle,
337        );
338
339        for pixel in resized_img.into_rgba8().enumerate_pixels() {
340            let x = pixel.0 as _;
341            let y = pixel.1 as _;
342            let [r, g, b, _] = pixel.2 .0;
343            padded_img[[0, 0, y, x]] = r as f32;
344            padded_img[[0, 1, y, x]] = g as f32;
345            padded_img[[0, 2, y, x]] = b as f32;
346        }
347
348        (padded_img, r as f32)
349    }
350
351    fn get_label(&self, label_id: i64) -> String {
352        self.label_map
353            .iter()
354            .find(|(l_i, _)| l_i == &label_id)
355            .unwrap()
356            .1
357            .clone()
358    }
359}
360
361fn multiclass_nms_class_agnostic(
362    boxes: &Array<f32, Dim<[usize; 2]>>,
363    scores: &Array<f32, Dim<[usize; 2]>>,
364    nms_thr: f32,
365    score_thr: f32,
366) -> Array2<f32> {
367    let cls_inds = Array1::from_iter(scores.axis_iter(Axis(0)).map(|e| {
368        let (max_i, _max) = e.iter().enumerate().fold((0_usize, 0_f32), |acc, (i, e)| {
369            let (max_i, max) = acc;
370            if *e > max {
371                (i, *e)
372            } else {
373                (max_i, max)
374            }
375        });
376        max_i
377    }));
378
379    let cls_scores = Array1::from_iter(
380        scores
381            .axis_iter(Axis(0))
382            .zip_eq(cls_inds.iter())
383            .map(|(e, i)| e[*i]),
384    );
385
386    let valid_score_mask = cls_scores.mapv(|s| s > score_thr);
387    let valid_scores = Array1::from_iter(
388        cls_scores
389            .iter()
390            .zip_eq(valid_score_mask.iter())
391            .filter(|(_, b)| **b)
392            .map(|(s, _)| *s),
393    );
394
395    let valid_boxes: Array2<f32> = to_array2(
396        &boxes
397            .outer_iter()
398            .zip_eq(valid_score_mask.iter())
399            .filter(|(_, b)| **b)
400            .map(|(s, _)| s.to_owned())
401            .collect::<Vec<_>>(),
402    )
403    .unwrap();
404
405    let valid_cls_inds = Array1::from_iter(
406        cls_inds
407            .iter()
408            .zip_eq(valid_score_mask.iter())
409            .filter(|(_, b)| **b)
410            .map(|(s, _)| s)
411            .collect::<Vec<_>>(),
412    );
413
414    let keep = nms(&valid_boxes.to_owned(), &valid_scores, nms_thr);
415
416    let valid_boxes_vec: Vec<_> = valid_boxes.outer_iter().collect();
417    let valid_boxes_kept = to_array2(
418        &keep
419            .iter()
420            .map(|i| valid_boxes_vec[*i])
421            .map(|e| e.to_owned())
422            .collect::<Vec<_>>(),
423    )
424    .unwrap();
425
426    let valid_scores_vec: Vec<_> = valid_scores.into_iter().collect();
427    let valid_scores_kept = to_array2(
428        &keep
429            .iter()
430            .map(|i| valid_scores_vec[*i])
431            .map(|e| Array1::from_elem(1, e))
432            .collect::<Vec<_>>(),
433    )
434    .unwrap();
435
436    let valid_cls_inds_vec: Vec<_> = valid_cls_inds.into_iter().collect();
437    let valid_cls_inds_kept = to_array2(
438        &keep
439            .iter()
440            .map(|i| valid_cls_inds_vec[*i])
441            .map(|e| Array1::from_elem(1, e))
442            .collect::<Vec<_>>(),
443    )
444    .unwrap();
445
446    let dets = concatenate(
447        Axis(1),
448        &[
449            valid_boxes_kept.view(),
450            valid_scores_kept.view(),
451            valid_cls_inds_kept.mapv(|e| *e as f32).view(),
452        ],
453    )
454    .unwrap();
455
456    return dets;
457}
458
459fn nms(
460    boxes: &Array<f32, Dim<[usize; 2]>>,
461    scores: &Array<f32, Dim<[usize; 1]>>,
462    nms_thr: f32,
463) -> Vec<usize> {
464    let x1 = boxes.slice(s![.., 0]);
465    let y1 = boxes.slice(s![.., 1]);
466    let x2 = boxes.slice(s![.., 2]);
467    let y2 = boxes.slice(s![.., 3]);
468
469    let areas = (&x2 - &x1 + 1_f32) * (&y2 - &y1 + 1_f32);
470    let mut order = {
471        let mut o = utils::argsort_by(&scores, |a, b| a.partial_cmp(b).unwrap());
472        o.reverse();
473        o
474    };
475
476    let mut keep = vec![];
477
478    while !order.is_empty() {
479        let i = order[0];
480        keep.push(i);
481
482        let order_sliced = Array1::from_iter(order.iter().skip(1));
483
484        let xx1 = order_sliced.mapv(|o_i| f32::max(x1[i], x1[*o_i]));
485        let yy1 = order_sliced.mapv(|o_i| f32::max(y1[i], y1[*o_i]));
486        let xx2 = order_sliced.mapv(|o_i| f32::min(x2[i], x2[*o_i]));
487        let yy2 = order_sliced.mapv(|o_i| f32::min(y2[i], y2[*o_i]));
488
489        let w = ((&xx2 - &xx1) + 1_f32).mapv(|v| f32::max(0.0, v));
490        let h = ((&yy2 - &yy1) + 1_f32).mapv(|v| f32::max(0.0, v));
491        let inter = w * h;
492        let ovr = &inter / (areas[i] + order_sliced.mapv(|e| areas[*e]) - &inter);
493
494        let inds = Array1::from_iter(
495            ovr.iter()
496                .map(|e| *e <= nms_thr)
497                .enumerate()
498                .filter(|(_, p)| *p)
499                .map(|(i, _)| i),
500        );
501
502        drop(order_sliced);
503
504        order = inds.into_iter().map(|i| order[i + 1]).collect();
505    }
506
507    return keep;
508}
509
510fn to_array2<T: Copy>(source: &[Array1<T>]) -> Result<Array2<T>, impl std::error::Error> {
511    let width = source.len();
512    let flattened: Array1<T> = source.into_iter().flat_map(|row| row.to_vec()).collect();
513    let height = if width == 0 {
514        flattened.len()
515    } else {
516        flattened.len() / width
517    };
518    flattened.into_shape((width, height))
519}
520
521/** [x1, y1, x2, y2, prob, class_id] */
522fn extract_bbox_etc(v: &Vec<f32>) -> [f32; 6] {
523    [v[0], v[1], v[2], v[3], v[4], v[5]]
524}
525
526// from: https://github.com/jreniel/meshgridrs (licensed under MIT)
527#[derive(PartialEq)]
528pub(crate) enum Indexing {
529    Xy,
530    Ij,
531}
532// from: https://github.com/jreniel/meshgridrs (licensed under MIT)
533pub(crate) fn meshgrid<T>(
534    xi: &[Array1<T>],
535    indexing: Indexing,
536) -> Vec<ArrayBase<OwnedRepr<T>, Dim<ndarray::IxDynImpl>>>
537where
538    T: Copy,
539{
540    let ndim = xi.len();
541    let product = xi.iter().map(|x| x.iter()).multi_cartesian_product();
542
543    let mut grids: Vec<ArrayD<T>> = Vec::with_capacity(ndim);
544
545    for (dim_index, _) in xi.iter().enumerate() {
546        // Generate a flat vector with the correct repeated pattern
547        let values: Vec<T> = product.clone().map(|p| *p[dim_index]).collect();
548
549        let mut grid_shape: Vec<usize> = vec![1; ndim];
550        grid_shape[dim_index] = xi[dim_index].len();
551
552        // Determine the correct repetition for each dimension
553        for (j, len) in xi.iter().map(|x| x.len()).enumerate() {
554            if j != dim_index {
555                grid_shape[j] = len;
556            }
557        }
558
559        let grid = Array::from_shape_vec(IxDyn(&grid_shape), values).unwrap();
560        grids.push(grid);
561    }
562
563    // Swap axes for "xy" indexing
564    if matches!(indexing, Indexing::Xy) && ndim > 1 {
565        for grid in &mut grids {
566            grid.swap_axes(0, 1);
567        }
568    }
569
570    grids
571}