1use std::time::Duration;
9
10use anyhow::Context as _;
11use futures_util::{SinkExt as _, StreamExt as _};
12use tokio::net::TcpStream;
13use tokio_tungstenite::{MaybeTlsStream, WebSocketStream, tungstenite::Message};
14
15use crate::{
16 base::{Constant, Res, SessionPath},
17 identity::Identity,
18 protocol::{self, ProtocolMessage},
19};
20
21type Ws = WebSocketStream<MaybeTlsStream<TcpStream>>;
23
24const CONNECT_TIMEOUT: Duration = Duration::from_secs(10);
27const RESPONSE_TIMEOUT: Duration = Duration::from_secs(15);
28const KEEPALIVE_INTERVAL: Duration = Duration::from_secs(20);
30
31pub async fn register(url: &str, identity: &Identity, username: &str, machine: &str, session: &str) -> Res<SessionPath> {
37 let mut ws = connect(url).await?;
38 let nonce = hello_challenge(&mut ws, session).await?;
39 let pubkey = identity.public_key().to_vec();
40 send(
41 &mut ws,
42 &ProtocolMessage::Register {
43 username: username.to_owned(),
44 machine: machine.to_owned(),
45 pubkey: pubkey.clone(),
46 },
47 )
48 .await?;
49 send(
50 &mut ws,
51 &ProtocolMessage::Auth {
52 pubkey,
53 signature: identity.sign(&nonce)?.to_vec(),
54 },
55 )
56 .await?;
57
58 match recv(&mut ws).await? {
59 ProtocolMessage::Established { path } => Ok(path),
60 ProtocolMessage::Error(err) => anyhow::bail!("registration rejected: {err}"),
61 other => anyhow::bail!("unexpected response to register: {other:?}"),
62 }
63}
64
65pub async fn one_shot(url: &str, identity: &Identity, session: &str, request: ProtocolMessage) -> Res<ProtocolMessage> {
72 let mut ws = connect(url).await?;
73 authenticate(&mut ws, identity, session).await?;
74 send(&mut ws, &request).await?;
75 recv(&mut ws).await
76}
77
78pub async fn send_message(url: &str, identity: &Identity, session: &str, channel: &str, text: &str) -> Res<()> {
85 let mut ws = connect(url).await?;
86 let from = authenticate(&mut ws, identity, session).await?;
87
88 send(&mut ws, &ProtocolMessage::Join { channel: channel.to_owned(), token: None }).await?;
89 match recv(&mut ws).await? {
90 ProtocolMessage::Joined { .. } => {}
91 ProtocolMessage::Error(err) => anyhow::bail!("join rejected: {err}"),
92 other => anyhow::bail!("unexpected response to join: {other:?}"),
93 }
94
95 send(
96 &mut ws,
97 &ProtocolMessage::ChannelMsg {
98 channel: channel.to_owned(),
99 from,
100 payload: protocol::Payload::Plain(text.to_owned()),
101 },
102 )
103 .await?;
104 match recv(&mut ws).await? {
105 ProtocolMessage::Ack { .. } => Ok(()),
106 ProtocolMessage::Error(err) => anyhow::bail!("send rejected: {err}"),
107 other => anyhow::bail!("unexpected response to send: {other:?}"),
108 }
109}
110
111const TAIL_BACKOFF_BASE: Duration = Duration::from_secs(1);
114const TAIL_BACKOFF_MAX: Duration = Duration::from_secs(30);
115const TAIL_RESUME_SLACK_MS: i64 = 5_000;
117
118#[derive(Debug, thiserror::Error)]
120#[error("{0}")]
121struct TailFatal(String);
122
123pub async fn tail(url: &str, identity: &Identity, session: &str, channel: &str, since_secs: Option<u64>) -> Res<()> {
136 let mut watermark_ms: Option<i64> = since_secs.map(|secs| chrono::Utc::now().timestamp_millis().saturating_sub(i64::try_from(secs).unwrap_or(i64::MAX).saturating_mul(1000)));
139 let mut established = false;
140 let mut backoff = TAIL_BACKOFF_BASE;
141
142 loop {
143 match tail_once(url, identity, session, channel, &mut watermark_ms, &mut established).await {
144 Ok(()) => return Ok(()),
145 Err(err) if !established || err.downcast_ref::<TailFatal>().is_some() => return Err(err),
146 Err(_) => {
147 eprintln!("⚠ connection to `{url}` lost — reconnecting (Ctrl-C to stop)");
149 tokio::select! {
150 _ = tokio::signal::ctrl_c() => return Ok(()),
151 () = tokio::time::sleep(backoff) => {}
152 }
153 backoff = (backoff * 2).min(TAIL_BACKOFF_MAX);
154 }
155 }
156 }
157}
158
159async fn tail_once(url: &str, identity: &Identity, session: &str, channel: &str, watermark_ms: &mut Option<i64>, established: &mut bool) -> Res<()> {
162 use std::io::Write as _;
163
164 let mut ws = connect(url).await?;
165 let path = authenticate(&mut ws, identity, session).await?;
166
167 send(&mut ws, &ProtocolMessage::Join { channel: channel.to_owned(), token: None }).await?;
168 match recv(&mut ws).await? {
169 ProtocolMessage::Joined { channel } => {
170 if *established {
171 eprintln!("✓ reconnected; resuming #{channel}");
172 } else {
173 let mut out = std::io::stdout();
175 writeln!(out, "tailing #{channel} as {path} — Ctrl-C to stop")?;
176 out.flush()?;
177 *established = true;
178 }
179 }
180 ProtocolMessage::Error(err) => return Err(TailFatal(format!("join rejected: {err}")).into()),
181 other => return Err(TailFatal(format!("unexpected response to join: {other:?}")).into()),
182 }
183
184 if let Some(since) = *watermark_ms {
187 send(
188 &mut ws,
189 &ProtocolMessage::ReadSince {
190 channel: channel.to_owned(),
191 since_ms: since.saturating_sub(TAIL_RESUME_SLACK_MS),
192 },
193 )
194 .await?;
195 }
196
197 let mut keepalive = tokio::time::interval(KEEPALIVE_INTERVAL);
198 keepalive.tick().await; loop {
200 tokio::select! {
201 _ = tokio::signal::ctrl_c() => return Ok(()),
202 _ = keepalive.tick() => send(&mut ws, &ProtocolMessage::Ping).await?,
203 frame = recv_frame(&mut ws) => {
204 let mut out = std::io::stdout();
205 match frame? {
206 ProtocolMessage::ChannelMsg { channel, from, payload } => {
207 writeln!(out, "[{channel}] {from}: {}", render_payload(&payload))?;
208 *watermark_ms = Some(chrono::Utc::now().timestamp_millis());
209 }
210 ProtocolMessage::Whisper { from, payload, .. } => writeln!(out, "[whisper] {from}: {}", render_payload(&payload))?,
211 ProtocolMessage::History { channel, messages } => {
212 for message in &messages {
213 writeln!(out, "[{channel}] {}: {}", message.from, render_payload(&message.payload))?;
214 }
215 if let Some(newest) = messages.iter().map(|m| m.ts_ms).max() {
216 *watermark_ms = Some(watermark_ms.unwrap_or(newest).max(newest));
217 }
218 }
219 ProtocolMessage::Error(err) => return Err(TailFatal(format!("server terminated the stream: {err}")).into()),
222 _ => continue,
224 }
225 out.flush()?;
226 }
227 }
228 }
229}
230
231fn render_payload(payload: &protocol::Payload) -> &str {
233 match payload {
234 protocol::Payload::Plain(text) => text,
235 protocol::Payload::Encrypted(_) => "<end-to-end-encrypted payload>",
236 }
237}
238
239async fn authenticate(ws: &mut Ws, identity: &Identity, session: &str) -> Res<SessionPath> {
241 let nonce = hello_challenge(ws, session).await?;
242 send(
243 ws,
244 &ProtocolMessage::Auth {
245 pubkey: identity.public_key().to_vec(),
246 signature: identity.sign(&nonce)?.to_vec(),
247 },
248 )
249 .await?;
250
251 match recv(ws).await? {
252 ProtocolMessage::Established { path } => Ok(path),
253 ProtocolMessage::Error(err) => anyhow::bail!("authentication rejected: {err}"),
254 other => anyhow::bail!("unexpected response before request: {other:?}"),
255 }
256}
257
258async fn connect(url: &str) -> Res<Ws> {
259 connect_with_timeout(url, CONNECT_TIMEOUT).await
260}
261
262async fn connect_with_timeout(url: &str, timeout: Duration) -> Res<Ws> {
263 crate::base::ensure_tls_provider();
264 match tokio::time::timeout(timeout, tokio_tungstenite::connect_async(url)).await {
265 Ok(result) => {
266 let (ws, _response) = result.with_context(|| format!("failed to connect to `{url}`"))?;
267 Ok(ws)
268 }
269 Err(_) => anyhow::bail!("timed out after {}s connecting to `{url}`", timeout.as_secs()),
270 }
271}
272
273async fn hello_challenge(ws: &mut Ws, session: &str) -> Res<Vec<u8>> {
274 send(
275 ws,
276 &ProtocolMessage::Hello {
277 protocol_version: Constant::PROTOCOL_VERSION,
278 session: session.to_owned(),
279 },
280 )
281 .await?;
282 match recv(ws).await? {
283 ProtocolMessage::Challenge { nonce } => Ok(nonce),
284 other => anyhow::bail!("expected a challenge, got {other:?}"),
285 }
286}
287
288async fn send(ws: &mut Ws, frame: &ProtocolMessage) -> Res<()> {
289 ws.send(Message::Binary(protocol::encode(frame)?.into())).await.context("failed to send control frame")?;
290 Ok(())
291}
292
293async fn recv(ws: &mut Ws) -> Res<ProtocolMessage> {
296 recv_with_timeout(ws, RESPONSE_TIMEOUT).await
297}
298
299async fn recv_with_timeout(ws: &mut Ws, timeout: Duration) -> Res<ProtocolMessage> {
300 match tokio::time::timeout(timeout, recv_frame(ws)).await {
301 Ok(result) => result,
302 Err(_) => anyhow::bail!("timed out after {}s waiting for a server response", timeout.as_secs()),
303 }
304}
305
306async fn recv_frame(ws: &mut Ws) -> Res<ProtocolMessage> {
307 loop {
308 match ws.next().await {
309 Some(Ok(Message::Binary(data))) => match protocol::decode(&data)? {
310 ProtocolMessage::ServerInfo { .. } | ProtocolMessage::Pong => {}
311 frame => return Ok(frame),
312 },
313 Some(Ok(Message::Close(_))) | None => anyhow::bail!("connection closed before a response arrived"),
314 Some(Ok(_)) => {}
315 Some(Err(err)) => anyhow::bail!("websocket error: {err}"),
316 }
317 }
318}
319
320#[cfg(test)]
321mod tests {
322 #![allow(clippy::unwrap_used)]
324
325 use std::time::Duration;
326
327 use tokio::net::TcpListener;
328
329 use super::{connect_with_timeout, recv_with_timeout};
330
331 #[tokio::test]
334 async fn control_timeout_connecting_to_a_silent_server() {
335 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
336 let addr = listener.local_addr().unwrap();
337 tokio::spawn(async move {
338 let _accepted = listener.accept().await; std::future::pending::<()>().await;
340 });
341
342 let url = format!("ws://{addr}");
343 let err = connect_with_timeout(&url, Duration::from_millis(150)).await.expect_err("a silent server must time out");
344 assert!(err.to_string().to_lowercase().contains("timed out"), "expected a timeout error, got: {err}");
345 }
346
347 #[tokio::test]
350 async fn control_timeout_waiting_for_a_reply_from_a_silent_server() {
351 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
352 let addr = listener.local_addr().unwrap();
353 tokio::spawn(async move {
354 let (stream, _) = listener.accept().await.unwrap();
355 let _ws = tokio_tungstenite::accept_async(stream).await.unwrap(); std::future::pending::<()>().await;
357 });
358
359 let url = format!("ws://{addr}");
360 let mut ws = connect_with_timeout(&url, Duration::from_secs(5)).await.unwrap();
361 let err = recv_with_timeout(&mut ws, Duration::from_millis(150)).await.expect_err("a silent reply must time out");
362 assert!(err.to_string().to_lowercase().contains("timed out"), "expected a timeout error, got: {err}");
363 }
364}