Skip to main content

rig/vector_store/
request.rs

1//! Types for constructing vector search queries.
2//!
3//! - [`VectorSearchRequest`]: Query parameters (text, result count, threshold, filters).
4//! - [`SearchFilter`]: Trait for backend-agnostic filter expressions.
5//! - [`Filter`]: Canonical, serializable filter representation.
6
7use serde::{Deserialize, Serialize};
8
9use super::VectorStoreError;
10use crate::markers::{Missing, Provided};
11
12/// A vector search request for querying a [`super::VectorStoreIndex`].
13///
14/// The type parameter `F` specifies the filter type (defaults to [`Filter<serde_json::Value>`]).
15/// Use [`VectorSearchRequest::builder()`] to construct instances.
16#[derive(Clone, Serialize, Deserialize, Debug)]
17pub struct VectorSearchRequest<F = Filter<serde_json::Value>> {
18    /// The query text to embed and search with.
19    query: String,
20    /// Maximum number of results to return.
21    samples: u64,
22    /// Minimum similarity score for results.
23    threshold: Option<f64>,
24    /// Backend-specific parameters as a JSON object.
25    additional_params: Option<serde_json::Value>,
26    /// Filter expression to narrow results by metadata.
27    filter: Option<F>,
28}
29
30impl<Filter> VectorSearchRequest<Filter> {
31    /// Creates a [`VectorSearchRequestBuilder`] which you can use to instantiate this struct.
32    pub fn builder() -> VectorSearchRequestBuilder<Filter> {
33        VectorSearchRequestBuilder::<Filter>::default()
34    }
35
36    /// The query to be embedded and used in similarity search.
37    pub fn query(&self) -> &str {
38        &self.query
39    }
40
41    /// Returns the maximum number of results to return.
42    pub fn samples(&self) -> u64 {
43        self.samples
44    }
45
46    /// Returns the optional similarity threshold.
47    pub fn threshold(&self) -> Option<f64> {
48        self.threshold
49    }
50
51    /// Returns a reference to the optional filter expression.
52    pub fn filter(&self) -> &Option<Filter> {
53        &self.filter
54    }
55
56    /// Transforms the filter type using the provided function.
57    ///
58    /// This is useful for converting between filter representations, such as
59    /// translating the canonical [`super::request::Filter`] to a backend-specific filter type.
60    pub fn map_filter<T, F>(self, f: F) -> VectorSearchRequest<T>
61    where
62        F: Fn(Filter) -> T,
63    {
64        VectorSearchRequest {
65            query: self.query,
66            samples: self.samples,
67            threshold: self.threshold,
68            additional_params: self.additional_params,
69            filter: self.filter.map(f),
70        }
71    }
72
73    /// Transforms the filter type using a provided function which can additionally return a result.
74    ///
75    /// Useful for converting between filter representations where the conversion can potentially fail (eg, unrepresentable or invalid values).
76    pub fn try_map_filter<T, F>(self, f: F) -> Result<VectorSearchRequest<T>, FilterError>
77    where
78        F: Fn(Filter) -> Result<T, FilterError>,
79    {
80        let filter = self.filter.map(f).transpose()?;
81
82        Ok(VectorSearchRequest {
83            query: self.query,
84            samples: self.samples,
85            threshold: self.threshold,
86            additional_params: self.additional_params,
87            filter,
88        })
89    }
90}
91
92/// Errors from constructing or converting filter expressions.
93#[derive(Debug, Clone, thiserror::Error)]
94pub enum FilterError {
95    #[error("Expected: {expected}, got: {got}")]
96    Expected { expected: String, got: String },
97
98    #[error("Cannot compile '{0}' to the backend's filter type")]
99    TypeError(String),
100
101    #[error("Missing field '{0}'")]
102    MissingField(String),
103
104    #[error("'{0}' must {1}")]
105    Must(String, String),
106
107    // NOTE: Uses String because `serde_json::Error` is not `Clone`.
108    #[error("Filter serialization failed: {0}")]
109    Serialization(String),
110}
111
112/// Trait for constructing filter expressions in vector search queries.
113///
114/// Uses [tagless final](https://nrinaudo.github.io/articles/tagless_final.html) encoding
115/// for backend-agnostic filters. Use `SearchFilter::eq(...)` etc. directly and let
116/// type inference resolve the concrete filter type.
117pub trait SearchFilter {
118    type Value;
119
120    fn eq(key: impl AsRef<str>, value: Self::Value) -> Self;
121    fn gt(key: impl AsRef<str>, value: Self::Value) -> Self;
122    fn lt(key: impl AsRef<str>, value: Self::Value) -> Self;
123    fn and(self, rhs: Self) -> Self;
124    fn or(self, rhs: Self) -> Self;
125}
126
127/// Canonical, serializable filter representation.
128///
129/// Use for serialization, runtime inspection, or translating between backends via
130/// [`Filter::interpret`]. Prefer [`SearchFilter`] trait methods for writing queries.
131#[derive(Debug, Clone, Serialize, Deserialize)]
132#[serde(rename_all = "lowercase")]
133pub enum Filter<V>
134where
135    V: std::fmt::Debug + Clone,
136{
137    Eq(String, V),
138    Gt(String, V),
139    Lt(String, V),
140    And(Box<Self>, Box<Self>),
141    Or(Box<Self>, Box<Self>),
142}
143
144impl<V> SearchFilter for Filter<V>
145where
146    V: std::fmt::Debug + Clone + Serialize + for<'de> Deserialize<'de>,
147{
148    type Value = V;
149
150    /// Select values where the entry at `key` is equal to `value`
151    fn eq(key: impl AsRef<str>, value: Self::Value) -> Self {
152        Self::Eq(key.as_ref().to_owned(), value)
153    }
154
155    /// Select values where the entry at `key` is greater than `value`
156    fn gt(key: impl AsRef<str>, value: Self::Value) -> Self {
157        Self::Gt(key.as_ref().to_owned(), value)
158    }
159
160    /// Select values where the entry at `key` is less than `value`
161    fn lt(key: impl AsRef<str>, value: Self::Value) -> Self {
162        Self::Lt(key.as_ref().to_owned(), value)
163    }
164
165    /// Select values where the entry satisfies `self` *and* `rhs`
166    fn and(self, rhs: Self) -> Self {
167        Self::And(self.into(), rhs.into())
168    }
169
170    /// Select values where the entry satisfies `self` *or* `rhs`
171    fn or(self, rhs: Self) -> Self {
172        Self::Or(self.into(), rhs.into())
173    }
174}
175
176impl<V> Filter<V>
177where
178    V: std::fmt::Debug + Clone,
179{
180    /// Converts this filter into a backend-specific filter type.
181    pub fn interpret<F>(self) -> F
182    where
183        F: SearchFilter<Value = V>,
184    {
185        match self {
186            Self::Eq(key, val) => F::eq(key, val),
187            Self::Gt(key, val) => F::gt(key, val),
188            Self::Lt(key, val) => F::lt(key, val),
189            Self::And(lhs, rhs) => F::and(lhs.interpret(), rhs.interpret()),
190            Self::Or(lhs, rhs) => F::or(lhs.interpret(), rhs.interpret()),
191        }
192    }
193}
194
195impl Filter<serde_json::Value> {
196    /// Tests whether a JSON value satisfies this filter.
197    pub fn satisfies(&self, value: &serde_json::Value) -> bool {
198        use Filter::*;
199        use serde_json::{Value, Value::*, json};
200        use std::cmp::Ordering;
201
202        fn compare_pair(l: &Value, r: &Value) -> Option<std::cmp::Ordering> {
203            match (l, r) {
204                (Number(l), Number(r)) => l
205                    .as_f64()
206                    .zip(r.as_f64())
207                    .and_then(|(l, r)| l.partial_cmp(&r))
208                    .or(l.as_i64().zip(r.as_i64()).map(|(l, r)| l.cmp(&r)))
209                    .or(l.as_u64().zip(r.as_u64()).map(|(l, r)| l.cmp(&r))),
210                (String(l), String(r)) => Some(l.cmp(r)),
211                (Null, Null) => Some(std::cmp::Ordering::Equal),
212                (Bool(l), Bool(r)) => Some(l.cmp(r)),
213                _ => None,
214            }
215        }
216
217        match self {
218            Eq(k, v) => &json!({ k: v }) == value,
219            Gt(k, v) => {
220                compare_pair(&json!({k: v}), value).is_some_and(|ord| ord == Ordering::Greater)
221            }
222            Lt(k, v) => {
223                compare_pair(&json!({k: v}), value).is_some_and(|ord| ord == Ordering::Less)
224            }
225            And(l, r) => l.satisfies(value) && r.satisfies(value),
226            Or(l, r) => l.satisfies(value) || r.satisfies(value),
227        }
228    }
229}
230
231/// Builder for [`VectorSearchRequest`]. Requires `query` and `samples`.
232#[derive(Clone, Serialize, Deserialize, Debug)]
233pub struct VectorSearchRequestBuilder<F = Filter<serde_json::Value>, Q = Missing, S = Missing> {
234    query: Q,
235    samples: S,
236    threshold: Option<f64>,
237    additional_params: Option<serde_json::Value>,
238    filter: Option<F>,
239}
240
241impl<F> Default for VectorSearchRequestBuilder<F, Missing, Missing> {
242    fn default() -> Self {
243        Self {
244            query: Missing,
245            samples: Missing,
246            threshold: None,
247            additional_params: None,
248            filter: None,
249        }
250    }
251}
252
253impl<F, Q, S> VectorSearchRequestBuilder<F, Q, S>
254where
255    F: SearchFilter,
256{
257    /// Sets the query text. Required.
258    pub fn query<T>(self, query: T) -> VectorSearchRequestBuilder<F, Provided<String>, S>
259    where
260        T: Into<String>,
261    {
262        VectorSearchRequestBuilder {
263            query: Provided(query.into()),
264            samples: self.samples,
265            threshold: self.threshold,
266            additional_params: self.additional_params,
267            filter: self.filter,
268        }
269    }
270
271    /// Sets the maximum number of results. Required.
272    pub fn samples(self, samples: u64) -> VectorSearchRequestBuilder<F, Q, Provided<u64>> {
273        VectorSearchRequestBuilder {
274            query: self.query,
275            samples: Provided(samples),
276            threshold: self.threshold,
277            additional_params: self.additional_params,
278            filter: self.filter,
279        }
280    }
281
282    /// Sets the minimum similarity threshold.
283    pub fn threshold(mut self, threshold: f64) -> Self {
284        self.threshold = Some(threshold);
285        self
286    }
287
288    /// Sets backend-specific parameters.
289    pub fn additional_params(
290        mut self,
291        params: serde_json::Value,
292    ) -> Result<Self, VectorStoreError> {
293        self.additional_params = Some(params);
294        Ok(self)
295    }
296
297    /// Sets a filter expression.
298    pub fn filter(mut self, filter: F) -> Self {
299        self.filter = Some(filter);
300        self
301    }
302}
303
304/// Only implement `build()` when both `query` and `samples` have been provided.
305impl<F> VectorSearchRequestBuilder<F, Provided<String>, Provided<u64>> {
306    /// Builds the request
307    pub fn build(self) -> VectorSearchRequest<F> {
308        VectorSearchRequest {
309            query: self.query.0,
310            samples: self.samples.0,
311            threshold: self.threshold,
312            additional_params: self.additional_params,
313            filter: self.filter,
314        }
315    }
316}