use serde::{Deserialize, Serialize};
use std::path::PathBuf;
mod confidence;
mod decoder;
mod engine;
mod inference;
mod models;
pub use confidence::{aggregate_confidence, calculate_confidence, ConfidenceCalibrator};
pub use decoder::{BeamSearchDecoder, CTCDecoder, Decoder, GreedyDecoder, Vocabulary};
pub use engine::{OcrEngine, OcrProcessor};
pub use inference::{DetectionResult, InferenceEngine, RecognitionResult};
pub use models::{ModelHandle, ModelRegistry};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OcrOptions {
pub detection_threshold: f32,
pub recognition_threshold: f32,
pub enable_math: bool,
pub decoder_type: DecoderType,
pub beam_width: usize,
pub batch_size: usize,
pub use_gpu: bool,
pub languages: Vec<String>,
}
impl Default for OcrOptions {
fn default() -> Self {
Self {
detection_threshold: 0.5,
recognition_threshold: 0.6,
enable_math: true,
decoder_type: DecoderType::BeamSearch,
beam_width: 5,
batch_size: 1,
use_gpu: false,
languages: vec!["en".to_string()],
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum DecoderType {
BeamSearch,
Greedy,
CTC,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OcrResult {
pub text: String,
pub confidence: f32,
pub regions: Vec<TextRegion>,
pub has_math: bool,
pub processing_time_ms: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TextRegion {
pub bbox: [f32; 4],
pub text: String,
pub confidence: f32,
pub region_type: RegionType,
pub characters: Vec<Character>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum RegionType {
Text,
Math,
Diagram,
Table,
Unknown,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Character {
pub char: char,
pub confidence: f32,
pub bbox: Option<[f32; 4]>,
}
#[derive(Debug, thiserror::Error)]
pub enum OcrError {
#[error("Model loading error: {0}")]
ModelLoading(String),
#[error("Inference error: {0}")]
Inference(String),
#[error("Image processing error: {0}")]
ImageProcessing(String),
#[error("Decoding error: {0}")]
Decoding(String),
#[error("Invalid configuration: {0}")]
InvalidConfig(String),
#[error("I/O error: {0}")]
Io(#[from] std::io::Error),
#[error("ONNX Runtime error: {0}")]
OnnxRuntime(String),
}
pub type Result<T> = std::result::Result<T, OcrError>;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ocr_options_default() {
let options = OcrOptions::default();
assert_eq!(options.detection_threshold, 0.5);
assert_eq!(options.recognition_threshold, 0.6);
assert!(options.enable_math);
assert_eq!(options.decoder_type, DecoderType::BeamSearch);
assert_eq!(options.beam_width, 5);
}
#[test]
fn test_text_region_creation() {
let region = TextRegion {
bbox: [10.0, 20.0, 100.0, 30.0],
text: "Test".to_string(),
confidence: 0.95,
region_type: RegionType::Text,
characters: vec![],
};
assert_eq!(region.bbox[0], 10.0);
assert_eq!(region.text, "Test");
assert_eq!(region.region_type, RegionType::Text);
}
#[test]
fn test_decoder_type_equality() {
assert_eq!(DecoderType::BeamSearch, DecoderType::BeamSearch);
assert_ne!(DecoderType::BeamSearch, DecoderType::Greedy);
assert_ne!(DecoderType::Greedy, DecoderType::CTC);
}
}