1use serde::{Deserialize, Serialize};
4use url::form_urlencoded;
5
6use super::messages::{CreateMessageRequest, ErrorResponse};
7use crate::types::ApiResponse;
8
9const BATCH_BASE_URL: &str = "https://api.anthropic.com";
10
11#[derive(Debug, Clone, Serialize)]
12pub struct BatchRequest {
13 pub custom_id: String,
14 pub params: CreateMessageRequest,
15}
16
17impl BatchRequest {
18 pub fn new(custom_id: impl Into<String>, params: CreateMessageRequest) -> Self {
19 Self {
20 custom_id: custom_id.into(),
21 params,
22 }
23 }
24}
25
26#[derive(Debug, Clone, Serialize)]
27pub struct CreateBatchRequest {
28 pub requests: Vec<BatchRequest>,
29}
30
31impl CreateBatchRequest {
32 pub fn new(requests: Vec<BatchRequest>) -> Self {
33 Self { requests }
34 }
35
36 pub fn with_request(mut self, request: BatchRequest) -> Self {
37 self.requests.push(request);
38 self
39 }
40}
41
42#[derive(Debug, Clone, Deserialize)]
43pub struct MessageBatch {
44 pub id: String,
45 #[serde(rename = "type")]
46 pub batch_type: String,
47 pub processing_status: BatchStatus,
48 pub request_counts: RequestCounts,
49 pub ended_at: Option<String>,
50 pub created_at: String,
51 pub expires_at: String,
52 pub cancel_initiated_at: Option<String>,
53 pub results_url: Option<String>,
54}
55
56#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
57#[serde(rename_all = "snake_case")]
58pub enum BatchStatus {
59 InProgress,
60 Canceling,
61 Ended,
62}
63
64#[derive(Debug, Clone, Copy, Default, Deserialize)]
65pub struct RequestCounts {
66 pub processing: u32,
67 pub succeeded: u32,
68 pub errored: u32,
69 pub canceled: u32,
70 pub expired: u32,
71}
72
73#[derive(Debug, Clone, Deserialize)]
74pub struct BatchResult {
75 pub custom_id: String,
76 pub result: BatchResultType,
77}
78
79#[derive(Debug, Clone, Deserialize)]
80#[serde(tag = "type", rename_all = "snake_case")]
81pub enum BatchResultType {
82 Succeeded { message: ApiResponse },
83 Errored { error: BatchError },
84 Canceled,
85 Expired,
86}
87
88#[derive(Debug, Clone, Deserialize)]
89pub struct BatchError {
90 #[serde(rename = "type")]
91 pub error_type: String,
92 pub message: String,
93}
94
95#[derive(Debug, Clone, Deserialize)]
96pub struct BatchListResponse {
97 pub data: Vec<MessageBatch>,
98 pub has_more: bool,
99 pub first_id: Option<String>,
100 pub last_id: Option<String>,
101}
102
103pub struct BatchClient<'a> {
104 client: &'a super::Client,
105}
106
107impl<'a> BatchClient<'a> {
108 pub fn new(client: &'a super::Client) -> Self {
109 Self { client }
110 }
111
112 fn base_url(&self) -> String {
113 std::env::var("ANTHROPIC_BASE_URL").unwrap_or_else(|_| BATCH_BASE_URL.into())
114 }
115
116 fn api_version(&self) -> &str {
117 &self.client.config().api_version
118 }
119
120 fn build_url(&self, path: &str) -> String {
121 format!("{}/v1/messages/batches{}", self.base_url(), path)
122 }
123
124 async fn build_request(&self, method: reqwest::Method, url: &str) -> reqwest::RequestBuilder {
125 if let Err(e) = self.client.adapter().ensure_fresh_credentials().await {
126 tracing::debug!("Proactive credential refresh failed: {}", e);
127 }
128
129 let mut request = self
130 .client
131 .http()
132 .request(method, url)
133 .header("anthropic-version", self.api_version())
134 .header("content-type", "application/json");
135
136 request = self.client.adapter().apply_auth_headers(request).await;
137
138 if let Some(beta_header) = self.client.config().beta.header_value() {
139 request = request.header("anthropic-beta", beta_header);
140 }
141
142 request
143 }
144
145 pub async fn create(&self, request: CreateBatchRequest) -> crate::Result<MessageBatch> {
146 let url = self.build_url("");
147 let request = self
148 .build_request(reqwest::Method::POST, &url)
149 .await
150 .json(&request);
151 let response = request.send().await?;
152
153 if !response.status().is_success() {
154 let status = response.status().as_u16();
155 let error: ErrorResponse = response.json().await?;
156 return Err(error.into_error(status));
157 }
158
159 Ok(response.json().await?)
160 }
161
162 pub async fn get(&self, batch_id: &str) -> crate::Result<MessageBatch> {
163 let url = self.build_url(&format!("/{}", batch_id));
164 let response = self
165 .build_request(reqwest::Method::GET, &url)
166 .await
167 .send()
168 .await?;
169
170 if !response.status().is_success() {
171 let status = response.status().as_u16();
172 let error: ErrorResponse = response.json().await?;
173 return Err(error.into_error(status));
174 }
175
176 Ok(response.json().await?)
177 }
178
179 pub async fn cancel(&self, batch_id: &str) -> crate::Result<MessageBatch> {
180 let url = self.build_url(&format!("/{}/cancel", batch_id));
181 let response = self
182 .build_request(reqwest::Method::POST, &url)
183 .await
184 .send()
185 .await?;
186
187 if !response.status().is_success() {
188 let status = response.status().as_u16();
189 let error: ErrorResponse = response.json().await?;
190 return Err(error.into_error(status));
191 }
192
193 Ok(response.json().await?)
194 }
195
196 pub async fn list(
197 &self,
198 limit: Option<u32>,
199 after_id: Option<&str>,
200 ) -> crate::Result<BatchListResponse> {
201 let mut url = self.build_url("");
202
203 let mut query_params: Vec<(&str, String)> = Vec::new();
204 if let Some(limit) = limit {
205 query_params.push(("limit", limit.to_string()));
206 }
207 if let Some(after_id) = after_id {
208 query_params.push(("after_id", after_id.to_string()));
209 }
210 if !query_params.is_empty() {
211 let encoded: String = form_urlencoded::Serializer::new(String::new())
212 .extend_pairs(query_params.iter().map(|(k, v)| (*k, v.as_str())))
213 .finish();
214 url = format!("{}?{}", url, encoded);
215 }
216
217 let response = self
218 .build_request(reqwest::Method::GET, &url)
219 .await
220 .send()
221 .await?;
222
223 if !response.status().is_success() {
224 let status = response.status().as_u16();
225 let error: ErrorResponse = response.json().await?;
226 return Err(error.into_error(status));
227 }
228
229 Ok(response.json().await?)
230 }
231
232 pub async fn results(&self, batch_id: &str) -> crate::Result<Vec<BatchResult>> {
233 let batch = self.get(batch_id).await?;
234
235 let results_url = batch.results_url.ok_or_else(|| crate::Error::Api {
236 message: "Batch results not yet available".to_string(),
237 status: None,
238 error_type: None,
239 })?;
240
241 let mut request = self
242 .client
243 .http()
244 .get(&results_url)
245 .header("anthropic-version", self.api_version());
246
247 request = self.client.adapter().apply_auth_headers(request).await;
248
249 let response = request.send().await?;
250
251 if !response.status().is_success() {
252 let status = response.status().as_u16();
253 return Err(crate::Error::Api {
254 message: format!("Failed to fetch batch results: HTTP {}", status),
255 status: Some(status),
256 error_type: None,
257 });
258 }
259
260 let text = response.text().await?;
261 let results: Vec<BatchResult> = text
262 .lines()
263 .filter(|line| !line.is_empty())
264 .filter_map(|line| serde_json::from_str(line).ok())
265 .collect();
266
267 Ok(results)
268 }
269
270 pub async fn wait_for_completion(
271 &self,
272 batch_id: &str,
273 poll_interval: std::time::Duration,
274 ) -> crate::Result<MessageBatch> {
275 loop {
276 let batch = self.get(batch_id).await?;
277 if batch.processing_status == BatchStatus::Ended {
278 return Ok(batch);
279 }
280 tokio::time::sleep(poll_interval).await;
281 }
282 }
283}
284
285#[cfg(test)]
286mod tests {
287 use super::*;
288
289 #[test]
290 fn test_batch_request_serialization() {
291 let request = BatchRequest::new(
292 "test-1",
293 CreateMessageRequest::new(
294 "claude-sonnet-4-5",
295 vec![crate::types::Message::user("Hello")],
296 ),
297 );
298 let json = serde_json::to_string(&request).unwrap();
299 assert!(json.contains("test-1"));
300 }
301
302 #[test]
303 fn test_batch_status_deserialization() {
304 let json = r#""in_progress""#;
305 let status: BatchStatus = serde_json::from_str(json).unwrap();
306 assert_eq!(status, BatchStatus::InProgress);
307 }
308
309 #[test]
310 fn test_batch_status_all_variants() {
311 assert_eq!(
312 serde_json::from_str::<BatchStatus>(r#""canceling""#).unwrap(),
313 BatchStatus::Canceling
314 );
315 assert_eq!(
316 serde_json::from_str::<BatchStatus>(r#""ended""#).unwrap(),
317 BatchStatus::Ended
318 );
319 }
320
321 #[test]
322 fn test_create_batch_request_builder() {
323 let req1 = BatchRequest::new(
324 "req-1",
325 CreateMessageRequest::new("claude-sonnet-4-5", vec![crate::types::Message::user("A")]),
326 );
327 let req2 = BatchRequest::new(
328 "req-2",
329 CreateMessageRequest::new("claude-sonnet-4-5", vec![crate::types::Message::user("B")]),
330 );
331
332 let batch = CreateBatchRequest::new(vec![req1]).with_request(req2);
333 assert_eq!(batch.requests.len(), 2);
334 assert_eq!(batch.requests[0].custom_id, "req-1");
335 assert_eq!(batch.requests[1].custom_id, "req-2");
336 }
337
338 #[test]
339 fn test_request_counts_default() {
340 let counts = RequestCounts::default();
341 assert_eq!(counts.processing, 0);
342 assert_eq!(counts.succeeded, 0);
343 assert_eq!(counts.errored, 0);
344 assert_eq!(counts.canceled, 0);
345 assert_eq!(counts.expired, 0);
346 }
347
348 #[test]
349 fn test_request_counts_deserialization() {
350 let json = r#"{"processing":5,"succeeded":10,"errored":2,"canceled":1,"expired":0}"#;
351 let counts: RequestCounts = serde_json::from_str(json).unwrap();
352 assert_eq!(counts.processing, 5);
353 assert_eq!(counts.succeeded, 10);
354 assert_eq!(counts.errored, 2);
355 assert_eq!(counts.canceled, 1);
356 assert_eq!(counts.expired, 0);
357 }
358
359 #[test]
360 fn test_batch_error_deserialization() {
361 let json = r#"{"type":"invalid_request","message":"Bad input"}"#;
362 let error: BatchError = serde_json::from_str(json).unwrap();
363 assert_eq!(error.error_type, "invalid_request");
364 assert_eq!(error.message, "Bad input");
365 }
366
367 #[test]
368 fn test_batch_result_succeeded() {
369 let json = r#"{
370 "custom_id": "req-1",
371 "result": {
372 "type": "succeeded",
373 "message": {
374 "id": "msg_123",
375 "type": "message",
376 "role": "assistant",
377 "content": [{"type": "text", "text": "Hello"}],
378 "model": "claude-sonnet-4-5",
379 "stop_reason": "end_turn",
380 "usage": {"input_tokens": 10, "output_tokens": 5}
381 }
382 }
383 }"#;
384 let result: BatchResult = serde_json::from_str(json).unwrap();
385 assert_eq!(result.custom_id, "req-1");
386 assert!(matches!(result.result, BatchResultType::Succeeded { .. }));
387 }
388
389 #[test]
390 fn test_batch_result_errored() {
391 let json = r#"{
392 "custom_id": "req-2",
393 "result": {
394 "type": "errored",
395 "error": {
396 "type": "rate_limit",
397 "message": "Too many requests"
398 }
399 }
400 }"#;
401 let result: BatchResult = serde_json::from_str(json).unwrap();
402 assert_eq!(result.custom_id, "req-2");
403 if let BatchResultType::Errored { error } = result.result {
404 assert_eq!(error.error_type, "rate_limit");
405 assert_eq!(error.message, "Too many requests");
406 } else {
407 panic!("Expected Errored variant");
408 }
409 }
410
411 #[test]
412 fn test_batch_result_canceled() {
413 let json = r#"{"custom_id": "req-3", "result": {"type": "canceled"}}"#;
414 let result: BatchResult = serde_json::from_str(json).unwrap();
415 assert!(matches!(result.result, BatchResultType::Canceled));
416 }
417
418 #[test]
419 fn test_batch_result_expired() {
420 let json = r#"{"custom_id": "req-4", "result": {"type": "expired"}}"#;
421 let result: BatchResult = serde_json::from_str(json).unwrap();
422 assert!(matches!(result.result, BatchResultType::Expired));
423 }
424
425 #[test]
426 fn test_message_batch_deserialization() {
427 let json = r#"{
428 "id": "batch_123",
429 "type": "message_batch",
430 "processing_status": "in_progress",
431 "request_counts": {"processing": 5, "succeeded": 0, "errored": 0, "canceled": 0, "expired": 0},
432 "created_at": "2024-01-01T00:00:00Z",
433 "expires_at": "2024-01-02T00:00:00Z",
434 "ended_at": null,
435 "cancel_initiated_at": null,
436 "results_url": null
437 }"#;
438 let batch: MessageBatch = serde_json::from_str(json).unwrap();
439 assert_eq!(batch.id, "batch_123");
440 assert_eq!(batch.processing_status, BatchStatus::InProgress);
441 assert_eq!(batch.request_counts.processing, 5);
442 assert!(batch.ended_at.is_none());
443 assert!(batch.results_url.is_none());
444 }
445
446 #[test]
447 fn test_batch_list_response() {
448 let json = r#"{
449 "data": [],
450 "has_more": false,
451 "first_id": null,
452 "last_id": null
453 }"#;
454 let response: BatchListResponse = serde_json::from_str(json).unwrap();
455 assert!(response.data.is_empty());
456 assert!(!response.has_more);
457 }
458
459 #[test]
460 fn test_batch_request_with_all_params() {
461 let request = CreateMessageRequest::new(
462 "claude-sonnet-4-5",
463 vec![crate::types::Message::user("Test")],
464 )
465 .with_max_tokens(1000)
466 .with_temperature(0.5);
467
468 let batch_req = BatchRequest::new("custom-id-123", request);
469 assert_eq!(batch_req.custom_id, "custom-id-123");
470 assert_eq!(batch_req.params.max_tokens, 1000);
471 }
472}