open_clip_inference 0.4.0

Run OpenCLIP compatible embedding models via ONNX Runtime
Documentation
use crate::config::{ModelConfig, OpenClipConfig};
use crate::error::ClipError;
use crate::model_manager;
use crate::model_manager::get_default_base_folder;
use crate::onnx::OnnxSession;
use bon::bon;
use ndarray::Array2;
use ort::ep::ExecutionProviderDispatch;
use ort::value::Value;
use std::path::{Path, PathBuf};
use tokenizers::{PaddingParams, PaddingStrategy, Tokenizer, TruncationParams};

#[derive(Debug)]
pub struct TextEmbedder {
    pub session: OnnxSession,
    pub config: OpenClipConfig,
    pub model_config: ModelConfig,
    pub model_dir: PathBuf,
    tokenizer: Tokenizer,
    id_name: String,
    mask_name: Option<String>,
}

#[bon]
impl TextEmbedder {
    /// Load vision embedder from a `HuggingFace` model ID
    #[builder(finish_fn = build)]
    #[cfg(feature = "hf-hub")]
    pub async fn from_hf(
        #[builder(start_fn)] model_id: &str,
        with_execution_providers: Option<&[ExecutionProviderDispatch]>,
    ) -> Result<Self, ClipError> {
        let model_dir = model_manager::get_hf_model(model_id).await?;
        Self::from_local_dir(&model_dir)
            .maybe_with_execution_providers(with_execution_providers)
            .build()
    }

    /// Load text embedder from a locally converted model ID
    #[builder(finish_fn = build)]
    pub fn from_local_id(
        #[builder(start_fn)] model_id: &str,
        base_folder: Option<&Path>,
        with_execution_providers: Option<&[ExecutionProviderDispatch]>,
    ) -> Result<Self, ClipError> {
        let base_folder = base_folder.map_or_else(get_default_base_folder, ToOwned::to_owned);
        Self::from_local_dir(&base_folder.join(model_id))
            .maybe_with_execution_providers(with_execution_providers)
            .build()
    }

    /// Load text embedder from a specific directory
    #[builder(finish_fn = build)]
    pub fn from_local_dir(
        #[builder(start_fn)] model_dir: &Path,
        with_execution_providers: Option<&[ExecutionProviderDispatch]>,
    ) -> Result<Self, ClipError> {
        model_manager::verify_model_dir(model_dir)?;
        let model_path = model_dir.join("text.onnx");
        let config_path = model_dir.join("open_clip_config.json");
        let tokenizer_path = model_dir.join("tokenizer.json");
        let model_config_path = model_dir.join("model_config.json");
        let execution_providers = with_execution_providers.unwrap_or_default();

        let model_config = ModelConfig::from_file(model_config_path)?;
        let session = OnnxSession::new(model_path, execution_providers)?;
        let config = OpenClipConfig::from_file(config_path)?;
        let mut tokenizer = Tokenizer::from_file(tokenizer_path)?;

        let pad_id = model_config
            .pad_id
            .or_else(|| tokenizer.get_vocab(true).get("<pad>").copied())
            .ok_or_else(|| ClipError::Config("No pad token found in tokenizer".into()))?;
        let ctx_len = config.model_cfg.text_cfg.context_length;

        tokenizer
            .with_padding(Some(PaddingParams {
                strategy: PaddingStrategy::Fixed(ctx_len),
                pad_id,
                ..Default::default()
            }))
            .with_truncation(Some(TruncationParams {
                max_length: ctx_len,
                ..Default::default()
            }))?;

        let id_name = session
            .find_input(&["input_ids"])?
            .ok_or_else(|| ClipError::Config("Could not find text input node".into()))?;
        let mask_name = session.find_input(&["attention_mask"])?;

        Ok(Self {
            session,
            config,
            model_config,
            tokenizer,
            id_name,
            mask_name,
            model_dir: model_dir.to_path_buf(),
        })
    }

    /// Create a new instance of the model
    pub fn duplicate(&self) -> Result<Self, ClipError> {
        Self::from_local_dir(&self.model_dir)
            .with_execution_providers(&self.session.execution_providers)
            .build()
    }

    /// Tokenize a batch of texts
    pub fn tokenize<T: AsRef<str>>(
        &self,
        texts: &[T],
    ) -> Result<(Array2<i64>, Array2<i64>), ClipError> {
        let encodings = if self.model_config.tokenizer_needs_lowercase {
            let lowered = texts.iter().map(|s| s.as_ref().to_lowercase()).collect();
            self.tokenizer.encode_batch(lowered, true)
        } else {
            let texts = texts.iter().map(AsRef::as_ref).collect();
            self.tokenizer.encode_batch(texts, true)
        }?;

        let batch_size = encodings.len();
        let seq_len = self.config.model_cfg.text_cfg.context_length;

        let ids: Vec<i64> = encodings
            .iter()
            .flat_map(|e| e.get_ids().iter().map(|&x| i64::from(x)))
            .collect();
        let mask: Vec<i64> = encodings
            .iter()
            .flat_map(|e| e.get_attention_mask().iter().map(|&x| i64::from(x)))
            .collect();

        let ids_array = Array2::from_shape_vec((batch_size, seq_len), ids)?;
        let mask_array = Array2::from_shape_vec((batch_size, seq_len), mask)?;

        Ok((ids_array, mask_array))
    }

    /// Embed a single text
    pub fn embed_text(&self, text: &str) -> Result<ndarray::Array1<f32>, ClipError> {
        let embs = self.embed_texts(&[text])?;
        let len = embs.len();
        Ok(embs.into_shape_with_order(len)?)
    }

    /// Embed a batch of texts
    #[allow(clippy::significant_drop_tightening)]
    pub fn embed_texts<T: AsRef<str>>(&self, texts: &[T]) -> Result<Array2<f32>, ClipError> {
        let (ids_tensor, mask_tensor) = self.tokenize(texts)?;

        let ort_ids = Value::from_array(ids_tensor)?;
        let array = {
            let mut session = self.session.session.write()?;
            let outputs = if let Some(m_name) = &self.mask_name {
                let ort_mask = Value::from_array(mask_tensor)?;
                session.run(ort::inputs![&self.id_name => ort_ids, m_name => ort_mask])?
            } else {
                session.run(ort::inputs![&self.id_name => ort_ids])?
            };
            let (shape, data) = outputs[0].try_extract_tensor::<f32>()?;
            #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
            let shape_usize: Vec<usize> = shape.iter().map(|&x| x as usize).collect();
            let view = ndarray::ArrayView::from_shape(ndarray::IxDyn(&shape_usize), data)?;
            view.into_dimensionality::<ndarray::Ix2>()?.to_owned()
        };
        Ok(array)
    }
}