Skip to main content

mdvault_core/index/
search.rs

1//! Contextual search beyond keyword matching.
2//!
3//! This module provides multi-modal search capabilities:
4//! - Direct match: Notes matching a query string
5//! - Graph neighbourhood: Linked notes within N hops
6//! - Temporal context: Recent dailies referencing matches
7//! - Cooccurrence: Notes that appeared together in dailies
8
9use std::collections::{HashMap, HashSet};
10
11use super::IndexError;
12use super::db::IndexDb;
13use super::types::{IndexedNote, NoteType};
14
15/// Search mode determining how results are expanded.
16#[derive(Debug, Clone, Copy, Default)]
17pub enum SearchMode {
18    /// Only return notes directly matching the query.
19    #[default]
20    Direct,
21    /// Include linked notes within N hops.
22    Neighbourhood { hops: u32 },
23    /// Include recent dailies referencing matching notes.
24    Temporal { days: u32 },
25    /// Include notes that cooccur with matches in dailies.
26    Cooccurrence { min_shared: u32 },
27    /// Combined: neighbourhood + temporal + cooccurrence.
28    Full,
29}
30
31/// Search query parameters.
32#[derive(Debug, Clone, Default)]
33pub struct SearchQuery {
34    /// Text to search for (in title, path, or content).
35    pub text: Option<String>,
36    /// Filter by note type.
37    pub note_type: Option<NoteType>,
38    /// Path prefix filter.
39    pub path_prefix: Option<String>,
40    /// Search mode for result expansion.
41    pub mode: SearchMode,
42    /// Maximum results to return.
43    pub limit: Option<u32>,
44    /// Favour recently active notes.
45    pub temporal_boost: bool,
46}
47
48/// A search result with relevance information.
49#[derive(Debug, Clone)]
50pub struct SearchResult {
51    /// The matching note.
52    pub note: IndexedNote,
53    /// Relevance score (higher = more relevant).
54    pub score: f64,
55    /// How this result was found.
56    pub match_source: MatchSource,
57    /// Staleness score if available (lower = more active).
58    pub staleness: Option<f64>,
59}
60
61/// How a search result was matched.
62#[derive(Debug, Clone, PartialEq, Eq)]
63pub enum MatchSource {
64    /// Direct text match.
65    Direct,
66    /// Linked from a direct match.
67    Linked { hops: u32 },
68    /// Referenced in a daily with a direct match.
69    Temporal { daily_path: String },
70    /// Cooccurs with a direct match.
71    Cooccurrence { shared_dailies: u32 },
72}
73
74/// Search engine using the vault index.
75pub struct SearchEngine<'a> {
76    db: &'a IndexDb,
77}
78
79impl<'a> SearchEngine<'a> {
80    /// Create a new search engine.
81    pub fn new(db: &'a IndexDb) -> Self {
82        Self { db }
83    }
84
85    /// Execute a search query.
86    pub fn search(&self, query: &SearchQuery) -> Result<Vec<SearchResult>, IndexError> {
87        // Step 1: Find direct matches
88        let direct_matches = self.find_direct_matches(query)?;
89        let direct_ids: HashSet<i64> =
90            direct_matches.iter().filter_map(|n| n.id).collect();
91
92        let mut results: Vec<SearchResult> = direct_matches
93            .into_iter()
94            .map(|note| SearchResult {
95                staleness: self.get_staleness(note.id),
96                note,
97                score: 1.0,
98                match_source: MatchSource::Direct,
99            })
100            .collect();
101
102        // Step 2: Expand based on mode
103        match query.mode {
104            SearchMode::Direct => {}
105            SearchMode::Neighbourhood { hops } => {
106                let expanded = self.expand_neighbourhood(&direct_ids, hops)?;
107                results.extend(expanded);
108            }
109            SearchMode::Temporal { days } => {
110                let expanded = self.expand_temporal(&direct_ids, days)?;
111                results.extend(expanded);
112            }
113            SearchMode::Cooccurrence { min_shared } => {
114                let expanded = self.expand_cooccurrence(&direct_ids, min_shared)?;
115                results.extend(expanded);
116            }
117            SearchMode::Full => {
118                // Combine all expansion modes
119                let neighbourhood = self.expand_neighbourhood(&direct_ids, 2)?;
120                let temporal = self.expand_temporal(&direct_ids, 30)?;
121                let cooccurrence = self.expand_cooccurrence(&direct_ids, 2)?;
122                results.extend(neighbourhood);
123                results.extend(temporal);
124                results.extend(cooccurrence);
125            }
126        }
127
128        // Step 3: Apply temporal boost if requested
129        if query.temporal_boost {
130            for result in &mut results {
131                if let Some(staleness) = result.staleness {
132                    // Boost score based on freshness (1 - staleness)
133                    result.score *= 1.0 + (1.0 - staleness) * 0.5;
134                }
135            }
136        }
137
138        // Step 4: Deduplicate and sort by score
139        results = self.deduplicate_results(results);
140        results.sort_by(|a, b| {
141            b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal)
142        });
143
144        // Step 5: Apply limit
145        if let Some(limit) = query.limit {
146            results.truncate(limit as usize);
147        }
148
149        Ok(results)
150    }
151
152    /// Find notes directly matching the query.
153    fn find_direct_matches(
154        &self,
155        query: &SearchQuery,
156    ) -> Result<Vec<IndexedNote>, IndexError> {
157        // Build a NoteQuery from SearchQuery
158        let note_query = super::types::NoteQuery {
159            note_type: query.note_type,
160            path_prefix: query.path_prefix.as_ref().map(Into::into),
161            limit: query.limit,
162            ..Default::default()
163        };
164
165        let notes = self.db.query_notes(&note_query)?;
166
167        // Filter by text if provided
168        if let Some(text) = &query.text {
169            let text_lower = text.to_lowercase();
170            Ok(notes
171                .into_iter()
172                .filter(|n| {
173                    n.title.to_lowercase().contains(&text_lower)
174                        || n.path.to_string_lossy().to_lowercase().contains(&text_lower)
175                })
176                .collect())
177        } else {
178            Ok(notes)
179        }
180    }
181
182    /// Expand results by following links up to N hops.
183    fn expand_neighbourhood(
184        &self,
185        seed_ids: &HashSet<i64>,
186        max_hops: u32,
187    ) -> Result<Vec<SearchResult>, IndexError> {
188        let mut results = Vec::new();
189        let mut visited: HashSet<i64> = seed_ids.clone();
190        let mut frontier: HashSet<i64> = seed_ids.clone();
191
192        for hop in 1..=max_hops {
193            let mut next_frontier = HashSet::new();
194
195            for &note_id in &frontier {
196                // Get outgoing links
197                let outlinks = self.db.get_outgoing_links(note_id)?;
198                for link in outlinks {
199                    if let Some(target_id) = link.target_id
200                        && !visited.contains(&target_id)
201                    {
202                        visited.insert(target_id);
203                        next_frontier.insert(target_id);
204
205                        if let Some(note) = self.db.get_note_by_id(target_id)? {
206                            results.push(SearchResult {
207                                staleness: self.get_staleness(note.id),
208                                note,
209                                score: 0.5 / (hop as f64), // Decay by distance
210                                match_source: MatchSource::Linked { hops: hop },
211                            });
212                        }
213                    }
214                }
215
216                // Get backlinks
217                let backlinks = self.db.get_backlinks(note_id)?;
218                for link in backlinks {
219                    if !visited.contains(&link.source_id) {
220                        visited.insert(link.source_id);
221                        next_frontier.insert(link.source_id);
222
223                        if let Some(note) = self.db.get_note_by_id(link.source_id)? {
224                            results.push(SearchResult {
225                                staleness: self.get_staleness(note.id),
226                                note,
227                                score: 0.5 / (hop as f64),
228                                match_source: MatchSource::Linked { hops: hop },
229                            });
230                        }
231                    }
232                }
233            }
234
235            frontier = next_frontier;
236            if frontier.is_empty() {
237                break;
238            }
239        }
240
241        Ok(results)
242    }
243
244    /// Expand results by finding recent dailies referencing matches.
245    fn expand_temporal(
246        &self,
247        seed_ids: &HashSet<i64>,
248        _days: u32,
249    ) -> Result<Vec<SearchResult>, IndexError> {
250        let mut results = Vec::new();
251        let mut seen_dailies: HashSet<i64> = HashSet::new();
252
253        for &note_id in seed_ids {
254            // Get backlinks to find dailies referencing this note
255            let backlinks = self.db.get_backlinks(note_id)?;
256            for link in backlinks {
257                if let Some(source_note) = self.db.get_note_by_id(link.source_id)?
258                    && source_note.note_type == NoteType::Daily
259                    && !seen_dailies.contains(&link.source_id)
260                    && !seed_ids.contains(&link.source_id)
261                {
262                    seen_dailies.insert(link.source_id);
263                    let path = source_note.path.to_string_lossy().to_string();
264                    results.push(SearchResult {
265                        staleness: self.get_staleness(source_note.id),
266                        note: source_note,
267                        score: 0.4,
268                        match_source: MatchSource::Temporal { daily_path: path },
269                    });
270                }
271            }
272        }
273
274        Ok(results)
275    }
276
277    /// Expand results by finding notes that cooccur with matches.
278    fn expand_cooccurrence(
279        &self,
280        seed_ids: &HashSet<i64>,
281        min_shared: u32,
282    ) -> Result<Vec<SearchResult>, IndexError> {
283        let mut results = Vec::new();
284        let mut seen: HashSet<i64> = seed_ids.clone();
285
286        for &note_id in seed_ids {
287            let cooccurrent = self.db.get_cooccurrent_notes(note_id, 10)?;
288            for (note, shared_count) in cooccurrent {
289                if let Some(id) = note.id
290                    && shared_count >= min_shared as i32
291                    && !seen.contains(&id)
292                {
293                    seen.insert(id);
294                    results.push(SearchResult {
295                        staleness: self.get_staleness(note.id),
296                        note,
297                        score: 0.3 * (shared_count as f64 / 10.0).min(1.0),
298                        match_source: MatchSource::Cooccurrence {
299                            shared_dailies: shared_count as u32,
300                        },
301                    });
302                }
303            }
304        }
305
306        Ok(results)
307    }
308
309    /// Get staleness score for a note.
310    fn get_staleness(&self, note_id: Option<i64>) -> Option<f64> {
311        note_id.and_then(|id| {
312            self.db
313                .get_activity_summary(id)
314                .ok()
315                .flatten()
316                .map(|s| s.staleness_score as f64)
317        })
318    }
319
320    /// Deduplicate results, keeping highest score for each note.
321    fn deduplicate_results(&self, results: Vec<SearchResult>) -> Vec<SearchResult> {
322        let mut best: HashMap<i64, SearchResult> = HashMap::new();
323
324        for result in results {
325            if let Some(id) = result.note.id {
326                best.entry(id)
327                    .and_modify(|existing| {
328                        if result.score > existing.score {
329                            *existing = result.clone();
330                        }
331                    })
332                    .or_insert(result);
333            }
334        }
335
336        best.into_values().collect()
337    }
338}
339
340#[cfg(test)]
341mod tests {
342    use super::*;
343    use chrono::Utc;
344    use std::path::PathBuf;
345
346    fn sample_note(path: &str, title: &str, note_type: NoteType) -> IndexedNote {
347        IndexedNote {
348            id: None,
349            path: PathBuf::from(path),
350            note_type,
351            title: title.to_string(),
352            created: Some(Utc::now()),
353            modified: Utc::now(),
354            frontmatter_json: None,
355            content_hash: format!("hash-{}", path),
356        }
357    }
358
359    #[test]
360    fn test_direct_search() {
361        let db = IndexDb::open_in_memory().unwrap();
362
363        // Insert test notes
364        db.insert_note(&sample_note(
365            "tasks/task1.md",
366            "Fix bug in parser",
367            NoteType::Task,
368        ))
369        .unwrap();
370        db.insert_note(&sample_note(
371            "tasks/task2.md",
372            "Write documentation",
373            NoteType::Task,
374        ))
375        .unwrap();
376        db.insert_note(&sample_note(
377            "zettel/note1.md",
378            "Parser internals",
379            NoteType::Zettel,
380        ))
381        .unwrap();
382
383        let engine = SearchEngine::new(&db);
384
385        // Search for "parser"
386        let query = SearchQuery {
387            text: Some("parser".to_string()),
388            mode: SearchMode::Direct,
389            ..Default::default()
390        };
391
392        let results = engine.search(&query).unwrap();
393        assert_eq!(results.len(), 2);
394        assert!(results.iter().all(|r| r.match_source == MatchSource::Direct));
395    }
396
397    #[test]
398    fn test_type_filter() {
399        let db = IndexDb::open_in_memory().unwrap();
400
401        db.insert_note(&sample_note("tasks/task1.md", "Task note", NoteType::Task))
402            .unwrap();
403        db.insert_note(&sample_note("zettel/note1.md", "Zettel note", NoteType::Zettel))
404            .unwrap();
405
406        let engine = SearchEngine::new(&db);
407
408        let query = SearchQuery {
409            note_type: Some(NoteType::Task),
410            mode: SearchMode::Direct,
411            ..Default::default()
412        };
413
414        let results = engine.search(&query).unwrap();
415        assert_eq!(results.len(), 1);
416        assert_eq!(results[0].note.note_type, NoteType::Task);
417    }
418}