use std::{
collections::HashMap,
sync::Arc,
time::{SystemTime, UNIX_EPOCH},
};
use axum::{
extract::{Path, State},
http::StatusCode,
response::Json,
};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use tokio::sync::RwLock;
use uuid::Uuid;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BatchRequestItem {
pub custom_id: String,
pub method: String,
pub url: String,
pub body: Value,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum BatchStatus {
Validating,
InProgress,
Finalizing,
Completed,
Failed,
Expired,
Cancelling,
Cancelled,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct BatchRequestCounts {
pub total: u32,
pub completed: u32,
pub failed: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Batch {
pub id: String,
pub object: String,
pub endpoint: String,
pub status: BatchStatus,
pub created_at: i64,
pub expires_at: Option<i64>,
pub completed_at: Option<i64>,
pub failed_at: Option<i64>,
pub request_counts: BatchRequestCounts,
pub metadata: Option<Value>,
}
#[derive(Debug, Deserialize)]
pub struct CreateBatchBody {
pub requests: Vec<BatchRequestItem>,
pub endpoint: String,
pub completion_window: Option<String>,
pub metadata: Option<Value>,
}
pub type BatchStore = Arc<RwLock<HashMap<String, Batch>>>;
pub fn new_batch_store() -> BatchStore {
Arc::new(RwLock::new(HashMap::new()))
}
fn unix_now() -> i64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs() as i64)
.unwrap_or(0)
}
fn not_found(id: &str) -> (StatusCode, Json<Value>) {
(
StatusCode::NOT_FOUND,
Json(serde_json::json!({
"error": {
"message": format!("Batch '{id}' not found"),
"type": "invalid_request_error",
"code": "batch_not_found"
}
})),
)
}
use crate::state::AppState;
pub async fn create_batch(
State(state): State<Arc<AppState>>,
Json(body): Json<CreateBatchBody>,
) -> Result<Json<Batch>, (StatusCode, Json<Value>)> {
let id = format!("batch_{}", Uuid::new_v4().as_simple());
let total = body.requests.len() as u32;
let now = unix_now();
let batch = Batch {
id: id.clone(),
object: "batch".into(),
endpoint: body.endpoint,
status: BatchStatus::InProgress,
created_at: now,
expires_at: Some(now + 86400),
completed_at: None,
failed_at: None,
request_counts: BatchRequestCounts {
total,
completed: 0,
failed: 0,
},
metadata: body.metadata,
};
state.batch_store.write().await.insert(id, batch.clone());
Ok(Json(batch))
}
pub async fn get_batch(
State(state): State<Arc<AppState>>,
Path(id): Path<String>,
) -> Result<Json<Batch>, (StatusCode, Json<Value>)> {
let guard = state.batch_store.read().await;
guard
.get(&id)
.cloned()
.map(Json)
.ok_or_else(|| not_found(&id))
}
pub async fn list_batches(State(state): State<Arc<AppState>>) -> Json<Value> {
let guard = state.batch_store.read().await;
let items: Vec<&Batch> = guard.values().collect();
Json(serde_json::json!({"object": "list", "data": items}))
}
pub async fn cancel_batch(
State(state): State<Arc<AppState>>,
Path(id): Path<String>,
) -> Result<Json<Batch>, (StatusCode, Json<Value>)> {
let mut guard = state.batch_store.write().await;
match guard.get_mut(&id) {
Some(batch) => {
batch.status = BatchStatus::Cancelled;
Ok(Json(batch.clone()))
}
None => Err(not_found(&id)),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn batch_status_serializes_snake_case() {
let s = serde_json::to_string(&BatchStatus::InProgress).expect("serialize");
assert_eq!(s, "\"in_progress\"");
}
#[test]
fn batch_request_counts_default_is_zero() {
let c = BatchRequestCounts::default();
assert_eq!(c.total, 0);
assert_eq!(c.completed, 0);
assert_eq!(c.failed, 0);
}
#[tokio::test]
async fn create_and_retrieve_batch() {
let store = new_batch_store();
let batch = Batch {
id: "batch_test".into(),
object: "batch".into(),
endpoint: "/v1/chat/completions".into(),
status: BatchStatus::InProgress,
created_at: 0,
expires_at: None,
completed_at: None,
failed_at: None,
request_counts: BatchRequestCounts {
total: 2,
completed: 0,
failed: 0,
},
metadata: None,
};
store.write().await.insert("batch_test".into(), batch);
assert!(store.read().await.contains_key("batch_test"));
}
#[tokio::test]
async fn cancel_nonexistent_batch_returns_not_found() {
let store = new_batch_store();
let guard = store.read().await;
assert!(!guard.contains_key("no_such_id"));
drop(guard);
let (status, _) = not_found("no_such_id");
assert_eq!(status, StatusCode::NOT_FOUND);
}
}