batching_embeddings/
main.rs1use anyhow::Result;
2use mistralrs::{EmbeddingModelBuilder, EmbeddingRequest};
3
4#[tokio::main]
5async fn main() -> Result<()> {
6 let model = EmbeddingModelBuilder::new("google/embeddinggemma-300m")
7 .with_logging()
8 .build()
9 .await?;
10
11 let a = model
12 .generate_embeddings(
13 EmbeddingRequest::builder()
14 .add_prompt("task: search result | query: What is graphene?"),
15 )
16 .await?;
17 let b =
18 model
19 .generate_embeddings(EmbeddingRequest::builder().add_prompt(
20 "task: search result | query: What is an apple's significance to gravity?",
21 ))
22 .await?;
23
24 let batched = model
25 .generate_embeddings(EmbeddingRequest::builder().add_prompts((0..100).map(|i| {
26 if i % 2 == 0 {
27 "task: search result | query: What is graphene?"
28 } else {
29 "task: search result | query: What is an apple's significance to gravity?"
30 }
31 })))
32 .await?;
33
34 for (i, embedding) in batched.into_iter().enumerate() {
35 if i % 2 == 0 {
36 assert_eq!(embedding, a[0]);
37 } else {
38 assert_eq!(embedding, b[0]);
39 }
40 }
41
42 Ok(())
43}