1use std::time::Duration;
4
5use anyhow::{Context, Result};
6use futures_util::{SinkExt, StreamExt};
7use serde::{de::DeserializeOwned, Deserialize, Serialize};
8use tokio::io::{self, AsyncRead, AsyncWrite};
9use tokio::time::timeout;
10use tokio_util::codec::{AnyDelimiterCodec, Framed, FramedParts};
11use tracing::trace;
12use uuid::Uuid;
13
14pub const CONTROL_PORT: u16 = 7835;
16
17pub const MAX_FRAME_LENGTH: usize = 256;
19
20pub const NETWORK_TIMEOUT: Duration = Duration::from_secs(3);
22
23#[derive(Debug, Serialize, Deserialize)]
25pub enum ClientMessage {
26 Authenticate(String),
28
29 Hello(u16),
31
32 Accept(Uuid),
34}
35
36#[derive(Debug, Serialize, Deserialize)]
38pub enum ServerMessage {
39 Challenge(Uuid),
41
42 Hello(u16),
44
45 Heartbeat,
47
48 Connection(Uuid),
50
51 Error(String),
53}
54
55pub struct Delimited<U>(Framed<U, AnyDelimiterCodec>);
57
58impl<U: AsyncRead + AsyncWrite + Unpin> Delimited<U> {
59 pub fn new(stream: U) -> Self {
61 let codec = AnyDelimiterCodec::new_with_max_length(vec![0], vec![0], MAX_FRAME_LENGTH);
62 Self(Framed::new(stream, codec))
63 }
64
65 pub async fn recv<T: DeserializeOwned>(&mut self) -> Result<Option<T>> {
67 trace!("waiting to receive json message");
68 if let Some(next_message) = self.0.next().await {
69 let byte_message = next_message.context("frame error, invalid byte length")?;
70 let serialized_obj =
71 serde_json::from_slice(&byte_message).context("unable to parse message")?;
72 Ok(serialized_obj)
73 } else {
74 Ok(None)
75 }
76 }
77
78 pub async fn recv_timeout<T: DeserializeOwned>(&mut self) -> Result<Option<T>> {
83 timeout(NETWORK_TIMEOUT, self.recv())
84 .await
85 .context("timed out waiting for initial message")?
86 }
87
88 pub async fn send<T: Serialize>(&mut self, msg: T) -> Result<()> {
90 trace!("sending json message");
91 self.0.send(serde_json::to_string(&msg)?).await?;
92 Ok(())
93 }
94
95 pub fn into_parts(self) -> FramedParts<U, AnyDelimiterCodec> {
97 self.0.into_parts()
98 }
99}
100
101pub async fn proxy<S1, S2>(stream1: S1, stream2: S2) -> io::Result<()>
103where
104 S1: AsyncRead + AsyncWrite + Unpin,
105 S2: AsyncRead + AsyncWrite + Unpin,
106{
107 let (mut s1_read, mut s1_write) = io::split(stream1);
108 let (mut s2_read, mut s2_write) = io::split(stream2);
109 tokio::select! {
110 res = io::copy(&mut s1_read, &mut s2_write) => res,
111 res = io::copy(&mut s2_read, &mut s1_write) => res,
112 }?;
113 Ok(())
114}