1use serde::{Deserialize, Serialize};
8
9use super::VectorStoreError;
10use crate::markers::{Missing, Provided};
11
12#[derive(Clone, Serialize, Deserialize, Debug)]
17pub struct VectorSearchRequest<F = Filter<serde_json::Value>> {
18 query: String,
20 samples: u64,
22 threshold: Option<f64>,
24 additional_params: Option<serde_json::Value>,
26 filter: Option<F>,
28}
29
30impl<Filter> VectorSearchRequest<Filter> {
31 pub fn builder() -> VectorSearchRequestBuilder<Filter> {
33 VectorSearchRequestBuilder::<Filter>::default()
34 }
35
36 pub fn query(&self) -> &str {
38 &self.query
39 }
40
41 pub fn samples(&self) -> u64 {
43 self.samples
44 }
45
46 pub fn threshold(&self) -> Option<f64> {
48 self.threshold
49 }
50
51 pub fn filter(&self) -> &Option<Filter> {
53 &self.filter
54 }
55
56 pub fn map_filter<T, F>(self, f: F) -> VectorSearchRequest<T>
61 where
62 F: Fn(Filter) -> T,
63 {
64 VectorSearchRequest {
65 query: self.query,
66 samples: self.samples,
67 threshold: self.threshold,
68 additional_params: self.additional_params,
69 filter: self.filter.map(f),
70 }
71 }
72
73 pub fn try_map_filter<T, F>(self, f: F) -> Result<VectorSearchRequest<T>, FilterError>
77 where
78 F: Fn(Filter) -> Result<T, FilterError>,
79 {
80 let filter = self.filter.map(f).transpose()?;
81
82 Ok(VectorSearchRequest {
83 query: self.query,
84 samples: self.samples,
85 threshold: self.threshold,
86 additional_params: self.additional_params,
87 filter,
88 })
89 }
90}
91
92#[derive(Debug, Clone, thiserror::Error)]
94pub enum FilterError {
95 #[error("Expected: {expected}, got: {got}")]
96 Expected { expected: String, got: String },
97
98 #[error("Cannot compile '{0}' to the backend's filter type")]
99 TypeError(String),
100
101 #[error("Missing field '{0}'")]
102 MissingField(String),
103
104 #[error("'{0}' must {1}")]
105 Must(String, String),
106
107 #[error("Filter serialization failed: {0}")]
109 Serialization(String),
110}
111
112pub trait SearchFilter {
118 type Value;
119
120 fn eq(key: impl AsRef<str>, value: Self::Value) -> Self;
121 fn gt(key: impl AsRef<str>, value: Self::Value) -> Self;
122 fn lt(key: impl AsRef<str>, value: Self::Value) -> Self;
123 fn and(self, rhs: Self) -> Self;
124 fn or(self, rhs: Self) -> Self;
125}
126
127#[derive(Debug, Clone, Serialize, Deserialize)]
132#[serde(rename_all = "lowercase")]
133pub enum Filter<V>
134where
135 V: std::fmt::Debug + Clone,
136{
137 Eq(String, V),
138 Gt(String, V),
139 Lt(String, V),
140 And(Box<Self>, Box<Self>),
141 Or(Box<Self>, Box<Self>),
142}
143
144impl<V> SearchFilter for Filter<V>
145where
146 V: std::fmt::Debug + Clone + Serialize + for<'de> Deserialize<'de>,
147{
148 type Value = V;
149
150 fn eq(key: impl AsRef<str>, value: Self::Value) -> Self {
152 Self::Eq(key.as_ref().to_owned(), value)
153 }
154
155 fn gt(key: impl AsRef<str>, value: Self::Value) -> Self {
157 Self::Gt(key.as_ref().to_owned(), value)
158 }
159
160 fn lt(key: impl AsRef<str>, value: Self::Value) -> Self {
162 Self::Lt(key.as_ref().to_owned(), value)
163 }
164
165 fn and(self, rhs: Self) -> Self {
167 Self::And(self.into(), rhs.into())
168 }
169
170 fn or(self, rhs: Self) -> Self {
172 Self::Or(self.into(), rhs.into())
173 }
174}
175
176impl<V> Filter<V>
177where
178 V: std::fmt::Debug + Clone,
179{
180 pub fn interpret<F>(self) -> F
182 where
183 F: SearchFilter<Value = V>,
184 {
185 match self {
186 Self::Eq(key, val) => F::eq(key, val),
187 Self::Gt(key, val) => F::gt(key, val),
188 Self::Lt(key, val) => F::lt(key, val),
189 Self::And(lhs, rhs) => F::and(lhs.interpret(), rhs.interpret()),
190 Self::Or(lhs, rhs) => F::or(lhs.interpret(), rhs.interpret()),
191 }
192 }
193}
194
195impl Filter<serde_json::Value> {
196 pub fn satisfies(&self, value: &serde_json::Value) -> bool {
198 use Filter::*;
199 use serde_json::{Value, Value::*, json};
200 use std::cmp::Ordering;
201
202 fn compare_pair(l: &Value, r: &Value) -> Option<std::cmp::Ordering> {
203 match (l, r) {
204 (Number(l), Number(r)) => l
205 .as_f64()
206 .zip(r.as_f64())
207 .and_then(|(l, r)| l.partial_cmp(&r))
208 .or(l.as_i64().zip(r.as_i64()).map(|(l, r)| l.cmp(&r)))
209 .or(l.as_u64().zip(r.as_u64()).map(|(l, r)| l.cmp(&r))),
210 (String(l), String(r)) => Some(l.cmp(r)),
211 (Null, Null) => Some(std::cmp::Ordering::Equal),
212 (Bool(l), Bool(r)) => Some(l.cmp(r)),
213 _ => None,
214 }
215 }
216
217 match self {
218 Eq(k, v) => &json!({ k: v }) == value,
219 Gt(k, v) => {
220 compare_pair(&json!({k: v}), value).is_some_and(|ord| ord == Ordering::Greater)
221 }
222 Lt(k, v) => {
223 compare_pair(&json!({k: v}), value).is_some_and(|ord| ord == Ordering::Less)
224 }
225 And(l, r) => l.satisfies(value) && r.satisfies(value),
226 Or(l, r) => l.satisfies(value) || r.satisfies(value),
227 }
228 }
229}
230
231#[derive(Clone, Serialize, Deserialize, Debug)]
233pub struct VectorSearchRequestBuilder<F = Filter<serde_json::Value>, Q = Missing, S = Missing> {
234 query: Q,
235 samples: S,
236 threshold: Option<f64>,
237 additional_params: Option<serde_json::Value>,
238 filter: Option<F>,
239}
240
241impl<F> Default for VectorSearchRequestBuilder<F, Missing, Missing> {
242 fn default() -> Self {
243 Self {
244 query: Missing,
245 samples: Missing,
246 threshold: None,
247 additional_params: None,
248 filter: None,
249 }
250 }
251}
252
253impl<F, Q, S> VectorSearchRequestBuilder<F, Q, S>
254where
255 F: SearchFilter,
256{
257 pub fn query<T>(self, query: T) -> VectorSearchRequestBuilder<F, Provided<String>, S>
259 where
260 T: Into<String>,
261 {
262 VectorSearchRequestBuilder {
263 query: Provided(query.into()),
264 samples: self.samples,
265 threshold: self.threshold,
266 additional_params: self.additional_params,
267 filter: self.filter,
268 }
269 }
270
271 pub fn samples(self, samples: u64) -> VectorSearchRequestBuilder<F, Q, Provided<u64>> {
273 VectorSearchRequestBuilder {
274 query: self.query,
275 samples: Provided(samples),
276 threshold: self.threshold,
277 additional_params: self.additional_params,
278 filter: self.filter,
279 }
280 }
281
282 pub fn threshold(mut self, threshold: f64) -> Self {
284 self.threshold = Some(threshold);
285 self
286 }
287
288 pub fn additional_params(
290 mut self,
291 params: serde_json::Value,
292 ) -> Result<Self, VectorStoreError> {
293 self.additional_params = Some(params);
294 Ok(self)
295 }
296
297 pub fn filter(mut self, filter: F) -> Self {
299 self.filter = Some(filter);
300 self
301 }
302}
303
304impl<F> VectorSearchRequestBuilder<F, Provided<String>, Provided<u64>> {
306 pub fn build(self) -> VectorSearchRequest<F> {
308 VectorSearchRequest {
309 query: self.query.0,
310 samples: self.samples.0,
311 threshold: self.threshold,
312 additional_params: self.additional_params,
313 filter: self.filter,
314 }
315 }
316}