pallas_network/
multiplexer.rs

1//! A multiplexer of several mini-protocols through a single bearer
2
3use std::collections::HashMap;
4
5use byteorder::{ByteOrder, NetworkEndian};
6use pallas_codec::{minicbor, Fragment};
7use thiserror::Error;
8use tokio::io::{AsyncReadExt, AsyncWriteExt};
9use tokio::task::JoinHandle;
10use tokio::time::Instant;
11use tokio::{select, sync::mpsc::error::SendError};
12use tracing::{debug, error, trace, warn};
13
14type IOResult<T> = tokio::io::Result<T>;
15
16use tokio::net as tcp;
17
18#[cfg(unix)]
19use tokio::net as unix;
20
21#[cfg(windows)]
22use tokio::net::windows::named_pipe::NamedPipeClient;
23
24#[cfg(windows)]
25use tokio::io::{ReadHalf, WriteHalf};
26
27const HEADER_LEN: usize = 8;
28
29pub type Timestamp = u32;
30
31pub type Payload = Vec<u8>;
32
33pub type Protocol = u16;
34
35#[derive(Debug)]
36pub struct Header {
37    pub protocol: Protocol,
38    pub timestamp: Timestamp,
39    pub payload_len: u16,
40}
41
42impl From<&[u8]> for Header {
43    fn from(value: &[u8]) -> Self {
44        let timestamp = NetworkEndian::read_u32(&value[0..4]);
45        let protocol = NetworkEndian::read_u16(&value[4..6]);
46        let payload_len = NetworkEndian::read_u16(&value[6..8]);
47
48        Self {
49            timestamp,
50            protocol,
51            payload_len,
52        }
53    }
54}
55
56impl From<Header> for [u8; 8] {
57    fn from(value: Header) -> Self {
58        let mut out = [0u8; 8];
59        NetworkEndian::write_u32(&mut out[0..4], value.timestamp);
60        NetworkEndian::write_u16(&mut out[4..6], value.protocol);
61        NetworkEndian::write_u16(&mut out[6..8], value.payload_len);
62
63        out
64    }
65}
66
67pub struct Segment {
68    pub header: Header,
69    pub payload: Payload,
70}
71
72pub enum Bearer {
73    Tcp(tcp::TcpStream),
74
75    #[cfg(unix)]
76    Unix(unix::UnixStream),
77
78    #[cfg(windows)]
79    NamedPipe(NamedPipeClient),
80}
81
82impl Bearer {
83    fn configure_tcp(stream: &tcp::TcpStream) -> IOResult<()> {
84        let sock_ref = socket2::SockRef::from(&stream);
85        let mut tcp_keepalive = socket2::TcpKeepalive::new();
86        tcp_keepalive = tcp_keepalive.with_time(tokio::time::Duration::from_secs(20));
87        tcp_keepalive = tcp_keepalive.with_interval(tokio::time::Duration::from_secs(20));
88        sock_ref.set_tcp_keepalive(&tcp_keepalive)?;
89        sock_ref.set_nodelay(true)?;
90        sock_ref.set_linger(Some(std::time::Duration::from_secs(0)))?;
91
92        Ok(())
93    }
94
95    pub async fn connect_tcp(addr: impl tcp::ToSocketAddrs) -> Result<Self, tokio::io::Error> {
96        let stream = tcp::TcpStream::connect(addr).await?;
97        Self::configure_tcp(&stream)?;
98        Ok(Self::Tcp(stream))
99    }
100
101    pub async fn connect_tcp_timeout(
102        addr: impl tcp::ToSocketAddrs,
103        timeout: std::time::Duration,
104    ) -> IOResult<Self> {
105        select! {
106            result = Self::connect_tcp(addr) => result,
107            _ = tokio::time::sleep(timeout) => Err(tokio::io::Error::new(tokio::io::ErrorKind::TimedOut, "connect timeout")),
108        }
109    }
110
111    pub async fn accept_tcp(listener: &tcp::TcpListener) -> IOResult<(Self, std::net::SocketAddr)> {
112        let (stream, addr) = listener.accept().await?;
113        Self::configure_tcp(&stream)?;
114        Ok((Self::Tcp(stream), addr))
115    }
116
117    #[cfg(unix)]
118    pub async fn connect_unix(path: impl AsRef<std::path::Path>) -> IOResult<Self> {
119        let stream = unix::UnixStream::connect(path).await?;
120        Ok(Self::Unix(stream))
121    }
122
123    #[cfg(unix)]
124    pub async fn accept_unix(
125        listener: &unix::UnixListener,
126    ) -> IOResult<(Self, unix::unix::SocketAddr)> {
127        let (stream, addr) = listener.accept().await?;
128        Ok((Self::Unix(stream), addr))
129    }
130
131    #[cfg(windows)]
132    pub fn connect_named_pipe(pipe_name: impl AsRef<std::ffi::OsStr>) -> IOResult<Self> {
133        let client = tokio::net::windows::named_pipe::ClientOptions::new().open(&pipe_name)?;
134        Ok(Self::NamedPipe(client))
135    }
136
137    pub fn into_split(self) -> (BearerReadHalf, BearerWriteHalf) {
138        match self {
139            Bearer::Tcp(x) => {
140                let (r, w) = x.into_split();
141                (BearerReadHalf::Tcp(r), BearerWriteHalf::Tcp(w))
142            }
143
144            #[cfg(unix)]
145            Bearer::Unix(x) => {
146                let (r, w) = x.into_split();
147                (BearerReadHalf::Unix(r), BearerWriteHalf::Unix(w))
148            }
149
150            #[cfg(windows)]
151            Bearer::NamedPipe(x) => {
152                let (read, write) = tokio::io::split(x);
153                let reader = BearerReadHalf::NamedPipe(read);
154                let writer = BearerWriteHalf::NamedPipe(write);
155
156                (reader, writer)
157            }
158        }
159    }
160}
161
162pub enum BearerReadHalf {
163    Tcp(tcp::tcp::OwnedReadHalf),
164
165    #[cfg(unix)]
166    Unix(unix::unix::OwnedReadHalf),
167
168    #[cfg(windows)]
169    NamedPipe(ReadHalf<NamedPipeClient>),
170}
171
172impl BearerReadHalf {
173    async fn read_exact(&mut self, buf: &mut [u8]) -> IOResult<usize> {
174        match self {
175            BearerReadHalf::Tcp(x) => x.read_exact(buf).await,
176
177            #[cfg(unix)]
178            BearerReadHalf::Unix(x) => x.read_exact(buf).await,
179
180            #[cfg(windows)]
181            BearerReadHalf::NamedPipe(x) => x.read_exact(buf).await,
182        }
183    }
184}
185
186pub enum BearerWriteHalf {
187    Tcp(tcp::tcp::OwnedWriteHalf),
188
189    #[cfg(unix)]
190    Unix(unix::unix::OwnedWriteHalf),
191
192    #[cfg(windows)]
193    NamedPipe(WriteHalf<NamedPipeClient>),
194}
195
196impl BearerWriteHalf {
197    async fn write_all(&mut self, buf: &[u8]) -> IOResult<()> {
198        match self {
199            Self::Tcp(x) => x.write_all(buf).await,
200
201            #[cfg(unix)]
202            Self::Unix(x) => x.write_all(buf).await,
203
204            #[cfg(windows)]
205            Self::NamedPipe(x) => x.write_all(buf).await,
206        }
207    }
208
209    async fn flush(&mut self) -> IOResult<()> {
210        match self {
211            Self::Tcp(x) => x.flush().await,
212
213            #[cfg(unix)]
214            Self::Unix(x) => x.flush().await,
215
216            #[cfg(windows)]
217            Self::NamedPipe(x) => x.flush().await,
218        }
219    }
220}
221
222#[derive(Debug, Error)]
223pub enum Error {
224    #[error("no data available in bearer to complete segment")]
225    EmptyBearer,
226
227    #[error("bearer I/O error")]
228    BearerIo(tokio::io::Error),
229
230    #[error("failure to encode channel message")]
231    Decoding(String),
232
233    #[error("failure to decode channel message")]
234    Encoding(String),
235
236    #[error("agent failed to enqueue chunk for protocol {0}")]
237    AgentEnqueue(Protocol, Payload),
238
239    #[error("agent failed to dequeue chunk")]
240    AgentDequeue,
241
242    #[error("plexer failed to dumux chunk for protocol {0}")]
243    PlexerDemux(Protocol, Payload),
244
245    #[error("plexer failed to mux chunk")]
246    PlexerMux,
247
248    #[error("failure to abort the plexer threads")]
249    AbortFailure,
250}
251
252type EgressChannel = tokio::sync::mpsc::Sender<Payload>;
253type Egress = HashMap<Protocol, EgressChannel>;
254
255const EGRESS_MSG_QUEUE_BUFFER: usize = 100;
256
257pub struct Demuxer(BearerReadHalf, Egress);
258
259impl Demuxer {
260    pub fn new(bearer: BearerReadHalf) -> Self {
261        let egress = HashMap::new();
262        Self(bearer, egress)
263    }
264
265    pub async fn read_segment(&mut self) -> Result<(Protocol, Payload), Error> {
266        trace!("waiting for segment header");
267        let mut buf = vec![0u8; HEADER_LEN];
268        self.0.read_exact(&mut buf).await.map_err(Error::BearerIo)?;
269        let header = Header::from(buf.as_slice());
270
271        trace!("waiting for full segment");
272        let segment_size = header.payload_len as usize;
273        let mut buf = vec![0u8; segment_size];
274        self.0.read_exact(&mut buf).await.map_err(Error::BearerIo)?;
275
276        Ok((header.protocol, buf))
277    }
278
279    async fn demux(&mut self, protocol: Protocol, payload: Payload) -> Result<(), Error> {
280        let channel = self.1.get(&protocol);
281
282        if let Some(sender) = channel {
283            sender
284                .send(payload)
285                .await
286                .map_err(|err| Error::PlexerDemux(protocol, err.0))?;
287        } else {
288            warn!(protocol, "message for unregistered protocol");
289        }
290
291        Ok(())
292    }
293
294    pub fn subscribe(&mut self, protocol: Protocol) -> tokio::sync::mpsc::Receiver<Payload> {
295        let (sender, recv) = tokio::sync::mpsc::channel(EGRESS_MSG_QUEUE_BUFFER);
296
297        // keep track of the sender
298        self.1.insert(protocol, sender);
299
300        // return the receiver for the agent
301        recv
302    }
303
304    pub async fn tick(&mut self) -> Result<(), Error> {
305        let (protocol, payload) = self.read_segment().await?;
306        trace!(protocol, "demux happening");
307        self.demux(protocol, payload).await
308    }
309
310    pub async fn run(&mut self) -> Result<(), Error> {
311        loop {
312            if let Err(err) = self.tick().await {
313                break Err(err);
314            }
315        }
316    }
317}
318
319type Ingress = (
320    tokio::sync::mpsc::Sender<(Protocol, Payload)>,
321    tokio::sync::mpsc::Receiver<(Protocol, Payload)>,
322);
323
324type Clock = Instant;
325
326const INGRESS_MSG_QUEUE_BUFFER: usize = 100;
327
328pub struct Muxer(BearerWriteHalf, Clock, Ingress);
329
330impl Muxer {
331    pub fn new(bearer: BearerWriteHalf) -> Self {
332        let ingress = tokio::sync::mpsc::channel(INGRESS_MSG_QUEUE_BUFFER);
333        let clock = Instant::now();
334        Self(bearer, clock, ingress)
335    }
336
337    async fn write_segment(&mut self, protocol: u16, payload: &[u8]) -> Result<(), std::io::Error> {
338        let header = Header {
339            protocol,
340            timestamp: self.1.elapsed().as_micros() as u32,
341            payload_len: payload.len() as u16,
342        };
343
344        let buf: [u8; 8] = header.into();
345        self.0.write_all(&buf).await?;
346        self.0.write_all(payload).await?;
347
348        self.0.flush().await?;
349
350        Ok(())
351    }
352
353    pub async fn mux(&mut self, msg: (Protocol, Payload)) -> Result<(), Error> {
354        self.write_segment(msg.0, &msg.1)
355            .await
356            .map_err(|_| Error::PlexerMux)?;
357
358        if tracing::event_enabled!(tracing::Level::TRACE) {
359            trace!(
360                protocol = msg.0,
361                data = hex::encode(&msg.1),
362                "write to bearer"
363            );
364        }
365
366        Ok(())
367    }
368
369    pub fn clone_sender(&self) -> tokio::sync::mpsc::Sender<(Protocol, Payload)> {
370        self.2 .0.clone()
371    }
372
373    pub async fn tick(&mut self) -> Result<(), Error> {
374        let msg = self.2 .1.recv().await;
375
376        if let Some(x) = msg {
377            trace!(protocol = x.0, "mux happening");
378            self.mux(x).await?
379        }
380
381        Ok(())
382    }
383
384    pub async fn run(&mut self) -> Result<(), Error> {
385        loop {
386            if let Err(err) = self.tick().await {
387                break Err(err);
388            }
389        }
390    }
391}
392
393type ToPlexerPort = tokio::sync::mpsc::Sender<(Protocol, Payload)>;
394type FromPlexerPort = tokio::sync::mpsc::Receiver<Payload>;
395
396pub struct AgentChannel {
397    protocol: Protocol,
398    to_plexer: ToPlexerPort,
399    from_plexer: FromPlexerPort,
400}
401
402impl AgentChannel {
403    fn for_client(
404        protocol: Protocol,
405        to_plexer: ToPlexerPort,
406        from_plexer: FromPlexerPort,
407    ) -> Self {
408        Self {
409            protocol,
410            from_plexer,
411            to_plexer,
412        }
413    }
414
415    fn for_server(
416        protocol: Protocol,
417        to_plexer: ToPlexerPort,
418        from_plexer: FromPlexerPort,
419    ) -> Self {
420        Self {
421            protocol,
422            from_plexer,
423            to_plexer,
424        }
425    }
426
427    pub async fn enqueue_chunk(&mut self, chunk: Payload) -> Result<(), Error> {
428        self.to_plexer
429            .send((self.protocol, chunk))
430            .await
431            .map_err(|SendError((protocol, payload))| Error::AgentEnqueue(protocol, payload))
432    }
433
434    pub async fn dequeue_chunk(&mut self) -> Result<Payload, Error> {
435        self.from_plexer.recv().await.ok_or(Error::AgentDequeue)
436    }
437}
438
439pub struct RunningPlexer {
440    demuxer: JoinHandle<Result<(), Error>>,
441    muxer: JoinHandle<Result<(), Error>>,
442}
443
444impl RunningPlexer {
445    pub async fn abort(self) {
446        self.demuxer.abort();
447        self.muxer.abort();
448    }
449}
450
451pub struct Plexer {
452    demuxer: Demuxer,
453    muxer: Muxer,
454}
455
456impl Plexer {
457    pub fn new(bearer: Bearer) -> Self {
458        let (r, w) = bearer.into_split();
459
460        Self {
461            demuxer: Demuxer::new(r),
462            muxer: Muxer::new(w),
463        }
464    }
465
466    pub fn subscribe_client(&mut self, protocol: Protocol) -> AgentChannel {
467        let to_plexer = self.muxer.clone_sender();
468        let from_plexer = self.demuxer.subscribe(protocol ^ 0x8000);
469        AgentChannel::for_client(protocol, to_plexer, from_plexer)
470    }
471
472    pub fn subscribe_server(&mut self, protocol: Protocol) -> AgentChannel {
473        let to_plexer = self.muxer.clone_sender();
474        let from_plexer = self.demuxer.subscribe(protocol);
475        AgentChannel::for_server(protocol ^ 0x8000, to_plexer, from_plexer)
476    }
477
478    pub fn spawn(self) -> RunningPlexer {
479        let mut demuxer = self.demuxer;
480        let mut muxer = self.muxer;
481
482        let demuxer = tokio::spawn(async move { demuxer.run().await });
483        let muxer = tokio::spawn(async move { muxer.run().await });
484
485        RunningPlexer { demuxer, muxer }
486    }
487}
488
489/// Protocol value that defines max segment length
490pub const MAX_SEGMENT_PAYLOAD_LENGTH: usize = 65535;
491
492fn try_decode_message<M>(buffer: &mut Vec<u8>) -> Result<Option<M>, Error>
493where
494    M: Fragment,
495{
496    let mut decoder = minicbor::Decoder::new(buffer);
497    let maybe_msg = decoder.decode();
498
499    match maybe_msg {
500        Ok(msg) => {
501            let pos = decoder.position();
502            buffer.drain(0..pos);
503            Ok(Some(msg))
504        }
505        Err(err) if err.is_end_of_input() => Ok(None),
506        Err(err) => {
507            error!(?err);
508            trace!("{}", hex::encode(buffer));
509            Err(Error::Decoding(err.to_string()))
510        }
511    }
512}
513
514/// A channel abstraction to hide the complexity of partial payloads
515pub struct ChannelBuffer {
516    channel: AgentChannel,
517    temp: Vec<u8>,
518}
519
520impl ChannelBuffer {
521    pub fn new(channel: AgentChannel) -> Self {
522        Self {
523            channel,
524            temp: Vec::new(),
525        }
526    }
527
528    /// Enqueues a msg as a sequence payload chunks
529    pub async fn send_msg_chunks<M>(&mut self, msg: &M) -> Result<(), Error>
530    where
531        M: Fragment,
532    {
533        let mut payload = Vec::new();
534        minicbor::encode(msg, &mut payload).map_err(|err| Error::Encoding(err.to_string()))?;
535
536        let chunks = payload.chunks(MAX_SEGMENT_PAYLOAD_LENGTH);
537
538        for chunk in chunks {
539            self.channel.enqueue_chunk(Vec::from(chunk)).await?;
540        }
541
542        Ok(())
543    }
544
545    /// Reads from the channel until a complete message is found
546    pub async fn recv_full_msg<M>(&mut self) -> Result<M, Error>
547    where
548        M: Fragment,
549    {
550        trace!(len = self.temp.len(), "waiting for full message");
551
552        if !self.temp.is_empty() {
553            trace!("buffer has data from previous payload");
554
555            if let Some(msg) = try_decode_message::<M>(&mut self.temp)? {
556                debug!("decoding done");
557                return Ok(msg);
558            }
559        }
560
561        loop {
562            let chunk = self.channel.dequeue_chunk().await?;
563            self.temp.extend(chunk);
564
565            if let Some(msg) = try_decode_message::<M>(&mut self.temp)? {
566                debug!("decoding done");
567                return Ok(msg);
568            }
569
570            trace!("not enough data");
571        }
572    }
573
574    pub fn unwrap(self) -> AgentChannel {
575        self.channel
576    }
577}
578
579impl From<AgentChannel> for ChannelBuffer {
580    fn from(channel: AgentChannel) -> Self {
581        ChannelBuffer::new(channel)
582    }
583}
584
585#[cfg(test)]
586mod tests {
587    use super::*;
588    use pallas_codec::minicbor;
589
590    #[tokio::test]
591    async fn multiple_messages_in_same_payload() {
592        let mut input = Vec::new();
593        let in_part1 = (1u8, 2u8, 3u8);
594        let in_part2 = (6u8, 5u8, 4u8);
595
596        minicbor::encode(in_part1, &mut input).unwrap();
597        minicbor::encode(in_part2, &mut input).unwrap();
598
599        let (to_plexer, _) = tokio::sync::mpsc::channel(100);
600        let (into_plexer, from_plexer) = tokio::sync::mpsc::channel(100);
601
602        let channel = AgentChannel::for_client(0, to_plexer, from_plexer);
603
604        into_plexer.send(input).await.unwrap();
605
606        let mut buf = ChannelBuffer::new(channel);
607
608        let out_part1 = buf.recv_full_msg::<(u8, u8, u8)>().await.unwrap();
609        let out_part2 = buf.recv_full_msg::<(u8, u8, u8)>().await.unwrap();
610
611        assert_eq!(in_part1, out_part1);
612        assert_eq!(in_part2, out_part2);
613    }
614
615    #[tokio::test]
616    async fn fragmented_message_in_multiple_payloads() {
617        let mut input = Vec::new();
618        let msg = (11u8, 12u8, 13u8, 14u8, 15u8, 16u8, 17u8);
619        minicbor::encode(msg, &mut input).unwrap();
620
621        let (to_plexer, _) = tokio::sync::mpsc::channel(100);
622        let (into_plexer, from_plexer) = tokio::sync::mpsc::channel(100);
623
624        let channel = AgentChannel::for_client(0, to_plexer, from_plexer);
625
626        while !input.is_empty() {
627            let chunk = Vec::from(input.drain(0..2).as_slice());
628            into_plexer.send(chunk).await.unwrap();
629        }
630
631        let mut buf = ChannelBuffer::new(channel);
632
633        let out_msg = buf
634            .recv_full_msg::<(u8, u8, u8, u8, u8, u8, u8)>()
635            .await
636            .unwrap();
637
638        assert_eq!(msg, out_msg);
639    }
640}