rig/providers/gemini/
completion.rs

1// ================================================================
2//! Google Gemini Completion Integration
3//! From [Gemini API Reference](https://ai.google.dev/api/generate-content)
4// ================================================================
5/// `gemini-2.5-pro-preview-06-05` completion model
6pub const GEMINI_2_5_PRO_PREVIEW_06_05: &str = "gemini-2.5-pro-preview-06-05";
7/// `gemini-2.5-pro-preview-05-06` completion model
8pub const GEMINI_2_5_PRO_PREVIEW_05_06: &str = "gemini-2.5-pro-preview-05-06";
9/// `gemini-2.5-pro-preview-03-25` completion model
10pub const GEMINI_2_5_PRO_PREVIEW_03_25: &str = "gemini-2.5-pro-preview-03-25";
11/// `gemini-2.5-flash-preview-05-20` completion model
12pub const GEMINI_2_5_FLASH_PREVIEW_05_20: &str = "gemini-2.5-flash-preview-05-20";
13/// `gemini-2.5-flash-preview-04-17` completion model
14pub const GEMINI_2_5_FLASH_PREVIEW_04_17: &str = "gemini-2.5-flash-preview-04-17";
15/// `gemini-2.5-pro-exp-03-25` experimental completion model
16pub const GEMINI_2_5_PRO_EXP_03_25: &str = "gemini-2.5-pro-exp-03-25";
17/// `gemini-2.0-flash-lite` completion model
18pub const GEMINI_2_0_FLASH_LITE: &str = "gemini-2.0-flash-lite";
19/// `gemini-2.0-flash` completion model
20pub const GEMINI_2_0_FLASH: &str = "gemini-2.0-flash";
21/// `gemini-1.5-flash` completion model
22pub const GEMINI_1_5_FLASH: &str = "gemini-1.5-flash";
23/// `gemini-1.5-pro` completion model
24pub const GEMINI_1_5_PRO: &str = "gemini-1.5-pro";
25/// `gemini-1.5-pro-8b` completion model
26pub const GEMINI_1_5_PRO_8B: &str = "gemini-1.5-pro-8b";
27/// `gemini-1.0-pro` completion model
28pub const GEMINI_1_0_PRO: &str = "gemini-1.0-pro";
29
30use self::gemini_api_types::Schema;
31use crate::providers::gemini::streaming::StreamingCompletionResponse;
32use crate::{
33    OneOrMany,
34    completion::{self, CompletionError, CompletionRequest},
35};
36use gemini_api_types::{
37    Content, FunctionDeclaration, GenerateContentRequest, GenerateContentResponse,
38    GenerationConfig, Part, Role, Tool,
39};
40use serde_json::{Map, Value};
41use std::convert::TryFrom;
42
43use super::Client;
44
45// =================================================================
46// Rig Implementation Types
47// =================================================================
48
49#[derive(Clone)]
50pub struct CompletionModel {
51    pub(crate) client: Client,
52    pub model: String,
53}
54
55impl CompletionModel {
56    pub fn new(client: Client, model: &str) -> Self {
57        Self {
58            client,
59            model: model.to_string(),
60        }
61    }
62}
63
64impl completion::CompletionModel for CompletionModel {
65    type Response = GenerateContentResponse;
66    type StreamingResponse = StreamingCompletionResponse;
67
68    #[cfg_attr(feature = "worker", worker::send)]
69    async fn completion(
70        &self,
71        completion_request: CompletionRequest,
72    ) -> Result<completion::CompletionResponse<GenerateContentResponse>, CompletionError> {
73        let request = create_request_body(completion_request)?;
74
75        tracing::debug!(
76            "Sending completion request to Gemini API {}",
77            serde_json::to_string_pretty(&request)?
78        );
79
80        let response = self
81            .client
82            .post(&format!("/v1beta/models/{}:generateContent", self.model))
83            .json(&request)
84            .send()
85            .await?;
86
87        if response.status().is_success() {
88            let response = response.json::<GenerateContentResponse>().await?;
89            match response.usage_metadata {
90                Some(ref usage) => tracing::info!(target: "rig",
91                "Gemini completion token usage: {}",
92                usage
93                ),
94                None => tracing::info!(target: "rig",
95                    "Gemini completion token usage: n/a",
96                ),
97            }
98
99            tracing::debug!("Received response");
100
101            Ok(completion::CompletionResponse::try_from(response))
102        } else {
103            Err(CompletionError::ProviderError(response.text().await?))
104        }?
105    }
106
107    #[cfg_attr(feature = "worker", worker::send)]
108    async fn stream(
109        &self,
110        request: CompletionRequest,
111    ) -> Result<
112        crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
113        CompletionError,
114    > {
115        CompletionModel::stream(self, request).await
116    }
117}
118
119pub(crate) fn create_request_body(
120    completion_request: CompletionRequest,
121) -> Result<GenerateContentRequest, CompletionError> {
122    let mut full_history = Vec::new();
123    full_history.extend(completion_request.chat_history);
124
125    let additional_params = completion_request
126        .additional_params
127        .unwrap_or_else(|| Value::Object(Map::new()));
128
129    let mut generation_config = serde_json::from_value::<GenerationConfig>(additional_params)?;
130
131    if let Some(temp) = completion_request.temperature {
132        generation_config.temperature = Some(temp);
133    }
134
135    if let Some(max_tokens) = completion_request.max_tokens {
136        generation_config.max_output_tokens = Some(max_tokens);
137    }
138
139    let system_instruction = completion_request.preamble.clone().map(|preamble| Content {
140        parts: OneOrMany::one(preamble.into()),
141        role: Some(Role::Model),
142    });
143
144    let tools = if completion_request.tools.is_empty() {
145        None
146    } else {
147        Some(Tool::try_from(completion_request.tools)?)
148    };
149
150    let request = GenerateContentRequest {
151        contents: full_history
152            .into_iter()
153            .map(|msg| {
154                msg.try_into()
155                    .map_err(|e| CompletionError::RequestError(Box::new(e)))
156            })
157            .collect::<Result<Vec<_>, _>>()?,
158        generation_config: Some(generation_config),
159        safety_settings: None,
160        tools,
161        tool_config: None,
162        system_instruction,
163    };
164
165    Ok(request)
166}
167
168impl TryFrom<completion::ToolDefinition> for Tool {
169    type Error = CompletionError;
170
171    fn try_from(tool: completion::ToolDefinition) -> Result<Self, Self::Error> {
172        let parameters: Option<Schema> =
173            if tool.parameters == serde_json::json!({"type": "object", "properties": {}}) {
174                None
175            } else {
176                Some(tool.parameters.try_into()?)
177            };
178
179        Ok(Self {
180            function_declarations: vec![FunctionDeclaration {
181                name: tool.name,
182                description: tool.description,
183                parameters,
184            }],
185            code_execution: None,
186        })
187    }
188}
189
190impl TryFrom<Vec<completion::ToolDefinition>> for Tool {
191    type Error = CompletionError;
192
193    fn try_from(tools: Vec<completion::ToolDefinition>) -> Result<Self, Self::Error> {
194        let mut function_declarations = Vec::new();
195
196        for tool in tools {
197            let parameters =
198                if tool.parameters == serde_json::json!({"type": "object", "properties": {}}) {
199                    None
200                } else {
201                    match tool.parameters.try_into() {
202                        Ok(schema) => Some(schema),
203                        Err(e) => {
204                            let emsg = format!(
205                                "Tool '{}' could not be converted to a schema: {:?}",
206                                tool.name, e,
207                            );
208                            return Err(CompletionError::ProviderError(emsg));
209                        }
210                    }
211                };
212
213            function_declarations.push(FunctionDeclaration {
214                name: tool.name,
215                description: tool.description,
216                parameters,
217            });
218        }
219
220        Ok(Self {
221            function_declarations,
222            code_execution: None,
223        })
224    }
225}
226
227impl TryFrom<GenerateContentResponse> for completion::CompletionResponse<GenerateContentResponse> {
228    type Error = CompletionError;
229
230    fn try_from(response: GenerateContentResponse) -> Result<Self, Self::Error> {
231        let candidate = response.candidates.first().ok_or_else(|| {
232            CompletionError::ResponseError("No response candidates in response".into())
233        })?;
234
235        let content = candidate
236            .content
237            .parts
238            .iter()
239            .map(|part| {
240                Ok(match part {
241                    Part::Text(text) => completion::AssistantContent::text(text),
242                    Part::FunctionCall(function_call) => completion::AssistantContent::tool_call(
243                        &function_call.name,
244                        &function_call.name,
245                        function_call.args.clone(),
246                    ),
247                    _ => {
248                        return Err(CompletionError::ResponseError(
249                            "Response did not contain a message or tool call".into(),
250                        ));
251                    }
252                })
253            })
254            .collect::<Result<Vec<_>, _>>()?;
255
256        let choice = OneOrMany::many(content).map_err(|_| {
257            CompletionError::ResponseError(
258                "Response contained no message or tool call (empty)".to_owned(),
259            )
260        })?;
261
262        let usage = response
263            .usage_metadata
264            .as_ref()
265            .map(|usage| completion::Usage {
266                input_tokens: usage.prompt_token_count as u64,
267                output_tokens: usage.candidates_token_count as u64,
268                total_tokens: usage.total_token_count as u64,
269            })
270            .unwrap_or_default();
271
272        Ok(completion::CompletionResponse {
273            choice,
274            usage,
275            raw_response: response,
276        })
277    }
278}
279
280pub mod gemini_api_types {
281    use std::{collections::HashMap, convert::Infallible, str::FromStr};
282
283    // =================================================================
284    // Gemini API Types
285    // =================================================================
286    use serde::{Deserialize, Serialize};
287    use serde_json::{Value, json};
288
289    use crate::{
290        OneOrMany,
291        completion::CompletionError,
292        message::{self, MessageError, MimeType as _},
293        one_or_many::string_or_one_or_many,
294        providers::gemini::gemini_api_types::{CodeExecutionResult, ExecutableCode},
295    };
296
297    /// Response from the model supporting multiple candidate responses.
298    /// Safety ratings and content filtering are reported for both prompt in GenerateContentResponse.prompt_feedback
299    /// and for each candidate in finishReason and in safetyRatings.
300    /// The API:
301    ///     - Returns either all requested candidates or none of them
302    ///     - Returns no candidates at all only if there was something wrong with the prompt (check promptFeedback)
303    ///     - Reports feedback on each candidate in finishReason and safetyRatings.
304    #[derive(Debug, Deserialize, Serialize)]
305    #[serde(rename_all = "camelCase")]
306    pub struct GenerateContentResponse {
307        /// Candidate responses from the model.
308        pub candidates: Vec<ContentCandidate>,
309        /// Returns the prompt's feedback related to the content filters.
310        pub prompt_feedback: Option<PromptFeedback>,
311        /// Output only. Metadata on the generation requests' token usage.
312        pub usage_metadata: Option<UsageMetadata>,
313        pub model_version: Option<String>,
314    }
315
316    /// A response candidate generated from the model.
317    #[derive(Debug, Deserialize, Serialize)]
318    #[serde(rename_all = "camelCase")]
319    pub struct ContentCandidate {
320        /// Output only. Generated content returned from the model.
321        pub content: Content,
322        /// Optional. Output only. The reason why the model stopped generating tokens.
323        /// If empty, the model has not stopped generating tokens.
324        pub finish_reason: Option<FinishReason>,
325        /// List of ratings for the safety of a response candidate.
326        /// There is at most one rating per category.
327        pub safety_ratings: Option<Vec<SafetyRating>>,
328        /// Output only. Citation information for model-generated candidate.
329        /// This field may be populated with recitation information for any text included in the content.
330        /// These are passages that are "recited" from copyrighted material in the foundational LLM's training data.
331        pub citation_metadata: Option<CitationMetadata>,
332        /// Output only. Token count for this candidate.
333        pub token_count: Option<i32>,
334        /// Output only.
335        pub avg_logprobs: Option<f64>,
336        /// Output only. Log-likelihood scores for the response tokens and top tokens
337        pub logprobs_result: Option<LogprobsResult>,
338        /// Output only. Index of the candidate in the list of response candidates.
339        pub index: Option<i32>,
340    }
341    #[derive(Debug, Deserialize, Serialize)]
342    pub struct Content {
343        /// Ordered Parts that constitute a single message. Parts may have different MIME types.
344        #[serde(deserialize_with = "string_or_one_or_many")]
345        pub parts: OneOrMany<Part>,
346        /// The producer of the content. Must be either 'user' or 'model'.
347        /// Useful to set for multi-turn conversations, otherwise can be left blank or unset.
348        pub role: Option<Role>,
349    }
350
351    impl TryFrom<message::Message> for Content {
352        type Error = message::MessageError;
353
354        fn try_from(msg: message::Message) -> Result<Self, Self::Error> {
355            Ok(match msg {
356                message::Message::User { content } => Content {
357                    parts: content.try_map(|c| c.try_into())?,
358                    role: Some(Role::User),
359                },
360                message::Message::Assistant { content, .. } => Content {
361                    role: Some(Role::Model),
362                    parts: content.map(|content| content.into()),
363                },
364            })
365        }
366    }
367
368    impl TryFrom<Content> for message::Message {
369        type Error = message::MessageError;
370
371        fn try_from(content: Content) -> Result<Self, Self::Error> {
372            match content.role {
373                Some(Role::User) | None => Ok(message::Message::User {
374                    content: content.parts.try_map(|part| {
375                        Ok(match part {
376                            Part::Text(text) => message::UserContent::text(text),
377                            Part::InlineData(inline_data) => {
378                                let mime_type =
379                                    message::MediaType::from_mime_type(&inline_data.mime_type);
380
381                                match mime_type {
382                                    Some(message::MediaType::Image(media_type)) => {
383                                        message::UserContent::image(
384                                            inline_data.data,
385                                            Some(message::ContentFormat::default()),
386                                            Some(media_type),
387                                            Some(message::ImageDetail::default()),
388                                        )
389                                    }
390                                    Some(message::MediaType::Document(media_type)) => {
391                                        message::UserContent::document(
392                                            inline_data.data,
393                                            Some(message::ContentFormat::default()),
394                                            Some(media_type),
395                                        )
396                                    }
397                                    Some(message::MediaType::Audio(media_type)) => {
398                                        message::UserContent::audio(
399                                            inline_data.data,
400                                            Some(message::ContentFormat::default()),
401                                            Some(media_type),
402                                        )
403                                    }
404                                    _ => {
405                                        return Err(message::MessageError::ConversionError(
406                                            format!("Unsupported media type {mime_type:?}"),
407                                        ));
408                                    }
409                                }
410                            }
411                            _ => {
412                                return Err(message::MessageError::ConversionError(format!(
413                                    "Unsupported gemini content part type: {part:?}"
414                                )));
415                            }
416                        })
417                    })?,
418                }),
419                Some(Role::Model) => Ok(message::Message::Assistant {
420                    id: None,
421                    content: content.parts.try_map(|part| {
422                        Ok(match part {
423                            Part::Text(text) => message::AssistantContent::text(text),
424                            Part::FunctionCall(function_call) => {
425                                message::AssistantContent::ToolCall(function_call.into())
426                            }
427                            _ => {
428                                return Err(message::MessageError::ConversionError(format!(
429                                    "Unsupported part type: {part:?}"
430                                )));
431                            }
432                        })
433                    })?,
434                }),
435            }
436        }
437    }
438
439    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
440    #[serde(rename_all = "lowercase")]
441    pub enum Role {
442        User,
443        Model,
444    }
445
446    /// A datatype containing media that is part of a multi-part [Content] message.
447    /// A Part consists of data which has an associated datatype. A Part can only contain one of the accepted types in Part.data.
448    /// A Part must have a fixed IANA MIME type identifying the type and subtype of the media if the inlineData field is filled with raw bytes.
449    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
450    #[serde(rename_all = "camelCase")]
451    pub enum Part {
452        Text(String),
453        InlineData(Blob),
454        FunctionCall(FunctionCall),
455        FunctionResponse(FunctionResponse),
456        FileData(FileData),
457        ExecutableCode(ExecutableCode),
458        CodeExecutionResult(CodeExecutionResult),
459        Thought { thoughts: Vec<String> },
460    }
461
462    impl From<String> for Part {
463        fn from(text: String) -> Self {
464            Self::Text(text)
465        }
466    }
467
468    impl From<&str> for Part {
469        fn from(text: &str) -> Self {
470            Self::Text(text.to_string())
471        }
472    }
473
474    impl FromStr for Part {
475        type Err = Infallible;
476
477        fn from_str(s: &str) -> Result<Self, Self::Err> {
478            Ok(s.into())
479        }
480    }
481
482    impl TryFrom<message::UserContent> for Part {
483        type Error = message::MessageError;
484
485        fn try_from(content: message::UserContent) -> Result<Self, Self::Error> {
486            match content {
487                message::UserContent::Text(message::Text { text }) => Ok(Self::Text(text)),
488                message::UserContent::ToolResult(message::ToolResult { id, content, .. }) => {
489                    let content = match content.first() {
490                        message::ToolResultContent::Text(text) => text.text,
491                        message::ToolResultContent::Image(_) => {
492                            return Err(message::MessageError::ConversionError(
493                                "Tool result content must be text".to_string(),
494                            ));
495                        }
496                    };
497                    // Convert to JSON since this value may be a valid JSON value
498                    let result: serde_json::Value = serde_json::from_str(&content)
499                        .map_err(|x| MessageError::ConversionError(x.to_string()))?;
500                    Ok(Part::FunctionResponse(FunctionResponse {
501                        name: id,
502                        response: Some(json!({ "result": result })),
503                    }))
504                }
505                message::UserContent::Image(message::Image {
506                    data, media_type, ..
507                }) => match media_type {
508                    Some(media_type) => match media_type {
509                        message::ImageMediaType::JPEG
510                        | message::ImageMediaType::PNG
511                        | message::ImageMediaType::WEBP
512                        | message::ImageMediaType::HEIC
513                        | message::ImageMediaType::HEIF => Ok(Self::InlineData(Blob {
514                            mime_type: media_type.to_mime_type().to_owned(),
515                            data,
516                        })),
517                        _ => Err(message::MessageError::ConversionError(format!(
518                            "Unsupported image media type {media_type:?}"
519                        ))),
520                    },
521                    None => Err(message::MessageError::ConversionError(
522                        "Media type for image is required for Anthropic".to_string(),
523                    )),
524                },
525                message::UserContent::Document(message::Document {
526                    data, media_type, ..
527                }) => match media_type {
528                    Some(media_type) => match media_type {
529                        message::DocumentMediaType::PDF
530                        | message::DocumentMediaType::TXT
531                        | message::DocumentMediaType::RTF
532                        | message::DocumentMediaType::HTML
533                        | message::DocumentMediaType::CSS
534                        | message::DocumentMediaType::MARKDOWN
535                        | message::DocumentMediaType::CSV
536                        | message::DocumentMediaType::XML => Ok(Self::InlineData(Blob {
537                            mime_type: media_type.to_mime_type().to_owned(),
538                            data,
539                        })),
540                        _ => Err(message::MessageError::ConversionError(format!(
541                            "Unsupported document media type {media_type:?}"
542                        ))),
543                    },
544                    None => Err(message::MessageError::ConversionError(
545                        "Media type for document is required for Anthropic".to_string(),
546                    )),
547                },
548                message::UserContent::Audio(message::Audio {
549                    data, media_type, ..
550                }) => match media_type {
551                    Some(media_type) => Ok(Self::InlineData(Blob {
552                        mime_type: media_type.to_mime_type().to_owned(),
553                        data,
554                    })),
555                    None => Err(message::MessageError::ConversionError(
556                        "Media type for audio is required for Anthropic".to_string(),
557                    )),
558                },
559            }
560        }
561    }
562
563    impl From<message::AssistantContent> for Part {
564        fn from(content: message::AssistantContent) -> Self {
565            match content {
566                message::AssistantContent::Text(message::Text { text }) => text.into(),
567                message::AssistantContent::ToolCall(tool_call) => tool_call.into(),
568                message::AssistantContent::Reasoning(message::Reasoning { reasoning }) => {
569                    Part::Thought {
570                        thoughts: vec![reasoning],
571                    }
572                }
573            }
574        }
575    }
576
577    impl From<message::ToolCall> for Part {
578        fn from(tool_call: message::ToolCall) -> Self {
579            Self::FunctionCall(FunctionCall {
580                name: tool_call.function.name,
581                args: tool_call.function.arguments,
582            })
583        }
584    }
585
586    /// Raw media bytes.
587    /// Text should not be sent as raw bytes, use the 'text' field.
588    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
589    #[serde(rename_all = "camelCase")]
590    pub struct Blob {
591        /// The IANA standard MIME type of the source data. Examples: - image/png - image/jpeg
592        /// If an unsupported MIME type is provided, an error will be returned.
593        pub mime_type: String,
594        /// Raw bytes for media formats. A base64-encoded string.
595        pub data: String,
596    }
597
598    /// A predicted FunctionCall returned from the model that contains a string representing the
599    /// FunctionDeclaration.name with the arguments and their values.
600    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
601    pub struct FunctionCall {
602        /// Required. The name of the function to call. Must be a-z, A-Z, 0-9, or contain underscores
603        /// and dashes, with a maximum length of 63.
604        pub name: String,
605        /// Optional. The function parameters and values in JSON object format.
606        pub args: serde_json::Value,
607    }
608
609    impl From<FunctionCall> for message::ToolCall {
610        fn from(function_call: FunctionCall) -> Self {
611            Self {
612                id: function_call.name.clone(),
613                call_id: None,
614                function: message::ToolFunction {
615                    name: function_call.name,
616                    arguments: function_call.args,
617                },
618            }
619        }
620    }
621
622    impl From<message::ToolCall> for FunctionCall {
623        fn from(tool_call: message::ToolCall) -> Self {
624            Self {
625                name: tool_call.function.name,
626                args: tool_call.function.arguments,
627            }
628        }
629    }
630
631    /// The result output from a FunctionCall that contains a string representing the FunctionDeclaration.name
632    /// and a structured JSON object containing any output from the function is used as context to the model.
633    /// This should contain the result of aFunctionCall made based on model prediction.
634    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
635    pub struct FunctionResponse {
636        /// The name of the function to call. Must be a-z, A-Z, 0-9, or contain underscores and dashes,
637        /// with a maximum length of 63.
638        pub name: String,
639        /// The function response in JSON object format.
640        pub response: Option<serde_json::Value>,
641    }
642
643    /// URI based data.
644    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
645    #[serde(rename_all = "camelCase")]
646    pub struct FileData {
647        /// Optional. The IANA standard MIME type of the source data.
648        pub mime_type: Option<String>,
649        /// Required. URI.
650        pub file_uri: String,
651    }
652
653    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
654    pub struct SafetyRating {
655        pub category: HarmCategory,
656        pub probability: HarmProbability,
657    }
658
659    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
660    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
661    pub enum HarmProbability {
662        HarmProbabilityUnspecified,
663        Negligible,
664        Low,
665        Medium,
666        High,
667    }
668
669    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
670    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
671    pub enum HarmCategory {
672        HarmCategoryUnspecified,
673        HarmCategoryDerogatory,
674        HarmCategoryToxicity,
675        HarmCategoryViolence,
676        HarmCategorySexually,
677        HarmCategoryMedical,
678        HarmCategoryDangerous,
679        HarmCategoryHarassment,
680        HarmCategoryHateSpeech,
681        HarmCategorySexuallyExplicit,
682        HarmCategoryDangerousContent,
683        HarmCategoryCivicIntegrity,
684    }
685
686    #[derive(Debug, Deserialize, Clone, Default, Serialize)]
687    #[serde(rename_all = "camelCase")]
688    pub struct UsageMetadata {
689        pub prompt_token_count: i32,
690        #[serde(skip_serializing_if = "Option::is_none")]
691        pub cached_content_token_count: Option<i32>,
692        pub candidates_token_count: i32,
693        pub total_token_count: i32,
694    }
695
696    impl std::fmt::Display for UsageMetadata {
697        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
698            write!(
699                f,
700                "Prompt token count: {}\nCached content token count: {}\nCandidates token count: {}\nTotal token count: {}",
701                self.prompt_token_count,
702                match self.cached_content_token_count {
703                    Some(count) => count.to_string(),
704                    None => "n/a".to_string(),
705                },
706                self.candidates_token_count,
707                self.total_token_count
708            )
709        }
710    }
711
712    /// A set of the feedback metadata the prompt specified in [GenerateContentRequest.contents](GenerateContentRequest).
713    #[derive(Debug, Deserialize, Serialize)]
714    #[serde(rename_all = "camelCase")]
715    pub struct PromptFeedback {
716        /// Optional. If set, the prompt was blocked and no candidates are returned. Rephrase the prompt.
717        pub block_reason: Option<BlockReason>,
718        /// Ratings for safety of the prompt. There is at most one rating per category.
719        pub safety_ratings: Option<Vec<SafetyRating>>,
720    }
721
722    /// Reason why a prompt was blocked by the model
723    #[derive(Debug, Deserialize, Serialize)]
724    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
725    pub enum BlockReason {
726        /// Default value. This value is unused.
727        BlockReasonUnspecified,
728        /// Prompt was blocked due to safety reasons. Inspect safetyRatings to understand which safety category blocked it.
729        Safety,
730        /// Prompt was blocked due to unknown reasons.
731        Other,
732        /// Prompt was blocked due to the terms which are included from the terminology blocklist.
733        Blocklist,
734        /// Prompt was blocked due to prohibited content.
735        ProhibitedContent,
736    }
737
738    #[derive(Debug, Deserialize, Serialize)]
739    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
740    pub enum FinishReason {
741        /// Default value. This value is unused.
742        FinishReasonUnspecified,
743        /// Natural stop point of the model or provided stop sequence.
744        Stop,
745        /// The maximum number of tokens as specified in the request was reached.
746        MaxTokens,
747        /// The response candidate content was flagged for safety reasons.
748        Safety,
749        /// The response candidate content was flagged for recitation reasons.
750        Recitation,
751        /// The response candidate content was flagged for using an unsupported language.
752        Language,
753        /// Unknown reason.
754        Other,
755        /// Token generation stopped because the content contains forbidden terms.
756        Blocklist,
757        /// Token generation stopped for potentially containing prohibited content.
758        ProhibitedContent,
759        /// Token generation stopped because the content potentially contains Sensitive Personally Identifiable Information (SPII).
760        Spii,
761        /// The function call generated by the model is invalid.
762        MalformedFunctionCall,
763    }
764
765    #[derive(Debug, Deserialize, Serialize)]
766    #[serde(rename_all = "camelCase")]
767    pub struct CitationMetadata {
768        pub citation_sources: Vec<CitationSource>,
769    }
770
771    #[derive(Debug, Deserialize, Serialize)]
772    #[serde(rename_all = "camelCase")]
773    pub struct CitationSource {
774        #[serde(skip_serializing_if = "Option::is_none")]
775        pub uri: Option<String>,
776        #[serde(skip_serializing_if = "Option::is_none")]
777        pub start_index: Option<i32>,
778        #[serde(skip_serializing_if = "Option::is_none")]
779        pub end_index: Option<i32>,
780        #[serde(skip_serializing_if = "Option::is_none")]
781        pub license: Option<String>,
782    }
783
784    #[derive(Debug, Deserialize, Serialize)]
785    #[serde(rename_all = "camelCase")]
786    pub struct LogprobsResult {
787        pub top_candidate: Vec<TopCandidate>,
788        pub chosen_candidate: Vec<LogProbCandidate>,
789    }
790
791    #[derive(Debug, Deserialize, Serialize)]
792    pub struct TopCandidate {
793        pub candidates: Vec<LogProbCandidate>,
794    }
795
796    #[derive(Debug, Deserialize, Serialize)]
797    #[serde(rename_all = "camelCase")]
798    pub struct LogProbCandidate {
799        pub token: String,
800        pub token_id: String,
801        pub log_probability: f64,
802    }
803
804    /// Gemini API Configuration options for model generation and outputs. Not all parameters are
805    /// configurable for every model. From [Gemini API Reference](https://ai.google.dev/api/generate-content#generationconfig)
806    /// ### Rig Note:
807    /// Can be used to construct a typesafe `additional_params` in rig::[AgentBuilder](crate::agent::AgentBuilder).
808    #[derive(Debug, Deserialize, Serialize)]
809    #[serde(rename_all = "camelCase")]
810    pub struct GenerationConfig {
811        /// The set of character sequences (up to 5) that will stop output generation. If specified, the API will stop
812        /// at the first appearance of a stop_sequence. The stop sequence will not be included as part of the response.
813        #[serde(skip_serializing_if = "Option::is_none")]
814        pub stop_sequences: Option<Vec<String>>,
815        /// MIME type of the generated candidate text. Supported MIME types are:
816        ///     - text/plain:  (default) Text output
817        ///     - application/json: JSON response in the response candidates.
818        ///     - text/x.enum: ENUM as a string response in the response candidates.
819        /// Refer to the docs for a list of all supported text MIME types
820        #[serde(skip_serializing_if = "Option::is_none")]
821        pub response_mime_type: Option<String>,
822        /// Output schema of the generated candidate text. Schemas must be a subset of the OpenAPI schema and can be
823        /// objects, primitives or arrays. If set, a compatible responseMimeType must also  be set. Compatible MIME
824        /// types: application/json: Schema for JSON response. Refer to the JSON text generation guide for more details.
825        #[serde(skip_serializing_if = "Option::is_none")]
826        pub response_schema: Option<Schema>,
827        /// Number of generated responses to return. Currently, this value can only be set to 1. If
828        /// unset, this will default to 1.
829        #[serde(skip_serializing_if = "Option::is_none")]
830        pub candidate_count: Option<i32>,
831        /// The maximum number of tokens to include in a response candidate. Note: The default value varies by model, see
832        /// the Model.output_token_limit attribute of the Model returned from the getModel function.
833        #[serde(skip_serializing_if = "Option::is_none")]
834        pub max_output_tokens: Option<u64>,
835        /// Controls the randomness of the output. Note: The default value varies by model, see the Model.temperature
836        /// attribute of the Model returned from the getModel function. Values can range from [0.0, 2.0].
837        #[serde(skip_serializing_if = "Option::is_none")]
838        pub temperature: Option<f64>,
839        /// The maximum cumulative probability of tokens to consider when sampling. The model uses combined Top-k and
840        /// Top-p (nucleus) sampling. Tokens are sorted based on their assigned probabilities so that only the most
841        /// likely tokens are considered. Top-k sampling directly limits the maximum number of tokens to consider, while
842        /// Nucleus sampling limits the number of tokens based on the cumulative probability. Note: The default value
843        /// varies by Model and is specified by theModel.top_p attribute returned from the getModel function. An empty
844        /// topK attribute indicates that the model doesn't apply top-k sampling and doesn't allow setting topK on requests.
845        #[serde(skip_serializing_if = "Option::is_none")]
846        pub top_p: Option<f64>,
847        /// The maximum number of tokens to consider when sampling. Gemini models use Top-p (nucleus) sampling or a
848        /// combination of Top-k and nucleus sampling. Top-k sampling considers the set of topK most probable tokens.
849        /// Models running with nucleus sampling don't allow topK setting. Note: The default value varies by Model and is
850        /// specified by theModel.top_p attribute returned from the getModel function. An empty topK attribute indicates
851        /// that the model doesn't apply top-k sampling and doesn't allow setting topK on requests.
852        #[serde(skip_serializing_if = "Option::is_none")]
853        pub top_k: Option<i32>,
854        /// Presence penalty applied to the next token's logprobs if the token has already been seen in the response.
855        /// This penalty is binary on/off and not dependent on the number of times the token is used (after the first).
856        /// Use frequencyPenalty for a penalty that increases with each use. A positive penalty will discourage the use
857        /// of tokens that have already been used in the response, increasing the vocabulary. A negative penalty will
858        /// encourage the use of tokens that have already been used in the response, decreasing the vocabulary.
859        #[serde(skip_serializing_if = "Option::is_none")]
860        pub presence_penalty: Option<f64>,
861        /// Frequency penalty applied to the next token's logprobs, multiplied by the number of times each token has been
862        /// seen in the response so far. A positive penalty will discourage the use of tokens that have already been
863        /// used, proportional to the number of times the token has been used: The more a token is used, the more
864        /// difficult it is for the  model to use that token again increasing the vocabulary of responses. Caution: A
865        /// negative penalty will encourage the model to reuse tokens proportional to the number of times the token has
866        /// been used. Small negative values will reduce the vocabulary of a response. Larger negative values will cause
867        /// the model to  repeating a common token until it hits the maxOutputTokens limit: "...the the the the the...".
868        #[serde(skip_serializing_if = "Option::is_none")]
869        pub frequency_penalty: Option<f64>,
870        /// If true, export the logprobs results in response.
871        #[serde(skip_serializing_if = "Option::is_none")]
872        pub response_logprobs: Option<bool>,
873        /// Only valid if responseLogprobs=True. This sets the number of top logprobs to return at each decoding step in
874        /// [Candidate.logprobs_result].
875        #[serde(skip_serializing_if = "Option::is_none")]
876        pub logprobs: Option<i32>,
877    }
878
879    impl Default for GenerationConfig {
880        fn default() -> Self {
881            Self {
882                temperature: Some(1.0),
883                max_output_tokens: Some(4096),
884                stop_sequences: None,
885                response_mime_type: None,
886                response_schema: None,
887                candidate_count: None,
888                top_p: None,
889                top_k: None,
890                presence_penalty: None,
891                frequency_penalty: None,
892                response_logprobs: None,
893                logprobs: None,
894            }
895        }
896    }
897    /// The Schema object allows the definition of input and output data types. These types can be objects, but also
898    /// primitives and arrays. Represents a select subset of an OpenAPI 3.0 schema object.
899    /// From [Gemini API Reference](https://ai.google.dev/api/caching#Schema)
900    #[derive(Debug, Deserialize, Serialize, Clone)]
901    pub struct Schema {
902        pub r#type: String,
903        #[serde(skip_serializing_if = "Option::is_none")]
904        pub format: Option<String>,
905        #[serde(skip_serializing_if = "Option::is_none")]
906        pub description: Option<String>,
907        #[serde(skip_serializing_if = "Option::is_none")]
908        pub nullable: Option<bool>,
909        #[serde(skip_serializing_if = "Option::is_none")]
910        pub r#enum: Option<Vec<String>>,
911        #[serde(skip_serializing_if = "Option::is_none")]
912        pub max_items: Option<i32>,
913        #[serde(skip_serializing_if = "Option::is_none")]
914        pub min_items: Option<i32>,
915        #[serde(skip_serializing_if = "Option::is_none")]
916        pub properties: Option<HashMap<String, Schema>>,
917        #[serde(skip_serializing_if = "Option::is_none")]
918        pub required: Option<Vec<String>>,
919        #[serde(skip_serializing_if = "Option::is_none")]
920        pub items: Option<Box<Schema>>,
921    }
922
923    impl TryFrom<Value> for Schema {
924        type Error = CompletionError;
925
926        fn try_from(value: Value) -> Result<Self, Self::Error> {
927            if let Some(obj) = value.as_object() {
928                Ok(Schema {
929                    r#type: obj
930                        .get("type")
931                        .and_then(|v| {
932                            if v.is_string() {
933                                v.as_str().map(String::from)
934                            } else if v.is_array() {
935                                v.as_array()
936                                    .and_then(|arr| arr.first())
937                                    .and_then(|v| v.as_str().map(String::from))
938                            } else {
939                                None
940                            }
941                        })
942                        .unwrap_or_default(),
943                    format: obj.get("format").and_then(|v| v.as_str()).map(String::from),
944                    description: obj
945                        .get("description")
946                        .and_then(|v| v.as_str())
947                        .map(String::from),
948                    nullable: obj.get("nullable").and_then(|v| v.as_bool()),
949                    r#enum: obj.get("enum").and_then(|v| v.as_array()).map(|arr| {
950                        arr.iter()
951                            .filter_map(|v| v.as_str().map(String::from))
952                            .collect()
953                    }),
954                    max_items: obj
955                        .get("maxItems")
956                        .and_then(|v| v.as_i64())
957                        .map(|v| v as i32),
958                    min_items: obj
959                        .get("minItems")
960                        .and_then(|v| v.as_i64())
961                        .map(|v| v as i32),
962                    properties: obj
963                        .get("properties")
964                        .and_then(|v| v.as_object())
965                        .map(|map| {
966                            map.iter()
967                                .filter_map(|(k, v)| {
968                                    v.clone().try_into().ok().map(|schema| (k.clone(), schema))
969                                })
970                                .collect()
971                        }),
972                    required: obj.get("required").and_then(|v| v.as_array()).map(|arr| {
973                        arr.iter()
974                            .filter_map(|v| v.as_str().map(String::from))
975                            .collect()
976                    }),
977                    items: obj
978                        .get("items")
979                        .map(|v| Box::new(v.clone().try_into().unwrap())),
980                })
981            } else {
982                Err(CompletionError::ResponseError(
983                    "Expected a JSON object for Schema".into(),
984                ))
985            }
986        }
987    }
988
989    #[derive(Debug, Serialize)]
990    #[serde(rename_all = "camelCase")]
991    pub struct GenerateContentRequest {
992        pub contents: Vec<Content>,
993        #[serde(skip_serializing_if = "Option::is_none")]
994        pub tools: Option<Tool>,
995        pub tool_config: Option<ToolConfig>,
996        /// Optional. Configuration options for model generation and outputs.
997        pub generation_config: Option<GenerationConfig>,
998        /// Optional. A list of unique SafetySetting instances for blocking unsafe content. This will be enforced on the
999        /// [GenerateContentRequest.contents] and [GenerateContentResponse.candidates]. There should not be more than one
1000        /// setting for each SafetyCategory type. The API will block any contents and responses that fail to meet the
1001        /// thresholds set by these settings. This list overrides the default settings for each SafetyCategory specified
1002        /// in the safetySettings. If there is no SafetySetting for a given SafetyCategory provided in the list, the API
1003        /// will use the default safety setting for that category. Harm categories:
1004        ///     - HARM_CATEGORY_HATE_SPEECH,
1005        ///     - HARM_CATEGORY_SEXUALLY_EXPLICIT
1006        ///     - HARM_CATEGORY_DANGEROUS_CONTENT
1007        ///     - HARM_CATEGORY_HARASSMENT
1008        /// are supported.
1009        /// Refer to the guide for detailed information on available safety settings. Also refer to the Safety guidance
1010        /// to learn how to incorporate safety considerations in your AI applications.
1011        pub safety_settings: Option<Vec<SafetySetting>>,
1012        /// Optional. Developer set system instruction(s). Currently, text only.
1013        /// From [Gemini API Reference](https://ai.google.dev/gemini-api/docs/system-instructions?lang=rest)
1014        pub system_instruction: Option<Content>,
1015        // cachedContent: Optional<String>
1016    }
1017
1018    #[derive(Debug, Serialize)]
1019    #[serde(rename_all = "camelCase")]
1020    pub struct Tool {
1021        pub function_declarations: Vec<FunctionDeclaration>,
1022        pub code_execution: Option<CodeExecution>,
1023    }
1024
1025    #[derive(Debug, Serialize, Clone)]
1026    #[serde(rename_all = "camelCase")]
1027    pub struct FunctionDeclaration {
1028        pub name: String,
1029        pub description: String,
1030        #[serde(skip_serializing_if = "Option::is_none")]
1031        pub parameters: Option<Schema>,
1032    }
1033
1034    #[derive(Debug, Serialize)]
1035    #[serde(rename_all = "camelCase")]
1036    pub struct ToolConfig {
1037        pub schema: Option<Schema>,
1038    }
1039
1040    #[derive(Debug, Serialize)]
1041    #[serde(rename_all = "camelCase")]
1042    pub struct CodeExecution {}
1043
1044    #[derive(Debug, Serialize)]
1045    #[serde(rename_all = "camelCase")]
1046    pub struct SafetySetting {
1047        pub category: HarmCategory,
1048        pub threshold: HarmBlockThreshold,
1049    }
1050
1051    #[derive(Debug, Serialize)]
1052    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1053    pub enum HarmBlockThreshold {
1054        HarmBlockThresholdUnspecified,
1055        BlockLowAndAbove,
1056        BlockMediumAndAbove,
1057        BlockOnlyHigh,
1058        BlockNone,
1059        Off,
1060    }
1061}
1062
1063#[cfg(test)]
1064mod tests {
1065    use crate::message;
1066
1067    use super::*;
1068    use serde_json::json;
1069
1070    #[test]
1071    fn test_deserialize_message_user() {
1072        let raw_message = r#"{
1073            "parts": [
1074                {"text": "Hello, world!"},
1075                {"inlineData": {"mimeType": "image/png", "data": "base64encodeddata"}},
1076                {"functionCall": {"name": "test_function", "args": {"arg1": "value1"}}},
1077                {"functionResponse": {"name": "test_function", "response": {"result": "success"}}},
1078                {"fileData": {"mimeType": "application/pdf", "fileUri": "http://example.com/file.pdf"}},
1079                {"executableCode": {"code": "print('Hello, world!')", "language": "PYTHON"}},
1080                {"codeExecutionResult": {"output": "Hello, world!", "outcome": "OUTCOME_OK"}}
1081            ],
1082            "role": "user"
1083        }"#;
1084
1085        let content: Content = {
1086            let jd = &mut serde_json::Deserializer::from_str(raw_message);
1087            serde_path_to_error::deserialize(jd).unwrap_or_else(|err| {
1088                panic!("Deserialization error at {}: {}", err.path(), err);
1089            })
1090        };
1091        assert_eq!(content.role, Some(Role::User));
1092        assert_eq!(content.parts.len(), 7);
1093
1094        let parts: Vec<Part> = content.parts.into_iter().collect();
1095
1096        if let Part::Text(text) = &parts[0] {
1097            assert_eq!(text, "Hello, world!");
1098        } else {
1099            panic!("Expected text part");
1100        }
1101
1102        if let Part::InlineData(inline_data) = &parts[1] {
1103            assert_eq!(inline_data.mime_type, "image/png");
1104            assert_eq!(inline_data.data, "base64encodeddata");
1105        } else {
1106            panic!("Expected inline data part");
1107        }
1108
1109        if let Part::FunctionCall(function_call) = &parts[2] {
1110            assert_eq!(function_call.name, "test_function");
1111            assert_eq!(
1112                function_call.args.as_object().unwrap().get("arg1").unwrap(),
1113                "value1"
1114            );
1115        } else {
1116            panic!("Expected function call part");
1117        }
1118
1119        if let Part::FunctionResponse(function_response) = &parts[3] {
1120            assert_eq!(function_response.name, "test_function");
1121            assert_eq!(
1122                function_response
1123                    .response
1124                    .as_ref()
1125                    .unwrap()
1126                    .get("result")
1127                    .unwrap(),
1128                "success"
1129            );
1130        } else {
1131            panic!("Expected function response part");
1132        }
1133
1134        if let Part::FileData(file_data) = &parts[4] {
1135            assert_eq!(file_data.mime_type.as_ref().unwrap(), "application/pdf");
1136            assert_eq!(file_data.file_uri, "http://example.com/file.pdf");
1137        } else {
1138            panic!("Expected file data part");
1139        }
1140
1141        if let Part::ExecutableCode(executable_code) = &parts[5] {
1142            assert_eq!(executable_code.code, "print('Hello, world!')");
1143        } else {
1144            panic!("Expected executable code part");
1145        }
1146
1147        if let Part::CodeExecutionResult(code_execution_result) = &parts[6] {
1148            assert_eq!(
1149                code_execution_result.clone().output.unwrap(),
1150                "Hello, world!"
1151            );
1152        } else {
1153            panic!("Expected code execution result part");
1154        }
1155    }
1156
1157    #[test]
1158    fn test_deserialize_message_model() {
1159        let json_data = json!({
1160            "parts": [{"text": "Hello, user!"}],
1161            "role": "model"
1162        });
1163
1164        let content: Content = serde_json::from_value(json_data).unwrap();
1165        assert_eq!(content.role, Some(Role::Model));
1166        assert_eq!(content.parts.len(), 1);
1167        if let Part::Text(text) = &content.parts.first() {
1168            assert_eq!(text, "Hello, user!");
1169        } else {
1170            panic!("Expected text part");
1171        }
1172    }
1173
1174    #[test]
1175    fn test_message_conversion_user() {
1176        let msg = message::Message::user("Hello, world!");
1177        let content: Content = msg.try_into().unwrap();
1178        assert_eq!(content.role, Some(Role::User));
1179        assert_eq!(content.parts.len(), 1);
1180        if let Part::Text(text) = &content.parts.first() {
1181            assert_eq!(text, "Hello, world!");
1182        } else {
1183            panic!("Expected text part");
1184        }
1185    }
1186
1187    #[test]
1188    fn test_message_conversion_model() {
1189        let msg = message::Message::assistant("Hello, user!");
1190
1191        let content: Content = msg.try_into().unwrap();
1192        assert_eq!(content.role, Some(Role::Model));
1193        assert_eq!(content.parts.len(), 1);
1194        if let Part::Text(text) = &content.parts.first() {
1195            assert_eq!(text, "Hello, user!");
1196        } else {
1197            panic!("Expected text part");
1198        }
1199    }
1200
1201    #[test]
1202    fn test_message_conversion_tool_call() {
1203        let tool_call = message::ToolCall {
1204            id: "test_tool".to_string(),
1205            call_id: None,
1206            function: message::ToolFunction {
1207                name: "test_function".to_string(),
1208                arguments: json!({"arg1": "value1"}),
1209            },
1210        };
1211
1212        let msg = message::Message::Assistant {
1213            id: None,
1214            content: OneOrMany::one(message::AssistantContent::ToolCall(tool_call)),
1215        };
1216
1217        let content: Content = msg.try_into().unwrap();
1218        assert_eq!(content.role, Some(Role::Model));
1219        assert_eq!(content.parts.len(), 1);
1220        if let Part::FunctionCall(function_call) = &content.parts.first() {
1221            assert_eq!(function_call.name, "test_function");
1222            assert_eq!(
1223                function_call.args.as_object().unwrap().get("arg1").unwrap(),
1224                "value1"
1225            );
1226        } else {
1227            panic!("Expected function call part");
1228        }
1229    }
1230}