Skip to main content

manga_ocr_rs/
lib.rs

1//! Japanese manga OCR — image-to-text for scanned manga and printed Japanese.
2//!
3//! Runs [mayocream/manga-ocr-onnx](https://huggingface.co/mayocream/manga-ocr-onnx)
4//! (kha-white/manga-ocr-base exported to ONNX) via ONNX Runtime.
5//! Returns raw Japanese text; no translation, no furigana stripping.
6//!
7//! Handles yokogaki (horizontal), tategaki (vertical), and tegaki
8//! (handwritten) text.  Images are squish-resized to 224×224 matching the
9//! original training pipeline.
10//!
11//! # Quick start
12//!
13//! ```no_run
14//! use manga_ocr_rs::MangaOcr;
15//!
16//! let ocr = MangaOcr::new(manga_ocr_rs::default_model_dir()).unwrap();
17//! let img = image::open("panel.png").unwrap();
18//! println!("{}", ocr.recognize(&img).unwrap());
19//!
20//! // With confidence scores:
21//! let r = ocr.recognize_with_score(&img).unwrap();
22//! println!("{} (confidence: {:.4})", r.text, r.confidence);
23//! ```
24//!
25//! Models are downloaded automatically on first `cargo build` via `build.rs`.
26//! Override the location by setting `MANGA_OCR_MODELS_DIR` before building.
27
28use anyhow::{Context, Result};
29use image::{imageops, DynamicImage};
30use ort::session::Session;
31use ort::value::Tensor as OrtTensor;
32use std::cmp::Ordering;
33use std::path::Path;
34use std::sync::Mutex;
35
36// ── Preprocessing ─────────────────────────────────────────────────────────────
37const IMG_SIZE: usize = 224;
38const PIXEL_MEAN: f32 = 0.5;
39const PIXEL_STD: f32  = 0.5;
40
41// ── Generation (from generation_config.json) ──────────────────────────────────
42const DECODER_START_TOKEN_ID: i64 = 2;
43const EOS_TOKEN_ID: i64           = 3;
44const DEFAULT_MAX_DECODE_STEPS: usize = 50;
45const NUM_BEAMS: usize            = 4;
46const LENGTH_PENALTY: f32         = 2.0;
47const NO_REPEAT_NGRAM: usize      = 3;
48
49// ── Early bailout — abort hallucinating decoders before they waste time ───────
50// After MIN_TOKENS tokens, if the best beam's per-token geometric mean
51// probability drops below CONFIDENCE, the decoder is generating garbage.
52// Breaking early saves ~67% of the hallucination time (16/50 steps instead
53// of running the full loop, then getting caught by the confidence gate).
54const EARLY_BAILOUT_MIN_TOKENS: usize = 16;
55const EARLY_BAILOUT_CONFIDENCE: f32   = 0.30;
56
57// ── Default model directory (set by build.rs) ─────────────────────────────────
58
59/// Returns the directory where `build.rs` downloaded (or expects) the model files.
60///
61/// Set `MANGA_OCR_MODELS_DIR` at build time to override:
62/// ```bash
63/// MANGA_OCR_MODELS_DIR=/my/models cargo build
64/// ```
65pub fn default_model_dir() -> &'static Path {
66    Path::new(env!("MANGA_OCR_DEFAULT_MODEL_DIR"))
67}
68
69// ── Preprocessing ─────────────────────────────────────────────────────────────
70
71/// Preprocess matching kha-white/manga-ocr-base's ViTImageProcessor:
72///
73/// 1. Grayscale → RGB  (`convert("L").convert("RGB")` in the original)
74/// 2. Resize directly to 224×224 with Bilinear  (squishes — no padding)
75/// 3. Normalise to [-1, 1]  (mean=0.5, std=0.5)
76///
77/// The original model was trained on squished (non-aspect-preserving) resizes,
78/// so we must NOT centre-pad to square — doing so degrades accuracy.
79///
80/// Returns shape `[1, 3, H, W]` + flat NCHW `Vec<f32>`.
81fn preprocess(img: &DynamicImage) -> ([usize; 4], Vec<f32>) {
82    // Bilinear matches preprocessor_config.json "resample": 2 (PIL.Image.BILINEAR)
83    let resized = img
84        .grayscale()
85        .resize_exact(IMG_SIZE as u32, IMG_SIZE as u32, imageops::FilterType::Triangle)
86        .to_rgb8();
87
88    let mut flat = vec![0.0f32; 3 * IMG_SIZE * IMG_SIZE];
89    for y in 0..IMG_SIZE {
90        for x in 0..IMG_SIZE {
91            let p = resized.get_pixel(x as u32, y as u32);
92            flat[0 * IMG_SIZE * IMG_SIZE + y * IMG_SIZE + x] = (p[0] as f32 / 255.0 - PIXEL_MEAN) / PIXEL_STD;
93            flat[1 * IMG_SIZE * IMG_SIZE + y * IMG_SIZE + x] = (p[1] as f32 / 255.0 - PIXEL_MEAN) / PIXEL_STD;
94            flat[2 * IMG_SIZE * IMG_SIZE + y * IMG_SIZE + x] = (p[2] as f32 / 255.0 - PIXEL_MEAN) / PIXEL_STD;
95        }
96    }
97    ([1, 3, IMG_SIZE, IMG_SIZE], flat)
98}
99
100// ── Vocabulary ────────────────────────────────────────────────────────────────
101//
102// vocab.txt is a line-indexed file: token ID = line number (0-based).
103// Special tokens to skip when decoding:
104//   0=[PAD]  1=[UNK]  2=[CLS]/BOS  3=[SEP]/EOS  4=[MASK]  5-14=<unused0-9>
105
106struct VocabDecoder {
107    tokens: Vec<String>,
108}
109
110impl VocabDecoder {
111    fn from_file(path: &Path) -> Result<Self> {
112        let content = std::fs::read_to_string(path)
113            .with_context(|| format!("read {}", path.display()))?;
114        let tokens = content.lines().map(str::to_owned).collect();
115        Ok(Self { tokens })
116    }
117
118    fn decode(&self, ids: &[i64]) -> String {
119        ids.iter()
120            .filter_map(|&id| {
121                let uid = id as usize;
122                if uid < 15 { return None; }
123                self.tokens.get(uid)
124            })
125            .map(|tok| tok.trim_start_matches("##"))
126            .collect()
127    }
128}
129
130// ── Beam search helpers ───────────────────────────────────────────────────────
131
132fn log_softmax(logits: &[f32]) -> Vec<f32> {
133    let max = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
134    let sum_exp: f32 = logits.iter().map(|&x| (x - max).exp()).sum();
135    let log_z = sum_exp.ln() + max;
136    logits.iter().map(|&x| x - log_z).collect()
137}
138
139fn top_k(log_probs: &[f32], k: usize) -> Vec<(usize, f32)> {
140    let mut v: Vec<(usize, f32)> = log_probs.iter().enumerate().map(|(i, &p)| (i, p)).collect();
141    v.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal));
142    v.truncate(k);
143    v
144}
145
146fn apply_no_repeat_ngram(ids: &[i64], log_probs: &mut [f32]) {
147    let n = NO_REPEAT_NGRAM;
148    if ids.len() < n - 1 || n < 2 { return; }
149    let prefix = &ids[ids.len() - (n - 1)..];
150    for i in 0..ids.len().saturating_sub(n - 1) {
151        if ids[i..i + n - 1] == *prefix {
152            let tok = ids[i + n - 1] as usize;
153            if tok < log_probs.len() {
154                log_probs[tok] = f32::NEG_INFINITY;
155            }
156        }
157    }
158}
159
160// ── Confidence calibration ────────────────────────────────────────────────────
161//
162// The model is trained on 224×224.  Crops that are too small lose detail on
163// up-scale; very large crops lose detail on down-scale; extreme aspect ratios
164// cause heavy squishing.  These tables apply a multiplicative penalty to the
165// raw token-level confidence so the returned value better reflects expected
166// accuracy.  Thresholds are tuned empirically — adjust as needed.
167//
168// Format: (exclusive_upper_bound, factor).  First match wins.
169
170/// Calibration by crop area (width × height in pixels).
171const AREA_CALIBRATION: &[(f32, f32)] = &[
172    (2_500.0,     0.3),  // < ~50×50: too small for 224×224 resize
173    (10_000.0,    0.6),  // < ~100×100: marginal
174    (500_000.0,   1.0),  // sweet spot for manga speech-bubble crops
175    (2_000_000.0, 0.9),  // large — some detail loss in downscale
176    (f32::MAX,    0.7),  // very large — heavy downscale
177];
178
179/// Calibration by aspect ratio (max_dim / min_dim).
180const ASPECT_CALIBRATION: &[(f32, f32)] = &[
181    (2.0,     1.0),   // near-square: no penalty
182    (4.0,     0.95),  // moderate stretch
183    (8.0,     0.85),  // significant squish distortion
184    (f32::MAX, 0.7),  // extreme aspect ratio
185];
186
187fn dimension_calibration(width: u32, height: u32) -> f32 {
188    let area = width as f32 * height as f32;
189    let aspect = width.max(height) as f32 / width.min(height).max(1) as f32;
190
191    let area_factor = AREA_CALIBRATION.iter()
192        .find(|(bound, _)| area < *bound)
193        .map(|(_, f)| *f)
194        .unwrap_or(1.0);
195
196    let aspect_factor = ASPECT_CALIBRATION.iter()
197        .find(|(bound, _)| aspect < *bound)
198        .map(|(_, f)| *f)
199        .unwrap_or(1.0);
200
201    area_factor * aspect_factor
202}
203
204// ── Public API ────────────────────────────────────────────────────────────────
205
206/// Result of OCR recognition, including confidence metrics.
207///
208/// Returned by [`MangaOcr::recognize_with_score`].
209///
210/// # Confidence scores
211///
212/// - `confidence` — geometric mean of per-token probabilities (0.0–1.0),
213///   adjusted for image dimensions.  Use this for accept/reject thresholding.
214/// - `raw_confidence` — same metric before dimension calibration.
215/// - `score` — length-normalised beam score used internally for beam selection.
216///
217/// # Suppressed beam-search data (available for future use)
218///
219/// The beam search internally computes more than what is surfaced here.
220/// These are not yet exposed but could be added to this struct:
221///
222/// - **Per-token log probabilities** — each token's individual log-prob is
223///   computed during candidate scoring but only the accumulated sum survives.
224///   Would allow callers to highlight which specific characters the model was
225///   uncertain about.
226///
227/// - **Alternative beams** — `completed` holds all finished beams (up to
228///   `NUM_BEAMS`), but only the best is returned.  Runner-up texts could aid
229///   disambiguation (e.g. `テスト` vs `ラスト` — if the top-2 beams disagree,
230///   that signals uncertainty the confidence score alone doesn't capture).
231///
232/// - **Decode step count** — number of decoder iterations before the winning
233///   beam terminated.  Distinct from `truncated` (which is binary): knowing
234///   the model took 5 steps vs 250 gives a sense of output complexity.
235#[derive(Debug, Clone)]
236pub struct Recognition {
237    /// Decoded Japanese text.
238    pub text: String,
239    /// Length-normalised beam score (higher is better).
240    /// Computed as `accumulated_log_prob / token_count ^ LENGTH_PENALTY`.
241    pub score: f32,
242    /// Geometric mean of per-token probabilities (0.0–1.0), before calibration.
243    pub raw_confidence: f32,
244    /// Dimension-adjusted confidence (0.0–1.0).  Penalises crops that are too
245    /// small, too large, or have extreme aspect ratios.
246    pub confidence: f32,
247    /// `true` if the decoder hit `max_decode_steps` without emitting EOS.
248    /// Strong hallucination signal — runaway generation almost always means
249    /// the output is garbage.
250    pub truncated: bool,
251    /// Number of tokens generated (excluding BOS).  Combined with image
252    /// dimensions, enables a characters-per-pixel heuristic: a 50×80 crop
253    /// producing 200 tokens is garbage regardless of confidence.
254    pub token_count: usize,
255}
256
257/// OCR engine wrapping the encoder + decoder ONNX sessions and vocabulary.
258///
259/// Construct once (model load is expensive), then call [`recognize`] or
260/// [`recognize_with_score`] repeatedly.
261/// `Session::run` requires `&mut Session`, so each session is wrapped in a
262/// `Mutex` — this lets `MangaOcr` be shared as `Arc<MangaOcr>` across threads.
263///
264/// [`recognize`]: MangaOcr::recognize
265/// [`recognize_with_score`]: MangaOcr::recognize_with_score
266pub struct MangaOcr {
267    encoder: Mutex<Session>,
268    decoder: Mutex<Session>,
269    vocab:   VocabDecoder,
270    max_decode_steps: usize,
271}
272
273impl MangaOcr {
274    /// Load models from `model_dir`.
275    ///
276    /// Expects these files inside the directory:
277    /// - `encoder_model.onnx`
278    /// - `decoder_model.onnx`
279    /// - `vocab.txt`
280    ///
281    /// Use [`default_model_dir()`] to get the path that `build.rs` prepared.
282    pub fn new(model_dir: &Path) -> Result<Self> {
283        let enc_path = model_dir.join("encoder_model.onnx");
284        let dec_path = model_dir.join("decoder_model.onnx");
285        let tok_path = model_dir.join("vocab.txt");
286
287        let encoder = Session::builder()
288            .context("encoder: SessionBuilder")?
289            .commit_from_file(&enc_path)
290            .with_context(|| format!("encoder: open {}", enc_path.display()))?;
291
292        let decoder = Session::builder()
293            .context("decoder: SessionBuilder")?
294            .commit_from_file(&dec_path)
295            .with_context(|| format!("decoder: open {}", dec_path.display()))?;
296
297        let vocab = VocabDecoder::from_file(&tok_path)?;
298
299        Ok(Self {
300            encoder: Mutex::new(encoder),
301            decoder: Mutex::new(decoder),
302            vocab,
303            max_decode_steps: DEFAULT_MAX_DECODE_STEPS,
304        })
305    }
306
307    /// Set the maximum number of decoder steps (default: 50).
308    ///
309    /// No manga speech bubble needs more than ~50 tokens.  The original
310    /// `generation_config.json` ships 300, but that only increases worst-case
311    /// latency when the decoder runs away on garbage input.
312    pub fn with_max_decode_steps(mut self, n: usize) -> Self {
313        self.max_decode_steps = n.max(1);
314        self
315    }
316
317    /// OCR one image crop.  Returns raw Japanese text; no translation.
318    ///
319    /// Convenience wrapper around [`recognize_with_score`] that discards
320    /// confidence metrics.
321    ///
322    /// [`recognize_with_score`]: MangaOcr::recognize_with_score
323    pub fn recognize(&self, img: &DynamicImage) -> Result<String> {
324        self.recognize_with_score(img).map(|r| r.text)
325    }
326
327    /// OCR one image crop, returning text with confidence scores.
328    ///
329    /// Works on any aspect ratio: tategaki (tall), yokogaki (wide), tegaki
330    /// (handwritten).  Images are squish-resized to 224×224 (no padding),
331    /// matching the original training pipeline.
332    ///
333    /// Uses beam search (4 beams) matching `generation_config.json`.
334    pub fn recognize_with_score(&self, img: &DynamicImage) -> Result<Recognition> {
335        let (img_w, img_h) = (img.width(), img.height());
336
337        // ── 1. Encode ────────────────────────────────────────────────────────
338        let (pv_shape, pv_data) = preprocess(img);
339        let pv_tensor = OrtTensor::<f32>::from_array((pv_shape, pv_data))
340            .context("pixel_values tensor")?;
341
342        let (enc_seq_len, hidden_dim, enc_hidden) = {
343            let mut enc = self.encoder.lock()
344                .map_err(|e| anyhow::anyhow!("encoder lock poisoned: {e}"))?;
345            let out = enc.run(ort::inputs!["pixel_values" => pv_tensor])
346                .context("encoder run")?;
347            let (shape, data) = out["last_hidden_state"]
348                .try_extract_tensor::<f32>()
349                .context("encoder: extract last_hidden_state")?;
350            (shape[1] as usize, shape[2] as usize, data.to_vec())
351        };
352
353        // ── 2. Beam search ───────────────────────────────────────────────────
354        let seeds = {
355            let mut lp = self.decoder_logprobs_single(
356                &[DECODER_START_TOKEN_ID], enc_seq_len, hidden_dim, &enc_hidden,
357            )?;
358            apply_no_repeat_ngram(&[DECODER_START_TOKEN_ID], &mut lp);
359            top_k(&lp, NUM_BEAMS)
360        };
361
362        let mut beams: Vec<(Vec<i64>, f32, bool)> = seeds.iter()
363            .map(|&(tok, lp)| {
364                let done = tok as i64 == EOS_TOKEN_ID;
365                (vec![DECODER_START_TOKEN_ID, tok as i64], lp, done)
366            })
367            .collect();
368        // completed: (token_ids, accumulated_log_prob, hit_eos)
369        let mut completed: Vec<(Vec<i64>, f32, bool)> = beams.iter()
370            .filter(|(_, _, done)| *done)
371            .map(|(ids, score, _)| (ids.clone(), *score, true))
372            .collect();
373
374        for _step in 1..self.max_decode_steps {
375            let active: Vec<usize> = beams.iter().enumerate()
376                .filter(|(_, (_, _, done))| !*done)
377                .map(|(i, _)| i)
378                .collect();
379            if active.is_empty() { break; }
380
381            let batch = active.len();
382            let seq_len = beams[active[0]].0.len();
383
384            let flat_ids: Vec<i64> = active.iter()
385                .flat_map(|&i| beams[i].0.iter().copied())
386                .collect();
387
388            let flat_enc: Vec<f32> = enc_hidden.iter().copied()
389                .cycle()
390                .take(batch * enc_seq_len * hidden_dim)
391                .collect();
392
393            let batch_log_probs: Vec<Vec<f32>> = {
394                let mut dec = self.decoder.lock()
395                    .map_err(|e| anyhow::anyhow!("decoder lock poisoned: {e}"))?;
396                let out = dec.run(ort::inputs![
397                    "input_ids" =>
398                        OrtTensor::<i64>::from_array(([batch, seq_len], flat_ids))
399                            .context("input_ids tensor")?,
400                    "encoder_hidden_states" =>
401                        OrtTensor::<f32>::from_array(([batch, enc_seq_len, hidden_dim], flat_enc))
402                            .context("encoder_hidden_states tensor")?
403                ]).context("decoder batch run")?;
404
405                let (logits_shape, logits_data) = out["logits"]
406                    .try_extract_tensor::<f32>()
407                    .context("logits")?;
408                let vocab_size = logits_shape[2] as usize;
409
410                (0..batch).map(|b| {
411                    let offset = (b * seq_len + seq_len - 1) * vocab_size;
412                    log_softmax(&logits_data[offset..offset + vocab_size])
413                }).collect()
414            };
415
416            let mut candidates: Vec<(usize, i64, f32)> = Vec::with_capacity(batch * NUM_BEAMS);
417            for (b, &beam_idx) in active.iter().enumerate() {
418                let beam_score = beams[beam_idx].1;
419                let mut lp = batch_log_probs[b].clone();
420                apply_no_repeat_ngram(&beams[beam_idx].0, &mut lp);
421                for (tok, lp) in top_k(&lp, NUM_BEAMS) {
422                    candidates.push((beam_idx, tok as i64, beam_score + lp));
423                }
424            }
425            candidates.sort_unstable_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(Ordering::Equal));
426            candidates.truncate(NUM_BEAMS);
427
428            beams = candidates.iter().map(|&(old_idx, tok, score)| {
429                let mut ids = beams[old_idx].0.clone();
430                let done = tok == EOS_TOKEN_ID;
431                if !done { ids.push(tok); }
432                (ids, score, done)
433            }).collect();
434
435            for (ids, score, done) in &beams {
436                if *done { completed.push((ids.clone(), *score, true)); }
437            }
438
439            // Early bailout: if the best active beam's running confidence
440            // drops below the threshold after enough tokens, the decoder is
441            // hallucinating.  Break now — the force-push below will mark
442            // these beams as truncated (hit_eos=false).
443            let best_active = beams.iter()
444                .filter(|(_, _, done)| !*done)
445                .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
446            if let Some((ids, score, _)) = best_active {
447                let num_generated = ids.len().saturating_sub(1);
448                if num_generated >= EARLY_BAILOUT_MIN_TOKENS {
449                    let running_conf = (*score / num_generated as f32).exp();
450                    if running_conf < EARLY_BAILOUT_CONFIDENCE {
451                        eprintln!(
452                            "[manga-ocr] early bailout at {} tokens — confidence {:.1}% < {:.0}%",
453                            num_generated, running_conf * 100.0, EARLY_BAILOUT_CONFIDENCE * 100.0,
454                        );
455                        break;
456                    }
457                }
458            }
459        }
460
461        // Force-push beams that never emitted EOS — these were truncated at
462        // max_decode_steps and are almost certainly hallucinations.
463        for (ids, score, done) in &beams {
464            if !done { completed.push((ids.clone(), *score, false)); }
465        }
466
467        // ── 3. Pick best beam (length-normalised score) ───────────────────────
468        let (best_ids, best_raw_score, hit_eos) = completed.iter()
469            .max_by(|(ids_a, score_a, _), (ids_b, score_b, _)| {
470                let norm = |ids: &[i64], s: f32| s / (ids.len() as f32).powf(LENGTH_PENALTY);
471                norm(ids_a, *score_a).partial_cmp(&norm(ids_b, *score_b))
472                    .unwrap_or(Ordering::Equal)
473            })
474            .map(|(ids, score, eos)| (ids.as_slice(), *score, *eos))
475            .unwrap_or((&[], 0.0, false));
476
477        // ── 4. Detokenise & compute confidence ────────────────────────────────
478        let decode_ids = if best_ids.first() == Some(&DECODER_START_TOKEN_ID) {
479            &best_ids[1..]
480        } else {
481            best_ids
482        };
483        let text = self.vocab.decode(decode_ids);
484
485        let token_count = best_ids.len().saturating_sub(1); // exclude BOS
486        let num_tokens = token_count.max(1) as f32;
487        let score = best_raw_score / (best_ids.len().max(1) as f32).powf(LENGTH_PENALTY);
488        let raw_confidence = (best_raw_score / num_tokens).exp();
489        let calibration = dimension_calibration(img_w, img_h);
490        let confidence = raw_confidence * calibration;
491
492        Ok(Recognition { text, score, raw_confidence, confidence, truncated: !hit_eos, token_count })
493    }
494
495    fn decoder_logprobs_single(
496        &self,
497        ids: &[i64],
498        enc_seq_len: usize,
499        hidden_dim: usize,
500        enc_hidden: &[f32],
501    ) -> Result<Vec<f32>> {
502        let seq_len = ids.len();
503        let mut dec = self.decoder.lock()
504            .map_err(|e| anyhow::anyhow!("decoder lock poisoned: {e}"))?;
505        let out = dec.run(ort::inputs![
506            "input_ids" =>
507                OrtTensor::<i64>::from_array(([1usize, seq_len], ids.to_vec()))
508                    .context("input_ids tensor")?,
509            "encoder_hidden_states" =>
510                OrtTensor::<f32>::from_array(([1usize, enc_seq_len, hidden_dim], enc_hidden.to_vec()))
511                    .context("encoder_hidden_states tensor")?
512        ]).context("decoder bootstrap run")?;
513
514        let (logits_shape, logits_data) = out["logits"]
515            .try_extract_tensor::<f32>()
516            .context("logits")?;
517        let vocab_size = logits_shape[2] as usize;
518        let last = &logits_data[(seq_len - 1) * vocab_size..seq_len * vocab_size];
519        Ok(log_softmax(last))
520    }
521}