1use std::sync::Arc;
17
18use axum::{
19 Json, Router,
20 extract::{FromRequest, Path, Query, Request, State, rejection::JsonRejection},
21 http::StatusCode,
22 middleware::{self, Next},
23 response::{IntoResponse, Response},
24 routing::{delete, get, post},
25};
26use serde::{Deserialize, de::DeserializeOwned};
27use solti_model::{TaskId, TaskPhase, TaskQuery};
28use tower_http::limit::RequestBodyLimitLayer;
29use tracing::debug;
30
31use crate::{
32 MAX_REQUEST_BYTES,
33 convert::{self, tasks_page_to_proto},
34 error::ApiError,
35 handler::ApiHandler,
36 proto_api,
37 validate::{clamp_list_limit, non_empty_id},
38};
39pub(crate) struct ApiJson<T>(pub T);
45
46impl<T, S> FromRequest<S> for ApiJson<T>
47where
48 T: DeserializeOwned,
49 S: Send + Sync,
50{
51 type Rejection = ApiError;
52
53 async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
54 let Json(value) = axum::Json::<T>::from_request(req, state)
55 .await
56 .map_err(map_json_rejection)?;
57 Ok(ApiJson(value))
58 }
59}
60
61fn map_json_rejection(rej: JsonRejection) -> ApiError {
62 if rej.status() == StatusCode::PAYLOAD_TOO_LARGE {
63 return ApiError::PayloadTooLarge(format!(
64 "request body exceeds the maximum of {} bytes",
65 MAX_REQUEST_BYTES
66 ));
67 }
68
69 let msg = rej.body_text();
70 let trimmed = msg
71 .strip_prefix("Failed to deserialize the JSON body into the target type: ")
72 .or_else(|| msg.strip_prefix("Failed to parse the request body as JSON: "))
73 .unwrap_or(&msg)
74 .to_string();
75 ApiError::InvalidRequest(trimmed)
76}
77
78async fn map_413_envelope(req: Request, next: Next) -> Response {
79 let resp = next.run(req).await;
80 if resp.status() == StatusCode::PAYLOAD_TOO_LARGE {
81 let body = serde_json::json!({
82 "error": "PayloadTooLarge",
83 "message": format!(
84 "request body exceeds the maximum of {} bytes",
85 MAX_REQUEST_BYTES
86 ),
87 });
88 return (StatusCode::PAYLOAD_TOO_LARGE, Json(body)).into_response();
89 }
90 resp
91}
92
93pub struct HttpApi<H> {
100 handler: Arc<H>,
101}
102
103impl<H> HttpApi<H>
104where
105 H: ApiHandler,
106{
107 pub fn new(handler: Arc<H>) -> Self {
109 Self { handler }
110 }
111
112 pub fn router(self) -> Router {
116 Router::new()
117 .route(api_url!("/tasks"), post(submit_task::<H>))
118 .route(api_url!("/tasks"), get(list_tasks::<H>))
119 .route(api_url!("/tasks/{id}"), get(get_task_status::<H>))
120 .route(api_url!("/tasks/{id}"), delete(delete_task::<H>))
121 .route(api_url!("/tasks/{id}/runs"), get(list_task_runs::<H>))
122 .layer(RequestBodyLimitLayer::new(MAX_REQUEST_BYTES))
123 .layer(middleware::from_fn(map_413_envelope))
124 .with_state(self.handler)
125 }
126}
127
128#[derive(Debug, Deserialize)]
129struct ListTasksParams {
130 slot: Option<String>,
131 status: Option<String>,
132 limit: Option<u32>,
133 offset: Option<u32>,
134}
135
136async fn submit_task<H>(
137 State(handler): State<Arc<H>>,
138 ApiJson(req): ApiJson<proto_api::SubmitTaskRequest>,
139) -> Result<impl IntoResponse, ApiError>
140where
141 H: ApiHandler,
142{
143 let spec = req
144 .spec
145 .ok_or_else(|| ApiError::InvalidRequest("missing spec".into()))?;
146 let spec = convert::convert_create_spec(spec)?;
147
148 debug!(slot = %spec.slot(), kind = ?spec.kind(), "submitting task");
149 let task_id = handler.submit_task(spec).await?;
150
151 let response = proto_api::SubmitTaskResponse {
152 task_id: task_id.to_string(),
153 };
154 Ok((StatusCode::CREATED, Json(response)))
155}
156
157async fn get_task_status<H>(
158 State(handler): State<Arc<H>>,
159 Path(id): Path<String>,
160) -> Result<impl IntoResponse, ApiError>
161where
162 H: ApiHandler,
163{
164 non_empty_id("task_id", &id)?;
165
166 let task_id = TaskId::from(id);
167 debug!(%task_id, "getting task status");
168 let task = handler.get_task_status(&task_id).await?;
169
170 let task = task.map(proto_api::TaskData::try_from).transpose()?;
171 Ok(Json(proto_api::GetTaskStatusResponse { task }))
172}
173
174async fn list_tasks<H>(
175 State(handler): State<Arc<H>>,
176 Query(params): Query<ListTasksParams>,
177) -> Result<impl IntoResponse, ApiError>
178where
179 H: ApiHandler,
180{
181 let mut query = TaskQuery::new();
182
183 if let Some(slot) = params.slot {
184 non_empty_id("slot", &slot)?;
185 query = query.with_slot(slot);
186 }
187
188 if let Some(status_str) = params.status {
189 let status = status_str.parse::<TaskPhase>().map_err(|_| {
190 ApiError::InvalidRequest(format!(
191 "invalid status: '{status_str}' (valid: pending, running, succeeded, failed, timeout, canceled, exhausted)"
192 ))
193 })?;
194 query = query.with_status(status);
195 }
196
197 query = query.with_limit(clamp_list_limit(params.limit.unwrap_or(0)));
198 if let Some(offset) = params.offset {
199 query = query.with_offset(offset as usize);
200 }
201
202 let page = handler.query_tasks(query).await?;
203 debug!(count = page.items.len(), total = page.total, "tasks listed");
204
205 Ok(Json(tasks_page_to_proto(page)?))
206}
207
208async fn list_task_runs<H>(
209 State(handler): State<Arc<H>>,
210 Path(id): Path<String>,
211) -> Result<impl IntoResponse, ApiError>
212where
213 H: ApiHandler,
214{
215 non_empty_id("task_id", &id)?;
216
217 let task_id = TaskId::from(id);
218 debug!(%task_id, "listing task runs");
219 let runs = handler.list_task_runs(&task_id).await?;
220 let runs = runs.into_iter().map(proto_api::TaskRunInfo::from).collect();
221
222 Ok(Json(proto_api::ListTaskRunsResponse { runs }))
223}
224
225async fn delete_task<H>(
226 State(handler): State<Arc<H>>,
227 Path(id): Path<String>,
228) -> Result<impl IntoResponse, ApiError>
229where
230 H: ApiHandler,
231{
232 non_empty_id("task_id", &id)?;
233
234 let task_id = TaskId::from(id);
235 handler.delete_task(&task_id).await?;
236 debug!(%task_id, "task deleted");
237
238 Ok(StatusCode::NO_CONTENT)
239}