Skip to main content

claude_agent/client/
batch.rs

1//! Batch Processing API for large-scale asynchronous requests.
2
3use serde::{Deserialize, Serialize};
4use url::form_urlencoded;
5
6use super::messages::{CreateMessageRequest, ErrorResponse};
7use crate::types::ApiResponse;
8
9#[derive(Debug, Clone, Serialize)]
10pub struct BatchRequest {
11    pub custom_id: String,
12    pub params: CreateMessageRequest,
13}
14
15impl BatchRequest {
16    pub fn new(custom_id: impl Into<String>, params: CreateMessageRequest) -> Self {
17        Self {
18            custom_id: custom_id.into(),
19            params,
20        }
21    }
22}
23
24#[derive(Debug, Clone, Serialize)]
25pub struct CreateBatchRequest {
26    pub requests: Vec<BatchRequest>,
27}
28
29impl CreateBatchRequest {
30    pub fn new(requests: Vec<BatchRequest>) -> Self {
31        Self { requests }
32    }
33
34    pub fn request(mut self, request: BatchRequest) -> Self {
35        self.requests.push(request);
36        self
37    }
38}
39
40#[derive(Debug, Clone, Deserialize)]
41pub struct MessageBatch {
42    pub id: String,
43    #[serde(rename = "type")]
44    pub batch_type: String,
45    pub processing_status: BatchStatus,
46    pub request_counts: RequestCounts,
47    pub ended_at: Option<String>,
48    pub created_at: String,
49    pub expires_at: String,
50    pub cancel_initiated_at: Option<String>,
51    pub results_url: Option<String>,
52}
53
54#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
55#[serde(rename_all = "snake_case")]
56pub enum BatchStatus {
57    InProgress,
58    Canceling,
59    Ended,
60}
61
62#[derive(Debug, Clone, Copy, Default, Deserialize)]
63pub struct RequestCounts {
64    pub processing: u32,
65    pub succeeded: u32,
66    pub errored: u32,
67    pub canceled: u32,
68    pub expired: u32,
69}
70
71#[derive(Debug, Clone, Deserialize)]
72pub struct BatchResult {
73    pub custom_id: String,
74    pub result: BatchResultType,
75}
76
77#[derive(Debug, Clone, Deserialize)]
78#[serde(tag = "type", rename_all = "snake_case")]
79pub enum BatchResultType {
80    Succeeded { message: ApiResponse },
81    Errored { error: BatchError },
82    Canceled,
83    Expired,
84}
85
86#[derive(Debug, Clone, Deserialize)]
87pub struct BatchError {
88    #[serde(rename = "type")]
89    pub error_type: String,
90    pub message: String,
91}
92
93#[derive(Debug, Clone, Deserialize)]
94pub struct BatchListResponse {
95    pub data: Vec<MessageBatch>,
96    pub has_more: bool,
97    pub first_id: Option<String>,
98    pub last_id: Option<String>,
99}
100
101pub struct BatchClient<'a> {
102    client: &'a super::Client,
103}
104
105impl<'a> BatchClient<'a> {
106    pub fn new(client: &'a super::Client) -> Self {
107        Self { client }
108    }
109
110    fn base_url(&self) -> &str {
111        self.client.adapter().base_url()
112    }
113
114    fn api_version(&self) -> &str {
115        &self.client.config().api_version
116    }
117
118    fn build_url(&self, path: &str) -> String {
119        format!("{}/v1/messages/batches{}", self.base_url(), path)
120    }
121
122    async fn build_request(&self, method: reqwest::Method, url: &str) -> reqwest::RequestBuilder {
123        if let Err(e) = self.client.adapter().ensure_fresh_credentials().await {
124            tracing::debug!("Proactive credential refresh failed: {}", e);
125        }
126
127        let mut request = self
128            .client
129            .http()
130            .request(method, url)
131            .header("anthropic-version", self.api_version())
132            .header("content-type", "application/json");
133
134        request = self.client.adapter().apply_auth_headers(request).await;
135
136        if let Some(beta_header) = self.client.config().beta.header_value() {
137            request = request.header("anthropic-beta", beta_header);
138        }
139
140        request
141    }
142
143    async fn handle_response<T: serde::de::DeserializeOwned>(
144        response: reqwest::Response,
145    ) -> crate::Result<T> {
146        if !response.status().is_success() {
147            let status = response.status().as_u16();
148            let error: ErrorResponse = response.json().await?;
149            return Err(error.into_error(status));
150        }
151        Ok(response.json().await?)
152    }
153
154    pub async fn create(&self, request: CreateBatchRequest) -> crate::Result<MessageBatch> {
155        let url = self.build_url("");
156        let response = self
157            .build_request(reqwest::Method::POST, &url)
158            .await
159            .json(&request)
160            .send()
161            .await?;
162        Self::handle_response(response).await
163    }
164
165    pub async fn get(&self, batch_id: &str) -> crate::Result<MessageBatch> {
166        let url = self.build_url(&format!("/{}", batch_id));
167        let response = self
168            .build_request(reqwest::Method::GET, &url)
169            .await
170            .send()
171            .await?;
172        Self::handle_response(response).await
173    }
174
175    pub async fn cancel(&self, batch_id: &str) -> crate::Result<MessageBatch> {
176        let url = self.build_url(&format!("/{}/cancel", batch_id));
177        let response = self
178            .build_request(reqwest::Method::POST, &url)
179            .await
180            .send()
181            .await?;
182        Self::handle_response(response).await
183    }
184
185    pub async fn list(
186        &self,
187        limit: Option<u32>,
188        after_id: Option<&str>,
189    ) -> crate::Result<BatchListResponse> {
190        let mut url = self.build_url("");
191
192        let mut query_params: Vec<(&str, String)> = Vec::new();
193        if let Some(limit) = limit {
194            query_params.push(("limit", limit.to_string()));
195        }
196        if let Some(after_id) = after_id {
197            query_params.push(("after_id", after_id.to_string()));
198        }
199        if !query_params.is_empty() {
200            let encoded: String = form_urlencoded::Serializer::new(String::new())
201                .extend_pairs(query_params.iter().map(|(k, v)| (*k, v.as_str())))
202                .finish();
203            url = format!("{}?{}", url, encoded);
204        }
205
206        let response = self
207            .build_request(reqwest::Method::GET, &url)
208            .await
209            .send()
210            .await?;
211        Self::handle_response(response).await
212    }
213
214    pub async fn results(&self, batch_id: &str) -> crate::Result<Vec<BatchResult>> {
215        let batch = self.get(batch_id).await?;
216
217        let results_url = batch.results_url.ok_or_else(|| crate::Error::Api {
218            message: "Batch results not yet available".to_string(),
219            status: None,
220            error_type: None,
221        })?;
222
223        let mut request = self
224            .client
225            .http()
226            .get(&results_url)
227            .header("anthropic-version", self.api_version());
228
229        request = self.client.adapter().apply_auth_headers(request).await;
230
231        let response = request.send().await?;
232
233        if !response.status().is_success() {
234            let status = response.status().as_u16();
235            return Err(crate::Error::Api {
236                message: format!("Failed to fetch batch results: HTTP {}", status),
237                status: Some(status),
238                error_type: None,
239            });
240        }
241
242        let text = response.text().await?;
243        let results: Vec<BatchResult> = text
244            .lines()
245            .filter(|line| !line.is_empty())
246            .filter_map(|line| serde_json::from_str(line).ok())
247            .collect();
248
249        Ok(results)
250    }
251
252    pub async fn wait_for_completion(
253        &self,
254        batch_id: &str,
255        poll_interval: std::time::Duration,
256    ) -> crate::Result<MessageBatch> {
257        loop {
258            let batch = self.get(batch_id).await?;
259            if batch.processing_status == BatchStatus::Ended {
260                return Ok(batch);
261            }
262            tokio::time::sleep(poll_interval).await;
263        }
264    }
265}
266
267#[cfg(test)]
268mod tests {
269    use super::*;
270
271    #[test]
272    fn test_batch_request_serialization() {
273        let request = BatchRequest::new(
274            "test-1",
275            CreateMessageRequest::new(
276                "claude-sonnet-4-5",
277                vec![crate::types::Message::user("Hello")],
278            ),
279        );
280        let json = serde_json::to_string(&request).unwrap();
281        assert!(json.contains("test-1"));
282    }
283
284    #[test]
285    fn test_batch_status_deserialization() {
286        let json = r#""in_progress""#;
287        let status: BatchStatus = serde_json::from_str(json).unwrap();
288        assert_eq!(status, BatchStatus::InProgress);
289    }
290
291    #[test]
292    fn test_batch_status_all_variants() {
293        assert_eq!(
294            serde_json::from_str::<BatchStatus>(r#""canceling""#).unwrap(),
295            BatchStatus::Canceling
296        );
297        assert_eq!(
298            serde_json::from_str::<BatchStatus>(r#""ended""#).unwrap(),
299            BatchStatus::Ended
300        );
301    }
302
303    #[test]
304    fn test_create_batch_request_builder() {
305        let req1 = BatchRequest::new(
306            "req-1",
307            CreateMessageRequest::new("claude-sonnet-4-5", vec![crate::types::Message::user("A")]),
308        );
309        let req2 = BatchRequest::new(
310            "req-2",
311            CreateMessageRequest::new("claude-sonnet-4-5", vec![crate::types::Message::user("B")]),
312        );
313
314        let batch = CreateBatchRequest::new(vec![req1]).request(req2);
315        assert_eq!(batch.requests.len(), 2);
316        assert_eq!(batch.requests[0].custom_id, "req-1");
317        assert_eq!(batch.requests[1].custom_id, "req-2");
318    }
319
320    #[test]
321    fn test_request_counts_default() {
322        let counts = RequestCounts::default();
323        assert_eq!(counts.processing, 0);
324        assert_eq!(counts.succeeded, 0);
325        assert_eq!(counts.errored, 0);
326        assert_eq!(counts.canceled, 0);
327        assert_eq!(counts.expired, 0);
328    }
329
330    #[test]
331    fn test_request_counts_deserialization() {
332        let json = r#"{"processing":5,"succeeded":10,"errored":2,"canceled":1,"expired":0}"#;
333        let counts: RequestCounts = serde_json::from_str(json).unwrap();
334        assert_eq!(counts.processing, 5);
335        assert_eq!(counts.succeeded, 10);
336        assert_eq!(counts.errored, 2);
337        assert_eq!(counts.canceled, 1);
338        assert_eq!(counts.expired, 0);
339    }
340
341    #[test]
342    fn test_batch_error_deserialization() {
343        let json = r#"{"type":"invalid_request","message":"Bad input"}"#;
344        let error: BatchError = serde_json::from_str(json).unwrap();
345        assert_eq!(error.error_type, "invalid_request");
346        assert_eq!(error.message, "Bad input");
347    }
348
349    #[test]
350    fn test_batch_result_succeeded() {
351        let json = r#"{
352            "custom_id": "req-1",
353            "result": {
354                "type": "succeeded",
355                "message": {
356                    "id": "msg_123",
357                    "type": "message",
358                    "role": "assistant",
359                    "content": [{"type": "text", "text": "Hello"}],
360                    "model": "claude-sonnet-4-5",
361                    "stop_reason": "end_turn",
362                    "usage": {"input_tokens": 10, "output_tokens": 5}
363                }
364            }
365        }"#;
366        let result: BatchResult = serde_json::from_str(json).unwrap();
367        assert_eq!(result.custom_id, "req-1");
368        assert!(matches!(result.result, BatchResultType::Succeeded { .. }));
369    }
370
371    #[test]
372    fn test_batch_result_errored() {
373        let json = r#"{
374            "custom_id": "req-2",
375            "result": {
376                "type": "errored",
377                "error": {
378                    "type": "rate_limit",
379                    "message": "Too many requests"
380                }
381            }
382        }"#;
383        let result: BatchResult = serde_json::from_str(json).unwrap();
384        assert_eq!(result.custom_id, "req-2");
385        if let BatchResultType::Errored { error } = result.result {
386            assert_eq!(error.error_type, "rate_limit");
387            assert_eq!(error.message, "Too many requests");
388        } else {
389            panic!("Expected Errored variant");
390        }
391    }
392
393    #[test]
394    fn test_batch_result_canceled() {
395        let json = r#"{"custom_id": "req-3", "result": {"type": "canceled"}}"#;
396        let result: BatchResult = serde_json::from_str(json).unwrap();
397        assert!(matches!(result.result, BatchResultType::Canceled));
398    }
399
400    #[test]
401    fn test_batch_result_expired() {
402        let json = r#"{"custom_id": "req-4", "result": {"type": "expired"}}"#;
403        let result: BatchResult = serde_json::from_str(json).unwrap();
404        assert!(matches!(result.result, BatchResultType::Expired));
405    }
406
407    #[test]
408    fn test_message_batch_deserialization() {
409        let json = r#"{
410            "id": "batch_123",
411            "type": "message_batch",
412            "processing_status": "in_progress",
413            "request_counts": {"processing": 5, "succeeded": 0, "errored": 0, "canceled": 0, "expired": 0},
414            "created_at": "2024-01-01T00:00:00Z",
415            "expires_at": "2024-01-02T00:00:00Z",
416            "ended_at": null,
417            "cancel_initiated_at": null,
418            "results_url": null
419        }"#;
420        let batch: MessageBatch = serde_json::from_str(json).unwrap();
421        assert_eq!(batch.id, "batch_123");
422        assert_eq!(batch.processing_status, BatchStatus::InProgress);
423        assert_eq!(batch.request_counts.processing, 5);
424        assert!(batch.ended_at.is_none());
425        assert!(batch.results_url.is_none());
426    }
427
428    #[test]
429    fn test_batch_list_response() {
430        let json = r#"{
431            "data": [],
432            "has_more": false,
433            "first_id": null,
434            "last_id": null
435        }"#;
436        let response: BatchListResponse = serde_json::from_str(json).unwrap();
437        assert!(response.data.is_empty());
438        assert!(!response.has_more);
439    }
440
441    #[test]
442    fn test_batch_request_with_all_params() {
443        let request = CreateMessageRequest::new(
444            "claude-sonnet-4-5",
445            vec![crate::types::Message::user("Test")],
446        )
447        .max_tokens(1000)
448        .temperature(0.5);
449
450        let batch_req = BatchRequest::new("custom-id-123", request);
451        assert_eq!(batch_req.custom_id, "custom-id-123");
452        assert_eq!(batch_req.params.max_tokens, 1000);
453    }
454}