#![allow(clippy::result_large_err)]
#![allow(clippy::cast_possible_truncation)]
use std::{
pin::Pin,
sync::Arc,
task::{Context, Poll},
time::SystemTime,
};
use {
futures::Stream,
reovim_protocol::v2::{
Notification, SubscribeRequest, notification::Payload,
notification_service_server::NotificationService,
},
tokio_stream::wrappers::{BroadcastStream, errors::BroadcastStreamRecvError},
tonic::{Request, Response, Status},
};
use crate::session::{ClientId, Session, SessionId, SessionRegistry, TokenRegistry};
#[cfg_attr(coverage_nightly, coverage(off))]
fn on_notification_stream_dropped(client_id: ClientId, session: &Session, tokens: &TokenRegistry) {
tracing::info!(client_id = client_id.as_usize(), "Notification stream dropped — auto-cleanup");
tokens.revoke_by_client(client_id);
if let Some(presence) = session.presence().leave(client_id) {
session
.emit_notification(build_presence_left_notification(client_id, &presence.display_name));
}
session.remove_client(client_id);
}
struct CleanupStream {
inner: Pin<Box<dyn Stream<Item = Result<Notification, Status>> + Send>>,
cleanup: Option<Box<dyn FnOnce() + Send>>,
}
impl CleanupStream {
fn new(
inner: Pin<Box<dyn Stream<Item = Result<Notification, Status>> + Send>>,
cleanup: impl FnOnce() + Send + 'static,
) -> Self {
Self {
inner,
cleanup: Some(Box::new(cleanup)),
}
}
}
#[cfg_attr(coverage_nightly, coverage(off))]
impl Stream for CleanupStream {
type Item = Result<Notification, Status>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.inner.as_mut().poll_next(cx)
}
}
impl Drop for CleanupStream {
fn drop(&mut self) {
if let Some(cleanup) = self.cleanup.take() {
cleanup();
}
}
}
fn build_presence_left_notification(client_id: ClientId, display_name: &str) -> Notification {
use reovim_protocol::v2::PresenceLeftPayload;
let timestamp_ms = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.expect("system time before UNIX_EPOCH")
.as_millis() as u64;
Notification {
event_type: "presence_left".to_string(),
timestamp_ms,
payload: Some(Payload::PresenceLeft(PresenceLeftPayload {
client_id: client_id.as_usize() as u64,
display_name: display_name.to_string(),
})),
}
}
pub struct NotificationServiceImpl {
sessions: Arc<SessionRegistry>,
default_session_id: SessionId,
tokens: Arc<TokenRegistry>,
}
impl NotificationServiceImpl {
#[must_use]
pub const fn new(
sessions: Arc<SessionRegistry>,
default_session_id: SessionId,
tokens: Arc<TokenRegistry>,
) -> Self {
Self {
sessions,
default_session_id,
tokens,
}
}
fn get_session(&self) -> Result<Arc<Session>, Status> {
self.sessions
.get(&self.default_session_id)
.ok_or_else(|| Status::not_found("No active session"))
}
}
#[tonic::async_trait]
impl NotificationService for NotificationServiceImpl {
type SubscribeStream =
Pin<Box<dyn Stream<Item = Result<Notification, Status>> + Send + 'static>>;
async fn subscribe(
&self,
request: Request<SubscribeRequest>,
) -> Result<Response<Self::SubscribeStream>, Status> {
let client_id = request.extensions().get::<ClientId>().copied();
let req = request.into_inner();
let session = self.get_session()?;
let rx = session.subscribe_notifications();
let filter_types: Vec<String> = req.event_types;
let stream = BroadcastStream::new(rx);
let output_stream = async_stream::stream! {
let mut stream = stream;
while let Some(result) = futures::StreamExt::next(&mut stream).await {
match result {
Ok(notification) => {
if filter_types.is_empty() || filter_types.contains(¬ification.event_type) {
yield Ok(notification);
}
}
Err(BroadcastStreamRecvError::Lagged(n)) => {
tracing::warn!(lagged = n, "Notification subscriber lagged behind");
}
}
}
};
if let Some(cid) = client_id {
let cleanup_session = Arc::clone(&session);
let cleanup_tokens = Arc::clone(&self.tokens);
let guarded = CleanupStream::new(Box::pin(output_stream), move || {
on_notification_stream_dropped(cid, &cleanup_session, &cleanup_tokens);
});
Ok(Response::new(Box::pin(guarded)))
} else {
Ok(Response::new(Box::pin(output_stream)))
}
}
}
#[cfg(test)]
#[path = "notification_tests.rs"]
mod tests;