use std::future::Future;
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll, ready};
use std::time::Duration;
use proto::{
ClosePathError, ClosedPath, PathError, PathEvent, PathId, PathStats, PathStatus,
SetPathStatusError, TransportErrorCode,
};
use tokio::sync::watch;
use tokio_stream::{Stream, wrappers::WatchStream};
use crate::connection::ConnectionRef;
use crate::{Runtime, WeakConnectionHandle};
pub struct OpenPath(OpenPathInner);
enum OpenPathInner {
Ongoing {
opened: WatchStream<Result<(), PathError>>,
path_id: PathId,
conn: ConnectionRef,
},
Rejected {
err: PathError,
},
Ready {
path_id: PathId,
conn: ConnectionRef,
},
}
impl OpenPath {
pub(crate) fn new(
path_id: PathId,
opened: watch::Receiver<Result<(), PathError>>,
conn: ConnectionRef,
) -> Self {
Self(OpenPathInner::Ongoing {
opened: WatchStream::from_changes(opened),
path_id,
conn,
})
}
pub(crate) fn ready(path_id: PathId, conn: ConnectionRef) -> Self {
Self(OpenPathInner::Ready { path_id, conn })
}
pub(crate) fn rejected(err: PathError) -> Self {
Self(OpenPathInner::Rejected { err })
}
pub fn path_id(&self) -> Option<PathId> {
match self.0 {
OpenPathInner::Ongoing { path_id, .. } => Some(path_id),
OpenPathInner::Rejected { .. } => None,
OpenPathInner::Ready { path_id, .. } => Some(path_id),
}
}
}
impl Future for OpenPath {
type Output = Result<Path, PathError>;
fn poll(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
match self.get_mut().0 {
OpenPathInner::Ongoing {
ref mut opened,
path_id,
ref mut conn,
} => match ready!(Pin::new(opened).poll_next(ctx)) {
Some(value) => {
Poll::Ready(value.map(|_| Path::new_unchecked(conn.clone(), path_id)))
}
None => {
Poll::Ready(Err(PathError::ValidationFailed))
}
},
OpenPathInner::Ready {
path_id,
ref mut conn,
} => Poll::Ready(Ok(Path::new_unchecked(conn.clone(), path_id))),
OpenPathInner::Rejected { err } => Poll::Ready(Err(err)),
}
}
}
#[derive(Debug)]
pub struct Path {
id: PathId,
conn: ConnectionRef,
}
impl Clone for Path {
fn clone(&self) -> Self {
self.conn
.lock_without_waking("Path::clone")
.increment_path_refs(self.id);
Self {
id: self.id,
conn: self.conn.clone(),
}
}
}
impl Drop for Path {
fn drop(&mut self) {
let mut state = self.conn.lock_without_waking("Path::drop");
state.decrement_path_refs(self.id);
}
}
impl Path {
pub(crate) fn new(conn: &ConnectionRef, id: PathId) -> Option<Self> {
{
let mut state = conn.lock_without_waking("Path::new");
state.inner.path_status(id).ok()?;
state.increment_path_refs(id);
}
Some(Self {
id,
conn: conn.clone(),
})
}
fn new_unchecked(conn: ConnectionRef, id: PathId) -> Self {
conn.lock_without_waking("Path::new_unchecked")
.increment_path_refs(id);
Self { id, conn }
}
pub fn weak_handle(&self) -> WeakPathHandle {
self.conn
.lock_without_waking("Path::weak_handle")
.increment_path_refs(self.id);
WeakPathHandle {
id: self.id,
conn: self.conn.weak_handle(),
}
}
pub fn id(&self) -> PathId {
self.id
}
pub fn status(&self) -> Result<PathStatus, ClosedPath> {
self.conn
.lock_without_waking("path status")
.inner
.path_status(self.id)
}
pub fn set_status(&self, status: PathStatus) -> Result<(), SetPathStatusError> {
self.conn
.lock_and_wake("set path status")
.inner
.set_path_status(self.id, status)?;
Ok(())
}
pub fn stats(&self) -> PathStats {
self.conn
.lock_without_waking("Path::stats")
.path_stats(self.id)
.expect("either path stats or discarded path stats are always set as long as Path is not dropped")
}
pub fn close(&self) -> Result<(), ClosePathError> {
let mut state = self.conn.lock_and_wake("close_path");
state.inner.close_path(
crate::Instant::now(),
self.id,
TransportErrorCode::APPLICATION_ABANDON_PATH.into(),
)
}
pub fn set_max_idle_timeout(
&self,
timeout: Option<Duration>,
) -> Result<Option<Duration>, ClosedPath> {
let mut state = self.conn.lock_and_wake("path_set_max_idle_timeout");
let now = state.runtime.now();
state.inner.set_path_max_idle_timeout(now, self.id, timeout)
}
pub fn set_keep_alive_interval(
&self,
interval: Option<Duration>,
) -> Result<Option<Duration>, ClosedPath> {
let mut state = self.conn.lock_and_wake("path_set_keep_alive_interval");
state.inner.set_path_keep_alive_interval(self.id, interval)
}
pub fn observed_external_addr(&self) -> Result<AddressDiscovery, ClosedPath> {
let state = self.conn.lock_without_waking("per_path_observed_address");
let path_events = state.path_events.subscribe();
let initial_value = state.inner.path_observed_address(self.id)?;
Ok(AddressDiscovery::new(
self.id,
path_events,
initial_value,
state.runtime.clone(),
))
}
pub fn remote_address(&self) -> Result<SocketAddr, ClosedPath> {
let state = self.conn.lock_without_waking("per_path_remote_address");
Ok(state.inner.network_path(self.id)?.remote())
}
pub fn ping(&self) -> Result<(), ClosedPath> {
let mut state = self.conn.lock_and_wake("ping");
state.inner.ping_path(self.id)
}
}
impl PartialEq for Path {
fn eq(&self, other: &Self) -> bool {
self.id == other.id && self.conn.stable_id() == other.conn.stable_id()
}
}
#[derive(Debug)]
pub struct WeakPathHandle {
id: PathId,
conn: WeakConnectionHandle,
}
impl Clone for WeakPathHandle {
fn clone(&self) -> Self {
if let Some(conn) = self.conn.upgrade_to_ref() {
conn.lock_without_waking("WeakPathHandle::clone")
.increment_path_refs(self.id);
}
Self {
id: self.id,
conn: self.conn.clone(),
}
}
}
impl PartialEq for WeakPathHandle {
fn eq(&self, other: &Self) -> bool {
self.id == other.id && self.conn.is_same_connection(&other.conn)
}
}
impl Eq for WeakPathHandle {}
impl Drop for WeakPathHandle {
fn drop(&mut self) {
if let Some(conn) = self.conn.upgrade_to_ref() {
conn.lock_without_waking("WeakPathHandle::drop")
.decrement_path_refs(self.id);
}
}
}
impl WeakPathHandle {
pub fn id(&self) -> PathId {
self.id
}
pub fn upgrade(&self) -> Option<Path> {
let conn = self.conn.upgrade_to_ref()?;
Some(Path::new_unchecked(conn, self.id))
}
}
pub struct AddressDiscovery {
watcher: WatchStream<SocketAddr>,
}
impl AddressDiscovery {
pub(super) fn new(
path_id: PathId,
mut path_events: tokio::sync::broadcast::Receiver<PathEvent>,
initial_value: Option<SocketAddr>,
runtime: Arc<dyn Runtime>,
) -> Self {
let (tx, rx) = watch::channel(initial_value.unwrap_or_else(||
SocketAddr::new([0, 0, 0, 0].into(), 0)));
let filter = async move {
loop {
match path_events.recv().await {
Ok(PathEvent::ObservedAddr { id, addr: observed }) if id == path_id => {
tx.send_if_modified(|addr| {
let old = std::mem::replace(addr, observed);
old != *addr
});
}
Ok(PathEvent::Discarded { id, .. }) if id == path_id => {
break;
}
Ok(_) => {
}
Err(_) => {
break;
}
}
}
};
let watcher = if initial_value.is_some() {
WatchStream::new(rx)
} else {
WatchStream::from_changes(rx)
};
runtime.spawn(Box::pin(filter));
Self { watcher }
}
}
impl Stream for AddressDiscovery {
type Item = SocketAddr;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Pin::new(&mut self.watcher).poll_next(cx)
}
}