use crate::{CachedJobResult, HasTasks, JobMetrics, TaskStatus};
use axum::{
Router,
extract::{Path, Query, State},
http::StatusCode,
response::Json,
routing::{get, post},
};
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
pub fn admin_routes<T>() -> Router<T>
where
T: HasTasks + Clone + Send + Sync + 'static,
{
Router::new()
.route("/tasks", get(list_tasks::<T>))
.route("/tasks/{task_id}", get(get_task_status::<T>))
.route("/tasks/{task_id}/metrics", get(get_job_metrics::<T>))
.route("/tasks/{task_id}/result", get(get_job_result::<T>))
.route("/tasks/{task_id}/cancel", post(cancel_task::<T>))
.route("/tasks/by-status/{status}", get(get_tasks_by_status::<T>))
.route("/cleanup", post(cleanup_old_tasks::<T>))
.route("/health", get(health_check::<T>))
.route("/metrics", get(get_metrics::<T>))
}
#[derive(Debug, Deserialize)]
struct ListTasksQuery {
status: Option<String>,
limit: Option<usize>,
}
async fn list_tasks<T>(
State(state): State<T>,
Query(query): Query<ListTasksQuery>,
) -> Result<Json<TaskListResponse>, ApiErrorResponse>
where
T: HasTasks,
{
let status = if let Some(status_str) = query.status {
Some(parse_task_status(&status_str)?)
} else {
None
};
let tasks = state.tasks().list_tasks(status.clone(), query.limit).await;
Ok(Json(TaskListResponse {
tasks: tasks.clone(),
total: tasks.len(),
filtered: status.is_some(),
}))
}
async fn get_task_status<T>(
State(state): State<T>,
Path(task_id): Path<String>,
) -> Result<Json<crate::TaskState>, ApiErrorResponse>
where
T: HasTasks,
{
match state.tasks().get_task(&task_id).await {
Some(task) => Ok(Json(task)),
None => Err(error_stack::report!(ApiError::NotFound(task_id.clone()))
.attach_printable("Task does not exist in the current state")
.into()),
}
}
async fn get_tasks_by_status<T>(
State(state): State<T>,
Path(status): Path<String>,
) -> Result<Json<Vec<crate::TaskState>>, ApiErrorResponse>
where
T: HasTasks,
{
let status = parse_task_status(&status)?;
let tasks = state.tasks().get_tasks_by_status(status).await;
Ok(Json(tasks))
}
async fn get_job_metrics<T>(
State(state): State<T>,
Path(task_id): Path<String>,
) -> Result<Json<JobMetrics>, ApiErrorResponse>
where
T: HasTasks,
{
match state.tasks().get_job_metrics(&task_id).await {
Some(metrics) => Ok(Json(metrics)),
None => Err(error_stack::report!(ApiError::NotFound(task_id.clone()))
.attach_printable("Task does not exist")
.into()),
}
}
async fn get_job_result<T>(
State(state): State<T>,
Path(task_id): Path<String>,
) -> Result<Json<CachedJobResult>, ApiErrorResponse>
where
T: HasTasks,
{
let status = state.tasks().get_status(&task_id).await.ok_or_else(|| {
error_stack::report!(ApiError::NotFound(task_id.clone()))
.attach_printable("Task does not exist")
})?;
match status {
TaskStatus::Completed | TaskStatus::Failed => {
match state.tasks().get_result(&task_id).await {
Some(result) => Ok(Json(result)),
None => Err(error_stack::report!(ApiError::NotFound(task_id.clone()))
.attach_printable("Result no longer in cache")
.into()),
}
}
_ => Err(error_stack::report!(ApiError::BadRequest(task_id.clone()))
.attach_printable("Task not yet completed")
.into()),
}
}
#[derive(Debug, Deserialize)]
struct CleanupRequest {
older_than_hours: Option<u64>,
older_than: Option<DateTime<Utc>>,
}
async fn cleanup_old_tasks<T>(
State(state): State<T>,
Json(request): Json<CleanupRequest>,
) -> Result<Json<CleanupResponse>, ApiErrorResponse>
where
T: HasTasks,
{
let cutoff = if let Some(timestamp) = request.older_than {
timestamp
} else {
let hours = request.older_than_hours.unwrap_or(24); Utc::now() - chrono::Duration::hours(hours as i64)
};
let removed = state.tasks().cleanup_old_tasks(cutoff).await;
Ok(Json(CleanupResponse {
removed_count: removed,
cutoff_time: cutoff,
}))
}
async fn health_check<T>(State(state): State<T>) -> Json<crate::types::HealthStatus>
where
T: HasTasks,
{
Json(state.tasks().health_status())
}
async fn get_metrics<T>(State(state): State<T>) -> Json<crate::metrics::MetricsSnapshot>
where
T: HasTasks,
{
Json(state.tasks().get_task_metrics())
}
#[derive(Debug, Serialize)]
struct TaskListResponse {
tasks: Vec<crate::TaskState>,
total: usize,
filtered: bool,
}
#[derive(Debug, Serialize)]
struct CleanupResponse {
removed_count: usize,
cutoff_time: DateTime<Utc>,
}
#[allow(dead_code)]
#[derive(Debug)]
pub enum ApiError {
NotFound(String),
InvalidStatus(String),
BadRequest(String),
Internal(String),
}
impl std::fmt::Display for ApiError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ApiError::NotFound(msg) => write!(f, "Task not found: {}", msg),
ApiError::InvalidStatus(msg) => write!(f, "Invalid task status: {}", msg),
ApiError::BadRequest(msg) => write!(f, "Bad request: {}", msg),
ApiError::Internal(msg) => write!(f, "Internal server error: {}", msg),
}
}
}
impl error_stack::Context for ApiError {}
#[derive(Debug)]
pub struct ApiErrorResponse(error_stack::Report<ApiError>);
impl From<error_stack::Report<ApiError>> for ApiErrorResponse {
fn from(report: error_stack::Report<ApiError>) -> Self {
Self(report)
}
}
impl From<ApiError> for ApiErrorResponse {
fn from(error: ApiError) -> Self {
Self(error_stack::report!(error))
}
}
impl axum::response::IntoResponse for ApiErrorResponse {
fn into_response(self) -> axum::response::Response {
let context = self.0.current_context();
let (status, error_message) = match context {
ApiError::NotFound(msg) => (StatusCode::NOT_FOUND, msg.clone()),
ApiError::InvalidStatus(msg) => (StatusCode::BAD_REQUEST, msg.clone()),
ApiError::BadRequest(msg) => (StatusCode::BAD_REQUEST, msg.clone()),
ApiError::Internal(msg) => (StatusCode::INTERNAL_SERVER_ERROR, msg.clone()),
};
let main_error_msg = context.to_string();
let additional_context: Vec<String> = format!("{}", self.0)
.lines()
.map(str::trim)
.filter(|line| {
!line.is_empty()
&& !line.starts_with("at ")
&& *line != main_error_msg
&& !line.contains("src/")
})
.map(String::from)
.collect();
let body = if additional_context.is_empty() {
Json(serde_json::json!({"error": error_message}))
} else {
Json(serde_json::json!({
"error": error_message,
"context": additional_context
}))
};
(status, body).into_response()
}
}
async fn cancel_task<T>(
State(state): State<T>,
Path(task_id): Path<String>,
) -> Result<Json<serde_json::Value>, ApiErrorResponse>
where
T: HasTasks,
{
let cancelled = state.tasks().cancel_task(&task_id).await;
if cancelled {
Ok(Json(serde_json::json!({
"task_id": task_id,
"status": "cancelled"
})))
} else {
Err(error_stack::report!(ApiError::BadRequest(task_id.clone()))
.attach_printable("Task not found or already in terminal state")
.into())
}
}
fn parse_task_status(status_str: &str) -> Result<TaskStatus, ApiErrorResponse> {
match status_str.to_lowercase().as_str() {
"queued" => Ok(TaskStatus::Queued),
"in_progress" | "inprogress" | "running" => Ok(TaskStatus::InProgress),
"completed" | "success" | "done" => Ok(TaskStatus::Completed),
"failed" | "error" => Ok(TaskStatus::Failed),
"retrying" | "retry" => Ok(TaskStatus::Retrying),
"cancelled" | "canceled" => Ok(TaskStatus::Cancelled),
_ => Err(
error_stack::report!(ApiError::InvalidStatus(status_str.to_string()))
.attach_printable(
"Valid statuses: queued, in_progress, completed, failed, retrying, cancelled",
)
.into(),
),
}
}