use crate::{
common::{
load_tokenizer, load_tokenizer_hf_hub, normalize, Tokenizer, TokenizerFiles,
DEFAULT_CACHE_DIR,
},
models::text_embedding::models_list,
Embedding, EmbeddingModel, ModelInfo,
};
use anyhow::Result;
use hf_hub::{
api::sync::{ApiBuilder, ApiRepo},
Cache,
};
use ndarray::{s, Array};
use ort::{ExecutionProviderDispatch, GraphOptimizationLevel, Session, Value};
use rayon::{iter::ParallelIterator, slice::ParallelSlice};
use std::{
fmt::Display,
path::{Path, PathBuf},
thread::available_parallelism,
};
const DEFAULT_BATCH_SIZE: usize = 256;
const DEFAULT_MAX_LENGTH: usize = 512;
const DEFAULT_EMBEDDING_MODEL: EmbeddingModel = EmbeddingModel::BGESmallENV15;
#[derive(Debug, Clone)]
pub struct InitOptions {
pub model_name: EmbeddingModel,
pub execution_providers: Vec<ExecutionProviderDispatch>,
pub max_length: usize,
pub cache_dir: PathBuf,
pub show_download_progress: bool,
}
impl Default for InitOptions {
fn default() -> Self {
Self {
model_name: DEFAULT_EMBEDDING_MODEL,
execution_providers: Default::default(),
max_length: DEFAULT_MAX_LENGTH,
cache_dir: Path::new(DEFAULT_CACHE_DIR).to_path_buf(),
show_download_progress: true,
}
}
}
#[derive(Debug, Clone)]
pub struct InitOptionsUserDefined {
pub execution_providers: Vec<ExecutionProviderDispatch>,
pub max_length: usize,
}
impl Default for InitOptionsUserDefined {
fn default() -> Self {
Self {
execution_providers: Default::default(),
max_length: DEFAULT_MAX_LENGTH,
}
}
}
impl From<InitOptions> for InitOptionsUserDefined {
fn from(options: InitOptions) -> Self {
InitOptionsUserDefined {
execution_providers: options.execution_providers,
max_length: options.max_length,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct UserDefinedEmbeddingModel {
pub onnx_file: Vec<u8>,
pub tokenizer_files: TokenizerFiles,
}
pub struct TextEmbedding {
pub tokenizer: Tokenizer,
session: Session,
need_token_type_ids: bool,
}
impl Display for EmbeddingModel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let model_info = TextEmbedding::list_supported_models()
.into_iter()
.find(|model| model.model == *self)
.unwrap();
write!(f, "{}", model_info.model_code)
}
}
impl TextEmbedding {
pub fn try_new(options: InitOptions) -> Result<Self> {
let InitOptions {
model_name,
execution_providers,
max_length,
cache_dir,
show_download_progress,
} = options;
let threads = available_parallelism()?.get();
let model_repo = TextEmbedding::retrieve_model(
model_name.clone(),
cache_dir.clone(),
show_download_progress,
)?;
let model_file_name = TextEmbedding::get_model_info(&model_name).model_file;
let model_file_reference = model_repo
.get(&model_file_name)
.unwrap_or_else(|_| panic!("Failed to retrieve {} ", model_file_name));
if model_name == EmbeddingModel::MultilingualE5Large {
model_repo
.get("model.onnx_data")
.expect("Failed to retrieve model.onnx_data.");
}
let session = Session::builder()?
.with_execution_providers(execution_providers)?
.with_optimization_level(GraphOptimizationLevel::Level3)?
.with_intra_threads(threads)?
.commit_from_file(model_file_reference)?;
let tokenizer = load_tokenizer_hf_hub(model_repo, max_length)?;
Ok(Self::new(tokenizer, session))
}
pub fn try_new_from_user_defined(
model: UserDefinedEmbeddingModel,
options: InitOptionsUserDefined,
) -> Result<Self> {
let InitOptionsUserDefined {
execution_providers,
max_length,
} = options;
let threads = available_parallelism()?.get();
let session = Session::builder()?
.with_execution_providers(execution_providers)?
.with_optimization_level(GraphOptimizationLevel::Level3)?
.with_intra_threads(threads)?
.commit_from_memory(&model.onnx_file)?;
let tokenizer = load_tokenizer(model.tokenizer_files, max_length)?;
Ok(Self::new(tokenizer, session))
}
fn new(tokenizer: Tokenizer, session: Session) -> Self {
let need_token_type_ids = session
.inputs
.iter()
.any(|input| input.name == "token_type_ids");
Self {
tokenizer,
session,
need_token_type_ids,
}
}
fn retrieve_model(
model: EmbeddingModel,
cache_dir: PathBuf,
show_download_progress: bool,
) -> Result<ApiRepo> {
let cache = Cache::new(cache_dir);
let api = ApiBuilder::from_cache(cache)
.with_progress(show_download_progress)
.build()
.unwrap();
let repo = api.model(model.to_string());
Ok(repo)
}
pub fn list_supported_models() -> Vec<ModelInfo> {
models_list()
}
pub fn get_model_info(model: &EmbeddingModel) -> ModelInfo {
TextEmbedding::list_supported_models()
.into_iter()
.find(|m| &m.model == model)
.expect("Model not found.")
}
pub fn embed<S: AsRef<str> + Send + Sync>(
&self,
texts: Vec<S>,
batch_size: Option<usize>,
) -> Result<Vec<Embedding>> {
let batch_size = batch_size.unwrap_or(DEFAULT_BATCH_SIZE);
let output = texts
.par_chunks(batch_size)
.map(|batch| {
let inputs = batch.iter().map(|text| text.as_ref()).collect();
let encodings = self.tokenizer.encode_batch(inputs, true).unwrap();
let encoding_length = encodings[0].len();
let batch_size = batch.len();
let max_size = encoding_length * batch_size;
let mut ids_array = Vec::with_capacity(max_size);
let mut mask_array = Vec::with_capacity(max_size);
let mut typeids_array = Vec::with_capacity(max_size);
encodings.iter().for_each(|encoding| {
let ids = encoding.get_ids();
let mask = encoding.get_attention_mask();
let typeids = encoding.get_type_ids();
ids_array.extend(ids.iter().map(|x| *x as i64));
mask_array.extend(mask.iter().map(|x| *x as i64));
typeids_array.extend(typeids.iter().map(|x| *x as i64));
});
let inputs_ids_array =
Array::from_shape_vec((batch_size, encoding_length), ids_array)?;
let attention_mask_array =
Array::from_shape_vec((batch_size, encoding_length), mask_array)?;
let token_type_ids_array =
Array::from_shape_vec((batch_size, encoding_length), typeids_array)?;
let mut session_inputs = ort::inputs![
"input_ids" => Value::from_array(inputs_ids_array)?,
"attention_mask" => Value::from_array(attention_mask_array)?,
]?;
if self.need_token_type_ids {
session_inputs.push((
"token_type_ids".into(),
Value::from_array(token_type_ids_array)?.into(),
));
}
let outputs = self.session.run(session_inputs)?;
let last_hidden_state_key = match outputs.len() {
1 => outputs.keys().next().unwrap(),
_ => "last_hidden_state",
};
let output_data = outputs[last_hidden_state_key].try_extract_tensor::<f32>()?;
let embeddings: Vec<Vec<f32>> = output_data
.slice(s![.., 0, ..])
.rows()
.into_iter()
.map(|row| normalize(row.as_slice().unwrap()))
.collect();
Ok(embeddings)
})
.flat_map(|result: Result<Vec<Vec<f32>>, anyhow::Error>| result.unwrap())
.collect();
Ok(output)
}
}