use std::net::SocketAddr;
use derive_more::Deref;
use super::{
EncodeSize, GetFrameType,
io::{WriteFrame, WriteFrameType},
};
use crate::{
net::{AddrFamily, Family, NatType, be_socket_addr},
varint::{VarInt, WriteVarInt, be_varint},
};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deref)]
pub struct PunchMeNowFrame {
local_seq: VarInt,
remote_seq: VarInt,
#[deref]
address: SocketAddr,
tire: VarInt,
nat_type: NatType,
}
pub(crate) fn be_punch_me_now_frame(
family: Family,
) -> impl Fn(&[u8]) -> nom::IResult<&[u8], PunchMeNowFrame> {
move |input| {
let (remain, local_seq) = be_varint(input)?;
let (remain, remote_seq) = be_varint(remain)?;
let (remain, address) = be_socket_addr(remain, family)?;
let (remain, tire) = be_varint(remain)?;
let (remain, nat_type) = be_varint(remain)?;
let nat_type = NatType::try_from(nat_type).map_err(|_| {
nom::Err::Error(nom::error::Error::new(
remain,
nom::error::ErrorKind::Verify,
))
})?;
Ok((
remain,
PunchMeNowFrame {
local_seq,
remote_seq,
address,
tire,
nat_type,
},
))
}
}
impl GetFrameType for PunchMeNowFrame {
fn frame_type(&self) -> super::FrameType {
super::FrameType::PunchMeNow(self.address.family())
}
}
impl EncodeSize for PunchMeNowFrame {
fn max_encoding_size(&self) -> usize {
4 + 8 + 8 + self.address.max_encoding_size() + 8 + 8
}
fn encoding_size(&self) -> usize {
VarInt::from(self.frame_type()).encoding_size()
+ self.local_seq.encoding_size()
+ self.remote_seq.encoding_size()
+ self.address.encoding_size()
+ self.tire.encoding_size()
+ VarInt::from(self.nat_type).encoding_size()
}
}
impl PunchMeNowFrame {
pub fn new(
local_seq: u32,
remote_seq: u32,
address: SocketAddr,
tire: u32,
nat_type: NatType,
) -> Self {
Self {
local_seq: VarInt::from_u32(local_seq),
remote_seq: VarInt::from_u32(remote_seq),
address,
tire: VarInt::from_u32(tire),
nat_type,
}
}
pub fn local_seq(&self) -> u32 {
self.local_seq.into_inner() as u32
}
pub fn remote_seq(&self) -> u32 {
self.remote_seq.into_inner() as u32
}
pub fn nat_type(&self) -> NatType {
self.nat_type
}
pub fn set_addr(&mut self, addr: SocketAddr) {
self.address = addr;
}
pub fn address(&self) -> SocketAddr {
self.address
}
pub fn tire(&self) -> u32 {
self.tire.into_inner() as u32
}
}
impl<T: bytes::BufMut> WriteFrame<PunchMeNowFrame> for T {
fn put_frame(&mut self, frame: &PunchMeNowFrame) {
self.put_frame_type(frame.frame_type());
self.put_varint(&frame.local_seq);
self.put_varint(&frame.remote_seq);
self.put_u16(frame.address.port());
match frame.address.ip() {
std::net::IpAddr::V4(ipv4) => self.put_slice(&ipv4.octets()),
std::net::IpAddr::V6(ipv6) => self.put_slice(&ipv6.octets()),
}
self.put_varint(&frame.tire);
self.put_varint(&VarInt::from(frame.nat_type));
}
}
#[cfg(test)]
mod tests {
use bytes::BytesMut;
use super::*;
use crate::frame::{GetFrameType, be_frame_type, io::WriteFrame};
#[test]
fn test_punch_me_now_frame() {
let frame = PunchMeNowFrame {
local_seq: VarInt::from_u32(1),
remote_seq: VarInt::from_u32(2),
address: "127.0.0.1:12345".parse().unwrap(),
tire: VarInt::from_u32(0x01u32),
nat_type: NatType::FullCone,
};
let mut buf = BytesMut::with_capacity(frame.max_encoding_size());
buf.put_frame(&frame);
let (remain, frame_type) = be_frame_type(&buf).unwrap();
assert_eq!(frame_type, frame.frame_type());
let frame2 = be_punch_me_now_frame(Family::V4)(remain).unwrap().1;
assert_eq!(frame, frame2);
}
}