1use tantivy::schema::{
9 Field, INDEXED, IndexRecordOption, STORED, Schema, TextFieldIndexing, TextOptions, Value,
10};
11use tantivy::tokenizer::{
12 LowerCaser, SimpleTokenizer, TextAnalyzer, Token, TokenFilter, TokenStream, Tokenizer,
13};
14use tantivy::{
15 Index, IndexReader, ReloadPolicy, TantivyDocument,
16 collector::TopDocs,
17 query::{BooleanQuery, BoostQuery, Occur, QueryParser},
18};
19
20use crate::chunk::CodeChunk;
21
22#[must_use]
42pub fn split_code_identifier(text: &str) -> Vec<String> {
43 let underscore_parts: Vec<&str> = text.split('_').filter(|s| !s.is_empty()).collect();
45
46 let mut parts: Vec<String> = Vec::new();
47
48 for segment in &underscore_parts {
49 let chars: Vec<char> = segment.chars().collect();
53 let n = chars.len();
54 let mut start = 0usize;
55
56 let mut i = 0usize;
57 while i < n {
58 if i > start {
62 let prev = chars[i - 1];
63 let cur = chars[i];
64
65 let lower_to_upper =
67 (prev.is_lowercase() || prev.is_ascii_digit()) && cur.is_uppercase();
68
69 let upper_run_to_lower = i >= 2
73 && prev.is_uppercase()
74 && cur.is_lowercase()
75 && chars[i - 2].is_uppercase();
76
77 if lower_to_upper {
78 parts.push(chars[start..i].iter().collect::<String>().to_lowercase());
79 start = i;
80 } else if upper_run_to_lower {
81 parts.push(
83 chars[start..i - 1]
84 .iter()
85 .collect::<String>()
86 .to_lowercase(),
87 );
88 start = i - 1;
89 }
90 }
91 i += 1;
92 }
93 if start < n {
95 parts.push(chars[start..n].iter().collect::<String>().to_lowercase());
96 }
97 }
98
99 if parts.len() <= 1 {
102 return Vec::new();
103 }
104
105 parts
106}
107
108pub struct CodeSplitTokenStream<'a, T> {
117 tail: T,
119 pending: &'a mut Vec<Token>,
122}
123
124impl<T: TokenStream> TokenStream for CodeSplitTokenStream<'_, T> {
125 fn advance(&mut self) -> bool {
126 if let Some(tok) = self.pending.pop() {
128 *self.tail.token_mut() = tok;
129 return true;
130 }
131
132 if !self.tail.advance() {
134 return false;
135 }
136
137 let upstream = self.tail.token().clone();
138 let sub_tokens = split_code_identifier(&upstream.text);
139
140 let position_offset = upstream.position;
142 for (idx, sub) in sub_tokens.iter().enumerate().rev() {
143 let mut t = upstream.clone();
144 t.text.clone_from(sub);
145 t.position = position_offset + idx + 1;
146 self.pending.push(t);
147 }
148
149 true
151 }
152
153 fn token(&self) -> &Token {
154 self.tail.token()
155 }
156
157 fn token_mut(&mut self) -> &mut Token {
158 self.tail.token_mut()
159 }
160}
161
162#[derive(Clone)]
165pub struct CodeSplitFilter;
166
167impl TokenFilter for CodeSplitFilter {
168 type Tokenizer<T: Tokenizer> = CodeSplitFilterWrapper<T>;
169
170 fn transform<T: Tokenizer>(self, tokenizer: T) -> CodeSplitFilterWrapper<T> {
171 CodeSplitFilterWrapper {
172 inner: tokenizer,
173 pending: Vec::new(),
174 }
175 }
176}
177
178#[derive(Clone)]
180pub struct CodeSplitFilterWrapper<T> {
181 inner: T,
182 pending: Vec<Token>,
183}
184
185impl<T: Tokenizer> Tokenizer for CodeSplitFilterWrapper<T> {
186 type TokenStream<'a> = CodeSplitTokenStream<'a, T::TokenStream<'a>>;
187
188 fn token_stream<'a>(&'a mut self, text: &'a str) -> Self::TokenStream<'a> {
189 self.pending.clear();
190 CodeSplitTokenStream {
191 tail: self.inner.token_stream(text),
192 pending: &mut self.pending,
193 }
194 }
195}
196
197#[must_use]
207pub fn code_analyzer() -> TextAnalyzer {
208 TextAnalyzer::builder(SimpleTokenizer::default())
209 .filter(CodeSplitFilter)
210 .filter(LowerCaser)
211 .build()
212}
213
214pub struct BM25Fields {
220 pub name: Field,
222 pub file_path: Field,
224 pub body: Field,
226 pub chunk_id: Field,
228}
229
230#[must_use]
232pub fn build_schema() -> (Schema, BM25Fields) {
233 let mut builder = Schema::builder();
234
235 let code_indexing = TextFieldIndexing::default()
236 .set_tokenizer("code")
237 .set_index_option(IndexRecordOption::WithFreqsAndPositions);
238
239 let text_opts = TextOptions::default()
240 .set_indexing_options(code_indexing)
241 .set_stored();
242
243 let name = builder.add_text_field("name", text_opts.clone());
244 let file_path = builder.add_text_field("file_path", text_opts.clone());
245 let body = builder.add_text_field("body", text_opts);
246 let chunk_id = builder.add_u64_field("chunk_id", INDEXED | STORED);
247
248 let schema = builder.build();
249 (
250 schema,
251 BM25Fields {
252 name,
253 file_path,
254 body,
255 chunk_id,
256 },
257 )
258}
259
260pub struct Bm25Index {
268 index: Index,
269 reader: IndexReader,
270 fields: BM25Fields,
271}
272
273impl Bm25Index {
274 pub fn build(chunks: &[CodeChunk]) -> crate::Result<Self> {
279 let (schema, fields) = build_schema();
280
281 let index = Index::create_in_ram(schema.clone());
282
283 index.tokenizers().register("code", code_analyzer());
285
286 let mut writer = index
287 .writer(50_000_000)
288 .map_err(|e| crate::Error::Other(e.into()))?;
289
290 for (idx, chunk) in chunks.iter().enumerate() {
291 let mut doc = TantivyDocument::default();
292 doc.add_text(fields.name, &chunk.name);
293 doc.add_text(fields.file_path, &chunk.file_path);
294 doc.add_text(fields.body, &chunk.content);
295 doc.add_u64(fields.chunk_id, idx as u64);
296 writer
297 .add_document(doc)
298 .map_err(|e| crate::Error::Other(e.into()))?;
299 }
300
301 writer.commit().map_err(|e| crate::Error::Other(e.into()))?;
302
303 let reader = index
304 .reader_builder()
305 .reload_policy(ReloadPolicy::Manual)
306 .try_into()
307 .map_err(|e| crate::Error::Other(e.into()))?;
308
309 Ok(Self {
310 index,
311 reader,
312 fields,
313 })
314 }
315
316 #[must_use]
322 pub fn search(&self, query_text: &str, top_k: usize) -> Vec<(usize, f32)> {
323 let searcher = self.reader.searcher();
324
325 let doc_count = usize::try_from(searcher.num_docs()).unwrap_or(usize::MAX);
330 let effective_limit = top_k.min(doc_count).max(1);
331
332 let make_sub = |field: Field, boost: f32| -> Box<dyn tantivy::query::Query> {
334 let mut parser = QueryParser::for_index(&self.index, vec![field]);
335 parser.set_field_boost(field, boost);
336 let q = parser.parse_query(query_text).unwrap_or_else(|_| {
337 Box::new(tantivy::query::AllQuery)
339 });
340 Box::new(BoostQuery::new(q, boost))
341 };
342
343 let sub_queries: Vec<(Occur, Box<dyn tantivy::query::Query>)> = vec![
344 (Occur::Should, make_sub(self.fields.name, 3.0)),
345 (Occur::Should, make_sub(self.fields.file_path, 1.5)),
346 (Occur::Should, make_sub(self.fields.body, 1.0)),
347 ];
348
349 let combined = BooleanQuery::new(sub_queries);
350
351 let Ok(top_docs) = searcher.search(
352 &combined,
353 &TopDocs::with_limit(effective_limit).order_by_score(),
354 ) else {
355 return vec![];
356 };
357
358 let mut results = Vec::with_capacity(top_docs.len());
359 for (score, doc_addr) in top_docs {
360 let Ok(doc) = searcher.doc::<TantivyDocument>(doc_addr) else {
361 continue;
362 };
363 let Some(id_val) = doc.get_first(self.fields.chunk_id) else {
364 continue;
365 };
366 let Some(id) = id_val.as_u64() else {
367 continue;
368 };
369 results.push((usize::try_from(id).unwrap_or(usize::MAX), score));
370 }
371
372 results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
373 results
374 }
375}
376
377#[cfg(test)]
382mod tests {
383 use super::*;
384
385 fn make_chunk(name: &str, file_path: &str, content: &str) -> CodeChunk {
386 CodeChunk {
387 file_path: file_path.to_string(),
388 name: name.to_string(),
389 kind: "function_item".to_string(),
390 start_line: 1,
391 end_line: 10,
392 content: content.to_string(),
393 enriched_content: content.to_string(),
394 }
395 }
396
397 #[test]
398 fn split_camel_case() {
399 let parts = split_code_identifier("parseJsonConfig");
400 assert_eq!(parts, vec!["parse", "json", "config"]);
401 }
402
403 #[test]
404 fn split_snake_case() {
405 let parts = split_code_identifier("my_func_name");
406 assert_eq!(parts, vec!["my", "func", "name"]);
407 }
408
409 #[test]
410 fn split_screaming_snake() {
411 let parts = split_code_identifier("MAX_BATCH_SIZE");
412 assert_eq!(parts, vec!["max", "batch", "size"]);
413 }
414
415 #[test]
416 fn split_mixed() {
417 let parts = split_code_identifier("MetalDriver");
418 assert_eq!(parts, vec!["metal", "driver"]);
419 }
420
421 #[test]
422 fn no_split_single_word() {
423 let parts = split_code_identifier("parser");
424 assert!(parts.is_empty(), "expected empty vec, got {parts:?}");
425 }
426
427 #[test]
428 fn bm25_index_search() {
429 let chunks = vec![
430 make_chunk(
431 "parseJsonConfig",
432 "src/config.rs",
433 "fn parseJsonConfig(data: &str) -> Config { ... }",
434 ),
435 make_chunk(
436 "renderHtml",
437 "src/render.rs",
438 "fn renderHtml(template: &str) -> String { ... }",
439 ),
440 ];
441
442 let index = Bm25Index::build(&chunks).expect("index build failed");
443 let results = index.search("parseJsonConfig", 5);
444
445 println!("results: {results:?}");
446 assert!(!results.is_empty(), "expected at least one result");
447 assert_eq!(results[0].0, 0, "chunk 0 should rank first");
448 }
449
450 #[test]
451 fn bm25_camel_case_subtoken_match() {
452 let chunks = vec![
453 make_chunk(
454 "parseJsonConfig",
455 "src/config.rs",
456 "fn parseJsonConfig(data: &str) -> Config { ... }",
457 ),
458 make_chunk(
459 "renderHtml",
460 "src/render.rs",
461 "fn renderHtml(template: &str) -> String { ... }",
462 ),
463 ];
464
465 let index = Bm25Index::build(&chunks).expect("index build failed");
466 let results = index.search("json", 5);
468
469 println!("subtoken results: {results:?}");
470 assert!(!results.is_empty(), "expected results for sub-token 'json'");
471 assert_eq!(results[0].0, 0, "parseJsonConfig chunk should match 'json'");
472 }
473}