codetether_agent/server/
auth.rs1use axum::{
9 body::Body,
10 http::{Request, StatusCode, header},
11 middleware::Next,
12 response::Response,
13};
14use rand::RngExt;
15use std::sync::Arc;
16
17const PUBLIC_PATHS: &[&str] = &["/health"];
19
20#[derive(Debug, Clone)]
22pub struct AuthState {
23 token: Arc<String>,
25}
26
27impl AuthState {
28 pub fn from_env() -> Self {
31 let token = match std::env::var("CODETETHER_AUTH_TOKEN") {
32 Ok(t) if !t.is_empty() => {
33 tracing::info!("Auth token loaded from CODETETHER_AUTH_TOKEN");
34 t
35 }
36 _ => {
37 let generated: String = {
38 let mut rng = rand::rng();
39 (0..32)
40 .map(|_| format!("{:02x}", rng.random::<u8>()))
41 .collect()
42 };
43 tracing::warn!(
44 token = %generated,
45 "No CODETETHER_AUTH_TOKEN set — generated a random token. \
46 Set CODETETHER_AUTH_TOKEN to use a stable token."
47 );
48 generated
49 }
50 };
51 Self {
52 token: Arc::new(token),
53 }
54 }
55
56 #[cfg(test)]
58 pub fn with_token(token: impl Into<String>) -> Self {
59 Self {
60 token: Arc::new(token.into()),
61 }
62 }
63
64 pub fn token(&self) -> &str {
66 &self.token
67 }
68}
69
70pub async fn require_auth(request: Request<Body>, next: Next) -> Result<Response, StatusCode> {
73 let path = request.uri().path();
74
75 if PUBLIC_PATHS.iter().any(|p| path == *p) {
77 return Ok(next.run(request).await);
78 }
79
80 let auth_state = request
82 .extensions()
83 .get::<AuthState>()
84 .cloned()
85 .ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
86
87 let auth_header = request
89 .headers()
90 .get(header::AUTHORIZATION)
91 .and_then(|v| v.to_str().ok());
92
93 let provided_token = match auth_header {
94 Some(value) if value.starts_with("Bearer ") => &value[7..],
95 _ => {
96 let query = request.uri().query().unwrap_or("");
98 let token_param = query.split('&').find_map(|pair| {
99 let mut parts = pair.splitn(2, '=');
100 match (parts.next(), parts.next()) {
101 (Some("token"), Some(v)) => Some(v),
102 _ => None,
103 }
104 });
105 match token_param {
106 Some(t) => t,
107 None => return Err(StatusCode::UNAUTHORIZED),
108 }
109 }
110 };
111
112 if constant_time_eq(provided_token.as_bytes(), auth_state.token.as_bytes()) {
114 Ok(next.run(request).await)
115 } else {
116 Err(StatusCode::UNAUTHORIZED)
117 }
118}
119
120fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
122 if a.len() != b.len() {
123 return false;
124 }
125 let mut diff = 0u8;
126 for (x, y) in a.iter().zip(b.iter()) {
127 diff |= x ^ y;
128 }
129 diff == 0
130}
131
132#[cfg(test)]
133mod tests {
134 use super::*;
135
136 #[test]
137 fn constant_time_eq_works() {
138 assert!(constant_time_eq(b"hello", b"hello"));
139 assert!(!constant_time_eq(b"hello", b"world"));
140 assert!(!constant_time_eq(b"short", b"longer"));
141 }
142
143 #[test]
144 fn auth_state_generates_token_when_env_missing() {
145 unsafe {
148 std::env::remove_var("CODETETHER_AUTH_TOKEN");
149 }
150 let state = AuthState::from_env();
151 assert_eq!(state.token().len(), 64); }
153}