1#![allow(missing_docs)]
12
13use std::collections::HashMap;
14
15use petgraph::{graph::NodeIndex, visit::EdgeRef, Direction, Graph};
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum JoinKind {
19 Inner,
20 Left,
21}
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24pub enum SetQuantifierKind {
25 Distinct,
26 All,
27}
28
29#[derive(Debug, Clone, PartialEq, Eq)]
30pub struct ColumnRef {
31 pub relation: Option<String>,
32 pub name: String,
33}
34
35#[derive(Debug, Clone, PartialEq, Eq)]
36pub struct OrderKey {
37 pub expression: String,
38 pub descending: bool,
39}
40
41#[derive(Debug, Clone, PartialEq, Eq)]
42pub struct AggExpr {
43 pub function: String,
44 pub args: Vec<String>,
45 pub alias: Option<String>,
46}
47
48#[derive(Debug, Clone, PartialEq, Eq)]
49pub enum MirNodeKind {
50 BaseTable {
51 table: String,
52 project: Vec<ColumnRef>,
53 },
54 Filter {
55 predicate: String,
56 },
57 Project {
58 columns: Vec<String>,
59 },
60 Join {
61 kind: JoinKind,
62 on: Vec<(ColumnRef, ColumnRef)>,
63 },
64 Aggregate {
65 group_by: Vec<ColumnRef>,
66 aggs: Vec<AggExpr>,
67 },
68 Distinct,
69 Union {
70 quantifier: SetQuantifierKind,
71 },
72 Except {
73 quantifier: SetQuantifierKind,
74 },
75 Intersect {
76 quantifier: SetQuantifierKind,
77 },
78 TopK {
79 order_by: Vec<OrderKey>,
80 limit: usize,
81 offset: usize,
82 },
83 CteRef {
84 cte: String,
85 },
86 Leaf {
87 name: String,
88 },
89}
90
91#[derive(Debug, Clone, Copy, PartialEq, Eq)]
92pub enum MirEdgeKind {
93 Input,
94 CteExpansion,
95}
96
97#[derive(Debug, Clone)]
98pub struct MirGraph {
99 graph: Graph<MirNodeKind, MirEdgeKind>,
100 root: NodeIndex,
101}
102
103impl MirGraph {
104 #[must_use]
105 pub fn new(root: MirNodeKind) -> Self {
106 let mut graph = Graph::new();
107 let root = graph.add_node(root);
108 Self { graph, root }
109 }
110
111 #[must_use]
112 pub const fn root(&self) -> NodeIndex {
113 self.root
114 }
115
116 #[must_use]
117 pub const fn graph(&self) -> &Graph<MirNodeKind, MirEdgeKind> {
118 &self.graph
119 }
120
121 #[must_use]
122 pub fn node_count(&self) -> usize {
123 self.graph.node_count()
124 }
125
126 #[must_use]
127 pub fn root_kind(&self) -> &MirNodeKind {
128 &self.graph[self.root]
129 }
130
131 pub fn node_kinds(&self) -> impl Iterator<Item = &MirNodeKind> {
132 self.graph.node_weights()
133 }
134
135 pub fn set_root(&mut self, root: NodeIndex) {
136 self.root = root;
137 }
138
139 pub fn add_input(&mut self, from: NodeIndex, to: NodeIndex) {
140 self.graph.add_edge(from, to, MirEdgeKind::Input);
141 }
142
143 pub fn add_cte_expansion(&mut self, from: NodeIndex, to: NodeIndex) {
144 self.graph.add_edge(from, to, MirEdgeKind::CteExpansion);
145 }
146
147 pub fn add_node(&mut self, node: MirNodeKind) -> NodeIndex {
148 self.graph.add_node(node)
149 }
150
151 pub fn append_graph(&mut self, other: &Self) -> NodeIndex {
152 let mut node_map = HashMap::with_capacity(other.graph.node_count());
153
154 for source in other.graph.node_indices() {
155 let target = self.graph.add_node(other.graph[source].clone());
156 node_map.insert(source, target);
157 }
158
159 for edge in other.graph.edge_references() {
160 self.graph.add_edge(
161 node_map[&edge.source()],
162 node_map[&edge.target()],
163 *edge.weight(),
164 );
165 }
166
167 node_map[&other.root]
168 }
169
170 pub(crate) const fn from_graph(
171 graph: Graph<MirNodeKind, MirEdgeKind>,
172 root: NodeIndex,
173 ) -> Self {
174 Self { graph, root }
175 }
176
177 pub fn splice_above(&mut self, target: NodeIndex, kind: MirNodeKind) -> NodeIndex {
187 let inserted = self.graph.add_node(kind);
188
189 let outgoing: Vec<_> = self
190 .graph
191 .edges_directed(target, Direction::Outgoing)
192 .map(|edge| (edge.id(), edge.target(), *edge.weight()))
193 .collect();
194 for (edge_id, consumer, weight) in outgoing {
195 self.graph.remove_edge(edge_id);
196 self.graph.add_edge(inserted, consumer, weight);
197 }
198
199 self.graph.add_edge(target, inserted, MirEdgeKind::Input);
200
201 if self.root == target {
202 self.root = inserted;
203 }
204
205 inserted
206 }
207
208 #[must_use]
210 pub fn base_table_indices(&self) -> Vec<NodeIndex> {
211 self.graph
212 .node_indices()
213 .filter(|index| matches!(self.graph[*index], MirNodeKind::BaseTable { .. }))
214 .collect()
215 }
216
217 #[must_use]
219 pub fn node_kind(&self, index: NodeIndex) -> &MirNodeKind {
220 &self.graph[index]
221 }
222}