Skip to main content

rhei_sync/
router.rs

1//! AST-based SQL query router that classifies statements as OLTP or OLAP.
2//!
3//! [`SqlParserRouter`] is the primary implementation.  It parses every
4//! statement with `sqlparser-rs` (SQLite dialect) and inspects the resulting
5//! AST to decide which backend should handle the query.
6//!
7//! [`HeuristicRouter`] is a thin compatibility wrapper that delegates to
8//! [`SqlParserRouter`]; new code should use [`SqlParserRouter`] directly.
9//!
10//! ## Routing rules
11//!
12//! | SQL pattern | Target |
13//! |-------------|--------|
14//! | INSERT / UPDATE / DELETE | OLTP |
15//! | DDL (CREATE / ALTER / DROP) | OLTP |
16//! | Transactions (BEGIN / COMMIT / ROLLBACK / SAVEPOINT) | OLTP |
17//! | Simple SELECT (no aggregates, no JOIN, no subquery) | OLTP |
18//! | SELECT with GROUP BY / HAVING | OLAP |
19//! | SELECT with aggregate functions (COUNT, SUM, AVG, …) | OLAP |
20//! | SELECT with window functions (OVER clause) | OLAP |
21//! | SELECT with JOINs | OLAP |
22//! | SELECT with subqueries in WHERE / FROM | OLAP |
23//! | CTEs (WITH …) | OLAP |
24//! | Set operations (UNION / INTERSECT / EXCEPT) | OLAP |
25//! | EXPLAIN wrapping a query (`EXPLAIN ` + inner statement) | Same as inner query |
26//! | EXPLAIN of a table (`EXPLAIN <table>`, SQLite-style) | OLTP |
27//! | Parse failure | Heuristic fallback (default: OLTP) |
28
29use rhei_core::types::QueryTarget;
30use rhei_core::QueryRouter;
31use sqlparser::ast::{
32    Expr, GroupByExpr, Query, Select, SelectItem, SetExpr, Statement, TableFactor,
33};
34use sqlparser::dialect::SQLiteDialect;
35use sqlparser::parser::Parser;
36use tracing::debug;
37
38/// SQL parser-based query router that classifies SQL using a real AST.
39///
40/// Uses `sqlparser-rs` to parse the SQL into an AST, then inspects the
41/// statement type and structure to determine whether it should go to OLTP
42/// or OLAP.  Falls back to keyword-based heuristic routing if parsing fails, and
43/// defaults to OLTP for unrecognised constructs (safety-first).
44///
45/// See the [module-level documentation](self) for the full routing table.
46pub struct SqlParserRouter;
47
48impl SqlParserRouter {
49    /// Create a new [`SqlParserRouter`].
50    pub fn new() -> Self {
51        Self
52    }
53}
54
55impl Default for SqlParserRouter {
56    fn default() -> Self {
57        Self::new()
58    }
59}
60
61impl QueryRouter for SqlParserRouter {
62    fn route(&self, sql: &str) -> QueryTarget {
63        let trimmed = sql.trim();
64        if trimmed.is_empty() {
65            return QueryTarget::Oltp;
66        }
67
68        match Parser::parse_sql(&SQLiteDialect {}, trimmed) {
69            Ok(stmts) if !stmts.is_empty() => route_statement(&stmts[0]),
70            Ok(_) => QueryTarget::Oltp,
71            Err(e) => {
72                debug!(error = %e, sql = trimmed, "SQL parse failed, falling back to heuristic");
73                heuristic_route(trimmed)
74            }
75        }
76    }
77}
78
79/// Route based on the parsed AST statement.
80fn route_statement(stmt: &Statement) -> QueryTarget {
81    match stmt {
82        // Write operations → OLTP
83        Statement::Insert(_)
84        | Statement::Update { .. }
85        | Statement::Delete(_)
86        | Statement::CreateTable { .. }
87        | Statement::CreateIndex { .. }
88        | Statement::AlterTable { .. }
89        | Statement::Drop { .. }
90        | Statement::StartTransaction { .. }
91        | Statement::Commit { .. }
92        | Statement::Rollback { .. }
93        | Statement::Savepoint { .. } => QueryTarget::Oltp,
94
95        // SELECT queries: inspect for analytical patterns
96        Statement::Query(query) => route_query(query),
97
98        // `EXPLAIN <table>` (SQLite-style table introspection): always OLTP.
99        Statement::ExplainTable { .. } => QueryTarget::Oltp,
100        // `EXPLAIN <statement>`: route like the inner query.
101        Statement::Explain { statement, .. } => route_statement(statement),
102
103        // Default: OLTP for safety
104        _ => QueryTarget::Oltp,
105    }
106}
107
108/// Route a SELECT query based on its structure.
109fn route_query(query: &Query) -> QueryTarget {
110    // CTEs (WITH) → analytical
111    if query.with.is_some() {
112        return QueryTarget::Olap;
113    }
114
115    // UNION / INTERSECT / EXCEPT → analytical
116    match query.body.as_ref() {
117        SetExpr::Select(select) => route_select(select),
118        SetExpr::SetOperation { .. } => QueryTarget::Olap,
119        SetExpr::Query(inner) => route_query(inner),
120        _ => QueryTarget::Oltp,
121    }
122}
123
124/// Route a SELECT body.
125fn route_select(select: &Select) -> QueryTarget {
126    // GROUP BY or HAVING → analytical
127    let has_group_by = match &select.group_by {
128        GroupByExpr::All(_) => true,
129        GroupByExpr::Expressions(exprs, _) => !exprs.is_empty(),
130    };
131    if has_group_by || select.having.is_some() {
132        return QueryTarget::Olap;
133    }
134
135    // JOINs → analytical
136    for table in &select.from {
137        if !table.joins.is_empty() {
138            return QueryTarget::Olap;
139        }
140        // Subqueries in FROM → analytical
141        if matches!(&table.relation, TableFactor::Derived { .. }) {
142            return QueryTarget::Olap;
143        }
144    }
145
146    // Check projection for aggregate functions or window functions
147    for item in &select.projection {
148        if let SelectItem::UnnamedExpr(expr) | SelectItem::ExprWithAlias { expr, .. } = item {
149            if expr_has_analytical_pattern(expr) {
150                return QueryTarget::Olap;
151            }
152        }
153    }
154
155    // Check WHERE for subqueries
156    if let Some(selection) = &select.selection {
157        if expr_has_subquery(selection) {
158            return QueryTarget::Olap;
159        }
160    }
161
162    // Default: point lookup → OLTP
163    QueryTarget::Oltp
164}
165
166/// Check if an expression contains aggregate functions or window functions.
167fn expr_has_analytical_pattern(expr: &Expr) -> bool {
168    match expr {
169        Expr::Function(func) => {
170            // Check for window function (OVER clause)
171            if func.over.is_some() {
172                return true;
173            }
174            // Check for known aggregate function names
175            let name = func.name.to_string().to_ascii_uppercase();
176            matches!(
177                name.as_str(),
178                "COUNT"
179                    | "SUM"
180                    | "AVG"
181                    | "MIN"
182                    | "MAX"
183                    | "STDDEV"
184                    | "VARIANCE"
185                    | "ARRAY_AGG"
186                    | "STRING_AGG"
187                    | "GROUP_CONCAT"
188                    | "MEDIAN"
189                    | "PERCENTILE_CONT"
190                    | "PERCENTILE_DISC"
191                    | "FIRST_VALUE"
192                    | "LAST_VALUE"
193                    | "NTH_VALUE"
194                    | "ROW_NUMBER"
195                    | "RANK"
196                    | "DENSE_RANK"
197                    | "NTILE"
198                    | "LAG"
199                    | "LEAD"
200                    | "CUME_DIST"
201                    | "PERCENT_RANK"
202            )
203        }
204        Expr::Nested(inner) => expr_has_analytical_pattern(inner),
205        Expr::BinaryOp { left, right, .. } => {
206            expr_has_analytical_pattern(left) || expr_has_analytical_pattern(right)
207        }
208        Expr::UnaryOp { expr, .. } => expr_has_analytical_pattern(expr),
209        Expr::Cast { expr, .. } => expr_has_analytical_pattern(expr),
210        Expr::Case {
211            operand,
212            conditions,
213            else_result,
214            ..
215        } => {
216            operand
217                .as_ref()
218                .is_some_and(|e| expr_has_analytical_pattern(e))
219                || conditions.iter().any(|cw| {
220                    expr_has_analytical_pattern(&cw.condition)
221                        || expr_has_analytical_pattern(&cw.result)
222                })
223                || else_result
224                    .as_ref()
225                    .is_some_and(|e| expr_has_analytical_pattern(e))
226        }
227        Expr::Subquery(q) => matches!(route_query(q), QueryTarget::Olap),
228        Expr::InSubquery { subquery, .. } => matches!(route_query(subquery), QueryTarget::Olap),
229        _ => false,
230    }
231}
232
233/// Check if an expression contains a subquery.
234fn expr_has_subquery(expr: &Expr) -> bool {
235    match expr {
236        Expr::Subquery(_) | Expr::InSubquery { .. } | Expr::Exists { .. } => true,
237        Expr::Nested(inner) => expr_has_subquery(inner),
238        Expr::BinaryOp { left, right, .. } => expr_has_subquery(left) || expr_has_subquery(right),
239        Expr::UnaryOp { expr, .. } => expr_has_subquery(expr),
240        _ => false,
241    }
242}
243
244// ---------------------------------------------------------------------------
245// Heuristic fallback (used when sqlparser fails)
246// ---------------------------------------------------------------------------
247
248/// Heuristic-based routing as fallback when the parser cannot handle the SQL.
249/// `sql` is expected to be already trimmed by the caller.
250fn heuristic_route(sql: &str) -> QueryTarget {
251    let trimmed = sql;
252
253    const WRITE_KEYWORDS: &[&str] = &[
254        "INSERT", "UPDATE", "DELETE", "CREATE", "ALTER", "DROP", "BEGIN", "COMMIT", "ROLLBACK",
255        "PRAGMA",
256    ];
257    for kw in WRITE_KEYWORDS {
258        if starts_with_ignore_case(trimmed, kw) {
259            return QueryTarget::Oltp;
260        }
261    }
262
263    if starts_with_ignore_case(trimmed, "SELECT") {
264        const AGGREGATE_FNS: &[&str] = &["COUNT(", "SUM(", "AVG(", "MIN(", "MAX("];
265        let has_aggregate = AGGREGATE_FNS
266            .iter()
267            .any(|agg| contains_ignore_case(trimmed, agg));
268        let has_grouping =
269            contains_ignore_case(trimmed, "GROUP BY") || contains_ignore_case(trimmed, "HAVING");
270        let has_window =
271            contains_ignore_case(trimmed, "OVER(") || contains_ignore_case(trimmed, "OVER (");
272        let has_join = contains_ignore_case(trimmed, " JOIN ");
273
274        if has_aggregate || has_grouping || has_window || has_join {
275            return QueryTarget::Olap;
276        }
277    }
278
279    QueryTarget::Oltp
280}
281
282fn starts_with_ignore_case(haystack: &str, needle: &str) -> bool {
283    debug_assert!(needle.bytes().all(|b| b == b.to_ascii_uppercase()));
284    haystack.len() >= needle.len()
285        && haystack.as_bytes()[..needle.len()]
286            .iter()
287            .zip(needle.as_bytes())
288            .all(|(h, n)| h.to_ascii_uppercase() == *n)
289}
290
291fn contains_ignore_case(haystack: &str, needle: &str) -> bool {
292    debug_assert!(needle.bytes().all(|b| b == b.to_ascii_uppercase()));
293    if needle.len() > haystack.len() {
294        return false;
295    }
296    haystack.as_bytes().windows(needle.len()).any(|window| {
297        window
298            .iter()
299            .zip(needle.as_bytes())
300            .all(|(h, n)| h.to_ascii_uppercase() == *n)
301    })
302}
303
304/// Backwards-compatible query router that delegates to [`SqlParserRouter`].
305///
306/// This type exists solely for API compatibility.  New code should use
307/// [`SqlParserRouter`] directly; both produce identical routing decisions.
308pub struct HeuristicRouter {
309    inner: SqlParserRouter,
310}
311
312impl HeuristicRouter {
313    /// Create a new [`HeuristicRouter`] backed by a [`SqlParserRouter`].
314    pub fn new() -> Self {
315        Self {
316            inner: SqlParserRouter::new(),
317        }
318    }
319}
320
321impl Default for HeuristicRouter {
322    fn default() -> Self {
323        Self::new()
324    }
325}
326
327impl QueryRouter for HeuristicRouter {
328    fn route(&self, sql: &str) -> QueryTarget {
329        self.inner.route(sql)
330    }
331}
332
333#[cfg(test)]
334mod tests {
335    use super::*;
336
337    #[test]
338    fn test_write_operations_route_to_oltp() {
339        let router = SqlParserRouter::new();
340        assert_eq!(
341            router.route("INSERT INTO users VALUES (1, 'Alice')"),
342            QueryTarget::Oltp
343        );
344        assert_eq!(
345            router.route("UPDATE users SET name = 'Bob' WHERE id = 1"),
346            QueryTarget::Oltp
347        );
348        assert_eq!(
349            router.route("DELETE FROM users WHERE id = 1"),
350            QueryTarget::Oltp
351        );
352        assert_eq!(
353            router.route("CREATE TABLE users (id INTEGER)"),
354            QueryTarget::Oltp
355        );
356        assert_eq!(
357            router.route("ALTER TABLE users ADD COLUMN email TEXT"),
358            QueryTarget::Oltp
359        );
360    }
361
362    #[test]
363    fn test_analytical_queries_route_to_olap() {
364        let router = SqlParserRouter::new();
365        assert_eq!(
366            router.route("SELECT COUNT(*) FROM users"),
367            QueryTarget::Olap
368        );
369        assert_eq!(
370            router.route("SELECT AVG(age) FROM users GROUP BY dept"),
371            QueryTarget::Olap
372        );
373        assert_eq!(
374            router.route("SELECT u.name, o.total FROM users u JOIN orders o ON u.id = o.user_id"),
375            QueryTarget::Olap,
376        );
377    }
378
379    #[test]
380    fn test_simple_selects_route_to_oltp() {
381        let router = SqlParserRouter::new();
382        assert_eq!(
383            router.route("SELECT * FROM users WHERE id = 1"),
384            QueryTarget::Oltp
385        );
386        assert_eq!(
387            router.route("SELECT name FROM users LIMIT 10"),
388            QueryTarget::Oltp
389        );
390    }
391
392    #[test]
393    fn test_window_functions_route_to_olap() {
394        let router = SqlParserRouter::new();
395        assert_eq!(
396            router.route("SELECT id, ROW_NUMBER() OVER (ORDER BY id) FROM users"),
397            QueryTarget::Olap
398        );
399        assert_eq!(
400            router.route("SELECT id, SUM(age) OVER (PARTITION BY dept) FROM users"),
401            QueryTarget::Olap
402        );
403    }
404
405    #[test]
406    fn test_subqueries_route_to_olap() {
407        let router = SqlParserRouter::new();
408        assert_eq!(
409            router.route("SELECT * FROM users WHERE id IN (SELECT user_id FROM orders)"),
410            QueryTarget::Olap
411        );
412        assert_eq!(
413            router.route("SELECT * FROM (SELECT dept, COUNT(*) cnt FROM users GROUP BY dept) sub"),
414            QueryTarget::Olap
415        );
416    }
417
418    #[test]
419    fn test_cte_routes_to_olap() {
420        let router = SqlParserRouter::new();
421        assert_eq!(
422            router.route(
423                "WITH active AS (SELECT * FROM users WHERE active = true) SELECT COUNT(*) FROM active"
424            ),
425            QueryTarget::Olap
426        );
427    }
428
429    #[test]
430    fn test_union_routes_to_olap() {
431        let router = SqlParserRouter::new();
432        assert_eq!(
433            router.route("SELECT id FROM users UNION ALL SELECT id FROM admins"),
434            QueryTarget::Olap
435        );
436    }
437
438    #[test]
439    fn test_string_containing_keywords_not_misrouted() {
440        let router = SqlParserRouter::new();
441        // With a real parser, a string literal containing "COUNT(" won't trigger OLAP
442        assert_eq!(
443            router.route("SELECT * FROM users WHERE note = 'COUNT(items) is 5'"),
444            QueryTarget::Oltp
445        );
446    }
447
448    #[test]
449    fn test_backwards_compat_heuristic_router() {
450        let router = HeuristicRouter::new();
451        assert_eq!(
452            router.route("SELECT COUNT(*) FROM users"),
453            QueryTarget::Olap
454        );
455        assert_eq!(
456            router.route("INSERT INTO users VALUES (1, 'Alice')"),
457            QueryTarget::Oltp
458        );
459    }
460
461    #[test]
462    fn test_pragma_routes_to_oltp() {
463        let router = SqlParserRouter::new();
464        // PRAGMA may not parse in sqlparser, should fall back to heuristic
465        assert_eq!(router.route("PRAGMA table_info(users)"), QueryTarget::Oltp);
466    }
467}