openai-tools 1.1.0

Tools for OpenAI API
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
//! OpenAI Batch API Request Module
//!
//! This module provides the functionality to interact with the OpenAI Batch API.
//! It allows you to create, list, retrieve, and cancel batch jobs.
//!
//! # Key Features
//!
//! - **Create Batch**: Submit a batch of requests for asynchronous processing
//! - **Retrieve Batch**: Get the status and details of a batch job
//! - **List Batches**: List all batch jobs
//! - **Cancel Batch**: Cancel an in-progress batch job
//!
//! # Quick Start
//!
//! ```rust,no_run
//! use openai_tools::batch::request::{Batches, CreateBatchRequest, BatchEndpoint, CompletionWindow};
//!
//! #[tokio::main]
//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
//!     let batches = Batches::new()?;
//!
//!     // List all batches
//!     let response = batches.list(None, None).await?;
//!     for batch in &response.data {
//!         println!("{}: {:?}", batch.id, batch.status);
//!     }
//!
//!     Ok(())
//! }
//! ```

use crate::batch::response::{BatchListResponse, BatchObject};
use crate::common::auth::AuthProvider;
use crate::common::client::create_http_client;
use crate::common::errors::{OpenAIToolError, Result};
use serde::Serialize;
use std::collections::HashMap;
use std::time::Duration;

/// Default API path for Batches
const BATCHES_PATH: &str = "batches";

/// The API endpoint to use for batch requests.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
pub enum BatchEndpoint {
    /// Chat Completions API (/v1/chat/completions)
    #[serde(rename = "/v1/chat/completions")]
    ChatCompletions,
    /// Embeddings API (/v1/embeddings)
    #[serde(rename = "/v1/embeddings")]
    Embeddings,
    /// Completions API (/v1/completions)
    #[serde(rename = "/v1/completions")]
    Completions,
    /// Responses API (/v1/responses)
    #[serde(rename = "/v1/responses")]
    Responses,
    /// Moderations API (/v1/moderations)
    #[serde(rename = "/v1/moderations")]
    Moderations,
}

impl BatchEndpoint {
    /// Returns the string representation of the endpoint.
    pub fn as_str(&self) -> &'static str {
        match self {
            BatchEndpoint::ChatCompletions => "/v1/chat/completions",
            BatchEndpoint::Embeddings => "/v1/embeddings",
            BatchEndpoint::Completions => "/v1/completions",
            BatchEndpoint::Responses => "/v1/responses",
            BatchEndpoint::Moderations => "/v1/moderations",
        }
    }
}

/// The time window in which the batch must be completed.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Default)]
pub enum CompletionWindow {
    /// 24 hours
    #[serde(rename = "24h")]
    #[default]
    Hours24,
}

impl CompletionWindow {
    /// Returns the string representation of the completion window.
    pub fn as_str(&self) -> &'static str {
        match self {
            CompletionWindow::Hours24 => "24h",
        }
    }
}

/// Request to create a new batch job.
#[derive(Debug, Clone, Serialize)]
pub struct CreateBatchRequest {
    /// The ID of an uploaded file that contains requests for the batch.
    /// The file must be uploaded with purpose "batch".
    pub input_file_id: String,

    /// The endpoint to use for all requests in the batch.
    pub endpoint: BatchEndpoint,

    /// The time window in which the batch must be completed.
    pub completion_window: CompletionWindow,

    /// Optional metadata to attach to the batch.
    #[serde(skip_serializing_if = "Option::is_none")]
    pub metadata: Option<HashMap<String, String>>,
}

impl CreateBatchRequest {
    /// Creates a new batch request with the given input file ID and endpoint.
    ///
    /// # Arguments
    ///
    /// * `input_file_id` - The ID of the uploaded input file
    /// * `endpoint` - The API endpoint to use for the batch
    ///
    /// # Example
    ///
    /// ```rust
    /// use openai_tools::batch::request::{CreateBatchRequest, BatchEndpoint};
    ///
    /// let request = CreateBatchRequest::new("file-abc123", BatchEndpoint::ChatCompletions);
    /// ```
    pub fn new(input_file_id: impl Into<String>, endpoint: BatchEndpoint) -> Self {
        Self { input_file_id: input_file_id.into(), endpoint, completion_window: CompletionWindow::default(), metadata: None }
    }

    /// Sets the metadata for the batch.
    ///
    /// # Arguments
    ///
    /// * `metadata` - Key-value pairs to attach to the batch
    pub fn with_metadata(mut self, metadata: HashMap<String, String>) -> Self {
        self.metadata = Some(metadata);
        self
    }
}

