use std::borrow::Cow;
use bytes::{BufMut, Bytes};
use crate::filter::FilterExpression;
#[derive(Debug, Clone)]
pub enum QueryParamValue {
String(String),
Binary(Vec<u8>),
}
#[derive(Debug, Clone)]
pub struct QueryParam {
pub name: String,
pub value: QueryParamValue,
}
#[derive(Debug, Clone, Copy)]
pub enum SortDirection {
Asc,
Desc,
}
#[derive(Debug, Clone)]
pub struct SortBy {
pub field: String,
pub direction: SortDirection,
}
#[derive(Debug, Clone, Copy)]
pub struct QueryLimit {
pub offset: usize,
pub num: usize,
}
#[derive(Debug, Clone)]
pub struct GeoFilter {
pub field: String,
pub lon: f64,
pub lat: f64,
pub radius: f64,
pub unit: String,
}
#[derive(Debug, Clone)]
pub struct QueryRender {
pub query_string: String,
pub params: Vec<QueryParam>,
pub return_fields: Vec<String>,
pub sort_by: Option<SortBy>,
pub limit: Option<QueryLimit>,
pub dialect: u32,
pub in_order: bool,
pub no_content: bool,
pub scorer: Option<String>,
pub geofilter: Option<GeoFilter>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum QueryKind {
Documents,
Count,
}
pub trait QueryString {
fn to_redis_query(&self) -> String;
fn render(&self) -> QueryRender {
QueryRender {
query_string: self.to_redis_query(),
params: self.params(),
return_fields: self.return_fields(),
sort_by: self.sort_by(),
limit: self.limit(),
dialect: self.dialect(),
in_order: self.in_order(),
no_content: self.no_content(),
scorer: self.scorer(),
geofilter: self.geofilter(),
}
}
fn params(&self) -> Vec<QueryParam> {
Vec::new()
}
fn return_fields(&self) -> Vec<String> {
Vec::new()
}
fn sort_by(&self) -> Option<SortBy> {
None
}
fn limit(&self) -> Option<QueryLimit> {
None
}
fn dialect(&self) -> u32 {
2
}
fn in_order(&self) -> bool {
false
}
fn no_content(&self) -> bool {
false
}
fn scorer(&self) -> Option<String> {
None
}
fn kind(&self) -> QueryKind {
QueryKind::Documents
}
fn should_unpack_json(&self) -> bool {
false
}
fn geofilter(&self) -> Option<GeoFilter> {
None
}
}
pub trait PageableQuery: QueryString + Clone {
fn paged(&self, offset: usize, num: usize) -> Self;
}
#[derive(Debug, Clone)]
struct QueryOptions {
return_fields: Vec<String>,
limit: QueryLimit,
dialect: u32,
sort_by: Option<SortBy>,
in_order: bool,
scorer: Option<String>,
}
impl QueryOptions {
fn with_num_results(num_results: usize) -> Self {
Self {
return_fields: Vec::new(),
limit: QueryLimit {
offset: 0,
num: num_results,
},
dialect: 2,
sort_by: None,
in_order: false,
scorer: None,
}
}
}
#[derive(Debug, Clone)]
pub struct Vector<'a> {
elements: Cow<'a, [f32]>,
}
impl<'a> Vector<'a> {
pub fn new(elements: impl Into<Cow<'a, [f32]>>) -> Self {
Self {
elements: elements.into(),
}
}
pub fn elements(&self) -> &[f32] {
&self.elements
}
pub fn to_bytes(&self) -> Bytes {
let mut buffer =
bytes::BytesMut::with_capacity(self.elements.len() * std::mem::size_of::<f32>());
for value in self.elements.iter().copied() {
buffer.put_f32_le(value);
}
buffer.freeze()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum HybridPolicy {
Batches,
AdhocBf,
}
impl HybridPolicy {
pub fn as_str(&self) -> &'static str {
match self {
Self::Batches => "BATCHES",
Self::AdhocBf => "ADHOC_BF",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SearchHistoryMode {
Off,
On,
Auto,
}
impl SearchHistoryMode {
pub fn as_str(&self) -> &'static str {
match self {
Self::Off => "OFF",
Self::On => "ON",
Self::Auto => "AUTO",
}
}
}
#[derive(Debug, Clone)]
pub struct VectorQuery<'a> {
vector: Vector<'a>,
vector_field_name: String,
num_results: usize,
filter_expression: Option<FilterExpression>,
ef_runtime: Option<usize>,
epsilon: Option<f64>,
hybrid_policy: Option<HybridPolicy>,
batch_size: Option<usize>,
search_window_size: Option<usize>,
use_search_history: Option<SearchHistoryMode>,
search_buffer_capacity: Option<usize>,
options: QueryOptions,
}
impl<'a> VectorQuery<'a> {
pub fn new(
vector: Vector<'a>,
vector_field_name: impl Into<String>,
num_results: usize,
) -> Self {
let mut options = QueryOptions::with_num_results(num_results);
options.return_fields.push("vector_distance".to_owned());
options.sort_by = Some(SortBy {
field: "vector_distance".to_owned(),
direction: SortDirection::Asc,
});
Self {
vector,
vector_field_name: vector_field_name.into(),
num_results,
filter_expression: None,
ef_runtime: None,
epsilon: None,
hybrid_policy: None,
batch_size: None,
search_window_size: None,
use_search_history: None,
search_buffer_capacity: None,
options,
}
}
pub fn with_filter(mut self, filter_expression: FilterExpression) -> Self {
self.filter_expression = Some(filter_expression);
self
}
pub fn set_filter(&mut self, filter_expression: FilterExpression) {
self.filter_expression = Some(filter_expression);
}
pub fn with_ef_runtime(mut self, ef_runtime: usize) -> Self {
self.ef_runtime = Some(ef_runtime);
self
}
pub fn set_ef_runtime(&mut self, ef_runtime: usize) {
self.ef_runtime = Some(ef_runtime);
}
pub fn ef_runtime(&self) -> Option<usize> {
self.ef_runtime
}
pub fn with_epsilon(mut self, epsilon: f64) -> Self {
self.epsilon = Some(epsilon);
self
}
pub fn set_epsilon(&mut self, epsilon: f64) {
self.epsilon = Some(epsilon);
}
pub fn epsilon(&self) -> Option<f64> {
self.epsilon
}
pub fn with_hybrid_policy(mut self, policy: HybridPolicy) -> Self {
self.hybrid_policy = Some(policy);
self
}
pub fn set_hybrid_policy(&mut self, policy: HybridPolicy) {
self.hybrid_policy = Some(policy);
}
pub fn hybrid_policy(&self) -> Option<HybridPolicy> {
self.hybrid_policy
}
pub fn with_batch_size(mut self, batch_size: usize) -> Self {
self.batch_size = Some(batch_size);
self
}
pub fn set_batch_size(&mut self, batch_size: usize) {
self.batch_size = Some(batch_size);
}
pub fn batch_size(&self) -> Option<usize> {
self.batch_size
}
pub fn with_search_window_size(mut self, size: usize) -> Self {
self.search_window_size = Some(size);
self
}
pub fn set_search_window_size(&mut self, size: usize) {
self.search_window_size = Some(size);
}
pub fn search_window_size(&self) -> Option<usize> {
self.search_window_size
}
pub fn with_use_search_history(mut self, mode: SearchHistoryMode) -> Self {
self.use_search_history = Some(mode);
self
}
pub fn set_use_search_history(&mut self, mode: SearchHistoryMode) {
self.use_search_history = Some(mode);
}
pub fn use_search_history(&self) -> Option<SearchHistoryMode> {
self.use_search_history
}
pub fn with_search_buffer_capacity(mut self, capacity: usize) -> Self {
self.search_buffer_capacity = Some(capacity);
self
}
pub fn set_search_buffer_capacity(&mut self, capacity: usize) {
self.search_buffer_capacity = Some(capacity);
}
pub fn search_buffer_capacity(&self) -> Option<usize> {
self.search_buffer_capacity
}
pub fn with_return_fields<I, S>(mut self, return_fields: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.options.return_fields = return_fields.into_iter().map(Into::into).collect();
if !self
.options
.return_fields
.iter()
.any(|field| field == "vector_distance")
{
self.options
.return_fields
.push("vector_distance".to_owned());
}
self
}
pub fn paging(mut self, offset: usize, num: usize) -> Self {
self.options.limit = QueryLimit { offset, num };
self
}
pub fn sort_by(mut self, field: impl Into<String>, direction: SortDirection) -> Self {
self.options.sort_by = Some(SortBy {
field: field.into(),
direction,
});
self
}
pub fn in_order(mut self, in_order: bool) -> Self {
self.options.in_order = in_order;
self
}
pub fn with_dialect(mut self, dialect: u32) -> Self {
self.options.dialect = dialect;
self
}
pub fn vector(&self) -> &Vector<'a> {
&self.vector
}
}
impl QueryString for VectorQuery<'_> {
fn to_redis_query(&self) -> String {
let base = self
.filter_expression
.as_ref()
.map_or_else(|| "*".to_owned(), FilterExpression::to_redis_syntax);
let mut query = format!(
"{}=>[KNN {} @{} $vector AS vector_distance",
base, self.num_results, self.vector_field_name
);
if self.ef_runtime.is_some() {
query.push_str(" EF_RUNTIME $EF");
}
if self.epsilon.is_some() {
query.push_str(" EPSILON $EPSILON");
}
if self.search_window_size.is_some() {
query.push_str(" SEARCH_WINDOW_SIZE $SEARCH_WINDOW_SIZE");
}
if self.use_search_history.is_some() {
query.push_str(" USE_SEARCH_HISTORY $USE_SEARCH_HISTORY");
}
if self.search_buffer_capacity.is_some() {
query.push_str(" SEARCH_BUFFER_CAPACITY $SEARCH_BUFFER_CAPACITY");
}
query.push(']');
if let Some(policy) = &self.hybrid_policy {
query.push_str(&format!(" HYBRID_POLICY {}", policy.as_str()));
if let Some(batch_size) = self.batch_size {
query.push_str(&format!(" BATCH_SIZE {}", batch_size));
}
}
query
}
fn params(&self) -> Vec<QueryParam> {
let mut params = vec![QueryParam {
name: "vector".to_owned(),
value: QueryParamValue::Binary(self.vector.to_bytes().to_vec()),
}];
if let Some(ef_runtime) = self.ef_runtime {
params.push(QueryParam {
name: "EF".to_owned(),
value: QueryParamValue::String(ef_runtime.to_string()),
});
}
if let Some(epsilon) = self.epsilon {
params.push(QueryParam {
name: "EPSILON".to_owned(),
value: QueryParamValue::String(epsilon.to_string()),
});
}
if let Some(size) = self.search_window_size {
params.push(QueryParam {
name: "SEARCH_WINDOW_SIZE".to_owned(),
value: QueryParamValue::String(size.to_string()),
});
}
if let Some(mode) = &self.use_search_history {
params.push(QueryParam {
name: "USE_SEARCH_HISTORY".to_owned(),
value: QueryParamValue::String(mode.as_str().to_owned()),
});
}
if let Some(capacity) = self.search_buffer_capacity {
params.push(QueryParam {
name: "SEARCH_BUFFER_CAPACITY".to_owned(),
value: QueryParamValue::String(capacity.to_string()),
});
}
params
}
fn return_fields(&self) -> Vec<String> {
self.options.return_fields.clone()
}
fn sort_by(&self) -> Option<SortBy> {
self.options.sort_by.clone()
}
fn limit(&self) -> Option<QueryLimit> {
Some(self.options.limit)
}
fn dialect(&self) -> u32 {
self.options.dialect
}
fn in_order(&self) -> bool {
self.options.in_order
}
}
impl PageableQuery for VectorQuery<'_> {
fn paged(&self, offset: usize, num: usize) -> Self {
self.clone().paging(offset, num)
}
}
#[derive(Debug, Clone)]
pub struct VectorRangeQuery<'a> {
vector: Vector<'a>,
vector_field_name: String,
distance_threshold: f32,
filter_expression: Option<FilterExpression>,
epsilon: Option<f64>,
hybrid_policy: Option<HybridPolicy>,
batch_size: Option<usize>,
search_window_size: Option<usize>,
use_search_history: Option<SearchHistoryMode>,
search_buffer_capacity: Option<usize>,
options: QueryOptions,
}
impl<'a> VectorRangeQuery<'a> {
pub fn new(
vector: Vector<'a>,
vector_field_name: impl Into<String>,
distance_threshold: f32,
) -> Self {
let mut options = QueryOptions::with_num_results(10);
options.return_fields.push("vector_distance".to_owned());
options.sort_by = Some(SortBy {
field: "vector_distance".to_owned(),
direction: SortDirection::Asc,
});
Self {
vector,
vector_field_name: vector_field_name.into(),
distance_threshold,
filter_expression: None,
epsilon: None,
hybrid_policy: None,
batch_size: None,
search_window_size: None,
use_search_history: None,
search_buffer_capacity: None,
options,
}
}
pub fn with_filter(mut self, filter_expression: FilterExpression) -> Self {
self.filter_expression = Some(filter_expression);
self
}
pub fn set_filter(&mut self, filter_expression: FilterExpression) {
self.filter_expression = Some(filter_expression);
}
pub fn distance_threshold(&self) -> f32 {
self.distance_threshold
}
pub fn set_distance_threshold(&mut self, distance_threshold: f32) {
self.distance_threshold = distance_threshold;
}
pub fn with_epsilon(mut self, epsilon: f64) -> Self {
self.epsilon = Some(epsilon);
self
}
pub fn set_epsilon(&mut self, epsilon: f64) {
self.epsilon = Some(epsilon);
}
pub fn epsilon(&self) -> Option<f64> {
self.epsilon
}
pub fn with_hybrid_policy(mut self, policy: HybridPolicy) -> Self {
self.hybrid_policy = Some(policy);
self
}
pub fn set_hybrid_policy(&mut self, policy: HybridPolicy) {
self.hybrid_policy = Some(policy);
}
pub fn hybrid_policy(&self) -> Option<HybridPolicy> {
self.hybrid_policy
}
pub fn with_batch_size(mut self, batch_size: usize) -> Self {
self.batch_size = Some(batch_size);
self
}
pub fn set_batch_size(&mut self, batch_size: usize) {
self.batch_size = Some(batch_size);
}
pub fn batch_size(&self) -> Option<usize> {
self.batch_size
}
pub fn with_search_window_size(mut self, size: usize) -> Self {
self.search_window_size = Some(size);
self
}
pub fn set_search_window_size(&mut self, size: usize) {
self.search_window_size = Some(size);
}
pub fn search_window_size(&self) -> Option<usize> {
self.search_window_size
}
pub fn with_use_search_history(mut self, mode: SearchHistoryMode) -> Self {
self.use_search_history = Some(mode);
self
}
pub fn set_use_search_history(&mut self, mode: SearchHistoryMode) {
self.use_search_history = Some(mode);
}
pub fn use_search_history(&self) -> Option<SearchHistoryMode> {
self.use_search_history
}
pub fn with_search_buffer_capacity(mut self, capacity: usize) -> Self {
self.search_buffer_capacity = Some(capacity);
self
}
pub fn set_search_buffer_capacity(&mut self, capacity: usize) {
self.search_buffer_capacity = Some(capacity);
}
pub fn search_buffer_capacity(&self) -> Option<usize> {
self.search_buffer_capacity
}
pub fn with_return_fields<I, S>(mut self, return_fields: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.options.return_fields = return_fields.into_iter().map(Into::into).collect();
if !self
.options
.return_fields
.iter()
.any(|field| field == "vector_distance")
{
self.options
.return_fields
.push("vector_distance".to_owned());
}
self
}
pub fn paging(mut self, offset: usize, num: usize) -> Self {
self.options.limit = QueryLimit { offset, num };
self
}
pub fn sort_by(mut self, field: impl Into<String>, direction: SortDirection) -> Self {
self.options.sort_by = Some(SortBy {
field: field.into(),
direction,
});
self
}
pub fn in_order(mut self, in_order: bool) -> Self {
self.options.in_order = in_order;
self
}
pub fn with_dialect(mut self, dialect: u32) -> Self {
self.options.dialect = dialect;
self
}
pub fn vector(&self) -> &Vector<'a> {
&self.vector
}
}
impl QueryString for VectorRangeQuery<'_> {
fn to_redis_query(&self) -> String {
let filter = self
.filter_expression
.as_ref()
.map_or_else(|| "*".to_owned(), FilterExpression::to_redis_syntax);
let base_query = format!(
"@{}:[VECTOR_RANGE $distance_threshold $vector]",
self.vector_field_name
);
let mut attr_parts = vec!["$YIELD_DISTANCE_AS: vector_distance".to_owned()];
if let Some(epsilon) = self.epsilon {
attr_parts.push(format!("$EPSILON: {}", epsilon));
}
if let Some(size) = self.search_window_size {
attr_parts.push(format!("$SEARCH_WINDOW_SIZE: {}", size));
}
if let Some(mode) = &self.use_search_history {
attr_parts.push(format!("$USE_SEARCH_HISTORY: {}", mode.as_str()));
}
if let Some(capacity) = self.search_buffer_capacity {
attr_parts.push(format!("$SEARCH_BUFFER_CAPACITY: {}", capacity));
}
let attr_section = format!("=>{{{}}}", attr_parts.join("; "));
if filter == "*" {
format!("{}{}", base_query, attr_section)
} else {
format!("({}{} {})", base_query, attr_section, filter)
}
}
fn params(&self) -> Vec<QueryParam> {
let mut params = vec![
QueryParam {
name: "vector".to_owned(),
value: QueryParamValue::Binary(self.vector.to_bytes().to_vec()),
},
QueryParam {
name: "distance_threshold".to_owned(),
value: QueryParamValue::String(self.distance_threshold.to_string()),
},
];
if let Some(policy) = &self.hybrid_policy {
params.push(QueryParam {
name: "HYBRID_POLICY".to_owned(),
value: QueryParamValue::String(policy.as_str().to_owned()),
});
}
if let Some(batch_size) = self.batch_size {
params.push(QueryParam {
name: "BATCH_SIZE".to_owned(),
value: QueryParamValue::String(batch_size.to_string()),
});
}
params
}
fn return_fields(&self) -> Vec<String> {
self.options.return_fields.clone()
}
fn sort_by(&self) -> Option<SortBy> {
self.options.sort_by.clone()
}
fn limit(&self) -> Option<QueryLimit> {
Some(self.options.limit)
}
fn dialect(&self) -> u32 {
self.options.dialect
}
fn in_order(&self) -> bool {
self.options.in_order
}
}
impl PageableQuery for VectorRangeQuery<'_> {
fn paged(&self, offset: usize, num: usize) -> Self {
self.clone().paging(offset, num)
}
}
#[derive(Debug, Clone)]
pub struct TextQuery {
text: String,
text_field_name: Option<String>,
filter_expression: Option<FilterExpression>,
return_score: bool,
options: QueryOptions,
stopwords: Option<std::collections::HashSet<String>>,
text_weights: Option<std::collections::HashMap<String, f32>>,
}
impl TextQuery {
pub fn new(text: impl Into<String>) -> Self {
Self {
text: text.into(),
text_field_name: None,
filter_expression: None,
return_score: true,
options: QueryOptions::with_num_results(10),
stopwords: None,
text_weights: None,
}
}
pub fn for_field(mut self, text_field_name: impl Into<String>) -> Self {
self.text_field_name = Some(text_field_name.into());
self
}
pub fn with_filter(mut self, filter_expression: FilterExpression) -> Self {
self.filter_expression = Some(filter_expression);
self
}
pub fn set_filter(&mut self, filter_expression: FilterExpression) {
self.filter_expression = Some(filter_expression);
}
pub fn with_return_score(mut self, return_score: bool) -> Self {
self.return_score = return_score;
self
}
pub fn with_return_fields<I, S>(mut self, return_fields: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.options.return_fields = return_fields.into_iter().map(Into::into).collect();
self
}
pub fn paging(mut self, offset: usize, num: usize) -> Self {
self.options.limit = QueryLimit { offset, num };
self
}
pub fn sort_by(mut self, field: impl Into<String>, direction: SortDirection) -> Self {
self.options.sort_by = Some(SortBy {
field: field.into(),
direction,
});
self
}
pub fn in_order(mut self, in_order: bool) -> Self {
self.options.in_order = in_order;
self
}
pub fn with_dialect(mut self, dialect: u32) -> Self {
self.options.dialect = dialect;
self
}
pub fn with_scorer(mut self, scorer: impl Into<String>) -> Self {
self.options.scorer = Some(scorer.into());
self
}
pub fn with_stopwords(mut self, stopwords: std::collections::HashSet<String>) -> Self {
self.stopwords = Some(stopwords);
self
}
pub fn with_text_weights(mut self, weights: std::collections::HashMap<String, f32>) -> Self {
self.text_weights = Some(weights);
self
}
pub fn set_text_weights(&mut self, weights: std::collections::HashMap<String, f32>) {
self.text_weights = Some(weights);
}
pub fn text_weights(&self) -> Option<&std::collections::HashMap<String, f32>> {
self.text_weights.as_ref()
}
fn build_query_text(&self) -> String {
let mut text = self.text.clone();
if let Some(stopwords) = &self.stopwords {
if !stopwords.is_empty() {
let words: Vec<&str> = text.split_whitespace().collect();
let filtered: Vec<&str> = words
.into_iter()
.filter(|w| !stopwords.contains(&w.to_lowercase()))
.collect();
text = filtered.join(" ");
}
}
if let Some(weights) = &self.text_weights {
if !weights.is_empty() {
let words: Vec<String> = text
.split_whitespace()
.map(|w| {
if let Some(weight) = weights.get(w) {
format!("{}=>{{{}}}", w, weight)
} else {
w.to_owned()
}
})
.collect();
text = words.join(" ");
}
}
text
}
}
impl QueryString for TextQuery {
fn to_redis_query(&self) -> String {
let processed_text = self.build_query_text();
let text_part = match &self.text_field_name {
Some(field) => format!("@{}:({})", field, processed_text),
None => processed_text,
};
match &self.filter_expression {
Some(filter) => {
let filter_str = filter.to_redis_syntax();
if filter_str == "*" {
text_part
} else {
format!("{} AND {}", text_part, filter_str)
}
}
None => text_part,
}
}
fn return_fields(&self) -> Vec<String> {
self.options.return_fields.clone()
}
fn sort_by(&self) -> Option<SortBy> {
self.options.sort_by.clone()
}
fn limit(&self) -> Option<QueryLimit> {
Some(self.options.limit)
}
fn dialect(&self) -> u32 {
self.options.dialect
}
fn in_order(&self) -> bool {
self.options.in_order
}
fn scorer(&self) -> Option<String> {
self.options.scorer.clone()
}
}
impl PageableQuery for TextQuery {
fn paged(&self, offset: usize, num: usize) -> Self {
self.clone().paging(offset, num)
}
}
#[derive(Debug, Clone)]
pub struct FilterQuery {
filter_expression: FilterExpression,
options: QueryOptions,
}
impl FilterQuery {
pub fn new(filter_expression: FilterExpression) -> Self {
Self {
filter_expression,
options: QueryOptions::with_num_results(10),
}
}
pub fn set_filter(&mut self, filter_expression: FilterExpression) {
self.filter_expression = filter_expression;
}
pub fn with_return_fields<I, S>(mut self, return_fields: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.options.return_fields = return_fields.into_iter().map(Into::into).collect();
self
}
pub fn paging(mut self, offset: usize, num: usize) -> Self {
self.options.limit = QueryLimit { offset, num };
self
}
pub fn sort_by(mut self, field: impl Into<String>, direction: SortDirection) -> Self {
self.options.sort_by = Some(SortBy {
field: field.into(),
direction,
});
self
}
pub fn in_order(mut self, in_order: bool) -> Self {
self.options.in_order = in_order;
self
}
pub fn with_dialect(mut self, dialect: u32) -> Self {
self.options.dialect = dialect;
self
}
}
impl QueryString for FilterQuery {
fn to_redis_query(&self) -> String {
self.filter_expression.to_redis_syntax()
}
fn return_fields(&self) -> Vec<String> {
self.options.return_fields.clone()
}
fn sort_by(&self) -> Option<SortBy> {
self.options.sort_by.clone()
}
fn limit(&self) -> Option<QueryLimit> {
Some(self.options.limit)
}
fn dialect(&self) -> u32 {
self.options.dialect
}
fn in_order(&self) -> bool {
self.options.in_order
}
fn should_unpack_json(&self) -> bool {
true
}
}
impl PageableQuery for FilterQuery {
fn paged(&self, offset: usize, num: usize) -> Self {
self.clone().paging(offset, num)
}
}
#[derive(Debug, Clone)]
pub struct CountQuery {
filter_expression: Option<FilterExpression>,
dialect: u32,
}
impl CountQuery {
pub fn new() -> Self {
Self {
filter_expression: None,
dialect: 2,
}
}
pub fn with_filter(mut self, filter_expression: FilterExpression) -> Self {
self.filter_expression = Some(filter_expression);
self
}
pub fn with_dialect(mut self, dialect: u32) -> Self {
self.dialect = dialect;
self
}
}
impl Default for CountQuery {
fn default() -> Self {
Self::new()
}
}
impl QueryString for CountQuery {
fn to_redis_query(&self) -> String {
self.filter_expression
.as_ref()
.map_or_else(|| "*".to_owned(), FilterExpression::to_redis_syntax)
}
fn limit(&self) -> Option<QueryLimit> {
Some(QueryLimit { offset: 0, num: 0 })
}
fn dialect(&self) -> u32 {
self.dialect
}
fn no_content(&self) -> bool {
true
}
fn kind(&self) -> QueryKind {
QueryKind::Count
}
}
#[derive(Debug, Clone, Copy)]
pub enum HybridCombinationMethod {
Linear,
Rrf,
}
impl HybridCombinationMethod {
pub fn redis_name(self) -> &'static str {
match self {
Self::Linear => "LINEAR",
Self::Rrf => "RRF",
}
}
}
#[derive(Debug, Clone, Copy)]
pub enum VectorSearchMethod {
Knn,
Range,
}
#[derive(Debug, Clone)]
pub struct HybridQuery<'a> {
text: String,
text_field_name: String,
vector: Vector<'a>,
vector_field_name: String,
vector_param_name: String,
num_results: usize,
text_scorer: Option<String>,
yield_text_score_as: Option<String>,
vector_search_method: Option<VectorSearchMethod>,
knn_ef_runtime: Option<usize>,
range_radius: Option<f32>,
range_epsilon: Option<f32>,
yield_vsim_score_as: Option<String>,
filter_expression: Option<FilterExpression>,
combination_method: Option<HybridCombinationMethod>,
rrf_window: Option<usize>,
rrf_constant: Option<usize>,
linear_alpha: Option<f32>,
yield_combined_score_as: Option<String>,
return_fields: Vec<String>,
stopwords: Option<std::collections::HashSet<String>>,
text_weights: Option<std::collections::HashMap<String, f32>>,
}
impl<'a> HybridQuery<'a> {
pub fn new(
text: impl Into<String>,
text_field_name: impl Into<String>,
vector: Vector<'a>,
vector_field_name: impl Into<String>,
) -> Self {
Self {
text: text.into(),
text_field_name: text_field_name.into(),
vector,
vector_field_name: vector_field_name.into(),
vector_param_name: "vector".to_owned(),
num_results: 10,
text_scorer: None,
yield_text_score_as: None,
vector_search_method: None,
knn_ef_runtime: None,
range_radius: None,
range_epsilon: None,
yield_vsim_score_as: None,
filter_expression: None,
combination_method: None,
rrf_window: None,
rrf_constant: None,
linear_alpha: None,
yield_combined_score_as: None,
return_fields: Vec::new(),
stopwords: None,
text_weights: None,
}
}
pub fn with_num_results(mut self, num_results: usize) -> Self {
self.num_results = num_results;
self
}
pub fn with_text_scorer(mut self, scorer: impl Into<String>) -> Self {
self.text_scorer = Some(scorer.into());
self
}
pub fn with_yield_text_score_as(mut self, alias: impl Into<String>) -> Self {
self.yield_text_score_as = Some(alias.into());
self
}
pub fn with_knn(mut self, ef_runtime: Option<usize>) -> Self {
self.vector_search_method = Some(VectorSearchMethod::Knn);
self.knn_ef_runtime = ef_runtime;
self
}
pub fn with_range(mut self, radius: f32, epsilon: Option<f32>) -> Self {
self.vector_search_method = Some(VectorSearchMethod::Range);
self.range_radius = Some(radius);
self.range_epsilon = epsilon;
self
}
pub fn with_yield_vsim_score_as(mut self, alias: impl Into<String>) -> Self {
self.yield_vsim_score_as = Some(alias.into());
self
}
pub fn with_filter(mut self, filter_expression: FilterExpression) -> Self {
self.filter_expression = Some(filter_expression);
self
}
pub fn with_combination_method(mut self, method: HybridCombinationMethod) -> Self {
self.combination_method = Some(method);
self
}
pub fn with_rrf(mut self, window: Option<usize>, constant: Option<usize>) -> Self {
self.combination_method = Some(HybridCombinationMethod::Rrf);
self.rrf_window = window;
self.rrf_constant = constant;
self
}
pub fn with_linear(mut self, alpha: f32) -> Self {
self.combination_method = Some(HybridCombinationMethod::Linear);
self.linear_alpha = Some(alpha);
self
}
pub fn with_yield_combined_score_as(mut self, alias: impl Into<String>) -> Self {
self.yield_combined_score_as = Some(alias.into());
self
}
pub fn with_return_fields<I, S>(mut self, return_fields: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.return_fields = return_fields.into_iter().map(Into::into).collect();
self
}
pub fn with_stopwords(mut self, stopwords: std::collections::HashSet<String>) -> Self {
self.stopwords = Some(stopwords);
self
}
pub fn with_text_weights(mut self, weights: std::collections::HashMap<String, f32>) -> Self {
self.text_weights = Some(weights);
self
}
pub fn with_vector_param_name(mut self, name: impl Into<String>) -> Self {
self.vector_param_name = name.into();
self
}
pub fn vector(&self) -> &Vector<'a> {
&self.vector
}
fn build_query_string(&self) -> String {
let mut text = self.text.clone();
if let Some(stopwords) = &self.stopwords {
if !stopwords.is_empty() {
let words: Vec<&str> = text.split_whitespace().collect();
let filtered: Vec<&str> = words
.into_iter()
.filter(|w| !stopwords.contains(&w.to_lowercase()))
.collect();
text = filtered.join(" ");
}
}
if let Some(weights) = &self.text_weights {
if !weights.is_empty() {
let words: Vec<String> = text
.split_whitespace()
.map(|w| {
if let Some(weight) = weights.get(w) {
format!("{}=>{{{}}}", w, weight)
} else {
w.to_owned()
}
})
.collect();
text = words.join(" ");
}
}
format!("@{}:({})", self.text_field_name, text)
}
pub fn build_cmd(&self, index_name: &str) -> redis::Cmd {
let mut cmd = redis::cmd("FT.HYBRID");
cmd.arg(index_name);
let query_string = self.build_query_string();
cmd.arg("SEARCH").arg(&query_string);
if let Some(scorer) = &self.text_scorer {
cmd.arg("SCORER").arg(scorer);
}
if let Some(alias) = &self.yield_text_score_as {
cmd.arg("YIELD_SCORE_AS").arg(alias);
}
cmd.arg("VSIM")
.arg(format!("@{}", self.vector_field_name))
.arg(format!("${}", self.vector_param_name));
if let Some(method) = self.vector_search_method {
match method {
VectorSearchMethod::Knn => {
let mut kv_count = 1_usize; if self.knn_ef_runtime.is_some() {
kv_count += 1;
}
cmd.arg("KNN").arg(kv_count * 2);
cmd.arg("K").arg(self.num_results);
if let Some(ef) = self.knn_ef_runtime {
cmd.arg("EF_RUNTIME").arg(ef);
}
}
VectorSearchMethod::Range => {
let mut kv_count = 0_usize;
if self.range_radius.is_some() {
kv_count += 1;
}
if self.range_epsilon.is_some() {
kv_count += 1;
}
if kv_count > 0 {
cmd.arg("RANGE").arg(kv_count * 2);
} else {
cmd.arg("RANGE");
}
if let Some(radius) = self.range_radius {
cmd.arg("RADIUS").arg(radius);
}
if let Some(epsilon) = self.range_epsilon {
cmd.arg("EPSILON").arg(epsilon);
}
}
}
}
if let Some(filter) = &self.filter_expression {
let filter_str = filter.to_redis_syntax();
if filter_str != "*" {
cmd.arg("FILTER").arg(&filter_str);
}
}
if let Some(alias) = &self.yield_vsim_score_as {
cmd.arg("YIELD_SCORE_AS").arg(alias);
}
if let Some(method) = &self.combination_method {
cmd.arg("COMBINE").arg(method.redis_name());
let mut kv_pairs: Vec<(String, String)> = Vec::new();
match method {
HybridCombinationMethod::Rrf => {
if let Some(window) = self.rrf_window {
kv_pairs.push(("WINDOW".to_owned(), window.to_string()));
}
if let Some(constant) = self.rrf_constant {
kv_pairs.push(("CONSTANT".to_owned(), constant.to_string()));
}
}
HybridCombinationMethod::Linear => {
if let Some(alpha) = self.linear_alpha {
kv_pairs.push(("ALPHA".to_owned(), alpha.to_string()));
kv_pairs.push(("BETA".to_owned(), (1.0 - alpha).to_string()));
}
}
}
if let Some(alias) = &self.yield_combined_score_as {
kv_pairs.push(("YIELD_SCORE_AS".to_owned(), alias.clone()));
}
if !kv_pairs.is_empty() {
cmd.arg(kv_pairs.len() * 2);
for (k, v) in &kv_pairs {
cmd.arg(k).arg(v);
}
}
}
if !self.return_fields.is_empty() {
cmd.arg("LOAD");
cmd.arg(self.return_fields.len());
for field in &self.return_fields {
cmd.arg(format!("@{}", field));
}
}
cmd.arg("LIMIT").arg(0).arg(self.num_results);
cmd.arg("PARAMS")
.arg(2)
.arg(&self.vector_param_name)
.arg(self.vector.to_bytes().as_ref());
cmd
}
}
#[derive(Debug, Clone)]
pub struct AggregateHybridQuery<'a> {
text: String,
text_field_name: String,
vector: Vector<'a>,
vector_field_name: String,
alpha: f32,
num_results: usize,
text_scorer: String,
filter_expression: Option<FilterExpression>,
return_fields: Vec<String>,
stopwords: Option<std::collections::HashSet<String>>,
text_weights: Option<std::collections::HashMap<String, f32>>,
dialect: u32,
}
impl<'a> AggregateHybridQuery<'a> {
pub fn new(
text: impl Into<String>,
text_field_name: impl Into<String>,
vector: Vector<'a>,
vector_field_name: impl Into<String>,
) -> std::result::Result<Self, String> {
let text = text.into();
if text.trim().is_empty() {
return Err("text string cannot be empty".to_owned());
}
Ok(Self {
text,
text_field_name: text_field_name.into(),
vector,
vector_field_name: vector_field_name.into(),
alpha: 0.7,
num_results: 10,
text_scorer: "BM25STD".to_owned(),
filter_expression: None,
return_fields: Vec::new(),
stopwords: None,
text_weights: None,
dialect: 2,
})
}
pub fn with_alpha(mut self, alpha: f32) -> Self {
self.alpha = alpha;
self
}
pub fn with_num_results(mut self, num_results: usize) -> Self {
self.num_results = num_results;
self
}
pub fn with_text_scorer(mut self, scorer: impl Into<String>) -> Self {
self.text_scorer = scorer.into();
self
}
pub fn with_filter(mut self, filter_expression: FilterExpression) -> Self {
self.filter_expression = Some(filter_expression);
self
}
pub fn with_return_fields<I, S>(mut self, return_fields: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.return_fields = return_fields.into_iter().map(Into::into).collect();
self
}
pub fn with_stopwords(mut self, stopwords: std::collections::HashSet<String>) -> Self {
self.stopwords = Some(stopwords);
self
}
pub fn with_text_weights(mut self, weights: std::collections::HashMap<String, f32>) -> Self {
self.text_weights = Some(weights);
self
}
pub fn set_text_weights(&mut self, weights: std::collections::HashMap<String, f32>) {
self.text_weights = Some(weights);
}
pub fn with_dialect(mut self, dialect: u32) -> Self {
self.dialect = dialect;
self
}
pub fn vector(&self) -> &Vector<'a> {
&self.vector
}
pub fn alpha(&self) -> f32 {
self.alpha
}
pub fn text(&self) -> &str {
&self.text
}
pub(crate) fn build_query_string(&self) -> String {
let tokens: Vec<String> = self
.text
.split_whitespace()
.map(|w| w.to_lowercase())
.filter(|w| {
if let Some(stopwords) = &self.stopwords {
!stopwords.contains(w.as_str())
} else {
true
}
})
.collect();
let tokens: Vec<String> = tokens
.into_iter()
.map(|w| {
if let Some(weights) = &self.text_weights {
if let Some(weight) = weights.get(&w) {
return format!("{}=>{{{}}}", w, weight);
}
}
w
})
.collect();
let text = tokens.join(" | ");
let base = if let Some(filter) = &self.filter_expression {
let filter_str = filter.to_redis_syntax();
if filter_str == "*" {
format!("(~@{}:({}))", self.text_field_name, text)
} else {
format!("(~@{}:({}) AND {})", self.text_field_name, text, filter_str)
}
} else {
format!("(~@{}:({}))", self.text_field_name, text)
};
format!(
"{}=>[KNN {} @{} $vector AS vector_distance]",
base, self.num_results, self.vector_field_name,
)
}
pub fn build_aggregate_cmd(&self, index_name: &str) -> redis::Cmd {
let query_string = self.build_query_string();
let mut cmd = redis::cmd("FT.AGGREGATE");
cmd.arg(index_name);
cmd.arg(&query_string);
cmd.arg("SCORER").arg(&self.text_scorer);
cmd.arg("ADDSCORES");
if !self.return_fields.is_empty() {
cmd.arg("LOAD");
cmd.arg(self.return_fields.len());
for field in &self.return_fields {
cmd.arg(field);
}
}
cmd.arg("DIALECT").arg(self.dialect);
cmd.arg("APPLY")
.arg("(2 - @vector_distance)/2")
.arg("AS")
.arg("vector_similarity");
cmd.arg("APPLY").arg("@__score").arg("AS").arg("text_score");
let hybrid_expr = format!(
"{}*@text_score + {}*@vector_similarity",
1.0 - self.alpha,
self.alpha
);
cmd.arg("APPLY")
.arg(&hybrid_expr)
.arg("AS")
.arg("hybrid_score");
cmd.arg("SORTBY")
.arg(2)
.arg("@hybrid_score")
.arg("DESC")
.arg("MAX")
.arg(self.num_results);
cmd.arg("PARAMS")
.arg(2)
.arg("vector")
.arg(self.vector.to_bytes().as_ref());
cmd
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum VectorDtype {
BFloat16,
Float16,
Float32,
Float64,
Int8,
Uint8,
}
impl Default for VectorDtype {
fn default() -> Self {
Self::Float32
}
}
impl VectorDtype {
pub fn bytes_per_element(self) -> usize {
match self {
Self::BFloat16 | Self::Float16 => 2,
Self::Float32 => 4,
Self::Float64 => 8,
Self::Int8 | Self::Uint8 => 1,
}
}
}
#[derive(Debug, Clone)]
pub struct VectorInput<'a> {
pub vector: Cow<'a, [u8]>,
pub field_name: String,
pub weight: f32,
pub dtype: VectorDtype,
pub max_distance: f32,
}
impl<'a> VectorInput<'a> {
pub fn from_floats(elements: &[f32], field_name: impl Into<String>) -> Self {
let mut buf = Vec::with_capacity(elements.len() * std::mem::size_of::<f32>());
for &v in elements {
buf.extend_from_slice(&v.to_le_bytes());
}
Self {
vector: Cow::Owned(buf),
field_name: field_name.into(),
weight: 1.0,
dtype: VectorDtype::Float32,
max_distance: 2.0,
}
}
pub fn from_bytes(
bytes: impl Into<Cow<'a, [u8]>>,
field_name: impl Into<String>,
dtype: VectorDtype,
) -> Self {
Self {
vector: bytes.into(),
field_name: field_name.into(),
weight: 1.0,
dtype,
max_distance: 2.0,
}
}
pub fn with_weight(mut self, weight: f32) -> Self {
self.weight = weight;
self
}
pub fn with_dtype(mut self, dtype: VectorDtype) -> Self {
self.dtype = dtype;
self
}
pub fn with_max_distance(mut self, max_distance: f32) -> Self {
assert!(
(0.0..=2.0).contains(&max_distance),
"max_distance must be in [0.0, 2.0], got {}",
max_distance
);
self.max_distance = max_distance;
self
}
}
#[derive(Debug, Clone)]
pub struct MultiVectorQuery<'a> {
vectors: Vec<VectorInput<'a>>,
filter_expression: Option<FilterExpression>,
num_results: usize,
return_fields: Vec<String>,
dialect: u32,
}
impl<'a> MultiVectorQuery<'a> {
pub fn new(vectors: Vec<VectorInput<'a>>) -> Self {
Self {
vectors,
filter_expression: None,
num_results: 10,
return_fields: Vec::new(),
dialect: 2,
}
}
pub fn with_num_results(mut self, num_results: usize) -> Self {
self.num_results = num_results;
self
}
pub fn with_filter(mut self, filter: FilterExpression) -> Self {
self.filter_expression = Some(filter);
self
}
pub fn with_return_fields<I, S>(mut self, fields: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.return_fields = fields.into_iter().map(Into::into).collect();
self
}
pub fn with_dialect(mut self, dialect: u32) -> Self {
self.dialect = dialect;
self
}
pub fn vectors(&self) -> &[VectorInput<'a>] {
&self.vectors
}
pub fn build_query_string(&self) -> String {
let mut parts = Vec::with_capacity(self.vectors.len());
for (i, vi) in self.vectors.iter().enumerate() {
parts.push(format!(
"@{}:[VECTOR_RANGE {} $vector_{}]=>{{$YIELD_DISTANCE_AS: distance_{}}}",
vi.field_name, vi.max_distance, i, i
));
}
let base = parts.join(" AND ");
if let Some(filter) = &self.filter_expression {
let filter_str = filter.to_redis_syntax();
if filter_str != "*" {
format!("({}) {}", filter_str, base)
} else {
base
}
} else {
base
}
}
pub fn build_aggregate_cmd(&self, index_name: &str) -> redis::Cmd {
let query_string = self.build_query_string();
let mut cmd = redis::cmd("FT.AGGREGATE");
cmd.arg(index_name);
cmd.arg(&query_string);
cmd.arg("SCORER").arg("TFIDF");
cmd.arg("DIALECT").arg(self.dialect);
for i in 0..self.vectors.len() {
cmd.arg("APPLY")
.arg(format!("(2 - @distance_{})/2", i))
.arg("AS")
.arg(format!("score_{}", i));
}
let combined_expr: Vec<String> = self
.vectors
.iter()
.enumerate()
.map(|(i, vi)| format!("@score_{} * {}", i, vi.weight))
.collect();
cmd.arg("APPLY")
.arg(combined_expr.join(" + "))
.arg("AS")
.arg("combined_score");
cmd.arg("SORTBY")
.arg(2)
.arg("@combined_score")
.arg("DESC")
.arg("MAX")
.arg(self.num_results);
if !self.return_fields.is_empty() {
cmd.arg("LOAD");
cmd.arg(self.return_fields.len());
for field in &self.return_fields {
cmd.arg(format!("@{}", field));
}
}
let param_count = self.vectors.len() * 2;
cmd.arg("PARAMS").arg(param_count);
for (i, vi) in self.vectors.iter().enumerate() {
cmd.arg(format!("vector_{}", i));
cmd.arg(vi.vector.as_ref());
}
cmd
}
}
impl QueryString for str {
fn to_redis_query(&self) -> String {
self.to_owned()
}
}
impl QueryString for &str {
fn to_redis_query(&self) -> String {
(*self).to_owned()
}
}
impl QueryString for String {
fn to_redis_query(&self) -> String {
self.clone()
}
}
#[cfg(feature = "sql")]
mod sql;
#[cfg(feature = "sql")]
pub use sql::{SQLQuery, SqlParam};
#[cfg(test)]
mod tests {
use super::{
AggregateHybridQuery, CountQuery, FilterQuery, HybridCombinationMethod, HybridPolicy,
HybridQuery, MultiVectorQuery, PageableQuery, QueryString, SearchHistoryMode,
SortDirection, TextQuery, Vector, VectorDtype, VectorInput, VectorQuery, VectorRangeQuery,
};
use crate::filter::{Num, Tag, Text};
#[test]
fn vector_query_should_render_knn() {
let query = VectorQuery::new(Vector::new(vec![1.0, 2.0, 3.0]), "embedding", 5)
.with_return_fields(["field1", "field2"])
.with_dialect(3);
assert!(query.to_redis_query().contains("KNN 5"));
assert_eq!(query.vector().to_bytes().len(), 12);
assert_eq!(
query.render().return_fields,
vec!["field1", "field2", "vector_distance"]
);
assert_eq!(query.render().dialect, 3);
}
#[test]
fn hybrid_query_should_build_ft_hybrid_cmd_like_python_hybrid_query() {
let query = HybridQuery::new(
"a medical professional",
"description",
Vector::new(vec![0.1, 0.1, 0.5]),
"user_embedding",
)
.with_num_results(10)
.with_combination_method(HybridCombinationMethod::Rrf)
.with_yield_combined_score_as("hybrid_score")
.with_return_fields(["user", "age", "job"]);
let cmd = query.build_cmd("my_index");
let packed = cmd.get_packed_command();
let cmd_str = String::from_utf8_lossy(&packed);
assert!(cmd_str.contains("FT.HYBRID"));
assert!(cmd_str.contains("my_index"));
assert!(cmd_str.contains("@description:(a medical professional)"));
assert!(cmd_str.contains("COMBINE"));
assert!(cmd_str.contains("RRF"));
assert!(cmd_str.contains("YIELD_SCORE_AS"));
assert!(cmd_str.contains("hybrid_score"));
}
#[test]
fn hybrid_query_with_rrf_params_like_python_hybrid_query_rrf() {
let query = HybridQuery::new(
"search text",
"content",
Vector::new(vec![0.5, 0.5]),
"vec_field",
)
.with_rrf(Some(100), Some(10));
let cmd = query.build_cmd("idx");
let packed = cmd.get_packed_command();
let cmd_str = String::from_utf8_lossy(&packed);
assert!(cmd_str.contains("COMBINE"));
assert!(cmd_str.contains("RRF"));
assert!(cmd_str.contains("WINDOW"));
assert!(cmd_str.contains("CONSTANT"));
}
#[test]
fn hybrid_query_with_linear_alpha_like_python_hybrid_query_linear() {
let query =
HybridQuery::new("query text", "body", Vector::new(vec![1.0]), "vec").with_linear(0.3);
let cmd = query.build_cmd("idx");
let packed = cmd.get_packed_command();
let cmd_str = String::from_utf8_lossy(&packed);
assert!(cmd_str.contains("COMBINE"));
assert!(cmd_str.contains("LINEAR"));
assert!(cmd_str.contains("ALPHA"));
}
#[test]
fn hybrid_query_with_filter_like_python_hybrid_query_filter() {
let filter = Tag::new("status").eq("active");
let query = HybridQuery::new("doctors", "description", Vector::new(vec![1.0, 2.0]), "vec")
.with_filter(filter);
let cmd = query.build_cmd("idx");
let packed = cmd.get_packed_command();
let cmd_str = String::from_utf8_lossy(&packed);
assert!(cmd_str.contains("FILTER"));
assert!(cmd_str.contains("@status:{active}"));
}
#[test]
fn hybrid_query_with_stopwords_and_weights_like_python_hybrid_query() {
use std::collections::{HashMap, HashSet};
let mut stopwords = HashSet::new();
stopwords.insert("the".to_owned());
stopwords.insert("a".to_owned());
let mut weights = HashMap::new();
weights.insert("doctor".to_owned(), 2.0_f32);
let query = HybridQuery::new(
"a doctor in the house",
"description",
Vector::new(vec![1.0]),
"vec",
)
.with_stopwords(stopwords)
.with_text_weights(weights);
let query_string = query.build_query_string();
assert!(!query_string.contains(" a "));
assert!(!query_string.contains(" the "));
assert!(query_string.contains("doctor"));
assert!(query_string.contains("doctor=>{2}"));
}
#[test]
fn hybrid_query_with_text_scorer_like_python_hybrid_query() {
let query = HybridQuery::new("test", "body", Vector::new(vec![1.0]), "vec")
.with_text_scorer("BM25STD")
.with_yield_text_score_as("text_score");
let cmd = query.build_cmd("idx");
let packed = cmd.get_packed_command();
let cmd_str = String::from_utf8_lossy(&packed);
assert!(cmd_str.contains("SCORER"));
assert!(cmd_str.contains("BM25STD"));
assert!(cmd_str.contains("YIELD_SCORE_AS"));
assert!(cmd_str.contains("text_score"));
}
#[test]
fn filter_query_should_track_paging_and_sort_like_python_test_query_types() {
let query = FilterQuery::new(Tag::new("brand").eq("Nike"))
.with_return_fields(["brand", "price"])
.paging(5, 7)
.sort_by("price", SortDirection::Asc)
.in_order(true)
.with_dialect(2);
let render = query.render();
assert_eq!(render.return_fields, vec!["brand", "price"]);
assert_eq!(render.limit.expect("limit").offset, 5);
assert_eq!(render.limit.expect("limit").num, 7);
assert!(render.sort_by.is_some());
assert!(render.in_order);
assert_eq!(render.dialect, 2);
}
#[test]
fn count_query_should_use_nocontent_and_zero_limit_like_python_test_query_types() {
let render = CountQuery::new()
.with_filter(Tag::new("brand").eq("Nike"))
.render();
assert!(render.no_content);
assert_eq!(render.limit.expect("limit").num, 0);
assert_eq!(render.dialect, 2);
}
#[test]
fn text_query_should_track_return_fields_and_limit_like_python_test_query_types() {
let render = TextQuery::new("basketball")
.for_field("description")
.with_return_fields(["title", "genre", "rating"])
.paging(5, 7)
.render();
assert_eq!(render.return_fields, vec!["title", "genre", "rating"]);
assert_eq!(render.limit.expect("limit").offset, 5);
assert!(render.query_string.contains("@description:(basketball)"));
}
#[test]
fn range_query_should_include_vector_params_like_python_test_query_types() {
let render = VectorRangeQuery::new(Vector::new(vec![1.0, 2.0, 3.0]), "embedding", 0.2)
.with_return_fields(["field1"])
.render();
assert_eq!(render.params.len(), 2);
assert_eq!(render.params[0].name, "vector");
assert_eq!(render.params[1].name, "distance_threshold");
assert_eq!(render.return_fields, vec!["field1", "vector_distance"]);
}
#[test]
fn vector_range_query_should_update_distance_threshold_like_python_integration_test_query() {
let mut query = VectorRangeQuery::new(Vector::new(vec![1.0, 2.0, 3.0]), "embedding", 0.2);
assert_eq!(query.distance_threshold(), 0.2);
query.set_distance_threshold(0.1);
assert_eq!(query.distance_threshold(), 0.1);
assert!(
query
.to_redis_query()
.contains("VECTOR_RANGE $distance_threshold")
);
}
#[test]
fn vector_query_should_replace_filter_in_place_like_python_integration_test_query() {
let mut query = VectorQuery::new(Vector::new(vec![1.0, 2.0, 3.0]), "embedding", 5);
query.set_filter(Tag::new("brand").eq("Nike"));
assert!(query.to_redis_query().starts_with("@brand:{Nike}"));
query.set_filter(Num::new("price").gte(10.0));
assert!(query.to_redis_query().starts_with("@price:[10 +inf]"));
}
#[test]
fn pageable_queries_should_clone_updated_limits_for_pagination() {
let query = FilterQuery::new(Tag::new("brand").eq("Nike")).paging(0, 5);
let paged = query.paged(10, 3);
assert_eq!(paged.render().limit.expect("limit").offset, 10);
assert_eq!(paged.render().limit.expect("limit").num, 3);
assert_eq!(query.render().limit.expect("limit").offset, 0);
}
#[test]
fn raw_string_queries_should_render_directly_for_python_style_batch_search() {
let render = "@test:{foo}".render();
assert_eq!(render.query_string, "@test:{foo}");
assert!(render.params.is_empty());
}
#[test]
fn aggregate_hybrid_query_should_reject_empty_text() {
let result = AggregateHybridQuery::new("", "desc", Vector::new(vec![1.0]), "vec");
assert!(result.is_err());
}
#[test]
fn aggregate_hybrid_query_should_build_query_string_like_python_aggregate_hybrid() {
let query = AggregateHybridQuery::new(
"a medical professional with expertise in lung cancer",
"description",
Vector::new(vec![0.1, 0.1, 0.5]),
"user_embedding",
)
.unwrap()
.with_num_results(10);
let qs = query.build_query_string();
assert!(
qs.contains("~@description:("),
"should use ~ (optional) prefix: {qs}"
);
assert!(qs.contains(" | "), "tokens should be OR-joined: {qs}");
assert!(qs.contains("=>[KNN 10 @user_embedding $vector AS vector_distance]"));
}
#[test]
fn aggregate_hybrid_query_should_build_ft_aggregate_cmd_like_python() {
let query = AggregateHybridQuery::new(
"medical professional",
"description",
Vector::new(vec![0.1, 0.1, 0.5]),
"user_embedding",
)
.unwrap()
.with_alpha(0.5)
.with_num_results(3)
.with_text_scorer("BM25STD")
.with_return_fields(["user", "age", "job"]);
let cmd = query.build_aggregate_cmd("my_index");
let packed = cmd.get_packed_command();
let cmd_str = String::from_utf8_lossy(&packed);
assert!(cmd_str.contains("FT.AGGREGATE"));
assert!(cmd_str.contains("my_index"));
assert!(cmd_str.contains("SCORER"));
assert!(cmd_str.contains("BM25STD"));
assert!(cmd_str.contains("ADDSCORES"));
assert!(cmd_str.contains("vector_similarity"));
assert!(cmd_str.contains("text_score"));
assert!(cmd_str.contains("hybrid_score"));
assert!(cmd_str.contains("SORTBY"));
assert!(cmd_str.contains("LOAD"));
assert!(cmd_str.contains("DIALECT"));
assert!(cmd_str.contains("PARAMS"));
}
#[test]
fn aggregate_hybrid_query_with_filter_like_python_aggregate_filter() {
let filter = Tag::new("credit_score").eq("high") & Num::new("age").gt(30.0);
let query = AggregateHybridQuery::new(
"medical professional",
"description",
Vector::new(vec![0.1, 0.1, 0.5]),
"user_embedding",
)
.unwrap()
.with_filter(filter);
let qs = query.build_query_string();
assert!(qs.contains("@credit_score:{high}"));
assert!(qs.contains("@age:[(30"));
}
#[test]
fn aggregate_hybrid_query_with_stopwords_like_python_aggregate_stopwords() {
use std::collections::HashSet;
let mut stopwords = HashSet::new();
stopwords.insert("medical".to_owned());
stopwords.insert("expertise".to_owned());
let query = AggregateHybridQuery::new(
"a medical professional with expertise in lung cancer",
"description",
Vector::new(vec![0.1, 0.1, 0.5]),
"user_embedding",
)
.unwrap()
.with_stopwords(stopwords);
let qs = query.build_query_string();
assert!(!qs.contains("medical"));
assert!(!qs.contains("expertise"));
}
#[test]
fn aggregate_hybrid_query_with_text_weights_like_python_aggregate_word_weights() {
use std::collections::HashMap;
let mut weights = HashMap::new();
weights.insert("medical".to_owned(), 3.4_f32);
weights.insert("cancers".to_owned(), 5.0_f32);
let query = AggregateHybridQuery::new(
"a medical professional with expertise in lung cancers",
"description",
Vector::new(vec![0.1, 0.1, 0.5]),
"user_embedding",
)
.unwrap()
.with_text_weights(weights);
let qs = query.build_query_string();
assert!(qs.contains("medical=>{3.4}"));
assert!(qs.contains("cancers=>{5}"));
}
#[test]
fn aggregate_hybrid_query_set_text_weights_should_match_constructor_weights() {
use std::collections::HashMap;
let mut weights = HashMap::new();
weights.insert("medical".to_owned(), 5.0_f32);
let query1 = AggregateHybridQuery::new(
"a medical professional",
"description",
Vector::new(vec![0.1, 0.1, 0.5]),
"user_embedding",
)
.unwrap()
.with_text_weights(weights.clone());
let mut query2 = AggregateHybridQuery::new(
"a medical professional",
"description",
Vector::new(vec![0.1, 0.1, 0.5]),
"user_embedding",
)
.unwrap();
query2.set_text_weights(weights);
assert_eq!(query1.build_query_string(), query2.build_query_string());
}
#[test]
fn multi_vector_query_should_build_vector_range_query_like_python() {
let v1 = VectorInput::from_floats(&[0.1, 0.2, 0.3, 0.4], "text embedding")
.with_weight(0.2)
.with_max_distance(0.7);
let v2 = VectorInput::from_floats(&[0.5, 0.5], "image embedding")
.with_weight(0.7)
.with_max_distance(1.8);
let query = MultiVectorQuery::new(vec![v1, v2]);
let qs = query.build_query_string();
assert!(qs.contains("@text embedding:[VECTOR_RANGE 0.7 $vector_0]"));
assert!(qs.contains("YIELD_DISTANCE_AS: distance_0"));
assert!(qs.contains("@image embedding:[VECTOR_RANGE 1.8 $vector_1]"));
assert!(qs.contains("YIELD_DISTANCE_AS: distance_1"));
assert!(qs.contains("AND"));
}
#[test]
fn multi_vector_query_default_properties_like_python() {
let vi = VectorInput::from_floats(&[0.1, 0.2, 0.3, 0.4], "field_1");
assert_eq!(vi.weight, 1.0);
assert_eq!(vi.dtype, VectorDtype::Float32);
assert_eq!(vi.max_distance, 2.0);
let query = MultiVectorQuery::new(vec![vi]);
assert!(query.filter_expression.is_none());
assert_eq!(query.num_results, 10);
assert!(query.return_fields.is_empty());
assert_eq!(query.dialect, 2);
}
#[test]
fn multi_vector_query_should_accept_multiple_vectors_like_python() {
let v1 = VectorInput::from_floats(&[0.1, 0.2, 0.3, 0.4], "field_1")
.with_weight(0.2)
.with_max_distance(2.0);
let v2 = VectorInput::from_floats(&[0.1, 0.2, 0.3, 0.4], "field_2")
.with_weight(0.5)
.with_max_distance(1.5);
let v3 = VectorInput::from_floats(&[0.5, 0.5], "field_3")
.with_weight(0.6)
.with_max_distance(0.4);
let v4 = VectorInput::from_floats(&[0.1, 0.1, 0.1], "field_4")
.with_weight(0.1)
.with_max_distance(0.01);
let query = MultiVectorQuery::new(vec![v1, v2, v3, v4]);
assert_eq!(query.vectors().len(), 4);
}
#[test]
fn multi_vector_query_overrides_like_python() {
let vi = VectorInput::from_floats(&[0.1, 0.2], "field_1");
let filter = Tag::new("user group").one_of(["group A", "group C"]);
let query = MultiVectorQuery::new(vec![vi])
.with_filter(filter)
.with_num_results(5)
.with_return_fields(["field_1", "user name", "address"])
.with_dialect(4);
assert!(query.filter_expression.is_some());
assert_eq!(query.num_results, 5);
assert_eq!(query.return_fields, vec!["field_1", "user name", "address"]);
assert_eq!(query.dialect, 4);
}
#[test]
fn multi_vector_query_aggregate_cmd_like_python() {
let v1 = VectorInput::from_floats(&[0.1, 0.2, 0.3, 0.4], "text embedding")
.with_weight(0.2)
.with_max_distance(0.7);
let v2 = VectorInput::from_floats(&[0.5, 0.5], "image embedding")
.with_weight(0.7)
.with_max_distance(1.8);
let query = MultiVectorQuery::new(vec![v1, v2]);
let cmd = query.build_aggregate_cmd("my_index");
let packed = cmd.get_packed_command();
let cmd_str = String::from_utf8_lossy(&packed);
assert!(cmd_str.contains("FT.AGGREGATE"));
assert!(cmd_str.contains("my_index"));
assert!(cmd_str.contains("SCORER"));
assert!(cmd_str.contains("TFIDF"));
assert!(cmd_str.contains("score_0"));
assert!(cmd_str.contains("score_1"));
assert!(cmd_str.contains("combined_score"));
assert!(cmd_str.contains("SORTBY"));
assert!(cmd_str.contains("PARAMS"));
}
#[test]
fn multi_vector_query_with_filter_like_python() {
let v1 = VectorInput::from_floats(&[0.1, 0.1, 0.5], "user_embedding");
let v2 = VectorInput::from_floats(&[0.3, 0.4, 0.7, 0.2, -0.3], "image_embedding");
let filter = Text::new("description").eq("medical");
let query = MultiVectorQuery::new(vec![v1, v2]).with_filter(filter);
let qs = query.build_query_string();
assert!(qs.contains("@description"));
assert!(qs.contains("medical"));
}
#[test]
#[should_panic(expected = "max_distance must be in [0.0, 2.0]")]
fn vector_input_should_reject_invalid_max_distance_like_python() {
VectorInput::from_floats(&[0.1, 0.2], "field").with_max_distance(2.001);
}
#[test]
#[should_panic(expected = "max_distance must be in [0.0, 2.0]")]
fn vector_input_should_reject_negative_max_distance_like_python() {
VectorInput::from_floats(&[0.1, 0.2], "field").with_max_distance(-0.1);
}
#[test]
fn vector_input_from_bytes_like_python() {
let floats = [0.1_f32, 0.2, 0.3, 0.4];
let mut expected_bytes = Vec::new();
for &f in &floats {
expected_bytes.extend_from_slice(&f.to_le_bytes());
}
let vi = VectorInput::from_floats(&floats, "field_1");
assert_eq!(vi.vector.as_ref(), expected_bytes.as_slice());
let vi2 = VectorInput::from_bytes(expected_bytes.clone(), "field_1", VectorDtype::Float32);
assert_eq!(vi2.vector.as_ref(), expected_bytes.as_slice());
}
#[test]
fn aggregate_hybrid_query_reject_stopword_only_text_like_python() {
let result = AggregateHybridQuery::new(
"",
"description",
Vector::new(vec![0.1, 0.1, 0.5]),
"user_embedding",
);
assert!(result.is_err());
}
#[test]
fn aggregate_hybrid_query_with_string_filter_like_python() {
use crate::filter::FilterExpression;
let filter_str = "@category:{tech|science|engineering}";
let filter = FilterExpression::raw(filter_str);
let query = AggregateHybridQuery::new(
"search for document 12345",
"description",
Vector::new(vec![0.1, 0.2, 0.3, 0.4]),
"embedding",
)
.unwrap()
.with_filter(filter);
let qs = query.build_query_string();
assert!(
qs.contains("~@description:(search | for | document | 12345)"),
"tokens should be OR-joined with ~ prefix: {qs}"
);
assert!(
qs.contains("AND @category:{tech|science|engineering}"),
"filter should be AND-joined: {qs}"
);
}
#[test]
fn aggregate_hybrid_query_wildcard_filter_is_ignored_like_python() {
use crate::filter::FilterExpression;
let filter = FilterExpression::raw("*");
let query = AggregateHybridQuery::new(
"search text",
"description",
Vector::new(vec![0.1]),
"embedding",
)
.unwrap()
.with_filter(filter);
let qs = query.build_query_string();
assert!(!qs.contains("AND"));
}
#[test]
fn aggregate_hybrid_query_text_weights_validation_like_python() {
use std::collections::HashMap;
let q1 = AggregateHybridQuery::new(
"sample text query",
"description",
Vector::new(vec![0.1, 0.2, 0.3, 0.4]),
"embedding",
)
.unwrap()
.with_text_weights(HashMap::new());
assert!(q1.build_query_string().contains("sample"));
let mut weights = HashMap::new();
weights.insert("alpha".to_owned(), 0.2_f32);
weights.insert("bravo".to_owned(), 0.4_f32);
let q2 = AggregateHybridQuery::new(
"sample text query",
"description",
Vector::new(vec![0.1, 0.2, 0.3, 0.4]),
"embedding",
)
.unwrap()
.with_text_weights(weights);
let qs = q2.build_query_string();
assert!(qs.contains("sample"));
}
#[test]
fn hybrid_query_without_filter_like_python() {
let query = HybridQuery::new(
"test query",
"description",
Vector::new(vec![0.1, 0.2, 0.3, 0.4]),
"embedding",
);
let cmd = query.build_cmd("idx");
let packed = cmd.get_packed_command();
let cmd_str = String::from_utf8_lossy(&packed);
assert!(!cmd_str.contains("FILTER"));
assert!(cmd_str.contains("@description:(test query)"));
}
#[test]
fn hybrid_query_vector_search_method_knn_like_python() {
let query = HybridQuery::new(
"test query",
"description",
Vector::new(vec![0.1, 0.2, 0.3, 0.4]),
"embedding",
)
.with_knn(Some(100))
.with_num_results(10);
let cmd = query.build_cmd("idx");
let packed = cmd.get_packed_command();
let cmd_str = String::from_utf8_lossy(&packed);
assert!(cmd_str.contains("KNN"));
assert!(cmd_str.contains("EF_RUNTIME"));
}
#[test]
fn hybrid_query_vector_search_method_range_like_python() {
let query = HybridQuery::new(
"test query",
"description",
Vector::new(vec![0.1, 0.2, 0.3, 0.4]),
"embedding",
)
.with_range(10.0, Some(0.1));
let cmd = query.build_cmd("idx");
let packed = cmd.get_packed_command();
let cmd_str = String::from_utf8_lossy(&packed);
assert!(cmd_str.contains("RANGE"));
assert!(cmd_str.contains("RADIUS"));
assert!(cmd_str.contains("EPSILON"));
}
#[test]
fn hybrid_query_without_vector_search_method_like_python() {
let query = HybridQuery::new(
"test query",
"description",
Vector::new(vec![0.1, 0.2, 0.3, 0.4]),
"embedding",
);
let cmd = query.build_cmd("idx");
let packed = cmd.get_packed_command();
let cmd_str = String::from_utf8_lossy(&packed);
assert!(cmd_str.contains("VSIM"));
assert!(!cmd_str.contains("KNN"));
assert!(!cmd_str.contains("RANGE"));
}
#[test]
fn hybrid_query_rrf_with_both_params_like_python() {
let query = HybridQuery::new(
"test query",
"description",
Vector::new(vec![0.1, 0.2, 0.3, 0.4]),
"embedding",
)
.with_rrf(Some(20), Some(50))
.with_yield_combined_score_as("rrf_score");
let cmd = query.build_cmd("idx");
let packed = cmd.get_packed_command();
let cmd_str = String::from_utf8_lossy(&packed);
assert!(cmd_str.contains("RRF"));
assert!(cmd_str.contains("WINDOW"));
assert!(cmd_str.contains("CONSTANT"));
assert!(cmd_str.contains("YIELD_SCORE_AS"));
assert!(cmd_str.contains("rrf_score"));
}
#[test]
fn hybrid_query_linear_with_alpha_like_python() {
for alpha in [0.1_f32, 0.5, 0.9] {
let query = HybridQuery::new(
"test query",
"description",
Vector::new(vec![0.1, 0.2, 0.3, 0.4]),
"embedding",
)
.with_linear(alpha);
let cmd = query.build_cmd("idx");
let packed = cmd.get_packed_command();
let cmd_str = String::from_utf8_lossy(&packed);
assert!(cmd_str.contains("LINEAR"));
assert!(cmd_str.contains("ALPHA"));
assert!(cmd_str.contains("BETA"));
}
}
#[test]
fn hybrid_query_without_combination_method_like_python() {
let query = HybridQuery::new(
"test query",
"description",
Vector::new(vec![0.1, 0.2, 0.3, 0.4]),
"embedding",
);
let cmd = query.build_cmd("idx");
let packed = cmd.get_packed_command();
let cmd_str = String::from_utf8_lossy(&packed);
assert!(!cmd_str.contains("COMBINE"));
}
#[test]
fn hybrid_query_with_combined_filters_like_python() {
let filter = Tag::new("genre").eq("comedy") & Num::new("rating").gt(7.0);
let query = HybridQuery::new(
"test query",
"description",
Vector::new(vec![0.1, 0.2, 0.3, 0.4]),
"embedding",
)
.with_filter(filter);
let cmd = query.build_cmd("idx");
let packed = cmd.get_packed_command();
let cmd_str = String::from_utf8_lossy(&packed);
assert!(cmd_str.contains("FILTER"));
assert!(cmd_str.contains("genre"));
assert!(cmd_str.contains("comedy"));
assert!(cmd_str.contains("rating"));
}
#[test]
fn hybrid_query_with_numeric_filter_like_python() {
let filter = Num::new("age").gt(30.0);
let query = HybridQuery::new(
"test query",
"description",
Vector::new(vec![0.1, 0.2, 0.3, 0.4]),
"embedding",
)
.with_filter(filter);
let cmd = query.build_cmd("idx");
let packed = cmd.get_packed_command();
let cmd_str = String::from_utf8_lossy(&packed);
assert!(cmd_str.contains("FILTER"));
assert!(cmd_str.contains("@age:[(30"));
}
#[test]
fn hybrid_query_with_text_filter_like_python() {
let filter = Text::new("job").eq("engineer");
let query = HybridQuery::new(
"test query",
"description",
Vector::new(vec![0.1, 0.2, 0.3, 0.4]),
"embedding",
)
.with_filter(filter);
let cmd = query.build_cmd("idx");
let packed = cmd.get_packed_command();
let cmd_str = String::from_utf8_lossy(&packed);
assert!(cmd_str.contains("FILTER"));
assert!(cmd_str.contains("@job"));
assert!(cmd_str.contains("engineer"));
}
#[test]
fn vector_query_hybrid_policy_like_python_test_query_types() {
let query = VectorQuery::new(Vector::new(vec![0.1, 0.2, 0.3, 0.4]), "vector_field", 10)
.with_hybrid_policy(HybridPolicy::Batches);
assert_eq!(query.hybrid_policy(), Some(HybridPolicy::Batches));
assert!(query.to_redis_query().contains("HYBRID_POLICY BATCHES"));
}
#[test]
fn vector_query_hybrid_policy_with_batch_size_like_python() {
let query = VectorQuery::new(Vector::new(vec![0.1, 0.2, 0.3, 0.4]), "vector_field", 10)
.with_hybrid_policy(HybridPolicy::Batches)
.with_batch_size(50);
let qs = query.to_redis_query();
assert!(qs.contains("HYBRID_POLICY BATCHES BATCH_SIZE 50"));
}
#[test]
fn vector_query_adhoc_bf_policy_like_python() {
let query = VectorQuery::new(Vector::new(vec![0.1, 0.2, 0.3, 0.4]), "vector_field", 10)
.with_hybrid_policy(HybridPolicy::AdhocBf);
assert!(query.to_redis_query().contains("HYBRID_POLICY ADHOC_BF"));
}
#[test]
fn vector_query_epsilon_like_python_test_query_types() {
let query = VectorQuery::new(Vector::new(vec![0.1, 0.2, 0.3, 0.4]), "vector_field", 10)
.with_epsilon(0.05);
assert_eq!(query.epsilon(), Some(0.05));
let qs = query.to_redis_query();
assert!(qs.contains("EPSILON $EPSILON"));
let params = query.params();
assert!(params.iter().any(|p| p.name == "EPSILON"));
}
#[test]
fn vector_query_ef_runtime_params_like_python_test_query_types() {
let query = VectorQuery::new(Vector::new(vec![0.1, 0.2, 0.3, 0.4]), "vector_field", 10)
.with_ef_runtime(100);
assert_eq!(query.ef_runtime(), Some(100));
let qs = query.to_redis_query();
assert!(qs.contains("EF_RUNTIME $EF"));
let params = query.params();
assert!(params.iter().any(|p| p.name == "EF"));
}
#[test]
fn vector_query_search_window_size_like_python() {
let query = VectorQuery::new(Vector::new(vec![0.1, 0.2, 0.3, 0.4]), "vector_field", 10)
.with_search_window_size(40);
assert_eq!(query.search_window_size(), Some(40));
let qs = query.to_redis_query();
assert!(qs.contains("SEARCH_WINDOW_SIZE $SEARCH_WINDOW_SIZE"));
}
#[test]
fn vector_query_use_search_history_like_python() {
for mode in [
SearchHistoryMode::Off,
SearchHistoryMode::On,
SearchHistoryMode::Auto,
] {
let query = VectorQuery::new(Vector::new(vec![0.1, 0.2, 0.3, 0.4]), "vector_field", 10)
.with_use_search_history(mode);
assert_eq!(query.use_search_history(), Some(mode));
let qs = query.to_redis_query();
assert!(qs.contains("USE_SEARCH_HISTORY $USE_SEARCH_HISTORY"));
}
}
#[test]
fn vector_query_search_buffer_capacity_like_python() {
let query = VectorQuery::new(Vector::new(vec![0.1, 0.2, 0.3, 0.4]), "vector_field", 10)
.with_search_buffer_capacity(50);
assert_eq!(query.search_buffer_capacity(), Some(50));
let qs = query.to_redis_query();
assert!(qs.contains("SEARCH_BUFFER_CAPACITY $SEARCH_BUFFER_CAPACITY"));
}
#[test]
fn vector_query_all_runtime_params_like_python() {
let query = VectorQuery::new(Vector::new(vec![0.1, 0.2, 0.3, 0.4]), "vector_field", 10)
.with_ef_runtime(100)
.with_epsilon(0.05)
.with_search_window_size(40)
.with_use_search_history(SearchHistoryMode::On)
.with_search_buffer_capacity(50);
let qs = query.to_redis_query();
assert!(qs.contains("EF_RUNTIME $EF"));
assert!(qs.contains("EPSILON $EPSILON"));
assert!(qs.contains("SEARCH_WINDOW_SIZE $SEARCH_WINDOW_SIZE"));
assert!(qs.contains("USE_SEARCH_HISTORY $USE_SEARCH_HISTORY"));
assert!(qs.contains("SEARCH_BUFFER_CAPACITY $SEARCH_BUFFER_CAPACITY"));
let params = query.params();
assert!(params.iter().any(|p| p.name == "EF"));
assert!(params.iter().any(|p| p.name == "EPSILON"));
assert!(params.iter().any(|p| p.name == "SEARCH_WINDOW_SIZE"));
assert!(params.iter().any(|p| p.name == "USE_SEARCH_HISTORY"));
assert!(params.iter().any(|p| p.name == "SEARCH_BUFFER_CAPACITY"));
}
#[test]
fn vector_query_set_methods_like_python() {
let mut query = VectorQuery::new(Vector::new(vec![0.1, 0.2, 0.3, 0.4]), "vector_field", 10);
assert!(query.ef_runtime().is_none());
assert!(query.epsilon().is_none());
assert!(query.hybrid_policy().is_none());
query.set_ef_runtime(200);
assert_eq!(query.ef_runtime(), Some(200));
query.set_epsilon(0.1);
assert_eq!(query.epsilon(), Some(0.1));
query.set_hybrid_policy(HybridPolicy::Batches);
assert_eq!(query.hybrid_policy(), Some(HybridPolicy::Batches));
query.set_batch_size(100);
assert_eq!(query.batch_size(), Some(100));
}
#[test]
fn range_query_epsilon_like_python_test_query_types() {
let query =
VectorRangeQuery::new(Vector::new(vec![0.1, 0.2, 0.3, 0.4]), "vector_field", 0.3)
.with_epsilon(0.05);
assert_eq!(query.epsilon(), Some(0.05));
let qs = query.to_redis_query();
assert!(qs.contains("$EPSILON: 0.05"));
}
#[test]
fn range_query_construction_like_python() {
let basic = VectorRangeQuery::new(Vector::new(vec![0.1, 0.1, 0.5]), "user_embedding", 0.2)
.with_return_fields(["user", "credit_score"]);
let qs = basic.to_redis_query();
assert!(qs.contains("VECTOR_RANGE $distance_threshold $vector"));
assert!(qs.contains("$YIELD_DISTANCE_AS: vector_distance"));
assert!(!qs.contains("HYBRID_POLICY"));
let epsilon_query =
VectorRangeQuery::new(Vector::new(vec![0.1, 0.1, 0.5]), "user_embedding", 0.2)
.with_epsilon(0.05);
let qs = epsilon_query.to_redis_query();
assert!(qs.contains("$EPSILON: 0.05"));
assert_eq!(epsilon_query.epsilon(), Some(0.05));
}
#[test]
fn range_query_hybrid_policy_in_params_not_query_string_like_python() {
let query = VectorRangeQuery::new(Vector::new(vec![0.1, 0.1, 0.5]), "user_embedding", 0.2)
.with_hybrid_policy(HybridPolicy::Batches);
let qs = query.to_redis_query();
assert!(!qs.contains("HYBRID_POLICY"));
assert_eq!(query.hybrid_policy(), Some(HybridPolicy::Batches));
let params = query.params();
assert!(params.iter().any(|p| p.name == "HYBRID_POLICY"));
}
#[test]
fn range_query_hybrid_policy_with_batch_size_in_params_like_python() {
let query = VectorRangeQuery::new(Vector::new(vec![0.1, 0.1, 0.5]), "user_embedding", 0.2)
.with_hybrid_policy(HybridPolicy::Batches)
.with_batch_size(50);
let qs = query.to_redis_query();
assert!(!qs.contains("HYBRID_POLICY"));
assert!(!qs.contains("BATCH_SIZE"));
let params = query.params();
assert!(params.iter().any(|p| p.name == "HYBRID_POLICY"));
assert!(params.iter().any(|p| p.name == "BATCH_SIZE"));
}
#[test]
fn range_query_setter_methods_like_python() {
let mut query =
VectorRangeQuery::new(Vector::new(vec![0.1, 0.2, 0.3, 0.4]), "user_embedding", 0.2);
assert!(query.epsilon().is_none());
assert!(query.hybrid_policy().is_none());
assert!(query.batch_size().is_none());
query.set_epsilon(0.1);
assert_eq!(query.epsilon(), Some(0.1));
assert!(query.to_redis_query().contains("$EPSILON: 0.1"));
query.set_hybrid_policy(HybridPolicy::Batches);
assert_eq!(query.hybrid_policy(), Some(HybridPolicy::Batches));
query.set_batch_size(25);
assert_eq!(query.batch_size(), Some(25));
}
#[test]
fn range_query_search_window_size_like_python() {
let query =
VectorRangeQuery::new(Vector::new(vec![0.1, 0.2, 0.3, 0.4]), "vector_field", 0.3)
.with_search_window_size(40);
assert_eq!(query.search_window_size(), Some(40));
assert!(query.to_redis_query().contains("$SEARCH_WINDOW_SIZE: 40"));
}
#[test]
fn range_query_use_search_history_like_python() {
for (mode, expected_str) in [
(SearchHistoryMode::Off, "OFF"),
(SearchHistoryMode::On, "ON"),
(SearchHistoryMode::Auto, "AUTO"),
] {
let query =
VectorRangeQuery::new(Vector::new(vec![0.1, 0.2, 0.3, 0.4]), "vector_field", 0.3)
.with_use_search_history(mode);
assert_eq!(query.use_search_history(), Some(mode));
let qs = query.to_redis_query();
assert!(
qs.contains(&format!("$USE_SEARCH_HISTORY: {}", expected_str)),
"query string should contain USE_SEARCH_HISTORY for {:?}",
mode,
);
}
}
#[test]
fn range_query_search_buffer_capacity_like_python() {
let query =
VectorRangeQuery::new(Vector::new(vec![0.1, 0.2, 0.3, 0.4]), "vector_field", 0.3)
.with_search_buffer_capacity(50);
assert_eq!(query.search_buffer_capacity(), Some(50));
assert!(
query
.to_redis_query()
.contains("$SEARCH_BUFFER_CAPACITY: 50")
);
}
#[test]
fn range_query_all_svs_params_like_python() {
let query =
VectorRangeQuery::new(Vector::new(vec![0.1, 0.2, 0.3, 0.4]), "vector_field", 0.3)
.with_epsilon(0.05)
.with_search_window_size(40)
.with_use_search_history(SearchHistoryMode::On)
.with_search_buffer_capacity(50);
let qs = query.to_redis_query();
assert!(qs.contains("$EPSILON: 0.05"));
assert!(qs.contains("$SEARCH_WINDOW_SIZE: 40"));
assert!(qs.contains("$USE_SEARCH_HISTORY: ON"));
assert!(qs.contains("$SEARCH_BUFFER_CAPACITY: 50"));
}
#[test]
fn text_query_with_filter_expression_like_python() {
let filter = Tag::new("genre").eq("comedy");
let query = TextQuery::new("basketball")
.for_field("description")
.with_filter(filter);
let qs = query.to_redis_query();
assert!(qs.contains("@description:(basketball)"));
assert!(qs.contains("AND @genre:{comedy}"));
}
#[test]
fn text_query_without_filter_like_python() {
let query = TextQuery::new("basketball").for_field("description");
let qs = query.to_redis_query();
assert!(qs.contains("@description:(basketball)"));
assert!(!qs.contains("AND"));
}
#[test]
fn text_query_set_filter_like_python() {
let mut query = TextQuery::new("basketball").for_field("description");
query.set_filter(Tag::new("category").eq("sports"));
let qs = query.to_redis_query();
assert!(qs.contains("AND @category:{sports}"));
}
#[test]
fn text_query_with_stopwords_removes_words() {
use std::collections::HashSet;
let mut stopwords = HashSet::new();
stopwords.insert("the".to_owned());
stopwords.insert("a".to_owned());
let query = TextQuery::new("a doctor in the house")
.for_field("description")
.with_stopwords(stopwords);
let qs = query.to_redis_query();
assert!(!qs.contains(" a "));
assert!(!qs.contains(" the "));
assert!(qs.contains("doctor"));
assert!(qs.contains("house"));
}
#[test]
fn text_query_with_text_weights_applies_weight_syntax() {
use std::collections::HashMap;
let mut weights = HashMap::new();
weights.insert("doctor".to_owned(), 2.0_f32);
let query = TextQuery::new("a doctor in the house")
.for_field("description")
.with_text_weights(weights);
let qs = query.to_redis_query();
assert!(qs.contains("doctor=>{2}"));
assert!(qs.contains("house"));
}
#[test]
fn text_query_with_stopwords_and_weights_combined() {
use std::collections::{HashMap, HashSet};
let mut stopwords = HashSet::new();
stopwords.insert("the".to_owned());
stopwords.insert("a".to_owned());
let mut weights = HashMap::new();
weights.insert("doctor".to_owned(), 2.0_f32);
let query = TextQuery::new("a doctor in the house")
.for_field("description")
.with_stopwords(stopwords)
.with_text_weights(weights);
let qs = query.to_redis_query();
assert!(!qs.contains(" a "));
assert!(!qs.contains(" the "));
assert!(qs.contains("doctor=>{2}"));
}
#[test]
fn text_query_set_text_weights_mirrors_builder() {
use std::collections::HashMap;
let mut weights = HashMap::new();
weights.insert("medical".to_owned(), 5.0_f32);
let query1 = TextQuery::new("a medical professional")
.for_field("description")
.with_text_weights(weights.clone());
let mut query2 = TextQuery::new("a medical professional").for_field("description");
query2.set_text_weights(weights);
assert_eq!(query1.to_redis_query(), query2.to_redis_query());
}
#[test]
fn text_query_text_weights_accessor() {
use std::collections::HashMap;
let mut weights = HashMap::new();
weights.insert("alpha".to_owned(), 0.2_f32);
weights.insert("bravo".to_owned(), 0.4_f32);
let query = TextQuery::new("sample text query")
.for_field("description")
.with_text_weights(weights.clone());
assert_eq!(query.text_weights(), Some(&weights));
}
#[test]
fn text_query_no_weights_returns_none() {
let query = TextQuery::new("sample text query").for_field("description");
assert!(query.text_weights().is_none());
}
}