google_ai_rs/
embedding.rs

1use std::borrow::Cow;
2
3use tonic::IntoRequest;
4
5use crate::{
6    client::CClient,
7    content::{IntoContent, TryIntoContent},
8    error::status_into_error,
9    full_model_name,
10    proto::{BatchEmbedContentsResponse, Content, EmbedContentResponse, Model as Info, TaskType},
11};
12
13use super::{
14    client::Client,
15    error::{Error, ServiceError},
16    proto::{BatchEmbedContentsRequest, EmbedContentRequest},
17};
18
19/// A client for generating embeddings using Google's embedding service
20///
21/// Provides both single and batch embedding capabilities with configurable task types.
22///
23/// # Example
24/// ```
25/// use google_ai_rs::{Client, GenerativeModel};
26///
27/// # async fn f() -> Result<(), Box<dyn std::error::Error>> {
28/// # let auth = "YOUR-API-KEY";
29/// let client = Client::new(auth).await?;
30/// let embedding_model = client.embedding_model("embedding-001");
31///
32/// // Single embedding
33/// let embedding = embedding_model.embed_content("Hello world").await?;
34///
35/// // Batch embeddings
36/// let batch_response = embedding_model.new_batch()
37///     .add_content("First text")
38///     .add_content("Second text")
39///     .embed()
40///     .await?;
41/// # Ok(())
42/// # }
43/// ```
44#[derive(Debug)]
45pub struct Model<'c> {
46    /// Client for making API requests
47    client: CClient<'c>,
48    /// Fully qualified model name (e.g., "models/embedding-001")
49    name: Box<str>,
50    /// Optional task type specification for embedding generation
51    ///
52    /// Affects how embeddings are optimized:
53    /// - `None`: General purpose embeddings
54    /// - `TaskType::RetrievalDocument`: Optimized for document storage
55    /// - `TaskType::RetrievalQuery`: Optimized for query matching
56    pub task_type: Option<TaskType>,
57}
58
59impl<'c> Model<'c> {
60    /// Creates a new Model instance
61    ///
62    /// # Arguments
63    /// * `client` - Configured API client
64    /// * `name` - Model identifier (e.g., "embedding-001")
65    pub fn new(client: &'c Client, name: &str) -> Self {
66        Self::new_inner(client, name)
67    }
68
69    fn new_inner(client: impl Into<CClient<'c>>, name: &str) -> Self {
70        Self {
71            client: client.into(),
72            name: full_model_name(name).into(),
73            task_type: None,
74        }
75    }
76
77    /// Optional task type specification for embedding generation
78    ///
79    /// Affects how embeddings are optimized:
80    /// - `TaskType::RetrievalDocument`: Optimized for document storage
81    /// - `TaskType::RetrievalQuery`: Optimized for query matching
82    pub fn task_type(mut self, task_type: TaskType) -> Self {
83        self.task_type = Some(task_type);
84        self
85    }
86
87    /// Embeds content using the API's embedding service.
88    ///
89    /// Consider batch embedding for multiple contents
90    ///
91    /// # Example
92    /// ```
93    /// # use google_ai_rs::{Client, GenerativeModel};
94    /// # use google_ai_rs::Part;
95    /// # async fn f() -> Result<(), Box<dyn std::error::Error>> {
96    /// # let auth = "YOUR-API-KEY";
97    /// # let client = Client::new(auth).await?;
98    /// # let model = client.embedding_model("embedding-001");
99    /// // Single text embedding
100    /// let embedding = model.embed_content("Hello world").await?;
101    ///
102    /// # let image_data = vec![];
103    /// // Multi-modal embedding
104    /// model.embed_content((
105    ///     "Query about this image",
106    ///     Part::blob("image/jpeg", image_data)
107    /// )).await?;
108    /// # Ok(())
109    /// # }
110    /// ```
111    ///
112    /// # Errors
113    /// Returns `Error::Net` for transport-level errors or `Error::Service` for service errors
114    #[inline]
115    pub async fn embed_content<T: TryIntoContent>(
116        &self,
117        content: T,
118    ) -> Result<EmbedContentResponse, Error> {
119        self.embed_content_with_title("", content).await
120    }
121
122    /// Embeds content with optional title context
123    ///
124    /// # Arguments
125    /// * `title` - Optional document title for retrieval tasks
126    /// * `parts` - Content input that converts to parts
127    pub async fn embed_content_with_title<T>(
128        &self,
129        title: &str,
130        content: T,
131    ) -> Result<EmbedContentResponse, Error>
132    where
133        T: TryIntoContent,
134    {
135        let request = self
136            .build_request(title, content.try_into_content()?)
137            .await?;
138        self.client
139            .gc
140            .clone()
141            .embed_content(request)
142            .await
143            .map_err(status_into_error)
144            .map(|response| response.into_inner())
145    }
146
147    /// Creates a new batch embedding context
148    pub fn new_batch(&self) -> Batch<'_> {
149        Batch {
150            m: self,
151            req: BatchEmbedContentsRequest {
152                model: self.name.to_string(),
153                requests: Vec::new(),
154            },
155        }
156    }
157
158    /// Embeds multiple contents as separate content items
159    ///
160    /// # Example
161    /// ```
162    /// # use google_ai_rs::{Client, GenerativeModel};
163    /// # use google_ai_rs::Part;
164    /// # async fn f() -> Result<(), Box<dyn std::error::Error>> {
165    /// # let auth = "YOUR-API-KEY";
166    /// # let client = Client::new(auth).await?;
167    /// # let model = client.embedding_model("embedding-001");
168    /// let texts = vec!["First", "Second", "Third"];
169    /// let batch = model.embed_batch(texts).await?;
170    /// # Ok(())
171    /// # }
172    /// ```
173    pub async fn embed_batch<I, T>(&self, contents: I) -> Result<BatchEmbedContentsResponse, Error>
174    where
175        I: IntoIterator<Item = T>,
176        T: TryIntoContent,
177    {
178        let mut batch = self.new_batch();
179        for content in contents.into_iter() {
180            batch = batch.add_content(content.try_into_content()?);
181        }
182        batch.embed().await
183    }
184
185    /// returns information about the model.
186    pub async fn info(&self) -> Result<Info, Error> {
187        self.client.get_model(&self.name).await
188    }
189
190    #[inline(always)]
191    async fn build_request(
192        &self,
193        title: &str,
194        content: Content,
195    ) -> Result<tonic::Request<EmbedContentRequest>, Error> {
196        let request = self._build_request(title, content).into_request();
197        Ok(request)
198    }
199
200    fn _build_request(&self, title: &str, content: Content) -> EmbedContentRequest {
201        let title = if title.is_empty() {
202            None
203        } else {
204            Some(title.to_owned())
205        };
206
207        // A non-empty title overrides the task type.
208        let task_type = title
209            .as_ref()
210            .map(|_| TaskType::RetrievalDocument.into())
211            .or(self.task_type.map(Into::into));
212
213        EmbedContentRequest {
214            model: self.name.to_string(),
215            content: Some(content),
216            task_type,
217            title,
218            output_dimensionality: None,
219        }
220    }
221}
222
223/// Builder for batch embedding requests
224///
225/// Collects multiple embedding requests for efficient batch processing.
226///
227/// # Example
228/// ```
229/// # use google_ai_rs::{Client, GenerativeModel};
230/// # async fn f() -> Result<(), Box<dyn std::error::Error>> {
231/// # let auth = "YOUR-API-KEY";
232/// # let client = Client::new(auth).await?;
233/// # let embedding_model = client.embedding_model("embedding-001");
234/// let batch = embedding_model.new_batch()
235///     .add_content_with_title("Document 1", "Full text content...")
236///     .add_content_with_title("Document 2", "Another text...");
237/// # Ok(())
238/// # }
239/// ```
240#[derive(Debug)]
241pub struct Batch<'m> {
242    m: &'m Model<'m>,
243    req: BatchEmbedContentsRequest,
244}
245
246impl Batch<'_> {
247    /// Adds content to the batch
248    #[inline]
249    pub fn add_content<T: IntoContent>(self, content: T) -> Self {
250        self.add_content_with_title("", content)
251    }
252
253    /// Adds content with title to the batch
254    ///
255    /// # Argument
256    /// * `title` - Document title for retrieval context
257    pub fn add_content_with_title<T: IntoContent>(mut self, title: &str, content: T) -> Self {
258        self.req
259            .requests
260            .push(self.m._build_request(title, content.into_content()));
261        self
262    }
263
264    /// Executes the batch embedding request
265    pub async fn embed(self) -> Result<BatchEmbedContentsResponse, Error> {
266        let expected = self.req.requests.len();
267        let request = self.req.into_request();
268
269        let response = self
270            .m
271            .client
272            .gc
273            .clone()
274            .batch_embed_contents(request)
275            .await
276            .map_err(status_into_error)
277            .map(|response| response.into_inner())?;
278
279        if response.embeddings.len() != expected {
280            return Err(Error::Service(ServiceError::InvalidResponse(
281                format!(
282                    "Expected {} embeddings, got {}",
283                    expected,
284                    response.embeddings.len()
285                )
286                .into(),
287            )));
288        }
289
290        Ok(response)
291    }
292}
293
294impl Client {
295    /// Creates a new embedding model interface
296    ///
297    /// Shorthand for `EmbeddingModel::new()`
298    pub fn embedding_model<'c>(&'c self, name: &str) -> Model<'c> {
299        Model::new(self, name)
300    }
301}