use crate::plugins::OcrBackend;
use crate::{KreuzbergError, Result};
use ahash::AHashMap;
use std::sync::Arc;
pub struct OcrBackendRegistry {
pub(super) backends: AHashMap<String, Arc<dyn OcrBackend>>,
}
impl OcrBackendRegistry {
#[tracing::instrument(name = "ocr_backend_registry_init")]
pub fn new() -> Self {
#[cfg(any(feature = "ocr", feature = "paddle-ocr", feature = "liter-llm"))]
let mut registry = Self {
backends: AHashMap::new(),
};
#[cfg(not(any(feature = "ocr", feature = "paddle-ocr", feature = "liter-llm")))]
let registry = Self {
backends: AHashMap::new(),
};
#[cfg(feature = "ocr")]
{
use crate::ocr::tesseract_backend::TesseractBackend;
tracing::info!("Initializing Tesseract OCR backend");
let backend = TesseractBackend::new().unwrap_or_else(|e| {
panic!(
"Tesseract OCR backend initialization failed: {e}. \
The 'ocr' feature is enabled but the backend cannot start. \
Check TESSDATA_PREFIX environment variable, tessdata file permissions, \
and cache directory writability."
);
});
registry.register(Arc::new(backend)).unwrap_or_else(|e| {
panic!(
"Failed to register Tesseract OCR backend: {e}. \
Check TESSDATA_PREFIX environment variable and tessdata file permissions."
);
});
tracing::info!("Tesseract OCR backend registered successfully");
}
#[cfg(feature = "paddle-ocr")]
{
use crate::paddle_ocr::PaddleOcrBackend;
tracing::info!("Initializing PaddleOCR backend");
let backend = PaddleOcrBackend::new().unwrap_or_else(|e| {
panic!(
"PaddleOCR backend initialization failed: {e}. \
The 'paddle-ocr' feature is enabled but the backend cannot start. \
Check ONNX Runtime availability and model files."
);
});
registry.register(Arc::new(backend)).unwrap_or_else(|e| {
panic!("Failed to register PaddleOCR backend: {e}.");
});
tracing::info!("PaddleOCR backend registered successfully");
}
#[cfg(feature = "liter-llm")]
{
use crate::llm::vlm_ocr::VlmOcrBackend;
tracing::info!("Registering VLM OCR backend");
registry.register(Arc::new(VlmOcrBackend)).unwrap_or_else(|e| {
tracing::warn!("Failed to register VLM OCR backend: {e}");
});
}
registry
}
pub fn new_empty() -> Self {
Self {
backends: AHashMap::new(),
}
}
#[tracing::instrument(skip(self, backend), fields(backend_name))]
pub fn register(&mut self, backend: Arc<dyn OcrBackend>) -> Result<()> {
let name = backend.name().to_string();
tracing::Span::current().record("backend_name", name.as_str());
super::validate_plugin_name(&name)?;
backend.initialize()?;
tracing::info!(backend = %name, "OCR backend registered");
self.backends.insert(name, backend);
Ok(())
}
#[tracing::instrument(skip(self), fields(registered_backends = ?self.backends.keys().collect::<Vec<_>>()))]
pub fn get(&self, name: &str) -> Result<Arc<dyn OcrBackend>> {
let canonical = match name {
"paddleocr" => "paddle-ocr",
_ => name,
};
self.backends.get(canonical).cloned().ok_or_else(|| {
tracing::error!(
backend = name,
available = ?self.backends.keys().collect::<Vec<_>>(),
"OCR backend not found in registry"
);
KreuzbergError::Plugin {
message: format!(
"OCR backend '{}' not registered. Available backends: {:?}",
name,
self.backends.keys().collect::<Vec<_>>()
),
plugin_name: name.to_string(),
}
})
}
pub fn get_for_language(&self, language: &str) -> Result<Arc<dyn OcrBackend>> {
self.backends
.values()
.find(|backend| backend.supports_language(language))
.cloned()
.ok_or_else(|| KreuzbergError::Plugin {
message: format!("No OCR backend supports language '{}'", language),
plugin_name: language.to_string(),
})
}
pub fn list(&self) -> Vec<String> {
self.backends.keys().cloned().collect()
}
pub fn remove(&mut self, name: &str) -> Result<()> {
if let Some(backend) = self.backends.remove(name) {
backend.shutdown()?;
}
Ok(())
}
pub fn shutdown_all(&mut self) -> Result<()> {
let names: Vec<_> = self.backends.keys().cloned().collect();
for name in names {
self.remove(&name)?;
}
Ok(())
}
pub fn reset_to_defaults(&mut self) -> Result<()> {
self.shutdown_all()?;
let fresh = Self::new();
for (_name, backend) in fresh.backends {
let _ = self.register(backend);
}
Ok(())
}
}
impl Default for OcrBackendRegistry {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::config::OcrConfig;
use crate::plugins::{OcrBackend, Plugin};
use crate::types::ExtractionResult;
use async_trait::async_trait;
use std::borrow::Cow;
struct MockOcrBackend {
name: String,
languages: Vec<String>,
}
impl Plugin for MockOcrBackend {
fn name(&self) -> &str {
&self.name
}
fn version(&self) -> String {
"1.0.0".to_string()
}
fn initialize(&self) -> Result<()> {
Ok(())
}
fn shutdown(&self) -> Result<()> {
Ok(())
}
}
#[async_trait]
impl OcrBackend for MockOcrBackend {
async fn process_image(&self, _: &[u8], _: &OcrConfig) -> Result<ExtractionResult> {
Ok(ExtractionResult {
content: "test".to_string(),
mime_type: Cow::Borrowed("text/plain"),
..Default::default()
})
}
fn supports_language(&self, lang: &str) -> bool {
self.languages.iter().any(|l| l == lang)
}
fn backend_type(&self) -> crate::plugins::ocr::OcrBackendType {
crate::plugins::ocr::OcrBackendType::Custom
}
}
#[test]
fn test_ocr_backend_registry() {
let mut registry = OcrBackendRegistry::new_empty();
let backend = Arc::new(MockOcrBackend {
name: "test-ocr".to_string(),
languages: vec!["eng".to_string(), "deu".to_string()],
});
registry.register(backend).unwrap();
let retrieved = registry.get("test-ocr").unwrap();
assert_eq!(retrieved.name(), "test-ocr");
let eng_backend = registry.get_for_language("eng").unwrap();
assert_eq!(eng_backend.name(), "test-ocr");
let names = registry.list();
assert_eq!(names.len(), 1);
assert!(names.contains(&"test-ocr".to_string()));
}
#[test]
fn test_ocr_backend_registry_new_empty() {
let registry = OcrBackendRegistry::new_empty();
assert_eq!(registry.list().len(), 0);
}
#[test]
fn test_ocr_backend_get_missing() {
let registry = OcrBackendRegistry::new_empty();
let result = registry.get("nonexistent");
assert!(result.is_err());
}
#[test]
fn test_ocr_backend_get_for_language_missing() {
let registry = OcrBackendRegistry::new_empty();
let result = registry.get_for_language("fra");
assert!(result.is_err());
}
#[test]
fn test_ocr_backend_remove() {
let mut registry = OcrBackendRegistry::new_empty();
let backend = Arc::new(MockOcrBackend {
name: "test-backend".to_string(),
languages: vec!["eng".to_string()],
});
registry.register(backend).unwrap();
registry.remove("test-backend").unwrap();
assert_eq!(registry.list().len(), 0);
}
#[test]
fn test_ocr_backend_shutdown_all() {
let mut registry = OcrBackendRegistry::new_empty();
let backend1 = Arc::new(MockOcrBackend {
name: "backend1".to_string(),
languages: vec!["eng".to_string()],
});
let backend2 = Arc::new(MockOcrBackend {
name: "backend2".to_string(),
languages: vec!["deu".to_string()],
});
registry.register(backend1).unwrap();
registry.register(backend2).unwrap();
registry.shutdown_all().unwrap();
assert_eq!(registry.list().len(), 0);
}
struct FailingOcrBackend {
name: String,
fail_on_init: bool,
}
impl Plugin for FailingOcrBackend {
fn name(&self) -> &str {
&self.name
}
fn version(&self) -> String {
"1.0.0".to_string()
}
fn initialize(&self) -> Result<()> {
if self.fail_on_init {
Err(KreuzbergError::Plugin {
message: "Backend initialization failed".to_string(),
plugin_name: self.name.clone(),
})
} else {
Ok(())
}
}
fn shutdown(&self) -> Result<()> {
Ok(())
}
}
#[async_trait]
impl OcrBackend for FailingOcrBackend {
async fn process_image(&self, _: &[u8], _: &OcrConfig) -> Result<ExtractionResult> {
Ok(ExtractionResult {
content: "test".to_string(),
mime_type: Cow::Borrowed("text/plain"),
..Default::default()
})
}
fn supports_language(&self, _lang: &str) -> bool {
false
}
fn backend_type(&self) -> crate::plugins::ocr::OcrBackendType {
crate::plugins::ocr::OcrBackendType::Custom
}
}
#[test]
fn test_ocr_backend_initialization_failure_logs_error() {
let mut registry = OcrBackendRegistry::new_empty();
let backend = Arc::new(FailingOcrBackend {
name: "failing-ocr".to_string(),
fail_on_init: true,
});
let result = registry.register(backend);
assert!(result.is_err());
assert_eq!(registry.list().len(), 0);
}
#[test]
fn test_ocr_backend_invalid_name_empty_logs_warning() {
let mut registry = OcrBackendRegistry::new_empty();
let backend = Arc::new(MockOcrBackend {
name: "".to_string(),
languages: vec!["eng".to_string()],
});
let result = registry.register(backend);
assert!(matches!(result, Err(KreuzbergError::Validation { .. })));
}
#[test]
fn test_ocr_backend_invalid_name_with_spaces_logs_warning() {
let mut registry = OcrBackendRegistry::new_empty();
let backend = Arc::new(MockOcrBackend {
name: "invalid ocr backend".to_string(),
languages: vec!["eng".to_string()],
});
let result = registry.register(backend);
assert!(matches!(result, Err(KreuzbergError::Validation { .. })));
}
#[test]
fn test_ocr_backend_successful_registration_logs_debug() {
let mut registry = OcrBackendRegistry::new_empty();
let backend = Arc::new(MockOcrBackend {
name: "valid-ocr".to_string(),
languages: vec!["eng".to_string()],
});
let result = registry.register(backend);
assert!(result.is_ok());
assert_eq!(registry.list().len(), 1);
}
#[test]
fn test_ocr_backend_multiple_registrations() {
let mut registry = OcrBackendRegistry::new_empty();
let backend1 = Arc::new(MockOcrBackend {
name: "ocr-backend-1".to_string(),
languages: vec!["eng".to_string()],
});
let backend2 = Arc::new(MockOcrBackend {
name: "ocr-backend-2".to_string(),
languages: vec!["deu".to_string()],
});
registry.register(backend1).unwrap();
registry.register(backend2).unwrap();
assert_eq!(registry.list().len(), 2);
}
#[test]
fn test_ocr_backend_paddleocr_alias_resolves() {
let mut registry = OcrBackendRegistry::new_empty();
let backend = Arc::new(MockOcrBackend {
name: "paddle-ocr".to_string(),
languages: vec!["en".to_string()],
});
registry.register(backend).unwrap();
let retrieved = registry.get("paddleocr").unwrap();
assert_eq!(retrieved.name(), "paddle-ocr");
let retrieved = registry.get("paddle-ocr").unwrap();
assert_eq!(retrieved.name(), "paddle-ocr");
}
#[test]
fn test_ocr_backend_paddleocr_alias_resolves_to_paddle_ocr() {
let mut registry = OcrBackendRegistry::new_empty();
let backend = Arc::new(MockOcrBackend {
name: "paddle-ocr".to_string(),
languages: vec!["en".to_string()],
});
registry.register(backend).unwrap();
let retrieved = registry.get("paddle-ocr").unwrap();
assert_eq!(retrieved.name(), "paddle-ocr");
let aliased = registry.get("paddleocr").unwrap();
assert_eq!(aliased.name(), "paddle-ocr");
}
}