bep/providers/gemini/completion.rs
1// ================================================================
2//! Google Gemini Completion Integration
3//! From [Gemini API Reference](https://ai.google.dev/api/generate-content)
4// ================================================================
5
6/// `gemini-1.5-flash` completion model
7pub const GEMINI_1_5_FLASH: &str = "gemini-1.5-flash";
8/// `gemini-1.5-pro` completion model
9pub const GEMINI_1_5_PRO: &str = "gemini-1.5-pro";
10/// `gemini-1.5-pro-8b` completion model
11pub const GEMINI_1_5_PRO_8B: &str = "gemini-1.5-pro-8b";
12/// `gemini-1.0-pro` completion model
13pub const GEMINI_1_0_PRO: &str = "gemini-1.0-pro";
14
15use gemini_api_types::{
16 Content, ContentCandidate, FunctionDeclaration, GenerateContentRequest,
17 GenerateContentResponse, GenerationConfig, Part, Role, Tool,
18};
19use serde_json::{Map, Value};
20use std::convert::TryFrom;
21
22use crate::completion::{self, CompletionError, CompletionRequest};
23
24use super::Client;
25
26// =================================================================
27// Bep Implementation Types
28// =================================================================
29
30#[derive(Clone)]
31pub struct CompletionModel {
32 client: Client,
33 pub model: String,
34}
35
36impl CompletionModel {
37 pub fn new(client: Client, model: &str) -> Self {
38 Self {
39 client,
40 model: model.to_string(),
41 }
42 }
43}
44
45impl completion::CompletionModel for CompletionModel {
46 type Response = GenerateContentResponse;
47
48 async fn completion(
49 &self,
50 mut completion_request: CompletionRequest,
51 ) -> Result<completion::CompletionResponse<GenerateContentResponse>, CompletionError> {
52 let mut full_history = Vec::new();
53 full_history.append(&mut completion_request.chat_history);
54
55 let prompt_with_context = completion_request.prompt_with_context();
56
57 full_history.push(completion::Message {
58 role: "user".into(),
59 content: prompt_with_context,
60 });
61
62 // Handle Gemini specific parameters
63 let additional_params = completion_request
64 .additional_params
65 .unwrap_or_else(|| Value::Object(Map::new()));
66 let mut generation_config = serde_json::from_value::<GenerationConfig>(additional_params)?;
67
68 // Set temperature from completion_request or additional_params
69 if let Some(temp) = completion_request.temperature {
70 generation_config.temperature = Some(temp);
71 }
72
73 // Set max_tokens from completion_request or additional_params
74 if let Some(max_tokens) = completion_request.max_tokens {
75 generation_config.max_output_tokens = Some(max_tokens);
76 }
77
78 let request = GenerateContentRequest {
79 contents: full_history
80 .into_iter()
81 .map(|msg| Content {
82 parts: vec![Part {
83 text: Some(msg.content),
84 ..Default::default()
85 }],
86 role: match msg.role.as_str() {
87 "system" => Some(Role::Model),
88 "user" => Some(Role::User),
89 "assistant" => Some(Role::Model),
90 _ => None,
91 },
92 })
93 .collect(),
94 generation_config: Some(generation_config),
95 safety_settings: None,
96 tools: Some(
97 completion_request
98 .tools
99 .into_iter()
100 .map(Tool::from)
101 .collect(),
102 ),
103 tool_config: None,
104 system_instruction: Some(Content {
105 parts: vec![Part {
106 text: Some("system".to_string()),
107 ..Default::default()
108 }],
109 role: Some(Role::Model),
110 }),
111 };
112
113 tracing::debug!("Sending completion request to Gemini API");
114
115 let response = self
116 .client
117 .post(&format!("/v1beta/models/{}:generateContent", self.model))
118 .json(&request)
119 .send()
120 .await?
121 .error_for_status()?
122 .json::<GenerateContentResponse>()
123 .await?;
124
125 match response.usage_metadata {
126 Some(ref usage) => tracing::info!(target: "bep",
127 "Gemini completion token usage: {}",
128 usage
129 ),
130 None => tracing::info!(target: "bep",
131 "Gemini completion token usage: n/a",
132 ),
133 }
134
135 tracing::debug!("Received response");
136
137 completion::CompletionResponse::try_from(response)
138 }
139}
140
141impl From<completion::ToolDefinition> for Tool {
142 fn from(tool: completion::ToolDefinition) -> Self {
143 Self {
144 function_declaration: FunctionDeclaration {
145 name: tool.name,
146 description: tool.description,
147 parameters: None, // tool.parameters, TODO: Map Gemini
148 },
149 code_execution: None,
150 }
151 }
152}
153
154impl TryFrom<GenerateContentResponse> for completion::CompletionResponse<GenerateContentResponse> {
155 type Error = CompletionError;
156
157 fn try_from(response: GenerateContentResponse) -> Result<Self, Self::Error> {
158 match response.candidates.as_slice() {
159 [ContentCandidate { content, .. }, ..] => Ok(completion::CompletionResponse {
160 choice: match content.parts.first().unwrap() {
161 Part {
162 text: Some(text), ..
163 } => completion::ModelChoice::Message(text.clone()),
164 Part {
165 function_call: Some(function_call),
166 ..
167 } => {
168 let args_value = serde_json::Value::Object(
169 function_call.args.clone().unwrap_or_default(),
170 );
171 completion::ModelChoice::ToolCall(function_call.name.clone(), args_value)
172 }
173 _ => {
174 return Err(CompletionError::ResponseError(
175 "Unsupported response by the model of type ".into(),
176 ))
177 }
178 },
179 raw_response: response,
180 }),
181 _ => Err(CompletionError::ResponseError(
182 "No candidates found in response".into(),
183 )),
184 }
185 }
186}
187
188pub mod gemini_api_types {
189 use std::collections::HashMap;
190
191 // =================================================================
192 // Gemini API Types
193 // =================================================================
194 use serde::{Deserialize, Serialize};
195 use serde_json::{Map, Value};
196
197 use crate::{
198 completion::CompletionError,
199 providers::gemini::gemini_api_types::{CodeExecutionResult, ExecutableCode},
200 };
201
202 /// Response from the model supporting multiple candidate responses.
203 /// Safety ratings and content filtering are reported for both prompt in GenerateContentResponse.prompt_feedback
204 /// and for each candidate in finishReason and in safetyRatings.
205 /// The API:
206 /// - Returns either all requested candidates or none of them
207 /// - Returns no candidates at all only if there was something wrong with the prompt (check promptFeedback)
208 /// - Reports feedback on each candidate in finishReason and safetyRatings.
209 #[derive(Debug, Deserialize)]
210 #[serde(rename_all = "camelCase")]
211 pub struct GenerateContentResponse {
212 /// Candidate responses from the model.
213 pub candidates: Vec<ContentCandidate>,
214 /// Returns the prompt's feedback related to the content filters.
215 pub prompt_feedback: Option<PromptFeedback>,
216 /// Output only. Metadata on the generation requests' token usage.
217 pub usage_metadata: Option<UsageMetadata>,
218 pub model_version: Option<String>,
219 }
220
221 /// A response candidate generated from the model.
222 #[derive(Debug, Deserialize)]
223 #[serde(rename_all = "camelCase")]
224 pub struct ContentCandidate {
225 /// Output only. Generated content returned from the model.
226 pub content: Content,
227 /// Optional. Output only. The reason why the model stopped generating tokens.
228 /// If empty, the model has not stopped generating tokens.
229 pub finish_reason: Option<FinishReason>,
230 /// List of ratings for the safety of a response candidate.
231 /// There is at most one rating per category.
232 pub safety_ratings: Option<Vec<SafetyRating>>,
233 /// Output only. Citation information for model-generated candidate.
234 /// This field may be populated with recitation information for any text included in the content.
235 /// These are passages that are "recited" from copybephted material in the foundational LLM's training data.
236 pub citation_metadata: Option<CitationMetadata>,
237 /// Output only. Token count for this candidate.
238 pub token_count: Option<i32>,
239 /// Output only.
240 pub avg_logprobs: Option<f64>,
241 /// Output only. Log-likelihood scores for the response tokens and top tokens
242 pub logprobs_result: Option<LogprobsResult>,
243 /// Output only. Index of the candidate in the list of response candidates.
244 pub index: Option<i32>,
245 }
246 #[derive(Debug, Deserialize, Serialize)]
247 pub struct Content {
248 /// Ordered Parts that constitute a single message. Parts may have different MIME types.
249 pub parts: Vec<Part>,
250 /// The producer of the content. Must be either 'user' or 'model'.
251 /// Useful to set for multi-turn conversations, otherwise can be left blank or unset.
252 pub role: Option<Role>,
253 }
254
255 #[derive(Debug, Deserialize, Serialize)]
256 #[serde(rename_all = "lowercase")]
257 pub enum Role {
258 User,
259 Model,
260 }
261
262 /// A datatype containing media that is part of a multi-part [Content] message.
263 /// A Part consists of data which has an associated datatype. A Part can only contain one of the accepted types in Part.data.
264 /// A Part must have a fixed IANA MIME type identifying the type and subtype of the media if the inlineData field is filled with raw bytes.
265 #[derive(Debug, Default, Deserialize, Serialize)]
266 #[serde(rename_all = "camelCase")]
267 pub struct Part {
268 #[serde(skip_serializing_if = "Option::is_none")]
269 pub text: Option<String>,
270 #[serde(skip_serializing_if = "Option::is_none")]
271 pub inline_data: Option<Blob>,
272 #[serde(skip_serializing_if = "Option::is_none")]
273 pub function_call: Option<FunctionCall>,
274 #[serde(skip_serializing_if = "Option::is_none")]
275 pub function_response: Option<FunctionResponse>,
276 #[serde(skip_serializing_if = "Option::is_none")]
277 pub file_data: Option<FileData>,
278 #[serde(skip_serializing_if = "Option::is_none")]
279 pub executable_code: Option<ExecutableCode>,
280 #[serde(skip_serializing_if = "Option::is_none")]
281 pub code_execution_result: Option<CodeExecutionResult>,
282 }
283
284 /// Raw media bytes.
285 /// Text should not be sent as raw bytes, use the 'text' field.
286 #[derive(Debug, Deserialize, Serialize)]
287 #[serde(rename_all = "camelCase")]
288 pub struct Blob {
289 /// The IANA standard MIME type of the source data. Examples: - image/png - image/jpeg
290 /// If an unsupported MIME type is provided, an error will be returned.
291 pub mime_type: String,
292 /// Raw bytes for media formats. A base64-encoded string.
293 pub data: String,
294 }
295
296 /// A predicted FunctionCall returned from the model that contains a string representing the
297 /// FunctionDeclaration.name with the arguments and their values.
298 /// #[derive(Debug, Deserialize, Serialize)]
299 #[derive(Debug, Deserialize, Serialize)]
300 pub struct FunctionCall {
301 /// Required. The name of the function to call. Must be a-z, A-Z, 0-9, or contain underscores
302 /// and dashes, with a maximum length of 63.
303 pub name: String,
304 /// Optional. The function parameters and values in JSON object format.
305 pub args: Option<Map<String, Value>>,
306 }
307
308 /// The result output from a FunctionCall that contains a string representing the FunctionDeclaration.name
309 /// and a structured JSON object containing any output from the function is used as context to the model.
310 /// This should contain the result of aFunctionCall made based on model prediction.
311 #[derive(Debug, Deserialize, Serialize)]
312 pub struct FunctionResponse {
313 /// The name of the function to call. Must be a-z, A-Z, 0-9, or contain underscores and dashes,
314 /// with a maximum length of 63.
315 pub name: String,
316 /// The function response in JSON object format.
317 pub response: Option<HashMap<String, Value>>,
318 }
319
320 /// URI based data.
321 #[derive(Debug, Deserialize, Serialize)]
322 #[serde(rename_all = "camelCase")]
323 pub struct FileData {
324 /// Optional. The IANA standard MIME type of the source data.
325 pub mime_type: Option<String>,
326 /// Required. URI.
327 pub file_uri: String,
328 }
329
330 #[derive(Debug, Deserialize, Serialize)]
331 pub struct SafetyRating {
332 pub category: HarmCategory,
333 pub probability: HarmProbability,
334 }
335
336 #[derive(Debug, Deserialize, Serialize)]
337 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
338 pub enum HarmProbability {
339 HarmProbabilityUnspecified,
340 Negligible,
341 Low,
342 Medium,
343 High,
344 }
345
346 #[derive(Debug, Deserialize, Serialize)]
347 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
348 pub enum HarmCategory {
349 HarmCategoryUnspecified,
350 HarmCategoryDerogatory,
351 HarmCategoryToxicity,
352 HarmCategoryViolence,
353 HarmCategorySexually,
354 HarmCategoryMedical,
355 HarmCategoryDangerous,
356 HarmCategoryHarassment,
357 HarmCategoryHateSpeech,
358 HarmCategorySexuallyExplicit,
359 HarmCategoryDangerousContent,
360 HarmCategoryCivicIntegrity,
361 }
362
363 #[derive(Debug, Deserialize)]
364 #[serde(rename_all = "camelCase")]
365 pub struct UsageMetadata {
366 pub prompt_token_count: i32,
367 pub cached_content_token_count: Option<i32>,
368 pub candidates_token_count: i32,
369 pub total_token_count: i32,
370 }
371
372 impl std::fmt::Display for UsageMetadata {
373 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
374 write!(
375 f,
376 "Prompt token count: {}\nCached content token count: {}\nCandidates token count: {}\nTotal token count: {}",
377 self.prompt_token_count,
378 match self.cached_content_token_count {
379 Some(count) => count.to_string(),
380 None => "n/a".to_string(),
381 },
382 self.candidates_token_count,
383 self.total_token_count
384 )
385 }
386 }
387
388 /// A set of the feedback metadata the prompt specified in [GenerateContentRequest.contents](GenerateContentRequest).
389 #[derive(Debug, Deserialize)]
390 #[serde(rename_all = "camelCase")]
391 pub struct PromptFeedback {
392 /// Optional. If set, the prompt was blocked and no candidates are returned. Rephrase the prompt.
393 pub block_reason: Option<BlockReason>,
394 /// Ratings for safety of the prompt. There is at most one rating per category.
395 pub safety_ratings: Option<Vec<SafetyRating>>,
396 }
397
398 /// Reason why a prompt was blocked by the model
399 #[derive(Debug, Deserialize)]
400 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
401 pub enum BlockReason {
402 /// Default value. This value is unused.
403 BlockReasonUnspecified,
404 /// Prompt was blocked due to safety reasons. Inspect safetyRatings to understand which safety category blocked it.
405 Safety,
406 /// Prompt was blocked due to unknown reasons.
407 Other,
408 /// Prompt was blocked due to the terms which are included from the terminology blocklist.
409 Blocklist,
410 /// Prompt was blocked due to prohibited content.
411 ProhibitedContent,
412 }
413
414 #[derive(Debug, Deserialize)]
415 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
416 pub enum FinishReason {
417 /// Default value. This value is unused.
418 FinishReasonUnspecified,
419 /// Natural stop point of the model or provided stop sequence.
420 Stop,
421 /// The maximum number of tokens as specified in the request was reached.
422 MaxTokens,
423 /// The response candidate content was flagged for safety reasons.
424 Safety,
425 /// The response candidate content was flagged for recitation reasons.
426 Recitation,
427 /// The response candidate content was flagged for using an unsupported language.
428 Language,
429 /// Unknown reason.
430 Other,
431 /// Token generation stopped because the content contains forbidden terms.
432 Blocklist,
433 /// Token generation stopped for potentially containing prohibited content.
434 ProhibitedContent,
435 /// Token generation stopped because the content potentially contains Sensitive Personally Identifiable Information (SPII).
436 Spii,
437 /// The function call generated by the model is invalid.
438 MalformedFunctionCall,
439 }
440
441 #[derive(Debug, Deserialize)]
442 #[serde(rename_all = "camelCase")]
443 pub struct CitationMetadata {
444 pub citation_sources: Vec<CitationSource>,
445 }
446
447 #[derive(Debug, Deserialize)]
448 #[serde(rename_all = "camelCase")]
449 pub struct CitationSource {
450 pub uri: Option<String>,
451 pub start_index: Option<i32>,
452 pub end_index: Option<i32>,
453 pub license: Option<String>,
454 }
455
456 #[derive(Debug, Deserialize)]
457 #[serde(rename_all = "camelCase")]
458 pub struct LogprobsResult {
459 pub top_candidate: Vec<TopCandidate>,
460 pub chosen_candidate: Vec<LogProbCandidate>,
461 }
462
463 #[derive(Debug, Deserialize)]
464 pub struct TopCandidate {
465 pub candidates: Vec<LogProbCandidate>,
466 }
467
468 #[derive(Debug, Deserialize)]
469 #[serde(rename_all = "camelCase")]
470 pub struct LogProbCandidate {
471 pub token: String,
472 pub token_id: String,
473 pub log_probability: f64,
474 }
475
476 /// Gemini API Configuration options for model generation and outputs. Not all parameters are
477 /// configurable for every model. From [Gemini API Reference](https://ai.google.dev/api/generate-content#generationconfig)
478 /// ### Bep Note:
479 /// Can be used to cosntruct a typesafe `additional_params` in bep::[AgentBuilder](crate::agent::AgentBuilder).
480 #[derive(Debug, Deserialize, Serialize)]
481 #[serde(rename_all = "camelCase")]
482 pub struct GenerationConfig {
483 /// The set of character sequences (up to 5) that will stop output generation. If specified, the API will stop
484 /// at the first appearance of a stop_sequence. The stop sequence will not be included as part of the response.
485 pub stop_sequences: Option<Vec<String>>,
486 /// MIME type of the generated candidate text. Supported MIME types are:
487 /// - text/plain: (default) Text output
488 /// - application/json: JSON response in the response candidates.
489 /// - text/x.enum: ENUM as a string response in the response candidates.
490 /// Refer to the docs for a list of all supported text MIME types
491 pub response_mime_type: Option<String>,
492 /// Output schema of the generated candidate text. Schemas must be a subset of the OpenAPI schema and can be
493 /// objects, primitives or arrays. If set, a compatible responseMimeType must also be set. Compatible MIME
494 /// types: application/json: Schema for JSON response. Refer to the JSON text generation guide for more details.
495 pub response_schema: Option<Schema>,
496 /// Number of generated responses to return. Currently, this value can only be set to 1. If
497 /// unset, this will default to 1.
498 pub candidate_count: Option<i32>,
499 /// The maximum number of tokens to include in a response candidate. Note: The default value varies by model, see
500 /// the Model.output_token_limit attribute of the Model returned from the getModel function.
501 pub max_output_tokens: Option<u64>,
502 /// Controls the randomness of the output. Note: The default value varies by model, see the Model.temperature
503 /// attribute of the Model returned from the getModel function. Values can range from [0.0, 2.0].
504 pub temperature: Option<f64>,
505 /// The maximum cumulative probability of tokens to consider when sampling. The model uses combined Top-k and
506 /// Top-p (nucleus) sampling. Tokens are sorted based on their assigned probabilities so that only the most
507 /// likely tokens are considered. Top-k sampling directly limits the maximum number of tokens to consider, while
508 /// Nucleus sampling limits the number of tokens based on the cumulative probability. Note: The default value
509 /// varies by Model and is specified by theModel.top_p attribute returned from the getModel function. An empty
510 /// topK attribute indicates that the model doesn't apply top-k sampling and doesn't allow setting topK on requests.
511 pub top_p: Option<f64>,
512 /// The maximum number of tokens to consider when sampling. Gemini models use Top-p (nucleus) sampling or a
513 /// combination of Top-k and nucleus sampling. Top-k sampling considers the set of topK most probable tokens.
514 /// Models running with nucleus sampling don't allow topK setting. Note: The default value varies by Model and is
515 /// specified by theModel.top_p attribute returned from the getModel function. An empty topK attribute indicates
516 /// that the model doesn't apply top-k sampling and doesn't allow setting topK on requests.
517 pub top_k: Option<i32>,
518 /// Presence penalty applied to the next token's logprobs if the token has already been seen in the response.
519 /// This penalty is binary on/off and not dependant on the number of times the token is used (after the first).
520 /// Use frequencyPenalty for a penalty that increases with each use. A positive penalty will discourage the use
521 /// of tokens that have already been used in the response, increasing the vocabulary. A negative penalty will
522 /// encourage the use of tokens that have already been used in the response, decreasing the vocabulary.
523 pub presence_penalty: Option<f64>,
524 /// Frequency penalty applied to the next token's logprobs, multiplied by the number of times each token has been
525 /// seen in the respponse so far. A positive penalty will discourage the use of tokens that have already been
526 /// used, proportional to the number of times the token has been used: The more a token is used, the more
527 /// dificult it is for the model to use that token again increasing the vocabulary of responses. Caution: A
528 /// negative penalty will encourage the model to reuse tokens proportional to the number of times the token has
529 /// been used. Small negative values will reduce the vocabulary of a response. Larger negative values will cause
530 /// the model to repeating a common token until it hits the maxOutputTokens limit: "...the the the the the...".
531 pub frequency_penalty: Option<f64>,
532 /// If true, export the logprobs results in response.
533 pub response_logprobs: Option<bool>,
534 /// Only valid if responseLogprobs=True. This sets the number of top logprobs to return at each decoding step in
535 /// [Candidate.logprobs_result].
536 pub logprobs: Option<i32>,
537 }
538
539 impl Default for GenerationConfig {
540 fn default() -> Self {
541 Self {
542 temperature: Some(1.0),
543 max_output_tokens: Some(4096),
544 stop_sequences: None,
545 response_mime_type: None,
546 response_schema: None,
547 candidate_count: None,
548 top_p: None,
549 top_k: None,
550 presence_penalty: None,
551 frequency_penalty: None,
552 response_logprobs: None,
553 logprobs: None,
554 }
555 }
556 }
557 /// The Schema object allows the definition of input and output data types. These types can be objects, but also
558 /// primitives and arrays. Represents a select subset of an OpenAPI 3.0 schema object.
559 /// From [Gemini API Reference](https://ai.google.dev/api/caching#Schema)
560 #[derive(Debug, Deserialize, Serialize)]
561 pub struct Schema {
562 pub r#type: String,
563 pub format: Option<String>,
564 pub description: Option<String>,
565 pub nullable: Option<bool>,
566 pub r#enum: Option<Vec<String>>,
567 pub max_items: Option<i32>,
568 pub min_items: Option<i32>,
569 pub properties: Option<HashMap<String, Schema>>,
570 pub required: Option<Vec<String>>,
571 pub items: Option<Box<Schema>>,
572 }
573
574 impl TryFrom<Value> for Schema {
575 type Error = CompletionError;
576
577 fn try_from(value: Value) -> Result<Self, Self::Error> {
578 if let Some(obj) = value.as_object() {
579 Ok(Schema {
580 r#type: obj
581 .get("type")
582 .and_then(|v| v.as_str())
583 .unwrap_or_default()
584 .to_string(),
585 format: obj.get("format").and_then(|v| v.as_str()).map(String::from),
586 description: obj
587 .get("description")
588 .and_then(|v| v.as_str())
589 .map(String::from),
590 nullable: obj.get("nullable").and_then(|v| v.as_bool()),
591 r#enum: obj.get("enum").and_then(|v| v.as_array()).map(|arr| {
592 arr.iter()
593 .filter_map(|v| v.as_str().map(String::from))
594 .collect()
595 }),
596 max_items: obj
597 .get("maxItems")
598 .and_then(|v| v.as_i64())
599 .map(|v| v as i32),
600 min_items: obj
601 .get("minItems")
602 .and_then(|v| v.as_i64())
603 .map(|v| v as i32),
604 properties: obj
605 .get("properties")
606 .and_then(|v| v.as_object())
607 .map(|map| {
608 map.iter()
609 .filter_map(|(k, v)| {
610 v.clone().try_into().ok().map(|schema| (k.clone(), schema))
611 })
612 .collect()
613 }),
614 required: obj.get("required").and_then(|v| v.as_array()).map(|arr| {
615 arr.iter()
616 .filter_map(|v| v.as_str().map(String::from))
617 .collect()
618 }),
619 items: obj
620 .get("items")
621 .map(|v| Box::new(v.clone().try_into().unwrap())),
622 })
623 } else {
624 Err(CompletionError::ResponseError(
625 "Expected a JSON object for Schema".into(),
626 ))
627 }
628 }
629 }
630
631 #[derive(Debug, Serialize)]
632 #[serde(rename_all = "camelCase")]
633 pub struct GenerateContentRequest {
634 pub contents: Vec<Content>,
635 pub tools: Option<Vec<Tool>>,
636 pub tool_config: Option<ToolConfig>,
637 /// Optional. Configuration options for model generation and outputs.
638 pub generation_config: Option<GenerationConfig>,
639 /// Optional. A list of unique SafetySetting instances for blocking unsafe content. This will be enforced on the
640 /// [GenerateContentRequest.contents] and [GenerateContentResponse.candidates]. There should not be more than one
641 /// setting for each SafetyCategory type. The API will block any contents and responses that fail to meet the
642 /// thresholds set by these settings. This list overrides the default settings for each SafetyCategory specified
643 /// in the safetySettings. If there is no SafetySetting for a given SafetyCategory provided in the list, the API
644 /// will use the default safety setting for that category. Harm categories:
645 /// - HARM_CATEGORY_HATE_SPEECH,
646 /// - HARM_CATEGORY_SEXUALLY_EXPLICIT
647 /// - HARM_CATEGORY_DANGEROUS_CONTENT
648 /// - HARM_CATEGORY_HARASSMENT
649 /// are supported.
650 /// Refer to the guide for detailed information on available safety settings. Also refer to the Safety guidance
651 /// to learn how to incorporate safety considerations in your AI applications.
652 pub safety_settings: Option<Vec<SafetySetting>>,
653 /// Optional. Developer set system instruction(s). Currently, text only.
654 /// From [Gemini API Reference](https://ai.google.dev/gemini-api/docs/system-instructions?lang=rest)
655 pub system_instruction: Option<Content>,
656 // cachedContent: Optional<String>
657 }
658
659 #[derive(Debug, Serialize)]
660 #[serde(rename_all = "camelCase")]
661 pub struct Tool {
662 pub function_declaration: FunctionDeclaration,
663 pub code_execution: Option<CodeExecution>,
664 }
665
666 #[derive(Debug, Serialize)]
667 #[serde(rename_all = "camelCase")]
668 pub struct FunctionDeclaration {
669 pub name: String,
670 pub description: String,
671 pub parameters: Option<Vec<Schema>>,
672 }
673
674 #[derive(Debug, Serialize)]
675 #[serde(rename_all = "camelCase")]
676 pub struct ToolConfig {
677 pub schema: Option<Schema>,
678 }
679
680 #[derive(Debug, Serialize)]
681 #[serde(rename_all = "camelCase")]
682 pub struct CodeExecution {}
683
684 #[derive(Debug, Serialize)]
685 #[serde(rename_all = "camelCase")]
686 pub struct SafetySetting {
687 pub category: HarmCategory,
688 pub threshold: HarmBlockThreshold,
689 }
690
691 #[derive(Debug, Serialize)]
692 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
693 pub enum HarmBlockThreshold {
694 HarmBlockThresholdUnspecified,
695 BlockLowAndAbove,
696 BlockMediumAndAbove,
697 BlockOnlyHigh,
698 BlockNone,
699 Off,
700 }
701}