omni_search 0.2.6

A unified Rust SDK for multimodal embedding and similarity search.
Documentation
use std::path::{Path, PathBuf};

use image::DynamicImage;
use ndarray::{Array, ArrayD, IxDyn};
use ort::value::TensorRef;
use tokenizers::{Encoding, Tokenizer};

use crate::backend::{
    EmbeddingBackend, LazySession, embeddings_from_output, load_tokenizer,
    runtime_snapshot_for_sessions, single_embedding,
};
use crate::bundle::{ModelBundle, ModelInfo};
use crate::config::{RuntimeConfig, SessionPolicy};
use crate::embedding::Embedding;
use crate::error::Error;
use crate::manifest::{ImagePreprocessConfig, TextInputConfig};
use crate::preprocess::clip_image::{
    ClipImagePreprocessConfig, preprocess_image, stack_image_batches,
};
use crate::preprocess::load_image_from_path;
use crate::runtime::{RuntimeSnapshot, RuntimeState};

pub(crate) struct ChineseClipBackend {
    info: ModelInfo,
    runtime: RuntimeConfig,
    normalize_output: bool,
    session_policy: SessionPolicy,
    tokenizer: Tokenizer,
    context_length: usize,
    text_output_name: String,
    input_ids_name: String,
    attention_mask_name: String,
    token_type_ids_name: Option<String>,
    image_output_name: String,
    image_preprocess: ClipImagePreprocessConfig,
    text_session: LazySession,
    image_session: LazySession,
}

impl ChineseClipBackend {
    pub(crate) fn new(bundle: ModelBundle, runtime: RuntimeConfig) -> Result<Self, Error> {
        let tokenizer = load_tokenizer(
            bundle.tokenizer_path(),
            bundle.manifest().text.context_length,
            "[PAD]",
        )?;
        let (input_ids_name, attention_mask_name, token_type_ids_name) =
            match &bundle.manifest().text.input {
                TextInputConfig::BertLike {
                    input_ids_name,
                    attention_mask_name,
                    token_type_ids_name,
                } => (
                    input_ids_name.clone(),
                    attention_mask_name.clone(),
                    token_type_ids_name.clone(),
                ),
                _ => {
                    return Err(Error::invalid_bundle(
                        "chinese clip bundle must use bert_like text input",
                    ));
                }
            };
        let image_preprocess = match &bundle.manifest().image.preprocess {
            ImagePreprocessConfig::ClipImage {
                image_size,
                resize_shortest_edge,
                crop,
                mean,
                std,
            } => ClipImagePreprocessConfig {
                image_size: *image_size,
                resize_shortest_edge: *resize_shortest_edge,
                crop: *crop,
                mean: [mean[0], mean[1], mean[2]],
                std: [std[0], std[1], std[2]],
            },
            _ => {
                return Err(Error::invalid_bundle(
                    "chinese clip bundle must use clip_image preprocess",
                ));
            }
        };
        Ok(Self {
            info: bundle.info().clone(),
            runtime: runtime.clone(),
            normalize_output: bundle.info().normalize_output,
            session_policy: runtime.session_policy,
            tokenizer,
            context_length: bundle.manifest().text.context_length,
            text_output_name: bundle.manifest().text.output_name.clone(),
            input_ids_name,
            attention_mask_name,
            token_type_ids_name,
            image_output_name: bundle.manifest().image.output_name.clone(),
            image_preprocess,
            text_session: LazySession::new(bundle.text_onnx_path().to_path_buf(), runtime.clone()),
            image_session: LazySession::new(bundle.image_onnx_path().to_path_buf(), runtime),
        })
    }

    fn maybe_unload_image(&self) {
        if self.session_policy == SessionPolicy::SingleActive {
            self.image_session.unload();
        }
    }

    fn maybe_unload_text(&self) {
        if self.session_policy == SessionPolicy::SingleActive {
            self.text_session.unload();
        }
    }

    fn encode_texts_internal(&self, texts: &[String]) -> Result<Vec<Embedding>, Error> {
        let encodings = self
            .tokenizer
            .encode_batch(texts.to_vec(), true)
            .map_err(Error::from_tokenizer)?;

        let input_ids = encodings_to_i64_array(&encodings, self.context_length, |encoding| {
            encoding
                .get_ids()
                .iter()
                .map(|id| i64::from(*id))
                .collect::<Vec<_>>()
        })?;
        let attention_mask = encodings_to_i64_array(&encodings, self.context_length, |encoding| {
            encoding
                .get_attention_mask()
                .iter()
                .map(|value| i64::from(*value))
                .collect::<Vec<_>>()
        })?;
        let token_type_ids = if self.token_type_ids_name.is_some() {
            Some(encodings_to_i64_array(
                &encodings,
                self.context_length,
                |encoding| {
                    encoding
                        .get_type_ids()
                        .iter()
                        .map(|value| i64::from(*value))
                        .collect::<Vec<_>>()
                },
            )?)
        } else {
            None
        };

        self.text_session.with_session(|session| {
            let input_ids =
                TensorRef::from_array_view(input_ids.view()).map_err(Error::from_ort)?;
            let attention_mask =
                TensorRef::from_array_view(attention_mask.view()).map_err(Error::from_ort)?;

            let outputs = if let (Some(token_type_ids_name), Some(token_type_ids)) =
                (&self.token_type_ids_name, &token_type_ids)
            {
                let token_type_ids =
                    TensorRef::from_array_view(token_type_ids.view()).map_err(Error::from_ort)?;
                session
                    .run(ort::inputs![
                        self.input_ids_name.as_str() => input_ids,
                        self.attention_mask_name.as_str() => attention_mask,
                        token_type_ids_name.as_str() => token_type_ids,
                    ])
                    .map_err(Error::from_ort)?
            } else {
                session
                    .run(ort::inputs![
                        self.input_ids_name.as_str() => input_ids,
                        self.attention_mask_name.as_str() => attention_mask,
                    ])
                    .map_err(Error::from_ort)?
            };

            let output = outputs.get(self.text_output_name.as_str()).ok_or_else(|| {
                Error::ort(format!(
                    "text output '{}' not found in chinese clip session",
                    self.text_output_name
                ))
            })?;
            let output = output
                .try_extract_array::<f32>()
                .map_err(Error::from_ort)?
                .to_owned()
                .into_dyn();
            embeddings_from_output(&self.info, output, self.normalize_output)
        })
    }

