use futures::{
future::BoxFuture,
stream::{BoxStream, Stream, StreamExt},
Future, FutureExt,
};
use quinn::Connecting;
use std::{
fmt::Debug,
hash::Hash,
net::SocketAddr,
pin::Pin,
sync::{Arc, TryLockError},
task::{Context, Poll},
};
use tokio_stream::StreamMap;
use socket2;
use crate::common::protocol::tunnel::{BoxedTunnel, TunnelSide};
pub struct QuinnListenEndpoint {
bind_addr: SocketAddr,
endpoint: Pin<Box<quinn::Endpoint>>,
accepting: Option<BoxFuture<'static, Option<Connecting>>>,
is_terminated: bool,
}
impl QuinnListenEndpoint {
pub fn bind(
bind_addr: SocketAddr,
quinn_config: quinn::ServerConfig,
) -> Result<Self, std::io::Error> {
let endpoint = quinn::Endpoint::server(quinn_config, bind_addr)?;
Ok(Self {
bind_addr,
endpoint: Box::pin(endpoint),
accepting: None,
is_terminated: false,
})
}
pub fn bind_address(&self) -> SocketAddr {
self.bind_addr
}
pub fn bind_with_buffer_sizes(
bind_addr: SocketAddr,
quinn_config: quinn::ServerConfig,
recv_socket_buffer_size: usize,
send_socket_buffer_size: usize,
) -> Result<Self, std::io::Error> {
let socket = std::net::UdpSocket::bind(bind_addr)?;
let socket2 = socket2::SockRef::from(&socket);
if recv_socket_buffer_size > 0 {
socket2.set_recv_buffer_size(recv_socket_buffer_size)?;
}
if send_socket_buffer_size > 0 {
socket2.set_send_buffer_size(send_socket_buffer_size)?;
}
let runtime = quinn::default_runtime()
.ok_or_else(||std::io::Error::new(std::io::ErrorKind::Other, "no async runtime found"))?;
let endpoint = quinn::Endpoint::new(
quinn::EndpointConfig::default(),
Some(quinn_config),
socket,
runtime,
)?;
Ok(Self { bind_addr, endpoint: Box::pin(endpoint), accepting: None, is_terminated: false })
}
}
impl Stream for QuinnListenEndpoint
where
Self: Send + Unpin,
{
type Item = (quinn::Connecting, TunnelSide);
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
if self.is_terminated {
if self.accepting.is_some() {
self.accepting = None;
}
return Poll::Ready(None);
}
let endpoint = self.endpoint.clone();
let accepting = match &mut self.accepting {
None => self
.accepting
.insert(async move { endpoint.accept().await }.boxed()),
Some(accepting) => accepting,
};
if let Some(connecting) = futures::ready!(Future::poll(accepting.as_mut(), cx)) {
drop(accepting);
self.accepting = None;
Poll::Ready(Some((connecting, TunnelSide::Listen)))
} else {
self.accepting = None;
self.is_terminated = true;
Poll::Ready(None)
}
}
}
pub struct NamedBoxedStream<Id, StreamItem> {
id: Id,
stream: BoxStream<'static, StreamItem>,
}
impl<Id, StreamItem> NamedBoxedStream<Id, StreamItem> {
pub fn new<TStream>(id: Id, stream: TStream) -> Self
where
TStream: Stream<Item = StreamItem> + Send + Sync + 'static,
{
Self::new_pre_boxed(id, stream.boxed())
}
pub fn new_pre_boxed(id: Id, stream: BoxStream<'static, StreamItem>) -> Self {
Self { id, stream }
}
}
impl<Id, StreamItem> Stream for NamedBoxedStream<Id, StreamItem>
where
Id: Unpin,
{
type Item = StreamItem;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Stream::poll_next(Pin::new(&mut self.stream), cx)
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.stream.size_hint()
}
}
impl<Id, StreamItem> std::fmt::Debug for NamedBoxedStream<Id, StreamItem>
where
Id: Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct(stringify!(DynamicConnection))
.field("id", &self.id)
.finish_non_exhaustive()
}
}
pub type DynamicConnectionSet<Id, TunnelType = BoxedTunnel<'static>> =
DynamicStreamSet<Id, TunnelType>;
pub struct DynamicStreamSet<Id, TStream> {
streams: Arc<std::sync::Mutex<StreamMap<Id, NamedBoxedStream<Id, TStream>>>>,
}
pub struct DynamicStreamSetHandle<Id, TStream> {
streams: Arc<std::sync::Mutex<StreamMap<Id, NamedBoxedStream<Id, TStream>>>>,
}
impl<Id, StreamItem> DynamicStreamSet<Id, StreamItem> {
pub fn new() -> Self {
Self {
streams: Arc::new(std::sync::Mutex::new(StreamMap::new())),
}
}
pub fn attach(
&self,
source: NamedBoxedStream<Id, StreamItem>,
) -> Option<NamedBoxedStream<Id, StreamItem>>
where
Id: Clone + Hash + Eq,
{
let mut streams = self.streams.lock().expect("Mutex poisoned");
streams.insert(source.id.clone(), source)
}
pub fn attach_stream(
&self,
id: Id,
source: BoxStream<'static, StreamItem>,
) -> Option<NamedBoxedStream<Id, StreamItem>>
where
Id: Clone + Hash + Eq,
{
let endpoint = NamedBoxedStream::new_pre_boxed(id.clone(), source);
self.attach(endpoint)
}
pub fn detach(&self, id: &Id) -> Option<NamedBoxedStream<Id, StreamItem>>
where
Id: Hash + Eq,
{
let mut streams = self.streams.lock().expect("Mutex poisoned");
streams.remove(id)
}
pub fn handle(&self) -> DynamicStreamSetHandle<Id, StreamItem> {
DynamicStreamSetHandle {
streams: self.streams.clone(),
}
}
pub fn into_handle(self) -> DynamicStreamSetHandle<Id, StreamItem> {
DynamicStreamSetHandle {
streams: self.streams,
}
}
fn poll_next(
streams: &std::sync::Mutex<StreamMap<Id, NamedBoxedStream<Id, StreamItem>>>,
cx: &mut Context<'_>,
) -> Poll<Option<(Id, StreamItem)>>
where
Id: Clone + Unpin,
{
let mut streams = match streams.try_lock() {
Ok(s) => s,
Err(TryLockError::WouldBlock) => {
cx.waker().wake_by_ref();
return Poll::Pending;
}
Err(TryLockError::Poisoned(poison)) => Err(poison).expect("Lock poisoned"),
};
Stream::poll_next(Pin::new(&mut *streams), cx)
}
}
impl<Id, StreamItem> Stream for DynamicStreamSet<Id, StreamItem>
where
Id: Clone + Unpin,
{
type Item = (Id, StreamItem);
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Self::poll_next(&*self.streams, cx)
}
}
impl<Id, StreamItem> Stream for DynamicStreamSetHandle<Id, StreamItem>
where
Id: Clone + Unpin,
{
type Item = (Id, StreamItem);
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
DynamicStreamSet::poll_next(&*self.streams, cx)
}
}
#[cfg(test)]
mod tests {
use super::{DynamicStreamSet, QuinnListenEndpoint};
use crate::common::protocol::tunnel::{quinn_tunnel::QuinnTunnel, IntoTunnel};
use futures::{stream, FutureExt, StreamExt};
use std::collections::HashSet;
use std::iter::FromIterator;
#[allow(dead_code)]
fn static_test_endpoint_items_assign_tunnel_id(
mut endpoint: QuinnListenEndpoint,
) -> Option<impl IntoTunnel<Tunnel = QuinnTunnel>> {
let (connecting, side) = endpoint.next().now_or_never().flatten()?;
let connection = connecting.now_or_never()?.ok()?;
Some((connection, side))
}
#[tokio::test]
async fn add_and_remove() {
let set = DynamicStreamSet::<u32, char>::new();
let a = stream::iter(vec!['a']).boxed();
let b = stream::iter(vec!['b']).boxed();
let c = stream::iter(vec!['c']).boxed();
assert!(set.attach_stream(1u32, a).is_none(), "Must attach to blank");
assert!(
set.attach_stream(2u32, b).is_none(),
"Must attach to non-blank with new key"
);
let mut replaced_b = set
.attach_stream(2u32, c)
.expect("Must overwrite keys and return an old one");
let mut detached_a = set.detach(&1u32).expect("Must detach fresh keys by ID");
let mut detached_c = set.detach(&2u32).expect("Must detach replaced keys by ID");
assert_eq!(detached_a.id, 1u32);
assert_eq!(
detached_a.stream.next().await.expect("Must have item"),
'a',
"Fresh-key stream identity mismatch"
);
assert_eq!(replaced_b.id, 2u32);
assert_eq!(
replaced_b.stream.next().await.expect("Must have item"),
'b',
"Replaced stream identity mismatch"
);
assert_eq!(detached_c.id, 2u32);
assert_eq!(
detached_c.stream.next().await.expect("Must have item"),
'c',
"Replacement stream identity mismatch"
);
}
#[tokio::test]
async fn poll_contents() {
let set = DynamicStreamSet::<u32, char>::new();
let a = stream::iter(vec!['a']).boxed();
let b = stream::iter(vec!['b']).boxed();
let c = stream::iter(vec!['c']).boxed();
assert!(set.attach_stream(1u32, a).is_none(), "Must attach to blank");
assert!(
set.attach_stream(2u32, b).is_none(),
"Must attach to non-blank with new key"
);
set
.attach_stream(2u32, c)
.expect("Must replace existing keys");
let results = set.collect::<HashSet<_>>().await;
assert_eq!(
results,
HashSet::from_iter(vec![(1, 'a'), (2, 'c')].into_iter())
);
}
#[tokio::test]
async fn end_of_stream_removal() {
use std::sync::Arc;
let set = Arc::new(DynamicStreamSet::<u32, i32>::new());
let a = stream::iter(vec![1, 2, 3]).boxed();
assert!(set.attach_stream(1u32, a).is_none(), "Must attach to blank");
let collected = set.handle().collect::<Vec<_>>().await;
assert_eq!(collected.as_slice(), &[(1, 1), (1, 2), (1, 3)]);
assert!(
set.detach(&1u32).is_none(),
"Must have already detached if polled to empty"
);
}
}