comic_text_detector/
lib.rs

1use std::thread;
2
3use candle_transformers::object_detection::{Bbox, non_maximum_suppression};
4use hf_hub::api::sync::Api;
5use image::GenericImageView;
6use ort::{inputs, session::Session, value::TensorRef};
7use serde::Serialize;
8
9#[derive(Debug)]
10pub struct ComicTextDetector {
11    model: Session,
12}
13
14#[derive(Debug, Serialize)]
15pub struct Output {
16    pub bboxes: Vec<ClassifiedBbox>,
17    pub segment: Vec<u8>,
18}
19
20#[derive(Debug, Serialize)]
21pub struct ClassifiedBbox {
22    pub xmin: f32,
23    pub ymin: f32,
24    pub xmax: f32,
25    pub ymax: f32,
26    pub confidence: f32,
27    pub class: usize,
28}
29
30const MASK_THRESHOLD: u8 = 30;
31
32impl ComicTextDetector {
33    pub fn new() -> anyhow::Result<Self> {
34        let api = Api::new()?;
35        let repo = api.model("mayocream/comic-text-detector-onnx".to_string());
36        let model_path = repo.get("comic-text-detector.onnx")?;
37
38        let model = Session::builder()?
39            .with_optimization_level(ort::session::builder::GraphOptimizationLevel::Level3)?
40            .with_intra_threads(thread::available_parallelism()?.get())?
41            .commit_from_file(model_path)?;
42
43        Ok(ComicTextDetector { model })
44    }
45
46    pub fn inference(
47        &mut self,
48        image: &image::DynamicImage,
49        confidence_threshold: f32,
50        nms_threshold: f32,
51    ) -> anyhow::Result<Output> {
52        let (orig_width, orig_height) = image.dimensions();
53        let w_ratio = orig_width as f32 / 1024.0;
54        let h_ratio = orig_height as f32 / 1024.0;
55        let image = image.resize_exact(1024, 1024, image::imageops::FilterType::CatmullRom);
56
57        let mut input = ndarray::Array::zeros((1, 3, 1024, 1024));
58        for pixel in image.pixels() {
59            let x = pixel.0 as usize;
60            let y = pixel.1 as usize;
61            let [r, g, b, _] = pixel.2.0;
62            input[[0, 0, y, x]] = (r as f32) / 255.0;
63            input[[0, 1, y, x]] = (g as f32) / 255.0;
64            input[[0, 2, y, x]] = (b as f32) / 255.0;
65        }
66
67        let inputs = inputs!["images" => TensorRef::from_array_view(input.view())?];
68        let outputs = self.model.run(inputs)?;
69
70        // handle blocks
71        let blk = outputs["blk"].try_extract_array::<f32>()?;
72        let blk = blk.view();
73
74        let mut boxes: Vec<Vec<Bbox<_>>> = (0..=1).map(|_| vec![]).collect();
75        for i in 0..blk.shape()[1] {
76            let confidence = blk[[0, i, 4]];
77            if confidence < confidence_threshold {
78                continue;
79            }
80
81            let mut class_index = 0;
82            if blk[[0, i, 5]] < blk[[0, i, 6]] {
83                class_index = 1;
84            }
85
86            let center_x = blk[[0, i, 0]] * w_ratio;
87            let center_y = blk[[0, i, 1]] * h_ratio;
88            let width = blk[[0, i, 2]] * w_ratio;
89            let height = blk[[0, i, 3]] * h_ratio;
90
91            boxes[class_index].push(Bbox {
92                confidence,
93                xmin: center_x - width / 2.,
94                ymin: center_y - height / 2.,
95                xmax: center_x + width / 2.,
96                ymax: center_y + height / 2.,
97                data: (),
98            });
99        }
100
101        non_maximum_suppression(&mut boxes, nms_threshold);
102
103        // Convert to output format
104        let mut bboxes: Vec<ClassifiedBbox> = vec![];
105        for (class_index, bboxes_for_class) in boxes.iter().enumerate() {
106            for bbox in bboxes_for_class {
107                bboxes.push(ClassifiedBbox {
108                    xmin: bbox.xmin,
109                    ymin: bbox.ymin,
110                    xmax: bbox.xmax,
111                    ymax: bbox.ymax,
112                    confidence: bbox.confidence,
113                    class: class_index,
114                });
115            }
116        }
117
118        // handle masks
119        let mask = outputs["seg"].try_extract_array::<f32>()?;
120        let mask = mask
121            .view()
122            .to_owned()
123            .into_dimensionality::<ndarray::Ix4>()?;
124        // Extract the relevant 2D slice from the 4D array
125        let mask_slice = mask.slice(ndarray::s![0, 0, .., ..]);
126
127        // Create a new 2D array for the thresholded values
128        let thresholded = mask_slice.mapv(|x| {
129            let val = (255.0 * x).round() as u8;
130            if val < MASK_THRESHOLD { 0 } else { val }
131        });
132
133        // Convert to Vec
134        let (segment, _) = thresholded.into_raw_vec_and_offset();
135        // dilate the mask
136        let segment = image::GrayImage::from_vec(1024, 1024, segment)
137            .ok_or_else(|| anyhow::anyhow!("Failed to create GrayImage"))?;
138        let segment = imageproc::morphology::grayscale_dilate(
139            &segment,
140            &imageproc::morphology::Mask::square(3),
141        );
142        let segment =
143            imageproc::morphology::erode(&segment, imageproc::distance_transform::Norm::L2, 1);
144        let segment = segment.into_raw();
145
146        Ok(Output { bboxes, segment })
147    }
148}