    fn encode_images_internal(&self, images: &[DynamicImage]) -> Result<Vec<Embedding>, Error> {
        let tensors = images
            .iter()
            .map(|image| preprocess_image(image, &self.image_preprocess))
            .collect::<Result<Vec<_>, _>>()?;
        let input = stack_image_batches(&tensors, self.image_preprocess.image_size)?;

        self.image_session.with_session(|session| {
            let pixel_values = TensorRef::from_array_view(input.view()).map_err(Error::from_ort)?;
            let outputs = session
                .run(ort::inputs!["pixel_values" => pixel_values])
                .map_err(Error::from_ort)?;
            let output = outputs
                .get(self.image_output_name.as_str())
                .ok_or_else(|| {
                    Error::ort(format!(
                        "image output '{}' not found in chinese clip session",
                        self.image_output_name
                    ))
                })?;
            let output = output
                .try_extract_array::<f32>()
                .map_err(Error::from_ort)?
                .to_owned()
                .into_dyn();
            embeddings_from_output(&self.info, output, self.normalize_output)
        })
    }
}

impl EmbeddingBackend for ChineseClipBackend {
    fn embed_text(&self, text: &str) -> Result<Embedding, Error> {
        let batch = vec![text.to_owned()];
        self.maybe_unload_image();
        single_embedding(self.encode_texts_internal(&batch)?, "chinese clip text")
    }

    fn embed_texts(&self, texts: &[String]) -> Result<Vec<Embedding>, Error> {
        if texts.is_empty() {
            return Ok(Vec::new());
        }
        self.maybe_unload_image();
        self.encode_texts_internal(texts)
    }

    fn embed_image_path(&self, path: &Path) -> Result<Embedding, Error> {
        self.maybe_unload_text();
        let image = load_image_from_path(path)?;
        single_embedding(self.encode_images_internal(&[image])?, "chinese clip image")
    }

    fn embed_image_bytes(&self, bytes: &[u8]) -> Result<Embedding, Error> {
        self.maybe_unload_text();
        let image = image::load_from_memory(bytes)
            .map_err(|error| Error::image_preprocess(error.to_string()))?;
        single_embedding(self.encode_images_internal(&[image])?, "chinese clip image")
    }

    fn embed_image_paths(&self, paths: &[PathBuf]) -> Result<Vec<Embedding>, Error> {
        if paths.is_empty() {
            return Ok(Vec::new());
        }
        self.maybe_unload_text();
        let images = paths
            .iter()
            .map(|path| load_image_from_path(path))
            .collect::<Result<Vec<_>, _>>()?;
        self.encode_images_internal(&images)
    }

    fn preload_text(&self) -> Result<(), Error> {
        self.maybe_unload_image();
        self.text_session.ensure_loaded()
    }

    fn preload_image(&self) -> Result<(), Error> {
        self.maybe_unload_text();
        self.image_session.ensure_loaded()
    }

    fn unload_text(&self) -> bool {
        self.text_session.unload()
    }

    fn unload_image(&self) -> bool {
        self.image_session.unload()
    }

    fn runtime_state(&self) -> RuntimeState {
        RuntimeState {
            text_loaded: self.text_session.is_loaded(),
            image_loaded: self.image_session.is_loaded(),
            last_text_used_at: self.text_session.last_used_at(),
            last_image_used_at: self.image_session.last_used_at(),
        }
    }

    fn runtime_snapshot(&self) -> RuntimeSnapshot {
        runtime_snapshot_for_sessions(&self.runtime, &self.text_session, &self.image_session)
    }
}

fn encodings_to_i64_array(
    encodings: &[Encoding],
    context_length: usize,
    extractor: impl Fn(&Encoding) -> Vec<i64>,
) -> Result<ArrayD<i64>, Error> {
    let mut values = Vec::with_capacity(encodings.len() * context_length);
    for encoding in encodings {
        let row = extractor(encoding);
        if row.len() != context_length {
            return Err(Error::tokenizer(format!(
                "tokenized length {} does not match configured context length {}",
                row.len(),
                context_length
            )));
        }
        values.extend(row);
    }

    Array::from_shape_vec(IxDyn(&[encodings.len(), context_length]), values)
        .map_err(|error| Error::tokenizer(error.to_string()))
}