Skip to main content

codetether_agent/server/
auth.rs

1//! Mandatory authentication middleware
2//!
3//! All endpoints except `/health` require a valid Bearer token.
4//! **Auth cannot be disabled.** If no `CODETETHER_AUTH_TOKEN` is set the
5//! server generates a secure random token at startup and prints it to stderr
6//! so the operator can copy it — but the gates never open without a token.
7
8use axum::{
9    body::Body,
10    http::{Request, StatusCode, header},
11    middleware::Next,
12    response::Response,
13};
14use rand::RngExt;
15use std::sync::Arc;
16
17/// Paths that are exempt from authentication.
18const PUBLIC_PATHS: &[&str] = &["/health"];
19
20/// Shared auth state.
21#[derive(Debug, Clone)]
22pub struct AuthState {
23    /// The required Bearer token.
24    token: Arc<String>,
25}
26
27impl AuthState {
28    /// Build from the environment.  If `CODETETHER_AUTH_TOKEN` is not set a
29    /// 32-byte hex token is generated and logged once.
30    pub fn from_env() -> Self {
31        let token = match std::env::var("CODETETHER_AUTH_TOKEN") {
32            Ok(t) if !t.is_empty() => {
33                tracing::info!("Auth token loaded from CODETETHER_AUTH_TOKEN");
34                t
35            }
36            _ => {
37                let generated: String = {
38                    let mut rng = rand::rng();
39                    (0..32)
40                        .map(|_| format!("{:02x}", rng.random::<u8>()))
41                        .collect()
42                };
43                tracing::warn!(
44                    token = %generated,
45                    "No CODETETHER_AUTH_TOKEN set — generated a random token. \
46                     Set CODETETHER_AUTH_TOKEN to use a stable token."
47                );
48                generated
49            }
50        };
51        Self {
52            token: Arc::new(token),
53        }
54    }
55
56    /// Create with an explicit token (for tests).
57    #[cfg(test)]
58    pub fn with_token(token: impl Into<String>) -> Self {
59        Self {
60            token: Arc::new(token.into()),
61        }
62    }
63
64    /// Return the active token (for display at startup).
65    pub fn token(&self) -> &str {
66        &self.token
67    }
68}
69
70/// Axum middleware layer that enforces Bearer token auth on every request
71/// except public paths.
72pub async fn require_auth(request: Request<Body>, next: Next) -> Result<Response, StatusCode> {
73    let path = request.uri().path();
74
75    // Allow public paths through without auth.
76    if PUBLIC_PATHS.iter().any(|p| path == *p) {
77        return Ok(next.run(request).await);
78    }
79
80    // Extract the AuthState from extensions (set by the server setup).
81    let auth_state = request
82        .extensions()
83        .get::<AuthState>()
84        .cloned()
85        .ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
86
87    // Extract Bearer token from Authorization header.
88    let auth_header = request
89        .headers()
90        .get(header::AUTHORIZATION)
91        .and_then(|v| v.to_str().ok());
92
93    let provided_token = match auth_header {
94        Some(value) if value.starts_with("Bearer ") => &value[7..],
95        _ => {
96            // Also accept token via query parameter for SSE/WebSocket clients.
97            let query = request.uri().query().unwrap_or("");
98            let token_param = query.split('&').find_map(|pair| {
99                let mut parts = pair.splitn(2, '=');
100                match (parts.next(), parts.next()) {
101                    (Some("token"), Some(v)) => Some(v),
102                    _ => None,
103                }
104            });
105            match token_param {
106                Some(t) => t,
107                None => return Err(StatusCode::UNAUTHORIZED),
108            }
109        }
110    };
111
112    // Constant-time comparison to prevent timing attacks.
113    if constant_time_eq(provided_token.as_bytes(), auth_state.token.as_bytes()) {
114        Ok(next.run(request).await)
115    } else {
116        Err(StatusCode::UNAUTHORIZED)
117    }
118}
119
120/// Constant-time byte comparison.
121fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
122    if a.len() != b.len() {
123        return false;
124    }
125    let mut diff = 0u8;
126    for (x, y) in a.iter().zip(b.iter()) {
127        diff |= x ^ y;
128    }
129    diff == 0
130}
131
132#[cfg(test)]
133mod tests {
134    use super::*;
135
136    #[test]
137    fn constant_time_eq_works() {
138        assert!(constant_time_eq(b"hello", b"hello"));
139        assert!(!constant_time_eq(b"hello", b"world"));
140        assert!(!constant_time_eq(b"short", b"longer"));
141    }
142
143    #[test]
144    fn auth_state_generates_token_when_env_missing() {
145        // Ensure the env var is not set for this test.
146        // SAFETY: This is a single-threaded test; no other thread reads this env var.
147        unsafe {
148            std::env::remove_var("CODETETHER_AUTH_TOKEN");
149        }
150        let state = AuthState::from_env();
151        assert_eq!(state.token().len(), 64); // 32 bytes = 64 hex chars
152    }
153}