use crate::mux::{
ForwardingVersionData, HandshakeMessage, PROTOCOL_DATA_POINT, PROTOCOL_EKG, PROTOCOL_HANDSHAKE,
PROTOCOL_TRACE_OBJECT, TraceAcceptorClient, version_table_v1,
};
use crate::server::config::{Address, Network};
use crate::server::datapoint::DataPointClient;
use crate::server::ekg::EkgPoller;
use crate::server::logging::LogWriter;
use crate::server::node::{NodeId, TracerState};
use crate::server::reforwarder::ReForwarder;
use crate::server::trace_handler::handle_traces;
use pallas_network::multiplexer::{Bearer, ChannelBuffer, Plexer};
use std::sync::Arc;
use tokio::net::{TcpListener, UnixListener};
use tokio::task::JoinSet;
use tokio::time::Duration;
use tracing::{debug, info, warn};
pub async fn run_network(
network: &Network,
state: Arc<TracerState>,
writer: Arc<LogWriter>,
reforwarder: Option<Arc<ReForwarder>>,
) -> anyhow::Result<()> {
match network {
Network::AcceptAt(addr) => run_accept_server(addr, state, writer, reforwarder).await,
Network::ConnectTo(addrs) => run_connect_clients(addrs, state, writer, reforwarder).await,
}
}
async fn run_accept_server(
addr: &Address,
state: Arc<TracerState>,
writer: Arc<LogWriter>,
reforwarder: Option<Arc<ReForwarder>>,
) -> anyhow::Result<()> {
match addr {
Address::LocalPipe(path) => {
let _ = std::fs::remove_file(path);
let listener = UnixListener::bind(path)?;
info!("Listening on Unix socket {}", path.display());
let mut counter = 0u64;
loop {
let (bearer, _) = Bearer::accept_unix(&listener).await?;
counter += 1;
let node_id = format!("unix-{}", counter);
spawn_handler(
bearer,
node_id,
state.clone(),
writer.clone(),
reforwarder.clone(),
);
}
}
Address::RemoteSocket(host, port) => {
let bind_addr = format!("{}:{}", host, port);
let listener = TcpListener::bind(&bind_addr).await?;
info!("Listening on TCP {}", bind_addr);
loop {
let (bearer, peer) = Bearer::accept_tcp(&listener).await?;
let node_id = peer.to_string();
spawn_handler(
bearer,
node_id,
state.clone(),
writer.clone(),
reforwarder.clone(),
);
}
}
}
}
async fn run_connect_clients(
addrs: &[Address],
state: Arc<TracerState>,
writer: Arc<LogWriter>,
reforwarder: Option<Arc<ReForwarder>>,
) -> anyhow::Result<()> {
let mut set = JoinSet::new();
for addr in addrs {
let addr = addr.clone();
let state = state.clone();
let writer = writer.clone();
let rf = reforwarder.clone();
set.spawn(async move {
connect_with_retry(&addr, state, writer, rf).await;
});
}
while set.join_next().await.is_some() {}
Ok(())
}
async fn connect_with_retry(
addr: &Address,
state: Arc<TracerState>,
writer: Arc<LogWriter>,
reforwarder: Option<Arc<ReForwarder>>,
) {
let node_id = addr.to_node_id();
let mut delay = 1u64;
loop {
info!("Connecting to {}", node_id);
let bearer_result = match addr {
Address::LocalPipe(path) => Bearer::connect_unix(path).await.map_err(|e| e.into()),
Address::RemoteSocket(host, port) => {
let addr_str = format!("{}:{}", host, port);
Bearer::connect_tcp(&addr_str)
.await
.map_err(|e| anyhow::anyhow!(e))
}
};
match bearer_result {
Ok(bearer) => {
delay = 1; if let Err(e) = handle_connection(
bearer,
node_id.clone(),
state.clone(),
writer.clone(),
reforwarder.clone(),
true,
)
.await
{
warn!("Connection to {} ended: {}", node_id, e);
}
}
Err(e) => {
warn!(
"Failed to connect to {}: {}, retrying in {}s",
node_id, e, delay
);
}
}
tokio::time::sleep(Duration::from_secs(delay)).await;
delay = (delay * 2).min(45);
}
}
fn spawn_handler(
bearer: Bearer,
node_id: NodeId,
state: Arc<TracerState>,
writer: Arc<LogWriter>,
reforwarder: Option<Arc<ReForwarder>>,
) {
tokio::spawn(async move {
if let Err(e) =
handle_connection(bearer, node_id.clone(), state, writer, reforwarder, false).await
{
warn!("Connection handler for {} ended: {}", node_id, e);
}
});
}
async fn handle_connection(
bearer: Bearer,
node_id: NodeId,
state: Arc<TracerState>,
writer: Arc<LogWriter>,
reforwarder: Option<Arc<ReForwarder>>,
is_initiator: bool,
) -> anyhow::Result<()> {
let config = state.config.clone();
let mut plexer = Plexer::new(bearer);
let (handshake_ch, trace_ch, ekg_ch, dp_ch) = if is_initiator {
(
plexer.subscribe_client(PROTOCOL_HANDSHAKE),
plexer.subscribe_client(PROTOCOL_TRACE_OBJECT),
plexer.subscribe_client(PROTOCOL_EKG),
plexer.subscribe_client(PROTOCOL_DATA_POINT),
)
} else {
(
plexer.subscribe_server(PROTOCOL_HANDSHAKE),
plexer.subscribe_server(PROTOCOL_TRACE_OBJECT),
plexer.subscribe_server(PROTOCOL_EKG),
plexer.subscribe_server(PROTOCOL_DATA_POINT),
)
};
let _plexer_handle = plexer.spawn();
let mut hs = ChannelBuffer::new(handshake_ch);
let network_magic = config.network_magic as u64;
let versions = version_table_v1(network_magic);
if is_initiator {
hs.send_msg_chunks(&HandshakeMessage::Propose(versions))
.await?;
let resp: HandshakeMessage = hs.recv_full_msg().await?;
match resp {
HandshakeMessage::Accept(ver, data) => {
info!(
"Handshake accepted v={} magic={} node={}",
ver, data.network_magic, node_id
);
}
HandshakeMessage::Refuse(_) => {
anyhow::bail!("Handshake refused by {}", node_id);
}
_ => anyhow::bail!("Unexpected handshake message from {}", node_id),
}
} else {
let msg: HandshakeMessage = hs.recv_full_msg().await?;
match msg {
HandshakeMessage::Propose(proposed) => {
let chosen = proposed
.keys()
.filter(|v| versions.contains_key(v))
.max()
.copied();
match chosen {
Some(ver) => {
let accept =
HandshakeMessage::Accept(ver, ForwardingVersionData { network_magic });
hs.send_msg_chunks(&accept).await?;
debug!("Handshake accepted v={} for {}", ver, node_id);
}
None => {
let offered: Vec<u64> = proposed.into_keys().collect();
hs.send_msg_chunks(&HandshakeMessage::Refuse(offered))
.await?;
anyhow::bail!("No compatible version with {}", node_id);
}
}
}
other => anyhow::bail!("Expected Propose, got {:?} from {}", other, node_id),
}
}
let mut dp_client = DataPointClient::new(dp_ch);
let node_name = resolve_node_name(&mut dp_client, &node_id).await;
let node = state.register(node_id.clone(), node_name).await;
info!(
"Node connected: {} name={} (slug={})",
node_id, node.name, node.slug
);
let mut tasks: JoinSet<()> = JoinSet::new();
{
let node = node.clone();
let writer = writer.clone();
let config = config.clone();
let rf = reforwarder.clone();
let logging = config.logging.clone();
let request_count = config.lo_request_num();
tasks.spawn(async move {
let mut client = TraceAcceptorClient::new(trace_ch);
loop {
match client.request_traces(request_count).await {
Ok(traces) => {
debug!("Received {} traces from {}", traces.len(), node.id);
handle_traces(traces, &node, &writer, &logging, rf.as_deref()).await;
}
Err(e) => {
info!("Trace loop ended for {}: {}", node.id, e);
return;
}
}
}
});
}
if config.has_ekg.is_some() {
let node = node.clone();
let config = config.clone();
tasks.spawn(async move {
let mut poller = EkgPoller::new(
ekg_ch,
node.clone(),
config.ekg_request_full.unwrap_or(false),
);
poller.run_poll_loop(config.ekg_request_freq()).await;
});
} else {
drop(ekg_ch);
}
{
tasks.spawn(async move {
dp_client.run_idle_loop().await;
});
}
tasks.join_next().await;
tasks.abort_all();
state.deregister(&node_id).await;
info!("Node disconnected: {}", node_id);
Ok(())
}
async fn resolve_node_name(dp: &mut DataPointClient, fallback: &str) -> String {
let result = tokio::time::timeout(
Duration::from_secs(5),
dp.request(vec!["NodeInfo".to_string()]),
)
.await;
result
.ok()
.and_then(|r| r.ok())
.and_then(|items| {
items
.into_iter()
.find(|(name, _)| name == "NodeInfo")
.and_then(|(_, val)| val)
.and_then(|v| v.get("niName")?.as_str().map(|s| s.to_string()))
})
.filter(|s| !s.is_empty())
.unwrap_or_else(|| fallback.to_string())
}