xee_ir/
declaration_compiler.rs1use ahash::{HashMap, HashMapExt, HashSet, HashSetExt};
2use rust_decimal::Decimal;
3use xee_interpreter::pattern::ModeId;
4use xee_xpath_ast::Pattern;
5
6use crate::function_compiler::Scopes;
7use crate::{ir, FunctionBuilder, FunctionCompiler};
8
9use xee_interpreter::{error, function, interpreter};
10use xee_xpath_ast::pattern::transform_pattern;
11
12#[derive(Debug, Clone)]
13pub(crate) struct RuleBuilder {
14 priority: Decimal,
15 declaration_order: i64,
16 pattern: Pattern<function::InlineFunctionId>,
17 function_id: function::InlineFunctionId,
18}
19
20impl RuleBuilder {
21 fn rule(
22 self,
23 ) -> (
24 Pattern<function::InlineFunctionId>,
25 function::InlineFunctionId,
26 ) {
27 (self.pattern, self.function_id)
28 }
29}
30
31pub type ModeIds = HashMap<ir::ApplyTemplatesModeValue, ModeId>;
32
33pub struct DeclarationCompiler<'a> {
34 program: &'a mut interpreter::Program,
35 scopes: Scopes,
36 rule_declaration_order: i64,
37 rule_builders: HashMap<ir::ModeValue, Vec<RuleBuilder>>,
38 mode_ids: ModeIds,
39}
40
41impl<'a> DeclarationCompiler<'a> {
42 pub fn new(program: &'a mut interpreter::Program) -> Self {
43 Self {
44 program,
45 scopes: Scopes::new(),
46 rule_declaration_order: 0,
47 rule_builders: HashMap::new(),
48 mode_ids: HashMap::new(),
49 }
50 }
51
52 fn function_compiler(&mut self) -> FunctionCompiler<'_> {
53 let function_builder = FunctionBuilder::new(self.program);
54 FunctionCompiler::new(function_builder, &mut self.scopes, &self.mode_ids)
55 }
56
57 pub fn compile_declarations(
58 &mut self,
59 declarations: &ir::Declarations,
60 ) -> error::SpannedResult<()> {
61 self.compile_modes(declarations);
64
65 for rule in &declarations.rules {
66 self.compile_rule(rule)?;
67 }
68 self.add_rules();
70 let mut function_compiler = self.function_compiler();
71 function_compiler.compile_function_definition(&declarations.main, (0..0).into())
72 }
73
74 fn compile_modes(&mut self, declarations: &ir::Declarations) {
75 for rule in &declarations.rules {
76 for mode_value in &rule.modes {
77 if matches!(mode_value, ir::ModeValue::All) {
79 continue;
80 }
81 let apply_templates_mode_value = match mode_value {
82 ir::ModeValue::All => continue,
83 ir::ModeValue::Named(name) => ir::ApplyTemplatesModeValue::Named(name.clone()),
84 ir::ModeValue::Unnamed => ir::ApplyTemplatesModeValue::Unnamed,
85 };
86 if self.mode_ids.contains_key(&apply_templates_mode_value) {
88 continue;
89 }
90 let mode_id = ModeId::new(self.mode_ids.len());
91 self.mode_ids.insert(apply_templates_mode_value, mode_id);
92 }
93 }
94 }
95
96 fn compile_rule(&mut self, rule: &ir::Rule) -> error::SpannedResult<()> {
97 let mut function_compiler = self.function_compiler();
98 let function_id =
99 function_compiler.compile_function_id(&rule.function_definition, (0..0).into())?;
100
101 let pattern = transform_pattern(&rule.pattern, |function_definition| {
102 function_compiler.compile_function_id(function_definition, (0..0).into())
103 })?;
104
105 self.add_rule(&rule.modes, rule.priority, &pattern, function_id);
106 Ok(())
107 }
108
109 fn add_rule(
110 &mut self,
111 modes: &[ir::ModeValue],
112 priority: Decimal,
113 pattern: &Pattern<function::InlineFunctionId>,
114 function_id: function::InlineFunctionId,
115 ) {
116 let mut mode_seen = HashSet::new();
118
119 let declaration_order = self.rule_declaration_order;
120 self.rule_declaration_order += 1;
121 for mode in modes {
122 if mode_seen.contains(mode) {
123 continue;
124 }
125 mode_seen.insert(mode);
126 self.rule_builders
127 .entry(mode.clone())
128 .or_default()
129 .push(RuleBuilder {
130 priority,
131 declaration_order,
132 pattern: pattern.clone(),
133 function_id,
134 });
135 }
136 }
137
138 fn add_rules(&mut self) {
139 let all_rule_builders = self.rule_builders.remove(&ir::ModeValue::All);
141
142 if let Some(all_rule_builders) = all_rule_builders {
146 for rule_builders in self.rule_builders.values_mut() {
147 for all_rule_builder in &all_rule_builders {
148 rule_builders.push(all_rule_builder.clone());
149 }
150 }
151 }
152
153 for (mode, mut rule_builders) in self.rule_builders.drain() {
154 rule_builders.sort_by_key(|rule_builder| {
156 (-rule_builder.priority, -rule_builder.declaration_order)
157 });
158 let rules = rule_builders
159 .drain(..)
160 .map(|rule_builder| rule_builder.rule())
161 .collect();
162 let apply_templates_mode_value = match mode {
163 ir::ModeValue::Named(name) => ir::ApplyTemplatesModeValue::Named(name),
164 ir::ModeValue::Unnamed => ir::ApplyTemplatesModeValue::Unnamed,
165 ir::ModeValue::All => {
166 unreachable!()
167 }
168 };
169 let mode_id = self
170 .mode_ids
171 .get(&apply_templates_mode_value)
172 .cloned()
173 .expect("Mode should have been registered");
174 self.program
175 .declarations
176 .mode_lookup
177 .add_rules(mode_id, rules)
178 }
179 }
180}