kotoba_rewrite/rewrite/
matcher.rs1use kotoba_core::{ir::*, types::*};
4use kotoba_graph::graph::*;
5use std::collections::HashMap;
6use kotoba_core::types::Result;
7
8#[derive(Debug)]
10pub struct RuleMatcher;
11
12impl RuleMatcher {
13 pub fn new() -> Self {
14 Self
15 }
16
17 pub fn find_matches(&self, graph: &GraphRef, rule: &RuleIR, catalog: &Catalog) -> Result<Vec<Match>> {
19 let graph = graph.read();
20
21 let mut matches = Vec::new();
23
24 let initial_candidates = self.generate_initial_candidates(&graph, &rule.lhs);
26
27 for candidate in initial_candidates {
28 if self.match_lhs(&graph, &rule.lhs, &candidate, catalog) {
29 if self.check_nacs(&graph, &rule.nacs, &candidate, catalog) {
31 if self.check_guards(&graph, &rule.guards, &candidate, catalog) {
33 let match_score = self.calculate_match_score(&candidate);
34 matches.push(Match {
35 mapping: candidate,
36 score: match_score,
37 });
38 }
39 }
40 }
41 }
42
43 matches.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
45
46 Ok(matches)
47 }
48
49 fn generate_initial_candidates(&self, graph: &Graph, lhs: &GraphPattern)
51 -> Vec<HashMap<String, VertexId>> {
52
53 if lhs.nodes.is_empty() {
54 return vec![HashMap::new()];
55 }
56
57 let mut candidates = Vec::new();
58
59 let first_node = &lhs.nodes[0];
61 let vertex_ids = if let Some(label) = &first_node.type_ {
62 graph.vertices_by_label(label)
63 } else {
64 graph.vertices.keys().cloned().collect()
65 };
66
67 for vertex_id in vertex_ids {
68 let mut mapping = HashMap::new();
69 mapping.insert(first_node.id.clone(), vertex_id);
70 candidates.push(mapping);
71 }
72
73 candidates
74 }
75
76 fn match_lhs(&self, graph: &Graph, lhs: &GraphPattern,
78 mapping: &HashMap<String, VertexId>, catalog: &Catalog) -> bool {
79
80 for node in &lhs.nodes {
82 if let Some(&vertex_id) = mapping.get(&node.id) {
83 if let Some(vertex) = graph.get_vertex(&vertex_id) {
84 if let Some(expected_label) = &node.type_ {
86 if !vertex.labels.contains(expected_label) {
87 return false;
88 }
89 }
90
91 if let Some(expected_props) = &node.props {
93 for (key, expected_value) in expected_props {
94 if let Some(actual_value) = vertex.props.get(key) {
95 if !self.values_match(actual_value, expected_value) {
96 return false;
97 }
98 } else {
99 return false;
100 }
101 }
102 }
103 } else {
104 return false;
105 }
106 }
107 }
108
109 for edge in &lhs.edges {
111 if let (Some(&src_id), Some(&dst_id)) = (mapping.get(&edge.src), mapping.get(&edge.dst)) {
112 if !graph.adj_out.get(&src_id)
114 .map(|neighbors| neighbors.contains(&dst_id))
115 .unwrap_or(false) {
116 return false;
117 }
118
119 }
122 }
123
124 true
125 }
126
127 fn check_nacs(&self, graph: &Graph, nacs: &[Nac],
129 mapping: &HashMap<String, VertexId>, catalog: &Catalog) -> bool {
130
131 for nac in nacs {
132 if self.match_nac(graph, nac, mapping, catalog) {
134 return false;
135 }
136 }
137
138 true
139 }
140
141 fn match_nac(&self, graph: &Graph, nac: &Nac,
143 mapping: &HashMap<String, VertexId>, _catalog: &Catalog) -> bool {
144
145 for node in &nac.nodes {
147 if let Some(&vertex_id) = mapping.get(&node.id) {
148 if let Some(vertex) = graph.get_vertex(&vertex_id) {
149 if let Some(expected_label) = &node.type_ {
151 if vertex.labels.contains(expected_label) {
152 return true;
153 }
154 }
155 }
156 }
157 }
158
159 for edge in &nac.edges {
161 if let (Some(&src_id), Some(&dst_id)) = (mapping.get(&edge.src), mapping.get(&edge.dst)) {
162 if graph.adj_out.get(&src_id)
163 .map(|neighbors| neighbors.contains(&dst_id))
164 .unwrap_or(false) {
165 return true;
166 }
167 }
168 }
169
170 false
171 }
172
173 fn check_guards(&self, graph: &Graph, guards: &[Guard],
175 mapping: &HashMap<String, VertexId>, _catalog: &Catalog) -> bool {
176
177 for guard in guards {
178 if !self.evaluate_guard(graph, guard, mapping, _catalog) {
179 return false;
180 }
181 }
182
183 true
184 }
185
186 fn evaluate_guard(&self, graph: &Graph, guard: &Guard,
188 mapping: &HashMap<String, VertexId>, _catalog: &Catalog) -> bool {
189
190 match guard.ref_.as_str() {
191 "deg_ge" => {
192 if let Some(Value::Int(k)) = guard.args.get("k") {
194 if let Some(Value::String(var)) = guard.args.get("var") {
195 if let Some(&vertex_id) = mapping.get(var) {
196 let degree = graph.degree(&vertex_id);
197 return degree >= *k as usize;
198 }
199 }
200 }
201 false
202 }
203 _ => {
204 true
206 }
207 }
208 }
209
210 fn values_match(&self, actual: &Value, expected: &Value) -> bool {
212 match (actual, expected) {
213 (Value::Null, Value::Null) => true,
214 (Value::Bool(a), Value::Bool(b)) => a == b,
215 (Value::Int(a), Value::Int(b)) => a == b,
216 (Value::String(a), Value::String(b)) => a == b,
217 _ => false,
218 }
219 }
220
221 fn calculate_match_score(&self, mapping: &HashMap<String, VertexId>) -> f64 {
223 mapping.len() as f64
225 }
226}