#![warn(missing_docs)]
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
mod image_processor;
use anyhow::anyhow;
use candle_core::DType;
use candle_core::{Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::models::trocr;
use candle_transformers::models::vit;
use hf_hub::api::sync::Api;
use image::{GenericImage, GenericImageView, ImageBuffer, Rgba};
use kalosm_common::*;
use tokenizers::Tokenizer;
#[derive(Default)]
pub struct OcrBuilder {
source: OcrSource,
}
impl OcrBuilder {
pub fn with_source(mut self, source: OcrSource) -> Self {
self.source = source;
self
}
pub async fn build(self) -> anyhow::Result<Ocr> {
Ocr::new(self, |_| {}).await
}
pub async fn build_with_loading_handler(
self,
handler: impl FnMut(ModelLoadingProgress) + Send + Sync + 'static,
) -> anyhow::Result<Ocr> {
Ocr::new(self, handler).await
}
}
pub struct OcrSource {
model: FileSource,
config: FileSource,
}
impl OcrSource {
pub fn new(model: FileSource, config: FileSource) -> Self {
Self { model, config }
}
pub fn base() -> Self {
Self::new(
FileSource::huggingface(
"microsoft/trocr-base-handwritten".to_string(),
"refs/pr/3".to_string(),
"model.safetensors".to_string(),
),
FileSource::huggingface(
"microsoft/trocr-base-handwritten".to_string(),
"refs/pr/3".to_string(),
"config.json".to_string(),
),
)
}
pub fn large() -> Self {
Self::new(
FileSource::huggingface(
"microsoft/trocr-large-handwritten".to_string(),
"refs/pr/6".to_string(),
"model.safetensors".to_string(),
),
FileSource::huggingface(
"microsoft/trocr-large-handwritten".to_string(),
"refs/pr/6".to_string(),
"config.json".to_string(),
),
)
}
pub fn base_printed() -> Self {
Self::new(
FileSource::huggingface(
"microsoft/trocr-base-printed".to_string(),
"refs/pr/7".to_string(),
"model.safetensors".to_string(),
),
FileSource::huggingface(
"microsoft/trocr-base-printed".to_string(),
"refs/pr/7".to_string(),
"config.json".to_string(),
),
)
}
pub fn large_printed() -> Self {
Self::new(
FileSource::huggingface(
"microsoft/trocr-large-printed".to_string(),
"main".to_string(),
"model.safetensors".to_string(),
),
FileSource::huggingface(
"microsoft/trocr-large-printed".to_string(),
"main".to_string(),
"config.json".to_string(),
),
)
}
async fn varbuilder(
&self,
device: &Device,
mut handler: impl FnMut(ModelLoadingProgress) + Send + Sync,
) -> anyhow::Result<VarBuilder> {
let source = format!("Model ({})", self.model);
let mut create_progress = ModelLoadingProgress::downloading_progress(source);
let filename = self
.model
.download(|progress| handler(create_progress(progress)))
.await?;
Ok(unsafe { VarBuilder::from_mmaped_safetensors(&[filename], DType::F32, device)? })
}
async fn config(
&self,
mut handler: impl FnMut(ModelLoadingProgress) + Send + Sync,
) -> anyhow::Result<(vit::Config, trocr::TrOCRConfig)> {
#[derive(Debug, Clone, serde::Deserialize)]
struct Config {
encoder: vit::Config,
decoder: trocr::TrOCRConfig,
}
let (encoder_config, decoder_config) = {
let source = format!("Config ({})", self.model);
let mut create_progress = ModelLoadingProgress::downloading_progress(source);
let config_filename = self
.config
.download(|progress| handler(create_progress(progress)))
.await?;
let config: Config = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
(config.encoder, config.decoder)
};
Ok((encoder_config, decoder_config))
}
}
impl Default for OcrSource {
fn default() -> Self {
Self::base()
}
}
pub struct OcrInferenceSettings {
image: ImageBuffer<image::Rgba<u8>, Vec<u8>>,
}
impl OcrInferenceSettings {
pub fn new<I: GenericImageView<Pixel = Rgba<u8>>>(input: I) -> anyhow::Result<Self> {
let mut image = ImageBuffer::new(input.width(), input.height());
image.copy_from(&input, 0, 0)?;
Ok(Self { image })
}
pub fn set_image<I: GenericImageView<Pixel = Rgba<u8>>>(
mut self,
image: I,
) -> anyhow::Result<Self> {
self.image = ImageBuffer::new(image.width(), image.height());
self.image.copy_from(&image, 0, 0)?;
Ok(self)
}
}
pub struct Ocr {
device: Device,
decoder: trocr::TrOCRModel,
decoder_config: trocr::TrOCRConfig,
processor: image_processor::ViTImageProcessor,
tokenizer_dec: Tokenizer,
}
impl Ocr {
pub fn builder() -> OcrBuilder {
OcrBuilder::default()
}
async fn new(
settings: OcrBuilder,
mut handler: impl FnMut(ModelLoadingProgress) + Send + Sync + 'static,
) -> anyhow::Result<Self> {
let OcrBuilder { source } = settings;
let tokenizer_dec = {
let tokenizer = Api::new()?
.model(String::from("ToluClassics/candle-trocr-tokenizer"))
.get("tokenizer.json")?;
Tokenizer::from_file(&tokenizer).map_err(|e| anyhow!(e))?
};
let device = accelerated_device_if_available()?;
let vb = source.varbuilder(&device, &mut handler).await?;
let (encoder_config, decoder_config) = source.config(&mut handler).await?;
let model = trocr::TrOCRModel::new(&encoder_config, &decoder_config, vb)?;
let config = image_processor::ProcessorConfig::default();
let processor = image_processor::ViTImageProcessor::new(&config);
Ok(Self {
device,
decoder: model,
processor,
decoder_config,
tokenizer_dec,
})
}
pub fn recognize_text(&mut self, settings: OcrInferenceSettings) -> anyhow::Result<String> {
let OcrInferenceSettings { image } = settings;
let image = image::DynamicImage::ImageRgba8(image);
let image = vec![image];
let image = self.processor.preprocess(image, &self.device)?;
let encoder_xs = self.decoder.encoder().forward(&image)?;
let mut logits_processor =
candle_transformers::generation::LogitsProcessor::new(1337, None, None);
let mut token_ids: Vec<u32> = vec![self.decoder_config.decoder_start_token_id];
for index in 0..1000 {
let context_size = if index >= 1 { 1 } else { token_ids.len() };
let start_pos = token_ids.len().saturating_sub(context_size);
let input_ids = Tensor::new(&token_ids[start_pos..], &self.device)?.unsqueeze(0)?;
let logits = self.decoder.decode(&input_ids, &encoder_xs, start_pos)?;
let logits = logits.squeeze(0)?;
let logits = logits.get(logits.dim(0)? - 1)?;
let token = logits_processor.sample(&logits)?;
token_ids.push(token);
if token == self.decoder_config.eos_token_id {
break;
}
}
let decoded = self
.tokenizer_dec
.decode(&token_ids, true)
.map_err(|e| anyhow!(e))?;
Ok(decoded)
}
}