use std::collections::BTreeMap;
use palimpsest_sql::mir::{ColumnRef, MirEdgeKind, MirGraph, MirNodeKind};
use petgraph::{graph::NodeIndex, visit::EdgeRef, Direction};
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct UpqueryRequest {
pub table: String,
pub primary_key: String,
pub keys: Vec<String>,
}
impl UpqueryRequest {
#[must_use]
pub fn to_sql(&self) -> String {
let placeholders = self
.keys
.iter()
.map(|key| format!("'{}'", key.replace('\'', "''")))
.collect::<Vec<_>>()
.join(", ");
format!(
"SELECT * FROM {table} WHERE {pk} IN ({values})",
table = self.table,
pk = self.primary_key,
values = placeholders
)
}
}
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct UpqueryPlan {
pub requests: Vec<UpqueryRequest>,
}
impl UpqueryPlan {
#[must_use]
pub fn is_empty(&self) -> bool {
self.requests.is_empty()
}
#[must_use]
pub fn len(&self) -> usize {
self.requests.len()
}
}
pub trait PrimaryKeyResolver {
fn primary_key(&self, table: &str) -> Option<&str>;
}
#[derive(Debug, Clone, Default)]
pub struct StaticPrimaryKeys {
keys: BTreeMap<String, String>,
}
impl StaticPrimaryKeys {
#[must_use]
pub const fn new() -> Self {
Self {
keys: BTreeMap::new(),
}
}
pub fn insert(&mut self, table: impl Into<String>, primary_key: impl Into<String>) {
self.keys.insert(table.into(), primary_key.into());
}
#[must_use]
pub fn from_iter<I, S, P>(pairs: I) -> Self
where
I: IntoIterator<Item = (S, P)>,
S: Into<String>,
P: Into<String>,
{
let mut resolver = Self::new();
for (table, key) in pairs {
resolver.insert(table, key);
}
resolver
}
}
impl PrimaryKeyResolver for StaticPrimaryKeys {
fn primary_key(&self, table: &str) -> Option<&str> {
self.keys.get(table).map(String::as_str)
}
}
#[must_use]
pub fn plan_upquery<R>(graph: &MirGraph, requested_keys: &[String], primary_keys: &R) -> UpqueryPlan
where
R: PrimaryKeyResolver + ?Sized,
{
let mut requests: BTreeMap<String, UpqueryRequest> = BTreeMap::new();
let mut stack = vec![graph.root()];
let mut visited = std::collections::BTreeSet::new();
while let Some(node) = stack.pop() {
if !visited.insert(node) {
continue;
}
match &graph.graph()[node] {
MirNodeKind::BaseTable { table, .. } => {
let Some(primary_key) = primary_keys.primary_key(table) else {
continue;
};
requests
.entry(table.clone())
.or_insert_with(|| UpqueryRequest {
table: table.clone(),
primary_key: primary_key.to_owned(),
keys: requested_keys.to_vec(),
});
}
MirNodeKind::CteRef { .. } => {
stack.extend(input_nodes(graph, node, MirEdgeKind::CteExpansion));
}
_ => {
stack.extend(input_nodes(graph, node, MirEdgeKind::Input));
}
}
}
UpqueryPlan {
requests: requests.into_values().collect(),
}
}
#[must_use]
pub fn base_tables(graph: &MirGraph) -> Vec<String> {
let mut tables = Vec::new();
let mut stack = vec![graph.root()];
let mut visited = std::collections::BTreeSet::new();
while let Some(node) = stack.pop() {
if !visited.insert(node) {
continue;
}
match &graph.graph()[node] {
MirNodeKind::BaseTable { table, .. } => tables.push(table.clone()),
MirNodeKind::CteRef { .. } => {
stack.extend(input_nodes(graph, node, MirEdgeKind::CteExpansion));
}
_ => {
stack.extend(input_nodes(graph, node, MirEdgeKind::Input));
}
}
}
tables.sort();
tables.dedup();
tables
}
#[must_use]
pub fn referenced_columns(graph: &MirGraph, table: &str) -> Vec<ColumnRef> {
let mut columns = Vec::new();
for node in graph.graph().node_weights() {
if let MirNodeKind::BaseTable {
table: name,
project,
} = node
{
if name == table {
columns.extend(project.iter().cloned());
}
}
}
columns.sort_by(|left, right| {
left.relation
.cmp(&right.relation)
.then_with(|| left.name.cmp(&right.name))
});
columns.dedup();
columns
}
fn input_nodes(graph: &MirGraph, node: NodeIndex, edge: MirEdgeKind) -> Vec<NodeIndex> {
graph
.graph()
.edges_directed(node, Direction::Incoming)
.filter(|candidate| *candidate.weight() == edge)
.map(|candidate| candidate.source())
.collect()
}
#[cfg(test)]
mod tests {
use palimpsest_sql::mir::{ColumnRef, JoinKind, MirGraph, MirNodeKind};
use super::{base_tables, plan_upquery, referenced_columns, StaticPrimaryKeys, UpqueryRequest};
fn join_graph() -> MirGraph {
let mut graph = MirGraph::new(MirNodeKind::BaseTable {
table: "posts".to_owned(),
project: vec![ColumnRef {
relation: Some("posts".to_owned()),
name: "id".to_owned(),
}],
});
let posts = graph.root();
let authors = graph.add_node(MirNodeKind::BaseTable {
table: "authors".to_owned(),
project: vec![ColumnRef {
relation: Some("authors".to_owned()),
name: "id".to_owned(),
}],
});
let join = graph.add_node(MirNodeKind::Join {
kind: JoinKind::Inner,
on: vec![(
ColumnRef {
relation: Some("posts".to_owned()),
name: "author_id".to_owned(),
},
ColumnRef {
relation: Some("authors".to_owned()),
name: "id".to_owned(),
},
)],
});
graph.add_input(posts, join);
graph.add_input(authors, join);
graph.set_root(join);
graph
}
#[test]
fn plan_upquery_emits_one_request_per_base_table() {
let graph = join_graph();
let primary_keys = StaticPrimaryKeys::from_iter([("posts", "id"), ("authors", "id")]);
let plan = plan_upquery(&graph, &["7".to_owned(), "9".to_owned()], &primary_keys);
assert_eq!(plan.len(), 2);
let mut requests = plan.requests;
requests.sort_by(|left, right| left.table.cmp(&right.table));
assert_eq!(
requests,
[
UpqueryRequest {
table: "authors".to_owned(),
primary_key: "id".to_owned(),
keys: vec!["7".to_owned(), "9".to_owned()],
},
UpqueryRequest {
table: "posts".to_owned(),
primary_key: "id".to_owned(),
keys: vec!["7".to_owned(), "9".to_owned()],
},
]
);
}
#[test]
fn plan_upquery_skips_tables_with_unknown_primary_key() {
let graph = join_graph();
let primary_keys = StaticPrimaryKeys::from_iter([("posts", "id")]);
let plan = plan_upquery(&graph, &["1".to_owned()], &primary_keys);
assert_eq!(plan.len(), 1);
assert_eq!(plan.requests[0].table, "posts");
}
#[test]
fn upquery_to_sql_quotes_values_and_escapes_inner_quotes() {
let request = UpqueryRequest {
table: "posts".to_owned(),
primary_key: "id".to_owned(),
keys: vec!["a".to_owned(), "b'c".to_owned()],
};
assert_eq!(
request.to_sql(),
"SELECT * FROM posts WHERE id IN ('a', 'b''c')"
);
}
#[test]
fn base_tables_walks_through_filter_and_join() {
let graph = join_graph();
assert_eq!(base_tables(&graph), ["authors", "posts"]);
}
#[test]
fn referenced_columns_collects_projected_columns_for_base_table() {
let graph = join_graph();
assert_eq!(
referenced_columns(&graph, "posts"),
[ColumnRef {
relation: Some("posts".to_owned()),
name: "id".to_owned(),
}]
);
}
}