use std::collections::VecDeque;
use std::io;
use std::mem;
use std::net::SocketAddr;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::u16;
use futures::prelude::*;
use generic_array::typenum::consts::U16;
use generic_array::GenericArray;
use num_traits::ToPrimitive;
use slog::{o, Logger};
use tokio::io::ReadBuf;
use tokio::net::UdpSocket;
use tsproto_packets::packets::*;
use tsproto_types::crypto::EccKeyPubP256;
use crate::packet_codec::PacketCodec;
use crate::resend::{PacketId, PartialPacketId, Resender, ResenderState};
use crate::{Error, Result, MAX_UDP_PACKET_LENGTH, UDP_SINK_CAPACITY};
pub trait Socket {
fn poll_recv_from(&self, cx: &mut Context, buf: &mut ReadBuf) -> Poll<io::Result<SocketAddr>>;
fn poll_send_to(
&self, cx: &mut Context, buf: &[u8], target: SocketAddr,
) -> Poll<io::Result<usize>>;
fn local_addr(&self) -> io::Result<SocketAddr>;
}
#[derive(Debug)]
pub struct CachedKey {
pub generation_id: u32,
pub key: GenericArray<u8, U16>,
pub nonce: GenericArray<u8, U16>,
}
pub struct ConnectedParams {
pub c_id: u16,
pub voice_encryption: bool,
pub public_key: EccKeyPubP256,
pub shared_iv: [u8; 64],
pub shared_mac: [u8; 8],
pub key_cache: [[CachedKey; 2]; 8],
}
#[derive(Debug)]
pub enum Event<'a> {
ReceiveUdpPacket(&'a InUdpPacket<'a>),
ReceivePacket(&'a InPacket<'a>),
SendUdpPacket(&'a OutUdpPacket),
SendPacket(&'a OutPacket),
}
#[derive(Debug)]
pub enum StreamItem {
Command(InCommandBuf),
Audio(InAudioBuf),
C2SInit(InC2SInitBuf),
S2CInit(InS2CInitBuf),
AckPacket(PacketId),
NetworkStatsUpdated,
Error(Error),
}
type EventListener = Box<dyn for<'a> Fn(&'a Event<'a>) + Send>;
pub struct Connection {
pub is_client: bool,
pub logger: Logger,
pub params: Option<ConnectedParams>,
pub address: SocketAddr,
pub resender: Resender,
pub codec: PacketCodec,
pub udp_socket: Box<dyn Socket + Send>,
udp_buffer: Vec<u8>,
pub(crate) stream_items: VecDeque<StreamItem>,
acks_to_send: VecDeque<OutUdpPacket>,
pub event_listeners: Vec<EventListener>,
}
impl Socket for UdpSocket {
fn poll_recv_from(&self, cx: &mut Context, buf: &mut ReadBuf) -> Poll<io::Result<SocketAddr>> {
self.poll_recv_from(cx, buf)
}
fn poll_send_to(
&self, cx: &mut Context, buf: &[u8], target: SocketAddr,
) -> Poll<io::Result<usize>> {
self.poll_send_to(cx, buf, target)
}
fn local_addr(&self) -> io::Result<SocketAddr> { self.local_addr() }
}
impl Default for CachedKey {
fn default() -> Self {
CachedKey { generation_id: u32::max_value(), key: [0; 16].into(), nonce: [0; 16].into() }
}
}
impl ConnectedParams {
pub fn new(public_key: EccKeyPubP256, shared_iv: [u8; 64], shared_mac: [u8; 8]) -> Self {
Self {
c_id: 0,
voice_encryption: true,
public_key,
shared_iv,
shared_mac,
key_cache: Default::default(),
}
}
}
impl Connection {
pub fn new(
is_client: bool, logger: Logger, address: SocketAddr, udp_socket: Box<dyn Socket + Send>,
) -> Self {
let logger = logger.new(o!("local_addr" => udp_socket.local_addr().unwrap().to_string(),
"remote_addr" => address.to_string()));
let mut res = Self {
is_client,
logger,
params: None,
address,
resender: Default::default(),
codec: Default::default(),
udp_socket,
udp_buffer: Default::default(),
stream_items: Default::default(),
acks_to_send: Default::default(),
event_listeners: Default::default(),
};
if is_client {
res.codec.outgoing_p_ids[PacketType::Command.to_usize().unwrap()] =
PartialPacketId { generation_id: 0, packet_id: 1 };
} else {
res.codec.incoming_p_ids[PacketType::Command.to_usize().unwrap()] =
PartialPacketId { generation_id: 0, packet_id: 1 };
}
res
}
pub(crate) fn in_receive_window(&self, p_type: PacketType, p_id: u16) -> (bool, u32, u16, u16) {
if p_type == PacketType::Init {
return (true, 0, 0, 0);
}
let type_i = p_type.to_usize().unwrap();
let cur_next = self.codec.incoming_p_ids[type_i].packet_id;
let (limit, next_gen) = cur_next.overflowing_add(u16::MAX / 2);
let gen = self.codec.incoming_p_ids[type_i].generation_id;
let in_recv_win = (!next_gen && p_id >= cur_next && p_id < limit)
|| (next_gen && (p_id >= cur_next || p_id < limit));
let gen_id = if in_recv_win {
if next_gen && p_id < limit { gen + 1 } else { gen }
} else if p_id < cur_next {
gen
} else {
gen - 1
};
(in_recv_win, gen_id, cur_next, limit)
}
pub fn send_event(&self, event: &Event) {
for l in &self.event_listeners {
l(event)
}
}
pub fn hand_back_buffer(&mut self, buffer: Vec<u8>) {
if self.udp_buffer.capacity() < MAX_UDP_PACKET_LENGTH
&& buffer.capacity() >= MAX_UDP_PACKET_LENGTH
{
self.udp_buffer = buffer;
}
}
fn poll_send_acks(&mut self, cx: &mut Context) -> Result<()> {
while let Some(packet) = self.acks_to_send.front() {
match self.poll_send_udp_packet(cx, packet) {
Poll::Ready(Ok(())) => {
self.resender.handle_loss_outgoing(packet);
}
Poll::Ready(Err(e)) => return Err(e),
Poll::Pending => break,
}
self.acks_to_send.pop_front();
}
Ok(())
}
fn poll_incoming_udp_packet(&mut self, cx: &mut Context) -> Poll<Result<StreamItem>> {
if self.acks_to_send.len() >= UDP_SINK_CAPACITY {
return Poll::Pending;
}
loop {
if self.udp_buffer.len() != MAX_UDP_PACKET_LENGTH {
self.udp_buffer.resize(MAX_UDP_PACKET_LENGTH, 0);
}
let mut read_buf = ReadBuf::new(&mut self.udp_buffer);
match self.udp_socket.poll_recv_from(cx, &mut &mut read_buf) {
Poll::Ready(Ok(addr)) => {
let size = read_buf.filled().len();
let mut udp_buffer = mem::take(&mut self.udp_buffer);
udp_buffer.truncate(size);
match self.handle_udp_packet(cx, udp_buffer, addr) {
Ok(()) => {
if let Some(item) = self.stream_items.pop_front() {
return Poll::Ready(Ok(item));
}
}
Err(e) => {
return Poll::Ready(Err(e));
}
}
}
Poll::Ready(Err(e)) => return Poll::Ready(Err(Error::Network(e))),
Poll::Pending => return Poll::Pending,
}
}
}
fn handle_udp_packet(
&mut self, cx: &mut Context, udp_buffer: Vec<u8>, addr: SocketAddr,
) -> Result<()> {
if addr != self.address {
self.stream_items.push_back(StreamItem::Error(Error::WrongAddress));
return Ok(());
}
let dir = if self.is_client { Direction::S2C } else { Direction::C2S };
let packet = InUdpPacket(match InPacket::try_new(dir, &udp_buffer) {
Ok(r) => r,
Err(e) => {
self.stream_items.push_back(StreamItem::Error(Error::PacketParse("udp", e)));
return Ok(());
}
});
let event = Event::ReceiveUdpPacket(&packet);
self.send_event(&event);
self.resender.received_packet();
PacketCodec::handle_udp_packet(self, cx, udp_buffer)?;
Ok(())
}
pub(crate) fn send_ack_packet(&mut self, cx: &mut Context, packet: OutPacket) -> Result<()> {
self.send_event(&Event::SendPacket(&packet));
let mut udp_packets = PacketCodec::encode_packet(self, packet)?;
assert_eq!(
udp_packets.len(),
1,
"Encoding an ack packet should only yield a single packet"
);
let packet = udp_packets.pop().unwrap();
match self.poll_send_udp_packet(cx, &packet) {
Poll::Ready(r) => {
if r.is_ok() {
self.resender.handle_loss_outgoing(&packet);
}
r
}
Poll::Pending => {
self.acks_to_send.push_back(packet);
Ok(())
}
}
}
pub fn send_packet(&mut self, packet: OutPacket) -> Result<PacketId> {
self.send_event(&Event::SendPacket(&packet));
let udp_packets = PacketCodec::encode_packet(self, packet)?;
let id = udp_packets.last().unwrap().into();
for p in udp_packets {
self.send_udp_packet(p);
}
Ok(id)
}
pub fn send_udp_packet(&mut self, packet: OutUdpPacket) {
match packet.packet_type() {
PacketType::Init | PacketType::Command | PacketType::CommandLow => {
Resender::send_packet(self, packet);
}
_ => self.acks_to_send.push_back(packet),
}
}
pub fn poll_send_udp_packet(
&self, cx: &mut Context, packet: &OutUdpPacket,
) -> Poll<Result<()>> {
Self::static_poll_send_udp_packet(
&*self.udp_socket,
self.address,
&self.event_listeners,
cx,
packet,
)
}
pub fn static_poll_send_udp_packet(
udp_socket: &dyn Socket, address: SocketAddr, event_listeners: &[EventListener],
cx: &mut Context, packet: &OutUdpPacket,
) -> Poll<Result<()>> {
let data = packet.data().data();
match udp_socket.poll_send_to(cx, data, address).map_err(Error::Network)? {
Poll::Pending => Poll::Pending,
Poll::Ready(size) => {
let event = Event::SendUdpPacket(&packet);
for l in event_listeners {
l(&event)
}
if size != data.len() {
Poll::Ready(Err(Error::Network(std::io::Error::new(
std::io::ErrorKind::Other,
"Failed to send whole udp packet",
))))
} else {
Poll::Ready(Ok(()))
}
}
}
}
pub fn is_send_queue_full(&self) -> bool { self.resender.is_full() }
pub fn is_send_queue_empty(&self) -> bool { self.resender.is_empty() }
}
impl Stream for Connection {
type Item = Result<StreamItem>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
if let Err(e) = self.poll_send_acks(cx) {
return Poll::Ready(Some(Err(e)));
}
if self.resender.get_state() == ResenderState::Disconnected {
if self.acks_to_send.is_empty() {
return Poll::Ready(None);
}
}
match Resender::poll_resend(&mut *self, cx) {
Ok(()) => {}
Err(e) => return Poll::Ready(Some(Err(e))),
}
match Resender::poll_ping(&mut *self, cx) {
Ok(()) => {}
Err(e) => return Poll::Ready(Some(Err(e))),
}
if let Some(item) = self.stream_items.pop_front() {
return Poll::Ready(Some(Ok(item)));
}
match self.poll_incoming_udp_packet(cx) {
Poll::Ready(r) => return Poll::Ready(Some(r)),
Poll::Pending => {}
}
Poll::Pending
}
}