/// Client for interacting with the OpenAI Batch API.
///
/// This struct provides methods to create, list, retrieve, and cancel batch jobs.
/// Use [`Batches::new()`] to create a new instance.
///
/// # Example
///
/// ```rust,no_run
/// use openai_tools::batch::request::{Batches, CreateBatchRequest, BatchEndpoint};
///
/// #[tokio::main]
/// async fn main() -> Result<(), Box<dyn std::error::Error>> {
///     let batches = Batches::new()?;
///
///     // Create a batch job
///     let request = CreateBatchRequest::new("file-abc123", BatchEndpoint::ChatCompletions);
///     let batch = batches.create(request).await?;
///     println!("Created batch: {} ({:?})", batch.id, batch.status);
///
///     Ok(())
/// }
/// ```
pub struct Batches {
    /// Authentication provider (OpenAI or Azure)
    auth: AuthProvider,
    /// Optional request timeout duration
    timeout: Option<Duration>,
}

impl Batches {
    /// Creates a new Batches client for OpenAI API.
    ///
    /// Initializes the client by loading the OpenAI API key from
    /// the environment variable `OPENAI_API_KEY`. Supports `.env` file loading
    /// via dotenvy.
    ///
    /// # Returns
    ///
    /// * `Ok(Batches)` - A new Batches client ready for use
    /// * `Err(OpenAIToolError)` - If the API key is not found in the environment
    ///
    /// # Example
    ///
    /// ```rust,no_run
    /// use openai_tools::batch::request::Batches;
    ///
    /// let batches = Batches::new().expect("API key should be set");
    /// ```
    pub fn new() -> Result<Self> {
        let auth = AuthProvider::openai_from_env()?;
        Ok(Self { auth, timeout: None })
    }

    /// Creates a new Batches client with a custom authentication provider
    pub fn with_auth(auth: AuthProvider) -> Self {
        Self { auth, timeout: None }
    }

    /// Creates a new Batches client for Azure OpenAI API
    pub fn azure() -> Result<Self> {
        let auth = AuthProvider::azure_from_env()?;
        Ok(Self { auth, timeout: None })
    }

    /// Creates a new Batches client by auto-detecting the provider
    pub fn detect_provider() -> Result<Self> {
        let auth = AuthProvider::from_env()?;
        Ok(Self { auth, timeout: None })
    }

    /// Creates a new Batches client with URL-based provider detection
    pub fn with_url<S: Into<String>>(base_url: S, api_key: S) -> Self {
        let auth = AuthProvider::from_url_with_key(base_url, api_key);
        Self { auth, timeout: None }
    }

    /// Creates a new Batches client from URL using environment variables
    pub fn from_url<S: Into<String>>(url: S) -> Result<Self> {
        let auth = AuthProvider::from_url(url)?;
        Ok(Self { auth, timeout: None })
    }

    /// Returns the authentication provider
    pub fn auth(&self) -> &AuthProvider {
        &self.auth
    }

    /// Sets the request timeout duration.
    ///
    /// # Arguments
    ///
    /// * `timeout` - The maximum time to wait for a response
    ///
    /// # Returns
    ///
    /// A mutable reference to self for method chaining
    pub fn timeout(&mut self, timeout: Duration) -> &mut Self {
        self.timeout = Some(timeout);
        self
    }

    /// Creates the HTTP client with default headers.
    fn create_client(&self) -> Result<(request::Client, request::header::HeaderMap)> {
        let client = create_http_client(self.timeout)?;
        let mut headers = request::header::HeaderMap::new();
        self.auth.apply_headers(&mut headers)?;
        headers.insert("Content-Type", request::header::HeaderValue::from_static("application/json"));
        headers.insert("User-Agent", request::header::HeaderValue::from_static("openai-tools-rust"));
        Ok((client, headers))
    }

    /// Creates a new batch job.
    ///
    /// # Arguments
    ///
    /// * `request` - The batch creation request
    ///
    /// # Returns
    ///
    /// * `Ok(BatchObject)` - The created batch object
    /// * `Err(OpenAIToolError)` - If the request fails
    ///
    /// # Example
    ///
    /// ```rust,no_run
    /// use openai_tools::batch::request::{Batches, CreateBatchRequest, BatchEndpoint};
    /// use std::collections::HashMap;
    ///
    /// #[tokio::main]
    /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
    ///     let batches = Batches::new()?;
    ///
    ///     let mut metadata = HashMap::new();
    ///     metadata.insert("customer_id".to_string(), "user_123".to_string());
    ///
    ///     let request = CreateBatchRequest::new("file-abc123", BatchEndpoint::ChatCompletions)
    ///         .with_metadata(metadata);
    ///
    ///     let batch = batches.create(request).await?;
    ///     println!("Created batch: {}", batch.id);
    ///     Ok(())
    /// }
    /// ```
    pub async fn create(&self, request: CreateBatchRequest) -> Result<BatchObject> {
        let (client, headers) = self.create_client()?;

        let body = serde_json::to_string(&request).map_err(OpenAIToolError::SerdeJsonError)?;

        let url = self.auth.endpoint(BATCHES_PATH);
        let response = client.post(&url).headers(headers).body(body).send().await.map_err(OpenAIToolError::RequestError)?;

        let content = response.text().await.map_err(OpenAIToolError::RequestError)?;

        if cfg!(test) {
            tracing::info!("Response content: {}", content);
        }

        serde_json::from_str::<BatchObject>(&content).map_err(OpenAIToolError::SerdeJsonError)
    }

