1use serde::{Deserialize, Serialize};
8
9use super::VectorStoreError;
10
11#[derive(Clone, Serialize, Deserialize, Debug)]
16pub struct VectorSearchRequest<F = Filter<serde_json::Value>> {
17 query: String,
19 samples: u64,
21 threshold: Option<f64>,
23 additional_params: Option<serde_json::Value>,
25 filter: Option<F>,
27}
28
29impl<Filter> VectorSearchRequest<Filter> {
30 pub fn builder() -> VectorSearchRequestBuilder<Filter> {
32 VectorSearchRequestBuilder::<Filter>::default()
33 }
34
35 pub fn query(&self) -> &str {
37 &self.query
38 }
39
40 pub fn samples(&self) -> u64 {
42 self.samples
43 }
44
45 pub fn threshold(&self) -> Option<f64> {
47 self.threshold
48 }
49
50 pub fn filter(&self) -> &Option<Filter> {
52 &self.filter
53 }
54
55 pub fn map_filter<T, F>(self, f: F) -> VectorSearchRequest<T>
60 where
61 F: Fn(Filter) -> T,
62 {
63 VectorSearchRequest {
64 query: self.query,
65 samples: self.samples,
66 threshold: self.threshold,
67 additional_params: self.additional_params,
68 filter: self.filter.map(f),
69 }
70 }
71}
72
73#[derive(Debug, Clone, thiserror::Error)]
75pub enum FilterError {
76 #[error("Expected: {expected}, got: {got}")]
77 Expected { expected: String, got: String },
78
79 #[error("Cannot compile '{0}' to the backend's filter type")]
80 TypeError(String),
81
82 #[error("Missing field '{0}'")]
83 MissingField(String),
84
85 #[error("'{0}' must {1}")]
86 Must(String, String),
87
88 #[error("Filter serialization failed: {0}")]
90 Serialization(String),
91}
92
93pub trait SearchFilter {
99 type Value;
100
101 fn eq(key: impl AsRef<str>, value: Self::Value) -> Self;
102 fn gt(key: impl AsRef<str>, value: Self::Value) -> Self;
103 fn lt(key: impl AsRef<str>, value: Self::Value) -> Self;
104 fn and(self, rhs: Self) -> Self;
105 fn or(self, rhs: Self) -> Self;
106}
107
108#[derive(Debug, Clone, Serialize, Deserialize)]
113#[serde(rename_all = "lowercase")]
114pub enum Filter<V>
115where
116 V: std::fmt::Debug + Clone,
117{
118 Eq(String, V),
119 Gt(String, V),
120 Lt(String, V),
121 And(Box<Self>, Box<Self>),
122 Or(Box<Self>, Box<Self>),
123}
124
125impl<V> SearchFilter for Filter<V>
126where
127 V: std::fmt::Debug + Clone + Serialize + for<'de> Deserialize<'de>,
128{
129 type Value = V;
130
131 fn eq(key: impl AsRef<str>, value: Self::Value) -> Self {
133 Self::Eq(key.as_ref().to_owned(), value)
134 }
135
136 fn gt(key: impl AsRef<str>, value: Self::Value) -> Self {
138 Self::Gt(key.as_ref().to_owned(), value)
139 }
140
141 fn lt(key: impl AsRef<str>, value: Self::Value) -> Self {
143 Self::Lt(key.as_ref().to_owned(), value)
144 }
145
146 fn and(self, rhs: Self) -> Self {
148 Self::And(self.into(), rhs.into())
149 }
150
151 fn or(self, rhs: Self) -> Self {
153 Self::Or(self.into(), rhs.into())
154 }
155}
156
157impl<V> Filter<V>
158where
159 V: std::fmt::Debug + Clone,
160{
161 pub fn interpret<F>(self) -> F
163 where
164 F: SearchFilter<Value = V>,
165 {
166 match self {
167 Self::Eq(key, val) => F::eq(key, val),
168 Self::Gt(key, val) => F::gt(key, val),
169 Self::Lt(key, val) => F::lt(key, val),
170 Self::And(lhs, rhs) => F::and(lhs.interpret(), rhs.interpret()),
171 Self::Or(lhs, rhs) => F::or(lhs.interpret(), rhs.interpret()),
172 }
173 }
174}
175
176impl Filter<serde_json::Value> {
177 pub fn satisfies(&self, value: &serde_json::Value) -> bool {
179 use Filter::*;
180 use serde_json::{Value, Value::*, json};
181 use std::cmp::Ordering;
182
183 fn compare_pair(l: &Value, r: &Value) -> Option<std::cmp::Ordering> {
184 match (l, r) {
185 (Number(l), Number(r)) => l
186 .as_f64()
187 .zip(r.as_f64())
188 .and_then(|(l, r)| l.partial_cmp(&r))
189 .or(l.as_i64().zip(r.as_i64()).map(|(l, r)| l.cmp(&r)))
190 .or(l.as_u64().zip(r.as_u64()).map(|(l, r)| l.cmp(&r))),
191 (String(l), String(r)) => Some(l.cmp(r)),
192 (Null, Null) => Some(std::cmp::Ordering::Equal),
193 (Bool(l), Bool(r)) => Some(l.cmp(r)),
194 _ => None,
195 }
196 }
197
198 match self {
199 Eq(k, v) => &json!({ k: v }) == value,
200 Gt(k, v) => {
201 compare_pair(&json!({k: v}), value).is_some_and(|ord| ord == Ordering::Greater)
202 }
203 Lt(k, v) => {
204 compare_pair(&json!({k: v}), value).is_some_and(|ord| ord == Ordering::Less)
205 }
206 And(l, r) => l.satisfies(value) && r.satisfies(value),
207 Or(l, r) => l.satisfies(value) || r.satisfies(value),
208 }
209 }
210}
211
212#[derive(Clone, Serialize, Deserialize, Debug)]
214pub struct VectorSearchRequestBuilder<F = Filter<serde_json::Value>> {
215 query: Option<String>,
216 samples: Option<u64>,
217 threshold: Option<f64>,
218 additional_params: Option<serde_json::Value>,
219 filter: Option<F>,
220}
221
222impl<F> Default for VectorSearchRequestBuilder<F> {
223 fn default() -> Self {
224 Self {
225 query: None,
226 samples: None,
227 threshold: None,
228 additional_params: None,
229 filter: None,
230 }
231 }
232}
233
234impl<F> VectorSearchRequestBuilder<F>
235where
236 F: SearchFilter,
237{
238 pub fn query<T>(mut self, query: T) -> Self
240 where
241 T: Into<String>,
242 {
243 self.query = Some(query.into());
244 self
245 }
246
247 pub fn samples(mut self, samples: u64) -> Self {
249 self.samples = Some(samples);
250 self
251 }
252
253 pub fn threshold(mut self, threshold: f64) -> Self {
255 self.threshold = Some(threshold);
256 self
257 }
258
259 pub fn additional_params(
261 mut self,
262 params: serde_json::Value,
263 ) -> Result<Self, VectorStoreError> {
264 self.additional_params = Some(params);
265 Ok(self)
266 }
267
268 pub fn filter(mut self, filter: F) -> Self {
270 self.filter = Some(filter);
271 self
272 }
273
274 pub fn build(self) -> Result<VectorSearchRequest<F>, VectorStoreError> {
276 let Some(query) = self.query else {
277 return Err(VectorStoreError::BuilderError(
278 "`query` is a required variable for building a vector search request".into(),
279 ));
280 };
281
282 let Some(samples) = self.samples else {
283 return Err(VectorStoreError::BuilderError(
284 "`samples` is a required variable for building a vector search request".into(),
285 ));
286 };
287
288 let additional_params = if let Some(params) = self.additional_params {
289 if !params.is_object() {
290 return Err(VectorStoreError::BuilderError(
291 "Expected JSON object for additional params, got something else".into(),
292 ));
293 }
294 Some(params)
295 } else {
296 None
297 };
298
299 Ok(VectorSearchRequest {
300 query,
301 samples,
302 threshold: self.threshold,
303 additional_params,
304 filter: self.filter,
305 })
306 }
307}