use std::{
collections::{HashMap, HashSet},
fmt::Display,
sync::{Arc, Mutex, MutexGuard},
time::Duration,
};
use tokio::{
sync::{Notify, mpsc},
time::Instant,
};
use crate::{
base::{SessionPath, Visibility},
identity::{self, AuthError},
protocol::{AdminOp, ChannelInfo, MachineInfo, Payload, ProtocolError, ProtocolMessage},
store::{ChannelRecord, Store},
};
use super::AclError;
type Outbound = mpsc::UnboundedSender<ProtocolMessage>;
type Reply = Result<ProtocolMessage, ProtocolError>;
struct SessionEntry {
user: String,
machine: String,
outbound: Outbound,
kill: Arc<Notify>,
channels: HashSet<String>,
last_seen: Instant,
}
#[derive(Default)]
struct HubState {
sessions: HashMap<SessionPath, SessionEntry>,
subscriptions: HashMap<String, HashSet<SessionPath>>,
bans: HashMap<String, HashSet<String>>,
}
pub(crate) struct Hub {
store: Store,
admins: HashSet<String>,
state: Mutex<HubState>,
}
impl Hub {
pub(crate) fn new(store: Store, admins: HashSet<String>) -> Arc<Self> {
Arc::new(Self {
store,
admins,
state: Mutex::new(HubState::default()),
})
}
fn state(&self) -> MutexGuard<'_, HubState> {
self.state.lock().expect("hub state mutex poisoned")
}
pub(crate) fn is_admin(&self, user: &str) -> bool {
self.admins.contains(user)
}
pub(crate) async fn list_machines(&self, user: &str) -> Result<Vec<MachineInfo>, ProtocolError> {
let machines = self.store.list_machines(user).await.map_err(internal)?;
Ok(machines
.into_iter()
.map(|m| MachineInfo {
name: m.name,
pubkey: m.pubkey,
added_at: m.added_at,
})
.collect())
}
pub(crate) async fn list_users(&self, caller: &str) -> Result<Vec<String>, ProtocolError> {
if !self.is_admin(caller) {
return Err(AclError::NotAdmin.into());
}
let users = self.store.list_users().await.map_err(internal)?;
Ok(users.into_iter().map(|u| u.username).collect())
}
pub(crate) async fn register(&self, username: &str, machine: &str, pubkey: &[u8]) -> Result<(), ProtocolError> {
if self.store.get_user(username).await.map_err(internal)?.is_some() {
return Err(AuthError::UsernameTaken(username.to_owned()).into());
}
let pubkey_b64 = identity::encode_key(pubkey);
if self.store.get_machine_by_pubkey(&pubkey_b64).await.map_err(internal)?.is_some() {
return Err(AuthError::Malformed("public key is already enrolled".to_owned()).into());
}
self.store.create_user(username).await.map_err(internal)?;
self.store.create_machine(username, machine, &pubkey_b64).await.map_err(internal)?;
Ok(())
}
pub(crate) async fn resolve(&self, pubkey: &[u8]) -> Result<(String, String), ProtocolError> {
let pubkey_b64 = identity::encode_key(pubkey);
let machine = self
.store
.get_machine_by_pubkey(&pubkey_b64)
.await
.map_err(internal)?
.ok_or_else(|| ProtocolError::from(AuthError::UnknownKey))?;
Ok((machine.user, machine.name))
}
pub(crate) fn attach(&self, path: &SessionPath, user: &str, machine: &str, outbound: Outbound) -> Result<Arc<Notify>, ProtocolError> {
let mut st = self.state();
if st.sessions.contains_key(path) {
return Err(AuthError::HandleCollision(path.session.clone()).into());
}
let kill = Arc::new(Notify::new());
st.sessions.insert(
path.clone(),
SessionEntry {
user: user.to_owned(),
machine: machine.to_owned(),
outbound,
kill: Arc::clone(&kill),
channels: HashSet::new(),
last_seen: Instant::now(),
},
);
Ok(kill)
}
pub(crate) fn detach(&self, path: &SessionPath) {
let mut st = self.state();
Self::take_session(&mut st, path);
}
pub(crate) fn touch(&self, path: &SessionPath) {
if let Some(entry) = self.state().sessions.get_mut(path) {
entry.last_seen = Instant::now();
}
}
pub(crate) fn reap_idle(&self, timeout: Duration) -> usize {
let now = Instant::now();
let mut st = self.state();
let stale: Vec<SessionPath> = st
.sessions
.iter()
.filter(|(_, e)| now.saturating_duration_since(e.last_seen) >= timeout)
.map(|(p, _)| p.clone())
.collect();
for path in &stale {
Self::kill_locked(&mut st, path);
}
stale.len()
}
pub(crate) async fn join(&self, user: &str, path: &SessionPath, channel: &str, token: Option<&str>) -> Result<(), ProtocolError> {
let mut record = self
.store
.get_channel(channel)
.await
.map_err(internal)?
.ok_or_else(|| ProtocolError::NotFound(format!("channel `{channel}`")))?;
if self.is_banned(channel, user) {
return Err(AclError::ChannelPrivate(channel.to_owned()).into());
}
let already_member = record.acl.iter().any(|u| u == user);
if !already_member {
match parse_visibility(&record.visibility) {
Visibility::Public | Visibility::Unlisted => {}
Visibility::Private => {
let token = token.ok_or_else(|| ProtocolError::from(AclError::ChannelPrivate(channel.to_owned())))?;
self.redeem_invite(channel, token).await?;
record.acl.push(user.to_owned());
self.store.set_channel_acl(channel, &record.acl).await.map_err(internal)?;
}
}
}
self.subscribe(path, channel);
Ok(())
}
pub(crate) fn leave(&self, path: &SessionPath, channel: &str) {
self.unsubscribe_session(path, channel);
}
pub(crate) async fn list_channels(&self, user: &str) -> Result<Vec<ChannelInfo>, ProtocolError> {
let channels = self.store.list_channels().await.map_err(internal)?;
let infos = channels
.into_iter()
.filter_map(|c| {
let visibility = parse_visibility(&c.visibility);
let member = c.acl.iter().any(|u| u == user);
let visible = matches!(visibility, Visibility::Public) || member;
visible.then_some(ChannelInfo { name: c.name, visibility, member })
})
.collect();
Ok(infos)
}
pub(crate) async fn who(&self, user: &str, channel: Option<&str>) -> Result<Vec<SessionPath>, ProtocolError> {
let Some(channel) = channel else {
return Ok(self.present_paths());
};
let record = self
.store
.get_channel(channel)
.await
.map_err(internal)?
.ok_or_else(|| ProtocolError::from(AclError::ChannelNotFound(channel.to_owned())))?;
if !record.acl.iter().any(|u| u == user) && !self.is_admin(user) {
return Err(AclError::NotMember(channel.to_owned()).into());
}
let st = self.state();
Ok(st.subscriptions.get(channel).map(|subs| subs.iter().cloned().collect()).unwrap_or_default())
}
pub(crate) fn post(&self, from: &SessionPath, channel: &str, payload: Payload) -> Result<(), ProtocolError> {
let targets: Vec<Outbound> = {
let st = self.state();
let subs = st.subscriptions.get(channel).ok_or_else(|| ProtocolError::from(AclError::NotMember(channel.to_owned())))?;
if !subs.contains(from) {
return Err(AclError::NotMember(channel.to_owned()).into());
}
subs.iter().filter(|p| *p != from).filter_map(|p| st.sessions.get(p).map(|e| e.outbound.clone())).collect()
};
let msg = ProtocolMessage::ChannelMsg {
channel: channel.to_owned(),
from: from.clone(),
payload,
};
for tx in targets {
let _ = tx.send(msg.clone());
}
Ok(())
}
pub(crate) fn whisper(&self, from: &SessionPath, target: &SessionPath, payload: Payload) -> Result<(), ProtocolError> {
let outbound = self.state().sessions.get(target).map(|e| e.outbound.clone());
let Some(outbound) = outbound else {
return Err(ProtocolError::NotFound(format!("session `{target}` is not online")));
};
let msg = ProtocolMessage::Whisper {
from: from.clone(),
target: target.clone(),
payload,
};
let _ = outbound.send(msg);
Ok(())
}
pub(crate) async fn admin(&self, user: &str, op: AdminOp) -> Reply {
match op {
AdminOp::CreateChannel { name, visibility } => {
self.store.create_channel(&name, visibility, user).await.map_err(internal)?;
Ok(ack(name))
}
AdminOp::DeleteChannel { name } => {
self.authorize_channel_admin(&name, user).await?;
self.store.delete_channel(&name).await.map_err(internal)?;
self.drop_channel(&name);
Ok(ack(name))
}
AdminOp::RenameChannel { name, new_name } => {
self.authorize_channel_admin(&name, user).await?;
self.store.rename_channel(&name, &new_name).await.map_err(internal)?;
self.rename_channel_subscriptions(&name, &new_name);
Ok(ack(new_name))
}
AdminOp::SetVisibility { name, visibility } => {
self.authorize_channel_admin(&name, user).await?;
self.store.set_channel_visibility(&name, visibility).await.map_err(internal)?;
Ok(ack(name))
}
AdminOp::AclAdd { channel, user: target } => {
let mut record = self.authorize_channel_admin(&channel, user).await?;
if !record.acl.iter().any(|u| u == &target) {
record.acl.push(target.clone());
self.store.set_channel_acl(&channel, &record.acl).await.map_err(internal)?;
}
self.remove_ban(&channel, &target);
Ok(ack(target))
}
AdminOp::AclRemove { channel, user: target } => {
let mut record = self.authorize_channel_admin(&channel, user).await?;
record.acl.retain(|u| u != &target);
self.store.set_channel_acl(&channel, &record.acl).await.map_err(internal)?;
self.unsubscribe_user(&target, &channel);
Ok(ack(target))
}
AdminOp::InviteCreate { channel, uses, expires_in_secs } => {
self.authorize_channel_admin(&channel, user).await?;
let token = identity::generate_token().map_err(internal)?;
let expires_at = expires_in_secs.map(|secs| {
let secs = i64::try_from(secs).unwrap_or(i64::MAX);
(chrono::Utc::now() + chrono::Duration::seconds(secs)).to_rfc3339()
});
self.store.create_invite(&channel, &token, uses.map(i64::from), expires_at, user).await.map_err(internal)?;
Ok(ProtocolMessage::InviteToken { token })
}
AdminOp::InviteRevoke { token } => {
if let Some(invite) = self.store.get_invite(&token).await.map_err(internal)? {
self.authorize_channel_admin(&invite.channel, user).await?;
self.store.delete_invite(&token).await.map_err(internal)?;
}
Ok(ack(token))
}
AdminOp::Kick { channel, target } => {
self.authorize_channel_admin(&channel, user).await?;
if let Ok(path) = target.parse::<SessionPath>() {
self.unsubscribe_session(&path, &channel);
} else {
self.unsubscribe_user(&target, &channel);
}
Ok(ack(target))
}
AdminOp::Ban { channel, user: target } => {
let mut record = self.authorize_channel_admin(&channel, user).await?;
record.acl.retain(|u| u != &target);
self.store.set_channel_acl(&channel, &record.acl).await.map_err(internal)?;
self.add_ban(&channel, &target);
self.unsubscribe_user(&target, &channel);
Ok(ack(target))
}
AdminOp::UserRemove { username } => {
if !self.is_admin(user) {
return Err(AclError::NotAdmin.into());
}
for machine in self.store.list_machines(&username).await.map_err(internal)? {
self.store.delete_machine(&username, &machine.name).await.map_err(internal)?;
}
self.store.delete_user(&username).await.map_err(internal)?;
self.force_drop_user(&username);
Ok(ack(username))
}
AdminOp::MachineRemove { name } => {
self.store.delete_machine(user, &name).await.map_err(internal)?;
self.force_drop_machine(user, &name);
Ok(ack(name))
}
AdminOp::MachineAdd { name, pubkey } => {
let pubkey_b64 = identity::encode_key(&pubkey);
self.store.create_machine(user, &name, &pubkey_b64).await.map_err(internal)?;
Ok(ack(name))
}
}
}
async fn authorize_channel_admin(&self, channel: &str, user: &str) -> Result<ChannelRecord, ProtocolError> {
let record = self
.store
.get_channel(channel)
.await
.map_err(internal)?
.ok_or_else(|| ProtocolError::from(AclError::ChannelNotFound(channel.to_owned())))?;
if record.created_by == user || self.is_admin(user) {
Ok(record)
} else {
Err(AclError::NotAdmin.into())
}
}
async fn redeem_invite(&self, channel: &str, token: &str) -> Result<(), ProtocolError> {
let invite = self
.store
.get_invite(token)
.await
.map_err(internal)?
.filter(|inv| inv.channel == channel)
.ok_or_else(|| ProtocolError::from(AclError::ChannelPrivate(channel.to_owned())))?;
if invite.expires_at.as_deref().is_some_and(is_expired) {
self.store.delete_invite(token).await.map_err(internal)?;
return Err(AclError::ChannelPrivate(channel.to_owned()).into());
}
match invite.uses_remaining {
Some(remaining) if remaining <= 1 => self.store.delete_invite(token).await.map_err(internal)?,
Some(remaining) => self.store.set_invite_uses(token, remaining - 1).await.map_err(internal)?,
None => {}
}
Ok(())
}
fn subscribe(&self, path: &SessionPath, channel: &str) {
let mut st = self.state();
let Some(entry) = st.sessions.get_mut(path) else {
return;
};
entry.channels.insert(channel.to_owned());
st.subscriptions.entry(channel.to_owned()).or_default().insert(path.clone());
}
fn unsubscribe_session(&self, path: &SessionPath, channel: &str) {
let mut st = self.state();
if let Some(subs) = st.subscriptions.get_mut(channel) {
subs.remove(path);
if subs.is_empty() {
st.subscriptions.remove(channel);
}
}
if let Some(entry) = st.sessions.get_mut(path) {
entry.channels.remove(channel);
}
}
fn unsubscribe_user(&self, user: &str, channel: &str) {
let mut st = self.state();
let Some(subs) = st.subscriptions.get(channel) else {
return;
};
let paths: Vec<SessionPath> = subs.iter().filter(|p| st.sessions.get(*p).is_some_and(|e| e.user == user)).cloned().collect();
for path in paths {
if let Some(subs) = st.subscriptions.get_mut(channel) {
subs.remove(&path);
if subs.is_empty() {
st.subscriptions.remove(channel);
}
}
if let Some(entry) = st.sessions.get_mut(&path) {
entry.channels.remove(channel);
}
}
}
fn drop_channel(&self, channel: &str) {
let mut st = self.state();
if let Some(subs) = st.subscriptions.remove(channel) {
for path in subs {
if let Some(entry) = st.sessions.get_mut(&path) {
entry.channels.remove(channel);
}
}
}
st.bans.remove(channel);
}
fn rename_channel_subscriptions(&self, old: &str, new: &str) {
let mut st = self.state();
if let Some(subs) = st.subscriptions.remove(old) {
for path in &subs {
if let Some(entry) = st.sessions.get_mut(path) {
entry.channels.remove(old);
entry.channels.insert(new.to_owned());
}
}
st.subscriptions.insert(new.to_owned(), subs);
}
if let Some(banned) = st.bans.remove(old) {
st.bans.insert(new.to_owned(), banned);
}
}
fn force_drop_user(&self, user: &str) {
let mut st = self.state();
let paths: Vec<SessionPath> = st.sessions.iter().filter(|(_, e)| e.user == user).map(|(p, _)| p.clone()).collect();
for path in &paths {
Self::kill_locked(&mut st, path);
}
}
fn force_drop_machine(&self, user: &str, machine: &str) {
let mut st = self.state();
let paths: Vec<SessionPath> = st.sessions.iter().filter(|(_, e)| e.user == user && e.machine == machine).map(|(p, _)| p.clone()).collect();
for path in &paths {
Self::kill_locked(&mut st, path);
}
}
fn is_banned(&self, channel: &str, user: &str) -> bool {
self.state().bans.get(channel).is_some_and(|banned| banned.contains(user))
}
fn add_ban(&self, channel: &str, user: &str) {
self.state().bans.entry(channel.to_owned()).or_default().insert(user.to_owned());
}
fn remove_ban(&self, channel: &str, user: &str) {
if let Some(banned) = self.state().bans.get_mut(channel) {
banned.remove(user);
}
}
fn take_session(st: &mut HubState, path: &SessionPath) -> Option<SessionEntry> {
let entry = st.sessions.remove(path)?;
for channel in &entry.channels {
if let Some(subs) = st.subscriptions.get_mut(channel) {
subs.remove(path);
if subs.is_empty() {
st.subscriptions.remove(channel);
}
}
}
Some(entry)
}
fn kill_locked(st: &mut HubState, path: &SessionPath) {
if let Some(entry) = Self::take_session(st, path) {
entry.kill.notify_one();
}
}
pub(crate) fn present_paths(&self) -> Vec<SessionPath> {
self.state().sessions.keys().cloned().collect()
}
#[cfg(test)]
pub(crate) fn is_present(&self, path: &SessionPath) -> bool {
self.state().sessions.contains_key(path)
}
#[cfg(test)]
pub(crate) fn subscribers(&self, channel: &str) -> Vec<SessionPath> {
self.state().subscriptions.get(channel).map(|subs| subs.iter().cloned().collect()).unwrap_or_default()
}
}
fn internal<E: Display>(err: E) -> ProtocolError {
ProtocolError::Internal(err.to_string())
}
fn ack(detail: impl Into<String>) -> ProtocolMessage {
ProtocolMessage::Ack { detail: Some(detail.into()) }
}
fn parse_visibility(token: &str) -> Visibility {
token.parse().unwrap_or(Visibility::Public)
}
fn is_expired(rfc3339: &str) -> bool {
chrono::DateTime::parse_from_rfc3339(rfc3339).map_or(true, |dt| dt.with_timezone(&chrono::Utc) <= chrono::Utc::now())
}
#[cfg(test)]
mod tests {
#![allow(clippy::unwrap_used)]
use super::*;
use crate::store::Store;
async fn hub_with_private_channel(token_uses: Option<i64>, expires_at: Option<String>) -> Arc<Hub> {
let store = Store::open_in_memory().await.unwrap();
store.create_channel("ops", Visibility::Private, "aaron").await.unwrap();
store.create_invite("ops", "tok", token_uses, expires_at, "aaron").await.unwrap();
Hub::new(store, HashSet::new())
}
fn attach_session(hub: &Arc<Hub>, user: &str) -> SessionPath {
let path = SessionPath::new(user, "machine", "session");
let (tx, _rx) = mpsc::unbounded_channel();
hub.attach(&path, user, "machine", tx).unwrap();
path
}
#[tokio::test]
async fn invite_single_use_is_consumed_after_one_redeem() {
let hub = hub_with_private_channel(Some(1), None).await;
let david = attach_session(&hub, "david");
hub.join("david", &david, "ops", Some("tok")).await.unwrap();
assert!(hub.subscribers("ops").contains(&david), "redeeming a valid invite must subscribe + add to the ACL");
let carol = attach_session(&hub, "carol");
assert!(hub.join("carol", &carol, "ops", Some("tok")).await.is_err(), "a spent single-use invite must be refused");
}
#[tokio::test]
async fn invite_multi_use_allows_several_then_exhausts() {
let hub = hub_with_private_channel(Some(2), None).await;
for user in ["david", "carol"] {
let path = attach_session(&hub, user);
hub.join(user, &path, "ops", Some("tok")).await.unwrap();
}
let evan = attach_session(&hub, "evan");
assert!(hub.join("evan", &evan, "ops", Some("tok")).await.is_err(), "an exhausted invite must be refused");
}
#[tokio::test]
async fn invite_expiry_refuses_an_expired_token() {
let hub = hub_with_private_channel(None, Some("2000-01-01T00:00:00+00:00".to_owned())).await;
let david = attach_session(&hub, "david");
assert!(hub.join("david", &david, "ops", Some("tok")).await.is_err(), "an expired token must be refused");
}
#[tokio::test]
async fn invite_revoked_token_is_refused() {
let hub = hub_with_private_channel(None, None).await;
hub.admin("aaron", AdminOp::InviteRevoke { token: "tok".to_owned() }).await.unwrap();
let david = attach_session(&hub, "david");
assert!(hub.join("david", &david, "ops", Some("tok")).await.is_err(), "a revoked token must be refused");
}
#[tokio::test]
async fn invite_wrong_channel_token_is_refused() {
let hub = hub_with_private_channel(None, None).await;
let david = attach_session(&hub, "david");
assert!(hub.join("david", &david, "ops", Some("nope")).await.is_err(), "an unknown token must be refused");
}
}