#[cfg(not(feature = "sync"))]
use std::rc::Weak as WeakShared;
#[cfg(feature = "sync")]
use std::sync::Weak as WeakShared;
use std::{
collections::VecDeque,
fmt::Debug,
future::poll_fn,
io,
net::{IpAddr, SocketAddr, SocketAddrV6},
pin::{Pin, pin},
sync::Arc,
task::{Context, Poll, Waker},
time::{Duration, Instant},
};
use compio::buf::bytes::Bytes;
use compio::net::UdpSocket;
use compio::runtime::JoinHandle;
use compio_log::Instrument;
use flume::{Receiver, Sender};
use futures_util::{
FutureExt, StreamExt,
future::{Fuse, FusedFuture, LocalBoxFuture},
select, stream,
};
#[cfg(rustls)]
use noq_proto::crypto::rustls::HandshakeData;
use noq_proto::{
ConnectionHandle, ConnectionStats, Dir, EndpointEvent, FourTuple, NetworkChangeHint, PathError,
PathEvent, PathId, PathStats, PathStatus, Side, StreamEvent, StreamId, VarInt,
congestion::Controller, n0_nat_traversal,
};
use rustc_hash::FxHashMap as HashMap;
use thiserror::Error;
use crate::{
NatTraversalUpdates, ObservedExternalAddr, OpenPath, Path, PathEvents, RecvStream, SendStream,
Socket, SocketEntry, SocketSet,
event_stream::{Broadcast, Watch},
select_socket, spawn_recv_task,
sync::{
mutex_blocking::{Mutex, MutexGuard},
shared::Shared,
},
};
#[derive(Debug)]
pub(crate) enum ConnectionEvent {
Close(VarInt, Bytes),
Proto(noq_proto::ConnectionEvent),
Rebind,
LocalAddressChanged(Option<Arc<dyn NetworkChangeHint + Send + Sync>>),
}
#[derive(Debug)]
pub(crate) struct ConnectionState {
pub(crate) conn: noq_proto::Connection,
pub(crate) error: Option<ConnectionError>,
connected: bool,
handshake_confirmed: bool,
worker: Option<JoinHandle<()>>,
poller: Option<Waker>,
on_connected: Option<Waker>,
on_handshake_data: Option<Waker>,
on_handshake_confirmed: VecDeque<Waker>,
datagram_received: VecDeque<Waker>,
datagrams_unblocked: VecDeque<Waker>,
stream_opened: [VecDeque<Waker>; 2],
stream_available: [VecDeque<Waker>; 2],
open_path: HashMap<PathId, Broadcast<Result<(), PathError>>>,
path_events: Broadcast<PathEvent>,
observed_external_addr: Watch<Option<SocketAddr>>,
nat_traversal_updates: Broadcast<n0_nat_traversal::Event>,
on_closed: Vec<Sender<Closed>>,
final_path_stats: HashMap<PathId, PathStats>,
path_refs: HashMap<PathId, usize>,
pub(crate) writable: HashMap<StreamId, Waker>,
pub(crate) readable: HashMap<StreamId, Waker>,
pub(crate) stopped: HashMap<StreamId, Waker>,
}
impl ConnectionState {
fn terminate(&mut self, reason: ConnectionError) {
self.error = Some(reason.clone());
self.connected = false;
if let Some(waker) = self.on_handshake_data.take() {
waker.wake()
}
if let Some(waker) = self.on_connected.take() {
waker.wake()
}
self.on_handshake_confirmed.drain(..).for_each(Waker::wake);
self.datagram_received.drain(..).for_each(Waker::wake);
self.datagrams_unblocked.drain(..).for_each(Waker::wake);
for e in &mut self.stream_opened {
e.drain(..).for_each(Waker::wake);
}
for e in &mut self.stream_available {
e.drain(..).for_each(Waker::wake);
}
for tx in self.open_path.drain().map(|(_, tx)| tx) {
tx.send(Err(PathError::ValidationFailed));
}
wake_all_streams(&mut self.writable);
wake_all_streams(&mut self.readable);
wake_all_streams(&mut self.stopped);
if !self.on_closed.is_empty() {
let closed = Closed::new(self, reason);
for tx in self.on_closed.drain(..) {
let _ = tx.send(closed.clone());
}
}
}
fn close(&mut self, error_code: VarInt, reason: Bytes) {
self.conn.close(Instant::now(), error_code, reason);
self.terminate(ConnectionError::LocallyClosed);
self.wake();
}
pub(crate) fn wake(&mut self) {
if let Some(waker) = self.poller.take() {
waker.wake()
}
}
#[cfg(rustls)]
fn handshake_data(&self) -> Option<Box<HandshakeData>> {
self.conn
.crypto_session()
.handshake_data()
.map(|data| data.downcast::<HandshakeData>().unwrap())
}
pub(crate) fn check_0rtt(&self) -> bool {
self.conn.side().is_server() || self.conn.is_handshaking() || self.conn.accepted_0rtt()
}
pub(crate) fn path_stats(&mut self, path_id: PathId) -> Option<PathStats> {
self.conn
.path_stats(path_id)
.or_else(|| self.final_path_stats.get(&path_id).copied())
}
pub(crate) fn increment_path_refs(&mut self, path_id: PathId) {
*self.path_refs.entry(path_id).or_default() += 1;
}
pub(crate) fn decrement_path_refs(&mut self, path_id: PathId) {
if let Some(refs) = self.path_refs.get_mut(&path_id) {
*refs = refs.saturating_sub(1);
if *refs == 0 {
self.path_refs.remove(&path_id);
self.final_path_stats.remove(&path_id);
}
}
}
}
fn wake_stream(stream: StreamId, wakers: &mut HashMap<StreamId, Waker>) {
if let Some(waker) = wakers.remove(&stream) {
waker.wake();
}
}
fn wake_all_streams(wakers: &mut HashMap<StreamId, Waker>) {
wakers.drain().for_each(|(_, waker)| waker.wake())
}
fn wake_waiters(wakers: &mut VecDeque<Waker>) {
wakers.drain(..).for_each(Waker::wake)
}
fn push_waker_dedup(wakers: &mut VecDeque<Waker>, waker: &Waker) {
for exist_waker in wakers.iter().rev() {
if exist_waker.will_wake(waker) {
return;
}
}
wakers.push_back(waker.clone());
}
fn push_waker_if_not_last(wakers: &mut VecDeque<Waker>, waker: &Waker) {
match wakers.back() {
Some(existing) if existing.will_wake(waker) => {}
_ => wakers.push_back(waker.clone()),
}
}
fn normalize_network_path(
state: &ConnectionState,
network_path: FourTuple,
) -> Result<FourTuple, PathError> {
let ipv6 = state
.conn
.paths()
.iter()
.filter_map(|id| state.conn.network_path(*id).ok())
.map(|path| path.remote().is_ipv6())
.next()
.unwrap_or_default();
let addr = network_path.remote();
if addr.is_ipv6() && !ipv6 {
return Err(PathError::InvalidRemoteAddress(addr));
}
let addr = if ipv6 {
SocketAddr::V6(match addr {
SocketAddr::V4(addr) => {
SocketAddrV6::new(addr.ip().to_ipv6_mapped(), addr.port(), 0, 0)
}
SocketAddr::V6(addr) => addr,
})
} else {
addr
};
Ok(FourTuple::new(addr, network_path.local_ip()))
}
pub(crate) struct ConnectionInner {
state: Mutex<ConnectionState>,
handle: ConnectionHandle,
pub(crate) sockets: SocketSet,
events_tx: Sender<(ConnectionHandle, EndpointEvent)>,
events_rx: Receiver<ConnectionEvent>,
}
impl std::fmt::Debug for ConnectionInner {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ConnectionInner")
.field("handle", &self.handle)
.finish_non_exhaustive()
}
}
fn implicit_close(this: &Shared<ConnectionInner>) {
if Shared::strong_count(this) == 2 {
this.state().close(0u32.into(), Bytes::new())
}
}
impl ConnectionInner {
fn new(
handle: ConnectionHandle,
conn: noq_proto::Connection,
sockets: SocketSet,
events_tx: Sender<(ConnectionHandle, EndpointEvent)>,
events_rx: Receiver<ConnectionEvent>,
) -> Self {
Self {
state: Mutex::new(ConnectionState {
conn,
connected: false,
handshake_confirmed: false,
error: None,
worker: None,
poller: None,
on_connected: None,
on_handshake_data: None,
on_handshake_confirmed: VecDeque::new(),
datagram_received: VecDeque::new(),
datagrams_unblocked: VecDeque::new(),
stream_opened: [VecDeque::new(), VecDeque::new()],
stream_available: [VecDeque::new(), VecDeque::new()],
open_path: HashMap::default(),
path_events: Broadcast::new(32),
observed_external_addr: Watch::new(None),
nat_traversal_updates: Broadcast::new(32),
on_closed: Vec::new(),
final_path_stats: HashMap::default(),
path_refs: HashMap::default(),
writable: HashMap::default(),
readable: HashMap::default(),
stopped: HashMap::default(),
}),
handle,
sockets,
events_tx,
events_rx,
}
}
#[inline]
pub(crate) fn state(&self) -> MutexGuard<'_, ConnectionState> {
self.state.lock()
}
#[inline]
pub(crate) fn try_state(&self) -> Result<MutexGuard<'_, ConnectionState>, ConnectionError> {
let state = self.state();
if let Some(error) = &state.error {
Err(error.clone())
} else {
Ok(state)
}
}
pub(crate) fn path(inner: &Shared<Self>, id: PathId) -> Option<Path> {
let mut state = inner.state.lock();
state.conn.path_status(id).ok()?;
state.increment_path_refs(id);
drop(state);
Some(Path::new_unchecked_without_ref(inner.clone(), id))
}
pub(crate) fn paths(inner: &Shared<ConnectionInner>) -> Vec<Path> {
let valid_ids = {
let state = inner.state.lock();
let path_ids = state.conn.paths();
path_ids
.into_iter()
.filter(|id| state.conn.path_status(*id).is_ok())
.collect::<Vec<_>>()
};
valid_ids
.into_iter()
.map(|id| Path::new_unchecked(inner.clone(), id))
.collect()
}
pub(crate) fn subscribe_path_events(inner: &Shared<ConnectionInner>) -> PathEvents {
PathEvents::new(inner.state().path_events.subscribe())
}
pub(crate) fn all_path_stats(&self) -> HashMap<PathId, PathStats> {
let mut state = self.state();
let mut stats = state.final_path_stats.clone();
stats.extend(state.conn.paths().into_iter().filter_map(|id| {
let stats = state.conn.path_stats(id)?;
Some((id, stats))
}));
stats
}
pub(crate) fn live_path_stats(&self) -> HashMap<PathId, PathStats> {
let mut state = self.state();
state
.conn
.paths()
.into_iter()
.filter_map(|id| {
let stats = state.conn.path_stats(id)?;
Some((id, stats))
})
.collect()
}
pub(crate) fn all_path_status(&self) -> HashMap<PathId, PathStatus> {
let state = self.state.lock();
state
.conn
.paths()
.into_iter()
.filter_map(|id| {
let status = state.conn.path_status(id).ok()?;
Some((id, status))
})
.collect()
}
async fn run(&self) {
let mut poller = stream::poll_fn(|cx| {
let mut state = self.state();
let ready = state.poller.is_none();
match &state.poller {
Some(waker) if waker.will_wake(cx.waker()) => {}
_ => state.poller = Some(cx.waker().clone()),
};
if ready {
Poll::Ready(Some(()))
} else {
Poll::Pending
}
})
.fuse();
let mut timer = Timer::new();
let mut event_stream = self.events_rx.stream().ready_chunks(100);
let mut send_buf = Some(Vec::with_capacity(self.state().conn.current_mtu() as usize));
let mut transmit_fut = pin!(Fuse::terminated());
loop {
let mut state = select! {
_ = poller.select_next_some() => self.state(),
_ = timer => {
timer.reset(None);
let mut state = self.state();
state.conn.handle_timeout(Instant::now());
state
}
events = event_stream.select_next_some() => {
let mut state = self.state();
for event in events {
match event {
ConnectionEvent::Close(error_code, reason) => state.close(error_code, reason),
ConnectionEvent::Proto(event) => state.conn.handle_event(event),
ConnectionEvent::Rebind => {
state.conn.handle_network_change(None, Instant::now());
}
ConnectionEvent::LocalAddressChanged(hint) => {
state.conn.handle_network_change(
hint.as_deref().map(|hint| hint as &dyn NetworkChangeHint),
Instant::now(),
);
}
}
}
state
},
buf = transmit_fut => {
let mut buf: Vec<_> = buf;
buf.clear();
send_buf = Some(buf);
self.state()
},
};
if let Some(mut buf) = send_buf.take() {
let default_socket = select_socket(&self.sockets, None);
let default_max_gso = default_socket.max_gso_segments();
if let Some(transmit) =
state
.conn
.poll_transmit(Instant::now(), default_max_gso, &mut buf)
{
let socket = if transmit.src_ip.is_some() {
select_socket(&self.sockets, transmit.src_ip)
} else {
default_socket
};
transmit_fut.set(async move { socket.send(buf, &transmit).await }.fuse())
} else {
send_buf = Some(buf);
}
}
timer.reset(state.conn.poll_timeout());
while let Some(event) = state.conn.poll_endpoint_events() {
let _ = self.events_tx.send((self.handle, event));
}
while let Some(event) = state.conn.poll() {
use noq_proto::Event::*;
match event {
HandshakeDataReady => {
if let Some(waker) = state.on_handshake_data.take() {
waker.wake()
}
}
Connected => {
state.connected = true;
if let Some(waker) = state.on_connected.take() {
waker.wake()
}
if state.conn.side().is_client() && !state.conn.accepted_0rtt() {
wake_all_streams(&mut state.writable);
wake_all_streams(&mut state.readable);
wake_all_streams(&mut state.stopped);
}
}
ConnectionLost { reason } => state.terminate(reason.into()),
Stream(StreamEvent::Readable { id }) => wake_stream(id, &mut state.readable),
Stream(StreamEvent::Writable { id }) => wake_stream(id, &mut state.writable),
Stream(StreamEvent::Finished { id }) => wake_stream(id, &mut state.stopped),
Stream(StreamEvent::Stopped { id, .. }) => {
wake_stream(id, &mut state.stopped);
wake_stream(id, &mut state.writable);
}
Stream(StreamEvent::Available { dir }) => state.stream_available[dir as usize]
.drain(..)
.for_each(Waker::wake),
Stream(StreamEvent::Opened { dir }) => state.stream_opened[dir as usize]
.drain(..)
.for_each(Waker::wake),
DatagramReceived => wake_waiters(&mut state.datagram_received),
DatagramsUnblocked => wake_waiters(&mut state.datagrams_unblocked),
HandshakeConfirmed => {
state.handshake_confirmed = true;
wake_waiters(&mut state.on_handshake_confirmed);
}
Path(event) => {
match &event {
PathEvent::ObservedAddr { addr, .. } => {
state.observed_external_addr.send_if_modified(|value| {
let old = value.replace(*addr);
old != *value
});
}
PathEvent::Established { id, .. } => {
if let Some(tx) = state.open_path.remove(id) {
tx.send(Ok(()));
}
}
PathEvent::Abandoned { id, .. } => {
if let Some(tx) = state.open_path.remove(id) {
tx.send(Err(PathError::ValidationFailed));
}
}
PathEvent::Discarded { id, path_stats, .. } => {
if state.path_refs.contains_key(id) {
state.final_path_stats.insert(*id, **path_stats);
}
}
PathEvent::RemoteStatus { .. } => {}
_ => {}
}
state.path_events.send(event);
}
NatTraversal(event) => {
state.nat_traversal_updates.send(event);
}
}
}
if state.conn.is_drained() {
break;
}
}
if let Some(worker) = self.state().worker.take() {
worker.detach();
}
}
}
macro_rules! conn_fn {
() => {
pub fn side(&self) -> Side {
self.0.state().conn.side()
}
pub fn local_ip(&self) -> Option<IpAddr> {
let state = self.0.state();
state
.conn
.paths()
.iter()
.filter_map(|id| state.conn.network_path(*id).ok())
.next()
.and_then(|path| path.local_ip())
}
pub fn remote_address(&self) -> SocketAddr {
let state = self.0.state();
state
.conn
.paths()
.iter()
.filter_map(|id| state.conn.network_path(*id).ok())
.next()
.expect("remote_address called on a connection with no paths")
.remote()
}
pub fn rtt(&self, path_id: PathId) -> Option<Duration> {
self.0.state().conn.rtt(path_id)
}
pub fn stats(&self) -> ConnectionStats {
self.0.state().conn.stats()
}
pub fn congestion_state(&self, path_id: PathId) -> Option<Box<dyn Controller>> {
self.0
.state()
.conn
.congestion_state(path_id)
.map(|state| state.clone_box())
}
pub fn peer_identity(
&self,
) -> Option<Box<Vec<rustls::pki_types::CertificateDer<'static>>>> {
self.0
.state()
.conn
.crypto_session()
.peer_identity()
.map(|v| v.downcast().unwrap())
}
pub fn stable_id(&self) -> usize {
Shared::as_ptr(&self.0) as usize
}
pub fn export_keying_material(
&self,
output: &mut [u8],
label: &[u8],
context: &[u8],
) -> Result<(), noq_proto::crypto::ExportKeyingMaterialError> {
self.0
.state()
.conn
.crypto_session()
.export_keying_material(output, label, context)
}
};
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct Closed {
pub reason: ConnectionError,
pub stats: ConnectionStats,
pub path_stats: Vec<(PathId, PathStats)>,
}
impl Closed {
fn new(state: &mut ConnectionState, reason: ConnectionError) -> Self {
let stats = state.conn.stats();
let mut path_stats = Vec::new();
path_stats.extend(
state
.conn
.paths()
.into_iter()
.filter_map(|id| state.conn.path_stats(id).map(|stats| (id, stats))),
);
path_stats.extend(
state
.final_path_stats
.iter()
.map(|(id, stats)| (*id, *stats)),
);
Self {
reason,
stats,
path_stats,
}
}
}
#[derive(Debug)]
pub struct OnClosed {
rx: Receiver<Closed>,
conn: WeakConnectionHandle,
}
impl Future for OnClosed {
type Output = Closed;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let fut = self.rx.recv_async();
futures_util::pin_mut!(fut);
fut.poll(cx)
.map(|closed| closed.expect("on_closed sender is kept until connection termination"))
}
}
impl Drop for OnClosed {
fn drop(&mut self) {
if let Some(conn) = self.conn.upgrade_inner() {
conn.state().on_closed.retain(|tx| !tx.is_disconnected());
}
}
}
#[derive(Debug, Clone)]
pub struct ZeroRttAccepted(Shared<ConnectionInner>);
impl Future for ZeroRttAccepted {
type Output = bool;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut state = self.0.state();
if state.connected {
return Poll::Ready(state.conn.accepted_0rtt());
}
if state.error.is_some() {
return Poll::Ready(false);
}
match &state.on_connected {
Some(waker) if waker.will_wake(cx.waker()) => {}
_ => state.on_connected = Some(cx.waker().clone()),
}
Poll::Pending
}
}
#[derive(Debug)]
#[must_use = "futures/streams/sinks do nothing unless you `.await` or poll them"]
pub struct Connecting(Shared<ConnectionInner>);
impl Connecting {
conn_fn!();
pub(crate) fn new(
handle: ConnectionHandle,
conn: noq_proto::Connection,
sockets: SocketSet,
events_tx: Sender<(ConnectionHandle, EndpointEvent)>,
events_rx: Receiver<ConnectionEvent>,
) -> Self {
let inner = Shared::new(ConnectionInner::new(
handle, conn, sockets, events_tx, events_rx,
));
let worker = compio::runtime::spawn({
let inner = inner.clone();
async move { inner.run().await }.in_current_span()
});
inner.state().worker = Some(worker);
Self(inner)
}
#[cfg(rustls)]
pub async fn handshake_data(&mut self) -> Result<Box<HandshakeData>, ConnectionError> {
poll_fn(|cx| {
let mut state = self.0.try_state()?;
if let Some(data) = state.handshake_data() {
return Poll::Ready(Ok(data));
}
match &state.on_handshake_data {
Some(waker) if waker.will_wake(cx.waker()) => {}
_ => state.on_handshake_data = Some(cx.waker().clone()),
}
Poll::Pending
})
.await
}
pub fn into_0rtt(self) -> Result<(Connection, ZeroRttAccepted), Self> {
let is_ok = {
let state = self.0.state();
state.conn.has_0rtt() || state.conn.side().is_server()
};
if is_ok {
Ok((Connection(self.0.clone()), ZeroRttAccepted(self.0.clone())))
} else {
Err(self)
}
}
}
impl Future for Connecting {
type Output = Result<Connection, ConnectionError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut state = self.0.try_state()?;
if state.connected {
return Poll::Ready(Ok(Connection(self.0.clone())));
}
match &state.on_connected {
Some(waker) if waker.will_wake(cx.waker()) => {}
_ => state.on_connected = Some(cx.waker().clone()),
}
Poll::Pending
}
}
impl Drop for Connecting {
fn drop(&mut self) {
implicit_close(&self.0)
}
}
#[derive(Debug, Clone)]
pub struct Connection(Shared<ConnectionInner>);
#[derive(Debug, Clone)]
pub struct WeakConnectionHandle(WeakShared<ConnectionInner>);
impl WeakConnectionHandle {
pub(crate) fn new(conn: &Shared<ConnectionInner>) -> Self {
Self(Shared::downgrade(conn))
}
pub(crate) fn upgrade_inner(&self) -> Option<Shared<ConnectionInner>> {
self.0.upgrade()
}
pub fn upgrade(&self) -> Option<Connection> {
self.upgrade_inner().map(Connection)
}
pub(crate) fn is_same_connection(&self, other: &Self) -> bool {
WeakShared::ptr_eq(&self.0, &other.0)
}
}
impl Connection {
conn_fn!();
pub fn weak_handle(&self) -> WeakConnectionHandle {
WeakConnectionHandle::new(&self.0)
}
pub fn on_closed(&self) -> OnClosed {
let (tx, rx) = flume::bounded(1);
let mut state = self.0.state();
if let Some(reason) = state.error.clone() {
let _ = tx.send(Closed::new(&mut state, reason));
} else {
state.on_closed.push(tx);
}
OnClosed {
rx,
conn: self.weak_handle(),
}
}
pub fn force_key_update(&self) {
self.0.state().conn.force_key_update()
}
#[cfg(rustls)]
pub fn handshake_data(&mut self) -> Result<Box<HandshakeData>, ConnectionError> {
self.0
.try_state()?
.handshake_data()
.ok_or(ConnectionError::LocallyClosed)
}
pub fn max_datagram_size(&self) -> Option<usize> {
self.0.state().conn.datagrams().max_size()
}
pub fn datagram_send_buffer_space(&self) -> usize {
self.0.state().conn.datagrams().send_buffer_space()
}
pub fn set_max_concurrent_uni_streams(&self, count: VarInt) {
let mut state = self.0.state();
state.conn.set_max_concurrent_streams(Dir::Uni, count);
state.wake();
}
pub fn set_send_window(&self, send_window: u64) {
let mut state = self.0.state();
state.conn.set_send_window(send_window);
state.wake();
}
pub fn set_receive_window(&self, receive_window: VarInt) {
let mut state = self.0.state();
state.conn.set_receive_window(receive_window);
state.wake();
}
pub fn set_max_concurrent_bi_streams(&self, count: VarInt) {
let mut state = self.0.state();
state.conn.set_max_concurrent_streams(Dir::Bi, count);
state.wake();
}
pub fn close(&self, error_code: VarInt, reason: &[u8]) {
self.0
.state()
.close(error_code, Bytes::copy_from_slice(reason));
}
pub async fn handshake_confirmed(&self) -> Result<(), ConnectionError> {
poll_fn(|cx| {
let mut state = self.0.try_state()?;
if state.handshake_confirmed {
return Poll::Ready(Ok(()));
}
if !state
.on_handshake_confirmed
.iter()
.any(|waker| waker.will_wake(cx.waker()))
{
state.on_handshake_confirmed.push_back(cx.waker().clone());
}
Poll::Pending
})
.await
}
pub async fn closed(&self) -> ConnectionError {
let worker = self.0.state().worker.take();
if let Some(worker) = worker {
let _ = worker.await;
return self.0.try_state().unwrap_err();
}
self.on_closed().await.reason
}
pub fn close_reason(&self) -> Option<ConnectionError> {
self.0.try_state().err()
}
pub fn open_path_ensure(
&self,
network_path: impl Into<FourTuple>,
initial_status: PathStatus,
) -> OpenPath {
let mut state = self.0.state();
let network_path = match normalize_network_path(&state, network_path.into()) {
Ok(network_path) => network_path,
Err(err) => return OpenPath::rejected(err),
};
let result = state
.conn
.open_path_ensure(network_path, initial_status, Instant::now());
match result {
Ok((path_id, true)) => {
if let Some(tx) = state.open_path.get(&path_id) {
OpenPath::new(path_id, tx.subscribe(), self.0.clone())
} else {
OpenPath::ready(path_id, self.0.clone())
}
}
Ok((path_id, false)) => {
let tx = Broadcast::new(1);
let rx = tx.subscribe();
state.open_path.insert(path_id, tx);
state.wake();
OpenPath::new(path_id, rx, self.0.clone())
}
Err(err) => OpenPath::rejected(err),
}
}
pub fn open_path(
&self,
network_path: impl Into<FourTuple>,
initial_status: PathStatus,
) -> OpenPath {
let mut state = self.0.state();
let network_path = match normalize_network_path(&state, network_path.into()) {
Ok(network_path) => network_path,
Err(err) => return OpenPath::rejected(err),
};
let tx = Broadcast::new(1);
let rx = tx.subscribe();
let result = state
.conn
.open_path(network_path, initial_status, Instant::now());
match result {
Ok(path_id) => {
state.open_path.insert(path_id, tx);
state.wake();
OpenPath::new(path_id, rx, self.0.clone())
}
Err(err) => OpenPath::rejected(err),
}
}
pub fn open_path_socket(
&self,
addr: SocketAddr,
socket: UdpSocket,
initial_status: PathStatus,
) -> Result<OpenPath, io::Error> {
let mut state = self.0.state();
let network_path = match normalize_network_path(&state, FourTuple::new(addr, None)) {
Ok(network_path) => network_path,
Err(err) => return Ok(OpenPath::rejected(err)),
};
let socket = Socket::new(socket)?;
let local_addr = socket.local_addr()?;
let local_ip = match local_addr.ip() {
IpAddr::V4(a) if a.is_unspecified() => None,
IpAddr::V6(a) if a.is_unspecified() => None,
ip => Some(ip),
};
let max_payload = self.0.sockets.lock().unwrap().max_payload_size;
{
let mut shared = self.0.sockets.lock().unwrap();
shared.sockets.push(SocketEntry {
socket: socket.clone(),
local_ip,
});
}
spawn_recv_task(&self.0.sockets, &socket, max_payload);
let tx = Broadcast::new(1);
let rx = tx.subscribe();
let result = state.conn.open_path(
FourTuple::new(network_path.remote(), local_ip),
initial_status,
Instant::now(),
);
match result {
Ok(path_id) => {
state.open_path.insert(path_id, tx);
state.wake();
Ok(OpenPath::new(path_id, rx, self.0.clone()))
}
Err(err) => Ok(OpenPath::rejected(err)),
}
}
pub fn path(&self, id: PathId) -> Option<Path> {
ConnectionInner::path(&self.0, id)
}
pub fn path_events(&self) -> PathEvents {
PathEvents::new(self.0.state().path_events.subscribe())
}
pub fn paths(&self) -> Vec<Path> {
ConnectionInner::paths(&self.0)
}
pub fn all_path_stats(&self) -> HashMap<PathId, PathStats> {
self.0.all_path_stats()
}
pub fn live_path_stats(&self) -> HashMap<PathId, PathStats> {
self.0.live_path_stats()
}
pub fn all_path_status(&self) -> HashMap<PathId, PathStatus> {
self.0.all_path_status()
}
pub fn nat_traversal_updates(&self) -> NatTraversalUpdates {
NatTraversalUpdates::new(self.0.state().nat_traversal_updates.subscribe())
}
pub fn observed_external_addr(&self) -> ObservedExternalAddr {
ObservedExternalAddr::new(self.0.state().observed_external_addr.subscribe())
}
pub fn path_stats(&self, path_id: PathId) -> Option<PathStats> {
self.0.state().path_stats(path_id)
}
pub fn is_multipath_enabled(&self) -> bool {
self.0.state().conn.is_multipath_negotiated()
}
pub fn add_nat_traversal_address(
&self,
address: SocketAddr,
) -> Result<(), n0_nat_traversal::Error> {
let mut state = self.0.state();
state.conn.add_nat_traversal_address(address)?;
state.wake();
Ok(())
}
pub fn remove_nat_traversal_address(
&self,
address: SocketAddr,
) -> Result<(), n0_nat_traversal::Error> {
let mut state = self.0.state();
state.conn.remove_nat_traversal_address(address)?;
state.wake();
Ok(())
}
pub fn get_local_nat_traversal_addresses(
&self,
) -> Result<Vec<SocketAddr>, n0_nat_traversal::Error> {
self.0.state().conn.get_local_nat_traversal_addresses()
}
pub fn get_remote_nat_traversal_addresses(
&self,
) -> Result<Vec<SocketAddr>, n0_nat_traversal::Error> {
self.0.state().conn.get_remote_nat_traversal_addresses()
}
pub fn initiate_nat_traversal_round(&self) -> Result<Vec<SocketAddr>, n0_nat_traversal::Error> {
let mut state = self.0.state();
let addresses = state.conn.initiate_nat_traversal_round(Instant::now())?;
state.wake();
Ok(addresses)
}
pub fn poll_recv_datagram(&self, cx: &mut Context) -> Poll<Result<Bytes, ConnectionError>> {
let mut state = self.0.try_state()?;
if let Some(bytes) = state.conn.datagrams().recv() {
return Poll::Ready(Ok(bytes));
}
push_waker_dedup(&mut state.datagram_received, cx.waker());
Poll::Pending
}
pub fn try_recv_datagram(&self) -> Result<Option<Bytes>, ConnectionError> {
let mut state = self.0.try_state()?;
Ok(state.conn.datagrams().recv())
}
pub async fn read_datagram(&self) -> Result<Bytes, ConnectionError> {
poll_fn(|cx| self.poll_recv_datagram(cx)).await
}
fn try_send_datagram(
&self,
cx: Option<&mut Context>,
data: Bytes,
) -> Result<(), Result<SendDatagramError, Bytes>> {
use noq_proto::SendDatagramError::*;
let mut state = self.0.try_state().map_err(|e| Ok(e.into()))?;
state
.conn
.datagrams()
.send(data, cx.is_none())
.map_err(|err| match err {
UnsupportedByPeer => Ok(SendDatagramError::UnsupportedByPeer),
Disabled => Ok(SendDatagramError::Disabled),
TooLarge => Ok(SendDatagramError::TooLarge),
Blocked(data) => {
let cx = cx.expect("blocked datagram sends are only possible when waiting");
state.datagrams_unblocked.push_back(cx.waker().clone());
Err(data)
}
})?;
state.wake();
Ok(())
}
pub fn send_datagram(&self, data: Bytes) -> Result<(), SendDatagramError> {
self.try_send_datagram(None, data).map_err(Result::unwrap)
}
pub async fn send_datagram_wait(&self, data: Bytes) -> Result<(), SendDatagramError> {
let mut data = Some(data);
poll_fn(
|cx| match self.try_send_datagram(Some(cx), data.take().unwrap()) {
Ok(()) => Poll::Ready(Ok(())),
Err(Ok(e)) => Poll::Ready(Err(e)),
Err(Err(b)) => {
data.replace(b);
Poll::Pending
}
},
)
.await
}
fn poll_open_stream(
&self,
cx: Option<&mut Context>,
dir: Dir,
) -> Poll<Result<(StreamId, bool), ConnectionError>> {
let mut state = self.0.try_state()?;
if let Some(stream) = state.conn.streams().open(dir) {
Poll::Ready(Ok((
stream,
state.conn.side().is_client() && state.conn.is_handshaking(),
)))
} else {
if let Some(cx) = cx {
state.stream_available[dir as usize].push_back(cx.waker().clone());
}
Poll::Pending
}
}
pub fn try_open_uni(&self) -> Result<SendStream, OpenStreamError> {
match self.poll_open_stream(None, Dir::Uni) {
Poll::Ready(Ok((stream, is_0rtt))) => {
Ok(SendStream::new(self.0.clone(), stream, is_0rtt))
}
Poll::Ready(Err(e)) => Err(e.into()),
Poll::Pending => Err(OpenStreamError::StreamsExhausted),
}
}
pub async fn open_uni(&self) -> Result<SendStream, ConnectionError> {
let (stream, is_0rtt) = poll_fn(|cx| self.poll_open_stream(Some(cx), Dir::Uni)).await?;
Ok(SendStream::new(self.0.clone(), stream, is_0rtt))
}
pub fn try_open_bi(&self) -> Result<(SendStream, RecvStream), OpenStreamError> {
match self.poll_open_stream(None, Dir::Bi) {
Poll::Ready(Ok((stream, is_0rtt))) => Ok((
SendStream::new(self.0.clone(), stream, is_0rtt),
RecvStream::new(self.0.clone(), stream, is_0rtt),
)),
Poll::Ready(Err(e)) => Err(e.into()),
Poll::Pending => Err(OpenStreamError::StreamsExhausted),
}
}
pub async fn open_bi(&self) -> Result<(SendStream, RecvStream), ConnectionError> {
let (stream, is_0rtt) = poll_fn(|cx| self.poll_open_stream(Some(cx), Dir::Bi)).await?;
Ok((
SendStream::new(self.0.clone(), stream, is_0rtt),
RecvStream::new(self.0.clone(), stream, is_0rtt),
))
}
fn poll_accept_stream(
&self,
cx: &mut Context,
dir: Dir,
) -> Poll<Result<(StreamId, bool), ConnectionError>> {
let mut state = self.0.try_state()?;
if let Some(stream) = state.conn.streams().accept(dir) {
state.wake();
Poll::Ready(Ok((stream, state.conn.is_handshaking())))
} else {
push_waker_if_not_last(&mut state.stream_opened[dir as usize], cx.waker());
Poll::Pending
}
}
pub async fn accept_uni(&self) -> Result<RecvStream, ConnectionError> {
let (stream, is_0rtt) = poll_fn(|cx| self.poll_accept_stream(cx, Dir::Uni)).await?;
Ok(RecvStream::new(self.0.clone(), stream, is_0rtt))
}
pub async fn accept_bi(&self) -> Result<(SendStream, RecvStream), ConnectionError> {
let (stream, is_0rtt) = poll_fn(|cx| self.poll_accept_stream(cx, Dir::Bi)).await?;
Ok((
SendStream::new(self.0.clone(), stream, is_0rtt),
RecvStream::new(self.0.clone(), stream, is_0rtt),
))
}
}
impl PartialEq for Connection {
fn eq(&self, other: &Self) -> bool {
Shared::ptr_eq(&self.0, &other.0)
}
}
impl Eq for Connection {}
impl Drop for Connection {
fn drop(&mut self) {
implicit_close(&self.0)
}
}
struct Timer {
deadline: Option<Instant>,
fut: Fuse<LocalBoxFuture<'static, ()>>,
}
impl Timer {
fn new() -> Self {
Self {
deadline: None,
fut: Fuse::terminated(),
}
}
fn reset(&mut self, deadline: Option<Instant>) {
if let Some(deadline) = deadline {
if self.deadline.is_none() || self.deadline != Some(deadline) {
self.fut = compio::runtime::time::sleep_until(deadline)
.boxed_local()
.fuse();
}
} else {
self.fut = Fuse::terminated();
}
self.deadline = deadline;
}
}
impl Future for Timer {
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.fut.poll_unpin(cx)
}
}
impl FusedFuture for Timer {
fn is_terminated(&self) -> bool {
self.fut.is_terminated()
}
}
#[derive(Debug, Error, Clone, PartialEq, Eq)]
pub enum ConnectionError {
#[error("peer doesn't implement any supported version")]
VersionMismatch,
#[error(transparent)]
TransportError(#[from] noq_proto::TransportError),
#[error("aborted by peer: {0}")]
ConnectionClosed(noq_proto::ConnectionClose),
#[error("closed by peer: {0}")]
ApplicationClosed(noq_proto::ApplicationClose),
#[error("reset by peer")]
Reset,
#[error("timed out")]
TimedOut,
#[error("closed")]
LocallyClosed,
#[error("CIDs exhausted")]
CidsExhausted,
}
impl From<noq_proto::ConnectionError> for ConnectionError {
fn from(value: noq_proto::ConnectionError) -> Self {
use noq_proto::ConnectionError::*;
match value {
VersionMismatch => ConnectionError::VersionMismatch,
TransportError(e) => ConnectionError::TransportError(e),
ConnectionClosed(e) => ConnectionError::ConnectionClosed(e),
ApplicationClosed(e) => ConnectionError::ApplicationClosed(e),
Reset => ConnectionError::Reset,
TimedOut => ConnectionError::TimedOut,
LocallyClosed => ConnectionError::LocallyClosed,
CidsExhausted => ConnectionError::CidsExhausted,
}
}
}
#[derive(Debug, Error, Clone, Eq, PartialEq)]
pub enum SendDatagramError {
#[error("datagrams not supported by peer")]
UnsupportedByPeer,
#[error("datagram support disabled")]
Disabled,
#[error("datagram too large")]
TooLarge,
#[error("connection lost")]
ConnectionLost(#[from] ConnectionError),
}
#[derive(Debug, Error, Clone, Eq, PartialEq)]
pub enum OpenStreamError {
#[error("connection lost")]
ConnectionLost(#[from] ConnectionError),
#[error("streams exhausted")]
StreamsExhausted,
}
#[cfg(feature = "h3")]
pub(crate) mod h3_impl {
use std::sync::Arc;
use compio::buf::bytes::Buf;
use futures_util::ready;
use h3::{
error::Code,
quic::{self, ConnectionErrorIncoming, StreamErrorIncoming, WriteBuf},
};
use h3_datagram::{
datagram::EncodedDatagram,
quic_traits::{
DatagramConnectionExt, RecvDatagram, SendDatagram, SendDatagramErrorIncoming,
},
};
use super::*;
use crate::send_stream::h3_impl::SendStream;
impl From<ConnectionError> for ConnectionErrorIncoming {
fn from(e: ConnectionError) -> Self {
use ConnectionError::*;
match e {
ApplicationClosed(e) => Self::ApplicationClose {
error_code: e.error_code.into_inner(),
},
TimedOut => Self::Timeout,
e => Self::Undefined(Arc::new(e)),
}
}
}
impl From<ConnectionError> for StreamErrorIncoming {
fn from(e: ConnectionError) -> Self {
Self::ConnectionErrorIncoming {
connection_error: e.into(),
}
}
}
impl From<SendDatagramError> for SendDatagramErrorIncoming {
fn from(e: SendDatagramError) -> Self {
use SendDatagramError::*;
match e {
UnsupportedByPeer | Disabled => Self::NotAvailable,
TooLarge => Self::TooLarge,
ConnectionLost(e) => Self::ConnectionError(e.into()),
}
}
}
impl<B> SendDatagram<B> for Connection
where
B: Buf,
{
fn send_datagram<T: Into<EncodedDatagram<B>>>(
&mut self,
data: T,
) -> Result<(), SendDatagramErrorIncoming> {
let mut buf: EncodedDatagram<B> = data.into();
let buf = buf.copy_to_bytes(buf.remaining());
Ok(Connection::send_datagram(self, buf)?)
}
}
impl RecvDatagram for Connection {
type Buffer = Bytes;
fn poll_incoming_datagram(
&mut self,
cx: &mut Context<'_>,
) -> Poll<Result<Self::Buffer, ConnectionErrorIncoming>> {
Poll::Ready(Ok(ready!(self.poll_recv_datagram(cx))?))
}
}
impl<B: Buf> DatagramConnectionExt<B> for Connection {
type SendDatagramHandler = Self;
type RecvDatagramHandler = Self;
fn send_datagram_handler(&self) -> Self::SendDatagramHandler {
self.clone()
}
fn recv_datagram_handler(&self) -> Self::RecvDatagramHandler {
self.clone()
}
}
pub struct BidiStream<B> {
send: SendStream<B>,
recv: RecvStream,
}
impl<B> BidiStream<B> {
pub(crate) fn new(conn: Shared<ConnectionInner>, stream: StreamId, is_0rtt: bool) -> Self {
Self {
send: SendStream::new(conn.clone(), stream, is_0rtt),
recv: RecvStream::new(conn, stream, is_0rtt),
}
}
}
impl<B> quic::BidiStream<B> for BidiStream<B>
where
B: Buf,
{
type SendStream = SendStream<B>;
type RecvStream = RecvStream;
fn split(self) -> (Self::SendStream, Self::RecvStream) {
(self.send, self.recv)
}
}
impl<B> quic::RecvStream for BidiStream<B>
where
B: Buf,
{
type Buf = Bytes;
fn poll_data(
&mut self,
cx: &mut Context<'_>,
) -> Poll<Result<Option<Self::Buf>, StreamErrorIncoming>> {
self.recv.poll_data(cx)
}
fn stop_sending(&mut self, error_code: u64) {
self.recv.stop_sending(error_code)
}
fn recv_id(&self) -> quic::StreamId {
self.recv.recv_id()
}
}
impl<B> quic::SendStream<B> for BidiStream<B>
where
B: Buf,
{
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), StreamErrorIncoming>> {
self.send.poll_ready(cx)
}
fn send_data<T: Into<WriteBuf<B>>>(&mut self, data: T) -> Result<(), StreamErrorIncoming> {
self.send.send_data(data)
}
fn poll_finish(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), StreamErrorIncoming>> {
self.send.poll_finish(cx)
}
fn reset(&mut self, reset_code: u64) {
self.send.reset(reset_code)
}
fn send_id(&self) -> quic::StreamId {
self.send.send_id()
}
}
impl<B> quic::SendStreamUnframed<B> for BidiStream<B>
where
B: Buf,
{
fn poll_send<D: Buf>(
&mut self,
cx: &mut Context<'_>,
buf: &mut D,
) -> Poll<Result<usize, StreamErrorIncoming>> {
self.send.poll_send(cx, buf)
}
}
#[derive(Clone)]
pub struct OpenStreams(Connection);
impl<B> quic::OpenStreams<B> for OpenStreams
where
B: Buf,
{
type BidiStream = BidiStream<B>;
type SendStream = SendStream<B>;
fn poll_open_bidi(
&mut self,
cx: &mut Context<'_>,
) -> Poll<Result<Self::BidiStream, StreamErrorIncoming>> {
let (stream, is_0rtt) = ready!(self.0.poll_open_stream(Some(cx), Dir::Bi))?;
Poll::Ready(Ok(BidiStream::new(self.0.0.clone(), stream, is_0rtt)))
}
fn poll_open_send(
&mut self,
cx: &mut Context<'_>,
) -> Poll<Result<Self::SendStream, StreamErrorIncoming>> {
let (stream, is_0rtt) = ready!(self.0.poll_open_stream(Some(cx), Dir::Uni))?;
Poll::Ready(Ok(SendStream::new(self.0.0.clone(), stream, is_0rtt)))
}
fn close(&mut self, code: Code, reason: &[u8]) {
self.0
.close(code.value().try_into().expect("invalid code"), reason)
}
}
impl<B> quic::OpenStreams<B> for Connection
where
B: Buf,
{
type BidiStream = BidiStream<B>;
type SendStream = SendStream<B>;
fn poll_open_bidi(
&mut self,
cx: &mut Context<'_>,
) -> Poll<Result<Self::BidiStream, StreamErrorIncoming>> {
let (stream, is_0rtt) = ready!(self.poll_open_stream(Some(cx), Dir::Bi))?;
Poll::Ready(Ok(BidiStream::new(self.0.clone(), stream, is_0rtt)))
}
fn poll_open_send(
&mut self,
cx: &mut Context<'_>,
) -> Poll<Result<Self::SendStream, StreamErrorIncoming>> {
let (stream, is_0rtt) = ready!(self.poll_open_stream(Some(cx), Dir::Uni))?;
Poll::Ready(Ok(SendStream::new(self.0.clone(), stream, is_0rtt)))
}
fn close(&mut self, code: Code, reason: &[u8]) {
Connection::close(self, code.value().try_into().expect("invalid code"), reason)
}
}
impl<B> quic::Connection<B> for Connection
where
B: Buf,
{
type RecvStream = RecvStream;
type OpenStreams = OpenStreams;
fn poll_accept_recv(
&mut self,
cx: &mut Context<'_>,
) -> Poll<Result<Self::RecvStream, ConnectionErrorIncoming>> {
let (stream, is_0rtt) = ready!(self.poll_accept_stream(cx, Dir::Uni))?;
Poll::Ready(Ok(RecvStream::new(self.0.clone(), stream, is_0rtt)))
}
fn poll_accept_bidi(
&mut self,
cx: &mut Context<'_>,
) -> Poll<Result<Self::BidiStream, ConnectionErrorIncoming>> {
let (stream, is_0rtt) = ready!(self.poll_accept_stream(cx, Dir::Bi))?;
Poll::Ready(Ok(BidiStream::new(self.0.clone(), stream, is_0rtt)))
}
fn opener(&self) -> Self::OpenStreams {
OpenStreams(self.clone())
}
}
}