Skip to main content

mdql_core/
search.rs

1//! Full-text search on section content using Tantivy.
2
3use std::collections::HashMap;
4use std::path::Path;
5
6use tantivy::collector::TopDocs;
7use tantivy::query::QueryParser;
8use tantivy::schema::{Schema as TantivySchema, Value as TantivyValue, STORED, TEXT};
9use tantivy::{doc, Index, IndexWriter, ReloadPolicy};
10
11use crate::errors::{MdqlError, Result};
12use crate::model::{Row, Value};
13
14/// Full-text search engine for a single table's section content.
15pub struct TableSearcher {
16    index: Index,
17    schema: TantivySchema,
18    path_field: tantivy::schema::Field,
19    section_fields: HashMap<String, tantivy::schema::Field>,
20}
21
22impl TableSearcher {
23    /// Build a Tantivy index from rows.
24    /// `section_names` are the section column names to index.
25    pub fn build(rows: &[Row], section_names: &[String]) -> Result<Self> {
26        let mut schema_builder = TantivySchema::builder();
27
28        let path_field = schema_builder.add_text_field("_path", STORED);
29        let mut section_fields = HashMap::new();
30        for name in section_names {
31            let field = schema_builder.add_text_field(name, TEXT | STORED);
32            section_fields.insert(name.clone(), field);
33        }
34        // A combined "all sections" field for unqualified searches
35        let all_field = schema_builder.add_text_field("_all", TEXT);
36
37        let schema = schema_builder.build();
38        let index = Index::create_in_ram(schema.clone());
39
40        let mut writer: IndexWriter = index
41            .writer(50_000_000)
42            .map_err(|e| MdqlError::General(format!("Tantivy writer error: {}", e)))?;
43
44        for row in rows {
45            let path = match row.get("path") {
46                Some(Value::String(p)) => p.clone(),
47                _ => continue,
48            };
49
50            let mut document = doc!(path_field => path);
51            let mut all_text = String::new();
52
53            for (name, &field) in &section_fields {
54                if let Some(Value::String(content)) = row.get(name) {
55                    document.add_text(field, content);
56                    all_text.push_str(content);
57                    all_text.push('\n');
58                }
59            }
60
61            document.add_text(all_field, &all_text);
62            writer
63                .add_document(document)
64                .map_err(|e| MdqlError::General(format!("Tantivy add error: {}", e)))?;
65        }
66
67        writer
68            .commit()
69            .map_err(|e| MdqlError::General(format!("Tantivy commit error: {}", e)))?;
70
71        Ok(TableSearcher {
72            index,
73            schema,
74            path_field,
75            section_fields,
76        })
77    }
78
79    /// Build from a table directory (stores index on disk for persistence).
80    pub fn build_on_disk(
81        rows: &[Row],
82        section_names: &[String],
83        index_dir: &Path,
84    ) -> Result<Self> {
85        std::fs::create_dir_all(index_dir)?;
86
87        let mut schema_builder = TantivySchema::builder();
88        let path_field = schema_builder.add_text_field("_path", STORED);
89        let mut section_fields = HashMap::new();
90        for name in section_names {
91            let field = schema_builder.add_text_field(name, TEXT | STORED);
92            section_fields.insert(name.clone(), field);
93        }
94        let all_field = schema_builder.add_text_field("_all", TEXT);
95
96        let schema = schema_builder.build();
97
98        // If an existing index exists, remove it and rebuild
99        let index = if Index::open_in_dir(index_dir).is_ok() {
100            // Remove old index
101            std::fs::remove_dir_all(index_dir)?;
102            std::fs::create_dir_all(index_dir)?;
103            Index::create_in_dir(index_dir, schema.clone())
104                .map_err(|e| MdqlError::General(format!("Tantivy create error: {}", e)))?
105        } else {
106            Index::create_in_dir(index_dir, schema.clone())
107                .map_err(|e| MdqlError::General(format!("Tantivy create error: {}", e)))?
108        };
109
110        let mut writer: IndexWriter = index
111            .writer(50_000_000)
112            .map_err(|e| MdqlError::General(format!("Tantivy writer error: {}", e)))?;
113
114        for row in rows {
115            let path = match row.get("path") {
116                Some(Value::String(p)) => p.clone(),
117                _ => continue,
118            };
119
120            let mut document = doc!(path_field => path);
121            let mut all_text = String::new();
122
123            for (name, &field) in &section_fields {
124                if let Some(Value::String(content)) = row.get(name) {
125                    document.add_text(field, content);
126                    all_text.push_str(content);
127                    all_text.push('\n');
128                }
129            }
130
131            document.add_text(all_field, &all_text);
132            writer
133                .add_document(document)
134                .map_err(|e| MdqlError::General(format!("Tantivy add error: {}", e)))?;
135        }
136
137        writer
138            .commit()
139            .map_err(|e| MdqlError::General(format!("Tantivy commit error: {}", e)))?;
140
141        Ok(TableSearcher {
142            index,
143            schema,
144            path_field,
145            section_fields,
146        })
147    }
148
149    /// Search for a term across all sections (or a specific section).
150    /// Returns matching file paths.
151    pub fn search(&self, query_str: &str, field: Option<&str>) -> Result<Vec<String>> {
152        let reader = self
153            .index
154            .reader_builder()
155            .reload_policy(ReloadPolicy::OnCommitWithDelay)
156            .try_into()
157            .map_err(|e| MdqlError::General(format!("Tantivy reader error: {}", e)))?;
158
159        let searcher = reader.searcher();
160
161        // Determine which fields to search
162        let search_fields: Vec<tantivy::schema::Field> = if let Some(field_name) = field {
163            if let Some(&f) = self.section_fields.get(field_name) {
164                vec![f]
165            } else {
166                return Ok(Vec::new());
167            }
168        } else {
169            // Search the combined _all field
170            let all_field = self.schema.get_field("_all")
171                .map_err(|e| MdqlError::General(format!("Missing _all field: {}", e)))?;
172            vec![all_field]
173        };
174
175        let parser = QueryParser::for_index(&self.index, search_fields);
176        let query = parser
177            .parse_query(query_str)
178            .map_err(|e| MdqlError::General(format!("Tantivy parse error: {}", e)))?;
179
180        let top_docs = searcher
181            .search(&query, &TopDocs::with_limit(10000))
182            .map_err(|e| MdqlError::General(format!("Tantivy search error: {}", e)))?;
183
184        let mut paths = Vec::new();
185        for (_score, doc_address) in top_docs {
186            let doc: tantivy::TantivyDocument = searcher.doc(doc_address)
187                .map_err(|e| MdqlError::General(format!("Tantivy doc error: {}", e)))?;
188            if let Some(path_value) = doc.get_first(self.path_field) {
189                if let Some(text) = path_value.as_str() {
190                    paths.push(text.to_string());
191                }
192            }
193        }
194
195        Ok(paths)
196    }
197
198    /// Rebuild the index from fresh rows.
199    pub fn rebuild(&mut self, rows: &[Row]) -> Result<()> {
200        let mut writer: IndexWriter = self
201            .index
202            .writer(50_000_000)
203            .map_err(|e| MdqlError::General(format!("Tantivy writer error: {}", e)))?;
204
205        writer
206            .delete_all_documents()
207            .map_err(|e| MdqlError::General(format!("Tantivy delete error: {}", e)))?;
208
209        let all_field = self.schema.get_field("_all")
210            .map_err(|e| MdqlError::General(format!("Missing _all field: {}", e)))?;
211
212        for row in rows {
213            let path = match row.get("path") {
214                Some(Value::String(p)) => p.clone(),
215                _ => continue,
216            };
217
218            let mut document = doc!(self.path_field => path);
219            let mut all_text = String::new();
220
221            for (name, &field) in &self.section_fields {
222                if let Some(Value::String(content)) = row.get(name) {
223                    document.add_text(field, content);
224                    all_text.push_str(content);
225                    all_text.push('\n');
226                }
227            }
228
229            document.add_text(all_field, &all_text);
230            writer
231                .add_document(document)
232                .map_err(|e| MdqlError::General(format!("Tantivy add error: {}", e)))?;
233        }
234
235        writer
236            .commit()
237            .map_err(|e| MdqlError::General(format!("Tantivy commit error: {}", e)))?;
238
239        Ok(())
240    }
241}
242
243#[cfg(test)]
244mod tests {
245    use super::*;
246    use crate::model::{Row, Value};
247
248    fn test_rows() -> Vec<Row> {
249        vec![
250            Row::from([
251                ("path".into(), Value::String("a.md".into())),
252                (
253                    "Summary".into(),
254                    Value::String("This is about machine learning and neural networks".into()),
255                ),
256                (
257                    "Details".into(),
258                    Value::String("Deep dive into backpropagation algorithms".into()),
259                ),
260            ]),
261            Row::from([
262                ("path".into(), Value::String("b.md".into())),
263                (
264                    "Summary".into(),
265                    Value::String("A guide to database optimization".into()),
266                ),
267                (
268                    "Details".into(),
269                    Value::String("Index tuning and query planning for PostgreSQL".into()),
270                ),
271            ]),
272            Row::from([
273                ("path".into(), Value::String("c.md".into())),
274                (
275                    "Summary".into(),
276                    Value::String("Introduction to neural network architectures".into()),
277                ),
278            ]),
279        ]
280    }
281
282    #[test]
283    fn test_search_all_sections() {
284        let sections = vec!["Summary".into(), "Details".into()];
285        let searcher = TableSearcher::build(&test_rows(), &sections).unwrap();
286
287        let results = searcher.search("neural", None).unwrap();
288        assert_eq!(results.len(), 2);
289        assert!(results.contains(&"a.md".to_string()));
290        assert!(results.contains(&"c.md".to_string()));
291    }
292
293    #[test]
294    fn test_search_specific_section() {
295        let sections = vec!["Summary".into(), "Details".into()];
296        let searcher = TableSearcher::build(&test_rows(), &sections).unwrap();
297
298        let results = searcher.search("backpropagation", Some("Details")).unwrap();
299        assert_eq!(results.len(), 1);
300        assert_eq!(results[0], "a.md");
301    }
302
303    #[test]
304    fn test_search_no_results() {
305        let sections = vec!["Summary".into(), "Details".into()];
306        let searcher = TableSearcher::build(&test_rows(), &sections).unwrap();
307
308        let results = searcher.search("quantum", None).unwrap();
309        assert!(results.is_empty());
310    }
311
312    #[test]
313    fn test_rebuild() {
314        let sections = vec!["Summary".into()];
315        let mut searcher = TableSearcher::build(&test_rows(), &sections).unwrap();
316
317        // Rebuild with different data
318        let new_rows = vec![Row::from([
319            ("path".into(), Value::String("d.md".into())),
320            (
321                "Summary".into(),
322                Value::String("Quantum computing basics".into()),
323            ),
324        ])];
325
326        searcher.rebuild(&new_rows).unwrap();
327        let results = searcher.search("quantum", None).unwrap();
328        assert_eq!(results.len(), 1);
329        assert_eq!(results[0], "d.md");
330
331        // Old data should be gone
332        let old_results = searcher.search("neural", None).unwrap();
333        assert!(old_results.is_empty());
334    }
335}