use log::{debug, info};
use reqwest::Client;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
use std::io::Write;
use std::time::Duration;
use thiserror::Error;
use tokio::time::sleep;
use crate::chat_completions::{ChatClient, ChatRequest};
use crate::files::{FilePurpose, FilesClient, FilesError};
use crate::utils::remove_trailing_slash;
use crate::OpenAiError;
pub struct BatchClient {
pub api_key: String,
pub base_url: url::Url,
pub batches_path: String,
pub endpoint: String,
pub model: String,
pub files_client: FilesClient,
pub http_client: Client,
}
impl From<&ChatClient> for BatchClient {
fn from(client: &ChatClient) -> Self {
Self {
api_key: client.api_key.clone(),
base_url: client.base_url.clone(),
batches_path: "batches/".to_string(),
endpoint: "/v1/chat/completions".to_string(),
model: client.model.clone(),
files_client: FilesClient::from(client),
http_client: client.http_client.clone(),
}
}
}
#[derive(Error, Debug)]
pub enum UploadBatchFileError {
#[error("Error uploading file")]
FileUploadError(#[from] FilesError),
}
#[derive(Error, Debug)]
pub enum CreateBatchError {
#[error("Error sending request to the API")]
RequestError(#[from] reqwest::Error),
#[error("JSON error when parsing {1}")]
JsonParseError(#[source] serde_json::Error, String),
#[error("OpenAI API error: {0}")]
OpenAiError(#[from] OpenAiError),
}
#[derive(Error, Debug)]
pub enum GetBatchStatusError {
#[error("Error sending request to the API")]
RequestError(#[from] reqwest::Error),
#[error("JSON error when parsing {1}")]
JsonParseError(#[source] serde_json::Error, String),
#[error("OpenAI API error: {0}")]
OpenAiError(#[from] OpenAiError),
}
#[derive(Error, Debug)]
pub enum WaitForBatchError {
#[error("Error getting batch status")]
GetBatchStatusError(#[from] GetBatchStatusError),
#[error("Batch {id} failed: {error}")]
BatchFailed {
id: String,
error: String,
},
#[error("Batch cancelled: {0}")]
BatchCancelled(String),
#[error("Timeout waiting for batch to complete: {0}")]
BatchTimeout(String),
#[error("Batch expired: {0}")]
BatchExpired(String),
}
#[derive(Error, Debug)]
pub enum GetBatchResultsError {
#[error("Batch is not completed: {0}")]
BatchNotCompleted(BatchStatus),
#[error("Batch has no output file")]
BatchNoOutputFile(String),
#[error("File error: {0}")]
DownloadFileError(#[from] FilesError),
#[error("JSON error when parsing {1}")]
JsonParseError(#[source] serde_json::Error, String),
}
#[derive(Error, Debug)]
pub enum CancelBatchError {
#[error("Error sending request to the API")]
RequestError(#[from] reqwest::Error),
#[error("JSON error when parsing {1}")]
JsonParseError(#[source] serde_json::Error, String),
#[error("OpenAI API error: {0}")]
OpenAiError(#[from] OpenAiError),
}
#[derive(Error, Debug)]
pub enum ListBatchesError {
#[error("Error sending request to the API")]
RequestError(#[from] reqwest::Error),
#[error("JSON error when parsing {1}")]
JsonParseError(#[source] serde_json::Error, String),
#[error("OpenAI API error: {0}")]
OpenAiError(#[from] OpenAiError),
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct BatchRequestItem {
pub custom_id: String,
pub method: String,
pub url: String,
pub body: Value,
}
#[derive(Deserialize, Debug, Clone)]
pub struct BatchResponseItem {
pub id: String,
pub custom_id: String,
pub response: Option<BatchItemResponse>,
pub error: Option<BatchItemError>,
}
#[derive(Deserialize, Debug, Clone)]
pub struct BatchItemResponse {
pub status_code: u16,
pub request_id: String,
pub body: Value,
}
#[derive(Deserialize, Debug, Clone, Error)]
#[error("Batch item error: ({code}) {message}")]
pub struct BatchItemError {
#[serde(default)]
pub code: String,
pub message: String,
}
#[derive(Deserialize, Debug, Clone)]
pub struct Batch {
pub id: String,
pub object: String,
pub endpoint: String,
pub errors: Option<Value>,
pub input_file_id: String,
pub completion_window: String,
pub status: BatchStatus,
pub output_file_id: Option<String>,
pub error_file_id: Option<String>,
pub created_at: u64,
pub in_progress_at: Option<u64>,
pub expires_at: Option<u64>,
pub completed_at: Option<u64>,
pub failed_at: Option<u64>,
pub expired_at: Option<u64>,
pub request_counts: BatchRequestCounts,
pub metadata: Option<HashMap<String, String>>,
}
#[derive(Deserialize, Debug, Clone, Copy, PartialEq, Eq, Ord, PartialOrd)]
pub enum BatchStatus {
#[serde(rename = "validating")]
Validating,
#[serde(rename = "failed")]
Failed,
#[serde(rename = "in_progress")]
InProgress,
#[serde(rename = "finalizing")]
Finalizing,
#[serde(rename = "completed")]
Completed,
#[serde(rename = "expired")]
Expired,
#[serde(rename = "cancelling")]
Cancelling,
#[serde(rename = "cancelled")]
Cancelled,
}
impl std::fmt::Display for BatchStatus {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:?}", self)
}
}
#[derive(Deserialize, Debug, Clone)]
pub struct BatchRequestCounts {
pub total: u32,
pub completed: u32,
pub failed: u32,
}
#[derive(Deserialize, Debug, Clone)]
pub struct BatchList {
pub data: Vec<Batch>,
pub object: String,
pub has_more: bool,
}
impl BatchRequestItem {
pub fn new_chat(custom_id: impl Into<String>, chat_request: ChatRequest) -> Self {
let body = serde_json::json!({
"model": chat_request.model,
"messages": chat_request.messages,
"response_format": chat_request.response_format,
});
Self {
custom_id: custom_id.into(),
method: "POST".to_string(),
url: "/v1/chat/completions".to_string(),
body,
}
}
pub fn new_embedding(
custom_id: impl Into<String>,
model: impl Into<String>,
input: Vec<String>,
) -> Self {
Self {
custom_id: custom_id.into(),
method: "POST".to_string(),
url: "/v1/embeddings".to_string(),
body: serde_json::json!({
"model": model.into(),
"input": input,
}),
}
}
pub fn new_completion(
custom_id: impl Into<String>,
model: impl Into<String>,
prompt: impl Into<String>,
) -> Self {
Self {
custom_id: custom_id.into(),
method: "POST".to_string(),
url: "/v1/completions".to_string(),
body: serde_json::json!({
"model": model.into(),
"prompt": prompt.into(),
"max_tokens": 1000,
}),
}
}
pub fn new_response(
custom_id: impl Into<String>,
model: impl Into<String>,
prompt: impl Into<String>,
) -> Self {
Self {
custom_id: custom_id.into(),
method: "POST".to_string(),
url: "/v1/responses".to_string(),
body: serde_json::json!({
"model": model.into(),
"prompt": prompt.into(),
"max_tokens": 1000,
}),
}
}
}
impl BatchClient {
fn batches_url(&self) -> url::Url {
self.base_url.join(&self.batches_path).unwrap()
}
pub fn create_batch_content(&self, requests: &[BatchRequestItem]) -> Vec<u8> {
let mut content = Vec::new();
for request in requests {
let json = serde_json::to_string(request).unwrap(); writeln!(&mut content, "{}", json).unwrap(); }
content
}
pub async fn upload_batch_file(
&self,
filename: impl AsRef<str>,
requests: &[BatchRequestItem],
) -> Result<String, UploadBatchFileError> {
let content = self.create_batch_content(requests);
let file_obj = self
.files_client
.upload_bytes(filename.as_ref(), content, FilePurpose::Batch)
.await?;
info!(
"Batch file {} uploaded with ID {}",
filename.as_ref(),
file_obj.id
);
Ok(file_obj.id)
}
pub async fn create_batch(
&self,
input_file_id: impl AsRef<str>,
metadata: HashMap<String, String>,
) -> Result<Batch, CreateBatchError> {
let url = remove_trailing_slash(self.batches_url());
let response = self
.http_client
.post(url)
.header("Authorization", format!("Bearer {}", self.api_key))
.header("Content-Type", "application/json")
.json(&serde_json::json!({
"input_file_id": input_file_id.as_ref(),
"endpoint": &self.endpoint,
"completion_window": "24h",
"metadata": metadata,
}))
.send()
.await?;
let response_text = response.text().await?;
let batch: Result<Batch, serde_json::Error> = serde_json::from_str(&response_text);
match batch {
Ok(batch) => {
info!(
"Batch {} created with file id {}",
batch.id,
input_file_id.as_ref()
);
Ok(batch)
}
Err(e) => {
let error: Result<OpenAiError, _> = serde_json::from_str(&response_text);
match error {
Ok(error) => Err(CreateBatchError::OpenAiError(error)),
Err(_) => Err(CreateBatchError::JsonParseError(e, response_text)),
}
}
}
}
pub async fn get_batch_status(&self, batch_id: &str) -> Result<Batch, GetBatchStatusError> {
let url = self.batches_url().join(batch_id).unwrap();
let response = self
.http_client
.get(url)
.header("Authorization", format!("Bearer {}", self.api_key))
.header("Content-Type", "application/json")
.send()
.await?;
let response_text = response.text().await?;
let batch: Result<Batch, serde_json::Error> = serde_json::from_str(&response_text);
match batch {
Ok(batch) => Ok(batch),
Err(e) => {
let error: Result<OpenAiError, _> = serde_json::from_str(&response_text);
match error {
Ok(error) => Err(GetBatchStatusError::OpenAiError(error)),
Err(_) => Err(GetBatchStatusError::JsonParseError(e, response_text)),
}
}
}
}
pub async fn wait_for_batch(&self, batch_id: &str) -> Result<Batch, WaitForBatchError> {
let mut attempts = 0;
let mut seconds_waited = 0;
loop {
let batch = self.get_batch_status(batch_id).await?;
match batch.status {
BatchStatus::Completed => return Ok(batch),
BatchStatus::Failed => {
return Err(WaitForBatchError::BatchFailed {
id: batch_id.to_string(),
error: batch.errors.unwrap_or_default().to_string(),
})
}
BatchStatus::Expired => {
return Err(WaitForBatchError::BatchExpired(batch_id.to_string()))
}
BatchStatus::Cancelled | BatchStatus::Cancelling => {
return Err(WaitForBatchError::BatchCancelled(batch_id.to_string()))
}
BatchStatus::InProgress | BatchStatus::Validating | BatchStatus::Finalizing => {
attempts += 1;
if seconds_waited >= 86400 {
return Err(WaitForBatchError::BatchTimeout(batch_id.to_string()));
}
let delay = std::cmp::min(120, 2_u64.pow(attempts)) as u64;
info!(
"batch {} is still in progress, waiting {} seconds",
batch_id, delay
);
sleep(Duration::from_secs(delay)).await;
seconds_waited += delay;
}
}
}
}
pub async fn get_batch_results(
&self,
batch: &Batch,
) -> Result<Vec<BatchResponseItem>, GetBatchResultsError> {
if batch.status != BatchStatus::Completed {
return Err(GetBatchResultsError::BatchNotCompleted(batch.status));
}
let output_file_id = batch
.output_file_id
.as_ref()
.ok_or_else(|| GetBatchResultsError::BatchNoOutputFile(batch.id.clone()))?;
let content = self.files_client.download_file(output_file_id).await?;
debug!("Got results for batch {}: {}", batch.id, content);
let mut results = Vec::new();
for line in content.lines() {
let result: BatchResponseItem = serde_json::from_str(line)
.map_err(|e| GetBatchResultsError::JsonParseError(e, content.clone()))?;
results.push(result);
}
Ok(results)
}
pub async fn cancel_batch(&self, batch_id: &str) -> Result<Batch, CancelBatchError> {
let response = self
.http_client
.post(
self.batches_url()
.join(batch_id)
.unwrap()
.join("cancel")
.unwrap(),
)
.header("Authorization", format!("Bearer {}", self.api_key))
.header("Content-Type", "application/json")
.send()
.await?;
let response_text = response.text().await?;
let batch: Result<Batch, serde_json::Error> = serde_json::from_str(&response_text);
match batch {
Ok(batch) => Ok(batch),
Err(e) => {
let error: Result<OpenAiError, _> = serde_json::from_str(&response_text);
match error {
Ok(error) => Err(CancelBatchError::OpenAiError(error)),
Err(_) => Err(CancelBatchError::JsonParseError(e, response_text)),
}
}
}
}
pub async fn list_batches(&self) -> Result<Vec<Batch>, ListBatchesError> {
let mut all_batches = Vec::new();
let mut last_batch_id = None;
loop {
let batch_list = self
.list_batches_limited(None, last_batch_id.as_deref())
.await?;
if batch_list.data.is_empty() {
break;
}
if let Some(last_batch) = batch_list.data.last() {
last_batch_id = Some(last_batch.id.clone());
}
all_batches.extend(batch_list.data);
if !batch_list.has_more {
break;
}
}
Ok(all_batches)
}
async fn list_batches_limited(
&self,
limit: Option<u32>,
after: Option<&str>,
) -> Result<BatchList, ListBatchesError> {
let mut url = self.batches_url();
let mut query_params = Vec::new();
if let Some(limit) = limit {
query_params.push(format!("limit={}", limit));
}
if let Some(after) = after {
query_params.push(format!("after={}", after));
}
if !query_params.is_empty() {
url.set_query(Some(&query_params.join("&")));
}
let response = self
.http_client
.get(remove_trailing_slash(url))
.header("Authorization", format!("Bearer {}", self.api_key))
.header("Content-Type", "application/json")
.send()
.await?;
let response_text = response.text().await?;
let batch_list: Result<BatchList, serde_json::Error> = serde_json::from_str(&response_text);
match batch_list {
Ok(batch_list) => Ok(batch_list),
Err(e) => {
let error: Result<OpenAiError, _> = serde_json::from_str(&response_text);
match error {
Ok(error) => Err(ListBatchesError::OpenAiError(error)),
Err(_) => Err(ListBatchesError::JsonParseError(e, response_text)),
}
}
}
}
}
#[test]
fn test_batch_request_serialization() {
use serde_json::json;
let request = BatchRequestItem {
custom_id: "request-1".to_string(),
method: "POST".to_string(),
url: "/v1/chat/completions".to_string(),
body: json!({
"model": "gpt-4o",
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello world!"}
],
"max_tokens": 1000
}),
};
let serialized = serde_json::to_string(&request).unwrap();
assert!(serialized.contains("custom_id"));
assert!(serialized.contains("request-1"));
assert!(serialized.contains("method"));
assert!(serialized.contains("POST"));
assert!(serialized.contains("url"));
assert!(serialized.contains("/v1/chat/completions"));
assert!(serialized.contains("body"));
assert!(serialized.contains("gpt-4o"));
assert!(serialized.contains("helpful assistant"));
assert!(serialized.contains("Hello world!"));
}