rig/providers/gemini/embedding.rs
1// ================================================================
2//! Google Gemini Embeddings Integration
3//! From [Gemini API Reference](https://ai.google.dev/api/embeddings)
4// ================================================================
5
6use serde_json::json;
7
8use super::{Client, client::ApiResponse};
9use crate::{
10 embeddings::{self, EmbeddingError},
11 http_client::HttpClientExt,
12 wasm_compat::WasmCompatSend,
13};
14
15/// `embedding-001` embedding model
16pub const EMBEDDING_001: &str = "embedding-001";
17/// `text-embedding-004` embedding model
18pub const EMBEDDING_004: &str = "text-embedding-004";
19
20#[derive(Clone)]
21pub struct EmbeddingModel<T = reqwest::Client> {
22 client: Client<T>,
23 model: String,
24 ndims: Option<usize>,
25}
26
27impl<T> EmbeddingModel<T> {
28 pub fn new(client: Client<T>, model: impl Into<String>, ndims: Option<usize>) -> Self {
29 Self {
30 client,
31 model: model.into(),
32 ndims,
33 }
34 }
35
36 pub fn with_model(client: Client<T>, model: &str, ndims: Option<usize>) -> Self {
37 Self {
38 client,
39 model: model.to_string(),
40 ndims,
41 }
42 }
43}
44
45impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
46where
47 T: Clone + HttpClientExt + 'static,
48{
49 type Client = Client<T>;
50
51 const MAX_DOCUMENTS: usize = 1024;
52
53 fn make(client: &Self::Client, model: impl Into<String>, dims: Option<usize>) -> Self {
54 Self::new(client.clone(), model, dims)
55 }
56
57 fn ndims(&self) -> usize {
58 768
59 }
60
61 /// <https://ai.google.dev/api/embeddings#batch_embed_contents-SHELL>
62 async fn embed_texts(
63 &self,
64 documents: impl IntoIterator<Item = String> + WasmCompatSend,
65 ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
66 let documents: Vec<String> = documents.into_iter().collect();
67
68 // Google batch embed requests. See docstrings for API ref link.
69 let requests: Vec<_> = documents
70 .iter()
71 .map(|doc| {
72 json!({
73 "model": format!("models/{}", self.model),
74 "content": json!({
75 "parts": [json!({
76 "text": doc.to_string()
77 })]
78 }),
79 "output_dimensionality": self.ndims,
80 })
81 })
82 .collect();
83
84 let request_body = json!({ "requests": requests });
85
86 tracing::trace!(
87 target: "rig::embedding",
88 "Sending embedding request to Gemini API {}",
89 serde_json::to_string_pretty(&request_body).unwrap()
90 );
91
92 let request_body = serde_json::to_vec(&request_body)?;
93 let path = format!("/v1beta/models/{}:batchEmbedContents", self.model);
94 let req = self
95 .client
96 .post(path.as_str())?
97 .body(request_body)
98 .map_err(|e| EmbeddingError::HttpError(e.into()))?;
99 let response = self.client.send::<_, Vec<u8>>(req).await?;
100
101 let response: ApiResponse<gemini_api_types::EmbeddingResponse> =
102 serde_json::from_slice(&response.into_body().await?)?;
103
104 match response {
105 ApiResponse::Ok(response) => {
106 let docs = documents
107 .into_iter()
108 .zip(response.embeddings)
109 .map(|(document, embedding)| embeddings::Embedding {
110 document,
111 vec: embedding.values,
112 })
113 .collect();
114
115 Ok(docs)
116 }
117 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
118 }
119 }
120}
121
122// =================================================================
123// Gemini API Types
124// =================================================================
125/// Rust Implementation of the Gemini Types from [Gemini API Reference](https://ai.google.dev/api/embeddings)
126#[allow(dead_code)]
127mod gemini_api_types {
128 use serde::{Deserialize, Serialize};
129 use serde_json::Value;
130
131 use crate::providers::gemini::gemini_api_types::{CodeExecutionResult, ExecutableCode};
132
133 #[derive(Serialize)]
134 #[serde(rename_all = "camelCase")]
135 pub struct EmbedContentRequest {
136 model: String,
137 content: EmbeddingContent,
138 task_type: TaskType,
139 title: String,
140 output_dimensionality: i32,
141 }
142
143 #[derive(Serialize)]
144 pub struct EmbeddingContent {
145 parts: Vec<EmbeddingContentPart>,
146 /// Optional. The producer of the content. Must be either 'user' or 'model'. Useful to set for multi-turn
147 /// conversations, otherwise can be left blank or unset.
148 role: Option<String>,
149 }
150
151 /// A datatype containing media that is part of a multi-part Content message.
152 /// - A Part consists of data which has an associated datatype. A Part can only contain one of the accepted types in Part.data.
153 /// - A Part must have a fixed IANA MIME type identifying the type and subtype of the media if the inlineData field is filled with raw bytes.
154 #[derive(Serialize)]
155 pub struct EmbeddingContentPart {
156 /// Inline text.
157 text: String,
158 /// Inline media bytes.
159 inline_data: Option<Blob>,
160 /// A predicted FunctionCall returned from the model that contains a string representing the [FunctionDeclaration.name]
161 /// with the arguments and their values.
162 function_call: Option<FunctionCall>,
163 /// The result output of a FunctionCall that contains a string representing the [FunctionDeclaration.name] and a structured
164 /// JSON object containing any output from the function is used as context to the model.
165 function_response: Option<FunctionResponse>,
166 /// URI based data.
167 file_data: Option<FileData>,
168 /// Code generated by the model that is meant to be executed.
169 executable_code: Option<ExecutableCode>,
170 /// Result of executing the ExecutableCode.
171 code_execution_result: Option<CodeExecutionResult>,
172 }
173
174 /// Raw media bytes.
175 /// Text should not be sent as raw bytes, use the 'text' field.
176 #[derive(Serialize)]
177 pub struct Blob {
178 /// Raw bytes for media formats.A base64-encoded string.
179 data: String,
180 /// The IANA standard MIME type of the source data. Examples: - image/png - image/jpeg If an unsupported MIME type is
181 /// provided, an error will be returned. For a complete list of supported types, see Supported file formats.
182 mime_type: String,
183 }
184
185 #[derive(Serialize)]
186 pub struct FunctionCall {
187 /// The name of the function to call. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 63.
188 name: String,
189 /// The function parameters and values in JSON object format.
190 args: Option<Value>,
191 }
192
193 #[derive(Serialize)]
194 pub struct FunctionResponse {
195 /// The name of the function to call. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 63.
196 name: String,
197 /// The result of the function call in JSON object format.
198 result: Value,
199 }
200
201 #[derive(Serialize)]
202 #[serde(rename_all = "camelCase")]
203 pub struct FileData {
204 /// The URI of the file.
205 file_uri: String,
206 /// The IANA standard MIME type of the source data.
207 mime_type: String,
208 }
209
210 #[derive(Serialize)]
211 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
212 pub enum TaskType {
213 /// Unset value, which will default to one of the other enum values.
214 Unspecified,
215 /// Specifies the given text is a query in a search/retrieval setting.
216 RetrievalQuery,
217 /// Specifies the given text is a document from the corpus being searched.
218 RetrievalDocument,
219 /// Specifies the given text will be used for STS.
220 SemanticSimilarity,
221 /// Specifies that the given text will be classified.
222 Classification,
223 /// Specifies that the embeddings will be used for clustering.
224 Clustering,
225 /// Specifies that the given text will be used for question answering.
226 QuestionAnswering,
227 /// Specifies that the given text will be used for fact verification.
228 FactVerification,
229 }
230
231 #[derive(Debug, Deserialize)]
232 pub struct EmbeddingResponse {
233 pub embeddings: Vec<EmbeddingValues>,
234 }
235
236 #[derive(Debug, Deserialize)]
237 pub struct EmbeddingValues {
238 pub values: Vec<f64>,
239 }
240}