1use crate::Label;
16use serde::{Deserialize, Serialize};
17use std::collections::HashSet;
18
19#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
36pub enum PayloadSort {
37 #[default]
39 Unit,
40 Nat,
42 Bool,
44 String,
46 Prod(Box<PayloadSort>, Box<PayloadSort>),
48}
49
50impl PayloadSort {
51 #[must_use]
53 pub fn prod(left: PayloadSort, right: PayloadSort) -> Self {
54 PayloadSort::Prod(Box::new(left), Box::new(right))
55 }
56
57 #[must_use]
59 pub fn is_simple(&self) -> bool {
60 !matches!(self, PayloadSort::Prod(_, _))
61 }
62}
63
64impl std::fmt::Display for PayloadSort {
65 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66 match self {
67 PayloadSort::Unit => write!(f, "Unit"),
68 PayloadSort::Nat => write!(f, "Nat"),
69 PayloadSort::Bool => write!(f, "Bool"),
70 PayloadSort::String => write!(f, "String"),
71 PayloadSort::Prod(l, r) => write!(f, "({} × {})", l, r),
72 }
73 }
74}
75
76#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
107pub enum GlobalType {
108 End,
110 Comm {
112 sender: String,
113 receiver: String,
114 branches: Vec<(Label, GlobalType)>,
115 },
116 Mu { var: String, body: Box<GlobalType> },
118 Var(String),
120}
121
122impl GlobalType {
123 #[must_use]
125 pub fn send(
126 sender: impl Into<String>,
127 receiver: impl Into<String>,
128 label: Label,
129 cont: GlobalType,
130 ) -> Self {
131 GlobalType::Comm {
132 sender: sender.into(),
133 receiver: receiver.into(),
134 branches: vec![(label, cont)],
135 }
136 }
137
138 #[must_use]
140 pub fn comm(
141 sender: impl Into<String>,
142 receiver: impl Into<String>,
143 branches: Vec<(Label, GlobalType)>,
144 ) -> Self {
145 GlobalType::Comm {
146 sender: sender.into(),
147 receiver: receiver.into(),
148 branches,
149 }
150 }
151
152 #[must_use]
154 pub fn mu(var: impl Into<String>, body: GlobalType) -> Self {
155 GlobalType::Mu {
156 var: var.into(),
157 body: Box::new(body),
158 }
159 }
160
161 #[must_use]
163 pub fn var(name: impl Into<String>) -> Self {
164 GlobalType::Var(name.into())
165 }
166
167 #[must_use]
171 pub fn roles(&self) -> Vec<String> {
172 let mut result = HashSet::new();
173 self.collect_roles(&mut result);
174 result.into_iter().collect()
175 }
176
177 fn collect_roles(&self, roles: &mut HashSet<String>) {
178 match self {
179 GlobalType::End => {}
180 GlobalType::Comm {
181 sender,
182 receiver,
183 branches,
184 } => {
185 roles.insert(sender.clone());
186 roles.insert(receiver.clone());
187 for (_, cont) in branches {
188 cont.collect_roles(roles);
189 }
190 }
191 GlobalType::Mu { body, .. } => body.collect_roles(roles),
192 GlobalType::Var(_) => {}
193 }
194 }
195
196 #[must_use]
200 pub fn free_vars(&self) -> Vec<String> {
201 let mut result = Vec::new();
202 let mut bound = HashSet::new();
203 self.collect_free_vars(&mut result, &mut bound);
204 result
205 }
206
207 fn collect_free_vars(&self, free: &mut Vec<String>, bound: &mut HashSet<String>) {
208 match self {
209 GlobalType::End => {}
210 GlobalType::Comm { branches, .. } => {
211 for (_, cont) in branches {
212 cont.collect_free_vars(free, bound);
213 }
214 }
215 GlobalType::Mu { var, body } => {
216 bound.insert(var.clone());
217 body.collect_free_vars(free, bound);
218 bound.remove(var);
219 }
220 GlobalType::Var(t) => {
221 if !bound.contains(t) && !free.contains(t) {
222 free.push(t.clone());
223 }
224 }
225 }
226 }
227
228 #[must_use]
232 pub fn substitute(&self, var_name: &str, replacement: &GlobalType) -> GlobalType {
233 match self {
234 GlobalType::End => GlobalType::End,
235 GlobalType::Comm {
236 sender,
237 receiver,
238 branches,
239 } => GlobalType::Comm {
240 sender: sender.clone(),
241 receiver: receiver.clone(),
242 branches: branches
243 .iter()
244 .map(|(l, cont)| (l.clone(), cont.substitute(var_name, replacement)))
245 .collect(),
246 },
247 GlobalType::Mu { var, body } => {
248 if var == var_name {
249 GlobalType::Mu {
251 var: var.clone(),
252 body: body.clone(),
253 }
254 } else {
255 GlobalType::Mu {
256 var: var.clone(),
257 body: Box::new(body.substitute(var_name, replacement)),
258 }
259 }
260 }
261 GlobalType::Var(t) => {
262 if t == var_name {
263 replacement.clone()
264 } else {
265 GlobalType::Var(t.clone())
266 }
267 }
268 }
269 }
270
271 #[must_use]
275 pub fn unfold(&self) -> GlobalType {
276 match self {
277 GlobalType::Mu { var, body } => body.substitute(var, self),
278 _ => self.clone(),
279 }
280 }
281
282 #[must_use]
286 pub fn all_vars_bound(&self) -> bool {
287 self.check_vars_bound(&HashSet::new())
288 }
289
290 fn check_vars_bound(&self, bound: &HashSet<String>) -> bool {
291 match self {
292 GlobalType::End => true,
293 GlobalType::Comm { branches, .. } => branches
294 .iter()
295 .all(|(_, cont)| cont.check_vars_bound(bound)),
296 GlobalType::Mu { var, body } => {
297 let mut new_bound = bound.clone();
298 new_bound.insert(var.clone());
299 body.check_vars_bound(&new_bound)
300 }
301 GlobalType::Var(t) => bound.contains(t),
302 }
303 }
304
305 #[must_use]
309 pub fn all_comms_non_empty(&self) -> bool {
310 match self {
311 GlobalType::End => true,
312 GlobalType::Comm { branches, .. } => {
313 !branches.is_empty() && branches.iter().all(|(_, cont)| cont.all_comms_non_empty())
314 }
315 GlobalType::Mu { body, .. } => body.all_comms_non_empty(),
316 GlobalType::Var(_) => true,
317 }
318 }
319
320 #[must_use]
324 pub fn no_self_comm(&self) -> bool {
325 match self {
326 GlobalType::End => true,
327 GlobalType::Comm {
328 sender,
329 receiver,
330 branches,
331 } => sender != receiver && branches.iter().all(|(_, cont)| cont.no_self_comm()),
332 GlobalType::Mu { body, .. } => body.no_self_comm(),
333 GlobalType::Var(_) => true,
334 }
335 }
336
337 #[must_use]
346 pub fn well_formed(&self) -> bool {
347 self.all_vars_bound()
348 && self.all_comms_non_empty()
349 && self.no_self_comm()
350 && self.is_guarded()
351 }
352
353 #[must_use]
357 pub fn mentions_role(&self, role: &str) -> bool {
358 self.roles().contains(&role.to_string())
359 }
360
361 #[must_use]
365 pub fn depth(&self) -> usize {
366 match self {
367 GlobalType::End => 0,
368 GlobalType::Comm { branches, .. } => {
369 1 + branches.iter().map(|(_, g)| g.depth()).max().unwrap_or(0)
370 }
371 GlobalType::Mu { body, .. } => 1 + body.depth(),
372 GlobalType::Var(_) => 0,
373 }
374 }
375
376 #[must_use]
380 pub fn is_guarded(&self) -> bool {
381 match self {
382 GlobalType::End => true,
383 GlobalType::Comm { branches, .. } => branches.iter().all(|(_, cont)| cont.is_guarded()),
384 GlobalType::Mu { body, .. } => match body.as_ref() {
385 GlobalType::Var(_) | GlobalType::Mu { .. } => false,
386 _ => body.is_guarded(),
387 },
388 GlobalType::Var(_) => true,
389 }
390 }
391
392 #[must_use]
398 pub fn consume(&self, sender: &str, receiver: &str, label: &str) -> Option<GlobalType> {
399 match self {
400 GlobalType::Comm {
401 sender: s,
402 receiver: r,
403 branches,
404 } => {
405 if s == sender && r == receiver {
406 branches
407 .iter()
408 .find(|(l, _)| l.name == label)
409 .map(|(_, cont)| cont.clone())
410 } else {
411 None
412 }
413 }
414 GlobalType::Mu { var, body } => {
415 body.substitute(var, self).consume(sender, receiver, label)
417 }
418 _ => None,
419 }
420 }
421}
422
423#[cfg(test)]
424mod tests {
425 use super::*;
426 use assert_matches::assert_matches;
427
428 #[test]
429 fn test_simple_protocol() {
430 let g = GlobalType::send("A", "B", Label::new("hello"), GlobalType::End);
432 assert!(g.well_formed());
433 assert_eq!(g.roles().len(), 2);
434 assert!(g.mentions_role("A"));
435 assert!(g.mentions_role("B"));
436 }
437
438 #[test]
439 fn test_recursive_protocol() {
440 let g = GlobalType::mu(
442 "t",
443 GlobalType::send("A", "B", Label::new("msg"), GlobalType::var("t")),
444 );
445 assert!(g.well_formed());
446 assert!(g.all_vars_bound());
447 }
448
449 #[test]
450 fn test_unbound_variable() {
451 let g = GlobalType::send("A", "B", Label::new("msg"), GlobalType::var("t"));
453 assert!(!g.all_vars_bound());
454 assert!(!g.well_formed());
455 }
456
457 #[test]
458 fn test_self_communication() {
459 let g = GlobalType::send("A", "A", Label::new("msg"), GlobalType::End);
461 assert!(!g.no_self_comm());
462 assert!(!g.well_formed());
463 }
464
465 #[test]
466 fn test_unfold() {
467 let g = GlobalType::mu(
469 "t",
470 GlobalType::send("A", "B", Label::new("msg"), GlobalType::var("t")),
471 );
472 let unfolded = g.unfold();
473 assert_matches!(unfolded, GlobalType::Comm { sender, receiver, branches } => {
474 assert_eq!(sender, "A");
475 assert_eq!(receiver, "B");
476 assert_eq!(branches.len(), 1);
477 assert_matches!(branches[0].1, GlobalType::Mu { .. });
479 });
480 }
481
482 #[test]
483 fn test_substitute() {
484 let g = GlobalType::var("t");
485 let replacement = GlobalType::End;
486 assert_eq!(g.substitute("t", &replacement), GlobalType::End);
487 assert_eq!(g.substitute("s", &replacement), GlobalType::var("t"));
488 }
489
490 #[test]
491 fn test_consume() {
492 let g = GlobalType::comm(
493 "A",
494 "B",
495 vec![
496 (Label::new("accept"), GlobalType::End),
497 (Label::new("reject"), GlobalType::End),
498 ],
499 );
500
501 assert_eq!(g.consume("A", "B", "accept"), Some(GlobalType::End));
502 assert_eq!(g.consume("A", "B", "reject"), Some(GlobalType::End));
503 assert_eq!(g.consume("A", "B", "unknown"), None);
504 assert_eq!(g.consume("B", "A", "accept"), None);
505 }
506
507 #[test]
508 fn test_payload_sort() {
509 let sort = PayloadSort::prod(PayloadSort::Nat, PayloadSort::Bool);
510 assert!(!sort.is_simple());
511
512 let label = Label::with_sort("data", sort);
513 assert_eq!(label.name, "data");
514 }
515
516 #[test]
517 fn test_guarded() {
518 let unguarded = GlobalType::mu("t", GlobalType::var("t"));
520 assert!(!unguarded.is_guarded());
521 assert!(!unguarded.well_formed()); let guarded = GlobalType::mu(
525 "t",
526 GlobalType::send("A", "B", Label::new("msg"), GlobalType::var("t")),
527 );
528 assert!(guarded.is_guarded());
529 assert!(guarded.well_formed()); }
531}