use crate::config::Config;
use async_trait::async_trait;
use dashmap::DashMap;
use tokio::sync::broadcast;
use uuid::Uuid;
#[derive(Debug, Clone, PartialEq)]
pub enum UserEvent {
MessageReceived,
Disconnect,
}
#[async_trait]
pub trait Notifier: Send + Sync {
fn subscribe(&self, user_id: Uuid) -> broadcast::Receiver<UserEvent>;
fn notify(&self, user_id: Uuid, event: UserEvent);
}
pub struct InMemoryNotifier {
channels: std::sync::Arc<DashMap<Uuid, broadcast::Sender<UserEvent>>>,
channel_capacity: usize,
}
impl InMemoryNotifier {
pub fn new(config: Config) -> Self {
let channels = std::sync::Arc::new(DashMap::new());
let map_ref = channels.clone();
let interval_secs = config.notifications.gc_interval_secs;
tokio::spawn(async move {
let mut interval = tokio::time::interval(std::time::Duration::from_secs(interval_secs));
loop {
interval.tick().await;
map_ref.retain(|_, sender: &mut broadcast::Sender<UserEvent>| sender.receiver_count() > 0);
}
});
Self { channels, channel_capacity: config.notifications.channel_capacity }
}
}
#[async_trait]
impl Notifier for InMemoryNotifier {
fn subscribe(&self, user_id: Uuid) -> broadcast::Receiver<UserEvent> {
let tx = self
.channels
.entry(user_id)
.or_insert_with(|| {
let (tx, _rx) = broadcast::channel(self.channel_capacity);
tx
})
.value()
.clone();
tx.subscribe()
}
fn notify(&self, user_id: Uuid, event: UserEvent) {
if let Some(tx) = self.channels.get(&user_id) {
let _ = tx.send(event);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
fn test_config(gc_interval: u64, capacity: usize) -> Config {
Config {
database_url: "".to_string(),
ttl_days: 30,
server: crate::config::ServerConfig {
host: "0.0.0.0".to_string(),
port: 3000,
trusted_proxies: vec!["127.0.0.1/32".parse().unwrap()],
},
auth: crate::config::AuthConfig {
jwt_secret: "".to_string(),
access_token_ttl_secs: 900,
refresh_token_ttl_days: 30,
},
rate_limit: crate::config::RateLimitConfig { per_second: 5, burst: 10, auth_per_second: 1, auth_burst: 3 },
messaging: crate::config::MessagingConfig {
max_inbox_size: 1000,
cleanup_interval_secs: 300,
batch_limit: 50,
pre_key_refill_threshold: 20,
max_pre_keys: 100,
},
notifications: crate::config::NotificationConfig {
gc_interval_secs: gc_interval,
channel_capacity: capacity,
},
websocket: crate::config::WsConfig {
outbound_buffer_size: 32,
ack_buffer_size: 100,
ack_batch_size: 50,
ack_flush_interval_ms: 500,
},
s3: crate::config::S3Config {
bucket: "".to_string(),
region: "us-east-1".to_string(),
endpoint: None,
access_key: None,
secret_key: None,
force_path_style: false,
attachment_max_size_bytes: 52_428_800,
},
}
}
#[tokio::test]
async fn test_notifier_gc_cleans_up_unused_channels() {
let config = test_config(1, 16);
let notifier = InMemoryNotifier::new(config);
let user_id = Uuid::new_v4();
let rx = notifier.subscribe(user_id);
assert!(notifier.channels.contains_key(&user_id));
assert_eq!(notifier.channels.len(), 1);
drop(rx);
tokio::time::sleep(Duration::from_millis(1100)).await;
assert!(!notifier.channels.contains_key(&user_id));
assert_eq!(notifier.channels.len(), 0);
}
#[tokio::test]
async fn test_notifier_gc_keeps_active_channels() {
let config = test_config(1, 16);
let notifier = InMemoryNotifier::new(config);
let user_id = Uuid::new_v4();
let _rx = notifier.subscribe(user_id);
tokio::time::sleep(Duration::from_millis(1100)).await;
assert!(notifier.channels.contains_key(&user_id));
}
}