gemini_rust/embedding/
builder.rs1use std::sync::Arc;
2
3use super::model::{
4 BatchContentEmbeddingResponse, BatchEmbedContentsRequest, ContentEmbeddingResponse,
5 EmbedContentRequest, TaskType,
6};
7use crate::{
8 client::{Error as ClientError, GeminiClient},
9 Content, Message,
10};
11
12pub struct EmbedBuilder {
14 client: Arc<GeminiClient>,
15 contents: Vec<Content>,
16 task_type: Option<TaskType>,
17 title: Option<String>,
18 output_dimensionality: Option<i32>,
19}
20
21impl EmbedBuilder {
22 pub(crate) fn new(client: Arc<GeminiClient>) -> Self {
24 Self {
25 client,
26 contents: Vec::new(),
27 task_type: None,
28 title: None,
29 output_dimensionality: None,
30 }
31 }
32
33 pub fn with_text(mut self, text: impl Into<String>) -> Self {
35 let message = Message::embed(text);
36 self.contents.push(message.content);
37 self
38 }
39
40 pub fn with_chunks(mut self, chunks: Vec<impl Into<String>>) -> Self {
42 for chunk in chunks {
44 let message = Message::embed(chunk);
45 self.contents.push(message.content);
46 }
47 self
48 }
49
50 pub fn with_task_type(mut self, task_type: TaskType) -> Self {
52 self.task_type = Some(task_type);
53 self
54 }
55
56 pub fn with_title(mut self, title: String) -> Self {
59 self.title = Some(title);
60 self
61 }
62
63 pub fn with_output_dimensionality(mut self, output_dimensionality: i32) -> Self {
65 self.output_dimensionality = Some(output_dimensionality);
66 self
67 }
68
69 pub async fn execute(self) -> Result<ContentEmbeddingResponse, ClientError> {
71 let request = EmbedContentRequest {
72 model: self.client.model.to_string(),
73 content: self.contents.first().expect("No content set").clone(),
74 task_type: self.task_type,
75 title: self.title,
76 output_dimensionality: self.output_dimensionality,
77 };
78
79 self.client.embed_content(request).await
80 }
81
82 pub async fn execute_batch(self) -> Result<BatchContentEmbeddingResponse, ClientError> {
84 let mut batch_request = BatchEmbedContentsRequest {
85 requests: Vec::new(),
86 };
87
88 for content in self.contents {
89 let request = EmbedContentRequest {
90 model: self.client.model.to_string(),
91 content: content.clone(),
92 task_type: self.task_type.clone(),
93 title: self.title.clone(),
94 output_dimensionality: self.output_dimensionality,
95 };
96 batch_request.requests.push(request);
97 }
98
99 self.client.embed_content_batch(batch_request).await
100 }
101}