1use crate::{CachedJobResult, HasTasks, JobMetrics, TaskStatus};
2use axum::{
3 Router,
4 extract::{Path, Query, State},
5 http::StatusCode,
6 response::Json,
7 routing::{get, post},
8};
9use chrono::{DateTime, Utc};
10use serde::{Deserialize, Serialize};
11
12pub fn admin_routes<T>() -> Router<T>
13where
14 T: HasTasks + Clone + Send + Sync + 'static,
15{
16 Router::new()
17 .route("/tasks", get(list_tasks::<T>))
18 .route("/tasks/{task_id}", get(get_task_status::<T>))
19 .route("/tasks/{task_id}/metrics", get(get_job_metrics::<T>)) .route("/tasks/{task_id}/result", get(get_job_result::<T>))
21 .route("/tasks/by-status/{status}", get(get_tasks_by_status::<T>))
22 .route("/cleanup", post(cleanup_old_tasks::<T>))
23 .route("/health", get(health_check::<T>))
24 .route("/metrics", get(get_metrics::<T>))
25}
26
27#[derive(Debug, Deserialize)]
28struct ListTasksQuery {
29 status: Option<String>,
30 limit: Option<usize>,
31}
32
33async fn list_tasks<T>(
34 State(state): State<T>,
35 Query(query): Query<ListTasksQuery>,
36) -> Result<Json<TaskListResponse>, ApiErrorResponse>
37where
38 T: HasTasks,
39{
40 let status = if let Some(status_str) = query.status {
41 Some(parse_task_status(&status_str)?)
42 } else {
43 None
44 };
45
46 let tasks = state.tasks().list_tasks(status.clone(), query.limit).await;
47
48 Ok(Json(TaskListResponse {
49 tasks: tasks.clone(),
50 total: tasks.len(),
51 filtered: status.is_some(),
52 }))
53}
54
55async fn get_task_status<T>(
56 State(state): State<T>,
57 Path(task_id): Path<String>,
58) -> Result<Json<crate::TaskState>, ApiErrorResponse>
59where
60 T: HasTasks,
61{
62 match state.tasks().get_task(&task_id).await {
63 Some(task) => Ok(Json(task)),
64 None => Err(error_stack::report!(ApiError::NotFound(task_id.clone()))
65 .attach_printable("Task does not exist in the current state")
66 .into()),
67 }
68}
69
70async fn get_tasks_by_status<T>(
71 State(state): State<T>,
72 Path(status): Path<String>,
73) -> Result<Json<Vec<crate::TaskState>>, ApiErrorResponse>
74where
75 T: HasTasks,
76{
77 let status = parse_task_status(&status)?;
78 let tasks = state.tasks().get_tasks_by_status(status).await;
79 Ok(Json(tasks))
80}
81
82async fn get_job_metrics<T>(
83 State(state): State<T>,
84 Path(task_id): Path<String>,
85) -> Result<Json<JobMetrics>, ApiErrorResponse>
86where
87 T: HasTasks,
88{
89 match state.tasks().get_job_metrics(&task_id).await {
90 Some(metrics) => Ok(Json(metrics)),
91 None => Err(error_stack::report!(ApiError::NotFound(task_id.clone()))
92 .attach_printable("Task does not exist")
93 .into()),
94 }
95}
96
97async fn get_job_result<T>(
98 State(state): State<T>,
99 Path(task_id): Path<String>,
100) -> Result<Json<CachedJobResult>, ApiErrorResponse>
101where
102 T: HasTasks,
103{
104 let status = state.tasks().get_status(&task_id).await.ok_or_else(|| {
105 error_stack::report!(ApiError::NotFound(task_id.clone()))
106 .attach_printable("Task does not exist")
107 })?;
108
109 match status {
110 TaskStatus::Completed | TaskStatus::Failed => {
111 match state.tasks().get_result(&task_id).await {
112 Some(result) => Ok(Json(result)),
113 None => Err(error_stack::report!(ApiError::NotFound(task_id.clone()))
114 .attach_printable("Result no longer in cache")
115 .into()),
116 }
117 }
118 _ => Err(error_stack::report!(ApiError::BadRequest(task_id.clone()))
119 .attach_printable("Task not yet completed")
120 .into()),
121 }
122}
123
124#[derive(Debug, Deserialize)]
125struct CleanupRequest {
126 older_than_hours: Option<u64>,
127 older_than: Option<DateTime<Utc>>,
128}
129
130async fn cleanup_old_tasks<T>(
131 State(state): State<T>,
132 Json(request): Json<CleanupRequest>,
133) -> Result<Json<CleanupResponse>, ApiErrorResponse>
134where
135 T: HasTasks,
136{
137 let cutoff = if let Some(timestamp) = request.older_than {
138 timestamp
139 } else {
140 let hours = request.older_than_hours.unwrap_or(24); Utc::now() - chrono::Duration::hours(hours as i64)
142 };
143
144 let removed = state.tasks().cleanup_old_tasks(cutoff).await;
145
146 Ok(Json(CleanupResponse {
147 removed_count: removed,
148 cutoff_time: cutoff,
149 }))
150}
151
152async fn health_check<T>(State(state): State<T>) -> Json<crate::types::HealthStatus>
153where
154 T: HasTasks,
155{
156 Json(state.tasks().health_status())
157}
158
159async fn get_metrics<T>(State(state): State<T>) -> Json<crate::metrics::MetricsSnapshot>
160where
161 T: HasTasks,
162{
163 Json(state.tasks().get_task_metrics())
164}
165
166#[derive(Debug, Serialize)]
167struct TaskListResponse {
168 tasks: Vec<crate::TaskState>,
169 total: usize,
170 filtered: bool,
171}
172
173#[derive(Debug, Serialize)]
174struct CleanupResponse {
175 removed_count: usize,
176 cutoff_time: DateTime<Utc>,
177}
178
179#[allow(dead_code)]
180#[derive(Debug)]
181pub enum ApiError {
182 NotFound(String),
183 InvalidStatus(String),
184 BadRequest(String),
185 Internal(String),
186}
187
188impl std::fmt::Display for ApiError {
189 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
190 match self {
191 ApiError::NotFound(msg) => write!(f, "Task not found: {}", msg),
192 ApiError::InvalidStatus(msg) => write!(f, "Invalid task status: {}", msg),
193 ApiError::BadRequest(msg) => write!(f, "Bad request: {}", msg),
194 ApiError::Internal(msg) => write!(f, "Internal server error: {}", msg),
195 }
196 }
197}
198
199impl error_stack::Context for ApiError {}
200
201#[derive(Debug)]
202pub struct ApiErrorResponse(error_stack::Report<ApiError>);
203
204impl From<error_stack::Report<ApiError>> for ApiErrorResponse {
205 fn from(report: error_stack::Report<ApiError>) -> Self {
206 Self(report)
207 }
208}
209
210impl From<ApiError> for ApiErrorResponse {
211 fn from(error: ApiError) -> Self {
212 Self(error_stack::report!(error))
213 }
214}
215
216impl axum::response::IntoResponse for ApiErrorResponse {
217 fn into_response(self) -> axum::response::Response {
218 let context = self.0.current_context();
219
220 let (status, error_message) = match context {
221 ApiError::NotFound(msg) => (StatusCode::NOT_FOUND, msg.clone()),
222 ApiError::InvalidStatus(msg) => (StatusCode::BAD_REQUEST, msg.clone()),
223 ApiError::BadRequest(msg) => (StatusCode::BAD_REQUEST, msg.clone()),
224 ApiError::Internal(msg) => (StatusCode::INTERNAL_SERVER_ERROR, msg.clone()),
225 };
226
227 let main_error_msg = context.to_string();
228 let additional_context: Vec<String> = format!("{}", self.0)
229 .lines()
230 .map(str::trim)
231 .filter(|line| {
232 !line.is_empty()
233 && !line.starts_with("at ")
234 && *line != main_error_msg
235 && !line.contains("src/")
236 })
237 .map(String::from)
238 .collect();
239
240 let body = if additional_context.is_empty() {
241 Json(serde_json::json!({"error": error_message}))
242 } else {
243 Json(serde_json::json!({
244 "error": error_message,
245 "context": additional_context
246 }))
247 };
248
249 (status, body).into_response()
250 }
251}
252
253fn parse_task_status(status_str: &str) -> Result<TaskStatus, ApiErrorResponse> {
254 match status_str.to_lowercase().as_str() {
255 "queued" => Ok(TaskStatus::Queued),
256 "in_progress" | "inprogress" | "running" => Ok(TaskStatus::InProgress),
257 "completed" | "success" | "done" => Ok(TaskStatus::Completed),
258 "failed" | "error" => Ok(TaskStatus::Failed),
259 "retrying" | "retry" => Ok(TaskStatus::Retrying),
260 _ => Err(
261 error_stack::report!(ApiError::InvalidStatus(status_str.to_string()))
262 .attach_printable(
263 "Valid statuses: queued, in_progress, completed, failed, retrying",
264 )
265 .into(),
266 ),
267 }
268}