Skip to main content

rig/vector_store/
request.rs

1//! Types for constructing vector search queries.
2//!
3//! - [`VectorSearchRequest`]: Query parameters (text, result count, threshold, filters).
4//! - [`SearchFilter`]: Trait for backend-agnostic filter expressions.
5//! - [`Filter`]: Canonical, serializable filter representation.
6
7use serde::{Deserialize, Serialize};
8
9use super::VectorStoreError;
10
11/// A vector search request for querying a [`super::VectorStoreIndex`].
12///
13/// The type parameter `F` specifies the filter type (defaults to [`Filter<serde_json::Value>`]).
14/// Use [`VectorSearchRequest::builder()`] to construct instances.
15#[derive(Clone, Serialize, Deserialize, Debug)]
16pub struct VectorSearchRequest<F = Filter<serde_json::Value>> {
17    /// The query text to embed and search with.
18    query: String,
19    /// Maximum number of results to return.
20    samples: u64,
21    /// Minimum similarity score for results.
22    threshold: Option<f64>,
23    /// Backend-specific parameters as a JSON object.
24    additional_params: Option<serde_json::Value>,
25    /// Filter expression to narrow results by metadata.
26    filter: Option<F>,
27}
28
29impl<Filter> VectorSearchRequest<Filter> {
30    /// Creates a [`VectorSearchRequestBuilder`] which you can use to instantiate this struct.
31    pub fn builder() -> VectorSearchRequestBuilder<Filter> {
32        VectorSearchRequestBuilder::<Filter>::default()
33    }
34
35    /// The query to be embedded and used in similarity search.
36    pub fn query(&self) -> &str {
37        &self.query
38    }
39
40    /// Returns the maximum number of results to return.
41    pub fn samples(&self) -> u64 {
42        self.samples
43    }
44
45    /// Returns the optional similarity threshold.
46    pub fn threshold(&self) -> Option<f64> {
47        self.threshold
48    }
49
50    /// Returns a reference to the optional filter expression.
51    pub fn filter(&self) -> &Option<Filter> {
52        &self.filter
53    }
54
55    /// Transforms the filter type using the provided function.
56    ///
57    /// This is useful for converting between filter representations, such as
58    /// translating the canonical [`super::request::Filter`] to a backend-specific filter type.
59    pub fn map_filter<T, F>(self, f: F) -> VectorSearchRequest<T>
60    where
61        F: Fn(Filter) -> T,
62    {
63        VectorSearchRequest {
64            query: self.query,
65            samples: self.samples,
66            threshold: self.threshold,
67            additional_params: self.additional_params,
68            filter: self.filter.map(f),
69        }
70    }
71
72    /// Transforms the filter type using a provided function which can additionally return a result.
73    ///
74    /// Useful for converting between filter representations where the conversion can potentially fail (eg, unrepresentable or invalid values).
75    pub fn try_map_filter<T, F>(self, f: F) -> Result<VectorSearchRequest<T>, FilterError>
76    where
77        F: Fn(Filter) -> Result<T, FilterError>,
78    {
79        let filter = self.filter.map(f).transpose()?;
80
81        Ok(VectorSearchRequest {
82            query: self.query,
83            samples: self.samples,
84            threshold: self.threshold,
85            additional_params: self.additional_params,
86            filter,
87        })
88    }
89}
90
91/// Errors from constructing or converting filter expressions.
92#[derive(Debug, Clone, thiserror::Error)]
93pub enum FilterError {
94    #[error("Expected: {expected}, got: {got}")]
95    Expected { expected: String, got: String },
96
97    #[error("Cannot compile '{0}' to the backend's filter type")]
98    TypeError(String),
99
100    #[error("Missing field '{0}'")]
101    MissingField(String),
102
103    #[error("'{0}' must {1}")]
104    Must(String, String),
105
106    // NOTE: Uses String because `serde_json::Error` is not `Clone`.
107    #[error("Filter serialization failed: {0}")]
108    Serialization(String),
109}
110
111/// Trait for constructing filter expressions in vector search queries.
112///
113/// Uses [tagless final](https://nrinaudo.github.io/articles/tagless_final.html) encoding
114/// for backend-agnostic filters. Use `SearchFilter::eq(...)` etc. directly and let
115/// type inference resolve the concrete filter type.
116pub trait SearchFilter {
117    type Value;
118
119    fn eq(key: impl AsRef<str>, value: Self::Value) -> Self;
120    fn gt(key: impl AsRef<str>, value: Self::Value) -> Self;
121    fn lt(key: impl AsRef<str>, value: Self::Value) -> Self;
122    fn and(self, rhs: Self) -> Self;
123    fn or(self, rhs: Self) -> Self;
124}
125
126/// Canonical, serializable filter representation.
127///
128/// Use for serialization, runtime inspection, or translating between backends via
129/// [`Filter::interpret`]. Prefer [`SearchFilter`] trait methods for writing queries.
130#[derive(Debug, Clone, Serialize, Deserialize)]
131#[serde(rename_all = "lowercase")]
132pub enum Filter<V>
133where
134    V: std::fmt::Debug + Clone,
135{
136    Eq(String, V),
137    Gt(String, V),
138    Lt(String, V),
139    And(Box<Self>, Box<Self>),
140    Or(Box<Self>, Box<Self>),
141}
142
143impl<V> SearchFilter for Filter<V>
144where
145    V: std::fmt::Debug + Clone + Serialize + for<'de> Deserialize<'de>,
146{
147    type Value = V;
148
149    /// Select values where the entry at `key` is equal to `value`
150    fn eq(key: impl AsRef<str>, value: Self::Value) -> Self {
151        Self::Eq(key.as_ref().to_owned(), value)
152    }
153
154    /// Select values where the entry at `key` is greater than `value`
155    fn gt(key: impl AsRef<str>, value: Self::Value) -> Self {
156        Self::Gt(key.as_ref().to_owned(), value)
157    }
158
159    /// Select values where the entry at `key` is less than `value`
160    fn lt(key: impl AsRef<str>, value: Self::Value) -> Self {
161        Self::Lt(key.as_ref().to_owned(), value)
162    }
163
164    /// Select values where the entry satisfies `self` *and* `rhs`
165    fn and(self, rhs: Self) -> Self {
166        Self::And(self.into(), rhs.into())
167    }
168
169    /// Select values where the entry satisfies `self` *or* `rhs`
170    fn or(self, rhs: Self) -> Self {
171        Self::Or(self.into(), rhs.into())
172    }
173}
174
175impl<V> Filter<V>
176where
177    V: std::fmt::Debug + Clone,
178{
179    /// Converts this filter into a backend-specific filter type.
180    pub fn interpret<F>(self) -> F
181    where
182        F: SearchFilter<Value = V>,
183    {
184        match self {
185            Self::Eq(key, val) => F::eq(key, val),
186            Self::Gt(key, val) => F::gt(key, val),
187            Self::Lt(key, val) => F::lt(key, val),
188            Self::And(lhs, rhs) => F::and(lhs.interpret(), rhs.interpret()),
189            Self::Or(lhs, rhs) => F::or(lhs.interpret(), rhs.interpret()),
190        }
191    }
192}
193
194impl Filter<serde_json::Value> {
195    /// Tests whether a JSON value satisfies this filter.
196    pub fn satisfies(&self, value: &serde_json::Value) -> bool {
197        use Filter::*;
198        use serde_json::{Value, Value::*, json};
199        use std::cmp::Ordering;
200
201        fn compare_pair(l: &Value, r: &Value) -> Option<std::cmp::Ordering> {
202            match (l, r) {
203                (Number(l), Number(r)) => l
204                    .as_f64()
205                    .zip(r.as_f64())
206                    .and_then(|(l, r)| l.partial_cmp(&r))
207                    .or(l.as_i64().zip(r.as_i64()).map(|(l, r)| l.cmp(&r)))
208                    .or(l.as_u64().zip(r.as_u64()).map(|(l, r)| l.cmp(&r))),
209                (String(l), String(r)) => Some(l.cmp(r)),
210                (Null, Null) => Some(std::cmp::Ordering::Equal),
211                (Bool(l), Bool(r)) => Some(l.cmp(r)),
212                _ => None,
213            }
214        }
215
216        match self {
217            Eq(k, v) => &json!({ k: v }) == value,
218            Gt(k, v) => {
219                compare_pair(&json!({k: v}), value).is_some_and(|ord| ord == Ordering::Greater)
220            }
221            Lt(k, v) => {
222                compare_pair(&json!({k: v}), value).is_some_and(|ord| ord == Ordering::Less)
223            }
224            And(l, r) => l.satisfies(value) && r.satisfies(value),
225            Or(l, r) => l.satisfies(value) || r.satisfies(value),
226        }
227    }
228}
229
230/// Builder for [`VectorSearchRequest`]. Requires `query` and `samples`.
231#[derive(Clone, Serialize, Deserialize, Debug)]
232pub struct VectorSearchRequestBuilder<F = Filter<serde_json::Value>> {
233    query: Option<String>,
234    samples: Option<u64>,
235    threshold: Option<f64>,
236    additional_params: Option<serde_json::Value>,
237    filter: Option<F>,
238}
239
240impl<F> Default for VectorSearchRequestBuilder<F> {
241    fn default() -> Self {
242        Self {
243            query: None,
244            samples: None,
245            threshold: None,
246            additional_params: None,
247            filter: None,
248        }
249    }
250}
251
252impl<F> VectorSearchRequestBuilder<F>
253where
254    F: SearchFilter,
255{
256    /// Sets the query text. Required.
257    pub fn query<T>(mut self, query: T) -> Self
258    where
259        T: Into<String>,
260    {
261        self.query = Some(query.into());
262        self
263    }
264
265    /// Sets the maximum number of results. Required.
266    pub fn samples(mut self, samples: u64) -> Self {
267        self.samples = Some(samples);
268        self
269    }
270
271    /// Sets the minimum similarity threshold.
272    pub fn threshold(mut self, threshold: f64) -> Self {
273        self.threshold = Some(threshold);
274        self
275    }
276
277    /// Sets backend-specific parameters.
278    pub fn additional_params(
279        mut self,
280        params: serde_json::Value,
281    ) -> Result<Self, VectorStoreError> {
282        self.additional_params = Some(params);
283        Ok(self)
284    }
285
286    /// Sets a filter expression.
287    pub fn filter(mut self, filter: F) -> Self {
288        self.filter = Some(filter);
289        self
290    }
291
292    /// Builds the request, returning an error if required fields are missing.
293    pub fn build(self) -> Result<VectorSearchRequest<F>, VectorStoreError> {
294        let Some(query) = self.query else {
295            return Err(VectorStoreError::BuilderError(
296                "`query` is a required variable for building a vector search request".into(),
297            ));
298        };
299
300        let Some(samples) = self.samples else {
301            return Err(VectorStoreError::BuilderError(
302                "`samples` is a required variable for building a vector search request".into(),
303            ));
304        };
305
306        let additional_params = if let Some(params) = self.additional_params {
307            if !params.is_object() {
308                return Err(VectorStoreError::BuilderError(
309                    "Expected JSON object for additional params, got something else".into(),
310                ));
311            }
312            Some(params)
313        } else {
314            None
315        };
316
317        Ok(VectorSearchRequest {
318            query,
319            samples,
320            threshold: self.threshold,
321            additional_params,
322            filter: self.filter,
323        })
324    }
325}