use crate::{
VarInt,
coding::{BufExt, BufMutExt, UnexpectedEnd},
frame::{FrameStruct, FrameType},
};
use bytes::{Buf, BufMut};
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
#[derive(Debug, Clone, PartialEq, Eq)]
#[allow(dead_code)]
pub struct RfcAddAddress {
pub sequence_number: VarInt,
pub address: SocketAddr,
}
#[allow(dead_code)]
impl RfcAddAddress {
pub fn encode<W: BufMut>(&self, buf: &mut W) {
match self.address {
SocketAddr::V4(_) => buf.write_var_or_debug_assert(FrameType::ADD_ADDRESS_IPV4.0),
SocketAddr::V6(_) => buf.write_var_or_debug_assert(FrameType::ADD_ADDRESS_IPV6.0),
}
buf.write_var_or_debug_assert(self.sequence_number.0);
match self.address {
SocketAddr::V4(addr) => {
buf.put_slice(&addr.ip().octets());
buf.put_u16(addr.port());
}
SocketAddr::V6(addr) => {
buf.put_slice(&addr.ip().octets());
buf.put_u16(addr.port());
}
}
}
pub fn decode<R: Buf>(r: &mut R, is_ipv6: bool) -> Result<Self, UnexpectedEnd> {
let sequence_number = VarInt::from_u64(r.get_var()?).map_err(|_| UnexpectedEnd)?;
let address = if is_ipv6 {
if r.remaining() < 16 + 2 {
return Err(UnexpectedEnd);
}
let mut octets = [0u8; 16];
r.copy_to_slice(&mut octets);
let port = r.get_u16();
SocketAddr::V6(SocketAddrV6::new(
Ipv6Addr::from(octets),
port,
0, 0, ))
} else {
if r.remaining() < 4 + 2 {
return Err(UnexpectedEnd);
}
let mut octets = [0u8; 4];
r.copy_to_slice(&mut octets);
let port = r.get_u16();
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::from(octets), port))
};
Ok(Self {
sequence_number,
address,
})
}
}
impl FrameStruct for RfcAddAddress {
const SIZE_BOUND: usize = 4 + 8 + 16 + 2;
}
#[derive(Debug, Clone, PartialEq, Eq)]
#[allow(dead_code)]
pub struct RfcPunchMeNow {
pub round: VarInt,
pub paired_with_sequence_number: VarInt,
pub address: SocketAddr,
}
#[allow(dead_code)]
impl RfcPunchMeNow {
pub fn encode<W: BufMut>(&self, buf: &mut W) {
match self.address {
SocketAddr::V4(_) => buf.write_var_or_debug_assert(FrameType::PUNCH_ME_NOW_IPV4.0),
SocketAddr::V6(_) => buf.write_var_or_debug_assert(FrameType::PUNCH_ME_NOW_IPV6.0),
}
buf.write_var_or_debug_assert(self.round.0);
buf.write_var_or_debug_assert(self.paired_with_sequence_number.0);
match self.address {
SocketAddr::V4(addr) => {
buf.put_slice(&addr.ip().octets());
buf.put_u16(addr.port());
}
SocketAddr::V6(addr) => {
buf.put_slice(&addr.ip().octets());
buf.put_u16(addr.port());
}
}
}
pub fn decode<R: Buf>(r: &mut R, is_ipv6: bool) -> Result<Self, UnexpectedEnd> {
let round = VarInt::from_u64(r.get_var()?).map_err(|_| UnexpectedEnd)?;
let paired_with_sequence_number =
VarInt::from_u64(r.get_var()?).map_err(|_| UnexpectedEnd)?;
let address = if is_ipv6 {
if r.remaining() < 16 + 2 {
return Err(UnexpectedEnd);
}
let mut octets = [0u8; 16];
r.copy_to_slice(&mut octets);
let port = r.get_u16();
SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::from(octets), port, 0, 0))
} else {
if r.remaining() < 4 + 2 {
return Err(UnexpectedEnd);
}
let mut octets = [0u8; 4];
r.copy_to_slice(&mut octets);
let port = r.get_u16();
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::from(octets), port))
};
Ok(Self {
round,
paired_with_sequence_number,
address,
})
}
}
impl FrameStruct for RfcPunchMeNow {
const SIZE_BOUND: usize = 4 + 8 + 8 + 16 + 2;
}
#[derive(Debug, Clone, PartialEq, Eq)]
#[allow(dead_code)]
pub struct RfcRemoveAddress {
pub sequence_number: VarInt,
}
#[allow(dead_code)]
impl RfcRemoveAddress {
pub fn encode<W: BufMut>(&self, buf: &mut W) {
buf.write_var_or_debug_assert(FrameType::REMOVE_ADDRESS.0);
buf.write_var_or_debug_assert(self.sequence_number.0);
}
pub fn decode<R: Buf>(r: &mut R) -> Result<Self, UnexpectedEnd> {
let sequence_number = VarInt::from_u64(r.get_var()?).map_err(|_| UnexpectedEnd)?;
Ok(Self { sequence_number })
}
}
impl FrameStruct for RfcRemoveAddress {
const SIZE_BOUND: usize = 4 + 8;
}
#[cfg(test)]
mod tests {
use super::*;
use bytes::BytesMut;
#[test]
fn test_rfc_add_address_roundtrip() {
let frame = RfcAddAddress {
sequence_number: VarInt::from_u32(42),
address: "192.168.1.100:8080".parse().unwrap(),
};
let mut buf = BytesMut::new();
frame.encode(&mut buf);
buf.advance(4);
let decoded = RfcAddAddress::decode(&mut buf, false).unwrap();
assert_eq!(frame.sequence_number, decoded.sequence_number);
assert_eq!(frame.address, decoded.address);
}
#[test]
fn test_rfc_punch_me_now_roundtrip() {
let frame = RfcPunchMeNow {
round: VarInt::from_u32(5),
paired_with_sequence_number: VarInt::from_u32(42),
address: "[2001:db8::1]:9000".parse().unwrap(),
};
let mut buf = BytesMut::new();
frame.encode(&mut buf);
buf.advance(4);
let decoded = RfcPunchMeNow::decode(&mut buf, true).unwrap();
assert_eq!(frame.round, decoded.round);
assert_eq!(
frame.paired_with_sequence_number,
decoded.paired_with_sequence_number
);
assert_eq!(frame.address, decoded.address);
}
}