use std::collections::HashSet;
#[derive(Debug, Default)]
pub struct QueryHints {
pub broadcast_tables: HashSet<String>,
pub shuffle_tables: HashSet<String>,
}
impl QueryHints {
pub fn parse(sql: &str) -> Self {
let mut hints = Self::default();
let mut pos = 0;
while let Some(start) = sql[pos..].find("/*+") {
let abs_start = pos + start + 3; if let Some(end) = sql[abs_start..].find("*/") {
let hint_body = &sql[abs_start..abs_start + end];
Self::parse_body(hint_body, &mut hints);
pos = abs_start + end + 2;
} else {
break;
}
}
hints
}
fn parse_body(body: &str, hints: &mut QueryHints) {
let upper = body.to_uppercase();
let mut chars = upper.chars().peekable();
let mut token = String::new();
while let Some(ch) = chars.next() {
if ch == '(' {
let directive = token.trim().to_string();
token.clear();
let mut table = String::new();
for c in chars.by_ref() {
if c == ')' {
break;
}
table.push(c);
}
let table = table.trim().to_lowercase();
if !table.is_empty() {
match directive.as_str() {
"BROADCAST" => {
hints.broadcast_tables.insert(table);
}
"SHUFFLE" => {
hints.shuffle_tables.insert(table);
}
_ => {} }
}
} else {
token.push(ch);
}
}
}
pub fn is_empty(&self) -> bool {
self.broadcast_tables.is_empty() && self.shuffle_tables.is_empty()
}
pub fn should_broadcast(&self, table: &str) -> bool {
self.broadcast_tables.contains(&table.to_lowercase())
}
pub fn should_shuffle(&self, table: &str) -> bool {
self.shuffle_tables.contains(&table.to_lowercase())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_broadcast_hint() {
let sql = "SELECT /*+ BROADCAST(small_table) */ * FROM large JOIN small_table ON id = id";
let hints = QueryHints::parse(sql);
assert!(hints.should_broadcast("small_table"));
assert!(!hints.should_shuffle("small_table"));
}
#[test]
fn parse_shuffle_hint() {
let sql = "SELECT /*+ SHUFFLE(orders) */ * FROM orders JOIN products ON product_id = id";
let hints = QueryHints::parse(sql);
assert!(hints.should_shuffle("orders"));
}
#[test]
fn parse_multiple_hints() {
let sql = "SELECT /*+ BROADCAST(a) SHUFFLE(b) */ * FROM a JOIN b ON x = y";
let hints = QueryHints::parse(sql);
assert!(hints.should_broadcast("a"));
assert!(hints.should_shuffle("b"));
}
#[test]
fn no_hints() {
let sql = "SELECT * FROM users WHERE id = 1";
let hints = QueryHints::parse(sql);
assert!(hints.is_empty());
}
#[test]
fn case_insensitive() {
let sql = "SELECT /*+ broadcast(MyTable) */ * FROM mytable";
let hints = QueryHints::parse(sql);
assert!(hints.should_broadcast("mytable"));
}
}