Skip to main content

openai_tools/batch/
request.rs

1//! OpenAI Batch API Request Module
2//!
3//! This module provides the functionality to interact with the OpenAI Batch API.
4//! It allows you to create, list, retrieve, and cancel batch jobs.
5//!
6//! # Key Features
7//!
8//! - **Create Batch**: Submit a batch of requests for asynchronous processing
9//! - **Retrieve Batch**: Get the status and details of a batch job
10//! - **List Batches**: List all batch jobs
11//! - **Cancel Batch**: Cancel an in-progress batch job
12//!
13//! # Quick Start
14//!
15//! ```rust,no_run
16//! use openai_tools::batch::request::{Batches, CreateBatchRequest, BatchEndpoint, CompletionWindow};
17//!
18//! #[tokio::main]
19//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
20//!     let batches = Batches::new()?;
21//!
22//!     // List all batches
23//!     let response = batches.list(None, None).await?;
24//!     for batch in &response.data {
25//!         println!("{}: {:?}", batch.id, batch.status);
26//!     }
27//!
28//!     Ok(())
29//! }
30//! ```
31
32use crate::batch::response::{BatchListResponse, BatchObject};
33use crate::common::auth::AuthProvider;
34use crate::common::client::create_http_client;
35use crate::common::errors::{OpenAIToolError, Result};
36use serde::Serialize;
37use std::collections::HashMap;
38use std::time::Duration;
39
40/// Default API path for Batches
41const BATCHES_PATH: &str = "batches";
42
43/// The API endpoint to use for batch requests.
44#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
45pub enum BatchEndpoint {
46    /// Chat Completions API (/v1/chat/completions)
47    #[serde(rename = "/v1/chat/completions")]
48    ChatCompletions,
49    /// Embeddings API (/v1/embeddings)
50    #[serde(rename = "/v1/embeddings")]
51    Embeddings,
52    /// Completions API (/v1/completions)
53    #[serde(rename = "/v1/completions")]
54    Completions,
55    /// Responses API (/v1/responses)
56    #[serde(rename = "/v1/responses")]
57    Responses,
58    /// Moderations API (/v1/moderations)
59    #[serde(rename = "/v1/moderations")]
60    Moderations,
61}
62
63impl BatchEndpoint {
64    /// Returns the string representation of the endpoint.
65    pub fn as_str(&self) -> &'static str {
66        match self {
67            BatchEndpoint::ChatCompletions => "/v1/chat/completions",
68            BatchEndpoint::Embeddings => "/v1/embeddings",
69            BatchEndpoint::Completions => "/v1/completions",
70            BatchEndpoint::Responses => "/v1/responses",
71            BatchEndpoint::Moderations => "/v1/moderations",
72        }
73    }
74}
75
76/// The time window in which the batch must be completed.
77#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Default)]
78pub enum CompletionWindow {
79    /// 24 hours
80    #[serde(rename = "24h")]
81    #[default]
82    Hours24,
83}
84
85impl CompletionWindow {
86    /// Returns the string representation of the completion window.
87    pub fn as_str(&self) -> &'static str {
88        match self {
89            CompletionWindow::Hours24 => "24h",
90        }
91    }
92}
93
94/// Request to create a new batch job.
95#[derive(Debug, Clone, Serialize)]
96pub struct CreateBatchRequest {
97    /// The ID of an uploaded file that contains requests for the batch.
98    /// The file must be uploaded with purpose "batch".
99    pub input_file_id: String,
100
101    /// The endpoint to use for all requests in the batch.
102    pub endpoint: BatchEndpoint,
103
104    /// The time window in which the batch must be completed.
105    pub completion_window: CompletionWindow,
106
107    /// Optional metadata to attach to the batch.
108    #[serde(skip_serializing_if = "Option::is_none")]
109    pub metadata: Option<HashMap<String, String>>,
110}
111
112impl CreateBatchRequest {
113    /// Creates a new batch request with the given input file ID and endpoint.
114    ///
115    /// # Arguments
116    ///
117    /// * `input_file_id` - The ID of the uploaded input file
118    /// * `endpoint` - The API endpoint to use for the batch
119    ///
120    /// # Example
121    ///
122    /// ```rust
123    /// use openai_tools::batch::request::{CreateBatchRequest, BatchEndpoint};
124    ///
125    /// let request = CreateBatchRequest::new("file-abc123", BatchEndpoint::ChatCompletions);
126    /// ```
127    pub fn new(input_file_id: impl Into<String>, endpoint: BatchEndpoint) -> Self {
128        Self { input_file_id: input_file_id.into(), endpoint, completion_window: CompletionWindow::default(), metadata: None }
129    }
130
131    /// Sets the metadata for the batch.
132    ///
133    /// # Arguments
134    ///
135    /// * `metadata` - Key-value pairs to attach to the batch
136    pub fn with_metadata(mut self, metadata: HashMap<String, String>) -> Self {
137        self.metadata = Some(metadata);
138        self
139    }
140}
141
142/// Client for interacting with the OpenAI Batch API.
143///
144/// This struct provides methods to create, list, retrieve, and cancel batch jobs.
145/// Use [`Batches::new()`] to create a new instance.
146///
147/// # Example
148///
149/// ```rust,no_run
150/// use openai_tools::batch::request::{Batches, CreateBatchRequest, BatchEndpoint};
151///
152/// #[tokio::main]
153/// async fn main() -> Result<(), Box<dyn std::error::Error>> {
154///     let batches = Batches::new()?;
155///
156///     // Create a batch job
157///     let request = CreateBatchRequest::new("file-abc123", BatchEndpoint::ChatCompletions);
158///     let batch = batches.create(request).await?;
159///     println!("Created batch: {} ({:?})", batch.id, batch.status);
160///
161///     Ok(())
162/// }
163/// ```
164pub struct Batches {
165    /// Authentication provider (OpenAI or Azure)
166    auth: AuthProvider,
167    /// Optional request timeout duration
168    timeout: Option<Duration>,
169}
170
171impl Batches {
172    /// Creates a new Batches client for OpenAI API.
173    ///
174    /// Initializes the client by loading the OpenAI API key from
175    /// the environment variable `OPENAI_API_KEY`. Supports `.env` file loading
176    /// via dotenvy.
177    ///
178    /// # Returns
179    ///
180    /// * `Ok(Batches)` - A new Batches client ready for use
181    /// * `Err(OpenAIToolError)` - If the API key is not found in the environment
182    ///
183    /// # Example
184    ///
185    /// ```rust,no_run
186    /// use openai_tools::batch::request::Batches;
187    ///
188    /// let batches = Batches::new().expect("API key should be set");
189    /// ```
190    pub fn new() -> Result<Self> {
191        let auth = AuthProvider::openai_from_env()?;
192        Ok(Self { auth, timeout: None })
193    }
194
195    /// Creates a new Batches client with a custom authentication provider
196    pub fn with_auth(auth: AuthProvider) -> Self {
197        Self { auth, timeout: None }
198    }
199
200    /// Creates a new Batches client for Azure OpenAI API
201    pub fn azure() -> Result<Self> {
202        let auth = AuthProvider::azure_from_env()?;
203        Ok(Self { auth, timeout: None })
204    }
205
206    /// Creates a new Batches client by auto-detecting the provider
207    pub fn detect_provider() -> Result<Self> {
208        let auth = AuthProvider::from_env()?;
209        Ok(Self { auth, timeout: None })
210    }
211
212    /// Creates a new Batches client with URL-based provider detection
213    pub fn with_url<S: Into<String>>(base_url: S, api_key: S) -> Self {
214        let auth = AuthProvider::from_url_with_key(base_url, api_key);
215        Self { auth, timeout: None }
216    }
217
218    /// Creates a new Batches client from URL using environment variables
219    pub fn from_url<S: Into<String>>(url: S) -> Result<Self> {
220        let auth = AuthProvider::from_url(url)?;
221        Ok(Self { auth, timeout: None })
222    }
223
224    /// Returns the authentication provider
225    pub fn auth(&self) -> &AuthProvider {
226        &self.auth
227    }
228
229    /// Sets the request timeout duration.
230    ///
231    /// # Arguments
232    ///
233    /// * `timeout` - The maximum time to wait for a response
234    ///
235    /// # Returns
236    ///
237    /// A mutable reference to self for method chaining
238    pub fn timeout(&mut self, timeout: Duration) -> &mut Self {
239        self.timeout = Some(timeout);
240        self
241    }
242
243    /// Creates the HTTP client with default headers.
244    fn create_client(&self) -> Result<(request::Client, request::header::HeaderMap)> {
245        let client = create_http_client(self.timeout)?;
246        let mut headers = request::header::HeaderMap::new();
247        self.auth.apply_headers(&mut headers)?;
248        headers.insert("Content-Type", request::header::HeaderValue::from_static("application/json"));
249        headers.insert("User-Agent", request::header::HeaderValue::from_static("openai-tools-rust"));
250        Ok((client, headers))
251    }
252
253    /// Creates a new batch job.
254    ///
255    /// # Arguments
256    ///
257    /// * `request` - The batch creation request
258    ///
259    /// # Returns
260    ///
261    /// * `Ok(BatchObject)` - The created batch object
262    /// * `Err(OpenAIToolError)` - If the request fails
263    ///
264    /// # Example
265    ///
266    /// ```rust,no_run
267    /// use openai_tools::batch::request::{Batches, CreateBatchRequest, BatchEndpoint};
268    /// use std::collections::HashMap;
269    ///
270    /// #[tokio::main]
271    /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
272    ///     let batches = Batches::new()?;
273    ///
274    ///     let mut metadata = HashMap::new();
275    ///     metadata.insert("customer_id".to_string(), "user_123".to_string());
276    ///
277    ///     let request = CreateBatchRequest::new("file-abc123", BatchEndpoint::ChatCompletions)
278    ///         .with_metadata(metadata);
279    ///
280    ///     let batch = batches.create(request).await?;
281    ///     println!("Created batch: {}", batch.id);
282    ///     Ok(())
283    /// }
284    /// ```
285    pub async fn create(&self, request: CreateBatchRequest) -> Result<BatchObject> {
286        let (client, headers) = self.create_client()?;
287
288        let body = serde_json::to_string(&request).map_err(OpenAIToolError::SerdeJsonError)?;
289
290        let url = self.auth.endpoint(BATCHES_PATH);
291        let response = client.post(&url).headers(headers).body(body).send().await.map_err(OpenAIToolError::RequestError)?;
292
293        let content = response.text().await.map_err(OpenAIToolError::RequestError)?;
294
295        if cfg!(test) {
296            tracing::info!("Response content: {}", content);
297        }
298
299        serde_json::from_str::<BatchObject>(&content).map_err(OpenAIToolError::SerdeJsonError)
300    }
301
302    /// Retrieves details of a specific batch job.
303    ///
304    /// # Arguments
305    ///
306    /// * `batch_id` - The ID of the batch to retrieve
307    ///
308    /// # Returns
309    ///
310    /// * `Ok(BatchObject)` - The batch details
311    /// * `Err(OpenAIToolError)` - If the batch is not found or the request fails
312    ///
313    /// # Example
314    ///
315    /// ```rust,no_run
316    /// use openai_tools::batch::request::Batches;
317    ///
318    /// #[tokio::main]
319    /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
320    ///     let batches = Batches::new()?;
321    ///     let batch = batches.retrieve("batch_abc123").await?;
322    ///
323    ///     println!("Status: {:?}", batch.status);
324    ///     if let Some(counts) = &batch.request_counts {
325    ///         println!("Completed: {}/{}", counts.completed, counts.total);
326    ///     }
327    ///     Ok(())
328    /// }
329    /// ```
330    pub async fn retrieve(&self, batch_id: &str) -> Result<BatchObject> {
331        let (client, headers) = self.create_client()?;
332        let url = format!("{}/{}", self.auth.endpoint(BATCHES_PATH), batch_id);
333
334        let response = client.get(&url).headers(headers).send().await.map_err(OpenAIToolError::RequestError)?;
335
336        let content = response.text().await.map_err(OpenAIToolError::RequestError)?;
337
338        if cfg!(test) {
339            tracing::info!("Response content: {}", content);
340        }
341
342        serde_json::from_str::<BatchObject>(&content).map_err(OpenAIToolError::SerdeJsonError)
343    }
344
345    /// Cancels an in-progress batch job.
346    ///
347    /// The batch will transition to "cancelling" and eventually "cancelled".
348    ///
349    /// # Arguments
350    ///
351    /// * `batch_id` - The ID of the batch to cancel
352    ///
353    /// # Returns
354    ///
355    /// * `Ok(BatchObject)` - The updated batch object
356    /// * `Err(OpenAIToolError)` - If the batch cannot be cancelled or the request fails
357    ///
358    /// # Example
359    ///
360    /// ```rust,no_run
361    /// use openai_tools::batch::request::Batches;
362    ///
363    /// #[tokio::main]
364    /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
365    ///     let batches = Batches::new()?;
366    ///     let batch = batches.cancel("batch_abc123").await?;
367    ///
368    ///     println!("Batch status: {:?}", batch.status);
369    ///     Ok(())
370    /// }
371    /// ```
372    pub async fn cancel(&self, batch_id: &str) -> Result<BatchObject> {
373        let (client, headers) = self.create_client()?;
374        let url = format!("{}/{}/cancel", self.auth.endpoint(BATCHES_PATH), batch_id);
375
376        let response = client.post(&url).headers(headers).send().await.map_err(OpenAIToolError::RequestError)?;
377
378        let content = response.text().await.map_err(OpenAIToolError::RequestError)?;
379
380        if cfg!(test) {
381            tracing::info!("Response content: {}", content);
382        }
383
384        serde_json::from_str::<BatchObject>(&content).map_err(OpenAIToolError::SerdeJsonError)
385    }
386
387    /// Lists all batch jobs.
388    ///
389    /// Supports pagination through `limit` and `after` parameters.
390    ///
391    /// # Arguments
392    ///
393    /// * `limit` - Maximum number of batches to return (default: 20)
394    /// * `after` - Cursor for pagination (batch ID to start after)
395    ///
396    /// # Returns
397    ///
398    /// * `Ok(BatchListResponse)` - The list of batch jobs
399    /// * `Err(OpenAIToolError)` - If the request fails
400    ///
401    /// # Example
402    ///
403    /// ```rust,no_run
404    /// use openai_tools::batch::request::Batches;
405    ///
406    /// #[tokio::main]
407    /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
408    ///     let batches = Batches::new()?;
409    ///
410    ///     // Get first page
411    ///     let response = batches.list(Some(10), None).await?;
412    ///     for batch in &response.data {
413    ///         println!("{}: {:?}", batch.id, batch.status);
414    ///     }
415    ///
416    ///     // Get next page if available
417    ///     if response.has_more {
418    ///         if let Some(last_id) = &response.last_id {
419    ///             let next_page = batches.list(Some(10), Some(last_id)).await?;
420    ///             // ...
421    ///         }
422    ///     }
423    ///
424    ///     Ok(())
425    /// }
426    /// ```
427    pub async fn list(&self, limit: Option<u32>, after: Option<&str>) -> Result<BatchListResponse> {
428        let (client, headers) = self.create_client()?;
429
430        let mut url = self.auth.endpoint(BATCHES_PATH);
431        let mut params = Vec::new();
432
433        if let Some(l) = limit {
434            params.push(format!("limit={}", l));
435        }
436        if let Some(a) = after {
437            params.push(format!("after={}", a));
438        }
439
440        if !params.is_empty() {
441            url.push('?');
442            url.push_str(&params.join("&"));
443        }
444
445        let response = client.get(&url).headers(headers).send().await.map_err(OpenAIToolError::RequestError)?;
446
447        let content = response.text().await.map_err(OpenAIToolError::RequestError)?;
448
449        if cfg!(test) {
450            tracing::info!("Response content: {}", content);
451        }
452
453        serde_json::from_str::<BatchListResponse>(&content).map_err(OpenAIToolError::SerdeJsonError)
454    }
455}