1use crate::models::{Entry, SearchAlgorithm};
2use crate::storage::StorageManager;
3use anyhow::{Result, anyhow};
4use std::cmp::Ordering;
5use std::collections::{HashMap, HashSet};
6use std::sync::{Arc, RwLock};
7use std::time::{Duration, Instant};
8use rust_stemmers::{Algorithm, Stemmer};
9use regex::Regex;
10
11#[derive(Debug, Clone)]
13pub struct SearchResult {
14 pub entry: Entry,
16
17 pub content: String,
19
20 pub score: f64,
22
23 pub backpack: Option<String>,
25
26 pub highlights: Vec<String>,
28}
29
30pub struct SearchEngine {
32 storage: StorageManager,
33 index: Arc<RwLock<SearchIndex>>,
34 stemmer: Stemmer,
35 stopwords: HashSet<String>,
36}
37
38struct SearchIndex {
40 term_docs: HashMap<String, HashMap<String, f64>>,
42
43 doc_backpack: HashMap<String, Option<String>>,
45
46 doc_frequencies: HashMap<String, usize>,
48
49 total_docs: usize,
51
52 term_frequencies: HashMap<String, HashMap<String, usize>>,
54
55 average_doc_length: f64,
57
58 doc_lengths: HashMap<String, usize>,
60
61 last_updated: Instant,
63
64 needs_rebuild: bool,
66}
67
68impl SearchEngine {
69 pub fn new(storage: StorageManager) -> Self {
71 let index = Arc::new(RwLock::new(SearchIndex::new()));
72 let stemmer = Stemmer::create(Algorithm::English);
73
74 let stopwords: HashSet<String> = vec![
76 "a", "an", "the", "and", "or", "but", "in", "on", "at", "to", "for", "with",
77 "by", "about", "as", "of", "from", "is", "are", "was", "were", "be", "been",
78 "being", "have", "has", "had", "do", "does", "did", "will", "would", "shall",
79 "should", "can", "could", "may", "might", "must", "this", "that", "these",
80 "those", "i", "you", "he", "she", "it", "we", "they", "their", "my", "your",
81 "his", "her", "its", "our", "not"
82 ].into_iter().map(|s| s.to_string()).collect();
83
84 let engine = Self {
85 storage,
86 index,
87 stemmer,
88 stopwords,
89 };
90
91 let index_clone = engine.index.clone();
93 let storage_clone = engine.storage.clone();
94 std::thread::spawn(move || {
95 let _ = SearchIndex::build(index_clone, storage_clone);
96 });
97
98 engine
99 }
100
101 pub fn search(&self, query: &str, limit: usize, backpack: Option<&str>, algorithm: SearchAlgorithm) -> Result<Vec<SearchResult>> {
103 {
105 let index = self.index.read().map_err(|_| anyhow!("Failed to acquire read lock on search index"))?;
106 if index.needs_rebuild || index.last_updated.elapsed() > Duration::from_secs(300) { drop(index);
109
110 SearchIndex::build(self.index.clone(), self.storage.clone())?;
112 }
113 }
114
115 let processed_query = self.preprocess_text(query);
117
118 let index = self.index.read().map_err(|_| anyhow!("Failed to acquire read lock on search index"))?;
119
120 let mut results = match algorithm {
122 SearchAlgorithm::Semantic => {
123 self.bm25_search(&processed_query, &index, backpack)?
125 },
126 SearchAlgorithm::Literal => {
127 self.fuzzy_search(query, &index, backpack)?
129 }
130 };
131
132 results.sort_by(|a, b| {
134 b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal)
135 });
136
137 if results.len() > limit {
139 results.truncate(limit);
140 }
141
142 results = self.generate_highlights(results, query)?;
144
145 Ok(results)
146 }
147
148 fn bm25_search(&self, query: &str, index: &SearchIndex, backpack_filter: Option<&str>) -> Result<Vec<SearchResult>> {
150 let query_terms = self.tokenize(query);
151 let mut scores: HashMap<String, f64> = HashMap::new();
152
153 let k1 = 1.2; let b = 0.75; for term in query_terms {
158 if self.stopwords.contains(&term) {
159 continue;
160 }
161
162 let stemmed_term = self.stemmer.stem(&term).to_string();
163
164 if !index.doc_frequencies.contains_key(&stemmed_term) {
166 continue;
167 }
168
169 let df = *index.doc_frequencies.get(&stemmed_term).unwrap_or(&1);
171 let idf = (index.total_docs as f64 - df as f64 + 0.5) / (df as f64 + 0.5);
172 let idf = (1.0 + idf).ln();
173
174 if let Some(term_docs) = index.term_docs.get(&stemmed_term) {
176 for (doc_id, _) in term_docs {
177 if let Some(filter) = backpack_filter {
179 if let Some(doc_backpack) = index.doc_backpack.get(doc_id) {
180 if *doc_backpack != Some(filter.to_string()) {
181 continue;
182 }
183 }
184 }
185
186 let tf = *index.term_frequencies
188 .get(doc_id)
189 .and_then(|terms| terms.get(&stemmed_term))
190 .unwrap_or(&0);
191
192 let doc_len = *index.doc_lengths.get(doc_id).unwrap_or(&1);
194
195 let numerator = tf as f64 * (k1 + 1.0);
197 let denominator = tf as f64 + k1 * (1.0 - b + b * doc_len as f64 / index.average_doc_length);
198 let score = idf * numerator / denominator;
199
200 *scores.entry(doc_id.clone()).or_insert(0.0) += score;
201 }
202 }
203 }
204
205 let mut results = Vec::new();
207 for (doc_id, score) in scores {
208 let normalized_score = (score / 10.0).min(1.0);
210
211 if normalized_score > 0.1 { let backpack = index.doc_backpack.get(&doc_id).cloned().unwrap_or(None);
214
215 if let Ok((entry, content)) = self.storage.load_entry(&doc_id, backpack.as_deref()) {
217 results.push(SearchResult {
218 entry,
219 content,
220 score: normalized_score,
221 backpack,
222 highlights: Vec::new(),
223 });
224 }
225 }
226 }
227
228 Ok(results)
229 }
230
231 fn fuzzy_search(&self, query: &str, _index: &SearchIndex, backpack_filter: Option<&str>) -> Result<Vec<SearchResult>> {
233 let mut entries: Vec<(Entry, String, Option<String>)> = Vec::new();
235
236 if let Some(backpack) = backpack_filter {
238 let backpack_entries = self.storage.list_entries(Some(backpack))?;
239 for entry in backpack_entries {
240 let (entry, content) = self.storage.load_entry(&entry.id, Some(backpack))?;
241 entries.push((entry, content, Some(backpack.to_string())));
242 }
243 } else {
244 let backpacks = self.storage.list_backpacks()?;
246 for backpack in backpacks {
247 let backpack_entries = self.storage.list_entries(Some(&backpack.name))?;
248 for entry in backpack_entries {
249 let (entry, content) = self.storage.load_entry(&entry.id, Some(&backpack.name))?;
250 entries.push((entry, content, Some(backpack.name.clone())));
251 }
252 }
253
254 let general_entries = self.storage.list_entries(None)?;
256 for entry in general_entries {
257 let (entry, content) = self.storage.load_entry(&entry.id, None)?;
258 entries.push((entry, content, None));
259 }
260 }
261
262 let mut results = Vec::new();
264
265 for (entry, content, backpack) in entries {
266 let score = self.calculate_fuzzy_similarity(query, &content);
268
269 if score > 0.2 { results.push(SearchResult {
271 entry,
272 content,
273 score,
274 backpack,
275 highlights: Vec::new(),
276 });
277 }
278 }
279
280 Ok(results)
281 }
282
283 fn calculate_fuzzy_similarity(&self, query: &str, content: &str) -> f64 {
285 if query.is_empty() || content.is_empty() {
286 return 0.0;
287 }
288
289 let query_ngrams = self.generate_ngrams(query, 3);
291 let content_ngrams = self.generate_ngrams(content, 3);
292
293 let intersection: HashSet<_> = query_ngrams.intersection(&content_ngrams).collect();
295 let union: HashSet<_> = query_ngrams.union(&content_ngrams).collect();
296
297 if union.is_empty() {
298 return 0.0;
299 }
300
301 intersection.len() as f64 / union.len() as f64
302 }
303
304 fn generate_ngrams(&self, text: &str, n: usize) -> HashSet<String> {
306 let text = text.to_lowercase();
307
308 let mut ngrams = HashSet::new();
309 if text.len() < n {
310 ngrams.insert(text);
311 return ngrams;
312 }
313
314 for i in 0..=text.len() - n {
316 let ngram: String = text.chars().skip(i).take(n).collect();
317 ngrams.insert(ngram);
318 }
319
320 ngrams
321 }
322
323 fn tokenize(&self, text: &str) -> Vec<String> {
325 let word_regex = Regex::new(r"\b[\w']+\b").unwrap();
327
328 word_regex.find_iter(text.to_lowercase().as_str())
329 .map(|m| m.as_str().to_string())
330 .collect()
331 }
332
333 fn preprocess_text(&self, text: &str) -> String {
335 let tokens = self.tokenize(text);
337
338 tokens.iter()
340 .filter(|token| !self.stopwords.contains(*token))
341 .map(|token| self.stemmer.stem(token).to_string())
342 .collect::<Vec<String>>()
343 .join(" ")
344 }
345
346 fn generate_highlights(&self, results: Vec<SearchResult>, query: &str) -> Result<Vec<SearchResult>> {
348 let mut highlighted_results = Vec::new();
349
350 let query_terms: Vec<&str> = query.split_whitespace().collect();
352 let regex_pattern = query_terms.iter()
353 .map(|term| format!(r"\b{}\b", regex::escape(term)))
354 .collect::<Vec<String>>()
355 .join("|");
356
357 let term_regex = Regex::new(®ex_pattern)
358 .map_err(|e| anyhow!("Failed to create regex: {}", e))?;
359
360 for result in results {
361 let mut highlights = Vec::new();
362 let content = &result.content;
363
364 let mut matches = term_regex.find_iter(content).peekable();
366
367 if matches.peek().is_none() {
368 highlights.push(self.get_context_snippet(content, 0, 150));
370 } else {
371 let mut current_pos = 0;
373 let mut highlight_count = 0;
374
375 for m in term_regex.find_iter(content) {
376 if highlight_count >= 3 {
377 break;
378 }
379
380 if m.start() < current_pos + 50 && current_pos > 0 {
382 continue;
383 }
384
385 let snippet = self.get_context_snippet(content, m.start(), 150);
387 highlights.push(snippet);
388
389 current_pos = m.end();
390 highlight_count += 1;
391 }
392 }
393
394 highlighted_results.push(SearchResult {
396 entry: result.entry,
397 content: result.content,
398 score: result.score,
399 backpack: result.backpack,
400 highlights,
401 });
402 }
403
404 Ok(highlighted_results)
405 }
406
407 fn get_context_snippet(&self, content: &str, position: usize, length: usize) -> String {
409 let content_len = content.len();
410
411 let start = if position > length / 2 {
413 position - length / 2
414 } else {
415 0
416 };
417
418 let mut start_pos = start;
420 while start_pos > 0 && content.chars().nth(start_pos) != Some(' ') {
421 start_pos -= 1;
422 }
423
424 let end = (start_pos + length).min(content_len);
426
427 let mut end_pos = end;
429 while end_pos < content_len && content.chars().nth(end_pos) != Some(' ') {
430 end_pos += 1;
431 }
432
433 let mut result = String::new();
435 if start_pos > 0 {
436 result.push_str("...");
437 }
438
439 result.push_str(&content[start_pos..end_pos.min(content_len)]);
440
441 if end_pos < content_len {
442 result.push_str("...");
443 }
444
445 result
446 }
447}
448
449impl SearchIndex {
450 fn new() -> Self {
452 Self {
453 term_docs: HashMap::new(),
454 doc_backpack: HashMap::new(),
455 doc_frequencies: HashMap::new(),
456 total_docs: 0,
457 term_frequencies: HashMap::new(),
458 average_doc_length: 0.0,
459 doc_lengths: HashMap::new(),
460 last_updated: Instant::now(),
461 needs_rebuild: true,
462 }
463 }
464
465 fn tokenize_content(text: &str) -> Vec<String> {
467 let stemmer = Stemmer::create(Algorithm::English);
469
470 let stopwords: HashSet<String> = vec![
472 "a", "an", "the", "and", "or", "but", "in", "on", "at", "to", "for", "with",
473 "by", "about", "as", "of", "from", "is", "are", "was", "were", "be", "been",
474 "being", "have", "has", "had", "do", "does", "did", "will", "would", "shall",
475 "should", "can", "could", "may", "might", "must", "this", "that", "these",
476 "those", "i", "you", "he", "she", "it", "we", "they", "their", "my", "your",
477 "his", "her", "its", "our", "not"
478 ].into_iter().map(|s| s.to_string()).collect();
479
480 let word_regex = Regex::new(r"\b[\w']+\b").unwrap();
482
483 word_regex.find_iter(text.to_lowercase().as_str())
484 .map(|m| m.as_str().to_string())
485 .filter(|token| !stopwords.contains(token))
486 .map(|token| stemmer.stem(&token).to_string())
487 .collect()
488 }
489
490 fn build(index: Arc<RwLock<SearchIndex>>, storage: StorageManager) -> Result<()> {
492 let mut index_guard = index.write().unwrap();
493
494 index_guard.term_docs.clear();
496 index_guard.doc_backpack.clear();
497 index_guard.doc_frequencies.clear();
498 index_guard.term_frequencies.clear();
499 index_guard.doc_lengths.clear();
500
501 let backpacks = storage.list_backpacks()?;
503 let mut all_entries = Vec::new();
504
505 for entry in storage.list_entries(None)? {
507 all_entries.push((entry, None));
508 }
509
510 for backpack in backpacks {
512 for entry in storage.list_entries(Some(&backpack.name))? {
513 all_entries.push((entry, Some(backpack.name.clone())));
514 }
515 }
516
517 let mut total_length = 0;
519 for (entry, backpack) in all_entries {
520 let content = match storage.load_entry_content(&entry.id, backpack.as_deref()) {
522 Ok(content) => content,
523 Err(_) => continue, };
525
526 let mut summary_text = String::new();
528 if let Some(summary_json) = entry.get_metadata("summary") {
529 if let Ok(summary) = crate::utils::SummaryMetadata::from_json(summary_json) {
530 summary_text = summary.summary;
531 }
532 }
533
534 let combined_text = if !summary_text.is_empty() {
536 format!("{} {} {}", summary_text, summary_text, content) } else {
538 content.clone()
539 };
540
541 let tokens = SearchIndex::tokenize_content(&combined_text);
543 let doc_length = tokens.len();
544 total_length += doc_length;
545
546 index_guard.doc_lengths.insert(entry.id.clone(), doc_length);
548
549 index_guard.doc_backpack.insert(entry.id.clone(), backpack.clone());
551
552 let mut term_counts = HashMap::new();
554 for token in tokens {
555 *term_counts.entry(token.clone()).or_insert(0) += 1;
556
557 let docs = index_guard.term_docs.entry(token.clone()).or_insert_with(HashMap::new);
559 docs.insert(entry.id.clone(), 0.0); }
561
562 index_guard.term_frequencies.insert(entry.id.clone(), term_counts);
564 }
565
566 if !index_guard.doc_lengths.is_empty() {
568 index_guard.average_doc_length = total_length as f64 / index_guard.doc_lengths.len() as f64;
569 }
570
571 let mut doc_frequencies = HashMap::new();
573 for (term, docs) in &index_guard.term_docs {
574 doc_frequencies.insert(term.clone(), docs.len());
575 }
576
577 index_guard.doc_frequencies = doc_frequencies;
579
580 index_guard.total_docs = index_guard.doc_lengths.len();
582
583 index_guard.last_updated = Instant::now();
585 index_guard.needs_rebuild = false;
586
587 Ok(())
588 }
589}