palimpsest_dataflow/palimpsest/
upquery.rs1use std::collections::BTreeMap;
13
14use palimpsest_sql::mir::{ColumnRef, MirEdgeKind, MirGraph, MirNodeKind};
15use petgraph::{graph::NodeIndex, visit::EdgeRef, Direction};
16
17#[derive(Debug, Clone, PartialEq, Eq)]
19pub struct UpqueryRequest {
20 pub table: String,
22 pub primary_key: String,
24 pub keys: Vec<String>,
26}
27
28impl UpqueryRequest {
29 #[must_use]
32 pub fn to_sql(&self) -> String {
33 let placeholders = self
34 .keys
35 .iter()
36 .map(|key| format!("'{}'", key.replace('\'', "''")))
37 .collect::<Vec<_>>()
38 .join(", ");
39 format!(
40 "SELECT * FROM {table} WHERE {pk} IN ({values})",
41 table = self.table,
42 pk = self.primary_key,
43 values = placeholders
44 )
45 }
46}
47
48#[derive(Debug, Clone, Default, PartialEq, Eq)]
50pub struct UpqueryPlan {
51 pub requests: Vec<UpqueryRequest>,
53}
54
55impl UpqueryPlan {
56 #[must_use]
58 pub fn is_empty(&self) -> bool {
59 self.requests.is_empty()
60 }
61
62 #[must_use]
64 pub fn len(&self) -> usize {
65 self.requests.len()
66 }
67}
68
69pub trait PrimaryKeyResolver {
71 fn primary_key(&self, table: &str) -> Option<&str>;
73}
74
75#[derive(Debug, Clone, Default)]
77pub struct StaticPrimaryKeys {
78 keys: BTreeMap<String, String>,
79}
80
81impl StaticPrimaryKeys {
82 #[must_use]
84 pub const fn new() -> Self {
85 Self {
86 keys: BTreeMap::new(),
87 }
88 }
89
90 pub fn insert(&mut self, table: impl Into<String>, primary_key: impl Into<String>) {
92 self.keys.insert(table.into(), primary_key.into());
93 }
94
95 #[must_use]
97 pub fn from_iter<I, S, P>(pairs: I) -> Self
98 where
99 I: IntoIterator<Item = (S, P)>,
100 S: Into<String>,
101 P: Into<String>,
102 {
103 let mut resolver = Self::new();
104 for (table, key) in pairs {
105 resolver.insert(table, key);
106 }
107 resolver
108 }
109}
110
111impl PrimaryKeyResolver for StaticPrimaryKeys {
112 fn primary_key(&self, table: &str) -> Option<&str> {
113 self.keys.get(table).map(String::as_str)
114 }
115}
116
117#[must_use]
124pub fn plan_upquery<R>(graph: &MirGraph, requested_keys: &[String], primary_keys: &R) -> UpqueryPlan
125where
126 R: PrimaryKeyResolver + ?Sized,
127{
128 let mut requests: BTreeMap<String, UpqueryRequest> = BTreeMap::new();
129 let mut stack = vec![graph.root()];
130 let mut visited = std::collections::BTreeSet::new();
131
132 while let Some(node) = stack.pop() {
133 if !visited.insert(node) {
134 continue;
135 }
136
137 match &graph.graph()[node] {
138 MirNodeKind::BaseTable { table, .. } => {
139 let Some(primary_key) = primary_keys.primary_key(table) else {
140 continue;
141 };
142 requests
143 .entry(table.clone())
144 .or_insert_with(|| UpqueryRequest {
145 table: table.clone(),
146 primary_key: primary_key.to_owned(),
147 keys: requested_keys.to_vec(),
148 });
149 }
150 MirNodeKind::CteRef { .. } => {
151 stack.extend(input_nodes(graph, node, MirEdgeKind::CteExpansion));
152 }
153 _ => {
154 stack.extend(input_nodes(graph, node, MirEdgeKind::Input));
155 }
156 }
157 }
158
159 UpqueryPlan {
160 requests: requests.into_values().collect(),
161 }
162}
163
164#[must_use]
166pub fn base_tables(graph: &MirGraph) -> Vec<String> {
167 let mut tables = Vec::new();
168 let mut stack = vec![graph.root()];
169 let mut visited = std::collections::BTreeSet::new();
170
171 while let Some(node) = stack.pop() {
172 if !visited.insert(node) {
173 continue;
174 }
175 match &graph.graph()[node] {
176 MirNodeKind::BaseTable { table, .. } => tables.push(table.clone()),
177 MirNodeKind::CteRef { .. } => {
178 stack.extend(input_nodes(graph, node, MirEdgeKind::CteExpansion));
179 }
180 _ => {
181 stack.extend(input_nodes(graph, node, MirEdgeKind::Input));
182 }
183 }
184 }
185
186 tables.sort();
187 tables.dedup();
188 tables
189}
190
191#[must_use]
195pub fn referenced_columns(graph: &MirGraph, table: &str) -> Vec<ColumnRef> {
196 let mut columns = Vec::new();
197 for node in graph.graph().node_weights() {
198 if let MirNodeKind::BaseTable {
199 table: name,
200 project,
201 } = node
202 {
203 if name == table {
204 columns.extend(project.iter().cloned());
205 }
206 }
207 }
208 columns.sort_by(|left, right| {
209 left.relation
210 .cmp(&right.relation)
211 .then_with(|| left.name.cmp(&right.name))
212 });
213 columns.dedup();
214 columns
215}
216
217fn input_nodes(graph: &MirGraph, node: NodeIndex, edge: MirEdgeKind) -> Vec<NodeIndex> {
218 graph
219 .graph()
220 .edges_directed(node, Direction::Incoming)
221 .filter(|candidate| *candidate.weight() == edge)
222 .map(|candidate| candidate.source())
223 .collect()
224}
225
226#[cfg(test)]
227mod tests {
228 use palimpsest_sql::mir::{ColumnRef, JoinKind, MirGraph, MirNodeKind};
229
230 use super::{base_tables, plan_upquery, referenced_columns, StaticPrimaryKeys, UpqueryRequest};
231
232 fn join_graph() -> MirGraph {
233 let mut graph = MirGraph::new(MirNodeKind::BaseTable {
234 table: "posts".to_owned(),
235 project: vec![ColumnRef {
236 relation: Some("posts".to_owned()),
237 name: "id".to_owned(),
238 }],
239 });
240 let posts = graph.root();
241 let authors = graph.add_node(MirNodeKind::BaseTable {
242 table: "authors".to_owned(),
243 project: vec![ColumnRef {
244 relation: Some("authors".to_owned()),
245 name: "id".to_owned(),
246 }],
247 });
248 let join = graph.add_node(MirNodeKind::Join {
249 kind: JoinKind::Inner,
250 on: vec![(
251 ColumnRef {
252 relation: Some("posts".to_owned()),
253 name: "author_id".to_owned(),
254 },
255 ColumnRef {
256 relation: Some("authors".to_owned()),
257 name: "id".to_owned(),
258 },
259 )],
260 });
261 graph.add_input(posts, join);
262 graph.add_input(authors, join);
263 graph.set_root(join);
264 graph
265 }
266
267 #[test]
268 fn plan_upquery_emits_one_request_per_base_table() {
269 let graph = join_graph();
270 let primary_keys = StaticPrimaryKeys::from_iter([("posts", "id"), ("authors", "id")]);
271
272 let plan = plan_upquery(&graph, &["7".to_owned(), "9".to_owned()], &primary_keys);
273 assert_eq!(plan.len(), 2);
274
275 let mut requests = plan.requests;
276 requests.sort_by(|left, right| left.table.cmp(&right.table));
277 assert_eq!(
278 requests,
279 [
280 UpqueryRequest {
281 table: "authors".to_owned(),
282 primary_key: "id".to_owned(),
283 keys: vec!["7".to_owned(), "9".to_owned()],
284 },
285 UpqueryRequest {
286 table: "posts".to_owned(),
287 primary_key: "id".to_owned(),
288 keys: vec!["7".to_owned(), "9".to_owned()],
289 },
290 ]
291 );
292 }
293
294 #[test]
295 fn plan_upquery_skips_tables_with_unknown_primary_key() {
296 let graph = join_graph();
297 let primary_keys = StaticPrimaryKeys::from_iter([("posts", "id")]);
298
299 let plan = plan_upquery(&graph, &["1".to_owned()], &primary_keys);
300 assert_eq!(plan.len(), 1);
301 assert_eq!(plan.requests[0].table, "posts");
302 }
303
304 #[test]
305 fn upquery_to_sql_quotes_values_and_escapes_inner_quotes() {
306 let request = UpqueryRequest {
307 table: "posts".to_owned(),
308 primary_key: "id".to_owned(),
309 keys: vec!["a".to_owned(), "b'c".to_owned()],
310 };
311 assert_eq!(
312 request.to_sql(),
313 "SELECT * FROM posts WHERE id IN ('a', 'b''c')"
314 );
315 }
316
317 #[test]
318 fn base_tables_walks_through_filter_and_join() {
319 let graph = join_graph();
320 assert_eq!(base_tables(&graph), ["authors", "posts"]);
321 }
322
323 #[test]
324 fn referenced_columns_collects_projected_columns_for_base_table() {
325 let graph = join_graph();
326 assert_eq!(
327 referenced_columns(&graph, "posts"),
328 [ColumnRef {
329 relation: Some("posts".to_owned()),
330 name: "id".to_owned(),
331 }]
332 );
333 }
334}