Skip to main content

hermod/
acceptor.rs

1//! Trace acceptor — listener side of the trace-forward protocol
2//!
3//! Listens on a Unix socket, performs the Ouroboros Network mux handshake,
4//! and drives the request-reply loop, yielding received `TraceObject`s via
5//! an async channel.
6
7use crate::mux::{
8    ForwardingVersionData, HandshakeMessage, PROTOCOL_DATA_POINT, PROTOCOL_EKG, PROTOCOL_HANDSHAKE,
9    PROTOCOL_TRACE_OBJECT, TraceAcceptorClient, version_table_v1,
10};
11use crate::protocol::TraceObject;
12use pallas_network::multiplexer::{Bearer, ChannelBuffer, Plexer};
13use std::path::PathBuf;
14use tokio::net::UnixListener;
15use tokio::sync::mpsc;
16use tracing::{debug, error, info, warn};
17
18/// Configuration for the trace acceptor
19#[derive(Debug, Clone)]
20pub struct AcceptorConfig {
21    /// Path to the Unix socket to listen on
22    pub socket_path: PathBuf,
23
24    /// Network magic (must match the forwarder's)
25    pub network_magic: u64,
26
27    /// Number of traces to request per round-trip
28    pub request_count: u16,
29
30    /// Capacity of the internal trace channel
31    pub channel_capacity: usize,
32}
33
34impl Default for AcceptorConfig {
35    fn default() -> Self {
36        Self {
37            socket_path: PathBuf::from("/tmp/hermod-tracer.sock"),
38            network_magic: 764824073,
39            request_count: 100,
40            channel_capacity: 1000,
41        }
42    }
43}
44
45/// Handle for receiving traces from the acceptor
46pub struct AcceptorHandle {
47    rx: mpsc::Receiver<TraceObject>,
48}
49
50impl AcceptorHandle {
51    /// Receive the next trace object, or `None` if the acceptor has shut down
52    pub async fn recv(&mut self) -> Option<TraceObject> {
53        self.rx.recv().await
54    }
55}
56
57/// Trace acceptor that listens for forwarder connections
58pub struct TraceAcceptor {
59    config: AcceptorConfig,
60    tx: mpsc::Sender<TraceObject>,
61}
62
63impl TraceAcceptor {
64    /// Create a new acceptor; returns the acceptor and a handle for consuming traces
65    pub fn new(config: AcceptorConfig) -> (Self, AcceptorHandle) {
66        let (tx, rx) = mpsc::channel(config.channel_capacity);
67        let acceptor = Self { config, tx };
68        let handle = AcceptorHandle { rx };
69        (acceptor, handle)
70    }
71
72    /// Run the acceptor (binds the socket and loops accepting connections)
73    pub async fn run(self) -> anyhow::Result<()> {
74        let path = &self.config.socket_path;
75
76        // Clean up any stale socket from a previous run
77        let _ = std::fs::remove_file(path);
78
79        let listener = UnixListener::bind(path)?;
80        info!("Acceptor listening on {}", path.display());
81
82        loop {
83            let (bearer, _addr) = Bearer::accept_unix(&listener).await?;
84            let tx = self.tx.clone();
85            let config = self.config.clone();
86            tokio::spawn(async move {
87                if let Err(e) = Self::handle_connection(bearer, tx, config).await {
88                    warn!("Connection handler error: {}", e);
89                }
90            });
91        }
92    }
93
94    async fn handle_connection(
95        bearer: Bearer,
96        tx: mpsc::Sender<TraceObject>,
97        config: AcceptorConfig,
98    ) -> anyhow::Result<()> {
99        let mut plexer = Plexer::new(bearer);
100
101        // Acceptor mirrors the forwarder's subscriptions
102        let handshake_channel = plexer.subscribe_server(PROTOCOL_HANDSHAKE);
103        let trace_channel = plexer.subscribe_server(PROTOCOL_TRACE_OBJECT);
104        let _ekg_channel = plexer.subscribe_server(PROTOCOL_EKG);
105        let _datapoint_channel = plexer.subscribe_server(PROTOCOL_DATA_POINT);
106
107        let _plexer_handle = plexer.spawn();
108
109        // Handshake
110        let mut hs_buf = ChannelBuffer::new(handshake_channel);
111        let msg: HandshakeMessage = hs_buf.recv_full_msg().await?;
112
113        match msg {
114            HandshakeMessage::Propose(versions) => {
115                // Pick the highest version we both support
116                let our_versions = version_table_v1(config.network_magic);
117                let chosen = versions
118                    .keys()
119                    .filter(|v| our_versions.contains_key(v))
120                    .max()
121                    .copied();
122
123                match chosen {
124                    Some(version) => {
125                        let accept = HandshakeMessage::Accept(
126                            version,
127                            ForwardingVersionData {
128                                network_magic: config.network_magic,
129                            },
130                        );
131                        hs_buf.send_msg_chunks(&accept).await?;
132                        debug!("Handshake accepted version {}", version);
133                    }
134                    None => {
135                        let offered: Vec<u64> = versions.into_keys().collect();
136                        let refuse = HandshakeMessage::Refuse(offered);
137                        hs_buf.send_msg_chunks(&refuse).await?;
138                        error!("Handshake refused: no compatible version");
139                        return Ok(());
140                    }
141                }
142            }
143            other => {
144                error!("Expected Propose, got {:?}", other);
145                return Ok(());
146            }
147        }
148
149        // Trace request loop
150        let mut client = TraceAcceptorClient::new(trace_channel);
151        loop {
152            match client.request_traces(config.request_count).await {
153                Ok(traces) => {
154                    debug!("Received {} traces", traces.len());
155                    for trace in traces {
156                        if tx.send(trace).await.is_err() {
157                            // Receiver dropped — shut down gracefully
158                            return Ok(());
159                        }
160                    }
161                }
162                Err(e) => {
163                    info!("Trace request loop ended: {}", e);
164                    return Ok(());
165                }
166            }
167        }
168    }
169}