1#[expect(clippy::disallowed_types)]
16use std::collections::HashMap;
17use std::{
18 cell::RefCell,
19 collections::{VecDeque, hash_map::Entry},
20 num::{NonZeroU16, NonZeroUsize},
21 time::SystemTime,
22};
23
24use amaru_kernel::NonEmptyBytes;
25use amaru_observability::trace;
26use amaru_ouroboros::ConnectionId;
27use anyhow::Context;
28use bytes::{Buf, BufMut, Bytes, BytesMut, TryGetError};
29use cbor_data::{Cbor, ErrorKind, ParseError};
30use pure_stage::{EPOCH, Effects, Instant, StageRef, TryInStage, Void};
31
32use crate::{
33 network_effects::{Network, NetworkOps},
34 protocol::{Erased, ProtocolId, Role, RoleT},
35};
36
37pub fn register_deserializers() -> pure_stage::DeserializerGuards {
38 vec![
39 pure_stage::register_data_deserializer::<MuxMessage>().boxed(),
40 pure_stage::register_data_deserializer::<NonEmptyBytes>().boxed(),
41 pure_stage::register_data_deserializer::<State>().boxed(),
42 pure_stage::register_data_deserializer::<HandlerMessage>().boxed(),
43 pure_stage::register_data_deserializer::<Sent>().boxed(),
44 pure_stage::register_data_deserializer::<Read>().boxed(),
45 ]
46}
47
48const MAX_SEGMENT_SIZE: usize = 65535;
49
50#[derive(Debug, PartialEq, serde::Serialize, serde::Deserialize)]
52pub struct Timestamp(u32);
53
54impl Timestamp {
55 pub fn now() -> Self {
56 #[expect(clippy::expect_used)]
57 Self(
58 SystemTime::now()
59 .duration_since(SystemTime::UNIX_EPOCH)
60 .expect("system time is not supposed to be before the UNIX epoch")
61 .as_micros() as u32,
62 )
63 }
64
65 fn encode(self, buffer: &mut BytesMut) {
66 buffer.put_u32(self.0);
67 }
68
69 pub fn from_instant(instant: Instant) -> Self {
70 Self(instant.saturating_since(*EPOCH).as_micros() as u32)
71 }
72
73 fn decode(buffer: &mut Bytes) -> Result<Self, TryGetError> {
74 Ok(Self(buffer.try_get_u32()?))
75 }
76}
77
78#[derive(Debug, Clone, Copy, PartialEq, serde::Serialize, serde::Deserialize)]
79pub enum Frame {
80 OneCborItem,
82 Buffer,
84}
85
86impl Frame {
87 pub fn try_consume(&self, data: &mut BytesMut) -> Result<Option<NonEmptyBytes>, ParseError> {
88 match self {
89 Frame::OneCborItem => match Cbor::checked_prefix(data) {
90 Ok((item, _rest)) => {
91 let item = data.copy_to_bytes(item.as_slice().len());
92 #[expect(clippy::expect_used)]
93 Ok(Some(item.try_into().expect("guaranteed by CBOR standard")))
94 }
95 Err(e) if matches!(e.kind(), ErrorKind::UnexpectedEof(_)) => Ok(None),
96 Err(e) => Err(e),
97 },
98 Frame::Buffer => Ok(None),
99 }
100 }
101}
102
103#[derive(Debug, PartialEq, serde::Serialize, serde::Deserialize)]
104pub enum HandlerMessage {
105 Registered(ProtocolId<Erased>),
106 FromNetwork(NonEmptyBytes),
107}
108
109#[derive(Debug, PartialEq, serde::Serialize, serde::Deserialize)]
110pub struct Sent;
111
112#[derive(Debug, PartialEq, serde::Serialize, serde::Deserialize)]
113pub struct Read;
114
115#[derive(Debug, PartialEq, serde::Serialize, serde::Deserialize)]
116pub enum MuxMessage {
117 Register { protocol: ProtocolId<Erased>, frame: Frame, handler: StageRef<HandlerMessage>, max_buffer: usize },
122 Buffer(ProtocolId<Erased>, usize),
128 Send(ProtocolId<Erased>, NonEmptyBytes, StageRef<Sent>),
130 FromNetwork(Timestamp, ProtocolId<Erased>, NonEmptyBytes),
132 Written,
134 WantNext(ProtocolId<Erased>),
136 Terminate,
138}
139
140#[derive(Debug, PartialEq, serde::Serialize, serde::Deserialize)]
141pub struct State {
142 conn: Connection,
143 muxer: Muxer,
144 sending: bool,
145}
146
147#[derive(Debug, PartialEq, serde::Serialize, serde::Deserialize)]
148enum Connection {
149 Unint(ConnectionId),
150 Init(StageRef<NonEmptyBytes>, StageRef<Read>),
151}
152
153impl State {
154 pub fn new(conn: ConnectionId, buffer: &[(ProtocolId<Erased>, usize)], role: Role) -> Self {
159 let mut muxer = Muxer::new(role);
160 for &(proto_id, limit) in buffer {
161 #[expect(clippy::expect_used)]
162 muxer.buffer(proto_id, limit).expect("no buffered data yet");
163 }
164 Self { conn: Connection::Unint(conn), muxer, sending: false }
165 }
166
167 pub async fn init(
168 &mut self,
169 eff: &mut Effects<MuxMessage>,
170 ) -> (&mut Muxer, &mut bool, &StageRef<NonEmptyBytes>, &StageRef<Read>) {
171 match &mut self.conn {
172 Connection::Unint(conn) => {
173 let writer = eff
174 .stage(
175 format!("writer-{}", conn),
176 move |(conn, muxer, role), data: NonEmptyBytes, eff| async move {
177 Network::new(&eff)
178 .send(conn, data)
179 .await
180 .or_terminate(
181 &eff,
182 async |err| tracing::error!(%err, %role, "failed to send data to network"),
183 )
184 .await;
185 eff.send(&muxer, MuxMessage::Written).await;
186 (conn, muxer, role)
187 },
188 )
189 .await;
190 let writer = eff.supervise(writer, MuxMessage::Terminate);
191 let writer = eff.wire_up(writer, (*conn, eff.me(), self.muxer.role())).await;
192 let reader = eff.stage(format!("reader-{}", conn), read_segment).await;
193 let reader = eff.supervise(reader, MuxMessage::Terminate);
194 let reader = eff.wire_up(reader, (*conn, eff.me(), self.muxer.role())).await;
195 eff.send(&reader, Read).await;
196 self.conn = Connection::Init(writer, reader);
197 }
198 Connection::Init(..) => {}
199 }
200 let Connection::Init(writer, reader) = &self.conn else { unreachable!() };
201 (&mut self.muxer, &mut self.sending, writer, reader)
202 }
203}
204
205pub async fn stage(mut state: State, msg: MuxMessage, mut eff: Effects<MuxMessage>) -> State {
206 let (muxer, sending, writer, reader) = state.init(&mut eff).await;
207
208 handle_msg(msg, &eff, muxer, sending, writer, reader)
209 .await
210 .or_terminate(&eff, async |error| {
211 use std::fmt::Write;
212 let mut err = String::new();
213 for error in error.chain() {
214 if !err.is_empty() {
215 err.push_str(" <- ");
216 }
217 write!(&mut err, "{}", error).ok();
218 }
219 tracing::error!(%err, role=%muxer.role(), "muxing error")
220 })
221 .await;
222
223 state
224}
225
226async fn handle_msg(
227 msg: MuxMessage,
228 eff: &Effects<MuxMessage>,
229 muxer: &mut Muxer,
230 sending: &mut bool,
231 writer: &StageRef<NonEmptyBytes>,
232 reader: &StageRef<Read>,
233) -> anyhow::Result<()> {
234 match msg {
235 MuxMessage::Register { protocol, frame, handler, max_buffer } => {
236 muxer.register(protocol, frame, max_buffer, handler, eff).await
237 }
238 MuxMessage::Buffer(proto_id, limit) => muxer.buffer(proto_id, limit),
239 MuxMessage::Send(proto_id, bytes, sent) => {
240 tracing::trace!(%proto_id, bytes = bytes.len(), "send");
241 muxer.outgoing(proto_id, bytes.into(), sent);
242 if !*sending && let Some((proto_id, bytes)) = muxer.next_segment(eff).await {
243 *sending = true;
244 let header = muxer.encode_header(eff, proto_id, &bytes).await;
245 eff.send(writer, header).await;
246 }
247 Ok(())
248 }
249 MuxMessage::FromNetwork(timestamp, proto_id, bytes) => {
250 tracing::trace!(%proto_id, bytes = bytes.len(), "received");
251 muxer
252 .received(timestamp, proto_id.opposite(), bytes.into(), eff)
253 .await
254 .with_context(|| format!("reading network message for protocol {}", proto_id))?;
255 eff.send(reader, Read).await;
256 Ok(())
257 }
258 MuxMessage::WantNext(proto_id) => {
259 muxer.want_next(proto_id, eff).await.with_context(|| format!("reading message for protocol {}", proto_id))
260 }
261 MuxMessage::Written => {
262 *sending = false;
263 if let Some((proto_id, bytes)) = muxer.next_segment(eff).await {
264 *sending = true;
265 let header = muxer.encode_header(eff, proto_id, &bytes).await;
266 eff.send(writer, header).await;
267 }
268 Ok(())
269 }
270 MuxMessage::Terminate => {
271 tracing::debug!(role=%muxer.role(), "terminating muxer due to read/write error");
272 eff.terminate::<Void>().await;
273 Ok(())
274 }
275 }
276}
277
278async fn read_segment(
279 (conn, muxer, role): (ConnectionId, StageRef<MuxMessage>, Role),
280 _token: Read,
281 eff: Effects<Read>,
282) -> (ConnectionId, StageRef<MuxMessage>, Role) {
283 let header = loop {
284 let data = Network::new(&eff)
285 .recv(conn, HEADER_LEN)
286 .await
287 .or_terminate(
288 &eff,
289 async |err| tracing::error!(%role, %err, "failed to receive segment header from network"),
290 )
291 .await;
292 let Some(header) = Header::decode(&mut data.into_inner())
293 .or_terminate(&eff, async |err| tracing::error!(%role, %err, "failed to decode segment header"))
294 .await
295 else {
296 tracing::info!(%role, "received empty segment header");
298 continue;
299 };
300 break header;
301 };
302
303 let data = Network::new(&eff)
304 .recv(conn, header.length.into())
305 .await
306 .or_terminate(&eff, async |err| tracing::error!(%role, %err, "failed to receive segment data from network"))
307 .await;
308
309 eff.send(&muxer, MuxMessage::FromNetwork(header.timestamp, header.proto_id, data)).await;
310 (conn, muxer, role)
311}
312
313struct Header {
318 timestamp: Timestamp,
319 proto_id: ProtocolId<Erased>,
320 length: NonZeroU16,
321}
322const HEADER_LEN: NonZeroUsize = NonZeroUsize::new(8).expect("8 is a valid non-zero size");
323
324impl Header {
325 pub fn encode<R: RoleT>(proto_id: ProtocolId<R>, bytes: impl AsRef<[u8]>, timestamp: Timestamp) -> NonEmptyBytes {
326 thread_local! {
327 static BUFFER: RefCell<BytesMut> = RefCell::new(BytesMut::with_capacity(HEADER_LEN.get() + MAX_SEGMENT_SIZE));
328 }
329 let bytes = bytes.as_ref();
330 BUFFER.with_borrow_mut(move |buffer| {
331 buffer.clear();
332 timestamp.encode(buffer);
333 proto_id.encode(buffer);
334 buffer.put_u16(bytes.len() as u16);
335 buffer.extend_from_slice(bytes);
336 #[expect(clippy::expect_used)]
337 buffer.copy_to_bytes(buffer.remaining()).try_into().expect("guaranteed by writing to the buffer")
338 })
339 }
340
341 pub fn decode(buffer: &mut Bytes) -> Result<Option<Self>, TryGetError> {
342 let timestamp = Timestamp::decode(buffer)?;
343 let proto_id = ProtocolId::decode(buffer)?;
344 let length = buffer.try_get_u16()?;
345 Ok(NonZeroU16::new(length).map(|length| Self { timestamp, proto_id, length }))
346 }
347}
348
349#[expect(clippy::disallowed_types)]
350type Protocols = HashMap<ProtocolId<Erased>, PerProto>;
351
352#[derive(Debug, PartialEq, serde::Serialize, serde::Deserialize)]
353pub struct Muxer {
354 protocols: Protocols,
355 outgoing: Vec<ProtocolId<Erased>>,
356 next_out: usize,
357 role: Role,
358}
359
360impl Muxer {
361 pub fn new(role: Role) -> Self {
362 Self { protocols: Protocols::new(), outgoing: Vec::new(), next_out: 0, role }
363 }
364
365 pub fn role(&self) -> Role {
366 self.role
367 }
368
369 async fn encode_header<M>(
370 &mut self,
371 eff: &Effects<M>,
372 proto_id: ProtocolId<Erased>,
373 bytes: &Bytes,
374 ) -> NonEmptyBytes {
375 let instant = eff.clock().await;
376 let timestamp = Timestamp::from_instant(instant);
377 Header::encode(proto_id, bytes, timestamp)
378 }
379
380 #[trace(amaru::protocols::mux::REGISTER)]
381 pub async fn register<M>(
382 &mut self,
383 proto_id: ProtocolId<Erased>,
384 frame: Frame,
385 max_buffer: usize,
386 handler: StageRef<HandlerMessage>,
387 eff: &Effects<M>,
388 ) -> anyhow::Result<()> {
389 eff.send(&handler, HandlerMessage::Registered(proto_id)).await;
390 self.do_register(proto_id, frame, max_buffer, handler);
391 Ok(())
392 }
393
394 #[trace(amaru::protocols::mux::BUFFER)]
395 pub fn buffer(&mut self, proto_id: ProtocolId<Erased>, limit: usize) -> anyhow::Result<()> {
396 let pp = self.do_register(proto_id, Frame::Buffer, limit, StageRef::blackhole());
397 if limit == 0 {
398 tracing::trace!(buffer = pp.incoming.len(), "switching to ignoring mode");
399 pp.incoming.clear();
400 } else if pp.incoming.len() > limit {
401 tracing::warn!(buffer = pp.incoming.len(), limit, "reducing buffer killed the connection");
402 anyhow::bail!("reducing buffer ({}) leads to excess data ({})", limit, pp.incoming.len());
403 }
404 Ok(())
405 }
406
407 fn do_register(
408 &mut self,
409 proto_id: ProtocolId<Erased>,
410 frame: Frame,
411 max_buffer: usize,
412 handler: StageRef<HandlerMessage>,
413 ) -> &mut PerProto {
414 if !self.outgoing.contains(&proto_id) {
415 self.outgoing.push(proto_id);
416 }
417 match self.protocols.entry(proto_id) {
418 Entry::Occupied(pp) => {
419 let pp = pp.into_mut();
420 tracing::trace!(want = pp.wanted, "updating registration");
421 pp.frame = frame;
422 pp.max_buffer = max_buffer;
423 pp.handler = handler;
424 pp
425 }
426 Entry::Vacant(pp) => pp.insert(PerProto::new(handler, frame, max_buffer)),
427 }
428 }
429
430 #[trace(amaru::protocols::mux::OUTGOING, proto_id = proto_id, bytes = bytes.len() as u64)]
431 pub fn outgoing(&mut self, proto_id: ProtocolId<Erased>, bytes: Bytes, sent: StageRef<Sent>) {
432 tracing::trace!(%proto_id, bytes = bytes.len(), "enqueueing send");
433 #[allow(clippy::expect_used)]
434 self.protocols
435 .get_mut(&proto_id)
436 .ok_or_else(|| anyhow::anyhow!("protocol {} not registered", proto_id))
437 .expect("internal error")
438 .enqueue_send(bytes, sent);
439 }
440
441 #[trace(amaru::protocols::mux::NEXT_SEGMENT)]
442 pub async fn next_segment<M>(&mut self, eff: &Effects<M>) -> Option<(ProtocolId<Erased>, Bytes)> {
443 for idx in (self.next_out..self.outgoing.len()).chain(0..self.next_out) {
444 let proto_id = self.outgoing[idx];
445 #[allow(clippy::expect_used)]
446 let proto = self.protocols.get_mut(&proto_id).expect("invariant violation");
447 let Some(bytes) = proto.next_segment(eff).await else {
448 continue;
449 };
450 self.next_out = (idx + 1) % self.outgoing.len();
451 tracing::trace!(size = bytes.len(), %proto_id, next = self.next_out, "sending segment");
452 return Some((proto_id, bytes));
453 }
454 None
455 }
456
457 #[trace(amaru::protocols::mux::RECEIVED, bytes = bytes.len() as u64)]
458 pub async fn received<M>(
459 &mut self,
460 timestamp: Timestamp,
461 proto_id: ProtocolId<Erased>,
462 bytes: Bytes,
463 eff: &Effects<M>,
464 ) -> anyhow::Result<()> {
465 if let Some(proto) = self.protocols.get_mut(&proto_id) {
466 proto.received(timestamp, bytes, eff).await
467 } else {
468 anyhow::bail!("received data for unknown protocol {}", proto_id)
469 }
470 }
471
472 #[trace(amaru::protocols::mux::WANT_NEXT)]
473 pub async fn want_next<M>(&mut self, proto_id: ProtocolId<Erased>, eff: &Effects<M>) -> anyhow::Result<()> {
474 #[allow(clippy::expect_used)]
475 self.protocols
476 .get_mut(&proto_id)
477 .ok_or_else(|| anyhow::anyhow!("protocol {} not registered", proto_id))
478 .expect("internal error")
479 .want_next(eff)
480 .await?;
481 Ok(())
482 }
483}
484
485#[derive(PartialEq, serde::Serialize, serde::Deserialize)]
486struct PerProto {
487 incoming: BytesMut,
488 outgoing: BytesMut,
489 sent_bytes: usize,
490 notifiers: VecDeque<(StageRef<Sent>, usize)>,
491 handler: StageRef<HandlerMessage>,
492 wanted: usize,
493 frame: Frame,
494 max_buffer: usize,
495}
496
497impl std::fmt::Debug for PerProto {
498 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
499 f.debug_struct("PerProto")
500 .field("incoming", &self.incoming.len())
501 .field("outgoing", &self.outgoing.len())
502 .field("sent_bytes", &self.sent_bytes)
503 .field("notifiers", &self.notifiers)
504 .field("handler", &self.handler)
505 .field("wanted", &self.wanted)
506 .field("frame", &self.frame)
507 .field("max_buffer", &self.max_buffer)
508 .finish()
509 }
510}
511
512impl PerProto {
513 pub fn new(handler: StageRef<HandlerMessage>, frame: Frame, max_buffer: usize) -> Self {
514 Self {
515 incoming: BytesMut::with_capacity(max_buffer),
516 outgoing: BytesMut::with_capacity(max_buffer),
517 sent_bytes: 0,
518 notifiers: VecDeque::new(),
519 handler,
520 wanted: 0,
521 frame,
522 max_buffer,
523 }
524 }
525
526 pub async fn received<M>(&mut self, _timestamp: Timestamp, bytes: Bytes, eff: &Effects<M>) -> anyhow::Result<()> {
527 if self.max_buffer == 0 {
528 tracing::debug!(size = bytes.len(), "ignoring bytes");
529 return Ok(());
530 }
531 tracing::trace!(wanted = self.wanted, "received bytes");
532 if self.incoming.len() + bytes.len() > self.max_buffer {
533 tracing::info!(buffered = self.incoming.len(), max_buffer = self.max_buffer, "message exceeds buffer");
534 anyhow::bail!(
535 "message (size {}) plus buffer (size {}) exceeds limit ({})",
536 bytes.len(),
537 self.incoming.len(),
538 self.max_buffer
539 );
540 }
541 self.incoming.extend(&bytes);
542 while self.wanted > 0
543 && let Some(bytes) = self.frame.try_consume(&mut self.incoming)?
544 {
545 tracing::trace!(len = bytes.len(), "extracted message");
546 eff.send(&self.handler, HandlerMessage::FromNetwork(bytes)).await;
547 self.wanted -= 1;
548 }
549 Ok(())
550 }
551
552 pub async fn want_next<M>(&mut self, eff: &Effects<M>) -> anyhow::Result<()> {
553 tracing::trace!(wanted = self.wanted, "wanting next");
554 if !self.incoming.is_empty()
555 && let Some(bytes) = self.frame.try_consume(&mut self.incoming)?
556 {
557 tracing::trace!(len = bytes.len(), "extracted message");
558 eff.send(&self.handler, HandlerMessage::FromNetwork(bytes)).await;
559 } else {
560 tracing::trace!("next delivery deferred");
561 self.wanted += 1;
562 }
563 Ok(())
564 }
565
566 pub fn enqueue_send(&mut self, bytes: Bytes, sent: StageRef<Sent>) {
567 self.outgoing.extend(&bytes);
568 self.notifiers.push_back((sent, self.sent_bytes + self.outgoing.len()));
569 }
570
571 pub async fn next_segment<M>(&mut self, eff: &Effects<M>) -> Option<Bytes> {
572 if self.outgoing.is_empty() {
573 return None;
574 }
575 let size = self.outgoing.len().min(MAX_SEGMENT_SIZE);
576 self.sent_bytes += size;
577 while let Some((_sent, size)) = self.notifiers.front() {
578 if self.sent_bytes >= *size {
579 #[expect(clippy::expect_used)]
580 let (sent, _) = self.notifiers.pop_front().expect("checked above");
581 eff.send(&sent, Sent).await;
582 } else {
583 break;
584 }
585 }
586 Some(self.outgoing.copy_to_bytes(size))
587 }
588}
589
590#[cfg(test)]
591mod tests {
592 use std::{fmt, sync::Arc, time::Duration};
593
594 use amaru_network::connection::TokioConnections;
595 use amaru_ouroboros::ConnectionsResource;
596 use amaru_ouroboros_traits::ConnectionProvider;
597 use futures_util::StreamExt;
598 use pure_stage::{
599 Effect, StageGraph,
600 simulation::{Blocked, SimulationBuilder, SimulationRunning},
601 tokio::TokioBuilder,
602 trace_buffer::TraceBuffer,
603 };
604 use tokio::{
605 io::{AsyncReadExt, AsyncWriteExt},
606 net::TcpListener,
607 runtime::Handle,
608 time::timeout,
609 };
610 use tracing_subscriber::EnvFilter;
611
612 use super::*;
613 use crate::{
614 network_effects::{RecvEffect, SendEffect},
615 protocol::{Initiator, PROTO_HANDSHAKE, PROTO_N2N_BLOCK_FETCH, PROTO_TEST, Responder},
616 };
617
618 const SAFE_SLEEP: Duration = Duration::from_millis(400);
622 const TIMEOUT: Duration = Duration::from_secs(1);
623
624 async fn s<F: Future>(f: F)
625 where
626 F::Output: fmt::Debug,
627 {
628 timeout(SAFE_SLEEP, f).await.unwrap_err();
629 }
630
631 async fn t<F: Future>(f: F) -> F::Output {
632 timeout(TIMEOUT, f).await.unwrap()
633 }
634
635 #[tokio::test]
636 async fn test_tcp() {
637 let _guard = pure_stage::register_data_deserializer::<MuxMessage>();
638 let _guard = pure_stage::register_data_deserializer::<NonEmptyBytes>();
639 let _guard = pure_stage::register_effect_deserializer::<SendEffect>();
640 let _guard = pure_stage::register_effect_deserializer::<RecvEffect>();
641 let _guard = pure_stage::register_data_deserializer::<State>();
642
643 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
644 let server_addr = listener.local_addr().unwrap();
645 let server_task = tokio::spawn(async move { listener.accept().await.unwrap().0 });
646
647 let network = TokioConnections::new(65536);
648 let conn_id = t(network.connect(vec![server_addr], Duration::from_secs(5))).await.unwrap();
649 let mut tcp = t(server_task).await.unwrap();
650
651 let trace_buffer = TraceBuffer::new_shared(1000, 1000000);
652 let trace_guard = TraceBuffer::drop_guard(&trace_buffer);
653 let mut graph = SimulationBuilder::default().with_trace_buffer(trace_buffer);
654
655 let mux = graph.stage("mux", super::stage);
656 let mux = graph.wire_up(mux, State::new(conn_id, &[(PROTO_TEST.erase(), 0)], Role::Initiator));
657
658 let (output, mut rx) = graph.output::<HandlerMessage>("output", 10);
659 let (sent, mut sent_rx) = graph.output::<Sent>("sent", 10);
660 let input = graph.input(&mux);
661
662 graph.resources().put::<ConnectionsResource>(Arc::new(network));
663
664 let mut running = graph.run();
665 let join_handle = tokio::spawn(async move {
666 loop {
667 let blocked = running.run_until_blocked();
668 eprintln!("{blocked:?}");
669 match blocked {
670 Blocked::Idle => running.await_external_input().await,
671 Blocked::Sleeping { .. } => unreachable!(),
672 Blocked::Deadlock(send_blocks) => panic!("deadlock: {:?}", send_blocks),
673 Blocked::Breakpoint(..) => unreachable!(),
674 Blocked::Busy { external_effects, .. } => {
675 assert!(external_effects > 0);
676 running.await_external_effect().await;
677 }
678 Blocked::Terminated(name) => return name,
679 };
680 }
681 });
682
683 input
684 .send(MuxMessage::Send(PROTO_TEST.erase(), Bytes::copy_from_slice(&[1, 24, 33]).try_into().unwrap(), sent))
685 .await
686 .unwrap();
687 let mut buf = [0u8; 11];
688 assert_eq!(t(tcp.read_exact(&mut buf)).await.unwrap(), 11);
689 t(sent_rx.next()).await.unwrap();
690 assert_eq!(&buf[4..], [1, 1, 0, 3, 1, 24, 33]);
692
693 input
694 .send(MuxMessage::Register {
695 protocol: PROTO_TEST.erase(),
696 frame: Frame::OneCborItem,
697 handler: output,
698 max_buffer: 100,
699 })
700 .await
701 .unwrap();
702 assert_eq!(t(rx.next()).await.unwrap(), HandlerMessage::Registered(PROTO_TEST.erase()));
703
704 input.send(MuxMessage::WantNext(PROTO_TEST.erase())).await.unwrap();
705
706 buf[4] |= 0x80;
708
709 t(tcp.write_all(&buf)).await.unwrap();
710 t(tcp.flush()).await.unwrap();
711 assert_eq!(t(rx.next()).await.unwrap(), HandlerMessage::FromNetwork(NonEmptyBytes::from_slice(&[1]).unwrap()));
712 s(rx.next()).await;
713 input.send(MuxMessage::WantNext(PROTO_TEST.erase())).await.unwrap();
714 assert_eq!(
715 t(rx.next()).await.unwrap(),
716 HandlerMessage::FromNetwork(NonEmptyBytes::from_slice(&[24, 33]).unwrap())
717 );
718
719 buf[5] += 1;
721 t(tcp.write_all(&buf)).await.unwrap();
722 t(tcp.flush()).await.unwrap();
723 assert_eq!(&t(join_handle).await.unwrap(), mux.name());
724
725 trace_guard.defuse();
726 }
727
728 #[test]
729 fn test_muxing() {
730 let _ = tracing_subscriber::fmt().with_env_filter(EnvFilter::from_default_env()).with_test_writer().try_init();
731
732 let _guard = pure_stage::register_data_deserializer::<MuxMessage>();
733 let _guard = pure_stage::register_data_deserializer::<NonEmptyBytes>();
734 let _guard = pure_stage::register_effect_deserializer::<SendEffect>();
735 let _guard = pure_stage::register_effect_deserializer::<RecvEffect>();
736 let _guard = pure_stage::register_data_deserializer::<State>();
737
738 let trace_buffer = TraceBuffer::new_shared(100, 1_000_000);
739 let drop_guard = TraceBuffer::drop_guard(&trace_buffer);
740 let mut network = SimulationBuilder::default().with_trace_buffer(trace_buffer);
741 let mux = network.stage("mux", super::stage);
742 let conn_id = ConnectionId::initial();
743 let mux = network.wire_up(
744 mux,
745 State::new(
746 conn_id,
747 &[(PROTO_TEST.erase(), 1024), (PROTO_N2N_BLOCK_FETCH.erase(), 0), (PROTO_HANDSHAKE.erase(), 1)],
749 Role::Initiator,
750 ),
751 );
752
753 let mut running = network.run();
754 let running = &mut running;
755
756 running.breakpoint("send", |eff| matches!(eff, Effect::External { effect, .. } if effect.is::<SendEffect>()));
758 running.breakpoint("recv", |eff| matches!(eff, Effect::External { effect, .. } if effect.is::<RecvEffect>()));
759 running.breakpoint("spawn", |eff| matches!(eff, Effect::WireStage { .. }));
760
761 let chain_sync = StageRef::named_for_tests("chain_sync");
763 running.enqueue_msg(
764 &mux,
765 [MuxMessage::Register {
766 protocol: PROTO_TEST.erase(),
767 frame: Frame::OneCborItem,
768 handler: chain_sync.clone(),
769 max_buffer: 1024,
770 }],
771 );
772 let spawn1 = running.run_until_blocked().assert_breakpoint("spawn");
773 let writer = spawn1.extract_wire_stage(&mux, (conn_id, (*mux).clone(), Role::Initiator)).clone();
774 running.handle_effect(spawn1);
775
776 let spawn2 = running.run_until_blocked().assert_breakpoint("spawn");
777 let reader = spawn2.extract_wire_stage(&mux, (conn_id, (*mux).clone(), Role::Initiator)).clone();
778 running.handle_effect(spawn2);
779
780 {
781 let mux_name = mux.name().clone();
782 let writer = writer.clone();
783 let reader = reader.clone();
784 running.breakpoint(
785 "mux",
786 move |eff| matches!(eff, Effect::Send { from, to, .. } if from == &mux_name && to != &writer && to != &reader),
787 );
788 }
789
790 running
791 .run_until_blocked()
792 .assert_breakpoint("recv")
793 .assert_external(&reader, &RecvEffect { conn: conn_id, bytes: HEADER_LEN });
794 let registered = running.run_until_blocked().assert_breakpoint("mux");
795 registered.assert_send(&mux, &chain_sync, HandlerMessage::Registered(PROTO_TEST.erase()));
796 running.handle_effect(registered);
797 running.enqueue_msg(&mux, [MuxMessage::WantNext(PROTO_TEST.erase())]);
798 running.run_until_blocked().assert_busy([&reader]);
799
800 let send_msg = |running: &mut SimulationRunning,
802 id: u64,
803 msg: u8,
804 len: usize,
805 proto_id: ProtocolId<Initiator>| {
806 let bytes = vec![msg; len];
807 let sent = StageRef::named_for_tests(&format!("sent_{id}"));
808 running.enqueue_msg(
809 &mux,
810 [MuxMessage::Send(proto_id.erase(), Bytes::copy_from_slice(&bytes).try_into().unwrap(), sent.clone())],
811 );
812 sent
813 };
814
815 let assert_send = |running: &mut SimulationRunning, data: &[(usize, u8)], proto_id: ProtocolId<Initiator>| {
816 running.run_until_blocked().assert_breakpoint("send").extract_external::<SendEffect>(&writer).assert_frame(
817 conn_id,
818 proto_id.erase(),
819 data,
820 );
821 };
822 let resume_send = |running: &mut SimulationRunning| {
823 running.resume_external::<SendEffect>(&writer, Ok(())).unwrap();
824 };
825 let assert_and_resume_send =
826 |running: &mut SimulationRunning, data: &[(usize, u8)], proto_id: ProtocolId<Initiator>| {
827 assert_send(running, data, proto_id);
828 resume_send(running);
829 };
830 let assert_respond = |running: &mut SimulationRunning, sent: &StageRef<Sent>| {
831 let mux_sent = running.run_until_blocked().assert_breakpoint("mux");
832 mux_sent.assert_send(&mux, sent, Sent);
833 running.handle_effect(mux_sent);
834 };
835
836 let cr1 = send_msg(running, 101, 1, 1024, PROTO_TEST);
838 assert_respond(running, &cr1);
839 assert_send(running, &[(1024, 1)], PROTO_TEST);
840
841 let cr2 = send_msg(running, 102, 2, 1024, PROTO_TEST);
843 let cr3 = send_msg(running, 103, 3, 10, PROTO_TEST);
845 let cr4 = send_msg(running, 104, 4, 66000, PROTO_HANDSHAKE);
849 let cr5 = send_msg(running, 105, 5, 66000, PROTO_N2N_BLOCK_FETCH);
850
851 resume_send(running);
852 assert_and_resume_send(running, &[(65535, 5)], PROTO_N2N_BLOCK_FETCH);
853 assert_and_resume_send(running, &[(65535, 4)], PROTO_HANDSHAKE);
854 assert_respond(running, &cr2);
855 assert_respond(running, &cr3);
856 assert_and_resume_send(running, &[(1024, 2), (10, 3)], PROTO_TEST);
857 assert_respond(running, &cr5);
858 assert_and_resume_send(running, &[(465, 5)], PROTO_N2N_BLOCK_FETCH);
859 assert_respond(running, &cr4);
860 assert_and_resume_send(running, &[(465, 4)], PROTO_HANDSHAKE);
861
862 let recv_header = RecvEffect { conn: conn_id, bytes: HEADER_LEN };
863 let recv_msg =
864 |running: &mut SimulationRunning, proto_id: ProtocolId<Responder>, bytes: &[u8], recv: &[&[u8]]| {
865 let mut msg = Header::encode(proto_id, bytes, Timestamp::now()).into_inner();
866 running
867 .resume_external::<RecvEffect>(&reader, Ok(msg.split_to(HEADER_LEN.get()).try_into().unwrap()))
868 .unwrap();
869 let msg = NonEmptyBytes::new(msg).unwrap();
870 running
871 .run_until_blocked()
872 .assert_breakpoint("recv")
873 .assert_external(&reader, &RecvEffect { conn: conn_id, bytes: msg.len() });
874 running.resume_external::<RecvEffect>(&reader, Ok(msg)).unwrap();
875 for recv in recv {
876 if recv.is_empty() {
877 running.run_until_blocked().assert_breakpoint("recv").assert_external(&reader, &recv_header);
878 continue;
879 }
880 running.run_until_blocked().assert_breakpoint("mux").assert_send(
881 &mux,
882 &chain_sync,
883 HandlerMessage::FromNetwork(NonEmptyBytes::from_slice(recv).unwrap()),
884 );
885 running.resume_send(&mux, &chain_sync, None).unwrap();
886 running.enqueue_msg(&mux, [MuxMessage::WantNext(proto_id.initiator().erase())]);
887 }
888 };
890
891 recv_msg(running, PROTO_TEST.responder(), &[1, 24], &[&[1], &[]]);
893 recv_msg(running, PROTO_TEST.responder(), &[25, 3], &[&[24, 25], &[], &[3]]);
895
896 recv_msg(running, PROTO_HANDSHAKE.responder(), &[1, 2, 3], &[]);
898 running.run_until_blocked().assert_terminated(mux.name());
899
900 drop_guard.defuse();
901 }
902
903 trait AssertBytes {
904 fn assert_frame(&self, conn: ConnectionId, proto_id: ProtocolId<Erased>, data: &[(usize, u8)]);
905 }
906 impl AssertBytes for SendEffect {
907 fn assert_frame(&self, conn: ConnectionId, proto_id: ProtocolId<Erased>, data: &[(usize, u8)]) {
908 assert_eq!(self.conn, conn);
909 let mut header = self.data.slice(..HEADER_LEN.get());
910 let header = Header::decode(&mut header).unwrap().unwrap();
911 assert_eq!(header.proto_id, proto_id);
912 assert_eq!(header.length.get() as usize, data.iter().map(|(len, _)| len).sum::<usize>());
913 let mut bytes = self.data.slice(HEADER_LEN.get()..);
914 for &(len, msg) in data {
915 assert_eq!(&bytes.split_to(len), &vec![msg; len]);
916 }
917 }
918 }
919
920 #[tokio::test]
921 async fn test_tokio() {
922 let _guard = pure_stage::register_data_deserializer::<MuxMessage>();
923 let _guard = pure_stage::register_data_deserializer::<NonEmptyBytes>();
924 let _guard = pure_stage::register_effect_deserializer::<SendEffect>();
925 let _guard = pure_stage::register_effect_deserializer::<RecvEffect>();
926 let _guard = pure_stage::register_data_deserializer::<State>();
927
928 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
929 let server_addr = listener.local_addr().unwrap();
930 let server_task = tokio::spawn(async move { listener.accept().await.unwrap().0 });
931
932 let network = TokioConnections::new(65536);
933 let conn_id = t(network.connect(vec![server_addr], Duration::from_secs(5))).await.unwrap();
934 let mut tcp = t(server_task).await.unwrap();
935
936 let trace_buffer = TraceBuffer::new_shared(1000, 1000000);
937 let trace_guard = TraceBuffer::drop_guard(&trace_buffer);
938 let mut graph = TokioBuilder::default().with_trace_buffer(trace_buffer);
939
940 let mux = graph.stage("mux", super::stage);
941 let mux = graph.wire_up(mux, State::new(conn_id, &[(PROTO_TEST.erase(), 0)], Role::Initiator));
942
943 let (output, mut rx) = graph.output::<HandlerMessage>("output", 10);
944 let (sent, mut sent_rx) = graph.output::<Sent>("sent", 10);
945 let input = graph.input(&mux);
946
947 graph.resources().put::<ConnectionsResource>(Arc::new(network));
948
949 let running = graph.run(Handle::current());
950
951 input
952 .send(MuxMessage::Send(PROTO_TEST.erase(), Bytes::copy_from_slice(&[1, 24, 33]).try_into().unwrap(), sent))
953 .await
954 .unwrap();
955 let mut buf = [0u8; 11];
956 assert_eq!(t(tcp.read_exact(&mut buf)).await.unwrap(), 11);
957 t(sent_rx.next()).await.unwrap();
958 assert_eq!(&buf[4..], [1, 1, 0, 3, 1, 24, 33]);
960
961 input
962 .send(MuxMessage::Register {
963 protocol: PROTO_TEST.erase(),
964 frame: Frame::OneCborItem,
965 handler: output,
966 max_buffer: 100,
967 })
968 .await
969 .unwrap();
970 assert_eq!(t(rx.next()).await.unwrap(), HandlerMessage::Registered(PROTO_TEST.erase()));
971
972 input.send(MuxMessage::WantNext(PROTO_TEST.erase())).await.unwrap();
973
974 buf[4] |= 0x80;
976
977 t(tcp.write_all(&buf)).await.unwrap();
978 t(tcp.flush()).await.unwrap();
979 assert_eq!(t(rx.next()).await.unwrap(), HandlerMessage::FromNetwork(NonEmptyBytes::from_slice(&[1]).unwrap()));
980 s(rx.next()).await;
981 input.send(MuxMessage::WantNext(PROTO_TEST.erase())).await.unwrap();
982 assert_eq!(
983 t(rx.next()).await.unwrap(),
984 HandlerMessage::FromNetwork(NonEmptyBytes::from_slice(&[24, 33]).unwrap())
985 );
986
987 buf[5] += 1;
989 t(tcp.write_all(&buf)).await.unwrap();
990 t(tcp.flush()).await.unwrap();
991 t(running.join()).await;
992
993 trace_guard.defuse();
994 }
995}