Skip to main content

provable_contracts/query/
index.rs

1//! Contract index building and BM25 search.
2
3use std::collections::HashMap;
4use std::path::Path;
5
6use provable_contracts_macros::requires;
7
8use crate::schema::{Contract, parse_contract};
9use crate::scoring;
10
11use super::persist::{self, PersistedIndex};
12use super::types::ContractEntry;
13
14/// In-memory contract index with inverted indexes for fast lookup.
15#[derive(Debug)]
16pub struct ContractIndex {
17    pub entries: Vec<ContractEntry>,
18    name_index: HashMap<String, usize>,
19    equation_index: HashMap<String, Vec<usize>>,
20    obligation_index: HashMap<String, Vec<usize>>,
21    /// Pre-computed composite scores for O(1) `--min-score` filtering.
22    score_cache: HashMap<String, f64>,
23    /// Pre-computed pagerank scores for importance-weighted ranking.
24    pagerank_cache: HashMap<String, f64>,
25    /// Average document length for BM25.
26    avg_dl: f64,
27    /// Document frequency per term.
28    df: HashMap<String, usize>,
29}
30
31impl ContractIndex {
32    /// Build an index from a directory of YAML contracts.
33    ///
34    /// Uses cached index from `.pv/contracts.idx` when fresh,
35    /// otherwise rebuilds and caches for next time.
36    pub fn from_directory(dir: &Path) -> Result<Self, Box<dyn std::error::Error>> {
37        Self::from_directory_opts(dir, false)
38    }
39
40    /// Build an index with option to force rebuild (skip cache).
41    pub fn from_directory_opts(
42        dir: &Path,
43        force_rebuild: bool,
44    ) -> Result<Self, Box<dyn std::error::Error>> {
45        // Try loading cached index first (unless force rebuild)
46        if !force_rebuild {
47            if let Some(cached) = persist::load_cached(dir) {
48                let mut index = Self::from_entries(cached.entries);
49                index.score_cache = cached.score_cache;
50                index.pagerank_cache = cached.pagerank_cache;
51                return Ok(index);
52            }
53        }
54
55        let index = Self::build_from_directory(dir)?;
56
57        // Save to cache (best-effort, don't fail on write errors)
58        let _ = persist::save_cached(
59            dir,
60            &PersistedIndex {
61                entries: index.entries.clone(),
62                score_cache: index.score_cache.clone(),
63                pagerank_cache: index.pagerank_cache.clone(),
64            },
65        );
66
67        Ok(index)
68    }
69
70    /// Build an index from a directory without cache.
71    pub fn build_from_directory(dir: &Path) -> Result<Self, Box<dyn std::error::Error>> {
72        let mut yaml_paths: Vec<_> = collect_yaml_files(dir)?;
73        yaml_paths.sort();
74
75        let mut entries = Vec::new();
76        let mut score_cache = HashMap::new();
77        for path in &yaml_paths {
78            let Ok(contract) = parse_contract(path) else {
79                continue;
80            };
81            let stem = path
82                .file_stem()
83                .and_then(|s| s.to_str())
84                .unwrap_or("unknown")
85                .to_string();
86            let path_str = path.display().to_string();
87            let score = scoring::score_contract(&contract, None, &stem);
88            score_cache.insert(stem.clone(), score.composite);
89            entries.push(build_entry(stem, path_str, &contract));
90        }
91
92        let mut index = Self::from_entries(entries);
93        index.score_cache = score_cache;
94        index.pagerank_cache = index.pagerank(20, 0.85);
95        Ok(index)
96    }
97
98    /// Build an index from pre-parsed entries.
99    #[allow(clippy::cast_precision_loss)]
100    pub fn from_entries(entries: Vec<ContractEntry>) -> Self {
101        let mut name_index = HashMap::new();
102        let mut equation_index: HashMap<String, Vec<usize>> = HashMap::new();
103        let mut obligation_index: HashMap<String, Vec<usize>> = HashMap::new();
104        let mut df: HashMap<String, usize> = HashMap::new();
105        let mut total_len = 0usize;
106
107        for (i, entry) in entries.iter().enumerate() {
108            name_index.insert(entry.stem.clone(), i);
109            for eq in &entry.equations {
110                equation_index.entry(eq.clone()).or_default().push(i);
111            }
112            for ot in &entry.obligation_types {
113                obligation_index.entry(ot.clone()).or_default().push(i);
114            }
115
116            let terms = tokenize(&entry.corpus_text);
117            total_len += terms.len();
118            let mut seen = std::collections::HashSet::new();
119            for t in &terms {
120                if seen.insert(t.clone()) {
121                    *df.entry(t.clone()).or_default() += 1;
122                }
123            }
124        }
125
126        let avg_dl = if entries.is_empty() {
127            1.0
128        } else {
129            total_len as f64 / entries.len() as f64
130        };
131
132        Self {
133            entries,
134            name_index,
135            equation_index,
136            obligation_index,
137            score_cache: HashMap::new(),
138            pagerank_cache: HashMap::new(),
139            avg_dl,
140            df,
141        }
142    }
143
144    /// Look up a contract by exact stem.
145    pub fn get_by_stem(&self, stem: &str) -> Option<&ContractEntry> {
146        self.name_index.get(stem).map(|&i| &self.entries[i])
147    }
148
149    /// Get the pre-computed composite score for a contract stem.
150    pub fn cached_score(&self, stem: &str) -> Option<f64> {
151        self.score_cache.get(stem).copied()
152    }
153
154    /// Get the pre-computed pagerank score for a contract stem.
155    pub fn cached_pagerank(&self, stem: &str) -> Option<f64> {
156        self.pagerank_cache.get(stem).copied()
157    }
158
159    /// Look up contracts by obligation type.
160    pub fn get_by_obligation(&self, ob_type: &str) -> Vec<&ContractEntry> {
161        self.obligation_index
162            .get(ob_type)
163            .map(|idxs| idxs.iter().map(|&i| &self.entries[i]).collect())
164            .unwrap_or_default()
165    }
166
167    /// Look up contracts by equation name.
168    pub fn get_by_equation(&self, eq: &str) -> Vec<&ContractEntry> {
169        self.equation_index
170            .get(eq)
171            .map(|idxs| idxs.iter().map(|&i| &self.entries[i]).collect())
172            .unwrap_or_default()
173    }
174
175    /// BM25 search across all entries. Returns (index, score) pairs sorted descending.
176    #[allow(clippy::cast_precision_loss)]
177    pub fn bm25_search(&self, query: &str) -> Vec<(usize, f64)> {
178        let query_terms = tokenize(query);
179        if query_terms.is_empty() {
180            return Vec::new();
181        }
182
183        let n = self.entries.len() as f64;
184        let k1 = 1.2;
185        let b = 0.75;
186
187        let mut scores: Vec<(usize, f64)> = self
188            .entries
189            .iter()
190            .enumerate()
191            .map(|(i, entry)| {
192                let doc_terms = tokenize(&entry.corpus_text);
193                let dl = doc_terms.len() as f64;
194
195                let tf_map = term_frequencies(&doc_terms);
196                let score: f64 = query_terms
197                    .iter()
198                    .map(|qt| {
199                        let doc_freq = self.df.get(qt).copied().unwrap_or(0) as f64;
200                        let idf = ((n - doc_freq + 0.5) / (doc_freq + 0.5) + 1.0).ln();
201                        let tf = tf_map.get(qt).copied().unwrap_or(0) as f64;
202                        idf * (tf * (k1 + 1.0)) / (tf + k1 * (1.0 - b + b * dl / self.avg_dl))
203                    })
204                    .sum();
205
206                (i, score)
207            })
208            .filter(|(_, s)| *s > 0.0)
209            .collect();
210
211        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
212        scores
213    }
214
215    /// Regex search across all entries. Returns matching indices.
216    pub fn regex_search(&self, pattern: &str) -> Result<Vec<usize>, regex::Error> {
217        let re = regex::Regex::new(pattern)?;
218        Ok(self
219            .entries
220            .iter()
221            .enumerate()
222            .filter(|(_, e)| re.is_match(&e.corpus_text))
223            .map(|(i, _)| i)
224            .collect())
225    }
226
227    /// Literal substring search. Returns matching indices.
228    pub fn literal_search(&self, needle: &str, case_sensitive: bool) -> Vec<usize> {
229        let needle_lower = needle.to_lowercase();
230        self.entries
231            .iter()
232            .enumerate()
233            .filter(|(_, e)| {
234                if case_sensitive {
235                    e.corpus_text.contains(needle)
236                } else {
237                    e.corpus_text.to_lowercase().contains(&needle_lower)
238                }
239            })
240            .map(|(i, _)| i)
241            .collect()
242    }
243
244    /// Return reverse dependencies: contracts that depend on `stem`.
245    pub fn depended_by(&self, stem: &str) -> Vec<&str> {
246        self.entries
247            .iter()
248            .filter(|e| e.depends_on.iter().any(|d| d == stem))
249            .map(|e| e.stem.as_str())
250            .collect()
251    }
252
253    /// Compute pagerank scores over the contract dependency graph.
254    ///
255    /// Returns a map from stem to pagerank score. Higher scores indicate
256    /// more "important" contracts (more depended-upon by others).
257    #[allow(clippy::cast_precision_loss)]
258    #[requires(iterations > 0 && damping > 0.0 && damping < 1.0)]
259    pub fn pagerank(&self, iterations: usize, damping: f64) -> HashMap<String, f64> {
260        let n = self.entries.len();
261        if n == 0 {
262            return HashMap::new();
263        }
264        let n_f = n as f64;
265        let mut scores: Vec<f64> = vec![1.0 / n_f; n];
266
267        for _ in 0..iterations {
268            let mut new_scores = vec![(1.0 - damping) / n_f; n];
269            for (i, entry) in self.entries.iter().enumerate() {
270                let out_degree = entry.depends_on.len();
271                if out_degree == 0 {
272                    // Distribute rank equally to all (dangling node)
273                    let share = damping * scores[i] / n_f;
274                    for s in &mut new_scores {
275                        *s += share;
276                    }
277                } else {
278                    let share = damping * scores[i] / out_degree as f64;
279                    for dep in &entry.depends_on {
280                        if let Some(&j) = self.name_index.get(dep) {
281                            new_scores[j] += share;
282                        }
283                    }
284                }
285            }
286            scores = new_scores;
287        }
288
289        self.entries
290            .iter()
291            .enumerate()
292            .map(|(i, e)| (e.stem.clone(), scores[i]))
293            .collect()
294    }
295}
296
297fn build_entry(stem: String, path: String, contract: &Contract) -> ContractEntry {
298    let equations: Vec<String> = contract.equations.keys().cloned().collect();
299    let obligation_types: Vec<String> = contract
300        .proof_obligations
301        .iter()
302        .map(|o| o.obligation_type.to_string())
303        .collect();
304    let properties: Vec<String> = contract
305        .proof_obligations
306        .iter()
307        .map(|o| o.property.clone())
308        .collect();
309    let references = contract.metadata.references.clone();
310    let depends_on = contract.metadata.depends_on.clone();
311    let mut corpus_parts = vec![stem.clone(), contract.metadata.description.clone()];
312    for (name, eq) in &contract.equations {
313        corpus_parts.push(name.clone());
314        corpus_parts.push(eq.formula.clone());
315        corpus_parts.extend(eq.invariants.iter().cloned());
316    }
317    for ob in &contract.proof_obligations {
318        corpus_parts.push(ob.property.clone());
319        if let Some(f) = &ob.formal {
320            corpus_parts.push(f.clone());
321        }
322    }
323    corpus_parts.extend(references.iter().cloned());
324    let corpus_text = corpus_parts.join(" ");
325
326    ContractEntry {
327        stem,
328        path,
329        description: contract.metadata.description.clone(),
330        equations,
331        obligation_types,
332        properties,
333        references,
334        depends_on,
335        is_registry: contract.is_registry(),
336        kind: contract.kind(),
337        obligation_count: contract.proof_obligations.len(),
338        falsification_count: contract.falsification_tests.len(),
339        kani_count: contract.kani_harnesses.len(),
340        corpus_text,
341    }
342}
343
344/// Tokenize text into lowercase alphanumeric terms (>= 2 chars).
345fn tokenize(text: &str) -> Vec<String> {
346    text.split(|c: char| !c.is_alphanumeric() && c != '_')
347        .map(str::to_lowercase)
348        .filter(|s| s.len() >= 2)
349        .collect()
350}
351
352fn term_frequencies(terms: &[String]) -> HashMap<&String, usize> {
353    let mut tf = HashMap::new();
354    for t in terms {
355        *tf.entry(t).or_insert(0) += 1;
356    }
357    tf
358}
359
360fn collect_yaml_files(dir: &Path) -> Result<Vec<std::path::PathBuf>, Box<dyn std::error::Error>> {
361    let mut result = Vec::new();
362    for entry in std::fs::read_dir(dir)? {
363        let entry = entry?;
364        let path = entry.path();
365        if path.is_dir() {
366            result.extend(collect_yaml_files(&path)?);
367        } else if path.extension().and_then(|x| x.to_str()) == Some("yaml") {
368            result.push(path);
369        }
370    }
371    Ok(result)
372}
373
374#[cfg(test)]
375mod tests {
376    use super::*;
377
378    #[test]
379    fn tokenize_splits_correctly() {
380        let tokens = tokenize("softmax-kernel_v1 numerical stability");
381        assert!(tokens.contains(&"softmax".to_string()));
382        assert!(tokens.contains(&"kernel_v1".to_string()));
383        assert!(tokens.contains(&"numerical".to_string()));
384        assert!(tokens.contains(&"stability".to_string()));
385    }
386
387    #[test]
388    fn tokenize_filters_short() {
389        let tokens = tokenize("a is ok");
390        assert!(!tokens.iter().any(|t| t == "a"));
391        assert!(tokens.contains(&"is".to_string()));
392        assert!(tokens.contains(&"ok".to_string()));
393    }
394
395    #[test]
396    fn index_from_contracts_dir() {
397        let dir = std::path::Path::new(env!("CARGO_MANIFEST_DIR")).join("../../contracts");
398        let index = ContractIndex::build_from_directory(&dir).unwrap();
399        assert!(index.entries.len() > 10, "Should index many contracts");
400        assert!(index.get_by_stem("softmax-kernel-v1").is_some());
401    }
402
403    #[test]
404    fn bm25_ranks_relevant_first() {
405        let dir = std::path::Path::new(env!("CARGO_MANIFEST_DIR")).join("../../contracts");
406        let index = ContractIndex::build_from_directory(&dir).unwrap();
407        let results = index.bm25_search("softmax numerical stability");
408        assert!(!results.is_empty());
409        // Top result should be related to softmax/cross-entropy (both reference softmax)
410        let top = &index.entries[results[0].0];
411        assert!(
412            top.corpus_text.to_lowercase().contains("softmax"),
413            "Top result corpus should mention softmax, got stem={}",
414            top.stem,
415        );
416    }
417
418    #[test]
419    fn literal_search_finds_match() {
420        let dir = std::path::Path::new(env!("CARGO_MANIFEST_DIR")).join("../../contracts");
421        let index = ContractIndex::build_from_directory(&dir).unwrap();
422        let matches = index.literal_search("RMSNorm", false);
423        assert!(!matches.is_empty());
424    }
425
426    #[test]
427    fn regex_search_finds_patterns() {
428        let dir = std::path::Path::new(env!("CARGO_MANIFEST_DIR")).join("../../contracts");
429        let index = ContractIndex::build_from_directory(&dir).unwrap();
430        let matches = index.regex_search(r"(?i)softmax|log.softmax").unwrap();
431        assert!(!matches.is_empty());
432    }
433
434    #[test]
435    fn depended_by_returns_dependents() {
436        let dir = std::path::Path::new(env!("CARGO_MANIFEST_DIR")).join("../../contracts");
437        let index = ContractIndex::build_from_directory(&dir).unwrap();
438        // softmax-kernel-v1 is depended on by several contracts
439        let _deps = index.depended_by("softmax-kernel-v1");
440        // At minimum attention contracts depend on softmax
441        // depended_by should return without panicking
442        assert!(!index.entries.is_empty(), "Index should contain contracts");
443    }
444
445    #[test]
446    fn pagerank_produces_valid_scores() {
447        let dir = std::path::Path::new(env!("CARGO_MANIFEST_DIR")).join("../../contracts");
448        let index = ContractIndex::build_from_directory(&dir).unwrap();
449        let scores = index.pagerank(20, 0.85);
450        // scores is a HashMap keyed by stem — duplicate stems collapse to one entry
451        let unique_stems: std::collections::HashSet<_> =
452            index.entries.iter().map(|e| &e.stem).collect();
453        assert_eq!(scores.len(), unique_stems.len());
454        // All scores should be positive
455        for s in scores.values() {
456            assert!(*s > 0.0, "PageRank should be positive");
457        }
458        // Softmax should rank relatively high (many things depend on it)
459        let softmax = scores.get("softmax-kernel-v1").unwrap();
460        #[allow(clippy::cast_precision_loss)]
461        let mean = scores.values().sum::<f64>() / scores.len() as f64;
462        assert!(
463            *softmax >= mean,
464            "softmax ({softmax:.4}) should be >= mean ({mean:.4})"
465        );
466    }
467
468    #[test]
469    fn pagerank_empty_index() {
470        let index = ContractIndex::from_entries(Vec::new());
471        let scores = index.pagerank(20, 0.85);
472        assert!(scores.is_empty());
473    }
474
475    #[test]
476    fn from_directory_uses_cache() {
477        // Build from a temp dir to avoid contaminating the real .pv cache
478        let tmp = std::env::temp_dir().join("pv_from_dir_cache_test");
479        let _ = std::fs::remove_dir_all(&tmp);
480        std::fs::create_dir_all(&tmp).unwrap();
481
482        // Copy a few contracts to temp
483        let src = std::path::Path::new(env!("CARGO_MANIFEST_DIR")).join("../../contracts");
484        for name in &["softmax-kernel-v1.yaml", "rmsnorm-kernel-v1.yaml"] {
485            let content = std::fs::read_to_string(src.join(name)).unwrap();
486            std::fs::write(tmp.join(name), content).unwrap();
487        }
488
489        // First call builds + caches
490        let idx1 = ContractIndex::from_directory(&tmp).unwrap();
491        assert!(idx1.entries.len() >= 2);
492
493        // Second call should hit cache
494        let idx2 = ContractIndex::from_directory(&tmp).unwrap();
495        assert_eq!(idx1.entries.len(), idx2.entries.len());
496
497        let _ = std::fs::remove_dir_all(&tmp);
498        let _ = std::fs::remove_dir_all(tmp.parent().unwrap().join(".pv"));
499    }
500}