use std::sync::Arc;
#[derive(Debug, Clone)]
pub enum SqlValue {
String(String),
Int(i64),
Float(f64),
Bool(bool),
Expression(String),
}
#[derive(Debug, Clone, PartialEq, Eq, Default)]
pub enum JoinType {
#[default]
Left,
Inner,
Full,
Cross,
}
impl JoinType {
pub fn sql_keyword(&self) -> &'static str {
match self {
JoinType::Left => "LEFT JOIN",
JoinType::Inner => "INNER JOIN",
JoinType::Full => "FULL OUTER JOIN",
JoinType::Cross => "CROSS JOIN",
}
}
}
#[derive(Clone)]
pub struct QueryBuilderFn(pub Arc<dyn Fn(&QueryIR) -> CompileResult + Send + Sync>);
impl std::fmt::Debug for QueryBuilderFn {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("QueryBuilderFn(...)")
}
}
#[derive(Debug, Clone)]
pub struct QueryIR {
pub cube: String,
pub schema: String,
pub table: String,
pub selects: Vec<SelectExpr>,
pub filters: FilterNode,
pub having: FilterNode,
pub group_by: Vec<String>,
pub order_by: Vec<OrderExpr>,
pub limit: u32,
pub offset: u32,
pub limit_by: Option<LimitByExpr>,
pub use_final: bool,
pub joins: Vec<JoinExpr>,
pub custom_query_builder: Option<QueryBuilderFn>,
pub from_subquery: Option<String>,
}
#[derive(Debug, Clone)]
pub struct JoinExpr {
pub schema: String,
pub table: String,
pub alias: String,
pub conditions: Vec<(String, String)>,
pub selects: Vec<SelectExpr>,
pub group_by: Vec<String>,
pub use_final: bool,
pub is_aggregate: bool,
pub target_cube: String,
pub join_field: String,
pub join_type: JoinType,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum DimAggType {
ArgMax,
ArgMin,
}
#[derive(Debug, Clone)]
pub enum SelectExpr {
Column {
column: String,
alias: Option<String>,
},
Aggregate {
function: String,
column: String,
alias: String,
condition: Option<String>,
},
DimAggregate {
agg_type: DimAggType,
value_column: String,
compare_column: String,
alias: String,
condition: Option<String>,
},
}
#[derive(Debug, Clone)]
pub enum FilterNode {
And(Vec<FilterNode>),
Or(Vec<FilterNode>),
Condition {
column: String,
op: CompareOp,
value: SqlValue,
},
ArrayIncludes {
array_columns: Vec<String>,
element_conditions: Vec<Vec<FilterNode>>,
},
Empty,
}
#[derive(Debug, Clone)]
pub enum CompareOp {
Eq,
Ne,
Gt,
Ge,
Lt,
Le,
Like,
NotLike,
In,
NotIn,
Includes,
NotIncludes,
StartsWith,
EndsWith,
Ilike,
NotIlike,
IlikeIncludes,
NotIlikeIncludes,
IlikeStartsWith,
IsNull,
IsNotNull,
}
impl CompareOp {
pub fn sql_op(&self) -> &'static str {
match self {
CompareOp::Eq => "=",
CompareOp::Ne => "!=",
CompareOp::Gt => ">",
CompareOp::Ge => ">=",
CompareOp::Lt => "<",
CompareOp::Le => "<=",
CompareOp::Like => "LIKE",
CompareOp::NotLike => "NOT LIKE",
CompareOp::In => "IN",
CompareOp::NotIn => "NOT IN",
CompareOp::Includes => "LIKE",
CompareOp::NotIncludes => "NOT LIKE",
CompareOp::StartsWith => "LIKE",
CompareOp::EndsWith => "LIKE",
CompareOp::Ilike => "ilike",
CompareOp::NotIlike => "NOT ilike",
CompareOp::IlikeIncludes => "ilike",
CompareOp::NotIlikeIncludes => "NOT ilike",
CompareOp::IlikeStartsWith => "ilike",
CompareOp::IsNull => "IS NULL",
CompareOp::IsNotNull => "IS NOT NULL",
}
}
pub fn is_unary(&self) -> bool {
matches!(self, CompareOp::IsNull | CompareOp::IsNotNull)
}
}
#[derive(Debug, Clone)]
pub struct OrderExpr {
pub column: String,
pub descending: bool,
}
#[derive(Debug, Clone)]
pub struct LimitByExpr {
pub count: u32,
pub offset: u32,
pub columns: Vec<String>,
}
impl FilterNode {
pub fn is_empty(&self) -> bool {
matches!(self, FilterNode::Empty)
}
}
const AGGREGATE_FUNCTIONS: &[&str] = &[
"count", "sum", "avg", "min", "max", "any",
"uniq", "uniqexact", "uniqcombined", "uniqhll12",
"argmax", "argmin",
"quantile", "quantiles", "quantileexact", "quantiletiming",
"median",
"grouparray", "groupuniqarray", "groupbitand", "groupbitor", "groupbitxor",
"topk", "entropy", "varpop", "varsamp", "stddevpop", "stddevsamp",
"covarsamp", "covarpop", "corr",
];
fn is_aggregate_func_name(name: &str) -> bool {
let lower = name.to_lowercase();
if lower.ends_with("merge") || lower.ends_with("mergestate") {
return true;
}
let base = lower.strip_suffix("if").unwrap_or(&lower);
AGGREGATE_FUNCTIONS.contains(&base)
}
pub fn is_aggregate_expr(column: &str) -> bool {
let Some(paren_pos) = column.find('(') else {
return false;
};
let func_name = column[..paren_pos].trim();
is_aggregate_func_name(func_name)
}
pub fn contains_aggregate_expr(column: &str) -> bool {
if !column.contains('(') {
return false;
}
if is_aggregate_expr(column) {
return true;
}
for (i, _) in column.match_indices('(') {
let before = &column[..i];
let func_name = before.rsplit(|c: char| !c.is_alphanumeric() && c != '_')
.next()
.unwrap_or("");
if !func_name.is_empty() && is_aggregate_func_name(func_name) {
return true;
}
}
false
}
pub struct CompileResult {
pub sql: String,
pub bindings: Vec<SqlValue>,
pub alias_remap: Vec<(String, String)>,
}