use fastembed::{EmbeddingModel as FastEmbeddingModel, InitOptions, TextEmbedding};
use crate::error::TldrError;
use crate::semantic::similarity::normalize;
use crate::semantic::types::EmbeddingModel;
use crate::TldrResult;
#[derive(Debug, Clone, Default)]
pub struct EmbedOptions {
pub model: EmbeddingModel,
pub show_progress: bool,
pub use_prefix: bool,
}
pub struct Embedder {
model: TextEmbedding,
config: EmbeddingModel,
}
impl Embedder {
pub fn new(model: EmbeddingModel) -> TldrResult<Self> {
let fast_model = Self::to_fastembed_model(model);
eprintln!(
"Loading embedding model ({})... First run may download ~{}MB model.",
model.model_name(),
Self::model_size_mb(model)
);
let mut embedding = TextEmbedding::try_new(InitOptions::new(fast_model)).map_err(|e| {
TldrError::ModelLoadError {
model: model.model_name().to_string(),
detail: e.to_string(),
}
})?;
let test_result = embedding
.embed(vec!["test"], None)
.map_err(|e| TldrError::Embedding(format!("Model integrity check failed: {}", e)))?;
if test_result.is_empty() {
return Err(TldrError::Embedding(
"Model integrity check failed: empty result".to_string(),
));
}
let actual_dims = test_result[0].len();
let expected_dims = model.dimensions();
if actual_dims != expected_dims {
return Err(TldrError::Embedding(format!(
"Model integrity check failed: expected {} dimensions, got {}",
expected_dims, actual_dims
)));
}
Ok(Self {
model: embedding,
config: model,
})
}
fn to_fastembed_model(model: EmbeddingModel) -> FastEmbeddingModel {
match model {
EmbeddingModel::ArcticXS => FastEmbeddingModel::SnowflakeArcticEmbedXS,
EmbeddingModel::ArcticS => FastEmbeddingModel::SnowflakeArcticEmbedS,
EmbeddingModel::ArcticM => FastEmbeddingModel::SnowflakeArcticEmbedM,
EmbeddingModel::ArcticMLong => FastEmbeddingModel::SnowflakeArcticEmbedMLong,
EmbeddingModel::ArcticL => FastEmbeddingModel::SnowflakeArcticEmbedL,
}
}
fn model_size_mb(model: EmbeddingModel) -> usize {
match model {
EmbeddingModel::ArcticXS => 30,
EmbeddingModel::ArcticS => 90,
EmbeddingModel::ArcticM | EmbeddingModel::ArcticMLong => 110,
EmbeddingModel::ArcticL => 335,
}
}
pub fn embed_text(&mut self, text: &str) -> TldrResult<Vec<f32>> {
if text.is_empty() {
return Ok(vec![0.0; self.config.dimensions()]);
}
let result = self
.model
.embed(vec![text], None)
.map_err(|e| TldrError::Embedding(format!("Failed to embed text: {}", e)))?;
let mut embedding = result
.into_iter()
.next()
.ok_or_else(|| TldrError::Embedding("No embedding returned".to_string()))?;
normalize(&mut embedding);
Ok(embedding)
}
pub fn embed_batch(
&mut self,
texts: Vec<&str>,
show_progress: bool,
) -> TldrResult<Vec<Vec<f32>>> {
if texts.is_empty() {
return Ok(Vec::new());
}
let batch_size = if show_progress { Some(32) } else { None };
let results = self
.model
.embed(texts, batch_size)
.map_err(|e| TldrError::Embedding(format!("Failed to embed batch: {}", e)))?;
let normalized: Vec<Vec<f32>> = results
.into_iter()
.map(|mut v| {
normalize(&mut v);
v
})
.collect();
Ok(normalized)
}
pub fn config(&self) -> EmbeddingModel {
self.config
}
pub fn dimensions(&self) -> usize {
self.config.dimensions()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::semantic::similarity::is_normalized;
#[test]
fn embed_options_default_values() {
let options = EmbedOptions::default();
assert_eq!(options.model, EmbeddingModel::ArcticM);
assert!(!options.show_progress);
assert!(!options.use_prefix);
}
#[test]
#[ignore = "Requires model download (~110MB)"]
fn embedder_new_initializes_model() {
let model = EmbeddingModel::ArcticM;
let embedder = Embedder::new(model);
assert!(
embedder.is_ok(),
"Failed to initialize: {:?}",
embedder.err()
);
let embedder = embedder.unwrap();
assert_eq!(embedder.config(), model);
assert_eq!(embedder.dimensions(), 768);
}
#[test]
#[ignore = "Requires model download (~110MB)"]
fn embedder_embed_text_returns_correct_dimensions() {
let mut embedder = Embedder::new(EmbeddingModel::ArcticM).expect("Failed to init");
let embedding = embedder
.embed_text("fn process_data() { }")
.expect("Failed to embed");
assert_eq!(embedding.len(), 768, "Expected 768 dimensions for ArcticM");
}
#[test]
#[ignore = "Requires model download (~110MB)"]
fn embedder_embed_text_is_normalized() {
let mut embedder = Embedder::new(EmbeddingModel::ArcticM).expect("Failed to init");
let embedding = embedder
.embed_text("fn process_data() { }")
.expect("Failed to embed");
assert!(
is_normalized(&embedding),
"Embedding should have L2 norm = 1.0"
);
}
#[test]
#[ignore = "Requires model download (~110MB)"]
fn embedder_batch_embedding_matches_single() {
let mut embedder = Embedder::new(EmbeddingModel::ArcticM).expect("Failed to init");
let text1 = "fn foo() { }";
let text2 = "fn bar() { }";
let single1 = embedder.embed_text(text1).expect("Failed single embed 1");
let single2 = embedder.embed_text(text2).expect("Failed single embed 2");
let batch = embedder
.embed_batch(vec![text1, text2], false)
.expect("Failed batch embed");
assert_eq!(batch.len(), 2);
for (a, b) in single1.iter().zip(batch[0].iter()) {
assert!(
(a - b).abs() < 1e-5,
"Single vs batch mismatch: {} vs {}",
a,
b
);
}
for (a, b) in single2.iter().zip(batch[1].iter()) {
assert!(
(a - b).abs() < 1e-5,
"Single vs batch mismatch: {} vs {}",
a,
b
);
}
}
#[test]
#[ignore = "Requires model download (~110MB)"]
fn embedder_empty_input_returns_zero_vector() {
let mut embedder = Embedder::new(EmbeddingModel::ArcticM).expect("Failed to init");
let embedding = embedder.embed_text("").expect("Failed to embed empty");
assert_eq!(embedding.len(), 768);
assert!(
embedding.iter().all(|&x| x == 0.0),
"Empty input should produce zero vector"
);
}
#[test]
#[ignore = "Requires model download (~110MB)"]
fn embedder_batch_empty_list_returns_empty() {
let mut embedder = Embedder::new(EmbeddingModel::ArcticM).expect("Failed to init");
let embeddings = embedder
.embed_batch(vec![], false)
.expect("Failed to embed empty batch");
assert!(embeddings.is_empty());
}
#[test]
#[ignore = "Requires model download (~30MB for XS)"]
fn embedder_xs_model_dimensions() {
let mut embedder = Embedder::new(EmbeddingModel::ArcticXS).expect("Failed to init XS");
let embedding = embedder.embed_text("test").expect("Failed to embed");
assert_eq!(embedding.len(), 384);
assert!(is_normalized(&embedding));
}
#[test]
#[ignore = "Requires model download (~110MB)"]
fn embedder_deterministic_results() {
let mut embedder = Embedder::new(EmbeddingModel::ArcticM).expect("Failed to init");
let text = "fn process_data(input: &str) -> Result<Output>";
let e1 = embedder.embed_text(text).expect("Failed embed 1");
let e2 = embedder.embed_text(text).expect("Failed embed 2");
assert_eq!(e1.len(), e2.len());
for (a, b) in e1.iter().zip(e2.iter()) {
assert!(
(a - b).abs() < 1e-6,
"Embeddings should be deterministic: {} vs {}",
a,
b
);
}
}
}