1use crate::types::{Message, MessageCreateParams};
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use chrono::{DateTime, Utc};
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct MessageBatch {
9 pub id: String,
11
12 #[serde(rename = "type")]
14 pub object_type: String,
15
16 pub processing_status: BatchStatus,
18
19 pub request_counts: BatchRequestCounts,
21
22 pub created_at: DateTime<Utc>,
24
25 pub expires_at: DateTime<Utc>,
27
28 pub ended_at: Option<DateTime<Utc>>,
30
31 pub input_file_id: String,
33
34 pub output_file_id: Option<String>,
36
37 pub error_file_id: Option<String>,
39
40 #[serde(default)]
42 pub metadata: HashMap<String, String>,
43}
44
45#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
47#[serde(rename_all = "snake_case")]
48pub enum BatchStatus {
49 Validating,
51
52 InProgress,
54
55 Finalizing,
57
58 Completed,
60
61 Expired,
63
64 Cancelling,
66
67 Cancelled,
69
70 Failed,
72}
73
74impl BatchStatus {
75 pub fn is_terminal(&self) -> bool {
77 matches!(
78 self,
79 BatchStatus::Completed
80 | BatchStatus::Expired
81 | BatchStatus::Cancelled
82 | BatchStatus::Failed
83 )
84 }
85
86 pub fn is_processing(&self) -> bool {
88 matches!(
89 self,
90 BatchStatus::Validating
91 | BatchStatus::InProgress
92 | BatchStatus::Finalizing
93 )
94 }
95}
96
97#[derive(Debug, Clone, Serialize, Deserialize)]
99pub struct BatchRequestCounts {
100 pub total: u32,
102
103 pub completed: u32,
105
106 pub failed: u32,
108}
109
110impl BatchRequestCounts {
111 pub fn pending(&self) -> u32 {
113 self.total.saturating_sub(self.completed + self.failed)
114 }
115
116 pub fn completion_percentage(&self) -> f64 {
118 if self.total == 0 {
119 0.0
120 } else {
121 (self.completed as f64 / self.total as f64) * 100.0
122 }
123 }
124}
125
126#[derive(Debug, Clone, Serialize, Deserialize)]
128pub struct BatchRequest {
129 pub custom_id: String,
131
132 pub method: String,
134
135 pub url: String,
137
138 pub body: MessageCreateParams,
140}
141
142impl BatchRequest {
143 pub fn new(custom_id: impl Into<String>, model: impl Into<String>, max_tokens: u32) -> BatchRequestBuilder {
145 BatchRequestBuilder {
146 custom_id: custom_id.into(),
147 method: "POST".to_string(),
148 url: "/v1/messages".to_string(),
149 body: MessageCreateParams {
150 model: model.into(),
151 max_tokens,
152 messages: Vec::new(),
153 system: None,
154 temperature: None,
155 top_p: None,
156 top_k: None,
157 stop_sequences: None,
158 stream: Some(false), tools: None,
160 tool_choice: None,
161 metadata: None,
162 },
163 }
164 }
165}
166
167#[derive(Debug, Clone)]
169pub struct BatchRequestBuilder {
170 custom_id: String,
171 method: String,
172 url: String,
173 body: MessageCreateParams,
174}
175
176impl BatchRequestBuilder {
177 pub fn user(mut self, content: impl Into<String>) -> Self {
179 use crate::types::{MessageParam, Role, MessageContent};
180
181 self.body.messages.push(MessageParam {
182 role: Role::User,
183 content: MessageContent::Text(content.into()),
184 });
185 self
186 }
187
188 pub fn assistant(mut self, content: impl Into<String>) -> Self {
190 use crate::types::{MessageParam, Role, MessageContent};
191
192 self.body.messages.push(MessageParam {
193 role: Role::Assistant,
194 content: MessageContent::Text(content.into()),
195 });
196 self
197 }
198
199 pub fn system(mut self, system: impl Into<String>) -> Self {
201 self.body.system = Some(system.into());
202 self
203 }
204
205 pub fn temperature(mut self, temperature: f32) -> Self {
207 self.body.temperature = Some(temperature);
208 self
209 }
210
211 pub fn top_p(mut self, top_p: f32) -> Self {
213 self.body.top_p = Some(top_p);
214 self
215 }
216
217 pub fn top_k(mut self, top_k: u32) -> Self {
219 self.body.top_k = Some(top_k);
220 self
221 }
222
223 pub fn stop_sequences(mut self, stop_sequences: Vec<String>) -> Self {
225 self.body.stop_sequences = Some(stop_sequences);
226 self
227 }
228
229 pub fn tools(mut self, tools: Vec<crate::types::Tool>) -> Self {
231 self.body.tools = Some(tools);
232 self
233 }
234
235 pub fn tool_choice(mut self, tool_choice: crate::types::ToolChoice) -> Self {
237 self.body.tool_choice = Some(tool_choice);
238 self
239 }
240
241 pub fn metadata(mut self, metadata: HashMap<String, String>) -> Self {
243 self.body.metadata = Some(metadata);
244 self
245 }
246
247 pub fn build(self) -> BatchRequest {
249 BatchRequest {
250 custom_id: self.custom_id,
251 method: self.method,
252 url: self.url,
253 body: self.body,
254 }
255 }
256}
257
258#[derive(Debug, Clone, Serialize, Deserialize)]
260pub struct BatchResult {
261 pub custom_id: String,
263
264 pub response: BatchResponse,
266}
267
268#[derive(Debug, Clone, Serialize, Deserialize)]
270pub struct BatchResponse {
271 pub status_code: u16,
273
274 #[serde(default)]
276 pub headers: HashMap<String, String>,
277
278 pub body: BatchResponseBody,
280}
281
282#[derive(Debug, Clone, Serialize, Deserialize)]
284#[serde(untagged)]
285pub enum BatchResponseBody {
286 Success(Message),
288
289 Error(BatchError),
291}
292
293#[derive(Debug, Clone, Serialize, Deserialize)]
295pub struct BatchError {
296 #[serde(rename = "type")]
298 pub error_type: String,
299
300 pub message: String,
302
303 #[serde(default)]
305 pub details: HashMap<String, serde_json::Value>,
306}
307
308#[derive(Debug, Clone, Serialize, Deserialize)]
310pub struct BatchCreateParams {
311 pub requests: Vec<BatchRequest>,
313
314 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
316 pub metadata: HashMap<String, String>,
317
318 #[serde(skip_serializing_if = "Option::is_none")]
320 pub completion_window: Option<u32>,
321}
322
323impl BatchCreateParams {
324 pub fn new(requests: Vec<BatchRequest>) -> Self {
326 Self {
327 requests,
328 metadata: HashMap::new(),
329 completion_window: None,
330 }
331 }
332
333 pub fn with_metadata(mut self, metadata: HashMap<String, String>) -> Self {
335 self.metadata = metadata;
336 self
337 }
338
339 pub fn with_completion_window(mut self, hours: u32) -> Self {
341 self.completion_window = Some(hours);
342 self
343 }
344}
345
346#[derive(Debug, Clone, Serialize, Deserialize, Default)]
348pub struct BatchListParams {
349 #[serde(skip_serializing_if = "Option::is_none")]
351 pub after: Option<String>,
352
353 #[serde(skip_serializing_if = "Option::is_none")]
355 pub limit: Option<u32>,
356}
357
358impl BatchListParams {
359 pub fn new() -> Self {
361 Self::default()
362 }
363
364 pub fn after(mut self, after: impl Into<String>) -> Self {
366 self.after = Some(after.into());
367 self
368 }
369
370 pub fn limit(mut self, limit: u32) -> Self {
372 self.limit = Some(limit.clamp(1, 100));
373 self
374 }
375}
376
377#[derive(Debug, Clone, Serialize, Deserialize)]
379pub struct BatchList {
380 pub data: Vec<MessageBatch>,
382
383 pub has_more: bool,
385
386 pub first_id: Option<String>,
388
389 pub last_id: Option<String>,
391}
392
393impl MessageBatch {
394 pub fn is_complete(&self) -> bool {
396 self.processing_status == BatchStatus::Completed
397 }
398
399 pub fn has_failed(&self) -> bool {
401 matches!(
402 self.processing_status,
403 BatchStatus::Failed | BatchStatus::Expired
404 )
405 }
406
407 pub fn can_cancel(&self) -> bool {
409 self.processing_status.is_processing()
410 }
411
412 pub fn completion_percentage(&self) -> f64 {
414 self.request_counts.completion_percentage()
415 }
416
417 pub fn pending_requests(&self) -> u32 {
419 self.request_counts.pending()
420 }
421}
422
423#[cfg(test)]
424mod tests {
425 use super::*;
426
427 #[test]
428 fn test_batch_status_terminal() {
429 assert!(BatchStatus::Completed.is_terminal());
430 assert!(BatchStatus::Failed.is_terminal());
431 assert!(BatchStatus::Cancelled.is_terminal());
432 assert!(BatchStatus::Expired.is_terminal());
433 assert!(BatchStatus::InProgress.is_processing());
434 }
435
436 #[test]
437 fn test_batch_request_builder() {
438 let request = BatchRequest::new("test1", "claude-3-5-sonnet-latest", 1024)
439 .user("Hello, world!")
440 .system("You are a helpful assistant")
441 .temperature(0.7)
442 .build();
443
444 assert_eq!(request.custom_id, "test1");
445 assert_eq!(request.method, "POST");
446 assert_eq!(request.url, "/v1/messages");
447 assert_eq!(request.body.model, "claude-3-5-sonnet-latest");
448 assert_eq!(request.body.max_tokens, 1024);
449 assert_eq!(request.body.messages.len(), 1);
450 assert_eq!(request.body.system, Some("You are a helpful assistant".to_string()));
451 assert_eq!(request.body.temperature, Some(0.7));
452 }
453
454 #[test]
455 fn test_request_counts() {
456 let counts = BatchRequestCounts {
457 total: 100,
458 completed: 75,
459 failed: 10,
460 };
461
462 assert_eq!(counts.pending(), 15);
463 assert_eq!(counts.completion_percentage(), 75.0);
464 }
465
466 #[test]
467 fn test_batch_create_params() {
468 let requests = vec![
469 BatchRequest::new("req1", "claude-3-5-sonnet-latest", 1024)
470 .user("Hello")
471 .build(),
472 ];
473
474 let params = BatchCreateParams::new(requests)
475 .with_completion_window(12);
476
477 assert_eq!(params.requests.len(), 1);
478 assert_eq!(params.completion_window, Some(12));
479 }
480}