Skip to main content

codetether_agent/server/
auth.rs

1//! Mandatory authentication middleware
2//!
3//! All endpoints except `/health` require a valid Bearer token.
4//! **Auth cannot be disabled.** If no `CODETETHER_AUTH_TOKEN` is set the
5//! server generates a secure random token at startup and prints it to stderr
6//! so the operator can copy it — but the gates never open without a token.
7//!
8//! JWT support: If the Bearer token is a JWT, topic claims are extracted
9//! and stored in request extensions for use by the bus stream endpoint.
10
11use axum::{
12    body::Body,
13    http::{Request, StatusCode, header},
14    middleware::Next,
15    response::Response,
16};
17use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
18use rand::RngExt;
19use serde::{Deserialize, Serialize};
20use std::sync::Arc;
21
22/// Paths that are exempt from authentication.
23const PUBLIC_PATHS: &[&str] = &["/health"];
24
25/// JWT claims extracted from the Bearer token for topic filtering.
26#[derive(Debug, Clone, Default, Serialize, Deserialize)]
27pub struct JwtClaims {
28    /// Allowed topics for bus stream filtering.
29    #[serde(default)]
30    pub topics: Vec<String>,
31    /// Optional user identifier.
32    #[serde(default, rename = "sub")]
33    pub subject: Option<String>,
34    /// Additional scopes from the JWT.
35    #[serde(default)]
36    pub scopes: Vec<String>,
37}
38
39/// Request extension key for JWT claims.
40#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
41pub struct JwtClaimsKey;
42
43/// Application state that includes JWT claims for extraction in handlers.
44#[derive(Debug, Clone)]
45pub struct JwtAppState {
46    /// JWT claims extracted from the Bearer token.
47    pub jwt_claims: JwtClaims,
48}
49
50impl Default for JwtClaimsKey {
51    fn default() -> Self {
52        Self
53    }
54}
55
56/// Parse a JWT token and extract claims from the payload.
57/// Returns None if the token is not a valid JWT (e.g., it's a static token).
58pub fn extract_jwt_claims(token: &str) -> Option<JwtClaims> {
59    let parts: Vec<&str> = token.split('.').collect();
60    if parts.len() != 3 {
61        // Not a JWT - it's likely a static token
62        return None;
63    }
64
65    // Decode the payload (second part)
66    let payload = URL_SAFE_NO_PAD.decode(parts[1]).ok()?;
67
68    // Parse JSON
69    let claims: JwtClaims = serde_json::from_slice(&payload).ok()?;
70
71    Some(claims)
72}
73
74/// Shared auth state.
75#[derive(Debug, Clone)]
76pub struct AuthState {
77    /// The required Bearer token.
78    token: Arc<String>,
79}
80
81impl AuthState {
82    /// Build from the environment.  If `CODETETHER_AUTH_TOKEN` is not set a
83    /// 32-byte hex token is generated and logged once.
84    pub fn from_env() -> Self {
85        let token = match std::env::var("CODETETHER_AUTH_TOKEN") {
86            Ok(t) if !t.is_empty() => {
87                tracing::info!("Auth token loaded from CODETETHER_AUTH_TOKEN");
88                t
89            }
90            _ => {
91                let generated: String = {
92                    let mut rng = rand::rng();
93                    (0..32)
94                        .map(|_| format!("{:02x}", rng.random::<u8>()))
95                        .collect()
96                };
97                tracing::warn!(
98                    token = %generated,
99                    "No CODETETHER_AUTH_TOKEN set — generated a random token. \
100                     Set CODETETHER_AUTH_TOKEN to use a stable token."
101                );
102                generated
103            }
104        };
105        Self {
106            token: Arc::new(token),
107        }
108    }
109
110    /// Create with an explicit token (for tests).
111    #[cfg(test)]
112    pub fn with_token(token: impl Into<String>) -> Self {
113        Self {
114            token: Arc::new(token.into()),
115        }
116    }
117
118    /// Return the active token (for display at startup).
119    pub fn token(&self) -> &str {
120        &self.token
121    }
122}
123
124/// Axum middleware layer that enforces Bearer token auth on every request
125/// except public paths.
126pub async fn require_auth(mut request: Request<Body>, next: Next) -> Result<Response, StatusCode> {
127    let path = request.uri().path();
128
129    // Allow public paths through without auth.
130    if PUBLIC_PATHS.iter().any(|p| path == *p) {
131        return Ok(next.run(request).await);
132    }
133
134    // Extract the AuthState from extensions (set by the server setup).
135    let auth_state = request
136        .extensions()
137        .get::<AuthState>()
138        .cloned()
139        .ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
140
141    // Extract Bearer token from Authorization header.
142    let auth_header = request
143        .headers()
144        .get(header::AUTHORIZATION)
145        .and_then(|v| v.to_str().ok());
146
147    let provided_token = match auth_header {
148        Some(value) if value.starts_with("Bearer ") => &value[7..],
149        _ => {
150            // Also accept token via query parameter for SSE/WebSocket clients.
151            let query = request.uri().query().unwrap_or("");
152            let token_param = query.split('&').find_map(|pair| {
153                let mut parts = pair.splitn(2, '=');
154                match (parts.next(), parts.next()) {
155                    (Some("token"), Some(v)) => Some(v),
156                    _ => None,
157                }
158            });
159            match token_param {
160                Some(t) => t,
161                None => return Err(StatusCode::UNAUTHORIZED),
162            }
163        }
164    };
165
166    // Constant-time comparison to prevent timing attacks.
167    if constant_time_eq(provided_token.as_bytes(), auth_state.token.as_bytes()) {
168        // Extract JWT claims and add to request extensions for downstream handlers
169        let claims = extract_jwt_claims(provided_token);
170        if let Some(claims) = claims {
171            request.extensions_mut().insert(claims);
172        }
173        Ok(next.run(request).await)
174    } else {
175        Err(StatusCode::UNAUTHORIZED)
176    }
177}
178
179/// Constant-time byte comparison.
180fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
181    if a.len() != b.len() {
182        return false;
183    }
184    let mut diff = 0u8;
185    for (x, y) in a.iter().zip(b.iter()) {
186        diff |= x ^ y;
187    }
188    diff == 0
189}
190
191#[cfg(test)]
192mod tests {
193    use super::*;
194
195    #[test]
196    fn constant_time_eq_works() {
197        assert!(constant_time_eq(b"hello", b"hello"));
198        assert!(!constant_time_eq(b"hello", b"world"));
199        assert!(!constant_time_eq(b"short", b"longer"));
200    }
201
202    #[test]
203    fn auth_state_generates_token_when_env_missing() {
204        // Ensure the env var is not set for this test.
205        // SAFETY: This is a single-threaded test; no other thread reads this env var.
206        unsafe {
207            std::env::remove_var("CODETETHER_AUTH_TOKEN");
208        }
209        let state = AuthState::from_env();
210        assert_eq!(state.token().len(), 64); // 32 bytes = 64 hex chars
211    }
212}