manga_ocr_rs/lib.rs
1//! Japanese manga OCR — image-to-text for scanned manga and printed Japanese.
2//!
3//! Runs [mayocream/manga-ocr-onnx](https://huggingface.co/mayocream/manga-ocr-onnx)
4//! (kha-white/manga-ocr-base exported to ONNX) via ONNX Runtime.
5//! Returns raw Japanese text; no translation, no furigana stripping.
6//!
7//! Handles yokogaki (horizontal), tategaki (vertical), and tegaki
8//! (handwritten) text. Images are squish-resized to 224×224 matching the
9//! original training pipeline.
10//!
11//! # Quick start
12//!
13//! ```no_run
14//! use manga_ocr_rs::MangaOcr;
15//!
16//! let ocr = MangaOcr::new(manga_ocr_rs::default_model_dir()).unwrap();
17//! let img = image::open("panel.png").unwrap();
18//! println!("{}", ocr.recognize(&img).unwrap());
19//!
20//! // With confidence scores:
21//! let r = ocr.recognize_with_score(&img).unwrap();
22//! println!("{} (confidence: {:.4})", r.text, r.confidence);
23//! ```
24//!
25//! Models are downloaded automatically on first `cargo build` via `build.rs`.
26//! Override the location by setting `MANGA_OCR_MODELS_DIR` before building.
27
28use anyhow::{Context, Result};
29use image::{imageops, DynamicImage};
30use ort::session::Session;
31use ort::value::Tensor as OrtTensor;
32use std::cmp::Ordering;
33use std::path::Path;
34use std::sync::Mutex;
35
36// ── Preprocessing ─────────────────────────────────────────────────────────────
37const IMG_SIZE: usize = 224;
38const PIXEL_MEAN: f32 = 0.5;
39const PIXEL_STD: f32 = 0.5;
40
41// ── Generation (from generation_config.json) ──────────────────────────────────
42const DECODER_START_TOKEN_ID: i64 = 2;
43const EOS_TOKEN_ID: i64 = 3;
44const DEFAULT_MAX_DECODE_STEPS: usize = 50;
45const NUM_BEAMS: usize = 4;
46const LENGTH_PENALTY: f32 = 2.0;
47const NO_REPEAT_NGRAM: usize = 3;
48
49// ── Early bailout — abort hallucinating decoders before they waste time ───────
50// After MIN_TOKENS tokens, if the best beam's per-token geometric mean
51// probability drops below CONFIDENCE, the decoder is generating garbage.
52// Breaking early saves ~67% of the hallucination time (16/50 steps instead
53// of running the full loop, then getting caught by the confidence gate).
54const EARLY_BAILOUT_MIN_TOKENS: usize = 16;
55const EARLY_BAILOUT_CONFIDENCE: f32 = 0.30;
56
57// ── Default model directory (set by build.rs) ─────────────────────────────────
58
59/// Returns the directory where `build.rs` downloaded (or expects) the model files.
60///
61/// Set `MANGA_OCR_MODELS_DIR` at build time to override:
62/// ```bash
63/// MANGA_OCR_MODELS_DIR=/my/models cargo build
64/// ```
65pub fn default_model_dir() -> &'static Path {
66 Path::new(env!("MANGA_OCR_DEFAULT_MODEL_DIR"))
67}
68
69// ── Preprocessing ─────────────────────────────────────────────────────────────
70
71/// Preprocess matching kha-white/manga-ocr-base's ViTImageProcessor:
72///
73/// 1. Grayscale → RGB (`convert("L").convert("RGB")` in the original)
74/// 2. Resize directly to 224×224 with Bilinear (squishes — no padding)
75/// 3. Normalise to [-1, 1] (mean=0.5, std=0.5)
76///
77/// The original model was trained on squished (non-aspect-preserving) resizes,
78/// so we must NOT centre-pad to square — doing so degrades accuracy.
79///
80/// Returns shape `[1, 3, H, W]` + flat NCHW `Vec<f32>`.
81fn preprocess(img: &DynamicImage) -> ([usize; 4], Vec<f32>) {
82 // Bilinear matches preprocessor_config.json "resample": 2 (PIL.Image.BILINEAR)
83 let resized = img
84 .grayscale()
85 .resize_exact(IMG_SIZE as u32, IMG_SIZE as u32, imageops::FilterType::Triangle)
86 .to_rgb8();
87
88 let mut flat = vec![0.0f32; 3 * IMG_SIZE * IMG_SIZE];
89 for y in 0..IMG_SIZE {
90 for x in 0..IMG_SIZE {
91 let p = resized.get_pixel(x as u32, y as u32);
92 flat[0 * IMG_SIZE * IMG_SIZE + y * IMG_SIZE + x] = (p[0] as f32 / 255.0 - PIXEL_MEAN) / PIXEL_STD;
93 flat[1 * IMG_SIZE * IMG_SIZE + y * IMG_SIZE + x] = (p[1] as f32 / 255.0 - PIXEL_MEAN) / PIXEL_STD;
94 flat[2 * IMG_SIZE * IMG_SIZE + y * IMG_SIZE + x] = (p[2] as f32 / 255.0 - PIXEL_MEAN) / PIXEL_STD;
95 }
96 }
97 ([1, 3, IMG_SIZE, IMG_SIZE], flat)
98}
99
100// ── Vocabulary ────────────────────────────────────────────────────────────────
101//
102// vocab.txt is a line-indexed file: token ID = line number (0-based).
103// Special tokens to skip when decoding:
104// 0=[PAD] 1=[UNK] 2=[CLS]/BOS 3=[SEP]/EOS 4=[MASK] 5-14=<unused0-9>
105
106struct VocabDecoder {
107 tokens: Vec<String>,
108}
109
110impl VocabDecoder {
111 fn from_file(path: &Path) -> Result<Self> {
112 let content = std::fs::read_to_string(path)
113 .with_context(|| format!("read {}", path.display()))?;
114 let tokens = content.lines().map(str::to_owned).collect();
115 Ok(Self { tokens })
116 }
117
118 fn decode(&self, ids: &[i64]) -> String {
119 ids.iter()
120 .filter_map(|&id| {
121 let uid = id as usize;
122 if uid < 15 { return None; }
123 self.tokens.get(uid)
124 })
125 .map(|tok| tok.trim_start_matches("##"))
126 .collect()
127 }
128}
129
130// ── Beam search helpers ───────────────────────────────────────────────────────
131
132fn log_softmax(logits: &[f32]) -> Vec<f32> {
133 let max = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
134 let sum_exp: f32 = logits.iter().map(|&x| (x - max).exp()).sum();
135 let log_z = sum_exp.ln() + max;
136 logits.iter().map(|&x| x - log_z).collect()
137}
138
139fn top_k(log_probs: &[f32], k: usize) -> Vec<(usize, f32)> {
140 let mut v: Vec<(usize, f32)> = log_probs.iter().enumerate().map(|(i, &p)| (i, p)).collect();
141 v.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal));
142 v.truncate(k);
143 v
144}
145
146fn apply_no_repeat_ngram(ids: &[i64], log_probs: &mut [f32]) {
147 let n = NO_REPEAT_NGRAM;
148 if ids.len() < n - 1 || n < 2 { return; }
149 let prefix = &ids[ids.len() - (n - 1)..];
150 for i in 0..ids.len().saturating_sub(n - 1) {
151 if ids[i..i + n - 1] == *prefix {
152 let tok = ids[i + n - 1] as usize;
153 if tok < log_probs.len() {
154 log_probs[tok] = f32::NEG_INFINITY;
155 }
156 }
157 }
158}
159
160// ── Confidence calibration ────────────────────────────────────────────────────
161//
162// The model is trained on 224×224. Crops that are too small lose detail on
163// up-scale; very large crops lose detail on down-scale; extreme aspect ratios
164// cause heavy squishing. These tables apply a multiplicative penalty to the
165// raw token-level confidence so the returned value better reflects expected
166// accuracy. Thresholds are tuned empirically — adjust as needed.
167//
168// Format: (exclusive_upper_bound, factor). First match wins.
169
170/// Calibration by crop area (width × height in pixels).
171const AREA_CALIBRATION: &[(f32, f32)] = &[
172 (2_500.0, 0.3), // < ~50×50: too small for 224×224 resize
173 (10_000.0, 0.6), // < ~100×100: marginal
174 (500_000.0, 1.0), // sweet spot for manga speech-bubble crops
175 (2_000_000.0, 0.9), // large — some detail loss in downscale
176 (f32::MAX, 0.7), // very large — heavy downscale
177];
178
179/// Calibration by aspect ratio (max_dim / min_dim).
180const ASPECT_CALIBRATION: &[(f32, f32)] = &[
181 (2.0, 1.0), // near-square: no penalty
182 (4.0, 0.95), // moderate stretch
183 (8.0, 0.85), // significant squish distortion
184 (f32::MAX, 0.7), // extreme aspect ratio
185];
186
187fn dimension_calibration(width: u32, height: u32) -> f32 {
188 let area = width as f32 * height as f32;
189 let aspect = width.max(height) as f32 / width.min(height).max(1) as f32;
190
191 let area_factor = AREA_CALIBRATION.iter()
192 .find(|(bound, _)| area < *bound)
193 .map(|(_, f)| *f)
194 .unwrap_or(1.0);
195
196 let aspect_factor = ASPECT_CALIBRATION.iter()
197 .find(|(bound, _)| aspect < *bound)
198 .map(|(_, f)| *f)
199 .unwrap_or(1.0);
200
201 area_factor * aspect_factor
202}
203
204// ── Public API ────────────────────────────────────────────────────────────────
205
206/// Result of OCR recognition, including confidence metrics.
207///
208/// Returned by [`MangaOcr::recognize_with_score`].
209///
210/// # Confidence scores
211///
212/// - `confidence` — geometric mean of per-token probabilities (0.0–1.0),
213/// adjusted for image dimensions. Use this for accept/reject thresholding.
214/// - `raw_confidence` — same metric before dimension calibration.
215/// - `score` — length-normalised beam score used internally for beam selection.
216///
217/// # Suppressed beam-search data (available for future use)
218///
219/// The beam search internally computes more than what is surfaced here.
220/// These are not yet exposed but could be added to this struct:
221///
222/// - **Per-token log probabilities** — each token's individual log-prob is
223/// computed during candidate scoring but only the accumulated sum survives.
224/// Would allow callers to highlight which specific characters the model was
225/// uncertain about.
226///
227/// - **Alternative beams** — `completed` holds all finished beams (up to
228/// `NUM_BEAMS`), but only the best is returned. Runner-up texts could aid
229/// disambiguation (e.g. `テスト` vs `ラスト` — if the top-2 beams disagree,
230/// that signals uncertainty the confidence score alone doesn't capture).
231///
232/// - **Decode step count** — number of decoder iterations before the winning
233/// beam terminated. Distinct from `truncated` (which is binary): knowing
234/// the model took 5 steps vs 250 gives a sense of output complexity.
235#[derive(Debug, Clone)]
236pub struct Recognition {
237 /// Decoded Japanese text.
238 pub text: String,
239 /// Length-normalised beam score (higher is better).
240 /// Computed as `accumulated_log_prob / token_count ^ LENGTH_PENALTY`.
241 pub score: f32,
242 /// Geometric mean of per-token probabilities (0.0–1.0), before calibration.
243 pub raw_confidence: f32,
244 /// Dimension-adjusted confidence (0.0–1.0). Penalises crops that are too
245 /// small, too large, or have extreme aspect ratios.
246 pub confidence: f32,
247 /// `true` if the decoder hit `max_decode_steps` without emitting EOS.
248 /// Strong hallucination signal — runaway generation almost always means
249 /// the output is garbage.
250 pub truncated: bool,
251 /// Number of tokens generated (excluding BOS). Combined with image
252 /// dimensions, enables a characters-per-pixel heuristic: a 50×80 crop
253 /// producing 200 tokens is garbage regardless of confidence.
254 pub token_count: usize,
255}
256
257/// OCR engine wrapping the encoder + decoder ONNX sessions and vocabulary.
258///
259/// Construct once (model load is expensive), then call [`recognize`] or
260/// [`recognize_with_score`] repeatedly.
261/// `Session::run` requires `&mut Session`, so each session is wrapped in a
262/// `Mutex` — this lets `MangaOcr` be shared as `Arc<MangaOcr>` across threads.
263///
264/// [`recognize`]: MangaOcr::recognize
265/// [`recognize_with_score`]: MangaOcr::recognize_with_score
266pub struct MangaOcr {
267 encoder: Mutex<Session>,
268 decoder: Mutex<Session>,
269 vocab: VocabDecoder,
270 max_decode_steps: usize,
271}
272
273impl MangaOcr {
274 /// Load models from `model_dir`.
275 ///
276 /// Expects these files inside the directory:
277 /// - `encoder_model.onnx`
278 /// - `decoder_model.onnx`
279 /// - `vocab.txt`
280 ///
281 /// Use [`default_model_dir()`] to get the path that `build.rs` prepared.
282 pub fn new(model_dir: &Path) -> Result<Self> {
283 let enc_path = model_dir.join("encoder_model.onnx");
284 let dec_path = model_dir.join("decoder_model.onnx");
285 let tok_path = model_dir.join("vocab.txt");
286
287 let encoder = Session::builder()
288 .context("encoder: SessionBuilder")?
289 .commit_from_file(&enc_path)
290 .with_context(|| format!("encoder: open {}", enc_path.display()))?;
291
292 let decoder = Session::builder()
293 .context("decoder: SessionBuilder")?
294 .commit_from_file(&dec_path)
295 .with_context(|| format!("decoder: open {}", dec_path.display()))?;
296
297 let vocab = VocabDecoder::from_file(&tok_path)?;
298
299 Ok(Self {
300 encoder: Mutex::new(encoder),
301 decoder: Mutex::new(decoder),
302 vocab,
303 max_decode_steps: DEFAULT_MAX_DECODE_STEPS,
304 })
305 }
306
307 /// Set the maximum number of decoder steps (default: 50).
308 ///
309 /// No manga speech bubble needs more than ~50 tokens. The original
310 /// `generation_config.json` ships 300, but that only increases worst-case
311 /// latency when the decoder runs away on garbage input.
312 pub fn with_max_decode_steps(mut self, n: usize) -> Self {
313 self.max_decode_steps = n.max(1);
314 self
315 }
316
317 /// OCR one image crop. Returns raw Japanese text; no translation.
318 ///
319 /// Convenience wrapper around [`recognize_with_score`] that discards
320 /// confidence metrics.
321 ///
322 /// [`recognize_with_score`]: MangaOcr::recognize_with_score
323 pub fn recognize(&self, img: &DynamicImage) -> Result<String> {
324 self.recognize_with_score(img).map(|r| r.text)
325 }
326
327 /// OCR one image crop, returning text with confidence scores.
328 ///
329 /// Works on any aspect ratio: tategaki (tall), yokogaki (wide), tegaki
330 /// (handwritten). Images are squish-resized to 224×224 (no padding),
331 /// matching the original training pipeline.
332 ///
333 /// Uses beam search (4 beams) matching `generation_config.json`.
334 pub fn recognize_with_score(&self, img: &DynamicImage) -> Result<Recognition> {
335 let (img_w, img_h) = (img.width(), img.height());
336
337 // ── 1. Encode ────────────────────────────────────────────────────────
338 let (pv_shape, pv_data) = preprocess(img);
339 let pv_tensor = OrtTensor::<f32>::from_array((pv_shape, pv_data))
340 .context("pixel_values tensor")?;
341
342 let (enc_seq_len, hidden_dim, enc_hidden) = {
343 let mut enc = self.encoder.lock()
344 .map_err(|e| anyhow::anyhow!("encoder lock poisoned: {e}"))?;
345 let out = enc.run(ort::inputs!["pixel_values" => pv_tensor])
346 .context("encoder run")?;
347 let (shape, data) = out["last_hidden_state"]
348 .try_extract_tensor::<f32>()
349 .context("encoder: extract last_hidden_state")?;
350 (shape[1] as usize, shape[2] as usize, data.to_vec())
351 };
352
353 // ── 2. Beam search ───────────────────────────────────────────────────
354 let seeds = {
355 let mut lp = self.decoder_logprobs_single(
356 &[DECODER_START_TOKEN_ID], enc_seq_len, hidden_dim, &enc_hidden,
357 )?;
358 apply_no_repeat_ngram(&[DECODER_START_TOKEN_ID], &mut lp);
359 top_k(&lp, NUM_BEAMS)
360 };
361
362 let mut beams: Vec<(Vec<i64>, f32, bool)> = seeds.iter()
363 .map(|&(tok, lp)| {
364 let done = tok as i64 == EOS_TOKEN_ID;
365 (vec![DECODER_START_TOKEN_ID, tok as i64], lp, done)
366 })
367 .collect();
368 // completed: (token_ids, accumulated_log_prob, hit_eos)
369 let mut completed: Vec<(Vec<i64>, f32, bool)> = beams.iter()
370 .filter(|(_, _, done)| *done)
371 .map(|(ids, score, _)| (ids.clone(), *score, true))
372 .collect();
373
374 for _step in 1..self.max_decode_steps {
375 let active: Vec<usize> = beams.iter().enumerate()
376 .filter(|(_, (_, _, done))| !*done)
377 .map(|(i, _)| i)
378 .collect();
379 if active.is_empty() { break; }
380
381 let batch = active.len();
382 let seq_len = beams[active[0]].0.len();
383
384 let flat_ids: Vec<i64> = active.iter()
385 .flat_map(|&i| beams[i].0.iter().copied())
386 .collect();
387
388 let flat_enc: Vec<f32> = enc_hidden.iter().copied()
389 .cycle()
390 .take(batch * enc_seq_len * hidden_dim)
391 .collect();
392
393 let batch_log_probs: Vec<Vec<f32>> = {
394 let mut dec = self.decoder.lock()
395 .map_err(|e| anyhow::anyhow!("decoder lock poisoned: {e}"))?;
396 let out = dec.run(ort::inputs![
397 "input_ids" =>
398 OrtTensor::<i64>::from_array(([batch, seq_len], flat_ids))
399 .context("input_ids tensor")?,
400 "encoder_hidden_states" =>
401 OrtTensor::<f32>::from_array(([batch, enc_seq_len, hidden_dim], flat_enc))
402 .context("encoder_hidden_states tensor")?
403 ]).context("decoder batch run")?;
404
405 let (logits_shape, logits_data) = out["logits"]
406 .try_extract_tensor::<f32>()
407 .context("logits")?;
408 let vocab_size = logits_shape[2] as usize;
409
410 (0..batch).map(|b| {
411 let offset = (b * seq_len + seq_len - 1) * vocab_size;
412 log_softmax(&logits_data[offset..offset + vocab_size])
413 }).collect()
414 };
415
416 let mut candidates: Vec<(usize, i64, f32)> = Vec::with_capacity(batch * NUM_BEAMS);
417 for (b, &beam_idx) in active.iter().enumerate() {
418 let beam_score = beams[beam_idx].1;
419 let mut lp = batch_log_probs[b].clone();
420 apply_no_repeat_ngram(&beams[beam_idx].0, &mut lp);
421 for (tok, lp) in top_k(&lp, NUM_BEAMS) {
422 candidates.push((beam_idx, tok as i64, beam_score + lp));
423 }
424 }
425 candidates.sort_unstable_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(Ordering::Equal));
426 candidates.truncate(NUM_BEAMS);
427
428 beams = candidates.iter().map(|&(old_idx, tok, score)| {
429 let mut ids = beams[old_idx].0.clone();
430 let done = tok == EOS_TOKEN_ID;
431 if !done { ids.push(tok); }
432 (ids, score, done)
433 }).collect();
434
435 for (ids, score, done) in &beams {
436 if *done { completed.push((ids.clone(), *score, true)); }
437 }
438
439 // Early bailout: if the best active beam's running confidence
440 // drops below the threshold after enough tokens, the decoder is
441 // hallucinating. Break now — the force-push below will mark
442 // these beams as truncated (hit_eos=false).
443 let best_active = beams.iter()
444 .filter(|(_, _, done)| !*done)
445 .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
446 if let Some((ids, score, _)) = best_active {
447 let num_generated = ids.len().saturating_sub(1);
448 if num_generated >= EARLY_BAILOUT_MIN_TOKENS {
449 let running_conf = (*score / num_generated as f32).exp();
450 if running_conf < EARLY_BAILOUT_CONFIDENCE {
451 eprintln!(
452 "[manga-ocr] early bailout at {} tokens — confidence {:.1}% < {:.0}%",
453 num_generated, running_conf * 100.0, EARLY_BAILOUT_CONFIDENCE * 100.0,
454 );
455 break;
456 }
457 }
458 }
459 }
460
461 // Force-push beams that never emitted EOS — these were truncated at
462 // max_decode_steps and are almost certainly hallucinations.
463 for (ids, score, done) in &beams {
464 if !done { completed.push((ids.clone(), *score, false)); }
465 }
466
467 // ── 3. Pick best beam (length-normalised score) ───────────────────────
468 let (best_ids, best_raw_score, hit_eos) = completed.iter()
469 .max_by(|(ids_a, score_a, _), (ids_b, score_b, _)| {
470 let norm = |ids: &[i64], s: f32| s / (ids.len() as f32).powf(LENGTH_PENALTY);
471 norm(ids_a, *score_a).partial_cmp(&norm(ids_b, *score_b))
472 .unwrap_or(Ordering::Equal)
473 })
474 .map(|(ids, score, eos)| (ids.as_slice(), *score, *eos))
475 .unwrap_or((&[], 0.0, false));
476
477 // ── 4. Detokenise & compute confidence ────────────────────────────────
478 let decode_ids = if best_ids.first() == Some(&DECODER_START_TOKEN_ID) {
479 &best_ids[1..]
480 } else {
481 best_ids
482 };
483 let text = self.vocab.decode(decode_ids);
484
485 let token_count = best_ids.len().saturating_sub(1); // exclude BOS
486 let num_tokens = token_count.max(1) as f32;
487 let score = best_raw_score / (best_ids.len().max(1) as f32).powf(LENGTH_PENALTY);
488 let raw_confidence = (best_raw_score / num_tokens).exp();
489 let calibration = dimension_calibration(img_w, img_h);
490 let confidence = raw_confidence * calibration;
491
492 Ok(Recognition { text, score, raw_confidence, confidence, truncated: !hit_eos, token_count })
493 }
494
495 fn decoder_logprobs_single(
496 &self,
497 ids: &[i64],
498 enc_seq_len: usize,
499 hidden_dim: usize,
500 enc_hidden: &[f32],
501 ) -> Result<Vec<f32>> {
502 let seq_len = ids.len();
503 let mut dec = self.decoder.lock()
504 .map_err(|e| anyhow::anyhow!("decoder lock poisoned: {e}"))?;
505 let out = dec.run(ort::inputs![
506 "input_ids" =>
507 OrtTensor::<i64>::from_array(([1usize, seq_len], ids.to_vec()))
508 .context("input_ids tensor")?,
509 "encoder_hidden_states" =>
510 OrtTensor::<f32>::from_array(([1usize, enc_seq_len, hidden_dim], enc_hidden.to_vec()))
511 .context("encoder_hidden_states tensor")?
512 ]).context("decoder bootstrap run")?;
513
514 let (logits_shape, logits_data) = out["logits"]
515 .try_extract_tensor::<f32>()
516 .context("logits")?;
517 let vocab_size = logits_shape[2] as usize;
518 let last = &logits_data[(seq_len - 1) * vocab_size..seq_len * vocab_size];
519 Ok(log_softmax(last))
520 }
521}