use rhei_core::types::QueryTarget;
use rhei_core::QueryRouter;
use sqlparser::ast::{
Expr, GroupByExpr, Query, Select, SelectItem, SetExpr, Statement, TableFactor,
};
use sqlparser::dialect::SQLiteDialect;
use sqlparser::parser::Parser;
use tracing::debug;
pub struct SqlParserRouter;
impl SqlParserRouter {
pub fn new() -> Self {
Self
}
}
impl Default for SqlParserRouter {
fn default() -> Self {
Self::new()
}
}
impl QueryRouter for SqlParserRouter {
fn route(&self, sql: &str) -> QueryTarget {
let trimmed = sql.trim();
if trimmed.is_empty() {
return QueryTarget::Oltp;
}
match Parser::parse_sql(&SQLiteDialect {}, trimmed) {
Ok(stmts) if !stmts.is_empty() => route_statement(&stmts[0]),
Ok(_) => QueryTarget::Oltp,
Err(e) => {
debug!(error = %e, sql = trimmed, "SQL parse failed, falling back to heuristic");
heuristic_route(trimmed)
}
}
}
}
fn route_statement(stmt: &Statement) -> QueryTarget {
match stmt {
Statement::Insert(_)
| Statement::Update { .. }
| Statement::Delete(_)
| Statement::CreateTable { .. }
| Statement::CreateIndex { .. }
| Statement::AlterTable { .. }
| Statement::Drop { .. }
| Statement::StartTransaction { .. }
| Statement::Commit { .. }
| Statement::Rollback { .. }
| Statement::Savepoint { .. } => QueryTarget::Oltp,
Statement::Query(query) => route_query(query),
Statement::ExplainTable { .. } => QueryTarget::Oltp,
Statement::Explain { statement, .. } => route_statement(statement),
_ => QueryTarget::Oltp,
}
}
fn route_query(query: &Query) -> QueryTarget {
if query.with.is_some() {
return QueryTarget::Olap;
}
match query.body.as_ref() {
SetExpr::Select(select) => route_select(select),
SetExpr::SetOperation { .. } => QueryTarget::Olap,
SetExpr::Query(inner) => route_query(inner),
_ => QueryTarget::Oltp,
}
}
fn route_select(select: &Select) -> QueryTarget {
let has_group_by = match &select.group_by {
GroupByExpr::All(_) => true,
GroupByExpr::Expressions(exprs, _) => !exprs.is_empty(),
};
if has_group_by || select.having.is_some() {
return QueryTarget::Olap;
}
for table in &select.from {
if !table.joins.is_empty() {
return QueryTarget::Olap;
}
if matches!(&table.relation, TableFactor::Derived { .. }) {
return QueryTarget::Olap;
}
}
for item in &select.projection {
if let SelectItem::UnnamedExpr(expr) | SelectItem::ExprWithAlias { expr, .. } = item {
if expr_has_analytical_pattern(expr) {
return QueryTarget::Olap;
}
}
}
if let Some(selection) = &select.selection {
if expr_has_subquery(selection) {
return QueryTarget::Olap;
}
}
QueryTarget::Oltp
}
fn expr_has_analytical_pattern(expr: &Expr) -> bool {
match expr {
Expr::Function(func) => {
if func.over.is_some() {
return true;
}
let name = func.name.to_string().to_ascii_uppercase();
matches!(
name.as_str(),
"COUNT"
| "SUM"
| "AVG"
| "MIN"
| "MAX"
| "STDDEV"
| "VARIANCE"
| "ARRAY_AGG"
| "STRING_AGG"
| "GROUP_CONCAT"
| "MEDIAN"
| "PERCENTILE_CONT"
| "PERCENTILE_DISC"
| "FIRST_VALUE"
| "LAST_VALUE"
| "NTH_VALUE"
| "ROW_NUMBER"
| "RANK"
| "DENSE_RANK"
| "NTILE"
| "LAG"
| "LEAD"
| "CUME_DIST"
| "PERCENT_RANK"
)
}
Expr::Nested(inner) => expr_has_analytical_pattern(inner),
Expr::BinaryOp { left, right, .. } => {
expr_has_analytical_pattern(left) || expr_has_analytical_pattern(right)
}
Expr::UnaryOp { expr, .. } => expr_has_analytical_pattern(expr),
Expr::Cast { expr, .. } => expr_has_analytical_pattern(expr),
Expr::Case {
operand,
conditions,
else_result,
..
} => {
operand
.as_ref()
.is_some_and(|e| expr_has_analytical_pattern(e))
|| conditions.iter().any(|cw| {
expr_has_analytical_pattern(&cw.condition)
|| expr_has_analytical_pattern(&cw.result)
})
|| else_result
.as_ref()
.is_some_and(|e| expr_has_analytical_pattern(e))
}
Expr::Subquery(q) => matches!(route_query(q), QueryTarget::Olap),
Expr::InSubquery { subquery, .. } => matches!(route_query(subquery), QueryTarget::Olap),
_ => false,
}
}
fn expr_has_subquery(expr: &Expr) -> bool {
match expr {
Expr::Subquery(_) | Expr::InSubquery { .. } | Expr::Exists { .. } => true,
Expr::Nested(inner) => expr_has_subquery(inner),
Expr::BinaryOp { left, right, .. } => expr_has_subquery(left) || expr_has_subquery(right),
Expr::UnaryOp { expr, .. } => expr_has_subquery(expr),
_ => false,
}
}
fn heuristic_route(sql: &str) -> QueryTarget {
let trimmed = sql;
const WRITE_KEYWORDS: &[&str] = &[
"INSERT", "UPDATE", "DELETE", "CREATE", "ALTER", "DROP", "BEGIN", "COMMIT", "ROLLBACK",
"PRAGMA",
];
for kw in WRITE_KEYWORDS {
if starts_with_ignore_case(trimmed, kw) {
return QueryTarget::Oltp;
}
}
if starts_with_ignore_case(trimmed, "SELECT") {
const AGGREGATE_FNS: &[&str] = &["COUNT(", "SUM(", "AVG(", "MIN(", "MAX("];
let has_aggregate = AGGREGATE_FNS
.iter()
.any(|agg| contains_ignore_case(trimmed, agg));
let has_grouping =
contains_ignore_case(trimmed, "GROUP BY") || contains_ignore_case(trimmed, "HAVING");
let has_window =
contains_ignore_case(trimmed, "OVER(") || contains_ignore_case(trimmed, "OVER (");
let has_join = contains_ignore_case(trimmed, " JOIN ");
if has_aggregate || has_grouping || has_window || has_join {
return QueryTarget::Olap;
}
}
QueryTarget::Oltp
}
fn starts_with_ignore_case(haystack: &str, needle: &str) -> bool {
debug_assert!(needle.bytes().all(|b| b == b.to_ascii_uppercase()));
haystack.len() >= needle.len()
&& haystack.as_bytes()[..needle.len()]
.iter()
.zip(needle.as_bytes())
.all(|(h, n)| h.to_ascii_uppercase() == *n)
}
fn contains_ignore_case(haystack: &str, needle: &str) -> bool {
debug_assert!(needle.bytes().all(|b| b == b.to_ascii_uppercase()));
if needle.len() > haystack.len() {
return false;
}
haystack.as_bytes().windows(needle.len()).any(|window| {
window
.iter()
.zip(needle.as_bytes())
.all(|(h, n)| h.to_ascii_uppercase() == *n)
})
}
pub struct HeuristicRouter {
inner: SqlParserRouter,
}
impl HeuristicRouter {
pub fn new() -> Self {
Self {
inner: SqlParserRouter::new(),
}
}
}
impl Default for HeuristicRouter {
fn default() -> Self {
Self::new()
}
}
impl QueryRouter for HeuristicRouter {
fn route(&self, sql: &str) -> QueryTarget {
self.inner.route(sql)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_write_operations_route_to_oltp() {
let router = SqlParserRouter::new();
assert_eq!(
router.route("INSERT INTO users VALUES (1, 'Alice')"),
QueryTarget::Oltp
);
assert_eq!(
router.route("UPDATE users SET name = 'Bob' WHERE id = 1"),
QueryTarget::Oltp
);
assert_eq!(
router.route("DELETE FROM users WHERE id = 1"),
QueryTarget::Oltp
);
assert_eq!(
router.route("CREATE TABLE users (id INTEGER)"),
QueryTarget::Oltp
);
assert_eq!(
router.route("ALTER TABLE users ADD COLUMN email TEXT"),
QueryTarget::Oltp
);
}
#[test]
fn test_analytical_queries_route_to_olap() {
let router = SqlParserRouter::new();
assert_eq!(
router.route("SELECT COUNT(*) FROM users"),
QueryTarget::Olap
);
assert_eq!(
router.route("SELECT AVG(age) FROM users GROUP BY dept"),
QueryTarget::Olap
);
assert_eq!(
router.route("SELECT u.name, o.total FROM users u JOIN orders o ON u.id = o.user_id"),
QueryTarget::Olap,
);
}
#[test]
fn test_simple_selects_route_to_oltp() {
let router = SqlParserRouter::new();
assert_eq!(
router.route("SELECT * FROM users WHERE id = 1"),
QueryTarget::Oltp
);
assert_eq!(
router.route("SELECT name FROM users LIMIT 10"),
QueryTarget::Oltp
);
}
#[test]
fn test_window_functions_route_to_olap() {
let router = SqlParserRouter::new();
assert_eq!(
router.route("SELECT id, ROW_NUMBER() OVER (ORDER BY id) FROM users"),
QueryTarget::Olap
);
assert_eq!(
router.route("SELECT id, SUM(age) OVER (PARTITION BY dept) FROM users"),
QueryTarget::Olap
);
}
#[test]
fn test_subqueries_route_to_olap() {
let router = SqlParserRouter::new();
assert_eq!(
router.route("SELECT * FROM users WHERE id IN (SELECT user_id FROM orders)"),
QueryTarget::Olap
);
assert_eq!(
router.route("SELECT * FROM (SELECT dept, COUNT(*) cnt FROM users GROUP BY dept) sub"),
QueryTarget::Olap
);
}
#[test]
fn test_cte_routes_to_olap() {
let router = SqlParserRouter::new();
assert_eq!(
router.route(
"WITH active AS (SELECT * FROM users WHERE active = true) SELECT COUNT(*) FROM active"
),
QueryTarget::Olap
);
}
#[test]
fn test_union_routes_to_olap() {
let router = SqlParserRouter::new();
assert_eq!(
router.route("SELECT id FROM users UNION ALL SELECT id FROM admins"),
QueryTarget::Olap
);
}
#[test]
fn test_string_containing_keywords_not_misrouted() {
let router = SqlParserRouter::new();
assert_eq!(
router.route("SELECT * FROM users WHERE note = 'COUNT(items) is 5'"),
QueryTarget::Oltp
);
}
#[test]
fn test_backwards_compat_heuristic_router() {
let router = HeuristicRouter::new();
assert_eq!(
router.route("SELECT COUNT(*) FROM users"),
QueryTarget::Olap
);
assert_eq!(
router.route("INSERT INTO users VALUES (1, 'Alice')"),
QueryTarget::Oltp
);
}
#[test]
fn test_pragma_routes_to_oltp() {
let router = SqlParserRouter::new();
assert_eq!(router.route("PRAGMA table_info(users)"), QueryTarget::Oltp);
}
}