use std::collections::{HashMap, HashSet, VecDeque};
use std::time::Instant;
use uuid::Uuid;
use super::{SubscriptionConfig, SubscriptionError};
pub type SessionSubscriptionId = [u8; 16];
#[derive(Debug, Clone)]
pub struct SessionSubscription {
pub id: SessionSubscriptionId,
pub query: String,
pub params: Vec<Option<Vec<u8>>>,
pub table_dependencies: HashSet<String>,
}
#[derive(Debug)]
pub struct SessionSubscriptionManager {
subscriptions: HashMap<SessionSubscriptionId, SessionSubscription>,
table_to_subscriptions: HashMap<String, HashSet<SessionSubscriptionId>>,
config: SubscriptionConfig,
recent_subscriptions: VecDeque<Instant>,
}
impl SessionSubscriptionManager {
pub fn new() -> Self {
Self::with_config(SubscriptionConfig::default())
}
pub fn with_config(config: SubscriptionConfig) -> Self {
Self {
subscriptions: HashMap::new(),
table_to_subscriptions: HashMap::new(),
config,
recent_subscriptions: VecDeque::new(),
}
}
pub fn subscribe(
&mut self,
query: String,
params: Vec<Option<Vec<u8>>>,
table_dependencies: HashSet<String>,
) -> Result<SessionSubscriptionId, SubscriptionError> {
if self.subscriptions.len() >= self.config.max_per_connection {
return Err(SubscriptionError::ConnectionLimitExceeded {
current: self.subscriptions.len(),
max: self.config.max_per_connection,
});
}
self.check_rate_limit()?;
let id = generate_subscription_id();
let subscription = SessionSubscription {
id,
query,
params,
table_dependencies: table_dependencies.clone(),
};
for table in &table_dependencies {
self.table_to_subscriptions.entry(table.clone()).or_default().insert(id);
}
self.subscriptions.insert(id, subscription);
self.recent_subscriptions.push_back(Instant::now());
Ok(id)
}
fn check_rate_limit(&mut self) -> Result<(), SubscriptionError> {
let now = Instant::now();
let one_second_ago = now - std::time::Duration::from_secs(1);
while let Some(front) = self.recent_subscriptions.front() {
if *front < one_second_ago {
self.recent_subscriptions.pop_front();
} else {
break;
}
}
if self.recent_subscriptions.len() >= self.config.rate_limit_per_second as usize {
if let Some(oldest) = self.recent_subscriptions.front() {
let elapsed = now.duration_since(*oldest);
let retry_after_ms = 1000u64.saturating_sub(elapsed.as_millis() as u64);
return Err(SubscriptionError::RateLimited { retry_after_ms });
}
}
Ok(())
}
pub fn unsubscribe(
&mut self,
subscription_id: &SessionSubscriptionId,
) -> Option<SessionSubscription> {
if let Some(subscription) = self.subscriptions.remove(subscription_id) {
for table in &subscription.table_dependencies {
if let Some(sub_ids) = self.table_to_subscriptions.get_mut(table) {
sub_ids.remove(subscription_id);
if sub_ids.is_empty() {
self.table_to_subscriptions.remove(table);
}
}
}
Some(subscription)
} else {
None
}
}
#[allow(dead_code)]
pub fn get_subscriptions_for_table(&self, table: &str) -> Vec<SessionSubscriptionId> {
self.table_to_subscriptions
.get(table)
.map(|ids| ids.iter().copied().collect())
.unwrap_or_default()
}
#[allow(dead_code)]
pub fn get(&self, subscription_id: &SessionSubscriptionId) -> Option<&SessionSubscription> {
self.subscriptions.get(subscription_id)
}
#[allow(dead_code)]
pub fn subscription_count(&self) -> usize {
self.subscriptions.len()
}
pub fn clear(&mut self) {
self.subscriptions.clear();
self.table_to_subscriptions.clear();
self.recent_subscriptions.clear();
}
#[allow(dead_code)]
pub fn all_subscription_ids(&self) -> Vec<SessionSubscriptionId> {
self.subscriptions.keys().copied().collect()
}
pub fn config(&self) -> &SubscriptionConfig {
&self.config
}
}
impl Default for SessionSubscriptionManager {
fn default() -> Self {
Self::new()
}
}
fn generate_subscription_id() -> SessionSubscriptionId {
*Uuid::new_v4().as_bytes()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_subscribe_and_unsubscribe() {
let mut manager = SessionSubscriptionManager::new();
let mut deps = HashSet::new();
deps.insert("users".to_string());
deps.insert("orders".to_string());
let id = manager
.subscribe(
"SELECT * FROM users JOIN orders ON users.id = orders.user_id".to_string(),
vec![],
deps,
)
.unwrap();
assert_eq!(manager.subscription_count(), 1);
assert!(manager.get(&id).is_some());
let user_subs = manager.get_subscriptions_for_table("users");
assert_eq!(user_subs.len(), 1);
assert_eq!(user_subs[0], id);
let order_subs = manager.get_subscriptions_for_table("orders");
assert_eq!(order_subs.len(), 1);
assert_eq!(order_subs[0], id);
let removed = manager.unsubscribe(&id);
assert!(removed.is_some());
assert_eq!(manager.subscription_count(), 0);
let user_subs = manager.get_subscriptions_for_table("users");
assert!(user_subs.is_empty());
}
#[test]
fn test_multiple_subscriptions_same_table() {
let mut manager = SessionSubscriptionManager::new();
let mut deps1 = HashSet::new();
deps1.insert("users".to_string());
let mut deps2 = HashSet::new();
deps2.insert("users".to_string());
deps2.insert("products".to_string());
let id1 = manager.subscribe("SELECT * FROM users".to_string(), vec![], deps1).unwrap();
let id2 = manager
.subscribe(
"SELECT * FROM users JOIN products ON users.id = products.seller_id".to_string(),
vec![],
deps2,
)
.unwrap();
assert_eq!(manager.subscription_count(), 2);
let user_subs = manager.get_subscriptions_for_table("users");
assert_eq!(user_subs.len(), 2);
assert!(user_subs.contains(&id1));
assert!(user_subs.contains(&id2));
let product_subs = manager.get_subscriptions_for_table("products");
assert_eq!(product_subs.len(), 1);
assert!(product_subs.contains(&id2));
manager.unsubscribe(&id1);
assert_eq!(manager.subscription_count(), 1);
let user_subs = manager.get_subscriptions_for_table("users");
assert_eq!(user_subs.len(), 1);
assert!(user_subs.contains(&id2));
}
#[test]
fn test_clear() {
let mut manager = SessionSubscriptionManager::new();
let mut deps = HashSet::new();
deps.insert("users".to_string());
manager.subscribe("SELECT * FROM users".to_string(), vec![], deps.clone()).unwrap();
manager.subscribe("SELECT * FROM users WHERE id = 1".to_string(), vec![], deps).unwrap();
assert_eq!(manager.subscription_count(), 2);
manager.clear();
assert_eq!(manager.subscription_count(), 0);
assert!(manager.get_subscriptions_for_table("users").is_empty());
}
#[test]
fn test_unsubscribe_nonexistent() {
let mut manager = SessionSubscriptionManager::new();
let fake_id = [0u8; 16];
let result = manager.unsubscribe(&fake_id);
assert!(result.is_none());
}
#[test]
fn test_connection_limit_exceeded() {
let config = SubscriptionConfig {
max_per_connection: 2,
max_global: 10000,
max_result_rows: 10000,
rate_limit_per_second: 100, ..Default::default()
};
let mut manager = SessionSubscriptionManager::with_config(config);
let mut deps = HashSet::new();
deps.insert("users".to_string());
manager.subscribe("SELECT * FROM users".to_string(), vec![], deps.clone()).unwrap();
manager
.subscribe("SELECT * FROM users WHERE id = 1".to_string(), vec![], deps.clone())
.unwrap();
let result =
manager.subscribe("SELECT * FROM users WHERE id = 2".to_string(), vec![], deps);
assert!(matches!(
result,
Err(SubscriptionError::ConnectionLimitExceeded { current: 2, max: 2 })
));
}
#[test]
fn test_rate_limit() {
let config = SubscriptionConfig {
max_per_connection: 100,
max_global: 10000,
max_result_rows: 10000,
rate_limit_per_second: 2,
..Default::default()
};
let mut manager = SessionSubscriptionManager::with_config(config);
let mut deps = HashSet::new();
deps.insert("users".to_string());
manager.subscribe("SELECT * FROM users".to_string(), vec![], deps.clone()).unwrap();
manager
.subscribe("SELECT * FROM users WHERE id = 1".to_string(), vec![], deps.clone())
.unwrap();
let result =
manager.subscribe("SELECT * FROM users WHERE id = 2".to_string(), vec![], deps);
assert!(matches!(result, Err(SubscriptionError::RateLimited { .. })));
}
}