rig/providers/gemini/
completion.rs

1// ================================================================
2//! Google Gemini Completion Integration
3//! From [Gemini API Reference](https://ai.google.dev/api/generate-content)
4// ================================================================
5/// `gemini-2.5-pro-preview-06-05` completion model
6pub const GEMINI_2_5_PRO_PREVIEW_06_05: &str = "gemini-2.5-pro-preview-06-05";
7/// `gemini-2.5-pro-preview-05-06` completion model
8pub const GEMINI_2_5_PRO_PREVIEW_05_06: &str = "gemini-2.5-pro-preview-05-06";
9/// `gemini-2.5-pro-preview-03-25` completion model
10pub const GEMINI_2_5_PRO_PREVIEW_03_25: &str = "gemini-2.5-pro-preview-03-25";
11/// `gemini-2.5-flash-preview-05-20` completion model
12pub const GEMINI_2_5_FLASH_PREVIEW_05_20: &str = "gemini-2.5-flash-preview-05-20";
13/// `gemini-2.5-flash-preview-04-17` completion model
14pub const GEMINI_2_5_FLASH_PREVIEW_04_17: &str = "gemini-2.5-flash-preview-04-17";
15/// `gemini-2.5-pro-exp-03-25` experimental completion model
16pub const GEMINI_2_5_PRO_EXP_03_25: &str = "gemini-2.5-pro-exp-03-25";
17/// `gemini-2.0-flash-lite` completion model
18pub const GEMINI_2_0_FLASH_LITE: &str = "gemini-2.0-flash-lite";
19/// `gemini-2.0-flash` completion model
20pub const GEMINI_2_0_FLASH: &str = "gemini-2.0-flash";
21/// `gemini-1.5-flash` completion model
22pub const GEMINI_1_5_FLASH: &str = "gemini-1.5-flash";
23/// `gemini-1.5-pro` completion model
24pub const GEMINI_1_5_PRO: &str = "gemini-1.5-pro";
25/// `gemini-1.5-pro-8b` completion model
26pub const GEMINI_1_5_PRO_8B: &str = "gemini-1.5-pro-8b";
27/// `gemini-1.0-pro` completion model
28pub const GEMINI_1_0_PRO: &str = "gemini-1.0-pro";
29
30use self::gemini_api_types::Schema;
31use crate::message::Reasoning;
32use crate::providers::gemini::completion::gemini_api_types::{
33    AdditionalParameters, FunctionCallingMode, ToolConfig,
34};
35use crate::providers::gemini::streaming::StreamingCompletionResponse;
36use crate::telemetry::SpanCombinator;
37use crate::{
38    OneOrMany,
39    completion::{self, CompletionError, CompletionRequest},
40};
41use gemini_api_types::{
42    Content, FunctionDeclaration, GenerateContentRequest, GenerateContentResponse, Part, PartKind,
43    Role, Tool,
44};
45use serde_json::{Map, Value};
46use std::convert::TryFrom;
47use tracing::info_span;
48
49use super::Client;
50
51// =================================================================
52// Rig Implementation Types
53// =================================================================
54
55#[derive(Clone)]
56pub struct CompletionModel<T = reqwest::Client> {
57    pub(crate) client: Client<T>,
58    pub model: String,
59}
60
61impl<T> CompletionModel<T> {
62    pub fn new(client: Client<T>, model: &str) -> Self {
63        Self {
64            client,
65            model: model.to_string(),
66        }
67    }
68}
69
70impl completion::CompletionModel for CompletionModel<reqwest::Client> {
71    type Response = GenerateContentResponse;
72    type StreamingResponse = StreamingCompletionResponse;
73
74    #[cfg_attr(feature = "worker", worker::send)]
75    async fn completion(
76        &self,
77        completion_request: CompletionRequest,
78    ) -> Result<completion::CompletionResponse<GenerateContentResponse>, CompletionError> {
79        let span = if tracing::Span::current().is_disabled() {
80            info_span!(
81                target: "rig::completions",
82                "generate_content",
83                gen_ai.operation.name = "generate_content",
84                gen_ai.provider.name = "gcp.gemini",
85                gen_ai.request.model = self.model,
86                gen_ai.system_instructions = &completion_request.preamble,
87                gen_ai.response.id = tracing::field::Empty,
88                gen_ai.response.model = tracing::field::Empty,
89                gen_ai.usage.output_tokens = tracing::field::Empty,
90                gen_ai.usage.input_tokens = tracing::field::Empty,
91                gen_ai.input.messages = tracing::field::Empty,
92                gen_ai.output.messages = tracing::field::Empty,
93            )
94        } else {
95            tracing::Span::current()
96        };
97
98        let request = create_request_body(completion_request)?;
99        span.record_model_input(&request.contents);
100
101        span.record_model_input(&request.contents);
102
103        tracing::debug!(
104            "Sending completion request to Gemini API {}",
105            serde_json::to_string_pretty(&request)?
106        );
107
108        let body = serde_json::to_vec(&request)?;
109
110        let request = self
111            .client
112            .post(&format!("/v1beta/models/{}:generateContent", self.model))
113            .header("Content-Type", "application/json")
114            .body(body)
115            .map_err(|e| CompletionError::HttpError(e.into()))?;
116
117        let response = self.client.send::<_, Vec<u8>>(request).await?;
118
119        if response.status().is_success() {
120            let response_body = response
121                .into_body()
122                .await
123                .map_err(CompletionError::HttpError)?;
124
125            let response: GenerateContentResponse = serde_json::from_slice(&response_body)?;
126
127            match response.usage_metadata {
128                Some(ref usage) => tracing::info!(target: "rig",
129                "Gemini completion token usage: {}",
130                usage
131                ),
132                None => tracing::info!(target: "rig",
133                    "Gemini completion token usage: n/a",
134                ),
135            }
136
137            let span = tracing::Span::current();
138            span.record_model_output(&response.candidates);
139            span.record_response_metadata(&response);
140            span.record_token_usage(&response.usage_metadata);
141
142            tracing::debug!(
143                "Received response from Gemini API: {}",
144                serde_json::to_string_pretty(&response)?
145            );
146
147            response.try_into()
148        } else {
149            let text = String::from_utf8_lossy(
150                &response
151                    .into_body()
152                    .await
153                    .map_err(CompletionError::HttpError)?,
154            )
155            .into();
156
157            Err(CompletionError::ProviderError(text))
158        }
159    }
160
161    #[cfg_attr(feature = "worker", worker::send)]
162    async fn stream(
163        &self,
164        request: CompletionRequest,
165    ) -> Result<
166        crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
167        CompletionError,
168    > {
169        CompletionModel::stream(self, request).await
170    }
171}
172
173pub(crate) fn create_request_body(
174    completion_request: CompletionRequest,
175) -> Result<GenerateContentRequest, CompletionError> {
176    let mut full_history = Vec::new();
177    full_history.extend(completion_request.chat_history);
178
179    let additional_params = completion_request
180        .additional_params
181        .unwrap_or_else(|| Value::Object(Map::new()));
182
183    let AdditionalParameters {
184        mut generation_config,
185        additional_params,
186    } = serde_json::from_value::<AdditionalParameters>(additional_params)?;
187
188    if let Some(temp) = completion_request.temperature {
189        generation_config.temperature = Some(temp);
190    }
191
192    if let Some(max_tokens) = completion_request.max_tokens {
193        generation_config.max_output_tokens = Some(max_tokens);
194    }
195
196    let system_instruction = completion_request.preamble.clone().map(|preamble| Content {
197        parts: vec![preamble.into()],
198        role: Some(Role::Model),
199    });
200
201    let tools = if completion_request.tools.is_empty() {
202        None
203    } else {
204        Some(Tool::try_from(completion_request.tools)?)
205    };
206
207    let tool_config = if let Some(cfg) = completion_request.tool_choice {
208        Some(ToolConfig {
209            function_calling_config: Some(FunctionCallingMode::try_from(cfg)?),
210        })
211    } else {
212        None
213    };
214
215    let request = GenerateContentRequest {
216        contents: full_history
217            .into_iter()
218            .map(|msg| {
219                msg.try_into()
220                    .map_err(|e| CompletionError::RequestError(Box::new(e)))
221            })
222            .collect::<Result<Vec<_>, _>>()?,
223        generation_config: Some(generation_config),
224        safety_settings: None,
225        tools,
226        tool_config,
227        system_instruction,
228        additional_params,
229    };
230
231    Ok(request)
232}
233
234impl TryFrom<completion::ToolDefinition> for Tool {
235    type Error = CompletionError;
236
237    fn try_from(tool: completion::ToolDefinition) -> Result<Self, Self::Error> {
238        let parameters: Option<Schema> =
239            if tool.parameters == serde_json::json!({"type": "object", "properties": {}}) {
240                None
241            } else {
242                Some(tool.parameters.try_into()?)
243            };
244
245        Ok(Self {
246            function_declarations: vec![FunctionDeclaration {
247                name: tool.name,
248                description: tool.description,
249                parameters,
250            }],
251            code_execution: None,
252        })
253    }
254}
255
256impl TryFrom<Vec<completion::ToolDefinition>> for Tool {
257    type Error = CompletionError;
258
259    fn try_from(tools: Vec<completion::ToolDefinition>) -> Result<Self, Self::Error> {
260        let mut function_declarations = Vec::new();
261
262        for tool in tools {
263            let parameters =
264                if tool.parameters == serde_json::json!({"type": "object", "properties": {}}) {
265                    None
266                } else {
267                    match tool.parameters.try_into() {
268                        Ok(schema) => Some(schema),
269                        Err(e) => {
270                            let emsg = format!(
271                                "Tool '{}' could not be converted to a schema: {:?}",
272                                tool.name, e,
273                            );
274                            return Err(CompletionError::ProviderError(emsg));
275                        }
276                    }
277                };
278
279            function_declarations.push(FunctionDeclaration {
280                name: tool.name,
281                description: tool.description,
282                parameters,
283            });
284        }
285
286        Ok(Self {
287            function_declarations,
288            code_execution: None,
289        })
290    }
291}
292
293impl TryFrom<GenerateContentResponse> for completion::CompletionResponse<GenerateContentResponse> {
294    type Error = CompletionError;
295
296    fn try_from(response: GenerateContentResponse) -> Result<Self, Self::Error> {
297        let candidate = response.candidates.first().ok_or_else(|| {
298            CompletionError::ResponseError("No response candidates in response".into())
299        })?;
300
301        let content = candidate
302            .content
303            .parts
304            .iter()
305            .map(|Part { thought, part, .. }| {
306                Ok(match part {
307                    PartKind::Text(text) => {
308                        if let Some(thought) = thought
309                            && *thought
310                        {
311                            completion::AssistantContent::Reasoning(Reasoning::new(text))
312                        } else {
313                            completion::AssistantContent::text(text)
314                        }
315                    }
316                    PartKind::FunctionCall(function_call) => {
317                        completion::AssistantContent::tool_call(
318                            &function_call.name,
319                            &function_call.name,
320                            function_call.args.clone(),
321                        )
322                    }
323                    _ => {
324                        return Err(CompletionError::ResponseError(
325                            "Response did not contain a message or tool call".into(),
326                        ));
327                    }
328                })
329            })
330            .collect::<Result<Vec<_>, _>>()?;
331
332        let choice = OneOrMany::many(content).map_err(|_| {
333            CompletionError::ResponseError(
334                "Response contained no message or tool call (empty)".to_owned(),
335            )
336        })?;
337
338        let usage = response
339            .usage_metadata
340            .as_ref()
341            .map(|usage| completion::Usage {
342                input_tokens: usage.prompt_token_count as u64,
343                output_tokens: usage.candidates_token_count.unwrap_or(0) as u64,
344                total_tokens: usage.total_token_count as u64,
345            })
346            .unwrap_or_default();
347
348        Ok(completion::CompletionResponse {
349            choice,
350            usage,
351            raw_response: response,
352        })
353    }
354}
355
356pub mod gemini_api_types {
357    use crate::telemetry::ProviderResponseExt;
358    use std::{collections::HashMap, convert::Infallible, str::FromStr};
359
360    // =================================================================
361    // Gemini API Types
362    // =================================================================
363    use serde::{Deserialize, Serialize};
364    use serde_json::{Value, json};
365
366    use crate::completion::GetTokenUsage;
367    use crate::message::{DocumentSourceKind, ImageMediaType, MessageError, MimeType};
368    use crate::{
369        OneOrMany,
370        completion::CompletionError,
371        message::{self, Reasoning, Text},
372        providers::gemini::gemini_api_types::{CodeExecutionResult, ExecutableCode},
373    };
374
375    #[derive(Debug, Deserialize, Serialize, Default)]
376    #[serde(rename_all = "camelCase")]
377    pub struct AdditionalParameters {
378        /// Change your Gemini request configuration.
379        pub generation_config: GenerationConfig,
380        /// Any additional parameters that you want.
381        #[serde(flatten, skip_serializing_if = "Option::is_none")]
382        pub additional_params: Option<serde_json::Value>,
383    }
384
385    impl AdditionalParameters {
386        pub fn with_config(mut self, cfg: GenerationConfig) -> Self {
387            self.generation_config = cfg;
388            self
389        }
390
391        pub fn with_params(mut self, params: serde_json::Value) -> Self {
392            self.additional_params = Some(params);
393            self
394        }
395    }
396
397    /// Response from the model supporting multiple candidate responses.
398    /// Safety ratings and content filtering are reported for both prompt in GenerateContentResponse.prompt_feedback
399    /// and for each candidate in finishReason and in safetyRatings.
400    /// The API:
401    ///     - Returns either all requested candidates or none of them
402    ///     - Returns no candidates at all only if there was something wrong with the prompt (check promptFeedback)
403    ///     - Reports feedback on each candidate in finishReason and safetyRatings.
404    #[derive(Debug, Deserialize, Serialize)]
405    #[serde(rename_all = "camelCase")]
406    pub struct GenerateContentResponse {
407        pub response_id: String,
408        /// Candidate responses from the model.
409        pub candidates: Vec<ContentCandidate>,
410        /// Returns the prompt's feedback related to the content filters.
411        pub prompt_feedback: Option<PromptFeedback>,
412        /// Output only. Metadata on the generation requests' token usage.
413        pub usage_metadata: Option<UsageMetadata>,
414        pub model_version: Option<String>,
415    }
416
417    impl ProviderResponseExt for GenerateContentResponse {
418        type OutputMessage = ContentCandidate;
419        type Usage = UsageMetadata;
420
421        fn get_response_id(&self) -> Option<String> {
422            Some(self.response_id.clone())
423        }
424
425        fn get_response_model_name(&self) -> Option<String> {
426            None
427        }
428
429        fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
430            self.candidates.clone()
431        }
432
433        fn get_text_response(&self) -> Option<String> {
434            let str = self
435                .candidates
436                .iter()
437                .filter_map(|x| {
438                    if x.content.role.as_ref().is_none_or(|y| y != &Role::Model) {
439                        return None;
440                    }
441
442                    let res = x
443                        .content
444                        .parts
445                        .iter()
446                        .filter_map(|part| {
447                            if let PartKind::Text(ref str) = part.part {
448                                Some(str.to_owned())
449                            } else {
450                                None
451                            }
452                        })
453                        .collect::<Vec<String>>()
454                        .join("\n");
455
456                    Some(res)
457                })
458                .collect::<Vec<String>>()
459                .join("\n");
460
461            if str.is_empty() { None } else { Some(str) }
462        }
463
464        fn get_usage(&self) -> Option<Self::Usage> {
465            self.usage_metadata.clone()
466        }
467    }
468
469    /// A response candidate generated from the model.
470    #[derive(Clone, Debug, Deserialize, Serialize)]
471    #[serde(rename_all = "camelCase")]
472    pub struct ContentCandidate {
473        /// Output only. Generated content returned from the model.
474        pub content: Content,
475        /// Optional. Output only. The reason why the model stopped generating tokens.
476        /// If empty, the model has not stopped generating tokens.
477        pub finish_reason: Option<FinishReason>,
478        /// List of ratings for the safety of a response candidate.
479        /// There is at most one rating per category.
480        pub safety_ratings: Option<Vec<SafetyRating>>,
481        /// Output only. Citation information for model-generated candidate.
482        /// This field may be populated with recitation information for any text included in the content.
483        /// These are passages that are "recited" from copyrighted material in the foundational LLM's training data.
484        pub citation_metadata: Option<CitationMetadata>,
485        /// Output only. Token count for this candidate.
486        pub token_count: Option<i32>,
487        /// Output only.
488        pub avg_logprobs: Option<f64>,
489        /// Output only. Log-likelihood scores for the response tokens and top tokens
490        pub logprobs_result: Option<LogprobsResult>,
491        /// Output only. Index of the candidate in the list of response candidates.
492        pub index: Option<i32>,
493    }
494
495    #[derive(Clone, Debug, Deserialize, Serialize)]
496    pub struct Content {
497        /// Ordered Parts that constitute a single message. Parts may have different MIME types.
498        #[serde(default)]
499        pub parts: Vec<Part>,
500        /// The producer of the content. Must be either 'user' or 'model'.
501        /// Useful to set for multi-turn conversations, otherwise can be left blank or unset.
502        pub role: Option<Role>,
503    }
504
505    impl TryFrom<message::Message> for Content {
506        type Error = message::MessageError;
507
508        fn try_from(msg: message::Message) -> Result<Self, Self::Error> {
509            Ok(match msg {
510                message::Message::User { content } => Content {
511                    parts: content
512                        .into_iter()
513                        .map(|c| c.try_into())
514                        .collect::<Result<Vec<_>, _>>()?,
515                    role: Some(Role::User),
516                },
517                message::Message::Assistant { content, .. } => Content {
518                    role: Some(Role::Model),
519                    parts: content.into_iter().map(|content| content.into()).collect(),
520                },
521            })
522        }
523    }
524
525    impl TryFrom<Content> for message::Message {
526        type Error = message::MessageError;
527
528        fn try_from(content: Content) -> Result<Self, Self::Error> {
529            match content.role {
530                Some(Role::User) | None => {
531                    Ok(message::Message::User {
532                        content: {
533                            let user_content: Result<Vec<_>, _> = content.parts.into_iter()
534                            .map(|Part { part, .. }| {
535                                Ok(match part {
536                                    PartKind::Text(text) => message::UserContent::text(text),
537                                    PartKind::InlineData(inline_data) => {
538                                        let mime_type =
539                                            message::MediaType::from_mime_type(&inline_data.mime_type);
540
541                                        match mime_type {
542                                            Some(message::MediaType::Image(media_type)) => {
543                                                message::UserContent::image_base64(
544                                                    inline_data.data,
545                                                    Some(media_type),
546                                                    Some(message::ImageDetail::default()),
547                                                )
548                                            }
549                                            Some(message::MediaType::Document(media_type)) => {
550                                                message::UserContent::document(
551                                                    inline_data.data,
552                                                    Some(media_type),
553                                                )
554                                            }
555                                            Some(message::MediaType::Audio(media_type)) => {
556                                                message::UserContent::audio(
557                                                    inline_data.data,
558                                                    Some(media_type),
559                                                )
560                                            }
561                                            _ => {
562                                                return Err(message::MessageError::ConversionError(
563                                                    format!("Unsupported media type {mime_type:?}"),
564                                                ));
565                                            }
566                                        }
567                                    }
568                                    _ => {
569                                        return Err(message::MessageError::ConversionError(format!(
570                                            "Unsupported gemini content part type: {part:?}"
571                                        )));
572                                    }
573                                })
574                            })
575                            .collect();
576                            OneOrMany::many(user_content?).map_err(|_| {
577                                message::MessageError::ConversionError(
578                                    "Failed to create OneOrMany from user content".to_string(),
579                                )
580                            })?
581                        },
582                    })
583                }
584                Some(Role::Model) => Ok(message::Message::Assistant {
585                    id: None,
586                    content: {
587                        let assistant_content: Result<Vec<_>, _> = content
588                            .parts
589                            .into_iter()
590                            .map(|Part { thought, part, .. }| {
591                                Ok(match part {
592                                    PartKind::Text(text) => match thought {
593                                        Some(true) => message::AssistantContent::Reasoning(
594                                            Reasoning::new(&text),
595                                        ),
596                                        _ => message::AssistantContent::Text(Text { text }),
597                                    },
598
599                                    PartKind::FunctionCall(function_call) => {
600                                        message::AssistantContent::ToolCall(function_call.into())
601                                    }
602                                    _ => {
603                                        return Err(message::MessageError::ConversionError(
604                                            format!("Unsupported part type: {part:?}"),
605                                        ));
606                                    }
607                                })
608                            })
609                            .collect();
610                        OneOrMany::many(assistant_content?).map_err(|_| {
611                            message::MessageError::ConversionError(
612                                "Failed to create OneOrMany from assistant content".to_string(),
613                            )
614                        })?
615                    },
616                }),
617            }
618        }
619    }
620
621    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
622    #[serde(rename_all = "lowercase")]
623    pub enum Role {
624        User,
625        Model,
626    }
627
628    #[derive(Debug, Default, Deserialize, Serialize, Clone, PartialEq)]
629    #[serde(rename_all = "camelCase")]
630    pub struct Part {
631        /// whether or not the part is a reasoning/thinking text or not
632        #[serde(skip_serializing_if = "Option::is_none")]
633        pub thought: Option<bool>,
634        /// an opaque sig for the thought so it can be reused - is a base64 string
635        #[serde(skip_serializing_if = "Option::is_none")]
636        pub thought_signature: Option<String>,
637        #[serde(flatten)]
638        pub part: PartKind,
639        #[serde(flatten, skip_serializing_if = "Option::is_none")]
640        pub additional_params: Option<Value>,
641    }
642
643    /// A datatype containing media that is part of a multi-part [Content] message.
644    /// A Part consists of data which has an associated datatype. A Part can only contain one of the accepted types in Part.data.
645    /// A Part must have a fixed IANA MIME type identifying the type and subtype of the media if the inlineData field is filled with raw bytes.
646    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
647    #[serde(rename_all = "camelCase")]
648    pub enum PartKind {
649        Text(String),
650        InlineData(Blob),
651        FunctionCall(FunctionCall),
652        FunctionResponse(FunctionResponse),
653        FileData(FileData),
654        ExecutableCode(ExecutableCode),
655        CodeExecutionResult(CodeExecutionResult),
656    }
657
658    // This default instance is primarily so we can easily fill in the optional fields of `Part`
659    // So this instance for `PartKind` (and the allocation it would cause) should be optimized away
660    impl Default for PartKind {
661        fn default() -> Self {
662            Self::Text(String::new())
663        }
664    }
665
666    impl From<String> for Part {
667        fn from(text: String) -> Self {
668            Self {
669                thought: Some(false),
670                thought_signature: None,
671                part: PartKind::Text(text),
672                additional_params: None,
673            }
674        }
675    }
676
677    impl From<&str> for Part {
678        fn from(text: &str) -> Self {
679            Self::from(text.to_string())
680        }
681    }
682
683    impl FromStr for Part {
684        type Err = Infallible;
685
686        fn from_str(s: &str) -> Result<Self, Self::Err> {
687            Ok(s.into())
688        }
689    }
690
691    impl TryFrom<(ImageMediaType, DocumentSourceKind)> for PartKind {
692        type Error = message::MessageError;
693        fn try_from(
694            (mime_type, doc_src): (ImageMediaType, DocumentSourceKind),
695        ) -> Result<Self, Self::Error> {
696            let mime_type = mime_type.to_mime_type().to_string();
697            let part = match doc_src {
698                DocumentSourceKind::Url(url) => PartKind::FileData(FileData {
699                    mime_type: Some(mime_type),
700                    file_uri: url,
701                }),
702                DocumentSourceKind::Base64(data) | DocumentSourceKind::String(data) => {
703                    PartKind::InlineData(Blob { mime_type, data })
704                }
705                DocumentSourceKind::Raw(_) => {
706                    return Err(message::MessageError::ConversionError(
707                        "Raw files not supported, encode as base64 first".into(),
708                    ));
709                }
710                DocumentSourceKind::Unknown => {
711                    return Err(message::MessageError::ConversionError(
712                        "Can't convert an unknown document source".to_string(),
713                    ));
714                }
715            };
716
717            Ok(part)
718        }
719    }
720
721    impl TryFrom<message::UserContent> for Part {
722        type Error = message::MessageError;
723
724        fn try_from(content: message::UserContent) -> Result<Self, Self::Error> {
725            match content {
726                message::UserContent::Text(message::Text { text }) => Ok(Part {
727                    thought: Some(false),
728                    thought_signature: None,
729                    part: PartKind::Text(text),
730                    additional_params: None,
731                }),
732                message::UserContent::ToolResult(message::ToolResult { id, content, .. }) => {
733                    let content = match content.first() {
734                        message::ToolResultContent::Text(text) => text.text,
735                        message::ToolResultContent::Image(_) => {
736                            return Err(message::MessageError::ConversionError(
737                                "Tool result content must be text".to_string(),
738                            ));
739                        }
740                    };
741                    // Convert to JSON since this value may be a valid JSON value
742                    let result: serde_json::Value =
743                        serde_json::from_str(&content).unwrap_or_else(|error| {
744                            tracing::trace!(
745                                ?error,
746                                "Tool result is not a valid JSON, treat it as normal string"
747                            );
748                            json!(content)
749                        });
750                    Ok(Part {
751                        thought: Some(false),
752                        thought_signature: None,
753                        part: PartKind::FunctionResponse(FunctionResponse {
754                            name: id,
755                            response: Some(json!({ "result": result })),
756                        }),
757                        additional_params: None,
758                    })
759                }
760                message::UserContent::Image(message::Image {
761                    data, media_type, ..
762                }) => match media_type {
763                    Some(media_type) => match media_type {
764                        message::ImageMediaType::JPEG
765                        | message::ImageMediaType::PNG
766                        | message::ImageMediaType::WEBP
767                        | message::ImageMediaType::HEIC
768                        | message::ImageMediaType::HEIF => {
769                            let part = PartKind::try_from((media_type, data))?;
770                            Ok(Part {
771                                thought: Some(false),
772                                thought_signature: None,
773                                part,
774                                additional_params: None,
775                            })
776                        }
777                        _ => Err(message::MessageError::ConversionError(format!(
778                            "Unsupported image media type {media_type:?}"
779                        ))),
780                    },
781                    None => Err(message::MessageError::ConversionError(
782                        "Media type for image is required for Gemini".to_string(),
783                    )),
784                },
785                message::UserContent::Document(message::Document {
786                    data, media_type, ..
787                }) => {
788                    let Some(media_type) = media_type else {
789                        return Err(MessageError::ConversionError(
790                            "A mime type is required for document inputs to Gemini".to_string(),
791                        ));
792                    };
793
794                    if !media_type.is_code() {
795                        let mime_type = media_type.to_mime_type().to_string();
796
797                        let part = match data {
798                            DocumentSourceKind::Url(file_uri) => PartKind::FileData(FileData {
799                                mime_type: Some(mime_type),
800                                file_uri,
801                            }),
802                            DocumentSourceKind::Base64(data) | DocumentSourceKind::String(data) => {
803                                PartKind::InlineData(Blob { mime_type, data })
804                            }
805                            DocumentSourceKind::Raw(_) => {
806                                return Err(message::MessageError::ConversionError(
807                                    "Raw files not supported, encode as base64 first".into(),
808                                ));
809                            }
810                            _ => {
811                                return Err(message::MessageError::ConversionError(
812                                    "Document has no body".to_string(),
813                                ));
814                            }
815                        };
816
817                        Ok(Part {
818                            thought: Some(false),
819                            part,
820                            ..Default::default()
821                        })
822                    } else {
823                        Err(message::MessageError::ConversionError(format!(
824                            "Unsupported document media type {media_type:?}"
825                        )))
826                    }
827                }
828
829                message::UserContent::Audio(message::Audio {
830                    data, media_type, ..
831                }) => {
832                    let Some(media_type) = media_type else {
833                        return Err(MessageError::ConversionError(
834                            "A mime type is required for audio inputs to Gemini".to_string(),
835                        ));
836                    };
837
838                    let mime_type = media_type.to_mime_type().to_string();
839
840                    let part = match data {
841                        DocumentSourceKind::Base64(data) => {
842                            PartKind::InlineData(Blob { data, mime_type })
843                        }
844
845                        DocumentSourceKind::Url(file_uri) => PartKind::FileData(FileData {
846                            mime_type: Some(mime_type),
847                            file_uri,
848                        }),
849                        DocumentSourceKind::String(_) => {
850                            return Err(message::MessageError::ConversionError(
851                                "Strings cannot be used as audio files!".into(),
852                            ));
853                        }
854                        DocumentSourceKind::Raw(_) => {
855                            return Err(message::MessageError::ConversionError(
856                                "Raw files not supported, encode as base64 first".into(),
857                            ));
858                        }
859                        DocumentSourceKind::Unknown => {
860                            return Err(message::MessageError::ConversionError(
861                                "Content has no body".to_string(),
862                            ));
863                        }
864                    };
865
866                    Ok(Part {
867                        thought: Some(false),
868                        part,
869                        ..Default::default()
870                    })
871                }
872                message::UserContent::Video(message::Video {
873                    data,
874                    media_type,
875                    additional_params,
876                    ..
877                }) => {
878                    let mime_type = media_type.map(|media_ty| media_ty.to_mime_type().to_string());
879
880                    let part = match data {
881                        DocumentSourceKind::Url(file_uri) => {
882                            if file_uri.starts_with("https://www.youtube.com") {
883                                PartKind::FileData(FileData {
884                                    mime_type,
885                                    file_uri,
886                                })
887                            } else {
888                                if mime_type.is_none() {
889                                    return Err(MessageError::ConversionError(
890                                        "A mime type is required for non-Youtube video file inputs to Gemini"
891                                            .to_string(),
892                                    ));
893                                }
894
895                                PartKind::FileData(FileData {
896                                    mime_type,
897                                    file_uri,
898                                })
899                            }
900                        }
901                        DocumentSourceKind::Base64(data) => {
902                            let Some(mime_type) = mime_type else {
903                                return Err(MessageError::ConversionError(
904                                    "A media type is expected for base64 encoded strings"
905                                        .to_string(),
906                                ));
907                            };
908                            PartKind::InlineData(Blob { mime_type, data })
909                        }
910                        DocumentSourceKind::String(_) => {
911                            return Err(message::MessageError::ConversionError(
912                                "Strings cannot be used as audio files!".into(),
913                            ));
914                        }
915                        DocumentSourceKind::Raw(_) => {
916                            return Err(message::MessageError::ConversionError(
917                                "Raw file data not supported, encode as base64 first".into(),
918                            ));
919                        }
920                        DocumentSourceKind::Unknown => {
921                            return Err(message::MessageError::ConversionError(
922                                "Media type for video is required for Gemini".to_string(),
923                            ));
924                        }
925                    };
926
927                    Ok(Part {
928                        thought: Some(false),
929                        thought_signature: None,
930                        part,
931                        additional_params,
932                    })
933                }
934            }
935        }
936    }
937
938    impl From<message::AssistantContent> for Part {
939        fn from(content: message::AssistantContent) -> Self {
940            match content {
941                message::AssistantContent::Text(message::Text { text }) => text.into(),
942                message::AssistantContent::ToolCall(tool_call) => tool_call.into(),
943                message::AssistantContent::Reasoning(message::Reasoning { reasoning, .. }) => {
944                    Part {
945                        thought: Some(true),
946                        thought_signature: None,
947                        part: PartKind::Text(
948                            reasoning.first().cloned().unwrap_or_else(|| "".to_string()),
949                        ),
950                        additional_params: None,
951                    }
952                }
953            }
954        }
955    }
956
957    impl From<message::ToolCall> for Part {
958        fn from(tool_call: message::ToolCall) -> Self {
959            Self {
960                thought: Some(false),
961                thought_signature: None,
962                part: PartKind::FunctionCall(FunctionCall {
963                    name: tool_call.function.name,
964                    args: tool_call.function.arguments,
965                }),
966                additional_params: None,
967            }
968        }
969    }
970
971    /// Raw media bytes.
972    /// Text should not be sent as raw bytes, use the 'text' field.
973    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
974    #[serde(rename_all = "camelCase")]
975    pub struct Blob {
976        /// The IANA standard MIME type of the source data. Examples: - image/png - image/jpeg
977        /// If an unsupported MIME type is provided, an error will be returned.
978        pub mime_type: String,
979        /// Raw bytes for media formats. A base64-encoded string.
980        pub data: String,
981    }
982
983    /// A predicted FunctionCall returned from the model that contains a string representing the
984    /// FunctionDeclaration.name with the arguments and their values.
985    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
986    pub struct FunctionCall {
987        /// Required. The name of the function to call. Must be a-z, A-Z, 0-9, or contain underscores
988        /// and dashes, with a maximum length of 63.
989        pub name: String,
990        /// Optional. The function parameters and values in JSON object format.
991        pub args: serde_json::Value,
992    }
993
994    impl From<FunctionCall> for message::ToolCall {
995        fn from(function_call: FunctionCall) -> Self {
996            Self {
997                id: function_call.name.clone(),
998                call_id: None,
999                function: message::ToolFunction {
1000                    name: function_call.name,
1001                    arguments: function_call.args,
1002                },
1003            }
1004        }
1005    }
1006
1007    impl From<message::ToolCall> for FunctionCall {
1008        fn from(tool_call: message::ToolCall) -> Self {
1009            Self {
1010                name: tool_call.function.name,
1011                args: tool_call.function.arguments,
1012            }
1013        }
1014    }
1015
1016    /// The result output from a FunctionCall that contains a string representing the FunctionDeclaration.name
1017    /// and a structured JSON object containing any output from the function is used as context to the model.
1018    /// This should contain the result of aFunctionCall made based on model prediction.
1019    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1020    pub struct FunctionResponse {
1021        /// The name of the function to call. Must be a-z, A-Z, 0-9, or contain underscores and dashes,
1022        /// with a maximum length of 63.
1023        pub name: String,
1024        /// The function response in JSON object format.
1025        pub response: Option<serde_json::Value>,
1026    }
1027
1028    /// URI based data.
1029    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1030    #[serde(rename_all = "camelCase")]
1031    pub struct FileData {
1032        /// Optional. The IANA standard MIME type of the source data.
1033        pub mime_type: Option<String>,
1034        /// Required. URI.
1035        pub file_uri: String,
1036    }
1037
1038    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1039    pub struct SafetyRating {
1040        pub category: HarmCategory,
1041        pub probability: HarmProbability,
1042    }
1043
1044    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1045    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1046    pub enum HarmProbability {
1047        HarmProbabilityUnspecified,
1048        Negligible,
1049        Low,
1050        Medium,
1051        High,
1052    }
1053
1054    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1055    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1056    pub enum HarmCategory {
1057        HarmCategoryUnspecified,
1058        HarmCategoryDerogatory,
1059        HarmCategoryToxicity,
1060        HarmCategoryViolence,
1061        HarmCategorySexually,
1062        HarmCategoryMedical,
1063        HarmCategoryDangerous,
1064        HarmCategoryHarassment,
1065        HarmCategoryHateSpeech,
1066        HarmCategorySexuallyExplicit,
1067        HarmCategoryDangerousContent,
1068        HarmCategoryCivicIntegrity,
1069    }
1070
1071    #[derive(Debug, Deserialize, Clone, Default, Serialize)]
1072    #[serde(rename_all = "camelCase")]
1073    pub struct UsageMetadata {
1074        pub prompt_token_count: i32,
1075        #[serde(skip_serializing_if = "Option::is_none")]
1076        pub cached_content_token_count: Option<i32>,
1077        #[serde(skip_serializing_if = "Option::is_none")]
1078        pub candidates_token_count: Option<i32>,
1079        pub total_token_count: i32,
1080        #[serde(skip_serializing_if = "Option::is_none")]
1081        pub thoughts_token_count: Option<i32>,
1082    }
1083
1084    impl std::fmt::Display for UsageMetadata {
1085        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1086            write!(
1087                f,
1088                "Prompt token count: {}\nCached content token count: {}\nCandidates token count: {}\nTotal token count: {}",
1089                self.prompt_token_count,
1090                match self.cached_content_token_count {
1091                    Some(count) => count.to_string(),
1092                    None => "n/a".to_string(),
1093                },
1094                match self.candidates_token_count {
1095                    Some(count) => count.to_string(),
1096                    None => "n/a".to_string(),
1097                },
1098                self.total_token_count
1099            )
1100        }
1101    }
1102
1103    impl GetTokenUsage for UsageMetadata {
1104        fn token_usage(&self) -> Option<crate::completion::Usage> {
1105            let mut usage = crate::completion::Usage::new();
1106
1107            usage.input_tokens = self.prompt_token_count as u64;
1108            usage.output_tokens = (self.cached_content_token_count.unwrap_or_default()
1109                + self.candidates_token_count.unwrap_or_default()
1110                + self.thoughts_token_count.unwrap_or_default())
1111                as u64;
1112            usage.total_tokens = usage.input_tokens + usage.output_tokens;
1113
1114            Some(usage)
1115        }
1116    }
1117
1118    /// A set of the feedback metadata the prompt specified in [GenerateContentRequest.contents](GenerateContentRequest).
1119    #[derive(Debug, Deserialize, Serialize)]
1120    #[serde(rename_all = "camelCase")]
1121    pub struct PromptFeedback {
1122        /// Optional. If set, the prompt was blocked and no candidates are returned. Rephrase the prompt.
1123        pub block_reason: Option<BlockReason>,
1124        /// Ratings for safety of the prompt. There is at most one rating per category.
1125        pub safety_ratings: Option<Vec<SafetyRating>>,
1126    }
1127
1128    /// Reason why a prompt was blocked by the model
1129    #[derive(Debug, Deserialize, Serialize)]
1130    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1131    pub enum BlockReason {
1132        /// Default value. This value is unused.
1133        BlockReasonUnspecified,
1134        /// Prompt was blocked due to safety reasons. Inspect safetyRatings to understand which safety category blocked it.
1135        Safety,
1136        /// Prompt was blocked due to unknown reasons.
1137        Other,
1138        /// Prompt was blocked due to the terms which are included from the terminology blocklist.
1139        Blocklist,
1140        /// Prompt was blocked due to prohibited content.
1141        ProhibitedContent,
1142    }
1143
1144    #[derive(Clone, Debug, Deserialize, Serialize)]
1145    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1146    pub enum FinishReason {
1147        /// Default value. This value is unused.
1148        FinishReasonUnspecified,
1149        /// Natural stop point of the model or provided stop sequence.
1150        Stop,
1151        /// The maximum number of tokens as specified in the request was reached.
1152        MaxTokens,
1153        /// The response candidate content was flagged for safety reasons.
1154        Safety,
1155        /// The response candidate content was flagged for recitation reasons.
1156        Recitation,
1157        /// The response candidate content was flagged for using an unsupported language.
1158        Language,
1159        /// Unknown reason.
1160        Other,
1161        /// Token generation stopped because the content contains forbidden terms.
1162        Blocklist,
1163        /// Token generation stopped for potentially containing prohibited content.
1164        ProhibitedContent,
1165        /// Token generation stopped because the content potentially contains Sensitive Personally Identifiable Information (SPII).
1166        Spii,
1167        /// The function call generated by the model is invalid.
1168        MalformedFunctionCall,
1169    }
1170
1171    #[derive(Clone, Debug, Deserialize, Serialize)]
1172    #[serde(rename_all = "camelCase")]
1173    pub struct CitationMetadata {
1174        pub citation_sources: Vec<CitationSource>,
1175    }
1176
1177    #[derive(Clone, Debug, Deserialize, Serialize)]
1178    #[serde(rename_all = "camelCase")]
1179    pub struct CitationSource {
1180        #[serde(skip_serializing_if = "Option::is_none")]
1181        pub uri: Option<String>,
1182        #[serde(skip_serializing_if = "Option::is_none")]
1183        pub start_index: Option<i32>,
1184        #[serde(skip_serializing_if = "Option::is_none")]
1185        pub end_index: Option<i32>,
1186        #[serde(skip_serializing_if = "Option::is_none")]
1187        pub license: Option<String>,
1188    }
1189
1190    #[derive(Clone, Debug, Deserialize, Serialize)]
1191    #[serde(rename_all = "camelCase")]
1192    pub struct LogprobsResult {
1193        pub top_candidate: Vec<TopCandidate>,
1194        pub chosen_candidate: Vec<LogProbCandidate>,
1195    }
1196
1197    #[derive(Clone, Debug, Deserialize, Serialize)]
1198    pub struct TopCandidate {
1199        pub candidates: Vec<LogProbCandidate>,
1200    }
1201
1202    #[derive(Clone, Debug, Deserialize, Serialize)]
1203    #[serde(rename_all = "camelCase")]
1204    pub struct LogProbCandidate {
1205        pub token: String,
1206        pub token_id: String,
1207        pub log_probability: f64,
1208    }
1209
1210    /// Gemini API Configuration options for model generation and outputs. Not all parameters are
1211    /// configurable for every model. From [Gemini API Reference](https://ai.google.dev/api/generate-content#generationconfig)
1212    /// ### Rig Note:
1213    /// Can be used to construct a typesafe `additional_params` in rig::[AgentBuilder](crate::agent::AgentBuilder).
1214    #[derive(Debug, Deserialize, Serialize)]
1215    #[serde(rename_all = "camelCase")]
1216    pub struct GenerationConfig {
1217        /// The set of character sequences (up to 5) that will stop output generation. If specified, the API will stop
1218        /// at the first appearance of a stop_sequence. The stop sequence will not be included as part of the response.
1219        #[serde(skip_serializing_if = "Option::is_none")]
1220        pub stop_sequences: Option<Vec<String>>,
1221        /// MIME type of the generated candidate text. Supported MIME types are:
1222        ///     - text/plain:  (default) Text output
1223        ///     - application/json: JSON response in the response candidates.
1224        ///     - text/x.enum: ENUM as a string response in the response candidates.
1225        /// Refer to the docs for a list of all supported text MIME types
1226        #[serde(skip_serializing_if = "Option::is_none")]
1227        pub response_mime_type: Option<String>,
1228        /// Output schema of the generated candidate text. Schemas must be a subset of the OpenAPI schema and can be
1229        /// objects, primitives or arrays. If set, a compatible responseMimeType must also  be set. Compatible MIME
1230        /// types: application/json: Schema for JSON response. Refer to the JSON text generation guide for more details.
1231        #[serde(skip_serializing_if = "Option::is_none")]
1232        pub response_schema: Option<Schema>,
1233        /// Number of generated responses to return. Currently, this value can only be set to 1. If
1234        /// unset, this will default to 1.
1235        #[serde(skip_serializing_if = "Option::is_none")]
1236        pub candidate_count: Option<i32>,
1237        /// The maximum number of tokens to include in a response candidate. Note: The default value varies by model, see
1238        /// the Model.output_token_limit attribute of the Model returned from the getModel function.
1239        #[serde(skip_serializing_if = "Option::is_none")]
1240        pub max_output_tokens: Option<u64>,
1241        /// Controls the randomness of the output. Note: The default value varies by model, see the Model.temperature
1242        /// attribute of the Model returned from the getModel function. Values can range from [0.0, 2.0].
1243        #[serde(skip_serializing_if = "Option::is_none")]
1244        pub temperature: Option<f64>,
1245        /// The maximum cumulative probability of tokens to consider when sampling. The model uses combined Top-k and
1246        /// Top-p (nucleus) sampling. Tokens are sorted based on their assigned probabilities so that only the most
1247        /// likely tokens are considered. Top-k sampling directly limits the maximum number of tokens to consider, while
1248        /// Nucleus sampling limits the number of tokens based on the cumulative probability. Note: The default value
1249        /// varies by Model and is specified by theModel.top_p attribute returned from the getModel function. An empty
1250        /// topK attribute indicates that the model doesn't apply top-k sampling and doesn't allow setting topK on requests.
1251        #[serde(skip_serializing_if = "Option::is_none")]
1252        pub top_p: Option<f64>,
1253        /// The maximum number of tokens to consider when sampling. Gemini models use Top-p (nucleus) sampling or a
1254        /// combination of Top-k and nucleus sampling. Top-k sampling considers the set of topK most probable tokens.
1255        /// Models running with nucleus sampling don't allow topK setting. Note: The default value varies by Model and is
1256        /// specified by theModel.top_p attribute returned from the getModel function. An empty topK attribute indicates
1257        /// that the model doesn't apply top-k sampling and doesn't allow setting topK on requests.
1258        #[serde(skip_serializing_if = "Option::is_none")]
1259        pub top_k: Option<i32>,
1260        /// Presence penalty applied to the next token's logprobs if the token has already been seen in the response.
1261        /// This penalty is binary on/off and not dependent on the number of times the token is used (after the first).
1262        /// Use frequencyPenalty for a penalty that increases with each use. A positive penalty will discourage the use
1263        /// of tokens that have already been used in the response, increasing the vocabulary. A negative penalty will
1264        /// encourage the use of tokens that have already been used in the response, decreasing the vocabulary.
1265        #[serde(skip_serializing_if = "Option::is_none")]
1266        pub presence_penalty: Option<f64>,
1267        /// Frequency penalty applied to the next token's logprobs, multiplied by the number of times each token has been
1268        /// seen in the response so far. A positive penalty will discourage the use of tokens that have already been
1269        /// used, proportional to the number of times the token has been used: The more a token is used, the more
1270        /// difficult it is for the  model to use that token again increasing the vocabulary of responses. Caution: A
1271        /// negative penalty will encourage the model to reuse tokens proportional to the number of times the token has
1272        /// been used. Small negative values will reduce the vocabulary of a response. Larger negative values will cause
1273        /// the model to  repeating a common token until it hits the maxOutputTokens limit: "...the the the the the...".
1274        #[serde(skip_serializing_if = "Option::is_none")]
1275        pub frequency_penalty: Option<f64>,
1276        /// If true, export the logprobs results in response.
1277        #[serde(skip_serializing_if = "Option::is_none")]
1278        pub response_logprobs: Option<bool>,
1279        /// Only valid if responseLogprobs=True. This sets the number of top logprobs to return at each decoding step in
1280        /// [Candidate.logprobs_result].
1281        #[serde(skip_serializing_if = "Option::is_none")]
1282        pub logprobs: Option<i32>,
1283        /// Configuration for thinking/reasoning.
1284        #[serde(skip_serializing_if = "Option::is_none")]
1285        pub thinking_config: Option<ThinkingConfig>,
1286    }
1287
1288    impl Default for GenerationConfig {
1289        fn default() -> Self {
1290            Self {
1291                temperature: Some(1.0),
1292                max_output_tokens: Some(4096),
1293                stop_sequences: None,
1294                response_mime_type: None,
1295                response_schema: None,
1296                candidate_count: None,
1297                top_p: None,
1298                top_k: None,
1299                presence_penalty: None,
1300                frequency_penalty: None,
1301                response_logprobs: None,
1302                logprobs: None,
1303                thinking_config: None,
1304            }
1305        }
1306    }
1307
1308    #[derive(Debug, Deserialize, Serialize)]
1309    #[serde(rename_all = "camelCase")]
1310    pub struct ThinkingConfig {
1311        pub thinking_budget: u32,
1312        pub include_thoughts: Option<bool>,
1313    }
1314    /// The Schema object allows the definition of input and output data types. These types can be objects, but also
1315    /// primitives and arrays. Represents a select subset of an OpenAPI 3.0 schema object.
1316    /// From [Gemini API Reference](https://ai.google.dev/api/caching#Schema)
1317    #[derive(Debug, Deserialize, Serialize, Clone)]
1318    pub struct Schema {
1319        pub r#type: String,
1320        #[serde(skip_serializing_if = "Option::is_none")]
1321        pub format: Option<String>,
1322        #[serde(skip_serializing_if = "Option::is_none")]
1323        pub description: Option<String>,
1324        #[serde(skip_serializing_if = "Option::is_none")]
1325        pub nullable: Option<bool>,
1326        #[serde(skip_serializing_if = "Option::is_none")]
1327        pub r#enum: Option<Vec<String>>,
1328        #[serde(skip_serializing_if = "Option::is_none")]
1329        pub max_items: Option<i32>,
1330        #[serde(skip_serializing_if = "Option::is_none")]
1331        pub min_items: Option<i32>,
1332        #[serde(skip_serializing_if = "Option::is_none")]
1333        pub properties: Option<HashMap<String, Schema>>,
1334        #[serde(skip_serializing_if = "Option::is_none")]
1335        pub required: Option<Vec<String>>,
1336        #[serde(skip_serializing_if = "Option::is_none")]
1337        pub items: Option<Box<Schema>>,
1338    }
1339
1340    /// Flattens a JSON schema by resolving all `$ref` references inline.
1341    /// It takes a JSON schema that may contain `$ref` references to definitions
1342    /// in `$defs` or `definitions` sections and returns a new schema with all references
1343    /// resolved and inlined. This is necessary for APIs like Gemini that don't support
1344    /// schema references.
1345    pub fn flatten_schema(mut schema: Value) -> Result<Value, CompletionError> {
1346        // extracting $defs if they exist
1347        let defs = if let Some(obj) = schema.as_object() {
1348            obj.get("$defs").or_else(|| obj.get("definitions")).cloned()
1349        } else {
1350            None
1351        };
1352
1353        let Some(defs_value) = defs else {
1354            return Ok(schema);
1355        };
1356
1357        let Some(defs_obj) = defs_value.as_object() else {
1358            return Err(CompletionError::ResponseError(
1359                "$defs must be an object".into(),
1360            ));
1361        };
1362
1363        resolve_refs(&mut schema, defs_obj)?;
1364
1365        // removing $defs from the final schema because we have inlined everything
1366        if let Some(obj) = schema.as_object_mut() {
1367            obj.remove("$defs");
1368            obj.remove("definitions");
1369        }
1370
1371        Ok(schema)
1372    }
1373
1374    /// Recursively resolves all `$ref` references in a JSON value by
1375    /// replacing them with their definitions.
1376    fn resolve_refs(
1377        value: &mut Value,
1378        defs: &serde_json::Map<String, Value>,
1379    ) -> Result<(), CompletionError> {
1380        match value {
1381            Value::Object(obj) => {
1382                if let Some(ref_value) = obj.get("$ref")
1383                    && let Some(ref_str) = ref_value.as_str()
1384                {
1385                    // "#/$defs/Person" -> "Person"
1386                    let def_name = parse_ref_path(ref_str)?;
1387
1388                    let def = defs.get(&def_name).ok_or_else(|| {
1389                        CompletionError::ResponseError(format!("Reference not found: {}", ref_str))
1390                    })?;
1391
1392                    let mut resolved = def.clone();
1393                    resolve_refs(&mut resolved, defs)?;
1394                    *value = resolved;
1395                    return Ok(());
1396                }
1397
1398                for (_, v) in obj.iter_mut() {
1399                    resolve_refs(v, defs)?;
1400                }
1401            }
1402            Value::Array(arr) => {
1403                for item in arr.iter_mut() {
1404                    resolve_refs(item, defs)?;
1405                }
1406            }
1407            _ => {}
1408        }
1409
1410        Ok(())
1411    }
1412
1413    /// Parses a JSON Schema `$ref` path to extract the definition name.
1414    ///
1415    /// JSON Schema references use URI fragment syntax to point to definitions within
1416    /// the same document. This function extracts the definition name from common
1417    /// reference patterns used in JSON Schema.
1418    fn parse_ref_path(ref_str: &str) -> Result<String, CompletionError> {
1419        if let Some(fragment) = ref_str.strip_prefix('#') {
1420            if let Some(name) = fragment.strip_prefix("/$defs/") {
1421                Ok(name.to_string())
1422            } else if let Some(name) = fragment.strip_prefix("/definitions/") {
1423                Ok(name.to_string())
1424            } else {
1425                Err(CompletionError::ResponseError(format!(
1426                    "Unsupported reference format: {}",
1427                    ref_str
1428                )))
1429            }
1430        } else {
1431            Err(CompletionError::ResponseError(format!(
1432                "Only fragment references (#/...) are supported: {}",
1433                ref_str
1434            )))
1435        }
1436    }
1437
1438    /// Helper function to extract the type string from a JSON value.
1439    /// Handles both direct string types and array types (returns the first element).
1440    fn extract_type(type_value: &Value) -> Option<String> {
1441        if type_value.is_string() {
1442            type_value.as_str().map(String::from)
1443        } else if type_value.is_array() {
1444            type_value
1445                .as_array()
1446                .and_then(|arr| arr.first())
1447                .and_then(|v| v.as_str().map(String::from))
1448        } else {
1449            None
1450        }
1451    }
1452
1453    /// Helper function to extract type from anyOf, oneOf, or allOf schemas.
1454    /// Returns the type of the first non-null schema found.
1455    fn extract_type_from_composition(composition: &Value) -> Option<String> {
1456        composition.as_array().and_then(|arr| {
1457            arr.iter().find_map(|schema| {
1458                if let Some(obj) = schema.as_object() {
1459                    // Skip null types
1460                    if let Some(type_val) = obj.get("type")
1461                        && let Some(type_str) = type_val.as_str()
1462                        && type_str == "null"
1463                    {
1464                        return None;
1465                    }
1466                    // Extract type from this schema
1467                    obj.get("type").and_then(extract_type).or_else(|| {
1468                        if obj.contains_key("properties") {
1469                            Some("object".to_string())
1470                        } else {
1471                            None
1472                        }
1473                    })
1474                } else {
1475                    None
1476                }
1477            })
1478        })
1479    }
1480
1481    /// Helper function to extract the first non-null schema from anyOf, oneOf, or allOf.
1482    /// Returns the schema object that should be used for properties, required, etc.
1483    fn extract_schema_from_composition(
1484        composition: &Value,
1485    ) -> Option<serde_json::Map<String, Value>> {
1486        composition.as_array().and_then(|arr| {
1487            arr.iter().find_map(|schema| {
1488                if let Some(obj) = schema.as_object()
1489                    && let Some(type_val) = obj.get("type")
1490                    && let Some(type_str) = type_val.as_str()
1491                {
1492                    if type_str == "null" {
1493                        return None;
1494                    }
1495                    Some(obj.clone())
1496                } else {
1497                    None
1498                }
1499            })
1500        })
1501    }
1502
1503    /// Helper function to infer the type of a schema object.
1504    /// Checks for explicit type, then anyOf/oneOf/allOf, then infers from properties.
1505    fn infer_type(obj: &serde_json::Map<String, Value>) -> String {
1506        // First, try direct type field
1507        if let Some(type_val) = obj.get("type")
1508            && let Some(type_str) = extract_type(type_val)
1509        {
1510            return type_str;
1511        }
1512
1513        // Then try anyOf, oneOf, allOf (in that order)
1514        if let Some(any_of) = obj.get("anyOf")
1515            && let Some(type_str) = extract_type_from_composition(any_of)
1516        {
1517            return type_str;
1518        }
1519
1520        if let Some(one_of) = obj.get("oneOf")
1521            && let Some(type_str) = extract_type_from_composition(one_of)
1522        {
1523            return type_str;
1524        }
1525
1526        if let Some(all_of) = obj.get("allOf")
1527            && let Some(type_str) = extract_type_from_composition(all_of)
1528        {
1529            return type_str;
1530        }
1531
1532        // Finally, infer object type if properties are present
1533        if obj.contains_key("properties") {
1534            "object".to_string()
1535        } else {
1536            String::new()
1537        }
1538    }
1539
1540    impl TryFrom<Value> for Schema {
1541        type Error = CompletionError;
1542
1543        fn try_from(value: Value) -> Result<Self, Self::Error> {
1544            let flattened_val = flatten_schema(value)?;
1545            if let Some(obj) = flattened_val.as_object() {
1546                // Determine which object to use for extracting properties and required fields.
1547                // If this object has anyOf/oneOf/allOf, we need to extract properties from the composition.
1548                let props_source = if obj.get("properties").is_none() {
1549                    if let Some(any_of) = obj.get("anyOf") {
1550                        extract_schema_from_composition(any_of)
1551                    } else if let Some(one_of) = obj.get("oneOf") {
1552                        extract_schema_from_composition(one_of)
1553                    } else if let Some(all_of) = obj.get("allOf") {
1554                        extract_schema_from_composition(all_of)
1555                    } else {
1556                        None
1557                    }
1558                    .unwrap_or(obj.clone())
1559                } else {
1560                    obj.clone()
1561                };
1562
1563                Ok(Schema {
1564                    r#type: infer_type(obj),
1565                    format: obj.get("format").and_then(|v| v.as_str()).map(String::from),
1566                    description: obj
1567                        .get("description")
1568                        .and_then(|v| v.as_str())
1569                        .map(String::from),
1570                    nullable: obj.get("nullable").and_then(|v| v.as_bool()),
1571                    r#enum: obj.get("enum").and_then(|v| v.as_array()).map(|arr| {
1572                        arr.iter()
1573                            .filter_map(|v| v.as_str().map(String::from))
1574                            .collect()
1575                    }),
1576                    max_items: obj
1577                        .get("maxItems")
1578                        .and_then(|v| v.as_i64())
1579                        .map(|v| v as i32),
1580                    min_items: obj
1581                        .get("minItems")
1582                        .and_then(|v| v.as_i64())
1583                        .map(|v| v as i32),
1584                    properties: props_source
1585                        .get("properties")
1586                        .and_then(|v| v.as_object())
1587                        .map(|map| {
1588                            map.iter()
1589                                .filter_map(|(k, v)| {
1590                                    v.clone().try_into().ok().map(|schema| (k.clone(), schema))
1591                                })
1592                                .collect()
1593                        }),
1594                    required: props_source
1595                        .get("required")
1596                        .and_then(|v| v.as_array())
1597                        .map(|arr| {
1598                            arr.iter()
1599                                .filter_map(|v| v.as_str().map(String::from))
1600                                .collect()
1601                        }),
1602                    items: obj
1603                        .get("items")
1604                        .and_then(|v| v.clone().try_into().ok())
1605                        .map(Box::new),
1606                })
1607            } else {
1608                Err(CompletionError::ResponseError(
1609                    "Expected a JSON object for Schema".into(),
1610                ))
1611            }
1612        }
1613    }
1614
1615    #[derive(Debug, Serialize)]
1616    #[serde(rename_all = "camelCase")]
1617    pub struct GenerateContentRequest {
1618        pub contents: Vec<Content>,
1619        #[serde(skip_serializing_if = "Option::is_none")]
1620        pub tools: Option<Tool>,
1621        pub tool_config: Option<ToolConfig>,
1622        /// Optional. Configuration options for model generation and outputs.
1623        pub generation_config: Option<GenerationConfig>,
1624        /// Optional. A list of unique SafetySetting instances for blocking unsafe content. This will be enforced on the
1625        /// [GenerateContentRequest.contents] and [GenerateContentResponse.candidates]. There should not be more than one
1626        /// setting for each SafetyCategory type. The API will block any contents and responses that fail to meet the
1627        /// thresholds set by these settings. This list overrides the default settings for each SafetyCategory specified
1628        /// in the safetySettings. If there is no SafetySetting for a given SafetyCategory provided in the list, the API
1629        /// will use the default safety setting for that category. Harm categories:
1630        ///     - HARM_CATEGORY_HATE_SPEECH,
1631        ///     - HARM_CATEGORY_SEXUALLY_EXPLICIT
1632        ///     - HARM_CATEGORY_DANGEROUS_CONTENT
1633        ///     - HARM_CATEGORY_HARASSMENT
1634        /// are supported.
1635        /// Refer to the guide for detailed information on available safety settings. Also refer to the Safety guidance
1636        /// to learn how to incorporate safety considerations in your AI applications.
1637        pub safety_settings: Option<Vec<SafetySetting>>,
1638        /// Optional. Developer set system instruction(s). Currently, text only.
1639        /// From [Gemini API Reference](https://ai.google.dev/gemini-api/docs/system-instructions?lang=rest)
1640        pub system_instruction: Option<Content>,
1641        // cachedContent: Optional<String>
1642        /// Additional parameters.
1643        #[serde(flatten, skip_serializing_if = "Option::is_none")]
1644        pub additional_params: Option<serde_json::Value>,
1645    }
1646
1647    #[derive(Debug, Serialize)]
1648    #[serde(rename_all = "camelCase")]
1649    pub struct Tool {
1650        pub function_declarations: Vec<FunctionDeclaration>,
1651        pub code_execution: Option<CodeExecution>,
1652    }
1653
1654    #[derive(Debug, Serialize, Clone)]
1655    #[serde(rename_all = "camelCase")]
1656    pub struct FunctionDeclaration {
1657        pub name: String,
1658        pub description: String,
1659        #[serde(skip_serializing_if = "Option::is_none")]
1660        pub parameters: Option<Schema>,
1661    }
1662
1663    #[derive(Debug, Serialize, Deserialize)]
1664    #[serde(rename_all = "camelCase")]
1665    pub struct ToolConfig {
1666        pub function_calling_config: Option<FunctionCallingMode>,
1667    }
1668
1669    #[derive(Debug, Serialize, Deserialize, Default)]
1670    #[serde(tag = "mode", rename_all = "UPPERCASE")]
1671    pub enum FunctionCallingMode {
1672        #[default]
1673        Auto,
1674        None,
1675        Any {
1676            #[serde(skip_serializing_if = "Option::is_none")]
1677            allowed_function_names: Option<Vec<String>>,
1678        },
1679    }
1680
1681    impl TryFrom<message::ToolChoice> for FunctionCallingMode {
1682        type Error = CompletionError;
1683        fn try_from(value: message::ToolChoice) -> Result<Self, Self::Error> {
1684            let res = match value {
1685                message::ToolChoice::Auto => Self::Auto,
1686                message::ToolChoice::None => Self::None,
1687                message::ToolChoice::Required => Self::Any {
1688                    allowed_function_names: None,
1689                },
1690                message::ToolChoice::Specific { function_names } => Self::Any {
1691                    allowed_function_names: Some(function_names),
1692                },
1693            };
1694
1695            Ok(res)
1696        }
1697    }
1698
1699    #[derive(Debug, Serialize)]
1700    pub struct CodeExecution {}
1701
1702    #[derive(Debug, Serialize)]
1703    #[serde(rename_all = "camelCase")]
1704    pub struct SafetySetting {
1705        pub category: HarmCategory,
1706        pub threshold: HarmBlockThreshold,
1707    }
1708
1709    #[derive(Debug, Serialize)]
1710    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1711    pub enum HarmBlockThreshold {
1712        HarmBlockThresholdUnspecified,
1713        BlockLowAndAbove,
1714        BlockMediumAndAbove,
1715        BlockOnlyHigh,
1716        BlockNone,
1717        Off,
1718    }
1719}
1720
1721#[cfg(test)]
1722mod tests {
1723    use crate::{message, providers::gemini::completion::gemini_api_types::flatten_schema};
1724
1725    use super::*;
1726    use serde_json::json;
1727
1728    #[test]
1729    fn test_deserialize_message_user() {
1730        let raw_message = r#"{
1731            "parts": [
1732                {"text": "Hello, world!"},
1733                {"inlineData": {"mimeType": "image/png", "data": "base64encodeddata"}},
1734                {"functionCall": {"name": "test_function", "args": {"arg1": "value1"}}},
1735                {"functionResponse": {"name": "test_function", "response": {"result": "success"}}},
1736                {"fileData": {"mimeType": "application/pdf", "fileUri": "http://example.com/file.pdf"}},
1737                {"executableCode": {"code": "print('Hello, world!')", "language": "PYTHON"}},
1738                {"codeExecutionResult": {"output": "Hello, world!", "outcome": "OUTCOME_OK"}}
1739            ],
1740            "role": "user"
1741        }"#;
1742
1743        let content: Content = {
1744            let jd = &mut serde_json::Deserializer::from_str(raw_message);
1745            serde_path_to_error::deserialize(jd).unwrap_or_else(|err| {
1746                panic!("Deserialization error at {}: {}", err.path(), err);
1747            })
1748        };
1749        assert_eq!(content.role, Some(Role::User));
1750        assert_eq!(content.parts.len(), 7);
1751
1752        let parts: Vec<Part> = content.parts.into_iter().collect();
1753
1754        if let Part {
1755            part: PartKind::Text(text),
1756            ..
1757        } = &parts[0]
1758        {
1759            assert_eq!(text, "Hello, world!");
1760        } else {
1761            panic!("Expected text part");
1762        }
1763
1764        if let Part {
1765            part: PartKind::InlineData(inline_data),
1766            ..
1767        } = &parts[1]
1768        {
1769            assert_eq!(inline_data.mime_type, "image/png");
1770            assert_eq!(inline_data.data, "base64encodeddata");
1771        } else {
1772            panic!("Expected inline data part");
1773        }
1774
1775        if let Part {
1776            part: PartKind::FunctionCall(function_call),
1777            ..
1778        } = &parts[2]
1779        {
1780            assert_eq!(function_call.name, "test_function");
1781            assert_eq!(
1782                function_call.args.as_object().unwrap().get("arg1").unwrap(),
1783                "value1"
1784            );
1785        } else {
1786            panic!("Expected function call part");
1787        }
1788
1789        if let Part {
1790            part: PartKind::FunctionResponse(function_response),
1791            ..
1792        } = &parts[3]
1793        {
1794            assert_eq!(function_response.name, "test_function");
1795            assert_eq!(
1796                function_response
1797                    .response
1798                    .as_ref()
1799                    .unwrap()
1800                    .get("result")
1801                    .unwrap(),
1802                "success"
1803            );
1804        } else {
1805            panic!("Expected function response part");
1806        }
1807
1808        if let Part {
1809            part: PartKind::FileData(file_data),
1810            ..
1811        } = &parts[4]
1812        {
1813            assert_eq!(file_data.mime_type.as_ref().unwrap(), "application/pdf");
1814            assert_eq!(file_data.file_uri, "http://example.com/file.pdf");
1815        } else {
1816            panic!("Expected file data part");
1817        }
1818
1819        if let Part {
1820            part: PartKind::ExecutableCode(executable_code),
1821            ..
1822        } = &parts[5]
1823        {
1824            assert_eq!(executable_code.code, "print('Hello, world!')");
1825        } else {
1826            panic!("Expected executable code part");
1827        }
1828
1829        if let Part {
1830            part: PartKind::CodeExecutionResult(code_execution_result),
1831            ..
1832        } = &parts[6]
1833        {
1834            assert_eq!(
1835                code_execution_result.clone().output.unwrap(),
1836                "Hello, world!"
1837            );
1838        } else {
1839            panic!("Expected code execution result part");
1840        }
1841    }
1842
1843    #[test]
1844    fn test_deserialize_message_model() {
1845        let json_data = json!({
1846            "parts": [{"text": "Hello, user!"}],
1847            "role": "model"
1848        });
1849
1850        let content: Content = serde_json::from_value(json_data).unwrap();
1851        assert_eq!(content.role, Some(Role::Model));
1852        assert_eq!(content.parts.len(), 1);
1853        if let Some(Part {
1854            part: PartKind::Text(text),
1855            ..
1856        }) = content.parts.first()
1857        {
1858            assert_eq!(text, "Hello, user!");
1859        } else {
1860            panic!("Expected text part");
1861        }
1862    }
1863
1864    #[test]
1865    fn test_message_conversion_user() {
1866        let msg = message::Message::user("Hello, world!");
1867        let content: Content = msg.try_into().unwrap();
1868        assert_eq!(content.role, Some(Role::User));
1869        assert_eq!(content.parts.len(), 1);
1870        if let Some(Part {
1871            part: PartKind::Text(text),
1872            ..
1873        }) = &content.parts.first()
1874        {
1875            assert_eq!(text, "Hello, world!");
1876        } else {
1877            panic!("Expected text part");
1878        }
1879    }
1880
1881    #[test]
1882    fn test_message_conversion_model() {
1883        let msg = message::Message::assistant("Hello, user!");
1884
1885        let content: Content = msg.try_into().unwrap();
1886        assert_eq!(content.role, Some(Role::Model));
1887        assert_eq!(content.parts.len(), 1);
1888        if let Some(Part {
1889            part: PartKind::Text(text),
1890            ..
1891        }) = &content.parts.first()
1892        {
1893            assert_eq!(text, "Hello, user!");
1894        } else {
1895            panic!("Expected text part");
1896        }
1897    }
1898
1899    #[test]
1900    fn test_message_conversion_tool_call() {
1901        let tool_call = message::ToolCall {
1902            id: "test_tool".to_string(),
1903            call_id: None,
1904            function: message::ToolFunction {
1905                name: "test_function".to_string(),
1906                arguments: json!({"arg1": "value1"}),
1907            },
1908        };
1909
1910        let msg = message::Message::Assistant {
1911            id: None,
1912            content: OneOrMany::one(message::AssistantContent::ToolCall(tool_call)),
1913        };
1914
1915        let content: Content = msg.try_into().unwrap();
1916        assert_eq!(content.role, Some(Role::Model));
1917        assert_eq!(content.parts.len(), 1);
1918        if let Some(Part {
1919            part: PartKind::FunctionCall(function_call),
1920            ..
1921        }) = content.parts.first()
1922        {
1923            assert_eq!(function_call.name, "test_function");
1924            assert_eq!(
1925                function_call.args.as_object().unwrap().get("arg1").unwrap(),
1926                "value1"
1927            );
1928        } else {
1929            panic!("Expected function call part");
1930        }
1931    }
1932
1933    #[test]
1934    fn test_vec_schema_conversion() {
1935        let schema_with_ref = json!({
1936            "type": "array",
1937            "items": {
1938                "$ref": "#/$defs/Person"
1939            },
1940            "$defs": {
1941                "Person": {
1942                    "type": "object",
1943                    "properties": {
1944                        "first_name": {
1945                            "type": ["string", "null"],
1946                            "description": "The person's first name, if provided (null otherwise)"
1947                        },
1948                        "last_name": {
1949                            "type": ["string", "null"],
1950                            "description": "The person's last name, if provided (null otherwise)"
1951                        },
1952                        "job": {
1953                            "type": ["string", "null"],
1954                            "description": "The person's job, if provided (null otherwise)"
1955                        }
1956                    },
1957                    "required": []
1958                }
1959            }
1960        });
1961
1962        let result: Result<Schema, _> = schema_with_ref.try_into();
1963
1964        match result {
1965            Ok(schema) => {
1966                assert_eq!(schema.r#type, "array");
1967
1968                if let Some(items) = schema.items {
1969                    println!("item types: {}", items.r#type);
1970
1971                    assert_ne!(items.r#type, "", "Items type should not be empty string!");
1972                    assert_eq!(items.r#type, "object", "Items should be object type");
1973                } else {
1974                    panic!("Schema should have items field for array type");
1975                }
1976            }
1977            Err(e) => println!("Schema conversion failed: {:?}", e),
1978        }
1979    }
1980
1981    #[test]
1982    fn test_object_schema() {
1983        let simple_schema = json!({
1984            "type": "object",
1985            "properties": {
1986                "name": {
1987                    "type": "string"
1988                }
1989            }
1990        });
1991
1992        let schema: Schema = simple_schema.try_into().unwrap();
1993        assert_eq!(schema.r#type, "object");
1994        assert!(schema.properties.is_some());
1995    }
1996
1997    #[test]
1998    fn test_array_with_inline_items() {
1999        let inline_schema = json!({
2000            "type": "array",
2001            "items": {
2002                "type": "object",
2003                "properties": {
2004                    "name": {
2005                        "type": "string"
2006                    }
2007                }
2008            }
2009        });
2010
2011        let schema: Schema = inline_schema.try_into().unwrap();
2012        assert_eq!(schema.r#type, "array");
2013
2014        if let Some(items) = schema.items {
2015            assert_eq!(items.r#type, "object");
2016            assert!(items.properties.is_some());
2017        } else {
2018            panic!("Schema should have items field");
2019        }
2020    }
2021    #[test]
2022    fn test_flattened_schema() {
2023        let ref_schema = json!({
2024            "type": "array",
2025            "items": {
2026                "$ref": "#/$defs/Person"
2027            },
2028            "$defs": {
2029                "Person": {
2030                    "type": "object",
2031                    "properties": {
2032                        "name": { "type": "string" }
2033                    }
2034                }
2035            }
2036        });
2037
2038        let flattened = flatten_schema(ref_schema).unwrap();
2039        let schema: Schema = flattened.try_into().unwrap();
2040
2041        assert_eq!(schema.r#type, "array");
2042
2043        if let Some(items) = schema.items {
2044            println!("Flattened items type: '{}'", items.r#type);
2045
2046            assert_eq!(items.r#type, "object");
2047            assert!(items.properties.is_some());
2048        }
2049    }
2050}