use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{mpsc, RwLock};
use forge_core::cluster::NodeId;
use forge_core::realtime::{Delta, SessionId, SubscriptionId};
use crate::gateway::websocket::{JobData, WorkflowData};
#[derive(Debug, Clone)]
pub struct WebSocketConfig {
pub max_subscriptions_per_connection: usize,
pub subscription_timeout: Duration,
pub subscription_rate_limit: usize,
pub heartbeat_interval: Duration,
pub max_message_size: usize,
pub reconnect: ReconnectConfig,
}
impl Default for WebSocketConfig {
fn default() -> Self {
Self {
max_subscriptions_per_connection: 50,
subscription_timeout: Duration::from_secs(30),
subscription_rate_limit: 100,
heartbeat_interval: Duration::from_secs(30),
max_message_size: 1024 * 1024, reconnect: ReconnectConfig::default(),
}
}
}
#[derive(Debug, Clone)]
pub struct ReconnectConfig {
pub enabled: bool,
pub max_attempts: usize,
pub delay: Duration,
pub max_delay: Duration,
pub backoff: BackoffStrategy,
}
impl Default for ReconnectConfig {
fn default() -> Self {
Self {
enabled: true,
max_attempts: 10,
delay: Duration::from_secs(1),
max_delay: Duration::from_secs(30),
backoff: BackoffStrategy::Exponential,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BackoffStrategy {
Linear,
Exponential,
Fixed,
}
#[derive(Debug, Clone)]
pub enum WebSocketMessage {
Subscribe {
id: String,
query: String,
args: serde_json::Value,
},
Unsubscribe { subscription_id: SubscriptionId },
Ping,
Pong,
Data {
subscription_id: SubscriptionId,
data: serde_json::Value,
},
DeltaUpdate {
subscription_id: SubscriptionId,
delta: Delta<serde_json::Value>,
},
JobUpdate { client_sub_id: String, job: JobData },
WorkflowUpdate {
client_sub_id: String,
workflow: WorkflowData,
},
Error { code: String, message: String },
ErrorWithId {
id: String,
code: String,
message: String,
},
}
#[derive(Debug)]
pub struct WebSocketConnection {
#[allow(dead_code)]
pub session_id: SessionId,
pub subscriptions: Vec<SubscriptionId>,
pub sender: mpsc::Sender<WebSocketMessage>,
#[allow(dead_code)]
pub connected_at: chrono::DateTime<chrono::Utc>,
pub last_active: chrono::DateTime<chrono::Utc>,
}
impl WebSocketConnection {
pub fn new(session_id: SessionId, sender: mpsc::Sender<WebSocketMessage>) -> Self {
let now = chrono::Utc::now();
Self {
session_id,
subscriptions: Vec::new(),
sender,
connected_at: now,
last_active: now,
}
}
pub fn add_subscription(&mut self, subscription_id: SubscriptionId) {
self.subscriptions.push(subscription_id);
self.last_active = chrono::Utc::now();
}
pub fn remove_subscription(&mut self, subscription_id: SubscriptionId) {
self.subscriptions.retain(|id| *id != subscription_id);
self.last_active = chrono::Utc::now();
}
pub async fn send(
&self,
message: WebSocketMessage,
) -> Result<(), mpsc::error::SendError<WebSocketMessage>> {
self.sender.send(message).await
}
}
pub struct WebSocketServer {
#[allow(dead_code)]
config: WebSocketConfig,
node_id: NodeId,
connections: Arc<RwLock<HashMap<SessionId, WebSocketConnection>>>,
subscription_sessions: Arc<RwLock<HashMap<SubscriptionId, SessionId>>>,
}
impl WebSocketServer {
pub fn new(node_id: NodeId, config: WebSocketConfig) -> Self {
Self {
config,
node_id,
connections: Arc::new(RwLock::new(HashMap::new())),
subscription_sessions: Arc::new(RwLock::new(HashMap::new())),
}
}
pub fn node_id(&self) -> NodeId {
self.node_id
}
pub fn config(&self) -> &WebSocketConfig {
&self.config
}
pub async fn register_connection(
&self,
session_id: SessionId,
sender: mpsc::Sender<WebSocketMessage>,
) {
let connection = WebSocketConnection::new(session_id, sender);
let mut connections = self.connections.write().await;
connections.insert(session_id, connection);
}
pub async fn remove_connection(&self, session_id: SessionId) -> Option<Vec<SubscriptionId>> {
let mut connections = self.connections.write().await;
if let Some(conn) = connections.remove(&session_id) {
let mut sub_sessions = self.subscription_sessions.write().await;
for sub_id in &conn.subscriptions {
sub_sessions.remove(sub_id);
}
Some(conn.subscriptions)
} else {
None
}
}
pub async fn add_subscription(
&self,
session_id: SessionId,
subscription_id: SubscriptionId,
) -> forge_core::Result<()> {
let mut connections = self.connections.write().await;
let conn = connections
.get_mut(&session_id)
.ok_or_else(|| forge_core::ForgeError::Validation("Session not found".to_string()))?;
if conn.subscriptions.len() >= self.config.max_subscriptions_per_connection {
return Err(forge_core::ForgeError::Validation(format!(
"Maximum subscriptions per connection ({}) exceeded",
self.config.max_subscriptions_per_connection
)));
}
conn.add_subscription(subscription_id);
let mut sub_sessions = self.subscription_sessions.write().await;
sub_sessions.insert(subscription_id, session_id);
Ok(())
}
pub async fn remove_subscription(&self, subscription_id: SubscriptionId) {
let session_id = {
let mut sub_sessions = self.subscription_sessions.write().await;
sub_sessions.remove(&subscription_id)
};
if let Some(session_id) = session_id {
let mut connections = self.connections.write().await;
if let Some(conn) = connections.get_mut(&session_id) {
conn.remove_subscription(subscription_id);
}
}
}
pub async fn send_to_session(
&self,
session_id: SessionId,
message: WebSocketMessage,
) -> forge_core::Result<()> {
let connections = self.connections.read().await;
let conn = connections
.get(&session_id)
.ok_or_else(|| forge_core::ForgeError::Validation("Session not found".to_string()))?;
conn.send(message)
.await
.map_err(|_| forge_core::ForgeError::Internal("Failed to send message".to_string()))
}
pub async fn broadcast_delta(
&self,
subscription_id: SubscriptionId,
delta: Delta<serde_json::Value>,
) -> forge_core::Result<()> {
let session_id = {
let sub_sessions = self.subscription_sessions.read().await;
sub_sessions.get(&subscription_id).copied()
};
if let Some(session_id) = session_id {
let message = WebSocketMessage::DeltaUpdate {
subscription_id,
delta,
};
self.send_to_session(session_id, message).await?;
}
Ok(())
}
pub async fn connection_count(&self) -> usize {
self.connections.read().await.len()
}
pub async fn subscription_count(&self) -> usize {
self.subscription_sessions.read().await.len()
}
pub async fn stats(&self) -> WebSocketStats {
let connections = self.connections.read().await;
let total_subscriptions: usize = connections.values().map(|c| c.subscriptions.len()).sum();
WebSocketStats {
connections: connections.len(),
subscriptions: total_subscriptions,
node_id: self.node_id,
}
}
pub async fn cleanup_stale(&self, max_idle: Duration) {
let cutoff = chrono::Utc::now() - chrono::Duration::from_std(max_idle).unwrap();
let mut connections = self.connections.write().await;
let mut sub_sessions = self.subscription_sessions.write().await;
connections.retain(|_, conn| {
if conn.last_active < cutoff {
for sub_id in &conn.subscriptions {
sub_sessions.remove(sub_id);
}
false
} else {
true
}
});
}
}
#[derive(Debug, Clone)]
pub struct WebSocketStats {
pub connections: usize,
pub subscriptions: usize,
pub node_id: NodeId,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_websocket_config_default() {
let config = WebSocketConfig::default();
assert_eq!(config.max_subscriptions_per_connection, 50);
assert_eq!(config.subscription_rate_limit, 100);
assert!(config.reconnect.enabled);
}
#[test]
fn test_reconnect_config_default() {
let config = ReconnectConfig::default();
assert!(config.enabled);
assert_eq!(config.max_attempts, 10);
assert_eq!(config.backoff, BackoffStrategy::Exponential);
}
#[tokio::test]
async fn test_websocket_server_creation() {
let node_id = NodeId::new();
let server = WebSocketServer::new(node_id, WebSocketConfig::default());
assert_eq!(server.node_id(), node_id);
assert_eq!(server.connection_count().await, 0);
assert_eq!(server.subscription_count().await, 0);
}
#[tokio::test]
async fn test_websocket_connection() {
let node_id = NodeId::new();
let server = WebSocketServer::new(node_id, WebSocketConfig::default());
let session_id = SessionId::new();
let (tx, _rx) = mpsc::channel(100);
server.register_connection(session_id, tx).await;
assert_eq!(server.connection_count().await, 1);
let removed = server.remove_connection(session_id).await;
assert!(removed.is_some());
assert_eq!(server.connection_count().await, 0);
}
#[tokio::test]
async fn test_websocket_subscription() {
let node_id = NodeId::new();
let server = WebSocketServer::new(node_id, WebSocketConfig::default());
let session_id = SessionId::new();
let subscription_id = SubscriptionId::new();
let (tx, _rx) = mpsc::channel(100);
server.register_connection(session_id, tx).await;
server
.add_subscription(session_id, subscription_id)
.await
.unwrap();
assert_eq!(server.subscription_count().await, 1);
server.remove_subscription(subscription_id).await;
assert_eq!(server.subscription_count().await, 0);
}
#[tokio::test]
async fn test_websocket_subscription_limit() {
let node_id = NodeId::new();
let config = WebSocketConfig {
max_subscriptions_per_connection: 2,
..Default::default()
};
let server = WebSocketServer::new(node_id, config);
let session_id = SessionId::new();
let (tx, _rx) = mpsc::channel(100);
server.register_connection(session_id, tx).await;
server
.add_subscription(session_id, SubscriptionId::new())
.await
.unwrap();
server
.add_subscription(session_id, SubscriptionId::new())
.await
.unwrap();
let result = server
.add_subscription(session_id, SubscriptionId::new())
.await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_websocket_stats() {
let node_id = NodeId::new();
let server = WebSocketServer::new(node_id, WebSocketConfig::default());
let session_id = SessionId::new();
let (tx, _rx) = mpsc::channel(100);
server.register_connection(session_id, tx).await;
server
.add_subscription(session_id, SubscriptionId::new())
.await
.unwrap();
server
.add_subscription(session_id, SubscriptionId::new())
.await
.unwrap();
let stats = server.stats().await;
assert_eq!(stats.connections, 1);
assert_eq!(stats.subscriptions, 2);
assert_eq!(stats.node_id, node_id);
}
}