use std::collections::HashSet;
use serde::de::Error as DeError;
use serde::ser::{SerializeStruct, Serializer};
use serde::{Deserialize, Serialize};
use time::OffsetDateTime;
use crate::types::{Message, MessageCreateParams};
const MAX_MESSAGE_BATCH_REQUESTS: usize = 100_000;
const MAX_MESSAGE_BATCH_BODY_BYTES: usize = 256 * 1024 * 1024;
#[derive(Debug, Clone, Deserialize, PartialEq)]
pub struct MessageBatchCreateParams {
pub requests: Vec<MessageBatchCreateRequest>,
#[serde(skip)]
pub betas: Option<Vec<String>>,
}
impl MessageBatchCreateParams {
pub fn new(requests: Vec<MessageBatchCreateRequest>) -> Self {
Self {
requests,
betas: None,
}
}
pub fn with_betas(mut self, betas: impl IntoIterator<Item = impl Into<String>>) -> Self {
self.betas = Some(betas.into_iter().map(Into::into).collect());
self
}
pub fn with_beta(mut self, beta: impl Into<String>) -> Self {
self.betas.get_or_insert_with(Vec::new).push(beta.into());
self
}
pub fn validate(&self) -> crate::Result<()> {
if self.requests.is_empty() {
return Err(crate::Error::validation(
"At least one batch request is required",
Some("requests".to_string()),
));
}
if self.requests.len() > MAX_MESSAGE_BATCH_REQUESTS {
return Err(crate::Error::validation(
format!(
"Batch request count {} exceeds limit of {}",
self.requests.len(),
MAX_MESSAGE_BATCH_REQUESTS
),
Some("requests".to_string()),
));
}
let mut custom_ids = HashSet::with_capacity(self.requests.len());
for (i, request) in self.requests.iter().enumerate() {
if !is_valid_custom_id(&request.custom_id) {
return Err(crate::Error::validation(
"custom_id must be 1 to 64 characters and contain only alphanumeric characters, hyphens, and underscores",
Some(format!("requests[{i}].custom_id")),
));
}
if !custom_ids.insert(request.custom_id.as_str()) {
return Err(crate::Error::validation(
format!("Duplicate custom_id: {}", request.custom_id),
Some(format!("requests[{i}].custom_id")),
));
}
if request.params.stream {
return Err(crate::Error::validation(
"stream is not supported in message batch requests",
Some(format!("requests[{i}].params.stream")),
));
}
request.params.validate().map_err(|err| match err {
crate::Error::Validation { message, param } => crate::Error::validation(
message,
param.map(|param| format!("requests[{i}].params.{param}")),
),
other => other,
})?;
}
let body = serde_json::to_vec(self).map_err(|e| {
crate::Error::serialization(
format!("Failed to serialize message batch create params: {e}"),
Some(Box::new(e)),
)
})?;
if body.len() > MAX_MESSAGE_BATCH_BODY_BYTES {
return Err(crate::Error::validation(
format!(
"Serialized batch request size {} exceeds limit of {} bytes",
body.len(),
MAX_MESSAGE_BATCH_BODY_BYTES
),
Some("requests".to_string()),
));
}
Ok(())
}
}
impl Serialize for MessageBatchCreateParams {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut state = serializer.serialize_struct("MessageBatchCreateParams", 1)?;
state.serialize_field("requests", &self.requests)?;
state.end()
}
}
#[derive(Debug, Clone, Deserialize, PartialEq)]
pub struct MessageBatchCreateRequest {
pub custom_id: String,
pub params: MessageCreateParams,
}
impl MessageBatchCreateRequest {
pub fn new(custom_id: impl Into<String>, params: MessageCreateParams) -> Self {
Self {
custom_id: custom_id.into(),
params,
}
}
}
impl Serialize for MessageBatchCreateRequest {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut state = serializer.serialize_struct("MessageBatchCreateRequest", 2)?;
state.serialize_field("custom_id", &self.custom_id)?;
state.serialize_field("params", &MessageBatchRequestParams(&self.params))?;
state.end()
}
}
struct MessageBatchRequestParams<'a>(&'a MessageCreateParams);
impl Serialize for MessageBatchRequestParams<'_> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let params = self.0;
let mut len = 3;
len += usize::from(params.cache_control.is_some());
len += usize::from(params.metadata.is_some());
len += usize::from(params.output_format.is_some());
len += usize::from(params.output_config.is_some());
len += usize::from(params.stop_sequences.is_some());
len += usize::from(params.system.is_some());
len += usize::from(params.temperature.is_some());
len += usize::from(params.thinking.is_some());
len += usize::from(params.tool_choice.is_some());
len += usize::from(params.tools.is_some());
len += usize::from(params.top_k.is_some());
len += usize::from(params.top_p.is_some());
let mut state = serializer.serialize_struct("MessageBatchRequestParams", len)?;
state.serialize_field("max_tokens", ¶ms.max_tokens)?;
state.serialize_field("messages", ¶ms.messages)?;
state.serialize_field("model", ¶ms.model)?;
if let Some(cache_control) = ¶ms.cache_control {
state.serialize_field("cache_control", cache_control)?;
}
if let Some(metadata) = ¶ms.metadata {
state.serialize_field("metadata", metadata)?;
}
if let Some(output_format) = ¶ms.output_format {
state.serialize_field("output_format", output_format)?;
}
if let Some(output_config) = ¶ms.output_config {
state.serialize_field("output_config", output_config)?;
}
if let Some(stop_sequences) = ¶ms.stop_sequences {
state.serialize_field("stop_sequences", stop_sequences)?;
}
if let Some(system) = ¶ms.system {
state.serialize_field("system", system)?;
}
if let Some(temperature) = ¶ms.temperature {
state.serialize_field("temperature", temperature)?;
}
if let Some(thinking) = ¶ms.thinking {
state.serialize_field("thinking", thinking)?;
}
if let Some(tool_choice) = ¶ms.tool_choice {
state.serialize_field("tool_choice", tool_choice)?;
}
if let Some(tools) = ¶ms.tools {
state.serialize_field("tools", tools)?;
}
if let Some(top_k) = ¶ms.top_k {
state.serialize_field("top_k", top_k)?;
}
if let Some(top_p) = ¶ms.top_p {
state.serialize_field("top_p", top_p)?;
}
state.end()
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct MessageBatch {
pub id: String,
#[serde(rename = "type")]
pub r#type: String,
pub processing_status: MessageBatchProcessingStatus,
pub request_counts: MessageBatchRequestCounts,
#[serde(
default,
skip_serializing_if = "Option::is_none",
with = "time::serde::rfc3339::option"
)]
pub ended_at: Option<OffsetDateTime>,
#[serde(with = "time::serde::rfc3339")]
pub created_at: OffsetDateTime,
#[serde(with = "time::serde::rfc3339")]
pub expires_at: OffsetDateTime,
#[serde(
default,
skip_serializing_if = "Option::is_none",
with = "time::serde::rfc3339::option"
)]
pub cancel_initiated_at: Option<OffsetDateTime>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub results_url: Option<String>,
#[serde(
default,
skip_serializing_if = "Option::is_none",
with = "time::serde::rfc3339::option"
)]
pub archived_at: Option<OffsetDateTime>,
}
#[derive(Debug, Copy, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum MessageBatchProcessingStatus {
InProgress,
Canceling,
Ended,
}
#[derive(Debug, Copy, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct MessageBatchRequestCounts {
pub processing: u32,
pub succeeded: u32,
pub errored: u32,
pub canceled: u32,
pub expired: u32,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
pub struct MessageBatchListParams {
#[serde(skip_serializing_if = "Option::is_none")]
pub after_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub before_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub limit: Option<u32>,
#[serde(skip)]
pub betas: Option<Vec<String>>,
}
impl MessageBatchListParams {
pub fn new() -> Self {
Self::default()
}
pub fn with_after_id(mut self, after_id: impl Into<String>) -> Self {
self.after_id = Some(after_id.into());
self
}
pub fn with_before_id(mut self, before_id: impl Into<String>) -> Self {
self.before_id = Some(before_id.into());
self
}
pub fn with_limit(mut self, limit: u32) -> Self {
self.limit = Some(limit);
self
}
pub fn with_betas(mut self, betas: impl IntoIterator<Item = impl Into<String>>) -> Self {
self.betas = Some(betas.into_iter().map(Into::into).collect());
self
}
pub fn with_beta(mut self, beta: impl Into<String>) -> Self {
self.betas.get_or_insert_with(Vec::new).push(beta.into());
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct MessageBatchListResponse {
pub data: Vec<MessageBatch>,
pub has_more: bool,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub first_id: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub last_id: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct MessageBatchResult {
pub custom_id: String,
pub result: MessageBatchResultVariant,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(tag = "type")]
pub enum MessageBatchResultVariant {
#[serde(rename = "succeeded")]
Succeeded {
message: Message,
},
#[serde(rename = "errored")]
Errored {
error: MessageBatchErrorResponse,
},
#[serde(rename = "canceled")]
Canceled,
#[serde(rename = "expired")]
Expired,
}
#[derive(Debug, Clone, Serialize, PartialEq, Eq)]
pub struct MessageBatchErrorResponse {
#[serde(rename = "type")]
pub r#type: String,
pub error: MessageBatchError,
}
impl<'de> Deserialize<'de> for MessageBatchErrorResponse {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
#[derive(Deserialize)]
struct Helper {
#[serde(rename = "type")]
r#type: Option<String>,
error: Option<MessageBatchError>,
message: Option<String>,
param: Option<String>,
}
let helper = Helper::deserialize(deserializer)?;
if let Some(error) = helper.error {
return Ok(Self {
r#type: helper.r#type.unwrap_or_else(|| "error".to_string()),
error,
});
}
let error_type = helper
.r#type
.ok_or_else(|| D::Error::missing_field("type"))?;
let message = helper
.message
.ok_or_else(|| D::Error::missing_field("message"))?;
Ok(Self {
r#type: "error".to_string(),
error: MessageBatchError {
r#type: error_type,
message,
param: helper.param,
},
})
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct MessageBatchError {
#[serde(rename = "type")]
pub r#type: String,
pub message: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub param: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct DeletedMessageBatch {
pub id: String,
#[serde(rename = "type")]
pub r#type: String,
}
fn is_valid_custom_id(custom_id: &str) -> bool {
!custom_id.is_empty()
&& custom_id.len() <= 64
&& custom_id
.bytes()
.all(|byte| byte.is_ascii_alphanumeric() || byte == b'_' || byte == b'-')
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::{KnownModel, MessageParam, Model, TextBlock, Usage};
use serde_json::{json, to_value};
use time::macros::datetime;
fn valid_message_params() -> MessageCreateParams {
MessageCreateParams::new(
1024,
vec![MessageParam::user("Hello, world")],
Model::Known(KnownModel::ClaudeOpus48),
)
}
fn valid_batch_request(custom_id: &str) -> MessageBatchCreateRequest {
MessageBatchCreateRequest::new(custom_id, valid_message_params())
}
#[test]
fn batch_create_params_serialize_without_stream_or_betas() {
let params = MessageBatchCreateParams::new(vec![valid_batch_request("my-first-request")])
.with_beta("output-300k-2026-03-24");
let json = to_value(¶ms).unwrap();
assert_eq!(
json,
json!({
"requests": [{
"custom_id": "my-first-request",
"params": {
"max_tokens": 1024,
"messages": [{
"role": "user",
"content": "Hello, world"
}],
"model": "claude-opus-4-8"
}
}]
})
);
assert!(json["requests"][0]["params"].get("stream").is_none());
assert!(json.get("betas").is_none());
}
#[test]
fn batch_create_params_validate_success() {
let params = MessageBatchCreateParams::new(vec![valid_batch_request("request_1")]);
assert!(params.validate().is_ok());
}
#[test]
fn batch_create_params_reject_empty_requests() {
let params = MessageBatchCreateParams::new(Vec::new());
assert!(params.validate().unwrap_err().is_validation());
}
#[test]
fn batch_create_params_reject_too_many_requests() {
let params = MessageBatchCreateParams::new(vec![
valid_batch_request("request_1");
MAX_MESSAGE_BATCH_REQUESTS + 1
]);
assert!(params.validate().unwrap_err().is_validation());
}
#[test]
fn batch_create_params_reject_invalid_custom_id() {
let params = MessageBatchCreateParams::new(vec![valid_batch_request("bad id")]);
let err = params.validate().unwrap_err();
assert!(err.is_validation());
assert!(err.to_string().contains("custom_id"));
}
#[test]
fn batch_create_params_reject_duplicate_custom_id() {
let params = MessageBatchCreateParams::new(vec![
valid_batch_request("same-id"),
valid_batch_request("same-id"),
]);
let err = params.validate().unwrap_err();
assert!(err.is_validation());
assert!(err.to_string().contains("Duplicate custom_id"));
}
#[test]
fn batch_create_params_reject_streaming_request() {
let mut request = valid_batch_request("streaming");
request.params.stream = true;
let params = MessageBatchCreateParams::new(vec![request]);
let err = params.validate().unwrap_err();
assert!(err.is_validation());
assert!(err.to_string().contains("stream"));
}
#[test]
fn batch_create_params_reject_zero_max_tokens() {
let mut request = valid_batch_request("zero-tokens");
request.params.max_tokens = 0;
let params = MessageBatchCreateParams::new(vec![request]);
let err = params.validate().unwrap_err();
assert!(err.is_validation());
assert!(err.to_string().contains("max_tokens"));
}
#[test]
fn message_batch_deserialization() {
let json = json!({
"id": "msgbatch_01HkcTjaV5uDC8jWR4ZsDV8d",
"type": "message_batch",
"processing_status": "in_progress",
"request_counts": {
"processing": 2,
"succeeded": 0,
"errored": 0,
"canceled": 0,
"expired": 0
},
"ended_at": null,
"created_at": "2024-09-24T18:37:24.100435Z",
"expires_at": "2024-09-25T18:37:24.100435Z",
"cancel_initiated_at": null,
"results_url": null,
"archived_at": null
});
let batch: MessageBatch = serde_json::from_value(json).unwrap();
assert_eq!(batch.id, "msgbatch_01HkcTjaV5uDC8jWR4ZsDV8d");
assert_eq!(
batch.processing_status,
MessageBatchProcessingStatus::InProgress
);
assert_eq!(batch.request_counts.processing, 2);
assert!(batch.ended_at.is_none());
}
#[test]
fn message_batch_list_response_deserialization() {
let batch = MessageBatch {
id: "msgbatch_123".to_string(),
r#type: "message_batch".to_string(),
processing_status: MessageBatchProcessingStatus::Ended,
request_counts: MessageBatchRequestCounts {
processing: 0,
succeeded: 1,
errored: 0,
canceled: 0,
expired: 0,
},
ended_at: Some(datetime!(2024-09-24 19:37:24 UTC)),
created_at: datetime!(2024-09-24 18:37:24 UTC),
expires_at: datetime!(2024-09-25 18:37:24 UTC),
cancel_initiated_at: None,
results_url: Some("https://api.anthropic.com/result".to_string()),
archived_at: None,
};
let response = MessageBatchListResponse {
data: vec![batch.clone()],
has_more: false,
first_id: Some(batch.id.clone()),
last_id: Some(batch.id.clone()),
};
let json = to_value(&response).unwrap();
let decoded: MessageBatchListResponse = serde_json::from_value(json).unwrap();
assert_eq!(decoded.data[0], batch);
assert!(!decoded.has_more);
}
#[test]
fn batch_result_succeeded_deserialization() {
let json = json!({
"custom_id": "my-first-request",
"result": {
"type": "succeeded",
"message": {
"id": "msg_123",
"type": "message",
"role": "assistant",
"model": "claude-opus-4-8",
"content": [{"type": "text", "text": "Hello"}],
"stop_reason": "end_turn",
"stop_sequence": null,
"usage": {"input_tokens": 10, "output_tokens": 2}
}
}
});
let result: MessageBatchResult = serde_json::from_value(json).unwrap();
match result.result {
MessageBatchResultVariant::Succeeded { message } => {
assert_eq!(message.id, "msg_123");
}
_ => panic!("expected succeeded result"),
}
}
#[test]
fn batch_result_errored_deserializes_standard_error_shape() {
let json = json!({
"custom_id": "bad-request",
"result": {
"type": "errored",
"error": {
"type": "error",
"error": {
"type": "invalid_request_error",
"message": "max_tokens must be at least 1"
}
}
}
});
let result: MessageBatchResult = serde_json::from_value(json).unwrap();
match result.result {
MessageBatchResultVariant::Errored { error } => {
assert_eq!(error.r#type, "error");
assert_eq!(error.error.r#type, "invalid_request_error");
}
_ => panic!("expected errored result"),
}
}
#[test]
fn batch_result_errored_deserializes_direct_error_shape() {
let json = json!({
"custom_id": "bad-request",
"result": {
"type": "errored",
"error": {
"type": "invalid_request_error",
"message": "max_tokens must be at least 1"
}
}
});
let result: MessageBatchResult = serde_json::from_value(json).unwrap();
match result.result {
MessageBatchResultVariant::Errored { error } => {
assert_eq!(error.r#type, "error");
assert_eq!(error.error.r#type, "invalid_request_error");
}
_ => panic!("expected errored result"),
}
}
#[test]
fn batch_result_canceled_and_expired_deserialization() {
let canceled: MessageBatchResult = serde_json::from_value(json!({
"custom_id": "canceled-request",
"result": {"type": "canceled"}
}))
.unwrap();
assert!(matches!(
canceled.result,
MessageBatchResultVariant::Canceled
));
let expired: MessageBatchResult = serde_json::from_value(json!({
"custom_id": "expired-request",
"result": {"type": "expired"}
}))
.unwrap();
assert!(matches!(expired.result, MessageBatchResultVariant::Expired));
}
#[test]
fn deleted_message_batch_deserialization() {
let deleted: DeletedMessageBatch = serde_json::from_value(json!({
"id": "msgbatch_123",
"type": "message_batch_deleted"
}))
.unwrap();
assert_eq!(deleted.id, "msgbatch_123");
assert_eq!(deleted.r#type, "message_batch_deleted");
}
#[test]
fn message_batch_result_round_trip_succeeded() {
let message = Message::new(
"msg_123".to_string(),
vec![TextBlock::new("Hello").into()],
Model::Known(KnownModel::ClaudeOpus48),
Usage::new(1, 1),
);
let result = MessageBatchResult {
custom_id: "request-1".to_string(),
result: MessageBatchResultVariant::Succeeded { message },
};
let json = to_value(&result).unwrap();
assert_eq!(json["result"]["type"], "succeeded");
let decoded: MessageBatchResult = serde_json::from_value(json).unwrap();
assert_eq!(decoded.custom_id, "request-1");
}
}