1use crate::dispatcher::backend::datapoint::DataPointStore;
8use crate::mux::{
9 HandshakeMessage, PROTOCOL_DATA_POINT, PROTOCOL_EKG, PROTOCOL_HANDSHAKE, PROTOCOL_TRACE_OBJECT,
10 TraceForwardClient, version_table_v1,
11};
12use crate::protocol::TraceObject;
13use crate::server::datapoint::DataPointMessage;
14use chrono::{DateTime, Utc};
15use pallas_network::multiplexer::{Bearer, ChannelBuffer, Plexer};
16use std::path::PathBuf;
17use thiserror::Error;
18use tokio::sync::mpsc;
19use tracing::{debug, error, info, warn};
20
21#[derive(Debug, Error)]
23pub enum ForwarderError {
24 #[error("IO error: {0}")]
26 Io(#[from] std::io::Error),
27
28 #[error("Multiplexer error: {0}")]
30 Multiplexer(#[from] pallas_network::multiplexer::Error),
31
32 #[error("Handshake refused")]
34 HandshakeRefused,
35
36 #[error("Unexpected handshake message")]
38 UnexpectedHandshake,
39
40 #[error("Connection closed unexpectedly")]
42 ConnectionClosed,
43
44 #[error("Trace queue full, dropping traces")]
46 QueueFull,
47}
48
49#[derive(Debug, Clone)]
51pub enum ForwarderAddress {
52 Unix(PathBuf),
54 Tcp(String, u16),
56}
57
58impl Default for ForwarderAddress {
59 fn default() -> Self {
60 ForwarderAddress::Unix(PathBuf::from("/tmp/hermod-tracer.sock"))
61 }
62}
63
64impl std::fmt::Display for ForwarderAddress {
65 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66 match self {
67 ForwarderAddress::Unix(p) => write!(f, "{}", p.display()),
68 ForwarderAddress::Tcp(host, port) => write!(f, "{}:{}", host, port),
69 }
70 }
71}
72
73#[derive(Debug, Clone)]
75pub struct ForwarderConfig {
76 pub address: ForwarderAddress,
78
79 pub queue_size: usize,
81
82 pub max_reconnect_delay: u64,
84
85 pub network_magic: u64,
87
88 pub node_name: Option<String>,
98}
99
100impl Default for ForwarderConfig {
101 fn default() -> Self {
102 Self {
103 address: ForwarderAddress::default(),
104 queue_size: 1000,
105 max_reconnect_delay: 45,
106 network_magic: 764824073,
107 node_name: None,
108 }
109 }
110}
111
112#[derive(Clone)]
114pub struct ForwarderHandle {
115 tx: mpsc::Sender<TraceObject>,
116}
117
118impl ForwarderHandle {
119 pub async fn send(&self, trace: TraceObject) -> Result<(), ForwarderError> {
123 self.tx
124 .send(trace)
125 .await
126 .map_err(|_| ForwarderError::QueueFull)
127 }
128
129 pub fn try_send(&self, trace: TraceObject) -> Result<(), ForwarderError> {
133 self.tx
134 .try_send(trace)
135 .map_err(|_| ForwarderError::QueueFull)
136 }
137}
138
139pub struct TraceForwarder {
141 config: ForwarderConfig,
142 rx: mpsc::Receiver<TraceObject>,
143 handle: ForwarderHandle,
144 start_time: DateTime<Utc>,
146 datapoint_store: Option<DataPointStore>,
148}
149
150impl TraceForwarder {
151 pub fn new(config: ForwarderConfig) -> Self {
153 let (tx, rx) = mpsc::channel(config.queue_size);
154 let handle = ForwarderHandle { tx };
155 Self {
156 config,
157 rx,
158 handle,
159 start_time: Utc::now(),
160 datapoint_store: None,
161 }
162 }
163
164 pub fn with_datapoint_store(mut self, store: DataPointStore) -> Self {
170 self.datapoint_store = Some(store);
171 self
172 }
173
174 pub fn handle(&self) -> ForwarderHandle {
176 self.handle.clone()
177 }
178
179 pub async fn run(mut self) -> Result<(), ForwarderError> {
181 info!("Starting trace forwarder");
182
183 let mut reconnect_delay = 1;
184
185 loop {
186 match self.connect_and_run().await {
187 Ok(()) => {
188 info!("Forwarder connection closed gracefully");
189 break Ok(());
190 }
191 Err(e) => {
192 error!(
193 "Forwarder error: {}, reconnecting in {}s",
194 e, reconnect_delay
195 );
196 tokio::time::sleep(tokio::time::Duration::from_secs(reconnect_delay)).await;
197 reconnect_delay = (reconnect_delay * 2).min(self.config.max_reconnect_delay);
198 }
199 }
200 }
201 }
202
203 async fn connect_and_run(&mut self) -> Result<(), ForwarderError> {
204 debug!("Connecting to {}", self.config.address);
205 let bearer = match &self.config.address {
206 ForwarderAddress::Unix(path) => Bearer::connect_unix(path).await?,
207 ForwarderAddress::Tcp(host, port) => {
208 let addr = format!("{}:{}", host, port);
209 Bearer::connect_tcp(&addr)
210 .await
211 .map_err(|e| std::io::Error::other(e.to_string()))?
212 }
213 };
214 info!("Connected to hermod-tracer at {}", self.config.address);
215
216 let mut plexer = Plexer::new(bearer);
217
218 let handshake_channel = plexer.subscribe_client(PROTOCOL_HANDSHAKE);
219 let trace_channel = plexer.subscribe_client(PROTOCOL_TRACE_OBJECT);
220 let _ekg_channel = plexer.subscribe_client(PROTOCOL_EKG);
221 let datapoint_channel = plexer.subscribe_client(PROTOCOL_DATA_POINT);
222
223 let _plexer_handle = plexer.spawn();
224
225 let node_info_bytes: Option<Vec<u8>> = self.config.node_name.as_deref().map(|name| {
234 serde_json::json!({
235 "niName": name,
236 "niProtocol": "",
237 "niVersion": env!("CARGO_PKG_VERSION"),
238 "niCommit": "",
239 "niStartTime": self.start_time,
240 "niSystemStartTime": self.start_time,
241 })
242 .to_string()
243 .into_bytes()
244 });
245
246 let dp_store = self.datapoint_store.clone();
247 tokio::spawn(async move {
248 let mut buf = ChannelBuffer::new(datapoint_channel);
249 while let Ok(DataPointMessage::Request(names)) =
250 buf.recv_full_msg::<DataPointMessage>().await
251 {
252 let reply = names
253 .into_iter()
254 .map(|n| {
255 let val = if n == "NodeInfo" {
256 node_info_bytes.clone()
257 } else {
258 dp_store.as_ref().and_then(|s| s.get(&n))
259 };
260 (n, val)
261 })
262 .collect();
263 if buf
264 .send_msg_chunks(&DataPointMessage::Reply(reply))
265 .await
266 .is_err()
267 {
268 break;
269 }
270 }
271 });
272
273 let mut hs_buf = ChannelBuffer::new(handshake_channel);
275 let versions = version_table_v1(self.config.network_magic);
276 hs_buf
277 .send_msg_chunks(&HandshakeMessage::Propose(versions))
278 .await?;
279 let response: HandshakeMessage = hs_buf.recv_full_msg().await?;
280 match response {
281 HandshakeMessage::Accept(version, data) => {
282 info!(
283 "Handshake accepted: version={}, magic={}",
284 version, data.network_magic
285 );
286 }
287 HandshakeMessage::Refuse(_) => {
288 return Err(ForwarderError::HandshakeRefused);
289 }
290 _ => {
291 return Err(ForwarderError::UnexpectedHandshake);
292 }
293 }
294
295 let mut client = TraceForwardClient::new(trace_channel);
296
297 loop {
298 let first = match self.rx.recv().await {
300 Some(t) => t,
301 None => return Ok(()), };
303
304 let mut traces = vec![first];
306 while let Ok(t) = self.rx.try_recv() {
307 traces.push(t);
308 }
309
310 debug!("Sending {} traces to acceptor", traces.len());
311
312 match client.handle_request(traces).await {
313 Ok(()) => {}
314 Err(crate::mux::ClientError::ConnectionClosed) => {
315 info!("Acceptor sent Done, closing connection");
316 return Ok(());
317 }
318 Err(e) => {
319 warn!("Client error: {}", e);
320 return Err(ForwarderError::ConnectionClosed);
321 }
322 }
323 }
324 }
325}
326
327#[cfg(test)]
328mod tests {
329 use super::*;
330
331 use crate::protocol::types::{DetailLevel, Severity, TraceObject};
332 use chrono::Utc;
333
334 fn make_trace() -> TraceObject {
335 TraceObject {
336 to_human: None,
337 to_machine: "{}".to_string(),
338 to_namespace: vec!["Test".to_string()],
339 to_severity: Severity::Info,
340 to_details: DetailLevel::DNormal,
341 to_timestamp: Utc::now(),
342 to_hostname: "host".to_string(),
343 to_thread_id: "1".to_string(),
344 }
345 }
346
347 #[test]
348 fn test_forwarder_config_default() {
349 let config = ForwarderConfig::default();
350 assert_eq!(config.queue_size, 1000);
351 assert_eq!(config.max_reconnect_delay, 45);
352 assert!(matches!(config.address, ForwarderAddress::Unix(_)));
353 assert!(config.node_name.is_none());
354 }
355
356 #[test]
357 fn test_forwarder_address_display() {
358 let unix = ForwarderAddress::Unix(PathBuf::from("/tmp/test.sock"));
359 assert_eq!(unix.to_string(), "/tmp/test.sock");
360
361 let tcp = ForwarderAddress::Tcp("127.0.0.1".to_string(), 9090);
362 assert_eq!(tcp.to_string(), "127.0.0.1:9090");
363 }
364
365 #[test]
366 fn try_send_succeeds_when_queue_has_space() {
367 let forwarder = TraceForwarder::new(ForwarderConfig {
368 queue_size: 10,
369 ..Default::default()
370 });
371 let handle = forwarder.handle();
372 assert!(handle.try_send(make_trace()).is_ok());
373 drop(forwarder);
375 }
376
377 #[test]
378 fn try_send_returns_queue_full_when_channel_full() {
379 let forwarder = TraceForwarder::new(ForwarderConfig {
380 queue_size: 1,
381 ..Default::default()
382 });
383 let handle = forwarder.handle();
384 let _ = handle.try_send(make_trace());
386 let result = handle.try_send(make_trace());
388 assert!(
389 matches!(result, Err(ForwarderError::QueueFull)),
390 "expected QueueFull, got {:?}",
391 result
392 );
393 drop(forwarder);
394 }
395
396 #[test]
397 fn forwarder_address_tcp_variant() {
398 let addr = ForwarderAddress::Tcp("localhost".to_string(), 3001);
399 assert_eq!(addr.to_string(), "localhost:3001");
400 }
401}