1use crate::nostr_relay::NostrRelay;
2use crate::socialgraph;
3use crate::storage::HashtreeStore;
4use crate::webrtc::{PeerRootEvent, WebRTCState};
5use axum::{
6 body::Body,
7 extract::ws::Message,
8 extract::State,
9 http::{header, Request, Response, StatusCode},
10 middleware::Next,
11};
12use futures::future::{BoxFuture, Shared};
13use hashtree_core::{Cid, LinkType, TreeEntry};
14use lru::LruCache;
15use std::collections::{HashMap, HashSet};
16use std::hash::Hash;
17use std::num::NonZeroUsize;
18use std::sync::{
19 atomic::{AtomicU32, AtomicU64, Ordering},
20 Arc, Mutex as StdMutex,
21};
22use std::time::{Duration, Instant};
23use tokio::{
24 sync::{mpsc, watch, Mutex},
25 task::JoinHandle,
26};
27
28const LOOKUP_CACHE_CAPACITY: usize = 4096;
29const LOOKUP_CACHE_HIT_TTL: Duration = Duration::from_secs(300);
30const LOOKUP_CACHE_MISS_TTL: Duration = Duration::from_secs(1);
31
32#[derive(Debug, Clone)]
33pub enum LookupResult<T> {
34 Hit(T),
35 Miss,
36}
37
38impl<T> LookupResult<T> {
39 pub fn from_option(value: Option<T>) -> Self {
40 match value {
41 Some(value) => Self::Hit(value),
42 None => Self::Miss,
43 }
44 }
45
46 pub fn into_option(self) -> Option<T> {
47 match self {
48 Self::Hit(value) => Some(value),
49 Self::Miss => None,
50 }
51 }
52
53 pub fn ttl(&self) -> Duration {
54 match self {
55 Self::Hit(_) => LOOKUP_CACHE_HIT_TTL,
56 Self::Miss => LOOKUP_CACHE_MISS_TTL,
57 }
58 }
59}
60
61pub struct TimedLruCache<K, V> {
62 cache: LruCache<K, TimedValue<V>>,
63}
64
65#[derive(Clone)]
66struct TimedValue<V> {
67 value: V,
68 expires_at: Instant,
69}
70
71impl<K: Eq + Hash, V: Clone> TimedLruCache<K, V> {
72 pub fn new(capacity: usize) -> Self {
73 Self {
74 cache: LruCache::new(NonZeroUsize::new(capacity.max(1)).unwrap()),
75 }
76 }
77
78 pub fn get_cloned(&mut self, key: &K) -> Option<V> {
79 let now = Instant::now();
80 if let Some(entry) = self.cache.get(key) {
81 if entry.expires_at > now {
82 return Some(entry.value.clone());
83 }
84 }
85 self.cache.pop(key);
86 None
87 }
88
89 pub fn put(&mut self, key: K, value: V, ttl: Duration) {
90 self.cache.put(
91 key,
92 TimedValue {
93 value,
94 expires_at: Instant::now() + ttl,
95 },
96 );
97 }
98}
99
100pub fn new_lookup_cache<K: Eq + Hash, V: Clone>() -> TimedLruCache<K, V> {
101 TimedLruCache::new(LOOKUP_CACHE_CAPACITY)
102}
103
104#[derive(Debug, Clone)]
105pub struct CachedResolvedPathEntry {
106 pub cid: Cid,
107 pub link_type: LinkType,
108}
109
110#[derive(Debug, Clone)]
111pub struct CachedTreeRootEntry {
112 pub cid: Cid,
113 pub source: &'static str,
114 pub root_event: Option<PeerRootEvent>,
115}
116
117pub type SharedBlobFetch = Shared<BoxFuture<'static, bool>>;
118
119#[derive(Debug, Clone, Copy, PartialEq, Eq)]
120pub enum WsProtocol {
121 HashtreeJson,
122 HashtreeMsgpack,
123 Unknown,
124}
125
126pub struct PendingRequest {
127 pub origin_id: u64,
128 pub hash: String,
129 pub found: bool,
130 pub origin_protocol: WsProtocol,
131}
132
133pub struct UpstreamNostrSubscription {
134 pub close_tx: watch::Sender<bool>,
135 pub tasks: Vec<JoinHandle<()>>,
136}
137
138pub struct WsRelayState {
139 pub clients: Mutex<HashMap<u64, mpsc::UnboundedSender<Message>>>,
140 pub pending: Mutex<HashMap<(u64, u32), PendingRequest>>,
141 pub client_protocols: Mutex<HashMap<u64, WsProtocol>>,
142 pub upstream_nostr_subscriptions: Mutex<HashMap<(u64, String), UpstreamNostrSubscription>>,
143 pub upstream_seen_events: Mutex<HashMap<(u64, String), HashSet<String>>>,
144 pub upstream_pending_eose: Mutex<HashMap<(u64, String), usize>>,
145 pub next_client_id: AtomicU64,
146 pub next_request_id: AtomicU32,
147 pub upstream_relay_bytes_sent: AtomicU64,
148 pub upstream_relay_bytes_received: AtomicU64,
149}
150
151impl WsRelayState {
152 pub fn new() -> Self {
153 Self {
154 clients: Mutex::new(HashMap::new()),
155 pending: Mutex::new(HashMap::new()),
156 client_protocols: Mutex::new(HashMap::new()),
157 upstream_nostr_subscriptions: Mutex::new(HashMap::new()),
158 upstream_seen_events: Mutex::new(HashMap::new()),
159 upstream_pending_eose: Mutex::new(HashMap::new()),
160 next_client_id: AtomicU64::new(1),
161 next_request_id: AtomicU32::new(1),
162 upstream_relay_bytes_sent: AtomicU64::new(0),
163 upstream_relay_bytes_received: AtomicU64::new(0),
164 }
165 }
166
167 pub fn next_id(&self) -> u64 {
168 self.next_client_id.fetch_add(1, Ordering::SeqCst)
169 }
170
171 pub fn next_request_id(&self) -> u32 {
172 self.next_request_id.fetch_add(1, Ordering::SeqCst)
173 }
174
175 pub fn note_upstream_relay_send(&self, bytes: usize) {
176 self.upstream_relay_bytes_sent
177 .fetch_add(bytes as u64, Ordering::Relaxed);
178 }
179
180 pub fn note_upstream_relay_receive(&self, bytes: usize) {
181 self.upstream_relay_bytes_received
182 .fetch_add(bytes as u64, Ordering::Relaxed);
183 }
184
185 pub fn upstream_relay_bandwidth(&self) -> (u64, u64) {
186 (
187 self.upstream_relay_bytes_sent.load(Ordering::Relaxed),
188 self.upstream_relay_bytes_received.load(Ordering::Relaxed),
189 )
190 }
191}
192
193#[derive(Clone)]
194pub struct AppState {
195 pub store: Arc<HashtreeStore>,
196 pub auth: Option<AuthCredentials>,
197 pub webrtc_peers: Option<Arc<WebRTCState>>,
199 pub ws_relay: Arc<WsRelayState>,
201 pub max_upload_bytes: usize,
203 pub public_writes: bool,
206 pub allowed_pubkeys: HashSet<String>,
208 pub upstream_blossom: Vec<String>,
210 pub social_graph: Option<Arc<socialgraph::SocialGraphAccessControl>>,
212 pub social_graph_store: Option<Arc<dyn socialgraph::SocialGraphBackend>>,
214 pub social_graph_root: Option<[u8; 32]>,
216 pub socialgraph_snapshot_public: bool,
218 pub nostr_relay: Option<Arc<NostrRelay>>,
220 pub nostr_relay_urls: Vec<String>,
222 pub tree_root_cache: Arc<StdMutex<HashMap<String, CachedTreeRootEntry>>>,
224 pub inflight_blob_fetches: Arc<Mutex<HashMap<String, SharedBlobFetch>>>,
226 pub directory_listing_cache: Arc<StdMutex<TimedLruCache<String, LookupResult<Vec<TreeEntry>>>>>,
228 pub resolved_path_cache:
230 Arc<StdMutex<TimedLruCache<String, LookupResult<CachedResolvedPathEntry>>>>,
231 pub thumbnail_path_cache: Arc<StdMutex<TimedLruCache<String, LookupResult<String>>>>,
233 pub cid_size_cache: Arc<StdMutex<TimedLruCache<String, LookupResult<u64>>>>,
235}
236
237#[derive(Clone)]
238pub struct AuthCredentials {
239 pub username: String,
240 pub password: String,
241}
242
243pub async fn auth_middleware(
245 State(state): State<AppState>,
246 request: Request<Body>,
247 next: Next,
248) -> Result<Response<Body>, StatusCode> {
249 let Some(auth) = &state.auth else {
251 return Ok(next.run(request).await);
252 };
253
254 let auth_header = request
256 .headers()
257 .get(header::AUTHORIZATION)
258 .and_then(|v| v.to_str().ok());
259
260 let authorized = if let Some(header_value) = auth_header {
261 if let Some(credentials) = header_value.strip_prefix("Basic ") {
262 use base64::Engine;
263 let engine = base64::engine::general_purpose::STANDARD;
264 if let Ok(decoded) = engine.decode(credentials) {
265 if let Ok(decoded_str) = String::from_utf8(decoded) {
266 let expected = format!("{}:{}", auth.username, auth.password);
267 decoded_str == expected
268 } else {
269 false
270 }
271 } else {
272 false
273 }
274 } else {
275 false
276 }
277 } else {
278 false
279 };
280
281 if authorized {
282 Ok(next.run(request).await)
283 } else {
284 Ok(Response::builder()
285 .status(StatusCode::UNAUTHORIZED)
286 .header(header::WWW_AUTHENTICATE, "Basic realm=\"hashtree\"")
287 .body(Body::from("Unauthorized"))
288 .unwrap())
289 }
290}