Skip to main content

amaru_protocols/
mux.rs

1// Copyright 2025 PRAGMA
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#[expect(clippy::disallowed_types)]
16use std::collections::HashMap;
17use std::{
18    cell::RefCell,
19    collections::{VecDeque, hash_map::Entry},
20    num::{NonZeroU16, NonZeroUsize},
21    time::SystemTime,
22};
23
24use amaru_kernel::NonEmptyBytes;
25use amaru_observability::trace;
26use amaru_ouroboros::ConnectionId;
27use anyhow::Context;
28use bytes::{Buf, BufMut, Bytes, BytesMut, TryGetError};
29use cbor_data::{Cbor, ErrorKind, ParseError};
30use pure_stage::{EPOCH, Effects, Instant, StageRef, TryInStage, Void};
31
32use crate::{
33    network_effects::{Network, NetworkOps},
34    protocol::{Erased, ProtocolId, Role, RoleT},
35};
36
37pub fn register_deserializers() -> pure_stage::DeserializerGuards {
38    vec![
39        pure_stage::register_data_deserializer::<MuxMessage>().boxed(),
40        pure_stage::register_data_deserializer::<NonEmptyBytes>().boxed(),
41        pure_stage::register_data_deserializer::<State>().boxed(),
42        pure_stage::register_data_deserializer::<HandlerMessage>().boxed(),
43        pure_stage::register_data_deserializer::<Sent>().boxed(),
44        pure_stage::register_data_deserializer::<Read>().boxed(),
45    ]
46}
47
48const MAX_SEGMENT_SIZE: usize = 65535;
49
50/// microseconds part of the wall clock time
51#[derive(Debug, PartialEq, serde::Serialize, serde::Deserialize)]
52pub struct Timestamp(u32);
53
54impl Timestamp {
55    pub fn now() -> Self {
56        #[expect(clippy::expect_used)]
57        Self(
58            SystemTime::now()
59                .duration_since(SystemTime::UNIX_EPOCH)
60                .expect("system time is not supposed to be before the UNIX epoch")
61                .as_micros() as u32,
62        )
63    }
64
65    fn encode(self, buffer: &mut BytesMut) {
66        buffer.put_u32(self.0);
67    }
68
69    pub fn from_instant(instant: Instant) -> Self {
70        Self(instant.saturating_since(*EPOCH).as_micros() as u32)
71    }
72
73    fn decode(buffer: &mut Bytes) -> Result<Self, TryGetError> {
74        Ok(Self(buffer.try_get_u32()?))
75    }
76}
77
78#[derive(Debug, Clone, Copy, PartialEq, serde::Serialize, serde::Deserialize)]
79pub enum Frame {
80    /// Each message is a single CBOR item
81    OneCborItem,
82    /// No message parsing, just buffer the data
83    Buffer,
84}
85
86impl Frame {
87    pub fn try_consume(&self, data: &mut BytesMut) -> Result<Option<NonEmptyBytes>, ParseError> {
88        match self {
89            Frame::OneCborItem => match Cbor::checked_prefix(data) {
90                Ok((item, _rest)) => {
91                    let item = data.copy_to_bytes(item.as_slice().len());
92                    #[expect(clippy::expect_used)]
93                    Ok(Some(item.try_into().expect("guaranteed by CBOR standard")))
94                }
95                Err(e) if matches!(e.kind(), ErrorKind::UnexpectedEof(_)) => Ok(None),
96                Err(e) => Err(e),
97            },
98            Frame::Buffer => Ok(None),
99        }
100    }
101}
102
103#[derive(Debug, PartialEq, serde::Serialize, serde::Deserialize)]
104pub enum HandlerMessage {
105    Registered(ProtocolId<Erased>),
106    FromNetwork(NonEmptyBytes),
107}
108
109#[derive(Debug, PartialEq, serde::Serialize, serde::Deserialize)]
110pub struct Sent;
111
112#[derive(Debug, PartialEq, serde::Serialize, serde::Deserialize)]
113pub struct Read;
114
115#[derive(Debug, PartialEq, serde::Serialize, serde::Deserialize)]
116pub enum MuxMessage {
117    /// Register the given protocol with its ID so that data will be fed into it
118    ///
119    /// Note that the handler explicitly needs to request each network message by sending `WantNext`.
120    /// This is necessary to allow proper handling of TCP simultaneous open in the handshake protocol.
121    Register { protocol: ProtocolId<Erased>, frame: Frame, handler: StageRef<HandlerMessage>, max_buffer: usize },
122    /// Buffer incoming data for this protocol ID up to the given limit
123    /// (this should be followed by Register eventually, to then consume the data)
124    ///
125    /// Setting the size to zero means that data are dropped without begin buffered
126    /// and without tearing down the connection.
127    Buffer(ProtocolId<Erased>, usize),
128    /// Send the given message on the protocol ID and notify when enqueued in TCP buffer
129    Send(ProtocolId<Erased>, NonEmptyBytes, StageRef<Sent>),
130    /// internal message coming from the TCP stream reader
131    FromNetwork(Timestamp, ProtocolId<Erased>, NonEmptyBytes),
132    /// Notify that the segment has been written to the TCP stream
133    Written,
134    /// Permit the next invocation of the Protocol with data from the network.
135    WantNext(ProtocolId<Erased>),
136    /// Reading or writing error occurred
137    Terminate,
138}
139
140#[derive(Debug, PartialEq, serde::Serialize, serde::Deserialize)]
141pub struct State {
142    conn: Connection,
143    muxer: Muxer,
144    sending: bool,
145}
146
147#[derive(Debug, PartialEq, serde::Serialize, serde::Deserialize)]
148enum Connection {
149    Unint(ConnectionId),
150    Init(StageRef<NonEmptyBytes>, StageRef<Read>),
151}
152
153impl State {
154    /// Create a new state with the given connection ID and buffering the given protocols.
155    ///
156    /// Note that upon receiving the first message, the stage will start reading from the network.
157    /// Any data received for unregistered protocols will lead to stage termination.
158    pub fn new(conn: ConnectionId, buffer: &[(ProtocolId<Erased>, usize)], role: Role) -> Self {
159        let mut muxer = Muxer::new(role);
160        for &(proto_id, limit) in buffer {
161            #[expect(clippy::expect_used)]
162            muxer.buffer(proto_id, limit).expect("no buffered data yet");
163        }
164        Self { conn: Connection::Unint(conn), muxer, sending: false }
165    }
166
167    pub async fn init(
168        &mut self,
169        eff: &mut Effects<MuxMessage>,
170    ) -> (&mut Muxer, &mut bool, &StageRef<NonEmptyBytes>, &StageRef<Read>) {
171        match &mut self.conn {
172            Connection::Unint(conn) => {
173                let writer = eff
174                    .stage(
175                        format!("writer-{}", conn),
176                        move |(conn, muxer, role), data: NonEmptyBytes, eff| async move {
177                            Network::new(&eff)
178                                .send(conn, data)
179                                .await
180                                .or_terminate(
181                                    &eff,
182                                    async |err| tracing::error!(%err, %role, "failed to send data to network"),
183                                )
184                                .await;
185                            eff.send(&muxer, MuxMessage::Written).await;
186                            (conn, muxer, role)
187                        },
188                    )
189                    .await;
190                let writer = eff.supervise(writer, MuxMessage::Terminate);
191                let writer = eff.wire_up(writer, (*conn, eff.me(), self.muxer.role())).await;
192                let reader = eff.stage(format!("reader-{}", conn), read_segment).await;
193                let reader = eff.supervise(reader, MuxMessage::Terminate);
194                let reader = eff.wire_up(reader, (*conn, eff.me(), self.muxer.role())).await;
195                eff.send(&reader, Read).await;
196                self.conn = Connection::Init(writer, reader);
197            }
198            Connection::Init(..) => {}
199        }
200        let Connection::Init(writer, reader) = &self.conn else { unreachable!() };
201        (&mut self.muxer, &mut self.sending, writer, reader)
202    }
203}
204
205pub async fn stage(mut state: State, msg: MuxMessage, mut eff: Effects<MuxMessage>) -> State {
206    let (muxer, sending, writer, reader) = state.init(&mut eff).await;
207
208    handle_msg(msg, &eff, muxer, sending, writer, reader)
209        .await
210        .or_terminate(&eff, async |error| {
211            use std::fmt::Write;
212            let mut err = String::new();
213            for error in error.chain() {
214                if !err.is_empty() {
215                    err.push_str(" <- ");
216                }
217                write!(&mut err, "{}", error).ok();
218            }
219            tracing::error!(%err, role=%muxer.role(), "muxing error")
220        })
221        .await;
222
223    state
224}
225
226async fn handle_msg(
227    msg: MuxMessage,
228    eff: &Effects<MuxMessage>,
229    muxer: &mut Muxer,
230    sending: &mut bool,
231    writer: &StageRef<NonEmptyBytes>,
232    reader: &StageRef<Read>,
233) -> anyhow::Result<()> {
234    match msg {
235        MuxMessage::Register { protocol, frame, handler, max_buffer } => {
236            muxer.register(protocol, frame, max_buffer, handler, eff).await
237        }
238        MuxMessage::Buffer(proto_id, limit) => muxer.buffer(proto_id, limit),
239        MuxMessage::Send(proto_id, bytes, sent) => {
240            tracing::trace!(%proto_id, bytes = bytes.len(), "send");
241            muxer.outgoing(proto_id, bytes.into(), sent);
242            if !*sending && let Some((proto_id, bytes)) = muxer.next_segment(eff).await {
243                *sending = true;
244                let header = muxer.encode_header(eff, proto_id, &bytes).await;
245                eff.send(writer, header).await;
246            }
247            Ok(())
248        }
249        MuxMessage::FromNetwork(timestamp, proto_id, bytes) => {
250            tracing::trace!(%proto_id, bytes = bytes.len(), "received");
251            muxer
252                .received(timestamp, proto_id.opposite(), bytes.into(), eff)
253                .await
254                .with_context(|| format!("reading network message for protocol {}", proto_id))?;
255            eff.send(reader, Read).await;
256            Ok(())
257        }
258        MuxMessage::WantNext(proto_id) => {
259            muxer.want_next(proto_id, eff).await.with_context(|| format!("reading message for protocol {}", proto_id))
260        }
261        MuxMessage::Written => {
262            *sending = false;
263            if let Some((proto_id, bytes)) = muxer.next_segment(eff).await {
264                *sending = true;
265                let header = muxer.encode_header(eff, proto_id, &bytes).await;
266                eff.send(writer, header).await;
267            }
268            Ok(())
269        }
270        MuxMessage::Terminate => {
271            tracing::debug!(role=%muxer.role(), "terminating muxer due to read/write error");
272            eff.terminate::<Void>().await;
273            Ok(())
274        }
275    }
276}
277
278async fn read_segment(
279    (conn, muxer, role): (ConnectionId, StageRef<MuxMessage>, Role),
280    _token: Read,
281    eff: Effects<Read>,
282) -> (ConnectionId, StageRef<MuxMessage>, Role) {
283    let header = loop {
284        let data = Network::new(&eff)
285            .recv(conn, HEADER_LEN)
286            .await
287            .or_terminate(
288                &eff,
289                async |err| tracing::error!(%role, %err, "failed to receive segment header from network"),
290            )
291            .await;
292        let Some(header) = Header::decode(&mut data.into_inner())
293            .or_terminate(&eff, async |err| tracing::error!(%role, %err, "failed to decode segment header"))
294            .await
295        else {
296            // sending frames without payload data is not explicitly forbidden, so we just ignore them
297            tracing::info!(%role, "received empty segment header");
298            continue;
299        };
300        break header;
301    };
302
303    let data = Network::new(&eff)
304        .recv(conn, header.length.into())
305        .await
306        .or_terminate(&eff, async |err| tracing::error!(%role, %err, "failed to receive segment data from network"))
307        .await;
308
309    eff.send(&muxer, MuxMessage::FromNetwork(header.timestamp, header.proto_id, data)).await;
310    (conn, muxer, role)
311}
312
313/// A header for a segment of data.
314///
315/// While the network spec doesn't explicitly forbid sending frames without payload data,
316/// we never do that and our code will just ignore such frames.
317struct Header {
318    timestamp: Timestamp,
319    proto_id: ProtocolId<Erased>,
320    length: NonZeroU16,
321}
322const HEADER_LEN: NonZeroUsize = NonZeroUsize::new(8).expect("8 is a valid non-zero size");
323
324impl Header {
325    pub fn encode<R: RoleT>(proto_id: ProtocolId<R>, bytes: impl AsRef<[u8]>, timestamp: Timestamp) -> NonEmptyBytes {
326        thread_local! {
327            static BUFFER: RefCell<BytesMut> = RefCell::new(BytesMut::with_capacity(HEADER_LEN.get() + MAX_SEGMENT_SIZE));
328        }
329        let bytes = bytes.as_ref();
330        BUFFER.with_borrow_mut(move |buffer| {
331            buffer.clear();
332            timestamp.encode(buffer);
333            proto_id.encode(buffer);
334            buffer.put_u16(bytes.len() as u16);
335            buffer.extend_from_slice(bytes);
336            #[expect(clippy::expect_used)]
337            buffer.copy_to_bytes(buffer.remaining()).try_into().expect("guaranteed by writing to the buffer")
338        })
339    }
340
341    pub fn decode(buffer: &mut Bytes) -> Result<Option<Self>, TryGetError> {
342        let timestamp = Timestamp::decode(buffer)?;
343        let proto_id = ProtocolId::decode(buffer)?;
344        let length = buffer.try_get_u16()?;
345        Ok(NonZeroU16::new(length).map(|length| Self { timestamp, proto_id, length }))
346    }
347}
348
349#[expect(clippy::disallowed_types)]
350type Protocols = HashMap<ProtocolId<Erased>, PerProto>;
351
352#[derive(Debug, PartialEq, serde::Serialize, serde::Deserialize)]
353pub struct Muxer {
354    protocols: Protocols,
355    outgoing: Vec<ProtocolId<Erased>>,
356    next_out: usize,
357    role: Role,
358}
359
360impl Muxer {
361    pub fn new(role: Role) -> Self {
362        Self { protocols: Protocols::new(), outgoing: Vec::new(), next_out: 0, role }
363    }
364
365    pub fn role(&self) -> Role {
366        self.role
367    }
368
369    async fn encode_header<M>(
370        &mut self,
371        eff: &Effects<M>,
372        proto_id: ProtocolId<Erased>,
373        bytes: &Bytes,
374    ) -> NonEmptyBytes {
375        let instant = eff.clock().await;
376        let timestamp = Timestamp::from_instant(instant);
377        Header::encode(proto_id, bytes, timestamp)
378    }
379
380    #[trace(amaru::protocols::mux::REGISTER)]
381    pub async fn register<M>(
382        &mut self,
383        proto_id: ProtocolId<Erased>,
384        frame: Frame,
385        max_buffer: usize,
386        handler: StageRef<HandlerMessage>,
387        eff: &Effects<M>,
388    ) -> anyhow::Result<()> {
389        eff.send(&handler, HandlerMessage::Registered(proto_id)).await;
390        self.do_register(proto_id, frame, max_buffer, handler);
391        Ok(())
392    }
393
394    #[trace(amaru::protocols::mux::BUFFER)]
395    pub fn buffer(&mut self, proto_id: ProtocolId<Erased>, limit: usize) -> anyhow::Result<()> {
396        let pp = self.do_register(proto_id, Frame::Buffer, limit, StageRef::blackhole());
397        if limit == 0 {
398            tracing::trace!(buffer = pp.incoming.len(), "switching to ignoring mode");
399            pp.incoming.clear();
400        } else if pp.incoming.len() > limit {
401            tracing::warn!(buffer = pp.incoming.len(), limit, "reducing buffer killed the connection");
402            anyhow::bail!("reducing buffer ({}) leads to excess data ({})", limit, pp.incoming.len());
403        }
404        Ok(())
405    }
406
407    fn do_register(
408        &mut self,
409        proto_id: ProtocolId<Erased>,
410        frame: Frame,
411        max_buffer: usize,
412        handler: StageRef<HandlerMessage>,
413    ) -> &mut PerProto {
414        if !self.outgoing.contains(&proto_id) {
415            self.outgoing.push(proto_id);
416        }
417        match self.protocols.entry(proto_id) {
418            Entry::Occupied(pp) => {
419                let pp = pp.into_mut();
420                tracing::trace!(want = pp.wanted, "updating registration");
421                pp.frame = frame;
422                pp.max_buffer = max_buffer;
423                pp.handler = handler;
424                pp
425            }
426            Entry::Vacant(pp) => pp.insert(PerProto::new(handler, frame, max_buffer)),
427        }
428    }
429
430    #[trace(amaru::protocols::mux::OUTGOING, proto_id = proto_id, bytes = bytes.len() as u64)]
431    pub fn outgoing(&mut self, proto_id: ProtocolId<Erased>, bytes: Bytes, sent: StageRef<Sent>) {
432        tracing::trace!(%proto_id, bytes = bytes.len(), "enqueueing send");
433        #[allow(clippy::expect_used)]
434        self.protocols
435            .get_mut(&proto_id)
436            .ok_or_else(|| anyhow::anyhow!("protocol {} not registered", proto_id))
437            .expect("internal error")
438            .enqueue_send(bytes, sent);
439    }
440
441    #[trace(amaru::protocols::mux::NEXT_SEGMENT)]
442    pub async fn next_segment<M>(&mut self, eff: &Effects<M>) -> Option<(ProtocolId<Erased>, Bytes)> {
443        for idx in (self.next_out..self.outgoing.len()).chain(0..self.next_out) {
444            let proto_id = self.outgoing[idx];
445            #[allow(clippy::expect_used)]
446            let proto = self.protocols.get_mut(&proto_id).expect("invariant violation");
447            let Some(bytes) = proto.next_segment(eff).await else {
448                continue;
449            };
450            self.next_out = (idx + 1) % self.outgoing.len();
451            tracing::trace!(size = bytes.len(), %proto_id, next = self.next_out, "sending segment");
452            return Some((proto_id, bytes));
453        }
454        None
455    }
456
457    #[trace(amaru::protocols::mux::RECEIVED, bytes = bytes.len() as u64)]
458    pub async fn received<M>(
459        &mut self,
460        timestamp: Timestamp,
461        proto_id: ProtocolId<Erased>,
462        bytes: Bytes,
463        eff: &Effects<M>,
464    ) -> anyhow::Result<()> {
465        if let Some(proto) = self.protocols.get_mut(&proto_id) {
466            proto.received(timestamp, bytes, eff).await
467        } else {
468            anyhow::bail!("received data for unknown protocol {}", proto_id)
469        }
470    }
471
472    #[trace(amaru::protocols::mux::WANT_NEXT)]
473    pub async fn want_next<M>(&mut self, proto_id: ProtocolId<Erased>, eff: &Effects<M>) -> anyhow::Result<()> {
474        #[allow(clippy::expect_used)]
475        self.protocols
476            .get_mut(&proto_id)
477            .ok_or_else(|| anyhow::anyhow!("protocol {} not registered", proto_id))
478            .expect("internal error")
479            .want_next(eff)
480            .await?;
481        Ok(())
482    }
483}
484
485#[derive(PartialEq, serde::Serialize, serde::Deserialize)]
486struct PerProto {
487    incoming: BytesMut,
488    outgoing: BytesMut,
489    sent_bytes: usize,
490    notifiers: VecDeque<(StageRef<Sent>, usize)>,
491    handler: StageRef<HandlerMessage>,
492    wanted: usize,
493    frame: Frame,
494    max_buffer: usize,
495}
496
497impl std::fmt::Debug for PerProto {
498    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
499        f.debug_struct("PerProto")
500            .field("incoming", &self.incoming.len())
501            .field("outgoing", &self.outgoing.len())
502            .field("sent_bytes", &self.sent_bytes)
503            .field("notifiers", &self.notifiers)
504            .field("handler", &self.handler)
505            .field("wanted", &self.wanted)
506            .field("frame", &self.frame)
507            .field("max_buffer", &self.max_buffer)
508            .finish()
509    }
510}
511
512impl PerProto {
513    pub fn new(handler: StageRef<HandlerMessage>, frame: Frame, max_buffer: usize) -> Self {
514        Self {
515            incoming: BytesMut::with_capacity(max_buffer),
516            outgoing: BytesMut::with_capacity(max_buffer),
517            sent_bytes: 0,
518            notifiers: VecDeque::new(),
519            handler,
520            wanted: 0,
521            frame,
522            max_buffer,
523        }
524    }
525
526    pub async fn received<M>(&mut self, _timestamp: Timestamp, bytes: Bytes, eff: &Effects<M>) -> anyhow::Result<()> {
527        if self.max_buffer == 0 {
528            tracing::debug!(size = bytes.len(), "ignoring bytes");
529            return Ok(());
530        }
531        tracing::trace!(wanted = self.wanted, "received bytes");
532        if self.incoming.len() + bytes.len() > self.max_buffer {
533            tracing::info!(buffered = self.incoming.len(), max_buffer = self.max_buffer, "message exceeds buffer");
534            anyhow::bail!(
535                "message (size {}) plus buffer (size {}) exceeds limit ({})",
536                bytes.len(),
537                self.incoming.len(),
538                self.max_buffer
539            );
540        }
541        self.incoming.extend(&bytes);
542        while self.wanted > 0
543            && let Some(bytes) = self.frame.try_consume(&mut self.incoming)?
544        {
545            tracing::trace!(len = bytes.len(), "extracted message");
546            eff.send(&self.handler, HandlerMessage::FromNetwork(bytes)).await;
547            self.wanted -= 1;
548        }
549        Ok(())
550    }
551
552    pub async fn want_next<M>(&mut self, eff: &Effects<M>) -> anyhow::Result<()> {
553        tracing::trace!(wanted = self.wanted, "wanting next");
554        if !self.incoming.is_empty()
555            && let Some(bytes) = self.frame.try_consume(&mut self.incoming)?
556        {
557            tracing::trace!(len = bytes.len(), "extracted message");
558            eff.send(&self.handler, HandlerMessage::FromNetwork(bytes)).await;
559        } else {
560            tracing::trace!("next delivery deferred");
561            self.wanted += 1;
562        }
563        Ok(())
564    }
565
566    pub fn enqueue_send(&mut self, bytes: Bytes, sent: StageRef<Sent>) {
567        self.outgoing.extend(&bytes);
568        self.notifiers.push_back((sent, self.sent_bytes + self.outgoing.len()));
569    }
570
571    pub async fn next_segment<M>(&mut self, eff: &Effects<M>) -> Option<Bytes> {
572        if self.outgoing.is_empty() {
573            return None;
574        }
575        let size = self.outgoing.len().min(MAX_SEGMENT_SIZE);
576        self.sent_bytes += size;
577        while let Some((_sent, size)) = self.notifiers.front() {
578            if self.sent_bytes >= *size {
579                #[expect(clippy::expect_used)]
580                let (sent, _) = self.notifiers.pop_front().expect("checked above");
581                eff.send(&sent, Sent).await;
582            } else {
583                break;
584            }
585        }
586        Some(self.outgoing.copy_to_bytes(size))
587    }
588}
589
590#[cfg(test)]
591mod tests {
592    use std::{fmt, sync::Arc, time::Duration};
593
594    use amaru_network::connection::TokioConnections;
595    use amaru_ouroboros::ConnectionsResource;
596    use amaru_ouroboros_traits::ConnectionProvider;
597    use futures_util::StreamExt;
598    use pure_stage::{
599        Effect, StageGraph,
600        simulation::{Blocked, SimulationBuilder, SimulationRunning},
601        tokio::TokioBuilder,
602        trace_buffer::TraceBuffer,
603    };
604    use tokio::{
605        io::{AsyncReadExt, AsyncWriteExt},
606        net::TcpListener,
607        runtime::Handle,
608        time::timeout,
609    };
610    use tracing_subscriber::EnvFilter;
611
612    use super::*;
613    use crate::{
614        network_effects::{RecvEffect, SendEffect},
615        protocol::{Initiator, PROTO_HANDSHAKE, PROTO_N2N_BLOCK_FETCH, PROTO_TEST, Responder},
616    };
617
618    /// Tests with real async behaviour unfortunately need real wall clock sleep time to allow
619    /// things to propagate or assert that something doesn’t get propagated. If tests below are
620    /// flaky then this value may be too small for the machine running the test.
621    const SAFE_SLEEP: Duration = Duration::from_millis(400);
622    const TIMEOUT: Duration = Duration::from_secs(1);
623
624    async fn s<F: Future>(f: F)
625    where
626        F::Output: fmt::Debug,
627    {
628        timeout(SAFE_SLEEP, f).await.unwrap_err();
629    }
630
631    async fn t<F: Future>(f: F) -> F::Output {
632        timeout(TIMEOUT, f).await.unwrap()
633    }
634
635    #[tokio::test]
636    async fn test_tcp() {
637        let _guard = pure_stage::register_data_deserializer::<MuxMessage>();
638        let _guard = pure_stage::register_data_deserializer::<NonEmptyBytes>();
639        let _guard = pure_stage::register_effect_deserializer::<SendEffect>();
640        let _guard = pure_stage::register_effect_deserializer::<RecvEffect>();
641        let _guard = pure_stage::register_data_deserializer::<State>();
642
643        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
644        let server_addr = listener.local_addr().unwrap();
645        let server_task = tokio::spawn(async move { listener.accept().await.unwrap().0 });
646
647        let network = TokioConnections::new(65536);
648        let conn_id = t(network.connect(vec![server_addr], Duration::from_secs(5))).await.unwrap();
649        let mut tcp = t(server_task).await.unwrap();
650
651        let trace_buffer = TraceBuffer::new_shared(1000, 1000000);
652        let trace_guard = TraceBuffer::drop_guard(&trace_buffer);
653        let mut graph = SimulationBuilder::default().with_trace_buffer(trace_buffer);
654
655        let mux = graph.stage("mux", super::stage);
656        let mux = graph.wire_up(mux, State::new(conn_id, &[(PROTO_TEST.erase(), 0)], Role::Initiator));
657
658        let (output, mut rx) = graph.output::<HandlerMessage>("output", 10);
659        let (sent, mut sent_rx) = graph.output::<Sent>("sent", 10);
660        let input = graph.input(&mux);
661
662        graph.resources().put::<ConnectionsResource>(Arc::new(network));
663
664        let mut running = graph.run();
665        let join_handle = tokio::spawn(async move {
666            loop {
667                let blocked = running.run_until_blocked();
668                eprintln!("{blocked:?}");
669                match blocked {
670                    Blocked::Idle => running.await_external_input().await,
671                    Blocked::Sleeping { .. } => unreachable!(),
672                    Blocked::Deadlock(send_blocks) => panic!("deadlock: {:?}", send_blocks),
673                    Blocked::Breakpoint(..) => unreachable!(),
674                    Blocked::Busy { external_effects, .. } => {
675                        assert!(external_effects > 0);
676                        running.await_external_effect().await;
677                    }
678                    Blocked::Terminated(name) => return name,
679                };
680            }
681        });
682
683        input
684            .send(MuxMessage::Send(PROTO_TEST.erase(), Bytes::copy_from_slice(&[1, 24, 33]).try_into().unwrap(), sent))
685            .await
686            .unwrap();
687        let mut buf = [0u8; 11];
688        assert_eq!(t(tcp.read_exact(&mut buf)).await.unwrap(), 11);
689        t(sent_rx.next()).await.unwrap();
690        // first four bytes are timestamp; proto ID is 257 (0x0101), length is 3
691        assert_eq!(&buf[4..], [1, 1, 0, 3, 1, 24, 33]);
692
693        input
694            .send(MuxMessage::Register {
695                protocol: PROTO_TEST.erase(),
696                frame: Frame::OneCborItem,
697                handler: output,
698                max_buffer: 100,
699            })
700            .await
701            .unwrap();
702        assert_eq!(t(rx.next()).await.unwrap(), HandlerMessage::Registered(PROTO_TEST.erase()));
703
704        input.send(MuxMessage::WantNext(PROTO_TEST.erase())).await.unwrap();
705
706        // need to flip role bit before sending as responses
707        buf[4] |= 0x80;
708
709        t(tcp.write_all(&buf)).await.unwrap();
710        t(tcp.flush()).await.unwrap();
711        assert_eq!(t(rx.next()).await.unwrap(), HandlerMessage::FromNetwork(NonEmptyBytes::from_slice(&[1]).unwrap()));
712        s(rx.next()).await;
713        input.send(MuxMessage::WantNext(PROTO_TEST.erase())).await.unwrap();
714        assert_eq!(
715            t(rx.next()).await.unwrap(),
716            HandlerMessage::FromNetwork(NonEmptyBytes::from_slice(&[24, 33]).unwrap())
717        );
718
719        // wrong protocol ID
720        buf[5] += 1;
721        t(tcp.write_all(&buf)).await.unwrap();
722        t(tcp.flush()).await.unwrap();
723        assert_eq!(&t(join_handle).await.unwrap(), mux.name());
724
725        trace_guard.defuse();
726    }
727
728    #[test]
729    fn test_muxing() {
730        let _ = tracing_subscriber::fmt().with_env_filter(EnvFilter::from_default_env()).with_test_writer().try_init();
731
732        let _guard = pure_stage::register_data_deserializer::<MuxMessage>();
733        let _guard = pure_stage::register_data_deserializer::<NonEmptyBytes>();
734        let _guard = pure_stage::register_effect_deserializer::<SendEffect>();
735        let _guard = pure_stage::register_effect_deserializer::<RecvEffect>();
736        let _guard = pure_stage::register_data_deserializer::<State>();
737
738        let trace_buffer = TraceBuffer::new_shared(100, 1_000_000);
739        let drop_guard = TraceBuffer::drop_guard(&trace_buffer);
740        let mut network = SimulationBuilder::default().with_trace_buffer(trace_buffer);
741        let mux = network.stage("mux", super::stage);
742        let conn_id = ConnectionId::initial();
743        let mux = network.wire_up(
744            mux,
745            State::new(
746                conn_id,
747                // sequence of registration is the sequence of round-robin
748                &[(PROTO_TEST.erase(), 1024), (PROTO_N2N_BLOCK_FETCH.erase(), 0), (PROTO_HANDSHAKE.erase(), 1)],
749                Role::Initiator,
750            ),
751        );
752
753        let mut running = network.run();
754        let running = &mut running;
755
756        // set breakpoints to capture interactions with outside world
757        running.breakpoint("send", |eff| matches!(eff, Effect::External { effect, .. } if effect.is::<SendEffect>()));
758        running.breakpoint("recv", |eff| matches!(eff, Effect::External { effect, .. } if effect.is::<RecvEffect>()));
759        running.breakpoint("spawn", |eff| matches!(eff, Effect::WireStage { .. }));
760
761        // send a message to trigger creation of the writer and reader stages
762        let chain_sync = StageRef::named_for_tests("chain_sync");
763        running.enqueue_msg(
764            &mux,
765            [MuxMessage::Register {
766                protocol: PROTO_TEST.erase(),
767                frame: Frame::OneCborItem,
768                handler: chain_sync.clone(),
769                max_buffer: 1024,
770            }],
771        );
772        let spawn1 = running.run_until_blocked().assert_breakpoint("spawn");
773        let writer = spawn1.extract_wire_stage(&mux, (conn_id, (*mux).clone(), Role::Initiator)).clone();
774        running.handle_effect(spawn1);
775
776        let spawn2 = running.run_until_blocked().assert_breakpoint("spawn");
777        let reader = spawn2.extract_wire_stage(&mux, (conn_id, (*mux).clone(), Role::Initiator)).clone();
778        running.handle_effect(spawn2);
779
780        {
781            let mux_name = mux.name().clone();
782            let writer = writer.clone();
783            let reader = reader.clone();
784            running.breakpoint(
785                "mux",
786                move |eff| matches!(eff, Effect::Send { from, to, .. } if from == &mux_name && to != &writer && to != &reader),
787            );
788        }
789
790        running
791            .run_until_blocked()
792            .assert_breakpoint("recv")
793            .assert_external(&reader, &RecvEffect { conn: conn_id, bytes: HEADER_LEN });
794        let registered = running.run_until_blocked().assert_breakpoint("mux");
795        registered.assert_send(&mux, &chain_sync, HandlerMessage::Registered(PROTO_TEST.erase()));
796        running.handle_effect(registered);
797        running.enqueue_msg(&mux, [MuxMessage::WantNext(PROTO_TEST.erase())]);
798        running.run_until_blocked().assert_busy([&reader]);
799
800        // send a message towards the network
801        let send_msg = |running: &mut SimulationRunning,
802                        id: u64,
803                        msg: u8,
804                        len: usize,
805                        proto_id: ProtocolId<Initiator>| {
806            let bytes = vec![msg; len];
807            let sent = StageRef::named_for_tests(&format!("sent_{id}"));
808            running.enqueue_msg(
809                &mux,
810                [MuxMessage::Send(proto_id.erase(), Bytes::copy_from_slice(&bytes).try_into().unwrap(), sent.clone())],
811            );
812            sent
813        };
814
815        let assert_send = |running: &mut SimulationRunning, data: &[(usize, u8)], proto_id: ProtocolId<Initiator>| {
816            running.run_until_blocked().assert_breakpoint("send").extract_external::<SendEffect>(&writer).assert_frame(
817                conn_id,
818                proto_id.erase(),
819                data,
820            );
821        };
822        let resume_send = |running: &mut SimulationRunning| {
823            running.resume_external::<SendEffect>(&writer, Ok(())).unwrap();
824        };
825        let assert_and_resume_send =
826            |running: &mut SimulationRunning, data: &[(usize, u8)], proto_id: ProtocolId<Initiator>| {
827                assert_send(running, data, proto_id);
828                resume_send(running);
829            };
830        let assert_respond = |running: &mut SimulationRunning, sent: &StageRef<Sent>| {
831            let mux_sent = running.run_until_blocked().assert_breakpoint("mux");
832            mux_sent.assert_send(&mux, sent, Sent);
833            running.handle_effect(mux_sent);
834        };
835
836        // start write but don't let the writer finish yet
837        let cr1 = send_msg(running, 101, 1, 1024, PROTO_TEST);
838        assert_respond(running, &cr1);
839        assert_send(running, &[(1024, 1)], PROTO_TEST);
840
841        // put 1024 bytes into the proto buffer
842        let cr2 = send_msg(running, 102, 2, 1024, PROTO_TEST);
843        // put 10 bytes into the proto buffer
844        let cr3 = send_msg(running, 103, 3, 10, PROTO_TEST);
845        // the above are for checking correct responses via the CallRefs
846
847        // fill segments for other two protocols
848        let cr4 = send_msg(running, 104, 4, 66000, PROTO_HANDSHAKE);
849        let cr5 = send_msg(running, 105, 5, 66000, PROTO_N2N_BLOCK_FETCH);
850
851        resume_send(running);
852        assert_and_resume_send(running, &[(65535, 5)], PROTO_N2N_BLOCK_FETCH);
853        assert_and_resume_send(running, &[(65535, 4)], PROTO_HANDSHAKE);
854        assert_respond(running, &cr2);
855        assert_respond(running, &cr3);
856        assert_and_resume_send(running, &[(1024, 2), (10, 3)], PROTO_TEST);
857        assert_respond(running, &cr5);
858        assert_and_resume_send(running, &[(465, 5)], PROTO_N2N_BLOCK_FETCH);
859        assert_respond(running, &cr4);
860        assert_and_resume_send(running, &[(465, 4)], PROTO_HANDSHAKE);
861
862        let recv_header = RecvEffect { conn: conn_id, bytes: HEADER_LEN };
863        let recv_msg =
864            |running: &mut SimulationRunning, proto_id: ProtocolId<Responder>, bytes: &[u8], recv: &[&[u8]]| {
865                let mut msg = Header::encode(proto_id, bytes, Timestamp::now()).into_inner();
866                running
867                    .resume_external::<RecvEffect>(&reader, Ok(msg.split_to(HEADER_LEN.get()).try_into().unwrap()))
868                    .unwrap();
869                let msg = NonEmptyBytes::new(msg).unwrap();
870                running
871                    .run_until_blocked()
872                    .assert_breakpoint("recv")
873                    .assert_external(&reader, &RecvEffect { conn: conn_id, bytes: msg.len() });
874                running.resume_external::<RecvEffect>(&reader, Ok(msg)).unwrap();
875                for recv in recv {
876                    if recv.is_empty() {
877                        running.run_until_blocked().assert_breakpoint("recv").assert_external(&reader, &recv_header);
878                        continue;
879                    }
880                    running.run_until_blocked().assert_breakpoint("mux").assert_send(
881                        &mux,
882                        &chain_sync,
883                        HandlerMessage::FromNetwork(NonEmptyBytes::from_slice(recv).unwrap()),
884                    );
885                    running.resume_send(&mux, &chain_sync, None).unwrap();
886                    running.enqueue_msg(&mux, [MuxMessage::WantNext(proto_id.initiator().erase())]);
887                }
888                // running.run_until_blocked().assert_busy([&reader]);
889            };
890
891        // send CBOR 1 followed by incomplete CBOR; "recv" effect always happens second
892        recv_msg(running, PROTO_TEST.responder(), &[1, 24], &[&[1], &[]]);
893        // send CBOR 25 continuation followed by CBOR 3
894        recv_msg(running, PROTO_TEST.responder(), &[25, 3], &[&[24, 25], &[], &[3]]);
895
896        // test buffer size violation
897        recv_msg(running, PROTO_HANDSHAKE.responder(), &[1, 2, 3], &[]);
898        running.run_until_blocked().assert_terminated(mux.name());
899
900        drop_guard.defuse();
901    }
902
903    trait AssertBytes {
904        fn assert_frame(&self, conn: ConnectionId, proto_id: ProtocolId<Erased>, data: &[(usize, u8)]);
905    }
906    impl AssertBytes for SendEffect {
907        fn assert_frame(&self, conn: ConnectionId, proto_id: ProtocolId<Erased>, data: &[(usize, u8)]) {
908            assert_eq!(self.conn, conn);
909            let mut header = self.data.slice(..HEADER_LEN.get());
910            let header = Header::decode(&mut header).unwrap().unwrap();
911            assert_eq!(header.proto_id, proto_id);
912            assert_eq!(header.length.get() as usize, data.iter().map(|(len, _)| len).sum::<usize>());
913            let mut bytes = self.data.slice(HEADER_LEN.get()..);
914            for &(len, msg) in data {
915                assert_eq!(&bytes.split_to(len), &vec![msg; len]);
916            }
917        }
918    }
919
920    #[tokio::test]
921    async fn test_tokio() {
922        let _guard = pure_stage::register_data_deserializer::<MuxMessage>();
923        let _guard = pure_stage::register_data_deserializer::<NonEmptyBytes>();
924        let _guard = pure_stage::register_effect_deserializer::<SendEffect>();
925        let _guard = pure_stage::register_effect_deserializer::<RecvEffect>();
926        let _guard = pure_stage::register_data_deserializer::<State>();
927
928        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
929        let server_addr = listener.local_addr().unwrap();
930        let server_task = tokio::spawn(async move { listener.accept().await.unwrap().0 });
931
932        let network = TokioConnections::new(65536);
933        let conn_id = t(network.connect(vec![server_addr], Duration::from_secs(5))).await.unwrap();
934        let mut tcp = t(server_task).await.unwrap();
935
936        let trace_buffer = TraceBuffer::new_shared(1000, 1000000);
937        let trace_guard = TraceBuffer::drop_guard(&trace_buffer);
938        let mut graph = TokioBuilder::default().with_trace_buffer(trace_buffer);
939
940        let mux = graph.stage("mux", super::stage);
941        let mux = graph.wire_up(mux, State::new(conn_id, &[(PROTO_TEST.erase(), 0)], Role::Initiator));
942
943        let (output, mut rx) = graph.output::<HandlerMessage>("output", 10);
944        let (sent, mut sent_rx) = graph.output::<Sent>("sent", 10);
945        let input = graph.input(&mux);
946
947        graph.resources().put::<ConnectionsResource>(Arc::new(network));
948
949        let running = graph.run(Handle::current());
950
951        input
952            .send(MuxMessage::Send(PROTO_TEST.erase(), Bytes::copy_from_slice(&[1, 24, 33]).try_into().unwrap(), sent))
953            .await
954            .unwrap();
955        let mut buf = [0u8; 11];
956        assert_eq!(t(tcp.read_exact(&mut buf)).await.unwrap(), 11);
957        t(sent_rx.next()).await.unwrap();
958        // first four bytes are timestamp; proto ID is 257 (0x0101), length is 3
959        assert_eq!(&buf[4..], [1, 1, 0, 3, 1, 24, 33]);
960
961        input
962            .send(MuxMessage::Register {
963                protocol: PROTO_TEST.erase(),
964                frame: Frame::OneCborItem,
965                handler: output,
966                max_buffer: 100,
967            })
968            .await
969            .unwrap();
970        assert_eq!(t(rx.next()).await.unwrap(), HandlerMessage::Registered(PROTO_TEST.erase()));
971
972        input.send(MuxMessage::WantNext(PROTO_TEST.erase())).await.unwrap();
973
974        // need to flip role bit before sending as responses
975        buf[4] |= 0x80;
976
977        t(tcp.write_all(&buf)).await.unwrap();
978        t(tcp.flush()).await.unwrap();
979        assert_eq!(t(rx.next()).await.unwrap(), HandlerMessage::FromNetwork(NonEmptyBytes::from_slice(&[1]).unwrap()));
980        s(rx.next()).await;
981        input.send(MuxMessage::WantNext(PROTO_TEST.erase())).await.unwrap();
982        assert_eq!(
983            t(rx.next()).await.unwrap(),
984            HandlerMessage::FromNetwork(NonEmptyBytes::from_slice(&[24, 33]).unwrap())
985        );
986
987        // wrong protocol ID
988        buf[5] += 1;
989        t(tcp.write_all(&buf)).await.unwrap();
990        t(tcp.flush()).await.unwrap();
991        t(running.join()).await;
992
993        trace_guard.defuse();
994    }
995}