dragoon-server 0.1.0

Public-relay server for the dragoon remote-executor: axum + rusqlite + ed25519 task signing + per-user message inbox.
Documentation
//! `/v1/worker/*` — endpoints called by the worker (dragoon-worker).
//!
//! Auth model: a Bearer token (workers.token_hash) — NOT the per-request
//! SSH signature. The /init endpoint takes a one-shot register code
//! instead.

use axum::{
    extract::{Multipart, State},
    http::{header, HeaderMap, StatusCode},
    response::IntoResponse,
    routing::post,
    Json, Router,
};
use base64::{engine::general_purpose::STANDARD as B64, Engine};
use chrono::Utc;
use serde_json::json;
use sha2::{Digest, Sha256};

use dragoon_proto::{
    models::{
        LogChunk, SignedTask, WorkerFinish, WorkerInitRequest, WorkerInitResponse,
        WorkerPollRequest, WorkerPollResponse, WorkerStatus,
    },
    pubkey::parse_pubkey_blob,
    task_sig::sign_task,
};

use crate::{
    app::AppState, messages_repo, server_keys, storage, tasks_repo, users_repo, workers_repo,
};

pub fn router(state: AppState) -> Router {
    Router::new()
        .route("/v1/worker/init", post(init))
        .route("/v1/worker/poll", post(poll))
        .route("/v1/worker/log", post(log_chunk))
        .route("/v1/worker/finish", post(finish))
        .route("/v1/worker/blob", post(blob))
        .with_state(state)
}

fn worker_from_bearer(
    state: &AppState,
    headers: &HeaderMap,
) -> Result<workers_repo::Worker, StatusCode> {
    let auth = headers
        .get(header::AUTHORIZATION)
        .and_then(|v| v.to_str().ok())
        .and_then(|s| s.strip_prefix("Bearer "))
        .ok_or(StatusCode::UNAUTHORIZED)?;
    let conn = state.conn.lock().unwrap();
    workers_repo::lookup_by_token(&conn, auth)
        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
        .ok_or(StatusCode::UNAUTHORIZED)
}

async fn init(
    State(state): State<AppState>,
    Json(req): Json<WorkerInitRequest>,
) -> Result<Json<WorkerInitResponse>, StatusCode> {
    let client_pub: Option<Vec<u8>> = if let Some(b64) = &req.client_pubkey_b64 {
        let bytes = B64
            .decode(b64.as_bytes())
            .map_err(|_| StatusCode::UNPROCESSABLE_ENTITY)?;
        parse_pubkey_blob(&bytes).map_err(|_| StatusCode::UNPROCESSABLE_ENTITY)?;
        Some(bytes)
    } else {
        None
    };
    let conn = state.conn.lock().unwrap();
    let (token, wid) =
        workers_repo::finalize_register(&conn, &req.name, &req.register_code, client_pub.as_deref())
            .map_err(|_| StatusCode::UNAUTHORIZED)?;
    let pub_blob = server_keys::get_public_blob(&conn).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
    let fp = server_keys::fingerprint(&conn).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
    Ok(Json(WorkerInitResponse {
        worker_token: token,
        worker_id: wid,
        server_pubkey_b64: B64.encode(&pub_blob),
        server_key_fingerprint: fp,
    }))
}

