use koharu_core::download;
use ndarray::s;
use ort::{inputs, session::Session, value::TensorRef};
#[derive(Debug)]
pub struct MangaOCR {
encoder_model: Session,
decoder_model: Session,
vocab: Vec<String>,
}
impl MangaOCR {
pub async fn new() -> anyhow::Result<Self> {
let encoder_model_path =
download::hf_hub("mayocream/manga-ocr-onnx", "encoder_model.onnx").await?;
let decoder_model_path =
download::hf_hub("mayocream/manga-ocr-onnx", "decoder_model.onnx").await?;
let vocab_path = download::hf_hub("mayocream/manga-ocr-onnx", "vocab.txt").await?;
let encoder_model = Session::builder()?
.with_optimization_level(ort::session::builder::GraphOptimizationLevel::Level3)?
.commit_from_file(encoder_model_path)?;
let decoder_model = Session::builder()?
.with_optimization_level(ort::session::builder::GraphOptimizationLevel::Level3)?
.commit_from_file(decoder_model_path)?;
let vocab = std::fs::read_to_string(vocab_path)
.map_err(|e| anyhow::anyhow!("Failed to read vocab file: {e}"))?
.lines()
.map(|s| s.to_string())
.collect::<Vec<_>>();
Ok(Self {
encoder_model,
decoder_model,
vocab,
})
}
pub fn inference(&mut self, image: &image::DynamicImage) -> anyhow::Result<String> {
let image = image.grayscale().to_rgb8();
let image =
image::imageops::resize(&image, 224, 224, image::imageops::FilterType::Lanczos3);
let mut tensor = ndarray::Array::zeros((1, 3, 224, 224));
for (x, y, pixel) in image.enumerate_pixels() {
let x = x as usize;
let y = y as usize;
tensor[[0, 0, y, x]] = (pixel[0] as f32 / 255.0 - 0.5) / 0.5;
tensor[[0, 1, y, x]] = (pixel[1] as f32 / 255.0 - 0.5) / 0.5;
tensor[[0, 2, y, x]] = (pixel[2] as f32 / 255.0 - 0.5) / 0.5;
}
let inputs = inputs! {
"pixel_values" => TensorRef::from_array_view(tensor.view())?,
};
let outputs = self.encoder_model.run(inputs)?;
let encoder_hidden_state = outputs[0].try_extract_array::<f32>()?;
let mut token_ids: Vec<i64> = vec![2i64];
for _ in 0..300 {
let input = ndarray::Array::from_shape_vec((1, token_ids.len()), token_ids.clone())?;
let inputs = inputs! {
"encoder_hidden_states" => TensorRef::from_array_view(encoder_hidden_state.view())?,
"input_ids" => TensorRef::from_array_view(input.view())?,
};
let outputs = self.decoder_model.run(inputs)?;
let logits = outputs["logits"].try_extract_array::<f32>()?;
let logits_view = logits.view();
let last_token_logits = logits_view.slice(s![0, -1, ..]);
let (token_id, _) = last_token_logits
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
.unwrap_or((0, &0.0));
token_ids.push(token_id as i64);
if token_id as i64 == 3 {
break;
}
}
let text = token_ids
.iter()
.filter(|&&id| id >= 5)
.filter_map(|&id| self.vocab.get(id as usize).cloned())
.collect::<Vec<_>>();
let text = text.join("");
Ok(text)
}
}