use anyhow::{Context, Result};
use axum::extract::ws::{Message, WebSocket};
use dashmap::DashMap;
use futures::{SinkExt, StreamExt};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use tokio::sync::{broadcast, RwLock};
use tracing::{debug, error, info, warn};
use uuid::Uuid;
const MAX_NOTIFICATION_BUFFER: usize = 1000;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum NotificationEvent {
DatasetCreated { dataset: String },
DatasetUpdated {
dataset: String,
triple_count: usize,
},
DatasetDeleted { dataset: String },
QueryCompleted {
query_id: String,
duration_ms: u64,
result_count: usize,
},
QueryFailed { query_id: String, error: String },
SystemStatus { status: SystemStatus },
MetricsUpdate { metrics: SystemMetrics },
BackupCompleted {
dataset: String,
backup_id: String,
size_bytes: u64,
},
BackupFailed { dataset: String, error: String },
FederationUpdate {
endpoint: String,
status: EndpointStatus,
},
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum SystemStatus {
Healthy,
Degraded { reason: String },
Unhealthy { reason: String },
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum EndpointStatus {
Available,
Unavailable,
Slow { latency_ms: u64 },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SystemMetrics {
pub timestamp: chrono::DateTime<chrono::Utc>,
pub cpu_usage_percent: f64,
pub memory_usage_mb: u64,
pub active_queries: usize,
pub queries_per_second: f64,
pub avg_query_time_ms: f64,
}
impl PartialEq for SystemMetrics {
fn eq(&self, other: &Self) -> bool {
self.timestamp == other.timestamp
&& (self.cpu_usage_percent - other.cpu_usage_percent).abs() < f64::EPSILON
&& self.memory_usage_mb == other.memory_usage_mb
&& self.active_queries == other.active_queries
&& (self.queries_per_second - other.queries_per_second).abs() < f64::EPSILON
&& (self.avg_query_time_ms - other.avg_query_time_ms).abs() < f64::EPSILON
}
}
impl Eq for SystemMetrics {}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Notification {
pub id: String,
pub timestamp: chrono::DateTime<chrono::Utc>,
pub event: NotificationEvent,
}
impl Notification {
pub fn new(event: NotificationEvent) -> Self {
Self {
id: Uuid::new_v4().to_string(),
timestamp: chrono::Utc::now(),
event,
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct SubscriptionFilter {
pub event_types: Option<Vec<String>>,
pub datasets: Option<Vec<String>>,
pub min_severity: Option<NotificationSeverity>,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord)]
#[serde(rename_all = "snake_case")]
pub enum NotificationSeverity {
Info = 0,
Warning = 1,
Error = 2,
Critical = 3,
}
impl NotificationEvent {
pub fn severity(&self) -> NotificationSeverity {
match self {
Self::DatasetCreated { .. } | Self::DatasetUpdated { .. } => NotificationSeverity::Info,
Self::DatasetDeleted { .. } => NotificationSeverity::Warning,
Self::QueryCompleted { .. } => NotificationSeverity::Info,
Self::QueryFailed { .. } => NotificationSeverity::Warning,
Self::SystemStatus { status } => match status {
SystemStatus::Healthy => NotificationSeverity::Info,
SystemStatus::Degraded { .. } => NotificationSeverity::Warning,
SystemStatus::Unhealthy { .. } => NotificationSeverity::Critical,
},
Self::MetricsUpdate { .. } => NotificationSeverity::Info,
Self::BackupCompleted { .. } => NotificationSeverity::Info,
Self::BackupFailed { .. } => NotificationSeverity::Error,
Self::FederationUpdate { status, .. } => match status {
EndpointStatus::Available => NotificationSeverity::Info,
EndpointStatus::Unavailable => NotificationSeverity::Error,
EndpointStatus::Slow { .. } => NotificationSeverity::Warning,
},
}
}
pub fn matches_filter(&self, filter: &SubscriptionFilter) -> bool {
if let Some(min_severity) = filter.min_severity {
if self.severity() < min_severity {
return false;
}
}
if let Some(ref event_types) = filter.event_types {
let event_type = match self {
Self::DatasetCreated { .. } => "dataset_created",
Self::DatasetUpdated { .. } => "dataset_updated",
Self::DatasetDeleted { .. } => "dataset_deleted",
Self::QueryCompleted { .. } => "query_completed",
Self::QueryFailed { .. } => "query_failed",
Self::SystemStatus { .. } => "system_status",
Self::MetricsUpdate { .. } => "metrics_update",
Self::BackupCompleted { .. } => "backup_completed",
Self::BackupFailed { .. } => "backup_failed",
Self::FederationUpdate { .. } => "federation_update",
};
if !event_types.contains(&event_type.to_string()) {
return false;
}
}
if let Some(ref datasets) = filter.datasets {
let dataset = match self {
Self::DatasetCreated { dataset }
| Self::DatasetUpdated { dataset, .. }
| Self::DatasetDeleted { dataset }
| Self::BackupCompleted { dataset, .. }
| Self::BackupFailed { dataset, .. } => Some(dataset),
_ => None,
};
if let Some(dataset) = dataset {
if !datasets.contains(dataset) {
return false;
}
}
}
true
}
}
#[derive(Debug)]
struct ClientConnection {
id: String,
filter: SubscriptionFilter,
tx: broadcast::Sender<Notification>,
}
pub struct NotificationManager {
clients: Arc<DashMap<String, ClientConnection>>,
global_tx: broadcast::Sender<Notification>,
history: Arc<RwLock<Vec<Notification>>>,
max_history: usize,
stats: Arc<RwLock<NotificationStats>>,
}
#[derive(Debug, Default)]
pub struct NotificationStats {
pub total_notifications: u64,
pub total_clients: u64,
pub active_clients: usize,
pub notifications_by_type: std::collections::HashMap<String, u64>,
pub dropped_notifications: u64,
}
impl NotificationManager {
pub fn new() -> Self {
let (global_tx, _) = broadcast::channel(MAX_NOTIFICATION_BUFFER);
Self {
clients: Arc::new(DashMap::new()),
global_tx,
history: Arc::new(RwLock::new(Vec::new())),
max_history: 100,
stats: Arc::new(RwLock::new(NotificationStats::default())),
}
}
pub fn with_config(max_history: usize, buffer_size: usize) -> Self {
let (global_tx, _) = broadcast::channel(buffer_size);
Self {
clients: Arc::new(DashMap::new()),
global_tx,
history: Arc::new(RwLock::new(Vec::new())),
max_history,
stats: Arc::new(RwLock::new(NotificationStats::default())),
}
}
pub async fn broadcast(&self, event: NotificationEvent) -> Result<()> {
let notification = Notification::new(event.clone());
{
let mut history = self.history.write().await;
history.push(notification.clone());
if history.len() > self.max_history {
history.remove(0);
}
}
{
let mut stats = self.stats.write().await;
stats.total_notifications += 1;
*stats
.notifications_by_type
.entry(format!("{:?}", event))
.or_insert(0) += 1;
}
match self.global_tx.send(notification.clone()) {
Ok(receiver_count) => {
debug!("Broadcasted notification to {} clients", receiver_count);
}
Err(e) => {
warn!("Failed to broadcast notification: {}", e);
let mut stats = self.stats.write().await;
stats.dropped_notifications += 1;
}
}
Ok(())
}
pub async fn register_client(
&self,
filter: SubscriptionFilter,
) -> Result<(String, broadcast::Receiver<Notification>)> {
let client_id = Uuid::new_v4().to_string();
let rx = self.global_tx.subscribe();
let (tx, _) = broadcast::channel(MAX_NOTIFICATION_BUFFER);
let client_rx = tx.subscribe();
let connection = ClientConnection {
id: client_id.clone(),
filter,
tx,
};
self.clients.insert(client_id.clone(), connection);
{
let mut stats = self.stats.write().await;
stats.total_clients += 1;
stats.active_clients = self.clients.len();
}
info!("Registered new notification client: {}", client_id);
Ok((client_id, rx))
}
pub async fn unregister_client(&self, client_id: &str) -> Result<()> {
self.clients.remove(client_id);
{
let mut stats = self.stats.write().await;
stats.active_clients = self.clients.len();
}
info!("Unregistered notification client: {}", client_id);
Ok(())
}
pub async fn get_history(&self, limit: Option<usize>) -> Vec<Notification> {
let history = self.history.read().await;
let limit = limit.unwrap_or(history.len());
history.iter().rev().take(limit).cloned().collect()
}
pub async fn get_statistics(&self) -> NotificationStats {
let stats = self.stats.read().await;
NotificationStats {
total_notifications: stats.total_notifications,
total_clients: stats.total_clients,
active_clients: self.clients.len(),
notifications_by_type: stats.notifications_by_type.clone(),
dropped_notifications: stats.dropped_notifications,
}
}
pub async fn handle_websocket(
self: Arc<Self>,
mut socket: WebSocket,
filter: SubscriptionFilter,
) -> Result<()> {
let (client_id, mut rx) = self
.register_client(filter.clone())
.await
.context("Failed to register client")?;
info!("WebSocket client connected: {}", client_id);
let welcome = serde_json::json!({
"type": "welcome",
"client_id": client_id,
"message": "Connected to OxiRS Fuseki real-time notifications",
});
if let Err(e) = socket
.send(Message::Text(serde_json::to_string(&welcome)?.into()))
.await
{
error!("Failed to send welcome message: {}", e);
self.unregister_client(&client_id).await?;
return Ok(());
}
loop {
tokio::select! {
result = rx.recv() => {
match result {
Ok(notification) => {
if !notification.event.matches_filter(&filter) {
continue;
}
let json = serde_json::to_string(¬ification)?;
if let Err(e) = socket.send(Message::Text(json.into())).await {
error!("Failed to send notification to client {}: {}", client_id, e);
break;
}
}
Err(broadcast::error::RecvError::Lagged(n)) => {
warn!("Client {} lagged behind by {} notifications", client_id, n);
let lag_msg = serde_json::json!({
"type": "lag_warning",
"lagged_by": n,
});
let _ = socket.send(Message::Text(serde_json::to_string(&lag_msg)?.into())).await;
}
Err(broadcast::error::RecvError::Closed) => {
info!("Notification channel closed for client {}", client_id);
break;
}
}
}
message = socket.recv() => {
match message {
Some(Ok(Message::Text(text))) => {
if let Err(e) = self.handle_client_message(&client_id, &text).await {
error!("Failed to handle client message: {}", e);
}
}
Some(Ok(Message::Close(_))) | None => {
info!("Client {} disconnected", client_id);
break;
}
Some(Ok(Message::Ping(data))) => {
if let Err(e) = socket.send(Message::Pong(data)).await {
error!("Failed to send pong: {}", e);
break;
}
}
Some(Err(e)) => {
error!("WebSocket error for client {}: {}", client_id, e);
break;
}
_ => {}
}
}
}
}
self.unregister_client(&client_id).await?;
info!("WebSocket client disconnected: {}", client_id);
Ok(())
}
async fn handle_client_message(&self, client_id: &str, message: &str) -> Result<()> {
#[derive(Deserialize)]
struct ClientCommand {
command: String,
#[serde(default)]
params: serde_json::Value,
}
let cmd: ClientCommand =
serde_json::from_str(message).context("Failed to parse client command")?;
match cmd.command.as_str() {
"ping" => {
debug!("Received ping from client {}", client_id);
}
"get_history" => {
let limit: Option<usize> = cmd
.params
.get("limit")
.and_then(|v| v.as_u64())
.map(|v| v as usize);
let history = self.get_history(limit).await;
debug!(
"Sent {} historical notifications to client {}",
history.len(),
client_id
);
}
"update_filter" => {
if let Ok(new_filter) = serde_json::from_value::<SubscriptionFilter>(cmd.params) {
if let Some(mut client) = self.clients.get_mut(client_id) {
client.filter = new_filter;
info!("Updated filter for client {}", client_id);
}
}
}
_ => {
warn!("Unknown command from client {}: {}", client_id, cmd.command);
}
}
Ok(())
}
}
impl Default for NotificationManager {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_notification_creation() {
let event = NotificationEvent::DatasetCreated {
dataset: "test".to_string(),
};
let notification = Notification::new(event);
assert!(!notification.id.is_empty());
assert!(notification.timestamp <= chrono::Utc::now());
}
#[tokio::test]
async fn test_event_severity() {
let info_event = NotificationEvent::QueryCompleted {
query_id: "q1".to_string(),
duration_ms: 100,
result_count: 10,
};
assert_eq!(info_event.severity(), NotificationSeverity::Info);
let error_event = NotificationEvent::BackupFailed {
dataset: "test".to_string(),
error: "disk full".to_string(),
};
assert_eq!(error_event.severity(), NotificationSeverity::Error);
let critical_event = NotificationEvent::SystemStatus {
status: SystemStatus::Unhealthy {
reason: "out of memory".to_string(),
},
};
assert_eq!(critical_event.severity(), NotificationSeverity::Critical);
}
#[tokio::test]
async fn test_filter_matching() {
let event = NotificationEvent::DatasetCreated {
dataset: "test".to_string(),
};
let filter = SubscriptionFilter::default();
assert!(event.matches_filter(&filter));
let filter = SubscriptionFilter {
datasets: Some(vec!["test".to_string()]),
..Default::default()
};
assert!(event.matches_filter(&filter));
let filter = SubscriptionFilter {
datasets: Some(vec!["other".to_string()]),
..Default::default()
};
assert!(!event.matches_filter(&filter));
let filter = SubscriptionFilter {
min_severity: Some(NotificationSeverity::Info),
..Default::default()
};
assert!(event.matches_filter(&filter));
let filter = SubscriptionFilter {
min_severity: Some(NotificationSeverity::Critical),
..Default::default()
};
assert!(!event.matches_filter(&filter));
}
#[tokio::test]
async fn test_notification_manager() {
let manager = NotificationManager::new();
let event = NotificationEvent::DatasetCreated {
dataset: "test".to_string(),
};
manager.broadcast(event).await.unwrap();
let stats = manager.get_statistics().await;
assert_eq!(stats.total_notifications, 1);
let history = manager.get_history(None).await;
assert_eq!(history.len(), 1);
}
#[tokio::test]
async fn test_client_registration() {
let manager = NotificationManager::new();
let filter = SubscriptionFilter::default();
let (client_id, _rx) = manager.register_client(filter).await.unwrap();
let stats = manager.get_statistics().await;
assert_eq!(stats.active_clients, 1);
manager.unregister_client(&client_id).await.unwrap();
let stats = manager.get_statistics().await;
assert_eq!(stats.active_clients, 0);
}
#[tokio::test]
async fn test_notification_history() {
let manager = NotificationManager::with_config(10, 100);
for i in 0..15 {
let event = NotificationEvent::DatasetCreated {
dataset: format!("test{}", i),
};
manager.broadcast(event).await.unwrap();
}
let history = manager.get_history(None).await;
assert_eq!(history.len(), 10);
let limited = manager.get_history(Some(5)).await;
assert_eq!(limited.len(), 5);
}
}