hashtree_cli/server/
auth.rs1use crate::nostr_relay::NostrRelay;
2use crate::socialgraph;
3use crate::storage::HashtreeStore;
4use crate::webrtc::WebRTCState;
5use axum::{
6 body::Body,
7 extract::ws::Message,
8 extract::State,
9 http::{header, Request, Response, StatusCode},
10 middleware::Next,
11};
12use std::collections::{HashMap, HashSet};
13use std::sync::{
14 atomic::{AtomicU32, AtomicU64, Ordering},
15 Arc,
16};
17use tokio::sync::{mpsc, Mutex};
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20pub enum WsProtocol {
21 HashtreeJson,
22 HashtreeMsgpack,
23 Unknown,
24}
25
26pub struct PendingRequest {
27 pub origin_id: u64,
28 pub hash: String,
29 pub found: bool,
30 pub origin_protocol: WsProtocol,
31}
32
33pub struct WsRelayState {
34 pub clients: Mutex<HashMap<u64, mpsc::UnboundedSender<Message>>>,
35 pub pending: Mutex<HashMap<(u64, u32), PendingRequest>>,
36 pub client_protocols: Mutex<HashMap<u64, WsProtocol>>,
37 pub next_client_id: AtomicU64,
38 pub next_request_id: AtomicU32,
39}
40
41impl WsRelayState {
42 pub fn new() -> Self {
43 Self {
44 clients: Mutex::new(HashMap::new()),
45 pending: Mutex::new(HashMap::new()),
46 client_protocols: Mutex::new(HashMap::new()),
47 next_client_id: AtomicU64::new(1),
48 next_request_id: AtomicU32::new(1),
49 }
50 }
51
52 pub fn next_id(&self) -> u64 {
53 self.next_client_id.fetch_add(1, Ordering::SeqCst)
54 }
55
56 pub fn next_request_id(&self) -> u32 {
57 self.next_request_id.fetch_add(1, Ordering::SeqCst)
58 }
59}
60
61#[derive(Clone)]
62pub struct AppState {
63 pub store: Arc<HashtreeStore>,
64 pub auth: Option<AuthCredentials>,
65 pub webrtc_peers: Option<Arc<WebRTCState>>,
67 pub ws_relay: Arc<WsRelayState>,
69 pub max_upload_bytes: usize,
71 pub public_writes: bool,
74 pub allowed_pubkeys: HashSet<String>,
76 pub upstream_blossom: Vec<String>,
78 pub social_graph: Option<Arc<socialgraph::SocialGraphAccessControl>>,
80 pub social_graph_ndb: Option<Arc<socialgraph::Ndb>>,
82 pub social_graph_root: Option<[u8; 32]>,
84 pub socialgraph_snapshot_public: bool,
86 pub nostr_relay: Option<Arc<NostrRelay>>,
88}
89
90#[derive(Clone)]
91pub struct AuthCredentials {
92 pub username: String,
93 pub password: String,
94}
95
96pub async fn auth_middleware(
98 State(state): State<AppState>,
99 request: Request<Body>,
100 next: Next,
101) -> Result<Response<Body>, StatusCode> {
102 let Some(auth) = &state.auth else {
104 return Ok(next.run(request).await);
105 };
106
107 let auth_header = request
109 .headers()
110 .get(header::AUTHORIZATION)
111 .and_then(|v| v.to_str().ok());
112
113 let authorized = if let Some(header_value) = auth_header {
114 if let Some(credentials) = header_value.strip_prefix("Basic ") {
115 use base64::Engine;
116 let engine = base64::engine::general_purpose::STANDARD;
117 if let Ok(decoded) = engine.decode(credentials) {
118 if let Ok(decoded_str) = String::from_utf8(decoded) {
119 let expected = format!("{}:{}", auth.username, auth.password);
120 decoded_str == expected
121 } else {
122 false
123 }
124 } else {
125 false
126 }
127 } else {
128 false
129 }
130 } else {
131 false
132 };
133
134 if authorized {
135 Ok(next.run(request).await)
136 } else {
137 Ok(Response::builder()
138 .status(StatusCode::UNAUTHORIZED)
139 .header(header::WWW_AUTHENTICATE, "Basic realm=\"hashtree\"")
140 .body(Body::from("Unauthorized"))
141 .unwrap())
142 }
143}