comic_text_detector/
lib.rs1use std::thread;
2
3use candle_transformers::object_detection::{Bbox, non_maximum_suppression};
4use hf_hub::api::sync::Api;
5use image::GenericImageView;
6use ort::{inputs, session::Session, value::TensorRef};
7use serde::Serialize;
8
9#[derive(Debug)]
10pub struct ComicTextDetector {
11 model: Session,
12}
13
14#[derive(Debug, Serialize)]
15pub struct Output {
16 pub bboxes: Vec<ClassifiedBbox>,
17 pub segment: Vec<u8>,
18}
19
20#[derive(Debug, Serialize)]
21pub struct ClassifiedBbox {
22 pub xmin: f32,
23 pub ymin: f32,
24 pub xmax: f32,
25 pub ymax: f32,
26 pub confidence: f32,
27 pub class: usize,
28}
29
30const MASK_THRESHOLD: u8 = 30;
31
32impl ComicTextDetector {
33 pub fn new() -> anyhow::Result<Self> {
34 let api = Api::new()?;
35 let repo = api.model("mayocream/comic-text-detector-onnx".to_string());
36 let model_path = repo.get("comic-text-detector.onnx")?;
37
38 let model = Session::builder()?
39 .with_optimization_level(ort::session::builder::GraphOptimizationLevel::Level3)?
40 .with_intra_threads(thread::available_parallelism()?.get())?
41 .commit_from_file(model_path)?;
42
43 Ok(ComicTextDetector { model })
44 }
45
46 pub fn inference(
47 &mut self,
48 image: &image::DynamicImage,
49 confidence_threshold: f32,
50 nms_threshold: f32,
51 ) -> anyhow::Result<Output> {
52 let (orig_width, orig_height) = image.dimensions();
53 let w_ratio = orig_width as f32 / 1024.0;
54 let h_ratio = orig_height as f32 / 1024.0;
55 let image = image.resize_exact(1024, 1024, image::imageops::FilterType::CatmullRom);
56
57 let mut input = ndarray::Array::zeros((1, 3, 1024, 1024));
58 for pixel in image.pixels() {
59 let x = pixel.0 as usize;
60 let y = pixel.1 as usize;
61 let [r, g, b, _] = pixel.2.0;
62 input[[0, 0, y, x]] = (r as f32) / 255.0;
63 input[[0, 1, y, x]] = (g as f32) / 255.0;
64 input[[0, 2, y, x]] = (b as f32) / 255.0;
65 }
66
67 let inputs = inputs!["images" => TensorRef::from_array_view(input.view())?];
68 let outputs = self.model.run(inputs)?;
69
70 let blk = outputs["blk"].try_extract_array::<f32>()?;
72 let blk = blk.view();
73
74 let mut boxes: Vec<Vec<Bbox<_>>> = (0..=1).map(|_| vec![]).collect();
75 for i in 0..blk.shape()[1] {
76 let confidence = blk[[0, i, 4]];
77 if confidence < confidence_threshold {
78 continue;
79 }
80
81 let mut class_index = 0;
82 if blk[[0, i, 5]] < blk[[0, i, 6]] {
83 class_index = 1;
84 }
85
86 let center_x = blk[[0, i, 0]] * w_ratio;
87 let center_y = blk[[0, i, 1]] * h_ratio;
88 let width = blk[[0, i, 2]] * w_ratio;
89 let height = blk[[0, i, 3]] * h_ratio;
90
91 boxes[class_index].push(Bbox {
92 confidence,
93 xmin: center_x - width / 2.,
94 ymin: center_y - height / 2.,
95 xmax: center_x + width / 2.,
96 ymax: center_y + height / 2.,
97 data: (),
98 });
99 }
100
101 non_maximum_suppression(&mut boxes, nms_threshold);
102
103 let mut bboxes: Vec<ClassifiedBbox> = vec![];
105 for (class_index, bboxes_for_class) in boxes.iter().enumerate() {
106 for bbox in bboxes_for_class {
107 bboxes.push(ClassifiedBbox {
108 xmin: bbox.xmin,
109 ymin: bbox.ymin,
110 xmax: bbox.xmax,
111 ymax: bbox.ymax,
112 confidence: bbox.confidence,
113 class: class_index,
114 });
115 }
116 }
117
118 let mask = outputs["seg"].try_extract_array::<f32>()?;
120 let mask = mask
121 .view()
122 .to_owned()
123 .into_dimensionality::<ndarray::Ix4>()?;
124 let mask_slice = mask.slice(ndarray::s![0, 0, .., ..]);
126
127 let thresholded = mask_slice.mapv(|x| {
129 let val = (255.0 * x).round() as u8;
130 if val < MASK_THRESHOLD { 0 } else { val }
131 });
132
133 let (segment, _) = thresholded.into_raw_vec_and_offset();
135 let segment = image::GrayImage::from_vec(1024, 1024, segment)
137 .ok_or_else(|| anyhow::anyhow!("Failed to create GrayImage"))?;
138 let segment = imageproc::morphology::grayscale_dilate(
139 &segment,
140 &imageproc::morphology::Mask::square(3),
141 );
142 let segment =
143 imageproc::morphology::erode(&segment, imageproc::distance_transform::Norm::L2, 1);
144 let segment = segment.into_raw();
145
146 Ok(Output { bboxes, segment })
147 }
148}