Skip to main content

ripvec_core/
bm25.rs

1//! BM25 keyword search index for code chunks.
2//!
3//! Provides camelCase/snake_case-aware tokenization via [`CodeSplitFilter`]
4//! and an in-RAM tantivy index ([`Bm25Index`]) that supports per-field
5//! boosted queries so identifier sub-tokens (e.g. `json` from
6//! `parseJsonConfig`) are matched correctly.
7
8use tantivy::schema::{
9    Field, INDEXED, IndexRecordOption, STORED, Schema, TextFieldIndexing, TextOptions, Value,
10};
11use tantivy::tokenizer::{
12    LowerCaser, SimpleTokenizer, TextAnalyzer, Token, TokenFilter, TokenStream, Tokenizer,
13};
14use tantivy::{
15    Index, IndexReader, ReloadPolicy, TantivyDocument,
16    collector::TopDocs,
17    query::{BooleanQuery, BoostQuery, Occur, QueryParser},
18};
19
20use crate::chunk::CodeChunk;
21
22// ──────────────────────────────────────────────────────────────────────────────
23// Identifier splitting
24// ──────────────────────────────────────────────────────────────────────────────
25
26/// Split a code identifier into its constituent sub-words.
27///
28/// Handles camelCase, PascalCase, snake_case, SCREAMING_SNAKE_CASE, and mixed
29/// forms (e.g. `HTML5Parser`). Returns the lowercased parts; if there is only
30/// one part (i.e. the token cannot be split further) an empty `Vec` is returned
31/// so callers know no expansion is needed.
32///
33/// # Examples
34/// ```
35/// # use ripvec_core::bm25::split_code_identifier;
36/// assert_eq!(split_code_identifier("parseJsonConfig"), vec!["parse", "json", "config"]);
37/// assert_eq!(split_code_identifier("my_func_name"),    vec!["my", "func", "name"]);
38/// assert_eq!(split_code_identifier("HTML5Parser"),     vec!["html5", "parser"]);
39/// assert_eq!(split_code_identifier("parser"),          Vec::<String>::new());
40/// ```
41#[must_use]
42pub fn split_code_identifier(text: &str) -> Vec<String> {
43    // First split on underscores (handles snake_case / SCREAMING_SNAKE).
44    let underscore_parts: Vec<&str> = text.split('_').filter(|s| !s.is_empty()).collect();
45
46    let mut parts: Vec<String> = Vec::new();
47
48    for segment in &underscore_parts {
49        // Within each segment apply camelCase splitting.
50        // State machine: accumulate a "run" of chars, flush when the boundary
51        // rule triggers.
52        let chars: Vec<char> = segment.chars().collect();
53        let n = chars.len();
54        let mut start = 0usize;
55
56        let mut i = 0usize;
57        while i < n {
58            // Detect camelCase / acronym boundaries.
59            // Treat digits as "non-upper" (like lowercase) for boundary
60            // detection so that "HTML5Parser" splits into ["html5", "parser"].
61            if i > start {
62                let prev = chars[i - 1];
63                let cur = chars[i];
64
65                // lowercase/digit → uppercase: "parseJson"/"HTML5Parser" → split before cur.
66                let lower_to_upper =
67                    (prev.is_lowercase() || prev.is_ascii_digit()) && cur.is_uppercase();
68
69                // uppercase-run → lowercase: "HTMLParser" → the 'P' starts a new word.
70                // The split point is before the last uppercase in the run.
71                // Digits are NOT treated as terminators here so "HTML5" stays intact.
72                let upper_run_to_lower = i >= 2
73                    && prev.is_uppercase()
74                    && cur.is_lowercase()
75                    && chars[i - 2].is_uppercase();
76
77                if lower_to_upper {
78                    parts.push(chars[start..i].iter().collect::<String>().to_lowercase());
79                    start = i;
80                } else if upper_run_to_lower {
81                    // Flush everything up to (but not including) prev.
82                    parts.push(
83                        chars[start..i - 1]
84                            .iter()
85                            .collect::<String>()
86                            .to_lowercase(),
87                    );
88                    start = i - 1;
89                }
90            }
91            i += 1;
92        }
93        // Flush remaining
94        if start < n {
95            parts.push(chars[start..n].iter().collect::<String>().to_lowercase());
96        }
97    }
98
99    // If we ended up with a single part that equals the lowercased original,
100    // there was nothing to split — return empty to signal "no expansion".
101    if parts.len() <= 1 {
102        return Vec::new();
103    }
104
105    parts
106}
107
108// ──────────────────────────────────────────────────────────────────────────────
109// Tantivy token filter
110// ──────────────────────────────────────────────────────────────────────────────
111
112/// Token stream produced by [`CodeSplitFilterWrapper`].
113///
114/// For each token from the upstream stream the original token is emitted first,
115/// then any sub-tokens produced by [`split_code_identifier`].
116pub struct CodeSplitTokenStream<'a, T> {
117    /// Upstream token stream (already lowercased by `LowerCaser`).
118    tail: T,
119    /// Buffer of pending sub-tokens; filled in reverse so `pop()` yields them
120    /// in order.
121    pending: &'a mut Vec<Token>,
122}
123
124impl<T: TokenStream> TokenStream for CodeSplitTokenStream<'_, T> {
125    fn advance(&mut self) -> bool {
126        // Drain any buffered sub-tokens first.
127        if let Some(tok) = self.pending.pop() {
128            *self.tail.token_mut() = tok;
129            return true;
130        }
131
132        // Advance the upstream stream.
133        if !self.tail.advance() {
134            return false;
135        }
136
137        let upstream = self.tail.token().clone();
138        let sub_tokens = split_code_identifier(&upstream.text);
139
140        // Queue sub-tokens in reverse order so pop() gives them in order.
141        let position_offset = upstream.position;
142        for (idx, sub) in sub_tokens.iter().enumerate().rev() {
143            let mut t = upstream.clone();
144            t.text.clone_from(sub);
145            t.position = position_offset + idx + 1;
146            self.pending.push(t);
147        }
148
149        // The upstream token is already current — nothing extra needed.
150        true
151    }
152
153    fn token(&self) -> &Token {
154        self.tail.token()
155    }
156
157    fn token_mut(&mut self) -> &mut Token {
158        self.tail.token_mut()
159    }
160}
161
162/// Tantivy [`TokenFilter`] that emits sub-tokens for camelCase/snake_case
163/// identifiers in addition to the original token.
164#[derive(Clone)]
165pub struct CodeSplitFilter;
166
167impl TokenFilter for CodeSplitFilter {
168    type Tokenizer<T: Tokenizer> = CodeSplitFilterWrapper<T>;
169
170    fn transform<T: Tokenizer>(self, tokenizer: T) -> CodeSplitFilterWrapper<T> {
171        CodeSplitFilterWrapper {
172            inner: tokenizer,
173            pending: Vec::new(),
174        }
175    }
176}
177
178/// Wrapper tokenizer produced by [`CodeSplitFilter::transform`].
179#[derive(Clone)]
180pub struct CodeSplitFilterWrapper<T> {
181    inner: T,
182    pending: Vec<Token>,
183}
184
185impl<T: Tokenizer> Tokenizer for CodeSplitFilterWrapper<T> {
186    type TokenStream<'a> = CodeSplitTokenStream<'a, T::TokenStream<'a>>;
187
188    fn token_stream<'a>(&'a mut self, text: &'a str) -> Self::TokenStream<'a> {
189        self.pending.clear();
190        CodeSplitTokenStream {
191            tail: self.inner.token_stream(text),
192            pending: &mut self.pending,
193        }
194    }
195}
196
197// ──────────────────────────────────────────────────────────────────────────────
198// Analyzer
199// ──────────────────────────────────────────────────────────────────────────────
200
201/// Build a tantivy [`TextAnalyzer`] that tokenizes, expands camelCase/snake_case
202/// identifiers into sub-tokens, then lowercases everything.
203///
204/// `CodeSplitFilter` must run **before** `LowerCaser` so that camelCase boundaries
205/// (uppercase letters) are still visible when splitting.
206#[must_use]
207pub fn code_analyzer() -> TextAnalyzer {
208    TextAnalyzer::builder(SimpleTokenizer::default())
209        .filter(CodeSplitFilter)
210        .filter(LowerCaser)
211        .build()
212}
213
214// ──────────────────────────────────────────────────────────────────────────────
215// Schema
216// ──────────────────────────────────────────────────────────────────────────────
217
218/// Handles to the tantivy schema fields used by [`Bm25Index`].
219pub struct BM25Fields {
220    /// Chunk name (function/struct name, high-value signal).
221    pub name: Field,
222    /// Relative file path of the source file.
223    pub file_path: Field,
224    /// Full chunk content (body text).
225    pub body: Field,
226    /// Monotonic index into the original `chunks` slice — stored for retrieval.
227    pub chunk_id: Field,
228}
229
230/// Construct the tantivy [`Schema`] and return field handles.
231#[must_use]
232pub fn build_schema() -> (Schema, BM25Fields) {
233    let mut builder = Schema::builder();
234
235    let code_indexing = TextFieldIndexing::default()
236        .set_tokenizer("code")
237        .set_index_option(IndexRecordOption::WithFreqsAndPositions);
238
239    let text_opts = TextOptions::default()
240        .set_indexing_options(code_indexing)
241        .set_stored();
242
243    let name = builder.add_text_field("name", text_opts.clone());
244    let file_path = builder.add_text_field("file_path", text_opts.clone());
245    let body = builder.add_text_field("body", text_opts);
246    let chunk_id = builder.add_u64_field("chunk_id", INDEXED | STORED);
247
248    let schema = builder.build();
249    (
250        schema,
251        BM25Fields {
252            name,
253            file_path,
254            body,
255            chunk_id,
256        },
257    )
258}
259
260// ──────────────────────────────────────────────────────────────────────────────
261// Bm25Index
262// ──────────────────────────────────────────────────────────────────────────────
263
264/// In-RAM BM25 index over a slice of [`CodeChunk`]s.
265///
266/// Built with [`Bm25Index::build`]; query with [`Bm25Index::search`].
267pub struct Bm25Index {
268    index: Index,
269    reader: IndexReader,
270    fields: BM25Fields,
271}
272
273impl Bm25Index {
274    /// Build a fresh in-RAM index from the given chunks.
275    ///
276    /// Registers the `"code"` tokenizer, indexes each chunk's `name`,
277    /// `file_path`, and `content`, then commits.
278    pub fn build(chunks: &[CodeChunk]) -> crate::Result<Self> {
279        let (schema, fields) = build_schema();
280
281        let index = Index::create_in_ram(schema.clone());
282
283        // Register our custom tokenizer under the name "code".
284        index.tokenizers().register("code", code_analyzer());
285
286        let mut writer = index
287            .writer(50_000_000)
288            .map_err(|e| crate::Error::Other(e.into()))?;
289
290        for (idx, chunk) in chunks.iter().enumerate() {
291            let mut doc = TantivyDocument::default();
292            doc.add_text(fields.name, &chunk.name);
293            doc.add_text(fields.file_path, &chunk.file_path);
294            doc.add_text(fields.body, &chunk.content);
295            doc.add_u64(fields.chunk_id, idx as u64);
296            writer
297                .add_document(doc)
298                .map_err(|e| crate::Error::Other(e.into()))?;
299        }
300
301        writer.commit().map_err(|e| crate::Error::Other(e.into()))?;
302
303        let reader = index
304            .reader_builder()
305            .reload_policy(ReloadPolicy::Manual)
306            .try_into()
307            .map_err(|e| crate::Error::Other(e.into()))?;
308
309        Ok(Self {
310            index,
311            reader,
312            fields,
313        })
314    }
315
316    /// Search the index for `query_text`, returning up to `top_k` results.
317    ///
318    /// Fields are boosted: `name` ×3.0, `file_path` ×1.5, `body` ×1.0.
319    ///
320    /// Returns a `Vec<(chunk_idx, bm25_score)>` sorted by descending score.
321    #[must_use]
322    pub fn search(&self, query_text: &str, top_k: usize) -> Vec<(usize, f32)> {
323        let searcher = self.reader.searcher();
324
325        // Build per-field boosted sub-queries and combine with BooleanQuery.
326        let make_sub = |field: Field, boost: f32| -> Box<dyn tantivy::query::Query> {
327            let mut parser = QueryParser::for_index(&self.index, vec![field]);
328            parser.set_field_boost(field, boost);
329            let q = parser.parse_query(query_text).unwrap_or_else(|_| {
330                // Fallback: empty query that matches nothing.
331                Box::new(tantivy::query::AllQuery)
332            });
333            Box::new(BoostQuery::new(q, boost))
334        };
335
336        let sub_queries: Vec<(Occur, Box<dyn tantivy::query::Query>)> = vec![
337            (Occur::Should, make_sub(self.fields.name, 3.0)),
338            (Occur::Should, make_sub(self.fields.file_path, 1.5)),
339            (Occur::Should, make_sub(self.fields.body, 1.0)),
340        ];
341
342        let combined = BooleanQuery::new(sub_queries);
343
344        let Ok(top_docs) = searcher.search(&combined, &TopDocs::with_limit(top_k)) else {
345            return vec![];
346        };
347
348        let mut results = Vec::with_capacity(top_docs.len());
349        for (score, doc_addr) in top_docs {
350            let Ok(doc) = searcher.doc::<TantivyDocument>(doc_addr) else {
351                continue;
352            };
353            let Some(id_val) = doc.get_first(self.fields.chunk_id) else {
354                continue;
355            };
356            let Some(id) = id_val.as_u64() else {
357                continue;
358            };
359            results.push((usize::try_from(id).unwrap_or(usize::MAX), score));
360        }
361
362        results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
363        results
364    }
365}
366
367// ──────────────────────────────────────────────────────────────────────────────
368// Tests
369// ──────────────────────────────────────────────────────────────────────────────
370
371#[cfg(test)]
372mod tests {
373    use super::*;
374
375    fn make_chunk(name: &str, file_path: &str, content: &str) -> CodeChunk {
376        CodeChunk {
377            file_path: file_path.to_string(),
378            name: name.to_string(),
379            kind: "function_item".to_string(),
380            start_line: 1,
381            end_line: 10,
382            content: content.to_string(),
383            enriched_content: content.to_string(),
384        }
385    }
386
387    #[test]
388    fn split_camel_case() {
389        let parts = split_code_identifier("parseJsonConfig");
390        assert_eq!(parts, vec!["parse", "json", "config"]);
391    }
392
393    #[test]
394    fn split_snake_case() {
395        let parts = split_code_identifier("my_func_name");
396        assert_eq!(parts, vec!["my", "func", "name"]);
397    }
398
399    #[test]
400    fn split_screaming_snake() {
401        let parts = split_code_identifier("MAX_BATCH_SIZE");
402        assert_eq!(parts, vec!["max", "batch", "size"]);
403    }
404
405    #[test]
406    fn split_mixed() {
407        let parts = split_code_identifier("MetalDriver");
408        assert_eq!(parts, vec!["metal", "driver"]);
409    }
410
411    #[test]
412    fn no_split_single_word() {
413        let parts = split_code_identifier("parser");
414        assert!(parts.is_empty(), "expected empty vec, got {parts:?}");
415    }
416
417    #[test]
418    fn bm25_index_search() {
419        let chunks = vec![
420            make_chunk(
421                "parseJsonConfig",
422                "src/config.rs",
423                "fn parseJsonConfig(data: &str) -> Config { ... }",
424            ),
425            make_chunk(
426                "renderHtml",
427                "src/render.rs",
428                "fn renderHtml(template: &str) -> String { ... }",
429            ),
430        ];
431
432        let index = Bm25Index::build(&chunks).expect("index build failed");
433        let results = index.search("parseJsonConfig", 5);
434
435        println!("results: {results:?}");
436        assert!(!results.is_empty(), "expected at least one result");
437        assert_eq!(results[0].0, 0, "chunk 0 should rank first");
438    }
439
440    #[test]
441    fn bm25_camel_case_subtoken_match() {
442        let chunks = vec![
443            make_chunk(
444                "parseJsonConfig",
445                "src/config.rs",
446                "fn parseJsonConfig(data: &str) -> Config { ... }",
447            ),
448            make_chunk(
449                "renderHtml",
450                "src/render.rs",
451                "fn renderHtml(template: &str) -> String { ... }",
452            ),
453        ];
454
455        let index = Bm25Index::build(&chunks).expect("index build failed");
456        // "json" is a sub-token of "parseJsonConfig" — should match chunk 0.
457        let results = index.search("json", 5);
458
459        println!("subtoken results: {results:?}");
460        assert!(!results.is_empty(), "expected results for sub-token 'json'");
461        assert_eq!(results[0].0, 0, "parseJsonConfig chunk should match 'json'");
462    }
463}