use std::net::{IpAddr, SocketAddr};
use derive_more::Deref;
use super::{
EncodeSize, GetFrameType,
io::{WriteFrame, WriteFrameType},
};
use crate::{
net::{AddrFamily, Family, NatType, WriteSocketAddr, be_socket_addr},
varint::{VarInt, WriteVarInt, be_varint},
};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deref)]
pub struct AddAddressFrame {
#[deref]
address: SocketAddr,
seq_num: VarInt,
tire: VarInt,
nat_type: NatType,
}
pub(crate) fn be_add_address_frame(
family: Family,
) -> impl Fn(&[u8]) -> nom::IResult<&[u8], AddAddressFrame> {
move |input| {
let (remain, seq_num) = be_varint(input)?;
let (remain, addr) = be_socket_addr(remain, family)?;
let (remain, tire) = be_varint(remain)?;
let (remain, nat_type) = be_varint(remain)?;
let nat_type = NatType::try_from(nat_type).map_err(|_| {
nom::Err::Error(nom::error::Error::new(
remain,
nom::error::ErrorKind::Verify,
))
})?;
Ok((
remain,
AddAddressFrame {
seq_num,
address: addr,
tire,
nat_type,
},
))
}
}
impl GetFrameType for AddAddressFrame {
fn frame_type(&self) -> super::FrameType {
super::FrameType::AddAddress(self.address.family())
}
}
impl EncodeSize for AddAddressFrame {
fn max_encoding_size(&self) -> usize {
4 + 8 + 2 + 16 + 8 + 8 }
fn encoding_size(&self) -> usize {
let addr_size = match self.address.ip() {
IpAddr::V4(_) => 2 + 4,
IpAddr::V6(_) => 2 + 16,
};
VarInt::from(self.frame_type()).encoding_size()
+ self.seq_num.encoding_size()
+ addr_size
+ self.tire.encoding_size()
+ VarInt::from(self.nat_type).encoding_size()
}
}
impl AddAddressFrame {
pub fn new(seq_num: u32, address: SocketAddr, tire: u32, nat_type: NatType) -> Self {
Self {
seq_num: VarInt::from_u32(seq_num),
address,
tire: VarInt::from_u32(tire),
nat_type,
}
}
pub fn seq_num(&self) -> u32 {
self.seq_num.into_inner() as u32
}
pub fn tire(&self) -> u32 {
self.tire.into_inner() as u32
}
pub fn nat_type(&self) -> NatType {
self.nat_type
}
}
impl<T: bytes::BufMut> WriteFrame<AddAddressFrame> for T {
fn put_frame(&mut self, frame: &AddAddressFrame) {
self.put_frame_type(frame.frame_type());
self.put_varint(&frame.seq_num);
self.put_socket_addr(&frame.address);
self.put_varint(&frame.tire);
self.put_varint(&VarInt::from(frame.nat_type));
}
}
#[cfg(test)]
mod tests {
use std::net::Ipv4Addr;
use bytes::BytesMut;
use super::*;
use crate::frame::{GetFrameType, be_frame_type, io::WriteFrame};
#[test]
fn test_add_address_frame() {
let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080);
let frame = AddAddressFrame {
seq_num: VarInt::from_u32(1u32),
address: addr,
tire: VarInt::from_u32(1u32),
nat_type: NatType::FullCone,
};
let mut buf = BytesMut::new();
buf.put_frame(&frame);
let (remain, frame_type) = be_frame_type(&buf).unwrap();
assert_eq!(frame_type, frame.frame_type());
let frame2 = be_add_address_frame(Family::V4)(remain).unwrap().1;
assert_eq!(frame, frame2);
}
}