openai_tools/batch/request.rs
1//! OpenAI Batch API Request Module
2//!
3//! This module provides the functionality to interact with the OpenAI Batch API.
4//! It allows you to create, list, retrieve, and cancel batch jobs.
5//!
6//! # Key Features
7//!
8//! - **Create Batch**: Submit a batch of requests for asynchronous processing
9//! - **Retrieve Batch**: Get the status and details of a batch job
10//! - **List Batches**: List all batch jobs
11//! - **Cancel Batch**: Cancel an in-progress batch job
12//!
13//! # Quick Start
14//!
15//! ```rust,no_run
16//! use openai_tools::batch::request::{Batches, CreateBatchRequest, BatchEndpoint, CompletionWindow};
17//!
18//! #[tokio::main]
19//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
20//! let batches = Batches::new()?;
21//!
22//! // List all batches
23//! let response = batches.list(None, None).await?;
24//! for batch in &response.data {
25//! println!("{}: {:?}", batch.id, batch.status);
26//! }
27//!
28//! Ok(())
29//! }
30//! ```
31
32use crate::batch::response::{BatchListResponse, BatchObject};
33use crate::common::auth::AuthProvider;
34use crate::common::client::create_http_client;
35use crate::common::errors::{OpenAIToolError, Result};
36use serde::Serialize;
37use std::collections::HashMap;
38use std::time::Duration;
39
40/// Default API path for Batches
41const BATCHES_PATH: &str = "batches";
42
43/// The API endpoint to use for batch requests.
44#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
45pub enum BatchEndpoint {
46 /// Chat Completions API (/v1/chat/completions)
47 #[serde(rename = "/v1/chat/completions")]
48 ChatCompletions,
49 /// Embeddings API (/v1/embeddings)
50 #[serde(rename = "/v1/embeddings")]
51 Embeddings,
52 /// Completions API (/v1/completions)
53 #[serde(rename = "/v1/completions")]
54 Completions,
55 /// Responses API (/v1/responses)
56 #[serde(rename = "/v1/responses")]
57 Responses,
58 /// Moderations API (/v1/moderations)
59 #[serde(rename = "/v1/moderations")]
60 Moderations,
61}
62
63impl BatchEndpoint {
64 /// Returns the string representation of the endpoint.
65 pub fn as_str(&self) -> &'static str {
66 match self {
67 BatchEndpoint::ChatCompletions => "/v1/chat/completions",
68 BatchEndpoint::Embeddings => "/v1/embeddings",
69 BatchEndpoint::Completions => "/v1/completions",
70 BatchEndpoint::Responses => "/v1/responses",
71 BatchEndpoint::Moderations => "/v1/moderations",
72 }
73 }
74}
75
76/// The time window in which the batch must be completed.
77#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Default)]
78pub enum CompletionWindow {
79 /// 24 hours
80 #[serde(rename = "24h")]
81 #[default]
82 Hours24,
83}
84
85impl CompletionWindow {
86 /// Returns the string representation of the completion window.
87 pub fn as_str(&self) -> &'static str {
88 match self {
89 CompletionWindow::Hours24 => "24h",
90 }
91 }
92}
93
94/// Request to create a new batch job.
95#[derive(Debug, Clone, Serialize)]
96pub struct CreateBatchRequest {
97 /// The ID of an uploaded file that contains requests for the batch.
98 /// The file must be uploaded with purpose "batch".
99 pub input_file_id: String,
100
101 /// The endpoint to use for all requests in the batch.
102 pub endpoint: BatchEndpoint,
103
104 /// The time window in which the batch must be completed.
105 pub completion_window: CompletionWindow,
106
107 /// Optional metadata to attach to the batch.
108 #[serde(skip_serializing_if = "Option::is_none")]
109 pub metadata: Option<HashMap<String, String>>,
110}
111
112impl CreateBatchRequest {
113 /// Creates a new batch request with the given input file ID and endpoint.
114 ///
115 /// # Arguments
116 ///
117 /// * `input_file_id` - The ID of the uploaded input file
118 /// * `endpoint` - The API endpoint to use for the batch
119 ///
120 /// # Example
121 ///
122 /// ```rust
123 /// use openai_tools::batch::request::{CreateBatchRequest, BatchEndpoint};
124 ///
125 /// let request = CreateBatchRequest::new("file-abc123", BatchEndpoint::ChatCompletions);
126 /// ```
127 pub fn new(input_file_id: impl Into<String>, endpoint: BatchEndpoint) -> Self {
128 Self { input_file_id: input_file_id.into(), endpoint, completion_window: CompletionWindow::default(), metadata: None }
129 }
130
131 /// Sets the metadata for the batch.
132 ///
133 /// # Arguments
134 ///
135 /// * `metadata` - Key-value pairs to attach to the batch
136 pub fn with_metadata(mut self, metadata: HashMap<String, String>) -> Self {
137 self.metadata = Some(metadata);
138 self
139 }
140}
141
142/// Client for interacting with the OpenAI Batch API.
143///
144/// This struct provides methods to create, list, retrieve, and cancel batch jobs.
145/// Use [`Batches::new()`] to create a new instance.
146///
147/// # Example
148///
149/// ```rust,no_run
150/// use openai_tools::batch::request::{Batches, CreateBatchRequest, BatchEndpoint};
151///
152/// #[tokio::main]
153/// async fn main() -> Result<(), Box<dyn std::error::Error>> {
154/// let batches = Batches::new()?;
155///
156/// // Create a batch job
157/// let request = CreateBatchRequest::new("file-abc123", BatchEndpoint::ChatCompletions);
158/// let batch = batches.create(request).await?;
159/// println!("Created batch: {} ({:?})", batch.id, batch.status);
160///
161/// Ok(())
162/// }
163/// ```
164pub struct Batches {
165 /// Authentication provider (OpenAI or Azure)
166 auth: AuthProvider,
167 /// Optional request timeout duration
168 timeout: Option<Duration>,
169}
170
171impl Batches {
172 /// Creates a new Batches client for OpenAI API.
173 ///
174 /// Initializes the client by loading the OpenAI API key from
175 /// the environment variable `OPENAI_API_KEY`. Supports `.env` file loading
176 /// via dotenvy.
177 ///
178 /// # Returns
179 ///
180 /// * `Ok(Batches)` - A new Batches client ready for use
181 /// * `Err(OpenAIToolError)` - If the API key is not found in the environment
182 ///
183 /// # Example
184 ///
185 /// ```rust,no_run
186 /// use openai_tools::batch::request::Batches;
187 ///
188 /// let batches = Batches::new().expect("API key should be set");
189 /// ```
190 pub fn new() -> Result<Self> {
191 let auth = AuthProvider::openai_from_env()?;
192 Ok(Self { auth, timeout: None })
193 }
194
195 /// Creates a new Batches client with a custom authentication provider
196 pub fn with_auth(auth: AuthProvider) -> Self {
197 Self { auth, timeout: None }
198 }
199
200 /// Creates a new Batches client for Azure OpenAI API
201 pub fn azure() -> Result<Self> {
202 let auth = AuthProvider::azure_from_env()?;
203 Ok(Self { auth, timeout: None })
204 }
205
206 /// Creates a new Batches client by auto-detecting the provider
207 pub fn detect_provider() -> Result<Self> {
208 let auth = AuthProvider::from_env()?;
209 Ok(Self { auth, timeout: None })
210 }
211
212 /// Creates a new Batches client with URL-based provider detection
213 pub fn with_url<S: Into<String>>(base_url: S, api_key: S) -> Self {
214 let auth = AuthProvider::from_url_with_key(base_url, api_key);
215 Self { auth, timeout: None }
216 }
217
218 /// Creates a new Batches client from URL using environment variables
219 pub fn from_url<S: Into<String>>(url: S) -> Result<Self> {
220 let auth = AuthProvider::from_url(url)?;
221 Ok(Self { auth, timeout: None })
222 }
223
224 /// Returns the authentication provider
225 pub fn auth(&self) -> &AuthProvider {
226 &self.auth
227 }
228
229 /// Sets the request timeout duration.
230 ///
231 /// # Arguments
232 ///
233 /// * `timeout` - The maximum time to wait for a response
234 ///
235 /// # Returns
236 ///
237 /// A mutable reference to self for method chaining
238 pub fn timeout(&mut self, timeout: Duration) -> &mut Self {
239 self.timeout = Some(timeout);
240 self
241 }
242
243 /// Creates the HTTP client with default headers.
244 fn create_client(&self) -> Result<(request::Client, request::header::HeaderMap)> {
245 let client = create_http_client(self.timeout)?;
246 let mut headers = request::header::HeaderMap::new();
247 self.auth.apply_headers(&mut headers)?;
248 headers.insert("Content-Type", request::header::HeaderValue::from_static("application/json"));
249 headers.insert("User-Agent", request::header::HeaderValue::from_static("openai-tools-rust"));
250 Ok((client, headers))
251 }
252
253 /// Creates a new batch job.
254 ///
255 /// # Arguments
256 ///
257 /// * `request` - The batch creation request
258 ///
259 /// # Returns
260 ///
261 /// * `Ok(BatchObject)` - The created batch object
262 /// * `Err(OpenAIToolError)` - If the request fails
263 ///
264 /// # Example
265 ///
266 /// ```rust,no_run
267 /// use openai_tools::batch::request::{Batches, CreateBatchRequest, BatchEndpoint};
268 /// use std::collections::HashMap;
269 ///
270 /// #[tokio::main]
271 /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
272 /// let batches = Batches::new()?;
273 ///
274 /// let mut metadata = HashMap::new();
275 /// metadata.insert("customer_id".to_string(), "user_123".to_string());
276 ///
277 /// let request = CreateBatchRequest::new("file-abc123", BatchEndpoint::ChatCompletions)
278 /// .with_metadata(metadata);
279 ///
280 /// let batch = batches.create(request).await?;
281 /// println!("Created batch: {}", batch.id);
282 /// Ok(())
283 /// }
284 /// ```
285 pub async fn create(&self, request: CreateBatchRequest) -> Result<BatchObject> {
286 let (client, headers) = self.create_client()?;
287
288 let body = serde_json::to_string(&request).map_err(OpenAIToolError::SerdeJsonError)?;
289
290 let url = self.auth.endpoint(BATCHES_PATH);
291 let response = client.post(&url).headers(headers).body(body).send().await.map_err(OpenAIToolError::RequestError)?;
292
293 let content = response.text().await.map_err(OpenAIToolError::RequestError)?;
294
295 if cfg!(test) {
296 tracing::info!("Response content: {}", content);
297 }
298
299 serde_json::from_str::<BatchObject>(&content).map_err(OpenAIToolError::SerdeJsonError)
300 }
301
302 /// Retrieves details of a specific batch job.
303 ///
304 /// # Arguments
305 ///
306 /// * `batch_id` - The ID of the batch to retrieve
307 ///
308 /// # Returns
309 ///
310 /// * `Ok(BatchObject)` - The batch details
311 /// * `Err(OpenAIToolError)` - If the batch is not found or the request fails
312 ///
313 /// # Example
314 ///
315 /// ```rust,no_run
316 /// use openai_tools::batch::request::Batches;
317 ///
318 /// #[tokio::main]
319 /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
320 /// let batches = Batches::new()?;
321 /// let batch = batches.retrieve("batch_abc123").await?;
322 ///
323 /// println!("Status: {:?}", batch.status);
324 /// if let Some(counts) = &batch.request_counts {
325 /// println!("Completed: {}/{}", counts.completed, counts.total);
326 /// }
327 /// Ok(())
328 /// }
329 /// ```
330 pub async fn retrieve(&self, batch_id: &str) -> Result<BatchObject> {
331 let (client, headers) = self.create_client()?;
332 let url = format!("{}/{}", self.auth.endpoint(BATCHES_PATH), batch_id);
333
334 let response = client.get(&url).headers(headers).send().await.map_err(OpenAIToolError::RequestError)?;
335
336 let content = response.text().await.map_err(OpenAIToolError::RequestError)?;
337
338 if cfg!(test) {
339 tracing::info!("Response content: {}", content);
340 }
341
342 serde_json::from_str::<BatchObject>(&content).map_err(OpenAIToolError::SerdeJsonError)
343 }
344
345 /// Cancels an in-progress batch job.
346 ///
347 /// The batch will transition to "cancelling" and eventually "cancelled".
348 ///
349 /// # Arguments
350 ///
351 /// * `batch_id` - The ID of the batch to cancel
352 ///
353 /// # Returns
354 ///
355 /// * `Ok(BatchObject)` - The updated batch object
356 /// * `Err(OpenAIToolError)` - If the batch cannot be cancelled or the request fails
357 ///
358 /// # Example
359 ///
360 /// ```rust,no_run
361 /// use openai_tools::batch::request::Batches;
362 ///
363 /// #[tokio::main]
364 /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
365 /// let batches = Batches::new()?;
366 /// let batch = batches.cancel("batch_abc123").await?;
367 ///
368 /// println!("Batch status: {:?}", batch.status);
369 /// Ok(())
370 /// }
371 /// ```
372 pub async fn cancel(&self, batch_id: &str) -> Result<BatchObject> {
373 let (client, headers) = self.create_client()?;
374 let url = format!("{}/{}/cancel", self.auth.endpoint(BATCHES_PATH), batch_id);
375
376 let response = client.post(&url).headers(headers).send().await.map_err(OpenAIToolError::RequestError)?;
377
378 let content = response.text().await.map_err(OpenAIToolError::RequestError)?;
379
380 if cfg!(test) {
381 tracing::info!("Response content: {}", content);
382 }
383
384 serde_json::from_str::<BatchObject>(&content).map_err(OpenAIToolError::SerdeJsonError)
385 }
386
387 /// Lists all batch jobs.
388 ///
389 /// Supports pagination through `limit` and `after` parameters.
390 ///
391 /// # Arguments
392 ///
393 /// * `limit` - Maximum number of batches to return (default: 20)
394 /// * `after` - Cursor for pagination (batch ID to start after)
395 ///
396 /// # Returns
397 ///
398 /// * `Ok(BatchListResponse)` - The list of batch jobs
399 /// * `Err(OpenAIToolError)` - If the request fails
400 ///
401 /// # Example
402 ///
403 /// ```rust,no_run
404 /// use openai_tools::batch::request::Batches;
405 ///
406 /// #[tokio::main]
407 /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
408 /// let batches = Batches::new()?;
409 ///
410 /// // Get first page
411 /// let response = batches.list(Some(10), None).await?;
412 /// for batch in &response.data {
413 /// println!("{}: {:?}", batch.id, batch.status);
414 /// }
415 ///
416 /// // Get next page if available
417 /// if response.has_more {
418 /// if let Some(last_id) = &response.last_id {
419 /// let next_page = batches.list(Some(10), Some(last_id)).await?;
420 /// // ...
421 /// }
422 /// }
423 ///
424 /// Ok(())
425 /// }
426 /// ```
427 pub async fn list(&self, limit: Option<u32>, after: Option<&str>) -> Result<BatchListResponse> {
428 let (client, headers) = self.create_client()?;
429
430 let mut url = self.auth.endpoint(BATCHES_PATH);
431 let mut params = Vec::new();
432
433 if let Some(l) = limit {
434 params.push(format!("limit={}", l));
435 }
436 if let Some(a) = after {
437 params.push(format!("after={}", a));
438 }
439
440 if !params.is_empty() {
441 url.push('?');
442 url.push_str(¶ms.join("&"));
443 }
444
445 let response = client.get(&url).headers(headers).send().await.map_err(OpenAIToolError::RequestError)?;
446
447 let content = response.text().await.map_err(OpenAIToolError::RequestError)?;
448
449 if cfg!(test) {
450 tracing::info!("Response content: {}", content);
451 }
452
453 serde_json::from_str::<BatchListResponse>(&content).map_err(OpenAIToolError::SerdeJsonError)
454 }
455}