gemini_rust/embedding/
builder.rs

1use std::sync::Arc;
2
3use super::model::{
4    BatchContentEmbeddingResponse, BatchEmbedContentsRequest, ContentEmbeddingResponse,
5    EmbedContentRequest, TaskType,
6};
7use crate::{
8    client::{Error as ClientError, GeminiClient},
9    Content, Message,
10};
11
12/// Builder for embed generation requests
13pub struct EmbedBuilder {
14    client: Arc<GeminiClient>,
15    contents: Vec<Content>,
16    task_type: Option<TaskType>,
17    title: Option<String>,
18    output_dimensionality: Option<i32>,
19}
20
21impl EmbedBuilder {
22    /// Create a new embed builder
23    pub(crate) fn new(client: Arc<GeminiClient>) -> Self {
24        Self {
25            client,
26            contents: Vec::new(),
27            task_type: None,
28            title: None,
29            output_dimensionality: None,
30        }
31    }
32
33    /// Add a vec of text to embed to the request
34    pub fn with_text(mut self, text: impl Into<String>) -> Self {
35        let message = Message::embed(text);
36        self.contents.push(message.content);
37        self
38    }
39
40    /// Add a vec of chunks to batch embed to the request
41    pub fn with_chunks(mut self, chunks: Vec<impl Into<String>>) -> Self {
42        //for each chunks
43        for chunk in chunks {
44            let message = Message::embed(chunk);
45            self.contents.push(message.content);
46        }
47        self
48    }
49
50    /// Specify embedding task type
51    pub fn with_task_type(mut self, task_type: TaskType) -> Self {
52        self.task_type = Some(task_type);
53        self
54    }
55
56    /// Specify document title
57    /// Supported by newer models since 2024 only !!
58    pub fn with_title(mut self, title: String) -> Self {
59        self.title = Some(title);
60        self
61    }
62
63    /// Specify output_dimensionality. If set, excessive values in the output embedding are truncated from the end
64    pub fn with_output_dimensionality(mut self, output_dimensionality: i32) -> Self {
65        self.output_dimensionality = Some(output_dimensionality);
66        self
67    }
68
69    /// Execute the request
70    pub async fn execute(self) -> Result<ContentEmbeddingResponse, ClientError> {
71        let request = EmbedContentRequest {
72            model: self.client.model.to_string(),
73            content: self.contents.first().expect("No content set").clone(),
74            task_type: self.task_type,
75            title: self.title,
76            output_dimensionality: self.output_dimensionality,
77        };
78
79        self.client.embed_content(request).await
80    }
81
82    /// Execute the request
83    pub async fn execute_batch(self) -> Result<BatchContentEmbeddingResponse, ClientError> {
84        let mut batch_request = BatchEmbedContentsRequest {
85            requests: Vec::new(),
86        };
87
88        for content in self.contents {
89            let request = EmbedContentRequest {
90                model: self.client.model.to_string(),
91                content: content.clone(),
92                task_type: self.task_type.clone(),
93                title: self.title.clone(),
94                output_dimensionality: self.output_dimensionality,
95            };
96            batch_request.requests.push(request);
97        }
98
99        self.client.embed_content_batch(batch_request).await
100    }
101}