1use crate::forwarder::ForwarderHandle;
15use crate::mux::{
16 ForwardingVersionData, HandshakeMessage, PROTOCOL_DATA_POINT, PROTOCOL_EKG, PROTOCOL_HANDSHAKE,
17 PROTOCOL_TRACE_OBJECT, TraceForwardClient, version_table_v1,
18};
19use crate::protocol::TraceObject;
20use crate::server::config::Address;
21use pallas_network::multiplexer::{Bearer, Plexer};
22use std::sync::Arc;
23use tokio::net::{TcpListener, UnixListener};
24use tokio::sync::broadcast;
25use tokio::task::JoinSet;
26use tracing::{debug, info, warn};
27
28pub struct ReForwarder {
34 inner: ReForwarderInner,
35 namespace_filters: Option<Vec<Vec<String>>>,
37}
38
39enum ReForwarderInner {
40 Outbound(ForwarderHandle),
42 Inbound(broadcast::Sender<Arc<Vec<TraceObject>>>),
44}
45
46impl ReForwarder {
47 pub fn new(handle: ForwarderHandle, namespace_filters: Option<Vec<Vec<String>>>) -> Self {
49 ReForwarder {
50 inner: ReForwarderInner::Outbound(handle),
51 namespace_filters,
52 }
53 }
54
55 pub fn new_inbound(
60 tx: broadcast::Sender<Arc<Vec<TraceObject>>>,
61 namespace_filters: Option<Vec<Vec<String>>>,
62 ) -> Self {
63 ReForwarder {
64 inner: ReForwarderInner::Inbound(tx),
65 namespace_filters,
66 }
67 }
68
69 pub async fn forward(&self, traces: &[TraceObject]) {
71 let filtered: Vec<TraceObject> = traces
72 .iter()
73 .filter(|t| self.matches_filter(t))
74 .cloned()
75 .collect();
76 if filtered.is_empty() {
77 return;
78 }
79 match &self.inner {
80 ReForwarderInner::Outbound(handle) => {
81 for trace in filtered {
82 if let Err(e) = handle.send(trace).await {
83 warn!("ReForwarder send error: {}", e);
84 }
85 }
86 }
87 ReForwarderInner::Inbound(tx) => {
88 let _ = tx.send(Arc::new(filtered));
90 }
91 }
92 }
93
94 fn matches_filter(&self, trace: &TraceObject) -> bool {
95 let Some(filters) = &self.namespace_filters else {
96 return true; };
98 filters
99 .iter()
100 .any(|prefix| trace.to_namespace.starts_with(prefix))
101 }
102}
103
104pub async fn run_accepting_loop(
116 addrs: &[Address],
117 tx: broadcast::Sender<Arc<Vec<TraceObject>>>,
118 network_magic: u64,
119) {
120 let mut set = JoinSet::new();
121 for addr in addrs {
122 let addr = addr.clone();
123 let tx = tx.clone();
124 set.spawn(async move {
125 listen_and_accept(addr, tx, network_magic).await;
126 });
127 }
128 while set.join_next().await.is_some() {}
129}
130
131async fn listen_and_accept(
132 addr: Address,
133 tx: broadcast::Sender<Arc<Vec<TraceObject>>>,
134 network_magic: u64,
135) {
136 match &addr {
137 Address::LocalPipe(path) => {
138 let _ = std::fs::remove_file(path);
139 let listener = match UnixListener::bind(path) {
140 Ok(l) => l,
141 Err(e) => {
142 warn!(
143 "AcceptingReForwarder: failed to bind {}: {}",
144 path.display(),
145 e
146 );
147 return;
148 }
149 };
150 info!("AcceptingReForwarder: listening on {}", path.display());
151 loop {
152 match Bearer::accept_unix(&listener).await {
153 Ok((bearer, _)) => {
154 let rx = tx.subscribe();
155 tokio::spawn(handle_accepting_connection(bearer, rx, network_magic));
156 }
157 Err(e) => {
158 warn!("AcceptingReForwarder accept error: {}", e);
159 break;
160 }
161 }
162 }
163 }
164 Address::RemoteSocket(host, port) => {
165 let bind_addr = format!("{}:{}", host, port);
166 let listener = match TcpListener::bind(&bind_addr).await {
167 Ok(l) => l,
168 Err(e) => {
169 warn!(
170 "AcceptingReForwarder: failed to bind TCP {}: {}",
171 bind_addr, e
172 );
173 return;
174 }
175 };
176 info!("AcceptingReForwarder: listening on TCP {}", bind_addr);
177 loop {
178 match Bearer::accept_tcp(&listener).await {
179 Ok((bearer, _)) => {
180 let rx = tx.subscribe();
181 tokio::spawn(handle_accepting_connection(bearer, rx, network_magic));
182 }
183 Err(e) => {
184 warn!("AcceptingReForwarder TCP accept error: {}", e);
185 break;
186 }
187 }
188 }
189 }
190 }
191}
192
193async fn handle_accepting_connection(
199 bearer: Bearer,
200 mut rx: broadcast::Receiver<Arc<Vec<TraceObject>>>,
201 network_magic: u64,
202) {
203 let mut plexer = Plexer::new(bearer);
204
205 let hs_ch = plexer.subscribe_server(PROTOCOL_HANDSHAKE);
208 let trace_ch = plexer.subscribe_server(PROTOCOL_TRACE_OBJECT);
209 drop(plexer.subscribe_server(PROTOCOL_EKG));
211 drop(plexer.subscribe_server(PROTOCOL_DATA_POINT));
212 let _plexer_handle = plexer.spawn();
213
214 use pallas_network::multiplexer::ChannelBuffer;
216 let mut hs = ChannelBuffer::new(hs_ch);
217 let versions = version_table_v1(network_magic);
218 let msg: HandshakeMessage = match hs.recv_full_msg().await {
219 Ok(m) => m,
220 Err(e) => {
221 warn!("AcceptingReForwarder: handshake recv failed: {}", e);
222 return;
223 }
224 };
225 match msg {
226 HandshakeMessage::Propose(proposed) => {
227 let chosen = proposed
228 .keys()
229 .filter(|v| versions.contains_key(v))
230 .max()
231 .copied();
232 match chosen {
233 Some(ver) => {
234 let accept =
235 HandshakeMessage::Accept(ver, ForwardingVersionData { network_magic });
236 if let Err(e) = hs.send_msg_chunks(&accept).await {
237 warn!("AcceptingReForwarder: handshake accept send failed: {}", e);
238 return;
239 }
240 debug!("AcceptingReForwarder: handshake accepted v={}", ver);
241 }
242 None => {
243 let offered: Vec<u64> = proposed.into_keys().collect();
244 let _ = hs.send_msg_chunks(&HandshakeMessage::Refuse(offered)).await;
245 warn!("AcceptingReForwarder: no compatible version");
246 return;
247 }
248 }
249 }
250 other => {
251 warn!("AcceptingReForwarder: expected Propose, got {:?}", other);
252 return;
253 }
254 }
255
256 let mut client = TraceForwardClient::new(trace_ch);
258 loop {
259 let batch: Arc<Vec<TraceObject>> = loop {
261 match rx.recv().await {
262 Ok(b) => break b,
263 Err(broadcast::error::RecvError::Closed) => {
264 info!("AcceptingReForwarder: broadcast channel closed");
265 return;
266 }
267 Err(broadcast::error::RecvError::Lagged(n)) => {
268 warn!("AcceptingReForwarder: lagged by {} batches, skipping", n);
269 continue;
270 }
271 }
272 };
273
274 let mut traces: Vec<TraceObject> = (*batch).clone();
276 while let Ok(extra) = rx.try_recv() {
277 traces.extend_from_slice(&extra);
278 }
279
280 match client.handle_request(traces).await {
282 Ok(()) => {}
283 Err(crate::mux::ClientError::ConnectionClosed) => {
284 info!("AcceptingReForwarder: downstream sent Done");
285 return;
286 }
287 Err(e) => {
288 warn!("AcceptingReForwarder: trace error: {}", e);
289 return;
290 }
291 }
292 }
293}
294
295#[cfg(test)]
296mod tests {
297 use super::*;
298 use crate::protocol::types::{DetailLevel, Severity, TraceObject};
299 use chrono::Utc;
300
301 fn make_trace(namespace: Vec<&str>) -> TraceObject {
302 TraceObject {
303 to_human: None,
304 to_machine: "{}".to_string(),
305 to_namespace: namespace.into_iter().map(str::to_string).collect(),
306 to_severity: Severity::Info,
307 to_details: DetailLevel::DNormal,
308 to_timestamp: Utc::now(),
309 to_hostname: "host".to_string(),
310 to_thread_id: "1".to_string(),
311 }
312 }
313
314 #[tokio::test]
315 async fn no_filter_forwards_all_traces() {
316 let (tx, mut rx) = broadcast::channel(16);
317 let rf = ReForwarder::new_inbound(tx, None);
318 let traces = vec![make_trace(vec!["A", "B"]), make_trace(vec!["C"])];
319 rf.forward(&traces).await;
320 let received = rx.recv().await.unwrap();
321 assert_eq!(received.len(), 2);
322 }
323
324 #[tokio::test]
325 async fn prefix_filter_blocks_non_matching_namespace() {
326 let (tx, mut rx) = broadcast::channel(16);
327 let filters = Some(vec![vec!["Cardano".to_string(), "Node".to_string()]]);
328 let rf = ReForwarder::new_inbound(tx, filters);
329 let traces = vec![
330 make_trace(vec!["Cardano", "Node", "Peers"]),
331 make_trace(vec!["Other", "Trace"]),
332 ];
333 rf.forward(&traces).await;
334 let received = rx.recv().await.unwrap();
335 assert_eq!(received.len(), 1);
336 assert_eq!(received[0].to_namespace, vec!["Cardano", "Node", "Peers"]);
337 }
338
339 #[tokio::test]
340 async fn prefix_filter_exact_match_passes() {
341 let (tx, mut rx) = broadcast::channel(16);
342 let filters = Some(vec![vec!["Cardano".to_string(), "Node".to_string()]]);
343 let rf = ReForwarder::new_inbound(tx, filters);
344 let traces = vec![make_trace(vec!["Cardano", "Node"])];
345 rf.forward(&traces).await;
346 let received = rx.recv().await.unwrap();
347 assert_eq!(received.len(), 1);
348 }
349
350 #[tokio::test]
351 async fn filter_all_out_sends_nothing() {
352 let (tx, mut rx) = broadcast::channel(16);
353 let filters = Some(vec![vec!["Cardano".to_string()]]);
354 let rf = ReForwarder::new_inbound(tx, filters);
355 let traces = vec![make_trace(vec!["Other"])];
356 rf.forward(&traces).await;
357 assert!(rx.try_recv().is_err(), "nothing should be broadcast");
358 }
359
360 #[tokio::test]
361 async fn multiple_prefixes_any_match_passes() {
362 let (tx, mut rx) = broadcast::channel(16);
363 let filters = Some(vec![vec!["Cardano".to_string()], vec!["Node".to_string()]]);
364 let rf = ReForwarder::new_inbound(tx, filters);
365 let traces = vec![
366 make_trace(vec!["Cardano", "X"]),
367 make_trace(vec!["Node", "Y"]),
368 make_trace(vec!["Other"]),
369 ];
370 rf.forward(&traces).await;
371 let received = rx.recv().await.unwrap();
372 assert_eq!(received.len(), 2);
373 }
374
375 #[tokio::test]
376 async fn empty_input_sends_nothing() {
377 let (tx, mut rx) = broadcast::channel(16);
378 let rf = ReForwarder::new_inbound(tx, None);
379 rf.forward(&[]).await;
380 assert!(rx.try_recv().is_err());
381 }
382
383 #[tokio::test]
384 async fn inbound_with_no_receivers_does_not_panic() {
385 let (tx, rx) = broadcast::channel::<Arc<Vec<TraceObject>>>(16);
386 drop(rx); let rf = ReForwarder::new_inbound(tx, None);
388 rf.forward(&[make_trace(vec!["A"])]).await;
390 }
391
392 #[tokio::test]
393 async fn inbound_broadcasts_to_multiple_receivers() {
394 let (tx, mut rx1) = broadcast::channel(16);
395 let mut rx2 = tx.subscribe();
396 let rf = ReForwarder::new_inbound(tx, None);
397 rf.forward(&[make_trace(vec!["A"])]).await;
398 assert_eq!(rx1.recv().await.unwrap().len(), 1);
399 assert_eq!(rx2.recv().await.unwrap().len(), 1);
400 }
401}