clickhouse_datafusion/
dialect.rs

1//! A custom [`UnparserDialect`] for `ClickHouse`.
2use datafusion::common::{plan_datafusion_err, plan_err};
3use datafusion::error::Result;
4use datafusion::prelude::*;
5use datafusion::scalar::ScalarValue;
6use datafusion::sql::sqlparser::parser::Parser;
7use datafusion::sql::sqlparser::tokenizer::Tokenizer;
8use datafusion::sql::sqlparser::{ast, dialect};
9use datafusion::sql::unparser::Unparser;
10use datafusion::sql::unparser::dialect::Dialect as UnparserDialect;
11
12use crate::udfs::apply::ClickHouseApplyRewriter;
13use crate::udfs::clickhouse::CLICKHOUSE_UDF_ALIASES;
14use crate::udfs::eval::CLICKHOUSE_EVAL_UDF_ALIASES;
15
16/// A custom [`UnparserDialect`] for `ClickHouse`.
17#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash)]
18pub struct ClickHouseDialect;
19
20impl UnparserDialect for ClickHouseDialect {
21    fn identifier_quote_style(&self, _: &str) -> Option<char> { Some('`') }
22
23    fn scalar_function_to_sql_overrides(
24        &self,
25        unparser: &Unparser<'_>,
26        func_name: &str,
27        args: &[Expr],
28    ) -> Result<Option<ast::Expr>> {
29        // First check for `clickhouse`/lambda UDFs
30        if CLICKHOUSE_UDF_ALIASES.contains(&func_name) {
31            let Some(inner_expr) = args.first() else {
32                return plan_err!("`clickhouse` expects a first argument, no arg provided");
33            };
34
35            // If the inner expression is a "lambda" or "HOF", attempt to rewrite it
36            if let Ok(rewriter) = ClickHouseApplyRewriter::try_new(inner_expr) {
37                rewriter.rewrite_to_ast(unparser).map(Some)
38            } else {
39                unparser.expr_to_sql(inner_expr).map(Some)
40            }
41
42        // Then check for eval UDFs
43        } else if CLICKHOUSE_EVAL_UDF_ALIASES.contains(&func_name) {
44            if let Some(Expr::Literal(
45                ScalarValue::Utf8(Some(s))
46                | ScalarValue::Utf8View(Some(s))
47                | ScalarValue::LargeUtf8(Some(s)),
48                _,
49            )) = args.first()
50            {
51                if s.is_empty() {
52                    return plan_err!("`clickhouse_eval` syntax argument cannot be empty");
53                }
54
55                // Tokenize the string with ClickHouseDialect
56                let mut tokenizer = Tokenizer::new(&dialect::ClickHouseDialect {}, s);
57                let tokens = tokenizer.tokenize().map_err(|e| {
58                    plan_datafusion_err!("Failed to tokenize ClickHouse expression '{s}': {e}")
59                })?;
60                // Create a Parser instance
61                let mut parser = Parser::new(&dialect::ClickHouseDialect {}).with_tokens(tokens);
62                Ok(Some(parser.parse_expr().map_err(|e| {
63                    plan_datafusion_err!("Invalid ClickHouse expression '{s}': {e}")
64                })?))
65            } else {
66                plan_err!(
67                    "`clickhouse_eval` expects a string literal syntax argument, found: {:?}",
68                    args.first()
69                )
70            }
71
72        // No relevant functions
73        } else {
74            Ok(None)
75        }
76    }
77}
78
79#[cfg(test)]
80mod tests {
81    use datafusion::scalar::ScalarValue;
82    use datafusion::sql::unparser::Unparser;
83
84    use super::*;
85
86    #[test]
87    fn test_identifier_quote_style() {
88        let dialect = ClickHouseDialect;
89        assert_eq!(dialect.identifier_quote_style("test"), Some('`'));
90        assert_eq!(dialect.identifier_quote_style(""), Some('`'));
91    }
92
93    #[test]
94    fn test_scalar_function_to_sql_overrides_clickhouse_eval() {
95        let dialect = ClickHouseDialect;
96        let unparser = Unparser::new(&dialect);
97
98        // Test valid clickhouse_eval function with string literal
99        let args = vec![Expr::Literal(ScalarValue::Utf8(Some("count()".to_string())), None)];
100        let result = dialect.scalar_function_to_sql_overrides(&unparser, "clickhouse_eval", &args);
101        assert!(result.is_ok());
102        assert!(result.unwrap().is_some());
103    }
104
105    #[test]
106    fn test_scalar_function_to_sql_overrides_clickhouse_eval_utf8view() {
107        let dialect = ClickHouseDialect;
108        let unparser = Unparser::new(&dialect);
109
110        // Test with Utf8View
111        let args = vec![Expr::Literal(ScalarValue::Utf8View(Some("sum(x)".to_string())), None)];
112        let result = dialect.scalar_function_to_sql_overrides(&unparser, "clickhouse_eval", &args);
113        assert!(result.is_ok());
114        assert!(result.unwrap().is_some());
115    }
116
117    #[test]
118    fn test_scalar_function_to_sql_overrides_clickhouse_eval_large_utf8() {
119        let dialect = ClickHouseDialect;
120        let unparser = Unparser::new(&dialect);
121
122        // Test with LargeUtf8
123        let args = vec![Expr::Literal(ScalarValue::LargeUtf8(Some("avg(y)".to_string())), None)];
124        let result = dialect.scalar_function_to_sql_overrides(&unparser, "clickhouse_eval", &args);
125        assert!(result.is_ok());
126        assert!(result.unwrap().is_some());
127    }
128
129    #[test]
130    fn test_scalar_function_to_sql_overrides_clickhouse_eval_empty_string() {
131        let dialect = ClickHouseDialect;
132        let unparser = Unparser::new(&dialect);
133
134        // Test empty string should return error
135        let args = vec![Expr::Literal(ScalarValue::Utf8(Some(String::new())), None)];
136        let result = dialect.scalar_function_to_sql_overrides(&unparser, "clickhouse_eval", &args);
137        assert!(result.is_err());
138        assert!(result.unwrap_err().to_string().contains("cannot be empty"));
139    }
140
141    #[test]
142    fn test_scalar_function_to_sql_overrides_clickhouse_eval_invalid_arg() {
143        let dialect = ClickHouseDialect;
144        let unparser = Unparser::new(&dialect);
145
146        // Test non-string literal should return error
147        let args = vec![Expr::Literal(ScalarValue::Int32(Some(42)), None)];
148        let result = dialect.scalar_function_to_sql_overrides(&unparser, "clickhouse_eval", &args);
149        assert!(result.is_err());
150        assert!(result.unwrap_err().to_string().contains("expects a string literal"));
151    }
152
153    #[test]
154    fn test_scalar_function_to_sql_overrides_clickhouse_eval_invalid_syntax() {
155        let dialect = ClickHouseDialect;
156        let unparser = Unparser::new(&dialect);
157
158        // Test invalid ClickHouse syntax - should actually fail parsing
159        let args = vec![Expr::Literal(ScalarValue::Utf8(Some("invalid(((".to_string())), None)];
160        let result = dialect.scalar_function_to_sql_overrides(&unparser, "clickhouse_eval", &args);
161        assert!(result.is_err());
162        assert!(result.unwrap_err().to_string().contains("Invalid ClickHouse expression"));
163    }
164
165    #[test]
166    fn test_scalar_function_to_sql_overrides_unknown_function() {
167        let dialect = ClickHouseDialect;
168        let unparser = Unparser::new(&dialect);
169
170        // Test unknown function should return None
171        let args = vec![Expr::Literal(ScalarValue::Utf8(Some("test".to_string())), None)];
172        let result = dialect.scalar_function_to_sql_overrides(&unparser, "unknown_func", &args);
173        assert!(result.is_ok());
174        assert!(result.unwrap().is_none());
175    }
176
177    #[test]
178    fn test_clickhouse_dialect_debug_clone_default() {
179        // Test Debug trait
180        let debug_str = format!("{ClickHouseDialect:?}");
181        assert_eq!(debug_str, "ClickHouseDialect");
182    }
183}