use crate::ast::{LocalType, Role};
use crate::compiler::projection::ProjectionError;
use std::any::{Any, TypeId};
use std::collections::BTreeMap;
use std::fmt::Debug;
#[derive(Debug, Clone)]
pub struct ExtensionDocumentation {
pub overview: String,
pub syntax_guide: String,
pub use_cases: Vec<String>,
pub limitations: Vec<String>,
pub see_also: Vec<String>,
}
impl Default for ExtensionDocumentation {
fn default() -> Self {
Self {
overview: "No documentation provided".to_string(),
syntax_guide: "No syntax guide provided".to_string(),
use_cases: vec![],
limitations: vec![],
see_also: vec![],
}
}
}
#[derive(Debug, Clone)]
pub struct ExtensionExample {
pub title: String,
pub description: String,
pub code: String,
pub expected_output: Option<String>,
}
pub trait GrammarExtension: Send + Sync + Debug {
fn grammar_rules(&self) -> &'static str;
fn statement_rules(&self) -> Vec<&'static str>;
fn priority(&self) -> u32 {
100
}
fn extension_id(&self) -> &'static str;
}
pub trait DocumentedGrammarExtension: GrammarExtension {
fn documentation(&self) -> ExtensionDocumentation {
ExtensionDocumentation::default()
}
fn examples(&self) -> Vec<ExtensionExample> {
vec![]
}
fn rule_descriptions(&self) -> std::collections::HashMap<String, String> {
std::collections::HashMap::new()
}
}
pub trait StatementParser: Send + Sync + Debug {
fn can_parse(&self, rule_name: &str) -> bool;
fn supported_rules(&self) -> Vec<String>;
fn parse_statement(
&self,
rule_name: &str,
content: &str,
context: &ParseContext,
) -> Result<Box<dyn ProtocolExtension>, ParseError>;
}
pub trait ProtocolExtension: Send + Sync + Debug {
fn type_name(&self) -> &'static str;
fn mentions_role(&self, role: &Role) -> bool;
fn validate(&self, roles: &[Role]) -> Result<(), ExtensionValidationError>;
fn project(
&self,
role: &Role,
context: &ProjectionContext,
) -> Result<LocalType, ProjectionError>;
fn generate_code(&self, context: &CodegenContext) -> proc_macro2::TokenStream;
fn as_any(&self) -> &dyn Any;
fn as_any_mut(&mut self) -> &mut dyn Any;
fn type_id(&self) -> TypeId;
}
#[derive(Debug, Default)]
pub struct ExtensionRegistry {
grammar_extensions: BTreeMap<String, Box<dyn GrammarExtension>>,
statement_parsers: BTreeMap<String, Box<dyn StatementParser>>,
rule_to_parser: BTreeMap<String, String>,
rule_conflicts: BTreeMap<String, Vec<String>>,
extension_dependencies: BTreeMap<String, Vec<String>>,
extension_versions: BTreeMap<String, String>,
}
impl ExtensionRegistry {
pub fn new() -> Self {
Self::default()
}
pub fn register_grammar<T: GrammarExtension + 'static>(
&mut self,
extension: T,
) -> Result<(), ParseError> {
let id = extension.extension_id().to_string();
let rules = extension.statement_rules();
let priority = extension.priority();
for rule in &rules {
if let Some(existing_id) = self.rule_to_parser.get(*rule) {
let existing_priority = self
.grammar_extensions
.get(existing_id)
.map(|e| e.priority())
.unwrap_or(0);
if priority > existing_priority {
self.rule_conflicts
.entry((*rule).to_string())
.or_default()
.push(existing_id.clone());
self.rule_to_parser.insert((*rule).to_string(), id.clone());
} else if priority == existing_priority {
return Err(ParseError::PriorityConflict {
extension1: existing_id.clone(),
extension2: id.clone(),
priority1: existing_priority,
priority2: priority,
rule: (*rule).to_string(),
});
}
} else {
self.rule_to_parser.insert((*rule).to_string(), id.clone());
}
}
self.grammar_extensions
.insert(id.clone(), Box::new(extension));
self.extension_versions
.entry(id)
.or_insert_with(|| "0.1.0".to_string());
Ok(())
}
pub fn register_parser<T: StatementParser + 'static>(&mut self, parser: T, parser_id: String) {
self.statement_parsers.insert(parser_id, Box::new(parser));
}
pub fn compose_grammar(&self, base_grammar: &str) -> String {
let mut composed = base_grammar.to_string();
let mut extensions: Vec<_> = self.grammar_extensions.iter().collect();
extensions.sort_by(|(id_a, ext_a), (id_b, ext_b)| {
std::cmp::Reverse(ext_a.priority())
.cmp(&std::cmp::Reverse(ext_b.priority()))
.then_with(|| id_a.cmp(id_b))
});
for (_, extension) in extensions {
composed.push('\n');
composed.push_str(extension.grammar_rules());
}
composed
}
pub fn find_parser(&self, rule_name: &str) -> Option<&dyn StatementParser> {
if let Some(parser_id) = self.rule_to_parser.get(rule_name) {
self.statement_parsers.get(parser_id).map(|p| p.as_ref())
} else {
None
}
}
pub fn can_handle(&self, rule_name: &str) -> bool {
self.rule_to_parser.contains_key(rule_name)
}
pub fn has_extensions(&self) -> bool {
!self.grammar_extensions.is_empty() || !self.statement_parsers.is_empty()
}
pub fn grammar_extensions(&self) -> impl Iterator<Item = &dyn GrammarExtension> {
let mut ordered: Vec<_> = self.grammar_extensions.iter().collect();
ordered.sort_by(|(id_a, _), (id_b, _)| id_a.cmp(id_b));
ordered.into_iter().map(|(_, e)| e.as_ref())
}
pub fn has_extension(&self, extension_id: &str) -> bool {
self.grammar_extensions.contains_key(extension_id)
}
pub fn get_parser_for_rule(&self, rule_name: &str) -> Option<&str> {
self.rule_to_parser.get(rule_name).map(String::as_str)
}
pub fn get_statement_parser(&self, parser_id: &str) -> Option<&dyn StatementParser> {
self.statement_parsers.get(parser_id).map(|p| p.as_ref())
}
pub fn add_dependency(&mut self, dependent: &str, required: &str) {
self.extension_dependencies
.entry(dependent.to_string())
.or_default()
.push(required.to_string());
}
pub fn validate_dependencies(&self) -> Result<(), ParseError> {
for (dependent, requirements) in &self.extension_dependencies {
for required in requirements {
if !self.grammar_extensions.contains_key(required) {
return Err(ParseError::MissingDependency {
extension: dependent.clone(),
dependency: required.clone(),
});
}
}
}
Ok(())
}
pub fn get_conflicts(&self) -> &BTreeMap<String, Vec<String>> {
&self.rule_conflicts
}
pub fn get_detailed_conflicts(&self) -> Vec<String> {
let mut details = Vec::new();
let unknown_ext = "unknown".to_string();
let mut conflicts: Vec<_> = self.rule_conflicts.iter().collect();
conflicts.sort_by(|(rule_a, _), (rule_b, _)| rule_a.cmp(rule_b));
for (rule, conflicting_extensions) in conflicts {
if !conflicting_extensions.is_empty() {
let active_extension = self.rule_to_parser.get(rule).unwrap_or(&unknown_ext);
let active_priority = self
.grammar_extensions
.get(active_extension)
.map(|e| e.priority())
.unwrap_or(0);
let mut conflicting_extensions = conflicting_extensions.clone();
conflicting_extensions.sort();
for conflicting in &conflicting_extensions {
let conflicting_priority = self
.grammar_extensions
.get(conflicting)
.map(|e| e.priority())
.unwrap_or(0);
details.push(format!(
"Rule '{}': Extension '{}' (priority {}) overrode '{}' (priority {}). \
To resolve: 1) Adjust priorities, 2) Use different rule names, or 3) Merge functionality.",
rule, active_extension, active_priority, conflicting, conflicting_priority
));
}
}
}
details
}
pub fn check_compatibility(&self, extension_ids: &[&str]) -> Result<(), ParseError> {
let mut rules_used = BTreeMap::new();
for &extension_id in extension_ids {
if let Some(extension) = self.grammar_extensions.get(extension_id) {
for rule in extension.statement_rules() {
if let Some(existing) = rules_used.get(rule) {
if existing != &extension_id {
return Err(ParseError::IncompatibleExtensions {
details: format!(
"Extensions '{}' and '{}' both define rule '{}'. Use different rule names or register extensions with different priorities.",
existing, extension_id, rule
),
});
}
}
rules_used.insert(rule.to_string(), extension_id);
}
}
}
Ok(())
}
pub fn with_builtin_extensions() -> Self {
let mut registry = Self::new();
registry
.register_grammar(timeout::TimeoutGrammarExtension)
.expect("builtin timeout extension should register successfully");
registry.register_parser(timeout::TimeoutStatementParser, "timeout".to_string());
registry
}
pub fn for_third_party() -> Self {
Self::new()
}
pub fn generate_docs(&self) -> String {
let mut docs = String::from("# Extension Documentation\n\n");
let mut entries: Vec<_> = self.grammar_extensions.iter().collect();
entries.sort_by(|(id_a, _), (id_b, _)| id_a.cmp(id_b));
for (id, extension) in entries {
docs.push_str(&format!("## {}\n\n", id));
docs.push_str(&format!("**Priority:** {}\n\n", extension.priority()));
docs.push_str(&format!(
"**Rules:** {}\n\n",
extension.statement_rules().join(", ")
));
if let Some(version) = self.extension_versions.get(id) {
docs.push_str(&format!("**Version:** {}\n\n", version));
}
docs.push_str("**Grammar:**\n```\n");
docs.push_str(extension.grammar_rules());
docs.push_str("\n```\n\n");
}
docs
}
}
#[derive(Debug)]
pub struct ParseContext<'a> {
pub declared_roles: &'a [Role],
pub input: &'a str,
}
#[derive(Debug)]
pub struct ProjectionContext<'a> {
pub all_roles: &'a [Role],
pub current_role: &'a Role,
}
#[derive(Debug)]
pub struct CodegenContext<'a> {
pub choreography_name: &'a str,
pub roles: &'a [Role],
pub namespace: Option<&'a str>,
}
impl<'a> Default for CodegenContext<'a> {
fn default() -> Self {
Self {
choreography_name: "Default",
roles: &[],
namespace: None,
}
}
}
#[derive(Debug, thiserror::Error)]
pub enum ParseError {
#[error("Syntax error: {message}")]
Syntax { message: String },
#[error("Unknown role '{role}' used in extension")]
UnknownRole { role: String },
#[error("Invalid extension syntax: {details}")]
InvalidSyntax { details: String },
#[error("Extension conflict: {message}")]
Conflict { message: String },
#[error("Extension priority conflict: Extension '{extension1}' (priority {priority1}) conflicts with '{extension2}' (priority {priority2}) for rule '{rule}'. Consider adjusting priorities or using different rule names.")]
PriorityConflict {
extension1: String,
extension2: String,
priority1: u32,
priority2: u32,
rule: String,
},
#[error("Missing dependency: Extension '{extension}' requires '{dependency}' which is not registered. Please register the required extension first.")]
MissingDependency {
extension: String,
dependency: String,
},
#[error("Extension registration failed: Extension '{extension}' with rule '{rule}' cannot be registered. {details}")]
RegistrationFailed {
extension: String,
rule: String,
details: String,
},
#[error("Incompatible extensions: {details}")]
IncompatibleExtensions { details: String },
}
#[derive(Debug, thiserror::Error)]
pub enum ExtensionValidationError {
#[error("Role '{role}' not declared")]
UndeclaredRole { role: String },
#[error("Invalid protocol structure: {reason}")]
InvalidStructure { reason: String },
#[error("Extension validation failed: {message}")]
ExtensionFailed { message: String },
}
#[macro_export]
macro_rules! register_extension {
($registry:expr, $extension:expr) => {{
let ext = $extension;
let id = ext.extension_id().to_string();
$registry.register_grammar(ext);
}};
}
pub trait RegisterExtension {
fn register_all(registry: &mut ExtensionRegistry);
}
pub mod discovery;
pub mod timeout;
#[cfg(test)]
mod tests {
use super::*;
#[derive(Debug)]
struct MockGrammarExtension;
impl GrammarExtension for MockGrammarExtension {
fn grammar_rules(&self) -> &'static str {
"timeout_stmt = { \"timeout\" ~ integer ~ protocol_block }"
}
fn statement_rules(&self) -> Vec<&'static str> {
vec!["timeout_stmt"]
}
fn extension_id(&self) -> &'static str {
"mock_timeout"
}
}
#[test]
fn test_extension_registry() {
let mut registry = ExtensionRegistry::new();
registry
.register_grammar(MockGrammarExtension)
.expect("extension registration should succeed");
assert!(registry.can_handle("timeout_stmt"));
assert!(!registry.can_handle("unknown_rule"));
let base = "basic_rule = { \"test\" }";
let composed = registry.compose_grammar(base);
assert!(composed.contains("basic_rule"));
assert!(composed.contains("timeout_stmt"));
}
#[test]
fn test_enhanced_error_messages() {
use crate::extensions::ParseError;
let err = ParseError::PriorityConflict {
extension1: "ext1".to_string(),
extension2: "ext2".to_string(),
priority1: 100,
priority2: 100,
rule: "test_rule".to_string(),
};
assert!(err.to_string().contains("Consider adjusting priorities"));
let err = ParseError::MissingDependency {
extension: "dependent_ext".to_string(),
dependency: "required_ext".to_string(),
};
assert!(err
.to_string()
.contains("Please register the required extension first"));
let err = ParseError::IncompatibleExtensions {
details: "Test incompatibility".to_string(),
};
assert!(err.to_string().contains("Incompatible extensions"));
}
#[test]
fn test_detailed_conflicts() {
#[derive(Debug)]
struct TestExt1;
impl GrammarExtension for TestExt1 {
fn grammar_rules(&self) -> &'static str {
"rule1 = { \"test1\" }"
}
fn statement_rules(&self) -> Vec<&'static str> {
vec!["rule1"]
}
fn priority(&self) -> u32 {
200
}
fn extension_id(&self) -> &'static str {
"test_ext1"
}
}
#[derive(Debug)]
struct TestExt2;
impl GrammarExtension for TestExt2 {
fn grammar_rules(&self) -> &'static str {
"rule1 = { \"test2\" }"
}
fn statement_rules(&self) -> Vec<&'static str> {
vec!["rule1"]
}
fn priority(&self) -> u32 {
100
}
fn extension_id(&self) -> &'static str {
"test_ext2"
}
}
let mut registry = ExtensionRegistry::new();
registry
.register_grammar(TestExt2)
.expect("lower priority extension should register");
registry
.register_grammar(TestExt1)
.expect("higher priority extension should override");
let conflicts = registry.get_detailed_conflicts();
assert!(!conflicts.is_empty());
assert!(conflicts[0].contains("overrode"));
assert!(conflicts[0].contains("priority"));
}
#[test]
fn test_documentation_system() {
let mut registry = ExtensionRegistry::new();
registry
.extension_versions
.insert("mock_timeout".to_string(), "1.0.0".to_string());
registry
.register_grammar(MockGrammarExtension)
.expect("grammar extension should register");
let docs = registry.generate_docs();
assert!(docs.contains("# Extension Documentation"));
assert!(docs.contains("mock_timeout"));
assert!(docs.contains("**Priority:** 100"));
assert!(docs.contains("**Version:** 1.0.0"));
assert_eq!(
registry.extension_versions.get("mock_timeout"),
Some(&"1.0.0".to_string())
);
}
#[test]
fn test_compose_grammar_is_stable_for_equal_priorities() {
#[derive(Debug)]
struct AlphaExt;
impl GrammarExtension for AlphaExt {
fn grammar_rules(&self) -> &'static str {
"alpha_stmt = { \"alpha\" }"
}
fn statement_rules(&self) -> Vec<&'static str> {
vec!["alpha_stmt"]
}
fn priority(&self) -> u32 {
100
}
fn extension_id(&self) -> &'static str {
"alpha_ext"
}
}
#[derive(Debug)]
struct BetaExt;
impl GrammarExtension for BetaExt {
fn grammar_rules(&self) -> &'static str {
"beta_stmt = { \"beta\" }"
}
fn statement_rules(&self) -> Vec<&'static str> {
vec!["beta_stmt"]
}
fn priority(&self) -> u32 {
100
}
fn extension_id(&self) -> &'static str {
"beta_ext"
}
}
let mut registry = ExtensionRegistry::new();
registry.register_grammar(BetaExt).unwrap();
registry.register_grammar(AlphaExt).unwrap();
let composed = registry.compose_grammar("base = { \"x\" }");
let alpha_idx = composed.find("alpha_stmt").unwrap();
let beta_idx = composed.find("beta_stmt").unwrap();
assert!(alpha_idx < beta_idx);
}
#[test]
fn test_parse_context() {
use proc_macro2::Span;
let roles = vec![
Role::new(proc_macro2::Ident::new("Alice", Span::call_site())).unwrap(),
Role::new(proc_macro2::Ident::new("Bob", Span::call_site())).unwrap(),
];
let context = ParseContext {
declared_roles: &roles,
input: "test input",
};
assert_eq!(context.declared_roles.len(), 2);
assert_eq!(context.input, "test input");
}
}