1use std::time::Duration;
9
10use anyhow::Context as _;
11use futures_util::{SinkExt as _, StreamExt as _};
12use tokio::net::TcpStream;
13use tokio_tungstenite::{MaybeTlsStream, WebSocketStream, tungstenite::Message};
14
15use crate::{
16 base::{Constant, Res, SessionPath},
17 identity::Identity,
18 protocol::{self, ProtocolMessage},
19};
20
21type Ws = WebSocketStream<MaybeTlsStream<TcpStream>>;
23
24const CONNECT_TIMEOUT: Duration = Duration::from_secs(10);
27const RESPONSE_TIMEOUT: Duration = Duration::from_secs(15);
28
29pub async fn register(url: &str, identity: &Identity, username: &str, machine: &str, session: &str) -> Res<SessionPath> {
35 let mut ws = connect(url).await?;
36 let nonce = hello_challenge(&mut ws, session).await?;
37 let pubkey = identity.public_key().to_vec();
38 send(
39 &mut ws,
40 &ProtocolMessage::Register {
41 username: username.to_owned(),
42 machine: machine.to_owned(),
43 pubkey: pubkey.clone(),
44 },
45 )
46 .await?;
47 send(
48 &mut ws,
49 &ProtocolMessage::Auth {
50 pubkey,
51 signature: identity.sign(&nonce)?.to_vec(),
52 },
53 )
54 .await?;
55
56 match recv(&mut ws).await? {
57 ProtocolMessage::Established { path } => Ok(path),
58 ProtocolMessage::Error(err) => anyhow::bail!("registration rejected: {err}"),
59 other => anyhow::bail!("unexpected response to register: {other:?}"),
60 }
61}
62
63pub async fn one_shot(url: &str, identity: &Identity, session: &str, request: ProtocolMessage) -> Res<ProtocolMessage> {
70 let mut ws = connect(url).await?;
71 let nonce = hello_challenge(&mut ws, session).await?;
72 send(
73 &mut ws,
74 &ProtocolMessage::Auth {
75 pubkey: identity.public_key().to_vec(),
76 signature: identity.sign(&nonce)?.to_vec(),
77 },
78 )
79 .await?;
80
81 match recv(&mut ws).await? {
82 ProtocolMessage::Established { .. } => {}
83 ProtocolMessage::Error(err) => anyhow::bail!("authentication rejected: {err}"),
84 other => anyhow::bail!("unexpected response before request: {other:?}"),
85 }
86
87 send(&mut ws, &request).await?;
88 recv(&mut ws).await
89}
90
91async fn connect(url: &str) -> Res<Ws> {
92 connect_with_timeout(url, CONNECT_TIMEOUT).await
93}
94
95async fn connect_with_timeout(url: &str, timeout: Duration) -> Res<Ws> {
96 crate::base::ensure_tls_provider();
97 match tokio::time::timeout(timeout, tokio_tungstenite::connect_async(url)).await {
98 Ok(result) => {
99 let (ws, _response) = result.with_context(|| format!("failed to connect to `{url}`"))?;
100 Ok(ws)
101 }
102 Err(_) => anyhow::bail!("timed out after {}s connecting to `{url}`", timeout.as_secs()),
103 }
104}
105
106async fn hello_challenge(ws: &mut Ws, session: &str) -> Res<Vec<u8>> {
107 send(
108 ws,
109 &ProtocolMessage::Hello {
110 protocol_version: Constant::PROTOCOL_VERSION,
111 session: session.to_owned(),
112 },
113 )
114 .await?;
115 match recv(ws).await? {
116 ProtocolMessage::Challenge { nonce } => Ok(nonce),
117 other => anyhow::bail!("expected a challenge, got {other:?}"),
118 }
119}
120
121async fn send(ws: &mut Ws, frame: &ProtocolMessage) -> Res<()> {
122 ws.send(Message::Binary(protocol::encode(frame)?.into())).await.context("failed to send control frame")?;
123 Ok(())
124}
125
126async fn recv(ws: &mut Ws) -> Res<ProtocolMessage> {
129 recv_with_timeout(ws, RESPONSE_TIMEOUT).await
130}
131
132async fn recv_with_timeout(ws: &mut Ws, timeout: Duration) -> Res<ProtocolMessage> {
133 match tokio::time::timeout(timeout, recv_frame(ws)).await {
134 Ok(result) => result,
135 Err(_) => anyhow::bail!("timed out after {}s waiting for a server response", timeout.as_secs()),
136 }
137}
138
139async fn recv_frame(ws: &mut Ws) -> Res<ProtocolMessage> {
140 loop {
141 match ws.next().await {
142 Some(Ok(Message::Binary(data))) => match protocol::decode(&data)? {
143 ProtocolMessage::ServerInfo { .. } | ProtocolMessage::Pong => {}
144 frame => return Ok(frame),
145 },
146 Some(Ok(Message::Close(_))) | None => anyhow::bail!("connection closed before a response arrived"),
147 Some(Ok(_)) => {}
148 Some(Err(err)) => anyhow::bail!("websocket error: {err}"),
149 }
150 }
151}
152
153#[cfg(test)]
154mod tests {
155 #![allow(clippy::unwrap_used)]
157
158 use std::time::Duration;
159
160 use tokio::net::TcpListener;
161
162 use super::{connect_with_timeout, recv_with_timeout};
163
164 #[tokio::test]
167 async fn control_timeout_connecting_to_a_silent_server() {
168 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
169 let addr = listener.local_addr().unwrap();
170 tokio::spawn(async move {
171 let _accepted = listener.accept().await; std::future::pending::<()>().await;
173 });
174
175 let url = format!("ws://{addr}");
176 let err = connect_with_timeout(&url, Duration::from_millis(150)).await.expect_err("a silent server must time out");
177 assert!(err.to_string().to_lowercase().contains("timed out"), "expected a timeout error, got: {err}");
178 }
179
180 #[tokio::test]
183 async fn control_timeout_waiting_for_a_reply_from_a_silent_server() {
184 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
185 let addr = listener.local_addr().unwrap();
186 tokio::spawn(async move {
187 let (stream, _) = listener.accept().await.unwrap();
188 let _ws = tokio_tungstenite::accept_async(stream).await.unwrap(); std::future::pending::<()>().await;
190 });
191
192 let url = format!("ws://{addr}");
193 let mut ws = connect_with_timeout(&url, Duration::from_secs(5)).await.unwrap();
194 let err = recv_with_timeout(&mut ws, Duration::from_millis(150)).await.expect_err("a silent reply must time out");
195 assert!(err.to_string().to_lowercase().contains("timed out"), "expected a timeout error, got: {err}");
196 }
197}