xee_ir/
declaration_compiler.rs

1use ahash::{HashMap, HashMapExt, HashSet, HashSetExt};
2use rust_decimal::Decimal;
3use xee_interpreter::pattern::ModeId;
4use xee_xpath_ast::Pattern;
5
6use crate::function_compiler::Scopes;
7use crate::{ir, FunctionBuilder, FunctionCompiler};
8
9use xee_interpreter::{error, function, interpreter};
10use xee_xpath_ast::pattern::transform_pattern;
11
12#[derive(Debug, Clone)]
13pub(crate) struct RuleBuilder {
14    priority: Decimal,
15    declaration_order: i64,
16    pattern: Pattern<function::InlineFunctionId>,
17    function_id: function::InlineFunctionId,
18}
19
20impl RuleBuilder {
21    fn rule(
22        self,
23    ) -> (
24        Pattern<function::InlineFunctionId>,
25        function::InlineFunctionId,
26    ) {
27        (self.pattern, self.function_id)
28    }
29}
30
31pub type ModeIds = HashMap<ir::ApplyTemplatesModeValue, ModeId>;
32
33pub struct DeclarationCompiler<'a> {
34    program: &'a mut interpreter::Program,
35    scopes: Scopes,
36    rule_declaration_order: i64,
37    rule_builders: HashMap<ir::ModeValue, Vec<RuleBuilder>>,
38    mode_ids: ModeIds,
39}
40
41impl<'a> DeclarationCompiler<'a> {
42    pub fn new(program: &'a mut interpreter::Program) -> Self {
43        Self {
44            program,
45            scopes: Scopes::new(),
46            rule_declaration_order: 0,
47            rule_builders: HashMap::new(),
48            mode_ids: HashMap::new(),
49        }
50    }
51
52    fn function_compiler(&mut self) -> FunctionCompiler<'_> {
53        let function_builder = FunctionBuilder::new(self.program);
54        FunctionCompiler::new(function_builder, &mut self.scopes, &self.mode_ids)
55    }
56
57    pub fn compile_declarations(
58        &mut self,
59        declarations: &ir::Declarations,
60    ) -> error::SpannedResult<()> {
61        // first keep track of what modes exist, to create a ModeId for them. We do
62        // this early so any mode reference within apply-templates will resolve.
63        self.compile_modes(declarations);
64
65        for rule in &declarations.rules {
66            self.compile_rule(rule)?;
67        }
68        // now add compiled rules from builder to the program
69        self.add_rules();
70        let mut function_compiler = self.function_compiler();
71        function_compiler.compile_function_definition(&declarations.main, (0..0).into())
72    }
73
74    fn compile_modes(&mut self, declarations: &ir::Declarations) {
75        for rule in &declarations.rules {
76            for mode_value in &rule.modes {
77                // we don't register All modes
78                if matches!(mode_value, ir::ModeValue::All) {
79                    continue;
80                }
81                let apply_templates_mode_value = match mode_value {
82                    ir::ModeValue::All => continue,
83                    ir::ModeValue::Named(name) => ir::ApplyTemplatesModeValue::Named(name.clone()),
84                    ir::ModeValue::Unnamed => ir::ApplyTemplatesModeValue::Unnamed,
85                };
86                // we want the mode id to be unique and not overwritten
87                if self.mode_ids.contains_key(&apply_templates_mode_value) {
88                    continue;
89                }
90                let mode_id = ModeId::new(self.mode_ids.len());
91                self.mode_ids.insert(apply_templates_mode_value, mode_id);
92            }
93        }
94    }
95
96    fn compile_rule(&mut self, rule: &ir::Rule) -> error::SpannedResult<()> {
97        let mut function_compiler = self.function_compiler();
98        let function_id =
99            function_compiler.compile_function_id(&rule.function_definition, (0..0).into())?;
100
101        let pattern = transform_pattern(&rule.pattern, |function_definition| {
102            function_compiler.compile_function_id(function_definition, (0..0).into())
103        })?;
104
105        self.add_rule(&rule.modes, rule.priority, &pattern, function_id);
106        Ok(())
107    }
108
109    fn add_rule(
110        &mut self,
111        modes: &[ir::ModeValue],
112        priority: Decimal,
113        pattern: &Pattern<function::InlineFunctionId>,
114        function_id: function::InlineFunctionId,
115    ) {
116        // ensure there are no duplicate modes
117        let mut mode_seen = HashSet::new();
118
119        let declaration_order = self.rule_declaration_order;
120        self.rule_declaration_order += 1;
121        for mode in modes {
122            if mode_seen.contains(mode) {
123                continue;
124            }
125            mode_seen.insert(mode);
126            self.rule_builders
127                .entry(mode.clone())
128                .or_default()
129                .push(RuleBuilder {
130                    priority,
131                    declaration_order,
132                    pattern: pattern.clone(),
133                    function_id,
134                });
135        }
136    }
137
138    fn add_rules(&mut self) {
139        // we don't want to register #all normally
140        let all_rule_builders = self.rule_builders.remove(&ir::ModeValue::All);
141
142        // we add the all rule builders to each rule builders, as they apply to
143        // all modes. We do this before the final registration so we benefit
144        // from priority sorting later
145        if let Some(all_rule_builders) = all_rule_builders {
146            for rule_builders in self.rule_builders.values_mut() {
147                for all_rule_builder in &all_rule_builders {
148                    rule_builders.push(all_rule_builder.clone());
149                }
150            }
151        }
152
153        for (mode, mut rule_builders) in self.rule_builders.drain() {
154            // higher priorities first, same priorities last declaration order wins
155            rule_builders.sort_by_key(|rule_builder| {
156                (-rule_builder.priority, -rule_builder.declaration_order)
157            });
158            let rules = rule_builders
159                .drain(..)
160                .map(|rule_builder| rule_builder.rule())
161                .collect();
162            let apply_templates_mode_value = match mode {
163                ir::ModeValue::Named(name) => ir::ApplyTemplatesModeValue::Named(name),
164                ir::ModeValue::Unnamed => ir::ApplyTemplatesModeValue::Unnamed,
165                ir::ModeValue::All => {
166                    unreachable!()
167                }
168            };
169            let mode_id = self
170                .mode_ids
171                .get(&apply_templates_mode_value)
172                .cloned()
173                .expect("Mode should have been registered");
174            self.program
175                .declarations
176                .mode_lookup
177                .add_rules(mode_id, rules)
178        }
179    }
180}