Skip to main content

adze_ir/
optimizer.rs

1// Grammar optimization passes for Adze
2//! Grammar optimization passes that simplify and normalize rules.
3
4// This module implements various optimizations to improve parser performance
5
6#[cfg(test)]
7use crate::Token;
8use crate::{Grammar, ProductionId, Rule, Symbol, SymbolId, TokenPattern};
9use indexmap::IndexMap;
10use std::collections::{HashMap, HashSet};
11
12#[cfg(not(debug_assertions))]
13macro_rules! debug_trace {
14    ($($arg:tt)*) => {};
15}
16
17#[cfg(debug_assertions)]
18macro_rules! debug_trace {
19    ($($arg:tt)*) => {
20        if std::env::var("RUST_LOG")
21            .ok()
22            .unwrap_or_default()
23            .contains("debug")
24        {
25            eprintln!($($arg)*);
26        }
27    };
28}
29
30/// Grammar optimizer that applies various optimization passes
31pub struct GrammarOptimizer {
32    /// Track which symbols are actually used
33    used_symbols: HashSet<SymbolId>,
34    /// Track which rules can be inlined
35    inlinable_rules: HashSet<SymbolId>,
36    /// Track left-recursive rules for special handling
37    left_recursive_rules: HashSet<SymbolId>,
38    /// Track the source_file symbol ID to prevent inlining
39    source_file_id: Option<SymbolId>,
40}
41
42impl Default for GrammarOptimizer {
43    fn default() -> Self {
44        Self::new()
45    }
46}
47
48impl GrammarOptimizer {
49    /// Create a new optimizer
50    pub fn new() -> Self {
51        GrammarOptimizer {
52            used_symbols: HashSet::new(),
53            inlinable_rules: HashSet::new(),
54            left_recursive_rules: HashSet::new(),
55            source_file_id: None,
56        }
57    }
58
59    /// Optimize a grammar by applying all optimization passes
60    pub fn optimize(&mut self, grammar: &mut Grammar) -> OptimizationStats {
61        let mut stats = OptimizationStats::default();
62
63        // Check source_file status after each optimization
64        let check_source_file = |grammar: &Grammar, phase: &str| {
65            if let Some(sf_id) = grammar.find_symbol_by_name("source_file") {
66                let has_rules = grammar.rules.contains_key(&sf_id);
67                let rule_count = grammar.rules.get(&sf_id).map(|r| r.len()).unwrap_or(0);
68                debug_trace!(
69                    "Debug after {}: source_file is SymbolId({}), has_rules={}, rule_count={}",
70                    phase,
71                    sf_id.0,
72                    has_rules,
73                    rule_count
74                );
75            } else {
76                debug_trace!("Debug after {}: source_file not found!", phase);
77            }
78        };
79
80        // Phase 1: Analysis
81        self.analyze_grammar(grammar);
82        check_source_file(grammar, "analysis");
83
84        // Phase 2: Optimizations
85        stats.removed_unused_symbols = self.remove_unused_symbols(grammar);
86        check_source_file(grammar, "remove_unused_symbols");
87
88        stats.inlined_rules = self.inline_simple_rules(grammar);
89        check_source_file(grammar, "inline_simple_rules");
90
91        stats.merged_tokens = self.merge_equivalent_tokens(grammar);
92        check_source_file(grammar, "merge_equivalent_tokens");
93
94        stats.optimized_left_recursion = self.optimize_left_recursion(grammar);
95        check_source_file(grammar, "optimize_left_recursion");
96
97        stats.eliminated_unit_rules = self.eliminate_unit_rules(grammar);
98        check_source_file(grammar, "eliminate_unit_rules");
99
100        // Phase 3: Cleanup
101        self.renumber_symbols(grammar);
102        check_source_file(grammar, "renumber_symbols");
103
104        stats
105    }
106
107    /// Mark a symbol and all its sub-symbols as used
108    fn mark_used_in_symbol(&mut self, symbol: &Symbol) {
109        match symbol {
110            Symbol::Terminal(id) | Symbol::NonTerminal(id) | Symbol::External(id) => {
111                self.used_symbols.insert(*id);
112            }
113            Symbol::Optional(inner) | Symbol::Repeat(inner) | Symbol::RepeatOne(inner) => {
114                self.mark_used_in_symbol(inner);
115            }
116            Symbol::Choice(choices) => {
117                for s in choices {
118                    self.mark_used_in_symbol(s);
119                }
120            }
121            Symbol::Sequence(seq) => {
122                for s in seq {
123                    self.mark_used_in_symbol(s);
124                }
125            }
126            Symbol::Epsilon => {}
127        }
128    }
129
130    /// Analyze the grammar to collect information for optimization
131    fn analyze_grammar(&mut self, grammar: &Grammar) {
132        // Mark start symbol as used
133        if let Some(start_symbol) = grammar.start_symbol() {
134            self.used_symbols.insert(start_symbol);
135        }
136
137        // Always mark source_file as used if it exists (Tree-sitter compatibility)
138        if let Some(source_file_id) = grammar.find_symbol_by_name("source_file") {
139            self.source_file_id = Some(source_file_id);
140            self.used_symbols.insert(source_file_id);
141
142            // Also mark symbols referenced by source_file
143            if let Some(rules) = grammar.rules.get(&source_file_id) {
144                for rule in rules {
145                    for symbol in &rule.rhs {
146                        match symbol {
147                            Symbol::NonTerminal(id) => {
148                                self.used_symbols.insert(*id);
149                            }
150                            Symbol::Terminal(id) => {
151                                self.used_symbols.insert(*id);
152                            }
153                            _ => {
154                                // Other symbol types don't directly reference IDs that need marking
155                            }
156                        }
157                    }
158                }
159            }
160        }
161
162        // Also mark all rule LHS as used (they define the symbols)
163        for symbol_id in grammar.rules.keys() {
164            self.used_symbols.insert(*symbol_id);
165        }
166
167        // Analyze all rules
168        for rules in grammar.rules.values() {
169            for rule in rules {
170                // Mark symbols used in productions
171                for symbol in &rule.rhs {
172                    match symbol {
173                        Symbol::Terminal(id) | Symbol::NonTerminal(id) | Symbol::External(id) => {
174                            self.used_symbols.insert(*id);
175                        }
176                        Symbol::Optional(inner)
177                        | Symbol::Repeat(inner)
178                        | Symbol::RepeatOne(inner) => {
179                            self.mark_used_in_symbol(inner);
180                        }
181                        Symbol::Choice(choices) => {
182                            for s in choices {
183                                self.mark_used_in_symbol(s);
184                            }
185                        }
186                        Symbol::Sequence(seq) => {
187                            for s in seq {
188                                self.mark_used_in_symbol(s);
189                            }
190                        }
191                        Symbol::Epsilon => {}
192                    }
193                }
194
195                // Check if rule is inlinable (simple, non-recursive)
196                // Never inline source_file as it's the start symbol
197                if rule.rhs.len() == 1
198                    && !self.is_recursive_rule(rule, grammar)
199                    && Some(rule.lhs) != self.source_file_id
200                {
201                    self.inlinable_rules.insert(rule.lhs);
202                } else if Some(rule.lhs) == self.source_file_id && rule.rhs.len() == 1 {
203                    // source_file is not inlinable
204                }
205
206                // Check for left recursion
207                if self.is_left_recursive(rule) {
208                    self.left_recursive_rules.insert(rule.lhs);
209                }
210            }
211        }
212
213        // Note: We don't mark tokens as used here - they need to be referenced in rules
214    }
215
216    /// Remove symbols that are never referenced
217    fn remove_unused_symbols(&mut self, grammar: &mut Grammar) -> usize {
218        let mut removed = 0;
219
220        // Remove unused rules
221        let unused_rules: Vec<_> = grammar
222            .rules
223            .iter()
224            .filter(|(id, _)| !self.used_symbols.contains(id))
225            .map(|(id, _)| *id)
226            .collect();
227
228        for id in unused_rules {
229            grammar.rules.shift_remove(&id);
230            removed += 1;
231        }
232
233        // Remove unused tokens
234        let unused_tokens: Vec<_> = grammar
235            .tokens
236            .iter()
237            .filter(|(id, _)| !self.used_symbols.contains(id))
238            .map(|(id, _)| *id)
239            .collect();
240
241        for id in unused_tokens {
242            grammar.tokens.shift_remove(&id);
243            removed += 1;
244        }
245
246        removed
247    }
248
249    /// Inline simple rules that just reference another symbol
250    fn inline_simple_rules(&mut self, grammar: &mut Grammar) -> usize {
251        // Process inlinable rules
252
253        let mut inlined = 0;
254        let mut replacements = HashMap::new();
255
256        // Find rules to inline
257        for (symbol_id, rules) in &grammar.rules {
258            if self.inlinable_rules.contains(symbol_id) {
259                // Only inline if all rules for this symbol have exactly one RHS symbol
260                if rules.len() == 1
261                    && rules[0].rhs.len() == 1
262                    && let Some(target) = rules[0].rhs.first()
263                {
264                    replacements.insert(*symbol_id, target.clone());
265                }
266            } else if Some(*symbol_id) == self.source_file_id {
267                // source_file is not inlined
268            }
269        }
270
271        // Apply replacements
272        for rules in grammar.rules.values_mut() {
273            for rule in rules.iter_mut() {
274                let mut modified = false;
275                for symbol in &mut rule.rhs {
276                    if let Symbol::NonTerminal(id) = symbol
277                        && let Some(replacement) = replacements.get(id)
278                    {
279                        *symbol = replacement.clone();
280                        modified = true;
281                    }
282                }
283                if modified {
284                    inlined += 1;
285                }
286            }
287        }
288
289        // Remove inlined rules
290        for id in replacements.keys() {
291            grammar.rules.shift_remove(id);
292            grammar.inline_rules.push(*id);
293        }
294
295        inlined
296    }
297
298    /// Merge tokens with identical patterns
299    fn merge_equivalent_tokens(&mut self, grammar: &mut Grammar) -> usize {
300        let mut merged = 0;
301        let mut pattern_to_id: HashMap<String, SymbolId> = HashMap::new();
302        let mut replacements: HashMap<SymbolId, SymbolId> = HashMap::new();
303
304        // Find equivalent tokens
305        for (id, token) in &grammar.tokens {
306            let pattern_str = match &token.pattern {
307                TokenPattern::String(s) => s.clone(),
308                TokenPattern::Regex(r) => r.clone(),
309            };
310
311            if let Some(&existing_id) = pattern_to_id.get(&pattern_str) {
312                // Found duplicate
313                replacements.insert(*id, existing_id);
314                merged += 1;
315            } else {
316                pattern_to_id.insert(pattern_str, *id);
317            }
318        }
319
320        // Apply replacements in rules
321        for rules in grammar.rules.values_mut() {
322            for rule in rules.iter_mut() {
323                for symbol in &mut rule.rhs {
324                    if let Symbol::Terminal(id) = symbol
325                        && let Some(&new_id) = replacements.get(id)
326                    {
327                        *symbol = Symbol::Terminal(new_id);
328                    }
329                }
330            }
331        }
332
333        // Remove duplicate tokens
334        for old_id in replacements.keys() {
335            grammar.tokens.shift_remove(old_id);
336        }
337
338        merged
339    }
340
341    /// Optimize left-recursive rules by transforming them
342    fn optimize_left_recursion(&mut self, grammar: &mut Grammar) -> usize {
343        let mut optimized = 0;
344
345        // For each left-recursive rule, transform it
346        // A -> A α | β becomes:
347        // A -> β A'
348        // A' -> α A' | ε
349        let left_recursive: Vec<_> = self.left_recursive_rules.iter().cloned().collect();
350
351        for symbol in left_recursive {
352            if let Some(rules) = self.extract_rules_for_symbol(grammar, symbol) {
353                let (recursive_rules, base_rules) = self.partition_recursive_rules(&rules, symbol);
354
355                if !recursive_rules.is_empty() && !base_rules.is_empty() {
356                    // Create new symbol for the recursive part
357                    let new_symbol = self.create_new_symbol(grammar);
358
359                    // Transform the rules
360                    self.transform_left_recursion(
361                        grammar,
362                        symbol,
363                        new_symbol,
364                        recursive_rules,
365                        base_rules,
366                    );
367
368                    optimized += 1;
369                }
370            }
371        }
372
373        optimized
374    }
375
376    /// Eliminate unit rules (A -> B)
377    fn eliminate_unit_rules(&mut self, grammar: &mut Grammar) -> usize {
378        let mut eliminated = 0;
379        let mut unit_rules = Vec::new();
380
381        // Get the start symbol to prevent creating terminal productions for it
382        let start_symbol = grammar.start_symbol();
383
384        // Find unit rules
385        for rule in grammar.all_rules() {
386            if rule.rhs.len() == 1
387                && let Symbol::NonTerminal(_) = &rule.rhs[0]
388            {
389                unit_rules.push(rule.clone());
390            }
391        }
392
393        // For each unit rule A -> B, add rules A -> γ for each B -> γ
394        let mut new_rules = Vec::new();
395        for unit_rule in unit_rules {
396            if let Symbol::NonTerminal(target) = &unit_rule.rhs[0] {
397                if let Some(target_rules) = grammar.get_rules_for_symbol(*target) {
398                    for target_rule in target_rules {
399                        // Skip if this would create a terminal production for the start symbol
400                        if Some(unit_rule.lhs) == start_symbol
401                            && target_rule
402                                .rhs
403                                .iter()
404                                .any(|s| matches!(s, Symbol::Terminal(_)))
405                        {
406                            continue;
407                        }
408
409                        // Create new rule A -> γ
410                        let new_rule = Rule {
411                            lhs: unit_rule.lhs,
412                            rhs: target_rule.rhs.clone(),
413                            precedence: target_rule.precedence.or(unit_rule.precedence),
414                            associativity: target_rule.associativity.or(unit_rule.associativity),
415                            fields: target_rule.fields.clone(),
416                            production_id: self.create_new_production_id(grammar),
417                        };
418                        new_rules.push(new_rule);
419                        eliminated += 1;
420                    }
421                }
422                // Remove the unit rule from the appropriate symbol's rules
423                if let Some(symbol_rules) = grammar.rules.get_mut(&unit_rule.lhs) {
424                    symbol_rules.retain(|r| r.production_id != unit_rule.production_id);
425                    if symbol_rules.is_empty() {
426                        grammar.rules.shift_remove(&unit_rule.lhs);
427                    }
428                }
429            }
430        }
431
432        // Add all new rules
433        for rule in new_rules {
434            grammar.add_rule(rule);
435        }
436
437        eliminated
438    }
439
440    /// Check if a rule is recursive
441    fn is_recursive_rule(&self, rule: &Rule, grammar: &Grammar) -> bool {
442        let mut visited = HashSet::new();
443        self.contains_symbol_recursive(&rule.rhs, rule.lhs, grammar, &mut visited)
444    }
445
446    /// Check if a rule is left-recursive
447    fn is_left_recursive(&self, rule: &Rule) -> bool {
448        if let Some(Symbol::NonTerminal(id)) = rule.rhs.first() {
449            *id == rule.lhs
450        } else {
451            false
452        }
453    }
454
455    /// Recursively check if symbols contain a target symbol
456    #[allow(clippy::only_used_in_recursion)]
457    fn contains_symbol_recursive(
458        &self,
459        symbols: &[Symbol],
460        target: SymbolId,
461        grammar: &Grammar,
462        visited: &mut HashSet<SymbolId>,
463    ) -> bool {
464        for symbol in symbols {
465            match symbol {
466                Symbol::NonTerminal(id) if *id == target => return true,
467                Symbol::NonTerminal(id) if !visited.contains(id) => {
468                    visited.insert(*id);
469
470                    // Check all rules for this non-terminal
471                    if let Some(rules) = grammar.get_rules_for_symbol(*id) {
472                        for rule in rules {
473                            if self.contains_symbol_recursive(&rule.rhs, target, grammar, visited) {
474                                return true;
475                            }
476                        }
477                    }
478                }
479                _ => {}
480            }
481        }
482        false
483    }
484
485    /// Extract all rules for a given symbol
486    fn extract_rules_for_symbol(&self, grammar: &Grammar, symbol: SymbolId) -> Option<Vec<Rule>> {
487        grammar.get_rules_for_symbol(symbol).cloned()
488    }
489
490    /// Partition rules into recursive and non-recursive
491    fn partition_recursive_rules(
492        &self,
493        rules: &[Rule],
494        symbol: SymbolId,
495    ) -> (Vec<Rule>, Vec<Rule>) {
496        let mut recursive = Vec::new();
497        let mut non_recursive = Vec::new();
498
499        for rule in rules {
500            if let Some(Symbol::NonTerminal(id)) = rule.rhs.first() {
501                if *id == symbol {
502                    recursive.push(rule.clone());
503                } else {
504                    non_recursive.push(rule.clone());
505                }
506            } else {
507                non_recursive.push(rule.clone());
508            }
509        }
510
511        (recursive, non_recursive)
512    }
513
514    /// Create a new unique symbol ID
515    fn create_new_symbol(&self, grammar: &Grammar) -> SymbolId {
516        let max_id = grammar
517            .rules
518            .keys()
519            .chain(grammar.tokens.keys())
520            .map(|id| id.0)
521            .max()
522            .unwrap_or(0);
523
524        SymbolId(max_id + 1)
525    }
526
527    /// Create a new unique production ID
528    fn create_new_production_id(&self, grammar: &Grammar) -> ProductionId {
529        let max_id = grammar
530            .rules
531            .values()
532            .flat_map(|rules| rules.iter())
533            .map(|r| r.production_id.0)
534            .max()
535            .unwrap_or(0);
536
537        ProductionId(max_id + 1)
538    }
539
540    /// Transform left-recursive rules
541    fn transform_left_recursion(
542        &mut self,
543        grammar: &mut Grammar,
544        original_symbol: SymbolId,
545        new_symbol: SymbolId,
546        recursive_rules: Vec<Rule>,
547        base_rules: Vec<Rule>,
548    ) {
549        // Remove all original rules for the symbol using the current Grammar APIs
550        grammar.rules.shift_remove(&original_symbol);
551
552        // Any conflict declarations referencing the original symbol should also
553        // reference the new helper symbol to preserve conflict metadata
554        for conflict in &mut grammar.conflicts {
555            if conflict.symbols.contains(&original_symbol)
556                && !conflict.symbols.contains(&new_symbol)
557            {
558                conflict.symbols.push(new_symbol);
559            }
560        }
561
562        // Give the new symbol a readable name if possible
563        if let Some(name) = grammar.rule_names.get(&original_symbol).cloned() {
564            grammar
565                .rule_names
566                .insert(new_symbol, format!("{}__rec", name));
567        }
568
569        // Add transformed base rules: A -> β A'
570        for base_rule in base_rules {
571            let mut new_rhs = base_rule.rhs.clone();
572            new_rhs.push(Symbol::NonTerminal(new_symbol));
573
574            let new_rule = Rule {
575                lhs: original_symbol,
576                rhs: new_rhs,
577                precedence: base_rule.precedence,
578                associativity: base_rule.associativity,
579                fields: base_rule.fields,
580                production_id: self.create_new_production_id(grammar),
581            };
582
583            grammar.add_rule(new_rule);
584        }
585
586        // Add recursive rules: A' -> α A' | ε
587        for recursive_rule in recursive_rules {
588            // Remove the left-recursive symbol
589            let mut new_rhs: Vec<_> = recursive_rule.rhs[1..].to_vec();
590            new_rhs.push(Symbol::NonTerminal(new_symbol));
591            // Adjust field positions since we removed the first symbol
592            let adjusted_fields = recursive_rule
593                .fields
594                .iter()
595                .filter_map(|(field_id, index)| {
596                    if *index > 0 {
597                        Some((*field_id, index - 1))
598                    } else {
599                        None
600                    }
601                })
602                .collect();
603
604            let new_rule = Rule {
605                lhs: new_symbol,
606                rhs: new_rhs,
607                precedence: recursive_rule.precedence,
608                associativity: recursive_rule.associativity,
609                fields: adjusted_fields,
610                production_id: self.create_new_production_id(grammar),
611            };
612
613            grammar.add_rule(new_rule);
614        }
615
616        // Add epsilon rule: A' -> ε
617        let epsilon_rule = Rule {
618            lhs: new_symbol,
619            rhs: Vec::new(),
620            precedence: None,
621            associativity: None,
622            fields: Vec::new(),
623            production_id: self.create_new_production_id(grammar),
624        };
625
626        grammar.add_rule(epsilon_rule);
627    }
628
629    /// Helper to renumber a symbol recursively
630    #[allow(clippy::only_used_in_recursion)]
631    fn collect_symbol_ids(&self, symbol: &Symbol, ids: &mut HashSet<SymbolId>) {
632        match symbol {
633            Symbol::Terminal(id) | Symbol::NonTerminal(id) | Symbol::External(id) => {
634                ids.insert(*id);
635            }
636            Symbol::Optional(inner) | Symbol::Repeat(inner) | Symbol::RepeatOne(inner) => {
637                self.collect_symbol_ids(inner, ids);
638            }
639            Symbol::Choice(choices) => {
640                for s in choices {
641                    self.collect_symbol_ids(s, ids);
642                }
643            }
644            Symbol::Sequence(seq) => {
645                for s in seq {
646                    self.collect_symbol_ids(s, ids);
647                }
648            }
649            Symbol::Epsilon => {}
650        }
651    }
652
653    #[allow(clippy::only_used_in_recursion)]
654    fn renumber_symbol(&self, symbol: &mut Symbol, old_to_new: &HashMap<SymbolId, SymbolId>) {
655        match symbol {
656            Symbol::Terminal(id) | Symbol::NonTerminal(id) | Symbol::External(id) => {
657                if let Some(&new_id) = old_to_new.get(id) {
658                    *id = new_id;
659                }
660            }
661            Symbol::Optional(inner) | Symbol::Repeat(inner) | Symbol::RepeatOne(inner) => {
662                self.renumber_symbol(inner, old_to_new);
663            }
664            Symbol::Choice(choices) => {
665                for s in choices {
666                    self.renumber_symbol(s, old_to_new);
667                }
668            }
669            Symbol::Sequence(seq) => {
670                for s in seq {
671                    self.renumber_symbol(s, old_to_new);
672                }
673            }
674            Symbol::Epsilon => {}
675        }
676    }
677
678    /// Renumber symbols to be contiguous
679    fn renumber_symbols(&mut self, grammar: &mut Grammar) {
680        let mut old_to_new: HashMap<SymbolId, SymbolId> = HashMap::new();
681        let mut next_id = 1u16; // 0 is reserved for EOF
682
683        // Renumber symbols to be contiguous while preserving parse table ordering
684
685        // Collect all symbols
686        let mut token_symbols: HashSet<SymbolId> = HashSet::new();
687        let mut non_terminal_symbols: HashSet<SymbolId> = HashSet::new();
688        let mut external_symbols: HashSet<SymbolId> = HashSet::new();
689
690        // Categorize symbols
691        token_symbols.extend(grammar.tokens.keys().copied());
692
693        // Add all symbols from rules
694        for (symbol_id, _) in &grammar.rules {
695            if !token_symbols.contains(symbol_id) {
696                non_terminal_symbols.insert(*symbol_id);
697            }
698        }
699
700        // Add all symbols referenced in rule RHS
701        for rules in grammar.rules.values() {
702            for rule in rules {
703                for symbol in &rule.rhs {
704                    match symbol {
705                        Symbol::Terminal(id) => {
706                            token_symbols.insert(*id);
707                        }
708                        Symbol::NonTerminal(id) => {
709                            non_terminal_symbols.insert(*id);
710                        }
711                        Symbol::External(id) => {
712                            external_symbols.insert(*id);
713                        }
714                        _ => {
715                            let mut ids = HashSet::new();
716                            self.collect_symbol_ids(symbol, &mut ids);
717                            for id in ids {
718                                // Determine category based on existing knowledge
719                                if grammar.tokens.contains_key(&id) {
720                                    token_symbols.insert(id);
721                                } else if grammar.externals.iter().any(|e| e.symbol_id == id) {
722                                    external_symbols.insert(id);
723                                } else {
724                                    non_terminal_symbols.insert(id);
725                                }
726                            }
727                        }
728                    }
729                }
730            }
731        }
732
733        // Add external symbols
734        for external in &grammar.externals {
735            external_symbols.insert(external.symbol_id);
736        }
737
738        // Sort each category for deterministic ordering by symbol name
739        let mut token_vec: Vec<_> = token_symbols.into_iter().collect();
740        let mut non_terminal_vec: Vec<_> = non_terminal_symbols.into_iter().collect();
741        let mut external_vec: Vec<_> = external_symbols.into_iter().collect();
742
743        // Sort by symbol name for deterministic ordering
744        token_vec.sort_by_key(|id| {
745            grammar
746                .tokens
747                .get(id)
748                .map(|t| t.name.clone())
749                .unwrap_or_else(|| format!("_token_{}", id.0))
750        });
751        non_terminal_vec.sort_by_key(|id| {
752            // Try to find the symbol name from rule_names
753            grammar
754                .rule_names
755                .get(id)
756                .cloned()
757                .unwrap_or_else(|| format!("_nt_{}", id.0))
758        });
759        external_vec.sort_by_key(|id| {
760            grammar
761                .externals
762                .iter()
763                .find(|e| e.symbol_id == *id)
764                .map(|e| e.name.clone())
765                .unwrap_or_else(|| format!("_ext_{}", id.0))
766        });
767
768        // Assign new IDs preserving parse table ordering: tokens first, then non-terminals, then externals
769        debug_trace!("DEBUG renumber_symbols: Assigning new IDs");
770        debug_trace!("  Tokens: {:?}", token_vec);
771        debug_trace!("  Non-terminals: {:?}", non_terminal_vec);
772        debug_trace!("  Externals: {:?}", external_vec);
773
774        for old_id in token_vec {
775            if let std::collections::hash_map::Entry::Vacant(e) = old_to_new.entry(old_id) {
776                e.insert(SymbolId(next_id));
777                debug_trace!("  Token {:?} -> {:?}", old_id, SymbolId(next_id));
778                next_id += 1;
779            }
780        }
781
782        for old_id in non_terminal_vec {
783            if let std::collections::hash_map::Entry::Vacant(e) = old_to_new.entry(old_id) {
784                e.insert(SymbolId(next_id));
785                debug_trace!("  Non-terminal {:?} -> {:?}", old_id, SymbolId(next_id));
786                next_id += 1;
787            }
788        }
789
790        for old_id in external_vec {
791            if let std::collections::hash_map::Entry::Vacant(e) = old_to_new.entry(old_id) {
792                e.insert(SymbolId(next_id));
793                debug_trace!("  External {:?} -> {:?}", old_id, SymbolId(next_id));
794                next_id += 1;
795            }
796        }
797
798        // Apply renumbering mappings
799
800        // Update tokens
801        let mut new_tokens = IndexMap::new();
802        for (old_id, token) in grammar.tokens.drain(..) {
803            if let Some(&new_id) = old_to_new.get(&old_id) {
804                new_tokens.insert(new_id, token);
805            }
806        }
807        grammar.tokens = new_tokens;
808
809        // Update rules
810        let mut new_rules = IndexMap::new();
811        // Process rules
812
813        for (old_id, mut rules) in grammar.rules.drain(..) {
814            // Process rules for this symbol
815
816            // Update each rule in the vector
817            for rule in &mut rules {
818                // Update LHS
819                if let Some(&new_id) = old_to_new.get(&rule.lhs) {
820                    rule.lhs = new_id;
821                }
822
823                // Update RHS
824                for symbol in &mut rule.rhs {
825                    self.renumber_symbol(symbol, &old_to_new);
826                }
827            }
828
829            // Insert with possibly updated key
830            let new_key = if let Some(&new_id) = old_to_new.get(&old_id) {
831                // Renumber symbol
832                new_id
833            } else {
834                // Keep original ID
835                old_id
836            };
837            new_rules.insert(new_key, rules);
838        }
839
840        // Update grammar rules
841        grammar.rules = new_rules;
842
843        // Update source_file_id if it was renumbered
844        if let Some(sf_id) = self.source_file_id
845            && let Some(&new_id) = old_to_new.get(&sf_id)
846        {
847            // Update source_file_id
848            self.source_file_id = Some(new_id);
849        }
850
851        // Update rule_names
852        let mut new_rule_names = IndexMap::new();
853        for (old_id, name) in grammar.rule_names.drain(..) {
854            if let Some(&new_id) = old_to_new.get(&old_id) {
855                new_rule_names.insert(new_id, name);
856            }
857        }
858        grammar.rule_names = new_rule_names;
859
860        // Update other references
861        grammar.supertypes = grammar
862            .supertypes
863            .iter()
864            .filter_map(|id| old_to_new.get(id).copied())
865            .collect();
866
867        grammar.inline_rules = grammar
868            .inline_rules
869            .iter()
870            .filter_map(|id| old_to_new.get(id).copied())
871            .collect();
872
873        // Update external tokens
874        for external in &mut grammar.externals {
875            if let Some(&new_id) = old_to_new.get(&external.symbol_id) {
876                external.symbol_id = new_id;
877            }
878        }
879
880        // Update extras
881        debug_trace!("DEBUG renumber_symbols: Updating extras");
882        debug_trace!("  Old extras: {:?}", grammar.extras);
883        grammar.extras = grammar
884            .extras
885            .iter()
886            .filter_map(|&old_id| {
887                if let Some(&new_id) = old_to_new.get(&old_id) {
888                    debug_trace!("  Extra {:?} -> {:?}", old_id, new_id);
889                    Some(new_id)
890                } else {
891                    debug_trace!(
892                        "  WARNING: Extra {:?} not found in renumbering map!",
893                        old_id
894                    );
895                    None
896                }
897            })
898            .collect();
899        debug_trace!("  New extras: {:?}", grammar.extras);
900    }
901}
902
903/// Statistics about optimizations performed
904#[derive(Debug, Default)]
905pub struct OptimizationStats {
906    /// Number of unused symbols removed
907    pub removed_unused_symbols: usize,
908    /// Number of rules inlined
909    pub inlined_rules: usize,
910    /// Number of tokens merged
911    pub merged_tokens: usize,
912    /// Number of left-recursive rules optimized
913    pub optimized_left_recursion: usize,
914    /// Number of unit rules eliminated
915    pub eliminated_unit_rules: usize,
916}
917
918/// Convenience function to optimize a grammar
919pub fn optimize_grammar(mut grammar: Grammar) -> anyhow::Result<Grammar> {
920    let mut optimizer = GrammarOptimizer::new();
921    optimizer.optimize(&mut grammar);
922    Ok(grammar)
923}
924
925impl OptimizationStats {
926    /// Get total number of optimizations performed
927    pub fn total(&self) -> usize {
928        self.removed_unused_symbols
929            + self.inlined_rules
930            + self.merged_tokens
931            + self.optimized_left_recursion
932            + self.eliminated_unit_rules
933    }
934}
935
936#[cfg(test)]
937mod tests {
938    use super::*;
939    use crate::{Associativity, ConflictDeclaration, ConflictResolution, FieldId, PrecedenceKind};
940
941    fn create_test_grammar() -> Grammar {
942        let mut grammar = Grammar::new("test".to_string());
943
944        // Add some tokens
945        grammar.tokens.insert(
946            SymbolId(1),
947            Token {
948                name: "plus".to_string(),
949                pattern: TokenPattern::String("+".to_string()),
950                fragile: false,
951            },
952        );
953
954        grammar.tokens.insert(
955            SymbolId(2),
956            Token {
957                name: "number".to_string(),
958                pattern: TokenPattern::Regex(r"\d+".to_string()),
959                fragile: false,
960            },
961        );
962
963        // Add an unused token
964        grammar.tokens.insert(
965            SymbolId(99),
966            Token {
967                name: "unused".to_string(),
968                pattern: TokenPattern::String("unused".to_string()),
969                fragile: false,
970            },
971        );
972
973        // Add rules
974        let expr = SymbolId(3);
975        let term = SymbolId(4);
976
977        // expr -> expr + term (left recursive)
978        grammar.add_rule(Rule {
979            lhs: expr,
980            rhs: vec![
981                Symbol::NonTerminal(expr),
982                Symbol::Terminal(SymbolId(1)),
983                Symbol::NonTerminal(term),
984            ],
985            precedence: Some(PrecedenceKind::Static(1)),
986            associativity: Some(Associativity::Left),
987            fields: vec![],
988            production_id: ProductionId(0),
989        });
990
991        // expr -> term
992        grammar.add_rule(Rule {
993            lhs: expr,
994            rhs: vec![Symbol::NonTerminal(term)],
995            precedence: None,
996            associativity: None,
997            fields: vec![],
998            production_id: ProductionId(1),
999        });
1000
1001        // term -> number
1002        grammar.add_rule(Rule {
1003            lhs: term,
1004            rhs: vec![Symbol::Terminal(SymbolId(2))],
1005            precedence: None,
1006            associativity: None,
1007            fields: vec![],
1008            production_id: ProductionId(2),
1009        });
1010
1011        grammar
1012    }
1013
1014    #[test]
1015    fn test_remove_unused_symbols() {
1016        let mut grammar = create_test_grammar();
1017        let mut optimizer = GrammarOptimizer::new();
1018
1019        optimizer.analyze_grammar(&grammar);
1020
1021        debug_trace!("Used symbols: {:?}", optimizer.used_symbols);
1022        debug_trace!(
1023            "Tokens before: {:?}",
1024            grammar.tokens.keys().collect::<Vec<_>>()
1025        );
1026        debug_trace!(
1027            "Rules: {:?}",
1028            grammar
1029                .all_rules()
1030                .map(|r| (r.lhs, &r.rhs))
1031                .collect::<Vec<_>>()
1032        );
1033
1034        let removed = optimizer.remove_unused_symbols(&mut grammar);
1035
1036        debug_trace!("Removed: {}", removed);
1037        debug_trace!(
1038            "Tokens after: {:?}",
1039            grammar.tokens.keys().collect::<Vec<_>>()
1040        );
1041
1042        // We expect to remove: SymbolId(99) token, and the rule key symbols 5 and 6
1043        assert!(removed >= 1); // At least the unused token should be removed
1044        assert!(!grammar.tokens.contains_key(&SymbolId(99)));
1045    }
1046
1047    #[test]
1048    fn test_eliminate_unit_rules() {
1049        let mut grammar = create_test_grammar();
1050        let mut optimizer = GrammarOptimizer::new();
1051
1052        optimizer.analyze_grammar(&grammar);
1053        let _eliminated = optimizer.eliminate_unit_rules(&mut grammar);
1054
1055        // The test grammar may not have unit rules, which is fine
1056    }
1057
1058    #[test]
1059    fn test_optimization_stats() {
1060        let mut grammar = create_test_grammar();
1061        let mut optimizer = GrammarOptimizer::new();
1062
1063        let stats = optimizer.optimize(&mut grammar);
1064
1065        assert!(stats.total() > 0);
1066        debug_trace!("Optimization stats: {:?}", stats);
1067    }
1068
1069    #[test]
1070    fn test_left_recursion_detection() {
1071        let grammar = create_test_grammar();
1072        let mut optimizer = GrammarOptimizer::new();
1073
1074        optimizer.analyze_grammar(&grammar);
1075
1076        // The expr rule should be detected as left-recursive
1077        let expr = SymbolId(3);
1078        assert!(optimizer.left_recursive_rules.contains(&expr));
1079    }
1080
1081    #[test]
1082    fn test_inline_single_use_rules() {
1083        let mut grammar = Grammar::new("test".to_string());
1084
1085        // Create a rule that's only used once
1086        let single_use = SymbolId(10);
1087        let main = SymbolId(11);
1088        let terminal = SymbolId(12);
1089
1090        grammar.tokens.insert(
1091            terminal,
1092            Token {
1093                name: "a".to_string(),
1094                pattern: TokenPattern::String("a".to_string()),
1095                fragile: false,
1096            },
1097        );
1098
1099        // main -> single_use
1100        grammar.add_rule(Rule {
1101            lhs: main,
1102            rhs: vec![Symbol::NonTerminal(single_use)],
1103            precedence: None,
1104            associativity: None,
1105            fields: vec![],
1106            production_id: ProductionId(0),
1107        });
1108
1109        // single_use -> a
1110        grammar.add_rule(Rule {
1111            lhs: single_use,
1112            rhs: vec![Symbol::Terminal(terminal)],
1113            precedence: None,
1114            associativity: None,
1115            fields: vec![],
1116            production_id: ProductionId(1),
1117        });
1118
1119        let mut optimizer = GrammarOptimizer::new();
1120        optimizer.analyze_grammar(&grammar);
1121
1122        // The inline_simple_rules function eliminates unit rules, not general inlining
1123        // So we test that at least something was optimized
1124        let stats = optimizer.optimize(&mut grammar);
1125
1126        // Either unit rules were eliminated or symbols were removed
1127        assert!(stats.total() > 0);
1128    }
1129
1130    #[test]
1131    fn test_transform_left_recursion_rewrites_grammar() {
1132        let mut grammar = Grammar::new("lr".to_string());
1133
1134        // Tokens used in the grammar
1135        grammar.tokens.insert(
1136            SymbolId(1),
1137            Token {
1138                name: "+".to_string(),
1139                pattern: TokenPattern::String("+".to_string()),
1140                fragile: false,
1141            },
1142        );
1143        grammar.tokens.insert(
1144            SymbolId(2),
1145            Token {
1146                name: "b".to_string(),
1147                pattern: TokenPattern::String("b".to_string()),
1148                fragile: false,
1149            },
1150        );
1151
1152        // Field and rule name for the non-terminal
1153        grammar.fields.insert(FieldId(0), "b".to_string());
1154        let a = SymbolId(3);
1155        grammar.rule_names.insert(a, "A".to_string());
1156
1157        // Left-recursive rule: A -> A + b
1158        grammar.add_rule(Rule {
1159            lhs: a,
1160            rhs: vec![
1161                Symbol::NonTerminal(a),
1162                Symbol::Terminal(SymbolId(1)),
1163                Symbol::Terminal(SymbolId(2)),
1164            ],
1165            precedence: Some(PrecedenceKind::Static(5)),
1166            associativity: Some(Associativity::Left),
1167            fields: vec![(FieldId(0), 2)],
1168            production_id: ProductionId(0),
1169        });
1170
1171        // Base rule: A -> b
1172        grammar.add_rule(Rule {
1173            lhs: a,
1174            rhs: vec![Symbol::Terminal(SymbolId(2))],
1175            precedence: None,
1176            associativity: None,
1177            fields: vec![(FieldId(0), 0)],
1178            production_id: ProductionId(1),
1179        });
1180
1181        // Conflict referencing original symbol
1182        grammar.conflicts.push(ConflictDeclaration {
1183            symbols: vec![a],
1184            resolution: ConflictResolution::GLR,
1185        });
1186
1187        let mut optimizer = GrammarOptimizer::new();
1188        optimizer.analyze_grammar(&grammar);
1189        let rules = optimizer.extract_rules_for_symbol(&grammar, a).unwrap();
1190        let (recursive, base) = optimizer.partition_recursive_rules(&rules, a);
1191        let new_symbol = optimizer.create_new_symbol(&grammar);
1192        optimizer.transform_left_recursion(&mut grammar, a, new_symbol, recursive, base);
1193
1194        // Verify base rule was rewritten
1195        let b_id = grammar
1196            .tokens
1197            .iter()
1198            .find(|(_, t)| t.name == "b")
1199            .map(|(id, _)| *id)
1200            .unwrap();
1201        let a_rules = grammar.get_rules_for_symbol(a).unwrap();
1202        assert_eq!(a_rules.len(), 1);
1203        assert_eq!(
1204            a_rules[0].rhs,
1205            vec![Symbol::Terminal(b_id), Symbol::NonTerminal(new_symbol)]
1206        );
1207        assert_eq!(a_rules[0].fields, vec![(FieldId(0), 0)]);
1208
1209        // Verify new symbol rules
1210        let plus_id = grammar
1211            .tokens
1212            .iter()
1213            .find(|(_, t)| t.name == "+")
1214            .map(|(id, _)| *id)
1215            .unwrap();
1216        let new_rules = grammar.get_rules_for_symbol(new_symbol).unwrap();
1217        assert_eq!(new_rules.len(), 2);
1218        let recursive_rule = new_rules.iter().find(|r| !r.rhs.is_empty()).unwrap();
1219        assert_eq!(
1220            recursive_rule.rhs,
1221            vec![
1222                Symbol::Terminal(plus_id),
1223                Symbol::Terminal(b_id),
1224                Symbol::NonTerminal(new_symbol),
1225            ]
1226        );
1227        assert_eq!(recursive_rule.fields, vec![(FieldId(0), 1)]);
1228        assert_eq!(recursive_rule.precedence, Some(PrecedenceKind::Static(5)));
1229        assert_eq!(recursive_rule.associativity, Some(Associativity::Left));
1230
1231        // Ensure epsilon rule exists
1232        assert!(new_rules.iter().any(|r| r.rhs.is_empty()));
1233
1234        // Conflicts should include new symbol
1235        assert!(grammar.conflicts[0].symbols.contains(&a));
1236        assert!(grammar.conflicts[0].symbols.contains(&new_symbol));
1237    }
1238}