gemini_rust/embedding/
builder.rs1use std::sync::Arc;
2use tracing::instrument;
3
4use super::model::{
5 BatchContentEmbeddingResponse, BatchEmbedContentsRequest, ContentEmbeddingResponse,
6 EmbedContentRequest, TaskType,
7};
8use crate::{
9 client::{Error as ClientError, GeminiClient},
10 Content, Message,
11};
12
13pub struct EmbedBuilder {
15 client: Arc<GeminiClient>,
16 contents: Vec<Content>,
17 task_type: Option<TaskType>,
18 title: Option<String>,
19 output_dimensionality: Option<i32>,
20}
21
22impl EmbedBuilder {
23 pub(crate) fn new(client: Arc<GeminiClient>) -> Self {
25 Self {
26 client,
27 contents: Vec::new(),
28 task_type: None,
29 title: None,
30 output_dimensionality: None,
31 }
32 }
33
34 pub fn with_text(mut self, text: impl Into<String>) -> Self {
36 let message = Message::embed(text);
37 self.contents.push(message.content);
38 self
39 }
40
41 pub fn with_chunks(mut self, chunks: Vec<impl Into<String>>) -> Self {
43 for chunk in chunks {
45 let message = Message::embed(chunk);
46 self.contents.push(message.content);
47 }
48 self
49 }
50
51 pub fn with_task_type(mut self, task_type: TaskType) -> Self {
53 self.task_type = Some(task_type);
54 self
55 }
56
57 pub fn with_title(mut self, title: String) -> Self {
60 self.title = Some(title);
61 self
62 }
63
64 pub fn with_output_dimensionality(mut self, output_dimensionality: i32) -> Self {
66 self.output_dimensionality = Some(output_dimensionality);
67 self
68 }
69
70 #[instrument(skip_all, fields(
72 task.type = self.task_type.as_ref().map(AsRef::<str>::as_ref),
73 title = self.title,
74 output.dimensionality = self.output_dimensionality
75 ))]
76 pub async fn execute(self) -> Result<ContentEmbeddingResponse, ClientError> {
77 let request = EmbedContentRequest {
78 model: self.client.model.clone(),
79 content: self.contents.first().expect("No content set").clone(),
80 task_type: self.task_type,
81 title: self.title,
82 output_dimensionality: self.output_dimensionality,
83 };
84
85 self.client.embed_content(request).await
86 }
87
88 #[instrument(skip_all, fields(
90 batch.size = self.contents.len(),
91 task.type = self.task_type.as_ref().map(AsRef::<str>::as_ref),
92 title = self.title,
93 output.dimensionality = self.output_dimensionality
94 ))]
95 pub async fn execute_batch(self) -> Result<BatchContentEmbeddingResponse, ClientError> {
96 let mut batch_request = BatchEmbedContentsRequest {
97 requests: Vec::new(),
98 };
99
100 for content in self.contents {
101 let request = EmbedContentRequest {
102 model: self.client.model.clone(),
103 content: content.clone(),
104 task_type: self.task_type.clone(),
105 title: self.title.clone(),
106 output_dimensionality: self.output_dimensionality,
107 };
108 batch_request.requests.push(request);
109 }
110
111 self.client.embed_content_batch(batch_request).await
112 }
113}