use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashSet;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum ChannelType {
Public,
Private,
Presence,
}
impl ChannelType {
pub fn from_name(name: &str) -> Self {
if name.starts_with("private-") {
Self::Private
} else if name.starts_with("presence-") {
Self::Presence
} else {
Self::Public
}
}
pub fn requires_auth(&self) -> bool {
matches!(self, Self::Private | Self::Presence)
}
}
#[derive(Debug, Clone)]
pub struct ChannelInfo {
pub name: String,
pub channel_type: ChannelType,
pub subscribers: HashSet<String>,
pub members: Vec<PresenceMember>,
}
impl ChannelInfo {
pub fn new(name: impl Into<String>) -> Self {
let name = name.into();
let channel_type = ChannelType::from_name(&name);
Self {
name,
channel_type,
subscribers: HashSet::new(),
members: Vec::new(),
}
}
pub fn is_empty(&self) -> bool {
self.subscribers.is_empty()
}
pub fn subscriber_count(&self) -> usize {
self.subscribers.len()
}
pub fn add_subscriber(&mut self, socket_id: String) -> bool {
self.subscribers.insert(socket_id)
}
pub fn remove_subscriber(&mut self, socket_id: &str) -> bool {
self.subscribers.remove(socket_id)
}
pub fn add_member(&mut self, member: PresenceMember) {
self.members.retain(|m| m.user_id != member.user_id);
self.members.push(member);
}
pub fn remove_member(&mut self, socket_id: &str) -> Option<PresenceMember> {
if let Some(idx) = self.members.iter().position(|m| m.socket_id == socket_id) {
Some(self.members.remove(idx))
} else {
None
}
}
pub fn get_members(&self) -> &[PresenceMember] {
&self.members
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PresenceMember {
pub socket_id: String,
pub user_id: String,
pub user_info: Value,
}
impl PresenceMember {
pub fn new(socket_id: impl Into<String>, user_id: impl Into<String>) -> Self {
Self {
socket_id: socket_id.into(),
user_id: user_id.into(),
user_info: Value::Null,
}
}
pub fn with_info(mut self, info: impl Serialize) -> Self {
self.user_info = serde_json::to_value(info).unwrap_or(Value::Null);
self
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_channel_type_from_name() {
assert_eq!(ChannelType::from_name("orders"), ChannelType::Public);
assert_eq!(
ChannelType::from_name("private-orders.1"),
ChannelType::Private
);
assert_eq!(
ChannelType::from_name("presence-chat.1"),
ChannelType::Presence
);
}
#[test]
fn test_channel_requires_auth() {
assert!(!ChannelType::Public.requires_auth());
assert!(ChannelType::Private.requires_auth());
assert!(ChannelType::Presence.requires_auth());
}
#[test]
fn test_channel_info() {
let mut channel = ChannelInfo::new("private-orders.1");
assert_eq!(channel.channel_type, ChannelType::Private);
assert!(channel.is_empty());
channel.add_subscriber("socket_1".into());
assert!(!channel.is_empty());
assert_eq!(channel.subscriber_count(), 1);
channel.remove_subscriber("socket_1");
assert!(channel.is_empty());
}
#[test]
fn test_presence_members() {
let mut channel = ChannelInfo::new("presence-chat.1");
let member = PresenceMember::new("socket_1", "user_1")
.with_info(serde_json::json!({"name": "Alice"}));
channel.add_subscriber("socket_1".into());
channel.add_member(member);
assert_eq!(channel.get_members().len(), 1);
assert_eq!(channel.get_members()[0].user_id, "user_1");
}
}