use std::collections::hash_map::RandomState;
use std::collections::{HashMap, HashSet};
use itertools::Itertools;
use super::anchor::{infer_complexity, CidCollector, Complexity};
use super::ast::*;
use crate::ir::generic::{ColumnSort, SortDirection, WindowFrame, WindowKind};
use crate::ir::pl::{JoinSide, Literal};
use crate::ir::rq::{
self, maybe_binop, new_binop, CId, Compute, Expr, ExprKind, RqFold, Transform, Window,
};
use crate::sql::Context;
use crate::{debug, Error, Result, WithErrorInfo};
use prqlc_parser::generic::{InterpolateItem, Range};
pub(in crate::sql) fn preprocess(
pipeline: Vec<Transform>,
ctx: &mut Context,
) -> Result<Vec<SqlTransform>> {
Ok(pipeline)
.and_then(normalize)
.and_then(|p| wrap(p, ctx))
.and_then(|p| prune_inputs(p, ctx))
.and_then(|p| distinct(p, ctx))
.and_then(|p| union(p, ctx))
.and_then(|p| except(p, ctx))
.and_then(|p| intersect(p, ctx))
.map(reorder)
.map(|p| {
debug::log_entry(|| debug::DebugEntryKind::ReprPqEarly(p.clone()));
p
})
}
pub(in crate::sql) fn prune_inputs(
mut pipeline: Vec<SqlTransform>,
ctx: &mut Context,
) -> Result<Vec<SqlTransform>> {
use SqlTransform::Super;
let mut used_cids = HashSet::new();
let mut res = Vec::new();
while let Some(mut transform) = pipeline.pop() {
match transform {
SqlTransform::Join { ref filter, .. } => {
used_cids.extend(CidCollector::collect(filter.clone()));
}
SqlTransform::From(_) => {}
Super(t) => {
let (t, cids) = CidCollector::collect_t(t);
used_cids.extend(cids);
transform = Super(t);
}
_ => unreachable!(),
}
if let SqlTransform::From(with) | SqlTransform::Join { with, .. } = &mut transform {
let relation = ctx.anchor.relation_instances.get_mut(with).unwrap();
(relation.table_ref.columns).retain(|(_, cid)| used_cids.contains(cid));
}
res.push(transform);
}
res.reverse();
Ok(res)
}
pub(in crate::sql) fn wrap(pipe: Vec<Transform>, ctx: &mut Context) -> Result<Vec<SqlTransform>> {
pipe.into_iter()
.map(|x| {
Ok(match x {
Transform::From(table_ref) => {
let riid = ctx
.anchor
.create_relation_instance(table_ref, HashMap::new());
SqlTransform::From(riid)
}
Transform::Join { with, side, filter } => {
let with = ctx.anchor.create_relation_instance(with, HashMap::new());
SqlTransform::Join { with, side, filter }
}
x => SqlTransform::Super(x),
})
})
.try_collect()
}
fn vecs_contain_same_elements<T: Eq + std::hash::Hash>(a: &[T], b: &[T]) -> bool {
let a: HashSet<&T, RandomState> = a.iter().collect();
let b: HashSet<&T, RandomState> = b.iter().collect();
a == b
}
pub(in crate::sql) fn distinct(
pipeline: Vec<SqlTransform>,
ctx: &mut Context,
) -> Result<Vec<SqlTransform>> {
use SqlTransform::Super;
use Transform::*;
let mut res = Vec::new();
for transform in pipeline.clone() {
match transform {
Super(Take(rq::Take { ref partition, .. })) if partition.is_empty() => {
res.push(transform);
}
Super(Take(rq::Take {
range,
partition,
sort,
})) => {
let range_int = range
.clone()
.try_map(as_int)
.map_err(|_| Error::new_simple("Invalid take arguments"))?;
let take_only_first =
range_int.start.unwrap_or(1) == 1 && matches!(range_int.end, Some(1));
let columns_in_frame = ctx.anchor.determine_select_columns(&pipeline.clone());
let matching_columns = vecs_contain_same_elements(&columns_in_frame, &partition);
if take_only_first && sort.is_empty() && matching_columns {
res.push(SqlTransform::Distinct);
} else if ctx.dialect.supports_distinct_on() && range_int.end == Some(1) {
let sort = if sort.is_empty() {
vec![]
} else {
[into_column_sort(&partition), sort].concat()
};
res.push(SqlTransform::Sort(sort));
res.push(SqlTransform::DistinctOn(partition));
} else {
res.extend(create_filter_by_row_number(range, sort, partition, ctx));
}
}
_ => {
res.push(transform);
}
}
}
Ok(res)
}
fn into_column_sort(partition: &[CId]) -> Vec<ColumnSort<CId>> {
partition
.iter()
.map(|cid| ColumnSort {
direction: SortDirection::Asc,
column: *cid,
})
.collect_vec()
}
fn create_filter_by_row_number(
range: Range<Expr>,
sort: Vec<ColumnSort<CId>>,
partition: Vec<CId>,
ctx: &mut Context,
) -> Vec<SqlTransform> {
let expr = Expr {
kind: ExprKind::SString(vec![InterpolateItem::String("ROW_NUMBER()".to_string())]),
span: None,
};
let is_unsorted = sort.is_empty();
let window = Window {
frame: if is_unsorted {
WindowFrame {
kind: WindowKind::Rows,
range: Range::unbounded(),
}
} else {
WindowFrame {
kind: WindowKind::Range,
range: Range {
start: None,
end: Some(int_expr(0)),
},
}
},
partition,
sort,
};
let compute = Compute {
id: ctx.anchor.cid.gen(),
expr,
window: Some(window),
is_aggregation: false,
};
ctx.anchor.register_compute(compute.clone());
let col_ref = Expr {
kind: ExprKind::ColumnRef(compute.id),
span: None,
};
let range_int = range.try_map(as_int).unwrap();
let compute = SqlTransform::Super(Transform::Compute(compute));
let filter = SqlTransform::Super(Transform::Filter(match (range_int.start, range_int.end) {
(Some(s), Some(e)) if s == e => new_binop(col_ref, "std.eq", int_expr(s)),
(start, end) => {
let start = start.map(|start| new_binop(col_ref.clone(), "std.gte", int_expr(start)));
let end = end.map(|end| new_binop(col_ref, "std.lte", int_expr(end)));
maybe_binop(start, "std.and", end).unwrap_or(Expr {
kind: ExprKind::Literal(Literal::Boolean(true)),
span: None,
})
}
}));
vec![compute, filter]
}
fn as_int(expr: Expr) -> Result<i64, ()> {
let lit = expr.kind.as_literal().ok_or(())?;
lit.as_integer().cloned().ok_or(())
}
fn int_expr(i: i64) -> Expr {
Expr {
span: None,
kind: ExprKind::Literal(Literal::Integer(i)),
}
}
pub(in crate::sql) fn union(
pipeline: Vec<SqlTransform>,
ctx: &mut Context,
) -> Result<Vec<SqlTransform>> {
use SqlTransform::*;
use Transform::*;
let mut res = Vec::with_capacity(pipeline.len());
let mut pipeline = pipeline.into_iter().peekable();
while let Some(t) = pipeline.next() {
let Super(Append(bottom)) = t else {
res.push(t);
continue;
};
let bottom = ctx.anchor.create_relation_instance(bottom, HashMap::new());
let distinct = if let Some(Distinct) = &pipeline.peek() {
pipeline.next();
true
} else {
false
};
res.push(SqlTransform::Union { bottom, distinct });
}
Ok(res)
}
pub(in crate::sql) fn except(
pipeline: Vec<SqlTransform>,
ctx: &mut Context,
) -> Result<Vec<SqlTransform>> {
use SqlTransform::*;
let output = ctx.anchor.determine_select_columns(&pipeline);
let output: HashSet<CId, RandomState> = HashSet::from_iter(output);
let mut res = Vec::with_capacity(pipeline.len());
for t in pipeline {
res.push(t);
if res.len() < 2 {
continue;
}
let SqlTransform::Join {
side: JoinSide::Left,
filter: join_cond,
with,
} = &res[res.len() - 2]
else {
continue;
};
let Super(Transform::Filter(filter)) = &res[res.len() - 1] else {
continue;
};
let with = ctx.anchor.relation_instances.get(with).unwrap();
let top = ctx.anchor.determine_select_columns(&res[0..res.len() - 2]);
let bottom = with.table_ref.columns.iter().map(|(_, c)| *c).collect_vec();
let (join_left, join_right) = collect_equals(join_cond)?;
if !all_in(&top, join_left) || !all_in(&bottom, join_right) {
continue;
}
let (filter_left, filter_right) = collect_equals(filter)?;
if !(all_in(&bottom, filter_left) && all_null(filter_right)) {
continue;
}
if bottom.iter().any(|c| output.contains(c)) {
continue;
}
let mut distinct = false;
if res.len() >= 3 {
if let Distinct = &res[res.len() - 3] {
distinct = true;
}
}
if !distinct && !ctx.dialect.except_all() {
if ctx.anchor.contains_wildcard(&top) || ctx.anchor.contains_wildcard(&bottom) {
return Err(Error::new_simple(format!("The dialect {:?} does not support EXCEPT ALL", ctx.dialect))
.push_hint("providing more column information will allow the query to be translated to an anti-join."));
} else {
continue;
}
}
res.pop(); let join = res.pop(); let (_, with, _) = join.unwrap().into_join().unwrap();
if distinct {
if let Some(Distinct) = &res.last() {
res.pop();
}
}
res.push(SqlTransform::Except {
bottom: with,
distinct,
});
}
Ok(res)
}
pub(in crate::sql) fn intersect(
pipeline: Vec<SqlTransform>,
ctx: &mut Context,
) -> Result<Vec<SqlTransform>> {
use SqlTransform::*;
let output = ctx.anchor.determine_select_columns(&pipeline);
let output: HashSet<CId, RandomState> = HashSet::from_iter(output);
let mut res = Vec::with_capacity(pipeline.len());
let mut pipeline = pipeline.into_iter().peekable();
while let Some(t) = pipeline.next() {
res.push(t);
if res.is_empty() {
continue;
}
let Join {
side: JoinSide::Inner,
filter: join_cond,
with,
} = &res[res.len() - 1]
else {
continue;
};
let with = ctx.anchor.relation_instances.get_mut(with).unwrap();
let bottom = with.table_ref.columns.iter().map(|(_, c)| *c).collect_vec();
let top = ctx.anchor.determine_select_columns(&res[0..res.len() - 1]);
let (left, right) = collect_equals(join_cond)?;
if !(all_in(&top, left) && all_in(&bottom, right)) {
continue;
}
if bottom.iter().any(|c| output.contains(c)) {
continue;
}
if top.iter().all(|c| !output.contains(c)) {
continue;
}
let mut distinct = false;
if res.len() > 1 {
if let Distinct = &res[res.len() - 2] {
distinct = true;
}
}
if let Some(SqlTransform::Distinct) = pipeline.peek() {
distinct = true;
}
if !distinct && !ctx.dialect.intersect_all() {
if ctx.anchor.contains_wildcard(&top) || ctx.anchor.contains_wildcard(&bottom) {
return Err(Error::new_simple(format!("The dialect {:?} does not support INTERSECT ALL", ctx.dialect))
.push_hint("providing more column information will allow the query to be translated to an anti-join."));
} else {
continue;
}
}
let join = res.pop(); let (_, with, _) = join.unwrap().into_join().unwrap();
if distinct {
if let Some(Distinct) = &res.last() {
res.pop();
}
if let Some(SqlTransform::Distinct) = pipeline.peek() {
pipeline.next();
}
}
res.push(SqlTransform::Intersect {
bottom: with,
distinct,
});
}
Ok(res)
}
fn all_in(cids: &[CId], exprs: Vec<&Expr>) -> bool {
let exprs = col_refs(exprs);
cids.iter().all(|c| exprs.contains(c))
}
fn all_null(exprs: Vec<&Expr>) -> bool {
exprs
.iter()
.all(|e| matches!(e.kind, ExprKind::Literal(Literal::Null)))
}
fn collect_equals(expr: &Expr) -> Result<(Vec<&Expr>, Vec<&Expr>)> {
let mut lefts = Vec::new();
let mut rights = Vec::new();
match &expr.kind {
ExprKind::Operator { name, args } if name == "std.eq" && args.len() == 2 => {
lefts.push(&args[0]);
rights.push(&args[1]);
}
ExprKind::Operator { name, args } if name == "std.and" && args.len() == 2 => {
let (l, r) = collect_equals(&args[0])?;
lefts.extend(l);
rights.extend(r);
let (l, r) = collect_equals(&args[1])?;
lefts.extend(l);
rights.extend(r);
}
_ => (),
}
Ok((lefts, rights))
}
fn col_refs(exprs: Vec<&Expr>) -> Vec<CId> {
exprs
.into_iter()
.flat_map(|expr| expr.kind.as_column_ref().cloned())
.collect()
}
pub(in crate::sql) fn reorder(mut pipeline: Vec<SqlTransform>) -> Vec<SqlTransform> {
use SqlTransform::Super;
use Transform::*;
for i in 1..pipeline.len() {
if !matches!(&pipeline[i], Super(Compute(_))) {
continue;
}
for j in 0..(i - 1) {
let compute_i = i - j;
let prev_i = compute_i - 1;
let compute = pipeline[compute_i]
.as_super()
.unwrap()
.as_compute()
.unwrap();
let prev = &pipeline[prev_i];
let should_swap = match prev {
SqlTransform::From(_) | SqlTransform::Join { .. } | Super(Compute(_)) => false,
Super(Sort(_)) => true,
Super(Take(_)) if infer_complexity(compute) == Complexity::Plain => true,
_ => false,
};
if should_swap {
pipeline.swap(compute_i, prev_i);
} else {
break;
}
}
}
pipeline
}
pub(in crate::sql) fn normalize(pipeline: Vec<Transform>) -> Result<Vec<Transform>> {
Normalizer {}.fold_transforms(pipeline)
}
struct Normalizer {}
impl RqFold for Normalizer {
fn fold_expr(&mut self, expr: Expr) -> Result<Expr> {
let expr = Expr {
kind: rq::fold_expr_kind(self, expr.kind)?,
..expr
};
if let ExprKind::Operator { name, args } = &expr.kind {
if name == "std.eq" && args.len() == 2 {
let (left, right) = (&args[0], &args[1]);
let span = expr.span;
let new_args = if let ExprKind::Literal(Literal::Null) = &left.kind {
vec![right.clone(), left.clone()]
} else {
vec![left.clone(), right.clone()]
};
let new_kind = ExprKind::Operator {
name: name.clone(),
args: new_args,
};
return Ok(Expr {
kind: new_kind,
span,
});
}
}
Ok(expr)
}
}