Skip to main content

ripvec_core/
bm25.rs

1//! BM25 keyword search index for code chunks.
2//!
3//! Provides camelCase/snake_case-aware tokenization via [`CodeSplitFilter`]
4//! and an in-RAM tantivy index ([`Bm25Index`]) that supports per-field
5//! boosted queries so identifier sub-tokens (e.g. `json` from
6//! `parseJsonConfig`) are matched correctly.
7
8use tantivy::schema::{
9    Field, INDEXED, IndexRecordOption, STORED, Schema, TextFieldIndexing, TextOptions, Value,
10};
11use tantivy::tokenizer::{
12    LowerCaser, SimpleTokenizer, TextAnalyzer, Token, TokenFilter, TokenStream, Tokenizer,
13};
14use tantivy::{
15    Index, IndexReader, ReloadPolicy, TantivyDocument,
16    collector::TopDocs,
17    query::{BooleanQuery, BoostQuery, Occur, QueryParser},
18};
19
20use crate::chunk::CodeChunk;
21
22// ──────────────────────────────────────────────────────────────────────────────
23// Identifier splitting
24// ──────────────────────────────────────────────────────────────────────────────
25
26/// Split a code identifier into its constituent sub-words.
27///
28/// Handles camelCase, PascalCase, snake_case, SCREAMING_SNAKE_CASE, and mixed
29/// forms (e.g. `HTML5Parser`). Returns the lowercased parts; if there is only
30/// one part (i.e. the token cannot be split further) an empty `Vec` is returned
31/// so callers know no expansion is needed.
32///
33/// # Examples
34/// ```
35/// # use ripvec_core::bm25::split_code_identifier;
36/// assert_eq!(split_code_identifier("parseJsonConfig"), vec!["parse", "json", "config"]);
37/// assert_eq!(split_code_identifier("my_func_name"),    vec!["my", "func", "name"]);
38/// assert_eq!(split_code_identifier("HTML5Parser"),     vec!["html5", "parser"]);
39/// assert_eq!(split_code_identifier("parser"),          Vec::<String>::new());
40/// ```
41#[must_use]
42pub fn split_code_identifier(text: &str) -> Vec<String> {
43    // First split on underscores (handles snake_case / SCREAMING_SNAKE).
44    let underscore_parts: Vec<&str> = text.split('_').filter(|s| !s.is_empty()).collect();
45
46    let mut parts: Vec<String> = Vec::new();
47
48    for segment in &underscore_parts {
49        // Within each segment apply camelCase splitting.
50        // State machine: accumulate a "run" of chars, flush when the boundary
51        // rule triggers.
52        let chars: Vec<char> = segment.chars().collect();
53        let n = chars.len();
54        let mut start = 0usize;
55
56        let mut i = 0usize;
57        while i < n {
58            // Detect camelCase / acronym boundaries.
59            // Treat digits as "non-upper" (like lowercase) for boundary
60            // detection so that "HTML5Parser" splits into ["html5", "parser"].
61            if i > start {
62                let prev = chars[i - 1];
63                let cur = chars[i];
64
65                // lowercase/digit → uppercase: "parseJson"/"HTML5Parser" → split before cur.
66                let lower_to_upper =
67                    (prev.is_lowercase() || prev.is_ascii_digit()) && cur.is_uppercase();
68
69                // uppercase-run → lowercase: "HTMLParser" → the 'P' starts a new word.
70                // The split point is before the last uppercase in the run.
71                // Digits are NOT treated as terminators here so "HTML5" stays intact.
72                let upper_run_to_lower = i >= 2
73                    && prev.is_uppercase()
74                    && cur.is_lowercase()
75                    && chars[i - 2].is_uppercase();
76
77                if lower_to_upper {
78                    parts.push(chars[start..i].iter().collect::<String>().to_lowercase());
79                    start = i;
80                } else if upper_run_to_lower {
81                    // Flush everything up to (but not including) prev.
82                    parts.push(
83                        chars[start..i - 1]
84                            .iter()
85                            .collect::<String>()
86                            .to_lowercase(),
87                    );
88                    start = i - 1;
89                }
90            }
91            i += 1;
92        }
93        // Flush remaining
94        if start < n {
95            parts.push(chars[start..n].iter().collect::<String>().to_lowercase());
96        }
97    }
98
99    // If we ended up with a single part that equals the lowercased original,
100    // there was nothing to split — return empty to signal "no expansion".
101    if parts.len() <= 1 {
102        return Vec::new();
103    }
104
105    parts
106}
107
108// ──────────────────────────────────────────────────────────────────────────────
109// Tantivy token filter
110// ──────────────────────────────────────────────────────────────────────────────
111
112/// Token stream produced by [`CodeSplitFilterWrapper`].
113///
114/// For each token from the upstream stream the original token is emitted first,
115/// then any sub-tokens produced by [`split_code_identifier`].
116pub struct CodeSplitTokenStream<'a, T> {
117    /// Upstream token stream (already lowercased by `LowerCaser`).
118    tail: T,
119    /// Buffer of pending sub-tokens; filled in reverse so `pop()` yields them
120    /// in order.
121    pending: &'a mut Vec<Token>,
122}
123
124impl<T: TokenStream> TokenStream for CodeSplitTokenStream<'_, T> {
125    fn advance(&mut self) -> bool {
126        // Drain any buffered sub-tokens first.
127        if let Some(tok) = self.pending.pop() {
128            *self.tail.token_mut() = tok;
129            return true;
130        }
131
132        // Advance the upstream stream.
133        if !self.tail.advance() {
134            return false;
135        }
136
137        let upstream = self.tail.token().clone();
138        let sub_tokens = split_code_identifier(&upstream.text);
139
140        // Queue sub-tokens in reverse order so pop() gives them in order.
141        let position_offset = upstream.position;
142        for (idx, sub) in sub_tokens.iter().enumerate().rev() {
143            let mut t = upstream.clone();
144            t.text.clone_from(sub);
145            t.position = position_offset + idx + 1;
146            self.pending.push(t);
147        }
148
149        // The upstream token is already current — nothing extra needed.
150        true
151    }
152
153    fn token(&self) -> &Token {
154        self.tail.token()
155    }
156
157    fn token_mut(&mut self) -> &mut Token {
158        self.tail.token_mut()
159    }
160}
161
162/// Tantivy [`TokenFilter`] that emits sub-tokens for camelCase/snake_case
163/// identifiers in addition to the original token.
164#[derive(Clone)]
165pub struct CodeSplitFilter;
166
167impl TokenFilter for CodeSplitFilter {
168    type Tokenizer<T: Tokenizer> = CodeSplitFilterWrapper<T>;
169
170    fn transform<T: Tokenizer>(self, tokenizer: T) -> CodeSplitFilterWrapper<T> {
171        CodeSplitFilterWrapper {
172            inner: tokenizer,
173            pending: Vec::new(),
174        }
175    }
176}
177
178/// Wrapper tokenizer produced by [`CodeSplitFilter::transform`].
179#[derive(Clone)]
180pub struct CodeSplitFilterWrapper<T> {
181    inner: T,
182    pending: Vec<Token>,
183}
184
185impl<T: Tokenizer> Tokenizer for CodeSplitFilterWrapper<T> {
186    type TokenStream<'a> = CodeSplitTokenStream<'a, T::TokenStream<'a>>;
187
188    fn token_stream<'a>(&'a mut self, text: &'a str) -> Self::TokenStream<'a> {
189        self.pending.clear();
190        CodeSplitTokenStream {
191            tail: self.inner.token_stream(text),
192            pending: &mut self.pending,
193        }
194    }
195}
196
197// ──────────────────────────────────────────────────────────────────────────────
198// Analyzer
199// ──────────────────────────────────────────────────────────────────────────────
200
201/// Build a tantivy [`TextAnalyzer`] that tokenizes, expands camelCase/snake_case
202/// identifiers into sub-tokens, then lowercases everything.
203///
204/// `CodeSplitFilter` must run **before** `LowerCaser` so that camelCase boundaries
205/// (uppercase letters) are still visible when splitting.
206#[must_use]
207pub fn code_analyzer() -> TextAnalyzer {
208    TextAnalyzer::builder(SimpleTokenizer::default())
209        .filter(CodeSplitFilter)
210        .filter(LowerCaser)
211        .build()
212}
213
214// ──────────────────────────────────────────────────────────────────────────────
215// Schema
216// ──────────────────────────────────────────────────────────────────────────────
217
218/// Handles to the tantivy schema fields used by [`Bm25Index`].
219pub struct BM25Fields {
220    /// Chunk name (function/struct name, high-value signal).
221    pub name: Field,
222    /// Relative file path of the source file.
223    pub file_path: Field,
224    /// Full chunk content (body text).
225    pub body: Field,
226    /// Monotonic index into the original `chunks` slice — stored for retrieval.
227    pub chunk_id: Field,
228}
229
230/// Construct the tantivy [`Schema`] and return field handles.
231#[must_use]
232pub fn build_schema() -> (Schema, BM25Fields) {
233    let mut builder = Schema::builder();
234
235    let code_indexing = TextFieldIndexing::default()
236        .set_tokenizer("code")
237        .set_index_option(IndexRecordOption::WithFreqsAndPositions);
238
239    let text_opts = TextOptions::default()
240        .set_indexing_options(code_indexing)
241        .set_stored();
242
243    let name = builder.add_text_field("name", text_opts.clone());
244    let file_path = builder.add_text_field("file_path", text_opts.clone());
245    let body = builder.add_text_field("body", text_opts);
246    let chunk_id = builder.add_u64_field("chunk_id", INDEXED | STORED);
247
248    let schema = builder.build();
249    (
250        schema,
251        BM25Fields {
252            name,
253            file_path,
254            body,
255            chunk_id,
256        },
257    )
258}
259
260// ──────────────────────────────────────────────────────────────────────────────
261// Bm25Index
262// ──────────────────────────────────────────────────────────────────────────────
263
264/// In-RAM BM25 index over a slice of [`CodeChunk`]s.
265///
266/// Built with [`Bm25Index::build`]; query with [`Bm25Index::search`].
267pub struct Bm25Index {
268    index: Index,
269    reader: IndexReader,
270    fields: BM25Fields,
271}
272
273impl Bm25Index {
274    /// Build a fresh in-RAM index from the given chunks.
275    ///
276    /// Registers the `"code"` tokenizer, indexes each chunk's `name`,
277    /// `file_path`, and `content`, then commits.
278    pub fn build(chunks: &[CodeChunk]) -> crate::Result<Self> {
279        let (schema, fields) = build_schema();
280
281        let index = Index::create_in_ram(schema.clone());
282
283        // Register our custom tokenizer under the name "code".
284        index.tokenizers().register("code", code_analyzer());
285
286        let mut writer = index
287            .writer(50_000_000)
288            .map_err(|e| crate::Error::Other(e.into()))?;
289
290        for (idx, chunk) in chunks.iter().enumerate() {
291            let mut doc = TantivyDocument::default();
292            doc.add_text(fields.name, &chunk.name);
293            doc.add_text(fields.file_path, &chunk.file_path);
294            doc.add_text(fields.body, &chunk.content);
295            doc.add_u64(fields.chunk_id, idx as u64);
296            writer
297                .add_document(doc)
298                .map_err(|e| crate::Error::Other(e.into()))?;
299        }
300
301        writer.commit().map_err(|e| crate::Error::Other(e.into()))?;
302
303        let reader = index
304            .reader_builder()
305            .reload_policy(ReloadPolicy::Manual)
306            .try_into()
307            .map_err(|e| crate::Error::Other(e.into()))?;
308
309        Ok(Self {
310            index,
311            reader,
312            fields,
313        })
314    }
315
316    /// Search the index for `query_text`, returning up to `top_k` results.
317    ///
318    /// Fields are boosted: `name` ×3.0, `file_path` ×1.5, `body` ×1.0.
319    ///
320    /// Returns a `Vec<(chunk_idx, bm25_score)>` sorted by descending score.
321    #[must_use]
322    pub fn search(&self, query_text: &str, top_k: usize) -> Vec<(usize, f32)> {
323        let searcher = self.reader.searcher();
324
325        // Build per-field boosted sub-queries and combine with BooleanQuery.
326        let make_sub = |field: Field, boost: f32| -> Box<dyn tantivy::query::Query> {
327            let mut parser = QueryParser::for_index(&self.index, vec![field]);
328            parser.set_field_boost(field, boost);
329            let q = parser.parse_query(query_text).unwrap_or_else(|_| {
330                // Fallback: empty query that matches nothing.
331                Box::new(tantivy::query::AllQuery)
332            });
333            Box::new(BoostQuery::new(q, boost))
334        };
335
336        let sub_queries: Vec<(Occur, Box<dyn tantivy::query::Query>)> = vec![
337            (Occur::Should, make_sub(self.fields.name, 3.0)),
338            (Occur::Should, make_sub(self.fields.file_path, 1.5)),
339            (Occur::Should, make_sub(self.fields.body, 1.0)),
340        ];
341
342        let combined = BooleanQuery::new(sub_queries);
343
344        let Ok(top_docs) = searcher.search(&combined, &TopDocs::with_limit(top_k).order_by_score())
345        else {
346            return vec![];
347        };
348
349        let mut results = Vec::with_capacity(top_docs.len());
350        for (score, doc_addr) in top_docs {
351            let Ok(doc) = searcher.doc::<TantivyDocument>(doc_addr) else {
352                continue;
353            };
354            let Some(id_val) = doc.get_first(self.fields.chunk_id) else {
355                continue;
356            };
357            let Some(id) = id_val.as_u64() else {
358                continue;
359            };
360            results.push((usize::try_from(id).unwrap_or(usize::MAX), score));
361        }
362
363        results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
364        results
365    }
366}
367
368// ──────────────────────────────────────────────────────────────────────────────
369// Tests
370// ──────────────────────────────────────────────────────────────────────────────
371
372#[cfg(test)]
373mod tests {
374    use super::*;
375
376    fn make_chunk(name: &str, file_path: &str, content: &str) -> CodeChunk {
377        CodeChunk {
378            file_path: file_path.to_string(),
379            name: name.to_string(),
380            kind: "function_item".to_string(),
381            start_line: 1,
382            end_line: 10,
383            content: content.to_string(),
384            enriched_content: content.to_string(),
385        }
386    }
387
388    #[test]
389    fn split_camel_case() {
390        let parts = split_code_identifier("parseJsonConfig");
391        assert_eq!(parts, vec!["parse", "json", "config"]);
392    }
393
394    #[test]
395    fn split_snake_case() {
396        let parts = split_code_identifier("my_func_name");
397        assert_eq!(parts, vec!["my", "func", "name"]);
398    }
399
400    #[test]
401    fn split_screaming_snake() {
402        let parts = split_code_identifier("MAX_BATCH_SIZE");
403        assert_eq!(parts, vec!["max", "batch", "size"]);
404    }
405
406    #[test]
407    fn split_mixed() {
408        let parts = split_code_identifier("MetalDriver");
409        assert_eq!(parts, vec!["metal", "driver"]);
410    }
411
412    #[test]
413    fn no_split_single_word() {
414        let parts = split_code_identifier("parser");
415        assert!(parts.is_empty(), "expected empty vec, got {parts:?}");
416    }
417
418    #[test]
419    fn bm25_index_search() {
420        let chunks = vec![
421            make_chunk(
422                "parseJsonConfig",
423                "src/config.rs",
424                "fn parseJsonConfig(data: &str) -> Config { ... }",
425            ),
426            make_chunk(
427                "renderHtml",
428                "src/render.rs",
429                "fn renderHtml(template: &str) -> String { ... }",
430            ),
431        ];
432
433        let index = Bm25Index::build(&chunks).expect("index build failed");
434        let results = index.search("parseJsonConfig", 5);
435
436        println!("results: {results:?}");
437        assert!(!results.is_empty(), "expected at least one result");
438        assert_eq!(results[0].0, 0, "chunk 0 should rank first");
439    }
440
441    #[test]
442    fn bm25_camel_case_subtoken_match() {
443        let chunks = vec![
444            make_chunk(
445                "parseJsonConfig",
446                "src/config.rs",
447                "fn parseJsonConfig(data: &str) -> Config { ... }",
448            ),
449            make_chunk(
450                "renderHtml",
451                "src/render.rs",
452                "fn renderHtml(template: &str) -> String { ... }",
453            ),
454        ];
455
456        let index = Bm25Index::build(&chunks).expect("index build failed");
457        // "json" is a sub-token of "parseJsonConfig" — should match chunk 0.
458        let results = index.search("json", 5);
459
460        println!("subtoken results: {results:?}");
461        assert!(!results.is_empty(), "expected results for sub-token 'json'");
462        assert_eq!(results[0].0, 0, "parseJsonConfig chunk should match 'json'");
463    }
464}