Skip to main content

ombrac/
protocol.rs

1use std::io;
2use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6};
3use std::sync::LazyLock;
4
5use bytes::Bytes;
6use serde::{Deserialize, Serialize};
7
8/// Secret key type for authentication (32 bytes, 256 bits).
9pub type Secret = [u8; 32];
10
11/// Current protocol version.
12pub const PROTOCOL_VERSION: u8 = 0x01;
13
14/// Maximum domain name length in bytes (RFC 1035).
15pub const MAX_DOMAIN_LENGTH: usize = 255;
16
17/// Bincode configuration for protocol message serialization.
18static BINCODE_CONFIG: LazyLock<bincode::config::Configuration> =
19    LazyLock::new(bincode::config::standard);
20
21/// Encodes a protocol message into bytes.
22///
23/// # Errors
24///
25/// Returns an error if serialization fails.
26pub fn encode<T: Serialize>(message: &T) -> io::Result<Bytes> {
27    bincode::serde::encode_to_vec(message, *BINCODE_CONFIG)
28        .map(Bytes::from)
29        .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, format!("encode error: {e}")))
30}
31
32/// Decodes a protocol message from bytes.
33///
34/// # Errors
35///
36/// Returns an error if deserialization fails or the data is malformed.
37pub fn decode<'a, T: Deserialize<'a>>(bytes: &'a [u8]) -> io::Result<T> {
38    bincode::serde::borrow_decode_from_slice(bytes, *BINCODE_CONFIG)
39        .map(|(msg, _)| msg)
40        .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, format!("decode error: {e}")))
41}
42
43/// Client authentication message containing credentials and configuration.
44#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
45pub struct ClientHello {
46    /// Protocol version the client supports.
47    pub version: u8,
48    /// Authentication secret (32-byte hash of the configured secret).
49    pub secret: Secret,
50    /// Optional protocol extensions and configuration (opaque to the protocol).
51    #[serde(with = "serde_bytes")]
52    pub options: Bytes,
53}
54
55/// Client connection request to establish a tunnel to a destination.
56#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
57pub struct ClientConnect {
58    /// Destination address to connect to (IP or domain name).
59    pub address: Address,
60}
61
62#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
63pub enum ServerAuthResponse {
64    Ok,
65    Err,
66}
67
68/// UDP packet representation with support for fragmentation.
69///
70/// Large UDP packets are automatically fragmented when they exceed the
71/// transport layer's maximum datagram size. Fragments are reassembled
72/// on the receiving side.
73#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
74pub enum UdpPacket {
75    /// Complete unfragmented packet.
76    Unfragmented {
77        /// Unique session identifier for this UDP session.
78        session_id: u64,
79        /// Destination address for this packet.
80        address: Address,
81        /// Packet payload.
82        #[serde(with = "serde_bytes")]
83        data: Bytes,
84    },
85    /// Fragment of a larger packet.
86    Fragmented {
87        /// Unique session identifier for this UDP session.
88        session_id: u64,
89        /// Unique fragment identifier for this fragmentation operation.
90        fragment_id: u32,
91        /// Zero-based index of this fragment within the packet.
92        fragment_index: u16,
93        /// Total number of fragments in this packet.
94        fragment_count: u16,
95        /// Destination address (only present in the first fragment).
96        address: Option<Address>,
97        /// Fragment payload.
98        #[serde(with = "serde_bytes")]
99        data: Bytes,
100    },
101}
102
103impl UdpPacket {
104    /// Calculates the overhead for a fragmented packet.
105    ///
106    /// This includes the fixed overhead (discriminant, session_id, fragment_id,
107    /// fragment_index, fragment_count) plus the maximum possible address overhead.
108    ///
109    /// Returns the maximum overhead in bytes for fragmentation calculations.
110    pub fn fragmented_overhead() -> usize {
111        // Discriminant (1 byte) + session_id (8 bytes) + fragment_id (4 bytes) +
112        // fragment_index (2 bytes) + fragment_count (2 bytes)
113        const FIXED_OVERHEAD: usize = 1 + 8 + 4 + 2 + 2;
114        // Maximum address overhead: discriminant (1 byte) + length (2 bytes) +
115        // max domain (255 bytes) + port (2 bytes)
116        const MAX_ADDRESS_OVERHEAD: usize = 1 + 2 + MAX_DOMAIN_LENGTH + 2;
117        FIXED_OVERHEAD + MAX_ADDRESS_OVERHEAD
118    }
119
120    /// Splits a large packet into fragments.
121    ///
122    /// # Arguments
123    ///
124    /// * `session_id` - The session identifier for this packet
125    /// * `address` - The destination address (included only in the first fragment)
126    /// * `data` - The packet data to fragment
127    /// * `max_payload_size` - Maximum payload size per fragment (after overhead)
128    /// * `fragment_id` - Unique identifier for this fragmentation operation
129    ///
130    /// # Returns
131    ///
132    /// An iterator over `UdpPacket::Fragmented` packets.
133    pub fn split_packet(
134        session_id: u64,
135        address: Address,
136        data: Bytes,
137        max_payload_size: usize,
138        fragment_id: u32,
139    ) -> impl Iterator<Item = UdpPacket> {
140        // Split data into chunks, ensuring each chunk fits within max_payload_size
141        let data_chunks: Vec<Bytes> = data
142            .chunks(max_payload_size)
143            .map(Bytes::copy_from_slice)
144            .collect();
145        let fragment_count = data_chunks.len() as u16;
146
147        // Ensure fragment_count fits in u16
148        assert!(fragment_count > 0, "fragment_count must be greater than 0");
149
150        data_chunks.into_iter().enumerate().map(move |(i, chunk)| {
151            let fragment_index = i as u16;
152            UdpPacket::Fragmented {
153                session_id,
154                fragment_id,
155                fragment_index,
156                fragment_count,
157                // Only include address in the first fragment to save bandwidth
158                address: if fragment_index == 0 {
159                    Some(address.clone())
160                } else {
161                    None
162                },
163                data: chunk,
164            }
165        })
166    }
167}
168
169/// Response to a client's connection request.
170///
171/// This message is sent by the server after attempting to connect to the
172/// destination address. It indicates whether the connection was successful
173/// or failed, allowing the client to properly handle TCP state.
174#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
175pub enum ServerConnectResponse {
176    /// Connection to the destination was successful.
177    Ok,
178    /// Connection to the destination failed.
179    ///
180    /// The error message provides details about why the connection failed,
181    /// which helps the client understand the failure context and avoid
182    /// retry storms in application-layer protocols.
183    Err {
184        /// Error kind that categorizes the failure
185        kind: ConnectErrorKind,
186        /// Human-readable error message
187        message: String,
188    },
189}
190
191/// Categorizes connection errors to help clients handle them appropriately.
192#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
193pub enum ConnectErrorKind {
194    /// Connection refused by the destination
195    ConnectionRefused,
196    /// Network unreachable
197    NetworkUnreachable,
198    /// Host unreachable
199    HostUnreachable,
200    /// Connection timed out
201    TimedOut,
202    #[serde(other)]
203    Other,
204}
205
206impl ConnectErrorKind {
207    /// Converts an `io::Error` to a `ConnectErrorKind` based on the error kind.
208    ///
209    /// This function maps standard IO error kinds to protocol error kinds,
210    /// ensuring consistent error handling across the codebase. DNS resolution
211    /// failures are categorized as `Other` since they can manifest with
212    /// different error kinds depending on the platform.
213    pub fn from_io_error(error: &io::Error) -> Self {
214        match error.kind() {
215            io::ErrorKind::ConnectionRefused => ConnectErrorKind::ConnectionRefused,
216            io::ErrorKind::NetworkUnreachable => ConnectErrorKind::NetworkUnreachable,
217            io::ErrorKind::HostUnreachable => ConnectErrorKind::HostUnreachable,
218            io::ErrorKind::TimedOut => ConnectErrorKind::TimedOut,
219            // All other errors, including DNS resolution failures (NotFound, etc.),
220            // are categorized as Other
221            _ => ConnectErrorKind::Other,
222        }
223    }
224}
225
226/// Network address representation supporting IPv4, IPv6, and domain names.
227///
228/// This type is used throughout the protocol to specify destination addresses
229/// for both TCP and UDP connections.
230#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Hash)]
231pub enum Address {
232    /// IPv4 socket address.
233    SocketV4(SocketAddrV4),
234    /// IPv6 socket address.
235    SocketV6(SocketAddrV6),
236    /// Domain name with port (domain name is limited to 255 bytes per RFC 1035).
237    Domain(#[serde(with = "serde_bytes")] Bytes, u16),
238}
239
240impl Address {
241    /// Resolves this address to a `SocketAddr`.
242    ///
243    /// For IP addresses, this is a no-op. For domain names, this performs
244    /// asynchronous DNS resolution.
245    ///
246    /// # Errors
247    ///
248    /// Returns an error if:
249    /// - The domain name contains invalid UTF-8
250    /// - DNS resolution fails
251    /// - No addresses are found for the domain
252    pub async fn to_socket_addr(&self) -> io::Result<SocketAddr> {
253        match self {
254            Self::SocketV4(addr) => Ok((*addr).into()),
255            Self::SocketV6(addr) => Ok((*addr).into()),
256            Self::Domain(domain, port) => {
257                let domain_str = std::str::from_utf8(domain).map_err(|_| {
258                    io::Error::new(
259                        io::ErrorKind::InvalidInput,
260                        "domain name contains invalid utf-8 characters",
261                    )
262                })?;
263
264                tokio::net::lookup_host((domain_str, *port))
265                    .await?
266                    .next()
267                    .ok_or_else(|| {
268                        io::Error::new(
269                            io::ErrorKind::NotFound,
270                            format!("domain name '{}' could not be resolved", domain_str),
271                        )
272                    })
273            }
274        }
275    }
276}
277
278impl From<SocketAddr> for Address {
279    fn from(value: SocketAddr) -> Self {
280        match value {
281            SocketAddr::V4(addr) => Self::SocketV4(addr),
282            SocketAddr::V6(addr) => Self::SocketV6(addr),
283        }
284    }
285}
286
287impl TryFrom<&str> for Address {
288    type Error = io::Error;
289
290    fn try_from(value: &str) -> Result<Self, Self::Error> {
291        if let Ok(addr) = value.parse::<SocketAddr>() {
292            return Ok(Address::from(addr));
293        }
294
295        if let Some((domain, port_str)) = value.rsplit_once(':')
296            && let Ok(port) = port_str.parse::<u16>()
297        {
298            if domain.is_empty() {
299                return Err(io::Error::new(
300                    io::ErrorKind::InvalidInput,
301                    "domain name cannot be empty",
302                ));
303            }
304
305            if domain.len() > MAX_DOMAIN_LENGTH {
306                return Err(io::Error::new(
307                    io::ErrorKind::InvalidInput,
308                    format!(
309                        "domain name is too long: {} bytes (max {})",
310                        domain.len(),
311                        MAX_DOMAIN_LENGTH
312                    ),
313                ));
314            }
315
316            return Ok(Address::Domain(
317                Bytes::copy_from_slice(domain.as_bytes()),
318                port,
319            ));
320        }
321
322        Err(io::Error::new(
323            io::ErrorKind::InvalidInput,
324            format!("invalid address format: {}", value),
325        ))
326    }
327}
328
329impl TryFrom<String> for Address {
330    type Error = io::Error;
331
332    fn try_from(value: String) -> Result<Self, Self::Error> {
333        Address::try_from(value.as_str())
334    }
335}
336
337impl From<(String, u16)> for Address {
338    fn from(value: (String, u16)) -> Self {
339        Address::Domain(Bytes::from(value.0), value.1)
340    }
341}
342
343impl From<(&str, u16)> for Address {
344    fn from(value: (&str, u16)) -> Self {
345        Address::Domain(Bytes::copy_from_slice(value.0.as_bytes()), value.1)
346    }
347}
348
349impl std::fmt::Display for Address {
350    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
351        match self {
352            Self::Domain(domain, port) => {
353                write!(f, "{}:{}", String::from_utf8_lossy(domain), port)
354            }
355            Self::SocketV4(addr) => write!(f, "{}", addr),
356            Self::SocketV6(addr) => write!(f, "{}", addr),
357        }
358    }
359}
360
361mod serde_bytes {
362    use bytes::Bytes;
363    use serde::{Deserialize, Deserializer, Serializer};
364
365    pub fn serialize<S>(bytes: &Bytes, serializer: S) -> Result<S::Ok, S::Error>
366    where
367        S: Serializer,
368    {
369        serializer.serialize_bytes(bytes)
370    }
371
372    pub fn deserialize<'de, D>(deserializer: D) -> Result<Bytes, D::Error>
373    where
374        D: Deserializer<'de>,
375    {
376        let vec: Vec<u8> = Vec::deserialize(deserializer)?;
377        Ok(Bytes::from(vec))
378    }
379}
380
381#[macro_export]
382macro_rules! impl_message_serde {
383    ($struct_name:ident) => {
384        impl $struct_name {
385            pub fn encode(&self) -> io::Result<Bytes> {
386                encode(self)
387            }
388
389            pub fn decode(bytes: &[u8]) -> io::Result<Self> {
390                decode(bytes)
391            }
392        }
393    };
394}
395
396impl_message_serde!(ClientHello);
397impl_message_serde!(UdpPacket);
398impl_message_serde!(Address);