use std::collections::HashMap;
use axum::http::HeaderMap;
use sha2::{Digest, Sha256};
use subtle::ConstantTimeEq;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Auth {
Project(String),
DevBypass,
}
impl Auth {
pub fn project(&self) -> Option<&str> {
match self {
Auth::Project(p) => Some(p),
Auth::DevBypass => None,
}
}
}
#[derive(Clone, Default)]
pub struct AuthConfig {
keys: HashMap<String, String>,
dev_mode: bool,
}
impl AuthConfig {
pub fn from_env() -> Self {
let raw_keys = raw_keys_from_env();
let dev_mode = dev_mode_from_env().unwrap_or(false);
let config = Self::new(raw_keys, dev_mode);
if config.keys.is_empty() && !dev_mode {
tracing::warn!(
"No API keys configured and dev mode is off; all requests will be rejected with 401. \
Set TRACE_WEFT_API_KEYS or TRACE_WEFT_DEV_MODE=1."
);
}
config
}
pub fn from_env_local_first() -> Self {
let raw_keys = raw_keys_from_env();
let dev_mode = dev_mode_from_env().unwrap_or(raw_keys.is_empty());
Self::new(raw_keys, dev_mode)
}
pub fn new(raw_keys: impl IntoIterator<Item = (String, String)>, dev_mode: bool) -> Self {
let keys = raw_keys
.into_iter()
.map(|(key, project)| (hash_key(&key), project))
.collect();
Self { keys, dev_mode }
}
pub fn authenticate(&self, headers: &HeaderMap) -> Option<Auth> {
if let Some(project) = bearer_token(headers).and_then(|token| self.lookup(&token)) {
return Some(Auth::Project(project));
}
self.dev_mode.then_some(Auth::DevBypass)
}
fn lookup(&self, presented: &str) -> Option<String> {
let presented_hash = hash_key(presented);
let mut matched: Option<String> = None;
for (stored_hash, project) in &self.keys {
if bool::from(stored_hash.as_bytes().ct_eq(presented_hash.as_bytes())) {
matched = Some(project.clone());
}
}
matched
}
}
fn raw_keys_from_env() -> Vec<(String, String)> {
std::env::var("TRACE_WEFT_API_KEYS")
.unwrap_or_default()
.split(',')
.filter_map(|pair| {
let (key, project) = pair.trim().split_once(':')?;
let (key, project) = (key.trim(), project.trim());
(!key.is_empty() && !project.is_empty()).then(|| (key.to_string(), project.to_string()))
})
.collect()
}
fn dev_mode_from_env() -> Option<bool> {
match std::env::var("TRACE_WEFT_DEV_MODE").as_deref() {
Ok("1") | Ok("true") => Some(true),
Ok("0") | Ok("false") => Some(false),
_ => None,
}
}
fn bearer_token(headers: &HeaderMap) -> Option<String> {
let value = headers.get("Authorization")?.to_str().ok()?;
value
.strip_prefix("Bearer ")
.map(|token| token.trim().to_string())
.filter(|token| !token.is_empty())
}
fn hash_key(key: &str) -> String {
let mut hasher = Sha256::new();
hasher.update(key.as_bytes());
hasher
.finalize()
.iter()
.map(|byte| format!("{byte:02x}"))
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
use axum::http::HeaderValue;
fn headers_with(auth: &str) -> HeaderMap {
let mut headers = HeaderMap::new();
headers.insert("Authorization", HeaderValue::from_str(auth).unwrap());
headers
}
fn config() -> AuthConfig {
AuthConfig::new(
[
("tw-alpha-key".to_string(), "proj_alpha".to_string()),
("tw-beta-key".to_string(), "proj_beta".to_string()),
],
false,
)
}
#[test]
fn valid_key_resolves_to_its_project() {
let auth = config().authenticate(&headers_with("Bearer tw-alpha-key"));
assert_eq!(auth, Some(Auth::Project("proj_alpha".to_string())));
let auth = config().authenticate(&headers_with("Bearer tw-beta-key"));
assert_eq!(auth, Some(Auth::Project("proj_beta".to_string())));
}
#[test]
fn unknown_key_is_rejected() {
assert_eq!(
config().authenticate(&headers_with("Bearer tw-unknown")),
None
);
}
#[test]
fn missing_or_malformed_header_is_rejected() {
assert_eq!(config().authenticate(&HeaderMap::new()), None);
assert_eq!(config().authenticate(&headers_with("tw-alpha-key")), None);
assert_eq!(config().authenticate(&headers_with("Bearer ")), None);
}
#[test]
fn dev_bypass_only_works_when_enabled() {
let strict = AuthConfig::new([], false);
assert_eq!(strict.authenticate(&HeaderMap::new()), None);
let dev = AuthConfig::new([], true);
assert_eq!(dev.authenticate(&HeaderMap::new()), Some(Auth::DevBypass));
}
#[test]
fn valid_key_takes_precedence_over_dev_bypass() {
let dev = AuthConfig::new(
[("tw-alpha-key".to_string(), "proj_alpha".to_string())],
true,
);
assert_eq!(
dev.authenticate(&headers_with("Bearer tw-alpha-key")),
Some(Auth::Project("proj_alpha".to_string()))
);
assert_eq!(
dev.authenticate(&headers_with("Bearer tw-nope")),
Some(Auth::DevBypass)
);
}
#[test]
fn local_first_defaults_bypass_on_only_without_keys() {
unsafe {
std::env::remove_var("TRACE_WEFT_DEV_MODE");
std::env::remove_var("TRACE_WEFT_API_KEYS");
}
assert_eq!(
AuthConfig::from_env_local_first().authenticate(&HeaderMap::new()),
Some(Auth::DevBypass)
);
assert_eq!(AuthConfig::from_env().authenticate(&HeaderMap::new()), None);
unsafe { std::env::set_var("TRACE_WEFT_API_KEYS", "tw-x:proj_x") }
let local = AuthConfig::from_env_local_first();
assert_eq!(local.authenticate(&HeaderMap::new()), None);
assert_eq!(
local.authenticate(&headers_with("Bearer tw-x")),
Some(Auth::Project("proj_x".to_string()))
);
unsafe { std::env::set_var("TRACE_WEFT_DEV_MODE", "1") }
assert_eq!(
AuthConfig::from_env_local_first().authenticate(&HeaderMap::new()),
Some(Auth::DevBypass)
);
unsafe {
std::env::remove_var("TRACE_WEFT_DEV_MODE");
std::env::remove_var("TRACE_WEFT_API_KEYS");
}
}
#[test]
fn stored_config_holds_hashes_not_raw_keys() {
let config = config();
assert!(!config.keys.contains_key("tw-alpha-key"));
assert!(config.keys.keys().all(|k| k.len() == 64));
}
}