use std::time::Duration;
use serde::{Deserialize, Serialize};
use crate::error::ApiErrorPayload;
use crate::messages::request::CreateMessageRequest;
use crate::messages::response::Message;
#[derive(Debug, Clone, Serialize)]
#[non_exhaustive]
pub struct BatchRequest {
pub custom_id: String,
pub params: CreateMessageRequest,
}
impl BatchRequest {
#[must_use]
pub fn new(custom_id: impl Into<String>, params: CreateMessageRequest) -> Self {
Self {
custom_id: custom_id.into(),
params,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[non_exhaustive]
pub struct MessageBatch {
pub id: String,
#[serde(rename = "type", default = "default_batch_kind")]
pub kind: String,
pub processing_status: ProcessingStatus,
pub request_counts: RequestCounts,
pub created_at: String,
pub expires_at: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub ended_at: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub archived_at: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub cancel_initiated_at: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub results_url: Option<String>,
}
fn default_batch_kind() -> String {
"message_batch".to_owned()
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
#[non_exhaustive]
pub enum ProcessingStatus {
InProgress,
Canceling,
Ended,
#[serde(other)]
Other,
}
#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
#[non_exhaustive]
pub struct RequestCounts {
#[serde(default)]
pub processing: u32,
#[serde(default)]
pub succeeded: u32,
#[serde(default)]
pub errored: u32,
#[serde(default)]
pub canceled: u32,
#[serde(default)]
pub expired: u32,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[non_exhaustive]
pub struct BatchResultItem {
pub custom_id: String,
pub result: BatchResultPayload,
}
#[allow(clippy::large_enum_variant)]
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
#[non_exhaustive]
pub enum BatchResultPayload {
Succeeded {
message: Message,
},
Errored {
error: ApiErrorPayload,
},
Canceled,
Expired,
}
#[derive(Debug, Clone, Default, Serialize)]
#[non_exhaustive]
pub struct ListBatchesParams {
#[serde(skip_serializing_if = "Option::is_none")]
pub before_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub after_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub limit: Option<u32>,
}
impl ListBatchesParams {
#[must_use]
pub fn after_id(mut self, id: impl Into<String>) -> Self {
self.after_id = Some(id.into());
self
}
#[must_use]
pub fn before_id(mut self, id: impl Into<String>) -> Self {
self.before_id = Some(id.into());
self
}
#[must_use]
pub fn limit(mut self, limit: u32) -> Self {
self.limit = Some(limit);
self
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[non_exhaustive]
pub struct BatchDeleted {
pub id: String,
#[serde(rename = "type", default)]
pub kind: String,
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct WaitOptions {
pub poll_interval: Duration,
pub timeout: Option<Duration>,
}
impl Default for WaitOptions {
fn default() -> Self {
Self {
poll_interval: Duration::from_secs(30),
timeout: None,
}
}
}
impl WaitOptions {
#[must_use]
pub fn poll_interval(mut self, d: Duration) -> Self {
self.poll_interval = d;
self
}
#[must_use]
pub fn timeout(mut self, d: Duration) -> Self {
self.timeout = Some(d);
self
}
}
#[cfg(test)]
mod tests {
use super::*;
use pretty_assertions::assert_eq;
use serde_json::json;
#[test]
fn message_batch_in_progress_round_trips() {
let raw = json!({
"id": "msgbatch_01ABC",
"type": "message_batch",
"processing_status": "in_progress",
"request_counts": {
"processing": 100,
"succeeded": 0,
"errored": 0,
"canceled": 0,
"expired": 0
},
"created_at": "2026-04-30T00:00:00Z",
"expires_at": "2026-05-01T00:00:00Z",
"ended_at": null,
"archived_at": null,
"cancel_initiated_at": null,
"results_url": null
});
let parsed: MessageBatch = serde_json::from_value(raw).unwrap();
assert_eq!(parsed.id, "msgbatch_01ABC");
assert_eq!(parsed.kind, "message_batch");
assert_eq!(parsed.processing_status, ProcessingStatus::InProgress);
assert_eq!(parsed.request_counts.processing, 100);
assert_eq!(parsed.ended_at, None);
}
#[test]
fn message_batch_ended_includes_results_url() {
let raw = json!({
"id": "msgbatch_01XYZ",
"type": "message_batch",
"processing_status": "ended",
"request_counts": {
"processing": 0, "succeeded": 95, "errored": 3,
"canceled": 0, "expired": 2
},
"created_at": "2026-04-30T00:00:00Z",
"expires_at": "2026-05-01T00:00:00Z",
"ended_at": "2026-04-30T01:00:00Z",
"results_url": "https://example/results"
});
let parsed: MessageBatch = serde_json::from_value(raw).unwrap();
assert_eq!(parsed.processing_status, ProcessingStatus::Ended);
assert_eq!(parsed.request_counts.succeeded, 95);
assert!(parsed.ended_at.is_some());
}
#[test]
fn processing_status_unknown_falls_back_to_other() {
let parsed: ProcessingStatus = serde_json::from_str("\"future_status\"").unwrap();
assert_eq!(parsed, ProcessingStatus::Other);
}
#[test]
fn batch_result_payload_succeeded_round_trips() {
let raw = json!({
"type": "succeeded",
"message": {
"id": "msg_X",
"type": "message",
"role": "assistant",
"content": [{"type": "text", "text": "hi"}],
"model": "claude-sonnet-4-6",
"stop_reason": "end_turn",
"usage": {"input_tokens": 5, "output_tokens": 1}
}
});
let parsed: BatchResultPayload = serde_json::from_value(raw).unwrap();
match parsed {
BatchResultPayload::Succeeded { message } => {
assert_eq!(message.id, "msg_X");
}
other => panic!("expected Succeeded, got {other:?}"),
}
}
#[test]
fn batch_result_payload_errored_round_trips() {
let raw = json!({
"type": "errored",
"error": {"type": "rate_limit_error", "message": "slow down"}
});
let parsed: BatchResultPayload = serde_json::from_value(raw).unwrap();
assert!(matches!(parsed, BatchResultPayload::Errored { .. }));
}
#[test]
fn batch_result_payload_canceled_and_expired_round_trip() {
let parsed: BatchResultPayload =
serde_json::from_value(json!({"type": "canceled"})).unwrap();
assert!(matches!(parsed, BatchResultPayload::Canceled));
let parsed: BatchResultPayload =
serde_json::from_value(json!({"type": "expired"})).unwrap();
assert!(matches!(parsed, BatchResultPayload::Expired));
}
#[test]
fn batch_result_item_round_trips() {
let raw = json!({
"custom_id": "req-42",
"result": {"type": "canceled"}
});
let parsed: BatchResultItem = serde_json::from_value(raw).unwrap();
assert_eq!(parsed.custom_id, "req-42");
assert!(matches!(parsed.result, BatchResultPayload::Canceled));
}
}