use ahash::AHashMap;
use async_trait::async_trait;
use std::borrow::Cow;
use std::collections::HashMap;
use std::panic::catch_unwind;
use std::path::Path;
use std::sync::{Arc, Mutex};
use crate::Result;
use crate::core::config::OcrConfig;
use crate::ocr::conversion::{elements_to_hocr_words, text_block_to_element};
use crate::plugins::{OcrBackend, OcrBackendType, Plugin};
use crate::types::{ExtractionResult, FormatMetadata, Metadata, OcrElement, OcrMetadata, Table};
use html_to_markdown_rs::hocr::{reconstruct_table, table_to_markdown};
use super::config::PaddleOcrConfig;
use super::model_manager::{ModelManager, SharedModelPaths};
use super::{is_language_supported, language_to_script_family, map_language_code};
use kreuzberg_paddle_ocr::OcrLite;
pub struct PaddleOcrBackend {
config: Arc<PaddleOcrConfig>,
model_manager: ModelManager,
shared_paths: Mutex<Option<SharedModelPaths>>,
engine_pool: Mutex<HashMap<String, Arc<OcrLite>>>,
doc_ori_detector: once_cell::sync::OnceCell<crate::doc_orientation::DocOrientationDetector>,
}
impl PaddleOcrBackend {
pub fn new() -> Result<Self> {
Self::with_config(PaddleOcrConfig::default())
}
pub fn with_config(config: PaddleOcrConfig) -> Result<Self> {
let cache_dir = config.resolve_cache_dir();
Ok(Self {
config: Arc::new(config),
model_manager: ModelManager::new(cache_dir),
shared_paths: Mutex::new(None),
engine_pool: Mutex::new(HashMap::new()),
doc_ori_detector: once_cell::sync::OnceCell::new(),
})
}
fn get_or_init_shared_paths(&self) -> Result<SharedModelPaths> {
let mut paths = self.shared_paths.lock().map_err(|e| crate::KreuzbergError::Plugin {
message: format!("Failed to acquire shared paths lock: {e}"),
plugin_name: "paddle-ocr".to_string(),
})?;
if let Some(ref p) = *paths {
return Ok(p.clone());
}
let shared = self.model_manager.ensure_shared_models(&self.config.model_tier)?;
*paths = Some(shared.clone());
Ok(shared)
}
fn get_or_init_engine_for_family(&self, family: &str) -> Result<Arc<OcrLite>> {
let tier = &self.config.model_tier;
let resolved = self.model_manager.resolve_rec_model(family, tier)?;
let pool_key = format!("{tier}/{}", resolved.model_key);
{
let pool = self.engine_pool.lock().map_err(|e| crate::KreuzbergError::Plugin {
message: format!("Failed to acquire engine pool lock: {e}"),
plugin_name: "paddle-ocr".to_string(),
})?;
if let Some(engine) = pool.get(&pool_key) {
return Ok(Arc::clone(engine));
}
}
let shared = self.get_or_init_shared_paths()?;
crate::ort_discovery::ensure_ort_available();
tracing::info!(family, model_key = %resolved.model_key, tier, "Initializing PaddleOCR engine");
let mut ocr_lite = OcrLite::new();
let det_model_path = Self::find_onnx_model(&shared.det_model)?;
let cls_model_path = Self::find_onnx_model(&shared.cls_model)?;
let rec_model_path = Self::find_onnx_model(&resolved.model_dir)?;
let num_threads = crate::core::config::concurrency::resolve_thread_budget(None).min(4);
let dict_path = resolved.dict_file.to_str().ok_or_else(|| crate::KreuzbergError::Ocr {
message: "Invalid dictionary file path".to_string(),
source: None,
})?;
ocr_lite
.init_models_with_dict(
det_model_path.to_str().ok_or_else(|| crate::KreuzbergError::Ocr {
message: "Invalid detection model path".to_string(),
source: None,
})?,
cls_model_path.to_str().ok_or_else(|| crate::KreuzbergError::Ocr {
message: "Invalid classification model path".to_string(),
source: None,
})?,
rec_model_path.to_str().ok_or_else(|| crate::KreuzbergError::Ocr {
message: "Invalid recognition model path".to_string(),
source: None,
})?,
dict_path,
num_threads,
)
.map_err(|e| crate::KreuzbergError::Ocr {
message: format!(
"Failed to initialize PaddleOCR models for {family} ({}): {e}",
resolved.model_key
),
source: None,
})?;
tracing::info!(family, model_key = %resolved.model_key, "PaddleOCR engine initialized successfully");
let engine = Arc::new(ocr_lite);
let mut pool = self.engine_pool.lock().map_err(|e| crate::KreuzbergError::Plugin {
message: format!("Failed to acquire engine pool lock: {e}"),
plugin_name: "paddle-ocr".to_string(),
})?;
if let Some(existing_engine) = pool.get(&pool_key) {
return Ok(Arc::clone(existing_engine));
}
pool.insert(pool_key, Arc::clone(&engine));
Ok(engine)
}
fn find_onnx_model(model_dir: &std::path::Path) -> Result<std::path::PathBuf> {
if !model_dir.exists() {
return Err(crate::KreuzbergError::Ocr {
message: format!("Model directory does not exist: {:?}", model_dir),
source: None,
});
}
let standard_path = model_dir.join("model.onnx");
if standard_path.exists() {
return Ok(standard_path);
}
let entries = std::fs::read_dir(model_dir).map_err(|e| crate::KreuzbergError::Ocr {
message: format!("Failed to read model directory {:?}: {}", model_dir, e),
source: None,
})?;
for entry in entries {
let entry = entry.map_err(|e| crate::KreuzbergError::Ocr {
message: format!("Failed to read directory entry: {}", e),
source: None,
})?;
let path = entry.path();
if path.extension().is_some_and(|ext| ext == "onnx") {
return Ok(path);
}
}
Err(crate::KreuzbergError::Ocr {
message: format!("No ONNX model file found in directory: {:?}", model_dir),
source: None,
})
}
fn detect_and_rotate(&self, image_bytes: &[u8]) -> Result<Option<Vec<u8>>> {
let detector = self.doc_ori_detector.get_or_try_init(|| {
let cache_dir = crate::doc_orientation::resolve_cache_dir();
Ok::<_, crate::KreuzbergError>(crate::doc_orientation::DocOrientationDetector::new(cache_dir))
})?;
crate::doc_orientation::detect_and_rotate(detector, image_bytes)
}
async fn do_ocr(
&self,
image_bytes: &[u8],
language: &str,
effective_config: Arc<PaddleOcrConfig>,
) -> Result<(String, Vec<OcrElement>)> {
let family = language_to_script_family(language);
let engine = self.get_or_init_engine_for_family(family)?;
let image_bytes_owned = image_bytes.to_vec();
let config = effective_config;
let text_blocks = tokio::task::spawn_blocking(move || {
catch_unwind(std::panic::AssertUnwindSafe(|| {
Self::perform_ocr(&image_bytes_owned, &engine, &config)
}))
.map_err(|_| crate::KreuzbergError::Plugin {
message: "PaddleOCR inference panicked (ONNX Runtime error)".to_string(),
plugin_name: "paddle-ocr".to_string(),
})?
})
.await
.map_err(|e| crate::KreuzbergError::Plugin {
message: format!("PaddleOCR task panicked: {}", e),
plugin_name: "paddle-ocr".to_string(),
})??;
let ocr_elements: Result<Vec<OcrElement>> = text_blocks
.iter()
.map(|block| text_block_to_element(block, 1))
.filter_map(|result| result.transpose())
.collect();
let ocr_elements = ocr_elements?;
let text = text_blocks
.iter()
.map(|block| block.text.as_str())
.filter(|t| !t.is_empty())
.collect::<Vec<_>>()
.join("\n\n");
Ok((text, ocr_elements))
}
fn perform_ocr(
image_bytes: &[u8],
ocr_engine: &Arc<OcrLite>,
config: &PaddleOcrConfig,
) -> Result<Vec<kreuzberg_paddle_ocr::TextBlock>> {
let img = crate::extraction::image::load_image_for_ocr(image_bytes)
.map_err(|e| crate::KreuzbergError::Ocr {
message: e.to_string(),
source: None,
})?
.to_rgb8();
let padding = config.padding;
let max_side_len = config.det_limit_side_len;
let box_score_thresh = config.det_db_box_thresh;
let box_thresh = config.det_db_thresh;
let un_clip_ratio = config.det_db_unclip_ratio;
let do_angle = config.use_angle_cls;
let most_angle = false;
let result = ocr_engine
.detect(
&img,
padding,
max_side_len,
box_score_thresh,
box_thresh,
un_clip_ratio,
do_angle,
most_angle,
)
.map_err(|e| crate::KreuzbergError::Ocr {
message: format!("PaddleOCR detection failed: {}", e),
source: None,
})?;
let drop_score = config.drop_score;
let text_blocks: Vec<_> = result
.text_blocks
.into_iter()
.filter(|block| block.text_score >= drop_score && !block.text_score.is_nan())
.collect();
tracing::debug!(text_block_count = text_blocks.len(), "PaddleOCR detection completed");
Ok(text_blocks)
}
}
impl Plugin for PaddleOcrBackend {
fn name(&self) -> &str {
"paddle-ocr"
}
fn version(&self) -> String {
env!("CARGO_PKG_VERSION").to_string()
}
fn initialize(&self) -> Result<()> {
Ok(())
}
fn shutdown(&self) -> Result<()> {
Ok(())
}
}
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
impl OcrBackend for PaddleOcrBackend {
async fn process_image(&self, image_bytes: &[u8], config: &OcrConfig) -> Result<ExtractionResult> {
if image_bytes.is_empty() {
return Err(crate::KreuzbergError::Validation {
message: "Empty image data provided to PaddleOCR".to_string(),
source: None,
});
}
let effective_config: Arc<PaddleOcrConfig> = if let Some(ref paddle_json) = config.paddle_ocr_config {
let overridden: PaddleOcrConfig =
serde_json::from_value(paddle_json.clone()).map_err(|e| crate::KreuzbergError::Validation {
message: format!("Failed to deserialize paddle_ocr_config: {}", e),
source: None,
})?;
Arc::new(overridden)
} else {
Arc::clone(&self.config)
};
let paddle_lang = map_language_code(&config.language).unwrap_or("en");
let ocr_image_bytes: std::borrow::Cow<'_, [u8]> = if config.auto_rotate {
match self.detect_and_rotate(image_bytes) {
Ok(Some(rotated)) => std::borrow::Cow::Owned(rotated),
Ok(None) => std::borrow::Cow::Borrowed(image_bytes),
Err(e) => {
tracing::warn!("Doc orientation detection failed, proceeding without rotation: {e}");
std::borrow::Cow::Borrowed(image_bytes)
}
}
} else {
std::borrow::Cow::Borrowed(image_bytes)
};
let (text, ocr_elements) = self
.do_ocr(&ocr_image_bytes, paddle_lang, Arc::clone(&effective_config))
.await?;
let mut tables: Vec<Table> = vec![];
let mut table_count = 0;
let mut table_rows: Option<usize> = None;
let mut table_cols: Option<usize> = None;
if effective_config.enable_table_detection && !ocr_elements.is_empty() {
let words = elements_to_hocr_words(&ocr_elements, 0.3);
if !words.is_empty() {
let cells = reconstruct_table(&words, 20, 0.5);
if !cells.is_empty() {
table_count = 1;
table_rows = Some(cells.len());
table_cols = cells.first().map(|row| row.len());
let table_markdown = table_to_markdown(&cells);
tables.push(Table {
cells,
markdown: table_markdown,
page_number: 1,
bounding_box: None,
});
}
}
}
let mut additional = AHashMap::new();
additional.insert(Cow::Borrowed("backend"), serde_json::json!("paddle-ocr"));
let metadata = Metadata {
format: Some(FormatMetadata::Ocr(OcrMetadata {
language: config.language.clone(),
psm: 3,
output_format: "text".to_string(),
table_count,
table_rows,
table_cols,
})),
additional,
..Default::default()
};
let include_elements = config.element_config.as_ref().is_some_and(|ec| ec.include_elements);
let ocr_elements_opt = if include_elements && !ocr_elements.is_empty() {
Some(ocr_elements)
} else {
None
};
Ok(ExtractionResult {
content: text,
mime_type: Cow::Borrowed("text/plain"),
metadata,
tables,
detected_languages: Some(vec![config.language.clone()]),
chunks: None,
images: None,
djot_content: None,
pages: None,
elements: None,
ocr_elements: ocr_elements_opt,
document: None,
#[cfg(any(feature = "keywords-yake", feature = "keywords-rake"))]
extracted_keywords: None,
quality_score: None,
processing_warnings: Vec::new(),
annotations: None,
children: None,
})
}
async fn process_image_file(&self, path: &Path, config: &OcrConfig) -> Result<ExtractionResult> {
let bytes = tokio::fs::read(path).await?;
self.process_image(&bytes, config).await
}
fn supports_language(&self, lang: &str) -> bool {
is_language_supported(lang) || map_language_code(lang).is_some()
}
fn backend_type(&self) -> OcrBackendType {
OcrBackendType::PaddleOCR
}
fn supported_languages(&self) -> Vec<String> {
super::SUPPORTED_LANGUAGES.iter().map(|s| s.to_string()).collect()
}
fn supports_table_detection(&self) -> bool {
self.config.enable_table_detection
}
}
impl Default for PaddleOcrBackend {
fn default() -> Self {
Self::with_config(PaddleOcrConfig::default())
.unwrap_or_else(|e| panic!("Failed to create default PaddleOcrBackend: {}", e))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_paddle_ocr_backend_creation() {
let result = PaddleOcrBackend::new();
assert!(result.is_ok(), "Failed to create PaddleOCR backend");
}
#[test]
fn test_paddle_ocr_backend_with_config() {
let config = PaddleOcrConfig::default();
let result = PaddleOcrBackend::with_config(config);
assert!(result.is_ok());
}
#[test]
fn test_paddle_ocr_language_support_direct() {
let backend = PaddleOcrBackend::new().unwrap();
assert!(backend.supports_language("ch"));
assert!(backend.supports_language("en"));
assert!(backend.supports_language("japan"));
assert!(backend.supports_language("korean"));
assert!(backend.supports_language("french"));
assert!(backend.supports_language("thai"));
assert!(backend.supports_language("greek"));
}
#[test]
fn test_paddle_ocr_language_support_mapped() {
let backend = PaddleOcrBackend::new().unwrap();
assert!(backend.supports_language("chi_sim"));
assert!(backend.supports_language("eng"));
assert!(backend.supports_language("jpn"));
assert!(backend.supports_language("kor"));
assert!(backend.supports_language("fra"));
assert!(backend.supports_language("zho"));
assert!(backend.supports_language("tha"));
assert!(backend.supports_language("ell"));
assert!(backend.supports_language("rus"));
}
#[test]
fn test_paddle_ocr_language_unsupported() {
let backend = PaddleOcrBackend::new().unwrap();
assert!(!backend.supports_language("xyz"));
assert!(!backend.supports_language("invalid"));
}
#[test]
fn test_paddle_ocr_plugin_interface() {
let backend = PaddleOcrBackend::new().unwrap();
assert_eq!(backend.name(), "paddle-ocr");
assert!(!backend.version().is_empty());
assert!(backend.initialize().is_ok());
assert!(backend.shutdown().is_ok());
}
#[test]
fn test_paddle_ocr_backend_type() {
let backend = PaddleOcrBackend::new().unwrap();
assert_eq!(backend.backend_type(), OcrBackendType::PaddleOCR);
}
#[test]
fn test_paddle_ocr_supported_languages() {
let backend = PaddleOcrBackend::new().unwrap();
let languages = backend.supported_languages();
assert!(!languages.is_empty());
assert!(languages.contains(&"ch".to_string()));
assert!(languages.contains(&"en".to_string()));
assert!(languages.contains(&"thai".to_string()));
assert!(languages.contains(&"greek".to_string()));
}
#[test]
fn test_paddle_ocr_table_detection_disabled_by_default() {
let backend = PaddleOcrBackend::new().unwrap();
assert!(!backend.supports_table_detection());
}
#[test]
fn test_paddle_ocr_table_detection_enabled() {
let config = PaddleOcrConfig::default().with_table_detection(true);
let backend = PaddleOcrBackend::with_config(config).unwrap();
assert!(backend.supports_table_detection());
}
#[test]
fn test_paddle_ocr_default() {
let backend = PaddleOcrBackend::default();
assert_eq!(backend.name(), "paddle-ocr");
}
#[tokio::test]
async fn test_paddle_ocr_process_empty_image() {
let backend = PaddleOcrBackend::new().unwrap();
let config = OcrConfig {
backend: "paddle-ocr".to_string(),
language: "ch".to_string(),
..Default::default()
};
let result = backend.process_image(&[], &config).await;
assert!(result.is_err(), "Should error on empty image");
}
}