1use serde_json::json;
7
8use super::{Client, client::ApiResponse};
9use crate::{
10 embeddings::{self, EmbeddingError},
11 http_client::HttpClientExt,
12 wasm_compat::WasmCompatSend,
13};
14
15pub const EMBEDDING_001: &str = "gemini-embedding-001";
17pub const EMBEDDING_004: &str = "text-embedding-004";
19
20fn model_default_ndims(model: &str) -> Option<usize> {
24 match model {
25 EMBEDDING_001 => Some(3072),
26 EMBEDDING_004 => Some(768),
27 _ => None,
28 }
29}
30
31#[derive(Clone)]
32pub struct EmbeddingModel<T = reqwest::Client> {
33 client: Client<T>,
34 model: String,
35 ndims: usize,
36}
37
38impl<T> EmbeddingModel<T> {
39 pub fn new(client: Client<T>, model: impl Into<String>, ndims: usize) -> Self {
40 Self {
41 client,
42 model: model.into(),
43 ndims,
44 }
45 }
46
47 pub fn with_model(client: Client<T>, model: &str, ndims: usize) -> Self {
48 Self {
49 client,
50 model: model.to_string(),
51 ndims,
52 }
53 }
54}
55
56impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
57where
58 T: Clone + HttpClientExt + 'static,
59{
60 type Client = Client<T>;
61
62 const MAX_DOCUMENTS: usize = 1024;
63
64 fn make(client: &Self::Client, model: impl Into<String>, dims: Option<usize>) -> Self {
65 let model = model.into();
66 let ndims = dims.or_else(|| model_default_ndims(&model)).unwrap_or(768);
67 Self::new(client.clone(), model, ndims)
68 }
69
70 fn ndims(&self) -> usize {
71 self.ndims
72 }
73
74 async fn embed_texts(
76 &self,
77 documents: impl IntoIterator<Item = String> + WasmCompatSend,
78 ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
79 let documents: Vec<String> = documents.into_iter().collect();
80
81 let requests: Vec<_> = documents
83 .iter()
84 .map(|doc| {
85 json!({
86 "model": format!("models/{}", self.model),
87 "content": json!({
88 "parts": [json!({
89 "text": doc.to_string()
90 })]
91 }),
92 "output_dimensionality": self.ndims,
93 })
94 })
95 .collect();
96
97 let request_body = json!({ "requests": requests });
98
99 if let Ok(pretty_body) = serde_json::to_string_pretty(&request_body) {
100 tracing::trace!(
101 target: "rig::embedding",
102 "Sending embedding request to Gemini API {pretty_body}"
103 );
104 }
105
106 let request_body = serde_json::to_vec(&request_body)?;
107 let path = format!("/v1beta/models/{}:batchEmbedContents", self.model);
108 let req = self
109 .client
110 .post(path.as_str())?
111 .body(request_body)
112 .map_err(|e| EmbeddingError::HttpError(e.into()))?;
113 let response = self.client.send::<_, Vec<u8>>(req).await?;
114
115 let response: ApiResponse<gemini_api_types::EmbeddingResponse> =
116 serde_json::from_slice(&response.into_body().await?)?;
117
118 match response {
119 ApiResponse::Ok(response) => {
120 let docs = documents
121 .into_iter()
122 .zip(response.embeddings)
123 .map(|(document, embedding)| embeddings::Embedding {
124 document,
125 vec: embedding
126 .values
127 .into_iter()
128 .filter_map(|n| n.as_f64())
129 .collect(),
130 })
131 .collect();
132
133 Ok(docs)
134 }
135 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
136 }
137 }
138}
139
140#[allow(dead_code)]
145mod gemini_api_types {
146 use serde::{Deserialize, Serialize};
147 use serde_json::Value;
148
149 use crate::providers::gemini::gemini_api_types::{CodeExecutionResult, ExecutableCode};
150
151 #[derive(Serialize)]
152 #[serde(rename_all = "camelCase")]
153 pub struct EmbedContentRequest {
154 model: String,
155 content: EmbeddingContent,
156 task_type: TaskType,
157 title: String,
158 output_dimensionality: i32,
159 }
160
161 #[derive(Serialize)]
162 pub struct EmbeddingContent {
163 parts: Vec<EmbeddingContentPart>,
164 role: Option<String>,
167 }
168
169 #[derive(Serialize)]
173 pub struct EmbeddingContentPart {
174 text: String,
176 inline_data: Option<Blob>,
178 function_call: Option<FunctionCall>,
181 function_response: Option<FunctionResponse>,
184 file_data: Option<FileData>,
186 executable_code: Option<ExecutableCode>,
188 code_execution_result: Option<CodeExecutionResult>,
190 }
191
192 #[derive(Serialize)]
195 pub struct Blob {
196 data: String,
198 mime_type: String,
201 }
202
203 #[derive(Serialize)]
204 pub struct FunctionCall {
205 name: String,
207 args: Option<Value>,
209 }
210
211 #[derive(Serialize)]
212 pub struct FunctionResponse {
213 name: String,
215 result: Value,
217 }
218
219 #[derive(Serialize)]
220 #[serde(rename_all = "camelCase")]
221 pub struct FileData {
222 file_uri: String,
224 mime_type: String,
226 }
227
228 #[derive(Serialize)]
229 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
230 pub enum TaskType {
231 Unspecified,
233 RetrievalQuery,
235 RetrievalDocument,
237 SemanticSimilarity,
239 Classification,
241 Clustering,
243 QuestionAnswering,
245 FactVerification,
247 }
248
249 #[derive(Debug, Deserialize)]
250 pub struct EmbeddingResponse {
251 pub embeddings: Vec<EmbeddingValues>,
252 }
253
254 #[derive(Debug, Deserialize)]
255 pub struct EmbeddingValues {
256 pub values: Vec<serde_json::Number>,
257 }
258}
259
260#[cfg(test)]
261mod tests {
262 use super::*;
263
264 #[test]
265 fn test_model_default_ndims_lookup() {
266 assert_eq!(model_default_ndims(EMBEDDING_001), Some(3072));
267 assert_eq!(model_default_ndims(EMBEDDING_004), Some(768));
268 assert_eq!(model_default_ndims("unknown-model"), None);
269 }
270
271 #[test]
272 fn test_make_resolves_default_dims() {
273 let client = Client::new("test_key").unwrap();
274
275 let model =
277 <EmbeddingModel as embeddings::EmbeddingModel>::make(&client, EMBEDDING_001, None);
278 assert_eq!(embeddings::EmbeddingModel::ndims(&model), 3072);
279
280 let model =
282 <EmbeddingModel as embeddings::EmbeddingModel>::make(&client, EMBEDDING_004, None);
283 assert_eq!(embeddings::EmbeddingModel::ndims(&model), 768);
284
285 let model = <EmbeddingModel as embeddings::EmbeddingModel>::make(
287 &client,
288 "some-future-model",
289 None,
290 );
291 assert_eq!(embeddings::EmbeddingModel::ndims(&model), 768);
292 }
293
294 #[test]
295 fn test_make_respects_explicit_dims() {
296 let client = Client::new("test_key").unwrap();
297
298 let model =
299 <EmbeddingModel as embeddings::EmbeddingModel>::make(&client, EMBEDDING_001, Some(256));
300 assert_eq!(embeddings::EmbeddingModel::ndims(&model), 256);
301 }
302
303 #[test]
304 fn test_new_uses_provided_ndims() {
305 let client = Client::new("test_key").unwrap();
306
307 let model = EmbeddingModel::new(client, EMBEDDING_001, 512);
308 assert_eq!(embeddings::EmbeddingModel::ndims(&model), 512);
309 }
310}