dragoon-server 0.1.0

Public-relay server for the dragoon remote-executor: axum + rusqlite + ed25519 task signing + per-user message inbox.
Documentation
//! axum application factory + shared state + signed-request middleware.

use std::sync::{Arc, Mutex};

use anyhow::Result;
use axum::{
    body::Body,
    extract::State,
    http::{header, Request, StatusCode},
    middleware::Next,
    response::{IntoResponse, Response},
    Router,
};
use rusqlite::Connection;

use dragoon_proto::constants;

use crate::{audit, auth, db, server_keys, settings::Settings};

/// Application state shared by every route. The single SQLite connection
/// is wrapped in a `Mutex` since rusqlite is sync-only; handlers acquire
/// the lock briefly. For long-running queries we'd `spawn_blocking`, but
/// the workload here is microsecond-scale.
#[derive(Clone)]
pub struct AppState {
    pub settings: Arc<Settings>,
    pub conn: Arc<Mutex<Connection>>,
}

const MAX_REQUEST_BODY: usize = 32 * 1024 * 1024; // 32 MiB

/// Wrap a [`Connection`] in shared state and apply schema bootstrap +
/// the server task-signing key initialisation.
pub fn build_state(settings: Settings) -> Result<AppState> {
    std::fs::create_dir_all(&settings.data_dir)?;
    std::fs::create_dir_all(settings.blobs_dir())?;
    let conn = db::connect(settings.db_path())?;
    db::bootstrap(&conn)?;
    server_keys::ensure(&conn)?;
    Ok(AppState {
        settings: Arc::new(settings),
        conn: Arc::new(Mutex::new(conn)),
    })
}

pub fn build_state_in_memory(settings: Settings) -> Result<AppState> {
    let conn = db::connect_in_memory()?;
    db::bootstrap(&conn)?;
    server_keys::ensure(&conn)?;
    Ok(AppState {
        settings: Arc::new(settings),
        conn: Arc::new(Mutex::new(conn)),
    })
}

/// Build the axum [`Router`] with every route and middleware wired up.
pub fn create_app(state: AppState) -> Router {
    use crate::routes;
    Router::new()
        .merge(routes::auth::router(state.clone()))
        .merge(routes::workers::router(state.clone()))
        .merge(routes::tasks::router(state.clone()))
        .merge(routes::messages::router(state.clone()))
        .merge(routes::worker_api::router(state))
}

// --------------------------------------------------------------------------
// Signed-request middleware
// --------------------------------------------------------------------------

/// Carries the verified [`auth::Session`] into the route handler via
/// `Extension<SignedSession>`.
#[derive(Clone, Debug)]
pub struct SignedSession(pub auth::Session);

/// Middleware factory that:
///
/// 1. Reads the body bytes (so the canonical request hash can be computed).
/// 2. Reads the four `X-RE-*` signature headers + `Authorization: Bearer`.
/// 3. Runs [`auth::verify_signed_request`].
/// 4. On success, re-attaches the body and inserts `SignedSession` into
///    `request.extensions()` so handlers can pull it out.
/// 5. On failure, audits the reason and returns 401.
pub async fn signed_request(
    State(state): State<AppState>,
    req: Request<Body>,
    next: Next,
) -> Response {
    let (parts, body) = req.into_parts();
    let Ok(body_bytes) = axum::body::to_bytes(body, MAX_REQUEST_BODY).await else {
        return (StatusCode::PAYLOAD_TOO_LARGE, "body too large").into_response();
    };

    let auth_header = parts
        .headers
        .get(header::AUTHORIZATION)
        .and_then(|v| v.to_str().ok())
        .unwrap_or("");
    let session_token = auth_header.strip_prefix("Bearer ").unwrap_or("");

    let ts = parts
        .headers
        .get(constants::HDR_TIMESTAMP)
        .and_then(|v| v.to_str().ok())
        .and_then(|s| s.parse::<i64>().ok());
    let nonce = parts
        .headers
        .get(constants::HDR_NONCE)
        .and_then(|v| v.to_str().ok());
    let fp = parts
        .headers
        .get(constants::HDR_KEY_FP)
        .and_then(|v| v.to_str().ok());
    let sig = parts
        .headers
        .get(constants::HDR_SIG)
        .and_then(|v| v.to_str().ok());

    let target_path = parts
        .uri
        .path_and_query()
        .map_or_else(|| parts.uri.path().to_string(), |pq| pq.as_str().to_string());

    let result = if let (Some(ts), Some(nonce), Some(fp), Some(sig)) = (ts, nonce, fp, sig) {
        let conn = state.conn.lock().unwrap();
        auth::verify_signed_request(
            &conn,
            session_token,
            parts.method.as_str(),
            &target_path,
            ts,
            nonce,
            fp,
            sig,
            &body_bytes,
            None,
        )
    } else {
        // missing-headers path: log and reject
        let conn = state.conn.lock().unwrap();
        let _ = audit::log(
            &conn,
            None,
            "signed_request",
            Some(&target_path),
            None,
            &serde_json::json!({"reason": "missing_headers"}),
        );
        return (StatusCode::UNAUTHORIZED, "unauthenticated").into_response();
    };

    match result {
        Ok(sess) => {
            let mut req = Request::from_parts(parts, Body::from(body_bytes));
            req.extensions_mut().insert(SignedSession(sess));
            next.run(req).await
        }
        Err(e) => {
            let conn = state.conn.lock().unwrap();
            let _ = audit::log(
                &conn,
                None,
                "signed_request",
                Some(&target_path),
                fp,
                &serde_json::json!({"reason": e.reason()}),
            );
            (StatusCode::UNAUTHORIZED, "unauthenticated").into_response()
        }
    }
}

/// Resolve `user_id -> username` for audit log actor strings.
pub fn username_for(conn: &Connection, user_id: i64) -> String {
    conn.query_row("SELECT username FROM users WHERE id=?", [user_id], |r| {
        r.get::<_, String>(0)
    })
    .unwrap_or_else(|_| format!("user:{user_id}"))
}