use std::collections::HashMap;
use super::provider::BatchProvider;
use super::{
ApiError, BatchItem, BatchRequest, BatchResponse, BatchResult, ChatMessage, LlmClient,
LlmError, MessagesRequest, API_VERSION, BATCH_POLL_INTERVAL,
};
use crate::Store;
impl LlmClient {
fn submit_batch_inner(
&self,
items: &[super::provider::BatchSubmitItem],
max_tokens: u32,
purpose: &str,
prompt_builder: fn(&str, &str, &str) -> String,
) -> Result<String, LlmError> {
if items.is_empty() {
return Err(LlmError::BatchFailed("Cannot submit empty batch".into()));
}
let _span =
tracing::info_span!("submit_batch_inner", purpose, count = items.len()).entered();
let model = self.llm_config.model.clone();
let requests: Vec<BatchItem> = items
.iter()
.map(|item| BatchItem {
custom_id: item.custom_id.clone(),
params: MessagesRequest {
model: model.clone(),
max_tokens,
messages: vec![ChatMessage {
role: "user".to_string(),
content: prompt_builder(&item.content, &item.context, &item.language),
}],
},
})
.collect();
let url = format!("{}/messages/batches", self.llm_config.api_base);
let response = self
.http
.post(&url)
.header("x-api-key", &self.api_key)
.header("anthropic-version", API_VERSION)
.header("content-type", "application/json")
.json(&BatchRequest { requests })
.send()?;
let status = response.status();
if status == 401 {
return Err(LlmError::Api {
status: 401,
message: "Invalid ANTHROPIC_API_KEY (401 Unauthorized)".to_string(),
});
}
if !status.is_success() {
let body = response.text().unwrap_or_else(|e| {
tracing::warn!(error = %e, "Failed to read HTTP error response body");
String::new()
});
let message = serde_json::from_str::<ApiError>(&body)
.map(|err| format!("{purpose} submission failed: {}", err.error.message))
.unwrap_or_else(|_| format!("{purpose} submission failed: HTTP {status}: {body}"));
return Err(LlmError::Api {
status: status.as_u16(),
message,
});
}
let batch: BatchResponse = response.json()?;
if !super::is_valid_anthropic_batch_id(&batch.id) {
return Err(LlmError::InvalidBatchId(batch.id));
}
tracing::info!(batch_id = %batch.id, count = items.len(), "{purpose} submitted");
Ok(batch.id)
}
pub(super) fn submit_batch_prebuilt(
&self,
items: &[super::provider::BatchSubmitItem],
max_tokens: u32,
) -> Result<String, LlmError> {
self.submit_batch_inner(items, max_tokens, "Batch", |content, _, _| {
content.to_string()
})
}
pub(super) fn submit_doc_batch(
&self,
items: &[super::provider::BatchSubmitItem],
max_tokens: u32,
) -> Result<String, LlmError> {
self.submit_batch_inner(items, max_tokens, "Doc batch", Self::build_doc_prompt)
}
pub(super) fn submit_hyde_batch(
&self,
items: &[super::provider::BatchSubmitItem],
max_tokens: u32,
) -> Result<String, LlmError> {
self.submit_batch_inner(items, max_tokens, "Hyde batch", Self::build_hyde_prompt)
}
pub(super) fn check_batch_status(&self, batch_id: &str) -> Result<String, LlmError> {
let _span = tracing::debug_span!("check_batch_status", batch_id).entered();
if !super::is_valid_anthropic_batch_id(batch_id) {
return Err(LlmError::InvalidBatchId(batch_id.to_string()));
}
let url = format!("{}/messages/batches/{}", self.llm_config.api_base, batch_id);
let response = self
.http
.get(&url)
.header("x-api-key", &self.api_key)
.header("anthropic-version", API_VERSION)
.send()?;
if !response.status().is_success() {
let status = response.status().as_u16();
let body = response.text().unwrap_or_else(|e| {
tracing::warn!(error = %e, "Failed to read HTTP error response body");
String::new()
});
return Err(LlmError::Api {
status,
message: format!("Batch status check failed: {body}"),
});
}
let batch: BatchResponse = response.json()?;
Ok(batch.processing_status)
}
pub(super) fn wait_for_batch(&self, batch_id: &str, quiet: bool) -> Result<(), LlmError> {
if !super::is_valid_anthropic_batch_id(batch_id) {
return Err(LlmError::InvalidBatchId(batch_id.to_string()));
}
let url = format!("{}/messages/batches/{}", self.llm_config.api_base, batch_id);
loop {
let response = self
.http
.get(&url)
.header("x-api-key", &self.api_key)
.header("anthropic-version", API_VERSION)
.send()?;
if !response.status().is_success() {
let status = response.status().as_u16();
let body = response.text().unwrap_or_else(|e| {
tracing::warn!(error = %e, "Failed to read HTTP error response body");
String::new()
});
return Err(LlmError::Api {
status,
message: format!("Batch status check failed: {body}"),
});
}
let batch: BatchResponse = response.json()?;
match batch.processing_status.as_str() {
"ended" => {
tracing::info!(batch_id, "Batch complete");
return Ok(());
}
"canceling" | "canceled" | "expired" => {
return Err(LlmError::BatchFailed(format!(
"Batch {} ended with status: {}",
batch_id, batch.processing_status
)));
}
_ => {
if !quiet {
eprint!(".");
}
tracing::debug!(batch_id, status = %batch.processing_status, "Batch still processing");
std::thread::sleep(BATCH_POLL_INTERVAL);
}
}
}
}
pub(super) fn fetch_batch_results(
&self,
batch_id: &str,
) -> Result<HashMap<String, String>, LlmError> {
if !super::is_valid_anthropic_batch_id(batch_id) {
return Err(LlmError::InvalidBatchId(batch_id.to_string()));
}
let url = format!(
"{}/messages/batches/{}/results",
self.llm_config.api_base, batch_id
);
let response = self
.http
.get(&url)
.header("x-api-key", &self.api_key)
.header("anthropic-version", API_VERSION)
.send()?;
if !response.status().is_success() {
let status = response.status().as_u16();
let body = response.text().unwrap_or_else(|e| {
tracing::warn!(error = %e, "Failed to read HTTP error response body");
String::new()
});
return Err(LlmError::Api {
status,
message: format!("Batch results fetch failed: {body}"),
});
}
const MAX_RESPONSE_BYTES: u64 = 100 * 1024 * 1024; if let Some(len) = response.content_length() {
if len > MAX_RESPONSE_BYTES {
return Err(LlmError::Api {
status: 200,
message: format!(
"Batch response too large: {} bytes (max {})",
len, MAX_RESPONSE_BYTES
),
});
}
}
use std::io::Read;
let mut body_bytes = Vec::new();
response
.take(MAX_RESPONSE_BYTES + 1)
.read_to_end(&mut body_bytes)
.map_err(|e| LlmError::Api {
status: 200,
message: format!("Failed to read batch response body: {e}"),
})?;
if body_bytes.len() as u64 > MAX_RESPONSE_BYTES {
return Err(LlmError::Api {
status: 200,
message: format!(
"Batch response exceeded {} byte limit while streaming",
MAX_RESPONSE_BYTES
),
});
}
let body = String::from_utf8(body_bytes).map_err(|e| LlmError::Api {
status: 200,
message: format!("Batch response not valid UTF-8: {e}"),
})?;
let mut results = HashMap::new();
for line in body.lines() {
let line = line.trim();
if line.is_empty() {
continue;
}
match serde_json::from_str::<BatchResult>(line) {
Ok(result) => {
if result.result.result_type == "succeeded" {
if let Some(msg) = result.result.message {
let text = msg
.content
.into_iter()
.find(|b| b.block_type == "text")
.and_then(|b| b.text);
if let Some(s) = text {
let trimmed = s.trim().to_string();
if !trimmed.is_empty() {
results.insert(result.custom_id, trimmed);
}
}
}
} else {
tracing::warn!(
custom_id = %result.custom_id,
result_type = %result.result.result_type,
"Batch item not succeeded"
);
}
}
Err(e) => {
tracing::warn!(error = %e, "Failed to parse batch result line");
}
}
}
tracing::info!(batch_id, succeeded = results.len(), "Batch results fetched");
Ok(results)
}
}
impl BatchProvider for LlmClient {
fn submit_batch(
&self,
items: &[super::provider::BatchSubmitItem],
max_tokens: u32,
purpose: &str,
prompt_builder: fn(&str, &str, &str) -> String,
) -> Result<String, LlmError> {
self.submit_batch_inner(items, max_tokens, purpose, prompt_builder)
}
fn submit_batch_prebuilt(
&self,
items: &[super::provider::BatchSubmitItem],
max_tokens: u32,
) -> Result<String, LlmError> {
LlmClient::submit_batch_prebuilt(self, items, max_tokens)
}
fn submit_doc_batch(
&self,
items: &[super::provider::BatchSubmitItem],
max_tokens: u32,
) -> Result<String, LlmError> {
LlmClient::submit_doc_batch(self, items, max_tokens)
}
fn submit_hyde_batch(
&self,
items: &[super::provider::BatchSubmitItem],
max_tokens: u32,
) -> Result<String, LlmError> {
LlmClient::submit_hyde_batch(self, items, max_tokens)
}
fn check_batch_status(&self, batch_id: &str) -> Result<String, LlmError> {
LlmClient::check_batch_status(self, batch_id)
}
fn wait_for_batch(&self, batch_id: &str, quiet: bool) -> Result<(), LlmError> {
LlmClient::wait_for_batch(self, batch_id, quiet)
}
fn fetch_batch_results(&self, batch_id: &str) -> Result<HashMap<String, String>, LlmError> {
LlmClient::fetch_batch_results(self, batch_id)
}
fn is_valid_batch_id(&self, id: &str) -> bool {
super::is_valid_anthropic_batch_id(id)
}
fn model_name(&self) -> &str {
&self.llm_config.model
}
}
type PendingFn = dyn Fn(&Store, Option<&str>) -> Result<(), crate::store::StoreError>;
type SubmitFn = dyn Fn(
&dyn BatchProvider,
&[super::provider::BatchSubmitItem],
u32,
) -> Result<String, LlmError>;
pub(super) struct BatchPhase2<'a> {
pub purpose: &'static str,
pub max_tokens: u32,
pub quiet: bool,
pub lock_dir: Option<&'a std::path::Path>,
}
impl BatchPhase2<'_> {
fn acquire_batch_lock(&self) -> Result<Option<std::fs::File>, LlmError> {
let Some(dir) = self.lock_dir else {
return Ok(None);
};
let lock_path = dir.join("batch.lock");
let lock_file = std::fs::OpenOptions::new()
.read(true)
.write(true)
.create(true)
.truncate(false)
.open(&lock_path)
.map_err(|e| {
LlmError::Io(std::io::Error::new(
e.kind(),
format!(
"Failed to open batch lock at {}: {}",
lock_path.display(),
e
),
))
})?;
lock_file.lock().map_err(|e| {
LlmError::Io(std::io::Error::new(
e.kind(),
format!(
"Failed to acquire batch lock at {}: {}",
lock_path.display(),
e
),
))
})?;
tracing::debug!(
lock_path = %lock_path.display(),
purpose = self.purpose,
"Acquired batch lock"
);
Ok(Some(lock_file))
}
pub fn submit_or_resume(
&self,
client: &dyn BatchProvider,
store: &Store,
batch_items: &[super::provider::BatchSubmitItem],
get_pending: &dyn Fn(&Store) -> Result<Option<String>, crate::store::StoreError>,
set_pending: &PendingFn,
submit: &SubmitFn,
) -> Result<HashMap<String, String>, LlmError> {
let _span = tracing::info_span!(
"submit_or_resume",
purpose = self.purpose,
count = batch_items.len()
)
.entered();
let _batch_lock = self.acquire_batch_lock()?;
if batch_items.is_empty() {
return match get_pending(store) {
Ok(Some(pending)) => {
tracing::info!(batch_id = %pending, purpose = self.purpose, "Resuming pending batch");
let results = self.resume(client, store, &pending, set_pending)?;
tracing::info!(
count = results.len(),
purpose = self.purpose,
"Fetched pending batch results — new chunks will be processed on next run"
);
Ok(results)
}
Err(e) => {
tracing::error!(
error = %e,
purpose = self.purpose,
"Failed to read pending batch ID — if a batch was in progress, results may be lost. \
Check store health and re-run with --llm-summaries to recover."
);
Err(LlmError::Store(e))
}
_ => Ok(HashMap::new()),
};
}
let batch_id = match get_pending(store) {
Ok(Some(pending)) => {
tracing::info!(batch_id = %pending, purpose = self.purpose, "Found pending batch, checking status");
match client.check_batch_status(&pending) {
Ok(status)
if status == "in_progress"
|| status == "finalizing"
|| status == "created"
|| status == "ended" =>
{
tracing::info!(
batch_id = %pending,
status = %status,
purpose = self.purpose,
"Pending batch still active, resuming"
);
pending
}
Ok(status) => {
tracing::warn!(
old_batch = %pending,
status = %status,
purpose = self.purpose,
"Pending batch has unexpected status '{}', abandoning and submitting fresh. \
If the batch was valid, its results are lost.",
status
);
if let Err(e) = set_pending(store, None) {
tracing::warn!(error = %e, "Failed to clear stale pending batch ID");
}
self.submit_fresh(client, store, batch_items, set_pending, submit)?
}
Err(e) => {
tracing::warn!(
old_batch = %pending,
error = %e,
purpose = self.purpose,
"Failed to check pending batch status, submitting fresh"
);
if let Err(e) = set_pending(store, None) {
tracing::warn!(error = %e, "Failed to clear stale pending batch ID");
}
self.submit_fresh(client, store, batch_items, set_pending, submit)?
}
}
}
Err(e) => {
tracing::warn!(error = %e, purpose = self.purpose, "Failed to read pending batch ID");
self.submit_fresh(client, store, batch_items, set_pending, submit)?
}
_ => self.submit_fresh(client, store, batch_items, set_pending, submit)?,
};
let result = self.resume(client, store, &batch_id, set_pending);
drop(_batch_lock);
self.cleanup_batch_lock();
result
}
fn cleanup_batch_lock(&self) {
if let Some(dir) = self.lock_dir {
let lock_path = dir.join("batch.lock");
let _ = std::fs::remove_file(&lock_path);
}
}
fn resume(
&self,
client: &dyn BatchProvider,
store: &Store,
batch_id: &str,
clear_pending: &PendingFn,
) -> Result<HashMap<String, String>, LlmError> {
let _span = tracing::info_span!("batch_resume", batch_id, purpose = self.purpose).entered();
client.wait_for_batch(batch_id, self.quiet)?;
if !self.quiet {
eprintln!();
}
let results = client.fetch_batch_results(batch_id)?;
let hash_result = store.get_all_content_hashes();
let valid_hashes: std::collections::HashSet<String> = match &hash_result {
Ok(hashes) => hashes.iter().cloned().collect(),
Err(e) => {
tracing::warn!(error = %e, "Could not validate content hashes, storing all results");
std::collections::HashSet::new()
}
};
let (valid_results, stale_count) = if valid_hashes.is_empty() && hash_result.is_err() {
tracing::error!(purpose = self.purpose, "Cannot validate batch results — skipping storage to prevent stale data. Will retry on next run.");
if let Err(e) = clear_pending(store, None) {
tracing::warn!(error = %e, purpose = self.purpose, "Failed to clear pending batch ID");
}
return Ok(results);
} else if valid_hashes.is_empty() {
(results, 0usize)
} else {
let mut valid = HashMap::new();
let mut stale = 0usize;
for (hash, text) in results {
if valid_hashes.contains(&hash) {
valid.insert(hash, text);
} else {
stale += 1;
}
}
(valid, stale)
};
if stale_count > 0 {
tracing::warn!(
stale = stale_count,
valid = valid_results.len(),
purpose = self.purpose,
"Skipped {} stale batch results (content_hash not in current index — likely from a previous build)",
stale_count
);
}
let model = client.model_name().to_string();
let purpose = self.purpose.to_string();
let to_store: Vec<(String, String, String, String)> = valid_results
.iter()
.map(|(hash, text)| (hash.clone(), text.clone(), model.clone(), purpose.clone()))
.collect();
if !to_store.is_empty() {
store.upsert_summaries_batch(&to_store)?;
}
if let Err(e) = clear_pending(store, None) {
tracing::warn!(error = %e, purpose = self.purpose, "Failed to clear pending batch ID");
}
Ok(valid_results)
}
fn submit_fresh(
&self,
client: &dyn BatchProvider,
store: &Store,
batch_items: &[super::provider::BatchSubmitItem],
set_pending: &PendingFn,
submit: &SubmitFn,
) -> Result<String, LlmError> {
let _span = tracing::info_span!("submit_fresh", purpose = self.purpose).entered();
tracing::info!(
count = batch_items.len(),
purpose = self.purpose,
"Submitting batch to Claude API"
);
let id = submit(client, batch_items, self.max_tokens)?;
set_pending(store, Some(&id)).map_err(|e| {
tracing::error!(
error = %e,
batch_id = %id,
purpose = self.purpose,
"Failed to store pending batch ID — batch {} submitted but ID lost. \
Manual recovery: cqs llm-resume --batch-id {}",
id, id
);
LlmError::Store(e)
})?;
tracing::info!(batch_id = %id, purpose = self.purpose, "Batch submitted, waiting for results");
Ok(id)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::llm::provider::{BatchSubmitItem, MockBatchProvider};
use crate::test_helpers::setup_store;
use std::collections::HashMap;
fn insert_chunk_with_hash(store: &Store, content_hash: &str) {
let embedding = crate::embedder::Embedding::new(vec![0.0f32; crate::EMBEDDING_DIM]);
let embedding_bytes =
crate::store::helpers::embedding_to_bytes(&embedding, crate::EMBEDDING_DIM).unwrap();
let now = chrono::Utc::now().to_rfc3339();
store.rt.block_on(async {
sqlx::query(
"INSERT INTO chunks (id, origin, source_type, language, chunk_type, name,
signature, content, content_hash, doc, line_start, line_end, embedding,
source_mtime, created_at, updated_at)
VALUES (?1, ?2, 'file', 'rust', 'function', ?3,
'', 'fn test() {}', ?4, NULL, 1, 10, ?5, 0, ?6, ?6)",
)
.bind(content_hash) .bind("test.rs")
.bind(format!("test_{}", content_hash))
.bind(content_hash)
.bind(&embedding_bytes)
.bind(&now)
.execute(&store.pool)
.await
.unwrap();
});
}
#[test]
fn test_resume_returns_valid_results_only() {
let (store, _dir) = setup_store();
insert_chunk_with_hash(&store, "hash_aaa");
insert_chunk_with_hash(&store, "hash_bbb");
let mut results = HashMap::new();
results.insert("hash_aaa".to_string(), "summary for aaa".to_string());
results.insert("hash_bbb".to_string(), "summary for bbb".to_string());
results.insert("hash_stale".to_string(), "summary for stale".to_string());
let mock = MockBatchProvider::new("msgbatch_test123", results);
let phase2 = BatchPhase2 {
purpose: "summary",
max_tokens: 1024,
quiet: true,
lock_dir: None,
};
let items = vec![BatchSubmitItem {
custom_id: "hash_aaa".to_string(),
content: "fn aaa() {}".to_string(),
context: "function".to_string(),
language: "rust".to_string(),
}];
let result = phase2.submit_or_resume(
&mock,
&store,
&items,
&|s| s.get_pending_batch_id(),
&|s, id| s.set_pending_batch_id(id),
&|c, items, max_tok| c.submit_batch_prebuilt(items, max_tok),
);
let map = result.unwrap();
assert!(
map.contains_key("hash_aaa"),
"hash_aaa should be in results"
);
assert!(
map.contains_key("hash_bbb"),
"hash_bbb should be in results"
);
assert!(
!map.contains_key("hash_stale"),
"stale hash should be filtered out"
);
assert_eq!(map.len(), 2);
}
#[test]
fn test_empty_batch_items() {
let (store, _dir) = setup_store();
let mock = MockBatchProvider::new("msgbatch_unused", HashMap::new());
let phase2 = BatchPhase2 {
purpose: "summary",
max_tokens: 1024,
quiet: true,
lock_dir: None,
};
let result = phase2.submit_or_resume(
&mock,
&store,
&[], &|s| s.get_pending_batch_id(),
&|s, id| s.set_pending_batch_id(id),
&|c, items, max_tok| c.submit_batch_prebuilt(items, max_tok),
);
let map = result.unwrap();
assert!(
map.is_empty(),
"empty batch_items should return empty results"
);
}
#[test]
fn test_is_valid_batch_id() {
assert!(super::super::is_valid_anthropic_batch_id("msgbatch_abc123"));
assert!(super::super::is_valid_anthropic_batch_id(
"msgbatch_0123456789"
));
assert!(super::super::is_valid_anthropic_batch_id(
"msgbatch_ABCdef_underscore"
));
assert!(!super::super::is_valid_anthropic_batch_id(""));
assert!(!super::super::is_valid_anthropic_batch_id("batch_123"));
assert!(!super::super::is_valid_anthropic_batch_id("not_a_batch"));
assert!(
!super::super::is_valid_anthropic_batch_id("msgbatch_has spaces"),
"spaces should be rejected"
);
assert!(
!super::super::is_valid_anthropic_batch_id("msgbatch_has/slash"),
"slash should be rejected"
);
assert!(
!super::super::is_valid_anthropic_batch_id(&format!("msgbatch_{}", "a".repeat(100))),
"over-length should be rejected"
);
}
#[test]
fn test_all_results_filtered_returns_empty() {
let (store, _dir) = setup_store();
insert_chunk_with_hash(&store, "hash_real");
let mut results = HashMap::new();
results.insert("nonexistent_hash_1".to_string(), "summary 1".to_string());
results.insert("nonexistent_hash_2".to_string(), "summary 2".to_string());
let mock = MockBatchProvider::new("msgbatch_filter", results);
let phase2 = BatchPhase2 {
purpose: "summary",
max_tokens: 1024,
quiet: true,
lock_dir: None,
};
let items = vec![BatchSubmitItem {
custom_id: "nonexistent_hash_1".to_string(),
content: "fn ghost() {}".to_string(),
context: "function".to_string(),
language: "rust".to_string(),
}];
let result = phase2.submit_or_resume(
&mock,
&store,
&items,
&|s| s.get_pending_batch_id(),
&|s, id| s.set_pending_batch_id(id),
&|c, items, max_tok| c.submit_batch_prebuilt(items, max_tok),
);
let map = result.unwrap();
assert!(map.is_empty(), "All stale results should be filtered out");
}
#[test]
fn test_resume_with_pending_batch() {
let (store, _dir) = setup_store();
insert_chunk_with_hash(&store, "hash_resume");
store
.set_pending_batch_id(Some("msgbatch_pending_resume"))
.unwrap();
let mut results = HashMap::new();
results.insert("hash_resume".to_string(), "resumed summary".to_string());
let mock = MockBatchProvider::new("msgbatch_pending_resume", results);
let phase2 = BatchPhase2 {
purpose: "summary",
max_tokens: 1024,
quiet: true,
lock_dir: None,
};
let result = phase2.submit_or_resume(
&mock,
&store,
&[],
&|s| s.get_pending_batch_id(),
&|s, id| s.set_pending_batch_id(id),
&|c, items, max_tok| c.submit_batch_prebuilt(items, max_tok),
);
let map = result.unwrap();
assert!(
map.contains_key("hash_resume"),
"Resumed batch should return valid results"
);
}
#[test]
fn test_clear_pending() {
let (store, _dir) = setup_store();
store
.set_pending_batch_id(Some("msgbatch_pending123"))
.unwrap();
assert_eq!(
store.get_pending_batch_id().unwrap(),
Some("msgbatch_pending123".to_string())
);
store.set_pending_batch_id(None).unwrap();
assert_eq!(store.get_pending_batch_id().unwrap(), None);
}
}