Skip to main content

polyglot_sql/dialects/
risingwave.rs

1//! RisingWave Dialect
2//!
3//! RisingWave-specific transformations based on sqlglot patterns.
4//! RisingWave is PostgreSQL-compatible with streaming SQL extensions.
5
6use super::{DialectImpl, DialectType};
7use crate::error::Result;
8use crate::expressions::{AggFunc, Case, Cast, Expression, Function, VarArgFunc};
9use crate::generator::GeneratorConfig;
10use crate::tokens::TokenizerConfig;
11
12/// RisingWave dialect (PostgreSQL-compatible streaming database)
13pub struct RisingWaveDialect;
14
15impl DialectImpl for RisingWaveDialect {
16    fn dialect_type(&self) -> DialectType {
17        DialectType::RisingWave
18    }
19
20    fn tokenizer_config(&self) -> TokenizerConfig {
21        let mut config = TokenizerConfig::default();
22        // RisingWave uses double quotes for identifiers (PostgreSQL-style)
23        config.identifiers.insert('"', '"');
24        // PostgreSQL-style nested comments supported
25        config.nested_comments = true;
26        config
27    }
28
29    fn generator_config(&self) -> GeneratorConfig {
30        use crate::generator::IdentifierQuoteStyle;
31        GeneratorConfig {
32            identifier_quote: '"',
33            identifier_quote_style: IdentifierQuoteStyle::DOUBLE_QUOTE,
34            dialect: Some(DialectType::RisingWave),
35            ..Default::default()
36        }
37    }
38
39    fn transform_expr(&self, expr: Expression) -> Result<Expression> {
40        match expr {
41            // IFNULL -> COALESCE in RisingWave
42            Expression::IfNull(f) => Ok(Expression::Coalesce(Box::new(VarArgFunc {
43                original_name: None,
44                expressions: vec![f.this, f.expression],
45            }))),
46
47            // NVL -> COALESCE in RisingWave
48            Expression::Nvl(f) => Ok(Expression::Coalesce(Box::new(VarArgFunc {
49                original_name: None,
50                expressions: vec![f.this, f.expression],
51            }))),
52
53            // Coalesce with original_name (e.g., IFNULL parsed as Coalesce) -> clear original_name
54            Expression::Coalesce(mut f) => {
55                f.original_name = None;
56                Ok(Expression::Coalesce(f))
57            }
58
59            // TryCast -> not directly supported, use CAST
60            Expression::TryCast(c) => Ok(Expression::Cast(c)),
61
62            // SafeCast -> CAST in RisingWave
63            Expression::SafeCast(c) => Ok(Expression::Cast(c)),
64
65            // ILIKE is native in RisingWave (PostgreSQL-style)
66            Expression::ILike(op) => Ok(Expression::ILike(op)),
67
68            // CountIf -> SUM(CASE WHEN condition THEN 1 ELSE 0 END)
69            Expression::CountIf(f) => {
70                let case_expr = Expression::Case(Box::new(Case {
71                    operand: None,
72                    whens: vec![(f.this.clone(), Expression::number(1))],
73                    else_: Some(Expression::number(0)),
74                    comments: Vec::new(),
75                }));
76                Ok(Expression::Sum(Box::new(AggFunc {
77                    ignore_nulls: None,
78                    having_max: None,
79                    this: case_expr,
80                    distinct: f.distinct,
81                    filter: f.filter,
82                    order_by: Vec::new(),
83                    name: None,
84                    limit: None,
85                })))
86            }
87
88            // RAND -> RANDOM in RisingWave (PostgreSQL-style)
89            Expression::Rand(r) => {
90                let _ = r.seed;
91                Ok(Expression::Random(crate::expressions::Random))
92            }
93
94            // Generic function transformations
95            Expression::Function(f) => self.transform_function(*f),
96
97            // Generic aggregate function transformations
98            Expression::AggregateFunction(f) => self.transform_aggregate_function(f),
99
100            // Cast transformations
101            Expression::Cast(c) => self.transform_cast(*c),
102
103            // Pass through everything else
104            _ => Ok(expr),
105        }
106    }
107}
108
109impl RisingWaveDialect {
110    fn transform_function(&self, f: Function) -> Result<Expression> {
111        let name_upper = f.name.to_uppercase();
112        match name_upper.as_str() {
113            // IFNULL -> COALESCE
114            "IFNULL" if f.args.len() == 2 => Ok(Expression::Coalesce(Box::new(VarArgFunc {
115                original_name: None,
116                expressions: f.args,
117            }))),
118
119            // NVL -> COALESCE
120            "NVL" if f.args.len() == 2 => Ok(Expression::Coalesce(Box::new(VarArgFunc {
121                original_name: None,
122                expressions: f.args,
123            }))),
124
125            // ISNULL -> COALESCE
126            "ISNULL" if f.args.len() == 2 => Ok(Expression::Coalesce(Box::new(VarArgFunc {
127                original_name: None,
128                expressions: f.args,
129            }))),
130
131            // NOW is native in RisingWave
132            "NOW" => Ok(Expression::CurrentTimestamp(
133                crate::expressions::CurrentTimestamp {
134                    precision: None,
135                    sysdate: false,
136                },
137            )),
138
139            // GETDATE -> NOW
140            "GETDATE" => Ok(Expression::CurrentTimestamp(
141                crate::expressions::CurrentTimestamp {
142                    precision: None,
143                    sysdate: false,
144                },
145            )),
146
147            // RAND -> RANDOM
148            "RAND" => Ok(Expression::Random(crate::expressions::Random)),
149
150            // STRING_AGG is native in RisingWave (PostgreSQL-style)
151            "STRING_AGG" => Ok(Expression::Function(Box::new(f))),
152
153            // GROUP_CONCAT -> STRING_AGG
154            "GROUP_CONCAT" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
155                Function::new("STRING_AGG".to_string(), f.args),
156            ))),
157
158            // LISTAGG -> STRING_AGG
159            "LISTAGG" if !f.args.is_empty() => Ok(Expression::Function(Box::new(Function::new(
160                "STRING_AGG".to_string(),
161                f.args,
162            )))),
163
164            // SUBSTR -> SUBSTRING
165            "SUBSTR" => Ok(Expression::Function(Box::new(Function::new(
166                "SUBSTRING".to_string(),
167                f.args,
168            )))),
169
170            // LENGTH is native in RisingWave
171            "LENGTH" => Ok(Expression::Function(Box::new(f))),
172
173            // LEN -> LENGTH
174            "LEN" if f.args.len() == 1 => Ok(Expression::Function(Box::new(Function::new(
175                "LENGTH".to_string(),
176                f.args,
177            )))),
178
179            // CHARINDEX -> STRPOS (with swapped args)
180            "CHARINDEX" if f.args.len() >= 2 => {
181                let mut args = f.args;
182                let substring = args.remove(0);
183                let string = args.remove(0);
184                Ok(Expression::Function(Box::new(Function::new(
185                    "STRPOS".to_string(),
186                    vec![string, substring],
187                ))))
188            }
189
190            // INSTR -> STRPOS
191            "INSTR" if f.args.len() >= 2 => Ok(Expression::Function(Box::new(Function::new(
192                "STRPOS".to_string(),
193                f.args,
194            )))),
195
196            // LOCATE -> STRPOS (with swapped args)
197            "LOCATE" if f.args.len() >= 2 => {
198                let mut args = f.args;
199                let substring = args.remove(0);
200                let string = args.remove(0);
201                Ok(Expression::Function(Box::new(Function::new(
202                    "STRPOS".to_string(),
203                    vec![string, substring],
204                ))))
205            }
206
207            // STRPOS is native in RisingWave
208            "STRPOS" => Ok(Expression::Function(Box::new(f))),
209
210            // ARRAY_LENGTH is native in RisingWave
211            "ARRAY_LENGTH" => Ok(Expression::Function(Box::new(f))),
212
213            // SIZE -> ARRAY_LENGTH
214            "SIZE" if f.args.len() == 1 => Ok(Expression::Function(Box::new(Function::new(
215                "ARRAY_LENGTH".to_string(),
216                f.args,
217            )))),
218
219            // TO_CHAR is native in RisingWave
220            "TO_CHAR" => Ok(Expression::Function(Box::new(f))),
221
222            // DATE_FORMAT -> TO_CHAR
223            "DATE_FORMAT" if f.args.len() >= 2 => Ok(Expression::Function(Box::new(
224                Function::new("TO_CHAR".to_string(), f.args),
225            ))),
226
227            // strftime -> TO_CHAR
228            "STRFTIME" if f.args.len() >= 2 => {
229                let mut args = f.args;
230                let format = args.remove(0);
231                let date = args.remove(0);
232                Ok(Expression::Function(Box::new(Function::new(
233                    "TO_CHAR".to_string(),
234                    vec![date, format],
235                ))))
236            }
237
238            // JSON_EXTRACT_PATH_TEXT is native in RisingWave
239            "JSON_EXTRACT_PATH_TEXT" => Ok(Expression::Function(Box::new(f))),
240
241            // GET_JSON_OBJECT -> JSON_EXTRACT_PATH_TEXT
242            "GET_JSON_OBJECT" if f.args.len() == 2 => Ok(Expression::Function(Box::new(
243                Function::new("JSON_EXTRACT_PATH_TEXT".to_string(), f.args),
244            ))),
245
246            // JSON_EXTRACT -> JSON_EXTRACT_PATH_TEXT
247            "JSON_EXTRACT" if f.args.len() >= 2 => Ok(Expression::Function(Box::new(
248                Function::new("JSON_EXTRACT_PATH_TEXT".to_string(), f.args),
249            ))),
250
251            // Pass through everything else
252            _ => Ok(Expression::Function(Box::new(f))),
253        }
254    }
255
256    fn transform_aggregate_function(
257        &self,
258        f: Box<crate::expressions::AggregateFunction>,
259    ) -> Result<Expression> {
260        let name_upper = f.name.to_uppercase();
261        match name_upper.as_str() {
262            // COUNT_IF -> SUM(CASE WHEN...)
263            "COUNT_IF" if !f.args.is_empty() => {
264                let condition = f.args.into_iter().next().unwrap();
265                let case_expr = Expression::Case(Box::new(Case {
266                    operand: None,
267                    whens: vec![(condition, Expression::number(1))],
268                    else_: Some(Expression::number(0)),
269                    comments: Vec::new(),
270                }));
271                Ok(Expression::Sum(Box::new(AggFunc {
272                    ignore_nulls: None,
273                    having_max: None,
274                    this: case_expr,
275                    distinct: f.distinct,
276                    filter: f.filter,
277                    order_by: Vec::new(),
278                    name: None,
279                    limit: None,
280                })))
281            }
282
283            // Pass through everything else
284            _ => Ok(Expression::AggregateFunction(f)),
285        }
286    }
287
288    fn transform_cast(&self, c: Cast) -> Result<Expression> {
289        // RisingWave type mappings are handled in the generator
290        Ok(Expression::Cast(Box::new(c)))
291    }
292}