hashtree_cli/server/
auth.rs1use crate::nostr_relay::NostrRelay;
2use crate::socialgraph;
3use crate::storage::HashtreeStore;
4use crate::webrtc::WebRTCState;
5use axum::{
6 body::Body,
7 extract::ws::Message,
8 extract::State,
9 http::{header, Request, Response, StatusCode},
10 middleware::Next,
11};
12use hashtree_core::Cid;
13use std::collections::{HashMap, HashSet};
14use std::sync::{
15 atomic::{AtomicU32, AtomicU64, Ordering},
16 Arc, Mutex as StdMutex,
17};
18use tokio::sync::{mpsc, Mutex};
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21pub enum WsProtocol {
22 HashtreeJson,
23 HashtreeMsgpack,
24 Unknown,
25}
26
27pub struct PendingRequest {
28 pub origin_id: u64,
29 pub hash: String,
30 pub found: bool,
31 pub origin_protocol: WsProtocol,
32}
33
34pub struct WsRelayState {
35 pub clients: Mutex<HashMap<u64, mpsc::UnboundedSender<Message>>>,
36 pub pending: Mutex<HashMap<(u64, u32), PendingRequest>>,
37 pub client_protocols: Mutex<HashMap<u64, WsProtocol>>,
38 pub next_client_id: AtomicU64,
39 pub next_request_id: AtomicU32,
40}
41
42impl WsRelayState {
43 pub fn new() -> Self {
44 Self {
45 clients: Mutex::new(HashMap::new()),
46 pending: Mutex::new(HashMap::new()),
47 client_protocols: Mutex::new(HashMap::new()),
48 next_client_id: AtomicU64::new(1),
49 next_request_id: AtomicU32::new(1),
50 }
51 }
52
53 pub fn next_id(&self) -> u64 {
54 self.next_client_id.fetch_add(1, Ordering::SeqCst)
55 }
56
57 pub fn next_request_id(&self) -> u32 {
58 self.next_request_id.fetch_add(1, Ordering::SeqCst)
59 }
60}
61
62#[derive(Clone)]
63pub struct AppState {
64 pub store: Arc<HashtreeStore>,
65 pub auth: Option<AuthCredentials>,
66 pub webrtc_peers: Option<Arc<WebRTCState>>,
68 pub ws_relay: Arc<WsRelayState>,
70 pub max_upload_bytes: usize,
72 pub public_writes: bool,
75 pub allowed_pubkeys: HashSet<String>,
77 pub upstream_blossom: Vec<String>,
79 pub social_graph: Option<Arc<socialgraph::SocialGraphAccessControl>>,
81 pub social_graph_store: Option<Arc<dyn socialgraph::SocialGraphBackend>>,
83 pub social_graph_root: Option<[u8; 32]>,
85 pub socialgraph_snapshot_public: bool,
87 pub nostr_relay: Option<Arc<NostrRelay>>,
89 pub tree_root_cache: Arc<StdMutex<HashMap<String, Cid>>>,
91}
92
93#[derive(Clone)]
94pub struct AuthCredentials {
95 pub username: String,
96 pub password: String,
97}
98
99pub async fn auth_middleware(
101 State(state): State<AppState>,
102 request: Request<Body>,
103 next: Next,
104) -> Result<Response<Body>, StatusCode> {
105 let Some(auth) = &state.auth else {
107 return Ok(next.run(request).await);
108 };
109
110 let auth_header = request
112 .headers()
113 .get(header::AUTHORIZATION)
114 .and_then(|v| v.to_str().ok());
115
116 let authorized = if let Some(header_value) = auth_header {
117 if let Some(credentials) = header_value.strip_prefix("Basic ") {
118 use base64::Engine;
119 let engine = base64::engine::general_purpose::STANDARD;
120 if let Ok(decoded) = engine.decode(credentials) {
121 if let Ok(decoded_str) = String::from_utf8(decoded) {
122 let expected = format!("{}:{}", auth.username, auth.password);
123 decoded_str == expected
124 } else {
125 false
126 }
127 } else {
128 false
129 }
130 } else {
131 false
132 }
133 } else {
134 false
135 };
136
137 if authorized {
138 Ok(next.run(request).await)
139 } else {
140 Ok(Response::builder()
141 .status(StatusCode::UNAUTHORIZED)
142 .header(header::WWW_AUTHENTICATE, "Basic realm=\"hashtree\"")
143 .body(Body::from("Unauthorized"))
144 .unwrap())
145 }
146}