use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::AtomicBool;
use async_trait::async_trait;
use axum::Router;
use axum::body::Body;
use axum::http::{Request, StatusCode, Version, header};
use camel_api::security_policy::{AuthorizationDecision, Principal, SecurityPolicy};
use camel_api::{CamelError, Exchange};
use camel_auth::TokenAuthenticator;
use camel_component_api::{ExchangeEnvelope, SecurityContext};
use camel_component_ws::{WsAppState, WsPathConfig, dispatch_handler};
use serde_json::json;
use tokio::sync::{RwLock, mpsc};
use tower::ServiceExt;
struct MockAuthenticator {
should_succeed: bool,
}
#[async_trait]
impl TokenAuthenticator for MockAuthenticator {
async fn authenticate_bearer(&self, _token: &str) -> Result<Principal, CamelError> {
if self.should_succeed {
Ok(test_principal())
} else {
Err(CamelError::Unauthenticated("invalid token".into()))
}
}
}
struct AlwaysGrantPolicy;
#[async_trait]
impl SecurityPolicy for AlwaysGrantPolicy {
async fn evaluate(
&self,
_exchange: &mut Exchange,
) -> Result<AuthorizationDecision, CamelError> {
Ok(AuthorizationDecision::Granted {
principal: test_principal(),
})
}
}
struct AlwaysDenyPolicy;
#[async_trait]
impl SecurityPolicy for AlwaysDenyPolicy {
async fn evaluate(
&self,
_exchange: &mut Exchange,
) -> Result<AuthorizationDecision, CamelError> {
Ok(AuthorizationDecision::Denied {
reason: "no roles assigned".into(),
required: vec!["admin".into()],
actual: vec![],
})
}
}
struct FailPolicy;
#[async_trait]
impl SecurityPolicy for FailPolicy {
async fn evaluate(
&self,
_exchange: &mut Exchange,
) -> Result<AuthorizationDecision, CamelError> {
Err(CamelError::ProcessorError("policy error".into()))
}
}
fn test_principal() -> Principal {
Principal {
subject: "test-user".into(),
issuer: "test-issuer".into(),
audience: vec!["api".into()],
scopes: vec!["read".into(), "write".into()],
roles: vec!["admin".into()],
claims: json!({"sub": "test-user"}),
}
}
fn test_security_context(
authenticator: impl TokenAuthenticator + 'static,
policy: impl SecurityPolicy + 'static,
) -> SecurityContext {
SecurityContext::new(policy, Arc::new(authenticator))
}
fn make_app_state(path: &str, sec_ctx: Option<SecurityContext>) -> WsAppState {
let (tx, _rx) = mpsc::channel::<ExchangeEnvelope>(64);
let dispatch: Arc<RwLock<HashMap<String, mpsc::Sender<ExchangeEnvelope>>>> =
Arc::new(RwLock::new([(path.to_string(), tx)].into_iter().collect()));
let path_configs = Arc::new(dashmap::DashMap::new());
path_configs.insert(
path.to_string(),
WsPathConfig {
max_connections: 100,
max_message_size: 65536,
heartbeat_interval: std::time::Duration::ZERO,
idle_timeout: std::time::Duration::ZERO,
allow_origin: "*".into(),
},
);
let path_policies = Arc::new(dashmap::DashMap::new());
if let Some(ctx) = sec_ctx {
path_policies.insert(path.to_string(), ctx);
}
WsAppState {
dispatch,
path_configs,
path_policies,
server_error: Arc::new(AtomicBool::new(false)),
}
}
fn free_port() -> u16 {
let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
listener.local_addr().unwrap().port()
}
fn ws_upgrade_request(port: u16, path: &str, auth_header: Option<&str>) -> Request<Body> {
let mut builder = Request::builder()
.method("GET")
.uri(format!("http://127.0.0.1:{}{}", port, path))
.version(Version::HTTP_11)
.header(header::UPGRADE, "websocket")
.header(header::CONNECTION, "Upgrade")
.header(header::SEC_WEBSOCKET_KEY, "dGhlIHNhbXBsZSBub25jZQ==")
.header(header::SEC_WEBSOCKET_VERSION, "13");
if let Some(auth) = auth_header {
builder = builder.header(header::AUTHORIZATION, auth);
}
builder.body(Body::empty()).unwrap()
}
#[tokio::test]
async fn test_ws_401_without_auth() {
let port = free_port();
let path = "/ws/auth";
let state = make_app_state(
path,
Some(test_security_context(
MockAuthenticator {
should_succeed: true,
},
AlwaysGrantPolicy,
)),
);
let app = Router::new().fallback(dispatch_handler).with_state(state);
let req = ws_upgrade_request(port, path, None);
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn test_ws_401_invalid_token() {
let port = free_port();
let path = "/ws/auth";
let state = make_app_state(
path,
Some(test_security_context(
MockAuthenticator {
should_succeed: false,
},
AlwaysGrantPolicy,
)),
);
let app = Router::new().fallback(dispatch_handler).with_state(state);
let req = ws_upgrade_request(port, path, Some("Bearer invalid-token"));
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn test_ws_403_policy_deny() {
let port = free_port();
let path = "/ws/auth";
let state = make_app_state(
path,
Some(test_security_context(
MockAuthenticator {
should_succeed: true,
},
AlwaysDenyPolicy,
)),
);
let app = Router::new().fallback(dispatch_handler).with_state(state);
let req = ws_upgrade_request(port, path, Some("Bearer valid-token"));
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::FORBIDDEN);
}
#[tokio::test]
async fn test_ws_auth_passes_with_valid_token() {
let port = free_port();
let path = "/ws/auth";
let state = make_app_state(
path,
Some(test_security_context(
MockAuthenticator {
should_succeed: true,
},
AlwaysGrantPolicy,
)),
);
let app = Router::new().fallback(dispatch_handler).with_state(state);
let req = ws_upgrade_request(port, path, Some("Bearer valid-token"));
let resp = app.oneshot(req).await.unwrap();
assert_ne!(resp.status(), StatusCode::UNAUTHORIZED);
assert_ne!(resp.status(), StatusCode::FORBIDDEN);
}
#[tokio::test]
async fn test_ws_unprotected_path_skips_auth() {
let port = free_port();
let path = "/ws/open";
let state = make_app_state(path, None);
let app = Router::new().fallback(dispatch_handler).with_state(state);
let req = ws_upgrade_request(port, path, None);
let resp = app.oneshot(req).await.unwrap();
assert_ne!(resp.status(), StatusCode::UNAUTHORIZED);
assert_ne!(resp.status(), StatusCode::FORBIDDEN);
}
#[tokio::test]
async fn test_ws_500_policy_error() {
let port = free_port();
let path = "/ws/auth";
let state = make_app_state(
path,
Some(test_security_context(
MockAuthenticator {
should_succeed: true,
},
FailPolicy,
)),
);
let app = Router::new().fallback(dispatch_handler).with_state(state);
let req = ws_upgrade_request(port, path, Some("Bearer valid-token"));
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
}
#[tokio::test]
async fn test_ws_401_malformed_auth_scheme() {
let port = free_port();
let path = "/ws/auth";
let state = make_app_state(
path,
Some(test_security_context(
MockAuthenticator {
should_succeed: true,
},
AlwaysGrantPolicy,
)),
);
let app = Router::new().fallback(dispatch_handler).with_state(state);
let req = ws_upgrade_request(port, path, Some("Basic abc123"));
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
}