rig/providers/gemini/embedding.rs
1// ================================================================
2//! Google Gemini Embeddings Integration
3//! From [Gemini API Reference](https://ai.google.dev/api/embeddings)
4// ================================================================
5
6use serde_json::json;
7
8use crate::embeddings::{self, EmbeddingError};
9
10use super::{client::ApiResponse, Client};
11
12/// `embedding-001` embedding model
13pub const EMBEDDING_001: &str = "embedding-001";
14/// `text-embedding-004` embedding model
15pub const EMBEDDING_004: &str = "text-embedding-004";
16#[derive(Clone)]
17pub struct EmbeddingModel {
18 client: Client,
19 model: String,
20 ndims: Option<usize>,
21}
22
23impl EmbeddingModel {
24 pub fn new(client: Client, model: &str, ndims: Option<usize>) -> Self {
25 Self {
26 client,
27 model: model.to_string(),
28 ndims,
29 }
30 }
31}
32
33impl embeddings::EmbeddingModel for EmbeddingModel {
34 const MAX_DOCUMENTS: usize = 1024;
35
36 fn ndims(&self) -> usize {
37 match self.model.as_str() {
38 EMBEDDING_001 => 768,
39 EMBEDDING_004 => 1024,
40 _ => 0, // Default to 0 for unknown models
41 }
42 }
43
44 /// <https://ai.google.dev/api/embeddings#batch_embed_contents-SHELL>
45 #[cfg_attr(feature = "worker", worker::send)]
46 async fn embed_texts(
47 &self,
48 documents: impl IntoIterator<Item = String> + Send,
49 ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
50 let documents: Vec<String> = documents.into_iter().collect();
51
52 // Google batch embed requests. See docstrings for API ref link.
53 let requests: Vec<_> = documents
54 .iter()
55 .map(|doc| {
56 json!({
57 "model": format!("models/{}", self.model),
58 "content": json!({
59 "parts": [json!({
60 "text": doc.to_string()
61 })]
62 }),
63 "output_dimensionality": self.ndims,
64 })
65 })
66 .collect();
67
68 let request_body = json!({ "requests": requests });
69
70 tracing::info!("{}", serde_json::to_string_pretty(&request_body).unwrap());
71
72 let response = self
73 .client
74 .post(&format!("/v1beta/models/{}:batchEmbedContents", self.model))
75 .json(&request_body)
76 .send()
77 .await?
78 .error_for_status()?
79 .json::<ApiResponse<gemini_api_types::EmbeddingResponse>>()
80 .await?;
81
82 match response {
83 ApiResponse::Ok(response) => {
84 let docs = documents
85 .into_iter()
86 .zip(response.embeddings)
87 .map(|(document, embedding)| embeddings::Embedding {
88 document,
89 vec: embedding.values,
90 })
91 .collect();
92
93 Ok(docs)
94 }
95 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
96 }
97 }
98}
99
100// =================================================================
101// Gemini API Types
102// =================================================================
103/// Rust Implementation of the Gemini Types from [Gemini API Reference](https://ai.google.dev/api/embeddings)
104#[allow(dead_code)]
105mod gemini_api_types {
106 use serde::{Deserialize, Serialize};
107 use serde_json::Value;
108
109 use crate::providers::gemini::gemini_api_types::{CodeExecutionResult, ExecutableCode};
110
111 #[derive(Serialize)]
112 #[serde(rename_all = "camelCase")]
113 pub struct EmbedContentRequest {
114 model: String,
115 content: EmbeddingContent,
116 task_type: TaskType,
117 title: String,
118 output_dimensionality: i32,
119 }
120
121 #[derive(Serialize)]
122 pub struct EmbeddingContent {
123 parts: Vec<EmbeddingContentPart>,
124 /// Optional. The producer of the content. Must be either 'user' or 'model'. Useful to set for multi-turn
125 /// conversations, otherwise can be left blank or unset.
126 role: Option<String>,
127 }
128
129 /// A datatype containing media that is part of a multi-part Content message.
130 /// - A Part consists of data which has an associated datatype. A Part can only contain one of the accepted types in Part.data.
131 /// - A Part must have a fixed IANA MIME type identifying the type and subtype of the media if the inlineData field is filled with raw bytes.
132 #[derive(Serialize)]
133 pub struct EmbeddingContentPart {
134 /// Inline text.
135 text: String,
136 /// Inline media bytes.
137 inline_data: Option<Blob>,
138 /// A predicted FunctionCall returned from the model that contains a string representing the [FunctionDeclaration.name]
139 /// with the arguments and their values.
140 function_call: Option<FunctionCall>,
141 /// The result output of a FunctionCall that contains a string representing the [FunctionDeclaration.name] and a structured
142 /// JSON object containing any output from the function is used as context to the model.
143 function_response: Option<FunctionResponse>,
144 /// URI based data.
145 file_data: Option<FileData>,
146 /// Code generated by the model that is meant to be executed.
147 executable_code: Option<ExecutableCode>,
148 /// Result of executing the ExecutableCode.
149 code_execution_result: Option<CodeExecutionResult>,
150 }
151
152 /// Raw media bytes.
153 /// Text should not be sent as raw bytes, use the 'text' field.
154 #[derive(Serialize)]
155 pub struct Blob {
156 /// Raw bytes for media formats.A base64-encoded string.
157 data: String,
158 /// The IANA standard MIME type of the source data. Examples: - image/png - image/jpeg If an unsupported MIME type is
159 /// provided, an error will be returned. For a complete list of supported types, see Supported file formats.
160 mime_type: String,
161 }
162
163 #[derive(Serialize)]
164 pub struct FunctionCall {
165 /// The name of the function to call. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 63.
166 name: String,
167 /// The function parameters and values in JSON object format.
168 args: Option<Value>,
169 }
170
171 #[derive(Serialize)]
172 pub struct FunctionResponse {
173 /// The name of the function to call. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 63.
174 name: String,
175 /// The result of the function call in JSON object format.
176 result: Value,
177 }
178
179 #[derive(Serialize)]
180 #[serde(rename_all = "camelCase")]
181 pub struct FileData {
182 /// The URI of the file.
183 file_uri: String,
184 /// The IANA standard MIME type of the source data.
185 mime_type: String,
186 }
187
188 #[derive(Serialize)]
189 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
190 pub enum TaskType {
191 /// Unset value, which will default to one of the other enum values.
192 Unspecified,
193 /// Specifies the given text is a query in a search/retrieval setting.
194 RetrievalQuery,
195 /// Specifies the given text is a document from the corpus being searched.
196 RetrievalDocument,
197 /// Specifies the given text will be used for STS.
198 SemanticSimilarity,
199 /// Specifies that the given text will be classified.
200 Classification,
201 /// Specifies that the embeddings will be used for clustering.
202 Clustering,
203 /// Specifies that the given text will be used for question answering.
204 QuestionAnswering,
205 /// Specifies that the given text will be used for fact verification.
206 FactVerification,
207 }
208
209 #[derive(Debug, Deserialize)]
210 pub struct EmbeddingResponse {
211 pub embeddings: Vec<EmbeddingValues>,
212 }
213
214 #[derive(Debug, Deserialize)]
215 pub struct EmbeddingValues {
216 pub values: Vec<f64>,
217 }
218}