use anyhow::{anyhow, bail, Result};
use itertools::Itertools;
use sqlformat::{format, FormatOptions, QueryParams};
use sqlparser::ast::{self as sql_ast, Select, SetExpr, TableWithJoins};
use crate::ast::pl::{DialectHandler, Literal};
use crate::ast::rq::{CId, Expr, ExprKind, IrFold, Query, Relation, TableDecl, Transform};
use crate::sql::anchor::materialize_inputs;
use crate::utils::{IntoOnly, Pluck, TableCounter};
use super::anchor;
use super::codegen::*;
use super::context::AnchorContext;
use super::distinct::{preprocess_distinct, preprocess_reorder};
pub(super) struct Context {
pub dialect: Box<dyn DialectHandler>,
pub anchor: AnchorContext,
pub omit_ident_prefix: bool,
pub pre_projection: bool,
}
pub fn translate(query: Query) -> Result<String> {
let sql_query = translate_query(query)?;
let sql_query_string = sql_query.to_string();
let formatted = format(
&sql_query_string,
&QueryParams::default(),
FormatOptions::default(),
);
let formatted = formatted.replace("{ {", "{{").replace("} }", "}}");
Ok(formatted)
}
pub fn translate_query(query: Query) -> Result<sql_ast::Query> {
let dialect = query.def.dialect.handler();
let (anchor, query) = AnchorContext::of(query);
let mut context = Context {
dialect,
anchor,
omit_ident_prefix: false,
pre_projection: false,
};
let tables = into_tables(query.relation, query.tables, &mut context)?;
let mut atomics = Vec::new();
for table in tables {
let pipeline = if let Relation::Pipeline(pipeline) = table.relation {
pipeline
} else {
continue;
};
let pipeline = preprocess_reorder(pipeline);
let pipeline = preprocess_distinct(pipeline, &mut context)?;
atomics.extend(split_into_atomics(table.name, pipeline, &mut context));
}
if atomics.is_empty() {
bail!("No tables?");
}
let main_query = atomics.remove(atomics.len() - 1);
let ctes = atomics;
let ctes: Vec<_> = ctes
.into_iter()
.map(|t| table_to_sql_cte(t, &mut context))
.try_collect()?;
let mut main_query = sql_query_of_atomic_query(main_query.pipeline, &mut context)?;
if !ctes.is_empty() {
main_query.with = Some(sql_ast::With {
cte_tables: ctes,
recursive: false,
});
}
Ok(main_query)
}
#[derive(Debug)]
pub struct AtomicQuery {
name: Option<String>,
pipeline: Vec<Transform>,
}
fn into_tables(
main_pipeline: Relation,
tables: Vec<TableDecl>,
context: &mut Context,
) -> Result<Vec<TableDecl>> {
let main = TableDecl {
id: context.anchor.tid.gen(),
name: None,
relation: main_pipeline,
};
Ok([tables, vec![main]].concat())
}
fn table_to_sql_cte(table: AtomicQuery, context: &mut Context) -> Result<sql_ast::Cte> {
let alias = sql_ast::TableAlias {
name: translate_ident_part(table.name.unwrap(), context),
columns: vec![],
};
Ok(sql_ast::Cte {
alias,
query: Box::new(sql_query_of_atomic_query(table.pipeline, context)?),
from: None,
})
}
fn sql_query_of_atomic_query(
pipeline: Vec<Transform>,
context: &mut Context,
) -> Result<sql_ast::Query> {
let mut counter = TableCounter::default();
let mut pipeline = counter.fold_transforms(pipeline)?;
context.omit_ident_prefix = counter.count() == 1;
log::debug!("atomic query contains {} tables", counter.count());
context.pre_projection = true;
let projection = pipeline
.pluck(|t| t.into_select())
.into_only()
.unwrap_or_default()
.into_iter()
.map(|id| translate_select_item(id, context))
.try_collect()?;
let mut from = pipeline
.pluck(|t| t.into_from())
.into_iter()
.map(|source| TableWithJoins {
relation: table_factor_of_tid(source, context),
joins: vec![],
})
.collect::<Vec<_>>();
let joins = pipeline
.pluck(|t| t.into_join())
.into_iter()
.map(|j| translate_join(j, context))
.collect::<Result<Vec<_>>>()?;
if !joins.is_empty() {
if let Some(from) = from.last_mut() {
from.joins = joins;
} else {
return Err(anyhow!("Cannot use `join` without `from`"));
}
}
let aggregate_position = pipeline
.iter()
.position(|t| matches!(t, Transform::Aggregate { .. }))
.unwrap_or(pipeline.len());
let (before, after) = pipeline.split_at(aggregate_position);
let where_ = filter_of_pipeline(before, context)?;
let having = filter_of_pipeline(after, context)?;
let aggregate = pipeline.get(aggregate_position);
let group_by: Vec<CId> = aggregate
.map(|t| match t {
Transform::Aggregate { partition, .. } => partition.clone(),
_ => unreachable!(),
})
.unwrap_or_default();
let group_by = try_into_exprs(group_by, context)?;
context.pre_projection = false;
let takes = pipeline.pluck(|t| t.into_take());
let ranges = takes.into_iter().map(|x| x.range).collect();
let take = range_of_ranges(ranges)?;
let offset = take.start.map(|s| s - 1).unwrap_or(0);
let limit = take.end.map(|e| e - offset);
let offset = if offset == 0 {
None
} else {
Some(sqlparser::ast::Offset {
value: translate_expr_kind(ExprKind::Literal(Literal::Integer(offset)), context)?,
rows: sqlparser::ast::OffsetRows::None,
})
};
let order_by = pipeline
.pluck(|t| t.into_sort())
.last()
.map(|sorts| {
sorts
.iter()
.map(|s| translate_column_sort(s, context))
.try_collect()
})
.transpose()?
.unwrap_or_default();
let distinct = pipeline.iter().any(|t| matches!(t, Transform::Unique));
Ok(sql_ast::Query {
body: Box::new(SetExpr::Select(Box::new(Select {
distinct,
top: if context.dialect.use_top() {
limit.map(|l| top_of_i64(l, context))
} else {
None
},
projection,
into: None,
from,
lateral_views: vec![],
selection: where_,
group_by,
cluster_by: vec![],
distribute_by: vec![],
sort_by: vec![],
having,
qualify: None,
}))),
order_by,
with: None,
limit: if context.dialect.use_top() {
None
} else {
limit.map(expr_of_i64)
},
offset,
fetch: None,
lock: None,
})
}
fn split_into_atomics(
table_name: Option<String>,
mut pipeline: Vec<Transform>,
context: &mut Context,
) -> Vec<AtomicQuery> {
materialize_inputs(&pipeline, &mut context.anchor);
let mut output_cols = context.anchor.determine_select_columns(&pipeline);
let mut parts_rev = Vec::new();
loop {
let (preceding, split) =
anchor::split_off_back(&mut context.anchor, output_cols.clone(), pipeline);
if let Some((preceding, cols_at_split)) = preceding {
log::debug!(
"pipeline split after {}",
preceding.last().unwrap().as_ref()
);
parts_rev.push((split, cols_at_split.clone()));
pipeline = preceding;
output_cols = cols_at_split;
} else {
parts_rev.push((split, Vec::new()));
break;
}
}
parts_rev.reverse();
let mut parts = parts_rev;
let mut atomics = Vec::with_capacity(parts.len());
let last = parts.pop().unwrap();
let last_pipeline = if parts.is_empty() {
last.0
} else {
let first = parts.remove(0);
let first_name = context.anchor.gen_table_name();
atomics.push(AtomicQuery {
name: Some(first_name.clone()),
pipeline: first.0,
});
let mut prev_name = first_name;
for (pipeline, cols_before) in parts.into_iter() {
let name = context.anchor.gen_table_name();
let pipeline =
anchor::anchor_split(&mut context.anchor, &prev_name, &cols_before, pipeline);
atomics.push(AtomicQuery {
name: Some(name.clone()),
pipeline,
});
prev_name = name;
}
anchor::anchor_split(&mut context.anchor, &prev_name, &last.1, last.0)
};
atomics.push(AtomicQuery {
name: table_name,
pipeline: last_pipeline,
});
atomics
}
fn filter_of_pipeline(
pipeline: &[Transform],
context: &mut Context,
) -> Result<Option<sql_ast::Expr>> {
let filters: Vec<Expr> = pipeline
.iter()
.filter_map(|t| match &t {
Transform::Filter(filter) => Some(filter.clone()),
_ => None,
})
.collect();
filter_of_filters(filters, context)
}
impl From<Vec<Transform>> for AtomicQuery {
fn from(pipeline: Vec<Transform>) -> Self {
AtomicQuery {
name: None,
pipeline,
}
}
}
#[cfg(test)]
mod test {
use insta::assert_snapshot;
use super::*;
use crate::{ast::pl::GenericDialect, parse, semantic::resolve};
fn parse_and_resolve(prql: &str) -> Result<(Vec<Transform>, Context)> {
let query = resolve(parse(prql)?)?;
let (anchor, query) = AnchorContext::of(query);
let context = Context {
dialect: Box::new(GenericDialect {}),
anchor,
omit_ident_prefix: false,
pre_projection: false,
};
let pipeline = query.relation.into_pipeline().unwrap();
Ok((preprocess_reorder(pipeline), context))
}
#[test]
fn test_ctes_of_pipeline() {
let prql: &str = r###"
from employees
filter country == "USA"
aggregate [sal = average salary]
sort sal
take 20
"###;
let (pipeline, mut context) = parse_and_resolve(prql).unwrap();
let queries = split_into_atomics(None, pipeline, &mut context);
assert_eq!(queries.len(), 1);
let prql: &str = r###"
from employees
take 20
filter country == "USA"
aggregate [sal = average salary]
sort sal
"###;
let (pipeline, mut context) = parse_and_resolve(prql).unwrap();
let queries = split_into_atomics(None, pipeline, &mut context);
assert_eq!(queries.len(), 2);
let prql: &str = r###"
from employees
take 20
filter country == "USA"
aggregate [sal = average salary]
aggregate [sal2 = average sal]
sort sal2
"###;
let (pipeline, mut context) = parse_and_resolve(prql).unwrap();
let queries = split_into_atomics(None, pipeline, &mut context);
assert_eq!(queries.len(), 3);
let prql: &str = r###"
from employees
take 20
select first_name
"###;
let (pipeline, mut context) = parse_and_resolve(prql).unwrap();
let queries = split_into_atomics(None, pipeline, &mut context);
assert_eq!(queries.len(), 1);
}
#[test]
fn test_variable_after_aggregate() {
let query = &r#"
from employees
group [title, emp_no] (
aggregate [emp_salary = average salary]
)
group [title] (
aggregate [avg_salary = average emp_salary]
)
"#;
let query = resolve(parse(query).unwrap()).unwrap();
let sql_ast = translate(query).unwrap();
assert_snapshot!(sql_ast);
}
#[test]
fn test_derive_filter() {
let query = &r#"
from employees
derive global_rank = rank
filter country == "USA"
derive rank = rank
"#;
let query = resolve(parse(query).unwrap()).unwrap();
let sql_ast = translate(query).unwrap();
assert_snapshot!(sql_ast, @r###"
WITH table_0 AS (
SELECT
*,
RANK() OVER () AS global_rank,
country
FROM
employees
)
SELECT
*,
global_rank,
RANK() OVER () AS rank
FROM
table_0
WHERE
country = 'USA'
"###);
}
#[test]
fn test_filter_windowed() {
let query = &r#"
from tbl1
filter (average bar) > 3
"#;
assert_snapshot!(crate::compile(query).unwrap(), @r###"
WITH table_0 AS (
SELECT
*,
AVG(bar) OVER () AS _expr_0
FROM
tbl1
)
SELECT
*
FROM
table_0
WHERE
_expr_0 > 3
"###);
}
}