#[cfg(feature = "hf-hub")]
use crate::common::load_tokenizer_hf_hub;
use crate::{
common::load_tokenizer,
models::{text_embedding::models_list, ModelTrait},
pooling::Pooling,
Embedding, EmbeddingModel, EmbeddingOutput, ModelInfo, OutputKey, QuantizationMode,
SingleBatchOutput,
};
#[cfg(feature = "hf-hub")]
use anyhow::Context;
use anyhow::Result;
#[cfg(feature = "hf-hub")]
use hf_hub::api::sync::ApiRepo;
use ndarray::Array;
use ort::{
session::{builder::GraphOptimizationLevel, Session},
value::Value,
};
#[cfg(feature = "hf-hub")]
use std::path::PathBuf;
use std::thread::available_parallelism;
use tokenizers::Tokenizer;
#[cfg(feature = "hf-hub")]
use super::TextInitOptions;
use super::{
output, InitOptionsUserDefined, TextEmbedding, UserDefinedEmbeddingModel, DEFAULT_BATCH_SIZE,
};
impl TextEmbedding {
#[cfg(feature = "hf-hub")]
pub fn try_new(options: TextInitOptions) -> Result<Self> {
let TextInitOptions {
max_length,
model_name,
execution_providers,
cache_dir,
show_download_progress,
} = options;
let threads = available_parallelism()?.get();
let model_repo = TextEmbedding::retrieve_model(
model_name.clone(),
cache_dir.clone(),
show_download_progress,
)?;
let model_info = TextEmbedding::get_model_info(&model_name)?;
let model_file_name = &model_info.model_file;
let model_file_reference = model_repo
.get(model_file_name)
.context(format!("Failed to retrieve {}", model_file_name))?;
if !model_info.additional_files.is_empty() {
for file in &model_info.additional_files {
model_repo
.get(file)
.context(format!("Failed to retrieve {}", file))?;
}
}
let post_processing = TextEmbedding::get_default_pooling_method(&model_name);
#[cfg(feature = "directml")]
let has_directml = execution_providers
.iter()
.any(|ep| ep.downcast_ref::<ort::ep::DirectML>().is_some());
#[cfg(not(feature = "directml"))]
let has_directml = false;
let mut builder = Session::builder()?
.with_execution_providers(execution_providers)?
.with_optimization_level(GraphOptimizationLevel::Level3)?
.with_intra_threads(threads)?;
if has_directml {
builder = builder
.with_memory_pattern(false)?
.with_parallel_execution(false)?;
}
let session = builder.commit_from_file(model_file_reference)?;
let tokenizer = load_tokenizer_hf_hub(model_repo, max_length)?;
Ok(Self::new(
tokenizer,
session,
post_processing,
TextEmbedding::get_quantization_mode(&model_name),
model_info.output_key.clone(),
))
}
pub fn try_new_from_user_defined(
model: UserDefinedEmbeddingModel,
options: InitOptionsUserDefined,
) -> Result<Self> {
let InitOptionsUserDefined {
execution_providers,
max_length,
} = options;
let threads = available_parallelism()?.get();
#[cfg(feature = "directml")]
let has_directml = execution_providers
.iter()
.any(|ep| ep.downcast_ref::<ort::ep::DirectML>().is_some());
#[cfg(not(feature = "directml"))]
let has_directml = false;
let session = {
let mut session_builder = Session::builder()?
.with_execution_providers(execution_providers)?
.with_optimization_level(GraphOptimizationLevel::Level3)?
.with_intra_threads(threads)?;
if has_directml {
session_builder = session_builder
.with_memory_pattern(false)?
.with_parallel_execution(false)?;
}
for external_initializer_file in model.external_initializers {
session_builder = session_builder.with_external_initializer_file_in_memory(
external_initializer_file.file_name,
external_initializer_file.buffer.into(),
)?;
}
session_builder.commit_from_memory(&model.onnx_file)?
};
let tokenizer = load_tokenizer(model.tokenizer_files, max_length)?;
Ok(Self::new(
tokenizer,
session,
model.pooling,
model.quantization,
model.output_key,
))
}
fn new(
tokenizer: Tokenizer,
session: Session,
post_process: Option<Pooling>,
quantization: QuantizationMode,
output_key: Option<OutputKey>,
) -> Self {
let need_token_type_ids = session
.inputs()
.iter()
.any(|input| input.name() == "token_type_ids");
Self {
tokenizer,
session,
need_token_type_ids,
pooling: post_process,
quantization,
output_key,
}
}
#[cfg(feature = "hf-hub")]
fn retrieve_model(
model: EmbeddingModel,
cache_dir: PathBuf,
show_download_progress: bool,
) -> anyhow::Result<ApiRepo> {
use crate::common::pull_from_hf;
let model_code = TextEmbedding::get_model_info(&model)?.model_code.clone();
pull_from_hf(model_code, cache_dir, show_download_progress)
}
pub fn get_default_pooling_method(model_name: &EmbeddingModel) -> Option<Pooling> {
match model_name {
EmbeddingModel::AllMiniLML6V2 => Some(Pooling::Mean),
EmbeddingModel::AllMiniLML6V2Q => Some(Pooling::Mean),
EmbeddingModel::AllMiniLML12V2 => Some(Pooling::Mean),
EmbeddingModel::AllMiniLML12V2Q => Some(Pooling::Mean),
EmbeddingModel::BGEBaseENV15 => Some(Pooling::Cls),
EmbeddingModel::BGEBaseENV15Q => Some(Pooling::Cls),
EmbeddingModel::BGELargeENV15 => Some(Pooling::Cls),
EmbeddingModel::BGELargeENV15Q => Some(Pooling::Cls),
EmbeddingModel::BGESmallENV15 => Some(Pooling::Cls),
EmbeddingModel::BGESmallENV15Q => Some(Pooling::Cls),
EmbeddingModel::BGESmallZHV15 => Some(Pooling::Cls),
EmbeddingModel::BGELargeZHV15 => Some(Pooling::Cls),
EmbeddingModel::BGEM3 => Some(Pooling::Cls),
EmbeddingModel::NomicEmbedTextV1 => Some(Pooling::Mean),
EmbeddingModel::NomicEmbedTextV15 => Some(Pooling::Mean),
EmbeddingModel::NomicEmbedTextV15Q => Some(Pooling::Mean),
EmbeddingModel::ParaphraseMLMiniLML12V2 => Some(Pooling::Mean),
EmbeddingModel::ParaphraseMLMiniLML12V2Q => Some(Pooling::Mean),
EmbeddingModel::ParaphraseMLMpnetBaseV2 => Some(Pooling::Mean),
EmbeddingModel::AllMpnetBaseV2 => Some(Pooling::Mean),
EmbeddingModel::ModernBertEmbedLarge => Some(Pooling::Mean),
EmbeddingModel::MultilingualE5Base => Some(Pooling::Mean),
EmbeddingModel::MultilingualE5Small => Some(Pooling::Mean),
EmbeddingModel::MultilingualE5Large => Some(Pooling::Mean),
EmbeddingModel::MxbaiEmbedLargeV1 => Some(Pooling::Cls),
EmbeddingModel::MxbaiEmbedLargeV1Q => Some(Pooling::Cls),
EmbeddingModel::GTEBaseENV15 => Some(Pooling::Cls),
EmbeddingModel::GTEBaseENV15Q => Some(Pooling::Cls),
EmbeddingModel::GTELargeENV15 => Some(Pooling::Cls),
EmbeddingModel::GTELargeENV15Q => Some(Pooling::Cls),
EmbeddingModel::ClipVitB32 => Some(Pooling::Mean),
EmbeddingModel::JinaEmbeddingsV2BaseCode => Some(Pooling::Mean),
EmbeddingModel::JinaEmbeddingsV2BaseEN => Some(Pooling::Mean),
EmbeddingModel::EmbeddingGemma300M => Some(Pooling::Mean),
EmbeddingModel::SnowflakeArcticEmbedXS => Some(Pooling::Cls),
EmbeddingModel::SnowflakeArcticEmbedXSQ => Some(Pooling::Cls),
EmbeddingModel::SnowflakeArcticEmbedS => Some(Pooling::Cls),
EmbeddingModel::SnowflakeArcticEmbedSQ => Some(Pooling::Cls),
EmbeddingModel::SnowflakeArcticEmbedM => Some(Pooling::Cls),
EmbeddingModel::SnowflakeArcticEmbedMQ => Some(Pooling::Cls),
EmbeddingModel::SnowflakeArcticEmbedMLong => Some(Pooling::Cls),
EmbeddingModel::SnowflakeArcticEmbedMLongQ => Some(Pooling::Cls),
EmbeddingModel::SnowflakeArcticEmbedL => Some(Pooling::Cls),
EmbeddingModel::SnowflakeArcticEmbedLQ => Some(Pooling::Cls),
}
}
pub fn get_quantization_mode(model_name: &EmbeddingModel) -> QuantizationMode {
match model_name {
EmbeddingModel::AllMiniLML6V2Q => QuantizationMode::Dynamic,
EmbeddingModel::AllMiniLML12V2Q => QuantizationMode::Dynamic,
EmbeddingModel::BGEBaseENV15Q => QuantizationMode::Static,
EmbeddingModel::BGELargeENV15Q => QuantizationMode::Static,
EmbeddingModel::BGESmallENV15Q => QuantizationMode::Static,
EmbeddingModel::NomicEmbedTextV15Q => QuantizationMode::Dynamic,
EmbeddingModel::ParaphraseMLMiniLML12V2Q => QuantizationMode::Static,
EmbeddingModel::MxbaiEmbedLargeV1Q => QuantizationMode::Dynamic,
EmbeddingModel::GTEBaseENV15Q => QuantizationMode::Dynamic,
EmbeddingModel::GTELargeENV15Q => QuantizationMode::Dynamic,
EmbeddingModel::SnowflakeArcticEmbedXSQ => QuantizationMode::Dynamic,
EmbeddingModel::SnowflakeArcticEmbedSQ => QuantizationMode::Dynamic,
EmbeddingModel::SnowflakeArcticEmbedMQ => QuantizationMode::Dynamic,
EmbeddingModel::SnowflakeArcticEmbedMLongQ => QuantizationMode::Dynamic,
EmbeddingModel::SnowflakeArcticEmbedLQ => QuantizationMode::Dynamic,
_ => QuantizationMode::None,
}
}
pub fn list_supported_models() -> Vec<ModelInfo<EmbeddingModel>> {
models_list()
}
pub fn get_model_info(model: &EmbeddingModel) -> Result<&ModelInfo<EmbeddingModel>> {
EmbeddingModel::get_model_info(model).ok_or_else(|| {
anyhow::Error::msg(format!(
"Model {model:?} not found. Please check if the model is supported \
by the current version."
))
})
}
pub fn transform<S: AsRef<str> + Send + Sync>(
&mut self,
texts: impl AsRef<[S]>,
batch_size: Option<usize>,
) -> Result<EmbeddingOutput> {
let texts = texts.as_ref();
let batch_size = match self.quantization {
QuantizationMode::Dynamic => {
if let Some(batch_size) = batch_size {
if batch_size < texts.len() {
Err(anyhow::Error::msg(
"Dynamic quantization cannot be used with batching. \
This is due to the dynamic quantization process adjusting \
the data range to fit each batch, making the embeddings \
incompatible across batches. Try specifying a batch size \
of `None`, or use a model with static or no quantization.",
))
} else {
Ok(texts.len())
}
} else {
Ok(texts.len())
}
}
_ => Ok(batch_size.unwrap_or(DEFAULT_BATCH_SIZE)),
}?;
let batches = texts
.chunks(batch_size)
.map(|batch| {
let inputs = batch.iter().map(|text| text.as_ref()).collect();
let encodings = self.tokenizer.encode_batch(inputs, true).map_err(|e| {
anyhow::Error::msg(e.to_string()).context("Failed to encode the batch.")
})?;
let encoding_length = encodings
.first()
.ok_or_else(|| anyhow::anyhow!("Tokenizer returned empty encodings"))?
.len();
let batch_size = batch.len();
let max_size = encoding_length * batch_size;
let mut ids_array = Vec::with_capacity(max_size);
let mut mask_array = Vec::with_capacity(max_size);
let mut type_ids_array = Vec::with_capacity(max_size);
encodings.iter().for_each(|encoding| {
let ids = encoding.get_ids();
let mask = encoding.get_attention_mask();
let type_ids = encoding.get_type_ids();
ids_array.extend(ids.iter().map(|x| *x as i64));
mask_array.extend(mask.iter().map(|x| *x as i64));
type_ids_array.extend(type_ids.iter().map(|x| *x as i64));
});
let inputs_ids_array =
Array::from_shape_vec((batch_size, encoding_length), ids_array)?;
let attention_mask_array =
Array::from_shape_vec((batch_size, encoding_length), mask_array)?;
let token_type_ids_array =
Array::from_shape_vec((batch_size, encoding_length), type_ids_array)?;
let mut session_inputs = ort::inputs![
"input_ids" => Value::from_array(inputs_ids_array)?,
"attention_mask" => Value::from_array(attention_mask_array.clone())?,
];
if self.need_token_type_ids {
session_inputs.push((
"token_type_ids".into(),
Value::from_array(token_type_ids_array)?.into(),
));
}
let outputs_map = self
.session
.run(session_inputs)
.map_err(anyhow::Error::new)?
.into_iter()
.map(|(k, v)| (k.to_string(), v))
.collect();
Ok(SingleBatchOutput {
outputs: outputs_map,
attention_mask_array,
})
})
.collect::<Result<Vec<_>>>()?;
Ok(EmbeddingOutput::new(batches))
}
pub fn embed<S: AsRef<str> + Send + Sync>(
&mut self,
texts: impl AsRef<[S]>,
batch_size: Option<usize>,
) -> Result<Vec<Embedding>> {
let batches = self.transform(texts.as_ref(), batch_size)?;
if let Some(output_key) = &self.output_key {
batches.export_with_transformer(output::transformer_with_precedence(
output_key,
self.pooling.clone(),
))
} else {
batches.export_with_transformer(output::transformer_with_precedence(
output::OUTPUT_TYPE_PRECEDENCE,
self.pooling.clone(),
))
}
}
}