use std::{
any::Any,
future::{Future, IntoFuture},
net::{IpAddr, SocketAddr},
pin::Pin,
sync::Arc,
task::Poll,
};
use ed25519_dalek::{VerifyingKey, pkcs8::DecodePublicKey};
use futures_util::{FutureExt, future::Shared};
use iroh_base::{EndpointId, RelayUrl};
use n0_error::{e, stack_error};
use n0_future::{TryFutureExt, future::Boxed as BoxFuture, time::Duration};
use noq::WeakConnectionHandle;
use pin_project::pin_project;
use tracing::{event, warn};
use super::quic::DecryptedInitial;
use crate::{
Endpoint,
endpoint::{
AfterHandshakeOutcome,
quic::{
AcceptBi, AcceptUni, ConnectionError, ConnectionStats, Controller,
ExportKeyingMaterialError, OpenBi, OpenUni, PathId, ReadDatagram, SendDatagram,
SendDatagramError, ServerConfig, Side, VarInt,
},
},
socket::{
RemoteStateActorStoppedError,
remote_map::{PathInfo, PathWatchable, PathWatcher},
transports,
},
};
#[derive(Debug, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub enum IncomingAddr {
Ip(SocketAddr),
Relay {
url: RelayUrl,
endpoint_id: EndpointId,
},
Custom(iroh_base::CustomAddr),
}
impl From<IncomingAddr> for iroh_base::TransportAddr {
fn from(addr: IncomingAddr) -> Self {
match addr {
IncomingAddr::Ip(addr) => Self::Ip(addr),
IncomingAddr::Relay { url, .. } => Self::Relay(url),
IncomingAddr::Custom(addr) => Self::Custom(addr),
}
}
}
impl From<transports::Addr> for IncomingAddr {
fn from(addr: transports::Addr) -> Self {
match addr {
transports::Addr::Ip(addr) => Self::Ip(addr),
transports::Addr::Relay(url, endpoint_id) => Self::Relay { url, endpoint_id },
transports::Addr::Custom(addr) => Self::Custom(addr),
}
}
}
#[derive(derive_more::Debug)]
#[pin_project]
pub struct Accept<'a> {
#[pin]
#[debug("noq::Accept")]
pub(crate) inner: noq::Accept<'a>,
pub(crate) ep: Endpoint,
}
impl Future for Accept<'_> {
type Output = Option<Incoming>;
fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
let this = self.project();
match this.inner.poll(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(None) => Poll::Ready(None),
Poll::Ready(Some(inner)) => {
let incoming = Incoming {
inner,
ep: this.ep.clone(),
};
event!(
target: "iroh::_events::conn::incoming",
tracing::Level::DEBUG,
remote_addr = ?incoming.remote_addr(),
);
Poll::Ready(Some(incoming))
}
}
}
}
#[derive(Debug)]
pub struct Incoming {
inner: noq::Incoming,
ep: Endpoint,
}
impl Incoming {
pub fn accept(self) -> Result<Accepting, ConnectionError> {
self.inner
.accept()
.map(|conn| Accepting::new(conn, self.ep))
}
pub fn accept_with(
self,
server_config: Arc<ServerConfig>,
) -> Result<Accepting, ConnectionError> {
self.inner
.accept_with(server_config.to_inner_arc())
.map(|conn| Accepting::new(conn, self.ep))
}
pub fn refuse(self) {
self.inner.refuse()
}
#[allow(clippy::result_large_err)]
pub fn retry(self) -> Result<(), RetryError> {
self.inner
.retry()
.map_err(|err| e!(RetryError { err, ep: self.ep }))
}
pub fn ignore(self) {
self.inner.ignore()
}
pub fn local_ip(&self) -> Option<IpAddr> {
self.inner.local_ip()
}
pub fn remote_addr(&self) -> IncomingAddr {
self.ep
.to_transport_addr(self.inner.remote_address())
.into()
}
pub fn remote_addr_validated(&self) -> bool {
self.inner.remote_address_validated()
}
pub fn decrypt(&self) -> Option<DecryptedInitial> {
self.inner.decrypt()
}
}
impl IntoFuture for Incoming {
type Output = Result<Connection, ConnectingError>;
type IntoFuture = IncomingFuture;
fn into_future(self) -> Self::IntoFuture {
IncomingFuture(Box::pin(async move {
let noq_conn = self.inner.into_future().await?;
let conn = conn_from_noq_conn(noq_conn, &self.ep)?.await?;
Ok(conn)
}))
}
}
#[stack_error(derive, add_meta, from_sources)]
#[error("retry() with validated Incoming")]
pub struct RetryError {
err: noq::RetryError,
ep: Endpoint,
}
impl RetryError {
pub fn into_incoming(self) -> Incoming {
Incoming {
inner: self.err.into_incoming(),
ep: self.ep,
}
}
}
#[derive(derive_more::Debug)]
#[debug("IncomingFuture")]
pub struct IncomingFuture(BoxFuture<Result<Connection, ConnectingError>>);
impl Future for IncomingFuture {
type Output = Result<Connection, ConnectingError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
self.0.poll_unpin(cx)
}
}
fn alpn_from_noq_conn(conn: &noq::Connection) -> Option<Vec<u8>> {
let data = conn.handshake_data()?;
match data.downcast::<noq::crypto::rustls::HandshakeData>() {
Ok(data) => data.protocol,
Err(_) => None,
}
}
async fn alpn_from_noq_connecting(conn: &mut noq::Connecting) -> Result<Vec<u8>, AlpnError> {
let data = conn.handshake_data().await?;
match data.downcast::<noq::crypto::rustls::HandshakeData>() {
Ok(data) => match data.protocol {
Some(protocol) => Ok(protocol),
None => Err(e!(AlpnError::Unavailable)),
},
Err(_) => Err(e!(AlpnError::UnknownHandshake)),
}
}
#[stack_error(add_meta, derive, from_sources)]
#[allow(missing_docs)]
#[non_exhaustive]
#[derive(Clone)]
pub enum AuthenticationError {
#[error(transparent)]
RemoteId { source: RemoteEndpointIdError },
#[error("no ALPN provided")]
NoAlpn {},
}
fn conn_from_noq_conn(
conn: noq::Connection,
ep: &Endpoint,
) -> Result<
impl Future<Output = Result<Connection, ConnectingError>> + Send + 'static,
ConnectingError,
> {
let info = match static_info_from_conn(&conn) {
Ok(val) => val,
Err(auth_err) => {
if let Some(conn_err) = conn.close_reason() {
return Err(e!(ConnectingError::ConnectionError { source: conn_err }));
} else {
return Err(e!(ConnectingError::HandshakeFailure { source: auth_err }));
}
}
};
event!(
target: "iroh::_events::conn::connected",
tracing::Level::DEBUG,
conn_id = conn.stable_id(),
side = ?conn.side(),
remote_id = %info.endpoint_id.fmt_short(),
alpn = %String::from_utf8_lossy(&info.alpn),
);
let fut = ep
.inner
.register_connection(info.endpoint_id, conn.weak_handle());
let inner = ep.inner.clone();
Ok(async move {
let paths = fut.await?;
let conn = Connection {
data: HandshakeCompletedData { info, paths },
inner: conn,
};
if let AfterHandshakeOutcome::Reject { error_code, reason } =
inner.hooks.after_handshake(&conn.to_info()).await
{
conn.close(error_code, &reason);
return Err(e!(ConnectingError::LocallyRejected));
}
Ok(conn)
})
}
fn static_info_from_conn(conn: &noq::Connection) -> Result<StaticInfo, AuthenticationError> {
let endpoint_id = remote_id_from_noq_conn(conn)?;
let alpn = alpn_from_noq_conn(conn).ok_or_else(|| e!(AuthenticationError::NoAlpn))?;
Ok(StaticInfo { endpoint_id, alpn })
}
fn remote_id_from_noq_conn(conn: &noq::Connection) -> Result<EndpointId, RemoteEndpointIdError> {
let data = conn.peer_identity();
match data {
None => {
warn!("no peer certificate found");
Err(RemoteEndpointIdError::new())
}
Some(data) => match data.downcast::<Vec<rustls::pki_types::CertificateDer>>() {
Ok(certs) => {
if certs.len() != 1 {
warn!(
"expected a single peer certificate, but {} found",
certs.len()
);
return Err(RemoteEndpointIdError::new());
}
let peer_id = EndpointId::from_verifying_key(
VerifyingKey::from_public_key_der(&certs[0])
.map_err(|_| RemoteEndpointIdError::new())?,
);
Ok(peer_id)
}
Err(err) => {
warn!("invalid peer certificate: {:?}", err);
Err(RemoteEndpointIdError::new())
}
},
}
}
#[derive(derive_more::Debug)]
pub struct Connecting {
inner: noq::Connecting,
#[debug("{}", register_with_socket.as_ref().map(|_| "Some(RegisterWithSocketFut)").unwrap_or("None"))]
register_with_socket: Option<RegisterWithSocketFut>,
ep: Endpoint,
remote_endpoint_id: EndpointId,
}
type RegisterWithSocketFut = BoxFuture<Result<Connection, ConnectingError>>;
#[derive(derive_more::Debug)]
pub struct Accepting {
inner: noq::Connecting,
#[debug("{}", register_with_socket.as_ref().map(|_| "Some(RegisterWithSocketFut)").unwrap_or("None"))]
register_with_socket: Option<RegisterWithSocketFut>,
ep: Endpoint,
}
#[stack_error(add_meta, derive, from_sources)]
#[allow(missing_docs)]
#[non_exhaustive]
pub enum AlpnError {
#[error(transparent)]
ConnectionError {
#[error(std_err)]
source: ConnectionError,
},
#[error("No ALPN available")]
Unavailable,
#[error("Unknown handshake type")]
UnknownHandshake,
}
#[stack_error(add_meta, derive, from_sources)]
#[allow(missing_docs)]
#[non_exhaustive]
#[derive(Clone)]
#[allow(private_interfaces)]
pub enum ConnectingError {
#[error(transparent)]
ConnectionError {
#[error(std_err)]
source: ConnectionError,
},
#[error("Failure finalizing the handshake")]
HandshakeFailure { source: AuthenticationError },
#[error("internal consistency error")]
InternalConsistencyError {
source: RemoteStateActorStoppedError,
},
#[error("Connection was rejected locally")]
LocallyRejected,
}
impl Connecting {
pub(crate) fn new(
inner: noq::Connecting,
ep: Endpoint,
remote_endpoint_id: EndpointId,
) -> Self {
Self {
inner,
ep,
remote_endpoint_id,
register_with_socket: None,
}
}
#[allow(clippy::result_large_err)]
pub fn into_0rtt(self) -> Result<OutgoingZeroRttConnection, Connecting> {
match self.inner.into_0rtt() {
Ok((noq_conn, zrtt_accepted)) => {
let accepted: BoxFuture<_> = Box::pin({
let noq_conn = noq_conn.clone();
async move {
let accepted = zrtt_accepted.await;
let conn = conn_from_noq_conn(noq_conn, &self.ep)?.await?;
Ok(match accepted {
true => ZeroRttStatus::Accepted(conn),
false => ZeroRttStatus::Rejected(conn),
})
}
});
let accepted = accepted.shared();
Ok(Connection {
inner: noq_conn,
data: OutgoingZeroRttData { accepted },
})
}
Err(inner) => Err(Self { inner, ..self }),
}
}
pub async fn handshake_data(&mut self) -> Result<Box<dyn Any>, ConnectionError> {
self.inner.handshake_data().await
}
pub async fn alpn(&mut self) -> Result<Vec<u8>, AlpnError> {
alpn_from_noq_connecting(&mut self.inner).await
}
pub fn remote_id(&self) -> EndpointId {
self.remote_endpoint_id
}
}
impl Future for Connecting {
type Output = Result<Connection, ConnectingError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
loop {
if let Some(fut) = &mut self.register_with_socket {
return fut.poll_unpin(cx).map_err(Into::into);
} else {
let noq_conn = std::task::ready!(self.inner.poll_unpin(cx)?);
let fut = conn_from_noq_conn(noq_conn, &self.ep)?;
self.register_with_socket = Some(Box::pin(fut.err_into()));
}
}
}
}
impl Accepting {
pub(crate) fn new(inner: noq::Connecting, ep: Endpoint) -> Self {
Self {
inner,
ep,
register_with_socket: None,
}
}
pub fn into_0rtt(self) -> IncomingZeroRttConnection {
let (noq_conn, zrtt_accepted) = self
.inner
.into_0rtt()
.expect("incoming connections can always be converted to 0-RTT");
let accepted: BoxFuture<_> = Box::pin({
let noq_conn = noq_conn.clone();
async move {
let _ = zrtt_accepted.await;
let conn = conn_from_noq_conn(noq_conn, &self.ep)?.await?;
Ok(conn)
}
});
let accepted = accepted.shared();
IncomingZeroRttConnection {
inner: noq_conn,
data: IncomingZeroRttData { accepted },
}
}
pub fn remote_addr(&self) -> IncomingAddr {
self.ep
.to_transport_addr(self.inner.remote_address())
.into()
}
pub async fn handshake_data(&mut self) -> Result<Box<dyn Any>, ConnectionError> {
self.inner.handshake_data().await
}
pub async fn alpn(&mut self) -> Result<Vec<u8>, AlpnError> {
alpn_from_noq_connecting(&mut self.inner).await
}
}
impl Future for Accepting {
type Output = Result<Connection, ConnectingError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
loop {
if let Some(fut) = &mut self.register_with_socket {
return fut.poll_unpin(cx).map_err(Into::into);
} else {
let noq_conn = std::task::ready!(self.inner.poll_unpin(cx)?);
match conn_from_noq_conn(noq_conn, &self.ep) {
Err(err) => return Poll::Ready(Err(err)),
Ok(fut) => self.register_with_socket = Some(Box::pin(fut.err_into())),
};
}
}
}
}
pub type OutgoingZeroRttConnection = Connection<OutgoingZeroRtt>;
#[derive(Debug, Clone)]
pub enum ZeroRttStatus {
Accepted(Connection),
Rejected(Connection),
}
pub type IncomingZeroRttConnection = Connection<IncomingZeroRtt>;
#[derive(Debug, Clone)]
pub struct Connection<State: ConnectionState = HandshakeCompleted> {
inner: noq::Connection,
data: State::Data,
}
#[doc(hidden)]
#[derive(Debug, Clone)]
pub struct HandshakeCompletedData {
info: StaticInfo,
paths: PathWatchable,
}
#[derive(Debug, Clone)]
struct StaticInfo {
endpoint_id: EndpointId,
alpn: Vec<u8>,
}
#[doc(hidden)]
#[derive(Debug, Clone)]
pub struct IncomingZeroRttData {
accepted: Shared<BoxFuture<Result<Connection, ConnectingError>>>,
}
#[doc(hidden)]
#[derive(Debug, Clone)]
pub struct OutgoingZeroRttData {
accepted: Shared<BoxFuture<Result<ZeroRttStatus, ConnectingError>>>,
}
mod sealed {
pub trait Sealed {}
}
pub trait ConnectionState: sealed::Sealed {
type Data: std::fmt::Debug + Clone;
}
#[derive(Debug, Clone)]
pub struct HandshakeCompleted;
#[derive(Debug, Clone)]
pub struct IncomingZeroRtt;
#[derive(Debug, Clone)]
pub struct OutgoingZeroRtt;
impl sealed::Sealed for HandshakeCompleted {}
impl ConnectionState for HandshakeCompleted {
type Data = HandshakeCompletedData;
}
impl sealed::Sealed for IncomingZeroRtt {}
impl ConnectionState for IncomingZeroRtt {
type Data = IncomingZeroRttData;
}
impl sealed::Sealed for OutgoingZeroRtt {}
impl ConnectionState for OutgoingZeroRtt {
type Data = OutgoingZeroRttData;
}
#[allow(missing_docs)]
#[stack_error(add_meta, derive)]
#[error("Protocol error: no remote id available")]
#[derive(Clone)]
pub struct RemoteEndpointIdError;
impl<T: ConnectionState> Connection<T> {
#[inline]
pub fn open_uni(&self) -> OpenUni<'_> {
self.inner.open_uni()
}
#[inline]
pub fn open_bi(&self) -> OpenBi<'_> {
self.inner.open_bi()
}
#[inline]
pub fn accept_uni(&self) -> AcceptUni<'_> {
self.inner.accept_uni()
}
#[inline]
pub fn accept_bi(&self) -> AcceptBi<'_> {
self.inner.accept_bi()
}
#[inline]
pub fn read_datagram(&self) -> ReadDatagram<'_> {
self.inner.read_datagram()
}
#[inline]
pub async fn closed(&self) -> ConnectionError {
self.inner.closed().await
}
#[inline]
pub fn close_reason(&self) -> Option<ConnectionError> {
self.inner.close_reason()
}
#[inline]
pub fn close(&self, error_code: VarInt, reason: &[u8]) {
self.inner.close(error_code, reason)
}
#[inline]
pub fn send_datagram(&self, data: bytes::Bytes) -> Result<(), SendDatagramError> {
self.inner.send_datagram(data)
}
#[inline]
pub fn send_datagram_wait(&self, data: bytes::Bytes) -> SendDatagram<'_> {
self.inner.send_datagram_wait(data)
}
#[inline]
pub fn max_datagram_size(&self) -> Option<usize> {
self.inner.max_datagram_size()
}
#[inline]
pub fn datagram_send_buffer_space(&self) -> usize {
self.inner.datagram_send_buffer_space()
}
#[inline]
pub fn rtt(&self, path_id: PathId) -> Option<Duration> {
self.inner.rtt(path_id)
}
#[inline]
pub fn stats(&self) -> ConnectionStats {
self.inner.stats()
}
#[inline]
pub fn congestion_state(&self, path_id: PathId) -> Option<Box<dyn Controller>> {
self.inner.congestion_state(path_id)
}
#[inline]
pub fn handshake_data(&self) -> Option<Box<dyn Any>> {
self.inner.handshake_data()
}
#[inline]
pub fn peer_identity(&self) -> Option<Box<dyn Any>> {
self.inner.peer_identity()
}
#[inline]
pub fn stable_id(&self) -> usize {
self.inner.stable_id()
}
#[inline]
pub fn export_keying_material(
&self,
output: &mut [u8],
label: &[u8],
context: &[u8],
) -> Result<(), ExportKeyingMaterialError> {
self.inner.export_keying_material(output, label, context)
}
#[inline]
pub fn set_max_concurrent_uni_streams(&self, count: VarInt) {
self.inner.set_max_concurrent_uni_streams(count)
}
#[inline]
pub fn set_receive_window(&self, receive_window: VarInt) {
self.inner.set_receive_window(receive_window)
}
#[inline]
pub fn set_max_concurrent_bi_streams(&self, count: VarInt) {
self.inner.set_max_concurrent_bi_streams(count)
}
}
impl Connection<HandshakeCompleted> {
pub fn alpn(&self) -> &[u8] {
&self.data.info.alpn
}
pub fn remote_id(&self) -> EndpointId {
self.data.info.endpoint_id
}
pub fn paths(&self) -> PathWatcher {
self.data.paths.watch()
}
pub fn side(&self) -> Side {
self.inner.side()
}
pub fn to_info(&self) -> ConnectionInfo {
ConnectionInfo {
data: self.data.clone(),
inner: self.inner.weak_handle(),
side: self.side(),
}
}
}
impl Connection<IncomingZeroRtt> {
pub fn alpn(&self) -> Option<Vec<u8>> {
alpn_from_noq_conn(&self.inner)
}
pub async fn handshake_completed(&self) -> Result<Connection, ConnectingError> {
self.data.accepted.clone().await
}
pub fn remote_id(&self) -> Result<EndpointId, RemoteEndpointIdError> {
remote_id_from_noq_conn(&self.inner)
}
}
impl Connection<OutgoingZeroRtt> {
pub fn alpn(&self) -> Option<Vec<u8>> {
alpn_from_noq_conn(&self.inner)
}
pub async fn handshake_completed(&self) -> Result<ZeroRttStatus, ConnectingError> {
self.data.accepted.clone().await
}
pub fn remote_id(&self) -> Result<EndpointId, RemoteEndpointIdError> {
remote_id_from_noq_conn(&self.inner)
}
}
#[derive(Debug, Clone)]
pub struct ConnectionInfo {
side: Side,
data: HandshakeCompletedData,
inner: WeakConnectionHandle,
}
#[allow(missing_docs)]
impl ConnectionInfo {
pub fn alpn(&self) -> &[u8] {
&self.data.info.alpn
}
pub fn remote_id(&self) -> EndpointId {
self.data.info.endpoint_id
}
pub fn is_alive(&self) -> bool {
self.inner.upgrade().is_some()
}
pub fn paths(&self) -> PathWatcher {
self.data.paths.watch()
}
pub fn stats(&self) -> Option<ConnectionStats> {
self.inner.upgrade().map(|conn| conn.stats())
}
pub fn side(&self) -> Side {
self.side
}
pub async fn closed(&self) -> Option<(ConnectionError, ConnectionStats)> {
let fut = self.inner.upgrade()?.on_closed();
Some(fut.await)
}
pub fn selected_path(&self) -> Option<PathInfo> {
self.paths().into_iter().find(|path| path.is_selected())
}
}
#[cfg(all(test, with_crypto_provider))]
mod tests {
use std::time::Duration;
use iroh_base::{EndpointAddr, SecretKey};
use iroh_relay::tls::CaRootsConfig;
use n0_error::{Result, StackResultExt, StdResultExt};
use n0_future::StreamExt;
use n0_tracing_test::traced_test;
use n0_watcher::Watcher;
use rand::{RngExt, SeedableRng};
use tracing::{Instrument, error_span, info, info_span, trace_span};
use super::Endpoint;
use crate::{
RelayMode,
endpoint::{ConnectOptions, Incoming, PathInfo, PathInfoList, ZeroRttStatus, presets},
test_utils::run_relay_server,
};
const TEST_ALPN: &[u8] = b"n0/iroh/test";
async fn spawn_0rtt_server(secret_key: SecretKey, log_span: tracing::Span) -> Result<Endpoint> {
let server = Endpoint::builder(presets::Minimal)
.secret_key(secret_key)
.alpns(vec![TEST_ALPN.to_vec()])
.bind()
.instrument(log_span.clone())
.await?;
async fn handle_incoming(incoming: Incoming) -> Result {
let accepting = incoming
.accept()
.std_context("Failed to accept incoming connection")?;
let zrtt_conn = accepting.into_0rtt();
let (mut send, mut recv) = zrtt_conn
.accept_bi()
.await
.std_context("failed to accept bi stream")?;
let data = recv
.read_to_end(10_000_000)
.await
.std_context("Failed to read data")?;
send.write_all(&data)
.await
.std_context("Failed to write data")?;
send.finish().std_context("Failed to finish send")?;
zrtt_conn.closed().await;
Ok(())
}
tokio::spawn({
let server = server.clone();
async move {
tracing::trace!("Server accept loop started");
while let Some(incoming) = server.accept().await {
tracing::trace!("Server received incoming connection");
if let Err(e) = handle_incoming(incoming).await {
tracing::warn!("Failure while handling connection: {e:#}");
}
tracing::trace!("Connection closed, ready for next");
}
tracing::trace!("Server accept loop exiting");
n0_error::Ok(())
}
.instrument(log_span)
});
Ok(server)
}
async fn connect_client_0rtt_expect_err(
client: &Endpoint,
server_addr: EndpointAddr,
) -> Result {
let conn = client
.connect_with_opts(server_addr, TEST_ALPN, ConnectOptions::new())
.await?
.into_0rtt()
.expect_err("expected 0-RTT to fail")
.await?;
let (mut send, mut recv) = conn.open_bi().await.anyerr()?;
send.write_all(b"hello").await.anyerr()?;
send.finish().anyerr()?;
let received = recv.read_to_end(1_000).await.anyerr()?;
assert_eq!(&received, b"hello");
conn.close(0u32.into(), b"thx");
Ok(())
}
async fn connect_client_0rtt_expect_ok(
client: &Endpoint,
server_addr: EndpointAddr,
expect_server_accepts: bool,
) -> Result {
tracing::trace!(?server_addr, "Client connecting with 0-RTT");
let zrtt_conn = client
.connect_with_opts(server_addr, TEST_ALPN, ConnectOptions::new())
.await
.context("connect")?
.into_0rtt()
.ok()
.context("into_0rtt")?;
tracing::trace!("Client established 0-RTT connection");
let (mut send, mut recv) = zrtt_conn.open_bi().await.anyerr()?;
send.write_all(b"hello").await.anyerr()?;
send.finish().anyerr()?;
tracing::trace!("Client sent 0-RTT data, waiting for server response");
let zrtt_res = zrtt_conn.handshake_completed().await;
tracing::trace!(?zrtt_res, "Server responded to 0-RTT");
let zrtt_res = zrtt_res.context("handshake completed")?;
let conn = match zrtt_res {
ZeroRttStatus::Accepted(conn) => {
assert!(expect_server_accepts);
conn
}
ZeroRttStatus::Rejected(conn) => {
assert!(!expect_server_accepts);
let (mut send, r) = conn.open_bi().await.anyerr()?;
send.write_all(b"hello").await.anyerr()?;
send.finish().anyerr()?;
recv = r;
conn
}
};
let received = recv.read_to_end(1_000).await.anyerr()?;
assert_eq!(&received, b"hello");
conn.close(0u32.into(), b"thx");
Ok(())
}
#[tokio::test]
#[traced_test]
async fn test_0rtt() -> Result {
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(42);
let client = Endpoint::bind(presets::Minimal).await?;
let server =
spawn_0rtt_server(SecretKey::from_bytes(&rng.random()), info_span!("server")).await?;
connect_client_0rtt_expect_err(&client, server.addr()).await?;
connect_client_0rtt_expect_ok(&client, server.addr(), true).await?;
client.close().await;
server.close().await;
Ok(())
}
#[tokio::test]
#[traced_test]
async fn test_0rtt_non_consecutive() -> Result {
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(42);
let client = Endpoint::bind(presets::Minimal).await?;
let server =
spawn_0rtt_server(SecretKey::from_bytes(&rng.random()), info_span!("server")).await?;
connect_client_0rtt_expect_err(&client, server.addr()).await?;
let another =
spawn_0rtt_server(SecretKey::from_bytes(&rng.random()), info_span!("another")).await?;
connect_client_0rtt_expect_err(&client, another.addr()).await?;
another.close().await;
connect_client_0rtt_expect_ok(&client, server.addr(), true).await?;
client.close().await;
server.close().await;
Ok(())
}
#[tokio::test]
#[traced_test]
async fn test_0rtt_after_server_restart() -> Result {
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(42);
let client = Endpoint::builder(presets::Minimal)
.bind()
.instrument(info_span!("client"))
.await?;
let server_key = SecretKey::from_bytes(&rng.random());
let server = spawn_0rtt_server(server_key.clone(), info_span!("server-initial")).await?;
connect_client_0rtt_expect_err(&client, server.addr())
.instrument(trace_span!("connect1"))
.await
.context("client connect 1")?;
connect_client_0rtt_expect_ok(&client, server.addr(), true)
.instrument(trace_span!("connect2"))
.await
.context("client connect 2")?;
server.close().await;
let server = spawn_0rtt_server(server_key, info_span!("server-restart")).await?;
connect_client_0rtt_expect_ok(&client, server.addr(), false)
.instrument(trace_span!("connect3"))
.await
.context("client connect 3")?;
tokio::join!(client.close(), server.close());
Ok(())
}
#[tokio::test]
#[traced_test]
async fn test_paths_watcher() -> Result {
const ALPN: &[u8] = b"test";
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(0u64);
let (relay_map, _relay_map, _guard) = run_relay_server().await?;
let server = Endpoint::builder(presets::Minimal)
.relay_mode(RelayMode::Custom(relay_map.clone()))
.secret_key(SecretKey::from_bytes(&rng.random()))
.ca_roots_config(CaRootsConfig::insecure_skip_verify())
.alpns(vec![ALPN.to_vec()])
.bind()
.await?;
let client = Endpoint::builder(presets::Minimal)
.relay_mode(RelayMode::Custom(relay_map.clone()))
.secret_key(SecretKey::from_bytes(&rng.random()))
.ca_roots_config(CaRootsConfig::insecure_skip_verify())
.bind()
.await?;
server.online().await;
let server_addr = server.addr();
info!("server addr: {server_addr:?}");
let (conn_client, conn_server) = tokio::join!(
async { client.connect(server_addr, ALPN).await.unwrap() },
async { server.accept().await.unwrap().await.unwrap() }
);
info!("connected");
let mut paths_client = conn_client.paths().stream();
let mut paths_server = conn_server.paths().stream();
async fn wait_for_paths(
stream: &mut n0_watcher::Stream<impl n0_watcher::Watcher<Value = PathInfoList> + Unpin>,
) {
loop {
let paths = stream.next().await.expect("paths stream ended");
info!(?paths, "paths");
if paths.len() >= 2
&& paths.iter().any(PathInfo::is_relay)
&& paths.iter().any(PathInfo::is_ip)
{
info!("break");
return;
}
}
}
tokio::join!(
async {
tokio::time::timeout(Duration::from_secs(1), wait_for_paths(&mut paths_server))
.instrument(error_span!("paths-server"))
.await
.unwrap()
},
async {
tokio::time::timeout(Duration::from_secs(1), wait_for_paths(&mut paths_client))
.instrument(error_span!("paths-client"))
.await
.unwrap()
}
);
tokio::time::pause();
info!("close client conn");
conn_client.close(0u32.into(), b"");
tokio::time::timeout(Duration::from_nanos(1), async {
while paths_client.next().await.is_some() {}
})
.await
.expect("client paths watcher did not close within 1s of connection close");
tokio::time::timeout(Duration::from_nanos(1), async {
while paths_client.next().await.is_some() {}
})
.await
.expect("server paths watcher did not close within 1s of connection close");
server.close().await;
client.close().await;
Ok(())
}
}