Skip to main content

telltale_language/
extensions.rs

1//! DSL Extension System for Telltale
2//!
3//! This module provides a clean, composable system for extending choreographic DSL syntax.
4//! Extensions can add new grammar rules, custom statement parsers, and protocol behaviors
5//! while maintaining compatibility with the core choreographic infrastructure.
6
7use crate::ast::{LocalType, Role};
8use crate::compiler::projection::ProjectionError;
9use std::any::{Any, TypeId};
10use std::collections::BTreeMap;
11use std::fmt::Debug;
12
13/// Documentation for an extension
14#[derive(Debug, Clone)]
15pub struct ExtensionDocumentation {
16    pub overview: String,
17    pub syntax_guide: String,
18    pub use_cases: Vec<String>,
19    pub limitations: Vec<String>,
20    pub see_also: Vec<String>,
21}
22
23impl Default for ExtensionDocumentation {
24    fn default() -> Self {
25        Self {
26            overview: "No documentation provided".to_string(),
27            syntax_guide: "No syntax guide provided".to_string(),
28            use_cases: vec![],
29            limitations: vec![],
30            see_also: vec![],
31        }
32    }
33}
34
35/// Example usage for an extension
36#[derive(Debug, Clone)]
37pub struct ExtensionExample {
38    pub title: String,
39    pub description: String,
40    pub code: String,
41    pub expected_output: Option<String>,
42}
43
44/// Trait for adding new grammar rules to the choreographic DSL
45pub trait GrammarExtension: Send + Sync + Debug {
46    /// Return the Pest grammar rules this extension provides
47    fn grammar_rules(&self) -> &'static str;
48
49    /// List of statement rule names this extension handles
50    fn statement_rules(&self) -> Vec<&'static str>;
51
52    /// Priority for conflict resolution (higher = more precedence)
53    fn priority(&self) -> u32 {
54        100
55    }
56
57    /// Extension identifier for debugging and registration
58    fn extension_id(&self) -> &'static str;
59}
60
61/// Trait for self-documenting extensions
62pub trait DocumentedGrammarExtension: GrammarExtension {
63    /// Documentation for this extension
64    fn documentation(&self) -> ExtensionDocumentation {
65        ExtensionDocumentation::default()
66    }
67
68    /// Examples showing how to use this extension
69    fn examples(&self) -> Vec<ExtensionExample> {
70        vec![]
71    }
72
73    /// Grammar rules with human-readable descriptions
74    fn rule_descriptions(&self) -> std::collections::HashMap<String, String> {
75        std::collections::HashMap::new()
76    }
77}
78
79/// Trait for parsing custom protocol statements
80pub trait StatementParser: Send + Sync + Debug {
81    /// Check if this parser can handle the given rule name
82    fn can_parse(&self, rule_name: &str) -> bool;
83
84    /// Return all rules this parser supports
85    fn supported_rules(&self) -> Vec<String>;
86
87    /// Parse a statement into a protocol extension
88    ///
89    /// # Arguments
90    /// * `rule_name` - The grammar rule name being parsed
91    /// * `content` - The matched content as a string
92    /// * `context` - Parsing context with declared roles
93    ///
94    /// # Returns
95    /// A boxed protocol extension representing the parsed statement
96    fn parse_statement(
97        &self,
98        rule_name: &str,
99        content: &str,
100        context: &ParseContext,
101    ) -> Result<Box<dyn ProtocolExtension>, ParseError>;
102}
103
104/// Trait for custom protocol behaviors that can be projected and validated
105pub trait ProtocolExtension: Send + Sync + Debug {
106    /// Unique identifier for this protocol extension type
107    fn type_name(&self) -> &'static str;
108
109    /// Check if this protocol mentions a specific role
110    fn mentions_role(&self, role: &Role) -> bool;
111
112    /// Validate this protocol against declared roles
113    fn validate(&self, roles: &[Role]) -> Result<(), ExtensionValidationError>;
114
115    /// Project this protocol to a local type for a specific role
116    fn project(
117        &self,
118        role: &Role,
119        context: &ProjectionContext,
120    ) -> Result<LocalType, ProjectionError>;
121
122    /// Generate code for this protocol extension
123    fn generate_code(&self, context: &CodegenContext) -> proc_macro2::TokenStream;
124
125    /// For trait object safety and downcasting
126    fn as_any(&self) -> &dyn Any;
127    fn as_any_mut(&mut self) -> &mut dyn Any;
128    fn type_id(&self) -> TypeId;
129    fn clone_box(&self) -> Box<dyn ProtocolExtension>;
130}
131
132impl Clone for Box<dyn ProtocolExtension> {
133    fn clone(&self) -> Self {
134        self.clone_box()
135    }
136}
137
138/// Registry for managing DSL extensions with conflict resolution
139#[derive(Debug, Default)]
140pub struct ExtensionRegistry {
141    grammar_extensions: BTreeMap<String, Box<dyn GrammarExtension>>,
142    statement_parsers: BTreeMap<String, Box<dyn StatementParser>>,
143    rule_to_parser: BTreeMap<String, String>,
144    /// Track rule conflicts for resolution
145    rule_conflicts: BTreeMap<String, Vec<String>>,
146    /// Extension dependencies
147    extension_dependencies: BTreeMap<String, Vec<String>>,
148    /// Extension version information for compatibility checking
149    extension_versions: BTreeMap<String, String>,
150}
151
152impl ExtensionRegistry {
153    /// Create a new empty extension registry
154    pub fn new() -> Self {
155        Self::default()
156    }
157
158    /// Register a grammar extension with conflict detection
159    pub fn register_grammar<T: GrammarExtension + 'static>(
160        &mut self,
161        extension: T,
162    ) -> Result<(), ParseError> {
163        let id = extension.extension_id().to_string();
164        let rules = extension.statement_rules();
165        let priority = extension.priority();
166
167        // Check for conflicts and resolve by priority
168        for rule in &rules {
169            if let Some(existing_id) = self.rule_to_parser.get(*rule) {
170                let existing_priority = self
171                    .grammar_extensions
172                    .get(existing_id)
173                    .map(|e| e.priority())
174                    .unwrap_or(0);
175
176                if priority > existing_priority {
177                    // New extension wins, record conflict
178                    self.rule_conflicts
179                        .entry((*rule).to_string())
180                        .or_default()
181                        .push(existing_id.clone());
182                    self.rule_to_parser.insert((*rule).to_string(), id.clone());
183                } else if priority == existing_priority {
184                    // Equal priority - this is a conflict
185                    return Err(ParseError::PriorityConflict {
186                        extension1: existing_id.clone(),
187                        extension2: id.clone(),
188                        priority1: existing_priority,
189                        priority2: priority,
190                        rule: (*rule).to_string(),
191                    });
192                }
193                // Lower priority - existing extension wins
194            } else {
195                self.rule_to_parser.insert((*rule).to_string(), id.clone());
196            }
197        }
198
199        self.grammar_extensions
200            .insert(id.clone(), Box::new(extension));
201        // Set default version if not specified
202        self.extension_versions
203            .entry(id)
204            .or_insert_with(|| "0.1.0".to_string());
205        Ok(())
206    }
207
208    /// Register a statement parser
209    pub fn register_parser<T: StatementParser + 'static>(&mut self, parser: T, parser_id: String) {
210        self.statement_parsers.insert(parser_id, Box::new(parser));
211    }
212
213    /// Get all grammar rules from registered extensions
214    pub fn compose_grammar(&self, base_grammar: &str) -> String {
215        let mut composed = base_grammar.to_string();
216
217        // Sort extensions by priority (highest first)
218        let mut extensions: Vec<_> = self.grammar_extensions.iter().collect();
219        extensions.sort_by(|(id_a, ext_a), (id_b, ext_b)| {
220            std::cmp::Reverse(ext_a.priority())
221                .cmp(&std::cmp::Reverse(ext_b.priority()))
222                .then_with(|| id_a.cmp(id_b))
223        });
224
225        for (_, extension) in extensions {
226            composed.push('\n');
227            composed.push_str(extension.grammar_rules());
228        }
229
230        composed
231    }
232
233    /// Find parser for a given rule name
234    pub fn find_parser(&self, rule_name: &str) -> Option<&dyn StatementParser> {
235        if let Some(parser_id) = self.rule_to_parser.get(rule_name) {
236            self.statement_parsers.get(parser_id).map(|p| p.as_ref())
237        } else {
238            None
239        }
240    }
241
242    /// Check if a rule is handled by an extension
243    pub fn can_handle(&self, rule_name: &str) -> bool {
244        self.rule_to_parser.contains_key(rule_name)
245    }
246
247    /// Check if any extensions are registered
248    pub fn has_extensions(&self) -> bool {
249        !self.grammar_extensions.is_empty() || !self.statement_parsers.is_empty()
250    }
251
252    /// Get all grammar extensions
253    pub fn grammar_extensions(&self) -> impl Iterator<Item = &dyn GrammarExtension> {
254        let mut ordered: Vec<_> = self.grammar_extensions.iter().collect();
255        ordered.sort_by(|(id_a, _), (id_b, _)| id_a.cmp(id_b));
256        ordered.into_iter().map(|(_, e)| e.as_ref())
257    }
258
259    /// Check if a specific extension is registered
260    pub fn has_extension(&self, extension_id: &str) -> bool {
261        self.grammar_extensions.contains_key(extension_id)
262    }
263
264    /// Get parser for a rule name
265    pub fn get_parser_for_rule(&self, rule_name: &str) -> Option<&str> {
266        self.rule_to_parser.get(rule_name).map(String::as_str)
267    }
268
269    /// Get statement parser by ID
270    pub fn get_statement_parser(&self, parser_id: &str) -> Option<&dyn StatementParser> {
271        self.statement_parsers.get(parser_id).map(|p| p.as_ref())
272    }
273
274    /// Get the number of registered statement parsers.
275    pub fn statement_parser_count(&self) -> usize {
276        self.statement_parsers.len()
277    }
278
279    /// Get the registered extension statement rules in stable order.
280    pub fn statement_rules(&self) -> Vec<&str> {
281        let mut rules: Vec<_> = self.rule_to_parser.keys().map(String::as_str).collect();
282        rules.sort_unstable();
283        rules
284    }
285
286    /// Add dependency between extensions
287    pub fn add_dependency(&mut self, dependent: &str, required: &str) {
288        self.extension_dependencies
289            .entry(dependent.to_string())
290            .or_default()
291            .push(required.to_string());
292    }
293
294    /// Validate all extension dependencies are satisfied
295    pub fn validate_dependencies(&self) -> Result<(), ParseError> {
296        for (dependent, requirements) in &self.extension_dependencies {
297            for required in requirements {
298                if !self.grammar_extensions.contains_key(required) {
299                    return Err(ParseError::MissingDependency {
300                        extension: dependent.clone(),
301                        dependency: required.clone(),
302                    });
303                }
304            }
305        }
306        Ok(())
307    }
308
309    /// Get all rule conflicts for debugging
310    pub fn get_conflicts(&self) -> &BTreeMap<String, Vec<String>> {
311        &self.rule_conflicts
312    }
313
314    /// Get detailed conflict information with resolution suggestions
315    pub fn get_detailed_conflicts(&self) -> Vec<String> {
316        let mut details = Vec::new();
317        let unknown_ext = "unknown".to_string();
318
319        let mut conflicts: Vec<_> = self.rule_conflicts.iter().collect();
320        conflicts.sort_by(|(rule_a, _), (rule_b, _)| rule_a.cmp(rule_b));
321
322        for (rule, conflicting_extensions) in conflicts {
323            if !conflicting_extensions.is_empty() {
324                let active_extension = self.rule_to_parser.get(rule).unwrap_or(&unknown_ext);
325                let active_priority = self
326                    .grammar_extensions
327                    .get(active_extension)
328                    .map(|e| e.priority())
329                    .unwrap_or(0);
330
331                let mut conflicting_extensions = conflicting_extensions.clone();
332                conflicting_extensions.sort();
333
334                for conflicting in &conflicting_extensions {
335                    let conflicting_priority = self
336                        .grammar_extensions
337                        .get(conflicting)
338                        .map(|e| e.priority())
339                        .unwrap_or(0);
340
341                    details.push(format!(
342                        "Rule '{}': Extension '{}' (priority {}) overrode '{}' (priority {}). \
343                         To resolve: 1) Adjust priorities, 2) Use different rule names, or 3) Merge functionality.",
344                        rule, active_extension, active_priority, conflicting, conflicting_priority
345                    ));
346                }
347            }
348        }
349
350        details
351    }
352
353    /// Check extension compatibility
354    pub fn check_compatibility(&self, extension_ids: &[&str]) -> Result<(), ParseError> {
355        // Check for direct conflicts between the specified extensions
356        let mut rules_used = BTreeMap::new();
357
358        for &extension_id in extension_ids {
359            if let Some(extension) = self.grammar_extensions.get(extension_id) {
360                for rule in extension.statement_rules() {
361                    if let Some(existing) = rules_used.get(rule) {
362                        if existing != &extension_id {
363                            return Err(ParseError::IncompatibleExtensions {
364                                details: format!(
365                                    "Extensions '{}' and '{}' both define rule '{}'. Use different rule names or register extensions with different priorities.",
366                                    existing, extension_id, rule
367                                ),
368                            });
369                        }
370                    }
371                    rules_used.insert(rule.to_string(), extension_id);
372                }
373            }
374        }
375        Ok(())
376    }
377
378    /// Create a registry with built-in extensions
379    pub fn with_builtin_extensions() -> Self {
380        let mut registry = Self::new();
381
382        // Register timeout extension
383        registry
384            .register_grammar(timeout::TimeoutGrammarExtension)
385            .expect("builtin timeout extension should register successfully");
386        registry.register_parser(timeout::TimeoutStatementParser, "timeout".to_string());
387
388        registry
389    }
390
391    /// Create a minimal registry for 3rd party integration
392    pub fn for_third_party() -> Self {
393        Self::new()
394    }
395
396    /// Generate basic documentation for all registered extensions
397    pub fn generate_docs(&self) -> String {
398        let mut docs = String::from("# Extension Documentation\n\n");
399
400        let mut entries: Vec<_> = self.grammar_extensions.iter().collect();
401        entries.sort_by(|(id_a, _), (id_b, _)| id_a.cmp(id_b));
402
403        for (id, extension) in entries {
404            docs.push_str(&format!("## {}\n\n", id));
405            docs.push_str(&format!("**Priority:** {}\n\n", extension.priority()));
406            docs.push_str(&format!(
407                "**Rules:** {}\n\n",
408                extension.statement_rules().join(", ")
409            ));
410
411            if let Some(version) = self.extension_versions.get(id) {
412                docs.push_str(&format!("**Version:** {}\n\n", version));
413            }
414
415            docs.push_str("**Grammar:**\n```\n");
416            docs.push_str(extension.grammar_rules());
417            docs.push_str("\n```\n\n");
418        }
419
420        docs
421    }
422}
423
424/// Context provided during statement parsing
425#[derive(Debug)]
426pub struct ParseContext<'a> {
427    /// Roles declared in the choreography
428    pub declared_roles: &'a [Role],
429    /// Original input string for error reporting
430    pub input: &'a str,
431}
432
433/// Context provided during projection
434#[derive(Debug)]
435pub struct ProjectionContext<'a> {
436    /// All roles in the choreography
437    pub all_roles: &'a [Role],
438    /// Current role being projected
439    pub current_role: &'a Role,
440}
441
442/// Context provided during code generation
443#[derive(Debug)]
444pub struct CodegenContext<'a> {
445    /// The choreography being generated
446    pub choreography_name: &'a str,
447    /// All roles in the choreography
448    pub roles: &'a [Role],
449    /// Namespace for generated code
450    pub namespace: Option<&'a str>,
451}
452
453impl<'a> Default for CodegenContext<'a> {
454    fn default() -> Self {
455        Self {
456            choreography_name: "Default",
457            roles: &[],
458            namespace: None,
459        }
460    }
461}
462
463/// Errors that can occur during extension parsing
464#[derive(Debug, thiserror::Error)]
465pub enum ParseError {
466    #[error("Syntax error: {message}")]
467    Syntax { message: String },
468
469    #[error("Unknown role '{role}' used in extension")]
470    UnknownRole { role: String },
471
472    #[error("Invalid extension syntax: {details}")]
473    InvalidSyntax { details: String },
474
475    #[error("Extension conflict: {message}")]
476    Conflict { message: String },
477
478    #[error("Extension priority conflict: Extension '{extension1}' (priority {priority1}) conflicts with '{extension2}' (priority {priority2}) for rule '{rule}'. Consider adjusting priorities or using different rule names.")]
479    PriorityConflict {
480        extension1: String,
481        extension2: String,
482        priority1: u32,
483        priority2: u32,
484        rule: String,
485    },
486
487    #[error("Missing dependency: Extension '{extension}' requires '{dependency}' which is not registered. Please register the required extension first.")]
488    MissingDependency {
489        extension: String,
490        dependency: String,
491    },
492
493    #[error("Extension registration failed: Extension '{extension}' with rule '{rule}' cannot be registered. {details}")]
494    RegistrationFailed {
495        extension: String,
496        rule: String,
497        details: String,
498    },
499
500    #[error("Incompatible extensions: {details}")]
501    IncompatibleExtensions { details: String },
502}
503
504/// Validation errors for protocol extensions
505#[derive(Debug, thiserror::Error)]
506pub enum ExtensionValidationError {
507    #[error("Role '{role}' not declared")]
508    UndeclaredRole { role: String },
509
510    #[error("Invalid protocol structure: {reason}")]
511    InvalidStructure { reason: String },
512
513    #[error("Extension validation failed: {message}")]
514    ExtensionFailed { message: String },
515}
516
517/// Convenience macro for registering extensions
518#[macro_export]
519macro_rules! register_extension {
520    ($registry:expr, $extension:expr) => {{
521        let ext = $extension;
522        let id = ext.extension_id().to_string();
523        $registry.register_grammar(ext);
524    }};
525}
526
527/// Utility trait for easy extension registration
528pub trait RegisterExtension {
529    fn register_all(registry: &mut ExtensionRegistry);
530}
531
532pub mod discovery;
533/// Built-in extensions
534pub mod timeout;
535
536#[cfg(test)]
537mod tests {
538    use super::*;
539
540    #[derive(Debug)]
541    struct MockGrammarExtension;
542
543    impl GrammarExtension for MockGrammarExtension {
544        fn grammar_rules(&self) -> &'static str {
545            "timeout_stmt = { \"timeout\" ~ integer ~ protocol_block }"
546        }
547
548        fn statement_rules(&self) -> Vec<&'static str> {
549            vec!["timeout_stmt"]
550        }
551
552        fn extension_id(&self) -> &'static str {
553            "mock_timeout"
554        }
555    }
556
557    #[test]
558    fn test_extension_registry() {
559        let mut registry = ExtensionRegistry::new();
560
561        // Register extension
562        registry
563            .register_grammar(MockGrammarExtension)
564            .expect("extension registration should succeed");
565
566        // Test rule mapping
567        assert!(registry.can_handle("timeout_stmt"));
568        assert!(!registry.can_handle("unknown_rule"));
569
570        // Test grammar composition
571        let base = "basic_rule = { \"test\" }";
572        let composed = registry.compose_grammar(base);
573        assert!(composed.contains("basic_rule"));
574        assert!(composed.contains("timeout_stmt"));
575    }
576
577    #[test]
578    fn test_enhanced_error_messages() {
579        use crate::extensions::ParseError;
580
581        // Test priority conflict error
582        let err = ParseError::PriorityConflict {
583            extension1: "ext1".to_string(),
584            extension2: "ext2".to_string(),
585            priority1: 100,
586            priority2: 100,
587            rule: "test_rule".to_string(),
588        };
589        assert!(err.to_string().contains("Consider adjusting priorities"));
590
591        // Test missing dependency error
592        let err = ParseError::MissingDependency {
593            extension: "dependent_ext".to_string(),
594            dependency: "required_ext".to_string(),
595        };
596        assert!(err
597            .to_string()
598            .contains("Please register the required extension first"));
599
600        // Test incompatible extensions error
601        let err = ParseError::IncompatibleExtensions {
602            details: "Test incompatibility".to_string(),
603        };
604        assert!(err.to_string().contains("Incompatible extensions"));
605    }
606
607    #[test]
608    fn test_detailed_conflicts() {
609        #[derive(Debug)]
610        struct TestExt1;
611        impl GrammarExtension for TestExt1 {
612            fn grammar_rules(&self) -> &'static str {
613                "rule1 = { \"test1\" }"
614            }
615            fn statement_rules(&self) -> Vec<&'static str> {
616                vec!["rule1"]
617            }
618            fn priority(&self) -> u32 {
619                200
620            }
621            fn extension_id(&self) -> &'static str {
622                "test_ext1"
623            }
624        }
625
626        #[derive(Debug)]
627        struct TestExt2;
628        impl GrammarExtension for TestExt2 {
629            fn grammar_rules(&self) -> &'static str {
630                "rule1 = { \"test2\" }"
631            }
632            fn statement_rules(&self) -> Vec<&'static str> {
633                vec!["rule1"]
634            }
635            fn priority(&self) -> u32 {
636                100
637            }
638            fn extension_id(&self) -> &'static str {
639                "test_ext2"
640            }
641        }
642
643        let mut registry = ExtensionRegistry::new();
644
645        // Register lower priority first
646        registry
647            .register_grammar(TestExt2)
648            .expect("lower priority extension should register");
649        // Register higher priority second (should override)
650        registry
651            .register_grammar(TestExt1)
652            .expect("higher priority extension should override");
653
654        let conflicts = registry.get_detailed_conflicts();
655        assert!(!conflicts.is_empty());
656        assert!(conflicts[0].contains("overrode"));
657        assert!(conflicts[0].contains("priority"));
658    }
659
660    #[test]
661    fn test_documentation_system() {
662        let mut registry = ExtensionRegistry::new();
663
664        registry
665            .extension_versions
666            .insert("mock_timeout".to_string(), "1.0.0".to_string());
667        registry
668            .register_grammar(MockGrammarExtension)
669            .expect("grammar extension should register");
670
671        // Test documentation generation
672        let docs = registry.generate_docs();
673        assert!(docs.contains("# Extension Documentation"));
674        assert!(docs.contains("mock_timeout"));
675        assert!(docs.contains("**Priority:** 100"));
676        assert!(docs.contains("**Version:** 1.0.0"));
677
678        assert_eq!(
679            registry.extension_versions.get("mock_timeout"),
680            Some(&"1.0.0".to_string())
681        );
682    }
683
684    #[test]
685    fn test_compose_grammar_is_stable_for_equal_priorities() {
686        #[derive(Debug)]
687        struct AlphaExt;
688        impl GrammarExtension for AlphaExt {
689            fn grammar_rules(&self) -> &'static str {
690                "alpha_stmt = { \"alpha\" }"
691            }
692            fn statement_rules(&self) -> Vec<&'static str> {
693                vec!["alpha_stmt"]
694            }
695            fn priority(&self) -> u32 {
696                100
697            }
698            fn extension_id(&self) -> &'static str {
699                "alpha_ext"
700            }
701        }
702
703        #[derive(Debug)]
704        struct BetaExt;
705        impl GrammarExtension for BetaExt {
706            fn grammar_rules(&self) -> &'static str {
707                "beta_stmt = { \"beta\" }"
708            }
709            fn statement_rules(&self) -> Vec<&'static str> {
710                vec!["beta_stmt"]
711            }
712            fn priority(&self) -> u32 {
713                100
714            }
715            fn extension_id(&self) -> &'static str {
716                "beta_ext"
717            }
718        }
719
720        let mut registry = ExtensionRegistry::new();
721        registry.register_grammar(BetaExt).unwrap();
722        registry.register_grammar(AlphaExt).unwrap();
723
724        let composed = registry.compose_grammar("base = { \"x\" }");
725        let alpha_idx = composed.find("alpha_stmt").unwrap();
726        let beta_idx = composed.find("beta_stmt").unwrap();
727        assert!(alpha_idx < beta_idx);
728    }
729
730    #[test]
731    fn test_parse_context() {
732        use proc_macro2::Span;
733        let roles = vec![
734            Role::new(proc_macro2::Ident::new("Alice", Span::call_site())).unwrap(),
735            Role::new(proc_macro2::Ident::new("Bob", Span::call_site())).unwrap(),
736        ];
737
738        let context = ParseContext {
739            declared_roles: &roles,
740            input: "test input",
741        };
742
743        assert_eq!(context.declared_roles.len(), 2);
744        assert_eq!(context.input, "test input");
745    }
746}