use std::collections::HashMap;
use super::{QueryLimit, QueryParam, QueryParamValue, QueryString, SortBy, SortDirection};
#[derive(Debug, Clone)]
pub enum SqlParam {
Int(i64),
Float(f64),
Str(String),
Bytes(Vec<u8>),
}
#[derive(Debug, Clone)]
pub struct SQLQuery {
sql: String,
params: HashMap<String, SqlParam>,
}
impl SQLQuery {
pub fn new(sql: impl Into<String>) -> Self {
Self {
sql: sql.into(),
params: HashMap::new(),
}
}
pub fn with_params(sql: impl Into<String>, params: HashMap<String, SqlParam>) -> Self {
Self {
sql: sql.into(),
params,
}
}
pub fn with_param(mut self, name: impl Into<String>, value: SqlParam) -> Self {
self.params.insert(name.into(), value);
self
}
pub fn sql(&self) -> &str {
&self.sql
}
pub fn params_map(&self) -> &HashMap<String, SqlParam> {
&self.params
}
pub fn substituted_sql(&self) -> String {
substitute_params(&self.sql, &self.params)
}
fn parsed(&self) -> Option<ParsedSelect> {
parse_select(&self.substituted_sql())
}
pub fn is_aggregate(&self) -> bool {
parse_aggregate(&self.substituted_sql()).is_some()
}
pub fn build_aggregate_cmd(&self, index_name: &str) -> Option<redis::Cmd> {
let parsed = parse_aggregate(&self.substituted_sql())?;
Some(parsed.build_cmd(index_name))
}
pub fn is_vector_query(&self) -> bool {
parse_vector_select(&self.substituted_sql(), &self.params).is_some()
}
pub fn is_geo_aggregate(&self) -> bool {
parse_geo_aggregate(&self.substituted_sql()).is_some()
}
pub fn build_geo_aggregate_cmd(&self, index_name: &str) -> Option<redis::Cmd> {
let parsed = parse_geo_aggregate(&self.substituted_sql())?;
Some(parsed.build_cmd(index_name))
}
fn parsed_vector(&self) -> Option<ParsedVectorSelect> {
parse_vector_select(&self.substituted_sql(), &self.params)
}
fn parsed_geo_where(&self) -> Option<ParsedGeoWhere> {
parse_geo_where(&self.substituted_sql())
}
}
impl QueryString for SQLQuery {
fn to_redis_query(&self) -> String {
if let Some(ref vq) = self.parsed_vector() {
return vq.to_knn_query_string();
}
if let Some(ref gw) = self.parsed_geo_where() {
return gw.filter_string();
}
if let Some(parsed) = self.parsed() {
parsed.filter_string()
} else {
self.substituted_sql()
}
}
fn params(&self) -> Vec<QueryParam> {
if let Some(ref vq) = self.parsed_vector() {
return vq.params();
}
Vec::new()
}
fn return_fields(&self) -> Vec<String> {
if let Some(ref vq) = self.parsed_vector() {
return vq.return_fields.clone();
}
if let Some(ref gw) = self.parsed_geo_where() {
return gw.return_fields.clone();
}
self.parsed().map(|p| p.return_fields).unwrap_or_default()
}
fn sort_by(&self) -> Option<SortBy> {
self.parsed().and_then(|p| p.sort_by)
}
fn limit(&self) -> Option<QueryLimit> {
if let Some(ref vq) = self.parsed_vector() {
return Some(QueryLimit {
offset: 0,
num: vq.knn_num,
});
}
self.parsed().and_then(|p| p.limit)
}
fn should_unpack_json(&self) -> bool {
self.parsed()
.map(|p| p.return_fields.is_empty())
.unwrap_or(false)
}
fn geofilter(&self) -> Option<super::GeoFilter> {
self.parsed_geo_where().map(|gw| gw.geofilter)
}
}
fn substitute_params(sql: &str, params: &HashMap<String, SqlParam>) -> String {
if params.is_empty() {
return sql.to_owned();
}
let mut result = String::with_capacity(sql.len());
let bytes = sql.as_bytes();
let len = bytes.len();
let mut i = 0;
while i < len {
if bytes[i] == b':' && i + 1 < len && is_ident_start(bytes[i + 1]) {
let start = i + 1;
let mut end = start;
while end < len && is_ident_continue(bytes[end]) {
end += 1;
}
let key = &sql[start..end];
if let Some(param) = params.get(key) {
match param {
SqlParam::Int(v) => {
result.push_str(&v.to_string());
}
SqlParam::Float(v) => {
result.push_str(&v.to_string());
}
SqlParam::Str(v) => {
result.push('\'');
result.push_str(&v.replace('\'', "''"));
result.push('\'');
}
SqlParam::Bytes(_) => {
result.push(':');
result.push_str(key);
}
}
} else {
result.push(':');
result.push_str(key);
}
i = end;
} else {
result.push(sql[i..].chars().next().unwrap());
i += sql[i..].chars().next().unwrap().len_utf8();
}
}
result
}
fn is_ident_start(b: u8) -> bool {
b.is_ascii_alphabetic() || b == b'_'
}
fn is_ident_continue(b: u8) -> bool {
b.is_ascii_alphanumeric() || b == b'_'
}
#[derive(Debug, Clone)]
struct ParsedSelect {
return_fields: Vec<String>,
where_filter: Option<String>,
sort_by: Option<SortBy>,
limit: Option<QueryLimit>,
}
impl ParsedSelect {
fn filter_string(&self) -> String {
self.where_filter.clone().unwrap_or_else(|| "*".to_owned())
}
}
fn parse_select(sql: &str) -> Option<ParsedSelect> {
let tokens = tokenize(sql);
if tokens.is_empty() {
return None;
}
let mut pos = 0;
if !tok_eq(&tokens, pos, "SELECT") {
return None;
}
pos += 1;
for tok in &tokens {
let upper = tok.to_ascii_uppercase();
if matches!(
upper.as_str(),
"COUNT"
| "AVG"
| "SUM"
| "MIN"
| "MAX"
| "STDDEV"
| "QUANTILE"
| "COUNT_DISTINCT"
| "ARRAY_AGG"
| "FIRST_VALUE"
) {
return None;
}
}
for tok in &tokens {
let lower = tok.to_ascii_lowercase();
if lower == "cosine_distance" || lower == "vector_distance" || lower == "geo_distance" {
return None;
}
}
let mut return_fields = Vec::new();
if tok_eq(&tokens, pos, "*") {
pos += 1;
} else {
loop {
if pos >= tokens.len() {
return None;
}
let field = &tokens[pos];
if field.eq_ignore_ascii_case("FROM") {
break;
}
if !field.eq_ignore_ascii_case(",") && !field.eq_ignore_ascii_case("AS") {
if pos > 0 && tokens[pos - 1].eq_ignore_ascii_case("AS") {
} else {
return_fields.push(field.to_string());
}
}
pos += 1;
}
}
if !tok_eq(&tokens, pos, "FROM") {
return None;
}
pos += 1;
if pos >= tokens.len() {
return None;
}
pos += 1;
let mut where_filter: Option<String> = None;
let mut sort_by: Option<SortBy> = None;
let mut limit: Option<QueryLimit> = None;
while pos < tokens.len() {
if tok_eq(&tokens, pos, "WHERE") {
pos += 1;
let (filter_str, next) = parse_where_clause(&tokens, pos)?;
where_filter = Some(filter_str);
pos = next;
} else if tok_eq(&tokens, pos, "ORDER") {
if !tok_eq(&tokens, pos + 1, "BY") {
return None;
}
pos += 2;
if pos >= tokens.len() {
return None;
}
let field = tokens[pos].clone();
pos += 1;
let direction = if tok_eq(&tokens, pos, "DESC") {
pos += 1;
SortDirection::Desc
} else {
if tok_eq(&tokens, pos, "ASC") {
pos += 1;
}
SortDirection::Asc
};
sort_by = Some(SortBy { field, direction });
} else if tok_eq(&tokens, pos, "LIMIT") {
pos += 1;
let num = parse_usize(&tokens, pos)?;
pos += 1;
let offset = if tok_eq(&tokens, pos, "OFFSET") {
pos += 1;
let off = parse_usize(&tokens, pos)?;
pos += 1;
off
} else {
0
};
limit = Some(QueryLimit { offset, num });
} else {
pos += 1;
}
}
Some(ParsedSelect {
return_fields,
where_filter,
sort_by,
limit,
})
}
#[derive(Debug, Clone)]
struct AggReducer {
function: String,
field: Option<String>,
alias: String,
extra_arg: Option<f64>,
}
#[derive(Debug, Clone)]
struct ParsedAggregate {
where_filter: Option<String>,
group_by_fields: Vec<String>,
reducers: Vec<AggReducer>,
}
impl ParsedAggregate {
fn build_cmd(&self, index_name: &str) -> redis::Cmd {
let mut cmd = redis::cmd("FT.AGGREGATE");
cmd.arg(index_name);
let filter = self.where_filter.as_deref().unwrap_or("*");
cmd.arg(filter);
if self.group_by_fields.is_empty() {
cmd.arg("GROUPBY").arg(0_u32);
for reducer in &self.reducers {
self.append_reducer(&mut cmd, reducer);
}
} else {
cmd.arg("GROUPBY").arg(self.group_by_fields.len());
for field in &self.group_by_fields {
cmd.arg(format!("@{}", field));
}
for reducer in &self.reducers {
self.append_reducer(&mut cmd, reducer);
}
}
cmd
}
fn append_reducer(&self, cmd: &mut redis::Cmd, reducer: &AggReducer) {
cmd.arg("REDUCE");
cmd.arg(&reducer.function);
match reducer.function.as_str() {
"COUNT" => {
cmd.arg(0_u32); }
"QUANTILE" => {
cmd.arg(2_u32);
if let Some(ref field) = reducer.field {
cmd.arg(format!("@{}", field));
}
if let Some(q) = reducer.extra_arg {
cmd.arg(format_num(q));
}
}
_ => {
cmd.arg(1_u32);
if let Some(ref field) = reducer.field {
cmd.arg(format!("@{}", field));
}
}
}
cmd.arg("AS").arg(&reducer.alias);
}
}
fn parse_aggregate(sql: &str) -> Option<ParsedAggregate> {
let tokens = tokenize(sql);
if tokens.is_empty() {
return None;
}
let mut pos = 0;
if !tok_eq(&tokens, pos, "SELECT") {
return None;
}
pos += 1;
let has_aggregate_fn = tokens.iter().any(|t| {
let upper = t.to_ascii_uppercase();
matches!(
upper.as_str(),
"COUNT"
| "AVG"
| "SUM"
| "MIN"
| "MAX"
| "STDDEV"
| "QUANTILE"
| "COUNT_DISTINCT"
| "ARRAY_AGG"
| "FIRST_VALUE"
)
});
let has_group_by = tokens
.windows(2)
.any(|w| w[0].eq_ignore_ascii_case("GROUP") && w[1].eq_ignore_ascii_case("BY"));
if !has_aggregate_fn && !has_group_by {
return None;
}
let mut reducers = Vec::new();
while pos < tokens.len() && !tok_eq(&tokens, pos, "FROM") {
if let Some((reducer, next)) = try_parse_aggregate_fn(&tokens, pos) {
reducers.push(reducer);
pos = next;
} else if tokens[pos] == "," {
pos += 1;
} else {
pos += 1;
}
}
if !tok_eq(&tokens, pos, "FROM") {
return None;
}
pos += 1;
if pos >= tokens.len() {
return None;
}
pos += 1;
let mut where_filter: Option<String> = None;
let mut group_by_fields = Vec::new();
while pos < tokens.len() {
if tok_eq(&tokens, pos, "WHERE") {
pos += 1;
let (filter_str, next) = parse_where_clause(&tokens, pos)?;
where_filter = Some(filter_str);
pos = next;
} else if tok_eq(&tokens, pos, "GROUP") {
if !tok_eq(&tokens, pos + 1, "BY") {
return None;
}
pos += 2;
while pos < tokens.len() {
let upper = tokens[pos].to_ascii_uppercase();
if matches!(upper.as_str(), "HAVING" | "ORDER" | "LIMIT") {
break;
}
if tokens[pos] == "," {
pos += 1;
continue;
}
group_by_fields.push(tokens[pos].clone());
pos += 1;
}
} else {
pos += 1;
}
}
if reducers.is_empty() {
return None;
}
Some(ParsedAggregate {
where_filter,
group_by_fields,
reducers,
})
}
fn try_parse_aggregate_fn(tokens: &[String], pos: usize) -> Option<(AggReducer, usize)> {
if pos >= tokens.len() {
return None;
}
let func_upper = tokens[pos].to_ascii_uppercase();
let redis_func = match func_upper.as_str() {
"COUNT" => "COUNT",
"SUM" => "SUM",
"AVG" => "AVG",
"MIN" => "MIN",
"MAX" => "MAX",
"STDDEV" => "STDDEV",
"COUNT_DISTINCT" => "COUNT_DISTINCT",
"QUANTILE" => "QUANTILE",
"ARRAY_AGG" => "TOLIST",
"FIRST_VALUE" => "FIRST_VALUE",
_ => return None,
};
let mut p = pos + 1;
if !tok_eq(tokens, p, "(") {
return None;
}
p += 1;
let mut field: Option<String> = None;
let mut extra_arg: Option<f64> = None;
if func_upper == "COUNT" && tok_eq(tokens, p, "*") {
p += 1;
} else if p < tokens.len() && tokens[p] != ")" {
field = Some(tokens[p].clone());
p += 1;
if tok_eq(tokens, p, ",") {
p += 1;
if p < tokens.len() && tokens[p] != ")" {
extra_arg = tokens[p].parse::<f64>().ok();
p += 1;
}
}
}
if !tok_eq(tokens, p, ")") {
return None;
}
p += 1;
let alias = if tok_eq(tokens, p, "AS") {
p += 1;
if p >= tokens.len() {
return None;
}
let a = tokens[p].clone();
p += 1;
a
} else {
func_upper.to_lowercase()
};
Some((
AggReducer {
function: redis_func.to_owned(),
field,
alias,
extra_arg,
},
p,
))
}
#[derive(Debug, Clone)]
struct VectorFuncCall {
field: String,
param_name: String,
alias: String,
}
#[derive(Debug, Clone)]
struct ParsedVectorSelect {
vector_fn: VectorFuncCall,
return_fields: Vec<String>,
where_filter: Option<String>,
knn_num: usize,
vector_blob: Option<Vec<u8>>,
}
impl ParsedVectorSelect {
fn to_knn_query_string(&self) -> String {
let base = self.where_filter.as_deref().unwrap_or("*");
format!(
"{}=>[KNN {} @{} $vector AS {}]",
base, self.knn_num, self.vector_fn.field, self.vector_fn.alias
)
}
fn params(&self) -> Vec<QueryParam> {
if let Some(ref blob) = self.vector_blob {
vec![QueryParam {
name: "vector".to_owned(),
value: QueryParamValue::Binary(blob.clone()),
}]
} else {
Vec::new()
}
}
}
fn parse_vector_select(
sql: &str,
params: &HashMap<String, SqlParam>,
) -> Option<ParsedVectorSelect> {
let tokens = tokenize(sql);
if tokens.is_empty() {
return None;
}
let mut pos = 0;
if !tok_eq(&tokens, pos, "SELECT") {
return None;
}
pos += 1;
let mut vector_fn: Option<VectorFuncCall> = None;
let mut return_fields: Vec<String> = Vec::new();
while pos < tokens.len() && !tok_eq(&tokens, pos, "FROM") {
if tokens[pos] == "," {
pos += 1;
continue;
}
let lower = tokens[pos].to_ascii_lowercase();
if (lower == "vector_distance" || lower == "cosine_distance")
&& tok_eq(&tokens, pos + 1, "(")
{
let parsed = try_parse_vector_fn_call(&tokens, pos)?;
vector_fn = Some(parsed.0);
pos = parsed.1;
continue;
}
if tokens[pos].eq_ignore_ascii_case("AS") {
pos += 1; if pos < tokens.len() && !tok_eq(&tokens, pos, "FROM") {
pos += 1; }
continue;
}
if !tokens[pos].eq_ignore_ascii_case("*") {
return_fields.push(tokens[pos].clone());
}
pos += 1;
}
let vector_fn = vector_fn?;
if !tok_eq(&tokens, pos, "FROM") {
return None;
}
pos += 1;
if pos >= tokens.len() {
return None;
}
pos += 1;
let mut where_filter: Option<String> = None;
let mut knn_num: usize = 10;
while pos < tokens.len() {
if tok_eq(&tokens, pos, "WHERE") {
pos += 1;
let (filter_str, next) = parse_where_clause(&tokens, pos)?;
where_filter = Some(filter_str);
pos = next;
} else if tok_eq(&tokens, pos, "ORDER") {
while pos < tokens.len()
&& !tok_eq(&tokens, pos, "LIMIT")
&& !tok_eq(&tokens, pos, "WHERE")
{
pos += 1;
}
} else if tok_eq(&tokens, pos, "LIMIT") {
pos += 1;
knn_num = parse_usize(&tokens, pos)?;
pos += 1;
if tok_eq(&tokens, pos, "OFFSET") {
pos += 2;
}
} else {
pos += 1;
}
}
let vector_blob = params.get(&vector_fn.param_name).and_then(|p| {
if let SqlParam::Bytes(b) = p {
Some(b.clone())
} else {
None
}
});
Some(ParsedVectorSelect {
vector_fn,
return_fields,
where_filter,
knn_num,
vector_blob,
})
}
fn try_parse_vector_fn_call(tokens: &[String], pos: usize) -> Option<(VectorFuncCall, usize)> {
if pos + 5 >= tokens.len() {
return None;
}
let _func_name = &tokens[pos]; let mut p = pos + 1;
if !tok_eq(tokens, p, "(") {
return None;
}
p += 1;
let field = tokens[p].clone();
p += 1;
if !tok_eq(tokens, p, ",") {
return None;
}
p += 1;
let param_tok = &tokens[p];
let param_name = if param_tok.starts_with(':') {
param_tok[1..].to_string()
} else {
param_tok.clone()
};
p += 1;
if !tok_eq(tokens, p, ")") {
return None;
}
p += 1;
let alias = if tok_eq(tokens, p, "AS") {
p += 1;
if p >= tokens.len() {
return None;
}
let a = tokens[p].clone();
p += 1;
a
} else {
"vector_distance".to_string()
};
Some((
VectorFuncCall {
field,
param_name,
alias,
},
p,
))
}
#[derive(Debug, Clone)]
struct ParsedGeoWhere {
geofilter: super::GeoFilter,
non_geo_filter: Option<String>,
return_fields: Vec<String>,
}
impl ParsedGeoWhere {
fn filter_string(&self) -> String {
self.non_geo_filter
.clone()
.unwrap_or_else(|| "*".to_owned())
}
}
#[derive(Debug, Clone)]
struct ParsedGeoAggregate {
geo_field: String,
lon: f64,
lat: f64,
alias: String,
where_filter: Option<String>,
}
impl ParsedGeoAggregate {
fn build_cmd(&self, index_name: &str) -> redis::Cmd {
let mut cmd = redis::cmd("FT.AGGREGATE");
cmd.arg(index_name);
cmd.arg(self.where_filter.as_deref().unwrap_or("*"));
cmd.arg("LOAD")
.arg(1_u32)
.arg(format!("@{}", self.geo_field));
let expr = format!(
"geodistance(@{}, {}, {})",
self.geo_field, self.lon, self.lat
);
cmd.arg("APPLY").arg(expr).arg("AS").arg(&self.alias);
cmd
}
}
fn parse_geo_where(sql: &str) -> Option<ParsedGeoWhere> {
let tokens = tokenize(sql);
if tokens.is_empty() {
return None;
}
let mut pos = 0;
if !tok_eq(&tokens, pos, "SELECT") {
return None;
}
pos += 1;
let mut return_fields: Vec<String> = Vec::new();
if tok_eq(&tokens, pos, "*") {
pos += 1;
} else {
while pos < tokens.len() && !tok_eq(&tokens, pos, "FROM") {
if tokens[pos] == "," || tokens[pos].eq_ignore_ascii_case("AS") {
pos += 1;
if pos > 1
&& tokens[pos - 1].eq_ignore_ascii_case("AS")
&& pos < tokens.len()
&& !tok_eq(&tokens, pos, "FROM")
{
pos += 1;
}
continue;
}
return_fields.push(tokens[pos].clone());
pos += 1;
}
}
if !tok_eq(&tokens, pos, "FROM") {
return None;
}
pos += 1;
if pos >= tokens.len() {
return None;
}
pos += 1;
if !tok_eq(&tokens, pos, "WHERE") {
return None;
}
pos += 1;
let mut non_geo_conditions: Vec<String> = Vec::new();
let mut geofilter: Option<super::GeoFilter> = None;
loop {
if pos >= tokens.len() {
break;
}
let upper = tokens[pos].to_ascii_uppercase();
if matches!(upper.as_str(), "ORDER" | "LIMIT" | "GROUP" | "HAVING") {
break;
}
if upper == "AND" {
pos += 1;
continue;
}
if tokens[pos].eq_ignore_ascii_case("geo_distance") && tok_eq(&tokens, pos + 1, "(") {
let (gf, next) = parse_geo_distance_where(&tokens, pos)?;
geofilter = Some(gf);
pos = next;
continue;
}
let (filter, next) = parse_single_condition(&tokens, pos)?;
non_geo_conditions.push(filter);
pos = next;
}
let geofilter = geofilter?;
let non_geo_filter = if non_geo_conditions.is_empty() {
None
} else if non_geo_conditions.len() == 1 {
Some(non_geo_conditions.into_iter().next().unwrap())
} else {
Some(format!("({})", non_geo_conditions.join(" ")))
};
Some(ParsedGeoWhere {
geofilter,
non_geo_filter,
return_fields,
})
}
fn parse_geo_distance_where(tokens: &[String], pos: usize) -> Option<(super::GeoFilter, usize)> {
let mut p = pos;
if !tokens[p].eq_ignore_ascii_case("geo_distance") {
return None;
}
p += 1;
if !tok_eq(tokens, p, "(") {
return None;
}
p += 1;
let field = tokens[p].clone();
p += 1;
if !tok_eq(tokens, p, ",") {
return None;
}
p += 1;
let (lon, lat);
if tokens[p].eq_ignore_ascii_case("POINT") {
p += 1;
if !tok_eq(tokens, p, "(") {
return None;
}
p += 1;
lon = tokens[p].parse::<f64>().ok()?;
p += 1;
if !tok_eq(tokens, p, ",") {
return None;
}
p += 1;
lat = tokens[p].parse::<f64>().ok()?;
p += 1;
if !tok_eq(tokens, p, ")") {
return None;
}
p += 1;
} else {
lon = tokens[p].parse::<f64>().ok()?;
p += 1;
if tok_eq(tokens, p, ",") {
p += 1;
}
lat = tokens[p].parse::<f64>().ok()?;
p += 1;
}
if !tok_eq(tokens, p, ",") {
return None;
}
p += 1;
let unit = unquote(&tokens[p]);
p += 1;
if !tok_eq(tokens, p, ")") {
return None;
}
p += 1;
if !tok_eq(tokens, p, "<") {
return None;
}
p += 1;
let radius = tokens[p].parse::<f64>().ok()?;
p += 1;
Some((
super::GeoFilter {
field,
lon,
lat,
radius,
unit,
},
p,
))
}
fn parse_geo_aggregate(sql: &str) -> Option<ParsedGeoAggregate> {
let tokens = tokenize(sql);
if tokens.is_empty() {
return None;
}
let mut pos = 0;
if !tok_eq(&tokens, pos, "SELECT") {
return None;
}
pos += 1;
let mut geo_field: Option<String> = None;
let mut geo_lon: Option<f64> = None;
let mut geo_lat: Option<f64> = None;
let mut geo_alias: Option<String> = None;
while pos < tokens.len() && !tok_eq(&tokens, pos, "FROM") {
if tokens[pos] == "," {
pos += 1;
continue;
}
if tokens[pos].eq_ignore_ascii_case("geo_distance") && tok_eq(&tokens, pos + 1, "(") {
pos += 2; let field = tokens[pos].clone();
pos += 1;
if !tok_eq(&tokens, pos, ",") {
return None;
}
pos += 1;
let (lon, lat);
if tokens[pos].eq_ignore_ascii_case("POINT") {
pos += 1;
if !tok_eq(&tokens, pos, "(") {
return None;
}
pos += 1;
lon = tokens[pos].parse::<f64>().ok()?;
pos += 1;
if tok_eq(&tokens, pos, ",") {
pos += 1;
}
lat = tokens[pos].parse::<f64>().ok()?;
pos += 1;
if !tok_eq(&tokens, pos, ")") {
return None;
}
pos += 1;
} else {
return None;
}
if !tok_eq(&tokens, pos, ")") {
return None;
}
pos += 1;
let alias = if tok_eq(&tokens, pos, "AS") {
pos += 1;
let a = tokens[pos].clone();
pos += 1;
a
} else {
"distance".to_string()
};
geo_field = Some(field);
geo_lon = Some(lon);
geo_lat = Some(lat);
geo_alias = Some(alias);
continue;
}
if tokens[pos].eq_ignore_ascii_case("AS") {
pos += 1;
if pos < tokens.len() {
pos += 1; }
continue;
}
pos += 1;
}
let geo_field = geo_field?;
let lon = geo_lon?;
let lat = geo_lat?;
let alias = geo_alias.unwrap_or_else(|| "distance".to_string());
if !tok_eq(&tokens, pos, "FROM") {
return None;
}
pos += 1;
if pos >= tokens.len() {
return None;
}
pos += 1;
let mut where_filter: Option<String> = None;
while pos < tokens.len() {
if tok_eq(&tokens, pos, "WHERE") {
pos += 1;
let (filter_str, next) = parse_where_clause(&tokens, pos)?;
where_filter = Some(filter_str);
pos = next;
} else {
pos += 1;
}
}
Some(ParsedGeoAggregate {
geo_field,
lon,
lat,
alias,
where_filter,
})
}
fn parse_where_clause(tokens: &[String], mut pos: usize) -> Option<(String, usize)> {
let mut or_groups: Vec<Vec<String>> = Vec::new();
let mut current_and_group: Vec<String> = Vec::new();
loop {
if pos >= tokens.len() {
break;
}
let upper = tokens[pos].to_ascii_uppercase();
if matches!(upper.as_str(), "ORDER" | "LIMIT" | "GROUP" | "HAVING") {
break;
}
if upper == "AND" {
pos += 1;
continue;
}
if upper == "OR" {
pos += 1;
or_groups.push(std::mem::take(&mut current_and_group));
continue;
}
let (filter, next) = parse_single_condition(tokens, pos)?;
current_and_group.push(filter);
pos = next;
}
if !current_and_group.is_empty() {
or_groups.push(current_and_group);
}
if or_groups.is_empty() {
return Some(("*".to_owned(), pos));
}
let group_strs: Vec<String> = or_groups
.into_iter()
.map(|g| {
if g.len() == 1 {
g.into_iter().next().unwrap()
} else {
format!("({})", g.join(" "))
}
})
.collect();
let filter = if group_strs.len() == 1 {
group_strs.into_iter().next().unwrap()
} else {
format!("({})", group_strs.join(" | "))
};
Some((filter, pos))
}
fn parse_single_condition(tokens: &[String], mut pos: usize) -> Option<(String, usize)> {
let field = &tokens[pos];
pos += 1;
if pos >= tokens.len() {
return None;
}
let op = &tokens[pos];
pos += 1;
if op.eq_ignore_ascii_case("BETWEEN") {
let lo = parse_numeric_or_date_literal(tokens, pos)?;
pos += 1;
if !tok_eq(tokens, pos, "AND") {
return None;
}
pos += 1;
let hi = parse_numeric_or_date_literal(tokens, pos)?;
pos += 1;
return Some((
format!("@{}:[{} {}]", field, format_num(lo), format_num(hi)),
pos,
));
}
if op.eq_ignore_ascii_case("NOT") && tok_eq(tokens, pos, "IN") {
pos += 1; if !tok_eq(tokens, pos, "(") {
return None;
}
pos += 1;
let mut vals = Vec::new();
loop {
if pos >= tokens.len() {
return None;
}
if tokens[pos] == ")" {
pos += 1;
break;
}
if tokens[pos] == "," {
pos += 1;
continue;
}
vals.push(unquote(&tokens[pos]));
pos += 1;
}
let escaped: Vec<String> = vals.iter().map(|v| escape_tag(v)).collect();
return Some((format!("(-@{}:{{{}}})", field, escaped.join("|")), pos));
}
if op.eq_ignore_ascii_case("IN") {
if !tok_eq(tokens, pos, "(") {
return None;
}
pos += 1;
let mut vals = Vec::new();
loop {
if pos >= tokens.len() {
return None;
}
if tokens[pos] == ")" {
pos += 1;
break;
}
if tokens[pos] == "," {
pos += 1;
continue;
}
vals.push(unquote(&tokens[pos]));
pos += 1;
}
let escaped: Vec<String> = vals.iter().map(|v| escape_tag(v)).collect();
return Some((format!("@{}:{{{}}}", field, escaped.join("|")), pos));
}
if op.eq_ignore_ascii_case("LIKE") {
if pos >= tokens.len() {
return None;
}
let pattern = unquote(&tokens[pos]);
pos += 1;
let redis_pattern = sql_like_to_redis(&pattern);
return Some((format!("@{}:({})", field, redis_pattern), pos));
}
if op.eq_ignore_ascii_case("NOT") && tok_eq(tokens, pos, "LIKE") {
pos += 1; if pos >= tokens.len() {
return None;
}
let pattern = unquote(&tokens[pos]);
pos += 1;
let redis_pattern = sql_like_to_redis(&pattern);
return Some((format!("(-@{}:({}))", field, redis_pattern), pos));
}
if op == "!=" {
if pos >= tokens.len() {
return None;
}
let value = unquote(&tokens[pos]);
pos += 1;
if is_numeric_str(&value) {
let n: f64 = value.parse().ok()?;
return Some((
format!("(-@{}:[{} {}])", field, format_num(n), format_num(n)),
pos,
));
}
if let Some(ts) = try_parse_date(&value) {
return Some((
format!("(-@{}:[{} {}])", field, format_num(ts), format_num(ts)),
pos,
));
}
return Some((format!("(-@{}:{{{}}})", field, escape_tag(&value)), pos));
}
if pos >= tokens.len() {
return None;
}
let (real_op, value_str) = if (op == "<" || op == ">") && tokens[pos] == "=" {
let combined = format!("{}=", op);
pos += 1;
if pos >= tokens.len() {
return None;
}
let v = unquote(&tokens[pos]);
pos += 1;
(combined, v)
} else {
let v = unquote(&tokens[pos]);
pos += 1;
(op.clone(), v)
};
let filter = match real_op.as_str() {
"=" => {
if is_numeric_str(&value_str) {
let n: f64 = value_str.parse().ok()?;
format!("@{}:[{} {}]", field, format_num(n), format_num(n))
} else if let Some(ts) = try_parse_date(&value_str) {
format!("@{}:[{} {}]", field, format_num(ts), format_num(ts))
} else {
let val = value_str.clone();
if val.contains('*') || val.contains('%') {
format!("@{}:({})", field, val)
} else if val.contains(' ') {
format!("@{}:(\"{}\")", field, val)
} else {
format!("@{}:{{{}}}", field, escape_tag(&val))
}
}
}
"<" => {
let n = parse_num_or_date(&value_str)?;
format!("@{}:[-inf ({}]", field, format_num(n))
}
">" => {
let n = parse_num_or_date(&value_str)?;
format!("@{}:[({} +inf]", field, format_num(n))
}
"<=" => {
let n = parse_num_or_date(&value_str)?;
format!("@{}:[-inf {}]", field, format_num(n))
}
">=" => {
let n = parse_num_or_date(&value_str)?;
format!("@{}:[{} +inf]", field, format_num(n))
}
_ => return None,
};
Some((filter, pos))
}
fn tokenize(sql: &str) -> Vec<String> {
let mut tokens = Vec::new();
let chars: Vec<char> = sql.chars().collect();
let len = chars.len();
let mut i = 0;
while i < len {
if chars[i].is_ascii_whitespace() {
i += 1;
continue;
}
if chars[i] == '\'' {
let mut s = String::new();
s.push('\'');
i += 1;
while i < len {
if chars[i] == '\'' {
if i + 1 < len && chars[i + 1] == '\'' {
s.push('\'');
s.push('\'');
i += 2;
} else {
break;
}
} else {
s.push(chars[i]);
i += 1;
}
}
s.push('\'');
if i < len {
i += 1;
}
tokens.push(s);
continue;
}
if chars[i] == ':'
&& i + 1 < len
&& (chars[i + 1].is_ascii_alphabetic() || chars[i + 1] == '_')
{
let start = i;
i += 1; while i < len && (chars[i].is_ascii_alphanumeric() || chars[i] == '_') {
i += 1;
}
tokens.push(chars[start..i].iter().collect());
continue;
}
if chars[i].is_ascii_alphabetic() || chars[i] == '_' {
let start = i;
while i < len && (chars[i].is_ascii_alphanumeric() || chars[i] == '_') {
i += 1;
}
tokens.push(chars[start..i].iter().collect());
continue;
}
if chars[i].is_ascii_digit()
|| (chars[i] == '-' && i + 1 < len && chars[i + 1].is_ascii_digit())
{
let start = i;
if chars[i] == '-' {
i += 1;
}
while i < len && (chars[i].is_ascii_digit() || chars[i] == '.') {
i += 1;
}
tokens.push(chars[start..i].iter().collect());
continue;
}
if i + 1 < len {
let two: String = chars[i..i + 2].iter().collect();
if two == "!=" || two == "<=" || two == ">=" {
tokens.push(two);
i += 2;
continue;
}
}
tokens.push(chars[i].to_string());
i += 1;
}
tokens
}
fn tok_eq(tokens: &[String], pos: usize, expected: &str) -> bool {
tokens
.get(pos)
.map_or(false, |t| t.eq_ignore_ascii_case(expected))
}
fn parse_usize(tokens: &[String], pos: usize) -> Option<usize> {
tokens.get(pos)?.parse().ok()
}
fn parse_numeric_or_date_literal(tokens: &[String], pos: usize) -> Option<f64> {
let tok = tokens.get(pos)?;
let s = unquote(tok);
if let Ok(n) = s.parse::<f64>() {
Some(n)
} else {
try_parse_date(&s)
}
}
fn parse_num_or_date(s: &str) -> Option<f64> {
if let Ok(n) = s.parse::<f64>() {
Some(n)
} else {
try_parse_date(s)
}
}
fn try_parse_date(s: &str) -> Option<f64> {
if s.len() == 10 && s.as_bytes().get(4) == Some(&b'-') && s.as_bytes().get(7) == Some(&b'-') {
let year: i32 = s[0..4].parse().ok()?;
let month: u32 = s[5..7].parse().ok()?;
let day: u32 = s[8..10].parse().ok()?;
if !(1..=12).contains(&month) || !(1..=31).contains(&day) {
return None;
}
let ts = date_to_unix_timestamp(year, month, day)?;
return Some(ts as f64);
}
if s.len() >= 19 && (s.as_bytes().get(10) == Some(&b'T') || s.as_bytes().get(10) == Some(&b' '))
{
let year: i32 = s[0..4].parse().ok()?;
let month: u32 = s[5..7].parse().ok()?;
let day: u32 = s[8..10].parse().ok()?;
let hour: u32 = s[11..13].parse().ok()?;
let min: u32 = s[14..16].parse().ok()?;
let sec: u32 = s[17..19].parse().ok()?;
if !(1..=12).contains(&month) || !(1..=31).contains(&day) {
return None;
}
if hour > 23 || min > 59 || sec > 59 {
return None;
}
let day_ts = date_to_unix_timestamp(year, month, day)?;
let ts = day_ts + (hour as i64) * 3600 + (min as i64) * 60 + (sec as i64);
return Some(ts as f64);
}
None
}
fn date_to_unix_timestamp(year: i32, month: u32, day: u32) -> Option<i64> {
const DAYS_IN_MONTH: [u32; 12] = [31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31];
fn is_leap(y: i32) -> bool {
(y % 4 == 0 && y % 100 != 0) || y % 400 == 0
}
let mut days: i64 = 0;
if year >= 1970 {
for y in 1970..year {
days += if is_leap(y) { 366 } else { 365 };
}
} else {
for y in year..1970 {
days -= if is_leap(y) { 366 } else { 365 };
}
}
for m in 1..month {
let mut d = DAYS_IN_MONTH[(m - 1) as usize];
if m == 2 && is_leap(year) {
d += 1;
}
days += d as i64;
}
days += (day as i64) - 1;
Some(days * 86400)
}
fn sql_like_to_redis(pattern: &str) -> String {
pattern.replace('%', "*")
}
fn unquote(s: &str) -> String {
if s.len() >= 2 && s.starts_with('\'') && s.ends_with('\'') {
let inner = &s[1..s.len() - 1];
inner.replace("''", "'")
} else {
s.to_string()
}
}
fn escape_tag(value: &str) -> String {
value
.chars()
.flat_map(|ch| {
if matches!(ch, ' ' | '$' | ':' | '&' | '/' | '-' | '.' | '*') {
vec!['\\', ch]
} else {
vec![ch]
}
})
.collect()
}
fn is_numeric_str(s: &str) -> bool {
s.parse::<f64>().is_ok()
}
fn format_num(n: f64) -> String {
if n.fract() == 0.0 {
format!("{:.0}", n)
} else {
n.to_string()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn similar_param_names_no_partial_match() {
let query = SQLQuery::with_params(
"SELECT * FROM idx WHERE id = :id AND product_id = :product_id",
HashMap::from([
("id".to_owned(), SqlParam::Int(123)),
("product_id".to_owned(), SqlParam::Int(456)),
]),
);
let substituted = query.substituted_sql();
assert!(substituted.contains("id = 123"));
assert!(substituted.contains("product_id = 456"));
assert!(!substituted.contains("product_123"));
}
#[test]
fn prefix_param_names() {
let query = SQLQuery::with_params(
"SELECT * FROM idx WHERE user = :user AND user_id = :user_id AND user_name = :user_name",
HashMap::from([
("user".to_owned(), SqlParam::Str("alice".to_owned())),
("user_id".to_owned(), SqlParam::Int(42)),
(
"user_name".to_owned(),
SqlParam::Str("Alice Smith".to_owned()),
),
]),
);
let substituted = query.substituted_sql();
assert!(substituted.contains("user = 'alice'"));
assert!(substituted.contains("user_id = 42"));
assert!(substituted.contains("user_name = 'Alice Smith'"));
assert!(!substituted.contains("'alice'_id"));
assert!(!substituted.contains("'alice'_name"));
}
#[test]
fn suffix_param_names() {
let query = SQLQuery::with_params(
"SELECT * FROM idx WHERE vec = :vec AND query_vec = :query_vec",
HashMap::from([
("vec".to_owned(), SqlParam::Float(1.0)),
("query_vec".to_owned(), SqlParam::Float(2.0)),
]),
);
let substituted = query.substituted_sql();
assert!(substituted.contains("vec = 1") || substituted.contains("vec = 1.0"));
assert!(substituted.contains("query_vec = 2") || substituted.contains("query_vec = 2.0"));
}
#[test]
fn single_quote_in_value() {
let query = SQLQuery::new("SELECT * FROM idx WHERE name = :name")
.with_param("name", SqlParam::Str("O'Brien".to_owned()));
let substituted = query.substituted_sql();
assert!(substituted.contains("name = 'O''Brien'"));
}
#[test]
fn multiple_quotes_in_value() {
let query = SQLQuery::new("SELECT * FROM idx WHERE phrase = :phrase")
.with_param("phrase", SqlParam::Str("It's a 'test' string".to_owned()));
let substituted = query.substituted_sql();
assert!(substituted.contains("phrase = 'It''s a ''test'' string'"));
}
#[test]
fn apostrophe_names() {
let cases = [
("McDonald's", "'McDonald''s'"),
("O'Reilly", "'O''Reilly'"),
("D'Angelo", "'D''Angelo'"),
];
for (name, expected) in cases {
let query = SQLQuery::new("SELECT * FROM idx WHERE name = :name")
.with_param("name", SqlParam::Str(name.to_owned()));
let substituted = query.substituted_sql();
assert!(
substituted.contains(&format!("name = {expected}")),
"Failed for {name}: got {substituted}"
);
}
}
#[test]
fn multiple_occurrences_same_param() {
let query = SQLQuery::new("SELECT * FROM idx WHERE category = :cat OR subcategory = :cat")
.with_param("cat", SqlParam::Str("electronics".to_owned()));
let substituted = query.substituted_sql();
assert_eq!(substituted.matches("'electronics'").count(), 2);
}
#[test]
fn empty_string_value() {
let query = SQLQuery::new("SELECT * FROM idx WHERE name = :name")
.with_param("name", SqlParam::Str(String::new()));
let substituted = query.substituted_sql();
assert!(substituted.contains("name = ''"));
}
#[test]
fn numeric_types() {
let query = SQLQuery::with_params(
"SELECT * FROM idx WHERE count = :count AND price = :price",
HashMap::from([
("count".to_owned(), SqlParam::Int(42)),
("price".to_owned(), SqlParam::Float(99.99)),
]),
);
let substituted = query.substituted_sql();
assert!(substituted.contains("count = 42"));
assert!(substituted.contains("price = 99.99"));
}
#[test]
fn bytes_param_not_substituted() {
let query = SQLQuery::new("SELECT * FROM idx WHERE embedding = :vec")
.with_param("vec", SqlParam::Bytes(vec![0x00, 0x01, 0x02, 0x03]));
let substituted = query.substituted_sql();
assert!(substituted.contains(":vec"));
}
#[test]
fn special_characters_in_value() {
let specials = [
"hello@world.com",
"path/to/file",
"price: $100",
"regex.*pattern",
"back\\slash",
];
for value in specials {
let query = SQLQuery::new("SELECT * FROM idx WHERE field = :field")
.with_param("field", SqlParam::Str(value.to_owned()));
let substituted = query.substituted_sql();
assert!(
!substituted.contains(":field"),
"Failed to substitute for value: {value}"
);
}
}
#[test]
fn no_params_returns_original() {
let query = SQLQuery::new("SELECT * FROM idx");
assert_eq!(query.substituted_sql(), "SELECT * FROM idx");
}
#[test]
fn unknown_placeholder_kept() {
let query = SQLQuery::new("SELECT * FROM idx WHERE x = :unknown")
.with_param("other", SqlParam::Int(1));
assert!(query.substituted_sql().contains(":unknown"));
}
#[test]
fn with_param_builder_pattern() {
let query = SQLQuery::new("SELECT * FROM idx WHERE a = :a AND b = :b")
.with_param("a", SqlParam::Int(1))
.with_param("b", SqlParam::Str("hello".to_owned()));
let sub = query.substituted_sql();
assert!(sub.contains("a = 1"));
assert!(sub.contains("b = 'hello'"));
}
#[test]
fn sql_accessor() {
let query = SQLQuery::new("SELECT 1");
assert_eq!(query.sql(), "SELECT 1");
}
#[test]
fn params_map_accessor() {
let query = SQLQuery::new("SELECT 1").with_param("x", SqlParam::Int(42));
assert_eq!(query.params_map().len(), 1);
}
#[test]
fn select_star_no_where_produces_wildcard() {
let query = SQLQuery::new("SELECT * FROM products");
assert_eq!(query.to_redis_query(), "*");
}
#[test]
fn select_specific_fields_sets_return_fields() {
let query = SQLQuery::new("SELECT title, price FROM products");
assert_eq!(query.to_redis_query(), "*");
assert_eq!(query.return_fields(), vec!["title", "price"]);
}
#[test]
fn where_tag_equals() {
let query = SQLQuery::new("SELECT * FROM products WHERE category = 'electronics'");
assert_eq!(query.to_redis_query(), "@category:{electronics}");
}
#[test]
fn where_tag_not_equals() {
let query = SQLQuery::new("SELECT * FROM products WHERE category != 'electronics'");
assert_eq!(query.to_redis_query(), "(-@category:{electronics})");
}
#[test]
fn where_tag_in() {
let query =
SQLQuery::new("SELECT * FROM products WHERE category IN ('books', 'accessories')");
assert_eq!(query.to_redis_query(), "@category:{books|accessories}");
}
#[test]
fn where_numeric_less_than() {
let query = SQLQuery::new("SELECT * FROM products WHERE price < 50");
assert_eq!(query.to_redis_query(), "@price:[-inf (50]");
}
#[test]
fn where_numeric_greater_than() {
let query = SQLQuery::new("SELECT * FROM products WHERE price > 100");
assert_eq!(query.to_redis_query(), "@price:[(100 +inf]");
}
#[test]
fn where_numeric_equals() {
let query = SQLQuery::new("SELECT * FROM products WHERE price = 45");
assert_eq!(query.to_redis_query(), "@price:[45 45]");
}
#[test]
fn where_numeric_not_equals() {
let query = SQLQuery::new("SELECT * FROM products WHERE price != 45");
assert_eq!(query.to_redis_query(), "(-@price:[45 45])");
}
#[test]
fn where_numeric_lte() {
let query = SQLQuery::new("SELECT * FROM products WHERE price <= 50");
assert_eq!(query.to_redis_query(), "@price:[-inf 50]");
}
#[test]
fn where_numeric_gte() {
let query = SQLQuery::new("SELECT * FROM products WHERE price >= 25");
assert_eq!(query.to_redis_query(), "@price:[25 +inf]");
}
#[test]
fn where_between() {
let query = SQLQuery::new("SELECT * FROM products WHERE price BETWEEN 40 AND 60");
assert_eq!(query.to_redis_query(), "@price:[40 60]");
}
#[test]
fn where_combined_and() {
let query =
SQLQuery::new("SELECT * FROM products WHERE category = 'electronics' AND price < 100");
assert_eq!(
query.to_redis_query(),
"(@category:{electronics} @price:[-inf (100])"
);
}
#[test]
fn order_by_asc() {
let query = SQLQuery::new("SELECT title, price FROM products ORDER BY price ASC");
let sb = query.sort_by().expect("sort_by should be set");
assert_eq!(sb.field, "price");
assert!(matches!(sb.direction, SortDirection::Asc));
}
#[test]
fn order_by_desc() {
let query = SQLQuery::new("SELECT title, price FROM products ORDER BY price DESC");
let sb = query.sort_by().expect("sort_by should be set");
assert_eq!(sb.field, "price");
assert!(matches!(sb.direction, SortDirection::Desc));
}
#[test]
fn limit_clause() {
let query = SQLQuery::new("SELECT title FROM products LIMIT 3");
let lim = query.limit().expect("limit should be set");
assert_eq!(lim.num, 3);
assert_eq!(lim.offset, 0);
}
#[test]
fn limit_with_offset() {
let query = SQLQuery::new("SELECT title FROM products ORDER BY price ASC LIMIT 3 OFFSET 3");
let lim = query.limit().expect("limit should be set");
assert_eq!(lim.num, 3);
assert_eq!(lim.offset, 3);
}
#[test]
fn where_with_order_and_limit() {
let query = SQLQuery::new(
"SELECT title, price FROM products WHERE category = 'electronics' ORDER BY price ASC LIMIT 5",
);
assert_eq!(query.to_redis_query(), "@category:{electronics}");
assert_eq!(query.return_fields(), vec!["title", "price"]);
let sb = query.sort_by().expect("sort_by");
assert_eq!(sb.field, "price");
let lim = query.limit().expect("limit");
assert_eq!(lim.num, 5);
}
#[test]
fn aggregate_query_returns_raw_sql_fallback() {
let query = SQLQuery::new("SELECT COUNT(*) as total FROM products");
let result = query.to_redis_query();
assert!(result.contains("COUNT"));
}
#[test]
fn text_equality_single_word() {
let query = SQLQuery::new("SELECT * FROM products WHERE title = 'laptop'");
assert_eq!(query.to_redis_query(), "@title:{laptop}");
}
#[test]
fn text_equality_phrase() {
let query = SQLQuery::new("SELECT * FROM products WHERE title = 'gaming laptop'");
assert_eq!(query.to_redis_query(), "@title:(\"gaming laptop\")");
}
#[test]
fn numeric_range_with_and() {
let query = SQLQuery::new("SELECT * FROM products WHERE price >= 25 AND price <= 50");
assert_eq!(
query.to_redis_query(),
"(@price:[25 +inf] @price:[-inf 50])"
);
}
#[test]
fn should_unpack_json_for_select_star() {
let query = SQLQuery::new("SELECT * FROM products");
assert!(query.should_unpack_json());
}
#[test]
fn should_not_unpack_json_for_field_projection() {
let query = SQLQuery::new("SELECT title, price FROM products");
assert!(!query.should_unpack_json());
}
#[test]
fn with_param_where_tag() {
let query = SQLQuery::new("SELECT * FROM products WHERE category = :cat")
.with_param("cat", SqlParam::Str("electronics".to_owned()));
assert_eq!(query.to_redis_query(), "@category:{electronics}");
}
#[test]
fn with_param_where_numeric() {
let query = SQLQuery::new("SELECT * FROM products WHERE price > :min_price")
.with_param("min_price", SqlParam::Float(99.99));
assert_eq!(query.to_redis_query(), "@price:[(99.99 +inf]");
}
#[test]
fn where_simple_or() {
let query = SQLQuery::new(
"SELECT * FROM products WHERE category = 'electronics' OR category = 'books'",
);
assert_eq!(
query.to_redis_query(),
"(@category:{electronics} | @category:{books})"
);
}
#[test]
fn where_or_with_three_branches() {
let query = SQLQuery::new(
"SELECT * FROM products WHERE category = 'electronics' OR category = 'books' OR category = 'accessories'",
);
assert_eq!(
query.to_redis_query(),
"(@category:{electronics} | @category:{books} | @category:{accessories})"
);
}
#[test]
fn where_and_binds_tighter_than_or() {
let query = SQLQuery::new(
"SELECT * FROM products WHERE category = 'electronics' AND price > 100 OR category = 'books' AND price < 50",
);
assert_eq!(
query.to_redis_query(),
"((@category:{electronics} @price:[(100 +inf]) | (@category:{books} @price:[-inf (50]))"
);
}
#[test]
fn where_or_with_single_conditions() {
let query = SQLQuery::new("SELECT * FROM products WHERE price < 20 OR price > 1000");
assert_eq!(
query.to_redis_query(),
"(@price:[-inf (20] | @price:[(1000 +inf])"
);
}
#[test]
fn where_or_preserves_order_limit() {
let query = SQLQuery::new(
"SELECT title FROM products WHERE category = 'a' OR category = 'b' ORDER BY price ASC LIMIT 5",
);
assert_eq!(query.to_redis_query(), "(@category:{a} | @category:{b})");
assert!(query.sort_by().is_some());
assert_eq!(query.limit().unwrap().num, 5);
}
#[test]
fn where_not_in() {
let query =
SQLQuery::new("SELECT * FROM products WHERE category NOT IN ('electronics', 'books')");
assert_eq!(query.to_redis_query(), "(-@category:{electronics|books})");
}
#[test]
fn where_not_in_combined_with_and() {
let query = SQLQuery::new(
"SELECT * FROM products WHERE category NOT IN ('electronics') AND price > 50",
);
assert_eq!(
query.to_redis_query(),
"((-@category:{electronics}) @price:[(50 +inf])"
);
}
#[test]
fn where_like_prefix() {
let query = SQLQuery::new("SELECT * FROM products WHERE title LIKE 'laptop%'");
assert_eq!(query.to_redis_query(), "@title:(laptop*)");
}
#[test]
fn where_like_suffix() {
let query = SQLQuery::new("SELECT * FROM products WHERE title LIKE '%laptop'");
assert_eq!(query.to_redis_query(), "@title:(*laptop)");
}
#[test]
fn where_like_contains() {
let query = SQLQuery::new("SELECT * FROM products WHERE title LIKE '%laptop%'");
assert_eq!(query.to_redis_query(), "@title:(*laptop*)");
}
#[test]
fn where_not_like() {
let query = SQLQuery::new("SELECT * FROM products WHERE title NOT LIKE 'laptop%'");
assert_eq!(query.to_redis_query(), "(-@title:(laptop*))");
}
#[test]
fn where_like_combined_with_and() {
let query =
SQLQuery::new("SELECT * FROM products WHERE title LIKE 'lap%' AND price < 1000");
assert_eq!(
query.to_redis_query(),
"(@title:(lap*) @price:[-inf (1000])"
);
}
#[test]
fn where_date_greater_than() {
let query = SQLQuery::new("SELECT * FROM events WHERE created_at > '2024-01-01'");
let result = query.to_redis_query();
assert_eq!(result, "@created_at:[(1704067200 +inf]");
}
#[test]
fn where_date_less_than() {
let query = SQLQuery::new("SELECT * FROM events WHERE created_at < '2024-03-31'");
let result = query.to_redis_query();
assert_eq!(result, "@created_at:[-inf (1711843200]");
}
#[test]
fn where_date_between() {
let query = SQLQuery::new(
"SELECT * FROM events WHERE created_at BETWEEN '2024-01-01' AND '2024-03-31'",
);
let result = query.to_redis_query();
assert_eq!(result, "@created_at:[1704067200 1711843200]");
}
#[test]
fn where_date_gte() {
let query = SQLQuery::new("SELECT * FROM events WHERE created_at >= '2024-06-15'");
let result = query.to_redis_query();
assert_eq!(result, "@created_at:[1718409600 +inf]");
}
#[test]
fn where_date_combined_with_tag() {
let query = SQLQuery::new(
"SELECT * FROM events WHERE category = 'meeting' AND created_at > '2024-01-01'",
);
let result = query.to_redis_query();
assert_eq!(
result,
"(@category:{meeting} @created_at:[(1704067200 +inf])"
);
}
#[test]
fn where_datetime_with_time() {
let query = SQLQuery::new("SELECT * FROM events WHERE created_at > '2024-01-15T10:30:00'");
let result = query.to_redis_query();
assert_eq!(result, "@created_at:[(1705314600 +inf]");
}
#[test]
fn date_to_timestamp_known_values() {
assert_eq!(try_parse_date("1970-01-01"), Some(0.0));
assert_eq!(try_parse_date("2000-01-01"), Some(946_684_800.0));
assert_eq!(try_parse_date("2024-01-01"), Some(1_704_067_200.0));
}
#[test]
fn invalid_date_returns_none() {
assert_eq!(try_parse_date("not-a-date"), None);
assert_eq!(try_parse_date("2024-13-01"), None); assert_eq!(try_parse_date("2024-00-01"), None); assert_eq!(try_parse_date("2024-01-32"), None); }
#[test]
fn where_or_with_like() {
let query = SQLQuery::new(
"SELECT * FROM products WHERE title LIKE 'laptop%' OR title LIKE 'phone%'",
);
assert_eq!(
query.to_redis_query(),
"(@title:(laptop*) | @title:(phone*))"
);
}
#[test]
fn where_or_with_date() {
let query = SQLQuery::new(
"SELECT * FROM events WHERE created_at < '2024-01-01' OR created_at > '2024-12-31'",
);
let result = query.to_redis_query();
assert_eq!(
result,
"(@created_at:[-inf (1704067200] | @created_at:[(1735603200 +inf])"
);
}
fn agg_cmd_args(sql: &str, index_name: &str) -> Vec<String> {
let q = SQLQuery::new(sql);
assert!(q.is_aggregate(), "expected aggregate for: {sql}");
let cmd = q.build_aggregate_cmd(index_name).unwrap();
let packed = cmd.get_packed_command();
parse_resp_args(&packed)
}
fn parse_resp_args(data: &[u8]) -> Vec<String> {
let s = String::from_utf8_lossy(data);
let mut args = Vec::new();
let mut remaining = &s[..];
while let Some(dollar) = remaining.find('$') {
remaining = &remaining[dollar + 1..];
let crlf = remaining.find("\r\n").unwrap();
let len: usize = remaining[..crlf].parse().unwrap();
remaining = &remaining[crlf + 2..];
let val = &remaining[..len];
args.push(val.to_string());
remaining = &remaining[len + 2..]; }
args
}
#[test]
fn aggregate_count_star() {
let args = agg_cmd_args("SELECT COUNT(*) AS total FROM products", "idx");
assert_eq!(args[0], "FT.AGGREGATE");
assert_eq!(args[1], "idx");
assert_eq!(args[2], "*"); assert_eq!(args[3], "GROUPBY");
assert_eq!(args[4], "0");
assert_eq!(args[5], "REDUCE");
assert_eq!(args[6], "COUNT");
assert_eq!(args[7], "0"); assert_eq!(args[8], "AS");
assert_eq!(args[9], "total");
}
#[test]
fn aggregate_count_star_default_alias() {
let args = agg_cmd_args("SELECT COUNT(*) FROM products", "idx");
assert_eq!(args[9], "count"); }
#[test]
fn aggregate_sum() {
let args = agg_cmd_args("SELECT SUM(price) AS total_price FROM products", "idx");
assert_eq!(args[5], "REDUCE");
assert_eq!(args[6], "SUM");
assert_eq!(args[7], "1"); assert_eq!(args[8], "@price");
assert_eq!(args[9], "AS");
assert_eq!(args[10], "total_price");
}
#[test]
fn aggregate_avg() {
let args = agg_cmd_args("SELECT AVG(score) AS avg_score FROM products", "idx");
assert_eq!(args[6], "AVG");
assert_eq!(args[8], "@score");
assert_eq!(args[10], "avg_score");
}
#[test]
fn aggregate_min_max() {
let args = agg_cmd_args("SELECT MIN(price) AS min_price FROM products", "idx");
assert_eq!(args[6], "MIN");
assert_eq!(args[8], "@price");
assert_eq!(args[10], "min_price");
let args = agg_cmd_args("SELECT MAX(price) AS max_price FROM products", "idx");
assert_eq!(args[6], "MAX");
assert_eq!(args[8], "@price");
assert_eq!(args[10], "max_price");
}
#[test]
fn aggregate_stddev() {
let args = agg_cmd_args("SELECT STDDEV(price) AS price_sd FROM products", "idx");
assert_eq!(args[6], "STDDEV");
assert_eq!(args[8], "@price");
assert_eq!(args[10], "price_sd");
}
#[test]
fn aggregate_count_distinct() {
let args = agg_cmd_args(
"SELECT COUNT_DISTINCT(brand) AS unique_brands FROM products",
"idx",
);
assert_eq!(args[6], "COUNT_DISTINCT");
assert_eq!(args[8], "@brand");
assert_eq!(args[10], "unique_brands");
}
#[test]
fn aggregate_quantile() {
let args = agg_cmd_args("SELECT QUANTILE(price, 0.95) AS p95 FROM products", "idx");
assert_eq!(args[6], "QUANTILE");
assert_eq!(args[7], "2"); assert_eq!(args[8], "@price");
assert_eq!(args[9], "0.95");
assert_eq!(args[10], "AS");
assert_eq!(args[11], "p95");
}
#[test]
fn aggregate_array_agg_to_tolist() {
let args = agg_cmd_args("SELECT ARRAY_AGG(name) AS names FROM products", "idx");
assert_eq!(args[6], "TOLIST");
assert_eq!(args[8], "@name");
assert_eq!(args[10], "names");
}
#[test]
fn aggregate_first_value() {
let args = agg_cmd_args(
"SELECT FIRST_VALUE(name) AS first_name FROM products",
"idx",
);
assert_eq!(args[6], "FIRST_VALUE");
assert_eq!(args[8], "@name");
assert_eq!(args[10], "first_name");
}
#[test]
fn aggregate_group_by_single_field() {
let args = agg_cmd_args(
"SELECT category, COUNT(*) AS cnt FROM products GROUP BY category",
"idx",
);
assert_eq!(args[0], "FT.AGGREGATE");
assert_eq!(args[1], "idx");
assert_eq!(args[2], "*");
assert_eq!(args[3], "GROUPBY");
assert_eq!(args[4], "1");
assert_eq!(args[5], "@category");
assert_eq!(args[6], "REDUCE");
assert_eq!(args[7], "COUNT");
assert_eq!(args[8], "0");
assert_eq!(args[9], "AS");
assert_eq!(args[10], "cnt");
}
#[test]
fn aggregate_group_by_with_where() {
let args = agg_cmd_args(
"SELECT category, AVG(price) AS avg_price FROM products WHERE price > 10 GROUP BY category",
"idx",
);
assert_eq!(args[2], "@price:[(10 +inf]"); assert_eq!(args[3], "GROUPBY");
assert_eq!(args[4], "1");
assert_eq!(args[5], "@category");
assert_eq!(args[6], "REDUCE");
assert_eq!(args[7], "AVG");
}
#[test]
fn aggregate_multiple_reducers() {
let args = agg_cmd_args(
"SELECT category, COUNT(*) AS cnt, AVG(price) AS avg_price FROM products GROUP BY category",
"idx",
);
assert_eq!(args[3], "GROUPBY");
assert_eq!(args[4], "1");
assert_eq!(args[5], "@category");
assert_eq!(args[6], "REDUCE");
assert_eq!(args[7], "COUNT");
assert_eq!(args[8], "0");
assert_eq!(args[9], "AS");
assert_eq!(args[10], "cnt");
assert_eq!(args[11], "REDUCE");
assert_eq!(args[12], "AVG");
assert_eq!(args[13], "1");
assert_eq!(args[14], "@price");
assert_eq!(args[15], "AS");
assert_eq!(args[16], "avg_price");
}
#[test]
fn aggregate_group_by_multiple_fields() {
let args = agg_cmd_args(
"SELECT category, brand, SUM(price) AS total FROM products GROUP BY category, brand",
"idx",
);
assert_eq!(args[3], "GROUPBY");
assert_eq!(args[4], "2");
assert_eq!(args[5], "@category");
assert_eq!(args[6], "@brand");
assert_eq!(args[7], "REDUCE");
assert_eq!(args[8], "SUM");
}
#[test]
fn non_aggregate_is_not_detected_as_aggregate() {
let q = SQLQuery::new("SELECT * FROM products WHERE price > 10");
assert!(!q.is_aggregate());
assert!(q.build_aggregate_cmd("idx").is_none());
}
#[test]
fn aggregate_query_returns_raw_sql_for_search() {
let q = SQLQuery::new("SELECT COUNT(*) AS total FROM products");
assert!(q.is_aggregate());
let redis_q = q.to_redis_query();
assert!(redis_q.contains("COUNT"));
}
#[test]
fn vector_distance_basic() {
let blob = vec![0u8; 12]; let q = SQLQuery::new(
"SELECT title, vector_distance(embedding, :vec) AS score FROM idx LIMIT 3",
)
.with_param("vec", SqlParam::Bytes(blob.clone()));
assert!(q.is_vector_query());
let query_str = q.to_redis_query();
assert_eq!(query_str, "*=>[KNN 3 @embedding $vector AS score]");
let params = q.params();
assert_eq!(params.len(), 1);
assert_eq!(params[0].name, "vector");
if let QueryParamValue::Binary(ref b) = params[0].value {
assert_eq!(b, &blob);
} else {
panic!("Expected Binary param");
}
}
#[test]
fn cosine_distance_basic() {
let blob = vec![0u8; 12];
let q = SQLQuery::new(
"SELECT title, cosine_distance(embedding, :vec) AS dist FROM idx LIMIT 5",
)
.with_param("vec", SqlParam::Bytes(blob));
assert!(q.is_vector_query());
let query_str = q.to_redis_query();
assert_eq!(query_str, "*=>[KNN 5 @embedding $vector AS dist]");
}
#[test]
fn vector_distance_with_where_filter() {
let blob = vec![0u8; 12];
let q = SQLQuery::new(
"SELECT title, vector_distance(embedding, :vec) AS score FROM idx WHERE genre = 'sci-fi' LIMIT 3",
)
.with_param("vec", SqlParam::Bytes(blob));
let query_str = q.to_redis_query();
assert_eq!(
query_str,
"@genre:{sci\\-fi}=>[KNN 3 @embedding $vector AS score]"
);
}
#[test]
fn vector_distance_default_alias() {
let blob = vec![0u8; 12];
let q = SQLQuery::new("SELECT vector_distance(embedding, :vec) FROM idx LIMIT 10")
.with_param("vec", SqlParam::Bytes(blob));
let query_str = q.to_redis_query();
assert_eq!(
query_str,
"*=>[KNN 10 @embedding $vector AS vector_distance]"
);
}
#[test]
fn vector_query_return_fields() {
let blob = vec![0u8; 12];
let q = SQLQuery::new(
"SELECT title, author, vector_distance(embedding, :vec) AS score FROM idx LIMIT 5",
)
.with_param("vec", SqlParam::Bytes(blob));
let fields = q.return_fields();
assert_eq!(fields, vec!["title", "author"]);
}
#[test]
fn vector_query_limit_as_knn() {
let blob = vec![0u8; 12];
let q = SQLQuery::new("SELECT vector_distance(embedding, :vec) AS score FROM idx LIMIT 7")
.with_param("vec", SqlParam::Bytes(blob));
let limit = q.limit().expect("should have limit");
assert_eq!(limit.num, 7);
assert_eq!(limit.offset, 0);
}
#[test]
fn non_vector_query_not_detected_as_vector() {
let q = SQLQuery::new("SELECT * FROM products WHERE price > 10");
assert!(!q.is_vector_query());
}
#[test]
fn geo_distance_where_basic() {
let q = SQLQuery::new(
"SELECT * FROM locations WHERE geo_distance(location, POINT(-122.4194, 37.7749), 'km') < 50",
);
let gf = q.geofilter().expect("should have geofilter");
assert_eq!(gf.field, "location");
assert!((gf.lon - (-122.4194)).abs() < 0.0001);
assert!((gf.lat - 37.7749).abs() < 0.0001);
assert!((gf.radius - 50.0).abs() < 0.001);
assert_eq!(gf.unit, "km");
assert_eq!(q.to_redis_query(), "*");
}
#[test]
fn geo_distance_where_with_other_conditions() {
let q = SQLQuery::new(
"SELECT name FROM locations WHERE category = 'restaurant' AND geo_distance(location, POINT(-122.4194, 37.7749), 'mi') < 10",
);
let gf = q.geofilter().expect("should have geofilter");
assert_eq!(gf.field, "location");
assert!((gf.radius - 10.0).abs() < 0.001);
assert_eq!(gf.unit, "mi");
assert_eq!(q.to_redis_query(), "@category:{restaurant}");
}
#[test]
fn non_geo_query_no_geofilter() {
let q = SQLQuery::new("SELECT * FROM products WHERE price > 10");
assert!(q.geofilter().is_none());
}
#[test]
fn geo_distance_select_aggregate() {
let q = SQLQuery::new(
"SELECT name, geo_distance(location, POINT(-122.4194, 37.7749)) AS distance FROM locations",
);
assert!(q.is_geo_aggregate());
let cmd = q.build_geo_aggregate_cmd("idx").expect("should build cmd");
let packed = cmd.get_packed_command();
let args = parse_resp_args(&packed);
assert_eq!(args[0], "FT.AGGREGATE");
assert_eq!(args[1], "idx");
assert_eq!(args[2], "*");
assert_eq!(args[3], "LOAD");
assert_eq!(args[4], "1");
assert_eq!(args[5], "@location");
assert_eq!(args[6], "APPLY");
assert!(args[7].contains("geodistance"));
assert!(args[7].contains("@location"));
assert_eq!(args[8], "AS");
assert_eq!(args[9], "distance");
}
#[test]
fn geo_distance_select_with_where() {
let q = SQLQuery::new(
"SELECT name, geo_distance(location, POINT(-73.9857, 40.7484)) AS dist FROM places WHERE category = 'cafe'",
);
assert!(q.is_geo_aggregate());
let cmd = q.build_geo_aggregate_cmd("idx").expect("should build cmd");
let packed = cmd.get_packed_command();
let args = parse_resp_args(&packed);
assert_eq!(args[0], "FT.AGGREGATE");
assert_eq!(args[2], "@category:{cafe}");
}
#[test]
fn non_geo_not_detected_as_geo_aggregate() {
let q = SQLQuery::new("SELECT * FROM products WHERE price > 10");
assert!(!q.is_geo_aggregate());
assert!(q.build_geo_aggregate_cmd("idx").is_none());
}
#[test]
fn tokenizer_handles_colon_param() {
let tokens = tokenize("SELECT vector_distance(embedding, :vec) AS score FROM idx");
assert!(tokens.contains(&":vec".to_owned()));
}
}