use crate::forwarder::ForwarderHandle;
use crate::mux::{
ForwardingVersionData, HandshakeMessage, PROTOCOL_DATA_POINT, PROTOCOL_EKG, PROTOCOL_HANDSHAKE,
PROTOCOL_TRACE_OBJECT, TraceForwardClient, version_table_v1,
};
use crate::protocol::TraceObject;
use crate::server::config::Address;
use pallas_network::multiplexer::{Bearer, Plexer};
use std::sync::Arc;
use tokio::net::{TcpListener, UnixListener};
use tokio::sync::broadcast;
use tokio::task::JoinSet;
use tracing::{debug, info, warn};
pub struct ReForwarder {
inner: ReForwarderInner,
namespace_filters: Option<Vec<Vec<String>>>,
}
enum ReForwarderInner {
Outbound(ForwarderHandle),
Inbound(broadcast::Sender<Arc<Vec<TraceObject>>>),
}
impl ReForwarder {
pub fn new(handle: ForwarderHandle, namespace_filters: Option<Vec<Vec<String>>>) -> Self {
ReForwarder {
inner: ReForwarderInner::Outbound(handle),
namespace_filters,
}
}
pub fn new_inbound(
tx: broadcast::Sender<Arc<Vec<TraceObject>>>,
namespace_filters: Option<Vec<Vec<String>>>,
) -> Self {
ReForwarder {
inner: ReForwarderInner::Inbound(tx),
namespace_filters,
}
}
pub async fn forward(&self, traces: &[TraceObject]) {
let filtered: Vec<TraceObject> = traces
.iter()
.filter(|t| self.matches_filter(t))
.cloned()
.collect();
if filtered.is_empty() {
return;
}
match &self.inner {
ReForwarderInner::Outbound(handle) => {
for trace in filtered {
if let Err(e) = handle.send(trace).await {
warn!("ReForwarder send error: {}", e);
}
}
}
ReForwarderInner::Inbound(tx) => {
let _ = tx.send(Arc::new(filtered));
}
}
}
fn matches_filter(&self, trace: &TraceObject) -> bool {
let Some(filters) = &self.namespace_filters else {
return true; };
filters
.iter()
.any(|prefix| trace.to_namespace.starts_with(prefix))
}
}
pub async fn run_accepting_loop(
addrs: &[Address],
tx: broadcast::Sender<Arc<Vec<TraceObject>>>,
network_magic: u64,
) {
let mut set = JoinSet::new();
for addr in addrs {
let addr = addr.clone();
let tx = tx.clone();
set.spawn(async move {
listen_and_accept(addr, tx, network_magic).await;
});
}
while set.join_next().await.is_some() {}
}
async fn listen_and_accept(
addr: Address,
tx: broadcast::Sender<Arc<Vec<TraceObject>>>,
network_magic: u64,
) {
match &addr {
Address::LocalPipe(path) => {
let _ = std::fs::remove_file(path);
let listener = match UnixListener::bind(path) {
Ok(l) => l,
Err(e) => {
warn!(
"AcceptingReForwarder: failed to bind {}: {}",
path.display(),
e
);
return;
}
};
info!("AcceptingReForwarder: listening on {}", path.display());
loop {
match Bearer::accept_unix(&listener).await {
Ok((bearer, _)) => {
let rx = tx.subscribe();
tokio::spawn(handle_accepting_connection(bearer, rx, network_magic));
}
Err(e) => {
warn!("AcceptingReForwarder accept error: {}", e);
break;
}
}
}
}
Address::RemoteSocket(host, port) => {
let bind_addr = format!("{}:{}", host, port);
let listener = match TcpListener::bind(&bind_addr).await {
Ok(l) => l,
Err(e) => {
warn!(
"AcceptingReForwarder: failed to bind TCP {}: {}",
bind_addr, e
);
return;
}
};
info!("AcceptingReForwarder: listening on TCP {}", bind_addr);
loop {
match Bearer::accept_tcp(&listener).await {
Ok((bearer, _)) => {
let rx = tx.subscribe();
tokio::spawn(handle_accepting_connection(bearer, rx, network_magic));
}
Err(e) => {
warn!("AcceptingReForwarder TCP accept error: {}", e);
break;
}
}
}
}
}
}
async fn handle_accepting_connection(
bearer: Bearer,
mut rx: broadcast::Receiver<Arc<Vec<TraceObject>>>,
network_magic: u64,
) {
let mut plexer = Plexer::new(bearer);
let hs_ch = plexer.subscribe_server(PROTOCOL_HANDSHAKE);
let trace_ch = plexer.subscribe_server(PROTOCOL_TRACE_OBJECT);
drop(plexer.subscribe_server(PROTOCOL_EKG));
drop(plexer.subscribe_server(PROTOCOL_DATA_POINT));
let _plexer_handle = plexer.spawn();
use pallas_network::multiplexer::ChannelBuffer;
let mut hs = ChannelBuffer::new(hs_ch);
let versions = version_table_v1(network_magic);
let msg: HandshakeMessage = match hs.recv_full_msg().await {
Ok(m) => m,
Err(e) => {
warn!("AcceptingReForwarder: handshake recv failed: {}", e);
return;
}
};
match msg {
HandshakeMessage::Propose(proposed) => {
let chosen = proposed
.keys()
.filter(|v| versions.contains_key(v))
.max()
.copied();
match chosen {
Some(ver) => {
let accept =
HandshakeMessage::Accept(ver, ForwardingVersionData { network_magic });
if let Err(e) = hs.send_msg_chunks(&accept).await {
warn!("AcceptingReForwarder: handshake accept send failed: {}", e);
return;
}
debug!("AcceptingReForwarder: handshake accepted v={}", ver);
}
None => {
let offered: Vec<u64> = proposed.into_keys().collect();
let _ = hs.send_msg_chunks(&HandshakeMessage::Refuse(offered)).await;
warn!("AcceptingReForwarder: no compatible version");
return;
}
}
}
other => {
warn!("AcceptingReForwarder: expected Propose, got {:?}", other);
return;
}
}
let mut client = TraceForwardClient::new(trace_ch);
loop {
let batch: Arc<Vec<TraceObject>> = loop {
match rx.recv().await {
Ok(b) => break b,
Err(broadcast::error::RecvError::Closed) => {
info!("AcceptingReForwarder: broadcast channel closed");
return;
}
Err(broadcast::error::RecvError::Lagged(n)) => {
warn!("AcceptingReForwarder: lagged by {} batches, skipping", n);
continue;
}
}
};
let mut traces: Vec<TraceObject> = (*batch).clone();
while let Ok(extra) = rx.try_recv() {
traces.extend_from_slice(&extra);
}
match client.handle_request(traces).await {
Ok(()) => {}
Err(crate::mux::ClientError::ConnectionClosed) => {
info!("AcceptingReForwarder: downstream sent Done");
return;
}
Err(e) => {
warn!("AcceptingReForwarder: trace error: {}", e);
return;
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::protocol::types::{DetailLevel, Severity, TraceObject};
use chrono::Utc;
fn make_trace(namespace: Vec<&str>) -> TraceObject {
TraceObject {
to_human: None,
to_machine: "{}".to_string(),
to_namespace: namespace.into_iter().map(str::to_string).collect(),
to_severity: Severity::Info,
to_details: DetailLevel::DNormal,
to_timestamp: Utc::now(),
to_hostname: "host".to_string(),
to_thread_id: "1".to_string(),
}
}
#[tokio::test]
async fn no_filter_forwards_all_traces() {
let (tx, mut rx) = broadcast::channel(16);
let rf = ReForwarder::new_inbound(tx, None);
let traces = vec![make_trace(vec!["A", "B"]), make_trace(vec!["C"])];
rf.forward(&traces).await;
let received = rx.recv().await.unwrap();
assert_eq!(received.len(), 2);
}
#[tokio::test]
async fn prefix_filter_blocks_non_matching_namespace() {
let (tx, mut rx) = broadcast::channel(16);
let filters = Some(vec![vec!["Cardano".to_string(), "Node".to_string()]]);
let rf = ReForwarder::new_inbound(tx, filters);
let traces = vec![
make_trace(vec!["Cardano", "Node", "Peers"]),
make_trace(vec!["Other", "Trace"]),
];
rf.forward(&traces).await;
let received = rx.recv().await.unwrap();
assert_eq!(received.len(), 1);
assert_eq!(received[0].to_namespace, vec!["Cardano", "Node", "Peers"]);
}
#[tokio::test]
async fn prefix_filter_exact_match_passes() {
let (tx, mut rx) = broadcast::channel(16);
let filters = Some(vec![vec!["Cardano".to_string(), "Node".to_string()]]);
let rf = ReForwarder::new_inbound(tx, filters);
let traces = vec![make_trace(vec!["Cardano", "Node"])];
rf.forward(&traces).await;
let received = rx.recv().await.unwrap();
assert_eq!(received.len(), 1);
}
#[tokio::test]
async fn filter_all_out_sends_nothing() {
let (tx, mut rx) = broadcast::channel(16);
let filters = Some(vec![vec!["Cardano".to_string()]]);
let rf = ReForwarder::new_inbound(tx, filters);
let traces = vec![make_trace(vec!["Other"])];
rf.forward(&traces).await;
assert!(rx.try_recv().is_err(), "nothing should be broadcast");
}
#[tokio::test]
async fn multiple_prefixes_any_match_passes() {
let (tx, mut rx) = broadcast::channel(16);
let filters = Some(vec![vec!["Cardano".to_string()], vec!["Node".to_string()]]);
let rf = ReForwarder::new_inbound(tx, filters);
let traces = vec![
make_trace(vec!["Cardano", "X"]),
make_trace(vec!["Node", "Y"]),
make_trace(vec!["Other"]),
];
rf.forward(&traces).await;
let received = rx.recv().await.unwrap();
assert_eq!(received.len(), 2);
}
#[tokio::test]
async fn empty_input_sends_nothing() {
let (tx, mut rx) = broadcast::channel(16);
let rf = ReForwarder::new_inbound(tx, None);
rf.forward(&[]).await;
assert!(rx.try_recv().is_err());
}
#[tokio::test]
async fn inbound_with_no_receivers_does_not_panic() {
let (tx, rx) = broadcast::channel::<Arc<Vec<TraceObject>>>(16);
drop(rx); let rf = ReForwarder::new_inbound(tx, None);
rf.forward(&[make_trace(vec!["A"])]).await;
}
#[tokio::test]
async fn inbound_broadcasts_to_multiple_receivers() {
let (tx, mut rx1) = broadcast::channel(16);
let mut rx2 = tx.subscribe();
let rf = ReForwarder::new_inbound(tx, None);
rf.forward(&[make_trace(vec!["A"])]).await;
assert_eq!(rx1.recv().await.unwrap().len(), 1);
assert_eq!(rx2.recv().await.unwrap().len(), 1);
}
}