1use crate::ast::{
8 choreography_to_global, local_to_local_r, Choreography, ConversionError, DslAnnotationEntry,
9 LocalType, Protocol, Role,
10};
11use crate::compiler::parser::{parse_choreography_str, ParseError};
12use crate::compiler::projection::{project, ProjectionError};
13use serde::{Deserialize, Serialize};
14use std::collections::BTreeMap;
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
18#[serde(rename_all = "snake_case")]
19pub enum AnnotationScope {
20 Statement,
22 Sender,
24 Receiver,
26}
27
28#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
30pub struct ProtocolAnnotationRecord {
31 pub path: String,
33 pub node_kind: String,
35 pub scope: AnnotationScope,
37 pub role: Option<String>,
39 #[serde(default)]
41 pub peer_roles: Vec<String>,
42 pub key: String,
44 pub value: String,
46}
47
48#[derive(Debug)]
50pub struct CompiledChoreography {
51 pub choreography: Choreography,
53 pub local_types: Vec<(Role, LocalType)>,
55}
56
57impl CompiledChoreography {
58 #[must_use]
60 pub fn role_names(&self) -> Vec<String> {
61 self.choreography
62 .roles
63 .iter()
64 .map(|role| role.name().to_string())
65 .collect()
66 }
67
68 #[must_use]
70 pub fn local_type(&self, role_name: &str) -> Option<&LocalType> {
71 self.local_types
72 .iter()
73 .find_map(|(role, local_type)| (*role.name() == *role_name).then_some(local_type))
74 }
75
76 pub fn try_global_type(&self) -> Result<crate::ast::GlobalTypeCore, ConversionError> {
81 choreography_to_global(&self.choreography)
82 }
83
84 pub fn try_local_type_r_map(
86 &self,
87 ) -> Result<BTreeMap<String, crate::ast::LocalTypeR>, ConversionError> {
88 let mut out = BTreeMap::new();
89 for (role, local_type) in &self.local_types {
90 out.insert(role.name().to_string(), local_to_local_r(local_type)?);
91 }
92 Ok(out)
93 }
94
95 pub fn global_type_json(&self) -> Result<String, CompileArtifactsError> {
97 let global = self
98 .try_global_type()
99 .map_err(CompileArtifactsError::Conversion)?;
100 serde_json::to_string(&global).map_err(CompileArtifactsError::Serialization)
101 }
102
103 pub fn local_type_r_json(&self) -> Result<String, CompileArtifactsError> {
105 let locals = self
106 .try_local_type_r_map()
107 .map_err(CompileArtifactsError::Conversion)?;
108 serde_json::to_string(&locals).map_err(CompileArtifactsError::Serialization)
109 }
110
111 #[must_use]
113 pub fn annotation_records(&self) -> Vec<ProtocolAnnotationRecord> {
114 collect_choreography_annotation_records(&self.choreography)
115 }
116}
117
118#[derive(Debug, thiserror::Error)]
120pub enum CompileArtifactsError {
121 #[error("parse error: {0}")]
122 Parse(#[from] ParseError),
123
124 #[error("validation error: {0}")]
125 Validation(String),
126
127 #[error("projection failed for role {role}: {source}")]
128 Projection {
129 role: String,
130 #[source]
131 source: ProjectionError,
132 },
133
134 #[error("theory conversion failed: {0}")]
135 Conversion(#[from] ConversionError),
136
137 #[error("serialization failed: {0}")]
138 Serialization(#[from] serde_json::Error),
139}
140
141pub fn compile_choreography(input: &str) -> Result<CompiledChoreography, CompileArtifactsError> {
143 let choreography = parse_choreography_str(input)?;
144 compile_choreography_ast(choreography)
145}
146
147pub fn compile_choreography_ast(
149 choreography: Choreography,
150) -> Result<CompiledChoreography, CompileArtifactsError> {
151 choreography
152 .validate()
153 .map_err(|err| CompileArtifactsError::Validation(err.to_string()))?;
154
155 let mut local_types = Vec::new();
156 for role in &choreography.roles {
157 let local_type =
158 project(&choreography, role).map_err(|source| CompileArtifactsError::Projection {
159 role: role.name().to_string(),
160 source,
161 })?;
162 local_types.push((role.clone(), local_type));
163 }
164
165 Ok(CompiledChoreography {
166 choreography,
167 local_types,
168 })
169}
170
171#[must_use]
173pub fn collect_choreography_annotation_records(
174 choreography: &Choreography,
175) -> Vec<ProtocolAnnotationRecord> {
176 collect_protocol_annotation_records(&choreography.protocol)
177}
178
179#[must_use]
181pub fn collect_protocol_annotation_records(protocol: &Protocol) -> Vec<ProtocolAnnotationRecord> {
182 let mut records = Vec::new();
183 collect_protocol_annotation_records_inner(protocol, "root", &mut records);
184 records
185}
186
187fn collect_protocol_annotation_records_inner(
188 protocol: &Protocol,
189 path: &str,
190 records: &mut Vec<ProtocolAnnotationRecord>,
191) {
192 match protocol {
193 Protocol::Send {
194 from,
195 to,
196 continuation,
197 ..
198 } => {
199 push_annotation_records(
200 records,
201 path,
202 "send",
203 AnnotationScope::Statement,
204 Some(from),
205 std::slice::from_ref(to),
206 protocol.get_annotations().dsl_entries(),
207 );
208 if let Some(from_annotations) = protocol.get_from_annotations() {
209 push_annotation_records(
210 records,
211 path,
212 "send",
213 AnnotationScope::Sender,
214 Some(from),
215 std::slice::from_ref(to),
216 from_annotations.dsl_entries(),
217 );
218 }
219 if let Some(to_annotations) = protocol.get_to_annotations() {
220 push_annotation_records(
221 records,
222 path,
223 "send",
224 AnnotationScope::Receiver,
225 Some(to),
226 std::slice::from_ref(from),
227 to_annotations.dsl_entries(),
228 );
229 }
230 collect_protocol_annotation_records_inner(
231 continuation,
232 &format!("{path}.continuation"),
233 records,
234 );
235 }
236 Protocol::Broadcast {
237 from,
238 to_all,
239 continuation,
240 ..
241 } => {
242 let peers = to_all.iter().cloned().collect::<Vec<_>>();
243 push_annotation_records(
244 records,
245 path,
246 "broadcast",
247 AnnotationScope::Statement,
248 Some(from),
249 &peers,
250 protocol.get_annotations().dsl_entries(),
251 );
252 if let Some(from_annotations) = protocol.get_from_annotations() {
253 push_annotation_records(
254 records,
255 path,
256 "broadcast",
257 AnnotationScope::Sender,
258 Some(from),
259 &peers,
260 from_annotations.dsl_entries(),
261 );
262 }
263 collect_protocol_annotation_records_inner(
264 continuation,
265 &format!("{path}.continuation"),
266 records,
267 );
268 }
269 Protocol::Choice { role, branches, .. } => {
270 push_annotation_records(
271 records,
272 path,
273 "choice",
274 AnnotationScope::Statement,
275 Some(role),
276 &[],
277 protocol.get_annotations().dsl_entries(),
278 );
279 for branch in branches {
280 collect_protocol_annotation_records_inner(
281 &branch.protocol,
282 &format!("{path}.branch[{}]", branch.label),
283 records,
284 );
285 }
286 }
287 Protocol::Loop { body, .. } => {
288 collect_protocol_annotation_records_inner(body, &format!("{path}.body"), records);
289 }
290 Protocol::Parallel { protocols } => {
291 for (idx, branch) in protocols.iter().enumerate() {
292 collect_protocol_annotation_records_inner(
293 branch,
294 &format!("{path}.parallel[{idx}]"),
295 records,
296 );
297 }
298 }
299 Protocol::Rec { label, body } => {
300 collect_protocol_annotation_records_inner(
301 body,
302 &format!("{path}.rec[{label}]"),
303 records,
304 );
305 }
306 Protocol::Timeout {
307 body,
308 on_timeout,
309 on_cancel,
310 ..
311 } => {
312 collect_protocol_annotation_records_inner(
313 body,
314 &format!("{path}.timeout.body"),
315 records,
316 );
317 collect_protocol_annotation_records_inner(
318 on_timeout,
319 &format!("{path}.timeout.on_timeout"),
320 records,
321 );
322 if let Some(on_cancel) = on_cancel {
323 collect_protocol_annotation_records_inner(
324 on_cancel,
325 &format!("{path}.timeout.on_cancel"),
326 records,
327 );
328 }
329 }
330 Protocol::Case { branches, .. } => {
331 for branch in branches {
332 collect_protocol_annotation_records_inner(
333 &branch.protocol,
334 &format!("{path}.case[{}]", branch.pattern.constructor),
335 records,
336 );
337 }
338 }
339 Protocol::Extension { continuation, .. } => {
340 push_annotation_records(
341 records,
342 path,
343 "extension",
344 AnnotationScope::Statement,
345 None,
346 &[],
347 protocol.get_annotations().dsl_entries(),
348 );
349 collect_protocol_annotation_records_inner(
350 continuation,
351 &format!("{path}.continuation"),
352 records,
353 );
354 }
355 Protocol::Begin { continuation, .. }
356 | Protocol::Await { continuation, .. }
357 | Protocol::Resolve { continuation, .. }
358 | Protocol::Invalidate { continuation, .. }
359 | Protocol::Let { continuation, .. }
360 | Protocol::Publish { continuation, .. }
361 | Protocol::PublishAuthority { continuation, .. }
362 | Protocol::Materialize { continuation, .. }
363 | Protocol::Handoff { continuation, .. }
364 | Protocol::DependentWork { continuation, .. } => {
365 collect_protocol_annotation_records_inner(
366 continuation,
367 &format!("{path}.continuation"),
368 records,
369 );
370 }
371 Protocol::Var(_) | Protocol::End => {}
372 }
373}
374
375fn push_annotation_records(
376 records: &mut Vec<ProtocolAnnotationRecord>,
377 path: &str,
378 node_kind: &str,
379 scope: AnnotationScope,
380 role: Option<&Role>,
381 peer_roles: &[Role],
382 entries: Vec<DslAnnotationEntry>,
383) {
384 let role = role.map(|role| role.name().to_string());
385 let peer_roles = peer_roles
386 .iter()
387 .map(|role| role.name().to_string())
388 .collect::<Vec<_>>();
389
390 for entry in entries {
391 records.push(ProtocolAnnotationRecord {
392 path: path.to_string(),
393 node_kind: node_kind.to_string(),
394 scope,
395 role: role.clone(),
396 peer_roles: peer_roles.clone(),
397 key: entry.key,
398 value: entry.value,
399 });
400 }
401}
402
403#[cfg(test)]
404mod tests {
405 use super::*;
406
407 #[test]
408 fn ordered_annotation_records_preserve_sender_order() {
409 let compiled = compile_choreography(
410 r#"
411protocol Demo =
412 roles Alice, Bob
413 Alice { guard_capability : "chat:send", flow_cost : 10, leak : external } -> Bob : Msg
414"#,
415 )
416 .expect("compile choreography");
417
418 let records = compiled
419 .annotation_records()
420 .into_iter()
421 .filter(|record| {
422 record.path == "root"
423 && record.scope == AnnotationScope::Sender
424 && record.role.as_deref() == Some("Alice")
425 })
426 .collect::<Vec<_>>();
427
428 assert_eq!(
429 records
430 .iter()
431 .map(|record| record.key.as_str())
432 .collect::<Vec<_>>(),
433 vec!["guard_capability", "flow_cost", "leak"]
434 );
435 }
436}