hashtree_cli/server/
auth.rs1use axum::{
2 body::Body,
3 extract::State,
4 http::{header, Request, Response, StatusCode},
5 middleware::Next,
6 extract::ws::Message,
7};
8use crate::socialgraph;
9use crate::nostr_relay::NostrRelay;
10use crate::storage::HashtreeStore;
11use crate::webrtc::WebRTCState;
12use std::collections::{HashMap, HashSet};
13use std::sync::{Arc, atomic::{AtomicU32, AtomicU64, Ordering}};
14use tokio::sync::{mpsc, Mutex};
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum WsProtocol {
18 HashtreeJson,
19 HashtreeMsgpack,
20 Unknown,
21}
22
23pub struct PendingRequest {
24 pub origin_id: u64,
25 pub hash: String,
26 pub found: bool,
27 pub origin_protocol: WsProtocol,
28}
29
30pub struct WsRelayState {
31 pub clients: Mutex<HashMap<u64, mpsc::UnboundedSender<Message>>>,
32 pub pending: Mutex<HashMap<(u64, u32), PendingRequest>>,
33 pub client_protocols: Mutex<HashMap<u64, WsProtocol>>,
34 pub next_client_id: AtomicU64,
35 pub next_request_id: AtomicU32,
36}
37
38impl WsRelayState {
39 pub fn new() -> Self {
40 Self {
41 clients: Mutex::new(HashMap::new()),
42 pending: Mutex::new(HashMap::new()),
43 client_protocols: Mutex::new(HashMap::new()),
44 next_client_id: AtomicU64::new(1),
45 next_request_id: AtomicU32::new(1),
46 }
47 }
48
49 pub fn next_id(&self) -> u64 {
50 self.next_client_id.fetch_add(1, Ordering::SeqCst)
51 }
52
53 pub fn next_request_id(&self) -> u32 {
54 self.next_request_id.fetch_add(1, Ordering::SeqCst)
55 }
56}
57
58#[derive(Clone)]
59pub struct AppState {
60 pub store: Arc<HashtreeStore>,
61 pub auth: Option<AuthCredentials>,
62 pub webrtc_peers: Option<Arc<WebRTCState>>,
64 pub ws_relay: Arc<WsRelayState>,
66 pub max_upload_bytes: usize,
68 pub public_writes: bool,
71 pub allowed_pubkeys: HashSet<String>,
73 pub upstream_blossom: Vec<String>,
75 pub social_graph: Option<Arc<socialgraph::SocialGraphAccessControl>>,
77 pub nostr_relay: Option<Arc<NostrRelay>>,
79}
80
81#[derive(Clone)]
82pub struct AuthCredentials {
83 pub username: String,
84 pub password: String,
85}
86
87pub async fn auth_middleware(
89 State(state): State<AppState>,
90 request: Request<Body>,
91 next: Next,
92) -> Result<Response<Body>, StatusCode> {
93 let Some(auth) = &state.auth else {
95 return Ok(next.run(request).await);
96 };
97
98 let auth_header = request
100 .headers()
101 .get(header::AUTHORIZATION)
102 .and_then(|v| v.to_str().ok());
103
104 let authorized = if let Some(header_value) = auth_header {
105 if let Some(credentials) = header_value.strip_prefix("Basic ") {
106 use base64::Engine;
107 let engine = base64::engine::general_purpose::STANDARD;
108 if let Ok(decoded) = engine.decode(credentials) {
109 if let Ok(decoded_str) = String::from_utf8(decoded) {
110 let expected = format!("{}:{}", auth.username, auth.password);
111 decoded_str == expected
112 } else {
113 false
114 }
115 } else {
116 false
117 }
118 } else {
119 false
120 }
121 } else {
122 false
123 };
124
125 if authorized {
126 Ok(next.run(request).await)
127 } else {
128 Ok(Response::builder()
129 .status(StatusCode::UNAUTHORIZED)
130 .header(header::WWW_AUTHENTICATE, "Basic realm=\"hashtree\"")
131 .body(Body::from("Unauthorized"))
132 .unwrap())
133 }
134}