use super::loader::EmbeddingTable;
use crate::sentencepiece::tokenizer::SentencePieceProcessor;
pub fn embed_text(
text: &str,
tokenizer: &SentencePieceProcessor,
table: &EmbeddingTable,
) -> Vec<f32> {
let tokens = match tokenizer.encode(text) {
Ok(t) => t,
Err(_) => return vec![0.0; table.embed_dim],
};
if tokens.is_empty() {
return vec![0.0; table.embed_dim];
}
let mut sum = vec![0.0f32; table.embed_dim];
let mut buf = vec![0.0f32; table.embed_dim];
for &token_id in &tokens {
let tid = token_id as usize;
if tid < table.vocab_size {
table.dequantize_into(tid, &mut buf);
for i in 0..table.embed_dim {
sum[i] += buf[i];
}
}
}
let count = tokens.len() as f32;
for x in &mut sum {
*x /= count;
}
let norm = sum.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 1e-8 {
for x in &mut sum {
*x /= norm;
}
}
sum
}
#[cfg(test)]
mod tests {
#[test]
#[ignore] fn test_embed_text() {
}
}