1use std::time::Duration;
4
5use serde::{Deserialize, Serialize};
6
7use crate::error::ApiErrorPayload;
8use crate::messages::request::CreateMessageRequest;
9use crate::messages::response::Message;
10
11#[derive(Debug, Clone, Serialize)]
13#[non_exhaustive]
14pub struct BatchRequest {
15 pub custom_id: String,
18 pub params: CreateMessageRequest,
20}
21
22impl BatchRequest {
23 #[must_use]
25 pub fn new(custom_id: impl Into<String>, params: CreateMessageRequest) -> Self {
26 Self {
27 custom_id: custom_id.into(),
28 params,
29 }
30 }
31}
32
33#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
35#[non_exhaustive]
36pub struct MessageBatch {
37 pub id: String,
39 #[serde(rename = "type", default = "default_batch_kind")]
41 pub kind: String,
42 pub processing_status: ProcessingStatus,
44 pub request_counts: RequestCounts,
46 pub created_at: String,
48 pub expires_at: String,
50 #[serde(default, skip_serializing_if = "Option::is_none")]
52 pub ended_at: Option<String>,
53 #[serde(default, skip_serializing_if = "Option::is_none")]
55 pub archived_at: Option<String>,
56 #[serde(default, skip_serializing_if = "Option::is_none")]
58 pub cancel_initiated_at: Option<String>,
59 #[serde(default, skip_serializing_if = "Option::is_none")]
63 pub results_url: Option<String>,
64}
65
66fn default_batch_kind() -> String {
67 "message_batch".to_owned()
68}
69
70#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
75#[serde(rename_all = "snake_case")]
76#[non_exhaustive]
77pub enum ProcessingStatus {
78 InProgress,
80 Canceling,
82 Ended,
84 #[serde(other)]
86 Other,
87}
88
89#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
91#[non_exhaustive]
92pub struct RequestCounts {
93 #[serde(default)]
95 pub processing: u32,
96 #[serde(default)]
98 pub succeeded: u32,
99 #[serde(default)]
101 pub errored: u32,
102 #[serde(default)]
104 pub canceled: u32,
105 #[serde(default)]
107 pub expired: u32,
108}
109
110#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
112#[non_exhaustive]
113pub struct BatchResultItem {
114 pub custom_id: String,
116 pub result: BatchResultPayload,
118}
119
120#[allow(clippy::large_enum_variant)]
122#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
123#[serde(tag = "type", rename_all = "snake_case")]
124#[non_exhaustive]
125pub enum BatchResultPayload {
126 Succeeded {
128 message: Message,
130 },
131 Errored {
133 error: ApiErrorPayload,
135 },
136 Canceled,
138 Expired,
140}
141
142#[derive(Debug, Clone, Default, Serialize)]
144#[non_exhaustive]
145pub struct ListBatchesParams {
146 #[serde(skip_serializing_if = "Option::is_none")]
148 pub before_id: Option<String>,
149 #[serde(skip_serializing_if = "Option::is_none")]
151 pub after_id: Option<String>,
152 #[serde(skip_serializing_if = "Option::is_none")]
154 pub limit: Option<u32>,
155}
156
157impl ListBatchesParams {
158 #[must_use]
160 pub fn after_id(mut self, id: impl Into<String>) -> Self {
161 self.after_id = Some(id.into());
162 self
163 }
164
165 #[must_use]
167 pub fn before_id(mut self, id: impl Into<String>) -> Self {
168 self.before_id = Some(id.into());
169 self
170 }
171
172 #[must_use]
174 pub fn limit(mut self, limit: u32) -> Self {
175 self.limit = Some(limit);
176 self
177 }
178}
179
180#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
182#[non_exhaustive]
183pub struct BatchDeleted {
184 pub id: String,
186 #[serde(rename = "type", default)]
188 pub kind: String,
189}
190
191#[derive(Debug, Clone)]
193#[non_exhaustive]
194pub struct WaitOptions {
195 pub poll_interval: Duration,
197 pub timeout: Option<Duration>,
199}
200
201impl Default for WaitOptions {
202 fn default() -> Self {
203 Self {
204 poll_interval: Duration::from_secs(30),
205 timeout: None,
206 }
207 }
208}
209
210impl WaitOptions {
211 #[must_use]
213 pub fn poll_interval(mut self, d: Duration) -> Self {
214 self.poll_interval = d;
215 self
216 }
217
218 #[must_use]
221 pub fn timeout(mut self, d: Duration) -> Self {
222 self.timeout = Some(d);
223 self
224 }
225}
226
227#[cfg(test)]
228mod tests {
229 use super::*;
230 use pretty_assertions::assert_eq;
231 use serde_json::json;
232
233 #[test]
234 fn message_batch_in_progress_round_trips() {
235 let raw = json!({
236 "id": "msgbatch_01ABC",
237 "type": "message_batch",
238 "processing_status": "in_progress",
239 "request_counts": {
240 "processing": 100,
241 "succeeded": 0,
242 "errored": 0,
243 "canceled": 0,
244 "expired": 0
245 },
246 "created_at": "2026-04-30T00:00:00Z",
247 "expires_at": "2026-05-01T00:00:00Z",
248 "ended_at": null,
249 "archived_at": null,
250 "cancel_initiated_at": null,
251 "results_url": null
252 });
253 let parsed: MessageBatch = serde_json::from_value(raw).unwrap();
254 assert_eq!(parsed.id, "msgbatch_01ABC");
255 assert_eq!(parsed.kind, "message_batch");
256 assert_eq!(parsed.processing_status, ProcessingStatus::InProgress);
257 assert_eq!(parsed.request_counts.processing, 100);
258 assert_eq!(parsed.ended_at, None);
259 }
260
261 #[test]
262 fn message_batch_ended_includes_results_url() {
263 let raw = json!({
264 "id": "msgbatch_01XYZ",
265 "type": "message_batch",
266 "processing_status": "ended",
267 "request_counts": {
268 "processing": 0, "succeeded": 95, "errored": 3,
269 "canceled": 0, "expired": 2
270 },
271 "created_at": "2026-04-30T00:00:00Z",
272 "expires_at": "2026-05-01T00:00:00Z",
273 "ended_at": "2026-04-30T01:00:00Z",
274 "results_url": "https://example/results"
275 });
276 let parsed: MessageBatch = serde_json::from_value(raw).unwrap();
277 assert_eq!(parsed.processing_status, ProcessingStatus::Ended);
278 assert_eq!(parsed.request_counts.succeeded, 95);
279 assert!(parsed.ended_at.is_some());
280 }
281
282 #[test]
283 fn processing_status_unknown_falls_back_to_other() {
284 let parsed: ProcessingStatus = serde_json::from_str("\"future_status\"").unwrap();
285 assert_eq!(parsed, ProcessingStatus::Other);
286 }
287
288 #[test]
289 fn batch_result_payload_succeeded_round_trips() {
290 let raw = json!({
291 "type": "succeeded",
292 "message": {
293 "id": "msg_X",
294 "type": "message",
295 "role": "assistant",
296 "content": [{"type": "text", "text": "hi"}],
297 "model": "claude-sonnet-4-6",
298 "stop_reason": "end_turn",
299 "usage": {"input_tokens": 5, "output_tokens": 1}
300 }
301 });
302 let parsed: BatchResultPayload = serde_json::from_value(raw).unwrap();
303 match parsed {
304 BatchResultPayload::Succeeded { message } => {
305 assert_eq!(message.id, "msg_X");
306 }
307 other => panic!("expected Succeeded, got {other:?}"),
308 }
309 }
310
311 #[test]
312 fn batch_result_payload_errored_round_trips() {
313 let raw = json!({
314 "type": "errored",
315 "error": {"type": "rate_limit_error", "message": "slow down"}
316 });
317 let parsed: BatchResultPayload = serde_json::from_value(raw).unwrap();
318 assert!(matches!(parsed, BatchResultPayload::Errored { .. }));
319 }
320
321 #[test]
322 fn batch_result_payload_canceled_and_expired_round_trip() {
323 let parsed: BatchResultPayload =
324 serde_json::from_value(json!({"type": "canceled"})).unwrap();
325 assert!(matches!(parsed, BatchResultPayload::Canceled));
326
327 let parsed: BatchResultPayload =
328 serde_json::from_value(json!({"type": "expired"})).unwrap();
329 assert!(matches!(parsed, BatchResultPayload::Expired));
330 }
331
332 #[test]
333 fn batch_result_item_round_trips() {
334 let raw = json!({
335 "custom_id": "req-42",
336 "result": {"type": "canceled"}
337 });
338 let parsed: BatchResultItem = serde_json::from_value(raw).unwrap();
339 assert_eq!(parsed.custom_id, "req-42");
340 assert!(matches!(parsed.result, BatchResultPayload::Canceled));
341 }
342}