ombrac/
protocol.rs

1use std::io;
2use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6};
3use std::sync::LazyLock;
4
5use bytes::Bytes;
6use serde::{Deserialize, Serialize};
7
8pub type Secret = [u8; 32];
9
10pub const PROTOCOLS_VERSION: u8 = 0x01;
11
12static BINCODE_CONFIG: LazyLock<bincode::config::Configuration> =
13    LazyLock::new(bincode::config::standard);
14
15pub fn encode<T: Serialize>(message: &T) -> io::Result<Bytes> {
16    bincode::serde::encode_to_vec(message, *BINCODE_CONFIG)
17        .map(Bytes::from)
18        .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))
19}
20
21pub fn decode<'a, T: Deserialize<'a>>(bytes: &'a [u8]) -> io::Result<T> {
22    bincode::serde::borrow_decode_from_slice(bytes, *BINCODE_CONFIG)
23        .map(|(msg, _)| msg)
24        .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
25}
26
27#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
28pub struct ClientHello {
29    pub version: u8,
30    pub secret: Secret,
31    #[serde(with = "serde_bytes")]
32    pub options: Bytes,
33}
34
35#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
36pub struct ClientConnect {
37    pub address: Address,
38}
39
40#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
41pub enum ServerHandshakeResponse {
42    Ok,
43    Err(HandshakeError),
44}
45
46#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
47pub enum UdpPacket {
48    Unfragmented {
49        session_id: u64,
50        address: Address,
51        #[serde(with = "serde_bytes")]
52        data: Bytes,
53    },
54    Fragmented {
55        session_id: u64,
56        fragment_id: u32,
57        fragment_index: u16,
58        fragment_count: u16,
59        address: Option<Address>,
60        #[serde(with = "serde_bytes")]
61        data: Bytes,
62    },
63}
64
65impl UdpPacket {
66    pub fn fragmented_overhead() -> usize {
67        // Type + u64 + u32 + u16 + u16
68        let fixed_overhead = 1 + 8 + 4 + 2 + 2;
69        // 1 byte tag + 2 bytes len + 255 bytes domain + 2 bytes port
70        const MAX_ADDRESS_OVERHEAD: usize = 260;
71        fixed_overhead + MAX_ADDRESS_OVERHEAD
72    }
73
74    pub fn split_packet(
75        session_id: u64,
76        address: Address,
77        data: Bytes,
78        max_payload_size: usize,
79        fragment_id: u32,
80    ) -> impl Iterator<Item = UdpPacket> {
81        let data_chunks: Vec<Bytes> = data
82            .chunks(max_payload_size)
83            .map(Bytes::copy_from_slice)
84            .collect();
85        let fragment_count = data_chunks.len() as u16;
86
87        data_chunks.into_iter().enumerate().map(move |(i, chunk)| {
88            let fragment_index = i as u16;
89            UdpPacket::Fragmented {
90                session_id,
91                fragment_id,
92                fragment_index,
93                fragment_count,
94                address: if fragment_index == 0 {
95                    Some(address.clone())
96                } else {
97                    None
98                },
99                data: chunk,
100            }
101        })
102    }
103}
104
105#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
106pub enum HandshakeError {
107    UnsupportedVersion,
108    InvalidSecret,
109    InternalServerError,
110}
111
112#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Hash)]
113pub enum Address {
114    SocketV4(SocketAddrV4),
115    SocketV6(SocketAddrV6),
116    Domain(#[serde(with = "serde_bytes")] Bytes, u16),
117}
118
119impl Address {
120    pub async fn to_socket_addr(&self) -> io::Result<SocketAddr> {
121        match self {
122            Self::SocketV4(addr) => Ok((*addr).into()),
123            Self::SocketV6(addr) => Ok((*addr).into()),
124            Self::Domain(domain, port) => {
125                let domain_str = std::str::from_utf8(domain).map_err(|_| {
126                    io::Error::new(
127                        io::ErrorKind::InvalidInput,
128                        "Domain name contains invalid UTF-8 characters",
129                    )
130                })?;
131
132                match tokio::net::lookup_host((domain_str, *port)).await?.next() {
133                    Some(addr) => Ok(addr),
134                    None => Err(io::Error::new(
135                        io::ErrorKind::NotFound,
136                        format!("Domain name '{}' could not be resolved", domain_str),
137                    )),
138                }
139            }
140        }
141    }
142}
143
144impl From<SocketAddr> for Address {
145    fn from(value: SocketAddr) -> Self {
146        match value {
147            SocketAddr::V4(addr) => Self::SocketV4(addr),
148            SocketAddr::V6(addr) => Self::SocketV6(addr),
149        }
150    }
151}
152
153impl TryFrom<&str> for Address {
154    type Error = io::Error;
155
156    fn try_from(value: &str) -> Result<Self, Self::Error> {
157        if let Ok(addr) = value.parse::<SocketAddr>() {
158            return Ok(Address::from(addr));
159        }
160
161        if let Some((domain, port_str)) = value.rsplit_once(':')
162            && let Ok(port) = port_str.parse::<u16>()
163        {
164            if domain.is_empty() {
165                return Err(io::Error::new(
166                    io::ErrorKind::InvalidInput,
167                    "Domain name cannot be empty",
168                ));
169            }
170
171            if domain.len() > 255 {
172                return Err(io::Error::new(
173                    io::ErrorKind::InvalidInput,
174                    format!("Domain name is too long: {} bytes (max 255)", domain.len()),
175                ));
176            }
177
178            return Ok(Address::Domain(
179                Bytes::copy_from_slice(domain.as_bytes()),
180                port,
181            ));
182        }
183
184        Err(io::Error::new(
185            io::ErrorKind::InvalidInput,
186            format!("Invalid address format: {}", value),
187        ))
188    }
189}
190
191impl TryFrom<String> for Address {
192    type Error = io::Error;
193
194    fn try_from(value: String) -> Result<Self, Self::Error> {
195        Address::try_from(value.as_str())
196    }
197}
198
199impl From<(String, u16)> for Address {
200    fn from(value: (String, u16)) -> Self {
201        Address::Domain(Bytes::from(value.0), value.1)
202    }
203}
204
205impl From<(&str, u16)> for Address {
206    fn from(value: (&str, u16)) -> Self {
207        Address::Domain(Bytes::copy_from_slice(value.0.as_bytes()), value.1)
208    }
209}
210
211impl std::fmt::Display for Address {
212    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
213        match self {
214            Self::Domain(domain, port) => {
215                write!(f, "{}:{}", String::from_utf8_lossy(domain), port)
216            }
217            Self::SocketV4(addr) => write!(f, "{}", addr),
218            Self::SocketV6(addr) => write!(f, "{}", addr),
219        }
220    }
221}
222
223mod serde_bytes {
224    use bytes::Bytes;
225    use serde::{Deserialize, Deserializer, Serializer};
226
227    pub fn serialize<S>(bytes: &Bytes, serializer: S) -> Result<S::Ok, S::Error>
228    where
229        S: Serializer,
230    {
231        serializer.serialize_bytes(bytes)
232    }
233
234    pub fn deserialize<'de, D>(deserializer: D) -> Result<Bytes, D::Error>
235    where
236        D: Deserializer<'de>,
237    {
238        let vec: Vec<u8> = Vec::deserialize(deserializer)?;
239        Ok(Bytes::from(vec))
240    }
241}
242
243#[macro_export]
244macro_rules! impl_message_serde {
245    ($struct_name:ident) => {
246        impl $struct_name {
247            pub fn encode(&self) -> io::Result<Bytes> {
248                encode(self)
249            }
250
251            pub fn decode(bytes: &[u8]) -> io::Result<Self> {
252                decode(bytes)
253            }
254        }
255    };
256}
257
258impl_message_serde!(ClientHello);
259impl_message_serde!(UdpPacket);
260impl_message_serde!(Address);