rig/providers/gemini/
embedding.rs1use serde_json::json;
7
8use super::{Client, client::ApiResponse};
9use crate::{
10 embeddings::{self, EmbeddingError},
11 http_client::HttpClientExt,
12 wasm_compat::WasmCompatSend,
13};
14
15pub const EMBEDDING_001: &str = "embedding-001";
17pub const EMBEDDING_004: &str = "text-embedding-004";
19
20#[derive(Clone)]
21pub struct EmbeddingModel<T = reqwest::Client> {
22 client: Client<T>,
23 model: String,
24 ndims: Option<usize>,
25}
26
27impl<T> EmbeddingModel<T> {
28 pub fn new(client: Client<T>, model: impl Into<String>, ndims: Option<usize>) -> Self {
29 Self {
30 client,
31 model: model.into(),
32 ndims,
33 }
34 }
35
36 pub fn with_model(client: Client<T>, model: &str, ndims: Option<usize>) -> Self {
37 Self {
38 client,
39 model: model.to_string(),
40 ndims,
41 }
42 }
43}
44
45impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
46where
47 T: Clone + HttpClientExt,
48{
49 type Client = Client<T>;
50
51 const MAX_DOCUMENTS: usize = 1024;
52
53 fn make(client: &Self::Client, model: impl Into<String>, dims: Option<usize>) -> Self {
54 Self::new(client.clone(), model, dims)
55 }
56
57 fn ndims(&self) -> usize {
58 768
59 }
60
61 #[cfg_attr(feature = "worker", worker::send)]
63 async fn embed_texts(
64 &self,
65 documents: impl IntoIterator<Item = String> + WasmCompatSend,
66 ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
67 let documents: Vec<String> = documents.into_iter().collect();
68
69 let requests: Vec<_> = documents
71 .iter()
72 .map(|doc| {
73 json!({
74 "model": format!("models/{}", self.model),
75 "content": json!({
76 "parts": [json!({
77 "text": doc.to_string()
78 })]
79 }),
80 "output_dimensionality": self.ndims,
81 })
82 })
83 .collect();
84
85 let request_body = json!({ "requests": requests });
86
87 tracing::trace!(
88 target: "rig::embedding",
89 "Sending embedding request to Gemini API {}",
90 serde_json::to_string_pretty(&request_body).unwrap()
91 );
92
93 let request_body = serde_json::to_vec(&request_body)?;
94 let path = format!("/v1beta/models/{}:batchEmbedContents", self.model);
95 let req = self
96 .client
97 .post(path.as_str())?
98 .body(request_body)
99 .map_err(|e| EmbeddingError::HttpError(e.into()))?;
100 let response = self.client.send::<_, Vec<u8>>(req).await?;
101
102 let response: ApiResponse<gemini_api_types::EmbeddingResponse> =
103 serde_json::from_slice(&response.into_body().await?)?;
104
105 match response {
106 ApiResponse::Ok(response) => {
107 let docs = documents
108 .into_iter()
109 .zip(response.embeddings)
110 .map(|(document, embedding)| embeddings::Embedding {
111 document,
112 vec: embedding.values,
113 })
114 .collect();
115
116 Ok(docs)
117 }
118 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
119 }
120 }
121}
122
123#[allow(dead_code)]
128mod gemini_api_types {
129 use serde::{Deserialize, Serialize};
130 use serde_json::Value;
131
132 use crate::providers::gemini::gemini_api_types::{CodeExecutionResult, ExecutableCode};
133
134 #[derive(Serialize)]
135 #[serde(rename_all = "camelCase")]
136 pub struct EmbedContentRequest {
137 model: String,
138 content: EmbeddingContent,
139 task_type: TaskType,
140 title: String,
141 output_dimensionality: i32,
142 }
143
144 #[derive(Serialize)]
145 pub struct EmbeddingContent {
146 parts: Vec<EmbeddingContentPart>,
147 role: Option<String>,
150 }
151
152 #[derive(Serialize)]
156 pub struct EmbeddingContentPart {
157 text: String,
159 inline_data: Option<Blob>,
161 function_call: Option<FunctionCall>,
164 function_response: Option<FunctionResponse>,
167 file_data: Option<FileData>,
169 executable_code: Option<ExecutableCode>,
171 code_execution_result: Option<CodeExecutionResult>,
173 }
174
175 #[derive(Serialize)]
178 pub struct Blob {
179 data: String,
181 mime_type: String,
184 }
185
186 #[derive(Serialize)]
187 pub struct FunctionCall {
188 name: String,
190 args: Option<Value>,
192 }
193
194 #[derive(Serialize)]
195 pub struct FunctionResponse {
196 name: String,
198 result: Value,
200 }
201
202 #[derive(Serialize)]
203 #[serde(rename_all = "camelCase")]
204 pub struct FileData {
205 file_uri: String,
207 mime_type: String,
209 }
210
211 #[derive(Serialize)]
212 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
213 pub enum TaskType {
214 Unspecified,
216 RetrievalQuery,
218 RetrievalDocument,
220 SemanticSimilarity,
222 Classification,
224 Clustering,
226 QuestionAnswering,
228 FactVerification,
230 }
231
232 #[derive(Debug, Deserialize)]
233 pub struct EmbeddingResponse {
234 pub embeddings: Vec<EmbeddingValues>,
235 }
236
237 #[derive(Debug, Deserialize)]
238 pub struct EmbeddingValues {
239 pub values: Vec<f64>,
240 }
241}