1use std::io;
2use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6};
3use std::sync::LazyLock;
4
5use bytes::Bytes;
6use serde::{Deserialize, Serialize};
7
8pub type Secret = [u8; 32];
9
10pub const PROTOCOLS_VERSION: u8 = 0x01;
11
12static BINCODE_CONFIG: LazyLock<bincode::config::Configuration> =
13 LazyLock::new(bincode::config::standard);
14
15pub fn encode<T: Serialize>(message: &T) -> io::Result<Bytes> {
16 bincode::serde::encode_to_vec(message, *BINCODE_CONFIG)
17 .map(Bytes::from)
18 .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))
19}
20
21pub fn decode<'a, T: Deserialize<'a>>(bytes: &'a [u8]) -> io::Result<T> {
22 bincode::serde::borrow_decode_from_slice(bytes, *BINCODE_CONFIG)
23 .map(|(msg, _)| msg)
24 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
25}
26
27#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
28pub struct ClientHello {
29 pub version: u8,
30 pub secret: Secret,
31 #[serde(with = "serde_bytes")]
32 pub options: Bytes,
33}
34
35#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
36pub struct ClientConnect {
37 pub address: Address,
38}
39
40#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
41pub enum ServerHandshakeResponse {
42 Ok,
43 Err(HandshakeError),
44}
45
46#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
47pub enum UdpPacket {
48 Unfragmented {
49 session_id: u64,
50 address: Address,
51 #[serde(with = "serde_bytes")]
52 data: Bytes,
53 },
54 Fragmented {
55 session_id: u64,
56 fragment_id: u32,
57 fragment_index: u16,
58 fragment_count: u16,
59 address: Option<Address>,
60 #[serde(with = "serde_bytes")]
61 data: Bytes,
62 },
63}
64
65impl UdpPacket {
66 pub fn fragmented_overhead() -> usize {
67 let fixed_overhead = 1 + 8 + 4 + 2 + 2;
69 const MAX_ADDRESS_OVERHEAD: usize = 260;
71 fixed_overhead + MAX_ADDRESS_OVERHEAD
72 }
73
74 pub fn split_packet(
75 session_id: u64,
76 address: Address,
77 data: Bytes,
78 max_payload_size: usize,
79 fragment_id: u32,
80 ) -> impl Iterator<Item = UdpPacket> {
81 let data_chunks: Vec<Bytes> = data
82 .chunks(max_payload_size)
83 .map(Bytes::copy_from_slice)
84 .collect();
85 let fragment_count = data_chunks.len() as u16;
86
87 data_chunks.into_iter().enumerate().map(move |(i, chunk)| {
88 let fragment_index = i as u16;
89 UdpPacket::Fragmented {
90 session_id,
91 fragment_id,
92 fragment_index,
93 fragment_count,
94 address: if fragment_index == 0 {
95 Some(address.clone())
96 } else {
97 None
98 },
99 data: chunk,
100 }
101 })
102 }
103}
104
105#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
106pub enum HandshakeError {
107 UnsupportedVersion,
108 InvalidSecret,
109 InternalServerError,
110}
111
112#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Hash)]
113pub enum Address {
114 SocketV4(SocketAddrV4),
115 SocketV6(SocketAddrV6),
116 Domain(#[serde(with = "serde_bytes")] Bytes, u16),
117}
118
119impl Address {
120 pub async fn to_socket_addr(&self) -> io::Result<SocketAddr> {
121 match self {
122 Self::SocketV4(addr) => Ok((*addr).into()),
123 Self::SocketV6(addr) => Ok((*addr).into()),
124 Self::Domain(domain, port) => {
125 let domain_str = std::str::from_utf8(domain).map_err(|_| {
126 io::Error::new(
127 io::ErrorKind::InvalidInput,
128 "Domain name contains invalid UTF-8 characters",
129 )
130 })?;
131
132 match tokio::net::lookup_host((domain_str, *port)).await?.next() {
133 Some(addr) => Ok(addr),
134 None => Err(io::Error::new(
135 io::ErrorKind::NotFound,
136 format!("Domain name '{}' could not be resolved", domain_str),
137 )),
138 }
139 }
140 }
141 }
142}
143
144impl From<SocketAddr> for Address {
145 fn from(value: SocketAddr) -> Self {
146 match value {
147 SocketAddr::V4(addr) => Self::SocketV4(addr),
148 SocketAddr::V6(addr) => Self::SocketV6(addr),
149 }
150 }
151}
152
153impl TryFrom<&str> for Address {
154 type Error = io::Error;
155
156 fn try_from(value: &str) -> Result<Self, Self::Error> {
157 if let Ok(addr) = value.parse::<SocketAddr>() {
158 return Ok(Address::from(addr));
159 }
160
161 if let Some((domain, port_str)) = value.rsplit_once(':')
162 && let Ok(port) = port_str.parse::<u16>()
163 {
164 if domain.is_empty() {
165 return Err(io::Error::new(
166 io::ErrorKind::InvalidInput,
167 "Domain name cannot be empty",
168 ));
169 }
170
171 if domain.len() > 255 {
172 return Err(io::Error::new(
173 io::ErrorKind::InvalidInput,
174 format!("Domain name is too long: {} bytes (max 255)", domain.len()),
175 ));
176 }
177
178 return Ok(Address::Domain(
179 Bytes::copy_from_slice(domain.as_bytes()),
180 port,
181 ));
182 }
183
184 Err(io::Error::new(
185 io::ErrorKind::InvalidInput,
186 format!("Invalid address format: {}", value),
187 ))
188 }
189}
190
191impl TryFrom<String> for Address {
192 type Error = io::Error;
193
194 fn try_from(value: String) -> Result<Self, Self::Error> {
195 Address::try_from(value.as_str())
196 }
197}
198
199impl From<(String, u16)> for Address {
200 fn from(value: (String, u16)) -> Self {
201 Address::Domain(Bytes::from(value.0), value.1)
202 }
203}
204
205impl From<(&str, u16)> for Address {
206 fn from(value: (&str, u16)) -> Self {
207 Address::Domain(Bytes::copy_from_slice(value.0.as_bytes()), value.1)
208 }
209}
210
211impl std::fmt::Display for Address {
212 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
213 match self {
214 Self::Domain(domain, port) => {
215 write!(f, "{}:{}", String::from_utf8_lossy(domain), port)
216 }
217 Self::SocketV4(addr) => write!(f, "{}", addr),
218 Self::SocketV6(addr) => write!(f, "{}", addr),
219 }
220 }
221}
222
223mod serde_bytes {
224 use bytes::Bytes;
225 use serde::{Deserialize, Deserializer, Serializer};
226
227 pub fn serialize<S>(bytes: &Bytes, serializer: S) -> Result<S::Ok, S::Error>
228 where
229 S: Serializer,
230 {
231 serializer.serialize_bytes(bytes)
232 }
233
234 pub fn deserialize<'de, D>(deserializer: D) -> Result<Bytes, D::Error>
235 where
236 D: Deserializer<'de>,
237 {
238 let vec: Vec<u8> = Vec::deserialize(deserializer)?;
239 Ok(Bytes::from(vec))
240 }
241}
242
243#[macro_export]
244macro_rules! impl_message_serde {
245 ($struct_name:ident) => {
246 impl $struct_name {
247 pub fn encode(&self) -> io::Result<Bytes> {
248 encode(self)
249 }
250
251 pub fn decode(bytes: &[u8]) -> io::Result<Self> {
252 decode(bytes)
253 }
254 }
255 };
256}
257
258impl_message_serde!(ClientHello);
259impl_message_serde!(UdpPacket);
260impl_message_serde!(Address);