use std::io;
use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6};
use std::sync::LazyLock;
use bytes::Bytes;
use serde::{Deserialize, Serialize};
pub type Secret = [u8; 32];
pub const PROTOCOL_VERSION: u8 = 0x01;
pub const MAX_DOMAIN_LENGTH: usize = 255;
static BINCODE_CONFIG: LazyLock<bincode::config::Configuration> =
LazyLock::new(bincode::config::standard);
pub fn encode<T: Serialize>(message: &T) -> io::Result<Bytes> {
bincode::serde::encode_to_vec(message, *BINCODE_CONFIG)
.map(Bytes::from)
.map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, format!("encode error: {e}")))
}
pub fn decode<'a, T: Deserialize<'a>>(bytes: &'a [u8]) -> io::Result<T> {
bincode::serde::borrow_decode_from_slice(bytes, *BINCODE_CONFIG)
.map(|(msg, _)| msg)
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, format!("decode error: {e}")))
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct ClientHello {
pub version: u8,
pub secret: Secret,
#[serde(with = "serde_bytes")]
pub options: Bytes,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct ClientConnect {
pub address: Address,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum ServerAuthResponse {
Ok,
Err,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum UdpPacket {
Unfragmented {
session_id: u64,
address: Address,
#[serde(with = "serde_bytes")]
data: Bytes,
},
Fragmented {
session_id: u64,
fragment_id: u32,
fragment_index: u16,
fragment_count: u16,
address: Option<Address>,
#[serde(with = "serde_bytes")]
data: Bytes,
},
}
impl UdpPacket {
pub fn fragmented_overhead() -> usize {
const FIXED_OVERHEAD: usize = 1 + 8 + 4 + 2 + 2;
const MAX_ADDRESS_OVERHEAD: usize = 1 + 2 + MAX_DOMAIN_LENGTH + 2;
FIXED_OVERHEAD + MAX_ADDRESS_OVERHEAD
}
pub fn split_packet(
session_id: u64,
address: Address,
data: Bytes,
max_payload_size: usize,
fragment_id: u32,
) -> impl Iterator<Item = UdpPacket> {
let data_chunks: Vec<Bytes> = data
.chunks(max_payload_size)
.map(Bytes::copy_from_slice)
.collect();
let fragment_count = data_chunks.len() as u16;
assert!(fragment_count > 0, "fragment_count must be greater than 0");
data_chunks.into_iter().enumerate().map(move |(i, chunk)| {
let fragment_index = i as u16;
UdpPacket::Fragmented {
session_id,
fragment_id,
fragment_index,
fragment_count,
address: if fragment_index == 0 {
Some(address.clone())
} else {
None
},
data: chunk,
}
})
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum ServerConnectResponse {
Ok,
Err {
kind: ConnectErrorKind,
message: String,
},
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum ConnectErrorKind {
ConnectionRefused,
NetworkUnreachable,
HostUnreachable,
TimedOut,
#[serde(other)]
Other,
}
impl ConnectErrorKind {
pub fn from_io_error(error: &io::Error) -> Self {
match error.kind() {
io::ErrorKind::ConnectionRefused => ConnectErrorKind::ConnectionRefused,
io::ErrorKind::NetworkUnreachable => ConnectErrorKind::NetworkUnreachable,
io::ErrorKind::HostUnreachable => ConnectErrorKind::HostUnreachable,
io::ErrorKind::TimedOut => ConnectErrorKind::TimedOut,
_ => ConnectErrorKind::Other,
}
}
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Hash)]
pub enum Address {
SocketV4(SocketAddrV4),
SocketV6(SocketAddrV6),
Domain(#[serde(with = "serde_bytes")] Bytes, u16),
}
impl Address {
pub async fn to_socket_addr(&self) -> io::Result<SocketAddr> {
match self {
Self::SocketV4(addr) => Ok((*addr).into()),
Self::SocketV6(addr) => Ok((*addr).into()),
Self::Domain(domain, port) => {
let domain_str = std::str::from_utf8(domain).map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidInput,
"domain name contains invalid utf-8 characters",
)
})?;
tokio::net::lookup_host((domain_str, *port))
.await?
.next()
.ok_or_else(|| {
io::Error::new(
io::ErrorKind::NotFound,
format!("domain name '{}' could not be resolved", domain_str),
)
})
}
}
}
}
impl From<SocketAddr> for Address {
fn from(value: SocketAddr) -> Self {
match value {
SocketAddr::V4(addr) => Self::SocketV4(addr),
SocketAddr::V6(addr) => Self::SocketV6(addr),
}
}
}
impl TryFrom<&str> for Address {
type Error = io::Error;
fn try_from(value: &str) -> Result<Self, Self::Error> {
if let Ok(addr) = value.parse::<SocketAddr>() {
return Ok(Address::from(addr));
}
if let Some((domain, port_str)) = value.rsplit_once(':')
&& let Ok(port) = port_str.parse::<u16>()
{
if domain.is_empty() {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"domain name cannot be empty",
));
}
if domain.len() > MAX_DOMAIN_LENGTH {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!(
"domain name is too long: {} bytes (max {})",
domain.len(),
MAX_DOMAIN_LENGTH
),
));
}
return Ok(Address::Domain(
Bytes::copy_from_slice(domain.as_bytes()),
port,
));
}
Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!("invalid address format: {}", value),
))
}
}
impl TryFrom<String> for Address {
type Error = io::Error;
fn try_from(value: String) -> Result<Self, Self::Error> {
Address::try_from(value.as_str())
}
}
impl From<(String, u16)> for Address {
fn from(value: (String, u16)) -> Self {
Address::Domain(Bytes::from(value.0), value.1)
}
}
impl From<(&str, u16)> for Address {
fn from(value: (&str, u16)) -> Self {
Address::Domain(Bytes::copy_from_slice(value.0.as_bytes()), value.1)
}
}
impl std::fmt::Display for Address {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Domain(domain, port) => {
write!(f, "{}:{}", String::from_utf8_lossy(domain), port)
}
Self::SocketV4(addr) => write!(f, "{}", addr),
Self::SocketV6(addr) => write!(f, "{}", addr),
}
}
}
mod serde_bytes {
use bytes::Bytes;
use serde::{Deserialize, Deserializer, Serializer};
pub fn serialize<S>(bytes: &Bytes, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_bytes(bytes)
}
pub fn deserialize<'de, D>(deserializer: D) -> Result<Bytes, D::Error>
where
D: Deserializer<'de>,
{
let vec: Vec<u8> = Vec::deserialize(deserializer)?;
Ok(Bytes::from(vec))
}
}
#[macro_export]
macro_rules! impl_message_serde {
($struct_name:ident) => {
impl $struct_name {
pub fn encode(&self) -> io::Result<Bytes> {
encode(self)
}
pub fn decode(bytes: &[u8]) -> io::Result<Self> {
decode(bytes)
}
}
};
}
impl_message_serde!(ClientHello);
impl_message_serde!(UdpPacket);
impl_message_serde!(Address);
#[cfg(test)]
mod tests {
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6};
use super::*;
#[test]
fn test_address_try_from_ipv4() {
let addr = Address::try_from("1.2.3.4:80").unwrap();
match addr {
Address::SocketV4(a) => {
assert_eq!(a.ip(), &Ipv4Addr::new(1, 2, 3, 4));
assert_eq!(a.port(), 80);
}
_ => panic!("expected SocketV4"),
}
}
#[test]
fn test_address_try_from_ipv6() {
let addr = Address::try_from("[::1]:443").unwrap();
match addr {
Address::SocketV6(a) => {
assert_eq!(a.ip(), &Ipv6Addr::LOCALHOST);
assert_eq!(a.port(), 443);
}
_ => panic!("expected SocketV6"),
}
}
#[test]
fn test_address_try_from_domain() {
let addr = Address::try_from("example.com:8080").unwrap();
match addr {
Address::Domain(bytes, port) => {
assert_eq!(bytes.as_ref(), b"example.com");
assert_eq!(port, 8080);
}
_ => panic!("expected Domain"),
}
}
#[test]
fn test_address_try_from_empty_string() {
let err = Address::try_from("").unwrap_err();
assert_eq!(err.kind(), std::io::ErrorKind::InvalidInput);
}
#[test]
fn test_address_try_from_empty_domain() {
let err = Address::try_from(":80").unwrap_err();
assert_eq!(err.kind(), std::io::ErrorKind::InvalidInput);
assert!(err.to_string().contains("empty"));
}
#[test]
fn test_address_try_from_domain_too_long() {
let long_domain = format!("{}:80", "a".repeat(MAX_DOMAIN_LENGTH + 1));
let err = Address::try_from(long_domain.as_str()).unwrap_err();
assert_eq!(err.kind(), std::io::ErrorKind::InvalidInput);
assert!(err.to_string().contains("too long"));
}
#[test]
fn test_address_try_from_invalid_port() {
let err = Address::try_from("example.com:notaport").unwrap_err();
assert_eq!(err.kind(), std::io::ErrorKind::InvalidInput);
}
#[test]
fn test_address_try_from_port_overflow() {
let err = Address::try_from("example.com:65536").unwrap_err();
assert_eq!(err.kind(), std::io::ErrorKind::InvalidInput);
}
#[tokio::test]
async fn test_address_to_socket_addr_ipv4() {
let addr = Address::SocketV4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 9999));
let result = addr.to_socket_addr().await.unwrap();
assert_eq!(result.port(), 9999);
assert_eq!(result.ip().to_string(), "127.0.0.1");
}
#[tokio::test]
async fn test_address_to_socket_addr_ipv6() {
let addr =
Address::SocketV6(SocketAddrV6::new(Ipv6Addr::LOCALHOST, 443, 0, 0));
let result = addr.to_socket_addr().await.unwrap();
assert_eq!(result.port(), 443);
}
#[tokio::test]
async fn test_address_to_socket_addr_invalid_utf8() {
let addr = Address::Domain(Bytes::from_static(b"\xff\xfe"), 80);
let err = addr.to_socket_addr().await.unwrap_err();
assert_eq!(err.kind(), std::io::ErrorKind::InvalidInput);
assert!(err.to_string().contains("utf-8"));
}
#[tokio::test]
async fn test_address_to_socket_addr_localhost_domain() {
let addr = Address::Domain(Bytes::from_static(b"localhost"), 80);
let result = addr.to_socket_addr().await.unwrap();
assert_eq!(result.port(), 80);
}
#[test]
fn test_encode_decode_client_hello() {
let original = ClientHello {
version: PROTOCOL_VERSION,
secret: [0xab; 32],
options: Bytes::from_static(b"opts"),
};
let bytes = encode(&original).unwrap();
let decoded: ClientHello = decode(&bytes).unwrap();
assert_eq!(decoded, original);
}
#[test]
fn test_encode_decode_client_hello_empty_options() {
let original = ClientHello {
version: 1,
secret: [0u8; 32],
options: Bytes::new(),
};
let bytes = encode(&original).unwrap();
let decoded: ClientHello = decode(&bytes).unwrap();
assert_eq!(decoded, original);
}
#[test]
fn test_encode_decode_client_connect_ipv4() {
let original = ClientConnect {
address: Address::SocketV4(SocketAddrV4::new(Ipv4Addr::new(10, 0, 0, 1), 443)),
};
let bytes = encode(&original).unwrap();
let decoded: ClientConnect = decode(&bytes).unwrap();
assert_eq!(decoded, original);
}
#[test]
fn test_encode_decode_client_connect_domain() {
let original = ClientConnect {
address: Address::Domain(Bytes::from_static(b"example.com"), 8080),
};
let bytes = encode(&original).unwrap();
let decoded: ClientConnect = decode(&bytes).unwrap();
assert_eq!(decoded, original);
}
#[test]
fn test_encode_decode_udp_packet_unfragmented() {
let original = UdpPacket::Unfragmented {
session_id: 42,
address: Address::SocketV4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 53)),
data: Bytes::from_static(b"hello"),
};
let bytes = original.encode().unwrap();
let decoded = UdpPacket::decode(&bytes).unwrap();
assert_eq!(decoded, original);
}
#[test]
fn test_encode_decode_udp_packet_fragmented() {
let original = UdpPacket::Fragmented {
session_id: 1,
fragment_id: 7,
fragment_index: 1,
fragment_count: 3,
address: None,
data: Bytes::from_static(b"chunk"),
};
let bytes = original.encode().unwrap();
let decoded = UdpPacket::decode(&bytes).unwrap();
assert_eq!(decoded, original);
}
#[test]
fn test_decode_rejects_garbage() {
let err = decode::<ClientHello>(&[0xde, 0xad, 0xbe, 0xef]).unwrap_err();
assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
}
#[test]
fn test_connect_error_kind_connection_refused() {
let e = std::io::Error::new(std::io::ErrorKind::ConnectionRefused, "");
assert_eq!(
ConnectErrorKind::from_io_error(&e),
ConnectErrorKind::ConnectionRefused
);
}
#[test]
fn test_connect_error_kind_network_unreachable() {
let e = std::io::Error::new(std::io::ErrorKind::NetworkUnreachable, "");
assert_eq!(
ConnectErrorKind::from_io_error(&e),
ConnectErrorKind::NetworkUnreachable
);
}
#[test]
fn test_connect_error_kind_host_unreachable() {
let e = std::io::Error::new(std::io::ErrorKind::HostUnreachable, "");
assert_eq!(
ConnectErrorKind::from_io_error(&e),
ConnectErrorKind::HostUnreachable
);
}
#[test]
fn test_connect_error_kind_timed_out() {
let e = std::io::Error::new(std::io::ErrorKind::TimedOut, "");
assert_eq!(
ConnectErrorKind::from_io_error(&e),
ConnectErrorKind::TimedOut
);
}
#[test]
fn test_connect_error_kind_other_not_found() {
let e = std::io::Error::new(std::io::ErrorKind::NotFound, "dns");
assert_eq!(ConnectErrorKind::from_io_error(&e), ConnectErrorKind::Other);
}
#[test]
fn test_connect_error_kind_other_permission_denied() {
let e = std::io::Error::new(std::io::ErrorKind::PermissionDenied, "");
assert_eq!(ConnectErrorKind::from_io_error(&e), ConnectErrorKind::Other);
}
fn make_ipv4_addr() -> Address {
Address::SocketV4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 5000))
}
#[test]
fn test_split_packet_basic() {
let frags: Vec<_> =
UdpPacket::split_packet(1, make_ipv4_addr(), Bytes::from(vec![0u8; 300]), 100, 42)
.collect();
assert_eq!(frags.len(), 3);
for (i, frag) in frags.iter().enumerate() {
match frag {
UdpPacket::Fragmented {
session_id,
fragment_id,
fragment_index,
fragment_count,
address,
..
} => {
assert_eq!(*session_id, 1);
assert_eq!(*fragment_id, 42);
assert_eq!(*fragment_index, i as u16);
assert_eq!(*fragment_count, 3);
if i == 0 {
assert!(address.is_some());
} else {
assert!(address.is_none());
}
}
_ => panic!("expected Fragmented"),
}
}
}
#[test]
fn test_split_packet_single_fragment() {
let frags: Vec<_> =
UdpPacket::split_packet(5, make_ipv4_addr(), Bytes::from(vec![1u8; 50]), 100, 1)
.collect();
assert_eq!(frags.len(), 1);
match &frags[0] {
UdpPacket::Fragmented {
fragment_index,
fragment_count,
address,
..
} => {
assert_eq!(*fragment_index, 0);
assert_eq!(*fragment_count, 1);
assert!(address.is_some());
}
_ => panic!("expected Fragmented"),
}
}
#[test]
fn test_split_packet_exact_boundary() {
let frags: Vec<_> =
UdpPacket::split_packet(2, make_ipv4_addr(), Bytes::from(vec![0u8; 100]), 100, 0)
.collect();
assert_eq!(frags.len(), 1);
}
#[test]
fn test_split_packet_one_byte_over() {
let frags: Vec<_> =
UdpPacket::split_packet(3, make_ipv4_addr(), Bytes::from(vec![7u8; 101]), 100, 0)
.collect();
assert_eq!(frags.len(), 2);
match &frags[1] {
UdpPacket::Fragmented { data, .. } => assert_eq!(data.len(), 1),
_ => panic!("expected Fragmented"),
}
}
#[test]
fn test_split_packet_data_integrity() {
let original: Vec<u8> = (0u16..500).map(|i| (i % 256) as u8).collect();
let frags: Vec<_> = UdpPacket::split_packet(
9,
make_ipv4_addr(),
Bytes::from(original.clone()),
100,
5,
)
.collect();
let reassembled: Vec<u8> = frags
.iter()
.flat_map(|f| match f {
UdpPacket::Fragmented { data, .. } => data.to_vec(),
_ => panic!("expected Fragmented"),
})
.collect();
assert_eq!(reassembled, original);
}
#[test]
fn test_fragmented_overhead_value() {
assert_eq!(UdpPacket::fragmented_overhead(), 277);
}
#[test]
fn test_address_display_ipv4() {
let addr =
Address::SocketV4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 8080));
assert_eq!(format!("{addr}"), "127.0.0.1:8080");
}
#[test]
fn test_address_display_domain() {
let addr = Address::Domain(Bytes::from_static(b"example.com"), 443);
assert_eq!(format!("{addr}"), "example.com:443");
}
}