use super::functions::*;
use std::collections::{HashMap, HashSet, VecDeque};
#[derive(Debug, Clone)]
pub enum GType {
Comm {
sender: Role,
receiver: Role,
msg_ty: BaseType,
cont: Box<GType>,
},
Choice {
selector: Role,
receiver: Role,
branches: HashMap<String, GType>,
},
End,
Rec(String, Box<GType>),
Var(String),
}
impl GType {
pub fn participants(&self) -> HashSet<Role> {
let mut roles = HashSet::new();
self.collect_roles(&mut roles);
roles
}
fn collect_roles(&self, roles: &mut HashSet<Role>) {
match self {
GType::Comm {
sender,
receiver,
cont,
..
} => {
roles.insert(sender.clone());
roles.insert(receiver.clone());
cont.collect_roles(roles);
}
GType::Choice {
selector,
receiver,
branches,
} => {
roles.insert(selector.clone());
roles.insert(receiver.clone());
for cont in branches.values() {
cont.collect_roles(roles);
}
}
GType::End | GType::Var(_) => {}
GType::Rec(_, body) => body.collect_roles(roles),
}
}
pub fn project(&self, role: &Role) -> LType {
match self {
GType::Comm {
sender,
receiver,
msg_ty,
cont,
} => {
let cont_proj = cont.project(role);
if sender == role {
LType::Send(receiver.clone(), msg_ty.clone(), Box::new(cont_proj))
} else if receiver == role {
LType::Recv(sender.clone(), msg_ty.clone(), Box::new(cont_proj))
} else {
cont_proj
}
}
GType::Choice {
selector,
receiver,
branches,
} => {
if selector == role {
let mut proj_branches: Vec<(String, LType)> = branches
.iter()
.map(|(lbl, g)| (lbl.clone(), g.project(role)))
.collect();
proj_branches.sort_by(|a, b| a.0.cmp(&b.0));
LType::IChoice(receiver.clone(), proj_branches)
} else if receiver == role {
let mut proj_branches: Vec<(String, LType)> = branches
.iter()
.map(|(lbl, g)| (lbl.clone(), g.project(role)))
.collect();
proj_branches.sort_by(|a, b| a.0.cmp(&b.0));
LType::EChoice(selector.clone(), proj_branches)
} else {
let projs: Vec<LType> = branches.values().map(|g| g.project(role)).collect();
Self::merge_all(projs)
}
}
GType::End => LType::End,
GType::Rec(x, body) => LType::Rec(x.clone(), Box::new(body.project(role))),
GType::Var(x) => LType::Var(x.clone()),
}
}
fn merge_all(types: Vec<LType>) -> LType {
types
.into_iter()
.reduce(|a, b| if a == b { a } else { b })
.unwrap_or(LType::End)
}
}
pub struct AsyncSessionEndpoint {
pub remaining: SType,
outbox: VecDeque<Message>,
inbox: VecDeque<Message>,
}
impl AsyncSessionEndpoint {
pub fn new(stype: SType) -> Self {
AsyncSessionEndpoint {
remaining: stype,
outbox: VecDeque::new(),
inbox: VecDeque::new(),
}
}
pub fn async_send(&mut self, msg: Message) -> Result<(), String> {
match &self.remaining.clone() {
SType::Send(_, cont) => {
self.remaining = *cont.clone();
self.outbox.push_back(msg);
Ok(())
}
other => Err(format!("AsyncSend: expected Send, got {}", other)),
}
}
pub fn flush_to(&mut self, peer: &mut AsyncSessionEndpoint) -> usize {
let count = self.outbox.len();
while let Some(msg) = self.outbox.pop_front() {
peer.inbox.push_back(msg);
}
count
}
pub fn async_recv(&mut self) -> Result<Message, String> {
match &self.remaining.clone() {
SType::Recv(_, cont) => {
if let Some(msg) = self.inbox.pop_front() {
self.remaining = *cont.clone();
Ok(msg)
} else {
Err("AsyncRecv: inbox empty — message not yet delivered".to_string())
}
}
other => Err(format!("AsyncRecv: expected Recv, got {}", other)),
}
}
pub fn outbox_len(&self) -> usize {
self.outbox.len()
}
pub fn inbox_len(&self) -> usize {
self.inbox.len()
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct Role(pub String);
impl Role {
pub fn new(name: impl Into<String>) -> Self {
Role(name.into())
}
}
pub struct ProtocolBuilder {
current: SType,
}
impl ProtocolBuilder {
pub fn end() -> Self {
ProtocolBuilder {
current: SType::End,
}
}
pub fn then_send(self, ty: BaseType) -> Self {
ProtocolBuilder {
current: SType::Send(Box::new(ty), Box::new(self.current)),
}
}
pub fn then_recv(self, ty: BaseType) -> Self {
ProtocolBuilder {
current: SType::Recv(Box::new(ty), Box::new(self.current)),
}
}
pub fn build(self) -> SType {
self.current
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum SType {
Send(Box<BaseType>, Box<SType>),
Recv(Box<BaseType>, Box<SType>),
End,
Choice(Box<SType>, Box<SType>),
Branch(Box<SType>, Box<SType>),
Rec(String, Box<SType>),
Var(String),
}
impl SType {
pub fn dual(&self) -> SType {
match self {
SType::Send(t, s) => SType::Recv(t.clone(), Box::new(s.dual())),
SType::Recv(t, s) => SType::Send(t.clone(), Box::new(s.dual())),
SType::End => SType::End,
SType::Choice(s1, s2) => SType::Branch(Box::new(s1.dual()), Box::new(s2.dual())),
SType::Branch(s1, s2) => SType::Choice(Box::new(s1.dual()), Box::new(s2.dual())),
SType::Rec(x, s) => SType::Rec(x.clone(), Box::new(s.dual())),
SType::Var(x) => SType::Var(x.clone()),
}
}
pub fn unfold(&self) -> SType {
match self {
SType::Rec(x, body) => {
let mut body = (**body).clone();
body.subst_var(x, self);
body
}
other => other.clone(),
}
}
fn subst_var(&mut self, x: &str, replacement: &SType) {
match self {
SType::Send(_, s) | SType::Recv(_, s) => s.subst_var(x, replacement),
SType::End => {}
SType::Choice(s1, s2) | SType::Branch(s1, s2) => {
s1.subst_var(x, replacement);
s2.subst_var(x, replacement);
}
SType::Rec(y, s) => {
if y != x {
s.subst_var(x, replacement);
}
}
SType::Var(y) => {
if y == x {
*self = replacement.clone();
}
}
}
}
pub fn is_end(&self) -> bool {
matches!(self, SType::End)
}
pub fn is_send(&self) -> bool {
matches!(self, SType::Send(_, _))
}
pub fn is_recv(&self) -> bool {
matches!(self, SType::Recv(_, _))
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum BaseType {
Nat,
Bool,
Str,
Unit,
Named(String),
Pair(Box<BaseType>, Box<BaseType>),
Sum(Box<BaseType>, Box<BaseType>),
}
pub struct ProbSessionScheduler {
branches: Vec<ProbBranch>,
}
impl ProbSessionScheduler {
pub fn new(branches: Vec<ProbBranch>) -> Self {
ProbSessionScheduler { branches }
}
pub fn probabilities(&self) -> Vec<f64> {
let total: f64 = self.branches.iter().map(|b| b.weight).sum();
if total == 0.0 {
return vec![0.0; self.branches.len()];
}
self.branches.iter().map(|b| b.weight / total).collect()
}
pub fn greedy_choice(&self) -> Option<usize> {
self.branches
.iter()
.enumerate()
.max_by(|a, b| {
a.1.weight
.partial_cmp(&b.1.weight)
.unwrap_or(std::cmp::Ordering::Equal)
})
.map(|(idx, _)| idx)
}
pub fn expected_rounds(&self) -> f64 {
let probs = self.probabilities();
probs
.iter()
.zip(self.branches.iter())
.map(|(p, b)| {
let cost = if b.cont == SType::End { 0.0 } else { 1.0 };
p * cost
})
.sum()
}
}
#[derive(Debug, Clone)]
pub enum Message {
Nat(u64),
Bool(bool),
Str(String),
Unit,
Left,
Right,
}
pub struct SessionSubtypeChecker {
decided: HashSet<(String, String)>,
}
impl SessionSubtypeChecker {
pub fn new() -> Self {
SessionSubtypeChecker {
decided: HashSet::new(),
}
}
pub fn is_subtype(&mut self, sub: &SType, sup: &SType) -> bool {
let key = (format!("{}", sub), format!("{}", sup));
if self.decided.contains(&key) {
return true;
}
self.decided.insert(key);
match (sub, sup) {
(SType::End, SType::End) => true,
(SType::Send(t1, s1), SType::Send(t2, s2)) => t1 == t2 && self.is_subtype(s1, s2),
(SType::Recv(t1, s1), SType::Recv(t2, s2)) => t1 == t2 && self.is_subtype(s1, s2),
(SType::Choice(l1, r1), SType::Choice(l2, r2)) => {
self.is_subtype(l1, l2) && self.is_subtype(r1, r2)
}
(SType::Branch(l1, r1), SType::Branch(l2, r2)) => {
self.is_subtype(l1, l2) && self.is_subtype(r1, r2)
}
(SType::Rec(_, _), _) => self.is_subtype(&sub.unfold(), sup),
(_, SType::Rec(_, _)) => self.is_subtype(sub, &sup.unfold()),
_ => false,
}
}
}
#[derive(Debug, Clone)]
pub enum ChoreographyStep {
Comm {
sender: String,
receiver: String,
msg_ty: String,
},
Choice {
selector: String,
receiver: String,
branch: String,
},
End,
}
pub struct ChoreographyEngine {
pub trace: Vec<ChoreographyStep>,
}
impl ChoreographyEngine {
pub fn new() -> Self {
ChoreographyEngine { trace: vec![] }
}
pub fn execute(&mut self, gtype: >ype) -> Result<(), String> {
match gtype {
GType::Comm {
sender,
receiver,
msg_ty,
cont,
} => {
self.trace.push(ChoreographyStep::Comm {
sender: sender.0.clone(),
receiver: receiver.0.clone(),
msg_ty: format!("{}", msg_ty),
});
self.execute(cont)
}
GType::Choice {
selector,
receiver,
branches,
} => {
let mut sorted: Vec<(&String, >ype)> = branches.iter().collect();
sorted.sort_by_key(|(k, _)| k.as_str());
if let Some((label, cont)) = sorted.first() {
self.trace.push(ChoreographyStep::Choice {
selector: selector.0.clone(),
receiver: receiver.0.clone(),
branch: (*label).clone(),
});
self.execute(cont)
} else {
Err("GType::Choice has no branches".to_string())
}
}
GType::End => {
self.trace.push(ChoreographyStep::End);
Ok(())
}
GType::Rec(_, body) => self.execute(body),
GType::Var(x) => Err(format!("Unresolved recursion variable: {}", x)),
}
}
pub fn comm_count(&self) -> usize {
self.trace
.iter()
.filter(|s| !matches!(s, ChoreographyStep::End))
.count()
}
}
pub struct SessionEndpoint {
pub remaining: SType,
buffer: VecDeque<Message>,
closed: bool,
}
impl SessionEndpoint {
pub fn new(stype: SType) -> Self {
SessionEndpoint {
remaining: stype,
buffer: VecDeque::new(),
closed: false,
}
}
pub fn send(&mut self, msg: Message) -> Result<(), String> {
match &self.remaining.clone() {
SType::Send(_, continuation) => {
self.remaining = *continuation.clone();
self.buffer.push_back(msg);
Ok(())
}
other => Err(format!("Expected Send, got {}", other)),
}
}
pub fn recv(&mut self) -> Result<Message, String> {
match &self.remaining.clone() {
SType::Recv(_, continuation) => {
if let Some(msg) = self.buffer.pop_front() {
self.remaining = *continuation.clone();
Ok(msg)
} else {
Err("No message available".to_string())
}
}
other => Err(format!("Expected Recv, got {}", other)),
}
}
pub fn select_left(&mut self) -> Result<(), String> {
match &self.remaining.clone() {
SType::Choice(left, _) => {
self.remaining = *left.clone();
Ok(())
}
other => Err(format!("Expected Choice, got {}", other)),
}
}
pub fn select_right(&mut self) -> Result<(), String> {
match &self.remaining.clone() {
SType::Choice(_, right) => {
self.remaining = *right.clone();
Ok(())
}
other => Err(format!("Expected Choice, got {}", other)),
}
}
pub fn close(&mut self) -> Result<(), String> {
if self.remaining == SType::End {
self.closed = true;
Ok(())
} else {
Err(format!("Expected End, got {}", self.remaining))
}
}
pub fn is_complete(&self) -> bool {
self.closed
}
}
#[allow(clippy::too_many_arguments)]
pub struct ProbBranch {
pub label: String,
pub weight: f64,
pub cont: SType,
}
pub struct DeadlockChecker {
wait_edges: Vec<(String, String, String)>,
}
impl DeadlockChecker {
pub fn new() -> Self {
DeadlockChecker { wait_edges: vec![] }
}
pub fn add_wait(
&mut self,
channel: impl Into<String>,
waiter: impl Into<String>,
provider: impl Into<String>,
) {
self.wait_edges
.push((channel.into(), waiter.into(), provider.into()));
}
pub fn is_deadlock_free(&self) -> bool {
let mut adj: HashMap<&str, Vec<&str>> = HashMap::new();
for (_, waiter, provider) in &self.wait_edges {
adj.entry(waiter.as_str())
.or_default()
.push(provider.as_str());
}
let mut visited: HashSet<&str> = HashSet::new();
let mut in_stack: HashSet<&str> = HashSet::new();
let nodes: Vec<&str> = adj.keys().copied().collect();
for &node in &nodes {
if !visited.contains(node) && Self::has_cycle(node, &adj, &mut visited, &mut in_stack) {
return false;
}
}
true
}
fn has_cycle<'a>(
node: &'a str,
adj: &HashMap<&'a str, Vec<&'a str>>,
visited: &mut HashSet<&'a str>,
in_stack: &mut HashSet<&'a str>,
) -> bool {
visited.insert(node);
in_stack.insert(node);
if let Some(neighbors) = adj.get(node) {
for &nb in neighbors {
if !visited.contains(nb) {
if Self::has_cycle(nb, adj, visited, in_stack) {
return true;
}
} else if in_stack.contains(nb) {
return true;
}
}
}
in_stack.remove(node);
false
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum LType {
Send(Role, BaseType, Box<LType>),
Recv(Role, BaseType, Box<LType>),
IChoice(Role, Vec<(String, LType)>),
EChoice(Role, Vec<(String, LType)>),
End,
Rec(String, Box<LType>),
Var(String),
}
#[derive(Debug, Clone)]
pub enum SessionOp {
Send(BaseType),
Recv(BaseType),
SelectLeft,
SelectRight,
Close,
}
impl SessionOp {
pub fn check_step(&self, stype: SType) -> Result<SType, String> {
match (self, &stype) {
(SessionOp::Send(t), SType::Send(expected, cont)) => {
if t == expected.as_ref() {
Ok(*cont.clone())
} else {
Err(format!(
"Type mismatch: sent {:?} but expected {:?}",
t, expected
))
}
}
(SessionOp::Recv(t), SType::Recv(expected, cont)) => {
if t == expected.as_ref() {
Ok(*cont.clone())
} else {
Err(format!(
"Type mismatch: recv {:?} but expected {:?}",
t, expected
))
}
}
(SessionOp::SelectLeft, SType::Choice(left, _)) => Ok(*left.clone()),
(SessionOp::SelectRight, SType::Choice(_, right)) => Ok(*right.clone()),
(SessionOp::Close, SType::End) => Ok(SType::End),
_ => Err(format!(
"Operation {:?} incompatible with session type {}",
self, stype
)),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum MonitorResult {
Ok,
CastInserted(String),
Failure(String),
}
pub struct SessionChecker {
channels: HashMap<String, SType>,
}
impl SessionChecker {
pub fn new() -> Self {
SessionChecker {
channels: HashMap::new(),
}
}
pub fn register_channel(&mut self, name: impl Into<String>, stype: SType) {
self.channels.insert(name.into(), stype);
}
pub fn check_usage(&self, channel: &str, ops: &[SessionOp]) -> Result<SType, String> {
let stype = self
.channels
.get(channel)
.ok_or_else(|| format!("Unknown channel: {}", channel))?;
let mut current = stype.clone();
for op in ops {
current = op.check_step(current)?;
}
Ok(current)
}
}
pub struct GradualSessionMonitor {
expected: SType,
pub violations: Vec<String>,
pub casts: Vec<String>,
}
impl GradualSessionMonitor {
pub fn new(expected: SType) -> Self {
GradualSessionMonitor {
expected,
violations: vec![],
casts: vec![],
}
}
pub fn check_send(&mut self, actual_ty: &BaseType) -> MonitorResult {
match self.expected.clone() {
SType::Send(expected_ty, cont) => {
self.expected = *cont;
if actual_ty == expected_ty.as_ref() {
MonitorResult::Ok
} else {
let msg = format!("cast {:?} → {:?}", actual_ty, expected_ty);
self.casts.push(msg.clone());
MonitorResult::CastInserted(msg)
}
}
SType::Var(ref s) if s == "?" => {
let msg = format!("dynamic send {:?}", actual_ty);
self.casts.push(msg.clone());
MonitorResult::CastInserted(msg)
}
other => {
let msg = format!("expected Send, got {}", other);
self.violations.push(msg.clone());
MonitorResult::Failure(msg)
}
}
}
pub fn check_recv(&mut self, actual_ty: &BaseType) -> MonitorResult {
match self.expected.clone() {
SType::Recv(expected_ty, cont) => {
self.expected = *cont;
if actual_ty == expected_ty.as_ref() {
MonitorResult::Ok
} else {
let msg = format!("cast {:?} → {:?}", actual_ty, expected_ty);
self.casts.push(msg.clone());
MonitorResult::CastInserted(msg)
}
}
SType::Var(ref s) if s == "?" => {
let msg = format!("dynamic recv {:?}", actual_ty);
self.casts.push(msg.clone());
MonitorResult::CastInserted(msg)
}
other => {
let msg = format!("expected Recv, got {}", other);
self.violations.push(msg.clone());
MonitorResult::Failure(msg)
}
}
}
pub fn is_safe(&self) -> bool {
self.violations.is_empty()
}
}