1use serde::{Deserialize, Serialize};
2
3use super::VectorStoreError;
4
5#[derive(Clone, Serialize, Deserialize, Debug)]
7pub struct VectorSearchRequest<F = Filter<serde_json::Value>> {
8 query: String,
10 samples: u64,
12 threshold: Option<f64>,
14 additional_params: Option<serde_json::Value>,
16 filter: Option<F>,
18}
19
20impl<Filter> VectorSearchRequest<Filter> {
21 pub fn builder() -> VectorSearchRequestBuilder<Filter> {
23 VectorSearchRequestBuilder::<Filter>::default()
24 }
25
26 pub fn query(&self) -> &str {
28 &self.query
29 }
30
31 pub fn samples(&self) -> u64 {
33 self.samples
34 }
35
36 pub fn threshold(&self) -> Option<f64> {
37 self.threshold
38 }
39
40 pub fn filter(&self) -> &Option<Filter> {
41 &self.filter
42 }
43
44 pub fn map_filter<T, F>(self, f: F) -> VectorSearchRequest<T>
45 where
46 F: Fn(Filter) -> T,
47 {
48 VectorSearchRequest {
49 query: self.query,
50 samples: self.samples,
51 threshold: self.threshold,
52 additional_params: self.additional_params,
53 filter: self.filter.map(f),
54 }
55 }
56}
57
58#[derive(Debug, Clone, thiserror::Error)]
59pub enum FilterError {
60 #[error("Expected: {expected}, got: {got}")]
61 Expected { expected: String, got: String },
62 #[error("Cannot compile '{0}' to the backend's filter type")]
63 TypeError(String),
64 #[error("Missing field '{0}'")]
65 MissingField(String),
66 #[error("'{0}' must {1}")]
67 Must(String, String),
68 #[error("Filter serialization failed: {0}")]
71 Serialization(String),
72}
73
74pub trait SearchFilter {
75 type Value;
76
77 fn eq(key: String, value: Self::Value) -> Self;
78 fn gt(key: String, value: Self::Value) -> Self;
79 fn lt(key: String, value: Self::Value) -> Self;
80 fn and(self, rhs: Self) -> Self;
81 fn or(self, rhs: Self) -> Self;
82}
83
84#[derive(Debug, Clone, Serialize, Deserialize)]
88#[serde(rename_all = "lowercase")]
89pub enum Filter<V>
90where
91 V: std::fmt::Debug + Clone,
92{
93 Eq(String, V),
94 Gt(String, V),
95 Lt(String, V),
96 And(Box<Self>, Box<Self>),
97 Or(Box<Self>, Box<Self>),
98}
99
100impl<V> SearchFilter for Filter<V>
101where
102 V: std::fmt::Debug + Clone + Serialize + for<'de> Deserialize<'de>,
103{
104 type Value = V;
105
106 fn eq(key: String, value: Self::Value) -> Self {
107 Self::Eq(key, value)
108 }
109
110 fn gt(key: String, value: Self::Value) -> Self {
111 Self::Gt(key, value)
112 }
113
114 fn lt(key: String, value: Self::Value) -> Self {
115 Self::Lt(key, value)
116 }
117
118 fn and(self, rhs: Self) -> Self {
119 Self::And(self.into(), rhs.into())
120 }
121
122 fn or(self, rhs: Self) -> Self {
123 Self::Or(self.into(), rhs.into())
124 }
125}
126
127impl<V> Filter<V>
128where
129 V: std::fmt::Debug + Clone,
130{
131 pub fn interpret<F>(self) -> F
132 where
133 F: SearchFilter<Value = V>,
134 {
135 match self {
136 Self::Eq(key, val) => F::eq(key, val),
137 Self::Gt(key, val) => F::gt(key, val),
138 Self::Lt(key, val) => F::lt(key, val),
139 Self::And(lhs, rhs) => F::and(lhs.interpret(), rhs.interpret()),
140 Self::Or(lhs, rhs) => F::or(lhs.interpret(), rhs.interpret()),
141 }
142 }
143}
144
145impl Filter<serde_json::Value> {
146 pub fn satisfies(&self, value: &serde_json::Value) -> bool {
147 use Filter::*;
148 use serde_json::{Value, Value::*, json};
149 use std::cmp::Ordering;
150
151 fn compare_pair(l: &Value, r: &Value) -> Option<std::cmp::Ordering> {
152 match (l, r) {
153 (Number(l), Number(r)) => l
154 .as_f64()
155 .zip(r.as_f64())
156 .and_then(|(l, r)| l.partial_cmp(&r))
157 .or(l.as_i64().zip(r.as_i64()).map(|(l, r)| l.cmp(&r)))
158 .or(l.as_u64().zip(r.as_u64()).map(|(l, r)| l.cmp(&r))),
159 (String(l), String(r)) => Some(l.cmp(r)),
160 (Null, Null) => Some(std::cmp::Ordering::Equal),
161 (Bool(l), Bool(r)) => Some(l.cmp(r)),
162 _ => None,
163 }
164 }
165
166 match self {
167 Eq(k, v) => &json!({ k: v }) == value,
168 Gt(k, v) => {
169 compare_pair(&json!({k: v}), value).is_some_and(|ord| ord == Ordering::Greater)
170 }
171 Lt(k, v) => {
172 compare_pair(&json!({k: v}), value).is_some_and(|ord| ord == Ordering::Less)
173 }
174 And(l, r) => l.satisfies(value) && r.satisfies(value),
175 Or(l, r) => l.satisfies(value) || r.satisfies(value),
176 }
177 }
178}
179
180#[derive(Clone, Serialize, Deserialize, Debug)]
182pub struct VectorSearchRequestBuilder<F = Filter<serde_json::Value>> {
183 query: Option<String>,
184 samples: Option<u64>,
185 threshold: Option<f64>,
186 additional_params: Option<serde_json::Value>,
187 filter: Option<F>,
188}
189
190impl<F> Default for VectorSearchRequestBuilder<F> {
191 fn default() -> Self {
192 Self {
193 query: None,
194 samples: None,
195 threshold: None,
196 additional_params: None,
197 filter: None,
198 }
199 }
200}
201
202impl<F> VectorSearchRequestBuilder<F>
203where
204 F: SearchFilter,
205{
206 pub fn query<T>(mut self, query: T) -> Self
208 where
209 T: Into<String>,
210 {
211 self.query = Some(query.into());
212 self
213 }
214
215 pub fn samples(mut self, samples: u64) -> Self {
216 self.samples = Some(samples);
217 self
218 }
219
220 pub fn threshold(mut self, threshold: f64) -> Self {
221 self.threshold = Some(threshold);
222 self
223 }
224
225 pub fn additional_params(
226 mut self,
227 params: serde_json::Value,
228 ) -> Result<Self, VectorStoreError> {
229 self.additional_params = Some(params);
230 Ok(self)
231 }
232
233 pub fn filter(mut self, filter: F) -> Self {
234 self.filter = Some(filter);
235 self
236 }
237
238 pub fn build(self) -> Result<VectorSearchRequest<F>, VectorStoreError> {
239 let Some(query) = self.query else {
240 return Err(VectorStoreError::BuilderError(
241 "`query` is a required variable for building a vector search request".into(),
242 ));
243 };
244
245 let Some(samples) = self.samples else {
246 return Err(VectorStoreError::BuilderError(
247 "`samples` is a required variable for building a vector search request".into(),
248 ));
249 };
250
251 let additional_params = if let Some(params) = self.additional_params {
252 if !params.is_object() {
253 return Err(VectorStoreError::BuilderError(
254 "Expected JSON object for additional params, got something else".into(),
255 ));
256 }
257 Some(params)
258 } else {
259 None
260 };
261
262 Ok(VectorSearchRequest {
263 query,
264 samples,
265 threshold: self.threshold,
266 additional_params,
267 filter: self.filter,
268 })
269 }
270}