use crate::extensions::{ExtensionRegistry, GrammarExtension};
use std::collections::HashSet;
use std::fs;
use std::path::Path;
pub struct GrammarComposer {
base_grammar: String,
extension_registry: ExtensionRegistry,
cached_grammar: Option<String>,
extension_hash: u64,
}
impl GrammarComposer {
pub fn new() -> Self {
let base_grammar = include_str!("choreography.pest").to_string();
Self {
base_grammar,
extension_registry: ExtensionRegistry::new(),
cached_grammar: None,
extension_hash: 0,
}
}
pub fn register_extension<T: GrammarExtension + 'static>(
&mut self,
extension: T,
) -> Result<(), crate::extensions::ParseError> {
let result = self.extension_registry.register_grammar(extension);
self.invalidate_cache();
result
}
fn invalidate_cache(&mut self) {
self.cached_grammar = None;
self.extension_hash = self.compute_extension_hash();
}
fn compute_extension_hash(&self) -> u64 {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
self.extension_registry
.grammar_extensions()
.count()
.hash(&mut hasher);
for ext in self.extension_registry.grammar_extensions() {
ext.extension_id().hash(&mut hasher);
ext.priority().hash(&mut hasher);
}
hasher.finish()
}
pub fn compose(&mut self) -> Result<String, GrammarCompositionError> {
let current_hash = self.compute_extension_hash();
if let Some(ref cached) = self.cached_grammar {
if current_hash == self.extension_hash {
return Ok(cached.clone());
}
}
let composed = self.compose_uncached()?;
self.cached_grammar = Some(composed.clone());
self.extension_hash = current_hash;
Ok(composed)
}
fn compose_uncached(&self) -> Result<String, GrammarCompositionError> {
let mut composed = self.base_grammar.clone();
self.validate_base_grammar_cached(&composed)?;
let extension_rules = self.extension_registry.compose_grammar("");
if !extension_rules.trim().is_empty() {
composed = self.inject_extension_rules_optimized(composed, &extension_rules)?;
}
self.validate_composed_grammar_cached(&composed)?;
Ok(composed)
}
fn inject_extension_rules_optimized(
&self,
base_grammar: String,
extension_rules: &str,
) -> Result<String, GrammarCompositionError> {
self.inject_extension_rules_via_lines(base_grammar, extension_rules)
}
fn inject_extension_rules_via_lines(
&self,
base_grammar: String,
extension_rules: &str,
) -> Result<String, GrammarCompositionError> {
let extension_statements = self.extract_extension_statements_optimized(extension_rules)?;
let mut lines: Vec<String> = base_grammar.lines().map(|line| line.to_string()).collect();
if !extension_statements.is_empty() {
let (stmt_start, stmt_end) = find_statement_rule_bounds(&lines)?;
let indent = find_statement_indent(&lines, stmt_start, stmt_end);
let insert_lines: Vec<String> = extension_statements
.iter()
.map(|rule| format!("{indent}| {rule}"))
.collect();
lines.splice(stmt_end..stmt_end, insert_lines);
}
let mut composed = lines.join("\n");
composed.push('\n');
composed.push_str("// Extension Rules\n");
composed.push_str(extension_rules);
Ok(composed)
}
fn extract_extension_statements_optimized(
&self,
extension_rules: &str,
) -> Result<Vec<String>, GrammarCompositionError> {
let mut statements = Vec::new();
let estimated_rules = extension_rules.matches("_stmt = {").count();
statements.reserve(estimated_rules);
for line in extension_rules.lines() {
let line = line.trim();
if line.ends_with("_stmt = {") {
if let Some(equals_pos) = line.find('=') {
let rule_name = line[..equals_pos].trim();
statements.push(rule_name.to_string());
}
}
}
Ok(statements)
}
fn validate_base_grammar_cached(&self, grammar: &str) -> Result<(), GrammarCompositionError> {
const REQUIRED_PATTERNS: &[&str] = &["statement = _{", "send_stmt", "broadcast_stmt"];
for &pattern in REQUIRED_PATTERNS {
if !grammar.contains(pattern) {
return Err(GrammarCompositionError::InvalidBaseGrammar(format!(
"Missing required rule: {}",
pattern
)));
}
}
Ok(())
}
fn validate_composed_grammar_cached(
&self,
grammar: &str,
) -> Result<(), GrammarCompositionError> {
let rule_names = collect_pest_rule_names(grammar);
let mut unique_rule_names = HashSet::with_capacity(rule_names.len());
for rule_name in rule_names {
if !unique_rule_names.insert(rule_name.clone()) {
return Err(GrammarCompositionError::DuplicateRule(rule_name));
}
}
let (open_braces, close_braces) = count_braces_outside_quotes(grammar);
if open_braces != close_braces {
return Err(GrammarCompositionError::SyntaxError(
"Unbalanced braces in composed grammar".to_string(),
));
}
Ok(())
}
pub fn has_extension_rule(&self, rule_name: &str) -> bool {
self.extension_registry.can_handle(rule_name)
}
pub fn extension_count(&self) -> usize {
self.extension_registry.grammar_extensions().count()
}
pub fn write_composed_grammar<P: AsRef<Path>>(
&mut self,
path: P,
) -> Result<(), GrammarCompositionError> {
let composed = self.compose()?;
fs::write(path, composed).map_err(|e| {
GrammarCompositionError::IoError(format!("Failed to write grammar: {}", e))
})?;
Ok(())
}
}
fn collect_pest_rule_names(grammar: &str) -> Vec<String> {
let mut rules = Vec::new();
for line in grammar.lines() {
let trimmed = line.trim_start();
if trimmed.is_empty() || trimmed.starts_with("//") {
continue;
}
let Some(first) = trimmed.chars().next() else {
continue;
};
if !(first == '_' || first.is_ascii_alphabetic()) {
continue;
}
let name_len = trimmed
.char_indices()
.take_while(|(_, ch)| *ch == '_' || ch.is_ascii_alphanumeric())
.last()
.map_or(0, |(idx, ch)| idx + ch.len_utf8());
let (name, rest) = trimmed.split_at(name_len);
if rest.trim_start().starts_with('=') {
rules.push(name.to_string());
}
}
rules
}
fn count_braces_outside_quotes(grammar: &str) -> (usize, usize) {
let mut open_braces = 0usize;
let mut close_braces = 0usize;
let mut in_string = false;
let mut escape = false;
for ch in grammar.chars() {
if in_string {
if escape {
escape = false;
} else if ch == '\\' {
escape = true;
} else if ch == '"' {
in_string = false;
}
continue;
}
if ch == '"' {
in_string = true;
continue;
}
match ch {
'{' => open_braces += 1,
'}' => close_braces += 1,
_ => {}
}
}
(open_braces, close_braces)
}
fn find_statement_rule_bounds(lines: &[String]) -> Result<(usize, usize), GrammarCompositionError> {
let mut start = None;
for (idx, line) in lines.iter().enumerate() {
if line.trim_start().starts_with("statement = _{") {
start = Some(idx);
break;
}
}
let start = start.ok_or_else(|| {
GrammarCompositionError::InvalidBaseGrammar(
"Could not find statement rule in base grammar".to_string(),
)
})?;
for (idx, line) in lines.iter().enumerate().skip(start + 1) {
if line.trim_start().starts_with('}') {
return Ok((start, idx));
}
}
Err(GrammarCompositionError::InvalidBaseGrammar(
"Could not find end of statement rule in base grammar".to_string(),
))
}
fn find_statement_indent(lines: &[String], start: usize, end: usize) -> String {
for line in lines.iter().take(end).skip(start + 1) {
let trimmed = line.trim_start();
if trimmed.starts_with('|') {
let indent_len = line.len().saturating_sub(trimmed.len());
return line[..indent_len].to_string();
}
}
" ".to_string()
}
impl Default for GrammarComposer {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, thiserror::Error)]
pub enum GrammarCompositionError {
#[error("Invalid base grammar: {0}")]
InvalidBaseGrammar(String),
#[error("Duplicate rule name: {0}")]
DuplicateRule(String),
#[error("Syntax error in composed grammar: {0}")]
SyntaxError(String),
#[error("Extension conflict: {0}")]
ExtensionConflict(String),
#[error("IO error: {0}")]
IoError(String),
}
pub struct GrammarComposerBuilder {
composer: GrammarComposer,
}
impl GrammarComposerBuilder {
pub fn new() -> Self {
Self {
composer: GrammarComposer::new(),
}
}
pub fn with_extension<T: GrammarExtension + 'static>(
mut self,
extension: T,
) -> Result<Self, crate::extensions::ParseError> {
self.composer.register_extension(extension)?;
Ok(self)
}
pub fn build(self) -> GrammarComposer {
self.composer
}
}
impl Default for GrammarComposerBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::extensions::GrammarExtension;
#[derive(Debug)]
struct TestExtension;
impl GrammarExtension for TestExtension {
fn grammar_rules(&self) -> &'static str {
"test_timeout_audit_stmt = { \"audit\" ~ ident ~ block }"
}
fn statement_rules(&self) -> Vec<&'static str> {
vec!["test_timeout_audit_stmt"]
}
fn extension_id(&self) -> &'static str {
"test_timeout"
}
}
#[test]
fn test_grammar_composer_creation() {
let composer = GrammarComposer::new();
assert_eq!(composer.extension_count(), 0);
assert!(composer.base_grammar.contains("choreography"));
assert!(composer.base_grammar.contains("statement = _{"));
}
#[test]
fn test_extension_registration() {
let mut composer = GrammarComposer::new();
composer
.register_extension(TestExtension)
.expect("extension should register");
assert_eq!(composer.extension_count(), 1);
assert!(composer.has_extension_rule("test_timeout_audit_stmt"));
}
#[test]
fn test_grammar_composition() {
let mut composer = GrammarComposer::new();
composer
.register_extension(TestExtension)
.expect("extension should register");
let result = composer.compose();
assert!(result.is_ok(), "Grammar composition should succeed");
let composed = result.unwrap();
assert!(composed.contains("test_timeout_audit_stmt"));
assert!(composed.contains("choreography"));
assert!(composed.contains("// Extension Rules"));
}
#[test]
fn test_grammar_caching() {
let mut composer = GrammarComposer::new();
composer
.register_extension(TestExtension)
.expect("extension should register");
let start = std::time::Instant::now();
let result1 = composer.compose();
let first_time = start.elapsed();
assert!(result1.is_ok());
let start = std::time::Instant::now();
let result2 = composer.compose();
let second_time = start.elapsed();
assert!(result2.is_ok());
assert_eq!(result1.unwrap(), result2.unwrap());
println!(
"First composition: {:?}, Second (cached): {:?}",
first_time, second_time
);
}
#[test]
fn test_builder_pattern() {
let composer = GrammarComposerBuilder::new()
.with_extension(TestExtension)
.expect("test extension should register")
.build();
assert_eq!(composer.extension_count(), 1);
assert!(composer.has_extension_rule("test_timeout_audit_stmt"));
}
#[test]
fn test_validation() {
let mut composer = GrammarComposer::new();
let valid_result = composer.validate_base_grammar_cached(&composer.base_grammar);
assert!(valid_result.is_ok(), "Base grammar should be valid");
let composed = composer.compose().unwrap();
let validation_result = composer.validate_composed_grammar_cached(&composed);
assert!(
validation_result.is_ok(),
"Composed grammar should be valid"
);
}
}