1use std::time::Duration;
9
10use anyhow::Context as _;
11use futures_util::{SinkExt as _, StreamExt as _};
12use tokio::net::TcpStream;
13use tokio_tungstenite::{MaybeTlsStream, WebSocketStream, tungstenite::Message};
14
15use crate::{
16 base::{Constant, Res, SessionPath},
17 identity::Identity,
18 protocol::{self, ProtocolMessage},
19};
20
21type Ws = WebSocketStream<MaybeTlsStream<TcpStream>>;
23
24const CONNECT_TIMEOUT: Duration = Duration::from_secs(10);
27const RESPONSE_TIMEOUT: Duration = Duration::from_secs(15);
28const KEEPALIVE_INTERVAL: Duration = Duration::from_secs(20);
30
31pub async fn register(url: &str, identity: &Identity, username: &str, machine: &str, session: &str) -> Res<SessionPath> {
37 let mut ws = connect(url).await?;
38 let nonce = hello_challenge(&mut ws, session).await?;
39 let pubkey = identity.public_key().to_vec();
40 send(
41 &mut ws,
42 &ProtocolMessage::Register {
43 username: username.to_owned(),
44 machine: machine.to_owned(),
45 pubkey: pubkey.clone(),
46 },
47 )
48 .await?;
49 send(
50 &mut ws,
51 &ProtocolMessage::Auth {
52 pubkey,
53 signature: identity.sign(&nonce)?.to_vec(),
54 },
55 )
56 .await?;
57
58 match recv(&mut ws).await? {
59 ProtocolMessage::Established { path } => Ok(path),
60 ProtocolMessage::Error(err) => anyhow::bail!("registration rejected: {err}"),
61 other => anyhow::bail!("unexpected response to register: {other:?}"),
62 }
63}
64
65pub async fn one_shot(url: &str, identity: &Identity, session: &str, request: ProtocolMessage) -> Res<ProtocolMessage> {
72 let mut ws = connect(url).await?;
73 authenticate(&mut ws, identity, session).await?;
74 send(&mut ws, &request).await?;
75 recv(&mut ws).await
76}
77
78pub async fn send_message(url: &str, identity: &Identity, session: &str, channel: &str, text: &str) -> Res<()> {
85 let mut ws = connect(url).await?;
86 let from = authenticate(&mut ws, identity, session).await?;
87
88 send(&mut ws, &ProtocolMessage::Join { channel: channel.to_owned(), token: None }).await?;
89 match recv(&mut ws).await? {
90 ProtocolMessage::Joined { .. } => {}
91 ProtocolMessage::Error(err) => anyhow::bail!("join rejected: {err}"),
92 other => anyhow::bail!("unexpected response to join: {other:?}"),
93 }
94
95 send(
96 &mut ws,
97 &ProtocolMessage::ChannelMsg {
98 channel: channel.to_owned(),
99 from,
100 payload: protocol::Payload::Plain(text.to_owned()),
101 },
102 )
103 .await?;
104 match recv(&mut ws).await? {
105 ProtocolMessage::Ack { .. } => Ok(()),
106 ProtocolMessage::Error(err) => anyhow::bail!("send rejected: {err}"),
107 other => anyhow::bail!("unexpected response to send: {other:?}"),
108 }
109}
110
111pub async fn tail(url: &str, identity: &Identity, session: &str, channel: &str) -> Res<()> {
119 use std::io::Write as _;
120
121 let mut ws = connect(url).await?;
122 let path = authenticate(&mut ws, identity, session).await?;
123
124 send(&mut ws, &ProtocolMessage::Join { channel: channel.to_owned(), token: None }).await?;
125 match recv(&mut ws).await? {
126 ProtocolMessage::Joined { channel } => {
127 let mut out = std::io::stdout();
129 writeln!(out, "tailing #{channel} as {path} — Ctrl-C to stop")?;
130 out.flush()?;
131 }
132 ProtocolMessage::Error(err) => anyhow::bail!("join rejected: {err}"),
133 other => anyhow::bail!("unexpected response to join: {other:?}"),
134 }
135
136 let mut keepalive = tokio::time::interval(KEEPALIVE_INTERVAL);
137 keepalive.tick().await; loop {
139 tokio::select! {
140 _ = tokio::signal::ctrl_c() => return Ok(()),
141 _ = keepalive.tick() => send(&mut ws, &ProtocolMessage::Ping).await?,
142 frame = recv_frame(&mut ws) => {
143 let mut out = std::io::stdout();
144 match frame? {
145 ProtocolMessage::ChannelMsg { channel, from, payload } => writeln!(out, "[{channel}] {from}: {}", render_payload(&payload))?,
146 ProtocolMessage::Whisper { from, payload, .. } => writeln!(out, "[whisper] {from}: {}", render_payload(&payload))?,
147 _ => continue,
149 }
150 out.flush()?;
151 }
152 }
153 }
154}
155
156fn render_payload(payload: &protocol::Payload) -> &str {
158 match payload {
159 protocol::Payload::Plain(text) => text,
160 protocol::Payload::Encrypted(_) => "<end-to-end-encrypted payload>",
161 }
162}
163
164async fn authenticate(ws: &mut Ws, identity: &Identity, session: &str) -> Res<SessionPath> {
166 let nonce = hello_challenge(ws, session).await?;
167 send(
168 ws,
169 &ProtocolMessage::Auth {
170 pubkey: identity.public_key().to_vec(),
171 signature: identity.sign(&nonce)?.to_vec(),
172 },
173 )
174 .await?;
175
176 match recv(ws).await? {
177 ProtocolMessage::Established { path } => Ok(path),
178 ProtocolMessage::Error(err) => anyhow::bail!("authentication rejected: {err}"),
179 other => anyhow::bail!("unexpected response before request: {other:?}"),
180 }
181}
182
183async fn connect(url: &str) -> Res<Ws> {
184 connect_with_timeout(url, CONNECT_TIMEOUT).await
185}
186
187async fn connect_with_timeout(url: &str, timeout: Duration) -> Res<Ws> {
188 crate::base::ensure_tls_provider();
189 match tokio::time::timeout(timeout, tokio_tungstenite::connect_async(url)).await {
190 Ok(result) => {
191 let (ws, _response) = result.with_context(|| format!("failed to connect to `{url}`"))?;
192 Ok(ws)
193 }
194 Err(_) => anyhow::bail!("timed out after {}s connecting to `{url}`", timeout.as_secs()),
195 }
196}
197
198async fn hello_challenge(ws: &mut Ws, session: &str) -> Res<Vec<u8>> {
199 send(
200 ws,
201 &ProtocolMessage::Hello {
202 protocol_version: Constant::PROTOCOL_VERSION,
203 session: session.to_owned(),
204 },
205 )
206 .await?;
207 match recv(ws).await? {
208 ProtocolMessage::Challenge { nonce } => Ok(nonce),
209 other => anyhow::bail!("expected a challenge, got {other:?}"),
210 }
211}
212
213async fn send(ws: &mut Ws, frame: &ProtocolMessage) -> Res<()> {
214 ws.send(Message::Binary(protocol::encode(frame)?.into())).await.context("failed to send control frame")?;
215 Ok(())
216}
217
218async fn recv(ws: &mut Ws) -> Res<ProtocolMessage> {
221 recv_with_timeout(ws, RESPONSE_TIMEOUT).await
222}
223
224async fn recv_with_timeout(ws: &mut Ws, timeout: Duration) -> Res<ProtocolMessage> {
225 match tokio::time::timeout(timeout, recv_frame(ws)).await {
226 Ok(result) => result,
227 Err(_) => anyhow::bail!("timed out after {}s waiting for a server response", timeout.as_secs()),
228 }
229}
230
231async fn recv_frame(ws: &mut Ws) -> Res<ProtocolMessage> {
232 loop {
233 match ws.next().await {
234 Some(Ok(Message::Binary(data))) => match protocol::decode(&data)? {
235 ProtocolMessage::ServerInfo { .. } | ProtocolMessage::Pong => {}
236 frame => return Ok(frame),
237 },
238 Some(Ok(Message::Close(_))) | None => anyhow::bail!("connection closed before a response arrived"),
239 Some(Ok(_)) => {}
240 Some(Err(err)) => anyhow::bail!("websocket error: {err}"),
241 }
242 }
243}
244
245#[cfg(test)]
246mod tests {
247 #![allow(clippy::unwrap_used)]
249
250 use std::time::Duration;
251
252 use tokio::net::TcpListener;
253
254 use super::{connect_with_timeout, recv_with_timeout};
255
256 #[tokio::test]
259 async fn control_timeout_connecting_to_a_silent_server() {
260 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
261 let addr = listener.local_addr().unwrap();
262 tokio::spawn(async move {
263 let _accepted = listener.accept().await; std::future::pending::<()>().await;
265 });
266
267 let url = format!("ws://{addr}");
268 let err = connect_with_timeout(&url, Duration::from_millis(150)).await.expect_err("a silent server must time out");
269 assert!(err.to_string().to_lowercase().contains("timed out"), "expected a timeout error, got: {err}");
270 }
271
272 #[tokio::test]
275 async fn control_timeout_waiting_for_a_reply_from_a_silent_server() {
276 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
277 let addr = listener.local_addr().unwrap();
278 tokio::spawn(async move {
279 let (stream, _) = listener.accept().await.unwrap();
280 let _ws = tokio_tungstenite::accept_async(stream).await.unwrap(); std::future::pending::<()>().await;
282 });
283
284 let url = format!("ws://{addr}");
285 let mut ws = connect_with_timeout(&url, Duration::from_secs(5)).await.unwrap();
286 let err = recv_with_timeout(&mut ws, Duration::from_millis(150)).await.expect_err("a silent reply must time out");
287 assert!(err.to_string().to_lowercase().contains("timed out"), "expected a timeout error, got: {err}");
288 }
289}