rig/providers/gemini/
embedding.rs

1// ================================================================
2//! Google Gemini Embeddings Integration
3//! From [Gemini API Reference](https://ai.google.dev/api/embeddings)
4// ================================================================
5
6use serde_json::json;
7
8use crate::{
9    embeddings::{self, EmbeddingError},
10    http_client::HttpClientExt,
11    wasm_compat::WasmCompatSend,
12};
13
14use super::{Client, client::ApiResponse};
15
16/// `embedding-001` embedding model
17pub const EMBEDDING_001: &str = "embedding-001";
18/// `text-embedding-004` embedding model
19pub const EMBEDDING_004: &str = "text-embedding-004";
20#[derive(Clone)]
21pub struct EmbeddingModel<T = reqwest::Client> {
22    client: Client<T>,
23    model: String,
24    ndims: Option<usize>,
25}
26
27impl<T> EmbeddingModel<T> {
28    pub fn new(client: Client<T>, model: &str, ndims: Option<usize>) -> Self {
29        Self {
30            client,
31            model: model.to_string(),
32            ndims,
33        }
34    }
35}
36
37impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
38where
39    T: Clone + HttpClientExt,
40{
41    const MAX_DOCUMENTS: usize = 1024;
42
43    fn ndims(&self) -> usize {
44        match self.model.as_str() {
45            EMBEDDING_001 | EMBEDDING_004 => 768,
46            _ => 0, // Default to 0 for unknown models
47        }
48    }
49
50    /// <https://ai.google.dev/api/embeddings#batch_embed_contents-SHELL>
51    #[cfg_attr(feature = "worker", worker::send)]
52    async fn embed_texts(
53        &self,
54        documents: impl IntoIterator<Item = String> + WasmCompatSend,
55    ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
56        let documents: Vec<String> = documents.into_iter().collect();
57
58        // Google batch embed requests. See docstrings for API ref link.
59        let requests: Vec<_> = documents
60            .iter()
61            .map(|doc| {
62                json!({
63                    "model": format!("models/{}", self.model),
64                    "content": json!({
65                        "parts": [json!({
66                            "text": doc.to_string()
67                        })]
68                    }),
69                    "output_dimensionality": self.ndims,
70                })
71            })
72            .collect();
73
74        let request_body = json!({ "requests": requests  });
75
76        tracing::info!("{}", serde_json::to_string_pretty(&request_body).unwrap());
77
78        let request_body = serde_json::to_vec(&request_body)?;
79        let req = self
80            .client
81            .post(&format!("/v1beta/models/{}:batchEmbedContents", self.model))
82            .header("Content-Type", "application/json")
83            .body(request_body)
84            .map_err(|e| EmbeddingError::HttpError(e.into()))?;
85        let response = self.client.send::<_, Vec<u8>>(req).await?;
86
87        let response: ApiResponse<gemini_api_types::EmbeddingResponse> =
88            serde_json::from_slice(&response.into_body().await?)?;
89
90        match response {
91            ApiResponse::Ok(response) => {
92                let docs = documents
93                    .into_iter()
94                    .zip(response.embeddings)
95                    .map(|(document, embedding)| embeddings::Embedding {
96                        document,
97                        vec: embedding.values,
98                    })
99                    .collect();
100
101                Ok(docs)
102            }
103            ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
104        }
105    }
106}
107
108// =================================================================
109// Gemini API Types
110// =================================================================
111/// Rust Implementation of the Gemini Types from [Gemini API Reference](https://ai.google.dev/api/embeddings)
112#[allow(dead_code)]
113mod gemini_api_types {
114    use serde::{Deserialize, Serialize};
115    use serde_json::Value;
116
117    use crate::providers::gemini::gemini_api_types::{CodeExecutionResult, ExecutableCode};
118
119    #[derive(Serialize)]
120    #[serde(rename_all = "camelCase")]
121    pub struct EmbedContentRequest {
122        model: String,
123        content: EmbeddingContent,
124        task_type: TaskType,
125        title: String,
126        output_dimensionality: i32,
127    }
128
129    #[derive(Serialize)]
130    pub struct EmbeddingContent {
131        parts: Vec<EmbeddingContentPart>,
132        /// Optional. The producer of the content. Must be either 'user' or 'model'. Useful to set for multi-turn
133        /// conversations, otherwise can be left blank or unset.
134        role: Option<String>,
135    }
136
137    /// A datatype containing media that is part of a multi-part Content message.
138    ///  - A Part consists of data which has an associated datatype. A Part can only contain one of the accepted types in Part.data.
139    ///  - A Part must have a fixed IANA MIME type identifying the type and subtype of the media if the inlineData field is filled with raw bytes.
140    #[derive(Serialize)]
141    pub struct EmbeddingContentPart {
142        /// Inline text.
143        text: String,
144        /// Inline media bytes.
145        inline_data: Option<Blob>,
146        /// A predicted FunctionCall returned from the model that contains a string representing the [FunctionDeclaration.name]
147        /// with the arguments and their values.
148        function_call: Option<FunctionCall>,
149        /// The result output of a FunctionCall that contains a string representing the [FunctionDeclaration.name] and a structured
150        /// JSON object containing any output from the function is used as context to the model.
151        function_response: Option<FunctionResponse>,
152        /// URI based data.
153        file_data: Option<FileData>,
154        /// Code generated by the model that is meant to be executed.
155        executable_code: Option<ExecutableCode>,
156        /// Result of executing the ExecutableCode.
157        code_execution_result: Option<CodeExecutionResult>,
158    }
159
160    /// Raw media bytes.
161    /// Text should not be sent as raw bytes, use the 'text' field.
162    #[derive(Serialize)]
163    pub struct Blob {
164        /// Raw bytes for media formats.A base64-encoded string.
165        data: String,
166        /// The IANA standard MIME type of the source data. Examples: - image/png - image/jpeg If an unsupported MIME type is
167        /// provided, an error will be returned. For a complete list of supported types, see Supported file formats.
168        mime_type: String,
169    }
170
171    #[derive(Serialize)]
172    pub struct FunctionCall {
173        /// The name of the function to call. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 63.
174        name: String,
175        /// The function parameters and values in JSON object format.
176        args: Option<Value>,
177    }
178
179    #[derive(Serialize)]
180    pub struct FunctionResponse {
181        /// The name of the function to call. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 63.
182        name: String,
183        /// The result of the function call in JSON object format.
184        result: Value,
185    }
186
187    #[derive(Serialize)]
188    #[serde(rename_all = "camelCase")]
189    pub struct FileData {
190        /// The URI of the file.
191        file_uri: String,
192        /// The IANA standard MIME type of the source data.
193        mime_type: String,
194    }
195
196    #[derive(Serialize)]
197    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
198    pub enum TaskType {
199        /// Unset value, which will default to one of the other enum values.
200        Unspecified,
201        /// Specifies the given text is a query in a search/retrieval setting.
202        RetrievalQuery,
203        /// Specifies the given text is a document from the corpus being searched.
204        RetrievalDocument,
205        /// Specifies the given text will be used for STS.
206        SemanticSimilarity,
207        /// Specifies that the given text will be classified.
208        Classification,
209        /// Specifies that the embeddings will be used for clustering.
210        Clustering,
211        /// Specifies that the given text will be used for question answering.
212        QuestionAnswering,
213        /// Specifies that the given text will be used for fact verification.
214        FactVerification,
215    }
216
217    #[derive(Debug, Deserialize)]
218    pub struct EmbeddingResponse {
219        pub embeddings: Vec<EmbeddingValues>,
220    }
221
222    #[derive(Debug, Deserialize)]
223    pub struct EmbeddingValues {
224        pub values: Vec<f64>,
225    }
226}