use axum::{
body::Body,
extract::{
ws::{Message, WebSocket, WebSocketUpgrade},
ConnectInfo, State,
},
http::{header::HeaderMap, Request, StatusCode},
middleware::Next,
response::{IntoResponse, Response},
};
#[allow(unused_imports)]
use futures_util::StreamExt;
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use std::{
collections::HashMap,
net::SocketAddr,
sync::Arc,
time::{Duration, Instant},
};
use thiserror::Error;
use tokio::sync::mpsc;
use tracing::{debug, error, info, instrument, warn};
use uuid::Uuid;
static GLOBAL_NONCE_TRACKER: once_cell::sync::Lazy<NonceTracker> =
once_cell::sync::Lazy::new(NonceTracker::new);
#[derive(Debug, Error)]
pub enum WsAuthError {
#[error("Missing API key")]
MissingApiKey,
#[error("Invalid API key")]
InvalidApiKey,
#[error("API key expired")]
ExpiredApiKey,
#[error("Subscription tier '{0}' does not allow WebSocket access")]
TierNotAllowed(String),
#[error("Connection limit exceeded for tier '{0}': max {1} connections")]
ConnectionLimitExceeded(String, usize),
#[error("Rate limit exceeded: {0} requests per minute allowed")]
RateLimitExceeded(u32),
#[error("Authentication timeout: must authenticate within {0} seconds")]
AuthTimeout(u64),
#[error("Invalid authentication message format")]
InvalidAuthMessage,
#[error("Replay attack detected: nonce already used or timestamp invalid")]
ReplayAttack,
#[error("Internal authentication error: {0}")]
Internal(String),
}
impl IntoResponse for WsAuthError {
fn into_response(self) -> Response {
let (status, message) = match &self {
WsAuthError::MissingApiKey => (StatusCode::UNAUTHORIZED, self.to_string()),
WsAuthError::InvalidApiKey => (StatusCode::UNAUTHORIZED, self.to_string()),
WsAuthError::ExpiredApiKey => (StatusCode::UNAUTHORIZED, self.to_string()),
WsAuthError::TierNotAllowed(_) => (StatusCode::FORBIDDEN, self.to_string()),
WsAuthError::ConnectionLimitExceeded(_, _) => {
(StatusCode::TOO_MANY_REQUESTS, self.to_string())
}
WsAuthError::RateLimitExceeded(_) => (StatusCode::TOO_MANY_REQUESTS, self.to_string()),
WsAuthError::AuthTimeout(_) => (StatusCode::REQUEST_TIMEOUT, self.to_string()),
WsAuthError::InvalidAuthMessage => (StatusCode::BAD_REQUEST, self.to_string()),
WsAuthError::ReplayAttack => (StatusCode::UNAUTHORIZED, self.to_string()),
WsAuthError::Internal(_) => (
StatusCode::INTERNAL_SERVER_ERROR,
"Internal error".to_string(),
),
};
(status, message).into_response()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum SubscriptionTier {
Free,
Pro,
Team,
Enterprise,
}
impl SubscriptionTier {
pub fn max_connections(&self) -> usize {
match self {
SubscriptionTier::Free => 1,
SubscriptionTier::Pro => 5,
SubscriptionTier::Team => 25,
SubscriptionTier::Enterprise => 100,
}
}
pub fn rate_limit(&self) -> u32 {
match self {
SubscriptionTier::Free => 60,
SubscriptionTier::Pro => 300,
SubscriptionTier::Team => 1000,
SubscriptionTier::Enterprise => 10000,
}
}
pub fn max_message_size(&self) -> usize {
match self {
SubscriptionTier::Free => 64 * 1024, SubscriptionTier::Pro => 1024 * 1024, SubscriptionTier::Team => 10 * 1024 * 1024, SubscriptionTier::Enterprise => 100 * 1024 * 1024, }
}
pub fn session_timeout(&self) -> Duration {
match self {
SubscriptionTier::Free => Duration::from_secs(30 * 60), SubscriptionTier::Pro => Duration::from_secs(2 * 60 * 60), SubscriptionTier::Team => Duration::from_secs(8 * 60 * 60), SubscriptionTier::Enterprise => Duration::from_secs(24 * 60 * 60), }
}
pub fn allows_websocket(&self) -> bool {
true
}
}
impl std::fmt::Display for SubscriptionTier {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
SubscriptionTier::Free => write!(f, "free"),
SubscriptionTier::Pro => write!(f, "pro"),
SubscriptionTier::Team => write!(f, "team"),
SubscriptionTier::Enterprise => write!(f, "enterprise"),
}
}
}
#[derive(Debug, Clone)]
pub struct ApiKeyInfo {
pub key_id: String,
pub owner_id: String,
pub tier: SubscriptionTier,
pub expires_at: Option<Instant>,
pub metadata: HashMap<String, String>,
}
#[async_trait::async_trait]
pub trait ApiKeyValidator: Send + Sync + 'static {
async fn validate(&self, api_key: &str) -> Result<ApiKeyInfo, WsAuthError>;
async fn revoke(&self, key_id: &str) -> Result<(), WsAuthError> {
let _ = key_id;
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct InMemoryApiKeyValidator {
keys: Arc<RwLock<HashMap<String, ApiKeyInfo>>>,
}
impl InMemoryApiKeyValidator {
pub fn new() -> Self {
Self {
keys: Arc::new(RwLock::new(HashMap::new())),
}
}
pub fn add_key(&self, api_key: String, info: ApiKeyInfo) {
self.keys.write().insert(api_key, info);
}
pub fn remove_key(&self, api_key: &str) {
self.keys.write().remove(api_key);
}
}
impl Default for InMemoryApiKeyValidator {
fn default() -> Self {
Self::new()
}
}
#[async_trait::async_trait]
impl ApiKeyValidator for InMemoryApiKeyValidator {
async fn validate(&self, api_key: &str) -> Result<ApiKeyInfo, WsAuthError> {
let keys = self.keys.read();
let mut found_info: Option<&ApiKeyInfo> = None;
for (stored_key, info) in keys.iter() {
if constant_time_compare(api_key, stored_key) {
found_info = Some(info);
break;
}
}
match found_info {
Some(info) => {
if let Some(expires_at) = info.expires_at {
if Instant::now() > expires_at {
return Err(WsAuthError::ExpiredApiKey);
}
}
Ok(info.clone())
}
None => Err(WsAuthError::InvalidApiKey),
}
}
async fn revoke(&self, key_id: &str) -> Result<(), WsAuthError> {
let mut keys = self.keys.write();
keys.retain(|_, v| v.key_id != key_id);
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct ConnectionInfo {
pub connection_id: Uuid,
pub key_id: String,
pub owner_id: String,
pub tier: SubscriptionTier,
pub remote_addr: SocketAddr,
pub connected_at: Instant,
pub last_activity: Instant,
pub request_count: u32,
pub rate_window_start: Instant,
}
#[derive(Debug)]
pub struct ConnectionTracker {
connections: RwLock<HashMap<Uuid, ConnectionInfo>>,
connection_counts: RwLock<HashMap<String, usize>>,
}
impl ConnectionTracker {
pub fn new() -> Self {
Self {
connections: RwLock::new(HashMap::new()),
connection_counts: RwLock::new(HashMap::new()),
}
}
pub fn register(
&self,
key_info: &ApiKeyInfo,
remote_addr: SocketAddr,
) -> Result<ConnectionInfo, WsAuthError> {
let mut counts = self.connection_counts.write();
let current_count = counts.get(&key_info.key_id).copied().unwrap_or(0);
let max_connections = key_info.tier.max_connections();
if current_count >= max_connections {
return Err(WsAuthError::ConnectionLimitExceeded(
key_info.tier.to_string(),
max_connections,
));
}
let now = Instant::now();
let conn_info = ConnectionInfo {
connection_id: Uuid::new_v4(),
key_id: key_info.key_id.clone(),
owner_id: key_info.owner_id.clone(),
tier: key_info.tier,
remote_addr,
connected_at: now,
last_activity: now,
request_count: 0,
rate_window_start: now,
};
*counts.entry(key_info.key_id.clone()).or_insert(0) += 1;
self.connections
.write()
.insert(conn_info.connection_id, conn_info.clone());
info!(
connection_id = %conn_info.connection_id,
key_id = %key_info.key_id,
tier = %key_info.tier,
"New WebSocket connection registered"
);
Ok(conn_info)
}
pub fn unregister(&self, connection_id: Uuid) {
let mut connections = self.connections.write();
if let Some(conn_info) = connections.remove(&connection_id) {
let mut counts = self.connection_counts.write();
if let Some(count) = counts.get_mut(&conn_info.key_id) {
*count = count.saturating_sub(1);
if *count == 0 {
counts.remove(&conn_info.key_id);
}
}
info!(
connection_id = %connection_id,
key_id = %conn_info.key_id,
"WebSocket connection unregistered"
);
}
}
pub fn check_rate_limit(&self, connection_id: Uuid) -> Result<(), WsAuthError> {
let mut connections = self.connections.write();
if let Some(conn_info) = connections.get_mut(&connection_id) {
let now = Instant::now();
let rate_limit = conn_info.tier.rate_limit();
if now.duration_since(conn_info.rate_window_start) > Duration::from_secs(60) {
conn_info.rate_window_start = now;
conn_info.request_count = 0;
}
conn_info.request_count += 1;
conn_info.last_activity = now;
if conn_info.request_count > rate_limit {
return Err(WsAuthError::RateLimitExceeded(rate_limit));
}
}
Ok(())
}
pub fn get(&self, connection_id: Uuid) -> Option<ConnectionInfo> {
self.connections.read().get(&connection_id).cloned()
}
pub fn get_by_key(&self, key_id: &str) -> Vec<ConnectionInfo> {
self.connections
.read()
.values()
.filter(|c| c.key_id == key_id)
.cloned()
.collect()
}
pub fn total_connections(&self) -> usize {
self.connections.read().len()
}
pub fn connection_count(&self, key_id: &str) -> usize {
self.connection_counts
.read()
.get(key_id)
.copied()
.unwrap_or(0)
}
pub fn cleanup_stale(&self, max_idle: Duration) {
let now = Instant::now();
let mut to_remove = Vec::new();
{
let connections = self.connections.read();
for (id, info) in connections.iter() {
if now.duration_since(info.last_activity) > max_idle {
to_remove.push(*id);
}
}
}
for id in to_remove {
self.unregister(id);
debug!(connection_id = %id, "Cleaned up stale connection");
}
}
}
impl Default for ConnectionTracker {
fn default() -> Self {
Self::new()
}
}
#[derive(Clone)]
pub struct WsAuthState<V: ApiKeyValidator> {
pub validator: Arc<V>,
pub tracker: Arc<ConnectionTracker>,
pub auth_timeout: Duration,
pub api_key_header: String,
pub require_tls: bool,
}
impl<V: ApiKeyValidator> WsAuthState<V> {
pub fn new(validator: V) -> Self {
Self {
validator: Arc::new(validator),
tracker: Arc::new(ConnectionTracker::new()),
auth_timeout: Duration::from_secs(30),
api_key_header: "Authorization".to_string(),
require_tls: false,
}
}
pub fn with_auth_timeout(mut self, timeout: Duration) -> Self {
self.auth_timeout = timeout;
self
}
pub fn with_api_key_header(mut self, header: impl Into<String>) -> Self {
self.api_key_header = header.into();
self
}
pub fn with_require_tls(mut self, require: bool) -> Self {
self.require_tls = require;
self
}
pub fn extract_api_key_from_headers(&self, headers: &HeaderMap) -> Option<String> {
headers
.get(&self.api_key_header)
.and_then(|v| v.to_str().ok())
.map(|s| {
s.strip_prefix("Bearer ").unwrap_or(s).to_string()
})
}
}
#[derive(Debug, Deserialize)]
pub struct WsAuthMessage {
pub api_key: String,
#[serde(default)]
pub nonce: Option<String>,
#[serde(default)]
pub timestamp: Option<i64>,
#[serde(default)]
pub client_info: HashMap<String, String>,
}
#[derive(Debug, Serialize)]
pub struct WsAuthResult {
pub success: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub connection_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tier: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub rate_limit: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub session_timeout_secs: Option<u64>,
}
pub struct AuthenticatedWsConnection {
pub info: ConnectionInfo,
pub socket: WebSocket,
tracker: Arc<ConnectionTracker>,
}
impl AuthenticatedWsConnection {
pub async fn send(&mut self, msg: Message) -> Result<(), WsAuthError> {
self.tracker.check_rate_limit(self.info.connection_id)?;
self.socket
.send(msg)
.await
.map_err(|e| WsAuthError::Internal(e.to_string()))
}
pub async fn recv(&mut self) -> Option<Result<Message, axum::Error>> {
self.socket.recv().await
}
pub fn connection_id(&self) -> Uuid {
self.info.connection_id
}
pub fn tier(&self) -> SubscriptionTier {
self.info.tier
}
}
#[instrument(skip(ws, state))]
pub async fn ws_handler_with_header_auth<V: ApiKeyValidator>(
ws: WebSocketUpgrade,
ConnectInfo(addr): ConnectInfo<SocketAddr>,
State(state): State<WsAuthState<V>>,
headers: HeaderMap,
) -> Result<Response, WsAuthError> {
let api_key = state
.extract_api_key_from_headers(&headers)
.ok_or(WsAuthError::MissingApiKey)?;
let key_info = state.validator.validate(&api_key).await?;
if !key_info.tier.allows_websocket() {
return Err(WsAuthError::TierNotAllowed(key_info.tier.to_string()));
}
let conn_info = state.tracker.register(&key_info, addr)?;
info!(
connection_id = %conn_info.connection_id,
tier = %key_info.tier,
remote_addr = %addr,
"WebSocket connection authenticated via header"
);
let tracker = Arc::clone(&state.tracker);
Ok(ws.on_upgrade(move |socket| async move {
handle_authenticated_socket(socket, conn_info, tracker).await;
}))
}
#[instrument(skip(ws, state))]
pub async fn ws_handler_with_message_auth<V: ApiKeyValidator>(
ws: WebSocketUpgrade,
ConnectInfo(addr): ConnectInfo<SocketAddr>,
State(state): State<WsAuthState<V>>,
headers: HeaderMap,
) -> impl IntoResponse {
if let Some(api_key) = state.extract_api_key_from_headers(&headers) {
match state.validator.validate(&api_key).await {
Ok(key_info) => {
if !key_info.tier.allows_websocket() {
return Err(WsAuthError::TierNotAllowed(key_info.tier.to_string()));
}
match state.tracker.register(&key_info, addr) {
Ok(conn_info) => {
let tracker = Arc::clone(&state.tracker);
return Ok(ws.on_upgrade(move |socket| async move {
handle_authenticated_socket(socket, conn_info, tracker).await;
}));
}
Err(e) => return Err(e),
}
}
Err(_) => {
}
}
}
let validator = Arc::clone(&state.validator);
let tracker = Arc::clone(&state.tracker);
let auth_timeout = state.auth_timeout;
Ok(ws.on_upgrade(move |socket| async move {
handle_unauthenticated_upgrade(socket, addr, validator, tracker, auth_timeout).await;
}))
}
async fn handle_unauthenticated_upgrade<V: ApiKeyValidator>(
mut socket: WebSocket,
addr: SocketAddr,
validator: Arc<V>,
tracker: Arc<ConnectionTracker>,
auth_timeout: Duration,
) {
let auth_result = tokio::time::timeout(auth_timeout, socket.recv()).await;
let auth_msg = match auth_result {
Ok(Some(Ok(Message::Text(text)))) => match serde_json::from_str::<WsAuthMessage>(&text) {
Ok(msg) => msg,
Err(e) => {
let _ = send_auth_error(&mut socket, &WsAuthError::InvalidAuthMessage).await;
warn!(error = %e, "Invalid auth message format");
return;
}
},
Ok(Some(Ok(_))) => {
let _ = send_auth_error(&mut socket, &WsAuthError::InvalidAuthMessage).await;
warn!("First message must be text auth message");
return;
}
Ok(Some(Err(e))) => {
warn!(error = %e, "WebSocket error during auth");
return;
}
Ok(None) => {
warn!("Connection closed before authentication");
return;
}
Err(_) => {
let _ = send_auth_error(
&mut socket,
&WsAuthError::AuthTimeout(auth_timeout.as_secs()),
)
.await;
warn!(
timeout_secs = auth_timeout.as_secs(),
"Authentication timeout"
);
return;
}
};
let key_info = match validator.validate(&auth_msg.api_key).await {
Ok(info) => info,
Err(e) => {
let _ = send_auth_error(&mut socket, &e).await;
warn!(error = %e, "API key validation failed");
return;
}
};
let (nonce, timestamp) = match (auth_msg.nonce, auth_msg.timestamp) {
(Some(n), Some(t)) => (n, t),
(None, _) => {
let err = WsAuthError::ReplayAttack;
let _ = send_auth_error(&mut socket, &err).await;
warn!("Authentication rejected: missing nonce (replay protection required)");
return;
}
(_, None) => {
let err = WsAuthError::ReplayAttack;
let _ = send_auth_error(&mut socket, &err).await;
warn!("Authentication rejected: missing timestamp (replay protection required)");
return;
}
};
if let Err(e) = GLOBAL_NONCE_TRACKER.validate_and_consume(&nonce, timestamp) {
let _ = send_auth_error(&mut socket, &WsAuthError::ReplayAttack).await;
warn!(
nonce = %nonce,
timestamp = %timestamp,
error = %e,
"Replay attack prevention: nonce validation failed"
);
return;
}
debug!(
nonce = %nonce,
timestamp = %timestamp,
"Nonce validated successfully - replay attack protection active"
);
if !key_info.tier.allows_websocket() {
let err = WsAuthError::TierNotAllowed(key_info.tier.to_string());
let _ = send_auth_error(&mut socket, &err).await;
return;
}
let conn_info = match tracker.register(&key_info, addr) {
Ok(info) => info,
Err(e) => {
let _ = send_auth_error(&mut socket, &e).await;
return;
}
};
let auth_result = WsAuthResult {
success: true,
error: None,
connection_id: Some(conn_info.connection_id.to_string()),
tier: Some(conn_info.tier.to_string()),
rate_limit: Some(conn_info.tier.rate_limit()),
session_timeout_secs: Some(conn_info.tier.session_timeout().as_secs()),
};
if let Ok(json) = serde_json::to_string(&auth_result) {
let _ = socket.send(Message::Text(json)).await;
}
info!(
connection_id = %conn_info.connection_id,
tier = %key_info.tier,
remote_addr = %addr,
"WebSocket connection authenticated via first message"
);
handle_authenticated_socket(socket, conn_info, tracker).await;
}
async fn send_auth_error(socket: &mut WebSocket, error: &WsAuthError) -> Result<(), axum::Error> {
let result = WsAuthResult {
success: false,
error: Some(error.to_string()),
connection_id: None,
tier: None,
rate_limit: None,
session_timeout_secs: None,
};
if let Ok(json) = serde_json::to_string(&result) {
socket.send(Message::Text(json)).await?;
}
socket
.send(Message::Close(Some(axum::extract::ws::CloseFrame {
code: axum::extract::ws::close_code::POLICY,
reason: error.to_string().into(),
})))
.await?;
Ok(())
}
async fn handle_authenticated_socket(
mut socket: WebSocket,
conn_info: ConnectionInfo,
tracker: Arc<ConnectionTracker>,
) {
let connection_id = conn_info.connection_id;
let tier = conn_info.tier;
let (_tx, mut rx) = mpsc::channel::<Message>(100);
let send_task = tokio::spawn({
let tracker = Arc::clone(&tracker);
async move {
while let Some(_msg) = rx.recv().await {
if let Err(e) = tracker.check_rate_limit(connection_id) {
warn!(
connection_id = %connection_id,
error = %e,
"Rate limit exceeded"
);
let _error_msg = serde_json::json!({
"jsonrpc": "2.0",
"error": {
"code": -32000,
"message": e.to_string()
}
});
break;
}
}
}
});
while let Some(msg) = socket.recv().await {
match msg {
Ok(Message::Text(text)) => {
debug!(
connection_id = %connection_id,
msg_len = text.len(),
"Received text message"
);
if text.len() > tier.max_message_size() {
warn!(
connection_id = %connection_id,
size = text.len(),
max = tier.max_message_size(),
"Message size exceeds tier limit"
);
let error_msg = serde_json::json!({
"jsonrpc": "2.0",
"error": {
"code": -32000,
"message": format!("Message size {} exceeds limit {}", text.len(), tier.max_message_size())
}
});
if let Ok(json) = serde_json::to_string(&error_msg) {
let _ = socket.send(Message::Text(json)).await;
}
continue;
}
let _ = socket.send(Message::Text(text)).await;
}
Ok(Message::Binary(data)) => {
debug!(
connection_id = %connection_id,
size = data.len(),
"Received binary message"
);
if data.len() > tier.max_message_size() {
warn!(
connection_id = %connection_id,
size = data.len(),
max = tier.max_message_size(),
"Binary message size exceeds tier limit"
);
continue;
}
let _ = socket.send(Message::Binary(data)).await;
}
Ok(Message::Ping(data)) => {
let _ = socket.send(Message::Pong(data)).await;
}
Ok(Message::Pong(_)) => {
}
Ok(Message::Close(_)) => {
info!(connection_id = %connection_id, "Client initiated close");
break;
}
Err(e) => {
error!(
connection_id = %connection_id,
error = %e,
"WebSocket error"
);
break;
}
}
}
send_task.abort();
tracker.unregister(connection_id);
info!(connection_id = %connection_id, "Connection closed");
}
pub async fn ws_auth_middleware<V: ApiKeyValidator>(
State(state): State<WsAuthState<V>>,
request: Request<Body>,
next: Next,
) -> Result<Response, WsAuthError> {
let is_upgrade = request
.headers()
.get("upgrade")
.and_then(|v| v.to_str().ok())
.map(|v| v.eq_ignore_ascii_case("websocket"))
.unwrap_or(false);
if !is_upgrade {
return Ok(next.run(request).await);
}
if state.require_tls {
let scheme = request.uri().scheme_str().unwrap_or("http");
if scheme != "https" && scheme != "wss" {
warn!("WebSocket connection rejected: TLS required");
return Err(WsAuthError::Internal(
"Secure connection (wss://) required".to_string(),
));
}
}
Ok(next.run(request).await)
}
fn constant_time_compare(a: &str, b: &str) -> bool {
let a_bytes = a.as_bytes();
let b_bytes = b.as_bytes();
if a_bytes.len() != b_bytes.len() {
let mut _dummy: u8 = 0;
for byte in a_bytes.iter() {
_dummy |= *byte;
}
return false;
}
let mut result: u8 = 0;
for (x, y) in a_bytes.iter().zip(b_bytes.iter()) {
result |= x ^ y;
}
result == 0
}
pub fn generate_api_key() -> String {
format!("rk_{}", Uuid::new_v4().to_string().replace('-', ""))
}
#[derive(Debug)]
pub struct NonceTracker {
used_nonces: RwLock<HashMap<String, Instant>>,
validity_window: Duration,
}
impl NonceTracker {
pub fn new() -> Self {
Self {
used_nonces: RwLock::new(HashMap::new()),
validity_window: Duration::from_secs(300),
}
}
pub fn with_validity(validity_window: Duration) -> Self {
Self {
used_nonces: RwLock::new(HashMap::new()),
validity_window,
}
}
pub fn validate_and_consume(&self, nonce: &str, timestamp: i64) -> Result<(), WsAuthError> {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs() as i64;
let age_secs = (now - timestamp).abs();
if age_secs > self.validity_window.as_secs() as i64 {
warn!(
timestamp = %timestamp,
age_secs = %age_secs,
validity_window_secs = %self.validity_window.as_secs(),
"Replay attack detected: timestamp outside validity window"
);
return Err(WsAuthError::ReplayAttack);
}
let mut used = self.used_nonces.write();
let expiry_threshold = Instant::now() - self.validity_window;
used.retain(|_, &mut exp| exp > expiry_threshold);
if used.contains_key(nonce) {
warn!(nonce = %nonce, "Replay attack detected: nonce already used");
return Err(WsAuthError::ReplayAttack);
}
used.insert(nonce.to_string(), Instant::now());
Ok(())
}
pub fn generate_nonce() -> String {
use std::time::{SystemTime, UNIX_EPOCH};
let timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_nanos();
let random = Uuid::new_v4();
format!("{}:{}", timestamp, random)
}
}
impl Default for NonceTracker {
fn default() -> Self {
Self::new()
}
}
pub fn is_url_safe(url: &str) -> Result<bool, WsAuthError> {
use std::net::IpAddr;
let parsed =
url::Url::parse(url).map_err(|_| WsAuthError::Internal("Invalid URL".to_string()))?;
match parsed.scheme() {
"http" | "https" => {}
scheme => {
warn!(scheme = %scheme, "SSRF: Blocked scheme");
return Ok(false);
}
}
let host = match parsed.host_str() {
Some(h) => h,
None => return Ok(false),
};
let localhost_variants = ["localhost", "127.0.0.1", "::1", "[::1]", "0.0.0.0", "0"];
if localhost_variants
.iter()
.any(|&l| host.eq_ignore_ascii_case(l))
{
warn!(host = %host, "SSRF: Blocked localhost");
return Ok(false);
}
if let Ok(ip) = host.parse::<IpAddr>() {
if !is_public_ip(&ip) {
warn!(ip = %ip, "SSRF: Blocked private/reserved IP");
return Ok(false);
}
}
let blocked_suffixes = [
".internal",
".local",
".localhost",
".lan",
".corp",
".home",
];
if blocked_suffixes
.iter()
.any(|&s| host.to_lowercase().ends_with(s))
{
warn!(host = %host, "SSRF: Blocked internal domain");
return Ok(false);
}
let blocked_hosts = [
"169.254.169.254", "metadata.google.internal", "metadata", ];
if blocked_hosts.iter().any(|&h| host.eq_ignore_ascii_case(h)) {
warn!(host = %host, "SSRF: Blocked cloud metadata endpoint");
return Ok(false);
}
Ok(true)
}
fn is_public_ip(ip: &std::net::IpAddr) -> bool {
match ip {
std::net::IpAddr::V4(ipv4) => {
!ipv4.is_private()
&& !ipv4.is_loopback()
&& !ipv4.is_link_local()
&& !ipv4.is_broadcast()
&& !ipv4.is_documentation()
&& !ipv4.is_unspecified()
&& !(ipv4.octets()[0] == 100 && (ipv4.octets()[1] >= 64 && ipv4.octets()[1] <= 127))
&& !(ipv4.octets()[0] == 192 && ipv4.octets()[1] == 0 && ipv4.octets()[2] == 0)
}
std::net::IpAddr::V6(ipv6) => {
!ipv6.is_loopback()
&& !ipv6.is_unspecified()
&& ((ipv6.segments()[0] & 0xffc0) != 0xfe80)
&& ((ipv6.segments()[0] & 0xfe00) != 0xfc00)
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KeyRotationInfo {
pub old_key_id: String,
pub new_key_id: String,
pub grace_period_end: std::time::SystemTime,
pub old_key_used: bool,
}
#[derive(Debug, Clone)]
pub struct ApiKeyInfoWithRotation {
pub info: ApiKeyInfo,
pub rotation: Option<KeyRotationInfo>,
pub accepts_rotated_key: bool,
}
pub fn validate_with_rotation_grace(
provided_key_id: &str,
current_key: &ApiKeyInfoWithRotation,
) -> bool {
if provided_key_id == current_key.info.key_id {
return true;
}
if let Some(ref rotation) = current_key.rotation {
if provided_key_id == rotation.old_key_id {
if std::time::SystemTime::now() < rotation.grace_period_end {
debug!(
old_key = %rotation.old_key_id,
new_key = %rotation.new_key_id,
"Accepting old key during rotation grace period"
);
return true;
}
}
}
false
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_subscription_tier_limits() {
assert_eq!(SubscriptionTier::Free.max_connections(), 1);
assert_eq!(SubscriptionTier::Pro.max_connections(), 5);
assert_eq!(SubscriptionTier::Team.max_connections(), 25);
assert_eq!(SubscriptionTier::Enterprise.max_connections(), 100);
}
#[test]
fn test_subscription_tier_rate_limits() {
assert_eq!(SubscriptionTier::Free.rate_limit(), 60);
assert_eq!(SubscriptionTier::Pro.rate_limit(), 300);
assert_eq!(SubscriptionTier::Team.rate_limit(), 1000);
assert_eq!(SubscriptionTier::Enterprise.rate_limit(), 10000);
}
#[test]
fn test_constant_time_compare() {
assert!(constant_time_compare("secret", "secret"));
assert!(!constant_time_compare("secret", "Secret"));
assert!(!constant_time_compare("short", "longer"));
assert!(!constant_time_compare("", "nonempty"));
}
#[test]
fn test_generate_api_key() {
let key = generate_api_key();
assert!(key.starts_with("rk_"));
assert_eq!(key.len(), 35); }
#[tokio::test]
async fn test_in_memory_validator() {
let validator = InMemoryApiKeyValidator::new();
let info = ApiKeyInfo {
key_id: "key_123".to_string(),
owner_id: "user_456".to_string(),
tier: SubscriptionTier::Pro,
expires_at: None,
metadata: HashMap::new(),
};
validator.add_key("test_api_key".to_string(), info.clone());
let result = validator.validate("test_api_key").await;
assert!(result.is_ok());
let validated = result.unwrap();
assert_eq!(validated.tier, SubscriptionTier::Pro);
let result = validator.validate("wrong_key").await;
assert!(matches!(result, Err(WsAuthError::InvalidApiKey)));
}
#[test]
fn test_connection_tracker() {
let tracker = ConnectionTracker::new();
let key_info = ApiKeyInfo {
key_id: "key_123".to_string(),
owner_id: "user_456".to_string(),
tier: SubscriptionTier::Free, expires_at: None,
metadata: HashMap::new(),
};
let addr: SocketAddr = "127.0.0.1:9100".parse().unwrap();
let conn1 = tracker.register(&key_info, addr);
assert!(conn1.is_ok());
let conn2 = tracker.register(&key_info, addr);
assert!(matches!(
conn2,
Err(WsAuthError::ConnectionLimitExceeded(_, 1))
));
tracker.unregister(conn1.unwrap().connection_id);
let conn3 = tracker.register(&key_info, addr);
assert!(conn3.is_ok());
}
#[test]
fn test_rate_limiting() {
let tracker = ConnectionTracker::new();
let key_info = ApiKeyInfo {
key_id: "key_123".to_string(),
owner_id: "user_456".to_string(),
tier: SubscriptionTier::Free, expires_at: None,
metadata: HashMap::new(),
};
let addr: SocketAddr = "127.0.0.1:9100".parse().unwrap();
let conn = tracker.register(&key_info, addr).unwrap();
for _ in 0..60 {
assert!(tracker.check_rate_limit(conn.connection_id).is_ok());
}
assert!(matches!(
tracker.check_rate_limit(conn.connection_id),
Err(WsAuthError::RateLimitExceeded(60))
));
}
#[test]
fn test_api_key_extraction() {
let validator = InMemoryApiKeyValidator::new();
let state = WsAuthState::new(validator);
let mut headers = HeaderMap::new();
headers.insert("Authorization", "Bearer my_api_key".parse().unwrap());
assert_eq!(
state.extract_api_key_from_headers(&headers),
Some("my_api_key".to_string())
);
headers.insert("Authorization", "raw_api_key".parse().unwrap());
assert_eq!(
state.extract_api_key_from_headers(&headers),
Some("raw_api_key".to_string())
);
let state = state.with_api_key_header("X-Api-Key");
headers.insert("X-Api-Key", "custom_key".parse().unwrap());
assert_eq!(
state.extract_api_key_from_headers(&headers),
Some("custom_key".to_string())
);
}
}