use crate::{Label, LocalTypeR};
use std::collections::BTreeMap;
use thiserror::Error;
#[derive(Debug, Clone, Error)]
pub enum MergeError {
#[error("cannot merge end with non-end type: {0:?}")]
EndMismatch(LocalTypeR),
#[error("partner mismatch in merge: expected {expected}, found {found}")]
PartnerMismatch { expected: String, found: String },
#[error("direction mismatch in merge: cannot merge send with recv")]
DirectionMismatch,
#[error("incompatible continuations for label '{label}'")]
IncompatibleContinuations { label: String },
#[error("payload annotation mismatch for label '{label}': left={left}, right={right}")]
PayloadAnnotationMismatch {
label: String,
left: String,
right: String,
},
#[error("send branch label mismatch: cannot merge sends with different labels '{left}' vs '{right}'")]
SendLabelMismatch { left: String, right: String },
#[error("send branch count mismatch: {left} labels vs {right} labels")]
SendBranchCountMismatch { left: usize, right: usize },
#[error("recursive variable mismatch: expected {expected}, found {found}")]
RecursiveVariableMismatch { expected: String, found: String },
#[error("type variable mismatch: expected {expected}, found {found}")]
VariableMismatch { expected: String, found: String },
#[error("cannot merge incompatible types")]
IncompatibleTypes,
}
pub type MergeResult = Result<LocalTypeR, MergeError>;
pub fn merge(t1: &LocalTypeR, t2: &LocalTypeR) -> MergeResult {
if t1 == t2 {
return Ok(t1.clone());
}
match (t1, t2) {
(LocalTypeR::End, LocalTypeR::End) => Ok(LocalTypeR::End),
(LocalTypeR::End, other) | (other, LocalTypeR::End) => {
Err(MergeError::EndMismatch(other.clone()))
}
(
LocalTypeR::Send {
partner: p1,
branches: b1,
},
LocalTypeR::Send {
partner: p2,
branches: b2,
},
) => merge_send_pair(p1, b1, p2, b2),
(
LocalTypeR::Recv {
partner: p1,
branches: b1,
},
LocalTypeR::Recv {
partner: p2,
branches: b2,
},
) => merge_recv_pair(p1, b1, p2, b2),
(
LocalTypeR::Mu {
var: v1,
body: body1,
},
LocalTypeR::Mu {
var: v2,
body: body2,
},
) => merge_recursive_pair(v1, body1, v2, body2),
(LocalTypeR::Var(v1), LocalTypeR::Var(v2)) => merge_var_pair(v1, v2),
(LocalTypeR::Send { .. }, LocalTypeR::Recv { .. })
| (LocalTypeR::Recv { .. }, LocalTypeR::Send { .. }) => Err(MergeError::DirectionMismatch),
_ => Err(MergeError::IncompatibleTypes),
}
}
fn merge_send_pair(
p1: &str,
b1: &[(Label, Option<crate::ValType>, LocalTypeR)],
p2: &str,
b2: &[(Label, Option<crate::ValType>, LocalTypeR)],
) -> MergeResult {
if p1 != p2 {
return Err(MergeError::PartnerMismatch {
expected: p1.to_string(),
found: p2.to_string(),
});
}
let merged_branches = merge_send_branches(b1, b2)?;
Ok(LocalTypeR::Send {
partner: p1.to_string(),
branches: merged_branches,
})
}
fn merge_recv_pair(
p1: &str,
b1: &[(Label, Option<crate::ValType>, LocalTypeR)],
p2: &str,
b2: &[(Label, Option<crate::ValType>, LocalTypeR)],
) -> MergeResult {
if p1 != p2 {
return Err(MergeError::PartnerMismatch {
expected: p1.to_string(),
found: p2.to_string(),
});
}
let merged_branches = merge_recv_branches(b1, b2)?;
Ok(LocalTypeR::Recv {
partner: p1.to_string(),
branches: merged_branches,
})
}
fn merge_recursive_pair(v1: &str, body1: &LocalTypeR, v2: &str, body2: &LocalTypeR) -> MergeResult {
if v1 != v2 {
return Err(MergeError::RecursiveVariableMismatch {
expected: v1.to_string(),
found: v2.to_string(),
});
}
let merged_body = merge(body1, body2)?;
Ok(LocalTypeR::Mu {
var: v1.to_string(),
body: Box::new(merged_body),
})
}
fn merge_var_pair(v1: &str, v2: &str) -> MergeResult {
if v1 != v2 {
return Err(MergeError::VariableMismatch {
expected: v1.to_string(),
found: v2.to_string(),
});
}
Ok(LocalTypeR::Var(v1.to_string()))
}
fn merge_payload_annotations(
label: &Label,
left: &Option<crate::ValType>,
right: &Option<crate::ValType>,
) -> Result<Option<crate::ValType>, MergeError> {
if left == right {
return Ok(left.clone());
}
Err(MergeError::PayloadAnnotationMismatch {
label: label.name.clone(),
left: format!("{left:?}"),
right: format!("{right:?}"),
})
}
fn merge_send_branches(
branches1: &[(Label, Option<crate::ValType>, LocalTypeR)],
branches2: &[(Label, Option<crate::ValType>, LocalTypeR)],
) -> Result<Vec<(Label, Option<crate::ValType>, LocalTypeR)>, MergeError> {
let mut sorted1: Vec<_> = branches1.to_vec();
let mut sorted2: Vec<_> = branches2.to_vec();
sorted1.sort_by(|a, b| a.0.name.cmp(&b.0.name));
sorted2.sort_by(|a, b| a.0.name.cmp(&b.0.name));
if sorted1.len() != sorted2.len() {
return Err(MergeError::SendBranchCountMismatch {
left: sorted1.len(),
right: sorted2.len(),
});
}
let mut result = Vec::with_capacity(sorted1.len());
for ((label1, vt1, cont1), (label2, vt2, cont2)) in sorted1.iter().zip(sorted2.iter()) {
if label1.name != label2.name {
return Err(MergeError::SendLabelMismatch {
left: label1.name.clone(),
right: label2.name.clone(),
});
}
if label1.sort != label2.sort {
return Err(MergeError::IncompatibleContinuations {
label: label1.name.clone(),
});
}
let merged_cont =
merge(cont1, cont2).map_err(|_| MergeError::IncompatibleContinuations {
label: label1.name.clone(),
})?;
let merged_vt = merge_payload_annotations(label1, vt1, vt2)?;
result.push((label1.clone(), merged_vt, merged_cont));
}
Ok(result)
}
fn merge_recv_branches(
branches1: &[(Label, Option<crate::ValType>, LocalTypeR)],
branches2: &[(Label, Option<crate::ValType>, LocalTypeR)],
) -> Result<Vec<(Label, Option<crate::ValType>, LocalTypeR)>, MergeError> {
let mut result: BTreeMap<String, (Label, Option<crate::ValType>, LocalTypeR)> = BTreeMap::new();
for (label, vt, cont) in branches1 {
result.insert(
label.name.clone(),
(label.clone(), vt.clone(), cont.clone()),
);
}
for (label, vt, cont) in branches2 {
if let Some((existing_label, existing_vt, existing_cont)) = result.get(&label.name) {
let merged_cont =
merge(existing_cont, cont).map_err(|_| MergeError::IncompatibleContinuations {
label: label.name.clone(),
})?;
if existing_label.sort != label.sort {
return Err(MergeError::IncompatibleContinuations {
label: label.name.clone(),
});
}
let merged_vt = merge_payload_annotations(label, existing_vt, vt)?;
result.insert(label.name.clone(), (label.clone(), merged_vt, merged_cont));
} else {
result.insert(
label.name.clone(),
(label.clone(), vt.clone(), cont.clone()),
);
}
}
Ok(result.into_values().collect())
}
pub fn merge_all(types: &[LocalTypeR]) -> MergeResult {
match types {
[] => Err(MergeError::IncompatibleTypes),
[single] => Ok(single.clone()),
[first, rest @ ..] => {
let mut result = first.clone();
for t in rest {
result = merge(&result, t)?;
}
Ok(result)
}
}
}
#[must_use]
pub fn can_merge(t1: &LocalTypeR, t2: &LocalTypeR) -> bool {
merge(t1, t2).is_ok()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ValType;
use assert_matches::assert_matches;
#[test]
fn test_merge_identical_end() {
let result = merge(&LocalTypeR::End, &LocalTypeR::End).unwrap();
assert_eq!(result, LocalTypeR::End);
}
#[test]
fn test_merge_identical_send() {
let t = LocalTypeR::send("B", Label::new("msg"), LocalTypeR::End);
let result = merge(&t, &t).unwrap();
assert_eq!(result, t);
}
#[test]
fn test_merge_sends_same_labels_succeeds() {
let t1 = LocalTypeR::send("B", Label::new("x"), LocalTypeR::End);
let t2 = LocalTypeR::send("B", Label::new("x"), LocalTypeR::End);
let result = merge(&t1, &t2).unwrap();
assert_matches!(result, LocalTypeR::Send { partner, branches } => {
assert_eq!(partner, "B");
assert_eq!(branches.len(), 1);
assert_eq!(branches[0].0.name, "x");
});
}
#[test]
fn test_merge_sends_different_labels_fails() {
let t1 = LocalTypeR::send("B", Label::new("yes"), LocalTypeR::End);
let t2 = LocalTypeR::send("B", Label::new("no"), LocalTypeR::End);
let result = merge(&t1, &t2);
assert!(
matches!(result, Err(MergeError::SendLabelMismatch { .. })),
"Expected SendLabelMismatch, got {:?}",
result
);
}
#[test]
fn test_merge_sends_different_count_fails() {
let t1 = LocalTypeR::Send {
partner: "B".to_string(),
branches: vec![
(Label::new("x"), None, LocalTypeR::End),
(Label::new("y"), None, LocalTypeR::End),
],
};
let t2 = LocalTypeR::send("B", Label::new("x"), LocalTypeR::End);
let result = merge(&t1, &t2);
assert!(
matches!(result, Err(MergeError::SendBranchCountMismatch { .. })),
"Expected SendBranchCountMismatch, got {:?}",
result
);
}
#[test]
fn test_merge_sends_payload_annotation_mismatch_fails() {
let t1 = LocalTypeR::Send {
partner: "B".to_string(),
branches: vec![(Label::new("x"), Some(ValType::Nat), LocalTypeR::End)],
};
let t2 = LocalTypeR::Send {
partner: "B".to_string(),
branches: vec![(Label::new("x"), Some(ValType::Bool), LocalTypeR::End)],
};
let result = merge(&t1, &t2);
assert!(matches!(
result,
Err(MergeError::PayloadAnnotationMismatch { .. })
));
}
#[test]
fn test_merge_sends_payload_annotation_none_some_mismatch_fails() {
let t1 = LocalTypeR::Send {
partner: "B".to_string(),
branches: vec![(Label::new("x"), None, LocalTypeR::End)],
};
let t2 = LocalTypeR::Send {
partner: "B".to_string(),
branches: vec![(Label::new("x"), Some(ValType::Nat), LocalTypeR::End)],
};
let result = merge(&t1, &t2);
assert!(matches!(
result,
Err(MergeError::PayloadAnnotationMismatch { .. })
));
}
#[test]
fn test_merge_sends_different_partners_fails() {
let t1 = LocalTypeR::send("B", Label::new("msg"), LocalTypeR::End);
let t2 = LocalTypeR::send("C", Label::new("msg"), LocalTypeR::End);
let result = merge(&t1, &t2);
assert!(matches!(result, Err(MergeError::PartnerMismatch { .. })));
}
#[test]
fn test_merge_recvs_different_labels_succeeds() {
let t1 = LocalTypeR::recv("A", Label::new("x"), LocalTypeR::End);
let t2 = LocalTypeR::recv("A", Label::new("y"), LocalTypeR::End);
let result = merge(&t1, &t2).unwrap();
assert_matches!(result, LocalTypeR::Recv { partner, branches } => {
assert_eq!(partner, "A");
assert_eq!(branches.len(), 2);
let labels: Vec<_> = branches.iter().map(|(l, _, _)| l.name.as_str()).collect();
assert!(labels.contains(&"x"));
assert!(labels.contains(&"y"));
});
}
#[test]
fn test_merge_recvs_same_label_merges_continuations() {
let t1 = LocalTypeR::recv(
"A",
Label::new("x"),
LocalTypeR::send("B", Label::new("m"), LocalTypeR::End),
);
let t2 = LocalTypeR::recv(
"A",
Label::new("x"),
LocalTypeR::send("B", Label::new("m"), LocalTypeR::End),
);
let result = merge(&t1, &t2).unwrap();
assert_matches!(result, LocalTypeR::Recv { branches, .. } => {
assert_eq!(branches.len(), 1);
assert_matches!(&branches[0].2, LocalTypeR::Send { partner, .. } => {
assert_eq!(partner, "B");
});
});
}
#[test]
fn test_merge_recvs_overlapping_labels() {
let t1 = LocalTypeR::Recv {
partner: "A".to_string(),
branches: vec![
(Label::new("x"), None, LocalTypeR::End),
(Label::new("y"), None, LocalTypeR::End),
],
};
let t2 = LocalTypeR::Recv {
partner: "A".to_string(),
branches: vec![
(Label::new("y"), None, LocalTypeR::End),
(Label::new("z"), None, LocalTypeR::End),
],
};
let result = merge(&t1, &t2).unwrap();
assert_matches!(result, LocalTypeR::Recv { partner, branches } => {
assert_eq!(partner, "A");
assert_eq!(branches.len(), 3);
let labels: Vec<_> = branches.iter().map(|(l, _, _)| l.name.as_str()).collect();
assert!(labels.contains(&"x"));
assert!(labels.contains(&"y"));
assert!(labels.contains(&"z"));
});
}
#[test]
fn test_merge_recvs_overlapping_payload_annotation_mismatch_fails() {
let t1 = LocalTypeR::Recv {
partner: "A".to_string(),
branches: vec![(Label::new("y"), Some(ValType::Nat), LocalTypeR::End)],
};
let t2 = LocalTypeR::Recv {
partner: "A".to_string(),
branches: vec![(Label::new("y"), Some(ValType::Bool), LocalTypeR::End)],
};
let result = merge(&t1, &t2);
assert!(matches!(
result,
Err(MergeError::PayloadAnnotationMismatch { .. })
));
}
#[test]
fn test_merge_recvs_overlapping_payload_annotation_match_succeeds() {
let t1 = LocalTypeR::Recv {
partner: "A".to_string(),
branches: vec![(Label::new("y"), Some(ValType::Nat), LocalTypeR::End)],
};
let t2 = LocalTypeR::Recv {
partner: "A".to_string(),
branches: vec![(Label::new("y"), Some(ValType::Nat), LocalTypeR::End)],
};
let result = merge(&t1, &t2).expect("matching payload annotations should merge");
assert_matches!(result, LocalTypeR::Recv { branches, .. } => {
assert_eq!(branches.len(), 1);
assert_eq!(branches[0].1, Some(ValType::Nat));
});
}
#[test]
fn test_merge_send_recv_fails() {
let t1 = LocalTypeR::send("B", Label::new("msg"), LocalTypeR::End);
let t2 = LocalTypeR::recv("B", Label::new("msg"), LocalTypeR::End);
let result = merge(&t1, &t2);
assert!(matches!(result, Err(MergeError::DirectionMismatch)));
}
#[test]
fn test_merge_all_sends_same_labels() {
let types = vec![
LocalTypeR::send("B", Label::new("x"), LocalTypeR::End),
LocalTypeR::send("B", Label::new("x"), LocalTypeR::End),
LocalTypeR::send("B", Label::new("x"), LocalTypeR::End),
];
let result = merge_all(&types).unwrap();
assert_matches!(result, LocalTypeR::Send { branches, .. } => {
assert_eq!(branches.len(), 1);
assert_eq!(branches[0].0.name, "x");
});
}
#[test]
fn test_merge_all_sends_different_labels_fails() {
let types = vec![
LocalTypeR::send("B", Label::new("a"), LocalTypeR::End),
LocalTypeR::send("B", Label::new("b"), LocalTypeR::End),
];
let result = merge_all(&types);
assert!(result.is_err());
}
#[test]
fn test_merge_all_recvs_different_labels() {
let types = vec![
LocalTypeR::recv("A", Label::new("a"), LocalTypeR::End),
LocalTypeR::recv("A", Label::new("b"), LocalTypeR::End),
LocalTypeR::recv("A", Label::new("c"), LocalTypeR::End),
];
let result = merge_all(&types).unwrap();
assert_matches!(result, LocalTypeR::Recv { branches, .. } => {
assert_eq!(branches.len(), 3);
});
}
#[test]
fn test_can_merge_sends_same_label() {
let t1 = LocalTypeR::send("B", Label::new("msg"), LocalTypeR::End);
let t2 = LocalTypeR::send("B", Label::new("msg"), LocalTypeR::End);
assert!(can_merge(&t1, &t2));
}
#[test]
fn test_can_merge_sends_different_labels_false() {
let t1 = LocalTypeR::send("B", Label::new("msg"), LocalTypeR::End);
let t2 = LocalTypeR::send("B", Label::new("other"), LocalTypeR::End);
assert!(!can_merge(&t1, &t2));
}
#[test]
fn test_can_merge_recvs_different_labels_true() {
let t1 = LocalTypeR::recv("A", Label::new("msg"), LocalTypeR::End);
let t2 = LocalTypeR::recv("A", Label::new("other"), LocalTypeR::End);
assert!(can_merge(&t1, &t2));
}
#[test]
fn test_can_merge_send_recv_false() {
let t1 = LocalTypeR::send("B", Label::new("msg"), LocalTypeR::End);
let t2 = LocalTypeR::recv("B", Label::new("msg"), LocalTypeR::End);
assert!(!can_merge(&t1, &t2));
}
}