Skip to main content

batching_embeddings/
main.rs

1use anyhow::Result;
2use mistralrs::{EmbeddingModelBuilder, EmbeddingRequest};
3
4#[tokio::main]
5async fn main() -> Result<()> {
6    let model = EmbeddingModelBuilder::new("google/embeddinggemma-300m")
7        .with_logging()
8        .build()
9        .await?;
10
11    let a = model
12        .generate_embeddings(
13            EmbeddingRequest::builder()
14                .add_prompt("task: search result | query: What is graphene?"),
15        )
16        .await?;
17    let b =
18        model
19            .generate_embeddings(EmbeddingRequest::builder().add_prompt(
20                "task: search result | query: What is an apple's significance to gravity?",
21            ))
22            .await?;
23
24    let batched = model
25        .generate_embeddings(EmbeddingRequest::builder().add_prompts((0..100).map(|i| {
26            if i % 2 == 0 {
27                "task: search result | query: What is graphene?"
28            } else {
29                "task: search result | query: What is an apple's significance to gravity?"
30            }
31        })))
32        .await?;
33
34    for (i, embedding) in batched.into_iter().enumerate() {
35        if i % 2 == 0 {
36            assert_eq!(embedding, a[0]);
37        } else {
38            assert_eq!(embedding, b[0]);
39        }
40    }
41
42    Ok(())
43}