use std::sync::{Arc, Mutex};
use anyhow::Result;
use axum::{
body::Body,
extract::State,
http::{header, Request, StatusCode},
middleware::Next,
response::{IntoResponse, Response},
Router,
};
use rusqlite::Connection;
use dragoon_proto::constants;
use crate::{audit, auth, db, server_keys, settings::Settings};
#[derive(Clone)]
pub struct AppState {
pub settings: Arc<Settings>,
pub conn: Arc<Mutex<Connection>>,
}
const MAX_REQUEST_BODY: usize = 32 * 1024 * 1024;
pub fn build_state(settings: Settings) -> Result<AppState> {
std::fs::create_dir_all(&settings.data_dir)?;
std::fs::create_dir_all(settings.blobs_dir())?;
let conn = db::connect(settings.db_path())?;
db::bootstrap(&conn)?;
server_keys::ensure(&conn)?;
Ok(AppState {
settings: Arc::new(settings),
conn: Arc::new(Mutex::new(conn)),
})
}
pub fn build_state_in_memory(settings: Settings) -> Result<AppState> {
let conn = db::connect_in_memory()?;
db::bootstrap(&conn)?;
server_keys::ensure(&conn)?;
Ok(AppState {
settings: Arc::new(settings),
conn: Arc::new(Mutex::new(conn)),
})
}
pub fn create_app(state: AppState) -> Router {
use crate::routes;
Router::new()
.merge(routes::auth::router(state.clone()))
.merge(routes::workers::router(state.clone()))
.merge(routes::tasks::router(state.clone()))
.merge(routes::messages::router(state.clone()))
.merge(routes::worker_api::router(state))
}
#[derive(Clone, Debug)]
pub struct SignedSession(pub auth::Session);
pub async fn signed_request(
State(state): State<AppState>,
req: Request<Body>,
next: Next,
) -> Response {
let (parts, body) = req.into_parts();
let Ok(body_bytes) = axum::body::to_bytes(body, MAX_REQUEST_BODY).await else {
return (StatusCode::PAYLOAD_TOO_LARGE, "body too large").into_response();
};
let auth_header = parts
.headers
.get(header::AUTHORIZATION)
.and_then(|v| v.to_str().ok())
.unwrap_or("");
let session_token = auth_header.strip_prefix("Bearer ").unwrap_or("");
let ts = parts
.headers
.get(constants::HDR_TIMESTAMP)
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<i64>().ok());
let nonce = parts
.headers
.get(constants::HDR_NONCE)
.and_then(|v| v.to_str().ok());
let fp = parts
.headers
.get(constants::HDR_KEY_FP)
.and_then(|v| v.to_str().ok());
let sig = parts
.headers
.get(constants::HDR_SIG)
.and_then(|v| v.to_str().ok());
let target_path = parts
.uri
.path_and_query()
.map_or_else(|| parts.uri.path().to_string(), |pq| pq.as_str().to_string());
let result = if let (Some(ts), Some(nonce), Some(fp), Some(sig)) = (ts, nonce, fp, sig) {
let conn = state.conn.lock().unwrap();
auth::verify_signed_request(
&conn,
session_token,
parts.method.as_str(),
&target_path,
ts,
nonce,
fp,
sig,
&body_bytes,
None,
)
} else {
let conn = state.conn.lock().unwrap();
let _ = audit::log(
&conn,
None,
"signed_request",
Some(&target_path),
None,
&serde_json::json!({"reason": "missing_headers"}),
);
return (StatusCode::UNAUTHORIZED, "unauthenticated").into_response();
};
match result {
Ok(sess) => {
let mut req = Request::from_parts(parts, Body::from(body_bytes));
req.extensions_mut().insert(SignedSession(sess));
next.run(req).await
}
Err(e) => {
let conn = state.conn.lock().unwrap();
let _ = audit::log(
&conn,
None,
"signed_request",
Some(&target_path),
fp,
&serde_json::json!({"reason": e.reason()}),
);
(StatusCode::UNAUTHORIZED, "unauthenticated").into_response()
}
}
}
pub fn username_for(conn: &Connection, user_id: i64) -> String {
conn.query_row("SELECT username FROM users WHERE id=?", [user_id], |r| {
r.get::<_, String>(0)
})
.unwrap_or_else(|_| format!("user:{user_id}"))
}