dragoon-server 0.1.0

Public-relay server for the dragoon remote-executor: axum + rusqlite + ed25519 task signing + per-user message inbox.
Documentation
//! `/v1/tasks/*` routes (submit / get / cancel / log tail / artifact download).

use std::time::{Duration, Instant};

use axum::{
    body::Body,
    extract::{Path, Query, State},
    http::{header, StatusCode},
    middleware,
    response::{IntoResponse, Response},
    routing::{get, post},
    Extension, Json, Router,
};
use base64::{engine::general_purpose::STANDARD as B64, Engine};
use serde::Deserialize;
use serde_json::{json, Value};

use dragoon_proto::models::{LogStream, Task, TaskSubmit};

use crate::{
    app::{signed_request, username_for, AppState, SignedSession},
    audit, messages_repo, storage, tasks_repo, users_repo, workers_repo,
};

pub fn router(state: AppState) -> Router {
    // The artifact name comes through as a query parameter rather than a
    // catch-all path segment, sidestepping a matchit limitation where
    // `:task_id` and `:task_id/*rest` siblings cause `/v1/tasks/:task_id`
    // to vanish from the route table.
    Router::new()
        .route("/v1/tasks", post(submit))
        .route("/v1/tasks/:task_id", get(get_task))
        .route("/v1/tasks/:task_id/cancel", post(cancel))
        .route("/v1/tasks/:task_id/log", get(tail_log))
        .route("/v1/tasks/:task_id/artifact", get(download_artifact))
        .layer(middleware::from_fn_with_state(state.clone(), signed_request))
        .with_state(state)
}

async fn submit(
    State(state): State<AppState>,
    Extension(sess): Extension<SignedSession>,
    Json(req): Json<TaskSubmit>,
) -> Result<Json<Task>, StatusCode> {
    let conn = state.conn.lock().unwrap();
    if workers_repo::lookup_by_name(&conn, &req.worker_name)
        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
        .is_none()
    {
        return Err(StatusCode::NOT_FOUND);
    }
    let id = tasks_repo::new_task_id();
    let username = username_for(&conn, sess.0.user_id);
    let task = tasks_repo::insert_task(
        &conn,
        &id,
        &req.worker_name,
        &username,
        req.kind,
        &req.payload,
        &req.collect,
        &req.limits,
        req.fetch_path.as_deref(),
    )
    .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
    let _ = audit::log(
        &conn,
        Some(&username),
        "submit",
        Some(&id),
        Some(&sess.0.fingerprint),
        &json!({"worker": req.worker_name, "kind": req.kind.as_wire_str()}),
    );
    Ok(Json(task))
}

async fn get_task(
    State(state): State<AppState>,
    Path(task_id): Path<String>,
    Extension(_sess): Extension<SignedSession>,
) -> Result<Json<Task>, StatusCode> {
    let conn = state.conn.lock().unwrap();
    let t = tasks_repo::get_task(&conn, &task_id)
        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
        .ok_or(StatusCode::NOT_FOUND)?;
    let _ = tasks_repo::touch_access(&conn, &task_id);
    Ok(Json(t))
}

async fn cancel(
    State(state): State<AppState>,
    Path(task_id): Path<String>,
    Extension(sess): Extension<SignedSession>,
) -> Result<Json<Task>, StatusCode> {
    let conn = state.conn.lock().unwrap();
    let cur = tasks_repo::get_task(&conn, &task_id)
        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
        .ok_or(StatusCode::NOT_FOUND)?;
    let new_t = tasks_repo::request_cancel(&conn, &task_id)
        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
    let username = username_for(&conn, sess.0.user_id);
    let _ = audit::log(
        &conn,
        Some(&username),
        "cancel",
        Some(&task_id),
        Some(&sess.0.fingerprint),
        &json!({}),
    );
    if tasks_repo::is_terminal(new_t.state) {
        if let Some(submitter_id) = users_repo::lookup_user_id_by_name(&conn, &cur.submitter)
            .ok()
            .flatten()
        {
            let payload = json!({
                "state": match new_t.state {
                    dragoon_proto::models::TaskState::Cancelled => "CANCELLED",
                    dragoon_proto::models::TaskState::Completed => "COMPLETED",
                    dragoon_proto::models::TaskState::Failed => "FAILED",
                    dragoon_proto::models::TaskState::Timeout => "TIMEOUT",
                    _ => "UNKNOWN",
                },
                "task_payload": new_t.payload,
                "worker_seq": new_t.worker_seq,
            });
            let _ = messages_repo::enqueue(
                &conn,
                submitter_id,
                "task_state",
                Some(&task_id),
                Some(&new_t.worker_name),
                new_t.error.as_deref(),
                &payload,
            );
        }
    }
    Ok(Json(new_t))
}

