use serde::{Deserialize, Serialize};
#[derive(Debug, thiserror::Error)]
pub enum OcrError {
#[error("Network error: {0}")]
Network(String),
#[error("Authentication error: {0}")]
Auth(String),
#[error("Rate limit exceeded: {0}")]
RateLimit(String),
#[error("Parse error: {0}")]
Parse(String),
#[error("Invalid input: {0}")]
InvalidInput(String),
#[error("API error: {0}")]
Api(String),
}
#[derive(Debug, Clone)]
pub enum OcrInput {
PdfBytes(Vec<u8>),
ImageBytes(Vec<u8>),
Url(String),
Base64(String),
}
#[derive(Debug, Clone)]
pub struct OcrRequest {
pub input: OcrInput,
pub output_format: OcrOutputFormat,
pub languages: Vec<String>,
pub extract_tables: bool,
pub extract_images: bool,
pub page_range: Option<(usize, usize)>,
}
impl OcrRequest {
#[must_use]
pub fn from_pdf_bytes(bytes: Vec<u8>) -> Self {
Self {
input: OcrInput::PdfBytes(bytes),
output_format: OcrOutputFormat::Markdown,
languages: vec![],
extract_tables: true,
extract_images: false,
page_range: None,
}
}
#[must_use]
pub fn from_image_bytes(bytes: Vec<u8>) -> Self {
Self {
input: OcrInput::ImageBytes(bytes),
output_format: OcrOutputFormat::Markdown,
languages: vec![],
extract_tables: true,
extract_images: false,
page_range: None,
}
}
#[must_use]
pub fn from_url(url: impl Into<String>) -> Self {
Self {
input: OcrInput::Url(url.into()),
output_format: OcrOutputFormat::Markdown,
languages: vec![],
extract_tables: true,
extract_images: false,
page_range: None,
}
}
#[must_use]
pub fn with_output_format(mut self, format: OcrOutputFormat) -> Self {
self.output_format = format;
self
}
#[must_use]
pub fn with_languages(mut self, languages: Vec<String>) -> Self {
self.languages = languages;
self
}
#[must_use]
pub fn with_extract_tables(mut self, extract: bool) -> Self {
self.extract_tables = extract;
self
}
#[must_use]
pub fn with_extract_images(mut self, extract: bool) -> Self {
self.extract_images = extract;
self
}
#[must_use]
pub fn with_page_range(mut self, start: usize, end: usize) -> Self {
self.page_range = Some((start, end));
self
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum OcrOutputFormat {
Text,
#[default]
Markdown,
Html,
Json,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OcrTable {
pub page: usize,
pub content: String,
pub bbox: Option<(f64, f64, f64, f64)>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OcrImage {
pub page: usize,
pub description: Option<String>,
pub bbox: Option<(f64, f64, f64, f64)>,
pub data: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct OcrProvenance {
pub provider: String,
pub version: String,
pub languages: Vec<String>,
pub preprocessing: OcrPreprocessing,
pub input_hash: Option<String>,
pub output_hash: Option<String>,
#[serde(default)]
pub metadata: std::collections::HashMap<String, String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct OcrPreprocessing {
pub dpi: Option<u32>,
pub binarized: bool,
pub deskewed: bool,
pub denoised: bool,
pub psm: Option<u32>,
pub oem: Option<u32>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct OcrConfidence {
pub mean: f64,
pub min: f64,
pub max: f64,
pub std_dev: Option<f64>,
pub low_confidence_words: usize,
pub threshold: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OcrSpan {
pub text: String,
pub confidence: f64,
pub page: usize,
pub bbox: Option<(i32, i32, i32, i32)>,
pub block_num: Option<i32>,
pub par_num: Option<i32>,
pub line_num: Option<i32>,
pub word_num: Option<i32>,
}
impl OcrSpan {
#[must_use]
pub fn new(text: impl Into<String>, confidence: f64) -> Self {
Self {
text: text.into(),
confidence,
page: 0,
bbox: None,
block_num: None,
par_num: None,
line_num: None,
word_num: None,
}
}
#[must_use]
pub fn with_bbox(mut self, x: i32, y: i32, w: i32, h: i32) -> Self {
self.bbox = Some((x, y, w, h));
self
}
#[must_use]
pub fn with_page(mut self, page: usize) -> Self {
self.page = page;
self
}
#[must_use]
pub fn is_low_confidence(&self, threshold: f64) -> bool {
self.confidence < threshold
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum TesseractOutputFormat {
#[default]
Text,
Tsv,
Hocr,
Alto,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OcrResult {
pub text: String,
pub pages: usize,
#[serde(default)]
pub spans: Vec<OcrSpan>,
pub tables: Vec<OcrTable>,
pub images: Vec<OcrImage>,
pub confidence: Option<OcrConfidence>,
pub processing_time_ms: Option<u64>,
pub provenance: OcrProvenance,
}
pub trait OcrProvider: Send + Sync {
fn name(&self) -> &'static str;
fn model(&self) -> &str;
fn extract(&self, request: &OcrRequest) -> Result<OcrResult, OcrError>;
}
pub struct MistralOcrProvider {
api_key: crate::secret::SecretString,
model: String,
base_url: String,
client: reqwest::blocking::Client,
}
impl MistralOcrProvider {
#[must_use]
pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Self {
Self {
api_key: crate::secret::SecretString::new(api_key),
model: model.into(),
base_url: "https://api.mistral.ai/v1".to_string(),
client: reqwest::blocking::Client::new(),
}
}
pub fn from_env() -> Result<Self, OcrError> {
let api_key = std::env::var("MISTRAL_API_KEY").map_err(|_| {
OcrError::Auth("MISTRAL_API_KEY environment variable not set".to_string())
})?;
Ok(Self::new(api_key, "mistral-ocr-latest"))
}
pub fn from_env_with_model(model: impl Into<String>) -> Result<Self, OcrError> {
let api_key = std::env::var("MISTRAL_API_KEY").map_err(|_| {
OcrError::Auth("MISTRAL_API_KEY environment variable not set".to_string())
})?;
Ok(Self::new(api_key, model))
}
#[must_use]
pub fn with_base_url(mut self, url: impl Into<String>) -> Self {
self.base_url = url.into();
self
}
}
impl OcrProvider for MistralOcrProvider {
fn name(&self) -> &'static str {
"mistral-ocr"
}
fn model(&self) -> &str {
&self.model
}
fn extract(&self, request: &OcrRequest) -> Result<OcrResult, OcrError> {
let document = match &request.input {
OcrInput::PdfBytes(bytes) => {
serde_json::json!({
"type": "document_url",
"document_url": format!("data:application/pdf;base64,{}", base64::Engine::encode(&base64::engine::general_purpose::STANDARD, bytes))
})
}
OcrInput::ImageBytes(bytes) => {
serde_json::json!({
"type": "image_url",
"image_url": format!("data:image/png;base64,{}", base64::Engine::encode(&base64::engine::general_purpose::STANDARD, bytes))
})
}
OcrInput::Url(url) => {
if std::path::Path::new(url)
.extension()
.is_some_and(|ext| ext.eq_ignore_ascii_case("pdf"))
{
serde_json::json!({
"type": "document_url",
"document_url": url
})
} else {
serde_json::json!({
"type": "image_url",
"image_url": url
})
}
}
OcrInput::Base64(data) => {
serde_json::json!({
"type": "document_url",
"document_url": format!("data:application/pdf;base64,{}", data)
})
}
};
let body = serde_json::json!({
"model": self.model,
"document": document,
"include_image_base64": request.extract_images
});
let response = self
.client
.post(format!("{}/ocr", self.base_url))
.header("Authorization", format!("Bearer {}", self.api_key.expose()))
.header("Content-Type", "application/json")
.json(&body)
.send()
.map_err(|e| OcrError::Network(format!("Request failed: {e}")))?;
let status = response.status();
if !status.is_success() {
let error_text = response.text().unwrap_or_default();
return match status.as_u16() {
401 | 403 => Err(OcrError::Auth(format!(
"Authentication failed: {error_text}"
))),
429 => Err(OcrError::RateLimit("Rate limit exceeded".to_string())),
_ => Err(OcrError::Api(format!("API error ({status}): {error_text}"))),
};
}
let api_response: MistralOcrResponse = response
.json()
.map_err(|e| OcrError::Parse(format!("Failed to parse response: {e}")))?;
let mut tables = vec![];
let mut images = vec![];
let mut text = String::new();
for (page_idx, page) in api_response.pages.iter().enumerate() {
text.push_str(&page.markdown);
text.push_str("\n\n");
if page.markdown.contains("<table") {
tables.push(OcrTable {
page: page_idx,
content: page.markdown.clone(),
bbox: None,
});
}
for img in &page.images {
images.push(OcrImage {
page: page_idx,
description: None,
bbox: None,
data: img.image_base64.clone(),
});
}
}
Ok(OcrResult {
text: text.trim().to_string(),
pages: api_response.pages.len(),
spans: vec![], tables,
images,
confidence: None,
processing_time_ms: None,
provenance: OcrProvenance {
provider: "mistral-ocr".to_string(),
version: self.model.clone(),
languages: request.languages.clone(),
preprocessing: OcrPreprocessing::default(),
input_hash: None, output_hash: None, metadata: std::collections::HashMap::new(),
},
})
}
}
#[derive(Debug, Deserialize)]
struct MistralOcrResponse {
pages: Vec<MistralOcrPage>,
}
#[derive(Debug, Deserialize)]
struct MistralOcrPage {
markdown: String,
#[serde(default)]
images: Vec<MistralOcrImage>,
}
#[derive(Debug, Deserialize)]
struct MistralOcrImage {
#[serde(default)]
image_base64: Option<String>,
}
pub struct DeepSeekOcrProvider {
api_key: crate::secret::SecretString,
model: String,
base_url: String,
client: reqwest::blocking::Client,
}
impl DeepSeekOcrProvider {
#[must_use]
pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Self {
Self {
api_key: crate::secret::SecretString::new(api_key),
model: model.into(),
base_url: "https://api.deepseek.com/v1".to_string(),
client: reqwest::blocking::Client::new(),
}
}
pub fn from_env() -> Result<Self, OcrError> {
let api_key = std::env::var("DEEPSEEK_API_KEY").map_err(|_| {
OcrError::Auth("DEEPSEEK_API_KEY environment variable not set".to_string())
})?;
Ok(Self::new(api_key, "deepseek-ocr-2"))
}
#[must_use]
pub fn with_base_url(mut self, url: impl Into<String>) -> Self {
self.base_url = url.into();
self
}
}
impl OcrProvider for DeepSeekOcrProvider {
fn name(&self) -> &'static str {
"deepseek-ocr"
}
fn model(&self) -> &str {
&self.model
}
fn extract(&self, request: &OcrRequest) -> Result<OcrResult, OcrError> {
let image_content = match &request.input {
OcrInput::ImageBytes(bytes) => {
format!(
"data:image/png;base64,{}",
base64::Engine::encode(&base64::engine::general_purpose::STANDARD, bytes)
)
}
OcrInput::PdfBytes(bytes) => {
format!(
"data:application/pdf;base64,{}",
base64::Engine::encode(&base64::engine::general_purpose::STANDARD, bytes)
)
}
OcrInput::Url(url) => url.clone(),
OcrInput::Base64(data) => format!("data:image/png;base64,{data}"),
};
let body = serde_json::json!({
"model": self.model,
"messages": [{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": image_content
}
},
{
"type": "text",
"text": "Extract all text from this document, preserving structure, tables, and reading order. Output in markdown format."
}
]
}],
"max_tokens": 8192
});
let response = self
.client
.post(format!("{}/chat/completions", self.base_url))
.header("Authorization", format!("Bearer {}", self.api_key.expose()))
.header("Content-Type", "application/json")
.json(&body)
.send()
.map_err(|e| OcrError::Network(format!("Request failed: {e}")))?;
let status = response.status();
if !status.is_success() {
let error_text = response.text().unwrap_or_default();
return match status.as_u16() {
401 | 403 => Err(OcrError::Auth(format!(
"Authentication failed: {error_text}"
))),
429 => Err(OcrError::RateLimit("Rate limit exceeded".to_string())),
_ => Err(OcrError::Api(format!("API error ({status}): {error_text}"))),
};
}
let api_response: DeepSeekOcrResponse = response
.json()
.map_err(|e| OcrError::Parse(format!("Failed to parse response: {e}")))?;
let text = api_response
.choices
.first()
.and_then(|c| c.message.content.clone())
.unwrap_or_default();
Ok(OcrResult {
text,
pages: 1, spans: vec![], tables: vec![],
images: vec![],
confidence: None,
processing_time_ms: None,
provenance: OcrProvenance {
provider: "deepseek-ocr".to_string(),
version: self.model.clone(),
languages: request.languages.clone(),
preprocessing: OcrPreprocessing::default(),
input_hash: None,
output_hash: None,
metadata: std::collections::HashMap::new(),
},
})
}
}
#[derive(Debug, Deserialize)]
struct DeepSeekOcrResponse {
choices: Vec<DeepSeekOcrChoice>,
}
#[derive(Debug, Deserialize)]
struct DeepSeekOcrChoice {
message: DeepSeekOcrMessage,
}
#[derive(Debug, Deserialize)]
struct DeepSeekOcrMessage {
content: Option<String>,
}
pub struct LightOnOcrProvider {
api_key: crate::secret::SecretString,
model: String,
base_url: String,
client: reqwest::blocking::Client,
}
impl LightOnOcrProvider {
#[must_use]
pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Self {
Self {
api_key: crate::secret::SecretString::new(api_key),
model: model.into(),
base_url: "https://api-inference.huggingface.co/models".to_string(),
client: reqwest::blocking::Client::new(),
}
}
pub fn from_env() -> Result<Self, OcrError> {
let api_key = std::env::var("HUGGINGFACE_API_KEY").map_err(|_| {
OcrError::Auth("HUGGINGFACE_API_KEY environment variable not set".to_string())
})?;
Ok(Self::new(api_key, "lightonai/LightOnOCR-2-1B"))
}
pub fn from_env_with_bbox() -> Result<Self, OcrError> {
let api_key = std::env::var("HUGGINGFACE_API_KEY").map_err(|_| {
OcrError::Auth("HUGGINGFACE_API_KEY environment variable not set".to_string())
})?;
Ok(Self::new(api_key, "lightonai/LightOnOCR-2-1B-bbox"))
}
#[must_use]
pub fn with_base_url(mut self, url: impl Into<String>) -> Self {
self.base_url = url.into();
self
}
}
impl OcrProvider for LightOnOcrProvider {
fn name(&self) -> &'static str {
"lighton-ocr"
}
fn model(&self) -> &str {
&self.model
}
fn extract(&self, request: &OcrRequest) -> Result<OcrResult, OcrError> {
let image_bytes = match &request.input {
OcrInput::ImageBytes(bytes) => bytes.clone(),
OcrInput::PdfBytes(_) => {
return Err(OcrError::InvalidInput(
"LightOnOCR requires image input. Convert PDF pages to images first."
.to_string(),
));
}
OcrInput::Url(url) => {
let response = self
.client
.get(url)
.send()
.map_err(|e| OcrError::Network(format!("Failed to fetch image: {e}")))?;
response
.bytes()
.map_err(|e| OcrError::Network(format!("Failed to read image: {e}")))?
.to_vec()
}
OcrInput::Base64(data) => {
base64::Engine::decode(&base64::engine::general_purpose::STANDARD, data)
.map_err(|e| OcrError::Parse(format!("Invalid base64: {e}")))?
}
};
let response = self
.client
.post(format!("{}/{}", self.base_url, self.model))
.header("Authorization", format!("Bearer {}", self.api_key.expose()))
.header("Content-Type", "application/octet-stream")
.body(image_bytes)
.send()
.map_err(|e| OcrError::Network(format!("Request failed: {e}")))?;
let status = response.status();
if !status.is_success() {
let error_text = response.text().unwrap_or_default();
return match status.as_u16() {
401 | 403 => Err(OcrError::Auth(format!(
"Authentication failed: {error_text}"
))),
429 => Err(OcrError::RateLimit("Rate limit exceeded".to_string())),
503 => Err(OcrError::Api("Model is loading, please retry".to_string())),
_ => Err(OcrError::Api(format!("API error ({status}): {error_text}"))),
};
}
let text = response
.text()
.map_err(|e| OcrError::Parse(format!("Failed to read response: {e}")))?;
Ok(OcrResult {
text,
pages: 1,
spans: vec![], tables: vec![],
images: vec![],
confidence: None,
processing_time_ms: None,
provenance: OcrProvenance {
provider: "lighton-ocr".to_string(),
version: self.model.clone(),
languages: request.languages.clone(),
preprocessing: OcrPreprocessing::default(),
input_hash: None,
output_hash: None,
metadata: std::collections::HashMap::new(),
},
})
}
}
#[derive(Debug, Clone)]
pub struct TesseractConfig {
pub binary_path: String,
pub tessdata_path: Option<String>,
pub languages: Vec<String>,
pub dpi: u32,
pub psm: u32,
pub oem: u32,
pub output_format: TesseractOutputFormat,
pub preprocess: bool,
pub timeout_secs: u64,
}
impl Default for TesseractConfig {
fn default() -> Self {
Self {
binary_path: "tesseract".to_string(),
tessdata_path: None,
languages: vec!["eng".to_string()],
dpi: 300,
psm: 3, oem: 3, output_format: TesseractOutputFormat::Text,
preprocess: true,
timeout_secs: 60,
}
}
}
impl TesseractConfig {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_binary_path(mut self, path: impl Into<String>) -> Self {
self.binary_path = path.into();
self
}
#[must_use]
pub fn with_tessdata_path(mut self, path: impl Into<String>) -> Self {
self.tessdata_path = Some(path.into());
self
}
#[must_use]
pub fn with_languages(mut self, languages: Vec<impl Into<String>>) -> Self {
self.languages = languages.into_iter().map(Into::into).collect();
self
}
#[must_use]
pub fn with_dpi(mut self, dpi: u32) -> Self {
self.dpi = dpi;
self
}
#[must_use]
pub fn with_psm(mut self, psm: u32) -> Self {
self.psm = psm;
self
}
#[must_use]
pub fn with_oem(mut self, oem: u32) -> Self {
self.oem = oem;
self
}
#[must_use]
pub fn with_preprocess(mut self, preprocess: bool) -> Self {
self.preprocess = preprocess;
self
}
#[must_use]
pub fn with_timeout(mut self, secs: u64) -> Self {
self.timeout_secs = secs;
self
}
#[must_use]
pub fn with_output_format(mut self, format: TesseractOutputFormat) -> Self {
self.output_format = format;
self
}
}
#[derive(Debug)]
pub struct TesseractOcrProvider {
config: TesseractConfig,
}
impl TesseractOcrProvider {
#[must_use]
pub fn new() -> Self {
Self {
config: TesseractConfig::default(),
}
}
#[must_use]
pub fn with_config(config: TesseractConfig) -> Self {
Self { config }
}
#[must_use]
pub fn with_languages(mut self, languages: Vec<impl Into<String>>) -> Self {
self.config.languages = languages.into_iter().map(Into::into).collect();
self
}
#[must_use]
pub fn with_dpi(mut self, dpi: u32) -> Self {
self.config.dpi = dpi;
self
}
pub fn check_availability(&self) -> Result<String, OcrError> {
Err(OcrError::Api(
"Tesseract provider not yet implemented. Enable the 'tesseract' feature.".to_string(),
))
}
#[must_use]
pub fn version(&self) -> Option<String> {
None }
}
impl Default for TesseractOcrProvider {
fn default() -> Self {
Self::new()
}
}
impl OcrProvider for TesseractOcrProvider {
fn name(&self) -> &'static str {
"tesseract"
}
fn model(&self) -> &'static str {
"tesseract-stub"
}
fn extract(&self, _request: &OcrRequest) -> Result<OcrResult, OcrError> {
Err(OcrError::Api(
"Tesseract OCR provider not yet implemented. \
This is a placeholder for future local OCR support. \
For now, use MistralOcrProvider, DeepSeekOcrProvider, or LightOnOcrProvider."
.to_string(),
))
}
}
#[must_use]
pub fn compute_hash(data: &[u8]) -> String {
use sha2::{Digest, Sha256};
let mut hasher = Sha256::new();
hasher.update(data);
format!("{:x}", hasher.finalize())
}
#[must_use]
pub fn with_trace_hashes(
mut provenance: OcrProvenance,
input: &[u8],
output: &str,
) -> OcrProvenance {
provenance.input_hash = Some(compute_hash(input));
provenance.output_hash = Some(compute_hash(output.as_bytes()));
provenance
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ocr_request_builder() {
let request = OcrRequest::from_pdf_bytes(vec![1, 2, 3])
.with_output_format(OcrOutputFormat::Html)
.with_languages(vec!["en".to_string(), "de".to_string()])
.with_extract_tables(true)
.with_extract_images(true)
.with_page_range(0, 10);
assert_eq!(request.output_format, OcrOutputFormat::Html);
assert_eq!(request.languages, vec!["en", "de"]);
assert!(request.extract_tables);
assert!(request.extract_images);
assert_eq!(request.page_range, Some((0, 10)));
}
#[test]
fn test_ocr_output_format_default() {
let format = OcrOutputFormat::default();
assert_eq!(format, OcrOutputFormat::Markdown);
}
}