rig/providers/gemini/
completion.rs

1// ================================================================
2//! Google Gemini Completion Integration
3//! From [Gemini API Reference](https://ai.google.dev/api/generate-content)
4// ================================================================
5/// `gemini-2.5-pro-preview-06-05` completion model
6pub const GEMINI_2_5_PRO_PREVIEW_06_05: &str = "gemini-2.5-pro-preview-06-05";
7/// `gemini-2.5-pro-preview-05-06` completion model
8pub const GEMINI_2_5_PRO_PREVIEW_05_06: &str = "gemini-2.5-pro-preview-05-06";
9/// `gemini-2.5-pro-preview-03-25` completion model
10pub const GEMINI_2_5_PRO_PREVIEW_03_25: &str = "gemini-2.5-pro-preview-03-25";
11/// `gemini-2.5-flash-preview-05-20` completion model
12pub const GEMINI_2_5_FLASH_PREVIEW_05_20: &str = "gemini-2.5-flash-preview-05-20";
13/// `gemini-2.5-flash-preview-04-17` completion model
14pub const GEMINI_2_5_FLASH_PREVIEW_04_17: &str = "gemini-2.5-flash-preview-04-17";
15/// `gemini-2.5-pro-exp-03-25` experimental completion model
16pub const GEMINI_2_5_PRO_EXP_03_25: &str = "gemini-2.5-pro-exp-03-25";
17/// `gemini-2.0-flash-lite` completion model
18pub const GEMINI_2_0_FLASH_LITE: &str = "gemini-2.0-flash-lite";
19/// `gemini-2.0-flash` completion model
20pub const GEMINI_2_0_FLASH: &str = "gemini-2.0-flash";
21/// `gemini-1.5-flash` completion model
22pub const GEMINI_1_5_FLASH: &str = "gemini-1.5-flash";
23/// `gemini-1.5-pro` completion model
24pub const GEMINI_1_5_PRO: &str = "gemini-1.5-pro";
25/// `gemini-1.5-pro-8b` completion model
26pub const GEMINI_1_5_PRO_8B: &str = "gemini-1.5-pro-8b";
27/// `gemini-1.0-pro` completion model
28pub const GEMINI_1_0_PRO: &str = "gemini-1.0-pro";
29
30use self::gemini_api_types::Schema;
31use crate::providers::gemini::streaming::StreamingCompletionResponse;
32use crate::{
33    completion::{self, CompletionError, CompletionRequest},
34    OneOrMany,
35};
36use gemini_api_types::{
37    Content, FunctionDeclaration, GenerateContentRequest, GenerateContentResponse,
38    GenerationConfig, Part, Role, Tool,
39};
40use serde_json::{Map, Value};
41use std::convert::TryFrom;
42
43use super::Client;
44
45// =================================================================
46// Rig Implementation Types
47// =================================================================
48
49#[derive(Clone)]
50pub struct CompletionModel {
51    pub(crate) client: Client,
52    pub model: String,
53}
54
55impl CompletionModel {
56    pub fn new(client: Client, model: &str) -> Self {
57        Self {
58            client,
59            model: model.to_string(),
60        }
61    }
62}
63
64impl completion::CompletionModel for CompletionModel {
65    type Response = GenerateContentResponse;
66    type StreamingResponse = StreamingCompletionResponse;
67
68    #[cfg_attr(feature = "worker", worker::send)]
69    async fn completion(
70        &self,
71        completion_request: CompletionRequest,
72    ) -> Result<completion::CompletionResponse<GenerateContentResponse>, CompletionError> {
73        let request = create_request_body(completion_request)?;
74
75        tracing::debug!(
76            "Sending completion request to Gemini API {}",
77            serde_json::to_string_pretty(&request)?
78        );
79
80        let response = self
81            .client
82            .post(&format!("/v1beta/models/{}:generateContent", self.model))
83            .json(&request)
84            .send()
85            .await?;
86
87        if response.status().is_success() {
88            let response = response.json::<GenerateContentResponse>().await?;
89            match response.usage_metadata {
90                Some(ref usage) => tracing::info!(target: "rig",
91                "Gemini completion token usage: {}",
92                usage
93                ),
94                None => tracing::info!(target: "rig",
95                    "Gemini completion token usage: n/a",
96                ),
97            }
98
99            tracing::debug!("Received response");
100
101            Ok(completion::CompletionResponse::try_from(response))
102        } else {
103            Err(CompletionError::ProviderError(response.text().await?))
104        }?
105    }
106
107    #[cfg_attr(feature = "worker", worker::send)]
108    async fn stream(
109        &self,
110        request: CompletionRequest,
111    ) -> Result<
112        crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
113        CompletionError,
114    > {
115        CompletionModel::stream(self, request).await
116    }
117}
118
119pub(crate) fn create_request_body(
120    completion_request: CompletionRequest,
121) -> Result<GenerateContentRequest, CompletionError> {
122    let mut full_history = Vec::new();
123    full_history.extend(completion_request.chat_history);
124
125    let additional_params = completion_request
126        .additional_params
127        .unwrap_or_else(|| Value::Object(Map::new()));
128
129    let mut generation_config = serde_json::from_value::<GenerationConfig>(additional_params)?;
130
131    if let Some(temp) = completion_request.temperature {
132        generation_config.temperature = Some(temp);
133    }
134
135    if let Some(max_tokens) = completion_request.max_tokens {
136        generation_config.max_output_tokens = Some(max_tokens);
137    }
138
139    let system_instruction = completion_request.preamble.clone().map(|preamble| Content {
140        parts: OneOrMany::one(preamble.into()),
141        role: Some(Role::Model),
142    });
143
144    let request = GenerateContentRequest {
145        contents: full_history
146            .into_iter()
147            .map(|msg| {
148                msg.try_into()
149                    .map_err(|e| CompletionError::RequestError(Box::new(e)))
150            })
151            .collect::<Result<Vec<_>, _>>()?,
152        generation_config: Some(generation_config),
153        safety_settings: None,
154        tools: Some(
155            completion_request
156                .tools
157                .into_iter()
158                .map(Tool::try_from)
159                .collect::<Result<Vec<_>, _>>()?,
160        ),
161        tool_config: None,
162        system_instruction,
163    };
164
165    Ok(request)
166}
167
168impl TryFrom<completion::ToolDefinition> for Tool {
169    type Error = CompletionError;
170
171    fn try_from(tool: completion::ToolDefinition) -> Result<Self, Self::Error> {
172        let parameters: Option<Schema> =
173            if tool.parameters == serde_json::json!({"type": "object", "properties": {}}) {
174                None
175            } else {
176                Some(tool.parameters.try_into()?)
177            };
178        Ok(Self {
179            function_declarations: FunctionDeclaration {
180                name: tool.name,
181                description: tool.description,
182                parameters,
183            },
184            code_execution: None,
185        })
186    }
187}
188
189impl TryFrom<GenerateContentResponse> for completion::CompletionResponse<GenerateContentResponse> {
190    type Error = CompletionError;
191
192    fn try_from(response: GenerateContentResponse) -> Result<Self, Self::Error> {
193        let candidate = response.candidates.first().ok_or_else(|| {
194            CompletionError::ResponseError("No response candidates in response".into())
195        })?;
196
197        let content = candidate
198            .content
199            .parts
200            .iter()
201            .map(|part| {
202                Ok(match part {
203                    Part::Text(text) => completion::AssistantContent::text(text),
204                    Part::FunctionCall(function_call) => completion::AssistantContent::tool_call(
205                        &function_call.name,
206                        &function_call.name,
207                        function_call.args.clone(),
208                    ),
209                    _ => {
210                        return Err(CompletionError::ResponseError(
211                            "Response did not contain a message or tool call".into(),
212                        ))
213                    }
214                })
215            })
216            .collect::<Result<Vec<_>, _>>()?;
217
218        let choice = OneOrMany::many(content).map_err(|_| {
219            CompletionError::ResponseError(
220                "Response contained no message or tool call (empty)".to_owned(),
221            )
222        })?;
223
224        Ok(completion::CompletionResponse {
225            choice,
226            raw_response: response,
227        })
228    }
229}
230
231pub mod gemini_api_types {
232    use std::{collections::HashMap, convert::Infallible, str::FromStr};
233
234    // =================================================================
235    // Gemini API Types
236    // =================================================================
237    use serde::{Deserialize, Serialize};
238    use serde_json::Value;
239
240    use crate::{
241        completion::CompletionError,
242        message::{self, MimeType as _},
243        one_or_many::string_or_one_or_many,
244        providers::gemini::gemini_api_types::{CodeExecutionResult, ExecutableCode},
245        OneOrMany,
246    };
247
248    /// Response from the model supporting multiple candidate responses.
249    /// Safety ratings and content filtering are reported for both prompt in GenerateContentResponse.prompt_feedback
250    /// and for each candidate in finishReason and in safetyRatings.
251    /// The API:
252    ///     - Returns either all requested candidates or none of them
253    ///     - Returns no candidates at all only if there was something wrong with the prompt (check promptFeedback)
254    ///     - Reports feedback on each candidate in finishReason and safetyRatings.
255    #[derive(Debug, Deserialize)]
256    #[serde(rename_all = "camelCase")]
257    pub struct GenerateContentResponse {
258        /// Candidate responses from the model.
259        pub candidates: Vec<ContentCandidate>,
260        /// Returns the prompt's feedback related to the content filters.
261        pub prompt_feedback: Option<PromptFeedback>,
262        /// Output only. Metadata on the generation requests' token usage.
263        pub usage_metadata: Option<UsageMetadata>,
264        pub model_version: Option<String>,
265    }
266
267    /// A response candidate generated from the model.
268    #[derive(Debug, Deserialize)]
269    #[serde(rename_all = "camelCase")]
270    pub struct ContentCandidate {
271        /// Output only. Generated content returned from the model.
272        pub content: Content,
273        /// Optional. Output only. The reason why the model stopped generating tokens.
274        /// If empty, the model has not stopped generating tokens.
275        pub finish_reason: Option<FinishReason>,
276        /// List of ratings for the safety of a response candidate.
277        /// There is at most one rating per category.
278        pub safety_ratings: Option<Vec<SafetyRating>>,
279        /// Output only. Citation information for model-generated candidate.
280        /// This field may be populated with recitation information for any text included in the content.
281        /// These are passages that are "recited" from copyrighted material in the foundational LLM's training data.
282        pub citation_metadata: Option<CitationMetadata>,
283        /// Output only. Token count for this candidate.
284        pub token_count: Option<i32>,
285        /// Output only.
286        pub avg_logprobs: Option<f64>,
287        /// Output only. Log-likelihood scores for the response tokens and top tokens
288        pub logprobs_result: Option<LogprobsResult>,
289        /// Output only. Index of the candidate in the list of response candidates.
290        pub index: Option<i32>,
291    }
292    #[derive(Debug, Deserialize, Serialize)]
293    pub struct Content {
294        /// Ordered Parts that constitute a single message. Parts may have different MIME types.
295        #[serde(deserialize_with = "string_or_one_or_many")]
296        pub parts: OneOrMany<Part>,
297        /// The producer of the content. Must be either 'user' or 'model'.
298        /// Useful to set for multi-turn conversations, otherwise can be left blank or unset.
299        pub role: Option<Role>,
300    }
301
302    impl TryFrom<message::Message> for Content {
303        type Error = message::MessageError;
304
305        fn try_from(msg: message::Message) -> Result<Self, Self::Error> {
306            Ok(match msg {
307                message::Message::User { content } => Content {
308                    parts: content.try_map(|c| c.try_into())?,
309                    role: Some(Role::User),
310                },
311                message::Message::Assistant { content } => Content {
312                    role: Some(Role::Model),
313                    parts: content.map(|content| content.into()),
314                },
315            })
316        }
317    }
318
319    impl TryFrom<Content> for message::Message {
320        type Error = message::MessageError;
321
322        fn try_from(content: Content) -> Result<Self, Self::Error> {
323            match content.role {
324                Some(Role::User) | None => Ok(message::Message::User {
325                    content: content.parts.try_map(|part| {
326                        Ok(match part {
327                            Part::Text(text) => message::UserContent::text(text),
328                            Part::InlineData(inline_data) => {
329                                let mime_type =
330                                    message::MediaType::from_mime_type(&inline_data.mime_type);
331
332                                match mime_type {
333                                    Some(message::MediaType::Image(media_type)) => {
334                                        message::UserContent::image(
335                                            inline_data.data,
336                                            Some(message::ContentFormat::default()),
337                                            Some(media_type),
338                                            Some(message::ImageDetail::default()),
339                                        )
340                                    }
341                                    Some(message::MediaType::Document(media_type)) => {
342                                        message::UserContent::document(
343                                            inline_data.data,
344                                            Some(message::ContentFormat::default()),
345                                            Some(media_type),
346                                        )
347                                    }
348                                    Some(message::MediaType::Audio(media_type)) => {
349                                        message::UserContent::audio(
350                                            inline_data.data,
351                                            Some(message::ContentFormat::default()),
352                                            Some(media_type),
353                                        )
354                                    }
355                                    _ => {
356                                        return Err(message::MessageError::ConversionError(
357                                            format!("Unsupported media type {mime_type:?}"),
358                                        ))
359                                    }
360                                }
361                            }
362                            _ => {
363                                return Err(message::MessageError::ConversionError(format!(
364                                    "Unsupported gemini content part type: {part:?}"
365                                )))
366                            }
367                        })
368                    })?,
369                }),
370                Some(Role::Model) => Ok(message::Message::Assistant {
371                    content: content.parts.try_map(|part| {
372                        Ok(match part {
373                            Part::Text(text) => message::AssistantContent::text(text),
374                            Part::FunctionCall(function_call) => {
375                                message::AssistantContent::ToolCall(function_call.into())
376                            }
377                            _ => {
378                                return Err(message::MessageError::ConversionError(format!(
379                                    "Unsupported part type: {part:?}"
380                                )))
381                            }
382                        })
383                    })?,
384                }),
385            }
386        }
387    }
388
389    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
390    #[serde(rename_all = "lowercase")]
391    pub enum Role {
392        User,
393        Model,
394    }
395
396    /// A datatype containing media that is part of a multi-part [Content] message.
397    /// A Part consists of data which has an associated datatype. A Part can only contain one of the accepted types in Part.data.
398    /// A Part must have a fixed IANA MIME type identifying the type and subtype of the media if the inlineData field is filled with raw bytes.
399    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
400    #[serde(rename_all = "camelCase")]
401    pub enum Part {
402        Text(String),
403        InlineData(Blob),
404        FunctionCall(FunctionCall),
405        FunctionResponse(FunctionResponse),
406        FileData(FileData),
407        ExecutableCode(ExecutableCode),
408        CodeExecutionResult(CodeExecutionResult),
409    }
410
411    impl From<String> for Part {
412        fn from(text: String) -> Self {
413            Self::Text(text)
414        }
415    }
416
417    impl From<&str> for Part {
418        fn from(text: &str) -> Self {
419            Self::Text(text.to_string())
420        }
421    }
422
423    impl FromStr for Part {
424        type Err = Infallible;
425
426        fn from_str(s: &str) -> Result<Self, Self::Err> {
427            Ok(s.into())
428        }
429    }
430
431    impl TryFrom<message::UserContent> for Part {
432        type Error = message::MessageError;
433
434        fn try_from(content: message::UserContent) -> Result<Self, Self::Error> {
435            match content {
436                message::UserContent::Text(message::Text { text }) => Ok(Self::Text(text)),
437                message::UserContent::ToolResult(message::ToolResult { id, content }) => {
438                    let content = match content.first() {
439                        message::ToolResultContent::Text(text) => text.text,
440                        message::ToolResultContent::Image(_) => {
441                            return Err(message::MessageError::ConversionError(
442                                "Tool result content must be text".to_string(),
443                            ))
444                        }
445                    };
446                    Ok(Part::FunctionResponse(FunctionResponse {
447                        name: id,
448                        response: Some(serde_json::from_str(&content).map_err(|e| {
449                            message::MessageError::ConversionError(format!(
450                                "Failed to parse tool response: {e}"
451                            ))
452                        })?),
453                    }))
454                }
455                message::UserContent::Image(message::Image {
456                    data, media_type, ..
457                }) => match media_type {
458                    Some(media_type) => match media_type {
459                        message::ImageMediaType::JPEG
460                        | message::ImageMediaType::PNG
461                        | message::ImageMediaType::WEBP
462                        | message::ImageMediaType::HEIC
463                        | message::ImageMediaType::HEIF => Ok(Self::InlineData(Blob {
464                            mime_type: media_type.to_mime_type().to_owned(),
465                            data,
466                        })),
467                        _ => Err(message::MessageError::ConversionError(format!(
468                            "Unsupported image media type {media_type:?}"
469                        ))),
470                    },
471                    None => Err(message::MessageError::ConversionError(
472                        "Media type for image is required for Anthropic".to_string(),
473                    )),
474                },
475                message::UserContent::Document(message::Document {
476                    data, media_type, ..
477                }) => match media_type {
478                    Some(media_type) => match media_type {
479                        message::DocumentMediaType::PDF
480                        | message::DocumentMediaType::TXT
481                        | message::DocumentMediaType::RTF
482                        | message::DocumentMediaType::HTML
483                        | message::DocumentMediaType::CSS
484                        | message::DocumentMediaType::MARKDOWN
485                        | message::DocumentMediaType::CSV
486                        | message::DocumentMediaType::XML => Ok(Self::InlineData(Blob {
487                            mime_type: media_type.to_mime_type().to_owned(),
488                            data,
489                        })),
490                        _ => Err(message::MessageError::ConversionError(format!(
491                            "Unsupported document media type {media_type:?}"
492                        ))),
493                    },
494                    None => Err(message::MessageError::ConversionError(
495                        "Media type for document is required for Anthropic".to_string(),
496                    )),
497                },
498                message::UserContent::Audio(message::Audio {
499                    data, media_type, ..
500                }) => match media_type {
501                    Some(media_type) => Ok(Self::InlineData(Blob {
502                        mime_type: media_type.to_mime_type().to_owned(),
503                        data,
504                    })),
505                    None => Err(message::MessageError::ConversionError(
506                        "Media type for audio is required for Anthropic".to_string(),
507                    )),
508                },
509            }
510        }
511    }
512
513    impl From<message::AssistantContent> for Part {
514        fn from(content: message::AssistantContent) -> Self {
515            match content {
516                message::AssistantContent::Text(message::Text { text }) => text.into(),
517                message::AssistantContent::ToolCall(tool_call) => tool_call.into(),
518            }
519        }
520    }
521
522    impl From<message::ToolCall> for Part {
523        fn from(tool_call: message::ToolCall) -> Self {
524            Self::FunctionCall(FunctionCall {
525                name: tool_call.function.name,
526                args: tool_call.function.arguments,
527            })
528        }
529    }
530
531    /// Raw media bytes.
532    /// Text should not be sent as raw bytes, use the 'text' field.
533    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
534    #[serde(rename_all = "camelCase")]
535    pub struct Blob {
536        /// The IANA standard MIME type of the source data. Examples: - image/png - image/jpeg
537        /// If an unsupported MIME type is provided, an error will be returned.
538        pub mime_type: String,
539        /// Raw bytes for media formats. A base64-encoded string.
540        pub data: String,
541    }
542
543    /// A predicted FunctionCall returned from the model that contains a string representing the
544    /// FunctionDeclaration.name with the arguments and their values.
545    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
546    pub struct FunctionCall {
547        /// Required. The name of the function to call. Must be a-z, A-Z, 0-9, or contain underscores
548        /// and dashes, with a maximum length of 63.
549        pub name: String,
550        /// Optional. The function parameters and values in JSON object format.
551        pub args: serde_json::Value,
552    }
553
554    impl From<FunctionCall> for message::ToolCall {
555        fn from(function_call: FunctionCall) -> Self {
556            Self {
557                id: function_call.name.clone(),
558                function: message::ToolFunction {
559                    name: function_call.name,
560                    arguments: function_call.args,
561                },
562            }
563        }
564    }
565
566    impl From<message::ToolCall> for FunctionCall {
567        fn from(tool_call: message::ToolCall) -> Self {
568            Self {
569                name: tool_call.function.name,
570                args: tool_call.function.arguments,
571            }
572        }
573    }
574
575    /// The result output from a FunctionCall that contains a string representing the FunctionDeclaration.name
576    /// and a structured JSON object containing any output from the function is used as context to the model.
577    /// This should contain the result of aFunctionCall made based on model prediction.
578    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
579    pub struct FunctionResponse {
580        /// The name of the function to call. Must be a-z, A-Z, 0-9, or contain underscores and dashes,
581        /// with a maximum length of 63.
582        pub name: String,
583        /// The function response in JSON object format.
584        pub response: Option<HashMap<String, serde_json::Value>>,
585    }
586
587    /// URI based data.
588    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
589    #[serde(rename_all = "camelCase")]
590    pub struct FileData {
591        /// Optional. The IANA standard MIME type of the source data.
592        pub mime_type: Option<String>,
593        /// Required. URI.
594        pub file_uri: String,
595    }
596
597    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
598    pub struct SafetyRating {
599        pub category: HarmCategory,
600        pub probability: HarmProbability,
601    }
602
603    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
604    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
605    pub enum HarmProbability {
606        HarmProbabilityUnspecified,
607        Negligible,
608        Low,
609        Medium,
610        High,
611    }
612
613    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
614    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
615    pub enum HarmCategory {
616        HarmCategoryUnspecified,
617        HarmCategoryDerogatory,
618        HarmCategoryToxicity,
619        HarmCategoryViolence,
620        HarmCategorySexually,
621        HarmCategoryMedical,
622        HarmCategoryDangerous,
623        HarmCategoryHarassment,
624        HarmCategoryHateSpeech,
625        HarmCategorySexuallyExplicit,
626        HarmCategoryDangerousContent,
627        HarmCategoryCivicIntegrity,
628    }
629
630    #[derive(Debug, Deserialize, Clone, Default)]
631    #[serde(rename_all = "camelCase")]
632    pub struct UsageMetadata {
633        pub prompt_token_count: i32,
634        #[serde(skip_serializing_if = "Option::is_none")]
635        pub cached_content_token_count: Option<i32>,
636        pub candidates_token_count: i32,
637        pub total_token_count: i32,
638    }
639
640    impl std::fmt::Display for UsageMetadata {
641        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
642            write!(
643                f,
644                "Prompt token count: {}\nCached content token count: {}\nCandidates token count: {}\nTotal token count: {}",
645                self.prompt_token_count,
646                match self.cached_content_token_count {
647                    Some(count) => count.to_string(),
648                    None => "n/a".to_string(),
649                },
650                self.candidates_token_count,
651                self.total_token_count
652            )
653        }
654    }
655
656    /// A set of the feedback metadata the prompt specified in [GenerateContentRequest.contents](GenerateContentRequest).
657    #[derive(Debug, Deserialize)]
658    #[serde(rename_all = "camelCase")]
659    pub struct PromptFeedback {
660        /// Optional. If set, the prompt was blocked and no candidates are returned. Rephrase the prompt.
661        pub block_reason: Option<BlockReason>,
662        /// Ratings for safety of the prompt. There is at most one rating per category.
663        pub safety_ratings: Option<Vec<SafetyRating>>,
664    }
665
666    /// Reason why a prompt was blocked by the model
667    #[derive(Debug, Deserialize)]
668    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
669    pub enum BlockReason {
670        /// Default value. This value is unused.
671        BlockReasonUnspecified,
672        /// Prompt was blocked due to safety reasons. Inspect safetyRatings to understand which safety category blocked it.
673        Safety,
674        /// Prompt was blocked due to unknown reasons.
675        Other,
676        /// Prompt was blocked due to the terms which are included from the terminology blocklist.
677        Blocklist,
678        /// Prompt was blocked due to prohibited content.
679        ProhibitedContent,
680    }
681
682    #[derive(Debug, Deserialize)]
683    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
684    pub enum FinishReason {
685        /// Default value. This value is unused.
686        FinishReasonUnspecified,
687        /// Natural stop point of the model or provided stop sequence.
688        Stop,
689        /// The maximum number of tokens as specified in the request was reached.
690        MaxTokens,
691        /// The response candidate content was flagged for safety reasons.
692        Safety,
693        /// The response candidate content was flagged for recitation reasons.
694        Recitation,
695        /// The response candidate content was flagged for using an unsupported language.
696        Language,
697        /// Unknown reason.
698        Other,
699        /// Token generation stopped because the content contains forbidden terms.
700        Blocklist,
701        /// Token generation stopped for potentially containing prohibited content.
702        ProhibitedContent,
703        /// Token generation stopped because the content potentially contains Sensitive Personally Identifiable Information (SPII).
704        Spii,
705        /// The function call generated by the model is invalid.
706        MalformedFunctionCall,
707    }
708
709    #[derive(Debug, Deserialize)]
710    #[serde(rename_all = "camelCase")]
711    pub struct CitationMetadata {
712        pub citation_sources: Vec<CitationSource>,
713    }
714
715    #[derive(Debug, Deserialize)]
716    #[serde(rename_all = "camelCase")]
717    pub struct CitationSource {
718        #[serde(skip_serializing_if = "Option::is_none")]
719        pub uri: Option<String>,
720        #[serde(skip_serializing_if = "Option::is_none")]
721        pub start_index: Option<i32>,
722        #[serde(skip_serializing_if = "Option::is_none")]
723        pub end_index: Option<i32>,
724        #[serde(skip_serializing_if = "Option::is_none")]
725        pub license: Option<String>,
726    }
727
728    #[derive(Debug, Deserialize)]
729    #[serde(rename_all = "camelCase")]
730    pub struct LogprobsResult {
731        pub top_candidate: Vec<TopCandidate>,
732        pub chosen_candidate: Vec<LogProbCandidate>,
733    }
734
735    #[derive(Debug, Deserialize)]
736    pub struct TopCandidate {
737        pub candidates: Vec<LogProbCandidate>,
738    }
739
740    #[derive(Debug, Deserialize)]
741    #[serde(rename_all = "camelCase")]
742    pub struct LogProbCandidate {
743        pub token: String,
744        pub token_id: String,
745        pub log_probability: f64,
746    }
747
748    /// Gemini API Configuration options for model generation and outputs. Not all parameters are
749    /// configurable for every model. From [Gemini API Reference](https://ai.google.dev/api/generate-content#generationconfig)
750    /// ### Rig Note:
751    /// Can be used to construct a typesafe `additional_params` in rig::[AgentBuilder](crate::agent::AgentBuilder).
752    #[derive(Debug, Deserialize, Serialize)]
753    #[serde(rename_all = "camelCase")]
754    pub struct GenerationConfig {
755        /// The set of character sequences (up to 5) that will stop output generation. If specified, the API will stop
756        /// at the first appearance of a stop_sequence. The stop sequence will not be included as part of the response.
757        #[serde(skip_serializing_if = "Option::is_none")]
758        pub stop_sequences: Option<Vec<String>>,
759        /// MIME type of the generated candidate text. Supported MIME types are:
760        ///     - text/plain:  (default) Text output
761        ///     - application/json: JSON response in the response candidates.
762        ///     - text/x.enum: ENUM as a string response in the response candidates.
763        /// Refer to the docs for a list of all supported text MIME types
764        #[serde(skip_serializing_if = "Option::is_none")]
765        pub response_mime_type: Option<String>,
766        /// Output schema of the generated candidate text. Schemas must be a subset of the OpenAPI schema and can be
767        /// objects, primitives or arrays. If set, a compatible responseMimeType must also  be set. Compatible MIME
768        /// types: application/json: Schema for JSON response. Refer to the JSON text generation guide for more details.
769        #[serde(skip_serializing_if = "Option::is_none")]
770        pub response_schema: Option<Schema>,
771        /// Number of generated responses to return. Currently, this value can only be set to 1. If
772        /// unset, this will default to 1.
773        #[serde(skip_serializing_if = "Option::is_none")]
774        pub candidate_count: Option<i32>,
775        /// The maximum number of tokens to include in a response candidate. Note: The default value varies by model, see
776        /// the Model.output_token_limit attribute of the Model returned from the getModel function.
777        #[serde(skip_serializing_if = "Option::is_none")]
778        pub max_output_tokens: Option<u64>,
779        /// Controls the randomness of the output. Note: The default value varies by model, see the Model.temperature
780        /// attribute of the Model returned from the getModel function. Values can range from [0.0, 2.0].
781        #[serde(skip_serializing_if = "Option::is_none")]
782        pub temperature: Option<f64>,
783        /// The maximum cumulative probability of tokens to consider when sampling. The model uses combined Top-k and
784        /// Top-p (nucleus) sampling. Tokens are sorted based on their assigned probabilities so that only the most
785        /// likely tokens are considered. Top-k sampling directly limits the maximum number of tokens to consider, while
786        /// Nucleus sampling limits the number of tokens based on the cumulative probability. Note: The default value
787        /// varies by Model and is specified by theModel.top_p attribute returned from the getModel function. An empty
788        /// topK attribute indicates that the model doesn't apply top-k sampling and doesn't allow setting topK on requests.
789        #[serde(skip_serializing_if = "Option::is_none")]
790        pub top_p: Option<f64>,
791        /// The maximum number of tokens to consider when sampling. Gemini models use Top-p (nucleus) sampling or a
792        /// combination of Top-k and nucleus sampling. Top-k sampling considers the set of topK most probable tokens.
793        /// Models running with nucleus sampling don't allow topK setting. Note: The default value varies by Model and is
794        /// specified by theModel.top_p attribute returned from the getModel function. An empty topK attribute indicates
795        /// that the model doesn't apply top-k sampling and doesn't allow setting topK on requests.
796        #[serde(skip_serializing_if = "Option::is_none")]
797        pub top_k: Option<i32>,
798        /// Presence penalty applied to the next token's logprobs if the token has already been seen in the response.
799        /// This penalty is binary on/off and not dependent on the number of times the token is used (after the first).
800        /// Use frequencyPenalty for a penalty that increases with each use. A positive penalty will discourage the use
801        /// of tokens that have already been used in the response, increasing the vocabulary. A negative penalty will
802        /// encourage the use of tokens that have already been used in the response, decreasing the vocabulary.
803        #[serde(skip_serializing_if = "Option::is_none")]
804        pub presence_penalty: Option<f64>,
805        /// Frequency penalty applied to the next token's logprobs, multiplied by the number of times each token has been
806        /// seen in the response so far. A positive penalty will discourage the use of tokens that have already been
807        /// used, proportional to the number of times the token has been used: The more a token is used, the more
808        /// difficult it is for the  model to use that token again increasing the vocabulary of responses. Caution: A
809        /// negative penalty will encourage the model to reuse tokens proportional to the number of times the token has
810        /// been used. Small negative values will reduce the vocabulary of a response. Larger negative values will cause
811        /// the model to  repeating a common token until it hits the maxOutputTokens limit: "...the the the the the...".
812        #[serde(skip_serializing_if = "Option::is_none")]
813        pub frequency_penalty: Option<f64>,
814        /// If true, export the logprobs results in response.
815        #[serde(skip_serializing_if = "Option::is_none")]
816        pub response_logprobs: Option<bool>,
817        /// Only valid if responseLogprobs=True. This sets the number of top logprobs to return at each decoding step in
818        /// [Candidate.logprobs_result].
819        #[serde(skip_serializing_if = "Option::is_none")]
820        pub logprobs: Option<i32>,
821    }
822
823    impl Default for GenerationConfig {
824        fn default() -> Self {
825            Self {
826                temperature: Some(1.0),
827                max_output_tokens: Some(4096),
828                stop_sequences: None,
829                response_mime_type: None,
830                response_schema: None,
831                candidate_count: None,
832                top_p: None,
833                top_k: None,
834                presence_penalty: None,
835                frequency_penalty: None,
836                response_logprobs: None,
837                logprobs: None,
838            }
839        }
840    }
841    /// The Schema object allows the definition of input and output data types. These types can be objects, but also
842    /// primitives and arrays. Represents a select subset of an OpenAPI 3.0 schema object.
843    /// From [Gemini API Reference](https://ai.google.dev/api/caching#Schema)
844    #[derive(Debug, Deserialize, Serialize)]
845    pub struct Schema {
846        pub r#type: String,
847        #[serde(skip_serializing_if = "Option::is_none")]
848        pub format: Option<String>,
849        #[serde(skip_serializing_if = "Option::is_none")]
850        pub description: Option<String>,
851        #[serde(skip_serializing_if = "Option::is_none")]
852        pub nullable: Option<bool>,
853        #[serde(skip_serializing_if = "Option::is_none")]
854        pub r#enum: Option<Vec<String>>,
855        #[serde(skip_serializing_if = "Option::is_none")]
856        pub max_items: Option<i32>,
857        #[serde(skip_serializing_if = "Option::is_none")]
858        pub min_items: Option<i32>,
859        #[serde(skip_serializing_if = "Option::is_none")]
860        pub properties: Option<HashMap<String, Schema>>,
861        #[serde(skip_serializing_if = "Option::is_none")]
862        pub required: Option<Vec<String>>,
863        #[serde(skip_serializing_if = "Option::is_none")]
864        pub items: Option<Box<Schema>>,
865    }
866
867    impl TryFrom<Value> for Schema {
868        type Error = CompletionError;
869
870        fn try_from(value: Value) -> Result<Self, Self::Error> {
871            if let Some(obj) = value.as_object() {
872                Ok(Schema {
873                    r#type: obj
874                        .get("type")
875                        .and_then(|v| {
876                            if v.is_string() {
877                                v.as_str().map(String::from)
878                            } else if v.is_array() {
879                                v.as_array()
880                                    .and_then(|arr| arr.first())
881                                    .and_then(|v| v.as_str().map(String::from))
882                            } else {
883                                None
884                            }
885                        })
886                        .unwrap_or_default(),
887                    format: obj.get("format").and_then(|v| v.as_str()).map(String::from),
888                    description: obj
889                        .get("description")
890                        .and_then(|v| v.as_str())
891                        .map(String::from),
892                    nullable: obj.get("nullable").and_then(|v| v.as_bool()),
893                    r#enum: obj.get("enum").and_then(|v| v.as_array()).map(|arr| {
894                        arr.iter()
895                            .filter_map(|v| v.as_str().map(String::from))
896                            .collect()
897                    }),
898                    max_items: obj
899                        .get("maxItems")
900                        .and_then(|v| v.as_i64())
901                        .map(|v| v as i32),
902                    min_items: obj
903                        .get("minItems")
904                        .and_then(|v| v.as_i64())
905                        .map(|v| v as i32),
906                    properties: obj
907                        .get("properties")
908                        .and_then(|v| v.as_object())
909                        .map(|map| {
910                            map.iter()
911                                .filter_map(|(k, v)| {
912                                    v.clone().try_into().ok().map(|schema| (k.clone(), schema))
913                                })
914                                .collect()
915                        }),
916                    required: obj.get("required").and_then(|v| v.as_array()).map(|arr| {
917                        arr.iter()
918                            .filter_map(|v| v.as_str().map(String::from))
919                            .collect()
920                    }),
921                    items: obj
922                        .get("items")
923                        .map(|v| Box::new(v.clone().try_into().unwrap())),
924                })
925            } else {
926                Err(CompletionError::ResponseError(
927                    "Expected a JSON object for Schema".into(),
928                ))
929            }
930        }
931    }
932
933    #[derive(Debug, Serialize)]
934    #[serde(rename_all = "camelCase")]
935    pub struct GenerateContentRequest {
936        pub contents: Vec<Content>,
937        pub tools: Option<Vec<Tool>>,
938        pub tool_config: Option<ToolConfig>,
939        /// Optional. Configuration options for model generation and outputs.
940        pub generation_config: Option<GenerationConfig>,
941        /// Optional. A list of unique SafetySetting instances for blocking unsafe content. This will be enforced on the
942        /// [GenerateContentRequest.contents] and [GenerateContentResponse.candidates]. There should not be more than one
943        /// setting for each SafetyCategory type. The API will block any contents and responses that fail to meet the
944        /// thresholds set by these settings. This list overrides the default settings for each SafetyCategory specified
945        /// in the safetySettings. If there is no SafetySetting for a given SafetyCategory provided in the list, the API
946        /// will use the default safety setting for that category. Harm categories:
947        ///     - HARM_CATEGORY_HATE_SPEECH,
948        ///     - HARM_CATEGORY_SEXUALLY_EXPLICIT
949        ///     - HARM_CATEGORY_DANGEROUS_CONTENT
950        ///     - HARM_CATEGORY_HARASSMENT
951        /// are supported.
952        /// Refer to the guide for detailed information on available safety settings. Also refer to the Safety guidance
953        /// to learn how to incorporate safety considerations in your AI applications.
954        pub safety_settings: Option<Vec<SafetySetting>>,
955        /// Optional. Developer set system instruction(s). Currently, text only.
956        /// From [Gemini API Reference](https://ai.google.dev/gemini-api/docs/system-instructions?lang=rest)
957        pub system_instruction: Option<Content>,
958        // cachedContent: Optional<String>
959    }
960
961    #[derive(Debug, Serialize)]
962    #[serde(rename_all = "camelCase")]
963    pub struct Tool {
964        pub function_declarations: FunctionDeclaration,
965        pub code_execution: Option<CodeExecution>,
966    }
967
968    #[derive(Debug, Serialize)]
969    #[serde(rename_all = "camelCase")]
970    pub struct FunctionDeclaration {
971        pub name: String,
972        pub description: String,
973        #[serde(skip_serializing_if = "Option::is_none")]
974        pub parameters: Option<Schema>,
975    }
976
977    #[derive(Debug, Serialize)]
978    #[serde(rename_all = "camelCase")]
979    pub struct ToolConfig {
980        pub schema: Option<Schema>,
981    }
982
983    #[derive(Debug, Serialize)]
984    #[serde(rename_all = "camelCase")]
985    pub struct CodeExecution {}
986
987    #[derive(Debug, Serialize)]
988    #[serde(rename_all = "camelCase")]
989    pub struct SafetySetting {
990        pub category: HarmCategory,
991        pub threshold: HarmBlockThreshold,
992    }
993
994    #[derive(Debug, Serialize)]
995    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
996    pub enum HarmBlockThreshold {
997        HarmBlockThresholdUnspecified,
998        BlockLowAndAbove,
999        BlockMediumAndAbove,
1000        BlockOnlyHigh,
1001        BlockNone,
1002        Off,
1003    }
1004}
1005
1006#[cfg(test)]
1007mod tests {
1008    use crate::message;
1009
1010    use super::*;
1011    use serde_json::json;
1012
1013    #[test]
1014    fn test_deserialize_message_user() {
1015        let raw_message = r#"{
1016            "parts": [
1017                {"text": "Hello, world!"},
1018                {"inlineData": {"mimeType": "image/png", "data": "base64encodeddata"}},
1019                {"functionCall": {"name": "test_function", "args": {"arg1": "value1"}}},
1020                {"functionResponse": {"name": "test_function", "response": {"result": "success"}}},
1021                {"fileData": {"mimeType": "application/pdf", "fileUri": "http://example.com/file.pdf"}},
1022                {"executableCode": {"code": "print('Hello, world!')", "language": "PYTHON"}},
1023                {"codeExecutionResult": {"output": "Hello, world!", "outcome": "OUTCOME_OK"}}
1024            ],
1025            "role": "user"
1026        }"#;
1027
1028        let content: Content = {
1029            let jd = &mut serde_json::Deserializer::from_str(raw_message);
1030            serde_path_to_error::deserialize(jd).unwrap_or_else(|err| {
1031                panic!("Deserialization error at {}: {}", err.path(), err);
1032            })
1033        };
1034        assert_eq!(content.role, Some(Role::User));
1035        assert_eq!(content.parts.len(), 7);
1036
1037        let parts: Vec<Part> = content.parts.into_iter().collect();
1038
1039        if let Part::Text(text) = &parts[0] {
1040            assert_eq!(text, "Hello, world!");
1041        } else {
1042            panic!("Expected text part");
1043        }
1044
1045        if let Part::InlineData(inline_data) = &parts[1] {
1046            assert_eq!(inline_data.mime_type, "image/png");
1047            assert_eq!(inline_data.data, "base64encodeddata");
1048        } else {
1049            panic!("Expected inline data part");
1050        }
1051
1052        if let Part::FunctionCall(function_call) = &parts[2] {
1053            assert_eq!(function_call.name, "test_function");
1054            assert_eq!(
1055                function_call.args.as_object().unwrap().get("arg1").unwrap(),
1056                "value1"
1057            );
1058        } else {
1059            panic!("Expected function call part");
1060        }
1061
1062        if let Part::FunctionResponse(function_response) = &parts[3] {
1063            assert_eq!(function_response.name, "test_function");
1064            assert_eq!(
1065                function_response
1066                    .response
1067                    .as_ref()
1068                    .unwrap()
1069                    .get("result")
1070                    .unwrap(),
1071                "success"
1072            );
1073        } else {
1074            panic!("Expected function response part");
1075        }
1076
1077        if let Part::FileData(file_data) = &parts[4] {
1078            assert_eq!(file_data.mime_type.as_ref().unwrap(), "application/pdf");
1079            assert_eq!(file_data.file_uri, "http://example.com/file.pdf");
1080        } else {
1081            panic!("Expected file data part");
1082        }
1083
1084        if let Part::ExecutableCode(executable_code) = &parts[5] {
1085            assert_eq!(executable_code.code, "print('Hello, world!')");
1086        } else {
1087            panic!("Expected executable code part");
1088        }
1089
1090        if let Part::CodeExecutionResult(code_execution_result) = &parts[6] {
1091            assert_eq!(
1092                code_execution_result.clone().output.unwrap(),
1093                "Hello, world!"
1094            );
1095        } else {
1096            panic!("Expected code execution result part");
1097        }
1098    }
1099
1100    #[test]
1101    fn test_deserialize_message_model() {
1102        let json_data = json!({
1103            "parts": [{"text": "Hello, user!"}],
1104            "role": "model"
1105        });
1106
1107        let content: Content = serde_json::from_value(json_data).unwrap();
1108        assert_eq!(content.role, Some(Role::Model));
1109        assert_eq!(content.parts.len(), 1);
1110        if let Part::Text(text) = &content.parts.first() {
1111            assert_eq!(text, "Hello, user!");
1112        } else {
1113            panic!("Expected text part");
1114        }
1115    }
1116
1117    #[test]
1118    fn test_message_conversion_user() {
1119        let msg = message::Message::user("Hello, world!");
1120        let content: Content = msg.try_into().unwrap();
1121        assert_eq!(content.role, Some(Role::User));
1122        assert_eq!(content.parts.len(), 1);
1123        if let Part::Text(text) = &content.parts.first() {
1124            assert_eq!(text, "Hello, world!");
1125        } else {
1126            panic!("Expected text part");
1127        }
1128    }
1129
1130    #[test]
1131    fn test_message_conversion_model() {
1132        let msg = message::Message::assistant("Hello, user!");
1133
1134        let content: Content = msg.try_into().unwrap();
1135        assert_eq!(content.role, Some(Role::Model));
1136        assert_eq!(content.parts.len(), 1);
1137        if let Part::Text(text) = &content.parts.first() {
1138            assert_eq!(text, "Hello, user!");
1139        } else {
1140            panic!("Expected text part");
1141        }
1142    }
1143
1144    #[test]
1145    fn test_message_conversion_tool_call() {
1146        let tool_call = message::ToolCall {
1147            id: "test_tool".to_string(),
1148            function: message::ToolFunction {
1149                name: "test_function".to_string(),
1150                arguments: json!({"arg1": "value1"}),
1151            },
1152        };
1153
1154        let msg = message::Message::Assistant {
1155            content: OneOrMany::one(message::AssistantContent::ToolCall(tool_call)),
1156        };
1157
1158        let content: Content = msg.try_into().unwrap();
1159        assert_eq!(content.role, Some(Role::Model));
1160        assert_eq!(content.parts.len(), 1);
1161        if let Part::FunctionCall(function_call) = &content.parts.first() {
1162            assert_eq!(function_call.name, "test_function");
1163            assert_eq!(
1164                function_call.args.as_object().unwrap().get("arg1").unwrap(),
1165                "value1"
1166            );
1167        } else {
1168            panic!("Expected function call part");
1169        }
1170    }
1171}