Skip to main content

hermod/
forwarder.rs

1//! Trace forwarder client implementation
2//!
3//! This module implements the forwarder side of the trace-forward protocol,
4//! which connects to a hermod-tracer acceptor and sends trace objects via the
5//! Ouroboros Network multiplexer.
6
7use crate::dispatcher::backend::datapoint::DataPointStore;
8use crate::mux::{
9    HandshakeMessage, PROTOCOL_DATA_POINT, PROTOCOL_EKG, PROTOCOL_HANDSHAKE, PROTOCOL_TRACE_OBJECT,
10    TraceForwardClient, version_table_v1,
11};
12use crate::protocol::TraceObject;
13use crate::server::datapoint::DataPointMessage;
14use chrono::{DateTime, Utc};
15use pallas_network::multiplexer::{Bearer, ChannelBuffer, Plexer};
16use std::path::PathBuf;
17use thiserror::Error;
18use tokio::sync::mpsc;
19use tracing::{debug, error, info, warn};
20
21/// Errors that can occur in the forwarder
22#[derive(Debug, Error)]
23pub enum ForwarderError {
24    /// IO error
25    #[error("IO error: {0}")]
26    Io(#[from] std::io::Error),
27
28    /// Multiplexer error
29    #[error("Multiplexer error: {0}")]
30    Multiplexer(#[from] pallas_network::multiplexer::Error),
31
32    /// Handshake was refused by the acceptor
33    #[error("Handshake refused")]
34    HandshakeRefused,
35
36    /// Unexpected message during handshake
37    #[error("Unexpected handshake message")]
38    UnexpectedHandshake,
39
40    /// Connection closed unexpectedly
41    #[error("Connection closed unexpectedly")]
42    ConnectionClosed,
43
44    /// Queue full (traces being dropped)
45    #[error("Trace queue full, dropping traces")]
46    QueueFull,
47}
48
49/// Address the forwarder should connect to
50#[derive(Debug, Clone)]
51pub enum ForwarderAddress {
52    /// Unix domain socket path
53    Unix(PathBuf),
54    /// TCP host and port
55    Tcp(String, u16),
56}
57
58impl Default for ForwarderAddress {
59    fn default() -> Self {
60        ForwarderAddress::Unix(PathBuf::from("/tmp/hermod-tracer.sock"))
61    }
62}
63
64impl std::fmt::Display for ForwarderAddress {
65    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66        match self {
67            ForwarderAddress::Unix(p) => write!(f, "{}", p.display()),
68            ForwarderAddress::Tcp(host, port) => write!(f, "{}:{}", host, port),
69        }
70    }
71}
72
73/// Configuration for the trace forwarder
74#[derive(Debug, Clone)]
75pub struct ForwarderConfig {
76    /// Address to connect to (Unix socket path or TCP host:port)
77    pub address: ForwarderAddress,
78
79    /// Maximum number of traces to buffer before dropping
80    pub queue_size: usize,
81
82    /// Maximum reconnection delay in seconds
83    pub max_reconnect_delay: u64,
84
85    /// Cardano network magic (must match the acceptor's)
86    pub network_magic: u64,
87
88    /// Node display name advertised via the `NodeInfo` DataPoint.
89    ///
90    /// When `Some`, the forwarder responds to `"NodeInfo"` DataPoint requests
91    /// with `{\"niName\": name, ...}`, which `hermod-tracer` (and Haskell
92    /// `cardano-tracer`) use as the node's display name, Prometheus slug, and
93    /// log subdirectory name.
94    ///
95    /// When `None`, the acceptor falls back to the connection-address node ID
96    /// (e.g. `unix-1` for the first inbound Unix socket connection).
97    pub node_name: Option<String>,
98}
99
100impl Default for ForwarderConfig {
101    fn default() -> Self {
102        Self {
103            address: ForwarderAddress::default(),
104            queue_size: 1000,
105            max_reconnect_delay: 45,
106            network_magic: 764824073,
107            node_name: None,
108        }
109    }
110}
111
112/// Handle for sending traces to the forwarder
113#[derive(Clone)]
114pub struct ForwarderHandle {
115    tx: mpsc::Sender<TraceObject>,
116}
117
118impl ForwarderHandle {
119    /// Send a trace object
120    ///
121    /// Returns `Err(ForwarderError::QueueFull)` if the queue is full
122    pub async fn send(&self, trace: TraceObject) -> Result<(), ForwarderError> {
123        self.tx
124            .send(trace)
125            .await
126            .map_err(|_| ForwarderError::QueueFull)
127    }
128
129    /// Try to send a trace object without waiting
130    ///
131    /// Returns `Err(ForwarderError::QueueFull)` if the queue is full
132    pub fn try_send(&self, trace: TraceObject) -> Result<(), ForwarderError> {
133        self.tx
134            .try_send(trace)
135            .map_err(|_| ForwarderError::QueueFull)
136    }
137}
138
139/// Trace forwarder that connects to hermod-tracer
140pub struct TraceForwarder {
141    config: ForwarderConfig,
142    rx: mpsc::Receiver<TraceObject>,
143    handle: ForwarderHandle,
144    /// When this forwarder process started (used in `NodeInfo` DataPoint replies)
145    start_time: DateTime<Utc>,
146    /// Optional shared data-point store (serves named data points on request)
147    datapoint_store: Option<DataPointStore>,
148}
149
150impl TraceForwarder {
151    /// Create a new trace forwarder
152    pub fn new(config: ForwarderConfig) -> Self {
153        let (tx, rx) = mpsc::channel(config.queue_size);
154        let handle = ForwarderHandle { tx };
155        Self {
156            config,
157            rx,
158            handle,
159            start_time: Utc::now(),
160            datapoint_store: None,
161        }
162    }
163
164    /// Attach a [`DataPointStore`] so the forwarder can serve named data points
165    /// to the acceptor on request.
166    ///
167    /// The same store should be passed to [`DatapointBackend::with_store`] so
168    /// that dispatched trace objects are automatically stored and served.
169    pub fn with_datapoint_store(mut self, store: DataPointStore) -> Self {
170        self.datapoint_store = Some(store);
171        self
172    }
173
174    /// Get a handle for sending traces
175    pub fn handle(&self) -> ForwarderHandle {
176        self.handle.clone()
177    }
178
179    /// Run the forwarder (connects and handles protocol, reconnecting on error)
180    pub async fn run(mut self) -> Result<(), ForwarderError> {
181        info!("Starting trace forwarder");
182
183        let mut reconnect_delay = 1;
184
185        loop {
186            match self.connect_and_run().await {
187                Ok(()) => {
188                    info!("Forwarder connection closed gracefully");
189                    break Ok(());
190                }
191                Err(e) => {
192                    error!(
193                        "Forwarder error: {}, reconnecting in {}s",
194                        e, reconnect_delay
195                    );
196                    tokio::time::sleep(tokio::time::Duration::from_secs(reconnect_delay)).await;
197                    reconnect_delay = (reconnect_delay * 2).min(self.config.max_reconnect_delay);
198                }
199            }
200        }
201    }
202
203    async fn connect_and_run(&mut self) -> Result<(), ForwarderError> {
204        debug!("Connecting to {}", self.config.address);
205        let bearer = match &self.config.address {
206            ForwarderAddress::Unix(path) => Bearer::connect_unix(path).await?,
207            ForwarderAddress::Tcp(host, port) => {
208                let addr = format!("{}:{}", host, port);
209                Bearer::connect_tcp(&addr)
210                    .await
211                    .map_err(|e| std::io::Error::other(e.to_string()))?
212            }
213        };
214        info!("Connected to hermod-tracer at {}", self.config.address);
215
216        let mut plexer = Plexer::new(bearer);
217
218        let handshake_channel = plexer.subscribe_client(PROTOCOL_HANDSHAKE);
219        let trace_channel = plexer.subscribe_client(PROTOCOL_TRACE_OBJECT);
220        let _ekg_channel = plexer.subscribe_client(PROTOCOL_EKG);
221        let datapoint_channel = plexer.subscribe_client(PROTOCOL_DATA_POINT);
222
223        let _plexer_handle = plexer.spawn();
224
225        // Respond to DataPoint requests.
226        // The acceptor requests "NodeInfo" immediately after the handshake to
227        // resolve our display name.  We serialise a NodeInfo-compatible JSON
228        // object so the acceptor can extract `niName` and use it as the node's
229        // display name, Prometheus slug, and log subdirectory name.
230        //
231        // Any other named data point is looked up in the optional DataPointStore
232        // (set via `with_datapoint_store`).
233        let node_info_bytes: Option<Vec<u8>> = self.config.node_name.as_deref().map(|name| {
234            serde_json::json!({
235                "niName":            name,
236                "niProtocol":        "",
237                "niVersion":         env!("CARGO_PKG_VERSION"),
238                "niCommit":          "",
239                "niStartTime":       self.start_time,
240                "niSystemStartTime": self.start_time,
241            })
242            .to_string()
243            .into_bytes()
244        });
245
246        let dp_store = self.datapoint_store.clone();
247        tokio::spawn(async move {
248            let mut buf = ChannelBuffer::new(datapoint_channel);
249            while let Ok(DataPointMessage::Request(names)) =
250                buf.recv_full_msg::<DataPointMessage>().await
251            {
252                let reply = names
253                    .into_iter()
254                    .map(|n| {
255                        let val = if n == "NodeInfo" {
256                            node_info_bytes.clone()
257                        } else {
258                            dp_store.as_ref().and_then(|s| s.get(&n))
259                        };
260                        (n, val)
261                    })
262                    .collect();
263                if buf
264                    .send_msg_chunks(&DataPointMessage::Reply(reply))
265                    .await
266                    .is_err()
267                {
268                    break;
269                }
270            }
271        });
272
273        // Perform handshake
274        let mut hs_buf = ChannelBuffer::new(handshake_channel);
275        let versions = version_table_v1(self.config.network_magic);
276        hs_buf
277            .send_msg_chunks(&HandshakeMessage::Propose(versions))
278            .await?;
279        let response: HandshakeMessage = hs_buf.recv_full_msg().await?;
280        match response {
281            HandshakeMessage::Accept(version, data) => {
282                info!(
283                    "Handshake accepted: version={}, magic={}",
284                    version, data.network_magic
285                );
286            }
287            HandshakeMessage::Refuse(_) => {
288                return Err(ForwarderError::HandshakeRefused);
289            }
290            _ => {
291                return Err(ForwarderError::UnexpectedHandshake);
292            }
293        }
294
295        let mut client = TraceForwardClient::new(trace_channel);
296
297        loop {
298            // Wait for at least one trace (blocking)
299            let first = match self.rx.recv().await {
300                Some(t) => t,
301                None => return Ok(()), // channel closed, shut down
302            };
303
304            // Drain any additional pending traces
305            let mut traces = vec![first];
306            while let Ok(t) = self.rx.try_recv() {
307                traces.push(t);
308            }
309
310            debug!("Sending {} traces to acceptor", traces.len());
311
312            match client.handle_request(traces).await {
313                Ok(()) => {}
314                Err(crate::mux::ClientError::ConnectionClosed) => {
315                    info!("Acceptor sent Done, closing connection");
316                    return Ok(());
317                }
318                Err(e) => {
319                    warn!("Client error: {}", e);
320                    return Err(ForwarderError::ConnectionClosed);
321                }
322            }
323        }
324    }
325}
326
327#[cfg(test)]
328mod tests {
329    use super::*;
330
331    use crate::protocol::types::{DetailLevel, Severity, TraceObject};
332    use chrono::Utc;
333
334    fn make_trace() -> TraceObject {
335        TraceObject {
336            to_human: None,
337            to_machine: "{}".to_string(),
338            to_namespace: vec!["Test".to_string()],
339            to_severity: Severity::Info,
340            to_details: DetailLevel::DNormal,
341            to_timestamp: Utc::now(),
342            to_hostname: "host".to_string(),
343            to_thread_id: "1".to_string(),
344        }
345    }
346
347    #[test]
348    fn test_forwarder_config_default() {
349        let config = ForwarderConfig::default();
350        assert_eq!(config.queue_size, 1000);
351        assert_eq!(config.max_reconnect_delay, 45);
352        assert!(matches!(config.address, ForwarderAddress::Unix(_)));
353        assert!(config.node_name.is_none());
354    }
355
356    #[test]
357    fn test_forwarder_address_display() {
358        let unix = ForwarderAddress::Unix(PathBuf::from("/tmp/test.sock"));
359        assert_eq!(unix.to_string(), "/tmp/test.sock");
360
361        let tcp = ForwarderAddress::Tcp("127.0.0.1".to_string(), 9090);
362        assert_eq!(tcp.to_string(), "127.0.0.1:9090");
363    }
364
365    #[test]
366    fn try_send_succeeds_when_queue_has_space() {
367        let forwarder = TraceForwarder::new(ForwarderConfig {
368            queue_size: 10,
369            ..Default::default()
370        });
371        let handle = forwarder.handle();
372        assert!(handle.try_send(make_trace()).is_ok());
373        // Keep forwarder alive (owns the receiver)
374        drop(forwarder);
375    }
376
377    #[test]
378    fn try_send_returns_queue_full_when_channel_full() {
379        let forwarder = TraceForwarder::new(ForwarderConfig {
380            queue_size: 1,
381            ..Default::default()
382        });
383        let handle = forwarder.handle();
384        // Fill the single-slot queue
385        let _ = handle.try_send(make_trace());
386        // Next send must fail
387        let result = handle.try_send(make_trace());
388        assert!(
389            matches!(result, Err(ForwarderError::QueueFull)),
390            "expected QueueFull, got {:?}",
391            result
392        );
393        drop(forwarder);
394    }
395
396    #[test]
397    fn forwarder_address_tcp_variant() {
398        let addr = ForwarderAddress::Tcp("localhost".to_string(), 3001);
399        assert_eq!(addr.to_string(), "localhost:3001");
400    }
401}