1use std::collections::HashMap;
4use std::path::Path;
5
6use tantivy::collector::TopDocs;
7use tantivy::query::QueryParser;
8use tantivy::schema::{Schema as TantivySchema, Value as TantivyValue, STORED, TEXT};
9use tantivy::{doc, Index, IndexWriter, ReloadPolicy};
10
11use crate::errors::{MdqlError, Result};
12use crate::model::{Row, Value};
13
14pub struct TableSearcher {
16 index: Index,
17 schema: TantivySchema,
18 path_field: tantivy::schema::Field,
19 section_fields: HashMap<String, tantivy::schema::Field>,
20}
21
22impl TableSearcher {
23 pub fn build(rows: &[Row], section_names: &[String]) -> Result<Self> {
26 let mut schema_builder = TantivySchema::builder();
27
28 let path_field = schema_builder.add_text_field("_path", STORED);
29 let mut section_fields = HashMap::new();
30 for name in section_names {
31 let field = schema_builder.add_text_field(name, TEXT | STORED);
32 section_fields.insert(name.clone(), field);
33 }
34 let all_field = schema_builder.add_text_field("_all", TEXT);
36
37 let schema = schema_builder.build();
38 let index = Index::create_in_ram(schema.clone());
39
40 let mut writer: IndexWriter = index
41 .writer(50_000_000)
42 .map_err(|e| MdqlError::General(format!("Tantivy writer error: {}", e)))?;
43
44 for row in rows {
45 let path = match row.get("path") {
46 Some(Value::String(p)) => p.clone(),
47 _ => continue,
48 };
49
50 let mut document = doc!(path_field => path);
51 let mut all_text = String::new();
52
53 for (name, &field) in §ion_fields {
54 if let Some(Value::String(content)) = row.get(name) {
55 document.add_text(field, content);
56 all_text.push_str(content);
57 all_text.push('\n');
58 }
59 }
60
61 document.add_text(all_field, &all_text);
62 writer
63 .add_document(document)
64 .map_err(|e| MdqlError::General(format!("Tantivy add error: {}", e)))?;
65 }
66
67 writer
68 .commit()
69 .map_err(|e| MdqlError::General(format!("Tantivy commit error: {}", e)))?;
70
71 Ok(TableSearcher {
72 index,
73 schema,
74 path_field,
75 section_fields,
76 })
77 }
78
79 pub fn build_on_disk(
81 rows: &[Row],
82 section_names: &[String],
83 index_dir: &Path,
84 ) -> Result<Self> {
85 std::fs::create_dir_all(index_dir)?;
86
87 let mut schema_builder = TantivySchema::builder();
88 let path_field = schema_builder.add_text_field("_path", STORED);
89 let mut section_fields = HashMap::new();
90 for name in section_names {
91 let field = schema_builder.add_text_field(name, TEXT | STORED);
92 section_fields.insert(name.clone(), field);
93 }
94 let all_field = schema_builder.add_text_field("_all", TEXT);
95
96 let schema = schema_builder.build();
97
98 let index = if Index::open_in_dir(index_dir).is_ok() {
100 std::fs::remove_dir_all(index_dir)?;
102 std::fs::create_dir_all(index_dir)?;
103 Index::create_in_dir(index_dir, schema.clone())
104 .map_err(|e| MdqlError::General(format!("Tantivy create error: {}", e)))?
105 } else {
106 Index::create_in_dir(index_dir, schema.clone())
107 .map_err(|e| MdqlError::General(format!("Tantivy create error: {}", e)))?
108 };
109
110 let mut writer: IndexWriter = index
111 .writer(50_000_000)
112 .map_err(|e| MdqlError::General(format!("Tantivy writer error: {}", e)))?;
113
114 for row in rows {
115 let path = match row.get("path") {
116 Some(Value::String(p)) => p.clone(),
117 _ => continue,
118 };
119
120 let mut document = doc!(path_field => path);
121 let mut all_text = String::new();
122
123 for (name, &field) in §ion_fields {
124 if let Some(Value::String(content)) = row.get(name) {
125 document.add_text(field, content);
126 all_text.push_str(content);
127 all_text.push('\n');
128 }
129 }
130
131 document.add_text(all_field, &all_text);
132 writer
133 .add_document(document)
134 .map_err(|e| MdqlError::General(format!("Tantivy add error: {}", e)))?;
135 }
136
137 writer
138 .commit()
139 .map_err(|e| MdqlError::General(format!("Tantivy commit error: {}", e)))?;
140
141 Ok(TableSearcher {
142 index,
143 schema,
144 path_field,
145 section_fields,
146 })
147 }
148
149 pub fn search(&self, query_str: &str, field: Option<&str>) -> Result<Vec<String>> {
152 let reader = self
153 .index
154 .reader_builder()
155 .reload_policy(ReloadPolicy::OnCommitWithDelay)
156 .try_into()
157 .map_err(|e| MdqlError::General(format!("Tantivy reader error: {}", e)))?;
158
159 let searcher = reader.searcher();
160
161 let search_fields: Vec<tantivy::schema::Field> = if let Some(field_name) = field {
163 if let Some(&f) = self.section_fields.get(field_name) {
164 vec![f]
165 } else {
166 return Ok(Vec::new());
167 }
168 } else {
169 let all_field = self.schema.get_field("_all")
171 .map_err(|e| MdqlError::General(format!("Missing _all field: {}", e)))?;
172 vec![all_field]
173 };
174
175 let parser = QueryParser::for_index(&self.index, search_fields);
176 let query = parser
177 .parse_query(query_str)
178 .map_err(|e| MdqlError::General(format!("Tantivy parse error: {}", e)))?;
179
180 let top_docs = searcher
181 .search(&query, &TopDocs::with_limit(10000))
182 .map_err(|e| MdqlError::General(format!("Tantivy search error: {}", e)))?;
183
184 let mut paths = Vec::new();
185 for (_score, doc_address) in top_docs {
186 let doc: tantivy::TantivyDocument = searcher.doc(doc_address)
187 .map_err(|e| MdqlError::General(format!("Tantivy doc error: {}", e)))?;
188 if let Some(path_value) = doc.get_first(self.path_field) {
189 if let Some(text) = path_value.as_str() {
190 paths.push(text.to_string());
191 }
192 }
193 }
194
195 Ok(paths)
196 }
197
198 pub fn rebuild(&mut self, rows: &[Row]) -> Result<()> {
200 let mut writer: IndexWriter = self
201 .index
202 .writer(50_000_000)
203 .map_err(|e| MdqlError::General(format!("Tantivy writer error: {}", e)))?;
204
205 writer
206 .delete_all_documents()
207 .map_err(|e| MdqlError::General(format!("Tantivy delete error: {}", e)))?;
208
209 let all_field = self.schema.get_field("_all")
210 .map_err(|e| MdqlError::General(format!("Missing _all field: {}", e)))?;
211
212 for row in rows {
213 let path = match row.get("path") {
214 Some(Value::String(p)) => p.clone(),
215 _ => continue,
216 };
217
218 let mut document = doc!(self.path_field => path);
219 let mut all_text = String::new();
220
221 for (name, &field) in &self.section_fields {
222 if let Some(Value::String(content)) = row.get(name) {
223 document.add_text(field, content);
224 all_text.push_str(content);
225 all_text.push('\n');
226 }
227 }
228
229 document.add_text(all_field, &all_text);
230 writer
231 .add_document(document)
232 .map_err(|e| MdqlError::General(format!("Tantivy add error: {}", e)))?;
233 }
234
235 writer
236 .commit()
237 .map_err(|e| MdqlError::General(format!("Tantivy commit error: {}", e)))?;
238
239 Ok(())
240 }
241}
242
243#[cfg(test)]
244mod tests {
245 use super::*;
246 use crate::model::{Row, Value};
247
248 fn test_rows() -> Vec<Row> {
249 vec![
250 Row::from([
251 ("path".into(), Value::String("a.md".into())),
252 (
253 "Summary".into(),
254 Value::String("This is about machine learning and neural networks".into()),
255 ),
256 (
257 "Details".into(),
258 Value::String("Deep dive into backpropagation algorithms".into()),
259 ),
260 ]),
261 Row::from([
262 ("path".into(), Value::String("b.md".into())),
263 (
264 "Summary".into(),
265 Value::String("A guide to database optimization".into()),
266 ),
267 (
268 "Details".into(),
269 Value::String("Index tuning and query planning for PostgreSQL".into()),
270 ),
271 ]),
272 Row::from([
273 ("path".into(), Value::String("c.md".into())),
274 (
275 "Summary".into(),
276 Value::String("Introduction to neural network architectures".into()),
277 ),
278 ]),
279 ]
280 }
281
282 #[test]
283 fn test_search_all_sections() {
284 let sections = vec!["Summary".into(), "Details".into()];
285 let searcher = TableSearcher::build(&test_rows(), §ions).unwrap();
286
287 let results = searcher.search("neural", None).unwrap();
288 assert_eq!(results.len(), 2);
289 assert!(results.contains(&"a.md".to_string()));
290 assert!(results.contains(&"c.md".to_string()));
291 }
292
293 #[test]
294 fn test_search_specific_section() {
295 let sections = vec!["Summary".into(), "Details".into()];
296 let searcher = TableSearcher::build(&test_rows(), §ions).unwrap();
297
298 let results = searcher.search("backpropagation", Some("Details")).unwrap();
299 assert_eq!(results.len(), 1);
300 assert_eq!(results[0], "a.md");
301 }
302
303 #[test]
304 fn test_search_no_results() {
305 let sections = vec!["Summary".into(), "Details".into()];
306 let searcher = TableSearcher::build(&test_rows(), §ions).unwrap();
307
308 let results = searcher.search("quantum", None).unwrap();
309 assert!(results.is_empty());
310 }
311
312 #[test]
313 fn test_rebuild() {
314 let sections = vec!["Summary".into()];
315 let mut searcher = TableSearcher::build(&test_rows(), §ions).unwrap();
316
317 let new_rows = vec![Row::from([
319 ("path".into(), Value::String("d.md".into())),
320 (
321 "Summary".into(),
322 Value::String("Quantum computing basics".into()),
323 ),
324 ])];
325
326 searcher.rebuild(&new_rows).unwrap();
327 let results = searcher.search("quantum", None).unwrap();
328 assert_eq!(results.len(), 1);
329 assert_eq!(results[0], "d.md");
330
331 let old_results = searcher.search("neural", None).unwrap();
333 assert!(old_results.is_empty());
334 }
335}