rig/providers/gemini/
completion.rs

1// ================================================================
2//! Google Gemini Completion Integration
3//! From [Gemini API Reference](https://ai.google.dev/api/generate-content)
4// ================================================================
5
6/// `gemini-2.0-flash` completion model
7pub const GEMINI_2_0_FLASH: &str = "gemini-2.0-flash";
8/// `gemini-1.5-flash` completion model
9pub const GEMINI_1_5_FLASH: &str = "gemini-1.5-flash";
10/// `gemini-1.5-pro` completion model
11pub const GEMINI_1_5_PRO: &str = "gemini-1.5-pro";
12/// `gemini-1.5-pro-8b` completion model
13pub const GEMINI_1_5_PRO_8B: &str = "gemini-1.5-pro-8b";
14/// `gemini-1.0-pro` completion model
15pub const GEMINI_1_0_PRO: &str = "gemini-1.0-pro";
16
17use gemini_api_types::{
18    Content, FunctionDeclaration, GenerateContentRequest, GenerateContentResponse,
19    GenerationConfig, Part, Role, Tool,
20};
21use serde_json::{Map, Value};
22use std::convert::TryFrom;
23
24use crate::{
25    completion::{self, CompletionError, CompletionRequest},
26    OneOrMany,
27};
28
29use self::gemini_api_types::Schema;
30
31use super::Client;
32
33// =================================================================
34// Rig Implementation Types
35// =================================================================
36
37#[derive(Clone)]
38pub struct CompletionModel {
39    pub(crate) client: Client,
40    pub model: String,
41}
42
43impl CompletionModel {
44    pub fn new(client: Client, model: &str) -> Self {
45        Self {
46            client,
47            model: model.to_string(),
48        }
49    }
50}
51
52impl completion::CompletionModel for CompletionModel {
53    type Response = GenerateContentResponse;
54
55    #[cfg_attr(feature = "worker", worker::send)]
56    async fn completion(
57        &self,
58        completion_request: CompletionRequest,
59    ) -> Result<completion::CompletionResponse<GenerateContentResponse>, CompletionError> {
60        let request = create_request_body(completion_request)?;
61
62        tracing::debug!(
63            "Sending completion request to Gemini API {}",
64            serde_json::to_string_pretty(&request)?
65        );
66
67        let response = self
68            .client
69            .post(&format!("/v1beta/models/{}:generateContent", self.model))
70            .json(&request)
71            .send()
72            .await?;
73
74        if response.status().is_success() {
75            let response = response.json::<GenerateContentResponse>().await?;
76            match response.usage_metadata {
77                Some(ref usage) => tracing::info!(target: "rig",
78                "Gemini completion token usage: {}",
79                usage
80                ),
81                None => tracing::info!(target: "rig",
82                    "Gemini completion token usage: n/a",
83                ),
84            }
85
86            tracing::debug!("Received response");
87
88            Ok(completion::CompletionResponse::try_from(response))
89        } else {
90            Err(CompletionError::ProviderError(response.text().await?))
91        }?
92    }
93}
94
95pub(crate) fn create_request_body(
96    completion_request: CompletionRequest,
97) -> Result<GenerateContentRequest, CompletionError> {
98    let mut full_history = Vec::new();
99    full_history.extend(completion_request.chat_history);
100
101    let additional_params = completion_request
102        .additional_params
103        .unwrap_or_else(|| Value::Object(Map::new()));
104
105    let mut generation_config = serde_json::from_value::<GenerationConfig>(additional_params)?;
106
107    if let Some(temp) = completion_request.temperature {
108        generation_config.temperature = Some(temp);
109    }
110
111    if let Some(max_tokens) = completion_request.max_tokens {
112        generation_config.max_output_tokens = Some(max_tokens);
113    }
114
115    let system_instruction = completion_request.preamble.clone().map(|preamble| Content {
116        parts: OneOrMany::one(preamble.into()),
117        role: Some(Role::Model),
118    });
119
120    let request = GenerateContentRequest {
121        contents: full_history
122            .into_iter()
123            .map(|msg| {
124                msg.try_into()
125                    .map_err(|e| CompletionError::RequestError(Box::new(e)))
126            })
127            .collect::<Result<Vec<_>, _>>()?,
128        generation_config: Some(generation_config),
129        safety_settings: None,
130        tools: Some(
131            completion_request
132                .tools
133                .into_iter()
134                .map(Tool::try_from)
135                .collect::<Result<Vec<_>, _>>()?,
136        ),
137        tool_config: None,
138        system_instruction,
139    };
140
141    Ok(request)
142}
143
144impl TryFrom<completion::ToolDefinition> for Tool {
145    type Error = CompletionError;
146
147    fn try_from(tool: completion::ToolDefinition) -> Result<Self, Self::Error> {
148        let parameters: Option<Schema> =
149            if tool.parameters == serde_json::json!({"type": "object", "properties": {}}) {
150                None
151            } else {
152                Some(tool.parameters.try_into()?)
153            };
154        Ok(Self {
155            function_declarations: FunctionDeclaration {
156                name: tool.name,
157                description: tool.description,
158                parameters,
159            },
160            code_execution: None,
161        })
162    }
163}
164
165impl TryFrom<GenerateContentResponse> for completion::CompletionResponse<GenerateContentResponse> {
166    type Error = CompletionError;
167
168    fn try_from(response: GenerateContentResponse) -> Result<Self, Self::Error> {
169        let candidate = response.candidates.first().ok_or_else(|| {
170            CompletionError::ResponseError("No response candidates in response".into())
171        })?;
172
173        let content = candidate
174            .content
175            .parts
176            .iter()
177            .map(|part| {
178                Ok(match part {
179                    Part::Text(text) => completion::AssistantContent::text(text),
180                    Part::FunctionCall(function_call) => completion::AssistantContent::tool_call(
181                        &function_call.name,
182                        &function_call.name,
183                        function_call.args.clone(),
184                    ),
185                    _ => {
186                        return Err(CompletionError::ResponseError(
187                            "Response did not contain a message or tool call".into(),
188                        ))
189                    }
190                })
191            })
192            .collect::<Result<Vec<_>, _>>()?;
193
194        let choice = OneOrMany::many(content).map_err(|_| {
195            CompletionError::ResponseError(
196                "Response contained no message or tool call (empty)".to_owned(),
197            )
198        })?;
199
200        Ok(completion::CompletionResponse {
201            choice,
202            raw_response: response,
203        })
204    }
205}
206
207pub mod gemini_api_types {
208    use std::{collections::HashMap, convert::Infallible, str::FromStr};
209
210    // =================================================================
211    // Gemini API Types
212    // =================================================================
213    use serde::{Deserialize, Serialize};
214    use serde_json::Value;
215
216    use crate::{
217        completion::CompletionError,
218        message::{self, MimeType as _},
219        one_or_many::string_or_one_or_many,
220        providers::gemini::gemini_api_types::{CodeExecutionResult, ExecutableCode},
221        OneOrMany,
222    };
223
224    /// Response from the model supporting multiple candidate responses.
225    /// Safety ratings and content filtering are reported for both prompt in GenerateContentResponse.prompt_feedback
226    /// and for each candidate in finishReason and in safetyRatings.
227    /// The API:
228    ///     - Returns either all requested candidates or none of them
229    ///     - Returns no candidates at all only if there was something wrong with the prompt (check promptFeedback)
230    ///     - Reports feedback on each candidate in finishReason and safetyRatings.
231    #[derive(Debug, Deserialize)]
232    #[serde(rename_all = "camelCase")]
233    pub struct GenerateContentResponse {
234        /// Candidate responses from the model.
235        pub candidates: Vec<ContentCandidate>,
236        /// Returns the prompt's feedback related to the content filters.
237        pub prompt_feedback: Option<PromptFeedback>,
238        /// Output only. Metadata on the generation requests' token usage.
239        pub usage_metadata: Option<UsageMetadata>,
240        pub model_version: Option<String>,
241    }
242
243    /// A response candidate generated from the model.
244    #[derive(Debug, Deserialize)]
245    #[serde(rename_all = "camelCase")]
246    pub struct ContentCandidate {
247        /// Output only. Generated content returned from the model.
248        pub content: Content,
249        /// Optional. Output only. The reason why the model stopped generating tokens.
250        /// If empty, the model has not stopped generating tokens.
251        pub finish_reason: Option<FinishReason>,
252        /// List of ratings for the safety of a response candidate.
253        /// There is at most one rating per category.
254        pub safety_ratings: Option<Vec<SafetyRating>>,
255        /// Output only. Citation information for model-generated candidate.
256        /// This field may be populated with recitation information for any text included in the content.
257        /// These are passages that are "recited" from copyrighted material in the foundational LLM's training data.
258        pub citation_metadata: Option<CitationMetadata>,
259        /// Output only. Token count for this candidate.
260        pub token_count: Option<i32>,
261        /// Output only.
262        pub avg_logprobs: Option<f64>,
263        /// Output only. Log-likelihood scores for the response tokens and top tokens
264        pub logprobs_result: Option<LogprobsResult>,
265        /// Output only. Index of the candidate in the list of response candidates.
266        pub index: Option<i32>,
267    }
268    #[derive(Debug, Deserialize, Serialize)]
269    pub struct Content {
270        /// Ordered Parts that constitute a single message. Parts may have different MIME types.
271        #[serde(deserialize_with = "string_or_one_or_many")]
272        pub parts: OneOrMany<Part>,
273        /// The producer of the content. Must be either 'user' or 'model'.
274        /// Useful to set for multi-turn conversations, otherwise can be left blank or unset.
275        pub role: Option<Role>,
276    }
277
278    impl TryFrom<message::Message> for Content {
279        type Error = message::MessageError;
280
281        fn try_from(msg: message::Message) -> Result<Self, Self::Error> {
282            Ok(match msg {
283                message::Message::User { content } => Content {
284                    parts: content.try_map(|c| c.try_into())?,
285                    role: Some(Role::User),
286                },
287                message::Message::Assistant { content } => Content {
288                    role: Some(Role::Model),
289                    parts: content.map(|content| content.into()),
290                },
291            })
292        }
293    }
294
295    impl TryFrom<Content> for message::Message {
296        type Error = message::MessageError;
297
298        fn try_from(content: Content) -> Result<Self, Self::Error> {
299            match content.role {
300                Some(Role::User) | None => Ok(message::Message::User {
301                    content: content.parts.try_map(|part| {
302                        Ok(match part {
303                            Part::Text(text) => message::UserContent::text(text),
304                            Part::InlineData(inline_data) => {
305                                let mime_type =
306                                    message::MediaType::from_mime_type(&inline_data.mime_type);
307
308                                match mime_type {
309                                    Some(message::MediaType::Image(media_type)) => {
310                                        message::UserContent::image(
311                                            inline_data.data,
312                                            Some(message::ContentFormat::default()),
313                                            Some(media_type),
314                                            Some(message::ImageDetail::default()),
315                                        )
316                                    }
317                                    Some(message::MediaType::Document(media_type)) => {
318                                        message::UserContent::document(
319                                            inline_data.data,
320                                            Some(message::ContentFormat::default()),
321                                            Some(media_type),
322                                        )
323                                    }
324                                    Some(message::MediaType::Audio(media_type)) => {
325                                        message::UserContent::audio(
326                                            inline_data.data,
327                                            Some(message::ContentFormat::default()),
328                                            Some(media_type),
329                                        )
330                                    }
331                                    _ => {
332                                        return Err(message::MessageError::ConversionError(
333                                            format!("Unsupported media type {:?}", mime_type),
334                                        ))
335                                    }
336                                }
337                            }
338                            _ => {
339                                return Err(message::MessageError::ConversionError(format!(
340                                    "Unsupported gemini content part type: {:?}",
341                                    part
342                                )))
343                            }
344                        })
345                    })?,
346                }),
347                Some(Role::Model) => Ok(message::Message::Assistant {
348                    content: content.parts.try_map(|part| {
349                        Ok(match part {
350                            Part::Text(text) => message::AssistantContent::text(text),
351                            Part::FunctionCall(function_call) => {
352                                message::AssistantContent::ToolCall(function_call.into())
353                            }
354                            _ => {
355                                return Err(message::MessageError::ConversionError(format!(
356                                    "Unsupported part type: {:?}",
357                                    part
358                                )))
359                            }
360                        })
361                    })?,
362                }),
363            }
364        }
365    }
366
367    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
368    #[serde(rename_all = "lowercase")]
369    pub enum Role {
370        User,
371        Model,
372    }
373
374    /// A datatype containing media that is part of a multi-part [Content] message.
375    /// A Part consists of data which has an associated datatype. A Part can only contain one of the accepted types in Part.data.
376    /// A Part must have a fixed IANA MIME type identifying the type and subtype of the media if the inlineData field is filled with raw bytes.
377    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
378    #[serde(rename_all = "camelCase")]
379    pub enum Part {
380        Text(String),
381        InlineData(Blob),
382        FunctionCall(FunctionCall),
383        FunctionResponse(FunctionResponse),
384        FileData(FileData),
385        ExecutableCode(ExecutableCode),
386        CodeExecutionResult(CodeExecutionResult),
387    }
388
389    impl From<String> for Part {
390        fn from(text: String) -> Self {
391            Self::Text(text)
392        }
393    }
394
395    impl From<&str> for Part {
396        fn from(text: &str) -> Self {
397            Self::Text(text.to_string())
398        }
399    }
400
401    impl FromStr for Part {
402        type Err = Infallible;
403
404        fn from_str(s: &str) -> Result<Self, Self::Err> {
405            Ok(s.into())
406        }
407    }
408
409    impl TryFrom<message::UserContent> for Part {
410        type Error = message::MessageError;
411
412        fn try_from(content: message::UserContent) -> Result<Self, Self::Error> {
413            match content {
414                message::UserContent::Text(message::Text { text }) => Ok(Self::Text(text)),
415                message::UserContent::ToolResult(message::ToolResult { id, content }) => {
416                    let content = match content.first() {
417                        message::ToolResultContent::Text(text) => text.text,
418                        message::ToolResultContent::Image(_) => {
419                            return Err(message::MessageError::ConversionError(
420                                "Tool result content must be text".to_string(),
421                            ))
422                        }
423                    };
424                    Ok(Part::FunctionResponse(FunctionResponse {
425                        name: id,
426                        response: Some(serde_json::from_str(&content).map_err(|e| {
427                            message::MessageError::ConversionError(format!(
428                                "Failed to parse tool response: {}",
429                                e
430                            ))
431                        })?),
432                    }))
433                }
434                message::UserContent::Image(message::Image {
435                    data, media_type, ..
436                }) => match media_type {
437                    Some(media_type) => match media_type {
438                        message::ImageMediaType::JPEG
439                        | message::ImageMediaType::PNG
440                        | message::ImageMediaType::WEBP
441                        | message::ImageMediaType::HEIC
442                        | message::ImageMediaType::HEIF => Ok(Self::InlineData(Blob {
443                            mime_type: media_type.to_mime_type().to_owned(),
444                            data,
445                        })),
446                        _ => Err(message::MessageError::ConversionError(format!(
447                            "Unsupported image media type {:?}",
448                            media_type
449                        ))),
450                    },
451                    None => Err(message::MessageError::ConversionError(
452                        "Media type for image is required for Anthropic".to_string(),
453                    )),
454                },
455                message::UserContent::Document(message::Document {
456                    data, media_type, ..
457                }) => match media_type {
458                    Some(media_type) => match media_type {
459                        message::DocumentMediaType::PDF
460                        | message::DocumentMediaType::TXT
461                        | message::DocumentMediaType::RTF
462                        | message::DocumentMediaType::HTML
463                        | message::DocumentMediaType::CSS
464                        | message::DocumentMediaType::MARKDOWN
465                        | message::DocumentMediaType::CSV
466                        | message::DocumentMediaType::XML => Ok(Self::InlineData(Blob {
467                            mime_type: media_type.to_mime_type().to_owned(),
468                            data,
469                        })),
470                        _ => Err(message::MessageError::ConversionError(format!(
471                            "Unsupported document media type {:?}",
472                            media_type
473                        ))),
474                    },
475                    None => Err(message::MessageError::ConversionError(
476                        "Media type for document is required for Anthropic".to_string(),
477                    )),
478                },
479                message::UserContent::Audio(message::Audio {
480                    data, media_type, ..
481                }) => match media_type {
482                    Some(media_type) => Ok(Self::InlineData(Blob {
483                        mime_type: media_type.to_mime_type().to_owned(),
484                        data,
485                    })),
486                    None => Err(message::MessageError::ConversionError(
487                        "Media type for audio is required for Anthropic".to_string(),
488                    )),
489                },
490            }
491        }
492    }
493
494    impl From<message::AssistantContent> for Part {
495        fn from(content: message::AssistantContent) -> Self {
496            match content {
497                message::AssistantContent::Text(message::Text { text }) => text.into(),
498                message::AssistantContent::ToolCall(tool_call) => tool_call.into(),
499            }
500        }
501    }
502
503    impl From<message::ToolCall> for Part {
504        fn from(tool_call: message::ToolCall) -> Self {
505            Self::FunctionCall(FunctionCall {
506                name: tool_call.function.name,
507                args: tool_call.function.arguments,
508            })
509        }
510    }
511
512    /// Raw media bytes.
513    /// Text should not be sent as raw bytes, use the 'text' field.
514    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
515    #[serde(rename_all = "camelCase")]
516    pub struct Blob {
517        /// The IANA standard MIME type of the source data. Examples: - image/png - image/jpeg
518        /// If an unsupported MIME type is provided, an error will be returned.
519        pub mime_type: String,
520        /// Raw bytes for media formats. A base64-encoded string.
521        pub data: String,
522    }
523
524    /// A predicted FunctionCall returned from the model that contains a string representing the
525    /// FunctionDeclaration.name with the arguments and their values.
526    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
527    pub struct FunctionCall {
528        /// Required. The name of the function to call. Must be a-z, A-Z, 0-9, or contain underscores
529        /// and dashes, with a maximum length of 63.
530        pub name: String,
531        /// Optional. The function parameters and values in JSON object format.
532        pub args: serde_json::Value,
533    }
534
535    impl From<FunctionCall> for message::ToolCall {
536        fn from(function_call: FunctionCall) -> Self {
537            Self {
538                id: function_call.name.clone(),
539                function: message::ToolFunction {
540                    name: function_call.name,
541                    arguments: function_call.args,
542                },
543            }
544        }
545    }
546
547    impl From<message::ToolCall> for FunctionCall {
548        fn from(tool_call: message::ToolCall) -> Self {
549            Self {
550                name: tool_call.function.name,
551                args: tool_call.function.arguments,
552            }
553        }
554    }
555
556    /// The result output from a FunctionCall that contains a string representing the FunctionDeclaration.name
557    /// and a structured JSON object containing any output from the function is used as context to the model.
558    /// This should contain the result of aFunctionCall made based on model prediction.
559    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
560    pub struct FunctionResponse {
561        /// The name of the function to call. Must be a-z, A-Z, 0-9, or contain underscores and dashes,
562        /// with a maximum length of 63.
563        pub name: String,
564        /// The function response in JSON object format.
565        pub response: Option<HashMap<String, serde_json::Value>>,
566    }
567
568    /// URI based data.
569    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
570    #[serde(rename_all = "camelCase")]
571    pub struct FileData {
572        /// Optional. The IANA standard MIME type of the source data.
573        pub mime_type: Option<String>,
574        /// Required. URI.
575        pub file_uri: String,
576    }
577
578    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
579    pub struct SafetyRating {
580        pub category: HarmCategory,
581        pub probability: HarmProbability,
582    }
583
584    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
585    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
586    pub enum HarmProbability {
587        HarmProbabilityUnspecified,
588        Negligible,
589        Low,
590        Medium,
591        High,
592    }
593
594    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
595    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
596    pub enum HarmCategory {
597        HarmCategoryUnspecified,
598        HarmCategoryDerogatory,
599        HarmCategoryToxicity,
600        HarmCategoryViolence,
601        HarmCategorySexually,
602        HarmCategoryMedical,
603        HarmCategoryDangerous,
604        HarmCategoryHarassment,
605        HarmCategoryHateSpeech,
606        HarmCategorySexuallyExplicit,
607        HarmCategoryDangerousContent,
608        HarmCategoryCivicIntegrity,
609    }
610
611    #[derive(Debug, Deserialize)]
612    #[serde(rename_all = "camelCase")]
613    pub struct UsageMetadata {
614        pub prompt_token_count: i32,
615        #[serde(skip_serializing_if = "Option::is_none")]
616        pub cached_content_token_count: Option<i32>,
617        pub candidates_token_count: i32,
618        pub total_token_count: i32,
619    }
620
621    impl std::fmt::Display for UsageMetadata {
622        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
623            write!(
624                f,
625                "Prompt token count: {}\nCached content token count: {}\nCandidates token count: {}\nTotal token count: {}",
626                self.prompt_token_count,
627                match self.cached_content_token_count {
628                    Some(count) => count.to_string(),
629                    None => "n/a".to_string(),
630                },
631                self.candidates_token_count,
632                self.total_token_count
633            )
634        }
635    }
636
637    /// A set of the feedback metadata the prompt specified in [GenerateContentRequest.contents](GenerateContentRequest).
638    #[derive(Debug, Deserialize)]
639    #[serde(rename_all = "camelCase")]
640    pub struct PromptFeedback {
641        /// Optional. If set, the prompt was blocked and no candidates are returned. Rephrase the prompt.
642        pub block_reason: Option<BlockReason>,
643        /// Ratings for safety of the prompt. There is at most one rating per category.
644        pub safety_ratings: Option<Vec<SafetyRating>>,
645    }
646
647    /// Reason why a prompt was blocked by the model
648    #[derive(Debug, Deserialize)]
649    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
650    pub enum BlockReason {
651        /// Default value. This value is unused.
652        BlockReasonUnspecified,
653        /// Prompt was blocked due to safety reasons. Inspect safetyRatings to understand which safety category blocked it.
654        Safety,
655        /// Prompt was blocked due to unknown reasons.
656        Other,
657        /// Prompt was blocked due to the terms which are included from the terminology blocklist.
658        Blocklist,
659        /// Prompt was blocked due to prohibited content.
660        ProhibitedContent,
661    }
662
663    #[derive(Debug, Deserialize)]
664    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
665    pub enum FinishReason {
666        /// Default value. This value is unused.
667        FinishReasonUnspecified,
668        /// Natural stop point of the model or provided stop sequence.
669        Stop,
670        /// The maximum number of tokens as specified in the request was reached.
671        MaxTokens,
672        /// The response candidate content was flagged for safety reasons.
673        Safety,
674        /// The response candidate content was flagged for recitation reasons.
675        Recitation,
676        /// The response candidate content was flagged for using an unsupported language.
677        Language,
678        /// Unknown reason.
679        Other,
680        /// Token generation stopped because the content contains forbidden terms.
681        Blocklist,
682        /// Token generation stopped for potentially containing prohibited content.
683        ProhibitedContent,
684        /// Token generation stopped because the content potentially contains Sensitive Personally Identifiable Information (SPII).
685        Spii,
686        /// The function call generated by the model is invalid.
687        MalformedFunctionCall,
688    }
689
690    #[derive(Debug, Deserialize)]
691    #[serde(rename_all = "camelCase")]
692    pub struct CitationMetadata {
693        pub citation_sources: Vec<CitationSource>,
694    }
695
696    #[derive(Debug, Deserialize)]
697    #[serde(rename_all = "camelCase")]
698    pub struct CitationSource {
699        #[serde(skip_serializing_if = "Option::is_none")]
700        pub uri: Option<String>,
701        #[serde(skip_serializing_if = "Option::is_none")]
702        pub start_index: Option<i32>,
703        #[serde(skip_serializing_if = "Option::is_none")]
704        pub end_index: Option<i32>,
705        #[serde(skip_serializing_if = "Option::is_none")]
706        pub license: Option<String>,
707    }
708
709    #[derive(Debug, Deserialize)]
710    #[serde(rename_all = "camelCase")]
711    pub struct LogprobsResult {
712        pub top_candidate: Vec<TopCandidate>,
713        pub chosen_candidate: Vec<LogProbCandidate>,
714    }
715
716    #[derive(Debug, Deserialize)]
717    pub struct TopCandidate {
718        pub candidates: Vec<LogProbCandidate>,
719    }
720
721    #[derive(Debug, Deserialize)]
722    #[serde(rename_all = "camelCase")]
723    pub struct LogProbCandidate {
724        pub token: String,
725        pub token_id: String,
726        pub log_probability: f64,
727    }
728
729    /// Gemini API Configuration options for model generation and outputs. Not all parameters are
730    /// configurable for every model. From [Gemini API Reference](https://ai.google.dev/api/generate-content#generationconfig)
731    /// ### Rig Note:
732    /// Can be used to cosntruct a typesafe `additional_params` in rig::[AgentBuilder](crate::agent::AgentBuilder).
733    #[derive(Debug, Deserialize, Serialize)]
734    #[serde(rename_all = "camelCase")]
735    pub struct GenerationConfig {
736        /// The set of character sequences (up to 5) that will stop output generation. If specified, the API will stop
737        /// at the first appearance of a stop_sequence. The stop sequence will not be included as part of the response.
738        #[serde(skip_serializing_if = "Option::is_none")]
739        pub stop_sequences: Option<Vec<String>>,
740        /// MIME type of the generated candidate text. Supported MIME types are:
741        ///     - text/plain:  (default) Text output
742        ///     - application/json: JSON response in the response candidates.
743        ///     - text/x.enum: ENUM as a string response in the response candidates.
744        /// Refer to the docs for a list of all supported text MIME types
745        #[serde(skip_serializing_if = "Option::is_none")]
746        pub response_mime_type: Option<String>,
747        /// Output schema of the generated candidate text. Schemas must be a subset of the OpenAPI schema and can be
748        /// objects, primitives or arrays. If set, a compatible responseMimeType must also  be set. Compatible MIME
749        /// types: application/json: Schema for JSON response. Refer to the JSON text generation guide for more details.
750        #[serde(skip_serializing_if = "Option::is_none")]
751        pub response_schema: Option<Schema>,
752        /// Number of generated responses to return. Currently, this value can only be set to 1. If
753        /// unset, this will default to 1.
754        #[serde(skip_serializing_if = "Option::is_none")]
755        pub candidate_count: Option<i32>,
756        /// The maximum number of tokens to include in a response candidate. Note: The default value varies by model, see
757        /// the Model.output_token_limit attribute of the Model returned from the getModel function.
758        #[serde(skip_serializing_if = "Option::is_none")]
759        pub max_output_tokens: Option<u64>,
760        /// Controls the randomness of the output. Note: The default value varies by model, see the Model.temperature
761        /// attribute of the Model returned from the getModel function. Values can range from [0.0, 2.0].
762        #[serde(skip_serializing_if = "Option::is_none")]
763        pub temperature: Option<f64>,
764        /// The maximum cumulative probability of tokens to consider when sampling. The model uses combined Top-k and
765        /// Top-p (nucleus) sampling. Tokens are sorted based on their assigned probabilities so that only the most
766        /// likely tokens are considered. Top-k sampling directly limits the maximum number of tokens to consider, while
767        /// Nucleus sampling limits the number of tokens based on the cumulative probability. Note: The default value
768        /// varies by Model and is specified by theModel.top_p attribute returned from the getModel function. An empty
769        /// topK attribute indicates that the model doesn't apply top-k sampling and doesn't allow setting topK on requests.
770        #[serde(skip_serializing_if = "Option::is_none")]
771        pub top_p: Option<f64>,
772        /// The maximum number of tokens to consider when sampling. Gemini models use Top-p (nucleus) sampling or a
773        /// combination of Top-k and nucleus sampling. Top-k sampling considers the set of topK most probable tokens.
774        /// Models running with nucleus sampling don't allow topK setting. Note: The default value varies by Model and is
775        /// specified by theModel.top_p attribute returned from the getModel function. An empty topK attribute indicates
776        /// that the model doesn't apply top-k sampling and doesn't allow setting topK on requests.
777        #[serde(skip_serializing_if = "Option::is_none")]
778        pub top_k: Option<i32>,
779        /// Presence penalty applied to the next token's logprobs if the token has already been seen in the response.
780        /// This penalty is binary on/off and not dependent on the number of times the token is used (after the first).
781        /// Use frequencyPenalty for a penalty that increases with each use. A positive penalty will discourage the use
782        /// of tokens that have already been used in the response, increasing the vocabulary. A negative penalty will
783        /// encourage the use of tokens that have already been used in the response, decreasing the vocabulary.
784        #[serde(skip_serializing_if = "Option::is_none")]
785        pub presence_penalty: Option<f64>,
786        /// Frequency penalty applied to the next token's logprobs, multiplied by the number of times each token has been
787        /// seen in the response so far. A positive penalty will discourage the use of tokens that have already been
788        /// used, proportional to the number of times the token has been used: The more a token is used, the more
789        /// difficult it is for the  model to use that token again increasing the vocabulary of responses. Caution: A
790        /// negative penalty will encourage the model to reuse tokens proportional to the number of times the token has
791        /// been used. Small negative values will reduce the vocabulary of a response. Larger negative values will cause
792        /// the model to  repeating a common token until it hits the maxOutputTokens limit: "...the the the the the...".
793        #[serde(skip_serializing_if = "Option::is_none")]
794        pub frequency_penalty: Option<f64>,
795        /// If true, export the logprobs results in response.
796        #[serde(skip_serializing_if = "Option::is_none")]
797        pub response_logprobs: Option<bool>,
798        /// Only valid if responseLogprobs=True. This sets the number of top logprobs to return at each decoding step in
799        /// [Candidate.logprobs_result].
800        #[serde(skip_serializing_if = "Option::is_none")]
801        pub logprobs: Option<i32>,
802    }
803
804    impl Default for GenerationConfig {
805        fn default() -> Self {
806            Self {
807                temperature: Some(1.0),
808                max_output_tokens: Some(4096),
809                stop_sequences: None,
810                response_mime_type: None,
811                response_schema: None,
812                candidate_count: None,
813                top_p: None,
814                top_k: None,
815                presence_penalty: None,
816                frequency_penalty: None,
817                response_logprobs: None,
818                logprobs: None,
819            }
820        }
821    }
822    /// The Schema object allows the definition of input and output data types. These types can be objects, but also
823    /// primitives and arrays. Represents a select subset of an OpenAPI 3.0 schema object.
824    /// From [Gemini API Reference](https://ai.google.dev/api/caching#Schema)
825    #[derive(Debug, Deserialize, Serialize)]
826    pub struct Schema {
827        pub r#type: String,
828        #[serde(skip_serializing_if = "Option::is_none")]
829        pub format: Option<String>,
830        #[serde(skip_serializing_if = "Option::is_none")]
831        pub description: Option<String>,
832        #[serde(skip_serializing_if = "Option::is_none")]
833        pub nullable: Option<bool>,
834        #[serde(skip_serializing_if = "Option::is_none")]
835        pub r#enum: Option<Vec<String>>,
836        #[serde(skip_serializing_if = "Option::is_none")]
837        pub max_items: Option<i32>,
838        #[serde(skip_serializing_if = "Option::is_none")]
839        pub min_items: Option<i32>,
840        #[serde(skip_serializing_if = "Option::is_none")]
841        pub properties: Option<HashMap<String, Schema>>,
842        #[serde(skip_serializing_if = "Option::is_none")]
843        pub required: Option<Vec<String>>,
844        #[serde(skip_serializing_if = "Option::is_none")]
845        pub items: Option<Box<Schema>>,
846    }
847
848    impl TryFrom<Value> for Schema {
849        type Error = CompletionError;
850
851        fn try_from(value: Value) -> Result<Self, Self::Error> {
852            if let Some(obj) = value.as_object() {
853                Ok(Schema {
854                    r#type: obj
855                        .get("type")
856                        .and_then(|v| {
857                            if v.is_string() {
858                                v.as_str().map(String::from)
859                            } else if v.is_array() {
860                                v.as_array()
861                                    .and_then(|arr| arr.first())
862                                    .and_then(|v| v.as_str().map(String::from))
863                            } else {
864                                None
865                            }
866                        })
867                        .unwrap_or_default(),
868                    format: obj.get("format").and_then(|v| v.as_str()).map(String::from),
869                    description: obj
870                        .get("description")
871                        .and_then(|v| v.as_str())
872                        .map(String::from),
873                    nullable: obj.get("nullable").and_then(|v| v.as_bool()),
874                    r#enum: obj.get("enum").and_then(|v| v.as_array()).map(|arr| {
875                        arr.iter()
876                            .filter_map(|v| v.as_str().map(String::from))
877                            .collect()
878                    }),
879                    max_items: obj
880                        .get("maxItems")
881                        .and_then(|v| v.as_i64())
882                        .map(|v| v as i32),
883                    min_items: obj
884                        .get("minItems")
885                        .and_then(|v| v.as_i64())
886                        .map(|v| v as i32),
887                    properties: obj
888                        .get("properties")
889                        .and_then(|v| v.as_object())
890                        .map(|map| {
891                            map.iter()
892                                .filter_map(|(k, v)| {
893                                    v.clone().try_into().ok().map(|schema| (k.clone(), schema))
894                                })
895                                .collect()
896                        }),
897                    required: obj.get("required").and_then(|v| v.as_array()).map(|arr| {
898                        arr.iter()
899                            .filter_map(|v| v.as_str().map(String::from))
900                            .collect()
901                    }),
902                    items: obj
903                        .get("items")
904                        .map(|v| Box::new(v.clone().try_into().unwrap())),
905                })
906            } else {
907                Err(CompletionError::ResponseError(
908                    "Expected a JSON object for Schema".into(),
909                ))
910            }
911        }
912    }
913
914    #[derive(Debug, Serialize)]
915    #[serde(rename_all = "camelCase")]
916    pub struct GenerateContentRequest {
917        pub contents: Vec<Content>,
918        pub tools: Option<Vec<Tool>>,
919        pub tool_config: Option<ToolConfig>,
920        /// Optional. Configuration options for model generation and outputs.
921        pub generation_config: Option<GenerationConfig>,
922        /// Optional. A list of unique SafetySetting instances for blocking unsafe content. This will be enforced on the
923        /// [GenerateContentRequest.contents] and [GenerateContentResponse.candidates]. There should not be more than one
924        /// setting for each SafetyCategory type. The API will block any contents and responses that fail to meet the
925        /// thresholds set by these settings. This list overrides the default settings for each SafetyCategory specified
926        /// in the safetySettings. If there is no SafetySetting for a given SafetyCategory provided in the list, the API
927        /// will use the default safety setting for that category. Harm categories:
928        ///     - HARM_CATEGORY_HATE_SPEECH,
929        ///     - HARM_CATEGORY_SEXUALLY_EXPLICIT
930        ///     - HARM_CATEGORY_DANGEROUS_CONTENT
931        ///     - HARM_CATEGORY_HARASSMENT
932        /// are supported.
933        /// Refer to the guide for detailed information on available safety settings. Also refer to the Safety guidance
934        /// to learn how to incorporate safety considerations in your AI applications.
935        pub safety_settings: Option<Vec<SafetySetting>>,
936        /// Optional. Developer set system instruction(s). Currently, text only.
937        /// From [Gemini API Reference](https://ai.google.dev/gemini-api/docs/system-instructions?lang=rest)
938        pub system_instruction: Option<Content>,
939        // cachedContent: Optional<String>
940    }
941
942    #[derive(Debug, Serialize)]
943    #[serde(rename_all = "camelCase")]
944    pub struct Tool {
945        pub function_declarations: FunctionDeclaration,
946        pub code_execution: Option<CodeExecution>,
947    }
948
949    #[derive(Debug, Serialize)]
950    #[serde(rename_all = "camelCase")]
951    pub struct FunctionDeclaration {
952        pub name: String,
953        pub description: String,
954        #[serde(skip_serializing_if = "Option::is_none")]
955        pub parameters: Option<Schema>,
956    }
957
958    #[derive(Debug, Serialize)]
959    #[serde(rename_all = "camelCase")]
960    pub struct ToolConfig {
961        pub schema: Option<Schema>,
962    }
963
964    #[derive(Debug, Serialize)]
965    #[serde(rename_all = "camelCase")]
966    pub struct CodeExecution {}
967
968    #[derive(Debug, Serialize)]
969    #[serde(rename_all = "camelCase")]
970    pub struct SafetySetting {
971        pub category: HarmCategory,
972        pub threshold: HarmBlockThreshold,
973    }
974
975    #[derive(Debug, Serialize)]
976    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
977    pub enum HarmBlockThreshold {
978        HarmBlockThresholdUnspecified,
979        BlockLowAndAbove,
980        BlockMediumAndAbove,
981        BlockOnlyHigh,
982        BlockNone,
983        Off,
984    }
985}
986
987#[cfg(test)]
988mod tests {
989    use crate::message;
990
991    use super::*;
992    use serde_json::json;
993
994    #[test]
995    fn test_deserialize_message_user() {
996        let raw_message = r#"{
997            "parts": [
998                {"text": "Hello, world!"},
999                {"inlineData": {"mimeType": "image/png", "data": "base64encodeddata"}},
1000                {"functionCall": {"name": "test_function", "args": {"arg1": "value1"}}},
1001                {"functionResponse": {"name": "test_function", "response": {"result": "success"}}},
1002                {"fileData": {"mimeType": "application/pdf", "fileUri": "http://example.com/file.pdf"}},
1003                {"executableCode": {"code": "print('Hello, world!')", "language": "PYTHON"}},
1004                {"codeExecutionResult": {"output": "Hello, world!", "outcome": "OUTCOME_OK"}}
1005            ],
1006            "role": "user"
1007        }"#;
1008
1009        let content: Content = {
1010            let jd = &mut serde_json::Deserializer::from_str(raw_message);
1011            serde_path_to_error::deserialize(jd).unwrap_or_else(|err| {
1012                panic!("Deserialization error at {}: {}", err.path(), err);
1013            })
1014        };
1015        assert_eq!(content.role, Some(Role::User));
1016        assert_eq!(content.parts.len(), 7);
1017
1018        let parts: Vec<Part> = content.parts.into_iter().collect();
1019
1020        if let Part::Text(text) = &parts[0] {
1021            assert_eq!(text, "Hello, world!");
1022        } else {
1023            panic!("Expected text part");
1024        }
1025
1026        if let Part::InlineData(inline_data) = &parts[1] {
1027            assert_eq!(inline_data.mime_type, "image/png");
1028            assert_eq!(inline_data.data, "base64encodeddata");
1029        } else {
1030            panic!("Expected inline data part");
1031        }
1032
1033        if let Part::FunctionCall(function_call) = &parts[2] {
1034            assert_eq!(function_call.name, "test_function");
1035            assert_eq!(
1036                function_call.args.as_object().unwrap().get("arg1").unwrap(),
1037                "value1"
1038            );
1039        } else {
1040            panic!("Expected function call part");
1041        }
1042
1043        if let Part::FunctionResponse(function_response) = &parts[3] {
1044            assert_eq!(function_response.name, "test_function");
1045            assert_eq!(
1046                function_response
1047                    .response
1048                    .as_ref()
1049                    .unwrap()
1050                    .get("result")
1051                    .unwrap(),
1052                "success"
1053            );
1054        } else {
1055            panic!("Expected function response part");
1056        }
1057
1058        if let Part::FileData(file_data) = &parts[4] {
1059            assert_eq!(file_data.mime_type.as_ref().unwrap(), "application/pdf");
1060            assert_eq!(file_data.file_uri, "http://example.com/file.pdf");
1061        } else {
1062            panic!("Expected file data part");
1063        }
1064
1065        if let Part::ExecutableCode(executable_code) = &parts[5] {
1066            assert_eq!(executable_code.code, "print('Hello, world!')");
1067        } else {
1068            panic!("Expected executable code part");
1069        }
1070
1071        if let Part::CodeExecutionResult(code_execution_result) = &parts[6] {
1072            assert_eq!(
1073                code_execution_result.clone().output.unwrap(),
1074                "Hello, world!"
1075            );
1076        } else {
1077            panic!("Expected code execution result part");
1078        }
1079    }
1080
1081    #[test]
1082    fn test_deserialize_message_model() {
1083        let json_data = json!({
1084            "parts": [{"text": "Hello, user!"}],
1085            "role": "model"
1086        });
1087
1088        let content: Content = serde_json::from_value(json_data).unwrap();
1089        assert_eq!(content.role, Some(Role::Model));
1090        assert_eq!(content.parts.len(), 1);
1091        if let Part::Text(text) = &content.parts.first() {
1092            assert_eq!(text, "Hello, user!");
1093        } else {
1094            panic!("Expected text part");
1095        }
1096    }
1097
1098    #[test]
1099    fn test_message_conversion_user() {
1100        let msg = message::Message::user("Hello, world!");
1101        let content: Content = msg.try_into().unwrap();
1102        assert_eq!(content.role, Some(Role::User));
1103        assert_eq!(content.parts.len(), 1);
1104        if let Part::Text(text) = &content.parts.first() {
1105            assert_eq!(text, "Hello, world!");
1106        } else {
1107            panic!("Expected text part");
1108        }
1109    }
1110
1111    #[test]
1112    fn test_message_conversion_model() {
1113        let msg = message::Message::assistant("Hello, user!");
1114
1115        let content: Content = msg.try_into().unwrap();
1116        assert_eq!(content.role, Some(Role::Model));
1117        assert_eq!(content.parts.len(), 1);
1118        if let Part::Text(text) = &content.parts.first() {
1119            assert_eq!(text, "Hello, user!");
1120        } else {
1121            panic!("Expected text part");
1122        }
1123    }
1124
1125    #[test]
1126    fn test_message_conversion_tool_call() {
1127        let tool_call = message::ToolCall {
1128            id: "test_tool".to_string(),
1129            function: message::ToolFunction {
1130                name: "test_function".to_string(),
1131                arguments: json!({"arg1": "value1"}),
1132            },
1133        };
1134
1135        let msg = message::Message::Assistant {
1136            content: OneOrMany::one(message::AssistantContent::ToolCall(tool_call)),
1137        };
1138
1139        let content: Content = msg.try_into().unwrap();
1140        assert_eq!(content.role, Some(Role::Model));
1141        assert_eq!(content.parts.len(), 1);
1142        if let Part::FunctionCall(function_call) = &content.parts.first() {
1143            assert_eq!(function_call.name, "test_function");
1144            assert_eq!(
1145                function_call.args.as_object().unwrap().get("arg1").unwrap(),
1146                "value1"
1147            );
1148        } else {
1149            panic!("Expected function call part");
1150        }
1151    }
1152}