bep/providers/gemini/embedding.rs
1// ================================================================
2//! Google Gemini Embeddings Integration
3//! From [Gemini API Reference](https://ai.google.dev/api/embeddings)
4// ================================================================
5
6use serde_json::json;
7
8use crate::embeddings::{self, EmbeddingError};
9
10use super::{client::ApiResponse, Client};
11
12/// `embedding-001` embedding model
13pub const EMBEDDING_001: &str = "embedding-001";
14/// `text-embedding-004` embedding model
15pub const EMBEDDING_004: &str = "text-embedding-004";
16#[derive(Clone)]
17pub struct EmbeddingModel {
18 client: Client,
19 model: String,
20 ndims: Option<usize>,
21}
22
23impl EmbeddingModel {
24 pub fn new(client: Client, model: &str, ndims: Option<usize>) -> Self {
25 Self {
26 client,
27 model: model.to_string(),
28 ndims,
29 }
30 }
31}
32
33impl embeddings::EmbeddingModel for EmbeddingModel {
34 const MAX_DOCUMENTS: usize = 1024;
35
36 fn ndims(&self) -> usize {
37 match self.model.as_str() {
38 EMBEDDING_001 => 768,
39 EMBEDDING_004 => 1024,
40 _ => 0, // Default to 0 for unknown models
41 }
42 }
43
44 async fn embed_texts(
45 &self,
46 documents: impl IntoIterator<Item = String> + Send,
47 ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
48 let documents: Vec<_> = documents.into_iter().collect();
49 let mut request_body = json!({
50 "model": format!("models/{}", self.model),
51 "content": {
52 "parts": documents.iter().map(|doc| json!({ "text": doc })).collect::<Vec<_>>(),
53 },
54 });
55
56 if let Some(ndims) = self.ndims {
57 request_body["output_dimensionality"] = json!(ndims);
58 }
59
60 let response = self
61 .client
62 .post(&format!("/v1beta/models/{}:embedContent", self.model))
63 .json(&request_body)
64 .send()
65 .await?
66 .error_for_status()?
67 .json::<ApiResponse<gemini_api_types::EmbeddingResponse>>()
68 .await?;
69
70 match response {
71 ApiResponse::Ok(response) => {
72 let chunk_size = self.ndims.unwrap_or_else(|| self.ndims());
73 Ok(documents
74 .into_iter()
75 .zip(response.embedding.values.chunks(chunk_size))
76 .map(|(document, embedding)| embeddings::Embedding {
77 document,
78 vec: embedding.to_vec(),
79 })
80 .collect())
81 }
82 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
83 }
84 }
85}
86
87// =================================================================
88// Gemini API Types
89// =================================================================
90/// Rust Implementation of the Gemini Types from [Gemini API Reference](https://ai.google.dev/api/embeddings)
91#[allow(dead_code)]
92mod gemini_api_types {
93 use serde::{Deserialize, Serialize};
94 use serde_json::Value;
95
96 use crate::providers::gemini::gemini_api_types::{CodeExecutionResult, ExecutableCode};
97
98 #[derive(Serialize)]
99 #[serde(rename_all = "camelCase")]
100 pub struct EmbedContentRequest {
101 model: String,
102 content: EmbeddingContent,
103 task_type: TaskType,
104 title: String,
105 output_dimensionality: i32,
106 }
107
108 #[derive(Serialize)]
109 pub struct EmbeddingContent {
110 parts: Vec<EmbeddingContentPart>,
111 /// Optional. The producer of the content. Must be either 'user' or 'model'. Useful to set for multi-turn
112 /// conversations, otherwise can be left blank or unset.
113 role: Option<String>,
114 }
115
116 /// A datatype containing media that is part of a multi-part Content message.
117 /// - A Part consists of data which has an associated datatype. A Part can only contain one of the accepted types in Part.data.
118 /// - A Part must have a fixed IANA MIME type identifying the type and subtype of the media if the inlineData field is filled with raw bytes.
119 #[derive(Serialize)]
120 pub struct EmbeddingContentPart {
121 /// Inline text.
122 text: String,
123 /// Inline media bytes.
124 inline_data: Option<Blob>,
125 /// A predicted FunctionCall returned from the model that contains a string representing the [FunctionDeclaration.name]
126 /// with the arguments and their values.
127 function_call: Option<FunctionCall>,
128 /// The result output of a FunctionCall that contains a string representing the [FunctionDeclaration.name] and a structured
129 /// JSON object containing any output from the function is used as context to the model.
130 function_response: Option<FunctionResponse>,
131 /// URI based data.
132 file_data: Option<FileData>,
133 /// Code generated by the model that is meant to be executed.
134 executable_code: Option<ExecutableCode>,
135 /// Result of executing the ExecutableCode.
136 code_execution_result: Option<CodeExecutionResult>,
137 }
138
139 /// Raw media bytes.
140 /// Text should not be sent as raw bytes, use the 'text' field.
141 #[derive(Serialize)]
142 pub struct Blob {
143 /// Raw bytes for media formats.A base64-encoded string.
144 data: String,
145 /// The IANA standard MIME type of the source data. Examples: - image/png - image/jpeg If an unsupported MIME type is
146 /// provided, an error will be returned. For a complete list of supported types, see Supported file formats.
147 mime_type: String,
148 }
149
150 #[derive(Serialize)]
151 pub struct FunctionCall {
152 /// The name of the function to call. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 63.
153 name: String,
154 /// The function parameters and values in JSON object format.
155 args: Option<Value>,
156 }
157
158 #[derive(Serialize)]
159 pub struct FunctionResponse {
160 /// The name of the function to call. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 63.
161 name: String,
162 /// The result of the function call in JSON object format.
163 result: Value,
164 }
165
166 #[derive(Serialize)]
167 #[serde(rename_all = "camelCase")]
168 pub struct FileData {
169 /// The URI of the file.
170 file_uri: String,
171 /// The IANA standard MIME type of the source data.
172 mime_type: String,
173 }
174
175 #[derive(Serialize)]
176 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
177 pub enum TaskType {
178 /// Unset value, which will default to one of the other enum values.
179 Unspecified,
180 /// Specifies the given text is a query in a search/retrieval setting.
181 RetrievalQuery,
182 /// Specifies the given text is a document from the corpus being searched.
183 RetrievalDocument,
184 /// Specifies the given text will be used for STS.
185 SemanticSimilarity,
186 /// Specifies that the given text will be classified.
187 Classification,
188 /// Specifies that the embeddings will be used for clustering.
189 Clustering,
190 /// Specifies that the given text will be used for question answering.
191 QuestionAnswering,
192 /// Specifies that the given text will be used for fact verification.
193 FactVerification,
194 }
195
196 #[derive(Debug, Deserialize)]
197 pub struct EmbeddingResponse {
198 pub embedding: EmbeddingValues,
199 }
200
201 #[derive(Debug, Deserialize)]
202 pub struct EmbeddingValues {
203 pub values: Vec<f64>,
204 }
205}