1use std::marker::PhantomData;
2use std::sync::atomic::{AtomicBool, Ordering};
3use std::sync::{Arc, Mutex, MutexGuard};
4
5use facet::Facet;
6use facet_core::PtrConst;
7use facet_reflect::Peek;
8use roam_types::{
9 Conduit, ConduitRx, ConduitTx, ConduitTxPermit, Link, LinkRx, LinkTx, LinkTxPermit, MsgFamily,
10 SelfRef, WriteSlot,
11};
12use zerocopy::little_endian::U32 as LeU32;
13
14mod replay_buffer;
15use replay_buffer::{PacketAck, PacketSeq, ReplayBuffer};
16
17#[derive(
23 Clone,
24 Copy,
25 zerocopy::FromBytes,
26 zerocopy::IntoBytes,
27 zerocopy::KnownLayout,
28 zerocopy::Immutable,
29)]
30#[repr(C)]
31struct ResumeKey([u8; 16]);
32
33const CLIENT_HELLO_MAGIC: u32 = u32::from_le_bytes(*b"ROCH");
34const SERVER_HELLO_MAGIC: u32 = u32::from_le_bytes(*b"ROSH");
35
36const CH_HAS_RESUME_KEY: u8 = 0b0000_0001;
38const CH_HAS_LAST_RECEIVED: u8 = 0b0000_0010;
39
40const SH_REJECTED: u8 = 0b0000_0001;
42const SH_HAS_LAST_RECEIVED: u8 = 0b0000_0010;
43
44#[derive(
47 Clone,
48 Copy,
49 zerocopy::FromBytes,
50 zerocopy::IntoBytes,
51 zerocopy::KnownLayout,
52 zerocopy::Immutable,
53)]
54#[repr(C)]
55struct ClientHello {
56 magic: LeU32,
57 flags: u8,
58 resume_key: ResumeKey,
59 last_received: LeU32,
60}
61
62#[derive(
66 Clone,
67 Copy,
68 zerocopy::FromBytes,
69 zerocopy::IntoBytes,
70 zerocopy::KnownLayout,
71 zerocopy::Immutable,
72)]
73#[repr(C)]
74struct ServerHello {
75 magic: LeU32,
76 flags: u8,
77 resume_key: ResumeKey,
78 last_received: LeU32,
79}
80
81#[derive(Facet, Debug, Clone)]
87struct Frame<T> {
88 seq: PacketSeq,
89 ack: Option<PacketAck>,
91 item: T,
92}
93
94pub struct Attachment<L> {
103 link: L,
104 client_hello: Option<ClientHello>,
105}
106
107impl<L> Attachment<L> {
108 pub fn initiator(link: L) -> Self {
110 Self {
111 link,
112 client_hello: None,
113 }
114 }
115}
116
117pub struct SplitLink<Tx, Rx> {
122 tx: Tx,
123 rx: Rx,
124}
125
126impl<Tx, Rx> Link for SplitLink<Tx, Rx>
127where
128 Tx: LinkTx,
129 Rx: LinkRx,
130{
131 type Tx = Tx;
132 type Rx = Rx;
133
134 fn split(self) -> (Self::Tx, Self::Rx) {
135 (self.tx, self.rx)
136 }
137}
138
139pub async fn prepare_acceptor_attachment<L: Link>(
144 link: L,
145) -> Result<Attachment<SplitLink<L::Tx, L::Rx>>, StableConduitError> {
146 let (tx, mut rx) = link.split();
147 let client_hello = recv_handshake::<_, ClientHello>(&mut rx).await?;
148 Ok(Attachment {
149 link: SplitLink { tx, rx },
150 client_hello: Some(client_hello),
151 })
152}
153
154pub trait LinkSource: Send + 'static {
156 type Link: Link + Send;
157
158 fn next_link(
159 &mut self,
160 ) -> impl Future<Output = std::io::Result<Attachment<Self::Link>>> + Send + '_;
161}
162
163pub struct StableConduit<F: MsgFamily, LS: LinkSource> {
170 shared: Arc<Shared<LS>>,
171 _phantom: PhantomData<fn(F) -> F>,
172}
173
174struct Shared<LS: LinkSource> {
175 inner: Mutex<Inner<LS>>,
176 reconnecting: AtomicBool,
177 reconnected: moire::sync::Notify,
178}
179
180struct Inner<LS: LinkSource> {
181 source: Option<LS>,
182 link_generation: u64,
185 tx: Option<<LS::Link as Link>::Tx>,
186 rx: Option<<LS::Link as Link>::Rx>,
187 resume_key: Option<ResumeKey>,
188 next_send_seq: PacketSeq,
190 last_received: Option<PacketSeq>,
191 replay: ReplayBuffer,
194}
195
196impl<F: MsgFamily, LS: LinkSource> StableConduit<F, LS> {
197 pub async fn new(mut source: LS) -> Result<Self, StableConduitError> {
198 let attachment = source.next_link().await.map_err(StableConduitError::Io)?;
199 let (link_tx, mut link_rx) = attachment.link.split();
200
201 let (resume_key, _peer_last_received) =
202 handshake::<LS::Link>(&link_tx, &mut link_rx, attachment.client_hello, None, None)
203 .await?;
204
205 let inner = Inner {
206 source: Some(source),
207 link_generation: 0,
208 tx: Some(link_tx),
209 rx: Some(link_rx),
210 resume_key: Some(resume_key),
211 next_send_seq: PacketSeq(0),
212 last_received: None,
213 replay: ReplayBuffer::new(),
214 };
215
216 Ok(Self {
217 shared: Arc::new(Shared {
218 inner: Mutex::new(inner),
219 reconnecting: AtomicBool::new(false),
220 reconnected: moire::sync::Notify::new("stable_conduit.reconnected"),
221 }),
222 _phantom: PhantomData,
223 })
224 }
225}
226
227impl<LS: LinkSource> Shared<LS> {
232 fn lock_inner(&self) -> Result<MutexGuard<'_, Inner<LS>>, StableConduitError> {
233 self.inner
234 .lock()
235 .map_err(|_| StableConduitError::Setup("stable conduit mutex poisoned".into()))
236 }
237
238 async fn ensure_reconnected(&self, generation: u64) -> Result<(), StableConduitError> {
239 loop {
240 {
241 let inner = self.lock_inner()?;
242 if inner.link_generation != generation {
243 return Ok(());
244 }
245 }
246
247 if self
248 .reconnecting
249 .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
250 .is_ok()
251 {
252 let result = self.reconnect_once(generation).await;
253 self.reconnecting.store(false, Ordering::Release);
254 self.reconnected.notify_waiters();
255 return result;
256 }
257
258 self.reconnected.notified().await;
259 }
260 }
261
262 async fn reconnect_once(&self, generation: u64) -> Result<(), StableConduitError> {
269 let (mut source, resume_key, last_received, replay_frames) = {
270 let mut inner = self.lock_inner()?;
271 if inner.link_generation != generation {
272 return Ok(());
273 }
274 let source = inner
275 .source
276 .take()
277 .ok_or_else(|| StableConduitError::Setup("link source unavailable".into()))?;
278 let replay_frames = inner
279 .replay
280 .iter()
281 .map(|(seq, bytes)| (*seq, bytes.clone()))
282 .collect::<Vec<_>>();
283 (source, inner.resume_key, inner.last_received, replay_frames)
284 };
285
286 let reconnect_result = async {
287 let attachment = source.next_link().await.map_err(StableConduitError::Io)?;
288 let (new_tx, mut new_rx) = attachment.link.split();
289
290 let (new_resume_key, peer_last_received) = handshake::<LS::Link>(
291 &new_tx,
292 &mut new_rx,
293 attachment.client_hello,
294 resume_key,
295 last_received,
296 )
297 .await?;
298
299 for (seq, frame_bytes) in replay_frames {
303 if peer_last_received.is_some_and(|last| seq <= last) {
304 continue;
305 }
306 let permit = new_tx.reserve().await.map_err(StableConduitError::Io)?;
307 let mut slot = permit
308 .alloc(frame_bytes.len())
309 .map_err(StableConduitError::Io)?;
310 slot.as_mut_slice().copy_from_slice(&frame_bytes);
311 slot.commit();
312 }
313
314 Ok::<_, StableConduitError>((new_tx, new_rx, new_resume_key))
315 }
316 .await;
317
318 let mut inner = self.lock_inner()?;
319 inner.source = Some(source);
320
321 if inner.link_generation != generation {
322 return Ok(());
323 }
324
325 let (new_tx, new_rx, new_resume_key) = reconnect_result?;
326
327 inner.link_generation = inner.link_generation.wrapping_add(1);
328 inner.tx = Some(new_tx);
329 inner.rx = Some(new_rx);
330 inner.resume_key = Some(new_resume_key);
331
332 Ok(())
333 }
334}
335
336async fn handshake<L: Link>(
344 tx: &L::Tx,
345 rx: &mut L::Rx,
346 client_hello: Option<ClientHello>,
347 resume_key: Option<ResumeKey>,
348 last_received: Option<PacketSeq>,
349) -> Result<(ResumeKey, Option<PacketSeq>), StableConduitError> {
350 match client_hello {
351 None => {
352 let mut flags = 0u8;
354 if resume_key.is_some() {
355 flags |= CH_HAS_RESUME_KEY;
356 }
357 if last_received.is_some() {
358 flags |= CH_HAS_LAST_RECEIVED;
359 }
360 let hello = ClientHello {
361 magic: LeU32::new(CLIENT_HELLO_MAGIC),
362 flags,
363 resume_key: resume_key.unwrap_or(ResumeKey([0u8; 16])),
364 last_received: LeU32::new(last_received.map_or(0, |s| s.0)),
365 };
366 send_handshake(tx, &hello).await?;
367
368 let sh = recv_handshake::<_, ServerHello>(rx).await?;
369 if sh.magic.get() != SERVER_HELLO_MAGIC {
370 return Err(StableConduitError::Setup(
371 "ServerHello magic mismatch".into(),
372 ));
373 }
374 if sh.flags & SH_REJECTED != 0 {
376 return Err(StableConduitError::SessionLost);
377 }
378 let peer_last_received =
379 (sh.flags & SH_HAS_LAST_RECEIVED != 0).then(|| PacketSeq(sh.last_received.get()));
380 Ok((sh.resume_key, peer_last_received))
381 }
382 Some(ch) => {
383 let key = fresh_key()?;
385 let mut flags = 0u8;
386 if last_received.is_some() {
387 flags |= SH_HAS_LAST_RECEIVED;
388 }
389 let hello = ServerHello {
390 magic: LeU32::new(SERVER_HELLO_MAGIC),
391 flags,
392 resume_key: key,
393 last_received: LeU32::new(last_received.map_or(0, |s| s.0)),
394 };
395 send_handshake(tx, &hello).await?;
396
397 let peer_last_received =
398 (ch.flags & CH_HAS_LAST_RECEIVED != 0).then(|| PacketSeq(ch.last_received.get()));
399 Ok((key, peer_last_received))
400 }
401 }
402}
403
404async fn send_handshake<LTx: LinkTx, M: zerocopy::IntoBytes + zerocopy::Immutable>(
405 tx: <x,
406 msg: &M,
407) -> Result<(), StableConduitError> {
408 let bytes = msg.as_bytes();
409 let permit = tx.reserve().await.map_err(StableConduitError::Io)?;
410 let mut slot = permit.alloc(bytes.len()).map_err(StableConduitError::Io)?;
411 slot.as_mut_slice().copy_from_slice(bytes);
412 slot.commit();
413 Ok(())
414}
415
416async fn recv_handshake<
417 LRx: LinkRx,
418 M: zerocopy::FromBytes + zerocopy::KnownLayout + zerocopy::Immutable,
419>(
420 rx: &mut LRx,
421) -> Result<M, StableConduitError> {
422 let backing = rx
423 .recv()
424 .await
425 .map_err(|_| StableConduitError::LinkDead)?
426 .ok_or(StableConduitError::LinkDead)?;
427 M::read_from_bytes(backing.as_bytes())
428 .map_err(|_| StableConduitError::Setup("handshake message size mismatch".into()))
429}
430
431fn fresh_key() -> Result<ResumeKey, StableConduitError> {
432 let mut key = ResumeKey([0u8; 16]);
433 getrandom::fill(&mut key.0)
434 .map_err(|e| StableConduitError::Setup(format!("failed to generate resume key: {e}")))?;
435 Ok(key)
436}
437
438impl<F: MsgFamily, LS: LinkSource> Conduit for StableConduit<F, LS>
443where
444 <LS::Link as Link>::Tx: Clone + Send + 'static,
445 <LS::Link as Link>::Rx: Send + 'static,
446 LS: Send + 'static,
447{
448 type Msg = F;
449 type Tx = StableConduitTx<F, LS>;
450 type Rx = StableConduitRx<F, LS>;
451
452 fn split(self) -> (Self::Tx, Self::Rx) {
453 (
454 StableConduitTx {
455 shared: Arc::clone(&self.shared),
456 _phantom: PhantomData,
457 },
458 StableConduitRx {
459 shared: Arc::clone(&self.shared),
460 _phantom: PhantomData,
461 },
462 )
463 }
464}
465
466pub struct StableConduitTx<F: MsgFamily, LS: LinkSource> {
471 shared: Arc<Shared<LS>>,
472 _phantom: PhantomData<fn(F)>,
473}
474
475impl<F: MsgFamily, LS: LinkSource> ConduitTx for StableConduitTx<F, LS>
476where
477 <LS::Link as Link>::Tx: Clone + Send + 'static,
478 <LS::Link as Link>::Rx: Send + 'static,
479 LS: Send + 'static,
480{
481 type Msg = F;
482 type Permit<'a>
483 = StableConduitPermit<F, LS>
484 where
485 Self: 'a;
486
487 async fn reserve(&self) -> std::io::Result<Self::Permit<'_>> {
488 loop {
489 let (tx, generation) = {
490 let inner = self
491 .shared
492 .lock_inner()
493 .map_err(|e| std::io::Error::other(e.to_string()))?;
494 (inner.tx.clone(), inner.link_generation)
495 };
496
497 let tx = match tx {
498 Some(tx) => tx,
499 None => {
500 self.shared
501 .ensure_reconnected(generation)
502 .await
503 .map_err(|e| std::io::Error::other(e.to_string()))?;
504 continue;
505 }
506 };
507
508 match tx.reserve().await {
509 Ok(link_permit) => {
510 return Ok(StableConduitPermit {
511 shared: Arc::clone(&self.shared),
512 link_permit,
513 generation,
514 _phantom: PhantomData,
515 });
516 }
517 Err(_) => {
518 self.shared
519 .ensure_reconnected(generation)
520 .await
521 .map_err(|e| std::io::Error::other(e.to_string()))?;
522 }
523 }
524 }
525 }
526
527 async fn close(self) -> std::io::Result<()> {
528 let tx = {
529 let mut inner = self
530 .shared
531 .lock_inner()
532 .map_err(|e| std::io::Error::other(e.to_string()))?;
533 inner.tx.take()
534 };
535 if let Some(tx) = tx {
536 tx.close().await?;
537 }
538 Ok(())
539 }
540}
541
542pub struct StableConduitPermit<F: MsgFamily, LS: LinkSource> {
547 shared: Arc<Shared<LS>>,
548 link_permit: <<LS::Link as Link>::Tx as LinkTx>::Permit,
549 generation: u64,
550 _phantom: PhantomData<fn(F)>,
551}
552
553impl<F: MsgFamily, LS: LinkSource> ConduitTxPermit for StableConduitPermit<F, LS> {
554 type Msg = F;
555 type Error = StableConduitError;
556
557 fn send(self, item: F::Msg<'_>) -> Result<(), StableConduitError> {
566 let StableConduitPermit {
567 shared,
568 link_permit,
569 generation,
570 _phantom: _,
571 } = self;
572
573 let (seq, ack) = {
574 let mut inner = shared.lock_inner()?;
575 if inner.link_generation != generation {
576 return Err(StableConduitError::LinkDead);
577 }
578 let seq = inner.next_send_seq;
579 inner.next_send_seq = PacketSeq(seq.0.wrapping_add(1));
580 let ack = inner
581 .last_received
582 .map(|max_delivered| PacketAck { max_delivered });
583 (seq, ack)
584 };
585
586 let frame = Frame { seq, ack, item };
587
588 #[allow(unsafe_code)]
591 let peek = unsafe {
592 Peek::unchecked_new(
593 PtrConst::new((&raw const frame).cast::<u8>()),
594 Frame::<F::Msg<'static>>::SHAPE,
595 )
596 };
597 let plan =
598 facet_postcard::peek_to_scatter_plan(peek).map_err(StableConduitError::Encode)?;
599
600 let mut slot = link_permit
601 .alloc(plan.total_size())
602 .map_err(StableConduitError::Io)?;
603 let slot_bytes = slot.as_mut_slice();
604 plan.write_into(slot_bytes)
605 .map_err(StableConduitError::Encode)?;
606
607 shared.lock_inner()?.replay.push(seq, slot_bytes.to_vec());
609 slot.commit();
610
611 Ok(())
612 }
613}
614
615pub struct StableConduitRx<F: MsgFamily, LS: LinkSource> {
620 shared: Arc<Shared<LS>>,
621 _phantom: PhantomData<fn() -> F>,
622}
623
624impl<F: MsgFamily, LS: LinkSource> ConduitRx for StableConduitRx<F, LS>
625where
626 <LS::Link as Link>::Tx: Send + 'static,
627 <LS::Link as Link>::Rx: Send + 'static,
628 LS: Send + 'static,
629{
630 type Msg = F;
631 type Error = StableConduitError;
632
633 #[moire::instrument]
634 async fn recv(&mut self) -> Result<Option<SelfRef<F::Msg<'static>>>, Self::Error> {
635 loop {
636 let (rx_opt, generation) = {
638 let mut inner = self.shared.lock_inner()?;
639 (inner.rx.take(), inner.link_generation)
640 }; let mut rx = match rx_opt {
642 Some(rx) => rx,
643 None => {
644 self.shared.ensure_reconnected(generation).await?;
645 continue;
646 }
647 };
648
649 let recv_result = rx.recv().await;
653
654 {
657 let mut inner = self.shared.lock_inner()?;
658 if inner.link_generation == generation && inner.rx.is_none() {
659 inner.rx = Some(rx);
660 }
661 }
662
663 let backing = match recv_result {
664 Ok(Some(b)) => b,
665 Ok(None) | Err(_) => {
666 self.shared.ensure_reconnected(generation).await?;
668 continue;
669 }
670 };
671
672 let frame: SelfRef<Frame<F::Msg<'static>>> =
674 crate::deserialize_postcard(backing).map_err(StableConduitError::Decode)?;
675
676 let is_dup = {
680 let mut inner = self.shared.lock_inner()?;
681
682 if let Some(ack) = frame.ack {
683 inner.replay.trim(ack);
684 }
685
686 let dup = inner.last_received.is_some_and(|prev| frame.seq <= prev);
687 if !dup {
688 inner.last_received = Some(frame.seq);
689 }
690 dup
691 };
692
693 if is_dup {
694 continue;
695 }
696
697 return Ok(Some(frame.map(|f| f.item)));
698 }
699 }
700}
701
702#[derive(Debug)]
707pub enum StableConduitError {
708 Encode(facet_postcard::SerializeError),
709 Decode(facet_format::DeserializeError),
710 Io(std::io::Error),
711 LinkDead,
712 Setup(String),
713 SessionLost,
716}
717
718impl std::fmt::Display for StableConduitError {
719 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
720 match self {
721 Self::Encode(e) => write!(f, "encode error: {e}"),
722 Self::Decode(e) => write!(f, "decode error: {e}"),
723 Self::Io(e) => write!(f, "io error: {e}"),
724 Self::LinkDead => write!(f, "link dead"),
725 Self::Setup(s) => write!(f, "setup error: {s}"),
726 Self::SessionLost => write!(f, "session lost: server rejected resume key"),
727 }
728 }
729}
730
731impl std::error::Error for StableConduitError {}
732
733#[cfg(test)]
738mod tests;