Skip to main content

solti_api/
http.rs

1//! # HTTP/JSON transport.
2//!
3//! Axum router exposing [`ApiHandler`] operations as REST-shaped JSON endpoints.
4//! All paths share the `/api/v<MAJOR>` prefix where `MAJOR` is [`crate::API_VERSION`];
5//!
6//! _the examples below show the current value (`v1`)_.
7//!
8//! | Method | Endpoint                    | Handler              |
9//! |--------|-----------------------------|----------------------|
10//! | POST   | `/api/v1/tasks`             | submit               |
11//! | GET    | `/api/v1/tasks`             | list (query params)  |
12//! | GET    | `/api/v1/tasks/{id}`        | get status           |
13//! | GET    | `/api/v1/tasks/{id}/runs`   | list runs            |
14//! | GET    | `/api/v1/tasks/{id}/logs`   | live-tail SSE stream |
15//! | DELETE | `/api/v1/tasks/{id}`        | delete (stop+purge)  |
16
17use std::sync::Arc;
18
19use std::convert::Infallible;
20
21use axum::{
22    Json, Router,
23    extract::{FromRequest, Path, Query, Request, State, rejection::JsonRejection},
24    http::StatusCode,
25    middleware::{self, Next},
26    response::{
27        IntoResponse, Response,
28        sse::{Event, KeepAlive, Sse},
29    },
30    routing::{delete, get, post},
31};
32use serde::{Deserialize, de::DeserializeOwned};
33use solti_model::{OutputEvent, TaskId, TaskPhase, TaskQuery};
34use tokio_stream::StreamExt;
35use tower_http::limit::RequestBodyLimitLayer;
36use tracing::debug;
37
38use crate::{
39    MAX_REQUEST_BYTES,
40    convert::{self, tasks_page_to_proto},
41    error::ApiError,
42    handler::ApiHandler,
43    proto_api,
44    validate::{clamp_list_limit, non_empty_id},
45};
46// `api_url!` is `#[macro_export]`, so it's already accessible in this
47// module by its bare name — `use crate::api_url` would be redundant
48// (and warnings about unused imports broke a `cargo publish` on us).
49
50/// Wrapper around `axum::Json<T>` that maps `JsonRejection` into [`ApiError::InvalidRequest`].
51pub(crate) struct ApiJson<T>(pub T);
52
53impl<T, S> FromRequest<S> for ApiJson<T>
54where
55    T: DeserializeOwned,
56    S: Send + Sync,
57{
58    type Rejection = ApiError;
59
60    async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
61        let Json(value) = axum::Json::<T>::from_request(req, state)
62            .await
63            .map_err(map_json_rejection)?;
64        Ok(ApiJson(value))
65    }
66}
67
68fn map_json_rejection(rej: JsonRejection) -> ApiError {
69    if rej.status() == StatusCode::PAYLOAD_TOO_LARGE {
70        return ApiError::PayloadTooLarge(format!(
71            "request body exceeds the maximum of {} bytes",
72            MAX_REQUEST_BYTES
73        ));
74    }
75
76    let msg = rej.body_text();
77    let trimmed = msg
78        .strip_prefix("Failed to deserialize the JSON body into the target type: ")
79        .or_else(|| msg.strip_prefix("Failed to parse the request body as JSON: "))
80        .unwrap_or(&msg)
81        .to_string();
82    ApiError::InvalidRequest(trimmed)
83}
84
85async fn map_413_envelope(req: Request, next: Next) -> Response {
86    let resp = next.run(req).await;
87    if resp.status() == StatusCode::PAYLOAD_TOO_LARGE {
88        let body = serde_json::json!({
89            "error": "PayloadTooLarge",
90            "message": format!(
91                "request body exceeds the maximum of {} bytes",
92                MAX_REQUEST_BYTES
93            ),
94        });
95        return (StatusCode::PAYLOAD_TOO_LARGE, Json(body)).into_response();
96    }
97    resp
98}
99
100/// HTTP API service builder.
101///
102/// ## Also
103///
104/// - [`ApiHandler`](crate::ApiHandler) the trait backing all endpoints.
105/// - [`ApiError`](crate::ApiError) mapped to JSON + HTTP status codes.
106pub struct HttpApi<H> {
107    handler: Arc<H>,
108}
109
110impl<H> HttpApi<H>
111where
112    H: ApiHandler,
113{
114    /// Create new HTTP API with the given handler.
115    pub fn new(handler: Arc<H>) -> Self {
116        Self { handler }
117    }
118
119    /// Build axum router with mounted endpoints.
120    ///
121    /// Applies a [`RequestBodyLimitLayer`] capped at [`MAX_REQUEST_BYTES`] bytes to every request.
122    pub fn router(self) -> Router {
123        Router::new()
124            .route(api_url!("/tasks"), post(submit_task::<H>))
125            .route(api_url!("/tasks"), get(list_tasks::<H>))
126            .route(api_url!("/tasks/{id}"), get(get_task_status::<H>))
127            .route(api_url!("/tasks/{id}"), delete(delete_task::<H>))
128            .route(api_url!("/tasks/{id}/runs"), get(list_task_runs::<H>))
129            .route(api_url!("/tasks/{id}/logs"), get(stream_task_logs::<H>))
130            .layer(RequestBodyLimitLayer::new(MAX_REQUEST_BYTES))
131            .layer(middleware::from_fn(map_413_envelope))
132            .with_state(self.handler)
133    }
134}
135
136#[derive(Debug, Deserialize)]
137struct ListTasksParams {
138    slot: Option<String>,
139    status: Option<String>,
140    limit: Option<u32>,
141    offset: Option<u32>,
142}
143
144async fn submit_task<H>(
145    State(handler): State<Arc<H>>,
146    ApiJson(req): ApiJson<proto_api::SubmitTaskRequest>,
147) -> Result<impl IntoResponse, ApiError>
148where
149    H: ApiHandler,
150{
151    let spec = req
152        .spec
153        .ok_or_else(|| ApiError::InvalidRequest("missing spec".into()))?;
154    let spec = convert::convert_create_spec(spec)?;
155
156    debug!(slot = %spec.slot(), kind = ?spec.kind(), "submitting task");
157    let task_id = handler.submit_task(spec).await?;
158
159    let response = proto_api::SubmitTaskResponse {
160        task_id: task_id.to_string(),
161    };
162    Ok((StatusCode::CREATED, Json(response)))
163}
164
165async fn get_task_status<H>(
166    State(handler): State<Arc<H>>,
167    Path(id): Path<String>,
168) -> Result<impl IntoResponse, ApiError>
169where
170    H: ApiHandler,
171{
172    non_empty_id("task_id", &id)?;
173
174    let task_id = TaskId::from(id);
175    debug!(%task_id, "getting task status");
176    let task = handler.get_task_status(&task_id).await?;
177
178    let task = task.map(proto_api::TaskData::try_from).transpose()?;
179    Ok(Json(proto_api::GetTaskStatusResponse { task }))
180}
181
182async fn list_tasks<H>(
183    State(handler): State<Arc<H>>,
184    Query(params): Query<ListTasksParams>,
185) -> Result<impl IntoResponse, ApiError>
186where
187    H: ApiHandler,
188{
189    let mut query = TaskQuery::new();
190
191    if let Some(slot) = params.slot {
192        non_empty_id("slot", &slot)?;
193        query = query.with_slot(slot);
194    }
195
196    if let Some(status_str) = params.status {
197        let status = status_str.parse::<TaskPhase>().map_err(|_| {
198            ApiError::InvalidRequest(format!(
199                "invalid status: '{status_str}' (valid: pending, running, succeeded, failed, timeout, canceled, exhausted)"
200            ))
201        })?;
202        query = query.with_status(status);
203    }
204
205    query = query.with_limit(clamp_list_limit(params.limit.unwrap_or(0)));
206    if let Some(offset) = params.offset {
207        query = query.with_offset(offset as usize);
208    }
209
210    let page = handler.query_tasks(query).await?;
211    debug!(count = page.items.len(), total = page.total, "tasks listed");
212
213    Ok(Json(tasks_page_to_proto(page)?))
214}
215
216async fn list_task_runs<H>(
217    State(handler): State<Arc<H>>,
218    Path(id): Path<String>,
219) -> Result<impl IntoResponse, ApiError>
220where
221    H: ApiHandler,
222{
223    non_empty_id("task_id", &id)?;
224
225    let task_id = TaskId::from(id);
226    debug!(%task_id, "listing task runs");
227    let runs = handler.list_task_runs(&task_id).await?;
228    let runs = runs.into_iter().map(proto_api::TaskRunInfo::from).collect();
229
230    Ok(Json(proto_api::ListTaskRunsResponse { runs }))
231}
232
233async fn delete_task<H>(
234    State(handler): State<Arc<H>>,
235    Path(id): Path<String>,
236) -> Result<impl IntoResponse, ApiError>
237where
238    H: ApiHandler,
239{
240    non_empty_id("task_id", &id)?;
241
242    let task_id = TaskId::from(id);
243    handler.delete_task(&task_id).await?;
244    debug!(%task_id, "task deleted");
245
246    Ok(StatusCode::NO_CONTENT)
247}
248
249/// `GET /tasks/{id}/logs` - Server-Sent Events stream of [`OutputEvent`]s (live tail of stdout/stderr + run boundary markers + lag signals).
250async fn stream_task_logs<H>(
251    State(handler): State<Arc<H>>,
252    Path(id): Path<String>,
253) -> Result<Sse<impl tokio_stream::Stream<Item = Result<Event, Infallible>>>, ApiError>
254where
255    H: ApiHandler,
256{
257    non_empty_id("task_id", &id)?;
258
259    let task_id = TaskId::from(id);
260    debug!(%task_id, "subscribing to task log stream");
261    let stream = handler.stream_task_logs(&task_id).await?;
262
263    let sse_stream = stream.map(|ev| {
264        let name = match &ev {
265            OutputEvent::Chunk(_) => "chunk",
266            OutputEvent::RunStarted { .. } => "run-started",
267            OutputEvent::RunFinished { .. } => "run-finished",
268            OutputEvent::Lagged { .. } => "lagged",
269        };
270        let data = serde_json::to_string(&ev).unwrap_or_else(|_| "{}".into());
271        Ok(Event::default().event(name).data(data))
272    });
273    Ok(Sse::new(sse_stream).keep_alive(KeepAlive::default()))
274}