use async_trait::async_trait;
use futures::{
future::{self, BoxFuture},
stream::FuturesUnordered,
FutureExt, StreamExt,
};
use std::{
collections::HashSet,
fmt::{self, Debug},
future::IntoFuture,
io::{Error, Result},
iter,
sync::{Arc, Weak},
time::Duration,
};
use tokio::sync::{broadcast, mpsc, oneshot, watch, RwLock};
use tracing::Instrument;
use super::{BoxControl, BoxLink, BoxLinkError, BoxTask, LinkTag, LinkTagBox};
use crate::{
connect,
control::DisconnectReason,
exec,
exec::time::sleep,
io::{StreamBox, TxRxBox},
Cfg, Link, Outgoing,
};
#[async_trait]
pub trait ConnectingTransport: Send + Sync + 'static {
fn name(&self) -> &str;
async fn link_tags(&self, tx: watch::Sender<HashSet<LinkTagBox>>) -> Result<()>;
async fn connect(&self, tag: &dyn LinkTag) -> Result<StreamBox>;
async fn link_filter(&self, _new: &Link<LinkTagBox>, _existing: &[Link<LinkTagBox>]) -> bool {
true
}
async fn connected_links(&self, _links: &[Link<LinkTagBox>]) {}
}
type ArcConnectingTransport = Arc<dyn ConnectingTransport>;
#[async_trait]
pub trait ConnectingWrapper: Send + Sync + fmt::Debug + 'static {
fn name(&self) -> &str;
async fn wrap(&self, io: StreamBox) -> Result<StreamBox>;
}
type BoxConnectingWrapper = Box<dyn ConnectingWrapper>;
struct TransportPack {
transport: ArcConnectingTransport,
result_tx: oneshot::Sender<Result<()>>,
remove_rx: oneshot::Receiver<()>,
}
#[derive(Debug)]
pub struct ConnectorBuilder {
task: BoxTask,
outgoing: Outgoing,
control: BoxControl,
reconnect_delay: Duration,
wrappers: Vec<BoxConnectingWrapper>,
}
impl ConnectorBuilder {
pub fn new(cfg: Cfg) -> Self {
let (task, outgoing, control) = connect(cfg);
Self { task, outgoing, control, reconnect_delay: Duration::from_secs(10), wrappers: Vec::new() }
}
pub fn task(&mut self) -> &mut BoxTask {
&mut self.task
}
pub fn set_reconnect_delay(&mut self, reconnect_delay: Duration) {
self.reconnect_delay = reconnect_delay
}
pub fn wrap(&mut self, wrapper: impl ConnectingWrapper) {
self.wrappers.push(Box::new(wrapper))
}
pub fn build(self) -> Connector {
let Self { mut task, outgoing, control, reconnect_delay, wrappers } = self;
let active_transports = Arc::new(RwLock::new(Vec::<Weak<dyn ConnectingTransport>>::new()));
let active_transports_filter = active_transports.clone();
task.set_link_filter(move |link, others| {
let active_transports_filter = active_transports_filter.clone();
async move {
let transports = active_transports_filter.read_owned().await;
for transport in &*transports {
let Some(transport) = transport.upgrade() else { continue };
if !transport.link_filter(&link, &others).await {
return false;
}
}
true
}
});
exec::spawn(task.run().in_current_span());
let (transport_tx, transport_rx) = mpsc::unbounded_channel();
let (tags_tx, tags_rx) = watch::channel(HashSet::new());
let (error_tx, error_rx) = broadcast::channel(1024);
let (disabled_tags_tx, disabled_tags_rx) = watch::channel(HashSet::new());
exec::spawn(
Connector::task(
control.clone(),
active_transports,
transport_rx,
tags_tx,
disabled_tags_rx,
error_tx,
reconnect_delay,
wrappers,
)
.in_current_span(),
);
Connector { control, outgoing: Some(outgoing), transport_tx, tags_rx, error_rx, disabled_tags_tx }
}
}
pub struct Connector {
control: BoxControl,
outgoing: Option<Outgoing>,
transport_tx: mpsc::UnboundedSender<TransportPack>,
tags_rx: watch::Receiver<HashSet<LinkTagBox>>,
disabled_tags_tx: watch::Sender<HashSet<LinkTagBox>>,
error_rx: broadcast::Receiver<BoxLinkError>,
}
impl fmt::Debug for Connector {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("Connector").field("id", &self.control.id()).finish()
}
}
impl Default for Connector {
fn default() -> Self {
Self::new()
}
}
impl Connector {
pub fn new() -> Self {
ConnectorBuilder::new(Cfg::default()).build()
}
pub fn wrapped(wrapper: impl ConnectingWrapper) -> Self {
let mut builder = ConnectorBuilder::new(Cfg::default());
builder.wrap(wrapper);
builder.build()
}
pub fn add(&self, transport: impl ConnectingTransport) -> ConnectingTransportHandle {
let name = transport.name().to_string();
let (result_tx, result_rx) = oneshot::channel();
let (remove_tx, remove_rx) = oneshot::channel();
let pack = TransportPack { transport: Arc::new(transport), result_tx, remove_rx };
let _ = self.transport_tx.send(pack);
ConnectingTransportHandle { name, result_rx, remove_tx }
}
pub fn channel(&mut self) -> Option<Outgoing> {
self.outgoing.take()
}
pub fn control(&self) -> BoxControl {
self.control.clone()
}
pub fn available_tags(&self) -> HashSet<LinkTagBox> {
self.tags_rx.borrow().clone()
}
pub fn available_tags_watch(&self) -> watch::Receiver<HashSet<LinkTagBox>> {
self.tags_rx.clone()
}
pub fn set_disabled_tags(&self, disabled_tags: HashSet<LinkTagBox>) {
self.disabled_tags_tx.send_replace(disabled_tags);
}
pub fn link_errors(&self) -> broadcast::Receiver<BoxLinkError> {
self.error_rx.resubscribe()
}
#[allow(clippy::too_many_arguments)]
#[tracing::instrument(name = "aggligator::connector", level = "info", skip_all, fields(conn_id =? control.id()))]
async fn task(
control: BoxControl, active_transports: Arc<RwLock<Vec<Weak<dyn ConnectingTransport>>>>,
mut transport_rx: mpsc::UnboundedReceiver<TransportPack>, tags_tx: watch::Sender<HashSet<LinkTagBox>>,
disabled_tags_rx: watch::Receiver<HashSet<LinkTagBox>>, link_error_tx: broadcast::Sender<BoxLinkError>,
reconnect_delay: Duration, wrappers: Vec<BoxConnectingWrapper>,
) {
let wrappers = Arc::new(wrappers);
let mut transport_tasks = FuturesUnordered::new();
let mut transport_tags: Vec<watch::Receiver<HashSet<LinkTagBox>>> = Vec::new();
loop {
transport_tags.retain(|tt| tt.has_changed().is_ok());
let mut all_tags = HashSet::new();
for tt in &mut transport_tags {
let tags = tt.borrow_and_update();
for tag in &*tags {
all_tags.insert(tag.clone());
}
}
tags_tx.send_if_modified(|tags| {
if *tags == all_tags {
false
} else {
*tags = all_tags;
true
}
});
let tags_changed = future::select_all(
transport_tags
.iter_mut()
.map(|tt| tt.changed().boxed())
.chain(iter::once(future::pending().boxed())),
);
enum ConnectorEvent {
TransportAdded(TransportPack),
TagsChanged,
TransportTerminated,
}
let event = tokio::select! {
Some(transport_pack) = transport_rx.recv() => ConnectorEvent::TransportAdded(transport_pack),
_ = tags_changed => ConnectorEvent::TagsChanged,
Some(()) = transport_tasks.next() => ConnectorEvent::TransportTerminated,
_ = control.terminated() => {
tracing::debug!("connection was terminated");
break;
}
};
match event {
ConnectorEvent::TransportAdded(transport_pack) => {
let mut active_transports = active_transports.write().await;
active_transports.retain(|at| at.strong_count() > 0);
active_transports.push(Arc::downgrade(&transport_pack.transport));
let (transport_tags_tx, transport_tags_rx) = watch::channel(HashSet::new());
transport_tags.push(transport_tags_rx);
transport_tasks.push(Self::transport_task(
transport_pack,
control.clone(),
transport_tags_tx,
disabled_tags_rx.clone(),
link_error_tx.clone(),
reconnect_delay,
wrappers.clone(),
));
}
ConnectorEvent::TagsChanged => (),
ConnectorEvent::TransportTerminated => (),
}
}
}
#[tracing::instrument(name = "transport", level = "info", skip_all, fields(name = transport_pack.transport.name()))]
async fn transport_task(
transport_pack: TransportPack, control: BoxControl, tags_fw_tx: watch::Sender<HashSet<LinkTagBox>>,
mut disabled_tags_rx: watch::Receiver<HashSet<LinkTagBox>>,
link_error_tx: broadcast::Sender<BoxLinkError>, reconnect_delay: Duration,
wrappers: Arc<Vec<BoxConnectingWrapper>>,
) {
let TransportPack { transport, result_tx, remove_rx } = transport_pack;
let mut remove_rx = remove_rx.fuse();
let conn_id = control.id();
let mut changed_control = control.clone();
let (tags_tx, mut tags_rx) = watch::channel(HashSet::new());
let mut tags_task = transport.link_tags(tags_tx);
let mut tags_changed = true;
let mut connecting_tags = HashSet::new();
let mut connecting_tasks = FuturesUnordered::new();
let mut link_filter_rejected_tags = HashSet::new();
let res = 'outer: loop {
{
let links = control.links();
transport.connected_links(&links).await;
let disabled_tags = disabled_tags_rx.borrow_and_update();
for link in &links {
if disabled_tags.contains(link.tag()) {
link.start_disconnect();
}
}
let tags = tags_rx.borrow_and_update().clone();
if tags_changed {
tracing::debug!(
"available tags: {}",
tags.iter().map(|tag| tag.to_string()).collect::<Vec<_>>().join(", ")
);
tags_fw_tx.send_replace(tags.clone());
tags_changed = false;
}
for tag in tags {
if tag.transport_name() != transport.name() {
break 'outer Err(Error::other("link tag transport name mismatch"));
}
if connecting_tags.contains(&tag)
|| disabled_tags.contains(&tag)
|| link_filter_rejected_tags.contains(&tag)
|| links.iter().any(|link| link.tag() == &tag)
{
continue;
}
tracing::debug!(%tag, "connecting tag");
connecting_tags.insert(tag.clone());
let connect_task = async {
tracing::debug!(%tag, "establishing transport connection for tag");
let mut stream_box = match transport.connect(&*tag).await {
Ok(stream_box) => stream_box,
Err(err) => {
tracing::debug!(%tag, %err, "connecting transport for tag failed");
let _ = link_error_tx.send(BoxLinkError::outgoing(conn_id, &tag, err));
sleep(reconnect_delay).await;
return (tag, None);
}
};
for wrapper in &*wrappers {
let name = wrapper.name();
tracing::debug!(%tag, wrapper =% name, "wrapping tag");
match wrapper.wrap(stream_box).await {
Ok(wrapped) => stream_box = wrapped,
Err(err) => {
tracing::debug!(%tag, wrapper =% name, %err, "wrapping tag failed");
let _ = link_error_tx.send(BoxLinkError::outgoing(conn_id, &tag, err));
sleep(reconnect_delay).await;
return (tag, None);
}
}
}
tracing::debug!(%tag, "adding link to connection");
let TxRxBox { tx, rx } = stream_box.into_tx_rx();
let link = match control.add(tx, rx, tag.clone(), &tag.user_data()).await {
Ok(link) => link,
Err(err) => {
tracing::warn!(%tag, %err, "adding link to connection failed");
let _ = link_error_tx.send(BoxLinkError::outgoing(conn_id, &tag, err.into()));
sleep(reconnect_delay).await;
return (tag, None);
}
};
tracing::info!(link_id =? link.id(), %tag, "link connected");
struct DisconnectLink<'a>(&'a BoxLink);
impl Drop for DisconnectLink<'_> {
fn drop(&mut self) {
self.0.start_disconnect();
}
}
let _disconnect_link = DisconnectLink(&link);
let sleep_until = sleep(reconnect_delay);
let reason = link.disconnected().await;
tracing::info!(link_id =? link.id(), %tag, %reason, "link disconnected");
let _ = link_error_tx.send(BoxLinkError::outgoing(conn_id, &tag, reason.clone().into()));
sleep_until.await;
(tag, Some(reason))
};
connecting_tasks.push(connect_task);
}
}
tokio::select! {
res = &mut tags_task => break res,
Ok(()) = &mut remove_rx => break Ok(()),
Ok(()) = disabled_tags_rx.changed() => (),
Ok(()) = tags_rx.changed() => tags_changed = true,
() = changed_control.links_changed() => (),
_ = control.terminated() => break Ok(()),
Some((tag, reason)) = connecting_tasks.next() => {
connecting_tags.remove(&tag);
match reason {
Some(DisconnectReason::LinkFilter) => {
tracing::debug!(%tag, "blocking tag");
link_filter_rejected_tags.insert(tag);
}
Some(_) => {
tracing::debug!("clearing tag block list");
link_filter_rejected_tags.clear();
}
None => (),
}
},
}
};
match &res {
Ok(()) => tracing::debug!("transport terminated"),
Err(err) => tracing::warn!(%err, "transport failed"),
}
let _ = result_tx.send(res);
}
}
pub struct ConnectingTransportHandle {
name: String,
result_rx: oneshot::Receiver<Result<()>>,
remove_tx: oneshot::Sender<()>,
}
impl fmt::Debug for ConnectingTransportHandle {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("ConnectingTransportHandle").field("name", &self.name).finish()
}
}
impl ConnectingTransportHandle {
pub fn name(&self) -> &str {
&self.name
}
pub fn remove(self) {
let Self { remove_tx, .. } = self;
let _ = remove_tx.send(());
}
}
impl IntoFuture for ConnectingTransportHandle {
type Output = Result<()>;
type IntoFuture = BoxFuture<'static, Result<()>>;
fn into_future(self) -> Self::IntoFuture {
let Self { result_rx, .. } = self;
async move {
match result_rx.await {
Ok(res) => res,
Err(_) => Ok(()),
}
}
.boxed()
}
}