1#![warn(missing_docs)]
23#[cfg(feature = "mkl")]
24extern crate intel_mkl_src;
25
26#[cfg(feature = "accelerate")]
27extern crate accelerate_src;
28
29mod image_processor;
30
31use candle_core::DType;
32use candle_core::{Device, Tensor};
33use candle_nn::VarBuilder;
34use candle_transformers::models::trocr;
35use candle_transformers::models::vit;
36use hf_hub::api::sync::Api;
37use image::{GenericImage, GenericImageView, ImageBuffer, Rgba};
38use kalosm_common::*;
39use kalosm_model_types::{FileSource, ModelLoadingProgress};
40use tokenizers::Tokenizer;
41
42#[derive(Default)]
44pub struct OcrBuilder {
45 source: OcrSource,
46}
47
48impl OcrBuilder {
49 pub fn with_source(mut self, source: OcrSource) -> Self {
51 self.source = source;
52 self
53 }
54
55 pub async fn build(self) -> Result<Ocr, LoadOcrError> {
57 Ocr::new(self, |_| {}).await
58 }
59
60 pub async fn build_with_loading_handler(
62 self,
63 handler: impl FnMut(ModelLoadingProgress) + Send + Sync + 'static,
64 ) -> Result<Ocr, LoadOcrError> {
65 Ocr::new(self, handler).await
66 }
67}
68
69pub struct OcrSource {
71 model: FileSource,
72 config: FileSource,
73}
74
75impl OcrSource {
76 pub fn new(model: FileSource, config: FileSource) -> Self {
78 Self { model, config }
79 }
80
81 pub fn base() -> Self {
83 Self::new(
84 FileSource::huggingface(
85 "microsoft/trocr-base-handwritten".to_string(),
86 "refs/pr/3".to_string(),
87 "model.safetensors".to_string(),
88 ),
89 FileSource::huggingface(
90 "microsoft/trocr-base-handwritten".to_string(),
91 "refs/pr/3".to_string(),
92 "config.json".to_string(),
93 ),
94 )
95 }
96
97 pub fn large() -> Self {
99 Self::new(
100 FileSource::huggingface(
101 "microsoft/trocr-large-handwritten".to_string(),
102 "refs/pr/6".to_string(),
103 "model.safetensors".to_string(),
104 ),
105 FileSource::huggingface(
106 "microsoft/trocr-large-handwritten".to_string(),
107 "refs/pr/6".to_string(),
108 "config.json".to_string(),
109 ),
110 )
111 }
112
113 pub fn base_printed() -> Self {
115 Self::new(
116 FileSource::huggingface(
117 "microsoft/trocr-base-printed".to_string(),
118 "refs/pr/7".to_string(),
119 "model.safetensors".to_string(),
120 ),
121 FileSource::huggingface(
122 "microsoft/trocr-base-printed".to_string(),
123 "refs/pr/7".to_string(),
124 "config.json".to_string(),
125 ),
126 )
127 }
128
129 pub fn large_printed() -> Self {
131 Self::new(
132 FileSource::huggingface(
133 "microsoft/trocr-large-printed".to_string(),
134 "main".to_string(),
135 "model.safetensors".to_string(),
136 ),
137 FileSource::huggingface(
138 "microsoft/trocr-large-printed".to_string(),
139 "main".to_string(),
140 "config.json".to_string(),
141 ),
142 )
143 }
144
145 async fn varbuilder(
146 &self,
147 device: &Device,
148 mut handler: impl FnMut(ModelLoadingProgress) + Send + Sync,
149 ) -> Result<VarBuilder, LoadOcrError> {
150 let source = format!("Model ({})", self.model);
151 let mut create_progress = ModelLoadingProgress::downloading_progress(source);
152 let cache = Cache::default();
153 let filename = cache
154 .get(&self.model, |progress| handler(create_progress(progress)))
155 .await?;
156 Ok(unsafe { VarBuilder::from_mmaped_safetensors(&[filename], DType::F32, device)? })
157 }
158
159 async fn config(
160 &self,
161 mut handler: impl FnMut(ModelLoadingProgress) + Send + Sync,
162 ) -> Result<(vit::Config, trocr::TrOCRConfig), LoadOcrError> {
163 #[derive(Debug, Clone, serde::Deserialize)]
164 struct Config {
165 encoder: vit::Config,
166 decoder: trocr::TrOCRConfig,
167 }
168
169 let (encoder_config, decoder_config) = {
170 let source = format!("Config ({})", self.model);
171 let mut create_progress = ModelLoadingProgress::downloading_progress(source);
172 let cache = Cache::default();
173 let config_filename = cache
174 .get(&self.config, |progress| handler(create_progress(progress)))
175 .await?;
176 let config: Config = serde_json::from_reader(
177 std::fs::File::open(config_filename)
178 .expect("FileSource::download should return a valid path"),
179 )
180 .map_err(LoadOcrError::LoadConfig)?;
181 (config.encoder, config.decoder)
182 };
183
184 Ok((encoder_config, decoder_config))
185 }
186}
187
188impl Default for OcrSource {
189 fn default() -> Self {
190 Self::base()
191 }
192}
193
194pub struct OcrInferenceSettings {
196 image: ImageBuffer<image::Rgba<u8>, Vec<u8>>,
197}
198
199impl OcrInferenceSettings {
200 pub fn new<I: GenericImageView<Pixel = Rgba<u8>>>(input: I) -> Self {
202 let mut image = ImageBuffer::new(input.width(), input.height());
203 image.copy_from(&input, 0, 0).unwrap();
204 Self { image }
205 }
206}
207
208#[derive(Debug, thiserror::Error)]
210pub enum LoadOcrError {
211 #[error("Failed to load model into device: {0}")]
213 LoadModel(#[from] candle_core::Error),
214 #[error("Failed to download model: {0}")]
216 DownloadModel(#[from] CacheError),
217 #[error("Failed to load tokenizer: {0}")]
219 LoadTokenizer(tokenizers::Error),
220 #[error("Failed to load config: {0}")]
222 LoadConfig(serde_json::Error),
223}
224
225#[derive(Debug, thiserror::Error)]
227pub enum OcrInferenceError {
228 #[error("Failed to run model: {0}")]
230 RunModel(#[from] candle_core::Error),
231 #[error("Failed to decode: {0}")]
233 Decode(tokenizers::Error),
234}
235
236pub struct Ocr {
238 device: Device,
239 decoder: trocr::TrOCRModel,
240 decoder_config: trocr::TrOCRConfig,
241 processor: image_processor::ViTImageProcessor,
242 tokenizer_dec: Tokenizer,
243}
244
245impl Ocr {
246 pub fn builder() -> OcrBuilder {
248 OcrBuilder::default()
249 }
250
251 async fn new(
252 settings: OcrBuilder,
253 mut handler: impl FnMut(ModelLoadingProgress) + Send + Sync + 'static,
254 ) -> Result<Self, LoadOcrError> {
255 let OcrBuilder { source } = settings;
256 let tokenizer_dec = {
257 let tokenizer = Api::new()
258 .map_err(CacheError::HuggingFaceApi)?
259 .model(String::from("ToluClassics/candle-trocr-tokenizer"))
260 .get("tokenizer.json")
261 .map_err(CacheError::HuggingFaceApi)?;
262
263 Tokenizer::from_file(&tokenizer).map_err(LoadOcrError::LoadTokenizer)?
264 };
265 let device = accelerated_device_if_available()?;
266
267 let vb = source.varbuilder(&device, &mut handler).await?;
268
269 let (encoder_config, decoder_config) = source.config(&mut handler).await?;
270
271 let model = trocr::TrOCRModel::new(&encoder_config, &decoder_config, vb)?;
272
273 let config = image_processor::ProcessorConfig::default();
274 let processor = image_processor::ViTImageProcessor::new(&config);
275
276 Ok(Self {
277 device,
278 decoder: model,
279 processor,
280 decoder_config,
281 tokenizer_dec,
282 })
283 }
284
285 pub fn recognize_text(
303 &mut self,
304 settings: OcrInferenceSettings,
305 ) -> Result<String, OcrInferenceError> {
306 let OcrInferenceSettings { image } = settings;
307
308 let image = image::DynamicImage::ImageRgba8(image);
309
310 let image = vec![image];
311 let image = self.processor.preprocess(image, &self.device)?;
312
313 let encoder_xs = self.decoder.encoder().forward(&image)?;
314
315 let mut logits_processor =
316 candle_transformers::generation::LogitsProcessor::new(1337, None, None);
317
318 let mut token_ids: Vec<u32> = vec![self.decoder_config.decoder_start_token_id];
319 for index in 0..1000 {
320 let context_size = if index >= 1 { 1 } else { token_ids.len() };
321 let start_pos = token_ids.len().saturating_sub(context_size);
322 let input_ids = Tensor::new(&token_ids[start_pos..], &self.device)?.unsqueeze(0)?;
323
324 let logits = self.decoder.decode(&input_ids, &encoder_xs, start_pos)?;
325
326 let logits = logits.squeeze(0)?;
327 let logits = logits.get(logits.dim(0)? - 1)?;
328 let token = logits_processor.sample(&logits)?;
329 token_ids.push(token);
330
331 if token == self.decoder_config.eos_token_id {
332 break;
333 }
334 }
335
336 let decoded = self
337 .tokenizer_dec
338 .decode(&token_ids, true)
339 .map_err(OcrInferenceError::Decode)?;
340
341 Ok(decoded)
342 }
343}