1use std::collections::HashMap;
44
45use super::{QueryLimit, QueryParam, QueryParamValue, QueryString, SortBy, SortDirection};
46
47#[derive(Debug, Clone)]
49pub enum SqlParam {
50 Int(i64),
52 Float(f64),
54 Str(String),
56 Bytes(Vec<u8>),
60}
61
62#[derive(Debug, Clone)]
83pub struct SQLQuery {
84 sql: String,
85 params: HashMap<String, SqlParam>,
86}
87
88impl SQLQuery {
89 pub fn new(sql: impl Into<String>) -> Self {
91 Self {
92 sql: sql.into(),
93 params: HashMap::new(),
94 }
95 }
96
97 pub fn with_params(sql: impl Into<String>, params: HashMap<String, SqlParam>) -> Self {
99 Self {
100 sql: sql.into(),
101 params,
102 }
103 }
104
105 pub fn with_param(mut self, name: impl Into<String>, value: SqlParam) -> Self {
107 self.params.insert(name.into(), value);
108 self
109 }
110
111 pub fn sql(&self) -> &str {
113 &self.sql
114 }
115
116 pub fn params_map(&self) -> &HashMap<String, SqlParam> {
118 &self.params
119 }
120
121 pub fn substituted_sql(&self) -> String {
127 substitute_params(&self.sql, &self.params)
128 }
129
130 fn parsed(&self) -> Option<ParsedSelect> {
136 parse_select(&self.substituted_sql())
137 }
138
139 pub fn is_aggregate(&self) -> bool {
145 parse_aggregate(&self.substituted_sql()).is_some()
146 }
147
148 pub fn build_aggregate_cmd(&self, index_name: &str) -> Option<redis::Cmd> {
154 let parsed = parse_aggregate(&self.substituted_sql())?;
155 Some(parsed.build_cmd(index_name))
156 }
157
158 pub fn is_vector_query(&self) -> bool {
161 parse_vector_select(&self.substituted_sql(), &self.params).is_some()
162 }
163
164 pub fn is_geo_aggregate(&self) -> bool {
167 parse_geo_aggregate(&self.substituted_sql()).is_some()
168 }
169
170 pub fn build_geo_aggregate_cmd(&self, index_name: &str) -> Option<redis::Cmd> {
174 let parsed = parse_geo_aggregate(&self.substituted_sql())?;
175 Some(parsed.build_cmd(index_name))
176 }
177
178 fn parsed_vector(&self) -> Option<ParsedVectorSelect> {
180 parse_vector_select(&self.substituted_sql(), &self.params)
181 }
182
183 fn parsed_geo_where(&self) -> Option<ParsedGeoWhere> {
185 parse_geo_where(&self.substituted_sql())
186 }
187}
188
189impl QueryString for SQLQuery {
190 fn to_redis_query(&self) -> String {
191 if let Some(ref vq) = self.parsed_vector() {
193 return vq.to_knn_query_string();
194 }
195 if let Some(ref gw) = self.parsed_geo_where() {
197 return gw.filter_string();
198 }
199 if let Some(parsed) = self.parsed() {
200 parsed.filter_string()
201 } else {
202 self.substituted_sql()
204 }
205 }
206
207 fn params(&self) -> Vec<QueryParam> {
208 if let Some(ref vq) = self.parsed_vector() {
210 return vq.params();
211 }
212 Vec::new()
213 }
214
215 fn return_fields(&self) -> Vec<String> {
216 if let Some(ref vq) = self.parsed_vector() {
217 return vq.return_fields.clone();
218 }
219 if let Some(ref gw) = self.parsed_geo_where() {
220 return gw.return_fields.clone();
221 }
222 self.parsed().map(|p| p.return_fields).unwrap_or_default()
223 }
224
225 fn sort_by(&self) -> Option<SortBy> {
226 self.parsed().and_then(|p| p.sort_by)
227 }
228
229 fn limit(&self) -> Option<QueryLimit> {
230 if let Some(ref vq) = self.parsed_vector() {
231 return Some(QueryLimit {
232 offset: 0,
233 num: vq.knn_num,
234 });
235 }
236 self.parsed().and_then(|p| p.limit)
237 }
238
239 fn should_unpack_json(&self) -> bool {
240 self.parsed()
242 .map(|p| p.return_fields.is_empty())
243 .unwrap_or(false)
244 }
245
246 fn geofilter(&self) -> Option<super::GeoFilter> {
247 self.parsed_geo_where().map(|gw| gw.geofilter)
248 }
249}
250
251fn substitute_params(sql: &str, params: &HashMap<String, SqlParam>) -> String {
261 if params.is_empty() {
262 return sql.to_owned();
263 }
264
265 let mut result = String::with_capacity(sql.len());
267 let bytes = sql.as_bytes();
268 let len = bytes.len();
269 let mut i = 0;
270
271 while i < len {
272 if bytes[i] == b':' && i + 1 < len && is_ident_start(bytes[i + 1]) {
273 let start = i + 1;
275 let mut end = start;
276 while end < len && is_ident_continue(bytes[end]) {
277 end += 1;
278 }
279 let key = &sql[start..end];
280 if let Some(param) = params.get(key) {
281 match param {
282 SqlParam::Int(v) => {
283 result.push_str(&v.to_string());
284 }
285 SqlParam::Float(v) => {
286 result.push_str(&v.to_string());
287 }
288 SqlParam::Str(v) => {
289 result.push('\'');
290 result.push_str(&v.replace('\'', "''"));
291 result.push('\'');
292 }
293 SqlParam::Bytes(_) => {
294 result.push(':');
296 result.push_str(key);
297 }
298 }
299 } else {
300 result.push(':');
302 result.push_str(key);
303 }
304 i = end;
305 } else {
306 result.push(sql[i..].chars().next().unwrap());
307 i += sql[i..].chars().next().unwrap().len_utf8();
308 }
309 }
310
311 result
312}
313
314fn is_ident_start(b: u8) -> bool {
315 b.is_ascii_alphabetic() || b == b'_'
316}
317
318fn is_ident_continue(b: u8) -> bool {
319 b.is_ascii_alphanumeric() || b == b'_'
320}
321
322#[derive(Debug, Clone)]
328struct ParsedSelect {
329 return_fields: Vec<String>,
331 where_filter: Option<String>,
333 sort_by: Option<SortBy>,
335 limit: Option<QueryLimit>,
337}
338
339impl ParsedSelect {
340 fn filter_string(&self) -> String {
342 self.where_filter.clone().unwrap_or_else(|| "*".to_owned())
343 }
344}
345
346fn parse_select(sql: &str) -> Option<ParsedSelect> {
350 let tokens = tokenize(sql);
351 if tokens.is_empty() {
352 return None;
353 }
354 let mut pos = 0;
355
356 if !tok_eq(&tokens, pos, "SELECT") {
358 return None;
359 }
360 pos += 1;
361
362 for tok in &tokens {
364 let upper = tok.to_ascii_uppercase();
365 if matches!(
366 upper.as_str(),
367 "COUNT"
368 | "AVG"
369 | "SUM"
370 | "MIN"
371 | "MAX"
372 | "STDDEV"
373 | "QUANTILE"
374 | "COUNT_DISTINCT"
375 | "ARRAY_AGG"
376 | "FIRST_VALUE"
377 ) {
378 return None;
379 }
380 }
381
382 for tok in &tokens {
384 let lower = tok.to_ascii_lowercase();
385 if lower == "cosine_distance" || lower == "vector_distance" || lower == "geo_distance" {
386 return None;
387 }
388 }
389
390 let mut return_fields = Vec::new();
392 if tok_eq(&tokens, pos, "*") {
393 pos += 1;
394 } else {
395 loop {
396 if pos >= tokens.len() {
397 return None;
398 }
399 let field = &tokens[pos];
400 if field.eq_ignore_ascii_case("FROM") {
401 break;
402 }
403 if !field.eq_ignore_ascii_case(",") && !field.eq_ignore_ascii_case("AS") {
405 if pos > 0 && tokens[pos - 1].eq_ignore_ascii_case("AS") {
407 } else {
409 return_fields.push(field.to_string());
410 }
411 }
412 pos += 1;
413 }
414 }
415
416 if !tok_eq(&tokens, pos, "FROM") {
418 return None;
419 }
420 pos += 1;
421 if pos >= tokens.len() {
423 return None;
424 }
425 pos += 1;
426
427 let mut where_filter: Option<String> = None;
429 let mut sort_by: Option<SortBy> = None;
430 let mut limit: Option<QueryLimit> = None;
431
432 while pos < tokens.len() {
433 if tok_eq(&tokens, pos, "WHERE") {
434 pos += 1;
435 let (filter_str, next) = parse_where_clause(&tokens, pos)?;
436 where_filter = Some(filter_str);
437 pos = next;
438 } else if tok_eq(&tokens, pos, "ORDER") {
439 if !tok_eq(&tokens, pos + 1, "BY") {
440 return None;
441 }
442 pos += 2;
443 if pos >= tokens.len() {
444 return None;
445 }
446 let field = tokens[pos].clone();
447 pos += 1;
448 let direction = if tok_eq(&tokens, pos, "DESC") {
449 pos += 1;
450 SortDirection::Desc
451 } else {
452 if tok_eq(&tokens, pos, "ASC") {
453 pos += 1;
454 }
455 SortDirection::Asc
456 };
457 sort_by = Some(SortBy { field, direction });
458 } else if tok_eq(&tokens, pos, "LIMIT") {
459 pos += 1;
460 let num = parse_usize(&tokens, pos)?;
461 pos += 1;
462 let offset = if tok_eq(&tokens, pos, "OFFSET") {
463 pos += 1;
464 let off = parse_usize(&tokens, pos)?;
465 pos += 1;
466 off
467 } else {
468 0
469 };
470 limit = Some(QueryLimit { offset, num });
471 } else {
472 pos += 1;
474 }
475 }
476
477 Some(ParsedSelect {
478 return_fields,
479 where_filter,
480 sort_by,
481 limit,
482 })
483}
484
485#[derive(Debug, Clone)]
491struct AggReducer {
492 function: String,
494 field: Option<String>,
496 alias: String,
498 extra_arg: Option<f64>,
500}
501
502#[derive(Debug, Clone)]
504struct ParsedAggregate {
505 where_filter: Option<String>,
507 group_by_fields: Vec<String>,
509 reducers: Vec<AggReducer>,
511}
512
513impl ParsedAggregate {
514 fn build_cmd(&self, index_name: &str) -> redis::Cmd {
516 let mut cmd = redis::cmd("FT.AGGREGATE");
517 cmd.arg(index_name);
518
519 let filter = self.where_filter.as_deref().unwrap_or("*");
521 cmd.arg(filter);
522
523 if self.group_by_fields.is_empty() {
524 cmd.arg("GROUPBY").arg(0_u32);
527 for reducer in &self.reducers {
528 self.append_reducer(&mut cmd, reducer);
529 }
530 } else {
531 cmd.arg("GROUPBY").arg(self.group_by_fields.len());
533 for field in &self.group_by_fields {
534 cmd.arg(format!("@{}", field));
535 }
536 for reducer in &self.reducers {
537 self.append_reducer(&mut cmd, reducer);
538 }
539 }
540
541 cmd
542 }
543
544 fn append_reducer(&self, cmd: &mut redis::Cmd, reducer: &AggReducer) {
546 cmd.arg("REDUCE");
547 cmd.arg(&reducer.function);
548
549 match reducer.function.as_str() {
550 "COUNT" => {
551 cmd.arg(0_u32); }
553 "QUANTILE" => {
554 cmd.arg(2_u32);
556 if let Some(ref field) = reducer.field {
557 cmd.arg(format!("@{}", field));
558 }
559 if let Some(q) = reducer.extra_arg {
560 cmd.arg(format_num(q));
561 }
562 }
563 _ => {
564 cmd.arg(1_u32);
566 if let Some(ref field) = reducer.field {
567 cmd.arg(format!("@{}", field));
568 }
569 }
570 }
571
572 cmd.arg("AS").arg(&reducer.alias);
573 }
574}
575
576fn parse_aggregate(sql: &str) -> Option<ParsedAggregate> {
581 let tokens = tokenize(sql);
582 if tokens.is_empty() {
583 return None;
584 }
585 let mut pos = 0;
586
587 if !tok_eq(&tokens, pos, "SELECT") {
589 return None;
590 }
591 pos += 1;
592
593 let has_aggregate_fn = tokens.iter().any(|t| {
595 let upper = t.to_ascii_uppercase();
596 matches!(
597 upper.as_str(),
598 "COUNT"
599 | "AVG"
600 | "SUM"
601 | "MIN"
602 | "MAX"
603 | "STDDEV"
604 | "QUANTILE"
605 | "COUNT_DISTINCT"
606 | "ARRAY_AGG"
607 | "FIRST_VALUE"
608 )
609 });
610
611 let has_group_by = tokens
612 .windows(2)
613 .any(|w| w[0].eq_ignore_ascii_case("GROUP") && w[1].eq_ignore_ascii_case("BY"));
614
615 if !has_aggregate_fn && !has_group_by {
616 return None;
617 }
618
619 let mut reducers = Vec::new();
621 while pos < tokens.len() && !tok_eq(&tokens, pos, "FROM") {
623 if let Some((reducer, next)) = try_parse_aggregate_fn(&tokens, pos) {
624 reducers.push(reducer);
625 pos = next;
626 } else if tokens[pos] == "," {
627 pos += 1;
628 } else {
629 pos += 1;
632 }
633 }
634
635 if !tok_eq(&tokens, pos, "FROM") {
637 return None;
638 }
639 pos += 1;
640 if pos >= tokens.len() {
642 return None;
643 }
644 pos += 1;
645
646 let mut where_filter: Option<String> = None;
648 let mut group_by_fields = Vec::new();
649
650 while pos < tokens.len() {
651 if tok_eq(&tokens, pos, "WHERE") {
652 pos += 1;
653 let (filter_str, next) = parse_where_clause(&tokens, pos)?;
654 where_filter = Some(filter_str);
655 pos = next;
656 } else if tok_eq(&tokens, pos, "GROUP") {
657 if !tok_eq(&tokens, pos + 1, "BY") {
658 return None;
659 }
660 pos += 2;
661 while pos < tokens.len() {
663 let upper = tokens[pos].to_ascii_uppercase();
664 if matches!(upper.as_str(), "HAVING" | "ORDER" | "LIMIT") {
665 break;
666 }
667 if tokens[pos] == "," {
668 pos += 1;
669 continue;
670 }
671 group_by_fields.push(tokens[pos].clone());
672 pos += 1;
673 }
674 } else {
675 pos += 1;
676 }
677 }
678
679 if reducers.is_empty() {
681 return None;
682 }
683
684 Some(ParsedAggregate {
685 where_filter,
686 group_by_fields,
687 reducers,
688 })
689}
690
691fn try_parse_aggregate_fn(tokens: &[String], pos: usize) -> Option<(AggReducer, usize)> {
697 if pos >= tokens.len() {
698 return None;
699 }
700
701 let func_upper = tokens[pos].to_ascii_uppercase();
702
703 let redis_func = match func_upper.as_str() {
705 "COUNT" => "COUNT",
706 "SUM" => "SUM",
707 "AVG" => "AVG",
708 "MIN" => "MIN",
709 "MAX" => "MAX",
710 "STDDEV" => "STDDEV",
711 "COUNT_DISTINCT" => "COUNT_DISTINCT",
712 "QUANTILE" => "QUANTILE",
713 "ARRAY_AGG" => "TOLIST",
714 "FIRST_VALUE" => "FIRST_VALUE",
715 _ => return None,
716 };
717
718 let mut p = pos + 1;
719
720 if !tok_eq(tokens, p, "(") {
722 return None;
723 }
724 p += 1;
725
726 let mut field: Option<String> = None;
728 let mut extra_arg: Option<f64> = None;
729
730 if func_upper == "COUNT" && tok_eq(tokens, p, "*") {
731 p += 1;
733 } else if p < tokens.len() && tokens[p] != ")" {
734 field = Some(tokens[p].clone());
736 p += 1;
737
738 if tok_eq(tokens, p, ",") {
740 p += 1;
741 if p < tokens.len() && tokens[p] != ")" {
742 extra_arg = tokens[p].parse::<f64>().ok();
743 p += 1;
744 }
745 }
746 }
747
748 if !tok_eq(tokens, p, ")") {
750 return None;
751 }
752 p += 1;
753
754 let alias = if tok_eq(tokens, p, "AS") {
756 p += 1;
757 if p >= tokens.len() {
758 return None;
759 }
760 let a = tokens[p].clone();
761 p += 1;
762 a
763 } else {
764 func_upper.to_lowercase()
766 };
767
768 Some((
769 AggReducer {
770 function: redis_func.to_owned(),
771 field,
772 alias,
773 extra_arg,
774 },
775 p,
776 ))
777}
778
779#[derive(Debug, Clone)]
785struct VectorFuncCall {
786 field: String,
788 param_name: String,
790 alias: String,
792}
793
794#[derive(Debug, Clone)]
796struct ParsedVectorSelect {
797 vector_fn: VectorFuncCall,
799 return_fields: Vec<String>,
801 where_filter: Option<String>,
803 knn_num: usize,
805 vector_blob: Option<Vec<u8>>,
807}
808
809impl ParsedVectorSelect {
810 fn to_knn_query_string(&self) -> String {
812 let base = self.where_filter.as_deref().unwrap_or("*");
813 format!(
814 "{}=>[KNN {} @{} $vector AS {}]",
815 base, self.knn_num, self.vector_fn.field, self.vector_fn.alias
816 )
817 }
818
819 fn params(&self) -> Vec<QueryParam> {
821 if let Some(ref blob) = self.vector_blob {
822 vec![QueryParam {
823 name: "vector".to_owned(),
824 value: QueryParamValue::Binary(blob.clone()),
825 }]
826 } else {
827 Vec::new()
828 }
829 }
830}
831
832fn parse_vector_select(
838 sql: &str,
839 params: &HashMap<String, SqlParam>,
840) -> Option<ParsedVectorSelect> {
841 let tokens = tokenize(sql);
842 if tokens.is_empty() {
843 return None;
844 }
845 let mut pos = 0;
846
847 if !tok_eq(&tokens, pos, "SELECT") {
849 return None;
850 }
851 pos += 1;
852
853 let mut vector_fn: Option<VectorFuncCall> = None;
855 let mut return_fields: Vec<String> = Vec::new();
856
857 while pos < tokens.len() && !tok_eq(&tokens, pos, "FROM") {
858 if tokens[pos] == "," {
859 pos += 1;
860 continue;
861 }
862
863 let lower = tokens[pos].to_ascii_lowercase();
865 if (lower == "vector_distance" || lower == "cosine_distance")
866 && tok_eq(&tokens, pos + 1, "(")
867 {
868 let parsed = try_parse_vector_fn_call(&tokens, pos)?;
869 vector_fn = Some(parsed.0);
870 pos = parsed.1;
871 continue;
872 }
873
874 if tokens[pos].eq_ignore_ascii_case("AS") {
876 pos += 1; if pos < tokens.len() && !tok_eq(&tokens, pos, "FROM") {
878 pos += 1; }
880 continue;
881 }
882
883 if !tokens[pos].eq_ignore_ascii_case("*") {
885 return_fields.push(tokens[pos].clone());
886 }
887 pos += 1;
888 }
889
890 let vector_fn = vector_fn?; if !tok_eq(&tokens, pos, "FROM") {
894 return None;
895 }
896 pos += 1;
897 if pos >= tokens.len() {
898 return None;
899 }
900 pos += 1; let mut where_filter: Option<String> = None;
904 let mut knn_num: usize = 10; while pos < tokens.len() {
907 if tok_eq(&tokens, pos, "WHERE") {
908 pos += 1;
909 let (filter_str, next) = parse_where_clause(&tokens, pos)?;
910 where_filter = Some(filter_str);
911 pos = next;
912 } else if tok_eq(&tokens, pos, "ORDER") {
913 while pos < tokens.len()
915 && !tok_eq(&tokens, pos, "LIMIT")
916 && !tok_eq(&tokens, pos, "WHERE")
917 {
918 pos += 1;
919 }
920 } else if tok_eq(&tokens, pos, "LIMIT") {
921 pos += 1;
922 knn_num = parse_usize(&tokens, pos)?;
923 pos += 1;
924 if tok_eq(&tokens, pos, "OFFSET") {
926 pos += 2;
927 }
928 } else {
929 pos += 1;
930 }
931 }
932
933 let vector_blob = params.get(&vector_fn.param_name).and_then(|p| {
935 if let SqlParam::Bytes(b) = p {
936 Some(b.clone())
937 } else {
938 None
939 }
940 });
941
942 Some(ParsedVectorSelect {
943 vector_fn,
944 return_fields,
945 where_filter,
946 knn_num,
947 vector_blob,
948 })
949}
950
951fn try_parse_vector_fn_call(tokens: &[String], pos: usize) -> Option<(VectorFuncCall, usize)> {
954 if pos + 5 >= tokens.len() {
955 return None;
956 }
957
958 let _func_name = &tokens[pos]; let mut p = pos + 1;
960
961 if !tok_eq(tokens, p, "(") {
963 return None;
964 }
965 p += 1;
966
967 let field = tokens[p].clone();
969 p += 1;
970
971 if !tok_eq(tokens, p, ",") {
973 return None;
974 }
975 p += 1;
976
977 let param_tok = &tokens[p];
979 let param_name = if param_tok.starts_with(':') {
980 param_tok[1..].to_string()
981 } else {
982 param_tok.clone()
983 };
984 p += 1;
985
986 if !tok_eq(tokens, p, ")") {
988 return None;
989 }
990 p += 1;
991
992 let alias = if tok_eq(tokens, p, "AS") {
994 p += 1;
995 if p >= tokens.len() {
996 return None;
997 }
998 let a = tokens[p].clone();
999 p += 1;
1000 a
1001 } else {
1002 "vector_distance".to_string()
1003 };
1004
1005 Some((
1006 VectorFuncCall {
1007 field,
1008 param_name,
1009 alias,
1010 },
1011 p,
1012 ))
1013}
1014
1015#[derive(Debug, Clone)]
1021struct ParsedGeoWhere {
1022 geofilter: super::GeoFilter,
1024 non_geo_filter: Option<String>,
1026 return_fields: Vec<String>,
1028}
1029
1030impl ParsedGeoWhere {
1031 fn filter_string(&self) -> String {
1033 self.non_geo_filter
1034 .clone()
1035 .unwrap_or_else(|| "*".to_owned())
1036 }
1037}
1038
1039#[derive(Debug, Clone)]
1041struct ParsedGeoAggregate {
1042 geo_field: String,
1044 lon: f64,
1046 lat: f64,
1048 alias: String,
1050 where_filter: Option<String>,
1052}
1053
1054impl ParsedGeoAggregate {
1055 fn build_cmd(&self, index_name: &str) -> redis::Cmd {
1057 let mut cmd = redis::cmd("FT.AGGREGATE");
1058 cmd.arg(index_name);
1059 cmd.arg(self.where_filter.as_deref().unwrap_or("*"));
1060
1061 cmd.arg("LOAD")
1063 .arg(1_u32)
1064 .arg(format!("@{}", self.geo_field));
1065
1066 let expr = format!(
1068 "geodistance(@{}, {}, {})",
1069 self.geo_field, self.lon, self.lat
1070 );
1071 cmd.arg("APPLY").arg(expr).arg("AS").arg(&self.alias);
1072
1073 cmd
1074 }
1075}
1076
1077fn parse_geo_where(sql: &str) -> Option<ParsedGeoWhere> {
1081 let tokens = tokenize(sql);
1082 if tokens.is_empty() {
1083 return None;
1084 }
1085 let mut pos = 0;
1086
1087 if !tok_eq(&tokens, pos, "SELECT") {
1089 return None;
1090 }
1091 pos += 1;
1092
1093 let mut return_fields: Vec<String> = Vec::new();
1095 if tok_eq(&tokens, pos, "*") {
1096 pos += 1;
1097 } else {
1098 while pos < tokens.len() && !tok_eq(&tokens, pos, "FROM") {
1099 if tokens[pos] == "," || tokens[pos].eq_ignore_ascii_case("AS") {
1100 pos += 1;
1101 if pos > 1
1103 && tokens[pos - 1].eq_ignore_ascii_case("AS")
1104 && pos < tokens.len()
1105 && !tok_eq(&tokens, pos, "FROM")
1106 {
1107 pos += 1;
1108 }
1109 continue;
1110 }
1111 return_fields.push(tokens[pos].clone());
1112 pos += 1;
1113 }
1114 }
1115
1116 if !tok_eq(&tokens, pos, "FROM") {
1118 return None;
1119 }
1120 pos += 1;
1121 if pos >= tokens.len() {
1122 return None;
1123 }
1124 pos += 1; if !tok_eq(&tokens, pos, "WHERE") {
1128 return None;
1129 }
1130 pos += 1;
1131
1132 let mut non_geo_conditions: Vec<String> = Vec::new();
1135 let mut geofilter: Option<super::GeoFilter> = None;
1136
1137 loop {
1138 if pos >= tokens.len() {
1139 break;
1140 }
1141 let upper = tokens[pos].to_ascii_uppercase();
1142 if matches!(upper.as_str(), "ORDER" | "LIMIT" | "GROUP" | "HAVING") {
1143 break;
1144 }
1145 if upper == "AND" {
1146 pos += 1;
1147 continue;
1148 }
1149
1150 if tokens[pos].eq_ignore_ascii_case("geo_distance") && tok_eq(&tokens, pos + 1, "(") {
1152 let (gf, next) = parse_geo_distance_where(&tokens, pos)?;
1153 geofilter = Some(gf);
1154 pos = next;
1155 continue;
1156 }
1157
1158 let (filter, next) = parse_single_condition(&tokens, pos)?;
1160 non_geo_conditions.push(filter);
1161 pos = next;
1162 }
1163
1164 let geofilter = geofilter?; let non_geo_filter = if non_geo_conditions.is_empty() {
1167 None
1168 } else if non_geo_conditions.len() == 1 {
1169 Some(non_geo_conditions.into_iter().next().unwrap())
1170 } else {
1171 Some(format!("({})", non_geo_conditions.join(" ")))
1172 };
1173
1174 Some(ParsedGeoWhere {
1175 geofilter,
1176 non_geo_filter,
1177 return_fields,
1178 })
1179}
1180
1181fn parse_geo_distance_where(tokens: &[String], pos: usize) -> Option<(super::GeoFilter, usize)> {
1185 let mut p = pos;
1186
1187 if !tokens[p].eq_ignore_ascii_case("geo_distance") {
1189 return None;
1190 }
1191 p += 1;
1192
1193 if !tok_eq(tokens, p, "(") {
1195 return None;
1196 }
1197 p += 1;
1198
1199 let field = tokens[p].clone();
1201 p += 1;
1202
1203 if !tok_eq(tokens, p, ",") {
1205 return None;
1206 }
1207 p += 1;
1208
1209 let (lon, lat);
1211 if tokens[p].eq_ignore_ascii_case("POINT") {
1212 p += 1;
1213 if !tok_eq(tokens, p, "(") {
1215 return None;
1216 }
1217 p += 1;
1218 lon = tokens[p].parse::<f64>().ok()?;
1219 p += 1;
1220 if !tok_eq(tokens, p, ",") {
1222 return None;
1223 }
1224 p += 1;
1225 lat = tokens[p].parse::<f64>().ok()?;
1226 p += 1;
1227 if !tok_eq(tokens, p, ")") {
1229 return None;
1230 }
1231 p += 1;
1232 } else {
1233 lon = tokens[p].parse::<f64>().ok()?;
1234 p += 1;
1235 if tok_eq(tokens, p, ",") {
1236 p += 1;
1237 }
1238 lat = tokens[p].parse::<f64>().ok()?;
1239 p += 1;
1240 }
1241
1242 if !tok_eq(tokens, p, ",") {
1244 return None;
1245 }
1246 p += 1;
1247 let unit = unquote(&tokens[p]);
1248 p += 1;
1249
1250 if !tok_eq(tokens, p, ")") {
1252 return None;
1253 }
1254 p += 1;
1255
1256 if !tok_eq(tokens, p, "<") {
1258 return None;
1259 }
1260 p += 1;
1261 let radius = tokens[p].parse::<f64>().ok()?;
1262 p += 1;
1263
1264 Some((
1265 super::GeoFilter {
1266 field,
1267 lon,
1268 lat,
1269 radius,
1270 unit,
1271 },
1272 p,
1273 ))
1274}
1275
1276fn parse_geo_aggregate(sql: &str) -> Option<ParsedGeoAggregate> {
1281 let tokens = tokenize(sql);
1282 if tokens.is_empty() {
1283 return None;
1284 }
1285 let mut pos = 0;
1286
1287 if !tok_eq(&tokens, pos, "SELECT") {
1288 return None;
1289 }
1290 pos += 1;
1291
1292 let mut geo_field: Option<String> = None;
1293 let mut geo_lon: Option<f64> = None;
1294 let mut geo_lat: Option<f64> = None;
1295 let mut geo_alias: Option<String> = None;
1296
1297 while pos < tokens.len() && !tok_eq(&tokens, pos, "FROM") {
1299 if tokens[pos] == "," {
1300 pos += 1;
1301 continue;
1302 }
1303
1304 if tokens[pos].eq_ignore_ascii_case("geo_distance") && tok_eq(&tokens, pos + 1, "(") {
1305 pos += 2; let field = tokens[pos].clone();
1308 pos += 1;
1309 if !tok_eq(&tokens, pos, ",") {
1310 return None;
1311 }
1312 pos += 1;
1313
1314 let (lon, lat);
1316 if tokens[pos].eq_ignore_ascii_case("POINT") {
1317 pos += 1;
1318 if !tok_eq(&tokens, pos, "(") {
1319 return None;
1320 }
1321 pos += 1;
1322 lon = tokens[pos].parse::<f64>().ok()?;
1323 pos += 1;
1324 if tok_eq(&tokens, pos, ",") {
1325 pos += 1;
1326 }
1327 lat = tokens[pos].parse::<f64>().ok()?;
1328 pos += 1;
1329 if !tok_eq(&tokens, pos, ")") {
1330 return None;
1331 }
1332 pos += 1;
1333 } else {
1334 return None;
1335 }
1336
1337 if !tok_eq(&tokens, pos, ")") {
1339 return None;
1340 }
1341 pos += 1;
1342
1343 let alias = if tok_eq(&tokens, pos, "AS") {
1345 pos += 1;
1346 let a = tokens[pos].clone();
1347 pos += 1;
1348 a
1349 } else {
1350 "distance".to_string()
1351 };
1352
1353 geo_field = Some(field);
1354 geo_lon = Some(lon);
1355 geo_lat = Some(lat);
1356 geo_alias = Some(alias);
1357 continue;
1358 }
1359
1360 if tokens[pos].eq_ignore_ascii_case("AS") {
1362 pos += 1;
1363 if pos < tokens.len() {
1364 pos += 1; }
1366 continue;
1367 }
1368
1369 pos += 1;
1371 }
1372
1373 let geo_field = geo_field?;
1374 let lon = geo_lon?;
1375 let lat = geo_lat?;
1376 let alias = geo_alias.unwrap_or_else(|| "distance".to_string());
1377
1378 if !tok_eq(&tokens, pos, "FROM") {
1380 return None;
1381 }
1382 pos += 1;
1383 if pos >= tokens.len() {
1384 return None;
1385 }
1386 pos += 1; let mut where_filter: Option<String> = None;
1390 while pos < tokens.len() {
1391 if tok_eq(&tokens, pos, "WHERE") {
1392 pos += 1;
1393 let (filter_str, next) = parse_where_clause(&tokens, pos)?;
1394 where_filter = Some(filter_str);
1395 pos = next;
1396 } else {
1397 pos += 1;
1398 }
1399 }
1400
1401 Some(ParsedGeoAggregate {
1402 geo_field,
1403 lon,
1404 lat,
1405 alias,
1406 where_filter,
1407 })
1408}
1409
1410fn parse_where_clause(tokens: &[String], mut pos: usize) -> Option<(String, usize)> {
1416 let mut or_groups: Vec<Vec<String>> = Vec::new();
1418 let mut current_and_group: Vec<String> = Vec::new();
1419
1420 loop {
1421 if pos >= tokens.len() {
1422 break;
1423 }
1424 let upper = tokens[pos].to_ascii_uppercase();
1426 if matches!(upper.as_str(), "ORDER" | "LIMIT" | "GROUP" | "HAVING") {
1427 break;
1428 }
1429 if upper == "AND" {
1431 pos += 1;
1432 continue;
1433 }
1434 if upper == "OR" {
1436 pos += 1;
1437 or_groups.push(std::mem::take(&mut current_and_group));
1438 continue;
1439 }
1440
1441 let (filter, next) = parse_single_condition(tokens, pos)?;
1442 current_and_group.push(filter);
1443 pos = next;
1444 }
1445
1446 if !current_and_group.is_empty() {
1448 or_groups.push(current_and_group);
1449 }
1450
1451 if or_groups.is_empty() {
1452 return Some(("*".to_owned(), pos));
1453 }
1454
1455 let group_strs: Vec<String> = or_groups
1457 .into_iter()
1458 .map(|g| {
1459 if g.len() == 1 {
1460 g.into_iter().next().unwrap()
1461 } else {
1462 format!("({})", g.join(" "))
1463 }
1464 })
1465 .collect();
1466
1467 let filter = if group_strs.len() == 1 {
1468 group_strs.into_iter().next().unwrap()
1469 } else {
1470 format!("({})", group_strs.join(" | "))
1472 };
1473
1474 Some((filter, pos))
1475}
1476
1477fn parse_single_condition(tokens: &[String], mut pos: usize) -> Option<(String, usize)> {
1482 let field = &tokens[pos];
1483 pos += 1;
1484 if pos >= tokens.len() {
1485 return None;
1486 }
1487
1488 let op = &tokens[pos];
1489 pos += 1;
1490
1491 if op.eq_ignore_ascii_case("BETWEEN") {
1493 let lo = parse_numeric_or_date_literal(tokens, pos)?;
1494 pos += 1;
1495 if !tok_eq(tokens, pos, "AND") {
1496 return None;
1497 }
1498 pos += 1;
1499 let hi = parse_numeric_or_date_literal(tokens, pos)?;
1500 pos += 1;
1501 return Some((
1502 format!("@{}:[{} {}]", field, format_num(lo), format_num(hi)),
1503 pos,
1504 ));
1505 }
1506
1507 if op.eq_ignore_ascii_case("NOT") && tok_eq(tokens, pos, "IN") {
1509 pos += 1; if !tok_eq(tokens, pos, "(") {
1511 return None;
1512 }
1513 pos += 1;
1514 let mut vals = Vec::new();
1515 loop {
1516 if pos >= tokens.len() {
1517 return None;
1518 }
1519 if tokens[pos] == ")" {
1520 pos += 1;
1521 break;
1522 }
1523 if tokens[pos] == "," {
1524 pos += 1;
1525 continue;
1526 }
1527 vals.push(unquote(&tokens[pos]));
1528 pos += 1;
1529 }
1530 let escaped: Vec<String> = vals.iter().map(|v| escape_tag(v)).collect();
1531 return Some((format!("(-@{}:{{{}}})", field, escaped.join("|")), pos));
1532 }
1533
1534 if op.eq_ignore_ascii_case("IN") {
1536 if !tok_eq(tokens, pos, "(") {
1537 return None;
1538 }
1539 pos += 1;
1540 let mut vals = Vec::new();
1541 loop {
1542 if pos >= tokens.len() {
1543 return None;
1544 }
1545 if tokens[pos] == ")" {
1546 pos += 1;
1547 break;
1548 }
1549 if tokens[pos] == "," {
1550 pos += 1;
1551 continue;
1552 }
1553 vals.push(unquote(&tokens[pos]));
1554 pos += 1;
1555 }
1556 let escaped: Vec<String> = vals.iter().map(|v| escape_tag(v)).collect();
1557 return Some((format!("@{}:{{{}}}", field, escaped.join("|")), pos));
1558 }
1559
1560 if op.eq_ignore_ascii_case("LIKE") {
1562 if pos >= tokens.len() {
1563 return None;
1564 }
1565 let pattern = unquote(&tokens[pos]);
1566 pos += 1;
1567 let redis_pattern = sql_like_to_redis(&pattern);
1568 return Some((format!("@{}:({})", field, redis_pattern), pos));
1569 }
1570
1571 if op.eq_ignore_ascii_case("NOT") && tok_eq(tokens, pos, "LIKE") {
1573 pos += 1; if pos >= tokens.len() {
1575 return None;
1576 }
1577 let pattern = unquote(&tokens[pos]);
1578 pos += 1;
1579 let redis_pattern = sql_like_to_redis(&pattern);
1580 return Some((format!("(-@{}:({}))", field, redis_pattern), pos));
1581 }
1582
1583 if op == "!=" {
1585 if pos >= tokens.len() {
1586 return None;
1587 }
1588 let value = unquote(&tokens[pos]);
1589 pos += 1;
1590 if is_numeric_str(&value) {
1591 let n: f64 = value.parse().ok()?;
1592 return Some((
1593 format!("(-@{}:[{} {}])", field, format_num(n), format_num(n)),
1594 pos,
1595 ));
1596 }
1597 if let Some(ts) = try_parse_date(&value) {
1598 return Some((
1599 format!("(-@{}:[{} {}])", field, format_num(ts), format_num(ts)),
1600 pos,
1601 ));
1602 }
1603 return Some((format!("(-@{}:{{{}}})", field, escape_tag(&value)), pos));
1605 }
1606
1607 if pos >= tokens.len() {
1609 return None;
1610 }
1611
1612 let (real_op, value_str) = if (op == "<" || op == ">") && tokens[pos] == "=" {
1614 let combined = format!("{}=", op);
1615 pos += 1;
1616 if pos >= tokens.len() {
1617 return None;
1618 }
1619 let v = unquote(&tokens[pos]);
1620 pos += 1;
1621 (combined, v)
1622 } else {
1623 let v = unquote(&tokens[pos]);
1624 pos += 1;
1625 (op.clone(), v)
1626 };
1627
1628 let filter = match real_op.as_str() {
1629 "=" => {
1630 if is_numeric_str(&value_str) {
1631 let n: f64 = value_str.parse().ok()?;
1632 format!("@{}:[{} {}]", field, format_num(n), format_num(n))
1633 } else if let Some(ts) = try_parse_date(&value_str) {
1634 format!("@{}:[{} {}]", field, format_num(ts), format_num(ts))
1635 } else {
1636 let val = value_str.clone();
1639 if val.contains('*') || val.contains('%') {
1640 format!("@{}:({})", field, val)
1642 } else if val.contains(' ') {
1643 format!("@{}:(\"{}\")", field, val)
1645 } else {
1646 format!("@{}:{{{}}}", field, escape_tag(&val))
1648 }
1649 }
1650 }
1651 "<" => {
1652 let n = parse_num_or_date(&value_str)?;
1653 format!("@{}:[-inf ({}]", field, format_num(n))
1654 }
1655 ">" => {
1656 let n = parse_num_or_date(&value_str)?;
1657 format!("@{}:[({} +inf]", field, format_num(n))
1658 }
1659 "<=" => {
1660 let n = parse_num_or_date(&value_str)?;
1661 format!("@{}:[-inf {}]", field, format_num(n))
1662 }
1663 ">=" => {
1664 let n = parse_num_or_date(&value_str)?;
1665 format!("@{}:[{} +inf]", field, format_num(n))
1666 }
1667 _ => return None,
1668 };
1669
1670 Some((filter, pos))
1671}
1672
1673fn tokenize(sql: &str) -> Vec<String> {
1682 let mut tokens = Vec::new();
1683 let chars: Vec<char> = sql.chars().collect();
1684 let len = chars.len();
1685 let mut i = 0;
1686
1687 while i < len {
1688 if chars[i].is_ascii_whitespace() {
1690 i += 1;
1691 continue;
1692 }
1693 if chars[i] == '\'' {
1695 let mut s = String::new();
1696 s.push('\'');
1697 i += 1;
1698 while i < len {
1699 if chars[i] == '\'' {
1700 if i + 1 < len && chars[i + 1] == '\'' {
1701 s.push('\'');
1702 s.push('\'');
1703 i += 2;
1704 } else {
1705 break;
1706 }
1707 } else {
1708 s.push(chars[i]);
1709 i += 1;
1710 }
1711 }
1712 s.push('\'');
1713 if i < len {
1714 i += 1;
1715 }
1716 tokens.push(s);
1717 continue;
1718 }
1719 if chars[i] == ':'
1721 && i + 1 < len
1722 && (chars[i + 1].is_ascii_alphabetic() || chars[i + 1] == '_')
1723 {
1724 let start = i;
1725 i += 1; while i < len && (chars[i].is_ascii_alphanumeric() || chars[i] == '_') {
1727 i += 1;
1728 }
1729 tokens.push(chars[start..i].iter().collect());
1730 continue;
1731 }
1732 if chars[i].is_ascii_alphabetic() || chars[i] == '_' {
1734 let start = i;
1735 while i < len && (chars[i].is_ascii_alphanumeric() || chars[i] == '_') {
1736 i += 1;
1737 }
1738 tokens.push(chars[start..i].iter().collect());
1739 continue;
1740 }
1741 if chars[i].is_ascii_digit()
1743 || (chars[i] == '-' && i + 1 < len && chars[i + 1].is_ascii_digit())
1744 {
1745 let start = i;
1746 if chars[i] == '-' {
1747 i += 1;
1748 }
1749 while i < len && (chars[i].is_ascii_digit() || chars[i] == '.') {
1750 i += 1;
1751 }
1752 tokens.push(chars[start..i].iter().collect());
1753 continue;
1754 }
1755 if i + 1 < len {
1757 let two: String = chars[i..i + 2].iter().collect();
1758 if two == "!=" || two == "<=" || two == ">=" {
1759 tokens.push(two);
1760 i += 2;
1761 continue;
1762 }
1763 }
1764 tokens.push(chars[i].to_string());
1766 i += 1;
1767 }
1768 tokens
1769}
1770
1771fn tok_eq(tokens: &[String], pos: usize, expected: &str) -> bool {
1777 tokens
1778 .get(pos)
1779 .map_or(false, |t| t.eq_ignore_ascii_case(expected))
1780}
1781
1782fn parse_usize(tokens: &[String], pos: usize) -> Option<usize> {
1784 tokens.get(pos)?.parse().ok()
1785}
1786
1787fn parse_numeric_or_date_literal(tokens: &[String], pos: usize) -> Option<f64> {
1792 let tok = tokens.get(pos)?;
1793 let s = unquote(tok);
1794 if let Ok(n) = s.parse::<f64>() {
1795 Some(n)
1796 } else {
1797 try_parse_date(&s)
1798 }
1799}
1800
1801fn parse_num_or_date(s: &str) -> Option<f64> {
1803 if let Ok(n) = s.parse::<f64>() {
1804 Some(n)
1805 } else {
1806 try_parse_date(s)
1807 }
1808}
1809
1810fn try_parse_date(s: &str) -> Option<f64> {
1815 if s.len() == 10 && s.as_bytes().get(4) == Some(&b'-') && s.as_bytes().get(7) == Some(&b'-') {
1817 let year: i32 = s[0..4].parse().ok()?;
1818 let month: u32 = s[5..7].parse().ok()?;
1819 let day: u32 = s[8..10].parse().ok()?;
1820 if !(1..=12).contains(&month) || !(1..=31).contains(&day) {
1821 return None;
1822 }
1823 let ts = date_to_unix_timestamp(year, month, day)?;
1825 return Some(ts as f64);
1826 }
1827 if s.len() >= 19 && (s.as_bytes().get(10) == Some(&b'T') || s.as_bytes().get(10) == Some(&b' '))
1829 {
1830 let year: i32 = s[0..4].parse().ok()?;
1831 let month: u32 = s[5..7].parse().ok()?;
1832 let day: u32 = s[8..10].parse().ok()?;
1833 let hour: u32 = s[11..13].parse().ok()?;
1834 let min: u32 = s[14..16].parse().ok()?;
1835 let sec: u32 = s[17..19].parse().ok()?;
1836 if !(1..=12).contains(&month) || !(1..=31).contains(&day) {
1837 return None;
1838 }
1839 if hour > 23 || min > 59 || sec > 59 {
1840 return None;
1841 }
1842 let day_ts = date_to_unix_timestamp(year, month, day)?;
1843 let ts = day_ts + (hour as i64) * 3600 + (min as i64) * 60 + (sec as i64);
1844 return Some(ts as f64);
1845 }
1846 None
1847}
1848
1849fn date_to_unix_timestamp(year: i32, month: u32, day: u32) -> Option<i64> {
1851 const DAYS_IN_MONTH: [u32; 12] = [31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31];
1853 fn is_leap(y: i32) -> bool {
1854 (y % 4 == 0 && y % 100 != 0) || y % 400 == 0
1855 }
1856
1857 let mut days: i64 = 0;
1859
1860 if year >= 1970 {
1862 for y in 1970..year {
1863 days += if is_leap(y) { 366 } else { 365 };
1864 }
1865 } else {
1866 for y in year..1970 {
1867 days -= if is_leap(y) { 366 } else { 365 };
1868 }
1869 }
1870
1871 for m in 1..month {
1873 let mut d = DAYS_IN_MONTH[(m - 1) as usize];
1874 if m == 2 && is_leap(year) {
1875 d += 1;
1876 }
1877 days += d as i64;
1878 }
1879
1880 days += (day as i64) - 1;
1882
1883 Some(days * 86400)
1884}
1885
1886fn sql_like_to_redis(pattern: &str) -> String {
1896 pattern.replace('%', "*")
1897}
1898
1899fn unquote(s: &str) -> String {
1901 if s.len() >= 2 && s.starts_with('\'') && s.ends_with('\'') {
1902 let inner = &s[1..s.len() - 1];
1903 inner.replace("''", "'")
1905 } else {
1906 s.to_string()
1907 }
1908}
1909
1910fn escape_tag(value: &str) -> String {
1912 value
1913 .chars()
1914 .flat_map(|ch| {
1915 if matches!(ch, ' ' | '$' | ':' | '&' | '/' | '-' | '.' | '*') {
1916 vec!['\\', ch]
1917 } else {
1918 vec![ch]
1919 }
1920 })
1921 .collect()
1922}
1923
1924fn is_numeric_str(s: &str) -> bool {
1926 s.parse::<f64>().is_ok()
1927}
1928
1929fn format_num(n: f64) -> String {
1931 if n.fract() == 0.0 {
1932 format!("{:.0}", n)
1933 } else {
1934 n.to_string()
1935 }
1936}
1937
1938#[cfg(test)]
1939mod tests {
1940 use super::*;
1941
1942 #[test]
1945 fn similar_param_names_no_partial_match() {
1946 let query = SQLQuery::with_params(
1947 "SELECT * FROM idx WHERE id = :id AND product_id = :product_id",
1948 HashMap::from([
1949 ("id".to_owned(), SqlParam::Int(123)),
1950 ("product_id".to_owned(), SqlParam::Int(456)),
1951 ]),
1952 );
1953 let substituted = query.substituted_sql();
1954 assert!(substituted.contains("id = 123"));
1955 assert!(substituted.contains("product_id = 456"));
1956 assert!(!substituted.contains("product_123"));
1957 }
1958
1959 #[test]
1960 fn prefix_param_names() {
1961 let query = SQLQuery::with_params(
1962 "SELECT * FROM idx WHERE user = :user AND user_id = :user_id AND user_name = :user_name",
1963 HashMap::from([
1964 ("user".to_owned(), SqlParam::Str("alice".to_owned())),
1965 ("user_id".to_owned(), SqlParam::Int(42)),
1966 (
1967 "user_name".to_owned(),
1968 SqlParam::Str("Alice Smith".to_owned()),
1969 ),
1970 ]),
1971 );
1972 let substituted = query.substituted_sql();
1973 assert!(substituted.contains("user = 'alice'"));
1974 assert!(substituted.contains("user_id = 42"));
1975 assert!(substituted.contains("user_name = 'Alice Smith'"));
1976 assert!(!substituted.contains("'alice'_id"));
1977 assert!(!substituted.contains("'alice'_name"));
1978 }
1979
1980 #[test]
1981 fn suffix_param_names() {
1982 let query = SQLQuery::with_params(
1983 "SELECT * FROM idx WHERE vec = :vec AND query_vec = :query_vec",
1984 HashMap::from([
1985 ("vec".to_owned(), SqlParam::Float(1.0)),
1986 ("query_vec".to_owned(), SqlParam::Float(2.0)),
1987 ]),
1988 );
1989 let substituted = query.substituted_sql();
1990 assert!(substituted.contains("vec = 1") || substituted.contains("vec = 1.0"));
1991 assert!(substituted.contains("query_vec = 2") || substituted.contains("query_vec = 2.0"));
1992 }
1993
1994 #[test]
1997 fn single_quote_in_value() {
1998 let query = SQLQuery::new("SELECT * FROM idx WHERE name = :name")
1999 .with_param("name", SqlParam::Str("O'Brien".to_owned()));
2000 let substituted = query.substituted_sql();
2001 assert!(substituted.contains("name = 'O''Brien'"));
2002 }
2003
2004 #[test]
2005 fn multiple_quotes_in_value() {
2006 let query = SQLQuery::new("SELECT * FROM idx WHERE phrase = :phrase")
2007 .with_param("phrase", SqlParam::Str("It's a 'test' string".to_owned()));
2008 let substituted = query.substituted_sql();
2009 assert!(substituted.contains("phrase = 'It''s a ''test'' string'"));
2010 }
2011
2012 #[test]
2013 fn apostrophe_names() {
2014 let cases = [
2015 ("McDonald's", "'McDonald''s'"),
2016 ("O'Reilly", "'O''Reilly'"),
2017 ("D'Angelo", "'D''Angelo'"),
2018 ];
2019 for (name, expected) in cases {
2020 let query = SQLQuery::new("SELECT * FROM idx WHERE name = :name")
2021 .with_param("name", SqlParam::Str(name.to_owned()));
2022 let substituted = query.substituted_sql();
2023 assert!(
2024 substituted.contains(&format!("name = {expected}")),
2025 "Failed for {name}: got {substituted}"
2026 );
2027 }
2028 }
2029
2030 #[test]
2033 fn multiple_occurrences_same_param() {
2034 let query = SQLQuery::new("SELECT * FROM idx WHERE category = :cat OR subcategory = :cat")
2035 .with_param("cat", SqlParam::Str("electronics".to_owned()));
2036 let substituted = query.substituted_sql();
2037 assert_eq!(substituted.matches("'electronics'").count(), 2);
2038 }
2039
2040 #[test]
2041 fn empty_string_value() {
2042 let query = SQLQuery::new("SELECT * FROM idx WHERE name = :name")
2043 .with_param("name", SqlParam::Str(String::new()));
2044 let substituted = query.substituted_sql();
2045 assert!(substituted.contains("name = ''"));
2046 }
2047
2048 #[test]
2049 fn numeric_types() {
2050 let query = SQLQuery::with_params(
2051 "SELECT * FROM idx WHERE count = :count AND price = :price",
2052 HashMap::from([
2053 ("count".to_owned(), SqlParam::Int(42)),
2054 ("price".to_owned(), SqlParam::Float(99.99)),
2055 ]),
2056 );
2057 let substituted = query.substituted_sql();
2058 assert!(substituted.contains("count = 42"));
2059 assert!(substituted.contains("price = 99.99"));
2060 }
2061
2062 #[test]
2063 fn bytes_param_not_substituted() {
2064 let query = SQLQuery::new("SELECT * FROM idx WHERE embedding = :vec")
2065 .with_param("vec", SqlParam::Bytes(vec![0x00, 0x01, 0x02, 0x03]));
2066 let substituted = query.substituted_sql();
2067 assert!(substituted.contains(":vec"));
2068 }
2069
2070 #[test]
2071 fn special_characters_in_value() {
2072 let specials = [
2073 "hello@world.com",
2074 "path/to/file",
2075 "price: $100",
2076 "regex.*pattern",
2077 "back\\slash",
2078 ];
2079 for value in specials {
2080 let query = SQLQuery::new("SELECT * FROM idx WHERE field = :field")
2081 .with_param("field", SqlParam::Str(value.to_owned()));
2082 let substituted = query.substituted_sql();
2083 assert!(
2084 !substituted.contains(":field"),
2085 "Failed to substitute for value: {value}"
2086 );
2087 }
2088 }
2089
2090 #[test]
2091 fn no_params_returns_original() {
2092 let query = SQLQuery::new("SELECT * FROM idx");
2093 assert_eq!(query.substituted_sql(), "SELECT * FROM idx");
2094 }
2095
2096 #[test]
2097 fn unknown_placeholder_kept() {
2098 let query = SQLQuery::new("SELECT * FROM idx WHERE x = :unknown")
2099 .with_param("other", SqlParam::Int(1));
2100 assert!(query.substituted_sql().contains(":unknown"));
2101 }
2102
2103 #[test]
2104 fn with_param_builder_pattern() {
2105 let query = SQLQuery::new("SELECT * FROM idx WHERE a = :a AND b = :b")
2106 .with_param("a", SqlParam::Int(1))
2107 .with_param("b", SqlParam::Str("hello".to_owned()));
2108 let sub = query.substituted_sql();
2109 assert!(sub.contains("a = 1"));
2110 assert!(sub.contains("b = 'hello'"));
2111 }
2112
2113 #[test]
2114 fn sql_accessor() {
2115 let query = SQLQuery::new("SELECT 1");
2116 assert_eq!(query.sql(), "SELECT 1");
2117 }
2118
2119 #[test]
2120 fn params_map_accessor() {
2121 let query = SQLQuery::new("SELECT 1").with_param("x", SqlParam::Int(42));
2122 assert_eq!(query.params_map().len(), 1);
2123 }
2124
2125 #[test]
2128 fn select_star_no_where_produces_wildcard() {
2129 let query = SQLQuery::new("SELECT * FROM products");
2130 assert_eq!(query.to_redis_query(), "*");
2131 }
2132
2133 #[test]
2134 fn select_specific_fields_sets_return_fields() {
2135 let query = SQLQuery::new("SELECT title, price FROM products");
2136 assert_eq!(query.to_redis_query(), "*");
2137 assert_eq!(query.return_fields(), vec!["title", "price"]);
2138 }
2139
2140 #[test]
2141 fn where_tag_equals() {
2142 let query = SQLQuery::new("SELECT * FROM products WHERE category = 'electronics'");
2143 assert_eq!(query.to_redis_query(), "@category:{electronics}");
2144 }
2145
2146 #[test]
2147 fn where_tag_not_equals() {
2148 let query = SQLQuery::new("SELECT * FROM products WHERE category != 'electronics'");
2149 assert_eq!(query.to_redis_query(), "(-@category:{electronics})");
2150 }
2151
2152 #[test]
2153 fn where_tag_in() {
2154 let query =
2155 SQLQuery::new("SELECT * FROM products WHERE category IN ('books', 'accessories')");
2156 assert_eq!(query.to_redis_query(), "@category:{books|accessories}");
2157 }
2158
2159 #[test]
2160 fn where_numeric_less_than() {
2161 let query = SQLQuery::new("SELECT * FROM products WHERE price < 50");
2162 assert_eq!(query.to_redis_query(), "@price:[-inf (50]");
2163 }
2164
2165 #[test]
2166 fn where_numeric_greater_than() {
2167 let query = SQLQuery::new("SELECT * FROM products WHERE price > 100");
2168 assert_eq!(query.to_redis_query(), "@price:[(100 +inf]");
2169 }
2170
2171 #[test]
2172 fn where_numeric_equals() {
2173 let query = SQLQuery::new("SELECT * FROM products WHERE price = 45");
2174 assert_eq!(query.to_redis_query(), "@price:[45 45]");
2175 }
2176
2177 #[test]
2178 fn where_numeric_not_equals() {
2179 let query = SQLQuery::new("SELECT * FROM products WHERE price != 45");
2180 assert_eq!(query.to_redis_query(), "(-@price:[45 45])");
2181 }
2182
2183 #[test]
2184 fn where_numeric_lte() {
2185 let query = SQLQuery::new("SELECT * FROM products WHERE price <= 50");
2186 assert_eq!(query.to_redis_query(), "@price:[-inf 50]");
2187 }
2188
2189 #[test]
2190 fn where_numeric_gte() {
2191 let query = SQLQuery::new("SELECT * FROM products WHERE price >= 25");
2192 assert_eq!(query.to_redis_query(), "@price:[25 +inf]");
2193 }
2194
2195 #[test]
2196 fn where_between() {
2197 let query = SQLQuery::new("SELECT * FROM products WHERE price BETWEEN 40 AND 60");
2198 assert_eq!(query.to_redis_query(), "@price:[40 60]");
2199 }
2200
2201 #[test]
2202 fn where_combined_and() {
2203 let query =
2204 SQLQuery::new("SELECT * FROM products WHERE category = 'electronics' AND price < 100");
2205 assert_eq!(
2206 query.to_redis_query(),
2207 "(@category:{electronics} @price:[-inf (100])"
2208 );
2209 }
2210
2211 #[test]
2212 fn order_by_asc() {
2213 let query = SQLQuery::new("SELECT title, price FROM products ORDER BY price ASC");
2214 let sb = query.sort_by().expect("sort_by should be set");
2215 assert_eq!(sb.field, "price");
2216 assert!(matches!(sb.direction, SortDirection::Asc));
2217 }
2218
2219 #[test]
2220 fn order_by_desc() {
2221 let query = SQLQuery::new("SELECT title, price FROM products ORDER BY price DESC");
2222 let sb = query.sort_by().expect("sort_by should be set");
2223 assert_eq!(sb.field, "price");
2224 assert!(matches!(sb.direction, SortDirection::Desc));
2225 }
2226
2227 #[test]
2228 fn limit_clause() {
2229 let query = SQLQuery::new("SELECT title FROM products LIMIT 3");
2230 let lim = query.limit().expect("limit should be set");
2231 assert_eq!(lim.num, 3);
2232 assert_eq!(lim.offset, 0);
2233 }
2234
2235 #[test]
2236 fn limit_with_offset() {
2237 let query = SQLQuery::new("SELECT title FROM products ORDER BY price ASC LIMIT 3 OFFSET 3");
2238 let lim = query.limit().expect("limit should be set");
2239 assert_eq!(lim.num, 3);
2240 assert_eq!(lim.offset, 3);
2241 }
2242
2243 #[test]
2244 fn where_with_order_and_limit() {
2245 let query = SQLQuery::new(
2246 "SELECT title, price FROM products WHERE category = 'electronics' ORDER BY price ASC LIMIT 5",
2247 );
2248 assert_eq!(query.to_redis_query(), "@category:{electronics}");
2249 assert_eq!(query.return_fields(), vec!["title", "price"]);
2250 let sb = query.sort_by().expect("sort_by");
2251 assert_eq!(sb.field, "price");
2252 let lim = query.limit().expect("limit");
2253 assert_eq!(lim.num, 5);
2254 }
2255
2256 #[test]
2257 fn aggregate_query_returns_raw_sql_fallback() {
2258 let query = SQLQuery::new("SELECT COUNT(*) as total FROM products");
2260 let result = query.to_redis_query();
2261 assert!(result.contains("COUNT"));
2263 }
2264
2265 #[test]
2266 fn text_equality_single_word() {
2267 let query = SQLQuery::new("SELECT * FROM products WHERE title = 'laptop'");
2268 assert_eq!(query.to_redis_query(), "@title:{laptop}");
2269 }
2270
2271 #[test]
2272 fn text_equality_phrase() {
2273 let query = SQLQuery::new("SELECT * FROM products WHERE title = 'gaming laptop'");
2274 assert_eq!(query.to_redis_query(), "@title:(\"gaming laptop\")");
2275 }
2276
2277 #[test]
2278 fn numeric_range_with_and() {
2279 let query = SQLQuery::new("SELECT * FROM products WHERE price >= 25 AND price <= 50");
2280 assert_eq!(
2281 query.to_redis_query(),
2282 "(@price:[25 +inf] @price:[-inf 50])"
2283 );
2284 }
2285
2286 #[test]
2287 fn should_unpack_json_for_select_star() {
2288 let query = SQLQuery::new("SELECT * FROM products");
2289 assert!(query.should_unpack_json());
2290 }
2291
2292 #[test]
2293 fn should_not_unpack_json_for_field_projection() {
2294 let query = SQLQuery::new("SELECT title, price FROM products");
2295 assert!(!query.should_unpack_json());
2296 }
2297
2298 #[test]
2299 fn with_param_where_tag() {
2300 let query = SQLQuery::new("SELECT * FROM products WHERE category = :cat")
2301 .with_param("cat", SqlParam::Str("electronics".to_owned()));
2302 assert_eq!(query.to_redis_query(), "@category:{electronics}");
2303 }
2304
2305 #[test]
2306 fn with_param_where_numeric() {
2307 let query = SQLQuery::new("SELECT * FROM products WHERE price > :min_price")
2308 .with_param("min_price", SqlParam::Float(99.99));
2309 assert_eq!(query.to_redis_query(), "@price:[(99.99 +inf]");
2310 }
2311
2312 #[test]
2315 fn where_simple_or() {
2316 let query = SQLQuery::new(
2317 "SELECT * FROM products WHERE category = 'electronics' OR category = 'books'",
2318 );
2319 assert_eq!(
2320 query.to_redis_query(),
2321 "(@category:{electronics} | @category:{books})"
2322 );
2323 }
2324
2325 #[test]
2326 fn where_or_with_three_branches() {
2327 let query = SQLQuery::new(
2328 "SELECT * FROM products WHERE category = 'electronics' OR category = 'books' OR category = 'accessories'",
2329 );
2330 assert_eq!(
2331 query.to_redis_query(),
2332 "(@category:{electronics} | @category:{books} | @category:{accessories})"
2333 );
2334 }
2335
2336 #[test]
2337 fn where_and_binds_tighter_than_or() {
2338 let query = SQLQuery::new(
2340 "SELECT * FROM products WHERE category = 'electronics' AND price > 100 OR category = 'books' AND price < 50",
2341 );
2342 assert_eq!(
2343 query.to_redis_query(),
2344 "((@category:{electronics} @price:[(100 +inf]) | (@category:{books} @price:[-inf (50]))"
2345 );
2346 }
2347
2348 #[test]
2349 fn where_or_with_single_conditions() {
2350 let query = SQLQuery::new("SELECT * FROM products WHERE price < 20 OR price > 1000");
2351 assert_eq!(
2352 query.to_redis_query(),
2353 "(@price:[-inf (20] | @price:[(1000 +inf])"
2354 );
2355 }
2356
2357 #[test]
2358 fn where_or_preserves_order_limit() {
2359 let query = SQLQuery::new(
2360 "SELECT title FROM products WHERE category = 'a' OR category = 'b' ORDER BY price ASC LIMIT 5",
2361 );
2362 assert_eq!(query.to_redis_query(), "(@category:{a} | @category:{b})");
2363 assert!(query.sort_by().is_some());
2364 assert_eq!(query.limit().unwrap().num, 5);
2365 }
2366
2367 #[test]
2370 fn where_not_in() {
2371 let query =
2372 SQLQuery::new("SELECT * FROM products WHERE category NOT IN ('electronics', 'books')");
2373 assert_eq!(query.to_redis_query(), "(-@category:{electronics|books})");
2374 }
2375
2376 #[test]
2377 fn where_not_in_combined_with_and() {
2378 let query = SQLQuery::new(
2379 "SELECT * FROM products WHERE category NOT IN ('electronics') AND price > 50",
2380 );
2381 assert_eq!(
2382 query.to_redis_query(),
2383 "((-@category:{electronics}) @price:[(50 +inf])"
2384 );
2385 }
2386
2387 #[test]
2390 fn where_like_prefix() {
2391 let query = SQLQuery::new("SELECT * FROM products WHERE title LIKE 'laptop%'");
2392 assert_eq!(query.to_redis_query(), "@title:(laptop*)");
2393 }
2394
2395 #[test]
2396 fn where_like_suffix() {
2397 let query = SQLQuery::new("SELECT * FROM products WHERE title LIKE '%laptop'");
2398 assert_eq!(query.to_redis_query(), "@title:(*laptop)");
2399 }
2400
2401 #[test]
2402 fn where_like_contains() {
2403 let query = SQLQuery::new("SELECT * FROM products WHERE title LIKE '%laptop%'");
2404 assert_eq!(query.to_redis_query(), "@title:(*laptop*)");
2405 }
2406
2407 #[test]
2408 fn where_not_like() {
2409 let query = SQLQuery::new("SELECT * FROM products WHERE title NOT LIKE 'laptop%'");
2410 assert_eq!(query.to_redis_query(), "(-@title:(laptop*))");
2411 }
2412
2413 #[test]
2414 fn where_like_combined_with_and() {
2415 let query =
2416 SQLQuery::new("SELECT * FROM products WHERE title LIKE 'lap%' AND price < 1000");
2417 assert_eq!(
2418 query.to_redis_query(),
2419 "(@title:(lap*) @price:[-inf (1000])"
2420 );
2421 }
2422
2423 #[test]
2426 fn where_date_greater_than() {
2427 let query = SQLQuery::new("SELECT * FROM events WHERE created_at > '2024-01-01'");
2428 let result = query.to_redis_query();
2429 assert_eq!(result, "@created_at:[(1704067200 +inf]");
2431 }
2432
2433 #[test]
2434 fn where_date_less_than() {
2435 let query = SQLQuery::new("SELECT * FROM events WHERE created_at < '2024-03-31'");
2436 let result = query.to_redis_query();
2437 assert_eq!(result, "@created_at:[-inf (1711843200]");
2439 }
2440
2441 #[test]
2442 fn where_date_between() {
2443 let query = SQLQuery::new(
2444 "SELECT * FROM events WHERE created_at BETWEEN '2024-01-01' AND '2024-03-31'",
2445 );
2446 let result = query.to_redis_query();
2447 assert_eq!(result, "@created_at:[1704067200 1711843200]");
2448 }
2449
2450 #[test]
2451 fn where_date_gte() {
2452 let query = SQLQuery::new("SELECT * FROM events WHERE created_at >= '2024-06-15'");
2453 let result = query.to_redis_query();
2454 assert_eq!(result, "@created_at:[1718409600 +inf]");
2456 }
2457
2458 #[test]
2459 fn where_date_combined_with_tag() {
2460 let query = SQLQuery::new(
2461 "SELECT * FROM events WHERE category = 'meeting' AND created_at > '2024-01-01'",
2462 );
2463 let result = query.to_redis_query();
2464 assert_eq!(
2465 result,
2466 "(@category:{meeting} @created_at:[(1704067200 +inf])"
2467 );
2468 }
2469
2470 #[test]
2471 fn where_datetime_with_time() {
2472 let query = SQLQuery::new("SELECT * FROM events WHERE created_at > '2024-01-15T10:30:00'");
2473 let result = query.to_redis_query();
2474 assert_eq!(result, "@created_at:[(1705314600 +inf]");
2476 }
2477
2478 #[test]
2479 fn date_to_timestamp_known_values() {
2480 assert_eq!(try_parse_date("1970-01-01"), Some(0.0));
2482 assert_eq!(try_parse_date("2000-01-01"), Some(946_684_800.0));
2484 assert_eq!(try_parse_date("2024-01-01"), Some(1_704_067_200.0));
2486 }
2487
2488 #[test]
2489 fn invalid_date_returns_none() {
2490 assert_eq!(try_parse_date("not-a-date"), None);
2491 assert_eq!(try_parse_date("2024-13-01"), None); assert_eq!(try_parse_date("2024-00-01"), None); assert_eq!(try_parse_date("2024-01-32"), None); }
2495
2496 #[test]
2499 fn where_or_with_like() {
2500 let query = SQLQuery::new(
2501 "SELECT * FROM products WHERE title LIKE 'laptop%' OR title LIKE 'phone%'",
2502 );
2503 assert_eq!(
2504 query.to_redis_query(),
2505 "(@title:(laptop*) | @title:(phone*))"
2506 );
2507 }
2508
2509 #[test]
2510 fn where_or_with_date() {
2511 let query = SQLQuery::new(
2512 "SELECT * FROM events WHERE created_at < '2024-01-01' OR created_at > '2024-12-31'",
2513 );
2514 let result = query.to_redis_query();
2515 assert_eq!(
2517 result,
2518 "(@created_at:[-inf (1704067200] | @created_at:[(1735603200 +inf])"
2519 );
2520 }
2521
2522 fn agg_cmd_args(sql: &str, index_name: &str) -> Vec<String> {
2526 let q = SQLQuery::new(sql);
2527 assert!(q.is_aggregate(), "expected aggregate for: {sql}");
2528 let cmd = q.build_aggregate_cmd(index_name).unwrap();
2529 let packed = cmd.get_packed_command();
2531 parse_resp_args(&packed)
2532 }
2533
2534 fn parse_resp_args(data: &[u8]) -> Vec<String> {
2536 let s = String::from_utf8_lossy(data);
2537 let mut args = Vec::new();
2538 let mut remaining = &s[..];
2539 while let Some(dollar) = remaining.find('$') {
2540 remaining = &remaining[dollar + 1..];
2541 let crlf = remaining.find("\r\n").unwrap();
2542 let len: usize = remaining[..crlf].parse().unwrap();
2543 remaining = &remaining[crlf + 2..];
2544 let val = &remaining[..len];
2545 args.push(val.to_string());
2546 remaining = &remaining[len + 2..]; }
2548 args
2549 }
2550
2551 #[test]
2552 fn aggregate_count_star() {
2553 let args = agg_cmd_args("SELECT COUNT(*) AS total FROM products", "idx");
2554 assert_eq!(args[0], "FT.AGGREGATE");
2555 assert_eq!(args[1], "idx");
2556 assert_eq!(args[2], "*"); assert_eq!(args[3], "GROUPBY");
2558 assert_eq!(args[4], "0");
2559 assert_eq!(args[5], "REDUCE");
2560 assert_eq!(args[6], "COUNT");
2561 assert_eq!(args[7], "0"); assert_eq!(args[8], "AS");
2563 assert_eq!(args[9], "total");
2564 }
2565
2566 #[test]
2567 fn aggregate_count_star_default_alias() {
2568 let args = agg_cmd_args("SELECT COUNT(*) FROM products", "idx");
2569 assert_eq!(args[9], "count"); }
2571
2572 #[test]
2573 fn aggregate_sum() {
2574 let args = agg_cmd_args("SELECT SUM(price) AS total_price FROM products", "idx");
2575 assert_eq!(args[5], "REDUCE");
2576 assert_eq!(args[6], "SUM");
2577 assert_eq!(args[7], "1"); assert_eq!(args[8], "@price");
2579 assert_eq!(args[9], "AS");
2580 assert_eq!(args[10], "total_price");
2581 }
2582
2583 #[test]
2584 fn aggregate_avg() {
2585 let args = agg_cmd_args("SELECT AVG(score) AS avg_score FROM products", "idx");
2586 assert_eq!(args[6], "AVG");
2587 assert_eq!(args[8], "@score");
2588 assert_eq!(args[10], "avg_score");
2589 }
2590
2591 #[test]
2592 fn aggregate_min_max() {
2593 let args = agg_cmd_args("SELECT MIN(price) AS min_price FROM products", "idx");
2594 assert_eq!(args[6], "MIN");
2595 assert_eq!(args[8], "@price");
2596 assert_eq!(args[10], "min_price");
2597
2598 let args = agg_cmd_args("SELECT MAX(price) AS max_price FROM products", "idx");
2599 assert_eq!(args[6], "MAX");
2600 assert_eq!(args[8], "@price");
2601 assert_eq!(args[10], "max_price");
2602 }
2603
2604 #[test]
2605 fn aggregate_stddev() {
2606 let args = agg_cmd_args("SELECT STDDEV(price) AS price_sd FROM products", "idx");
2607 assert_eq!(args[6], "STDDEV");
2608 assert_eq!(args[8], "@price");
2609 assert_eq!(args[10], "price_sd");
2610 }
2611
2612 #[test]
2613 fn aggregate_count_distinct() {
2614 let args = agg_cmd_args(
2615 "SELECT COUNT_DISTINCT(brand) AS unique_brands FROM products",
2616 "idx",
2617 );
2618 assert_eq!(args[6], "COUNT_DISTINCT");
2619 assert_eq!(args[8], "@brand");
2620 assert_eq!(args[10], "unique_brands");
2621 }
2622
2623 #[test]
2624 fn aggregate_quantile() {
2625 let args = agg_cmd_args("SELECT QUANTILE(price, 0.95) AS p95 FROM products", "idx");
2626 assert_eq!(args[6], "QUANTILE");
2627 assert_eq!(args[7], "2"); assert_eq!(args[8], "@price");
2629 assert_eq!(args[9], "0.95");
2630 assert_eq!(args[10], "AS");
2631 assert_eq!(args[11], "p95");
2632 }
2633
2634 #[test]
2635 fn aggregate_array_agg_to_tolist() {
2636 let args = agg_cmd_args("SELECT ARRAY_AGG(name) AS names FROM products", "idx");
2637 assert_eq!(args[6], "TOLIST");
2638 assert_eq!(args[8], "@name");
2639 assert_eq!(args[10], "names");
2640 }
2641
2642 #[test]
2643 fn aggregate_first_value() {
2644 let args = agg_cmd_args(
2645 "SELECT FIRST_VALUE(name) AS first_name FROM products",
2646 "idx",
2647 );
2648 assert_eq!(args[6], "FIRST_VALUE");
2649 assert_eq!(args[8], "@name");
2650 assert_eq!(args[10], "first_name");
2651 }
2652
2653 #[test]
2654 fn aggregate_group_by_single_field() {
2655 let args = agg_cmd_args(
2656 "SELECT category, COUNT(*) AS cnt FROM products GROUP BY category",
2657 "idx",
2658 );
2659 assert_eq!(args[0], "FT.AGGREGATE");
2660 assert_eq!(args[1], "idx");
2661 assert_eq!(args[2], "*");
2662 assert_eq!(args[3], "GROUPBY");
2663 assert_eq!(args[4], "1");
2664 assert_eq!(args[5], "@category");
2665 assert_eq!(args[6], "REDUCE");
2666 assert_eq!(args[7], "COUNT");
2667 assert_eq!(args[8], "0");
2668 assert_eq!(args[9], "AS");
2669 assert_eq!(args[10], "cnt");
2670 }
2671
2672 #[test]
2673 fn aggregate_group_by_with_where() {
2674 let args = agg_cmd_args(
2675 "SELECT category, AVG(price) AS avg_price FROM products WHERE price > 10 GROUP BY category",
2676 "idx",
2677 );
2678 assert_eq!(args[2], "@price:[(10 +inf]"); assert_eq!(args[3], "GROUPBY");
2680 assert_eq!(args[4], "1");
2681 assert_eq!(args[5], "@category");
2682 assert_eq!(args[6], "REDUCE");
2683 assert_eq!(args[7], "AVG");
2684 }
2685
2686 #[test]
2687 fn aggregate_multiple_reducers() {
2688 let args = agg_cmd_args(
2689 "SELECT category, COUNT(*) AS cnt, AVG(price) AS avg_price FROM products GROUP BY category",
2690 "idx",
2691 );
2692 assert_eq!(args[3], "GROUPBY");
2693 assert_eq!(args[4], "1");
2694 assert_eq!(args[5], "@category");
2695 assert_eq!(args[6], "REDUCE");
2697 assert_eq!(args[7], "COUNT");
2698 assert_eq!(args[8], "0");
2699 assert_eq!(args[9], "AS");
2700 assert_eq!(args[10], "cnt");
2701 assert_eq!(args[11], "REDUCE");
2703 assert_eq!(args[12], "AVG");
2704 assert_eq!(args[13], "1");
2705 assert_eq!(args[14], "@price");
2706 assert_eq!(args[15], "AS");
2707 assert_eq!(args[16], "avg_price");
2708 }
2709
2710 #[test]
2711 fn aggregate_group_by_multiple_fields() {
2712 let args = agg_cmd_args(
2713 "SELECT category, brand, SUM(price) AS total FROM products GROUP BY category, brand",
2714 "idx",
2715 );
2716 assert_eq!(args[3], "GROUPBY");
2717 assert_eq!(args[4], "2");
2718 assert_eq!(args[5], "@category");
2719 assert_eq!(args[6], "@brand");
2720 assert_eq!(args[7], "REDUCE");
2721 assert_eq!(args[8], "SUM");
2722 }
2723
2724 #[test]
2725 fn non_aggregate_is_not_detected_as_aggregate() {
2726 let q = SQLQuery::new("SELECT * FROM products WHERE price > 10");
2727 assert!(!q.is_aggregate());
2728 assert!(q.build_aggregate_cmd("idx").is_none());
2729 }
2730
2731 #[test]
2732 fn aggregate_query_returns_raw_sql_for_search() {
2733 let q = SQLQuery::new("SELECT COUNT(*) AS total FROM products");
2736 assert!(q.is_aggregate());
2737 let redis_q = q.to_redis_query();
2739 assert!(redis_q.contains("COUNT"));
2740 }
2741
2742 #[test]
2745 fn vector_distance_basic() {
2746 let blob = vec![0u8; 12]; let q = SQLQuery::new(
2748 "SELECT title, vector_distance(embedding, :vec) AS score FROM idx LIMIT 3",
2749 )
2750 .with_param("vec", SqlParam::Bytes(blob.clone()));
2751 assert!(q.is_vector_query());
2752 let query_str = q.to_redis_query();
2753 assert_eq!(query_str, "*=>[KNN 3 @embedding $vector AS score]");
2754 let params = q.params();
2755 assert_eq!(params.len(), 1);
2756 assert_eq!(params[0].name, "vector");
2757 if let QueryParamValue::Binary(ref b) = params[0].value {
2758 assert_eq!(b, &blob);
2759 } else {
2760 panic!("Expected Binary param");
2761 }
2762 }
2763
2764 #[test]
2765 fn cosine_distance_basic() {
2766 let blob = vec![0u8; 12];
2767 let q = SQLQuery::new(
2768 "SELECT title, cosine_distance(embedding, :vec) AS dist FROM idx LIMIT 5",
2769 )
2770 .with_param("vec", SqlParam::Bytes(blob));
2771 assert!(q.is_vector_query());
2772 let query_str = q.to_redis_query();
2773 assert_eq!(query_str, "*=>[KNN 5 @embedding $vector AS dist]");
2774 }
2775
2776 #[test]
2777 fn vector_distance_with_where_filter() {
2778 let blob = vec![0u8; 12];
2779 let q = SQLQuery::new(
2780 "SELECT title, vector_distance(embedding, :vec) AS score FROM idx WHERE genre = 'sci-fi' LIMIT 3",
2781 )
2782 .with_param("vec", SqlParam::Bytes(blob));
2783 let query_str = q.to_redis_query();
2784 assert_eq!(
2785 query_str,
2786 "@genre:{sci\\-fi}=>[KNN 3 @embedding $vector AS score]"
2787 );
2788 }
2789
2790 #[test]
2791 fn vector_distance_default_alias() {
2792 let blob = vec![0u8; 12];
2793 let q = SQLQuery::new("SELECT vector_distance(embedding, :vec) FROM idx LIMIT 10")
2794 .with_param("vec", SqlParam::Bytes(blob));
2795 let query_str = q.to_redis_query();
2796 assert_eq!(
2797 query_str,
2798 "*=>[KNN 10 @embedding $vector AS vector_distance]"
2799 );
2800 }
2801
2802 #[test]
2803 fn vector_query_return_fields() {
2804 let blob = vec![0u8; 12];
2805 let q = SQLQuery::new(
2806 "SELECT title, author, vector_distance(embedding, :vec) AS score FROM idx LIMIT 5",
2807 )
2808 .with_param("vec", SqlParam::Bytes(blob));
2809 let fields = q.return_fields();
2810 assert_eq!(fields, vec!["title", "author"]);
2811 }
2812
2813 #[test]
2814 fn vector_query_limit_as_knn() {
2815 let blob = vec![0u8; 12];
2816 let q = SQLQuery::new("SELECT vector_distance(embedding, :vec) AS score FROM idx LIMIT 7")
2817 .with_param("vec", SqlParam::Bytes(blob));
2818 let limit = q.limit().expect("should have limit");
2819 assert_eq!(limit.num, 7);
2820 assert_eq!(limit.offset, 0);
2821 }
2822
2823 #[test]
2824 fn non_vector_query_not_detected_as_vector() {
2825 let q = SQLQuery::new("SELECT * FROM products WHERE price > 10");
2826 assert!(!q.is_vector_query());
2827 }
2828
2829 #[test]
2832 fn geo_distance_where_basic() {
2833 let q = SQLQuery::new(
2834 "SELECT * FROM locations WHERE geo_distance(location, POINT(-122.4194, 37.7749), 'km') < 50",
2835 );
2836 let gf = q.geofilter().expect("should have geofilter");
2837 assert_eq!(gf.field, "location");
2838 assert!((gf.lon - (-122.4194)).abs() < 0.0001);
2839 assert!((gf.lat - 37.7749).abs() < 0.0001);
2840 assert!((gf.radius - 50.0).abs() < 0.001);
2841 assert_eq!(gf.unit, "km");
2842 assert_eq!(q.to_redis_query(), "*");
2844 }
2845
2846 #[test]
2847 fn geo_distance_where_with_other_conditions() {
2848 let q = SQLQuery::new(
2849 "SELECT name FROM locations WHERE category = 'restaurant' AND geo_distance(location, POINT(-122.4194, 37.7749), 'mi') < 10",
2850 );
2851 let gf = q.geofilter().expect("should have geofilter");
2852 assert_eq!(gf.field, "location");
2853 assert!((gf.radius - 10.0).abs() < 0.001);
2854 assert_eq!(gf.unit, "mi");
2855 assert_eq!(q.to_redis_query(), "@category:{restaurant}");
2857 }
2858
2859 #[test]
2860 fn non_geo_query_no_geofilter() {
2861 let q = SQLQuery::new("SELECT * FROM products WHERE price > 10");
2862 assert!(q.geofilter().is_none());
2863 }
2864
2865 #[test]
2868 fn geo_distance_select_aggregate() {
2869 let q = SQLQuery::new(
2870 "SELECT name, geo_distance(location, POINT(-122.4194, 37.7749)) AS distance FROM locations",
2871 );
2872 assert!(q.is_geo_aggregate());
2873 let cmd = q.build_geo_aggregate_cmd("idx").expect("should build cmd");
2874 let packed = cmd.get_packed_command();
2875 let args = parse_resp_args(&packed);
2876 assert_eq!(args[0], "FT.AGGREGATE");
2877 assert_eq!(args[1], "idx");
2878 assert_eq!(args[2], "*");
2879 assert_eq!(args[3], "LOAD");
2880 assert_eq!(args[4], "1");
2881 assert_eq!(args[5], "@location");
2882 assert_eq!(args[6], "APPLY");
2883 assert!(args[7].contains("geodistance"));
2884 assert!(args[7].contains("@location"));
2885 assert_eq!(args[8], "AS");
2886 assert_eq!(args[9], "distance");
2887 }
2888
2889 #[test]
2890 fn geo_distance_select_with_where() {
2891 let q = SQLQuery::new(
2892 "SELECT name, geo_distance(location, POINT(-73.9857, 40.7484)) AS dist FROM places WHERE category = 'cafe'",
2893 );
2894 assert!(q.is_geo_aggregate());
2895 let cmd = q.build_geo_aggregate_cmd("idx").expect("should build cmd");
2896 let packed = cmd.get_packed_command();
2897 let args = parse_resp_args(&packed);
2898 assert_eq!(args[0], "FT.AGGREGATE");
2899 assert_eq!(args[2], "@category:{cafe}");
2900 }
2901
2902 #[test]
2903 fn non_geo_not_detected_as_geo_aggregate() {
2904 let q = SQLQuery::new("SELECT * FROM products WHERE price > 10");
2905 assert!(!q.is_geo_aggregate());
2906 assert!(q.build_geo_aggregate_cmd("idx").is_none());
2907 }
2908
2909 #[test]
2912 fn tokenizer_handles_colon_param() {
2913 let tokens = tokenize("SELECT vector_distance(embedding, :vec) AS score FROM idx");
2914 assert!(tokens.contains(&":vec".to_owned()));
2915 }
2916}