Skip to main content

flusso_query/
multi.rs

1//! Combined search: one query over several indexes, hits ranked together.
2//!
3//! Where [`Client::msearch`](crate::Client::msearch) runs *independent*
4//! searches in one round-trip (separate result lists), a [`MultiSearch`] runs
5//! **one** query across every index a [`FlussoMultiDocument`] union spans and
6//! returns a single, blended, relevance-ranked result list. Each hit decodes
7//! into the union variant matching its physical `_index` — the sink writes
8//! exactly `{INDEX}_{SCHEMA_HASH}`, so dispatch is precise, no alias involved.
9//!
10//! The union enum is yours: one single-field variant per document type, named
11//! after the search surface it serves. `#[derive(FlussoMultiDocument)]` (the
12//! `derive` feature) writes the impl; without it, a hand-written impl is two
13//! short members — see the trait docs.
14//!
15//! Root-scope queries already compose across document types ([`Query<Root>`]
16//! carries no document type), so any handle mix works in the builder. A field
17//! unmapped in one of the indexes simply doesn't match there — but **sorting**
18//! on it errors on the OpenSearch side unless the sort carries an
19//! `unmapped_type`; prefer sorting on fields all indexes share (or relevance).
20//!
21//! [`Query<Root>`]: crate::Query
22
23use std::marker::PhantomData;
24use std::time::Duration;
25
26use serde::Deserialize;
27use serde_json::{Map, Value};
28
29use crate::Client;
30use crate::error::Result;
31use crate::handles::Sort;
32use crate::query::{AsQuery, BoolBuilder, Root};
33use crate::search::{Hit, RawCount, SearchResponse};
34
35/// A union of [`FlussoDocument`](crate::FlussoDocument) types searched
36/// together — one query, one blended result list, each hit decoded into the
37/// variant matching its index.
38///
39/// `#[derive(FlussoMultiDocument)]` (the `derive` feature) implements it for
40/// an enum with one single-field variant per document type. Without the
41/// derive, the impl is written by hand — exactly what the derive generates:
42///
43/// ```no_run
44/// use flusso_query::{FlussoDocument, FlussoIndex, FlussoMultiDocument, Error, Result, Segment};
45/// use serde_json::Value;
46/// # #[derive(serde::Deserialize)] struct User { email: String }
47/// # impl FlussoDocument for User { const PATH: &'static [Segment] = &[]; }
48/// # impl FlussoIndex for User {
49/// #     const INDEX: &'static str = "users";
50/// #     const SCHEMA_HASH: &'static str = "xxxxxx";
51/// # }
52/// # #[derive(serde::Deserialize)] struct Order { status: String }
53/// # impl FlussoDocument for Order { const PATH: &'static [Segment] = &[]; }
54/// # impl FlussoIndex for Order {
55/// #     const INDEX: &'static str = "orders";
56/// #     const SCHEMA_HASH: &'static str = "yyyyyy";
57/// # }
58///
59/// /// One item in the storefront's blended search — name it after the
60/// /// surface it serves, like your document structs.
61/// enum StoreItem {
62///     User(User),
63///     Order(Order),
64/// }
65///
66/// impl FlussoMultiDocument for StoreItem {
67///     const TARGETS: &'static [(&'static str, &'static str)] = &[
68///         (User::INDEX, User::SCHEMA_HASH),
69///         (Order::INDEX, Order::SCHEMA_HASH),
70///     ];
71///
72///     fn decode(physical_index: &str, source: Value) -> Result<Self> {
73///         if physical_index == User::physical_index() {
74///             return Ok(Self::User(serde_json::from_value(source)?));
75///         }
76///         if physical_index == Order::physical_index() {
77///             return Ok(Self::Order(serde_json::from_value(source)?));
78///         }
79///         Err(Error::UnexpectedIndex { index: physical_index.to_owned() })
80///     }
81/// }
82/// ```
83pub trait FlussoMultiDocument: Sized {
84    /// The `(logical index, schema hash)` pair of every document type in the
85    /// union, in variant order — each is that type's
86    /// [`INDEX`](crate::FlussoIndex::INDEX) /
87    /// [`SCHEMA_HASH`](crate::FlussoIndex::SCHEMA_HASH).
88    const TARGETS: &'static [(&'static str, &'static str)];
89
90    /// Decode one hit's `_source` into the right variant, dispatching on the
91    /// hit's physical index name. A hit from an index no variant claims is
92    /// [`Error::UnexpectedIndex`](crate::Error::UnexpectedIndex).
93    fn decode(physical_index: &str, source: Value) -> Result<Self>;
94
95    /// Start a typed query across all of this union's indexes. Like
96    /// [`FlussoIndex::query`](crate::FlussoIndex::query), the returned
97    /// builder is a plain client-free value.
98    fn query() -> MultiSearch<Self> {
99        MultiSearch::new()
100    }
101}
102
103/// A typed query across every index of a [`FlussoMultiDocument`] union — the
104/// blended counterpart of [`Search`](crate::Search), with the same clause
105/// builder and the same client-free shape.
106///
107/// Hits come back in **one** relevance-ranked list; `from`/`size` page that
108/// blended list, not each index. Terminals: [`send`](Self::send) for a typed
109/// page of union values, [`count`](Self::count) for the total matches across
110/// all the indexes.
111#[derive(Debug, Clone)]
112pub struct MultiSearch<U> {
113    /// The comma-joined physical index list the request addresses.
114    path: String,
115    bool_query: BoolBuilder,
116    raw: Option<Value>,
117    sort: Vec<Sort>,
118    from: Option<u64>,
119    size: Option<u64>,
120    _marker: PhantomData<fn() -> U>,
121}
122
123impl<U: FlussoMultiDocument> MultiSearch<U> {
124    /// Start a query across the union's indexes (usually via
125    /// [`FlussoMultiDocument::query`]).
126    #[must_use]
127    pub fn new() -> Self {
128        let path = U::TARGETS
129            .iter()
130            .map(|(index, hash)| format!("{index}_{hash}"))
131            .collect::<Vec<_>>()
132            .join(",");
133        Self {
134            path,
135            bool_query: BoolBuilder::default(),
136            raw: None,
137            sort: Vec::new(),
138            from: None,
139            size: None,
140            _marker: PhantomData,
141        }
142    }
143
144    /// A scoring clause (`bool.must`). Root-scope queries from *any* of the
145    /// union's document types compose here; a field unmapped in one index
146    /// simply doesn't match there. An absent clause adds nothing.
147    #[must_use]
148    pub fn query(mut self, query: impl AsQuery<Root>) -> Self {
149        if let Some(query) = query.into_query() {
150            self.bool_query.push_must(query.into_inner());
151        }
152        self
153    }
154
155    /// A non-scoring, cacheable clause (`bool.filter`). An absent clause adds
156    /// nothing — so `filter(opt.map(|v| handle.eq(v)))` is a conditional filter.
157    #[must_use]
158    pub fn filter(mut self, query: impl AsQuery<Root>) -> Self {
159        if let Some(query) = query.into_query() {
160            self.bool_query.push_filter(query.into_inner());
161        }
162        self
163    }
164
165    /// An exclusion clause (`bool.must_not`). An absent clause excludes nothing.
166    #[must_use]
167    pub fn must_not(mut self, query: impl AsQuery<Root>) -> Self {
168        if let Some(query) = query.into_query() {
169            self.bool_query.push_must_not(query.into_inner());
170        }
171        self
172    }
173
174    /// An optional, scoring clause (`bool.should`). An absent clause adds nothing.
175    #[must_use]
176    pub fn should(mut self, query: impl AsQuery<Root>) -> Self {
177        if let Some(query) = query.into_query() {
178            self.bool_query.push_should(query.into_inner());
179        }
180        self
181    }
182
183    /// Append a sort key. It applies to the **blended** list, so the field
184    /// must exist in every index of the union (or carry an `unmapped_type` in
185    /// its options) — OpenSearch rejects a sort on a field one index lacks.
186    /// Relevance (no sort) is always safe.
187    #[must_use]
188    pub fn sort(mut self, sort: Sort) -> Self {
189        self.sort.push(sort);
190        self
191    }
192
193    /// Append several sort keys at once — e.g. from a
194    /// [`SortBuilder`](crate::SortBuilder). Equivalent to repeated [`sort`](Self::sort).
195    #[must_use]
196    pub fn sorts(mut self, sorts: impl IntoIterator<Item = Sort>) -> Self {
197        self.sort.extend(sorts);
198        self
199    }
200
201    /// Offset of the first hit to return, in the blended list.
202    #[must_use]
203    pub fn from(mut self, from: u64) -> Self {
204        self.from = Some(from);
205        self
206    }
207
208    /// Maximum number of hits to return, across all the indexes combined.
209    #[must_use]
210    pub fn size(mut self, size: u64) -> Self {
211        self.size = Some(size);
212        self
213    }
214
215    /// Replace the query body with a raw OpenSearch query DSL value. The
216    /// pressure-release valve, as on [`Search`](crate::Search); hits still
217    /// decode into the union.
218    #[must_use]
219    pub fn raw(mut self, query: Value) -> Self {
220        self.raw = Some(query);
221        self
222    }
223
224    /// The comma-joined physical index list this query addresses — one
225    /// `{index}_{hash}` per union variant. For logging and debugging.
226    #[must_use]
227    pub fn physical_path(&self) -> &str {
228        &self.path
229    }
230
231    /// The accumulated query alone: the raw override, the bool clauses, or
232    /// `match_all` when nothing was added.
233    fn query_value(&self) -> Value {
234        match &self.raw {
235            Some(raw) => raw.clone(),
236            None if self.bool_query.is_empty() => crate::handles::match_all_value(),
237            None => self.bool_query.to_value(),
238        }
239    }
240
241    /// The request body this search will POST to `_search`. Pure — useful for
242    /// tests and debugging.
243    #[must_use]
244    pub fn body(&self) -> Value {
245        let mut root = Map::new();
246        root.insert("query".to_string(), self.query_value());
247        if !self.sort.is_empty() {
248            let keys = self.sort.iter().map(Sort::to_value).collect();
249            root.insert("sort".to_string(), Value::Array(keys));
250        }
251        if let Some(from) = self.from {
252            root.insert("from".to_string(), Value::from(from));
253        }
254        if let Some(size) = self.size {
255            root.insert("size".to_string(), Value::from(size));
256        }
257        Value::Object(root)
258    }
259
260    /// The request body [`count`](Self::count) will POST to `_count`: just
261    /// the query (as on [`Search::count_body`](crate::Search::count_body)).
262    #[must_use]
263    pub fn count_body(&self) -> Value {
264        let mut root = Map::new();
265        root.insert("query".to_string(), self.query_value());
266        Value::Object(root)
267    }
268
269    /// Execute the search and decode the blended hits into the union.
270    #[tracing::instrument(
271        name = "search.multi",
272        skip_all,
273        fields(
274            path = %self.path,
275            from = ?self.from,
276            size = ?self.size,
277            total = tracing::field::Empty,
278            took_ms = tracing::field::Empty,
279        ),
280        err,
281    )]
282    pub async fn send(&self, client: &Client) -> Result<SearchResponse<U>> {
283        let body = self.body();
284        let response = client.search_at(&self.path, &body).await?;
285        let page = decode_response::<U>(response, &client.index_prefix)?;
286        let span = tracing::Span::current();
287        span.record("total", page.total);
288        span.record("took_ms", page.took.as_millis() as u64);
289        tracing::debug!(
290            total = page.total,
291            hits = page.hits.len(),
292            "combined search completed"
293        );
294        Ok(page)
295    }
296
297    /// Count the matches across all the union's indexes, without fetching
298    /// any hits.
299    #[tracing::instrument(
300        name = "search.multi_count",
301        skip_all,
302        fields(path = %self.path, count = tracing::field::Empty),
303        err,
304    )]
305    pub async fn count(&self, client: &Client) -> Result<u64> {
306        let body = self.count_body();
307        let response = client.count_at(&self.path, &body).await?;
308        let raw: RawCount = serde_json::from_value(response)?;
309        tracing::Span::current().record("count", raw.count);
310        tracing::debug!(count = raw.count, "combined count completed");
311        Ok(raw.count)
312    }
313}
314
315impl<U: FlussoMultiDocument> Default for MultiSearch<U> {
316    fn default() -> Self {
317        Self::new()
318    }
319}
320
321/// Decode a combined `_search` response: the usual envelope, but each hit's
322/// `_source` is dispatched by the hit's `_index` into the union. `prefix` (the
323/// client's index prefix) is stripped from each hit's `_index` first, so
324/// dispatch matches the union's unprefixed `physical_index()` — empty for an
325/// unprefixed deployment.
326pub(crate) fn decode_response<U: FlussoMultiDocument>(
327    value: Value,
328    prefix: &str,
329) -> Result<SearchResponse<U>> {
330    let raw: RawMultiResponse = serde_json::from_value(value)?;
331    let hits = raw
332        .hits
333        .hits
334        .into_iter()
335        .map(|hit| {
336            let index = hit.index.strip_prefix(prefix).unwrap_or(&hit.index);
337            Ok(Hit {
338                id: hit.id,
339                score: hit.score.unwrap_or(0.0),
340                source: U::decode(index, hit.source)?,
341            })
342        })
343        .collect::<Result<Vec<_>>>()?;
344    Ok(SearchResponse {
345        total: raw.hits.total.value,
346        max_score: raw.hits.max_score,
347        hits,
348        took: Duration::from_millis(raw.took),
349    })
350}
351
352#[derive(Deserialize)]
353struct RawMultiResponse {
354    #[serde(default)]
355    took: u64,
356    hits: RawMultiHits,
357}
358
359#[derive(Deserialize)]
360struct RawMultiHits {
361    total: RawMultiTotal,
362    #[serde(default)]
363    max_score: Option<f32>,
364    hits: Vec<RawMultiHit>,
365}
366
367#[derive(Deserialize)]
368struct RawMultiTotal {
369    value: u64,
370}
371
372#[derive(Deserialize)]
373struct RawMultiHit {
374    #[serde(rename = "_index")]
375    index: String,
376    #[serde(rename = "_id")]
377    id: String,
378    #[serde(rename = "_score", default)]
379    score: Option<f32>,
380    #[serde(rename = "_source")]
381    source: Value,
382}