use crate::core::config::ExtractionConfig;
use crate::core::config::extraction::FileExtractionConfig;
use crate::types::ExtractionResult;
use crate::{KreuzbergError, Result};
use std::future::Future;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Instant;
use super::bytes::extract_bytes;
use super::file::extract_file;
use super::helpers::error_extraction_result;
#[cfg(feature = "tokio-runtime")]
async fn collect_batch<F, Fut>(count: usize, config: &ExtractionConfig, spawn_task: F) -> Result<Vec<ExtractionResult>>
where
F: Fn(usize, Arc<tokio::sync::Semaphore>) -> Fut,
Fut: Future<Output = (usize, Result<ExtractionResult>, u64)> + Send + 'static,
{
use tokio::sync::Semaphore;
use tokio::task::JoinSet;
if count == 0 {
return Ok(vec![]);
}
let max_concurrent = config
.max_concurrent_extractions
.or_else(|| config.concurrency.as_ref().and_then(|c| c.max_threads))
.unwrap_or_else(|| (num_cpus::get() as f64 * 1.5).ceil() as usize);
let semaphore = Arc::new(Semaphore::new(max_concurrent));
let mut tasks = JoinSet::new();
for index in 0..count {
let sem = Arc::clone(&semaphore);
tasks.spawn(spawn_task(index, sem));
}
let mut results: Vec<Option<ExtractionResult>> = vec![None; count];
while let Some(task_result) = tasks.join_next().await {
match task_result {
Ok((index, Ok(result), _elapsed_ms)) => {
results[index] = Some(result);
}
Ok((index, Err(e), elapsed_ms)) => {
results[index] = Some(error_extraction_result(&e, Some(elapsed_ms)));
}
Err(join_err) => {
return Err(KreuzbergError::Other(format!("Task panicked: {}", join_err)));
}
}
}
#[allow(clippy::unwrap_used)]
Ok(results.into_iter().map(|r| r.unwrap()).collect())
}
#[cfg(feature = "tokio-runtime")]
async fn run_timed_extraction<F, Fut>(
index: usize,
semaphore: Arc<tokio::sync::Semaphore>,
extract_fn: F,
) -> (usize, Result<ExtractionResult>, u64)
where
F: FnOnce() -> Fut,
Fut: Future<Output = Result<ExtractionResult>>,
{
let _permit = semaphore.acquire().await.unwrap();
let start = Instant::now();
let mut result = crate::core::batch_mode::with_batch_mode(extract_fn()).await;
let elapsed_ms = start.elapsed().as_millis() as u64;
if let Ok(ref mut r) = result {
r.metadata.extraction_duration_ms = Some(elapsed_ms);
}
(index, result, elapsed_ms)
}
fn resolve_config(base: &ExtractionConfig, file_config: &Option<FileExtractionConfig>) -> ExtractionConfig {
match file_config {
Some(fc) => base.with_file_overrides(fc),
None => base.clone(),
}
}
#[cfg(feature = "tokio-runtime")]
#[cfg_attr(feature = "otel", tracing::instrument(
skip(config, items),
fields(
extraction.batch_size = items.len(),
)
))]
pub async fn batch_extract_file(
items: Vec<(PathBuf, Option<FileExtractionConfig>)>,
config: &ExtractionConfig,
) -> Result<Vec<ExtractionResult>> {
let config_arc = Arc::new(config.clone());
let items_arc = Arc::new(items);
let count = items_arc.len();
collect_batch(count, config, |index, sem| {
let cfg = Arc::clone(&config_arc);
let items = Arc::clone(&items_arc);
async move {
let (ref path, ref file_config) = items[index];
let resolved = resolve_config(&cfg, file_config);
run_timed_extraction(index, sem, || {
let path = path.clone();
async move { extract_file(&path, None, &resolved).await }
})
.await
}
})
.await
}
#[cfg(feature = "tokio-runtime")]
#[cfg_attr(feature = "otel", tracing::instrument(
skip(config, items),
fields(
extraction.batch_size = items.len(),
)
))]
pub async fn batch_extract_bytes(
items: Vec<(Vec<u8>, String, Option<FileExtractionConfig>)>,
config: &ExtractionConfig,
) -> Result<Vec<ExtractionResult>> {
let config_arc = Arc::new(config.clone());
let count = items.len();
type BytesSlot = std::sync::Mutex<Option<(Vec<u8>, String, Option<FileExtractionConfig>)>>;
let slots: Arc<Vec<BytesSlot>> = Arc::new(
items
.into_iter()
.map(|item| std::sync::Mutex::new(Some(item)))
.collect(),
);
collect_batch(count, config, |index, sem| {
let cfg = Arc::clone(&config_arc);
let slots = Arc::clone(&slots);
async move {
let (bytes, mime_type, file_config) = slots[index]
.lock()
.unwrap()
.take()
.expect("batch item already consumed");
let resolved = resolve_config(&cfg, &file_config);
run_timed_extraction(index, sem, || async move {
extract_bytes(&bytes, &mime_type, &resolved).await
})
.await
}
})
.await
}