use std::collections::HashMap;
use std::sync::Arc;
use parking_lot::RwLock;
use crate::EmbeddedDatabase;
#[derive(Clone, Debug, serde::Serialize)]
pub struct ChangeEvent {
pub table: String,
pub event_type: String,
pub new_record: Option<serde_json::Value>,
pub old_record: Option<serde_json::Value>,
pub timestamp: String,
}
pub struct ChangeNotifier {
#[allow(dead_code)]
db: Arc<EmbeddedDatabase>,
subscriptions: RwLock<HashMap<String, usize>>,
sender: tokio::sync::broadcast::Sender<ChangeEvent>,
}
impl ChangeNotifier {
pub fn new(db: Arc<EmbeddedDatabase>) -> Self {
let (sender, _) = tokio::sync::broadcast::channel(1024);
Self {
db,
subscriptions: RwLock::new(HashMap::new()),
sender,
}
}
pub fn subscribe(&self) -> tokio::sync::broadcast::Receiver<ChangeEvent> {
self.sender.subscribe()
}
pub fn notify(
&self,
table: &str,
event_type: &str,
new_record: Option<serde_json::Value>,
old_record: Option<serde_json::Value>,
) {
let subs = self.subscriptions.read();
if subs.contains_key(table) || subs.contains_key("*") {
let event = ChangeEvent {
table: table.to_string(),
event_type: event_type.to_string(),
new_record,
old_record,
timestamp: chrono::Utc::now().to_rfc3339(),
};
let _ = self.sender.send(event);
}
}
pub fn add_table_subscription(&self, table: &str) {
let mut subs = self.subscriptions.write();
*subs.entry(table.to_string()).or_insert(0) += 1;
}
pub fn remove_table_subscription(&self, table: &str) {
let mut subs = self.subscriptions.write();
if let Some(count) = subs.get_mut(table) {
*count = count.saturating_sub(1);
if *count == 0 {
subs.remove(table);
}
}
}
#[cfg(test)]
pub fn subscriber_count(&self, table: &str) -> usize {
let subs = self.subscriptions.read();
subs.get(table).copied().unwrap_or(0)
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
fn make_notifier() -> ChangeNotifier {
let db = Arc::new(EmbeddedDatabase::new_in_memory().unwrap());
ChangeNotifier::new(db)
}
#[test]
fn test_subscribe_and_unsubscribe() {
let n = make_notifier();
assert_eq!(n.subscriber_count("users"), 0);
n.add_table_subscription("users");
assert_eq!(n.subscriber_count("users"), 1);
n.add_table_subscription("users");
assert_eq!(n.subscriber_count("users"), 2);
n.remove_table_subscription("users");
assert_eq!(n.subscriber_count("users"), 1);
n.remove_table_subscription("users");
assert_eq!(n.subscriber_count("users"), 0);
}
#[test]
fn test_remove_below_zero_is_safe() {
let n = make_notifier();
n.remove_table_subscription("unknown");
assert_eq!(n.subscriber_count("unknown"), 0);
}
#[tokio::test]
async fn test_notify_sends_to_subscriber() {
let n = make_notifier();
n.add_table_subscription("orders");
let mut rx = n.subscribe();
n.notify(
"orders",
"INSERT",
Some(serde_json::json!({"id": 1})),
None,
);
let event = rx.recv().await.unwrap();
assert_eq!(event.table, "orders");
assert_eq!(event.event_type, "INSERT");
assert!(event.new_record.is_some());
assert!(event.old_record.is_none());
}
#[tokio::test]
async fn test_notify_skips_unsubscribed_table() {
let n = make_notifier();
n.add_table_subscription("orders");
let mut rx = n.subscribe();
n.notify("other", "INSERT", Some(serde_json::json!({"id": 1})), None);
assert!(rx.try_recv().is_err());
}
#[tokio::test]
async fn test_wildcard_subscription() {
let n = make_notifier();
n.add_table_subscription("*");
let mut rx = n.subscribe();
n.notify("any_table", "DELETE", None, Some(serde_json::json!({"id": 99})));
let event = rx.recv().await.unwrap();
assert_eq!(event.table, "any_table");
assert_eq!(event.event_type, "DELETE");
}
}