Skip to main content

rig_qdrant/
filter.rs

1use qdrant_client::qdrant::{
2    Condition, FieldCondition, Filter, IsEmptyCondition, IsNullCondition, Match, Range,
3    condition::ConditionOneOf, r#match::MatchValue,
4};
5use rig_core::vector_store::request::{FilterError, SearchFilter};
6use serde::{Deserialize, Serialize};
7use serde_json::json;
8
9/// Qdrant-compatible metadata filter for vector search requests.
10///
11/// Use this as the filter type for [`rig_core::vector_store::request::VectorSearchRequest`]
12/// when querying [`crate::QdrantVectorStore`].
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct QdrantFilter(serde_json::Value);
15
16impl SearchFilter for QdrantFilter {
17    type Value = serde_json::Value;
18
19    fn eq(key: impl AsRef<str>, value: Self::Value) -> Self {
20        let key = key.as_ref().to_owned();
21
22        Self(json!({
23            "key": key,
24            "match": {
25                "value": value
26            }
27        }))
28    }
29
30    fn gt(key: impl AsRef<str>, value: Self::Value) -> Self {
31        let key = key.as_ref().to_owned();
32
33        Self(json!({
34            "key": key,
35            "range": {
36                "gt": value
37            }
38        }))
39    }
40
41    fn lt(key: impl AsRef<str>, value: Self::Value) -> Self {
42        let key = key.as_ref().to_owned();
43
44        Self(json!({
45            "key": key,
46            "range": {
47                "lt": value
48            }
49        }))
50    }
51
52    fn and(self, rhs: Self) -> Self {
53        Self(json!({ "must": [ self.0, rhs.0 ]}))
54    }
55
56    fn or(self, rhs: Self) -> Self {
57        Self(json!({ "should": [ self.0, rhs.0 ]}))
58    }
59}
60
61impl QdrantFilter {
62    #[allow(clippy::should_implement_trait)]
63    pub fn not(self) -> Self {
64        Self(json!({ "must_not": [ self.0 ]}))
65    }
66    pub fn into_inner(self) -> serde_json::Value {
67        self.0
68    }
69
70    pub fn exists(key: String) -> Self {
71        Self(json!({ "key": key, "is_null": { "value": false } }))
72    }
73
74    pub fn is_null(key: String) -> Self {
75        Self(json!({ "key": key, "is_null": { "value": true } }))
76    }
77
78    pub fn is_empty(key: String) -> Self {
79        Self(json!({ "is_empty": { "key": key } }))
80    }
81
82    /// Construct a range filter `(lo .. hi)`
83    pub fn range_exclusive(key: String, lo: serde_json::Value, hi: serde_json::Value) -> Self {
84        Self(json!({
85            "key": key,
86            "range": {
87                "gt": lo,
88                "lt": hi
89            }
90        }))
91    }
92
93    /// Construct a range filter `[lo .. hi)`
94    pub fn range_lower_inclusive(
95        key: String,
96        lo: serde_json::Value,
97        hi: serde_json::Value,
98    ) -> Self {
99        Self(json!({
100            "key": key,
101            "range": {
102                "gt": lo,
103                "lte": hi
104            }
105        }))
106    }
107
108    /// Construct a range filter `(lo .. hi]`
109    pub fn range_higher_inclusive(
110        key: String,
111        lo: serde_json::Value,
112        hi: serde_json::Value,
113    ) -> Self {
114        Self(json!({
115            "key": key,
116            "range": {
117                "gte": lo,
118                "lt": hi
119            }
120        }))
121    }
122
123    /// Construct a range filter `[lo .. hi]`
124    pub fn range_inclusive(key: String, lo: serde_json::Value, hi: serde_json::Value) -> Self {
125        Self(json!({
126            "key": key,
127            "range": {
128                "gte": lo,
129                "lte": hi
130            }
131        }))
132    }
133
134    pub fn interpret(self) -> Result<Option<Filter>, FilterError> {
135        use serde_json::Value::*;
136
137        let value = self.into_inner();
138
139        if let Null = value {
140            Ok(None)
141        } else if json!({}) == value {
142            Ok(None)
143        } else {
144            fn to_match(value: serde_json::Value) -> Result<MatchValue, FilterError> {
145                match value {
146                    String(s) => Ok(MatchValue::Keyword(s)),
147                    Bool(b) => Ok(MatchValue::Boolean(b)),
148                    Number(n) => {
149                        if let Some(as_int) = n.as_i64() {
150                            Ok(MatchValue::Integer(as_int))
151                        } else {
152                            Err(FilterError::Expected {
153                                expected: "Integer".into(),
154                                got: n.to_string(),
155                            })
156                        }
157                    }
158                    _ => Err(FilterError::TypeError(value.to_string())),
159                }
160            }
161
162            fn to_condition(value: serde_json::Value) -> Result<Condition, FilterError> {
163                // Handle is_empty condition
164                if let Some(is_empty) = value.get("is_empty") {
165                    let key = is_empty
166                        .get("key")
167                        .and_then(|k| k.as_str())
168                        .ok_or(FilterError::MissingField("key".into()))?
169                        .to_string();
170
171                    Ok(Condition {
172                        condition_one_of: Some(
173                            qdrant_client::qdrant::condition::ConditionOneOf::IsEmpty(
174                                IsEmptyCondition { key },
175                            ),
176                        ),
177                    })
178                } else if let Some(is_null) = value.get("is_null") {
179                    let is_null_value =
180                        is_null
181                            .get("value")
182                            .and_then(|v| v.as_bool())
183                            .ok_or(FilterError::Must(
184                                "is_null".into(),
185                                "have a 'value' field".into(),
186                            ))?;
187
188                    // Get the key from the parent object
189                    let key = value
190                        .get("key")
191                        .and_then(|k| k.as_str())
192                        .ok_or(FilterError::Must(
193                            "is_null".into(),
194                            "have a 'key' field".into(),
195                        ))?
196                        .to_string();
197
198                    if is_null_value {
199                        Ok(Condition {
200                            condition_one_of: Some(
201                                qdrant_client::qdrant::condition::ConditionOneOf::IsNull(
202                                    IsNullCondition { key },
203                                ),
204                            ),
205                        })
206                    } else {
207                        let is_empty_condition = Condition {
208                            condition_one_of: Some(ConditionOneOf::IsEmpty(IsEmptyCondition {
209                                key,
210                            })),
211                        };
212
213                        let filter = Filter {
214                            must_not: vec![is_empty_condition],
215                            ..Default::default()
216                        };
217
218                        Ok(Condition {
219                            condition_one_of: Some(ConditionOneOf::Filter(filter)),
220                        })
221                    }
222                } else if value
223                    .as_object()
224                    .map(|o| {
225                        o.contains_key("must")
226                            || o.contains_key("must_not")
227                            || o.contains_key("should")
228                    })
229                    .unwrap_or(false)
230                {
231                    let filter = QdrantFilter(value).interpret()?;
232
233                    Ok(Condition {
234                        condition_one_of: filter.map(ConditionOneOf::Filter),
235                    })
236                } else if let Some(key) = value.get("key").and_then(|k| k.as_str()) {
237                    let mut field_condition = FieldCondition {
238                        key: key.to_string(),
239                        ..Default::default()
240                    };
241
242                    // Handle match condition
243                    if let Some(match_obj) = value.get("match")
244                        && let Some(val) = match_obj.get("value")
245                    {
246                        field_condition.r#match = Some(Match {
247                            match_value: Some(to_match(val.clone())?),
248                        });
249                    }
250
251                    // Handle range condition
252                    if let Some(range_obj) = value.get("range") {
253                        let mut range = Range::default();
254
255                        if let Some(gt) = range_obj.get("gt") {
256                            range.gt = gt.as_f64();
257                        }
258                        if let Some(gte) = range_obj.get("gte") {
259                            range.gte = gte.as_f64();
260                        }
261                        if let Some(lt) = range_obj.get("lt") {
262                            range.lt = lt.as_f64();
263                        }
264                        if let Some(lte) = range_obj.get("lte") {
265                            range.lte = lte.as_f64();
266                        }
267
268                        field_condition.range = Some(range);
269                    }
270
271                    Ok(Condition {
272                        condition_one_of: Some(
273                            qdrant_client::qdrant::condition::ConditionOneOf::Field(
274                                field_condition,
275                            ),
276                        ),
277                    })
278                } else {
279                    Err(FilterError::TypeError(value.to_string()))
280                }
281            }
282
283            fn to_filter(value: serde_json::Value) -> Result<Option<Filter>, FilterError> {
284                let mut filter = Filter::default();
285
286                if value.get("key").or(value.get("is_empty")).is_some() {
287                    let condition = to_condition(value)?;
288                    filter.must.push(condition);
289                    Ok(Some(filter))
290                } else {
291                    if let Some(must) = value.get("must")
292                        && let Some(arr) = must.as_array()
293                    {
294                        let conditions: Vec<Condition> = arr
295                            .iter()
296                            .cloned()
297                            .map(to_condition)
298                            .collect::<Result<_, _>>()?;
299                        filter.must.extend(conditions)
300                    }
301
302                    if let Some(should) = value.get("should")
303                        && let Some(arr) = should.as_array()
304                    {
305                        let conditions: Vec<Condition> = arr
306                            .iter()
307                            .cloned()
308                            .map(to_condition)
309                            .collect::<Result<_, _>>()?;
310                        filter.should.extend(conditions)
311                    }
312
313                    if let Some(must_not) = value.get("must_not")
314                        && let Some(arr) = must_not.as_array()
315                    {
316                        let conditions: Vec<Condition> = arr
317                            .iter()
318                            .cloned()
319                            .map(to_condition)
320                            .collect::<Result<_, _>>()?;
321                        filter.must_not.extend(conditions)
322                    }
323
324                    if filter.must.is_empty()
325                        && filter.should.is_empty()
326                        && filter.must_not.is_empty()
327                    {
328                        Ok(None)
329                    } else {
330                        Ok(Some(filter))
331                    }
332                }
333            }
334
335            to_filter(value)
336        }
337    }
338}