use std::collections::HashMap;
use super::LlmError;
pub struct BatchSubmitItem {
pub custom_id: String,
pub content: String,
pub context: String,
pub language: String,
}
pub trait BatchProvider {
fn submit_batch(
&self,
items: &[BatchSubmitItem],
max_tokens: u32,
purpose: &str,
prompt_builder: fn(&str, &str, &str) -> String,
) -> Result<String, LlmError>;
fn submit_batch_prebuilt(
&self,
items: &[BatchSubmitItem],
max_tokens: u32,
) -> Result<String, LlmError>;
fn submit_doc_batch(
&self,
items: &[BatchSubmitItem],
max_tokens: u32,
) -> Result<String, LlmError>;
fn submit_hyde_batch(
&self,
items: &[BatchSubmitItem],
max_tokens: u32,
) -> Result<String, LlmError>;
fn check_batch_status(&self, batch_id: &str) -> Result<String, LlmError>;
fn wait_for_batch(&self, batch_id: &str, quiet: bool) -> Result<(), LlmError>;
fn fetch_batch_results(&self, batch_id: &str) -> Result<HashMap<String, String>, LlmError>;
fn is_valid_batch_id(&self, id: &str) -> bool {
!id.is_empty()
}
fn model_name(&self) -> &str;
}
#[cfg(test)]
pub(crate) struct MockBatchProvider {
pub batch_id: String,
pub results: HashMap<String, String>,
pub model: String,
}
#[cfg(test)]
impl MockBatchProvider {
pub fn new(batch_id: &str, results: HashMap<String, String>) -> Self {
Self {
batch_id: batch_id.to_string(),
results,
model: "mock-model".to_string(),
}
}
}
#[cfg(test)]
impl BatchProvider for MockBatchProvider {
fn submit_batch(
&self,
_items: &[BatchSubmitItem],
_max_tokens: u32,
_purpose: &str,
_prompt_builder: fn(&str, &str, &str) -> String,
) -> Result<String, LlmError> {
Ok(self.batch_id.clone())
}
fn submit_batch_prebuilt(
&self,
_items: &[BatchSubmitItem],
_max_tokens: u32,
) -> Result<String, LlmError> {
Ok(self.batch_id.clone())
}
fn submit_doc_batch(
&self,
_items: &[BatchSubmitItem],
_max_tokens: u32,
) -> Result<String, LlmError> {
Ok(self.batch_id.clone())
}
fn submit_hyde_batch(
&self,
_items: &[BatchSubmitItem],
_max_tokens: u32,
) -> Result<String, LlmError> {
Ok(self.batch_id.clone())
}
fn check_batch_status(&self, _batch_id: &str) -> Result<String, LlmError> {
Ok("ended".to_string())
}
fn wait_for_batch(&self, _batch_id: &str, _quiet: bool) -> Result<(), LlmError> {
Ok(())
}
fn fetch_batch_results(&self, _batch_id: &str) -> Result<HashMap<String, String>, LlmError> {
Ok(self.results.clone())
}
fn is_valid_batch_id(&self, id: &str) -> bool {
id.starts_with("msgbatch_")
}
fn model_name(&self) -> &str {
&self.model
}
}
#[cfg(test)]
mod tests {
use super::*;
struct DefaultValidationProvider;
impl BatchProvider for DefaultValidationProvider {
fn submit_batch(
&self,
_items: &[BatchSubmitItem],
_max_tokens: u32,
_purpose: &str,
_prompt_builder: fn(&str, &str, &str) -> String,
) -> Result<String, LlmError> {
Ok(String::new())
}
fn submit_batch_prebuilt(
&self,
_items: &[BatchSubmitItem],
_max_tokens: u32,
) -> Result<String, LlmError> {
Ok(String::new())
}
fn submit_doc_batch(
&self,
_items: &[BatchSubmitItem],
_max_tokens: u32,
) -> Result<String, LlmError> {
Ok(String::new())
}
fn submit_hyde_batch(
&self,
_items: &[BatchSubmitItem],
_max_tokens: u32,
) -> Result<String, LlmError> {
Ok(String::new())
}
fn check_batch_status(&self, _batch_id: &str) -> Result<String, LlmError> {
Ok("ended".to_string())
}
fn wait_for_batch(&self, _batch_id: &str, _quiet: bool) -> Result<(), LlmError> {
Ok(())
}
fn fetch_batch_results(
&self,
_batch_id: &str,
) -> Result<HashMap<String, String>, LlmError> {
Ok(HashMap::new())
}
fn model_name(&self) -> &str {
"default-test"
}
}
#[test]
fn default_is_valid_batch_id_accepts_any_nonempty() {
let provider = DefaultValidationProvider;
assert!(provider.is_valid_batch_id("any_format_123"));
assert!(provider.is_valid_batch_id("custom-provider-batch-xyz"));
assert!(provider.is_valid_batch_id("msgbatch_abc"));
}
#[test]
fn default_is_valid_batch_id_rejects_empty() {
let provider = DefaultValidationProvider;
assert!(!provider.is_valid_batch_id(""));
}
#[test]
fn mock_provider_uses_anthropic_validation() {
let mock = MockBatchProvider::new("msgbatch_test", HashMap::new());
assert!(mock.is_valid_batch_id("msgbatch_abc123"));
assert!(!mock.is_valid_batch_id("other_format"));
}
}