embed_anything 0.4.17

Embed anything at lightning speed
Documentation
use candle_core::{Device, Tensor};
use embed_anything::{
    embed_image_directory, embed_query,
    embeddings::embed::{EmbedData, Embedder},
};
use std::{path::PathBuf, sync::Arc, time::Instant};

#[tokio::main]
async fn main() {
    let now = Instant::now();

    let model = Embedder::from_pretrained_hf("clip", "openai/clip-vit-base-patch32", None).unwrap();
    let model: Arc<Embedder> = Arc::new(model);
    let out = embed_image_directory(
        PathBuf::from("test_files"),
        &model,
        None,
        None::<fn(Vec<EmbedData>)>,
    )
    .await
    .unwrap()
    .unwrap();

    let query_emb_data = embed_query(vec!["Photo of a monkey".to_string()], &model, None)
        .await
        .unwrap();
    let n_vectors = out.len();

    let vector = out
        .iter()
        .map(|embed| embed.embedding.clone())
        .collect::<Vec<_>>()
        .into_iter()
        .map(|x| x.to_dense().unwrap())
        .flatten()
        .collect::<Vec<_>>();

    let out_embeddings = Tensor::from_vec(
        vector,
        (n_vectors, out[0].embedding.to_dense().unwrap().len()),
        &Device::Cpu,
    )
    .unwrap();

    let image_paths = out
        .iter()
        .map(|embed| embed.text.clone().unwrap())
        .collect::<Vec<_>>();

    let query_embeddings = Tensor::from_vec(
        query_emb_data
            .iter()
            .map(|embed| embed.embedding.clone())
            .collect::<Vec<_>>()
            .into_iter()
            .map(|x| x.to_dense().unwrap())
            .flatten()
            .collect::<Vec<_>>(),
        (1, query_emb_data[0].embedding.to_dense().unwrap().len()),
        &Device::Cpu,
    )
    .unwrap();

    let similarities = out_embeddings
        .matmul(&query_embeddings.transpose(0, 1).unwrap())
        .unwrap()
        .detach()
        .squeeze(1)
        .unwrap()
        .to_vec1::<f32>()
        .unwrap();
    let mut indices: Vec<usize> = (0..similarities.len()).collect();
    indices.sort_by(|a, b| similarities[*b].partial_cmp(&similarities[*a]).unwrap());

    let top_3_indices = indices[0..3].to_vec();
    let top_3_image_paths = top_3_indices
        .iter()
        .map(|i| image_paths[*i].clone())
        .collect::<Vec<String>>();

    let similar_image = top_3_image_paths[0].clone();

    println!("{:?}", similar_image);

    let elapsed_time = now.elapsed();
    println!("Elapsed Time: {}", elapsed_time.as_secs_f32());
}