rig/vector_store/
request.rs

1//! Types for constructing vector search queries.
2//!
3//! - [`VectorSearchRequest`]: Query parameters (text, result count, threshold, filters).
4//! - [`SearchFilter`]: Trait for backend-agnostic filter expressions.
5//! - [`Filter`]: Canonical, serializable filter representation.
6
7use serde::{Deserialize, Serialize};
8
9use super::VectorStoreError;
10
11/// A vector search request for querying a [`super::VectorStoreIndex`].
12///
13/// The type parameter `F` specifies the filter type (defaults to [`Filter<serde_json::Value>`]).
14/// Use [`VectorSearchRequest::builder()`] to construct instances.
15#[derive(Clone, Serialize, Deserialize, Debug)]
16pub struct VectorSearchRequest<F = Filter<serde_json::Value>> {
17    /// The query text to embed and search with.
18    query: String,
19    /// Maximum number of results to return.
20    samples: u64,
21    /// Minimum similarity score for results.
22    threshold: Option<f64>,
23    /// Backend-specific parameters as a JSON object.
24    additional_params: Option<serde_json::Value>,
25    /// Filter expression to narrow results by metadata.
26    filter: Option<F>,
27}
28
29impl<Filter> VectorSearchRequest<Filter> {
30    /// Creates a [`VectorSearchRequestBuilder`] which you can use to instantiate this struct.
31    pub fn builder() -> VectorSearchRequestBuilder<Filter> {
32        VectorSearchRequestBuilder::<Filter>::default()
33    }
34
35    /// The query to be embedded and used in similarity search.
36    pub fn query(&self) -> &str {
37        &self.query
38    }
39
40    /// Returns the maximum number of results to return.
41    pub fn samples(&self) -> u64 {
42        self.samples
43    }
44
45    /// Returns the optional similarity threshold.
46    pub fn threshold(&self) -> Option<f64> {
47        self.threshold
48    }
49
50    /// Returns a reference to the optional filter expression.
51    pub fn filter(&self) -> &Option<Filter> {
52        &self.filter
53    }
54
55    /// Transforms the filter type using the provided function.
56    ///
57    /// This is useful for converting between filter representations, such as
58    /// translating the canonical [`super::request::Filter`] to a backend-specific filter type.
59    pub fn map_filter<T, F>(self, f: F) -> VectorSearchRequest<T>
60    where
61        F: Fn(Filter) -> T,
62    {
63        VectorSearchRequest {
64            query: self.query,
65            samples: self.samples,
66            threshold: self.threshold,
67            additional_params: self.additional_params,
68            filter: self.filter.map(f),
69        }
70    }
71}
72
73/// Errors from constructing or converting filter expressions.
74#[derive(Debug, Clone, thiserror::Error)]
75pub enum FilterError {
76    #[error("Expected: {expected}, got: {got}")]
77    Expected { expected: String, got: String },
78
79    #[error("Cannot compile '{0}' to the backend's filter type")]
80    TypeError(String),
81
82    #[error("Missing field '{0}'")]
83    MissingField(String),
84
85    #[error("'{0}' must {1}")]
86    Must(String, String),
87
88    // NOTE: Uses String because `serde_json::Error` is not `Clone`.
89    #[error("Filter serialization failed: {0}")]
90    Serialization(String),
91}
92
93/// Trait for constructing filter expressions in vector search queries.
94///
95/// Uses [tagless final](https://nrinaudo.github.io/articles/tagless_final.html) encoding
96/// for backend-agnostic filters. Use `SearchFilter::eq(...)` etc. directly and let
97/// type inference resolve the concrete filter type.
98pub trait SearchFilter {
99    type Value;
100
101    fn eq(key: impl AsRef<str>, value: Self::Value) -> Self;
102    fn gt(key: impl AsRef<str>, value: Self::Value) -> Self;
103    fn lt(key: impl AsRef<str>, value: Self::Value) -> Self;
104    fn and(self, rhs: Self) -> Self;
105    fn or(self, rhs: Self) -> Self;
106}
107
108/// Canonical, serializable filter representation.
109///
110/// Use for serialization, runtime inspection, or translating between backends via
111/// [`Filter::interpret`]. Prefer [`SearchFilter`] trait methods for writing queries.
112#[derive(Debug, Clone, Serialize, Deserialize)]
113#[serde(rename_all = "lowercase")]
114pub enum Filter<V>
115where
116    V: std::fmt::Debug + Clone,
117{
118    Eq(String, V),
119    Gt(String, V),
120    Lt(String, V),
121    And(Box<Self>, Box<Self>),
122    Or(Box<Self>, Box<Self>),
123}
124
125impl<V> SearchFilter for Filter<V>
126where
127    V: std::fmt::Debug + Clone + Serialize + for<'de> Deserialize<'de>,
128{
129    type Value = V;
130
131    /// Select values where the entry at `key` is equal to `value`
132    fn eq(key: impl AsRef<str>, value: Self::Value) -> Self {
133        Self::Eq(key.as_ref().to_owned(), value)
134    }
135
136    /// Select values where the entry at `key` is greater than `value`
137    fn gt(key: impl AsRef<str>, value: Self::Value) -> Self {
138        Self::Gt(key.as_ref().to_owned(), value)
139    }
140
141    /// Select values where the entry at `key` is less than `value`
142    fn lt(key: impl AsRef<str>, value: Self::Value) -> Self {
143        Self::Lt(key.as_ref().to_owned(), value)
144    }
145
146    /// Select values where the entry satisfies `self` *and* `rhs`
147    fn and(self, rhs: Self) -> Self {
148        Self::And(self.into(), rhs.into())
149    }
150
151    /// Select values where the entry satisfies `self` *or* `rhs`
152    fn or(self, rhs: Self) -> Self {
153        Self::Or(self.into(), rhs.into())
154    }
155}
156
157impl<V> Filter<V>
158where
159    V: std::fmt::Debug + Clone,
160{
161    /// Converts this filter into a backend-specific filter type.
162    pub fn interpret<F>(self) -> F
163    where
164        F: SearchFilter<Value = V>,
165    {
166        match self {
167            Self::Eq(key, val) => F::eq(key, val),
168            Self::Gt(key, val) => F::gt(key, val),
169            Self::Lt(key, val) => F::lt(key, val),
170            Self::And(lhs, rhs) => F::and(lhs.interpret(), rhs.interpret()),
171            Self::Or(lhs, rhs) => F::or(lhs.interpret(), rhs.interpret()),
172        }
173    }
174}
175
176impl Filter<serde_json::Value> {
177    /// Tests whether a JSON value satisfies this filter.
178    pub fn satisfies(&self, value: &serde_json::Value) -> bool {
179        use Filter::*;
180        use serde_json::{Value, Value::*, json};
181        use std::cmp::Ordering;
182
183        fn compare_pair(l: &Value, r: &Value) -> Option<std::cmp::Ordering> {
184            match (l, r) {
185                (Number(l), Number(r)) => l
186                    .as_f64()
187                    .zip(r.as_f64())
188                    .and_then(|(l, r)| l.partial_cmp(&r))
189                    .or(l.as_i64().zip(r.as_i64()).map(|(l, r)| l.cmp(&r)))
190                    .or(l.as_u64().zip(r.as_u64()).map(|(l, r)| l.cmp(&r))),
191                (String(l), String(r)) => Some(l.cmp(r)),
192                (Null, Null) => Some(std::cmp::Ordering::Equal),
193                (Bool(l), Bool(r)) => Some(l.cmp(r)),
194                _ => None,
195            }
196        }
197
198        match self {
199            Eq(k, v) => &json!({ k: v }) == value,
200            Gt(k, v) => {
201                compare_pair(&json!({k: v}), value).is_some_and(|ord| ord == Ordering::Greater)
202            }
203            Lt(k, v) => {
204                compare_pair(&json!({k: v}), value).is_some_and(|ord| ord == Ordering::Less)
205            }
206            And(l, r) => l.satisfies(value) && r.satisfies(value),
207            Or(l, r) => l.satisfies(value) || r.satisfies(value),
208        }
209    }
210}
211
212/// Builder for [`VectorSearchRequest`]. Requires `query` and `samples`.
213#[derive(Clone, Serialize, Deserialize, Debug)]
214pub struct VectorSearchRequestBuilder<F = Filter<serde_json::Value>> {
215    query: Option<String>,
216    samples: Option<u64>,
217    threshold: Option<f64>,
218    additional_params: Option<serde_json::Value>,
219    filter: Option<F>,
220}
221
222impl<F> Default for VectorSearchRequestBuilder<F> {
223    fn default() -> Self {
224        Self {
225            query: None,
226            samples: None,
227            threshold: None,
228            additional_params: None,
229            filter: None,
230        }
231    }
232}
233
234impl<F> VectorSearchRequestBuilder<F>
235where
236    F: SearchFilter,
237{
238    /// Sets the query text. Required.
239    pub fn query<T>(mut self, query: T) -> Self
240    where
241        T: Into<String>,
242    {
243        self.query = Some(query.into());
244        self
245    }
246
247    /// Sets the maximum number of results. Required.
248    pub fn samples(mut self, samples: u64) -> Self {
249        self.samples = Some(samples);
250        self
251    }
252
253    /// Sets the minimum similarity threshold.
254    pub fn threshold(mut self, threshold: f64) -> Self {
255        self.threshold = Some(threshold);
256        self
257    }
258
259    /// Sets backend-specific parameters.
260    pub fn additional_params(
261        mut self,
262        params: serde_json::Value,
263    ) -> Result<Self, VectorStoreError> {
264        self.additional_params = Some(params);
265        Ok(self)
266    }
267
268    /// Sets a filter expression.
269    pub fn filter(mut self, filter: F) -> Self {
270        self.filter = Some(filter);
271        self
272    }
273
274    /// Builds the request, returning an error if required fields are missing.
275    pub fn build(self) -> Result<VectorSearchRequest<F>, VectorStoreError> {
276        let Some(query) = self.query else {
277            return Err(VectorStoreError::BuilderError(
278                "`query` is a required variable for building a vector search request".into(),
279            ));
280        };
281
282        let Some(samples) = self.samples else {
283            return Err(VectorStoreError::BuilderError(
284                "`samples` is a required variable for building a vector search request".into(),
285            ));
286        };
287
288        let additional_params = if let Some(params) = self.additional_params {
289            if !params.is_object() {
290                return Err(VectorStoreError::BuilderError(
291                    "Expected JSON object for additional params, got something else".into(),
292                ));
293            }
294            Some(params)
295        } else {
296            None
297        };
298
299        Ok(VectorSearchRequest {
300            query,
301            samples,
302            threshold: self.threshold,
303            additional_params,
304            filter: self.filter,
305        })
306    }
307}