Skip to main content

rig/providers/gemini/
embedding.rs

1// ================================================================
2//! Google Gemini Embeddings Integration
3//! From [Gemini API Reference](https://ai.google.dev/api/embeddings)
4// ================================================================
5
6use serde_json::json;
7
8use super::{Client, client::ApiResponse};
9use crate::{
10    embeddings::{self, EmbeddingError},
11    http_client::HttpClientExt,
12    wasm_compat::WasmCompatSend,
13};
14
15/// `gemini-embedding-001` embedding model (3072 dimensions by default)
16pub const EMBEDDING_001: &str = "gemini-embedding-001";
17/// `text-embedding-004` embedding model (768 dimensions by default)
18pub const EMBEDDING_004: &str = "text-embedding-004";
19
20/// Returns the default output dimensionality for known Gemini embedding models.
21///
22/// See <https://ai.google.dev/gemini-api/docs/models#gemini-embedding>
23fn model_default_ndims(model: &str) -> Option<usize> {
24    match model {
25        EMBEDDING_001 => Some(3072),
26        EMBEDDING_004 => Some(768),
27        _ => None,
28    }
29}
30
31#[derive(Clone)]
32pub struct EmbeddingModel<T = reqwest::Client> {
33    client: Client<T>,
34    model: String,
35    ndims: usize,
36}
37
38impl<T> EmbeddingModel<T> {
39    pub fn new(client: Client<T>, model: impl Into<String>, ndims: usize) -> Self {
40        Self {
41            client,
42            model: model.into(),
43            ndims,
44        }
45    }
46
47    pub fn with_model(client: Client<T>, model: &str, ndims: usize) -> Self {
48        Self {
49            client,
50            model: model.to_string(),
51            ndims,
52        }
53    }
54}
55
56impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
57where
58    T: Clone + HttpClientExt + 'static,
59{
60    type Client = Client<T>;
61
62    const MAX_DOCUMENTS: usize = 1024;
63
64    fn make(client: &Self::Client, model: impl Into<String>, dims: Option<usize>) -> Self {
65        let model = model.into();
66        let ndims = dims.or_else(|| model_default_ndims(&model)).unwrap_or(768);
67        Self::new(client.clone(), model, ndims)
68    }
69
70    fn ndims(&self) -> usize {
71        self.ndims
72    }
73
74    /// <https://ai.google.dev/api/embeddings#batch_embed_contents-SHELL>
75    async fn embed_texts(
76        &self,
77        documents: impl IntoIterator<Item = String> + WasmCompatSend,
78    ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
79        let documents: Vec<String> = documents.into_iter().collect();
80
81        // Google batch embed requests. See docstrings for API ref link.
82        let requests: Vec<_> = documents
83            .iter()
84            .map(|doc| {
85                json!({
86                    "model": format!("models/{}", self.model),
87                    "content": json!({
88                        "parts": [json!({
89                            "text": doc.to_string()
90                        })]
91                    }),
92                    "output_dimensionality": self.ndims,
93                })
94            })
95            .collect();
96
97        let request_body = json!({ "requests": requests  });
98
99        if let Ok(pretty_body) = serde_json::to_string_pretty(&request_body) {
100            tracing::trace!(
101                target: "rig::embedding",
102                "Sending embedding request to Gemini API {pretty_body}"
103            );
104        }
105
106        let request_body = serde_json::to_vec(&request_body)?;
107        let path = format!("/v1beta/models/{}:batchEmbedContents", self.model);
108        let req = self
109            .client
110            .post(path.as_str())?
111            .body(request_body)
112            .map_err(|e| EmbeddingError::HttpError(e.into()))?;
113        let response = self.client.send::<_, Vec<u8>>(req).await?;
114
115        let response: ApiResponse<gemini_api_types::EmbeddingResponse> =
116            serde_json::from_slice(&response.into_body().await?)?;
117
118        match response {
119            ApiResponse::Ok(response) => {
120                let docs = documents
121                    .into_iter()
122                    .zip(response.embeddings)
123                    .map(|(document, embedding)| embeddings::Embedding {
124                        document,
125                        vec: embedding
126                            .values
127                            .into_iter()
128                            .filter_map(|n| n.as_f64())
129                            .collect(),
130                    })
131                    .collect();
132
133                Ok(docs)
134            }
135            ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
136        }
137    }
138}
139
140// =================================================================
141// Gemini API Types
142// =================================================================
143/// Rust Implementation of the Gemini Types from [Gemini API Reference](https://ai.google.dev/api/embeddings)
144#[allow(dead_code)]
145mod gemini_api_types {
146    use serde::{Deserialize, Serialize};
147    use serde_json::Value;
148
149    use crate::providers::gemini::gemini_api_types::{CodeExecutionResult, ExecutableCode};
150
151    #[derive(Serialize)]
152    #[serde(rename_all = "camelCase")]
153    pub struct EmbedContentRequest {
154        model: String,
155        content: EmbeddingContent,
156        task_type: TaskType,
157        title: String,
158        output_dimensionality: i32,
159    }
160
161    #[derive(Serialize)]
162    pub struct EmbeddingContent {
163        parts: Vec<EmbeddingContentPart>,
164        /// Optional. The producer of the content. Must be either 'user' or 'model'. Useful to set for multi-turn
165        /// conversations, otherwise can be left blank or unset.
166        role: Option<String>,
167    }
168
169    /// A datatype containing media that is part of a multi-part Content message.
170    ///  - A Part consists of data which has an associated datatype. A Part can only contain one of the accepted types in Part.data.
171    ///  - A Part must have a fixed IANA MIME type identifying the type and subtype of the media if the inlineData field is filled with raw bytes.
172    #[derive(Serialize)]
173    pub struct EmbeddingContentPart {
174        /// Inline text.
175        text: String,
176        /// Inline media bytes.
177        inline_data: Option<Blob>,
178        /// A predicted FunctionCall returned from the model that contains a string representing the [FunctionDeclaration.name]
179        /// with the arguments and their values.
180        function_call: Option<FunctionCall>,
181        /// The result output of a FunctionCall that contains a string representing the [FunctionDeclaration.name] and a structured
182        /// JSON object containing any output from the function is used as context to the model.
183        function_response: Option<FunctionResponse>,
184        /// URI based data.
185        file_data: Option<FileData>,
186        /// Code generated by the model that is meant to be executed.
187        executable_code: Option<ExecutableCode>,
188        /// Result of executing the ExecutableCode.
189        code_execution_result: Option<CodeExecutionResult>,
190    }
191
192    /// Raw media bytes.
193    /// Text should not be sent as raw bytes, use the 'text' field.
194    #[derive(Serialize)]
195    pub struct Blob {
196        /// Raw bytes for media formats.A base64-encoded string.
197        data: String,
198        /// The IANA standard MIME type of the source data. Examples: - image/png - image/jpeg If an unsupported MIME type is
199        /// provided, an error will be returned. For a complete list of supported types, see Supported file formats.
200        mime_type: String,
201    }
202
203    #[derive(Serialize)]
204    pub struct FunctionCall {
205        /// The name of the function to call. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 63.
206        name: String,
207        /// The function parameters and values in JSON object format.
208        args: Option<Value>,
209    }
210
211    #[derive(Serialize)]
212    pub struct FunctionResponse {
213        /// The name of the function to call. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 63.
214        name: String,
215        /// The result of the function call in JSON object format.
216        result: Value,
217    }
218
219    #[derive(Serialize)]
220    #[serde(rename_all = "camelCase")]
221    pub struct FileData {
222        /// The URI of the file.
223        file_uri: String,
224        /// The IANA standard MIME type of the source data.
225        mime_type: String,
226    }
227
228    #[derive(Serialize)]
229    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
230    pub enum TaskType {
231        /// Unset value, which will default to one of the other enum values.
232        Unspecified,
233        /// Specifies the given text is a query in a search/retrieval setting.
234        RetrievalQuery,
235        /// Specifies the given text is a document from the corpus being searched.
236        RetrievalDocument,
237        /// Specifies the given text will be used for STS.
238        SemanticSimilarity,
239        /// Specifies that the given text will be classified.
240        Classification,
241        /// Specifies that the embeddings will be used for clustering.
242        Clustering,
243        /// Specifies that the given text will be used for question answering.
244        QuestionAnswering,
245        /// Specifies that the given text will be used for fact verification.
246        FactVerification,
247    }
248
249    #[derive(Debug, Deserialize)]
250    pub struct EmbeddingResponse {
251        pub embeddings: Vec<EmbeddingValues>,
252    }
253
254    #[derive(Debug, Deserialize)]
255    pub struct EmbeddingValues {
256        pub values: Vec<serde_json::Number>,
257    }
258}
259
260#[cfg(test)]
261mod tests {
262    use super::*;
263
264    #[test]
265    fn test_model_default_ndims_lookup() {
266        assert_eq!(model_default_ndims(EMBEDDING_001), Some(3072));
267        assert_eq!(model_default_ndims(EMBEDDING_004), Some(768));
268        assert_eq!(model_default_ndims("unknown-model"), None);
269    }
270
271    #[test]
272    fn test_make_resolves_default_dims() {
273        let client = Client::new("test_key").unwrap();
274
275        // EMBEDDING_001 defaults to 3072
276        let model =
277            <EmbeddingModel as embeddings::EmbeddingModel>::make(&client, EMBEDDING_001, None);
278        assert_eq!(embeddings::EmbeddingModel::ndims(&model), 3072);
279
280        // EMBEDDING_004 defaults to 768
281        let model =
282            <EmbeddingModel as embeddings::EmbeddingModel>::make(&client, EMBEDDING_004, None);
283        assert_eq!(embeddings::EmbeddingModel::ndims(&model), 768);
284
285        // Unknown model falls back to 768
286        let model = <EmbeddingModel as embeddings::EmbeddingModel>::make(
287            &client,
288            "some-future-model",
289            None,
290        );
291        assert_eq!(embeddings::EmbeddingModel::ndims(&model), 768);
292    }
293
294    #[test]
295    fn test_make_respects_explicit_dims() {
296        let client = Client::new("test_key").unwrap();
297
298        let model =
299            <EmbeddingModel as embeddings::EmbeddingModel>::make(&client, EMBEDDING_001, Some(256));
300        assert_eq!(embeddings::EmbeddingModel::ndims(&model), 256);
301    }
302
303    #[test]
304    fn test_new_uses_provided_ndims() {
305        let client = Client::new("test_key").unwrap();
306
307        let model = EmbeddingModel::new(client, EMBEDDING_001, 512);
308        assert_eq!(embeddings::EmbeddingModel::ndims(&model), 512);
309    }
310}