Skip to main content

agentik_sdk/types/
batches.rs

1use crate::types::{Message, MessageCreateParams};
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use chrono::{DateTime, Utc};
5
6/// A batch request for processing multiple messages efficiently
7#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct MessageBatch {
9    /// Unique identifier for the batch
10    pub id: String,
11    
12    /// The type of object (always "message_batch")
13    #[serde(rename = "type")]
14    pub object_type: String,
15    
16    /// Current processing status of the batch
17    pub processing_status: BatchStatus,
18    
19    /// Total number of requests in the batch
20    pub request_counts: BatchRequestCounts,
21    
22    /// When the batch was created
23    pub created_at: DateTime<Utc>,
24    
25    /// When the batch processing will expire
26    pub expires_at: DateTime<Utc>,
27    
28    /// When the batch processing was completed (if applicable)
29    pub ended_at: Option<DateTime<Utc>>,
30    
31    /// File ID containing the batch requests
32    pub input_file_id: String,
33    
34    /// File ID containing the batch results (if completed)
35    pub output_file_id: Option<String>,
36    
37    /// File ID containing any errors (if applicable)
38    pub error_file_id: Option<String>,
39    
40    /// Custom metadata for the batch
41    #[serde(default)]
42    pub metadata: HashMap<String, String>,
43}
44
45/// Status of batch processing
46#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
47#[serde(rename_all = "snake_case")]
48pub enum BatchStatus {
49    /// Batch is validating inputs
50    Validating,
51    
52    /// Batch is in the processing queue
53    InProgress,
54    
55    /// Batch is being processed
56    Finalizing,
57    
58    /// Batch processing completed successfully
59    Completed,
60    
61    /// Batch processing expired before completion
62    Expired,
63    
64    /// Batch processing was cancelled
65    Cancelling,
66    
67    /// Batch processing was cancelled
68    Cancelled,
69    
70    /// Batch processing failed
71    Failed,
72}
73
74impl BatchStatus {
75    /// Check if the batch is in a terminal state
76    pub fn is_terminal(&self) -> bool {
77        matches!(
78            self,
79            BatchStatus::Completed 
80            | BatchStatus::Expired 
81            | BatchStatus::Cancelled 
82            | BatchStatus::Failed
83        )
84    }
85    
86    /// Check if the batch is still being processed
87    pub fn is_processing(&self) -> bool {
88        matches!(
89            self,
90            BatchStatus::Validating 
91            | BatchStatus::InProgress 
92            | BatchStatus::Finalizing
93        )
94    }
95}
96
97/// Count of requests in different states within a batch
98#[derive(Debug, Clone, Serialize, Deserialize)]
99pub struct BatchRequestCounts {
100    /// Total number of requests in the batch
101    pub total: u32,
102    
103    /// Number of requests completed successfully
104    pub completed: u32,
105    
106    /// Number of requests that failed
107    pub failed: u32,
108}
109
110impl BatchRequestCounts {
111    /// Calculate the number of pending requests
112    pub fn pending(&self) -> u32 {
113        self.total.saturating_sub(self.completed + self.failed)
114    }
115    
116    /// Calculate completion percentage
117    pub fn completion_percentage(&self) -> f64 {
118        if self.total == 0 {
119            0.0
120        } else {
121            (self.completed as f64 / self.total as f64) * 100.0
122        }
123    }
124}
125
126/// Individual request within a batch
127#[derive(Debug, Clone, Serialize, Deserialize)]
128pub struct BatchRequest {
129    /// Custom ID for this request (for result matching)
130    pub custom_id: String,
131    
132    /// HTTP method (always "POST" for messages)
133    pub method: String,
134    
135    /// API endpoint URL
136    pub url: String,
137    
138    /// Request body containing message parameters
139    pub body: MessageCreateParams,
140}
141
142impl BatchRequest {
143    /// Create a new batch request
144    pub fn new(custom_id: impl Into<String>, model: impl Into<String>, max_tokens: u32) -> BatchRequestBuilder {
145        BatchRequestBuilder {
146            custom_id: custom_id.into(),
147            method: "POST".to_string(),
148            url: "/v1/messages".to_string(),
149            body: MessageCreateParams {
150                model: model.into(),
151                max_tokens,
152                messages: Vec::new(),
153                system: None,
154                temperature: None,
155                top_p: None,
156                top_k: None,
157                stop_sequences: None,
158                stream: Some(false), // Batches don't support streaming
159                tools: None,
160                tool_choice: None,
161                metadata: None,
162            },
163        }
164    }
165}
166
167/// Builder for creating batch requests
168#[derive(Debug, Clone)]
169pub struct BatchRequestBuilder {
170    custom_id: String,
171    method: String,
172    url: String,
173    body: MessageCreateParams,
174}
175
176impl BatchRequestBuilder {
177    /// Add a user message to the request
178    pub fn user(mut self, content: impl Into<String>) -> Self {
179        use crate::types::{MessageParam, Role, MessageContent};
180        
181        self.body.messages.push(MessageParam {
182            role: Role::User,
183            content: MessageContent::Text(content.into()),
184        });
185        self
186    }
187    
188    /// Add an assistant message to the request
189    pub fn assistant(mut self, content: impl Into<String>) -> Self {
190        use crate::types::{MessageParam, Role, MessageContent};
191        
192        self.body.messages.push(MessageParam {
193            role: Role::Assistant,
194            content: MessageContent::Text(content.into()),
195        });
196        self
197    }
198    
199    /// Set the system prompt for the request
200    pub fn system(mut self, system: impl Into<String>) -> Self {
201        self.body.system = Some(system.into());
202        self
203    }
204    
205    /// Set the temperature for the request
206    pub fn temperature(mut self, temperature: f32) -> Self {
207        self.body.temperature = Some(temperature);
208        self
209    }
210    
211    /// Set the top_p for the request
212    pub fn top_p(mut self, top_p: f32) -> Self {
213        self.body.top_p = Some(top_p);
214        self
215    }
216    
217    /// Set the top_k for the request
218    pub fn top_k(mut self, top_k: u32) -> Self {
219        self.body.top_k = Some(top_k);
220        self
221    }
222    
223    /// Add stop sequences for the request
224    pub fn stop_sequences(mut self, stop_sequences: Vec<String>) -> Self {
225        self.body.stop_sequences = Some(stop_sequences);
226        self
227    }
228    
229    /// Add tools to the request
230    pub fn tools(mut self, tools: Vec<crate::types::Tool>) -> Self {
231        self.body.tools = Some(tools);
232        self
233    }
234    
235    /// Set tool choice for the request
236    pub fn tool_choice(mut self, tool_choice: crate::types::ToolChoice) -> Self {
237        self.body.tool_choice = Some(tool_choice);
238        self
239    }
240    
241    /// Add metadata to the request
242    pub fn metadata(mut self, metadata: HashMap<String, String>) -> Self {
243        self.body.metadata = Some(metadata);
244        self
245    }
246    
247    /// Build the batch request
248    pub fn build(self) -> BatchRequest {
249        BatchRequest {
250            custom_id: self.custom_id,
251            method: self.method,
252            url: self.url,
253            body: self.body,
254        }
255    }
256}
257
258/// Result of a single request within a batch
259#[derive(Debug, Clone, Serialize, Deserialize)]
260pub struct BatchResult {
261    /// Custom ID from the original request
262    pub custom_id: String,
263    
264    /// HTTP response for this request
265    pub response: BatchResponse,
266}
267
268/// HTTP response for a batch request
269#[derive(Debug, Clone, Serialize, Deserialize)]
270pub struct BatchResponse {
271    /// HTTP status code
272    pub status_code: u16,
273    
274    /// Response headers
275    #[serde(default)]
276    pub headers: HashMap<String, String>,
277    
278    /// Response body (success or error)
279    pub body: BatchResponseBody,
280}
281
282/// Response body for a batch request
283#[derive(Debug, Clone, Serialize, Deserialize)]
284#[serde(untagged)]
285pub enum BatchResponseBody {
286    /// Successful message response
287    Success(Message),
288    
289    /// Error response
290    Error(BatchError),
291}
292
293/// Error response for a failed batch request
294#[derive(Debug, Clone, Serialize, Deserialize)]
295pub struct BatchError {
296    /// Error type
297    #[serde(rename = "type")]
298    pub error_type: String,
299    
300    /// Error message
301    pub message: String,
302    
303    /// Additional error details
304    #[serde(default)]
305    pub details: HashMap<String, serde_json::Value>,
306}
307
308/// Parameters for creating a new batch
309#[derive(Debug, Clone, Serialize, Deserialize)]
310pub struct BatchCreateParams {
311    /// Array of individual requests
312    pub requests: Vec<BatchRequest>,
313    
314    /// Custom metadata for the batch
315    #[serde(default, skip_serializing_if = "HashMap::is_empty")]
316    pub metadata: HashMap<String, String>,
317    
318    /// Completion window for the batch (in hours, default 24)
319    #[serde(skip_serializing_if = "Option::is_none")]
320    pub completion_window: Option<u32>,
321}
322
323impl BatchCreateParams {
324    /// Create new batch parameters
325    pub fn new(requests: Vec<BatchRequest>) -> Self {
326        Self {
327            requests,
328            metadata: HashMap::new(),
329            completion_window: None,
330        }
331    }
332    
333    /// Add metadata to the batch
334    pub fn with_metadata(mut self, metadata: HashMap<String, String>) -> Self {
335        self.metadata = metadata;
336        self
337    }
338    
339    /// Set completion window in hours
340    pub fn with_completion_window(mut self, hours: u32) -> Self {
341        self.completion_window = Some(hours);
342        self
343    }
344}
345
346/// Parameters for listing batches
347#[derive(Debug, Clone, Serialize, Deserialize, Default)]
348pub struct BatchListParams {
349    /// A cursor for use in pagination
350    #[serde(skip_serializing_if = "Option::is_none")]
351    pub after: Option<String>,
352    
353    /// Number of items to return (1-100, default 20)
354    #[serde(skip_serializing_if = "Option::is_none")]
355    pub limit: Option<u32>,
356}
357
358impl BatchListParams {
359    /// Create new list parameters
360    pub fn new() -> Self {
361        Self::default()
362    }
363    
364    /// Set pagination cursor
365    pub fn after(mut self, after: impl Into<String>) -> Self {
366        self.after = Some(after.into());
367        self
368    }
369    
370    /// Set result limit
371    pub fn limit(mut self, limit: u32) -> Self {
372        self.limit = Some(limit.clamp(1, 100));
373        self
374    }
375}
376
377/// Response containing a list of batches
378#[derive(Debug, Clone, Serialize, Deserialize)]
379pub struct BatchList {
380    /// List of batch objects
381    pub data: Vec<MessageBatch>,
382    
383    /// Whether there are more items available
384    pub has_more: bool,
385    
386    /// First ID in the current page
387    pub first_id: Option<String>,
388    
389    /// Last ID in the current page
390    pub last_id: Option<String>,
391}
392
393impl MessageBatch {
394    /// Check if the batch is complete
395    pub fn is_complete(&self) -> bool {
396        self.processing_status == BatchStatus::Completed
397    }
398    
399    /// Check if the batch has failed
400    pub fn has_failed(&self) -> bool {
401        matches!(
402            self.processing_status,
403            BatchStatus::Failed | BatchStatus::Expired
404        )
405    }
406    
407    /// Check if the batch can be cancelled
408    pub fn can_cancel(&self) -> bool {
409        self.processing_status.is_processing()
410    }
411    
412    /// Get completion percentage
413    pub fn completion_percentage(&self) -> f64 {
414        self.request_counts.completion_percentage()
415    }
416    
417    /// Get the number of pending requests
418    pub fn pending_requests(&self) -> u32 {
419        self.request_counts.pending()
420    }
421}
422
423#[cfg(test)]
424mod tests {
425    use super::*;
426
427    #[test]
428    fn test_batch_status_terminal() {
429        assert!(BatchStatus::Completed.is_terminal());
430        assert!(BatchStatus::Failed.is_terminal());
431        assert!(BatchStatus::Cancelled.is_terminal());
432        assert!(BatchStatus::Expired.is_terminal());
433        assert!(BatchStatus::InProgress.is_processing());
434    }
435
436    #[test]
437    fn test_batch_request_builder() {
438        let request = BatchRequest::new("test1", "claude-3-5-sonnet-latest", 1024)
439            .user("Hello, world!")
440            .system("You are a helpful assistant")
441            .temperature(0.7)
442            .build();
443
444        assert_eq!(request.custom_id, "test1");
445        assert_eq!(request.method, "POST");
446        assert_eq!(request.url, "/v1/messages");
447        assert_eq!(request.body.model, "claude-3-5-sonnet-latest");
448        assert_eq!(request.body.max_tokens, 1024);
449        assert_eq!(request.body.messages.len(), 1);
450        assert_eq!(request.body.system, Some("You are a helpful assistant".to_string()));
451        assert_eq!(request.body.temperature, Some(0.7));
452    }
453
454    #[test]
455    fn test_request_counts() {
456        let counts = BatchRequestCounts {
457            total: 100,
458            completed: 75,
459            failed: 10,
460        };
461
462        assert_eq!(counts.pending(), 15);
463        assert_eq!(counts.completion_percentage(), 75.0);
464    }
465
466    #[test]
467    fn test_batch_create_params() {
468        let requests = vec![
469            BatchRequest::new("req1", "claude-3-5-sonnet-latest", 1024)
470                .user("Hello")
471                .build(),
472        ];
473
474        let params = BatchCreateParams::new(requests)
475            .with_completion_window(12);
476
477        assert_eq!(params.requests.len(), 1);
478        assert_eq!(params.completion_window, Some(12));
479    }
480}