use anyhow::{Context, Result};
use image::{imageops, DynamicImage};
use ort::session::Session;
use ort::value::Tensor as OrtTensor;
use std::cmp::Ordering;
use std::path::Path;
use std::sync::Mutex;
const IMG_SIZE: usize = 224;
const PIXEL_MEAN: f32 = 0.5;
const PIXEL_STD: f32 = 0.5;
const DECODER_START_TOKEN_ID: i64 = 2;
const EOS_TOKEN_ID: i64 = 3;
const DEFAULT_MAX_DECODE_STEPS: usize = 50;
const NUM_BEAMS: usize = 4;
const LENGTH_PENALTY: f32 = 2.0;
const NO_REPEAT_NGRAM: usize = 3;
const EARLY_BAILOUT_MIN_TOKENS: usize = 16;
const EARLY_BAILOUT_CONFIDENCE: f32 = 0.30;
pub fn default_model_dir() -> &'static Path {
Path::new(env!("MANGA_OCR_DEFAULT_MODEL_DIR"))
}
fn preprocess(img: &DynamicImage) -> ([usize; 4], Vec<f32>) {
let resized = img
.grayscale()
.resize_exact(IMG_SIZE as u32, IMG_SIZE as u32, imageops::FilterType::Triangle)
.to_rgb8();
let mut flat = vec![0.0f32; 3 * IMG_SIZE * IMG_SIZE];
for y in 0..IMG_SIZE {
for x in 0..IMG_SIZE {
let p = resized.get_pixel(x as u32, y as u32);
flat[0 * IMG_SIZE * IMG_SIZE + y * IMG_SIZE + x] = (p[0] as f32 / 255.0 - PIXEL_MEAN) / PIXEL_STD;
flat[1 * IMG_SIZE * IMG_SIZE + y * IMG_SIZE + x] = (p[1] as f32 / 255.0 - PIXEL_MEAN) / PIXEL_STD;
flat[2 * IMG_SIZE * IMG_SIZE + y * IMG_SIZE + x] = (p[2] as f32 / 255.0 - PIXEL_MEAN) / PIXEL_STD;
}
}
([1, 3, IMG_SIZE, IMG_SIZE], flat)
}
struct VocabDecoder {
tokens: Vec<String>,
}
impl VocabDecoder {
fn from_file(path: &Path) -> Result<Self> {
let content = std::fs::read_to_string(path)
.with_context(|| format!("read {}", path.display()))?;
let tokens = content.lines().map(str::to_owned).collect();
Ok(Self { tokens })
}
fn decode(&self, ids: &[i64]) -> String {
ids.iter()
.filter_map(|&id| {
let uid = id as usize;
if uid < 15 { return None; }
self.tokens.get(uid)
})
.map(|tok| tok.trim_start_matches("##"))
.collect()
}
}
fn log_softmax(logits: &[f32]) -> Vec<f32> {
let max = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let sum_exp: f32 = logits.iter().map(|&x| (x - max).exp()).sum();
let log_z = sum_exp.ln() + max;
logits.iter().map(|&x| x - log_z).collect()
}
fn top_k(log_probs: &[f32], k: usize) -> Vec<(usize, f32)> {
let mut v: Vec<(usize, f32)> = log_probs.iter().enumerate().map(|(i, &p)| (i, p)).collect();
v.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal));
v.truncate(k);
v
}
fn apply_no_repeat_ngram(ids: &[i64], log_probs: &mut [f32]) {
let n = NO_REPEAT_NGRAM;
if ids.len() < n - 1 || n < 2 { return; }
let prefix = &ids[ids.len() - (n - 1)..];
for i in 0..ids.len().saturating_sub(n - 1) {
if ids[i..i + n - 1] == *prefix {
let tok = ids[i + n - 1] as usize;
if tok < log_probs.len() {
log_probs[tok] = f32::NEG_INFINITY;
}
}
}
}
const AREA_CALIBRATION: &[(f32, f32)] = &[
(2_500.0, 0.3), (10_000.0, 0.6), (500_000.0, 1.0), (2_000_000.0, 0.9), (f32::MAX, 0.7), ];
const ASPECT_CALIBRATION: &[(f32, f32)] = &[
(2.0, 1.0), (4.0, 0.95), (8.0, 0.85), (f32::MAX, 0.7), ];
fn dimension_calibration(width: u32, height: u32) -> f32 {
let area = width as f32 * height as f32;
let aspect = width.max(height) as f32 / width.min(height).max(1) as f32;
let area_factor = AREA_CALIBRATION.iter()
.find(|(bound, _)| area < *bound)
.map(|(_, f)| *f)
.unwrap_or(1.0);
let aspect_factor = ASPECT_CALIBRATION.iter()
.find(|(bound, _)| aspect < *bound)
.map(|(_, f)| *f)
.unwrap_or(1.0);
area_factor * aspect_factor
}
#[derive(Debug, Clone)]
pub struct Recognition {
pub text: String,
pub score: f32,
pub raw_confidence: f32,
pub confidence: f32,
pub truncated: bool,
pub token_count: usize,
}
pub struct MangaOcr {
encoder: Mutex<Session>,
decoder: Mutex<Session>,
vocab: VocabDecoder,
max_decode_steps: usize,
}
impl MangaOcr {
pub fn new(model_dir: &Path) -> Result<Self> {
let enc_path = model_dir.join("encoder_model.onnx");
let dec_path = model_dir.join("decoder_model.onnx");
let tok_path = model_dir.join("vocab.txt");
let encoder = Session::builder()
.context("encoder: SessionBuilder")?
.commit_from_file(&enc_path)
.with_context(|| format!("encoder: open {}", enc_path.display()))?;
let decoder = Session::builder()
.context("decoder: SessionBuilder")?
.commit_from_file(&dec_path)
.with_context(|| format!("decoder: open {}", dec_path.display()))?;
let vocab = VocabDecoder::from_file(&tok_path)?;
Ok(Self {
encoder: Mutex::new(encoder),
decoder: Mutex::new(decoder),
vocab,
max_decode_steps: DEFAULT_MAX_DECODE_STEPS,
})
}
pub fn with_max_decode_steps(mut self, n: usize) -> Self {
self.max_decode_steps = n.max(1);
self
}
pub fn recognize(&self, img: &DynamicImage) -> Result<String> {
self.recognize_with_score(img).map(|r| r.text)
}
pub fn recognize_with_score(&self, img: &DynamicImage) -> Result<Recognition> {
let (img_w, img_h) = (img.width(), img.height());
let (pv_shape, pv_data) = preprocess(img);
let pv_tensor = OrtTensor::<f32>::from_array((pv_shape, pv_data))
.context("pixel_values tensor")?;
let (enc_seq_len, hidden_dim, enc_hidden) = {
let mut enc = self.encoder.lock()
.map_err(|e| anyhow::anyhow!("encoder lock poisoned: {e}"))?;
let out = enc.run(ort::inputs!["pixel_values" => pv_tensor])
.context("encoder run")?;
let (shape, data) = out["last_hidden_state"]
.try_extract_tensor::<f32>()
.context("encoder: extract last_hidden_state")?;
(shape[1] as usize, shape[2] as usize, data.to_vec())
};
let seeds = {
let mut lp = self.decoder_logprobs_single(
&[DECODER_START_TOKEN_ID], enc_seq_len, hidden_dim, &enc_hidden,
)?;
apply_no_repeat_ngram(&[DECODER_START_TOKEN_ID], &mut lp);
top_k(&lp, NUM_BEAMS)
};
let mut beams: Vec<(Vec<i64>, f32, bool)> = seeds.iter()
.map(|&(tok, lp)| {
let done = tok as i64 == EOS_TOKEN_ID;
(vec![DECODER_START_TOKEN_ID, tok as i64], lp, done)
})
.collect();
let mut completed: Vec<(Vec<i64>, f32, bool)> = beams.iter()
.filter(|(_, _, done)| *done)
.map(|(ids, score, _)| (ids.clone(), *score, true))
.collect();
for _step in 1..self.max_decode_steps {
let active: Vec<usize> = beams.iter().enumerate()
.filter(|(_, (_, _, done))| !*done)
.map(|(i, _)| i)
.collect();
if active.is_empty() { break; }
let batch = active.len();
let seq_len = beams[active[0]].0.len();
let flat_ids: Vec<i64> = active.iter()
.flat_map(|&i| beams[i].0.iter().copied())
.collect();
let flat_enc: Vec<f32> = enc_hidden.iter().copied()
.cycle()
.take(batch * enc_seq_len * hidden_dim)
.collect();
let batch_log_probs: Vec<Vec<f32>> = {
let mut dec = self.decoder.lock()
.map_err(|e| anyhow::anyhow!("decoder lock poisoned: {e}"))?;
let out = dec.run(ort::inputs![
"input_ids" =>
OrtTensor::<i64>::from_array(([batch, seq_len], flat_ids))
.context("input_ids tensor")?,
"encoder_hidden_states" =>
OrtTensor::<f32>::from_array(([batch, enc_seq_len, hidden_dim], flat_enc))
.context("encoder_hidden_states tensor")?
]).context("decoder batch run")?;
let (logits_shape, logits_data) = out["logits"]
.try_extract_tensor::<f32>()
.context("logits")?;
let vocab_size = logits_shape[2] as usize;
(0..batch).map(|b| {
let offset = (b * seq_len + seq_len - 1) * vocab_size;
log_softmax(&logits_data[offset..offset + vocab_size])
}).collect()
};
let mut candidates: Vec<(usize, i64, f32)> = Vec::with_capacity(batch * NUM_BEAMS);
for (b, &beam_idx) in active.iter().enumerate() {
let beam_score = beams[beam_idx].1;
let mut lp = batch_log_probs[b].clone();
apply_no_repeat_ngram(&beams[beam_idx].0, &mut lp);
for (tok, lp) in top_k(&lp, NUM_BEAMS) {
candidates.push((beam_idx, tok as i64, beam_score + lp));
}
}
candidates.sort_unstable_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(Ordering::Equal));
candidates.truncate(NUM_BEAMS);
beams = candidates.iter().map(|&(old_idx, tok, score)| {
let mut ids = beams[old_idx].0.clone();
let done = tok == EOS_TOKEN_ID;
if !done { ids.push(tok); }
(ids, score, done)
}).collect();
for (ids, score, done) in &beams {
if *done { completed.push((ids.clone(), *score, true)); }
}
let best_active = beams.iter()
.filter(|(_, _, done)| !*done)
.max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
if let Some((ids, score, _)) = best_active {
let num_generated = ids.len().saturating_sub(1);
if num_generated >= EARLY_BAILOUT_MIN_TOKENS {
let running_conf = (*score / num_generated as f32).exp();
if running_conf < EARLY_BAILOUT_CONFIDENCE {
eprintln!(
"[manga-ocr] early bailout at {} tokens — confidence {:.1}% < {:.0}%",
num_generated, running_conf * 100.0, EARLY_BAILOUT_CONFIDENCE * 100.0,
);
break;
}
}
}
}
for (ids, score, done) in &beams {
if !done { completed.push((ids.clone(), *score, false)); }
}
let (best_ids, best_raw_score, hit_eos) = completed.iter()
.max_by(|(ids_a, score_a, _), (ids_b, score_b, _)| {
let norm = |ids: &[i64], s: f32| s / (ids.len() as f32).powf(LENGTH_PENALTY);
norm(ids_a, *score_a).partial_cmp(&norm(ids_b, *score_b))
.unwrap_or(Ordering::Equal)
})
.map(|(ids, score, eos)| (ids.as_slice(), *score, *eos))
.unwrap_or((&[], 0.0, false));
let decode_ids = if best_ids.first() == Some(&DECODER_START_TOKEN_ID) {
&best_ids[1..]
} else {
best_ids
};
let text = self.vocab.decode(decode_ids);
let token_count = best_ids.len().saturating_sub(1); let num_tokens = token_count.max(1) as f32;
let score = best_raw_score / (best_ids.len().max(1) as f32).powf(LENGTH_PENALTY);
let raw_confidence = (best_raw_score / num_tokens).exp();
let calibration = dimension_calibration(img_w, img_h);
let confidence = raw_confidence * calibration;
Ok(Recognition { text, score, raw_confidence, confidence, truncated: !hit_eos, token_count })
}
fn decoder_logprobs_single(
&self,
ids: &[i64],
enc_seq_len: usize,
hidden_dim: usize,
enc_hidden: &[f32],
) -> Result<Vec<f32>> {
let seq_len = ids.len();
let mut dec = self.decoder.lock()
.map_err(|e| anyhow::anyhow!("decoder lock poisoned: {e}"))?;
let out = dec.run(ort::inputs![
"input_ids" =>
OrtTensor::<i64>::from_array(([1usize, seq_len], ids.to_vec()))
.context("input_ids tensor")?,
"encoder_hidden_states" =>
OrtTensor::<f32>::from_array(([1usize, enc_seq_len, hidden_dim], enc_hidden.to_vec()))
.context("encoder_hidden_states tensor")?
]).context("decoder bootstrap run")?;
let (logits_shape, logits_data) = out["logits"]
.try_extract_tensor::<f32>()
.context("logits")?;
let vocab_size = logits_shape[2] as usize;
let last = &logits_data[(seq_len - 1) * vocab_size..seq_len * vocab_size];
Ok(log_softmax(last))
}
}