use std::net::IpAddr;
use thiserror::Error;
#[derive(Debug, Error, PartialEq, Eq)]
pub enum AuthError {
#[error("missing Authorization header")]
MissingHeader,
#[error("unsupported authorization scheme (expected Bearer)")]
UnsupportedScheme,
#[error("invalid bearer token")]
InvalidToken,
#[error("request from non-localhost address {0} rejected (localhost-only mode)")]
NonLocalhostAddress(IpAddr),
#[error("refusing to start: --bind 0.0.0.0 requires --token or --localhost-only")]
UnsafeConfig,
}
#[derive(Debug, Clone)]
pub enum AuthConfig {
LocalhostOnly,
Bearer(String),
}
impl AuthConfig {
#[must_use]
pub fn localhost_only() -> Self {
Self::LocalhostOnly
}
#[must_use]
pub fn bearer(token: String) -> Self {
Self::Bearer(token)
}
#[must_use]
pub fn is_localhost_only(&self) -> bool {
matches!(self, Self::LocalhostOnly)
}
#[must_use]
pub fn is_bearer(&self) -> bool {
matches!(self, Self::Bearer(_))
}
}
#[derive(Debug, Clone)]
pub struct BearerValidator {
config: AuthConfig,
}
impl BearerValidator {
#[must_use]
pub fn new(config: AuthConfig) -> Self {
Self { config }
}
pub fn validate_bearer(&self, raw_header: &str) -> Result<(), AuthError> {
let AuthConfig::Bearer(expected) = &self.config else {
return Ok(());
};
let Some(credential) = strip_bearer_scheme(raw_header) else {
return Err(AuthError::UnsupportedScheme);
};
if constant_time_eq(credential.as_bytes(), expected.as_bytes()) {
Ok(())
} else {
Err(AuthError::InvalidToken)
}
}
pub fn validate_source_ip(&self, addr: IpAddr) -> Result<(), AuthError> {
if matches!(self.config, AuthConfig::LocalhostOnly) && !addr.is_loopback() {
return Err(AuthError::NonLocalhostAddress(addr));
}
Ok(())
}
pub fn validate_header(&self, header: Option<&str>) -> Result<(), AuthError> {
match header {
None => {
if matches!(self.config, AuthConfig::LocalhostOnly) {
Ok(())
} else {
Err(AuthError::MissingHeader)
}
}
Some(h) => self.validate_bearer(h),
}
}
pub fn check_bind_safety(&self, bind_addr: IpAddr) -> Result<(), AuthError> {
if !bind_addr.is_loopback() && matches!(self.config, AuthConfig::LocalhostOnly) {
return Err(AuthError::UnsafeConfig);
}
Ok(())
}
}
#[cfg(feature = "http-transport")]
#[must_use]
pub fn generate_token() -> String {
use rand::Rng as _;
let mut bytes = [0u8; 32];
rand::rng().fill_bytes(&mut bytes);
format!("axt_{}", hex::encode_bytes(&bytes))
}
#[must_use]
pub fn make_test_token(seed: &str) -> String {
let hex_body = format!("{:0<64}", seed.chars().take(64).collect::<String>());
format!("axt_{hex_body}")
}
fn strip_bearer_scheme(header: &str) -> Option<&str> {
let trimmed = header.trim_start();
if trimmed.len() < 7 {
return None;
}
let (scheme, rest) = trimmed.split_at(7);
if scheme.eq_ignore_ascii_case("Bearer ") {
Some(rest.trim_start())
} else {
None
}
}
fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
if a.len() != b.len() {
return false;
}
let diff: u8 = a.iter().zip(b.iter()).fold(0, |acc, (x, y)| acc | (x ^ y));
diff == 0
}
mod hex {
#[cfg(feature = "http-transport")]
pub fn encode_bytes(bytes: &[u8]) -> String {
bytes.iter().fold(String::new(), |mut s, b| {
use std::fmt::Write as _;
let _ = write!(s, "{b:02x}");
s
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
#[test]
fn localhost_only_config_reports_correctly() {
let cfg = AuthConfig::localhost_only();
assert!(cfg.is_localhost_only());
assert!(!cfg.is_bearer());
}
#[test]
fn bearer_config_reports_correctly() {
let cfg = AuthConfig::bearer("secret".into());
assert!(!cfg.is_localhost_only());
assert!(cfg.is_bearer());
}
#[test]
fn strip_bearer_scheme_extracts_credential() {
let result = strip_bearer_scheme("Bearer my-secret-token");
assert_eq!(result, Some("my-secret-token"));
}
#[test]
fn strip_bearer_scheme_is_case_insensitive() {
assert_eq!(strip_bearer_scheme("bearer TOKEN"), Some("TOKEN"));
assert_eq!(strip_bearer_scheme("BEARER TOKEN"), Some("TOKEN"));
}
#[test]
fn strip_bearer_scheme_returns_none_for_basic() {
assert_eq!(strip_bearer_scheme("Basic dXNlcjpwYXNz"), None);
}
#[test]
fn strip_bearer_scheme_returns_none_for_empty_string() {
assert_eq!(strip_bearer_scheme(""), None);
}
#[test]
fn constant_time_eq_matches_identical_slices() {
assert!(constant_time_eq(b"hello", b"hello"));
}
#[test]
fn constant_time_eq_rejects_different_slices() {
assert!(!constant_time_eq(b"hello", b"world"));
}
#[test]
fn constant_time_eq_rejects_different_lengths() {
assert!(!constant_time_eq(b"hi", b"hello"));
}
#[test]
fn constant_time_eq_handles_empty_slices() {
assert!(constant_time_eq(b"", b""));
}
fn make_validator(token: &str) -> BearerValidator {
BearerValidator::new(AuthConfig::bearer(token.to_string()))
}
#[test]
fn validate_bearer_accepts_correct_token() {
let v = make_validator("my-secret");
assert!(v.validate_bearer("Bearer my-secret").is_ok());
}
#[test]
fn validate_bearer_rejects_wrong_token() {
let v = make_validator("correct");
assert_eq!(
v.validate_bearer("Bearer wrong"),
Err(AuthError::InvalidToken)
);
}
#[test]
fn validate_bearer_rejects_non_bearer_scheme() {
let v = make_validator("tok");
assert_eq!(
v.validate_bearer("Basic dXNlcjpwYXNz"),
Err(AuthError::UnsupportedScheme)
);
}
#[test]
fn validate_header_returns_missing_header_when_none() {
let v = make_validator("tok");
assert_eq!(v.validate_header(None), Err(AuthError::MissingHeader));
}
#[test]
fn validate_header_accepts_correct_bearer() {
let v = make_validator("tok");
assert!(v.validate_header(Some("Bearer tok")).is_ok());
}
fn localhost_validator() -> BearerValidator {
BearerValidator::new(AuthConfig::localhost_only())
}
#[test]
fn localhost_only_allows_loopback_ipv4() {
let v = localhost_validator();
let r = v.validate_source_ip(IpAddr::V4(Ipv4Addr::LOCALHOST));
assert!(r.is_ok());
}
#[test]
fn localhost_only_allows_loopback_ipv6() {
let v = localhost_validator();
assert!(v
.validate_source_ip(IpAddr::V6(Ipv6Addr::LOCALHOST))
.is_ok());
}
#[test]
fn localhost_only_rejects_external_ip() {
let v = localhost_validator();
let ext = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100));
assert_eq!(
v.validate_source_ip(ext),
Err(AuthError::NonLocalhostAddress(ext))
);
}
#[test]
fn localhost_only_header_validation_skips_token_check() {
let v = localhost_validator();
assert!(v.validate_header(None).is_ok());
assert!(v.validate_bearer("Bearer anything").is_ok());
}
#[test]
fn check_bind_safety_rejects_external_bind_without_token() {
let v = localhost_validator();
let addr = IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0));
assert_eq!(v.check_bind_safety(addr), Err(AuthError::UnsafeConfig));
}
#[test]
fn check_bind_safety_allows_loopback_without_token() {
let v = localhost_validator();
assert!(v.check_bind_safety(IpAddr::V4(Ipv4Addr::LOCALHOST)).is_ok());
}
#[test]
fn check_bind_safety_allows_external_with_token() {
let v = make_validator("tok");
let addr = IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0));
assert!(v.check_bind_safety(addr).is_ok());
}
#[test]
fn make_test_token_produces_axt_prefix() {
let t = make_test_token("abc");
assert!(t.starts_with("axt_"));
}
#[test]
fn make_test_token_produces_correct_total_length() {
let t = make_test_token("anything");
assert_eq!(t.len(), 68);
}
#[test]
fn make_test_token_pads_short_seed() {
let t = make_test_token("x");
assert_eq!(t.len(), 68);
}
#[cfg(feature = "http-transport")]
#[test]
fn hex_encode_produces_lowercase_hex() {
assert_eq!(hex::encode_bytes(&[0x0A, 0xFF, 0x00]), "0aff00");
}
#[cfg(feature = "http-transport")]
#[test]
fn hex_encode_empty_slice_is_empty_string() {
assert_eq!(hex::encode_bytes(&[]), "");
}
}