use std::{collections::HashSet, sync::Arc, time::Duration};
use axum::{
extract::{
Query, State,
ws::{Message, WebSocket, WebSocketUpgrade},
},
http::StatusCode,
response::IntoResponse,
};
use futures::{SinkExt, StreamExt, future::BoxFuture};
use serde::Deserialize;
use tracing::{debug, info, warn};
use super::{
connections::{ConnectionManager, ConnectionState},
protocol::{ClientMessage, ServerMessage},
subscriptions::{EventKind, SubscriptionDetails, SubscriptionManager, parse_filter},
};
#[derive(Debug, Clone)]
pub struct RealtimeConfig {
pub max_connections_per_context: usize,
pub heartbeat_interval: Duration,
pub idle_timeout: Duration,
pub max_subscriptions_per_entity: usize,
pub event_channel_capacity: usize,
pub token_revalidation_interval: Duration,
pub max_consecutive_drops: usize,
pub connection_event_capacity: usize,
}
impl Default for RealtimeConfig {
fn default() -> Self {
Self {
max_connections_per_context: 10,
heartbeat_interval: Duration::from_secs(30),
idle_timeout: Duration::from_secs(60),
max_subscriptions_per_entity: 10_000,
event_channel_capacity: 10_000,
token_revalidation_interval: Duration::from_secs(30),
max_consecutive_drops: 50,
connection_event_capacity: 256,
}
}
}
pub trait TokenValidator: Send + Sync + 'static {
fn validate<'a>(&'a self, token: &'a str) -> BoxFuture<'a, Result<TokenInfo, String>>;
}
#[derive(Debug, Clone)]
pub struct TokenInfo {
pub user_id: String,
pub context_hash: u64,
pub expires_at: i64,
}
#[derive(Clone)]
pub struct RealtimeState {
pub server: Arc<RealtimeServer>,
pub validator: Arc<dyn TokenValidator>,
}
pub struct RealtimeServer {
pub(crate) connections: Arc<ConnectionManager>,
pub(crate) subscriptions: Arc<SubscriptionManager>,
pub(crate) known_entities: HashSet<String>,
pub(crate) config: RealtimeConfig,
}
impl RealtimeServer {
#[must_use]
pub fn new(config: RealtimeConfig) -> Self {
let max_subs = config.max_subscriptions_per_entity;
let connections = Arc::new(ConnectionManager::new(
config.max_consecutive_drops,
config.connection_event_capacity,
));
Self {
connections,
subscriptions: Arc::new(SubscriptionManager::new(max_subs)),
known_entities: HashSet::new(),
config,
}
}
#[must_use]
pub fn with_entities(config: RealtimeConfig, entities: HashSet<String>) -> Self {
let max_subs = config.max_subscriptions_per_entity;
let connections = Arc::new(ConnectionManager::new(
config.max_consecutive_drops,
config.connection_event_capacity,
));
Self {
connections,
subscriptions: Arc::new(SubscriptionManager::new(max_subs)),
known_entities: entities,
config,
}
}
#[must_use]
pub fn active_connections(&self) -> usize {
self.connections.count()
}
}
#[derive(Debug, Deserialize)]
pub struct WsQueryParams {
pub token: Option<String>,
}
pub async fn ws_handler(
headers: axum::http::HeaderMap,
Query(params): Query<WsQueryParams>,
ws: WebSocketUpgrade,
State(state): State<RealtimeState>,
) -> impl IntoResponse {
let token = params.token.or_else(|| {
headers
.get(axum::http::header::AUTHORIZATION)
.and_then(|v| v.to_str().ok())
.and_then(|v| v.strip_prefix("Bearer "))
.map(str::to_owned)
});
let Some(token) = token else {
return StatusCode::UNAUTHORIZED.into_response();
};
let token_info = match state.validator.validate(&token).await {
Ok(info) => info,
Err(reason) => {
warn!(reason = %reason, "Realtime WebSocket auth failed");
return StatusCode::UNAUTHORIZED.into_response();
},
};
let context_hash = token_info.context_hash;
let current = state.server.connections.count_by_context(context_hash);
if current >= state.server.config.max_connections_per_context {
return StatusCode::TOO_MANY_REQUESTS.into_response();
}
let server = state.server.clone();
ws.on_upgrade(move |socket| handle_realtime_connection(socket, server, token_info))
.into_response()
}
#[allow(clippy::cognitive_complexity)] async fn handle_realtime_connection(
socket: WebSocket,
server: Arc<RealtimeServer>,
token_info: TokenInfo,
) {
let context_hash = token_info.context_hash;
let connection_id = uuid::Uuid::new_v4().to_string();
let config = &server.config;
let conn_state = ConnectionState::new(
connection_id.clone(),
token_info.user_id.clone(),
token_info.context_hash,
token_info.expires_at,
);
let (mut event_rx, control_rx) = server.connections.insert(conn_state);
tokio::pin!(control_rx);
info!(
connection_id = %connection_id,
user_id = %token_info.user_id,
"Realtime WebSocket connected"
);
let (mut sender, mut receiver) = socket.split();
let connected_msg = ServerMessage::Connected {
connection_id: connection_id.clone(),
};
if let Ok(json) = connected_msg.to_json() {
if sender.send(Message::Text(json.into())).await.is_err() {
server.connections.remove(&connection_id);
return;
}
}
let mut heartbeat_interval = tokio::time::interval(config.heartbeat_interval);
heartbeat_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
heartbeat_interval.tick().await;
let mut idle_deadline = tokio::time::Instant::now() + config.idle_timeout;
loop {
tokio::select! {
_ = heartbeat_interval.tick() => {
let now_ts = chrono::Utc::now().timestamp();
if now_ts >= token_info.expires_at {
debug!(connection_id = %connection_id, "Token expired, closing connection");
if let Ok(json) = ServerMessage::TokenExpired.to_json() {
let _ = sender.send(Message::Text(json.into())).await;
}
let _ = sender.send(Message::Close(Some(axum::extract::ws::CloseFrame {
code: 4401,
reason: "token expired".into(),
}))).await;
break;
}
if let Ok(json) = ServerMessage::Ping.to_json() {
if sender.send(Message::Text(json.into())).await.is_err() {
break;
}
}
}
() = tokio::time::sleep_until(idle_deadline) => {
debug!(connection_id = %connection_id, "Idle timeout, closing connection");
let _ = sender.send(Message::Close(Some(axum::extract::ws::CloseFrame {
code: 1000,
reason: "idle timeout".into(),
}))).await;
break;
}
Some(event_json) = event_rx.recv() => {
if sender.send(Message::Text(event_json.into())).await.is_err() {
break;
}
}
signal = control_rx.as_mut() => {
if let Ok(sig) = signal {
debug!(
connection_id = %connection_id,
code = sig.code,
reason = %sig.reason,
"Slow consumer: closing connection"
);
let _ = sender.send(Message::Close(Some(axum::extract::ws::CloseFrame {
code: sig.code,
reason: sig.reason.into(),
}))).await;
}
break;
}
msg = receiver.next() => {
match msg {
Some(Ok(Message::Text(text))) => {
idle_deadline = tokio::time::Instant::now() + config.idle_timeout;
match serde_json::from_str::<ClientMessage>(&text) {
Ok(ClientMessage::Pong) => {
debug!(connection_id = %connection_id, "Received pong");
}
Ok(ClientMessage::Subscribe { entity, event, filter }) => {
let reply = handle_subscribe(
&server,
&connection_id,
context_hash,
&entity,
&event,
filter.as_deref(),
);
if let Ok(json) = reply.to_json() {
if sender.send(Message::Text(json.into())).await.is_err() {
break;
}
}
}
Ok(ClientMessage::Unsubscribe { entity }) => {
let _ = server.subscriptions.unsubscribe(&connection_id, &entity);
let reply = ServerMessage::Unsubscribed { entity };
if let Ok(json) = reply.to_json() {
if sender.send(Message::Text(json.into())).await.is_err() {
break;
}
}
}
Err(_) => {
}
}
}
Some(Ok(Message::Close(_))) => {
debug!(connection_id = %connection_id, "Client sent close");
break;
}
Some(Ok(Message::Pong(_))) => {
idle_deadline = tokio::time::Instant::now() + config.idle_timeout;
}
Some(Err(e)) => {
warn!(connection_id = %connection_id, error = %e, "WebSocket error");
break;
}
None => break,
_ => {}
}
}
}
}
server.subscriptions.unsubscribe_all(&connection_id);
server.connections.remove(&connection_id);
info!(connection_id = %connection_id, "Realtime WebSocket disconnected");
}
fn handle_subscribe(
server: &RealtimeServer,
connection_id: &str,
context_hash: u64,
entity: &str,
event: &str,
filter: Option<&str>,
) -> ServerMessage {
if !server.known_entities.is_empty() && !server.known_entities.contains(entity) {
return ServerMessage::Error {
message: format!("unknown entity: {entity}"),
};
}
let event_filter = if event == "*" {
None
} else {
match EventKind::parse(event) {
Ok(kind) => Some(kind),
Err(e) => return ServerMessage::Error { message: e },
}
};
let field_filters = if let Some(f) = filter {
match parse_filter(f) {
Ok(filters) => filters,
Err(e) => return ServerMessage::Error { message: e },
}
} else {
Vec::new()
};
let details = SubscriptionDetails {
event_filter,
field_filters,
security_context_hash: context_hash,
};
match server.subscriptions.subscribe(connection_id, entity, details) {
Ok(_) => ServerMessage::Subscribed {
entity: entity.to_owned(),
},
Err(e) => ServerMessage::Error { message: e },
}
}