google_ai_rs/embedding.rs
1use std::borrow::Cow;
2
3use tonic::IntoRequest;
4
5use crate::{
6 client::CClient,
7 content::{IntoContent, TryIntoContent},
8 error::status_into_error,
9 full_model_name,
10 proto::{BatchEmbedContentsResponse, Content, EmbedContentResponse, Model as Info, TaskType},
11};
12
13use super::{
14 client::Client,
15 error::{Error, ServiceError},
16 proto::{BatchEmbedContentsRequest, EmbedContentRequest},
17};
18
19/// A client for generating embeddings using Google's embedding service
20///
21/// Provides both single and batch embedding capabilities with configurable task types.
22///
23/// # Example
24/// ```
25/// use google_ai_rs::{Client, GenerativeModel};
26///
27/// # async fn f() -> Result<(), Box<dyn std::error::Error>> {
28/// # let auth = "YOUR-API-KEY";
29/// let client = Client::new(auth).await?;
30/// let embedding_model = client.embedding_model("embedding-001");
31///
32/// // Single embedding
33/// let embedding = embedding_model.embed_content("Hello world").await?;
34///
35/// // Batch embeddings
36/// let batch_response = embedding_model.new_batch()
37/// .add_content("First text")
38/// .add_content("Second text")
39/// .embed()
40/// .await?;
41/// # Ok(())
42/// # }
43/// ```
44#[derive(Debug)]
45pub struct Model<'c> {
46 /// Client for making API requests
47 client: CClient<'c>,
48 /// Fully qualified model name (e.g., "models/embedding-001")
49 name: Box<str>,
50 /// Optional task type specification for embedding generation
51 ///
52 /// Affects how embeddings are optimized:
53 /// - `None`: General purpose embeddings
54 /// - `TaskType::RetrievalDocument`: Optimized for document storage
55 /// - `TaskType::RetrievalQuery`: Optimized for query matching
56 pub task_type: Option<TaskType>,
57}
58
59impl<'c> Model<'c> {
60 /// Creates a new Model instance
61 ///
62 /// # Arguments
63 /// * `client` - Configured API client
64 /// * `name` - Model identifier (e.g., "embedding-001")
65 pub fn new(client: &'c Client, name: &str) -> Self {
66 Self::new_inner(client, name)
67 }
68
69 fn new_inner(client: impl Into<CClient<'c>>, name: &str) -> Self {
70 Self {
71 client: client.into(),
72 name: full_model_name(name).into(),
73 task_type: None,
74 }
75 }
76
77 /// Optional task type specification for embedding generation
78 ///
79 /// Affects how embeddings are optimized:
80 /// - `TaskType::RetrievalDocument`: Optimized for document storage
81 /// - `TaskType::RetrievalQuery`: Optimized for query matching
82 pub fn task_type(mut self, task_type: TaskType) -> Self {
83 self.task_type = Some(task_type);
84 self
85 }
86
87 /// Embeds content using the API's embedding service.
88 ///
89 /// Consider batch embedding for multiple contents
90 ///
91 /// # Example
92 /// ```
93 /// # use google_ai_rs::{Client, GenerativeModel};
94 /// # use google_ai_rs::Part;
95 /// # async fn f() -> Result<(), Box<dyn std::error::Error>> {
96 /// # let auth = "YOUR-API-KEY";
97 /// # let client = Client::new(auth).await?;
98 /// # let model = client.embedding_model("embedding-001");
99 /// // Single text embedding
100 /// let embedding = model.embed_content("Hello world").await?;
101 ///
102 /// # let image_data = vec![];
103 /// // Multi-modal embedding
104 /// model.embed_content((
105 /// "Query about this image",
106 /// Part::blob("image/jpeg", image_data)
107 /// )).await?;
108 /// # Ok(())
109 /// # }
110 /// ```
111 ///
112 /// # Errors
113 /// Returns `Error::Net` for transport-level errors or `Error::Service` for service errors
114 #[inline]
115 pub async fn embed_content<T: TryIntoContent>(
116 &self,
117 content: T,
118 ) -> Result<EmbedContentResponse, Error> {
119 self.embed_content_with_title("", content).await
120 }
121
122 /// Embeds content with optional title context
123 ///
124 /// # Arguments
125 /// * `title` - Optional document title for retrieval tasks
126 /// * `parts` - Content input that converts to parts
127 pub async fn embed_content_with_title<T>(
128 &self,
129 title: &str,
130 content: T,
131 ) -> Result<EmbedContentResponse, Error>
132 where
133 T: TryIntoContent,
134 {
135 let request = self
136 .build_request(title, content.try_into_content()?)
137 .await?;
138 self.client
139 .gc
140 .clone()
141 .embed_content(request)
142 .await
143 .map_err(status_into_error)
144 .map(|response| response.into_inner())
145 }
146
147 /// Creates a new batch embedding context
148 pub fn new_batch(&self) -> Batch<'_> {
149 Batch {
150 m: self,
151 req: BatchEmbedContentsRequest {
152 model: self.name.to_string(),
153 requests: Vec::new(),
154 },
155 }
156 }
157
158 /// Embeds multiple contents as separate content items
159 ///
160 /// # Example
161 /// ```
162 /// # use google_ai_rs::{Client, GenerativeModel};
163 /// # use google_ai_rs::Part;
164 /// # async fn f() -> Result<(), Box<dyn std::error::Error>> {
165 /// # let auth = "YOUR-API-KEY";
166 /// # let client = Client::new(auth).await?;
167 /// # let model = client.embedding_model("embedding-001");
168 /// let texts = vec!["First", "Second", "Third"];
169 /// let batch = model.embed_batch(texts).await?;
170 /// # Ok(())
171 /// # }
172 /// ```
173 pub async fn embed_batch<I, T>(&self, contents: I) -> Result<BatchEmbedContentsResponse, Error>
174 where
175 I: IntoIterator<Item = T>,
176 T: TryIntoContent,
177 {
178 let mut batch = self.new_batch();
179 for content in contents.into_iter() {
180 batch = batch.add_content(content.try_into_content()?);
181 }
182 batch.embed().await
183 }
184
185 /// returns information about the model.
186 pub async fn info(&self) -> Result<Info, Error> {
187 self.client.get_model(&self.name).await
188 }
189
190 #[inline(always)]
191 async fn build_request(
192 &self,
193 title: &str,
194 content: Content,
195 ) -> Result<tonic::Request<EmbedContentRequest>, Error> {
196 let request = self._build_request(title, content).into_request();
197 Ok(request)
198 }
199
200 fn _build_request(&self, title: &str, content: Content) -> EmbedContentRequest {
201 let title = if title.is_empty() {
202 None
203 } else {
204 Some(title.to_owned())
205 };
206
207 // A non-empty title overrides the task type.
208 let task_type = title
209 .as_ref()
210 .map(|_| TaskType::RetrievalDocument.into())
211 .or(self.task_type.map(Into::into));
212
213 EmbedContentRequest {
214 model: self.name.to_string(),
215 content: Some(content),
216 task_type,
217 title,
218 output_dimensionality: None,
219 }
220 }
221}
222
223/// Builder for batch embedding requests
224///
225/// Collects multiple embedding requests for efficient batch processing.
226///
227/// # Example
228/// ```
229/// # use google_ai_rs::{Client, GenerativeModel};
230/// # async fn f() -> Result<(), Box<dyn std::error::Error>> {
231/// # let auth = "YOUR-API-KEY";
232/// # let client = Client::new(auth).await?;
233/// # let embedding_model = client.embedding_model("embedding-001");
234/// let batch = embedding_model.new_batch()
235/// .add_content_with_title("Document 1", "Full text content...")
236/// .add_content_with_title("Document 2", "Another text...");
237/// # Ok(())
238/// # }
239/// ```
240#[derive(Debug)]
241pub struct Batch<'m> {
242 m: &'m Model<'m>,
243 req: BatchEmbedContentsRequest,
244}
245
246impl Batch<'_> {
247 /// Adds content to the batch
248 #[inline]
249 pub fn add_content<T: IntoContent>(self, content: T) -> Self {
250 self.add_content_with_title("", content)
251 }
252
253 /// Adds content with title to the batch
254 ///
255 /// # Argument
256 /// * `title` - Document title for retrieval context
257 pub fn add_content_with_title<T: IntoContent>(mut self, title: &str, content: T) -> Self {
258 self.req
259 .requests
260 .push(self.m._build_request(title, content.into_content()));
261 self
262 }
263
264 /// Executes the batch embedding request
265 pub async fn embed(self) -> Result<BatchEmbedContentsResponse, Error> {
266 let expected = self.req.requests.len();
267 let request = self.req.into_request();
268
269 let response = self
270 .m
271 .client
272 .gc
273 .clone()
274 .batch_embed_contents(request)
275 .await
276 .map_err(status_into_error)
277 .map(|response| response.into_inner())?;
278
279 if response.embeddings.len() != expected {
280 return Err(Error::Service(ServiceError::InvalidResponse(
281 format!(
282 "Expected {} embeddings, got {}",
283 expected,
284 response.embeddings.len()
285 )
286 .into(),
287 )));
288 }
289
290 Ok(response)
291 }
292}
293
294impl Client {
295 /// Creates a new embedding model interface
296 ///
297 /// Shorthand for `EmbeddingModel::new()`
298 pub fn embedding_model<'c>(&'c self, name: &str) -> Model<'c> {
299 Model::new(self, name)
300 }
301}