hashtree_cli/server/
auth.rs

1use axum::{
2    body::Body,
3    extract::State,
4    http::{header, Request, Response, StatusCode},
5    middleware::Next,
6    extract::ws::Message,
7};
8use crate::storage::HashtreeStore;
9use crate::webrtc::WebRTCState;
10use std::collections::{HashMap, HashSet};
11use std::sync::{Arc, atomic::{AtomicU64, Ordering}};
12use tokio::sync::{mpsc, Mutex};
13
14pub struct PendingRequest {
15    pub origin_id: u64,
16    pub hash: String,
17    pub found: bool,
18}
19
20pub struct WsRelayState {
21    pub clients: Mutex<HashMap<u64, mpsc::UnboundedSender<Message>>>,
22    pub pending: Mutex<HashMap<(u64, u32), PendingRequest>>,
23    pub next_client_id: AtomicU64,
24}
25
26impl WsRelayState {
27    pub fn new() -> Self {
28        Self {
29            clients: Mutex::new(HashMap::new()),
30            pending: Mutex::new(HashMap::new()),
31            next_client_id: AtomicU64::new(1),
32        }
33    }
34
35    pub fn next_id(&self) -> u64 {
36        self.next_client_id.fetch_add(1, Ordering::SeqCst)
37    }
38}
39
40#[derive(Clone)]
41pub struct AppState {
42    pub store: Arc<HashtreeStore>,
43    pub auth: Option<AuthCredentials>,
44    /// WebRTC peer state for forwarding requests to connected P2P peers
45    pub webrtc_peers: Option<Arc<WebRTCState>>,
46    /// WebSocket relay state for /ws/data clients
47    pub ws_relay: Arc<WsRelayState>,
48    /// Maximum upload size in bytes for Blossom uploads (default: 5 MB)
49    pub max_upload_bytes: usize,
50    /// Allow anyone with valid Nostr auth to write (default: true)
51    /// When false, only allowed_pubkeys can write
52    pub public_writes: bool,
53    /// Pubkeys allowed to write (hex format, from config allowed_npubs)
54    pub allowed_pubkeys: HashSet<String>,
55    /// Upstream Blossom servers for cascade fetching
56    pub upstream_blossom: Vec<String>,
57}
58
59#[derive(Clone)]
60pub struct AuthCredentials {
61    pub username: String,
62    pub password: String,
63}
64
65/// Auth middleware - validates HTTP Basic Auth
66pub async fn auth_middleware(
67    State(state): State<AppState>,
68    request: Request<Body>,
69    next: Next,
70) -> Result<Response<Body>, StatusCode> {
71    // If auth is not enabled, allow request
72    let Some(auth) = &state.auth else {
73        return Ok(next.run(request).await);
74    };
75
76    // Check Authorization header
77    let auth_header = request
78        .headers()
79        .get(header::AUTHORIZATION)
80        .and_then(|v| v.to_str().ok());
81
82    let authorized = if let Some(header_value) = auth_header {
83        if let Some(credentials) = header_value.strip_prefix("Basic ") {
84            use base64::Engine;
85            let engine = base64::engine::general_purpose::STANDARD;
86            if let Ok(decoded) = engine.decode(credentials) {
87                if let Ok(decoded_str) = String::from_utf8(decoded) {
88                    let expected = format!("{}:{}", auth.username, auth.password);
89                    decoded_str == expected
90                } else {
91                    false
92                }
93            } else {
94                false
95            }
96        } else {
97            false
98        }
99    } else {
100        false
101    };
102
103    if authorized {
104        Ok(next.run(request).await)
105    } else {
106        Ok(Response::builder()
107            .status(StatusCode::UNAUTHORIZED)
108            .header(header::WWW_AUTHENTICATE, "Basic realm=\"hashtree\"")
109            .body(Body::from("Unauthorized"))
110            .unwrap())
111    }
112}