use std::{ops::ControlFlow, sync::Arc, time::Duration};
use tokio::sync::mpsc;
pub(crate) const HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(10);
use crate::{
base::SessionPath,
identity,
protocol::{ProtocolError, ProtocolMessage, negotiate_version},
};
use super::hub::{Hub, Kill};
type Inbound = mpsc::UnboundedReceiver<ProtocolMessage>;
type Outbound = mpsc::Sender<ProtocolMessage>;
#[cfg(test)]
type InboundTx = mpsc::UnboundedSender<ProtocolMessage>;
#[cfg(test)]
type OutboundRx = mpsc::Receiver<ProtocolMessage>;
pub(crate) const OUTBOUND_CAPACITY: usize = 1024;
struct SessionCtx {
path: SessionPath,
kill: Arc<Kill>,
}
pub(crate) async fn run_session(hub: Arc<Hub>, mut inbound: Inbound, outbound: Outbound) {
let ctx = match tokio::time::timeout(HANDSHAKE_TIMEOUT, handshake(&hub, &mut inbound, &outbound)).await {
Ok(Some(ctx)) => ctx,
Ok(None) => return,
Err(_elapsed) => {
let _ = outbound.try_send(err(ProtocolError::Unauthorized("handshake timed out".to_owned())));
return;
}
};
let kill = Arc::clone(&ctx.kill);
loop {
tokio::select! {
() = kill.notified() => {
let _ = outbound.try_send(err(ProtocolError::Unauthorized(kill.reason().to_owned())));
break;
}
frame = inbound.recv() => {
let Some(frame) = frame else { break };
hub.touch(&ctx.path);
if handle_frame(&hub, &ctx, &outbound, frame).await.is_break() {
break;
}
}
}
}
hub.detach(&ctx.path, &ctx.kill);
}
#[cfg(test)]
pub(crate) async fn handle_connection<S>(hub: Arc<Hub>, stream: S)
where
S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
{
let (reader, writer) = tokio::io::split(stream);
let (inbound_tx, inbound_rx) = mpsc::unbounded_channel();
let (outbound_tx, outbound_rx) = mpsc::channel(OUTBOUND_CAPACITY);
let read_task = tokio::spawn(read_pump(reader, inbound_tx));
let write_task = tokio::spawn(write_pump(writer, outbound_rx));
run_session(hub, inbound_rx, outbound_tx).await;
read_task.abort();
let _ = write_task.await;
}
#[cfg(test)]
async fn read_pump<R: tokio::io::AsyncRead + Unpin>(mut reader: R, inbound: InboundTx) {
use crate::protocol::ProtocolRead as _;
while let Ok(frame) = reader.recv_message().await {
if inbound.send(frame).is_err() {
break;
}
}
}
#[cfg(test)]
async fn write_pump<W: tokio::io::AsyncWrite + Unpin>(mut writer: W, mut outbound: OutboundRx) {
use crate::protocol::ProtocolWrite as _;
while let Some(frame) = outbound.recv().await {
if writer.send_message(&frame).await.is_err() {
break;
}
}
}
async fn handshake(hub: &Arc<Hub>, inbound: &mut Inbound, outbound: &Outbound) -> Option<SessionCtx> {
let ProtocolMessage::Hello { protocol_version, session } = inbound.recv().await? else {
let _ = outbound.try_send(err(ProtocolError::MalformedFrame("expected Hello".to_owned())));
return None;
};
if let Err(mismatch) = negotiate_version(protocol_version) {
let _ = outbound.try_send(err(mismatch));
return None;
}
let nonce = match identity::generate_challenge() {
Ok(nonce) => nonce,
Err(e) => {
let _ = outbound.try_send(err(ProtocolError::Internal(e.to_string())));
return None;
}
};
let _ = outbound.try_send(ProtocolMessage::Challenge { nonce: nonce.to_vec() });
let (user, machine) = match inbound.recv().await? {
ProtocolMessage::Register { username, machine, pubkey } => {
let ProtocolMessage::Auth { pubkey: auth_pubkey, signature } = inbound.recv().await? else {
let _ = outbound.try_send(err(ProtocolError::MalformedFrame("expected Auth after Register".to_owned())));
return None;
};
if auth_pubkey != pubkey {
let _ = outbound.try_send(err(ProtocolError::Unauthorized("auth key does not match the registered key".to_owned())));
return None;
}
if let Err(e) = identity::verify(&auth_pubkey, &nonce, &signature) {
let _ = outbound.try_send(err(e.into()));
return None;
}
if !accept_component(outbound, "username", &username) || !accept_component(outbound, "machine name", &machine) {
return None;
}
if let Err(e) = hub.register(&username, &machine, &pubkey).await {
let _ = outbound.try_send(err(e));
return None;
}
(username, machine)
}
ProtocolMessage::Auth { pubkey, signature } => {
if let Err(e) = identity::verify(&pubkey, &nonce, &signature) {
let _ = outbound.try_send(err(e.into()));
return None;
}
match hub.resolve(&pubkey).await {
Ok(resolved) => resolved,
Err(e) => {
let _ = outbound.try_send(err(e));
return None;
}
}
}
_ => {
let _ = outbound.try_send(err(ProtocolError::MalformedFrame("expected Register or Auth".to_owned())));
return None;
}
};
if !accept_component(outbound, "session handle", &session) {
return None;
}
let path = SessionPath::new(user.clone(), machine.clone(), session);
let kill = hub.attach(&path, &user, &machine, outbound.clone());
let _ = outbound.try_send(ProtocolMessage::Established { path: path.clone() });
let _ = outbound.try_send(ProtocolMessage::ServerInfo { admin: hub.is_admin(&user) });
Some(SessionCtx { path, kill })
}
async fn handle_frame(hub: &Arc<Hub>, ctx: &SessionCtx, outbound: &Outbound, frame: ProtocolMessage) -> ControlFlow<()> {
let user = &ctx.path.user;
match frame {
ProtocolMessage::Ping => {
let _ = outbound.try_send(ProtocolMessage::Pong);
}
ProtocolMessage::Join { channel, token } => match hub.join(user, &ctx.path, &channel, token.as_deref()).await {
Ok(()) => {
let _ = outbound.try_send(ProtocolMessage::Joined { channel });
}
Err(e) => {
let _ = outbound.try_send(err(e));
}
},
ProtocolMessage::Leave { channel } => {
hub.leave(&ctx.path, &channel);
let _ = outbound.try_send(ProtocolMessage::Ack { detail: Some(channel) });
}
ProtocolMessage::Who { channel } => match hub.who(user, channel.as_deref()).await {
Ok(sessions) => {
let _ = outbound.try_send(ProtocolMessage::Presence { channel, sessions });
}
Err(e) => {
let _ = outbound.try_send(err(e));
}
},
ProtocolMessage::ListChannels => match hub.list_channels(user).await {
Ok(channels) => {
let _ = outbound.try_send(ProtocolMessage::ChannelList { channels });
}
Err(e) => {
let _ = outbound.try_send(err(e));
}
},
ProtocolMessage::ListMachines => match hub.list_machines(user).await {
Ok(machines) => {
let _ = outbound.try_send(ProtocolMessage::MachineList { machines });
}
Err(e) => {
let _ = outbound.try_send(err(e));
}
},
ProtocolMessage::ListUsers => match hub.list_users(user).await {
Ok(users) => {
let _ = outbound.try_send(ProtocolMessage::UserList { users });
}
Err(e) => {
let _ = outbound.try_send(err(e));
}
},
ProtocolMessage::Admin(op) => match hub.admin(user, op).await {
Ok(reply) => {
let _ = outbound.try_send(reply);
}
Err(e) => {
let _ = outbound.try_send(err(e));
}
},
ProtocolMessage::ChannelMsg { channel, payload, .. } => match hub.post(&ctx.path, &channel, payload) {
Ok(()) => {
let _ = outbound.try_send(ProtocolMessage::Ack { detail: None });
}
Err(e) => {
let _ = outbound.try_send(err(e));
}
},
ProtocolMessage::Whisper { target, payload, .. } => match hub.whisper(&ctx.path, &target, payload) {
Ok(()) => {
let _ = outbound.try_send(ProtocolMessage::Ack { detail: None });
}
Err(e) => {
let _ = outbound.try_send(err(e));
}
},
_ => {
let _ = outbound.try_send(err(ProtocolError::MalformedFrame("unexpected frame from client".to_owned())));
}
}
ControlFlow::Continue(())
}
fn accept_component(outbound: &Outbound, label: &str, value: &str) -> bool {
if SessionPath::validate_component(value).is_ok() {
return true;
}
let _ = outbound.try_send(err(ProtocolError::MalformedFrame(format!("invalid {label}: `{value}`"))));
false
}
fn err(error: ProtocolError) -> ProtocolMessage {
ProtocolMessage::Error(error)
}