rig/vector_store/
request.rs

1use serde::{Deserialize, Serialize};
2
3use super::VectorStoreError;
4
5/// A vector search request - used in the [`super::VectorStoreIndex`] trait.
6#[derive(Clone, Serialize, Deserialize, Debug)]
7pub struct VectorSearchRequest<F = Filter<serde_json::Value>> {
8    /// The query to be embedded and used in similarity search.
9    query: String,
10    /// The maximum number of samples that may be returned. If adding a similarity search threshold, you may receive less than the inputted number if there aren't enough results that satisfy the threshold.
11    samples: u64,
12    /// Similarity search threshold. If present, any result with a distance less than this may be omitted from the final result.
13    threshold: Option<f64>,
14    /// Any additional parameters that are required by the vector store.
15    additional_params: Option<serde_json::Value>,
16    /// An expression used to filter samples
17    filter: Option<F>,
18}
19
20impl<Filter> VectorSearchRequest<Filter> {
21    /// Creates a [`VectorSearchRequestBuilder`] which you can use to instantiate this struct.
22    pub fn builder() -> VectorSearchRequestBuilder<Filter> {
23        VectorSearchRequestBuilder::<Filter>::default()
24    }
25
26    /// The query to be embedded and used in similarity search.
27    pub fn query(&self) -> &str {
28        &self.query
29    }
30
31    /// The maximum number of samples that may be returned. If adding a similarity search threshold, you may receive less than the inputted number if there aren't enough results that satisfy the threshold.
32    pub fn samples(&self) -> u64 {
33        self.samples
34    }
35
36    pub fn threshold(&self) -> Option<f64> {
37        self.threshold
38    }
39
40    pub fn filter(&self) -> &Option<Filter> {
41        &self.filter
42    }
43
44    pub fn map_filter<T, F>(self, f: F) -> VectorSearchRequest<T>
45    where
46        F: Fn(Filter) -> T,
47    {
48        VectorSearchRequest {
49            query: self.query,
50            samples: self.samples,
51            threshold: self.threshold,
52            additional_params: self.additional_params,
53            filter: self.filter.map(f),
54        }
55    }
56}
57
58#[derive(Debug, Clone, thiserror::Error)]
59pub enum FilterError {
60    #[error("Expected: {expected}, got: {got}")]
61    Expected { expected: String, got: String },
62    #[error("Cannot compile '{0}' to the backend's filter type")]
63    TypeError(String),
64    #[error("Missing field '{0}'")]
65    MissingField(String),
66    #[error("'{0}' must {1}")]
67    Must(String, String),
68    // NOTE: @FayCarsons - string because `serde_json::Error` is not `Clone`
69    // and we need this to be `Clone`
70    #[error("Filter serialization failed: {0}")]
71    Serialization(String),
72}
73
74pub trait SearchFilter {
75    type Value;
76
77    fn eq(key: String, value: Self::Value) -> Self;
78    fn gt(key: String, value: Self::Value) -> Self;
79    fn lt(key: String, value: Self::Value) -> Self;
80    fn and(self, rhs: Self) -> Self;
81    fn or(self, rhs: Self) -> Self;
82}
83
84/// A canonical, serializable retpresentation of filter expressions.
85/// This serves as an intermediary form whenever you need to inspect,
86/// store, or translate between specific vector store backends
87#[derive(Debug, Clone, Serialize, Deserialize)]
88#[serde(rename_all = "lowercase")]
89pub enum Filter<V>
90where
91    V: std::fmt::Debug + Clone,
92{
93    Eq(String, V),
94    Gt(String, V),
95    Lt(String, V),
96    And(Box<Self>, Box<Self>),
97    Or(Box<Self>, Box<Self>),
98}
99
100impl<V> SearchFilter for Filter<V>
101where
102    V: std::fmt::Debug + Clone + Serialize + for<'de> Deserialize<'de>,
103{
104    type Value = V;
105
106    fn eq(key: String, value: Self::Value) -> Self {
107        Self::Eq(key, value)
108    }
109
110    fn gt(key: String, value: Self::Value) -> Self {
111        Self::Gt(key, value)
112    }
113
114    fn lt(key: String, value: Self::Value) -> Self {
115        Self::Lt(key, value)
116    }
117
118    fn and(self, rhs: Self) -> Self {
119        Self::And(self.into(), rhs.into())
120    }
121
122    fn or(self, rhs: Self) -> Self {
123        Self::Or(self.into(), rhs.into())
124    }
125}
126
127impl<V> Filter<V>
128where
129    V: std::fmt::Debug + Clone,
130{
131    pub fn interpret<F>(self) -> F
132    where
133        F: SearchFilter<Value = V>,
134    {
135        match self {
136            Self::Eq(key, val) => F::eq(key, val),
137            Self::Gt(key, val) => F::gt(key, val),
138            Self::Lt(key, val) => F::lt(key, val),
139            Self::And(lhs, rhs) => F::and(lhs.interpret(), rhs.interpret()),
140            Self::Or(lhs, rhs) => F::or(lhs.interpret(), rhs.interpret()),
141        }
142    }
143}
144
145impl Filter<serde_json::Value> {
146    pub fn satisfies(&self, value: &serde_json::Value) -> bool {
147        use Filter::*;
148        use serde_json::{Value, Value::*, json};
149        use std::cmp::Ordering;
150
151        fn compare_pair(l: &Value, r: &Value) -> Option<std::cmp::Ordering> {
152            match (l, r) {
153                (Number(l), Number(r)) => l
154                    .as_f64()
155                    .zip(r.as_f64())
156                    .and_then(|(l, r)| l.partial_cmp(&r))
157                    .or(l.as_i64().zip(r.as_i64()).map(|(l, r)| l.cmp(&r)))
158                    .or(l.as_u64().zip(r.as_u64()).map(|(l, r)| l.cmp(&r))),
159                (String(l), String(r)) => Some(l.cmp(r)),
160                (Null, Null) => Some(std::cmp::Ordering::Equal),
161                (Bool(l), Bool(r)) => Some(l.cmp(r)),
162                _ => None,
163            }
164        }
165
166        match self {
167            Eq(k, v) => &json!({ k: v }) == value,
168            Gt(k, v) => {
169                compare_pair(&json!({k: v}), value).is_some_and(|ord| ord == Ordering::Greater)
170            }
171            Lt(k, v) => {
172                compare_pair(&json!({k: v}), value).is_some_and(|ord| ord == Ordering::Less)
173            }
174            And(l, r) => l.satisfies(value) && r.satisfies(value),
175            Or(l, r) => l.satisfies(value) || r.satisfies(value),
176        }
177    }
178}
179
180/// The builder struct to instantiate [`VectorSearchRequest`].
181#[derive(Clone, Serialize, Deserialize, Debug)]
182pub struct VectorSearchRequestBuilder<F = Filter<serde_json::Value>> {
183    query: Option<String>,
184    samples: Option<u64>,
185    threshold: Option<f64>,
186    additional_params: Option<serde_json::Value>,
187    filter: Option<F>,
188}
189
190impl<F> Default for VectorSearchRequestBuilder<F> {
191    fn default() -> Self {
192        Self {
193            query: None,
194            samples: None,
195            threshold: None,
196            additional_params: None,
197            filter: None,
198        }
199    }
200}
201
202impl<F> VectorSearchRequestBuilder<F>
203where
204    F: SearchFilter,
205{
206    /// Set the query (that will then be embedded )
207    pub fn query<T>(mut self, query: T) -> Self
208    where
209        T: Into<String>,
210    {
211        self.query = Some(query.into());
212        self
213    }
214
215    pub fn samples(mut self, samples: u64) -> Self {
216        self.samples = Some(samples);
217        self
218    }
219
220    pub fn threshold(mut self, threshold: f64) -> Self {
221        self.threshold = Some(threshold);
222        self
223    }
224
225    pub fn additional_params(
226        mut self,
227        params: serde_json::Value,
228    ) -> Result<Self, VectorStoreError> {
229        self.additional_params = Some(params);
230        Ok(self)
231    }
232
233    pub fn filter(mut self, filter: F) -> Self {
234        self.filter = Some(filter);
235        self
236    }
237
238    pub fn build(self) -> Result<VectorSearchRequest<F>, VectorStoreError> {
239        let Some(query) = self.query else {
240            return Err(VectorStoreError::BuilderError(
241                "`query` is a required variable for building a vector search request".into(),
242            ));
243        };
244
245        let Some(samples) = self.samples else {
246            return Err(VectorStoreError::BuilderError(
247                "`samples` is a required variable for building a vector search request".into(),
248            ));
249        };
250
251        let additional_params = if let Some(params) = self.additional_params {
252            if !params.is_object() {
253                return Err(VectorStoreError::BuilderError(
254                    "Expected JSON object for additional params, got something else".into(),
255                ));
256            }
257            Some(params)
258        } else {
259            None
260        };
261
262        Ok(VectorSearchRequest {
263            query,
264            samples,
265            threshold: self.threshold,
266            additional_params,
267            filter: self.filter,
268        })
269    }
270}