use crate::config::Config;
use async_trait::async_trait;
use dashmap::DashMap;
use tokio::sync::broadcast;
use uuid::Uuid;
#[derive(Debug, Clone, PartialEq)]
pub enum UserEvent {
MessageReceived,
Disconnect,
}
#[async_trait]
pub trait Notifier: Send + Sync {
fn subscribe(&self, user_id: Uuid) -> broadcast::Receiver<UserEvent>;
fn notify(&self, user_id: Uuid, event: UserEvent);
}
pub struct InMemoryNotifier {
channels: std::sync::Arc<DashMap<Uuid, broadcast::Sender<UserEvent>>>,
channel_capacity: usize,
}
impl InMemoryNotifier {
pub fn new(config: Config) -> Self {
let channels = std::sync::Arc::new(DashMap::new());
let map_ref = channels.clone();
let interval_secs = config.notification_gc_interval_secs;
tokio::spawn(async move {
let mut interval = tokio::time::interval(std::time::Duration::from_secs(interval_secs));
loop {
interval.tick().await;
map_ref.retain(|_, sender: &mut broadcast::Sender<UserEvent>| sender.receiver_count() > 0);
}
});
Self { channels, channel_capacity: config.notification_channel_capacity }
}
}
#[async_trait]
impl Notifier for InMemoryNotifier {
fn subscribe(&self, user_id: Uuid) -> broadcast::Receiver<UserEvent> {
let tx = self
.channels
.entry(user_id)
.or_insert_with(|| {
let (tx, _rx) = broadcast::channel(self.channel_capacity);
tx
})
.value()
.clone();
tx.subscribe()
}
fn notify(&self, user_id: Uuid, event: UserEvent) {
if let Some(tx) = self.channels.get(&user_id) {
let _ = tx.send(event);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
fn test_config(gc_interval: u64, capacity: usize) -> Config {
Config {
database_url: "".to_string(),
jwt_secret: "".to_string(),
rate_limit_per_second: 5,
rate_limit_burst: 10,
auth_rate_limit_per_second: 1,
auth_rate_limit_burst: 3,
server_host: "0.0.0.0".to_string(),
server_port: 3000,
message_ttl_days: 30,
max_inbox_size: 1000,
access_token_ttl_secs: 900,
refresh_token_ttl_days: 30,
message_cleanup_interval_secs: 300,
notification_gc_interval_secs: gc_interval,
notification_channel_capacity: capacity,
message_batch_limit: 50,
trusted_proxies: "127.0.0.1/32".to_string(),
ws_outbound_buffer_size: 32,
ws_ack_buffer_size: 100,
ws_ack_batch_size: 50,
ws_ack_flush_interval_ms: 500,
}
}
#[tokio::test]
async fn test_notifier_gc_cleans_up_unused_channels() {
let config = test_config(1, 16);
let notifier = InMemoryNotifier::new(config);
let user_id = Uuid::new_v4();
let rx = notifier.subscribe(user_id);
assert!(notifier.channels.contains_key(&user_id));
assert_eq!(notifier.channels.len(), 1);
drop(rx);
tokio::time::sleep(Duration::from_millis(1100)).await;
assert!(!notifier.channels.contains_key(&user_id));
assert_eq!(notifier.channels.len(), 0);
}
#[tokio::test]
async fn test_notifier_gc_keeps_active_channels() {
let config = test_config(1, 16);
let notifier = InMemoryNotifier::new(config);
let user_id = Uuid::new_v4();
let _rx = notifier.subscribe(user_id);
tokio::time::sleep(Duration::from_millis(1100)).await;
assert!(notifier.channels.contains_key(&user_id));
}
}