Skip to main content

hashtree_cli/server/
auth.rs

1use axum::{
2    body::Body,
3    extract::State,
4    http::{header, Request, Response, StatusCode},
5    middleware::Next,
6    extract::ws::Message,
7};
8use crate::socialgraph;
9use crate::nostr_relay::NostrRelay;
10use crate::storage::HashtreeStore;
11use crate::webrtc::WebRTCState;
12use std::collections::{HashMap, HashSet};
13use std::sync::{Arc, atomic::{AtomicU32, AtomicU64, Ordering}};
14use tokio::sync::{mpsc, Mutex};
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum WsProtocol {
18    HashtreeJson,
19    HashtreeMsgpack,
20    Unknown,
21}
22
23pub struct PendingRequest {
24    pub origin_id: u64,
25    pub hash: String,
26    pub found: bool,
27    pub origin_protocol: WsProtocol,
28}
29
30pub struct WsRelayState {
31    pub clients: Mutex<HashMap<u64, mpsc::UnboundedSender<Message>>>,
32    pub pending: Mutex<HashMap<(u64, u32), PendingRequest>>,
33    pub client_protocols: Mutex<HashMap<u64, WsProtocol>>,
34    pub next_client_id: AtomicU64,
35    pub next_request_id: AtomicU32,
36}
37
38impl WsRelayState {
39    pub fn new() -> Self {
40        Self {
41            clients: Mutex::new(HashMap::new()),
42            pending: Mutex::new(HashMap::new()),
43            client_protocols: Mutex::new(HashMap::new()),
44            next_client_id: AtomicU64::new(1),
45            next_request_id: AtomicU32::new(1),
46        }
47    }
48
49    pub fn next_id(&self) -> u64 {
50        self.next_client_id.fetch_add(1, Ordering::SeqCst)
51    }
52
53    pub fn next_request_id(&self) -> u32 {
54        self.next_request_id.fetch_add(1, Ordering::SeqCst)
55    }
56}
57
58#[derive(Clone)]
59pub struct AppState {
60    pub store: Arc<HashtreeStore>,
61    pub auth: Option<AuthCredentials>,
62    /// WebRTC peer state for forwarding requests to connected P2P peers
63    pub webrtc_peers: Option<Arc<WebRTCState>>,
64    /// WebSocket relay state for /ws clients
65    pub ws_relay: Arc<WsRelayState>,
66    /// Maximum upload size in bytes for Blossom uploads (default: 5 MB)
67    pub max_upload_bytes: usize,
68    /// Allow anyone with valid Nostr auth to write (default: true)
69    /// When false, only allowed_pubkeys can write
70    pub public_writes: bool,
71    /// Pubkeys allowed to write (hex format, from config allowed_npubs)
72    pub allowed_pubkeys: HashSet<String>,
73    /// Upstream Blossom servers for cascade fetching
74    pub upstream_blossom: Vec<String>,
75    /// Social graph access control (nostrdb-backed when feature enabled)
76    pub social_graph: Option<Arc<socialgraph::SocialGraphAccessControl>>,
77    /// Nostr relay state for /ws and WebRTC Nostr messages
78    pub nostr_relay: Option<Arc<NostrRelay>>,
79}
80
81#[derive(Clone)]
82pub struct AuthCredentials {
83    pub username: String,
84    pub password: String,
85}
86
87/// Auth middleware - validates HTTP Basic Auth
88pub async fn auth_middleware(
89    State(state): State<AppState>,
90    request: Request<Body>,
91    next: Next,
92) -> Result<Response<Body>, StatusCode> {
93    // If auth is not enabled, allow request
94    let Some(auth) = &state.auth else {
95        return Ok(next.run(request).await);
96    };
97
98    // Check Authorization header
99    let auth_header = request
100        .headers()
101        .get(header::AUTHORIZATION)
102        .and_then(|v| v.to_str().ok());
103
104    let authorized = if let Some(header_value) = auth_header {
105        if let Some(credentials) = header_value.strip_prefix("Basic ") {
106            use base64::Engine;
107            let engine = base64::engine::general_purpose::STANDARD;
108            if let Ok(decoded) = engine.decode(credentials) {
109                if let Ok(decoded_str) = String::from_utf8(decoded) {
110                    let expected = format!("{}:{}", auth.username, auth.password);
111                    decoded_str == expected
112                } else {
113                    false
114                }
115            } else {
116                false
117            }
118        } else {
119            false
120        }
121    } else {
122        false
123    };
124
125    if authorized {
126        Ok(next.run(request).await)
127    } else {
128        Ok(Response::builder()
129            .status(StatusCode::UNAUTHORIZED)
130            .header(header::WWW_AUTHENTICATE, "Basic realm=\"hashtree\"")
131            .body(Body::from("Unauthorized"))
132            .unwrap())
133    }
134}