use crate::config::Config;
use crate::embedding::count_tokens;
use crate::mcp::logging::log_performance_metrics;
use crate::store::{CodeBlock, DocumentBlock, Store, TextBlock};
use anyhow::Result;
pub async fn process_code_blocks_batch(
store: &Store,
blocks: &[CodeBlock],
config: &Config,
) -> Result<()> {
let start_time = std::time::Instant::now();
let contents: Vec<String> = blocks
.iter()
.map(|block| {
let mut parts = Vec::new();
if !block.symbols.is_empty() {
for symbol in &block.symbols {
parts.push(symbol.clone());
}
}
parts.push(block.content.clone());
parts.join("\n")
})
.collect();
let embeddings = crate::embedding::generate_embeddings_batch(
contents,
true,
config,
crate::embedding::types::InputType::None,
)
.await?;
store.store_code_blocks(blocks, &embeddings).await?;
let duration_ms = start_time.elapsed().as_millis() as u64;
log_performance_metrics("code_blocks_batch", duration_ms, blocks.len(), None);
Ok(())
}
pub async fn process_text_blocks_batch(
store: &Store,
blocks: &[TextBlock],
config: &Config,
) -> Result<()> {
let start_time = std::time::Instant::now();
let contents: Vec<String> = blocks.iter().map(|b| b.content.clone()).collect();
let embeddings = crate::embedding::generate_embeddings_batch(
contents,
false,
config,
crate::embedding::types::InputType::Document,
)
.await?;
store.store_text_blocks(blocks, &embeddings).await?;
let duration_ms = start_time.elapsed().as_millis() as u64;
log_performance_metrics("text_blocks_batch", duration_ms, blocks.len(), None);
Ok(())
}
pub async fn process_document_blocks_batch(
store: &Store,
blocks: &[DocumentBlock],
config: &Config,
) -> Result<()> {
let start_time = std::time::Instant::now();
let contents: Vec<String> = blocks
.iter()
.map(|b| {
if !b.context.is_empty() {
format!("{}\n\n{}", b.context.join("\n"), b.content)
} else {
b.content.clone()
}
})
.collect();
let embeddings = crate::embedding::generate_embeddings_batch(
contents,
false,
config,
crate::embedding::types::InputType::Document,
)
.await?;
store.store_document_blocks(blocks, &embeddings).await?;
let duration_ms = start_time.elapsed().as_millis() as u64;
log_performance_metrics("document_blocks_batch", duration_ms, blocks.len(), None);
Ok(())
}
pub fn should_process_batch<T>(
batch: &[T],
get_content: impl Fn(&T) -> &str,
config: &Config,
) -> bool {
if batch.is_empty() {
return false;
}
if batch.len() >= config.index.embeddings_batch_size {
return true;
}
let total_tokens: usize = batch
.iter()
.map(|item| count_tokens(get_content(item)))
.sum();
total_tokens >= config.index.embeddings_max_tokens_per_batch
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_code_block_content_formatting() {
let block = CodeBlock {
path: "src/main.rs".to_string(),
language: "rust".to_string(),
content: "fn main() {\n println!(\"Hello, world!\");\n}".to_string(),
symbols: vec!["main".to_string()],
start_line: 1,
end_line: 3,
hash: "test_hash".to_string(),
distance: None,
};
let mut parts = Vec::new();
if !block.symbols.is_empty() {
for symbol in &block.symbols {
parts.push(symbol.clone());
}
}
parts.push(block.content.clone());
let formatted = parts.join("\n");
assert!(formatted.contains("main"));
assert!(formatted.contains("fn main()"));
assert!(formatted.contains("Hello, world!"));
}
#[test]
fn test_code_block_without_symbols() {
let block = CodeBlock {
path: "src/utils.rs".to_string(),
language: "rust".to_string(),
content: "const VERSION: &str = \"1.0.0\";".to_string(),
symbols: vec![],
start_line: 1,
end_line: 1,
hash: "test_hash2".to_string(),
distance: None,
};
let mut parts = Vec::new();
if !block.symbols.is_empty() {
for symbol in &block.symbols {
parts.push(symbol.clone());
}
}
parts.push(block.content.clone());
let formatted = parts.join("\n");
assert_eq!(formatted, block.content);
assert!(formatted.contains("const VERSION"));
}
}