1use serde::{Deserialize, Serialize};
2
3use super::VectorStoreError;
4
5#[derive(Clone, Serialize, Deserialize, Debug)]
7pub struct VectorSearchRequest<F = Filter<serde_json::Value>> {
8 query: String,
9 query_vector_name: Option<String>,
10 samples: u64,
11 threshold: Option<f64>,
12 additional_params: Option<serde_json::Value>,
13 filter: Option<F>,
14}
15
16impl<Filter> VectorSearchRequest<Filter> {
17 pub fn builder() -> VectorSearchRequestBuilder<Filter> {
18 VectorSearchRequestBuilder::<Filter>::default()
19 }
20
21 pub fn query(&self) -> &str {
22 &self.query
23 }
24
25 pub fn query_vector_name(&self) -> Option<&str> {
26 self.query_vector_name.as_deref()
27 }
28
29 pub fn samples(&self) -> u64 {
30 self.samples
31 }
32
33 pub fn threshold(&self) -> Option<f64> {
34 self.threshold
35 }
36
37 pub fn filter(&self) -> &Option<Filter> {
38 &self.filter
39 }
40
41 pub fn map_filter<T, F>(self, f: F) -> VectorSearchRequest<T>
42 where
43 F: Fn(Filter) -> T,
44 {
45 VectorSearchRequest {
46 query: self.query,
47 query_vector_name: self.query_vector_name,
48 samples: self.samples,
49 threshold: self.threshold,
50 additional_params: self.additional_params,
51 filter: self.filter.map(f),
52 }
53 }
54}
55
56#[derive(Debug, Clone, thiserror::Error)]
57pub enum FilterError {
58 #[error("Expected: {expected}, got: {got}")]
59 Expected { expected: String, got: String },
60 #[error("Cannot compile '{0}' to the backend's filter type")]
61 TypeError(String),
62 #[error("Missing field '{0}'")]
63 MissingField(String),
64 #[error("'{0}' must {1}")]
65 Must(String, String),
66 #[error("Filter serialization failed: {0}")]
67 Serialization(String),
68}
69
70pub trait SearchFilter {
71 type Value;
72
73 fn eq(key: String, value: Self::Value) -> Self;
74 fn gt(key: String, value: Self::Value) -> Self;
75 fn lt(key: String, value: Self::Value) -> Self;
76 fn and(self, rhs: Self) -> Self;
77 fn or(self, rhs: Self) -> Self;
78}
79
80#[derive(Debug, Clone, Serialize, Deserialize)]
81#[serde(rename_all = "lowercase")]
82pub enum Filter<V>
83where
84 V: std::fmt::Debug + Clone,
85{
86 Eq(String, V),
87 Gt(String, V),
88 Lt(String, V),
89 And(Box<Self>, Box<Self>),
90 Or(Box<Self>, Box<Self>),
91}
92
93impl<V> SearchFilter for Filter<V>
94where
95 V: std::fmt::Debug + Clone + Serialize + for<'de> Deserialize<'de>,
96{
97 type Value = V;
98
99 fn eq(key: String, value: Self::Value) -> Self {
100 Self::Eq(key, value)
101 }
102
103 fn gt(key: String, value: Self::Value) -> Self {
104 Self::Gt(key, value)
105 }
106
107 fn lt(key: String, value: Self::Value) -> Self {
108 Self::Lt(key, value)
109 }
110
111 fn and(self, rhs: Self) -> Self {
112 Self::And(self.into(), rhs.into())
113 }
114
115 fn or(self, rhs: Self) -> Self {
116 Self::Or(self.into(), rhs.into())
117 }
118}
119
120impl<V> Filter<V>
121where
122 V: std::fmt::Debug + Clone,
123{
124 pub fn interpret<F>(self) -> F
125 where
126 F: SearchFilter<Value = V>,
127 {
128 match self {
129 Self::Eq(key, val) => F::eq(key, val),
130 Self::Gt(key, val) => F::gt(key, val),
131 Self::Lt(key, val) => F::lt(key, val),
132 Self::And(lhs, rhs) => F::and(lhs.interpret(), rhs.interpret()),
133 Self::Or(lhs, rhs) => F::or(lhs.interpret(), rhs.interpret()),
134 }
135 }
136}
137
138impl Filter<serde_json::Value> {
139 pub fn satisfies(&self, value: &serde_json::Value) -> bool {
140 use Filter::*;
141 use serde_json::{Value, Value::*, json};
142 use std::cmp::Ordering;
143
144 fn compare_pair(l: &Value, r: &Value) -> Option<std::cmp::Ordering> {
145 match (l, r) {
146 (Number(l), Number(r)) => l
147 .as_f64()
148 .zip(r.as_f64())
149 .and_then(|(l, r)| l.partial_cmp(&r))
150 .or(l.as_i64().zip(r.as_i64()).map(|(l, r)| l.cmp(&r)))
151 .or(l.as_u64().zip(r.as_u64()).map(|(l, r)| l.cmp(&r))),
152 (String(l), String(r)) => Some(l.cmp(r)),
153 (Null, Null) => Some(std::cmp::Ordering::Equal),
154 (Bool(l), Bool(r)) => Some(l.cmp(r)),
155 _ => None,
156 }
157 }
158
159 match self {
160 Eq(k, v) => &json!({ k: v }) == value,
161 Gt(k, v) => {
162 compare_pair(&json!({k: v}), value).is_some_and(|ord| ord == Ordering::Greater)
163 }
164 Lt(k, v) => {
165 compare_pair(&json!({k: v}), value).is_some_and(|ord| ord == Ordering::Less)
166 }
167 And(l, r) => l.satisfies(value) && r.satisfies(value),
168 Or(l, r) => l.satisfies(value) || r.satisfies(value),
169 }
170 }
171}
172
173#[derive(Clone, Serialize, Deserialize, Debug)]
174pub struct VectorSearchRequestBuilder<F = Filter<serde_json::Value>> {
175 query: Option<String>,
176 query_vector_name: Option<String>,
177 samples: Option<u64>,
178 threshold: Option<f64>,
179 additional_params: Option<serde_json::Value>,
180 filter: Option<F>,
181}
182
183impl<F> Default for VectorSearchRequestBuilder<F> {
184 fn default() -> Self {
185 Self {
186 query: None,
187 query_vector_name: None,
188 samples: None,
189 threshold: None,
190 additional_params: None,
191 filter: None,
192 }
193 }
194}
195
196impl<F> VectorSearchRequestBuilder<F>
197where
198 F: SearchFilter,
199{
200 pub fn query<T>(mut self, query: T) -> Self
201 where
202 T: Into<String>,
203 {
204 self.query = Some(query.into());
205 self
206 }
207
208 pub fn samples(mut self, samples: u64) -> Self {
209 self.samples = Some(samples);
210 self
211 }
212
213 pub fn query_vector_name<T>(mut self, name: T) -> Self
214 where
215 T: Into<String>,
216 {
217 self.query_vector_name = Some(name.into());
218 self
219 }
220
221 pub fn threshold(mut self, threshold: f64) -> Self {
222 self.threshold = Some(threshold);
223 self
224 }
225
226 pub fn additional_params(
227 mut self,
228 params: serde_json::Value,
229 ) -> Result<Self, VectorStoreError> {
230 self.additional_params = Some(params);
231 Ok(self)
232 }
233
234 pub fn filter(mut self, filter: F) -> Self {
235 self.filter = Some(filter);
236 self
237 }
238
239 pub fn build(self) -> Result<VectorSearchRequest<F>, VectorStoreError> {
240 let Some(query) = self.query else {
241 return Err(VectorStoreError::BuilderError(
242 "`query` is a required variable for building a vector search request".into(),
243 ));
244 };
245
246 let Some(samples) = self.samples else {
247 return Err(VectorStoreError::BuilderError(
248 "`samples` is a required variable for building a vector search request".into(),
249 ));
250 };
251
252 let additional_params = if let Some(params) = self.additional_params {
253 if !params.is_object() {
254 return Err(VectorStoreError::BuilderError(
255 "Expected JSON object for additional params, got something else".into(),
256 ));
257 }
258 Some(params)
259 } else {
260 None
261 };
262
263 Ok(VectorSearchRequest {
264 query,
265 query_vector_name: self.query_vector_name,
266 samples,
267 threshold: self.threshold,
268 additional_params,
269 filter: self.filter,
270 })
271 }
272}
273
274#[cfg(test)]
275mod tests {
276 use super::*;
277 use serde_json::json;
278
279 #[test]
280 fn test_builder_missing_query() {
281 let result = VectorSearchRequest::<Filter<serde_json::Value>>::builder()
282 .samples(10)
283 .build();
284 assert!(result.is_err());
285 assert!(result.unwrap_err().to_string().contains("query"));
286 }
287
288 #[test]
289 fn test_builder_missing_samples() {
290 let result = VectorSearchRequest::<Filter<serde_json::Value>>::builder()
291 .query("test")
292 .build();
293 assert!(result.is_err());
294 assert!(result.unwrap_err().to_string().contains("samples"));
295 }
296
297 #[test]
298 fn test_builder_non_object_additional_params() {
299 let result = VectorSearchRequest::<Filter<serde_json::Value>>::builder()
300 .query("test")
301 .samples(5)
302 .additional_params(json!("not an object"))
303 .unwrap()
304 .build();
305 assert!(result.is_err());
306 assert!(result.unwrap_err().to_string().contains("JSON object"));
307 }
308
309 #[test]
310 fn test_builder_success_with_all_options() {
311 let filter = Filter::eq("color".to_string(), json!("red"));
312 let result = VectorSearchRequest::builder()
313 .query("search query")
314 .samples(10)
315 .threshold(0.8)
316 .additional_params(json!({"key": "value"}))
317 .unwrap()
318 .filter(filter)
319 .build();
320 assert!(result.is_ok());
321 let req = result.unwrap();
322 assert_eq!(req.query(), "search query");
323 assert_eq!(req.samples(), 10);
324 assert_eq!(req.threshold(), Some(0.8));
325 assert!(req.filter().is_some());
326 }
327
328 #[test]
329 fn test_builder_minimal_success() {
330 let result = VectorSearchRequest::<Filter<serde_json::Value>>::builder()
331 .query("q")
332 .samples(1)
333 .build();
334 assert!(result.is_ok());
335 let req = result.unwrap();
336 assert_eq!(req.query(), "q");
337 assert_eq!(req.samples(), 1);
338 assert_eq!(req.threshold(), None);
339 assert!(req.filter().is_none());
340 }
341
342 #[test]
343 fn test_filter_constructors() {
344 let eq: Filter<serde_json::Value> = SearchFilter::eq("k".to_string(), json!("v"));
345 assert!(matches!(eq, Filter::Eq(_, _)));
346
347 let gt: Filter<serde_json::Value> = SearchFilter::gt("k".to_string(), json!(10));
348 assert!(matches!(gt, Filter::Gt(_, _)));
349
350 let lt: Filter<serde_json::Value> = SearchFilter::lt("k".to_string(), json!(5));
351 assert!(matches!(lt, Filter::Lt(_, _)));
352 }
353
354 #[test]
355 fn test_filter_and_or() {
356 let f1: Filter<serde_json::Value> = SearchFilter::eq("a".to_string(), json!(1));
357 let f2: Filter<serde_json::Value> = SearchFilter::eq("b".to_string(), json!(2));
358 let combined = SearchFilter::and(f1, f2);
359 assert!(matches!(combined, Filter::And(_, _)));
360
361 let f3: Filter<serde_json::Value> = SearchFilter::eq("c".to_string(), json!(3));
362 let f4: Filter<serde_json::Value> = SearchFilter::eq("d".to_string(), json!(4));
363 let either = SearchFilter::or(f3, f4);
364 assert!(matches!(either, Filter::Or(_, _)));
365 }
366
367 #[test]
368 fn test_filter_satisfies_eq_match() {
369 let filter = Filter::Eq("color".to_string(), json!("red"));
370 assert!(filter.satisfies(&json!({"color": "red"})));
371 }
372
373 #[test]
374 fn test_filter_satisfies_eq_mismatch() {
375 let filter = Filter::Eq("color".to_string(), json!("red"));
376 assert!(!filter.satisfies(&json!({"color": "blue"})));
377 }
378
379 #[test]
380 fn test_filter_satisfies_and() {
381 let f = Filter::And(
382 Box::new(Filter::Eq("a".to_string(), json!(1))),
383 Box::new(Filter::Eq("b".to_string(), json!(2))),
384 );
385 assert!(!f.satisfies(&json!({"a": 1})));
388 }
389
390 #[test]
391 fn test_filter_satisfies_or() {
392 let f = Filter::Or(
393 Box::new(Filter::Eq("a".to_string(), json!(1))),
394 Box::new(Filter::Eq("b".to_string(), json!(2))),
395 );
396 assert!(f.satisfies(&json!({"a": 1})));
397 assert!(f.satisfies(&json!({"b": 2})));
398 assert!(!f.satisfies(&json!({"c": 3})));
399 }
400
401 #[test]
402 fn test_filter_satisfies_gt_and_lt() {
403 let gt = Filter::Gt("score".to_string(), json!(5));
404 assert!(!gt.satisfies(&json!({"score": 3})));
405 assert!(!gt.satisfies(&json!({"score": 7})));
406
407 let lt = Filter::Lt("score".to_string(), json!(5));
408 assert!(!lt.satisfies(&json!({"score": 7})));
409 assert!(!lt.satisfies(&json!({"score": 3})));
410 }
411
412 #[test]
413 fn test_filter_satisfies_incompatible_types() {
414 let gt = Filter::Gt("score".to_string(), json!("high"));
415 assert!(!gt.satisfies(&json!({"score": 3})));
416 }
417
418 #[test]
419 fn test_filter_interpret_roundtrip() {
420 let original: Filter<serde_json::Value> = Filter::Eq("key".to_string(), json!("value"));
421 let interpreted: Filter<serde_json::Value> = original.interpret();
422 assert!(matches!(interpreted, Filter::Eq(ref k, _) if k == "key"));
423 }
424
425 #[test]
426 fn test_filter_interpret_compound() {
427 let f: Filter<serde_json::Value> = Filter::And(
428 Box::new(Filter::Gt("x".to_string(), json!(10))),
429 Box::new(Filter::Lt("y".to_string(), json!(20))),
430 );
431 let interpreted: Filter<serde_json::Value> = f.interpret();
432 assert!(matches!(interpreted, Filter::And(_, _)));
433 }
434
435 #[test]
436 fn test_map_filter() {
437 let req = VectorSearchRequest::<Filter<serde_json::Value>>::builder()
438 .query("q")
439 .samples(5)
440 .filter(Filter::Eq("k".to_string(), json!("v")))
441 .build()
442 .unwrap();
443
444 let mapped = req.map_filter(|f| format!("{f:?}"));
445 assert_eq!(mapped.query(), "q");
446 assert_eq!(mapped.samples(), 5);
447 assert!(mapped.filter().is_some());
448 }
449
450 #[test]
451 fn test_filter_serialize_deserialize() {
452 let filter: Filter<serde_json::Value> = Filter::Eq("name".to_string(), json!("test"));
453 let json = serde_json::to_string(&filter).unwrap();
454 let deserialized: Filter<serde_json::Value> = serde_json::from_str(&json).unwrap();
455 assert!(matches!(deserialized, Filter::Eq(ref k, _) if k == "name"));
456 }
457}