use crate::ast::{
choreography_to_global, local_to_local_r, Choreography, ConversionError, DslAnnotationEntry,
LocalType, Protocol, Role,
};
use crate::compiler::parser::{parse_choreography_str, ParseError};
use crate::compiler::projection::{project, ProjectionError};
use serde::{Deserialize, Serialize};
use std::collections::BTreeMap;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum AnnotationScope {
Statement,
Sender,
Receiver,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct ProtocolAnnotationRecord {
pub path: String,
pub node_kind: String,
pub scope: AnnotationScope,
pub role: Option<String>,
#[serde(default)]
pub peer_roles: Vec<String>,
pub key: String,
pub value: String,
}
#[derive(Debug)]
pub struct CompiledChoreography {
pub choreography: Choreography,
pub local_types: Vec<(Role, LocalType)>,
}
impl CompiledChoreography {
#[must_use]
pub fn role_names(&self) -> Vec<String> {
self.choreography
.roles
.iter()
.map(|role| role.name().to_string())
.collect()
}
#[must_use]
pub fn local_type(&self, role_name: &str) -> Option<&LocalType> {
self.local_types
.iter()
.find_map(|(role, local_type)| (*role.name() == *role_name).then_some(local_type))
}
pub fn try_global_type(&self) -> Result<crate::ast::GlobalTypeCore, ConversionError> {
choreography_to_global(&self.choreography)
}
pub fn try_local_type_r_map(
&self,
) -> Result<BTreeMap<String, crate::ast::LocalTypeR>, ConversionError> {
let mut out = BTreeMap::new();
for (role, local_type) in &self.local_types {
out.insert(role.name().to_string(), local_to_local_r(local_type)?);
}
Ok(out)
}
pub fn global_type_json(&self) -> Result<String, CompileArtifactsError> {
let global = self
.try_global_type()
.map_err(CompileArtifactsError::Conversion)?;
serde_json::to_string(&global).map_err(CompileArtifactsError::Serialization)
}
pub fn local_type_r_json(&self) -> Result<String, CompileArtifactsError> {
let locals = self
.try_local_type_r_map()
.map_err(CompileArtifactsError::Conversion)?;
serde_json::to_string(&locals).map_err(CompileArtifactsError::Serialization)
}
#[must_use]
pub fn annotation_records(&self) -> Vec<ProtocolAnnotationRecord> {
collect_choreography_annotation_records(&self.choreography)
}
}
#[derive(Debug, thiserror::Error)]
pub enum CompileArtifactsError {
#[error("parse error: {0}")]
Parse(#[from] ParseError),
#[error("validation error: {0}")]
Validation(String),
#[error("projection failed for role {role}: {source}")]
Projection {
role: String,
#[source]
source: ProjectionError,
},
#[error("theory conversion failed: {0}")]
Conversion(#[from] ConversionError),
#[error("serialization failed: {0}")]
Serialization(#[from] serde_json::Error),
}
pub fn compile_choreography(input: &str) -> Result<CompiledChoreography, CompileArtifactsError> {
let choreography = parse_choreography_str(input)?;
compile_choreography_ast(choreography)
}
pub fn compile_choreography_ast(
choreography: Choreography,
) -> Result<CompiledChoreography, CompileArtifactsError> {
choreography
.validate()
.map_err(|err| CompileArtifactsError::Validation(err.to_string()))?;
let mut local_types = Vec::new();
for role in &choreography.roles {
let local_type =
project(&choreography, role).map_err(|source| CompileArtifactsError::Projection {
role: role.name().to_string(),
source,
})?;
local_types.push((role.clone(), local_type));
}
Ok(CompiledChoreography {
choreography,
local_types,
})
}
#[must_use]
pub fn collect_choreography_annotation_records(
choreography: &Choreography,
) -> Vec<ProtocolAnnotationRecord> {
collect_protocol_annotation_records(&choreography.protocol)
}
#[must_use]
pub fn collect_protocol_annotation_records(protocol: &Protocol) -> Vec<ProtocolAnnotationRecord> {
let mut records = Vec::new();
collect_protocol_annotation_records_inner(protocol, "root", &mut records);
records
}
fn collect_protocol_annotation_records_inner(
protocol: &Protocol,
path: &str,
records: &mut Vec<ProtocolAnnotationRecord>,
) {
match protocol {
Protocol::Send {
from,
to,
continuation,
..
} => {
push_annotation_records(
records,
path,
"send",
AnnotationScope::Statement,
Some(from),
std::slice::from_ref(to),
protocol.get_annotations().dsl_entries(),
);
if let Some(from_annotations) = protocol.get_from_annotations() {
push_annotation_records(
records,
path,
"send",
AnnotationScope::Sender,
Some(from),
std::slice::from_ref(to),
from_annotations.dsl_entries(),
);
}
if let Some(to_annotations) = protocol.get_to_annotations() {
push_annotation_records(
records,
path,
"send",
AnnotationScope::Receiver,
Some(to),
std::slice::from_ref(from),
to_annotations.dsl_entries(),
);
}
collect_protocol_annotation_records_inner(
continuation,
&format!("{path}.continuation"),
records,
);
}
Protocol::Broadcast {
from,
to_all,
continuation,
..
} => {
let peers = to_all.iter().cloned().collect::<Vec<_>>();
push_annotation_records(
records,
path,
"broadcast",
AnnotationScope::Statement,
Some(from),
&peers,
protocol.get_annotations().dsl_entries(),
);
if let Some(from_annotations) = protocol.get_from_annotations() {
push_annotation_records(
records,
path,
"broadcast",
AnnotationScope::Sender,
Some(from),
&peers,
from_annotations.dsl_entries(),
);
}
collect_protocol_annotation_records_inner(
continuation,
&format!("{path}.continuation"),
records,
);
}
Protocol::Choice { role, branches, .. } => {
push_annotation_records(
records,
path,
"choice",
AnnotationScope::Statement,
Some(role),
&[],
protocol.get_annotations().dsl_entries(),
);
for branch in branches {
collect_protocol_annotation_records_inner(
&branch.protocol,
&format!("{path}.branch[{}]", branch.label),
records,
);
}
}
Protocol::Loop { body, .. } => {
collect_protocol_annotation_records_inner(body, &format!("{path}.body"), records);
}
Protocol::Parallel { protocols } => {
for (idx, branch) in protocols.iter().enumerate() {
collect_protocol_annotation_records_inner(
branch,
&format!("{path}.parallel[{idx}]"),
records,
);
}
}
Protocol::Rec { label, body } => {
collect_protocol_annotation_records_inner(
body,
&format!("{path}.rec[{label}]"),
records,
);
}
Protocol::Timeout {
body,
on_timeout,
on_cancel,
..
} => {
collect_protocol_annotation_records_inner(
body,
&format!("{path}.timeout.body"),
records,
);
collect_protocol_annotation_records_inner(
on_timeout,
&format!("{path}.timeout.on_timeout"),
records,
);
if let Some(on_cancel) = on_cancel {
collect_protocol_annotation_records_inner(
on_cancel,
&format!("{path}.timeout.on_cancel"),
records,
);
}
}
Protocol::Case { branches, .. } => {
for branch in branches {
collect_protocol_annotation_records_inner(
&branch.protocol,
&format!("{path}.case[{}]", branch.pattern.constructor),
records,
);
}
}
Protocol::Extension { continuation, .. } => {
push_annotation_records(
records,
path,
"extension",
AnnotationScope::Statement,
None,
&[],
protocol.get_annotations().dsl_entries(),
);
collect_protocol_annotation_records_inner(
continuation,
&format!("{path}.continuation"),
records,
);
}
Protocol::Begin { continuation, .. }
| Protocol::Await { continuation, .. }
| Protocol::Resolve { continuation, .. }
| Protocol::Invalidate { continuation, .. }
| Protocol::Let { continuation, .. }
| Protocol::Publish { continuation, .. }
| Protocol::PublishAuthority { continuation, .. }
| Protocol::Materialize { continuation, .. }
| Protocol::Handoff { continuation, .. }
| Protocol::DependentWork { continuation, .. } => {
collect_protocol_annotation_records_inner(
continuation,
&format!("{path}.continuation"),
records,
);
}
Protocol::Var(_) | Protocol::End => {}
}
}
fn push_annotation_records(
records: &mut Vec<ProtocolAnnotationRecord>,
path: &str,
node_kind: &str,
scope: AnnotationScope,
role: Option<&Role>,
peer_roles: &[Role],
entries: Vec<DslAnnotationEntry>,
) {
let role = role.map(|role| role.name().to_string());
let peer_roles = peer_roles
.iter()
.map(|role| role.name().to_string())
.collect::<Vec<_>>();
for entry in entries {
records.push(ProtocolAnnotationRecord {
path: path.to_string(),
node_kind: node_kind.to_string(),
scope,
role: role.clone(),
peer_roles: peer_roles.clone(),
key: entry.key,
value: entry.value,
});
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn ordered_annotation_records_preserve_sender_order() {
let compiled = compile_choreography(
r#"
protocol Demo =
roles Alice, Bob
Alice { guard_capability : "chat:send", flow_cost : 10, leak : external } -> Bob : Msg
"#,
)
.expect("compile choreography");
let records = compiled
.annotation_records()
.into_iter()
.filter(|record| {
record.path == "root"
&& record.scope == AnnotationScope::Sender
&& record.role.as_deref() == Some("Alice")
})
.collect::<Vec<_>>();
assert_eq!(
records
.iter()
.map(|record| record.key.as_str())
.collect::<Vec<_>>(),
vec!["guard_capability", "flow_cost", "leak"]
);
}
}