use anyhow::{anyhow, Result};
use chrono::{DateTime, Duration as ChronoDuration, Utc};
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet, VecDeque};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{broadcast, RwLock};
use tokio::time::interval;
use tracing::{debug, info};
use crate::StreamEvent;
pub struct GraphQLSubscriptionManager {
subscriptions: Arc<RwLock<HashMap<String, EnhancedSubscription>>>,
groups: Arc<RwLock<HashMap<String, SubscriptionGroup>>>,
windows: Arc<RwLock<HashMap<String, SubscriptionWindow>>>,
event_tx: broadcast::Sender<SubscriptionEvent>,
config: SubscriptionConfig,
stats: Arc<RwLock<SubscriptionStats>>,
}
#[derive(Debug, Clone)]
pub struct SubscriptionConfig {
pub max_subscriptions: usize,
pub max_subscriptions_per_client: usize,
pub default_window_size: Duration,
pub enable_windowing: bool,
pub enable_advanced_filtering: bool,
pub heartbeat_interval: Duration,
pub subscription_timeout: Duration,
}
impl Default for SubscriptionConfig {
fn default() -> Self {
Self {
max_subscriptions: 10000,
max_subscriptions_per_client: 100,
default_window_size: Duration::from_secs(60),
enable_windowing: true,
enable_advanced_filtering: true,
heartbeat_interval: Duration::from_secs(30),
subscription_timeout: Duration::from_secs(300),
}
}
}
#[derive(Debug, Clone)]
pub struct EnhancedSubscription {
pub id: String,
pub client_id: String,
pub query: String,
pub variables: HashMap<String, serde_json::Value>,
pub filters: Vec<AdvancedFilter>,
pub window: Option<WindowSpec>,
pub state: SubscriptionState,
pub metadata: SubscriptionMetadata,
pub stats: SubscriptionStatistics,
}
#[derive(Debug, Clone, PartialEq)]
pub enum SubscriptionState {
Active,
Paused,
Reconnecting {
attempts: u32,
next_retry: DateTime<Utc>,
},
Throttled { until: DateTime<Utc> },
Terminated {
reason: String,
timestamp: DateTime<Utc>,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum AdvancedFilter {
TimeRange {
start: Option<DateTime<Utc>>,
end: Option<DateTime<Utc>>,
},
ValueFilter {
field: String,
operator: FilterOperator,
value: serde_json::Value,
},
PatternMatch {
field: String,
pattern: String,
case_sensitive: bool,
},
GeoFilter {
latitude: f64,
longitude: f64,
radius_km: f64,
},
SemanticFilter {
subject_pattern: Option<String>,
predicate_pattern: Option<String>,
object_pattern: Option<String>,
},
AggregationFilter {
function: AggregationFunction,
threshold: f64,
},
CompositeFilter {
operator: LogicalOperator,
filters: Vec<Box<AdvancedFilter>>,
},
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum FilterOperator {
Equal,
NotEqual,
LessThan,
LessThanOrEqual,
GreaterThan,
GreaterThanOrEqual,
Contains,
StartsWith,
EndsWith,
In,
NotIn,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum AggregationFunction {
Count,
Sum,
Average,
Min,
Max,
StdDev,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum LogicalOperator {
And,
Or,
Not,
Xor,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WindowSpec {
pub window_type: WindowType,
pub size: WindowSize,
pub slide: Option<WindowSize>,
pub triggers: Vec<WindowTrigger>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum WindowType {
Tumbling,
Sliding,
Session,
Global,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum WindowSize {
Time(Duration),
Count(usize),
Bytes(usize),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum WindowTrigger {
TimeInterval(Duration),
EventCount(usize),
Watermark,
EventType(String),
Custom(String),
}
#[derive(Debug, Clone)]
pub struct SubscriptionMetadata {
pub created_at: DateTime<Utc>,
pub last_activity: DateTime<Utc>,
pub tags: Vec<String>,
pub priority: SubscriptionPriority,
pub namespace: Option<String>,
pub groups: Vec<String>,
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
pub enum SubscriptionPriority {
Low = 1,
Normal = 2,
High = 3,
Critical = 4,
}
#[derive(Debug, Clone, Default)]
pub struct SubscriptionStatistics {
pub events_received: u64,
pub updates_sent: u64,
pub bytes_sent: u64,
pub avg_latency_ms: f64,
pub max_latency_ms: f64,
pub error_count: u64,
pub last_error: Option<String>,
}
#[derive(Debug, Clone)]
pub struct SubscriptionGroup {
pub id: String,
pub name: String,
pub members: HashSet<String>,
pub filters: Vec<AdvancedFilter>,
pub config: GroupConfig,
}
#[derive(Debug, Clone)]
pub struct GroupConfig {
pub shared_windowing: bool,
pub load_balancing: bool,
pub max_members: usize,
}
pub struct SubscriptionWindow {
pub id: String,
pub subscription_id: String,
pub spec: WindowSpec,
pub buffer: VecDeque<WindowedEvent>,
pub state: WindowState,
}
#[derive(Debug, Clone)]
pub struct WindowedEvent {
pub event: StreamEvent,
pub timestamp: DateTime<Utc>,
pub sequence_id: u64,
}
#[derive(Debug, Clone)]
pub struct WindowState {
pub start_time: DateTime<Utc>,
pub end_time: Option<DateTime<Utc>>,
pub event_count: usize,
pub total_bytes: usize,
pub is_closed: bool,
}
#[derive(Debug, Clone)]
pub enum SubscriptionEvent {
Update {
subscription_id: String,
data: serde_json::Value,
timestamp: DateTime<Utc>,
},
StateChanged {
subscription_id: String,
old_state: SubscriptionState,
new_state: SubscriptionState,
},
Heartbeat {
subscription_id: String,
timestamp: DateTime<Utc>,
},
Error {
subscription_id: String,
error: String,
timestamp: DateTime<Utc>,
},
}
#[derive(Debug, Clone, Default)]
pub struct SubscriptionStats {
pub total_subscriptions: usize,
pub active_subscriptions: usize,
pub paused_subscriptions: usize,
pub reconnecting_subscriptions: usize,
pub total_events_processed: u64,
pub total_updates_sent: u64,
pub avg_processing_time_ms: f64,
}
impl GraphQLSubscriptionManager {
pub fn new(config: SubscriptionConfig) -> Self {
let (event_tx, _) = broadcast::channel(10000);
let manager = Self {
subscriptions: Arc::new(RwLock::new(HashMap::new())),
groups: Arc::new(RwLock::new(HashMap::new())),
windows: Arc::new(RwLock::new(HashMap::new())),
event_tx,
config,
stats: Arc::new(RwLock::new(SubscriptionStats::default())),
};
manager.start_heartbeat_task();
manager.start_cleanup_task();
manager
}
pub async fn register_subscription(
&self,
subscription: EnhancedSubscription,
) -> Result<String> {
let mut subscriptions = self.subscriptions.write().await;
if subscriptions.len() >= self.config.max_subscriptions {
return Err(anyhow!("Maximum subscriptions limit reached"));
}
let client_count = subscriptions
.values()
.filter(|s| s.client_id == subscription.client_id)
.count();
if client_count >= self.config.max_subscriptions_per_client {
return Err(anyhow!("Client subscription limit reached"));
}
let id = subscription.id.clone();
if self.config.enable_windowing {
if let Some(window_spec) = &subscription.window {
self.create_window(&id, window_spec.clone()).await?;
}
}
subscriptions.insert(id.clone(), subscription);
let mut stats = self.stats.write().await;
stats.total_subscriptions = subscriptions.len();
stats.active_subscriptions = subscriptions
.values()
.filter(|s| s.state == SubscriptionState::Active)
.count();
info!("Registered GraphQL subscription: {}", id);
Ok(id)
}
pub async fn unregister_subscription(&self, subscription_id: &str) -> Result<()> {
let mut subscriptions = self.subscriptions.write().await;
subscriptions
.remove(subscription_id)
.ok_or_else(|| anyhow!("Subscription not found"))?;
self.windows.write().await.remove(subscription_id);
let mut stats = self.stats.write().await;
stats.total_subscriptions = subscriptions.len();
stats.active_subscriptions = subscriptions
.values()
.filter(|s| s.state == SubscriptionState::Active)
.count();
info!("Unregistered GraphQL subscription: {}", subscription_id);
Ok(())
}
pub async fn pause_subscription(&self, subscription_id: &str) -> Result<()> {
let mut subscriptions = self.subscriptions.write().await;
let subscription = subscriptions
.get_mut(subscription_id)
.ok_or_else(|| anyhow!("Subscription not found"))?;
let old_state = subscription.state.clone();
subscription.state = SubscriptionState::Paused;
let _ = self.event_tx.send(SubscriptionEvent::StateChanged {
subscription_id: subscription_id.to_string(),
old_state,
new_state: SubscriptionState::Paused,
});
info!("Paused subscription: {}", subscription_id);
Ok(())
}
pub async fn resume_subscription(&self, subscription_id: &str) -> Result<()> {
let mut subscriptions = self.subscriptions.write().await;
let subscription = subscriptions
.get_mut(subscription_id)
.ok_or_else(|| anyhow!("Subscription not found"))?;
let old_state = subscription.state.clone();
subscription.state = SubscriptionState::Active;
subscription.metadata.last_activity = Utc::now();
let _ = self.event_tx.send(SubscriptionEvent::StateChanged {
subscription_id: subscription_id.to_string(),
old_state,
new_state: SubscriptionState::Active,
});
info!("Resumed subscription: {}", subscription_id);
Ok(())
}
pub async fn process_event(&self, event: &StreamEvent) -> Result<()> {
let subscriptions = self.subscriptions.read().await;
for (sub_id, subscription) in subscriptions.iter() {
if subscription.state != SubscriptionState::Active {
continue;
}
if !self.apply_filters(event, &subscription.filters).await? {
continue;
}
if self.config.enable_windowing && subscription.window.is_some() {
self.add_to_window(sub_id, event).await?;
} else {
self.send_update(sub_id, event).await?;
}
}
let mut stats = self.stats.write().await;
stats.total_events_processed += 1;
Ok(())
}
async fn apply_filters(
&self,
_event: &StreamEvent,
filters: &[AdvancedFilter],
) -> Result<bool> {
if !self.config.enable_advanced_filtering || filters.is_empty() {
return Ok(true);
}
for filter in filters {
match filter {
AdvancedFilter::TimeRange { start, end } => {
let now = Utc::now();
if let Some(start) = start {
if &now < start {
return Ok(false);
}
}
if let Some(end) = end {
if &now > end {
return Ok(false);
}
}
}
_ => {
}
}
}
Ok(true)
}
async fn create_window(&self, subscription_id: &str, spec: WindowSpec) -> Result<()> {
let window = SubscriptionWindow {
id: uuid::Uuid::new_v4().to_string(),
subscription_id: subscription_id.to_string(),
spec,
buffer: VecDeque::new(),
state: WindowState {
start_time: Utc::now(),
end_time: None,
event_count: 0,
total_bytes: 0,
is_closed: false,
},
};
self.windows
.write()
.await
.insert(subscription_id.to_string(), window);
Ok(())
}
async fn add_to_window(&self, subscription_id: &str, event: &StreamEvent) -> Result<()> {
let mut windows = self.windows.write().await;
if let Some(window) = windows.get_mut(subscription_id) {
let windowed_event = WindowedEvent {
event: event.clone(),
timestamp: Utc::now(),
sequence_id: window.state.event_count as u64,
};
window.buffer.push_back(windowed_event);
window.state.event_count += 1;
self.check_window_triggers(window).await?;
}
Ok(())
}
async fn check_window_triggers(&self, window: &mut SubscriptionWindow) -> Result<()> {
for trigger in &window.spec.triggers {
match trigger {
WindowTrigger::EventCount(count) if window.state.event_count >= *count => {
debug!("Window trigger fired: event count {}", count);
}
WindowTrigger::TimeInterval(duration) => {
let elapsed = Utc::now() - window.state.start_time;
if elapsed > ChronoDuration::from_std(*duration)? {
debug!("Window trigger fired: time interval {:?}", duration);
}
}
_ => {}
}
}
Ok(())
}
async fn send_update(&self, subscription_id: &str, event: &StreamEvent) -> Result<()> {
let data = self.convert_event_to_graphql(event)?;
let _ = self.event_tx.send(SubscriptionEvent::Update {
subscription_id: subscription_id.to_string(),
data,
timestamp: Utc::now(),
});
let mut subscriptions = self.subscriptions.write().await;
if let Some(subscription) = subscriptions.get_mut(subscription_id) {
subscription.stats.updates_sent += 1;
subscription.metadata.last_activity = Utc::now();
}
Ok(())
}
fn convert_event_to_graphql(&self, event: &StreamEvent) -> Result<serde_json::Value> {
Ok(serde_json::json!({
"type": format!("{:?}", event),
"timestamp": Utc::now().to_rfc3339(),
}))
}
fn start_heartbeat_task(&self) {
let subscriptions = self.subscriptions.clone();
let event_tx = self.event_tx.clone();
let interval_duration = self.config.heartbeat_interval;
tokio::spawn(async move {
let mut interval_timer = interval(interval_duration);
loop {
interval_timer.tick().await;
let subs = subscriptions.read().await;
for (sub_id, subscription) in subs.iter() {
if subscription.state == SubscriptionState::Active {
let _ = event_tx.send(SubscriptionEvent::Heartbeat {
subscription_id: sub_id.clone(),
timestamp: Utc::now(),
});
}
}
}
});
}
fn start_cleanup_task(&self) {
let subscriptions = self.subscriptions.clone();
let timeout = self.config.subscription_timeout;
tokio::spawn(async move {
let mut interval_timer = interval(Duration::from_secs(60));
loop {
interval_timer.tick().await;
let mut subs = subscriptions.write().await;
let now = Utc::now();
subs.retain(|_, subscription| {
let inactive_duration = now - subscription.metadata.last_activity;
inactive_duration
< ChronoDuration::from_std(timeout)
.expect("timeout should be valid chrono Duration")
});
}
});
}
pub async fn get_stats(&self) -> SubscriptionStats {
self.stats.read().await.clone()
}
pub fn subscribe(&self) -> broadcast::Receiver<SubscriptionEvent> {
self.event_tx.subscribe()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_subscription_config_defaults() {
let config = SubscriptionConfig::default();
assert_eq!(config.max_subscriptions, 10000);
assert!(config.enable_windowing);
}
#[tokio::test]
async fn test_subscription_states() {
let state = SubscriptionState::Active;
assert_eq!(state, SubscriptionState::Active);
let state = SubscriptionState::Paused;
assert_eq!(state, SubscriptionState::Paused);
}
#[tokio::test]
async fn test_filter_operators() {
assert_eq!(FilterOperator::Equal, FilterOperator::Equal);
assert_ne!(FilterOperator::Equal, FilterOperator::NotEqual);
}
#[tokio::test]
async fn test_window_types() {
let window = WindowSpec {
window_type: WindowType::Tumbling,
size: WindowSize::Time(Duration::from_secs(60)),
slide: None,
triggers: vec![WindowTrigger::EventCount(100)],
};
assert_eq!(window.window_type, WindowType::Tumbling);
}
#[tokio::test]
async fn test_subscription_priority() {
assert!(SubscriptionPriority::Critical > SubscriptionPriority::High);
assert!(SubscriptionPriority::High > SubscriptionPriority::Normal);
assert!(SubscriptionPriority::Normal > SubscriptionPriority::Low);
}
}