use std::collections::VecDeque;
use std::net::SocketAddr;
use std::time::Instant;
use bytes::BytesMut;
use quinn_proto::{
ClientConfig, ConnectionHandle, DatagramEvent, Dir, Event, StreamEvent, StreamId,
};
use slab::Slab;
use crate::config::QuicConfig;
use crate::error::Error;
use crate::event::{QuicConnId, QuicEvent};
pub struct QuicEndpoint {
endpoint: quinn_proto::Endpoint,
connections: Slab<QuicConnection>,
handle_map: Vec<Option<u32>>,
events: VecDeque<QuicEvent>,
send_queue: VecDeque<OutgoingPacket>,
transmit_buf: Vec<u8>,
response_buf: Vec<u8>,
local_addr: SocketAddr,
client_config: Option<ClientConfig>,
send_queue_capacity: usize,
}
struct QuicConnection {
handle: ConnectionHandle,
conn: quinn_proto::Connection,
established: bool,
outbound: bool,
}
struct OutgoingPacket {
destination: SocketAddr,
data: Vec<u8>,
}
impl QuicEndpoint {
pub fn new(config: QuicConfig, local_addr: SocketAddr) -> Self {
let endpoint = quinn_proto::Endpoint::new(
config.endpoint_config,
config.server_config,
config.allow_mtud,
config.rng_seed,
);
Self {
endpoint,
connections: Slab::new(),
handle_map: Vec::new(),
events: VecDeque::new(),
send_queue: VecDeque::new(),
transmit_buf: Vec::with_capacity(1500),
response_buf: Vec::with_capacity(1500),
local_addr,
client_config: config.client_config,
send_queue_capacity: config.send_queue_capacity,
}
}
pub fn handle_datagram(&mut self, now: Instant, data: &[u8], peer: SocketAddr) {
let data = BytesMut::from(data);
let event = self.endpoint.handle(
now,
peer,
Some(self.local_addr.ip()),
None, data,
&mut self.response_buf,
);
match event {
Some(DatagramEvent::ConnectionEvent(ch, event)) => {
if let Some(&Some(key)) = self.handle_map.get(ch.0) {
let key = key as usize;
self.connections[key].conn.handle_event(event);
self.poll_connection(key, now);
}
}
Some(DatagramEvent::NewConnection(incoming)) => {
let result = self.endpoint.accept(
incoming,
now,
&mut self.response_buf,
None, );
match result {
Ok((ch, conn)) => {
let key = self.insert_connection(ch, conn, false);
self.drain_transmits(key, now);
self.poll_connection(key, now);
}
Err(_) => {
}
}
}
Some(DatagramEvent::Response(transmit)) => {
let data = self.response_buf[..transmit.size].to_vec();
self.queue_packet(transmit.destination, data);
}
None => {}
}
}
pub fn drive_timers(&mut self, now: Instant) {
let keys: Vec<u32> = self.connections.iter().map(|(k, _)| k as u32).collect();
for key in keys {
let key = key as usize;
if !self.connections.contains(key) {
continue;
}
if let Some(timeout) = self.connections[key].conn.poll_timeout()
&& timeout <= now
{
self.connections[key].conn.handle_timeout(now);
self.drain_transmits(key, now);
self.poll_connection(key, now);
}
}
}
pub fn poll_event(&mut self) -> Option<QuicEvent> {
self.events.pop_front()
}
pub fn poll_send(&mut self) -> Option<(SocketAddr, Vec<u8>)> {
self.send_queue
.pop_front()
.map(|pkt| (pkt.destination, pkt.data))
}
pub fn connect(
&mut self,
now: Instant,
peer: SocketAddr,
server_name: &str,
) -> Result<QuicConnId, Error> {
let client_config = self.client_config.clone().ok_or(Error::ConnectionClosed)?;
let (ch, conn) = self
.endpoint
.connect(now, client_config, peer, server_name)?;
let key = self.insert_connection(ch, conn, true);
self.drain_transmits(key, now);
Ok(QuicConnId(key as u32))
}
pub fn stream_send(
&mut self,
conn: QuicConnId,
stream: StreamId,
data: &[u8],
) -> Result<usize, Error> {
let c = self.get_conn_mut(conn)?;
let n = c.conn.send_stream(stream).write(data)?;
Ok(n)
}
pub fn stream_recv(
&mut self,
conn: QuicConnId,
stream: StreamId,
buf: &mut [u8],
) -> Result<(usize, bool), Error> {
let c = self.get_conn_mut(conn)?;
let mut recv = c.conn.recv_stream(stream);
let mut chunks = recv.read(true)?;
let mut total = 0;
let mut finished = false;
while total < buf.len() {
match chunks.next(buf.len() - total) {
Ok(Some(chunk)) => {
let len = chunk.bytes.len();
buf[total..total + len].copy_from_slice(&chunk.bytes);
total += len;
}
Ok(None) => {
finished = true;
break;
}
Err(quinn_proto::ReadError::Blocked) => break,
Err(e) => {
let _ = chunks.finalize();
return Err(Error::Read(e));
}
}
}
let _ = chunks.finalize();
Ok((total, finished))
}
pub fn stream_finish(&mut self, conn: QuicConnId, stream: StreamId) -> Result<(), Error> {
let c = self.get_conn_mut(conn)?;
c.conn
.send_stream(stream)
.finish()
.map_err(|_| Error::ConnectionClosed)?;
Ok(())
}
pub fn open_bi(&mut self, conn: QuicConnId) -> Result<Option<StreamId>, Error> {
let c = self.get_conn_mut(conn)?;
Ok(c.conn.streams().open(Dir::Bi))
}
pub fn open_uni(&mut self, conn: QuicConnId) -> Result<Option<StreamId>, Error> {
let c = self.get_conn_mut(conn)?;
Ok(c.conn.streams().open(Dir::Uni))
}
pub fn close_connection(&mut self, conn: QuicConnId, code: u32, reason: &[u8]) {
if let Ok(c) = self.get_conn_mut(conn) {
c.conn.close(
Instant::now(),
quinn_proto::VarInt::from_u32(code),
bytes::Bytes::copy_from_slice(reason),
);
}
}
pub fn connection_count(&self) -> usize {
self.connections.len()
}
pub fn send_queue_len(&self) -> usize {
self.send_queue.len()
}
pub fn remote_addr(&self, conn: QuicConnId) -> Option<SocketAddr> {
self.connections
.get(conn.0 as usize)
.map(|c| c.conn.remote_address())
}
fn insert_connection(
&mut self,
ch: ConnectionHandle,
conn: quinn_proto::Connection,
outbound: bool,
) -> usize {
let key = self.connections.insert(QuicConnection {
handle: ch,
conn,
established: false,
outbound,
});
let idx = ch.0;
if idx >= self.handle_map.len() {
self.handle_map.resize(idx + 1, None);
}
self.handle_map[idx] = Some(key as u32);
key
}
fn get_conn_mut(&mut self, conn: QuicConnId) -> Result<&mut QuicConnection, Error> {
self.connections
.get_mut(conn.0 as usize)
.ok_or(Error::InvalidConnection)
}
fn drain_transmits(&mut self, key: usize, now: Instant) {
loop {
self.transmit_buf.clear();
let transmit = self.connections[key]
.conn
.poll_transmit(now, 1, &mut self.transmit_buf);
match transmit {
Some(t) => {
let data = self.transmit_buf[..t.size].to_vec();
self.queue_packet(t.destination, data);
}
None => break,
}
}
}
fn poll_connection(&mut self, key: usize, now: Instant) {
while let Some(event) = self.connections[key].conn.poll_endpoint_events() {
if let Some(conn_event) = self
.endpoint
.handle_event(self.connections[key].handle, event)
{
self.connections[key].conn.handle_event(conn_event);
}
}
self.drain_transmits(key, now);
let conn_id = QuicConnId(key as u32);
while let Some(event) = self.connections[key].conn.poll() {
match event {
Event::Connected => {
self.connections[key].established = true;
if self.connections[key].outbound {
self.events.push_back(QuicEvent::Connected(conn_id));
} else {
self.events.push_back(QuicEvent::NewConnection(conn_id));
}
}
Event::ConnectionLost { reason } => {
self.events.push_back(QuicEvent::ConnectionClosed {
conn: conn_id,
reason,
});
self.remove_connection(key);
return; }
Event::Stream(stream_event) => match stream_event {
StreamEvent::Opened { dir } => {
while let Some(stream) = self.connections[key].conn.streams().accept(dir) {
self.events.push_back(QuicEvent::StreamOpened {
conn: conn_id,
stream,
bidi: dir == Dir::Bi,
});
}
}
StreamEvent::Readable { id } => {
self.events.push_back(QuicEvent::StreamReadable {
conn: conn_id,
stream: id,
});
}
StreamEvent::Writable { id } => {
self.events.push_back(QuicEvent::StreamWritable {
conn: conn_id,
stream: id,
});
}
StreamEvent::Finished { id } => {
self.events.push_back(QuicEvent::StreamFinished {
conn: conn_id,
stream: id,
});
}
StreamEvent::Stopped { .. } | StreamEvent::Available { .. } => {
}
},
Event::HandshakeDataReady | Event::DatagramReceived | Event::DatagramsUnblocked => {
}
}
}
self.drain_transmits(key, now);
if self.connections.contains(key) && self.connections[key].conn.is_drained() {
self.remove_connection(key);
}
}
fn remove_connection(&mut self, key: usize) {
let qc = self.connections.remove(key);
let idx = qc.handle.0;
if idx < self.handle_map.len() {
self.handle_map[idx] = None;
}
}
fn queue_packet(&mut self, destination: SocketAddr, data: Vec<u8>) {
if self.send_queue.len() < self.send_queue_capacity {
self.send_queue
.push_back(OutgoingPacket { destination, data });
}
}
}