pub mod constants;
#[cfg(test)]
mod mod_test;
pub mod provider;
pub mod types;
use anyhow::Result;
use tiktoken_rs::cl100k_base;
pub use provider::{create_embedding_provider_from_parts, EmbeddingProvider};
pub use types::*;
pub async fn generate_embeddings(contents: &str, provider: &str, model: &str) -> Result<Vec<f32>> {
let (provider_type, model_name) = parse_provider_model(&format!("{}:{}", provider, model));
let provider_impl = create_embedding_provider_from_parts(&provider_type, &model_name).await?;
provider_impl.generate_embedding(contents).await
}
pub fn count_tokens(text: &str) -> usize {
let bpe = cl100k_base().expect("Failed to load cl100k_base tokenizer");
bpe.encode_with_special_tokens(text).len()
}
pub fn truncate_output(output: &str, max_tokens: usize) -> String {
if max_tokens == 0 {
return output.to_string();
}
let token_count = count_tokens(output);
if token_count <= max_tokens {
return output.to_string();
}
let estimated_chars = max_tokens * 3; let truncated = if output.len() > estimated_chars {
&output[..estimated_chars]
} else {
output
};
let last_newline = truncated.rfind('\n').unwrap_or(truncated.len());
let final_truncated = &truncated[..last_newline];
format!(
"{}\n\n[Output truncated - {} tokens estimated, max {} allowed. Use more specific queries to reduce output size]",
final_truncated,
token_count,
max_tokens
)
}
pub fn split_texts_into_token_limited_batches(
texts: Vec<String>,
max_batch_size: usize,
max_tokens_per_batch: usize,
) -> Vec<Vec<String>> {
let mut batches = Vec::new();
let mut current_batch = Vec::new();
let mut current_token_count = 0;
for text in texts {
let text_tokens = count_tokens(&text);
if !current_batch.is_empty()
&& (current_batch.len() >= max_batch_size
|| current_token_count + text_tokens > max_tokens_per_batch)
{
batches.push(current_batch);
current_batch = Vec::new();
current_token_count = 0;
}
current_batch.push(text);
current_token_count += text_tokens;
}
if !current_batch.is_empty() {
batches.push(current_batch);
}
batches
}
pub async fn generate_embeddings_batch(
texts: Vec<String>,
provider: &str,
model: &str,
input_type: types::InputType,
batch_size: usize,
max_tokens_per_batch: usize,
) -> Result<Vec<Vec<f32>>> {
let (provider_type, model_name) = parse_provider_model(&format!("{}:{}", provider, model));
let provider_impl = create_embedding_provider_from_parts(&provider_type, &model_name).await?;
let batches = split_texts_into_token_limited_batches(texts, batch_size, max_tokens_per_batch);
let mut all_embeddings = Vec::new();
for batch in batches {
let batch_embeddings = provider_impl
.generate_embeddings_batch(batch, input_type.clone())
.await?;
all_embeddings.extend(batch_embeddings);
}
Ok(all_embeddings)
}