use crate::mux::{
ForwardingVersionData, HandshakeMessage, PROTOCOL_DATA_POINT, PROTOCOL_EKG, PROTOCOL_HANDSHAKE,
PROTOCOL_TRACE_OBJECT, TraceAcceptorClient, version_table_v1,
};
use crate::protocol::TraceObject;
use pallas_network::multiplexer::{Bearer, ChannelBuffer, Plexer};
use std::path::PathBuf;
use tokio::net::UnixListener;
use tokio::sync::mpsc;
use tracing::{debug, error, info, warn};
#[derive(Debug, Clone)]
pub struct AcceptorConfig {
pub socket_path: PathBuf,
pub network_magic: u64,
pub request_count: u16,
pub channel_capacity: usize,
}
impl Default for AcceptorConfig {
fn default() -> Self {
Self {
socket_path: PathBuf::from("/tmp/hermod-tracer.sock"),
network_magic: 764824073,
request_count: 100,
channel_capacity: 1000,
}
}
}
pub struct AcceptorHandle {
rx: mpsc::Receiver<TraceObject>,
}
impl AcceptorHandle {
pub async fn recv(&mut self) -> Option<TraceObject> {
self.rx.recv().await
}
}
pub struct TraceAcceptor {
config: AcceptorConfig,
tx: mpsc::Sender<TraceObject>,
}
impl TraceAcceptor {
pub fn new(config: AcceptorConfig) -> (Self, AcceptorHandle) {
let (tx, rx) = mpsc::channel(config.channel_capacity);
let acceptor = Self { config, tx };
let handle = AcceptorHandle { rx };
(acceptor, handle)
}
pub async fn run(self) -> anyhow::Result<()> {
let path = &self.config.socket_path;
let _ = std::fs::remove_file(path);
let listener = UnixListener::bind(path)?;
info!("Acceptor listening on {}", path.display());
loop {
let (bearer, _addr) = Bearer::accept_unix(&listener).await?;
let tx = self.tx.clone();
let config = self.config.clone();
tokio::spawn(async move {
if let Err(e) = Self::handle_connection(bearer, tx, config).await {
warn!("Connection handler error: {}", e);
}
});
}
}
async fn handle_connection(
bearer: Bearer,
tx: mpsc::Sender<TraceObject>,
config: AcceptorConfig,
) -> anyhow::Result<()> {
let mut plexer = Plexer::new(bearer);
let handshake_channel = plexer.subscribe_server(PROTOCOL_HANDSHAKE);
let trace_channel = plexer.subscribe_server(PROTOCOL_TRACE_OBJECT);
let _ekg_channel = plexer.subscribe_server(PROTOCOL_EKG);
let _datapoint_channel = plexer.subscribe_server(PROTOCOL_DATA_POINT);
let _plexer_handle = plexer.spawn();
let mut hs_buf = ChannelBuffer::new(handshake_channel);
let msg: HandshakeMessage = hs_buf.recv_full_msg().await?;
match msg {
HandshakeMessage::Propose(versions) => {
let our_versions = version_table_v1(config.network_magic);
let chosen = versions
.keys()
.filter(|v| our_versions.contains_key(v))
.max()
.copied();
match chosen {
Some(version) => {
let accept = HandshakeMessage::Accept(
version,
ForwardingVersionData {
network_magic: config.network_magic,
},
);
hs_buf.send_msg_chunks(&accept).await?;
debug!("Handshake accepted version {}", version);
}
None => {
let offered: Vec<u64> = versions.into_keys().collect();
let refuse = HandshakeMessage::Refuse(offered);
hs_buf.send_msg_chunks(&refuse).await?;
error!("Handshake refused: no compatible version");
return Ok(());
}
}
}
other => {
error!("Expected Propose, got {:?}", other);
return Ok(());
}
}
let mut client = TraceAcceptorClient::new(trace_channel);
loop {
match client.request_traces(config.request_count).await {
Ok(traces) => {
debug!("Received {} traces", traces.len());
for trace in traces {
if tx.send(trace).await.is_err() {
return Ok(());
}
}
}
Err(e) => {
info!("Trace request loop ended: {}", e);
return Ok(());
}
}
}
}
}