atm0s_reverse_proxy_protocol/
cluster.rs

1use anyhow::anyhow;
2use serde::{de::DeserializeOwned, Deserialize, Serialize};
3use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
4
5#[derive(Debug, Serialize, Deserialize)]
6pub struct ClusterTunnelRequest {
7    pub domain: String,
8    pub handshake: Vec<u8>,
9}
10
11impl From<&ClusterTunnelRequest> for Vec<u8> {
12    fn from(resp: &ClusterTunnelRequest) -> Self {
13        bincode::serialize(resp).expect("Should serialize cluster tunnel request")
14    }
15}
16
17impl TryFrom<&[u8]> for ClusterTunnelRequest {
18    type Error = bincode::Error;
19
20    fn try_from(buf: &[u8]) -> Result<Self, Self::Error> {
21        bincode::deserialize(buf)
22    }
23}
24
25#[derive(Debug, Serialize, Deserialize)]
26pub struct ClusterTunnelResponse {
27    pub success: bool,
28}
29
30impl From<&ClusterTunnelResponse> for Vec<u8> {
31    fn from(resp: &ClusterTunnelResponse) -> Self {
32        bincode::serialize(resp).expect("Should serialize cluster tunnel response")
33    }
34}
35
36impl TryFrom<&[u8]> for ClusterTunnelResponse {
37    type Error = bincode::Error;
38
39    fn try_from(buf: &[u8]) -> Result<Self, Self::Error> {
40        bincode::deserialize(buf)
41    }
42}
43
44#[derive(Debug, Serialize, Deserialize)]
45pub struct AgentTunnelRequest {
46    pub service: Option<u16>,
47    pub tls: bool,
48    pub domain: String,
49}
50
51impl From<&AgentTunnelRequest> for Vec<u8> {
52    fn from(resp: &AgentTunnelRequest) -> Self {
53        bincode::serialize(resp).expect("Should serialize agent tunnel request")
54    }
55}
56
57impl TryFrom<&[u8]> for AgentTunnelRequest {
58    type Error = bincode::Error;
59
60    fn try_from(buf: &[u8]) -> Result<Self, Self::Error> {
61        bincode::deserialize(buf)
62    }
63}
64
65pub async fn wait_object<R: AsyncRead + Unpin, O: DeserializeOwned, const MAX_SIZE: usize>(reader: &mut R) -> anyhow::Result<O> {
66    let mut len_buf = [0; 2];
67    let mut data_buf = [0; MAX_SIZE];
68    reader.read_exact(&mut len_buf).await?;
69    let handshake_len = u16::from_be_bytes([len_buf[0], len_buf[1]]) as usize;
70    if handshake_len > data_buf.len() {
71        return Err(anyhow!("packet to big {} vs {MAX_SIZE}", data_buf.len()));
72    }
73
74    reader.read_exact(&mut data_buf[0..handshake_len]).await?;
75
76    Ok(bincode::deserialize(&data_buf[0..handshake_len])?)
77}
78
79pub async fn write_object<W: AsyncWrite + Send + Unpin, O: Serialize, const MAX_SIZE: usize>(writer: &mut W, object: &O) -> anyhow::Result<()> {
80    let data_buf: Vec<u8> = bincode::serialize(object).expect("Should convert to binary");
81    if data_buf.len() > MAX_SIZE {
82        return Err(anyhow!("buffer to big {} vs {MAX_SIZE}", data_buf.len()));
83    }
84    let len_buf = (data_buf.len() as u16).to_be_bytes();
85
86    writer.write_all(&len_buf).await?;
87    writer.write_all(&data_buf).await?;
88    Ok(())
89}
90
91pub async fn wait_buf<R: AsyncRead + Unpin, const MAX_SIZE: usize>(reader: &mut R) -> anyhow::Result<Vec<u8>> {
92    let mut len_buf = [0; 2];
93    let mut data_buf = [0; MAX_SIZE];
94    reader.read_exact(&mut len_buf).await?;
95    let handshake_len = u16::from_be_bytes([len_buf[0], len_buf[1]]) as usize;
96    if handshake_len > data_buf.len() {
97        return Err(anyhow!("packet to big {} vs {MAX_SIZE}", data_buf.len()));
98    }
99
100    reader.read_exact(&mut data_buf[0..handshake_len]).await?;
101
102    Ok(data_buf[0..handshake_len].to_vec())
103}
104
105pub async fn write_buf<W: AsyncWrite + Send + Unpin, const MAX_SIZE: usize>(writer: &mut W, data_buf: &[u8]) -> anyhow::Result<()> {
106    if data_buf.len() > MAX_SIZE {
107        return Err(anyhow!("buffer to big {} vs {MAX_SIZE}", data_buf.len()));
108    }
109    let len_buf = (data_buf.len() as u16).to_be_bytes();
110
111    writer.write_all(&len_buf).await?;
112    writer.write_all(data_buf).await?;
113    Ok(())
114}