use std::sync::atomic::{AtomicU32, Ordering};
use std::time::Duration;
use dashmap::DashMap;
use serde::Serialize;
use tokio::sync::mpsc;
use forge_core::cluster::NodeId;
use forge_core::realtime::{Delta, SessionId, SubscriptionId};
#[derive(Debug, Clone)]
pub struct RealtimeConfig {
pub max_subscriptions_per_session: usize,
}
impl Default for RealtimeConfig {
fn default() -> Self {
Self {
max_subscriptions_per_session: 50,
}
}
}
#[derive(Debug, Clone, Serialize)]
pub struct JobData {
pub job_id: String,
pub status: String,
#[serde(rename = "progress")]
pub progress_percent: Option<i32>,
#[serde(rename = "message")]
pub progress_message: Option<String>,
pub output: Option<serde_json::Value>,
pub error: Option<String>,
}
#[derive(Debug, Clone, Serialize)]
pub struct WorkflowData {
pub workflow_id: String,
pub status: String,
#[serde(rename = "step")]
pub current_step: Option<String>,
pub steps: Vec<WorkflowStepData>,
pub output: Option<serde_json::Value>,
pub error: Option<String>,
}
#[derive(Debug, Clone, Serialize)]
pub struct WorkflowStepData {
pub name: String,
pub status: String,
pub error: Option<String>,
}
#[derive(Debug, Clone)]
pub enum RealtimeMessage {
Subscribe {
id: String,
query: String,
args: serde_json::Value,
},
Unsubscribe {
subscription_id: SubscriptionId,
},
Ping,
Pong,
Data {
subscription_id: String,
data: serde_json::Value,
},
DeltaUpdate {
subscription_id: String,
delta: Delta<serde_json::Value>,
},
JobUpdate {
client_sub_id: String,
job: JobData,
},
WorkflowUpdate {
client_sub_id: String,
workflow: WorkflowData,
},
Error {
code: String,
message: String,
},
ErrorWithId {
id: String,
code: String,
message: String,
},
AuthSuccess,
AuthFailed {
reason: String,
},
Lagging,
}
struct SessionEntry {
sender: mpsc::Sender<RealtimeMessage>,
subscriptions: Vec<SubscriptionId>,
connected_at: chrono::DateTime<chrono::Utc>,
last_active: chrono::DateTime<chrono::Utc>,
consecutive_drops: AtomicU32,
}
const MAX_CONSECUTIVE_DROPS: u32 = 10;
pub struct SessionServer {
config: RealtimeConfig,
node_id: NodeId,
connections: DashMap<SessionId, SessionEntry>,
subscription_sessions: DashMap<SubscriptionId, SessionId>,
}
impl SessionServer {
pub fn new(node_id: NodeId, config: RealtimeConfig) -> Self {
Self {
config,
node_id,
connections: DashMap::new(),
subscription_sessions: DashMap::new(),
}
}
pub fn node_id(&self) -> NodeId {
self.node_id
}
pub fn config(&self) -> &RealtimeConfig {
&self.config
}
pub fn register_connection(
&self,
session_id: SessionId,
sender: mpsc::Sender<RealtimeMessage>,
) {
let now = chrono::Utc::now();
let entry = SessionEntry {
sender,
subscriptions: Vec::new(),
connected_at: now,
last_active: now,
consecutive_drops: AtomicU32::new(0),
};
self.connections.insert(session_id, entry);
}
pub fn remove_connection(&self, session_id: SessionId) -> Option<Vec<SubscriptionId>> {
if let Some((_, conn)) = self.connections.remove(&session_id) {
for sub_id in &conn.subscriptions {
self.subscription_sessions.remove(sub_id);
}
Some(conn.subscriptions)
} else {
None
}
}
pub fn add_subscription(
&self,
session_id: SessionId,
subscription_id: SubscriptionId,
) -> forge_core::Result<()> {
let mut conn = self
.connections
.get_mut(&session_id)
.ok_or_else(|| forge_core::ForgeError::Validation("Session not found".to_string()))?;
if conn.subscriptions.len() >= self.config.max_subscriptions_per_session {
return Err(forge_core::ForgeError::Validation(format!(
"Maximum subscriptions per session ({}) exceeded",
self.config.max_subscriptions_per_session
)));
}
conn.subscriptions.push(subscription_id);
drop(conn);
self.subscription_sessions
.insert(subscription_id, session_id);
Ok(())
}
pub fn remove_subscription(&self, subscription_id: SubscriptionId) {
if let Some((_, session_id)) = self.subscription_sessions.remove(&subscription_id)
&& let Some(mut conn) = self.connections.get_mut(&session_id)
{
conn.subscriptions.retain(|id| *id != subscription_id);
}
}
pub fn try_send_to_session(
&self,
session_id: SessionId,
message: RealtimeMessage,
) -> Result<(), SendError> {
let conn = self
.connections
.get(&session_id)
.ok_or(SendError::SessionNotFound)?;
match conn.sender.try_send(message) {
Ok(()) => {
conn.consecutive_drops.store(0, Ordering::Relaxed);
Ok(())
}
Err(mpsc::error::TrySendError::Full(_)) => {
let drops = conn.consecutive_drops.fetch_add(1, Ordering::Relaxed);
if drops >= MAX_CONSECUTIVE_DROPS {
let _ = conn.sender.try_send(RealtimeMessage::Lagging);
drop(conn);
self.evict_session(session_id);
Err(SendError::Evicted)
} else {
Err(SendError::Full)
}
}
Err(mpsc::error::TrySendError::Closed(_)) => {
drop(conn);
self.remove_connection(session_id);
Err(SendError::Closed)
}
}
}
pub async fn send_to_session(
&self,
session_id: SessionId,
message: RealtimeMessage,
) -> forge_core::Result<()> {
let sender = {
let conn = self.connections.get(&session_id).ok_or_else(|| {
forge_core::ForgeError::Validation("Session not found".to_string())
})?;
conn.sender.clone()
};
sender
.send(message)
.await
.map_err(|_| forge_core::ForgeError::Internal("Failed to send message".to_string()))
}
pub async fn broadcast_delta(
&self,
subscription_id: SubscriptionId,
delta: Delta<serde_json::Value>,
) -> forge_core::Result<()> {
let session_id = self.subscription_sessions.get(&subscription_id).map(|r| *r);
if let Some(session_id) = session_id {
let message = RealtimeMessage::DeltaUpdate {
subscription_id: subscription_id.to_string(),
delta,
};
self.send_to_session(session_id, message).await?;
}
Ok(())
}
fn evict_session(&self, session_id: SessionId) {
tracing::warn!(?session_id, "Evicting slow client");
self.remove_connection(session_id);
}
pub fn connection_count(&self) -> usize {
self.connections.len()
}
pub fn subscription_count(&self) -> usize {
self.subscription_sessions.len()
}
pub fn stats(&self) -> SessionStats {
let total_subscriptions: usize =
self.connections.iter().map(|c| c.subscriptions.len()).sum();
SessionStats {
connections: self.connections.len(),
subscriptions: total_subscriptions,
node_id: self.node_id,
}
}
pub fn cleanup_stale(&self, max_idle: Duration) {
let cutoff = chrono::Utc::now()
- chrono::Duration::from_std(max_idle).unwrap_or(chrono::TimeDelta::MAX);
let stale: Vec<(SessionId, chrono::DateTime<chrono::Utc>)> = self
.connections
.iter()
.filter(|entry| entry.last_active < cutoff)
.map(|entry| (*entry.key(), entry.connected_at))
.collect();
if let Some((_, oldest_connected_at)) =
stale.iter().min_by_key(|(_, connected_at)| *connected_at)
{
tracing::debug!(
count = stale.len(),
oldest_connected_at = %oldest_connected_at,
"Cleaning up stale connections"
);
}
for (session_id, _) in stale {
self.remove_connection(session_id);
}
}
}
#[derive(Debug)]
pub enum SendError {
SessionNotFound,
Full,
Closed,
Evicted,
}
#[derive(Debug, Clone)]
pub struct SessionStats {
pub connections: usize,
pub subscriptions: usize,
pub node_id: NodeId,
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
mod tests {
use super::*;
#[test]
fn test_realtime_config_default() {
let config = RealtimeConfig::default();
assert_eq!(config.max_subscriptions_per_session, 50);
}
#[test]
fn test_session_server_creation() {
let node_id = NodeId::new();
let server = SessionServer::new(node_id, RealtimeConfig::default());
assert_eq!(server.node_id(), node_id);
assert_eq!(server.connection_count(), 0);
assert_eq!(server.subscription_count(), 0);
}
#[test]
fn test_session_connection() {
let node_id = NodeId::new();
let server = SessionServer::new(node_id, RealtimeConfig::default());
let session_id = SessionId::new();
let (tx, _rx) = mpsc::channel(100);
server.register_connection(session_id, tx);
assert_eq!(server.connection_count(), 1);
let removed = server.remove_connection(session_id);
assert!(removed.is_some());
assert_eq!(server.connection_count(), 0);
}
#[test]
fn test_session_subscription() {
let node_id = NodeId::new();
let server = SessionServer::new(node_id, RealtimeConfig::default());
let session_id = SessionId::new();
let subscription_id = SubscriptionId::new();
let (tx, _rx) = mpsc::channel(100);
server.register_connection(session_id, tx);
server
.add_subscription(session_id, subscription_id)
.unwrap();
assert_eq!(server.subscription_count(), 1);
server.remove_subscription(subscription_id);
assert_eq!(server.subscription_count(), 0);
}
#[test]
fn test_session_subscription_limit() {
let node_id = NodeId::new();
let config = RealtimeConfig {
max_subscriptions_per_session: 2,
};
let server = SessionServer::new(node_id, config);
let session_id = SessionId::new();
let (tx, _rx) = mpsc::channel(100);
server.register_connection(session_id, tx);
server
.add_subscription(session_id, SubscriptionId::new())
.unwrap();
server
.add_subscription(session_id, SubscriptionId::new())
.unwrap();
let result = server.add_subscription(session_id, SubscriptionId::new());
assert!(result.is_err());
}
#[test]
fn test_try_send_backpressure() {
let node_id = NodeId::new();
let server = SessionServer::new(node_id, RealtimeConfig::default());
let session_id = SessionId::new();
let (tx, _rx) = mpsc::channel(1);
server.register_connection(session_id, tx);
let result = server.try_send_to_session(session_id, RealtimeMessage::Ping);
assert!(result.is_ok());
let result = server.try_send_to_session(session_id, RealtimeMessage::Ping);
assert!(matches!(result, Err(SendError::Full)));
}
#[test]
fn test_session_stats() {
let node_id = NodeId::new();
let server = SessionServer::new(node_id, RealtimeConfig::default());
let session_id = SessionId::new();
let (tx, _rx) = mpsc::channel(100);
server.register_connection(session_id, tx);
server
.add_subscription(session_id, SubscriptionId::new())
.unwrap();
server
.add_subscription(session_id, SubscriptionId::new())
.unwrap();
let stats = server.stats();
assert_eq!(stats.connections, 1);
assert_eq!(stats.subscriptions, 2);
assert_eq!(stats.node_id, node_id);
}
}