embed_anything 0.6.7

Embed anything at lightning speed
Documentation
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;

#[cfg(feature = "accelerate")]
extern crate accelerate_src;

use crate::{
    embeddings::{embed::EmbeddingResult, normalize_l2, select_device, utils::tokenize_batch},
    models::qwen3::{Config, Model},
};
use anyhow::Error;
use candle_core::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use hf_hub::{api::sync::ApiBuilder, Repo};
use tokenizers::{PaddingParams, Tokenizer, TruncationParams};

use super::{
    colpali::hub_load_safetensors,
    pooling::{ModelOutput, PooledOutputType, Pooling},
};

pub trait Qwen3Embed {
    fn embed(
        &self,
        text_batch: &[&str],
        batch_size: Option<usize>,
        late_chunking: Option<bool>,
    ) -> Result<Vec<EmbeddingResult>, anyhow::Error>;
}

pub struct Qwen3Embedder {
    pub model: std::sync::RwLock<Model>,
    pub tokenizer: Tokenizer,
    pub device: Device,
}

impl Qwen3Embedder {
    pub fn new(
        model_id: &str,
        revision: Option<String>,
        token: Option<&str>,
        dtype: Option<crate::Dtype>,
    ) -> Result<Self, anyhow::Error> {
        let api = ApiBuilder::from_env()
            .with_token(token.map(|s| s.to_string()))
            .build()
            .unwrap();

        let repo = match revision {
            Some(rev) => api.repo(Repo::with_revision(
                model_id.to_string(),
                hf_hub::RepoType::Model,
                rev,
            )),
            None => api.repo(hf_hub::Repo::new(
                model_id.to_string(),
                hf_hub::RepoType::Model,
            )),
        };
        let (config_filename, tokenizer_filename, weights_filename) = {
            let config = repo.get("config.json")?;
            let tokenizer = repo.get("tokenizer.json")?;
            let weights = repo.get("model.safetensors");

            (config, tokenizer, weights)
        };

        let config = std::fs::read_to_string(config_filename)?;
        let config: Config = serde_json::from_str(&config)?;

        let mut tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(Error::msg)?;
        let pp = PaddingParams {
            strategy: tokenizers::PaddingStrategy::BatchLongest,
            direction: tokenizers::PaddingDirection::Left,
            ..Default::default()
        };
        let trunc = TruncationParams {
            strategy: tokenizers::TruncationStrategy::LongestFirst,
            max_length: 1024,
            ..Default::default()
        };

        tokenizer
            .with_padding(Some(pp))
            .with_truncation(Some(trunc))
            .map_err(Error::msg)?;

        let device = select_device();
        let dtype = match dtype {
            Some(crate::Dtype::F16) => DType::F16,
            Some(crate::Dtype::F32) => DType::F32,
            Some(crate::Dtype::BF16) => DType::BF16,
            _ => DType::F32,
        };

        let vb = match weights_filename {
            Ok(weights) => unsafe {
                VarBuilder::from_mmaped_safetensors(&[weights], dtype, &device)?
            },
            Err(_) => {
                let weights = hub_load_safetensors(&repo, "model.safetensors.index.json")?;
                unsafe { VarBuilder::from_mmaped_safetensors(&weights, dtype, &device)? }
            }
        };

        let model = Model::new(&config, vb)?;

        Ok(Self {
            model: std::sync::RwLock::new(model),
            tokenizer,
            device,
        })
    }
}

impl Qwen3Embed for Qwen3Embedder {
    fn embed(
        &self,
        text_batch: &[&str],
        batch_size: Option<usize>,
        _late_chunking: Option<bool>,
    ) -> Result<Vec<EmbeddingResult>, anyhow::Error> {
        let batch_size = batch_size.unwrap_or(32);
        let mut encodings: Vec<EmbeddingResult> = Vec::new();

        for mini_text_batch in text_batch.chunks(batch_size) {
            let (token_ids, attention_mask) =
                tokenize_batch(&self.tokenizer, mini_text_batch, &self.device)?;

            let embeddings: Tensor = {
                let mut model = self
                    .model
                    .write()
                    .map_err(|e| anyhow::anyhow!("Lock poisoned: {}", e))?;
                let result = model
                    .forward(&token_ids, &attention_mask, 0)?
                    .to_dtype(DType::F32)?;
                model.clear_kv_cache();

                result
            };

            let attention_mask = PooledOutputType::from(attention_mask);
            let attention_mask = Some(&attention_mask);
            let model_output = ModelOutput::Tensor(embeddings.clone());
            let pooled_output = Pooling::LastToken
                .pool(&model_output, attention_mask)
                .unwrap();
            let pooled_output = pooled_output.to_tensor()?;
            let embeddings = normalize_l2(pooled_output)?;
            let batch_encodings = embeddings.to_vec2::<f32>()?;

            encodings.extend(
                batch_encodings
                    .iter()
                    .map(|x| EmbeddingResult::DenseVector(x.to_vec())),
            );
        }

        Ok(encodings)
    }
}
#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_qwen3_embed() {
        let embedder = Qwen3Embedder::new(
            "Qwen/Qwen3-Embedding-0.6B",
            None,
            None,
            Some(crate::Dtype::F32),
        )
        .unwrap();
        let embeddings = embedder
            .embed(
                &["Hello, world!", "I am a rust programmer now"],
                Some(2),
                None,
            )
            .unwrap();
        let test_embeddings: Vec<f32> = vec![
            0.00555867,
            0.00928946,
            -0.00985782,
            -0.06393453,
            0.00829317,
            0.00708855,
        ];
        let first_embeddings = embeddings[0].to_dense().unwrap()[0..6].to_vec();
        println!("{:?}", first_embeddings);
        assert!(
            (first_embeddings
                .iter()
                .zip(test_embeddings.iter())
                .all(|(a, b)| (a.abs() - b.abs()).abs() < 1e-6))
        );
        let test_embeddings: Vec<f32> = vec![
            0.03579775,
            -0.04019123,
            -0.01412615,
            -0.05743032,
            0.04517555,
            -0.0193235,
        ];

        let second_embeddings = embeddings[1].to_dense().unwrap()[0..6].to_vec();
        println!("{:?}", second_embeddings);
        assert!(
            (second_embeddings
                .iter()
                .zip(test_embeddings.iter())
                .all(|(a, b)| (a.abs() - b.abs()).abs() < 1e-6))
        );
    }
}