1use serde_json::json;
7
8use super::{Client, client::ApiResponse};
9use crate::{
10 embeddings::{self, EmbeddingError},
11 http_client::HttpClientExt,
12 wasm_compat::WasmCompatSend,
13};
14
15pub const EMBEDDING_001: &str = "gemini-embedding-001";
17pub const EMBEDDING_004: &str = "text-embedding-004";
19
20fn model_default_ndims(model: &str) -> Option<usize> {
24 match model {
25 EMBEDDING_001 => Some(3072),
26 EMBEDDING_004 => Some(768),
27 _ => None,
28 }
29}
30
31#[derive(Clone)]
32pub struct EmbeddingModel<T = reqwest::Client> {
33 client: Client<T>,
34 model: String,
35 ndims: usize,
36}
37
38impl<T> EmbeddingModel<T> {
39 pub fn new(client: Client<T>, model: impl Into<String>, ndims: usize) -> Self {
40 Self {
41 client,
42 model: model.into(),
43 ndims,
44 }
45 }
46
47 pub fn with_model(client: Client<T>, model: &str, ndims: usize) -> Self {
48 Self {
49 client,
50 model: model.to_string(),
51 ndims,
52 }
53 }
54}
55
56impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
57where
58 T: Clone + HttpClientExt + 'static,
59{
60 type Client = Client<T>;
61
62 const MAX_DOCUMENTS: usize = 1024;
63
64 fn make(client: &Self::Client, model: impl Into<String>, dims: Option<usize>) -> Self {
65 let model = model.into();
66 let ndims = dims.or_else(|| model_default_ndims(&model)).unwrap_or(768);
67 Self::new(client.clone(), model, ndims)
68 }
69
70 fn ndims(&self) -> usize {
71 self.ndims
72 }
73
74 async fn embed_texts(
76 &self,
77 documents: impl IntoIterator<Item = String> + WasmCompatSend,
78 ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
79 let documents: Vec<String> = documents.into_iter().collect();
80
81 let requests: Vec<_> = documents
83 .iter()
84 .map(|doc| {
85 json!({
86 "model": format!("models/{}", self.model),
87 "content": json!({
88 "parts": [json!({
89 "text": doc.to_string()
90 })]
91 }),
92 "output_dimensionality": self.ndims,
93 })
94 })
95 .collect();
96
97 let request_body = json!({ "requests": requests });
98
99 tracing::trace!(
100 target: "rig::embedding",
101 "Sending embedding request to Gemini API {}",
102 serde_json::to_string_pretty(&request_body).unwrap()
103 );
104
105 let request_body = serde_json::to_vec(&request_body)?;
106 let path = format!("/v1beta/models/{}:batchEmbedContents", self.model);
107 let req = self
108 .client
109 .post(path.as_str())?
110 .body(request_body)
111 .map_err(|e| EmbeddingError::HttpError(e.into()))?;
112 let response = self.client.send::<_, Vec<u8>>(req).await?;
113
114 let response: ApiResponse<gemini_api_types::EmbeddingResponse> =
115 serde_json::from_slice(&response.into_body().await?)?;
116
117 match response {
118 ApiResponse::Ok(response) => {
119 let docs = documents
120 .into_iter()
121 .zip(response.embeddings)
122 .map(|(document, embedding)| embeddings::Embedding {
123 document,
124 vec: embedding
125 .values
126 .into_iter()
127 .filter_map(|n| n.as_f64())
128 .collect(),
129 })
130 .collect();
131
132 Ok(docs)
133 }
134 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
135 }
136 }
137}
138
139#[allow(dead_code)]
144mod gemini_api_types {
145 use serde::{Deserialize, Serialize};
146 use serde_json::Value;
147
148 use crate::providers::gemini::gemini_api_types::{CodeExecutionResult, ExecutableCode};
149
150 #[derive(Serialize)]
151 #[serde(rename_all = "camelCase")]
152 pub struct EmbedContentRequest {
153 model: String,
154 content: EmbeddingContent,
155 task_type: TaskType,
156 title: String,
157 output_dimensionality: i32,
158 }
159
160 #[derive(Serialize)]
161 pub struct EmbeddingContent {
162 parts: Vec<EmbeddingContentPart>,
163 role: Option<String>,
166 }
167
168 #[derive(Serialize)]
172 pub struct EmbeddingContentPart {
173 text: String,
175 inline_data: Option<Blob>,
177 function_call: Option<FunctionCall>,
180 function_response: Option<FunctionResponse>,
183 file_data: Option<FileData>,
185 executable_code: Option<ExecutableCode>,
187 code_execution_result: Option<CodeExecutionResult>,
189 }
190
191 #[derive(Serialize)]
194 pub struct Blob {
195 data: String,
197 mime_type: String,
200 }
201
202 #[derive(Serialize)]
203 pub struct FunctionCall {
204 name: String,
206 args: Option<Value>,
208 }
209
210 #[derive(Serialize)]
211 pub struct FunctionResponse {
212 name: String,
214 result: Value,
216 }
217
218 #[derive(Serialize)]
219 #[serde(rename_all = "camelCase")]
220 pub struct FileData {
221 file_uri: String,
223 mime_type: String,
225 }
226
227 #[derive(Serialize)]
228 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
229 pub enum TaskType {
230 Unspecified,
232 RetrievalQuery,
234 RetrievalDocument,
236 SemanticSimilarity,
238 Classification,
240 Clustering,
242 QuestionAnswering,
244 FactVerification,
246 }
247
248 #[derive(Debug, Deserialize)]
249 pub struct EmbeddingResponse {
250 pub embeddings: Vec<EmbeddingValues>,
251 }
252
253 #[derive(Debug, Deserialize)]
254 pub struct EmbeddingValues {
255 pub values: Vec<serde_json::Number>,
256 }
257}
258
259#[cfg(test)]
260mod tests {
261 use super::*;
262
263 #[test]
264 fn test_model_default_ndims_lookup() {
265 assert_eq!(model_default_ndims(EMBEDDING_001), Some(3072));
266 assert_eq!(model_default_ndims(EMBEDDING_004), Some(768));
267 assert_eq!(model_default_ndims("unknown-model"), None);
268 }
269
270 #[test]
271 fn test_make_resolves_default_dims() {
272 let client = Client::new("test_key").unwrap();
273
274 let model =
276 <EmbeddingModel as embeddings::EmbeddingModel>::make(&client, EMBEDDING_001, None);
277 assert_eq!(embeddings::EmbeddingModel::ndims(&model), 3072);
278
279 let model =
281 <EmbeddingModel as embeddings::EmbeddingModel>::make(&client, EMBEDDING_004, None);
282 assert_eq!(embeddings::EmbeddingModel::ndims(&model), 768);
283
284 let model = <EmbeddingModel as embeddings::EmbeddingModel>::make(
286 &client,
287 "some-future-model",
288 None,
289 );
290 assert_eq!(embeddings::EmbeddingModel::ndims(&model), 768);
291 }
292
293 #[test]
294 fn test_make_respects_explicit_dims() {
295 let client = Client::new("test_key").unwrap();
296
297 let model =
298 <EmbeddingModel as embeddings::EmbeddingModel>::make(&client, EMBEDDING_001, Some(256));
299 assert_eq!(embeddings::EmbeddingModel::ndims(&model), 256);
300 }
301
302 #[test]
303 fn test_new_uses_provided_ndims() {
304 let client = Client::new("test_key").unwrap();
305
306 let model = EmbeddingModel::new(client, EMBEDDING_001, 512);
307 assert_eq!(embeddings::EmbeddingModel::ndims(&model), 512);
308 }
309}