use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct AuthError {
pub message: &'static str,
}
impl AuthError {
pub fn new(message: &'static str) -> Self {
Self { message }
}
}
impl std::fmt::Display for AuthError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.message)
}
}
impl std::error::Error for AuthError {}
#[derive(Debug, Clone, Default)]
pub struct TokenAuth {
pub token: Option<String>,
}
impl TokenAuth {
pub fn new(token: Option<String>) -> Self {
Self { token }
}
pub fn validate(&self, expected: &str) -> Result<(), AuthError> {
match &self.token {
Some(token) if token == expected => Ok(()),
Some(_) => Err(AuthError::new("Invalid token")),
None => Err(AuthError::new("Missing token")),
}
}
pub fn validate_env(&self, env_var: &str) -> Result<(), AuthError> {
match std::env::var(env_var) {
Ok(expected) => self.validate(&expected),
Err(_) => Ok(()), }
}
pub fn has_token(&self) -> bool {
self.token.is_some()
}
pub fn from_query_string(query: &str) -> Self {
let params: HashMap<_, _> = form_urlencoded::parse(query.as_bytes()).collect();
Self {
token: params.get("token").map(|s| s.to_string()),
}
}
}
#[cfg(feature = "axum")]
mod axum_impl {
use super::*;
use axum::extract::FromRequestParts;
use axum::http::request::Parts;
use axum::http::StatusCode;
impl<S> FromRequestParts<S> for TokenAuth
where
S: Send + Sync,
{
type Rejection = (StatusCode, &'static str);
async fn from_request_parts(
parts: &mut Parts,
_state: &S,
) -> Result<Self, Self::Rejection> {
let token = parts.uri.query().and_then(|q| {
form_urlencoded::parse(q.as_bytes())
.find(|(k, _)| k == "token")
.map(|(_, v)| {
let s = v.to_string();
if let Some(idx) = s.find("/api/") {
s[..idx].to_string()
} else {
s
}
})
});
Ok(TokenAuth { token })
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_token_auth_validate() {
let auth = TokenAuth::new(Some("secret123".to_string()));
assert!(auth.validate("secret123").is_ok());
assert!(auth.validate("wrong").is_err());
}
#[test]
fn test_token_auth_missing() {
let auth = TokenAuth::new(None);
assert!(auth.validate("anything").is_err());
}
#[test]
fn test_from_query_string() {
let auth = TokenAuth::from_query_string("token=mysecret&other=value");
assert_eq!(auth.token, Some("mysecret".to_string()));
let auth_empty = TokenAuth::from_query_string("other=value");
assert_eq!(auth_empty.token, None);
}
#[test]
fn test_validate_env_not_set() {
let auth = TokenAuth::new(None);
assert!(auth.validate_env("NONEXISTENT_VAR_12345").is_ok());
}
#[test]
fn test_from_query_string_with_firmware_quirk() {
let auth = TokenAuth::from_query_string("token=mysecret/api/display");
let token = auth.token.unwrap();
let clean = if let Some(idx) = token.find("/api/") {
token[..idx].to_string()
} else {
token
};
assert_eq!(clean, "mysecret");
}
}