pub mod provider;
#[cfg(test)]
mod tests;
pub mod types;
use crate::config::Config;
use anyhow::Result;
use tiktoken_rs::cl100k_base;
pub use provider::{create_embedding_provider_from_parts, EmbeddingProvider};
pub use types::*;
pub async fn generate_embeddings(
contents: &str,
is_code: bool,
config: &Config,
) -> Result<Vec<f32>> {
let model_string = if is_code {
&config.embedding.code_model
} else {
&config.embedding.text_model
};
let (provider, model) = parse_provider_model(model_string);
let provider_impl = create_embedding_provider_from_parts(&provider, &model).await?;
provider_impl.generate_embedding(contents).await
}
pub fn count_tokens(text: &str) -> usize {
let bpe = cl100k_base().expect("Failed to load cl100k_base tokenizer");
bpe.encode_with_special_tokens(text).len()
}
pub fn truncate_output(output: &str, max_tokens: usize) -> String {
if max_tokens == 0 {
return output.to_string();
}
let token_count = count_tokens(output);
if token_count <= max_tokens {
return output.to_string();
}
let estimated_chars = max_tokens * 3; let truncated = if output.len() > estimated_chars {
&output[..estimated_chars]
} else {
output
};
let last_newline = truncated.rfind('\n').unwrap_or(truncated.len());
let final_truncated = &truncated[..last_newline];
format!(
"{}\n\n[Output truncated - {} tokens estimated, max {} allowed. Use more specific queries to reduce output size]",
final_truncated,
token_count,
max_tokens
)
}
pub fn split_texts_into_token_limited_batches(
texts: Vec<String>,
max_batch_size: usize,
max_tokens_per_batch: usize,
) -> Vec<Vec<String>> {
let mut batches = Vec::new();
let mut current_batch = Vec::new();
let mut current_token_count = 0;
for text in texts {
let text_tokens = count_tokens(&text);
if !current_batch.is_empty()
&& (current_batch.len() >= max_batch_size
|| current_token_count + text_tokens > max_tokens_per_batch)
{
batches.push(current_batch);
current_batch = Vec::new();
current_token_count = 0;
}
current_batch.push(text);
current_token_count += text_tokens;
}
if !current_batch.is_empty() {
batches.push(current_batch);
}
batches
}
pub async fn generate_embeddings_batch(
texts: Vec<String>,
is_code: bool,
config: &Config,
input_type: types::InputType,
) -> Result<Vec<Vec<f32>>> {
let model_string = if is_code {
&config.embedding.code_model
} else {
&config.embedding.text_model
};
let (provider, model) = parse_provider_model(model_string);
let provider_impl = create_embedding_provider_from_parts(&provider, &model).await?;
let batches = split_texts_into_token_limited_batches(
texts,
config.index.embeddings_batch_size,
config.index.embeddings_max_tokens_per_batch,
);
let mut all_embeddings = Vec::new();
for batch in batches {
let batch_embeddings = provider_impl
.generate_embeddings_batch(batch, input_type.clone())
.await?;
all_embeddings.extend(batch_embeddings);
}
Ok(all_embeddings)
}
pub fn calculate_unique_content_hash(contents: &str, file_path: &str) -> String {
use sha2::{Digest, Sha256};
let mut hasher = Sha256::new();
hasher.update(contents.as_bytes());
hasher.update(file_path.as_bytes());
format!("{:x}", hasher.finalize())
}
pub fn calculate_content_hash_with_lines(
contents: &str,
file_path: &str,
start_line: usize,
end_line: usize,
) -> String {
use sha2::{Digest, Sha256};
let mut hasher = Sha256::new();
hasher.update(contents.as_bytes());
hasher.update(file_path.as_bytes());
hasher.update(start_line.to_string().as_bytes());
hasher.update(end_line.to_string().as_bytes());
format!("{:x}", hasher.finalize())
}
pub fn calculate_content_hash(contents: &str) -> String {
use sha2::{Digest, Sha256};
let mut hasher = Sha256::new();
hasher.update(contents.as_bytes());
format!("{:x}", hasher.finalize())
}
#[derive(Debug, Clone)]
pub struct SearchModeEmbeddings {
pub code_embeddings: Option<Vec<f32>>,
pub text_embeddings: Option<Vec<f32>>,
}
pub async fn generate_search_embeddings(
query: &str,
mode: &str,
config: &Config,
) -> Result<SearchModeEmbeddings> {
match mode {
"code" => {
let embeddings = generate_embeddings(query, true, config).await?;
Ok(SearchModeEmbeddings {
code_embeddings: Some(embeddings),
text_embeddings: None,
})
}
"docs" | "text" => {
let embeddings = generate_embeddings(query, false, config).await?;
Ok(SearchModeEmbeddings {
code_embeddings: None,
text_embeddings: Some(embeddings),
})
}
"all" => {
let code_model = &config.embedding.code_model;
let text_model = &config.embedding.text_model;
if code_model == text_model {
let embeddings = generate_embeddings(query, true, config).await?;
Ok(SearchModeEmbeddings {
code_embeddings: Some(embeddings.clone()),
text_embeddings: Some(embeddings),
})
} else {
let code_embeddings = generate_embeddings(query, true, config).await?;
let text_embeddings = generate_embeddings(query, false, config).await?;
Ok(SearchModeEmbeddings {
code_embeddings: Some(code_embeddings),
text_embeddings: Some(text_embeddings),
})
}
}
_ => Err(anyhow::anyhow!(
"Invalid search mode '{}'. Use 'all', 'code', 'docs', or 'text'.",
mode
)),
}
}