use crate::model::Model;
use crate::{context::Context, error::MullamaError, sys, token::TokenId};
use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct Embeddings {
pub data: Vec<f32>,
pub dimension: usize,
}
impl Embeddings {
pub fn new(data: Vec<f32>, dimension: usize) -> Self {
Self { data, dimension }
}
pub fn get(&self, index: usize) -> Option<&[f32]> {
if index * self.dimension < self.data.len() {
Some(&self.data[index * self.dimension..(index + 1) * self.dimension])
} else {
None
}
}
pub fn len(&self) -> usize {
if self.dimension == 0 {
0
} else {
self.data.len() / self.dimension
}
}
pub fn is_empty(&self) -> bool {
self.data.is_empty()
}
pub fn as_slice(&self) -> &[f32] {
&self.data
}
pub fn to_vecs(&self) -> Vec<Vec<f32>> {
(0..self.len())
.filter_map(|i| self.get(i).map(|e| e.to_vec()))
.collect()
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum PoolingStrategy {
Last,
Mean,
First,
Max,
Native,
}
#[derive(Debug, Clone)]
pub struct EmbeddingConfig {
pub pooling: PoolingStrategy,
pub normalize: bool,
pub batch_size: usize,
}
impl Default for EmbeddingConfig {
fn default() -> Self {
Self {
pooling: PoolingStrategy::Native,
normalize: true,
batch_size: 32,
}
}
}
pub struct EmbeddingGenerator {
model: Arc<Model>,
context: Context,
config: EmbeddingConfig,
}
impl EmbeddingGenerator {
pub fn new(model: Arc<Model>, config: EmbeddingConfig) -> Result<Self, MullamaError> {
use crate::context::ContextParams;
let pooling_type = match config.pooling {
PoolingStrategy::Last => sys::llama_pooling_type::LLAMA_POOLING_TYPE_LAST,
PoolingStrategy::Mean => sys::llama_pooling_type::LLAMA_POOLING_TYPE_MEAN,
PoolingStrategy::First => sys::llama_pooling_type::LLAMA_POOLING_TYPE_CLS,
PoolingStrategy::Max | PoolingStrategy::Native => {
sys::llama_pooling_type::LLAMA_POOLING_TYPE_UNSPECIFIED
}
};
let params = ContextParams {
embeddings: true,
pooling_type,
..Default::default()
};
let context = Context::new(model.clone(), params)?;
Ok(Self {
model,
context,
config,
})
}
pub fn embed_text(&mut self, text: &str) -> Result<Vec<f32>, MullamaError> {
let tokens = self.model.tokenize(text, true, false)?;
self.embed_tokens(&tokens)
}
pub fn embed_tokens(&mut self, tokens: &[TokenId]) -> Result<Vec<f32>, MullamaError> {
if tokens.is_empty() {
return Err(MullamaError::InvalidInput(
"Cannot embed empty token sequence".to_string(),
));
}
self.context.kv_cache_clear();
self.context.decode(tokens)?;
let n_embd = self.model.n_embd() as usize;
let embedding = match self.config.pooling {
PoolingStrategy::Native => {
self.context
.get_embeddings_seq(0)
.map(|e| e.to_vec())
.ok_or_else(|| {
MullamaError::EmbeddingError(
"Failed to get sequence embeddings".to_string(),
)
})?
}
PoolingStrategy::Last => {
let last_idx = (tokens.len() - 1) as i32;
self.context
.get_embeddings_ith(last_idx)
.map(|e| e.to_vec())
.ok_or_else(|| {
MullamaError::EmbeddingError(
"Failed to get last token embedding".to_string(),
)
})?
}
PoolingStrategy::First => {
self.context
.get_embeddings_ith(0)
.map(|e| e.to_vec())
.ok_or_else(|| {
MullamaError::EmbeddingError(
"Failed to get first token embedding".to_string(),
)
})?
}
PoolingStrategy::Mean => {
self.pool_mean(tokens.len(), n_embd)?
}
PoolingStrategy::Max => {
self.pool_max(tokens.len(), n_embd)?
}
};
if self.config.normalize {
Ok(EmbeddingUtil::normalize(&embedding))
} else {
Ok(embedding)
}
}
pub fn embed_batch(&mut self, texts: &[&str]) -> Result<Vec<Vec<f32>>, MullamaError> {
let mut results = Vec::with_capacity(texts.len());
for text in texts {
results.push(self.embed_text(text)?);
}
Ok(results)
}
fn pool_mean(&self, n_tokens: usize, n_embd: usize) -> Result<Vec<f32>, MullamaError> {
let mut mean = vec![0.0f32; n_embd];
for i in 0..n_tokens {
if let Some(emb) = self.context.get_embeddings_ith(i as i32) {
for (j, &val) in emb.iter().enumerate() {
if j < n_embd {
mean[j] += val;
}
}
}
}
let n = n_tokens as f32;
for val in &mut mean {
*val /= n;
}
Ok(mean)
}
fn pool_max(&self, n_tokens: usize, n_embd: usize) -> Result<Vec<f32>, MullamaError> {
let mut max = vec![f32::NEG_INFINITY; n_embd];
for i in 0..n_tokens {
if let Some(emb) = self.context.get_embeddings_ith(i as i32) {
for (j, &val) in emb.iter().enumerate() {
if j < n_embd && val > max[j] {
max[j] = val;
}
}
}
}
Ok(max)
}
pub fn embedding_dim(&self) -> usize {
self.model.n_embd() as usize
}
pub fn model(&self) -> &Arc<Model> {
&self.model
}
}
pub struct EmbeddingUtil;
impl EmbeddingUtil {
pub fn generate_embeddings(
context: &mut Context,
tokens: &[TokenId],
) -> Result<Embeddings, MullamaError> {
if tokens.is_empty() {
return Err(MullamaError::InvalidInput(
"Cannot embed empty token sequence".to_string(),
));
}
context.set_embeddings(true);
context.kv_cache_clear();
context.decode(tokens)?;
let n_embd = context.model().n_embd() as usize;
if let Some(emb) = context.get_embeddings_seq(0) {
return Ok(Embeddings::new(emb.to_vec(), n_embd));
}
let mut all_embeddings = Vec::with_capacity(tokens.len() * n_embd);
for i in 0..tokens.len() {
if let Some(emb) = context.get_embeddings_ith(i as i32) {
all_embeddings.extend_from_slice(emb);
} else {
return Err(MullamaError::EmbeddingError(format!(
"Failed to get embedding for token at index {}",
i
)));
}
}
Ok(Embeddings::new(all_embeddings, n_embd))
}
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() {
return 0.0;
}
let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let magnitude_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let magnitude_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if magnitude_a == 0.0 || magnitude_b == 0.0 {
0.0
} else {
dot_product / (magnitude_a * magnitude_b)
}
}
pub fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() {
return f32::MAX;
}
a.iter()
.zip(b.iter())
.map(|(x, y)| (x - y).powi(2))
.sum::<f32>()
.sqrt()
}
pub fn dot_product(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() {
return 0.0;
}
a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
}
pub fn normalize(embedding: &[f32]) -> Vec<f32> {
let magnitude: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
if magnitude == 0.0 {
embedding.to_vec()
} else {
embedding.iter().map(|x| x / magnitude).collect()
}
}
pub fn normalize_inplace(embedding: &mut [f32]) {
let magnitude: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
if magnitude > 0.0 {
for x in embedding.iter_mut() {
*x /= magnitude;
}
}
}
pub fn find_most_similar(query: &[f32], embeddings: &[Vec<f32>]) -> Option<(usize, f32)> {
if embeddings.is_empty() {
return None;
}
let mut best_idx = 0;
let mut best_sim = f32::NEG_INFINITY;
for (i, emb) in embeddings.iter().enumerate() {
let sim = Self::cosine_similarity(query, emb);
if sim > best_sim {
best_sim = sim;
best_idx = i;
}
}
Some((best_idx, best_sim))
}
pub fn find_top_k(query: &[f32], embeddings: &[Vec<f32>], k: usize) -> Vec<(usize, f32)> {
let mut scores: Vec<(usize, f32)> = embeddings
.iter()
.enumerate()
.map(|(i, emb)| (i, Self::cosine_similarity(query, emb)))
.collect();
scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scores.truncate(k);
scores
}
pub fn average(embeddings: &[Vec<f32>]) -> Option<Vec<f32>> {
if embeddings.is_empty() {
return None;
}
let dim = embeddings[0].len();
let mut avg = vec![0.0f32; dim];
for emb in embeddings {
if emb.len() != dim {
return None; }
for (i, &val) in emb.iter().enumerate() {
avg[i] += val;
}
}
let n = embeddings.len() as f32;
for val in &mut avg {
*val /= n;
}
Some(avg)
}
pub fn weighted_average(embeddings: &[Vec<f32>], weights: &[f32]) -> Option<Vec<f32>> {
if embeddings.is_empty() || embeddings.len() != weights.len() {
return None;
}
let dim = embeddings[0].len();
let mut avg = vec![0.0f32; dim];
let weight_sum: f32 = weights.iter().sum();
if weight_sum == 0.0 {
return None;
}
for (emb, &weight) in embeddings.iter().zip(weights.iter()) {
if emb.len() != dim {
return None;
}
for (i, &val) in emb.iter().enumerate() {
avg[i] += val * weight;
}
}
for val in &mut avg {
*val /= weight_sum;
}
Some(avg)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_embeddings_struct() {
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let emb = Embeddings::new(data, 3);
assert_eq!(emb.len(), 2);
assert!(!emb.is_empty());
assert_eq!(emb.get(0), Some(&[1.0, 2.0, 3.0][..]));
assert_eq!(emb.get(1), Some(&[4.0, 5.0, 6.0][..]));
assert_eq!(emb.get(2), None);
}
#[test]
fn test_cosine_similarity() {
let a = vec![1.0, 0.0, 0.0];
let b = vec![1.0, 0.0, 0.0];
assert!((EmbeddingUtil::cosine_similarity(&a, &b) - 1.0).abs() < 0.001);
let c = vec![0.0, 1.0, 0.0];
assert!((EmbeddingUtil::cosine_similarity(&a, &c)).abs() < 0.001);
let d = vec![-1.0, 0.0, 0.0];
assert!((EmbeddingUtil::cosine_similarity(&a, &d) - (-1.0)).abs() < 0.001);
}
#[test]
fn test_normalize() {
let emb = vec![3.0, 4.0];
let normalized = EmbeddingUtil::normalize(&emb);
assert!((normalized[0] - 0.6).abs() < 0.001);
assert!((normalized[1] - 0.8).abs() < 0.001);
let mag: f32 = normalized.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((mag - 1.0).abs() < 0.001);
}
#[test]
fn test_euclidean_distance() {
let a = vec![0.0, 0.0];
let b = vec![3.0, 4.0];
assert!((EmbeddingUtil::euclidean_distance(&a, &b) - 5.0).abs() < 0.001);
}
#[test]
fn test_find_most_similar() {
let query = vec![1.0, 0.0];
let embeddings = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![0.7, 0.7]];
let result = EmbeddingUtil::find_most_similar(&query, &embeddings);
assert!(result.is_some());
let (idx, _sim) = result.unwrap();
assert_eq!(idx, 0);
}
#[test]
fn test_average() {
let embeddings = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
let avg = EmbeddingUtil::average(&embeddings).unwrap();
assert!((avg[0] - 2.0).abs() < 0.001);
assert!((avg[1] - 3.0).abs() < 0.001);
}
}