use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct QdrantFilter {
#[serde(skip_serializing_if = "Option::is_none")]
pub must: Option<Vec<QdrantCondition>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub should: Option<Vec<QdrantCondition>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub must_not: Option<Vec<QdrantCondition>>,
}
impl QdrantFilter {
pub fn must(conditions: Vec<QdrantCondition>) -> Self {
Self {
must: Some(conditions),
should: None,
must_not: None,
}
}
pub fn should(conditions: Vec<QdrantCondition>) -> Self {
Self {
must: None,
should: Some(conditions),
must_not: None,
}
}
pub fn must_not(conditions: Vec<QdrantCondition>) -> Self {
Self {
must: None,
should: None,
must_not: Some(conditions),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum QdrantCondition {
Match {
key: String,
match_value: QdrantMatchValue,
},
Range {
key: String,
range: QdrantRange,
},
ValuesCount {
key: String,
values_count: QdrantValuesCount,
},
GeoBoundingBox {
key: String,
geo_bounding_box: QdrantGeoBoundingBox,
},
GeoRadius {
key: String,
geo_radius: QdrantGeoRadius,
},
Nested {
filter: Box<QdrantFilter>,
},
}
impl QdrantCondition {
pub fn match_string(key: &str, value: &str) -> Self {
Self::Match {
key: key.to_string(),
match_value: QdrantMatchValue::String(value.to_string()),
}
}
pub fn match_integer(key: &str, value: i64) -> Self {
Self::Match {
key: key.to_string(),
match_value: QdrantMatchValue::Integer(value),
}
}
pub fn match_bool(key: &str, value: bool) -> Self {
Self::Match {
key: key.to_string(),
match_value: QdrantMatchValue::Bool(value),
}
}
pub fn range(key: &str, range: QdrantRange) -> Self {
Self::Range {
key: key.to_string(),
range,
}
}
pub fn values_count(key: &str, values_count: QdrantValuesCount) -> Self {
Self::ValuesCount {
key: key.to_string(),
values_count,
}
}
pub fn nested(filter: QdrantFilter) -> Self {
Self::Nested {
filter: Box::new(filter),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum QdrantMatchValue {
String(String),
Integer(i64),
Bool(bool),
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct QdrantRange {
#[serde(skip_serializing_if = "Option::is_none")]
pub gt: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub gte: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub lt: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub lte: Option<f64>,
}
impl QdrantRange {
pub fn gte(value: f64) -> Self {
Self {
gte: Some(value),
..Default::default()
}
}
pub fn lt(value: f64) -> Self {
Self {
lt: Some(value),
..Default::default()
}
}
pub fn between_inclusive(min: f64, max: f64) -> Self {
Self {
gte: Some(min),
lte: Some(max),
..Default::default()
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct QdrantValuesCount {
#[serde(skip_serializing_if = "Option::is_none")]
pub gt: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub gte: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub lt: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub lte: Option<u32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QdrantGeoBoundingBox {
pub top_right: QdrantGeoPoint,
pub bottom_left: QdrantGeoPoint,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QdrantGeoRadius {
pub center: QdrantGeoPoint,
pub radius: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QdrantGeoPoint {
pub lat: f64,
pub lon: f64,
}
impl QdrantGeoPoint {
pub fn new(lat: f64, lon: f64) -> Self {
Self { lat, lon }
}
}
#[cfg(test)]
mod tests {
use serde_json::{Value, json};
use super::*;
#[test]
fn match_string_wire_shape() {
let cond = QdrantCondition::match_string("topic", "index");
let v = serde_json::to_value(&cond).unwrap();
assert_eq!(v["type"], "match");
assert_eq!(v["key"], "topic");
assert_eq!(v["match_value"], "index");
}
#[test]
fn match_integer_wire_shape() {
let cond = QdrantCondition::match_integer("count", 42);
let v = serde_json::to_value(&cond).unwrap();
assert_eq!(v["type"], "match");
assert_eq!(v["key"], "count");
assert_eq!(v["match_value"], 42);
}
#[test]
fn range_condition_wire_shape() {
let cond = QdrantCondition::range(
"score",
QdrantRange {
gte: Some(0.5),
lt: Some(0.9),
..Default::default()
},
);
let v = serde_json::to_value(&cond).unwrap();
assert_eq!(v["type"], "range");
assert_eq!(v["key"], "score");
assert_eq!(v["range"]["gte"], 0.5);
assert_eq!(v["range"]["lt"], 0.9);
assert!(v["range"].get("gt").is_none() || v["range"]["gt"].is_null());
assert!(v["range"].get("lte").is_none() || v["range"]["lte"].is_null());
}
#[test]
fn filter_must_omits_absent_clauses() {
let filter = QdrantFilter::must(vec![QdrantCondition::match_string("topic", "index")]);
let v = serde_json::to_value(&filter).unwrap();
assert!(v.get("must").is_some());
assert!(v.get("should").is_none());
assert!(v.get("must_not").is_none());
}
#[test]
fn compound_and_filter_round_trips() {
let filter = QdrantFilter::must(vec![
QdrantCondition::match_string("tier", "hot"),
QdrantCondition::range("score", QdrantRange::gte(0.8)),
]);
let serialised = serde_json::to_string(&filter).unwrap();
let deserialised: QdrantFilter = serde_json::from_str(&serialised).unwrap();
let must = deserialised.must.unwrap();
assert_eq!(must.len(), 2);
match &must[0] {
QdrantCondition::Match {
key,
match_value: QdrantMatchValue::String(v),
} => {
assert_eq!(key, "tier");
assert_eq!(v, "hot");
}
other => panic!("unexpected condition: {other:?}"),
}
match &must[1] {
QdrantCondition::Range { key, range } => {
assert_eq!(key, "score");
assert_eq!(range.gte, Some(0.8));
}
other => panic!("unexpected condition: {other:?}"),
}
}
#[test]
fn nested_filter_wire_shape() {
let inner = QdrantFilter::must(vec![QdrantCondition::match_string("inner_key", "value")]);
let outer = QdrantFilter::must(vec![QdrantCondition::nested(inner)]);
let v = serde_json::to_value(&outer).unwrap();
let nested_cond = &v["must"][0];
assert_eq!(nested_cond["type"], "nested");
assert!(nested_cond.get("filter").is_some());
assert_eq!(nested_cond["filter"]["must"][0]["key"], "inner_key");
}
#[test]
fn filter_wrapped_in_request_body() {
let filter = QdrantFilter::must(vec![QdrantCondition::match_string("topic", "index")]);
let body = json!({ "filter": filter });
let filter_obj = &body["filter"];
assert!(filter_obj.get("must").is_some());
assert_eq!(filter_obj["must"][0]["type"], "match");
assert_eq!(filter_obj["must"][0]["key"], "topic");
assert_eq!(filter_obj["must"][0]["match_value"], "index");
}
}