gemini_rust/embedding/
builder.rs1use std::sync::Arc;
2use tracing::instrument;
3
4use super::model::{
5 BatchContentEmbeddingResponse, BatchEmbedContentsRequest, ContentEmbeddingResponse,
6 EmbedContentRequest, TaskType,
7};
8use crate::{
9 client::{Error as ClientError, GeminiClient},
10 Content, Message,
11};
12
13#[derive(Clone)]
15pub struct EmbedBuilder {
16 client: Arc<GeminiClient>,
17 contents: Vec<Content>,
18 task_type: Option<TaskType>,
19 title: Option<String>,
20 output_dimensionality: Option<i32>,
21}
22
23impl EmbedBuilder {
24 pub(crate) fn new(client: Arc<GeminiClient>) -> Self {
26 Self {
27 client,
28 contents: Vec::new(),
29 task_type: None,
30 title: None,
31 output_dimensionality: None,
32 }
33 }
34
35 pub fn with_text(mut self, text: impl Into<String>) -> Self {
37 let message = Message::embed(text);
38 self.contents.push(message.content);
39 self
40 }
41
42 pub fn with_chunks(mut self, chunks: Vec<impl Into<String>>) -> Self {
44 for chunk in chunks {
46 let message = Message::embed(chunk);
47 self.contents.push(message.content);
48 }
49 self
50 }
51
52 pub fn with_task_type(mut self, task_type: TaskType) -> Self {
54 self.task_type = Some(task_type);
55 self
56 }
57
58 pub fn with_title(mut self, title: String) -> Self {
61 self.title = Some(title);
62 self
63 }
64
65 pub fn with_output_dimensionality(mut self, output_dimensionality: i32) -> Self {
67 self.output_dimensionality = Some(output_dimensionality);
68 self
69 }
70
71 #[instrument(skip_all, fields(
73 task.type = self.task_type.as_ref().map(AsRef::<str>::as_ref),
74 title = self.title,
75 output.dimensionality = self.output_dimensionality
76 ))]
77 pub async fn execute(self) -> Result<ContentEmbeddingResponse, ClientError> {
78 let request = EmbedContentRequest {
79 model: self.client.model.clone(),
80 content: self.contents.first().expect("No content set").clone(),
81 task_type: self.task_type,
82 title: self.title,
83 output_dimensionality: self.output_dimensionality,
84 };
85
86 self.client.embed_content(request).await
87 }
88
89 #[instrument(skip_all, fields(
91 batch.size = self.contents.len(),
92 task.type = self.task_type.as_ref().map(AsRef::<str>::as_ref),
93 title = self.title,
94 output.dimensionality = self.output_dimensionality
95 ))]
96 pub async fn execute_batch(self) -> Result<BatchContentEmbeddingResponse, ClientError> {
97 let mut batch_request = BatchEmbedContentsRequest {
98 requests: Vec::new(),
99 };
100
101 for content in self.contents {
102 let request = EmbedContentRequest {
103 model: self.client.model.clone(),
104 content: content.clone(),
105 task_type: self.task_type.clone(),
106 title: self.title.clone(),
107 output_dimensionality: self.output_dimensionality,
108 };
109 batch_request.requests.push(request);
110 }
111
112 self.client.embed_content_batch(batch_request).await
113 }
114}