hashtree_cli/server/
auth.rs1use axum::{
2 body::Body,
3 extract::State,
4 http::{header, Request, Response, StatusCode},
5 middleware::Next,
6 extract::ws::Message,
7};
8use crate::storage::HashtreeStore;
9use crate::webrtc::WebRTCState;
10use std::collections::{HashMap, HashSet};
11use std::sync::{Arc, atomic::{AtomicU64, Ordering}};
12use tokio::sync::{mpsc, Mutex};
13
14pub struct PendingRequest {
15 pub origin_id: u64,
16 pub hash: String,
17 pub found: bool,
18}
19
20pub struct WsRelayState {
21 pub clients: Mutex<HashMap<u64, mpsc::UnboundedSender<Message>>>,
22 pub pending: Mutex<HashMap<(u64, u32), PendingRequest>>,
23 pub next_client_id: AtomicU64,
24}
25
26impl WsRelayState {
27 pub fn new() -> Self {
28 Self {
29 clients: Mutex::new(HashMap::new()),
30 pending: Mutex::new(HashMap::new()),
31 next_client_id: AtomicU64::new(1),
32 }
33 }
34
35 pub fn next_id(&self) -> u64 {
36 self.next_client_id.fetch_add(1, Ordering::SeqCst)
37 }
38}
39
40#[derive(Clone)]
41pub struct AppState {
42 pub store: Arc<HashtreeStore>,
43 pub auth: Option<AuthCredentials>,
44 pub webrtc_peers: Option<Arc<WebRTCState>>,
46 pub ws_relay: Arc<WsRelayState>,
48 pub max_upload_bytes: usize,
50 pub public_writes: bool,
53 pub allowed_pubkeys: HashSet<String>,
55 pub upstream_blossom: Vec<String>,
57}
58
59#[derive(Clone)]
60pub struct AuthCredentials {
61 pub username: String,
62 pub password: String,
63}
64
65pub async fn auth_middleware(
67 State(state): State<AppState>,
68 request: Request<Body>,
69 next: Next,
70) -> Result<Response<Body>, StatusCode> {
71 let Some(auth) = &state.auth else {
73 return Ok(next.run(request).await);
74 };
75
76 let auth_header = request
78 .headers()
79 .get(header::AUTHORIZATION)
80 .and_then(|v| v.to_str().ok());
81
82 let authorized = if let Some(header_value) = auth_header {
83 if let Some(credentials) = header_value.strip_prefix("Basic ") {
84 use base64::Engine;
85 let engine = base64::engine::general_purpose::STANDARD;
86 if let Ok(decoded) = engine.decode(credentials) {
87 if let Ok(decoded_str) = String::from_utf8(decoded) {
88 let expected = format!("{}:{}", auth.username, auth.password);
89 decoded_str == expected
90 } else {
91 false
92 }
93 } else {
94 false
95 }
96 } else {
97 false
98 }
99 } else {
100 false
101 };
102
103 if authorized {
104 Ok(next.run(request).await)
105 } else {
106 Ok(Response::builder()
107 .status(StatusCode::UNAUTHORIZED)
108 .header(header::WWW_AUTHENTICATE, "Basic realm=\"hashtree\"")
109 .body(Body::from("Unauthorized"))
110 .unwrap())
111 }
112}