Skip to main content

redis_vl/query/
sql.rs

1//! SQL-like query type for Redis Search parity with Python `redisvl.query.SQLQuery`.
2//!
3//! This module mirrors the upstream Python `SQLQuery` class that accepts SQL-like
4//! syntax and optional parameters. It implements a lightweight SQL-to-Redis-Search
5//! translation layer that converts `SELECT` statements into `FT.SEARCH` filter
6//! syntax, and aggregate queries (`COUNT`, `GROUP BY`, etc.) into `FT.AGGREGATE`
7//! commands.
8//!
9//! ## What is implemented
10//!
11//! - `SQLQuery` type holding the raw SQL string and optional parameters
12//! - `SqlParam` enum for typed parameter values (string, numeric, binary)
13//! - Token-based parameter substitution that prevents partial-match bugs
14//!   (`:id` won't clobber `:product_id`) and escapes single quotes in strings
15//! - `QueryString` trait implementation so `SQLQuery` can be passed to
16//!   `SearchIndex::query()` / `AsyncSearchIndex::query()`
17//! - SQL→Redis translation for non-aggregate `SELECT` queries:
18//!   - `WHERE` clauses with tag `=`, `!=`, `IN`, `NOT IN`; numeric `=`, `!=`,
19//!     `<`, `>`, `<=`, `>=`, `BETWEEN`; text `=`, `!=`; `LIKE` / `NOT LIKE`
20//!   - `AND` and `OR` combinators with correct precedence (AND binds tighter)
21//!   - ISO date literal parsing (`'2024-01-01'`, `'2024-01-15T10:30:00'`)
22//!     in comparison and `BETWEEN` clauses
23//!   - `ORDER BY field ASC|DESC`
24//!   - `LIMIT n [OFFSET m]`
25//!   - `SELECT field1, field2` → `RETURN` field projection
26//! - Aggregate SQL → `FT.AGGREGATE` translation:
27//!   - `COUNT(*)`, `SUM(field)`, `AVG(field)`, `MIN(field)`, `MAX(field)`
28//!   - `STDDEV(field)`, `COUNT_DISTINCT(field)`, `QUANTILE(field, q)`
29//!   - `ARRAY_AGG(field)` → `TOLIST`, `FIRST_VALUE(field)`
30//!   - `GROUP BY field` with multiple reducers
31//!   - `WHERE` filters combined with `GROUP BY`
32//!   - Global aggregation (no `GROUP BY`)
33//!
34//! ## What is out of scope (not yet implemented)
35//!
36//! - Vector search functions (`cosine_distance()`, `vector_distance()`)
37//! - GEO functions (`geo_distance()`)
38//! - Date functions (`YEAR()` in SELECT, `GROUP BY YEAR()`)
39//! - `IS NULL` / `IS NOT NULL`
40//! - `HAVING` clause
41//! - Phrase-level stopword handling
42
43use std::collections::HashMap;
44
45use super::{QueryLimit, QueryParam, QueryParamValue, QueryString, SortBy, SortDirection};
46
47/// A typed SQL parameter value.
48#[derive(Debug, Clone)]
49pub enum SqlParam {
50    /// An integer parameter.
51    Int(i64),
52    /// A floating-point parameter.
53    Float(f64),
54    /// A UTF-8 string parameter.
55    Str(String),
56    /// A binary blob (typically a serialized vector). Binary params are **not**
57    /// substituted into the SQL text; they are kept as placeholders for the
58    /// downstream executor.
59    Bytes(Vec<u8>),
60}
61
62/// SQL-like query for Redis Search.
63///
64/// Holds a SQL `SELECT` statement and optional named parameters. Parameter
65/// placeholders use the `:name` syntax (e.g. `:id`, `:product_id`).
66///
67/// When used with `SearchIndex::query()`, the SQL is parsed and translated
68/// into a Redis Search `FT.SEARCH` filter string. Non-aggregate queries with
69/// `WHERE`, `ORDER BY`, `LIMIT`, and `OFFSET` clauses are supported.
70///
71/// # Example
72///
73/// ```
74/// use redis_vl::{SQLQuery, SqlParam};
75///
76/// let query = SQLQuery::new("SELECT * FROM idx WHERE price > :min_price")
77///     .with_param("min_price", SqlParam::Float(99.99));
78///
79/// assert!(!query.substituted_sql().contains(":min_price"));
80/// assert!(query.substituted_sql().contains("99.99"));
81/// ```
82#[derive(Debug, Clone)]
83pub struct SQLQuery {
84    sql: String,
85    params: HashMap<String, SqlParam>,
86}
87
88impl SQLQuery {
89    /// Creates an SQL query wrapper with no parameters.
90    pub fn new(sql: impl Into<String>) -> Self {
91        Self {
92            sql: sql.into(),
93            params: HashMap::new(),
94        }
95    }
96
97    /// Creates an SQL query with pre-populated parameters.
98    pub fn with_params(sql: impl Into<String>, params: HashMap<String, SqlParam>) -> Self {
99        Self {
100            sql: sql.into(),
101            params,
102        }
103    }
104
105    /// Adds a single named parameter.
106    pub fn with_param(mut self, name: impl Into<String>, value: SqlParam) -> Self {
107        self.params.insert(name.into(), value);
108        self
109    }
110
111    /// Returns the raw SQL string.
112    pub fn sql(&self) -> &str {
113        &self.sql
114    }
115
116    /// Returns the parameter map.
117    pub fn params_map(&self) -> &HashMap<String, SqlParam> {
118        &self.params
119    }
120
121    /// Returns the SQL string with non-binary parameters substituted.
122    ///
123    /// Uses a token-based approach: splits the SQL on `:param` boundaries to
124    /// prevent partial matching (`:id` inside `:product_id` stays intact).
125    /// Single quotes in string values are SQL-escaped (`'` → `''`).
126    pub fn substituted_sql(&self) -> String {
127        substitute_params(&self.sql, &self.params)
128    }
129
130    /// Parses the SQL statement into a [`ParsedSelect`].
131    ///
132    /// Returns `None` if the SQL cannot be parsed (e.g. aggregate or
133    /// unsupported syntax). In that case, the raw substituted SQL is used
134    /// as the Redis query string (fallback behaviour).
135    fn parsed(&self) -> Option<ParsedSelect> {
136        parse_select(&self.substituted_sql())
137    }
138
139    /// Returns `true` if this SQL query is an aggregate query (contains
140    /// aggregate functions like `COUNT`, `SUM`, `AVG`, etc., or `GROUP BY`).
141    ///
142    /// Aggregate queries are translated to `FT.AGGREGATE` rather than
143    /// `FT.SEARCH`.
144    pub fn is_aggregate(&self) -> bool {
145        parse_aggregate(&self.substituted_sql()).is_some()
146    }
147
148    /// Builds an `FT.AGGREGATE` command for aggregate SQL queries.
149    ///
150    /// Returns `None` if the SQL is not an aggregate query.
151    ///
152    /// The `index_name` parameter is the Redis Search index to query.
153    pub fn build_aggregate_cmd(&self, index_name: &str) -> Option<redis::Cmd> {
154        let parsed = parse_aggregate(&self.substituted_sql())?;
155        Some(parsed.build_cmd(index_name))
156    }
157
158    /// Returns `true` if this SQL query contains vector search functions
159    /// (`vector_distance()` or `cosine_distance()`).
160    pub fn is_vector_query(&self) -> bool {
161        parse_vector_select(&self.substituted_sql(), &self.params).is_some()
162    }
163
164    /// Returns `true` if this SQL query contains `geo_distance()` in the
165    /// SELECT clause (which generates `FT.AGGREGATE` with `APPLY geodistance`).
166    pub fn is_geo_aggregate(&self) -> bool {
167        parse_geo_aggregate(&self.substituted_sql()).is_some()
168    }
169
170    /// Builds an `FT.AGGREGATE` command for geo_distance in SELECT.
171    ///
172    /// Returns `None` if the SQL doesn't have `geo_distance()` in SELECT.
173    pub fn build_geo_aggregate_cmd(&self, index_name: &str) -> Option<redis::Cmd> {
174        let parsed = parse_geo_aggregate(&self.substituted_sql())?;
175        Some(parsed.build_cmd(index_name))
176    }
177
178    /// Parses a vector SQL query for internal use by `QueryString`.
179    fn parsed_vector(&self) -> Option<ParsedVectorSelect> {
180        parse_vector_select(&self.substituted_sql(), &self.params)
181    }
182
183    /// Parses a geo WHERE filter for internal use by `QueryString`.
184    fn parsed_geo_where(&self) -> Option<ParsedGeoWhere> {
185        parse_geo_where(&self.substituted_sql())
186    }
187}
188
189impl QueryString for SQLQuery {
190    fn to_redis_query(&self) -> String {
191        // Vector queries: generate KNN query string.
192        if let Some(ref vq) = self.parsed_vector() {
193            return vq.to_knn_query_string();
194        }
195        // Geo WHERE queries: generate filter + GEOFILTER handled separately.
196        if let Some(ref gw) = self.parsed_geo_where() {
197            return gw.filter_string();
198        }
199        if let Some(parsed) = self.parsed() {
200            parsed.filter_string()
201        } else {
202            // Fallback: return raw substituted SQL (backwards-compatible).
203            self.substituted_sql()
204        }
205    }
206
207    fn params(&self) -> Vec<QueryParam> {
208        // Vector queries need binary vector params.
209        if let Some(ref vq) = self.parsed_vector() {
210            return vq.params();
211        }
212        Vec::new()
213    }
214
215    fn return_fields(&self) -> Vec<String> {
216        if let Some(ref vq) = self.parsed_vector() {
217            return vq.return_fields.clone();
218        }
219        if let Some(ref gw) = self.parsed_geo_where() {
220            return gw.return_fields.clone();
221        }
222        self.parsed().map(|p| p.return_fields).unwrap_or_default()
223    }
224
225    fn sort_by(&self) -> Option<SortBy> {
226        self.parsed().and_then(|p| p.sort_by)
227    }
228
229    fn limit(&self) -> Option<QueryLimit> {
230        if let Some(ref vq) = self.parsed_vector() {
231            return Some(QueryLimit {
232                offset: 0,
233                num: vq.knn_num,
234            });
235        }
236        self.parsed().and_then(|p| p.limit)
237    }
238
239    fn should_unpack_json(&self) -> bool {
240        // Unpack JSON when no explicit field projection (SELECT *).
241        self.parsed()
242            .map(|p| p.return_fields.is_empty())
243            .unwrap_or(false)
244    }
245
246    fn geofilter(&self) -> Option<super::GeoFilter> {
247        self.parsed_geo_where().map(|gw| gw.geofilter)
248    }
249}
250
251/// Substitutes `:name` parameter placeholders in `sql` using `params`.
252///
253/// Uses token-based splitting on `:identifier` boundaries so that `:id`
254/// placeholders are never partially matched inside `:product_id`.
255///
256/// - `Int` and `Float` values are stringified directly.
257/// - `Str` values are wrapped in single quotes with `'` escaped to `''`.
258/// - `Bytes` values are left as their original placeholder (for downstream
259///   executor handling).
260fn substitute_params(sql: &str, params: &HashMap<String, SqlParam>) -> String {
261    if params.is_empty() {
262        return sql.to_owned();
263    }
264
265    // Split on `:identifier` tokens, keeping delimiters.
266    let mut result = String::with_capacity(sql.len());
267    let bytes = sql.as_bytes();
268    let len = bytes.len();
269    let mut i = 0;
270
271    while i < len {
272        if bytes[i] == b':' && i + 1 < len && is_ident_start(bytes[i + 1]) {
273            // Found a potential parameter placeholder — consume the identifier.
274            let start = i + 1;
275            let mut end = start;
276            while end < len && is_ident_continue(bytes[end]) {
277                end += 1;
278            }
279            let key = &sql[start..end];
280            if let Some(param) = params.get(key) {
281                match param {
282                    SqlParam::Int(v) => {
283                        result.push_str(&v.to_string());
284                    }
285                    SqlParam::Float(v) => {
286                        result.push_str(&v.to_string());
287                    }
288                    SqlParam::Str(v) => {
289                        result.push('\'');
290                        result.push_str(&v.replace('\'', "''"));
291                        result.push('\'');
292                    }
293                    SqlParam::Bytes(_) => {
294                        // Keep the original placeholder for binary params.
295                        result.push(':');
296                        result.push_str(key);
297                    }
298                }
299            } else {
300                // Unknown placeholder — keep as-is.
301                result.push(':');
302                result.push_str(key);
303            }
304            i = end;
305        } else {
306            result.push(sql[i..].chars().next().unwrap());
307            i += sql[i..].chars().next().unwrap().len_utf8();
308        }
309    }
310
311    result
312}
313
314fn is_ident_start(b: u8) -> bool {
315    b.is_ascii_alphabetic() || b == b'_'
316}
317
318fn is_ident_continue(b: u8) -> bool {
319    b.is_ascii_alphanumeric() || b == b'_'
320}
321
322// ---------------------------------------------------------------------------
323// Lightweight SQL SELECT parser → Redis Search translation
324// ---------------------------------------------------------------------------
325
326/// A parsed SQL `SELECT` statement.
327#[derive(Debug, Clone)]
328struct ParsedSelect {
329    /// Field names to project (empty = `SELECT *`).
330    return_fields: Vec<String>,
331    /// Redis Search filter string derived from the `WHERE` clause.
332    where_filter: Option<String>,
333    /// Sort specification from `ORDER BY`.
334    sort_by: Option<SortBy>,
335    /// Limit specification from `LIMIT … [OFFSET …]`.
336    limit: Option<QueryLimit>,
337}
338
339impl ParsedSelect {
340    /// Returns the Redis Search query string used by `FT.SEARCH`.
341    fn filter_string(&self) -> String {
342        self.where_filter.clone().unwrap_or_else(|| "*".to_owned())
343    }
344}
345
346/// Tokenise and parse a SQL `SELECT` statement.
347///
348/// Returns `None` for unsupported syntax (aggregates, sub-queries, etc.).
349fn parse_select(sql: &str) -> Option<ParsedSelect> {
350    let tokens = tokenize(sql);
351    if tokens.is_empty() {
352        return None;
353    }
354    let mut pos = 0;
355
356    // SELECT
357    if !tok_eq(&tokens, pos, "SELECT") {
358        return None;
359    }
360    pos += 1;
361
362    // Bail on aggregate functions in SELECT list.
363    for tok in &tokens {
364        let upper = tok.to_ascii_uppercase();
365        if matches!(
366            upper.as_str(),
367            "COUNT"
368                | "AVG"
369                | "SUM"
370                | "MIN"
371                | "MAX"
372                | "STDDEV"
373                | "QUANTILE"
374                | "COUNT_DISTINCT"
375                | "ARRAY_AGG"
376                | "FIRST_VALUE"
377        ) {
378            return None;
379        }
380    }
381
382    // Bail on vector/geo functions.
383    for tok in &tokens {
384        let lower = tok.to_ascii_lowercase();
385        if lower == "cosine_distance" || lower == "vector_distance" || lower == "geo_distance" {
386            return None;
387        }
388    }
389
390    // Parse field list.
391    let mut return_fields = Vec::new();
392    if tok_eq(&tokens, pos, "*") {
393        pos += 1;
394    } else {
395        loop {
396            if pos >= tokens.len() {
397                return None;
398            }
399            let field = &tokens[pos];
400            if field.eq_ignore_ascii_case("FROM") {
401                break;
402            }
403            // Skip aliases: field AS alias
404            if !field.eq_ignore_ascii_case(",") && !field.eq_ignore_ascii_case("AS") {
405                // Check if the previous token was AS (skip alias names)
406                if pos > 0 && tokens[pos - 1].eq_ignore_ascii_case("AS") {
407                    // This is an alias name, skip it
408                } else {
409                    return_fields.push(field.to_string());
410                }
411            }
412            pos += 1;
413        }
414    }
415
416    // FROM
417    if !tok_eq(&tokens, pos, "FROM") {
418        return None;
419    }
420    pos += 1;
421    // Skip table name.
422    if pos >= tokens.len() {
423        return None;
424    }
425    pos += 1;
426
427    // WHERE, ORDER BY, LIMIT, OFFSET — all optional.
428    let mut where_filter: Option<String> = None;
429    let mut sort_by: Option<SortBy> = None;
430    let mut limit: Option<QueryLimit> = None;
431
432    while pos < tokens.len() {
433        if tok_eq(&tokens, pos, "WHERE") {
434            pos += 1;
435            let (filter_str, next) = parse_where_clause(&tokens, pos)?;
436            where_filter = Some(filter_str);
437            pos = next;
438        } else if tok_eq(&tokens, pos, "ORDER") {
439            if !tok_eq(&tokens, pos + 1, "BY") {
440                return None;
441            }
442            pos += 2;
443            if pos >= tokens.len() {
444                return None;
445            }
446            let field = tokens[pos].clone();
447            pos += 1;
448            let direction = if tok_eq(&tokens, pos, "DESC") {
449                pos += 1;
450                SortDirection::Desc
451            } else {
452                if tok_eq(&tokens, pos, "ASC") {
453                    pos += 1;
454                }
455                SortDirection::Asc
456            };
457            sort_by = Some(SortBy { field, direction });
458        } else if tok_eq(&tokens, pos, "LIMIT") {
459            pos += 1;
460            let num = parse_usize(&tokens, pos)?;
461            pos += 1;
462            let offset = if tok_eq(&tokens, pos, "OFFSET") {
463                pos += 1;
464                let off = parse_usize(&tokens, pos)?;
465                pos += 1;
466                off
467            } else {
468                0
469            };
470            limit = Some(QueryLimit { offset, num });
471        } else {
472            // Unknown clause — skip.
473            pos += 1;
474        }
475    }
476
477    Some(ParsedSelect {
478        return_fields,
479        where_filter,
480        sort_by,
481        limit,
482    })
483}
484
485// ---------------------------------------------------------------------------
486// Aggregate SQL parser → FT.AGGREGATE command builder
487// ---------------------------------------------------------------------------
488
489/// A single aggregate reducer (e.g. `COUNT(*)`, `SUM(price)`).
490#[derive(Debug, Clone)]
491struct AggReducer {
492    /// The Redis reducer function name (e.g. `COUNT`, `SUM`, `AVG`, etc.).
493    function: String,
494    /// The field argument to the reducer, if any (empty for `COUNT(*)`).
495    field: Option<String>,
496    /// The output alias (`AS alias`).
497    alias: String,
498    /// Extra numeric argument (e.g. quantile value for `QUANTILE(field, 0.5)`).
499    extra_arg: Option<f64>,
500}
501
502/// A parsed aggregate SQL statement.
503#[derive(Debug, Clone)]
504struct ParsedAggregate {
505    /// Redis Search filter from the WHERE clause.
506    where_filter: Option<String>,
507    /// GROUP BY field names (empty for global aggregation).
508    group_by_fields: Vec<String>,
509    /// Aggregate reducers from the SELECT list.
510    reducers: Vec<AggReducer>,
511}
512
513impl ParsedAggregate {
514    /// Builds an `FT.AGGREGATE` command for this parsed aggregate query.
515    fn build_cmd(&self, index_name: &str) -> redis::Cmd {
516        let mut cmd = redis::cmd("FT.AGGREGATE");
517        cmd.arg(index_name);
518
519        // Query filter (WHERE clause or wildcard).
520        let filter = self.where_filter.as_deref().unwrap_or("*");
521        cmd.arg(filter);
522
523        if self.group_by_fields.is_empty() {
524            // Global aggregation (no GROUP BY).
525            // Use GROUPBY 0 with reducers.
526            cmd.arg("GROUPBY").arg(0_u32);
527            for reducer in &self.reducers {
528                self.append_reducer(&mut cmd, reducer);
529            }
530        } else {
531            // GROUP BY with fields.
532            cmd.arg("GROUPBY").arg(self.group_by_fields.len());
533            for field in &self.group_by_fields {
534                cmd.arg(format!("@{}", field));
535            }
536            for reducer in &self.reducers {
537                self.append_reducer(&mut cmd, reducer);
538            }
539        }
540
541        cmd
542    }
543
544    /// Appends a single REDUCE clause to the command.
545    fn append_reducer(&self, cmd: &mut redis::Cmd, reducer: &AggReducer) {
546        cmd.arg("REDUCE");
547        cmd.arg(&reducer.function);
548
549        match reducer.function.as_str() {
550            "COUNT" => {
551                cmd.arg(0_u32); // COUNT takes 0 arguments
552            }
553            "QUANTILE" => {
554                // QUANTILE takes 2 arguments: field and quantile value
555                cmd.arg(2_u32);
556                if let Some(ref field) = reducer.field {
557                    cmd.arg(format!("@{}", field));
558                }
559                if let Some(q) = reducer.extra_arg {
560                    cmd.arg(format_num(q));
561                }
562            }
563            _ => {
564                // Most reducers take 1 argument: the field
565                cmd.arg(1_u32);
566                if let Some(ref field) = reducer.field {
567                    cmd.arg(format!("@{}", field));
568                }
569            }
570        }
571
572        cmd.arg("AS").arg(&reducer.alias);
573    }
574}
575
576/// Try to parse an aggregate SQL statement.
577///
578/// Returns `Some(ParsedAggregate)` if the SQL contains aggregate functions
579/// (COUNT, SUM, AVG, etc.) or GROUP BY; `None` otherwise.
580fn parse_aggregate(sql: &str) -> Option<ParsedAggregate> {
581    let tokens = tokenize(sql);
582    if tokens.is_empty() {
583        return None;
584    }
585    let mut pos = 0;
586
587    // SELECT
588    if !tok_eq(&tokens, pos, "SELECT") {
589        return None;
590    }
591    pos += 1;
592
593    // Check if this query has aggregate functions in the SELECT list.
594    let has_aggregate_fn = tokens.iter().any(|t| {
595        let upper = t.to_ascii_uppercase();
596        matches!(
597            upper.as_str(),
598            "COUNT"
599                | "AVG"
600                | "SUM"
601                | "MIN"
602                | "MAX"
603                | "STDDEV"
604                | "QUANTILE"
605                | "COUNT_DISTINCT"
606                | "ARRAY_AGG"
607                | "FIRST_VALUE"
608        )
609    });
610
611    let has_group_by = tokens
612        .windows(2)
613        .any(|w| w[0].eq_ignore_ascii_case("GROUP") && w[1].eq_ignore_ascii_case("BY"));
614
615    if !has_aggregate_fn && !has_group_by {
616        return None;
617    }
618
619    // Parse SELECT list for aggregate functions.
620    let mut reducers = Vec::new();
621    // Consume tokens until FROM.
622    while pos < tokens.len() && !tok_eq(&tokens, pos, "FROM") {
623        if let Some((reducer, next)) = try_parse_aggregate_fn(&tokens, pos) {
624            reducers.push(reducer);
625            pos = next;
626        } else if tokens[pos] == "," {
627            pos += 1;
628        } else {
629            // Non-aggregate field in SELECT (e.g. the GROUP BY field).
630            // Skip it—GROUP BY fields are handled separately.
631            pos += 1;
632        }
633    }
634
635    // FROM
636    if !tok_eq(&tokens, pos, "FROM") {
637        return None;
638    }
639    pos += 1;
640    // Skip table name.
641    if pos >= tokens.len() {
642        return None;
643    }
644    pos += 1;
645
646    // Parse WHERE, GROUP BY.
647    let mut where_filter: Option<String> = None;
648    let mut group_by_fields = Vec::new();
649
650    while pos < tokens.len() {
651        if tok_eq(&tokens, pos, "WHERE") {
652            pos += 1;
653            let (filter_str, next) = parse_where_clause(&tokens, pos)?;
654            where_filter = Some(filter_str);
655            pos = next;
656        } else if tok_eq(&tokens, pos, "GROUP") {
657            if !tok_eq(&tokens, pos + 1, "BY") {
658                return None;
659            }
660            pos += 2;
661            // Parse group by fields.
662            while pos < tokens.len() {
663                let upper = tokens[pos].to_ascii_uppercase();
664                if matches!(upper.as_str(), "HAVING" | "ORDER" | "LIMIT") {
665                    break;
666                }
667                if tokens[pos] == "," {
668                    pos += 1;
669                    continue;
670                }
671                group_by_fields.push(tokens[pos].clone());
672                pos += 1;
673            }
674        } else {
675            pos += 1;
676        }
677    }
678
679    // Need at least one reducer to be a valid aggregate query.
680    if reducers.is_empty() {
681        return None;
682    }
683
684    Some(ParsedAggregate {
685        where_filter,
686        group_by_fields,
687        reducers,
688    })
689}
690
691/// Try to parse an aggregate function call at position `pos`.
692///
693/// Handles: `COUNT(*)`, `SUM(field)`, `AVG(field)`, `MIN(field)`, `MAX(field)`,
694/// `STDDEV(field)`, `COUNT_DISTINCT(field)`, `QUANTILE(field, q)`,
695/// `ARRAY_AGG(field)`, `FIRST_VALUE(field)` — all with optional `AS alias`.
696fn try_parse_aggregate_fn(tokens: &[String], pos: usize) -> Option<(AggReducer, usize)> {
697    if pos >= tokens.len() {
698        return None;
699    }
700
701    let func_upper = tokens[pos].to_ascii_uppercase();
702
703    // Map SQL function names to Redis REDUCE function names.
704    let redis_func = match func_upper.as_str() {
705        "COUNT" => "COUNT",
706        "SUM" => "SUM",
707        "AVG" => "AVG",
708        "MIN" => "MIN",
709        "MAX" => "MAX",
710        "STDDEV" => "STDDEV",
711        "COUNT_DISTINCT" => "COUNT_DISTINCT",
712        "QUANTILE" => "QUANTILE",
713        "ARRAY_AGG" => "TOLIST",
714        "FIRST_VALUE" => "FIRST_VALUE",
715        _ => return None,
716    };
717
718    let mut p = pos + 1;
719
720    // Expect '('
721    if !tok_eq(tokens, p, "(") {
722        return None;
723    }
724    p += 1;
725
726    // Parse arguments.
727    let mut field: Option<String> = None;
728    let mut extra_arg: Option<f64> = None;
729
730    if func_upper == "COUNT" && tok_eq(tokens, p, "*") {
731        // COUNT(*)
732        p += 1;
733    } else if p < tokens.len() && tokens[p] != ")" {
734        // First argument: field name
735        field = Some(tokens[p].clone());
736        p += 1;
737
738        // Check for second argument (QUANTILE has 2 args)
739        if tok_eq(tokens, p, ",") {
740            p += 1;
741            if p < tokens.len() && tokens[p] != ")" {
742                extra_arg = tokens[p].parse::<f64>().ok();
743                p += 1;
744            }
745        }
746    }
747
748    // Expect ')'
749    if !tok_eq(tokens, p, ")") {
750        return None;
751    }
752    p += 1;
753
754    // Parse optional AS alias.
755    let alias = if tok_eq(tokens, p, "AS") {
756        p += 1;
757        if p >= tokens.len() {
758            return None;
759        }
760        let a = tokens[p].clone();
761        p += 1;
762        a
763    } else {
764        // Default alias: use the lowercase function name.
765        func_upper.to_lowercase()
766    };
767
768    Some((
769        AggReducer {
770            function: redis_func.to_owned(),
771            field,
772            alias,
773            extra_arg,
774        },
775        p,
776    ))
777}
778
779// ---------------------------------------------------------------------------
780// Vector SQL parser → KNN FT.SEARCH command
781// ---------------------------------------------------------------------------
782
783/// Information about a vector function call in SELECT.
784#[derive(Debug, Clone)]
785struct VectorFuncCall {
786    /// The vector field name (e.g. `embedding`).
787    field: String,
788    /// The parameter name that holds the binary vector (e.g. `vec`).
789    param_name: String,
790    /// The output alias (e.g. `score`, `vector_distance`).
791    alias: String,
792}
793
794/// A parsed SQL SELECT with vector search function.
795#[derive(Debug, Clone)]
796struct ParsedVectorSelect {
797    /// The vector function call details.
798    vector_fn: VectorFuncCall,
799    /// Non-vector fields to return (from SELECT list).
800    return_fields: Vec<String>,
801    /// Redis Search filter string from WHERE clause (without vector function).
802    where_filter: Option<String>,
803    /// KNN N value (from LIMIT).
804    knn_num: usize,
805    /// The binary vector blob.
806    vector_blob: Option<Vec<u8>>,
807}
808
809impl ParsedVectorSelect {
810    /// Generates a KNN query string: `(filter)=>[KNN N @field $vector AS alias]`
811    fn to_knn_query_string(&self) -> String {
812        let base = self.where_filter.as_deref().unwrap_or("*");
813        format!(
814            "{}=>[KNN {} @{} $vector AS {}]",
815            base, self.knn_num, self.vector_fn.field, self.vector_fn.alias
816        )
817    }
818
819    /// Returns `QueryParam` entries for the vector blob.
820    fn params(&self) -> Vec<QueryParam> {
821        if let Some(ref blob) = self.vector_blob {
822            vec![QueryParam {
823                name: "vector".to_owned(),
824                value: QueryParamValue::Binary(blob.clone()),
825            }]
826        } else {
827            Vec::new()
828        }
829    }
830}
831
832/// Try to parse a SQL SELECT with vector_distance or cosine_distance.
833///
834/// Detects patterns like:
835/// - `SELECT title, vector_distance(embedding, :vec) AS score FROM idx LIMIT 3`
836/// - `SELECT title, cosine_distance(embedding, :vec) AS dist FROM idx WHERE genre = 'x' LIMIT 3`
837fn parse_vector_select(
838    sql: &str,
839    params: &HashMap<String, SqlParam>,
840) -> Option<ParsedVectorSelect> {
841    let tokens = tokenize(sql);
842    if tokens.is_empty() {
843        return None;
844    }
845    let mut pos = 0;
846
847    // SELECT
848    if !tok_eq(&tokens, pos, "SELECT") {
849        return None;
850    }
851    pos += 1;
852
853    // Scan SELECT list for vector function calls.
854    let mut vector_fn: Option<VectorFuncCall> = None;
855    let mut return_fields: Vec<String> = Vec::new();
856
857    while pos < tokens.len() && !tok_eq(&tokens, pos, "FROM") {
858        if tokens[pos] == "," {
859            pos += 1;
860            continue;
861        }
862
863        // Check for vector_distance(...) or cosine_distance(...)
864        let lower = tokens[pos].to_ascii_lowercase();
865        if (lower == "vector_distance" || lower == "cosine_distance")
866            && tok_eq(&tokens, pos + 1, "(")
867        {
868            let parsed = try_parse_vector_fn_call(&tokens, pos)?;
869            vector_fn = Some(parsed.0);
870            pos = parsed.1;
871            continue;
872        }
873
874        // Skip AS alias (for non-vector fields)
875        if tokens[pos].eq_ignore_ascii_case("AS") {
876            pos += 1; // skip "AS"
877            if pos < tokens.len() && !tok_eq(&tokens, pos, "FROM") {
878                pos += 1; // skip alias
879            }
880            continue;
881        }
882
883        // Regular field
884        if !tokens[pos].eq_ignore_ascii_case("*") {
885            return_fields.push(tokens[pos].clone());
886        }
887        pos += 1;
888    }
889
890    let vector_fn = vector_fn?; // Must have a vector function
891
892    // FROM
893    if !tok_eq(&tokens, pos, "FROM") {
894        return None;
895    }
896    pos += 1;
897    if pos >= tokens.len() {
898        return None;
899    }
900    pos += 1; // skip table name
901
902    // Parse WHERE, ORDER BY, LIMIT.
903    let mut where_filter: Option<String> = None;
904    let mut knn_num: usize = 10; // default
905
906    while pos < tokens.len() {
907        if tok_eq(&tokens, pos, "WHERE") {
908            pos += 1;
909            let (filter_str, next) = parse_where_clause(&tokens, pos)?;
910            where_filter = Some(filter_str);
911            pos = next;
912        } else if tok_eq(&tokens, pos, "ORDER") {
913            // Skip ORDER BY for vector queries (ordering is by vector distance).
914            while pos < tokens.len()
915                && !tok_eq(&tokens, pos, "LIMIT")
916                && !tok_eq(&tokens, pos, "WHERE")
917            {
918                pos += 1;
919            }
920        } else if tok_eq(&tokens, pos, "LIMIT") {
921            pos += 1;
922            knn_num = parse_usize(&tokens, pos)?;
923            pos += 1;
924            // Skip OFFSET if present
925            if tok_eq(&tokens, pos, "OFFSET") {
926                pos += 2;
927            }
928        } else {
929            pos += 1;
930        }
931    }
932
933    // Look up the binary vector blob from params.
934    let vector_blob = params.get(&vector_fn.param_name).and_then(|p| {
935        if let SqlParam::Bytes(b) = p {
936            Some(b.clone())
937        } else {
938            None
939        }
940    });
941
942    Some(ParsedVectorSelect {
943        vector_fn,
944        return_fields,
945        where_filter,
946        knn_num,
947        vector_blob,
948    })
949}
950
951/// Parse a vector function call: `vector_distance(field, :param)` or
952/// `cosine_distance(field, :param)`, optionally followed by `AS alias`.
953fn try_parse_vector_fn_call(tokens: &[String], pos: usize) -> Option<(VectorFuncCall, usize)> {
954    if pos + 5 >= tokens.len() {
955        return None;
956    }
957
958    let _func_name = &tokens[pos]; // vector_distance or cosine_distance
959    let mut p = pos + 1;
960
961    // Expect '('
962    if !tok_eq(tokens, p, "(") {
963        return None;
964    }
965    p += 1;
966
967    // Field name
968    let field = tokens[p].clone();
969    p += 1;
970
971    // Expect ','
972    if !tok_eq(tokens, p, ",") {
973        return None;
974    }
975    p += 1;
976
977    // Parameter reference: :param_name
978    let param_tok = &tokens[p];
979    let param_name = if param_tok.starts_with(':') {
980        param_tok[1..].to_string()
981    } else {
982        param_tok.clone()
983    };
984    p += 1;
985
986    // Expect ')'
987    if !tok_eq(tokens, p, ")") {
988        return None;
989    }
990    p += 1;
991
992    // Optional AS alias
993    let alias = if tok_eq(tokens, p, "AS") {
994        p += 1;
995        if p >= tokens.len() {
996            return None;
997        }
998        let a = tokens[p].clone();
999        p += 1;
1000        a
1001    } else {
1002        "vector_distance".to_string()
1003    };
1004
1005    Some((
1006        VectorFuncCall {
1007            field,
1008            param_name,
1009            alias,
1010        },
1011        p,
1012    ))
1013}
1014
1015// ---------------------------------------------------------------------------
1016// Geo SQL parser → GEOFILTER and FT.AGGREGATE APPLY geodistance
1017// ---------------------------------------------------------------------------
1018
1019/// Parsed geo_distance() call in a WHERE clause.
1020#[derive(Debug, Clone)]
1021struct ParsedGeoWhere {
1022    /// The GEOFILTER specification.
1023    geofilter: super::GeoFilter,
1024    /// Non-geo Redis Search filter from the WHERE clause.
1025    non_geo_filter: Option<String>,
1026    /// Return fields from SELECT.
1027    return_fields: Vec<String>,
1028}
1029
1030impl ParsedGeoWhere {
1031    /// Returns the filter string (non-geo part, or wildcard).
1032    fn filter_string(&self) -> String {
1033        self.non_geo_filter
1034            .clone()
1035            .unwrap_or_else(|| "*".to_owned())
1036    }
1037}
1038
1039/// Parsed geo_distance() call in SELECT (generates FT.AGGREGATE).
1040#[derive(Debug, Clone)]
1041struct ParsedGeoAggregate {
1042    /// The geo field name.
1043    geo_field: String,
1044    /// Longitude of the reference point.
1045    lon: f64,
1046    /// Latitude of the reference point.
1047    lat: f64,
1048    /// Output alias.
1049    alias: String,
1050    /// Filter from WHERE clause.
1051    where_filter: Option<String>,
1052}
1053
1054impl ParsedGeoAggregate {
1055    /// Builds an `FT.AGGREGATE` command with `APPLY geodistance(...)`.
1056    fn build_cmd(&self, index_name: &str) -> redis::Cmd {
1057        let mut cmd = redis::cmd("FT.AGGREGATE");
1058        cmd.arg(index_name);
1059        cmd.arg(self.where_filter.as_deref().unwrap_or("*"));
1060
1061        // LOAD field (ensure the geo field is available for APPLY)
1062        cmd.arg("LOAD")
1063            .arg(1_u32)
1064            .arg(format!("@{}", self.geo_field));
1065
1066        // APPLY geodistance(@field, lon, lat) AS alias
1067        let expr = format!(
1068            "geodistance(@{}, {}, {})",
1069            self.geo_field, self.lon, self.lat
1070        );
1071        cmd.arg("APPLY").arg(expr).arg("AS").arg(&self.alias);
1072
1073        cmd
1074    }
1075}
1076
1077/// Try to parse a SQL SELECT with geo_distance() in the WHERE clause.
1078///
1079/// Pattern: `WHERE geo_distance(location, POINT(lon, lat), 'unit') < radius`
1080fn parse_geo_where(sql: &str) -> Option<ParsedGeoWhere> {
1081    let tokens = tokenize(sql);
1082    if tokens.is_empty() {
1083        return None;
1084    }
1085    let mut pos = 0;
1086
1087    // SELECT
1088    if !tok_eq(&tokens, pos, "SELECT") {
1089        return None;
1090    }
1091    pos += 1;
1092
1093    // Parse SELECT list.
1094    let mut return_fields: Vec<String> = Vec::new();
1095    if tok_eq(&tokens, pos, "*") {
1096        pos += 1;
1097    } else {
1098        while pos < tokens.len() && !tok_eq(&tokens, pos, "FROM") {
1099            if tokens[pos] == "," || tokens[pos].eq_ignore_ascii_case("AS") {
1100                pos += 1;
1101                // Skip alias name after AS
1102                if pos > 1
1103                    && tokens[pos - 1].eq_ignore_ascii_case("AS")
1104                    && pos < tokens.len()
1105                    && !tok_eq(&tokens, pos, "FROM")
1106                {
1107                    pos += 1;
1108                }
1109                continue;
1110            }
1111            return_fields.push(tokens[pos].clone());
1112            pos += 1;
1113        }
1114    }
1115
1116    // FROM table
1117    if !tok_eq(&tokens, pos, "FROM") {
1118        return None;
1119    }
1120    pos += 1;
1121    if pos >= tokens.len() {
1122        return None;
1123    }
1124    pos += 1; // skip table name
1125
1126    // WHERE
1127    if !tok_eq(&tokens, pos, "WHERE") {
1128        return None;
1129    }
1130    pos += 1;
1131
1132    // Look for geo_distance(...) in the WHERE clause.
1133    // Collect non-geo conditions and geo conditions.
1134    let mut non_geo_conditions: Vec<String> = Vec::new();
1135    let mut geofilter: Option<super::GeoFilter> = None;
1136
1137    loop {
1138        if pos >= tokens.len() {
1139            break;
1140        }
1141        let upper = tokens[pos].to_ascii_uppercase();
1142        if matches!(upper.as_str(), "ORDER" | "LIMIT" | "GROUP" | "HAVING") {
1143            break;
1144        }
1145        if upper == "AND" {
1146            pos += 1;
1147            continue;
1148        }
1149
1150        // Check for geo_distance function
1151        if tokens[pos].eq_ignore_ascii_case("geo_distance") && tok_eq(&tokens, pos + 1, "(") {
1152            let (gf, next) = parse_geo_distance_where(&tokens, pos)?;
1153            geofilter = Some(gf);
1154            pos = next;
1155            continue;
1156        }
1157
1158        // Regular condition
1159        let (filter, next) = parse_single_condition(&tokens, pos)?;
1160        non_geo_conditions.push(filter);
1161        pos = next;
1162    }
1163
1164    let geofilter = geofilter?; // Must have a geo_distance call
1165
1166    let non_geo_filter = if non_geo_conditions.is_empty() {
1167        None
1168    } else if non_geo_conditions.len() == 1 {
1169        Some(non_geo_conditions.into_iter().next().unwrap())
1170    } else {
1171        Some(format!("({})", non_geo_conditions.join(" ")))
1172    };
1173
1174    Some(ParsedGeoWhere {
1175        geofilter,
1176        non_geo_filter,
1177        return_fields,
1178    })
1179}
1180
1181/// Parse `geo_distance(field, POINT(lon, lat), 'unit') < radius` from WHERE.
1182///
1183/// Returns a `GeoFilter` and the position after the comparison.
1184fn parse_geo_distance_where(tokens: &[String], pos: usize) -> Option<(super::GeoFilter, usize)> {
1185    let mut p = pos;
1186
1187    // geo_distance
1188    if !tokens[p].eq_ignore_ascii_case("geo_distance") {
1189        return None;
1190    }
1191    p += 1;
1192
1193    // (
1194    if !tok_eq(tokens, p, "(") {
1195        return None;
1196    }
1197    p += 1;
1198
1199    // field name
1200    let field = tokens[p].clone();
1201    p += 1;
1202
1203    // ,
1204    if !tok_eq(tokens, p, ",") {
1205        return None;
1206    }
1207    p += 1;
1208
1209    // POINT(lon, lat) or just lon, lat
1210    let (lon, lat);
1211    if tokens[p].eq_ignore_ascii_case("POINT") {
1212        p += 1;
1213        // (
1214        if !tok_eq(tokens, p, "(") {
1215            return None;
1216        }
1217        p += 1;
1218        lon = tokens[p].parse::<f64>().ok()?;
1219        p += 1;
1220        // ,
1221        if !tok_eq(tokens, p, ",") {
1222            return None;
1223        }
1224        p += 1;
1225        lat = tokens[p].parse::<f64>().ok()?;
1226        p += 1;
1227        // )
1228        if !tok_eq(tokens, p, ")") {
1229            return None;
1230        }
1231        p += 1;
1232    } else {
1233        lon = tokens[p].parse::<f64>().ok()?;
1234        p += 1;
1235        if tok_eq(tokens, p, ",") {
1236            p += 1;
1237        }
1238        lat = tokens[p].parse::<f64>().ok()?;
1239        p += 1;
1240    }
1241
1242    // , 'unit'
1243    if !tok_eq(tokens, p, ",") {
1244        return None;
1245    }
1246    p += 1;
1247    let unit = unquote(&tokens[p]);
1248    p += 1;
1249
1250    // )
1251    if !tok_eq(tokens, p, ")") {
1252        return None;
1253    }
1254    p += 1;
1255
1256    // < radius
1257    if !tok_eq(tokens, p, "<") {
1258        return None;
1259    }
1260    p += 1;
1261    let radius = tokens[p].parse::<f64>().ok()?;
1262    p += 1;
1263
1264    Some((
1265        super::GeoFilter {
1266            field,
1267            lon,
1268            lat,
1269            radius,
1270            unit,
1271        },
1272        p,
1273    ))
1274}
1275
1276/// Try to parse a SQL SELECT with geo_distance() in the SELECT clause.
1277///
1278/// Pattern: `SELECT name, geo_distance(location, POINT(lon, lat)) AS distance FROM idx`
1279/// → FT.AGGREGATE with APPLY geodistance.
1280fn parse_geo_aggregate(sql: &str) -> Option<ParsedGeoAggregate> {
1281    let tokens = tokenize(sql);
1282    if tokens.is_empty() {
1283        return None;
1284    }
1285    let mut pos = 0;
1286
1287    if !tok_eq(&tokens, pos, "SELECT") {
1288        return None;
1289    }
1290    pos += 1;
1291
1292    let mut geo_field: Option<String> = None;
1293    let mut geo_lon: Option<f64> = None;
1294    let mut geo_lat: Option<f64> = None;
1295    let mut geo_alias: Option<String> = None;
1296
1297    // Parse SELECT list for geo_distance function call.
1298    while pos < tokens.len() && !tok_eq(&tokens, pos, "FROM") {
1299        if tokens[pos] == "," {
1300            pos += 1;
1301            continue;
1302        }
1303
1304        if tokens[pos].eq_ignore_ascii_case("geo_distance") && tok_eq(&tokens, pos + 1, "(") {
1305            // Parse geo_distance(field, POINT(lon, lat))
1306            pos += 2; // skip "geo_distance" and "("
1307            let field = tokens[pos].clone();
1308            pos += 1;
1309            if !tok_eq(&tokens, pos, ",") {
1310                return None;
1311            }
1312            pos += 1;
1313
1314            // POINT(lon, lat)
1315            let (lon, lat);
1316            if tokens[pos].eq_ignore_ascii_case("POINT") {
1317                pos += 1;
1318                if !tok_eq(&tokens, pos, "(") {
1319                    return None;
1320                }
1321                pos += 1;
1322                lon = tokens[pos].parse::<f64>().ok()?;
1323                pos += 1;
1324                if tok_eq(&tokens, pos, ",") {
1325                    pos += 1;
1326                }
1327                lat = tokens[pos].parse::<f64>().ok()?;
1328                pos += 1;
1329                if !tok_eq(&tokens, pos, ")") {
1330                    return None;
1331                }
1332                pos += 1;
1333            } else {
1334                return None;
1335            }
1336
1337            // )
1338            if !tok_eq(&tokens, pos, ")") {
1339                return None;
1340            }
1341            pos += 1;
1342
1343            // AS alias
1344            let alias = if tok_eq(&tokens, pos, "AS") {
1345                pos += 1;
1346                let a = tokens[pos].clone();
1347                pos += 1;
1348                a
1349            } else {
1350                "distance".to_string()
1351            };
1352
1353            geo_field = Some(field);
1354            geo_lon = Some(lon);
1355            geo_lat = Some(lat);
1356            geo_alias = Some(alias);
1357            continue;
1358        }
1359
1360        // Skip AS alias for non-geo fields
1361        if tokens[pos].eq_ignore_ascii_case("AS") {
1362            pos += 1;
1363            if pos < tokens.len() {
1364                pos += 1; // skip alias
1365            }
1366            continue;
1367        }
1368
1369        // Skip non-geo field
1370        pos += 1;
1371    }
1372
1373    let geo_field = geo_field?;
1374    let lon = geo_lon?;
1375    let lat = geo_lat?;
1376    let alias = geo_alias.unwrap_or_else(|| "distance".to_string());
1377
1378    // FROM
1379    if !tok_eq(&tokens, pos, "FROM") {
1380        return None;
1381    }
1382    pos += 1;
1383    if pos >= tokens.len() {
1384        return None;
1385    }
1386    pos += 1; // skip table
1387
1388    // Optional WHERE
1389    let mut where_filter: Option<String> = None;
1390    while pos < tokens.len() {
1391        if tok_eq(&tokens, pos, "WHERE") {
1392            pos += 1;
1393            let (filter_str, next) = parse_where_clause(&tokens, pos)?;
1394            where_filter = Some(filter_str);
1395            pos = next;
1396        } else {
1397            pos += 1;
1398        }
1399    }
1400
1401    Some(ParsedGeoAggregate {
1402        geo_field,
1403        lon,
1404        lat,
1405        alias,
1406        where_filter,
1407    })
1408}
1409
1410/// Parse a WHERE clause starting at `pos`. Returns the Redis filter string and
1411/// the position after the last consumed token.
1412///
1413/// Supports `AND` and `OR` combinators with correct precedence: `AND` binds
1414/// tighter than `OR`, so `a AND b OR c AND d` is parsed as `(a b) | (c d)`.
1415fn parse_where_clause(tokens: &[String], mut pos: usize) -> Option<(String, usize)> {
1416    // We collect OR-separated groups of AND-joined conditions.
1417    let mut or_groups: Vec<Vec<String>> = Vec::new();
1418    let mut current_and_group: Vec<String> = Vec::new();
1419
1420    loop {
1421        if pos >= tokens.len() {
1422            break;
1423        }
1424        // Stop at ORDER / LIMIT / GROUP (not part of WHERE).
1425        let upper = tokens[pos].to_ascii_uppercase();
1426        if matches!(upper.as_str(), "ORDER" | "LIMIT" | "GROUP" | "HAVING") {
1427            break;
1428        }
1429        // AND combinator — continue in current group.
1430        if upper == "AND" {
1431            pos += 1;
1432            continue;
1433        }
1434        // OR combinator — start a new group.
1435        if upper == "OR" {
1436            pos += 1;
1437            or_groups.push(std::mem::take(&mut current_and_group));
1438            continue;
1439        }
1440
1441        let (filter, next) = parse_single_condition(tokens, pos)?;
1442        current_and_group.push(filter);
1443        pos = next;
1444    }
1445
1446    // Push the last AND group.
1447    if !current_and_group.is_empty() {
1448        or_groups.push(current_and_group);
1449    }
1450
1451    if or_groups.is_empty() {
1452        return Some(("*".to_owned(), pos));
1453    }
1454
1455    // Build the filter string.
1456    let group_strs: Vec<String> = or_groups
1457        .into_iter()
1458        .map(|g| {
1459            if g.len() == 1 {
1460                g.into_iter().next().unwrap()
1461            } else {
1462                format!("({})", g.join(" "))
1463            }
1464        })
1465        .collect();
1466
1467    let filter = if group_strs.len() == 1 {
1468        group_strs.into_iter().next().unwrap()
1469    } else {
1470        // OR-combine: (a | b) in Redis Search syntax.
1471        format!("({})", group_strs.join(" | "))
1472    };
1473
1474    Some((filter, pos))
1475}
1476
1477/// Parse a single WHERE condition starting at `pos`.
1478///
1479/// Returns the Redis filter string for this condition and the position after
1480/// the last consumed token.
1481fn parse_single_condition(tokens: &[String], mut pos: usize) -> Option<(String, usize)> {
1482    let field = &tokens[pos];
1483    pos += 1;
1484    if pos >= tokens.len() {
1485        return None;
1486    }
1487
1488    let op = &tokens[pos];
1489    pos += 1;
1490
1491    // BETWEEN handling: field BETWEEN lo AND hi
1492    if op.eq_ignore_ascii_case("BETWEEN") {
1493        let lo = parse_numeric_or_date_literal(tokens, pos)?;
1494        pos += 1;
1495        if !tok_eq(tokens, pos, "AND") {
1496            return None;
1497        }
1498        pos += 1;
1499        let hi = parse_numeric_or_date_literal(tokens, pos)?;
1500        pos += 1;
1501        return Some((
1502            format!("@{}:[{} {}]", field, format_num(lo), format_num(hi)),
1503            pos,
1504        ));
1505    }
1506
1507    // NOT IN handling: field NOT IN ('a', 'b')
1508    if op.eq_ignore_ascii_case("NOT") && tok_eq(tokens, pos, "IN") {
1509        pos += 1; // skip "IN"
1510        if !tok_eq(tokens, pos, "(") {
1511            return None;
1512        }
1513        pos += 1;
1514        let mut vals = Vec::new();
1515        loop {
1516            if pos >= tokens.len() {
1517                return None;
1518            }
1519            if tokens[pos] == ")" {
1520                pos += 1;
1521                break;
1522            }
1523            if tokens[pos] == "," {
1524                pos += 1;
1525                continue;
1526            }
1527            vals.push(unquote(&tokens[pos]));
1528            pos += 1;
1529        }
1530        let escaped: Vec<String> = vals.iter().map(|v| escape_tag(v)).collect();
1531        return Some((format!("(-@{}:{{{}}})", field, escaped.join("|")), pos));
1532    }
1533
1534    // IN handling: field IN ('a', 'b')
1535    if op.eq_ignore_ascii_case("IN") {
1536        if !tok_eq(tokens, pos, "(") {
1537            return None;
1538        }
1539        pos += 1;
1540        let mut vals = Vec::new();
1541        loop {
1542            if pos >= tokens.len() {
1543                return None;
1544            }
1545            if tokens[pos] == ")" {
1546                pos += 1;
1547                break;
1548            }
1549            if tokens[pos] == "," {
1550                pos += 1;
1551                continue;
1552            }
1553            vals.push(unquote(&tokens[pos]));
1554            pos += 1;
1555        }
1556        let escaped: Vec<String> = vals.iter().map(|v| escape_tag(v)).collect();
1557        return Some((format!("@{}:{{{}}}", field, escaped.join("|")), pos));
1558    }
1559
1560    // LIKE handling: field LIKE 'pattern'
1561    if op.eq_ignore_ascii_case("LIKE") {
1562        if pos >= tokens.len() {
1563            return None;
1564        }
1565        let pattern = unquote(&tokens[pos]);
1566        pos += 1;
1567        let redis_pattern = sql_like_to_redis(&pattern);
1568        return Some((format!("@{}:({})", field, redis_pattern), pos));
1569    }
1570
1571    // NOT LIKE handling: field NOT LIKE 'pattern'
1572    if op.eq_ignore_ascii_case("NOT") && tok_eq(tokens, pos, "LIKE") {
1573        pos += 1; // skip "LIKE"
1574        if pos >= tokens.len() {
1575            return None;
1576        }
1577        let pattern = unquote(&tokens[pos]);
1578        pos += 1;
1579        let redis_pattern = sql_like_to_redis(&pattern);
1580        return Some((format!("(-@{}:({}))", field, redis_pattern), pos));
1581    }
1582
1583    // !=
1584    if op == "!=" {
1585        if pos >= tokens.len() {
1586            return None;
1587        }
1588        let value = unquote(&tokens[pos]);
1589        pos += 1;
1590        if is_numeric_str(&value) {
1591            let n: f64 = value.parse().ok()?;
1592            return Some((
1593                format!("(-@{}:[{} {}])", field, format_num(n), format_num(n)),
1594                pos,
1595            ));
1596        }
1597        if let Some(ts) = try_parse_date(&value) {
1598            return Some((
1599                format!("(-@{}:[{} {}])", field, format_num(ts), format_num(ts)),
1600                pos,
1601            ));
1602        }
1603        // Tag or text negation.
1604        return Some((format!("(-@{}:{{{}}})", field, escape_tag(&value)), pos));
1605    }
1606
1607    // Comparison operators: =, <, >, <=, >=
1608    if pos >= tokens.len() {
1609        return None;
1610    }
1611
1612    // Handle two-character ops: <=, >=
1613    let (real_op, value_str) = if (op == "<" || op == ">") && tokens[pos] == "=" {
1614        let combined = format!("{}=", op);
1615        pos += 1;
1616        if pos >= tokens.len() {
1617            return None;
1618        }
1619        let v = unquote(&tokens[pos]);
1620        pos += 1;
1621        (combined, v)
1622    } else {
1623        let v = unquote(&tokens[pos]);
1624        pos += 1;
1625        (op.clone(), v)
1626    };
1627
1628    let filter = match real_op.as_str() {
1629        "=" => {
1630            if is_numeric_str(&value_str) {
1631                let n: f64 = value_str.parse().ok()?;
1632                format!("@{}:[{} {}]", field, format_num(n), format_num(n))
1633            } else if let Some(ts) = try_parse_date(&value_str) {
1634                format!("@{}:[{} {}]", field, format_num(ts), format_num(ts))
1635            } else {
1636                // Could be tag or text. Use tag syntax for simple values.
1637                // For text with wildcards or multi-word, use text syntax.
1638                let val = value_str.clone();
1639                if val.contains('*') || val.contains('%') {
1640                    // Wildcard/fuzzy → text field search.
1641                    format!("@{}:({})", field, val)
1642                } else if val.contains(' ') {
1643                    // Multi-word → phrase search.
1644                    format!("@{}:(\"{}\")", field, val)
1645                } else {
1646                    // Single term → tag match.
1647                    format!("@{}:{{{}}}", field, escape_tag(&val))
1648                }
1649            }
1650        }
1651        "<" => {
1652            let n = parse_num_or_date(&value_str)?;
1653            format!("@{}:[-inf ({}]", field, format_num(n))
1654        }
1655        ">" => {
1656            let n = parse_num_or_date(&value_str)?;
1657            format!("@{}:[({} +inf]", field, format_num(n))
1658        }
1659        "<=" => {
1660            let n = parse_num_or_date(&value_str)?;
1661            format!("@{}:[-inf {}]", field, format_num(n))
1662        }
1663        ">=" => {
1664            let n = parse_num_or_date(&value_str)?;
1665            format!("@{}:[{} +inf]", field, format_num(n))
1666        }
1667        _ => return None,
1668    };
1669
1670    Some((filter, pos))
1671}
1672
1673// ---------------------------------------------------------------------------
1674// SQL tokenizer
1675// ---------------------------------------------------------------------------
1676
1677/// Tokenize SQL into a sequence of meaningful tokens.
1678///
1679/// Handles single-quoted strings, double-quoted identifiers, numbers, identifiers,
1680/// and single-character operators.
1681fn tokenize(sql: &str) -> Vec<String> {
1682    let mut tokens = Vec::new();
1683    let chars: Vec<char> = sql.chars().collect();
1684    let len = chars.len();
1685    let mut i = 0;
1686
1687    while i < len {
1688        // Skip whitespace.
1689        if chars[i].is_ascii_whitespace() {
1690            i += 1;
1691            continue;
1692        }
1693        // Single-quoted string literal.
1694        if chars[i] == '\'' {
1695            let mut s = String::new();
1696            s.push('\'');
1697            i += 1;
1698            while i < len {
1699                if chars[i] == '\'' {
1700                    if i + 1 < len && chars[i + 1] == '\'' {
1701                        s.push('\'');
1702                        s.push('\'');
1703                        i += 2;
1704                    } else {
1705                        break;
1706                    }
1707                } else {
1708                    s.push(chars[i]);
1709                    i += 1;
1710                }
1711            }
1712            s.push('\'');
1713            if i < len {
1714                i += 1;
1715            }
1716            tokens.push(s);
1717            continue;
1718        }
1719        // Parameter reference (e.g. :vec, :param_name).
1720        if chars[i] == ':'
1721            && i + 1 < len
1722            && (chars[i + 1].is_ascii_alphabetic() || chars[i + 1] == '_')
1723        {
1724            let start = i;
1725            i += 1; // skip ':'
1726            while i < len && (chars[i].is_ascii_alphanumeric() || chars[i] == '_') {
1727                i += 1;
1728            }
1729            tokens.push(chars[start..i].iter().collect());
1730            continue;
1731        }
1732        // Identifier or keyword.
1733        if chars[i].is_ascii_alphabetic() || chars[i] == '_' {
1734            let start = i;
1735            while i < len && (chars[i].is_ascii_alphanumeric() || chars[i] == '_') {
1736                i += 1;
1737            }
1738            tokens.push(chars[start..i].iter().collect());
1739            continue;
1740        }
1741        // Number (with optional negative sign or decimal point).
1742        if chars[i].is_ascii_digit()
1743            || (chars[i] == '-' && i + 1 < len && chars[i + 1].is_ascii_digit())
1744        {
1745            let start = i;
1746            if chars[i] == '-' {
1747                i += 1;
1748            }
1749            while i < len && (chars[i].is_ascii_digit() || chars[i] == '.') {
1750                i += 1;
1751            }
1752            tokens.push(chars[start..i].iter().collect());
1753            continue;
1754        }
1755        // Two-character operators: !=, <=, >=.
1756        if i + 1 < len {
1757            let two: String = chars[i..i + 2].iter().collect();
1758            if two == "!=" || two == "<=" || two == ">=" {
1759                tokens.push(two);
1760                i += 2;
1761                continue;
1762            }
1763        }
1764        // Single-character operators/punctuation.
1765        tokens.push(chars[i].to_string());
1766        i += 1;
1767    }
1768    tokens
1769}
1770
1771// ---------------------------------------------------------------------------
1772// Helpers
1773// ---------------------------------------------------------------------------
1774
1775/// Case-insensitive token match at position `pos`.
1776fn tok_eq(tokens: &[String], pos: usize, expected: &str) -> bool {
1777    tokens
1778        .get(pos)
1779        .map_or(false, |t| t.eq_ignore_ascii_case(expected))
1780}
1781
1782/// Parse a usize from a token at `pos`.
1783fn parse_usize(tokens: &[String], pos: usize) -> Option<usize> {
1784    tokens.get(pos)?.parse().ok()
1785}
1786
1787/// Parse a numeric literal or ISO date string from a token at `pos`.
1788///
1789/// This extends `parse_numeric_literal` to handle date strings like
1790/// `'2024-01-01'` by converting them to Unix timestamps.
1791fn parse_numeric_or_date_literal(tokens: &[String], pos: usize) -> Option<f64> {
1792    let tok = tokens.get(pos)?;
1793    let s = unquote(tok);
1794    if let Ok(n) = s.parse::<f64>() {
1795        Some(n)
1796    } else {
1797        try_parse_date(&s)
1798    }
1799}
1800
1801/// Try to parse a string as a number; if that fails, try as an ISO date.
1802fn parse_num_or_date(s: &str) -> Option<f64> {
1803    if let Ok(n) = s.parse::<f64>() {
1804        Some(n)
1805    } else {
1806        try_parse_date(s)
1807    }
1808}
1809
1810/// Try to parse an ISO 8601 date string (`YYYY-MM-DD` or `YYYY-MM-DDTHH:MM:SS`)
1811/// and return the Unix timestamp as `f64`.
1812///
1813/// This mirrors the upstream Python `sql-redis` library's date literal handling.
1814fn try_parse_date(s: &str) -> Option<f64> {
1815    // Try YYYY-MM-DD
1816    if s.len() == 10 && s.as_bytes().get(4) == Some(&b'-') && s.as_bytes().get(7) == Some(&b'-') {
1817        let year: i32 = s[0..4].parse().ok()?;
1818        let month: u32 = s[5..7].parse().ok()?;
1819        let day: u32 = s[8..10].parse().ok()?;
1820        if !(1..=12).contains(&month) || !(1..=31).contains(&day) {
1821            return None;
1822        }
1823        // Compute days from Unix epoch (1970-01-01) using a simplified calendar.
1824        let ts = date_to_unix_timestamp(year, month, day)?;
1825        return Some(ts as f64);
1826    }
1827    // Try YYYY-MM-DDTHH:MM:SS
1828    if s.len() >= 19 && (s.as_bytes().get(10) == Some(&b'T') || s.as_bytes().get(10) == Some(&b' '))
1829    {
1830        let year: i32 = s[0..4].parse().ok()?;
1831        let month: u32 = s[5..7].parse().ok()?;
1832        let day: u32 = s[8..10].parse().ok()?;
1833        let hour: u32 = s[11..13].parse().ok()?;
1834        let min: u32 = s[14..16].parse().ok()?;
1835        let sec: u32 = s[17..19].parse().ok()?;
1836        if !(1..=12).contains(&month) || !(1..=31).contains(&day) {
1837            return None;
1838        }
1839        if hour > 23 || min > 59 || sec > 59 {
1840            return None;
1841        }
1842        let day_ts = date_to_unix_timestamp(year, month, day)?;
1843        let ts = day_ts + (hour as i64) * 3600 + (min as i64) * 60 + (sec as i64);
1844        return Some(ts as f64);
1845    }
1846    None
1847}
1848
1849/// Convert a date (year, month, day) to a Unix timestamp (seconds since 1970-01-01 UTC).
1850fn date_to_unix_timestamp(year: i32, month: u32, day: u32) -> Option<i64> {
1851    // Days in months (non-leap).
1852    const DAYS_IN_MONTH: [u32; 12] = [31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31];
1853    fn is_leap(y: i32) -> bool {
1854        (y % 4 == 0 && y % 100 != 0) || y % 400 == 0
1855    }
1856
1857    // Count days from 1970-01-01 to the given date.
1858    let mut days: i64 = 0;
1859
1860    // Years
1861    if year >= 1970 {
1862        for y in 1970..year {
1863            days += if is_leap(y) { 366 } else { 365 };
1864        }
1865    } else {
1866        for y in year..1970 {
1867            days -= if is_leap(y) { 366 } else { 365 };
1868        }
1869    }
1870
1871    // Months
1872    for m in 1..month {
1873        let mut d = DAYS_IN_MONTH[(m - 1) as usize];
1874        if m == 2 && is_leap(year) {
1875            d += 1;
1876        }
1877        days += d as i64;
1878    }
1879
1880    // Days (1-based, so day 1 = 0 extra days)
1881    days += (day as i64) - 1;
1882
1883    Some(days * 86400)
1884}
1885
1886/// Convert a SQL `LIKE` pattern to Redis Search text syntax.
1887///
1888/// - `%` is mapped to `*` (match zero or more characters)
1889/// - `_` is left as-is (Redis does not have single-char wildcard; best effort)
1890///
1891/// Examples:
1892/// - `laptop%` → `laptop*`
1893/// - `%laptop` → `*laptop`
1894/// - `%laptop%` → `*laptop*`
1895fn sql_like_to_redis(pattern: &str) -> String {
1896    pattern.replace('%', "*")
1897}
1898
1899/// Remove surrounding single quotes from a string literal.
1900fn unquote(s: &str) -> String {
1901    if s.len() >= 2 && s.starts_with('\'') && s.ends_with('\'') {
1902        let inner = &s[1..s.len() - 1];
1903        // Unescape double-quotes: '' → '
1904        inner.replace("''", "'")
1905    } else {
1906        s.to_string()
1907    }
1908}
1909
1910/// Escape tag value characters for Redis Search `@field:{value}` syntax.
1911fn escape_tag(value: &str) -> String {
1912    value
1913        .chars()
1914        .flat_map(|ch| {
1915            if matches!(ch, ' ' | '$' | ':' | '&' | '/' | '-' | '.' | '*') {
1916                vec!['\\', ch]
1917            } else {
1918                vec![ch]
1919            }
1920        })
1921        .collect()
1922}
1923
1924/// Check if a string looks like a numeric value.
1925fn is_numeric_str(s: &str) -> bool {
1926    s.parse::<f64>().is_ok()
1927}
1928
1929/// Format a number: drop fractional part if it's .0.
1930fn format_num(n: f64) -> String {
1931    if n.fract() == 0.0 {
1932        format!("{:.0}", n)
1933    } else {
1934        n.to_string()
1935    }
1936}
1937
1938#[cfg(test)]
1939mod tests {
1940    use super::*;
1941
1942    // ---- Parameter substitution: partial matching prevention ----
1943
1944    #[test]
1945    fn similar_param_names_no_partial_match() {
1946        let query = SQLQuery::with_params(
1947            "SELECT * FROM idx WHERE id = :id AND product_id = :product_id",
1948            HashMap::from([
1949                ("id".to_owned(), SqlParam::Int(123)),
1950                ("product_id".to_owned(), SqlParam::Int(456)),
1951            ]),
1952        );
1953        let substituted = query.substituted_sql();
1954        assert!(substituted.contains("id = 123"));
1955        assert!(substituted.contains("product_id = 456"));
1956        assert!(!substituted.contains("product_123"));
1957    }
1958
1959    #[test]
1960    fn prefix_param_names() {
1961        let query = SQLQuery::with_params(
1962            "SELECT * FROM idx WHERE user = :user AND user_id = :user_id AND user_name = :user_name",
1963            HashMap::from([
1964                ("user".to_owned(), SqlParam::Str("alice".to_owned())),
1965                ("user_id".to_owned(), SqlParam::Int(42)),
1966                (
1967                    "user_name".to_owned(),
1968                    SqlParam::Str("Alice Smith".to_owned()),
1969                ),
1970            ]),
1971        );
1972        let substituted = query.substituted_sql();
1973        assert!(substituted.contains("user = 'alice'"));
1974        assert!(substituted.contains("user_id = 42"));
1975        assert!(substituted.contains("user_name = 'Alice Smith'"));
1976        assert!(!substituted.contains("'alice'_id"));
1977        assert!(!substituted.contains("'alice'_name"));
1978    }
1979
1980    #[test]
1981    fn suffix_param_names() {
1982        let query = SQLQuery::with_params(
1983            "SELECT * FROM idx WHERE vec = :vec AND query_vec = :query_vec",
1984            HashMap::from([
1985                ("vec".to_owned(), SqlParam::Float(1.0)),
1986                ("query_vec".to_owned(), SqlParam::Float(2.0)),
1987            ]),
1988        );
1989        let substituted = query.substituted_sql();
1990        assert!(substituted.contains("vec = 1") || substituted.contains("vec = 1.0"));
1991        assert!(substituted.contains("query_vec = 2") || substituted.contains("query_vec = 2.0"));
1992    }
1993
1994    // ---- Parameter substitution: quote escaping ----
1995
1996    #[test]
1997    fn single_quote_in_value() {
1998        let query = SQLQuery::new("SELECT * FROM idx WHERE name = :name")
1999            .with_param("name", SqlParam::Str("O'Brien".to_owned()));
2000        let substituted = query.substituted_sql();
2001        assert!(substituted.contains("name = 'O''Brien'"));
2002    }
2003
2004    #[test]
2005    fn multiple_quotes_in_value() {
2006        let query = SQLQuery::new("SELECT * FROM idx WHERE phrase = :phrase")
2007            .with_param("phrase", SqlParam::Str("It's a 'test' string".to_owned()));
2008        let substituted = query.substituted_sql();
2009        assert!(substituted.contains("phrase = 'It''s a ''test'' string'"));
2010    }
2011
2012    #[test]
2013    fn apostrophe_names() {
2014        let cases = [
2015            ("McDonald's", "'McDonald''s'"),
2016            ("O'Reilly", "'O''Reilly'"),
2017            ("D'Angelo", "'D''Angelo'"),
2018        ];
2019        for (name, expected) in cases {
2020            let query = SQLQuery::new("SELECT * FROM idx WHERE name = :name")
2021                .with_param("name", SqlParam::Str(name.to_owned()));
2022            let substituted = query.substituted_sql();
2023            assert!(
2024                substituted.contains(&format!("name = {expected}")),
2025                "Failed for {name}: got {substituted}"
2026            );
2027        }
2028    }
2029
2030    // ---- Edge cases ----
2031
2032    #[test]
2033    fn multiple_occurrences_same_param() {
2034        let query = SQLQuery::new("SELECT * FROM idx WHERE category = :cat OR subcategory = :cat")
2035            .with_param("cat", SqlParam::Str("electronics".to_owned()));
2036        let substituted = query.substituted_sql();
2037        assert_eq!(substituted.matches("'electronics'").count(), 2);
2038    }
2039
2040    #[test]
2041    fn empty_string_value() {
2042        let query = SQLQuery::new("SELECT * FROM idx WHERE name = :name")
2043            .with_param("name", SqlParam::Str(String::new()));
2044        let substituted = query.substituted_sql();
2045        assert!(substituted.contains("name = ''"));
2046    }
2047
2048    #[test]
2049    fn numeric_types() {
2050        let query = SQLQuery::with_params(
2051            "SELECT * FROM idx WHERE count = :count AND price = :price",
2052            HashMap::from([
2053                ("count".to_owned(), SqlParam::Int(42)),
2054                ("price".to_owned(), SqlParam::Float(99.99)),
2055            ]),
2056        );
2057        let substituted = query.substituted_sql();
2058        assert!(substituted.contains("count = 42"));
2059        assert!(substituted.contains("price = 99.99"));
2060    }
2061
2062    #[test]
2063    fn bytes_param_not_substituted() {
2064        let query = SQLQuery::new("SELECT * FROM idx WHERE embedding = :vec")
2065            .with_param("vec", SqlParam::Bytes(vec![0x00, 0x01, 0x02, 0x03]));
2066        let substituted = query.substituted_sql();
2067        assert!(substituted.contains(":vec"));
2068    }
2069
2070    #[test]
2071    fn special_characters_in_value() {
2072        let specials = [
2073            "hello@world.com",
2074            "path/to/file",
2075            "price: $100",
2076            "regex.*pattern",
2077            "back\\slash",
2078        ];
2079        for value in specials {
2080            let query = SQLQuery::new("SELECT * FROM idx WHERE field = :field")
2081                .with_param("field", SqlParam::Str(value.to_owned()));
2082            let substituted = query.substituted_sql();
2083            assert!(
2084                !substituted.contains(":field"),
2085                "Failed to substitute for value: {value}"
2086            );
2087        }
2088    }
2089
2090    #[test]
2091    fn no_params_returns_original() {
2092        let query = SQLQuery::new("SELECT * FROM idx");
2093        assert_eq!(query.substituted_sql(), "SELECT * FROM idx");
2094    }
2095
2096    #[test]
2097    fn unknown_placeholder_kept() {
2098        let query = SQLQuery::new("SELECT * FROM idx WHERE x = :unknown")
2099            .with_param("other", SqlParam::Int(1));
2100        assert!(query.substituted_sql().contains(":unknown"));
2101    }
2102
2103    #[test]
2104    fn with_param_builder_pattern() {
2105        let query = SQLQuery::new("SELECT * FROM idx WHERE a = :a AND b = :b")
2106            .with_param("a", SqlParam::Int(1))
2107            .with_param("b", SqlParam::Str("hello".to_owned()));
2108        let sub = query.substituted_sql();
2109        assert!(sub.contains("a = 1"));
2110        assert!(sub.contains("b = 'hello'"));
2111    }
2112
2113    #[test]
2114    fn sql_accessor() {
2115        let query = SQLQuery::new("SELECT 1");
2116        assert_eq!(query.sql(), "SELECT 1");
2117    }
2118
2119    #[test]
2120    fn params_map_accessor() {
2121        let query = SQLQuery::new("SELECT 1").with_param("x", SqlParam::Int(42));
2122        assert_eq!(query.params_map().len(), 1);
2123    }
2124
2125    // ---- SQL→Redis translation tests ----
2126
2127    #[test]
2128    fn select_star_no_where_produces_wildcard() {
2129        let query = SQLQuery::new("SELECT * FROM products");
2130        assert_eq!(query.to_redis_query(), "*");
2131    }
2132
2133    #[test]
2134    fn select_specific_fields_sets_return_fields() {
2135        let query = SQLQuery::new("SELECT title, price FROM products");
2136        assert_eq!(query.to_redis_query(), "*");
2137        assert_eq!(query.return_fields(), vec!["title", "price"]);
2138    }
2139
2140    #[test]
2141    fn where_tag_equals() {
2142        let query = SQLQuery::new("SELECT * FROM products WHERE category = 'electronics'");
2143        assert_eq!(query.to_redis_query(), "@category:{electronics}");
2144    }
2145
2146    #[test]
2147    fn where_tag_not_equals() {
2148        let query = SQLQuery::new("SELECT * FROM products WHERE category != 'electronics'");
2149        assert_eq!(query.to_redis_query(), "(-@category:{electronics})");
2150    }
2151
2152    #[test]
2153    fn where_tag_in() {
2154        let query =
2155            SQLQuery::new("SELECT * FROM products WHERE category IN ('books', 'accessories')");
2156        assert_eq!(query.to_redis_query(), "@category:{books|accessories}");
2157    }
2158
2159    #[test]
2160    fn where_numeric_less_than() {
2161        let query = SQLQuery::new("SELECT * FROM products WHERE price < 50");
2162        assert_eq!(query.to_redis_query(), "@price:[-inf (50]");
2163    }
2164
2165    #[test]
2166    fn where_numeric_greater_than() {
2167        let query = SQLQuery::new("SELECT * FROM products WHERE price > 100");
2168        assert_eq!(query.to_redis_query(), "@price:[(100 +inf]");
2169    }
2170
2171    #[test]
2172    fn where_numeric_equals() {
2173        let query = SQLQuery::new("SELECT * FROM products WHERE price = 45");
2174        assert_eq!(query.to_redis_query(), "@price:[45 45]");
2175    }
2176
2177    #[test]
2178    fn where_numeric_not_equals() {
2179        let query = SQLQuery::new("SELECT * FROM products WHERE price != 45");
2180        assert_eq!(query.to_redis_query(), "(-@price:[45 45])");
2181    }
2182
2183    #[test]
2184    fn where_numeric_lte() {
2185        let query = SQLQuery::new("SELECT * FROM products WHERE price <= 50");
2186        assert_eq!(query.to_redis_query(), "@price:[-inf 50]");
2187    }
2188
2189    #[test]
2190    fn where_numeric_gte() {
2191        let query = SQLQuery::new("SELECT * FROM products WHERE price >= 25");
2192        assert_eq!(query.to_redis_query(), "@price:[25 +inf]");
2193    }
2194
2195    #[test]
2196    fn where_between() {
2197        let query = SQLQuery::new("SELECT * FROM products WHERE price BETWEEN 40 AND 60");
2198        assert_eq!(query.to_redis_query(), "@price:[40 60]");
2199    }
2200
2201    #[test]
2202    fn where_combined_and() {
2203        let query =
2204            SQLQuery::new("SELECT * FROM products WHERE category = 'electronics' AND price < 100");
2205        assert_eq!(
2206            query.to_redis_query(),
2207            "(@category:{electronics} @price:[-inf (100])"
2208        );
2209    }
2210
2211    #[test]
2212    fn order_by_asc() {
2213        let query = SQLQuery::new("SELECT title, price FROM products ORDER BY price ASC");
2214        let sb = query.sort_by().expect("sort_by should be set");
2215        assert_eq!(sb.field, "price");
2216        assert!(matches!(sb.direction, SortDirection::Asc));
2217    }
2218
2219    #[test]
2220    fn order_by_desc() {
2221        let query = SQLQuery::new("SELECT title, price FROM products ORDER BY price DESC");
2222        let sb = query.sort_by().expect("sort_by should be set");
2223        assert_eq!(sb.field, "price");
2224        assert!(matches!(sb.direction, SortDirection::Desc));
2225    }
2226
2227    #[test]
2228    fn limit_clause() {
2229        let query = SQLQuery::new("SELECT title FROM products LIMIT 3");
2230        let lim = query.limit().expect("limit should be set");
2231        assert_eq!(lim.num, 3);
2232        assert_eq!(lim.offset, 0);
2233    }
2234
2235    #[test]
2236    fn limit_with_offset() {
2237        let query = SQLQuery::new("SELECT title FROM products ORDER BY price ASC LIMIT 3 OFFSET 3");
2238        let lim = query.limit().expect("limit should be set");
2239        assert_eq!(lim.num, 3);
2240        assert_eq!(lim.offset, 3);
2241    }
2242
2243    #[test]
2244    fn where_with_order_and_limit() {
2245        let query = SQLQuery::new(
2246            "SELECT title, price FROM products WHERE category = 'electronics' ORDER BY price ASC LIMIT 5",
2247        );
2248        assert_eq!(query.to_redis_query(), "@category:{electronics}");
2249        assert_eq!(query.return_fields(), vec!["title", "price"]);
2250        let sb = query.sort_by().expect("sort_by");
2251        assert_eq!(sb.field, "price");
2252        let lim = query.limit().expect("limit");
2253        assert_eq!(lim.num, 5);
2254    }
2255
2256    #[test]
2257    fn aggregate_query_returns_raw_sql_fallback() {
2258        // Aggregate queries are not translated—they fall back to the raw SQL.
2259        let query = SQLQuery::new("SELECT COUNT(*) as total FROM products");
2260        let result = query.to_redis_query();
2261        // Parsed as None → fallback to substituted_sql.
2262        assert!(result.contains("COUNT"));
2263    }
2264
2265    #[test]
2266    fn text_equality_single_word() {
2267        let query = SQLQuery::new("SELECT * FROM products WHERE title = 'laptop'");
2268        assert_eq!(query.to_redis_query(), "@title:{laptop}");
2269    }
2270
2271    #[test]
2272    fn text_equality_phrase() {
2273        let query = SQLQuery::new("SELECT * FROM products WHERE title = 'gaming laptop'");
2274        assert_eq!(query.to_redis_query(), "@title:(\"gaming laptop\")");
2275    }
2276
2277    #[test]
2278    fn numeric_range_with_and() {
2279        let query = SQLQuery::new("SELECT * FROM products WHERE price >= 25 AND price <= 50");
2280        assert_eq!(
2281            query.to_redis_query(),
2282            "(@price:[25 +inf] @price:[-inf 50])"
2283        );
2284    }
2285
2286    #[test]
2287    fn should_unpack_json_for_select_star() {
2288        let query = SQLQuery::new("SELECT * FROM products");
2289        assert!(query.should_unpack_json());
2290    }
2291
2292    #[test]
2293    fn should_not_unpack_json_for_field_projection() {
2294        let query = SQLQuery::new("SELECT title, price FROM products");
2295        assert!(!query.should_unpack_json());
2296    }
2297
2298    #[test]
2299    fn with_param_where_tag() {
2300        let query = SQLQuery::new("SELECT * FROM products WHERE category = :cat")
2301            .with_param("cat", SqlParam::Str("electronics".to_owned()));
2302        assert_eq!(query.to_redis_query(), "@category:{electronics}");
2303    }
2304
2305    #[test]
2306    fn with_param_where_numeric() {
2307        let query = SQLQuery::new("SELECT * FROM products WHERE price > :min_price")
2308            .with_param("min_price", SqlParam::Float(99.99));
2309        assert_eq!(query.to_redis_query(), "@price:[(99.99 +inf]");
2310    }
2311
2312    // ---- OR support ----
2313
2314    #[test]
2315    fn where_simple_or() {
2316        let query = SQLQuery::new(
2317            "SELECT * FROM products WHERE category = 'electronics' OR category = 'books'",
2318        );
2319        assert_eq!(
2320            query.to_redis_query(),
2321            "(@category:{electronics} | @category:{books})"
2322        );
2323    }
2324
2325    #[test]
2326    fn where_or_with_three_branches() {
2327        let query = SQLQuery::new(
2328            "SELECT * FROM products WHERE category = 'electronics' OR category = 'books' OR category = 'accessories'",
2329        );
2330        assert_eq!(
2331            query.to_redis_query(),
2332            "(@category:{electronics} | @category:{books} | @category:{accessories})"
2333        );
2334    }
2335
2336    #[test]
2337    fn where_and_binds_tighter_than_or() {
2338        // a AND b OR c AND d → (a b) | (c d)
2339        let query = SQLQuery::new(
2340            "SELECT * FROM products WHERE category = 'electronics' AND price > 100 OR category = 'books' AND price < 50",
2341        );
2342        assert_eq!(
2343            query.to_redis_query(),
2344            "((@category:{electronics} @price:[(100 +inf]) | (@category:{books} @price:[-inf (50]))"
2345        );
2346    }
2347
2348    #[test]
2349    fn where_or_with_single_conditions() {
2350        let query = SQLQuery::new("SELECT * FROM products WHERE price < 20 OR price > 1000");
2351        assert_eq!(
2352            query.to_redis_query(),
2353            "(@price:[-inf (20] | @price:[(1000 +inf])"
2354        );
2355    }
2356
2357    #[test]
2358    fn where_or_preserves_order_limit() {
2359        let query = SQLQuery::new(
2360            "SELECT title FROM products WHERE category = 'a' OR category = 'b' ORDER BY price ASC LIMIT 5",
2361        );
2362        assert_eq!(query.to_redis_query(), "(@category:{a} | @category:{b})");
2363        assert!(query.sort_by().is_some());
2364        assert_eq!(query.limit().unwrap().num, 5);
2365    }
2366
2367    // ---- NOT IN support ----
2368
2369    #[test]
2370    fn where_not_in() {
2371        let query =
2372            SQLQuery::new("SELECT * FROM products WHERE category NOT IN ('electronics', 'books')");
2373        assert_eq!(query.to_redis_query(), "(-@category:{electronics|books})");
2374    }
2375
2376    #[test]
2377    fn where_not_in_combined_with_and() {
2378        let query = SQLQuery::new(
2379            "SELECT * FROM products WHERE category NOT IN ('electronics') AND price > 50",
2380        );
2381        assert_eq!(
2382            query.to_redis_query(),
2383            "((-@category:{electronics}) @price:[(50 +inf])"
2384        );
2385    }
2386
2387    // ---- LIKE support ----
2388
2389    #[test]
2390    fn where_like_prefix() {
2391        let query = SQLQuery::new("SELECT * FROM products WHERE title LIKE 'laptop%'");
2392        assert_eq!(query.to_redis_query(), "@title:(laptop*)");
2393    }
2394
2395    #[test]
2396    fn where_like_suffix() {
2397        let query = SQLQuery::new("SELECT * FROM products WHERE title LIKE '%laptop'");
2398        assert_eq!(query.to_redis_query(), "@title:(*laptop)");
2399    }
2400
2401    #[test]
2402    fn where_like_contains() {
2403        let query = SQLQuery::new("SELECT * FROM products WHERE title LIKE '%laptop%'");
2404        assert_eq!(query.to_redis_query(), "@title:(*laptop*)");
2405    }
2406
2407    #[test]
2408    fn where_not_like() {
2409        let query = SQLQuery::new("SELECT * FROM products WHERE title NOT LIKE 'laptop%'");
2410        assert_eq!(query.to_redis_query(), "(-@title:(laptop*))");
2411    }
2412
2413    #[test]
2414    fn where_like_combined_with_and() {
2415        let query =
2416            SQLQuery::new("SELECT * FROM products WHERE title LIKE 'lap%' AND price < 1000");
2417        assert_eq!(
2418            query.to_redis_query(),
2419            "(@title:(lap*) @price:[-inf (1000])"
2420        );
2421    }
2422
2423    // ---- Date literal parsing ----
2424
2425    #[test]
2426    fn where_date_greater_than() {
2427        let query = SQLQuery::new("SELECT * FROM events WHERE created_at > '2024-01-01'");
2428        let result = query.to_redis_query();
2429        // 2024-01-01 00:00:00 UTC = 1704067200
2430        assert_eq!(result, "@created_at:[(1704067200 +inf]");
2431    }
2432
2433    #[test]
2434    fn where_date_less_than() {
2435        let query = SQLQuery::new("SELECT * FROM events WHERE created_at < '2024-03-31'");
2436        let result = query.to_redis_query();
2437        // 2024-03-31 00:00:00 UTC = 1711843200
2438        assert_eq!(result, "@created_at:[-inf (1711843200]");
2439    }
2440
2441    #[test]
2442    fn where_date_between() {
2443        let query = SQLQuery::new(
2444            "SELECT * FROM events WHERE created_at BETWEEN '2024-01-01' AND '2024-03-31'",
2445        );
2446        let result = query.to_redis_query();
2447        assert_eq!(result, "@created_at:[1704067200 1711843200]");
2448    }
2449
2450    #[test]
2451    fn where_date_gte() {
2452        let query = SQLQuery::new("SELECT * FROM events WHERE created_at >= '2024-06-15'");
2453        let result = query.to_redis_query();
2454        // 2024-06-15 = 1718409600
2455        assert_eq!(result, "@created_at:[1718409600 +inf]");
2456    }
2457
2458    #[test]
2459    fn where_date_combined_with_tag() {
2460        let query = SQLQuery::new(
2461            "SELECT * FROM events WHERE category = 'meeting' AND created_at > '2024-01-01'",
2462        );
2463        let result = query.to_redis_query();
2464        assert_eq!(
2465            result,
2466            "(@category:{meeting} @created_at:[(1704067200 +inf])"
2467        );
2468    }
2469
2470    #[test]
2471    fn where_datetime_with_time() {
2472        let query = SQLQuery::new("SELECT * FROM events WHERE created_at > '2024-01-15T10:30:00'");
2473        let result = query.to_redis_query();
2474        // 2024-01-15 00:00:00 UTC = 1705276800, + 10*3600 + 30*60 = 37800 → 1705314600
2475        assert_eq!(result, "@created_at:[(1705314600 +inf]");
2476    }
2477
2478    #[test]
2479    fn date_to_timestamp_known_values() {
2480        // 1970-01-01 → 0
2481        assert_eq!(try_parse_date("1970-01-01"), Some(0.0));
2482        // 2000-01-01 → 946684800
2483        assert_eq!(try_parse_date("2000-01-01"), Some(946_684_800.0));
2484        // 2024-01-01 → 1704067200
2485        assert_eq!(try_parse_date("2024-01-01"), Some(1_704_067_200.0));
2486    }
2487
2488    #[test]
2489    fn invalid_date_returns_none() {
2490        assert_eq!(try_parse_date("not-a-date"), None);
2491        assert_eq!(try_parse_date("2024-13-01"), None); // invalid month
2492        assert_eq!(try_parse_date("2024-00-01"), None); // month 0
2493        assert_eq!(try_parse_date("2024-01-32"), None); // day 32
2494    }
2495
2496    // ---- OR combined with other new features ----
2497
2498    #[test]
2499    fn where_or_with_like() {
2500        let query = SQLQuery::new(
2501            "SELECT * FROM products WHERE title LIKE 'laptop%' OR title LIKE 'phone%'",
2502        );
2503        assert_eq!(
2504            query.to_redis_query(),
2505            "(@title:(laptop*) | @title:(phone*))"
2506        );
2507    }
2508
2509    #[test]
2510    fn where_or_with_date() {
2511        let query = SQLQuery::new(
2512            "SELECT * FROM events WHERE created_at < '2024-01-01' OR created_at > '2024-12-31'",
2513        );
2514        let result = query.to_redis_query();
2515        // 2024-12-31 = 1735603200
2516        assert_eq!(
2517            result,
2518            "(@created_at:[-inf (1704067200] | @created_at:[(1735603200 +inf])"
2519        );
2520    }
2521
2522    // ---- Aggregate SQL tests ----
2523
2524    /// Helper: builds an aggregate command and returns its args as strings.
2525    fn agg_cmd_args(sql: &str, index_name: &str) -> Vec<String> {
2526        let q = SQLQuery::new(sql);
2527        assert!(q.is_aggregate(), "expected aggregate for: {sql}");
2528        let cmd = q.build_aggregate_cmd(index_name).unwrap();
2529        // Convert redis::Cmd to packed args for inspection.
2530        let packed = cmd.get_packed_command();
2531        parse_resp_args(&packed)
2532    }
2533
2534    /// Minimal RESP2 inline arg parser for test inspection.
2535    fn parse_resp_args(data: &[u8]) -> Vec<String> {
2536        let s = String::from_utf8_lossy(data);
2537        let mut args = Vec::new();
2538        let mut remaining = &s[..];
2539        while let Some(dollar) = remaining.find('$') {
2540            remaining = &remaining[dollar + 1..];
2541            let crlf = remaining.find("\r\n").unwrap();
2542            let len: usize = remaining[..crlf].parse().unwrap();
2543            remaining = &remaining[crlf + 2..];
2544            let val = &remaining[..len];
2545            args.push(val.to_string());
2546            remaining = &remaining[len + 2..]; // skip \r\n
2547        }
2548        args
2549    }
2550
2551    #[test]
2552    fn aggregate_count_star() {
2553        let args = agg_cmd_args("SELECT COUNT(*) AS total FROM products", "idx");
2554        assert_eq!(args[0], "FT.AGGREGATE");
2555        assert_eq!(args[1], "idx");
2556        assert_eq!(args[2], "*"); // no WHERE filter
2557        assert_eq!(args[3], "GROUPBY");
2558        assert_eq!(args[4], "0");
2559        assert_eq!(args[5], "REDUCE");
2560        assert_eq!(args[6], "COUNT");
2561        assert_eq!(args[7], "0"); // COUNT takes 0 args
2562        assert_eq!(args[8], "AS");
2563        assert_eq!(args[9], "total");
2564    }
2565
2566    #[test]
2567    fn aggregate_count_star_default_alias() {
2568        let args = agg_cmd_args("SELECT COUNT(*) FROM products", "idx");
2569        assert_eq!(args[9], "count"); // default alias
2570    }
2571
2572    #[test]
2573    fn aggregate_sum() {
2574        let args = agg_cmd_args("SELECT SUM(price) AS total_price FROM products", "idx");
2575        assert_eq!(args[5], "REDUCE");
2576        assert_eq!(args[6], "SUM");
2577        assert_eq!(args[7], "1"); // SUM takes 1 arg
2578        assert_eq!(args[8], "@price");
2579        assert_eq!(args[9], "AS");
2580        assert_eq!(args[10], "total_price");
2581    }
2582
2583    #[test]
2584    fn aggregate_avg() {
2585        let args = agg_cmd_args("SELECT AVG(score) AS avg_score FROM products", "idx");
2586        assert_eq!(args[6], "AVG");
2587        assert_eq!(args[8], "@score");
2588        assert_eq!(args[10], "avg_score");
2589    }
2590
2591    #[test]
2592    fn aggregate_min_max() {
2593        let args = agg_cmd_args("SELECT MIN(price) AS min_price FROM products", "idx");
2594        assert_eq!(args[6], "MIN");
2595        assert_eq!(args[8], "@price");
2596        assert_eq!(args[10], "min_price");
2597
2598        let args = agg_cmd_args("SELECT MAX(price) AS max_price FROM products", "idx");
2599        assert_eq!(args[6], "MAX");
2600        assert_eq!(args[8], "@price");
2601        assert_eq!(args[10], "max_price");
2602    }
2603
2604    #[test]
2605    fn aggregate_stddev() {
2606        let args = agg_cmd_args("SELECT STDDEV(price) AS price_sd FROM products", "idx");
2607        assert_eq!(args[6], "STDDEV");
2608        assert_eq!(args[8], "@price");
2609        assert_eq!(args[10], "price_sd");
2610    }
2611
2612    #[test]
2613    fn aggregate_count_distinct() {
2614        let args = agg_cmd_args(
2615            "SELECT COUNT_DISTINCT(brand) AS unique_brands FROM products",
2616            "idx",
2617        );
2618        assert_eq!(args[6], "COUNT_DISTINCT");
2619        assert_eq!(args[8], "@brand");
2620        assert_eq!(args[10], "unique_brands");
2621    }
2622
2623    #[test]
2624    fn aggregate_quantile() {
2625        let args = agg_cmd_args("SELECT QUANTILE(price, 0.95) AS p95 FROM products", "idx");
2626        assert_eq!(args[6], "QUANTILE");
2627        assert_eq!(args[7], "2"); // QUANTILE takes 2 args
2628        assert_eq!(args[8], "@price");
2629        assert_eq!(args[9], "0.95");
2630        assert_eq!(args[10], "AS");
2631        assert_eq!(args[11], "p95");
2632    }
2633
2634    #[test]
2635    fn aggregate_array_agg_to_tolist() {
2636        let args = agg_cmd_args("SELECT ARRAY_AGG(name) AS names FROM products", "idx");
2637        assert_eq!(args[6], "TOLIST");
2638        assert_eq!(args[8], "@name");
2639        assert_eq!(args[10], "names");
2640    }
2641
2642    #[test]
2643    fn aggregate_first_value() {
2644        let args = agg_cmd_args(
2645            "SELECT FIRST_VALUE(name) AS first_name FROM products",
2646            "idx",
2647        );
2648        assert_eq!(args[6], "FIRST_VALUE");
2649        assert_eq!(args[8], "@name");
2650        assert_eq!(args[10], "first_name");
2651    }
2652
2653    #[test]
2654    fn aggregate_group_by_single_field() {
2655        let args = agg_cmd_args(
2656            "SELECT category, COUNT(*) AS cnt FROM products GROUP BY category",
2657            "idx",
2658        );
2659        assert_eq!(args[0], "FT.AGGREGATE");
2660        assert_eq!(args[1], "idx");
2661        assert_eq!(args[2], "*");
2662        assert_eq!(args[3], "GROUPBY");
2663        assert_eq!(args[4], "1");
2664        assert_eq!(args[5], "@category");
2665        assert_eq!(args[6], "REDUCE");
2666        assert_eq!(args[7], "COUNT");
2667        assert_eq!(args[8], "0");
2668        assert_eq!(args[9], "AS");
2669        assert_eq!(args[10], "cnt");
2670    }
2671
2672    #[test]
2673    fn aggregate_group_by_with_where() {
2674        let args = agg_cmd_args(
2675            "SELECT category, AVG(price) AS avg_price FROM products WHERE price > 10 GROUP BY category",
2676            "idx",
2677        );
2678        assert_eq!(args[2], "@price:[(10 +inf]"); // WHERE filter
2679        assert_eq!(args[3], "GROUPBY");
2680        assert_eq!(args[4], "1");
2681        assert_eq!(args[5], "@category");
2682        assert_eq!(args[6], "REDUCE");
2683        assert_eq!(args[7], "AVG");
2684    }
2685
2686    #[test]
2687    fn aggregate_multiple_reducers() {
2688        let args = agg_cmd_args(
2689            "SELECT category, COUNT(*) AS cnt, AVG(price) AS avg_price FROM products GROUP BY category",
2690            "idx",
2691        );
2692        assert_eq!(args[3], "GROUPBY");
2693        assert_eq!(args[4], "1");
2694        assert_eq!(args[5], "@category");
2695        // First reducer: COUNT
2696        assert_eq!(args[6], "REDUCE");
2697        assert_eq!(args[7], "COUNT");
2698        assert_eq!(args[8], "0");
2699        assert_eq!(args[9], "AS");
2700        assert_eq!(args[10], "cnt");
2701        // Second reducer: AVG
2702        assert_eq!(args[11], "REDUCE");
2703        assert_eq!(args[12], "AVG");
2704        assert_eq!(args[13], "1");
2705        assert_eq!(args[14], "@price");
2706        assert_eq!(args[15], "AS");
2707        assert_eq!(args[16], "avg_price");
2708    }
2709
2710    #[test]
2711    fn aggregate_group_by_multiple_fields() {
2712        let args = agg_cmd_args(
2713            "SELECT category, brand, SUM(price) AS total FROM products GROUP BY category, brand",
2714            "idx",
2715        );
2716        assert_eq!(args[3], "GROUPBY");
2717        assert_eq!(args[4], "2");
2718        assert_eq!(args[5], "@category");
2719        assert_eq!(args[6], "@brand");
2720        assert_eq!(args[7], "REDUCE");
2721        assert_eq!(args[8], "SUM");
2722    }
2723
2724    #[test]
2725    fn non_aggregate_is_not_detected_as_aggregate() {
2726        let q = SQLQuery::new("SELECT * FROM products WHERE price > 10");
2727        assert!(!q.is_aggregate());
2728        assert!(q.build_aggregate_cmd("idx").is_none());
2729    }
2730
2731    #[test]
2732    fn aggregate_query_returns_raw_sql_for_search() {
2733        // An aggregate query should still work via to_redis_query but fall back
2734        // to the raw SQL since it's not a standard SELECT.
2735        let q = SQLQuery::new("SELECT COUNT(*) AS total FROM products");
2736        assert!(q.is_aggregate());
2737        // to_redis_query falls back to raw substituted SQL
2738        let redis_q = q.to_redis_query();
2739        assert!(redis_q.contains("COUNT"));
2740    }
2741
2742    // ---- Vector SQL tests ----
2743
2744    #[test]
2745    fn vector_distance_basic() {
2746        let blob = vec![0u8; 12]; // 3 x f32
2747        let q = SQLQuery::new(
2748            "SELECT title, vector_distance(embedding, :vec) AS score FROM idx LIMIT 3",
2749        )
2750        .with_param("vec", SqlParam::Bytes(blob.clone()));
2751        assert!(q.is_vector_query());
2752        let query_str = q.to_redis_query();
2753        assert_eq!(query_str, "*=>[KNN 3 @embedding $vector AS score]");
2754        let params = q.params();
2755        assert_eq!(params.len(), 1);
2756        assert_eq!(params[0].name, "vector");
2757        if let QueryParamValue::Binary(ref b) = params[0].value {
2758            assert_eq!(b, &blob);
2759        } else {
2760            panic!("Expected Binary param");
2761        }
2762    }
2763
2764    #[test]
2765    fn cosine_distance_basic() {
2766        let blob = vec![0u8; 12];
2767        let q = SQLQuery::new(
2768            "SELECT title, cosine_distance(embedding, :vec) AS dist FROM idx LIMIT 5",
2769        )
2770        .with_param("vec", SqlParam::Bytes(blob));
2771        assert!(q.is_vector_query());
2772        let query_str = q.to_redis_query();
2773        assert_eq!(query_str, "*=>[KNN 5 @embedding $vector AS dist]");
2774    }
2775
2776    #[test]
2777    fn vector_distance_with_where_filter() {
2778        let blob = vec![0u8; 12];
2779        let q = SQLQuery::new(
2780            "SELECT title, vector_distance(embedding, :vec) AS score FROM idx WHERE genre = 'sci-fi' LIMIT 3",
2781        )
2782        .with_param("vec", SqlParam::Bytes(blob));
2783        let query_str = q.to_redis_query();
2784        assert_eq!(
2785            query_str,
2786            "@genre:{sci\\-fi}=>[KNN 3 @embedding $vector AS score]"
2787        );
2788    }
2789
2790    #[test]
2791    fn vector_distance_default_alias() {
2792        let blob = vec![0u8; 12];
2793        let q = SQLQuery::new("SELECT vector_distance(embedding, :vec) FROM idx LIMIT 10")
2794            .with_param("vec", SqlParam::Bytes(blob));
2795        let query_str = q.to_redis_query();
2796        assert_eq!(
2797            query_str,
2798            "*=>[KNN 10 @embedding $vector AS vector_distance]"
2799        );
2800    }
2801
2802    #[test]
2803    fn vector_query_return_fields() {
2804        let blob = vec![0u8; 12];
2805        let q = SQLQuery::new(
2806            "SELECT title, author, vector_distance(embedding, :vec) AS score FROM idx LIMIT 5",
2807        )
2808        .with_param("vec", SqlParam::Bytes(blob));
2809        let fields = q.return_fields();
2810        assert_eq!(fields, vec!["title", "author"]);
2811    }
2812
2813    #[test]
2814    fn vector_query_limit_as_knn() {
2815        let blob = vec![0u8; 12];
2816        let q = SQLQuery::new("SELECT vector_distance(embedding, :vec) AS score FROM idx LIMIT 7")
2817            .with_param("vec", SqlParam::Bytes(blob));
2818        let limit = q.limit().expect("should have limit");
2819        assert_eq!(limit.num, 7);
2820        assert_eq!(limit.offset, 0);
2821    }
2822
2823    #[test]
2824    fn non_vector_query_not_detected_as_vector() {
2825        let q = SQLQuery::new("SELECT * FROM products WHERE price > 10");
2826        assert!(!q.is_vector_query());
2827    }
2828
2829    // ---- Geo WHERE tests (GEOFILTER) ----
2830
2831    #[test]
2832    fn geo_distance_where_basic() {
2833        let q = SQLQuery::new(
2834            "SELECT * FROM locations WHERE geo_distance(location, POINT(-122.4194, 37.7749), 'km') < 50",
2835        );
2836        let gf = q.geofilter().expect("should have geofilter");
2837        assert_eq!(gf.field, "location");
2838        assert!((gf.lon - (-122.4194)).abs() < 0.0001);
2839        assert!((gf.lat - 37.7749).abs() < 0.0001);
2840        assert!((gf.radius - 50.0).abs() < 0.001);
2841        assert_eq!(gf.unit, "km");
2842        // Query string should be wildcard (no additional filter).
2843        assert_eq!(q.to_redis_query(), "*");
2844    }
2845
2846    #[test]
2847    fn geo_distance_where_with_other_conditions() {
2848        let q = SQLQuery::new(
2849            "SELECT name FROM locations WHERE category = 'restaurant' AND geo_distance(location, POINT(-122.4194, 37.7749), 'mi') < 10",
2850        );
2851        let gf = q.geofilter().expect("should have geofilter");
2852        assert_eq!(gf.field, "location");
2853        assert!((gf.radius - 10.0).abs() < 0.001);
2854        assert_eq!(gf.unit, "mi");
2855        // Query string should have the non-geo filter.
2856        assert_eq!(q.to_redis_query(), "@category:{restaurant}");
2857    }
2858
2859    #[test]
2860    fn non_geo_query_no_geofilter() {
2861        let q = SQLQuery::new("SELECT * FROM products WHERE price > 10");
2862        assert!(q.geofilter().is_none());
2863    }
2864
2865    // ---- Geo aggregate (SELECT geo_distance) tests ----
2866
2867    #[test]
2868    fn geo_distance_select_aggregate() {
2869        let q = SQLQuery::new(
2870            "SELECT name, geo_distance(location, POINT(-122.4194, 37.7749)) AS distance FROM locations",
2871        );
2872        assert!(q.is_geo_aggregate());
2873        let cmd = q.build_geo_aggregate_cmd("idx").expect("should build cmd");
2874        let packed = cmd.get_packed_command();
2875        let args = parse_resp_args(&packed);
2876        assert_eq!(args[0], "FT.AGGREGATE");
2877        assert_eq!(args[1], "idx");
2878        assert_eq!(args[2], "*");
2879        assert_eq!(args[3], "LOAD");
2880        assert_eq!(args[4], "1");
2881        assert_eq!(args[5], "@location");
2882        assert_eq!(args[6], "APPLY");
2883        assert!(args[7].contains("geodistance"));
2884        assert!(args[7].contains("@location"));
2885        assert_eq!(args[8], "AS");
2886        assert_eq!(args[9], "distance");
2887    }
2888
2889    #[test]
2890    fn geo_distance_select_with_where() {
2891        let q = SQLQuery::new(
2892            "SELECT name, geo_distance(location, POINT(-73.9857, 40.7484)) AS dist FROM places WHERE category = 'cafe'",
2893        );
2894        assert!(q.is_geo_aggregate());
2895        let cmd = q.build_geo_aggregate_cmd("idx").expect("should build cmd");
2896        let packed = cmd.get_packed_command();
2897        let args = parse_resp_args(&packed);
2898        assert_eq!(args[0], "FT.AGGREGATE");
2899        assert_eq!(args[2], "@category:{cafe}");
2900    }
2901
2902    #[test]
2903    fn non_geo_not_detected_as_geo_aggregate() {
2904        let q = SQLQuery::new("SELECT * FROM products WHERE price > 10");
2905        assert!(!q.is_geo_aggregate());
2906        assert!(q.build_geo_aggregate_cmd("idx").is_none());
2907    }
2908
2909    // ---- Tokenizer test for :param ----
2910
2911    #[test]
2912    fn tokenizer_handles_colon_param() {
2913        let tokens = tokenize("SELECT vector_distance(embedding, :vec) AS score FROM idx");
2914        assert!(tokens.contains(&":vec".to_owned()));
2915    }
2916}