rig/providers/gemini/embedding.rs
1// ================================================================
2//! Google Gemini Embeddings Integration
3//! From [Gemini API Reference](https://ai.google.dev/api/embeddings)
4// ================================================================
5
6use serde_json::json;
7
8use crate::embeddings::{self, EmbeddingError};
9
10use super::{client::ApiResponse, Client};
11
12/// `embedding-001` embedding model
13pub const EMBEDDING_001: &str = "embedding-001";
14/// `text-embedding-004` embedding model
15pub const EMBEDDING_004: &str = "text-embedding-004";
16#[derive(Clone)]
17pub struct EmbeddingModel {
18 client: Client,
19 model: String,
20 ndims: Option<usize>,
21}
22
23impl EmbeddingModel {
24 pub fn new(client: Client, model: &str, ndims: Option<usize>) -> Self {
25 Self {
26 client,
27 model: model.to_string(),
28 ndims,
29 }
30 }
31}
32
33impl embeddings::EmbeddingModel for EmbeddingModel {
34 const MAX_DOCUMENTS: usize = 1024;
35
36 fn ndims(&self) -> usize {
37 match self.model.as_str() {
38 EMBEDDING_001 => 768,
39 EMBEDDING_004 => 1024,
40 _ => 0, // Default to 0 for unknown models
41 }
42 }
43
44 #[cfg_attr(feature = "worker", worker::send)]
45 async fn embed_texts(
46 &self,
47 documents: impl IntoIterator<Item = String> + Send,
48 ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
49 let documents: Vec<_> = documents.into_iter().collect();
50 let mut request_body = json!({
51 "model": format!("models/{}", self.model),
52 "content": {
53 "parts": documents.iter().map(|doc| json!({ "text": doc })).collect::<Vec<_>>(),
54 },
55 });
56
57 if let Some(ndims) = self.ndims {
58 request_body["output_dimensionality"] = json!(ndims);
59 }
60
61 let response = self
62 .client
63 .post(&format!("/v1beta/models/{}:embedContent", self.model))
64 .json(&request_body)
65 .send()
66 .await?
67 .error_for_status()?
68 .json::<ApiResponse<gemini_api_types::EmbeddingResponse>>()
69 .await?;
70
71 match response {
72 ApiResponse::Ok(response) => {
73 let chunk_size = self.ndims.unwrap_or_else(|| self.ndims());
74 Ok(documents
75 .into_iter()
76 .zip(response.embedding.values.chunks(chunk_size))
77 .map(|(document, embedding)| embeddings::Embedding {
78 document,
79 vec: embedding.to_vec(),
80 })
81 .collect())
82 }
83 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
84 }
85 }
86}
87
88// =================================================================
89// Gemini API Types
90// =================================================================
91/// Rust Implementation of the Gemini Types from [Gemini API Reference](https://ai.google.dev/api/embeddings)
92#[allow(dead_code)]
93mod gemini_api_types {
94 use serde::{Deserialize, Serialize};
95 use serde_json::Value;
96
97 use crate::providers::gemini::gemini_api_types::{CodeExecutionResult, ExecutableCode};
98
99 #[derive(Serialize)]
100 #[serde(rename_all = "camelCase")]
101 pub struct EmbedContentRequest {
102 model: String,
103 content: EmbeddingContent,
104 task_type: TaskType,
105 title: String,
106 output_dimensionality: i32,
107 }
108
109 #[derive(Serialize)]
110 pub struct EmbeddingContent {
111 parts: Vec<EmbeddingContentPart>,
112 /// Optional. The producer of the content. Must be either 'user' or 'model'. Useful to set for multi-turn
113 /// conversations, otherwise can be left blank or unset.
114 role: Option<String>,
115 }
116
117 /// A datatype containing media that is part of a multi-part Content message.
118 /// - A Part consists of data which has an associated datatype. A Part can only contain one of the accepted types in Part.data.
119 /// - A Part must have a fixed IANA MIME type identifying the type and subtype of the media if the inlineData field is filled with raw bytes.
120 #[derive(Serialize)]
121 pub struct EmbeddingContentPart {
122 /// Inline text.
123 text: String,
124 /// Inline media bytes.
125 inline_data: Option<Blob>,
126 /// A predicted FunctionCall returned from the model that contains a string representing the [FunctionDeclaration.name]
127 /// with the arguments and their values.
128 function_call: Option<FunctionCall>,
129 /// The result output of a FunctionCall that contains a string representing the [FunctionDeclaration.name] and a structured
130 /// JSON object containing any output from the function is used as context to the model.
131 function_response: Option<FunctionResponse>,
132 /// URI based data.
133 file_data: Option<FileData>,
134 /// Code generated by the model that is meant to be executed.
135 executable_code: Option<ExecutableCode>,
136 /// Result of executing the ExecutableCode.
137 code_execution_result: Option<CodeExecutionResult>,
138 }
139
140 /// Raw media bytes.
141 /// Text should not be sent as raw bytes, use the 'text' field.
142 #[derive(Serialize)]
143 pub struct Blob {
144 /// Raw bytes for media formats.A base64-encoded string.
145 data: String,
146 /// The IANA standard MIME type of the source data. Examples: - image/png - image/jpeg If an unsupported MIME type is
147 /// provided, an error will be returned. For a complete list of supported types, see Supported file formats.
148 mime_type: String,
149 }
150
151 #[derive(Serialize)]
152 pub struct FunctionCall {
153 /// The name of the function to call. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 63.
154 name: String,
155 /// The function parameters and values in JSON object format.
156 args: Option<Value>,
157 }
158
159 #[derive(Serialize)]
160 pub struct FunctionResponse {
161 /// The name of the function to call. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 63.
162 name: String,
163 /// The result of the function call in JSON object format.
164 result: Value,
165 }
166
167 #[derive(Serialize)]
168 #[serde(rename_all = "camelCase")]
169 pub struct FileData {
170 /// The URI of the file.
171 file_uri: String,
172 /// The IANA standard MIME type of the source data.
173 mime_type: String,
174 }
175
176 #[derive(Serialize)]
177 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
178 pub enum TaskType {
179 /// Unset value, which will default to one of the other enum values.
180 Unspecified,
181 /// Specifies the given text is a query in a search/retrieval setting.
182 RetrievalQuery,
183 /// Specifies the given text is a document from the corpus being searched.
184 RetrievalDocument,
185 /// Specifies the given text will be used for STS.
186 SemanticSimilarity,
187 /// Specifies that the given text will be classified.
188 Classification,
189 /// Specifies that the embeddings will be used for clustering.
190 Clustering,
191 /// Specifies that the given text will be used for question answering.
192 QuestionAnswering,
193 /// Specifies that the given text will be used for fact verification.
194 FactVerification,
195 }
196
197 #[derive(Debug, Deserialize)]
198 pub struct EmbeddingResponse {
199 pub embedding: EmbeddingValues,
200 }
201
202 #[derive(Debug, Deserialize)]
203 pub struct EmbeddingValues {
204 pub values: Vec<f64>,
205 }
206}