use crate::model::{InferenceContext, Model, ModelConfig};
use crate::tokenizer::Tokenizer;
#[derive(Debug, Clone)]
pub struct EmbeddingConfig {
pub layer: i32,
pub pooling: PoolingStrategy,
pub normalize: bool,
pub max_length: usize,
pub truncation: TruncationStrategy,
}
impl Default for EmbeddingConfig {
fn default() -> Self {
Self {
layer: -1,
pooling: PoolingStrategy::Mean,
normalize: true,
max_length: 512,
truncation: TruncationStrategy::Right,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PoolingStrategy {
Last,
First,
Mean,
Max,
WeightedMean,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TruncationStrategy {
Right,
Left,
Middle,
}
pub struct EmbeddingExtractor {
config: EmbeddingConfig,
hidden_dim: usize,
}
impl EmbeddingExtractor {
pub fn new(config: EmbeddingConfig, model_config: &ModelConfig) -> Self {
Self {
config,
hidden_dim: model_config.hidden_size,
}
}
pub fn embed_text(
&self,
model: &dyn Model,
tokenizer: &Tokenizer,
ctx: &mut InferenceContext,
text: &str,
) -> Result<Vec<f32>, EmbeddingError> {
let tokens = tokenizer.encode(text, false)?;
let tokens = self.truncate_tokens(&tokens);
let embeddings = self.get_token_embeddings(model, ctx, &tokens)?;
let pooled = self.pool_embeddings(&embeddings, tokens.len());
if self.config.normalize {
Ok(self.normalize_embedding(&pooled))
} else {
Ok(pooled)
}
}
pub fn embed_batch(
&self,
model: &dyn Model,
tokenizer: &Tokenizer,
ctx: &mut InferenceContext,
texts: &[&str],
) -> Result<Vec<Vec<f32>>, EmbeddingError> {
let mut results = Vec::with_capacity(texts.len());
for text in texts {
ctx.reset();
let embedding = self.embed_text(model, tokenizer, ctx, text)?;
results.push(embedding);
}
Ok(results)
}
fn truncate_tokens(&self, tokens: &[u32]) -> Vec<u32> {
if tokens.len() <= self.config.max_length {
return tokens.to_vec();
}
match self.config.truncation {
TruncationStrategy::Right => tokens[..self.config.max_length].to_vec(),
TruncationStrategy::Left => tokens[tokens.len() - self.config.max_length..].to_vec(),
TruncationStrategy::Middle => {
let half = self.config.max_length / 2;
let mut truncated = tokens[..half].to_vec();
truncated.extend_from_slice(&tokens[tokens.len() - half..]);
truncated
}
}
}
fn get_token_embeddings(
&self,
model: &dyn Model,
ctx: &mut InferenceContext,
tokens: &[u32],
) -> Result<Vec<Vec<f32>>, EmbeddingError> {
let mut embeddings = Vec::with_capacity(tokens.len());
for token in tokens {
let logits = model.forward(&[*token], ctx)?;
let logits_data = logits.as_f32()?;
let dim = self.hidden_dim.min(logits_data.len());
embeddings.push(logits_data[..dim].to_vec());
}
Ok(embeddings)
}
fn pool_embeddings(&self, embeddings: &[Vec<f32>], _seq_len: usize) -> Vec<f32> {
if embeddings.is_empty() {
return vec![0.0; self.hidden_dim];
}
let dim = embeddings[0].len();
match self.config.pooling {
PoolingStrategy::Last => embeddings.last().cloned().unwrap_or_else(|| vec![0.0; dim]),
PoolingStrategy::First => embeddings
.first()
.cloned()
.unwrap_or_else(|| vec![0.0; dim]),
PoolingStrategy::Mean => {
let mut mean = vec![0.0f32; dim];
for emb in embeddings {
for (i, &v) in emb.iter().enumerate() {
mean[i] += v;
}
}
let n = embeddings.len() as f32;
for v in &mut mean {
*v /= n;
}
mean
}
PoolingStrategy::Max => {
let mut max = vec![f32::NEG_INFINITY; dim];
for emb in embeddings {
for (i, &v) in emb.iter().enumerate() {
max[i] = max[i].max(v);
}
}
max
}
PoolingStrategy::WeightedMean => {
let mut weighted = vec![0.0f32; dim];
let mut total_weight = 0.0f32;
for (pos, emb) in embeddings.iter().enumerate() {
let weight = (pos + 1) as f32;
total_weight += weight;
for (i, &v) in emb.iter().enumerate() {
weighted[i] += v * weight;
}
}
for v in &mut weighted {
*v /= total_weight;
}
weighted
}
}
}
fn normalize_embedding(&self, embedding: &[f32]) -> Vec<f32> {
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
embedding.iter().map(|x| x / norm).collect()
} else {
embedding.to_vec()
}
}
pub fn embedding_dim(&self) -> usize {
self.hidden_dim
}
}
#[derive(thiserror::Error, Debug)]
pub enum EmbeddingError {
#[error("Tokenization error: {0}")]
Tokenization(#[from] crate::tokenizer::TokenizerError),
#[error("Model error: {0}")]
Model(#[from] crate::model::ModelError),
#[error("Tensor error: {0}")]
Tensor(#[from] crate::tensor::TensorError),
#[error("Empty input")]
EmptyInput,
}
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() {
return 0.0;
}
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a > 0.0 && norm_b > 0.0 {
dot / (norm_a * norm_b)
} else {
0.0
}
}
pub fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() {
return f32::INFINITY;
}
a.iter()
.zip(b.iter())
.map(|(x, y)| (x - y).powi(2))
.sum::<f32>()
.sqrt()
}
pub fn dot_product(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() {
return 0.0;
}
a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
}
pub fn find_nearest(query: &[f32], embeddings: &[Vec<f32>], k: usize) -> Vec<(usize, f32)> {
let mut scores: Vec<(usize, f32)> = embeddings
.iter()
.enumerate()
.map(|(i, emb)| (i, cosine_similarity(query, emb)))
.collect();
scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scores.into_iter().take(k).collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_embedding_config_default() {
let config = EmbeddingConfig::default();
assert_eq!(config.layer, -1);
assert!(config.normalize);
assert_eq!(config.pooling, PoolingStrategy::Mean);
}
#[test]
fn test_cosine_similarity() {
let a = vec![1.0, 0.0, 0.0];
let b = vec![1.0, 0.0, 0.0];
assert!((cosine_similarity(&a, &b) - 1.0).abs() < 0.001);
let c = vec![0.0, 1.0, 0.0];
assert!((cosine_similarity(&a, &c)).abs() < 0.001);
}
#[test]
fn test_euclidean_distance() {
let a = vec![0.0, 0.0];
let b = vec![3.0, 4.0];
assert!((euclidean_distance(&a, &b) - 5.0).abs() < 0.001);
}
#[test]
fn test_find_nearest() {
let query = vec![1.0, 0.0];
let embeddings = vec![
vec![1.0, 0.0], vec![0.0, 1.0], vec![0.7, 0.7], ];
let nearest = find_nearest(&query, &embeddings, 2);
assert_eq!(nearest.len(), 2);
assert_eq!(nearest[0].0, 0); }
#[test]
fn test_normalize() {
let extractor = EmbeddingExtractor {
config: EmbeddingConfig::default(),
hidden_dim: 3,
};
let embedding = vec![3.0, 4.0, 0.0];
let normalized = extractor.normalize_embedding(&embedding);
let norm: f32 = normalized.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 0.001);
}
#[test]
fn test_pooling_mean() {
let extractor = EmbeddingExtractor {
config: EmbeddingConfig {
pooling: PoolingStrategy::Mean,
..Default::default()
},
hidden_dim: 2,
};
let embeddings = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
let pooled = extractor.pool_embeddings(&embeddings, 2);
assert!((pooled[0] - 0.5).abs() < 0.001);
assert!((pooled[1] - 0.5).abs() < 0.001);
}
}