use std::collections::HashSet;
use std::sync::Arc;
use std::time::Duration;
use async_trait::async_trait;
use tokio::sync::broadcast::{error::RecvError, Receiver};
use crate::storage::StorageEvent;
pub const MAX_SUBSCRIPTIONS_PER_CONNECTION: usize = 100;
pub const MAX_URL_LENGTH: usize = 2048;
pub const DEFAULT_HEARTBEAT_INTERVAL: Duration = Duration::from_secs(30);
pub const PROTOCOL_GREETING: &str = "protocol solid-0.1";
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum SolidZeroOp {
Sub,
Ack,
Err,
Pub,
Unsub,
}
impl SolidZeroOp {
pub const fn as_str(self) -> &'static str {
match self {
SolidZeroOp::Sub => "sub",
SolidZeroOp::Ack => "ack",
SolidZeroOp::Err => "err",
SolidZeroOp::Pub => "pub",
SolidZeroOp::Unsub => "unsub",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DenyReason {
Forbidden,
CrossOrigin,
}
pub trait SubscriptionAuthorizer: Send + Sync {
fn check(&self, target: &str, subject: Option<&str>) -> Result<(), DenyReason>;
}
pub struct AllowAllAuthorizer;
impl SubscriptionAuthorizer for AllowAllAuthorizer {
fn check(&self, _: &str, _: Option<&str>) -> Result<(), DenyReason> {
Ok(())
}
}
pub struct DenyAllAuthorizer;
impl SubscriptionAuthorizer for DenyAllAuthorizer {
fn check(&self, _: &str, _: Option<&str>) -> Result<(), DenyReason> {
Err(DenyReason::Forbidden)
}
}
pub struct LegacyNotificationChannel {
storage_events: Receiver<StorageEvent>,
subscriptions: HashSet<String>,
url_cap_bytes: usize,
max_subs_per_conn: usize,
heartbeat_interval: Duration,
authorizer: Arc<dyn SubscriptionAuthorizer>,
server_origin: Option<String>,
web_id: Option<String>,
}
impl LegacyNotificationChannel {
pub fn new(storage_events: Receiver<StorageEvent>) -> Self {
Self {
storage_events,
subscriptions: HashSet::new(),
url_cap_bytes: MAX_URL_LENGTH,
max_subs_per_conn: MAX_SUBSCRIPTIONS_PER_CONNECTION,
heartbeat_interval: DEFAULT_HEARTBEAT_INTERVAL,
authorizer: Arc::new(DenyAllAuthorizer),
server_origin: None,
web_id: None,
}
}
pub fn with_heartbeat(mut self, interval: Duration) -> Self {
self.heartbeat_interval = interval;
self
}
pub fn with_url_cap(mut self, cap: usize) -> Self {
self.url_cap_bytes = cap;
self
}
pub fn with_subscription_cap(mut self, cap: usize) -> Self {
self.max_subs_per_conn = cap;
self
}
pub fn with_authorizer(mut self, authorizer: Arc<dyn SubscriptionAuthorizer>) -> Self {
self.authorizer = authorizer;
self
}
pub fn with_server_origin(mut self, origin: String) -> Self {
self.server_origin = Some(origin);
self
}
pub fn with_web_id(mut self, web_id: Option<String>) -> Self {
self.web_id = web_id;
self
}
pub fn heartbeat_interval(&self) -> Duration {
self.heartbeat_interval
}
pub fn url_cap(&self) -> usize {
self.url_cap_bytes
}
pub fn subscription_count(&self) -> usize {
self.subscriptions.len()
}
pub fn subscribe(&mut self, target: String) -> Result<(), String> {
if target.len() > self.url_cap_bytes {
return Err(format!("err {} url-too-long", truncate(&target, 64)));
}
if let Some(server_origin) = &self.server_origin {
match url::Url::parse(&target) {
Ok(parsed) => {
let host = parsed.host_str().unwrap_or("");
let port_suffix = parsed
.port()
.map(|p| format!(":{p}"))
.unwrap_or_default();
let target_origin =
format!("{}://{}{}", parsed.scheme(), host, port_suffix);
if &target_origin != server_origin {
return Err(format!("err {target} forbidden"));
}
}
Err(_) => {
return Err(format!("err {target} forbidden"));
}
}
}
if self.subscriptions.len() >= self.max_subs_per_conn
&& !self.subscriptions.contains(&target)
{
return Err(format!("err {} subscription-limit", target));
}
match self.authorizer.check(&target, self.web_id.as_deref()) {
Ok(()) => {}
Err(DenyReason::Forbidden | DenyReason::CrossOrigin) => {
return Err(format!("err {target} forbidden"));
}
}
self.subscriptions.insert(target);
Ok(())
}
pub fn unsubscribe(&mut self, target: &str) {
self.subscriptions.remove(target);
}
pub fn matches_subscription(&self, resource_uri: &str) -> bool {
for sub in &self.subscriptions {
if sub == resource_uri {
return true;
}
if sub.ends_with('/') && resource_uri.starts_with(sub.as_str()) {
return true;
}
}
false
}
pub async fn next_event(&mut self) -> Option<StorageEvent> {
loop {
match self.storage_events.recv().await {
Ok(ev) => return Some(ev),
Err(RecvError::Lagged(_)) => continue,
Err(RecvError::Closed) => return None,
}
}
}
pub fn to_legacy_line(event: &StorageEvent) -> Option<String> {
let uri = match event {
StorageEvent::Created(p) | StorageEvent::Updated(p) | StorageEvent::Deleted(p) => p,
};
Some(format!("{} {}", SolidZeroOp::Pub.as_str(), uri))
}
pub fn parse_subscribe(line: &str) -> Option<String> {
parse_prefixed(line, "sub ")
}
pub fn parse_unsubscribe(line: &str) -> Option<String> {
parse_prefixed(line, "unsub ")
}
pub fn ack_line(target: &str) -> String {
format!("{} {}", SolidZeroOp::Ack.as_str(), target)
}
pub fn err_line(target: &str, reason: &str) -> String {
format!("{} {} {}", SolidZeroOp::Err.as_str(), target, reason)
}
}
fn parse_prefixed(line: &str, prefix: &str) -> Option<String> {
let trimmed = line.trim_end_matches(['\r', '\n']).trim_start();
let rest = trimmed.strip_prefix(prefix)?;
let target = rest.trim();
if target.is_empty() {
None
} else {
Some(target.to_string())
}
}
fn truncate(s: &str, max: usize) -> &str {
if s.len() <= max {
s
} else {
let mut end = max;
while end > 0 && !s.is_char_boundary(end) {
end -= 1;
}
&s[..end]
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::sync::broadcast;
#[test]
fn parse_subscribe_valid() {
let got = LegacyNotificationChannel::parse_subscribe("sub https://pod.example.com/x");
assert_eq!(got, Some("https://pod.example.com/x".to_string()));
}
#[test]
fn parse_subscribe_trims_whitespace_and_crlf() {
let got = LegacyNotificationChannel::parse_subscribe("sub https://pod.example.com/x\r\n");
assert_eq!(got, Some("https://pod.example.com/x".to_string()));
let got = LegacyNotificationChannel::parse_subscribe(" sub https://pod.example.com/x ");
assert_eq!(got, Some("https://pod.example.com/x".to_string()));
}
#[test]
fn parse_subscribe_rejects_malformed() {
assert!(LegacyNotificationChannel::parse_subscribe("sub").is_none());
assert!(LegacyNotificationChannel::parse_subscribe("sub ").is_none());
assert!(LegacyNotificationChannel::parse_subscribe("subscribe foo").is_none());
assert!(LegacyNotificationChannel::parse_subscribe("pub foo").is_none());
assert!(LegacyNotificationChannel::parse_subscribe("").is_none());
}
#[test]
fn parse_unsubscribe_valid() {
let got = LegacyNotificationChannel::parse_unsubscribe("unsub https://p/x");
assert_eq!(got, Some("https://p/x".to_string()));
}
#[test]
fn to_legacy_line_created() {
let ev = StorageEvent::Created("https://pod.example.com/x".into());
assert_eq!(
LegacyNotificationChannel::to_legacy_line(&ev),
Some("pub https://pod.example.com/x".to_string())
);
}
#[test]
fn to_legacy_line_updated_and_deleted_also_map_to_pub() {
let u = StorageEvent::Updated("https://pod.example.com/x".into());
let d = StorageEvent::Deleted("https://pod.example.com/x".into());
assert_eq!(
LegacyNotificationChannel::to_legacy_line(&u),
Some("pub https://pod.example.com/x".to_string())
);
assert_eq!(
LegacyNotificationChannel::to_legacy_line(&d),
Some("pub https://pod.example.com/x".to_string())
);
}
#[test]
fn subscription_cap_rejects_over_limit() {
let (_tx, rx) = broadcast::channel::<StorageEvent>(16);
let mut chan = LegacyNotificationChannel::new(rx)
.with_authorizer(Arc::new(AllowAllAuthorizer))
.with_subscription_cap(2);
assert!(chan.subscribe("https://p/a".into()).is_ok());
assert!(chan.subscribe("https://p/b".into()).is_ok());
let err = chan.subscribe("https://p/c".into()).unwrap_err();
assert!(err.starts_with("err "));
assert!(err.contains("subscription-limit"));
assert_eq!(chan.subscription_count(), 2);
}
#[test]
fn url_cap_rejects_over_limit() {
let (_tx, rx) = broadcast::channel::<StorageEvent>(16);
let mut chan = LegacyNotificationChannel::new(rx)
.with_authorizer(Arc::new(AllowAllAuthorizer))
.with_url_cap(16);
let err = chan
.subscribe("https://pod.example.com/really/long/path".into())
.unwrap_err();
assert!(err.contains("url-too-long"));
assert_eq!(chan.subscription_count(), 0);
}
#[test]
fn matches_subscription_prefix_and_exact() {
let (_tx, rx) = broadcast::channel::<StorageEvent>(16);
let mut chan =
LegacyNotificationChannel::new(rx).with_authorizer(Arc::new(AllowAllAuthorizer));
chan.subscribe("https://pod.example.com/foo/".into()).unwrap();
chan.subscribe("https://pod.example.com/bar.ttl".into()).unwrap();
assert!(chan.matches_subscription("https://pod.example.com/foo/"));
assert!(chan.matches_subscription("https://pod.example.com/foo/deep/nested"));
assert!(chan.matches_subscription("https://pod.example.com/bar.ttl"));
assert!(!chan.matches_subscription("https://pod.example.com/other"));
assert!(!chan.matches_subscription("https://pod.example.com/bar.ttl.backup"));
}
#[test]
fn unsubscribe_removes_target() {
let (_tx, rx) = broadcast::channel::<StorageEvent>(16);
let mut chan =
LegacyNotificationChannel::new(rx).with_authorizer(Arc::new(AllowAllAuthorizer));
chan.subscribe("https://p/x".into()).unwrap();
chan.unsubscribe("https://p/x");
assert_eq!(chan.subscription_count(), 0);
chan.unsubscribe("https://p/y"); }
#[test]
fn ack_and_err_lines() {
assert_eq!(
LegacyNotificationChannel::ack_line("https://p/x"),
"ack https://p/x"
);
assert_eq!(
LegacyNotificationChannel::err_line("https://p/x", "forbidden"),
"err https://p/x forbidden"
);
}
#[test]
fn opcode_wire_names() {
assert_eq!(SolidZeroOp::Sub.as_str(), "sub");
assert_eq!(SolidZeroOp::Ack.as_str(), "ack");
assert_eq!(SolidZeroOp::Err.as_str(), "err");
assert_eq!(SolidZeroOp::Pub.as_str(), "pub");
assert_eq!(SolidZeroOp::Unsub.as_str(), "unsub");
}
}
#[async_trait]
pub trait LegacyWacRead: Send + Sync {
async fn can_read(&self, webid: Option<&str>, resource_uri: &str) -> bool;
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum LegacyFrame {
Ack(String),
Err(String),
Pub(String),
}
impl LegacyFrame {
pub fn to_wire(&self) -> String {
match self {
LegacyFrame::Ack(u) => format!("ack {u}"),
LegacyFrame::Err(m) => format!("err {m}"),
LegacyFrame::Pub(u) => format!("pub {u}"),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Default)]
pub struct LegacyResponse {
pub frames: Vec<LegacyFrame>,
}
impl LegacyResponse {
fn one(frame: LegacyFrame) -> Self {
Self { frames: vec![frame] }
}
fn empty() -> Self {
Self::default()
}
}
pub struct LegacyWebSocketSession {
subs: HashSet<String>,
max_subs: usize,
max_uri_bytes: usize,
wac_check: Arc<dyn LegacyWacRead>,
subscriber_webid: Option<String>,
}
impl LegacyWebSocketSession {
pub fn new(wac_check: Arc<dyn LegacyWacRead>, subscriber_webid: Option<String>) -> Self {
Self {
subs: HashSet::new(),
max_subs: MAX_SUBSCRIPTIONS_PER_CONNECTION,
max_uri_bytes: MAX_URL_LENGTH,
wac_check,
subscriber_webid,
}
}
pub fn with_max_subs(mut self, cap: usize) -> Self {
self.max_subs = cap;
self
}
pub fn with_max_uri_bytes(mut self, cap: usize) -> Self {
self.max_uri_bytes = cap;
self
}
pub fn subscription_count(&self) -> usize {
self.subs.len()
}
pub fn is_subscribed(&self, uri: &str) -> bool {
self.subs.contains(uri)
}
pub async fn handle_message(&mut self, msg: &str) -> LegacyResponse {
let line = msg.trim_end_matches(['\r', '\n']).trim();
if line.is_empty() {
return LegacyResponse::empty();
}
if let Some(rest) = line.strip_prefix("sub ") {
return self.handle_sub(rest.trim()).await;
}
if let Some(rest) = line.strip_prefix("unsub ") {
self.handle_unsub(rest.trim());
return LegacyResponse::empty();
}
LegacyResponse::one(LegacyFrame::Err("unknown command".to_string()))
}
async fn handle_sub(&mut self, uri: &str) -> LegacyResponse {
if uri.is_empty() {
return LegacyResponse::one(LegacyFrame::Err("unknown command".to_string()));
}
if uri.len() > self.max_uri_bytes {
return LegacyResponse::one(LegacyFrame::Err("uri too long".to_string()));
}
if !self.subs.contains(uri) && self.subs.len() >= self.max_subs {
return LegacyResponse::one(LegacyFrame::Err(
"subscription limit reached".to_string(),
));
}
let allowed = self
.wac_check
.can_read(self.subscriber_webid.as_deref(), uri)
.await;
if !allowed {
return LegacyResponse::one(LegacyFrame::Err(format!("{uri} forbidden")));
}
self.subs.insert(uri.to_string());
LegacyResponse::one(LegacyFrame::Ack(uri.to_string()))
}
fn handle_unsub(&mut self, uri: &str) {
if !uri.is_empty() {
self.subs.remove(uri);
}
}
pub async fn on_resource_change(&self, changed_uri: &str) -> Vec<LegacyFrame> {
let mut candidates: HashSet<String> = HashSet::new();
if self.subs.contains(changed_uri) {
candidates.insert(changed_uri.to_string());
}
for anc in ancestor_containers(changed_uri) {
if self.subs.contains(&anc) {
candidates.insert(anc);
}
}
if candidates.is_empty() {
return Vec::new();
}
let allowed = self
.wac_check
.can_read(self.subscriber_webid.as_deref(), changed_uri)
.await;
if allowed {
vec![LegacyFrame::Pub(changed_uri.to_string())]
} else {
vec![LegacyFrame::Err(format!("{changed_uri} forbidden"))]
}
}
pub async fn on_storage_event(&self, ev: &StorageEvent) -> Vec<LegacyFrame> {
let uri = match ev {
StorageEvent::Created(p) | StorageEvent::Updated(p) | StorageEvent::Deleted(p) => p,
};
self.on_resource_change(uri).await
}
}
pub fn ancestor_containers(uri: &str) -> Vec<String> {
let trimmed = uri.trim_end_matches('/');
if trimmed.is_empty() {
return Vec::new();
}
let origin_end = find_origin_end(trimmed);
let mut out = Vec::new();
let mut cursor = trimmed.to_string();
loop {
let search_region = &cursor[origin_end..];
let Some(last_slash_rel) = search_region.rfind('/') else {
break;
};
let last_slash_abs = origin_end + last_slash_rel;
let parent = &cursor[..=last_slash_abs];
if parent.len() <= origin_end + 1 {
out.push(parent.to_string());
break;
}
out.push(parent.to_string());
cursor = cursor[..last_slash_abs].to_string();
}
out
}
fn find_origin_end(uri: &str) -> usize {
if let Some(scheme_end) = uri.find("://") {
let authority_start = scheme_end + 3;
if let Some(path_slash_rel) = uri[authority_start..].find('/') {
return authority_start + path_slash_rel;
}
return uri.len();
}
0
}
#[cfg(test)]
mod session_tests {
use super::*;
#[test]
fn ancestor_containers_root_path_has_none() {
assert!(ancestor_containers("/").is_empty());
assert!(ancestor_containers("").is_empty());
}
#[test]
fn ancestor_containers_relative_path_climbs() {
let got = ancestor_containers("/a/b/c");
assert_eq!(got, vec!["/a/b/".to_string(), "/a/".to_string(), "/".to_string()]);
}
#[test]
fn ancestor_containers_absolute_url_stops_at_origin_root() {
let got = ancestor_containers("https://pod.example/a/b/c");
assert_eq!(
got,
vec![
"https://pod.example/a/b/".to_string(),
"https://pod.example/a/".to_string(),
"https://pod.example/".to_string(),
]
);
}
#[test]
fn ancestor_containers_trailing_slash_treated_as_container() {
let a = ancestor_containers("/a/b/");
let b = ancestor_containers("/a/b");
assert_eq!(a, b);
assert_eq!(a, vec!["/a/".to_string(), "/".to_string()]);
}
#[test]
fn legacy_frame_to_wire_roundtrip() {
assert_eq!(LegacyFrame::Ack("/x".into()).to_wire(), "ack /x");
assert_eq!(LegacyFrame::Err("forbidden".into()).to_wire(), "err forbidden");
assert_eq!(LegacyFrame::Pub("/x".into()).to_wire(), "pub /x");
}
}