1#![cfg(feature = "async")]
4
5use std::time::Instant;
6
7use futures_util::stream::{BoxStream, Stream, StreamExt};
8use serde::Serialize;
9
10use crate::client::Client;
11use crate::error::{Error, Result};
12use crate::pagination::Paginated;
13
14use super::types::{
15 BatchDeleted, BatchRequest, BatchResultItem, ListBatchesParams, MessageBatch, ProcessingStatus,
16 WaitOptions,
17};
18
19pub struct Batches<'a> {
21 client: &'a Client,
22}
23
24impl<'a> Batches<'a> {
25 pub(crate) fn new(client: &'a Client) -> Self {
26 Self { client }
27 }
28
29 pub async fn create(&self, requests: Vec<BatchRequest>) -> Result<MessageBatch> {
32 #[derive(Serialize)]
33 struct Envelope<'r> {
34 requests: &'r [BatchRequest],
35 }
36 let envelope = Envelope {
37 requests: &requests,
38 };
39 let envelope_ref = &envelope;
40 self.client
41 .execute_with_retry(
42 || {
43 self.client
44 .request_builder(reqwest::Method::POST, "/v1/messages/batches")
45 .json(envelope_ref)
46 },
47 &[],
48 )
49 .await
50 }
51
52 pub async fn get(&self, id: &str) -> Result<MessageBatch> {
54 let path = format!("/v1/messages/batches/{id}");
55 self.client
56 .execute_with_retry(
57 || self.client.request_builder(reqwest::Method::GET, &path),
58 &[],
59 )
60 .await
61 }
62
63 pub async fn list(&self, params: ListBatchesParams) -> Result<Paginated<MessageBatch>> {
65 let params_ref = ¶ms;
66 self.client
67 .execute_with_retry(
68 || {
69 self.client
70 .request_builder(reqwest::Method::GET, "/v1/messages/batches")
71 .query(params_ref)
72 },
73 &[],
74 )
75 .await
76 }
77
78 pub async fn list_all(&self) -> Result<Vec<MessageBatch>> {
80 let mut all = Vec::new();
81 let mut params = ListBatchesParams::default();
82 loop {
83 let page = self.list(params.clone()).await?;
84 let next_cursor = page.next_after().map(str::to_owned);
85 all.extend(page.data);
86 match next_cursor {
87 Some(cursor) => params.after_id = Some(cursor),
88 None => break,
89 }
90 }
91 Ok(all)
92 }
93
94 pub async fn cancel(&self, id: &str) -> Result<MessageBatch> {
98 let path = format!("/v1/messages/batches/{id}/cancel");
99 self.client
100 .execute_with_retry(
101 || self.client.request_builder(reqwest::Method::POST, &path),
102 &[],
103 )
104 .await
105 }
106
107 pub async fn delete(&self, id: &str) -> Result<BatchDeleted> {
109 let path = format!("/v1/messages/batches/{id}");
110 self.client
111 .execute_with_retry(
112 || self.client.request_builder(reqwest::Method::DELETE, &path),
113 &[],
114 )
115 .await
116 }
117
118 pub async fn wait_for(&self, id: &str, options: WaitOptions) -> Result<MessageBatch> {
125 let started = Instant::now();
126 loop {
127 let batch = self.get(id).await?;
128 if matches!(
129 batch.processing_status,
130 ProcessingStatus::Ended | ProcessingStatus::Other
131 ) {
132 return Ok(batch);
133 }
134 if let Some(timeout) = options.timeout
135 && started.elapsed() >= timeout
136 {
137 return Err(Error::InvalidConfig(format!(
138 "wait_for({id}) timed out after {:?}",
139 started.elapsed()
140 )));
141 }
142 tokio::time::sleep(options.poll_interval).await;
143 }
144 }
145
146 pub async fn results(&self, id: &str) -> Result<Vec<BatchResultItem>> {
149 let mut stream = self.results_stream(id).await?;
150 let mut out = Vec::new();
151 while let Some(item) = stream.next().await {
152 out.push(item?);
153 }
154 Ok(out)
155 }
156
157 pub async fn results_stream(
165 &self,
166 id: &str,
167 ) -> Result<BoxStream<'static, Result<BatchResultItem>>> {
168 let path = format!("/v1/messages/batches/{id}/results");
169 let response = self
170 .client
171 .execute_streaming(
172 self.client.request_builder(reqwest::Method::GET, &path),
173 &[],
174 )
175 .await?;
176 Ok(jsonl_stream(response).boxed())
177 }
178}
179
180fn jsonl_stream<T>(response: reqwest::Response) -> impl Stream<Item = Result<T>> + Send + 'static
183where
184 T: serde::de::DeserializeOwned + Send + 'static,
185{
186 futures_util::stream::unfold(
187 (response.bytes_stream(), Vec::<u8>::new(), false),
188 |(mut bytes, mut buffer, done)| async move {
189 if done && buffer.is_empty() {
190 return None;
191 }
192 loop {
193 if let Some(newline_idx) = buffer.iter().position(|&b| b == b'\n') {
195 let line: Vec<u8> = buffer.drain(..=newline_idx).collect();
196 let trimmed = trim_trailing_newline(&line);
197 if trimmed.is_empty() {
198 continue;
200 }
201 let parsed: Result<T> = serde_json::from_slice(trimmed).map_err(Error::from);
202 return Some((parsed, (bytes, buffer, done)));
203 }
204
205 match bytes.next().await {
207 Some(Ok(chunk)) => buffer.extend_from_slice(&chunk),
208 Some(Err(e)) => {
209 return Some((Err(Error::from(e)), (bytes, buffer, true)));
210 }
211 None => {
212 if buffer.is_empty() {
214 return None;
215 }
216 let trimmed = trim_trailing_newline(&buffer);
217 let parsed: Result<T> =
218 serde_json::from_slice(trimmed).map_err(Error::from);
219 buffer.clear();
220 return Some((parsed, (bytes, buffer, true)));
221 }
222 }
223 }
224 },
225 )
226}
227
228fn trim_trailing_newline(bytes: &[u8]) -> &[u8] {
229 let mut end = bytes.len();
230 while end > 0 && (bytes[end - 1] == b'\n' || bytes[end - 1] == b'\r') {
231 end -= 1;
232 }
233 &bytes[..end]
234}
235
236#[cfg(test)]
237mod tests {
238 use super::*;
239 use crate::batches::types::BatchResultPayload;
240 use pretty_assertions::assert_eq;
241 use serde_json::json;
242 use wiremock::matchers::{body_partial_json, method, path};
243 use wiremock::{Mock, MockServer, ResponseTemplate};
244
245 fn client_for(mock: &MockServer) -> Client {
246 Client::builder()
247 .api_key("sk-ant-test")
248 .base_url(mock.uri())
249 .build()
250 .unwrap()
251 }
252
253 fn batch_in_progress() -> serde_json::Value {
254 json!({
255 "id": "msgbatch_01",
256 "type": "message_batch",
257 "processing_status": "in_progress",
258 "request_counts": {
259 "processing": 2, "succeeded": 0, "errored": 0,
260 "canceled": 0, "expired": 0
261 },
262 "created_at": "2026-04-30T00:00:00Z",
263 "expires_at": "2026-05-01T00:00:00Z"
264 })
265 }
266
267 fn batch_ended() -> serde_json::Value {
268 json!({
269 "id": "msgbatch_01",
270 "type": "message_batch",
271 "processing_status": "ended",
272 "request_counts": {
273 "processing": 0, "succeeded": 2, "errored": 0,
274 "canceled": 0, "expired": 0
275 },
276 "created_at": "2026-04-30T00:00:00Z",
277 "expires_at": "2026-05-01T00:00:00Z",
278 "ended_at": "2026-04-30T01:00:00Z",
279 "results_url": "https://example/results"
280 })
281 }
282
283 #[tokio::test]
284 async fn create_posts_envelope_with_requests_array() {
285 use crate::messages::request::CreateMessageRequest;
286 use crate::types::ModelId;
287
288 let mock = MockServer::start().await;
289 Mock::given(method("POST"))
290 .and(path("/v1/messages/batches"))
291 .and(body_partial_json(json!({
292 "requests": [
293 {
294 "custom_id": "r1",
295 "params": {
296 "model": "claude-sonnet-4-6",
297 "max_tokens": 8,
298 "messages": [{"role": "user", "content": "hi"}]
299 }
300 }
301 ]
302 })))
303 .respond_with(ResponseTemplate::new(200).set_body_json(batch_in_progress()))
304 .mount(&mock)
305 .await;
306
307 let client = client_for(&mock);
308 let req = CreateMessageRequest::builder()
309 .model(ModelId::SONNET_4_6)
310 .max_tokens(8)
311 .user("hi")
312 .build()
313 .unwrap();
314 let batch = client
315 .batches()
316 .create(vec![BatchRequest::new("r1", req)])
317 .await
318 .unwrap();
319 assert_eq!(batch.id, "msgbatch_01");
320 assert_eq!(batch.processing_status, ProcessingStatus::InProgress);
321 }
322
323 #[tokio::test]
324 async fn get_returns_status_for_id() {
325 let mock = MockServer::start().await;
326 Mock::given(method("GET"))
327 .and(path("/v1/messages/batches/msgbatch_01"))
328 .respond_with(ResponseTemplate::new(200).set_body_json(batch_ended()))
329 .mount(&mock)
330 .await;
331
332 let client = client_for(&mock);
333 let b = client.batches().get("msgbatch_01").await.unwrap();
334 assert_eq!(b.processing_status, ProcessingStatus::Ended);
335 assert_eq!(b.request_counts.succeeded, 2);
336 }
337
338 #[tokio::test]
339 async fn cancel_transitions_to_canceling() {
340 let mock = MockServer::start().await;
341 Mock::given(method("POST"))
342 .and(path("/v1/messages/batches/msgbatch_01/cancel"))
343 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
344 "id": "msgbatch_01",
345 "type": "message_batch",
346 "processing_status": "canceling",
347 "request_counts": {
348 "processing": 1, "succeeded": 0, "errored": 0,
349 "canceled": 1, "expired": 0
350 },
351 "created_at": "2026-04-30T00:00:00Z",
352 "expires_at": "2026-05-01T00:00:00Z",
353 "cancel_initiated_at": "2026-04-30T00:30:00Z"
354 })))
355 .mount(&mock)
356 .await;
357
358 let client = client_for(&mock);
359 let b = client.batches().cancel("msgbatch_01").await.unwrap();
360 assert_eq!(b.processing_status, ProcessingStatus::Canceling);
361 assert!(b.cancel_initiated_at.is_some());
362 }
363
364 #[tokio::test]
365 async fn delete_returns_typed_confirmation() {
366 let mock = MockServer::start().await;
367 Mock::given(method("DELETE"))
368 .and(path("/v1/messages/batches/msgbatch_01"))
369 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
370 "id": "msgbatch_01",
371 "type": "message_batch_deleted"
372 })))
373 .mount(&mock)
374 .await;
375
376 let client = client_for(&mock);
377 let d = client.batches().delete("msgbatch_01").await.unwrap();
378 assert_eq!(d.id, "msgbatch_01");
379 assert_eq!(d.kind, "message_batch_deleted");
380 }
381
382 #[tokio::test]
383 async fn list_returns_paginated_envelope() {
384 let mock = MockServer::start().await;
385 Mock::given(method("GET"))
386 .and(path("/v1/messages/batches"))
387 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
388 "data": [batch_in_progress()],
389 "has_more": false,
390 "first_id": "msgbatch_01",
391 "last_id": "msgbatch_01"
392 })))
393 .mount(&mock)
394 .await;
395
396 let client = client_for(&mock);
397 let page = client
398 .batches()
399 .list(ListBatchesParams::default())
400 .await
401 .unwrap();
402 assert_eq!(page.data.len(), 1);
403 }
404
405 #[tokio::test]
406 async fn wait_for_polls_until_ended() {
407 let mock = MockServer::start().await;
408 Mock::given(method("GET"))
409 .and(path("/v1/messages/batches/msgbatch_01"))
410 .respond_with(ResponseTemplate::new(200).set_body_json(batch_in_progress()))
411 .up_to_n_times(2)
412 .mount(&mock)
413 .await;
414 Mock::given(method("GET"))
415 .and(path("/v1/messages/batches/msgbatch_01"))
416 .respond_with(ResponseTemplate::new(200).set_body_json(batch_ended()))
417 .mount(&mock)
418 .await;
419
420 let client = client_for(&mock);
421 let opts = WaitOptions::default()
422 .poll_interval(std::time::Duration::from_millis(1))
423 .timeout(std::time::Duration::from_secs(5));
424 let final_batch = client
425 .batches()
426 .wait_for("msgbatch_01", opts)
427 .await
428 .unwrap();
429 assert_eq!(final_batch.processing_status, ProcessingStatus::Ended);
430 }
431
432 #[tokio::test]
433 async fn wait_for_honors_timeout() {
434 let mock = MockServer::start().await;
435 Mock::given(method("GET"))
436 .and(path("/v1/messages/batches/msgbatch_01"))
437 .respond_with(ResponseTemplate::new(200).set_body_json(batch_in_progress()))
438 .mount(&mock)
439 .await;
440
441 let client = client_for(&mock);
442 let opts = WaitOptions::default()
443 .poll_interval(std::time::Duration::from_millis(1))
444 .timeout(std::time::Duration::from_millis(20));
445 let err = client
446 .batches()
447 .wait_for("msgbatch_01", opts)
448 .await
449 .unwrap_err();
450 assert!(matches!(err, Error::InvalidConfig(_)));
451 }
452
453 #[tokio::test]
454 async fn results_decodes_jsonl_into_typed_items() {
455 let jsonl = "\
456{\"custom_id\":\"r1\",\"result\":{\"type\":\"succeeded\",\"message\":{\"id\":\"m1\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"text\",\"text\":\"a\"}],\"model\":\"claude-sonnet-4-6\",\"usage\":{\"input_tokens\":1,\"output_tokens\":1}}}}
457{\"custom_id\":\"r2\",\"result\":{\"type\":\"errored\",\"error\":{\"type\":\"rate_limit_error\",\"message\":\"slow\"}}}
458{\"custom_id\":\"r3\",\"result\":{\"type\":\"canceled\"}}
459";
460 let mock = MockServer::start().await;
461 Mock::given(method("GET"))
462 .and(path("/v1/messages/batches/msgbatch_01/results"))
463 .respond_with(
464 ResponseTemplate::new(200)
465 .insert_header("content-type", "application/x-jsonl")
466 .set_body_string(jsonl),
467 )
468 .mount(&mock)
469 .await;
470
471 let client = client_for(&mock);
472 let items = client.batches().results("msgbatch_01").await.unwrap();
473 assert_eq!(items.len(), 3);
474 assert_eq!(items[0].custom_id, "r1");
475 assert!(matches!(
476 items[0].result,
477 BatchResultPayload::Succeeded { .. }
478 ));
479 assert_eq!(items[1].custom_id, "r2");
480 assert!(matches!(
481 items[1].result,
482 BatchResultPayload::Errored { .. }
483 ));
484 assert!(matches!(items[2].result, BatchResultPayload::Canceled));
485 }
486
487 #[tokio::test]
488 async fn results_stream_yields_items_lazily() {
489 let jsonl = "\
490{\"custom_id\":\"a\",\"result\":{\"type\":\"canceled\"}}
491{\"custom_id\":\"b\",\"result\":{\"type\":\"expired\"}}
492";
493 let mock = MockServer::start().await;
494 Mock::given(method("GET"))
495 .and(path("/v1/messages/batches/msgbatch_01/results"))
496 .respond_with(
497 ResponseTemplate::new(200)
498 .insert_header("content-type", "application/x-jsonl")
499 .set_body_string(jsonl),
500 )
501 .mount(&mock)
502 .await;
503
504 let client = client_for(&mock);
505 let mut stream = client
506 .batches()
507 .results_stream("msgbatch_01")
508 .await
509 .unwrap();
510
511 let first = stream.next().await.unwrap().unwrap();
512 assert_eq!(first.custom_id, "a");
513 let second = stream.next().await.unwrap().unwrap();
514 assert_eq!(second.custom_id, "b");
515 assert!(stream.next().await.is_none());
516 }
517
518 #[tokio::test]
519 async fn results_stream_skips_blank_lines() {
520 let jsonl = concat!(
521 "\n",
522 "{\"custom_id\":\"a\",\"result\":{\"type\":\"canceled\"}}\n",
523 "\n",
524 "{\"custom_id\":\"b\",\"result\":{\"type\":\"expired\"}}\n",
525 "\n",
526 );
527 let mock = MockServer::start().await;
528 Mock::given(method("GET"))
529 .and(path("/v1/messages/batches/msgbatch_01/results"))
530 .respond_with(ResponseTemplate::new(200).set_body_string(jsonl))
531 .mount(&mock)
532 .await;
533
534 let client = client_for(&mock);
535 let items = client.batches().results("msgbatch_01").await.unwrap();
536 assert_eq!(items.len(), 2);
537 }
538
539 #[tokio::test]
540 async fn results_stream_handles_missing_trailing_newline() {
541 let jsonl = "{\"custom_id\":\"a\",\"result\":{\"type\":\"canceled\"}}\n{\"custom_id\":\"b\",\"result\":{\"type\":\"expired\"}}";
543 let mock = MockServer::start().await;
544 Mock::given(method("GET"))
545 .and(path("/v1/messages/batches/msgbatch_01/results"))
546 .respond_with(ResponseTemplate::new(200).set_body_string(jsonl))
547 .mount(&mock)
548 .await;
549
550 let client = client_for(&mock);
551 let items = client.batches().results("msgbatch_01").await.unwrap();
552 assert_eq!(items.len(), 2);
553 }
554}