#[cfg(all(feature = "ocr", feature = "tokio-runtime"))]
use crate::ocr::OcrProcessor;
use crate::types::{ExtractedImage, ExtractionResult, Metadata};
#[cfg(all(feature = "ocr", feature = "tokio-runtime"))]
const MAX_CONCURRENT_OCR_TASKS: usize = 8;
#[cfg(all(feature = "ocr", feature = "tokio-runtime"))]
pub async fn process_images_with_ocr(
mut images: Vec<ExtractedImage>,
config: &crate::core::config::ExtractionConfig,
) -> crate::Result<Vec<ExtractedImage>> {
if images.is_empty() || config.ocr.is_none() {
return Ok(images);
}
let ocr_config = config.ocr.as_ref().unwrap();
let tess_config = ocr_config.tesseract_config.as_ref().cloned().unwrap_or_default();
let output_format = config.output_format.clone();
use std::sync::Arc;
use tokio::sync::Semaphore;
use tokio::task::JoinSet;
let semaphore = Arc::new(Semaphore::new(MAX_CONCURRENT_OCR_TASKS));
type OcrTaskResult = (
usize,
Result<Result<crate::types::OcrExtractionResult, crate::ocr::error::OcrError>, tokio::task::JoinError>,
);
let mut join_set: JoinSet<OcrTaskResult> = JoinSet::new();
for (idx, image) in images.iter().enumerate() {
let image_data = image.data.clone();
let tess_config_clone = tess_config.clone();
let span = tracing::Span::current();
let permit = Arc::clone(&semaphore);
let output_format = output_format.clone();
join_set.spawn(async move {
let _permit = permit.acquire().await.expect("semaphore should not be closed");
let blocking_result = tokio::task::spawn_blocking(move || {
let _guard = span.entered();
let cache_dir = std::env::var("KREUZBERG_CACHE_DIR").ok().map(std::path::PathBuf::from);
let proc = OcrProcessor::new(cache_dir)?;
let ocr_tess_config: crate::ocr::types::TesseractConfig = (&tess_config_clone).into();
proc.process_image_with_format(&image_data, &ocr_tess_config, output_format)
})
.await;
(idx, blocking_result)
});
}
while let Some(join_result) = join_set.join_next().await {
let (idx, blocking_result) = join_result.map_err(|e| crate::KreuzbergError::Ocr {
message: format!("OCR task panicked: {}", e),
source: None,
})?;
let ocr_result = blocking_result.map_err(|e| crate::KreuzbergError::Ocr {
message: format!("OCR blocking task panicked: {}", e),
source: None,
})?;
match ocr_result {
Ok(ocr_extraction) => {
let extraction_result = ExtractionResult {
content: ocr_extraction.content,
mime_type: ocr_extraction.mime_type.into(),
metadata: Metadata::default(),
tables: vec![],
detected_languages: None,
chunks: None,
images: None,
djot_content: None,
pages: None,
elements: None,
ocr_elements: ocr_extraction.ocr_elements,
document: None,
#[cfg(any(feature = "keywords-yake", feature = "keywords-rake"))]
extracted_keywords: None,
quality_score: None,
processing_warnings: Vec::new(),
annotations: None,
children: None,
uris: None,
#[cfg(feature = "tree-sitter")]
code_intelligence: None,
formatted_content: None,
};
images[idx].ocr_result = Some(Box::new(extraction_result));
}
Err(_) => {
images[idx].ocr_result = None;
}
}
}
Ok(images)
}