1use image::imageops;
2use itertools::Itertools;
3use ndarray::{
4 concatenate, s, stack, Array, Array1, Array2, ArrayBase, ArrayD, Axis, Dim, IxDyn, OwnedRepr,
5};
6use ort::{Session, SessionBuilder, SessionOutputs};
7
8pub use crate::error::Result;
9use crate::{utils, LayoutElement};
10
11pub struct YOLOXModel {
13 model_name: String,
14 model: ort::Session,
15 is_quantized: bool,
16 label_map: Vec<(i64, String)>,
17}
18
19#[derive(PartialEq)]
20pub enum YOLOXPretrainedModels {
22 Large,
23 LargeQuantized,
24 Tiny,
25}
26
27impl YOLOXPretrainedModels {
28 pub fn name(&self) -> &str {
30 match self {
31 _ => self.hf_repo(),
32 }
33 }
34
35 pub fn hf_repo(&self) -> &str {
37 match self {
38 _ => "unstructuredio/yolo_x_layout",
39 }
40 }
41
42 pub fn hf_filename(&self) -> &str {
44 match self {
45 YOLOXPretrainedModels::Large => "yolox_l0.05.onnx",
46 YOLOXPretrainedModels::LargeQuantized => "yolox_l0.05_quantized.onnx",
47 YOLOXPretrainedModels::Tiny => "yolox_tiny.onnx",
48 }
49 }
50
51 pub fn label_map(&self) -> Vec<(i64, String)> {
53 match self {
54 _ => Vec::from_iter(
55 [
56 (0, "Caption"),
57 (1, "Footnote"),
58 (2, "Formula"),
59 (3, "List-item"),
60 (4, "Page-footer"),
61 (5, "Page-header"),
62 (6, "Picture"),
63 (7, "Section-header"),
64 (8, "Table"),
65 (9, "Text"),
66 (10, "Title"),
67 ]
68 .iter()
69 .map(|(i, l)| (*i as i64, l.to_string())),
70 ),
71 }
72 }
73}
74
75impl YOLOXModel {
76 pub const REQUIRED_WIDTH: u32 = 768;
78 pub const REQUIRED_HEIGHT: u32 = 1024;
80
81 pub fn pretrained(p_model: YOLOXPretrainedModels) -> Result<Self> {
83 let session_builder = Session::builder()?;
84 let api = hf_hub::api::sync::Api::new()?;
85 let filename = api
86 .model(p_model.hf_repo().to_string())
87 .get(p_model.hf_filename())?;
88
89 let model = session_builder.commit_from_file(filename)?;
90
91 Ok(Self {
92 model_name: p_model.name().to_string(),
93 model,
94 label_map: p_model.label_map(),
95 is_quantized: p_model == YOLOXPretrainedModels::LargeQuantized,
96 })
97 }
98
99 pub fn configure_pretrained(
101 p_model: YOLOXPretrainedModels,
102 session_builder: SessionBuilder,
103 ) -> Result<Self> {
104 let api = hf_hub::api::sync::Api::new()?;
105 let filename = api
106 .model(p_model.hf_repo().to_string())
107 .get(p_model.hf_filename())?;
108
109 let model = session_builder.commit_from_file(filename)?;
110
111 Ok(Self {
112 model_name: p_model.name().to_string(),
113 model,
114 label_map: p_model.label_map(),
115 is_quantized: p_model == YOLOXPretrainedModels::LargeQuantized,
116 })
117 }
118
119 pub fn new_from_file(
121 file_path: &str,
122 model_name: &str,
123 label_map: &[(i64, &str)],
124 is_quantized: bool,
125 session_builder: SessionBuilder,
126 ) -> Result<Self> {
127 let model = session_builder.commit_from_file(file_path)?;
128
129 Ok(Self {
130 model_name: model_name.to_string(),
131 model,
132 label_map: label_map.iter().map(|(i, l)| (*i, l.to_string())).collect(),
133 is_quantized,
134 })
135 }
136
137 pub fn predict(&self, img: &image::DynamicImage) -> Result<Vec<LayoutElement>> {
139 let (input, r) = self.preprocess(img);
141
142 let input_name = &self.model.inputs[0].name;
143
144 let run_result = self.model.run(ort::inputs![input_name => input]?);
145 match run_result {
146 Ok(outputs) => {
147 let predictions = self
148 .postprocess(&outputs, false)?
149 .slice(s![0, .., ..])
150 .to_owned();
151
152 let boxes = predictions
153 .slice(s![.., 0..4])
154 .to_shape([16128, 4])
155 .unwrap()
156 .to_owned();
157 let scores = predictions
158 .slice(s![.., 4..5])
159 .to_shape([16128, 1])
160 .unwrap()
161 .to_owned()
162 * predictions.slice(s![.., 5..]);
163
164 let mut boxes_xyxy: Array<f32, _> = ndarray::Array::ones([16128, 4]);
165
166 let s0 =
167 boxes.slice(s![.., 0]).to_owned() - (boxes.slice(s![.., 2]).to_owned() / 2.0);
168 let s1 =
169 boxes.slice(s![.., 1]).to_owned() - (boxes.slice(s![.., 3]).to_owned() / 2.0);
170 let s2 =
171 boxes.slice(s![.., 0]).to_owned() + (boxes.slice(s![.., 2]).to_owned() / 2.0);
172 let s3 =
173 boxes.slice(s![.., 1]).to_owned() + (boxes.slice(s![.., 3]).to_owned() / 2.0);
174
175 boxes_xyxy
176 .slice_mut(s![.., 0])
177 .iter_mut()
178 .zip_eq(s0.iter())
179 .for_each(|(old, new)| *old = *new);
180 boxes_xyxy
181 .slice_mut(s![.., 1])
182 .iter_mut()
183 .zip_eq(s1.iter())
184 .for_each(|(old, new)| *old = *new);
185 boxes_xyxy
186 .slice_mut(s![.., 2])
187 .iter_mut()
188 .zip_eq(s2.iter())
189 .for_each(|(old, new)| *old = *new);
190 boxes_xyxy
191 .slice_mut(s![.., 3])
192 .iter_mut()
193 .zip_eq(s3.iter())
194 .for_each(|(old, new)| *old = *new);
195
196 boxes_xyxy /= r;
197
198 let mut regions = vec![];
199
200 let (nms_thr, score_thr) = if self.is_quantized {
201 (0.0, 0.07)
202 } else {
203 (0.1, 0.25)
204 };
205
206 let dets = multiclass_nms_class_agnostic(&boxes_xyxy, &scores, nms_thr, score_thr);
207
208 for det in dets.outer_iter() {
209 let [x1, y1, x2, y2, prob, class_id] =
210 extract_bbox_etc(&det.into_iter().copied().collect());
211 let detected_class = self.get_label(class_id as i64);
212 regions.push(LayoutElement::new(
213 x1,
214 y1,
215 x2,
216 y2,
217 &detected_class,
218 prob,
219 &self.model_name,
220 ));
221 }
222
223 regions.sort_by(|a, b| a.bbox.max().y.total_cmp(&b.bbox.max().y));
224
225 return Ok(regions);
226 }
227 Err(_err) => {
228 eprintln!("{_err:?}");
229 tracing::warn!(
230 "Ignoring runtime error from onnx (likely due to encountering blank page)."
231 );
232 return Ok(vec![]);
233 }
234 }
235 }
236
237 fn postprocess<'s>(
238 &self,
239 outputs: &SessionOutputs<'s>,
240 p6: bool,
241 ) -> Result<Array<f32, Dim<[usize; 3]>>> {
242 let output_m = &outputs[0].try_extract_tensor::<f32>()?;
243 let mut shaped_output = output_m.to_shape([1, 16128, 16]).unwrap().to_owned();
244
245 let strides = if !p6 {
246 vec![8, 16, 32]
247 } else {
248 vec![8, 16, 32, 64]
249 };
250
251 let hsizes: Vec<u32> = strides.iter().map(|s| Self::REQUIRED_HEIGHT / s).collect();
252 let wsizes: Vec<u32> = strides.iter().map(|s| Self::REQUIRED_WIDTH / s).collect();
253
254 let mut grids = vec![];
255 let mut expanded_strides = vec![];
256
257 for (stride, (hsize, wsize)) in strides.iter().zip(hsizes.iter().zip(wsizes.iter())) {
258 let meshgrid_res = meshgrid(
259 &[Array1::from_iter(0..*wsize), Array1::from_iter(0..*hsize)],
260 Indexing::Xy,
261 );
262 let xv = meshgrid_res[0].to_owned();
263 let yv = meshgrid_res[1].to_owned();
264
265 let grid = stack![Axis(2), xv, yv]
266 .to_shape((1, (hsize * wsize) as usize, 2))
267 .unwrap()
268 .to_owned();
269
270 let shape_1 = &grid.shape()[0..2];
271 expanded_strides.push(Array::from_elem((shape_1[0], shape_1[1], 1), stride));
272
273 grids.push(grid);
274 }
275
276 let grids =
277 ndarray::concatenate(Axis(1), &grids.iter().map(|g| g.view()).collect::<Vec<_>>())
278 .unwrap();
279 let expanded_strides = ndarray::concatenate(
280 Axis(1),
281 &expanded_strides
282 .iter()
283 .map(|g| g.view())
284 .collect::<Vec<_>>(),
285 )
286 .unwrap();
287
288 let s1 = (shaped_output.slice(s![.., .., 0..2]).to_owned() + grids.mapv(|e| e as f32))
289 * expanded_strides.mapv(|e| *e as f32);
290 let s2 = (shaped_output
291 .slice(s![.., .., 2..4])
292 .mapv(|e| e.exp())
293 .to_owned())
294 * expanded_strides.mapv(|e| *e as f32);
295
296 shaped_output
297 .slice_mut(s![.., .., 0..2])
298 .into_iter()
299 .zip_eq(s1.into_iter())
300 .for_each(|(old, new)| {
301 *old = new;
302 });
303
304 shaped_output
305 .slice_mut(s![.., .., 2..4])
306 .into_iter()
307 .zip_eq(s2.into_iter())
308 .for_each(|(old, new)| {
309 *old = new;
310 });
311
312 Ok(shaped_output)
313 }
314
315 fn preprocess(
316 &self,
317 img: &image::DynamicImage,
318 ) -> (ArrayBase<OwnedRepr<f32>, Dim<[usize; 4]>>, f32) {
319 let (img_width, img_height) = (img.width(), img.height());
320
321 let mut padded_img: ArrayBase<OwnedRepr<f32>, Dim<[usize; 4]>> = Array::ones((
322 1,
323 3,
324 Self::REQUIRED_HEIGHT as usize,
325 Self::REQUIRED_WIDTH as usize,
326 )) * 114_f32;
327
328 let r: f64 = f64::min(
329 Self::REQUIRED_HEIGHT as f64 / img_height as f64,
330 Self::REQUIRED_WIDTH as f64 / img_width as f64,
331 );
332
333 let resized_img = img.resize_exact(
334 (img_width as f64 * r) as u32,
335 (img_height as f64 * r) as u32,
336 imageops::FilterType::Triangle,
337 );
338
339 for pixel in resized_img.into_rgba8().enumerate_pixels() {
340 let x = pixel.0 as _;
341 let y = pixel.1 as _;
342 let [r, g, b, _] = pixel.2 .0;
343 padded_img[[0, 0, y, x]] = r as f32;
344 padded_img[[0, 1, y, x]] = g as f32;
345 padded_img[[0, 2, y, x]] = b as f32;
346 }
347
348 (padded_img, r as f32)
349 }
350
351 fn get_label(&self, label_id: i64) -> String {
352 self.label_map
353 .iter()
354 .find(|(l_i, _)| l_i == &label_id)
355 .unwrap()
356 .1
357 .clone()
358 }
359}
360
361fn multiclass_nms_class_agnostic(
362 boxes: &Array<f32, Dim<[usize; 2]>>,
363 scores: &Array<f32, Dim<[usize; 2]>>,
364 nms_thr: f32,
365 score_thr: f32,
366) -> Array2<f32> {
367 let cls_inds = Array1::from_iter(scores.axis_iter(Axis(0)).map(|e| {
368 let (max_i, _max) = e.iter().enumerate().fold((0_usize, 0_f32), |acc, (i, e)| {
369 let (max_i, max) = acc;
370 if *e > max {
371 (i, *e)
372 } else {
373 (max_i, max)
374 }
375 });
376 max_i
377 }));
378
379 let cls_scores = Array1::from_iter(
380 scores
381 .axis_iter(Axis(0))
382 .zip_eq(cls_inds.iter())
383 .map(|(e, i)| e[*i]),
384 );
385
386 let valid_score_mask = cls_scores.mapv(|s| s > score_thr);
387 let valid_scores = Array1::from_iter(
388 cls_scores
389 .iter()
390 .zip_eq(valid_score_mask.iter())
391 .filter(|(_, b)| **b)
392 .map(|(s, _)| *s),
393 );
394
395 let valid_boxes: Array2<f32> = to_array2(
396 &boxes
397 .outer_iter()
398 .zip_eq(valid_score_mask.iter())
399 .filter(|(_, b)| **b)
400 .map(|(s, _)| s.to_owned())
401 .collect::<Vec<_>>(),
402 )
403 .unwrap();
404
405 let valid_cls_inds = Array1::from_iter(
406 cls_inds
407 .iter()
408 .zip_eq(valid_score_mask.iter())
409 .filter(|(_, b)| **b)
410 .map(|(s, _)| s)
411 .collect::<Vec<_>>(),
412 );
413
414 let keep = nms(&valid_boxes.to_owned(), &valid_scores, nms_thr);
415
416 let valid_boxes_vec: Vec<_> = valid_boxes.outer_iter().collect();
417 let valid_boxes_kept = to_array2(
418 &keep
419 .iter()
420 .map(|i| valid_boxes_vec[*i])
421 .map(|e| e.to_owned())
422 .collect::<Vec<_>>(),
423 )
424 .unwrap();
425
426 let valid_scores_vec: Vec<_> = valid_scores.into_iter().collect();
427 let valid_scores_kept = to_array2(
428 &keep
429 .iter()
430 .map(|i| valid_scores_vec[*i])
431 .map(|e| Array1::from_elem(1, e))
432 .collect::<Vec<_>>(),
433 )
434 .unwrap();
435
436 let valid_cls_inds_vec: Vec<_> = valid_cls_inds.into_iter().collect();
437 let valid_cls_inds_kept = to_array2(
438 &keep
439 .iter()
440 .map(|i| valid_cls_inds_vec[*i])
441 .map(|e| Array1::from_elem(1, e))
442 .collect::<Vec<_>>(),
443 )
444 .unwrap();
445
446 let dets = concatenate(
447 Axis(1),
448 &[
449 valid_boxes_kept.view(),
450 valid_scores_kept.view(),
451 valid_cls_inds_kept.mapv(|e| *e as f32).view(),
452 ],
453 )
454 .unwrap();
455
456 return dets;
457}
458
459fn nms(
460 boxes: &Array<f32, Dim<[usize; 2]>>,
461 scores: &Array<f32, Dim<[usize; 1]>>,
462 nms_thr: f32,
463) -> Vec<usize> {
464 let x1 = boxes.slice(s![.., 0]);
465 let y1 = boxes.slice(s![.., 1]);
466 let x2 = boxes.slice(s![.., 2]);
467 let y2 = boxes.slice(s![.., 3]);
468
469 let areas = (&x2 - &x1 + 1_f32) * (&y2 - &y1 + 1_f32);
470 let mut order = {
471 let mut o = utils::argsort_by(&scores, |a, b| a.partial_cmp(b).unwrap());
472 o.reverse();
473 o
474 };
475
476 let mut keep = vec![];
477
478 while !order.is_empty() {
479 let i = order[0];
480 keep.push(i);
481
482 let order_sliced = Array1::from_iter(order.iter().skip(1));
483
484 let xx1 = order_sliced.mapv(|o_i| f32::max(x1[i], x1[*o_i]));
485 let yy1 = order_sliced.mapv(|o_i| f32::max(y1[i], y1[*o_i]));
486 let xx2 = order_sliced.mapv(|o_i| f32::min(x2[i], x2[*o_i]));
487 let yy2 = order_sliced.mapv(|o_i| f32::min(y2[i], y2[*o_i]));
488
489 let w = ((&xx2 - &xx1) + 1_f32).mapv(|v| f32::max(0.0, v));
490 let h = ((&yy2 - &yy1) + 1_f32).mapv(|v| f32::max(0.0, v));
491 let inter = w * h;
492 let ovr = &inter / (areas[i] + order_sliced.mapv(|e| areas[*e]) - &inter);
493
494 let inds = Array1::from_iter(
495 ovr.iter()
496 .map(|e| *e <= nms_thr)
497 .enumerate()
498 .filter(|(_, p)| *p)
499 .map(|(i, _)| i),
500 );
501
502 drop(order_sliced);
503
504 order = inds.into_iter().map(|i| order[i + 1]).collect();
505 }
506
507 return keep;
508}
509
510fn to_array2<T: Copy>(source: &[Array1<T>]) -> Result<Array2<T>, impl std::error::Error> {
511 let width = source.len();
512 let flattened: Array1<T> = source.into_iter().flat_map(|row| row.to_vec()).collect();
513 let height = if width == 0 {
514 flattened.len()
515 } else {
516 flattened.len() / width
517 };
518 flattened.into_shape((width, height))
519}
520
521fn extract_bbox_etc(v: &Vec<f32>) -> [f32; 6] {
523 [v[0], v[1], v[2], v[3], v[4], v[5]]
524}
525
526#[derive(PartialEq)]
528pub(crate) enum Indexing {
529 Xy,
530 Ij,
531}
532pub(crate) fn meshgrid<T>(
534 xi: &[Array1<T>],
535 indexing: Indexing,
536) -> Vec<ArrayBase<OwnedRepr<T>, Dim<ndarray::IxDynImpl>>>
537where
538 T: Copy,
539{
540 let ndim = xi.len();
541 let product = xi.iter().map(|x| x.iter()).multi_cartesian_product();
542
543 let mut grids: Vec<ArrayD<T>> = Vec::with_capacity(ndim);
544
545 for (dim_index, _) in xi.iter().enumerate() {
546 let values: Vec<T> = product.clone().map(|p| *p[dim_index]).collect();
548
549 let mut grid_shape: Vec<usize> = vec![1; ndim];
550 grid_shape[dim_index] = xi[dim_index].len();
551
552 for (j, len) in xi.iter().map(|x| x.len()).enumerate() {
554 if j != dim_index {
555 grid_shape[j] = len;
556 }
557 }
558
559 let grid = Array::from_shape_vec(IxDyn(&grid_shape), values).unwrap();
560 grids.push(grid);
561 }
562
563 if matches!(indexing, Indexing::Xy) && ndim > 1 {
565 for grid in &mut grids {
566 grid.swap_axes(0, 1);
567 }
568 }
569
570 grids
571}