1pub mod reconnect;
2
3use std::{io, net::SocketAddr};
4
5use bytes::BytesMut;
6use chrono::DateTime;
7use serde::{Deserialize, Serialize};
8use tokio_util::codec::{Encoder, Decoder};
9use uuid::Uuid;
10
11pub const CLIENT_HELLO_VERSION: u16 = 2;
12
13#[derive(Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
14#[serde(transparent)]
15pub struct StreamId(Uuid);
16
17impl std::fmt::Display for StreamId {
18 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
19 write!(f, "stream_{}", self.0)
20 }
21}
22
23impl StreamId {
24 pub fn new() -> Self {
25 Self(Uuid::new_v4())
26 }
27}
28
29#[derive(Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
30#[serde(transparent)]
31pub struct ClientId(Uuid);
32
33impl std::fmt::Display for ClientId {
34 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35 write!(f, "client_{}", self.0)
36 }
37}
38
39impl ClientId {
40 pub fn new() -> Self {
41 Self(Uuid::new_v4())
42 }
43}
44
45#[derive(Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
46#[serde(transparent)]
47pub struct EndpointId(Uuid);
48impl std::fmt::Display for EndpointId {
49 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
50 write!(f, "endpoint_{}", self.0)
51 }
52}
53impl EndpointId {
54 pub fn new() -> Self {
55 Self(Uuid::new_v4())
56 }
57}
58
59#[derive(Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq, Hash)]
60pub enum Protocol {
61 TCP = 6,
62 UDP = 17,
63}
64
65impl std::fmt::Display for Protocol {
66 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
67 match self {
68 Protocol::TCP => write!(f, "tcp"),
69 Protocol::UDP => write!(f, "udp"),
70 }
71 }
72}
73
74
75#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Hash)]
76pub struct EndpointClaim {
77 pub protocol: Protocol,
78 pub local_port: u16,
79 pub remote_port: u16,
80}
81
82pub type EndpointClaims = Vec<EndpointClaim>;
83
84#[derive(Serialize, Deserialize, Debug, Clone)]
85pub struct ClientHelloV2 {
86 pub version: u16,
87 pub token: String,
88 pub endpoint_claims: EndpointClaims,
89 pub client_type: ClientType,
90}
91
92#[derive(Serialize, Deserialize, Debug, Clone)]
93pub enum ClientType {
94 Auth,
95 Reconnect,
96}
97
98#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
99pub enum ControlPacketV2 {
100 Init(StreamId, EndpointId, RemoteInfo),
101 Data(StreamId, Vec<u8>),
102 Refused(StreamId),
103 End(StreamId),
104 Ping(u32,DateTime<chrono::Utc>,Option<String>),
105 Pong(u32,DateTime<chrono::Utc>),
106}
107
108impl std::fmt::Display for ControlPacketV2 {
109 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
110 match self {
111 ControlPacketV2::Init(sid, eid, remote_info) => write!(f, "ControlPacket::Init(sid={}, eid={}, remote_info={})", sid, eid, remote_info),
112 ControlPacketV2::Data(sid, data) => write!(f, "ControlPacket::Data(sid={}, data_len={})", sid, data.len()),
113 ControlPacketV2::Refused(sid) => write!(f, "ControlPacket::Refused(sid={})", sid),
114 ControlPacketV2::End(sid) => write!(f, "ControlPacket::End(sid={})", sid),
115 ControlPacketV2::Ping(seq, datetime, Some(token)) => write!(f, "ControlPacket::Ping(seq={}, datetime={}, token={})", seq, datetime, token),
116 ControlPacketV2::Ping(seq, datetime, None) => write!(f, "ControlPacket::Ping(seq={}, datetime={})", seq, datetime),
117 ControlPacketV2::Pong(seq, datetime) => write!(f, "ControlPacket::Pong(seq={}, datetime={})", seq, datetime),
118 }
119 }
120}
121
122#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Hash)]
123pub struct Endpoint {
124 pub id: EndpointId,
125 pub protocol: Protocol,
126 pub local_port: u16,
127 pub remote_port: u16
128}
129
130pub type Endpoints = Vec<Endpoint>;
131
132#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Hash)]
133pub struct RemoteInfo {
134 pub remote_peer_addr: SocketAddr,
135}
136
137impl RemoteInfo {
138 pub fn new(remote_peer_addr: SocketAddr) -> Self {
139 Self { remote_peer_addr }
140 }
141}
142
143impl std::fmt::Display for RemoteInfo {
144 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
145 write!(f, "RemoteInfo(remote_peer_addr={}", self.remote_peer_addr)
146 }
147}
148
149#[derive(Serialize, Deserialize, Debug, Clone)]
150#[serde(rename_all = "snake_case")]
151pub enum ServerHelloV2 {
152 Success {
153 client_id: ClientId,
154 host: String,
155 endpoints: Endpoints,
156 },
157 BadRequest,
158 ServiceTemporaryUnavailable,
159 IllegalHost,
160 InternalServerError,
161 VersionMismatch,
162}
163
164
165#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Default)]
166pub struct ControlPacketV2Codec {
167}
168impl ControlPacketV2Codec {
169 pub fn new() -> Self {
170 Self {}
171 }
172}
173impl Encoder<ControlPacketV2> for ControlPacketV2Codec {
174 type Error = io::Error;
175
176 fn encode(&mut self, item: ControlPacketV2, dst: &mut BytesMut) -> Result<(), Self::Error> {
177 let encoded = rmp_serde::to_vec(&item).map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
178 dst.extend_from_slice(&encoded);
179 Ok(())
180 }
181}
182
183impl Decoder for ControlPacketV2Codec {
184 type Item = ControlPacketV2;
185 type Error = io::Error;
186
187 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
188 if !src.is_empty() {
189 let decoded = rmp_serde::from_slice(src).map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
190 Ok(Some(decoded))
191 } else {
192 Ok(None)
193 }
194 }
195}
196
197
198
199#[cfg(test)]
200mod control_packet_test {
201 use bytes::BytesMut;
202
203 use super::*;
204
205 #[test]
206 fn test_control_packet_init() -> Result<(), Box<dyn std::error::Error>> {
207 let stream_id = StreamId::default();
208 let endpoint_id = EndpointId::default();
209 let remote_info = RemoteInfo::new("127.0.0.1:8080".parse()?);
210 let expected_packet = ControlPacketV2::Init(stream_id, endpoint_id, remote_info);
211
212 let mut encoded = BytesMut::new();
213 ControlPacketV2Codec::new().encode(expected_packet, &mut encoded)?;
214
215 let deserialized_packet = ControlPacketV2Codec::new().decode(&mut encoded)?.unwrap();
216 assert_eq!(ControlPacketV2::Init(stream_id, endpoint_id, RemoteInfo::new("127.0.0.1:8080".parse()?)), deserialized_packet);
217 Ok(())
218 }
219}