use crate::protocol::{Message, MsgTraceObjectsReply, TraceObject};
use pallas_network::multiplexer;
use thiserror::Error;
use tracing::{debug, info};
#[derive(Error, Debug)]
pub enum ClientError {
#[error("multiplexer error: {0}")]
Multiplexer(#[from] multiplexer::Error),
#[error("invalid inbound message")]
InvalidInbound,
#[error("connection closed")]
ConnectionClosed,
}
pub struct TraceForwardClient {
channel: multiplexer::ChannelBuffer,
}
impl TraceForwardClient {
pub fn new(channel: multiplexer::AgentChannel) -> Self {
Self {
channel: multiplexer::ChannelBuffer::new(channel),
}
}
pub async fn send_message(&mut self, msg: &Message) -> Result<(), ClientError> {
debug!("Sending message");
self.channel
.send_msg_chunks(msg)
.await
.map_err(ClientError::Multiplexer)?;
Ok(())
}
pub async fn recv_message(&mut self) -> Result<Message, ClientError> {
let msg = self
.channel
.recv_full_msg()
.await
.map_err(ClientError::Multiplexer)?;
debug!("Received message");
Ok(msg)
}
pub async fn handle_request(&mut self, traces: Vec<TraceObject>) -> Result<(), ClientError> {
let msg = self.recv_message().await?;
match msg {
Message::TraceObjectsRequest(req) => {
debug!(
"Received request for {} traces (blocking: {})",
req.number_of_trace_objects, req.blocking
);
let to_send = traces
.into_iter()
.take(req.number_of_trace_objects as usize)
.collect();
let reply = Message::TraceObjectsReply(MsgTraceObjectsReply {
trace_objects: to_send,
});
self.send_message(&reply).await?;
Ok(())
}
Message::Done => {
info!("Received Done message");
Err(ClientError::ConnectionClosed)
}
_ => Err(ClientError::InvalidInbound),
}
}
pub async fn send_done(&mut self) -> Result<(), ClientError> {
self.send_message(&Message::Done).await
}
}