use std::sync::Arc;
use axum::body::Body;
use axum::extract::{Path, Query, State};
use axum::http::StatusCode;
use axum::response::{IntoResponse, Response};
use axum::Json;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use tokio::fs::File as TokioFile;
use uuid::Uuid;
use crate::batch_spool::queue::{BatchQueueSender, BatchWorkItem};
use crate::batch_spool::store::{BatchJobMeta, BatchJobStatus, BatchStore};
use crate::state::AppState;
pub const MAX_BATCH_LINES: usize = 50_000;
pub const MAX_BATCH_BYTES: usize = 10 * 1024 * 1024;
#[derive(Debug, Deserialize)]
pub struct CreateBatchBody {
#[serde(default)]
pub input_jsonl: String,
#[serde(default)]
pub input_file_id: String,
pub endpoint: String,
#[serde(default)]
pub completion_window: Option<String>,
#[serde(default)]
pub metadata: Option<Value>,
}
#[derive(Debug, Deserialize, Default)]
pub struct ListBatchesQuery {
#[serde(default = "default_limit")]
pub limit: usize,
pub after: Option<String>,
}
fn default_limit() -> usize {
20
}
#[derive(Debug, Serialize)]
pub struct BatchObject {
pub id: String,
pub object: String,
pub endpoint: String,
pub status: String,
pub created_at: i64,
pub expires_at: Option<i64>,
pub completed_at: Option<i64>,
pub failed_at: Option<i64>,
pub request_counts: BatchRequestCounts,
pub metadata: Option<Value>,
}
#[derive(Debug, Serialize, Default)]
pub struct BatchRequestCounts {
pub total: u32,
pub completed: u32,
pub failed: u32,
}
pub struct BatchRouterState {
pub store: Arc<BatchStore>,
pub queue_tx: BatchQueueSender,
}
fn meta_to_object(meta: &BatchJobMeta) -> BatchObject {
let status_str = match meta.status {
BatchJobStatus::Pending => "validating",
BatchJobStatus::InProgress => "in_progress",
BatchJobStatus::Completed => "completed",
BatchJobStatus::Failed => "failed",
BatchJobStatus::Cancelled => "cancelled",
};
let completed_at = matches!(
meta.status,
BatchJobStatus::Completed | BatchJobStatus::Failed | BatchJobStatus::Cancelled
)
.then_some(meta.updated_at);
let failed_at = matches!(meta.status, BatchJobStatus::Failed).then_some(meta.updated_at);
BatchObject {
id: meta.id.clone(),
object: "batch".to_string(),
endpoint: meta.endpoint.clone(),
status: status_str.to_string(),
created_at: meta.created_at,
expires_at: Some(meta.created_at + 86400),
completed_at,
failed_at,
request_counts: BatchRequestCounts {
total: meta.total_lines,
completed: meta.completed_lines,
failed: meta.failed_lines,
},
metadata: None,
}
}
fn not_found(id: &str) -> Response {
let body = serde_json::json!({
"error": {
"message": format!("Batch '{id}' not found"),
"type": "invalid_request_error",
"code": "batch_not_found"
}
});
(StatusCode::NOT_FOUND, Json(body)).into_response()
}
pub async fn create_batch(
State(state): State<Arc<AppState>>,
Json(body): Json<CreateBatchBody>,
) -> Response {
let input_jsonl = if !body.input_jsonl.is_empty() {
body.input_jsonl.clone()
} else {
String::new()
};
if input_jsonl.len() > MAX_BATCH_BYTES {
let err = serde_json::json!({
"error": {
"message": format!("input exceeds max size of {} bytes", MAX_BATCH_BYTES),
"type": "invalid_request_error",
}
});
return (StatusCode::PAYLOAD_TOO_LARGE, Json(err)).into_response();
}
let line_count = input_jsonl.lines().filter(|l| !l.trim().is_empty()).count();
if line_count > MAX_BATCH_LINES {
let err = serde_json::json!({
"error": {
"message": format!("input exceeds max {} lines", MAX_BATCH_LINES),
"type": "invalid_request_error",
}
});
return (StatusCode::PAYLOAD_TOO_LARGE, Json(err)).into_response();
}
let job_id = format!("batch_{}", Uuid::new_v4().as_simple());
let store = Arc::clone(&state.batch_disk_store);
let job_id_c = job_id.clone();
let endpoint = body.endpoint.clone();
let jsonl = input_jsonl.clone();
let total = line_count as u32;
let meta_result =
tokio::task::spawn_blocking(move || store.create_job(&job_id_c, &jsonl, &endpoint, total))
.await;
match meta_result {
Ok(Ok(meta)) => {
if let Err(e) = state
.batch_queue_tx
.send(BatchWorkItem {
job_id: job_id.clone(),
})
.await
{
error_response(
StatusCode::INTERNAL_SERVER_ERROR,
&format!("failed to enqueue batch job: {e}"),
)
} else {
(StatusCode::ACCEPTED, Json(meta_to_object(&meta))).into_response()
}
}
Ok(Err(e)) => error_response(
StatusCode::INTERNAL_SERVER_ERROR,
&format!("failed to create batch job on disk: {e}"),
),
Err(e) => error_response(
StatusCode::INTERNAL_SERVER_ERROR,
&format!("task join error: {e}"),
),
}
}
pub async fn get_batch(State(state): State<Arc<AppState>>, Path(id): Path<String>) -> Response {
let store = Arc::clone(&state.batch_disk_store);
let id_c = id.clone();
match tokio::task::spawn_blocking(move || store.read_status(&id_c)).await {
Ok(Ok(meta)) => (StatusCode::OK, Json(meta_to_object(&meta))).into_response(),
Ok(Err(_)) => not_found(&id),
Err(e) => error_response(
StatusCode::INTERNAL_SERVER_ERROR,
&format!("task join error: {e}"),
),
}
}
pub async fn list_batches(
State(state): State<Arc<AppState>>,
Query(query): Query<ListBatchesQuery>,
) -> Response {
let store = Arc::clone(&state.batch_disk_store);
let ids_result = tokio::task::spawn_blocking(move || store.list_jobs()).await;
let ids = match ids_result {
Ok(Ok(ids)) => ids,
Ok(Err(e)) => {
return error_response(
StatusCode::INTERNAL_SERVER_ERROR,
&format!("list_jobs error: {e}"),
);
}
Err(e) => {
return error_response(
StatusCode::INTERNAL_SERVER_ERROR,
&format!("task join: {e}"),
);
}
};
let mut skip = query.after.is_some();
let mut objects = Vec::new();
for id in &ids {
if skip {
if query.after.as_deref() == Some(id.as_str()) {
skip = false;
}
continue;
}
if objects.len() >= query.limit {
break;
}
let store2 = Arc::clone(&state.batch_disk_store);
let id2 = id.clone();
if let Ok(Ok(meta)) = tokio::task::spawn_blocking(move || store2.read_status(&id2)).await {
objects.push(meta_to_object(&meta));
}
}
let body = serde_json::json!({
"object": "list",
"data": objects,
});
(StatusCode::OK, Json(body)).into_response()
}
pub async fn get_batch_output(
State(state): State<Arc<AppState>>,
Path(id): Path<String>,
) -> Response {
let output_path = state.batch_disk_store.job_dir(&id).join("output.jsonl");
match TokioFile::open(&output_path).await {
Ok(file) => {
use axum::http::header;
use tokio_util::io::ReaderStream;
let stream = ReaderStream::new(file);
let body = Body::from_stream(stream);
(
StatusCode::OK,
[(
header::CONTENT_TYPE,
"application/jsonl"
.parse::<axum::http::HeaderValue>()
.unwrap_or_else(|_| "text/plain".parse().expect("header")),
)],
body,
)
.into_response()
}
Err(_) => not_found(&id),
}
}
pub async fn cancel_batch(State(state): State<Arc<AppState>>, Path(id): Path<String>) -> Response {
let store = Arc::clone(&state.batch_disk_store);
let id_c = id.clone();
let result = tokio::task::spawn_blocking(move || -> std::io::Result<BatchJobMeta> {
let mut meta = store.read_status(&id_c)?;
meta.cancel_requested = true;
meta.updated_at = unix_now();
store.update_status(&id_c, &meta)?;
Ok(meta)
})
.await;
match result {
Ok(Ok(meta)) => (StatusCode::OK, Json(meta_to_object(&meta))).into_response(),
Ok(Err(_)) => not_found(&id),
Err(e) => error_response(
StatusCode::INTERNAL_SERVER_ERROR,
&format!("task join: {e}"),
),
}
}
fn error_response(status: StatusCode, message: &str) -> Response {
let body = serde_json::json!({
"error": {
"message": message,
"type": "server_error",
}
});
(status, Json(body)).into_response()
}
fn unix_now() -> i64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs() as i64)
.unwrap_or(0)
}