1use chartml_core::spec::{
7 AggregateSpec, Dimension, FilterGroup, FilterRule, Measure,
8};
9use std::collections::HashMap;
10
11#[derive(Debug, Clone)]
13struct Symbol {
14 sql: String,
15 is_aggregated: bool,
16}
17
18fn quote_identifier(id: &str) -> String {
21 if id.is_empty() {
22 return "\"\"".to_string();
23 }
24 if id.contains('(') || id.contains('*') {
26 return id.to_string();
27 }
28 format!("\"{}\"", id.replace('"', "\"\""))
29}
30
31fn escape_string(s: &str) -> String {
33 s.replace('\'', "''")
34}
35
36fn aggregation_to_sql(agg: &str, column: &str) -> Option<String> {
38 let quoted_col = quote_identifier(column);
39 let agg_lower = agg.to_lowercase();
40 match agg_lower.as_str() {
41 "sum" => Some(format!("SUM({})", quoted_col)),
42 "avg" => Some(format!("AVG({})", quoted_col)),
43 "count" => Some(format!("COUNT({})", quoted_col)),
44 "min" => Some(format!("MIN({})", quoted_col)),
45 "max" => Some(format!("MAX({})", quoted_col)),
46 "countdistinct" => Some(format!("COUNT(DISTINCT {})", quoted_col)),
47 "median" => Some(format!("MEDIAN({})", quoted_col)),
48 "stddev" => Some(format!("STDDEV({})", quoted_col)),
49 "variance" => Some(format!("VARIANCE({})", quoted_col)),
50 _ => {
51 if agg_lower.starts_with("percentile") {
53 let pct_str = agg_lower.strip_prefix("percentile")?;
54 let pct: u32 = pct_str.parse().ok()?;
55 let fraction = pct as f64 / 100.0;
56 Some(format!(
57 "PERCENTILE_CONT({}) WITHIN GROUP (ORDER BY {})",
58 fraction, quoted_col
59 ))
60 } else {
61 None
62 }
63 }
64 }
65}
66
67fn build_symbol_table(
70 dimensions: &[Dimension],
71 measures: &[Measure],
72) -> HashMap<String, Symbol> {
73 let mut symbols = HashMap::new();
74
75 for dim in dimensions {
77 match dim {
78 Dimension::Simple(name) => {
79 symbols.insert(
80 name.clone(),
81 Symbol {
82 sql: quote_identifier(name),
83 is_aggregated: false,
84 },
85 );
86 }
87 Dimension::Detailed(spec) => {
88 let field_name = spec.name.clone().unwrap_or_else(|| spec.column.clone());
89 let sql_expr = if spec.column.contains('(') {
91 spec.column.clone()
92 } else {
93 quote_identifier(&spec.column)
94 };
95 symbols.insert(
96 field_name,
97 Symbol {
98 sql: sql_expr,
99 is_aggregated: false,
100 },
101 );
102 }
103 }
104 }
105
106 let mut calculated: Vec<(String, String)> = Vec::new();
108 for measure in measures {
109 if let Some(ref agg) = measure.aggregation {
110 if let Some(ref col) = measure.column {
111 if let Some(sql_expr) = aggregation_to_sql(agg, col) {
112 symbols.insert(
113 measure.name.clone(),
114 Symbol {
115 sql: sql_expr,
116 is_aggregated: true,
117 },
118 );
119 }
120 }
121 } else if let Some(ref expr) = measure.expression {
122 calculated.push((measure.name.clone(), expr.clone()));
123 }
124 }
125
126 for (field_name, expression) in calculated {
128 let resolved = resolve_expression(&expression, &symbols);
129 symbols.insert(
130 field_name,
131 Symbol {
132 sql: resolved,
133 is_aggregated: true,
134 },
135 );
136 }
137
138 symbols
139}
140
141fn resolve_expression(expression: &str, symbols: &HashMap<String, Symbol>) -> String {
144 let mut resolved = expression.to_string();
145
146 let mut field_names: Vec<&String> = symbols.keys().collect();
148 field_names.sort_by_key(|b| std::cmp::Reverse(b.len()));
149
150 for field_name in field_names {
151 if let Some(symbol) = symbols.get(field_name) {
152 resolved = replace_whole_word(&resolved, field_name, &symbol.sql);
155 }
156 }
157
158 format!("({})", resolved)
159}
160
161fn replace_whole_word(text: &str, target: &str, replacement: &str) -> String {
163 if target.is_empty() {
164 return text.to_string();
165 }
166
167 let mut result = String::new();
168 let mut remaining = text;
169
170 while let Some(pos) = remaining.find(target) {
171 let before_ok = if pos == 0 {
173 true
174 } else {
175 match remaining[..pos].chars().last() {
176 Some(ch) => !ch.is_alphanumeric() && ch != '_',
177 None => true,
178 }
179 };
180
181 let after_pos = pos + target.len();
183 let after_ok = if after_pos >= remaining.len() {
184 true
185 } else {
186 match remaining[after_pos..].chars().next() {
187 Some(ch) => !ch.is_alphanumeric() && ch != '_',
188 None => true,
189 }
190 };
191
192 if before_ok && after_ok {
193 result.push_str(&remaining[..pos]);
194 result.push_str(replacement);
195 remaining = &remaining[after_pos..];
196 } else {
197 result.push_str(&remaining[..after_pos]);
198 remaining = &remaining[after_pos..];
199 }
200 }
201 result.push_str(remaining);
202
203 result
204}
205
206fn format_filter_value(value: &Option<serde_json::Value>, operator: &str) -> String {
208 if operator == "isNull" || operator == "isNotNull" {
210 return String::new();
211 }
212
213 let value = match value {
214 Some(v) => v,
215 None => return String::new(),
216 };
217
218 match operator {
219 "in" | "notIn" => {
220 let items = match value {
221 serde_json::Value::Array(arr) => arr.clone(),
222 other => vec![other.clone()],
223 };
224 let formatted: Vec<String> = items
225 .iter()
226 .map(|v| match v {
227 serde_json::Value::String(s) => format!("'{}'", escape_string(s)),
228 serde_json::Value::Number(n) => n.to_string(),
229 serde_json::Value::Bool(b) => b.to_string(),
230 _ => format!("'{}'", v),
231 })
232 .collect();
233 format!("({})", formatted.join(", "))
234 }
235 "between" => {
236 if let serde_json::Value::Array(arr) = value {
237 if arr.len() == 2 {
238 let v1 = format_scalar_value(&arr[0]);
239 let v2 = format_scalar_value(&arr[1]);
240 return format!("{} AND {}", v1, v2);
241 }
242 }
243 String::new()
244 }
245 "contains" => {
246 let s = value.as_str().unwrap_or("");
247 format!("'%{}%'", escape_string(s))
248 }
249 "startsWith" => {
250 let s = value.as_str().unwrap_or("");
251 format!("'{}%'", escape_string(s))
252 }
253 "endsWith" => {
254 let s = value.as_str().unwrap_or("");
255 format!("'%{}'", escape_string(s))
256 }
257 _ => format_scalar_value(value),
258 }
259}
260
261fn format_scalar_value(value: &serde_json::Value) -> String {
263 match value {
264 serde_json::Value::String(s) => format!("'{}'", escape_string(s)),
265 serde_json::Value::Number(n) => n.to_string(),
266 serde_json::Value::Bool(b) => b.to_string(),
267 serde_json::Value::Null => "NULL".to_string(),
268 _ => format!("'{}'", value),
269 }
270}
271
272fn operator_to_sql(op: &str) -> Option<&'static str> {
274 match op {
275 "=" | "==" => Some("="),
276 "!=" => Some("!="),
277 "<" => Some("<"),
278 ">" => Some(">"),
279 "<=" => Some("<="),
280 ">=" => Some(">="),
281 "contains" | "startsWith" | "endsWith" => Some("LIKE"),
282 "isNull" => Some("IS NULL"),
283 "isNotNull" => Some("IS NOT NULL"),
284 "in" => Some("IN"),
285 "notIn" => Some("NOT IN"),
286 "between" => Some("BETWEEN"),
287 _ => None,
288 }
289}
290
291fn build_filter_condition(rule: &FilterRule, symbols: &HashMap<String, Symbol>) -> String {
293 let sql_op = match operator_to_sql(&rule.operator) {
294 Some(op) => op,
295 None => return String::new(),
296 };
297
298 let sql_expr = if let Some(sym) = symbols.get(&rule.field) {
300 sym.sql.clone()
301 } else {
302 quote_identifier(&rule.field)
303 };
304
305 if rule.operator == "in" || rule.operator == "notIn" {
307 if let Some(serde_json::Value::Array(arr)) = &rule.value {
308 if arr.is_empty() {
309 return if rule.operator == "in" {
310 "(1=0)".to_string()
311 } else {
312 "(1=1)".to_string()
313 };
314 }
315 }
316 }
317
318 let formatted_value = format_filter_value(&rule.value, &rule.operator);
319
320 if rule.operator == "isNull" || rule.operator == "isNotNull" {
321 return format!("{} {}", sql_expr, sql_op);
322 }
323
324 format!("{} {} {}", sql_expr, sql_op, formatted_value)
325}
326
327fn build_filter_clause(
329 filter: &FilterGroup,
330 symbols: &HashMap<String, Symbol>,
331) -> String {
332 if filter.rules.is_empty() {
333 return String::new();
334 }
335
336 let combinator = filter
337 .combinator
338 .as_deref()
339 .unwrap_or("and")
340 .to_uppercase();
341
342 let conditions: Vec<String> = filter
343 .rules
344 .iter()
345 .map(|rule| build_filter_condition(rule, symbols))
346 .filter(|c| !c.is_empty())
347 .collect();
348
349 if conditions.is_empty() {
350 return String::new();
351 }
352
353 conditions.join(&format!(" {} ", combinator))
354}
355
356fn partition_filters(
359 filter: &FilterGroup,
360 symbols: &HashMap<String, Symbol>,
361) -> (Option<FilterGroup>, Option<FilterGroup>) {
362 let mut where_rules = Vec::new();
363 let mut having_rules = Vec::new();
364
365 for rule in &filter.rules {
366 let symbol = symbols.get(&rule.field);
367 if symbol.is_some_and(|s| s.is_aggregated) {
368 having_rules.push(rule.clone());
369 } else {
370 where_rules.push(rule.clone());
371 }
372 }
373
374 let combinator = filter.combinator.clone();
375
376 let where_filter = if where_rules.is_empty() {
377 None
378 } else {
379 Some(FilterGroup {
380 combinator: combinator.clone(),
381 rules: where_rules,
382 })
383 };
384
385 let having_filter = if having_rules.is_empty() {
386 None
387 } else {
388 Some(FilterGroup {
389 combinator,
390 rules: having_rules,
391 })
392 };
393
394 (where_filter, having_filter)
395}
396
397pub fn build_aggregate_sql(table_name: &str, spec: &AggregateSpec) -> String {
401 let is_passthrough = spec.dimensions.is_empty() && spec.measures.is_empty();
402
403 if is_passthrough {
404 let mut sql = format!("SELECT * FROM {}", table_name);
406
407 if let Some(ref filters) = spec.filters {
408 let symbols = HashMap::new();
409 let clause = build_filter_clause(filters, &symbols);
410 if !clause.is_empty() {
411 sql.push_str(&format!("\nWHERE {}", clause));
412 }
413 }
414
415 if let Some(ref sorts) = spec.sort {
416 if !sorts.is_empty() {
417 let order_clauses: Vec<String> = sorts
418 .iter()
419 .map(|s| {
420 let dir = s
421 .direction
422 .as_deref()
423 .unwrap_or("ASC")
424 .to_uppercase();
425 format!("{} {}", quote_identifier(&s.field), dir)
426 })
427 .collect();
428 sql.push_str(&format!("\nORDER BY {}", order_clauses.join(", ")));
429 }
430 }
431
432 if let Some(limit) = spec.limit {
433 sql.push_str(&format!("\nLIMIT {}", limit));
434 }
435
436 return sql;
437 }
438
439 let symbols = build_symbol_table(&spec.dimensions, &spec.measures);
440 let has_aggregation = spec
441 .measures
442 .iter()
443 .any(|m| m.aggregation.is_some() || m.expression.is_some());
444
445 let mut select_cols = Vec::new();
447 let mut group_by_cols = Vec::new();
448
449 for dim in &spec.dimensions {
450 match dim {
451 Dimension::Simple(name) => {
452 let quoted = quote_identifier(name);
453 select_cols.push(quoted.clone());
454 if has_aggregation {
455 group_by_cols.push(quoted);
456 }
457 }
458 Dimension::Detailed(dspec) => {
459 let field_name = dspec
460 .name
461 .clone()
462 .unwrap_or_else(|| dspec.column.clone());
463 let sql_expr = if dspec.column.contains('(') {
464 dspec.column.clone()
465 } else {
466 quote_identifier(&dspec.column)
467 };
468
469 if field_name == dspec.column && !dspec.column.contains('(') {
470 select_cols.push(quote_identifier(&field_name));
471 } else {
472 select_cols.push(format!(
473 "{} as {}",
474 sql_expr,
475 quote_identifier(&field_name)
476 ));
477 }
478
479 if has_aggregation {
480 group_by_cols.push(sql_expr);
481 }
482 }
483 }
484 }
485
486 for measure in &spec.measures {
488 if let Some(symbol) = symbols.get(&measure.name) {
489 if measure.name == symbol.sql {
490 select_cols.push(quote_identifier(&measure.name));
491 } else {
492 select_cols.push(format!(
493 "{} as {}",
494 symbol.sql,
495 quote_identifier(&measure.name)
496 ));
497 }
498 }
499 }
500
501 let mut where_clause = String::new();
503 let mut having_clause = String::new();
504
505 if let Some(ref filters) = spec.filters {
506 let (where_filter, having_filter) = partition_filters(filters, &symbols);
507
508 if let Some(ref wf) = where_filter {
509 let clause = build_filter_clause(wf, &symbols);
510 if !clause.is_empty() {
511 where_clause = format!("\nWHERE {}", clause);
512 }
513 }
514
515 if let Some(ref hf) = having_filter {
516 let clause = build_filter_clause(hf, &symbols);
517 if !clause.is_empty() {
518 having_clause = format!("\nHAVING {}", clause);
519 }
520 }
521 }
522
523 let group_by_clause = if has_aggregation && !group_by_cols.is_empty() {
525 format!("\nGROUP BY {}", group_by_cols.join(", "))
526 } else {
527 String::new()
528 };
529
530 let order_by_clause = if let Some(ref sorts) = spec.sort {
532 if sorts.is_empty() {
533 String::new()
534 } else {
535 let clauses: Vec<String> = sorts
536 .iter()
537 .map(|s| {
538 let dir = s
539 .direction
540 .as_deref()
541 .unwrap_or("ASC")
542 .to_uppercase();
543 format!("{} {}", quote_identifier(&s.field), dir)
544 })
545 .collect();
546 format!("\nORDER BY {}", clauses.join(", "))
547 }
548 } else {
549 String::new()
550 };
551
552 let limit_clause = if let Some(limit) = spec.limit {
554 format!("\nLIMIT {}", limit)
555 } else {
556 String::new()
557 };
558
559 let select_str = select_cols.join(",\n ");
560 format!(
561 "SELECT\n {}\nFROM {}{}{}{}{}{}",
562 select_str,
563 table_name,
564 where_clause,
565 group_by_clause,
566 having_clause,
567 order_by_clause,
568 limit_clause
569 )
570}
571
572#[cfg(test)]
573mod tests {
574 use super::*;
575 use chartml_core::spec::*;
576
577 #[test]
578 fn test_aggregate_sql_basic() {
579 let spec = AggregateSpec {
580 dimensions: vec![Dimension::Simple("region".to_string())],
581 measures: vec![Measure {
582 column: Some("revenue".to_string()),
583 aggregation: Some("sum".to_string()),
584 name: "total_revenue".to_string(),
585 expression: None,
586 }],
587 filters: None,
588 sort: None,
589 limit: None,
590 };
591
592 let sql = build_aggregate_sql("source", &spec);
593 assert!(sql.contains("SELECT"), "SQL: {}", sql);
594 assert!(sql.contains("\"region\""), "SQL: {}", sql);
595 assert!(sql.contains("SUM(\"revenue\") as \"total_revenue\""), "SQL: {}", sql);
596 assert!(sql.contains("FROM source"), "SQL: {}", sql);
597 assert!(sql.contains("GROUP BY \"region\""), "SQL: {}", sql);
598 }
599
600 #[test]
601 fn test_aggregate_sql_with_filters() {
602 let spec = AggregateSpec {
603 dimensions: vec![Dimension::Simple("region".to_string())],
604 measures: vec![Measure {
605 column: Some("revenue".to_string()),
606 aggregation: Some("sum".to_string()),
607 name: "total_revenue".to_string(),
608 expression: None,
609 }],
610 filters: Some(FilterGroup {
611 combinator: None,
612 rules: vec![
613 FilterRule {
614 field: "category".to_string(),
615 operator: "=".to_string(),
616 value: Some(serde_json::json!("Electronics")),
617 },
618 FilterRule {
619 field: "total_revenue".to_string(),
620 operator: ">=".to_string(),
621 value: Some(serde_json::json!(50000)),
622 },
623 ],
624 }),
625 sort: None,
626 limit: None,
627 };
628
629 let sql = build_aggregate_sql("source", &spec);
630 assert!(sql.contains("WHERE"), "SQL should have WHERE: {}", sql);
632 assert!(
633 sql.contains("\"category\" = 'Electronics'"),
634 "WHERE should filter category: {}",
635 sql
636 );
637 assert!(sql.contains("HAVING"), "SQL should have HAVING: {}", sql);
639 assert!(
640 sql.contains("SUM(\"revenue\") >= 50000"),
641 "HAVING should filter total_revenue: {}",
642 sql
643 );
644 }
645
646 #[test]
647 fn test_aggregate_sql_with_expressions() {
648 let spec = AggregateSpec {
649 dimensions: vec![Dimension::Simple("region".to_string())],
650 measures: vec![
651 Measure {
652 column: Some("revenue".to_string()),
653 aggregation: Some("sum".to_string()),
654 name: "total_revenue".to_string(),
655 expression: None,
656 },
657 Measure {
658 column: Some("units".to_string()),
659 aggregation: Some("sum".to_string()),
660 name: "total_units".to_string(),
661 expression: None,
662 },
663 Measure {
664 column: None,
665 aggregation: None,
666 name: "avg_price".to_string(),
667 expression: Some("total_revenue / total_units".to_string()),
668 },
669 ],
670 filters: None,
671 sort: None,
672 limit: None,
673 };
674
675 let sql = build_aggregate_sql("source", &spec);
676 assert!(sql.contains("SUM(\"revenue\") as \"total_revenue\""), "SQL: {}", sql);
677 assert!(sql.contains("SUM(\"units\") as \"total_units\""), "SQL: {}", sql);
678 assert!(
680 sql.contains("(SUM(\"revenue\") / SUM(\"units\")) as \"avg_price\""),
681 "Expression measure should be inlined: {}",
682 sql
683 );
684 }
685
686 #[test]
687 fn test_aggregate_sql_passthrough() {
688 let spec = AggregateSpec {
689 dimensions: vec![],
690 measures: vec![],
691 filters: None,
692 sort: Some(vec![SortSpec {
693 field: "name".to_string(),
694 direction: Some("asc".to_string()),
695 }]),
696 limit: Some(10),
697 };
698
699 let sql = build_aggregate_sql("source", &spec);
700 assert!(sql.contains("SELECT * FROM source"), "SQL: {}", sql);
701 assert!(sql.contains("ORDER BY"), "SQL: {}", sql);
702 assert!(sql.contains("LIMIT 10"), "SQL: {}", sql);
703 }
704
705 #[test]
706 fn test_quote_identifier() {
707 assert_eq!(quote_identifier("region"), "\"region\"");
708 assert_eq!(
709 quote_identifier("DATE_TRUNC(sale_date, 'MONTH')"),
710 "DATE_TRUNC(sale_date, 'MONTH')"
711 );
712 assert_eq!(quote_identifier("*"), "*");
713 assert_eq!(quote_identifier(""), "\"\"");
714 }
715
716 #[test]
717 fn test_aggregate_sql_sort_and_limit() {
718 let spec = AggregateSpec {
719 dimensions: vec![Dimension::Simple("region".to_string())],
720 measures: vec![Measure {
721 column: Some("revenue".to_string()),
722 aggregation: Some("sum".to_string()),
723 name: "total_revenue".to_string(),
724 expression: None,
725 }],
726 filters: None,
727 sort: Some(vec![SortSpec {
728 field: "total_revenue".to_string(),
729 direction: Some("desc".to_string()),
730 }]),
731 limit: Some(5),
732 };
733
734 let sql = build_aggregate_sql("source", &spec);
735 assert!(sql.contains("ORDER BY \"total_revenue\" DESC"), "SQL: {}", sql);
736 assert!(sql.contains("LIMIT 5"), "SQL: {}", sql);
737 }
738
739 #[test]
740 fn test_aggregate_sql_count_distinct() {
741 let spec = AggregateSpec {
742 dimensions: vec![Dimension::Simple("region".to_string())],
743 measures: vec![Measure {
744 column: Some("customer_id".to_string()),
745 aggregation: Some("countdistinct".to_string()),
746 name: "unique_customers".to_string(),
747 expression: None,
748 }],
749 filters: None,
750 sort: None,
751 limit: None,
752 };
753
754 let sql = build_aggregate_sql("source", &spec);
755 assert!(
756 sql.contains("COUNT(DISTINCT \"customer_id\") as \"unique_customers\""),
757 "SQL: {}",
758 sql
759 );
760 }
761
762 #[test]
763 fn test_aggregate_sql_detailed_dimension() {
764 let spec = AggregateSpec {
765 dimensions: vec![Dimension::Detailed(DimensionSpec {
766 column: "DATE_TRUNC(sale_date, 'MONTH')".to_string(),
767 name: Some("month".to_string()),
768 dim_type: None,
769 })],
770 measures: vec![Measure {
771 column: Some("revenue".to_string()),
772 aggregation: Some("sum".to_string()),
773 name: "total_revenue".to_string(),
774 expression: None,
775 }],
776 filters: None,
777 sort: None,
778 limit: None,
779 };
780
781 let sql = build_aggregate_sql("source", &spec);
782 assert!(
783 sql.contains("DATE_TRUNC(sale_date, 'MONTH') as \"month\""),
784 "SQL: {}",
785 sql
786 );
787 assert!(
788 sql.contains("GROUP BY DATE_TRUNC(sale_date, 'MONTH')"),
789 "SQL: {}",
790 sql
791 );
792 }
793}