Skip to main content

pallet/
search.rs

1use crate::{err, Document, DocumentLike, Store};
2use std::path::PathBuf;
3use std::sync::Mutex;
4
5mod as_query;
6mod field_value;
7mod params;
8mod scored_ids;
9
10pub use as_query::AsQuery;
11pub use params::Params;
12pub use scored_ids::{ScoredId, ScoredIds};
13
14// For use primarily by `pallet_macros`.
15#[doc(hidden)]
16pub use field_value::FieldValue;
17
18#[doc(hidden)]
19// For use primarily by `pallet_macros`.
20pub struct FieldsContainer(pub Vec<tantivy::schema::Field>);
21
22/// Wrapper around `tantivy::Index`, with additional search data
23pub struct Index<T> {
24    pub id_field: tantivy::schema::Field,
25    pub fields: T,
26    default_search_fields: Vec<tantivy::schema::Field>,
27    inner: tantivy::Index,
28    pub(crate) writer: Mutex<Option<tantivy::IndexWriter>>,
29    writer_accessor:
30        Box<dyn Fn(&tantivy::Index) -> tantivy::Result<tantivy::IndexWriter> + Send + Sync>,
31}
32
33impl<T> Index<T> {
34    /// Create a new builder
35    pub fn builder() -> IndexBuilder<T> {
36        IndexBuilder::default()
37    }
38    /// Get the query parser associated with index and default search fields.
39    pub fn query_parser(&self) -> tantivy::query::QueryParser {
40        tantivy::query::QueryParser::for_index(&self.inner, self.default_search_fields.clone())
41    }
42
43    pub(crate) fn with_writer<F, S, E>(&self, cls: F) -> Result<S, E>
44    where
45        F: Fn(&mut tantivy::IndexWriter) -> Result<S, E>,
46        E: From<err::Error>,
47    {
48        let mut lock = self.writer.lock().map_err(err::custom).map_err(E::from)?;
49
50        let mut writer = match lock.take() {
51            Some(writer) => writer,
52            None => {
53                (self.writer_accessor)(&self.inner).map_err(err::Error::Tantivy).map_err(E::from)?
54            }
55        };
56
57        let out = cls(&mut writer)?;
58
59        *lock = Some(writer);
60
61        Ok(out)
62    }
63}
64
65/// Builder for an `Index`
66pub struct IndexBuilder<T> {
67    fields_builder: Option<Box<dyn Fn(&mut tantivy::schema::SchemaBuilder) -> err::Result<T>>>,
68    default_search_fields_builder: Option<Box<dyn Fn(&T) -> Vec<tantivy::schema::Field>>>,
69    writer_accessor:
70        Option<Box<dyn Fn(&tantivy::Index) -> tantivy::Result<tantivy::IndexWriter> + Send + Sync>>,
71    index_dir: Option<PathBuf>,
72    config: Option<Box<dyn Fn(&mut tantivy::Index) -> tantivy::Result<()>>>,
73    id_field_name: Option<String>,
74}
75
76impl<T> Default for IndexBuilder<T> {
77    fn default() -> Self {
78        IndexBuilder {
79            fields_builder: None,
80            default_search_fields_builder: None,
81            writer_accessor: None,
82            index_dir: None,
83            config: None,
84            id_field_name: None,
85        }
86    }
87}
88
89impl<T> IndexBuilder<T> {
90    pub(crate) fn merge(self, other: Self) -> Self {
91        let IndexBuilder {
92            fields_builder: a1,
93            default_search_fields_builder: a2,
94            writer_accessor: a3,
95            index_dir: a4,
96            config: a5,
97            id_field_name: a6,
98        } = self;
99
100        let IndexBuilder {
101            fields_builder: b1,
102            default_search_fields_builder: b2,
103            writer_accessor: b3,
104            index_dir: b4,
105            config: b5,
106            id_field_name: b6,
107        } = other;
108
109        IndexBuilder {
110            fields_builder: a1.or(b1),
111            default_search_fields_builder: a2.or(b2),
112            writer_accessor: a3.or(b3),
113            index_dir: a4.or(b4),
114            config: a5.or(b5),
115            id_field_name: a6.or(b6),
116        }
117    }
118
119    /// Use the given directory (must exist) for the `tantivy::Index`.
120    pub fn with_index_dir<I: Into<PathBuf>>(mut self, index_dir: I) -> Self {
121        self.index_dir = Some(index_dir.into());
122        self
123    }
124
125    /// Define a custom way to get the `tantivy::IndexWriter`.
126    ///
127    /// By default will use `tantivy_index.writer(128_000_000)`.
128    pub fn with_writer_accessor<F>(mut self, writer_accessor: F) -> Self
129    where
130        F: Fn(&tantivy::Index) -> tantivy::Result<tantivy::IndexWriter> + Send + Sync + 'static,
131    {
132        self.writer_accessor = Some(Box::new(writer_accessor));
133        self
134    }
135
136    /// Custom configuration for the `tantivy::Index`.
137    ///
138    /// By default will use `tantivy_index.set_default_multithread_executor()?`.
139    pub fn with_config<F>(mut self, config: F) -> Self
140    where
141        F: Fn(&mut tantivy::Index) -> tantivy::Result<()> + 'static,
142    {
143        self.config = Some(Box::new(config));
144        self
145    }
146
147    /// Set the field name to be used for the datastore `id`.
148    ///
149    /// By default will use `__id__`.
150    pub fn with_id_field_name<I: Into<String>>(mut self, id_field_name: I) -> Self {
151        self.id_field_name = Some(id_field_name.into());
152        self
153    }
154
155    /// Handler that adds fields to a schema, and returns them in the fields container
156    pub fn with_fields_builder<F>(mut self, fields_builder: F) -> Self
157    where
158        F: Fn(&mut tantivy::schema::SchemaBuilder) -> err::Result<T> + 'static,
159    {
160        self.fields_builder = Some(Box::new(fields_builder));
161        self
162    }
163
164    /// Given the fields container, return fields that should be used in default search.
165    pub fn with_default_search_fields_builder<F>(mut self, default_search_fields_builder: F) -> Self
166    where
167        F: Fn(&T) -> Vec<tantivy::schema::Field> + 'static,
168    {
169        self.default_search_fields_builder = Some(Box::new(default_search_fields_builder));
170        self
171    }
172
173    /// Convert into finished `Index`
174    pub fn finish(self) -> err::Result<Index<T>> {
175        let fields_builder =
176            self.fields_builder.ok_or_else(|| err::custom("`fields_builder` not set"))?;
177
178        let index_dir = self.index_dir.ok_or_else(|| err::custom("`index_dir` not set"))?;
179
180        let mut schema_builder = tantivy::schema::SchemaBuilder::default();
181
182        let fields = fields_builder(&mut schema_builder)?;
183
184        let id_field = match self.id_field_name.as_ref() {
185            Some(id_field_name) => schema_builder
186                .add_u64_field(id_field_name, tantivy::schema::INDEXED | tantivy::schema::FAST),
187            None => schema_builder
188                .add_u64_field("__id__", tantivy::schema::INDEXED | tantivy::schema::FAST),
189        };
190
191        let schema = schema_builder.build();
192
193        let mmap_dir = tantivy::directory::MmapDirectory::open(&index_dir)
194            .map_err(tantivy::TantivyError::from)?;
195
196        let mut index = tantivy::Index::open_or_create(mmap_dir, schema)?;
197
198        if let Some(config) = self.config {
199            config(&mut index)?;
200        } else {
201            index.set_default_multithread_executor()?;
202        }
203
204        let writer_accessor =
205            self.writer_accessor.unwrap_or_else(|| Box::new(|idx| idx.writer(128_000_000)));
206
207        let default_search_fields =
208            if let Some(default_search_fields_builder) = self.default_search_fields_builder {
209                default_search_fields_builder(&fields)
210            } else {
211                Vec::new()
212            };
213
214        // let writer = writer_accessor(&index)?;
215
216        Ok(Index {
217            default_search_fields,
218            inner: index,
219            id_field,
220            fields,
221            writer_accessor,
222            writer: Mutex::new(None),
223        })
224    }
225}
226
227/// `Document` wrapper that includes the search query score
228#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
229pub struct Hit<T> {
230    pub score: f32,
231    pub doc: Document<T>,
232}
233
234/// Search results container, contains the `count` of returned results
235#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
236pub struct Results<T> {
237    pub count: usize,
238    pub hits: Vec<Hit<T>>,
239}
240
241/// Items that function as search parameters
242pub trait Searcher<T: DocumentLike> {
243    type Item;
244    type Error: From<err::Error>;
245    fn search(&self, store: &Store<T>) -> Result<Self::Item, Self::Error>;
246}
247
248impl<Q, C, H, O, T, E> Searcher<T> for Params<Q, params::Collector<C>, params::Handler<H>>
249where
250    Q: AsQuery,
251    E: From<err::Error>,
252    C: tantivy::collector::Collector,
253    H: Fn(C::Fruit) -> Result<O, E>,
254    T: DocumentLike,
255{
256    type Item = O;
257    type Error = E;
258
259    fn search(&self, store: &Store<T>) -> Result<Self::Item, Self::Error> {
260        let Self {
261            query: ref query_like,
262            collector: params::Collector(ref collector),
263            handler: params::Handler(ref handler),
264            ..
265        } = self;
266
267        let reader = store.index.inner.reader().map_err(err::Error::from)?;
268
269        let searcher = reader.searcher();
270
271        let query = query_like.as_query(&store.index.inner, &store.index.default_search_fields)?;
272
273        let fruit = searcher.search(query.as_ref(), collector).map_err(err::Error::from)?;
274
275        handler(fruit)
276    }
277}
278
279impl<Q, T> Searcher<T> for Q
280where
281    Q: AsQuery,
282    T: DocumentLike + Send,
283    T::IndexFieldsType: Sync,
284{
285    type Item = Results<T>;
286    type Error = err::Error;
287
288    fn search(&self, store: &Store<T>) -> Result<Self::Item, Self::Error> {
289        use rayon::prelude::*;
290
291        let scored_ids_handle = ScoredIds { size_hint: None, id_field: store.index.id_field };
292        let count_handle = tantivy::collector::Count;
293
294        let query = self.as_query(&store.index.inner, &store.index.default_search_fields)?;
295
296        let search_params = Params::default()
297            .with_query(query)
298            .with_collector((count_handle, scored_ids_handle))
299            .with_handler(|(count, scored_ids)| -> Result<_, err::Error> {
300                let hits = scored_ids
301                    .into_par_iter()
302                    .map(|ScoredId { id, score }| {
303                        store.find(id).map(|opt_doc| opt_doc.map(|doc| Hit { doc, score }))
304                    })
305                    .filter_map(Result::transpose)
306                    .collect::<err::Result<Vec<_>>>()?;
307
308                Ok(Results { count, hits })
309            });
310        Ok(search_params.search(store)?)
311    }
312}