gemini_rust/embedding/
builder.rs

1use std::sync::Arc;
2use tracing::instrument;
3
4use super::model::{
5    BatchContentEmbeddingResponse, BatchEmbedContentsRequest, ContentEmbeddingResponse,
6    EmbedContentRequest, TaskType,
7};
8use crate::{
9    client::{Error as ClientError, GeminiClient},
10    Content, Message,
11};
12
13/// Builder for embed generation requests
14pub struct EmbedBuilder {
15    client: Arc<GeminiClient>,
16    contents: Vec<Content>,
17    task_type: Option<TaskType>,
18    title: Option<String>,
19    output_dimensionality: Option<i32>,
20}
21
22impl EmbedBuilder {
23    /// Create a new embed builder
24    pub(crate) fn new(client: Arc<GeminiClient>) -> Self {
25        Self {
26            client,
27            contents: Vec::new(),
28            task_type: None,
29            title: None,
30            output_dimensionality: None,
31        }
32    }
33
34    /// Add a vec of text to embed to the request
35    pub fn with_text(mut self, text: impl Into<String>) -> Self {
36        let message = Message::embed(text);
37        self.contents.push(message.content);
38        self
39    }
40
41    /// Add a vec of chunks to batch embed to the request
42    pub fn with_chunks(mut self, chunks: Vec<impl Into<String>>) -> Self {
43        //for each chunks
44        for chunk in chunks {
45            let message = Message::embed(chunk);
46            self.contents.push(message.content);
47        }
48        self
49    }
50
51    /// Specify embedding task type
52    pub fn with_task_type(mut self, task_type: TaskType) -> Self {
53        self.task_type = Some(task_type);
54        self
55    }
56
57    /// Specify document title
58    /// Supported by newer models since 2024 only !!
59    pub fn with_title(mut self, title: String) -> Self {
60        self.title = Some(title);
61        self
62    }
63
64    /// Specify output_dimensionality. If set, excessive values in the output embedding are truncated from the end
65    pub fn with_output_dimensionality(mut self, output_dimensionality: i32) -> Self {
66        self.output_dimensionality = Some(output_dimensionality);
67        self
68    }
69
70    /// Execute the request
71    #[instrument(skip_all, fields(
72        task.type = self.task_type.as_ref().map(AsRef::<str>::as_ref),
73        title = self.title,
74        output.dimensionality = self.output_dimensionality
75    ))]
76    pub async fn execute(self) -> Result<ContentEmbeddingResponse, ClientError> {
77        let request = EmbedContentRequest {
78            model: self.client.model.clone(),
79            content: self.contents.first().expect("No content set").clone(),
80            task_type: self.task_type,
81            title: self.title,
82            output_dimensionality: self.output_dimensionality,
83        };
84
85        self.client.embed_content(request).await
86    }
87
88    /// Execute the request
89    #[instrument(skip_all, fields(
90        batch.size = self.contents.len(),
91        task.type = self.task_type.as_ref().map(AsRef::<str>::as_ref),
92        title = self.title,
93        output.dimensionality = self.output_dimensionality
94    ))]
95    pub async fn execute_batch(self) -> Result<BatchContentEmbeddingResponse, ClientError> {
96        let mut batch_request = BatchEmbedContentsRequest {
97            requests: Vec::new(),
98        };
99
100        for content in self.contents {
101            let request = EmbedContentRequest {
102                model: self.client.model.clone(),
103                content: content.clone(),
104                task_type: self.task_type.clone(),
105                title: self.title.clone(),
106                output_dimensionality: self.output_dimensionality,
107            };
108            batch_request.requests.push(request);
109        }
110
111        self.client.embed_content_batch(batch_request).await
112    }
113}