comic_text_detector/
lib.rs

1use candle_transformers::object_detection::{Bbox, non_maximum_suppression};
2use hf_hub::api::sync::Api;
3use image::{DynamicImage, GenericImageView};
4use ort::{inputs, session::Session, value::TensorRef};
5
6#[derive(Debug)]
7pub struct ComicTextDetector {
8    model: Session,
9}
10
11#[derive(Debug)]
12pub struct Output {
13    pub bboxes: Vec<ClassifiedBbox>,
14    pub segment: DynamicImage,
15}
16
17#[derive(Debug)]
18pub struct ClassifiedBbox {
19    pub xmin: f32,
20    pub ymin: f32,
21    pub xmax: f32,
22    pub ymax: f32,
23    pub confidence: f32,
24    pub class: usize,
25}
26
27const MASK_THRESHOLD: u8 = 30;
28
29impl ComicTextDetector {
30    pub fn new() -> anyhow::Result<Self> {
31        let api = Api::new()?;
32        let repo = api.model("mayocream/comic-text-detector-onnx".to_string());
33        let model_path = repo.get("comic-text-detector.onnx")?;
34
35        let model = Session::builder()?
36            .with_optimization_level(ort::session::builder::GraphOptimizationLevel::Level3)?
37            .commit_from_file(model_path)?;
38
39        Ok(ComicTextDetector { model })
40    }
41
42    pub fn inference(
43        &mut self,
44        image: &image::DynamicImage,
45        confidence_threshold: f32,
46        nms_threshold: f32,
47    ) -> anyhow::Result<Output> {
48        let (orig_width, orig_height) = image.dimensions();
49        let w_ratio = orig_width as f32 / 1024.0;
50        let h_ratio = orig_height as f32 / 1024.0;
51        let image = image.resize_exact(1024, 1024, image::imageops::FilterType::CatmullRom);
52
53        let mut input = ndarray::Array::zeros((1, 3, 1024, 1024));
54        for pixel in image.pixels() {
55            let x = pixel.0 as usize;
56            let y = pixel.1 as usize;
57            let [r, g, b, _] = pixel.2.0;
58            input[[0, 0, y, x]] = (r as f32) / 255.0;
59            input[[0, 1, y, x]] = (g as f32) / 255.0;
60            input[[0, 2, y, x]] = (b as f32) / 255.0;
61        }
62
63        let inputs = inputs!["images" => TensorRef::from_array_view(input.view())?];
64        let outputs = self.model.run(inputs)?;
65
66        // handle blocks
67        let blk = outputs["blk"].try_extract_array::<f32>()?;
68        let blk = blk.view();
69
70        let mut boxes: Vec<Vec<Bbox<_>>> = (0..=1).map(|_| vec![]).collect();
71        for i in 0..blk.shape()[1] {
72            let confidence = blk[[0, i, 4]];
73            if confidence < confidence_threshold {
74                continue;
75            }
76
77            let mut class_index = 0;
78            if blk[[0, i, 5]] < blk[[0, i, 6]] {
79                class_index = 1;
80            }
81
82            let center_x = blk[[0, i, 0]] * w_ratio;
83            let center_y = blk[[0, i, 1]] * h_ratio;
84            let width = blk[[0, i, 2]] * w_ratio;
85            let height = blk[[0, i, 3]] * h_ratio;
86
87            boxes[class_index].push(Bbox {
88                confidence,
89                xmin: center_x - width / 2.,
90                ymin: center_y - height / 2.,
91                xmax: center_x + width / 2.,
92                ymax: center_y + height / 2.,
93                data: (),
94            });
95        }
96
97        non_maximum_suppression(&mut boxes, nms_threshold);
98
99        // Convert to output format
100        let mut bboxes: Vec<ClassifiedBbox> = vec![];
101        for (class_index, bboxes_for_class) in boxes.iter().enumerate() {
102            for bbox in bboxes_for_class {
103                bboxes.push(ClassifiedBbox {
104                    xmin: bbox.xmin,
105                    ymin: bbox.ymin,
106                    xmax: bbox.xmax,
107                    ymax: bbox.ymax,
108                    confidence: bbox.confidence,
109                    class: class_index,
110                });
111            }
112        }
113
114        // handle masks
115        let mask = outputs["seg"].try_extract_array::<f32>()?;
116        let mask = mask
117            .view()
118            .to_owned()
119            .into_dimensionality::<ndarray::Ix4>()?;
120        // Extract the relevant 2D slice from the 4D array
121        let mask_slice = mask.slice(ndarray::s![0, 0, .., ..]);
122
123        // Create a new 2D array for the thresholded values
124        let thresholded = mask_slice.mapv(|x| {
125            let val = (255.0 * x).round() as u8;
126            if val < MASK_THRESHOLD { 0 } else { val }
127        });
128
129        // Convert to Vec
130        let (segment, _) = thresholded.into_raw_vec_and_offset();
131        let segment = image::GrayImage::from_vec(1024, 1024, segment)
132            .ok_or_else(|| anyhow::anyhow!("Failed to create GrayImage"))?;
133
134        // dilate the mask
135        let segment = imageproc::morphology::grayscale_dilate(
136            &segment,
137            &imageproc::morphology::Mask::square(3),
138        );
139        let segment =
140            imageproc::morphology::erode(&segment, imageproc::distance_transform::Norm::L2, 1);
141
142        // resize back to original size
143        let segment = DynamicImage::ImageLuma8(segment);
144        let segment = segment.resize_exact(
145            orig_width,
146            orig_height,
147            image::imageops::FilterType::CatmullRom,
148        );
149
150        Ok(Output { bboxes, segment })
151    }
152}