use std::{
future::{Future, IntoFuture},
net::{IpAddr, Ipv4Addr, SocketAddr},
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use crate::{ConnectionError, ConnectionId, ServerConfig};
use thiserror::Error;
use tracing::error;
use super::{
connection::{Connecting, Connection},
endpoint::EndpointRef,
};
#[derive(Debug)]
pub struct Incoming(Option<State>);
impl Incoming {
pub(crate) fn new(inner: crate::Incoming, endpoint: EndpointRef) -> Self {
Self(Some(State { inner, endpoint }))
}
pub fn accept(mut self) -> Result<Connecting, ConnectionError> {
let state = self.0.take().ok_or_else(|| {
error!("Incoming connection state already consumed");
ConnectionError::LocallyClosed
})?;
state.endpoint.accept(state.inner, None)
}
pub fn accept_with(
mut self,
server_config: Arc<ServerConfig>,
) -> Result<Connecting, ConnectionError> {
let state = self.0.take().ok_or_else(|| {
error!("Incoming connection state already consumed");
ConnectionError::LocallyClosed
})?;
state.endpoint.accept(state.inner, Some(server_config))
}
pub fn refuse(mut self) {
if let Some(state) = self.0.take() {
state.endpoint.refuse(state.inner);
} else {
error!("Incoming connection state already consumed");
}
}
pub fn retry(mut self) -> Result<(), RetryError> {
let state = match self.0.take() {
Some(state) => state,
None => {
error!("Incoming connection state already consumed");
return Err(RetryError::incoming(self));
}
};
let State { inner, endpoint } = state;
match endpoint.retry(inner) {
Ok(()) => Ok(()),
Err(err) => Err(RetryError::incoming(Incoming::new(
err.into_incoming(),
endpoint,
))),
}
}
pub fn ignore(mut self) {
if let Some(state) = self.0.take() {
state.endpoint.ignore(state.inner);
} else {
error!("Incoming connection state already consumed");
}
}
pub fn local_ip(&self) -> Option<IpAddr> {
self.0.as_ref()?.inner.local_ip()
}
pub fn remote_address(&self) -> SocketAddr {
self.0
.as_ref()
.map(|state| state.inner.remote_address())
.unwrap_or_else(|| SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0))
}
pub fn remote_address_validated(&self) -> bool {
self.0
.as_ref()
.map(|state| state.inner.remote_address_validated())
.unwrap_or(false)
}
pub fn may_retry(&self) -> bool {
self.0
.as_ref()
.map(|state| state.inner.may_retry())
.unwrap_or(false)
}
pub fn orig_dst_cid(&self) -> ConnectionId {
self.0
.as_ref()
.map(|state| *state.inner.orig_dst_cid())
.unwrap_or_else(|| ConnectionId::new(&[]))
}
}
impl Drop for Incoming {
fn drop(&mut self) {
if let Some(state) = self.0.take() {
state.endpoint.refuse(state.inner);
}
}
}
#[derive(Debug)]
struct State {
inner: crate::Incoming,
endpoint: EndpointRef,
}
#[derive(Debug, Error)]
pub enum RetryError {
#[error("retry() with invalid Incoming")]
Incoming(Box<Incoming>),
}
impl RetryError {
pub fn incoming(incoming: Incoming) -> Self {
Self::Incoming(Box::new(incoming))
}
pub fn into_incoming(self) -> Incoming {
match self {
Self::Incoming(incoming) => *incoming,
}
}
}
#[derive(Debug)]
pub struct IncomingFuture(Result<Connecting, ConnectionError>);
impl Future for IncomingFuture {
type Output = Result<Connection, ConnectionError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
match &mut self.0 {
Ok(connecting) => Pin::new(connecting).poll(cx),
Err(e) => Poll::Ready(Err(e.clone())),
}
}
}
impl IntoFuture for Incoming {
type Output = Result<Connection, ConnectionError>;
type IntoFuture = IncomingFuture;
fn into_future(self) -> Self::IntoFuture {
IncomingFuture(self.accept())
}
}
#[cfg(test)]
mod tests {
use super::{Incoming, RetryError};
#[test]
fn retry_on_consumed_incoming_returns_error() {
let incoming = Incoming(None);
let err = incoming.retry().unwrap_err();
match err {
RetryError::Incoming(inner) => {
assert!(inner.0.is_none());
}
}
}
}