rig_gemini_grpc/
embedding.rs1pub const EMBEDDING_004: &str = "text-embedding-004";
7
8use rig::embeddings::{self, EmbeddingError};
9
10use super::Client;
11use super::proto::{self, EmbedContentRequest};
12
13#[derive(Clone, Debug)]
14pub struct EmbeddingModel {
15 client: Client,
16 model: String,
17 ndims: usize,
18}
19
20impl EmbeddingModel {
21 pub fn new(client: Client, model: impl Into<String>, dims: Option<usize>) -> Self {
22 Self {
23 client,
24 model: model.into(),
25 ndims: dims.unwrap_or(768), }
27 }
28}
29
30impl embeddings::EmbeddingModel for EmbeddingModel {
31 const MAX_DOCUMENTS: usize = 100;
32
33 type Client = super::Client;
34
35 fn make(client: &Self::Client, model: impl Into<String>, dims: Option<usize>) -> Self {
36 Self::new(client.clone(), model, dims)
37 }
38
39 fn ndims(&self) -> usize {
40 self.ndims
41 }
42
43 async fn embed_texts(
44 &self,
45 documents: impl IntoIterator<Item = String> + rig::wasm_compat::WasmCompatSend,
46 ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
47 let documents_vec: Vec<String> = documents.into_iter().collect();
48 let mut embeddings = Vec::new();
49
50 let mut grpc_client = self
51 .client
52 .grpc_client()
53 .map_err(|e| EmbeddingError::ProviderError(e.to_string()))?;
54
55 for doc in documents_vec {
56 let request = EmbedContentRequest {
57 model: format!("models/{}", self.model),
58 content: Some(proto::Content {
59 parts: vec![proto::Part {
60 data: Some(proto::part::Data::Text(doc.clone())),
61 thought: false,
62 thought_signature: Vec::new(),
63 part_metadata: None,
64 }],
65 role: String::new(),
66 }),
67 task_type: None,
68 title: None,
69 output_dimensionality: Some(self.ndims as i32),
70 };
71
72 let response = grpc_client
73 .embed_content(request)
74 .await
75 .map_err(|e| EmbeddingError::ProviderError(e.to_string()))?
76 .into_inner();
77
78 if let Some(embedding) = response.embedding {
79 embeddings.push(embeddings::Embedding {
80 document: doc,
81 vec: embedding.values.into_iter().map(|v| v as f64).collect(),
82 });
83 } else {
84 return Err(EmbeddingError::ResponseError(
85 "No embedding in response".to_string(),
86 ));
87 }
88 }
89
90 Ok(embeddings)
91 }
92}