use std::sync::Arc;
use anyhow::{anyhow, Context, Result};
use doiget_core::provenance::{Capability, LogEvent, LogResult, RowInput};
use doiget_core::{RateLimits, Ref, MCP_BATCH_MAX_SIZE};
use super::fetch::{build_fetch_plan, emit_dry_run_plan_to_stdout, FetchHarness};
use super::resolve_store_root;
pub async fn run_with_options(path: String, dry_run: bool) -> Result<()> {
let raw =
std::fs::read_to_string(&path).with_context(|| format!("reading batch file: {path}"))?;
let inputs: Vec<String> = raw
.lines()
.map(str::trim)
.filter(|l| !l.is_empty() && !l.starts_with('#'))
.map(|s| s.to_string())
.collect();
if inputs.len() > MCP_BATCH_MAX_SIZE {
return Err(anyhow!(
"batch size {} exceeds limit {}",
inputs.len(),
MCP_BATCH_MAX_SIZE,
));
}
if dry_run {
let store_root = resolve_store_root()?;
let mut parse_errors: usize = 0;
for input in &inputs {
match Ref::parse(input) {
Ok(ref_) => {
let plan = build_fetch_plan(&ref_, &store_root);
emit_dry_run_plan_to_stdout(&ref_, &plan)?;
}
Err(e) => {
parse_errors += 1;
tracing::warn!(
%input,
error = %e,
"skipping malformed batch entry in dry-run mode",
);
}
}
}
if parse_errors > 0 {
return Err(anyhow!(
"dry-run batch: {} parse errors (no fetches attempted)",
parse_errors
));
}
return Ok(());
}
let harness = Arc::new(FetchHarness::from_env()?);
harness.log_session_start(None)?;
let max_concurrent = RateLimits::HARD_CODED.max_concurrent_fetches() as usize;
let semaphore = Arc::new(tokio::sync::Semaphore::new(max_concurrent));
let mut parse_errors: usize = 0;
let mut joins: tokio::task::JoinSet<TaskOutcome> = tokio::task::JoinSet::new();
for input in inputs {
let ref_ = match Ref::parse(&input) {
Ok(r) => r,
Err(e) => {
parse_errors += 1;
let _ = harness.log.append(RowInput {
event: LogEvent::Resolve,
result: LogResult::Err,
capability: Capability::Oa,
ref_: Some(&input),
source: None,
error_code: Some("INVALID_REF"),
size_bytes: None,
license: None,
store_path: None,
canonical_digest: None,
});
tracing::warn!(
%input,
error = %e,
"skipping malformed batch entry",
);
continue;
}
};
let harness_task = Arc::clone(&harness);
let sem_task = Arc::clone(&semaphore);
joins.spawn(async move {
let _permit = match sem_task.acquire_owned().await {
Ok(p) => p,
Err(_) => {
return TaskOutcome {
input,
result: Err(anyhow!("batch semaphore unexpectedly closed")),
}
}
};
let result = harness_task.fetch_one(&ref_).await;
TaskOutcome { input, result }
});
}
let mut fetch_ok: usize = 0;
let mut fetch_errors: usize = 0;
while let Some(joined) = joins.join_next().await {
match joined {
Ok(TaskOutcome { input, result }) => match result {
Ok(()) => fetch_ok += 1,
Err(e) => {
fetch_errors += 1;
tracing::warn!(%input, error = ?e, "batch entry fetch failed");
}
},
Err(join_err) => {
fetch_errors += 1;
tracing::error!(error = ?join_err, "batch task panicked or was cancelled");
}
}
}
let total_errors = parse_errors + fetch_errors;
let all_ok = total_errors == 0;
harness.log_session_end(all_ok, None);
print_summary(format_args!(
"batch: {} OK, {} failed ({} parse errors, {} fetch errors)",
fetch_ok, total_errors, parse_errors, fetch_errors,
));
if all_ok {
Ok(())
} else {
Err(anyhow!(
"batch failed: {} OK, {} parse errors, {} fetch errors",
fetch_ok,
parse_errors,
fetch_errors,
))
}
}
struct TaskOutcome {
input: String,
result: Result<()>,
}
#[allow(clippy::print_stderr)]
fn print_summary(args: std::fmt::Arguments<'_>) {
eprintln!("{args}");
}
#[cfg(test)]
#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
mod tests {
use super::*;
#[test]
fn parses_and_filters_input_lines() {
let raw = "\
arxiv:2401.12345
# a comment line
# indented comment with leading whitespace
arxiv:2401.12346
\t\t
arxiv:2401.12347
";
let lines: Vec<String> = raw
.lines()
.map(str::trim)
.filter(|l| !l.is_empty() && !l.starts_with('#'))
.map(|s| s.to_string())
.collect();
assert_eq!(
lines,
vec![
"arxiv:2401.12345".to_string(),
"arxiv:2401.12346".to_string(),
"arxiv:2401.12347".to_string(),
],
);
}
#[test]
fn over_limit_input_is_rejected() {
let n = MCP_BATCH_MAX_SIZE + 1;
let body: String = (0..n)
.map(|i| format!("arxiv:2401.{:05}\n", 10000 + i))
.collect();
let lines: Vec<String> = body
.lines()
.map(str::trim)
.filter(|l| !l.is_empty() && !l.starts_with('#'))
.map(|s| s.to_string())
.collect();
assert_eq!(lines.len(), n);
assert!(lines.len() > MCP_BATCH_MAX_SIZE);
}
}