rig/providers/gemini/
completion.rs

1// ================================================================
2//! Google Gemini Completion Integration
3//! From [Gemini API Reference](https://ai.google.dev/api/generate-content)
4// ================================================================
5
6/// `gemini-2.0-flash` completion model
7pub const GEMINI_2_0_FLASH: &str = "gemini-2.0-flash";
8/// `gemini-1.5-flash` completion model
9pub const GEMINI_1_5_FLASH: &str = "gemini-1.5-flash";
10/// `gemini-1.5-pro` completion model
11pub const GEMINI_1_5_PRO: &str = "gemini-1.5-pro";
12/// `gemini-1.5-pro-8b` completion model
13pub const GEMINI_1_5_PRO_8B: &str = "gemini-1.5-pro-8b";
14/// `gemini-1.0-pro` completion model
15pub const GEMINI_1_0_PRO: &str = "gemini-1.0-pro";
16
17use gemini_api_types::{
18    Content, FunctionDeclaration, GenerateContentRequest, GenerateContentResponse,
19    GenerationConfig, Part, Role, Tool,
20};
21use serde_json::{Map, Value};
22use std::convert::TryFrom;
23
24use crate::{
25    completion::{self, CompletionError, CompletionRequest},
26    OneOrMany,
27};
28
29use super::Client;
30
31// =================================================================
32// Rig Implementation Types
33// =================================================================
34
35#[derive(Clone)]
36pub struct CompletionModel {
37    client: Client,
38    pub model: String,
39}
40
41impl CompletionModel {
42    pub fn new(client: Client, model: &str) -> Self {
43        Self {
44            client,
45            model: model.to_string(),
46        }
47    }
48}
49
50impl completion::CompletionModel for CompletionModel {
51    type Response = GenerateContentResponse;
52
53    #[cfg_attr(feature = "worker", worker::send)]
54    async fn completion(
55        &self,
56        mut completion_request: CompletionRequest,
57    ) -> Result<completion::CompletionResponse<GenerateContentResponse>, CompletionError> {
58        let mut full_history = Vec::new();
59        full_history.append(&mut completion_request.chat_history);
60
61        full_history.push(completion_request.prompt_with_context());
62
63        // Handle Gemini specific parameters
64        let additional_params = completion_request
65            .additional_params
66            .unwrap_or_else(|| Value::Object(Map::new()));
67        let mut generation_config = serde_json::from_value::<GenerationConfig>(additional_params)?;
68
69        // Set temperature from completion_request or additional_params
70        if let Some(temp) = completion_request.temperature {
71            generation_config.temperature = Some(temp);
72        }
73
74        // Set max_tokens from completion_request or additional_params
75        if let Some(max_tokens) = completion_request.max_tokens {
76            generation_config.max_output_tokens = Some(max_tokens);
77        }
78
79        let system_instruction = completion_request.preamble.clone().map(|preamble| Content {
80            parts: OneOrMany::one(preamble.into()),
81            role: Some(Role::Model),
82        });
83
84        let request = GenerateContentRequest {
85            contents: full_history
86                .into_iter()
87                .map(|msg| {
88                    msg.try_into()
89                        .map_err(|e| CompletionError::RequestError(Box::new(e)))
90                })
91                .collect::<Result<Vec<_>, _>>()?,
92            generation_config: Some(generation_config),
93            safety_settings: None,
94            tools: Some(
95                completion_request
96                    .tools
97                    .into_iter()
98                    .map(Tool::try_from)
99                    .collect::<Result<Vec<_>, _>>()?,
100            ),
101            tool_config: None,
102            system_instruction,
103        };
104
105        tracing::debug!(
106            "Sending completion request to Gemini API {}",
107            serde_json::to_string_pretty(&request)?
108        );
109
110        let response = self
111            .client
112            .post(&format!("/v1beta/models/{}:generateContent", self.model))
113            .json(&request)
114            .send()
115            .await?;
116
117        if response.status().is_success() {
118            let response = response.json::<GenerateContentResponse>().await?;
119            match response.usage_metadata {
120                Some(ref usage) => tracing::info!(target: "rig",
121                "Gemini completion token usage: {}",
122                usage
123                ),
124                None => tracing::info!(target: "rig",
125                    "Gemini completion token usage: n/a",
126                ),
127            }
128
129            tracing::debug!("Received response");
130
131            Ok(completion::CompletionResponse::try_from(response))
132        } else {
133            Err(CompletionError::ProviderError(response.text().await?))
134        }?
135    }
136}
137
138impl TryFrom<completion::ToolDefinition> for Tool {
139    type Error = CompletionError;
140
141    fn try_from(tool: completion::ToolDefinition) -> Result<Self, Self::Error> {
142        Ok(Self {
143            function_declarations: FunctionDeclaration {
144                name: tool.name,
145                description: tool.description,
146                parameters: Some(tool.parameters.try_into()?),
147            },
148            code_execution: None,
149        })
150    }
151}
152
153impl TryFrom<GenerateContentResponse> for completion::CompletionResponse<GenerateContentResponse> {
154    type Error = CompletionError;
155
156    fn try_from(response: GenerateContentResponse) -> Result<Self, Self::Error> {
157        let candidate = response.candidates.first().ok_or_else(|| {
158            CompletionError::ResponseError("No response candidates in response".into())
159        })?;
160
161        let content = candidate
162            .content
163            .parts
164            .iter()
165            .map(|part| {
166                Ok(match part {
167                    Part::Text(text) => completion::AssistantContent::text(text),
168                    Part::FunctionCall(function_call) => completion::AssistantContent::tool_call(
169                        &function_call.name,
170                        &function_call.name,
171                        function_call.args.clone(),
172                    ),
173                    _ => {
174                        return Err(CompletionError::ResponseError(
175                            "Response did not contain a message or tool call".into(),
176                        ))
177                    }
178                })
179            })
180            .collect::<Result<Vec<_>, _>>()?;
181
182        let choice = OneOrMany::many(content).map_err(|_| {
183            CompletionError::ResponseError(
184                "Response contained no message or tool call (empty)".to_owned(),
185            )
186        })?;
187
188        Ok(completion::CompletionResponse {
189            choice,
190            raw_response: response,
191        })
192    }
193}
194
195pub mod gemini_api_types {
196    use std::{collections::HashMap, convert::Infallible, str::FromStr};
197
198    // =================================================================
199    // Gemini API Types
200    // =================================================================
201    use serde::{Deserialize, Serialize};
202    use serde_json::Value;
203
204    use crate::{
205        completion::CompletionError,
206        message::{self, MimeType as _},
207        one_or_many::string_or_one_or_many,
208        providers::gemini::gemini_api_types::{CodeExecutionResult, ExecutableCode},
209        OneOrMany,
210    };
211
212    /// Response from the model supporting multiple candidate responses.
213    /// Safety ratings and content filtering are reported for both prompt in GenerateContentResponse.prompt_feedback
214    /// and for each candidate in finishReason and in safetyRatings.
215    /// The API:
216    ///     - Returns either all requested candidates or none of them
217    ///     - Returns no candidates at all only if there was something wrong with the prompt (check promptFeedback)
218    ///     - Reports feedback on each candidate in finishReason and safetyRatings.
219    #[derive(Debug, Deserialize)]
220    #[serde(rename_all = "camelCase")]
221    pub struct GenerateContentResponse {
222        /// Candidate responses from the model.
223        pub candidates: Vec<ContentCandidate>,
224        /// Returns the prompt's feedback related to the content filters.
225        pub prompt_feedback: Option<PromptFeedback>,
226        /// Output only. Metadata on the generation requests' token usage.
227        pub usage_metadata: Option<UsageMetadata>,
228        pub model_version: Option<String>,
229    }
230
231    /// A response candidate generated from the model.
232    #[derive(Debug, Deserialize)]
233    #[serde(rename_all = "camelCase")]
234    pub struct ContentCandidate {
235        /// Output only. Generated content returned from the model.
236        pub content: Content,
237        /// Optional. Output only. The reason why the model stopped generating tokens.
238        /// If empty, the model has not stopped generating tokens.
239        pub finish_reason: Option<FinishReason>,
240        /// List of ratings for the safety of a response candidate.
241        /// There is at most one rating per category.
242        pub safety_ratings: Option<Vec<SafetyRating>>,
243        /// Output only. Citation information for model-generated candidate.
244        /// This field may be populated with recitation information for any text included in the content.
245        /// These are passages that are "recited" from copyrighted material in the foundational LLM's training data.
246        pub citation_metadata: Option<CitationMetadata>,
247        /// Output only. Token count for this candidate.
248        pub token_count: Option<i32>,
249        /// Output only.
250        pub avg_logprobs: Option<f64>,
251        /// Output only. Log-likelihood scores for the response tokens and top tokens
252        pub logprobs_result: Option<LogprobsResult>,
253        /// Output only. Index of the candidate in the list of response candidates.
254        pub index: Option<i32>,
255    }
256    #[derive(Debug, Deserialize, Serialize)]
257    pub struct Content {
258        /// Ordered Parts that constitute a single message. Parts may have different MIME types.
259        #[serde(deserialize_with = "string_or_one_or_many")]
260        pub parts: OneOrMany<Part>,
261        /// The producer of the content. Must be either 'user' or 'model'.
262        /// Useful to set for multi-turn conversations, otherwise can be left blank or unset.
263        pub role: Option<Role>,
264    }
265
266    impl TryFrom<message::Message> for Content {
267        type Error = message::MessageError;
268
269        fn try_from(msg: message::Message) -> Result<Self, Self::Error> {
270            Ok(match msg {
271                message::Message::User { content } => Content {
272                    parts: content.try_map(|c| c.try_into())?,
273                    role: Some(Role::User),
274                },
275                message::Message::Assistant { content } => Content {
276                    role: Some(Role::Model),
277                    parts: content.map(|content| content.into()),
278                },
279            })
280        }
281    }
282
283    impl TryFrom<Content> for message::Message {
284        type Error = message::MessageError;
285
286        fn try_from(content: Content) -> Result<Self, Self::Error> {
287            match content.role {
288                Some(Role::User) | None => Ok(message::Message::User {
289                    content: content.parts.try_map(|part| {
290                        Ok(match part {
291                            Part::Text(text) => message::UserContent::text(text),
292                            Part::InlineData(inline_data) => {
293                                let mime_type =
294                                    message::MediaType::from_mime_type(&inline_data.mime_type);
295
296                                match mime_type {
297                                    Some(message::MediaType::Image(media_type)) => {
298                                        message::UserContent::image(
299                                            inline_data.data,
300                                            Some(message::ContentFormat::default()),
301                                            Some(media_type),
302                                            Some(message::ImageDetail::default()),
303                                        )
304                                    }
305                                    Some(message::MediaType::Document(media_type)) => {
306                                        message::UserContent::document(
307                                            inline_data.data,
308                                            Some(message::ContentFormat::default()),
309                                            Some(media_type),
310                                        )
311                                    }
312                                    Some(message::MediaType::Audio(media_type)) => {
313                                        message::UserContent::audio(
314                                            inline_data.data,
315                                            Some(message::ContentFormat::default()),
316                                            Some(media_type),
317                                        )
318                                    }
319                                    _ => {
320                                        return Err(message::MessageError::ConversionError(
321                                            format!("Unsupported media type {:?}", mime_type),
322                                        ))
323                                    }
324                                }
325                            }
326                            _ => {
327                                return Err(message::MessageError::ConversionError(format!(
328                                    "Unsupported gemini content part type: {:?}",
329                                    part
330                                )))
331                            }
332                        })
333                    })?,
334                }),
335                Some(Role::Model) => Ok(message::Message::Assistant {
336                    content: content.parts.try_map(|part| {
337                        Ok(match part {
338                            Part::Text(text) => message::AssistantContent::text(text),
339                            Part::FunctionCall(function_call) => {
340                                message::AssistantContent::ToolCall(function_call.into())
341                            }
342                            _ => {
343                                return Err(message::MessageError::ConversionError(format!(
344                                    "Unsupported part type: {:?}",
345                                    part
346                                )))
347                            }
348                        })
349                    })?,
350                }),
351            }
352        }
353    }
354
355    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
356    #[serde(rename_all = "lowercase")]
357    pub enum Role {
358        User,
359        Model,
360    }
361
362    /// A datatype containing media that is part of a multi-part [Content] message.
363    /// A Part consists of data which has an associated datatype. A Part can only contain one of the accepted types in Part.data.
364    /// A Part must have a fixed IANA MIME type identifying the type and subtype of the media if the inlineData field is filled with raw bytes.
365    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
366    #[serde(rename_all = "camelCase")]
367    pub enum Part {
368        Text(String),
369        InlineData(Blob),
370        FunctionCall(FunctionCall),
371        FunctionResponse(FunctionResponse),
372        FileData(FileData),
373        ExecutableCode(ExecutableCode),
374        CodeExecutionResult(CodeExecutionResult),
375    }
376
377    impl From<String> for Part {
378        fn from(text: String) -> Self {
379            Self::Text(text)
380        }
381    }
382
383    impl From<&str> for Part {
384        fn from(text: &str) -> Self {
385            Self::Text(text.to_string())
386        }
387    }
388
389    impl FromStr for Part {
390        type Err = Infallible;
391
392        fn from_str(s: &str) -> Result<Self, Self::Err> {
393            Ok(s.into())
394        }
395    }
396
397    impl TryFrom<message::UserContent> for Part {
398        type Error = message::MessageError;
399
400        fn try_from(content: message::UserContent) -> Result<Self, Self::Error> {
401            match content {
402                message::UserContent::Text(message::Text { text }) => Ok(Self::Text(text)),
403                message::UserContent::ToolResult(message::ToolResult { id, content }) => {
404                    let content = match content.first() {
405                        message::ToolResultContent::Text(text) => text.text,
406                        message::ToolResultContent::Image(_) => {
407                            return Err(message::MessageError::ConversionError(
408                                "Tool result content must be text".to_string(),
409                            ))
410                        }
411                    };
412                    Ok(Part::FunctionResponse(FunctionResponse {
413                        name: id,
414                        response: Some(serde_json::from_str(&content).map_err(|e| {
415                            message::MessageError::ConversionError(format!(
416                                "Failed to parse tool response: {}",
417                                e
418                            ))
419                        })?),
420                    }))
421                }
422                message::UserContent::Image(message::Image {
423                    data, media_type, ..
424                }) => match media_type {
425                    Some(media_type) => match media_type {
426                        message::ImageMediaType::JPEG
427                        | message::ImageMediaType::PNG
428                        | message::ImageMediaType::WEBP
429                        | message::ImageMediaType::HEIC
430                        | message::ImageMediaType::HEIF => Ok(Self::InlineData(Blob {
431                            mime_type: media_type.to_mime_type().to_owned(),
432                            data,
433                        })),
434                        _ => Err(message::MessageError::ConversionError(format!(
435                            "Unsupported image media type {:?}",
436                            media_type
437                        ))),
438                    },
439                    None => Err(message::MessageError::ConversionError(
440                        "Media type for image is required for Anthropic".to_string(),
441                    )),
442                },
443                message::UserContent::Document(message::Document {
444                    data, media_type, ..
445                }) => match media_type {
446                    Some(media_type) => match media_type {
447                        message::DocumentMediaType::PDF
448                        | message::DocumentMediaType::TXT
449                        | message::DocumentMediaType::RTF
450                        | message::DocumentMediaType::HTML
451                        | message::DocumentMediaType::CSS
452                        | message::DocumentMediaType::MARKDOWN
453                        | message::DocumentMediaType::CSV
454                        | message::DocumentMediaType::XML => Ok(Self::InlineData(Blob {
455                            mime_type: media_type.to_mime_type().to_owned(),
456                            data,
457                        })),
458                        _ => Err(message::MessageError::ConversionError(format!(
459                            "Unsupported document media type {:?}",
460                            media_type
461                        ))),
462                    },
463                    None => Err(message::MessageError::ConversionError(
464                        "Media type for document is required for Anthropic".to_string(),
465                    )),
466                },
467                message::UserContent::Audio(message::Audio {
468                    data, media_type, ..
469                }) => match media_type {
470                    Some(media_type) => Ok(Self::InlineData(Blob {
471                        mime_type: media_type.to_mime_type().to_owned(),
472                        data,
473                    })),
474                    None => Err(message::MessageError::ConversionError(
475                        "Media type for audio is required for Anthropic".to_string(),
476                    )),
477                },
478            }
479        }
480    }
481
482    impl From<message::AssistantContent> for Part {
483        fn from(content: message::AssistantContent) -> Self {
484            match content {
485                message::AssistantContent::Text(message::Text { text }) => text.into(),
486                message::AssistantContent::ToolCall(tool_call) => tool_call.into(),
487            }
488        }
489    }
490
491    impl From<message::ToolCall> for Part {
492        fn from(tool_call: message::ToolCall) -> Self {
493            Self::FunctionCall(FunctionCall {
494                name: tool_call.function.name,
495                args: tool_call.function.arguments,
496            })
497        }
498    }
499
500    /// Raw media bytes.
501    /// Text should not be sent as raw bytes, use the 'text' field.
502    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
503    #[serde(rename_all = "camelCase")]
504    pub struct Blob {
505        /// The IANA standard MIME type of the source data. Examples: - image/png - image/jpeg
506        /// If an unsupported MIME type is provided, an error will be returned.
507        pub mime_type: String,
508        /// Raw bytes for media formats. A base64-encoded string.
509        pub data: String,
510    }
511
512    /// A predicted FunctionCall returned from the model that contains a string representing the
513    /// FunctionDeclaration.name with the arguments and their values.
514    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
515    pub struct FunctionCall {
516        /// Required. The name of the function to call. Must be a-z, A-Z, 0-9, or contain underscores
517        /// and dashes, with a maximum length of 63.
518        pub name: String,
519        /// Optional. The function parameters and values in JSON object format.
520        pub args: serde_json::Value,
521    }
522
523    impl From<FunctionCall> for message::ToolCall {
524        fn from(function_call: FunctionCall) -> Self {
525            Self {
526                id: function_call.name.clone(),
527                function: message::ToolFunction {
528                    name: function_call.name,
529                    arguments: function_call.args,
530                },
531            }
532        }
533    }
534
535    impl From<message::ToolCall> for FunctionCall {
536        fn from(tool_call: message::ToolCall) -> Self {
537            Self {
538                name: tool_call.function.name,
539                args: tool_call.function.arguments,
540            }
541        }
542    }
543
544    /// The result output from a FunctionCall that contains a string representing the FunctionDeclaration.name
545    /// and a structured JSON object containing any output from the function is used as context to the model.
546    /// This should contain the result of aFunctionCall made based on model prediction.
547    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
548    pub struct FunctionResponse {
549        /// The name of the function to call. Must be a-z, A-Z, 0-9, or contain underscores and dashes,
550        /// with a maximum length of 63.
551        pub name: String,
552        /// The function response in JSON object format.
553        pub response: Option<HashMap<String, serde_json::Value>>,
554    }
555
556    /// URI based data.
557    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
558    #[serde(rename_all = "camelCase")]
559    pub struct FileData {
560        /// Optional. The IANA standard MIME type of the source data.
561        pub mime_type: Option<String>,
562        /// Required. URI.
563        pub file_uri: String,
564    }
565
566    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
567    pub struct SafetyRating {
568        pub category: HarmCategory,
569        pub probability: HarmProbability,
570    }
571
572    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
573    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
574    pub enum HarmProbability {
575        HarmProbabilityUnspecified,
576        Negligible,
577        Low,
578        Medium,
579        High,
580    }
581
582    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
583    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
584    pub enum HarmCategory {
585        HarmCategoryUnspecified,
586        HarmCategoryDerogatory,
587        HarmCategoryToxicity,
588        HarmCategoryViolence,
589        HarmCategorySexually,
590        HarmCategoryMedical,
591        HarmCategoryDangerous,
592        HarmCategoryHarassment,
593        HarmCategoryHateSpeech,
594        HarmCategorySexuallyExplicit,
595        HarmCategoryDangerousContent,
596        HarmCategoryCivicIntegrity,
597    }
598
599    #[derive(Debug, Deserialize)]
600    #[serde(rename_all = "camelCase")]
601    pub struct UsageMetadata {
602        pub prompt_token_count: i32,
603        #[serde(skip_serializing_if = "Option::is_none")]
604        pub cached_content_token_count: Option<i32>,
605        pub candidates_token_count: i32,
606        pub total_token_count: i32,
607    }
608
609    impl std::fmt::Display for UsageMetadata {
610        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
611            write!(
612                f,
613                "Prompt token count: {}\nCached content token count: {}\nCandidates token count: {}\nTotal token count: {}",
614                self.prompt_token_count,
615                match self.cached_content_token_count {
616                    Some(count) => count.to_string(),
617                    None => "n/a".to_string(),
618                },
619                self.candidates_token_count,
620                self.total_token_count
621            )
622        }
623    }
624
625    /// A set of the feedback metadata the prompt specified in [GenerateContentRequest.contents](GenerateContentRequest).
626    #[derive(Debug, Deserialize)]
627    #[serde(rename_all = "camelCase")]
628    pub struct PromptFeedback {
629        /// Optional. If set, the prompt was blocked and no candidates are returned. Rephrase the prompt.
630        pub block_reason: Option<BlockReason>,
631        /// Ratings for safety of the prompt. There is at most one rating per category.
632        pub safety_ratings: Option<Vec<SafetyRating>>,
633    }
634
635    /// Reason why a prompt was blocked by the model
636    #[derive(Debug, Deserialize)]
637    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
638    pub enum BlockReason {
639        /// Default value. This value is unused.
640        BlockReasonUnspecified,
641        /// Prompt was blocked due to safety reasons. Inspect safetyRatings to understand which safety category blocked it.
642        Safety,
643        /// Prompt was blocked due to unknown reasons.
644        Other,
645        /// Prompt was blocked due to the terms which are included from the terminology blocklist.
646        Blocklist,
647        /// Prompt was blocked due to prohibited content.
648        ProhibitedContent,
649    }
650
651    #[derive(Debug, Deserialize)]
652    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
653    pub enum FinishReason {
654        /// Default value. This value is unused.
655        FinishReasonUnspecified,
656        /// Natural stop point of the model or provided stop sequence.
657        Stop,
658        /// The maximum number of tokens as specified in the request was reached.
659        MaxTokens,
660        /// The response candidate content was flagged for safety reasons.
661        Safety,
662        /// The response candidate content was flagged for recitation reasons.
663        Recitation,
664        /// The response candidate content was flagged for using an unsupported language.
665        Language,
666        /// Unknown reason.
667        Other,
668        /// Token generation stopped because the content contains forbidden terms.
669        Blocklist,
670        /// Token generation stopped for potentially containing prohibited content.
671        ProhibitedContent,
672        /// Token generation stopped because the content potentially contains Sensitive Personally Identifiable Information (SPII).
673        Spii,
674        /// The function call generated by the model is invalid.
675        MalformedFunctionCall,
676    }
677
678    #[derive(Debug, Deserialize)]
679    #[serde(rename_all = "camelCase")]
680    pub struct CitationMetadata {
681        pub citation_sources: Vec<CitationSource>,
682    }
683
684    #[derive(Debug, Deserialize)]
685    #[serde(rename_all = "camelCase")]
686    pub struct CitationSource {
687        #[serde(skip_serializing_if = "Option::is_none")]
688        pub uri: Option<String>,
689        #[serde(skip_serializing_if = "Option::is_none")]
690        pub start_index: Option<i32>,
691        #[serde(skip_serializing_if = "Option::is_none")]
692        pub end_index: Option<i32>,
693        #[serde(skip_serializing_if = "Option::is_none")]
694        pub license: Option<String>,
695    }
696
697    #[derive(Debug, Deserialize)]
698    #[serde(rename_all = "camelCase")]
699    pub struct LogprobsResult {
700        pub top_candidate: Vec<TopCandidate>,
701        pub chosen_candidate: Vec<LogProbCandidate>,
702    }
703
704    #[derive(Debug, Deserialize)]
705    pub struct TopCandidate {
706        pub candidates: Vec<LogProbCandidate>,
707    }
708
709    #[derive(Debug, Deserialize)]
710    #[serde(rename_all = "camelCase")]
711    pub struct LogProbCandidate {
712        pub token: String,
713        pub token_id: String,
714        pub log_probability: f64,
715    }
716
717    /// Gemini API Configuration options for model generation and outputs. Not all parameters are
718    /// configurable for every model. From [Gemini API Reference](https://ai.google.dev/api/generate-content#generationconfig)
719    /// ### Rig Note:
720    /// Can be used to cosntruct a typesafe `additional_params` in rig::[AgentBuilder](crate::agent::AgentBuilder).
721    #[derive(Debug, Deserialize, Serialize)]
722    #[serde(rename_all = "camelCase")]
723    pub struct GenerationConfig {
724        /// The set of character sequences (up to 5) that will stop output generation. If specified, the API will stop
725        /// at the first appearance of a stop_sequence. The stop sequence will not be included as part of the response.
726        #[serde(skip_serializing_if = "Option::is_none")]
727        pub stop_sequences: Option<Vec<String>>,
728        /// MIME type of the generated candidate text. Supported MIME types are:
729        ///     - text/plain:  (default) Text output
730        ///     - application/json: JSON response in the response candidates.
731        ///     - text/x.enum: ENUM as a string response in the response candidates.
732        /// Refer to the docs for a list of all supported text MIME types
733        #[serde(skip_serializing_if = "Option::is_none")]
734        pub response_mime_type: Option<String>,
735        /// Output schema of the generated candidate text. Schemas must be a subset of the OpenAPI schema and can be
736        /// objects, primitives or arrays. If set, a compatible responseMimeType must also  be set. Compatible MIME
737        /// types: application/json: Schema for JSON response. Refer to the JSON text generation guide for more details.
738        #[serde(skip_serializing_if = "Option::is_none")]
739        pub response_schema: Option<Schema>,
740        /// Number of generated responses to return. Currently, this value can only be set to 1. If
741        /// unset, this will default to 1.
742        #[serde(skip_serializing_if = "Option::is_none")]
743        pub candidate_count: Option<i32>,
744        /// The maximum number of tokens to include in a response candidate. Note: The default value varies by model, see
745        /// the Model.output_token_limit attribute of the Model returned from the getModel function.
746        #[serde(skip_serializing_if = "Option::is_none")]
747        pub max_output_tokens: Option<u64>,
748        /// Controls the randomness of the output. Note: The default value varies by model, see the Model.temperature
749        /// attribute of the Model returned from the getModel function. Values can range from [0.0, 2.0].
750        #[serde(skip_serializing_if = "Option::is_none")]
751        pub temperature: Option<f64>,
752        /// The maximum cumulative probability of tokens to consider when sampling. The model uses combined Top-k and
753        /// Top-p (nucleus) sampling. Tokens are sorted based on their assigned probabilities so that only the most
754        /// likely tokens are considered. Top-k sampling directly limits the maximum number of tokens to consider, while
755        /// Nucleus sampling limits the number of tokens based on the cumulative probability. Note: The default value
756        /// varies by Model and is specified by theModel.top_p attribute returned from the getModel function. An empty
757        /// topK attribute indicates that the model doesn't apply top-k sampling and doesn't allow setting topK on requests.
758        #[serde(skip_serializing_if = "Option::is_none")]
759        pub top_p: Option<f64>,
760        /// The maximum number of tokens to consider when sampling. Gemini models use Top-p (nucleus) sampling or a
761        /// combination of Top-k and nucleus sampling. Top-k sampling considers the set of topK most probable tokens.
762        /// Models running with nucleus sampling don't allow topK setting. Note: The default value varies by Model and is
763        /// specified by theModel.top_p attribute returned from the getModel function. An empty topK attribute indicates
764        /// that the model doesn't apply top-k sampling and doesn't allow setting topK on requests.
765        #[serde(skip_serializing_if = "Option::is_none")]
766        pub top_k: Option<i32>,
767        /// Presence penalty applied to the next token's logprobs if the token has already been seen in the response.
768        /// This penalty is binary on/off and not dependent on the number of times the token is used (after the first).
769        /// Use frequencyPenalty for a penalty that increases with each use. A positive penalty will discourage the use
770        /// of tokens that have already been used in the response, increasing the vocabulary. A negative penalty will
771        /// encourage the use of tokens that have already been used in the response, decreasing the vocabulary.
772        #[serde(skip_serializing_if = "Option::is_none")]
773        pub presence_penalty: Option<f64>,
774        /// Frequency penalty applied to the next token's logprobs, multiplied by the number of times each token has been
775        /// seen in the response so far. A positive penalty will discourage the use of tokens that have already been
776        /// used, proportional to the number of times the token has been used: The more a token is used, the more
777        /// difficult it is for the  model to use that token again increasing the vocabulary of responses. Caution: A
778        /// negative penalty will encourage the model to reuse tokens proportional to the number of times the token has
779        /// been used. Small negative values will reduce the vocabulary of a response. Larger negative values will cause
780        /// the model to  repeating a common token until it hits the maxOutputTokens limit: "...the the the the the...".
781        #[serde(skip_serializing_if = "Option::is_none")]
782        pub frequency_penalty: Option<f64>,
783        /// If true, export the logprobs results in response.
784        #[serde(skip_serializing_if = "Option::is_none")]
785        pub response_logprobs: Option<bool>,
786        /// Only valid if responseLogprobs=True. This sets the number of top logprobs to return at each decoding step in
787        /// [Candidate.logprobs_result].
788        #[serde(skip_serializing_if = "Option::is_none")]
789        pub logprobs: Option<i32>,
790    }
791
792    impl Default for GenerationConfig {
793        fn default() -> Self {
794            Self {
795                temperature: Some(1.0),
796                max_output_tokens: Some(4096),
797                stop_sequences: None,
798                response_mime_type: None,
799                response_schema: None,
800                candidate_count: None,
801                top_p: None,
802                top_k: None,
803                presence_penalty: None,
804                frequency_penalty: None,
805                response_logprobs: None,
806                logprobs: None,
807            }
808        }
809    }
810    /// The Schema object allows the definition of input and output data types. These types can be objects, but also
811    /// primitives and arrays. Represents a select subset of an OpenAPI 3.0 schema object.
812    /// From [Gemini API Reference](https://ai.google.dev/api/caching#Schema)
813    #[derive(Debug, Deserialize, Serialize)]
814    pub struct Schema {
815        pub r#type: String,
816        #[serde(skip_serializing_if = "Option::is_none")]
817        pub format: Option<String>,
818        #[serde(skip_serializing_if = "Option::is_none")]
819        pub description: Option<String>,
820        #[serde(skip_serializing_if = "Option::is_none")]
821        pub nullable: Option<bool>,
822        #[serde(skip_serializing_if = "Option::is_none")]
823        pub r#enum: Option<Vec<String>>,
824        #[serde(skip_serializing_if = "Option::is_none")]
825        pub max_items: Option<i32>,
826        #[serde(skip_serializing_if = "Option::is_none")]
827        pub min_items: Option<i32>,
828        #[serde(skip_serializing_if = "Option::is_none")]
829        pub properties: Option<HashMap<String, Schema>>,
830        #[serde(skip_serializing_if = "Option::is_none")]
831        pub required: Option<Vec<String>>,
832        #[serde(skip_serializing_if = "Option::is_none")]
833        pub items: Option<Box<Schema>>,
834    }
835
836    impl TryFrom<Value> for Schema {
837        type Error = CompletionError;
838
839        fn try_from(value: Value) -> Result<Self, Self::Error> {
840            if let Some(obj) = value.as_object() {
841                Ok(Schema {
842                    r#type: obj
843                        .get("type")
844                        .and_then(|v| {
845                            if v.is_string() {
846                                v.as_str().map(String::from)
847                            } else if v.is_array() {
848                                v.as_array()
849                                    .and_then(|arr| arr.first())
850                                    .and_then(|v| v.as_str().map(String::from))
851                            } else {
852                                None
853                            }
854                        })
855                        .unwrap_or_default(),
856                    format: obj.get("format").and_then(|v| v.as_str()).map(String::from),
857                    description: obj
858                        .get("description")
859                        .and_then(|v| v.as_str())
860                        .map(String::from),
861                    nullable: obj.get("nullable").and_then(|v| v.as_bool()),
862                    r#enum: obj.get("enum").and_then(|v| v.as_array()).map(|arr| {
863                        arr.iter()
864                            .filter_map(|v| v.as_str().map(String::from))
865                            .collect()
866                    }),
867                    max_items: obj
868                        .get("maxItems")
869                        .and_then(|v| v.as_i64())
870                        .map(|v| v as i32),
871                    min_items: obj
872                        .get("minItems")
873                        .and_then(|v| v.as_i64())
874                        .map(|v| v as i32),
875                    properties: obj
876                        .get("properties")
877                        .and_then(|v| v.as_object())
878                        .map(|map| {
879                            map.iter()
880                                .filter_map(|(k, v)| {
881                                    v.clone().try_into().ok().map(|schema| (k.clone(), schema))
882                                })
883                                .collect()
884                        }),
885                    required: obj.get("required").and_then(|v| v.as_array()).map(|arr| {
886                        arr.iter()
887                            .filter_map(|v| v.as_str().map(String::from))
888                            .collect()
889                    }),
890                    items: obj
891                        .get("items")
892                        .map(|v| Box::new(v.clone().try_into().unwrap())),
893                })
894            } else {
895                Err(CompletionError::ResponseError(
896                    "Expected a JSON object for Schema".into(),
897                ))
898            }
899        }
900    }
901
902    #[derive(Debug, Serialize)]
903    #[serde(rename_all = "camelCase")]
904    pub struct GenerateContentRequest {
905        pub contents: Vec<Content>,
906        pub tools: Option<Vec<Tool>>,
907        pub tool_config: Option<ToolConfig>,
908        /// Optional. Configuration options for model generation and outputs.
909        pub generation_config: Option<GenerationConfig>,
910        /// Optional. A list of unique SafetySetting instances for blocking unsafe content. This will be enforced on the
911        /// [GenerateContentRequest.contents] and [GenerateContentResponse.candidates]. There should not be more than one
912        /// setting for each SafetyCategory type. The API will block any contents and responses that fail to meet the
913        /// thresholds set by these settings. This list overrides the default settings for each SafetyCategory specified
914        /// in the safetySettings. If there is no SafetySetting for a given SafetyCategory provided in the list, the API
915        /// will use the default safety setting for that category. Harm categories:
916        ///     - HARM_CATEGORY_HATE_SPEECH,
917        ///     - HARM_CATEGORY_SEXUALLY_EXPLICIT
918        ///     - HARM_CATEGORY_DANGEROUS_CONTENT
919        ///     - HARM_CATEGORY_HARASSMENT
920        /// are supported.
921        /// Refer to the guide for detailed information on available safety settings. Also refer to the Safety guidance
922        /// to learn how to incorporate safety considerations in your AI applications.
923        pub safety_settings: Option<Vec<SafetySetting>>,
924        /// Optional. Developer set system instruction(s). Currently, text only.
925        /// From [Gemini API Reference](https://ai.google.dev/gemini-api/docs/system-instructions?lang=rest)
926        pub system_instruction: Option<Content>,
927        // cachedContent: Optional<String>
928    }
929
930    #[derive(Debug, Serialize)]
931    #[serde(rename_all = "camelCase")]
932    pub struct Tool {
933        pub function_declarations: FunctionDeclaration,
934        pub code_execution: Option<CodeExecution>,
935    }
936
937    #[derive(Debug, Serialize)]
938    #[serde(rename_all = "camelCase")]
939    pub struct FunctionDeclaration {
940        pub name: String,
941        pub description: String,
942        #[serde(skip_serializing_if = "Option::is_none")]
943        pub parameters: Option<Schema>,
944    }
945
946    #[derive(Debug, Serialize)]
947    #[serde(rename_all = "camelCase")]
948    pub struct ToolConfig {
949        pub schema: Option<Schema>,
950    }
951
952    #[derive(Debug, Serialize)]
953    #[serde(rename_all = "camelCase")]
954    pub struct CodeExecution {}
955
956    #[derive(Debug, Serialize)]
957    #[serde(rename_all = "camelCase")]
958    pub struct SafetySetting {
959        pub category: HarmCategory,
960        pub threshold: HarmBlockThreshold,
961    }
962
963    #[derive(Debug, Serialize)]
964    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
965    pub enum HarmBlockThreshold {
966        HarmBlockThresholdUnspecified,
967        BlockLowAndAbove,
968        BlockMediumAndAbove,
969        BlockOnlyHigh,
970        BlockNone,
971        Off,
972    }
973}
974
975#[cfg(test)]
976mod tests {
977    use crate::message;
978
979    use super::*;
980    use serde_json::json;
981
982    #[test]
983    fn test_deserialize_message_user() {
984        let raw_message = r#"{
985            "parts": [
986                {"text": "Hello, world!"},
987                {"inlineData": {"mimeType": "image/png", "data": "base64encodeddata"}},
988                {"functionCall": {"name": "test_function", "args": {"arg1": "value1"}}},
989                {"functionResponse": {"name": "test_function", "response": {"result": "success"}}},
990                {"fileData": {"mimeType": "application/pdf", "fileUri": "http://example.com/file.pdf"}},
991                {"executableCode": {"code": "print('Hello, world!')", "language": "PYTHON"}},
992                {"codeExecutionResult": {"output": "Hello, world!", "outcome": "OUTCOME_OK"}}
993            ],
994            "role": "user"
995        }"#;
996
997        let content: Content = {
998            let jd = &mut serde_json::Deserializer::from_str(raw_message);
999            serde_path_to_error::deserialize(jd).unwrap_or_else(|err| {
1000                panic!("Deserialization error at {}: {}", err.path(), err);
1001            })
1002        };
1003        assert_eq!(content.role, Some(Role::User));
1004        assert_eq!(content.parts.len(), 7);
1005
1006        let parts: Vec<Part> = content.parts.into_iter().collect();
1007
1008        if let Part::Text(text) = &parts[0] {
1009            assert_eq!(text, "Hello, world!");
1010        } else {
1011            panic!("Expected text part");
1012        }
1013
1014        if let Part::InlineData(inline_data) = &parts[1] {
1015            assert_eq!(inline_data.mime_type, "image/png");
1016            assert_eq!(inline_data.data, "base64encodeddata");
1017        } else {
1018            panic!("Expected inline data part");
1019        }
1020
1021        if let Part::FunctionCall(function_call) = &parts[2] {
1022            assert_eq!(function_call.name, "test_function");
1023            assert_eq!(
1024                function_call.args.as_object().unwrap().get("arg1").unwrap(),
1025                "value1"
1026            );
1027        } else {
1028            panic!("Expected function call part");
1029        }
1030
1031        if let Part::FunctionResponse(function_response) = &parts[3] {
1032            assert_eq!(function_response.name, "test_function");
1033            assert_eq!(
1034                function_response
1035                    .response
1036                    .as_ref()
1037                    .unwrap()
1038                    .get("result")
1039                    .unwrap(),
1040                "success"
1041            );
1042        } else {
1043            panic!("Expected function response part");
1044        }
1045
1046        if let Part::FileData(file_data) = &parts[4] {
1047            assert_eq!(file_data.mime_type.as_ref().unwrap(), "application/pdf");
1048            assert_eq!(file_data.file_uri, "http://example.com/file.pdf");
1049        } else {
1050            panic!("Expected file data part");
1051        }
1052
1053        if let Part::ExecutableCode(executable_code) = &parts[5] {
1054            assert_eq!(executable_code.code, "print('Hello, world!')");
1055        } else {
1056            panic!("Expected executable code part");
1057        }
1058
1059        if let Part::CodeExecutionResult(code_execution_result) = &parts[6] {
1060            assert_eq!(
1061                code_execution_result.clone().output.unwrap(),
1062                "Hello, world!"
1063            );
1064        } else {
1065            panic!("Expected code execution result part");
1066        }
1067    }
1068
1069    #[test]
1070    fn test_deserialize_message_model() {
1071        let json_data = json!({
1072            "parts": [{"text": "Hello, user!"}],
1073            "role": "model"
1074        });
1075
1076        let content: Content = serde_json::from_value(json_data).unwrap();
1077        assert_eq!(content.role, Some(Role::Model));
1078        assert_eq!(content.parts.len(), 1);
1079        if let Part::Text(text) = &content.parts.first() {
1080            assert_eq!(text, "Hello, user!");
1081        } else {
1082            panic!("Expected text part");
1083        }
1084    }
1085
1086    #[test]
1087    fn test_message_conversion_user() {
1088        let msg = message::Message::user("Hello, world!");
1089        let content: Content = msg.try_into().unwrap();
1090        assert_eq!(content.role, Some(Role::User));
1091        assert_eq!(content.parts.len(), 1);
1092        if let Part::Text(text) = &content.parts.first() {
1093            assert_eq!(text, "Hello, world!");
1094        } else {
1095            panic!("Expected text part");
1096        }
1097    }
1098
1099    #[test]
1100    fn test_message_conversion_model() {
1101        let msg = message::Message::assistant("Hello, user!");
1102
1103        let content: Content = msg.try_into().unwrap();
1104        assert_eq!(content.role, Some(Role::Model));
1105        assert_eq!(content.parts.len(), 1);
1106        if let Part::Text(text) = &content.parts.first() {
1107            assert_eq!(text, "Hello, user!");
1108        } else {
1109            panic!("Expected text part");
1110        }
1111    }
1112
1113    #[test]
1114    fn test_message_conversion_tool_call() {
1115        let tool_call = message::ToolCall {
1116            id: "test_tool".to_string(),
1117            function: message::ToolFunction {
1118                name: "test_function".to_string(),
1119                arguments: json!({"arg1": "value1"}),
1120            },
1121        };
1122
1123        let msg = message::Message::Assistant {
1124            content: OneOrMany::one(message::AssistantContent::ToolCall(tool_call)),
1125        };
1126
1127        let content: Content = msg.try_into().unwrap();
1128        assert_eq!(content.role, Some(Role::Model));
1129        assert_eq!(content.parts.len(), 1);
1130        if let Part::FunctionCall(function_call) = &content.parts.first() {
1131            assert_eq!(function_call.name, "test_function");
1132            assert_eq!(
1133                function_call.args.as_object().unwrap().get("arg1").unwrap(),
1134                "value1"
1135            );
1136        } else {
1137            panic!("Expected function call part");
1138        }
1139    }
1140}