use crate::batch::response::{BatchListResponse, BatchObject};
use crate::common::auth::AuthProvider;
use crate::common::client::create_http_client;
use crate::common::errors::{OpenAIToolError, Result};
use serde::Serialize;
use std::collections::HashMap;
use std::time::Duration;
const BATCHES_PATH: &str = "batches";
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
pub enum BatchEndpoint {
#[serde(rename = "/v1/chat/completions")]
ChatCompletions,
#[serde(rename = "/v1/embeddings")]
Embeddings,
#[serde(rename = "/v1/completions")]
Completions,
#[serde(rename = "/v1/responses")]
Responses,
#[serde(rename = "/v1/moderations")]
Moderations,
}
impl BatchEndpoint {
pub fn as_str(&self) -> &'static str {
match self {
BatchEndpoint::ChatCompletions => "/v1/chat/completions",
BatchEndpoint::Embeddings => "/v1/embeddings",
BatchEndpoint::Completions => "/v1/completions",
BatchEndpoint::Responses => "/v1/responses",
BatchEndpoint::Moderations => "/v1/moderations",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Default)]
pub enum CompletionWindow {
#[serde(rename = "24h")]
#[default]
Hours24,
}
impl CompletionWindow {
pub fn as_str(&self) -> &'static str {
match self {
CompletionWindow::Hours24 => "24h",
}
}
}
#[derive(Debug, Clone, Serialize)]
pub struct CreateBatchRequest {
pub input_file_id: String,
pub endpoint: BatchEndpoint,
pub completion_window: CompletionWindow,
#[serde(skip_serializing_if = "Option::is_none")]
pub metadata: Option<HashMap<String, String>>,
}
impl CreateBatchRequest {
pub fn new(input_file_id: impl Into<String>, endpoint: BatchEndpoint) -> Self {
Self { input_file_id: input_file_id.into(), endpoint, completion_window: CompletionWindow::default(), metadata: None }
}
pub fn with_metadata(mut self, metadata: HashMap<String, String>) -> Self {
self.metadata = Some(metadata);
self
}
}
pub struct Batches {
auth: AuthProvider,
timeout: Option<Duration>,
}
impl Batches {
pub fn new() -> Result<Self> {
let auth = AuthProvider::openai_from_env()?;
Ok(Self { auth, timeout: None })
}
pub fn with_auth(auth: AuthProvider) -> Self {
Self { auth, timeout: None }
}
pub fn azure() -> Result<Self> {
let auth = AuthProvider::azure_from_env()?;
Ok(Self { auth, timeout: None })
}
pub fn detect_provider() -> Result<Self> {
let auth = AuthProvider::from_env()?;
Ok(Self { auth, timeout: None })
}
pub fn with_url<S: Into<String>>(base_url: S, api_key: S) -> Self {
let auth = AuthProvider::from_url_with_key(base_url, api_key);
Self { auth, timeout: None }
}
pub fn from_url<S: Into<String>>(url: S) -> Result<Self> {
let auth = AuthProvider::from_url(url)?;
Ok(Self { auth, timeout: None })
}
pub fn auth(&self) -> &AuthProvider {
&self.auth
}
pub fn timeout(&mut self, timeout: Duration) -> &mut Self {
self.timeout = Some(timeout);
self
}
fn create_client(&self) -> Result<(request::Client, request::header::HeaderMap)> {
let client = create_http_client(self.timeout)?;
let mut headers = request::header::HeaderMap::new();
self.auth.apply_headers(&mut headers)?;
headers.insert("Content-Type", request::header::HeaderValue::from_static("application/json"));
headers.insert("User-Agent", request::header::HeaderValue::from_static("openai-tools-rust"));
Ok((client, headers))
}
pub async fn create(&self, request: CreateBatchRequest) -> Result<BatchObject> {
let (client, headers) = self.create_client()?;
let body = serde_json::to_string(&request).map_err(OpenAIToolError::SerdeJsonError)?;
let url = self.auth.endpoint(BATCHES_PATH);
let response = client.post(&url).headers(headers).body(body).send().await.map_err(OpenAIToolError::RequestError)?;
let content = response.text().await.map_err(OpenAIToolError::RequestError)?;
if cfg!(test) {
tracing::info!("Response content: {}", content);
}
serde_json::from_str::<BatchObject>(&content).map_err(OpenAIToolError::SerdeJsonError)
}
pub async fn retrieve(&self, batch_id: &str) -> Result<BatchObject> {
let (client, headers) = self.create_client()?;
let url = format!("{}/{}", self.auth.endpoint(BATCHES_PATH), batch_id);
let response = client.get(&url).headers(headers).send().await.map_err(OpenAIToolError::RequestError)?;
let content = response.text().await.map_err(OpenAIToolError::RequestError)?;
if cfg!(test) {
tracing::info!("Response content: {}", content);
}
serde_json::from_str::<BatchObject>(&content).map_err(OpenAIToolError::SerdeJsonError)
}
pub async fn cancel(&self, batch_id: &str) -> Result<BatchObject> {
let (client, headers) = self.create_client()?;
let url = format!("{}/{}/cancel", self.auth.endpoint(BATCHES_PATH), batch_id);
let response = client.post(&url).headers(headers).send().await.map_err(OpenAIToolError::RequestError)?;
let content = response.text().await.map_err(OpenAIToolError::RequestError)?;
if cfg!(test) {
tracing::info!("Response content: {}", content);
}
serde_json::from_str::<BatchObject>(&content).map_err(OpenAIToolError::SerdeJsonError)
}
pub async fn list(&self, limit: Option<u32>, after: Option<&str>) -> Result<BatchListResponse> {
let (client, headers) = self.create_client()?;
let mut url = self.auth.endpoint(BATCHES_PATH);
let mut params = Vec::new();
if let Some(l) = limit {
params.push(format!("limit={}", l));
}
if let Some(a) = after {
params.push(format!("after={}", a));
}
if !params.is_empty() {
url.push('?');
url.push_str(¶ms.join("&"));
}
let response = client.get(&url).headers(headers).send().await.map_err(OpenAIToolError::RequestError)?;
let content = response.text().await.map_err(OpenAIToolError::RequestError)?;
if cfg!(test) {
tracing::info!("Response content: {}", content);
}
serde_json::from_str::<BatchListResponse>(&content).map_err(OpenAIToolError::SerdeJsonError)
}
}