1use anyhow::Context as _;
9use futures_util::{SinkExt as _, StreamExt as _};
10use tokio::net::TcpStream;
11use tokio_tungstenite::{MaybeTlsStream, WebSocketStream, tungstenite::Message};
12
13use crate::{
14 base::{Constant, Res, SessionPath},
15 identity::Identity,
16 protocol::{self, ProtocolMessage},
17};
18
19type Ws = WebSocketStream<MaybeTlsStream<TcpStream>>;
21
22pub async fn register(url: &str, identity: &Identity, username: &str, machine: &str, session: &str) -> Res<SessionPath> {
28 let mut ws = connect(url).await?;
29 let nonce = hello_challenge(&mut ws, session).await?;
30 let pubkey = identity.public_key().to_vec();
31 send(
32 &mut ws,
33 &ProtocolMessage::Register {
34 username: username.to_owned(),
35 machine: machine.to_owned(),
36 pubkey: pubkey.clone(),
37 },
38 )
39 .await?;
40 send(
41 &mut ws,
42 &ProtocolMessage::Auth {
43 pubkey,
44 signature: identity.sign(&nonce)?.to_vec(),
45 },
46 )
47 .await?;
48
49 match recv(&mut ws).await? {
50 ProtocolMessage::Established { path } => Ok(path),
51 ProtocolMessage::Error(err) => anyhow::bail!("registration rejected: {err}"),
52 other => anyhow::bail!("unexpected response to register: {other:?}"),
53 }
54}
55
56pub async fn one_shot(url: &str, identity: &Identity, session: &str, request: ProtocolMessage) -> Res<ProtocolMessage> {
63 let mut ws = connect(url).await?;
64 let nonce = hello_challenge(&mut ws, session).await?;
65 send(
66 &mut ws,
67 &ProtocolMessage::Auth {
68 pubkey: identity.public_key().to_vec(),
69 signature: identity.sign(&nonce)?.to_vec(),
70 },
71 )
72 .await?;
73
74 match recv(&mut ws).await? {
75 ProtocolMessage::Established { .. } => {}
76 ProtocolMessage::Error(err) => anyhow::bail!("authentication rejected: {err}"),
77 other => anyhow::bail!("unexpected response before request: {other:?}"),
78 }
79
80 send(&mut ws, &request).await?;
81 recv(&mut ws).await
82}
83
84async fn connect(url: &str) -> Res<Ws> {
85 let (ws, _response) = tokio_tungstenite::connect_async(url).await.with_context(|| format!("failed to connect to `{url}`"))?;
86 Ok(ws)
87}
88
89async fn hello_challenge(ws: &mut Ws, session: &str) -> Res<Vec<u8>> {
90 send(
91 ws,
92 &ProtocolMessage::Hello {
93 protocol_version: Constant::PROTOCOL_VERSION,
94 session: session.to_owned(),
95 },
96 )
97 .await?;
98 match recv(ws).await? {
99 ProtocolMessage::Challenge { nonce } => Ok(nonce),
100 other => anyhow::bail!("expected a challenge, got {other:?}"),
101 }
102}
103
104async fn send(ws: &mut Ws, frame: &ProtocolMessage) -> Res<()> {
105 ws.send(Message::Binary(protocol::encode(frame)?.into())).await.context("failed to send control frame")?;
106 Ok(())
107}
108
109async fn recv(ws: &mut Ws) -> Res<ProtocolMessage> {
111 loop {
112 match ws.next().await {
113 Some(Ok(Message::Binary(data))) => match protocol::decode(&data)? {
114 ProtocolMessage::ServerInfo { .. } | ProtocolMessage::Pong => {}
115 frame => return Ok(frame),
116 },
117 Some(Ok(Message::Close(_))) | None => anyhow::bail!("connection closed before a response arrived"),
118 Some(Ok(_)) => {}
119 Some(Err(err)) => anyhow::bail!("websocket error: {err}"),
120 }
121 }
122}