async fn poll(
    State(state): State<AppState>,
    headers: HeaderMap,
    Json(req): Json<WorkerPollRequest>,
) -> Result<Json<WorkerPollResponse>, StatusCode> {
    let w = worker_from_bearer(&state, &headers)?;
    let conn = state.conn.lock().unwrap();
    workers_repo::update_status(
        &conn,
        &w.name,
        match req.status {
            WorkerStatus::Idle => "IDLE",
            WorkerStatus::Busy => "BUSY",
        },
        None,
        req.current_task_id.as_deref(),
    )
    .ok();

    if let Some(current) = &req.current_task_id {
        // already running — just report any cancel signal
        match tasks_repo::get_task(&conn, current).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? {
            None => {
                return Ok(Json(WorkerPollResponse {
                    task: None,
                    cancel_task_id: None,
                }));
            }
            Some(_) => {
                let cancel = tasks_repo::consume_cancel_signal(&conn, current)
                    .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
                return Ok(Json(WorkerPollResponse {
                    task: None,
                    cancel_task_id: if cancel { Some(current.clone()) } else { None },
                }));
            }
        }
    }

    if !matches!(req.status, WorkerStatus::Idle) {
        return Ok(Json(WorkerPollResponse {
            task: None,
            cancel_task_id: None,
        }));
    }

    let nxt = tasks_repo::next_queued_for_worker(&conn, &w.name)
        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
    let Some(nxt) = nxt else {
        return Ok(Json(WorkerPollResponse {
            task: None,
            cancel_task_id: None,
        }));
    };
    let started = Utc::now();
    let t = tasks_repo::transition(
        &conn,
        &nxt.task_id,
        dragoon_proto::models::TaskState::Running,
        tasks_repo::TransitionUpdate {
            started_at: Some(started),
            ..Default::default()
        },
    )
    .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;

    let priv_key = server_keys::get_private(&conn).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
    let (signature_b64, server_key_fingerprint) = sign_task(&priv_key, &t);
    Ok(Json(WorkerPollResponse {
        task: Some(SignedTask {
            task: t,
            signature_b64,
            server_key_fingerprint,
        }),
        cancel_task_id: None,
    }))
}

async fn log_chunk(
    State(state): State<AppState>,
    headers: HeaderMap,
    Json(req): Json<LogChunk>,
) -> Result<Json<serde_json::Value>, StatusCode> {
    let w = worker_from_bearer(&state, &headers)?;
    {
        let conn = state.conn.lock().unwrap();
        let t = tasks_repo::get_task(&conn, &req.task_id)
            .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
            .ok_or(StatusCode::NOT_FOUND)?;
        if t.worker_name != w.name {
            return Err(StatusCode::FORBIDDEN);
        }
    }
    let data = if req.data_b64.is_empty() {
        Vec::new()
    } else {
        B64.decode(req.data_b64.as_bytes())
            .map_err(|_| StatusCode::UNPROCESSABLE_ENTITY)?
    };
    if !data.is_empty() {
        let _ = storage::append_log(&state.settings.blobs_dir(), &req.task_id, req.stream, &data);
    }
    Ok(Json(json!({"ok": true})))
}

async fn finish(
    State(state): State<AppState>,
    headers: HeaderMap,
    Json(req): Json<WorkerFinish>,
) -> Result<Json<serde_json::Value>, StatusCode> {
    use dragoon_proto::models::TaskState;

    let w = worker_from_bearer(&state, &headers)?;
    let conn = state.conn.lock().unwrap();
    let t = tasks_repo::get_task(&conn, &req.task_id)
        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
        .ok_or(StatusCode::NOT_FOUND)?;
    if t.worker_name != w.name {
        return Err(StatusCode::FORBIDDEN);
    }

    let new_state = if req.error.as_deref() == Some("timeout") {
        TaskState::Timeout
    } else if tasks_repo::consume_cancel_signal(&conn, &req.task_id)
        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
    {
        TaskState::Cancelled
    } else if req.exit_code == Some(0) && req.error.is_none() {
        TaskState::Completed
    } else {
        TaskState::Failed
    };

    if matches!(new_state, TaskState::Cancelled) && t.state == TaskState::Running {
        let _ = tasks_repo::transition(
            &conn,
            &req.task_id,
            TaskState::Cancelling,
            tasks_repo::TransitionUpdate::default(),
        );
    }

    let new_t = tasks_repo::transition(
        &conn,
        &req.task_id,
        new_state,
        tasks_repo::TransitionUpdate {
            finished_at: Some(Utc::now()),
            exit_code: req.exit_code,
            final_pwd: req.final_pwd.clone(),
            error: req.error.clone(),
            ..Default::default()
        },
    )
    .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;

    workers_repo::update_status(
        &conn,
        &w.name,
        "IDLE",
        req.final_pwd.as_deref(),
        None,
    )
    .ok();

    for art in &req.artifacts {
        // skip duplicates
        let exists: Option<i64> = conn
            .query_row(
                "SELECT id FROM artifacts WHERE task_id=? AND path=? AND sha256=?",
                rusqlite::params![req.task_id, art.path, art.sha256],
                |r| r.get(0),
            )
            .ok();
        if exists.is_none() {
            tasks_repo::add_artifact(&conn, &req.task_id, art, "")
                .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
        }
    }

    if let Some(submitter_id) = users_repo::lookup_user_id_by_name(&conn, &t.submitter)
        .ok()
        .flatten()
    {
        let kind = if req.error.is_none() {
            "task_state"
        } else {
            "worker_error"
        };
        let payload = json!({
            "state": match new_t.state {
                TaskState::Completed => "COMPLETED",
                TaskState::Failed => "FAILED",
                TaskState::Timeout => "TIMEOUT",
                TaskState::Cancelled => "CANCELLED",
                TaskState::Cancelling => "CANCELLING",
                TaskState::Running => "RUNNING",
                TaskState::Queued => "QUEUED",
            },
            "exit_code": req.exit_code,
            "final_pwd": req.final_pwd,
            "task_payload": t.payload,
            "worker_seq": t.worker_seq,
        });
        let _ = messages_repo::enqueue(
            &conn,
            submitter_id,
            kind,
            Some(&req.task_id),
            Some(&w.name),
            req.error.as_deref(),
            &payload,
        );
    }

    Ok(Json(json!({"ok": true, "state": match new_t.state {
        TaskState::Completed => "COMPLETED",
        TaskState::Failed => "FAILED",
        TaskState::Timeout => "TIMEOUT",
        TaskState::Cancelled => "CANCELLED",
        TaskState::Cancelling => "CANCELLING",
        TaskState::Running => "RUNNING",
        TaskState::Queued => "QUEUED",
    }})))
}

