provable_contracts/query/
index.rs1use std::collections::HashMap;
4use std::path::Path;
5
6use provable_contracts_macros::requires;
7
8use crate::schema::{Contract, parse_contract};
9use crate::scoring;
10
11use super::persist::{self, PersistedIndex};
12use super::types::ContractEntry;
13
14#[derive(Debug)]
16pub struct ContractIndex {
17 pub entries: Vec<ContractEntry>,
18 name_index: HashMap<String, usize>,
19 equation_index: HashMap<String, Vec<usize>>,
20 obligation_index: HashMap<String, Vec<usize>>,
21 score_cache: HashMap<String, f64>,
23 pagerank_cache: HashMap<String, f64>,
25 avg_dl: f64,
27 df: HashMap<String, usize>,
29}
30
31impl ContractIndex {
32 pub fn from_directory(dir: &Path) -> Result<Self, Box<dyn std::error::Error>> {
37 Self::from_directory_opts(dir, false)
38 }
39
40 pub fn from_directory_opts(
42 dir: &Path,
43 force_rebuild: bool,
44 ) -> Result<Self, Box<dyn std::error::Error>> {
45 if !force_rebuild {
47 if let Some(cached) = persist::load_cached(dir) {
48 let mut index = Self::from_entries(cached.entries);
49 index.score_cache = cached.score_cache;
50 index.pagerank_cache = cached.pagerank_cache;
51 return Ok(index);
52 }
53 }
54
55 let index = Self::build_from_directory(dir)?;
56
57 let _ = persist::save_cached(
59 dir,
60 &PersistedIndex {
61 entries: index.entries.clone(),
62 score_cache: index.score_cache.clone(),
63 pagerank_cache: index.pagerank_cache.clone(),
64 },
65 );
66
67 Ok(index)
68 }
69
70 pub fn build_from_directory(dir: &Path) -> Result<Self, Box<dyn std::error::Error>> {
72 let mut yaml_paths: Vec<_> = collect_yaml_files(dir)?;
73 yaml_paths.sort();
74
75 let mut entries = Vec::new();
76 let mut score_cache = HashMap::new();
77 for path in &yaml_paths {
78 let Ok(contract) = parse_contract(path) else {
79 continue;
80 };
81 let stem = path
82 .file_stem()
83 .and_then(|s| s.to_str())
84 .unwrap_or("unknown")
85 .to_string();
86 let path_str = path.display().to_string();
87 let score = scoring::score_contract(&contract, None, &stem);
88 score_cache.insert(stem.clone(), score.composite);
89 entries.push(build_entry(stem, path_str, &contract));
90 }
91
92 let mut index = Self::from_entries(entries);
93 index.score_cache = score_cache;
94 index.pagerank_cache = index.pagerank(20, 0.85);
95 Ok(index)
96 }
97
98 #[allow(clippy::cast_precision_loss)]
100 pub fn from_entries(entries: Vec<ContractEntry>) -> Self {
101 let mut name_index = HashMap::new();
102 let mut equation_index: HashMap<String, Vec<usize>> = HashMap::new();
103 let mut obligation_index: HashMap<String, Vec<usize>> = HashMap::new();
104 let mut df: HashMap<String, usize> = HashMap::new();
105 let mut total_len = 0usize;
106
107 for (i, entry) in entries.iter().enumerate() {
108 name_index.insert(entry.stem.clone(), i);
109 for eq in &entry.equations {
110 equation_index.entry(eq.clone()).or_default().push(i);
111 }
112 for ot in &entry.obligation_types {
113 obligation_index.entry(ot.clone()).or_default().push(i);
114 }
115
116 let terms = tokenize(&entry.corpus_text);
117 total_len += terms.len();
118 let mut seen = std::collections::HashSet::new();
119 for t in &terms {
120 if seen.insert(t.clone()) {
121 *df.entry(t.clone()).or_default() += 1;
122 }
123 }
124 }
125
126 let avg_dl = if entries.is_empty() {
127 1.0
128 } else {
129 total_len as f64 / entries.len() as f64
130 };
131
132 Self {
133 entries,
134 name_index,
135 equation_index,
136 obligation_index,
137 score_cache: HashMap::new(),
138 pagerank_cache: HashMap::new(),
139 avg_dl,
140 df,
141 }
142 }
143
144 pub fn get_by_stem(&self, stem: &str) -> Option<&ContractEntry> {
146 self.name_index.get(stem).map(|&i| &self.entries[i])
147 }
148
149 pub fn cached_score(&self, stem: &str) -> Option<f64> {
151 self.score_cache.get(stem).copied()
152 }
153
154 pub fn cached_pagerank(&self, stem: &str) -> Option<f64> {
156 self.pagerank_cache.get(stem).copied()
157 }
158
159 pub fn get_by_obligation(&self, ob_type: &str) -> Vec<&ContractEntry> {
161 self.obligation_index
162 .get(ob_type)
163 .map(|idxs| idxs.iter().map(|&i| &self.entries[i]).collect())
164 .unwrap_or_default()
165 }
166
167 pub fn get_by_equation(&self, eq: &str) -> Vec<&ContractEntry> {
169 self.equation_index
170 .get(eq)
171 .map(|idxs| idxs.iter().map(|&i| &self.entries[i]).collect())
172 .unwrap_or_default()
173 }
174
175 #[allow(clippy::cast_precision_loss)]
177 pub fn bm25_search(&self, query: &str) -> Vec<(usize, f64)> {
178 let query_terms = tokenize(query);
179 if query_terms.is_empty() {
180 return Vec::new();
181 }
182
183 let n = self.entries.len() as f64;
184 let k1 = 1.2;
185 let b = 0.75;
186
187 let mut scores: Vec<(usize, f64)> = self
188 .entries
189 .iter()
190 .enumerate()
191 .map(|(i, entry)| {
192 let doc_terms = tokenize(&entry.corpus_text);
193 let dl = doc_terms.len() as f64;
194
195 let tf_map = term_frequencies(&doc_terms);
196 let score: f64 = query_terms
197 .iter()
198 .map(|qt| {
199 let doc_freq = self.df.get(qt).copied().unwrap_or(0) as f64;
200 let idf = ((n - doc_freq + 0.5) / (doc_freq + 0.5) + 1.0).ln();
201 let tf = tf_map.get(qt).copied().unwrap_or(0) as f64;
202 idf * (tf * (k1 + 1.0)) / (tf + k1 * (1.0 - b + b * dl / self.avg_dl))
203 })
204 .sum();
205
206 (i, score)
207 })
208 .filter(|(_, s)| *s > 0.0)
209 .collect();
210
211 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
212 scores
213 }
214
215 pub fn regex_search(&self, pattern: &str) -> Result<Vec<usize>, regex::Error> {
217 let re = regex::Regex::new(pattern)?;
218 Ok(self
219 .entries
220 .iter()
221 .enumerate()
222 .filter(|(_, e)| re.is_match(&e.corpus_text))
223 .map(|(i, _)| i)
224 .collect())
225 }
226
227 pub fn literal_search(&self, needle: &str, case_sensitive: bool) -> Vec<usize> {
229 let needle_lower = needle.to_lowercase();
230 self.entries
231 .iter()
232 .enumerate()
233 .filter(|(_, e)| {
234 if case_sensitive {
235 e.corpus_text.contains(needle)
236 } else {
237 e.corpus_text.to_lowercase().contains(&needle_lower)
238 }
239 })
240 .map(|(i, _)| i)
241 .collect()
242 }
243
244 pub fn depended_by(&self, stem: &str) -> Vec<&str> {
246 self.entries
247 .iter()
248 .filter(|e| e.depends_on.iter().any(|d| d == stem))
249 .map(|e| e.stem.as_str())
250 .collect()
251 }
252
253 #[allow(clippy::cast_precision_loss)]
258 #[requires(iterations > 0 && damping > 0.0 && damping < 1.0)]
259 pub fn pagerank(&self, iterations: usize, damping: f64) -> HashMap<String, f64> {
260 let n = self.entries.len();
261 if n == 0 {
262 return HashMap::new();
263 }
264 let n_f = n as f64;
265 let mut scores: Vec<f64> = vec![1.0 / n_f; n];
266
267 for _ in 0..iterations {
268 let mut new_scores = vec![(1.0 - damping) / n_f; n];
269 for (i, entry) in self.entries.iter().enumerate() {
270 let out_degree = entry.depends_on.len();
271 if out_degree == 0 {
272 let share = damping * scores[i] / n_f;
274 for s in &mut new_scores {
275 *s += share;
276 }
277 } else {
278 let share = damping * scores[i] / out_degree as f64;
279 for dep in &entry.depends_on {
280 if let Some(&j) = self.name_index.get(dep) {
281 new_scores[j] += share;
282 }
283 }
284 }
285 }
286 scores = new_scores;
287 }
288
289 self.entries
290 .iter()
291 .enumerate()
292 .map(|(i, e)| (e.stem.clone(), scores[i]))
293 .collect()
294 }
295}
296
297fn build_entry(stem: String, path: String, contract: &Contract) -> ContractEntry {
298 let equations: Vec<String> = contract.equations.keys().cloned().collect();
299 let obligation_types: Vec<String> = contract
300 .proof_obligations
301 .iter()
302 .map(|o| o.obligation_type.to_string())
303 .collect();
304 let properties: Vec<String> = contract
305 .proof_obligations
306 .iter()
307 .map(|o| o.property.clone())
308 .collect();
309 let references = contract.metadata.references.clone();
310 let depends_on = contract.metadata.depends_on.clone();
311 let mut corpus_parts = vec![stem.clone(), contract.metadata.description.clone()];
312 for (name, eq) in &contract.equations {
313 corpus_parts.push(name.clone());
314 corpus_parts.push(eq.formula.clone());
315 corpus_parts.extend(eq.invariants.iter().cloned());
316 }
317 for ob in &contract.proof_obligations {
318 corpus_parts.push(ob.property.clone());
319 if let Some(f) = &ob.formal {
320 corpus_parts.push(f.clone());
321 }
322 }
323 corpus_parts.extend(references.iter().cloned());
324 let corpus_text = corpus_parts.join(" ");
325
326 ContractEntry {
327 stem,
328 path,
329 description: contract.metadata.description.clone(),
330 equations,
331 obligation_types,
332 properties,
333 references,
334 depends_on,
335 is_registry: contract.is_registry(),
336 kind: contract.kind(),
337 obligation_count: contract.proof_obligations.len(),
338 falsification_count: contract.falsification_tests.len(),
339 kani_count: contract.kani_harnesses.len(),
340 corpus_text,
341 }
342}
343
344fn tokenize(text: &str) -> Vec<String> {
346 text.split(|c: char| !c.is_alphanumeric() && c != '_')
347 .map(str::to_lowercase)
348 .filter(|s| s.len() >= 2)
349 .collect()
350}
351
352fn term_frequencies(terms: &[String]) -> HashMap<&String, usize> {
353 let mut tf = HashMap::new();
354 for t in terms {
355 *tf.entry(t).or_insert(0) += 1;
356 }
357 tf
358}
359
360fn collect_yaml_files(dir: &Path) -> Result<Vec<std::path::PathBuf>, Box<dyn std::error::Error>> {
361 let mut result = Vec::new();
362 for entry in std::fs::read_dir(dir)? {
363 let entry = entry?;
364 let path = entry.path();
365 if path.is_dir() {
366 result.extend(collect_yaml_files(&path)?);
367 } else if path.extension().and_then(|x| x.to_str()) == Some("yaml") {
368 result.push(path);
369 }
370 }
371 Ok(result)
372}
373
374#[cfg(test)]
375mod tests {
376 use super::*;
377
378 #[test]
379 fn tokenize_splits_correctly() {
380 let tokens = tokenize("softmax-kernel_v1 numerical stability");
381 assert!(tokens.contains(&"softmax".to_string()));
382 assert!(tokens.contains(&"kernel_v1".to_string()));
383 assert!(tokens.contains(&"numerical".to_string()));
384 assert!(tokens.contains(&"stability".to_string()));
385 }
386
387 #[test]
388 fn tokenize_filters_short() {
389 let tokens = tokenize("a is ok");
390 assert!(!tokens.iter().any(|t| t == "a"));
391 assert!(tokens.contains(&"is".to_string()));
392 assert!(tokens.contains(&"ok".to_string()));
393 }
394
395 #[test]
396 fn index_from_contracts_dir() {
397 let dir = std::path::Path::new(env!("CARGO_MANIFEST_DIR")).join("../../contracts");
398 let index = ContractIndex::build_from_directory(&dir).unwrap();
399 assert!(index.entries.len() > 10, "Should index many contracts");
400 assert!(index.get_by_stem("softmax-kernel-v1").is_some());
401 }
402
403 #[test]
404 fn bm25_ranks_relevant_first() {
405 let dir = std::path::Path::new(env!("CARGO_MANIFEST_DIR")).join("../../contracts");
406 let index = ContractIndex::build_from_directory(&dir).unwrap();
407 let results = index.bm25_search("softmax numerical stability");
408 assert!(!results.is_empty());
409 let top = &index.entries[results[0].0];
411 assert!(
412 top.corpus_text.to_lowercase().contains("softmax"),
413 "Top result corpus should mention softmax, got stem={}",
414 top.stem,
415 );
416 }
417
418 #[test]
419 fn literal_search_finds_match() {
420 let dir = std::path::Path::new(env!("CARGO_MANIFEST_DIR")).join("../../contracts");
421 let index = ContractIndex::build_from_directory(&dir).unwrap();
422 let matches = index.literal_search("RMSNorm", false);
423 assert!(!matches.is_empty());
424 }
425
426 #[test]
427 fn regex_search_finds_patterns() {
428 let dir = std::path::Path::new(env!("CARGO_MANIFEST_DIR")).join("../../contracts");
429 let index = ContractIndex::build_from_directory(&dir).unwrap();
430 let matches = index.regex_search(r"(?i)softmax|log.softmax").unwrap();
431 assert!(!matches.is_empty());
432 }
433
434 #[test]
435 fn depended_by_returns_dependents() {
436 let dir = std::path::Path::new(env!("CARGO_MANIFEST_DIR")).join("../../contracts");
437 let index = ContractIndex::build_from_directory(&dir).unwrap();
438 let _deps = index.depended_by("softmax-kernel-v1");
440 assert!(!index.entries.is_empty(), "Index should contain contracts");
443 }
444
445 #[test]
446 fn pagerank_produces_valid_scores() {
447 let dir = std::path::Path::new(env!("CARGO_MANIFEST_DIR")).join("../../contracts");
448 let index = ContractIndex::build_from_directory(&dir).unwrap();
449 let scores = index.pagerank(20, 0.85);
450 let unique_stems: std::collections::HashSet<_> =
452 index.entries.iter().map(|e| &e.stem).collect();
453 assert_eq!(scores.len(), unique_stems.len());
454 for s in scores.values() {
456 assert!(*s > 0.0, "PageRank should be positive");
457 }
458 let softmax = scores.get("softmax-kernel-v1").unwrap();
460 #[allow(clippy::cast_precision_loss)]
461 let mean = scores.values().sum::<f64>() / scores.len() as f64;
462 assert!(
463 *softmax >= mean,
464 "softmax ({softmax:.4}) should be >= mean ({mean:.4})"
465 );
466 }
467
468 #[test]
469 fn pagerank_empty_index() {
470 let index = ContractIndex::from_entries(Vec::new());
471 let scores = index.pagerank(20, 0.85);
472 assert!(scores.is_empty());
473 }
474
475 #[test]
476 fn from_directory_uses_cache() {
477 let tmp = std::env::temp_dir().join("pv_from_dir_cache_test");
479 let _ = std::fs::remove_dir_all(&tmp);
480 std::fs::create_dir_all(&tmp).unwrap();
481
482 let src = std::path::Path::new(env!("CARGO_MANIFEST_DIR")).join("../../contracts");
484 for name in &["softmax-kernel-v1.yaml", "rmsnorm-kernel-v1.yaml"] {
485 let content = std::fs::read_to_string(src.join(name)).unwrap();
486 std::fs::write(tmp.join(name), content).unwrap();
487 }
488
489 let idx1 = ContractIndex::from_directory(&tmp).unwrap();
491 assert!(idx1.entries.len() >= 2);
492
493 let idx2 = ContractIndex::from_directory(&tmp).unwrap();
495 assert_eq!(idx1.entries.len(), idx2.entries.len());
496
497 let _ = std::fs::remove_dir_all(&tmp);
498 let _ = std::fs::remove_dir_all(tmp.parent().unwrap().join(".pv"));
499 }
500}