Skip to main content

polyglot_sql/dialects/
risingwave.rs

1//! RisingWave Dialect
2//!
3//! RisingWave-specific transformations based on sqlglot patterns.
4//! RisingWave is PostgreSQL-compatible with streaming SQL extensions.
5
6use super::{DialectImpl, DialectType};
7use crate::error::Result;
8use crate::expressions::{AggFunc, Case, Cast, Expression, Function, VarArgFunc};
9use crate::generator::GeneratorConfig;
10use crate::tokens::TokenizerConfig;
11
12/// RisingWave dialect (PostgreSQL-compatible streaming database)
13pub struct RisingWaveDialect;
14
15impl DialectImpl for RisingWaveDialect {
16    fn dialect_type(&self) -> DialectType {
17        DialectType::RisingWave
18    }
19
20    fn tokenizer_config(&self) -> TokenizerConfig {
21        let mut config = TokenizerConfig::default();
22        // RisingWave uses double quotes for identifiers (PostgreSQL-style)
23        config.identifiers.insert('"', '"');
24        // PostgreSQL-style nested comments supported
25        config.nested_comments = true;
26        config
27    }
28
29    fn generator_config(&self) -> GeneratorConfig {
30        use crate::generator::IdentifierQuoteStyle;
31        GeneratorConfig {
32            identifier_quote: '"',
33            identifier_quote_style: IdentifierQuoteStyle::DOUBLE_QUOTE,
34            dialect: Some(DialectType::RisingWave),
35            ..Default::default()
36        }
37    }
38
39    fn transform_expr(&self, expr: Expression) -> Result<Expression> {
40        match expr {
41            // IFNULL -> COALESCE in RisingWave
42            Expression::IfNull(f) => Ok(Expression::Coalesce(Box::new(VarArgFunc { original_name: None,
43                expressions: vec![f.this, f.expression],
44            }))),
45
46            // NVL -> COALESCE in RisingWave
47            Expression::Nvl(f) => Ok(Expression::Coalesce(Box::new(VarArgFunc { original_name: None,
48                expressions: vec![f.this, f.expression],
49            }))),
50
51            // Coalesce with original_name (e.g., IFNULL parsed as Coalesce) -> clear original_name
52            Expression::Coalesce(mut f) => {
53                f.original_name = None;
54                Ok(Expression::Coalesce(f))
55            }
56
57            // TryCast -> not directly supported, use CAST
58            Expression::TryCast(c) => Ok(Expression::Cast(c)),
59
60            // SafeCast -> CAST in RisingWave
61            Expression::SafeCast(c) => Ok(Expression::Cast(c)),
62
63            // ILIKE is native in RisingWave (PostgreSQL-style)
64            Expression::ILike(op) => Ok(Expression::ILike(op)),
65
66            // CountIf -> SUM(CASE WHEN condition THEN 1 ELSE 0 END)
67            Expression::CountIf(f) => {
68                let case_expr = Expression::Case(Box::new(Case {
69                    operand: None,
70                    whens: vec![(f.this.clone(), Expression::number(1))],
71                    else_: Some(Expression::number(0)),
72                }));
73                Ok(Expression::Sum(Box::new(AggFunc { ignore_nulls: None, having_max: None,
74                    this: case_expr,
75                    distinct: f.distinct,
76                    filter: f.filter,
77                    order_by: Vec::new(),
78                name: None,
79                limit: None,
80                })))
81            }
82
83            // RAND -> RANDOM in RisingWave (PostgreSQL-style)
84            Expression::Rand(r) => {
85                let _ = r.seed;
86                Ok(Expression::Random(crate::expressions::Random))
87            }
88
89            // Generic function transformations
90            Expression::Function(f) => self.transform_function(*f),
91
92            // Generic aggregate function transformations
93            Expression::AggregateFunction(f) => self.transform_aggregate_function(f),
94
95            // Cast transformations
96            Expression::Cast(c) => self.transform_cast(*c),
97
98            // Pass through everything else
99            _ => Ok(expr),
100        }
101    }
102}
103
104impl RisingWaveDialect {
105    fn transform_function(&self, f: Function) -> Result<Expression> {
106        let name_upper = f.name.to_uppercase();
107        match name_upper.as_str() {
108            // IFNULL -> COALESCE
109            "IFNULL" if f.args.len() == 2 => Ok(Expression::Coalesce(Box::new(VarArgFunc { original_name: None,
110                expressions: f.args,
111            }))),
112
113            // NVL -> COALESCE
114            "NVL" if f.args.len() == 2 => Ok(Expression::Coalesce(Box::new(VarArgFunc { original_name: None,
115                expressions: f.args,
116            }))),
117
118            // ISNULL -> COALESCE
119            "ISNULL" if f.args.len() == 2 => Ok(Expression::Coalesce(Box::new(VarArgFunc { original_name: None,
120                expressions: f.args,
121            }))),
122
123            // NOW is native in RisingWave
124            "NOW" => Ok(Expression::CurrentTimestamp(
125                crate::expressions::CurrentTimestamp { precision: None, sysdate: false },
126            )),
127
128            // GETDATE -> NOW
129            "GETDATE" => Ok(Expression::CurrentTimestamp(
130                crate::expressions::CurrentTimestamp { precision: None, sysdate: false },
131            )),
132
133            // RAND -> RANDOM
134            "RAND" => Ok(Expression::Random(crate::expressions::Random)),
135
136            // STRING_AGG is native in RisingWave (PostgreSQL-style)
137            "STRING_AGG" => Ok(Expression::Function(Box::new(f))),
138
139            // GROUP_CONCAT -> STRING_AGG
140            "GROUP_CONCAT" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
141                Function::new("STRING_AGG".to_string(), f.args),
142            ))),
143
144            // LISTAGG -> STRING_AGG
145            "LISTAGG" if !f.args.is_empty() => Ok(Expression::Function(Box::new(Function::new(
146                "STRING_AGG".to_string(),
147                f.args,
148            )))),
149
150            // SUBSTR -> SUBSTRING
151            "SUBSTR" => Ok(Expression::Function(Box::new(Function::new(
152                "SUBSTRING".to_string(),
153                f.args,
154            )))),
155
156            // LENGTH is native in RisingWave
157            "LENGTH" => Ok(Expression::Function(Box::new(f))),
158
159            // LEN -> LENGTH
160            "LEN" if f.args.len() == 1 => Ok(Expression::Function(Box::new(Function::new(
161                "LENGTH".to_string(),
162                f.args,
163            )))),
164
165            // CHARINDEX -> STRPOS (with swapped args)
166            "CHARINDEX" if f.args.len() >= 2 => {
167                let mut args = f.args;
168                let substring = args.remove(0);
169                let string = args.remove(0);
170                Ok(Expression::Function(Box::new(Function::new(
171                    "STRPOS".to_string(),
172                    vec![string, substring],
173                ))))
174            }
175
176            // INSTR -> STRPOS
177            "INSTR" if f.args.len() >= 2 => Ok(Expression::Function(Box::new(Function::new(
178                "STRPOS".to_string(),
179                f.args,
180            )))),
181
182            // LOCATE -> STRPOS (with swapped args)
183            "LOCATE" if f.args.len() >= 2 => {
184                let mut args = f.args;
185                let substring = args.remove(0);
186                let string = args.remove(0);
187                Ok(Expression::Function(Box::new(Function::new(
188                    "STRPOS".to_string(),
189                    vec![string, substring],
190                ))))
191            }
192
193            // STRPOS is native in RisingWave
194            "STRPOS" => Ok(Expression::Function(Box::new(f))),
195
196            // ARRAY_LENGTH is native in RisingWave
197            "ARRAY_LENGTH" => Ok(Expression::Function(Box::new(f))),
198
199            // SIZE -> ARRAY_LENGTH
200            "SIZE" if f.args.len() == 1 => Ok(Expression::Function(Box::new(Function::new(
201                "ARRAY_LENGTH".to_string(),
202                f.args,
203            )))),
204
205            // TO_CHAR is native in RisingWave
206            "TO_CHAR" => Ok(Expression::Function(Box::new(f))),
207
208            // DATE_FORMAT -> TO_CHAR
209            "DATE_FORMAT" if f.args.len() >= 2 => Ok(Expression::Function(Box::new(Function::new(
210                "TO_CHAR".to_string(),
211                f.args,
212            )))),
213
214            // strftime -> TO_CHAR
215            "STRFTIME" if f.args.len() >= 2 => {
216                let mut args = f.args;
217                let format = args.remove(0);
218                let date = args.remove(0);
219                Ok(Expression::Function(Box::new(Function::new(
220                    "TO_CHAR".to_string(),
221                    vec![date, format],
222                ))))
223            }
224
225            // JSON_EXTRACT_PATH_TEXT is native in RisingWave
226            "JSON_EXTRACT_PATH_TEXT" => Ok(Expression::Function(Box::new(f))),
227
228            // GET_JSON_OBJECT -> JSON_EXTRACT_PATH_TEXT
229            "GET_JSON_OBJECT" if f.args.len() == 2 => Ok(Expression::Function(Box::new(
230                Function::new("JSON_EXTRACT_PATH_TEXT".to_string(), f.args),
231            ))),
232
233            // JSON_EXTRACT -> JSON_EXTRACT_PATH_TEXT
234            "JSON_EXTRACT" if f.args.len() >= 2 => Ok(Expression::Function(Box::new(
235                Function::new("JSON_EXTRACT_PATH_TEXT".to_string(), f.args),
236            ))),
237
238            // Pass through everything else
239            _ => Ok(Expression::Function(Box::new(f))),
240        }
241    }
242
243    fn transform_aggregate_function(
244        &self,
245        f: Box<crate::expressions::AggregateFunction>,
246    ) -> Result<Expression> {
247        let name_upper = f.name.to_uppercase();
248        match name_upper.as_str() {
249            // COUNT_IF -> SUM(CASE WHEN...)
250            "COUNT_IF" if !f.args.is_empty() => {
251                let condition = f.args.into_iter().next().unwrap();
252                let case_expr = Expression::Case(Box::new(Case {
253                    operand: None,
254                    whens: vec![(condition, Expression::number(1))],
255                    else_: Some(Expression::number(0)),
256                }));
257                Ok(Expression::Sum(Box::new(AggFunc { ignore_nulls: None, having_max: None,
258                    this: case_expr,
259                    distinct: f.distinct,
260                    filter: f.filter,
261                    order_by: Vec::new(),
262                name: None,
263                limit: None,
264                })))
265            }
266
267            // Pass through everything else
268            _ => Ok(Expression::AggregateFunction(f)),
269        }
270    }
271
272    fn transform_cast(&self, c: Cast) -> Result<Expression> {
273        // RisingWave type mappings are handled in the generator
274        Ok(Expression::Cast(Box::new(c)))
275    }
276}