1use std::collections::{HashMap, HashSet};
10
11use super::IndexError;
12use super::db::IndexDb;
13use super::types::{IndexedNote, NoteType};
14
15#[derive(Debug, Clone, Copy, Default)]
17pub enum SearchMode {
18 #[default]
20 Direct,
21 Neighbourhood { hops: u32 },
23 Temporal { days: u32 },
25 Cooccurrence { min_shared: u32 },
27 Full,
29}
30
31#[derive(Debug, Clone, Default)]
33pub struct SearchQuery {
34 pub text: Option<String>,
36 pub note_type: Option<NoteType>,
38 pub path_prefix: Option<String>,
40 pub mode: SearchMode,
42 pub limit: Option<u32>,
44 pub temporal_boost: bool,
46}
47
48#[derive(Debug, Clone)]
50pub struct SearchResult {
51 pub note: IndexedNote,
53 pub score: f64,
55 pub match_source: MatchSource,
57 pub staleness: Option<f64>,
59}
60
61#[derive(Debug, Clone, PartialEq, Eq)]
63pub enum MatchSource {
64 Direct,
66 Linked { hops: u32 },
68 Temporal { daily_path: String },
70 Cooccurrence { shared_dailies: u32 },
72}
73
74pub struct SearchEngine<'a> {
76 db: &'a IndexDb,
77}
78
79impl<'a> SearchEngine<'a> {
80 pub fn new(db: &'a IndexDb) -> Self {
82 Self { db }
83 }
84
85 pub fn search(&self, query: &SearchQuery) -> Result<Vec<SearchResult>, IndexError> {
87 let direct_matches = self.find_direct_matches(query)?;
89 let direct_ids: HashSet<i64> =
90 direct_matches.iter().filter_map(|n| n.id).collect();
91
92 let mut results: Vec<SearchResult> = direct_matches
93 .into_iter()
94 .map(|note| SearchResult {
95 staleness: self.get_staleness(note.id),
96 note,
97 score: 1.0,
98 match_source: MatchSource::Direct,
99 })
100 .collect();
101
102 match query.mode {
104 SearchMode::Direct => {}
105 SearchMode::Neighbourhood { hops } => {
106 let expanded = self.expand_neighbourhood(&direct_ids, hops)?;
107 results.extend(expanded);
108 }
109 SearchMode::Temporal { days } => {
110 let expanded = self.expand_temporal(&direct_ids, days)?;
111 results.extend(expanded);
112 }
113 SearchMode::Cooccurrence { min_shared } => {
114 let expanded = self.expand_cooccurrence(&direct_ids, min_shared)?;
115 results.extend(expanded);
116 }
117 SearchMode::Full => {
118 let neighbourhood = self.expand_neighbourhood(&direct_ids, 2)?;
120 let temporal = self.expand_temporal(&direct_ids, 30)?;
121 let cooccurrence = self.expand_cooccurrence(&direct_ids, 2)?;
122 results.extend(neighbourhood);
123 results.extend(temporal);
124 results.extend(cooccurrence);
125 }
126 }
127
128 if query.temporal_boost {
130 for result in &mut results {
131 if let Some(staleness) = result.staleness {
132 result.score *= 1.0 + (1.0 - staleness) * 0.5;
134 }
135 }
136 }
137
138 results = self.deduplicate_results(results);
140 results.sort_by(|a, b| {
141 b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal)
142 });
143
144 if let Some(limit) = query.limit {
146 results.truncate(limit as usize);
147 }
148
149 Ok(results)
150 }
151
152 fn find_direct_matches(
154 &self,
155 query: &SearchQuery,
156 ) -> Result<Vec<IndexedNote>, IndexError> {
157 let note_query = super::types::NoteQuery {
159 note_type: query.note_type,
160 path_prefix: query.path_prefix.as_ref().map(Into::into),
161 limit: query.limit,
162 ..Default::default()
163 };
164
165 let notes = self.db.query_notes(¬e_query)?;
166
167 if let Some(text) = &query.text {
169 let text_lower = text.to_lowercase();
170 Ok(notes
171 .into_iter()
172 .filter(|n| {
173 n.title.to_lowercase().contains(&text_lower)
174 || n.path.to_string_lossy().to_lowercase().contains(&text_lower)
175 })
176 .collect())
177 } else {
178 Ok(notes)
179 }
180 }
181
182 fn expand_neighbourhood(
184 &self,
185 seed_ids: &HashSet<i64>,
186 max_hops: u32,
187 ) -> Result<Vec<SearchResult>, IndexError> {
188 let mut results = Vec::new();
189 let mut visited: HashSet<i64> = seed_ids.clone();
190 let mut frontier: HashSet<i64> = seed_ids.clone();
191
192 for hop in 1..=max_hops {
193 let mut next_frontier = HashSet::new();
194
195 for ¬e_id in &frontier {
196 let outlinks = self.db.get_outgoing_links(note_id)?;
198 for link in outlinks {
199 if let Some(target_id) = link.target_id
200 && !visited.contains(&target_id)
201 {
202 visited.insert(target_id);
203 next_frontier.insert(target_id);
204
205 if let Some(note) = self.db.get_note_by_id(target_id)? {
206 results.push(SearchResult {
207 staleness: self.get_staleness(note.id),
208 note,
209 score: 0.5 / (hop as f64), match_source: MatchSource::Linked { hops: hop },
211 });
212 }
213 }
214 }
215
216 let backlinks = self.db.get_backlinks(note_id)?;
218 for link in backlinks {
219 if !visited.contains(&link.source_id) {
220 visited.insert(link.source_id);
221 next_frontier.insert(link.source_id);
222
223 if let Some(note) = self.db.get_note_by_id(link.source_id)? {
224 results.push(SearchResult {
225 staleness: self.get_staleness(note.id),
226 note,
227 score: 0.5 / (hop as f64),
228 match_source: MatchSource::Linked { hops: hop },
229 });
230 }
231 }
232 }
233 }
234
235 frontier = next_frontier;
236 if frontier.is_empty() {
237 break;
238 }
239 }
240
241 Ok(results)
242 }
243
244 fn expand_temporal(
246 &self,
247 seed_ids: &HashSet<i64>,
248 _days: u32,
249 ) -> Result<Vec<SearchResult>, IndexError> {
250 let mut results = Vec::new();
251 let mut seen_dailies: HashSet<i64> = HashSet::new();
252
253 for ¬e_id in seed_ids {
254 let backlinks = self.db.get_backlinks(note_id)?;
256 for link in backlinks {
257 if let Some(source_note) = self.db.get_note_by_id(link.source_id)?
258 && source_note.note_type == NoteType::Daily
259 && !seen_dailies.contains(&link.source_id)
260 && !seed_ids.contains(&link.source_id)
261 {
262 seen_dailies.insert(link.source_id);
263 let path = source_note.path.to_string_lossy().to_string();
264 results.push(SearchResult {
265 staleness: self.get_staleness(source_note.id),
266 note: source_note,
267 score: 0.4,
268 match_source: MatchSource::Temporal { daily_path: path },
269 });
270 }
271 }
272 }
273
274 Ok(results)
275 }
276
277 fn expand_cooccurrence(
279 &self,
280 seed_ids: &HashSet<i64>,
281 min_shared: u32,
282 ) -> Result<Vec<SearchResult>, IndexError> {
283 let mut results = Vec::new();
284 let mut seen: HashSet<i64> = seed_ids.clone();
285
286 for ¬e_id in seed_ids {
287 let cooccurrent = self.db.get_cooccurrent_notes(note_id, 10)?;
288 for (note, shared_count) in cooccurrent {
289 if let Some(id) = note.id
290 && shared_count >= min_shared as i32
291 && !seen.contains(&id)
292 {
293 seen.insert(id);
294 results.push(SearchResult {
295 staleness: self.get_staleness(note.id),
296 note,
297 score: 0.3 * (shared_count as f64 / 10.0).min(1.0),
298 match_source: MatchSource::Cooccurrence {
299 shared_dailies: shared_count as u32,
300 },
301 });
302 }
303 }
304 }
305
306 Ok(results)
307 }
308
309 fn get_staleness(&self, note_id: Option<i64>) -> Option<f64> {
311 note_id.and_then(|id| {
312 self.db
313 .get_activity_summary(id)
314 .ok()
315 .flatten()
316 .map(|s| s.staleness_score as f64)
317 })
318 }
319
320 fn deduplicate_results(&self, results: Vec<SearchResult>) -> Vec<SearchResult> {
322 let mut best: HashMap<i64, SearchResult> = HashMap::new();
323
324 for result in results {
325 if let Some(id) = result.note.id {
326 best.entry(id)
327 .and_modify(|existing| {
328 if result.score > existing.score {
329 *existing = result.clone();
330 }
331 })
332 .or_insert(result);
333 }
334 }
335
336 best.into_values().collect()
337 }
338}
339
340#[cfg(test)]
341mod tests {
342 use super::*;
343 use chrono::Utc;
344 use std::path::PathBuf;
345
346 fn sample_note(path: &str, title: &str, note_type: NoteType) -> IndexedNote {
347 IndexedNote {
348 id: None,
349 path: PathBuf::from(path),
350 note_type,
351 title: title.to_string(),
352 created: Some(Utc::now()),
353 modified: Utc::now(),
354 frontmatter_json: None,
355 content_hash: format!("hash-{}", path),
356 }
357 }
358
359 #[test]
360 fn test_direct_search() {
361 let db = IndexDb::open_in_memory().unwrap();
362
363 db.insert_note(&sample_note(
365 "tasks/task1.md",
366 "Fix bug in parser",
367 NoteType::Task,
368 ))
369 .unwrap();
370 db.insert_note(&sample_note(
371 "tasks/task2.md",
372 "Write documentation",
373 NoteType::Task,
374 ))
375 .unwrap();
376 db.insert_note(&sample_note(
377 "zettel/note1.md",
378 "Parser internals",
379 NoteType::Zettel,
380 ))
381 .unwrap();
382
383 let engine = SearchEngine::new(&db);
384
385 let query = SearchQuery {
387 text: Some("parser".to_string()),
388 mode: SearchMode::Direct,
389 ..Default::default()
390 };
391
392 let results = engine.search(&query).unwrap();
393 assert_eq!(results.len(), 2);
394 assert!(results.iter().all(|r| r.match_source == MatchSource::Direct));
395 }
396
397 #[test]
398 fn test_type_filter() {
399 let db = IndexDb::open_in_memory().unwrap();
400
401 db.insert_note(&sample_note("tasks/task1.md", "Task note", NoteType::Task))
402 .unwrap();
403 db.insert_note(&sample_note("zettel/note1.md", "Zettel note", NoteType::Zettel))
404 .unwrap();
405
406 let engine = SearchEngine::new(&db);
407
408 let query = SearchQuery {
409 note_type: Some(NoteType::Task),
410 mode: SearchMode::Direct,
411 ..Default::default()
412 };
413
414 let results = engine.search(&query).unwrap();
415 assert_eq!(results.len(), 1);
416 assert_eq!(results[0].note.note_type, NoteType::Task);
417 }
418}