use crate::error::RealtimeError;
use serde::{Deserialize, Serialize};
use std::collections::HashSet;
pub const MAX_SUBSCRIPTIONS_PER_CLIENT: usize = 100;
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct Channel {
pub channel_type: ChannelType,
pub identifier: String,
pub filter: Option<String>,
}
impl Channel {
pub fn parse(s: &str) -> Result<Self, RealtimeError> {
let parts: Vec<&str> = s.splitn(2, ':').collect();
if parts.len() != 2 {
return Err(RealtimeError::InvalidChannel(format!(
"missing channel type prefix: {}",
s
)));
}
let channel_type = match parts[0] {
"repo" => ChannelType::Repository,
"user" => ChannelType::User,
"org" => ChannelType::Organization,
_ => {
return Err(RealtimeError::InvalidChannel(format!(
"unknown channel type: {}",
parts[0]
)))
}
};
let identifier_parts: Vec<&str> = parts[1].splitn(3, '/').collect();
let (identifier, filter) = match channel_type {
ChannelType::Repository => {
if identifier_parts.len() < 2 {
return Err(RealtimeError::InvalidChannel(format!(
"repository channel requires owner/name format: {}",
parts[1]
)));
}
let id = format!("{}/{}", identifier_parts[0], identifier_parts[1]);
let filter = if identifier_parts.len() > 2 {
Some(identifier_parts[2].to_string())
} else {
None
};
(id, filter)
}
ChannelType::User | ChannelType::Organization => {
if identifier_parts.is_empty() || identifier_parts[0].is_empty() {
return Err(RealtimeError::InvalidChannel(format!(
"channel identifier cannot be empty: {}",
s
)));
}
(identifier_parts[0].to_string(), None)
}
};
Ok(Channel {
channel_type,
identifier,
filter,
})
}
pub fn matches(&self, event_channel: &str) -> bool {
let event_chan = match Channel::parse(event_channel) {
Ok(c) => c,
Err(_) => return false,
};
if self.channel_type != event_chan.channel_type {
return false;
}
if self.identifier != event_chan.identifier {
return false;
}
if self.filter.is_none() {
return true;
}
self.filter == event_chan.filter
}
}
impl std::fmt::Display for Channel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let prefix = match self.channel_type {
ChannelType::Repository => "repo",
ChannelType::User => "user",
ChannelType::Organization => "org",
};
match &self.filter {
Some(filter) => write!(f, "{}:{}/{}", prefix, self.identifier, filter),
None => write!(f, "{}:{}", prefix, self.identifier),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum ChannelType {
Repository,
User,
Organization,
}
#[derive(Debug, Default)]
pub struct ClientSubscriptions {
channels: HashSet<Channel>,
}
impl ClientSubscriptions {
pub fn new() -> Self {
Self {
channels: HashSet::new(),
}
}
pub fn subscribe(&mut self, channel: Channel) -> Result<bool, RealtimeError> {
if self.channels.len() >= MAX_SUBSCRIPTIONS_PER_CLIENT {
return Err(RealtimeError::SubscriptionLimit(
MAX_SUBSCRIPTIONS_PER_CLIENT,
));
}
Ok(self.channels.insert(channel))
}
pub fn unsubscribe(&mut self, channel: &Channel) -> bool {
self.channels.remove(channel)
}
pub fn is_subscribed(&self, channel: &Channel) -> bool {
self.channels.contains(channel)
}
pub fn matches_event(&self, event_channel: &str) -> bool {
self.channels.iter().any(|c| c.matches(event_channel))
}
pub fn channels(&self) -> impl Iterator<Item = &Channel> {
self.channels.iter()
}
pub fn count(&self) -> usize {
self.channels.len()
}
pub fn clear(&mut self) {
self.channels.clear();
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_channel_parse_repo() {
let channel = Channel::parse("repo:alice/myrepo").unwrap();
assert_eq!(channel.channel_type, ChannelType::Repository);
assert_eq!(channel.identifier, "alice/myrepo");
assert_eq!(channel.filter, None);
}
#[test]
fn test_channel_parse_repo_with_filter() {
let channel = Channel::parse("repo:alice/myrepo/prs").unwrap();
assert_eq!(channel.channel_type, ChannelType::Repository);
assert_eq!(channel.identifier, "alice/myrepo");
assert_eq!(channel.filter, Some("prs".to_string()));
}
#[test]
fn test_channel_parse_user() {
let channel = Channel::parse("user:alice").unwrap();
assert_eq!(channel.channel_type, ChannelType::User);
assert_eq!(channel.identifier, "alice");
assert_eq!(channel.filter, None);
}
#[test]
fn test_channel_parse_org() {
let channel = Channel::parse("org:acme").unwrap();
assert_eq!(channel.channel_type, ChannelType::Organization);
assert_eq!(channel.identifier, "acme");
assert_eq!(channel.filter, None);
}
#[test]
fn test_channel_parse_invalid() {
assert!(Channel::parse("invalid").is_err());
assert!(Channel::parse("unknown:test").is_err());
assert!(Channel::parse("repo:").is_err());
assert!(Channel::parse("repo:onlyname").is_err());
}
#[test]
fn test_channel_to_string() {
let channel = Channel {
channel_type: ChannelType::Repository,
identifier: "alice/myrepo".to_string(),
filter: None,
};
assert_eq!(channel.to_string(), "repo:alice/myrepo");
let channel_with_filter = Channel {
channel_type: ChannelType::Repository,
identifier: "alice/myrepo".to_string(),
filter: Some("prs".to_string()),
};
assert_eq!(channel_with_filter.to_string(), "repo:alice/myrepo/prs");
}
#[test]
fn test_channel_matches() {
let subscription = Channel::parse("repo:alice/myrepo").unwrap();
assert!(subscription.matches("repo:alice/myrepo"));
assert!(subscription.matches("repo:alice/myrepo/prs"));
assert!(subscription.matches("repo:alice/myrepo/issues"));
assert!(!subscription.matches("repo:bob/otherrepo"));
assert!(!subscription.matches("user:alice"));
}
#[test]
fn test_channel_matches_with_filter() {
let subscription = Channel::parse("repo:alice/myrepo/prs").unwrap();
assert!(subscription.matches("repo:alice/myrepo/prs"));
assert!(!subscription.matches("repo:alice/myrepo"));
assert!(!subscription.matches("repo:alice/myrepo/issues"));
}
#[test]
fn test_client_subscriptions() {
let mut subs = ClientSubscriptions::new();
let channel = Channel::parse("repo:alice/myrepo").unwrap();
assert!(subs.subscribe(channel.clone()).unwrap());
assert!(subs.is_subscribed(&channel));
assert_eq!(subs.count(), 1);
assert!(!subs.subscribe(channel.clone()).unwrap());
assert_eq!(subs.count(), 1);
assert!(subs.unsubscribe(&channel));
assert!(!subs.is_subscribed(&channel));
assert_eq!(subs.count(), 0);
}
#[test]
fn test_client_subscriptions_limit() {
let mut subs = ClientSubscriptions::new();
for i in 0..MAX_SUBSCRIPTIONS_PER_CLIENT {
let channel = Channel::parse(&format!("user:user{}", i)).unwrap();
subs.subscribe(channel).unwrap();
}
let extra = Channel::parse("user:extra").unwrap();
assert!(matches!(
subs.subscribe(extra),
Err(RealtimeError::SubscriptionLimit(_))
));
}
#[test]
fn test_matches_event() {
let mut subs = ClientSubscriptions::new();
subs.subscribe(Channel::parse("repo:alice/myrepo").unwrap())
.unwrap();
subs.subscribe(Channel::parse("user:alice").unwrap())
.unwrap();
assert!(subs.matches_event("repo:alice/myrepo"));
assert!(subs.matches_event("repo:alice/myrepo/prs"));
assert!(subs.matches_event("user:alice"));
assert!(!subs.matches_event("repo:bob/otherrepo"));
assert!(!subs.matches_event("user:bob"));
}
}