use std::sync::Arc;
use std::time::Duration;
use tokio::sync::mpsc;
use crate::notifications::legacy::{
LegacyNotificationChannel, SubscriptionAuthorizer, PROTOCOL_GREETING,
};
use crate::storage::StorageEvent;
use tokio::sync::broadcast::Receiver;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum OutboundFrame {
Text(String),
Close(String),
}
pub struct LegacyWsDriver {
channel: LegacyNotificationChannel,
outbound_cap: usize,
}
impl LegacyWsDriver {
pub fn new(events: Receiver<StorageEvent>) -> Self {
Self {
channel: LegacyNotificationChannel::new(events),
outbound_cap: 256,
}
}
pub fn with_outbound_capacity(mut self, cap: usize) -> Self {
self.outbound_cap = cap;
self
}
pub fn with_heartbeat(mut self, interval: Duration) -> Self {
self.channel = self.channel.with_heartbeat(interval);
self
}
pub fn with_authorizer(mut self, authorizer: Arc<dyn SubscriptionAuthorizer>) -> Self {
self.channel = self.channel.with_authorizer(authorizer);
self
}
pub fn with_server_origin(mut self, origin: String) -> Self {
self.channel = self.channel.with_server_origin(origin);
self
}
pub fn with_web_id(mut self, web_id: Option<String>) -> Self {
self.channel = self.channel.with_web_id(web_id);
self
}
pub fn split(
self,
) -> (
mpsc::Sender<String>,
mpsc::Receiver<OutboundFrame>,
impl std::future::Future<Output = ()> + Send,
) {
let (in_tx, in_rx) = mpsc::channel::<String>(64);
let (out_tx, out_rx) = mpsc::channel::<OutboundFrame>(self.outbound_cap);
let fut = run_loop(self.channel, in_rx, out_tx);
(in_tx, out_rx, fut)
}
}
async fn run_loop(
mut chan: LegacyNotificationChannel,
mut inbound: mpsc::Receiver<String>,
outbound: mpsc::Sender<OutboundFrame>,
) {
if outbound
.send(OutboundFrame::Text(PROTOCOL_GREETING.to_string()))
.await
.is_err()
{
return;
}
let heartbeat = chan.heartbeat_interval();
let mut ticker = tokio::time::interval(heartbeat);
ticker.tick().await;
loop {
tokio::select! {
maybe_line = inbound.recv() => {
let Some(line) = maybe_line else {
return;
};
for frame in handle_line(&mut chan, &line) {
if outbound.send(frame).await.is_err() {
return;
}
}
}
maybe_event = chan.next_event() => {
let Some(event) = maybe_event else { return; };
let uri = match &event {
StorageEvent::Created(p)
| StorageEvent::Updated(p)
| StorageEvent::Deleted(p) => p.clone(),
};
if chan.matches_subscription(&uri) {
if let Some(line) = LegacyNotificationChannel::to_legacy_line(&event) {
if outbound.try_send(OutboundFrame::Text(line)).is_err() {
tracing::warn!(
target: "solid_pod_rs::legacy_notifications",
"outbound queue saturated, dropping pub frame"
);
}
}
}
}
_ = ticker.tick() => {
if outbound
.send(OutboundFrame::Text(String::new()))
.await
.is_err()
{
return;
}
}
}
}
}
pub fn handle_line(chan: &mut LegacyNotificationChannel, line: &str) -> Vec<OutboundFrame> {
let trimmed = line.trim();
if trimmed.is_empty() {
return Vec::new(); }
if let Some(target) = LegacyNotificationChannel::parse_subscribe(trimmed) {
match chan.subscribe(target.clone()) {
Ok(()) => {
return vec![OutboundFrame::Text(LegacyNotificationChannel::ack_line(
&target,
))];
}
Err(err_line) => {
return vec![OutboundFrame::Text(err_line)];
}
}
}
if let Some(target) = LegacyNotificationChannel::parse_unsubscribe(trimmed) {
chan.unsubscribe(&target);
return Vec::new();
}
vec![OutboundFrame::Close(format!(
"unknown opcode: {}",
first_token(trimmed)
))]
}
fn first_token(s: &str) -> &str {
s.split_whitespace().next().unwrap_or("")
}
#[cfg(test)]
mod tests {
use super::*;
use crate::notifications::legacy::AllowAllAuthorizer;
use tokio::sync::broadcast;
#[test]
fn handle_line_sub_emits_ack() {
let (_tx, rx) = broadcast::channel::<StorageEvent>(16);
let mut chan =
LegacyNotificationChannel::new(rx).with_authorizer(Arc::new(AllowAllAuthorizer));
let frames = handle_line(&mut chan, "sub https://p/x");
assert_eq!(frames, vec![OutboundFrame::Text("ack https://p/x".into())]);
assert_eq!(chan.subscription_count(), 1);
}
#[test]
fn handle_line_unsub_is_silent() {
let (_tx, rx) = broadcast::channel::<StorageEvent>(16);
let mut chan =
LegacyNotificationChannel::new(rx).with_authorizer(Arc::new(AllowAllAuthorizer));
chan.subscribe("https://p/x".into()).unwrap();
let frames = handle_line(&mut chan, "unsub https://p/x");
assert!(frames.is_empty());
assert_eq!(chan.subscription_count(), 0);
}
#[test]
fn handle_line_unknown_opcode_closes() {
let (_tx, rx) = broadcast::channel::<StorageEvent>(16);
let mut chan =
LegacyNotificationChannel::new(rx).with_authorizer(Arc::new(AllowAllAuthorizer));
let frames = handle_line(&mut chan, "wat foo");
assert_eq!(frames.len(), 1);
assert!(matches!(frames[0], OutboundFrame::Close(_)));
}
#[test]
fn handle_line_blank_is_noop() {
let (_tx, rx) = broadcast::channel::<StorageEvent>(16);
let mut chan =
LegacyNotificationChannel::new(rx).with_authorizer(Arc::new(AllowAllAuthorizer));
assert!(handle_line(&mut chan, "").is_empty());
assert!(handle_line(&mut chan, " ").is_empty());
}
#[test]
fn handle_line_sub_over_cap_emits_err() {
let (_tx, rx) = broadcast::channel::<StorageEvent>(16);
let mut chan = LegacyNotificationChannel::new(rx)
.with_authorizer(Arc::new(AllowAllAuthorizer))
.with_subscription_cap(1);
let _ = handle_line(&mut chan, "sub https://p/a");
let frames = handle_line(&mut chan, "sub https://p/b");
assert_eq!(frames.len(), 1);
match &frames[0] {
OutboundFrame::Text(t) => {
assert!(t.starts_with("err "));
assert!(t.contains("subscription-limit"));
}
other => panic!("expected Text, got {other:?}"),
}
}
}