    /// Retrieves details of a specific batch job.
    ///
    /// # Arguments
    ///
    /// * `batch_id` - The ID of the batch to retrieve
    ///
    /// # Returns
    ///
    /// * `Ok(BatchObject)` - The batch details
    /// * `Err(OpenAIToolError)` - If the batch is not found or the request fails
    ///
    /// # Example
    ///
    /// ```rust,no_run
    /// use openai_tools::batch::request::Batches;
    ///
    /// #[tokio::main]
    /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
    ///     let batches = Batches::new()?;
    ///     let batch = batches.retrieve("batch_abc123").await?;
    ///
    ///     println!("Status: {:?}", batch.status);
    ///     if let Some(counts) = &batch.request_counts {
    ///         println!("Completed: {}/{}", counts.completed, counts.total);
    ///     }
    ///     Ok(())
    /// }
    /// ```
    pub async fn retrieve(&self, batch_id: &str) -> Result<BatchObject> {
        let (client, headers) = self.create_client()?;
        let url = format!("{}/{}", self.auth.endpoint(BATCHES_PATH), batch_id);

        let response = client.get(&url).headers(headers).send().await.map_err(OpenAIToolError::RequestError)?;

        let content = response.text().await.map_err(OpenAIToolError::RequestError)?;

        if cfg!(test) {
            tracing::info!("Response content: {}", content);
        }

        serde_json::from_str::<BatchObject>(&content).map_err(OpenAIToolError::SerdeJsonError)
    }

    /// Cancels an in-progress batch job.
    ///
    /// The batch will transition to "cancelling" and eventually "cancelled".
    ///
    /// # Arguments
    ///
    /// * `batch_id` - The ID of the batch to cancel
    ///
    /// # Returns
    ///
    /// * `Ok(BatchObject)` - The updated batch object
    /// * `Err(OpenAIToolError)` - If the batch cannot be cancelled or the request fails
    ///
    /// # Example
    ///
    /// ```rust,no_run
    /// use openai_tools::batch::request::Batches;
    ///
    /// #[tokio::main]
    /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
    ///     let batches = Batches::new()?;
    ///     let batch = batches.cancel("batch_abc123").await?;
    ///
    ///     println!("Batch status: {:?}", batch.status);
    ///     Ok(())
    /// }
    /// ```
    pub async fn cancel(&self, batch_id: &str) -> Result<BatchObject> {
        let (client, headers) = self.create_client()?;
        let url = format!("{}/{}/cancel", self.auth.endpoint(BATCHES_PATH), batch_id);

        let response = client.post(&url).headers(headers).send().await.map_err(OpenAIToolError::RequestError)?;

        let content = response.text().await.map_err(OpenAIToolError::RequestError)?;

        if cfg!(test) {
            tracing::info!("Response content: {}", content);
        }

        serde_json::from_str::<BatchObject>(&content).map_err(OpenAIToolError::SerdeJsonError)
    }

    /// Lists all batch jobs.
    ///
    /// Supports pagination through `limit` and `after` parameters.
    ///
    /// # Arguments
    ///
    /// * `limit` - Maximum number of batches to return (default: 20)
    /// * `after` - Cursor for pagination (batch ID to start after)
    ///
    /// # Returns
    ///
    /// * `Ok(BatchListResponse)` - The list of batch jobs
    /// * `Err(OpenAIToolError)` - If the request fails
    ///
    /// # Example
    ///
    /// ```rust,no_run
    /// use openai_tools::batch::request::Batches;
    ///
    /// #[tokio::main]
    /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
    ///     let batches = Batches::new()?;
    ///
    ///     // Get first page
    ///     let response = batches.list(Some(10), None).await?;
    ///     for batch in &response.data {
    ///         println!("{}: {:?}", batch.id, batch.status);
    ///     }
    ///
    ///     // Get next page if available
    ///     if response.has_more {
    ///         if let Some(last_id) = &response.last_id {
    ///             let next_page = batches.list(Some(10), Some(last_id)).await?;
    ///             // ...
    ///         }
    ///     }
    ///
    ///     Ok(())
    /// }
    /// ```
    pub async fn list(&self, limit: Option<u32>, after: Option<&str>) -> Result<BatchListResponse> {
        let (client, headers) = self.create_client()?;

        let mut url = self.auth.endpoint(BATCHES_PATH);
        let mut params = Vec::new();

        if let Some(l) = limit {
            params.push(format!("limit={}", l));
        }
        if let Some(a) = after {
            params.push(format!("after={}", a));
        }

        if !params.is_empty() {
            url.push('?');
            url.push_str(&params.join("&"));
        }

        let response = client.get(&url).headers(headers).send().await.map_err(OpenAIToolError::RequestError)?;

        let content = response.text().await.map_err(OpenAIToolError::RequestError)?;

        if cfg!(test) {
            tracing::info!("Response content: {}", content);
        }

        serde_json::from_str::<BatchListResponse>(&content).map_err(OpenAIToolError::SerdeJsonError)
    }
}