async fn blob(
    State(state): State<AppState>,
    headers: HeaderMap,
    mut multipart: Multipart,
) -> impl IntoResponse {
    let w = match worker_from_bearer(&state, &headers) {
        Ok(w) => w,
        Err(s) => return (s, "unauthenticated").into_response(),
    };

    let mut task_id: Option<String> = None;
    let mut path: Option<String> = None;
    let mut sha: Option<String> = None;
    let mut file_bytes: Option<Vec<u8>> = None;

    while let Ok(Some(field)) = multipart.next_field().await {
        let name = field.name().unwrap_or("").to_string();
        match name.as_str() {
            "task_id" => task_id = field.text().await.ok(),
            "path" => path = field.text().await.ok(),
            "sha256" => sha = field.text().await.ok(),
            "file" => {
                let data = field.bytes().await.unwrap_or_default();
                file_bytes = Some(data.to_vec());
            }
            _ => {}
        }
    }

    let (Some(task_id), Some(path), Some(sha), Some(data)) = (task_id, path, sha, file_bytes)
    else {
        return (StatusCode::UNPROCESSABLE_ENTITY, "missing fields").into_response();
    };

    {
        let conn = state.conn.lock().unwrap();
        let Some(t) = tasks_repo::get_task(&conn, &task_id).ok().flatten() else {
            return (StatusCode::NOT_FOUND, "unknown task").into_response();
        };
        if t.worker_name != w.name {
            return (StatusCode::FORBIDDEN, "not owned").into_response();
        }
    }

    let actual_sha = hex::encode(Sha256::digest(&data));
    if actual_sha.to_lowercase() != sha.to_lowercase() {
        return (StatusCode::UNPROCESSABLE_ENTITY, "sha256 mismatch").into_response();
    }
    if storage::store_artifact(&state.settings.blobs_dir(), &task_id, &path, &data).is_err() {
        return (StatusCode::INTERNAL_SERVER_ERROR, "store failed").into_response();
    }
    let conn = state.conn.lock().unwrap();
    let exists: Option<i64> = conn
        .query_row(
            "SELECT id FROM artifacts WHERE task_id=? AND path=? AND sha256=?",
            rusqlite::params![task_id, path, actual_sha],
            |r| r.get(0),
        )
        .ok();
    if exists.is_none() {
        let art = dragoon_proto::models::Artifact {
            path: path.clone(),
            size: data.len() as u64,
            sha256: actual_sha.clone(),
        };
        let _ = tasks_repo::add_artifact(&conn, &task_id, &art, &path);
    }
    Json(json!({"ok": true, "size": data.len()})).into_response()
}