use liter_llm::{EmbeddingInput, EmbeddingRequest, LlmClient};
use crate::core::config::LlmConfig;
pub async fn embed_via_llm<T: AsRef<str>>(
texts: &[T],
config: &LlmConfig,
normalize: bool,
) -> crate::Result<(Vec<Vec<f32>>, Option<crate::types::LlmUsage>)> {
if texts.is_empty() {
return Ok((Vec::new(), None));
}
let client = super::client::create_client(config)?;
let input_strings: Vec<String> = texts.iter().map(|t| t.as_ref().to_string()).collect();
let input = if input_strings.len() == 1 {
EmbeddingInput::Single(input_strings.into_iter().next().expect("checked non-empty"))
} else {
EmbeddingInput::Multiple(input_strings)
};
let request = EmbeddingRequest {
model: config.model.clone(),
input,
encoding_format: None,
dimensions: None,
user: None,
};
let response = client.embed(request).await.map_err(|e| {
crate::KreuzbergError::embedding(format!("LLM embedding request failed (model={}): {e}", config.model))
})?;
let usage = super::usage::extract_usage_from_embedding(&response, "embeddings");
let mut data = response.data;
data.sort_by_key(|obj| obj.index);
let mut embeddings: Vec<Vec<f32>> = data
.into_iter()
.map(|obj| obj.embedding.into_iter().map(|v| v as f32).collect())
.collect();
if normalize {
for embedding in &mut embeddings {
normalize_l2(embedding);
}
}
Ok((embeddings, usage))
}
fn normalize_l2(embedding: &mut [f32]) {
let magnitude: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
if magnitude > f32::EPSILON {
let inv_mag = 1.0 / magnitude;
embedding.iter_mut().for_each(|x| *x *= inv_mag);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_normalize_l2_unit_vector() {
let mut v = vec![1.0f32, 0.0, 0.0];
normalize_l2(&mut v);
assert!((v[0] - 1.0).abs() < f32::EPSILON);
assert!((v[1]).abs() < f32::EPSILON);
}
#[test]
fn test_normalize_l2_arbitrary_vector() {
let mut v = vec![3.0f32, 4.0];
normalize_l2(&mut v);
let magnitude: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((magnitude - 1.0).abs() < 1e-6);
}
#[test]
fn test_normalize_l2_zero_vector() {
let mut v = vec![0.0f32, 0.0, 0.0];
normalize_l2(&mut v);
assert!(v.iter().all(|&x| x == 0.0));
}
}