gemini_rust/embedding/
builder.rs

1use std::sync::Arc;
2use tracing::instrument;
3
4use super::model::{
5    BatchContentEmbeddingResponse, BatchEmbedContentsRequest, ContentEmbeddingResponse,
6    EmbedContentRequest, TaskType,
7};
8use crate::{
9    client::{Error as ClientError, GeminiClient},
10    Content, Message,
11};
12
13/// Builder for embed generation requests
14#[derive(Clone)]
15pub struct EmbedBuilder {
16    client: Arc<GeminiClient>,
17    contents: Vec<Content>,
18    task_type: Option<TaskType>,
19    title: Option<String>,
20    output_dimensionality: Option<i32>,
21}
22
23impl EmbedBuilder {
24    /// Create a new embed builder
25    pub(crate) fn new(client: Arc<GeminiClient>) -> Self {
26        Self {
27            client,
28            contents: Vec::new(),
29            task_type: None,
30            title: None,
31            output_dimensionality: None,
32        }
33    }
34
35    /// Add a vec of text to embed to the request
36    pub fn with_text(mut self, text: impl Into<String>) -> Self {
37        let message = Message::embed(text);
38        self.contents.push(message.content);
39        self
40    }
41
42    /// Add a vec of chunks to batch embed to the request
43    pub fn with_chunks(mut self, chunks: Vec<impl Into<String>>) -> Self {
44        //for each chunks
45        for chunk in chunks {
46            let message = Message::embed(chunk);
47            self.contents.push(message.content);
48        }
49        self
50    }
51
52    /// Specify embedding task type
53    pub fn with_task_type(mut self, task_type: TaskType) -> Self {
54        self.task_type = Some(task_type);
55        self
56    }
57
58    /// Specify document title
59    /// Supported by newer models since 2024 only !!
60    pub fn with_title(mut self, title: String) -> Self {
61        self.title = Some(title);
62        self
63    }
64
65    /// Specify output_dimensionality. If set, excessive values in the output embedding are truncated from the end
66    pub fn with_output_dimensionality(mut self, output_dimensionality: i32) -> Self {
67        self.output_dimensionality = Some(output_dimensionality);
68        self
69    }
70
71    /// Execute the request
72    #[instrument(skip_all, fields(
73        task.type = self.task_type.as_ref().map(AsRef::<str>::as_ref),
74        title = self.title,
75        output.dimensionality = self.output_dimensionality
76    ))]
77    pub async fn execute(self) -> Result<ContentEmbeddingResponse, ClientError> {
78        let request = EmbedContentRequest {
79            model: self.client.model.clone(),
80            content: self.contents.first().expect("No content set").clone(),
81            task_type: self.task_type,
82            title: self.title,
83            output_dimensionality: self.output_dimensionality,
84        };
85
86        self.client.embed_content(request).await
87    }
88
89    /// Execute the request
90    #[instrument(skip_all, fields(
91        batch.size = self.contents.len(),
92        task.type = self.task_type.as_ref().map(AsRef::<str>::as_ref),
93        title = self.title,
94        output.dimensionality = self.output_dimensionality
95    ))]
96    pub async fn execute_batch(self) -> Result<BatchContentEmbeddingResponse, ClientError> {
97        let mut batch_request = BatchEmbedContentsRequest {
98            requests: Vec::new(),
99        };
100
101        for content in self.contents {
102            let request = EmbedContentRequest {
103                model: self.client.model.clone(),
104                content: content.clone(),
105                task_type: self.task_type.clone(),
106                title: self.title.clone(),
107                output_dimensionality: self.output_dimensionality,
108            };
109            batch_request.requests.push(request);
110        }
111
112        self.client.embed_content_batch(batch_request).await
113    }
114}