1use serde::{Deserialize, Serialize};
8
9use super::VectorStoreError;
10
11#[derive(Clone, Serialize, Deserialize, Debug)]
16pub struct VectorSearchRequest<F = Filter<serde_json::Value>> {
17 query: String,
19 samples: u64,
21 threshold: Option<f64>,
23 additional_params: Option<serde_json::Value>,
25 filter: Option<F>,
27}
28
29impl<Filter> VectorSearchRequest<Filter> {
30 pub fn builder() -> VectorSearchRequestBuilder<Filter> {
32 VectorSearchRequestBuilder::<Filter>::default()
33 }
34
35 pub fn query(&self) -> &str {
37 &self.query
38 }
39
40 pub fn samples(&self) -> u64 {
42 self.samples
43 }
44
45 pub fn threshold(&self) -> Option<f64> {
47 self.threshold
48 }
49
50 pub fn filter(&self) -> &Option<Filter> {
52 &self.filter
53 }
54
55 pub fn map_filter<T, F>(self, f: F) -> VectorSearchRequest<T>
60 where
61 F: Fn(Filter) -> T,
62 {
63 VectorSearchRequest {
64 query: self.query,
65 samples: self.samples,
66 threshold: self.threshold,
67 additional_params: self.additional_params,
68 filter: self.filter.map(f),
69 }
70 }
71
72 pub fn try_map_filter<T, F>(self, f: F) -> Result<VectorSearchRequest<T>, FilterError>
76 where
77 F: Fn(Filter) -> Result<T, FilterError>,
78 {
79 let filter = self.filter.map(f).transpose()?;
80
81 Ok(VectorSearchRequest {
82 query: self.query,
83 samples: self.samples,
84 threshold: self.threshold,
85 additional_params: self.additional_params,
86 filter,
87 })
88 }
89}
90
91#[derive(Debug, Clone, thiserror::Error)]
93pub enum FilterError {
94 #[error("Expected: {expected}, got: {got}")]
95 Expected { expected: String, got: String },
96
97 #[error("Cannot compile '{0}' to the backend's filter type")]
98 TypeError(String),
99
100 #[error("Missing field '{0}'")]
101 MissingField(String),
102
103 #[error("'{0}' must {1}")]
104 Must(String, String),
105
106 #[error("Filter serialization failed: {0}")]
108 Serialization(String),
109}
110
111pub trait SearchFilter {
117 type Value;
118
119 fn eq(key: impl AsRef<str>, value: Self::Value) -> Self;
120 fn gt(key: impl AsRef<str>, value: Self::Value) -> Self;
121 fn lt(key: impl AsRef<str>, value: Self::Value) -> Self;
122 fn and(self, rhs: Self) -> Self;
123 fn or(self, rhs: Self) -> Self;
124}
125
126#[derive(Debug, Clone, Serialize, Deserialize)]
131#[serde(rename_all = "lowercase")]
132pub enum Filter<V>
133where
134 V: std::fmt::Debug + Clone,
135{
136 Eq(String, V),
137 Gt(String, V),
138 Lt(String, V),
139 And(Box<Self>, Box<Self>),
140 Or(Box<Self>, Box<Self>),
141}
142
143impl<V> SearchFilter for Filter<V>
144where
145 V: std::fmt::Debug + Clone + Serialize + for<'de> Deserialize<'de>,
146{
147 type Value = V;
148
149 fn eq(key: impl AsRef<str>, value: Self::Value) -> Self {
151 Self::Eq(key.as_ref().to_owned(), value)
152 }
153
154 fn gt(key: impl AsRef<str>, value: Self::Value) -> Self {
156 Self::Gt(key.as_ref().to_owned(), value)
157 }
158
159 fn lt(key: impl AsRef<str>, value: Self::Value) -> Self {
161 Self::Lt(key.as_ref().to_owned(), value)
162 }
163
164 fn and(self, rhs: Self) -> Self {
166 Self::And(self.into(), rhs.into())
167 }
168
169 fn or(self, rhs: Self) -> Self {
171 Self::Or(self.into(), rhs.into())
172 }
173}
174
175impl<V> Filter<V>
176where
177 V: std::fmt::Debug + Clone,
178{
179 pub fn interpret<F>(self) -> F
181 where
182 F: SearchFilter<Value = V>,
183 {
184 match self {
185 Self::Eq(key, val) => F::eq(key, val),
186 Self::Gt(key, val) => F::gt(key, val),
187 Self::Lt(key, val) => F::lt(key, val),
188 Self::And(lhs, rhs) => F::and(lhs.interpret(), rhs.interpret()),
189 Self::Or(lhs, rhs) => F::or(lhs.interpret(), rhs.interpret()),
190 }
191 }
192}
193
194impl Filter<serde_json::Value> {
195 pub fn satisfies(&self, value: &serde_json::Value) -> bool {
197 use Filter::*;
198 use serde_json::{Value, Value::*, json};
199 use std::cmp::Ordering;
200
201 fn compare_pair(l: &Value, r: &Value) -> Option<std::cmp::Ordering> {
202 match (l, r) {
203 (Number(l), Number(r)) => l
204 .as_f64()
205 .zip(r.as_f64())
206 .and_then(|(l, r)| l.partial_cmp(&r))
207 .or(l.as_i64().zip(r.as_i64()).map(|(l, r)| l.cmp(&r)))
208 .or(l.as_u64().zip(r.as_u64()).map(|(l, r)| l.cmp(&r))),
209 (String(l), String(r)) => Some(l.cmp(r)),
210 (Null, Null) => Some(std::cmp::Ordering::Equal),
211 (Bool(l), Bool(r)) => Some(l.cmp(r)),
212 _ => None,
213 }
214 }
215
216 match self {
217 Eq(k, v) => &json!({ k: v }) == value,
218 Gt(k, v) => {
219 compare_pair(&json!({k: v}), value).is_some_and(|ord| ord == Ordering::Greater)
220 }
221 Lt(k, v) => {
222 compare_pair(&json!({k: v}), value).is_some_and(|ord| ord == Ordering::Less)
223 }
224 And(l, r) => l.satisfies(value) && r.satisfies(value),
225 Or(l, r) => l.satisfies(value) || r.satisfies(value),
226 }
227 }
228}
229
230#[derive(Clone, Serialize, Deserialize, Debug)]
232pub struct VectorSearchRequestBuilder<F = Filter<serde_json::Value>> {
233 query: Option<String>,
234 samples: Option<u64>,
235 threshold: Option<f64>,
236 additional_params: Option<serde_json::Value>,
237 filter: Option<F>,
238}
239
240impl<F> Default for VectorSearchRequestBuilder<F> {
241 fn default() -> Self {
242 Self {
243 query: None,
244 samples: None,
245 threshold: None,
246 additional_params: None,
247 filter: None,
248 }
249 }
250}
251
252impl<F> VectorSearchRequestBuilder<F>
253where
254 F: SearchFilter,
255{
256 pub fn query<T>(mut self, query: T) -> Self
258 where
259 T: Into<String>,
260 {
261 self.query = Some(query.into());
262 self
263 }
264
265 pub fn samples(mut self, samples: u64) -> Self {
267 self.samples = Some(samples);
268 self
269 }
270
271 pub fn threshold(mut self, threshold: f64) -> Self {
273 self.threshold = Some(threshold);
274 self
275 }
276
277 pub fn additional_params(
279 mut self,
280 params: serde_json::Value,
281 ) -> Result<Self, VectorStoreError> {
282 self.additional_params = Some(params);
283 Ok(self)
284 }
285
286 pub fn filter(mut self, filter: F) -> Self {
288 self.filter = Some(filter);
289 self
290 }
291
292 pub fn build(self) -> Result<VectorSearchRequest<F>, VectorStoreError> {
294 let Some(query) = self.query else {
295 return Err(VectorStoreError::BuilderError(
296 "`query` is a required variable for building a vector search request".into(),
297 ));
298 };
299
300 let Some(samples) = self.samples else {
301 return Err(VectorStoreError::BuilderError(
302 "`samples` is a required variable for building a vector search request".into(),
303 ));
304 };
305
306 let additional_params = if let Some(params) = self.additional_params {
307 if !params.is_object() {
308 return Err(VectorStoreError::BuilderError(
309 "Expected JSON object for additional params, got something else".into(),
310 ));
311 }
312 Some(params)
313 } else {
314 None
315 };
316
317 Ok(VectorSearchRequest {
318 query,
319 samples,
320 threshold: self.threshold,
321 additional_params,
322 filter: self.filter,
323 })
324 }
325}