use super::{
confidence::aggregate_confidence,
decoder::{BeamSearchDecoder, CTCDecoder, Decoder, GreedyDecoder, Vocabulary},
inference::{DetectionResult, InferenceEngine, RecognitionResult},
models::{ModelHandle, ModelRegistry},
Character, DecoderType, OcrError, OcrOptions, OcrResult, RegionType, Result, TextRegion,
};
use parking_lot::RwLock;
use std::sync::Arc;
use std::time::Instant;
use tracing::{debug, info, warn};
pub trait OcrProcessor: Send + Sync {
fn process(&self, image_data: &[u8], options: &OcrOptions) -> Result<OcrResult>;
fn process_batch(&self, images: &[&[u8]], options: &OcrOptions) -> Result<Vec<OcrResult>>;
}
pub struct OcrEngine {
registry: Arc<RwLock<ModelRegistry>>,
inference: Arc<InferenceEngine>,
default_options: OcrOptions,
vocabulary: Arc<Vocabulary>,
warmed_up: Arc<RwLock<bool>>,
}
impl OcrEngine {
pub async fn new() -> Result<Self> {
Self::with_options(OcrOptions::default()).await
}
pub async fn with_options(options: OcrOptions) -> Result<Self> {
info!("Initializing OCR engine with options: {:?}", options);
let registry = Arc::new(RwLock::new(ModelRegistry::new()));
debug!("Loading detection model...");
let detection_model = registry.write().load_detection_model().await.map_err(|e| {
OcrError::ModelLoading(format!("Failed to load detection model: {}", e))
})?;
debug!("Loading recognition model...");
let recognition_model = registry
.write()
.load_recognition_model()
.await
.map_err(|e| {
OcrError::ModelLoading(format!("Failed to load recognition model: {}", e))
})?;
let math_model =
if options.enable_math {
debug!("Loading math recognition model...");
Some(registry.write().load_math_model().await.map_err(|e| {
OcrError::ModelLoading(format!("Failed to load math model: {}", e))
})?)
} else {
None
};
let inference = Arc::new(InferenceEngine::new(
detection_model,
recognition_model,
math_model,
options.use_gpu,
)?);
let vocabulary = Arc::new(Vocabulary::default());
let engine = Self {
registry,
inference,
default_options: options,
vocabulary,
warmed_up: Arc::new(RwLock::new(false)),
};
info!("OCR engine initialized successfully");
Ok(engine)
}
pub async fn warmup(&self) -> Result<()> {
if *self.warmed_up.read() {
debug!("Engine already warmed up, skipping");
return Ok(());
}
info!("Warming up OCR engine...");
let start = Instant::now();
let dummy_image = vec![0u8; 100 * 100 * 3];
let _ = self.recognize(&dummy_image).await;
*self.warmed_up.write() = true;
info!("Engine warmup completed in {:?}", start.elapsed());
Ok(())
}
pub async fn recognize(&self, image_data: &[u8]) -> Result<OcrResult> {
self.recognize_with_options(image_data, &self.default_options)
.await
}
pub async fn recognize_with_options(
&self,
image_data: &[u8],
options: &OcrOptions,
) -> Result<OcrResult> {
let start = Instant::now();
debug!("Starting OCR recognition");
debug!("Running text detection...");
let detection_results = self
.inference
.run_detection(image_data, options.detection_threshold)
.await?;
debug!("Detected {} regions", detection_results.len());
if detection_results.is_empty() {
warn!("No text regions detected");
return Ok(OcrResult {
text: String::new(),
confidence: 0.0,
regions: vec![],
has_math: false,
processing_time_ms: start.elapsed().as_millis() as u64,
});
}
debug!("Running text recognition...");
let mut text_regions = Vec::new();
let mut has_math = false;
for detection in detection_results {
let region_type = if options.enable_math && detection.is_math_likely {
has_math = true;
RegionType::Math
} else {
RegionType::Text
};
let recognition = if region_type == RegionType::Math {
self.inference
.run_math_recognition(&detection.region_image, options)
.await?
} else {
self.inference
.run_recognition(&detection.region_image, options)
.await?
};
let decoded_text = self.decode_output(&recognition, options)?;
let confidence = aggregate_confidence(&recognition.character_confidences);
if confidence < options.recognition_threshold {
debug!(
"Skipping region with low confidence: {:.2} < {:.2}",
confidence, options.recognition_threshold
);
continue;
}
let characters = decoded_text
.chars()
.zip(recognition.character_confidences.iter())
.map(|(ch, &conf)| Character {
char: ch,
confidence: conf,
bbox: None, })
.collect();
text_regions.push(TextRegion {
bbox: detection.bbox,
text: decoded_text,
confidence,
region_type,
characters,
});
}
let combined_text = text_regions
.iter()
.map(|r| r.text.as_str())
.collect::<Vec<_>>()
.join(" ");
let overall_confidence = if text_regions.is_empty() {
0.0
} else {
text_regions.iter().map(|r| r.confidence).sum::<f32>() / text_regions.len() as f32
};
let processing_time_ms = start.elapsed().as_millis() as u64;
debug!(
"OCR completed in {}ms, recognized {} regions",
processing_time_ms,
text_regions.len()
);
Ok(OcrResult {
text: combined_text,
confidence: overall_confidence,
regions: text_regions,
has_math,
processing_time_ms,
})
}
pub async fn recognize_batch(
&self,
images: &[&[u8]],
options: &OcrOptions,
) -> Result<Vec<OcrResult>> {
info!("Processing batch of {} images", images.len());
let start = Instant::now();
let results: Result<Vec<OcrResult>> = images
.iter()
.map(|image_data| {
futures::executor::block_on(self.recognize_with_options(image_data, options))
})
.collect();
info!("Batch processing completed in {:?}", start.elapsed());
results
}
fn decode_output(
&self,
recognition: &RecognitionResult,
options: &OcrOptions,
) -> Result<String> {
debug!("Decoding output with {:?} decoder", options.decoder_type);
let decoded = match options.decoder_type {
DecoderType::BeamSearch => {
let decoder = BeamSearchDecoder::new(self.vocabulary.clone(), options.beam_width);
decoder.decode(&recognition.logits)?
}
DecoderType::Greedy => {
let decoder = GreedyDecoder::new(self.vocabulary.clone());
decoder.decode(&recognition.logits)?
}
DecoderType::CTC => {
let decoder = CTCDecoder::new(self.vocabulary.clone());
decoder.decode(&recognition.logits)?
}
};
Ok(decoded)
}
pub fn registry(&self) -> Arc<RwLock<ModelRegistry>> {
Arc::clone(&self.registry)
}
pub fn default_options(&self) -> &OcrOptions {
&self.default_options
}
pub fn is_warmed_up(&self) -> bool {
*self.warmed_up.read()
}
}
impl OcrProcessor for OcrEngine {
fn process(&self, image_data: &[u8], options: &OcrOptions) -> Result<OcrResult> {
futures::executor::block_on(self.recognize_with_options(image_data, options))
}
fn process_batch(&self, images: &[&[u8]], options: &OcrOptions) -> Result<Vec<OcrResult>> {
futures::executor::block_on(self.recognize_batch(images, options))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_decoder_selection() {
let options = OcrOptions {
decoder_type: DecoderType::BeamSearch,
..Default::default()
};
assert_eq!(options.decoder_type, DecoderType::BeamSearch);
}
#[test]
fn test_warmup_flag() {
let flag = Arc::new(RwLock::new(false));
assert!(!*flag.read());
*flag.write() = true;
assert!(*flag.read());
}
}