Skip to main content

palimpsest_sql/
mir.rs

1// Copyright 2026 Thousand Birds Inc.
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! Mid-level IR.
5//!
6//! Each variant of [`MirNode`] corresponds to a relational operator
7//! the dataflow engine knows how to instantiate. Internal — consumers
8//! should treat the graph as opaque and only inspect it through the
9//! helpers re-exported from the crate root.
10
11#![allow(missing_docs)]
12
13use std::collections::HashMap;
14
15use petgraph::{graph::NodeIndex, visit::EdgeRef, Direction, Graph};
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum JoinKind {
19    Inner,
20    Left,
21}
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24pub enum SetQuantifierKind {
25    Distinct,
26    All,
27}
28
29#[derive(Debug, Clone, PartialEq, Eq)]
30pub struct ColumnRef {
31    pub relation: Option<String>,
32    pub name: String,
33}
34
35#[derive(Debug, Clone, PartialEq, Eq)]
36pub struct OrderKey {
37    pub expression: String,
38    pub descending: bool,
39}
40
41#[derive(Debug, Clone, PartialEq, Eq)]
42pub struct AggExpr {
43    pub function: String,
44    pub args: Vec<String>,
45    pub alias: Option<String>,
46}
47
48#[derive(Debug, Clone, PartialEq, Eq)]
49pub enum MirNodeKind {
50    BaseTable {
51        table: String,
52        project: Vec<ColumnRef>,
53    },
54    Filter {
55        predicate: String,
56    },
57    Project {
58        columns: Vec<String>,
59    },
60    Join {
61        kind: JoinKind,
62        on: Vec<(ColumnRef, ColumnRef)>,
63    },
64    Aggregate {
65        group_by: Vec<ColumnRef>,
66        aggs: Vec<AggExpr>,
67    },
68    Distinct,
69    Union {
70        quantifier: SetQuantifierKind,
71    },
72    Except {
73        quantifier: SetQuantifierKind,
74    },
75    Intersect {
76        quantifier: SetQuantifierKind,
77    },
78    TopK {
79        order_by: Vec<OrderKey>,
80        limit: usize,
81        offset: usize,
82    },
83    CteRef {
84        cte: String,
85    },
86    Leaf {
87        name: String,
88    },
89}
90
91#[derive(Debug, Clone, Copy, PartialEq, Eq)]
92pub enum MirEdgeKind {
93    Input,
94    CteExpansion,
95}
96
97#[derive(Debug, Clone)]
98pub struct MirGraph {
99    graph: Graph<MirNodeKind, MirEdgeKind>,
100    root: NodeIndex,
101}
102
103impl MirGraph {
104    #[must_use]
105    pub fn new(root: MirNodeKind) -> Self {
106        let mut graph = Graph::new();
107        let root = graph.add_node(root);
108        Self { graph, root }
109    }
110
111    #[must_use]
112    pub const fn root(&self) -> NodeIndex {
113        self.root
114    }
115
116    #[must_use]
117    pub const fn graph(&self) -> &Graph<MirNodeKind, MirEdgeKind> {
118        &self.graph
119    }
120
121    #[must_use]
122    pub fn node_count(&self) -> usize {
123        self.graph.node_count()
124    }
125
126    #[must_use]
127    pub fn root_kind(&self) -> &MirNodeKind {
128        &self.graph[self.root]
129    }
130
131    pub fn node_kinds(&self) -> impl Iterator<Item = &MirNodeKind> {
132        self.graph.node_weights()
133    }
134
135    pub fn set_root(&mut self, root: NodeIndex) {
136        self.root = root;
137    }
138
139    pub fn add_input(&mut self, from: NodeIndex, to: NodeIndex) {
140        self.graph.add_edge(from, to, MirEdgeKind::Input);
141    }
142
143    pub fn add_cte_expansion(&mut self, from: NodeIndex, to: NodeIndex) {
144        self.graph.add_edge(from, to, MirEdgeKind::CteExpansion);
145    }
146
147    pub fn add_node(&mut self, node: MirNodeKind) -> NodeIndex {
148        self.graph.add_node(node)
149    }
150
151    pub fn append_graph(&mut self, other: &Self) -> NodeIndex {
152        let mut node_map = HashMap::with_capacity(other.graph.node_count());
153
154        for source in other.graph.node_indices() {
155            let target = self.graph.add_node(other.graph[source].clone());
156            node_map.insert(source, target);
157        }
158
159        for edge in other.graph.edge_references() {
160            self.graph.add_edge(
161                node_map[&edge.source()],
162                node_map[&edge.target()],
163                *edge.weight(),
164            );
165        }
166
167        node_map[&other.root]
168    }
169
170    pub(crate) const fn from_graph(
171        graph: Graph<MirNodeKind, MirEdgeKind>,
172        root: NodeIndex,
173    ) -> Self {
174        Self { graph, root }
175    }
176
177    /// Inserts `kind` immediately above `target`, redirecting every
178    /// outgoing edge from `target` (i.e. each consumer that read from
179    /// `target`) to read from the newly-spliced node instead. The new
180    /// node receives a single incoming `Input` edge from `target`.
181    ///
182    /// If `target` was the graph root, the spliced node becomes the new
183    /// root.
184    ///
185    /// Returns the index of the newly-inserted node.
186    pub fn splice_above(&mut self, target: NodeIndex, kind: MirNodeKind) -> NodeIndex {
187        let inserted = self.graph.add_node(kind);
188
189        let outgoing: Vec<_> = self
190            .graph
191            .edges_directed(target, Direction::Outgoing)
192            .map(|edge| (edge.id(), edge.target(), *edge.weight()))
193            .collect();
194        for (edge_id, consumer, weight) in outgoing {
195            self.graph.remove_edge(edge_id);
196            self.graph.add_edge(inserted, consumer, weight);
197        }
198
199        self.graph.add_edge(target, inserted, MirEdgeKind::Input);
200
201        if self.root == target {
202            self.root = inserted;
203        }
204
205        inserted
206    }
207
208    /// Returns every node index whose payload is a `BaseTable`.
209    #[must_use]
210    pub fn base_table_indices(&self) -> Vec<NodeIndex> {
211        self.graph
212            .node_indices()
213            .filter(|index| matches!(self.graph[*index], MirNodeKind::BaseTable { .. }))
214            .collect()
215    }
216
217    /// Returns the kind stored at `index`.
218    #[must_use]
219    pub fn node_kind(&self, index: NodeIndex) -> &MirNodeKind {
220        &self.graph[index]
221    }
222}