use std::collections::HashMap;
use byteorder::{ByteOrder, NetworkEndian};
use pallas_codec::{minicbor, Fragment};
use thiserror::Error;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::task::JoinHandle;
use tokio::time::Instant;
use tokio::{select, sync::mpsc::error::SendError};
use tracing::{debug, error, trace, warn};
type IOResult<T> = tokio::io::Result<T>;
use tokio::net as tcp;
#[cfg(unix)]
use tokio::net as unix;
#[cfg(windows)]
use tokio::net::windows::named_pipe::NamedPipeClient;
#[cfg(windows)]
use tokio::io::{ReadHalf, WriteHalf};
const HEADER_LEN: usize = 8;
pub type Timestamp = u32;
pub type Payload = Vec<u8>;
pub type Protocol = u16;
#[derive(Debug)]
pub struct Header {
pub protocol: Protocol,
pub timestamp: Timestamp,
pub payload_len: u16,
}
impl From<&[u8]> for Header {
fn from(value: &[u8]) -> Self {
let timestamp = NetworkEndian::read_u32(&value[0..4]);
let protocol = NetworkEndian::read_u16(&value[4..6]);
let payload_len = NetworkEndian::read_u16(&value[6..8]);
Self {
timestamp,
protocol,
payload_len,
}
}
}
impl From<Header> for [u8; 8] {
fn from(value: Header) -> Self {
let mut out = [0u8; 8];
NetworkEndian::write_u32(&mut out[0..4], value.timestamp);
NetworkEndian::write_u16(&mut out[4..6], value.protocol);
NetworkEndian::write_u16(&mut out[6..8], value.payload_len);
out
}
}
pub struct Segment {
pub header: Header,
pub payload: Payload,
}
pub enum Bearer {
Tcp(tcp::TcpStream),
#[cfg(unix)]
Unix(unix::UnixStream),
#[cfg(windows)]
NamedPipe(NamedPipeClient),
}
impl Bearer {
fn configure_tcp(stream: &tcp::TcpStream) -> IOResult<()> {
let sock_ref = socket2::SockRef::from(&stream);
let mut tcp_keepalive = socket2::TcpKeepalive::new();
tcp_keepalive = tcp_keepalive.with_time(tokio::time::Duration::from_secs(20));
tcp_keepalive = tcp_keepalive.with_interval(tokio::time::Duration::from_secs(20));
sock_ref.set_tcp_keepalive(&tcp_keepalive)?;
sock_ref.set_nodelay(true)?;
sock_ref.set_linger(Some(std::time::Duration::from_secs(0)))?;
Ok(())
}
pub async fn connect_tcp(addr: impl tcp::ToSocketAddrs) -> Result<Self, tokio::io::Error> {
let stream = tcp::TcpStream::connect(addr).await?;
Self::configure_tcp(&stream)?;
Ok(Self::Tcp(stream))
}
pub async fn connect_tcp_timeout(
addr: impl tcp::ToSocketAddrs,
timeout: std::time::Duration,
) -> IOResult<Self> {
select! {
result = Self::connect_tcp(addr) => result,
_ = tokio::time::sleep(timeout) => Err(tokio::io::Error::new(tokio::io::ErrorKind::TimedOut, "connect timeout")),
}
}
pub async fn accept_tcp(listener: &tcp::TcpListener) -> IOResult<(Self, std::net::SocketAddr)> {
let (stream, addr) = listener.accept().await?;
Self::configure_tcp(&stream)?;
Ok((Self::Tcp(stream), addr))
}
#[cfg(unix)]
pub async fn connect_unix(path: impl AsRef<std::path::Path>) -> IOResult<Self> {
let stream = unix::UnixStream::connect(path).await?;
Ok(Self::Unix(stream))
}
#[cfg(unix)]
pub async fn accept_unix(
listener: &unix::UnixListener,
) -> IOResult<(Self, unix::unix::SocketAddr)> {
let (stream, addr) = listener.accept().await?;
Ok((Self::Unix(stream), addr))
}
#[cfg(windows)]
pub fn connect_named_pipe(pipe_name: impl AsRef<std::ffi::OsStr>) -> IOResult<Self> {
let client = tokio::net::windows::named_pipe::ClientOptions::new().open(&pipe_name)?;
Ok(Self::NamedPipe(client))
}
pub fn into_split(self) -> (BearerReadHalf, BearerWriteHalf) {
match self {
Bearer::Tcp(x) => {
let (r, w) = x.into_split();
(BearerReadHalf::Tcp(r), BearerWriteHalf::Tcp(w))
}
#[cfg(unix)]
Bearer::Unix(x) => {
let (r, w) = x.into_split();
(BearerReadHalf::Unix(r), BearerWriteHalf::Unix(w))
}
#[cfg(windows)]
Bearer::NamedPipe(x) => {
let (read, write) = tokio::io::split(x);
let reader = BearerReadHalf::NamedPipe(read);
let writer = BearerWriteHalf::NamedPipe(write);
(reader, writer)
}
}
}
}
pub enum BearerReadHalf {
Tcp(tcp::tcp::OwnedReadHalf),
#[cfg(unix)]
Unix(unix::unix::OwnedReadHalf),
#[cfg(windows)]
NamedPipe(ReadHalf<NamedPipeClient>),
}
impl BearerReadHalf {
async fn read_exact(&mut self, buf: &mut [u8]) -> IOResult<usize> {
match self {
BearerReadHalf::Tcp(x) => x.read_exact(buf).await,
#[cfg(unix)]
BearerReadHalf::Unix(x) => x.read_exact(buf).await,
#[cfg(windows)]
BearerReadHalf::NamedPipe(x) => x.read_exact(buf).await,
}
}
}
pub enum BearerWriteHalf {
Tcp(tcp::tcp::OwnedWriteHalf),
#[cfg(unix)]
Unix(unix::unix::OwnedWriteHalf),
#[cfg(windows)]
NamedPipe(WriteHalf<NamedPipeClient>),
}
impl BearerWriteHalf {
async fn write_all(&mut self, buf: &[u8]) -> IOResult<()> {
match self {
Self::Tcp(x) => x.write_all(buf).await,
#[cfg(unix)]
Self::Unix(x) => x.write_all(buf).await,
#[cfg(windows)]
Self::NamedPipe(x) => x.write_all(buf).await,
}
}
async fn flush(&mut self) -> IOResult<()> {
match self {
Self::Tcp(x) => x.flush().await,
#[cfg(unix)]
Self::Unix(x) => x.flush().await,
#[cfg(windows)]
Self::NamedPipe(x) => x.flush().await,
}
}
}
#[derive(Debug, Error)]
pub enum Error {
#[error("no data available in bearer to complete segment")]
EmptyBearer,
#[error("bearer I/O error")]
BearerIo(tokio::io::Error),
#[error("failure to encode channel message")]
Decoding(String),
#[error("failure to decode channel message")]
Encoding(String),
#[error("agent failed to enqueue chunk for protocol {0}")]
AgentEnqueue(Protocol, Payload),
#[error("agent failed to dequeue chunk")]
AgentDequeue,
#[error("plexer failed to dumux chunk for protocol {0}")]
PlexerDemux(Protocol, Payload),
#[error("plexer failed to mux chunk")]
PlexerMux,
#[error("failure to abort the plexer threads")]
AbortFailure,
}
type EgressChannel = tokio::sync::mpsc::Sender<Payload>;
type Egress = HashMap<Protocol, EgressChannel>;
const EGRESS_MSG_QUEUE_BUFFER: usize = 100;
pub struct Demuxer(BearerReadHalf, Egress);
impl Demuxer {
pub fn new(bearer: BearerReadHalf) -> Self {
let egress = HashMap::new();
Self(bearer, egress)
}
pub async fn read_segment(&mut self) -> Result<(Protocol, Payload), Error> {
trace!("waiting for segment header");
let mut buf = vec![0u8; HEADER_LEN];
self.0.read_exact(&mut buf).await.map_err(Error::BearerIo)?;
let header = Header::from(buf.as_slice());
trace!("waiting for full segment");
let segment_size = header.payload_len as usize;
let mut buf = vec![0u8; segment_size];
self.0.read_exact(&mut buf).await.map_err(Error::BearerIo)?;
Ok((header.protocol, buf))
}
async fn demux(&mut self, protocol: Protocol, payload: Payload) -> Result<(), Error> {
let channel = self.1.get(&protocol);
if let Some(sender) = channel {
sender
.send(payload)
.await
.map_err(|err| Error::PlexerDemux(protocol, err.0))?;
} else {
warn!(protocol, "message for unregistered protocol");
}
Ok(())
}
pub fn subscribe(&mut self, protocol: Protocol) -> tokio::sync::mpsc::Receiver<Payload> {
let (sender, recv) = tokio::sync::mpsc::channel(EGRESS_MSG_QUEUE_BUFFER);
self.1.insert(protocol, sender);
recv
}
pub async fn tick(&mut self) -> Result<(), Error> {
let (protocol, payload) = self.read_segment().await?;
trace!(protocol, "demux happening");
self.demux(protocol, payload).await
}
pub async fn run(&mut self) -> Result<(), Error> {
loop {
if let Err(err) = self.tick().await {
break Err(err);
}
}
}
}
type Ingress = (
tokio::sync::mpsc::Sender<(Protocol, Payload)>,
tokio::sync::mpsc::Receiver<(Protocol, Payload)>,
);
type Clock = Instant;
const INGRESS_MSG_QUEUE_BUFFER: usize = 100;
pub struct Muxer(BearerWriteHalf, Clock, Ingress);
impl Muxer {
pub fn new(bearer: BearerWriteHalf) -> Self {
let ingress = tokio::sync::mpsc::channel(INGRESS_MSG_QUEUE_BUFFER);
let clock = Instant::now();
Self(bearer, clock, ingress)
}
async fn write_segment(&mut self, protocol: u16, payload: &[u8]) -> Result<(), std::io::Error> {
let header = Header {
protocol,
timestamp: self.1.elapsed().as_micros() as u32,
payload_len: payload.len() as u16,
};
let buf: [u8; 8] = header.into();
self.0.write_all(&buf).await?;
self.0.write_all(payload).await?;
self.0.flush().await?;
Ok(())
}
pub async fn mux(&mut self, msg: (Protocol, Payload)) -> Result<(), Error> {
self.write_segment(msg.0, &msg.1)
.await
.map_err(|_| Error::PlexerMux)?;
if tracing::event_enabled!(tracing::Level::TRACE) {
trace!(
protocol = msg.0,
data = hex::encode(&msg.1),
"write to bearer"
);
}
Ok(())
}
pub fn clone_sender(&self) -> tokio::sync::mpsc::Sender<(Protocol, Payload)> {
self.2 .0.clone()
}
pub async fn tick(&mut self) -> Result<(), Error> {
let msg = self.2 .1.recv().await;
if let Some(x) = msg {
trace!(protocol = x.0, "mux happening");
self.mux(x).await?
}
Ok(())
}
pub async fn run(&mut self) -> Result<(), Error> {
loop {
if let Err(err) = self.tick().await {
break Err(err);
}
}
}
}
type ToPlexerPort = tokio::sync::mpsc::Sender<(Protocol, Payload)>;
type FromPlexerPort = tokio::sync::mpsc::Receiver<Payload>;
pub struct AgentChannel {
protocol: Protocol,
to_plexer: ToPlexerPort,
from_plexer: FromPlexerPort,
}
impl AgentChannel {
fn for_client(
protocol: Protocol,
to_plexer: ToPlexerPort,
from_plexer: FromPlexerPort,
) -> Self {
Self {
protocol,
from_plexer,
to_plexer,
}
}
fn for_server(
protocol: Protocol,
to_plexer: ToPlexerPort,
from_plexer: FromPlexerPort,
) -> Self {
Self {
protocol,
from_plexer,
to_plexer,
}
}
pub async fn enqueue_chunk(&mut self, chunk: Payload) -> Result<(), Error> {
self.to_plexer
.send((self.protocol, chunk))
.await
.map_err(|SendError((protocol, payload))| Error::AgentEnqueue(protocol, payload))
}
pub async fn dequeue_chunk(&mut self) -> Result<Payload, Error> {
self.from_plexer.recv().await.ok_or(Error::AgentDequeue)
}
}
pub struct RunningPlexer {
demuxer: JoinHandle<Result<(), Error>>,
muxer: JoinHandle<Result<(), Error>>,
}
impl RunningPlexer {
pub async fn abort(self) {
self.demuxer.abort();
self.muxer.abort();
}
}
pub struct Plexer {
demuxer: Demuxer,
muxer: Muxer,
}
impl Plexer {
pub fn new(bearer: Bearer) -> Self {
let (r, w) = bearer.into_split();
Self {
demuxer: Demuxer::new(r),
muxer: Muxer::new(w),
}
}
pub fn subscribe_client(&mut self, protocol: Protocol) -> AgentChannel {
let to_plexer = self.muxer.clone_sender();
let from_plexer = self.demuxer.subscribe(protocol ^ 0x8000);
AgentChannel::for_client(protocol, to_plexer, from_plexer)
}
pub fn subscribe_server(&mut self, protocol: Protocol) -> AgentChannel {
let to_plexer = self.muxer.clone_sender();
let from_plexer = self.demuxer.subscribe(protocol);
AgentChannel::for_server(protocol ^ 0x8000, to_plexer, from_plexer)
}
pub fn spawn(self) -> RunningPlexer {
let mut demuxer = self.demuxer;
let mut muxer = self.muxer;
let demuxer = tokio::spawn(async move { demuxer.run().await });
let muxer = tokio::spawn(async move { muxer.run().await });
RunningPlexer { demuxer, muxer }
}
}
pub const MAX_SEGMENT_PAYLOAD_LENGTH: usize = 65535;
fn try_decode_message<M>(buffer: &mut Vec<u8>) -> Result<Option<M>, Error>
where
M: Fragment,
{
let mut decoder = minicbor::Decoder::new(buffer);
let maybe_msg = decoder.decode();
match maybe_msg {
Ok(msg) => {
let pos = decoder.position();
buffer.drain(0..pos);
Ok(Some(msg))
}
Err(err) if err.is_end_of_input() => Ok(None),
Err(err) => {
error!(?err);
trace!("{}", hex::encode(buffer));
Err(Error::Decoding(err.to_string()))
}
}
}
pub struct ChannelBuffer {
channel: AgentChannel,
temp: Vec<u8>,
}
impl ChannelBuffer {
pub fn new(channel: AgentChannel) -> Self {
Self {
channel,
temp: Vec::new(),
}
}
pub async fn send_msg_chunks<M>(&mut self, msg: &M) -> Result<(), Error>
where
M: Fragment,
{
let mut payload = Vec::new();
minicbor::encode(msg, &mut payload).map_err(|err| Error::Encoding(err.to_string()))?;
let chunks = payload.chunks(MAX_SEGMENT_PAYLOAD_LENGTH);
for chunk in chunks {
self.channel.enqueue_chunk(Vec::from(chunk)).await?;
}
Ok(())
}
pub async fn recv_full_msg<M>(&mut self) -> Result<M, Error>
where
M: Fragment,
{
trace!(len = self.temp.len(), "waiting for full message");
if !self.temp.is_empty() {
trace!("buffer has data from previous payload");
if let Some(msg) = try_decode_message::<M>(&mut self.temp)? {
debug!("decoding done");
return Ok(msg);
}
}
loop {
let chunk = self.channel.dequeue_chunk().await?;
self.temp.extend(chunk);
if let Some(msg) = try_decode_message::<M>(&mut self.temp)? {
debug!("decoding done");
return Ok(msg);
}
trace!("not enough data");
}
}
pub fn unwrap(self) -> AgentChannel {
self.channel
}
}
impl From<AgentChannel> for ChannelBuffer {
fn from(channel: AgentChannel) -> Self {
ChannelBuffer::new(channel)
}
}
#[cfg(test)]
mod tests {
use super::*;
use pallas_codec::minicbor;
#[tokio::test]
async fn multiple_messages_in_same_payload() {
let mut input = Vec::new();
let in_part1 = (1u8, 2u8, 3u8);
let in_part2 = (6u8, 5u8, 4u8);
minicbor::encode(in_part1, &mut input).unwrap();
minicbor::encode(in_part2, &mut input).unwrap();
let (to_plexer, _) = tokio::sync::mpsc::channel(100);
let (into_plexer, from_plexer) = tokio::sync::mpsc::channel(100);
let channel = AgentChannel::for_client(0, to_plexer, from_plexer);
into_plexer.send(input).await.unwrap();
let mut buf = ChannelBuffer::new(channel);
let out_part1 = buf.recv_full_msg::<(u8, u8, u8)>().await.unwrap();
let out_part2 = buf.recv_full_msg::<(u8, u8, u8)>().await.unwrap();
assert_eq!(in_part1, out_part1);
assert_eq!(in_part2, out_part2);
}
#[tokio::test]
async fn fragmented_message_in_multiple_payloads() {
let mut input = Vec::new();
let msg = (11u8, 12u8, 13u8, 14u8, 15u8, 16u8, 17u8);
minicbor::encode(msg, &mut input).unwrap();
let (to_plexer, _) = tokio::sync::mpsc::channel(100);
let (into_plexer, from_plexer) = tokio::sync::mpsc::channel(100);
let channel = AgentChannel::for_client(0, to_plexer, from_plexer);
while !input.is_empty() {
let chunk = Vec::from(input.drain(0..2).as_slice());
into_plexer.send(chunk).await.unwrap();
}
let mut buf = ChannelBuffer::new(channel);
let out_msg = buf
.recv_full_msg::<(u8, u8, u8, u8, u8, u8, u8)>()
.await
.unwrap();
assert_eq!(msg, out_msg);
}
}