Skip to main content

hashtree_cli/server/
auth.rs

1use crate::nostr_relay::NostrRelay;
2use crate::socialgraph;
3use crate::storage::HashtreeStore;
4use crate::webrtc::WebRTCState;
5use axum::{
6    body::Body,
7    extract::ws::Message,
8    extract::State,
9    http::{header, Request, Response, StatusCode},
10    middleware::Next,
11};
12use std::collections::{HashMap, HashSet};
13use std::sync::{
14    atomic::{AtomicU32, AtomicU64, Ordering},
15    Arc,
16};
17use tokio::sync::{mpsc, Mutex};
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20pub enum WsProtocol {
21    HashtreeJson,
22    HashtreeMsgpack,
23    Unknown,
24}
25
26pub struct PendingRequest {
27    pub origin_id: u64,
28    pub hash: String,
29    pub found: bool,
30    pub origin_protocol: WsProtocol,
31}
32
33pub struct WsRelayState {
34    pub clients: Mutex<HashMap<u64, mpsc::UnboundedSender<Message>>>,
35    pub pending: Mutex<HashMap<(u64, u32), PendingRequest>>,
36    pub client_protocols: Mutex<HashMap<u64, WsProtocol>>,
37    pub next_client_id: AtomicU64,
38    pub next_request_id: AtomicU32,
39}
40
41impl WsRelayState {
42    pub fn new() -> Self {
43        Self {
44            clients: Mutex::new(HashMap::new()),
45            pending: Mutex::new(HashMap::new()),
46            client_protocols: Mutex::new(HashMap::new()),
47            next_client_id: AtomicU64::new(1),
48            next_request_id: AtomicU32::new(1),
49        }
50    }
51
52    pub fn next_id(&self) -> u64 {
53        self.next_client_id.fetch_add(1, Ordering::SeqCst)
54    }
55
56    pub fn next_request_id(&self) -> u32 {
57        self.next_request_id.fetch_add(1, Ordering::SeqCst)
58    }
59}
60
61#[derive(Clone)]
62pub struct AppState {
63    pub store: Arc<HashtreeStore>,
64    pub auth: Option<AuthCredentials>,
65    /// WebRTC peer state for forwarding requests to connected P2P peers
66    pub webrtc_peers: Option<Arc<WebRTCState>>,
67    /// WebSocket relay state for /ws clients
68    pub ws_relay: Arc<WsRelayState>,
69    /// Maximum upload size in bytes for Blossom uploads (default: 5 MB)
70    pub max_upload_bytes: usize,
71    /// Allow anyone with valid Nostr auth to write (default: true)
72    /// When false, only allowed_pubkeys can write
73    pub public_writes: bool,
74    /// Pubkeys allowed to write (hex format, from config allowed_npubs)
75    pub allowed_pubkeys: HashSet<String>,
76    /// Upstream Blossom servers for cascade fetching
77    pub upstream_blossom: Vec<String>,
78    /// Social graph access control (nostrdb-backed when feature enabled)
79    pub social_graph: Option<Arc<socialgraph::SocialGraphAccessControl>>,
80    /// Social graph nostrdb handle for snapshot export
81    pub social_graph_ndb: Option<Arc<socialgraph::Ndb>>,
82    /// Social graph root pubkey bytes for snapshot export
83    pub social_graph_root: Option<[u8; 32]>,
84    /// Allow public access to social graph snapshot endpoint
85    pub socialgraph_snapshot_public: bool,
86    /// Nostr relay state for /ws and WebRTC Nostr messages
87    pub nostr_relay: Option<Arc<NostrRelay>>,
88}
89
90#[derive(Clone)]
91pub struct AuthCredentials {
92    pub username: String,
93    pub password: String,
94}
95
96/// Auth middleware - validates HTTP Basic Auth
97pub async fn auth_middleware(
98    State(state): State<AppState>,
99    request: Request<Body>,
100    next: Next,
101) -> Result<Response<Body>, StatusCode> {
102    // If auth is not enabled, allow request
103    let Some(auth) = &state.auth else {
104        return Ok(next.run(request).await);
105    };
106
107    // Check Authorization header
108    let auth_header = request
109        .headers()
110        .get(header::AUTHORIZATION)
111        .and_then(|v| v.to_str().ok());
112
113    let authorized = if let Some(header_value) = auth_header {
114        if let Some(credentials) = header_value.strip_prefix("Basic ") {
115            use base64::Engine;
116            let engine = base64::engine::general_purpose::STANDARD;
117            if let Ok(decoded) = engine.decode(credentials) {
118                if let Ok(decoded_str) = String::from_utf8(decoded) {
119                    let expected = format!("{}:{}", auth.username, auth.password);
120                    decoded_str == expected
121                } else {
122                    false
123                }
124            } else {
125                false
126            }
127        } else {
128            false
129        }
130    } else {
131        false
132    };
133
134    if authorized {
135        Ok(next.run(request).await)
136    } else {
137        Ok(Response::builder()
138            .status(StatusCode::UNAUTHORIZED)
139            .header(header::WWW_AUTHENTICATE, "Basic realm=\"hashtree\"")
140            .body(Body::from("Unauthorized"))
141            .unwrap())
142    }
143}