probe_code/search/
cache.rs

1use anyhow::Result;
2use rand::{distributions::Alphanumeric, Rng};
3use serde::{Deserialize, Serialize};
4use std::collections::hash_map::DefaultHasher;
5use std::collections::{HashMap, HashSet};
6use std::fs::{create_dir_all, File};
7use std::hash::{Hash, Hasher};
8use std::io::{Read, Write};
9use std::path::PathBuf;
10
11use probe_code::models::SearchResult;
12
13/// Generate a hash for a query string
14/// This is used to create a unique identifier for each query
15pub fn hash_query(query: &str) -> String {
16    let mut hasher = DefaultHasher::new();
17    query.hash(&mut hasher);
18    format!("{:x}", hasher.finish())
19}
20
21/// Structure to hold cache data for a session
22#[derive(Debug, Serialize, Deserialize)]
23pub struct SessionCache {
24    /// Session identifier
25    pub session_id: String,
26    /// Query hash for this cache
27    pub query_hash: String,
28    /// Set of block identifiers that have been seen in this session
29    /// Format: "file.rs:23-45" (file path with start-end line numbers)
30    pub block_identifiers: HashSet<String>,
31}
32
33impl SessionCache {
34    /// Create a new session cache with the given ID and query hash
35    pub fn new(session_id: String, query_hash: String) -> Self {
36        Self {
37            session_id,
38            query_hash,
39            block_identifiers: HashSet::new(),
40        }
41    }
42
43    /// Load a session cache from disk
44    pub fn load(session_id: &str, query_hash: &str) -> Result<Self> {
45        let debug_mode = std::env::var("DEBUG").unwrap_or_default() == "1";
46        let cache_path = Self::get_cache_path(session_id, query_hash);
47
48        // If the cache file doesn't exist, create a new empty cache
49        if !cache_path.exists() {
50            if debug_mode {
51                println!("DEBUG: Cache file does not exist at {cache_path:?}, creating new cache");
52            }
53            return Ok(Self::new(session_id.to_string(), query_hash.to_string()));
54        }
55
56        if debug_mode {
57            println!("DEBUG: Loading cache from {cache_path:?}");
58        }
59
60        // Read the cache file
61        let mut file = match File::open(&cache_path) {
62            Ok(f) => f,
63            Err(e) => {
64                if debug_mode {
65                    println!("DEBUG: Error opening cache file: {e}");
66                }
67                return Ok(Self::new(session_id.to_string(), query_hash.to_string()));
68            }
69        };
70
71        let mut contents = String::new();
72        if let Err(e) = file.read_to_string(&mut contents) {
73            if debug_mode {
74                println!("DEBUG: Error reading cache file: {e}");
75            }
76            return Ok(Self::new(session_id.to_string(), query_hash.to_string()));
77        }
78
79        // Parse the JSON
80        match serde_json::from_str(&contents) {
81            Ok(cache) => {
82                let cache: SessionCache = cache;
83                if debug_mode {
84                    println!(
85                        "DEBUG: Successfully loaded cache with {} entries",
86                        cache.block_identifiers.len()
87                    );
88                }
89                Ok(cache)
90            }
91            Err(e) => {
92                if debug_mode {
93                    println!("DEBUG: Error parsing cache JSON: {e}");
94                }
95                Ok(Self::new(session_id.to_string(), query_hash.to_string()))
96            }
97        }
98    }
99
100    /// Save the session cache to disk
101    pub fn save(&self) -> Result<()> {
102        let debug_mode = std::env::var("DEBUG").unwrap_or_default() == "1";
103        let cache_path = Self::get_cache_path(&self.session_id, &self.query_hash);
104
105        if debug_mode {
106            println!(
107                "DEBUG: Saving cache with {} entries to {:?}",
108                self.block_identifiers.len(),
109                cache_path
110            );
111        }
112
113        // Ensure the cache directory exists
114        if let Some(parent) = cache_path.parent() {
115            if let Err(e) = create_dir_all(parent) {
116                if debug_mode {
117                    println!("DEBUG: Error creating cache directory: {e}");
118                }
119                return Err(e.into());
120            }
121        }
122
123        // Serialize the cache to JSON
124        let json = match serde_json::to_string_pretty(self) {
125            Ok(j) => j,
126            Err(e) => {
127                if debug_mode {
128                    println!("DEBUG: Error serializing cache to JSON: {e}");
129                }
130                return Err(e.into());
131            }
132        };
133
134        // Write to the cache file
135        match File::create(&cache_path) {
136            Ok(mut file) => {
137                if let Err(e) = file.write_all(json.as_bytes()) {
138                    if debug_mode {
139                        println!("DEBUG: Error writing to cache file: {e}");
140                    }
141                    return Err(e.into());
142                }
143            }
144            Err(e) => {
145                if debug_mode {
146                    println!("DEBUG: Error creating cache file: {e}");
147                }
148                return Err(e.into());
149            }
150        }
151
152        if debug_mode {
153            println!("DEBUG: Successfully saved cache to disk");
154        }
155
156        Ok(())
157    }
158
159    /// Check if a block identifier is in the cache
160    pub fn is_cached(&self, block_id: &str) -> bool {
161        self.block_identifiers.contains(block_id)
162    }
163
164    /// Add a block identifier to the cache
165    pub fn add_to_cache(&mut self, block_id: String) {
166        self.block_identifiers.insert(block_id);
167    }
168
169    /// Get the path to the cache file
170    pub fn get_cache_path(session_id: &str, query_hash: &str) -> PathBuf {
171        let home_dir = dirs::home_dir().unwrap_or_else(|| PathBuf::from("."));
172        home_dir
173            .join(".cache")
174            .join("probe")
175            .join("sessions")
176            .join(format!("{session_id}_{query_hash}.json"))
177    }
178}
179/// Normalize a file path for consistent cache keys
180/// Removes leading "./" and ensures consistent format
181fn normalize_path(path: &str) -> String {
182    // Remove leading "./"
183    let normalized = if let Some(stripped) = path.strip_prefix("./") {
184        stripped
185    } else {
186        path
187    };
188
189    normalized.to_string()
190}
191
192/// Generate a cache key for a search result
193/// Format: "file.rs:23-45" (file path with start-end line numbers)
194pub fn generate_cache_key(result: &SearchResult) -> String {
195    let normalized_path = normalize_path(&result.file);
196    format!("{normalized_path}:{}-{}", result.lines.0, result.lines.1)
197}
198
199/// Filter search results using the cache without adding to the cache
200pub fn filter_results_with_cache(
201    results: &[SearchResult],
202    session_id: &str,
203    query: &str,
204) -> Result<(Vec<SearchResult>, usize)> {
205    let query_hash = hash_query(query);
206    let debug_mode = std::env::var("DEBUG").unwrap_or_default() == "1";
207
208    // Check if this is a new session by looking for the cache file
209    let cache_path = SessionCache::get_cache_path(session_id, &query_hash);
210    let is_new_session = !cache_path.exists();
211
212    // For a new session, don't skip any results
213    if is_new_session {
214        if debug_mode {
215            println!("DEBUG: New session, not filtering results");
216        }
217        // Return all results with no skipped blocks
218        return Ok((results.to_vec(), 0));
219    }
220
221    // Load the cache
222    let cache = SessionCache::load(session_id, &query_hash)?;
223
224    // If the cache is empty, don't skip any results
225    if cache.block_identifiers.is_empty() {
226        if debug_mode {
227            println!("DEBUG: Cache is empty, not filtering results");
228        }
229        return Ok((results.to_vec(), 0));
230    }
231
232    if debug_mode {
233        println!(
234            "DEBUG: Filtering {} results against {} cached blocks",
235            results.len(),
236            cache.block_identifiers.len()
237        );
238    }
239
240    // Count of skipped blocks
241    let mut skipped_count = 0;
242
243    // For existing sessions, filter the results
244    let filtered_results: Vec<SearchResult> = results
245        .iter()
246        .filter(|result| {
247            let cache_key = generate_cache_key(result);
248            let is_cached = cache.is_cached(&cache_key);
249
250            if is_cached {
251                if debug_mode && skipped_count < 5 {
252                    println!("DEBUG: Skipping cached block: {cache_key}");
253                }
254                skipped_count += 1;
255                false
256            } else {
257                true
258            }
259        })
260        .cloned()
261        .collect();
262
263    if debug_mode {
264        println!(
265            "DEBUG: Filtered out {} cached blocks, returning {} results",
266            skipped_count,
267            filtered_results.len()
268        );
269    }
270
271    Ok((filtered_results, skipped_count))
272}
273
274/// Filter matched lines using the cache to skip already cached blocks
275/// This is applied early in the search process, right after ripgrep results
276pub fn filter_matched_lines_with_cache(
277    file_term_map: &mut HashMap<PathBuf, HashMap<usize, HashSet<usize>>>,
278    session_id: &str,
279    query: &str,
280) -> Result<usize> {
281    let query_hash = hash_query(query);
282    let debug_mode = std::env::var("DEBUG").unwrap_or_default() == "1";
283
284    // Check if this is a new session by looking for the cache file
285    let cache_path = SessionCache::get_cache_path(session_id, &query_hash);
286    let is_new_session = !cache_path.exists();
287
288    // For a new session, don't skip any lines
289    if is_new_session {
290        if debug_mode {
291            println!("DEBUG: New session, not filtering matched lines");
292        }
293        return Ok(0);
294    }
295
296    // Load the cache
297    let cache = SessionCache::load(session_id, &query_hash)?;
298
299    // If the cache is empty, don't skip any lines
300    if cache.block_identifiers.is_empty() {
301        if debug_mode {
302            println!("DEBUG: Cache is empty, not filtering matched lines");
303        }
304        return Ok(0);
305    }
306
307    if debug_mode {
308        println!(
309            "DEBUG: Early filtering of matched lines against {} cached blocks",
310            cache.block_identifiers.len()
311        );
312    }
313
314    // Count of skipped lines
315    let mut skipped_count = 0;
316    let mut files_to_remove = Vec::new();
317
318    // For each file in the map
319    for (file_path, term_map) in file_term_map.iter_mut() {
320        if term_map.is_empty() {
321            continue;
322        }
323
324        // Get all matched lines for this file
325        let mut all_lines = HashSet::new();
326        for lineset in term_map.values() {
327            all_lines.extend(lineset.iter());
328        }
329
330        if debug_mode {
331            println!(
332                "DEBUG: File {:?} has {} matched lines before filtering",
333                file_path,
334                all_lines.len()
335            );
336        }
337
338        // Check each line against the cache
339        let mut lines_to_remove = HashSet::new();
340        for &line_num in &all_lines {
341            // Create a simple cache key for this line
342            // Format: "file.rs:line_num"
343            let path_str = file_path.to_string_lossy();
344            let normalized_path = normalize_path(&path_str);
345            let line_cache_key = format!("{normalized_path}:{line_num}");
346
347            // Check if this line is part of a cached block
348            let is_cached = cache.block_identifiers.iter().any(|block_id| {
349                // Parse the block ID to get file and line range
350                if let Some(colon_pos) = block_id.find(':') {
351                    if let Some(dash_pos) = block_id[colon_pos + 1..].find('-') {
352                        let file_part = &block_id[..colon_pos];
353                        let start_line_str = &block_id[colon_pos + 1..colon_pos + 1 + dash_pos];
354                        let end_line_str = &block_id[colon_pos + 1 + dash_pos + 1..];
355
356                        if let (Ok(start_line), Ok(end_line)) = (
357                            start_line_str.parse::<usize>(),
358                            end_line_str.parse::<usize>(),
359                        ) {
360                            // Check if this line is within a cached block from the same file
361                            let path_str = file_path.to_string_lossy();
362                            let normalized_path = normalize_path(&path_str);
363                            let normalized_file_part = normalize_path(file_part);
364
365                            return normalized_file_part == normalized_path
366                                && line_num >= start_line
367                                && line_num <= end_line;
368                        }
369                    }
370                }
371                false
372            });
373
374            if is_cached {
375                if debug_mode && skipped_count < 5 {
376                    println!("DEBUG: Skipping cached line: {line_cache_key}");
377                }
378                lines_to_remove.insert(line_num);
379                skipped_count += 1;
380            }
381        }
382
383        // Remove cached lines from each term's line set
384        for term_lines in term_map.values_mut() {
385            for line in &lines_to_remove {
386                term_lines.remove(line);
387            }
388        }
389
390        // Remove terms with empty line sets
391        term_map.retain(|_, lines| !lines.is_empty());
392
393        // Mark file for removal if all terms have been removed
394        if term_map.is_empty() {
395            files_to_remove.push(file_path.clone());
396        }
397
398        if debug_mode {
399            let remaining_lines: HashSet<_> =
400                term_map.values().flat_map(|lines| lines.iter()).collect();
401            println!(
402                "DEBUG: File {:?} has {} matched lines after filtering",
403                file_path,
404                remaining_lines.len()
405            );
406        }
407    }
408
409    // Remove files with no remaining terms
410    for file in files_to_remove {
411        file_term_map.remove(&file);
412    }
413
414    if debug_mode {
415        println!(
416            "DEBUG: Early filtering removed {} cached lines, {} files remain",
417            skipped_count,
418            file_term_map.len()
419        );
420    }
421
422    Ok(skipped_count)
423}
424
425/// Add search results to the cache
426pub fn add_results_to_cache(results: &[SearchResult], session_id: &str, query: &str) -> Result<()> {
427    let debug_mode = std::env::var("DEBUG").unwrap_or_default() == "1";
428    let query_hash = hash_query(query);
429
430    // Load or create the cache
431    let mut cache = SessionCache::load(session_id, &query_hash)?;
432
433    if debug_mode {
434        println!(
435            "DEBUG: Adding {} results to cache for session {}",
436            results.len(),
437            session_id
438        );
439        println!(
440            "DEBUG: Cache had {} entries before update",
441            cache.block_identifiers.len()
442        );
443    }
444
445    // Add all results to the cache
446    let mut new_entries = 0;
447    for result in results {
448        let cache_key = generate_cache_key(result);
449        if !cache.is_cached(&cache_key) {
450            new_entries += 1;
451            if debug_mode && new_entries <= 5 {
452                println!("DEBUG: Adding new cache entry: {cache_key}");
453            }
454        }
455        cache.add_to_cache(cache_key);
456    }
457
458    if debug_mode {
459        println!("DEBUG: Added {new_entries} new entries to cache");
460        println!(
461            "DEBUG: Cache now has {} entries",
462            cache.block_identifiers.len()
463        );
464    }
465
466    // Save the updated cache
467    cache.save()?;
468
469    Ok(())
470}
471
472/// Debug function to print cache contents (only used when DEBUG=1)
473pub fn debug_print_cache(session_id: &str, query: &str) -> Result<()> {
474    let debug_mode = std::env::var("DEBUG").unwrap_or_default() == "1";
475    if !debug_mode {
476        return Ok(());
477    }
478
479    let query_hash = hash_query(query);
480    let cache = SessionCache::load(session_id, &query_hash)?;
481
482    println!("DEBUG: Cache for session {session_id} with query hash {query_hash}");
483    println!(
484        "DEBUG: Contains {} cached blocks",
485        cache.block_identifiers.len()
486    );
487
488    for (i, block_id) in cache.block_identifiers.iter().enumerate().take(10) {
489        println!("DEBUG: Cached block {i}: {block_id}");
490    }
491
492    if cache.block_identifiers.len() > 10 {
493        let _remaining = cache.block_identifiers.len() - 10;
494        println!("DEBUG: ... and {} more", cache.block_identifiers.len() - 10);
495    }
496
497    Ok(())
498}
499
500/// Generate a unique 4-character alphanumeric session ID
501/// Returns a tuple of (session_id, is_new) where is_new indicates if this is a newly generated ID
502pub fn generate_session_id() -> Result<(&'static str, bool)> {
503    let debug_mode = std::env::var("DEBUG").unwrap_or_default() == "1";
504
505    // Generate a single session ID instead of looping
506    if (0..10).next().is_some() {
507        // Generate a random 4-character alphanumeric string
508        let session_id: String = rand::thread_rng()
509            .sample_iter(&Alphanumeric)
510            .take(4)
511            .map(char::from)
512            .collect();
513
514        // Convert to lowercase for consistency
515        let session_id = session_id.to_lowercase();
516
517        if debug_mode {
518            println!("DEBUG: Generated session ID: {session_id}");
519        }
520
521        // We don't check for existing cache files here since we're just generating a session ID
522        // The actual cache file will be created with both session ID and query hash
523        if debug_mode {
524            println!("DEBUG: Generated new session ID: {session_id}");
525        }
526        // Convert to a static string (this leaks memory, but it's a small amount and only happens once per session)
527        let static_id: &'static str = Box::leak(session_id.into_boxed_str());
528        return Ok((static_id, true));
529    }
530
531    // If we couldn't generate a unique ID after 10 attempts, return an error
532    Err(anyhow::anyhow!(
533        "Failed to generate a unique session ID after multiple attempts"
534    ))
535}
536
537#[cfg(test)]
538mod tests {
539    use super::*;
540    use probe_code::models::SearchResult;
541
542    #[test]
543    fn test_path_normalization() {
544        // Test that normalize_path removes leading "./"
545        assert_eq!(normalize_path("./path/to/file.rs"), "path/to/file.rs");
546        assert_eq!(normalize_path("path/to/file.rs"), "path/to/file.rs");
547    }
548
549    #[test]
550    fn test_query_hashing() {
551        // Test that different queries produce different hashes
552        let hash1 = hash_query("query1");
553        let hash2 = hash_query("query2");
554        assert_ne!(hash1, hash2);
555
556        // Test that the same query produces the same hash
557        let hash3 = hash_query("query1");
558        assert_eq!(hash1, hash3);
559    }
560
561    #[test]
562    fn test_cache_key_generation_with_different_path_formats() {
563        // Create two search results with the same path but different formats
564        let result1 = SearchResult {
565            file: "./path/to/file.rs".to_string(),
566            lines: (10, 20),
567            node_type: "function".to_string(),
568            code: "".to_string(),
569            matched_by_filename: None,
570            rank: None,
571            score: None,
572            tfidf_score: None,
573            bm25_score: None,
574            tfidf_rank: None,
575            bm25_rank: None,
576            new_score: None,
577            hybrid2_rank: None,
578            combined_score_rank: None,
579            file_unique_terms: None,
580            file_total_matches: None,
581            file_match_rank: None,
582            block_unique_terms: None,
583            block_total_matches: None,
584            parent_file_id: None,
585            block_id: None,
586            matched_keywords: None,
587            tokenized_content: None,
588        };
589
590        let result2 = SearchResult {
591            file: "path/to/file.rs".to_string(),
592            lines: (10, 20),
593            node_type: "function".to_string(),
594            code: "".to_string(),
595            matched_by_filename: None,
596            rank: None,
597            score: None,
598            tfidf_score: None,
599            bm25_score: None,
600            tfidf_rank: None,
601            bm25_rank: None,
602            new_score: None,
603            hybrid2_rank: None,
604            combined_score_rank: None,
605            file_unique_terms: None,
606            file_total_matches: None,
607            file_match_rank: None,
608            block_unique_terms: None,
609            block_total_matches: None,
610            parent_file_id: None,
611            block_id: None,
612            matched_keywords: None,
613            tokenized_content: None,
614        };
615
616        // Generate cache keys for both results
617        let key1 = generate_cache_key(&result1);
618        let key2 = generate_cache_key(&result2);
619
620        // The cache keys should be identical
621        assert_eq!(key1, key2);
622        assert_eq!(key1, "path/to/file.rs:10-20");
623    }
624
625    #[test]
626    fn test_session_cache_with_query_hash() {
627        // Test that different queries for the same session have different cache paths
628        let session_id = "test_session";
629        let query1 = "query1";
630        let query2 = "query2";
631
632        let hash1 = hash_query(query1);
633        let hash2 = hash_query(query2);
634
635        let path1 = SessionCache::get_cache_path(session_id, &hash1);
636        let path2 = SessionCache::get_cache_path(session_id, &hash2);
637
638        // Paths should be different for different queries
639        assert_ne!(path1, path2);
640
641        // Create caches with different queries
642        let cache1 = SessionCache::new(session_id.to_string(), hash1);
643        let cache2 = SessionCache::new(session_id.to_string(), hash2);
644
645        // Caches should have different query hashes
646        assert_ne!(cache1.query_hash, cache2.query_hash);
647    }
648}