1use rhei_core::types::QueryTarget;
30use rhei_core::QueryRouter;
31use sqlparser::ast::{
32 Expr, GroupByExpr, Query, Select, SelectItem, SetExpr, Statement, TableFactor,
33};
34use sqlparser::dialect::SQLiteDialect;
35use sqlparser::parser::Parser;
36use tracing::debug;
37
38pub struct SqlParserRouter;
47
48impl SqlParserRouter {
49 pub fn new() -> Self {
51 Self
52 }
53}
54
55impl Default for SqlParserRouter {
56 fn default() -> Self {
57 Self::new()
58 }
59}
60
61impl QueryRouter for SqlParserRouter {
62 fn route(&self, sql: &str) -> QueryTarget {
63 let trimmed = sql.trim();
64 if trimmed.is_empty() {
65 return QueryTarget::Oltp;
66 }
67
68 match Parser::parse_sql(&SQLiteDialect {}, trimmed) {
69 Ok(stmts) if !stmts.is_empty() => route_statement(&stmts[0]),
70 Ok(_) => QueryTarget::Oltp,
71 Err(e) => {
72 debug!(error = %e, sql = trimmed, "SQL parse failed, falling back to heuristic");
73 heuristic_route(trimmed)
74 }
75 }
76 }
77}
78
79fn route_statement(stmt: &Statement) -> QueryTarget {
81 match stmt {
82 Statement::Insert(_)
84 | Statement::Update { .. }
85 | Statement::Delete(_)
86 | Statement::CreateTable { .. }
87 | Statement::CreateIndex { .. }
88 | Statement::AlterTable { .. }
89 | Statement::Drop { .. }
90 | Statement::StartTransaction { .. }
91 | Statement::Commit { .. }
92 | Statement::Rollback { .. }
93 | Statement::Savepoint { .. } => QueryTarget::Oltp,
94
95 Statement::Query(query) => route_query(query),
97
98 Statement::ExplainTable { .. } => QueryTarget::Oltp,
100 Statement::Explain { statement, .. } => route_statement(statement),
102
103 _ => QueryTarget::Oltp,
105 }
106}
107
108fn route_query(query: &Query) -> QueryTarget {
110 if query.with.is_some() {
112 return QueryTarget::Olap;
113 }
114
115 match query.body.as_ref() {
117 SetExpr::Select(select) => route_select(select),
118 SetExpr::SetOperation { .. } => QueryTarget::Olap,
119 SetExpr::Query(inner) => route_query(inner),
120 _ => QueryTarget::Oltp,
121 }
122}
123
124fn route_select(select: &Select) -> QueryTarget {
126 let has_group_by = match &select.group_by {
128 GroupByExpr::All(_) => true,
129 GroupByExpr::Expressions(exprs, _) => !exprs.is_empty(),
130 };
131 if has_group_by || select.having.is_some() {
132 return QueryTarget::Olap;
133 }
134
135 for table in &select.from {
137 if !table.joins.is_empty() {
138 return QueryTarget::Olap;
139 }
140 if matches!(&table.relation, TableFactor::Derived { .. }) {
142 return QueryTarget::Olap;
143 }
144 }
145
146 for item in &select.projection {
148 if let SelectItem::UnnamedExpr(expr) | SelectItem::ExprWithAlias { expr, .. } = item {
149 if expr_has_analytical_pattern(expr) {
150 return QueryTarget::Olap;
151 }
152 }
153 }
154
155 if let Some(selection) = &select.selection {
157 if expr_has_subquery(selection) {
158 return QueryTarget::Olap;
159 }
160 }
161
162 QueryTarget::Oltp
164}
165
166fn expr_has_analytical_pattern(expr: &Expr) -> bool {
168 match expr {
169 Expr::Function(func) => {
170 if func.over.is_some() {
172 return true;
173 }
174 let name = func.name.to_string().to_ascii_uppercase();
176 matches!(
177 name.as_str(),
178 "COUNT"
179 | "SUM"
180 | "AVG"
181 | "MIN"
182 | "MAX"
183 | "STDDEV"
184 | "VARIANCE"
185 | "ARRAY_AGG"
186 | "STRING_AGG"
187 | "GROUP_CONCAT"
188 | "MEDIAN"
189 | "PERCENTILE_CONT"
190 | "PERCENTILE_DISC"
191 | "FIRST_VALUE"
192 | "LAST_VALUE"
193 | "NTH_VALUE"
194 | "ROW_NUMBER"
195 | "RANK"
196 | "DENSE_RANK"
197 | "NTILE"
198 | "LAG"
199 | "LEAD"
200 | "CUME_DIST"
201 | "PERCENT_RANK"
202 )
203 }
204 Expr::Nested(inner) => expr_has_analytical_pattern(inner),
205 Expr::BinaryOp { left, right, .. } => {
206 expr_has_analytical_pattern(left) || expr_has_analytical_pattern(right)
207 }
208 Expr::UnaryOp { expr, .. } => expr_has_analytical_pattern(expr),
209 Expr::Cast { expr, .. } => expr_has_analytical_pattern(expr),
210 Expr::Case {
211 operand,
212 conditions,
213 else_result,
214 ..
215 } => {
216 operand
217 .as_ref()
218 .is_some_and(|e| expr_has_analytical_pattern(e))
219 || conditions.iter().any(|cw| {
220 expr_has_analytical_pattern(&cw.condition)
221 || expr_has_analytical_pattern(&cw.result)
222 })
223 || else_result
224 .as_ref()
225 .is_some_and(|e| expr_has_analytical_pattern(e))
226 }
227 Expr::Subquery(q) => matches!(route_query(q), QueryTarget::Olap),
228 Expr::InSubquery { subquery, .. } => matches!(route_query(subquery), QueryTarget::Olap),
229 _ => false,
230 }
231}
232
233fn expr_has_subquery(expr: &Expr) -> bool {
235 match expr {
236 Expr::Subquery(_) | Expr::InSubquery { .. } | Expr::Exists { .. } => true,
237 Expr::Nested(inner) => expr_has_subquery(inner),
238 Expr::BinaryOp { left, right, .. } => expr_has_subquery(left) || expr_has_subquery(right),
239 Expr::UnaryOp { expr, .. } => expr_has_subquery(expr),
240 _ => false,
241 }
242}
243
244fn heuristic_route(sql: &str) -> QueryTarget {
251 let trimmed = sql;
252
253 const WRITE_KEYWORDS: &[&str] = &[
254 "INSERT", "UPDATE", "DELETE", "CREATE", "ALTER", "DROP", "BEGIN", "COMMIT", "ROLLBACK",
255 "PRAGMA",
256 ];
257 for kw in WRITE_KEYWORDS {
258 if starts_with_ignore_case(trimmed, kw) {
259 return QueryTarget::Oltp;
260 }
261 }
262
263 if starts_with_ignore_case(trimmed, "SELECT") {
264 const AGGREGATE_FNS: &[&str] = &["COUNT(", "SUM(", "AVG(", "MIN(", "MAX("];
265 let has_aggregate = AGGREGATE_FNS
266 .iter()
267 .any(|agg| contains_ignore_case(trimmed, agg));
268 let has_grouping =
269 contains_ignore_case(trimmed, "GROUP BY") || contains_ignore_case(trimmed, "HAVING");
270 let has_window =
271 contains_ignore_case(trimmed, "OVER(") || contains_ignore_case(trimmed, "OVER (");
272 let has_join = contains_ignore_case(trimmed, " JOIN ");
273
274 if has_aggregate || has_grouping || has_window || has_join {
275 return QueryTarget::Olap;
276 }
277 }
278
279 QueryTarget::Oltp
280}
281
282fn starts_with_ignore_case(haystack: &str, needle: &str) -> bool {
283 debug_assert!(needle.bytes().all(|b| b == b.to_ascii_uppercase()));
284 haystack.len() >= needle.len()
285 && haystack.as_bytes()[..needle.len()]
286 .iter()
287 .zip(needle.as_bytes())
288 .all(|(h, n)| h.to_ascii_uppercase() == *n)
289}
290
291fn contains_ignore_case(haystack: &str, needle: &str) -> bool {
292 debug_assert!(needle.bytes().all(|b| b == b.to_ascii_uppercase()));
293 if needle.len() > haystack.len() {
294 return false;
295 }
296 haystack.as_bytes().windows(needle.len()).any(|window| {
297 window
298 .iter()
299 .zip(needle.as_bytes())
300 .all(|(h, n)| h.to_ascii_uppercase() == *n)
301 })
302}
303
304pub struct HeuristicRouter {
309 inner: SqlParserRouter,
310}
311
312impl HeuristicRouter {
313 pub fn new() -> Self {
315 Self {
316 inner: SqlParserRouter::new(),
317 }
318 }
319}
320
321impl Default for HeuristicRouter {
322 fn default() -> Self {
323 Self::new()
324 }
325}
326
327impl QueryRouter for HeuristicRouter {
328 fn route(&self, sql: &str) -> QueryTarget {
329 self.inner.route(sql)
330 }
331}
332
333#[cfg(test)]
334mod tests {
335 use super::*;
336
337 #[test]
338 fn test_write_operations_route_to_oltp() {
339 let router = SqlParserRouter::new();
340 assert_eq!(
341 router.route("INSERT INTO users VALUES (1, 'Alice')"),
342 QueryTarget::Oltp
343 );
344 assert_eq!(
345 router.route("UPDATE users SET name = 'Bob' WHERE id = 1"),
346 QueryTarget::Oltp
347 );
348 assert_eq!(
349 router.route("DELETE FROM users WHERE id = 1"),
350 QueryTarget::Oltp
351 );
352 assert_eq!(
353 router.route("CREATE TABLE users (id INTEGER)"),
354 QueryTarget::Oltp
355 );
356 assert_eq!(
357 router.route("ALTER TABLE users ADD COLUMN email TEXT"),
358 QueryTarget::Oltp
359 );
360 }
361
362 #[test]
363 fn test_analytical_queries_route_to_olap() {
364 let router = SqlParserRouter::new();
365 assert_eq!(
366 router.route("SELECT COUNT(*) FROM users"),
367 QueryTarget::Olap
368 );
369 assert_eq!(
370 router.route("SELECT AVG(age) FROM users GROUP BY dept"),
371 QueryTarget::Olap
372 );
373 assert_eq!(
374 router.route("SELECT u.name, o.total FROM users u JOIN orders o ON u.id = o.user_id"),
375 QueryTarget::Olap,
376 );
377 }
378
379 #[test]
380 fn test_simple_selects_route_to_oltp() {
381 let router = SqlParserRouter::new();
382 assert_eq!(
383 router.route("SELECT * FROM users WHERE id = 1"),
384 QueryTarget::Oltp
385 );
386 assert_eq!(
387 router.route("SELECT name FROM users LIMIT 10"),
388 QueryTarget::Oltp
389 );
390 }
391
392 #[test]
393 fn test_window_functions_route_to_olap() {
394 let router = SqlParserRouter::new();
395 assert_eq!(
396 router.route("SELECT id, ROW_NUMBER() OVER (ORDER BY id) FROM users"),
397 QueryTarget::Olap
398 );
399 assert_eq!(
400 router.route("SELECT id, SUM(age) OVER (PARTITION BY dept) FROM users"),
401 QueryTarget::Olap
402 );
403 }
404
405 #[test]
406 fn test_subqueries_route_to_olap() {
407 let router = SqlParserRouter::new();
408 assert_eq!(
409 router.route("SELECT * FROM users WHERE id IN (SELECT user_id FROM orders)"),
410 QueryTarget::Olap
411 );
412 assert_eq!(
413 router.route("SELECT * FROM (SELECT dept, COUNT(*) cnt FROM users GROUP BY dept) sub"),
414 QueryTarget::Olap
415 );
416 }
417
418 #[test]
419 fn test_cte_routes_to_olap() {
420 let router = SqlParserRouter::new();
421 assert_eq!(
422 router.route(
423 "WITH active AS (SELECT * FROM users WHERE active = true) SELECT COUNT(*) FROM active"
424 ),
425 QueryTarget::Olap
426 );
427 }
428
429 #[test]
430 fn test_union_routes_to_olap() {
431 let router = SqlParserRouter::new();
432 assert_eq!(
433 router.route("SELECT id FROM users UNION ALL SELECT id FROM admins"),
434 QueryTarget::Olap
435 );
436 }
437
438 #[test]
439 fn test_string_containing_keywords_not_misrouted() {
440 let router = SqlParserRouter::new();
441 assert_eq!(
443 router.route("SELECT * FROM users WHERE note = 'COUNT(items) is 5'"),
444 QueryTarget::Oltp
445 );
446 }
447
448 #[test]
449 fn test_backwards_compat_heuristic_router() {
450 let router = HeuristicRouter::new();
451 assert_eq!(
452 router.route("SELECT COUNT(*) FROM users"),
453 QueryTarget::Olap
454 );
455 assert_eq!(
456 router.route("INSERT INTO users VALUES (1, 'Alice')"),
457 QueryTarget::Oltp
458 );
459 }
460
461 #[test]
462 fn test_pragma_routes_to_oltp() {
463 let router = SqlParserRouter::new();
464 assert_eq!(router.route("PRAGMA table_info(users)"), QueryTarget::Oltp);
466 }
467}