rig/providers/gemini/embedding.rs
1// ================================================================
2//! Google Gemini Embeddings Integration
3//! From [Gemini API Reference](https://ai.google.dev/api/embeddings)
4// ================================================================
5
6use serde_json::json;
7
8use crate::{
9 embeddings::{self, EmbeddingError},
10 http_client::HttpClientExt,
11 wasm_compat::WasmCompatSend,
12};
13
14use super::{Client, client::ApiResponse};
15
16/// `embedding-001` embedding model
17pub const EMBEDDING_001: &str = "embedding-001";
18/// `text-embedding-004` embedding model
19pub const EMBEDDING_004: &str = "text-embedding-004";
20#[derive(Clone)]
21pub struct EmbeddingModel<T = reqwest::Client> {
22 client: Client<T>,
23 model: String,
24 ndims: Option<usize>,
25}
26
27impl<T> EmbeddingModel<T> {
28 pub fn new(client: Client<T>, model: &str, ndims: Option<usize>) -> Self {
29 Self {
30 client,
31 model: model.to_string(),
32 ndims,
33 }
34 }
35}
36
37impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
38where
39 T: Clone + HttpClientExt,
40{
41 const MAX_DOCUMENTS: usize = 1024;
42
43 fn ndims(&self) -> usize {
44 match self.model.as_str() {
45 EMBEDDING_001 | EMBEDDING_004 => 768,
46 _ => 0, // Default to 0 for unknown models
47 }
48 }
49
50 /// <https://ai.google.dev/api/embeddings#batch_embed_contents-SHELL>
51 #[cfg_attr(feature = "worker", worker::send)]
52 async fn embed_texts(
53 &self,
54 documents: impl IntoIterator<Item = String> + WasmCompatSend,
55 ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
56 let documents: Vec<String> = documents.into_iter().collect();
57
58 // Google batch embed requests. See docstrings for API ref link.
59 let requests: Vec<_> = documents
60 .iter()
61 .map(|doc| {
62 json!({
63 "model": format!("models/{}", self.model),
64 "content": json!({
65 "parts": [json!({
66 "text": doc.to_string()
67 })]
68 }),
69 "output_dimensionality": self.ndims,
70 })
71 })
72 .collect();
73
74 let request_body = json!({ "requests": requests });
75
76 tracing::info!("{}", serde_json::to_string_pretty(&request_body).unwrap());
77
78 let request_body = serde_json::to_vec(&request_body)?;
79 let req = self
80 .client
81 .post(&format!("/v1beta/models/{}:batchEmbedContents", self.model))
82 .header("Content-Type", "application/json")
83 .body(request_body)
84 .map_err(|e| EmbeddingError::HttpError(e.into()))?;
85 let response = self.client.send::<_, Vec<u8>>(req).await?;
86
87 let response: ApiResponse<gemini_api_types::EmbeddingResponse> =
88 serde_json::from_slice(&response.into_body().await?)?;
89
90 match response {
91 ApiResponse::Ok(response) => {
92 let docs = documents
93 .into_iter()
94 .zip(response.embeddings)
95 .map(|(document, embedding)| embeddings::Embedding {
96 document,
97 vec: embedding.values,
98 })
99 .collect();
100
101 Ok(docs)
102 }
103 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
104 }
105 }
106}
107
108// =================================================================
109// Gemini API Types
110// =================================================================
111/// Rust Implementation of the Gemini Types from [Gemini API Reference](https://ai.google.dev/api/embeddings)
112#[allow(dead_code)]
113mod gemini_api_types {
114 use serde::{Deserialize, Serialize};
115 use serde_json::Value;
116
117 use crate::providers::gemini::gemini_api_types::{CodeExecutionResult, ExecutableCode};
118
119 #[derive(Serialize)]
120 #[serde(rename_all = "camelCase")]
121 pub struct EmbedContentRequest {
122 model: String,
123 content: EmbeddingContent,
124 task_type: TaskType,
125 title: String,
126 output_dimensionality: i32,
127 }
128
129 #[derive(Serialize)]
130 pub struct EmbeddingContent {
131 parts: Vec<EmbeddingContentPart>,
132 /// Optional. The producer of the content. Must be either 'user' or 'model'. Useful to set for multi-turn
133 /// conversations, otherwise can be left blank or unset.
134 role: Option<String>,
135 }
136
137 /// A datatype containing media that is part of a multi-part Content message.
138 /// - A Part consists of data which has an associated datatype. A Part can only contain one of the accepted types in Part.data.
139 /// - A Part must have a fixed IANA MIME type identifying the type and subtype of the media if the inlineData field is filled with raw bytes.
140 #[derive(Serialize)]
141 pub struct EmbeddingContentPart {
142 /// Inline text.
143 text: String,
144 /// Inline media bytes.
145 inline_data: Option<Blob>,
146 /// A predicted FunctionCall returned from the model that contains a string representing the [FunctionDeclaration.name]
147 /// with the arguments and their values.
148 function_call: Option<FunctionCall>,
149 /// The result output of a FunctionCall that contains a string representing the [FunctionDeclaration.name] and a structured
150 /// JSON object containing any output from the function is used as context to the model.
151 function_response: Option<FunctionResponse>,
152 /// URI based data.
153 file_data: Option<FileData>,
154 /// Code generated by the model that is meant to be executed.
155 executable_code: Option<ExecutableCode>,
156 /// Result of executing the ExecutableCode.
157 code_execution_result: Option<CodeExecutionResult>,
158 }
159
160 /// Raw media bytes.
161 /// Text should not be sent as raw bytes, use the 'text' field.
162 #[derive(Serialize)]
163 pub struct Blob {
164 /// Raw bytes for media formats.A base64-encoded string.
165 data: String,
166 /// The IANA standard MIME type of the source data. Examples: - image/png - image/jpeg If an unsupported MIME type is
167 /// provided, an error will be returned. For a complete list of supported types, see Supported file formats.
168 mime_type: String,
169 }
170
171 #[derive(Serialize)]
172 pub struct FunctionCall {
173 /// The name of the function to call. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 63.
174 name: String,
175 /// The function parameters and values in JSON object format.
176 args: Option<Value>,
177 }
178
179 #[derive(Serialize)]
180 pub struct FunctionResponse {
181 /// The name of the function to call. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 63.
182 name: String,
183 /// The result of the function call in JSON object format.
184 result: Value,
185 }
186
187 #[derive(Serialize)]
188 #[serde(rename_all = "camelCase")]
189 pub struct FileData {
190 /// The URI of the file.
191 file_uri: String,
192 /// The IANA standard MIME type of the source data.
193 mime_type: String,
194 }
195
196 #[derive(Serialize)]
197 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
198 pub enum TaskType {
199 /// Unset value, which will default to one of the other enum values.
200 Unspecified,
201 /// Specifies the given text is a query in a search/retrieval setting.
202 RetrievalQuery,
203 /// Specifies the given text is a document from the corpus being searched.
204 RetrievalDocument,
205 /// Specifies the given text will be used for STS.
206 SemanticSimilarity,
207 /// Specifies that the given text will be classified.
208 Classification,
209 /// Specifies that the embeddings will be used for clustering.
210 Clustering,
211 /// Specifies that the given text will be used for question answering.
212 QuestionAnswering,
213 /// Specifies that the given text will be used for fact verification.
214 FactVerification,
215 }
216
217 #[derive(Debug, Deserialize)]
218 pub struct EmbeddingResponse {
219 pub embeddings: Vec<EmbeddingValues>,
220 }
221
222 #[derive(Debug, Deserialize)]
223 pub struct EmbeddingValues {
224 pub values: Vec<f64>,
225 }
226}