1use std::collections::HashSet;
2
3use serde::de::Error as DeError;
4use serde::ser::{SerializeStruct, Serializer};
5use serde::{Deserialize, Serialize};
6use time::OffsetDateTime;
7
8use crate::types::{Message, MessageCreateParams};
9
10const MAX_MESSAGE_BATCH_REQUESTS: usize = 100_000;
11const MAX_MESSAGE_BATCH_BODY_BYTES: usize = 256 * 1024 * 1024;
12
13#[derive(Debug, Clone, Deserialize, PartialEq)]
15pub struct MessageBatchCreateParams {
16 pub requests: Vec<MessageBatchCreateRequest>,
18
19 #[serde(skip)]
24 pub betas: Option<Vec<String>>,
25}
26
27impl MessageBatchCreateParams {
28 pub fn new(requests: Vec<MessageBatchCreateRequest>) -> Self {
30 Self {
31 requests,
32 betas: None,
33 }
34 }
35
36 pub fn with_betas(mut self, betas: impl IntoIterator<Item = impl Into<String>>) -> Self {
38 self.betas = Some(betas.into_iter().map(Into::into).collect());
39 self
40 }
41
42 pub fn with_beta(mut self, beta: impl Into<String>) -> Self {
44 self.betas.get_or_insert_with(Vec::new).push(beta.into());
45 self
46 }
47
48 pub fn validate(&self) -> crate::Result<()> {
50 if self.requests.is_empty() {
51 return Err(crate::Error::validation(
52 "At least one batch request is required",
53 Some("requests".to_string()),
54 ));
55 }
56
57 if self.requests.len() > MAX_MESSAGE_BATCH_REQUESTS {
58 return Err(crate::Error::validation(
59 format!(
60 "Batch request count {} exceeds limit of {}",
61 self.requests.len(),
62 MAX_MESSAGE_BATCH_REQUESTS
63 ),
64 Some("requests".to_string()),
65 ));
66 }
67
68 let mut custom_ids = HashSet::with_capacity(self.requests.len());
69 for (i, request) in self.requests.iter().enumerate() {
70 if !is_valid_custom_id(&request.custom_id) {
71 return Err(crate::Error::validation(
72 "custom_id must be 1 to 64 characters and contain only alphanumeric characters, hyphens, and underscores",
73 Some(format!("requests[{i}].custom_id")),
74 ));
75 }
76
77 if !custom_ids.insert(request.custom_id.as_str()) {
78 return Err(crate::Error::validation(
79 format!("Duplicate custom_id: {}", request.custom_id),
80 Some(format!("requests[{i}].custom_id")),
81 ));
82 }
83
84 if request.params.stream {
85 return Err(crate::Error::validation(
86 "stream is not supported in message batch requests",
87 Some(format!("requests[{i}].params.stream")),
88 ));
89 }
90
91 request.params.validate().map_err(|err| match err {
92 crate::Error::Validation { message, param } => crate::Error::validation(
93 message,
94 param.map(|param| format!("requests[{i}].params.{param}")),
95 ),
96 other => other,
97 })?;
98 }
99
100 let body = serde_json::to_vec(self).map_err(|e| {
101 crate::Error::serialization(
102 format!("Failed to serialize message batch create params: {e}"),
103 Some(Box::new(e)),
104 )
105 })?;
106 if body.len() > MAX_MESSAGE_BATCH_BODY_BYTES {
107 return Err(crate::Error::validation(
108 format!(
109 "Serialized batch request size {} exceeds limit of {} bytes",
110 body.len(),
111 MAX_MESSAGE_BATCH_BODY_BYTES
112 ),
113 Some("requests".to_string()),
114 ));
115 }
116
117 Ok(())
118 }
119}
120
121impl Serialize for MessageBatchCreateParams {
122 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
123 where
124 S: Serializer,
125 {
126 let mut state = serializer.serialize_struct("MessageBatchCreateParams", 1)?;
127 state.serialize_field("requests", &self.requests)?;
128 state.end()
129 }
130}
131
132#[derive(Debug, Clone, Deserialize, PartialEq)]
134pub struct MessageBatchCreateRequest {
135 pub custom_id: String,
137
138 pub params: MessageCreateParams,
140}
141
142impl MessageBatchCreateRequest {
143 pub fn new(custom_id: impl Into<String>, params: MessageCreateParams) -> Self {
145 Self {
146 custom_id: custom_id.into(),
147 params,
148 }
149 }
150}
151
152impl Serialize for MessageBatchCreateRequest {
153 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
154 where
155 S: Serializer,
156 {
157 let mut state = serializer.serialize_struct("MessageBatchCreateRequest", 2)?;
158 state.serialize_field("custom_id", &self.custom_id)?;
159 state.serialize_field("params", &MessageBatchRequestParams(&self.params))?;
160 state.end()
161 }
162}
163
164struct MessageBatchRequestParams<'a>(&'a MessageCreateParams);
165
166impl Serialize for MessageBatchRequestParams<'_> {
167 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
168 where
169 S: Serializer,
170 {
171 let params = self.0;
172 let mut len = 3;
173 len += usize::from(params.cache_control.is_some());
174 len += usize::from(params.metadata.is_some());
175 len += usize::from(params.output_format.is_some());
176 len += usize::from(params.output_config.is_some());
177 len += usize::from(params.stop_sequences.is_some());
178 len += usize::from(params.system.is_some());
179 len += usize::from(params.temperature.is_some());
180 len += usize::from(params.thinking.is_some());
181 len += usize::from(params.tool_choice.is_some());
182 len += usize::from(params.tools.is_some());
183 len += usize::from(params.top_k.is_some());
184 len += usize::from(params.top_p.is_some());
185
186 let mut state = serializer.serialize_struct("MessageBatchRequestParams", len)?;
187 state.serialize_field("max_tokens", ¶ms.max_tokens)?;
188 state.serialize_field("messages", ¶ms.messages)?;
189 state.serialize_field("model", ¶ms.model)?;
190 if let Some(cache_control) = ¶ms.cache_control {
191 state.serialize_field("cache_control", cache_control)?;
192 }
193 if let Some(metadata) = ¶ms.metadata {
194 state.serialize_field("metadata", metadata)?;
195 }
196 if let Some(output_format) = ¶ms.output_format {
197 state.serialize_field("output_format", output_format)?;
198 }
199 if let Some(output_config) = ¶ms.output_config {
200 state.serialize_field("output_config", output_config)?;
201 }
202 if let Some(stop_sequences) = ¶ms.stop_sequences {
203 state.serialize_field("stop_sequences", stop_sequences)?;
204 }
205 if let Some(system) = ¶ms.system {
206 state.serialize_field("system", system)?;
207 }
208 if let Some(temperature) = ¶ms.temperature {
209 state.serialize_field("temperature", temperature)?;
210 }
211 if let Some(thinking) = ¶ms.thinking {
212 state.serialize_field("thinking", thinking)?;
213 }
214 if let Some(tool_choice) = ¶ms.tool_choice {
215 state.serialize_field("tool_choice", tool_choice)?;
216 }
217 if let Some(tools) = ¶ms.tools {
218 state.serialize_field("tools", tools)?;
219 }
220 if let Some(top_k) = ¶ms.top_k {
221 state.serialize_field("top_k", top_k)?;
222 }
223 if let Some(top_p) = ¶ms.top_p {
224 state.serialize_field("top_p", top_p)?;
225 }
226 state.end()
227 }
228}
229
230#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
232pub struct MessageBatch {
233 pub id: String,
235
236 #[serde(rename = "type")]
238 pub r#type: String,
239
240 pub processing_status: MessageBatchProcessingStatus,
242
243 pub request_counts: MessageBatchRequestCounts,
245
246 #[serde(
248 default,
249 skip_serializing_if = "Option::is_none",
250 with = "time::serde::rfc3339::option"
251 )]
252 pub ended_at: Option<OffsetDateTime>,
253
254 #[serde(with = "time::serde::rfc3339")]
256 pub created_at: OffsetDateTime,
257
258 #[serde(with = "time::serde::rfc3339")]
260 pub expires_at: OffsetDateTime,
261
262 #[serde(
264 default,
265 skip_serializing_if = "Option::is_none",
266 with = "time::serde::rfc3339::option"
267 )]
268 pub cancel_initiated_at: Option<OffsetDateTime>,
269
270 #[serde(default, skip_serializing_if = "Option::is_none")]
272 pub results_url: Option<String>,
273
274 #[serde(
276 default,
277 skip_serializing_if = "Option::is_none",
278 with = "time::serde::rfc3339::option"
279 )]
280 pub archived_at: Option<OffsetDateTime>,
281}
282
283#[derive(Debug, Copy, Clone, Serialize, Deserialize, PartialEq, Eq)]
285#[serde(rename_all = "snake_case")]
286pub enum MessageBatchProcessingStatus {
287 InProgress,
289
290 Canceling,
292
293 Ended,
295}
296
297#[derive(Debug, Copy, Clone, Serialize, Deserialize, PartialEq, Eq)]
299pub struct MessageBatchRequestCounts {
300 pub processing: u32,
302
303 pub succeeded: u32,
305
306 pub errored: u32,
308
309 pub canceled: u32,
311
312 pub expired: u32,
314}
315
316#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
318pub struct MessageBatchListParams {
319 #[serde(skip_serializing_if = "Option::is_none")]
321 pub after_id: Option<String>,
322
323 #[serde(skip_serializing_if = "Option::is_none")]
325 pub before_id: Option<String>,
326
327 #[serde(skip_serializing_if = "Option::is_none")]
329 pub limit: Option<u32>,
330
331 #[serde(skip)]
333 pub betas: Option<Vec<String>>,
334}
335
336impl MessageBatchListParams {
337 pub fn new() -> Self {
339 Self::default()
340 }
341
342 pub fn with_after_id(mut self, after_id: impl Into<String>) -> Self {
344 self.after_id = Some(after_id.into());
345 self
346 }
347
348 pub fn with_before_id(mut self, before_id: impl Into<String>) -> Self {
350 self.before_id = Some(before_id.into());
351 self
352 }
353
354 pub fn with_limit(mut self, limit: u32) -> Self {
356 self.limit = Some(limit);
357 self
358 }
359
360 pub fn with_betas(mut self, betas: impl IntoIterator<Item = impl Into<String>>) -> Self {
362 self.betas = Some(betas.into_iter().map(Into::into).collect());
363 self
364 }
365
366 pub fn with_beta(mut self, beta: impl Into<String>) -> Self {
368 self.betas.get_or_insert_with(Vec::new).push(beta.into());
369 self
370 }
371}
372
373#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
375pub struct MessageBatchListResponse {
376 pub data: Vec<MessageBatch>,
378
379 pub has_more: bool,
381
382 #[serde(default, skip_serializing_if = "Option::is_none")]
384 pub first_id: Option<String>,
385
386 #[serde(default, skip_serializing_if = "Option::is_none")]
388 pub last_id: Option<String>,
389}
390
391#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
393pub struct MessageBatchResult {
394 pub custom_id: String,
396
397 pub result: MessageBatchResultVariant,
399}
400
401#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
403#[serde(tag = "type")]
404pub enum MessageBatchResultVariant {
405 #[serde(rename = "succeeded")]
407 Succeeded {
408 message: Message,
410 },
411
412 #[serde(rename = "errored")]
414 Errored {
415 error: MessageBatchErrorResponse,
417 },
418
419 #[serde(rename = "canceled")]
421 Canceled,
422
423 #[serde(rename = "expired")]
425 Expired,
426}
427
428#[derive(Debug, Clone, Serialize, PartialEq, Eq)]
430pub struct MessageBatchErrorResponse {
431 #[serde(rename = "type")]
433 pub r#type: String,
434
435 pub error: MessageBatchError,
437}
438
439impl<'de> Deserialize<'de> for MessageBatchErrorResponse {
440 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
441 where
442 D: serde::Deserializer<'de>,
443 {
444 #[derive(Deserialize)]
445 struct Helper {
446 #[serde(rename = "type")]
447 r#type: Option<String>,
448 error: Option<MessageBatchError>,
449 message: Option<String>,
450 param: Option<String>,
451 }
452
453 let helper = Helper::deserialize(deserializer)?;
454 if let Some(error) = helper.error {
455 return Ok(Self {
456 r#type: helper.r#type.unwrap_or_else(|| "error".to_string()),
457 error,
458 });
459 }
460
461 let error_type = helper
462 .r#type
463 .ok_or_else(|| D::Error::missing_field("type"))?;
464 let message = helper
465 .message
466 .ok_or_else(|| D::Error::missing_field("message"))?;
467 Ok(Self {
468 r#type: "error".to_string(),
469 error: MessageBatchError {
470 r#type: error_type,
471 message,
472 param: helper.param,
473 },
474 })
475 }
476}
477
478#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
480pub struct MessageBatchError {
481 #[serde(rename = "type")]
483 pub r#type: String,
484
485 pub message: String,
487
488 #[serde(default, skip_serializing_if = "Option::is_none")]
490 pub param: Option<String>,
491}
492
493#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
495pub struct DeletedMessageBatch {
496 pub id: String,
498
499 #[serde(rename = "type")]
501 pub r#type: String,
502}
503
504fn is_valid_custom_id(custom_id: &str) -> bool {
505 !custom_id.is_empty()
506 && custom_id.len() <= 64
507 && custom_id
508 .bytes()
509 .all(|byte| byte.is_ascii_alphanumeric() || byte == b'_' || byte == b'-')
510}
511
512#[cfg(test)]
513mod tests {
514 use super::*;
515 use crate::types::{KnownModel, MessageParam, Model, TextBlock, Usage};
516 use serde_json::{json, to_value};
517 use time::macros::datetime;
518
519 fn valid_message_params() -> MessageCreateParams {
520 MessageCreateParams::new(
521 1024,
522 vec![MessageParam::user("Hello, world")],
523 Model::Known(KnownModel::ClaudeOpus48),
524 )
525 }
526
527 fn valid_batch_request(custom_id: &str) -> MessageBatchCreateRequest {
528 MessageBatchCreateRequest::new(custom_id, valid_message_params())
529 }
530
531 #[test]
532 fn batch_create_params_serialize_without_stream_or_betas() {
533 let params = MessageBatchCreateParams::new(vec![valid_batch_request("my-first-request")])
534 .with_beta("output-300k-2026-03-24");
535
536 let json = to_value(¶ms).unwrap();
537 assert_eq!(
538 json,
539 json!({
540 "requests": [{
541 "custom_id": "my-first-request",
542 "params": {
543 "max_tokens": 1024,
544 "messages": [{
545 "role": "user",
546 "content": "Hello, world"
547 }],
548 "model": "claude-opus-4-8"
549 }
550 }]
551 })
552 );
553 assert!(json["requests"][0]["params"].get("stream").is_none());
554 assert!(json.get("betas").is_none());
555 }
556
557 #[test]
558 fn batch_create_params_validate_success() {
559 let params = MessageBatchCreateParams::new(vec![valid_batch_request("request_1")]);
560 assert!(params.validate().is_ok());
561 }
562
563 #[test]
564 fn batch_create_params_reject_empty_requests() {
565 let params = MessageBatchCreateParams::new(Vec::new());
566 assert!(params.validate().unwrap_err().is_validation());
567 }
568
569 #[test]
570 fn batch_create_params_reject_too_many_requests() {
571 let params = MessageBatchCreateParams::new(vec![
572 valid_batch_request("request_1");
573 MAX_MESSAGE_BATCH_REQUESTS + 1
574 ]);
575 assert!(params.validate().unwrap_err().is_validation());
576 }
577
578 #[test]
579 fn batch_create_params_reject_invalid_custom_id() {
580 let params = MessageBatchCreateParams::new(vec![valid_batch_request("bad id")]);
581 let err = params.validate().unwrap_err();
582 assert!(err.is_validation());
583 assert!(err.to_string().contains("custom_id"));
584 }
585
586 #[test]
587 fn batch_create_params_reject_duplicate_custom_id() {
588 let params = MessageBatchCreateParams::new(vec![
589 valid_batch_request("same-id"),
590 valid_batch_request("same-id"),
591 ]);
592 let err = params.validate().unwrap_err();
593 assert!(err.is_validation());
594 assert!(err.to_string().contains("Duplicate custom_id"));
595 }
596
597 #[test]
598 fn batch_create_params_reject_streaming_request() {
599 let mut request = valid_batch_request("streaming");
600 request.params.stream = true;
601 let params = MessageBatchCreateParams::new(vec![request]);
602 let err = params.validate().unwrap_err();
603 assert!(err.is_validation());
604 assert!(err.to_string().contains("stream"));
605 }
606
607 #[test]
608 fn batch_create_params_reject_zero_max_tokens() {
609 let mut request = valid_batch_request("zero-tokens");
610 request.params.max_tokens = 0;
611 let params = MessageBatchCreateParams::new(vec![request]);
612 let err = params.validate().unwrap_err();
613 assert!(err.is_validation());
614 assert!(err.to_string().contains("max_tokens"));
615 }
616
617 #[test]
618 fn message_batch_deserialization() {
619 let json = json!({
620 "id": "msgbatch_01HkcTjaV5uDC8jWR4ZsDV8d",
621 "type": "message_batch",
622 "processing_status": "in_progress",
623 "request_counts": {
624 "processing": 2,
625 "succeeded": 0,
626 "errored": 0,
627 "canceled": 0,
628 "expired": 0
629 },
630 "ended_at": null,
631 "created_at": "2024-09-24T18:37:24.100435Z",
632 "expires_at": "2024-09-25T18:37:24.100435Z",
633 "cancel_initiated_at": null,
634 "results_url": null,
635 "archived_at": null
636 });
637
638 let batch: MessageBatch = serde_json::from_value(json).unwrap();
639 assert_eq!(batch.id, "msgbatch_01HkcTjaV5uDC8jWR4ZsDV8d");
640 assert_eq!(
641 batch.processing_status,
642 MessageBatchProcessingStatus::InProgress
643 );
644 assert_eq!(batch.request_counts.processing, 2);
645 assert!(batch.ended_at.is_none());
646 }
647
648 #[test]
649 fn message_batch_list_response_deserialization() {
650 let batch = MessageBatch {
651 id: "msgbatch_123".to_string(),
652 r#type: "message_batch".to_string(),
653 processing_status: MessageBatchProcessingStatus::Ended,
654 request_counts: MessageBatchRequestCounts {
655 processing: 0,
656 succeeded: 1,
657 errored: 0,
658 canceled: 0,
659 expired: 0,
660 },
661 ended_at: Some(datetime!(2024-09-24 19:37:24 UTC)),
662 created_at: datetime!(2024-09-24 18:37:24 UTC),
663 expires_at: datetime!(2024-09-25 18:37:24 UTC),
664 cancel_initiated_at: None,
665 results_url: Some("https://api.anthropic.com/result".to_string()),
666 archived_at: None,
667 };
668 let response = MessageBatchListResponse {
669 data: vec![batch.clone()],
670 has_more: false,
671 first_id: Some(batch.id.clone()),
672 last_id: Some(batch.id.clone()),
673 };
674
675 let json = to_value(&response).unwrap();
676 let decoded: MessageBatchListResponse = serde_json::from_value(json).unwrap();
677 assert_eq!(decoded.data[0], batch);
678 assert!(!decoded.has_more);
679 }
680
681 #[test]
682 fn batch_result_succeeded_deserialization() {
683 let json = json!({
684 "custom_id": "my-first-request",
685 "result": {
686 "type": "succeeded",
687 "message": {
688 "id": "msg_123",
689 "type": "message",
690 "role": "assistant",
691 "model": "claude-opus-4-8",
692 "content": [{"type": "text", "text": "Hello"}],
693 "stop_reason": "end_turn",
694 "stop_sequence": null,
695 "usage": {"input_tokens": 10, "output_tokens": 2}
696 }
697 }
698 });
699
700 let result: MessageBatchResult = serde_json::from_value(json).unwrap();
701 match result.result {
702 MessageBatchResultVariant::Succeeded { message } => {
703 assert_eq!(message.id, "msg_123");
704 }
705 _ => panic!("expected succeeded result"),
706 }
707 }
708
709 #[test]
710 fn batch_result_errored_deserializes_standard_error_shape() {
711 let json = json!({
712 "custom_id": "bad-request",
713 "result": {
714 "type": "errored",
715 "error": {
716 "type": "error",
717 "error": {
718 "type": "invalid_request_error",
719 "message": "max_tokens must be at least 1"
720 }
721 }
722 }
723 });
724
725 let result: MessageBatchResult = serde_json::from_value(json).unwrap();
726 match result.result {
727 MessageBatchResultVariant::Errored { error } => {
728 assert_eq!(error.r#type, "error");
729 assert_eq!(error.error.r#type, "invalid_request_error");
730 }
731 _ => panic!("expected errored result"),
732 }
733 }
734
735 #[test]
736 fn batch_result_errored_deserializes_direct_error_shape() {
737 let json = json!({
738 "custom_id": "bad-request",
739 "result": {
740 "type": "errored",
741 "error": {
742 "type": "invalid_request_error",
743 "message": "max_tokens must be at least 1"
744 }
745 }
746 });
747
748 let result: MessageBatchResult = serde_json::from_value(json).unwrap();
749 match result.result {
750 MessageBatchResultVariant::Errored { error } => {
751 assert_eq!(error.r#type, "error");
752 assert_eq!(error.error.r#type, "invalid_request_error");
753 }
754 _ => panic!("expected errored result"),
755 }
756 }
757
758 #[test]
759 fn batch_result_canceled_and_expired_deserialization() {
760 let canceled: MessageBatchResult = serde_json::from_value(json!({
761 "custom_id": "canceled-request",
762 "result": {"type": "canceled"}
763 }))
764 .unwrap();
765 assert!(matches!(
766 canceled.result,
767 MessageBatchResultVariant::Canceled
768 ));
769
770 let expired: MessageBatchResult = serde_json::from_value(json!({
771 "custom_id": "expired-request",
772 "result": {"type": "expired"}
773 }))
774 .unwrap();
775 assert!(matches!(expired.result, MessageBatchResultVariant::Expired));
776 }
777
778 #[test]
779 fn deleted_message_batch_deserialization() {
780 let deleted: DeletedMessageBatch = serde_json::from_value(json!({
781 "id": "msgbatch_123",
782 "type": "message_batch_deleted"
783 }))
784 .unwrap();
785 assert_eq!(deleted.id, "msgbatch_123");
786 assert_eq!(deleted.r#type, "message_batch_deleted");
787 }
788
789 #[test]
790 fn message_batch_result_round_trip_succeeded() {
791 let message = Message::new(
792 "msg_123".to_string(),
793 vec![TextBlock::new("Hello").into()],
794 Model::Known(KnownModel::ClaudeOpus48),
795 Usage::new(1, 1),
796 );
797 let result = MessageBatchResult {
798 custom_id: "request-1".to_string(),
799 result: MessageBatchResultVariant::Succeeded { message },
800 };
801
802 let json = to_value(&result).unwrap();
803 assert_eq!(json["result"]["type"], "succeeded");
804 let decoded: MessageBatchResult = serde_json::from_value(json).unwrap();
805 assert_eq!(decoded.custom_id, "request-1");
806 }
807}