reasonkit/mcp/
ws_auth.rs

1//! WebSocket Authentication Middleware for MCP Server
2//!
3//! This module implements secure WebSocket authentication with:
4//! - API key validation via header or first message
5//! - Subscription tier enforcement (Free, Pro, Team, Enterprise)
6//! - Active connection tracking per API key
7//! - Rate limiting based on tier
8//! - Secure connection upgrade handling
9//!
10//! # Security Features
11//!
12//! - Constant-time API key comparison to prevent timing attacks
13//! - Connection limits per tier to prevent resource exhaustion
14//! - Automatic connection cleanup on disconnect
15//! - TLS enforcement in production (wss:// only)
16//!
17//! # Usage
18//!
19//! ```rust,ignore
20//! use reasonkit::mcp::ws_auth::{WsAuthLayer, WsAuthState, SubscriptionTier};
21//!
22//! let auth_state = WsAuthState::new(api_key_validator);
23//! let app = Router::new()
24//!     .route("/ws", get(ws_handler))
25//!     .layer(WsAuthLayer::new(auth_state));
26//! ```
27
28use axum::{
29    body::Body,
30    extract::{
31        ws::{Message, WebSocket, WebSocketUpgrade},
32        ConnectInfo, State,
33    },
34    http::{header::HeaderMap, Request, StatusCode},
35    middleware::Next,
36    response::{IntoResponse, Response},
37};
38// Note: StreamExt is used indirectly via WebSocket stream operations
39#[allow(unused_imports)]
40use futures_util::StreamExt;
41use parking_lot::RwLock;
42use serde::{Deserialize, Serialize};
43use std::{
44    collections::HashMap,
45    net::SocketAddr,
46    sync::Arc,
47    time::{Duration, Instant},
48};
49use thiserror::Error;
50use tokio::sync::mpsc;
51use tracing::{debug, error, info, instrument, warn};
52use uuid::Uuid;
53
54// ============================================================================
55// Error Types
56// ============================================================================
57
58/// Authentication and connection errors
59#[derive(Debug, Error)]
60pub enum WsAuthError {
61    #[error("Missing API key")]
62    MissingApiKey,
63
64    #[error("Invalid API key")]
65    InvalidApiKey,
66
67    #[error("API key expired")]
68    ExpiredApiKey,
69
70    #[error("Subscription tier '{0}' does not allow WebSocket access")]
71    TierNotAllowed(String),
72
73    #[error("Connection limit exceeded for tier '{0}': max {1} connections")]
74    ConnectionLimitExceeded(String, usize),
75
76    #[error("Rate limit exceeded: {0} requests per minute allowed")]
77    RateLimitExceeded(u32),
78
79    #[error("Authentication timeout: must authenticate within {0} seconds")]
80    AuthTimeout(u64),
81
82    #[error("Invalid authentication message format")]
83    InvalidAuthMessage,
84
85    #[error("Internal authentication error: {0}")]
86    Internal(String),
87}
88
89impl IntoResponse for WsAuthError {
90    fn into_response(self) -> Response {
91        let (status, message) = match &self {
92            WsAuthError::MissingApiKey => (StatusCode::UNAUTHORIZED, self.to_string()),
93            WsAuthError::InvalidApiKey => (StatusCode::UNAUTHORIZED, self.to_string()),
94            WsAuthError::ExpiredApiKey => (StatusCode::UNAUTHORIZED, self.to_string()),
95            WsAuthError::TierNotAllowed(_) => (StatusCode::FORBIDDEN, self.to_string()),
96            WsAuthError::ConnectionLimitExceeded(_, _) => {
97                (StatusCode::TOO_MANY_REQUESTS, self.to_string())
98            }
99            WsAuthError::RateLimitExceeded(_) => (StatusCode::TOO_MANY_REQUESTS, self.to_string()),
100            WsAuthError::AuthTimeout(_) => (StatusCode::REQUEST_TIMEOUT, self.to_string()),
101            WsAuthError::InvalidAuthMessage => (StatusCode::BAD_REQUEST, self.to_string()),
102            WsAuthError::Internal(_) => (
103                StatusCode::INTERNAL_SERVER_ERROR,
104                "Internal error".to_string(),
105            ),
106        };
107
108        (status, message).into_response()
109    }
110}
111
112// ============================================================================
113// Subscription Tiers
114// ============================================================================
115
116/// Subscription tier levels with associated limits
117#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
118#[serde(rename_all = "lowercase")]
119pub enum SubscriptionTier {
120    /// Free tier: Limited features
121    Free,
122    /// Pro tier: Enhanced limits
123    Pro,
124    /// Team tier: Collaborative features
125    Team,
126    /// Enterprise tier: Unlimited access
127    Enterprise,
128}
129
130impl SubscriptionTier {
131    /// Maximum concurrent WebSocket connections allowed
132    pub fn max_connections(&self) -> usize {
133        match self {
134            SubscriptionTier::Free => 1,
135            SubscriptionTier::Pro => 5,
136            SubscriptionTier::Team => 25,
137            SubscriptionTier::Enterprise => 100,
138        }
139    }
140
141    /// Maximum requests per minute
142    pub fn rate_limit(&self) -> u32 {
143        match self {
144            SubscriptionTier::Free => 60,
145            SubscriptionTier::Pro => 300,
146            SubscriptionTier::Team => 1000,
147            SubscriptionTier::Enterprise => 10000,
148        }
149    }
150
151    /// Maximum message size in bytes
152    pub fn max_message_size(&self) -> usize {
153        match self {
154            SubscriptionTier::Free => 64 * 1024,               // 64 KB
155            SubscriptionTier::Pro => 1024 * 1024,              // 1 MB
156            SubscriptionTier::Team => 10 * 1024 * 1024,        // 10 MB
157            SubscriptionTier::Enterprise => 100 * 1024 * 1024, // 100 MB
158        }
159    }
160
161    /// Session timeout duration
162    pub fn session_timeout(&self) -> Duration {
163        match self {
164            SubscriptionTier::Free => Duration::from_secs(30 * 60), // 30 min
165            SubscriptionTier::Pro => Duration::from_secs(2 * 60 * 60), // 2 hours
166            SubscriptionTier::Team => Duration::from_secs(8 * 60 * 60), // 8 hours
167            SubscriptionTier::Enterprise => Duration::from_secs(24 * 60 * 60), // 24 hours
168        }
169    }
170
171    /// Whether WebSocket access is allowed
172    pub fn allows_websocket(&self) -> bool {
173        // All tiers allow WebSocket, but with different limits
174        true
175    }
176}
177
178impl std::fmt::Display for SubscriptionTier {
179    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
180        match self {
181            SubscriptionTier::Free => write!(f, "free"),
182            SubscriptionTier::Pro => write!(f, "pro"),
183            SubscriptionTier::Team => write!(f, "team"),
184            SubscriptionTier::Enterprise => write!(f, "enterprise"),
185        }
186    }
187}
188
189// ============================================================================
190// API Key Types
191// ============================================================================
192
193/// Validated API key information
194#[derive(Debug, Clone)]
195pub struct ApiKeyInfo {
196    /// Unique API key identifier (hashed or prefix)
197    pub key_id: String,
198    /// User or organization identifier
199    pub owner_id: String,
200    /// Subscription tier
201    pub tier: SubscriptionTier,
202    /// Key expiration timestamp (None = never expires)
203    pub expires_at: Option<Instant>,
204    /// Custom metadata
205    pub metadata: HashMap<String, String>,
206}
207
208/// Trait for validating API keys
209/// Implement this for your storage backend (database, Redis, etc.)
210#[async_trait::async_trait]
211pub trait ApiKeyValidator: Send + Sync + 'static {
212    /// Validate an API key and return its info if valid
213    async fn validate(&self, api_key: &str) -> Result<ApiKeyInfo, WsAuthError>;
214
215    /// Revoke an API key (optional)
216    async fn revoke(&self, key_id: &str) -> Result<(), WsAuthError> {
217        let _ = key_id;
218        Ok(())
219    }
220}
221
222/// In-memory API key validator for development/testing
223#[derive(Debug, Clone)]
224pub struct InMemoryApiKeyValidator {
225    keys: Arc<RwLock<HashMap<String, ApiKeyInfo>>>,
226}
227
228impl InMemoryApiKeyValidator {
229    pub fn new() -> Self {
230        Self {
231            keys: Arc::new(RwLock::new(HashMap::new())),
232        }
233    }
234
235    /// Add a new API key
236    pub fn add_key(&self, api_key: String, info: ApiKeyInfo) {
237        self.keys.write().insert(api_key, info);
238    }
239
240    /// Remove an API key
241    pub fn remove_key(&self, api_key: &str) {
242        self.keys.write().remove(api_key);
243    }
244}
245
246impl Default for InMemoryApiKeyValidator {
247    fn default() -> Self {
248        Self::new()
249    }
250}
251
252#[async_trait::async_trait]
253impl ApiKeyValidator for InMemoryApiKeyValidator {
254    async fn validate(&self, api_key: &str) -> Result<ApiKeyInfo, WsAuthError> {
255        let keys = self.keys.read();
256
257        // Constant-time comparison for all keys to prevent timing attacks
258        let mut found_info: Option<&ApiKeyInfo> = None;
259        for (stored_key, info) in keys.iter() {
260            if constant_time_compare(api_key, stored_key) {
261                found_info = Some(info);
262                break;
263            }
264        }
265
266        match found_info {
267            Some(info) => {
268                // Check expiration
269                if let Some(expires_at) = info.expires_at {
270                    if Instant::now() > expires_at {
271                        return Err(WsAuthError::ExpiredApiKey);
272                    }
273                }
274                Ok(info.clone())
275            }
276            None => Err(WsAuthError::InvalidApiKey),
277        }
278    }
279
280    async fn revoke(&self, key_id: &str) -> Result<(), WsAuthError> {
281        let mut keys = self.keys.write();
282        keys.retain(|_, v| v.key_id != key_id);
283        Ok(())
284    }
285}
286
287// ============================================================================
288// Connection Tracking
289// ============================================================================
290
291/// Information about an active WebSocket connection
292#[derive(Debug, Clone)]
293pub struct ConnectionInfo {
294    /// Unique connection ID
295    pub connection_id: Uuid,
296    /// API key ID (for grouping connections)
297    pub key_id: String,
298    /// Owner ID
299    pub owner_id: String,
300    /// Subscription tier
301    pub tier: SubscriptionTier,
302    /// Remote address
303    pub remote_addr: SocketAddr,
304    /// Connection established time
305    pub connected_at: Instant,
306    /// Last activity time
307    pub last_activity: Instant,
308    /// Request count for rate limiting
309    pub request_count: u32,
310    /// Rate limit window start
311    pub rate_window_start: Instant,
312}
313
314/// Connection tracker for managing active WebSocket connections
315#[derive(Debug)]
316pub struct ConnectionTracker {
317    /// Active connections by connection ID
318    connections: RwLock<HashMap<Uuid, ConnectionInfo>>,
319    /// Connection count by API key ID
320    connection_counts: RwLock<HashMap<String, usize>>,
321}
322
323impl ConnectionTracker {
324    pub fn new() -> Self {
325        Self {
326            connections: RwLock::new(HashMap::new()),
327            connection_counts: RwLock::new(HashMap::new()),
328        }
329    }
330
331    /// Register a new connection
332    pub fn register(
333        &self,
334        key_info: &ApiKeyInfo,
335        remote_addr: SocketAddr,
336    ) -> Result<ConnectionInfo, WsAuthError> {
337        let mut counts = self.connection_counts.write();
338        let current_count = counts.get(&key_info.key_id).copied().unwrap_or(0);
339        let max_connections = key_info.tier.max_connections();
340
341        if current_count >= max_connections {
342            return Err(WsAuthError::ConnectionLimitExceeded(
343                key_info.tier.to_string(),
344                max_connections,
345            ));
346        }
347
348        let now = Instant::now();
349        let conn_info = ConnectionInfo {
350            connection_id: Uuid::new_v4(),
351            key_id: key_info.key_id.clone(),
352            owner_id: key_info.owner_id.clone(),
353            tier: key_info.tier,
354            remote_addr,
355            connected_at: now,
356            last_activity: now,
357            request_count: 0,
358            rate_window_start: now,
359        };
360
361        // Update counts
362        *counts.entry(key_info.key_id.clone()).or_insert(0) += 1;
363
364        // Store connection info
365        self.connections
366            .write()
367            .insert(conn_info.connection_id, conn_info.clone());
368
369        info!(
370            connection_id = %conn_info.connection_id,
371            key_id = %key_info.key_id,
372            tier = %key_info.tier,
373            "New WebSocket connection registered"
374        );
375
376        Ok(conn_info)
377    }
378
379    /// Unregister a connection
380    pub fn unregister(&self, connection_id: Uuid) {
381        let mut connections = self.connections.write();
382        if let Some(conn_info) = connections.remove(&connection_id) {
383            let mut counts = self.connection_counts.write();
384            if let Some(count) = counts.get_mut(&conn_info.key_id) {
385                *count = count.saturating_sub(1);
386                if *count == 0 {
387                    counts.remove(&conn_info.key_id);
388                }
389            }
390
391            info!(
392                connection_id = %connection_id,
393                key_id = %conn_info.key_id,
394                "WebSocket connection unregistered"
395            );
396        }
397    }
398
399    /// Check and update rate limit, returns true if allowed
400    pub fn check_rate_limit(&self, connection_id: Uuid) -> Result<(), WsAuthError> {
401        let mut connections = self.connections.write();
402
403        if let Some(conn_info) = connections.get_mut(&connection_id) {
404            let now = Instant::now();
405            let rate_limit = conn_info.tier.rate_limit();
406
407            // Reset window if more than 1 minute has passed
408            if now.duration_since(conn_info.rate_window_start) > Duration::from_secs(60) {
409                conn_info.rate_window_start = now;
410                conn_info.request_count = 0;
411            }
412
413            conn_info.request_count += 1;
414            conn_info.last_activity = now;
415
416            if conn_info.request_count > rate_limit {
417                return Err(WsAuthError::RateLimitExceeded(rate_limit));
418            }
419        }
420
421        Ok(())
422    }
423
424    /// Get connection info
425    pub fn get(&self, connection_id: Uuid) -> Option<ConnectionInfo> {
426        self.connections.read().get(&connection_id).cloned()
427    }
428
429    /// Get all connections for an API key
430    pub fn get_by_key(&self, key_id: &str) -> Vec<ConnectionInfo> {
431        self.connections
432            .read()
433            .values()
434            .filter(|c| c.key_id == key_id)
435            .cloned()
436            .collect()
437    }
438
439    /// Get total active connection count
440    pub fn total_connections(&self) -> usize {
441        self.connections.read().len()
442    }
443
444    /// Get connection count for a specific API key
445    pub fn connection_count(&self, key_id: &str) -> usize {
446        self.connection_counts
447            .read()
448            .get(key_id)
449            .copied()
450            .unwrap_or(0)
451    }
452
453    /// Clean up expired/stale connections
454    pub fn cleanup_stale(&self, max_idle: Duration) {
455        let now = Instant::now();
456        let mut to_remove = Vec::new();
457
458        {
459            let connections = self.connections.read();
460            for (id, info) in connections.iter() {
461                if now.duration_since(info.last_activity) > max_idle {
462                    to_remove.push(*id);
463                }
464            }
465        }
466
467        for id in to_remove {
468            self.unregister(id);
469            debug!(connection_id = %id, "Cleaned up stale connection");
470        }
471    }
472}
473
474impl Default for ConnectionTracker {
475    fn default() -> Self {
476        Self::new()
477    }
478}
479
480// ============================================================================
481// Authentication State
482// ============================================================================
483
484/// Shared authentication state for the WebSocket server
485#[derive(Clone)]
486pub struct WsAuthState<V: ApiKeyValidator> {
487    /// API key validator
488    pub validator: Arc<V>,
489    /// Connection tracker
490    pub tracker: Arc<ConnectionTracker>,
491    /// Authentication timeout duration
492    pub auth_timeout: Duration,
493    /// Header name for API key (default: "Authorization")
494    pub api_key_header: String,
495    /// Whether to require TLS (wss://)
496    pub require_tls: bool,
497}
498
499impl<V: ApiKeyValidator> WsAuthState<V> {
500    pub fn new(validator: V) -> Self {
501        Self {
502            validator: Arc::new(validator),
503            tracker: Arc::new(ConnectionTracker::new()),
504            auth_timeout: Duration::from_secs(30),
505            api_key_header: "Authorization".to_string(),
506            require_tls: false,
507        }
508    }
509
510    pub fn with_auth_timeout(mut self, timeout: Duration) -> Self {
511        self.auth_timeout = timeout;
512        self
513    }
514
515    pub fn with_api_key_header(mut self, header: impl Into<String>) -> Self {
516        self.api_key_header = header.into();
517        self
518    }
519
520    pub fn with_require_tls(mut self, require: bool) -> Self {
521        self.require_tls = require;
522        self
523    }
524
525    /// Extract API key from request headers
526    pub fn extract_api_key_from_headers(&self, headers: &HeaderMap) -> Option<String> {
527        headers
528            .get(&self.api_key_header)
529            .and_then(|v| v.to_str().ok())
530            .map(|s| {
531                // Handle "Bearer <token>" format
532                s.strip_prefix("Bearer ").unwrap_or(s).to_string()
533            })
534    }
535}
536
537// ============================================================================
538// WebSocket Handler
539// ============================================================================
540
541/// Authentication message sent by client in first WebSocket message
542#[derive(Debug, Deserialize)]
543pub struct WsAuthMessage {
544    /// API key for authentication
545    pub api_key: String,
546    /// Optional client metadata
547    #[serde(default)]
548    pub client_info: HashMap<String, String>,
549}
550
551/// Authentication result sent to client
552#[derive(Debug, Serialize)]
553pub struct WsAuthResult {
554    /// Whether authentication succeeded
555    pub success: bool,
556    /// Error message if failed
557    #[serde(skip_serializing_if = "Option::is_none")]
558    pub error: Option<String>,
559    /// Connection ID if succeeded
560    #[serde(skip_serializing_if = "Option::is_none")]
561    pub connection_id: Option<String>,
562    /// Subscription tier if succeeded
563    #[serde(skip_serializing_if = "Option::is_none")]
564    pub tier: Option<String>,
565    /// Rate limit (requests per minute)
566    #[serde(skip_serializing_if = "Option::is_none")]
567    pub rate_limit: Option<u32>,
568    /// Session timeout in seconds
569    #[serde(skip_serializing_if = "Option::is_none")]
570    pub session_timeout_secs: Option<u64>,
571}
572
573/// Authenticated WebSocket connection handle
574pub struct AuthenticatedWsConnection {
575    /// Connection info
576    pub info: ConnectionInfo,
577    /// WebSocket stream
578    pub socket: WebSocket,
579    /// State reference for rate limiting
580    tracker: Arc<ConnectionTracker>,
581}
582
583impl AuthenticatedWsConnection {
584    /// Send a message (with rate limit check)
585    pub async fn send(&mut self, msg: Message) -> Result<(), WsAuthError> {
586        self.tracker.check_rate_limit(self.info.connection_id)?;
587        self.socket
588            .send(msg)
589            .await
590            .map_err(|e| WsAuthError::Internal(e.to_string()))
591    }
592
593    /// Receive a message
594    pub async fn recv(&mut self) -> Option<Result<Message, axum::Error>> {
595        self.socket.recv().await
596    }
597
598    /// Get the connection ID
599    pub fn connection_id(&self) -> Uuid {
600        self.info.connection_id
601    }
602
603    /// Get the subscription tier
604    pub fn tier(&self) -> SubscriptionTier {
605        self.info.tier
606    }
607}
608
609/// WebSocket upgrade handler with header-based authentication
610#[instrument(skip(ws, state))]
611pub async fn ws_handler_with_header_auth<V: ApiKeyValidator>(
612    ws: WebSocketUpgrade,
613    ConnectInfo(addr): ConnectInfo<SocketAddr>,
614    State(state): State<WsAuthState<V>>,
615    headers: HeaderMap,
616) -> Result<Response, WsAuthError> {
617    // Try to extract API key from headers
618    let api_key = state
619        .extract_api_key_from_headers(&headers)
620        .ok_or(WsAuthError::MissingApiKey)?;
621
622    // Validate API key
623    let key_info = state.validator.validate(&api_key).await?;
624
625    // Check tier allows WebSocket
626    if !key_info.tier.allows_websocket() {
627        return Err(WsAuthError::TierNotAllowed(key_info.tier.to_string()));
628    }
629
630    // Register connection
631    let conn_info = state.tracker.register(&key_info, addr)?;
632
633    info!(
634        connection_id = %conn_info.connection_id,
635        tier = %key_info.tier,
636        remote_addr = %addr,
637        "WebSocket connection authenticated via header"
638    );
639
640    // Upgrade connection
641    let tracker = Arc::clone(&state.tracker);
642
643    Ok(ws.on_upgrade(move |socket| async move {
644        handle_authenticated_socket(socket, conn_info, tracker).await;
645    }))
646}
647
648/// WebSocket upgrade handler with first-message authentication
649#[instrument(skip(ws, state))]
650pub async fn ws_handler_with_message_auth<V: ApiKeyValidator>(
651    ws: WebSocketUpgrade,
652    ConnectInfo(addr): ConnectInfo<SocketAddr>,
653    State(state): State<WsAuthState<V>>,
654    headers: HeaderMap,
655) -> impl IntoResponse {
656    // Check for API key in header first (preferred method)
657    if let Some(api_key) = state.extract_api_key_from_headers(&headers) {
658        match state.validator.validate(&api_key).await {
659            Ok(key_info) => {
660                if !key_info.tier.allows_websocket() {
661                    return Err(WsAuthError::TierNotAllowed(key_info.tier.to_string()));
662                }
663
664                match state.tracker.register(&key_info, addr) {
665                    Ok(conn_info) => {
666                        let tracker = Arc::clone(&state.tracker);
667                        return Ok(ws.on_upgrade(move |socket| async move {
668                            handle_authenticated_socket(socket, conn_info, tracker).await;
669                        }));
670                    }
671                    Err(e) => return Err(e),
672                }
673            }
674            Err(_) => {
675                // Fall through to message-based auth
676            }
677        }
678    }
679
680    // No valid header auth, require first-message authentication
681    let validator = Arc::clone(&state.validator);
682    let tracker = Arc::clone(&state.tracker);
683    let auth_timeout = state.auth_timeout;
684
685    Ok(ws.on_upgrade(move |socket| async move {
686        handle_unauthenticated_upgrade(socket, addr, validator, tracker, auth_timeout).await;
687    }))
688}
689
690/// Handle socket that requires first-message authentication
691async fn handle_unauthenticated_upgrade<V: ApiKeyValidator>(
692    mut socket: WebSocket,
693    addr: SocketAddr,
694    validator: Arc<V>,
695    tracker: Arc<ConnectionTracker>,
696    auth_timeout: Duration,
697) {
698    // Wait for authentication message with timeout
699    let auth_result = tokio::time::timeout(auth_timeout, socket.recv()).await;
700
701    let auth_msg = match auth_result {
702        Ok(Some(Ok(Message::Text(text)))) => match serde_json::from_str::<WsAuthMessage>(&text) {
703            Ok(msg) => msg,
704            Err(e) => {
705                let _ = send_auth_error(&mut socket, &WsAuthError::InvalidAuthMessage).await;
706                warn!(error = %e, "Invalid auth message format");
707                return;
708            }
709        },
710        Ok(Some(Ok(_))) => {
711            let _ = send_auth_error(&mut socket, &WsAuthError::InvalidAuthMessage).await;
712            warn!("First message must be text auth message");
713            return;
714        }
715        Ok(Some(Err(e))) => {
716            warn!(error = %e, "WebSocket error during auth");
717            return;
718        }
719        Ok(None) => {
720            warn!("Connection closed before authentication");
721            return;
722        }
723        Err(_) => {
724            let _ = send_auth_error(
725                &mut socket,
726                &WsAuthError::AuthTimeout(auth_timeout.as_secs()),
727            )
728            .await;
729            warn!(
730                timeout_secs = auth_timeout.as_secs(),
731                "Authentication timeout"
732            );
733            return;
734        }
735    };
736
737    // Validate API key
738    let key_info = match validator.validate(&auth_msg.api_key).await {
739        Ok(info) => info,
740        Err(e) => {
741            let _ = send_auth_error(&mut socket, &e).await;
742            warn!(error = %e, "API key validation failed");
743            return;
744        }
745    };
746
747    // Check tier
748    if !key_info.tier.allows_websocket() {
749        let err = WsAuthError::TierNotAllowed(key_info.tier.to_string());
750        let _ = send_auth_error(&mut socket, &err).await;
751        return;
752    }
753
754    // Register connection
755    let conn_info = match tracker.register(&key_info, addr) {
756        Ok(info) => info,
757        Err(e) => {
758            let _ = send_auth_error(&mut socket, &e).await;
759            return;
760        }
761    };
762
763    // Send success response
764    let auth_result = WsAuthResult {
765        success: true,
766        error: None,
767        connection_id: Some(conn_info.connection_id.to_string()),
768        tier: Some(conn_info.tier.to_string()),
769        rate_limit: Some(conn_info.tier.rate_limit()),
770        session_timeout_secs: Some(conn_info.tier.session_timeout().as_secs()),
771    };
772
773    if let Ok(json) = serde_json::to_string(&auth_result) {
774        let _ = socket.send(Message::Text(json)).await;
775    }
776
777    info!(
778        connection_id = %conn_info.connection_id,
779        tier = %key_info.tier,
780        remote_addr = %addr,
781        "WebSocket connection authenticated via first message"
782    );
783
784    // Continue with authenticated handler
785    handle_authenticated_socket(socket, conn_info, tracker).await;
786}
787
788/// Send authentication error response
789async fn send_auth_error(socket: &mut WebSocket, error: &WsAuthError) -> Result<(), axum::Error> {
790    let result = WsAuthResult {
791        success: false,
792        error: Some(error.to_string()),
793        connection_id: None,
794        tier: None,
795        rate_limit: None,
796        session_timeout_secs: None,
797    };
798
799    if let Ok(json) = serde_json::to_string(&result) {
800        socket.send(Message::Text(json)).await?;
801    }
802
803    // Send close frame
804    socket
805        .send(Message::Close(Some(axum::extract::ws::CloseFrame {
806            code: axum::extract::ws::close_code::POLICY,
807            reason: error.to_string().into(),
808        })))
809        .await?;
810
811    Ok(())
812}
813
814/// Handle an authenticated WebSocket connection
815async fn handle_authenticated_socket(
816    mut socket: WebSocket,
817    conn_info: ConnectionInfo,
818    tracker: Arc<ConnectionTracker>,
819) {
820    let connection_id = conn_info.connection_id;
821    let tier = conn_info.tier;
822
823    // Create a channel for the MCP message handler (for future bidirectional messaging)
824    let (_tx, mut rx) = mpsc::channel::<Message>(100);
825
826    // Spawn task to forward messages from channel to socket
827    let send_task = tokio::spawn({
828        let tracker = Arc::clone(&tracker);
829        async move {
830            while let Some(_msg) = rx.recv().await {
831                // Check rate limit before sending
832                if let Err(e) = tracker.check_rate_limit(connection_id) {
833                    warn!(
834                        connection_id = %connection_id,
835                        error = %e,
836                        "Rate limit exceeded"
837                    );
838                    // Send error and close
839                    let _error_msg = serde_json::json!({
840                        "jsonrpc": "2.0",
841                        "error": {
842                            "code": -32000,
843                            "message": e.to_string()
844                        }
845                    });
846                    // Ignore send errors during rate limit
847                    break;
848                }
849            }
850        }
851    });
852
853    // Process incoming messages
854    while let Some(msg) = socket.recv().await {
855        match msg {
856            Ok(Message::Text(text)) => {
857                debug!(
858                    connection_id = %connection_id,
859                    msg_len = text.len(),
860                    "Received text message"
861                );
862
863                // Check message size limit
864                if text.len() > tier.max_message_size() {
865                    warn!(
866                        connection_id = %connection_id,
867                        size = text.len(),
868                        max = tier.max_message_size(),
869                        "Message size exceeds tier limit"
870                    );
871                    // Send error response
872                    let error_msg = serde_json::json!({
873                        "jsonrpc": "2.0",
874                        "error": {
875                            "code": -32000,
876                            "message": format!("Message size {} exceeds limit {}", text.len(), tier.max_message_size())
877                        }
878                    });
879                    if let Ok(json) = serde_json::to_string(&error_msg) {
880                        let _ = socket.send(Message::Text(json)).await;
881                    }
882                    continue;
883                }
884
885                // TODO: Dispatch to MCP handler
886                // For now, echo back
887                let _ = socket.send(Message::Text(text)).await;
888            }
889            Ok(Message::Binary(data)) => {
890                debug!(
891                    connection_id = %connection_id,
892                    size = data.len(),
893                    "Received binary message"
894                );
895
896                // Check message size limit
897                if data.len() > tier.max_message_size() {
898                    warn!(
899                        connection_id = %connection_id,
900                        size = data.len(),
901                        max = tier.max_message_size(),
902                        "Binary message size exceeds tier limit"
903                    );
904                    continue;
905                }
906
907                // Echo back for now
908                let _ = socket.send(Message::Binary(data)).await;
909            }
910            Ok(Message::Ping(data)) => {
911                let _ = socket.send(Message::Pong(data)).await;
912            }
913            Ok(Message::Pong(_)) => {
914                // Ignore pongs
915            }
916            Ok(Message::Close(_)) => {
917                info!(connection_id = %connection_id, "Client initiated close");
918                break;
919            }
920            Err(e) => {
921                error!(
922                    connection_id = %connection_id,
923                    error = %e,
924                    "WebSocket error"
925                );
926                break;
927            }
928        }
929    }
930
931    // Clean up
932    send_task.abort();
933    tracker.unregister(connection_id);
934    info!(connection_id = %connection_id, "Connection closed");
935}
936
937// ============================================================================
938// Axum Middleware Layer
939// ============================================================================
940
941/// Middleware for pre-upgrade authentication checks
942pub async fn ws_auth_middleware<V: ApiKeyValidator>(
943    State(state): State<WsAuthState<V>>,
944    request: Request<Body>,
945    next: Next,
946) -> Result<Response, WsAuthError> {
947    // Check if this is a WebSocket upgrade request
948    let is_upgrade = request
949        .headers()
950        .get("upgrade")
951        .and_then(|v| v.to_str().ok())
952        .map(|v| v.eq_ignore_ascii_case("websocket"))
953        .unwrap_or(false);
954
955    if !is_upgrade {
956        // Not a WebSocket request, pass through
957        return Ok(next.run(request).await);
958    }
959
960    // TLS check in production
961    if state.require_tls {
962        let scheme = request.uri().scheme_str().unwrap_or("http");
963        if scheme != "https" && scheme != "wss" {
964            warn!("WebSocket connection rejected: TLS required");
965            return Err(WsAuthError::Internal(
966                "Secure connection (wss://) required".to_string(),
967            ));
968        }
969    }
970
971    // Continue to handler
972    Ok(next.run(request).await)
973}
974
975// ============================================================================
976// Helper Functions
977// ============================================================================
978
979/// Constant-time string comparison to prevent timing attacks
980fn constant_time_compare(a: &str, b: &str) -> bool {
981    let a_bytes = a.as_bytes();
982    let b_bytes = b.as_bytes();
983
984    if a_bytes.len() != b_bytes.len() {
985        // Still do a comparison to maintain constant time behavior
986        let mut _dummy: u8 = 0;
987        for byte in a_bytes.iter() {
988            _dummy |= *byte;
989        }
990        return false;
991    }
992
993    let mut result: u8 = 0;
994    for (x, y) in a_bytes.iter().zip(b_bytes.iter()) {
995        result |= x ^ y;
996    }
997
998    result == 0
999}
1000
1001/// Generate a new API key (for development/testing)
1002pub fn generate_api_key() -> String {
1003    format!("rk_{}", Uuid::new_v4().to_string().replace('-', ""))
1004}
1005
1006// ============================================================================
1007// Tests
1008// ============================================================================
1009
1010#[cfg(test)]
1011mod tests {
1012    use super::*;
1013
1014    #[test]
1015    fn test_subscription_tier_limits() {
1016        assert_eq!(SubscriptionTier::Free.max_connections(), 1);
1017        assert_eq!(SubscriptionTier::Pro.max_connections(), 5);
1018        assert_eq!(SubscriptionTier::Team.max_connections(), 25);
1019        assert_eq!(SubscriptionTier::Enterprise.max_connections(), 100);
1020    }
1021
1022    #[test]
1023    fn test_subscription_tier_rate_limits() {
1024        assert_eq!(SubscriptionTier::Free.rate_limit(), 60);
1025        assert_eq!(SubscriptionTier::Pro.rate_limit(), 300);
1026        assert_eq!(SubscriptionTier::Team.rate_limit(), 1000);
1027        assert_eq!(SubscriptionTier::Enterprise.rate_limit(), 10000);
1028    }
1029
1030    #[test]
1031    fn test_constant_time_compare() {
1032        assert!(constant_time_compare("secret", "secret"));
1033        assert!(!constant_time_compare("secret", "Secret"));
1034        assert!(!constant_time_compare("short", "longer"));
1035        assert!(!constant_time_compare("", "nonempty"));
1036    }
1037
1038    #[test]
1039    fn test_generate_api_key() {
1040        let key = generate_api_key();
1041        assert!(key.starts_with("rk_"));
1042        assert_eq!(key.len(), 35); // "rk_" + 32 hex chars
1043    }
1044
1045    #[tokio::test]
1046    async fn test_in_memory_validator() {
1047        let validator = InMemoryApiKeyValidator::new();
1048
1049        let info = ApiKeyInfo {
1050            key_id: "key_123".to_string(),
1051            owner_id: "user_456".to_string(),
1052            tier: SubscriptionTier::Pro,
1053            expires_at: None,
1054            metadata: HashMap::new(),
1055        };
1056
1057        validator.add_key("test_api_key".to_string(), info.clone());
1058
1059        // Valid key
1060        let result = validator.validate("test_api_key").await;
1061        assert!(result.is_ok());
1062        let validated = result.unwrap();
1063        assert_eq!(validated.tier, SubscriptionTier::Pro);
1064
1065        // Invalid key
1066        let result = validator.validate("wrong_key").await;
1067        assert!(matches!(result, Err(WsAuthError::InvalidApiKey)));
1068    }
1069
1070    #[test]
1071    fn test_connection_tracker() {
1072        let tracker = ConnectionTracker::new();
1073
1074        let key_info = ApiKeyInfo {
1075            key_id: "key_123".to_string(),
1076            owner_id: "user_456".to_string(),
1077            tier: SubscriptionTier::Free, // Only 1 connection allowed
1078            expires_at: None,
1079            metadata: HashMap::new(),
1080        };
1081
1082        let addr: SocketAddr = "127.0.0.1:9100".parse().unwrap();
1083
1084        // First connection should succeed
1085        let conn1 = tracker.register(&key_info, addr);
1086        assert!(conn1.is_ok());
1087
1088        // Second connection should fail (Free tier = 1 max)
1089        let conn2 = tracker.register(&key_info, addr);
1090        assert!(matches!(
1091            conn2,
1092            Err(WsAuthError::ConnectionLimitExceeded(_, 1))
1093        ));
1094
1095        // Unregister first connection
1096        tracker.unregister(conn1.unwrap().connection_id);
1097
1098        // Now should be able to connect again
1099        let conn3 = tracker.register(&key_info, addr);
1100        assert!(conn3.is_ok());
1101    }
1102
1103    #[test]
1104    fn test_rate_limiting() {
1105        let tracker = ConnectionTracker::new();
1106
1107        let key_info = ApiKeyInfo {
1108            key_id: "key_123".to_string(),
1109            owner_id: "user_456".to_string(),
1110            tier: SubscriptionTier::Free, // 60 requests per minute
1111            expires_at: None,
1112            metadata: HashMap::new(),
1113        };
1114
1115        let addr: SocketAddr = "127.0.0.1:9100".parse().unwrap();
1116        let conn = tracker.register(&key_info, addr).unwrap();
1117
1118        // Should allow 60 requests
1119        for _ in 0..60 {
1120            assert!(tracker.check_rate_limit(conn.connection_id).is_ok());
1121        }
1122
1123        // 61st request should fail
1124        assert!(matches!(
1125            tracker.check_rate_limit(conn.connection_id),
1126            Err(WsAuthError::RateLimitExceeded(60))
1127        ));
1128    }
1129
1130    #[test]
1131    fn test_api_key_extraction() {
1132        let validator = InMemoryApiKeyValidator::new();
1133        let state = WsAuthState::new(validator);
1134
1135        let mut headers = HeaderMap::new();
1136
1137        // Test Bearer format
1138        headers.insert("Authorization", "Bearer my_api_key".parse().unwrap());
1139        assert_eq!(
1140            state.extract_api_key_from_headers(&headers),
1141            Some("my_api_key".to_string())
1142        );
1143
1144        // Test raw format
1145        headers.insert("Authorization", "raw_api_key".parse().unwrap());
1146        assert_eq!(
1147            state.extract_api_key_from_headers(&headers),
1148            Some("raw_api_key".to_string())
1149        );
1150
1151        // Test custom header
1152        let state = state.with_api_key_header("X-Api-Key");
1153        headers.insert("X-Api-Key", "custom_key".parse().unwrap());
1154        assert_eq!(
1155            state.extract_api_key_from_headers(&headers),
1156            Some("custom_key".to_string())
1157        );
1158    }
1159}