1use super::functions::*;
6use std::collections::{HashMap, HashSet, VecDeque};
7
8#[derive(Debug, Clone)]
10pub enum GType {
11 Comm {
13 sender: Role,
15 receiver: Role,
17 msg_ty: BaseType,
19 cont: Box<GType>,
21 },
22 Choice {
24 selector: Role,
26 receiver: Role,
28 branches: HashMap<String, GType>,
30 },
31 End,
33 Rec(String, Box<GType>),
35 Var(String),
37}
38impl GType {
39 pub fn participants(&self) -> HashSet<Role> {
41 let mut roles = HashSet::new();
42 self.collect_roles(&mut roles);
43 roles
44 }
45 fn collect_roles(&self, roles: &mut HashSet<Role>) {
46 match self {
47 GType::Comm {
48 sender,
49 receiver,
50 cont,
51 ..
52 } => {
53 roles.insert(sender.clone());
54 roles.insert(receiver.clone());
55 cont.collect_roles(roles);
56 }
57 GType::Choice {
58 selector,
59 receiver,
60 branches,
61 } => {
62 roles.insert(selector.clone());
63 roles.insert(receiver.clone());
64 for cont in branches.values() {
65 cont.collect_roles(roles);
66 }
67 }
68 GType::End | GType::Var(_) => {}
69 GType::Rec(_, body) => body.collect_roles(roles),
70 }
71 }
72 pub fn project(&self, role: &Role) -> LType {
74 match self {
75 GType::Comm {
76 sender,
77 receiver,
78 msg_ty,
79 cont,
80 } => {
81 let cont_proj = cont.project(role);
82 if sender == role {
83 LType::Send(receiver.clone(), msg_ty.clone(), Box::new(cont_proj))
84 } else if receiver == role {
85 LType::Recv(sender.clone(), msg_ty.clone(), Box::new(cont_proj))
86 } else {
87 cont_proj
88 }
89 }
90 GType::Choice {
91 selector,
92 receiver,
93 branches,
94 } => {
95 if selector == role {
96 let mut proj_branches: Vec<(String, LType)> = branches
97 .iter()
98 .map(|(lbl, g)| (lbl.clone(), g.project(role)))
99 .collect();
100 proj_branches.sort_by(|a, b| a.0.cmp(&b.0));
101 LType::IChoice(receiver.clone(), proj_branches)
102 } else if receiver == role {
103 let mut proj_branches: Vec<(String, LType)> = branches
104 .iter()
105 .map(|(lbl, g)| (lbl.clone(), g.project(role)))
106 .collect();
107 proj_branches.sort_by(|a, b| a.0.cmp(&b.0));
108 LType::EChoice(selector.clone(), proj_branches)
109 } else {
110 let projs: Vec<LType> = branches.values().map(|g| g.project(role)).collect();
111 Self::merge_all(projs)
112 }
113 }
114 GType::End => LType::End,
115 GType::Rec(x, body) => LType::Rec(x.clone(), Box::new(body.project(role))),
116 GType::Var(x) => LType::Var(x.clone()),
117 }
118 }
119 fn merge_all(types: Vec<LType>) -> LType {
120 types
121 .into_iter()
122 .reduce(|a, b| if a == b { a } else { b })
123 .unwrap_or(LType::End)
124 }
125}
126pub struct AsyncSessionEndpoint {
132 pub remaining: SType,
134 outbox: VecDeque<Message>,
136 inbox: VecDeque<Message>,
138}
139impl AsyncSessionEndpoint {
140 pub fn new(stype: SType) -> Self {
142 AsyncSessionEndpoint {
143 remaining: stype,
144 outbox: VecDeque::new(),
145 inbox: VecDeque::new(),
146 }
147 }
148 pub fn async_send(&mut self, msg: Message) -> Result<(), String> {
150 match &self.remaining.clone() {
151 SType::Send(_, cont) => {
152 self.remaining = *cont.clone();
153 self.outbox.push_back(msg);
154 Ok(())
155 }
156 other => Err(format!("AsyncSend: expected Send, got {}", other)),
157 }
158 }
159 pub fn flush_to(&mut self, peer: &mut AsyncSessionEndpoint) -> usize {
162 let count = self.outbox.len();
163 while let Some(msg) = self.outbox.pop_front() {
164 peer.inbox.push_back(msg);
165 }
166 count
167 }
168 pub fn async_recv(&mut self) -> Result<Message, String> {
170 match &self.remaining.clone() {
171 SType::Recv(_, cont) => {
172 if let Some(msg) = self.inbox.pop_front() {
173 self.remaining = *cont.clone();
174 Ok(msg)
175 } else {
176 Err("AsyncRecv: inbox empty — message not yet delivered".to_string())
177 }
178 }
179 other => Err(format!("AsyncRecv: expected Recv, got {}", other)),
180 }
181 }
182 pub fn outbox_len(&self) -> usize {
184 self.outbox.len()
185 }
186 pub fn inbox_len(&self) -> usize {
188 self.inbox.len()
189 }
190}
191#[derive(Debug, Clone, PartialEq, Eq, Hash)]
193pub struct Role(pub String);
194impl Role {
195 pub fn new(name: impl Into<String>) -> Self {
197 Role(name.into())
198 }
199}
200pub struct ProtocolBuilder {
202 current: SType,
203}
204impl ProtocolBuilder {
205 pub fn end() -> Self {
207 ProtocolBuilder {
208 current: SType::End,
209 }
210 }
211 pub fn then_send(self, ty: BaseType) -> Self {
213 ProtocolBuilder {
214 current: SType::Send(Box::new(ty), Box::new(self.current)),
215 }
216 }
217 pub fn then_recv(self, ty: BaseType) -> Self {
219 ProtocolBuilder {
220 current: SType::Recv(Box::new(ty), Box::new(self.current)),
221 }
222 }
223 pub fn build(self) -> SType {
225 self.current
226 }
227}
228#[derive(Debug, Clone, PartialEq, Eq)]
230pub enum SType {
231 Send(Box<BaseType>, Box<SType>),
233 Recv(Box<BaseType>, Box<SType>),
235 End,
237 Choice(Box<SType>, Box<SType>),
239 Branch(Box<SType>, Box<SType>),
241 Rec(String, Box<SType>),
243 Var(String),
245}
246impl SType {
247 pub fn dual(&self) -> SType {
249 match self {
250 SType::Send(t, s) => SType::Recv(t.clone(), Box::new(s.dual())),
251 SType::Recv(t, s) => SType::Send(t.clone(), Box::new(s.dual())),
252 SType::End => SType::End,
253 SType::Choice(s1, s2) => SType::Branch(Box::new(s1.dual()), Box::new(s2.dual())),
254 SType::Branch(s1, s2) => SType::Choice(Box::new(s1.dual()), Box::new(s2.dual())),
255 SType::Rec(x, s) => SType::Rec(x.clone(), Box::new(s.dual())),
256 SType::Var(x) => SType::Var(x.clone()),
257 }
258 }
259 pub fn unfold(&self) -> SType {
261 match self {
262 SType::Rec(x, body) => {
263 let mut body = (**body).clone();
264 body.subst_var(x, self);
265 body
266 }
267 other => other.clone(),
268 }
269 }
270 fn subst_var(&mut self, x: &str, replacement: &SType) {
272 match self {
273 SType::Send(_, s) | SType::Recv(_, s) => s.subst_var(x, replacement),
274 SType::End => {}
275 SType::Choice(s1, s2) | SType::Branch(s1, s2) => {
276 s1.subst_var(x, replacement);
277 s2.subst_var(x, replacement);
278 }
279 SType::Rec(y, s) => {
280 if y != x {
281 s.subst_var(x, replacement);
282 }
283 }
284 SType::Var(y) => {
285 if y == x {
286 *self = replacement.clone();
287 }
288 }
289 }
290 }
291 pub fn is_end(&self) -> bool {
293 matches!(self, SType::End)
294 }
295 pub fn is_send(&self) -> bool {
297 matches!(self, SType::Send(_, _))
298 }
299 pub fn is_recv(&self) -> bool {
301 matches!(self, SType::Recv(_, _))
302 }
303}
304#[derive(Debug, Clone, PartialEq, Eq)]
306pub enum BaseType {
307 Nat,
309 Bool,
311 Str,
313 Unit,
315 Named(String),
317 Pair(Box<BaseType>, Box<BaseType>),
319 Sum(Box<BaseType>, Box<BaseType>),
321}
322pub struct ProbSessionScheduler {
325 branches: Vec<ProbBranch>,
327}
328impl ProbSessionScheduler {
329 pub fn new(branches: Vec<ProbBranch>) -> Self {
331 ProbSessionScheduler { branches }
332 }
333 pub fn probabilities(&self) -> Vec<f64> {
335 let total: f64 = self.branches.iter().map(|b| b.weight).sum();
336 if total == 0.0 {
337 return vec![0.0; self.branches.len()];
338 }
339 self.branches.iter().map(|b| b.weight / total).collect()
340 }
341 pub fn greedy_choice(&self) -> Option<usize> {
344 self.branches
345 .iter()
346 .enumerate()
347 .max_by(|a, b| {
348 a.1.weight
349 .partial_cmp(&b.1.weight)
350 .unwrap_or(std::cmp::Ordering::Equal)
351 })
352 .map(|(idx, _)| idx)
353 }
354 pub fn expected_rounds(&self) -> f64 {
357 let probs = self.probabilities();
358 probs
359 .iter()
360 .zip(self.branches.iter())
361 .map(|(p, b)| {
362 let cost = if b.cont == SType::End { 0.0 } else { 1.0 };
363 p * cost
364 })
365 .sum()
366 }
367}
368#[derive(Debug, Clone)]
370pub enum Message {
371 Nat(u64),
373 Bool(bool),
375 Str(String),
377 Unit,
379 Left,
381 Right,
383}
384pub struct SessionSubtypeChecker {
393 decided: HashSet<(String, String)>,
395}
396impl SessionSubtypeChecker {
397 pub fn new() -> Self {
399 SessionSubtypeChecker {
400 decided: HashSet::new(),
401 }
402 }
403 pub fn is_subtype(&mut self, sub: &SType, sup: &SType) -> bool {
405 let key = (format!("{}", sub), format!("{}", sup));
406 if self.decided.contains(&key) {
407 return true;
408 }
409 self.decided.insert(key);
410 match (sub, sup) {
411 (SType::End, SType::End) => true,
412 (SType::Send(t1, s1), SType::Send(t2, s2)) => t1 == t2 && self.is_subtype(s1, s2),
413 (SType::Recv(t1, s1), SType::Recv(t2, s2)) => t1 == t2 && self.is_subtype(s1, s2),
414 (SType::Choice(l1, r1), SType::Choice(l2, r2)) => {
415 self.is_subtype(l1, l2) && self.is_subtype(r1, r2)
416 }
417 (SType::Branch(l1, r1), SType::Branch(l2, r2)) => {
418 self.is_subtype(l1, l2) && self.is_subtype(r1, r2)
419 }
420 (SType::Rec(_, _), _) => self.is_subtype(&sub.unfold(), sup),
421 (_, SType::Rec(_, _)) => self.is_subtype(sub, &sup.unfold()),
422 _ => false,
423 }
424 }
425}
426#[derive(Debug, Clone)]
428pub enum ChoreographyStep {
429 Comm {
431 sender: String,
433 receiver: String,
435 msg_ty: String,
437 },
438 Choice {
440 selector: String,
442 receiver: String,
444 branch: String,
446 },
447 End,
449}
450pub struct ChoreographyEngine {
452 pub trace: Vec<ChoreographyStep>,
454}
455impl ChoreographyEngine {
456 pub fn new() -> Self {
458 ChoreographyEngine { trace: vec![] }
459 }
460 pub fn execute(&mut self, gtype: >ype) -> Result<(), String> {
464 match gtype {
465 GType::Comm {
466 sender,
467 receiver,
468 msg_ty,
469 cont,
470 } => {
471 self.trace.push(ChoreographyStep::Comm {
472 sender: sender.0.clone(),
473 receiver: receiver.0.clone(),
474 msg_ty: format!("{}", msg_ty),
475 });
476 self.execute(cont)
477 }
478 GType::Choice {
479 selector,
480 receiver,
481 branches,
482 } => {
483 let mut sorted: Vec<(&String, >ype)> = branches.iter().collect();
484 sorted.sort_by_key(|(k, _)| k.as_str());
485 if let Some((label, cont)) = sorted.first() {
486 self.trace.push(ChoreographyStep::Choice {
487 selector: selector.0.clone(),
488 receiver: receiver.0.clone(),
489 branch: (*label).clone(),
490 });
491 self.execute(cont)
492 } else {
493 Err("GType::Choice has no branches".to_string())
494 }
495 }
496 GType::End => {
497 self.trace.push(ChoreographyStep::End);
498 Ok(())
499 }
500 GType::Rec(_, body) => self.execute(body),
501 GType::Var(x) => Err(format!("Unresolved recursion variable: {}", x)),
502 }
503 }
504 pub fn comm_count(&self) -> usize {
506 self.trace
507 .iter()
508 .filter(|s| !matches!(s, ChoreographyStep::End))
509 .count()
510 }
511}
512pub struct SessionEndpoint {
517 pub remaining: SType,
519 buffer: VecDeque<Message>,
521 closed: bool,
523}
524impl SessionEndpoint {
525 pub fn new(stype: SType) -> Self {
527 SessionEndpoint {
528 remaining: stype,
529 buffer: VecDeque::new(),
530 closed: false,
531 }
532 }
533 pub fn send(&mut self, msg: Message) -> Result<(), String> {
535 match &self.remaining.clone() {
536 SType::Send(_, continuation) => {
537 self.remaining = *continuation.clone();
538 self.buffer.push_back(msg);
539 Ok(())
540 }
541 other => Err(format!("Expected Send, got {}", other)),
542 }
543 }
544 pub fn recv(&mut self) -> Result<Message, String> {
546 match &self.remaining.clone() {
547 SType::Recv(_, continuation) => {
548 if let Some(msg) = self.buffer.pop_front() {
549 self.remaining = *continuation.clone();
550 Ok(msg)
551 } else {
552 Err("No message available".to_string())
553 }
554 }
555 other => Err(format!("Expected Recv, got {}", other)),
556 }
557 }
558 pub fn select_left(&mut self) -> Result<(), String> {
560 match &self.remaining.clone() {
561 SType::Choice(left, _) => {
562 self.remaining = *left.clone();
563 Ok(())
564 }
565 other => Err(format!("Expected Choice, got {}", other)),
566 }
567 }
568 pub fn select_right(&mut self) -> Result<(), String> {
570 match &self.remaining.clone() {
571 SType::Choice(_, right) => {
572 self.remaining = *right.clone();
573 Ok(())
574 }
575 other => Err(format!("Expected Choice, got {}", other)),
576 }
577 }
578 pub fn close(&mut self) -> Result<(), String> {
580 if self.remaining == SType::End {
581 self.closed = true;
582 Ok(())
583 } else {
584 Err(format!("Expected End, got {}", self.remaining))
585 }
586 }
587 pub fn is_complete(&self) -> bool {
589 self.closed
590 }
591}
592#[allow(clippy::too_many_arguments)]
597pub struct ProbBranch {
598 pub label: String,
600 pub weight: f64,
602 pub cont: SType,
604}
605pub struct DeadlockChecker {
610 wait_edges: Vec<(String, String, String)>,
612}
613impl DeadlockChecker {
614 pub fn new() -> Self {
616 DeadlockChecker { wait_edges: vec![] }
617 }
618 pub fn add_wait(
620 &mut self,
621 channel: impl Into<String>,
622 waiter: impl Into<String>,
623 provider: impl Into<String>,
624 ) {
625 self.wait_edges
626 .push((channel.into(), waiter.into(), provider.into()));
627 }
628 pub fn is_deadlock_free(&self) -> bool {
630 let mut adj: HashMap<&str, Vec<&str>> = HashMap::new();
631 for (_, waiter, provider) in &self.wait_edges {
632 adj.entry(waiter.as_str())
633 .or_default()
634 .push(provider.as_str());
635 }
636 let mut visited: HashSet<&str> = HashSet::new();
637 let mut in_stack: HashSet<&str> = HashSet::new();
638 let nodes: Vec<&str> = adj.keys().copied().collect();
639 for &node in &nodes {
640 if !visited.contains(node) && Self::has_cycle(node, &adj, &mut visited, &mut in_stack) {
641 return false;
642 }
643 }
644 true
645 }
646 fn has_cycle<'a>(
647 node: &'a str,
648 adj: &HashMap<&'a str, Vec<&'a str>>,
649 visited: &mut HashSet<&'a str>,
650 in_stack: &mut HashSet<&'a str>,
651 ) -> bool {
652 visited.insert(node);
653 in_stack.insert(node);
654 if let Some(neighbors) = adj.get(node) {
655 for &nb in neighbors {
656 if !visited.contains(nb) {
657 if Self::has_cycle(nb, adj, visited, in_stack) {
658 return true;
659 }
660 } else if in_stack.contains(nb) {
661 return true;
662 }
663 }
664 }
665 in_stack.remove(node);
666 false
667 }
668}
669#[derive(Debug, Clone, PartialEq, Eq)]
671pub enum LType {
672 Send(Role, BaseType, Box<LType>),
674 Recv(Role, BaseType, Box<LType>),
676 IChoice(Role, Vec<(String, LType)>),
678 EChoice(Role, Vec<(String, LType)>),
680 End,
682 Rec(String, Box<LType>),
684 Var(String),
686}
687#[derive(Debug, Clone)]
689pub enum SessionOp {
690 Send(BaseType),
692 Recv(BaseType),
694 SelectLeft,
696 SelectRight,
698 Close,
700}
701impl SessionOp {
702 pub fn check_step(&self, stype: SType) -> Result<SType, String> {
704 match (self, &stype) {
705 (SessionOp::Send(t), SType::Send(expected, cont)) => {
706 if t == expected.as_ref() {
707 Ok(*cont.clone())
708 } else {
709 Err(format!(
710 "Type mismatch: sent {:?} but expected {:?}",
711 t, expected
712 ))
713 }
714 }
715 (SessionOp::Recv(t), SType::Recv(expected, cont)) => {
716 if t == expected.as_ref() {
717 Ok(*cont.clone())
718 } else {
719 Err(format!(
720 "Type mismatch: recv {:?} but expected {:?}",
721 t, expected
722 ))
723 }
724 }
725 (SessionOp::SelectLeft, SType::Choice(left, _)) => Ok(*left.clone()),
726 (SessionOp::SelectRight, SType::Choice(_, right)) => Ok(*right.clone()),
727 (SessionOp::Close, SType::End) => Ok(SType::End),
728 _ => Err(format!(
729 "Operation {:?} incompatible with session type {}",
730 self, stype
731 )),
732 }
733 }
734}
735#[derive(Debug, Clone, PartialEq, Eq)]
737pub enum MonitorResult {
738 Ok,
740 CastInserted(String),
742 Failure(String),
744}
745pub struct SessionChecker {
747 channels: HashMap<String, SType>,
749}
750impl SessionChecker {
751 pub fn new() -> Self {
753 SessionChecker {
754 channels: HashMap::new(),
755 }
756 }
757 pub fn register_channel(&mut self, name: impl Into<String>, stype: SType) {
759 self.channels.insert(name.into(), stype);
760 }
761 pub fn check_usage(&self, channel: &str, ops: &[SessionOp]) -> Result<SType, String> {
764 let stype = self
765 .channels
766 .get(channel)
767 .ok_or_else(|| format!("Unknown channel: {}", channel))?;
768 let mut current = stype.clone();
769 for op in ops {
770 current = op.check_step(current)?;
771 }
772 Ok(current)
773 }
774}
775pub struct GradualSessionMonitor {
780 expected: SType,
782 pub violations: Vec<String>,
784 pub casts: Vec<String>,
786}
787impl GradualSessionMonitor {
788 pub fn new(expected: SType) -> Self {
790 GradualSessionMonitor {
791 expected,
792 violations: vec![],
793 casts: vec![],
794 }
795 }
796 pub fn check_send(&mut self, actual_ty: &BaseType) -> MonitorResult {
800 match self.expected.clone() {
801 SType::Send(expected_ty, cont) => {
802 self.expected = *cont;
803 if actual_ty == expected_ty.as_ref() {
804 MonitorResult::Ok
805 } else {
806 let msg = format!("cast {:?} → {:?}", actual_ty, expected_ty);
807 self.casts.push(msg.clone());
808 MonitorResult::CastInserted(msg)
809 }
810 }
811 SType::Var(ref s) if s == "?" => {
812 let msg = format!("dynamic send {:?}", actual_ty);
813 self.casts.push(msg.clone());
814 MonitorResult::CastInserted(msg)
815 }
816 other => {
817 let msg = format!("expected Send, got {}", other);
818 self.violations.push(msg.clone());
819 MonitorResult::Failure(msg)
820 }
821 }
822 }
823 pub fn check_recv(&mut self, actual_ty: &BaseType) -> MonitorResult {
825 match self.expected.clone() {
826 SType::Recv(expected_ty, cont) => {
827 self.expected = *cont;
828 if actual_ty == expected_ty.as_ref() {
829 MonitorResult::Ok
830 } else {
831 let msg = format!("cast {:?} → {:?}", actual_ty, expected_ty);
832 self.casts.push(msg.clone());
833 MonitorResult::CastInserted(msg)
834 }
835 }
836 SType::Var(ref s) if s == "?" => {
837 let msg = format!("dynamic recv {:?}", actual_ty);
838 self.casts.push(msg.clone());
839 MonitorResult::CastInserted(msg)
840 }
841 other => {
842 let msg = format!("expected Recv, got {}", other);
843 self.violations.push(msg.clone());
844 MonitorResult::Failure(msg)
845 }
846 }
847 }
848 pub fn is_safe(&self) -> bool {
850 self.violations.is_empty()
851 }
852}