use crate::dispatcher::backend::datapoint::DataPointStore;
use crate::mux::{
HandshakeMessage, PROTOCOL_DATA_POINT, PROTOCOL_EKG, PROTOCOL_HANDSHAKE, PROTOCOL_TRACE_OBJECT,
TraceForwardClient, version_table_v1,
};
use crate::protocol::TraceObject;
use crate::server::datapoint::DataPointMessage;
use chrono::{DateTime, Utc};
use pallas_network::multiplexer::{Bearer, ChannelBuffer, Plexer};
use std::path::PathBuf;
use thiserror::Error;
use tokio::sync::mpsc;
use tracing::{debug, error, info, warn};
#[derive(Debug, Error)]
pub enum ForwarderError {
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
#[error("Multiplexer error: {0}")]
Multiplexer(#[from] pallas_network::multiplexer::Error),
#[error("Handshake refused")]
HandshakeRefused,
#[error("Unexpected handshake message")]
UnexpectedHandshake,
#[error("Connection closed unexpectedly")]
ConnectionClosed,
#[error("Trace queue full, dropping traces")]
QueueFull,
}
#[derive(Debug, Clone)]
pub enum ForwarderAddress {
Unix(PathBuf),
Tcp(String, u16),
}
impl Default for ForwarderAddress {
fn default() -> Self {
ForwarderAddress::Unix(PathBuf::from("/tmp/hermod-tracer.sock"))
}
}
impl std::fmt::Display for ForwarderAddress {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ForwarderAddress::Unix(p) => write!(f, "{}", p.display()),
ForwarderAddress::Tcp(host, port) => write!(f, "{}:{}", host, port),
}
}
}
#[derive(Debug, Clone)]
pub struct ForwarderConfig {
pub address: ForwarderAddress,
pub queue_size: usize,
pub max_reconnect_delay: u64,
pub network_magic: u64,
pub node_name: Option<String>,
}
impl Default for ForwarderConfig {
fn default() -> Self {
Self {
address: ForwarderAddress::default(),
queue_size: 1000,
max_reconnect_delay: 45,
network_magic: 764824073,
node_name: None,
}
}
}
#[derive(Clone)]
pub struct ForwarderHandle {
tx: mpsc::Sender<TraceObject>,
}
impl ForwarderHandle {
pub async fn send(&self, trace: TraceObject) -> Result<(), ForwarderError> {
self.tx
.send(trace)
.await
.map_err(|_| ForwarderError::QueueFull)
}
pub fn try_send(&self, trace: TraceObject) -> Result<(), ForwarderError> {
self.tx
.try_send(trace)
.map_err(|_| ForwarderError::QueueFull)
}
}
pub struct TraceForwarder {
config: ForwarderConfig,
rx: mpsc::Receiver<TraceObject>,
handle: ForwarderHandle,
start_time: DateTime<Utc>,
datapoint_store: Option<DataPointStore>,
}
impl TraceForwarder {
pub fn new(config: ForwarderConfig) -> Self {
let (tx, rx) = mpsc::channel(config.queue_size);
let handle = ForwarderHandle { tx };
Self {
config,
rx,
handle,
start_time: Utc::now(),
datapoint_store: None,
}
}
pub fn with_datapoint_store(mut self, store: DataPointStore) -> Self {
self.datapoint_store = Some(store);
self
}
pub fn handle(&self) -> ForwarderHandle {
self.handle.clone()
}
pub async fn run(mut self) -> Result<(), ForwarderError> {
info!("Starting trace forwarder");
let mut reconnect_delay = 1;
loop {
match self.connect_and_run().await {
Ok(()) => {
info!("Forwarder connection closed gracefully");
break Ok(());
}
Err(e) => {
error!(
"Forwarder error: {}, reconnecting in {}s",
e, reconnect_delay
);
tokio::time::sleep(tokio::time::Duration::from_secs(reconnect_delay)).await;
reconnect_delay = (reconnect_delay * 2).min(self.config.max_reconnect_delay);
}
}
}
}
async fn connect_and_run(&mut self) -> Result<(), ForwarderError> {
debug!("Connecting to {}", self.config.address);
let bearer = match &self.config.address {
ForwarderAddress::Unix(path) => Bearer::connect_unix(path).await?,
ForwarderAddress::Tcp(host, port) => {
let addr = format!("{}:{}", host, port);
Bearer::connect_tcp(&addr)
.await
.map_err(|e| std::io::Error::other(e.to_string()))?
}
};
info!("Connected to hermod-tracer at {}", self.config.address);
let mut plexer = Plexer::new(bearer);
let handshake_channel = plexer.subscribe_client(PROTOCOL_HANDSHAKE);
let trace_channel = plexer.subscribe_client(PROTOCOL_TRACE_OBJECT);
let _ekg_channel = plexer.subscribe_client(PROTOCOL_EKG);
let datapoint_channel = plexer.subscribe_client(PROTOCOL_DATA_POINT);
let _plexer_handle = plexer.spawn();
let node_info_bytes: Option<Vec<u8>> = self.config.node_name.as_deref().map(|name| {
serde_json::json!({
"niName": name,
"niProtocol": "",
"niVersion": env!("CARGO_PKG_VERSION"),
"niCommit": "",
"niStartTime": self.start_time,
"niSystemStartTime": self.start_time,
})
.to_string()
.into_bytes()
});
let dp_store = self.datapoint_store.clone();
tokio::spawn(async move {
let mut buf = ChannelBuffer::new(datapoint_channel);
while let Ok(DataPointMessage::Request(names)) =
buf.recv_full_msg::<DataPointMessage>().await
{
let reply = names
.into_iter()
.map(|n| {
let val = if n == "NodeInfo" {
node_info_bytes.clone()
} else {
dp_store.as_ref().and_then(|s| s.get(&n))
};
(n, val)
})
.collect();
if buf
.send_msg_chunks(&DataPointMessage::Reply(reply))
.await
.is_err()
{
break;
}
}
});
let mut hs_buf = ChannelBuffer::new(handshake_channel);
let versions = version_table_v1(self.config.network_magic);
hs_buf
.send_msg_chunks(&HandshakeMessage::Propose(versions))
.await?;
let response: HandshakeMessage = hs_buf.recv_full_msg().await?;
match response {
HandshakeMessage::Accept(version, data) => {
info!(
"Handshake accepted: version={}, magic={}",
version, data.network_magic
);
}
HandshakeMessage::Refuse(_) => {
return Err(ForwarderError::HandshakeRefused);
}
_ => {
return Err(ForwarderError::UnexpectedHandshake);
}
}
let mut client = TraceForwardClient::new(trace_channel);
loop {
let first = match self.rx.recv().await {
Some(t) => t,
None => return Ok(()), };
let mut traces = vec![first];
while let Ok(t) = self.rx.try_recv() {
traces.push(t);
}
debug!("Sending {} traces to acceptor", traces.len());
match client.handle_request(traces).await {
Ok(()) => {}
Err(crate::mux::ClientError::ConnectionClosed) => {
info!("Acceptor sent Done, closing connection");
return Ok(());
}
Err(e) => {
warn!("Client error: {}", e);
return Err(ForwarderError::ConnectionClosed);
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::protocol::types::{DetailLevel, Severity, TraceObject};
use chrono::Utc;
fn make_trace() -> TraceObject {
TraceObject {
to_human: None,
to_machine: "{}".to_string(),
to_namespace: vec!["Test".to_string()],
to_severity: Severity::Info,
to_details: DetailLevel::DNormal,
to_timestamp: Utc::now(),
to_hostname: "host".to_string(),
to_thread_id: "1".to_string(),
}
}
#[test]
fn test_forwarder_config_default() {
let config = ForwarderConfig::default();
assert_eq!(config.queue_size, 1000);
assert_eq!(config.max_reconnect_delay, 45);
assert!(matches!(config.address, ForwarderAddress::Unix(_)));
assert!(config.node_name.is_none());
}
#[test]
fn test_forwarder_address_display() {
let unix = ForwarderAddress::Unix(PathBuf::from("/tmp/test.sock"));
assert_eq!(unix.to_string(), "/tmp/test.sock");
let tcp = ForwarderAddress::Tcp("127.0.0.1".to_string(), 9090);
assert_eq!(tcp.to_string(), "127.0.0.1:9090");
}
#[test]
fn try_send_succeeds_when_queue_has_space() {
let forwarder = TraceForwarder::new(ForwarderConfig {
queue_size: 10,
..Default::default()
});
let handle = forwarder.handle();
assert!(handle.try_send(make_trace()).is_ok());
drop(forwarder);
}
#[test]
fn try_send_returns_queue_full_when_channel_full() {
let forwarder = TraceForwarder::new(ForwarderConfig {
queue_size: 1,
..Default::default()
});
let handle = forwarder.handle();
let _ = handle.try_send(make_trace());
let result = handle.try_send(make_trace());
assert!(
matches!(result, Err(ForwarderError::QueueFull)),
"expected QueueFull, got {:?}",
result
);
drop(forwarder);
}
#[test]
fn forwarder_address_tcp_variant() {
let addr = ForwarderAddress::Tcp("localhost".to_string(), 3001);
assert_eq!(addr.to_string(), "localhost:3001");
}
}