use super::field::{Field, Value};
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum WhereOperator {
Eq(Field, Value),
Neq(Field, Value),
Gt(Field, Value),
Gte(Field, Value),
Lt(Field, Value),
Lte(Field, Value),
In(Field, Vec<Value>),
Nin(Field, Vec<Value>),
Contains(Field, String),
ArrayContains(Field, Value),
ArrayContainedBy(Field, Value),
ArrayOverlaps(Field, Vec<Value>),
LenEq(Field, usize),
LenGt(Field, usize),
LenGte(Field, usize),
LenLt(Field, usize),
LenLte(Field, usize),
Icontains(Field, String),
Startswith(Field, String),
Istartswith(Field, String),
Endswith(Field, String),
Iendswith(Field, String),
Like(Field, String),
Ilike(Field, String),
IsNull(Field, bool),
L2Distance {
field: Field,
vector: Vec<f32>,
threshold: f32,
},
CosineDistance {
field: Field,
vector: Vec<f32>,
threshold: f32,
},
InnerProduct {
field: Field,
vector: Vec<f32>,
threshold: f32,
},
L1Distance {
field: Field,
vector: Vec<f32>,
threshold: f32,
},
HammingDistance {
field: Field,
vector: Vec<f32>,
threshold: f32,
},
JaccardDistance {
field: Field,
set: Vec<String>,
threshold: f32,
},
Matches {
field: Field,
query: String,
language: Option<String>,
},
PlainQuery {
field: Field,
query: String,
},
PhraseQuery {
field: Field,
query: String,
language: Option<String>,
},
WebsearchQuery {
field: Field,
query: String,
language: Option<String>,
},
IsIPv4(Field),
IsIPv6(Field),
IsPrivate(Field),
IsPublic(Field),
IsLoopback(Field),
InSubnet {
field: Field,
subnet: String,
},
ContainsSubnet {
field: Field,
subnet: String,
},
ContainsIP {
field: Field,
ip: String,
},
IPRangeOverlap {
field: Field,
range: String,
},
StrictlyContains(Field, Value),
AncestorOf {
field: Field,
path: String,
},
DescendantOf {
field: Field,
path: String,
},
MatchesLquery {
field: Field,
pattern: String,
},
MatchesLtxtquery {
field: Field,
query: String,
},
MatchesAnyLquery {
field: Field,
patterns: Vec<String>,
},
DepthEq {
field: Field,
depth: usize,
},
DepthNeq {
field: Field,
depth: usize,
},
DepthGt {
field: Field,
depth: usize,
},
DepthGte {
field: Field,
depth: usize,
},
DepthLt {
field: Field,
depth: usize,
},
DepthLte {
field: Field,
depth: usize,
},
Lca {
field: Field,
paths: Vec<String>,
},
}
impl WhereOperator {
pub const fn name(&self) -> &'static str {
match self {
WhereOperator::Eq(_, _) => "Eq",
WhereOperator::Neq(_, _) => "Neq",
WhereOperator::Gt(_, _) => "Gt",
WhereOperator::Gte(_, _) => "Gte",
WhereOperator::Lt(_, _) => "Lt",
WhereOperator::Lte(_, _) => "Lte",
WhereOperator::In(_, _) => "In",
WhereOperator::Nin(_, _) => "Nin",
WhereOperator::Contains(_, _) => "Contains",
WhereOperator::ArrayContains(_, _) => "ArrayContains",
WhereOperator::ArrayContainedBy(_, _) => "ArrayContainedBy",
WhereOperator::ArrayOverlaps(_, _) => "ArrayOverlaps",
WhereOperator::LenEq(_, _) => "LenEq",
WhereOperator::LenGt(_, _) => "LenGt",
WhereOperator::LenGte(_, _) => "LenGte",
WhereOperator::LenLt(_, _) => "LenLt",
WhereOperator::LenLte(_, _) => "LenLte",
WhereOperator::Icontains(_, _) => "Icontains",
WhereOperator::Startswith(_, _) => "Startswith",
WhereOperator::Istartswith(_, _) => "Istartswith",
WhereOperator::Endswith(_, _) => "Endswith",
WhereOperator::Iendswith(_, _) => "Iendswith",
WhereOperator::Like(_, _) => "Like",
WhereOperator::Ilike(_, _) => "Ilike",
WhereOperator::IsNull(_, _) => "IsNull",
WhereOperator::L2Distance { .. } => "L2Distance",
WhereOperator::CosineDistance { .. } => "CosineDistance",
WhereOperator::InnerProduct { .. } => "InnerProduct",
WhereOperator::L1Distance { .. } => "L1Distance",
WhereOperator::HammingDistance { .. } => "HammingDistance",
WhereOperator::JaccardDistance { .. } => "JaccardDistance",
WhereOperator::Matches { .. } => "Matches",
WhereOperator::PlainQuery { .. } => "PlainQuery",
WhereOperator::PhraseQuery { .. } => "PhraseQuery",
WhereOperator::WebsearchQuery { .. } => "WebsearchQuery",
WhereOperator::IsIPv4(_) => "IsIPv4",
WhereOperator::IsIPv6(_) => "IsIPv6",
WhereOperator::IsPrivate(_) => "IsPrivate",
WhereOperator::IsPublic(_) => "IsPublic",
WhereOperator::IsLoopback(_) => "IsLoopback",
WhereOperator::InSubnet { .. } => "InSubnet",
WhereOperator::ContainsSubnet { .. } => "ContainsSubnet",
WhereOperator::ContainsIP { .. } => "ContainsIP",
WhereOperator::IPRangeOverlap { .. } => "IPRangeOverlap",
WhereOperator::StrictlyContains(_, _) => "StrictlyContains",
WhereOperator::AncestorOf { .. } => "AncestorOf",
WhereOperator::DescendantOf { .. } => "DescendantOf",
WhereOperator::MatchesLquery { .. } => "MatchesLquery",
WhereOperator::MatchesLtxtquery { .. } => "MatchesLtxtquery",
WhereOperator::MatchesAnyLquery { .. } => "MatchesAnyLquery",
WhereOperator::DepthEq { .. } => "DepthEq",
WhereOperator::DepthNeq { .. } => "DepthNeq",
WhereOperator::DepthGt { .. } => "DepthGt",
WhereOperator::DepthGte { .. } => "DepthGte",
WhereOperator::DepthLt { .. } => "DepthLt",
WhereOperator::DepthLte { .. } => "DepthLte",
WhereOperator::Lca { .. } => "Lca",
}
}
pub fn validate(&self) -> Result<(), String> {
match self {
WhereOperator::Eq(f, _)
| WhereOperator::Neq(f, _)
| WhereOperator::Gt(f, _)
| WhereOperator::Gte(f, _)
| WhereOperator::Lt(f, _)
| WhereOperator::Lte(f, _)
| WhereOperator::In(f, _)
| WhereOperator::Nin(f, _)
| WhereOperator::Contains(f, _)
| WhereOperator::ArrayContains(f, _)
| WhereOperator::ArrayContainedBy(f, _)
| WhereOperator::ArrayOverlaps(f, _)
| WhereOperator::LenEq(f, _)
| WhereOperator::LenGt(f, _)
| WhereOperator::LenGte(f, _)
| WhereOperator::LenLt(f, _)
| WhereOperator::LenLte(f, _)
| WhereOperator::Icontains(f, _)
| WhereOperator::Startswith(f, _)
| WhereOperator::Istartswith(f, _)
| WhereOperator::Endswith(f, _)
| WhereOperator::Iendswith(f, _)
| WhereOperator::Like(f, _)
| WhereOperator::Ilike(f, _)
| WhereOperator::IsNull(f, _)
| WhereOperator::StrictlyContains(f, _) => f.validate(),
WhereOperator::L2Distance {
field, threshold, ..
}
| WhereOperator::CosineDistance {
field, threshold, ..
}
| WhereOperator::InnerProduct {
field, threshold, ..
}
| WhereOperator::L1Distance {
field, threshold, ..
}
| WhereOperator::HammingDistance {
field, threshold, ..
}
| WhereOperator::JaccardDistance {
field, threshold, ..
} => {
if !threshold.is_finite() {
return Err(format!(
"Vector distance threshold must be a finite number, got {}",
threshold
));
}
field.validate()
}
WhereOperator::Matches { field, .. }
| WhereOperator::PlainQuery { field, .. }
| WhereOperator::PhraseQuery { field, .. }
| WhereOperator::WebsearchQuery { field, .. }
| WhereOperator::IsIPv4(field)
| WhereOperator::IsIPv6(field)
| WhereOperator::IsPrivate(field)
| WhereOperator::IsPublic(field)
| WhereOperator::IsLoopback(field)
| WhereOperator::InSubnet { field, .. }
| WhereOperator::ContainsSubnet { field, .. }
| WhereOperator::ContainsIP { field, .. }
| WhereOperator::IPRangeOverlap { field, .. }
| WhereOperator::AncestorOf { field, .. }
| WhereOperator::DescendantOf { field, .. }
| WhereOperator::MatchesLquery { field, .. }
| WhereOperator::MatchesLtxtquery { field, .. }
| WhereOperator::MatchesAnyLquery { field, .. }
| WhereOperator::DepthEq { field, .. }
| WhereOperator::DepthNeq { field, .. }
| WhereOperator::DepthGt { field, .. }
| WhereOperator::DepthGte { field, .. }
| WhereOperator::DepthLt { field, .. }
| WhereOperator::DepthLte { field, .. }
| WhereOperator::Lca { field, .. } => field.validate(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_operator_names() {
let op = WhereOperator::Eq(Field::JsonbField("id".to_string()), Value::Number(1.0));
assert_eq!(op.name(), "Eq");
let op = WhereOperator::LenGt(Field::JsonbField("tags".to_string()), 5);
assert_eq!(op.name(), "LenGt");
}
#[test]
fn test_operator_validation() {
let op = WhereOperator::Eq(
Field::JsonbField("name".to_string()),
Value::String("John".to_string()),
);
op.validate()
.unwrap_or_else(|e| panic!("expected Ok for valid field 'name': {e}"));
let op = WhereOperator::Eq(
Field::JsonbField("bad-name".to_string()),
Value::String("John".to_string()),
);
let result = op.validate();
assert!(
result.is_err(),
"expected Err for invalid field 'bad-name', got: {result:?}"
);
}
#[test]
fn test_vector_operator_creation() {
let op = WhereOperator::L2Distance {
field: Field::JsonbField("embedding".to_string()),
vector: vec![0.1, 0.2, 0.3],
threshold: 0.5,
};
assert_eq!(op.name(), "L2Distance");
}
#[test]
fn test_network_operator_creation() {
let op = WhereOperator::InSubnet {
field: Field::JsonbField("ip".to_string()),
subnet: "192.168.0.0/24".to_string(),
};
assert_eq!(op.name(), "InSubnet");
}
}