claude_agent/client/
batch.rs

1//! Batch Processing API for large-scale asynchronous requests.
2
3use serde::{Deserialize, Serialize};
4use url::form_urlencoded;
5
6use super::messages::{CreateMessageRequest, ErrorResponse};
7use crate::types::ApiResponse;
8
9const BATCH_BASE_URL: &str = "https://api.anthropic.com";
10
11#[derive(Debug, Clone, Serialize)]
12pub struct BatchRequest {
13    pub custom_id: String,
14    pub params: CreateMessageRequest,
15}
16
17impl BatchRequest {
18    pub fn new(custom_id: impl Into<String>, params: CreateMessageRequest) -> Self {
19        Self {
20            custom_id: custom_id.into(),
21            params,
22        }
23    }
24}
25
26#[derive(Debug, Clone, Serialize)]
27pub struct CreateBatchRequest {
28    pub requests: Vec<BatchRequest>,
29}
30
31impl CreateBatchRequest {
32    pub fn new(requests: Vec<BatchRequest>) -> Self {
33        Self { requests }
34    }
35
36    pub fn with_request(mut self, request: BatchRequest) -> Self {
37        self.requests.push(request);
38        self
39    }
40}
41
42#[derive(Debug, Clone, Deserialize)]
43pub struct MessageBatch {
44    pub id: String,
45    #[serde(rename = "type")]
46    pub batch_type: String,
47    pub processing_status: BatchStatus,
48    pub request_counts: RequestCounts,
49    pub ended_at: Option<String>,
50    pub created_at: String,
51    pub expires_at: String,
52    pub cancel_initiated_at: Option<String>,
53    pub results_url: Option<String>,
54}
55
56#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
57#[serde(rename_all = "snake_case")]
58pub enum BatchStatus {
59    InProgress,
60    Canceling,
61    Ended,
62}
63
64#[derive(Debug, Clone, Copy, Default, Deserialize)]
65pub struct RequestCounts {
66    pub processing: u32,
67    pub succeeded: u32,
68    pub errored: u32,
69    pub canceled: u32,
70    pub expired: u32,
71}
72
73#[derive(Debug, Clone, Deserialize)]
74pub struct BatchResult {
75    pub custom_id: String,
76    pub result: BatchResultType,
77}
78
79#[derive(Debug, Clone, Deserialize)]
80#[serde(tag = "type", rename_all = "snake_case")]
81pub enum BatchResultType {
82    Succeeded { message: ApiResponse },
83    Errored { error: BatchError },
84    Canceled,
85    Expired,
86}
87
88#[derive(Debug, Clone, Deserialize)]
89pub struct BatchError {
90    #[serde(rename = "type")]
91    pub error_type: String,
92    pub message: String,
93}
94
95#[derive(Debug, Clone, Deserialize)]
96pub struct BatchListResponse {
97    pub data: Vec<MessageBatch>,
98    pub has_more: bool,
99    pub first_id: Option<String>,
100    pub last_id: Option<String>,
101}
102
103pub struct BatchClient<'a> {
104    client: &'a super::Client,
105}
106
107impl<'a> BatchClient<'a> {
108    pub fn new(client: &'a super::Client) -> Self {
109        Self { client }
110    }
111
112    fn base_url(&self) -> String {
113        std::env::var("ANTHROPIC_BASE_URL").unwrap_or_else(|_| BATCH_BASE_URL.into())
114    }
115
116    fn api_version(&self) -> &str {
117        &self.client.config().api_version
118    }
119
120    fn build_url(&self, path: &str) -> String {
121        format!("{}/v1/messages/batches{}", self.base_url(), path)
122    }
123
124    async fn build_request(&self, method: reqwest::Method, url: &str) -> reqwest::RequestBuilder {
125        if let Err(e) = self.client.adapter().ensure_fresh_credentials().await {
126            tracing::debug!("Proactive credential refresh failed: {}", e);
127        }
128
129        let mut request = self
130            .client
131            .http()
132            .request(method, url)
133            .header("anthropic-version", self.api_version())
134            .header("content-type", "application/json");
135
136        request = self.client.adapter().apply_auth_headers(request).await;
137
138        if let Some(beta_header) = self.client.config().beta.header_value() {
139            request = request.header("anthropic-beta", beta_header);
140        }
141
142        request
143    }
144
145    pub async fn create(&self, request: CreateBatchRequest) -> crate::Result<MessageBatch> {
146        let url = self.build_url("");
147        let request = self
148            .build_request(reqwest::Method::POST, &url)
149            .await
150            .json(&request);
151        let response = request.send().await?;
152
153        if !response.status().is_success() {
154            let status = response.status().as_u16();
155            let error: ErrorResponse = response.json().await?;
156            return Err(error.into_error(status));
157        }
158
159        Ok(response.json().await?)
160    }
161
162    pub async fn get(&self, batch_id: &str) -> crate::Result<MessageBatch> {
163        let url = self.build_url(&format!("/{}", batch_id));
164        let response = self
165            .build_request(reqwest::Method::GET, &url)
166            .await
167            .send()
168            .await?;
169
170        if !response.status().is_success() {
171            let status = response.status().as_u16();
172            let error: ErrorResponse = response.json().await?;
173            return Err(error.into_error(status));
174        }
175
176        Ok(response.json().await?)
177    }
178
179    pub async fn cancel(&self, batch_id: &str) -> crate::Result<MessageBatch> {
180        let url = self.build_url(&format!("/{}/cancel", batch_id));
181        let response = self
182            .build_request(reqwest::Method::POST, &url)
183            .await
184            .send()
185            .await?;
186
187        if !response.status().is_success() {
188            let status = response.status().as_u16();
189            let error: ErrorResponse = response.json().await?;
190            return Err(error.into_error(status));
191        }
192
193        Ok(response.json().await?)
194    }
195
196    pub async fn list(
197        &self,
198        limit: Option<u32>,
199        after_id: Option<&str>,
200    ) -> crate::Result<BatchListResponse> {
201        let mut url = self.build_url("");
202
203        let mut query_params: Vec<(&str, String)> = Vec::new();
204        if let Some(limit) = limit {
205            query_params.push(("limit", limit.to_string()));
206        }
207        if let Some(after_id) = after_id {
208            query_params.push(("after_id", after_id.to_string()));
209        }
210        if !query_params.is_empty() {
211            let encoded: String = form_urlencoded::Serializer::new(String::new())
212                .extend_pairs(query_params.iter().map(|(k, v)| (*k, v.as_str())))
213                .finish();
214            url = format!("{}?{}", url, encoded);
215        }
216
217        let response = self
218            .build_request(reqwest::Method::GET, &url)
219            .await
220            .send()
221            .await?;
222
223        if !response.status().is_success() {
224            let status = response.status().as_u16();
225            let error: ErrorResponse = response.json().await?;
226            return Err(error.into_error(status));
227        }
228
229        Ok(response.json().await?)
230    }
231
232    pub async fn results(&self, batch_id: &str) -> crate::Result<Vec<BatchResult>> {
233        let batch = self.get(batch_id).await?;
234
235        let results_url = batch.results_url.ok_or_else(|| crate::Error::Api {
236            message: "Batch results not yet available".to_string(),
237            status: None,
238            error_type: None,
239        })?;
240
241        let mut request = self
242            .client
243            .http()
244            .get(&results_url)
245            .header("anthropic-version", self.api_version());
246
247        request = self.client.adapter().apply_auth_headers(request).await;
248
249        let response = request.send().await?;
250
251        if !response.status().is_success() {
252            let status = response.status().as_u16();
253            return Err(crate::Error::Api {
254                message: format!("Failed to fetch batch results: HTTP {}", status),
255                status: Some(status),
256                error_type: None,
257            });
258        }
259
260        let text = response.text().await?;
261        let results: Vec<BatchResult> = text
262            .lines()
263            .filter(|line| !line.is_empty())
264            .filter_map(|line| serde_json::from_str(line).ok())
265            .collect();
266
267        Ok(results)
268    }
269
270    pub async fn wait_for_completion(
271        &self,
272        batch_id: &str,
273        poll_interval: std::time::Duration,
274    ) -> crate::Result<MessageBatch> {
275        loop {
276            let batch = self.get(batch_id).await?;
277            if batch.processing_status == BatchStatus::Ended {
278                return Ok(batch);
279            }
280            tokio::time::sleep(poll_interval).await;
281        }
282    }
283}
284
285#[cfg(test)]
286mod tests {
287    use super::*;
288
289    #[test]
290    fn test_batch_request_serialization() {
291        let request = BatchRequest::new(
292            "test-1",
293            CreateMessageRequest::new(
294                "claude-sonnet-4-5",
295                vec![crate::types::Message::user("Hello")],
296            ),
297        );
298        let json = serde_json::to_string(&request).unwrap();
299        assert!(json.contains("test-1"));
300    }
301
302    #[test]
303    fn test_batch_status_deserialization() {
304        let json = r#""in_progress""#;
305        let status: BatchStatus = serde_json::from_str(json).unwrap();
306        assert_eq!(status, BatchStatus::InProgress);
307    }
308
309    #[test]
310    fn test_batch_status_all_variants() {
311        assert_eq!(
312            serde_json::from_str::<BatchStatus>(r#""canceling""#).unwrap(),
313            BatchStatus::Canceling
314        );
315        assert_eq!(
316            serde_json::from_str::<BatchStatus>(r#""ended""#).unwrap(),
317            BatchStatus::Ended
318        );
319    }
320
321    #[test]
322    fn test_create_batch_request_builder() {
323        let req1 = BatchRequest::new(
324            "req-1",
325            CreateMessageRequest::new("claude-sonnet-4-5", vec![crate::types::Message::user("A")]),
326        );
327        let req2 = BatchRequest::new(
328            "req-2",
329            CreateMessageRequest::new("claude-sonnet-4-5", vec![crate::types::Message::user("B")]),
330        );
331
332        let batch = CreateBatchRequest::new(vec![req1]).with_request(req2);
333        assert_eq!(batch.requests.len(), 2);
334        assert_eq!(batch.requests[0].custom_id, "req-1");
335        assert_eq!(batch.requests[1].custom_id, "req-2");
336    }
337
338    #[test]
339    fn test_request_counts_default() {
340        let counts = RequestCounts::default();
341        assert_eq!(counts.processing, 0);
342        assert_eq!(counts.succeeded, 0);
343        assert_eq!(counts.errored, 0);
344        assert_eq!(counts.canceled, 0);
345        assert_eq!(counts.expired, 0);
346    }
347
348    #[test]
349    fn test_request_counts_deserialization() {
350        let json = r#"{"processing":5,"succeeded":10,"errored":2,"canceled":1,"expired":0}"#;
351        let counts: RequestCounts = serde_json::from_str(json).unwrap();
352        assert_eq!(counts.processing, 5);
353        assert_eq!(counts.succeeded, 10);
354        assert_eq!(counts.errored, 2);
355        assert_eq!(counts.canceled, 1);
356        assert_eq!(counts.expired, 0);
357    }
358
359    #[test]
360    fn test_batch_error_deserialization() {
361        let json = r#"{"type":"invalid_request","message":"Bad input"}"#;
362        let error: BatchError = serde_json::from_str(json).unwrap();
363        assert_eq!(error.error_type, "invalid_request");
364        assert_eq!(error.message, "Bad input");
365    }
366
367    #[test]
368    fn test_batch_result_succeeded() {
369        let json = r#"{
370            "custom_id": "req-1",
371            "result": {
372                "type": "succeeded",
373                "message": {
374                    "id": "msg_123",
375                    "type": "message",
376                    "role": "assistant",
377                    "content": [{"type": "text", "text": "Hello"}],
378                    "model": "claude-sonnet-4-5",
379                    "stop_reason": "end_turn",
380                    "usage": {"input_tokens": 10, "output_tokens": 5}
381                }
382            }
383        }"#;
384        let result: BatchResult = serde_json::from_str(json).unwrap();
385        assert_eq!(result.custom_id, "req-1");
386        assert!(matches!(result.result, BatchResultType::Succeeded { .. }));
387    }
388
389    #[test]
390    fn test_batch_result_errored() {
391        let json = r#"{
392            "custom_id": "req-2",
393            "result": {
394                "type": "errored",
395                "error": {
396                    "type": "rate_limit",
397                    "message": "Too many requests"
398                }
399            }
400        }"#;
401        let result: BatchResult = serde_json::from_str(json).unwrap();
402        assert_eq!(result.custom_id, "req-2");
403        if let BatchResultType::Errored { error } = result.result {
404            assert_eq!(error.error_type, "rate_limit");
405            assert_eq!(error.message, "Too many requests");
406        } else {
407            panic!("Expected Errored variant");
408        }
409    }
410
411    #[test]
412    fn test_batch_result_canceled() {
413        let json = r#"{"custom_id": "req-3", "result": {"type": "canceled"}}"#;
414        let result: BatchResult = serde_json::from_str(json).unwrap();
415        assert!(matches!(result.result, BatchResultType::Canceled));
416    }
417
418    #[test]
419    fn test_batch_result_expired() {
420        let json = r#"{"custom_id": "req-4", "result": {"type": "expired"}}"#;
421        let result: BatchResult = serde_json::from_str(json).unwrap();
422        assert!(matches!(result.result, BatchResultType::Expired));
423    }
424
425    #[test]
426    fn test_message_batch_deserialization() {
427        let json = r#"{
428            "id": "batch_123",
429            "type": "message_batch",
430            "processing_status": "in_progress",
431            "request_counts": {"processing": 5, "succeeded": 0, "errored": 0, "canceled": 0, "expired": 0},
432            "created_at": "2024-01-01T00:00:00Z",
433            "expires_at": "2024-01-02T00:00:00Z",
434            "ended_at": null,
435            "cancel_initiated_at": null,
436            "results_url": null
437        }"#;
438        let batch: MessageBatch = serde_json::from_str(json).unwrap();
439        assert_eq!(batch.id, "batch_123");
440        assert_eq!(batch.processing_status, BatchStatus::InProgress);
441        assert_eq!(batch.request_counts.processing, 5);
442        assert!(batch.ended_at.is_none());
443        assert!(batch.results_url.is_none());
444    }
445
446    #[test]
447    fn test_batch_list_response() {
448        let json = r#"{
449            "data": [],
450            "has_more": false,
451            "first_id": null,
452            "last_id": null
453        }"#;
454        let response: BatchListResponse = serde_json::from_str(json).unwrap();
455        assert!(response.data.is_empty());
456        assert!(!response.has_more);
457    }
458
459    #[test]
460    fn test_batch_request_with_all_params() {
461        let request = CreateMessageRequest::new(
462            "claude-sonnet-4-5",
463            vec![crate::types::Message::user("Test")],
464        )
465        .with_max_tokens(1000)
466        .with_temperature(0.5);
467
468        let batch_req = BatchRequest::new("custom-id-123", request);
469        assert_eq!(batch_req.custom_id, "custom-id-123");
470        assert_eq!(batch_req.params.max_tokens, 1000);
471    }
472}