use crate::config::Config;
use crate::embedding::count_tokens;
use crate::indexer::languages;
use crate::indexer::signature_extractor::{extract_signatures, SignatureItem};
use crate::mcp::logging::log_performance_metrics;
use crate::store::{CodeBlock, DocumentBlock, Store, TextBlock};
use anyhow::Result;
use std::collections::HashMap;
use tree_sitter::Parser;
fn format_code_block_with_context(block: &CodeBlock, signatures: &[SignatureItem]) -> String {
let mut context_parts = Vec::new();
context_parts.push(format!("File: {}", block.path));
context_parts.push(format!("Language: {}", block.language));
if !signatures.is_empty() {
context_parts.push(String::from("\nFile Structure:"));
for sig in signatures {
let sig_line = match sig.kind.as_str() {
"function" | "method" => format!("- {} {}()", sig.kind, sig.name),
"class" | "struct" | "interface" | "trait" => {
format!("- {} {}", sig.kind, sig.name)
}
"type" | "enum" => format!("- {} {}", sig.kind, sig.name),
_ => format!("- {}", sig.name),
};
context_parts.push(sig_line);
}
}
if !block.symbols.is_empty() {
context_parts.push(String::from("\nBlock Symbols:"));
for symbol in &block.symbols {
context_parts.push(format!("- {}", symbol));
}
}
context_parts.push(String::from("\nCode:"));
context_parts.push(block.content.clone());
context_parts.join("\n")
}
fn extract_file_signatures_for_context(
_file_path: &str,
contents: &str,
language: &str,
) -> Vec<SignatureItem> {
let mut parser = Parser::new();
let lang_impl = match languages::get_language(language) {
Some(impl_) => impl_,
None => return Vec::new(), };
if parser.set_language(&lang_impl.get_ts_language()).is_err() {
return Vec::new();
}
let tree = parser
.parse(contents, None)
.unwrap_or_else(|| parser.parse("", None).unwrap());
extract_signatures(tree.root_node(), contents, lang_impl.as_ref())
}
pub async fn process_code_blocks_batch(
store: &Store,
blocks: &[CodeBlock],
config: &Config,
) -> Result<()> {
let start_time = std::time::Instant::now();
let mut blocks_by_file: HashMap<String, Vec<&CodeBlock>> = HashMap::new();
for block in blocks {
blocks_by_file
.entry(block.path.clone())
.or_default()
.push(block);
}
let mut enriched_contents = Vec::new();
for (file_path, file_blocks) in blocks_by_file {
let signatures = if let Ok(contents) = std::fs::read_to_string(&file_path) {
if let Some(first_block) = file_blocks.first() {
extract_file_signatures_for_context(&file_path, &contents, &first_block.language)
} else {
Vec::new()
}
} else {
Vec::new()
};
for block in file_blocks {
let enriched_content = format_code_block_with_context(block, &signatures);
enriched_contents.push(enriched_content);
}
}
let embeddings = crate::embedding::generate_embeddings_batch(
enriched_contents,
true,
config,
crate::embedding::types::InputType::Document,
)
.await?;
store.store_code_blocks(blocks, &embeddings).await?;
let duration_ms = start_time.elapsed().as_millis() as u64;
log_performance_metrics("code_blocks_batch", duration_ms, blocks.len(), None);
Ok(())
}
pub async fn process_text_blocks_batch(
store: &Store,
blocks: &[TextBlock],
config: &Config,
) -> Result<()> {
let start_time = std::time::Instant::now();
let contents: Vec<String> = blocks.iter().map(|b| b.content.clone()).collect();
let embeddings = crate::embedding::generate_embeddings_batch(
contents,
false,
config,
crate::embedding::types::InputType::Document,
)
.await?;
store.store_text_blocks(blocks, &embeddings).await?;
let duration_ms = start_time.elapsed().as_millis() as u64;
log_performance_metrics("text_blocks_batch", duration_ms, blocks.len(), None);
Ok(())
}
pub async fn process_document_blocks_batch(
store: &Store,
blocks: &[DocumentBlock],
config: &Config,
) -> Result<()> {
let start_time = std::time::Instant::now();
let contents: Vec<String> = blocks
.iter()
.map(|b| {
if !b.context.is_empty() {
format!("{}\n\n{}", b.context.join("\n"), b.content)
} else {
b.content.clone()
}
})
.collect();
let embeddings = crate::embedding::generate_embeddings_batch(
contents,
false,
config,
crate::embedding::types::InputType::Document,
)
.await?;
store.store_document_blocks(blocks, &embeddings).await?;
let duration_ms = start_time.elapsed().as_millis() as u64;
log_performance_metrics("document_blocks_batch", duration_ms, blocks.len(), None);
Ok(())
}
pub fn should_process_batch<T>(
batch: &[T],
get_content: impl Fn(&T) -> &str,
config: &Config,
) -> bool {
if batch.is_empty() {
return false;
}
if batch.len() >= config.index.embeddings_batch_size {
return true;
}
let total_tokens: usize = batch
.iter()
.map(|item| count_tokens(get_content(item)))
.sum();
total_tokens >= config.index.embeddings_max_tokens_per_batch
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_format_code_block_with_context() {
let block = CodeBlock {
path: "src/main.rs".to_string(),
language: "rust".to_string(),
content: "fn main() {\n println!(\"Hello, world!\");\n}".to_string(),
symbols: vec!["main".to_string()],
start_line: 1,
end_line: 3,
hash: "test_hash".to_string(),
distance: None,
};
let signatures = vec![
SignatureItem {
kind: "function".to_string(),
name: "main".to_string(),
signature: "fn main()".to_string(),
description: None,
start_line: 1,
end_line: 3,
},
SignatureItem {
kind: "struct".to_string(),
name: "Config".to_string(),
signature: "struct Config".to_string(),
description: None,
start_line: 5,
end_line: 10,
},
];
let formatted = format_code_block_with_context(&block, &signatures);
assert!(formatted.contains("File: src/main.rs"));
assert!(formatted.contains("Language: rust"));
assert!(formatted.contains("File Structure:"));
assert!(formatted.contains("- function main()"));
assert!(formatted.contains("- struct Config"));
assert!(formatted.contains("Block Symbols:"));
assert!(formatted.contains("- main"));
assert!(formatted.contains("Code:"));
assert!(formatted.contains("fn main()"));
assert!(formatted.contains("Hello, world!"));
let formatted_no_sigs = format_code_block_with_context(&block, &[]);
assert!(formatted_no_sigs.contains("File: src/main.rs"));
assert!(!formatted_no_sigs.contains("File Structure:"));
}
#[test]
fn test_format_code_block_without_symbols() {
let block = CodeBlock {
path: "src/utils.rs".to_string(),
language: "rust".to_string(),
content: "const VERSION: &str = \"1.0.0\";".to_string(),
symbols: vec![],
start_line: 1,
end_line: 1,
hash: "test_hash2".to_string(),
distance: None,
};
let formatted = format_code_block_with_context(&block, &[]);
assert!(!formatted.contains("Block Symbols:"));
assert!(formatted.contains("File: src/utils.rs"));
assert!(formatted.contains("const VERSION"));
}
}