rig/providers/gemini/
completion.rs

1// ================================================================
2//! Google Gemini Completion Integration
3//! From [Gemini API Reference](https://ai.google.dev/api/generate-content)
4// ================================================================
5
6/// `gemini-2.0-flash` completion model
7pub const GEMINI_2_0_FLASH: &str = "gemini-2.0-flash";
8/// `gemini-1.5-flash` completion model
9pub const GEMINI_1_5_FLASH: &str = "gemini-1.5-flash";
10/// `gemini-1.5-pro` completion model
11pub const GEMINI_1_5_PRO: &str = "gemini-1.5-pro";
12/// `gemini-1.5-pro-8b` completion model
13pub const GEMINI_1_5_PRO_8B: &str = "gemini-1.5-pro-8b";
14/// `gemini-1.0-pro` completion model
15pub const GEMINI_1_0_PRO: &str = "gemini-1.0-pro";
16
17use gemini_api_types::{
18    Content, FunctionDeclaration, GenerateContentRequest, GenerateContentResponse,
19    GenerationConfig, Part, Role, Tool,
20};
21use serde_json::{Map, Value};
22use std::convert::TryFrom;
23
24use crate::{
25    completion::{self, CompletionError, CompletionRequest},
26    OneOrMany,
27};
28
29use super::Client;
30
31// =================================================================
32// Rig Implementation Types
33// =================================================================
34
35#[derive(Clone)]
36pub struct CompletionModel {
37    pub(crate) client: Client,
38    pub model: String,
39}
40
41impl CompletionModel {
42    pub fn new(client: Client, model: &str) -> Self {
43        Self {
44            client,
45            model: model.to_string(),
46        }
47    }
48}
49
50impl completion::CompletionModel for CompletionModel {
51    type Response = GenerateContentResponse;
52
53    #[cfg_attr(feature = "worker", worker::send)]
54    async fn completion(
55        &self,
56        completion_request: CompletionRequest,
57    ) -> Result<completion::CompletionResponse<GenerateContentResponse>, CompletionError> {
58        let request = create_request_body(completion_request)?;
59
60        tracing::debug!(
61            "Sending completion request to Gemini API {}",
62            serde_json::to_string_pretty(&request)?
63        );
64
65        let response = self
66            .client
67            .post(&format!("/v1beta/models/{}:generateContent", self.model))
68            .json(&request)
69            .send()
70            .await?;
71
72        if response.status().is_success() {
73            let response = response.json::<GenerateContentResponse>().await?;
74            match response.usage_metadata {
75                Some(ref usage) => tracing::info!(target: "rig",
76                "Gemini completion token usage: {}",
77                usage
78                ),
79                None => tracing::info!(target: "rig",
80                    "Gemini completion token usage: n/a",
81                ),
82            }
83
84            tracing::debug!("Received response");
85
86            Ok(completion::CompletionResponse::try_from(response))
87        } else {
88            Err(CompletionError::ProviderError(response.text().await?))
89        }?
90    }
91}
92
93pub(crate) fn create_request_body(
94    mut completion_request: CompletionRequest,
95) -> Result<GenerateContentRequest, CompletionError> {
96    let mut full_history = Vec::new();
97    full_history.append(&mut completion_request.chat_history);
98    full_history.push(completion_request.prompt_with_context());
99
100    let additional_params = completion_request
101        .additional_params
102        .unwrap_or_else(|| Value::Object(Map::new()));
103
104    let mut generation_config = serde_json::from_value::<GenerationConfig>(additional_params)?;
105
106    if let Some(temp) = completion_request.temperature {
107        generation_config.temperature = Some(temp);
108    }
109
110    if let Some(max_tokens) = completion_request.max_tokens {
111        generation_config.max_output_tokens = Some(max_tokens);
112    }
113
114    let system_instruction = completion_request.preamble.clone().map(|preamble| Content {
115        parts: OneOrMany::one(preamble.into()),
116        role: Some(Role::Model),
117    });
118
119    let request = GenerateContentRequest {
120        contents: full_history
121            .into_iter()
122            .map(|msg| {
123                msg.try_into()
124                    .map_err(|e| CompletionError::RequestError(Box::new(e)))
125            })
126            .collect::<Result<Vec<_>, _>>()?,
127        generation_config: Some(generation_config),
128        safety_settings: None,
129        tools: Some(
130            completion_request
131                .tools
132                .into_iter()
133                .map(Tool::try_from)
134                .collect::<Result<Vec<_>, _>>()?,
135        ),
136        tool_config: None,
137        system_instruction,
138    };
139
140    Ok(request)
141}
142
143impl TryFrom<completion::ToolDefinition> for Tool {
144    type Error = CompletionError;
145
146    fn try_from(tool: completion::ToolDefinition) -> Result<Self, Self::Error> {
147        Ok(Self {
148            function_declarations: FunctionDeclaration {
149                name: tool.name,
150                description: tool.description,
151                parameters: Some(tool.parameters.try_into()?),
152            },
153            code_execution: None,
154        })
155    }
156}
157
158impl TryFrom<GenerateContentResponse> for completion::CompletionResponse<GenerateContentResponse> {
159    type Error = CompletionError;
160
161    fn try_from(response: GenerateContentResponse) -> Result<Self, Self::Error> {
162        let candidate = response.candidates.first().ok_or_else(|| {
163            CompletionError::ResponseError("No response candidates in response".into())
164        })?;
165
166        let content = candidate
167            .content
168            .parts
169            .iter()
170            .map(|part| {
171                Ok(match part {
172                    Part::Text(text) => completion::AssistantContent::text(text),
173                    Part::FunctionCall(function_call) => completion::AssistantContent::tool_call(
174                        &function_call.name,
175                        &function_call.name,
176                        function_call.args.clone(),
177                    ),
178                    _ => {
179                        return Err(CompletionError::ResponseError(
180                            "Response did not contain a message or tool call".into(),
181                        ))
182                    }
183                })
184            })
185            .collect::<Result<Vec<_>, _>>()?;
186
187        let choice = OneOrMany::many(content).map_err(|_| {
188            CompletionError::ResponseError(
189                "Response contained no message or tool call (empty)".to_owned(),
190            )
191        })?;
192
193        Ok(completion::CompletionResponse {
194            choice,
195            raw_response: response,
196        })
197    }
198}
199
200pub mod gemini_api_types {
201    use std::{collections::HashMap, convert::Infallible, str::FromStr};
202
203    // =================================================================
204    // Gemini API Types
205    // =================================================================
206    use serde::{Deserialize, Serialize};
207    use serde_json::Value;
208
209    use crate::{
210        completion::CompletionError,
211        message::{self, MimeType as _},
212        one_or_many::string_or_one_or_many,
213        providers::gemini::gemini_api_types::{CodeExecutionResult, ExecutableCode},
214        OneOrMany,
215    };
216
217    /// Response from the model supporting multiple candidate responses.
218    /// Safety ratings and content filtering are reported for both prompt in GenerateContentResponse.prompt_feedback
219    /// and for each candidate in finishReason and in safetyRatings.
220    /// The API:
221    ///     - Returns either all requested candidates or none of them
222    ///     - Returns no candidates at all only if there was something wrong with the prompt (check promptFeedback)
223    ///     - Reports feedback on each candidate in finishReason and safetyRatings.
224    #[derive(Debug, Deserialize)]
225    #[serde(rename_all = "camelCase")]
226    pub struct GenerateContentResponse {
227        /// Candidate responses from the model.
228        pub candidates: Vec<ContentCandidate>,
229        /// Returns the prompt's feedback related to the content filters.
230        pub prompt_feedback: Option<PromptFeedback>,
231        /// Output only. Metadata on the generation requests' token usage.
232        pub usage_metadata: Option<UsageMetadata>,
233        pub model_version: Option<String>,
234    }
235
236    /// A response candidate generated from the model.
237    #[derive(Debug, Deserialize)]
238    #[serde(rename_all = "camelCase")]
239    pub struct ContentCandidate {
240        /// Output only. Generated content returned from the model.
241        pub content: Content,
242        /// Optional. Output only. The reason why the model stopped generating tokens.
243        /// If empty, the model has not stopped generating tokens.
244        pub finish_reason: Option<FinishReason>,
245        /// List of ratings for the safety of a response candidate.
246        /// There is at most one rating per category.
247        pub safety_ratings: Option<Vec<SafetyRating>>,
248        /// Output only. Citation information for model-generated candidate.
249        /// This field may be populated with recitation information for any text included in the content.
250        /// These are passages that are "recited" from copyrighted material in the foundational LLM's training data.
251        pub citation_metadata: Option<CitationMetadata>,
252        /// Output only. Token count for this candidate.
253        pub token_count: Option<i32>,
254        /// Output only.
255        pub avg_logprobs: Option<f64>,
256        /// Output only. Log-likelihood scores for the response tokens and top tokens
257        pub logprobs_result: Option<LogprobsResult>,
258        /// Output only. Index of the candidate in the list of response candidates.
259        pub index: Option<i32>,
260    }
261    #[derive(Debug, Deserialize, Serialize)]
262    pub struct Content {
263        /// Ordered Parts that constitute a single message. Parts may have different MIME types.
264        #[serde(deserialize_with = "string_or_one_or_many")]
265        pub parts: OneOrMany<Part>,
266        /// The producer of the content. Must be either 'user' or 'model'.
267        /// Useful to set for multi-turn conversations, otherwise can be left blank or unset.
268        pub role: Option<Role>,
269    }
270
271    impl TryFrom<message::Message> for Content {
272        type Error = message::MessageError;
273
274        fn try_from(msg: message::Message) -> Result<Self, Self::Error> {
275            Ok(match msg {
276                message::Message::User { content } => Content {
277                    parts: content.try_map(|c| c.try_into())?,
278                    role: Some(Role::User),
279                },
280                message::Message::Assistant { content } => Content {
281                    role: Some(Role::Model),
282                    parts: content.map(|content| content.into()),
283                },
284            })
285        }
286    }
287
288    impl TryFrom<Content> for message::Message {
289        type Error = message::MessageError;
290
291        fn try_from(content: Content) -> Result<Self, Self::Error> {
292            match content.role {
293                Some(Role::User) | None => Ok(message::Message::User {
294                    content: content.parts.try_map(|part| {
295                        Ok(match part {
296                            Part::Text(text) => message::UserContent::text(text),
297                            Part::InlineData(inline_data) => {
298                                let mime_type =
299                                    message::MediaType::from_mime_type(&inline_data.mime_type);
300
301                                match mime_type {
302                                    Some(message::MediaType::Image(media_type)) => {
303                                        message::UserContent::image(
304                                            inline_data.data,
305                                            Some(message::ContentFormat::default()),
306                                            Some(media_type),
307                                            Some(message::ImageDetail::default()),
308                                        )
309                                    }
310                                    Some(message::MediaType::Document(media_type)) => {
311                                        message::UserContent::document(
312                                            inline_data.data,
313                                            Some(message::ContentFormat::default()),
314                                            Some(media_type),
315                                        )
316                                    }
317                                    Some(message::MediaType::Audio(media_type)) => {
318                                        message::UserContent::audio(
319                                            inline_data.data,
320                                            Some(message::ContentFormat::default()),
321                                            Some(media_type),
322                                        )
323                                    }
324                                    _ => {
325                                        return Err(message::MessageError::ConversionError(
326                                            format!("Unsupported media type {:?}", mime_type),
327                                        ))
328                                    }
329                                }
330                            }
331                            _ => {
332                                return Err(message::MessageError::ConversionError(format!(
333                                    "Unsupported gemini content part type: {:?}",
334                                    part
335                                )))
336                            }
337                        })
338                    })?,
339                }),
340                Some(Role::Model) => Ok(message::Message::Assistant {
341                    content: content.parts.try_map(|part| {
342                        Ok(match part {
343                            Part::Text(text) => message::AssistantContent::text(text),
344                            Part::FunctionCall(function_call) => {
345                                message::AssistantContent::ToolCall(function_call.into())
346                            }
347                            _ => {
348                                return Err(message::MessageError::ConversionError(format!(
349                                    "Unsupported part type: {:?}",
350                                    part
351                                )))
352                            }
353                        })
354                    })?,
355                }),
356            }
357        }
358    }
359
360    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
361    #[serde(rename_all = "lowercase")]
362    pub enum Role {
363        User,
364        Model,
365    }
366
367    /// A datatype containing media that is part of a multi-part [Content] message.
368    /// A Part consists of data which has an associated datatype. A Part can only contain one of the accepted types in Part.data.
369    /// A Part must have a fixed IANA MIME type identifying the type and subtype of the media if the inlineData field is filled with raw bytes.
370    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
371    #[serde(rename_all = "camelCase")]
372    pub enum Part {
373        Text(String),
374        InlineData(Blob),
375        FunctionCall(FunctionCall),
376        FunctionResponse(FunctionResponse),
377        FileData(FileData),
378        ExecutableCode(ExecutableCode),
379        CodeExecutionResult(CodeExecutionResult),
380    }
381
382    impl From<String> for Part {
383        fn from(text: String) -> Self {
384            Self::Text(text)
385        }
386    }
387
388    impl From<&str> for Part {
389        fn from(text: &str) -> Self {
390            Self::Text(text.to_string())
391        }
392    }
393
394    impl FromStr for Part {
395        type Err = Infallible;
396
397        fn from_str(s: &str) -> Result<Self, Self::Err> {
398            Ok(s.into())
399        }
400    }
401
402    impl TryFrom<message::UserContent> for Part {
403        type Error = message::MessageError;
404
405        fn try_from(content: message::UserContent) -> Result<Self, Self::Error> {
406            match content {
407                message::UserContent::Text(message::Text { text }) => Ok(Self::Text(text)),
408                message::UserContent::ToolResult(message::ToolResult { id, content }) => {
409                    let content = match content.first() {
410                        message::ToolResultContent::Text(text) => text.text,
411                        message::ToolResultContent::Image(_) => {
412                            return Err(message::MessageError::ConversionError(
413                                "Tool result content must be text".to_string(),
414                            ))
415                        }
416                    };
417                    Ok(Part::FunctionResponse(FunctionResponse {
418                        name: id,
419                        response: Some(serde_json::from_str(&content).map_err(|e| {
420                            message::MessageError::ConversionError(format!(
421                                "Failed to parse tool response: {}",
422                                e
423                            ))
424                        })?),
425                    }))
426                }
427                message::UserContent::Image(message::Image {
428                    data, media_type, ..
429                }) => match media_type {
430                    Some(media_type) => match media_type {
431                        message::ImageMediaType::JPEG
432                        | message::ImageMediaType::PNG
433                        | message::ImageMediaType::WEBP
434                        | message::ImageMediaType::HEIC
435                        | message::ImageMediaType::HEIF => Ok(Self::InlineData(Blob {
436                            mime_type: media_type.to_mime_type().to_owned(),
437                            data,
438                        })),
439                        _ => Err(message::MessageError::ConversionError(format!(
440                            "Unsupported image media type {:?}",
441                            media_type
442                        ))),
443                    },
444                    None => Err(message::MessageError::ConversionError(
445                        "Media type for image is required for Anthropic".to_string(),
446                    )),
447                },
448                message::UserContent::Document(message::Document {
449                    data, media_type, ..
450                }) => match media_type {
451                    Some(media_type) => match media_type {
452                        message::DocumentMediaType::PDF
453                        | message::DocumentMediaType::TXT
454                        | message::DocumentMediaType::RTF
455                        | message::DocumentMediaType::HTML
456                        | message::DocumentMediaType::CSS
457                        | message::DocumentMediaType::MARKDOWN
458                        | message::DocumentMediaType::CSV
459                        | message::DocumentMediaType::XML => Ok(Self::InlineData(Blob {
460                            mime_type: media_type.to_mime_type().to_owned(),
461                            data,
462                        })),
463                        _ => Err(message::MessageError::ConversionError(format!(
464                            "Unsupported document media type {:?}",
465                            media_type
466                        ))),
467                    },
468                    None => Err(message::MessageError::ConversionError(
469                        "Media type for document is required for Anthropic".to_string(),
470                    )),
471                },
472                message::UserContent::Audio(message::Audio {
473                    data, media_type, ..
474                }) => match media_type {
475                    Some(media_type) => Ok(Self::InlineData(Blob {
476                        mime_type: media_type.to_mime_type().to_owned(),
477                        data,
478                    })),
479                    None => Err(message::MessageError::ConversionError(
480                        "Media type for audio is required for Anthropic".to_string(),
481                    )),
482                },
483            }
484        }
485    }
486
487    impl From<message::AssistantContent> for Part {
488        fn from(content: message::AssistantContent) -> Self {
489            match content {
490                message::AssistantContent::Text(message::Text { text }) => text.into(),
491                message::AssistantContent::ToolCall(tool_call) => tool_call.into(),
492            }
493        }
494    }
495
496    impl From<message::ToolCall> for Part {
497        fn from(tool_call: message::ToolCall) -> Self {
498            Self::FunctionCall(FunctionCall {
499                name: tool_call.function.name,
500                args: tool_call.function.arguments,
501            })
502        }
503    }
504
505    /// Raw media bytes.
506    /// Text should not be sent as raw bytes, use the 'text' field.
507    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
508    #[serde(rename_all = "camelCase")]
509    pub struct Blob {
510        /// The IANA standard MIME type of the source data. Examples: - image/png - image/jpeg
511        /// If an unsupported MIME type is provided, an error will be returned.
512        pub mime_type: String,
513        /// Raw bytes for media formats. A base64-encoded string.
514        pub data: String,
515    }
516
517    /// A predicted FunctionCall returned from the model that contains a string representing the
518    /// FunctionDeclaration.name with the arguments and their values.
519    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
520    pub struct FunctionCall {
521        /// Required. The name of the function to call. Must be a-z, A-Z, 0-9, or contain underscores
522        /// and dashes, with a maximum length of 63.
523        pub name: String,
524        /// Optional. The function parameters and values in JSON object format.
525        pub args: serde_json::Value,
526    }
527
528    impl From<FunctionCall> for message::ToolCall {
529        fn from(function_call: FunctionCall) -> Self {
530            Self {
531                id: function_call.name.clone(),
532                function: message::ToolFunction {
533                    name: function_call.name,
534                    arguments: function_call.args,
535                },
536            }
537        }
538    }
539
540    impl From<message::ToolCall> for FunctionCall {
541        fn from(tool_call: message::ToolCall) -> Self {
542            Self {
543                name: tool_call.function.name,
544                args: tool_call.function.arguments,
545            }
546        }
547    }
548
549    /// The result output from a FunctionCall that contains a string representing the FunctionDeclaration.name
550    /// and a structured JSON object containing any output from the function is used as context to the model.
551    /// This should contain the result of aFunctionCall made based on model prediction.
552    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
553    pub struct FunctionResponse {
554        /// The name of the function to call. Must be a-z, A-Z, 0-9, or contain underscores and dashes,
555        /// with a maximum length of 63.
556        pub name: String,
557        /// The function response in JSON object format.
558        pub response: Option<HashMap<String, serde_json::Value>>,
559    }
560
561    /// URI based data.
562    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
563    #[serde(rename_all = "camelCase")]
564    pub struct FileData {
565        /// Optional. The IANA standard MIME type of the source data.
566        pub mime_type: Option<String>,
567        /// Required. URI.
568        pub file_uri: String,
569    }
570
571    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
572    pub struct SafetyRating {
573        pub category: HarmCategory,
574        pub probability: HarmProbability,
575    }
576
577    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
578    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
579    pub enum HarmProbability {
580        HarmProbabilityUnspecified,
581        Negligible,
582        Low,
583        Medium,
584        High,
585    }
586
587    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
588    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
589    pub enum HarmCategory {
590        HarmCategoryUnspecified,
591        HarmCategoryDerogatory,
592        HarmCategoryToxicity,
593        HarmCategoryViolence,
594        HarmCategorySexually,
595        HarmCategoryMedical,
596        HarmCategoryDangerous,
597        HarmCategoryHarassment,
598        HarmCategoryHateSpeech,
599        HarmCategorySexuallyExplicit,
600        HarmCategoryDangerousContent,
601        HarmCategoryCivicIntegrity,
602    }
603
604    #[derive(Debug, Deserialize)]
605    #[serde(rename_all = "camelCase")]
606    pub struct UsageMetadata {
607        pub prompt_token_count: i32,
608        #[serde(skip_serializing_if = "Option::is_none")]
609        pub cached_content_token_count: Option<i32>,
610        pub candidates_token_count: i32,
611        pub total_token_count: i32,
612    }
613
614    impl std::fmt::Display for UsageMetadata {
615        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
616            write!(
617                f,
618                "Prompt token count: {}\nCached content token count: {}\nCandidates token count: {}\nTotal token count: {}",
619                self.prompt_token_count,
620                match self.cached_content_token_count {
621                    Some(count) => count.to_string(),
622                    None => "n/a".to_string(),
623                },
624                self.candidates_token_count,
625                self.total_token_count
626            )
627        }
628    }
629
630    /// A set of the feedback metadata the prompt specified in [GenerateContentRequest.contents](GenerateContentRequest).
631    #[derive(Debug, Deserialize)]
632    #[serde(rename_all = "camelCase")]
633    pub struct PromptFeedback {
634        /// Optional. If set, the prompt was blocked and no candidates are returned. Rephrase the prompt.
635        pub block_reason: Option<BlockReason>,
636        /// Ratings for safety of the prompt. There is at most one rating per category.
637        pub safety_ratings: Option<Vec<SafetyRating>>,
638    }
639
640    /// Reason why a prompt was blocked by the model
641    #[derive(Debug, Deserialize)]
642    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
643    pub enum BlockReason {
644        /// Default value. This value is unused.
645        BlockReasonUnspecified,
646        /// Prompt was blocked due to safety reasons. Inspect safetyRatings to understand which safety category blocked it.
647        Safety,
648        /// Prompt was blocked due to unknown reasons.
649        Other,
650        /// Prompt was blocked due to the terms which are included from the terminology blocklist.
651        Blocklist,
652        /// Prompt was blocked due to prohibited content.
653        ProhibitedContent,
654    }
655
656    #[derive(Debug, Deserialize)]
657    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
658    pub enum FinishReason {
659        /// Default value. This value is unused.
660        FinishReasonUnspecified,
661        /// Natural stop point of the model or provided stop sequence.
662        Stop,
663        /// The maximum number of tokens as specified in the request was reached.
664        MaxTokens,
665        /// The response candidate content was flagged for safety reasons.
666        Safety,
667        /// The response candidate content was flagged for recitation reasons.
668        Recitation,
669        /// The response candidate content was flagged for using an unsupported language.
670        Language,
671        /// Unknown reason.
672        Other,
673        /// Token generation stopped because the content contains forbidden terms.
674        Blocklist,
675        /// Token generation stopped for potentially containing prohibited content.
676        ProhibitedContent,
677        /// Token generation stopped because the content potentially contains Sensitive Personally Identifiable Information (SPII).
678        Spii,
679        /// The function call generated by the model is invalid.
680        MalformedFunctionCall,
681    }
682
683    #[derive(Debug, Deserialize)]
684    #[serde(rename_all = "camelCase")]
685    pub struct CitationMetadata {
686        pub citation_sources: Vec<CitationSource>,
687    }
688
689    #[derive(Debug, Deserialize)]
690    #[serde(rename_all = "camelCase")]
691    pub struct CitationSource {
692        #[serde(skip_serializing_if = "Option::is_none")]
693        pub uri: Option<String>,
694        #[serde(skip_serializing_if = "Option::is_none")]
695        pub start_index: Option<i32>,
696        #[serde(skip_serializing_if = "Option::is_none")]
697        pub end_index: Option<i32>,
698        #[serde(skip_serializing_if = "Option::is_none")]
699        pub license: Option<String>,
700    }
701
702    #[derive(Debug, Deserialize)]
703    #[serde(rename_all = "camelCase")]
704    pub struct LogprobsResult {
705        pub top_candidate: Vec<TopCandidate>,
706        pub chosen_candidate: Vec<LogProbCandidate>,
707    }
708
709    #[derive(Debug, Deserialize)]
710    pub struct TopCandidate {
711        pub candidates: Vec<LogProbCandidate>,
712    }
713
714    #[derive(Debug, Deserialize)]
715    #[serde(rename_all = "camelCase")]
716    pub struct LogProbCandidate {
717        pub token: String,
718        pub token_id: String,
719        pub log_probability: f64,
720    }
721
722    /// Gemini API Configuration options for model generation and outputs. Not all parameters are
723    /// configurable for every model. From [Gemini API Reference](https://ai.google.dev/api/generate-content#generationconfig)
724    /// ### Rig Note:
725    /// Can be used to cosntruct a typesafe `additional_params` in rig::[AgentBuilder](crate::agent::AgentBuilder).
726    #[derive(Debug, Deserialize, Serialize)]
727    #[serde(rename_all = "camelCase")]
728    pub struct GenerationConfig {
729        /// The set of character sequences (up to 5) that will stop output generation. If specified, the API will stop
730        /// at the first appearance of a stop_sequence. The stop sequence will not be included as part of the response.
731        #[serde(skip_serializing_if = "Option::is_none")]
732        pub stop_sequences: Option<Vec<String>>,
733        /// MIME type of the generated candidate text. Supported MIME types are:
734        ///     - text/plain:  (default) Text output
735        ///     - application/json: JSON response in the response candidates.
736        ///     - text/x.enum: ENUM as a string response in the response candidates.
737        /// Refer to the docs for a list of all supported text MIME types
738        #[serde(skip_serializing_if = "Option::is_none")]
739        pub response_mime_type: Option<String>,
740        /// Output schema of the generated candidate text. Schemas must be a subset of the OpenAPI schema and can be
741        /// objects, primitives or arrays. If set, a compatible responseMimeType must also  be set. Compatible MIME
742        /// types: application/json: Schema for JSON response. Refer to the JSON text generation guide for more details.
743        #[serde(skip_serializing_if = "Option::is_none")]
744        pub response_schema: Option<Schema>,
745        /// Number of generated responses to return. Currently, this value can only be set to 1. If
746        /// unset, this will default to 1.
747        #[serde(skip_serializing_if = "Option::is_none")]
748        pub candidate_count: Option<i32>,
749        /// The maximum number of tokens to include in a response candidate. Note: The default value varies by model, see
750        /// the Model.output_token_limit attribute of the Model returned from the getModel function.
751        #[serde(skip_serializing_if = "Option::is_none")]
752        pub max_output_tokens: Option<u64>,
753        /// Controls the randomness of the output. Note: The default value varies by model, see the Model.temperature
754        /// attribute of the Model returned from the getModel function. Values can range from [0.0, 2.0].
755        #[serde(skip_serializing_if = "Option::is_none")]
756        pub temperature: Option<f64>,
757        /// The maximum cumulative probability of tokens to consider when sampling. The model uses combined Top-k and
758        /// Top-p (nucleus) sampling. Tokens are sorted based on their assigned probabilities so that only the most
759        /// likely tokens are considered. Top-k sampling directly limits the maximum number of tokens to consider, while
760        /// Nucleus sampling limits the number of tokens based on the cumulative probability. Note: The default value
761        /// varies by Model and is specified by theModel.top_p attribute returned from the getModel function. An empty
762        /// topK attribute indicates that the model doesn't apply top-k sampling and doesn't allow setting topK on requests.
763        #[serde(skip_serializing_if = "Option::is_none")]
764        pub top_p: Option<f64>,
765        /// The maximum number of tokens to consider when sampling. Gemini models use Top-p (nucleus) sampling or a
766        /// combination of Top-k and nucleus sampling. Top-k sampling considers the set of topK most probable tokens.
767        /// Models running with nucleus sampling don't allow topK setting. Note: The default value varies by Model and is
768        /// specified by theModel.top_p attribute returned from the getModel function. An empty topK attribute indicates
769        /// that the model doesn't apply top-k sampling and doesn't allow setting topK on requests.
770        #[serde(skip_serializing_if = "Option::is_none")]
771        pub top_k: Option<i32>,
772        /// Presence penalty applied to the next token's logprobs if the token has already been seen in the response.
773        /// This penalty is binary on/off and not dependent on the number of times the token is used (after the first).
774        /// Use frequencyPenalty for a penalty that increases with each use. A positive penalty will discourage the use
775        /// of tokens that have already been used in the response, increasing the vocabulary. A negative penalty will
776        /// encourage the use of tokens that have already been used in the response, decreasing the vocabulary.
777        #[serde(skip_serializing_if = "Option::is_none")]
778        pub presence_penalty: Option<f64>,
779        /// Frequency penalty applied to the next token's logprobs, multiplied by the number of times each token has been
780        /// seen in the response so far. A positive penalty will discourage the use of tokens that have already been
781        /// used, proportional to the number of times the token has been used: The more a token is used, the more
782        /// difficult it is for the  model to use that token again increasing the vocabulary of responses. Caution: A
783        /// negative penalty will encourage the model to reuse tokens proportional to the number of times the token has
784        /// been used. Small negative values will reduce the vocabulary of a response. Larger negative values will cause
785        /// the model to  repeating a common token until it hits the maxOutputTokens limit: "...the the the the the...".
786        #[serde(skip_serializing_if = "Option::is_none")]
787        pub frequency_penalty: Option<f64>,
788        /// If true, export the logprobs results in response.
789        #[serde(skip_serializing_if = "Option::is_none")]
790        pub response_logprobs: Option<bool>,
791        /// Only valid if responseLogprobs=True. This sets the number of top logprobs to return at each decoding step in
792        /// [Candidate.logprobs_result].
793        #[serde(skip_serializing_if = "Option::is_none")]
794        pub logprobs: Option<i32>,
795    }
796
797    impl Default for GenerationConfig {
798        fn default() -> Self {
799            Self {
800                temperature: Some(1.0),
801                max_output_tokens: Some(4096),
802                stop_sequences: None,
803                response_mime_type: None,
804                response_schema: None,
805                candidate_count: None,
806                top_p: None,
807                top_k: None,
808                presence_penalty: None,
809                frequency_penalty: None,
810                response_logprobs: None,
811                logprobs: None,
812            }
813        }
814    }
815    /// The Schema object allows the definition of input and output data types. These types can be objects, but also
816    /// primitives and arrays. Represents a select subset of an OpenAPI 3.0 schema object.
817    /// From [Gemini API Reference](https://ai.google.dev/api/caching#Schema)
818    #[derive(Debug, Deserialize, Serialize)]
819    pub struct Schema {
820        pub r#type: String,
821        #[serde(skip_serializing_if = "Option::is_none")]
822        pub format: Option<String>,
823        #[serde(skip_serializing_if = "Option::is_none")]
824        pub description: Option<String>,
825        #[serde(skip_serializing_if = "Option::is_none")]
826        pub nullable: Option<bool>,
827        #[serde(skip_serializing_if = "Option::is_none")]
828        pub r#enum: Option<Vec<String>>,
829        #[serde(skip_serializing_if = "Option::is_none")]
830        pub max_items: Option<i32>,
831        #[serde(skip_serializing_if = "Option::is_none")]
832        pub min_items: Option<i32>,
833        #[serde(skip_serializing_if = "Option::is_none")]
834        pub properties: Option<HashMap<String, Schema>>,
835        #[serde(skip_serializing_if = "Option::is_none")]
836        pub required: Option<Vec<String>>,
837        #[serde(skip_serializing_if = "Option::is_none")]
838        pub items: Option<Box<Schema>>,
839    }
840
841    impl TryFrom<Value> for Schema {
842        type Error = CompletionError;
843
844        fn try_from(value: Value) -> Result<Self, Self::Error> {
845            if let Some(obj) = value.as_object() {
846                Ok(Schema {
847                    r#type: obj
848                        .get("type")
849                        .and_then(|v| {
850                            if v.is_string() {
851                                v.as_str().map(String::from)
852                            } else if v.is_array() {
853                                v.as_array()
854                                    .and_then(|arr| arr.first())
855                                    .and_then(|v| v.as_str().map(String::from))
856                            } else {
857                                None
858                            }
859                        })
860                        .unwrap_or_default(),
861                    format: obj.get("format").and_then(|v| v.as_str()).map(String::from),
862                    description: obj
863                        .get("description")
864                        .and_then(|v| v.as_str())
865                        .map(String::from),
866                    nullable: obj.get("nullable").and_then(|v| v.as_bool()),
867                    r#enum: obj.get("enum").and_then(|v| v.as_array()).map(|arr| {
868                        arr.iter()
869                            .filter_map(|v| v.as_str().map(String::from))
870                            .collect()
871                    }),
872                    max_items: obj
873                        .get("maxItems")
874                        .and_then(|v| v.as_i64())
875                        .map(|v| v as i32),
876                    min_items: obj
877                        .get("minItems")
878                        .and_then(|v| v.as_i64())
879                        .map(|v| v as i32),
880                    properties: obj
881                        .get("properties")
882                        .and_then(|v| v.as_object())
883                        .map(|map| {
884                            map.iter()
885                                .filter_map(|(k, v)| {
886                                    v.clone().try_into().ok().map(|schema| (k.clone(), schema))
887                                })
888                                .collect()
889                        }),
890                    required: obj.get("required").and_then(|v| v.as_array()).map(|arr| {
891                        arr.iter()
892                            .filter_map(|v| v.as_str().map(String::from))
893                            .collect()
894                    }),
895                    items: obj
896                        .get("items")
897                        .map(|v| Box::new(v.clone().try_into().unwrap())),
898                })
899            } else {
900                Err(CompletionError::ResponseError(
901                    "Expected a JSON object for Schema".into(),
902                ))
903            }
904        }
905    }
906
907    #[derive(Debug, Serialize)]
908    #[serde(rename_all = "camelCase")]
909    pub struct GenerateContentRequest {
910        pub contents: Vec<Content>,
911        pub tools: Option<Vec<Tool>>,
912        pub tool_config: Option<ToolConfig>,
913        /// Optional. Configuration options for model generation and outputs.
914        pub generation_config: Option<GenerationConfig>,
915        /// Optional. A list of unique SafetySetting instances for blocking unsafe content. This will be enforced on the
916        /// [GenerateContentRequest.contents] and [GenerateContentResponse.candidates]. There should not be more than one
917        /// setting for each SafetyCategory type. The API will block any contents and responses that fail to meet the
918        /// thresholds set by these settings. This list overrides the default settings for each SafetyCategory specified
919        /// in the safetySettings. If there is no SafetySetting for a given SafetyCategory provided in the list, the API
920        /// will use the default safety setting for that category. Harm categories:
921        ///     - HARM_CATEGORY_HATE_SPEECH,
922        ///     - HARM_CATEGORY_SEXUALLY_EXPLICIT
923        ///     - HARM_CATEGORY_DANGEROUS_CONTENT
924        ///     - HARM_CATEGORY_HARASSMENT
925        /// are supported.
926        /// Refer to the guide for detailed information on available safety settings. Also refer to the Safety guidance
927        /// to learn how to incorporate safety considerations in your AI applications.
928        pub safety_settings: Option<Vec<SafetySetting>>,
929        /// Optional. Developer set system instruction(s). Currently, text only.
930        /// From [Gemini API Reference](https://ai.google.dev/gemini-api/docs/system-instructions?lang=rest)
931        pub system_instruction: Option<Content>,
932        // cachedContent: Optional<String>
933    }
934
935    #[derive(Debug, Serialize)]
936    #[serde(rename_all = "camelCase")]
937    pub struct Tool {
938        pub function_declarations: FunctionDeclaration,
939        pub code_execution: Option<CodeExecution>,
940    }
941
942    #[derive(Debug, Serialize)]
943    #[serde(rename_all = "camelCase")]
944    pub struct FunctionDeclaration {
945        pub name: String,
946        pub description: String,
947        #[serde(skip_serializing_if = "Option::is_none")]
948        pub parameters: Option<Schema>,
949    }
950
951    #[derive(Debug, Serialize)]
952    #[serde(rename_all = "camelCase")]
953    pub struct ToolConfig {
954        pub schema: Option<Schema>,
955    }
956
957    #[derive(Debug, Serialize)]
958    #[serde(rename_all = "camelCase")]
959    pub struct CodeExecution {}
960
961    #[derive(Debug, Serialize)]
962    #[serde(rename_all = "camelCase")]
963    pub struct SafetySetting {
964        pub category: HarmCategory,
965        pub threshold: HarmBlockThreshold,
966    }
967
968    #[derive(Debug, Serialize)]
969    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
970    pub enum HarmBlockThreshold {
971        HarmBlockThresholdUnspecified,
972        BlockLowAndAbove,
973        BlockMediumAndAbove,
974        BlockOnlyHigh,
975        BlockNone,
976        Off,
977    }
978}
979
980#[cfg(test)]
981mod tests {
982    use crate::message;
983
984    use super::*;
985    use serde_json::json;
986
987    #[test]
988    fn test_deserialize_message_user() {
989        let raw_message = r#"{
990            "parts": [
991                {"text": "Hello, world!"},
992                {"inlineData": {"mimeType": "image/png", "data": "base64encodeddata"}},
993                {"functionCall": {"name": "test_function", "args": {"arg1": "value1"}}},
994                {"functionResponse": {"name": "test_function", "response": {"result": "success"}}},
995                {"fileData": {"mimeType": "application/pdf", "fileUri": "http://example.com/file.pdf"}},
996                {"executableCode": {"code": "print('Hello, world!')", "language": "PYTHON"}},
997                {"codeExecutionResult": {"output": "Hello, world!", "outcome": "OUTCOME_OK"}}
998            ],
999            "role": "user"
1000        }"#;
1001
1002        let content: Content = {
1003            let jd = &mut serde_json::Deserializer::from_str(raw_message);
1004            serde_path_to_error::deserialize(jd).unwrap_or_else(|err| {
1005                panic!("Deserialization error at {}: {}", err.path(), err);
1006            })
1007        };
1008        assert_eq!(content.role, Some(Role::User));
1009        assert_eq!(content.parts.len(), 7);
1010
1011        let parts: Vec<Part> = content.parts.into_iter().collect();
1012
1013        if let Part::Text(text) = &parts[0] {
1014            assert_eq!(text, "Hello, world!");
1015        } else {
1016            panic!("Expected text part");
1017        }
1018
1019        if let Part::InlineData(inline_data) = &parts[1] {
1020            assert_eq!(inline_data.mime_type, "image/png");
1021            assert_eq!(inline_data.data, "base64encodeddata");
1022        } else {
1023            panic!("Expected inline data part");
1024        }
1025
1026        if let Part::FunctionCall(function_call) = &parts[2] {
1027            assert_eq!(function_call.name, "test_function");
1028            assert_eq!(
1029                function_call.args.as_object().unwrap().get("arg1").unwrap(),
1030                "value1"
1031            );
1032        } else {
1033            panic!("Expected function call part");
1034        }
1035
1036        if let Part::FunctionResponse(function_response) = &parts[3] {
1037            assert_eq!(function_response.name, "test_function");
1038            assert_eq!(
1039                function_response
1040                    .response
1041                    .as_ref()
1042                    .unwrap()
1043                    .get("result")
1044                    .unwrap(),
1045                "success"
1046            );
1047        } else {
1048            panic!("Expected function response part");
1049        }
1050
1051        if let Part::FileData(file_data) = &parts[4] {
1052            assert_eq!(file_data.mime_type.as_ref().unwrap(), "application/pdf");
1053            assert_eq!(file_data.file_uri, "http://example.com/file.pdf");
1054        } else {
1055            panic!("Expected file data part");
1056        }
1057
1058        if let Part::ExecutableCode(executable_code) = &parts[5] {
1059            assert_eq!(executable_code.code, "print('Hello, world!')");
1060        } else {
1061            panic!("Expected executable code part");
1062        }
1063
1064        if let Part::CodeExecutionResult(code_execution_result) = &parts[6] {
1065            assert_eq!(
1066                code_execution_result.clone().output.unwrap(),
1067                "Hello, world!"
1068            );
1069        } else {
1070            panic!("Expected code execution result part");
1071        }
1072    }
1073
1074    #[test]
1075    fn test_deserialize_message_model() {
1076        let json_data = json!({
1077            "parts": [{"text": "Hello, user!"}],
1078            "role": "model"
1079        });
1080
1081        let content: Content = serde_json::from_value(json_data).unwrap();
1082        assert_eq!(content.role, Some(Role::Model));
1083        assert_eq!(content.parts.len(), 1);
1084        if let Part::Text(text) = &content.parts.first() {
1085            assert_eq!(text, "Hello, user!");
1086        } else {
1087            panic!("Expected text part");
1088        }
1089    }
1090
1091    #[test]
1092    fn test_message_conversion_user() {
1093        let msg = message::Message::user("Hello, world!");
1094        let content: Content = msg.try_into().unwrap();
1095        assert_eq!(content.role, Some(Role::User));
1096        assert_eq!(content.parts.len(), 1);
1097        if let Part::Text(text) = &content.parts.first() {
1098            assert_eq!(text, "Hello, world!");
1099        } else {
1100            panic!("Expected text part");
1101        }
1102    }
1103
1104    #[test]
1105    fn test_message_conversion_model() {
1106        let msg = message::Message::assistant("Hello, user!");
1107
1108        let content: Content = msg.try_into().unwrap();
1109        assert_eq!(content.role, Some(Role::Model));
1110        assert_eq!(content.parts.len(), 1);
1111        if let Part::Text(text) = &content.parts.first() {
1112            assert_eq!(text, "Hello, user!");
1113        } else {
1114            panic!("Expected text part");
1115        }
1116    }
1117
1118    #[test]
1119    fn test_message_conversion_tool_call() {
1120        let tool_call = message::ToolCall {
1121            id: "test_tool".to_string(),
1122            function: message::ToolFunction {
1123                name: "test_function".to_string(),
1124                arguments: json!({"arg1": "value1"}),
1125            },
1126        };
1127
1128        let msg = message::Message::Assistant {
1129            content: OneOrMany::one(message::AssistantContent::ToolCall(tool_call)),
1130        };
1131
1132        let content: Content = msg.try_into().unwrap();
1133        assert_eq!(content.role, Some(Role::Model));
1134        assert_eq!(content.parts.len(), 1);
1135        if let Part::FunctionCall(function_call) = &content.parts.first() {
1136            assert_eq!(function_call.name, "test_function");
1137            assert_eq!(
1138                function_call.args.as_object().unwrap().get("arg1").unwrap(),
1139                "value1"
1140            );
1141        } else {
1142            panic!("Expected function call part");
1143        }
1144    }
1145}