Skip to main content

rig_gemini_grpc/
embedding.rs

1// ================================================================
2//! Google Gemini gRPC Embedding Integration
3// ================================================================
4
5/// `text-embedding-004` embedding model
6pub const EMBEDDING_004: &str = "text-embedding-004";
7
8use rig::embeddings::{self, EmbeddingError};
9
10use super::Client;
11use super::proto::{self, EmbedContentRequest};
12
13#[derive(Clone, Debug)]
14pub struct EmbeddingModel {
15    client: Client,
16    model: String,
17    ndims: usize,
18}
19
20impl EmbeddingModel {
21    pub fn new(client: Client, model: impl Into<String>, dims: Option<usize>) -> Self {
22        Self {
23            client,
24            model: model.into(),
25            ndims: dims.unwrap_or(768), // Default embedding size for text-embedding-004
26        }
27    }
28}
29
30impl embeddings::EmbeddingModel for EmbeddingModel {
31    const MAX_DOCUMENTS: usize = 100;
32
33    type Client = super::Client;
34
35    fn make(client: &Self::Client, model: impl Into<String>, dims: Option<usize>) -> Self {
36        Self::new(client.clone(), model, dims)
37    }
38
39    fn ndims(&self) -> usize {
40        self.ndims
41    }
42
43    async fn embed_texts(
44        &self,
45        documents: impl IntoIterator<Item = String> + rig::wasm_compat::WasmCompatSend,
46    ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
47        let documents_vec: Vec<String> = documents.into_iter().collect();
48        let mut embeddings = Vec::new();
49
50        let mut grpc_client = self
51            .client
52            .grpc_client()
53            .map_err(|e| EmbeddingError::ProviderError(e.to_string()))?;
54
55        for doc in documents_vec {
56            let request = EmbedContentRequest {
57                model: format!("models/{}", self.model),
58                content: Some(proto::Content {
59                    parts: vec![proto::Part {
60                        data: Some(proto::part::Data::Text(doc.clone())),
61                        thought: false,
62                        thought_signature: Vec::new(),
63                        part_metadata: None,
64                    }],
65                    role: String::new(),
66                }),
67                task_type: None,
68                title: None,
69                output_dimensionality: Some(self.ndims as i32),
70            };
71
72            let response = grpc_client
73                .embed_content(request)
74                .await
75                .map_err(|e| EmbeddingError::ProviderError(e.to_string()))?
76                .into_inner();
77
78            if let Some(embedding) = response.embedding {
79                embeddings.push(embeddings::Embedding {
80                    document: doc,
81                    vec: embedding.values.into_iter().map(|v| v as f64).collect(),
82                });
83            } else {
84                return Err(EmbeddingError::ResponseError(
85                    "No embedding in response".to_string(),
86                ));
87            }
88        }
89
90        Ok(embeddings)
91    }
92}