Skip to main content

hashtree_cli/server/
auth.rs

1use crate::nostr_relay::NostrRelay;
2use crate::socialgraph;
3use crate::storage::HashtreeStore;
4use crate::webrtc::WebRTCState;
5use axum::{
6    body::Body,
7    extract::ws::Message,
8    extract::State,
9    http::{header, Request, Response, StatusCode},
10    middleware::Next,
11};
12use hashtree_core::Cid;
13use std::collections::{HashMap, HashSet};
14use std::sync::{
15    atomic::{AtomicU32, AtomicU64, Ordering},
16    Arc, Mutex as StdMutex,
17};
18use tokio::sync::{mpsc, Mutex};
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21pub enum WsProtocol {
22    HashtreeJson,
23    HashtreeMsgpack,
24    Unknown,
25}
26
27pub struct PendingRequest {
28    pub origin_id: u64,
29    pub hash: String,
30    pub found: bool,
31    pub origin_protocol: WsProtocol,
32}
33
34pub struct WsRelayState {
35    pub clients: Mutex<HashMap<u64, mpsc::UnboundedSender<Message>>>,
36    pub pending: Mutex<HashMap<(u64, u32), PendingRequest>>,
37    pub client_protocols: Mutex<HashMap<u64, WsProtocol>>,
38    pub next_client_id: AtomicU64,
39    pub next_request_id: AtomicU32,
40}
41
42impl WsRelayState {
43    pub fn new() -> Self {
44        Self {
45            clients: Mutex::new(HashMap::new()),
46            pending: Mutex::new(HashMap::new()),
47            client_protocols: Mutex::new(HashMap::new()),
48            next_client_id: AtomicU64::new(1),
49            next_request_id: AtomicU32::new(1),
50        }
51    }
52
53    pub fn next_id(&self) -> u64 {
54        self.next_client_id.fetch_add(1, Ordering::SeqCst)
55    }
56
57    pub fn next_request_id(&self) -> u32 {
58        self.next_request_id.fetch_add(1, Ordering::SeqCst)
59    }
60}
61
62#[derive(Clone)]
63pub struct AppState {
64    pub store: Arc<HashtreeStore>,
65    pub auth: Option<AuthCredentials>,
66    /// WebRTC peer state for forwarding requests to connected P2P peers
67    pub webrtc_peers: Option<Arc<WebRTCState>>,
68    /// WebSocket relay state for /ws clients
69    pub ws_relay: Arc<WsRelayState>,
70    /// Maximum upload size in bytes for Blossom uploads (default: 5 MB)
71    pub max_upload_bytes: usize,
72    /// Allow anyone with valid Nostr auth to write (default: true)
73    /// When false, only allowed_pubkeys can write
74    pub public_writes: bool,
75    /// Pubkeys allowed to write (hex format, from config allowed_npubs)
76    pub allowed_pubkeys: HashSet<String>,
77    /// Upstream Blossom servers for cascade fetching
78    pub upstream_blossom: Vec<String>,
79    /// Social graph access control
80    pub social_graph: Option<Arc<socialgraph::SocialGraphAccessControl>>,
81    /// Social graph store handle for snapshot export
82    pub social_graph_store: Option<Arc<dyn socialgraph::SocialGraphBackend>>,
83    /// Social graph root pubkey bytes for snapshot export
84    pub social_graph_root: Option<[u8; 32]>,
85    /// Allow public access to social graph snapshot endpoint
86    pub socialgraph_snapshot_public: bool,
87    /// Nostr relay state for /ws and WebRTC Nostr messages
88    pub nostr_relay: Option<Arc<NostrRelay>>,
89    /// In-process cache for resolved mutable tree roots, keyed by npub/tree(+key)
90    pub tree_root_cache: Arc<StdMutex<HashMap<String, Cid>>>,
91}
92
93#[derive(Clone)]
94pub struct AuthCredentials {
95    pub username: String,
96    pub password: String,
97}
98
99/// Auth middleware - validates HTTP Basic Auth
100pub async fn auth_middleware(
101    State(state): State<AppState>,
102    request: Request<Body>,
103    next: Next,
104) -> Result<Response<Body>, StatusCode> {
105    // If auth is not enabled, allow request
106    let Some(auth) = &state.auth else {
107        return Ok(next.run(request).await);
108    };
109
110    // Check Authorization header
111    let auth_header = request
112        .headers()
113        .get(header::AUTHORIZATION)
114        .and_then(|v| v.to_str().ok());
115
116    let authorized = if let Some(header_value) = auth_header {
117        if let Some(credentials) = header_value.strip_prefix("Basic ") {
118            use base64::Engine;
119            let engine = base64::engine::general_purpose::STANDARD;
120            if let Ok(decoded) = engine.decode(credentials) {
121                if let Ok(decoded_str) = String::from_utf8(decoded) {
122                    let expected = format!("{}:{}", auth.username, auth.password);
123                    decoded_str == expected
124                } else {
125                    false
126                }
127            } else {
128                false
129            }
130        } else {
131            false
132        }
133    } else {
134        false
135    };
136
137    if authorized {
138        Ok(next.run(request).await)
139    } else {
140        Ok(Response::builder()
141            .status(StatusCode::UNAUTHORIZED)
142            .header(header::WWW_AUTHENTICATE, "Basic realm=\"hashtree\"")
143            .body(Body::from("Unauthorized"))
144            .unwrap())
145    }
146}