bep/providers/gemini/
completion.rs

1// ================================================================
2//! Google Gemini Completion Integration
3//! From [Gemini API Reference](https://ai.google.dev/api/generate-content)
4// ================================================================
5
6/// `gemini-1.5-flash` completion model
7pub const GEMINI_1_5_FLASH: &str = "gemini-1.5-flash";
8/// `gemini-1.5-pro` completion model
9pub const GEMINI_1_5_PRO: &str = "gemini-1.5-pro";
10/// `gemini-1.5-pro-8b` completion model
11pub const GEMINI_1_5_PRO_8B: &str = "gemini-1.5-pro-8b";
12/// `gemini-1.0-pro` completion model
13pub const GEMINI_1_0_PRO: &str = "gemini-1.0-pro";
14
15use gemini_api_types::{
16    Content, ContentCandidate, FunctionDeclaration, GenerateContentRequest,
17    GenerateContentResponse, GenerationConfig, Part, Role, Tool,
18};
19use serde_json::{Map, Value};
20use std::convert::TryFrom;
21
22use crate::completion::{self, CompletionError, CompletionRequest};
23
24use super::Client;
25
26// =================================================================
27// Bep Implementation Types
28// =================================================================
29
30#[derive(Clone)]
31pub struct CompletionModel {
32    client: Client,
33    pub model: String,
34}
35
36impl CompletionModel {
37    pub fn new(client: Client, model: &str) -> Self {
38        Self {
39            client,
40            model: model.to_string(),
41        }
42    }
43}
44
45impl completion::CompletionModel for CompletionModel {
46    type Response = GenerateContentResponse;
47
48    async fn completion(
49        &self,
50        mut completion_request: CompletionRequest,
51    ) -> Result<completion::CompletionResponse<GenerateContentResponse>, CompletionError> {
52        let mut full_history = Vec::new();
53        full_history.append(&mut completion_request.chat_history);
54
55        let prompt_with_context = completion_request.prompt_with_context();
56
57        full_history.push(completion::Message {
58            role: "user".into(),
59            content: prompt_with_context,
60        });
61
62        // Handle Gemini specific parameters
63        let additional_params = completion_request
64            .additional_params
65            .unwrap_or_else(|| Value::Object(Map::new()));
66        let mut generation_config = serde_json::from_value::<GenerationConfig>(additional_params)?;
67
68        // Set temperature from completion_request or additional_params
69        if let Some(temp) = completion_request.temperature {
70            generation_config.temperature = Some(temp);
71        }
72
73        // Set max_tokens from completion_request or additional_params
74        if let Some(max_tokens) = completion_request.max_tokens {
75            generation_config.max_output_tokens = Some(max_tokens);
76        }
77
78        let request = GenerateContentRequest {
79            contents: full_history
80                .into_iter()
81                .map(|msg| Content {
82                    parts: vec![Part {
83                        text: Some(msg.content),
84                        ..Default::default()
85                    }],
86                    role: match msg.role.as_str() {
87                        "system" => Some(Role::Model),
88                        "user" => Some(Role::User),
89                        "assistant" => Some(Role::Model),
90                        _ => None,
91                    },
92                })
93                .collect(),
94            generation_config: Some(generation_config),
95            safety_settings: None,
96            tools: Some(
97                completion_request
98                    .tools
99                    .into_iter()
100                    .map(Tool::from)
101                    .collect(),
102            ),
103            tool_config: None,
104            system_instruction: Some(Content {
105                parts: vec![Part {
106                    text: Some("system".to_string()),
107                    ..Default::default()
108                }],
109                role: Some(Role::Model),
110            }),
111        };
112
113        tracing::debug!("Sending completion request to Gemini API");
114
115        let response = self
116            .client
117            .post(&format!("/v1beta/models/{}:generateContent", self.model))
118            .json(&request)
119            .send()
120            .await?
121            .error_for_status()?
122            .json::<GenerateContentResponse>()
123            .await?;
124
125        match response.usage_metadata {
126            Some(ref usage) => tracing::info!(target: "bep",
127            "Gemini completion token usage: {}",
128            usage
129            ),
130            None => tracing::info!(target: "bep",
131                "Gemini completion token usage: n/a",
132            ),
133        }
134
135        tracing::debug!("Received response");
136
137        completion::CompletionResponse::try_from(response)
138    }
139}
140
141impl From<completion::ToolDefinition> for Tool {
142    fn from(tool: completion::ToolDefinition) -> Self {
143        Self {
144            function_declaration: FunctionDeclaration {
145                name: tool.name,
146                description: tool.description,
147                parameters: None, // tool.parameters, TODO: Map Gemini
148            },
149            code_execution: None,
150        }
151    }
152}
153
154impl TryFrom<GenerateContentResponse> for completion::CompletionResponse<GenerateContentResponse> {
155    type Error = CompletionError;
156
157    fn try_from(response: GenerateContentResponse) -> Result<Self, Self::Error> {
158        match response.candidates.as_slice() {
159            [ContentCandidate { content, .. }, ..] => Ok(completion::CompletionResponse {
160                choice: match content.parts.first().unwrap() {
161                    Part {
162                        text: Some(text), ..
163                    } => completion::ModelChoice::Message(text.clone()),
164                    Part {
165                        function_call: Some(function_call),
166                        ..
167                    } => {
168                        let args_value = serde_json::Value::Object(
169                            function_call.args.clone().unwrap_or_default(),
170                        );
171                        completion::ModelChoice::ToolCall(function_call.name.clone(), args_value)
172                    }
173                    _ => {
174                        return Err(CompletionError::ResponseError(
175                            "Unsupported response by the model of type ".into(),
176                        ))
177                    }
178                },
179                raw_response: response,
180            }),
181            _ => Err(CompletionError::ResponseError(
182                "No candidates found in response".into(),
183            )),
184        }
185    }
186}
187
188pub mod gemini_api_types {
189    use std::collections::HashMap;
190
191    // =================================================================
192    // Gemini API Types
193    // =================================================================
194    use serde::{Deserialize, Serialize};
195    use serde_json::{Map, Value};
196
197    use crate::{
198        completion::CompletionError,
199        providers::gemini::gemini_api_types::{CodeExecutionResult, ExecutableCode},
200    };
201
202    /// Response from the model supporting multiple candidate responses.
203    /// Safety ratings and content filtering are reported for both prompt in GenerateContentResponse.prompt_feedback
204    /// and for each candidate in finishReason and in safetyRatings.
205    /// The API:
206    ///     - Returns either all requested candidates or none of them
207    ///     - Returns no candidates at all only if there was something wrong with the prompt (check promptFeedback)
208    ///     - Reports feedback on each candidate in finishReason and safetyRatings.
209    #[derive(Debug, Deserialize)]
210    #[serde(rename_all = "camelCase")]
211    pub struct GenerateContentResponse {
212        /// Candidate responses from the model.
213        pub candidates: Vec<ContentCandidate>,
214        /// Returns the prompt's feedback related to the content filters.
215        pub prompt_feedback: Option<PromptFeedback>,
216        /// Output only. Metadata on the generation requests' token usage.
217        pub usage_metadata: Option<UsageMetadata>,
218        pub model_version: Option<String>,
219    }
220
221    /// A response candidate generated from the model.
222    #[derive(Debug, Deserialize)]
223    #[serde(rename_all = "camelCase")]
224    pub struct ContentCandidate {
225        /// Output only. Generated content returned from the model.
226        pub content: Content,
227        /// Optional. Output only. The reason why the model stopped generating tokens.
228        /// If empty, the model has not stopped generating tokens.
229        pub finish_reason: Option<FinishReason>,
230        /// List of ratings for the safety of a response candidate.
231        /// There is at most one rating per category.
232        pub safety_ratings: Option<Vec<SafetyRating>>,
233        /// Output only. Citation information for model-generated candidate.
234        /// This field may be populated with recitation information for any text included in the content.
235        /// These are passages that are "recited" from copybephted material in the foundational LLM's training data.
236        pub citation_metadata: Option<CitationMetadata>,
237        /// Output only. Token count for this candidate.
238        pub token_count: Option<i32>,
239        /// Output only.
240        pub avg_logprobs: Option<f64>,
241        /// Output only. Log-likelihood scores for the response tokens and top tokens
242        pub logprobs_result: Option<LogprobsResult>,
243        /// Output only. Index of the candidate in the list of response candidates.
244        pub index: Option<i32>,
245    }
246    #[derive(Debug, Deserialize, Serialize)]
247    pub struct Content {
248        /// Ordered Parts that constitute a single message. Parts may have different MIME types.
249        pub parts: Vec<Part>,
250        /// The producer of the content. Must be either 'user' or 'model'.
251        /// Useful to set for multi-turn conversations, otherwise can be left blank or unset.
252        pub role: Option<Role>,
253    }
254
255    #[derive(Debug, Deserialize, Serialize)]
256    #[serde(rename_all = "lowercase")]
257    pub enum Role {
258        User,
259        Model,
260    }
261
262    /// A datatype containing media that is part of a multi-part [Content] message.
263    /// A Part consists of data which has an associated datatype. A Part can only contain one of the accepted types in Part.data.
264    /// A Part must have a fixed IANA MIME type identifying the type and subtype of the media if the inlineData field is filled with raw bytes.
265    #[derive(Debug, Default, Deserialize, Serialize)]
266    #[serde(rename_all = "camelCase")]
267    pub struct Part {
268        #[serde(skip_serializing_if = "Option::is_none")]
269        pub text: Option<String>,
270        #[serde(skip_serializing_if = "Option::is_none")]
271        pub inline_data: Option<Blob>,
272        #[serde(skip_serializing_if = "Option::is_none")]
273        pub function_call: Option<FunctionCall>,
274        #[serde(skip_serializing_if = "Option::is_none")]
275        pub function_response: Option<FunctionResponse>,
276        #[serde(skip_serializing_if = "Option::is_none")]
277        pub file_data: Option<FileData>,
278        #[serde(skip_serializing_if = "Option::is_none")]
279        pub executable_code: Option<ExecutableCode>,
280        #[serde(skip_serializing_if = "Option::is_none")]
281        pub code_execution_result: Option<CodeExecutionResult>,
282    }
283
284    /// Raw media bytes.
285    /// Text should not be sent as raw bytes, use the 'text' field.
286    #[derive(Debug, Deserialize, Serialize)]
287    #[serde(rename_all = "camelCase")]
288    pub struct Blob {
289        /// The IANA standard MIME type of the source data. Examples: - image/png - image/jpeg
290        /// If an unsupported MIME type is provided, an error will be returned.
291        pub mime_type: String,
292        /// Raw bytes for media formats. A base64-encoded string.
293        pub data: String,
294    }
295
296    /// A predicted FunctionCall returned from the model that contains a string representing the
297    /// FunctionDeclaration.name with the arguments and their values.
298    ///     #[derive(Debug, Deserialize, Serialize)]
299    #[derive(Debug, Deserialize, Serialize)]
300    pub struct FunctionCall {
301        /// Required. The name of the function to call. Must be a-z, A-Z, 0-9, or contain underscores
302        /// and dashes, with a maximum length of 63.
303        pub name: String,
304        /// Optional. The function parameters and values in JSON object format.
305        pub args: Option<Map<String, Value>>,
306    }
307
308    /// The result output from a FunctionCall that contains a string representing the FunctionDeclaration.name
309    /// and a structured JSON object containing any output from the function is used as context to the model.
310    /// This should contain the result of aFunctionCall made based on model prediction.
311    #[derive(Debug, Deserialize, Serialize)]
312    pub struct FunctionResponse {
313        /// The name of the function to call. Must be a-z, A-Z, 0-9, or contain underscores and dashes,
314        /// with a maximum length of 63.
315        pub name: String,
316        /// The function response in JSON object format.
317        pub response: Option<HashMap<String, Value>>,
318    }
319
320    /// URI based data.
321    #[derive(Debug, Deserialize, Serialize)]
322    #[serde(rename_all = "camelCase")]
323    pub struct FileData {
324        /// Optional. The IANA standard MIME type of the source data.
325        pub mime_type: Option<String>,
326        /// Required. URI.
327        pub file_uri: String,
328    }
329
330    #[derive(Debug, Deserialize, Serialize)]
331    pub struct SafetyRating {
332        pub category: HarmCategory,
333        pub probability: HarmProbability,
334    }
335
336    #[derive(Debug, Deserialize, Serialize)]
337    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
338    pub enum HarmProbability {
339        HarmProbabilityUnspecified,
340        Negligible,
341        Low,
342        Medium,
343        High,
344    }
345
346    #[derive(Debug, Deserialize, Serialize)]
347    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
348    pub enum HarmCategory {
349        HarmCategoryUnspecified,
350        HarmCategoryDerogatory,
351        HarmCategoryToxicity,
352        HarmCategoryViolence,
353        HarmCategorySexually,
354        HarmCategoryMedical,
355        HarmCategoryDangerous,
356        HarmCategoryHarassment,
357        HarmCategoryHateSpeech,
358        HarmCategorySexuallyExplicit,
359        HarmCategoryDangerousContent,
360        HarmCategoryCivicIntegrity,
361    }
362
363    #[derive(Debug, Deserialize)]
364    #[serde(rename_all = "camelCase")]
365    pub struct UsageMetadata {
366        pub prompt_token_count: i32,
367        pub cached_content_token_count: Option<i32>,
368        pub candidates_token_count: i32,
369        pub total_token_count: i32,
370    }
371
372    impl std::fmt::Display for UsageMetadata {
373        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
374            write!(
375                f,
376                "Prompt token count: {}\nCached content token count: {}\nCandidates token count: {}\nTotal token count: {}",
377                self.prompt_token_count,
378                match self.cached_content_token_count {
379                    Some(count) => count.to_string(),
380                    None => "n/a".to_string(),
381                },
382                self.candidates_token_count,
383                self.total_token_count
384            )
385        }
386    }
387
388    /// A set of the feedback metadata the prompt specified in [GenerateContentRequest.contents](GenerateContentRequest).
389    #[derive(Debug, Deserialize)]
390    #[serde(rename_all = "camelCase")]
391    pub struct PromptFeedback {
392        /// Optional. If set, the prompt was blocked and no candidates are returned. Rephrase the prompt.
393        pub block_reason: Option<BlockReason>,
394        /// Ratings for safety of the prompt. There is at most one rating per category.
395        pub safety_ratings: Option<Vec<SafetyRating>>,
396    }
397
398    /// Reason why a prompt was blocked by the model
399    #[derive(Debug, Deserialize)]
400    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
401    pub enum BlockReason {
402        /// Default value. This value is unused.
403        BlockReasonUnspecified,
404        /// Prompt was blocked due to safety reasons. Inspect safetyRatings to understand which safety category blocked it.
405        Safety,
406        /// Prompt was blocked due to unknown reasons.
407        Other,
408        /// Prompt was blocked due to the terms which are included from the terminology blocklist.
409        Blocklist,
410        /// Prompt was blocked due to prohibited content.
411        ProhibitedContent,
412    }
413
414    #[derive(Debug, Deserialize)]
415    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
416    pub enum FinishReason {
417        /// Default value. This value is unused.
418        FinishReasonUnspecified,
419        /// Natural stop point of the model or provided stop sequence.
420        Stop,
421        /// The maximum number of tokens as specified in the request was reached.
422        MaxTokens,
423        /// The response candidate content was flagged for safety reasons.
424        Safety,
425        /// The response candidate content was flagged for recitation reasons.
426        Recitation,
427        /// The response candidate content was flagged for using an unsupported language.
428        Language,
429        /// Unknown reason.
430        Other,
431        /// Token generation stopped because the content contains forbidden terms.
432        Blocklist,
433        /// Token generation stopped for potentially containing prohibited content.
434        ProhibitedContent,
435        /// Token generation stopped because the content potentially contains Sensitive Personally Identifiable Information (SPII).
436        Spii,
437        /// The function call generated by the model is invalid.
438        MalformedFunctionCall,
439    }
440
441    #[derive(Debug, Deserialize)]
442    #[serde(rename_all = "camelCase")]
443    pub struct CitationMetadata {
444        pub citation_sources: Vec<CitationSource>,
445    }
446
447    #[derive(Debug, Deserialize)]
448    #[serde(rename_all = "camelCase")]
449    pub struct CitationSource {
450        pub uri: Option<String>,
451        pub start_index: Option<i32>,
452        pub end_index: Option<i32>,
453        pub license: Option<String>,
454    }
455
456    #[derive(Debug, Deserialize)]
457    #[serde(rename_all = "camelCase")]
458    pub struct LogprobsResult {
459        pub top_candidate: Vec<TopCandidate>,
460        pub chosen_candidate: Vec<LogProbCandidate>,
461    }
462
463    #[derive(Debug, Deserialize)]
464    pub struct TopCandidate {
465        pub candidates: Vec<LogProbCandidate>,
466    }
467
468    #[derive(Debug, Deserialize)]
469    #[serde(rename_all = "camelCase")]
470    pub struct LogProbCandidate {
471        pub token: String,
472        pub token_id: String,
473        pub log_probability: f64,
474    }
475
476    /// Gemini API Configuration options for model generation and outputs. Not all parameters are
477    /// configurable for every model. From [Gemini API Reference](https://ai.google.dev/api/generate-content#generationconfig)
478    /// ### Bep Note:
479    /// Can be used to cosntruct a typesafe `additional_params` in bep::[AgentBuilder](crate::agent::AgentBuilder).
480    #[derive(Debug, Deserialize, Serialize)]
481    #[serde(rename_all = "camelCase")]
482    pub struct GenerationConfig {
483        /// The set of character sequences (up to 5) that will stop output generation. If specified, the API will stop
484        /// at the first appearance of a stop_sequence. The stop sequence will not be included as part of the response.
485        pub stop_sequences: Option<Vec<String>>,
486        /// MIME type of the generated candidate text. Supported MIME types are:
487        ///     - text/plain:  (default) Text output
488        ///     - application/json: JSON response in the response candidates.
489        ///     - text/x.enum: ENUM as a string response in the response candidates.
490        /// Refer to the docs for a list of all supported text MIME types
491        pub response_mime_type: Option<String>,
492        /// Output schema of the generated candidate text. Schemas must be a subset of the OpenAPI schema and can be
493        /// objects, primitives or arrays. If set, a compatible responseMimeType must also  be set. Compatible MIME
494        /// types: application/json: Schema for JSON response. Refer to the JSON text generation guide for more details.
495        pub response_schema: Option<Schema>,
496        /// Number of generated responses to return. Currently, this value can only be set to 1. If
497        /// unset, this will default to 1.
498        pub candidate_count: Option<i32>,
499        /// The maximum number of tokens to include in a response candidate. Note: The default value varies by model, see
500        /// the Model.output_token_limit attribute of the Model returned from the getModel function.
501        pub max_output_tokens: Option<u64>,
502        /// Controls the randomness of the output. Note: The default value varies by model, see the Model.temperature
503        /// attribute of the Model returned from the getModel function. Values can range from [0.0, 2.0].
504        pub temperature: Option<f64>,
505        /// The maximum cumulative probability of tokens to consider when sampling. The model uses combined Top-k and
506        /// Top-p (nucleus) sampling. Tokens are sorted based on their assigned probabilities so that only the most
507        /// likely tokens are considered. Top-k sampling directly limits the maximum number of tokens to consider, while
508        /// Nucleus sampling limits the number of tokens based on the cumulative probability. Note: The default value
509        /// varies by Model and is specified by theModel.top_p attribute returned from the getModel function. An empty
510        /// topK attribute indicates that the model doesn't apply top-k sampling and doesn't allow setting topK on requests.
511        pub top_p: Option<f64>,
512        /// The maximum number of tokens to consider when sampling. Gemini models use Top-p (nucleus) sampling or a
513        /// combination of Top-k and nucleus sampling. Top-k sampling considers the set of topK most probable tokens.
514        /// Models running with nucleus sampling don't allow topK setting. Note: The default value varies by Model and is
515        /// specified by theModel.top_p attribute returned from the getModel function. An empty topK attribute indicates
516        /// that the model doesn't apply top-k sampling and doesn't allow setting topK on requests.
517        pub top_k: Option<i32>,
518        /// Presence penalty applied to the next token's logprobs if the token has already been seen in the response.
519        /// This penalty is binary on/off and not dependant on the number of times the token is used (after the first).
520        /// Use frequencyPenalty for a penalty that increases with each use. A positive penalty will discourage the use
521        /// of tokens that have already been used in the response, increasing the vocabulary. A negative penalty will
522        /// encourage the use of tokens that have already been used in the response, decreasing the vocabulary.
523        pub presence_penalty: Option<f64>,
524        /// Frequency penalty applied to the next token's logprobs, multiplied by the number of times each token has been
525        /// seen in the respponse so far. A positive penalty will discourage the use of tokens that have already been
526        /// used, proportional to the number of times the token has been used: The more a token is used, the more
527        /// dificult it is for the  model to use that token again increasing the vocabulary of responses. Caution: A
528        /// negative penalty will encourage the model to reuse tokens proportional to the number of times the token has
529        /// been used. Small negative values will reduce the vocabulary of a response. Larger negative values will cause
530        /// the model to  repeating a common token until it hits the maxOutputTokens limit: "...the the the the the...".
531        pub frequency_penalty: Option<f64>,
532        /// If true, export the logprobs results in response.
533        pub response_logprobs: Option<bool>,
534        /// Only valid if responseLogprobs=True. This sets the number of top logprobs to return at each decoding step in
535        /// [Candidate.logprobs_result].
536        pub logprobs: Option<i32>,
537    }
538
539    impl Default for GenerationConfig {
540        fn default() -> Self {
541            Self {
542                temperature: Some(1.0),
543                max_output_tokens: Some(4096),
544                stop_sequences: None,
545                response_mime_type: None,
546                response_schema: None,
547                candidate_count: None,
548                top_p: None,
549                top_k: None,
550                presence_penalty: None,
551                frequency_penalty: None,
552                response_logprobs: None,
553                logprobs: None,
554            }
555        }
556    }
557    /// The Schema object allows the definition of input and output data types. These types can be objects, but also
558    /// primitives and arrays. Represents a select subset of an OpenAPI 3.0 schema object.
559    /// From [Gemini API Reference](https://ai.google.dev/api/caching#Schema)
560    #[derive(Debug, Deserialize, Serialize)]
561    pub struct Schema {
562        pub r#type: String,
563        pub format: Option<String>,
564        pub description: Option<String>,
565        pub nullable: Option<bool>,
566        pub r#enum: Option<Vec<String>>,
567        pub max_items: Option<i32>,
568        pub min_items: Option<i32>,
569        pub properties: Option<HashMap<String, Schema>>,
570        pub required: Option<Vec<String>>,
571        pub items: Option<Box<Schema>>,
572    }
573
574    impl TryFrom<Value> for Schema {
575        type Error = CompletionError;
576
577        fn try_from(value: Value) -> Result<Self, Self::Error> {
578            if let Some(obj) = value.as_object() {
579                Ok(Schema {
580                    r#type: obj
581                        .get("type")
582                        .and_then(|v| v.as_str())
583                        .unwrap_or_default()
584                        .to_string(),
585                    format: obj.get("format").and_then(|v| v.as_str()).map(String::from),
586                    description: obj
587                        .get("description")
588                        .and_then(|v| v.as_str())
589                        .map(String::from),
590                    nullable: obj.get("nullable").and_then(|v| v.as_bool()),
591                    r#enum: obj.get("enum").and_then(|v| v.as_array()).map(|arr| {
592                        arr.iter()
593                            .filter_map(|v| v.as_str().map(String::from))
594                            .collect()
595                    }),
596                    max_items: obj
597                        .get("maxItems")
598                        .and_then(|v| v.as_i64())
599                        .map(|v| v as i32),
600                    min_items: obj
601                        .get("minItems")
602                        .and_then(|v| v.as_i64())
603                        .map(|v| v as i32),
604                    properties: obj
605                        .get("properties")
606                        .and_then(|v| v.as_object())
607                        .map(|map| {
608                            map.iter()
609                                .filter_map(|(k, v)| {
610                                    v.clone().try_into().ok().map(|schema| (k.clone(), schema))
611                                })
612                                .collect()
613                        }),
614                    required: obj.get("required").and_then(|v| v.as_array()).map(|arr| {
615                        arr.iter()
616                            .filter_map(|v| v.as_str().map(String::from))
617                            .collect()
618                    }),
619                    items: obj
620                        .get("items")
621                        .map(|v| Box::new(v.clone().try_into().unwrap())),
622                })
623            } else {
624                Err(CompletionError::ResponseError(
625                    "Expected a JSON object for Schema".into(),
626                ))
627            }
628        }
629    }
630
631    #[derive(Debug, Serialize)]
632    #[serde(rename_all = "camelCase")]
633    pub struct GenerateContentRequest {
634        pub contents: Vec<Content>,
635        pub tools: Option<Vec<Tool>>,
636        pub tool_config: Option<ToolConfig>,
637        /// Optional. Configuration options for model generation and outputs.
638        pub generation_config: Option<GenerationConfig>,
639        /// Optional. A list of unique SafetySetting instances for blocking unsafe content. This will be enforced on the
640        /// [GenerateContentRequest.contents] and [GenerateContentResponse.candidates]. There should not be more than one
641        /// setting for each SafetyCategory type. The API will block any contents and responses that fail to meet the
642        /// thresholds set by these settings. This list overrides the default settings for each SafetyCategory specified
643        /// in the safetySettings. If there is no SafetySetting for a given SafetyCategory provided in the list, the API
644        /// will use the default safety setting for that category. Harm categories:
645        ///     - HARM_CATEGORY_HATE_SPEECH,
646        ///     - HARM_CATEGORY_SEXUALLY_EXPLICIT
647        ///     - HARM_CATEGORY_DANGEROUS_CONTENT
648        ///     - HARM_CATEGORY_HARASSMENT
649        /// are supported.
650        /// Refer to the guide for detailed information on available safety settings. Also refer to the Safety guidance
651        /// to learn how to incorporate safety considerations in your AI applications.
652        pub safety_settings: Option<Vec<SafetySetting>>,
653        /// Optional. Developer set system instruction(s). Currently, text only.
654        /// From [Gemini API Reference](https://ai.google.dev/gemini-api/docs/system-instructions?lang=rest)
655        pub system_instruction: Option<Content>,
656        // cachedContent: Optional<String>
657    }
658
659    #[derive(Debug, Serialize)]
660    #[serde(rename_all = "camelCase")]
661    pub struct Tool {
662        pub function_declaration: FunctionDeclaration,
663        pub code_execution: Option<CodeExecution>,
664    }
665
666    #[derive(Debug, Serialize)]
667    #[serde(rename_all = "camelCase")]
668    pub struct FunctionDeclaration {
669        pub name: String,
670        pub description: String,
671        pub parameters: Option<Vec<Schema>>,
672    }
673
674    #[derive(Debug, Serialize)]
675    #[serde(rename_all = "camelCase")]
676    pub struct ToolConfig {
677        pub schema: Option<Schema>,
678    }
679
680    #[derive(Debug, Serialize)]
681    #[serde(rename_all = "camelCase")]
682    pub struct CodeExecution {}
683
684    #[derive(Debug, Serialize)]
685    #[serde(rename_all = "camelCase")]
686    pub struct SafetySetting {
687        pub category: HarmCategory,
688        pub threshold: HarmBlockThreshold,
689    }
690
691    #[derive(Debug, Serialize)]
692    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
693    pub enum HarmBlockThreshold {
694        HarmBlockThresholdUnspecified,
695        BlockLowAndAbove,
696        BlockMediumAndAbove,
697        BlockOnlyHigh,
698        BlockNone,
699        Off,
700    }
701}