embed_anything 0.6.7

Embed anything at lightning speed
Documentation
use candle_core::{Device, Tensor};
use embed_anything::{
    embed_image_directory, embed_query,
    embeddings::embed::{Embedder, EmbedderBuilder},
};
use std::{path::PathBuf, sync::Arc, time::Instant};

#[tokio::main]
async fn main() {
    let now = Instant::now();

    let model = EmbedderBuilder::new()
        .model_id(Some("google/siglip-base-patch16-224"))
        .revision(None)
        .token(None)
        .from_pretrained_hf()
        .unwrap();
    let model: Arc<Embedder> = Arc::new(model);
    let out = embed_image_directory(PathBuf::from("test_files"), &model, None, None)
        .await
        .unwrap()
        .unwrap();

    let query_emb_data = embed_query(&["Photo of a monkey?"], &model, None)
        .await
        .unwrap();
    let n_vectors = out.len();

    let vector = out
        .iter()
        .map(|embed| embed.embedding.clone())
        .collect::<Vec<_>>()
        .into_iter()
        .flat_map(|x| x.to_dense().unwrap())
        .collect::<Vec<_>>();

    let out_embeddings = Tensor::from_vec(
        vector,
        (n_vectors, out[0].embedding.to_dense().unwrap().len()),
        &Device::Cpu,
    )
    .unwrap();

    let image_paths = out
        .iter()
        .map(|embed| embed.text.clone().unwrap())
        .collect::<Vec<_>>();

    let query_embeddings = Tensor::from_vec(
        query_emb_data
            .iter()
            .map(|embed| embed.embedding.clone())
            .collect::<Vec<_>>()
            .into_iter()
            .flat_map(|x| x.to_dense().unwrap())
            .collect::<Vec<_>>(),
        (1, query_emb_data[0].embedding.to_dense().unwrap().len()),
        &Device::Cpu,
    )
    .unwrap();

    let similarities = out_embeddings
        .matmul(&query_embeddings.transpose(0, 1).unwrap())
        .unwrap()
        .detach()
        .squeeze(1)
        .unwrap()
        .to_vec1::<f32>()
        .unwrap();

    let mut indices: Vec<usize> = (0..similarities.len()).collect();
    indices.sort_by(|a, b| similarities[*b].partial_cmp(&similarities[*a]).unwrap());

    println!("Descending order of similarity: ");
    for idx in &indices {
        println!("{}", image_paths[*idx]);
    }

    println!("-----------");

    println!("Most similar image: {}", image_paths[indices[0]]);

    let elapsed_time = now.elapsed();
    println!("Elapsed Time: {}", elapsed_time.as_secs_f32());
}