Skip to main content

hashtree_cli/server/
auth.rs

1use crate::nostr_relay::NostrRelay;
2use crate::socialgraph;
3use crate::storage::HashtreeStore;
4use crate::webrtc::{PeerRootEvent, WebRTCState};
5use axum::{
6    body::Body,
7    extract::ws::Message,
8    extract::State,
9    http::{header, Request, Response, StatusCode},
10    middleware::Next,
11};
12use futures::future::{BoxFuture, Shared};
13use hashtree_core::{Cid, LinkType, TreeEntry};
14use lru::LruCache;
15use std::collections::{HashMap, HashSet};
16use std::hash::Hash;
17use std::num::NonZeroUsize;
18use std::sync::{
19    atomic::{AtomicU32, AtomicU64, Ordering},
20    Arc, Mutex as StdMutex,
21};
22use std::time::{Duration, Instant};
23use tokio::{
24    sync::{mpsc, watch, Mutex},
25    task::JoinHandle,
26};
27
28const LOOKUP_CACHE_CAPACITY: usize = 4096;
29const LOOKUP_CACHE_HIT_TTL: Duration = Duration::from_secs(300);
30const LOOKUP_CACHE_MISS_TTL: Duration = Duration::from_secs(1);
31
32#[derive(Debug, Clone)]
33pub enum LookupResult<T> {
34    Hit(T),
35    Miss,
36}
37
38impl<T> LookupResult<T> {
39    pub fn from_option(value: Option<T>) -> Self {
40        match value {
41            Some(value) => Self::Hit(value),
42            None => Self::Miss,
43        }
44    }
45
46    pub fn into_option(self) -> Option<T> {
47        match self {
48            Self::Hit(value) => Some(value),
49            Self::Miss => None,
50        }
51    }
52
53    pub fn ttl(&self) -> Duration {
54        match self {
55            Self::Hit(_) => LOOKUP_CACHE_HIT_TTL,
56            Self::Miss => LOOKUP_CACHE_MISS_TTL,
57        }
58    }
59}
60
61pub struct TimedLruCache<K, V> {
62    cache: LruCache<K, TimedValue<V>>,
63}
64
65#[derive(Clone)]
66struct TimedValue<V> {
67    value: V,
68    expires_at: Instant,
69}
70
71impl<K: Eq + Hash, V: Clone> TimedLruCache<K, V> {
72    pub fn new(capacity: usize) -> Self {
73        Self {
74            cache: LruCache::new(NonZeroUsize::new(capacity.max(1)).unwrap()),
75        }
76    }
77
78    pub fn get_cloned(&mut self, key: &K) -> Option<V> {
79        let now = Instant::now();
80        if let Some(entry) = self.cache.get(key) {
81            if entry.expires_at > now {
82                return Some(entry.value.clone());
83            }
84        }
85        self.cache.pop(key);
86        None
87    }
88
89    pub fn put(&mut self, key: K, value: V, ttl: Duration) {
90        self.cache.put(
91            key,
92            TimedValue {
93                value,
94                expires_at: Instant::now() + ttl,
95            },
96        );
97    }
98}
99
100pub fn new_lookup_cache<K: Eq + Hash, V: Clone>() -> TimedLruCache<K, V> {
101    TimedLruCache::new(LOOKUP_CACHE_CAPACITY)
102}
103
104#[derive(Debug, Clone)]
105pub struct CachedResolvedPathEntry {
106    pub cid: Cid,
107    pub link_type: LinkType,
108}
109
110#[derive(Debug, Clone)]
111pub struct CachedTreeRootEntry {
112    pub cid: Cid,
113    pub source: &'static str,
114    pub root_event: Option<PeerRootEvent>,
115}
116
117pub type SharedBlobFetch = Shared<BoxFuture<'static, bool>>;
118
119#[derive(Debug, Clone, Copy, PartialEq, Eq)]
120pub enum WsProtocol {
121    HashtreeJson,
122    HashtreeMsgpack,
123    Unknown,
124}
125
126pub struct PendingRequest {
127    pub origin_id: u64,
128    pub hash: String,
129    pub found: bool,
130    pub origin_protocol: WsProtocol,
131}
132
133pub struct UpstreamNostrSubscription {
134    pub close_tx: watch::Sender<bool>,
135    pub tasks: Vec<JoinHandle<()>>,
136}
137
138pub struct WsRelayState {
139    pub clients: Mutex<HashMap<u64, mpsc::UnboundedSender<Message>>>,
140    pub pending: Mutex<HashMap<(u64, u32), PendingRequest>>,
141    pub client_protocols: Mutex<HashMap<u64, WsProtocol>>,
142    pub upstream_nostr_subscriptions: Mutex<HashMap<(u64, String), UpstreamNostrSubscription>>,
143    pub upstream_seen_events: Mutex<HashMap<(u64, String), HashSet<String>>>,
144    pub upstream_pending_eose: Mutex<HashMap<(u64, String), usize>>,
145    pub next_client_id: AtomicU64,
146    pub next_request_id: AtomicU32,
147    pub upstream_relay_bytes_sent: AtomicU64,
148    pub upstream_relay_bytes_received: AtomicU64,
149}
150
151impl WsRelayState {
152    pub fn new() -> Self {
153        Self {
154            clients: Mutex::new(HashMap::new()),
155            pending: Mutex::new(HashMap::new()),
156            client_protocols: Mutex::new(HashMap::new()),
157            upstream_nostr_subscriptions: Mutex::new(HashMap::new()),
158            upstream_seen_events: Mutex::new(HashMap::new()),
159            upstream_pending_eose: Mutex::new(HashMap::new()),
160            next_client_id: AtomicU64::new(1),
161            next_request_id: AtomicU32::new(1),
162            upstream_relay_bytes_sent: AtomicU64::new(0),
163            upstream_relay_bytes_received: AtomicU64::new(0),
164        }
165    }
166
167    pub fn next_id(&self) -> u64 {
168        self.next_client_id.fetch_add(1, Ordering::SeqCst)
169    }
170
171    pub fn next_request_id(&self) -> u32 {
172        self.next_request_id.fetch_add(1, Ordering::SeqCst)
173    }
174
175    pub fn note_upstream_relay_send(&self, bytes: usize) {
176        self.upstream_relay_bytes_sent
177            .fetch_add(bytes as u64, Ordering::Relaxed);
178    }
179
180    pub fn note_upstream_relay_receive(&self, bytes: usize) {
181        self.upstream_relay_bytes_received
182            .fetch_add(bytes as u64, Ordering::Relaxed);
183    }
184
185    pub fn upstream_relay_bandwidth(&self) -> (u64, u64) {
186        (
187            self.upstream_relay_bytes_sent.load(Ordering::Relaxed),
188            self.upstream_relay_bytes_received.load(Ordering::Relaxed),
189        )
190    }
191}
192
193#[derive(Clone)]
194pub struct AppState {
195    pub store: Arc<HashtreeStore>,
196    pub auth: Option<AuthCredentials>,
197    pub peer_mode: crate::config::ServerMode,
198    pub hash_get_enabled: bool,
199    /// WebRTC peer state for forwarding requests to connected P2P peers
200    pub webrtc_peers: Option<Arc<WebRTCState>>,
201    /// WebSocket relay state for /ws clients
202    pub ws_relay: Arc<WsRelayState>,
203    /// Maximum upload size in bytes for Blossom uploads (default: 5 MB)
204    pub max_upload_bytes: usize,
205    /// Allow anyone with valid Nostr auth to write (default: true)
206    /// When false, only allowed_pubkeys can write
207    pub public_writes: bool,
208    /// Pubkeys allowed to write (hex format, from config allowed_npubs)
209    pub allowed_pubkeys: HashSet<String>,
210    /// Upstream Blossom servers for cascade fetching
211    pub upstream_blossom: Vec<String>,
212    /// Social graph access control
213    pub social_graph: Option<Arc<socialgraph::SocialGraphAccessControl>>,
214    /// Social graph store handle for snapshot export
215    pub social_graph_store: Option<Arc<dyn socialgraph::SocialGraphBackend>>,
216    /// Social graph root pubkey bytes for snapshot export
217    pub social_graph_root: Option<[u8; 32]>,
218    /// Allow public access to social graph snapshot endpoint
219    pub socialgraph_snapshot_public: bool,
220    /// Nostr relay state for /ws and WebRTC Nostr messages
221    pub nostr_relay: Option<Arc<NostrRelay>>,
222    /// Active upstream Nostr relays for HTTP resolver operations.
223    pub nostr_relay_urls: Vec<String>,
224    /// In-process cache for resolved mutable tree roots, keyed by npub/tree(+key)
225    pub tree_root_cache: Arc<StdMutex<HashMap<String, CachedTreeRootEntry>>>,
226    /// Shared in-flight blob fetches so concurrent misses only hit upstream once per hash
227    pub inflight_blob_fetches: Arc<Mutex<HashMap<String, SharedBlobFetch>>>,
228    /// Immutable directory listings keyed by CID
229    pub directory_listing_cache: Arc<StdMutex<TimedLruCache<String, LookupResult<Vec<TreeEntry>>>>>,
230    /// Immutable resolved paths keyed by root CID + path
231    pub resolved_path_cache:
232        Arc<StdMutex<TimedLruCache<String, LookupResult<CachedResolvedPathEntry>>>>,
233    /// Immutable thumbnail alias resolutions keyed by root CID + alias path
234    pub thumbnail_path_cache: Arc<StdMutex<TimedLruCache<String, LookupResult<String>>>>,
235    /// Immutable file sizes keyed by CID
236    pub cid_size_cache: Arc<StdMutex<TimedLruCache<String, LookupResult<u64>>>>,
237}
238
239#[derive(Clone)]
240pub struct AuthCredentials {
241    pub username: String,
242    pub password: String,
243}
244
245/// Auth middleware - validates HTTP Basic Auth
246pub async fn auth_middleware(
247    State(state): State<AppState>,
248    request: Request<Body>,
249    next: Next,
250) -> Result<Response<Body>, StatusCode> {
251    // If auth is not enabled, allow request
252    let Some(auth) = &state.auth else {
253        return Ok(next.run(request).await);
254    };
255
256    // Check Authorization header
257    let auth_header = request
258        .headers()
259        .get(header::AUTHORIZATION)
260        .and_then(|v| v.to_str().ok());
261
262    let authorized = if let Some(header_value) = auth_header {
263        if let Some(credentials) = header_value.strip_prefix("Basic ") {
264            use base64::Engine;
265            let engine = base64::engine::general_purpose::STANDARD;
266            if let Ok(decoded) = engine.decode(credentials) {
267                if let Ok(decoded_str) = String::from_utf8(decoded) {
268                    let expected = format!("{}:{}", auth.username, auth.password);
269                    decoded_str == expected
270                } else {
271                    false
272                }
273            } else {
274                false
275            }
276        } else {
277            false
278        }
279    } else {
280        false
281    };
282
283    if authorized {
284        Ok(next.run(request).await)
285    } else {
286        Ok(Response::builder()
287            .status(StatusCode::UNAUTHORIZED)
288            .header(header::WWW_AUTHENTICATE, "Basic realm=\"hashtree\"")
289            .body(Body::from("Unauthorized"))
290            .unwrap())
291    }
292}