rumpsteak_types/
local.rs

1//! Local Session Types for Multiparty Protocols
2//!
3//! This module defines local types that describe protocols from a single participant's
4//! perspective. Local types are obtained by projecting global types onto specific roles.
5//!
6//! Based on: "A Very Gentle Introduction to Multiparty Session Types" (Yoshida & Gheri)
7//!
8//! # Lean Correspondence
9//!
10//! The core `LocalTypeR` enum mirrors `lean/Rumpsteak/Protocol/LocalTypeR.lean`:
11//! - `LocalTypeR::End` ↔ Lean's `LocalTypeR.end`
12//! - `LocalTypeR::Send` ↔ Lean's `LocalTypeR.send`
13//! - `LocalTypeR::Recv` ↔ Lean's `LocalTypeR.recv`
14//! - `LocalTypeR::Mu` ↔ Lean's `LocalTypeR.mu`
15//! - `LocalTypeR::Var` ↔ Lean's `LocalTypeR.var`
16
17use crate::Label;
18use serde::{Deserialize, Serialize};
19use std::collections::HashSet;
20
21/// Core local type matching Lean's `LocalTypeR`.
22///
23/// This is the minimal type used for validation and correspondence proofs.
24/// For code generation, see the extended `LocalType` in the DSL crate.
25///
26/// # Examples
27///
28/// ```
29/// use rumpsteak_types::{LocalTypeR, Label};
30///
31/// // Simple send: !B{msg.end}
32/// let lt = LocalTypeR::send("B", Label::new("msg"), LocalTypeR::End);
33/// assert!(lt.well_formed());
34///
35/// // Recursive type: μt. !B{msg.t}
36/// let rec = LocalTypeR::mu(
37///     "t",
38///     LocalTypeR::send("B", Label::new("msg"), LocalTypeR::var("t")),
39/// );
40/// assert!(rec.well_formed());
41/// ```
42#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
43pub enum LocalTypeR {
44    /// Protocol termination
45    End,
46    /// Internal choice: send to partner with choice of continuations
47    Send {
48        partner: String,
49        branches: Vec<(Label, LocalTypeR)>,
50    },
51    /// External choice: receive from partner with offered continuations
52    Recv {
53        partner: String,
54        branches: Vec<(Label, LocalTypeR)>,
55    },
56    /// Recursive type: μt.T binds variable t in body T
57    Mu { var: String, body: Box<LocalTypeR> },
58    /// Type variable: reference to enclosing μ-binder
59    Var(String),
60}
61
62impl LocalTypeR {
63    /// Create a simple send with one label
64    #[must_use]
65    pub fn send(partner: impl Into<String>, label: Label, cont: LocalTypeR) -> Self {
66        LocalTypeR::Send {
67            partner: partner.into(),
68            branches: vec![(label, cont)],
69        }
70    }
71
72    /// Create a send with multiple branches
73    #[must_use]
74    pub fn send_choice(partner: impl Into<String>, branches: Vec<(Label, LocalTypeR)>) -> Self {
75        LocalTypeR::Send {
76            partner: partner.into(),
77            branches,
78        }
79    }
80
81    /// Create a simple recv with one label
82    #[must_use]
83    pub fn recv(partner: impl Into<String>, label: Label, cont: LocalTypeR) -> Self {
84        LocalTypeR::Recv {
85            partner: partner.into(),
86            branches: vec![(label, cont)],
87        }
88    }
89
90    /// Create a recv with multiple branches
91    #[must_use]
92    pub fn recv_choice(partner: impl Into<String>, branches: Vec<(Label, LocalTypeR)>) -> Self {
93        LocalTypeR::Recv {
94            partner: partner.into(),
95            branches,
96        }
97    }
98
99    /// Create a recursive type
100    #[must_use]
101    pub fn mu(var: impl Into<String>, body: LocalTypeR) -> Self {
102        LocalTypeR::Mu {
103            var: var.into(),
104            body: Box::new(body),
105        }
106    }
107
108    /// Create a type variable
109    #[must_use]
110    pub fn var(name: impl Into<String>) -> Self {
111        LocalTypeR::Var(name.into())
112    }
113
114    /// Extract free type variables from a local type.
115    ///
116    /// Corresponds to Lean's `LocalTypeR.freeVars`.
117    #[must_use]
118    pub fn free_vars(&self) -> Vec<String> {
119        let mut result = Vec::new();
120        let mut bound = HashSet::new();
121        self.collect_free_vars(&mut result, &mut bound);
122        result
123    }
124
125    fn collect_free_vars(&self, free: &mut Vec<String>, bound: &mut HashSet<String>) {
126        match self {
127            LocalTypeR::End => {}
128            LocalTypeR::Send { branches, .. } | LocalTypeR::Recv { branches, .. } => {
129                for (_, cont) in branches {
130                    cont.collect_free_vars(free, bound);
131                }
132            }
133            LocalTypeR::Mu { var, body } => {
134                bound.insert(var.clone());
135                body.collect_free_vars(free, bound);
136                bound.remove(var);
137            }
138            LocalTypeR::Var(t) => {
139                if !bound.contains(t) && !free.contains(t) {
140                    free.push(t.clone());
141                }
142            }
143        }
144    }
145
146    /// Substitute a local type for a variable.
147    ///
148    /// Corresponds to Lean's `LocalTypeR.substitute`.
149    #[must_use]
150    pub fn substitute(&self, var_name: &str, replacement: &LocalTypeR) -> LocalTypeR {
151        match self {
152            LocalTypeR::End => LocalTypeR::End,
153            LocalTypeR::Send { partner, branches } => LocalTypeR::Send {
154                partner: partner.clone(),
155                branches: branches
156                    .iter()
157                    .map(|(l, cont)| (l.clone(), cont.substitute(var_name, replacement)))
158                    .collect(),
159            },
160            LocalTypeR::Recv { partner, branches } => LocalTypeR::Recv {
161                partner: partner.clone(),
162                branches: branches
163                    .iter()
164                    .map(|(l, cont)| (l.clone(), cont.substitute(var_name, replacement)))
165                    .collect(),
166            },
167            LocalTypeR::Mu { var, body } => {
168                if var == var_name {
169                    // Variable is shadowed by this binder
170                    LocalTypeR::Mu {
171                        var: var.clone(),
172                        body: body.clone(),
173                    }
174                } else {
175                    LocalTypeR::Mu {
176                        var: var.clone(),
177                        body: Box::new(body.substitute(var_name, replacement)),
178                    }
179                }
180            }
181            LocalTypeR::Var(t) => {
182                if t == var_name {
183                    replacement.clone()
184                } else {
185                    LocalTypeR::Var(t.clone())
186                }
187            }
188        }
189    }
190
191    /// Unfold one level of recursion: μt.T ↦ T[μt.T/t]
192    ///
193    /// Corresponds to Lean's `LocalTypeR.unfold`.
194    #[must_use]
195    pub fn unfold(&self) -> LocalTypeR {
196        match self {
197            LocalTypeR::Mu { var, body } => body.substitute(var, self),
198            _ => self.clone(),
199        }
200    }
201
202    /// Compute the dual of a local type (swap send/recv).
203    ///
204    /// The dual of role A's view is role B's view when A and B are the only participants.
205    /// Corresponds to Lean's `LocalTypeR.dual`.
206    #[must_use]
207    pub fn dual(&self) -> LocalTypeR {
208        match self {
209            LocalTypeR::End => LocalTypeR::End,
210            LocalTypeR::Send { partner, branches } => LocalTypeR::Recv {
211                partner: partner.clone(),
212                branches: branches
213                    .iter()
214                    .map(|(l, cont)| (l.clone(), cont.dual()))
215                    .collect(),
216            },
217            LocalTypeR::Recv { partner, branches } => LocalTypeR::Send {
218                partner: partner.clone(),
219                branches: branches
220                    .iter()
221                    .map(|(l, cont)| (l.clone(), cont.dual()))
222                    .collect(),
223            },
224            LocalTypeR::Mu { var, body } => LocalTypeR::Mu {
225                var: var.clone(),
226                body: Box::new(body.dual()),
227            },
228            LocalTypeR::Var(t) => LocalTypeR::Var(t.clone()),
229        }
230    }
231
232    /// Check if all recursion variables are bound.
233    ///
234    /// Corresponds to Lean's `LocalTypeR.allVarsBound`.
235    #[must_use]
236    pub fn all_vars_bound(&self) -> bool {
237        self.check_vars_bound(&HashSet::new())
238    }
239
240    fn check_vars_bound(&self, bound: &HashSet<String>) -> bool {
241        match self {
242            LocalTypeR::End => true,
243            LocalTypeR::Send { branches, .. } | LocalTypeR::Recv { branches, .. } => branches
244                .iter()
245                .all(|(_, cont)| cont.check_vars_bound(bound)),
246            LocalTypeR::Mu { var, body } => {
247                let mut new_bound = bound.clone();
248                new_bound.insert(var.clone());
249                body.check_vars_bound(&new_bound)
250            }
251            LocalTypeR::Var(t) => bound.contains(t),
252        }
253    }
254
255    /// Check if each choice has at least one branch.
256    ///
257    /// Corresponds to Lean's `LocalTypeR.allChoicesNonEmpty`.
258    #[must_use]
259    pub fn all_choices_non_empty(&self) -> bool {
260        match self {
261            LocalTypeR::End => true,
262            LocalTypeR::Send { branches, .. } | LocalTypeR::Recv { branches, .. } => {
263                !branches.is_empty()
264                    && branches
265                        .iter()
266                        .all(|(_, cont)| cont.all_choices_non_empty())
267            }
268            LocalTypeR::Mu { body, .. } => body.all_choices_non_empty(),
269            LocalTypeR::Var(_) => true,
270        }
271    }
272
273    /// Well-formedness predicate for local types.
274    ///
275    /// Corresponds to Lean's `LocalTypeR.wellFormed`.
276    /// A local type is well-formed if:
277    /// 1. All recursion variables are bound
278    /// 2. Each choice has at least one branch
279    /// 3. All recursion is guarded (no immediate recursion without communication)
280    #[must_use]
281    pub fn well_formed(&self) -> bool {
282        self.all_vars_bound() && self.all_choices_non_empty() && self.is_guarded()
283    }
284
285    /// Count the depth of a local type (for termination proofs).
286    ///
287    /// Corresponds to Lean's `LocalTypeR.depth`.
288    #[must_use]
289    pub fn depth(&self) -> usize {
290        match self {
291            LocalTypeR::End => 0,
292            LocalTypeR::Send { branches, .. } | LocalTypeR::Recv { branches, .. } => {
293                1 + branches.iter().map(|(_, t)| t.depth()).max().unwrap_or(0)
294            }
295            LocalTypeR::Mu { body, .. } => 1 + body.depth(),
296            LocalTypeR::Var(_) => 0,
297        }
298    }
299
300    /// Check if a local type is guarded (no immediate recursion).
301    ///
302    /// Corresponds to Lean's `LocalTypeR.isGuarded`.
303    #[must_use]
304    pub fn is_guarded(&self) -> bool {
305        match self {
306            LocalTypeR::End => true,
307            LocalTypeR::Send { branches, .. } | LocalTypeR::Recv { branches, .. } => {
308                branches.iter().all(|(_, cont)| cont.is_guarded())
309            }
310            LocalTypeR::Mu { body, .. } => match body.as_ref() {
311                LocalTypeR::Var(_) | LocalTypeR::Mu { .. } => false,
312                _ => body.is_guarded(),
313            },
314            LocalTypeR::Var(_) => true,
315        }
316    }
317
318    /// Extract all labels from a local type.
319    ///
320    /// Corresponds to Lean's `LocalTypeR.labels`.
321    #[must_use]
322    pub fn labels(&self) -> Vec<String> {
323        match self {
324            LocalTypeR::End | LocalTypeR::Var(_) => vec![],
325            LocalTypeR::Send { branches, .. } | LocalTypeR::Recv { branches, .. } => {
326                branches.iter().map(|(l, _)| l.name.clone()).collect()
327            }
328            LocalTypeR::Mu { body, .. } => body.labels(),
329        }
330    }
331
332    /// Extract all partners from a local type.
333    ///
334    /// Corresponds to Lean's `LocalTypeR.partners`.
335    #[must_use]
336    pub fn partners(&self) -> Vec<String> {
337        let mut result = HashSet::new();
338        self.collect_partners(&mut result);
339        result.into_iter().collect()
340    }
341
342    fn collect_partners(&self, partners: &mut HashSet<String>) {
343        match self {
344            LocalTypeR::End | LocalTypeR::Var(_) => {}
345            LocalTypeR::Send { partner, branches } | LocalTypeR::Recv { partner, branches } => {
346                partners.insert(partner.clone());
347                for (_, cont) in branches {
348                    cont.collect_partners(partners);
349                }
350            }
351            LocalTypeR::Mu { body, .. } => body.collect_partners(partners),
352        }
353    }
354
355    /// Check if a local type mentions a specific partner.
356    #[must_use]
357    pub fn mentions_partner(&self, role: &str) -> bool {
358        self.partners().contains(&role.to_string())
359    }
360
361    /// Check if this is an internal choice (send)
362    #[must_use]
363    pub fn is_send(&self) -> bool {
364        matches!(self, LocalTypeR::Send { .. })
365    }
366
367    /// Check if this is an external choice (recv)
368    #[must_use]
369    pub fn is_recv(&self) -> bool {
370        matches!(self, LocalTypeR::Recv { .. })
371    }
372
373    /// Check if this is a terminated type
374    #[must_use]
375    pub fn is_end(&self) -> bool {
376        matches!(self, LocalTypeR::End)
377    }
378}
379
380#[cfg(test)]
381mod tests {
382    use super::*;
383    use crate::PayloadSort;
384    use assert_matches::assert_matches;
385
386    #[test]
387    fn test_simple_local_type() {
388        // !B{msg.end}
389        let lt = LocalTypeR::send("B", Label::new("msg"), LocalTypeR::End);
390        assert!(lt.well_formed());
391        assert_eq!(lt.partners().len(), 1);
392        assert!(lt.mentions_partner("B"));
393    }
394
395    #[test]
396    fn test_recursive_local_type() {
397        // μt. !B{msg.t}
398        let lt = LocalTypeR::mu(
399            "t",
400            LocalTypeR::send("B", Label::new("msg"), LocalTypeR::var("t")),
401        );
402        assert!(lt.well_formed());
403        assert!(lt.all_vars_bound());
404    }
405
406    #[test]
407    fn test_dual() {
408        // !B{msg.end} dual is ?B{msg.end}
409        let send = LocalTypeR::send("B", Label::new("msg"), LocalTypeR::End);
410        let recv = send.dual();
411
412        assert_matches!(recv, LocalTypeR::Recv { partner, branches } => {
413            assert_eq!(partner, "B");
414            assert_eq!(branches.len(), 1);
415            assert_eq!(branches[0].0.name, "msg");
416        });
417    }
418
419    #[test]
420    fn test_unfold() {
421        // μt. !B{msg.t} unfolds to !B{msg.(μt. !B{msg.t})}
422        let lt = LocalTypeR::mu(
423            "t",
424            LocalTypeR::send("B", Label::new("msg"), LocalTypeR::var("t")),
425        );
426        let unfolded = lt.unfold();
427
428        assert_matches!(unfolded, LocalTypeR::Send { partner, branches } => {
429            assert_eq!(partner, "B");
430            assert_matches!(branches[0].1, LocalTypeR::Mu { .. });
431        });
432    }
433
434    #[test]
435    fn test_substitute() {
436        let lt = LocalTypeR::var("t");
437        let replacement = LocalTypeR::End;
438        assert_eq!(lt.substitute("t", &replacement), LocalTypeR::End);
439        assert_eq!(lt.substitute("s", &replacement), LocalTypeR::var("t"));
440    }
441
442    #[test]
443    fn test_unbound_variable() {
444        // !B{msg.t} where t is unbound
445        let lt = LocalTypeR::send("B", Label::new("msg"), LocalTypeR::var("t"));
446        assert!(!lt.all_vars_bound());
447        assert!(!lt.well_formed());
448    }
449
450    #[test]
451    fn test_guarded() {
452        // μt. t is not guarded
453        let unguarded = LocalTypeR::mu("t", LocalTypeR::var("t"));
454        assert!(!unguarded.is_guarded());
455        assert!(!unguarded.well_formed()); // Unguarded recursion should fail well_formed()
456
457        // μt. !B{msg.t} is guarded
458        let guarded = LocalTypeR::mu(
459            "t",
460            LocalTypeR::send("B", Label::new("msg"), LocalTypeR::var("t")),
461        );
462        assert!(guarded.is_guarded());
463        assert!(guarded.well_formed()); // Guarded recursion should pass well_formed()
464    }
465
466    #[test]
467    fn test_free_vars() {
468        // μt. !B{msg.s} has free var s
469        let lt = LocalTypeR::mu(
470            "t",
471            LocalTypeR::send("B", Label::new("msg"), LocalTypeR::var("s")),
472        );
473        let free = lt.free_vars();
474        assert_eq!(free, vec!["s"]);
475    }
476
477    #[test]
478    fn test_choice_with_payload() {
479        let branches = vec![
480            (
481                Label::with_sort("accept", PayloadSort::Bool),
482                LocalTypeR::End,
483            ),
484            (Label::with_sort("data", PayloadSort::Nat), LocalTypeR::End),
485        ];
486        let lt = LocalTypeR::recv_choice("A", branches);
487        assert!(lt.well_formed());
488        assert_eq!(lt.labels(), vec!["accept", "data"]);
489    }
490
491    #[test]
492    fn test_depth() {
493        let lt = LocalTypeR::send(
494            "B",
495            Label::new("outer"),
496            LocalTypeR::send("C", Label::new("inner"), LocalTypeR::End),
497        );
498        assert_eq!(lt.depth(), 2);
499    }
500
501    #[test]
502    fn test_is_send_recv() {
503        let send = LocalTypeR::send("B", Label::new("msg"), LocalTypeR::End);
504        let recv = LocalTypeR::recv("B", Label::new("msg"), LocalTypeR::End);
505
506        assert!(send.is_send());
507        assert!(!send.is_recv());
508        assert!(recv.is_recv());
509        assert!(!recv.is_send());
510    }
511}