Skip to main content

autoagents_core/vector_store/
request.rs

1use serde::{Deserialize, Serialize};
2
3use super::VectorStoreError;
4
5/// A vector search request - used in the [`super::VectorStoreIndex`] trait.
6#[derive(Clone, Serialize, Deserialize, Debug)]
7pub struct VectorSearchRequest<F = Filter<serde_json::Value>> {
8    query: String,
9    query_vector_name: Option<String>,
10    samples: u64,
11    threshold: Option<f64>,
12    additional_params: Option<serde_json::Value>,
13    filter: Option<F>,
14}
15
16impl<Filter> VectorSearchRequest<Filter> {
17    pub fn builder() -> VectorSearchRequestBuilder<Filter> {
18        VectorSearchRequestBuilder::<Filter>::default()
19    }
20
21    pub fn query(&self) -> &str {
22        &self.query
23    }
24
25    pub fn query_vector_name(&self) -> Option<&str> {
26        self.query_vector_name.as_deref()
27    }
28
29    pub fn samples(&self) -> u64 {
30        self.samples
31    }
32
33    pub fn threshold(&self) -> Option<f64> {
34        self.threshold
35    }
36
37    pub fn filter(&self) -> &Option<Filter> {
38        &self.filter
39    }
40
41    pub fn map_filter<T, F>(self, f: F) -> VectorSearchRequest<T>
42    where
43        F: Fn(Filter) -> T,
44    {
45        VectorSearchRequest {
46            query: self.query,
47            query_vector_name: self.query_vector_name,
48            samples: self.samples,
49            threshold: self.threshold,
50            additional_params: self.additional_params,
51            filter: self.filter.map(f),
52        }
53    }
54}
55
56#[derive(Debug, Clone, thiserror::Error)]
57pub enum FilterError {
58    #[error("Expected: {expected}, got: {got}")]
59    Expected { expected: String, got: String },
60    #[error("Cannot compile '{0}' to the backend's filter type")]
61    TypeError(String),
62    #[error("Missing field '{0}'")]
63    MissingField(String),
64    #[error("'{0}' must {1}")]
65    Must(String, String),
66    #[error("Filter serialization failed: {0}")]
67    Serialization(String),
68}
69
70pub trait SearchFilter {
71    type Value;
72
73    fn eq(key: String, value: Self::Value) -> Self;
74    fn gt(key: String, value: Self::Value) -> Self;
75    fn lt(key: String, value: Self::Value) -> Self;
76    fn and(self, rhs: Self) -> Self;
77    fn or(self, rhs: Self) -> Self;
78}
79
80#[derive(Debug, Clone, Serialize, Deserialize)]
81#[serde(rename_all = "lowercase")]
82pub enum Filter<V>
83where
84    V: std::fmt::Debug + Clone,
85{
86    Eq(String, V),
87    Gt(String, V),
88    Lt(String, V),
89    And(Box<Self>, Box<Self>),
90    Or(Box<Self>, Box<Self>),
91}
92
93impl<V> SearchFilter for Filter<V>
94where
95    V: std::fmt::Debug + Clone + Serialize + for<'de> Deserialize<'de>,
96{
97    type Value = V;
98
99    fn eq(key: String, value: Self::Value) -> Self {
100        Self::Eq(key, value)
101    }
102
103    fn gt(key: String, value: Self::Value) -> Self {
104        Self::Gt(key, value)
105    }
106
107    fn lt(key: String, value: Self::Value) -> Self {
108        Self::Lt(key, value)
109    }
110
111    fn and(self, rhs: Self) -> Self {
112        Self::And(self.into(), rhs.into())
113    }
114
115    fn or(self, rhs: Self) -> Self {
116        Self::Or(self.into(), rhs.into())
117    }
118}
119
120impl<V> Filter<V>
121where
122    V: std::fmt::Debug + Clone,
123{
124    pub fn interpret<F>(self) -> F
125    where
126        F: SearchFilter<Value = V>,
127    {
128        match self {
129            Self::Eq(key, val) => F::eq(key, val),
130            Self::Gt(key, val) => F::gt(key, val),
131            Self::Lt(key, val) => F::lt(key, val),
132            Self::And(lhs, rhs) => F::and(lhs.interpret(), rhs.interpret()),
133            Self::Or(lhs, rhs) => F::or(lhs.interpret(), rhs.interpret()),
134        }
135    }
136}
137
138impl Filter<serde_json::Value> {
139    pub fn satisfies(&self, value: &serde_json::Value) -> bool {
140        use Filter::*;
141        use serde_json::{Value, Value::*, json};
142        use std::cmp::Ordering;
143
144        fn compare_pair(l: &Value, r: &Value) -> Option<std::cmp::Ordering> {
145            match (l, r) {
146                (Number(l), Number(r)) => l
147                    .as_f64()
148                    .zip(r.as_f64())
149                    .and_then(|(l, r)| l.partial_cmp(&r))
150                    .or(l.as_i64().zip(r.as_i64()).map(|(l, r)| l.cmp(&r)))
151                    .or(l.as_u64().zip(r.as_u64()).map(|(l, r)| l.cmp(&r))),
152                (String(l), String(r)) => Some(l.cmp(r)),
153                (Null, Null) => Some(std::cmp::Ordering::Equal),
154                (Bool(l), Bool(r)) => Some(l.cmp(r)),
155                _ => None,
156            }
157        }
158
159        match self {
160            Eq(k, v) => &json!({ k: v }) == value,
161            Gt(k, v) => {
162                compare_pair(&json!({k: v}), value).is_some_and(|ord| ord == Ordering::Greater)
163            }
164            Lt(k, v) => {
165                compare_pair(&json!({k: v}), value).is_some_and(|ord| ord == Ordering::Less)
166            }
167            And(l, r) => l.satisfies(value) && r.satisfies(value),
168            Or(l, r) => l.satisfies(value) || r.satisfies(value),
169        }
170    }
171}
172
173#[derive(Clone, Serialize, Deserialize, Debug)]
174pub struct VectorSearchRequestBuilder<F = Filter<serde_json::Value>> {
175    query: Option<String>,
176    query_vector_name: Option<String>,
177    samples: Option<u64>,
178    threshold: Option<f64>,
179    additional_params: Option<serde_json::Value>,
180    filter: Option<F>,
181}
182
183impl<F> Default for VectorSearchRequestBuilder<F> {
184    fn default() -> Self {
185        Self {
186            query: None,
187            query_vector_name: None,
188            samples: None,
189            threshold: None,
190            additional_params: None,
191            filter: None,
192        }
193    }
194}
195
196impl<F> VectorSearchRequestBuilder<F>
197where
198    F: SearchFilter,
199{
200    pub fn query<T>(mut self, query: T) -> Self
201    where
202        T: Into<String>,
203    {
204        self.query = Some(query.into());
205        self
206    }
207
208    pub fn samples(mut self, samples: u64) -> Self {
209        self.samples = Some(samples);
210        self
211    }
212
213    pub fn query_vector_name<T>(mut self, name: T) -> Self
214    where
215        T: Into<String>,
216    {
217        self.query_vector_name = Some(name.into());
218        self
219    }
220
221    pub fn threshold(mut self, threshold: f64) -> Self {
222        self.threshold = Some(threshold);
223        self
224    }
225
226    pub fn additional_params(
227        mut self,
228        params: serde_json::Value,
229    ) -> Result<Self, VectorStoreError> {
230        self.additional_params = Some(params);
231        Ok(self)
232    }
233
234    pub fn filter(mut self, filter: F) -> Self {
235        self.filter = Some(filter);
236        self
237    }
238
239    pub fn build(self) -> Result<VectorSearchRequest<F>, VectorStoreError> {
240        let Some(query) = self.query else {
241            return Err(VectorStoreError::BuilderError(
242                "`query` is a required variable for building a vector search request".into(),
243            ));
244        };
245
246        let Some(samples) = self.samples else {
247            return Err(VectorStoreError::BuilderError(
248                "`samples` is a required variable for building a vector search request".into(),
249            ));
250        };
251
252        let additional_params = if let Some(params) = self.additional_params {
253            if !params.is_object() {
254                return Err(VectorStoreError::BuilderError(
255                    "Expected JSON object for additional params, got something else".into(),
256                ));
257            }
258            Some(params)
259        } else {
260            None
261        };
262
263        Ok(VectorSearchRequest {
264            query,
265            query_vector_name: self.query_vector_name,
266            samples,
267            threshold: self.threshold,
268            additional_params,
269            filter: self.filter,
270        })
271    }
272}
273
274#[cfg(test)]
275mod tests {
276    use super::*;
277    use serde_json::json;
278
279    #[test]
280    fn test_builder_missing_query() {
281        let result = VectorSearchRequest::<Filter<serde_json::Value>>::builder()
282            .samples(10)
283            .build();
284        assert!(result.is_err());
285        assert!(result.unwrap_err().to_string().contains("query"));
286    }
287
288    #[test]
289    fn test_builder_missing_samples() {
290        let result = VectorSearchRequest::<Filter<serde_json::Value>>::builder()
291            .query("test")
292            .build();
293        assert!(result.is_err());
294        assert!(result.unwrap_err().to_string().contains("samples"));
295    }
296
297    #[test]
298    fn test_builder_non_object_additional_params() {
299        let result = VectorSearchRequest::<Filter<serde_json::Value>>::builder()
300            .query("test")
301            .samples(5)
302            .additional_params(json!("not an object"))
303            .unwrap()
304            .build();
305        assert!(result.is_err());
306        assert!(result.unwrap_err().to_string().contains("JSON object"));
307    }
308
309    #[test]
310    fn test_builder_success_with_all_options() {
311        let filter = Filter::eq("color".to_string(), json!("red"));
312        let result = VectorSearchRequest::builder()
313            .query("search query")
314            .samples(10)
315            .threshold(0.8)
316            .additional_params(json!({"key": "value"}))
317            .unwrap()
318            .filter(filter)
319            .build();
320        assert!(result.is_ok());
321        let req = result.unwrap();
322        assert_eq!(req.query(), "search query");
323        assert_eq!(req.samples(), 10);
324        assert_eq!(req.threshold(), Some(0.8));
325        assert!(req.filter().is_some());
326    }
327
328    #[test]
329    fn test_builder_minimal_success() {
330        let result = VectorSearchRequest::<Filter<serde_json::Value>>::builder()
331            .query("q")
332            .samples(1)
333            .build();
334        assert!(result.is_ok());
335        let req = result.unwrap();
336        assert_eq!(req.query(), "q");
337        assert_eq!(req.samples(), 1);
338        assert_eq!(req.threshold(), None);
339        assert!(req.filter().is_none());
340    }
341
342    #[test]
343    fn test_filter_constructors() {
344        let eq: Filter<serde_json::Value> = SearchFilter::eq("k".to_string(), json!("v"));
345        assert!(matches!(eq, Filter::Eq(_, _)));
346
347        let gt: Filter<serde_json::Value> = SearchFilter::gt("k".to_string(), json!(10));
348        assert!(matches!(gt, Filter::Gt(_, _)));
349
350        let lt: Filter<serde_json::Value> = SearchFilter::lt("k".to_string(), json!(5));
351        assert!(matches!(lt, Filter::Lt(_, _)));
352    }
353
354    #[test]
355    fn test_filter_and_or() {
356        let f1: Filter<serde_json::Value> = SearchFilter::eq("a".to_string(), json!(1));
357        let f2: Filter<serde_json::Value> = SearchFilter::eq("b".to_string(), json!(2));
358        let combined = SearchFilter::and(f1, f2);
359        assert!(matches!(combined, Filter::And(_, _)));
360
361        let f3: Filter<serde_json::Value> = SearchFilter::eq("c".to_string(), json!(3));
362        let f4: Filter<serde_json::Value> = SearchFilter::eq("d".to_string(), json!(4));
363        let either = SearchFilter::or(f3, f4);
364        assert!(matches!(either, Filter::Or(_, _)));
365    }
366
367    #[test]
368    fn test_filter_satisfies_eq_match() {
369        let filter = Filter::Eq("color".to_string(), json!("red"));
370        assert!(filter.satisfies(&json!({"color": "red"})));
371    }
372
373    #[test]
374    fn test_filter_satisfies_eq_mismatch() {
375        let filter = Filter::Eq("color".to_string(), json!("red"));
376        assert!(!filter.satisfies(&json!({"color": "blue"})));
377    }
378
379    #[test]
380    fn test_filter_satisfies_and() {
381        let f = Filter::And(
382            Box::new(Filter::Eq("a".to_string(), json!(1))),
383            Box::new(Filter::Eq("b".to_string(), json!(2))),
384        );
385        // Note: satisfies checks json!({k:v}) == value, so both must match same value
386        // This won't match a single object with both - the Eq check is per-key
387        assert!(!f.satisfies(&json!({"a": 1})));
388    }
389
390    #[test]
391    fn test_filter_satisfies_or() {
392        let f = Filter::Or(
393            Box::new(Filter::Eq("a".to_string(), json!(1))),
394            Box::new(Filter::Eq("b".to_string(), json!(2))),
395        );
396        assert!(f.satisfies(&json!({"a": 1})));
397        assert!(f.satisfies(&json!({"b": 2})));
398        assert!(!f.satisfies(&json!({"c": 3})));
399    }
400
401    #[test]
402    fn test_filter_satisfies_gt_and_lt() {
403        let gt = Filter::Gt("score".to_string(), json!(5));
404        assert!(!gt.satisfies(&json!({"score": 3})));
405        assert!(!gt.satisfies(&json!({"score": 7})));
406
407        let lt = Filter::Lt("score".to_string(), json!(5));
408        assert!(!lt.satisfies(&json!({"score": 7})));
409        assert!(!lt.satisfies(&json!({"score": 3})));
410    }
411
412    #[test]
413    fn test_filter_satisfies_incompatible_types() {
414        let gt = Filter::Gt("score".to_string(), json!("high"));
415        assert!(!gt.satisfies(&json!({"score": 3})));
416    }
417
418    #[test]
419    fn test_filter_interpret_roundtrip() {
420        let original: Filter<serde_json::Value> = Filter::Eq("key".to_string(), json!("value"));
421        let interpreted: Filter<serde_json::Value> = original.interpret();
422        assert!(matches!(interpreted, Filter::Eq(ref k, _) if k == "key"));
423    }
424
425    #[test]
426    fn test_filter_interpret_compound() {
427        let f: Filter<serde_json::Value> = Filter::And(
428            Box::new(Filter::Gt("x".to_string(), json!(10))),
429            Box::new(Filter::Lt("y".to_string(), json!(20))),
430        );
431        let interpreted: Filter<serde_json::Value> = f.interpret();
432        assert!(matches!(interpreted, Filter::And(_, _)));
433    }
434
435    #[test]
436    fn test_map_filter() {
437        let req = VectorSearchRequest::<Filter<serde_json::Value>>::builder()
438            .query("q")
439            .samples(5)
440            .filter(Filter::Eq("k".to_string(), json!("v")))
441            .build()
442            .unwrap();
443
444        let mapped = req.map_filter(|f| format!("{f:?}"));
445        assert_eq!(mapped.query(), "q");
446        assert_eq!(mapped.samples(), 5);
447        assert!(mapped.filter().is_some());
448    }
449
450    #[test]
451    fn test_filter_serialize_deserialize() {
452        let filter: Filter<serde_json::Value> = Filter::Eq("name".to_string(), json!("test"));
453        let json = serde_json::to_string(&filter).unwrap();
454        let deserialized: Filter<serde_json::Value> = serde_json::from_str(&json).unwrap();
455        assert!(matches!(deserialized, Filter::Eq(ref k, _) if k == "name"));
456    }
457}