mod builder;
use std::{
fmt::{self, Debug, Formatter},
net::SocketAddr,
pin::Pin,
task::{Context, Poll},
};
pub use builder::Builder;
use flume::{r#async::RecvStream, Sender};
use futures_channel::oneshot::{self, Receiver};
use futures_util::{
stream::{FusedStream, Stream},
StreamExt,
};
use quinn::{ClientConfig, ServerConfig, VarInt};
#[cfg(feature = "dns")]
use trust_dns_resolver::{
config::{ResolverConfig, ResolverOpts},
TokioAsyncResolver,
};
use super::Task;
use crate::{
certificate::{Certificate, PrivateKey},
Connecting, Error, Result,
};
#[derive(Clone)]
pub struct Endpoint {
endpoint: quinn::Endpoint,
receiver: RecvStream<'static, Connecting>,
task: Task<()>,
}
impl Debug for Endpoint {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.debug_struct("Server")
.field("endpoint", &self.endpoint)
.field("receiver", &"RecvStream<Connection>")
.field("task", &self.task)
.finish()
}
}
impl Endpoint {
#[must_use]
pub fn builder() -> Builder {
Builder::new()
}
fn new(
address: SocketAddr,
client: ClientConfig,
server: Option<ServerConfig>,
) -> Result<Self> {
let mut endpoint_builder = quinn::Endpoint::builder();
let _ = endpoint_builder.default_client_config(client);
let server = server.map_or(false, |server| {
let _ = endpoint_builder.listen(server);
true
});
let (endpoint, incoming) = endpoint_builder.bind(&address).map_err(Error::BindSocket)?;
let (sender, receiver) = flume::unbounded();
let receiver = receiver.into_stream();
let task = if server {
let (shutdown_sender, shutdown_receiver) = oneshot::channel();
Task::new(
Self::incoming(incoming, sender, shutdown_receiver),
shutdown_sender,
)
} else {
Task::empty()
};
Ok(Self {
endpoint,
receiver,
task,
})
}
pub fn new_client(ca: &Certificate) -> Result<Self> {
let mut builder = Builder::new();
let _ = builder.add_ca(ca)?;
builder.build().map_err(|(error, _)| error)
}
pub fn new_server(
port: u16,
certificate: &Certificate,
private_key: &PrivateKey,
) -> Result<Self> {
let mut builder = Builder::new();
#[cfg(not(feature = "test"))]
let _ = builder.set_address(([0; 8], port).into());
#[cfg(feature = "test")]
let _ = builder.set_address(([0, 0, 0, 0, 0, 0, 0, 1], port).into());
let _ = builder.add_key_pair(certificate, private_key)?;
builder.build().map_err(|(error, _)| error)
}
async fn incoming(
mut incoming: quinn::Incoming,
sender: Sender<Connecting>,
mut shutdown: Receiver<()>,
) {
while let Some(connecting) = allochronic_util::select! {
connecting: &mut incoming => connecting,
_ : &mut shutdown => None,
} {
if sender.send(Connecting::new(connecting)).is_err() {
break;
}
}
}
pub fn connect<D: AsRef<str>>(&self, address: SocketAddr, domain: D) -> Result<Connecting> {
let connecting = self
.endpoint
.connect(&address, domain.as_ref())
.map_err(Error::ConnectConfig)?;
Ok(Connecting::new(connecting))
}
#[cfg(feature = "dns")]
#[cfg_attr(doc, doc(cfg(feature = "dns")))]
pub async fn connect_with<S: AsRef<str>>(&self, port: u16, domain: S) -> Result<Connecting> {
let config = ResolverConfig::cloudflare_https();
let opts = ResolverOpts {
validate: true,
..ResolverOpts::default()
};
#[allow(box_pointers)]
let resolver = TokioAsyncResolver::tokio(config, opts)
.map_err(|error| Error::Resolve(Box::new(error)))?;
#[allow(box_pointers)]
let ip = resolver
.lookup_ip(domain.as_ref())
.await
.map_err(|error| Error::Resolve(Box::new(error)))?;
ip.into_iter().next().map_or(Err(Error::NoIp), |ip| {
self.connect(SocketAddr::from((ip, port)), domain.as_ref())
})
}
pub fn local_address(&self) -> Result<SocketAddr> {
self.endpoint.local_addr().map_err(Error::LocalAddress)
}
pub async fn close(&self) {
self.endpoint.close(VarInt::from_u32(0), &[]);
let _result = (&self.task).await;
}
pub async fn close_incoming(&self) -> Result<()> {
self.task.close(()).await
}
pub async fn wait_idle(&self) {
self.endpoint.wait_idle().await;
}
}
impl Stream for Endpoint {
type Item = Connecting;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
if self.receiver.is_terminated() {
Poll::Ready(None)
} else {
self.receiver.poll_next_unpin(cx)
}
}
}
impl FusedStream for Endpoint {
fn is_terminated(&self) -> bool {
self.receiver.is_terminated()
}
}
#[cfg(test)]
mod test {
use anyhow::Result;
use super::*;
#[test]
fn builder() {
let _builder: Builder = Endpoint::builder();
}
#[tokio::test]
async fn endpoint() -> Result<()> {
use futures_util::StreamExt;
let (certificate, private_key) = crate::generate_self_signed("test");
let client = Endpoint::new_client(&certificate)?;
let mut server = Endpoint::new_server(0, &certificate, &private_key)?;
let _connection = client
.connect(server.local_address()?, "test")?
.accept::<()>()
.await?;
let _connection = server
.next()
.await
.expect("client dropped")
.accept::<()>()
.await?;
Ok(())
}
#[tokio::test]
async fn close() -> Result<()> {
use futures_util::StreamExt;
use quinn::ConnectionError;
let (certificate, private_key) = crate::generate_self_signed("test");
let client = Endpoint::new_client(&certificate)?;
let mut server = Endpoint::new_server(0, &certificate, &private_key)?;
let address = server.local_address()?;
let _connection = client.connect(address, "test")?.accept::<()>().await?;
let _connection = server
.next()
.await
.expect("client dropped")
.accept::<()>()
.await?;
client.close().await;
server.close().await;
assert!(matches!(
client.connect(address, "test")?.accept::<()>().await,
Err(Error::Connecting(ConnectionError::LocallyClosed))
));
assert!(matches!(server.next().await, None));
client.wait_idle().await;
server.wait_idle().await;
Ok(())
}
#[tokio::test]
async fn close_incoming() -> Result<()> {
use futures_util::StreamExt;
use quinn::{ConnectionClose, ConnectionError};
use quinn_proto::TransportErrorCode;
let (certificate, private_key) = crate::generate_self_signed("test");
let client = Endpoint::new_client(&certificate)?;
let mut server = Endpoint::new_server(0, &certificate, &private_key)?;
let address = server.local_address()?;
let client_connection = client.connect(address, "test")?.accept::<()>().await?;
let mut server_connection = server
.next()
.await
.expect("client dropped")
.accept::<()>()
.await?;
assert!(matches!(
client.close_incoming().await,
Err(Error::AlreadyClosed)
));
server.close_incoming().await?;
assert!(matches!(
server.close_incoming().await,
Err(Error::AlreadyClosed)
));
assert!(matches!(
client.connect(address, "test")?.accept::<()>().await,
Err(Error::Connecting(ConnectionError::ConnectionClosed(
ConnectionClose {
error_code: TransportErrorCode::CONNECTION_REFUSED,
frame_type: None,
reason: bytes,
}
))) if bytes.is_empty()
));
assert!(matches!(server.next().await, None));
{
let (sender, _) = client_connection.open_stream::<(), ()>(&()).await?;
let _server_stream = server_connection
.next()
.await
.expect("client dropped")
.accept::<(), ()>();
sender.finish().await?;
}
drop(client_connection);
drop(server_connection);
client.wait_idle().await;
server.wait_idle().await;
Ok(())
}
#[tokio::test]
async fn wait_idle() -> Result<()> {
use futures_util::StreamExt;
let (certificate, private_key) = crate::generate_self_signed("test");
let client = Endpoint::new_client(&certificate)?;
let mut server = Endpoint::new_server(0, &certificate, &private_key)?;
{
let _connection = client
.connect(server.local_address()?, "test")?
.accept::<()>()
.await?;
let _connection = server
.next()
.await
.expect("client dropped")
.accept::<()>()
.await?;
}
client.wait_idle().await;
server.wait_idle().await;
Ok(())
}
}