use std::sync::Arc;
use anyhow::{anyhow, Context, Result};
use doiget_core::provenance::{Capability, LogEvent, LogResult, RowInput};
use doiget_core::{RateLimits, Ref, MCP_BATCH_MAX_SIZE};
use super::fetch::{build_fetch_plan, emit_dry_run_plan_to_stdout, CliExit, FetchHarness};
use super::resolve_store_root;
pub async fn run_with_options(
path: String,
dry_run: bool,
mode: super::output::OutputMode,
) -> Result<()> {
let json_mode = mode == super::output::OutputMode::Json;
let raw =
std::fs::read_to_string(&path).with_context(|| format!("reading batch file: {path}"))?;
let inputs: Vec<String> = raw
.lines()
.map(str::trim)
.filter(|l| !l.is_empty() && !l.starts_with('#'))
.map(|s| s.to_string())
.collect();
if inputs.len() > MCP_BATCH_MAX_SIZE {
return Err(anyhow!(
"batch size {} exceeds limit {}",
inputs.len(),
MCP_BATCH_MAX_SIZE,
));
}
if dry_run {
let store_root = resolve_store_root()?;
let mut parse_errors: usize = 0;
for input in &inputs {
match Ref::parse(input) {
Ok(ref_) => {
let plan = build_fetch_plan(&ref_, &store_root);
emit_dry_run_plan_to_stdout(&ref_, &plan)?;
}
Err(e) => {
parse_errors += 1;
tracing::warn!(
%input,
error = %e,
"skipping malformed batch entry in dry-run mode",
);
}
}
}
if parse_errors > 0 {
return Err(anyhow!(
"dry-run batch: {} parse errors (no fetches attempted)",
parse_errors
));
}
return Ok(());
}
let harness = Arc::new(FetchHarness::from_env()?);
harness.log_session_start(None)?;
let max_concurrent = RateLimits::HARD_CODED.max_concurrent_fetches() as usize;
let semaphore = Arc::new(tokio::sync::Semaphore::new(max_concurrent));
let mut parse_errors: usize = 0;
let mut joins: tokio::task::JoinSet<TaskOutcome> = tokio::task::JoinSet::new();
for input in inputs {
let ref_ = match Ref::parse(&input) {
Ok(r) => r,
Err(e) => {
parse_errors += 1;
if json_mode {
emit_jsonl_failure(Some(&input), "INVALID_REF", &e.to_string());
}
let _ = harness.log.append(RowInput {
event: LogEvent::Resolve,
result: LogResult::Err,
capability: Capability::Oa,
ref_: Some(&input),
source: None,
error_code: Some("INVALID_REF"),
size_bytes: None,
license: None,
store_path: None,
canonical_digest: None,
});
tracing::warn!(
%input,
error = %e,
"skipping malformed batch entry",
);
continue;
}
};
let harness_task = Arc::clone(&harness);
let sem_task = Arc::clone(&semaphore);
joins.spawn(async move {
let _permit = match sem_task.acquire_owned().await {
Ok(p) => p,
Err(_) => {
return TaskOutcome {
input,
result: Err(anyhow!("batch semaphore unexpectedly closed")),
}
}
};
let result = harness_task.fetch_one(&ref_).await;
TaskOutcome { input, result }
});
}
let mut fetch_ok: usize = 0;
let mut fetch_errors: usize = 0;
while let Some(joined) = joins.join_next().await {
let JoinedOutcome {
is_error,
json_record,
log_breadcrumb,
} = classify_joined(joined, json_mode);
if is_error {
fetch_errors += 1;
} else {
fetch_ok += 1;
}
if let Some(record) = json_record {
#[allow(clippy::print_stdout)]
{
println!("{record}");
}
}
log_breadcrumb.emit();
}
let total_errors = parse_errors + fetch_errors;
let all_ok = total_errors == 0;
harness.log_session_end(all_ok, None);
print_summary(format_args!(
"batch: {} OK, {} failed ({} parse errors, {} fetch errors)",
fetch_ok, total_errors, parse_errors, fetch_errors,
));
if all_ok {
Ok(())
} else {
let code = total_errors.min(255) as i32;
Err(anyhow::Error::new(CliExit(code)))
}
}
#[derive(Debug)]
struct TaskOutcome {
input: String,
result: Result<()>,
}
#[allow(clippy::print_stderr)]
fn print_summary(args: std::fmt::Arguments<'_>) {
eprintln!("{args}");
}
struct JoinedOutcome {
is_error: bool,
json_record: Option<serde_json::Value>,
log_breadcrumb: LogBreadcrumb,
}
enum LogBreadcrumb {
None,
FetchFailed { input: String, error_dbg: String },
TaskPanicked { error_dbg: String },
}
impl LogBreadcrumb {
fn emit(self) {
match self {
LogBreadcrumb::None => {}
LogBreadcrumb::FetchFailed { input, error_dbg } => {
tracing::warn!(%input, %error_dbg, "batch entry fetch failed");
}
LogBreadcrumb::TaskPanicked { error_dbg } => {
tracing::error!(%error_dbg, "batch task panicked or was cancelled");
}
}
}
}
fn classify_joined(
joined: Result<TaskOutcome, tokio::task::JoinError>,
json_mode: bool,
) -> JoinedOutcome {
match joined {
Ok(TaskOutcome { input, result }) => match result {
Ok(()) => JoinedOutcome {
is_error: false,
json_record: json_mode.then(|| build_jsonl_success(&input)),
log_breadcrumb: LogBreadcrumb::None,
},
Err(e) => {
let error_dbg = format!("{e:?}");
let json_msg = format!("{e:#}");
let record =
json_mode.then(|| build_jsonl_failure(Some(&input), "FETCH_ERROR", &json_msg));
JoinedOutcome {
is_error: true,
json_record: record,
log_breadcrumb: LogBreadcrumb::FetchFailed { input, error_dbg },
}
}
},
Err(join_err) => {
let error_dbg = format!("{join_err:?}");
let json_msg = format!("batch task panicked: {join_err}");
let record = json_mode.then(|| build_jsonl_failure(None, "FETCH_ERROR", &json_msg));
JoinedOutcome {
is_error: true,
json_record: record,
log_breadcrumb: LogBreadcrumb::TaskPanicked { error_dbg },
}
}
}
}
fn build_jsonl_success(ref_input: &str) -> serde_json::Value {
serde_json::json!({ "ok": true, "ref": ref_input })
}
fn build_jsonl_failure(ref_input: Option<&str>, code: &str, message: &str) -> serde_json::Value {
serde_json::json!({
"ok": false,
"ref": ref_input,
"error": { "code": code, "message": message },
})
}
#[allow(clippy::print_stdout)]
fn emit_jsonl_failure(ref_input: Option<&str>, code: &str, message: &str) {
println!("{}", build_jsonl_failure(ref_input, code, message));
}
#[cfg(test)]
#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
mod tests {
use super::*;
#[test]
fn jsonl_success_shape() {
let v = build_jsonl_success("10.1234/foo");
assert_eq!(v["ok"], true);
assert_eq!(v["ref"], "10.1234/foo");
assert!(v.get("error").is_none(), "no error field on success");
}
#[test]
fn jsonl_failure_shape_invalid_ref() {
let v = build_jsonl_failure(Some("not-a-doi"), "INVALID_REF", "bad ref");
assert_eq!(v["ok"], false);
assert_eq!(v["ref"], "not-a-doi");
assert_eq!(v["error"]["code"], "INVALID_REF");
assert_eq!(v["error"]["message"], "bad ref");
}
#[test]
fn jsonl_failure_shape_fetch_error() {
let v = build_jsonl_failure(Some("arxiv:2401.12345"), "FETCH_ERROR", "boom");
assert_eq!(v["ok"], false);
assert_eq!(v["ref"], "arxiv:2401.12345");
assert_eq!(v["error"]["code"], "FETCH_ERROR");
assert_eq!(v["error"]["message"], "boom");
}
#[test]
fn jsonl_failure_shape_panic_ref_is_null() {
let v = build_jsonl_failure(None, "FETCH_ERROR", "batch task panicked: ...");
assert_eq!(v["ok"], false);
assert!(v["ref"].is_null(), "panic record's ref MUST be null: {v}");
assert_eq!(v["error"]["code"], "FETCH_ERROR");
}
#[test]
fn classify_joined_success_json_emits_record() {
let outcome = classify_joined(
Ok(TaskOutcome {
input: "10.1234/foo".to_string(),
result: Ok(()),
}),
true,
);
assert!(!outcome.is_error);
let rec = outcome.json_record.expect("json_mode → record");
assert_eq!(rec["ok"], true);
assert_eq!(rec["ref"], "10.1234/foo");
assert!(matches!(outcome.log_breadcrumb, LogBreadcrumb::None));
}
#[test]
fn classify_joined_success_human_no_record() {
let outcome = classify_joined(
Ok(TaskOutcome {
input: "10.1234/foo".to_string(),
result: Ok(()),
}),
false,
);
assert!(!outcome.is_error);
assert!(outcome.json_record.is_none(), "human mode → no record");
}
#[test]
fn classify_joined_fetch_failure_emits_fetch_error() {
let outcome = classify_joined(
Ok(TaskOutcome {
input: "arxiv:2401.99999".to_string(),
result: Err(anyhow!("connection refused")),
}),
true,
);
assert!(outcome.is_error);
let rec = outcome.json_record.expect("json_mode → record");
assert_eq!(rec["ok"], false);
assert_eq!(rec["ref"], "arxiv:2401.99999");
assert_eq!(rec["error"]["code"], "FETCH_ERROR");
assert!(matches!(
outcome.log_breadcrumb,
LogBreadcrumb::FetchFailed { .. }
));
}
#[test]
fn classify_joined_panic_emits_null_ref_fetch_error() {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.expect("tokio runtime");
let join_err = rt.block_on(async {
let mut js: tokio::task::JoinSet<TaskOutcome> = tokio::task::JoinSet::new();
js.spawn(async { panic!("synthetic panic for classify_joined") });
let joined = js.join_next().await.expect("one task");
joined.expect_err("expected panic → Err(JoinError)")
});
let outcome = classify_joined(Err(join_err), true);
assert!(outcome.is_error);
let rec = outcome.json_record.expect("json_mode → record");
assert_eq!(rec["ok"], false);
assert!(
rec["ref"].is_null(),
"panic record's ref MUST be null: {rec}"
);
assert_eq!(rec["error"]["code"], "FETCH_ERROR");
assert!(
rec["error"]["message"]
.as_str()
.unwrap_or("")
.contains("batch task panicked"),
"panic message preserved: {rec}"
);
assert!(matches!(
outcome.log_breadcrumb,
LogBreadcrumb::TaskPanicked { .. }
));
}
#[test]
fn log_breadcrumb_emit_does_not_panic_on_any_variant() {
for variant in [
LogBreadcrumb::None,
LogBreadcrumb::FetchFailed {
input: "x".into(),
error_dbg: "y".into(),
},
LogBreadcrumb::TaskPanicked {
error_dbg: "z".into(),
},
] {
variant.emit();
}
}
#[test]
fn jsonl_records_are_single_line_serialised() {
let s = build_jsonl_success("10.1/x").to_string();
assert!(
!s.contains('\n'),
"JSONL success must be single-line: {s:?}"
);
let s2 = build_jsonl_failure(Some("10.1/x"), "FETCH_ERROR", "msg").to_string();
assert!(
!s2.contains('\n'),
"JSONL failure must be single-line: {s2:?}"
);
let s3 = build_jsonl_failure(None, "FETCH_ERROR", "msg").to_string();
assert!(
!s3.contains('\n'),
"null-ref JSONL must be single-line: {s3:?}"
);
}
#[test]
fn parses_and_filters_input_lines() {
let raw = "\
arxiv:2401.12345
# a comment line
# indented comment with leading whitespace
arxiv:2401.12346
\t\t
arxiv:2401.12347
";
let lines: Vec<String> = raw
.lines()
.map(str::trim)
.filter(|l| !l.is_empty() && !l.starts_with('#'))
.map(|s| s.to_string())
.collect();
assert_eq!(
lines,
vec![
"arxiv:2401.12345".to_string(),
"arxiv:2401.12346".to_string(),
"arxiv:2401.12347".to_string(),
],
);
}
#[test]
fn over_limit_input_is_rejected() {
let n = MCP_BATCH_MAX_SIZE + 1;
let body: String = (0..n)
.map(|i| format!("arxiv:2401.{:05}\n", 10000 + i))
.collect();
let lines: Vec<String> = body
.lines()
.map(str::trim)
.filter(|l| !l.is_empty() && !l.starts_with('#'))
.map(|s| s.to_string())
.collect();
assert_eq!(lines.len(), n);
assert!(lines.len() > MCP_BATCH_MAX_SIZE);
}
}