use anyhow::{anyhow, Result};
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::{broadcast, RwLock};
use tokio::time::interval;
use tracing::{debug, info, warn};
use crate::StreamEvent;
#[derive(Debug, Clone)]
pub struct BridgeConfig {
pub max_queue_size: usize,
pub debounce_duration: Duration,
pub enable_batching: bool,
pub max_batch_size: usize,
pub batch_interval: Duration,
pub enable_query_filtering: bool,
pub max_subscriptions: usize,
}
impl Default for BridgeConfig {
fn default() -> Self {
Self {
max_queue_size: 10000,
debounce_duration: Duration::from_millis(100),
enable_batching: true,
max_batch_size: 100,
batch_interval: Duration::from_millis(500),
enable_query_filtering: true,
max_subscriptions: 1000,
}
}
}
pub struct GraphQLBridge {
config: BridgeConfig,
subscriptions: Arc<RwLock<HashMap<String, GraphQLSubscription>>>,
event_sender: broadcast::Sender<GraphQLUpdate>,
stats: Arc<RwLock<BridgeStats>>,
debounce_tracker: Arc<RwLock<HashMap<String, Instant>>>,
batch_buffer: Arc<RwLock<Vec<GraphQLUpdate>>>,
}
#[derive(Debug, Clone)]
pub struct GraphQLSubscription {
pub id: String,
pub query: String,
pub variables: HashMap<String, serde_json::Value>,
pub filters: Vec<SubscriptionFilter>,
pub created_at: DateTime<Utc>,
pub last_update: Option<DateTime<Utc>>,
pub update_count: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum SubscriptionFilter {
SubjectPattern(String),
PredicatePattern(String),
ObjectPattern(String),
GraphFilter(String),
CustomFilter(String),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GraphQLUpdate {
pub id: String,
pub timestamp: DateTime<Utc>,
pub update_type: GraphQLUpdateType,
pub data: serde_json::Value,
pub subscriptions: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum GraphQLUpdateType {
DataAdded,
DataRemoved,
DataModified,
BulkUpdate,
QueryResultChanged,
}
#[derive(Debug, Clone, Default)]
pub struct BridgeStats {
pub events_processed: u64,
pub updates_sent: u64,
pub updates_batched: u64,
pub updates_debounced: u64,
pub active_subscriptions: usize,
pub avg_processing_time_ms: f64,
}
impl GraphQLBridge {
pub fn new(config: BridgeConfig) -> Self {
let (event_sender, _) = broadcast::channel(config.max_queue_size);
let bridge = Self {
config,
subscriptions: Arc::new(RwLock::new(HashMap::new())),
event_sender,
stats: Arc::new(RwLock::new(BridgeStats::default())),
debounce_tracker: Arc::new(RwLock::new(HashMap::new())),
batch_buffer: Arc::new(RwLock::new(Vec::new())),
};
if bridge.config.enable_batching {
bridge.start_batch_processor();
}
bridge
}
pub async fn register_subscription(&self, subscription: GraphQLSubscription) -> Result<String> {
let mut subscriptions = self.subscriptions.write().await;
if subscriptions.len() >= self.config.max_subscriptions {
return Err(anyhow!("Maximum subscriptions limit reached"));
}
let id = subscription.id.clone();
subscriptions.insert(id.clone(), subscription);
self.stats.write().await.active_subscriptions = subscriptions.len();
info!("Registered GraphQL subscription: {}", id);
Ok(id)
}
pub async fn unregister_subscription(&self, subscription_id: &str) -> Result<()> {
let mut subscriptions = self.subscriptions.write().await;
subscriptions
.remove(subscription_id)
.ok_or_else(|| anyhow!("Subscription not found"))?;
self.stats.write().await.active_subscriptions = subscriptions.len();
info!("Unregistered GraphQL subscription: {}", subscription_id);
Ok(())
}
pub async fn process_stream_event(&self, event: &StreamEvent) -> Result<()> {
let start_time = Instant::now();
let update = self.convert_stream_event_to_update(event).await?;
if self.should_debounce(&update).await {
self.stats.write().await.updates_debounced += 1;
return Ok(());
}
self.update_debounce_tracker(&update).await;
if self.config.enable_batching {
self.add_to_batch(update).await?;
} else {
self.send_update(update).await?;
}
let mut stats = self.stats.write().await;
stats.events_processed += 1;
let processing_time = start_time.elapsed().as_millis() as f64;
stats.avg_processing_time_ms = (stats.avg_processing_time_ms + processing_time) / 2.0;
Ok(())
}
async fn convert_stream_event_to_update(&self, event: &StreamEvent) -> Result<GraphQLUpdate> {
let (update_type, data) = match event {
StreamEvent::TripleAdded {
subject,
predicate,
object,
graph,
metadata,
} => (
GraphQLUpdateType::DataAdded,
serde_json::json!({
"subject": subject,
"predicate": predicate,
"object": object,
"graph": graph,
"timestamp": metadata.timestamp,
}),
),
StreamEvent::TripleRemoved {
subject,
predicate,
object,
graph,
metadata,
} => (
GraphQLUpdateType::DataRemoved,
serde_json::json!({
"subject": subject,
"predicate": predicate,
"object": object,
"graph": graph,
"timestamp": metadata.timestamp,
}),
),
StreamEvent::QuadAdded {
subject,
predicate,
object,
graph,
metadata,
} => (
GraphQLUpdateType::DataAdded,
serde_json::json!({
"subject": subject,
"predicate": predicate,
"object": object,
"graph": graph,
"timestamp": metadata.timestamp,
}),
),
StreamEvent::QuadRemoved {
subject,
predicate,
object,
graph,
metadata,
} => (
GraphQLUpdateType::DataRemoved,
serde_json::json!({
"subject": subject,
"predicate": predicate,
"object": object,
"graph": graph,
"timestamp": metadata.timestamp,
}),
),
StreamEvent::QueryResultAdded {
query_id,
result,
metadata,
} => (
GraphQLUpdateType::QueryResultChanged,
serde_json::json!({
"query_id": query_id,
"result": result.bindings,
"execution_time": result.execution_time.as_millis(),
"timestamp": metadata.timestamp,
}),
),
StreamEvent::QueryResultRemoved {
query_id,
result,
metadata,
} => (
GraphQLUpdateType::QueryResultChanged,
serde_json::json!({
"query_id": query_id,
"result": result.bindings,
"execution_time": result.execution_time.as_millis(),
"timestamp": metadata.timestamp,
}),
),
_ => (
GraphQLUpdateType::BulkUpdate,
serde_json::json!({
"message": "Bulk update occurred",
"timestamp": Utc::now(),
}),
),
};
let relevant_subscriptions = self.find_relevant_subscriptions(&data).await;
Ok(GraphQLUpdate {
id: uuid::Uuid::new_v4().to_string(),
timestamp: Utc::now(),
update_type,
data,
subscriptions: relevant_subscriptions,
})
}
async fn find_relevant_subscriptions(&self, data: &serde_json::Value) -> Vec<String> {
let subscriptions = self.subscriptions.read().await;
if !self.config.enable_query_filtering {
return subscriptions.keys().cloned().collect();
}
let mut relevant = Vec::new();
for (id, subscription) in subscriptions.iter() {
if self.subscription_matches_data(subscription, data) {
relevant.push(id.clone());
}
}
relevant
}
fn subscription_matches_data(
&self,
subscription: &GraphQLSubscription,
data: &serde_json::Value,
) -> bool {
if subscription.filters.is_empty() {
return true;
}
for filter in &subscription.filters {
match filter {
SubscriptionFilter::SubjectPattern(pattern) => {
if let Some(subject) = data.get("subject").and_then(|v| v.as_str()) {
if self.pattern_matches(pattern, subject) {
return true;
}
}
}
SubscriptionFilter::PredicatePattern(pattern) => {
if let Some(predicate) = data.get("predicate").and_then(|v| v.as_str()) {
if self.pattern_matches(pattern, predicate) {
return true;
}
}
}
SubscriptionFilter::ObjectPattern(pattern) => {
if let Some(object) = data.get("object").and_then(|v| v.as_str()) {
if self.pattern_matches(pattern, object) {
return true;
}
}
}
SubscriptionFilter::GraphFilter(graph_uri) => {
if let Some(graph) = data.get("graph").and_then(|v| v.as_str()) {
if graph == graph_uri {
return true;
}
}
}
SubscriptionFilter::CustomFilter(_expr) => {
return true;
}
}
}
false
}
fn pattern_matches(&self, pattern: &str, value: &str) -> bool {
if pattern == "*" {
return true;
}
if pattern.contains('*') {
let regex_pattern = pattern.replace('*', ".*");
if let Ok(regex) = regex::Regex::new(®ex_pattern) {
return regex.is_match(value);
}
}
pattern == value
}
async fn should_debounce(&self, update: &GraphQLUpdate) -> bool {
let tracker = self.debounce_tracker.read().await;
if let Some(last_update) = tracker.get(&update.id) {
let elapsed = Instant::now().duration_since(*last_update);
elapsed < self.config.debounce_duration
} else {
false
}
}
async fn update_debounce_tracker(&self, update: &GraphQLUpdate) {
let mut tracker = self.debounce_tracker.write().await;
tracker.insert(update.id.clone(), Instant::now());
}
async fn add_to_batch(&self, update: GraphQLUpdate) -> Result<()> {
let mut buffer = self.batch_buffer.write().await;
buffer.push(update);
if buffer.len() >= self.config.max_batch_size {
let updates = std::mem::take(&mut *buffer);
drop(buffer);
self.send_batch(updates).await?;
}
Ok(())
}
async fn send_update(&self, update: GraphQLUpdate) -> Result<()> {
match self.event_sender.send(update.clone()) {
Ok(receiver_count) => {
debug!("Sent GraphQL update to {} receivers", receiver_count);
self.stats.write().await.updates_sent += 1;
Ok(())
}
Err(e) => {
warn!("No active GraphQL subscription receivers: {}", e);
Ok(())
}
}
}
async fn send_batch(&self, updates: Vec<GraphQLUpdate>) -> Result<()> {
for update in updates {
self.send_update(update).await?;
}
self.stats.write().await.updates_batched += 1;
Ok(())
}
fn start_batch_processor(&self) {
let batch_buffer = Arc::clone(&self.batch_buffer);
let event_sender = self.event_sender.clone();
let batch_interval = self.config.batch_interval;
let stats = Arc::clone(&self.stats);
tokio::spawn(async move {
let mut interval = interval(batch_interval);
loop {
interval.tick().await;
let updates = {
let mut buffer = batch_buffer.write().await;
if buffer.is_empty() {
continue;
}
std::mem::take(&mut *buffer)
};
if !updates.is_empty() {
debug!("Processing batch of {} updates", updates.len());
for update in updates {
if let Err(e) = event_sender.send(update) {
warn!("Failed to send batched update: {}", e);
} else {
stats.write().await.updates_sent += 1;
}
}
stats.write().await.updates_batched += 1;
}
}
});
}
pub fn subscribe(&self) -> broadcast::Receiver<GraphQLUpdate> {
self.event_sender.subscribe()
}
pub async fn get_stats(&self) -> BridgeStats {
self.stats.read().await.clone()
}
pub async fn list_subscriptions(&self) -> Vec<String> {
self.subscriptions.read().await.keys().cloned().collect()
}
pub async fn get_subscription(&self, id: &str) -> Option<GraphQLSubscription> {
self.subscriptions.read().await.get(id).cloned()
}
}
pub fn create_simple_subscription(
query: String,
filters: Vec<SubscriptionFilter>,
) -> GraphQLSubscription {
GraphQLSubscription {
id: uuid::Uuid::new_v4().to_string(),
query,
variables: HashMap::new(),
filters,
created_at: Utc::now(),
last_update: None,
update_count: 0,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_bridge_config_default() {
let config = BridgeConfig::default();
assert_eq!(config.max_queue_size, 10000);
assert!(config.enable_batching);
assert!(config.enable_query_filtering);
}
#[tokio::test]
async fn test_graphql_bridge_creation() {
let bridge = GraphQLBridge::new(BridgeConfig::default());
let stats = bridge.get_stats().await;
assert_eq!(stats.active_subscriptions, 0);
assert_eq!(stats.events_processed, 0);
}
#[tokio::test]
async fn test_subscription_registration() {
let bridge = GraphQLBridge::new(BridgeConfig::default());
let subscription = create_simple_subscription(
"subscription { triples { subject predicate object } }".to_string(),
vec![],
);
let id = bridge.register_subscription(subscription).await.unwrap();
assert!(!id.is_empty());
let stats = bridge.get_stats().await;
assert_eq!(stats.active_subscriptions, 1);
}
#[tokio::test]
async fn test_pattern_matching() {
let bridge = GraphQLBridge::new(BridgeConfig::default());
assert!(bridge.pattern_matches("*", "anything"));
assert!(bridge.pattern_matches("http://example.org/*", "http://example.org/resource"));
assert!(!bridge.pattern_matches("http://example.org/*", "http://other.org/resource"));
assert!(bridge.pattern_matches("exact_match", "exact_match"));
assert!(!bridge.pattern_matches("exact_match", "different"));
}
#[test]
fn test_subscription_filter_types() {
let filter = SubscriptionFilter::SubjectPattern("http://example.org/*".to_string());
matches!(filter, SubscriptionFilter::SubjectPattern(_));
let filter2 = SubscriptionFilter::GraphFilter("http://example.org/graph1".to_string());
matches!(filter2, SubscriptionFilter::GraphFilter(_));
}
}