comic_text_detector/
lib.rs1use candle_transformers::object_detection::{Bbox, non_maximum_suppression};
2use hf_hub::api::sync::Api;
3use image::{DynamicImage, GenericImageView};
4use ort::{inputs, session::Session, value::TensorRef};
5
6#[derive(Debug)]
7pub struct ComicTextDetector {
8 model: Session,
9}
10
11#[derive(Debug)]
12pub struct Output {
13 pub bboxes: Vec<ClassifiedBbox>,
14 pub segment: DynamicImage,
15}
16
17#[derive(Debug)]
18pub struct ClassifiedBbox {
19 pub xmin: f32,
20 pub ymin: f32,
21 pub xmax: f32,
22 pub ymax: f32,
23 pub confidence: f32,
24 pub class: usize,
25}
26
27const MASK_THRESHOLD: u8 = 30;
28
29impl ComicTextDetector {
30 pub fn new() -> anyhow::Result<Self> {
31 let api = Api::new()?;
32 let repo = api.model("mayocream/comic-text-detector-onnx".to_string());
33 let model_path = repo.get("comic-text-detector.onnx")?;
34
35 let model = Session::builder()?
36 .with_optimization_level(ort::session::builder::GraphOptimizationLevel::Level3)?
37 .commit_from_file(model_path)?;
38
39 Ok(ComicTextDetector { model })
40 }
41
42 pub fn inference(
43 &mut self,
44 image: &image::DynamicImage,
45 confidence_threshold: f32,
46 nms_threshold: f32,
47 ) -> anyhow::Result<Output> {
48 let (orig_width, orig_height) = image.dimensions();
49 let w_ratio = orig_width as f32 / 1024.0;
50 let h_ratio = orig_height as f32 / 1024.0;
51 let image = image.resize_exact(1024, 1024, image::imageops::FilterType::CatmullRom);
52
53 let mut input = ndarray::Array::zeros((1, 3, 1024, 1024));
54 for pixel in image.pixels() {
55 let x = pixel.0 as usize;
56 let y = pixel.1 as usize;
57 let [r, g, b, _] = pixel.2.0;
58 input[[0, 0, y, x]] = (r as f32) / 255.0;
59 input[[0, 1, y, x]] = (g as f32) / 255.0;
60 input[[0, 2, y, x]] = (b as f32) / 255.0;
61 }
62
63 let inputs = inputs!["images" => TensorRef::from_array_view(input.view())?];
64 let outputs = self.model.run(inputs)?;
65
66 let blk = outputs["blk"].try_extract_array::<f32>()?;
68 let blk = blk.view();
69
70 let mut boxes: Vec<Vec<Bbox<_>>> = (0..=1).map(|_| vec![]).collect();
71 for i in 0..blk.shape()[1] {
72 let confidence = blk[[0, i, 4]];
73 if confidence < confidence_threshold {
74 continue;
75 }
76
77 let mut class_index = 0;
78 if blk[[0, i, 5]] < blk[[0, i, 6]] {
79 class_index = 1;
80 }
81
82 let center_x = blk[[0, i, 0]] * w_ratio;
83 let center_y = blk[[0, i, 1]] * h_ratio;
84 let width = blk[[0, i, 2]] * w_ratio;
85 let height = blk[[0, i, 3]] * h_ratio;
86
87 boxes[class_index].push(Bbox {
88 confidence,
89 xmin: center_x - width / 2.,
90 ymin: center_y - height / 2.,
91 xmax: center_x + width / 2.,
92 ymax: center_y + height / 2.,
93 data: (),
94 });
95 }
96
97 non_maximum_suppression(&mut boxes, nms_threshold);
98
99 let mut bboxes: Vec<ClassifiedBbox> = vec![];
101 for (class_index, bboxes_for_class) in boxes.iter().enumerate() {
102 for bbox in bboxes_for_class {
103 bboxes.push(ClassifiedBbox {
104 xmin: bbox.xmin,
105 ymin: bbox.ymin,
106 xmax: bbox.xmax,
107 ymax: bbox.ymax,
108 confidence: bbox.confidence,
109 class: class_index,
110 });
111 }
112 }
113
114 let mask = outputs["seg"].try_extract_array::<f32>()?;
116 let mask = mask
117 .view()
118 .to_owned()
119 .into_dimensionality::<ndarray::Ix4>()?;
120 let mask_slice = mask.slice(ndarray::s![0, 0, .., ..]);
122
123 let thresholded = mask_slice.mapv(|x| {
125 let val = (255.0 * x).round() as u8;
126 if val < MASK_THRESHOLD { 0 } else { val }
127 });
128
129 let (segment, _) = thresholded.into_raw_vec_and_offset();
131 let segment = image::GrayImage::from_vec(1024, 1024, segment)
132 .ok_or_else(|| anyhow::anyhow!("Failed to create GrayImage"))?;
133
134 let segment = imageproc::morphology::grayscale_dilate(
136 &segment,
137 &imageproc::morphology::Mask::square(3),
138 );
139 let segment =
140 imageproc::morphology::erode(&segment, imageproc::distance_transform::Norm::L2, 1);
141
142 let segment = DynamicImage::ImageLuma8(segment);
144 let segment = segment.resize_exact(
145 orig_width,
146 orig_height,
147 image::imageops::FilterType::CatmullRom,
148 );
149
150 Ok(Output { bboxes, segment })
151 }
152}