#[derive(Debug, Deserialize)]
struct LogQuery {
    #[serde(default)]
    since_seq: u64,
    #[serde(default = "default_stream")]
    stream: String,
}

fn default_stream() -> String {
    "stdout".into()
}

async fn tail_log(
    State(state): State<AppState>,
    Path(task_id): Path<String>,
    Query(q): Query<LogQuery>,
    Extension(_sess): Extension<SignedSession>,
) -> Result<Json<Value>, StatusCode> {
    let stream = match q.stream.as_str() {
        "stdout" => LogStream::Stdout,
        "stderr" => LogStream::Stderr,
        _ => return Err(StatusCode::UNPROCESSABLE_ENTITY),
    };
    {
        let conn = state.conn.lock().unwrap();
        if tasks_repo::get_task(&conn, &task_id)
            .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
            .is_none()
        {
            return Err(StatusCode::NOT_FOUND);
        }
    }
    let blobs = state.settings.blobs_dir();
    let deadline = Instant::now() + Duration::from_secs_f64(state.settings.log_long_poll_sec);
    let mut since = q.since_seq;
    loop {
        let (data, seq) = storage::read_log_slice(&blobs, &task_id, stream, since)
            .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
        if !data.is_empty() {
            let eof = is_log_eof(&state, &task_id);
            return Ok(Json(json!({
                "stream": q.stream,
                "seq": seq,
                "data_b64": B64.encode(&data),
                "eof": eof,
            })));
        }
        if Instant::now() >= deadline {
            let eof = is_log_eof(&state, &task_id);
            return Ok(Json(json!({
                "stream": q.stream,
                "seq": seq,
                "data_b64": "",
                "eof": eof,
            })));
        }
        tokio::time::sleep(Duration::from_secs_f64(
            state.settings.log_long_poll_step_sec,
        ))
        .await;
        since = seq;
    }
}

fn is_log_eof(state: &AppState, task_id: &str) -> bool {
    let conn = state.conn.lock().unwrap();
    match tasks_repo::get_task(&conn, task_id) {
        Ok(Some(t)) => tasks_repo::is_terminal(t.state),
        _ => true,
    }
}

#[derive(Debug, Deserialize)]
struct ArtifactQuery {
    name: String,
}

async fn download_artifact(
    State(state): State<AppState>,
    Path(task_id): Path<String>,
    Query(q): Query<ArtifactQuery>,
    Extension(_sess): Extension<SignedSession>,
) -> Result<Response, StatusCode> {
    let blobs = state.settings.blobs_dir();
    let name = &q.name;
    let path = storage::artifacts_dir(&blobs, &task_id).join(name);
    if !path.is_file() {
        return Err(StatusCode::NOT_FOUND);
    }
    let bytes = std::fs::read(&path).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
    {
        let conn = state.conn.lock().unwrap();
        let _ = tasks_repo::touch_access(&conn, &task_id);
    }
    let response = (
        [
            (header::CONTENT_TYPE, "application/octet-stream"),
            (
                header::CONTENT_DISPOSITION,
                Box::leak(format!("attachment; filename=\"{name}\"").into_boxed_str()),
            ),
        ],
        Body::from(bytes),
    )
        .into_response();
    Ok(response)
}