Skip to main content

palimpsest_dataflow/palimpsest/
upquery.rs

1//! Upquery resolver.
2//!
3//! When a partial-materialization arrangement misses a key, the runtime
4//! issues an *upquery* against the upstream Postgres so the consumer can
5//! be served. The resolver walks an MIR plan backward, finds every base
6//! table that contributes to the requested rows, and emits a
7//! `SELECT … WHERE pk IN (…)` statement per base table.
8//!
9//! This module owns only the planning step. The actual SQL execution is
10//! the responsibility of the gateway/postgres layer.
11
12use std::collections::BTreeMap;
13
14use palimpsest_sql::mir::{ColumnRef, MirEdgeKind, MirGraph, MirNodeKind};
15use petgraph::{graph::NodeIndex, visit::EdgeRef, Direction};
16
17/// One Postgres upquery directive.
18#[derive(Debug, Clone, PartialEq, Eq)]
19pub struct UpqueryRequest {
20    /// Postgres relation name.
21    pub table: String,
22    /// Primary-key column for the relation.
23    pub primary_key: String,
24    /// Primary-key values to fetch.
25    pub keys: Vec<String>,
26}
27
28impl UpqueryRequest {
29    /// Returns the SQL `SELECT * FROM table WHERE pk IN (…)` form, with values
30    /// quoted using single quotes (and embedded quotes escaped per SQL rules).
31    #[must_use]
32    pub fn to_sql(&self) -> String {
33        let placeholders = self
34            .keys
35            .iter()
36            .map(|key| format!("'{}'", key.replace('\'', "''")))
37            .collect::<Vec<_>>()
38            .join(", ");
39        format!(
40            "SELECT * FROM {table} WHERE {pk} IN ({values})",
41            table = self.table,
42            pk = self.primary_key,
43            values = placeholders
44        )
45    }
46}
47
48/// Plan produced by walking an MIR backward from the requested node.
49#[derive(Debug, Clone, Default, PartialEq, Eq)]
50pub struct UpqueryPlan {
51    /// One entry per contributing base table.
52    pub requests: Vec<UpqueryRequest>,
53}
54
55impl UpqueryPlan {
56    /// Returns true when no base tables were reached.
57    #[must_use]
58    pub fn is_empty(&self) -> bool {
59        self.requests.is_empty()
60    }
61
62    /// Number of distinct base tables in the plan.
63    #[must_use]
64    pub fn len(&self) -> usize {
65        self.requests.len()
66    }
67}
68
69/// Maps a relation name to its primary-key column.
70pub trait PrimaryKeyResolver {
71    /// Returns the primary-key column for `table`, if known.
72    fn primary_key(&self, table: &str) -> Option<&str>;
73}
74
75/// Static `(table, primary_key)` lookup driven by a `BTreeMap`.
76#[derive(Debug, Clone, Default)]
77pub struct StaticPrimaryKeys {
78    keys: BTreeMap<String, String>,
79}
80
81impl StaticPrimaryKeys {
82    /// Creates an empty resolver.
83    #[must_use]
84    pub const fn new() -> Self {
85        Self {
86            keys: BTreeMap::new(),
87        }
88    }
89
90    /// Inserts or updates a `(table, primary_key)` mapping.
91    pub fn insert(&mut self, table: impl Into<String>, primary_key: impl Into<String>) {
92        self.keys.insert(table.into(), primary_key.into());
93    }
94
95    /// Builds a resolver from an iterator of `(table, primary_key)` pairs.
96    #[must_use]
97    pub fn from_iter<I, S, P>(pairs: I) -> Self
98    where
99        I: IntoIterator<Item = (S, P)>,
100        S: Into<String>,
101        P: Into<String>,
102    {
103        let mut resolver = Self::new();
104        for (table, key) in pairs {
105            resolver.insert(table, key);
106        }
107        resolver
108    }
109}
110
111impl PrimaryKeyResolver for StaticPrimaryKeys {
112    fn primary_key(&self, table: &str) -> Option<&str> {
113        self.keys.get(table).map(String::as_str)
114    }
115}
116
117/// Walks the MIR backwards from its root and records the base tables that
118/// must be queried for `requested_keys`.
119///
120/// The same primary-key set is forwarded to every base table; callers that
121/// need per-base filtering should build a tailored plan up front (e.g., by
122/// resolving join keys to per-table identifiers in the SQL frontend).
123#[must_use]
124pub fn plan_upquery<R>(graph: &MirGraph, requested_keys: &[String], primary_keys: &R) -> UpqueryPlan
125where
126    R: PrimaryKeyResolver + ?Sized,
127{
128    let mut requests: BTreeMap<String, UpqueryRequest> = BTreeMap::new();
129    let mut stack = vec![graph.root()];
130    let mut visited = std::collections::BTreeSet::new();
131
132    while let Some(node) = stack.pop() {
133        if !visited.insert(node) {
134            continue;
135        }
136
137        match &graph.graph()[node] {
138            MirNodeKind::BaseTable { table, .. } => {
139                let Some(primary_key) = primary_keys.primary_key(table) else {
140                    continue;
141                };
142                requests
143                    .entry(table.clone())
144                    .or_insert_with(|| UpqueryRequest {
145                        table: table.clone(),
146                        primary_key: primary_key.to_owned(),
147                        keys: requested_keys.to_vec(),
148                    });
149            }
150            MirNodeKind::CteRef { .. } => {
151                stack.extend(input_nodes(graph, node, MirEdgeKind::CteExpansion));
152            }
153            _ => {
154                stack.extend(input_nodes(graph, node, MirEdgeKind::Input));
155            }
156        }
157    }
158
159    UpqueryPlan {
160        requests: requests.into_values().collect(),
161    }
162}
163
164/// Walks an MIR backward from `root` and returns every base table referenced.
165#[must_use]
166pub fn base_tables(graph: &MirGraph) -> Vec<String> {
167    let mut tables = Vec::new();
168    let mut stack = vec![graph.root()];
169    let mut visited = std::collections::BTreeSet::new();
170
171    while let Some(node) = stack.pop() {
172        if !visited.insert(node) {
173            continue;
174        }
175        match &graph.graph()[node] {
176            MirNodeKind::BaseTable { table, .. } => tables.push(table.clone()),
177            MirNodeKind::CteRef { .. } => {
178                stack.extend(input_nodes(graph, node, MirEdgeKind::CteExpansion));
179            }
180            _ => {
181                stack.extend(input_nodes(graph, node, MirEdgeKind::Input));
182            }
183        }
184    }
185
186    tables.sort();
187    tables.dedup();
188    tables
189}
190
191/// Returns every `(relation, column)` pair the MIR projects for `table`.
192///
193/// Useful for narrowing upqueries to only the columns the consumer reads.
194#[must_use]
195pub fn referenced_columns(graph: &MirGraph, table: &str) -> Vec<ColumnRef> {
196    let mut columns = Vec::new();
197    for node in graph.graph().node_weights() {
198        if let MirNodeKind::BaseTable {
199            table: name,
200            project,
201        } = node
202        {
203            if name == table {
204                columns.extend(project.iter().cloned());
205            }
206        }
207    }
208    columns.sort_by(|left, right| {
209        left.relation
210            .cmp(&right.relation)
211            .then_with(|| left.name.cmp(&right.name))
212    });
213    columns.dedup();
214    columns
215}
216
217fn input_nodes(graph: &MirGraph, node: NodeIndex, edge: MirEdgeKind) -> Vec<NodeIndex> {
218    graph
219        .graph()
220        .edges_directed(node, Direction::Incoming)
221        .filter(|candidate| *candidate.weight() == edge)
222        .map(|candidate| candidate.source())
223        .collect()
224}
225
226#[cfg(test)]
227mod tests {
228    use palimpsest_sql::mir::{ColumnRef, JoinKind, MirGraph, MirNodeKind};
229
230    use super::{base_tables, plan_upquery, referenced_columns, StaticPrimaryKeys, UpqueryRequest};
231
232    fn join_graph() -> MirGraph {
233        let mut graph = MirGraph::new(MirNodeKind::BaseTable {
234            table: "posts".to_owned(),
235            project: vec![ColumnRef {
236                relation: Some("posts".to_owned()),
237                name: "id".to_owned(),
238            }],
239        });
240        let posts = graph.root();
241        let authors = graph.add_node(MirNodeKind::BaseTable {
242            table: "authors".to_owned(),
243            project: vec![ColumnRef {
244                relation: Some("authors".to_owned()),
245                name: "id".to_owned(),
246            }],
247        });
248        let join = graph.add_node(MirNodeKind::Join {
249            kind: JoinKind::Inner,
250            on: vec![(
251                ColumnRef {
252                    relation: Some("posts".to_owned()),
253                    name: "author_id".to_owned(),
254                },
255                ColumnRef {
256                    relation: Some("authors".to_owned()),
257                    name: "id".to_owned(),
258                },
259            )],
260        });
261        graph.add_input(posts, join);
262        graph.add_input(authors, join);
263        graph.set_root(join);
264        graph
265    }
266
267    #[test]
268    fn plan_upquery_emits_one_request_per_base_table() {
269        let graph = join_graph();
270        let primary_keys = StaticPrimaryKeys::from_iter([("posts", "id"), ("authors", "id")]);
271
272        let plan = plan_upquery(&graph, &["7".to_owned(), "9".to_owned()], &primary_keys);
273        assert_eq!(plan.len(), 2);
274
275        let mut requests = plan.requests;
276        requests.sort_by(|left, right| left.table.cmp(&right.table));
277        assert_eq!(
278            requests,
279            [
280                UpqueryRequest {
281                    table: "authors".to_owned(),
282                    primary_key: "id".to_owned(),
283                    keys: vec!["7".to_owned(), "9".to_owned()],
284                },
285                UpqueryRequest {
286                    table: "posts".to_owned(),
287                    primary_key: "id".to_owned(),
288                    keys: vec!["7".to_owned(), "9".to_owned()],
289                },
290            ]
291        );
292    }
293
294    #[test]
295    fn plan_upquery_skips_tables_with_unknown_primary_key() {
296        let graph = join_graph();
297        let primary_keys = StaticPrimaryKeys::from_iter([("posts", "id")]);
298
299        let plan = plan_upquery(&graph, &["1".to_owned()], &primary_keys);
300        assert_eq!(plan.len(), 1);
301        assert_eq!(plan.requests[0].table, "posts");
302    }
303
304    #[test]
305    fn upquery_to_sql_quotes_values_and_escapes_inner_quotes() {
306        let request = UpqueryRequest {
307            table: "posts".to_owned(),
308            primary_key: "id".to_owned(),
309            keys: vec!["a".to_owned(), "b'c".to_owned()],
310        };
311        assert_eq!(
312            request.to_sql(),
313            "SELECT * FROM posts WHERE id IN ('a', 'b''c')"
314        );
315    }
316
317    #[test]
318    fn base_tables_walks_through_filter_and_join() {
319        let graph = join_graph();
320        assert_eq!(base_tables(&graph), ["authors", "posts"]);
321    }
322
323    #[test]
324    fn referenced_columns_collects_projected_columns_for_base_table() {
325        let graph = join_graph();
326        assert_eq!(
327            referenced_columns(&graph, "posts"),
328            [ColumnRef {
329                relation: Some("posts".to_owned()),
330                name: "id".to_owned(),
331            }]
332        );
333    }
334}