pocket_cli/search/
mod.rs

1use crate::models::{Entry, SearchAlgorithm};
2use crate::storage::StorageManager;
3use anyhow::{Result, anyhow};
4use std::cmp::Ordering;
5use std::collections::{HashMap, HashSet};
6use std::sync::{Arc, RwLock};
7use std::time::{Duration, Instant};
8use rust_stemmers::{Algorithm, Stemmer};
9use regex::Regex;
10
11/// Result of a search operation
12#[derive(Debug, Clone)]
13pub struct SearchResult {
14    /// The entry that matched
15    pub entry: Entry,
16    
17    /// The content of the entry
18    pub content: String,
19    
20    /// The similarity score (0.0 to 1.0)
21    pub score: f64,
22
23    /// Backpack name (None for general pocket)
24    pub backpack: Option<String>,
25
26    /// Highlighted snippets
27    pub highlights: Vec<String>,
28}
29
30/// Search engine for finding entries
31pub struct SearchEngine {
32    storage: StorageManager,
33    index: Arc<RwLock<SearchIndex>>,
34    stemmer: Stemmer,
35    stopwords: HashSet<String>,
36}
37
38/// Index structure to speed up searches
39struct SearchIndex {
40    /// Maps terms to documents
41    term_docs: HashMap<String, HashMap<String, f64>>,
42    
43    /// Maps document IDs to their backpack
44    doc_backpack: HashMap<String, Option<String>>,
45    
46    /// Document frequencies (for IDF calculation)
47    doc_frequencies: HashMap<String, usize>,
48    
49    /// Total document count (for IDF calculation)
50    total_docs: usize,
51    
52    /// Term frequency in each document (for TF calculation)
53    term_frequencies: HashMap<String, HashMap<String, usize>>,
54    
55    /// Average document length (for BM25)
56    average_doc_length: f64,
57    
58    /// Document lengths (for BM25)
59    doc_lengths: HashMap<String, usize>,
60    
61    /// Last time the index was updated
62    last_updated: Instant,
63    
64    /// Whether the index needs a rebuild
65    needs_rebuild: bool,
66}
67
68impl SearchEngine {
69    /// Create a new search engine
70    pub fn new(storage: StorageManager) -> Self {
71        let index = Arc::new(RwLock::new(SearchIndex::new()));
72        let stemmer = Stemmer::create(Algorithm::English);
73        
74        // Common English stopwords
75        let stopwords: HashSet<String> = vec![
76            "a", "an", "the", "and", "or", "but", "in", "on", "at", "to", "for", "with", 
77            "by", "about", "as", "of", "from", "is", "are", "was", "were", "be", "been",
78            "being", "have", "has", "had", "do", "does", "did", "will", "would", "shall",
79            "should", "can", "could", "may", "might", "must", "this", "that", "these",
80            "those", "i", "you", "he", "she", "it", "we", "they", "their", "my", "your",
81            "his", "her", "its", "our", "not"
82        ].into_iter().map(|s| s.to_string()).collect();
83        
84        let engine = Self { 
85            storage, 
86            index,
87            stemmer,
88            stopwords,
89        };
90        
91        // Initialize the index in the background
92        let index_clone = engine.index.clone();
93        let storage_clone = engine.storage.clone();
94        std::thread::spawn(move || {
95            let _ = SearchIndex::build(index_clone, storage_clone);
96        });
97        
98        engine
99    }
100
101    /// Search for entries matching a query
102    pub fn search(&self, query: &str, limit: usize, backpack: Option<&str>, algorithm: SearchAlgorithm) -> Result<Vec<SearchResult>> {
103        // Check if we need to rebuild the index
104        {
105            let index = self.index.read().map_err(|_| anyhow!("Failed to acquire read lock on search index"))?;
106            if index.needs_rebuild || index.last_updated.elapsed() > Duration::from_secs(300) {  // Rebuild every 5 minutes
107                // Release the read lock before acquiring the write lock
108                drop(index);
109                
110                // Rebuild the index
111                SearchIndex::build(self.index.clone(), self.storage.clone())?;
112            }
113        }
114        
115        // Tokenize and process the query
116        let processed_query = self.preprocess_text(query);
117        
118        let index = self.index.read().map_err(|_| anyhow!("Failed to acquire read lock on search index"))?;
119        
120        // Perform the search based on algorithm
121        let mut results = match algorithm {
122            SearchAlgorithm::Semantic => {
123                // Use BM25 ranking for semantic search
124                self.bm25_search(&processed_query, &index, backpack)?
125            },
126            SearchAlgorithm::Literal => {
127                // Use fuzzy matching for literal search
128                self.fuzzy_search(query, &index, backpack)?
129            }
130        };
131        
132        // Sort by score (highest first)
133        results.sort_by(|a, b| {
134            b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal)
135        });
136        
137        // Limit the number of results
138        if results.len() > limit {
139            results.truncate(limit);
140        }
141        
142        // Generate highlights
143        results = self.generate_highlights(results, query)?;
144        
145        Ok(results)
146    }
147    
148    /// Perform BM25 search
149    fn bm25_search(&self, query: &str, index: &SearchIndex, backpack_filter: Option<&str>) -> Result<Vec<SearchResult>> {
150        let query_terms = self.tokenize(query);
151        let mut scores: HashMap<String, f64> = HashMap::new();
152        
153        // BM25 parameters
154        let k1 = 1.2; // Controls term frequency saturation
155        let b = 0.75; // Controls document length normalization
156        
157        for term in query_terms {
158            if self.stopwords.contains(&term) {
159                continue;
160            }
161            
162            let stemmed_term = self.stemmer.stem(&term).to_string();
163            
164            // Skip if term not in index
165            if !index.doc_frequencies.contains_key(&stemmed_term) {
166                continue;
167            }
168            
169            // Calculate IDF (Inverse Document Frequency)
170            let df = *index.doc_frequencies.get(&stemmed_term).unwrap_or(&1);
171            let idf = (index.total_docs as f64 - df as f64 + 0.5) / (df as f64 + 0.5);
172            let idf = (1.0 + idf).ln();
173            
174            // For each document containing this term
175            if let Some(term_docs) = index.term_docs.get(&stemmed_term) {
176                for (doc_id, _) in term_docs {
177                    // Skip if not in the requested backpack
178                    if let Some(filter) = backpack_filter {
179                        if let Some(doc_backpack) = index.doc_backpack.get(doc_id) {
180                            if *doc_backpack != Some(filter.to_string()) {
181                                continue;
182                            }
183                        }
184                    }
185                    
186                    // Get term frequency
187                    let tf = *index.term_frequencies
188                        .get(doc_id)
189                        .and_then(|terms| terms.get(&stemmed_term))
190                        .unwrap_or(&0);
191                    
192                    // Get document length
193                    let doc_len = *index.doc_lengths.get(doc_id).unwrap_or(&1);
194                    
195                    // BM25 formula
196                    let numerator = tf as f64 * (k1 + 1.0);
197                    let denominator = tf as f64 + k1 * (1.0 - b + b * doc_len as f64 / index.average_doc_length);
198                    let score = idf * numerator / denominator;
199                    
200                    *scores.entry(doc_id.clone()).or_insert(0.0) += score;
201                }
202            }
203        }
204        
205        // Convert scores to results
206        let mut results = Vec::new();
207        for (doc_id, score) in scores {
208            // Normalize score to 0-1 range
209            let normalized_score = (score / 10.0).min(1.0);
210            
211            if normalized_score > 0.1 {  // Minimum threshold
212                // Get backpack
213                let backpack = index.doc_backpack.get(&doc_id).cloned().unwrap_or(None);
214                
215                // Load entry and content
216                if let Ok((entry, content)) = self.storage.load_entry(&doc_id, backpack.as_deref()) {
217                    results.push(SearchResult {
218                        entry,
219                        content,
220                        score: normalized_score,
221                        backpack,
222                        highlights: Vec::new(),
223                    });
224                }
225            }
226        }
227        
228        Ok(results)
229    }
230    
231    /// Perform fuzzy search using n-gram matching
232    fn fuzzy_search(&self, query: &str, _index: &SearchIndex, backpack_filter: Option<&str>) -> Result<Vec<SearchResult>> {
233        // Get all entries that might match
234        let mut entries: Vec<(Entry, String, Option<String>)> = Vec::new();
235        
236        // If backpack is specified, only search in that backpack
237        if let Some(backpack) = backpack_filter {
238            let backpack_entries = self.storage.list_entries(Some(backpack))?;
239            for entry in backpack_entries {
240                let (entry, content) = self.storage.load_entry(&entry.id, Some(backpack))?;
241                entries.push((entry, content, Some(backpack.to_string())));
242            }
243        } else {
244            // Get entries from all backpacks
245            let backpacks = self.storage.list_backpacks()?;
246            for backpack in backpacks {
247                let backpack_entries = self.storage.list_entries(Some(&backpack.name))?;
248                for entry in backpack_entries {
249                    let (entry, content) = self.storage.load_entry(&entry.id, Some(&backpack.name))?;
250                    entries.push((entry, content, Some(backpack.name.clone())));
251                }
252            }
253            
254            // Also get entries from the general pocket
255            let general_entries = self.storage.list_entries(None)?;
256            for entry in general_entries {
257                let (entry, content) = self.storage.load_entry(&entry.id, None)?;
258                entries.push((entry, content, None));
259            }
260        }
261        
262        // Calculate fuzzy match scores
263        let mut results = Vec::new();
264        
265        for (entry, content, backpack) in entries {
266            // Calculate fuzzy match score
267            let score = self.calculate_fuzzy_similarity(query, &content);
268            
269            if score > 0.2 {  // Higher threshold for fuzzy matching
270                results.push(SearchResult {
271                    entry,
272                    content,
273                    score,
274                    backpack,
275                    highlights: Vec::new(),
276                });
277            }
278        }
279        
280        Ok(results)
281    }
282
283    /// Calculate fuzzy similarity between query and content using n-grams
284    fn calculate_fuzzy_similarity(&self, query: &str, content: &str) -> f64 {
285        if query.is_empty() || content.is_empty() {
286            return 0.0;
287        }
288        
289        // Generate n-grams for query and content (trigrams are common for fuzzy matching)
290        let query_ngrams = self.generate_ngrams(query, 3);
291        let content_ngrams = self.generate_ngrams(content, 3);
292        
293        // Calculate Jaccard similarity coefficient
294        let intersection: HashSet<_> = query_ngrams.intersection(&content_ngrams).collect();
295        let union: HashSet<_> = query_ngrams.union(&content_ngrams).collect();
296        
297        if union.is_empty() {
298            return 0.0;
299        }
300        
301        intersection.len() as f64 / union.len() as f64
302    }
303    
304    /// Generate n-grams from text
305    fn generate_ngrams(&self, text: &str, n: usize) -> HashSet<String> {
306        let text = text.to_lowercase();
307        
308        let mut ngrams = HashSet::new();
309        if text.len() < n {
310            ngrams.insert(text);
311            return ngrams;
312        }
313        
314        // Generate character n-grams
315        for i in 0..=text.len() - n {
316            let ngram: String = text.chars().skip(i).take(n).collect();
317            ngrams.insert(ngram);
318        }
319        
320        ngrams
321    }
322    
323    /// Tokenize text into terms
324    fn tokenize(&self, text: &str) -> Vec<String> {
325        // Regex to extract words
326        let word_regex = Regex::new(r"\b[\w']+\b").unwrap();
327        
328        word_regex.find_iter(text.to_lowercase().as_str())
329            .map(|m| m.as_str().to_string())
330            .collect()
331    }
332    
333    /// Preprocess text for indexing/searching
334    fn preprocess_text(&self, text: &str) -> String {
335        // Tokenize
336        let tokens = self.tokenize(text);
337        
338        // Filter out stopwords and apply stemming
339        tokens.iter()
340            .filter(|token| !self.stopwords.contains(*token))
341            .map(|token| self.stemmer.stem(token).to_string())
342            .collect::<Vec<String>>()
343            .join(" ")
344    }
345    
346    /// Generate meaningful highlights for search results
347    fn generate_highlights(&self, results: Vec<SearchResult>, query: &str) -> Result<Vec<SearchResult>> {
348        let mut highlighted_results = Vec::new();
349        
350        // Create regex for finding query terms (with word boundaries)
351        let query_terms: Vec<&str> = query.split_whitespace().collect();
352        let regex_pattern = query_terms.iter()
353            .map(|term| format!(r"\b{}\b", regex::escape(term)))
354            .collect::<Vec<String>>()
355            .join("|");
356        
357        let term_regex = Regex::new(&regex_pattern)
358            .map_err(|e| anyhow!("Failed to create regex: {}", e))?;
359        
360        for result in results {
361            let mut highlights = Vec::new();
362            let content = &result.content;
363            
364            // Find best context for each match
365            let mut matches = term_regex.find_iter(content).peekable();
366            
367            if matches.peek().is_none() {
368                // No exact matches, find fuzzy matches
369                highlights.push(self.get_context_snippet(content, 0, 150));
370            } else {
371                // Generate snippets around each match, limiting to 3 highlights
372                let mut current_pos = 0;
373                let mut highlight_count = 0;
374                
375                for m in term_regex.find_iter(content) {
376                    if highlight_count >= 3 {
377                        break;
378                    }
379                    
380                    // Skip if too close to previous highlight
381                    if m.start() < current_pos + 50 && current_pos > 0 {
382                        continue;
383                    }
384                    
385                    // Get snippet context
386                    let snippet = self.get_context_snippet(content, m.start(), 150);
387                    highlights.push(snippet);
388                    
389                    current_pos = m.end();
390                    highlight_count += 1;
391                }
392            }
393            
394            // Create a copy with highlights
395            highlighted_results.push(SearchResult {
396                entry: result.entry,
397                content: result.content,
398                score: result.score,
399                backpack: result.backpack,
400                highlights,
401            });
402        }
403        
404        Ok(highlighted_results)
405    }
406    
407    /// Get a context snippet around a position
408    fn get_context_snippet(&self, content: &str, position: usize, length: usize) -> String {
409        let content_len = content.len();
410        
411        // Calculate start position
412        let start = if position > length / 2 {
413            position - length / 2
414        } else {
415            0
416        };
417        
418        // Find word boundary for start
419        let mut start_pos = start;
420        while start_pos > 0 && content.chars().nth(start_pos) != Some(' ') {
421            start_pos -= 1;
422        }
423        
424        // Calculate end position
425        let end = (start_pos + length).min(content_len);
426        
427        // Find word boundary for end
428        let mut end_pos = end;
429        while end_pos < content_len && content.chars().nth(end_pos) != Some(' ') {
430            end_pos += 1;
431        }
432        
433        // Extract snippet
434        let mut result = String::new();
435        if start_pos > 0 {
436            result.push_str("...");
437        }
438        
439        result.push_str(&content[start_pos..end_pos.min(content_len)]);
440        
441        if end_pos < content_len {
442            result.push_str("...");
443        }
444        
445        result
446    }
447}
448
449impl SearchIndex {
450    /// Create a new empty search index
451    fn new() -> Self {
452        Self {
453            term_docs: HashMap::new(),
454            doc_backpack: HashMap::new(),
455            doc_frequencies: HashMap::new(),
456            total_docs: 0,
457            term_frequencies: HashMap::new(),
458            average_doc_length: 0.0,
459            doc_lengths: HashMap::new(),
460            last_updated: Instant::now(),
461            needs_rebuild: true,
462        }
463    }
464    
465    /// Tokenize content for indexing
466    fn tokenize_content(text: &str) -> Vec<String> {
467        // Create stemmer and stopwords
468        let stemmer = Stemmer::create(Algorithm::English);
469        
470        // Common English stopwords
471        let stopwords: HashSet<String> = vec![
472            "a", "an", "the", "and", "or", "but", "in", "on", "at", "to", "for", "with", 
473            "by", "about", "as", "of", "from", "is", "are", "was", "were", "be", "been",
474            "being", "have", "has", "had", "do", "does", "did", "will", "would", "shall",
475            "should", "can", "could", "may", "might", "must", "this", "that", "these",
476            "those", "i", "you", "he", "she", "it", "we", "they", "their", "my", "your",
477            "his", "her", "its", "our", "not"
478        ].into_iter().map(|s| s.to_string()).collect();
479        
480        // Tokenize the text
481        let word_regex = Regex::new(r"\b[\w']+\b").unwrap();
482        
483        word_regex.find_iter(text.to_lowercase().as_str())
484            .map(|m| m.as_str().to_string())
485            .filter(|token| !stopwords.contains(token))
486            .map(|token| stemmer.stem(&token).to_string())
487            .collect()
488    }
489    
490    /// Build or rebuild the search index
491    fn build(index: Arc<RwLock<SearchIndex>>, storage: StorageManager) -> Result<()> {
492        let mut index_guard = index.write().unwrap();
493        
494        // Clear the index
495        index_guard.term_docs.clear();
496        index_guard.doc_backpack.clear();
497        index_guard.doc_frequencies.clear();
498        index_guard.term_frequencies.clear();
499        index_guard.doc_lengths.clear();
500        
501        // Get all entries from all backpacks
502        let backpacks = storage.list_backpacks()?;
503        let mut all_entries = Vec::new();
504        
505        // Add general entries
506        for entry in storage.list_entries(None)? {
507            all_entries.push((entry, None));
508        }
509        
510        // Add backpack entries
511        for backpack in backpacks {
512            for entry in storage.list_entries(Some(&backpack.name))? {
513                all_entries.push((entry, Some(backpack.name.clone())));
514            }
515        }
516        
517        // Process each entry
518        let mut total_length = 0;
519        for (entry, backpack) in all_entries {
520            // Load the entry content
521            let content = match storage.load_entry_content(&entry.id, backpack.as_deref()) {
522                Ok(content) => content,
523                Err(_) => continue, // Skip entries with missing content
524            };
525            
526            // Get summary if available
527            let mut summary_text = String::new();
528            if let Some(summary_json) = entry.get_metadata("summary") {
529                if let Ok(summary) = crate::utils::SummaryMetadata::from_json(summary_json) {
530                    summary_text = summary.summary;
531                }
532            }
533            
534            // Combine content and summary for indexing, giving summary higher weight
535            let combined_text = if !summary_text.is_empty() {
536                format!("{} {} {}", summary_text, summary_text, content) // Repeat summary to give it more weight
537            } else {
538                content.clone()
539            };
540            
541            // Tokenize and process the content
542            let tokens = SearchIndex::tokenize_content(&combined_text);
543            let doc_length = tokens.len();
544            total_length += doc_length;
545            
546            // Store document length
547            index_guard.doc_lengths.insert(entry.id.clone(), doc_length);
548            
549            // Store backpack info
550            index_guard.doc_backpack.insert(entry.id.clone(), backpack.clone());
551            
552            // Process each token
553            let mut term_counts = HashMap::new();
554            for token in tokens {
555                *term_counts.entry(token.clone()).or_insert(0) += 1;
556                
557                // Add to term-document index
558                let docs = index_guard.term_docs.entry(token.clone()).or_insert_with(HashMap::new);
559                docs.insert(entry.id.clone(), 0.0); // Score will be calculated later
560            }
561            
562            // Store term frequencies for this document
563            index_guard.term_frequencies.insert(entry.id.clone(), term_counts);
564        }
565        
566        // Calculate average document length
567        if !index_guard.doc_lengths.is_empty() {
568            index_guard.average_doc_length = total_length as f64 / index_guard.doc_lengths.len() as f64;
569        }
570        
571        // Calculate document frequencies
572        let mut doc_frequencies = HashMap::new();
573        for (term, docs) in &index_guard.term_docs {
574            doc_frequencies.insert(term.clone(), docs.len());
575        }
576        
577        // Update document frequencies
578        index_guard.doc_frequencies = doc_frequencies;
579        
580        // Update total document count
581        index_guard.total_docs = index_guard.doc_lengths.len();
582        
583        // Mark as updated
584        index_guard.last_updated = Instant::now();
585        index_guard.needs_rebuild = false;
586        
587        Ok(())
588    }
589}