rig/providers/gemini/
embedding.rs

1// ================================================================
2//! Google Gemini Embeddings Integration
3//! From [Gemini API Reference](https://ai.google.dev/api/embeddings)
4// ================================================================
5
6use serde_json::json;
7
8use super::{Client, client::ApiResponse};
9use crate::{
10    embeddings::{self, EmbeddingError},
11    http_client::HttpClientExt,
12    wasm_compat::WasmCompatSend,
13};
14
15/// `embedding-001` embedding model
16pub const EMBEDDING_001: &str = "embedding-001";
17/// `text-embedding-004` embedding model
18pub const EMBEDDING_004: &str = "text-embedding-004";
19
20#[derive(Clone)]
21pub struct EmbeddingModel<T = reqwest::Client> {
22    client: Client<T>,
23    model: String,
24    ndims: Option<usize>,
25}
26
27impl<T> EmbeddingModel<T> {
28    pub fn new(client: Client<T>, model: impl Into<String>, ndims: Option<usize>) -> Self {
29        Self {
30            client,
31            model: model.into(),
32            ndims,
33        }
34    }
35
36    pub fn with_model(client: Client<T>, model: &str, ndims: Option<usize>) -> Self {
37        Self {
38            client,
39            model: model.to_string(),
40            ndims,
41        }
42    }
43}
44
45impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
46where
47    T: Clone + HttpClientExt,
48{
49    type Client = Client<T>;
50
51    const MAX_DOCUMENTS: usize = 1024;
52
53    fn make(client: &Self::Client, model: impl Into<String>, dims: Option<usize>) -> Self {
54        Self::new(client.clone(), model, dims)
55    }
56
57    fn ndims(&self) -> usize {
58        768
59    }
60
61    /// <https://ai.google.dev/api/embeddings#batch_embed_contents-SHELL>
62    #[cfg_attr(feature = "worker", worker::send)]
63    async fn embed_texts(
64        &self,
65        documents: impl IntoIterator<Item = String> + WasmCompatSend,
66    ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
67        let documents: Vec<String> = documents.into_iter().collect();
68
69        // Google batch embed requests. See docstrings for API ref link.
70        let requests: Vec<_> = documents
71            .iter()
72            .map(|doc| {
73                json!({
74                    "model": format!("models/{}", self.model),
75                    "content": json!({
76                        "parts": [json!({
77                            "text": doc.to_string()
78                        })]
79                    }),
80                    "output_dimensionality": self.ndims,
81                })
82            })
83            .collect();
84
85        let request_body = json!({ "requests": requests  });
86
87        tracing::trace!(
88            target: "rig::embedding",
89            "Sending embedding request to Gemini API {}",
90            serde_json::to_string_pretty(&request_body).unwrap()
91        );
92
93        let request_body = serde_json::to_vec(&request_body)?;
94        let path = format!("/v1beta/models/{}:batchEmbedContents", self.model);
95        let req = self
96            .client
97            .post(path.as_str())?
98            .body(request_body)
99            .map_err(|e| EmbeddingError::HttpError(e.into()))?;
100        let response = self.client.send::<_, Vec<u8>>(req).await?;
101
102        let response: ApiResponse<gemini_api_types::EmbeddingResponse> =
103            serde_json::from_slice(&response.into_body().await?)?;
104
105        match response {
106            ApiResponse::Ok(response) => {
107                let docs = documents
108                    .into_iter()
109                    .zip(response.embeddings)
110                    .map(|(document, embedding)| embeddings::Embedding {
111                        document,
112                        vec: embedding.values,
113                    })
114                    .collect();
115
116                Ok(docs)
117            }
118            ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
119        }
120    }
121}
122
123// =================================================================
124// Gemini API Types
125// =================================================================
126/// Rust Implementation of the Gemini Types from [Gemini API Reference](https://ai.google.dev/api/embeddings)
127#[allow(dead_code)]
128mod gemini_api_types {
129    use serde::{Deserialize, Serialize};
130    use serde_json::Value;
131
132    use crate::providers::gemini::gemini_api_types::{CodeExecutionResult, ExecutableCode};
133
134    #[derive(Serialize)]
135    #[serde(rename_all = "camelCase")]
136    pub struct EmbedContentRequest {
137        model: String,
138        content: EmbeddingContent,
139        task_type: TaskType,
140        title: String,
141        output_dimensionality: i32,
142    }
143
144    #[derive(Serialize)]
145    pub struct EmbeddingContent {
146        parts: Vec<EmbeddingContentPart>,
147        /// Optional. The producer of the content. Must be either 'user' or 'model'. Useful to set for multi-turn
148        /// conversations, otherwise can be left blank or unset.
149        role: Option<String>,
150    }
151
152    /// A datatype containing media that is part of a multi-part Content message.
153    ///  - A Part consists of data which has an associated datatype. A Part can only contain one of the accepted types in Part.data.
154    ///  - A Part must have a fixed IANA MIME type identifying the type and subtype of the media if the inlineData field is filled with raw bytes.
155    #[derive(Serialize)]
156    pub struct EmbeddingContentPart {
157        /// Inline text.
158        text: String,
159        /// Inline media bytes.
160        inline_data: Option<Blob>,
161        /// A predicted FunctionCall returned from the model that contains a string representing the [FunctionDeclaration.name]
162        /// with the arguments and their values.
163        function_call: Option<FunctionCall>,
164        /// The result output of a FunctionCall that contains a string representing the [FunctionDeclaration.name] and a structured
165        /// JSON object containing any output from the function is used as context to the model.
166        function_response: Option<FunctionResponse>,
167        /// URI based data.
168        file_data: Option<FileData>,
169        /// Code generated by the model that is meant to be executed.
170        executable_code: Option<ExecutableCode>,
171        /// Result of executing the ExecutableCode.
172        code_execution_result: Option<CodeExecutionResult>,
173    }
174
175    /// Raw media bytes.
176    /// Text should not be sent as raw bytes, use the 'text' field.
177    #[derive(Serialize)]
178    pub struct Blob {
179        /// Raw bytes for media formats.A base64-encoded string.
180        data: String,
181        /// The IANA standard MIME type of the source data. Examples: - image/png - image/jpeg If an unsupported MIME type is
182        /// provided, an error will be returned. For a complete list of supported types, see Supported file formats.
183        mime_type: String,
184    }
185
186    #[derive(Serialize)]
187    pub struct FunctionCall {
188        /// The name of the function to call. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 63.
189        name: String,
190        /// The function parameters and values in JSON object format.
191        args: Option<Value>,
192    }
193
194    #[derive(Serialize)]
195    pub struct FunctionResponse {
196        /// The name of the function to call. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 63.
197        name: String,
198        /// The result of the function call in JSON object format.
199        result: Value,
200    }
201
202    #[derive(Serialize)]
203    #[serde(rename_all = "camelCase")]
204    pub struct FileData {
205        /// The URI of the file.
206        file_uri: String,
207        /// The IANA standard MIME type of the source data.
208        mime_type: String,
209    }
210
211    #[derive(Serialize)]
212    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
213    pub enum TaskType {
214        /// Unset value, which will default to one of the other enum values.
215        Unspecified,
216        /// Specifies the given text is a query in a search/retrieval setting.
217        RetrievalQuery,
218        /// Specifies the given text is a document from the corpus being searched.
219        RetrievalDocument,
220        /// Specifies the given text will be used for STS.
221        SemanticSimilarity,
222        /// Specifies that the given text will be classified.
223        Classification,
224        /// Specifies that the embeddings will be used for clustering.
225        Clustering,
226        /// Specifies that the given text will be used for question answering.
227        QuestionAnswering,
228        /// Specifies that the given text will be used for fact verification.
229        FactVerification,
230    }
231
232    #[derive(Debug, Deserialize)]
233    pub struct EmbeddingResponse {
234        pub embeddings: Vec<EmbeddingValues>,
235    }
236
237    #[derive(Debug, Deserialize)]
238    pub struct EmbeddingValues {
239        pub values: Vec<f64>,
240    }
241}