Skip to main content

dragoon_server/
app.rs

1//! axum application factory + shared state + signed-request middleware.
2
3use std::sync::{Arc, Mutex};
4
5use anyhow::Result;
6use axum::{
7    body::Body,
8    extract::State,
9    http::{header, Request, StatusCode},
10    middleware::Next,
11    response::{IntoResponse, Response},
12    Router,
13};
14use rusqlite::Connection;
15
16use dragoon_proto::constants;
17
18use crate::{audit, auth, db, server_keys, settings::Settings};
19
20/// Application state shared by every route. The single SQLite connection
21/// is wrapped in a `Mutex` since rusqlite is sync-only; handlers acquire
22/// the lock briefly. For long-running queries we'd `spawn_blocking`, but
23/// the workload here is microsecond-scale.
24#[derive(Clone)]
25pub struct AppState {
26    pub settings: Arc<Settings>,
27    pub conn: Arc<Mutex<Connection>>,
28}
29
30const MAX_REQUEST_BODY: usize = 32 * 1024 * 1024; // 32 MiB
31
32/// Wrap a [`Connection`] in shared state and apply schema bootstrap +
33/// the server task-signing key initialisation.
34pub fn build_state(settings: Settings) -> Result<AppState> {
35    std::fs::create_dir_all(&settings.data_dir)?;
36    std::fs::create_dir_all(settings.blobs_dir())?;
37    let conn = db::connect(settings.db_path())?;
38    db::bootstrap(&conn)?;
39    server_keys::ensure(&conn)?;
40    Ok(AppState {
41        settings: Arc::new(settings),
42        conn: Arc::new(Mutex::new(conn)),
43    })
44}
45
46pub fn build_state_in_memory(settings: Settings) -> Result<AppState> {
47    let conn = db::connect_in_memory()?;
48    db::bootstrap(&conn)?;
49    server_keys::ensure(&conn)?;
50    Ok(AppState {
51        settings: Arc::new(settings),
52        conn: Arc::new(Mutex::new(conn)),
53    })
54}
55
56/// Build the axum [`Router`] with every route and middleware wired up.
57pub fn create_app(state: AppState) -> Router {
58    use crate::routes;
59    Router::new()
60        .merge(routes::auth::router(state.clone()))
61        .merge(routes::workers::router(state.clone()))
62        .merge(routes::tasks::router(state.clone()))
63        .merge(routes::messages::router(state.clone()))
64        .merge(routes::worker_api::router(state))
65}
66
67// --------------------------------------------------------------------------
68// Signed-request middleware
69// --------------------------------------------------------------------------
70
71/// Carries the verified [`auth::Session`] into the route handler via
72/// `Extension<SignedSession>`.
73#[derive(Clone, Debug)]
74pub struct SignedSession(pub auth::Session);
75
76/// Middleware factory that:
77///
78/// 1. Reads the body bytes (so the canonical request hash can be computed).
79/// 2. Reads the four `X-RE-*` signature headers + `Authorization: Bearer`.
80/// 3. Runs [`auth::verify_signed_request`].
81/// 4. On success, re-attaches the body and inserts `SignedSession` into
82///    `request.extensions()` so handlers can pull it out.
83/// 5. On failure, audits the reason and returns 401.
84pub async fn signed_request(
85    State(state): State<AppState>,
86    req: Request<Body>,
87    next: Next,
88) -> Response {
89    let (parts, body) = req.into_parts();
90    let Ok(body_bytes) = axum::body::to_bytes(body, MAX_REQUEST_BODY).await else {
91        return (StatusCode::PAYLOAD_TOO_LARGE, "body too large").into_response();
92    };
93
94    let auth_header = parts
95        .headers
96        .get(header::AUTHORIZATION)
97        .and_then(|v| v.to_str().ok())
98        .unwrap_or("");
99    let session_token = auth_header.strip_prefix("Bearer ").unwrap_or("");
100
101    let ts = parts
102        .headers
103        .get(constants::HDR_TIMESTAMP)
104        .and_then(|v| v.to_str().ok())
105        .and_then(|s| s.parse::<i64>().ok());
106    let nonce = parts
107        .headers
108        .get(constants::HDR_NONCE)
109        .and_then(|v| v.to_str().ok());
110    let fp = parts
111        .headers
112        .get(constants::HDR_KEY_FP)
113        .and_then(|v| v.to_str().ok());
114    let sig = parts
115        .headers
116        .get(constants::HDR_SIG)
117        .and_then(|v| v.to_str().ok());
118
119    let target_path = parts
120        .uri
121        .path_and_query()
122        .map_or_else(|| parts.uri.path().to_string(), |pq| pq.as_str().to_string());
123
124    let result = if let (Some(ts), Some(nonce), Some(fp), Some(sig)) = (ts, nonce, fp, sig) {
125        let conn = state.conn.lock().unwrap();
126        auth::verify_signed_request(
127            &conn,
128            session_token,
129            parts.method.as_str(),
130            &target_path,
131            ts,
132            nonce,
133            fp,
134            sig,
135            &body_bytes,
136            None,
137        )
138    } else {
139        // missing-headers path: log and reject
140        let conn = state.conn.lock().unwrap();
141        let _ = audit::log(
142            &conn,
143            None,
144            "signed_request",
145            Some(&target_path),
146            None,
147            &serde_json::json!({"reason": "missing_headers"}),
148        );
149        return (StatusCode::UNAUTHORIZED, "unauthenticated").into_response();
150    };
151
152    match result {
153        Ok(sess) => {
154            let mut req = Request::from_parts(parts, Body::from(body_bytes));
155            req.extensions_mut().insert(SignedSession(sess));
156            next.run(req).await
157        }
158        Err(e) => {
159            let conn = state.conn.lock().unwrap();
160            let _ = audit::log(
161                &conn,
162                None,
163                "signed_request",
164                Some(&target_path),
165                fp,
166                &serde_json::json!({"reason": e.reason()}),
167            );
168            (StatusCode::UNAUTHORIZED, "unauthenticated").into_response()
169        }
170    }
171}
172
173/// Resolve `user_id -> username` for audit log actor strings.
174pub fn username_for(conn: &Connection, user_id: i64) -> String {
175    conn.query_row("SELECT username FROM users WHERE id=?", [user_id], |r| {
176        r.get::<_, String>(0)
177    })
178    .unwrap_or_else(|_| format!("user:{user_id}"))
179}