1use std::io;
2use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6};
3use std::sync::LazyLock;
4
5use bytes::Bytes;
6use serde::{Deserialize, Serialize};
7
8pub type Secret = [u8; 32];
10
11pub const PROTOCOL_VERSION: u8 = 0x01;
13
14pub const MAX_DOMAIN_LENGTH: usize = 255;
16
17static BINCODE_CONFIG: LazyLock<bincode::config::Configuration> =
19 LazyLock::new(bincode::config::standard);
20
21pub fn encode<T: Serialize>(message: &T) -> io::Result<Bytes> {
27 bincode::serde::encode_to_vec(message, *BINCODE_CONFIG)
28 .map(Bytes::from)
29 .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, format!("encode error: {e}")))
30}
31
32pub fn decode<'a, T: Deserialize<'a>>(bytes: &'a [u8]) -> io::Result<T> {
38 bincode::serde::borrow_decode_from_slice(bytes, *BINCODE_CONFIG)
39 .map(|(msg, _)| msg)
40 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, format!("decode error: {e}")))
41}
42
43#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
45pub struct ClientHello {
46 pub version: u8,
48 pub secret: Secret,
50 #[serde(with = "serde_bytes")]
52 pub options: Bytes,
53}
54
55#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
57pub struct ClientConnect {
58 pub address: Address,
60}
61
62#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
63pub enum ServerAuthResponse {
64 Ok,
65 Err,
66}
67
68#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
74pub enum UdpPacket {
75 Unfragmented {
77 session_id: u64,
79 address: Address,
81 #[serde(with = "serde_bytes")]
83 data: Bytes,
84 },
85 Fragmented {
87 session_id: u64,
89 fragment_id: u32,
91 fragment_index: u16,
93 fragment_count: u16,
95 address: Option<Address>,
97 #[serde(with = "serde_bytes")]
99 data: Bytes,
100 },
101}
102
103impl UdpPacket {
104 pub fn fragmented_overhead() -> usize {
111 const FIXED_OVERHEAD: usize = 1 + 8 + 4 + 2 + 2;
114 const MAX_ADDRESS_OVERHEAD: usize = 1 + 2 + MAX_DOMAIN_LENGTH + 2;
117 FIXED_OVERHEAD + MAX_ADDRESS_OVERHEAD
118 }
119
120 pub fn split_packet(
134 session_id: u64,
135 address: Address,
136 data: Bytes,
137 max_payload_size: usize,
138 fragment_id: u32,
139 ) -> impl Iterator<Item = UdpPacket> {
140 let data_chunks: Vec<Bytes> = data
142 .chunks(max_payload_size)
143 .map(Bytes::copy_from_slice)
144 .collect();
145 let fragment_count = data_chunks.len() as u16;
146
147 assert!(fragment_count > 0, "fragment_count must be greater than 0");
149
150 data_chunks.into_iter().enumerate().map(move |(i, chunk)| {
151 let fragment_index = i as u16;
152 UdpPacket::Fragmented {
153 session_id,
154 fragment_id,
155 fragment_index,
156 fragment_count,
157 address: if fragment_index == 0 {
159 Some(address.clone())
160 } else {
161 None
162 },
163 data: chunk,
164 }
165 })
166 }
167}
168
169#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
175pub enum ServerConnectResponse {
176 Ok,
178 Err {
184 kind: ConnectErrorKind,
186 message: String,
188 },
189}
190
191#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
193pub enum ConnectErrorKind {
194 ConnectionRefused,
196 NetworkUnreachable,
198 HostUnreachable,
200 TimedOut,
202 #[serde(other)]
203 Other,
204}
205
206impl ConnectErrorKind {
207 pub fn from_io_error(error: &io::Error) -> Self {
214 match error.kind() {
215 io::ErrorKind::ConnectionRefused => ConnectErrorKind::ConnectionRefused,
216 io::ErrorKind::NetworkUnreachable => ConnectErrorKind::NetworkUnreachable,
217 io::ErrorKind::HostUnreachable => ConnectErrorKind::HostUnreachable,
218 io::ErrorKind::TimedOut => ConnectErrorKind::TimedOut,
219 _ => ConnectErrorKind::Other,
222 }
223 }
224}
225
226#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Hash)]
231pub enum Address {
232 SocketV4(SocketAddrV4),
234 SocketV6(SocketAddrV6),
236 Domain(#[serde(with = "serde_bytes")] Bytes, u16),
238}
239
240impl Address {
241 pub async fn to_socket_addr(&self) -> io::Result<SocketAddr> {
253 match self {
254 Self::SocketV4(addr) => Ok((*addr).into()),
255 Self::SocketV6(addr) => Ok((*addr).into()),
256 Self::Domain(domain, port) => {
257 let domain_str = std::str::from_utf8(domain).map_err(|_| {
258 io::Error::new(
259 io::ErrorKind::InvalidInput,
260 "domain name contains invalid utf-8 characters",
261 )
262 })?;
263
264 tokio::net::lookup_host((domain_str, *port))
265 .await?
266 .next()
267 .ok_or_else(|| {
268 io::Error::new(
269 io::ErrorKind::NotFound,
270 format!("domain name '{}' could not be resolved", domain_str),
271 )
272 })
273 }
274 }
275 }
276}
277
278impl From<SocketAddr> for Address {
279 fn from(value: SocketAddr) -> Self {
280 match value {
281 SocketAddr::V4(addr) => Self::SocketV4(addr),
282 SocketAddr::V6(addr) => Self::SocketV6(addr),
283 }
284 }
285}
286
287impl TryFrom<&str> for Address {
288 type Error = io::Error;
289
290 fn try_from(value: &str) -> Result<Self, Self::Error> {
291 if let Ok(addr) = value.parse::<SocketAddr>() {
292 return Ok(Address::from(addr));
293 }
294
295 if let Some((domain, port_str)) = value.rsplit_once(':')
296 && let Ok(port) = port_str.parse::<u16>()
297 {
298 if domain.is_empty() {
299 return Err(io::Error::new(
300 io::ErrorKind::InvalidInput,
301 "domain name cannot be empty",
302 ));
303 }
304
305 if domain.len() > MAX_DOMAIN_LENGTH {
306 return Err(io::Error::new(
307 io::ErrorKind::InvalidInput,
308 format!(
309 "domain name is too long: {} bytes (max {})",
310 domain.len(),
311 MAX_DOMAIN_LENGTH
312 ),
313 ));
314 }
315
316 return Ok(Address::Domain(
317 Bytes::copy_from_slice(domain.as_bytes()),
318 port,
319 ));
320 }
321
322 Err(io::Error::new(
323 io::ErrorKind::InvalidInput,
324 format!("invalid address format: {}", value),
325 ))
326 }
327}
328
329impl TryFrom<String> for Address {
330 type Error = io::Error;
331
332 fn try_from(value: String) -> Result<Self, Self::Error> {
333 Address::try_from(value.as_str())
334 }
335}
336
337impl From<(String, u16)> for Address {
338 fn from(value: (String, u16)) -> Self {
339 Address::Domain(Bytes::from(value.0), value.1)
340 }
341}
342
343impl From<(&str, u16)> for Address {
344 fn from(value: (&str, u16)) -> Self {
345 Address::Domain(Bytes::copy_from_slice(value.0.as_bytes()), value.1)
346 }
347}
348
349impl std::fmt::Display for Address {
350 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
351 match self {
352 Self::Domain(domain, port) => {
353 write!(f, "{}:{}", String::from_utf8_lossy(domain), port)
354 }
355 Self::SocketV4(addr) => write!(f, "{}", addr),
356 Self::SocketV6(addr) => write!(f, "{}", addr),
357 }
358 }
359}
360
361mod serde_bytes {
362 use bytes::Bytes;
363 use serde::{Deserialize, Deserializer, Serializer};
364
365 pub fn serialize<S>(bytes: &Bytes, serializer: S) -> Result<S::Ok, S::Error>
366 where
367 S: Serializer,
368 {
369 serializer.serialize_bytes(bytes)
370 }
371
372 pub fn deserialize<'de, D>(deserializer: D) -> Result<Bytes, D::Error>
373 where
374 D: Deserializer<'de>,
375 {
376 let vec: Vec<u8> = Vec::deserialize(deserializer)?;
377 Ok(Bytes::from(vec))
378 }
379}
380
381#[macro_export]
382macro_rules! impl_message_serde {
383 ($struct_name:ident) => {
384 impl $struct_name {
385 pub fn encode(&self) -> io::Result<Bytes> {
386 encode(self)
387 }
388
389 pub fn decode(bytes: &[u8]) -> io::Result<Self> {
390 decode(bytes)
391 }
392 }
393 };
394}
395
396impl_message_serde!(ClientHello);
397impl_message_serde!(UdpPacket);
398impl_message_serde!(Address);