1use anyhow::Result;
2use rand::{distributions::Alphanumeric, Rng};
3use serde::{Deserialize, Serialize};
4use std::collections::hash_map::DefaultHasher;
5use std::collections::{HashMap, HashSet};
6use std::fs::{create_dir_all, File};
7use std::hash::{Hash, Hasher};
8use std::io::{Read, Write};
9use std::path::PathBuf;
10
11use probe_code::models::SearchResult;
12
13pub fn hash_query(query: &str) -> String {
16 let mut hasher = DefaultHasher::new();
17 query.hash(&mut hasher);
18 format!("{:x}", hasher.finish())
19}
20
21#[derive(Debug, Serialize, Deserialize)]
23pub struct SessionCache {
24 pub session_id: String,
26 pub query_hash: String,
28 pub block_identifiers: HashSet<String>,
31}
32
33impl SessionCache {
34 pub fn new(session_id: String, query_hash: String) -> Self {
36 Self {
37 session_id,
38 query_hash,
39 block_identifiers: HashSet::new(),
40 }
41 }
42
43 pub fn load(session_id: &str, query_hash: &str) -> Result<Self> {
45 let debug_mode = std::env::var("DEBUG").unwrap_or_default() == "1";
46 let cache_path = Self::get_cache_path(session_id, query_hash);
47
48 if !cache_path.exists() {
50 if debug_mode {
51 println!("DEBUG: Cache file does not exist at {cache_path:?}, creating new cache");
52 }
53 return Ok(Self::new(session_id.to_string(), query_hash.to_string()));
54 }
55
56 if debug_mode {
57 println!("DEBUG: Loading cache from {cache_path:?}");
58 }
59
60 let mut file = match File::open(&cache_path) {
62 Ok(f) => f,
63 Err(e) => {
64 if debug_mode {
65 println!("DEBUG: Error opening cache file: {e}");
66 }
67 return Ok(Self::new(session_id.to_string(), query_hash.to_string()));
68 }
69 };
70
71 let mut contents = String::new();
72 if let Err(e) = file.read_to_string(&mut contents) {
73 if debug_mode {
74 println!("DEBUG: Error reading cache file: {e}");
75 }
76 return Ok(Self::new(session_id.to_string(), query_hash.to_string()));
77 }
78
79 match serde_json::from_str(&contents) {
81 Ok(cache) => {
82 let cache: SessionCache = cache;
83 if debug_mode {
84 println!(
85 "DEBUG: Successfully loaded cache with {} entries",
86 cache.block_identifiers.len()
87 );
88 }
89 Ok(cache)
90 }
91 Err(e) => {
92 if debug_mode {
93 println!("DEBUG: Error parsing cache JSON: {e}");
94 }
95 Ok(Self::new(session_id.to_string(), query_hash.to_string()))
96 }
97 }
98 }
99
100 pub fn save(&self) -> Result<()> {
102 let debug_mode = std::env::var("DEBUG").unwrap_or_default() == "1";
103 let cache_path = Self::get_cache_path(&self.session_id, &self.query_hash);
104
105 if debug_mode {
106 println!(
107 "DEBUG: Saving cache with {} entries to {:?}",
108 self.block_identifiers.len(),
109 cache_path
110 );
111 }
112
113 if let Some(parent) = cache_path.parent() {
115 if let Err(e) = create_dir_all(parent) {
116 if debug_mode {
117 println!("DEBUG: Error creating cache directory: {e}");
118 }
119 return Err(e.into());
120 }
121 }
122
123 let json = match serde_json::to_string_pretty(self) {
125 Ok(j) => j,
126 Err(e) => {
127 if debug_mode {
128 println!("DEBUG: Error serializing cache to JSON: {e}");
129 }
130 return Err(e.into());
131 }
132 };
133
134 match File::create(&cache_path) {
136 Ok(mut file) => {
137 if let Err(e) = file.write_all(json.as_bytes()) {
138 if debug_mode {
139 println!("DEBUG: Error writing to cache file: {e}");
140 }
141 return Err(e.into());
142 }
143 }
144 Err(e) => {
145 if debug_mode {
146 println!("DEBUG: Error creating cache file: {e}");
147 }
148 return Err(e.into());
149 }
150 }
151
152 if debug_mode {
153 println!("DEBUG: Successfully saved cache to disk");
154 }
155
156 Ok(())
157 }
158
159 pub fn is_cached(&self, block_id: &str) -> bool {
161 self.block_identifiers.contains(block_id)
162 }
163
164 pub fn add_to_cache(&mut self, block_id: String) {
166 self.block_identifiers.insert(block_id);
167 }
168
169 pub fn get_cache_path(session_id: &str, query_hash: &str) -> PathBuf {
171 let home_dir = dirs::home_dir().unwrap_or_else(|| PathBuf::from("."));
172 home_dir
173 .join(".cache")
174 .join("probe")
175 .join("sessions")
176 .join(format!("{session_id}_{query_hash}.json"))
177 }
178}
179fn normalize_path(path: &str) -> String {
182 let normalized = if let Some(stripped) = path.strip_prefix("./") {
184 stripped
185 } else {
186 path
187 };
188
189 normalized.to_string()
190}
191
192pub fn generate_cache_key(result: &SearchResult) -> String {
195 let normalized_path = normalize_path(&result.file);
196 format!("{normalized_path}:{}-{}", result.lines.0, result.lines.1)
197}
198
199pub fn filter_results_with_cache(
201 results: &[SearchResult],
202 session_id: &str,
203 query: &str,
204) -> Result<(Vec<SearchResult>, usize)> {
205 let query_hash = hash_query(query);
206 let debug_mode = std::env::var("DEBUG").unwrap_or_default() == "1";
207
208 let cache_path = SessionCache::get_cache_path(session_id, &query_hash);
210 let is_new_session = !cache_path.exists();
211
212 if is_new_session {
214 if debug_mode {
215 println!("DEBUG: New session, not filtering results");
216 }
217 return Ok((results.to_vec(), 0));
219 }
220
221 let cache = SessionCache::load(session_id, &query_hash)?;
223
224 if cache.block_identifiers.is_empty() {
226 if debug_mode {
227 println!("DEBUG: Cache is empty, not filtering results");
228 }
229 return Ok((results.to_vec(), 0));
230 }
231
232 if debug_mode {
233 println!(
234 "DEBUG: Filtering {} results against {} cached blocks",
235 results.len(),
236 cache.block_identifiers.len()
237 );
238 }
239
240 let mut skipped_count = 0;
242
243 let filtered_results: Vec<SearchResult> = results
245 .iter()
246 .filter(|result| {
247 let cache_key = generate_cache_key(result);
248 let is_cached = cache.is_cached(&cache_key);
249
250 if is_cached {
251 if debug_mode && skipped_count < 5 {
252 println!("DEBUG: Skipping cached block: {cache_key}");
253 }
254 skipped_count += 1;
255 false
256 } else {
257 true
258 }
259 })
260 .cloned()
261 .collect();
262
263 if debug_mode {
264 println!(
265 "DEBUG: Filtered out {} cached blocks, returning {} results",
266 skipped_count,
267 filtered_results.len()
268 );
269 }
270
271 Ok((filtered_results, skipped_count))
272}
273
274pub fn filter_matched_lines_with_cache(
277 file_term_map: &mut HashMap<PathBuf, HashMap<usize, HashSet<usize>>>,
278 session_id: &str,
279 query: &str,
280) -> Result<usize> {
281 let query_hash = hash_query(query);
282 let debug_mode = std::env::var("DEBUG").unwrap_or_default() == "1";
283
284 let cache_path = SessionCache::get_cache_path(session_id, &query_hash);
286 let is_new_session = !cache_path.exists();
287
288 if is_new_session {
290 if debug_mode {
291 println!("DEBUG: New session, not filtering matched lines");
292 }
293 return Ok(0);
294 }
295
296 let cache = SessionCache::load(session_id, &query_hash)?;
298
299 if cache.block_identifiers.is_empty() {
301 if debug_mode {
302 println!("DEBUG: Cache is empty, not filtering matched lines");
303 }
304 return Ok(0);
305 }
306
307 if debug_mode {
308 println!(
309 "DEBUG: Early filtering of matched lines against {} cached blocks",
310 cache.block_identifiers.len()
311 );
312 }
313
314 let mut skipped_count = 0;
316 let mut files_to_remove = Vec::new();
317
318 for (file_path, term_map) in file_term_map.iter_mut() {
320 if term_map.is_empty() {
321 continue;
322 }
323
324 let mut all_lines = HashSet::new();
326 for lineset in term_map.values() {
327 all_lines.extend(lineset.iter());
328 }
329
330 if debug_mode {
331 println!(
332 "DEBUG: File {:?} has {} matched lines before filtering",
333 file_path,
334 all_lines.len()
335 );
336 }
337
338 let mut lines_to_remove = HashSet::new();
340 for &line_num in &all_lines {
341 let path_str = file_path.to_string_lossy();
344 let normalized_path = normalize_path(&path_str);
345 let line_cache_key = format!("{normalized_path}:{line_num}");
346
347 let is_cached = cache.block_identifiers.iter().any(|block_id| {
349 if let Some(colon_pos) = block_id.find(':') {
351 if let Some(dash_pos) = block_id[colon_pos + 1..].find('-') {
352 let file_part = &block_id[..colon_pos];
353 let start_line_str = &block_id[colon_pos + 1..colon_pos + 1 + dash_pos];
354 let end_line_str = &block_id[colon_pos + 1 + dash_pos + 1..];
355
356 if let (Ok(start_line), Ok(end_line)) = (
357 start_line_str.parse::<usize>(),
358 end_line_str.parse::<usize>(),
359 ) {
360 let path_str = file_path.to_string_lossy();
362 let normalized_path = normalize_path(&path_str);
363 let normalized_file_part = normalize_path(file_part);
364
365 return normalized_file_part == normalized_path
366 && line_num >= start_line
367 && line_num <= end_line;
368 }
369 }
370 }
371 false
372 });
373
374 if is_cached {
375 if debug_mode && skipped_count < 5 {
376 println!("DEBUG: Skipping cached line: {line_cache_key}");
377 }
378 lines_to_remove.insert(line_num);
379 skipped_count += 1;
380 }
381 }
382
383 for term_lines in term_map.values_mut() {
385 for line in &lines_to_remove {
386 term_lines.remove(line);
387 }
388 }
389
390 term_map.retain(|_, lines| !lines.is_empty());
392
393 if term_map.is_empty() {
395 files_to_remove.push(file_path.clone());
396 }
397
398 if debug_mode {
399 let remaining_lines: HashSet<_> =
400 term_map.values().flat_map(|lines| lines.iter()).collect();
401 println!(
402 "DEBUG: File {:?} has {} matched lines after filtering",
403 file_path,
404 remaining_lines.len()
405 );
406 }
407 }
408
409 for file in files_to_remove {
411 file_term_map.remove(&file);
412 }
413
414 if debug_mode {
415 println!(
416 "DEBUG: Early filtering removed {} cached lines, {} files remain",
417 skipped_count,
418 file_term_map.len()
419 );
420 }
421
422 Ok(skipped_count)
423}
424
425pub fn add_results_to_cache(results: &[SearchResult], session_id: &str, query: &str) -> Result<()> {
427 let debug_mode = std::env::var("DEBUG").unwrap_or_default() == "1";
428 let query_hash = hash_query(query);
429
430 let mut cache = SessionCache::load(session_id, &query_hash)?;
432
433 if debug_mode {
434 println!(
435 "DEBUG: Adding {} results to cache for session {}",
436 results.len(),
437 session_id
438 );
439 println!(
440 "DEBUG: Cache had {} entries before update",
441 cache.block_identifiers.len()
442 );
443 }
444
445 let mut new_entries = 0;
447 for result in results {
448 let cache_key = generate_cache_key(result);
449 if !cache.is_cached(&cache_key) {
450 new_entries += 1;
451 if debug_mode && new_entries <= 5 {
452 println!("DEBUG: Adding new cache entry: {cache_key}");
453 }
454 }
455 cache.add_to_cache(cache_key);
456 }
457
458 if debug_mode {
459 println!("DEBUG: Added {new_entries} new entries to cache");
460 println!(
461 "DEBUG: Cache now has {} entries",
462 cache.block_identifiers.len()
463 );
464 }
465
466 cache.save()?;
468
469 Ok(())
470}
471
472pub fn debug_print_cache(session_id: &str, query: &str) -> Result<()> {
474 let debug_mode = std::env::var("DEBUG").unwrap_or_default() == "1";
475 if !debug_mode {
476 return Ok(());
477 }
478
479 let query_hash = hash_query(query);
480 let cache = SessionCache::load(session_id, &query_hash)?;
481
482 println!("DEBUG: Cache for session {session_id} with query hash {query_hash}");
483 println!(
484 "DEBUG: Contains {} cached blocks",
485 cache.block_identifiers.len()
486 );
487
488 for (i, block_id) in cache.block_identifiers.iter().enumerate().take(10) {
489 println!("DEBUG: Cached block {i}: {block_id}");
490 }
491
492 if cache.block_identifiers.len() > 10 {
493 let _remaining = cache.block_identifiers.len() - 10;
494 println!("DEBUG: ... and {} more", cache.block_identifiers.len() - 10);
495 }
496
497 Ok(())
498}
499
500pub fn generate_session_id() -> Result<(&'static str, bool)> {
503 let debug_mode = std::env::var("DEBUG").unwrap_or_default() == "1";
504
505 if (0..10).next().is_some() {
507 let session_id: String = rand::thread_rng()
509 .sample_iter(&Alphanumeric)
510 .take(4)
511 .map(char::from)
512 .collect();
513
514 let session_id = session_id.to_lowercase();
516
517 if debug_mode {
518 println!("DEBUG: Generated session ID: {session_id}");
519 }
520
521 if debug_mode {
524 println!("DEBUG: Generated new session ID: {session_id}");
525 }
526 let static_id: &'static str = Box::leak(session_id.into_boxed_str());
528 return Ok((static_id, true));
529 }
530
531 Err(anyhow::anyhow!(
533 "Failed to generate a unique session ID after multiple attempts"
534 ))
535}
536
537#[cfg(test)]
538mod tests {
539 use super::*;
540 use probe_code::models::SearchResult;
541
542 #[test]
543 fn test_path_normalization() {
544 assert_eq!(normalize_path("./path/to/file.rs"), "path/to/file.rs");
546 assert_eq!(normalize_path("path/to/file.rs"), "path/to/file.rs");
547 }
548
549 #[test]
550 fn test_query_hashing() {
551 let hash1 = hash_query("query1");
553 let hash2 = hash_query("query2");
554 assert_ne!(hash1, hash2);
555
556 let hash3 = hash_query("query1");
558 assert_eq!(hash1, hash3);
559 }
560
561 #[test]
562 fn test_cache_key_generation_with_different_path_formats() {
563 let result1 = SearchResult {
565 file: "./path/to/file.rs".to_string(),
566 lines: (10, 20),
567 node_type: "function".to_string(),
568 code: "".to_string(),
569 matched_by_filename: None,
570 rank: None,
571 score: None,
572 tfidf_score: None,
573 bm25_score: None,
574 tfidf_rank: None,
575 bm25_rank: None,
576 new_score: None,
577 hybrid2_rank: None,
578 combined_score_rank: None,
579 file_unique_terms: None,
580 file_total_matches: None,
581 file_match_rank: None,
582 block_unique_terms: None,
583 block_total_matches: None,
584 parent_file_id: None,
585 block_id: None,
586 matched_keywords: None,
587 tokenized_content: None,
588 };
589
590 let result2 = SearchResult {
591 file: "path/to/file.rs".to_string(),
592 lines: (10, 20),
593 node_type: "function".to_string(),
594 code: "".to_string(),
595 matched_by_filename: None,
596 rank: None,
597 score: None,
598 tfidf_score: None,
599 bm25_score: None,
600 tfidf_rank: None,
601 bm25_rank: None,
602 new_score: None,
603 hybrid2_rank: None,
604 combined_score_rank: None,
605 file_unique_terms: None,
606 file_total_matches: None,
607 file_match_rank: None,
608 block_unique_terms: None,
609 block_total_matches: None,
610 parent_file_id: None,
611 block_id: None,
612 matched_keywords: None,
613 tokenized_content: None,
614 };
615
616 let key1 = generate_cache_key(&result1);
618 let key2 = generate_cache_key(&result2);
619
620 assert_eq!(key1, key2);
622 assert_eq!(key1, "path/to/file.rs:10-20");
623 }
624
625 #[test]
626 fn test_session_cache_with_query_hash() {
627 let session_id = "test_session";
629 let query1 = "query1";
630 let query2 = "query2";
631
632 let hash1 = hash_query(query1);
633 let hash2 = hash_query(query2);
634
635 let path1 = SessionCache::get_cache_path(session_id, &hash1);
636 let path2 = SessionCache::get_cache_path(session_id, &hash2);
637
638 assert_ne!(path1, path2);
640
641 let cache1 = SessionCache::new(session_id.to_string(), hash1);
643 let cache2 = SessionCache::new(session_id.to_string(), hash2);
644
645 assert_ne!(cache1.query_hash, cache2.query_hash);
647 }
648}