use std::collections::HashMap;
use std::sync::Arc;
use thiserror::Error;
#[derive(Error, Debug, Clone)]
pub enum AuthError {
#[error("Authentication token missing")]
TokenMissing,
#[error("Invalid token format: {0}")]
InvalidFormat(String),
#[error("Token has expired")]
TokenExpired,
#[error("Invalid token signature")]
InvalidSignature,
#[error("Token validation failed: {0}")]
ValidationFailed(String),
#[error("Insufficient permissions: {0}")]
InsufficientPermissions(String),
}
impl AuthError {
pub fn validation_failed(msg: impl Into<String>) -> Self {
Self::ValidationFailed(msg.into())
}
pub fn invalid_format(msg: impl Into<String>) -> Self {
Self::InvalidFormat(msg.into())
}
pub fn insufficient_permissions(msg: impl Into<String>) -> Self {
Self::InsufficientPermissions(msg.into())
}
}
#[derive(Debug, Clone)]
pub struct Claims {
pub sub: String,
pub extra: HashMap<String, String>,
}
impl Claims {
pub fn new(sub: impl Into<String>) -> Self {
Self {
sub: sub.into(),
extra: HashMap::new(),
}
}
pub fn with_extra(sub: impl Into<String>, extra: HashMap<String, String>) -> Self {
Self {
sub: sub.into(),
extra,
}
}
pub fn subject(&self) -> &str {
&self.sub
}
pub fn get(&self, key: &str) -> Option<&str> {
self.extra.get(key).map(|s| s.as_str())
}
pub fn insert(&mut self, key: impl Into<String>, value: impl Into<String>) {
self.extra.insert(key.into(), value.into());
}
}
#[derive(Debug, Clone)]
pub enum TokenExtractor {
Header(String),
Query(String),
Protocol,
}
impl Default for TokenExtractor {
fn default() -> Self {
Self::Header("Authorization".to_string())
}
}
impl TokenExtractor {
pub fn header(name: impl Into<String>) -> Self {
Self::Header(name.into())
}
pub fn query(name: impl Into<String>) -> Self {
Self::Query(name.into())
}
pub fn protocol() -> Self {
Self::Protocol
}
pub fn extract<B>(&self, req: &http::Request<B>) -> Option<String> {
match self {
TokenExtractor::Header(name) => {
req.headers()
.get(name)
.and_then(|v| v.to_str().ok())
.map(|s| {
if let Some(token) = s.strip_prefix("Bearer ") {
token.to_string()
} else {
s.to_string()
}
})
}
TokenExtractor::Query(name) => req.uri().query().and_then(|query| {
url::form_urlencoded::parse(query.as_bytes())
.find(|(key, _)| key == name)
.map(|(_, value)| value.into_owned())
}),
TokenExtractor::Protocol => req
.headers()
.get("Sec-WebSocket-Protocol")
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string()),
}
}
}
#[async_trait::async_trait]
pub trait TokenValidator: Send + Sync {
async fn validate(&self, token: &str) -> Result<Claims, AuthError>;
}
#[derive(Clone)]
pub struct WsAuthConfig {
pub extractor: TokenExtractor,
pub validator: Arc<dyn TokenValidator>,
pub required: bool,
}
impl WsAuthConfig {
pub fn new<V: TokenValidator + 'static>(validator: V) -> Self {
Self {
extractor: TokenExtractor::default(),
validator: Arc::new(validator),
required: true,
}
}
pub fn extractor(mut self, extractor: TokenExtractor) -> Self {
self.extractor = extractor;
self
}
pub fn required(mut self, required: bool) -> Self {
self.required = required;
self
}
pub async fn authenticate<B>(
&self,
req: &http::Request<B>,
) -> Result<Option<Claims>, AuthError> {
match self.extractor.extract(req) {
Some(token) => {
let claims = self.validator.validate(&token).await?;
Ok(Some(claims))
}
None if self.required => Err(AuthError::TokenMissing),
None => Ok(None),
}
}
}
impl std::fmt::Debug for WsAuthConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("WsAuthConfig")
.field("extractor", &self.extractor)
.field("required", &self.required)
.finish()
}
}
pub struct AcceptAllValidator;
#[async_trait::async_trait]
impl TokenValidator for AcceptAllValidator {
async fn validate(&self, token: &str) -> Result<Claims, AuthError> {
if token.is_empty() {
return Err(AuthError::invalid_format("Token cannot be empty"));
}
Ok(Claims::new(token))
}
}
pub struct RejectAllValidator;
#[async_trait::async_trait]
impl TokenValidator for RejectAllValidator {
async fn validate(&self, _token: &str) -> Result<Claims, AuthError> {
Err(AuthError::validation_failed("All tokens rejected"))
}
}
pub struct StaticTokenValidator {
tokens: HashMap<String, Claims>,
}
impl StaticTokenValidator {
pub fn new() -> Self {
Self {
tokens: HashMap::new(),
}
}
pub fn add_token(mut self, token: impl Into<String>, claims: Claims) -> Self {
self.tokens.insert(token.into(), claims);
self
}
}
impl Default for StaticTokenValidator {
fn default() -> Self {
Self::new()
}
}
#[async_trait::async_trait]
impl TokenValidator for StaticTokenValidator {
async fn validate(&self, token: &str) -> Result<Claims, AuthError> {
self.tokens
.get(token)
.cloned()
.ok_or_else(|| AuthError::validation_failed("Invalid token"))
}
}
#[cfg(test)]
mod tests {
use super::*;
use http::Request;
#[test]
fn test_token_extractor_header() {
let extractor = TokenExtractor::header("Authorization");
let req = Request::builder()
.header("Authorization", "Bearer test-token")
.body(())
.unwrap();
assert_eq!(extractor.extract(&req), Some("test-token".to_string()));
}
#[test]
fn test_token_extractor_header_no_bearer() {
let extractor = TokenExtractor::header("X-API-Key");
let req = Request::builder()
.header("X-API-Key", "my-api-key")
.body(())
.unwrap();
assert_eq!(extractor.extract(&req), Some("my-api-key".to_string()));
}
#[test]
fn test_token_extractor_query() {
let extractor = TokenExtractor::query("token");
let req = Request::builder()
.uri("ws://localhost/ws?token=query-token&other=value")
.body(())
.unwrap();
assert_eq!(extractor.extract(&req), Some("query-token".to_string()));
}
#[test]
fn test_token_extractor_protocol() {
let extractor = TokenExtractor::protocol();
let req = Request::builder()
.header("Sec-WebSocket-Protocol", "my-protocol-token")
.body(())
.unwrap();
assert_eq!(
extractor.extract(&req),
Some("my-protocol-token".to_string())
);
}
#[test]
fn test_token_extractor_missing() {
let extractor = TokenExtractor::header("Authorization");
let req = Request::builder().body(()).unwrap();
assert_eq!(extractor.extract(&req), None);
}
#[tokio::test]
async fn test_accept_all_validator() {
let validator = AcceptAllValidator;
let result = validator.validate("any-token").await;
assert!(result.is_ok());
assert_eq!(result.unwrap().subject(), "any-token");
}
#[tokio::test]
async fn test_accept_all_validator_empty() {
let validator = AcceptAllValidator;
let result = validator.validate("").await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_reject_all_validator() {
let validator = RejectAllValidator;
let result = validator.validate("any-token").await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_static_token_validator() {
let validator =
StaticTokenValidator::new().add_token("valid-token", Claims::new("user-123"));
let result = validator.validate("valid-token").await;
assert!(result.is_ok());
assert_eq!(result.unwrap().subject(), "user-123");
let result = validator.validate("invalid-token").await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_ws_auth_config_required() {
let config = WsAuthConfig::new(AcceptAllValidator)
.extractor(TokenExtractor::header("Authorization"))
.required(true);
let req = Request::builder().body(()).unwrap();
let result = config.authenticate(&req).await;
assert!(matches!(result, Err(AuthError::TokenMissing)));
}
#[tokio::test]
async fn test_ws_auth_config_optional() {
let config = WsAuthConfig::new(AcceptAllValidator)
.extractor(TokenExtractor::header("Authorization"))
.required(false);
let req = Request::builder().body(()).unwrap();
let result = config.authenticate(&req).await;
assert!(result.is_ok());
assert!(result.unwrap().is_none());
}
#[tokio::test]
async fn test_ws_auth_config_with_token() {
let config = WsAuthConfig::new(AcceptAllValidator)
.extractor(TokenExtractor::header("Authorization"));
let req = Request::builder()
.header("Authorization", "Bearer my-token")
.body(())
.unwrap();
let result = config.authenticate(&req).await;
assert!(result.is_ok());
let claims = result.unwrap().unwrap();
assert_eq!(claims.subject(), "my-token");
}
#[test]
fn test_claims_extra() {
let mut claims = Claims::new("user-123");
claims.insert("role", "admin");
claims.insert("tenant", "acme");
assert_eq!(claims.subject(), "user-123");
assert_eq!(claims.get("role"), Some("admin"));
assert_eq!(claims.get("tenant"), Some("acme"));
assert_eq!(claims.get("missing"), None);
}
#[test]
fn test_auth_error_display() {
let err = AuthError::TokenMissing;
assert_eq!(err.to_string(), "Authentication token missing");
let err = AuthError::validation_failed("custom error");
assert_eq!(err.to_string(), "Token validation failed: custom error");
}
#[test]
fn test_token_extractor_default() {
let extractor = TokenExtractor::default();
match extractor {
TokenExtractor::Header(name) => assert_eq!(name, "Authorization"),
_ => panic!("Expected Header extractor"),
}
}
}
#[cfg(test)]
mod property_tests {
use super::*;
use proptest::prelude::*;
fn token_strategy() -> impl Strategy<Value = String> {
prop::string::string_regex("[a-zA-Z0-9._-]{1,100}").unwrap()
}
fn header_name_strategy() -> impl Strategy<Value = String> {
prop::string::string_regex("[A-Za-z][A-Za-z0-9-]{0,30}").unwrap()
}
fn query_param_strategy() -> impl Strategy<Value = String> {
prop::string::string_regex("[a-z][a-z0-9_]{0,20}").unwrap()
}
fn extractor_strategy() -> impl Strategy<Value = TokenExtractor> {
prop_oneof![
header_name_strategy().prop_map(TokenExtractor::Header),
query_param_strategy().prop_map(TokenExtractor::Query),
Just(TokenExtractor::Protocol),
]
}
proptest! {
#[test]
fn prop_auth_required_rejects_missing_token(
extractor in extractor_strategy()
) {
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
let config = WsAuthConfig::new(AcceptAllValidator)
.extractor(extractor)
.required(true);
let req = http::Request::builder()
.uri("ws://localhost/ws")
.body(())
.unwrap();
let result = config.authenticate(&req).await;
prop_assert!(matches!(result, Err(AuthError::TokenMissing)));
Ok(())
})?;
}
#[test]
fn prop_auth_accepts_valid_token_in_header(
token in token_strategy(),
header_name in header_name_strategy()
) {
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
let config = WsAuthConfig::new(AcceptAllValidator)
.extractor(TokenExtractor::Header(header_name.clone()))
.required(true);
let req = http::Request::builder()
.uri("ws://localhost/ws")
.header(&header_name, format!("Bearer {}", token))
.body(())
.unwrap();
let result = config.authenticate(&req).await;
prop_assert!(result.is_ok());
let claims = result.unwrap();
prop_assert!(claims.is_some());
let claims = claims.unwrap();
prop_assert_eq!(claims.subject(), &token);
Ok(())
})?;
}
#[test]
fn prop_auth_accepts_valid_token_in_query(
token in token_strategy(),
param_name in query_param_strategy()
) {
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
let config = WsAuthConfig::new(AcceptAllValidator)
.extractor(TokenExtractor::Query(param_name.clone()))
.required(true);
let uri = format!("ws://localhost/ws?{}={}", param_name, token);
let req = http::Request::builder()
.uri(&uri)
.body(())
.unwrap();
let result = config.authenticate(&req).await;
prop_assert!(result.is_ok());
let claims = result.unwrap();
prop_assert!(claims.is_some());
let claims = claims.unwrap();
prop_assert_eq!(claims.subject(), &token);
Ok(())
})?;
}
#[test]
fn prop_auth_rejects_invalid_token(
token in token_strategy()
) {
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
let config = WsAuthConfig::new(RejectAllValidator)
.extractor(TokenExtractor::Header("Authorization".to_string()))
.required(true);
let req = http::Request::builder()
.uri("ws://localhost/ws")
.header("Authorization", format!("Bearer {}", token))
.body(())
.unwrap();
let result = config.authenticate(&req).await;
prop_assert!(result.is_err());
prop_assert!(matches!(result, Err(AuthError::ValidationFailed(_))));
Ok(())
})?;
}
#[test]
fn prop_optional_auth_allows_missing_token(
extractor in extractor_strategy()
) {
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
let config = WsAuthConfig::new(AcceptAllValidator)
.extractor(extractor)
.required(false);
let req = http::Request::builder()
.uri("ws://localhost/ws")
.body(())
.unwrap();
let result = config.authenticate(&req).await;
prop_assert!(result.is_ok());
prop_assert!(result.unwrap().is_none());
Ok(())
})?;
}
#[test]
fn prop_static_validator_only_accepts_known_tokens(
valid_token in token_strategy(),
test_token in token_strategy(),
user_id in "[a-z]{3,10}"
) {
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
let validator = StaticTokenValidator::new()
.add_token(valid_token.clone(), Claims::new(user_id.clone()));
let result = validator.validate(&test_token).await;
if test_token == valid_token {
prop_assert!(result.is_ok());
let claims = result.unwrap();
prop_assert_eq!(claims.subject(), &user_id);
} else {
prop_assert!(result.is_err());
}
Ok(())
})?;
}
}
}