autoagents_core/vector_store/
request.rs1use serde::{Deserialize, Serialize};
2
3use super::VectorStoreError;
4
5#[derive(Clone, Serialize, Deserialize, Debug)]
7pub struct VectorSearchRequest<F = Filter<serde_json::Value>> {
8 query: String,
9 samples: u64,
10 threshold: Option<f64>,
11 additional_params: Option<serde_json::Value>,
12 filter: Option<F>,
13}
14
15impl<Filter> VectorSearchRequest<Filter> {
16 pub fn builder() -> VectorSearchRequestBuilder<Filter> {
17 VectorSearchRequestBuilder::<Filter>::default()
18 }
19
20 pub fn query(&self) -> &str {
21 &self.query
22 }
23
24 pub fn samples(&self) -> u64 {
25 self.samples
26 }
27
28 pub fn threshold(&self) -> Option<f64> {
29 self.threshold
30 }
31
32 pub fn filter(&self) -> &Option<Filter> {
33 &self.filter
34 }
35
36 pub fn map_filter<T, F>(self, f: F) -> VectorSearchRequest<T>
37 where
38 F: Fn(Filter) -> T,
39 {
40 VectorSearchRequest {
41 query: self.query,
42 samples: self.samples,
43 threshold: self.threshold,
44 additional_params: self.additional_params,
45 filter: self.filter.map(f),
46 }
47 }
48}
49
50#[derive(Debug, Clone, thiserror::Error)]
51pub enum FilterError {
52 #[error("Expected: {expected}, got: {got}")]
53 Expected { expected: String, got: String },
54 #[error("Cannot compile '{0}' to the backend's filter type")]
55 TypeError(String),
56 #[error("Missing field '{0}'")]
57 MissingField(String),
58 #[error("'{0}' must {1}")]
59 Must(String, String),
60 #[error("Filter serialization failed: {0}")]
61 Serialization(String),
62}
63
64pub trait SearchFilter {
65 type Value;
66
67 fn eq(key: String, value: Self::Value) -> Self;
68 fn gt(key: String, value: Self::Value) -> Self;
69 fn lt(key: String, value: Self::Value) -> Self;
70 fn and(self, rhs: Self) -> Self;
71 fn or(self, rhs: Self) -> Self;
72}
73
74#[derive(Debug, Clone, Serialize, Deserialize)]
75#[serde(rename_all = "lowercase")]
76pub enum Filter<V>
77where
78 V: std::fmt::Debug + Clone,
79{
80 Eq(String, V),
81 Gt(String, V),
82 Lt(String, V),
83 And(Box<Self>, Box<Self>),
84 Or(Box<Self>, Box<Self>),
85}
86
87impl<V> SearchFilter for Filter<V>
88where
89 V: std::fmt::Debug + Clone + Serialize + for<'de> Deserialize<'de>,
90{
91 type Value = V;
92
93 fn eq(key: String, value: Self::Value) -> Self {
94 Self::Eq(key, value)
95 }
96
97 fn gt(key: String, value: Self::Value) -> Self {
98 Self::Gt(key, value)
99 }
100
101 fn lt(key: String, value: Self::Value) -> Self {
102 Self::Lt(key, value)
103 }
104
105 fn and(self, rhs: Self) -> Self {
106 Self::And(self.into(), rhs.into())
107 }
108
109 fn or(self, rhs: Self) -> Self {
110 Self::Or(self.into(), rhs.into())
111 }
112}
113
114impl<V> Filter<V>
115where
116 V: std::fmt::Debug + Clone,
117{
118 pub fn interpret<F>(self) -> F
119 where
120 F: SearchFilter<Value = V>,
121 {
122 match self {
123 Self::Eq(key, val) => F::eq(key, val),
124 Self::Gt(key, val) => F::gt(key, val),
125 Self::Lt(key, val) => F::lt(key, val),
126 Self::And(lhs, rhs) => F::and(lhs.interpret(), rhs.interpret()),
127 Self::Or(lhs, rhs) => F::or(lhs.interpret(), rhs.interpret()),
128 }
129 }
130}
131
132impl Filter<serde_json::Value> {
133 pub fn satisfies(&self, value: &serde_json::Value) -> bool {
134 use Filter::*;
135 use serde_json::{Value, Value::*, json};
136 use std::cmp::Ordering;
137
138 fn compare_pair(l: &Value, r: &Value) -> Option<std::cmp::Ordering> {
139 match (l, r) {
140 (Number(l), Number(r)) => l
141 .as_f64()
142 .zip(r.as_f64())
143 .and_then(|(l, r)| l.partial_cmp(&r))
144 .or(l.as_i64().zip(r.as_i64()).map(|(l, r)| l.cmp(&r)))
145 .or(l.as_u64().zip(r.as_u64()).map(|(l, r)| l.cmp(&r))),
146 (String(l), String(r)) => Some(l.cmp(r)),
147 (Null, Null) => Some(std::cmp::Ordering::Equal),
148 (Bool(l), Bool(r)) => Some(l.cmp(r)),
149 _ => None,
150 }
151 }
152
153 match self {
154 Eq(k, v) => &json!({ k: v }) == value,
155 Gt(k, v) => {
156 compare_pair(&json!({k: v}), value).is_some_and(|ord| ord == Ordering::Greater)
157 }
158 Lt(k, v) => {
159 compare_pair(&json!({k: v}), value).is_some_and(|ord| ord == Ordering::Less)
160 }
161 And(l, r) => l.satisfies(value) && r.satisfies(value),
162 Or(l, r) => l.satisfies(value) || r.satisfies(value),
163 }
164 }
165}
166
167#[derive(Clone, Serialize, Deserialize, Debug)]
168pub struct VectorSearchRequestBuilder<F = Filter<serde_json::Value>> {
169 query: Option<String>,
170 samples: Option<u64>,
171 threshold: Option<f64>,
172 additional_params: Option<serde_json::Value>,
173 filter: Option<F>,
174}
175
176impl<F> Default for VectorSearchRequestBuilder<F> {
177 fn default() -> Self {
178 Self {
179 query: None,
180 samples: None,
181 threshold: None,
182 additional_params: None,
183 filter: None,
184 }
185 }
186}
187
188impl<F> VectorSearchRequestBuilder<F>
189where
190 F: SearchFilter,
191{
192 pub fn query<T>(mut self, query: T) -> Self
193 where
194 T: Into<String>,
195 {
196 self.query = Some(query.into());
197 self
198 }
199
200 pub fn samples(mut self, samples: u64) -> Self {
201 self.samples = Some(samples);
202 self
203 }
204
205 pub fn threshold(mut self, threshold: f64) -> Self {
206 self.threshold = Some(threshold);
207 self
208 }
209
210 pub fn additional_params(
211 mut self,
212 params: serde_json::Value,
213 ) -> Result<Self, VectorStoreError> {
214 self.additional_params = Some(params);
215 Ok(self)
216 }
217
218 pub fn filter(mut self, filter: F) -> Self {
219 self.filter = Some(filter);
220 self
221 }
222
223 pub fn build(self) -> Result<VectorSearchRequest<F>, VectorStoreError> {
224 let Some(query) = self.query else {
225 return Err(VectorStoreError::BuilderError(
226 "`query` is a required variable for building a vector search request".into(),
227 ));
228 };
229
230 let Some(samples) = self.samples else {
231 return Err(VectorStoreError::BuilderError(
232 "`samples` is a required variable for building a vector search request".into(),
233 ));
234 };
235
236 let additional_params = if let Some(params) = self.additional_params {
237 if !params.is_object() {
238 return Err(VectorStoreError::BuilderError(
239 "Expected JSON object for additional params, got something else".into(),
240 ));
241 }
242 Some(params)
243 } else {
244 None
245 };
246
247 Ok(VectorSearchRequest {
248 query,
249 samples,
250 threshold: self.threshold,
251 additional_params,
252 filter: self.filter,
253 })
254 }
255}