use axum::{
body::Body,
extract::Request,
http::{header::AUTHORIZATION, StatusCode},
middleware::Next,
response::{IntoResponse, Response},
};
use std::sync::Arc;
#[derive(Clone, Debug)]
pub struct AuthConfig {
pub token: Option<String>,
}
impl AuthConfig {
pub fn from_env_or_disk() -> Self {
if let Ok(t) = std::env::var("ISELF_API_TOKEN") {
let trimmed = t.trim().to_string();
if !trimmed.is_empty() {
return Self { token: Some(trimmed) };
}
}
if let Some(home) = dirs::home_dir() {
let token_file = home.join(".i-self").join("api_token");
if let Ok(content) = std::fs::read_to_string(&token_file) {
let trimmed = content.trim().to_string();
if !trimmed.is_empty() {
return Self { token: Some(trimmed) };
}
}
}
Self { token: None }
}
}
pub async fn require_bearer(
axum::extract::State(cfg): axum::extract::State<Arc<AuthConfig>>,
req: Request,
next: Next,
) -> Response {
let expected = match &cfg.token {
Some(t) => t,
None => return next.run(req).await,
};
let header_token = req
.headers()
.get(AUTHORIZATION)
.and_then(|v| v.to_str().ok())
.and_then(|v| v.strip_prefix("Bearer "))
.unwrap_or("");
let query_token = req
.uri()
.query()
.and_then(extract_query_token)
.unwrap_or("");
let presented = if !header_token.is_empty() {
header_token
} else {
query_token
};
if constant_time_eq(presented.as_bytes(), expected.as_bytes()) {
next.run(req).await
} else {
(
StatusCode::UNAUTHORIZED,
[("WWW-Authenticate", "Bearer realm=\"i-self\"")],
Body::from("missing or invalid bearer token"),
)
.into_response()
}
}
fn extract_query_token(query: &str) -> Option<&str> {
for pair in query.split('&') {
if let Some((k, v)) = pair.split_once('=') {
if k == "token" {
return Some(v);
}
}
}
None
}
fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
if a.len() != b.len() {
let mut diff: u8 = 1;
for &x in a.iter().take(64) {
diff |= x ^ x;
}
return diff == 0 && false;
}
let mut diff: u8 = 0;
for (x, y) in a.iter().zip(b.iter()) {
diff |= x ^ y;
}
diff == 0
}
#[cfg(test)]
mod tests {
use super::*;
use axum::{
body::Body,
http::{Request as HttpRequest, StatusCode as HttpStatusCode},
middleware,
routing::get,
Router,
};
use tower::ServiceExt;
#[test]
fn constant_time_eq_matches() {
assert!(constant_time_eq(b"abc", b"abc"));
assert!(!constant_time_eq(b"abc", b"abd"));
assert!(!constant_time_eq(b"abc", b"abcd"));
assert!(!constant_time_eq(b"abc", b""));
assert!(constant_time_eq(b"", b""));
}
#[test]
fn from_env_returns_none_when_no_sources() {
let _ = AuthConfig::from_env_or_disk();
}
fn router_with_token(token: Option<&str>) -> Router {
let cfg = Arc::new(AuthConfig {
token: token.map(|s| s.to_string()),
});
Router::new()
.route("/api/secret", get(|| async { "ok" }))
.route_layer(middleware::from_fn_with_state(cfg, require_bearer))
}
#[tokio::test]
async fn no_token_configured_means_no_auth_required() {
let app = router_with_token(None);
let resp = app
.oneshot(HttpRequest::builder().uri("/api/secret").body(Body::empty()).unwrap())
.await
.unwrap();
assert_eq!(resp.status(), HttpStatusCode::OK);
}
#[tokio::test]
async fn missing_authorization_header_rejected() {
let app = router_with_token(Some("s3cret"));
let resp = app
.oneshot(HttpRequest::builder().uri("/api/secret").body(Body::empty()).unwrap())
.await
.unwrap();
assert_eq!(resp.status(), HttpStatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn wrong_token_rejected() {
let app = router_with_token(Some("s3cret"));
let resp = app
.oneshot(
HttpRequest::builder()
.uri("/api/secret")
.header("authorization", "Bearer wrong")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(resp.status(), HttpStatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn correct_token_accepted() {
let app = router_with_token(Some("s3cret"));
let resp = app
.oneshot(
HttpRequest::builder()
.uri("/api/secret")
.header("authorization", "Bearer s3cret")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(resp.status(), HttpStatusCode::OK);
}
#[tokio::test]
async fn missing_bearer_prefix_rejected() {
let app = router_with_token(Some("s3cret"));
let resp = app
.oneshot(
HttpRequest::builder()
.uri("/api/secret")
.header("authorization", "s3cret") .body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(resp.status(), HttpStatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn query_string_token_accepted() {
let app = router_with_token(Some("s3cret"));
let resp = app
.oneshot(
HttpRequest::builder()
.uri("/api/secret?token=s3cret")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(resp.status(), HttpStatusCode::OK);
}
#[tokio::test]
async fn query_string_token_works_with_other_params() {
let app = router_with_token(Some("s3cret"));
let resp = app
.oneshot(
HttpRequest::builder()
.uri("/api/secret?foo=bar&token=s3cret&baz=qux")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(resp.status(), HttpStatusCode::OK);
}
#[tokio::test]
async fn query_string_wrong_token_rejected() {
let app = router_with_token(Some("s3cret"));
let resp = app
.oneshot(
HttpRequest::builder()
.uri("/api/secret?token=wrong")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(resp.status(), HttpStatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn header_takes_precedence_over_query_string() {
let app = router_with_token(Some("s3cret"));
let resp = app
.oneshot(
HttpRequest::builder()
.uri("/api/secret?token=wrong")
.header("authorization", "Bearer s3cret")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(resp.status(), HttpStatusCode::OK);
}
#[test]
fn extract_query_token_handles_edge_cases() {
assert_eq!(extract_query_token("token=abc"), Some("abc"));
assert_eq!(extract_query_token("a=1&token=abc&b=2"), Some("abc"));
assert_eq!(extract_query_token(""), None);
assert_eq!(extract_query_token("token"), None); assert_eq!(extract_query_token("foo=bar"), None);
assert_eq!(extract_query_token("token="), Some(""));
}
}