use std::sync::Arc;
use tracing::instrument;
use super::model::{
BatchContentEmbeddingResponse, BatchEmbedContentsRequest, ContentEmbeddingResponse,
EmbedContentRequest, TaskType,
};
use crate::{
Content, Message,
client::{Error as ClientError, GeminiClient},
};
pub struct EmbedBuilder {
client: Arc<GeminiClient>,
contents: Vec<Content>,
task_type: Option<TaskType>,
title: Option<String>,
output_dimensionality: Option<i32>,
}
impl EmbedBuilder {
pub(crate) fn new(client: Arc<GeminiClient>) -> Self {
Self {
client,
contents: Vec::new(),
task_type: None,
title: None,
output_dimensionality: None,
}
}
pub fn with_text(mut self, text: impl Into<String>) -> Self {
let message = Message::embed(text);
self.contents.push(message.content);
self
}
pub fn with_chunks(mut self, chunks: Vec<impl Into<String>>) -> Self {
for chunk in chunks {
let message = Message::embed(chunk);
self.contents.push(message.content);
}
self
}
pub fn with_task_type(mut self, task_type: TaskType) -> Self {
self.task_type = Some(task_type);
self
}
pub fn with_title(mut self, title: String) -> Self {
self.title = Some(title);
self
}
pub fn with_output_dimensionality(mut self, output_dimensionality: i32) -> Self {
self.output_dimensionality = Some(output_dimensionality);
self
}
#[instrument(skip_all, fields(
task.type = self.task_type.as_ref().map(AsRef::<str>::as_ref),
title = self.title,
output.dimensionality = self.output_dimensionality
))]
pub async fn execute(self) -> Result<ContentEmbeddingResponse, ClientError> {
let request = EmbedContentRequest {
model: self.client.model.clone(),
content: self.contents.first().expect("No content set").clone(),
task_type: self.task_type,
title: self.title,
output_dimensionality: self.output_dimensionality,
};
self.client.embed_content(request).await
}
#[instrument(skip_all, fields(
batch.size = self.contents.len(),
task.type = self.task_type.as_ref().map(AsRef::<str>::as_ref),
title = self.title,
output.dimensionality = self.output_dimensionality
))]
pub async fn execute_batch(self) -> Result<BatchContentEmbeddingResponse, ClientError> {
let mut batch_request = BatchEmbedContentsRequest { requests: Vec::new() };
for content in self.contents {
let request = EmbedContentRequest {
model: self.client.model.clone(),
content: content.clone(),
task_type: self.task_type.clone(),
title: self.title.clone(),
output_dimensionality: self.output_dimensionality,
};
batch_request.requests.push(request);
}
self.client.embed_content_batch(batch_request).await
}
}