Skip to main content

lean_ctx/core/
splade_retrieval.rs

1//! SPLADE-style sparse expansion layered on BM25 candidate retrieval.
2//!
3//! Stage 1: BM25 top-100. Stage 2: programming-term synonym / association expansion.
4//! Stage 3: additive expansion BM25-like scoring combined with stage-1 scores.
5
6use std::collections::{HashMap, HashSet};
7
8use crate::core::bm25_index::{tokenize_for_index, BM25Index};
9
10/// Result row after hybrid retrieval + re-ranking.
11#[derive(Debug, Clone, PartialEq)]
12pub struct SpladeResult {
13    pub chunk_idx: usize,
14    pub file_path: String,
15    pub symbol_name: String,
16    pub combined_score: f64,
17    pub bm25_score: f64,
18    pub expansion_score: f64,
19}
20
21/// Static programming-term associations (token → related tokens with weights).
22fn expansion_dictionary() -> HashMap<&'static str, Vec<(&'static str, f64)>> {
23    HashMap::from([
24        (
25            "auth",
26            vec![
27                ("authentication", 1.0),
28                ("token", 0.9),
29                ("jwt", 0.85),
30                ("login", 0.8),
31                ("session", 0.85),
32                ("oauth", 0.75),
33                ("credential", 0.7),
34            ],
35        ),
36        (
37            "async",
38            vec![
39                ("await", 1.0),
40                ("future", 0.85),
41                ("promise", 0.75),
42                ("tokio", 0.65),
43                ("concurrent", 0.7),
44            ],
45        ),
46        (
47            "error",
48            vec![
49                ("err", 0.95),
50                ("result", 0.75),
51                ("panic", 0.55),
52                ("exception", 0.65),
53                ("unwrap", 0.5),
54            ],
55        ),
56        (
57            "http",
58            vec![
59                ("request", 0.85),
60                ("response", 0.85),
61                ("rest", 0.65),
62                ("json", 0.7),
63                ("header", 0.6),
64            ],
65        ),
66        (
67            "db",
68            vec![
69                ("database", 1.0),
70                ("sql", 0.85),
71                ("query", 0.75),
72                ("transaction", 0.65),
73                ("migration", 0.55),
74            ],
75        ),
76        (
77            "test",
78            vec![
79                ("mock", 0.75),
80                ("fixture", 0.6),
81                ("assert", 0.85),
82                ("expect", 0.65),
83            ],
84        ),
85        (
86            "config",
87            vec![
88                ("configuration", 1.0),
89                ("env", 0.75),
90                ("setting", 0.65),
91                ("toml", 0.55),
92            ],
93        ),
94        (
95            "cache",
96            vec![
97                ("memo", 0.65),
98                ("redis", 0.55),
99                ("ttl", 0.6),
100                ("invalidate", 0.55),
101            ],
102        ),
103    ])
104}
105
106fn build_expanded_weights(query_tokens: &[String]) -> HashMap<String, f64> {
107    let dict = expansion_dictionary();
108    let mut out: HashMap<String, f64> = HashMap::new();
109
110    for t in query_tokens {
111        let lower = t.to_lowercase();
112        let entry = out.entry(lower.clone()).or_insert(1.0);
113        *entry = (*entry).max(1.0);
114
115        if let Some(rel) = dict.get(lower.as_str()) {
116            for (syn, w) in rel {
117                let le = out.entry((*syn).to_string()).or_insert(0.0);
118                *le = (*le).max(*w);
119            }
120        }
121    }
122    out
123}
124
125fn expansion_bm25_for_chunk(
126    index: &BM25Index,
127    chunk_idx: usize,
128    expanded: &HashMap<String, f64>,
129    original_terms: &HashSet<String>,
130) -> f64 {
131    if index.doc_count == 0 {
132        return 0.0;
133    }
134
135    let doc_len = index.chunks[chunk_idx].token_count as f64;
136    let norm_len = doc_len / index.avg_doc_len.max(1.0);
137
138    const K1: f64 = 1.2;
139    const B: f64 = 0.75;
140
141    let mut sum = 0.0;
142    for (term, ew) in expanded {
143        if original_terms.contains(term) {
144            continue;
145        }
146        let df = *index.doc_freqs.get(term).unwrap_or(&0) as f64;
147        if df == 0.0 {
148            continue;
149        }
150
151        let idf = ((index.doc_count as f64 - df + 0.5) / (df + 0.5) + 1.0).ln();
152        let tf = index.inverted.get(term).map_or(0.0, |postings| {
153            postings
154                .iter()
155                .filter(|(idx, _)| *idx == chunk_idx)
156                .map(|(_, w)| *w)
157                .sum::<f64>()
158        });
159
160        if tf == 0.0 {
161            continue;
162        }
163
164        let bm25_t = idf * (tf * (K1 + 1.0)) / (tf + K1 * (1.0 - B + B * norm_len));
165        sum += ew * bm25_t;
166    }
167    sum
168}
169
170/// BM25 top-100 → SPLADE-like expansion → combined re-rank.
171pub fn hybrid_retrieve(query: &str, bm25_index: &BM25Index, top_k: usize) -> Vec<SpladeResult> {
172    if bm25_index.doc_count == 0 || top_k == 0 {
173        return Vec::new();
174    }
175
176    let query_tokens = tokenize_for_index(query);
177    if query_tokens.is_empty() {
178        return Vec::new();
179    }
180
181    let original_terms: HashSet<String> = query_tokens.iter().map(|s| s.to_lowercase()).collect();
182
183    let expanded = build_expanded_weights(&query_tokens);
184
185    let stage1 = bm25_index.search(query, 100.min(bm25_index.doc_count.max(1)));
186
187    let bm25_by_chunk: HashMap<usize, f64> =
188        stage1.iter().map(|r| (r.chunk_idx, r.score)).collect();
189
190    let chunk_indices: Vec<usize> = bm25_by_chunk.keys().copied().collect();
191    if chunk_indices.is_empty() {
192        return Vec::new();
193    }
194
195    let mut expansion_scores: HashMap<usize, f64> = HashMap::new();
196    for &idx in &chunk_indices {
197        let es = expansion_bm25_for_chunk(bm25_index, idx, &expanded, &original_terms);
198        expansion_scores.insert(idx, es);
199    }
200
201    let max_bm25 = bm25_by_chunk.values().copied().fold(0.0_f64, f64::max);
202    let max_exp = expansion_scores.values().copied().fold(0.0_f64, f64::max);
203
204    let norm_bm25 = |s: f64| -> f64 {
205        if max_bm25 > 1e-12 {
206            s / max_bm25
207        } else {
208            0.0
209        }
210    };
211    let norm_exp = |s: f64| -> f64 {
212        if max_exp > 1e-12 {
213            s / max_exp
214        } else {
215            0.0
216        }
217    };
218
219    const W_BM25: f64 = 0.65;
220    const W_EXP: f64 = 0.35;
221
222    let mut results: Vec<SpladeResult> = chunk_indices
223        .into_iter()
224        .map(|chunk_idx| {
225            let chunk = &bm25_index.chunks[chunk_idx];
226            let bm25_score = bm25_by_chunk.get(&chunk_idx).copied().unwrap_or(0.0);
227            let expansion_score = expansion_scores.get(&chunk_idx).copied().unwrap_or(0.0);
228            let combined_score = W_BM25 * norm_bm25(bm25_score) + W_EXP * norm_exp(expansion_score);
229
230            SpladeResult {
231                chunk_idx,
232                file_path: chunk.file_path.clone(),
233                symbol_name: chunk.symbol_name.clone(),
234                combined_score,
235                bm25_score,
236                expansion_score,
237            }
238        })
239        .collect();
240
241    results.sort_by(|a, b| {
242        b.combined_score
243            .partial_cmp(&a.combined_score)
244            .unwrap_or(std::cmp::Ordering::Equal)
245            .then_with(|| {
246                b.bm25_score
247                    .partial_cmp(&a.bm25_score)
248                    .unwrap_or(std::cmp::Ordering::Equal)
249            })
250            .then_with(|| a.file_path.cmp(&b.file_path))
251            .then_with(|| a.symbol_name.cmp(&b.symbol_name))
252    });
253
254    results.truncate(top_k);
255    results
256}
257
258#[cfg(test)]
259mod tests {
260    use super::*;
261    use crate::core::bm25_index::{ChunkKind, CodeChunk};
262
263    fn sample_index() -> BM25Index {
264        BM25Index::from_chunks_for_test(vec![
265            CodeChunk {
266                file_path: "login.rs".into(),
267                symbol_name: "login_user".into(),
268                kind: ChunkKind::Function,
269                start_line: 1,
270                end_line: 3,
271                content: "pub fn login_user() { session_token authentication jwt oauth login credential }"
272                    .into(),
273                tokens: vec![],
274                token_count: 0,
275            },
276            CodeChunk {
277                file_path: "cache.rs".into(),
278                symbol_name: "memo_cache".into(),
279                kind: ChunkKind::Function,
280                start_line: 1,
281                end_line: 3,
282                content: "pub fn memo_cache() { redis ttl invalidate memo cache }".into(),
283                tokens: vec![],
284                token_count: 0,
285            },
286            CodeChunk {
287                file_path: "auth.rs".into(),
288                symbol_name: "oauth_flow".into(),
289                kind: ChunkKind::Function,
290                start_line: 1,
291                end_line: 3,
292                content: "pub fn oauth_flow() { credential authentication token jwt session }".into(),
293                tokens: vec![],
294                token_count: 0,
295            },
296        ])
297    }
298
299    #[test]
300    fn hybrid_prefers_expansion_overlap() {
301        let index = sample_index();
302
303        let hits = hybrid_retrieve("auth login", &index, 5);
304        assert!(!hits.is_empty());
305        let top_path = hits[0].file_path.clone();
306        assert!(
307            top_path.ends_with("login.rs") || top_path.ends_with("auth.rs"),
308            "expected auth-expanded chunk first, got {hits:?}"
309        );
310    }
311
312    #[test]
313    fn expansion_boosts_related_terms() {
314        let index = sample_index();
315        // "jwt" matches BM25 stage 1; "auth" drives SPLADE expansion toward login/oauth chunks.
316        let hybrid = hybrid_retrieve("jwt auth", &index, 10);
317
318        assert!(!hybrid.is_empty());
319        assert!(
320            hybrid[0].expansion_score >= 0.0,
321            "expansion_score should be non-negative"
322        );
323    }
324
325    #[test]
326    fn empty_query_returns_empty() {
327        let index = sample_index();
328        assert!(hybrid_retrieve("", &index, 5).is_empty());
329    }
330
331    #[test]
332    fn splade_result_fields_populated() {
333        let index = sample_index();
334
335        let hybrid = hybrid_retrieve("auth session", &index, 3);
336        let r = &hybrid[0];
337        assert!(r.chunk_idx < index.chunks.len());
338        assert!(!r.file_path.is_empty());
339        assert!(r.combined_score.is_finite());
340        assert!(r.bm25_score.is_finite());
341        assert!(r.expansion_score.is_finite());
342    }
343}