use anyhow::Result;
use itertools::Itertools;
use std::collections::{HashMap, HashSet};
use crate::ast::rq::{
self, fold_transform, CId, Compute, Expr, Relation, RelationColumn, RelationKind, RqFold,
TableDecl, TableRef, Transform,
};
use super::{
context::{AnchorContext, ColumnDecl},
preprocess::{SqlFold, SqlTransform},
};
type RemainingPipeline = (Vec<SqlTransform>, Vec<CId>);
pub(super) fn split_off_back(
ctx: &mut AnchorContext,
output: Vec<CId>,
mut pipeline: Vec<SqlTransform>,
) -> (Option<RemainingPipeline>, Vec<SqlTransform>) {
if pipeline.is_empty() {
return (None, Vec::new());
}
log::debug!("traversing pipeline to obtain columns: {output:?}");
let mut following_transforms: HashSet<String> = HashSet::new();
let mut inputs_required = into_requirements(output.clone(), Complexity::highest(), true);
let mut inputs_avail = HashSet::new();
let mut curr_pipeline_rev = Vec::new();
'pipeline: while let Some(transform) = pipeline.pop() {
let split = is_split_required(&transform, &mut following_transforms);
if split {
log::debug!("split required after {}", transform.as_str());
log::debug!(".. following={:?}", following_transforms);
pipeline.push(transform);
break;
}
let required = get_requirements(&transform, &following_transforms);
log::debug!("transform {} requires {:?}", transform.as_str(), required);
inputs_required.extend(required);
match &transform {
SqlTransform::Super(Transform::Compute(compute)) => {
if can_materialize(compute, &inputs_required) {
log::debug!("materializing {:?}", compute.id);
inputs_avail.insert(compute.id);
} else {
pipeline.push(transform);
break;
}
}
SqlTransform::Super(Transform::Aggregate { compute, .. }) => {
for cid in compute {
let decl = &ctx.column_decls[cid];
if let ColumnDecl::Compute(compute) = decl {
if !can_materialize(compute, &inputs_required) {
pipeline.push(transform);
break 'pipeline;
}
}
}
}
SqlTransform::Super(Transform::From(with) | Transform::Join { with, .. }) => {
for (_, cid) in &with.columns {
inputs_avail.insert(*cid);
}
}
_ => (),
}
if !matches!(transform, SqlTransform::Super(Transform::Select(_))) {
curr_pipeline_rev.push(transform);
}
}
let selected = inputs_required
.iter()
.filter(|r| r.selected)
.map(|r| r.col)
.collect_vec();
log::debug!("finished table:");
log::debug!(".. avail={inputs_avail:?}");
let required = inputs_required
.into_iter()
.map(|r| r.col)
.unique()
.collect_vec();
log::debug!(".. required={required:?}");
let missing = required
.into_iter()
.filter(|i| !inputs_avail.contains(i))
.collect_vec();
log::debug!(".. missing={missing:?}");
{
let mut output = output;
for c in selected {
if !output.contains(&c) {
output.push(c);
}
}
let output = if output.is_empty() {
let (input_tables, _) = ctx.collect_pipeline_inputs(&pipeline);
input_tables
.iter()
.map(|tiid| ctx.register_wildcard(*tiid))
.collect()
} else {
output
};
curr_pipeline_rev.push(SqlTransform::Super(Transform::Select(output)));
}
let remaining_pipeline = if pipeline.is_empty() {
None
} else {
Some((pipeline, missing))
};
curr_pipeline_rev.reverse();
(remaining_pipeline, curr_pipeline_rev)
}
fn can_materialize(compute: &Compute, inputs_required: &[Requirement]) -> bool {
let complexity = infer_complexity(compute);
let required_max = inputs_required
.iter()
.filter(|r| r.col == compute.id)
.fold(Complexity::highest(), |c, r| {
Complexity::min(c, r.max_complexity)
});
let can = complexity <= required_max;
if !can {
log::debug!(
"{:?} has complexity {complexity:?}, but is required to have max={required_max:?}",
compute.id
);
}
can
}
pub(super) fn anchor_split(
ctx: &mut AnchorContext,
first_table_name: &str,
cols_at_split: &[CId],
second_pipeline: Vec<SqlTransform>,
) -> Vec<SqlTransform> {
let new_tid = ctx.tid.gen();
log::debug!("split pipeline, first pipeline output: {cols_at_split:?}");
let mut cid_redirects = HashMap::<CId, CId>::new();
let mut new_columns = Vec::new();
for old_cid in cols_at_split {
let new_cid = ctx.cid.gen();
let old_name = ctx.ensure_column_name(*old_cid).cloned();
if let Some(name) = old_name.clone() {
ctx.column_names.insert(new_cid, name);
}
let old_def = ctx.column_decls.get(old_cid).unwrap();
let col = match old_def {
ColumnDecl::RelationColumn(_, _, RelationColumn::Wildcard) => RelationColumn::Wildcard,
_ => RelationColumn::Single(old_name),
};
new_columns.push((col, new_cid));
cid_redirects.insert(*old_cid, new_cid);
}
ctx.table_decls.insert(
new_tid,
TableDecl {
id: new_tid,
name: Some(first_table_name.to_string()),
relation: Relation {
kind: RelationKind::SString(vec![]),
columns: vec![],
},
},
);
let table_ref = TableRef {
source: new_tid,
name: Some(first_table_name.to_string()),
columns: new_columns,
};
ctx.create_table_instance(table_ref.clone());
let mut second = second_pipeline;
second.insert(0, SqlTransform::Super(Transform::From(table_ref)));
let mut redirector = CidRedirector { ctx, cid_redirects };
redirector.fold_sql_transforms(second).unwrap()
}
fn is_split_required(transform: &SqlTransform, following: &mut HashSet<String>) -> bool {
use SqlTransform::*;
use Transform::*;
if let Super(Compute(decl)) = transform {
if decl.is_aggregation {
return false;
}
}
fn contains_any<const C: usize>(set: &HashSet<String>, elements: [&'static str; C]) -> bool {
for t in elements {
if set.contains(t) {
return true;
}
}
false
}
let split = match transform {
Super(From(_)) => contains_any(following, ["From"]),
Super(Join { .. }) => contains_any(following, ["From"]),
Super(Aggregate { .. }) => contains_any(following, ["From", "Join", "Aggregate"]),
Super(Filter(_)) => contains_any(following, ["From", "Join"]),
Super(Compute(_)) => contains_any(following, ["From", "Join", "Filter"]),
Super(Sort(_)) => contains_any(following, ["From", "Join", "Compute", "Aggregate"]),
Super(Take(_)) => contains_any(
following,
["From", "Join", "Compute", "Filter", "Aggregate", "Sort"],
),
Distinct => contains_any(
following,
[
"From",
"Join",
"Compute",
"Filter",
"Aggregate",
"Sort",
"Take",
],
),
Union { .. } | Except { .. } | Intersect { .. } => contains_any(
following,
[
"From",
"Join",
"Compute",
"Filter",
"Aggregate",
"Sort",
"Take",
"Distinct",
],
),
_ => false,
};
if !split {
following.insert(transform.as_str().to_string());
}
split
}
pub struct Requirement {
pub col: CId,
pub max_complexity: Complexity,
pub selected: bool,
}
fn into_requirements(
cids: Vec<CId>,
max_complexity: Complexity,
selected: bool,
) -> Vec<Requirement> {
cids.into_iter()
.map(|col| Requirement {
col,
max_complexity,
selected,
})
.collect()
}
impl std::fmt::Debug for Requirement {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
std::fmt::Debug::fmt(&self.col, f)?;
f.write_str("-as-")?;
std::fmt::Debug::fmt(&self.max_complexity, f)
}
}
pub(super) fn get_requirements(
transform: &SqlTransform,
following: &HashSet<String>,
) -> Vec<Requirement> {
use SqlTransform::*;
use Transform::*;
if let Super(Aggregate { partition, compute }) = transform {
let mut r = Vec::new();
r.extend(into_requirements(
partition.clone(),
Complexity::Plain,
false,
));
r.extend(into_requirements(
compute.clone(),
Complexity::Aggregation,
false,
));
return r;
}
let cids = match transform {
Super(Compute(compute)) => CidCollector::collect(compute.expr.clone()),
Super(Filter(expr) | Join { filter: expr, .. }) => CidCollector::collect(expr.clone()),
Super(Sort(sorts)) => sorts.iter().map(|s| s.column).collect(),
Super(Take(rq::Take { range, .. })) => {
let mut cids = Vec::new();
if let Some(e) = &range.start {
cids.extend(CidCollector::collect(e.clone()));
}
if let Some(e) = &range.end {
cids.extend(CidCollector::collect(e.clone()));
}
cids
}
Super(Append(_)) => unreachable!(),
Super(Select(_) | From(_) | Aggregate { .. })
| Distinct
| Union { .. }
| Except { .. }
| Intersect { .. } => return Vec::new(),
};
let (max_complexity, selected) = match transform {
Super(Compute(decl)) => (
if infer_complexity(decl) == Complexity::Plain {
Complexity::Aggregation
} else {
Complexity::Plain
},
false,
),
Super(Filter(_)) => (
if !following.contains("Aggregate") {
Complexity::Aggregation
} else {
Complexity::Plain
},
false,
),
Super(Sort(_)) => (Complexity::Aggregation, true),
Super(Take(_)) => (Complexity::Plain, false),
Super(Join { .. }) => (Complexity::Plain, false),
_ => unreachable!(),
};
into_requirements(cids, max_complexity, selected)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum Complexity {
Plain,
NonGroup,
Windowed,
Aggregation,
}
impl Complexity {
const fn highest() -> Self {
Self::Aggregation
}
}
pub fn infer_complexity(compute: &Compute) -> Complexity {
use Complexity::*;
if compute.window.is_some() {
Windowed
} else if compute.is_aggregation {
Aggregation
} else {
infer_complexity_expr(&compute.expr)
}
}
pub fn infer_complexity_expr(expr: &Expr) -> Complexity {
match &expr.kind {
rq::ExprKind::Switch(_) => Complexity::NonGroup,
rq::ExprKind::Binary { left, right, .. } => {
Complexity::max(infer_complexity_expr(left), infer_complexity_expr(right))
}
rq::ExprKind::Unary { expr, .. } => infer_complexity_expr(expr),
rq::ExprKind::BuiltInFunction { args, .. } => args
.iter()
.map(infer_complexity_expr)
.max()
.unwrap_or(Complexity::Plain),
rq::ExprKind::ColumnRef(_)
| rq::ExprKind::Literal(_)
| rq::ExprKind::SString(_)
| rq::ExprKind::FString(_) => Complexity::Plain,
}
}
#[derive(Default)]
pub struct CidCollector {
cids: Vec<CId>,
}
impl CidCollector {
pub fn collect(expr: Expr) -> Vec<CId> {
let mut collector = CidCollector::default();
collector.fold_expr(expr).unwrap();
collector.cids
}
pub fn collect_t(t: Transform) -> (Transform, Vec<CId>) {
let mut collector = CidCollector::default();
let t = collector.fold_transform(t).unwrap();
(t, collector.cids)
}
}
impl RqFold for CidCollector {
fn fold_cid(&mut self, cid: CId) -> Result<CId> {
self.cids.push(cid);
Ok(cid)
}
}
struct CidRedirector<'a> {
ctx: &'a mut AnchorContext,
cid_redirects: HashMap<CId, CId>,
}
impl<'a> RqFold for CidRedirector<'a> {
fn fold_cid(&mut self, cid: CId) -> Result<CId> {
Ok(self.cid_redirects.get(&cid).cloned().unwrap_or(cid))
}
fn fold_transform(&mut self, transform: Transform) -> Result<Transform> {
match transform {
Transform::Compute(compute) => {
let compute = self.fold_compute(compute)?;
self.ctx.register_compute(compute.clone());
Ok(Transform::Compute(compute))
}
_ => fold_transform(self, transform),
}
}
}
impl<'a> SqlFold for CidRedirector<'a> {}