axum_tasks/
routes.rs

1use crate::{CachedJobResult, HasTasks, JobMetrics, TaskStatus};
2use axum::{
3    Router,
4    extract::{Path, Query, State},
5    http::StatusCode,
6    response::Json,
7    routing::{get, post},
8};
9use chrono::{DateTime, Utc};
10use serde::{Deserialize, Serialize};
11
12pub fn admin_routes<T>() -> Router<T>
13where
14    T: HasTasks + Clone + Send + Sync + 'static,
15{
16    Router::new()
17        .route("/tasks", get(list_tasks::<T>))
18        .route("/tasks/{task_id}", get(get_task_status::<T>))
19        .route("/tasks/{task_id}/metrics", get(get_job_metrics::<T>)) // NEW
20        .route("/tasks/{task_id}/result", get(get_job_result::<T>))
21        .route("/tasks/by-status/{status}", get(get_tasks_by_status::<T>))
22        .route("/cleanup", post(cleanup_old_tasks::<T>))
23        .route("/health", get(health_check::<T>))
24        .route("/metrics", get(get_metrics::<T>))
25}
26
27#[derive(Debug, Deserialize)]
28struct ListTasksQuery {
29    status: Option<String>,
30    limit: Option<usize>,
31}
32
33async fn list_tasks<T>(
34    State(state): State<T>,
35    Query(query): Query<ListTasksQuery>,
36) -> Result<Json<TaskListResponse>, ApiErrorResponse>
37where
38    T: HasTasks,
39{
40    let status = if let Some(status_str) = query.status {
41        Some(parse_task_status(&status_str)?)
42    } else {
43        None
44    };
45
46    let tasks = state.tasks().list_tasks(status.clone(), query.limit).await;
47
48    Ok(Json(TaskListResponse {
49        tasks: tasks.clone(),
50        total: tasks.len(),
51        filtered: status.is_some(),
52    }))
53}
54
55async fn get_task_status<T>(
56    State(state): State<T>,
57    Path(task_id): Path<String>,
58) -> Result<Json<crate::TaskState>, ApiErrorResponse>
59where
60    T: HasTasks,
61{
62    match state.tasks().get_task(&task_id).await {
63        Some(task) => Ok(Json(task)),
64        None => Err(error_stack::report!(ApiError::NotFound(task_id.clone()))
65            .attach_printable("Task does not exist in the current state")
66            .into()),
67    }
68}
69
70async fn get_tasks_by_status<T>(
71    State(state): State<T>,
72    Path(status): Path<String>,
73) -> Result<Json<Vec<crate::TaskState>>, ApiErrorResponse>
74where
75    T: HasTasks,
76{
77    let status = parse_task_status(&status)?;
78    let tasks = state.tasks().get_tasks_by_status(status).await;
79    Ok(Json(tasks))
80}
81
82async fn get_job_metrics<T>(
83    State(state): State<T>,
84    Path(task_id): Path<String>,
85) -> Result<Json<JobMetrics>, ApiErrorResponse>
86where
87    T: HasTasks,
88{
89    match state.tasks().get_job_metrics(&task_id).await {
90        Some(metrics) => Ok(Json(metrics)),
91        None => Err(error_stack::report!(ApiError::NotFound(task_id.clone()))
92            .attach_printable("Task does not exist")
93            .into()),
94    }
95}
96
97async fn get_job_result<T>(
98    State(state): State<T>,
99    Path(task_id): Path<String>,
100) -> Result<Json<CachedJobResult>, ApiErrorResponse>
101where
102    T: HasTasks,
103{
104    let status = state.tasks().get_status(&task_id).await.ok_or_else(|| {
105        error_stack::report!(ApiError::NotFound(task_id.clone()))
106            .attach_printable("Task does not exist")
107    })?;
108
109    match status {
110        TaskStatus::Completed | TaskStatus::Failed => {
111            match state.tasks().get_result(&task_id).await {
112                Some(result) => Ok(Json(result)),
113                None => Err(error_stack::report!(ApiError::NotFound(task_id.clone()))
114                    .attach_printable("Result no longer in cache")
115                    .into()),
116            }
117        }
118        _ => Err(error_stack::report!(ApiError::BadRequest(task_id.clone()))
119            .attach_printable("Task not yet completed")
120            .into()),
121    }
122}
123
124#[derive(Debug, Deserialize)]
125struct CleanupRequest {
126    older_than_hours: Option<u64>,
127    older_than: Option<DateTime<Utc>>,
128}
129
130async fn cleanup_old_tasks<T>(
131    State(state): State<T>,
132    Json(request): Json<CleanupRequest>,
133) -> Result<Json<CleanupResponse>, ApiErrorResponse>
134where
135    T: HasTasks,
136{
137    let cutoff = if let Some(timestamp) = request.older_than {
138        timestamp
139    } else {
140        let hours = request.older_than_hours.unwrap_or(24); // Default 24 hours
141        Utc::now() - chrono::Duration::hours(hours as i64)
142    };
143
144    let removed = state.tasks().cleanup_old_tasks(cutoff).await;
145
146    Ok(Json(CleanupResponse {
147        removed_count: removed,
148        cutoff_time: cutoff,
149    }))
150}
151
152async fn health_check<T>(State(state): State<T>) -> Json<crate::types::HealthStatus>
153where
154    T: HasTasks,
155{
156    Json(state.tasks().health_status())
157}
158
159async fn get_metrics<T>(State(state): State<T>) -> Json<crate::metrics::MetricsSnapshot>
160where
161    T: HasTasks,
162{
163    Json(state.tasks().get_task_metrics())
164}
165
166#[derive(Debug, Serialize)]
167struct TaskListResponse {
168    tasks: Vec<crate::TaskState>,
169    total: usize,
170    filtered: bool,
171}
172
173#[derive(Debug, Serialize)]
174struct CleanupResponse {
175    removed_count: usize,
176    cutoff_time: DateTime<Utc>,
177}
178
179#[allow(dead_code)]
180#[derive(Debug)]
181pub enum ApiError {
182    NotFound(String),
183    InvalidStatus(String),
184    BadRequest(String),
185    Internal(String),
186}
187
188impl std::fmt::Display for ApiError {
189    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
190        match self {
191            ApiError::NotFound(msg) => write!(f, "Task not found: {}", msg),
192            ApiError::InvalidStatus(msg) => write!(f, "Invalid task status: {}", msg),
193            ApiError::BadRequest(msg) => write!(f, "Bad request: {}", msg),
194            ApiError::Internal(msg) => write!(f, "Internal server error: {}", msg),
195        }
196    }
197}
198
199impl error_stack::Context for ApiError {}
200
201#[derive(Debug)]
202pub struct ApiErrorResponse(error_stack::Report<ApiError>);
203
204impl From<error_stack::Report<ApiError>> for ApiErrorResponse {
205    fn from(report: error_stack::Report<ApiError>) -> Self {
206        Self(report)
207    }
208}
209
210impl From<ApiError> for ApiErrorResponse {
211    fn from(error: ApiError) -> Self {
212        Self(error_stack::report!(error))
213    }
214}
215
216impl axum::response::IntoResponse for ApiErrorResponse {
217    fn into_response(self) -> axum::response::Response {
218        let context = self.0.current_context();
219
220        let (status, error_message) = match context {
221            ApiError::NotFound(msg) => (StatusCode::NOT_FOUND, msg.clone()),
222            ApiError::InvalidStatus(msg) => (StatusCode::BAD_REQUEST, msg.clone()),
223            ApiError::BadRequest(msg) => (StatusCode::BAD_REQUEST, msg.clone()),
224            ApiError::Internal(msg) => (StatusCode::INTERNAL_SERVER_ERROR, msg.clone()),
225        };
226
227        let main_error_msg = context.to_string();
228        let additional_context: Vec<String> = format!("{}", self.0)
229            .lines()
230            .map(str::trim)
231            .filter(|line| {
232                !line.is_empty()
233                    && !line.starts_with("at ")
234                    && *line != main_error_msg
235                    && !line.contains("src/")
236            })
237            .map(String::from)
238            .collect();
239
240        let body = if additional_context.is_empty() {
241            Json(serde_json::json!({"error": error_message}))
242        } else {
243            Json(serde_json::json!({
244                "error": error_message,
245                "context": additional_context
246            }))
247        };
248
249        (status, body).into_response()
250    }
251}
252
253fn parse_task_status(status_str: &str) -> Result<TaskStatus, ApiErrorResponse> {
254    match status_str.to_lowercase().as_str() {
255        "queued" => Ok(TaskStatus::Queued),
256        "in_progress" | "inprogress" | "running" => Ok(TaskStatus::InProgress),
257        "completed" | "success" | "done" => Ok(TaskStatus::Completed),
258        "failed" | "error" => Ok(TaskStatus::Failed),
259        "retrying" | "retry" => Ok(TaskStatus::Retrying),
260        _ => Err(
261            error_stack::report!(ApiError::InvalidStatus(status_str.to_string()))
262                .attach_printable(
263                    "Valid statuses: queued, in_progress, completed, failed, retrying",
264                )
265                .into(),
266        ),
267    }
268}