1use qdrant_client::qdrant::{
2 Condition, FieldCondition, Filter, IsEmptyCondition, IsNullCondition, Match, Range,
3 condition::ConditionOneOf, r#match::MatchValue,
4};
5use rig_core::vector_store::request::{FilterError, SearchFilter};
6use serde::{Deserialize, Serialize};
7use serde_json::json;
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct QdrantFilter(serde_json::Value);
15
16impl SearchFilter for QdrantFilter {
17 type Value = serde_json::Value;
18
19 fn eq(key: impl AsRef<str>, value: Self::Value) -> Self {
20 let key = key.as_ref().to_owned();
21
22 Self(json!({
23 "key": key,
24 "match": {
25 "value": value
26 }
27 }))
28 }
29
30 fn gt(key: impl AsRef<str>, value: Self::Value) -> Self {
31 let key = key.as_ref().to_owned();
32
33 Self(json!({
34 "key": key,
35 "range": {
36 "gt": value
37 }
38 }))
39 }
40
41 fn lt(key: impl AsRef<str>, value: Self::Value) -> Self {
42 let key = key.as_ref().to_owned();
43
44 Self(json!({
45 "key": key,
46 "range": {
47 "lt": value
48 }
49 }))
50 }
51
52 fn and(self, rhs: Self) -> Self {
53 Self(json!({ "must": [ self.0, rhs.0 ]}))
54 }
55
56 fn or(self, rhs: Self) -> Self {
57 Self(json!({ "should": [ self.0, rhs.0 ]}))
58 }
59}
60
61impl QdrantFilter {
62 #[allow(clippy::should_implement_trait)]
63 pub fn not(self) -> Self {
64 Self(json!({ "must_not": [ self.0 ]}))
65 }
66 pub fn into_inner(self) -> serde_json::Value {
67 self.0
68 }
69
70 pub fn exists(key: String) -> Self {
71 Self(json!({ "key": key, "is_null": { "value": false } }))
72 }
73
74 pub fn is_null(key: String) -> Self {
75 Self(json!({ "key": key, "is_null": { "value": true } }))
76 }
77
78 pub fn is_empty(key: String) -> Self {
79 Self(json!({ "is_empty": { "key": key } }))
80 }
81
82 pub fn range_exclusive(key: String, lo: serde_json::Value, hi: serde_json::Value) -> Self {
84 Self(json!({
85 "key": key,
86 "range": {
87 "gt": lo,
88 "lt": hi
89 }
90 }))
91 }
92
93 pub fn range_lower_inclusive(
95 key: String,
96 lo: serde_json::Value,
97 hi: serde_json::Value,
98 ) -> Self {
99 Self(json!({
100 "key": key,
101 "range": {
102 "gt": lo,
103 "lte": hi
104 }
105 }))
106 }
107
108 pub fn range_higher_inclusive(
110 key: String,
111 lo: serde_json::Value,
112 hi: serde_json::Value,
113 ) -> Self {
114 Self(json!({
115 "key": key,
116 "range": {
117 "gte": lo,
118 "lt": hi
119 }
120 }))
121 }
122
123 pub fn range_inclusive(key: String, lo: serde_json::Value, hi: serde_json::Value) -> Self {
125 Self(json!({
126 "key": key,
127 "range": {
128 "gte": lo,
129 "lte": hi
130 }
131 }))
132 }
133
134 pub fn interpret(self) -> Result<Option<Filter>, FilterError> {
135 use serde_json::Value::*;
136
137 let value = self.into_inner();
138
139 if let Null = value {
140 Ok(None)
141 } else if json!({}) == value {
142 Ok(None)
143 } else {
144 fn to_match(value: serde_json::Value) -> Result<MatchValue, FilterError> {
145 match value {
146 String(s) => Ok(MatchValue::Keyword(s)),
147 Bool(b) => Ok(MatchValue::Boolean(b)),
148 Number(n) => {
149 if let Some(as_int) = n.as_i64() {
150 Ok(MatchValue::Integer(as_int))
151 } else {
152 Err(FilterError::Expected {
153 expected: "Integer".into(),
154 got: n.to_string(),
155 })
156 }
157 }
158 _ => Err(FilterError::TypeError(value.to_string())),
159 }
160 }
161
162 fn to_condition(value: serde_json::Value) -> Result<Condition, FilterError> {
163 if let Some(is_empty) = value.get("is_empty") {
165 let key = is_empty
166 .get("key")
167 .and_then(|k| k.as_str())
168 .ok_or(FilterError::MissingField("key".into()))?
169 .to_string();
170
171 Ok(Condition {
172 condition_one_of: Some(
173 qdrant_client::qdrant::condition::ConditionOneOf::IsEmpty(
174 IsEmptyCondition { key },
175 ),
176 ),
177 })
178 } else if let Some(is_null) = value.get("is_null") {
179 let is_null_value =
180 is_null
181 .get("value")
182 .and_then(|v| v.as_bool())
183 .ok_or(FilterError::Must(
184 "is_null".into(),
185 "have a 'value' field".into(),
186 ))?;
187
188 let key = value
190 .get("key")
191 .and_then(|k| k.as_str())
192 .ok_or(FilterError::Must(
193 "is_null".into(),
194 "have a 'key' field".into(),
195 ))?
196 .to_string();
197
198 if is_null_value {
199 Ok(Condition {
200 condition_one_of: Some(
201 qdrant_client::qdrant::condition::ConditionOneOf::IsNull(
202 IsNullCondition { key },
203 ),
204 ),
205 })
206 } else {
207 let is_empty_condition = Condition {
208 condition_one_of: Some(ConditionOneOf::IsEmpty(IsEmptyCondition {
209 key,
210 })),
211 };
212
213 let filter = Filter {
214 must_not: vec![is_empty_condition],
215 ..Default::default()
216 };
217
218 Ok(Condition {
219 condition_one_of: Some(ConditionOneOf::Filter(filter)),
220 })
221 }
222 } else if value
223 .as_object()
224 .map(|o| {
225 o.contains_key("must")
226 || o.contains_key("must_not")
227 || o.contains_key("should")
228 })
229 .unwrap_or(false)
230 {
231 let filter = QdrantFilter(value).interpret()?;
232
233 Ok(Condition {
234 condition_one_of: filter.map(ConditionOneOf::Filter),
235 })
236 } else if let Some(key) = value.get("key").and_then(|k| k.as_str()) {
237 let mut field_condition = FieldCondition {
238 key: key.to_string(),
239 ..Default::default()
240 };
241
242 if let Some(match_obj) = value.get("match")
244 && let Some(val) = match_obj.get("value")
245 {
246 field_condition.r#match = Some(Match {
247 match_value: Some(to_match(val.clone())?),
248 });
249 }
250
251 if let Some(range_obj) = value.get("range") {
253 let mut range = Range::default();
254
255 if let Some(gt) = range_obj.get("gt") {
256 range.gt = gt.as_f64();
257 }
258 if let Some(gte) = range_obj.get("gte") {
259 range.gte = gte.as_f64();
260 }
261 if let Some(lt) = range_obj.get("lt") {
262 range.lt = lt.as_f64();
263 }
264 if let Some(lte) = range_obj.get("lte") {
265 range.lte = lte.as_f64();
266 }
267
268 field_condition.range = Some(range);
269 }
270
271 Ok(Condition {
272 condition_one_of: Some(
273 qdrant_client::qdrant::condition::ConditionOneOf::Field(
274 field_condition,
275 ),
276 ),
277 })
278 } else {
279 Err(FilterError::TypeError(value.to_string()))
280 }
281 }
282
283 fn to_filter(value: serde_json::Value) -> Result<Option<Filter>, FilterError> {
284 let mut filter = Filter::default();
285
286 if value.get("key").or(value.get("is_empty")).is_some() {
287 let condition = to_condition(value)?;
288 filter.must.push(condition);
289 Ok(Some(filter))
290 } else {
291 if let Some(must) = value.get("must")
292 && let Some(arr) = must.as_array()
293 {
294 let conditions: Vec<Condition> = arr
295 .iter()
296 .cloned()
297 .map(to_condition)
298 .collect::<Result<_, _>>()?;
299 filter.must.extend(conditions)
300 }
301
302 if let Some(should) = value.get("should")
303 && let Some(arr) = should.as_array()
304 {
305 let conditions: Vec<Condition> = arr
306 .iter()
307 .cloned()
308 .map(to_condition)
309 .collect::<Result<_, _>>()?;
310 filter.should.extend(conditions)
311 }
312
313 if let Some(must_not) = value.get("must_not")
314 && let Some(arr) = must_not.as_array()
315 {
316 let conditions: Vec<Condition> = arr
317 .iter()
318 .cloned()
319 .map(to_condition)
320 .collect::<Result<_, _>>()?;
321 filter.must_not.extend(conditions)
322 }
323
324 if filter.must.is_empty()
325 && filter.should.is_empty()
326 && filter.must_not.is_empty()
327 {
328 Ok(None)
329 } else {
330 Ok(Some(filter))
331 }
332 }
333 }
334
335 to_filter(value)
336 }
337 }
338}