use bytes::{Buf, BufMut, Bytes, BytesMut};
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use crate::VarInt;
use crate::coding::{self, Codec};
pub const CAPSULE_COMPRESSION_ASSIGN: u64 = 0x11;
pub const CAPSULE_COMPRESSION_ACK: u64 = 0x12;
pub const CAPSULE_COMPRESSION_CLOSE: u64 = 0x13;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct CompressionAssign {
pub context_id: VarInt,
pub ip_version: u8,
pub ip_address: Option<IpAddr>,
pub udp_port: Option<u16>,
}
impl CompressionAssign {
pub fn uncompressed(context_id: VarInt) -> Self {
Self {
context_id,
ip_version: 0,
ip_address: None,
udp_port: None,
}
}
pub fn compressed_v4(context_id: VarInt, addr: Ipv4Addr, port: u16) -> Self {
Self {
context_id,
ip_version: 4,
ip_address: Some(IpAddr::V4(addr)),
udp_port: Some(port),
}
}
pub fn compressed_v6(context_id: VarInt, addr: Ipv6Addr, port: u16) -> Self {
Self {
context_id,
ip_version: 6,
ip_address: Some(IpAddr::V6(addr)),
udp_port: Some(port),
}
}
pub fn is_uncompressed(&self) -> bool {
self.ip_version == 0
}
pub fn target(&self) -> Option<std::net::SocketAddr> {
match (self.ip_address, self.udp_port) {
(Some(ip), Some(port)) => Some(std::net::SocketAddr::new(ip, port)),
_ => None,
}
}
}
impl Codec for CompressionAssign {
fn decode<B: Buf>(buf: &mut B) -> coding::Result<Self> {
let context_id = VarInt::decode(buf)?;
if buf.remaining() < 1 {
return Err(coding::UnexpectedEnd);
}
let ip_version = buf.get_u8();
let (ip_address, udp_port) = if ip_version == 0 {
(None, None)
} else {
let ip = match ip_version {
4 => {
if buf.remaining() < 4 {
return Err(coding::UnexpectedEnd);
}
let mut octets = [0u8; 4];
buf.copy_to_slice(&mut octets);
IpAddr::V4(Ipv4Addr::from(octets))
}
6 => {
if buf.remaining() < 16 {
return Err(coding::UnexpectedEnd);
}
let mut octets = [0u8; 16];
buf.copy_to_slice(&mut octets);
IpAddr::V6(Ipv6Addr::from(octets))
}
_ => return Err(coding::UnexpectedEnd),
};
if buf.remaining() < 2 {
return Err(coding::UnexpectedEnd);
}
let port = buf.get_u16();
(Some(ip), Some(port))
};
Ok(Self {
context_id,
ip_version,
ip_address,
udp_port,
})
}
fn encode<B: BufMut>(&self, buf: &mut B) {
self.context_id.encode(buf);
buf.put_u8(self.ip_version);
if let (Some(ip), Some(port)) = (&self.ip_address, self.udp_port) {
match ip {
IpAddr::V4(v4) => buf.put_slice(&v4.octets()),
IpAddr::V6(v6) => buf.put_slice(&v6.octets()),
}
buf.put_u16(port);
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct CompressionAck {
pub context_id: VarInt,
}
impl CompressionAck {
pub fn new(context_id: VarInt) -> Self {
Self { context_id }
}
}
impl Codec for CompressionAck {
fn decode<B: Buf>(buf: &mut B) -> coding::Result<Self> {
let context_id = VarInt::decode(buf)?;
Ok(Self { context_id })
}
fn encode<B: BufMut>(&self, buf: &mut B) {
self.context_id.encode(buf);
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct CompressionClose {
pub context_id: VarInt,
}
impl CompressionClose {
pub fn new(context_id: VarInt) -> Self {
Self { context_id }
}
}
impl Codec for CompressionClose {
fn decode<B: Buf>(buf: &mut B) -> coding::Result<Self> {
let context_id = VarInt::decode(buf)?;
Ok(Self { context_id })
}
fn encode<B: BufMut>(&self, buf: &mut B) {
self.context_id.encode(buf);
}
}
#[derive(Debug, Clone)]
pub enum Capsule {
CompressionAssign(CompressionAssign),
CompressionAck(CompressionAck),
CompressionClose(CompressionClose),
Unknown {
capsule_type: VarInt,
data: Vec<u8>,
},
}
impl Capsule {
pub fn decode<B: Buf>(buf: &mut B) -> coding::Result<Self> {
let capsule_type = VarInt::decode(buf)?;
let length = VarInt::decode(buf)?;
let length_usize = length.into_inner() as usize;
if buf.remaining() < length_usize {
return Err(coding::UnexpectedEnd);
}
match capsule_type.into_inner() {
CAPSULE_COMPRESSION_ASSIGN => {
let capsule = CompressionAssign::decode(buf)?;
Ok(Capsule::CompressionAssign(capsule))
}
CAPSULE_COMPRESSION_ACK => {
let capsule = CompressionAck::decode(buf)?;
Ok(Capsule::CompressionAck(capsule))
}
CAPSULE_COMPRESSION_CLOSE => {
let capsule = CompressionClose::decode(buf)?;
Ok(Capsule::CompressionClose(capsule))
}
_ => {
let mut data = vec![0u8; length_usize];
buf.copy_to_slice(&mut data);
Ok(Capsule::Unknown { capsule_type, data })
}
}
}
pub fn encode(&self) -> Bytes {
let mut buf = BytesMut::new();
let mut payload = BytesMut::new();
let capsule_type = match self {
Capsule::CompressionAssign(c) => {
c.encode(&mut payload);
CAPSULE_COMPRESSION_ASSIGN
}
Capsule::CompressionAck(c) => {
c.encode(&mut payload);
CAPSULE_COMPRESSION_ACK
}
Capsule::CompressionClose(c) => {
c.encode(&mut payload);
CAPSULE_COMPRESSION_CLOSE
}
Capsule::Unknown { capsule_type, data } => {
payload.put_slice(data);
capsule_type.into_inner()
}
};
if let Ok(ct) = VarInt::from_u64(capsule_type) {
ct.encode(&mut buf);
}
if let Ok(len) = VarInt::from_u64(payload.len() as u64) {
len.encode(&mut buf);
}
buf.put(payload);
buf.freeze()
}
pub fn capsule_type(&self) -> u64 {
match self {
Capsule::CompressionAssign(_) => CAPSULE_COMPRESSION_ASSIGN,
Capsule::CompressionAck(_) => CAPSULE_COMPRESSION_ACK,
Capsule::CompressionClose(_) => CAPSULE_COMPRESSION_CLOSE,
Capsule::Unknown { capsule_type, .. } => capsule_type.into_inner(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_compression_assign_uncompressed_roundtrip() {
let original = CompressionAssign::uncompressed(VarInt::from_u32(2));
let mut buf = BytesMut::new();
original.encode(&mut buf);
let decoded = CompressionAssign::decode(&mut buf.freeze()).unwrap();
assert_eq!(original, decoded);
assert!(decoded.is_uncompressed());
assert!(decoded.target().is_none());
}
#[test]
fn test_compression_assign_ipv4_roundtrip() {
let addr = Ipv4Addr::new(192, 168, 1, 100);
let original = CompressionAssign::compressed_v4(VarInt::from_u32(4), addr, 8080);
let mut buf = BytesMut::new();
original.encode(&mut buf);
let decoded = CompressionAssign::decode(&mut buf.freeze()).unwrap();
assert_eq!(original, decoded);
assert!(!decoded.is_uncompressed());
assert_eq!(
decoded.target(),
Some(std::net::SocketAddr::new(IpAddr::V4(addr), 8080))
);
}
#[test]
fn test_compression_assign_ipv6_roundtrip() {
let addr = Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1);
let original = CompressionAssign::compressed_v6(VarInt::from_u32(6), addr, 443);
let mut buf = BytesMut::new();
original.encode(&mut buf);
let decoded = CompressionAssign::decode(&mut buf.freeze()).unwrap();
assert_eq!(original, decoded);
assert_eq!(decoded.ip_version, 6);
}
#[test]
fn test_compression_ack_roundtrip() {
let original = CompressionAck::new(VarInt::from_u32(42));
let mut buf = BytesMut::new();
original.encode(&mut buf);
let decoded = CompressionAck::decode(&mut buf.freeze()).unwrap();
assert_eq!(original, decoded);
}
#[test]
fn test_compression_close_roundtrip() {
let original = CompressionClose::new(VarInt::from_u32(99));
let mut buf = BytesMut::new();
original.encode(&mut buf);
let decoded = CompressionClose::decode(&mut buf.freeze()).unwrap();
assert_eq!(original, decoded);
}
#[test]
fn test_capsule_wrapper_encoding() {
let assign =
CompressionAssign::compressed_v4(VarInt::from_u32(2), Ipv4Addr::new(10, 0, 0, 1), 9000);
let capsule = Capsule::CompressionAssign(assign.clone());
let encoded = capsule.encode();
let mut buf = encoded;
let decoded = Capsule::decode(&mut buf).unwrap();
match decoded {
Capsule::CompressionAssign(c) => assert_eq!(c, assign),
_ => panic!("Expected CompressionAssign capsule"),
}
}
#[test]
fn test_capsule_type_identifiers() {
assert_eq!(
Capsule::CompressionAssign(CompressionAssign::uncompressed(VarInt::from_u32(1)))
.capsule_type(),
CAPSULE_COMPRESSION_ASSIGN
);
assert_eq!(
Capsule::CompressionAck(CompressionAck::new(VarInt::from_u32(1))).capsule_type(),
CAPSULE_COMPRESSION_ACK
);
assert_eq!(
Capsule::CompressionClose(CompressionClose::new(VarInt::from_u32(1))).capsule_type(),
CAPSULE_COMPRESSION_CLOSE
);
}
}