1use std::collections::HashMap;
8use std::fmt;
9use std::str::FromStr;
10
11use crate::chunk::CodeChunk;
12
13#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
15pub enum SearchMode {
16 #[default]
18 Hybrid,
19 Semantic,
21 Keyword,
23}
24
25impl fmt::Display for SearchMode {
26 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
27 match self {
28 Self::Hybrid => f.write_str("hybrid"),
29 Self::Semantic => f.write_str("semantic"),
30 Self::Keyword => f.write_str("keyword"),
31 }
32 }
33}
34
35#[derive(Debug, Clone, PartialEq, Eq)]
37pub struct ParseSearchModeError(String);
38
39impl fmt::Display for ParseSearchModeError {
40 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
41 write!(
42 f,
43 "unknown search mode {:?}; expected hybrid, semantic, or keyword",
44 self.0
45 )
46 }
47}
48
49impl std::error::Error for ParseSearchModeError {}
50
51impl FromStr for SearchMode {
52 type Err = ParseSearchModeError;
53
54 fn from_str(s: &str) -> Result<Self, Self::Err> {
55 match s {
56 "hybrid" => Ok(Self::Hybrid),
57 "semantic" => Ok(Self::Semantic),
58 "keyword" => Ok(Self::Keyword),
59 other => Err(ParseSearchModeError(other.to_string())),
60 }
61 }
62}
63
64#[must_use]
65pub fn pagerank_boost_factor(percentile: f32, alpha: f32) -> f32 {
66 if percentile <= 0.0 || alpha <= 0.0 {
67 return 1.0;
68 }
69 let z = (percentile.clamp(0.0, 1.0) - 0.5) / PAGERANK_SIGMOID_STEEPNESS;
70 let sigmoid = 1.0 / (1.0 + (-z).exp());
71 1.0 + alpha * sigmoid
72}
73
74pub fn boost_with_pagerank<S: std::hash::BuildHasher>(
90 results: &mut [(usize, f32)],
91 chunks: &[CodeChunk],
92 pagerank_by_file: &HashMap<String, f32, S>,
93 alpha: f32,
94) {
95 for (idx, score) in results.iter_mut() {
101 if let Some(chunk) = chunks.get(*idx) {
102 let rank = lookup_rank(pagerank_by_file, &chunk.file_path, &chunk.name);
103 *score *= pagerank_boost_factor(rank, alpha);
104 }
105 }
106 results.sort_unstable_by(|a, b| b.1.total_cmp(&a.1).then_with(|| a.0.cmp(&b.0)));
107}
108
109#[must_use]
114pub fn pagerank_lookup(graph: &crate::repo_map::RepoGraph) -> HashMap<String, f32> {
115 let def_pct = make_percentile_fn(&graph.def_ranks);
128 let base_pct = make_percentile_fn(&graph.base_ranks);
129 let mut map = HashMap::new();
130 for (file_idx, file) in graph.files.iter().enumerate() {
131 for (def_idx, def) in file.defs.iter().enumerate() {
132 let flat = graph.def_offsets[file_idx] + def_idx;
133 if let Some(&rank) = graph.def_ranks.get(flat) {
134 let key = format!("{}::{}", file.path, def.name);
135 map.insert(key, def_pct(rank));
136 }
137 }
138 if file_idx < graph.base_ranks.len() {
139 map.insert(file.path.clone(), base_pct(graph.base_ranks[file_idx]));
140 }
141 }
142 map
143}
144
145fn make_percentile_fn(values: &[f32]) -> impl Fn(f32) -> f32 + '_ {
152 let mut sorted: Vec<f32> = values.iter().copied().filter(|v| v.is_finite()).collect();
153 sorted.sort_unstable_by(f32::total_cmp);
154 move |value: f32| {
155 if sorted.is_empty() {
156 return 0.0;
157 }
158 let count_below = sorted.partition_point(|&v| v < value);
161 #[expect(
162 clippy::cast_precision_loss,
163 reason = "rank counts well below f32 precision threshold"
164 )]
165 let pct = count_below as f32 / sorted.len() as f32;
166 pct
167 }
168}
169
170const PAGERANK_SIGMOID_STEEPNESS: f32 = 0.15;
180
181#[must_use]
187pub(crate) fn lookup_rank_for_chunk<S: std::hash::BuildHasher>(
188 pr: &HashMap<String, f32, S>,
189 file_path: &str,
190 name: &str,
191) -> f32 {
192 lookup_rank(pr, file_path, name)
193}
194
195#[must_use]
196fn lookup_rank<S: std::hash::BuildHasher>(
197 pr: &HashMap<String, f32, S>,
198 file_path: &str,
199 name: &str,
200) -> f32 {
201 let def_key = format!("{file_path}::{name}");
202 if let Some(&r) = pr.get(&def_key) {
203 return r;
204 }
205 if let Some(&r) = pr.get(file_path) {
206 return r;
207 }
208 let mut rest = file_path;
209 while let Some(idx) = rest.find('/') {
210 rest = &rest[idx + 1..];
211 if rest.is_empty() {
212 break;
213 }
214 let def_key = format!("{rest}::{name}");
215 if let Some(&r) = pr.get(&def_key) {
216 return r;
217 }
218 if let Some(&r) = pr.get(rest) {
219 return r;
220 }
221 }
222 0.0
223}
224
225#[cfg(test)]
226mod tests {
227 use super::*;
228
229 #[test]
230 fn search_mode_roundtrip() {
231 assert_eq!("hybrid".parse::<SearchMode>().unwrap(), SearchMode::Hybrid);
232 assert_eq!(
233 "semantic".parse::<SearchMode>().unwrap(),
234 SearchMode::Semantic
235 );
236 assert_eq!(
237 "keyword".parse::<SearchMode>().unwrap(),
238 SearchMode::Keyword
239 );
240
241 let err = "invalid".parse::<SearchMode>();
242 assert!(err.is_err(), "expected parse error for 'invalid'");
243 let msg = err.unwrap_err().to_string();
244 assert!(
245 msg.contains("invalid"),
246 "error message should echo the bad input"
247 );
248 }
249
250 #[test]
251 fn search_mode_display() {
252 assert_eq!(SearchMode::Hybrid.to_string(), "hybrid");
253 assert_eq!(SearchMode::Semantic.to_string(), "semantic");
254 assert_eq!(SearchMode::Keyword.to_string(), "keyword");
255 }
256
257 #[test]
258 fn pagerank_boost_amplifies_relevant() {
259 let chunks = vec![
260 CodeChunk {
261 file_path: "important.rs".into(),
262 name: "a".into(),
263 kind: "function".into(),
264 start_line: 1,
265 end_line: 10,
266 content: String::new(),
267 enriched_content: String::new(),
268 },
269 CodeChunk {
270 file_path: "obscure.rs".into(),
271 name: "b".into(),
272 kind: "function".into(),
273 start_line: 1,
274 end_line: 10,
275 content: String::new(),
276 enriched_content: String::new(),
277 },
278 ];
279
280 let mut results = vec![(0, 0.8_f32), (1, 0.8)];
282 let mut pr = HashMap::new();
283 pr.insert("important.rs".to_string(), 1.0); pr.insert("obscure.rs".to_string(), 0.1); boost_with_pagerank(&mut results, &chunks, &pr, 0.3);
287
288 assert_eq!(
290 results[0].0, 0,
291 "important.rs should rank first after boost"
292 );
293 assert!(results[0].1 > results[1].1);
294
295 assert!(
300 (results[0].1 - 1.032).abs() < 0.01,
301 "rank=1.0 boost: expected ~1.032, got {}",
302 results[0].1
303 );
304 assert!(
305 (results[1].1 - 0.816).abs() < 0.01,
306 "rank=0.1 boost: expected ~0.816, got {}",
307 results[1].1
308 );
309 }
310
311 #[test]
312 fn pagerank_boost_zero_relevance_stays_zero() {
313 let chunks = vec![CodeChunk {
314 file_path: "important.rs".into(),
315 name: "a".into(),
316 kind: "function".into(),
317 start_line: 1,
318 end_line: 10,
319 content: String::new(),
320 enriched_content: String::new(),
321 }];
322
323 let mut results = vec![(0, 0.0_f32)];
324 let mut pr = HashMap::new();
325 pr.insert("important.rs".to_string(), 1.0);
326
327 boost_with_pagerank(&mut results, &chunks, &pr, 0.3);
328
329 assert!(results[0].1.abs() < f32::EPSILON);
331 }
332
333 #[test]
334 fn pagerank_boost_unknown_file_no_effect() {
335 let chunks = vec![CodeChunk {
336 file_path: "unknown.rs".into(),
337 name: "a".into(),
338 kind: "function".into(),
339 start_line: 1,
340 end_line: 10,
341 content: String::new(),
342 enriched_content: String::new(),
343 }];
344
345 let mut results = vec![(0, 0.5_f32)];
346 let pr = HashMap::new(); boost_with_pagerank(&mut results, &chunks, &pr, 0.3);
349
350 assert!((results[0].1 - 0.5).abs() < f32::EPSILON);
352 }
353}