1use std::sync::{Arc, Mutex};
4
5use anyhow::Result;
6use axum::{
7 body::Body,
8 extract::State,
9 http::{header, Request, StatusCode},
10 middleware::Next,
11 response::{IntoResponse, Response},
12 Router,
13};
14use rusqlite::Connection;
15
16use dragoon_proto::constants;
17
18use crate::{audit, auth, db, server_keys, settings::Settings};
19
20#[derive(Clone)]
25pub struct AppState {
26 pub settings: Arc<Settings>,
27 pub conn: Arc<Mutex<Connection>>,
28}
29
30const MAX_REQUEST_BODY: usize = 32 * 1024 * 1024; pub fn build_state(settings: Settings) -> Result<AppState> {
35 std::fs::create_dir_all(&settings.data_dir)?;
36 std::fs::create_dir_all(settings.blobs_dir())?;
37 let conn = db::connect(settings.db_path())?;
38 db::bootstrap(&conn)?;
39 server_keys::ensure(&conn)?;
40 Ok(AppState {
41 settings: Arc::new(settings),
42 conn: Arc::new(Mutex::new(conn)),
43 })
44}
45
46pub fn build_state_in_memory(settings: Settings) -> Result<AppState> {
47 let conn = db::connect_in_memory()?;
48 db::bootstrap(&conn)?;
49 server_keys::ensure(&conn)?;
50 Ok(AppState {
51 settings: Arc::new(settings),
52 conn: Arc::new(Mutex::new(conn)),
53 })
54}
55
56pub fn create_app(state: AppState) -> Router {
58 use crate::routes;
59 Router::new()
60 .merge(routes::auth::router(state.clone()))
61 .merge(routes::workers::router(state.clone()))
62 .merge(routes::tasks::router(state.clone()))
63 .merge(routes::messages::router(state.clone()))
64 .merge(routes::worker_api::router(state))
65}
66
67#[derive(Clone, Debug)]
74pub struct SignedSession(pub auth::Session);
75
76pub async fn signed_request(
85 State(state): State<AppState>,
86 req: Request<Body>,
87 next: Next,
88) -> Response {
89 let (parts, body) = req.into_parts();
90 let Ok(body_bytes) = axum::body::to_bytes(body, MAX_REQUEST_BODY).await else {
91 return (StatusCode::PAYLOAD_TOO_LARGE, "body too large").into_response();
92 };
93
94 let auth_header = parts
95 .headers
96 .get(header::AUTHORIZATION)
97 .and_then(|v| v.to_str().ok())
98 .unwrap_or("");
99 let session_token = auth_header.strip_prefix("Bearer ").unwrap_or("");
100
101 let ts = parts
102 .headers
103 .get(constants::HDR_TIMESTAMP)
104 .and_then(|v| v.to_str().ok())
105 .and_then(|s| s.parse::<i64>().ok());
106 let nonce = parts
107 .headers
108 .get(constants::HDR_NONCE)
109 .and_then(|v| v.to_str().ok());
110 let fp = parts
111 .headers
112 .get(constants::HDR_KEY_FP)
113 .and_then(|v| v.to_str().ok());
114 let sig = parts
115 .headers
116 .get(constants::HDR_SIG)
117 .and_then(|v| v.to_str().ok());
118
119 let target_path = parts
120 .uri
121 .path_and_query()
122 .map_or_else(|| parts.uri.path().to_string(), |pq| pq.as_str().to_string());
123
124 let result = if let (Some(ts), Some(nonce), Some(fp), Some(sig)) = (ts, nonce, fp, sig) {
125 let conn = state.conn.lock().unwrap();
126 auth::verify_signed_request(
127 &conn,
128 session_token,
129 parts.method.as_str(),
130 &target_path,
131 ts,
132 nonce,
133 fp,
134 sig,
135 &body_bytes,
136 None,
137 )
138 } else {
139 let conn = state.conn.lock().unwrap();
141 let _ = audit::log(
142 &conn,
143 None,
144 "signed_request",
145 Some(&target_path),
146 None,
147 &serde_json::json!({"reason": "missing_headers"}),
148 );
149 return (StatusCode::UNAUTHORIZED, "unauthenticated").into_response();
150 };
151
152 match result {
153 Ok(sess) => {
154 let mut req = Request::from_parts(parts, Body::from(body_bytes));
155 req.extensions_mut().insert(SignedSession(sess));
156 next.run(req).await
157 }
158 Err(e) => {
159 let conn = state.conn.lock().unwrap();
160 let _ = audit::log(
161 &conn,
162 None,
163 "signed_request",
164 Some(&target_path),
165 fp,
166 &serde_json::json!({"reason": e.reason()}),
167 );
168 (StatusCode::UNAUTHORIZED, "unauthenticated").into_response()
169 }
170 }
171}
172
173pub fn username_for(conn: &Connection, user_id: i64) -> String {
175 conn.query_row("SELECT username FROM users WHERE id=?", [user_id], |r| {
176 r.get::<_, String>(0)
177 })
178 .unwrap_or_else(|_| format!("user:{user_id}"))
179}