use std::time::Duration;
use anyhow::Context as _;
use futures_util::{SinkExt as _, StreamExt as _};
use tokio::net::TcpStream;
use tokio_tungstenite::{MaybeTlsStream, WebSocketStream, tungstenite::Message};
use crate::{
base::{Constant, Res, SessionPath},
identity::Identity,
protocol::{self, ProtocolMessage},
};
type Ws = WebSocketStream<MaybeTlsStream<TcpStream>>;
const CONNECT_TIMEOUT: Duration = Duration::from_secs(10);
const RESPONSE_TIMEOUT: Duration = Duration::from_secs(15);
const KEEPALIVE_INTERVAL: Duration = Duration::from_secs(20);
pub async fn register(url: &str, identity: &Identity, username: &str, machine: &str, session: &str) -> Res<SessionPath> {
let mut ws = connect(url).await?;
let nonce = hello_challenge(&mut ws, session).await?;
let pubkey = identity.public_key().to_vec();
send(
&mut ws,
&ProtocolMessage::Register {
username: username.to_owned(),
machine: machine.to_owned(),
pubkey: pubkey.clone(),
},
)
.await?;
send(
&mut ws,
&ProtocolMessage::Auth {
pubkey,
signature: identity.sign(&nonce)?.to_vec(),
},
)
.await?;
match recv(&mut ws).await? {
ProtocolMessage::Established { path } => Ok(path),
ProtocolMessage::Error(err) => anyhow::bail!("registration rejected: {err}"),
other => anyhow::bail!("unexpected response to register: {other:?}"),
}
}
pub async fn one_shot(url: &str, identity: &Identity, session: &str, request: ProtocolMessage) -> Res<ProtocolMessage> {
let mut ws = connect(url).await?;
authenticate(&mut ws, identity, session).await?;
send(&mut ws, &request).await?;
recv(&mut ws).await
}
pub async fn send_message(url: &str, identity: &Identity, session: &str, channel: &str, text: &str) -> Res<()> {
let mut ws = connect(url).await?;
let from = authenticate(&mut ws, identity, session).await?;
send(&mut ws, &ProtocolMessage::Join { channel: channel.to_owned(), token: None }).await?;
match recv(&mut ws).await? {
ProtocolMessage::Joined { .. } => {}
ProtocolMessage::Error(err) => anyhow::bail!("join rejected: {err}"),
other => anyhow::bail!("unexpected response to join: {other:?}"),
}
send(
&mut ws,
&ProtocolMessage::ChannelMsg {
channel: channel.to_owned(),
from,
payload: protocol::Payload::Plain(text.to_owned()),
},
)
.await?;
match recv(&mut ws).await? {
ProtocolMessage::Ack { .. } => Ok(()),
ProtocolMessage::Error(err) => anyhow::bail!("send rejected: {err}"),
other => anyhow::bail!("unexpected response to send: {other:?}"),
}
}
const TAIL_BACKOFF_BASE: Duration = Duration::from_secs(1);
const TAIL_BACKOFF_MAX: Duration = Duration::from_secs(30);
const TAIL_RESUME_SLACK_MS: i64 = 5_000;
#[derive(Debug, thiserror::Error)]
#[error("{0}")]
struct TailFatal(String);
pub async fn tail(url: &str, identity: &Identity, session: &str, channel: &str, since_secs: Option<u64>) -> Res<()> {
let mut watermark_ms: Option<i64> = since_secs.map(|secs| chrono::Utc::now().timestamp_millis().saturating_sub(i64::try_from(secs).unwrap_or(i64::MAX).saturating_mul(1000)));
let mut established = false;
let mut backoff = TAIL_BACKOFF_BASE;
loop {
match tail_once(url, identity, session, channel, &mut watermark_ms, &mut established).await {
Ok(()) => return Ok(()),
Err(err) if !established || err.downcast_ref::<TailFatal>().is_some() => return Err(err),
Err(_) => {
eprintln!("⚠ connection to `{url}` lost — reconnecting (Ctrl-C to stop)");
tokio::select! {
_ = tokio::signal::ctrl_c() => return Ok(()),
() = tokio::time::sleep(backoff) => {}
}
backoff = (backoff * 2).min(TAIL_BACKOFF_MAX);
}
}
}
}
async fn tail_once(url: &str, identity: &Identity, session: &str, channel: &str, watermark_ms: &mut Option<i64>, established: &mut bool) -> Res<()> {
use std::io::Write as _;
let mut ws = connect(url).await?;
let path = authenticate(&mut ws, identity, session).await?;
send(&mut ws, &ProtocolMessage::Join { channel: channel.to_owned(), token: None }).await?;
match recv(&mut ws).await? {
ProtocolMessage::Joined { channel } => {
if *established {
eprintln!("✓ reconnected; resuming #{channel}");
} else {
let mut out = std::io::stdout();
writeln!(out, "tailing #{channel} as {path} — Ctrl-C to stop")?;
out.flush()?;
*established = true;
}
}
ProtocolMessage::Error(err) => return Err(TailFatal(format!("join rejected: {err}")).into()),
other => return Err(TailFatal(format!("unexpected response to join: {other:?}")).into()),
}
if let Some(since) = *watermark_ms {
send(
&mut ws,
&ProtocolMessage::ReadSince {
channel: channel.to_owned(),
since_ms: since.saturating_sub(TAIL_RESUME_SLACK_MS),
},
)
.await?;
}
let mut keepalive = tokio::time::interval(KEEPALIVE_INTERVAL);
keepalive.tick().await; loop {
tokio::select! {
_ = tokio::signal::ctrl_c() => return Ok(()),
_ = keepalive.tick() => send(&mut ws, &ProtocolMessage::Ping).await?,
frame = recv_frame(&mut ws) => {
let mut out = std::io::stdout();
match frame? {
ProtocolMessage::ChannelMsg { channel, from, payload } => {
writeln!(out, "[{channel}] {from}: {}", render_payload(&payload))?;
*watermark_ms = Some(chrono::Utc::now().timestamp_millis());
}
ProtocolMessage::Whisper { from, payload, .. } => writeln!(out, "[whisper] {from}: {}", render_payload(&payload))?,
ProtocolMessage::History { channel, messages } => {
for message in &messages {
writeln!(out, "[{channel}] {}: {}", message.from, render_payload(&message.payload))?;
}
if let Some(newest) = messages.iter().map(|m| m.ts_ms).max() {
*watermark_ms = Some(watermark_ms.unwrap_or(newest).max(newest));
}
}
ProtocolMessage::Error(err) => return Err(TailFatal(format!("server terminated the stream: {err}")).into()),
_ => continue,
}
out.flush()?;
}
}
}
}
fn render_payload(payload: &protocol::Payload) -> &str {
match payload {
protocol::Payload::Plain(text) => text,
protocol::Payload::Encrypted(_) => "<end-to-end-encrypted payload>",
}
}
async fn authenticate(ws: &mut Ws, identity: &Identity, session: &str) -> Res<SessionPath> {
let nonce = hello_challenge(ws, session).await?;
send(
ws,
&ProtocolMessage::Auth {
pubkey: identity.public_key().to_vec(),
signature: identity.sign(&nonce)?.to_vec(),
},
)
.await?;
match recv(ws).await? {
ProtocolMessage::Established { path } => Ok(path),
ProtocolMessage::Error(err) => anyhow::bail!("authentication rejected: {err}"),
other => anyhow::bail!("unexpected response before request: {other:?}"),
}
}
async fn connect(url: &str) -> Res<Ws> {
connect_with_timeout(url, CONNECT_TIMEOUT).await
}
async fn connect_with_timeout(url: &str, timeout: Duration) -> Res<Ws> {
crate::base::ensure_tls_provider();
match tokio::time::timeout(timeout, tokio_tungstenite::connect_async(url)).await {
Ok(result) => {
let (ws, _response) = result.with_context(|| format!("failed to connect to `{url}`"))?;
Ok(ws)
}
Err(_) => anyhow::bail!("timed out after {}s connecting to `{url}`", timeout.as_secs()),
}
}
async fn hello_challenge(ws: &mut Ws, session: &str) -> Res<Vec<u8>> {
send(
ws,
&ProtocolMessage::Hello {
protocol_version: Constant::PROTOCOL_VERSION,
session: session.to_owned(),
},
)
.await?;
match recv(ws).await? {
ProtocolMessage::Challenge { nonce } => Ok(nonce),
other => anyhow::bail!("expected a challenge, got {other:?}"),
}
}
async fn send(ws: &mut Ws, frame: &ProtocolMessage) -> Res<()> {
ws.send(Message::Binary(protocol::encode(frame)?.into())).await.context("failed to send control frame")?;
Ok(())
}
async fn recv(ws: &mut Ws) -> Res<ProtocolMessage> {
recv_with_timeout(ws, RESPONSE_TIMEOUT).await
}
async fn recv_with_timeout(ws: &mut Ws, timeout: Duration) -> Res<ProtocolMessage> {
match tokio::time::timeout(timeout, recv_frame(ws)).await {
Ok(result) => result,
Err(_) => anyhow::bail!("timed out after {}s waiting for a server response", timeout.as_secs()),
}
}
async fn recv_frame(ws: &mut Ws) -> Res<ProtocolMessage> {
loop {
match ws.next().await {
Some(Ok(Message::Binary(data))) => match protocol::decode(&data)? {
ProtocolMessage::ServerInfo { .. } | ProtocolMessage::Pong => {}
frame => return Ok(frame),
},
Some(Ok(Message::Close(_))) | None => anyhow::bail!("connection closed before a response arrived"),
Some(Ok(_)) => {}
Some(Err(err)) => anyhow::bail!("websocket error: {err}"),
}
}
}
#[cfg(test)]
mod tests {
#![allow(clippy::unwrap_used)]
use std::time::Duration;
use tokio::net::TcpListener;
use super::{connect_with_timeout, recv_with_timeout};
#[tokio::test]
async fn control_timeout_connecting_to_a_silent_server() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move {
let _accepted = listener.accept().await; std::future::pending::<()>().await;
});
let url = format!("ws://{addr}");
let err = connect_with_timeout(&url, Duration::from_millis(150)).await.expect_err("a silent server must time out");
assert!(err.to_string().to_lowercase().contains("timed out"), "expected a timeout error, got: {err}");
}
#[tokio::test]
async fn control_timeout_waiting_for_a_reply_from_a_silent_server() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move {
let (stream, _) = listener.accept().await.unwrap();
let _ws = tokio_tungstenite::accept_async(stream).await.unwrap(); std::future::pending::<()>().await;
});
let url = format!("ws://{addr}");
let mut ws = connect_with_timeout(&url, Duration::from_secs(5)).await.unwrap();
let err = recv_with_timeout(&mut ws, Duration::from_millis(150)).await.expect_err("a silent reply must time out");
assert!(err.to_string().to_lowercase().contains("timed out"), "expected a timeout error, got: {err}");
}
}