codetether_agent/server/
auth.rs1use axum::{
12 body::Body,
13 http::{Request, StatusCode, header},
14 middleware::Next,
15 response::Response,
16};
17use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
18use rand::RngExt;
19use serde::{Deserialize, Serialize};
20use std::sync::Arc;
21
22const PUBLIC_PATHS: &[&str] = &["/health"];
24
25#[derive(Debug, Clone, Default, Serialize, Deserialize)]
27pub struct JwtClaims {
28 #[serde(default)]
30 pub topics: Vec<String>,
31 #[serde(default, rename = "sub")]
33 pub subject: Option<String>,
34 #[serde(default)]
36 pub scopes: Vec<String>,
37}
38
39#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
41pub struct JwtClaimsKey;
42
43#[derive(Debug, Clone)]
45pub struct JwtAppState {
46 pub jwt_claims: JwtClaims,
48}
49
50impl Default for JwtClaimsKey {
51 fn default() -> Self {
52 Self
53 }
54}
55
56pub fn extract_jwt_claims(token: &str) -> Option<JwtClaims> {
59 let parts: Vec<&str> = token.split('.').collect();
60 if parts.len() != 3 {
61 return None;
63 }
64
65 let payload = URL_SAFE_NO_PAD.decode(parts[1]).ok()?;
67
68 let claims: JwtClaims = serde_json::from_slice(&payload).ok()?;
70
71 Some(claims)
72}
73
74#[derive(Debug, Clone)]
76pub struct AuthState {
77 token: Arc<String>,
79}
80
81impl AuthState {
82 pub fn from_env() -> Self {
85 let token = match std::env::var("CODETETHER_AUTH_TOKEN") {
86 Ok(t) if !t.is_empty() => {
87 tracing::info!("Auth token loaded from CODETETHER_AUTH_TOKEN");
88 t
89 }
90 _ => {
91 let generated: String = {
92 let mut rng = rand::rng();
93 (0..32)
94 .map(|_| format!("{:02x}", rng.random::<u8>()))
95 .collect()
96 };
97 tracing::warn!(
98 token = %generated,
99 "No CODETETHER_AUTH_TOKEN set — generated a random token. \
100 Set CODETETHER_AUTH_TOKEN to use a stable token."
101 );
102 generated
103 }
104 };
105 Self {
106 token: Arc::new(token),
107 }
108 }
109
110 #[cfg(test)]
112 pub fn with_token(token: impl Into<String>) -> Self {
113 Self {
114 token: Arc::new(token.into()),
115 }
116 }
117
118 pub fn token(&self) -> &str {
120 &self.token
121 }
122}
123
124pub async fn require_auth(mut request: Request<Body>, next: Next) -> Result<Response, StatusCode> {
127 let path = request.uri().path();
128
129 if PUBLIC_PATHS.iter().any(|p| path == *p) {
131 return Ok(next.run(request).await);
132 }
133
134 let auth_state = request
136 .extensions()
137 .get::<AuthState>()
138 .cloned()
139 .ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
140
141 let auth_header = request
143 .headers()
144 .get(header::AUTHORIZATION)
145 .and_then(|v| v.to_str().ok());
146
147 let provided_token = match auth_header {
148 Some(value) if value.starts_with("Bearer ") => &value[7..],
149 _ => {
150 let query = request.uri().query().unwrap_or("");
152 let token_param = query.split('&').find_map(|pair| {
153 let mut parts = pair.splitn(2, '=');
154 match (parts.next(), parts.next()) {
155 (Some("token"), Some(v)) => Some(v),
156 _ => None,
157 }
158 });
159 match token_param {
160 Some(t) => t,
161 None => return Err(StatusCode::UNAUTHORIZED),
162 }
163 }
164 };
165
166 if constant_time_eq(provided_token.as_bytes(), auth_state.token.as_bytes()) {
168 let claims = extract_jwt_claims(provided_token);
170 if let Some(claims) = claims {
171 request.extensions_mut().insert(claims);
172 }
173 Ok(next.run(request).await)
174 } else {
175 Err(StatusCode::UNAUTHORIZED)
176 }
177}
178
179fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
181 if a.len() != b.len() {
182 return false;
183 }
184 let mut diff = 0u8;
185 for (x, y) in a.iter().zip(b.iter()) {
186 diff |= x ^ y;
187 }
188 diff == 0
189}
190
191#[cfg(test)]
192mod tests {
193 use super::*;
194
195 #[test]
196 fn constant_time_eq_works() {
197 assert!(constant_time_eq(b"hello", b"hello"));
198 assert!(!constant_time_eq(b"hello", b"world"));
199 assert!(!constant_time_eq(b"short", b"longer"));
200 }
201
202 #[test]
203 fn auth_state_generates_token_when_env_missing() {
204 unsafe {
207 std::env::remove_var("CODETETHER_AUTH_TOKEN");
208 }
209 let state = AuthState::from_env();
210 assert_eq!(state.token().len(), 64); }
212}