1use axum::{
29 body::Body,
30 extract::{
31 ws::{Message, WebSocket, WebSocketUpgrade},
32 ConnectInfo, State,
33 },
34 http::{header::HeaderMap, Request, StatusCode},
35 middleware::Next,
36 response::{IntoResponse, Response},
37};
38#[allow(unused_imports)]
40use futures_util::StreamExt;
41use parking_lot::RwLock;
42use serde::{Deserialize, Serialize};
43use std::{
44 collections::HashMap,
45 net::SocketAddr,
46 sync::Arc,
47 time::{Duration, Instant},
48};
49use thiserror::Error;
50use tokio::sync::mpsc;
51use tracing::{debug, error, info, instrument, warn};
52use uuid::Uuid;
53
54#[derive(Debug, Error)]
60pub enum WsAuthError {
61 #[error("Missing API key")]
62 MissingApiKey,
63
64 #[error("Invalid API key")]
65 InvalidApiKey,
66
67 #[error("API key expired")]
68 ExpiredApiKey,
69
70 #[error("Subscription tier '{0}' does not allow WebSocket access")]
71 TierNotAllowed(String),
72
73 #[error("Connection limit exceeded for tier '{0}': max {1} connections")]
74 ConnectionLimitExceeded(String, usize),
75
76 #[error("Rate limit exceeded: {0} requests per minute allowed")]
77 RateLimitExceeded(u32),
78
79 #[error("Authentication timeout: must authenticate within {0} seconds")]
80 AuthTimeout(u64),
81
82 #[error("Invalid authentication message format")]
83 InvalidAuthMessage,
84
85 #[error("Internal authentication error: {0}")]
86 Internal(String),
87}
88
89impl IntoResponse for WsAuthError {
90 fn into_response(self) -> Response {
91 let (status, message) = match &self {
92 WsAuthError::MissingApiKey => (StatusCode::UNAUTHORIZED, self.to_string()),
93 WsAuthError::InvalidApiKey => (StatusCode::UNAUTHORIZED, self.to_string()),
94 WsAuthError::ExpiredApiKey => (StatusCode::UNAUTHORIZED, self.to_string()),
95 WsAuthError::TierNotAllowed(_) => (StatusCode::FORBIDDEN, self.to_string()),
96 WsAuthError::ConnectionLimitExceeded(_, _) => {
97 (StatusCode::TOO_MANY_REQUESTS, self.to_string())
98 }
99 WsAuthError::RateLimitExceeded(_) => (StatusCode::TOO_MANY_REQUESTS, self.to_string()),
100 WsAuthError::AuthTimeout(_) => (StatusCode::REQUEST_TIMEOUT, self.to_string()),
101 WsAuthError::InvalidAuthMessage => (StatusCode::BAD_REQUEST, self.to_string()),
102 WsAuthError::Internal(_) => (
103 StatusCode::INTERNAL_SERVER_ERROR,
104 "Internal error".to_string(),
105 ),
106 };
107
108 (status, message).into_response()
109 }
110}
111
112#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
118#[serde(rename_all = "lowercase")]
119pub enum SubscriptionTier {
120 Free,
122 Pro,
124 Team,
126 Enterprise,
128}
129
130impl SubscriptionTier {
131 pub fn max_connections(&self) -> usize {
133 match self {
134 SubscriptionTier::Free => 1,
135 SubscriptionTier::Pro => 5,
136 SubscriptionTier::Team => 25,
137 SubscriptionTier::Enterprise => 100,
138 }
139 }
140
141 pub fn rate_limit(&self) -> u32 {
143 match self {
144 SubscriptionTier::Free => 60,
145 SubscriptionTier::Pro => 300,
146 SubscriptionTier::Team => 1000,
147 SubscriptionTier::Enterprise => 10000,
148 }
149 }
150
151 pub fn max_message_size(&self) -> usize {
153 match self {
154 SubscriptionTier::Free => 64 * 1024, SubscriptionTier::Pro => 1024 * 1024, SubscriptionTier::Team => 10 * 1024 * 1024, SubscriptionTier::Enterprise => 100 * 1024 * 1024, }
159 }
160
161 pub fn session_timeout(&self) -> Duration {
163 match self {
164 SubscriptionTier::Free => Duration::from_secs(30 * 60), SubscriptionTier::Pro => Duration::from_secs(2 * 60 * 60), SubscriptionTier::Team => Duration::from_secs(8 * 60 * 60), SubscriptionTier::Enterprise => Duration::from_secs(24 * 60 * 60), }
169 }
170
171 pub fn allows_websocket(&self) -> bool {
173 true
175 }
176}
177
178impl std::fmt::Display for SubscriptionTier {
179 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
180 match self {
181 SubscriptionTier::Free => write!(f, "free"),
182 SubscriptionTier::Pro => write!(f, "pro"),
183 SubscriptionTier::Team => write!(f, "team"),
184 SubscriptionTier::Enterprise => write!(f, "enterprise"),
185 }
186 }
187}
188
189#[derive(Debug, Clone)]
195pub struct ApiKeyInfo {
196 pub key_id: String,
198 pub owner_id: String,
200 pub tier: SubscriptionTier,
202 pub expires_at: Option<Instant>,
204 pub metadata: HashMap<String, String>,
206}
207
208#[async_trait::async_trait]
211pub trait ApiKeyValidator: Send + Sync + 'static {
212 async fn validate(&self, api_key: &str) -> Result<ApiKeyInfo, WsAuthError>;
214
215 async fn revoke(&self, key_id: &str) -> Result<(), WsAuthError> {
217 let _ = key_id;
218 Ok(())
219 }
220}
221
222#[derive(Debug, Clone)]
224pub struct InMemoryApiKeyValidator {
225 keys: Arc<RwLock<HashMap<String, ApiKeyInfo>>>,
226}
227
228impl InMemoryApiKeyValidator {
229 pub fn new() -> Self {
230 Self {
231 keys: Arc::new(RwLock::new(HashMap::new())),
232 }
233 }
234
235 pub fn add_key(&self, api_key: String, info: ApiKeyInfo) {
237 self.keys.write().insert(api_key, info);
238 }
239
240 pub fn remove_key(&self, api_key: &str) {
242 self.keys.write().remove(api_key);
243 }
244}
245
246impl Default for InMemoryApiKeyValidator {
247 fn default() -> Self {
248 Self::new()
249 }
250}
251
252#[async_trait::async_trait]
253impl ApiKeyValidator for InMemoryApiKeyValidator {
254 async fn validate(&self, api_key: &str) -> Result<ApiKeyInfo, WsAuthError> {
255 let keys = self.keys.read();
256
257 let mut found_info: Option<&ApiKeyInfo> = None;
259 for (stored_key, info) in keys.iter() {
260 if constant_time_compare(api_key, stored_key) {
261 found_info = Some(info);
262 break;
263 }
264 }
265
266 match found_info {
267 Some(info) => {
268 if let Some(expires_at) = info.expires_at {
270 if Instant::now() > expires_at {
271 return Err(WsAuthError::ExpiredApiKey);
272 }
273 }
274 Ok(info.clone())
275 }
276 None => Err(WsAuthError::InvalidApiKey),
277 }
278 }
279
280 async fn revoke(&self, key_id: &str) -> Result<(), WsAuthError> {
281 let mut keys = self.keys.write();
282 keys.retain(|_, v| v.key_id != key_id);
283 Ok(())
284 }
285}
286
287#[derive(Debug, Clone)]
293pub struct ConnectionInfo {
294 pub connection_id: Uuid,
296 pub key_id: String,
298 pub owner_id: String,
300 pub tier: SubscriptionTier,
302 pub remote_addr: SocketAddr,
304 pub connected_at: Instant,
306 pub last_activity: Instant,
308 pub request_count: u32,
310 pub rate_window_start: Instant,
312}
313
314#[derive(Debug)]
316pub struct ConnectionTracker {
317 connections: RwLock<HashMap<Uuid, ConnectionInfo>>,
319 connection_counts: RwLock<HashMap<String, usize>>,
321}
322
323impl ConnectionTracker {
324 pub fn new() -> Self {
325 Self {
326 connections: RwLock::new(HashMap::new()),
327 connection_counts: RwLock::new(HashMap::new()),
328 }
329 }
330
331 pub fn register(
333 &self,
334 key_info: &ApiKeyInfo,
335 remote_addr: SocketAddr,
336 ) -> Result<ConnectionInfo, WsAuthError> {
337 let mut counts = self.connection_counts.write();
338 let current_count = counts.get(&key_info.key_id).copied().unwrap_or(0);
339 let max_connections = key_info.tier.max_connections();
340
341 if current_count >= max_connections {
342 return Err(WsAuthError::ConnectionLimitExceeded(
343 key_info.tier.to_string(),
344 max_connections,
345 ));
346 }
347
348 let now = Instant::now();
349 let conn_info = ConnectionInfo {
350 connection_id: Uuid::new_v4(),
351 key_id: key_info.key_id.clone(),
352 owner_id: key_info.owner_id.clone(),
353 tier: key_info.tier,
354 remote_addr,
355 connected_at: now,
356 last_activity: now,
357 request_count: 0,
358 rate_window_start: now,
359 };
360
361 *counts.entry(key_info.key_id.clone()).or_insert(0) += 1;
363
364 self.connections
366 .write()
367 .insert(conn_info.connection_id, conn_info.clone());
368
369 info!(
370 connection_id = %conn_info.connection_id,
371 key_id = %key_info.key_id,
372 tier = %key_info.tier,
373 "New WebSocket connection registered"
374 );
375
376 Ok(conn_info)
377 }
378
379 pub fn unregister(&self, connection_id: Uuid) {
381 let mut connections = self.connections.write();
382 if let Some(conn_info) = connections.remove(&connection_id) {
383 let mut counts = self.connection_counts.write();
384 if let Some(count) = counts.get_mut(&conn_info.key_id) {
385 *count = count.saturating_sub(1);
386 if *count == 0 {
387 counts.remove(&conn_info.key_id);
388 }
389 }
390
391 info!(
392 connection_id = %connection_id,
393 key_id = %conn_info.key_id,
394 "WebSocket connection unregistered"
395 );
396 }
397 }
398
399 pub fn check_rate_limit(&self, connection_id: Uuid) -> Result<(), WsAuthError> {
401 let mut connections = self.connections.write();
402
403 if let Some(conn_info) = connections.get_mut(&connection_id) {
404 let now = Instant::now();
405 let rate_limit = conn_info.tier.rate_limit();
406
407 if now.duration_since(conn_info.rate_window_start) > Duration::from_secs(60) {
409 conn_info.rate_window_start = now;
410 conn_info.request_count = 0;
411 }
412
413 conn_info.request_count += 1;
414 conn_info.last_activity = now;
415
416 if conn_info.request_count > rate_limit {
417 return Err(WsAuthError::RateLimitExceeded(rate_limit));
418 }
419 }
420
421 Ok(())
422 }
423
424 pub fn get(&self, connection_id: Uuid) -> Option<ConnectionInfo> {
426 self.connections.read().get(&connection_id).cloned()
427 }
428
429 pub fn get_by_key(&self, key_id: &str) -> Vec<ConnectionInfo> {
431 self.connections
432 .read()
433 .values()
434 .filter(|c| c.key_id == key_id)
435 .cloned()
436 .collect()
437 }
438
439 pub fn total_connections(&self) -> usize {
441 self.connections.read().len()
442 }
443
444 pub fn connection_count(&self, key_id: &str) -> usize {
446 self.connection_counts
447 .read()
448 .get(key_id)
449 .copied()
450 .unwrap_or(0)
451 }
452
453 pub fn cleanup_stale(&self, max_idle: Duration) {
455 let now = Instant::now();
456 let mut to_remove = Vec::new();
457
458 {
459 let connections = self.connections.read();
460 for (id, info) in connections.iter() {
461 if now.duration_since(info.last_activity) > max_idle {
462 to_remove.push(*id);
463 }
464 }
465 }
466
467 for id in to_remove {
468 self.unregister(id);
469 debug!(connection_id = %id, "Cleaned up stale connection");
470 }
471 }
472}
473
474impl Default for ConnectionTracker {
475 fn default() -> Self {
476 Self::new()
477 }
478}
479
480#[derive(Clone)]
486pub struct WsAuthState<V: ApiKeyValidator> {
487 pub validator: Arc<V>,
489 pub tracker: Arc<ConnectionTracker>,
491 pub auth_timeout: Duration,
493 pub api_key_header: String,
495 pub require_tls: bool,
497}
498
499impl<V: ApiKeyValidator> WsAuthState<V> {
500 pub fn new(validator: V) -> Self {
501 Self {
502 validator: Arc::new(validator),
503 tracker: Arc::new(ConnectionTracker::new()),
504 auth_timeout: Duration::from_secs(30),
505 api_key_header: "Authorization".to_string(),
506 require_tls: false,
507 }
508 }
509
510 pub fn with_auth_timeout(mut self, timeout: Duration) -> Self {
511 self.auth_timeout = timeout;
512 self
513 }
514
515 pub fn with_api_key_header(mut self, header: impl Into<String>) -> Self {
516 self.api_key_header = header.into();
517 self
518 }
519
520 pub fn with_require_tls(mut self, require: bool) -> Self {
521 self.require_tls = require;
522 self
523 }
524
525 pub fn extract_api_key_from_headers(&self, headers: &HeaderMap) -> Option<String> {
527 headers
528 .get(&self.api_key_header)
529 .and_then(|v| v.to_str().ok())
530 .map(|s| {
531 s.strip_prefix("Bearer ").unwrap_or(s).to_string()
533 })
534 }
535}
536
537#[derive(Debug, Deserialize)]
543pub struct WsAuthMessage {
544 pub api_key: String,
546 #[serde(default)]
548 pub client_info: HashMap<String, String>,
549}
550
551#[derive(Debug, Serialize)]
553pub struct WsAuthResult {
554 pub success: bool,
556 #[serde(skip_serializing_if = "Option::is_none")]
558 pub error: Option<String>,
559 #[serde(skip_serializing_if = "Option::is_none")]
561 pub connection_id: Option<String>,
562 #[serde(skip_serializing_if = "Option::is_none")]
564 pub tier: Option<String>,
565 #[serde(skip_serializing_if = "Option::is_none")]
567 pub rate_limit: Option<u32>,
568 #[serde(skip_serializing_if = "Option::is_none")]
570 pub session_timeout_secs: Option<u64>,
571}
572
573pub struct AuthenticatedWsConnection {
575 pub info: ConnectionInfo,
577 pub socket: WebSocket,
579 tracker: Arc<ConnectionTracker>,
581}
582
583impl AuthenticatedWsConnection {
584 pub async fn send(&mut self, msg: Message) -> Result<(), WsAuthError> {
586 self.tracker.check_rate_limit(self.info.connection_id)?;
587 self.socket
588 .send(msg)
589 .await
590 .map_err(|e| WsAuthError::Internal(e.to_string()))
591 }
592
593 pub async fn recv(&mut self) -> Option<Result<Message, axum::Error>> {
595 self.socket.recv().await
596 }
597
598 pub fn connection_id(&self) -> Uuid {
600 self.info.connection_id
601 }
602
603 pub fn tier(&self) -> SubscriptionTier {
605 self.info.tier
606 }
607}
608
609#[instrument(skip(ws, state))]
611pub async fn ws_handler_with_header_auth<V: ApiKeyValidator>(
612 ws: WebSocketUpgrade,
613 ConnectInfo(addr): ConnectInfo<SocketAddr>,
614 State(state): State<WsAuthState<V>>,
615 headers: HeaderMap,
616) -> Result<Response, WsAuthError> {
617 let api_key = state
619 .extract_api_key_from_headers(&headers)
620 .ok_or(WsAuthError::MissingApiKey)?;
621
622 let key_info = state.validator.validate(&api_key).await?;
624
625 if !key_info.tier.allows_websocket() {
627 return Err(WsAuthError::TierNotAllowed(key_info.tier.to_string()));
628 }
629
630 let conn_info = state.tracker.register(&key_info, addr)?;
632
633 info!(
634 connection_id = %conn_info.connection_id,
635 tier = %key_info.tier,
636 remote_addr = %addr,
637 "WebSocket connection authenticated via header"
638 );
639
640 let tracker = Arc::clone(&state.tracker);
642
643 Ok(ws.on_upgrade(move |socket| async move {
644 handle_authenticated_socket(socket, conn_info, tracker).await;
645 }))
646}
647
648#[instrument(skip(ws, state))]
650pub async fn ws_handler_with_message_auth<V: ApiKeyValidator>(
651 ws: WebSocketUpgrade,
652 ConnectInfo(addr): ConnectInfo<SocketAddr>,
653 State(state): State<WsAuthState<V>>,
654 headers: HeaderMap,
655) -> impl IntoResponse {
656 if let Some(api_key) = state.extract_api_key_from_headers(&headers) {
658 match state.validator.validate(&api_key).await {
659 Ok(key_info) => {
660 if !key_info.tier.allows_websocket() {
661 return Err(WsAuthError::TierNotAllowed(key_info.tier.to_string()));
662 }
663
664 match state.tracker.register(&key_info, addr) {
665 Ok(conn_info) => {
666 let tracker = Arc::clone(&state.tracker);
667 return Ok(ws.on_upgrade(move |socket| async move {
668 handle_authenticated_socket(socket, conn_info, tracker).await;
669 }));
670 }
671 Err(e) => return Err(e),
672 }
673 }
674 Err(_) => {
675 }
677 }
678 }
679
680 let validator = Arc::clone(&state.validator);
682 let tracker = Arc::clone(&state.tracker);
683 let auth_timeout = state.auth_timeout;
684
685 Ok(ws.on_upgrade(move |socket| async move {
686 handle_unauthenticated_upgrade(socket, addr, validator, tracker, auth_timeout).await;
687 }))
688}
689
690async fn handle_unauthenticated_upgrade<V: ApiKeyValidator>(
692 mut socket: WebSocket,
693 addr: SocketAddr,
694 validator: Arc<V>,
695 tracker: Arc<ConnectionTracker>,
696 auth_timeout: Duration,
697) {
698 let auth_result = tokio::time::timeout(auth_timeout, socket.recv()).await;
700
701 let auth_msg = match auth_result {
702 Ok(Some(Ok(Message::Text(text)))) => match serde_json::from_str::<WsAuthMessage>(&text) {
703 Ok(msg) => msg,
704 Err(e) => {
705 let _ = send_auth_error(&mut socket, &WsAuthError::InvalidAuthMessage).await;
706 warn!(error = %e, "Invalid auth message format");
707 return;
708 }
709 },
710 Ok(Some(Ok(_))) => {
711 let _ = send_auth_error(&mut socket, &WsAuthError::InvalidAuthMessage).await;
712 warn!("First message must be text auth message");
713 return;
714 }
715 Ok(Some(Err(e))) => {
716 warn!(error = %e, "WebSocket error during auth");
717 return;
718 }
719 Ok(None) => {
720 warn!("Connection closed before authentication");
721 return;
722 }
723 Err(_) => {
724 let _ = send_auth_error(
725 &mut socket,
726 &WsAuthError::AuthTimeout(auth_timeout.as_secs()),
727 )
728 .await;
729 warn!(
730 timeout_secs = auth_timeout.as_secs(),
731 "Authentication timeout"
732 );
733 return;
734 }
735 };
736
737 let key_info = match validator.validate(&auth_msg.api_key).await {
739 Ok(info) => info,
740 Err(e) => {
741 let _ = send_auth_error(&mut socket, &e).await;
742 warn!(error = %e, "API key validation failed");
743 return;
744 }
745 };
746
747 if !key_info.tier.allows_websocket() {
749 let err = WsAuthError::TierNotAllowed(key_info.tier.to_string());
750 let _ = send_auth_error(&mut socket, &err).await;
751 return;
752 }
753
754 let conn_info = match tracker.register(&key_info, addr) {
756 Ok(info) => info,
757 Err(e) => {
758 let _ = send_auth_error(&mut socket, &e).await;
759 return;
760 }
761 };
762
763 let auth_result = WsAuthResult {
765 success: true,
766 error: None,
767 connection_id: Some(conn_info.connection_id.to_string()),
768 tier: Some(conn_info.tier.to_string()),
769 rate_limit: Some(conn_info.tier.rate_limit()),
770 session_timeout_secs: Some(conn_info.tier.session_timeout().as_secs()),
771 };
772
773 if let Ok(json) = serde_json::to_string(&auth_result) {
774 let _ = socket.send(Message::Text(json)).await;
775 }
776
777 info!(
778 connection_id = %conn_info.connection_id,
779 tier = %key_info.tier,
780 remote_addr = %addr,
781 "WebSocket connection authenticated via first message"
782 );
783
784 handle_authenticated_socket(socket, conn_info, tracker).await;
786}
787
788async fn send_auth_error(socket: &mut WebSocket, error: &WsAuthError) -> Result<(), axum::Error> {
790 let result = WsAuthResult {
791 success: false,
792 error: Some(error.to_string()),
793 connection_id: None,
794 tier: None,
795 rate_limit: None,
796 session_timeout_secs: None,
797 };
798
799 if let Ok(json) = serde_json::to_string(&result) {
800 socket.send(Message::Text(json)).await?;
801 }
802
803 socket
805 .send(Message::Close(Some(axum::extract::ws::CloseFrame {
806 code: axum::extract::ws::close_code::POLICY,
807 reason: error.to_string().into(),
808 })))
809 .await?;
810
811 Ok(())
812}
813
814async fn handle_authenticated_socket(
816 mut socket: WebSocket,
817 conn_info: ConnectionInfo,
818 tracker: Arc<ConnectionTracker>,
819) {
820 let connection_id = conn_info.connection_id;
821 let tier = conn_info.tier;
822
823 let (_tx, mut rx) = mpsc::channel::<Message>(100);
825
826 let send_task = tokio::spawn({
828 let tracker = Arc::clone(&tracker);
829 async move {
830 while let Some(_msg) = rx.recv().await {
831 if let Err(e) = tracker.check_rate_limit(connection_id) {
833 warn!(
834 connection_id = %connection_id,
835 error = %e,
836 "Rate limit exceeded"
837 );
838 let _error_msg = serde_json::json!({
840 "jsonrpc": "2.0",
841 "error": {
842 "code": -32000,
843 "message": e.to_string()
844 }
845 });
846 break;
848 }
849 }
850 }
851 });
852
853 while let Some(msg) = socket.recv().await {
855 match msg {
856 Ok(Message::Text(text)) => {
857 debug!(
858 connection_id = %connection_id,
859 msg_len = text.len(),
860 "Received text message"
861 );
862
863 if text.len() > tier.max_message_size() {
865 warn!(
866 connection_id = %connection_id,
867 size = text.len(),
868 max = tier.max_message_size(),
869 "Message size exceeds tier limit"
870 );
871 let error_msg = serde_json::json!({
873 "jsonrpc": "2.0",
874 "error": {
875 "code": -32000,
876 "message": format!("Message size {} exceeds limit {}", text.len(), tier.max_message_size())
877 }
878 });
879 if let Ok(json) = serde_json::to_string(&error_msg) {
880 let _ = socket.send(Message::Text(json)).await;
881 }
882 continue;
883 }
884
885 let _ = socket.send(Message::Text(text)).await;
888 }
889 Ok(Message::Binary(data)) => {
890 debug!(
891 connection_id = %connection_id,
892 size = data.len(),
893 "Received binary message"
894 );
895
896 if data.len() > tier.max_message_size() {
898 warn!(
899 connection_id = %connection_id,
900 size = data.len(),
901 max = tier.max_message_size(),
902 "Binary message size exceeds tier limit"
903 );
904 continue;
905 }
906
907 let _ = socket.send(Message::Binary(data)).await;
909 }
910 Ok(Message::Ping(data)) => {
911 let _ = socket.send(Message::Pong(data)).await;
912 }
913 Ok(Message::Pong(_)) => {
914 }
916 Ok(Message::Close(_)) => {
917 info!(connection_id = %connection_id, "Client initiated close");
918 break;
919 }
920 Err(e) => {
921 error!(
922 connection_id = %connection_id,
923 error = %e,
924 "WebSocket error"
925 );
926 break;
927 }
928 }
929 }
930
931 send_task.abort();
933 tracker.unregister(connection_id);
934 info!(connection_id = %connection_id, "Connection closed");
935}
936
937pub async fn ws_auth_middleware<V: ApiKeyValidator>(
943 State(state): State<WsAuthState<V>>,
944 request: Request<Body>,
945 next: Next,
946) -> Result<Response, WsAuthError> {
947 let is_upgrade = request
949 .headers()
950 .get("upgrade")
951 .and_then(|v| v.to_str().ok())
952 .map(|v| v.eq_ignore_ascii_case("websocket"))
953 .unwrap_or(false);
954
955 if !is_upgrade {
956 return Ok(next.run(request).await);
958 }
959
960 if state.require_tls {
962 let scheme = request.uri().scheme_str().unwrap_or("http");
963 if scheme != "https" && scheme != "wss" {
964 warn!("WebSocket connection rejected: TLS required");
965 return Err(WsAuthError::Internal(
966 "Secure connection (wss://) required".to_string(),
967 ));
968 }
969 }
970
971 Ok(next.run(request).await)
973}
974
975fn constant_time_compare(a: &str, b: &str) -> bool {
981 let a_bytes = a.as_bytes();
982 let b_bytes = b.as_bytes();
983
984 if a_bytes.len() != b_bytes.len() {
985 let mut _dummy: u8 = 0;
987 for byte in a_bytes.iter() {
988 _dummy |= *byte;
989 }
990 return false;
991 }
992
993 let mut result: u8 = 0;
994 for (x, y) in a_bytes.iter().zip(b_bytes.iter()) {
995 result |= x ^ y;
996 }
997
998 result == 0
999}
1000
1001pub fn generate_api_key() -> String {
1003 format!("rk_{}", Uuid::new_v4().to_string().replace('-', ""))
1004}
1005
1006#[cfg(test)]
1011mod tests {
1012 use super::*;
1013
1014 #[test]
1015 fn test_subscription_tier_limits() {
1016 assert_eq!(SubscriptionTier::Free.max_connections(), 1);
1017 assert_eq!(SubscriptionTier::Pro.max_connections(), 5);
1018 assert_eq!(SubscriptionTier::Team.max_connections(), 25);
1019 assert_eq!(SubscriptionTier::Enterprise.max_connections(), 100);
1020 }
1021
1022 #[test]
1023 fn test_subscription_tier_rate_limits() {
1024 assert_eq!(SubscriptionTier::Free.rate_limit(), 60);
1025 assert_eq!(SubscriptionTier::Pro.rate_limit(), 300);
1026 assert_eq!(SubscriptionTier::Team.rate_limit(), 1000);
1027 assert_eq!(SubscriptionTier::Enterprise.rate_limit(), 10000);
1028 }
1029
1030 #[test]
1031 fn test_constant_time_compare() {
1032 assert!(constant_time_compare("secret", "secret"));
1033 assert!(!constant_time_compare("secret", "Secret"));
1034 assert!(!constant_time_compare("short", "longer"));
1035 assert!(!constant_time_compare("", "nonempty"));
1036 }
1037
1038 #[test]
1039 fn test_generate_api_key() {
1040 let key = generate_api_key();
1041 assert!(key.starts_with("rk_"));
1042 assert_eq!(key.len(), 35); }
1044
1045 #[tokio::test]
1046 async fn test_in_memory_validator() {
1047 let validator = InMemoryApiKeyValidator::new();
1048
1049 let info = ApiKeyInfo {
1050 key_id: "key_123".to_string(),
1051 owner_id: "user_456".to_string(),
1052 tier: SubscriptionTier::Pro,
1053 expires_at: None,
1054 metadata: HashMap::new(),
1055 };
1056
1057 validator.add_key("test_api_key".to_string(), info.clone());
1058
1059 let result = validator.validate("test_api_key").await;
1061 assert!(result.is_ok());
1062 let validated = result.unwrap();
1063 assert_eq!(validated.tier, SubscriptionTier::Pro);
1064
1065 let result = validator.validate("wrong_key").await;
1067 assert!(matches!(result, Err(WsAuthError::InvalidApiKey)));
1068 }
1069
1070 #[test]
1071 fn test_connection_tracker() {
1072 let tracker = ConnectionTracker::new();
1073
1074 let key_info = ApiKeyInfo {
1075 key_id: "key_123".to_string(),
1076 owner_id: "user_456".to_string(),
1077 tier: SubscriptionTier::Free, expires_at: None,
1079 metadata: HashMap::new(),
1080 };
1081
1082 let addr: SocketAddr = "127.0.0.1:9100".parse().unwrap();
1083
1084 let conn1 = tracker.register(&key_info, addr);
1086 assert!(conn1.is_ok());
1087
1088 let conn2 = tracker.register(&key_info, addr);
1090 assert!(matches!(
1091 conn2,
1092 Err(WsAuthError::ConnectionLimitExceeded(_, 1))
1093 ));
1094
1095 tracker.unregister(conn1.unwrap().connection_id);
1097
1098 let conn3 = tracker.register(&key_info, addr);
1100 assert!(conn3.is_ok());
1101 }
1102
1103 #[test]
1104 fn test_rate_limiting() {
1105 let tracker = ConnectionTracker::new();
1106
1107 let key_info = ApiKeyInfo {
1108 key_id: "key_123".to_string(),
1109 owner_id: "user_456".to_string(),
1110 tier: SubscriptionTier::Free, expires_at: None,
1112 metadata: HashMap::new(),
1113 };
1114
1115 let addr: SocketAddr = "127.0.0.1:9100".parse().unwrap();
1116 let conn = tracker.register(&key_info, addr).unwrap();
1117
1118 for _ in 0..60 {
1120 assert!(tracker.check_rate_limit(conn.connection_id).is_ok());
1121 }
1122
1123 assert!(matches!(
1125 tracker.check_rate_limit(conn.connection_id),
1126 Err(WsAuthError::RateLimitExceeded(60))
1127 ));
1128 }
1129
1130 #[test]
1131 fn test_api_key_extraction() {
1132 let validator = InMemoryApiKeyValidator::new();
1133 let state = WsAuthState::new(validator);
1134
1135 let mut headers = HeaderMap::new();
1136
1137 headers.insert("Authorization", "Bearer my_api_key".parse().unwrap());
1139 assert_eq!(
1140 state.extract_api_key_from_headers(&headers),
1141 Some("my_api_key".to_string())
1142 );
1143
1144 headers.insert("Authorization", "raw_api_key".parse().unwrap());
1146 assert_eq!(
1147 state.extract_api_key_from_headers(&headers),
1148 Some("raw_api_key".to_string())
1149 );
1150
1151 let state = state.with_api_key_header("X-Api-Key");
1153 headers.insert("X-Api-Key", "custom_key".parse().unwrap());
1154 assert_eq!(
1155 state.extract_api_key_from_headers(&headers),
1156 Some("custom_key".to_string())
1157 );
1158 }
1159}