use crate::config::{ModelConfig, OpenClipConfig};
use crate::error::ClipError;
use crate::model_manager;
use crate::model_manager::get_default_base_folder;
use crate::onnx::OnnxSession;
use bon::bon;
use ndarray::Array2;
use ort::ep::ExecutionProviderDispatch;
use ort::value::Value;
use std::path::Path;
use tokenizers::{PaddingParams, PaddingStrategy, Tokenizer, TruncationParams};
#[derive(Debug)]
pub struct TextEmbedder {
pub session: OnnxSession,
pub config: OpenClipConfig,
pub model_config: ModelConfig,
tokenizer: Tokenizer,
id_name: String,
mask_name: Option<String>,
}
#[bon]
impl TextEmbedder {
#[builder(finish_fn = build)]
#[cfg(feature = "hf-hub")]
pub async fn from_hf(
#[builder(start_fn)] model_id: &str,
with_execution_providers: Option<&[ExecutionProviderDispatch]>,
) -> Result<Self, ClipError> {
let model_dir = model_manager::get_hf_model(model_id).await?;
Self::from_local_dir(&model_dir)
.maybe_with_execution_providers(with_execution_providers)
.build()
}
#[builder(finish_fn = build)]
pub fn from_local_id(
#[builder(start_fn)] model_id: &str,
base_folder: Option<&Path>,
with_execution_providers: Option<&[ExecutionProviderDispatch]>,
) -> Result<Self, ClipError> {
let base_folder = base_folder.map_or_else(get_default_base_folder, ToOwned::to_owned);
Self::from_local_dir(&base_folder.join(model_id))
.maybe_with_execution_providers(with_execution_providers)
.build()
}
#[builder(finish_fn = build)]
pub fn from_local_dir(
#[builder(start_fn)] model_dir: &Path,
with_execution_providers: Option<&[ExecutionProviderDispatch]>,
) -> Result<Self, ClipError> {
model_manager::verify_model_dir(model_dir)?;
let model_path = model_dir.join("text.onnx");
let config_path = model_dir.join("open_clip_config.json");
let tokenizer_path = model_dir.join("tokenizer.json");
let model_config_path = model_dir.join("model_config.json");
let execution_providers = with_execution_providers.unwrap_or_default();
let model_config = ModelConfig::from_file(model_config_path)?;
let session = OnnxSession::new(model_path, execution_providers)?;
let config = OpenClipConfig::from_file(config_path)?;
let mut tokenizer = Tokenizer::from_file(tokenizer_path)?;
let pad_id = model_config
.pad_id
.or_else(|| tokenizer.get_vocab(true).get("<pad>").copied())
.ok_or_else(|| ClipError::Config("No pad token found in tokenizer".into()))?;
let ctx_len = config.model_cfg.text_cfg.context_length;
tokenizer
.with_padding(Some(PaddingParams {
strategy: PaddingStrategy::Fixed(ctx_len),
pad_id,
..Default::default()
}))
.with_truncation(Some(TruncationParams {
max_length: ctx_len,
..Default::default()
}))?;
let id_name = session
.find_input(&["input_ids"])?
.ok_or_else(|| ClipError::Config("Could not find text input node".into()))?;
let mask_name = session.find_input(&["attention_mask"])?;
Ok(Self {
session,
config,
model_config,
tokenizer,
id_name,
mask_name,
})
}
pub fn tokenize<T: AsRef<str>>(
&self,
texts: &[T],
) -> Result<(Array2<i64>, Array2<i64>), ClipError> {
let encodings = if self.model_config.tokenizer_needs_lowercase {
let lowered = texts.iter().map(|s| s.as_ref().to_lowercase()).collect();
self.tokenizer.encode_batch(lowered, true)
} else {
let texts = texts.iter().map(AsRef::as_ref).collect();
self.tokenizer.encode_batch(texts, true)
}?;
let batch_size = encodings.len();
let seq_len = self.config.model_cfg.text_cfg.context_length;
let ids: Vec<i64> = encodings
.iter()
.flat_map(|e| e.get_ids().iter().map(|&x| i64::from(x)))
.collect();
let mask: Vec<i64> = encodings
.iter()
.flat_map(|e| e.get_attention_mask().iter().map(|&x| i64::from(x)))
.collect();
let ids_array = Array2::from_shape_vec((batch_size, seq_len), ids)?;
let mask_array = Array2::from_shape_vec((batch_size, seq_len), mask)?;
Ok((ids_array, mask_array))
}
pub fn embed_text(&self, text: &str) -> Result<ndarray::Array1<f32>, ClipError> {
let embs = self.embed_texts(&[text])?;
let len = embs.len();
Ok(embs.into_shape_with_order(len)?)
}
#[allow(clippy::significant_drop_tightening)]
pub fn embed_texts<T: AsRef<str>>(&self, texts: &[T]) -> Result<Array2<f32>, ClipError> {
let (ids_tensor, mask_tensor) = self.tokenize(texts)?;
let ort_ids = Value::from_array(ids_tensor)?;
let array = {
let mut session = self.session.session.write()?;
let outputs = if let Some(m_name) = &self.mask_name {
let ort_mask = Value::from_array(mask_tensor)?;
session.run(ort::inputs![&self.id_name => ort_ids, m_name => ort_mask])?
} else {
session.run(ort::inputs![&self.id_name => ort_ids])?
};
let (shape, data) = outputs[0].try_extract_tensor::<f32>()?;
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
let shape_usize: Vec<usize> = shape.iter().map(|&x| x as usize).collect();
let view = ndarray::ArrayView::from_shape(ndarray::IxDyn(&shape_usize), data)?;
view.into_dimensionality::<ndarray::Ix2>()?.to_owned()
};
Ok(array)
}
}