use std::collections::HashSet;
use std::str::FromStr;
use anyhow::{anyhow, Result};
use enum_as_inner::EnumAsInner;
use itertools::Itertools;
use sqlparser::ast::{
self as sql_ast, Ident, Select, SelectItem, SetExpr, TableAlias, TableFactor, TableWithJoins,
};
use crate::ast::pl::{BinOp, Literal, RelationLiteral};
use crate::ast::rq::{CId, Expr, ExprKind, Query, Relation, RelationKind, TableDecl, Transform};
use crate::error::{Error, Reason};
use crate::utils::{BreakUp, IntoOnly, Pluck};
use super::context::AnchorContext;
use super::gen_expr::*;
use super::gen_projection::*;
use super::preprocess::{self, SqlTransform};
use super::{anchor, Context, Dialect};
pub fn translate_query(query: Query, dialect: Option<Dialect>) -> Result<sql_ast::Query> {
let dialect = if let Some(dialect) = dialect {
dialect
} else {
let target = query.def.other.get("target");
target
.and_then(|target| target.strip_prefix("sql."))
.map(|dialect| {
super::Dialect::from_str(dialect).map_err(|_| {
Error::new(Reason::NotFound {
name: format!("{dialect:?}"),
namespace: "dialect".to_string(),
})
})
})
.transpose()?
.unwrap_or_default()
};
let dialect = dialect.handler();
let (anchor, query) = AnchorContext::of(query);
let mut context = Context {
dialect,
anchor,
omit_ident_prefix: false,
pre_projection: false,
};
let tables = into_tables(query.relation, query.tables, &mut context)?;
let mut atomics = Vec::new();
for table in tables {
let name = table
.name
.unwrap_or_else(|| context.anchor.table_name.gen());
match table.relation.kind {
RelationKind::Pipeline(pipeline) => {
let pipeline = Ok(pipeline)
.map(preprocess::normalize)
.map(preprocess::push_down_selects)
.map(preprocess::prune_inputs)
.map(preprocess::wrap)
.and_then(|p| preprocess::distinct(p, &mut context))
.map(preprocess::union)
.and_then(|p| preprocess::except(p, &context))
.and_then(|p| preprocess::intersect(p, &context))
.map(preprocess::reorder)?;
context.anchor.load_names(&pipeline, table.relation.columns);
let ats = split_into_atomics(name, pipeline, &mut context.anchor);
ensure_names(&ats, &mut context.anchor);
atomics.extend(ats);
}
RelationKind::Literal(_) | RelationKind::SString(_) => atomics.push(AtomicQuery {
name,
relation: SqlRelation::Super(table.relation.kind),
}),
RelationKind::ExternRef(_) => {
}
}
}
let main_query = atomics.remove(atomics.len() - 1);
let ctes = atomics;
let ctes: Vec<_> = ctes
.into_iter()
.map(|t| table_to_sql_cte(t, &mut context))
.try_collect()?;
let mut main_query = sql_query_of_relation(main_query.relation, &mut context)?;
if !ctes.is_empty() {
main_query.with = Some(sql_ast::With {
cte_tables: ctes,
recursive: false,
});
}
Ok(main_query)
}
#[derive(Debug)]
pub struct AtomicQuery {
name: String,
relation: SqlRelation,
}
#[derive(Debug, EnumAsInner)]
enum SqlRelation {
Super(RelationKind),
Pipeline(Vec<SqlTransform>),
}
fn into_tables(
main_pipeline: Relation,
tables: Vec<TableDecl>,
context: &mut Context,
) -> Result<Vec<TableDecl>> {
let main = TableDecl {
id: context.anchor.tid.gen(),
name: None,
relation: main_pipeline,
};
Ok([tables, vec![main]].concat())
}
fn table_to_sql_cte(table: AtomicQuery, context: &mut Context) -> Result<sql_ast::Cte> {
let alias = sql_ast::TableAlias {
name: translate_ident_part(table.name, context),
columns: vec![],
};
Ok(sql_ast::Cte {
alias,
query: Box::new(sql_query_of_relation(table.relation, context)?),
from: None,
})
}
fn sql_query_of_relation(relation: SqlRelation, context: &mut Context) -> Result<sql_ast::Query> {
use RelationKind::*;
match relation {
SqlRelation::Super(ExternRef(_)) | SqlRelation::Super(Pipeline(_)) => unreachable!(),
SqlRelation::Pipeline(pipeline) => sql_query_of_pipeline(pipeline, context),
SqlRelation::Super(Literal(lit)) => Ok(sql_of_sample_data(lit)?),
SqlRelation::Super(SString(items)) => translate_query_sstring(items, context),
}
}
fn sql_query_of_pipeline(
pipeline: Vec<SqlTransform>,
context: &mut Context,
) -> Result<sql_ast::Query> {
use SqlTransform::*;
let (select, set_ops) =
pipeline.break_up(|t| matches!(t, Union { .. } | Except { .. } | Intersect { .. }));
let select = sql_select_query_of_pipeline(select, context)?;
sql_set_ops_of_pipeline(select, set_ops, context)
}
fn sql_select_query_of_pipeline(
mut pipeline: Vec<SqlTransform>,
context: &mut Context,
) -> Result<sql_ast::Query> {
let table_count = count_tables(&pipeline);
log::debug!("atomic query contains {table_count} tables");
context.omit_ident_prefix = table_count == 1;
context.pre_projection = true;
let projection = pipeline
.pluck(|t| t.into_super_and(|t| t.into_select()))
.into_only()
.unwrap();
let projection = translate_wildcards(&context.anchor, projection);
let projection = translate_select_items(projection.0, projection.1, context)?;
let mut from = pipeline
.pluck(|t| t.into_super_and(|t| t.into_from()))
.into_iter()
.map(|source| TableWithJoins {
relation: table_factor_of_tid(source, context),
joins: vec![],
})
.collect::<Vec<_>>();
let joins = pipeline
.pluck(|t| t.into_super_and(|t| t.into_join()))
.into_iter()
.map(|j| translate_join(j, context))
.collect::<Result<Vec<_>>>()?;
if !joins.is_empty() {
if let Some(from) = from.last_mut() {
from.joins = joins;
} else {
return Err(anyhow!("Cannot use `join` without `from`"));
}
}
let sorts = pipeline.pluck(|t| t.into_super_and(|t| t.into_sort()));
let takes = pipeline.pluck(|t| t.into_super_and(|t| t.into_take()));
let distinct = pipeline.iter().any(|t| matches!(t, SqlTransform::Distinct));
let (mut before_agg, mut after_agg) = pipeline.break_up(|t| {
matches!(
t,
SqlTransform::Super(Transform::Aggregate { .. } | Transform::Append(_))
)
});
let where_ = filter_of_conditions(
before_agg.pluck(|t| t.into_super_and(|t| t.into_filter())),
context,
)?;
let having = filter_of_conditions(
after_agg.pluck(|t| t.into_super_and(|t| t.into_filter())),
context,
)?;
let aggregate = after_agg
.pluck(|t| t.into_super_and(|t| t.into_aggregate()))
.into_iter()
.next();
let group_by: Vec<CId> = aggregate.map(|(part, _)| part).unwrap_or_default();
let group_by = try_into_exprs(group_by, context, None)?;
context.pre_projection = false;
let ranges = takes.into_iter().map(|x| x.range).collect();
let take = range_of_ranges(ranges)?;
let offset = take.start.map(|s| s - 1).unwrap_or(0);
let limit = take.end.map(|e| e - offset);
let offset = if offset == 0 {
None
} else {
Some(sqlparser::ast::Offset {
value: translate_expr_kind(ExprKind::Literal(Literal::Integer(offset)), context)?,
rows: sqlparser::ast::OffsetRows::None,
})
};
let order_by = sorts
.last()
.map(|sorts| {
sorts
.iter()
.map(|s| translate_column_sort(s, context))
.try_collect()
})
.transpose()?
.unwrap_or_default();
let (top, limit) = if context.dialect.use_top() {
(limit.map(|l| top_of_i64(l, context)), None)
} else {
(None, limit.map(expr_of_i64))
};
Ok(sql_ast::Query {
order_by,
limit,
offset,
..default_query(SetExpr::Select(Box::new(Select {
distinct,
top,
projection,
from,
selection: where_,
group_by,
having,
..default_select()
})))
})
}
fn sql_set_ops_of_pipeline(
mut top: sql_ast::Query,
mut pipeline: Vec<SqlTransform>,
context: &mut Context,
) -> Result<sql_ast::Query, anyhow::Error> {
pipeline.reverse();
while let Some(transform) = pipeline.pop() {
use SqlTransform::*;
let op = match &transform {
Union { .. } => sql_ast::SetOperator::Union,
Except { .. } => sql_ast::SetOperator::Except,
Intersect { .. } => sql_ast::SetOperator::Intersect,
_ => unreachable!(),
};
let (distinct, bottom) = match transform {
Union { distinct, bottom }
| Except { distinct, bottom }
| Intersect { distinct, bottom } => (distinct, bottom),
_ => unreachable!(),
};
let top_is_simple = top.with.is_none()
&& top.order_by.is_empty()
&& top.limit.is_none()
&& top.offset.is_none()
&& top.fetch.is_none()
&& top.locks.is_empty();
let left = if top_is_simple {
top.body
} else {
Box::new(SetExpr::Select(Box::new(Select {
projection: vec![SelectItem::Wildcard(
sql_ast::WildcardAdditionalOptions::default(),
)],
from: vec![TableWithJoins {
relation: TableFactor::Derived {
lateral: false,
subquery: Box::new(top),
alias: Some(TableAlias {
name: Ident::new(context.anchor.table_name.gen()),
columns: Vec::new(),
}),
},
joins: vec![],
}],
..default_select()
})))
};
top = default_query(SetExpr::SetOperation {
left,
right: Box::new(SetExpr::Select(Box::new(Select {
projection: vec![SelectItem::Wildcard(
sql_ast::WildcardAdditionalOptions::default(),
)],
from: vec![TableWithJoins {
relation: table_factor_of_tid(bottom, context),
joins: vec![],
}],
..default_select()
}))),
set_quantifier: if distinct {
if context.dialect.set_ops_distinct() {
sql_ast::SetQuantifier::Distinct
} else {
sql_ast::SetQuantifier::None
}
} else {
sql_ast::SetQuantifier::All
},
op,
});
}
Ok(top)
}
fn sql_of_sample_data(data: RelationLiteral) -> Result<sql_ast::Query> {
let mut selects = Vec::with_capacity(data.rows.len());
for row in data.rows {
let body = sql_ast::SetExpr::Select(Box::new(Select {
projection: std::iter::zip(data.columns.clone(), row)
.map(|(col, value)| -> Result<_> {
Ok(SelectItem::ExprWithAlias {
expr: translate_literal(value)?,
alias: sql_ast::Ident::new(col),
})
})
.try_collect()?,
..default_select()
}));
selects.push(body)
}
let mut body = selects.remove(0);
for select in selects {
body = SetExpr::SetOperation {
op: sql_ast::SetOperator::Union,
set_quantifier: sql_ast::SetQuantifier::All,
left: Box::new(body),
right: Box::new(select),
}
}
Ok(default_query(body))
}
fn split_into_atomics(
name: String,
mut pipeline: Vec<SqlTransform>,
ctx: &mut AnchorContext,
) -> Vec<AtomicQuery> {
let outputs_cid = AnchorContext::determine_select_columns(&pipeline);
let mut required_cols = outputs_cid.clone();
let mut parts_rev = Vec::new();
loop {
let (preceding, split) = anchor::split_off_back(ctx, required_cols, pipeline);
if let Some((preceding, cols_at_split)) = preceding {
log::debug!(
"pipeline split after {}",
preceding.last().unwrap().as_str()
);
parts_rev.push((split, cols_at_split.clone()));
pipeline = preceding;
required_cols = cols_at_split;
} else {
parts_rev.push((split, Vec::new()));
break;
}
}
parts_rev.reverse();
let mut parts = parts_rev;
if let Some((pipeline, _)) = parts.last() {
let select_cols = pipeline
.first()
.unwrap()
.as_super()
.unwrap()
.as_select()
.unwrap();
if select_cols.iter().any(|c| !outputs_cid.contains(c)) {
parts.push((
vec![SqlTransform::Super(Transform::Select(outputs_cid))],
select_cols.clone(),
));
}
}
let mut atomics = Vec::with_capacity(parts.len());
let last = parts.pop().unwrap();
let last_pipeline = if parts.is_empty() {
last.0
} else {
let first = parts.remove(0);
let first_name = ctx.table_name.gen();
atomics.push(AtomicQuery {
name: first_name.clone(),
relation: SqlRelation::Pipeline(first.0),
});
let mut prev_name = first_name;
for (pipeline, cols_before) in parts.into_iter() {
let name = ctx.table_name.gen();
let pipeline = anchor::anchor_split(ctx, &prev_name, &cols_before, pipeline);
atomics.push(AtomicQuery {
name: name.clone(),
relation: SqlRelation::Pipeline(pipeline),
});
prev_name = name;
}
anchor::anchor_split(ctx, &prev_name, &last.1, last.0)
};
atomics.push(AtomicQuery {
name,
relation: SqlRelation::Pipeline(last_pipeline),
});
atomics
}
fn ensure_names(atomics: &[AtomicQuery], ctx: &mut AnchorContext) {
for a in atomics {
let empty = HashSet::new();
for t in a.relation.as_pipeline().unwrap() {
match t {
SqlTransform::Super(Transform::Sort(_)) => {
for r in anchor::get_requirements(t, &empty) {
ctx.ensure_column_name(r.col);
}
}
SqlTransform::Super(Transform::Select(cids)) => {
for cid in cids {
let _decl = &ctx.column_decls[cid];
}
}
_ => (),
}
}
}
}
fn filter_of_conditions(exprs: Vec<Expr>, context: &mut Context) -> Result<Option<sql_ast::Expr>> {
Ok(if let Some(cond) = all(exprs) {
Some(translate_expr_kind(cond.kind, context)?)
} else {
None
})
}
fn all(mut exprs: Vec<Expr>) -> Option<Expr> {
let mut condition = exprs.pop()?;
while let Some(expr) = exprs.pop() {
condition = Expr {
kind: ExprKind::Binary {
op: BinOp::And,
left: Box::new(expr),
right: Box::new(condition),
},
span: None,
};
}
Some(condition)
}
fn default_query(body: sql_ast::SetExpr) -> sql_ast::Query {
sql_ast::Query {
with: None,
body: Box::new(body),
order_by: Vec::new(),
limit: None,
offset: None,
fetch: None,
locks: Vec::new(),
}
}
fn default_select() -> Select {
Select {
distinct: false,
top: None,
projection: Vec::new(),
into: None,
from: Vec::new(),
lateral_views: Vec::new(),
selection: None,
group_by: Vec::new(),
cluster_by: Vec::new(),
distribute_by: Vec::new(),
sort_by: Vec::new(),
having: None,
qualify: None,
}
}
fn count_tables(transforms: &[SqlTransform]) -> usize {
let mut count = 0;
for transform in transforms {
if let SqlTransform::Super(Transform::Join { .. } | Transform::From(_)) = transform {
count += 1;
}
}
count
}
#[cfg(test)]
mod test {
use insta::assert_snapshot;
use super::*;
use crate::{parser::parse, semantic::resolve, sql::dialect::GenericDialect};
fn parse_and_resolve(prql: &str) -> Result<(Vec<SqlTransform>, Context)> {
let query = resolve(parse(prql)?)?;
let (anchor, query) = AnchorContext::of(query);
let context = Context {
dialect: Box::new(GenericDialect {}),
anchor,
omit_ident_prefix: false,
pre_projection: false,
};
let pipeline = query.relation.kind.into_pipeline().unwrap();
Ok((preprocess::reorder(preprocess::wrap(pipeline)), context))
}
#[test]
fn test_ctes_of_pipeline() {
let prql: &str = r###"
from employees
filter country == "USA"
aggregate [sal = average salary]
sort sal
take 20
"###;
let (pipeline, mut context) = parse_and_resolve(prql).unwrap();
let queries = split_into_atomics("".to_string(), pipeline, &mut context.anchor);
assert_eq!(queries.len(), 1);
let prql: &str = r###"
from employees
take 20
filter country == "USA"
aggregate [sal = average salary]
sort sal
"###;
let (pipeline, mut context) = parse_and_resolve(prql).unwrap();
let queries = split_into_atomics("".to_string(), pipeline, &mut context.anchor);
assert_eq!(queries.len(), 2);
let prql: &str = r###"
from employees
take 20
filter country == "USA"
aggregate [sal = average salary]
aggregate [sal2 = average sal]
sort sal2
"###;
let (pipeline, mut context) = parse_and_resolve(prql).unwrap();
let queries = split_into_atomics("".to_string(), pipeline, &mut context.anchor);
assert_eq!(queries.len(), 3);
let prql: &str = r###"
from employees
take 20
select first_name
"###;
let (pipeline, mut context) = parse_and_resolve(prql).unwrap();
let queries = split_into_atomics("".to_string(), pipeline, &mut context.anchor);
assert_eq!(queries.len(), 1);
}
#[test]
fn test_variable_after_aggregate() {
let query = &r#"
from employees
group [title, emp_no] (
aggregate [emp_salary = average salary]
)
group [title] (
aggregate [avg_salary = average emp_salary]
)
"#;
let sql_ast = crate::test::compile(query).unwrap();
assert_snapshot!(sql_ast);
}
#[test]
fn test_derive_filter() {
let query = &r#"
from employees
derive global_rank = rank
filter country == "USA"
derive rank = rank
"#;
let sql_ast = crate::test::compile(query).unwrap();
assert_snapshot!(sql_ast, @r###"
WITH table_1 AS (
SELECT
*,
RANK() OVER () AS global_rank
FROM
employees
)
SELECT
*,
RANK() OVER () AS rank
FROM
table_1
WHERE
country = 'USA'
"###);
}
#[test]
fn test_filter_windowed() {
let query = &r#"
from tbl1
filter (average bar) > 3
"#;
assert_snapshot!(crate::test::compile(query).unwrap(), @r###"
WITH table_1 AS (
SELECT
*,
AVG(bar) OVER () AS _expr_0
FROM
tbl1
)
SELECT
*
FROM
table_1
WHERE
_expr_0 > 3
"###);
}
}