1use tantivy::schema::{
9 Field, INDEXED, IndexRecordOption, STORED, Schema, TextFieldIndexing, TextOptions, Value,
10};
11use tantivy::tokenizer::{
12 LowerCaser, SimpleTokenizer, TextAnalyzer, Token, TokenFilter, TokenStream, Tokenizer,
13};
14use tantivy::{
15 Index, IndexReader, ReloadPolicy, TantivyDocument,
16 collector::TopDocs,
17 query::{BooleanQuery, BoostQuery, Occur, QueryParser},
18};
19
20use crate::chunk::CodeChunk;
21
22#[must_use]
42pub fn split_code_identifier(text: &str) -> Vec<String> {
43 let underscore_parts: Vec<&str> = text.split('_').filter(|s| !s.is_empty()).collect();
45
46 let mut parts: Vec<String> = Vec::new();
47
48 for segment in &underscore_parts {
49 let chars: Vec<char> = segment.chars().collect();
53 let n = chars.len();
54 let mut start = 0usize;
55
56 let mut i = 0usize;
57 while i < n {
58 if i > start {
62 let prev = chars[i - 1];
63 let cur = chars[i];
64
65 let lower_to_upper =
67 (prev.is_lowercase() || prev.is_ascii_digit()) && cur.is_uppercase();
68
69 let upper_run_to_lower = i >= 2
73 && prev.is_uppercase()
74 && cur.is_lowercase()
75 && chars[i - 2].is_uppercase();
76
77 if lower_to_upper {
78 parts.push(chars[start..i].iter().collect::<String>().to_lowercase());
79 start = i;
80 } else if upper_run_to_lower {
81 parts.push(
83 chars[start..i - 1]
84 .iter()
85 .collect::<String>()
86 .to_lowercase(),
87 );
88 start = i - 1;
89 }
90 }
91 i += 1;
92 }
93 if start < n {
95 parts.push(chars[start..n].iter().collect::<String>().to_lowercase());
96 }
97 }
98
99 if parts.len() <= 1 {
102 return Vec::new();
103 }
104
105 parts
106}
107
108pub struct CodeSplitTokenStream<'a, T> {
117 tail: T,
119 pending: &'a mut Vec<Token>,
122}
123
124impl<T: TokenStream> TokenStream for CodeSplitTokenStream<'_, T> {
125 fn advance(&mut self) -> bool {
126 if let Some(tok) = self.pending.pop() {
128 *self.tail.token_mut() = tok;
129 return true;
130 }
131
132 if !self.tail.advance() {
134 return false;
135 }
136
137 let upstream = self.tail.token().clone();
138 let sub_tokens = split_code_identifier(&upstream.text);
139
140 let position_offset = upstream.position;
142 for (idx, sub) in sub_tokens.iter().enumerate().rev() {
143 let mut t = upstream.clone();
144 t.text.clone_from(sub);
145 t.position = position_offset + idx + 1;
146 self.pending.push(t);
147 }
148
149 true
151 }
152
153 fn token(&self) -> &Token {
154 self.tail.token()
155 }
156
157 fn token_mut(&mut self) -> &mut Token {
158 self.tail.token_mut()
159 }
160}
161
162#[derive(Clone)]
165pub struct CodeSplitFilter;
166
167impl TokenFilter for CodeSplitFilter {
168 type Tokenizer<T: Tokenizer> = CodeSplitFilterWrapper<T>;
169
170 fn transform<T: Tokenizer>(self, tokenizer: T) -> CodeSplitFilterWrapper<T> {
171 CodeSplitFilterWrapper {
172 inner: tokenizer,
173 pending: Vec::new(),
174 }
175 }
176}
177
178#[derive(Clone)]
180pub struct CodeSplitFilterWrapper<T> {
181 inner: T,
182 pending: Vec<Token>,
183}
184
185impl<T: Tokenizer> Tokenizer for CodeSplitFilterWrapper<T> {
186 type TokenStream<'a> = CodeSplitTokenStream<'a, T::TokenStream<'a>>;
187
188 fn token_stream<'a>(&'a mut self, text: &'a str) -> Self::TokenStream<'a> {
189 self.pending.clear();
190 CodeSplitTokenStream {
191 tail: self.inner.token_stream(text),
192 pending: &mut self.pending,
193 }
194 }
195}
196
197#[must_use]
207pub fn code_analyzer() -> TextAnalyzer {
208 TextAnalyzer::builder(SimpleTokenizer::default())
209 .filter(CodeSplitFilter)
210 .filter(LowerCaser)
211 .build()
212}
213
214pub struct BM25Fields {
220 pub name: Field,
222 pub file_path: Field,
224 pub body: Field,
226 pub chunk_id: Field,
228}
229
230#[must_use]
232pub fn build_schema() -> (Schema, BM25Fields) {
233 let mut builder = Schema::builder();
234
235 let code_indexing = TextFieldIndexing::default()
236 .set_tokenizer("code")
237 .set_index_option(IndexRecordOption::WithFreqsAndPositions);
238
239 let text_opts = TextOptions::default()
240 .set_indexing_options(code_indexing)
241 .set_stored();
242
243 let name = builder.add_text_field("name", text_opts.clone());
244 let file_path = builder.add_text_field("file_path", text_opts.clone());
245 let body = builder.add_text_field("body", text_opts);
246 let chunk_id = builder.add_u64_field("chunk_id", INDEXED | STORED);
247
248 let schema = builder.build();
249 (
250 schema,
251 BM25Fields {
252 name,
253 file_path,
254 body,
255 chunk_id,
256 },
257 )
258}
259
260pub struct Bm25Index {
268 index: Index,
269 reader: IndexReader,
270 fields: BM25Fields,
271}
272
273impl Bm25Index {
274 pub fn build(chunks: &[CodeChunk]) -> crate::Result<Self> {
279 let (schema, fields) = build_schema();
280
281 let index = Index::create_in_ram(schema.clone());
282
283 index.tokenizers().register("code", code_analyzer());
285
286 let mut writer = index
287 .writer(50_000_000)
288 .map_err(|e| crate::Error::Other(e.into()))?;
289
290 for (idx, chunk) in chunks.iter().enumerate() {
291 let mut doc = TantivyDocument::default();
292 doc.add_text(fields.name, &chunk.name);
293 doc.add_text(fields.file_path, &chunk.file_path);
294 doc.add_text(fields.body, &chunk.content);
295 doc.add_u64(fields.chunk_id, idx as u64);
296 writer
297 .add_document(doc)
298 .map_err(|e| crate::Error::Other(e.into()))?;
299 }
300
301 writer.commit().map_err(|e| crate::Error::Other(e.into()))?;
302
303 let reader = index
304 .reader_builder()
305 .reload_policy(ReloadPolicy::Manual)
306 .try_into()
307 .map_err(|e| crate::Error::Other(e.into()))?;
308
309 Ok(Self {
310 index,
311 reader,
312 fields,
313 })
314 }
315
316 #[must_use]
322 pub fn search(&self, query_text: &str, top_k: usize) -> Vec<(usize, f32)> {
323 let searcher = self.reader.searcher();
324
325 let make_sub = |field: Field, boost: f32| -> Box<dyn tantivy::query::Query> {
327 let mut parser = QueryParser::for_index(&self.index, vec![field]);
328 parser.set_field_boost(field, boost);
329 let q = parser.parse_query(query_text).unwrap_or_else(|_| {
330 Box::new(tantivy::query::AllQuery)
332 });
333 Box::new(BoostQuery::new(q, boost))
334 };
335
336 let sub_queries: Vec<(Occur, Box<dyn tantivy::query::Query>)> = vec![
337 (Occur::Should, make_sub(self.fields.name, 3.0)),
338 (Occur::Should, make_sub(self.fields.file_path, 1.5)),
339 (Occur::Should, make_sub(self.fields.body, 1.0)),
340 ];
341
342 let combined = BooleanQuery::new(sub_queries);
343
344 let Ok(top_docs) = searcher.search(&combined, &TopDocs::with_limit(top_k)) else {
345 return vec![];
346 };
347
348 let mut results = Vec::with_capacity(top_docs.len());
349 for (score, doc_addr) in top_docs {
350 let Ok(doc) = searcher.doc::<TantivyDocument>(doc_addr) else {
351 continue;
352 };
353 let Some(id_val) = doc.get_first(self.fields.chunk_id) else {
354 continue;
355 };
356 let Some(id) = id_val.as_u64() else {
357 continue;
358 };
359 results.push((usize::try_from(id).unwrap_or(usize::MAX), score));
360 }
361
362 results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
363 results
364 }
365}
366
367#[cfg(test)]
372mod tests {
373 use super::*;
374
375 fn make_chunk(name: &str, file_path: &str, content: &str) -> CodeChunk {
376 CodeChunk {
377 file_path: file_path.to_string(),
378 name: name.to_string(),
379 kind: "function_item".to_string(),
380 start_line: 1,
381 end_line: 10,
382 content: content.to_string(),
383 enriched_content: content.to_string(),
384 }
385 }
386
387 #[test]
388 fn split_camel_case() {
389 let parts = split_code_identifier("parseJsonConfig");
390 assert_eq!(parts, vec!["parse", "json", "config"]);
391 }
392
393 #[test]
394 fn split_snake_case() {
395 let parts = split_code_identifier("my_func_name");
396 assert_eq!(parts, vec!["my", "func", "name"]);
397 }
398
399 #[test]
400 fn split_screaming_snake() {
401 let parts = split_code_identifier("MAX_BATCH_SIZE");
402 assert_eq!(parts, vec!["max", "batch", "size"]);
403 }
404
405 #[test]
406 fn split_mixed() {
407 let parts = split_code_identifier("MetalDriver");
408 assert_eq!(parts, vec!["metal", "driver"]);
409 }
410
411 #[test]
412 fn no_split_single_word() {
413 let parts = split_code_identifier("parser");
414 assert!(parts.is_empty(), "expected empty vec, got {parts:?}");
415 }
416
417 #[test]
418 fn bm25_index_search() {
419 let chunks = vec![
420 make_chunk(
421 "parseJsonConfig",
422 "src/config.rs",
423 "fn parseJsonConfig(data: &str) -> Config { ... }",
424 ),
425 make_chunk(
426 "renderHtml",
427 "src/render.rs",
428 "fn renderHtml(template: &str) -> String { ... }",
429 ),
430 ];
431
432 let index = Bm25Index::build(&chunks).expect("index build failed");
433 let results = index.search("parseJsonConfig", 5);
434
435 println!("results: {results:?}");
436 assert!(!results.is_empty(), "expected at least one result");
437 assert_eq!(results[0].0, 0, "chunk 0 should rank first");
438 }
439
440 #[test]
441 fn bm25_camel_case_subtoken_match() {
442 let chunks = vec![
443 make_chunk(
444 "parseJsonConfig",
445 "src/config.rs",
446 "fn parseJsonConfig(data: &str) -> Config { ... }",
447 ),
448 make_chunk(
449 "renderHtml",
450 "src/render.rs",
451 "fn renderHtml(template: &str) -> String { ... }",
452 ),
453 ];
454
455 let index = Bm25Index::build(&chunks).expect("index build failed");
456 let results = index.search("json", 5);
458
459 println!("subtoken results: {results:?}");
460 assert!(!results.is_empty(), "expected results for sub-token 'json'");
461 assert_eq!(results[0].0, 0, "parseJsonConfig chunk should match 'json'");
462 }
463}