use futures::TryStream;
use std::sync::Arc;
use tracing::instrument;
use crate::{
Content, FunctionCallingMode, FunctionDeclaration, GenerationConfig, GenerationResponse,
Message, Role, Tool,
cache::CachedContentHandle,
client::{Error as ClientError, GeminiClient},
generation::{GenerateContentRequest, SpeakerVoiceConfig, SpeechConfig, ThinkingConfig},
tools::{FunctionCallingConfig, ToolConfig},
};
pub struct ContentBuilder {
client: Arc<GeminiClient>,
pub contents: Vec<Content>,
generation_config: Option<GenerationConfig>,
tools: Option<Vec<Tool>>,
tool_config: Option<ToolConfig>,
system_instruction: Option<Content>,
cached_content: Option<String>,
}
impl ContentBuilder {
pub(crate) fn new(client: Arc<GeminiClient>) -> Self {
Self {
client,
contents: Vec::new(),
generation_config: None,
tools: None,
tool_config: None,
system_instruction: None,
cached_content: None,
}
}
pub fn with_system_prompt(self, text: impl Into<String>) -> Self {
self.with_system_instruction(text)
}
pub fn with_system_instruction(mut self, text: impl Into<String>) -> Self {
let content = Content::text(text);
self.system_instruction = Some(content);
self
}
pub fn with_user_message(mut self, text: impl Into<String>) -> Self {
let message = Message::user(text);
self.contents.push(message.content);
self
}
pub fn with_model_message(mut self, text: impl Into<String>) -> Self {
let message = Message::model(text);
self.contents.push(message.content);
self
}
pub fn with_inline_data(
mut self,
data: impl Into<String>,
mime_type: impl Into<String>,
) -> Self {
let content = Content::inline_data(mime_type, data).with_role(Role::User);
self.contents.push(content);
self
}
pub fn with_function_response<Response>(
mut self,
name: impl Into<String>,
response: Response,
) -> std::result::Result<Self, serde_json::Error>
where
Response: serde::Serialize,
{
let content = Content::function_response_json(name, serde_json::to_value(response)?)
.with_role(Role::User);
self.contents.push(content);
Ok(self)
}
pub fn with_function_response_str(
mut self,
name: impl Into<String>,
response: impl Into<String>,
) -> std::result::Result<Self, serde_json::Error> {
let response_str = response.into();
let json = serde_json::from_str(&response_str)?;
let content = Content::function_response_json(name, json).with_role(Role::User);
self.contents.push(content);
Ok(self)
}
pub fn with_message(mut self, message: Message) -> Self {
let content = message.content.clone();
let role = content.role.clone().unwrap_or(message.role);
self.contents.push(content.with_role(role));
self
}
pub fn with_cached_content(mut self, cached_content: &CachedContentHandle) -> Self {
self.cached_content = Some(cached_content.name().to_string());
self
}
pub fn with_messages(mut self, messages: impl IntoIterator<Item = Message>) -> Self {
for message in messages {
self = self.with_message(message);
}
self
}
pub fn with_generation_config(mut self, config: GenerationConfig) -> Self {
self.generation_config = Some(config);
self
}
pub fn with_temperature(mut self, temperature: f32) -> Self {
self.generation_config.get_or_insert_with(Default::default).temperature = Some(temperature);
self
}
pub fn with_top_p(mut self, top_p: f32) -> Self {
self.generation_config.get_or_insert_with(Default::default).top_p = Some(top_p);
self
}
pub fn with_top_k(mut self, top_k: i32) -> Self {
self.generation_config.get_or_insert_with(Default::default).top_k = Some(top_k);
self
}
pub fn with_max_output_tokens(mut self, max_output_tokens: i32) -> Self {
self.generation_config.get_or_insert_with(Default::default).max_output_tokens =
Some(max_output_tokens);
self
}
pub fn with_candidate_count(mut self, candidate_count: i32) -> Self {
self.generation_config.get_or_insert_with(Default::default).candidate_count =
Some(candidate_count);
self
}
pub fn with_stop_sequences(mut self, stop_sequences: Vec<String>) -> Self {
self.generation_config.get_or_insert_with(Default::default).stop_sequences =
Some(stop_sequences);
self
}
pub fn with_response_mime_type(mut self, mime_type: impl Into<String>) -> Self {
self.generation_config.get_or_insert_with(Default::default).response_mime_type =
Some(mime_type.into());
self
}
pub fn with_response_schema(mut self, schema: serde_json::Value) -> Self {
self.generation_config.get_or_insert_with(Default::default).response_schema = Some(schema);
self
}
pub fn with_tool(mut self, tool: Tool) -> Self {
self.tools.get_or_insert_with(Vec::new).push(tool);
self
}
pub fn with_function(mut self, function: FunctionDeclaration) -> Self {
let tool = Tool::new(function);
self = self.with_tool(tool);
self
}
pub fn with_function_calling_mode(mut self, mode: FunctionCallingMode) -> Self {
self.tool_config.get_or_insert_with(Default::default).function_calling_config =
Some(FunctionCallingConfig { mode });
self
}
pub fn with_thinking_config(mut self, thinking_config: ThinkingConfig) -> Self {
self.generation_config.get_or_insert_with(Default::default).thinking_config =
Some(thinking_config);
self
}
pub fn with_thinking_budget(mut self, budget: i32) -> Self {
self.generation_config
.get_or_insert_with(Default::default)
.thinking_config
.get_or_insert_with(Default::default)
.thinking_budget = Some(budget);
self
}
pub fn with_dynamic_thinking(self) -> Self {
self.with_thinking_budget(-1)
}
pub fn with_thoughts_included(mut self, include: bool) -> Self {
self.generation_config
.get_or_insert_with(Default::default)
.thinking_config
.get_or_insert_with(Default::default)
.include_thoughts = Some(include);
self
}
pub fn with_audio_output(mut self) -> Self {
self.generation_config.get_or_insert_with(Default::default).response_modalities =
Some(vec!["AUDIO".to_string()]);
self
}
pub fn with_speech_config(mut self, speech_config: SpeechConfig) -> Self {
self.generation_config.get_or_insert_with(Default::default).speech_config =
Some(speech_config);
self
}
pub fn with_voice(self, voice_name: impl Into<String>) -> Self {
let speech_config = SpeechConfig::single_voice(voice_name);
self.with_speech_config(speech_config).with_audio_output()
}
pub fn with_multi_speaker_config(self, speakers: Vec<SpeakerVoiceConfig>) -> Self {
let speech_config = SpeechConfig::multi_speaker(speakers);
self.with_speech_config(speech_config).with_audio_output()
}
pub fn build(self) -> GenerateContentRequest {
GenerateContentRequest {
contents: self.contents,
generation_config: self.generation_config,
safety_settings: None,
tools: self.tools,
tool_config: self.tool_config,
system_instruction: self.system_instruction,
cached_content: self.cached_content,
}
}
#[instrument(skip_all, fields(
messages.parts.count = self.contents.len(),
tools.present = self.tools.is_some(),
system.instruction.present = self.system_instruction.is_some(),
cached.content.present = self.cached_content.is_some(),
))]
pub async fn execute(self) -> Result<GenerationResponse, ClientError> {
let client = self.client.clone();
let request = self.build();
client.generate_content_raw(request).await
}
#[instrument(skip_all, fields(
messages.parts.count = self.contents.len(),
tools.present = self.tools.is_some(),
system.instruction.present = self.system_instruction.is_some(),
cached.content.present = self.cached_content.is_some(),
))]
pub async fn execute_stream(
self,
) -> Result<impl TryStream<Ok = GenerationResponse, Error = ClientError> + Send, ClientError>
{
let client = self.client.clone();
let request = self.build();
client.generate_content_stream(request).await
}
}