Skip to main content

claude_api/batches/
api.rs

1//! The async `Batches<'a>` namespace.
2
3#![cfg(feature = "async")]
4
5use std::time::Instant;
6
7use futures_util::stream::{BoxStream, Stream, StreamExt};
8use serde::Serialize;
9
10use crate::client::Client;
11use crate::error::{Error, Result};
12use crate::pagination::Paginated;
13
14use super::types::{
15    BatchDeleted, BatchRequest, BatchResultItem, ListBatchesParams, MessageBatch, ProcessingStatus,
16    WaitOptions,
17};
18
19/// Namespace handle for the Batches API.
20pub struct Batches<'a> {
21    client: &'a Client,
22}
23
24impl<'a> Batches<'a> {
25    pub(crate) fn new(client: &'a Client) -> Self {
26        Self { client }
27    }
28
29    /// Submit a batch of [`BatchRequest`] entries. Returns the initial
30    /// [`MessageBatch`] status (typically `processing_status: in_progress`).
31    pub async fn create(&self, requests: Vec<BatchRequest>) -> Result<MessageBatch> {
32        #[derive(Serialize)]
33        struct Envelope<'r> {
34            requests: &'r [BatchRequest],
35        }
36        let envelope = Envelope {
37            requests: &requests,
38        };
39        let envelope_ref = &envelope;
40        self.client
41            .execute_with_retry(
42                || {
43                    self.client
44                        .request_builder(reqwest::Method::POST, "/v1/messages/batches")
45                        .json(envelope_ref)
46                },
47                &[],
48            )
49            .await
50    }
51
52    /// Fetch the current status of a batch by id.
53    pub async fn get(&self, id: &str) -> Result<MessageBatch> {
54        let path = format!("/v1/messages/batches/{id}");
55        self.client
56            .execute_with_retry(
57                || self.client.request_builder(reqwest::Method::GET, &path),
58                &[],
59            )
60            .await
61    }
62
63    /// Fetch one page of batches.
64    pub async fn list(&self, params: ListBatchesParams) -> Result<Paginated<MessageBatch>> {
65        let params_ref = &params;
66        self.client
67            .execute_with_retry(
68                || {
69                    self.client
70                        .request_builder(reqwest::Method::GET, "/v1/messages/batches")
71                        .query(params_ref)
72                },
73                &[],
74            )
75            .await
76    }
77
78    /// Fetch every batch, transparently paging.
79    pub async fn list_all(&self) -> Result<Vec<MessageBatch>> {
80        let mut all = Vec::new();
81        let mut params = ListBatchesParams::default();
82        loop {
83            let page = self.list(params.clone()).await?;
84            let next_cursor = page.next_after().map(str::to_owned);
85            all.extend(page.data);
86            match next_cursor {
87                Some(cursor) => params.after_id = Some(cursor),
88                None => break,
89            }
90        }
91        Ok(all)
92    }
93
94    /// Request cancellation of a batch. Already-running entries continue
95    /// until they finish; the batch transitions to `Canceling` and then
96    /// to `Ended` once those settle.
97    pub async fn cancel(&self, id: &str) -> Result<MessageBatch> {
98        let path = format!("/v1/messages/batches/{id}/cancel");
99        self.client
100            .execute_with_retry(
101                || self.client.request_builder(reqwest::Method::POST, &path),
102                &[],
103            )
104            .await
105    }
106
107    /// Delete a batch. Allowed only after the batch has ended.
108    pub async fn delete(&self, id: &str) -> Result<BatchDeleted> {
109        let path = format!("/v1/messages/batches/{id}");
110        self.client
111            .execute_with_retry(
112                || self.client.request_builder(reqwest::Method::DELETE, &path),
113                &[],
114            )
115            .await
116    }
117
118    /// Poll [`Self::get`] until the batch's `processing_status` is
119    /// [`ProcessingStatus::Ended`] (or any other terminal status the
120    /// server reports). Returns the final [`MessageBatch`].
121    ///
122    /// Honors [`WaitOptions::poll_interval`] between calls and
123    /// [`WaitOptions::timeout`] as an overall ceiling.
124    pub async fn wait_for(&self, id: &str, options: WaitOptions) -> Result<MessageBatch> {
125        let started = Instant::now();
126        loop {
127            let batch = self.get(id).await?;
128            if matches!(
129                batch.processing_status,
130                ProcessingStatus::Ended | ProcessingStatus::Other
131            ) {
132                return Ok(batch);
133            }
134            if let Some(timeout) = options.timeout
135                && started.elapsed() >= timeout
136            {
137                return Err(Error::InvalidConfig(format!(
138                    "wait_for({id}) timed out after {:?}",
139                    started.elapsed()
140                )));
141            }
142            tokio::time::sleep(options.poll_interval).await;
143        }
144    }
145
146    /// Fetch all batch results into a Vec. Convenience wrapper over
147    /// [`Self::results_stream`] for callers that don't need streaming.
148    pub async fn results(&self, id: &str) -> Result<Vec<BatchResultItem>> {
149        let mut stream = self.results_stream(id).await?;
150        let mut out = Vec::new();
151        while let Some(item) = stream.next().await {
152            out.push(item?);
153        }
154        Ok(out)
155    }
156
157    /// Stream the JSONL results body line-by-line, decoding each line as
158    /// a [`BatchResultItem`]. Returns immediately after the connection
159    /// is established; consumes lazily as the caller polls the stream.
160    ///
161    /// Mid-stream connection failures are surfaced as stream items;
162    /// retries are *not* applied (consistent with the SSE streaming
163    /// design -- silent retry would drop content).
164    pub async fn results_stream(
165        &self,
166        id: &str,
167    ) -> Result<BoxStream<'static, Result<BatchResultItem>>> {
168        let path = format!("/v1/messages/batches/{id}/results");
169        let response = self
170            .client
171            .execute_streaming(
172                self.client.request_builder(reqwest::Method::GET, &path),
173                &[],
174            )
175            .await?;
176        Ok(jsonl_stream(response).boxed())
177    }
178}
179
180/// Wrap a streaming response body in a `Stream` that yields one decoded
181/// `T` per JSONL line.
182fn jsonl_stream<T>(response: reqwest::Response) -> impl Stream<Item = Result<T>> + Send + 'static
183where
184    T: serde::de::DeserializeOwned + Send + 'static,
185{
186    futures_util::stream::unfold(
187        (response.bytes_stream(), Vec::<u8>::new(), false),
188        |(mut bytes, mut buffer, done)| async move {
189            if done && buffer.is_empty() {
190                return None;
191            }
192            loop {
193                // Try to extract the next complete line from the buffer.
194                if let Some(newline_idx) = buffer.iter().position(|&b| b == b'\n') {
195                    let line: Vec<u8> = buffer.drain(..=newline_idx).collect();
196                    let trimmed = trim_trailing_newline(&line);
197                    if trimmed.is_empty() {
198                        // Skip blank lines; loop to extract the next.
199                        continue;
200                    }
201                    let parsed: Result<T> = serde_json::from_slice(trimmed).map_err(Error::from);
202                    return Some((parsed, (bytes, buffer, done)));
203                }
204
205                // Need more bytes from the upstream stream.
206                match bytes.next().await {
207                    Some(Ok(chunk)) => buffer.extend_from_slice(&chunk),
208                    Some(Err(e)) => {
209                        return Some((Err(Error::from(e)), (bytes, buffer, true)));
210                    }
211                    None => {
212                        // Upstream EOF. Flush any trailing partial line.
213                        if buffer.is_empty() {
214                            return None;
215                        }
216                        let trimmed = trim_trailing_newline(&buffer);
217                        let parsed: Result<T> =
218                            serde_json::from_slice(trimmed).map_err(Error::from);
219                        buffer.clear();
220                        return Some((parsed, (bytes, buffer, true)));
221                    }
222                }
223            }
224        },
225    )
226}
227
228fn trim_trailing_newline(bytes: &[u8]) -> &[u8] {
229    let mut end = bytes.len();
230    while end > 0 && (bytes[end - 1] == b'\n' || bytes[end - 1] == b'\r') {
231        end -= 1;
232    }
233    &bytes[..end]
234}
235
236#[cfg(test)]
237mod tests {
238    use super::*;
239    use crate::batches::types::BatchResultPayload;
240    use pretty_assertions::assert_eq;
241    use serde_json::json;
242    use wiremock::matchers::{body_partial_json, method, path};
243    use wiremock::{Mock, MockServer, ResponseTemplate};
244
245    fn client_for(mock: &MockServer) -> Client {
246        Client::builder()
247            .api_key("sk-ant-test")
248            .base_url(mock.uri())
249            .build()
250            .unwrap()
251    }
252
253    fn batch_in_progress() -> serde_json::Value {
254        json!({
255            "id": "msgbatch_01",
256            "type": "message_batch",
257            "processing_status": "in_progress",
258            "request_counts": {
259                "processing": 2, "succeeded": 0, "errored": 0,
260                "canceled": 0, "expired": 0
261            },
262            "created_at": "2026-04-30T00:00:00Z",
263            "expires_at": "2026-05-01T00:00:00Z"
264        })
265    }
266
267    fn batch_ended() -> serde_json::Value {
268        json!({
269            "id": "msgbatch_01",
270            "type": "message_batch",
271            "processing_status": "ended",
272            "request_counts": {
273                "processing": 0, "succeeded": 2, "errored": 0,
274                "canceled": 0, "expired": 0
275            },
276            "created_at": "2026-04-30T00:00:00Z",
277            "expires_at": "2026-05-01T00:00:00Z",
278            "ended_at": "2026-04-30T01:00:00Z",
279            "results_url": "https://example/results"
280        })
281    }
282
283    #[tokio::test]
284    async fn create_posts_envelope_with_requests_array() {
285        use crate::messages::request::CreateMessageRequest;
286        use crate::types::ModelId;
287
288        let mock = MockServer::start().await;
289        Mock::given(method("POST"))
290            .and(path("/v1/messages/batches"))
291            .and(body_partial_json(json!({
292                "requests": [
293                    {
294                        "custom_id": "r1",
295                        "params": {
296                            "model": "claude-sonnet-4-6",
297                            "max_tokens": 8,
298                            "messages": [{"role": "user", "content": "hi"}]
299                        }
300                    }
301                ]
302            })))
303            .respond_with(ResponseTemplate::new(200).set_body_json(batch_in_progress()))
304            .mount(&mock)
305            .await;
306
307        let client = client_for(&mock);
308        let req = CreateMessageRequest::builder()
309            .model(ModelId::SONNET_4_6)
310            .max_tokens(8)
311            .user("hi")
312            .build()
313            .unwrap();
314        let batch = client
315            .batches()
316            .create(vec![BatchRequest::new("r1", req)])
317            .await
318            .unwrap();
319        assert_eq!(batch.id, "msgbatch_01");
320        assert_eq!(batch.processing_status, ProcessingStatus::InProgress);
321    }
322
323    #[tokio::test]
324    async fn get_returns_status_for_id() {
325        let mock = MockServer::start().await;
326        Mock::given(method("GET"))
327            .and(path("/v1/messages/batches/msgbatch_01"))
328            .respond_with(ResponseTemplate::new(200).set_body_json(batch_ended()))
329            .mount(&mock)
330            .await;
331
332        let client = client_for(&mock);
333        let b = client.batches().get("msgbatch_01").await.unwrap();
334        assert_eq!(b.processing_status, ProcessingStatus::Ended);
335        assert_eq!(b.request_counts.succeeded, 2);
336    }
337
338    #[tokio::test]
339    async fn cancel_transitions_to_canceling() {
340        let mock = MockServer::start().await;
341        Mock::given(method("POST"))
342            .and(path("/v1/messages/batches/msgbatch_01/cancel"))
343            .respond_with(ResponseTemplate::new(200).set_body_json(json!({
344                "id": "msgbatch_01",
345                "type": "message_batch",
346                "processing_status": "canceling",
347                "request_counts": {
348                    "processing": 1, "succeeded": 0, "errored": 0,
349                    "canceled": 1, "expired": 0
350                },
351                "created_at": "2026-04-30T00:00:00Z",
352                "expires_at": "2026-05-01T00:00:00Z",
353                "cancel_initiated_at": "2026-04-30T00:30:00Z"
354            })))
355            .mount(&mock)
356            .await;
357
358        let client = client_for(&mock);
359        let b = client.batches().cancel("msgbatch_01").await.unwrap();
360        assert_eq!(b.processing_status, ProcessingStatus::Canceling);
361        assert!(b.cancel_initiated_at.is_some());
362    }
363
364    #[tokio::test]
365    async fn delete_returns_typed_confirmation() {
366        let mock = MockServer::start().await;
367        Mock::given(method("DELETE"))
368            .and(path("/v1/messages/batches/msgbatch_01"))
369            .respond_with(ResponseTemplate::new(200).set_body_json(json!({
370                "id": "msgbatch_01",
371                "type": "message_batch_deleted"
372            })))
373            .mount(&mock)
374            .await;
375
376        let client = client_for(&mock);
377        let d = client.batches().delete("msgbatch_01").await.unwrap();
378        assert_eq!(d.id, "msgbatch_01");
379        assert_eq!(d.kind, "message_batch_deleted");
380    }
381
382    #[tokio::test]
383    async fn list_returns_paginated_envelope() {
384        let mock = MockServer::start().await;
385        Mock::given(method("GET"))
386            .and(path("/v1/messages/batches"))
387            .respond_with(ResponseTemplate::new(200).set_body_json(json!({
388                "data": [batch_in_progress()],
389                "has_more": false,
390                "first_id": "msgbatch_01",
391                "last_id": "msgbatch_01"
392            })))
393            .mount(&mock)
394            .await;
395
396        let client = client_for(&mock);
397        let page = client
398            .batches()
399            .list(ListBatchesParams::default())
400            .await
401            .unwrap();
402        assert_eq!(page.data.len(), 1);
403    }
404
405    #[tokio::test]
406    async fn wait_for_polls_until_ended() {
407        let mock = MockServer::start().await;
408        Mock::given(method("GET"))
409            .and(path("/v1/messages/batches/msgbatch_01"))
410            .respond_with(ResponseTemplate::new(200).set_body_json(batch_in_progress()))
411            .up_to_n_times(2)
412            .mount(&mock)
413            .await;
414        Mock::given(method("GET"))
415            .and(path("/v1/messages/batches/msgbatch_01"))
416            .respond_with(ResponseTemplate::new(200).set_body_json(batch_ended()))
417            .mount(&mock)
418            .await;
419
420        let client = client_for(&mock);
421        let opts = WaitOptions::default()
422            .poll_interval(std::time::Duration::from_millis(1))
423            .timeout(std::time::Duration::from_secs(5));
424        let final_batch = client
425            .batches()
426            .wait_for("msgbatch_01", opts)
427            .await
428            .unwrap();
429        assert_eq!(final_batch.processing_status, ProcessingStatus::Ended);
430    }
431
432    #[tokio::test]
433    async fn wait_for_honors_timeout() {
434        let mock = MockServer::start().await;
435        Mock::given(method("GET"))
436            .and(path("/v1/messages/batches/msgbatch_01"))
437            .respond_with(ResponseTemplate::new(200).set_body_json(batch_in_progress()))
438            .mount(&mock)
439            .await;
440
441        let client = client_for(&mock);
442        let opts = WaitOptions::default()
443            .poll_interval(std::time::Duration::from_millis(1))
444            .timeout(std::time::Duration::from_millis(20));
445        let err = client
446            .batches()
447            .wait_for("msgbatch_01", opts)
448            .await
449            .unwrap_err();
450        assert!(matches!(err, Error::InvalidConfig(_)));
451    }
452
453    #[tokio::test]
454    async fn results_decodes_jsonl_into_typed_items() {
455        let jsonl = "\
456{\"custom_id\":\"r1\",\"result\":{\"type\":\"succeeded\",\"message\":{\"id\":\"m1\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"text\",\"text\":\"a\"}],\"model\":\"claude-sonnet-4-6\",\"usage\":{\"input_tokens\":1,\"output_tokens\":1}}}}
457{\"custom_id\":\"r2\",\"result\":{\"type\":\"errored\",\"error\":{\"type\":\"rate_limit_error\",\"message\":\"slow\"}}}
458{\"custom_id\":\"r3\",\"result\":{\"type\":\"canceled\"}}
459";
460        let mock = MockServer::start().await;
461        Mock::given(method("GET"))
462            .and(path("/v1/messages/batches/msgbatch_01/results"))
463            .respond_with(
464                ResponseTemplate::new(200)
465                    .insert_header("content-type", "application/x-jsonl")
466                    .set_body_string(jsonl),
467            )
468            .mount(&mock)
469            .await;
470
471        let client = client_for(&mock);
472        let items = client.batches().results("msgbatch_01").await.unwrap();
473        assert_eq!(items.len(), 3);
474        assert_eq!(items[0].custom_id, "r1");
475        assert!(matches!(
476            items[0].result,
477            BatchResultPayload::Succeeded { .. }
478        ));
479        assert_eq!(items[1].custom_id, "r2");
480        assert!(matches!(
481            items[1].result,
482            BatchResultPayload::Errored { .. }
483        ));
484        assert!(matches!(items[2].result, BatchResultPayload::Canceled));
485    }
486
487    #[tokio::test]
488    async fn results_stream_yields_items_lazily() {
489        let jsonl = "\
490{\"custom_id\":\"a\",\"result\":{\"type\":\"canceled\"}}
491{\"custom_id\":\"b\",\"result\":{\"type\":\"expired\"}}
492";
493        let mock = MockServer::start().await;
494        Mock::given(method("GET"))
495            .and(path("/v1/messages/batches/msgbatch_01/results"))
496            .respond_with(
497                ResponseTemplate::new(200)
498                    .insert_header("content-type", "application/x-jsonl")
499                    .set_body_string(jsonl),
500            )
501            .mount(&mock)
502            .await;
503
504        let client = client_for(&mock);
505        let mut stream = client
506            .batches()
507            .results_stream("msgbatch_01")
508            .await
509            .unwrap();
510
511        let first = stream.next().await.unwrap().unwrap();
512        assert_eq!(first.custom_id, "a");
513        let second = stream.next().await.unwrap().unwrap();
514        assert_eq!(second.custom_id, "b");
515        assert!(stream.next().await.is_none());
516    }
517
518    #[tokio::test]
519    async fn results_stream_skips_blank_lines() {
520        let jsonl = concat!(
521            "\n",
522            "{\"custom_id\":\"a\",\"result\":{\"type\":\"canceled\"}}\n",
523            "\n",
524            "{\"custom_id\":\"b\",\"result\":{\"type\":\"expired\"}}\n",
525            "\n",
526        );
527        let mock = MockServer::start().await;
528        Mock::given(method("GET"))
529            .and(path("/v1/messages/batches/msgbatch_01/results"))
530            .respond_with(ResponseTemplate::new(200).set_body_string(jsonl))
531            .mount(&mock)
532            .await;
533
534        let client = client_for(&mock);
535        let items = client.batches().results("msgbatch_01").await.unwrap();
536        assert_eq!(items.len(), 2);
537    }
538
539    #[tokio::test]
540    async fn results_stream_handles_missing_trailing_newline() {
541        // Last line has no trailing \n; must still be emitted.
542        let jsonl = "{\"custom_id\":\"a\",\"result\":{\"type\":\"canceled\"}}\n{\"custom_id\":\"b\",\"result\":{\"type\":\"expired\"}}";
543        let mock = MockServer::start().await;
544        Mock::given(method("GET"))
545            .and(path("/v1/messages/batches/msgbatch_01/results"))
546            .respond_with(ResponseTemplate::new(200).set_body_string(jsonl))
547            .mount(&mock)
548            .await;
549
550        let client = client_for(&mock);
551        let items = client.batches().results("msgbatch_01").await.unwrap();
552        assert_eq!(items.len(), 2);
553    }
554}