1use std::collections::HashMap;
4
5use byteorder::{ByteOrder, NetworkEndian};
6use pallas_codec::{minicbor, Fragment};
7use thiserror::Error;
8use tokio::io::{AsyncReadExt, AsyncWriteExt};
9use tokio::task::JoinHandle;
10use tokio::time::Instant;
11use tokio::{select, sync::mpsc::error::SendError};
12use tracing::{debug, error, trace, warn};
13
14type IOResult<T> = tokio::io::Result<T>;
15
16use tokio::net as tcp;
17
18#[cfg(unix)]
19use tokio::net as unix;
20
21#[cfg(windows)]
22use tokio::net::windows::named_pipe::NamedPipeClient;
23
24#[cfg(windows)]
25use tokio::io::{ReadHalf, WriteHalf};
26
27const HEADER_LEN: usize = 8;
28
29pub type Timestamp = u32;
30
31pub type Payload = Vec<u8>;
32
33pub type Protocol = u16;
34
35#[derive(Debug)]
36pub struct Header {
37 pub protocol: Protocol,
38 pub timestamp: Timestamp,
39 pub payload_len: u16,
40}
41
42impl From<&[u8]> for Header {
43 fn from(value: &[u8]) -> Self {
44 let timestamp = NetworkEndian::read_u32(&value[0..4]);
45 let protocol = NetworkEndian::read_u16(&value[4..6]);
46 let payload_len = NetworkEndian::read_u16(&value[6..8]);
47
48 Self {
49 timestamp,
50 protocol,
51 payload_len,
52 }
53 }
54}
55
56impl From<Header> for [u8; 8] {
57 fn from(value: Header) -> Self {
58 let mut out = [0u8; 8];
59 NetworkEndian::write_u32(&mut out[0..4], value.timestamp);
60 NetworkEndian::write_u16(&mut out[4..6], value.protocol);
61 NetworkEndian::write_u16(&mut out[6..8], value.payload_len);
62
63 out
64 }
65}
66
67pub struct Segment {
68 pub header: Header,
69 pub payload: Payload,
70}
71
72pub enum Bearer {
73 Tcp(tcp::TcpStream),
74
75 #[cfg(unix)]
76 Unix(unix::UnixStream),
77
78 #[cfg(windows)]
79 NamedPipe(NamedPipeClient),
80}
81
82impl Bearer {
83 fn configure_tcp(stream: &tcp::TcpStream) -> IOResult<()> {
84 let sock_ref = socket2::SockRef::from(&stream);
85 let mut tcp_keepalive = socket2::TcpKeepalive::new();
86 tcp_keepalive = tcp_keepalive.with_time(tokio::time::Duration::from_secs(20));
87 tcp_keepalive = tcp_keepalive.with_interval(tokio::time::Duration::from_secs(20));
88 sock_ref.set_tcp_keepalive(&tcp_keepalive)?;
89 sock_ref.set_nodelay(true)?;
90 sock_ref.set_linger(Some(std::time::Duration::from_secs(0)))?;
91
92 Ok(())
93 }
94
95 pub async fn connect_tcp(addr: impl tcp::ToSocketAddrs) -> Result<Self, tokio::io::Error> {
96 let stream = tcp::TcpStream::connect(addr).await?;
97 Self::configure_tcp(&stream)?;
98 Ok(Self::Tcp(stream))
99 }
100
101 pub async fn connect_tcp_timeout(
102 addr: impl tcp::ToSocketAddrs,
103 timeout: std::time::Duration,
104 ) -> IOResult<Self> {
105 select! {
106 result = Self::connect_tcp(addr) => result,
107 _ = tokio::time::sleep(timeout) => Err(tokio::io::Error::new(tokio::io::ErrorKind::TimedOut, "connect timeout")),
108 }
109 }
110
111 pub async fn accept_tcp(listener: &tcp::TcpListener) -> IOResult<(Self, std::net::SocketAddr)> {
112 let (stream, addr) = listener.accept().await?;
113 Self::configure_tcp(&stream)?;
114 Ok((Self::Tcp(stream), addr))
115 }
116
117 #[cfg(unix)]
118 pub async fn connect_unix(path: impl AsRef<std::path::Path>) -> IOResult<Self> {
119 let stream = unix::UnixStream::connect(path).await?;
120 Ok(Self::Unix(stream))
121 }
122
123 #[cfg(unix)]
124 pub async fn accept_unix(
125 listener: &unix::UnixListener,
126 ) -> IOResult<(Self, unix::unix::SocketAddr)> {
127 let (stream, addr) = listener.accept().await?;
128 Ok((Self::Unix(stream), addr))
129 }
130
131 #[cfg(windows)]
132 pub fn connect_named_pipe(pipe_name: impl AsRef<std::ffi::OsStr>) -> IOResult<Self> {
133 let client = tokio::net::windows::named_pipe::ClientOptions::new().open(&pipe_name)?;
134 Ok(Self::NamedPipe(client))
135 }
136
137 pub fn into_split(self) -> (BearerReadHalf, BearerWriteHalf) {
138 match self {
139 Bearer::Tcp(x) => {
140 let (r, w) = x.into_split();
141 (BearerReadHalf::Tcp(r), BearerWriteHalf::Tcp(w))
142 }
143
144 #[cfg(unix)]
145 Bearer::Unix(x) => {
146 let (r, w) = x.into_split();
147 (BearerReadHalf::Unix(r), BearerWriteHalf::Unix(w))
148 }
149
150 #[cfg(windows)]
151 Bearer::NamedPipe(x) => {
152 let (read, write) = tokio::io::split(x);
153 let reader = BearerReadHalf::NamedPipe(read);
154 let writer = BearerWriteHalf::NamedPipe(write);
155
156 (reader, writer)
157 }
158 }
159 }
160}
161
162pub enum BearerReadHalf {
163 Tcp(tcp::tcp::OwnedReadHalf),
164
165 #[cfg(unix)]
166 Unix(unix::unix::OwnedReadHalf),
167
168 #[cfg(windows)]
169 NamedPipe(ReadHalf<NamedPipeClient>),
170}
171
172impl BearerReadHalf {
173 async fn read_exact(&mut self, buf: &mut [u8]) -> IOResult<usize> {
174 match self {
175 BearerReadHalf::Tcp(x) => x.read_exact(buf).await,
176
177 #[cfg(unix)]
178 BearerReadHalf::Unix(x) => x.read_exact(buf).await,
179
180 #[cfg(windows)]
181 BearerReadHalf::NamedPipe(x) => x.read_exact(buf).await,
182 }
183 }
184}
185
186pub enum BearerWriteHalf {
187 Tcp(tcp::tcp::OwnedWriteHalf),
188
189 #[cfg(unix)]
190 Unix(unix::unix::OwnedWriteHalf),
191
192 #[cfg(windows)]
193 NamedPipe(WriteHalf<NamedPipeClient>),
194}
195
196impl BearerWriteHalf {
197 async fn write_all(&mut self, buf: &[u8]) -> IOResult<()> {
198 match self {
199 Self::Tcp(x) => x.write_all(buf).await,
200
201 #[cfg(unix)]
202 Self::Unix(x) => x.write_all(buf).await,
203
204 #[cfg(windows)]
205 Self::NamedPipe(x) => x.write_all(buf).await,
206 }
207 }
208
209 async fn flush(&mut self) -> IOResult<()> {
210 match self {
211 Self::Tcp(x) => x.flush().await,
212
213 #[cfg(unix)]
214 Self::Unix(x) => x.flush().await,
215
216 #[cfg(windows)]
217 Self::NamedPipe(x) => x.flush().await,
218 }
219 }
220}
221
222#[derive(Debug, Error)]
223pub enum Error {
224 #[error("no data available in bearer to complete segment")]
225 EmptyBearer,
226
227 #[error("bearer I/O error")]
228 BearerIo(tokio::io::Error),
229
230 #[error("failure to encode channel message")]
231 Decoding(String),
232
233 #[error("failure to decode channel message")]
234 Encoding(String),
235
236 #[error("agent failed to enqueue chunk for protocol {0}")]
237 AgentEnqueue(Protocol, Payload),
238
239 #[error("agent failed to dequeue chunk")]
240 AgentDequeue,
241
242 #[error("plexer failed to dumux chunk for protocol {0}")]
243 PlexerDemux(Protocol, Payload),
244
245 #[error("plexer failed to mux chunk")]
246 PlexerMux,
247
248 #[error("failure to abort the plexer threads")]
249 AbortFailure,
250}
251
252type EgressChannel = tokio::sync::mpsc::Sender<Payload>;
253type Egress = HashMap<Protocol, EgressChannel>;
254
255const EGRESS_MSG_QUEUE_BUFFER: usize = 100;
256
257pub struct Demuxer(BearerReadHalf, Egress);
258
259impl Demuxer {
260 pub fn new(bearer: BearerReadHalf) -> Self {
261 let egress = HashMap::new();
262 Self(bearer, egress)
263 }
264
265 pub async fn read_segment(&mut self) -> Result<(Protocol, Payload), Error> {
266 trace!("waiting for segment header");
267 let mut buf = vec![0u8; HEADER_LEN];
268 self.0.read_exact(&mut buf).await.map_err(Error::BearerIo)?;
269 let header = Header::from(buf.as_slice());
270
271 trace!("waiting for full segment");
272 let segment_size = header.payload_len as usize;
273 let mut buf = vec![0u8; segment_size];
274 self.0.read_exact(&mut buf).await.map_err(Error::BearerIo)?;
275
276 Ok((header.protocol, buf))
277 }
278
279 async fn demux(&mut self, protocol: Protocol, payload: Payload) -> Result<(), Error> {
280 let channel = self.1.get(&protocol);
281
282 if let Some(sender) = channel {
283 sender
284 .send(payload)
285 .await
286 .map_err(|err| Error::PlexerDemux(protocol, err.0))?;
287 } else {
288 warn!(protocol, "message for unregistered protocol");
289 }
290
291 Ok(())
292 }
293
294 pub fn subscribe(&mut self, protocol: Protocol) -> tokio::sync::mpsc::Receiver<Payload> {
295 let (sender, recv) = tokio::sync::mpsc::channel(EGRESS_MSG_QUEUE_BUFFER);
296
297 self.1.insert(protocol, sender);
299
300 recv
302 }
303
304 pub async fn tick(&mut self) -> Result<(), Error> {
305 let (protocol, payload) = self.read_segment().await?;
306 trace!(protocol, "demux happening");
307 self.demux(protocol, payload).await
308 }
309
310 pub async fn run(&mut self) -> Result<(), Error> {
311 loop {
312 if let Err(err) = self.tick().await {
313 break Err(err);
314 }
315 }
316 }
317}
318
319type Ingress = (
320 tokio::sync::mpsc::Sender<(Protocol, Payload)>,
321 tokio::sync::mpsc::Receiver<(Protocol, Payload)>,
322);
323
324type Clock = Instant;
325
326const INGRESS_MSG_QUEUE_BUFFER: usize = 100;
327
328pub struct Muxer(BearerWriteHalf, Clock, Ingress);
329
330impl Muxer {
331 pub fn new(bearer: BearerWriteHalf) -> Self {
332 let ingress = tokio::sync::mpsc::channel(INGRESS_MSG_QUEUE_BUFFER);
333 let clock = Instant::now();
334 Self(bearer, clock, ingress)
335 }
336
337 async fn write_segment(&mut self, protocol: u16, payload: &[u8]) -> Result<(), std::io::Error> {
338 let header = Header {
339 protocol,
340 timestamp: self.1.elapsed().as_micros() as u32,
341 payload_len: payload.len() as u16,
342 };
343
344 let buf: [u8; 8] = header.into();
345 self.0.write_all(&buf).await?;
346 self.0.write_all(payload).await?;
347
348 self.0.flush().await?;
349
350 Ok(())
351 }
352
353 pub async fn mux(&mut self, msg: (Protocol, Payload)) -> Result<(), Error> {
354 self.write_segment(msg.0, &msg.1)
355 .await
356 .map_err(|_| Error::PlexerMux)?;
357
358 if tracing::event_enabled!(tracing::Level::TRACE) {
359 trace!(
360 protocol = msg.0,
361 data = hex::encode(&msg.1),
362 "write to bearer"
363 );
364 }
365
366 Ok(())
367 }
368
369 pub fn clone_sender(&self) -> tokio::sync::mpsc::Sender<(Protocol, Payload)> {
370 self.2 .0.clone()
371 }
372
373 pub async fn tick(&mut self) -> Result<(), Error> {
374 let msg = self.2 .1.recv().await;
375
376 if let Some(x) = msg {
377 trace!(protocol = x.0, "mux happening");
378 self.mux(x).await?
379 }
380
381 Ok(())
382 }
383
384 pub async fn run(&mut self) -> Result<(), Error> {
385 loop {
386 if let Err(err) = self.tick().await {
387 break Err(err);
388 }
389 }
390 }
391}
392
393type ToPlexerPort = tokio::sync::mpsc::Sender<(Protocol, Payload)>;
394type FromPlexerPort = tokio::sync::mpsc::Receiver<Payload>;
395
396pub struct AgentChannel {
397 protocol: Protocol,
398 to_plexer: ToPlexerPort,
399 from_plexer: FromPlexerPort,
400}
401
402impl AgentChannel {
403 fn for_client(
404 protocol: Protocol,
405 to_plexer: ToPlexerPort,
406 from_plexer: FromPlexerPort,
407 ) -> Self {
408 Self {
409 protocol,
410 from_plexer,
411 to_plexer,
412 }
413 }
414
415 fn for_server(
416 protocol: Protocol,
417 to_plexer: ToPlexerPort,
418 from_plexer: FromPlexerPort,
419 ) -> Self {
420 Self {
421 protocol,
422 from_plexer,
423 to_plexer,
424 }
425 }
426
427 pub async fn enqueue_chunk(&mut self, chunk: Payload) -> Result<(), Error> {
428 self.to_plexer
429 .send((self.protocol, chunk))
430 .await
431 .map_err(|SendError((protocol, payload))| Error::AgentEnqueue(protocol, payload))
432 }
433
434 pub async fn dequeue_chunk(&mut self) -> Result<Payload, Error> {
435 self.from_plexer.recv().await.ok_or(Error::AgentDequeue)
436 }
437}
438
439pub struct RunningPlexer {
440 demuxer: JoinHandle<Result<(), Error>>,
441 muxer: JoinHandle<Result<(), Error>>,
442}
443
444impl RunningPlexer {
445 pub async fn abort(self) {
446 self.demuxer.abort();
447 self.muxer.abort();
448 }
449}
450
451pub struct Plexer {
452 demuxer: Demuxer,
453 muxer: Muxer,
454}
455
456impl Plexer {
457 pub fn new(bearer: Bearer) -> Self {
458 let (r, w) = bearer.into_split();
459
460 Self {
461 demuxer: Demuxer::new(r),
462 muxer: Muxer::new(w),
463 }
464 }
465
466 pub fn subscribe_client(&mut self, protocol: Protocol) -> AgentChannel {
467 let to_plexer = self.muxer.clone_sender();
468 let from_plexer = self.demuxer.subscribe(protocol ^ 0x8000);
469 AgentChannel::for_client(protocol, to_plexer, from_plexer)
470 }
471
472 pub fn subscribe_server(&mut self, protocol: Protocol) -> AgentChannel {
473 let to_plexer = self.muxer.clone_sender();
474 let from_plexer = self.demuxer.subscribe(protocol);
475 AgentChannel::for_server(protocol ^ 0x8000, to_plexer, from_plexer)
476 }
477
478 pub fn spawn(self) -> RunningPlexer {
479 let mut demuxer = self.demuxer;
480 let mut muxer = self.muxer;
481
482 let demuxer = tokio::spawn(async move { demuxer.run().await });
483 let muxer = tokio::spawn(async move { muxer.run().await });
484
485 RunningPlexer { demuxer, muxer }
486 }
487}
488
489pub const MAX_SEGMENT_PAYLOAD_LENGTH: usize = 65535;
491
492fn try_decode_message<M>(buffer: &mut Vec<u8>) -> Result<Option<M>, Error>
493where
494 M: Fragment,
495{
496 let mut decoder = minicbor::Decoder::new(buffer);
497 let maybe_msg = decoder.decode();
498
499 match maybe_msg {
500 Ok(msg) => {
501 let pos = decoder.position();
502 buffer.drain(0..pos);
503 Ok(Some(msg))
504 }
505 Err(err) if err.is_end_of_input() => Ok(None),
506 Err(err) => {
507 error!(?err);
508 trace!("{}", hex::encode(buffer));
509 Err(Error::Decoding(err.to_string()))
510 }
511 }
512}
513
514pub struct ChannelBuffer {
516 channel: AgentChannel,
517 temp: Vec<u8>,
518}
519
520impl ChannelBuffer {
521 pub fn new(channel: AgentChannel) -> Self {
522 Self {
523 channel,
524 temp: Vec::new(),
525 }
526 }
527
528 pub async fn send_msg_chunks<M>(&mut self, msg: &M) -> Result<(), Error>
530 where
531 M: Fragment,
532 {
533 let mut payload = Vec::new();
534 minicbor::encode(msg, &mut payload).map_err(|err| Error::Encoding(err.to_string()))?;
535
536 let chunks = payload.chunks(MAX_SEGMENT_PAYLOAD_LENGTH);
537
538 for chunk in chunks {
539 self.channel.enqueue_chunk(Vec::from(chunk)).await?;
540 }
541
542 Ok(())
543 }
544
545 pub async fn recv_full_msg<M>(&mut self) -> Result<M, Error>
547 where
548 M: Fragment,
549 {
550 trace!(len = self.temp.len(), "waiting for full message");
551
552 if !self.temp.is_empty() {
553 trace!("buffer has data from previous payload");
554
555 if let Some(msg) = try_decode_message::<M>(&mut self.temp)? {
556 debug!("decoding done");
557 return Ok(msg);
558 }
559 }
560
561 loop {
562 let chunk = self.channel.dequeue_chunk().await?;
563 self.temp.extend(chunk);
564
565 if let Some(msg) = try_decode_message::<M>(&mut self.temp)? {
566 debug!("decoding done");
567 return Ok(msg);
568 }
569
570 trace!("not enough data");
571 }
572 }
573
574 pub fn unwrap(self) -> AgentChannel {
575 self.channel
576 }
577}
578
579impl From<AgentChannel> for ChannelBuffer {
580 fn from(channel: AgentChannel) -> Self {
581 ChannelBuffer::new(channel)
582 }
583}
584
585#[cfg(test)]
586mod tests {
587 use super::*;
588 use pallas_codec::minicbor;
589
590 #[tokio::test]
591 async fn multiple_messages_in_same_payload() {
592 let mut input = Vec::new();
593 let in_part1 = (1u8, 2u8, 3u8);
594 let in_part2 = (6u8, 5u8, 4u8);
595
596 minicbor::encode(in_part1, &mut input).unwrap();
597 minicbor::encode(in_part2, &mut input).unwrap();
598
599 let (to_plexer, _) = tokio::sync::mpsc::channel(100);
600 let (into_plexer, from_plexer) = tokio::sync::mpsc::channel(100);
601
602 let channel = AgentChannel::for_client(0, to_plexer, from_plexer);
603
604 into_plexer.send(input).await.unwrap();
605
606 let mut buf = ChannelBuffer::new(channel);
607
608 let out_part1 = buf.recv_full_msg::<(u8, u8, u8)>().await.unwrap();
609 let out_part2 = buf.recv_full_msg::<(u8, u8, u8)>().await.unwrap();
610
611 assert_eq!(in_part1, out_part1);
612 assert_eq!(in_part2, out_part2);
613 }
614
615 #[tokio::test]
616 async fn fragmented_message_in_multiple_payloads() {
617 let mut input = Vec::new();
618 let msg = (11u8, 12u8, 13u8, 14u8, 15u8, 16u8, 17u8);
619 minicbor::encode(msg, &mut input).unwrap();
620
621 let (to_plexer, _) = tokio::sync::mpsc::channel(100);
622 let (into_plexer, from_plexer) = tokio::sync::mpsc::channel(100);
623
624 let channel = AgentChannel::for_client(0, to_plexer, from_plexer);
625
626 while !input.is_empty() {
627 let chunk = Vec::from(input.drain(0..2).as_slice());
628 into_plexer.send(chunk).await.unwrap();
629 }
630
631 let mut buf = ChannelBuffer::new(channel);
632
633 let out_msg = buf
634 .recv_full_msg::<(u8, u8, u8, u8, u8, u8, u8)>()
635 .await
636 .unwrap();
637
638 assert_eq!(msg, out_msg);
639 }
640}