rig/providers/gemini/
embedding.rs

1// ================================================================
2//! Google Gemini Embeddings Integration
3//! From [Gemini API Reference](https://ai.google.dev/api/embeddings)
4// ================================================================
5
6use serde_json::json;
7
8use crate::embeddings::{self, EmbeddingError};
9
10use super::{client::ApiResponse, Client};
11
12/// `embedding-001` embedding model
13pub const EMBEDDING_001: &str = "embedding-001";
14/// `text-embedding-004` embedding model
15pub const EMBEDDING_004: &str = "text-embedding-004";
16#[derive(Clone)]
17pub struct EmbeddingModel {
18    client: Client,
19    model: String,
20    ndims: Option<usize>,
21}
22
23impl EmbeddingModel {
24    pub fn new(client: Client, model: &str, ndims: Option<usize>) -> Self {
25        Self {
26            client,
27            model: model.to_string(),
28            ndims,
29        }
30    }
31}
32
33impl embeddings::EmbeddingModel for EmbeddingModel {
34    const MAX_DOCUMENTS: usize = 1024;
35
36    fn ndims(&self) -> usize {
37        match self.model.as_str() {
38            EMBEDDING_001 => 768,
39            EMBEDDING_004 => 1024,
40            _ => 0, // Default to 0 for unknown models
41        }
42    }
43
44    /// <https://ai.google.dev/api/embeddings#batch_embed_contents-SHELL>
45    #[cfg_attr(feature = "worker", worker::send)]
46    async fn embed_texts(
47        &self,
48        documents: impl IntoIterator<Item = String> + Send,
49    ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
50        let documents: Vec<String> = documents.into_iter().collect();
51
52        // Google batch embed requests. See docstrings for API ref link.
53        let requests: Vec<_> = documents
54            .iter()
55            .map(|doc| {
56                json!({
57                    "model": format!("models/{}", self.model),
58                    "content": json!({
59                        "parts": [json!({
60                            "text": doc.to_string()
61                        })]
62                    }),
63                    "output_dimensionality": self.ndims,
64                })
65            })
66            .collect();
67
68        let request_body = json!({ "requests": requests  });
69
70        tracing::info!("{}", serde_json::to_string_pretty(&request_body).unwrap());
71
72        let response = self
73            .client
74            .post(&format!("/v1beta/models/{}:batchEmbedContents", self.model))
75            .json(&request_body)
76            .send()
77            .await?
78            .error_for_status()?
79            .json::<ApiResponse<gemini_api_types::EmbeddingResponse>>()
80            .await?;
81
82        match response {
83            ApiResponse::Ok(response) => {
84                let docs = documents
85                    .into_iter()
86                    .zip(response.embeddings)
87                    .map(|(document, embedding)| embeddings::Embedding {
88                        document,
89                        vec: embedding.values,
90                    })
91                    .collect();
92
93                Ok(docs)
94            }
95            ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
96        }
97    }
98}
99
100// =================================================================
101// Gemini API Types
102// =================================================================
103/// Rust Implementation of the Gemini Types from [Gemini API Reference](https://ai.google.dev/api/embeddings)
104#[allow(dead_code)]
105mod gemini_api_types {
106    use serde::{Deserialize, Serialize};
107    use serde_json::Value;
108
109    use crate::providers::gemini::gemini_api_types::{CodeExecutionResult, ExecutableCode};
110
111    #[derive(Serialize)]
112    #[serde(rename_all = "camelCase")]
113    pub struct EmbedContentRequest {
114        model: String,
115        content: EmbeddingContent,
116        task_type: TaskType,
117        title: String,
118        output_dimensionality: i32,
119    }
120
121    #[derive(Serialize)]
122    pub struct EmbeddingContent {
123        parts: Vec<EmbeddingContentPart>,
124        /// Optional. The producer of the content. Must be either 'user' or 'model'. Useful to set for multi-turn
125        /// conversations, otherwise can be left blank or unset.
126        role: Option<String>,
127    }
128
129    /// A datatype containing media that is part of a multi-part Content message.
130    ///  - A Part consists of data which has an associated datatype. A Part can only contain one of the accepted types in Part.data.
131    ///  - A Part must have a fixed IANA MIME type identifying the type and subtype of the media if the inlineData field is filled with raw bytes.
132    #[derive(Serialize)]
133    pub struct EmbeddingContentPart {
134        /// Inline text.
135        text: String,
136        /// Inline media bytes.
137        inline_data: Option<Blob>,
138        /// A predicted FunctionCall returned from the model that contains a string representing the [FunctionDeclaration.name]
139        /// with the arguments and their values.
140        function_call: Option<FunctionCall>,
141        /// The result output of a FunctionCall that contains a string representing the [FunctionDeclaration.name] and a structured
142        /// JSON object containing any output from the function is used as context to the model.
143        function_response: Option<FunctionResponse>,
144        /// URI based data.
145        file_data: Option<FileData>,
146        /// Code generated by the model that is meant to be executed.
147        executable_code: Option<ExecutableCode>,
148        /// Result of executing the ExecutableCode.
149        code_execution_result: Option<CodeExecutionResult>,
150    }
151
152    /// Raw media bytes.
153    /// Text should not be sent as raw bytes, use the 'text' field.
154    #[derive(Serialize)]
155    pub struct Blob {
156        /// Raw bytes for media formats.A base64-encoded string.
157        data: String,
158        /// The IANA standard MIME type of the source data. Examples: - image/png - image/jpeg If an unsupported MIME type is
159        /// provided, an error will be returned. For a complete list of supported types, see Supported file formats.
160        mime_type: String,
161    }
162
163    #[derive(Serialize)]
164    pub struct FunctionCall {
165        /// The name of the function to call. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 63.
166        name: String,
167        /// The function parameters and values in JSON object format.
168        args: Option<Value>,
169    }
170
171    #[derive(Serialize)]
172    pub struct FunctionResponse {
173        /// The name of the function to call. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 63.
174        name: String,
175        /// The result of the function call in JSON object format.
176        result: Value,
177    }
178
179    #[derive(Serialize)]
180    #[serde(rename_all = "camelCase")]
181    pub struct FileData {
182        /// The URI of the file.
183        file_uri: String,
184        /// The IANA standard MIME type of the source data.
185        mime_type: String,
186    }
187
188    #[derive(Serialize)]
189    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
190    pub enum TaskType {
191        /// Unset value, which will default to one of the other enum values.
192        Unspecified,
193        /// Specifies the given text is a query in a search/retrieval setting.
194        RetrievalQuery,
195        /// Specifies the given text is a document from the corpus being searched.
196        RetrievalDocument,
197        /// Specifies the given text will be used for STS.
198        SemanticSimilarity,
199        /// Specifies that the given text will be classified.
200        Classification,
201        /// Specifies that the embeddings will be used for clustering.
202        Clustering,
203        /// Specifies that the given text will be used for question answering.
204        QuestionAnswering,
205        /// Specifies that the given text will be used for fact verification.
206        FactVerification,
207    }
208
209    #[derive(Debug, Deserialize)]
210    pub struct EmbeddingResponse {
211        pub embeddings: Vec<EmbeddingValues>,
212    }
213
214    #[derive(Debug, Deserialize)]
215    pub struct EmbeddingValues {
216        pub values: Vec<f64>,
217    }
218}