1use std::sync::Arc;
2
3#[derive(Debug, Clone)]
5pub enum SqlValue {
6 String(String),
7 Int(i64),
8 Float(f64),
9 Bool(bool),
10 Expression(String),
12}
13
14#[derive(Debug, Clone, PartialEq, Eq, Default)]
16pub enum JoinType {
17 #[default]
18 Left,
19 Inner,
20 Full,
21 Cross,
22}
23
24impl JoinType {
25 pub fn sql_keyword(&self) -> &'static str {
26 match self {
27 JoinType::Left => "LEFT JOIN",
28 JoinType::Inner => "INNER JOIN",
29 JoinType::Full => "FULL OUTER JOIN",
30 JoinType::Cross => "CROSS JOIN",
31 }
32 }
33}
34
35#[derive(Clone)]
39pub struct QueryBuilderFn(pub Arc<dyn Fn(&QueryIR) -> CompileResult + Send + Sync>);
40
41impl std::fmt::Debug for QueryBuilderFn {
42 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
43 f.write_str("QueryBuilderFn(...)")
44 }
45}
46
47#[derive(Debug, Clone)]
49pub struct QueryIR {
50 pub cube: String,
51 pub schema: String,
52 pub table: String,
53 pub selects: Vec<SelectExpr>,
54 pub filters: FilterNode,
55 pub having: FilterNode,
56 pub group_by: Vec<String>,
57 pub order_by: Vec<OrderExpr>,
58 pub limit: u32,
59 pub offset: u32,
60 pub limit_by: Option<LimitByExpr>,
62 pub use_final: bool,
64 pub joins: Vec<JoinExpr>,
66 pub custom_query_builder: Option<QueryBuilderFn>,
68 pub from_subquery: Option<String>,
71}
72
73#[derive(Debug, Clone)]
75pub struct JoinExpr {
76 pub schema: String,
77 pub table: String,
78 pub alias: String,
80 pub conditions: Vec<(String, String)>,
82 pub selects: Vec<SelectExpr>,
84 pub group_by: Vec<String>,
86 pub use_final: bool,
88 pub is_aggregate: bool,
90 pub target_cube: String,
92 pub join_field: String,
94 pub join_type: JoinType,
96}
97
98#[derive(Debug, Clone, PartialEq, Eq)]
101pub enum DimAggType {
102 ArgMax,
103 ArgMin,
104}
105
106#[derive(Debug, Clone)]
107pub enum SelectExpr {
108 Column {
109 column: String,
110 alias: Option<String>,
111 },
112 Aggregate {
113 function: String,
114 column: String,
115 alias: String,
116 condition: Option<String>,
117 },
118 DimAggregate {
121 agg_type: DimAggType,
122 value_column: String,
123 compare_column: String,
124 alias: String,
125 condition: Option<String>,
126 },
127}
128
129#[derive(Debug, Clone)]
130pub enum FilterNode {
131 And(Vec<FilterNode>),
132 Or(Vec<FilterNode>),
133 Condition {
134 column: String,
135 op: CompareOp,
136 value: SqlValue,
137 },
138 ArrayIncludes {
141 array_columns: Vec<String>,
143 element_conditions: Vec<Vec<FilterNode>>,
146 },
147 Empty,
148}
149
150#[derive(Debug, Clone)]
151pub enum CompareOp {
152 Eq,
153 Ne,
154 Gt,
155 Ge,
156 Lt,
157 Le,
158 Like,
159 NotLike,
160 In,
161 NotIn,
162 Includes,
163 NotIncludes,
164 StartsWith,
165 EndsWith,
166 Ilike,
167 NotIlike,
168 IlikeIncludes,
169 NotIlikeIncludes,
170 IlikeStartsWith,
171 IsNull,
172 IsNotNull,
173}
174
175impl CompareOp {
176 pub fn sql_op(&self) -> &'static str {
177 match self {
178 CompareOp::Eq => "=",
179 CompareOp::Ne => "!=",
180 CompareOp::Gt => ">",
181 CompareOp::Ge => ">=",
182 CompareOp::Lt => "<",
183 CompareOp::Le => "<=",
184 CompareOp::Like => "LIKE",
185 CompareOp::NotLike => "NOT LIKE",
186 CompareOp::In => "IN",
187 CompareOp::NotIn => "NOT IN",
188 CompareOp::Includes => "LIKE",
189 CompareOp::NotIncludes => "NOT LIKE",
190 CompareOp::StartsWith => "LIKE",
191 CompareOp::EndsWith => "LIKE",
192 CompareOp::Ilike => "ilike",
193 CompareOp::NotIlike => "NOT ilike",
194 CompareOp::IlikeIncludes => "ilike",
195 CompareOp::NotIlikeIncludes => "NOT ilike",
196 CompareOp::IlikeStartsWith => "ilike",
197 CompareOp::IsNull => "IS NULL",
198 CompareOp::IsNotNull => "IS NOT NULL",
199 }
200 }
201
202 pub fn is_unary(&self) -> bool {
203 matches!(self, CompareOp::IsNull | CompareOp::IsNotNull)
204 }
205}
206
207#[derive(Debug, Clone)]
208pub struct OrderExpr {
209 pub column: String,
210 pub descending: bool,
211}
212
213#[derive(Debug, Clone)]
214pub struct LimitByExpr {
215 pub count: u32,
216 pub offset: u32,
217 pub columns: Vec<String>,
218}
219
220impl FilterNode {
221 pub fn is_empty(&self) -> bool {
222 matches!(self, FilterNode::Empty)
223 }
224}
225
226const AGGREGATE_FUNCTIONS: &[&str] = &[
228 "count", "sum", "avg", "min", "max", "any",
229 "uniq", "uniqexact", "uniqcombined", "uniqhll12",
230 "argmax", "argmin",
231 "quantile", "quantiles", "quantileexact", "quantiletiming",
232 "median",
233 "grouparray", "groupuniqarray", "groupbitand", "groupbitor", "groupbitxor",
234 "topk", "entropy", "varpop", "varsamp", "stddevpop", "stddevsamp",
235 "covarsamp", "covarpop", "corr",
236];
237
238fn is_aggregate_func_name(name: &str) -> bool {
239 let lower = name.to_lowercase();
240 if lower.ends_with("merge") || lower.ends_with("mergestate") {
241 return true;
242 }
243 let base = lower.strip_suffix("if").unwrap_or(&lower);
244 AGGREGATE_FUNCTIONS.contains(&base)
245}
246
247pub fn is_aggregate_expr(column: &str) -> bool {
251 let Some(paren_pos) = column.find('(') else {
252 return false;
253 };
254 let func_name = column[..paren_pos].trim();
255 is_aggregate_func_name(func_name)
256}
257
258pub fn contains_aggregate_expr(column: &str) -> bool {
262 if !column.contains('(') {
263 return false;
264 }
265 if is_aggregate_expr(column) {
266 return true;
267 }
268 for (i, _) in column.match_indices('(') {
269 let before = &column[..i];
270 let func_name = before.rsplit(|c: char| !c.is_alphanumeric() && c != '_')
271 .next()
272 .unwrap_or("");
273 if !func_name.is_empty() && is_aggregate_func_name(func_name) {
274 return true;
275 }
276 }
277 false
278}
279
280pub struct CompileResult {
282 pub sql: String,
283 pub bindings: Vec<SqlValue>,
284 pub alias_remap: Vec<(String, String)>,
287}