use std::{
collections::{HashMap, HashSet},
fmt::Display,
sync::{Arc, Mutex, MutexGuard, OnceLock},
time::Duration,
};
use tokio::{
sync::{Notify, mpsc},
time::Instant,
};
use crate::{
base::{SessionPath, Visibility},
identity::{self, AuthError},
protocol::{self, AdminOp, ChannelInfo, HistoryMessage, MachineInfo, Payload, ProtocolError, ProtocolMessage},
store::{ChannelRecord, Store},
};
use super::AclError;
const HISTORY_PAGE_CAP: usize = 500;
type Outbound = mpsc::Sender<ProtocolMessage>;
type Reply = Result<ProtocolMessage, ProtocolError>;
struct SessionEntry {
user: String,
machine: String,
outbound: Outbound,
kill: Arc<Kill>,
channels: HashSet<String>,
last_seen: Instant,
}
pub(crate) struct Kill {
notify: Notify,
reason: OnceLock<&'static str>,
}
impl Kill {
fn new() -> Self {
Self {
notify: Notify::new(),
reason: OnceLock::new(),
}
}
pub(crate) fn fire(&self, reason: &'static str) {
let _ = self.reason.set(reason);
self.notify.notify_one();
}
pub(crate) async fn notified(&self) {
self.notify.notified().await;
}
pub(crate) fn reason(&self) -> &'static str {
self.reason.get().copied().unwrap_or("session terminated")
}
}
#[derive(Default)]
struct HubState {
sessions: HashMap<SessionPath, SessionEntry>,
subscriptions: HashMap<String, HashSet<SessionPath>>,
bans: HashMap<String, HashSet<String>>,
}
pub(crate) struct Hub {
store: Store,
admins: super::AdminAllowlist,
instance_id: String,
state: Mutex<HubState>,
}
impl Hub {
pub(crate) async fn new(store: Store, admins: super::AdminAllowlist) -> crate::base::Res<Arc<Self>> {
let mut bans: HashMap<String, HashSet<String>> = HashMap::new();
for (channel, user) in store.list_bans().await? {
bans.entry(channel).or_default().insert(user);
}
let instance_id = store.instance_id().await?;
Ok(Arc::new(Self {
store,
admins,
instance_id,
state: Mutex::new(HubState { bans, ..HubState::default() }),
}))
}
fn state(&self) -> MutexGuard<'_, HubState> {
self.state.lock().expect("hub state mutex poisoned")
}
pub(crate) fn is_admin(&self, user: &str) -> bool {
self.admins.contains_key(user)
}
pub(crate) fn instance_id(&self) -> &str {
&self.instance_id
}
pub(crate) async fn list_machines(&self, user: &str) -> Result<Vec<MachineInfo>, ProtocolError> {
let machines = self.store.list_machines(user).await.map_err(internal)?;
Ok(machines
.into_iter()
.map(|m| MachineInfo {
name: m.name,
pubkey: m.pubkey,
added_at: m.added_at,
})
.collect())
}
pub(crate) async fn list_users(&self, caller: &str) -> Result<Vec<String>, ProtocolError> {
if !self.is_admin(caller) {
return Err(AclError::NotAdmin.into());
}
let users = self.store.list_users().await.map_err(internal)?;
Ok(users.into_iter().map(|u| u.username).collect())
}
#[tracing::instrument(level = "debug", skip(self, pubkey))]
pub(crate) async fn register(&self, username: &str, machine: &str, pubkey: &[u8]) -> Result<(), ProtocolError> {
let pubkey_b64 = identity::encode_key(pubkey);
if let Some(Some(bound)) = self.admins.get(username)
&& &pubkey_b64 != bound
{
return Err(AuthError::Reserved(username.to_owned()).into());
}
if self.store.get_user(username).await.map_err(internal)?.is_some() {
return Err(AuthError::UsernameTaken(username.to_owned()).into());
}
if self.store.get_machine_by_pubkey(&pubkey_b64).await.map_err(internal)?.is_some() {
return Err(AuthError::Malformed("public key is already enrolled".to_owned()).into());
}
self.store.create_user(username).await.map_err(internal)?;
self.store.create_machine(username, machine, &pubkey_b64).await.map_err(internal)?;
Ok(())
}
pub(crate) async fn resolve(&self, pubkey: &[u8]) -> Result<(String, String), ProtocolError> {
let pubkey_b64 = identity::encode_key(pubkey);
let machine = self
.store
.get_machine_by_pubkey(&pubkey_b64)
.await
.map_err(internal)?
.ok_or_else(|| ProtocolError::from(AuthError::UnknownKey))?;
Ok((machine.user, machine.name))
}
pub(crate) fn attach(&self, path: &SessionPath, user: &str, machine: &str, outbound: Outbound) -> Arc<Kill> {
let mut st = self.state();
if let Some(existing) = st.sessions.get(path) {
existing.kill.fire("session superseded by a newer connection for the same session path");
Self::take_session(&mut st, path);
tracing::info!(%path, "session superseded by a newer connection");
}
tracing::info!(%path, user, machine, "session established");
let kill = Arc::new(Kill::new());
st.sessions.insert(
path.clone(),
SessionEntry {
user: user.to_owned(),
machine: machine.to_owned(),
outbound,
kill: Arc::clone(&kill),
channels: HashSet::new(),
last_seen: Instant::now(),
},
);
kill
}
pub(crate) fn detach(&self, path: &SessionPath, kill: &Arc<Kill>) {
let mut st = self.state();
if st.sessions.get(path).is_some_and(|e| Arc::ptr_eq(&e.kill, kill)) {
Self::take_session(&mut st, path);
tracing::debug!(%path, "session detached");
}
}
pub(crate) fn touch(&self, path: &SessionPath) {
if let Some(entry) = self.state().sessions.get_mut(path) {
entry.last_seen = Instant::now();
}
}
pub(crate) fn reap_idle(&self, timeout: Duration) -> usize {
let now = Instant::now();
let mut st = self.state();
let stale: Vec<SessionPath> = st
.sessions
.iter()
.filter(|(_, e)| now.saturating_duration_since(e.last_seen) >= timeout)
.map(|(p, _)| p.clone())
.collect();
for path in &stale {
Self::kill_locked(&mut st, path, "idle timeout: no heartbeat");
}
stale.len()
}
#[tracing::instrument(level = "debug", skip(self, token), fields(path = %path))]
pub(crate) async fn join(&self, user: &str, path: &SessionPath, channel: &str, token: Option<&str>) -> Result<(), ProtocolError> {
let record = self
.store
.get_channel(channel)
.await
.map_err(internal)?
.ok_or_else(|| ProtocolError::NotFound(format!("channel `{channel}`")))?;
if self.is_banned(channel, user) {
return Err(AclError::ChannelPrivate(channel.to_owned()).into());
}
let already_member = self.store.is_channel_member(channel, user).await.map_err(internal)?;
if !already_member {
match parse_visibility(&record.visibility) {
Visibility::Public | Visibility::Unlisted => {}
Visibility::Private => {
let token = token.ok_or_else(|| ProtocolError::from(AclError::ChannelNotFound(channel.to_owned())))?;
self.redeem_invite(channel, token).await?;
self.store.add_channel_member(channel, user).await.map_err(internal)?;
}
}
}
if !self.subscribe(path, channel) {
return Err(AclError::ChannelPrivate(channel.to_owned()).into());
}
Ok(())
}
pub(crate) fn leave(&self, path: &SessionPath, channel: &str) {
self.unsubscribe_session(path, channel);
}
pub(crate) async fn list_channels(&self, user: &str) -> Result<Vec<ChannelInfo>, ProtocolError> {
let channels = self.store.list_channels().await.map_err(internal)?;
let memberships: HashSet<String> = self.store.list_user_memberships(user).await.map_err(internal)?.into_iter().collect();
let admin = self.is_admin(user);
let infos = channels
.into_iter()
.filter_map(|c| {
let visibility = parse_visibility(&c.visibility);
let member = memberships.contains(&c.name);
let visible = admin || matches!(visibility, Visibility::Public) || member;
visible.then_some(ChannelInfo { name: c.name, visibility, member })
})
.collect();
Ok(infos)
}
pub(crate) async fn who(&self, user: &str, channel: Option<&str>) -> Result<Vec<SessionPath>, ProtocolError> {
let Some(channel) = channel else {
return Ok(self.present_paths());
};
let not_found = || ProtocolError::from(AclError::ChannelNotFound(channel.to_owned()));
let record = self.store.get_channel(channel).await.map_err(internal)?.ok_or_else(not_found)?;
let allowed = match parse_visibility(&record.visibility) {
Visibility::Public => true,
Visibility::Unlisted | Visibility::Private => self.store.is_channel_member(channel, user).await.map_err(internal)? || self.is_admin(user),
};
if !allowed {
return Err(not_found());
}
let st = self.state();
Ok(st.subscriptions.get(channel).map(|subs| subs.iter().cloned().collect()).unwrap_or_default())
}
#[tracing::instrument(level = "debug", skip(self, payload), fields(from = %from))]
pub(crate) async fn post(&self, from: &SessionPath, channel: &str, payload: Payload) -> Result<(), ProtocolError> {
let targets: Vec<(Arc<Kill>, Outbound)> = {
let st = self.state();
let subs = st.subscriptions.get(channel).ok_or_else(|| ProtocolError::from(AclError::NotMember(channel.to_owned())))?;
if !subs.contains(from) {
return Err(AclError::NotMember(channel.to_owned()).into());
}
subs.iter()
.filter(|p| *p != from)
.filter_map(|p| st.sessions.get(p).map(|e| (Arc::clone(&e.kill), e.outbound.clone())))
.collect()
};
match protocol::encode_payload(&payload) {
Ok(bytes) => {
let ts_ms = chrono::Utc::now().timestamp_millis();
if let Err(err) = self.store.append_message(channel, &from.to_string(), &bytes, ts_ms).await {
tracing::warn!(%channel, error = %err, "failed to retain channel history; delivering live only");
}
}
Err(err) => tracing::warn!(%channel, error = %err, "failed to encode payload for retention; delivering live only"),
}
let msg = ProtocolMessage::ChannelMsg {
channel: channel.to_owned(),
from: from.clone(),
payload,
};
for (kill, tx) in targets {
if tx.try_send(msg.clone()).is_err() {
kill.fire("slow consumer: outbound queue overflowed");
}
}
Ok(())
}
#[tracing::instrument(level = "debug", skip(self), fields(caller = %caller))]
pub(crate) async fn read_since(&self, caller: &SessionPath, channel: &str, since_ms: i64) -> Result<ProtocolMessage, ProtocolError> {
{
let st = self.state();
let subs = st.subscriptions.get(channel).ok_or_else(|| ProtocolError::from(AclError::NotMember(channel.to_owned())))?;
if !subs.contains(caller) {
return Err(AclError::NotMember(channel.to_owned()).into());
}
}
let rows = self.store.read_messages_since(channel, since_ms, HISTORY_PAGE_CAP).await.map_err(internal)?;
let messages = rows
.into_iter()
.filter_map(|row| {
let payload = protocol::decode_payload(&row.payload).ok()?;
let from = row.from.parse::<SessionPath>().ok()?;
Some(HistoryMessage { from, ts_ms: row.ts_ms, payload })
})
.collect();
Ok(ProtocolMessage::History { channel: channel.to_owned(), messages })
}
#[tracing::instrument(level = "debug", skip(self, payload), fields(from = %from, target = %target))]
pub(crate) fn whisper(&self, from: &SessionPath, target: &SessionPath, payload: Payload) -> Result<(), ProtocolError> {
let target_entry = self.state().sessions.get(target).map(|e| (Arc::clone(&e.kill), e.outbound.clone()));
let Some((kill, outbound)) = target_entry else {
return Err(ProtocolError::NotFound(format!("session `{target}` is not online")));
};
let msg = ProtocolMessage::Whisper {
from: from.clone(),
target: target.clone(),
payload,
};
if outbound.try_send(msg).is_err() {
kill.fire("slow consumer: outbound queue overflowed");
}
Ok(())
}
#[tracing::instrument(level = "debug", skip(self, op), fields(op = op.name()))]
pub(crate) async fn admin(&self, user: &str, op: AdminOp) -> Reply {
match op {
AdminOp::CreateChannel { name, visibility } => {
self.store.create_channel(&name, visibility, user).await.map_err(internal)?;
Ok(ack(name))
}
AdminOp::DeleteChannel { name } => {
self.authorize_channel_admin(&name, user).await?;
self.store.delete_channel(&name).await.map_err(internal)?;
self.drop_channel(&name);
Ok(ack(name))
}
AdminOp::RenameChannel { name, new_name } => {
self.authorize_channel_admin(&name, user).await?;
self.store.rename_channel(&name, &new_name).await.map_err(internal)?;
self.rename_channel_subscriptions(&name, &new_name);
Ok(ack(new_name))
}
AdminOp::SetVisibility { name, visibility } => {
self.authorize_channel_admin(&name, user).await?;
self.store.set_channel_visibility(&name, visibility).await.map_err(internal)?;
Ok(ack(name))
}
AdminOp::AclAdd { channel, user: target } => {
self.authorize_channel_admin(&channel, user).await?;
self.store.add_channel_member(&channel, &target).await.map_err(internal)?;
self.remove_ban(&channel, &target).await?;
Ok(ack(target))
}
AdminOp::AclRemove { channel, user: target } => {
self.authorize_channel_admin(&channel, user).await?;
self.store.remove_channel_member(&channel, &target).await.map_err(internal)?;
self.unsubscribe_user(&target, &channel);
Ok(ack(target))
}
AdminOp::AclList { channel } => {
self.authorize_channel_admin(&channel, user).await?;
let users = self.store.list_channel_members(&channel).await.map_err(internal)?;
Ok(ProtocolMessage::UserList { users })
}
AdminOp::Unban { channel, user: target } => {
self.authorize_channel_admin(&channel, user).await?;
self.remove_ban(&channel, &target).await?;
Ok(ack(target))
}
AdminOp::BanList { channel } => {
self.authorize_channel_admin(&channel, user).await?;
let users = self.store.list_channel_bans(&channel).await.map_err(internal)?;
Ok(ProtocolMessage::UserList { users })
}
AdminOp::InviteList { channel } => self.list_channel_invites(user, &channel).await,
AdminOp::InviteCreate { channel, uses, expires_in_secs } => self.create_invite(user, &channel, uses, expires_in_secs).await,
AdminOp::InviteRevoke { token } => {
if let Some(invite) = self.store.get_invite(&token).await.map_err(internal)?
&& self.is_channel_admin(&invite.channel, user).await
{
self.store.delete_invite(&token).await.map_err(internal)?;
}
Ok(ack(token))
}
AdminOp::Kick { channel, target } => {
self.authorize_channel_admin(&channel, user).await?;
if let Ok(path) = target.parse::<SessionPath>() {
self.unsubscribe_session(&path, &channel);
} else {
self.unsubscribe_user(&target, &channel);
}
Ok(ack(target))
}
AdminOp::Ban { channel, user: target } => {
self.authorize_channel_admin(&channel, user).await?;
self.store.remove_channel_member(&channel, &target).await.map_err(internal)?;
self.add_ban(&channel, &target).await?;
self.unsubscribe_user(&target, &channel);
Ok(ack(target))
}
AdminOp::UserRemove { username } => {
if !self.is_admin(user) {
return Err(AclError::NotAdmin.into());
}
self.remove_user(&username).await?;
Ok(ack(username))
}
AdminOp::MachineRemove { name } => {
self.store.delete_machine(user, &name).await.map_err(internal)?;
self.force_drop_machine(user, &name);
Ok(ack(name))
}
AdminOp::MachineAdd { name, pubkey } => {
let pubkey_b64 = identity::encode_key(&pubkey);
self.store.create_machine(user, &name, &pubkey_b64).await.map_err(internal)?;
Ok(ack(name))
}
}
}
async fn create_invite(&self, user: &str, channel: &str, max_uses: Option<u32>, lifetime_secs: Option<u64>) -> Reply {
self.authorize_channel_admin(channel, user).await?;
let token = identity::generate_token().map_err(internal)?;
let expires_at = match lifetime_secs {
Some(secs) => Some(invite_expiry(secs)?),
None => None,
};
self.store.create_invite(channel, &token, max_uses.map(i64::from), expires_at, user).await.map_err(internal)?;
Ok(ProtocolMessage::InviteToken { token })
}
async fn list_channel_invites(&self, user: &str, channel: &str) -> Reply {
self.authorize_channel_admin(channel, user).await?;
let invites = self
.store
.list_invites(channel)
.await
.map_err(internal)?
.into_iter()
.map(|i| crate::protocol::InviteInfo {
token: i.token,
uses_remaining: i.uses_remaining,
expires_at: i.expires_at,
})
.collect();
Ok(ProtocolMessage::InviteList { invites })
}
async fn remove_user(&self, username: &str) -> Result<(), ProtocolError> {
for channel in self.store.list_channels_created_by(username).await.map_err(internal)? {
self.store.delete_channel(&channel).await.map_err(internal)?;
self.drop_channel(&channel);
}
self.store.delete_user_memberships(username).await.map_err(internal)?;
for machine in self.store.list_machines(username).await.map_err(internal)? {
self.store.delete_machine(username, &machine.name).await.map_err(internal)?;
}
self.store.delete_user(username).await.map_err(internal)?;
self.force_drop_user(username);
Ok(())
}
async fn authorize_channel_admin(&self, channel: &str, user: &str) -> Result<ChannelRecord, ProtocolError> {
let record = self
.store
.get_channel(channel)
.await
.map_err(internal)?
.ok_or_else(|| ProtocolError::from(AclError::ChannelNotFound(channel.to_owned())))?;
if record.created_by == user || self.is_admin(user) {
Ok(record)
} else {
Err(AclError::NotAdmin.into())
}
}
async fn is_channel_admin(&self, channel: &str, user: &str) -> bool {
match self.store.get_channel(channel).await {
Ok(Some(record)) => record.created_by == user || self.is_admin(user),
_ => false,
}
}
async fn redeem_invite(&self, channel: &str, token: &str) -> Result<(), ProtocolError> {
let invite = self
.store
.get_invite(token)
.await
.map_err(internal)?
.filter(|inv| inv.channel == channel)
.ok_or_else(|| ProtocolError::from(AclError::ChannelNotFound(channel.to_owned())))?;
if invite.expires_at.as_deref().is_some_and(is_expired) {
self.store.delete_invite(token).await.map_err(internal)?;
return Err(AclError::ChannelNotFound(channel.to_owned()).into());
}
if invite.uses_remaining.is_some() && !self.store.try_consume_invite_use(token).await.map_err(internal)? {
return Err(AclError::ChannelNotFound(channel.to_owned()).into());
}
Ok(())
}
fn subscribe(&self, path: &SessionPath, channel: &str) -> bool {
let mut st = self.state();
if st.bans.get(channel).is_some_and(|banned| banned.contains(&path.user)) {
return false;
}
let Some(entry) = st.sessions.get_mut(path) else {
return false;
};
entry.channels.insert(channel.to_owned());
st.subscriptions.entry(channel.to_owned()).or_default().insert(path.clone());
true
}
fn unsubscribe_session(&self, path: &SessionPath, channel: &str) {
let mut st = self.state();
if let Some(subs) = st.subscriptions.get_mut(channel) {
subs.remove(path);
if subs.is_empty() {
st.subscriptions.remove(channel);
}
}
if let Some(entry) = st.sessions.get_mut(path) {
entry.channels.remove(channel);
}
}
fn unsubscribe_user(&self, user: &str, channel: &str) {
let mut st = self.state();
let Some(subs) = st.subscriptions.get(channel) else {
return;
};
let paths: Vec<SessionPath> = subs.iter().filter(|p| st.sessions.get(*p).is_some_and(|e| e.user == user)).cloned().collect();
for path in paths {
if let Some(subs) = st.subscriptions.get_mut(channel) {
subs.remove(&path);
if subs.is_empty() {
st.subscriptions.remove(channel);
}
}
if let Some(entry) = st.sessions.get_mut(&path) {
entry.channels.remove(channel);
}
}
}
fn drop_channel(&self, channel: &str) {
let mut st = self.state();
if let Some(subs) = st.subscriptions.remove(channel) {
for path in subs {
if let Some(entry) = st.sessions.get_mut(&path) {
entry.channels.remove(channel);
}
}
}
st.bans.remove(channel);
}
fn rename_channel_subscriptions(&self, old: &str, new: &str) {
let mut st = self.state();
if let Some(subs) = st.subscriptions.remove(old) {
for path in &subs {
if let Some(entry) = st.sessions.get_mut(path) {
entry.channels.remove(old);
entry.channels.insert(new.to_owned());
}
}
st.subscriptions.insert(new.to_owned(), subs);
}
if let Some(banned) = st.bans.remove(old) {
st.bans.insert(new.to_owned(), banned);
}
}
fn force_drop_user(&self, user: &str) {
let mut st = self.state();
let paths: Vec<SessionPath> = st.sessions.iter().filter(|(_, e)| e.user == user).map(|(p, _)| p.clone()).collect();
for path in &paths {
Self::kill_locked(&mut st, path, "user removed from this server");
}
}
fn force_drop_machine(&self, user: &str, machine: &str) {
let mut st = self.state();
let paths: Vec<SessionPath> = st.sessions.iter().filter(|(_, e)| e.user == user && e.machine == machine).map(|(p, _)| p.clone()).collect();
for path in &paths {
Self::kill_locked(&mut st, path, "machine key revoked");
}
}
fn is_banned(&self, channel: &str, user: &str) -> bool {
self.state().bans.get(channel).is_some_and(|banned| banned.contains(user))
}
async fn add_ban(&self, channel: &str, user: &str) -> Result<(), ProtocolError> {
self.store.add_ban(channel, user).await.map_err(internal)?;
self.state().bans.entry(channel.to_owned()).or_default().insert(user.to_owned());
Ok(())
}
async fn remove_ban(&self, channel: &str, user: &str) -> Result<(), ProtocolError> {
self.store.remove_ban(channel, user).await.map_err(internal)?;
if let Some(banned) = self.state().bans.get_mut(channel) {
banned.remove(user);
}
Ok(())
}
fn take_session(st: &mut HubState, path: &SessionPath) -> Option<SessionEntry> {
let entry = st.sessions.remove(path)?;
for channel in &entry.channels {
if let Some(subs) = st.subscriptions.get_mut(channel) {
subs.remove(path);
if subs.is_empty() {
st.subscriptions.remove(channel);
}
}
}
Some(entry)
}
fn kill_locked(st: &mut HubState, path: &SessionPath, reason: &'static str) {
if let Some(entry) = Self::take_session(st, path) {
entry.kill.fire(reason);
}
}
pub(crate) fn present_paths(&self) -> Vec<SessionPath> {
self.state().sessions.keys().cloned().collect()
}
#[cfg(test)]
pub(crate) fn is_present(&self, path: &SessionPath) -> bool {
self.state().sessions.contains_key(path)
}
#[cfg(test)]
pub(crate) fn subscribers(&self, channel: &str) -> Vec<SessionPath> {
self.state().subscriptions.get(channel).map(|subs| subs.iter().cloned().collect()).unwrap_or_default()
}
}
fn internal<E: Display>(err: E) -> ProtocolError {
ProtocolError::Internal(err.to_string())
}
fn ack(detail: impl Into<String>) -> ProtocolMessage {
ProtocolMessage::Ack { detail: Some(detail.into()) }
}
fn parse_visibility(token: &str) -> Visibility {
token.parse().unwrap_or(Visibility::Public)
}
fn is_expired(rfc3339: &str) -> bool {
chrono::DateTime::parse_from_rfc3339(rfc3339).map_or(true, |dt| dt.with_timezone(&chrono::Utc) <= chrono::Utc::now())
}
fn invite_expiry(secs: u64) -> Result<String, ProtocolError> {
let too_far = || ProtocolError::MalformedFrame("invite expiry is too far in the future".to_owned());
let secs = i64::try_from(secs).map_err(|_| too_far())?;
let delta = chrono::TimeDelta::try_seconds(secs).ok_or_else(too_far)?;
let expiry = chrono::Utc::now().checked_add_signed(delta).ok_or_else(too_far)?;
Ok(expiry.to_rfc3339())
}
#[cfg(test)]
mod tests {
#![allow(clippy::unwrap_used)]
use super::*;
use crate::store::Store;
async fn hub_with_private_channel(token_uses: Option<i64>, expires_at: Option<String>) -> Arc<Hub> {
let store = Store::open_in_memory().await.unwrap();
store.create_channel("ops", Visibility::Private, "aaron").await.unwrap();
store.create_invite("ops", "tok", token_uses, expires_at, "aaron").await.unwrap();
Hub::new(store, HashMap::new()).await.unwrap()
}
fn attach_session(hub: &Arc<Hub>, user: &str) -> SessionPath {
let path = SessionPath::new(user, "machine", "session");
let (tx, _rx) = mpsc::channel(super::super::session::OUTBOUND_CAPACITY);
hub.attach(&path, user, "machine", tx);
path
}
#[derive(Clone, Default)]
struct Buf(Arc<Mutex<Vec<u8>>>);
impl std::io::Write for Buf {
fn write(&mut self, data: &[u8]) -> std::io::Result<usize> {
self.0.lock().unwrap().extend_from_slice(data);
Ok(data.len())
}
fn flush(&mut self) -> std::io::Result<()> {
Ok(())
}
}
impl tracing_subscriber::fmt::MakeWriter<'_> for Buf {
type Writer = Buf;
fn make_writer(&self) -> Buf {
self.clone()
}
}
#[tokio::test]
async fn hub_request_paths_emit_spans_with_path_and_channel_fields() {
let buf = Buf::default();
let subscriber = tracing_subscriber::fmt()
.with_writer(buf.clone())
.with_ansi(false)
.with_max_level(tracing::Level::DEBUG)
.with_span_events(tracing_subscriber::fmt::format::FmtSpan::NEW)
.finish();
let _guard = tracing::subscriber::set_default(subscriber);
let store = Store::open_in_memory().await.unwrap();
store.create_channel("ops", Visibility::Public, "aaron").await.unwrap();
let hub = Hub::new(store, HashMap::new()).await.unwrap();
let aaron = attach_session(&hub, "aaron");
hub.join("aaron", &aaron, "ops", None).await.unwrap();
hub.post(&aaron, "ops", Payload::Plain("observable".to_owned())).await.unwrap();
hub.read_since(&aaron, "ops", 0).await.unwrap();
let output = String::from_utf8(buf.0.lock().unwrap().clone()).unwrap();
assert!(
output.contains("session established") && output.contains("aaron/machine/session"),
"attach must log establishment with the path: {output}"
);
for span in ["join", "post", "read_since"] {
assert!(output.contains(span), "the `{span}` path must be instrumented: {output}");
}
assert!(output.contains("ops"), "spans must carry the channel: {output}");
assert!(!output.contains("observable"), "payload bodies must never reach telemetry: {output}");
}
#[tokio::test]
async fn fanout_drops_a_slow_consumer_instead_of_growing_its_queue() {
let store = Store::open_in_memory().await.unwrap();
store.create_channel("ops", Visibility::Public, "aaron").await.unwrap();
let hub = Hub::new(store, HashMap::new()).await.unwrap();
let aaron = attach_session(&hub, "aaron");
hub.join("aaron", &aaron, "ops", None).await.unwrap();
let bob = SessionPath::new("bob", "machine", "session");
let (b_tx, _b_rx) = mpsc::channel(1);
let b_kill = hub.attach(&bob, "bob", "machine", b_tx);
hub.join("bob", &bob, "ops", None).await.unwrap();
hub.post(&aaron, "ops", Payload::Plain("one".to_owned())).await.unwrap();
hub.post(&aaron, "ops", Payload::Plain("two".to_owned())).await.unwrap();
assert!(
tokio::time::timeout(Duration::from_millis(500), b_kill.notified()).await.is_ok(),
"a consumer that fills its bounded queue must be force-dropped",
);
}
#[tokio::test]
async fn invite_revoke_gives_a_uniform_ack_and_no_delete_for_non_admins() {
let hub = hub_with_private_channel(None, None).await;
let existing = hub.admin("mallory", AdminOp::InviteRevoke { token: "tok".to_owned() }).await.unwrap();
let absent = hub.admin("mallory", AdminOp::InviteRevoke { token: "ghost".to_owned() }).await.unwrap();
assert!(
matches!(existing, ProtocolMessage::Ack { .. }),
"revoking an existing token as a non-admin must ack, not error: {existing:?}"
);
assert!(matches!(absent, ProtocolMessage::Ack { .. }), "revoking an absent token must ack identically: {absent:?}");
let carol = attach_session(&hub, "carol");
assert!(hub.join("carol", &carol, "ops", Some("tok")).await.is_ok(), "a non-admin revoke must not delete the token");
}
#[tokio::test]
async fn list_channels_shows_a_server_admin_everything() {
let store = Store::open_in_memory().await.unwrap();
store.create_channel("secret", Visibility::Private, "alice").await.unwrap();
let hub = Hub::new(store, HashMap::from([("root".to_owned(), None)])).await.unwrap();
let admin_view: Vec<String> = hub.list_channels("root").await.unwrap().into_iter().map(|c| c.name).collect();
assert!(admin_view.contains(&"secret".to_owned()), "a server admin must see private channels: {admin_view:?}");
let bob_view: Vec<String> = hub.list_channels("bob").await.unwrap().into_iter().map(|c| c.name).collect();
assert!(!bob_view.contains(&"secret".to_owned()), "a private channel must stay hidden from non-members: {bob_view:?}");
}
#[tokio::test]
async fn acl_list_returns_members_to_channel_admins_only() {
let store = Store::open_in_memory().await.unwrap();
store.create_channel("ops", Visibility::Private, "aaron").await.unwrap();
let hub = Hub::new(store, HashMap::new()).await.unwrap();
hub.admin(
"aaron",
AdminOp::AclAdd {
channel: "ops".to_owned(),
user: "david".to_owned(),
},
)
.await
.unwrap();
match hub.admin("aaron", AdminOp::AclList { channel: "ops".to_owned() }).await.unwrap() {
ProtocolMessage::UserList { mut users } => {
users.sort();
assert_eq!(users, vec!["aaron".to_owned(), "david".to_owned()]);
}
other => panic!("expected a UserList, got {other:?}"),
}
assert!(
hub.admin("mallory", AdminOp::AclList { channel: "ops".to_owned() }).await.is_err(),
"a non-admin must not list a channel's members",
);
}
#[tokio::test]
async fn invite_list_shows_outstanding_tokens_to_channel_admins_only() {
let hub = hub_with_private_channel(Some(2), None).await;
match hub.admin("aaron", AdminOp::InviteList { channel: "ops".to_owned() }).await.unwrap() {
ProtocolMessage::InviteList { invites } => {
assert_eq!(invites.len(), 1);
assert_eq!(invites[0].token, "tok");
assert_eq!(invites[0].uses_remaining, Some(2));
}
other => panic!("expected an InviteList, got {other:?}"),
}
assert!(hub.admin("mallory", AdminOp::InviteList { channel: "ops".to_owned() }).await.is_err());
}
#[tokio::test]
async fn ban_visibility_list_and_unban_are_channel_admin_gated() {
let store = Store::open_in_memory().await.unwrap();
store.create_channel("ops", Visibility::Public, "aaron").await.unwrap();
let hub = Hub::new(store.clone(), HashMap::new()).await.unwrap();
hub.admin(
"aaron",
AdminOp::Ban {
channel: "ops".to_owned(),
user: "bob".to_owned(),
},
)
.await
.unwrap();
match hub.admin("aaron", AdminOp::BanList { channel: "ops".to_owned() }).await.unwrap() {
ProtocolMessage::UserList { users } => assert_eq!(users, vec!["bob".to_owned()]),
other => panic!("expected a UserList of bans, got {other:?}"),
}
assert!(hub.admin("mallory", AdminOp::BanList { channel: "ops".to_owned() }).await.is_err());
assert!(
hub.admin(
"mallory",
AdminOp::Unban {
channel: "ops".to_owned(),
user: "bob".to_owned(),
}
)
.await
.is_err()
);
hub.admin(
"aaron",
AdminOp::Unban {
channel: "ops".to_owned(),
user: "bob".to_owned(),
},
)
.await
.unwrap();
match hub.admin("aaron", AdminOp::BanList { channel: "ops".to_owned() }).await.unwrap() {
ProtocolMessage::UserList { users } => assert!(users.is_empty(), "unban must lift the ban: {users:?}"),
other => panic!("expected a UserList, got {other:?}"),
}
assert!(store.list_bans().await.unwrap().is_empty(), "unban must be durable");
assert!(!store.is_channel_member("ops", "bob").await.unwrap(), "unban must not grant ACL membership");
let bob = attach_session(&hub, "bob");
assert!(hub.join("bob", &bob, "ops", None).await.is_ok(), "an unbanned user may rejoin");
}
#[tokio::test]
async fn bans_survive_a_server_restart() {
let store = Store::open_in_memory().await.unwrap();
store.create_channel("ops", Visibility::Public, "aaron").await.unwrap();
let hub = Hub::new(store.clone(), HashMap::new()).await.unwrap();
let ack = hub
.admin(
"aaron",
AdminOp::Ban {
channel: "ops".to_owned(),
user: "bob".to_owned(),
},
)
.await
.unwrap();
assert!(matches!(ack, ProtocolMessage::Ack { .. }));
drop(hub);
let hub = Hub::new(store, HashMap::new()).await.unwrap();
let bob = attach_session(&hub, "bob");
assert!(hub.join("bob", &bob, "ops", None).await.is_err(), "a persisted ban must survive a server restart");
}
#[tokio::test]
async fn subscribe_re_checks_the_ban_atomically() {
let store = Store::open_in_memory().await.unwrap();
store.create_channel("ops", Visibility::Public, "aaron").await.unwrap();
let hub = Hub::new(store, HashMap::new()).await.unwrap();
let bob = attach_session(&hub, "bob");
hub.add_ban("ops", "bob").await.unwrap();
assert!(!hub.subscribe(&bob, "ops"), "subscribe must refuse a banned user");
assert!(!hub.subscribers("ops").contains(&bob), "a banned user must not end up subscribed");
}
#[tokio::test]
async fn invite_single_use_is_consumed_after_one_redeem() {
let hub = hub_with_private_channel(Some(1), None).await;
let david = attach_session(&hub, "david");
hub.join("david", &david, "ops", Some("tok")).await.unwrap();
assert!(hub.subscribers("ops").contains(&david), "redeeming a valid invite must subscribe + add to the ACL");
let carol = attach_session(&hub, "carol");
assert!(hub.join("carol", &carol, "ops", Some("tok")).await.is_err(), "a spent single-use invite must be refused");
}
#[tokio::test]
async fn invite_multi_use_allows_several_then_exhausts() {
let hub = hub_with_private_channel(Some(2), None).await;
for user in ["david", "carol"] {
let path = attach_session(&hub, user);
hub.join(user, &path, "ops", Some("tok")).await.unwrap();
}
let evan = attach_session(&hub, "evan");
assert!(hub.join("evan", &evan, "ops", Some("tok")).await.is_err(), "an exhausted invite must be refused");
}
#[tokio::test]
async fn invite_expiry_refuses_an_expired_token() {
let hub = hub_with_private_channel(None, Some("2000-01-01T00:00:00+00:00".to_owned())).await;
let david = attach_session(&hub, "david");
assert!(hub.join("david", &david, "ops", Some("tok")).await.is_err(), "an expired token must be refused");
}
#[tokio::test]
async fn invite_revoked_token_is_refused() {
let hub = hub_with_private_channel(None, None).await;
hub.admin("aaron", AdminOp::InviteRevoke { token: "tok".to_owned() }).await.unwrap();
let david = attach_session(&hub, "david");
assert!(hub.join("david", &david, "ops", Some("tok")).await.is_err(), "a revoked token must be refused");
}
#[tokio::test]
async fn invite_wrong_channel_token_is_refused() {
let hub = hub_with_private_channel(None, None).await;
let david = attach_session(&hub, "david");
assert!(hub.join("david", &david, "ops", Some("nope")).await.is_err(), "an unknown token must be refused");
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn invite_single_use_admits_exactly_one_under_concurrent_redeem() {
const RACERS: usize = 16;
let hub = hub_with_private_channel(Some(1), None).await;
let mut tasks = Vec::with_capacity(RACERS);
for i in 0..RACERS {
let user = format!("user{i}");
let path = attach_session(&hub, &user);
let hub = Arc::clone(&hub);
tasks.push(tokio::spawn(async move { hub.join(&user, &path, "ops", Some("tok")).await.is_ok() }));
}
let mut admitted = 0;
for task in tasks {
if task.await.unwrap() {
admitted += 1;
}
}
assert_eq!(admitted, 1, "a single-use invite must admit exactly one redeemer under concurrency, admitted {admitted}");
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn concurrent_joins_on_unlimited_invite_preserve_every_acl_entry() {
const JOINERS: usize = 12;
let hub = hub_with_private_channel(None, None).await;
let mut tasks = Vec::with_capacity(JOINERS);
for i in 0..JOINERS {
let user = format!("user{i}");
let path = attach_session(&hub, &user);
let hub = Arc::clone(&hub);
tasks.push(tokio::spawn(async move { hub.join(&user, &path, "ops", Some("tok")).await }));
}
for task in tasks {
task.await.unwrap().unwrap();
}
let members = hub.store.list_channel_members("ops").await.unwrap();
assert_eq!(members.len(), JOINERS + 1, "concurrent joins must not lose a membership, got {members:?}");
}
#[tokio::test]
async fn user_remove_purges_memberships_and_created_channels() {
let store = Store::open_in_memory().await.unwrap();
store.create_user("victim").await.unwrap();
store.create_channel("victim-ops", Visibility::Private, "victim").await.unwrap();
store.create_channel("lobby", Visibility::Public, "aaron").await.unwrap();
store.add_channel_member("lobby", "victim").await.unwrap();
let hub = Hub::new(store, HashMap::from([("root".to_owned(), None)])).await.unwrap();
hub.admin("root", AdminOp::UserRemove { username: "victim".to_owned() }).await.unwrap();
assert!(hub.store.get_channel("victim-ops").await.unwrap().is_none(), "a removed user's created channels must be deleted");
assert!(!hub.store.is_channel_member("lobby", "victim").await.unwrap(), "a removed user's memberships must be purged");
}
#[tokio::test]
async fn invite_create_with_absurd_expiry_errors_instead_of_panicking() {
let store = Store::open_in_memory().await.unwrap();
store.create_channel("ops", Visibility::Private, "aaron").await.unwrap();
let hub = Hub::new(store, HashMap::new()).await.unwrap();
let result = hub
.admin(
"aaron",
AdminOp::InviteCreate {
channel: "ops".to_owned(),
uses: None,
expires_in_secs: Some(u64::MAX),
},
)
.await;
assert!(result.is_err(), "an absurd expiry must return an error, got {result:?}");
}
}