use std::path::{Path, PathBuf};
use std::sync::Arc;
use ndarray::{Array1, Array2, Axis};
use ort::environment::Environment;
use ort::GraphOptimizationLevel;
use crate::common::Device;
use crate::error::Result;
use crate::hf_hub::hf_hub_download;
use crate::tokenizer::AutoTokenizer;
use crate::{Embedding, EmbeddingModel, PoolingStrategy};
pub struct EmbeddingPipeline<'a> {
tokenizer: AutoTokenizer,
model: EmbeddingModel<'a>,
}
impl<'a> EmbeddingPipeline<'a> {
pub fn from_pretrained(
env: Arc<Environment>,
model_id: String,
pool_strategy: PoolingStrategy,
device: Device,
optimization_level: GraphOptimizationLevel,
) -> Result<Self> {
let model_dir = Path::new(&model_id);
if model_dir.exists() {
let model_path = model_dir.join("model.onnx");
let tokenizer_path = model_dir.join("tokenizer.json");
let mut special_tokens_path = model_dir.join("special_tokens_map.json");
if !special_tokens_path.exists() {
special_tokens_path = model_dir.join("config.json");
}
Self::new_from_files(
env,
model_path,
tokenizer_path,
special_tokens_path,
pool_strategy,
device,
optimization_level,
)
} else {
let model_path = hf_hub_download(&model_id, "model.onnx", None, None)?;
let tokenizer_path = hf_hub_download(&model_id, "tokenizer.json", None, None)?;
let mut special_tokens_path =
hf_hub_download(&model_id, "special_tokens_map.json", None, None);
if special_tokens_path.is_err() {
special_tokens_path = hf_hub_download(&model_id, "config.json", None, None);
}
Self::new_from_files(
env,
model_path,
tokenizer_path,
special_tokens_path?,
pool_strategy,
device,
optimization_level,
)
}
}
pub fn new_from_files(
environment: Arc<Environment>,
model_path: PathBuf,
tokenizer_config: PathBuf,
special_tokens_map: PathBuf,
pooling: PoolingStrategy,
device: Device,
optimization_level: GraphOptimizationLevel,
) -> Result<Self> {
let tokenizer = AutoTokenizer::new(tokenizer_config, special_tokens_map)?;
let model = EmbeddingModel::new_from_file(
environment,
model_path,
pooling,
device,
optimization_level,
)?;
Ok(Self { tokenizer, model })
}
pub fn new_from_memory(
environment: Arc<Environment>,
model: &'a [u8],
tokenizer_config: String,
special_tokens_map: String,
pooling: PoolingStrategy,
device: Device,
optimization_level: GraphOptimizationLevel,
) -> Result<Self> {
let tokenizer = AutoTokenizer::new_from_memory(tokenizer_config, special_tokens_map)?;
let model = EmbeddingModel::new_from_memory(
environment,
model,
pooling,
device,
optimization_level,
)?;
Ok(Self { tokenizer, model })
}
pub fn embed(&self, input: &str) -> Result<Embedding> {
let tokenized = self.tokenizer.tokenizer.encode(input, false)?;
let input_ids = Array1::from_iter(tokenized.get_ids().iter().map(|i| *i as u32));
let input_ids = input_ids.insert_axis(Axis(0));
let attention_mask =
Array1::from_iter(tokenized.get_attention_mask().iter().map(|i| *i as u32));
let attention_mask = attention_mask.insert_axis(Axis(0));
let token_type_ids = Array1::from_iter(tokenized.get_type_ids().iter().map(|i| *i as u32));
let token_type_ids = token_type_ids.insert_axis(Axis(0));
let mut output =
self.model
.forward(input_ids, Some(attention_mask), Some(token_type_ids))?;
Ok(output.pop().unwrap())
}
pub fn embed_batch(&self, inputs: Vec<String>) -> Result<Vec<Embedding>> {
let tokenized = self.tokenizer.tokenizer.encode_batch(inputs, false)?;
let input_ids = tokenized.iter().map(|t| t.get_ids()).collect::<Vec<_>>();
let input_ids =
Array2::from_shape_vec((input_ids.len(), input_ids[0].len()), input_ids.concat())?;
let attention_mask = tokenized
.iter()
.map(|t| t.get_attention_mask())
.collect::<Vec<_>>();
let attention_mask = Array2::from_shape_vec(
(attention_mask.len(), attention_mask[0].len()),
attention_mask.concat(),
)?;
let token_type_ids = tokenized
.iter()
.map(|t| t.get_type_ids())
.collect::<Vec<_>>();
let token_type_ids = Array2::from_shape_vec(
(token_type_ids.len(), token_type_ids[0].len()),
token_type_ids.concat(),
)?;
let output = self
.model
.forward(input_ids, Some(attention_mask), Some(token_type_ids))?;
Ok(output)
}
}
#[cfg(test)]
mod tests {
use ort::LoggingLevel;
use super::*;
#[test]
fn test_embedding_pipeline() {
let environment = Environment::builder()
.with_name("embedding_pipeline")
.with_log_level(LoggingLevel::Verbose)
.build()
.unwrap();
let pipeline = EmbeddingPipeline::from_pretrained(
environment.into_arc(),
"optimum/all-MiniLM-L6-v2".to_string(),
PoolingStrategy::Mean,
Device::CPU,
GraphOptimizationLevel::Level3,
)
.unwrap();
let input = "This is a test";
let input1 = "This is a test";
let embedding = pipeline.embed(input).unwrap();
let embeddings = pipeline
.embed_batch(vec![input.to_string(), input1.to_string()])
.unwrap();
let sim1 = embedding.similarity(&embeddings[0]);
let sim2 = embedding.similarity(&embeddings[1]);
assert!(sim1 > -1.0);
assert!(sim2 > -1.0);
assert!(sim1 < 1.0);
assert!(sim2 < 1.0);
}
}