rumpsteak_types/
global.rs

1//! Global Types for Multiparty Session Type Protocols
2//!
3//! This module defines global types that describe protocols from a bird's-eye view.
4//! Global types specify the complete interaction pattern between all participants,
5//! including message exchanges, choices, and recursive behavior.
6//!
7//! Based on: "A Very Gentle Introduction to Multiparty Session Types" (Yoshida & Gheri)
8//!
9//! # Lean Correspondence
10//!
11//! This module mirrors the definitions in `lean/Rumpsteak/Protocol/GlobalType.lean`:
12//! - `PayloadSort` ↔ Lean's `PayloadSort`
13//! - `GlobalType` ↔ Lean's `GlobalType`
14
15use crate::Label;
16use serde::{Deserialize, Serialize};
17use std::collections::HashSet;
18
19/// Payload sort types for message payloads.
20///
21/// Corresponds to Lean's `PayloadSort` inductive type.
22/// These represent the data types that can be sent in messages.
23///
24/// # Examples
25///
26/// ```
27/// use rumpsteak_types::PayloadSort;
28///
29/// let unit = PayloadSort::Unit;
30/// assert!(unit.is_simple());
31///
32/// let pair = PayloadSort::prod(PayloadSort::Nat, PayloadSort::Bool);
33/// assert!(!pair.is_simple());
34/// ```
35#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
36pub enum PayloadSort {
37    /// Unit type (no payload)
38    #[default]
39    Unit,
40    /// Natural numbers
41    Nat,
42    /// Booleans
43    Bool,
44    /// Strings
45    String,
46    /// Product types (pairs)
47    Prod(Box<PayloadSort>, Box<PayloadSort>),
48}
49
50impl PayloadSort {
51    /// Create a product sort
52    #[must_use]
53    pub fn prod(left: PayloadSort, right: PayloadSort) -> Self {
54        PayloadSort::Prod(Box::new(left), Box::new(right))
55    }
56
57    /// Check if this is a simple (non-product) sort
58    #[must_use]
59    pub fn is_simple(&self) -> bool {
60        !matches!(self, PayloadSort::Prod(_, _))
61    }
62}
63
64impl std::fmt::Display for PayloadSort {
65    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66        match self {
67            PayloadSort::Unit => write!(f, "Unit"),
68            PayloadSort::Nat => write!(f, "Nat"),
69            PayloadSort::Bool => write!(f, "Bool"),
70            PayloadSort::String => write!(f, "String"),
71            PayloadSort::Prod(l, r) => write!(f, "({} × {})", l, r),
72        }
73    }
74}
75
76/// Global types describe protocols from the bird's-eye view.
77///
78/// Corresponds to Lean's `GlobalType` inductive type.
79///
80/// # Syntax
81///
82/// - `End`: Protocol termination
83/// - `Comm { sender, receiver, branches }`: Communication with labeled branches
84/// - `Mu { var, body }`: Recursive type μt.G
85/// - `Var(t)`: Type variable reference
86///
87/// The `Comm` variant models `p → q : {l₁(S₁).G₁, l₂(S₂).G₂, ...}`
88/// where the sender p chooses which branch to take.
89///
90/// # Examples
91///
92/// ```
93/// use rumpsteak_types::{GlobalType, Label};
94///
95/// // Simple protocol: A -> B: hello. end
96/// let g = GlobalType::send("A", "B", Label::new("hello"), GlobalType::End);
97/// assert!(g.well_formed());
98///
99/// // Recursive protocol: μt. A -> B: msg. t
100/// let rec = GlobalType::mu(
101///     "t",
102///     GlobalType::send("A", "B", Label::new("msg"), GlobalType::var("t")),
103/// );
104/// assert!(rec.well_formed());
105/// ```
106#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
107pub enum GlobalType {
108    /// Protocol termination
109    End,
110    /// Communication: sender → receiver with choice of labeled continuations
111    Comm {
112        sender: String,
113        receiver: String,
114        branches: Vec<(Label, GlobalType)>,
115    },
116    /// Recursive type: μt.G binds variable t in body G
117    Mu { var: String, body: Box<GlobalType> },
118    /// Type variable: reference to enclosing μ-binder
119    Var(String),
120}
121
122impl GlobalType {
123    /// Create a simple send without choice
124    #[must_use]
125    pub fn send(
126        sender: impl Into<String>,
127        receiver: impl Into<String>,
128        label: Label,
129        cont: GlobalType,
130    ) -> Self {
131        GlobalType::Comm {
132            sender: sender.into(),
133            receiver: receiver.into(),
134            branches: vec![(label, cont)],
135        }
136    }
137
138    /// Create a communication with multiple branches
139    #[must_use]
140    pub fn comm(
141        sender: impl Into<String>,
142        receiver: impl Into<String>,
143        branches: Vec<(Label, GlobalType)>,
144    ) -> Self {
145        GlobalType::Comm {
146            sender: sender.into(),
147            receiver: receiver.into(),
148            branches,
149        }
150    }
151
152    /// Create a recursive type
153    #[must_use]
154    pub fn mu(var: impl Into<String>, body: GlobalType) -> Self {
155        GlobalType::Mu {
156            var: var.into(),
157            body: Box::new(body),
158        }
159    }
160
161    /// Create a type variable
162    #[must_use]
163    pub fn var(name: impl Into<String>) -> Self {
164        GlobalType::Var(name.into())
165    }
166
167    /// Extract all role names from this global type.
168    ///
169    /// Corresponds to Lean's `GlobalType.roles`.
170    #[must_use]
171    pub fn roles(&self) -> Vec<String> {
172        let mut result = HashSet::new();
173        self.collect_roles(&mut result);
174        result.into_iter().collect()
175    }
176
177    fn collect_roles(&self, roles: &mut HashSet<String>) {
178        match self {
179            GlobalType::End => {}
180            GlobalType::Comm {
181                sender,
182                receiver,
183                branches,
184            } => {
185                roles.insert(sender.clone());
186                roles.insert(receiver.clone());
187                for (_, cont) in branches {
188                    cont.collect_roles(roles);
189                }
190            }
191            GlobalType::Mu { body, .. } => body.collect_roles(roles),
192            GlobalType::Var(_) => {}
193        }
194    }
195
196    /// Extract free type variables from this global type.
197    ///
198    /// Corresponds to Lean's `GlobalType.freeVars`.
199    #[must_use]
200    pub fn free_vars(&self) -> Vec<String> {
201        let mut result = Vec::new();
202        let mut bound = HashSet::new();
203        self.collect_free_vars(&mut result, &mut bound);
204        result
205    }
206
207    fn collect_free_vars(&self, free: &mut Vec<String>, bound: &mut HashSet<String>) {
208        match self {
209            GlobalType::End => {}
210            GlobalType::Comm { branches, .. } => {
211                for (_, cont) in branches {
212                    cont.collect_free_vars(free, bound);
213                }
214            }
215            GlobalType::Mu { var, body } => {
216                bound.insert(var.clone());
217                body.collect_free_vars(free, bound);
218                bound.remove(var);
219            }
220            GlobalType::Var(t) => {
221                if !bound.contains(t) && !free.contains(t) {
222                    free.push(t.clone());
223                }
224            }
225        }
226    }
227
228    /// Substitute a global type for a variable.
229    ///
230    /// Corresponds to Lean's `GlobalType.substitute`.
231    #[must_use]
232    pub fn substitute(&self, var_name: &str, replacement: &GlobalType) -> GlobalType {
233        match self {
234            GlobalType::End => GlobalType::End,
235            GlobalType::Comm {
236                sender,
237                receiver,
238                branches,
239            } => GlobalType::Comm {
240                sender: sender.clone(),
241                receiver: receiver.clone(),
242                branches: branches
243                    .iter()
244                    .map(|(l, cont)| (l.clone(), cont.substitute(var_name, replacement)))
245                    .collect(),
246            },
247            GlobalType::Mu { var, body } => {
248                if var == var_name {
249                    // Variable is shadowed by this binder
250                    GlobalType::Mu {
251                        var: var.clone(),
252                        body: body.clone(),
253                    }
254                } else {
255                    GlobalType::Mu {
256                        var: var.clone(),
257                        body: Box::new(body.substitute(var_name, replacement)),
258                    }
259                }
260            }
261            GlobalType::Var(t) => {
262                if t == var_name {
263                    replacement.clone()
264                } else {
265                    GlobalType::Var(t.clone())
266                }
267            }
268        }
269    }
270
271    /// Unfold one level of recursion: μt.G ↦ G[μt.G/t]
272    ///
273    /// Corresponds to Lean's `GlobalType.unfold`.
274    #[must_use]
275    pub fn unfold(&self) -> GlobalType {
276        match self {
277            GlobalType::Mu { var, body } => body.substitute(var, self),
278            _ => self.clone(),
279        }
280    }
281
282    /// Check if all recursion variables are bound.
283    ///
284    /// Corresponds to Lean's `GlobalType.allVarsBound`.
285    #[must_use]
286    pub fn all_vars_bound(&self) -> bool {
287        self.check_vars_bound(&HashSet::new())
288    }
289
290    fn check_vars_bound(&self, bound: &HashSet<String>) -> bool {
291        match self {
292            GlobalType::End => true,
293            GlobalType::Comm { branches, .. } => branches
294                .iter()
295                .all(|(_, cont)| cont.check_vars_bound(bound)),
296            GlobalType::Mu { var, body } => {
297                let mut new_bound = bound.clone();
298                new_bound.insert(var.clone());
299                body.check_vars_bound(&new_bound)
300            }
301            GlobalType::Var(t) => bound.contains(t),
302        }
303    }
304
305    /// Check if each communication has at least one branch.
306    ///
307    /// Corresponds to Lean's `GlobalType.allCommsNonEmpty`.
308    #[must_use]
309    pub fn all_comms_non_empty(&self) -> bool {
310        match self {
311            GlobalType::End => true,
312            GlobalType::Comm { branches, .. } => {
313                !branches.is_empty() && branches.iter().all(|(_, cont)| cont.all_comms_non_empty())
314            }
315            GlobalType::Mu { body, .. } => body.all_comms_non_empty(),
316            GlobalType::Var(_) => true,
317        }
318    }
319
320    /// Check if sender and receiver are different in each communication.
321    ///
322    /// Corresponds to Lean's `GlobalType.noSelfComm`.
323    #[must_use]
324    pub fn no_self_comm(&self) -> bool {
325        match self {
326            GlobalType::End => true,
327            GlobalType::Comm {
328                sender,
329                receiver,
330                branches,
331            } => sender != receiver && branches.iter().all(|(_, cont)| cont.no_self_comm()),
332            GlobalType::Mu { body, .. } => body.no_self_comm(),
333            GlobalType::Var(_) => true,
334        }
335    }
336
337    /// Well-formedness predicate for global types.
338    ///
339    /// Corresponds to Lean's `GlobalType.wellFormed`.
340    /// A global type is well-formed if:
341    /// 1. All recursion variables are bound
342    /// 2. Each communication has at least one branch
343    /// 3. Sender ≠ receiver in each communication
344    /// 4. All recursion is guarded (no immediate recursion without communication)
345    #[must_use]
346    pub fn well_formed(&self) -> bool {
347        self.all_vars_bound()
348            && self.all_comms_non_empty()
349            && self.no_self_comm()
350            && self.is_guarded()
351    }
352
353    /// Check if a role participates in the global type.
354    ///
355    /// Corresponds to Lean's `GlobalType.mentionsRole`.
356    #[must_use]
357    pub fn mentions_role(&self, role: &str) -> bool {
358        self.roles().contains(&role.to_string())
359    }
360
361    /// Count the depth of a global type (for termination proofs).
362    ///
363    /// Corresponds to Lean's `GlobalType.depth`.
364    #[must_use]
365    pub fn depth(&self) -> usize {
366        match self {
367            GlobalType::End => 0,
368            GlobalType::Comm { branches, .. } => {
369                1 + branches.iter().map(|(_, g)| g.depth()).max().unwrap_or(0)
370            }
371            GlobalType::Mu { body, .. } => 1 + body.depth(),
372            GlobalType::Var(_) => 0,
373        }
374    }
375
376    /// Check if a global type is guarded (no immediate recursion without communication).
377    ///
378    /// Corresponds to Lean's `GlobalType.isGuarded`.
379    #[must_use]
380    pub fn is_guarded(&self) -> bool {
381        match self {
382            GlobalType::End => true,
383            GlobalType::Comm { branches, .. } => branches.iter().all(|(_, cont)| cont.is_guarded()),
384            GlobalType::Mu { body, .. } => match body.as_ref() {
385                GlobalType::Var(_) | GlobalType::Mu { .. } => false,
386                _ => body.is_guarded(),
387            },
388            GlobalType::Var(_) => true,
389        }
390    }
391
392    /// Consume a communication from a global type.
393    ///
394    /// Corresponds to Lean's `GlobalType.consume`.
395    /// G \ p →ℓ q represents the global type after the communication p →ℓ q
396    /// has been performed.
397    #[must_use]
398    pub fn consume(&self, sender: &str, receiver: &str, label: &str) -> Option<GlobalType> {
399        match self {
400            GlobalType::Comm {
401                sender: s,
402                receiver: r,
403                branches,
404            } => {
405                if s == sender && r == receiver {
406                    branches
407                        .iter()
408                        .find(|(l, _)| l.name == label)
409                        .map(|(_, cont)| cont.clone())
410                } else {
411                    None
412                }
413            }
414            GlobalType::Mu { var, body } => {
415                // Unfold and try again
416                body.substitute(var, self).consume(sender, receiver, label)
417            }
418            _ => None,
419        }
420    }
421}
422
423#[cfg(test)]
424mod tests {
425    use super::*;
426    use assert_matches::assert_matches;
427
428    #[test]
429    fn test_simple_protocol() {
430        // A -> B: hello. end
431        let g = GlobalType::send("A", "B", Label::new("hello"), GlobalType::End);
432        assert!(g.well_formed());
433        assert_eq!(g.roles().len(), 2);
434        assert!(g.mentions_role("A"));
435        assert!(g.mentions_role("B"));
436    }
437
438    #[test]
439    fn test_recursive_protocol() {
440        // μt. A -> B: msg. t
441        let g = GlobalType::mu(
442            "t",
443            GlobalType::send("A", "B", Label::new("msg"), GlobalType::var("t")),
444        );
445        assert!(g.well_formed());
446        assert!(g.all_vars_bound());
447    }
448
449    #[test]
450    fn test_unbound_variable() {
451        // A -> B: msg. t (t is unbound)
452        let g = GlobalType::send("A", "B", Label::new("msg"), GlobalType::var("t"));
453        assert!(!g.all_vars_bound());
454        assert!(!g.well_formed());
455    }
456
457    #[test]
458    fn test_self_communication() {
459        // A -> A: msg. end (self-communication)
460        let g = GlobalType::send("A", "A", Label::new("msg"), GlobalType::End);
461        assert!(!g.no_self_comm());
462        assert!(!g.well_formed());
463    }
464
465    #[test]
466    fn test_unfold() {
467        // μt. A -> B: msg. t unfolds to A -> B: msg. (μt. A -> B: msg. t)
468        let g = GlobalType::mu(
469            "t",
470            GlobalType::send("A", "B", Label::new("msg"), GlobalType::var("t")),
471        );
472        let unfolded = g.unfold();
473        assert_matches!(unfolded, GlobalType::Comm { sender, receiver, branches } => {
474            assert_eq!(sender, "A");
475            assert_eq!(receiver, "B");
476            assert_eq!(branches.len(), 1);
477            // Continuation should be the original recursive type
478            assert_matches!(branches[0].1, GlobalType::Mu { .. });
479        });
480    }
481
482    #[test]
483    fn test_substitute() {
484        let g = GlobalType::var("t");
485        let replacement = GlobalType::End;
486        assert_eq!(g.substitute("t", &replacement), GlobalType::End);
487        assert_eq!(g.substitute("s", &replacement), GlobalType::var("t"));
488    }
489
490    #[test]
491    fn test_consume() {
492        let g = GlobalType::comm(
493            "A",
494            "B",
495            vec![
496                (Label::new("accept"), GlobalType::End),
497                (Label::new("reject"), GlobalType::End),
498            ],
499        );
500
501        assert_eq!(g.consume("A", "B", "accept"), Some(GlobalType::End));
502        assert_eq!(g.consume("A", "B", "reject"), Some(GlobalType::End));
503        assert_eq!(g.consume("A", "B", "unknown"), None);
504        assert_eq!(g.consume("B", "A", "accept"), None);
505    }
506
507    #[test]
508    fn test_payload_sort() {
509        let sort = PayloadSort::prod(PayloadSort::Nat, PayloadSort::Bool);
510        assert!(!sort.is_simple());
511
512        let label = Label::with_sort("data", sort);
513        assert_eq!(label.name, "data");
514    }
515
516    #[test]
517    fn test_guarded() {
518        // μt. t is not guarded (immediate recursion)
519        let unguarded = GlobalType::mu("t", GlobalType::var("t"));
520        assert!(!unguarded.is_guarded());
521        assert!(!unguarded.well_formed()); // Unguarded recursion should fail well_formed()
522
523        // μt. A -> B: msg. t is guarded
524        let guarded = GlobalType::mu(
525            "t",
526            GlobalType::send("A", "B", Label::new("msg"), GlobalType::var("t")),
527        );
528        assert!(guarded.is_guarded());
529        assert!(guarded.well_formed()); // Guarded recursion should pass well_formed()
530    }
531}