1#![allow(missing_docs)]
8
9use std::{
10 collections::{HashMap, HashSet},
11 hash::{DefaultHasher, Hash, Hasher},
12};
13
14use petgraph::{graph::NodeIndex, visit::EdgeRef, Direction, Graph};
15
16use crate::mir::{MirEdgeKind, MirGraph, MirNodeKind};
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
19pub struct CanonicalKey(u64);
20
21impl CanonicalKey {
22 #[must_use]
23 pub const fn get(self) -> u64 {
24 self.0
25 }
26}
27
28#[must_use]
29pub fn canonical_key(graph: &MirGraph) -> CanonicalKey {
30 let mut hasher = DefaultHasher::new();
31 canonical_form(graph).hash(&mut hasher);
32 CanonicalKey(hasher.finish())
33}
34
35#[must_use]
36pub fn canonical_form(graph: &MirGraph) -> String {
37 let mut stack = HashSet::new();
38 canonical_node(graph, graph.root(), &mut stack)
39}
40
41#[must_use]
42pub fn collapse_reused_subgraphs(graph: &MirGraph) -> MirGraph {
43 let mut rebuilt = Graph::new();
44 let mut interned = HashMap::new();
45 let mut stack = HashSet::new();
46 let (root, _) = rebuild_node(graph, graph.root(), &mut rebuilt, &mut interned, &mut stack);
47 MirGraph::from_graph(rebuilt, root)
48}
49
50fn rebuild_node(
51 source: &MirGraph,
52 node_index: NodeIndex,
53 rebuilt: &mut Graph<MirNodeKind, MirEdgeKind>,
54 interned: &mut HashMap<String, NodeIndex>,
55 stack: &mut HashSet<NodeIndex>,
56) -> (NodeIndex, String) {
57 assert!(stack.insert(node_index), "MIR graph contains a cycle");
58
59 let mut inputs = source
60 .graph()
61 .edges_directed(node_index, Direction::Incoming)
62 .map(|edge| {
63 let (child, child_signature) =
64 rebuild_node(source, edge.source(), rebuilt, interned, stack);
65 (*edge.weight(), child, child_signature)
66 })
67 .collect::<Vec<_>>();
68
69 if matches!(
70 source.graph()[node_index],
71 MirNodeKind::Union { .. } | MirNodeKind::Intersect { .. }
72 ) {
73 inputs.sort_by(|left, right| left.2.cmp(&right.2));
74 }
75
76 let input_signature = inputs
77 .iter()
78 .map(|(edge, _, child)| format!("{}:{child}", edge_kind_name(*edge)))
79 .collect::<Vec<_>>()
80 .join(",");
81 let signature = format!(
82 "{}[{input_signature}]",
83 canonical_node_kind(&source.graph()[node_index])
84 );
85
86 if let Some(existing) = interned.get(&signature) {
87 stack.remove(&node_index);
88 return (*existing, signature);
89 }
90
91 let rebuilt_node = rebuilt.add_node(source.graph()[node_index].clone());
92 for (edge, child, _) in inputs {
93 rebuilt.add_edge(child, rebuilt_node, edge);
94 }
95 interned.insert(signature.clone(), rebuilt_node);
96 stack.remove(&node_index);
97 (rebuilt_node, signature)
98}
99
100fn canonical_node(
101 graph: &MirGraph,
102 node_index: NodeIndex,
103 stack: &mut HashSet<NodeIndex>,
104) -> String {
105 assert!(stack.insert(node_index), "MIR graph contains a cycle");
106
107 let mut inputs = graph
108 .graph()
109 .edges_directed(node_index, Direction::Incoming)
110 .map(|edge| {
111 let edge_kind = edge_kind_name(*edge.weight());
112 format!(
113 "{edge_kind}:{}",
114 canonical_node(graph, edge.source(), stack)
115 )
116 })
117 .collect::<Vec<_>>();
118
119 if matches!(
120 graph.graph()[node_index],
121 MirNodeKind::Union { .. } | MirNodeKind::Intersect { .. }
122 ) {
123 inputs.sort();
124 }
125
126 let node = canonical_node_kind(&graph.graph()[node_index]);
127 stack.remove(&node_index);
128 format!("{node}[{}]", inputs.join(","))
129}
130
131const fn edge_kind_name(edge: MirEdgeKind) -> &'static str {
132 match edge {
133 MirEdgeKind::Input => "input",
134 MirEdgeKind::CteExpansion => "cte",
135 }
136}
137
138fn canonical_node_kind(node: &MirNodeKind) -> String {
139 match node {
140 MirNodeKind::BaseTable { table, project } => {
141 format!("base:{table}:{}", canonical_debug(project))
142 }
143 MirNodeKind::Filter { predicate } => format!("filter:{predicate}"),
144 MirNodeKind::Project { columns } => format!("project:{}", columns.join(",")),
145 MirNodeKind::Join { kind, on } => {
146 format!("join:{kind:?}:{}", canonical_debug(on))
147 }
148 MirNodeKind::Aggregate { group_by, aggs } => {
149 format!(
150 "aggregate:{}:{}",
151 canonical_debug(group_by),
152 canonical_debug(aggs)
153 )
154 }
155 MirNodeKind::Distinct => "distinct".to_owned(),
156 MirNodeKind::Union { quantifier } => format!("union:{quantifier:?}"),
157 MirNodeKind::Except { quantifier } => format!("except:{quantifier:?}"),
158 MirNodeKind::Intersect { quantifier } => format!("intersect:{quantifier:?}"),
159 MirNodeKind::TopK {
160 order_by,
161 limit,
162 offset,
163 } => format!("topk:{}:{limit}:{offset}", canonical_debug(order_by)),
164 MirNodeKind::CteRef { .. } => "cte-ref".to_owned(),
165 MirNodeKind::Leaf { name } => format!("leaf:{name}"),
166 }
167}
168
169fn canonical_debug<T: core::fmt::Debug>(value: &T) -> String {
170 format!("{value:?}")
171}
172
173#[cfg(test)]
174mod tests {
175 use crate::{
176 canonical::{canonical_form, canonical_key, collapse_reused_subgraphs},
177 lower::parse_and_lower,
178 };
179
180 #[test]
181 fn equivalent_filter_conjunctions_have_same_key() {
182 let left = parse_and_lower(
183 "SELECT id FROM posts
184 WHERE author_id = 42 AND id = 7",
185 )
186 .expect("query should lower");
187 let right = parse_and_lower(
188 "SELECT id FROM posts
189 WHERE id = 7 AND author_id = 42",
190 )
191 .expect("query should lower");
192
193 assert_eq!(canonical_form(&left), canonical_form(&right));
194 assert_eq!(canonical_key(&left), canonical_key(&right));
195 }
196
197 #[test]
198 fn different_queries_have_different_keys() {
199 let left = parse_and_lower("SELECT id FROM posts WHERE author_id = 42")
200 .expect("query should lower");
201 let right = parse_and_lower("SELECT id FROM posts WHERE author_id = 43")
202 .expect("query should lower");
203
204 assert_ne!(canonical_key(&left), canonical_key(&right));
205 }
206
207 #[test]
208 fn normalized_literals_have_same_key() {
209 let left = parse_and_lower("SELECT id FROM posts WHERE author_id = 00042")
210 .expect("query should lower");
211 let right = parse_and_lower("SELECT id FROM posts WHERE author_id = 42")
212 .expect("query should lower");
213 let escaped = parse_and_lower("SELECT id FROM posts WHERE title = E'hello'")
214 .expect("query should lower");
215 let quoted = parse_and_lower("SELECT id FROM posts WHERE title = 'hello'")
216 .expect("query should lower");
217
218 assert_eq!(canonical_form(&left), canonical_form(&right));
219 assert_eq!(canonical_key(&escaped), canonical_key("ed));
220 }
221
222 #[test]
223 fn cte_names_do_not_affect_canonical_key() {
224 let left = parse_and_lower(
225 "WITH recent_posts AS (
226 SELECT id FROM posts WHERE author_id = 42
227 )
228 SELECT id FROM recent_posts",
229 )
230 .expect("query should lower");
231 let right = parse_and_lower(
232 "WITH visible_posts AS (
233 SELECT id FROM posts WHERE author_id = 42
234 )
235 SELECT id FROM visible_posts",
236 )
237 .expect("query should lower");
238
239 assert_eq!(canonical_key(&left), canonical_key(&right));
240 }
241
242 #[test]
243 fn collapse_reused_subgraphs_shares_duplicate_branches() {
244 let graph = parse_and_lower(
245 "SELECT id FROM posts
246 UNION ALL
247 SELECT id FROM posts",
248 )
249 .expect("query should lower");
250
251 let collapsed = collapse_reused_subgraphs(&graph);
252
253 assert!(collapsed.node_count() < graph.node_count());
254 assert_eq!(canonical_key(&graph), canonical_key(&collapsed));
255 }
256}