use std::net::SocketAddr;
use std::sync::{Arc, Mutex};
use axum::extract::{Path, State};
use axum::http::{header, StatusCode};
use axum::response::{IntoResponse, Response};
use axum::routing::get;
use tokio::sync::oneshot;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum UiAuthMode {
Injected,
Prompt,
#[default]
Open,
}
impl UiAuthMode {
pub fn resolve(token_supplied: bool, db_auth_required: bool) -> Self {
match (token_supplied, db_auth_required) {
(true, _) => UiAuthMode::Injected,
(false, true) => UiAuthMode::Prompt,
(false, false) => UiAuthMode::Open,
}
}
pub fn as_str(self) -> &'static str {
match self {
UiAuthMode::Injected => "injected",
UiAuthMode::Prompt => "prompt",
UiAuthMode::Open => "open",
}
}
}
pub fn auth_mode_config_snippet(mode: UiAuthMode) -> String {
format!(
"<script>window.REDDB_AUTH_MODE=\"{}\";</script>",
mode.as_str()
)
}
pub fn inject_auth_mode_config(html: Vec<u8>, mode: UiAuthMode) -> Vec<u8> {
let snippet = auth_mode_config_snippet(mode);
let marker = b"</head>";
match html.windows(marker.len()).position(|w| w == marker) {
Some(pos) => {
let mut out = Vec::with_capacity(html.len() + snippet.len());
out.extend_from_slice(&html[..pos]);
out.extend_from_slice(snippet.as_bytes());
out.extend_from_slice(&html[pos..]);
out
}
None => html,
}
}
pub fn new_handoff_nonce() -> String {
let mut bytes = [0u8; 16];
if crate::crypto::os_random::fill_bytes(&mut bytes).is_err() {
let seed = (&bytes as *const _ as usize) as u64;
bytes[..8].copy_from_slice(&seed.to_le_bytes());
}
let mut out = String::with_capacity(32);
for b in bytes {
out.push(nibble_hex(b >> 4));
out.push(nibble_hex(b & 0x0f));
}
out
}
fn nibble_hex(n: u8) -> char {
match n {
0..=9 => (b'0' + n) as char,
_ => (b'a' + (n - 10)) as char,
}
}
#[derive(Debug)]
pub struct OneTimeSecret {
inner: Mutex<Option<String>>,
}
impl OneTimeSecret {
pub fn new(secret: String) -> Self {
Self {
inner: Mutex::new(Some(secret)),
}
}
pub fn take(&self) -> Option<String> {
self.inner.lock().expect("one-time secret lock").take()
}
pub fn is_consumed(&self) -> bool {
self.inner.lock().expect("one-time secret lock").is_none()
}
}
#[derive(Clone)]
struct HandoffState {
nonce: Arc<String>,
secret: Arc<OneTimeSecret>,
}
pub struct HandoffServer {
local_addr: SocketAddr,
nonce: String,
secret: Arc<OneTimeSecret>,
shutdown_tx: Option<oneshot::Sender<()>>,
join: tokio::task::JoinHandle<()>,
}
impl HandoffServer {
pub fn handoff_url(&self) -> String {
format!("http://{}/handoff/{}", self.local_addr, self.nonce)
}
pub fn local_addr(&self) -> SocketAddr {
self.local_addr
}
pub fn is_consumed(&self) -> bool {
self.secret.is_consumed()
}
pub async fn shutdown(mut self) {
if let Some(tx) = self.shutdown_tx.take() {
let _ = tx.send(());
}
let _ = self.join.await;
}
}
pub async fn spawn_handoff_server(token: String) -> std::io::Result<HandoffServer> {
let nonce = new_handoff_nonce();
let secret = Arc::new(OneTimeSecret::new(token));
let state = HandoffState {
nonce: Arc::new(nonce.clone()),
secret: Arc::clone(&secret),
};
let listener = tokio::net::TcpListener::bind(("127.0.0.1", 0)).await?;
let local_addr = listener.local_addr()?;
let router = axum::Router::new()
.route("/handoff/{nonce}", get(serve_handoff))
.with_state(state);
let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>();
let join = tokio::spawn(async move {
let _ = axum::serve(listener, router)
.with_graceful_shutdown(async move {
let _ = shutdown_rx.await;
})
.await;
});
Ok(HandoffServer {
local_addr,
nonce,
secret,
shutdown_tx: Some(shutdown_tx),
join,
})
}
async fn serve_handoff(State(state): State<HandoffState>, Path(nonce): Path<String>) -> Response {
if !crate::crypto::constant_time_eq(nonce.as_bytes(), state.nonce.as_bytes()) {
return not_found();
}
match state.secret.take() {
Some(token) => (
StatusCode::OK,
[
(header::CONTENT_TYPE, "text/plain; charset=utf-8"),
(header::CACHE_CONTROL, "no-store"),
],
token,
)
.into_response(),
None => not_found(),
}
}
fn not_found() -> Response {
(StatusCode::NOT_FOUND, "not found").into_response()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn resolve_supplied_token_is_always_injected() {
assert_eq!(UiAuthMode::resolve(true, true), UiAuthMode::Injected);
assert_eq!(UiAuthMode::resolve(true, false), UiAuthMode::Injected);
}
#[test]
fn resolve_no_token_follows_db_auth_config() {
assert_eq!(UiAuthMode::resolve(false, true), UiAuthMode::Prompt);
assert_eq!(UiAuthMode::resolve(false, false), UiAuthMode::Open);
}
#[test]
fn auth_mode_strings_are_stable() {
assert_eq!(UiAuthMode::Injected.as_str(), "injected");
assert_eq!(UiAuthMode::Prompt.as_str(), "prompt");
assert_eq!(UiAuthMode::Open.as_str(), "open");
}
#[test]
fn config_snippet_never_carries_a_token() {
for mode in [UiAuthMode::Injected, UiAuthMode::Prompt, UiAuthMode::Open] {
let snippet = auth_mode_config_snippet(mode);
assert!(snippet.contains(mode.as_str()));
assert!(!snippet.to_ascii_lowercase().contains("token"));
assert!(!snippet.to_ascii_lowercase().contains("bearer"));
}
}
#[test]
fn inject_auth_mode_inserts_before_head_close() {
let html = b"<html><head></head><body></body></html>".to_vec();
let out = inject_auth_mode_config(html, UiAuthMode::Injected);
let s = String::from_utf8(out).unwrap();
assert!(
s.contains("<script>window.REDDB_AUTH_MODE=\"injected\";</script></head>"),
"snippet must appear before </head>: {s}"
);
}
#[test]
fn inject_auth_mode_noop_without_head_close() {
let html = b"<html><body>no head</body></html>".to_vec();
let orig = html.clone();
assert_eq!(inject_auth_mode_config(html, UiAuthMode::Prompt), orig);
}
#[test]
fn handoff_nonce_is_32_hex_chars_and_varies() {
let a = new_handoff_nonce();
let b = new_handoff_nonce();
assert_eq!(a.len(), 32, "nonce is 16 bytes hex-encoded");
assert!(a.chars().all(|c| c.is_ascii_hexdigit()));
assert_ne!(a, b, "nonces must be unique per draw");
}
#[test]
fn one_time_secret_yields_once_then_empty() {
let secret = OneTimeSecret::new("rk_supersecret".to_string());
assert!(!secret.is_consumed());
assert_eq!(secret.take().as_deref(), Some("rk_supersecret"));
assert!(secret.is_consumed());
assert_eq!(secret.take(), None);
}
}