use std::path::PathBuf;
use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
use crate::error::{Result, SQuaJLError};
use crate::quantized::{QuantizedVector, QuerySketch};
use crate::{SQuaJL, SQuaJLConfig};
pub struct FastEmbedQuantizer {
model: TextEmbedding,
codec: SQuaJL,
}
impl FastEmbedQuantizer {
pub fn new(model_name: EmbeddingModel, codec: SQuaJL) -> Result<Self> {
Self::with_options(model_name, codec, None, None, false)
}
pub fn with_options(
model_name: EmbeddingModel,
codec: SQuaJL,
cache_dir: Option<PathBuf>,
max_length: Option<usize>,
show_download_progress: bool,
) -> Result<Self> {
let mut options = InitOptions::new(model_name);
options.show_download_progress = show_download_progress;
if let Some(cache_dir) = cache_dir {
options.cache_dir = cache_dir;
}
if let Some(max_length) = max_length {
options.max_length = max_length;
}
let model = TextEmbedding::try_new(options)
.map_err(|error| SQuaJLError::Backend(error.to_string()))?;
Ok(Self { model, codec })
}
pub fn all_minilm_l6_v2(config: SQuaJLConfig) -> Result<Self> {
if config.input_dim != 384 {
return Err(SQuaJLError::InvalidConfig(
"all-MiniLM-L6-v2 emits 384-dimensional embeddings".to_owned(),
));
}
let codec = SQuaJL::new(config)?;
Self::new(EmbeddingModel::AllMiniLML6V2, codec)
}
pub fn codec(&self) -> &SQuaJL {
&self.codec
}
pub fn model(&self) -> &TextEmbedding {
&self.model
}
pub fn embed_texts<S>(
&mut self,
texts: &[S],
batch_size: Option<usize>,
) -> Result<Vec<Vec<f32>>>
where
S: AsRef<str> + Send + Sync,
{
self.model
.embed(texts, batch_size)
.map_err(|error| SQuaJLError::Backend(error.to_string()))
}
pub fn quantize_texts<S>(
&mut self,
texts: &[S],
batch_size: Option<usize>,
) -> Result<Vec<QuantizedVector>>
where
S: AsRef<str> + Send + Sync,
{
let embeddings = self.embed_texts(texts, batch_size)?;
embeddings
.iter()
.map(|embedding| self.codec.encode(embedding))
.collect()
}
pub fn embed_query(&mut self, text: &str) -> Result<QuerySketch> {
let embeddings = self.embed_texts(&[text], Some(1))?;
let embedding = embeddings
.into_iter()
.next()
.ok_or_else(|| SQuaJLError::Backend("fastembed returned no embedding".to_owned()))?;
self.codec.sketch_query(&embedding)
}
}