gemini_rust/generation/
builder.rs

1use futures::TryStream;
2use std::sync::Arc;
3
4use crate::{
5    cache::CachedContentHandle,
6    client::{Error as ClientError, GeminiClient},
7    generation::{GenerateContentRequest, SpeakerVoiceConfig, SpeechConfig, ThinkingConfig},
8    tools::{FunctionCallingConfig, ToolConfig},
9    Content, FunctionCallingMode, FunctionDeclaration, GenerationConfig, GenerationResponse,
10    Message, Role, Tool,
11};
12
13/// Builder for content generation requests
14pub struct ContentBuilder {
15    client: Arc<GeminiClient>,
16    pub contents: Vec<Content>,
17    generation_config: Option<GenerationConfig>,
18    tools: Option<Vec<Tool>>,
19    tool_config: Option<ToolConfig>,
20    system_instruction: Option<Content>,
21    cached_content: Option<String>,
22}
23
24impl ContentBuilder {
25    /// Create a new content builder
26    pub(crate) fn new(client: Arc<GeminiClient>) -> Self {
27        Self {
28            client,
29            contents: Vec::new(),
30            generation_config: None,
31            tools: None,
32            tool_config: None,
33            system_instruction: None,
34            cached_content: None,
35        }
36    }
37
38    /// Add a system prompt to the request
39    pub fn with_system_prompt(self, text: impl Into<String>) -> Self {
40        // Create a Content with text parts specifically for system_instruction field
41        self.with_system_instruction(text)
42    }
43
44    /// Set the system instruction directly (matching the API format in the curl example)
45    pub fn with_system_instruction(mut self, text: impl Into<String>) -> Self {
46        // Create a Content with text parts specifically for system_instruction field
47        let content = Content::text(text);
48        self.system_instruction = Some(content);
49        self
50    }
51
52    /// Add a user message to the request
53    pub fn with_user_message(mut self, text: impl Into<String>) -> Self {
54        let message = Message::user(text);
55        let content = message.content;
56        self.contents.push(content);
57        self
58    }
59
60    /// Add a model message to the request
61    pub fn with_model_message(mut self, text: impl Into<String>) -> Self {
62        let message = Message::model(text);
63        let content = message.content;
64        self.contents.push(content);
65        self
66    }
67
68    /// Add a inline data (blob data) to the request
69    pub fn with_inline_data(
70        mut self,
71        data: impl Into<String>,
72        mime_type: impl Into<String>,
73    ) -> Self {
74        let content = Content::inline_data(mime_type, data).with_role(Role::User);
75        self.contents.push(content);
76        self
77    }
78
79    /// Add a function response to the request using a JSON value
80    pub fn with_function_response(
81        mut self,
82        name: impl Into<String>,
83        response: serde_json::Value,
84    ) -> Self {
85        let content = Content::function_response_json(name, response).with_role(Role::User);
86        self.contents.push(content);
87        self
88    }
89
90    /// Add a function response to the request using a JSON string
91    pub fn with_function_response_str(
92        mut self,
93        name: impl Into<String>,
94        response: impl Into<String>,
95    ) -> std::result::Result<Self, serde_json::Error> {
96        let response_str = response.into();
97        let json = serde_json::from_str(&response_str)?;
98        let content = Content::function_response_json(name, json).with_role(Role::User);
99        self.contents.push(content);
100        Ok(self)
101    }
102
103    /// Add a message to the request
104    pub fn with_message(mut self, message: Message) -> Self {
105        let content = message.content.clone();
106        match &content.role {
107            Some(role) => {
108                let role_clone = role.clone();
109                self.contents.push(content.with_role(role_clone));
110            }
111            None => {
112                self.contents.push(content.with_role(message.role));
113            }
114        }
115        self
116    }
117
118    /// Use cached content for this request.
119    /// This allows reusing previously cached system instructions and conversation history.
120    pub fn with_cached_content(mut self, cached_content: &CachedContentHandle) -> Self {
121        self.cached_content = Some(cached_content.name().to_string());
122        self
123    }
124
125    /// Add multiple messages to the request
126    pub fn with_messages(mut self, messages: impl IntoIterator<Item = Message>) -> Self {
127        for message in messages {
128            self = self.with_message(message);
129        }
130        self
131    }
132
133    /// Set the generation config for the request
134    pub fn with_generation_config(mut self, config: GenerationConfig) -> Self {
135        self.generation_config = Some(config);
136        self
137    }
138
139    /// Set the temperature for the request
140    pub fn with_temperature(mut self, temperature: f32) -> Self {
141        if self.generation_config.is_none() {
142            self.generation_config = Some(GenerationConfig::default());
143        }
144        if let Some(config) = &mut self.generation_config {
145            config.temperature = Some(temperature);
146        }
147        self
148    }
149
150    /// Set the top-p value for the request
151    pub fn with_top_p(mut self, top_p: f32) -> Self {
152        if self.generation_config.is_none() {
153            self.generation_config = Some(GenerationConfig::default());
154        }
155        if let Some(config) = &mut self.generation_config {
156            config.top_p = Some(top_p);
157        }
158        self
159    }
160
161    /// Set the top-k value for the request
162    pub fn with_top_k(mut self, top_k: i32) -> Self {
163        if self.generation_config.is_none() {
164            self.generation_config = Some(GenerationConfig::default());
165        }
166        if let Some(config) = &mut self.generation_config {
167            config.top_k = Some(top_k);
168        }
169        self
170    }
171
172    /// Set the maximum output tokens for the request
173    pub fn with_max_output_tokens(mut self, max_output_tokens: i32) -> Self {
174        if self.generation_config.is_none() {
175            self.generation_config = Some(GenerationConfig::default());
176        }
177        if let Some(config) = &mut self.generation_config {
178            config.max_output_tokens = Some(max_output_tokens);
179        }
180        self
181    }
182
183    /// Set the candidate count for the request
184    pub fn with_candidate_count(mut self, candidate_count: i32) -> Self {
185        if self.generation_config.is_none() {
186            self.generation_config = Some(GenerationConfig::default());
187        }
188        if let Some(config) = &mut self.generation_config {
189            config.candidate_count = Some(candidate_count);
190        }
191        self
192    }
193
194    /// Set the stop sequences for the request
195    pub fn with_stop_sequences(mut self, stop_sequences: Vec<String>) -> Self {
196        if self.generation_config.is_none() {
197            self.generation_config = Some(GenerationConfig::default());
198        }
199        if let Some(config) = &mut self.generation_config {
200            config.stop_sequences = Some(stop_sequences);
201        }
202        self
203    }
204
205    /// Set the response mime type for the request
206    pub fn with_response_mime_type(mut self, mime_type: impl Into<String>) -> Self {
207        if self.generation_config.is_none() {
208            self.generation_config = Some(GenerationConfig::default());
209        }
210        if let Some(config) = &mut self.generation_config {
211            config.response_mime_type = Some(mime_type.into());
212        }
213        self
214    }
215
216    /// Set the response schema for structured output
217    pub fn with_response_schema(mut self, schema: serde_json::Value) -> Self {
218        if self.generation_config.is_none() {
219            self.generation_config = Some(GenerationConfig::default());
220        }
221        if let Some(config) = &mut self.generation_config {
222            config.response_schema = Some(schema);
223        }
224        self
225    }
226
227    /// Add a tool to the request
228    pub fn with_tool(mut self, tool: Tool) -> Self {
229        if self.tools.is_none() {
230            self.tools = Some(Vec::new());
231        }
232        if let Some(tools) = &mut self.tools {
233            tools.push(tool);
234        }
235        self
236    }
237
238    /// Add a function declaration as a tool
239    pub fn with_function(mut self, function: FunctionDeclaration) -> Self {
240        let tool = Tool::new(function);
241        self = self.with_tool(tool);
242        self
243    }
244
245    /// Set the function calling mode for the request
246    pub fn with_function_calling_mode(mut self, mode: FunctionCallingMode) -> Self {
247        if self.tool_config.is_none() {
248            self.tool_config = Some(ToolConfig {
249                function_calling_config: Some(FunctionCallingConfig { mode }),
250            });
251        } else if let Some(tool_config) = &mut self.tool_config {
252            tool_config.function_calling_config = Some(FunctionCallingConfig { mode });
253        }
254        self
255    }
256
257    /// Set the thinking configuration for the request (Gemini 2.5 series only)
258    pub fn with_thinking_config(mut self, thinking_config: ThinkingConfig) -> Self {
259        if self.generation_config.is_none() {
260            self.generation_config = Some(GenerationConfig::default());
261        }
262        if let Some(config) = &mut self.generation_config {
263            config.thinking_config = Some(thinking_config);
264        }
265        self
266    }
267
268    /// Set the thinking budget for the request (Gemini 2.5 series only)
269    pub fn with_thinking_budget(mut self, budget: i32) -> Self {
270        if self.generation_config.is_none() {
271            self.generation_config = Some(GenerationConfig::default());
272        }
273        if let Some(config) = &mut self.generation_config {
274            if config.thinking_config.is_none() {
275                config.thinking_config = Some(ThinkingConfig::default());
276            }
277            if let Some(thinking_config) = &mut config.thinking_config {
278                thinking_config.thinking_budget = Some(budget);
279            }
280        }
281        self
282    }
283
284    /// Enable dynamic thinking (model decides the budget) (Gemini 2.5 series only)
285    pub fn with_dynamic_thinking(self) -> Self {
286        self.with_thinking_budget(-1)
287    }
288
289    /// Include thought summaries in the response (Gemini 2.5 series only)
290    pub fn with_thoughts_included(mut self, include: bool) -> Self {
291        if self.generation_config.is_none() {
292            self.generation_config = Some(GenerationConfig::default());
293        }
294        if let Some(config) = &mut self.generation_config {
295            if config.thinking_config.is_none() {
296                config.thinking_config = Some(ThinkingConfig::default());
297            }
298            if let Some(thinking_config) = &mut config.thinking_config {
299                thinking_config.include_thoughts = Some(include);
300            }
301        }
302        self
303    }
304
305    /// Enable audio output (text-to-speech)
306    pub fn with_audio_output(mut self) -> Self {
307        if self.generation_config.is_none() {
308            self.generation_config = Some(GenerationConfig::default());
309        }
310        if let Some(config) = &mut self.generation_config {
311            config.response_modalities = Some(vec!["AUDIO".to_string()]);
312        }
313        self
314    }
315
316    /// Set speech configuration for text-to-speech generation
317    pub fn with_speech_config(mut self, speech_config: SpeechConfig) -> Self {
318        if self.generation_config.is_none() {
319            self.generation_config = Some(GenerationConfig::default());
320        }
321        if let Some(config) = &mut self.generation_config {
322            config.speech_config = Some(speech_config);
323        }
324        self
325    }
326
327    /// Set a single voice for text-to-speech generation
328    pub fn with_voice(self, voice_name: impl Into<String>) -> Self {
329        let speech_config = SpeechConfig::single_voice(voice_name);
330        self.with_speech_config(speech_config).with_audio_output()
331    }
332
333    /// Set multi-speaker configuration for text-to-speech generation
334    pub fn with_multi_speaker_config(self, speakers: Vec<SpeakerVoiceConfig>) -> Self {
335        let speech_config = SpeechConfig::multi_speaker(speakers);
336        self.with_speech_config(speech_config).with_audio_output()
337    }
338
339    pub fn build(self) -> GenerateContentRequest {
340        GenerateContentRequest {
341            contents: self.contents,
342            generation_config: self.generation_config,
343            safety_settings: None,
344            tools: self.tools,
345            tool_config: self.tool_config,
346            system_instruction: self.system_instruction,
347            cached_content: self.cached_content,
348        }
349    }
350
351    /// Execute the request
352    pub async fn execute(self) -> Result<GenerationResponse, ClientError> {
353        let client = self.client.clone();
354        let request = self.build();
355        client.generate_content_raw(request).await
356    }
357
358    /// Execute the request with streaming
359    pub async fn execute_stream(
360        self,
361    ) -> Result<impl TryStream<Ok = GenerationResponse, Error = ClientError> + Send, ClientError>
362    {
363        let request = GenerateContentRequest {
364            contents: self.contents,
365            generation_config: self.generation_config,
366            safety_settings: None,
367            tools: self.tools,
368            tool_config: self.tool_config,
369            system_instruction: self.system_instruction,
370            cached_content: self.cached_content,
371        };
372
373        self.client.generate_content_stream(request).await
374    }
375}