open_clip_inference 0.1.0

Run OpenCLIP compatible models in Rust via ONNX Runtime
Documentation
use crate::config::{ModelConfig, OpenClipConfig};
use crate::error::ClipError;
use crate::onnx::OnnxSession;
use ndarray::Array2;
use ort::value::Value;
use std::path::Path;
use tokenizers::{PaddingParams, PaddingStrategy, Tokenizer, TruncationParams};

pub struct TextEmbedder {
    pub session: OnnxSession,
    pub config: OpenClipConfig,
    pub model_config: ModelConfig,
    tokenizer: Tokenizer,
    id_name: String,
    mask_name: Option<String>,
}

impl TextEmbedder {
    pub fn from_model_id(model_id: &str) -> Result<Self, ClipError> {
        let model_dir = OnnxSession::get_model_dir(model_id);
        Self::new(&model_dir)
    }

    pub fn new(model_dir: &Path) -> Result<Self, ClipError> {
        OnnxSession::verify_model_dir(model_dir)?;
        let model_path = model_dir.join("text.onnx");
        let config_path = model_dir.join("open_clip_config.json");
        let tokenizer_path = model_dir.join("tokenizer.json");
        let model_config_path = model_dir.join("model_config.json");

        let model_config = ModelConfig::from_file(model_config_path)?;
        let session = OnnxSession::new(model_path)?;
        let config = OpenClipConfig::from_file(config_path)?;
        let mut tokenizer = Tokenizer::from_file(tokenizer_path)
            .map_err(|e| ClipError::Tokenizer(e.to_string()))?;

        let pad_id = model_config
            .pad_id
            .or_else(|| tokenizer.get_vocab(true).get("<pad>").copied())
            .ok_or_else(|| ClipError::Config("No pad token found in tokenizer".into()))?;
        let ctx_len = config.model_cfg.text_cfg.context_length;

        tokenizer
            .with_padding(Some(PaddingParams {
                strategy: PaddingStrategy::Fixed(ctx_len),
                pad_id,
                ..Default::default()
            }))
            .with_truncation(Some(TruncationParams {
                max_length: ctx_len,
                ..Default::default()
            }))
            .map_err(|e| ClipError::Tokenizer(e.to_string()))?;

        let id_name = session
            .find_input(&["input_ids"])
            .ok_or_else(|| ClipError::Config("Could not find text input node".into()))?;
        let mask_name = session.find_input(&["attention_mask"]);

        Ok(Self {
            session,
            config,
            model_config,
            tokenizer,
            id_name,
            mask_name,
        })
    }

    pub fn tokenize<T: AsRef<str>>(
        &self,
        texts: &[T],
    ) -> Result<(Array2<i64>, Array2<i64>), ClipError> {
        let encodings = if self.model_config.tokenizer_needs_lowercase {
            let lowered = texts.iter().map(|s| s.as_ref().to_lowercase()).collect();
            self.tokenizer.encode_batch(lowered, true)
        } else {
            let texts = texts.iter().map(AsRef::as_ref).collect();
            self.tokenizer.encode_batch(texts, true)
        }
        .map_err(|e| ClipError::Tokenizer(e.to_string()))?;

        let batch_size = encodings.len();
        let seq_len = self.config.model_cfg.text_cfg.context_length;

        let ids: Vec<i64> = encodings
            .iter()
            .flat_map(|e| e.get_ids().iter().map(|&x| i64::from(x)))
            .collect();
        let mask: Vec<i64> = encodings
            .iter()
            .flat_map(|e| e.get_attention_mask().iter().map(|&x| i64::from(x)))
            .collect();

        let ids_array = Array2::from_shape_vec((batch_size, seq_len), ids)
            .map_err(|e| ClipError::Inference(e.to_string()))?;
        let mask_array = Array2::from_shape_vec((batch_size, seq_len), mask)
            .map_err(|e| ClipError::Inference(e.to_string()))?;

        Ok((ids_array, mask_array))
    }

    pub fn embed_text(&mut self, text: &str) -> Result<ndarray::Array1<f32>, ClipError> {
        let embs = self.embed_texts(&[text])?;
        let len = embs.len();
        embs.into_shape_with_order(len)
            .map_err(|e| ClipError::Inference(e.to_string()))
    }

    pub fn embed_texts<T: AsRef<str>>(&mut self, texts: &[T]) -> Result<Array2<f32>, ClipError> {
        let (ids_tensor, mask_tensor) = self.tokenize(texts)?;

        let ort_ids = Value::from_array(ids_tensor)?;
        let outputs = if let Some(m_name) = &self.mask_name {
            let ort_mask = Value::from_array(mask_tensor)?;
            self.session
                .session
                .run(ort::inputs![&self.id_name => ort_ids, m_name => ort_mask])?
        } else {
            self.session
                .session
                .run(ort::inputs![&self.id_name => ort_ids])?
        };

        let (shape, data) = outputs[0].try_extract_tensor::<f32>()?;
        #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
        let shape_usize: Vec<usize> = shape.iter().map(|&x| x as usize).collect();
        let view = ndarray::ArrayView::from_shape(ndarray::IxDyn(&shape_usize), data)
            .map_err(|e| ClipError::Inference(e.to_string()))?;
        Ok(view
            .into_dimensionality::<ndarray::Ix2>()
            .map_err(|e| ClipError::Inference(e.to_string()))?
            .to_owned())
    }
}