use crate::client::OpenAI;
use crate::error::OpenAIError;
use crate::pagination::{Page, Paginator};
use crate::types::batch::{Batch, BatchCreateRequest, BatchList, BatchListParams};
pub struct Batches<'a> {
client: &'a OpenAI,
}
impl<'a> Batches<'a> {
pub(crate) fn new(client: &'a OpenAI) -> Self {
Self { client }
}
pub async fn create(&self, request: BatchCreateRequest) -> Result<Batch, OpenAIError> {
self.client.post("/batches", &request).await
}
pub async fn list(&self) -> Result<BatchList, OpenAIError> {
self.client.get("/batches").await
}
pub async fn list_page(&self, params: BatchListParams) -> Result<BatchList, OpenAIError> {
self.client
.get_with_query("/batches", ¶ms.to_query())
.await
}
pub fn list_auto(&self, params: BatchListParams) -> Paginator<Batch> {
let client = self.client.clone();
let base_params = params;
Paginator::new(move |cursor| {
let client = client.clone();
let mut params = base_params.clone();
if cursor.is_some() {
params.after = cursor;
}
async move {
let list: BatchList = client
.get_with_query("/batches", ¶ms.to_query())
.await?;
let after_cursor = list
.last_id
.clone()
.or_else(|| list.data.last().map(|b| b.id.clone()));
Ok(Page {
has_more: list.has_more.unwrap_or(false),
after_cursor,
data: list.data,
})
}
})
}
pub async fn retrieve(&self, batch_id: &str) -> Result<Batch, OpenAIError> {
self.client.get(&format!("/batches/{batch_id}")).await
}
pub async fn cancel(&self, batch_id: &str) -> Result<Batch, OpenAIError> {
self.client
.post(
&format!("/batches/{batch_id}/cancel"),
&serde_json::Value::Null,
)
.await
}
}
#[cfg(test)]
mod tests {
use crate::OpenAI;
use crate::config::ClientConfig;
use crate::types::batch::{BatchCreateRequest, BatchListParams};
const BATCH_JSON: &str = r#"{
"id": "batch_abc123",
"object": "batch",
"endpoint": "/v1/chat/completions",
"input_file_id": "file-abc123",
"completion_window": "24h",
"status": "validating",
"created_at": 1699012949
}"#;
#[tokio::test]
async fn test_batches_create() {
let mut server = mockito::Server::new_async().await;
let mock = server
.mock("POST", "/batches")
.match_header("authorization", "Bearer sk-test")
.with_status(200)
.with_header("content-type", "application/json")
.with_body(BATCH_JSON)
.create_async()
.await;
let client = OpenAI::with_config(ClientConfig::new("sk-test").base_url(server.url()));
let request = BatchCreateRequest::new("file-abc123", "/v1/chat/completions", "24h");
let batch = client.batches().create(request).await.unwrap();
assert_eq!(batch.id, "batch_abc123");
assert_eq!(batch.status, crate::types::batch::BatchStatus::Validating);
mock.assert_async().await;
}
#[tokio::test]
async fn test_batches_retrieve() {
let mut server = mockito::Server::new_async().await;
let mock = server
.mock("GET", "/batches/batch_abc123")
.with_status(200)
.with_header("content-type", "application/json")
.with_body(BATCH_JSON)
.create_async()
.await;
let client = OpenAI::with_config(ClientConfig::new("sk-test").base_url(server.url()));
let batch = client.batches().retrieve("batch_abc123").await.unwrap();
assert_eq!(batch.id, "batch_abc123");
mock.assert_async().await;
}
#[tokio::test]
async fn test_batches_cancel() {
let mut server = mockito::Server::new_async().await;
let mock = server
.mock("POST", "/batches/batch_abc123/cancel")
.with_status(200)
.with_header("content-type", "application/json")
.with_body(
r#"{
"id": "batch_abc123",
"object": "batch",
"endpoint": "/v1/chat/completions",
"input_file_id": "file-abc123",
"completion_window": "24h",
"status": "cancelling",
"created_at": 1699012949
}"#,
)
.create_async()
.await;
let client = OpenAI::with_config(ClientConfig::new("sk-test").base_url(server.url()));
let batch = client.batches().cancel("batch_abc123").await.unwrap();
assert_eq!(batch.status, crate::types::batch::BatchStatus::Cancelling);
mock.assert_async().await;
}
#[tokio::test]
async fn test_batches_list_auto_multi_page() {
use futures_util::StreamExt;
let mut server = mockito::Server::new_async().await;
let _mock_p1 = server
.mock("GET", "/batches")
.match_query(mockito::Matcher::Missing)
.with_status(200)
.with_header("content-type", "application/json")
.with_body(
r#"{
"object": "list",
"data": [
{"id": "batch_1", "object": "batch", "endpoint": "/v1/chat/completions", "input_file_id": "file-1", "completion_window": "24h", "status": "completed", "created_at": 1},
{"id": "batch_2", "object": "batch", "endpoint": "/v1/chat/completions", "input_file_id": "file-2", "completion_window": "24h", "status": "completed", "created_at": 2}
],
"has_more": true,
"last_id": "batch_2"
}"#,
)
.create_async()
.await;
let _mock_p2 = server
.mock("GET", "/batches")
.match_query(mockito::Matcher::AllOf(vec![
mockito::Matcher::UrlEncoded("after".into(), "batch_2".into()),
]))
.with_status(200)
.with_header("content-type", "application/json")
.with_body(
r#"{
"object": "list",
"data": [
{"id": "batch_3", "object": "batch", "endpoint": "/v1/chat/completions", "input_file_id": "file-3", "completion_window": "24h", "status": "completed", "created_at": 3}
],
"has_more": false,
"last_id": "batch_3"
}"#,
)
.create_async()
.await;
let client = OpenAI::with_config(ClientConfig::new("sk-test").base_url(server.url()));
let stream = client.batches().list_auto(BatchListParams::new());
let batches: Vec<_> = stream
.collect::<Vec<_>>()
.await
.into_iter()
.map(|r| r.unwrap())
.collect();
assert_eq!(batches.len(), 3);
assert_eq!(batches[0].id, "batch_1");
assert_eq!(batches[2].id, "batch_3");
}
}