use std::collections::{HashMap, HashSet};
use uuid::Uuid;
pub type SessionSubscriptionId = [u8; 16];
#[derive(Debug, Clone)]
pub struct SessionSubscription {
pub id: SessionSubscriptionId,
pub query: String,
pub params: Vec<Option<Vec<u8>>>,
pub table_dependencies: HashSet<String>,
}
#[derive(Debug, Default)]
pub struct SessionSubscriptionManager {
subscriptions: HashMap<SessionSubscriptionId, SessionSubscription>,
table_to_subscriptions: HashMap<String, HashSet<SessionSubscriptionId>>,
}
impl SessionSubscriptionManager {
pub fn new() -> Self {
Self {
subscriptions: HashMap::new(),
table_to_subscriptions: HashMap::new(),
}
}
pub fn subscribe(
&mut self,
query: String,
params: Vec<Option<Vec<u8>>>,
table_dependencies: HashSet<String>,
) -> SessionSubscriptionId {
let id = generate_subscription_id();
let subscription = SessionSubscription {
id,
query,
params,
table_dependencies: table_dependencies.clone(),
};
for table in &table_dependencies {
self.table_to_subscriptions
.entry(table.clone())
.or_default()
.insert(id);
}
self.subscriptions.insert(id, subscription);
id
}
pub fn unsubscribe(&mut self, subscription_id: &SessionSubscriptionId) -> Option<SessionSubscription> {
if let Some(subscription) = self.subscriptions.remove(subscription_id) {
for table in &subscription.table_dependencies {
if let Some(sub_ids) = self.table_to_subscriptions.get_mut(table) {
sub_ids.remove(subscription_id);
if sub_ids.is_empty() {
self.table_to_subscriptions.remove(table);
}
}
}
Some(subscription)
} else {
None
}
}
#[allow(dead_code)]
pub fn get_subscriptions_for_table(&self, table: &str) -> Vec<SessionSubscriptionId> {
self.table_to_subscriptions
.get(table)
.map(|ids| ids.iter().copied().collect())
.unwrap_or_default()
}
#[allow(dead_code)]
pub fn get(&self, subscription_id: &SessionSubscriptionId) -> Option<&SessionSubscription> {
self.subscriptions.get(subscription_id)
}
#[allow(dead_code)]
pub fn subscription_count(&self) -> usize {
self.subscriptions.len()
}
pub fn clear(&mut self) {
self.subscriptions.clear();
self.table_to_subscriptions.clear();
}
#[allow(dead_code)]
pub fn all_subscription_ids(&self) -> Vec<SessionSubscriptionId> {
self.subscriptions.keys().copied().collect()
}
}
fn generate_subscription_id() -> SessionSubscriptionId {
*Uuid::new_v4().as_bytes()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_subscribe_and_unsubscribe() {
let mut manager = SessionSubscriptionManager::new();
let mut deps = HashSet::new();
deps.insert("users".to_string());
deps.insert("orders".to_string());
let id = manager.subscribe(
"SELECT * FROM users JOIN orders ON users.id = orders.user_id".to_string(),
vec![],
deps,
);
assert_eq!(manager.subscription_count(), 1);
assert!(manager.get(&id).is_some());
let user_subs = manager.get_subscriptions_for_table("users");
assert_eq!(user_subs.len(), 1);
assert_eq!(user_subs[0], id);
let order_subs = manager.get_subscriptions_for_table("orders");
assert_eq!(order_subs.len(), 1);
assert_eq!(order_subs[0], id);
let removed = manager.unsubscribe(&id);
assert!(removed.is_some());
assert_eq!(manager.subscription_count(), 0);
let user_subs = manager.get_subscriptions_for_table("users");
assert!(user_subs.is_empty());
}
#[test]
fn test_multiple_subscriptions_same_table() {
let mut manager = SessionSubscriptionManager::new();
let mut deps1 = HashSet::new();
deps1.insert("users".to_string());
let mut deps2 = HashSet::new();
deps2.insert("users".to_string());
deps2.insert("products".to_string());
let id1 = manager.subscribe("SELECT * FROM users".to_string(), vec![], deps1);
let id2 = manager.subscribe(
"SELECT * FROM users JOIN products ON users.id = products.seller_id".to_string(),
vec![],
deps2,
);
assert_eq!(manager.subscription_count(), 2);
let user_subs = manager.get_subscriptions_for_table("users");
assert_eq!(user_subs.len(), 2);
assert!(user_subs.contains(&id1));
assert!(user_subs.contains(&id2));
let product_subs = manager.get_subscriptions_for_table("products");
assert_eq!(product_subs.len(), 1);
assert!(product_subs.contains(&id2));
manager.unsubscribe(&id1);
assert_eq!(manager.subscription_count(), 1);
let user_subs = manager.get_subscriptions_for_table("users");
assert_eq!(user_subs.len(), 1);
assert!(user_subs.contains(&id2));
}
#[test]
fn test_clear() {
let mut manager = SessionSubscriptionManager::new();
let mut deps = HashSet::new();
deps.insert("users".to_string());
manager.subscribe("SELECT * FROM users".to_string(), vec![], deps.clone());
manager.subscribe("SELECT * FROM users WHERE id = 1".to_string(), vec![], deps);
assert_eq!(manager.subscription_count(), 2);
manager.clear();
assert_eq!(manager.subscription_count(), 0);
assert!(manager.get_subscriptions_for_table("users").is_empty());
}
#[test]
fn test_unsubscribe_nonexistent() {
let mut manager = SessionSubscriptionManager::new();
let fake_id = [0u8; 16];
let result = manager.unsubscribe(&fake_id);
assert!(result.is_none());
}
}