Skip to main content

palimpsest_sql/
canonical.rs

1// Copyright 2026 Thousand Birds Inc.
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! Canonical-form fingerprint and reused-subgraph collapsing. Internal
5//! to MIR processing — exposed for callers that compare/dedupe graphs.
6
7#![allow(missing_docs)]
8
9use std::{
10    collections::{HashMap, HashSet},
11    hash::{DefaultHasher, Hash, Hasher},
12};
13
14use petgraph::{graph::NodeIndex, visit::EdgeRef, Direction, Graph};
15
16use crate::mir::{MirEdgeKind, MirGraph, MirNodeKind};
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
19pub struct CanonicalKey(u64);
20
21impl CanonicalKey {
22    #[must_use]
23    pub const fn get(self) -> u64 {
24        self.0
25    }
26}
27
28#[must_use]
29pub fn canonical_key(graph: &MirGraph) -> CanonicalKey {
30    let mut hasher = DefaultHasher::new();
31    canonical_form(graph).hash(&mut hasher);
32    CanonicalKey(hasher.finish())
33}
34
35#[must_use]
36pub fn canonical_form(graph: &MirGraph) -> String {
37    let mut stack = HashSet::new();
38    canonical_node(graph, graph.root(), &mut stack)
39}
40
41#[must_use]
42pub fn collapse_reused_subgraphs(graph: &MirGraph) -> MirGraph {
43    let mut rebuilt = Graph::new();
44    let mut interned = HashMap::new();
45    let mut stack = HashSet::new();
46    let (root, _) = rebuild_node(graph, graph.root(), &mut rebuilt, &mut interned, &mut stack);
47    MirGraph::from_graph(rebuilt, root)
48}
49
50fn rebuild_node(
51    source: &MirGraph,
52    node_index: NodeIndex,
53    rebuilt: &mut Graph<MirNodeKind, MirEdgeKind>,
54    interned: &mut HashMap<String, NodeIndex>,
55    stack: &mut HashSet<NodeIndex>,
56) -> (NodeIndex, String) {
57    assert!(stack.insert(node_index), "MIR graph contains a cycle");
58
59    let mut inputs = source
60        .graph()
61        .edges_directed(node_index, Direction::Incoming)
62        .map(|edge| {
63            let (child, child_signature) =
64                rebuild_node(source, edge.source(), rebuilt, interned, stack);
65            (*edge.weight(), child, child_signature)
66        })
67        .collect::<Vec<_>>();
68
69    if matches!(
70        source.graph()[node_index],
71        MirNodeKind::Union { .. } | MirNodeKind::Intersect { .. }
72    ) {
73        inputs.sort_by(|left, right| left.2.cmp(&right.2));
74    }
75
76    let input_signature = inputs
77        .iter()
78        .map(|(edge, _, child)| format!("{}:{child}", edge_kind_name(*edge)))
79        .collect::<Vec<_>>()
80        .join(",");
81    let signature = format!(
82        "{}[{input_signature}]",
83        canonical_node_kind(&source.graph()[node_index])
84    );
85
86    if let Some(existing) = interned.get(&signature) {
87        stack.remove(&node_index);
88        return (*existing, signature);
89    }
90
91    let rebuilt_node = rebuilt.add_node(source.graph()[node_index].clone());
92    for (edge, child, _) in inputs {
93        rebuilt.add_edge(child, rebuilt_node, edge);
94    }
95    interned.insert(signature.clone(), rebuilt_node);
96    stack.remove(&node_index);
97    (rebuilt_node, signature)
98}
99
100fn canonical_node(
101    graph: &MirGraph,
102    node_index: NodeIndex,
103    stack: &mut HashSet<NodeIndex>,
104) -> String {
105    assert!(stack.insert(node_index), "MIR graph contains a cycle");
106
107    let mut inputs = graph
108        .graph()
109        .edges_directed(node_index, Direction::Incoming)
110        .map(|edge| {
111            let edge_kind = edge_kind_name(*edge.weight());
112            format!(
113                "{edge_kind}:{}",
114                canonical_node(graph, edge.source(), stack)
115            )
116        })
117        .collect::<Vec<_>>();
118
119    if matches!(
120        graph.graph()[node_index],
121        MirNodeKind::Union { .. } | MirNodeKind::Intersect { .. }
122    ) {
123        inputs.sort();
124    }
125
126    let node = canonical_node_kind(&graph.graph()[node_index]);
127    stack.remove(&node_index);
128    format!("{node}[{}]", inputs.join(","))
129}
130
131const fn edge_kind_name(edge: MirEdgeKind) -> &'static str {
132    match edge {
133        MirEdgeKind::Input => "input",
134        MirEdgeKind::CteExpansion => "cte",
135    }
136}
137
138fn canonical_node_kind(node: &MirNodeKind) -> String {
139    match node {
140        MirNodeKind::BaseTable { table, project } => {
141            format!("base:{table}:{}", canonical_debug(project))
142        }
143        MirNodeKind::Filter { predicate } => format!("filter:{predicate}"),
144        MirNodeKind::Project { columns } => format!("project:{}", columns.join(",")),
145        MirNodeKind::Join { kind, on } => {
146            format!("join:{kind:?}:{}", canonical_debug(on))
147        }
148        MirNodeKind::Aggregate { group_by, aggs } => {
149            format!(
150                "aggregate:{}:{}",
151                canonical_debug(group_by),
152                canonical_debug(aggs)
153            )
154        }
155        MirNodeKind::Distinct => "distinct".to_owned(),
156        MirNodeKind::Union { quantifier } => format!("union:{quantifier:?}"),
157        MirNodeKind::Except { quantifier } => format!("except:{quantifier:?}"),
158        MirNodeKind::Intersect { quantifier } => format!("intersect:{quantifier:?}"),
159        MirNodeKind::TopK {
160            order_by,
161            limit,
162            offset,
163        } => format!("topk:{}:{limit}:{offset}", canonical_debug(order_by)),
164        MirNodeKind::CteRef { .. } => "cte-ref".to_owned(),
165        MirNodeKind::Leaf { name } => format!("leaf:{name}"),
166    }
167}
168
169fn canonical_debug<T: core::fmt::Debug>(value: &T) -> String {
170    format!("{value:?}")
171}
172
173#[cfg(test)]
174mod tests {
175    use crate::{
176        canonical::{canonical_form, canonical_key, collapse_reused_subgraphs},
177        lower::parse_and_lower,
178    };
179
180    #[test]
181    fn equivalent_filter_conjunctions_have_same_key() {
182        let left = parse_and_lower(
183            "SELECT id FROM posts
184             WHERE author_id = 42 AND id = 7",
185        )
186        .expect("query should lower");
187        let right = parse_and_lower(
188            "SELECT id FROM posts
189             WHERE id = 7 AND author_id = 42",
190        )
191        .expect("query should lower");
192
193        assert_eq!(canonical_form(&left), canonical_form(&right));
194        assert_eq!(canonical_key(&left), canonical_key(&right));
195    }
196
197    #[test]
198    fn different_queries_have_different_keys() {
199        let left = parse_and_lower("SELECT id FROM posts WHERE author_id = 42")
200            .expect("query should lower");
201        let right = parse_and_lower("SELECT id FROM posts WHERE author_id = 43")
202            .expect("query should lower");
203
204        assert_ne!(canonical_key(&left), canonical_key(&right));
205    }
206
207    #[test]
208    fn normalized_literals_have_same_key() {
209        let left = parse_and_lower("SELECT id FROM posts WHERE author_id = 00042")
210            .expect("query should lower");
211        let right = parse_and_lower("SELECT id FROM posts WHERE author_id = 42")
212            .expect("query should lower");
213        let escaped = parse_and_lower("SELECT id FROM posts WHERE title = E'hello'")
214            .expect("query should lower");
215        let quoted = parse_and_lower("SELECT id FROM posts WHERE title = 'hello'")
216            .expect("query should lower");
217
218        assert_eq!(canonical_form(&left), canonical_form(&right));
219        assert_eq!(canonical_key(&escaped), canonical_key(&quoted));
220    }
221
222    #[test]
223    fn cte_names_do_not_affect_canonical_key() {
224        let left = parse_and_lower(
225            "WITH recent_posts AS (
226                SELECT id FROM posts WHERE author_id = 42
227             )
228             SELECT id FROM recent_posts",
229        )
230        .expect("query should lower");
231        let right = parse_and_lower(
232            "WITH visible_posts AS (
233                SELECT id FROM posts WHERE author_id = 42
234             )
235             SELECT id FROM visible_posts",
236        )
237        .expect("query should lower");
238
239        assert_eq!(canonical_key(&left), canonical_key(&right));
240    }
241
242    #[test]
243    fn collapse_reused_subgraphs_shares_duplicate_branches() {
244        let graph = parse_and_lower(
245            "SELECT id FROM posts
246             UNION ALL
247             SELECT id FROM posts",
248        )
249        .expect("query should lower");
250
251        let collapsed = collapse_reused_subgraphs(&graph);
252
253        assert!(collapsed.node_count() < graph.node_count());
254        assert_eq!(canonical_key(&graph), canonical_key(&collapsed));
255    }
256}