use axum::{
body::Body,
http::{Request, StatusCode, header},
middleware::Next,
response::Response,
};
use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
use rand::RngExt;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
const PUBLIC_PATHS: &[&str] = &["/health"];
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct JwtClaims {
#[serde(default)]
pub topics: Vec<String>,
#[serde(default, rename = "sub")]
pub subject: Option<String>,
#[serde(default)]
pub scopes: Vec<String>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct JwtClaimsKey;
#[derive(Debug, Clone)]
pub struct JwtAppState {
pub jwt_claims: JwtClaims,
}
impl Default for JwtClaimsKey {
fn default() -> Self {
Self
}
}
pub fn extract_jwt_claims(token: &str) -> Option<JwtClaims> {
let parts: Vec<&str> = token.split('.').collect();
if parts.len() != 3 {
return None;
}
let payload = URL_SAFE_NO_PAD.decode(parts[1]).ok()?;
let claims: JwtClaims = serde_json::from_slice(&payload).ok()?;
Some(claims)
}
#[derive(Debug, Clone)]
pub struct AuthState {
token: Arc<String>,
}
impl AuthState {
pub fn from_env() -> Self {
let token = match std::env::var("CODETETHER_AUTH_TOKEN") {
Ok(t) if !t.is_empty() => {
tracing::info!("Auth token loaded from CODETETHER_AUTH_TOKEN");
t
}
_ => {
let generated: String = {
let mut rng = rand::rng();
(0..32)
.map(|_| format!("{:02x}", rng.random::<u8>()))
.collect()
};
tracing::warn!(
token = %generated,
"No CODETETHER_AUTH_TOKEN set — generated a random token. \
Set CODETETHER_AUTH_TOKEN to use a stable token."
);
generated
}
};
Self {
token: Arc::new(token),
}
}
#[cfg(test)]
pub fn with_token(token: impl Into<String>) -> Self {
Self {
token: Arc::new(token.into()),
}
}
pub fn token(&self) -> &str {
&self.token
}
}
pub async fn require_auth(mut request: Request<Body>, next: Next) -> Result<Response, StatusCode> {
let path = request.uri().path();
if PUBLIC_PATHS.iter().any(|p| path == *p) {
return Ok(next.run(request).await);
}
let auth_state = request
.extensions()
.get::<AuthState>()
.cloned()
.ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
let auth_header = request
.headers()
.get(header::AUTHORIZATION)
.and_then(|v| v.to_str().ok());
let provided_token = match auth_header {
Some(value) if value.starts_with("Bearer ") => &value[7..],
_ => {
let query = request.uri().query().unwrap_or("");
let token_param = query.split('&').find_map(|pair| {
let mut parts = pair.splitn(2, '=');
match (parts.next(), parts.next()) {
(Some("token"), Some(v)) => Some(v),
_ => None,
}
});
match token_param {
Some(t) => t,
None => return Err(StatusCode::UNAUTHORIZED),
}
}
};
if constant_time_eq(provided_token.as_bytes(), auth_state.token.as_bytes()) {
let claims = extract_jwt_claims(provided_token);
if let Some(claims) = claims {
request.extensions_mut().insert(claims);
}
Ok(next.run(request).await)
} else {
Err(StatusCode::UNAUTHORIZED)
}
}
fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
if a.len() != b.len() {
return false;
}
let mut diff = 0u8;
for (x, y) in a.iter().zip(b.iter()) {
diff |= x ^ y;
}
diff == 0
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn constant_time_eq_works() {
assert!(constant_time_eq(b"hello", b"hello"));
assert!(!constant_time_eq(b"hello", b"world"));
assert!(!constant_time_eq(b"short", b"longer"));
}
#[test]
fn auth_state_generates_token_when_env_missing() {
unsafe {
std::env::remove_var("CODETETHER_AUTH_TOKEN");
}
let state = AuthState::from_env();
assert_eq!(state.token().len(), 64); }
}