kalosm_ocr/
lib.rs

1//! # Kalosm OCR
2//!
3//! A rust wrapper for [TR OCR](https://huggingface.co/docs/transformers/model_doc/trocr)
4//!
5//! ## Usage
6//!
7//! ```rust, no_run
8//! # #[tokio::main]
9//! # async fn main() {
10//! use kalosm_ocr::*;
11//!
12//! let mut model = Ocr::builder().build().await.unwrap();
13//! let image = image::open("examples/ocr.png").unwrap();
14//! let text = model
15//!     .recognize_text(OcrInferenceSettings::new(image))
16//!     .unwrap();
17//!
18//! println!("{}", text);
19//! # }
20//! ```
21
22#![warn(missing_docs)]
23#[cfg(feature = "mkl")]
24extern crate intel_mkl_src;
25
26#[cfg(feature = "accelerate")]
27extern crate accelerate_src;
28
29mod image_processor;
30
31use candle_core::DType;
32use candle_core::{Device, Tensor};
33use candle_nn::VarBuilder;
34use candle_transformers::models::trocr;
35use candle_transformers::models::vit;
36use hf_hub::api::sync::Api;
37use image::{GenericImage, GenericImageView, ImageBuffer, Rgba};
38use kalosm_common::*;
39use kalosm_model_types::{FileSource, ModelLoadingProgress};
40use tokenizers::Tokenizer;
41
42/// A builder for [`Ocr`].
43#[derive(Default)]
44pub struct OcrBuilder {
45    source: OcrSource,
46}
47
48impl OcrBuilder {
49    /// Sets the source of the model.
50    pub fn with_source(mut self, source: OcrSource) -> Self {
51        self.source = source;
52        self
53    }
54
55    /// Builds the [`Ocr`] model.
56    pub async fn build(self) -> Result<Ocr, LoadOcrError> {
57        Ocr::new(self, |_| {}).await
58    }
59
60    /// Builds the [`Ocr`] model.
61    pub async fn build_with_loading_handler(
62        self,
63        handler: impl FnMut(ModelLoadingProgress) + Send + Sync + 'static,
64    ) -> Result<Ocr, LoadOcrError> {
65        Ocr::new(self, handler).await
66    }
67}
68
69/// The source of the model.
70pub struct OcrSource {
71    model: FileSource,
72    config: FileSource,
73}
74
75impl OcrSource {
76    /// Creates a new [`OcrSource`].
77    pub fn new(model: FileSource, config: FileSource) -> Self {
78        Self { model, config }
79    }
80
81    /// Create the base model source.
82    pub fn base() -> Self {
83        Self::new(
84            FileSource::huggingface(
85                "microsoft/trocr-base-handwritten".to_string(),
86                "refs/pr/3".to_string(),
87                "model.safetensors".to_string(),
88            ),
89            FileSource::huggingface(
90                "microsoft/trocr-base-handwritten".to_string(),
91                "refs/pr/3".to_string(),
92                "config.json".to_string(),
93            ),
94        )
95    }
96
97    /// Create a normal sized model source.
98    pub fn large() -> Self {
99        Self::new(
100            FileSource::huggingface(
101                "microsoft/trocr-large-handwritten".to_string(),
102                "refs/pr/6".to_string(),
103                "model.safetensors".to_string(),
104            ),
105            FileSource::huggingface(
106                "microsoft/trocr-large-handwritten".to_string(),
107                "refs/pr/6".to_string(),
108                "config.json".to_string(),
109            ),
110        )
111    }
112
113    /// Create a base printed model source.
114    pub fn base_printed() -> Self {
115        Self::new(
116            FileSource::huggingface(
117                "microsoft/trocr-base-printed".to_string(),
118                "refs/pr/7".to_string(),
119                "model.safetensors".to_string(),
120            ),
121            FileSource::huggingface(
122                "microsoft/trocr-base-printed".to_string(),
123                "refs/pr/7".to_string(),
124                "config.json".to_string(),
125            ),
126        )
127    }
128
129    /// Create a large printed model source.
130    pub fn large_printed() -> Self {
131        Self::new(
132            FileSource::huggingface(
133                "microsoft/trocr-large-printed".to_string(),
134                "main".to_string(),
135                "model.safetensors".to_string(),
136            ),
137            FileSource::huggingface(
138                "microsoft/trocr-large-printed".to_string(),
139                "main".to_string(),
140                "config.json".to_string(),
141            ),
142        )
143    }
144
145    async fn varbuilder(
146        &self,
147        device: &Device,
148        mut handler: impl FnMut(ModelLoadingProgress) + Send + Sync,
149    ) -> Result<VarBuilder, LoadOcrError> {
150        let source = format!("Model ({})", self.model);
151        let mut create_progress = ModelLoadingProgress::downloading_progress(source);
152        let cache = Cache::default();
153        let filename = cache
154            .get(&self.model, |progress| handler(create_progress(progress)))
155            .await?;
156        Ok(unsafe { VarBuilder::from_mmaped_safetensors(&[filename], DType::F32, device)? })
157    }
158
159    async fn config(
160        &self,
161        mut handler: impl FnMut(ModelLoadingProgress) + Send + Sync,
162    ) -> Result<(vit::Config, trocr::TrOCRConfig), LoadOcrError> {
163        #[derive(Debug, Clone, serde::Deserialize)]
164        struct Config {
165            encoder: vit::Config,
166            decoder: trocr::TrOCRConfig,
167        }
168
169        let (encoder_config, decoder_config) = {
170            let source = format!("Config ({})", self.model);
171            let mut create_progress = ModelLoadingProgress::downloading_progress(source);
172            let cache = Cache::default();
173            let config_filename = cache
174                .get(&self.config, |progress| handler(create_progress(progress)))
175                .await?;
176            let config: Config = serde_json::from_reader(
177                std::fs::File::open(config_filename)
178                    .expect("FileSource::download should return a valid path"),
179            )
180            .map_err(LoadOcrError::LoadConfig)?;
181            (config.encoder, config.decoder)
182        };
183
184        Ok((encoder_config, decoder_config))
185    }
186}
187
188impl Default for OcrSource {
189    fn default() -> Self {
190        Self::base()
191    }
192}
193
194/// Settings for running inference on [`Ocr`].
195pub struct OcrInferenceSettings {
196    image: ImageBuffer<image::Rgba<u8>, Vec<u8>>,
197}
198
199impl OcrInferenceSettings {
200    /// Creates a new [`OcrInferenceSettings`] from an image.
201    pub fn new<I: GenericImageView<Pixel = Rgba<u8>>>(input: I) -> Self {
202        let mut image = ImageBuffer::new(input.width(), input.height());
203        image.copy_from(&input, 0, 0).unwrap();
204        Self { image }
205    }
206}
207
208/// An error that can occur when loading an [`Ocr`] model.
209#[derive(Debug, thiserror::Error)]
210pub enum LoadOcrError {
211    /// An error that can occur when trying to load an [`Ocr`] model into a device.
212    #[error("Failed to load model into device: {0}")]
213    LoadModel(#[from] candle_core::Error),
214    /// An error that can occur when downloading an [`Ocr`] model from the cache.
215    #[error("Failed to download model: {0}")]
216    DownloadModel(#[from] CacheError),
217    /// An error that can occur when loading the tokenizer.
218    #[error("Failed to load tokenizer: {0}")]
219    LoadTokenizer(tokenizers::Error),
220    /// An error that can occur when loading the config.
221    #[error("Failed to load config: {0}")]
222    LoadConfig(serde_json::Error),
223}
224
225/// An error that can occur when running an [`Ocr`] model.
226#[derive(Debug, thiserror::Error)]
227pub enum OcrInferenceError {
228    /// An error that can occur when trying to run an [`Ocr`] model.
229    #[error("Failed to run model: {0}")]
230    RunModel(#[from] candle_core::Error),
231    /// An error that can occur when decoding the result of an [`Ocr`] model.
232    #[error("Failed to decode: {0}")]
233    Decode(tokenizers::Error),
234}
235
236/// The [trocs](https://www.microsoft.com/en-us/research/publication/trocr-transformer-based-optical-character-recognition-with-pre-trained-models/) optical character recognition model.
237pub struct Ocr {
238    device: Device,
239    decoder: trocr::TrOCRModel,
240    decoder_config: trocr::TrOCRConfig,
241    processor: image_processor::ViTImageProcessor,
242    tokenizer_dec: Tokenizer,
243}
244
245impl Ocr {
246    /// Creates a new [`OcrBuilder`].
247    pub fn builder() -> OcrBuilder {
248        OcrBuilder::default()
249    }
250
251    async fn new(
252        settings: OcrBuilder,
253        mut handler: impl FnMut(ModelLoadingProgress) + Send + Sync + 'static,
254    ) -> Result<Self, LoadOcrError> {
255        let OcrBuilder { source } = settings;
256        let tokenizer_dec = {
257            let tokenizer = Api::new()
258                .map_err(CacheError::HuggingFaceApi)?
259                .model(String::from("ToluClassics/candle-trocr-tokenizer"))
260                .get("tokenizer.json")
261                .map_err(CacheError::HuggingFaceApi)?;
262
263            Tokenizer::from_file(&tokenizer).map_err(LoadOcrError::LoadTokenizer)?
264        };
265        let device = accelerated_device_if_available()?;
266
267        let vb = source.varbuilder(&device, &mut handler).await?;
268
269        let (encoder_config, decoder_config) = source.config(&mut handler).await?;
270
271        let model = trocr::TrOCRModel::new(&encoder_config, &decoder_config, vb)?;
272
273        let config = image_processor::ProcessorConfig::default();
274        let processor = image_processor::ViTImageProcessor::new(&config);
275
276        Ok(Self {
277            device,
278            decoder: model,
279            processor,
280            decoder_config,
281            tokenizer_dec,
282        })
283    }
284
285    /// Recognize text from an image. Returns the recognized text.
286    ///
287    /// # Example
288    /// ```rust, no_run
289    /// # #[tokio::main]
290    /// # async fn main() {
291    /// use kalosm_ocr::*;
292    ///
293    /// let mut model = Ocr::builder().build().await.unwrap();
294    /// let image = image::open("examples/ocr.png").unwrap();
295    /// let text = model
296    ///     .recognize_text(OcrInferenceSettings::new(image))
297    ///     .unwrap();
298    ///
299    /// println!("{}", text);
300    /// # }
301    /// ```
302    pub fn recognize_text(
303        &mut self,
304        settings: OcrInferenceSettings,
305    ) -> Result<String, OcrInferenceError> {
306        let OcrInferenceSettings { image } = settings;
307
308        let image = image::DynamicImage::ImageRgba8(image);
309
310        let image = vec![image];
311        let image = self.processor.preprocess(image, &self.device)?;
312
313        let encoder_xs = self.decoder.encoder().forward(&image)?;
314
315        let mut logits_processor =
316            candle_transformers::generation::LogitsProcessor::new(1337, None, None);
317
318        let mut token_ids: Vec<u32> = vec![self.decoder_config.decoder_start_token_id];
319        for index in 0..1000 {
320            let context_size = if index >= 1 { 1 } else { token_ids.len() };
321            let start_pos = token_ids.len().saturating_sub(context_size);
322            let input_ids = Tensor::new(&token_ids[start_pos..], &self.device)?.unsqueeze(0)?;
323
324            let logits = self.decoder.decode(&input_ids, &encoder_xs, start_pos)?;
325
326            let logits = logits.squeeze(0)?;
327            let logits = logits.get(logits.dim(0)? - 1)?;
328            let token = logits_processor.sample(&logits)?;
329            token_ids.push(token);
330
331            if token == self.decoder_config.eos_token_id {
332                break;
333            }
334        }
335
336        let decoded = self
337            .tokenizer_dec
338            .decode(&token_ids, true)
339            .map_err(OcrInferenceError::Decode)?;
340
341        Ok(decoded)
342    }
343}