use std::time::{Duration, Instant};
use axum::{
body::Body,
extract::{Path, Query, State},
http::{header, StatusCode},
middleware,
response::{IntoResponse, Response},
routing::{get, post},
Extension, Json, Router,
};
use base64::{engine::general_purpose::STANDARD as B64, Engine};
use serde::Deserialize;
use serde_json::{json, Value};
use dragoon_proto::models::{LogStream, Task, TaskSubmit};
use crate::{
app::{signed_request, username_for, AppState, SignedSession},
audit, messages_repo, storage, tasks_repo, users_repo, workers_repo,
};
pub fn router(state: AppState) -> Router {
Router::new()
.route("/v1/tasks", post(submit))
.route("/v1/tasks/:task_id", get(get_task))
.route("/v1/tasks/:task_id/cancel", post(cancel))
.route("/v1/tasks/:task_id/log", get(tail_log))
.route("/v1/tasks/:task_id/artifact", get(download_artifact))
.layer(middleware::from_fn_with_state(state.clone(), signed_request))
.with_state(state)
}
async fn submit(
State(state): State<AppState>,
Extension(sess): Extension<SignedSession>,
Json(req): Json<TaskSubmit>,
) -> Result<Json<Task>, StatusCode> {
let conn = state.conn.lock().unwrap();
if workers_repo::lookup_by_name(&conn, &req.worker_name)
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.is_none()
{
return Err(StatusCode::NOT_FOUND);
}
let id = tasks_repo::new_task_id();
let username = username_for(&conn, sess.0.user_id);
let task = tasks_repo::insert_task(
&conn,
&id,
&req.worker_name,
&username,
req.kind,
&req.payload,
&req.collect,
&req.limits,
req.fetch_path.as_deref(),
)
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
let _ = audit::log(
&conn,
Some(&username),
"submit",
Some(&id),
Some(&sess.0.fingerprint),
&json!({"worker": req.worker_name, "kind": req.kind.as_wire_str()}),
);
Ok(Json(task))
}
async fn get_task(
State(state): State<AppState>,
Path(task_id): Path<String>,
Extension(_sess): Extension<SignedSession>,
) -> Result<Json<Task>, StatusCode> {
let conn = state.conn.lock().unwrap();
let t = tasks_repo::get_task(&conn, &task_id)
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.ok_or(StatusCode::NOT_FOUND)?;
let _ = tasks_repo::touch_access(&conn, &task_id);
Ok(Json(t))
}
async fn cancel(
State(state): State<AppState>,
Path(task_id): Path<String>,
Extension(sess): Extension<SignedSession>,
) -> Result<Json<Task>, StatusCode> {
let conn = state.conn.lock().unwrap();
let cur = tasks_repo::get_task(&conn, &task_id)
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.ok_or(StatusCode::NOT_FOUND)?;
let new_t = tasks_repo::request_cancel(&conn, &task_id)
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
let username = username_for(&conn, sess.0.user_id);
let _ = audit::log(
&conn,
Some(&username),
"cancel",
Some(&task_id),
Some(&sess.0.fingerprint),
&json!({}),
);
if tasks_repo::is_terminal(new_t.state) {
if let Some(submitter_id) = users_repo::lookup_user_id_by_name(&conn, &cur.submitter)
.ok()
.flatten()
{
let payload = json!({
"state": match new_t.state {
dragoon_proto::models::TaskState::Cancelled => "CANCELLED",
dragoon_proto::models::TaskState::Completed => "COMPLETED",
dragoon_proto::models::TaskState::Failed => "FAILED",
dragoon_proto::models::TaskState::Timeout => "TIMEOUT",
_ => "UNKNOWN",
},
"task_payload": new_t.payload,
"worker_seq": new_t.worker_seq,
});
let _ = messages_repo::enqueue(
&conn,
submitter_id,
"task_state",
Some(&task_id),
Some(&new_t.worker_name),
new_t.error.as_deref(),
&payload,
);
}
}
Ok(Json(new_t))
}
#[derive(Debug, Deserialize)]
struct LogQuery {
#[serde(default)]
since_seq: u64,
#[serde(default = "default_stream")]
stream: String,
}
fn default_stream() -> String {
"stdout".into()
}
async fn tail_log(
State(state): State<AppState>,
Path(task_id): Path<String>,
Query(q): Query<LogQuery>,
Extension(_sess): Extension<SignedSession>,
) -> Result<Json<Value>, StatusCode> {
let stream = match q.stream.as_str() {
"stdout" => LogStream::Stdout,
"stderr" => LogStream::Stderr,
_ => return Err(StatusCode::UNPROCESSABLE_ENTITY),
};
{
let conn = state.conn.lock().unwrap();
if tasks_repo::get_task(&conn, &task_id)
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.is_none()
{
return Err(StatusCode::NOT_FOUND);
}
}
let blobs = state.settings.blobs_dir();
let deadline = Instant::now() + Duration::from_secs_f64(state.settings.log_long_poll_sec);
let mut since = q.since_seq;
loop {
let (data, seq) = storage::read_log_slice(&blobs, &task_id, stream, since)
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
if !data.is_empty() {
let eof = is_log_eof(&state, &task_id);
return Ok(Json(json!({
"stream": q.stream,
"seq": seq,
"data_b64": B64.encode(&data),
"eof": eof,
})));
}
if Instant::now() >= deadline {
let eof = is_log_eof(&state, &task_id);
return Ok(Json(json!({
"stream": q.stream,
"seq": seq,
"data_b64": "",
"eof": eof,
})));
}
tokio::time::sleep(Duration::from_secs_f64(
state.settings.log_long_poll_step_sec,
))
.await;
since = seq;
}
}
fn is_log_eof(state: &AppState, task_id: &str) -> bool {
let conn = state.conn.lock().unwrap();
match tasks_repo::get_task(&conn, task_id) {
Ok(Some(t)) => tasks_repo::is_terminal(t.state),
_ => true,
}
}
#[derive(Debug, Deserialize)]
struct ArtifactQuery {
name: String,
}
async fn download_artifact(
State(state): State<AppState>,
Path(task_id): Path<String>,
Query(q): Query<ArtifactQuery>,
Extension(_sess): Extension<SignedSession>,
) -> Result<Response, StatusCode> {
let blobs = state.settings.blobs_dir();
let name = &q.name;
let path = storage::artifacts_dir(&blobs, &task_id).join(name);
if !path.is_file() {
return Err(StatusCode::NOT_FOUND);
}
let bytes = std::fs::read(&path).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
{
let conn = state.conn.lock().unwrap();
let _ = tasks_repo::touch_access(&conn, &task_id);
}
let response = (
[
(header::CONTENT_TYPE, "application/octet-stream"),
(
header::CONTENT_DISPOSITION,
Box::leak(format!("attachment; filename=\"{name}\"").into_boxed_str()),
),
],
Body::from(bytes),
)
.into_response();
Ok(response)
}