use crate::error::{Error, ErrorCode};
use crate::types::Message;
use dashmap::DashMap;
use tokio::sync::mpsc::{Sender, error::TrySendError};
use uuid::Uuid;
#[derive(Default)]
pub(crate) struct MessageRegistry {
inner: DashMap<Uuid, MessageSession>,
}
struct MessageSession {
sender: Sender<Message>,
generation: u64,
}
#[allow(dead_code)]
impl MessageRegistry {
#[inline]
pub(crate) fn new() -> Self {
Self {
inner: DashMap::new(),
}
}
#[inline]
pub(crate) fn register(&self, key: Uuid, generation: u64, sender: Sender<Message>) {
self.inner
.insert(key, MessageSession { sender, generation });
}
#[inline]
pub(crate) fn unregister(&self, key: &Uuid) -> bool {
self.inner.remove(key).is_some()
}
#[inline]
pub(crate) fn unregister_if_generation(&self, key: &Uuid, generation: u64) {
self.inner
.remove_if(key, |_, current| current.generation == generation);
}
#[inline]
pub(crate) fn send(&self, message: Message) -> Result<(), Error> {
let session_id = *message.session_id().ok_or(ErrorCode::InvalidParams)?;
if let Some(entry) = self.inner.get(&session_id) {
match entry.sender.try_send(message) {
Ok(()) => Ok(()),
Err(err) => {
let err_text = err.to_string();
match err {
TrySendError::Full(_message) => {
#[cfg(feature = "tracing")]
tracing::warn!(
logger = "neva",
"Dropping SSE log message for session {}: {}",
session_id,
err_text
);
Ok(())
}
TrySendError::Closed(_message) => {
#[cfg(feature = "tracing")]
tracing::warn!(
logger = "neva",
"Failed to deliver SSE log message for session {}: {}",
session_id,
err_text
);
Err(Error::new(ErrorCode::InternalError, err_text))
}
}
}
}
} else {
Err(Error::new(ErrorCode::InvalidParams, "Sender not found"))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::Message;
use crate::types::notification::Notification;
use tokio::sync::mpsc;
#[test]
fn it_creates_new_registry() {
let registry = MessageRegistry::new();
assert!(registry.inner.is_empty());
}
#[test]
fn it_registers_and_unregisters() {
let registry = MessageRegistry::new();
let session_id = Uuid::new_v4();
let (tx, _rx) = mpsc::channel(8);
registry.register(session_id, 1, tx.clone());
assert!(registry.inner.contains_key(&session_id));
let result = registry.unregister(&session_id);
assert!(result);
assert!(!registry.inner.contains_key(&session_id));
let random_id = Uuid::new_v4();
let result = registry.unregister(&random_id);
assert!(!result);
}
#[test]
fn it_unregisters_only_matching_generation() {
let registry = MessageRegistry::new();
let session_id = Uuid::new_v4();
let (tx1, _rx1) = mpsc::channel(8);
let (tx2, _rx2) = mpsc::channel(8);
registry.register(session_id, 1, tx1);
registry.unregister_if_generation(&session_id, 2);
assert!(registry.inner.contains_key(&session_id));
registry.register(session_id, 2, tx2);
registry.unregister_if_generation(&session_id, 1);
assert!(registry.inner.contains_key(&session_id));
registry.unregister_if_generation(&session_id, 2);
assert!(!registry.inner.contains_key(&session_id));
}
#[tokio::test]
async fn it_sends_message() {
let registry = MessageRegistry::new();
let session_id = Uuid::new_v4();
let (tx, mut rx) = mpsc::channel(8);
registry.register(session_id, 1, tx);
let test_message =
Message::Notification(Notification::new("test", None)).set_session_id(session_id);
let send_result = registry.send(test_message);
assert!(send_result.is_ok());
let received = rx.recv().await;
assert!(received.is_some());
assert_eq!(received.unwrap().session_id(), Some(&session_id));
}
#[test]
fn it_sends_to_nonexistent_session() {
let registry = MessageRegistry::new();
let session_id = Uuid::new_v4();
let test_message =
Message::Notification(Notification::new("test", None)).set_session_id(session_id);
let send_result = registry.send(test_message);
assert!(send_result.is_err());
assert_eq!(send_result.unwrap_err().code, ErrorCode::InvalidParams);
}
#[test]
fn it_sends_message_without_session_id() {
let registry = MessageRegistry::new();
let test_message = Message::Notification(Notification::new("test", None));
let send_result = registry.send(test_message);
assert!(send_result.is_err());
assert_eq!(send_result.unwrap_err().code, ErrorCode::InvalidParams);
}
}