use axum::{
extract::{Query, State},
response::sse::{Event, KeepAlive, Sse},
routing::get,
Router,
};
use futures::stream::Stream;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::convert::Infallible;
use std::sync::{Arc, RwLock};
use std::time::Duration;
use tokio::sync::broadcast;
#[derive(Clone)]
pub struct EventSourceManager {
tx: broadcast::Sender<StateChange>,
states: Arc<RwLock<HashMap<String, String>>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct StateChange {
pub changed: HashMap<String, String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct EventSourcePushHint {
pub url: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub types: Option<Vec<String>>,
}
#[derive(Debug, Deserialize)]
pub struct EventSourceQuery {
#[serde(default)]
pub types: Option<String>,
#[serde(default)]
pub closeafter: Option<u64>,
#[serde(default)]
pub ping: Option<u64>,
}
impl EventSourceManager {
pub fn new() -> Self {
let (tx, _rx) = broadcast::channel(100);
Self {
tx,
states: Arc::new(RwLock::new(HashMap::new())),
}
}
pub fn notify_change(&self, data_type: String, new_state: String) {
if let Ok(mut states) = self.states.write() {
states.insert(data_type.clone(), new_state.clone());
}
let mut changed = HashMap::new();
changed.insert(data_type, new_state);
let state_change = StateChange { changed };
let _ = self.tx.send(state_change);
}
pub fn get_state(&self, data_type: &str) -> Option<String> {
self.states
.read()
.ok()
.and_then(|states| states.get(data_type).cloned())
}
fn subscribe(&self) -> broadcast::Receiver<StateChange> {
self.tx.subscribe()
}
}
impl Default for EventSourceManager {
fn default() -> Self {
Self::new()
}
}
pub fn eventsource_routes() -> Router<EventSourceManager> {
Router::new().route("/eventsource", get(eventsource_handler))
}
async fn eventsource_handler(
Query(params): Query<EventSourceQuery>,
State(manager): State<EventSourceManager>,
) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
let types_filter: Option<Vec<String>> = params
.types
.map(|t| t.split(',').map(|s| s.trim().to_string()).collect());
let mut rx = manager.subscribe();
let close_after = params.closeafter.map(Duration::from_secs);
let ping_interval = params
.ping
.map(Duration::from_secs)
.unwrap_or(Duration::from_secs(30));
let stream = async_stream::stream! {
let start_time = tokio::time::Instant::now();
loop {
if let Some(timeout) = close_after {
if start_time.elapsed() >= timeout {
break;
}
}
tokio::select! {
result = rx.recv() => {
match result {
Ok(state_change) => {
let filtered_changes: HashMap<String, String> = if let Some(ref filter) = types_filter {
state_change.changed.into_iter()
.filter(|(k, _)| filter.contains(k))
.collect()
} else {
state_change.changed
};
if !filtered_changes.is_empty() {
let event_data = StateChange { changed: filtered_changes };
if let Ok(json) = serde_json::to_string(&event_data) {
yield Ok(Event::default()
.event("state")
.data(json));
}
}
}
Err(broadcast::error::RecvError::Lagged(_)) => {
yield Ok(Event::default()
.event("error")
.data("Client lagged behind"));
}
Err(broadcast::error::RecvError::Closed) => {
break;
}
}
}
_ = tokio::time::sleep(ping_interval) => {
yield Ok(Event::default().comment("ping"));
}
}
}
};
Sse::new(stream).keep_alive(KeepAlive::default())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_event_source_manager_new() {
let manager = EventSourceManager::new();
assert!(manager.get_state("Email").is_none());
}
#[test]
fn test_notify_change() {
let manager = EventSourceManager::new();
manager.notify_change("Email".to_string(), "state1".to_string());
assert_eq!(manager.get_state("Email"), Some("state1".to_string()));
}
#[test]
fn test_notify_multiple_changes() {
let manager = EventSourceManager::new();
manager.notify_change("Email".to_string(), "state1".to_string());
manager.notify_change("Mailbox".to_string(), "state2".to_string());
manager.notify_change("Thread".to_string(), "state3".to_string());
assert_eq!(manager.get_state("Email"), Some("state1".to_string()));
assert_eq!(manager.get_state("Mailbox"), Some("state2".to_string()));
assert_eq!(manager.get_state("Thread"), Some("state3".to_string()));
}
#[test]
fn test_state_update() {
let manager = EventSourceManager::new();
manager.notify_change("Email".to_string(), "state1".to_string());
assert_eq!(manager.get_state("Email"), Some("state1".to_string()));
manager.notify_change("Email".to_string(), "state2".to_string());
assert_eq!(manager.get_state("Email"), Some("state2".to_string()));
}
#[test]
fn test_subscribe() {
let manager = EventSourceManager::new();
let mut rx = manager.subscribe();
manager.notify_change("Email".to_string(), "state1".to_string());
let change = rx.try_recv().unwrap();
assert_eq!(change.changed.get("Email"), Some(&"state1".to_string()));
}
#[test]
fn test_multiple_subscribers() {
let manager = EventSourceManager::new();
let mut rx1 = manager.subscribe();
let mut rx2 = manager.subscribe();
manager.notify_change("Email".to_string(), "state1".to_string());
let change1 = rx1.try_recv().unwrap();
let change2 = rx2.try_recv().unwrap();
assert_eq!(change1.changed.get("Email"), Some(&"state1".to_string()));
assert_eq!(change2.changed.get("Email"), Some(&"state1".to_string()));
}
#[test]
fn test_state_change_serialization() {
let mut changed = HashMap::new();
changed.insert("Email".to_string(), "state123".to_string());
changed.insert("Mailbox".to_string(), "state456".to_string());
let state_change = StateChange { changed };
let json = serde_json::to_string(&state_change).unwrap();
assert!(json.contains("Email"));
assert!(json.contains("state123"));
}
#[test]
fn test_push_subscription_serialization() {
let subscription = EventSourcePushHint {
url: "https://push.example.com/abc123".to_string(),
types: Some(vec!["Email".to_string(), "Mailbox".to_string()]),
};
let json = serde_json::to_string(&subscription).unwrap();
assert!(json.contains("push.example.com"));
}
#[test]
fn test_event_source_manager_default() {
let manager = EventSourceManager::default();
assert!(manager.get_state("any").is_none());
}
#[test]
fn test_event_source_manager_clone() {
let manager1 = EventSourceManager::new();
manager1.notify_change("Email".to_string(), "state1".to_string());
let manager2 = manager1.clone();
assert_eq!(manager2.get_state("Email"), Some("state1".to_string()));
}
#[test]
fn test_get_nonexistent_state() {
let manager = EventSourceManager::new();
assert_eq!(manager.get_state("NonExistent"), None);
}
#[test]
fn test_notify_empty_state() {
let manager = EventSourceManager::new();
manager.notify_change("Email".to_string(), "".to_string());
assert_eq!(manager.get_state("Email"), Some("".to_string()));
}
#[test]
fn test_subscribe_before_notify() {
let manager = EventSourceManager::new();
let mut rx = manager.subscribe();
assert!(rx.try_recv().is_err());
manager.notify_change("Email".to_string(), "state1".to_string());
assert!(rx.try_recv().is_ok());
}
#[test]
fn test_subscribe_after_notify() {
let manager = EventSourceManager::new();
manager.notify_change("Email".to_string(), "state1".to_string());
let mut rx = manager.subscribe();
assert!(rx.try_recv().is_err());
assert_eq!(manager.get_state("Email"), Some("state1".to_string()));
}
#[test]
fn test_multiple_data_types() {
let manager = EventSourceManager::new();
manager.notify_change("Email".to_string(), "email_state".to_string());
manager.notify_change("Mailbox".to_string(), "mailbox_state".to_string());
manager.notify_change("Thread".to_string(), "thread_state".to_string());
manager.notify_change("Identity".to_string(), "identity_state".to_string());
assert_eq!(manager.get_state("Email"), Some("email_state".to_string()));
assert_eq!(
manager.get_state("Mailbox"),
Some("mailbox_state".to_string())
);
assert_eq!(
manager.get_state("Thread"),
Some("thread_state".to_string())
);
assert_eq!(
manager.get_state("Identity"),
Some("identity_state".to_string())
);
}
#[test]
fn test_state_change_empty_changed() {
let state_change = StateChange {
changed: HashMap::new(),
};
let json = serde_json::to_string(&state_change).unwrap();
assert!(json.contains("changed"));
}
#[test]
fn test_push_subscription_without_types() {
let subscription = EventSourcePushHint {
url: "https://push.example.com/def456".to_string(),
types: None,
};
let json = serde_json::to_string(&subscription).unwrap();
assert!(!json.contains("types"));
}
#[test]
fn test_concurrent_notifications() {
let manager = EventSourceManager::new();
let mut rx = manager.subscribe();
for i in 0..10 {
manager.notify_change(format!("Type{}", i), format!("state{}", i));
}
let mut received = 0;
while rx.try_recv().is_ok() {
received += 1;
}
assert!(received > 0);
}
#[test]
fn test_state_persistence_across_notifications() {
let manager = EventSourceManager::new();
manager.notify_change("Email".to_string(), "state1".to_string());
manager.notify_change("Mailbox".to_string(), "state2".to_string());
manager.notify_change("Email".to_string(), "state3".to_string());
assert_eq!(manager.get_state("Email"), Some("state3".to_string()));
assert_eq!(manager.get_state("Mailbox"), Some("state2".to_string()));
}
#[test]
fn test_subscriber_receives_only_new_changes() {
let manager = EventSourceManager::new();
manager.notify_change("Email".to_string(), "old_state".to_string());
let mut rx = manager.subscribe();
manager.notify_change("Email".to_string(), "new_state".to_string());
let change = rx.try_recv().unwrap();
assert_eq!(change.changed.get("Email"), Some(&"new_state".to_string()));
}
#[test]
fn test_broadcast_channel_capacity() {
let manager = EventSourceManager::new();
let mut rx = manager.subscribe();
for i in 0..200 {
manager.notify_change(format!("Type{}", i), format!("state{}", i));
}
let mut received = 0;
let mut lagged = false;
loop {
match rx.try_recv() {
Ok(_) => received += 1,
Err(broadcast::error::TryRecvError::Lagged(_)) => {
lagged = true;
break;
}
Err(_) => break,
}
}
assert!(received > 0 || lagged);
}
#[test]
fn test_state_change_deserialization() {
let json = r#"{"changed":{"Email":"state1","Mailbox":"state2"}}"#;
let state_change: StateChange = serde_json::from_str(json).unwrap();
assert_eq!(
state_change.changed.get("Email"),
Some(&"state1".to_string())
);
assert_eq!(
state_change.changed.get("Mailbox"),
Some(&"state2".to_string())
);
}
#[test]
fn test_push_subscription_deserialization() {
let json = r#"{"url":"https://example.com","types":["Email"]}"#;
let subscription: EventSourcePushHint = serde_json::from_str(json).unwrap();
assert_eq!(subscription.url, "https://example.com");
assert_eq!(subscription.types, Some(vec!["Email".to_string()]));
}
}