use async_trait::async_trait;
use futures::{future, future::BoxFuture, pin_mut, stream::FuturesUnordered, FutureExt, StreamExt};
use std::{
fmt,
future::IntoFuture,
io::{Error, ErrorKind, Result},
sync::{Arc, Weak},
time::Duration,
};
use tokio::sync::{broadcast, mpsc, oneshot, watch, Mutex, OwnedSemaphorePermit, RwLock, Semaphore};
use tracing::Instrument;
use super::{BoxControl, BoxLink, BoxLinkError, BoxListener, BoxServer, BoxTask, LinkError, LinkTag, LinkTagBox};
use crate::{
alc::Channel,
exec,
exec::time::{sleep_until, Instant},
io::{StreamBox, TxRxBox},
Cfg, Server,
};
pub struct AcceptedStreamBox {
pub stream: StreamBox,
pub tag: LinkTagBox,
}
impl AcceptedStreamBox {
pub fn new(stream: StreamBox, tag: impl LinkTag) -> Self {
Self { stream, tag: Box::new(tag) }
}
}
#[async_trait]
pub trait AcceptingTransport: Send + Sync + 'static {
fn name(&self) -> &str;
async fn listen(&self, tx: mpsc::Sender<AcceptedStreamBox>) -> Result<()>;
async fn link_filter(&self, _new: &BoxLink, _existing: &[BoxLink]) -> bool {
true
}
}
type ArcAcceptingTransport = Arc<dyn AcceptingTransport>;
type TaskCfgFn = Box<dyn Fn(&mut BoxTask) + Send + Sync + 'static>;
#[async_trait]
pub trait AcceptingWrapper: Send + Sync + fmt::Debug + 'static {
fn name(&self) -> &str;
async fn wrap(&self, io: StreamBox) -> Result<StreamBox>;
}
type BoxAcceptingWrapper = Box<dyn AcceptingWrapper>;
struct AcceptingTransportPack {
transport: ArcAcceptingTransport,
result_tx: oneshot::Sender<Result<()>>,
remove_rx: oneshot::Receiver<()>,
_permit: OwnedSemaphorePermit,
}
pub struct AcceptorBuilder {
server: BoxServer,
task_cfg: TaskCfgFn,
wrappers: Vec<BoxAcceptingWrapper>,
no_transport_timeout: Duration,
}
impl AcceptorBuilder {
pub fn new(cfg: Cfg) -> Self {
let server = Server::new(cfg);
let task_cfg: TaskCfgFn = Box::new(|_| ());
Self { server, task_cfg, wrappers: Vec::new(), no_transport_timeout: Duration::from_secs(30) }
}
pub fn set_task_cfg(&mut self, task_cfg: impl Fn(&mut BoxTask) + Send + Sync + 'static) {
self.task_cfg = Box::new(task_cfg);
}
pub fn set_no_transport_timeout(&mut self, no_transport_timeout: Duration) {
self.no_transport_timeout = no_transport_timeout;
}
pub fn wrap(&mut self, wrapper: impl AcceptingWrapper) {
self.wrappers.push(Box::new(wrapper))
}
pub fn build(self) -> Acceptor {
let Self { server, task_cfg, wrappers, no_transport_timeout } = self;
let active_transports = Arc::new(RwLock::new(Vec::<Weak<dyn AcceptingTransport>>::new()));
let (transport_tx, transport_rx) = mpsc::unbounded_channel();
let (transports_present_tx, transports_present_rx) = watch::channel(true);
let (error_tx, error_rx) = broadcast::channel(1024);
let listener = Mutex::new(server.listen().unwrap());
exec::spawn(
Acceptor::task(
server.clone(),
active_transports.clone(),
transport_rx,
error_tx,
transports_present_tx,
wrappers,
)
.in_current_span(),
);
Acceptor {
server,
listener,
task_cfg,
transport_tx,
transports_present_rx,
transports_being_added: Arc::new(Semaphore::new(Semaphore::MAX_PERMITS)),
error_rx,
active_transports,
no_transport_timeout,
}
}
}
pub struct Acceptor {
server: BoxServer,
listener: Mutex<BoxListener>,
task_cfg: TaskCfgFn,
transport_tx: mpsc::UnboundedSender<AcceptingTransportPack>,
transports_present_rx: watch::Receiver<bool>,
transports_being_added: Arc<Semaphore>,
active_transports: Arc<RwLock<Vec<Weak<dyn AcceptingTransport>>>>,
error_rx: broadcast::Receiver<BoxLinkError>,
no_transport_timeout: Duration,
}
impl fmt::Debug for Acceptor {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("Acceptor").field("id", &self.server.id()).finish()
}
}
impl Default for Acceptor {
fn default() -> Self {
Self::new()
}
}
impl Acceptor {
pub fn new() -> Self {
AcceptorBuilder::new(Cfg::default()).build()
}
pub fn wrapped(wrapper: impl AcceptingWrapper) -> Self {
let mut builder = AcceptorBuilder::new(Cfg::default());
builder.wrap(wrapper);
builder.build()
}
pub fn add(&self, transport: impl AcceptingTransport) -> AcceptingTransportHandle {
let name = transport.name().to_string();
let (result_tx, result_rx) = oneshot::channel();
let (remove_tx, remove_rx) = oneshot::channel();
let pack = AcceptingTransportPack {
transport: Arc::new(transport),
result_tx,
remove_rx,
_permit: self.transports_being_added.clone().try_acquire_owned().unwrap(),
};
let _ = self.transport_tx.send(pack);
AcceptingTransportHandle { name, result_rx, remove_tx }
}
pub fn is_empty(&self) -> bool {
!*self.transports_present_rx.borrow()
&& self.transports_being_added.available_permits() == Semaphore::MAX_PERMITS
}
pub async fn accept(&self) -> Result<(Channel, BoxControl)> {
let mut transports_present_rx = self.transports_present_rx.clone();
let no_transport_timeout = self.no_transport_timeout;
let timeout = async move {
let mut until = None;
loop {
if *transports_present_rx.borrow_and_update() {
until = None;
} else if until.is_none() {
until = Instant::now().checked_add(no_transport_timeout);
}
let sleep_task = async {
match until {
Some(until) => sleep_until(until).await,
None => future::pending().await,
}
};
tokio::select! {
() = sleep_task => return Error::new(ErrorKind::BrokenPipe, "no listening transports available"),
res = transports_present_rx.changed() => {
if res.is_err() {
return Error::new(ErrorKind::BrokenPipe, "listener was terminated");
}
}
}
}
};
pin_mut!(timeout);
let mut listener = self.listener.lock().await;
let (mut task, channel, control) = tokio::select! {
res = listener.accept() => res?,
err = &mut timeout => return Err(err),
};
(self.task_cfg)(&mut task);
let active_transports = self.active_transports.clone();
task.set_link_filter(move |link, others| {
let active_transports = active_transports.clone();
async move {
let transports = active_transports.read_owned().await;
for transport in &*transports {
let Some(transport) = transport.upgrade() else { continue };
if !transport.link_filter(&link, &others).await {
return false;
}
}
true
}
});
exec::spawn(task.run().in_current_span());
tracing::debug!(conn_id =? control.id(), "accepted incoming connection");
Ok((channel, control))
}
pub fn link_errors(&self) -> broadcast::Receiver<BoxLinkError> {
self.error_rx.resubscribe()
}
#[tracing::instrument(name = "aggligator::acceptor", level = "info", skip_all, fields(server_id =? server.id()))]
async fn task(
server: BoxServer, active_transports: Arc<RwLock<Vec<Weak<dyn AcceptingTransport>>>>,
mut transport_rx: mpsc::UnboundedReceiver<AcceptingTransportPack>,
link_error_tx: broadcast::Sender<BoxLinkError>, transports_present_tx: watch::Sender<bool>,
wrappers: Vec<BoxAcceptingWrapper>,
) {
let wrappers = Arc::new(wrappers);
let mut transport_tasks = FuturesUnordered::new();
loop {
transports_present_tx.send_if_modified(|v| {
let present = !transport_tasks.is_empty();
if *v != present {
*v = present;
true
} else {
false
}
});
enum ListenerEvent {
TransportAdded(AcceptingTransportPack),
TaskEnded,
}
let event = tokio::select! {
res = transport_rx.recv() => {
match res {
Some(transport_pack) => ListenerEvent::TransportAdded(transport_pack),
None => break,
}
}
Some(()) = transport_tasks.next() => ListenerEvent::TaskEnded,
};
match event {
ListenerEvent::TransportAdded(transport_pack) => {
let mut active_transports = active_transports.write().await;
active_transports.retain(|at| at.strong_count() > 0);
active_transports.push(Arc::downgrade(&transport_pack.transport));
transport_tasks.push(Self::transport_task(
server.clone(),
transport_pack,
link_error_tx.clone(),
wrappers.clone(),
));
}
ListenerEvent::TaskEnded => (),
}
}
}
#[tracing::instrument(name = "transport", level = "info", skip_all, fields(name = transport.transport.name()))]
async fn transport_task(
server: BoxServer, transport: AcceptingTransportPack, link_error_tx: broadcast::Sender<BoxLinkError>,
wrappers: Arc<Vec<BoxAcceptingWrapper>>,
) {
let AcceptingTransportPack { transport, result_tx, remove_rx, _permit: _ } = transport;
let mut remove_rx = remove_rx.fuse();
let (tx, mut rx) = mpsc::channel(128);
let mut listener = transport.listen(tx);
let mut accepting_tasks = FuturesUnordered::new();
let res = loop {
let AcceptedStreamBox { stream: mut stream_box, tag } = tokio::select! {
Some(accepted) = rx.recv() => accepted,
Some(()) = accepting_tasks.next() => continue,
res = &mut listener => break res,
Ok(()) = &mut remove_rx => break Ok(()),
};
tracing::debug!(%tag, "accepted transport connection");
if tag.transport_name() != transport.name() {
break Err(Error::other("link tag transport name mismatch"));
}
let wrappers = &*wrappers;
let server = &server;
let link_error_tx = &link_error_tx;
let task = async move {
for wrapper in wrappers {
let name = wrapper.name();
tracing::debug!(%tag, wrapper =% name, "wrapping");
match wrapper.wrap(stream_box).await {
Ok(wrapped) => stream_box = wrapped,
Err(err) => {
tracing::debug!(%tag, wrapper =% name, %err, "wrapping failed");
let _ = link_error_tx.send(BoxLinkError::incoming(&tag, err));
return;
}
}
}
tracing::debug!(%tag, "adding link to connection");
let user_data = tag.user_data();
let TxRxBox { tx, rx } = stream_box.into_tx_rx();
let link = match server.add_incoming(tx, rx, tag.clone(), &user_data).await {
Ok(link) => link,
Err(err) => {
tracing::warn!(%tag, %err, "adding link to connection failed");
let _ = link_error_tx.send(LinkError::incoming(&tag, err.into()));
return;
}
};
tracing::info!(link_id =? link.id(), %tag, conn_id =? link.conn_id(), "link connected");
struct DisconnectLink<'a>(&'a BoxLink);
impl Drop for DisconnectLink<'_> {
fn drop(&mut self) {
self.0.start_disconnect();
}
}
let _disconnect_link = DisconnectLink(&link);
let reason = link.disconnected().await;
tracing::info!(link_id =? link.id(), %tag, %reason, "link disconnected");
let _ = link_error_tx.send(BoxLinkError::incoming(&tag, reason.into()));
};
accepting_tasks.push(task);
};
if let Err(err) = &res {
tracing::warn!(%err, "transport failed");
}
let _ = result_tx.send(res);
}
}
pub struct AcceptingTransportHandle {
name: String,
result_rx: oneshot::Receiver<Result<()>>,
remove_tx: oneshot::Sender<()>,
}
impl fmt::Debug for AcceptingTransportHandle {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("AcceptingTransportHandle").field("name", &self.name).finish()
}
}
impl AcceptingTransportHandle {
pub fn name(&self) -> &str {
&self.name
}
pub fn remove(self) {
let Self { remove_tx, .. } = self;
let _ = remove_tx.send(());
}
}
impl IntoFuture for AcceptingTransportHandle {
type Output = Result<()>;
type IntoFuture = BoxFuture<'static, Result<()>>;
fn into_future(self) -> Self::IntoFuture {
let Self { result_rx, .. } = self;
async move {
match result_rx.await {
Ok(res) => res,
Err(_) => Ok(()),
}
}
.boxed()
}
}