reddb_server/server/
ui_auth.rs1use std::net::SocketAddr;
24use std::sync::{Arc, Mutex};
25
26use axum::extract::{Path, State};
27use axum::http::{header, StatusCode};
28use axum::response::{IntoResponse, Response};
29use axum::routing::get;
30use tokio::sync::oneshot;
31
32#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
37pub enum UiAuthMode {
38 Injected,
41 Prompt,
44 #[default]
47 Open,
48}
49
50impl UiAuthMode {
51 pub fn resolve(token_supplied: bool, db_auth_required: bool) -> Self {
55 match (token_supplied, db_auth_required) {
56 (true, _) => UiAuthMode::Injected,
57 (false, true) => UiAuthMode::Prompt,
58 (false, false) => UiAuthMode::Open,
59 }
60 }
61
62 pub fn as_str(self) -> &'static str {
65 match self {
66 UiAuthMode::Injected => "injected",
67 UiAuthMode::Prompt => "prompt",
68 UiAuthMode::Open => "open",
69 }
70 }
71}
72
73pub fn auth_mode_config_snippet(mode: UiAuthMode) -> String {
77 format!(
78 "<script>window.REDDB_AUTH_MODE=\"{}\";</script>",
79 mode.as_str()
80 )
81}
82
83pub fn inject_auth_mode_config(html: Vec<u8>, mode: UiAuthMode) -> Vec<u8> {
88 let snippet = auth_mode_config_snippet(mode);
89 let marker = b"</head>";
90 match html.windows(marker.len()).position(|w| w == marker) {
91 Some(pos) => {
92 let mut out = Vec::with_capacity(html.len() + snippet.len());
93 out.extend_from_slice(&html[..pos]);
94 out.extend_from_slice(snippet.as_bytes());
95 out.extend_from_slice(&html[pos..]);
96 out
97 }
98 None => html,
99 }
100}
101
102pub fn new_handoff_nonce() -> String {
110 let mut bytes = [0u8; 16];
111 if crate::crypto::os_random::fill_bytes(&mut bytes).is_err() {
114 let seed = (&bytes as *const _ as usize) as u64;
115 bytes[..8].copy_from_slice(&seed.to_le_bytes());
116 }
117 let mut out = String::with_capacity(32);
118 for b in bytes {
119 out.push(nibble_hex(b >> 4));
120 out.push(nibble_hex(b & 0x0f));
121 }
122 out
123}
124
125fn nibble_hex(n: u8) -> char {
126 match n {
127 0..=9 => (b'0' + n) as char,
128 _ => (b'a' + (n - 10)) as char,
129 }
130}
131
132#[derive(Debug)]
137pub struct OneTimeSecret {
138 inner: Mutex<Option<String>>,
139}
140
141impl OneTimeSecret {
142 pub fn new(secret: String) -> Self {
144 Self {
145 inner: Mutex::new(Some(secret)),
146 }
147 }
148
149 pub fn take(&self) -> Option<String> {
152 self.inner.lock().expect("one-time secret lock").take()
153 }
154
155 pub fn is_consumed(&self) -> bool {
157 self.inner.lock().expect("one-time secret lock").is_none()
158 }
159}
160
161#[derive(Clone)]
167struct HandoffState {
168 nonce: Arc<String>,
170 secret: Arc<OneTimeSecret>,
172}
173
174pub struct HandoffServer {
179 local_addr: SocketAddr,
180 nonce: String,
181 secret: Arc<OneTimeSecret>,
182 shutdown_tx: Option<oneshot::Sender<()>>,
183 join: tokio::task::JoinHandle<()>,
184}
185
186impl HandoffServer {
187 pub fn handoff_url(&self) -> String {
190 format!("http://{}/handoff/{}", self.local_addr, self.nonce)
191 }
192
193 pub fn local_addr(&self) -> SocketAddr {
195 self.local_addr
196 }
197
198 pub fn is_consumed(&self) -> bool {
200 self.secret.is_consumed()
201 }
202
203 pub async fn shutdown(mut self) {
205 if let Some(tx) = self.shutdown_tx.take() {
206 let _ = tx.send(());
207 }
208 let _ = self.join.await;
209 }
210}
211
212pub async fn spawn_handoff_server(token: String) -> std::io::Result<HandoffServer> {
216 let nonce = new_handoff_nonce();
217 let secret = Arc::new(OneTimeSecret::new(token));
218
219 let state = HandoffState {
220 nonce: Arc::new(nonce.clone()),
221 secret: Arc::clone(&secret),
222 };
223
224 let listener = tokio::net::TcpListener::bind(("127.0.0.1", 0)).await?;
225 let local_addr = listener.local_addr()?;
226
227 let router = axum::Router::new()
228 .route("/handoff/{nonce}", get(serve_handoff))
229 .with_state(state);
230
231 let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>();
232 let join = tokio::spawn(async move {
233 let _ = axum::serve(listener, router)
234 .with_graceful_shutdown(async move {
235 let _ = shutdown_rx.await;
236 })
237 .await;
238 });
239
240 Ok(HandoffServer {
241 local_addr,
242 nonce,
243 secret,
244 shutdown_tx: Some(shutdown_tx),
245 join,
246 })
247}
248
249async fn serve_handoff(State(state): State<HandoffState>, Path(nonce): Path<String>) -> Response {
254 if !crate::crypto::constant_time_eq(nonce.as_bytes(), state.nonce.as_bytes()) {
255 return not_found();
256 }
257 match state.secret.take() {
258 Some(token) => (
259 StatusCode::OK,
260 [
261 (header::CONTENT_TYPE, "text/plain; charset=utf-8"),
262 (header::CACHE_CONTROL, "no-store"),
263 ],
264 token,
265 )
266 .into_response(),
267 None => not_found(),
268 }
269}
270
271fn not_found() -> Response {
272 (StatusCode::NOT_FOUND, "not found").into_response()
273}
274
275#[cfg(test)]
276mod tests {
277 use super::*;
278
279 #[test]
280 fn resolve_supplied_token_is_always_injected() {
281 assert_eq!(UiAuthMode::resolve(true, true), UiAuthMode::Injected);
282 assert_eq!(UiAuthMode::resolve(true, false), UiAuthMode::Injected);
283 }
284
285 #[test]
286 fn resolve_no_token_follows_db_auth_config() {
287 assert_eq!(UiAuthMode::resolve(false, true), UiAuthMode::Prompt);
289 assert_eq!(UiAuthMode::resolve(false, false), UiAuthMode::Open);
291 }
292
293 #[test]
294 fn auth_mode_strings_are_stable() {
295 assert_eq!(UiAuthMode::Injected.as_str(), "injected");
296 assert_eq!(UiAuthMode::Prompt.as_str(), "prompt");
297 assert_eq!(UiAuthMode::Open.as_str(), "open");
298 }
299
300 #[test]
301 fn config_snippet_never_carries_a_token() {
302 for mode in [UiAuthMode::Injected, UiAuthMode::Prompt, UiAuthMode::Open] {
304 let snippet = auth_mode_config_snippet(mode);
305 assert!(snippet.contains(mode.as_str()));
306 assert!(!snippet.to_ascii_lowercase().contains("token"));
307 assert!(!snippet.to_ascii_lowercase().contains("bearer"));
308 }
309 }
310
311 #[test]
312 fn inject_auth_mode_inserts_before_head_close() {
313 let html = b"<html><head></head><body></body></html>".to_vec();
314 let out = inject_auth_mode_config(html, UiAuthMode::Injected);
315 let s = String::from_utf8(out).unwrap();
316 assert!(
317 s.contains("<script>window.REDDB_AUTH_MODE=\"injected\";</script></head>"),
318 "snippet must appear before </head>: {s}"
319 );
320 }
321
322 #[test]
323 fn inject_auth_mode_noop_without_head_close() {
324 let html = b"<html><body>no head</body></html>".to_vec();
325 let orig = html.clone();
326 assert_eq!(inject_auth_mode_config(html, UiAuthMode::Prompt), orig);
327 }
328
329 #[test]
330 fn handoff_nonce_is_32_hex_chars_and_varies() {
331 let a = new_handoff_nonce();
332 let b = new_handoff_nonce();
333 assert_eq!(a.len(), 32, "nonce is 16 bytes hex-encoded");
334 assert!(a.chars().all(|c| c.is_ascii_hexdigit()));
335 assert_ne!(a, b, "nonces must be unique per draw");
338 }
339
340 #[test]
341 fn one_time_secret_yields_once_then_empty() {
342 let secret = OneTimeSecret::new("rk_supersecret".to_string());
343 assert!(!secret.is_consumed());
344 assert_eq!(secret.take().as_deref(), Some("rk_supersecret"));
345 assert!(secret.is_consumed());
346 assert_eq!(secret.take(), None);
348 }
349}