Skip to main content

autoagents_core/vector_store/
request.rs

1use serde::{Deserialize, Serialize};
2
3use super::VectorStoreError;
4
5/// A vector search request - used in the [`super::VectorStoreIndex`] trait.
6#[derive(Clone, Serialize, Deserialize, Debug)]
7pub struct VectorSearchRequest<F = Filter<serde_json::Value>> {
8    query: String,
9    samples: u64,
10    threshold: Option<f64>,
11    additional_params: Option<serde_json::Value>,
12    filter: Option<F>,
13}
14
15impl<Filter> VectorSearchRequest<Filter> {
16    pub fn builder() -> VectorSearchRequestBuilder<Filter> {
17        VectorSearchRequestBuilder::<Filter>::default()
18    }
19
20    pub fn query(&self) -> &str {
21        &self.query
22    }
23
24    pub fn samples(&self) -> u64 {
25        self.samples
26    }
27
28    pub fn threshold(&self) -> Option<f64> {
29        self.threshold
30    }
31
32    pub fn filter(&self) -> &Option<Filter> {
33        &self.filter
34    }
35
36    pub fn map_filter<T, F>(self, f: F) -> VectorSearchRequest<T>
37    where
38        F: Fn(Filter) -> T,
39    {
40        VectorSearchRequest {
41            query: self.query,
42            samples: self.samples,
43            threshold: self.threshold,
44            additional_params: self.additional_params,
45            filter: self.filter.map(f),
46        }
47    }
48}
49
50#[derive(Debug, Clone, thiserror::Error)]
51pub enum FilterError {
52    #[error("Expected: {expected}, got: {got}")]
53    Expected { expected: String, got: String },
54    #[error("Cannot compile '{0}' to the backend's filter type")]
55    TypeError(String),
56    #[error("Missing field '{0}'")]
57    MissingField(String),
58    #[error("'{0}' must {1}")]
59    Must(String, String),
60    #[error("Filter serialization failed: {0}")]
61    Serialization(String),
62}
63
64pub trait SearchFilter {
65    type Value;
66
67    fn eq(key: String, value: Self::Value) -> Self;
68    fn gt(key: String, value: Self::Value) -> Self;
69    fn lt(key: String, value: Self::Value) -> Self;
70    fn and(self, rhs: Self) -> Self;
71    fn or(self, rhs: Self) -> Self;
72}
73
74#[derive(Debug, Clone, Serialize, Deserialize)]
75#[serde(rename_all = "lowercase")]
76pub enum Filter<V>
77where
78    V: std::fmt::Debug + Clone,
79{
80    Eq(String, V),
81    Gt(String, V),
82    Lt(String, V),
83    And(Box<Self>, Box<Self>),
84    Or(Box<Self>, Box<Self>),
85}
86
87impl<V> SearchFilter for Filter<V>
88where
89    V: std::fmt::Debug + Clone + Serialize + for<'de> Deserialize<'de>,
90{
91    type Value = V;
92
93    fn eq(key: String, value: Self::Value) -> Self {
94        Self::Eq(key, value)
95    }
96
97    fn gt(key: String, value: Self::Value) -> Self {
98        Self::Gt(key, value)
99    }
100
101    fn lt(key: String, value: Self::Value) -> Self {
102        Self::Lt(key, value)
103    }
104
105    fn and(self, rhs: Self) -> Self {
106        Self::And(self.into(), rhs.into())
107    }
108
109    fn or(self, rhs: Self) -> Self {
110        Self::Or(self.into(), rhs.into())
111    }
112}
113
114impl<V> Filter<V>
115where
116    V: std::fmt::Debug + Clone,
117{
118    pub fn interpret<F>(self) -> F
119    where
120        F: SearchFilter<Value = V>,
121    {
122        match self {
123            Self::Eq(key, val) => F::eq(key, val),
124            Self::Gt(key, val) => F::gt(key, val),
125            Self::Lt(key, val) => F::lt(key, val),
126            Self::And(lhs, rhs) => F::and(lhs.interpret(), rhs.interpret()),
127            Self::Or(lhs, rhs) => F::or(lhs.interpret(), rhs.interpret()),
128        }
129    }
130}
131
132impl Filter<serde_json::Value> {
133    pub fn satisfies(&self, value: &serde_json::Value) -> bool {
134        use Filter::*;
135        use serde_json::{Value, Value::*, json};
136        use std::cmp::Ordering;
137
138        fn compare_pair(l: &Value, r: &Value) -> Option<std::cmp::Ordering> {
139            match (l, r) {
140                (Number(l), Number(r)) => l
141                    .as_f64()
142                    .zip(r.as_f64())
143                    .and_then(|(l, r)| l.partial_cmp(&r))
144                    .or(l.as_i64().zip(r.as_i64()).map(|(l, r)| l.cmp(&r)))
145                    .or(l.as_u64().zip(r.as_u64()).map(|(l, r)| l.cmp(&r))),
146                (String(l), String(r)) => Some(l.cmp(r)),
147                (Null, Null) => Some(std::cmp::Ordering::Equal),
148                (Bool(l), Bool(r)) => Some(l.cmp(r)),
149                _ => None,
150            }
151        }
152
153        match self {
154            Eq(k, v) => &json!({ k: v }) == value,
155            Gt(k, v) => {
156                compare_pair(&json!({k: v}), value).is_some_and(|ord| ord == Ordering::Greater)
157            }
158            Lt(k, v) => {
159                compare_pair(&json!({k: v}), value).is_some_and(|ord| ord == Ordering::Less)
160            }
161            And(l, r) => l.satisfies(value) && r.satisfies(value),
162            Or(l, r) => l.satisfies(value) || r.satisfies(value),
163        }
164    }
165}
166
167#[derive(Clone, Serialize, Deserialize, Debug)]
168pub struct VectorSearchRequestBuilder<F = Filter<serde_json::Value>> {
169    query: Option<String>,
170    samples: Option<u64>,
171    threshold: Option<f64>,
172    additional_params: Option<serde_json::Value>,
173    filter: Option<F>,
174}
175
176impl<F> Default for VectorSearchRequestBuilder<F> {
177    fn default() -> Self {
178        Self {
179            query: None,
180            samples: None,
181            threshold: None,
182            additional_params: None,
183            filter: None,
184        }
185    }
186}
187
188impl<F> VectorSearchRequestBuilder<F>
189where
190    F: SearchFilter,
191{
192    pub fn query<T>(mut self, query: T) -> Self
193    where
194        T: Into<String>,
195    {
196        self.query = Some(query.into());
197        self
198    }
199
200    pub fn samples(mut self, samples: u64) -> Self {
201        self.samples = Some(samples);
202        self
203    }
204
205    pub fn threshold(mut self, threshold: f64) -> Self {
206        self.threshold = Some(threshold);
207        self
208    }
209
210    pub fn additional_params(
211        mut self,
212        params: serde_json::Value,
213    ) -> Result<Self, VectorStoreError> {
214        self.additional_params = Some(params);
215        Ok(self)
216    }
217
218    pub fn filter(mut self, filter: F) -> Self {
219        self.filter = Some(filter);
220        self
221    }
222
223    pub fn build(self) -> Result<VectorSearchRequest<F>, VectorStoreError> {
224        let Some(query) = self.query else {
225            return Err(VectorStoreError::BuilderError(
226                "`query` is a required variable for building a vector search request".into(),
227            ));
228        };
229
230        let Some(samples) = self.samples else {
231            return Err(VectorStoreError::BuilderError(
232                "`samples` is a required variable for building a vector search request".into(),
233            ));
234        };
235
236        let additional_params = if let Some(params) = self.additional_params {
237            if !params.is_object() {
238                return Err(VectorStoreError::BuilderError(
239                    "Expected JSON object for additional params, got something else".into(),
240                ));
241            }
242            Some(params)
243        } else {
244            None
245        };
246
247        Ok(VectorSearchRequest {
248            query,
249            samples,
250            threshold: self.threshold,
251            additional_params,
252            filter: self.filter,
253        })
254    }
255}