Skip to main content

rig/providers/gemini/
embedding.rs

1// ================================================================
2//! Google Gemini Embeddings Integration
3//! From [Gemini API Reference](https://ai.google.dev/api/embeddings)
4// ================================================================
5
6use serde_json::json;
7
8use super::{Client, client::ApiResponse};
9use crate::{
10    embeddings::{self, EmbeddingError},
11    http_client::HttpClientExt,
12    wasm_compat::WasmCompatSend,
13};
14
15/// `gemini-embedding-001` embedding model (3072 dimensions by default)
16pub const EMBEDDING_001: &str = "gemini-embedding-001";
17/// `text-embedding-004` embedding model (768 dimensions by default)
18pub const EMBEDDING_004: &str = "text-embedding-004";
19
20/// Returns the default output dimensionality for known Gemini embedding models.
21///
22/// See <https://ai.google.dev/gemini-api/docs/models#gemini-embedding>
23fn model_default_ndims(model: &str) -> Option<usize> {
24    match model {
25        EMBEDDING_001 => Some(3072),
26        EMBEDDING_004 => Some(768),
27        _ => None,
28    }
29}
30
31#[derive(Clone)]
32pub struct EmbeddingModel<T = reqwest::Client> {
33    client: Client<T>,
34    model: String,
35    ndims: usize,
36}
37
38impl<T> EmbeddingModel<T> {
39    pub fn new(client: Client<T>, model: impl Into<String>, ndims: usize) -> Self {
40        Self {
41            client,
42            model: model.into(),
43            ndims,
44        }
45    }
46
47    pub fn with_model(client: Client<T>, model: &str, ndims: usize) -> Self {
48        Self {
49            client,
50            model: model.to_string(),
51            ndims,
52        }
53    }
54}
55
56impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
57where
58    T: Clone + HttpClientExt + 'static,
59{
60    type Client = Client<T>;
61
62    const MAX_DOCUMENTS: usize = 1024;
63
64    fn make(client: &Self::Client, model: impl Into<String>, dims: Option<usize>) -> Self {
65        let model = model.into();
66        let ndims = dims.or_else(|| model_default_ndims(&model)).unwrap_or(768);
67        Self::new(client.clone(), model, ndims)
68    }
69
70    fn ndims(&self) -> usize {
71        self.ndims
72    }
73
74    /// <https://ai.google.dev/api/embeddings#batch_embed_contents-SHELL>
75    async fn embed_texts(
76        &self,
77        documents: impl IntoIterator<Item = String> + WasmCompatSend,
78    ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
79        let documents: Vec<String> = documents.into_iter().collect();
80
81        // Google batch embed requests. See docstrings for API ref link.
82        let requests: Vec<_> = documents
83            .iter()
84            .map(|doc| {
85                json!({
86                    "model": format!("models/{}", self.model),
87                    "content": json!({
88                        "parts": [json!({
89                            "text": doc.to_string()
90                        })]
91                    }),
92                    "output_dimensionality": self.ndims,
93                })
94            })
95            .collect();
96
97        let request_body = json!({ "requests": requests  });
98
99        tracing::trace!(
100            target: "rig::embedding",
101            "Sending embedding request to Gemini API {}",
102            serde_json::to_string_pretty(&request_body).unwrap()
103        );
104
105        let request_body = serde_json::to_vec(&request_body)?;
106        let path = format!("/v1beta/models/{}:batchEmbedContents", self.model);
107        let req = self
108            .client
109            .post(path.as_str())?
110            .body(request_body)
111            .map_err(|e| EmbeddingError::HttpError(e.into()))?;
112        let response = self.client.send::<_, Vec<u8>>(req).await?;
113
114        let response: ApiResponse<gemini_api_types::EmbeddingResponse> =
115            serde_json::from_slice(&response.into_body().await?)?;
116
117        match response {
118            ApiResponse::Ok(response) => {
119                let docs = documents
120                    .into_iter()
121                    .zip(response.embeddings)
122                    .map(|(document, embedding)| embeddings::Embedding {
123                        document,
124                        vec: embedding
125                            .values
126                            .into_iter()
127                            .filter_map(|n| n.as_f64())
128                            .collect(),
129                    })
130                    .collect();
131
132                Ok(docs)
133            }
134            ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
135        }
136    }
137}
138
139// =================================================================
140// Gemini API Types
141// =================================================================
142/// Rust Implementation of the Gemini Types from [Gemini API Reference](https://ai.google.dev/api/embeddings)
143#[allow(dead_code)]
144mod gemini_api_types {
145    use serde::{Deserialize, Serialize};
146    use serde_json::Value;
147
148    use crate::providers::gemini::gemini_api_types::{CodeExecutionResult, ExecutableCode};
149
150    #[derive(Serialize)]
151    #[serde(rename_all = "camelCase")]
152    pub struct EmbedContentRequest {
153        model: String,
154        content: EmbeddingContent,
155        task_type: TaskType,
156        title: String,
157        output_dimensionality: i32,
158    }
159
160    #[derive(Serialize)]
161    pub struct EmbeddingContent {
162        parts: Vec<EmbeddingContentPart>,
163        /// Optional. The producer of the content. Must be either 'user' or 'model'. Useful to set for multi-turn
164        /// conversations, otherwise can be left blank or unset.
165        role: Option<String>,
166    }
167
168    /// A datatype containing media that is part of a multi-part Content message.
169    ///  - A Part consists of data which has an associated datatype. A Part can only contain one of the accepted types in Part.data.
170    ///  - A Part must have a fixed IANA MIME type identifying the type and subtype of the media if the inlineData field is filled with raw bytes.
171    #[derive(Serialize)]
172    pub struct EmbeddingContentPart {
173        /// Inline text.
174        text: String,
175        /// Inline media bytes.
176        inline_data: Option<Blob>,
177        /// A predicted FunctionCall returned from the model that contains a string representing the [FunctionDeclaration.name]
178        /// with the arguments and their values.
179        function_call: Option<FunctionCall>,
180        /// The result output of a FunctionCall that contains a string representing the [FunctionDeclaration.name] and a structured
181        /// JSON object containing any output from the function is used as context to the model.
182        function_response: Option<FunctionResponse>,
183        /// URI based data.
184        file_data: Option<FileData>,
185        /// Code generated by the model that is meant to be executed.
186        executable_code: Option<ExecutableCode>,
187        /// Result of executing the ExecutableCode.
188        code_execution_result: Option<CodeExecutionResult>,
189    }
190
191    /// Raw media bytes.
192    /// Text should not be sent as raw bytes, use the 'text' field.
193    #[derive(Serialize)]
194    pub struct Blob {
195        /// Raw bytes for media formats.A base64-encoded string.
196        data: String,
197        /// The IANA standard MIME type of the source data. Examples: - image/png - image/jpeg If an unsupported MIME type is
198        /// provided, an error will be returned. For a complete list of supported types, see Supported file formats.
199        mime_type: String,
200    }
201
202    #[derive(Serialize)]
203    pub struct FunctionCall {
204        /// The name of the function to call. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 63.
205        name: String,
206        /// The function parameters and values in JSON object format.
207        args: Option<Value>,
208    }
209
210    #[derive(Serialize)]
211    pub struct FunctionResponse {
212        /// The name of the function to call. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 63.
213        name: String,
214        /// The result of the function call in JSON object format.
215        result: Value,
216    }
217
218    #[derive(Serialize)]
219    #[serde(rename_all = "camelCase")]
220    pub struct FileData {
221        /// The URI of the file.
222        file_uri: String,
223        /// The IANA standard MIME type of the source data.
224        mime_type: String,
225    }
226
227    #[derive(Serialize)]
228    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
229    pub enum TaskType {
230        /// Unset value, which will default to one of the other enum values.
231        Unspecified,
232        /// Specifies the given text is a query in a search/retrieval setting.
233        RetrievalQuery,
234        /// Specifies the given text is a document from the corpus being searched.
235        RetrievalDocument,
236        /// Specifies the given text will be used for STS.
237        SemanticSimilarity,
238        /// Specifies that the given text will be classified.
239        Classification,
240        /// Specifies that the embeddings will be used for clustering.
241        Clustering,
242        /// Specifies that the given text will be used for question answering.
243        QuestionAnswering,
244        /// Specifies that the given text will be used for fact verification.
245        FactVerification,
246    }
247
248    #[derive(Debug, Deserialize)]
249    pub struct EmbeddingResponse {
250        pub embeddings: Vec<EmbeddingValues>,
251    }
252
253    #[derive(Debug, Deserialize)]
254    pub struct EmbeddingValues {
255        pub values: Vec<serde_json::Number>,
256    }
257}
258
259#[cfg(test)]
260mod tests {
261    use super::*;
262
263    #[test]
264    fn test_model_default_ndims_lookup() {
265        assert_eq!(model_default_ndims(EMBEDDING_001), Some(3072));
266        assert_eq!(model_default_ndims(EMBEDDING_004), Some(768));
267        assert_eq!(model_default_ndims("unknown-model"), None);
268    }
269
270    #[test]
271    fn test_make_resolves_default_dims() {
272        let client = Client::new("test_key").unwrap();
273
274        // EMBEDDING_001 defaults to 3072
275        let model =
276            <EmbeddingModel as embeddings::EmbeddingModel>::make(&client, EMBEDDING_001, None);
277        assert_eq!(embeddings::EmbeddingModel::ndims(&model), 3072);
278
279        // EMBEDDING_004 defaults to 768
280        let model =
281            <EmbeddingModel as embeddings::EmbeddingModel>::make(&client, EMBEDDING_004, None);
282        assert_eq!(embeddings::EmbeddingModel::ndims(&model), 768);
283
284        // Unknown model falls back to 768
285        let model = <EmbeddingModel as embeddings::EmbeddingModel>::make(
286            &client,
287            "some-future-model",
288            None,
289        );
290        assert_eq!(embeddings::EmbeddingModel::ndims(&model), 768);
291    }
292
293    #[test]
294    fn test_make_respects_explicit_dims() {
295        let client = Client::new("test_key").unwrap();
296
297        let model =
298            <EmbeddingModel as embeddings::EmbeddingModel>::make(&client, EMBEDDING_001, Some(256));
299        assert_eq!(embeddings::EmbeddingModel::ndims(&model), 256);
300    }
301
302    #[test]
303    fn test_new_uses_provided_ndims() {
304        let client = Client::new("test_key").unwrap();
305
306        let model = EmbeddingModel::new(client, EMBEDDING_001, 512);
307        assert_eq!(embeddings::EmbeddingModel::ndims(&model), 512);
308    }
309}