1use serde::{Deserialize, Serialize};
4use url::form_urlencoded;
5
6use super::messages::{CreateMessageRequest, ErrorResponse};
7use crate::types::ApiResponse;
8
9#[derive(Debug, Clone, Serialize)]
10pub struct BatchRequest {
11 pub custom_id: String,
12 pub params: CreateMessageRequest,
13}
14
15impl BatchRequest {
16 pub fn new(custom_id: impl Into<String>, params: CreateMessageRequest) -> Self {
17 Self {
18 custom_id: custom_id.into(),
19 params,
20 }
21 }
22}
23
24#[derive(Debug, Clone, Serialize)]
25pub struct CreateBatchRequest {
26 pub requests: Vec<BatchRequest>,
27}
28
29impl CreateBatchRequest {
30 pub fn new(requests: Vec<BatchRequest>) -> Self {
31 Self { requests }
32 }
33
34 pub fn request(mut self, request: BatchRequest) -> Self {
35 self.requests.push(request);
36 self
37 }
38}
39
40#[derive(Debug, Clone, Deserialize)]
41pub struct MessageBatch {
42 pub id: String,
43 #[serde(rename = "type")]
44 pub batch_type: String,
45 pub processing_status: BatchStatus,
46 pub request_counts: RequestCounts,
47 pub ended_at: Option<String>,
48 pub created_at: String,
49 pub expires_at: String,
50 pub cancel_initiated_at: Option<String>,
51 pub results_url: Option<String>,
52}
53
54#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
55#[serde(rename_all = "snake_case")]
56pub enum BatchStatus {
57 InProgress,
58 Canceling,
59 Ended,
60}
61
62#[derive(Debug, Clone, Copy, Default, Deserialize)]
63pub struct RequestCounts {
64 pub processing: u32,
65 pub succeeded: u32,
66 pub errored: u32,
67 pub canceled: u32,
68 pub expired: u32,
69}
70
71#[derive(Debug, Clone, Deserialize)]
72pub struct BatchResult {
73 pub custom_id: String,
74 pub result: BatchResultType,
75}
76
77#[derive(Debug, Clone, Deserialize)]
78#[serde(tag = "type", rename_all = "snake_case")]
79pub enum BatchResultType {
80 Succeeded { message: ApiResponse },
81 Errored { error: BatchError },
82 Canceled,
83 Expired,
84}
85
86#[derive(Debug, Clone, Deserialize)]
87pub struct BatchError {
88 #[serde(rename = "type")]
89 pub error_type: String,
90 pub message: String,
91}
92
93#[derive(Debug, Clone, Deserialize)]
94pub struct BatchListResponse {
95 pub data: Vec<MessageBatch>,
96 pub has_more: bool,
97 pub first_id: Option<String>,
98 pub last_id: Option<String>,
99}
100
101pub struct BatchClient<'a> {
102 client: &'a super::Client,
103}
104
105impl<'a> BatchClient<'a> {
106 pub fn new(client: &'a super::Client) -> Self {
107 Self { client }
108 }
109
110 fn base_url(&self) -> &str {
111 self.client.adapter().base_url()
112 }
113
114 fn api_version(&self) -> &str {
115 &self.client.config().api_version
116 }
117
118 fn build_url(&self, path: &str) -> String {
119 format!("{}/v1/messages/batches{}", self.base_url(), path)
120 }
121
122 async fn build_request(&self, method: reqwest::Method, url: &str) -> reqwest::RequestBuilder {
123 if let Err(e) = self.client.adapter().ensure_fresh_credentials().await {
124 tracing::debug!("Proactive credential refresh failed: {}", e);
125 }
126
127 let mut request = self
128 .client
129 .http()
130 .request(method, url)
131 .header("anthropic-version", self.api_version())
132 .header("content-type", "application/json");
133
134 request = self.client.adapter().apply_auth_headers(request).await;
135
136 if let Some(beta_header) = self.client.config().beta.header_value() {
137 request = request.header("anthropic-beta", beta_header);
138 }
139
140 request
141 }
142
143 async fn handle_response<T: serde::de::DeserializeOwned>(
144 response: reqwest::Response,
145 ) -> crate::Result<T> {
146 if !response.status().is_success() {
147 let status = response.status().as_u16();
148 let error: ErrorResponse = response.json().await?;
149 return Err(error.into_error(status));
150 }
151 Ok(response.json().await?)
152 }
153
154 pub async fn create(&self, request: CreateBatchRequest) -> crate::Result<MessageBatch> {
155 let url = self.build_url("");
156 let response = self
157 .build_request(reqwest::Method::POST, &url)
158 .await
159 .json(&request)
160 .send()
161 .await?;
162 Self::handle_response(response).await
163 }
164
165 pub async fn get(&self, batch_id: &str) -> crate::Result<MessageBatch> {
166 let url = self.build_url(&format!("/{}", batch_id));
167 let response = self
168 .build_request(reqwest::Method::GET, &url)
169 .await
170 .send()
171 .await?;
172 Self::handle_response(response).await
173 }
174
175 pub async fn cancel(&self, batch_id: &str) -> crate::Result<MessageBatch> {
176 let url = self.build_url(&format!("/{}/cancel", batch_id));
177 let response = self
178 .build_request(reqwest::Method::POST, &url)
179 .await
180 .send()
181 .await?;
182 Self::handle_response(response).await
183 }
184
185 pub async fn list(
186 &self,
187 limit: Option<u32>,
188 after_id: Option<&str>,
189 ) -> crate::Result<BatchListResponse> {
190 let mut url = self.build_url("");
191
192 let mut query_params: Vec<(&str, String)> = Vec::new();
193 if let Some(limit) = limit {
194 query_params.push(("limit", limit.to_string()));
195 }
196 if let Some(after_id) = after_id {
197 query_params.push(("after_id", after_id.to_string()));
198 }
199 if !query_params.is_empty() {
200 let encoded: String = form_urlencoded::Serializer::new(String::new())
201 .extend_pairs(query_params.iter().map(|(k, v)| (*k, v.as_str())))
202 .finish();
203 url = format!("{}?{}", url, encoded);
204 }
205
206 let response = self
207 .build_request(reqwest::Method::GET, &url)
208 .await
209 .send()
210 .await?;
211 Self::handle_response(response).await
212 }
213
214 pub async fn results(&self, batch_id: &str) -> crate::Result<Vec<BatchResult>> {
215 let batch = self.get(batch_id).await?;
216
217 let results_url = batch.results_url.ok_or_else(|| crate::Error::Api {
218 message: "Batch results not yet available".to_string(),
219 status: None,
220 error_type: None,
221 })?;
222
223 let mut request = self
224 .client
225 .http()
226 .get(&results_url)
227 .header("anthropic-version", self.api_version());
228
229 request = self.client.adapter().apply_auth_headers(request).await;
230
231 let response = request.send().await?;
232
233 if !response.status().is_success() {
234 let status = response.status().as_u16();
235 return Err(crate::Error::Api {
236 message: format!("Failed to fetch batch results: HTTP {}", status),
237 status: Some(status),
238 error_type: None,
239 });
240 }
241
242 let text = response.text().await?;
243 let results: Vec<BatchResult> = text
244 .lines()
245 .filter(|line| !line.is_empty())
246 .filter_map(|line| serde_json::from_str(line).ok())
247 .collect();
248
249 Ok(results)
250 }
251
252 pub async fn wait_for_completion(
253 &self,
254 batch_id: &str,
255 poll_interval: std::time::Duration,
256 ) -> crate::Result<MessageBatch> {
257 loop {
258 let batch = self.get(batch_id).await?;
259 if batch.processing_status == BatchStatus::Ended {
260 return Ok(batch);
261 }
262 tokio::time::sleep(poll_interval).await;
263 }
264 }
265}
266
267#[cfg(test)]
268mod tests {
269 use super::*;
270
271 #[test]
272 fn test_batch_request_serialization() {
273 let request = BatchRequest::new(
274 "test-1",
275 CreateMessageRequest::new(
276 "claude-sonnet-4-5",
277 vec![crate::types::Message::user("Hello")],
278 ),
279 );
280 let json = serde_json::to_string(&request).unwrap();
281 assert!(json.contains("test-1"));
282 }
283
284 #[test]
285 fn test_batch_status_deserialization() {
286 let json = r#""in_progress""#;
287 let status: BatchStatus = serde_json::from_str(json).unwrap();
288 assert_eq!(status, BatchStatus::InProgress);
289 }
290
291 #[test]
292 fn test_batch_status_all_variants() {
293 assert_eq!(
294 serde_json::from_str::<BatchStatus>(r#""canceling""#).unwrap(),
295 BatchStatus::Canceling
296 );
297 assert_eq!(
298 serde_json::from_str::<BatchStatus>(r#""ended""#).unwrap(),
299 BatchStatus::Ended
300 );
301 }
302
303 #[test]
304 fn test_create_batch_request_builder() {
305 let req1 = BatchRequest::new(
306 "req-1",
307 CreateMessageRequest::new("claude-sonnet-4-5", vec![crate::types::Message::user("A")]),
308 );
309 let req2 = BatchRequest::new(
310 "req-2",
311 CreateMessageRequest::new("claude-sonnet-4-5", vec![crate::types::Message::user("B")]),
312 );
313
314 let batch = CreateBatchRequest::new(vec![req1]).request(req2);
315 assert_eq!(batch.requests.len(), 2);
316 assert_eq!(batch.requests[0].custom_id, "req-1");
317 assert_eq!(batch.requests[1].custom_id, "req-2");
318 }
319
320 #[test]
321 fn test_request_counts_default() {
322 let counts = RequestCounts::default();
323 assert_eq!(counts.processing, 0);
324 assert_eq!(counts.succeeded, 0);
325 assert_eq!(counts.errored, 0);
326 assert_eq!(counts.canceled, 0);
327 assert_eq!(counts.expired, 0);
328 }
329
330 #[test]
331 fn test_request_counts_deserialization() {
332 let json = r#"{"processing":5,"succeeded":10,"errored":2,"canceled":1,"expired":0}"#;
333 let counts: RequestCounts = serde_json::from_str(json).unwrap();
334 assert_eq!(counts.processing, 5);
335 assert_eq!(counts.succeeded, 10);
336 assert_eq!(counts.errored, 2);
337 assert_eq!(counts.canceled, 1);
338 assert_eq!(counts.expired, 0);
339 }
340
341 #[test]
342 fn test_batch_error_deserialization() {
343 let json = r#"{"type":"invalid_request","message":"Bad input"}"#;
344 let error: BatchError = serde_json::from_str(json).unwrap();
345 assert_eq!(error.error_type, "invalid_request");
346 assert_eq!(error.message, "Bad input");
347 }
348
349 #[test]
350 fn test_batch_result_succeeded() {
351 let json = r#"{
352 "custom_id": "req-1",
353 "result": {
354 "type": "succeeded",
355 "message": {
356 "id": "msg_123",
357 "type": "message",
358 "role": "assistant",
359 "content": [{"type": "text", "text": "Hello"}],
360 "model": "claude-sonnet-4-5",
361 "stop_reason": "end_turn",
362 "usage": {"input_tokens": 10, "output_tokens": 5}
363 }
364 }
365 }"#;
366 let result: BatchResult = serde_json::from_str(json).unwrap();
367 assert_eq!(result.custom_id, "req-1");
368 assert!(matches!(result.result, BatchResultType::Succeeded { .. }));
369 }
370
371 #[test]
372 fn test_batch_result_errored() {
373 let json = r#"{
374 "custom_id": "req-2",
375 "result": {
376 "type": "errored",
377 "error": {
378 "type": "rate_limit",
379 "message": "Too many requests"
380 }
381 }
382 }"#;
383 let result: BatchResult = serde_json::from_str(json).unwrap();
384 assert_eq!(result.custom_id, "req-2");
385 if let BatchResultType::Errored { error } = result.result {
386 assert_eq!(error.error_type, "rate_limit");
387 assert_eq!(error.message, "Too many requests");
388 } else {
389 panic!("Expected Errored variant");
390 }
391 }
392
393 #[test]
394 fn test_batch_result_canceled() {
395 let json = r#"{"custom_id": "req-3", "result": {"type": "canceled"}}"#;
396 let result: BatchResult = serde_json::from_str(json).unwrap();
397 assert!(matches!(result.result, BatchResultType::Canceled));
398 }
399
400 #[test]
401 fn test_batch_result_expired() {
402 let json = r#"{"custom_id": "req-4", "result": {"type": "expired"}}"#;
403 let result: BatchResult = serde_json::from_str(json).unwrap();
404 assert!(matches!(result.result, BatchResultType::Expired));
405 }
406
407 #[test]
408 fn test_message_batch_deserialization() {
409 let json = r#"{
410 "id": "batch_123",
411 "type": "message_batch",
412 "processing_status": "in_progress",
413 "request_counts": {"processing": 5, "succeeded": 0, "errored": 0, "canceled": 0, "expired": 0},
414 "created_at": "2024-01-01T00:00:00Z",
415 "expires_at": "2024-01-02T00:00:00Z",
416 "ended_at": null,
417 "cancel_initiated_at": null,
418 "results_url": null
419 }"#;
420 let batch: MessageBatch = serde_json::from_str(json).unwrap();
421 assert_eq!(batch.id, "batch_123");
422 assert_eq!(batch.processing_status, BatchStatus::InProgress);
423 assert_eq!(batch.request_counts.processing, 5);
424 assert!(batch.ended_at.is_none());
425 assert!(batch.results_url.is_none());
426 }
427
428 #[test]
429 fn test_batch_list_response() {
430 let json = r#"{
431 "data": [],
432 "has_more": false,
433 "first_id": null,
434 "last_id": null
435 }"#;
436 let response: BatchListResponse = serde_json::from_str(json).unwrap();
437 assert!(response.data.is_empty());
438 assert!(!response.has_more);
439 }
440
441 #[test]
442 fn test_batch_request_with_all_params() {
443 let request = CreateMessageRequest::new(
444 "claude-sonnet-4-5",
445 vec![crate::types::Message::user("Test")],
446 )
447 .max_tokens(1000)
448 .temperature(0.5);
449
450 let batch_req = BatchRequest::new("custom-id-123", request);
451 assert_eq!(batch_req.custom_id, "custom-id-123");
452 assert_eq!(batch_req.params.max_tokens, 1000);
453 }
454}