1use std::marker::PhantomData;
86
87pub struct Send<T, Next: Session> {
91 _phantom: PhantomData<(T, Next)>,
92}
93
94pub struct Recv<T, Next: Session> {
96 _phantom: PhantomData<(T, Next)>,
97}
98
99pub struct Choose<A: Session, B: Session> {
103 _phantom: PhantomData<(A, B)>,
104}
105
106pub struct Offer<A: Session, B: Session> {
110 _phantom: PhantomData<(A, B)>,
111}
112
113pub struct End;
115
116pub trait Session: std::marker::Send + 'static {
123 type Dual: Session<Dual = Self>;
127}
128
129impl Session for End {
132 type Dual = Self;
133}
134
135impl<T: std::marker::Send + 'static, Next: Session> Session for self::Send<T, Next> {
136 type Dual = Recv<T, Next::Dual>;
137}
138
139impl<T: std::marker::Send + 'static, Next: Session> Session for Recv<T, Next> {
140 type Dual = self::Send<T, Next::Dual>;
141}
142
143impl<A: Session, B: Session> Session for Choose<A, B> {
144 type Dual = Offer<A::Dual, B::Dual>;
145}
146
147impl<A: Session, B: Session> Session for Offer<A, B> {
148 type Dual = Choose<A::Dual, B::Dual>;
149}
150
151pub type Dual<S> = <S as Session>::Dual;
165
166pub struct Endpoint<S: Session> {
177 _session: PhantomData<S>,
178 tx: crate::channel::mpsc::Sender<Box<dyn std::any::Any + std::marker::Send>>,
180 rx: crate::channel::mpsc::Receiver<Box<dyn std::any::Any + std::marker::Send>>,
182}
183
184#[derive(Debug)]
186pub enum SessionError {
187 Disconnected,
189 TypeMismatch,
191 Cancelled,
193}
194
195#[must_use]
209pub fn channel<S: Session>() -> (Endpoint<S>, Endpoint<Dual<S>>) {
210 let (tx1, rx1) = crate::channel::mpsc::channel(1);
211 let (tx2, rx2) = crate::channel::mpsc::channel(1);
212
213 let ep1 = Endpoint {
214 _session: PhantomData,
215 tx: tx1,
216 rx: rx2,
217 };
218 let ep2 = Endpoint {
219 _session: PhantomData,
220 tx: tx2,
221 rx: rx1,
222 };
223
224 (ep1, ep2)
225}
226
227impl<T, Next> Endpoint<self::Send<T, Next>>
228where
229 T: std::marker::Send + 'static,
230 Next: Session,
231{
232 pub async fn send(self, cx: &crate::cx::Cx, value: T) -> Result<Endpoint<Next>, SessionError> {
237 let Self { tx, rx, .. } = self;
238 let boxed: Box<dyn std::any::Any + std::marker::Send> = Box::new(value);
239 tx.send(cx, boxed)
240 .await
241 .map_err(|_| SessionError::Disconnected)?;
242 Ok(Endpoint {
243 _session: PhantomData,
244 tx,
245 rx,
246 })
247 }
248}
249
250impl<T, Next> Endpoint<Recv<T, Next>>
251where
252 T: std::marker::Send + 'static,
253 Next: Session,
254{
255 pub async fn recv(self, cx: &crate::cx::Cx) -> Result<(T, Endpoint<Next>), SessionError> {
259 let Self { tx, mut rx, .. } = self;
260 let boxed = rx.recv(cx).await.map_err(|e| match e {
261 crate::channel::mpsc::RecvError::Cancelled => SessionError::Cancelled,
262 crate::channel::mpsc::RecvError::Disconnected
263 | crate::channel::mpsc::RecvError::Empty => SessionError::Disconnected,
264 })?;
265 let value = boxed
266 .downcast::<T>()
267 .map_err(|_| SessionError::TypeMismatch)?;
268 Ok((
269 *value,
270 Endpoint {
271 _session: PhantomData,
272 tx,
273 rx,
274 },
275 ))
276 }
277}
278
279impl<A: Session, B: Session> Endpoint<Choose<A, B>> {
280 pub async fn choose_left(self, cx: &crate::cx::Cx) -> Result<Endpoint<A>, SessionError> {
284 let Self { tx, rx, .. } = self;
285 let boxed: Box<dyn std::any::Any + std::marker::Send> = Box::new(Branch::Left);
286 tx.send(cx, boxed)
287 .await
288 .map_err(|_| SessionError::Disconnected)?;
289 Ok(Endpoint {
290 _session: PhantomData,
291 tx,
292 rx,
293 })
294 }
295
296 pub async fn choose_right(self, cx: &crate::cx::Cx) -> Result<Endpoint<B>, SessionError> {
300 let Self { tx, rx, .. } = self;
301 let boxed: Box<dyn std::any::Any + std::marker::Send> = Box::new(Branch::Right);
302 tx.send(cx, boxed)
303 .await
304 .map_err(|_| SessionError::Disconnected)?;
305 Ok(Endpoint {
306 _session: PhantomData,
307 tx,
308 rx,
309 })
310 }
311}
312
313pub enum Offered<A: Session, B: Session> {
315 Left(Endpoint<A>),
317 Right(Endpoint<B>),
319}
320
321impl<A: Session, B: Session> Endpoint<Offer<A, B>> {
322 pub async fn offer(self, cx: &crate::cx::Cx) -> Result<Offered<A, B>, SessionError> {
326 let Self { tx, mut rx, .. } = self;
327 let boxed = rx.recv(cx).await.map_err(|e| match e {
328 crate::channel::mpsc::RecvError::Cancelled => SessionError::Cancelled,
329 crate::channel::mpsc::RecvError::Disconnected
330 | crate::channel::mpsc::RecvError::Empty => SessionError::Disconnected,
331 })?;
332 let branch = boxed
333 .downcast::<Branch>()
334 .map_err(|_| SessionError::TypeMismatch)?;
335 match *branch {
336 Branch::Left => Ok(Offered::Left(Endpoint {
337 _session: PhantomData,
338 tx,
339 rx,
340 })),
341 Branch::Right => Ok(Offered::Right(Endpoint {
342 _session: PhantomData,
343 tx,
344 rx,
345 })),
346 }
347 }
348}
349
350impl Endpoint<End> {
351 pub fn close(self) {
356 }
358}
359
360#[derive(Debug, Clone, Copy, PartialEq, Eq)]
364pub enum Branch {
365 Left,
367 Right,
369}
370
371#[cfg(test)]
374mod tests {
375 use super::*;
376
377 fn init_test(name: &str) {
378 crate::test_utils::init_test_logging();
379 crate::test_phase!(name);
380 }
381
382 fn assert_dual<S: Session>()
385 where
386 S::Dual: Session<Dual = S>,
387 {
388 }
390
391 #[test]
392 fn duality_end() {
393 fn _check() -> Dual<End> {
394 End
395 }
396
397 init_test("duality_end");
398
399 assert_dual::<End>();
400 crate::test_complete!("duality_end");
403 }
404
405 #[test]
406 fn duality_send_recv() {
407 init_test("duality_send_recv");
408
409 assert_dual::<Send<String, End>>();
411 assert_dual::<Recv<String, End>>();
412
413 assert_dual::<Send<u64, Recv<bool, End>>>();
415
416 crate::test_complete!("duality_send_recv");
417 }
418
419 #[test]
420 fn duality_choose_offer() {
421 init_test("duality_choose_offer");
422
423 assert_dual::<Choose<End, End>>();
425 assert_dual::<Offer<End, End>>();
426 assert_dual::<Choose<Send<u8, End>, Recv<u8, End>>>();
427
428 crate::test_complete!("duality_choose_offer");
429 }
430
431 #[test]
432 fn duality_is_involutive() {
433 fn _roundtrip_end(_: Dual<Dual<End>>) -> End {
437 End
438 }
439
440 fn _roundtrip_send(_: Dual<Dual<Send<u32, End>>>) -> Send<u32, End> {
441 Send {
442 _phantom: PhantomData,
443 }
444 }
445
446 init_test("duality_is_involutive");
447
448 crate::test_complete!("duality_is_involutive");
449 }
450
451 #[test]
452 fn duality_complex_protocol() {
453 type Card = u64;
459 type Pin = u32;
460 type Amount = u64;
461 type Cash = u64;
462 type Balance = u64;
463
464 type ClientProtocol =
465 Send<Card, Recv<Pin, Choose<Send<Amount, Recv<Cash, End>>, Recv<Balance, End>>>>;
466
467 type ServerProtocol = Dual<ClientProtocol>;
473
474 fn _accept_server(_: ServerProtocol) {}
476
477 init_test("duality_complex_protocol");
478
479 assert_dual::<ClientProtocol>();
480 assert_dual::<ServerProtocol>();
481
482 crate::test_complete!("duality_complex_protocol");
483 }
484
485 #[test]
486 fn channel_creates_dual_endpoints() {
487 type P = Send<u32, Recv<bool, End>>;
488
489 init_test("channel_creates_dual_endpoints");
490 let (_client, _server) = channel::<P>();
491
492 crate::test_complete!("channel_creates_dual_endpoints");
496 }
497
498 #[test]
499 fn endpoint_close_at_end() {
500 init_test("endpoint_close_at_end");
501
502 let (ep1, ep2) = channel::<End>();
503 ep1.close();
504 ep2.close();
505
506 crate::test_complete!("endpoint_close_at_end");
507 }
508
509 #[test]
510 fn branch_enum() {
511 init_test("branch_enum");
512
513 let left = Branch::Left;
514 let right = Branch::Right;
515 assert_ne!(left, right);
516 assert_eq!(left, Branch::Left);
517 assert_eq!(right, Branch::Right);
518
519 crate::test_complete!("branch_enum");
520 }
521
522 use std::sync::Arc;
525 use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
526
527 #[test]
529 fn session_send_recv_e2e() {
530 type ClientP = Send<u64, Recv<u64, End>>;
532
533 init_test("session_send_recv_e2e");
534
535 let mut runtime = crate::lab::LabRuntime::new(crate::lab::LabConfig::default());
536 let region = runtime
537 .state
538 .create_root_region(crate::types::Budget::INFINITE);
539
540 let (client_ep, server_ep) = channel::<ClientP>();
541
542 let client_result = Arc::new(AtomicU64::new(0));
543 let server_result = Arc::new(AtomicU64::new(0));
544 let cr = client_result.clone();
545 let sr = server_result.clone();
546
547 let (client_id, _) = runtime
549 .state
550 .create_task(region, crate::types::Budget::INFINITE, async move {
551 let cx: crate::cx::Cx = crate::cx::Cx::for_testing();
552 let ep = client_ep.send(&cx, 42).await.expect("client send");
553 let (response, ep) = ep.recv(&cx).await.expect("client recv");
554 cr.store(response, Ordering::SeqCst);
555 ep.close();
556 })
557 .unwrap();
558
559 let (server_id, _) = runtime
561 .state
562 .create_task(region, crate::types::Budget::INFINITE, async move {
563 let cx: crate::cx::Cx = crate::cx::Cx::for_testing();
564 let (request, ep) = server_ep.recv(&cx).await.expect("server recv");
565 sr.store(request, Ordering::SeqCst);
566 let ep = ep.send(&cx, request * 2).await.expect("server send");
567 ep.close();
568 })
569 .unwrap();
570
571 runtime.scheduler.lock().schedule(client_id, 0);
572 runtime.scheduler.lock().schedule(server_id, 0);
573 runtime.run_until_quiescent();
574
575 assert_eq!(
576 server_result.load(Ordering::SeqCst),
577 42,
578 "server received 42"
579 );
580 assert_eq!(
581 client_result.load(Ordering::SeqCst),
582 84,
583 "client received 84"
584 );
585
586 crate::test_complete!("session_send_recv_e2e");
587 }
588
589 #[test]
591 fn session_choose_offer_e2e() {
592 type ClientP = Choose<Send<u64, End>, Recv<u64, End>>;
594
595 init_test("session_choose_offer_e2e");
596
597 let mut runtime = crate::lab::LabRuntime::new(crate::lab::LabConfig::default());
598 let region = runtime
599 .state
600 .create_root_region(crate::types::Budget::INFINITE);
601
602 let (client_ep, server_ep) = channel::<ClientP>();
603
604 let left_taken = Arc::new(AtomicBool::new(false));
605 let value_sent = Arc::new(AtomicU64::new(0));
606 let lt = left_taken.clone();
607 let vs = value_sent.clone();
608
609 let (client_id, _) = runtime
611 .state
612 .create_task(region, crate::types::Budget::INFINITE, async move {
613 let cx: crate::cx::Cx = crate::cx::Cx::for_testing();
614 let ep = client_ep.choose_left(&cx).await.expect("choose left");
615 let ep = ep.send(&cx, 99).await.expect("send on left");
616 ep.close();
617 })
618 .unwrap();
619
620 let (server_id, _) = runtime
622 .state
623 .create_task(region, crate::types::Budget::INFINITE, async move {
624 let cx: crate::cx::Cx = crate::cx::Cx::for_testing();
625 match server_ep.offer(&cx).await.expect("offer") {
626 Offered::Left(ep) => {
627 lt.store(true, Ordering::SeqCst);
628 let (val, ep) = ep.recv(&cx).await.expect("recv on left");
629 vs.store(val, Ordering::SeqCst);
630 ep.close();
631 }
632 Offered::Right(ep) => {
633 let ep = ep.send(&cx, 0).await.unwrap();
635 ep.close();
636 }
637 }
638 })
639 .unwrap();
640
641 runtime.scheduler.lock().schedule(client_id, 0);
642 runtime.scheduler.lock().schedule(server_id, 0);
643 runtime.run_until_quiescent();
644
645 assert!(left_taken.load(Ordering::SeqCst), "server took left branch");
646 assert_eq!(value_sent.load(Ordering::SeqCst), 99, "server received 99");
647
648 crate::test_complete!("session_choose_offer_e2e");
649 }
650
651 #[test]
653 fn session_deterministic() {
654 fn run_protocol(seed: u64) -> u64 {
655 type P = Send<u64, Recv<u64, End>>;
656
657 let config = crate::lab::LabConfig::new(seed);
658 let mut runtime = crate::lab::LabRuntime::new(config);
659 let region = runtime
660 .state
661 .create_root_region(crate::types::Budget::INFINITE);
662 let (client_ep, server_ep) = channel::<P>();
663
664 let result = Arc::new(AtomicU64::new(0));
665 let r = result.clone();
666
667 let (cid, _) = runtime
668 .state
669 .create_task(region, crate::types::Budget::INFINITE, async move {
670 let cx: crate::cx::Cx = crate::cx::Cx::for_testing();
671 let ep = client_ep.send(&cx, 7).await.unwrap();
672 let (val, ep) = ep.recv(&cx).await.unwrap();
673 r.store(val, Ordering::SeqCst);
674 ep.close();
675 })
676 .unwrap();
677
678 let (sid, _) = runtime
679 .state
680 .create_task(region, crate::types::Budget::INFINITE, async move {
681 let cx: crate::cx::Cx = crate::cx::Cx::for_testing();
682 let (v, ep) = server_ep.recv(&cx).await.unwrap();
683 let ep = ep.send(&cx, v + 100).await.unwrap();
684 ep.close();
685 })
686 .unwrap();
687
688 runtime.scheduler.lock().schedule(cid, 0);
689 runtime.scheduler.lock().schedule(sid, 0);
690 runtime.run_until_quiescent();
691
692 result.load(Ordering::SeqCst)
693 }
694
695 init_test("session_deterministic");
696
697 let r1 = run_protocol(0xCAFE);
698 let r2 = run_protocol(0xCAFE);
699 assert_eq!(r1, r2, "deterministic replay");
700 assert_eq!(r1, 107, "7 + 100 = 107");
701
702 crate::test_complete!("session_deterministic");
703 }
704
705 #[test]
708 fn session_error_debug() {
709 let e1 = SessionError::Disconnected;
710 let e2 = SessionError::TypeMismatch;
711 let e3 = SessionError::Cancelled;
712
713 let dbg1 = format!("{e1:?}");
714 let dbg2 = format!("{e2:?}");
715 let dbg3 = format!("{e3:?}");
716
717 assert!(dbg1.contains("Disconnected"));
718 assert!(dbg2.contains("TypeMismatch"));
719 assert!(dbg3.contains("Cancelled"));
720 }
721
722 #[test]
723 fn branch_debug_copy() {
724 let left = Branch::Left;
725 let right = Branch::Right;
726
727 let dbg_l = format!("{left:?}");
728 let dbg_r = format!("{right:?}");
729 assert!(dbg_l.contains("Left"));
730 assert!(dbg_r.contains("Right"));
731
732 let left2 = left;
734 assert_eq!(left, left2);
735
736 let right2 = right;
738 assert_eq!(right, right2);
739 }
740}