rig/providers/gemini/
completion.rs

1// ================================================================
2//! Google Gemini Completion Integration
3//! From [Gemini API Reference](https://ai.google.dev/api/generate-content)
4// ================================================================
5/// `gemini-2.5-pro-preview-06-05` completion model
6pub const GEMINI_2_5_PRO_PREVIEW_06_05: &str = "gemini-2.5-pro-preview-06-05";
7/// `gemini-2.5-pro-preview-05-06` completion model
8pub const GEMINI_2_5_PRO_PREVIEW_05_06: &str = "gemini-2.5-pro-preview-05-06";
9/// `gemini-2.5-pro-preview-03-25` completion model
10pub const GEMINI_2_5_PRO_PREVIEW_03_25: &str = "gemini-2.5-pro-preview-03-25";
11/// `gemini-2.5-flash-preview-05-20` completion model
12pub const GEMINI_2_5_FLASH_PREVIEW_05_20: &str = "gemini-2.5-flash-preview-05-20";
13/// `gemini-2.5-flash-preview-04-17` completion model
14pub const GEMINI_2_5_FLASH_PREVIEW_04_17: &str = "gemini-2.5-flash-preview-04-17";
15/// `gemini-2.5-pro-exp-03-25` experimental completion model
16pub const GEMINI_2_5_PRO_EXP_03_25: &str = "gemini-2.5-pro-exp-03-25";
17/// `gemini-2.0-flash-lite` completion model
18pub const GEMINI_2_0_FLASH_LITE: &str = "gemini-2.0-flash-lite";
19/// `gemini-2.0-flash` completion model
20pub const GEMINI_2_0_FLASH: &str = "gemini-2.0-flash";
21/// `gemini-1.5-flash` completion model
22pub const GEMINI_1_5_FLASH: &str = "gemini-1.5-flash";
23/// `gemini-1.5-pro` completion model
24pub const GEMINI_1_5_PRO: &str = "gemini-1.5-pro";
25/// `gemini-1.5-pro-8b` completion model
26pub const GEMINI_1_5_PRO_8B: &str = "gemini-1.5-pro-8b";
27/// `gemini-1.0-pro` completion model
28pub const GEMINI_1_0_PRO: &str = "gemini-1.0-pro";
29
30use self::gemini_api_types::Schema;
31use crate::http_client::HttpClientExt;
32use crate::message::Reasoning;
33use crate::providers::gemini::completion::gemini_api_types::{
34    AdditionalParameters, FunctionCallingMode, ToolConfig,
35};
36use crate::providers::gemini::streaming::StreamingCompletionResponse;
37use crate::telemetry::SpanCombinator;
38use crate::{
39    OneOrMany,
40    completion::{self, CompletionError, CompletionRequest},
41};
42use gemini_api_types::{
43    Content, FunctionDeclaration, GenerateContentRequest, GenerateContentResponse, Part, PartKind,
44    Role, Tool,
45};
46use serde_json::{Map, Value};
47use std::convert::TryFrom;
48use tracing::info_span;
49
50use super::Client;
51
52// =================================================================
53// Rig Implementation Types
54// =================================================================
55
56#[derive(Clone)]
57pub struct CompletionModel<T = reqwest::Client> {
58    pub(crate) client: Client<T>,
59    pub model: String,
60}
61
62impl<T> CompletionModel<T> {
63    pub fn new(client: Client<T>, model: &str) -> Self {
64        Self {
65            client,
66            model: model.to_string(),
67        }
68    }
69}
70
71impl<T> completion::CompletionModel for CompletionModel<T>
72where
73    T: HttpClientExt + Clone + 'static,
74{
75    type Response = GenerateContentResponse;
76    type StreamingResponse = StreamingCompletionResponse;
77
78    #[cfg_attr(feature = "worker", worker::send)]
79    async fn completion(
80        &self,
81        completion_request: CompletionRequest,
82    ) -> Result<completion::CompletionResponse<GenerateContentResponse>, CompletionError> {
83        let span = if tracing::Span::current().is_disabled() {
84            info_span!(
85                target: "rig::completions",
86                "generate_content",
87                gen_ai.operation.name = "generate_content",
88                gen_ai.provider.name = "gcp.gemini",
89                gen_ai.request.model = self.model,
90                gen_ai.system_instructions = &completion_request.preamble,
91                gen_ai.response.id = tracing::field::Empty,
92                gen_ai.response.model = tracing::field::Empty,
93                gen_ai.usage.output_tokens = tracing::field::Empty,
94                gen_ai.usage.input_tokens = tracing::field::Empty,
95                gen_ai.input.messages = tracing::field::Empty,
96                gen_ai.output.messages = tracing::field::Empty,
97            )
98        } else {
99            tracing::Span::current()
100        };
101
102        let request = create_request_body(completion_request)?;
103        span.record_model_input(&request.contents);
104
105        span.record_model_input(&request.contents);
106
107        tracing::debug!(
108            "Sending completion request to Gemini API {}",
109            serde_json::to_string_pretty(&request)?
110        );
111
112        let body = serde_json::to_vec(&request)?;
113
114        let request = self
115            .client
116            .post(&format!("/v1beta/models/{}:generateContent", self.model))
117            .header("Content-Type", "application/json")
118            .body(body)
119            .map_err(|e| CompletionError::HttpError(e.into()))?;
120
121        let response = self.client.send::<_, Vec<u8>>(request).await?;
122
123        if response.status().is_success() {
124            let response_body = response
125                .into_body()
126                .await
127                .map_err(CompletionError::HttpError)?;
128
129            let response: GenerateContentResponse = serde_json::from_slice(&response_body)?;
130
131            match response.usage_metadata {
132                Some(ref usage) => tracing::info!(target: "rig",
133                "Gemini completion token usage: {}",
134                usage
135                ),
136                None => tracing::info!(target: "rig",
137                    "Gemini completion token usage: n/a",
138                ),
139            }
140
141            let span = tracing::Span::current();
142            span.record_model_output(&response.candidates);
143            span.record_response_metadata(&response);
144            span.record_token_usage(&response.usage_metadata);
145
146            tracing::debug!(
147                "Received response from Gemini API: {}",
148                serde_json::to_string_pretty(&response)?
149            );
150
151            response.try_into()
152        } else {
153            let text = String::from_utf8_lossy(
154                &response
155                    .into_body()
156                    .await
157                    .map_err(CompletionError::HttpError)?,
158            )
159            .into();
160
161            Err(CompletionError::ProviderError(text))
162        }
163    }
164
165    #[cfg_attr(feature = "worker", worker::send)]
166    async fn stream(
167        &self,
168        request: CompletionRequest,
169    ) -> Result<
170        crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
171        CompletionError,
172    > {
173        CompletionModel::stream(self, request).await
174    }
175}
176
177pub(crate) fn create_request_body(
178    completion_request: CompletionRequest,
179) -> Result<GenerateContentRequest, CompletionError> {
180    let mut full_history = Vec::new();
181    full_history.extend(completion_request.chat_history);
182
183    let additional_params = completion_request
184        .additional_params
185        .unwrap_or_else(|| Value::Object(Map::new()));
186
187    let AdditionalParameters {
188        mut generation_config,
189        additional_params,
190    } = serde_json::from_value::<AdditionalParameters>(additional_params)?;
191
192    if let Some(temp) = completion_request.temperature {
193        generation_config.temperature = Some(temp);
194    }
195
196    if let Some(max_tokens) = completion_request.max_tokens {
197        generation_config.max_output_tokens = Some(max_tokens);
198    }
199
200    let system_instruction = completion_request.preamble.clone().map(|preamble| Content {
201        parts: vec![preamble.into()],
202        role: Some(Role::Model),
203    });
204
205    let tools = if completion_request.tools.is_empty() {
206        None
207    } else {
208        Some(Tool::try_from(completion_request.tools)?)
209    };
210
211    let tool_config = if let Some(cfg) = completion_request.tool_choice {
212        Some(ToolConfig {
213            function_calling_config: Some(FunctionCallingMode::try_from(cfg)?),
214        })
215    } else {
216        None
217    };
218
219    let request = GenerateContentRequest {
220        contents: full_history
221            .into_iter()
222            .map(|msg| {
223                msg.try_into()
224                    .map_err(|e| CompletionError::RequestError(Box::new(e)))
225            })
226            .collect::<Result<Vec<_>, _>>()?,
227        generation_config: Some(generation_config),
228        safety_settings: None,
229        tools,
230        tool_config,
231        system_instruction,
232        additional_params,
233    };
234
235    Ok(request)
236}
237
238impl TryFrom<completion::ToolDefinition> for Tool {
239    type Error = CompletionError;
240
241    fn try_from(tool: completion::ToolDefinition) -> Result<Self, Self::Error> {
242        let parameters: Option<Schema> =
243            if tool.parameters == serde_json::json!({"type": "object", "properties": {}}) {
244                None
245            } else {
246                Some(tool.parameters.try_into()?)
247            };
248
249        Ok(Self {
250            function_declarations: vec![FunctionDeclaration {
251                name: tool.name,
252                description: tool.description,
253                parameters,
254            }],
255            code_execution: None,
256        })
257    }
258}
259
260impl TryFrom<Vec<completion::ToolDefinition>> for Tool {
261    type Error = CompletionError;
262
263    fn try_from(tools: Vec<completion::ToolDefinition>) -> Result<Self, Self::Error> {
264        let mut function_declarations = Vec::new();
265
266        for tool in tools {
267            let parameters =
268                if tool.parameters == serde_json::json!({"type": "object", "properties": {}}) {
269                    None
270                } else {
271                    match tool.parameters.try_into() {
272                        Ok(schema) => Some(schema),
273                        Err(e) => {
274                            let emsg = format!(
275                                "Tool '{}' could not be converted to a schema: {:?}",
276                                tool.name, e,
277                            );
278                            return Err(CompletionError::ProviderError(emsg));
279                        }
280                    }
281                };
282
283            function_declarations.push(FunctionDeclaration {
284                name: tool.name,
285                description: tool.description,
286                parameters,
287            });
288        }
289
290        Ok(Self {
291            function_declarations,
292            code_execution: None,
293        })
294    }
295}
296
297impl TryFrom<GenerateContentResponse> for completion::CompletionResponse<GenerateContentResponse> {
298    type Error = CompletionError;
299
300    fn try_from(response: GenerateContentResponse) -> Result<Self, Self::Error> {
301        let candidate = response.candidates.first().ok_or_else(|| {
302            CompletionError::ResponseError("No response candidates in response".into())
303        })?;
304
305        let content = candidate
306            .content
307            .parts
308            .iter()
309            .map(|Part { thought, part, .. }| {
310                Ok(match part {
311                    PartKind::Text(text) => {
312                        if let Some(thought) = thought
313                            && *thought
314                        {
315                            completion::AssistantContent::Reasoning(Reasoning::new(text))
316                        } else {
317                            completion::AssistantContent::text(text)
318                        }
319                    }
320                    PartKind::FunctionCall(function_call) => {
321                        completion::AssistantContent::tool_call(
322                            &function_call.name,
323                            &function_call.name,
324                            function_call.args.clone(),
325                        )
326                    }
327                    _ => {
328                        return Err(CompletionError::ResponseError(
329                            "Response did not contain a message or tool call".into(),
330                        ));
331                    }
332                })
333            })
334            .collect::<Result<Vec<_>, _>>()?;
335
336        let choice = OneOrMany::many(content).map_err(|_| {
337            CompletionError::ResponseError(
338                "Response contained no message or tool call (empty)".to_owned(),
339            )
340        })?;
341
342        let usage = response
343            .usage_metadata
344            .as_ref()
345            .map(|usage| completion::Usage {
346                input_tokens: usage.prompt_token_count as u64,
347                output_tokens: usage.candidates_token_count.unwrap_or(0) as u64,
348                total_tokens: usage.total_token_count as u64,
349            })
350            .unwrap_or_default();
351
352        Ok(completion::CompletionResponse {
353            choice,
354            usage,
355            raw_response: response,
356        })
357    }
358}
359
360pub mod gemini_api_types {
361    use crate::telemetry::ProviderResponseExt;
362    use std::{collections::HashMap, convert::Infallible, str::FromStr};
363
364    // =================================================================
365    // Gemini API Types
366    // =================================================================
367    use serde::{Deserialize, Serialize};
368    use serde_json::{Value, json};
369
370    use crate::completion::GetTokenUsage;
371    use crate::message::{DocumentSourceKind, ImageMediaType, MessageError, MimeType};
372    use crate::{
373        OneOrMany,
374        completion::CompletionError,
375        message::{self, Reasoning, Text},
376        providers::gemini::gemini_api_types::{CodeExecutionResult, ExecutableCode},
377    };
378
379    #[derive(Debug, Deserialize, Serialize, Default)]
380    #[serde(rename_all = "camelCase")]
381    pub struct AdditionalParameters {
382        /// Change your Gemini request configuration.
383        pub generation_config: GenerationConfig,
384        /// Any additional parameters that you want.
385        #[serde(flatten, skip_serializing_if = "Option::is_none")]
386        pub additional_params: Option<serde_json::Value>,
387    }
388
389    impl AdditionalParameters {
390        pub fn with_config(mut self, cfg: GenerationConfig) -> Self {
391            self.generation_config = cfg;
392            self
393        }
394
395        pub fn with_params(mut self, params: serde_json::Value) -> Self {
396            self.additional_params = Some(params);
397            self
398        }
399    }
400
401    /// Response from the model supporting multiple candidate responses.
402    /// Safety ratings and content filtering are reported for both prompt in GenerateContentResponse.prompt_feedback
403    /// and for each candidate in finishReason and in safetyRatings.
404    /// The API:
405    ///     - Returns either all requested candidates or none of them
406    ///     - Returns no candidates at all only if there was something wrong with the prompt (check promptFeedback)
407    ///     - Reports feedback on each candidate in finishReason and safetyRatings.
408    #[derive(Debug, Deserialize, Serialize)]
409    #[serde(rename_all = "camelCase")]
410    pub struct GenerateContentResponse {
411        pub response_id: String,
412        /// Candidate responses from the model.
413        pub candidates: Vec<ContentCandidate>,
414        /// Returns the prompt's feedback related to the content filters.
415        pub prompt_feedback: Option<PromptFeedback>,
416        /// Output only. Metadata on the generation requests' token usage.
417        pub usage_metadata: Option<UsageMetadata>,
418        pub model_version: Option<String>,
419    }
420
421    impl ProviderResponseExt for GenerateContentResponse {
422        type OutputMessage = ContentCandidate;
423        type Usage = UsageMetadata;
424
425        fn get_response_id(&self) -> Option<String> {
426            Some(self.response_id.clone())
427        }
428
429        fn get_response_model_name(&self) -> Option<String> {
430            None
431        }
432
433        fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
434            self.candidates.clone()
435        }
436
437        fn get_text_response(&self) -> Option<String> {
438            let str = self
439                .candidates
440                .iter()
441                .filter_map(|x| {
442                    if x.content.role.as_ref().is_none_or(|y| y != &Role::Model) {
443                        return None;
444                    }
445
446                    let res = x
447                        .content
448                        .parts
449                        .iter()
450                        .filter_map(|part| {
451                            if let PartKind::Text(ref str) = part.part {
452                                Some(str.to_owned())
453                            } else {
454                                None
455                            }
456                        })
457                        .collect::<Vec<String>>()
458                        .join("\n");
459
460                    Some(res)
461                })
462                .collect::<Vec<String>>()
463                .join("\n");
464
465            if str.is_empty() { None } else { Some(str) }
466        }
467
468        fn get_usage(&self) -> Option<Self::Usage> {
469            self.usage_metadata.clone()
470        }
471    }
472
473    /// A response candidate generated from the model.
474    #[derive(Clone, Debug, Deserialize, Serialize)]
475    #[serde(rename_all = "camelCase")]
476    pub struct ContentCandidate {
477        /// Output only. Generated content returned from the model.
478        pub content: Content,
479        /// Optional. Output only. The reason why the model stopped generating tokens.
480        /// If empty, the model has not stopped generating tokens.
481        pub finish_reason: Option<FinishReason>,
482        /// List of ratings for the safety of a response candidate.
483        /// There is at most one rating per category.
484        pub safety_ratings: Option<Vec<SafetyRating>>,
485        /// Output only. Citation information for model-generated candidate.
486        /// This field may be populated with recitation information for any text included in the content.
487        /// These are passages that are "recited" from copyrighted material in the foundational LLM's training data.
488        pub citation_metadata: Option<CitationMetadata>,
489        /// Output only. Token count for this candidate.
490        pub token_count: Option<i32>,
491        /// Output only.
492        pub avg_logprobs: Option<f64>,
493        /// Output only. Log-likelihood scores for the response tokens and top tokens
494        pub logprobs_result: Option<LogprobsResult>,
495        /// Output only. Index of the candidate in the list of response candidates.
496        pub index: Option<i32>,
497    }
498
499    #[derive(Clone, Debug, Deserialize, Serialize)]
500    pub struct Content {
501        /// Ordered Parts that constitute a single message. Parts may have different MIME types.
502        #[serde(default)]
503        pub parts: Vec<Part>,
504        /// The producer of the content. Must be either 'user' or 'model'.
505        /// Useful to set for multi-turn conversations, otherwise can be left blank or unset.
506        pub role: Option<Role>,
507    }
508
509    impl TryFrom<message::Message> for Content {
510        type Error = message::MessageError;
511
512        fn try_from(msg: message::Message) -> Result<Self, Self::Error> {
513            Ok(match msg {
514                message::Message::User { content } => Content {
515                    parts: content
516                        .into_iter()
517                        .map(|c| c.try_into())
518                        .collect::<Result<Vec<_>, _>>()?,
519                    role: Some(Role::User),
520                },
521                message::Message::Assistant { content, .. } => Content {
522                    role: Some(Role::Model),
523                    parts: content.into_iter().map(|content| content.into()).collect(),
524                },
525            })
526        }
527    }
528
529    impl TryFrom<Content> for message::Message {
530        type Error = message::MessageError;
531
532        fn try_from(content: Content) -> Result<Self, Self::Error> {
533            match content.role {
534                Some(Role::User) | None => {
535                    Ok(message::Message::User {
536                        content: {
537                            let user_content: Result<Vec<_>, _> = content.parts.into_iter()
538                            .map(|Part { part, .. }| {
539                                Ok(match part {
540                                    PartKind::Text(text) => message::UserContent::text(text),
541                                    PartKind::InlineData(inline_data) => {
542                                        let mime_type =
543                                            message::MediaType::from_mime_type(&inline_data.mime_type);
544
545                                        match mime_type {
546                                            Some(message::MediaType::Image(media_type)) => {
547                                                message::UserContent::image_base64(
548                                                    inline_data.data,
549                                                    Some(media_type),
550                                                    Some(message::ImageDetail::default()),
551                                                )
552                                            }
553                                            Some(message::MediaType::Document(media_type)) => {
554                                                message::UserContent::document(
555                                                    inline_data.data,
556                                                    Some(media_type),
557                                                )
558                                            }
559                                            Some(message::MediaType::Audio(media_type)) => {
560                                                message::UserContent::audio(
561                                                    inline_data.data,
562                                                    Some(media_type),
563                                                )
564                                            }
565                                            _ => {
566                                                return Err(message::MessageError::ConversionError(
567                                                    format!("Unsupported media type {mime_type:?}"),
568                                                ));
569                                            }
570                                        }
571                                    }
572                                    _ => {
573                                        return Err(message::MessageError::ConversionError(format!(
574                                            "Unsupported gemini content part type: {part:?}"
575                                        )));
576                                    }
577                                })
578                            })
579                            .collect();
580                            OneOrMany::many(user_content?).map_err(|_| {
581                                message::MessageError::ConversionError(
582                                    "Failed to create OneOrMany from user content".to_string(),
583                                )
584                            })?
585                        },
586                    })
587                }
588                Some(Role::Model) => Ok(message::Message::Assistant {
589                    id: None,
590                    content: {
591                        let assistant_content: Result<Vec<_>, _> = content
592                            .parts
593                            .into_iter()
594                            .map(|Part { thought, part, .. }| {
595                                Ok(match part {
596                                    PartKind::Text(text) => match thought {
597                                        Some(true) => message::AssistantContent::Reasoning(
598                                            Reasoning::new(&text),
599                                        ),
600                                        _ => message::AssistantContent::Text(Text { text }),
601                                    },
602
603                                    PartKind::FunctionCall(function_call) => {
604                                        message::AssistantContent::ToolCall(function_call.into())
605                                    }
606                                    _ => {
607                                        return Err(message::MessageError::ConversionError(
608                                            format!("Unsupported part type: {part:?}"),
609                                        ));
610                                    }
611                                })
612                            })
613                            .collect();
614                        OneOrMany::many(assistant_content?).map_err(|_| {
615                            message::MessageError::ConversionError(
616                                "Failed to create OneOrMany from assistant content".to_string(),
617                            )
618                        })?
619                    },
620                }),
621            }
622        }
623    }
624
625    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
626    #[serde(rename_all = "lowercase")]
627    pub enum Role {
628        User,
629        Model,
630    }
631
632    #[derive(Debug, Default, Deserialize, Serialize, Clone, PartialEq)]
633    #[serde(rename_all = "camelCase")]
634    pub struct Part {
635        /// whether or not the part is a reasoning/thinking text or not
636        #[serde(skip_serializing_if = "Option::is_none")]
637        pub thought: Option<bool>,
638        /// an opaque sig for the thought so it can be reused - is a base64 string
639        #[serde(skip_serializing_if = "Option::is_none")]
640        pub thought_signature: Option<String>,
641        #[serde(flatten)]
642        pub part: PartKind,
643        #[serde(flatten, skip_serializing_if = "Option::is_none")]
644        pub additional_params: Option<Value>,
645    }
646
647    /// A datatype containing media that is part of a multi-part [Content] message.
648    /// A Part consists of data which has an associated datatype. A Part can only contain one of the accepted types in Part.data.
649    /// A Part must have a fixed IANA MIME type identifying the type and subtype of the media if the inlineData field is filled with raw bytes.
650    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
651    #[serde(rename_all = "camelCase")]
652    pub enum PartKind {
653        Text(String),
654        InlineData(Blob),
655        FunctionCall(FunctionCall),
656        FunctionResponse(FunctionResponse),
657        FileData(FileData),
658        ExecutableCode(ExecutableCode),
659        CodeExecutionResult(CodeExecutionResult),
660    }
661
662    // This default instance is primarily so we can easily fill in the optional fields of `Part`
663    // So this instance for `PartKind` (and the allocation it would cause) should be optimized away
664    impl Default for PartKind {
665        fn default() -> Self {
666            Self::Text(String::new())
667        }
668    }
669
670    impl From<String> for Part {
671        fn from(text: String) -> Self {
672            Self {
673                thought: Some(false),
674                thought_signature: None,
675                part: PartKind::Text(text),
676                additional_params: None,
677            }
678        }
679    }
680
681    impl From<&str> for Part {
682        fn from(text: &str) -> Self {
683            Self::from(text.to_string())
684        }
685    }
686
687    impl FromStr for Part {
688        type Err = Infallible;
689
690        fn from_str(s: &str) -> Result<Self, Self::Err> {
691            Ok(s.into())
692        }
693    }
694
695    impl TryFrom<(ImageMediaType, DocumentSourceKind)> for PartKind {
696        type Error = message::MessageError;
697        fn try_from(
698            (mime_type, doc_src): (ImageMediaType, DocumentSourceKind),
699        ) -> Result<Self, Self::Error> {
700            let mime_type = mime_type.to_mime_type().to_string();
701            let part = match doc_src {
702                DocumentSourceKind::Url(url) => PartKind::FileData(FileData {
703                    mime_type: Some(mime_type),
704                    file_uri: url,
705                }),
706                DocumentSourceKind::Base64(data) | DocumentSourceKind::String(data) => {
707                    PartKind::InlineData(Blob { mime_type, data })
708                }
709                DocumentSourceKind::Raw(_) => {
710                    return Err(message::MessageError::ConversionError(
711                        "Raw files not supported, encode as base64 first".into(),
712                    ));
713                }
714                DocumentSourceKind::Unknown => {
715                    return Err(message::MessageError::ConversionError(
716                        "Can't convert an unknown document source".to_string(),
717                    ));
718                }
719            };
720
721            Ok(part)
722        }
723    }
724
725    impl TryFrom<message::UserContent> for Part {
726        type Error = message::MessageError;
727
728        fn try_from(content: message::UserContent) -> Result<Self, Self::Error> {
729            match content {
730                message::UserContent::Text(message::Text { text }) => Ok(Part {
731                    thought: Some(false),
732                    thought_signature: None,
733                    part: PartKind::Text(text),
734                    additional_params: None,
735                }),
736                message::UserContent::ToolResult(message::ToolResult { id, content, .. }) => {
737                    let content = match content.first() {
738                        message::ToolResultContent::Text(text) => text.text,
739                        message::ToolResultContent::Image(_) => {
740                            return Err(message::MessageError::ConversionError(
741                                "Tool result content must be text".to_string(),
742                            ));
743                        }
744                    };
745                    // Convert to JSON since this value may be a valid JSON value
746                    let result: serde_json::Value =
747                        serde_json::from_str(&content).unwrap_or_else(|error| {
748                            tracing::trace!(
749                                ?error,
750                                "Tool result is not a valid JSON, treat it as normal string"
751                            );
752                            json!(content)
753                        });
754                    Ok(Part {
755                        thought: Some(false),
756                        thought_signature: None,
757                        part: PartKind::FunctionResponse(FunctionResponse {
758                            name: id,
759                            response: Some(json!({ "result": result })),
760                        }),
761                        additional_params: None,
762                    })
763                }
764                message::UserContent::Image(message::Image {
765                    data, media_type, ..
766                }) => match media_type {
767                    Some(media_type) => match media_type {
768                        message::ImageMediaType::JPEG
769                        | message::ImageMediaType::PNG
770                        | message::ImageMediaType::WEBP
771                        | message::ImageMediaType::HEIC
772                        | message::ImageMediaType::HEIF => {
773                            let part = PartKind::try_from((media_type, data))?;
774                            Ok(Part {
775                                thought: Some(false),
776                                thought_signature: None,
777                                part,
778                                additional_params: None,
779                            })
780                        }
781                        _ => Err(message::MessageError::ConversionError(format!(
782                            "Unsupported image media type {media_type:?}"
783                        ))),
784                    },
785                    None => Err(message::MessageError::ConversionError(
786                        "Media type for image is required for Gemini".to_string(),
787                    )),
788                },
789                message::UserContent::Document(message::Document {
790                    data, media_type, ..
791                }) => {
792                    let Some(media_type) = media_type else {
793                        return Err(MessageError::ConversionError(
794                            "A mime type is required for document inputs to Gemini".to_string(),
795                        ));
796                    };
797
798                    if !media_type.is_code() {
799                        let mime_type = media_type.to_mime_type().to_string();
800
801                        let part = match data {
802                            DocumentSourceKind::Url(file_uri) => PartKind::FileData(FileData {
803                                mime_type: Some(mime_type),
804                                file_uri,
805                            }),
806                            DocumentSourceKind::Base64(data) | DocumentSourceKind::String(data) => {
807                                PartKind::InlineData(Blob { mime_type, data })
808                            }
809                            DocumentSourceKind::Raw(_) => {
810                                return Err(message::MessageError::ConversionError(
811                                    "Raw files not supported, encode as base64 first".into(),
812                                ));
813                            }
814                            _ => {
815                                return Err(message::MessageError::ConversionError(
816                                    "Document has no body".to_string(),
817                                ));
818                            }
819                        };
820
821                        Ok(Part {
822                            thought: Some(false),
823                            part,
824                            ..Default::default()
825                        })
826                    } else {
827                        Err(message::MessageError::ConversionError(format!(
828                            "Unsupported document media type {media_type:?}"
829                        )))
830                    }
831                }
832
833                message::UserContent::Audio(message::Audio {
834                    data, media_type, ..
835                }) => {
836                    let Some(media_type) = media_type else {
837                        return Err(MessageError::ConversionError(
838                            "A mime type is required for audio inputs to Gemini".to_string(),
839                        ));
840                    };
841
842                    let mime_type = media_type.to_mime_type().to_string();
843
844                    let part = match data {
845                        DocumentSourceKind::Base64(data) => {
846                            PartKind::InlineData(Blob { data, mime_type })
847                        }
848
849                        DocumentSourceKind::Url(file_uri) => PartKind::FileData(FileData {
850                            mime_type: Some(mime_type),
851                            file_uri,
852                        }),
853                        DocumentSourceKind::String(_) => {
854                            return Err(message::MessageError::ConversionError(
855                                "Strings cannot be used as audio files!".into(),
856                            ));
857                        }
858                        DocumentSourceKind::Raw(_) => {
859                            return Err(message::MessageError::ConversionError(
860                                "Raw files not supported, encode as base64 first".into(),
861                            ));
862                        }
863                        DocumentSourceKind::Unknown => {
864                            return Err(message::MessageError::ConversionError(
865                                "Content has no body".to_string(),
866                            ));
867                        }
868                    };
869
870                    Ok(Part {
871                        thought: Some(false),
872                        part,
873                        ..Default::default()
874                    })
875                }
876                message::UserContent::Video(message::Video {
877                    data,
878                    media_type,
879                    additional_params,
880                    ..
881                }) => {
882                    let mime_type = media_type.map(|media_ty| media_ty.to_mime_type().to_string());
883
884                    let part = match data {
885                        DocumentSourceKind::Url(file_uri) => {
886                            if file_uri.starts_with("https://www.youtube.com") {
887                                PartKind::FileData(FileData {
888                                    mime_type,
889                                    file_uri,
890                                })
891                            } else {
892                                if mime_type.is_none() {
893                                    return Err(MessageError::ConversionError(
894                                        "A mime type is required for non-Youtube video file inputs to Gemini"
895                                            .to_string(),
896                                    ));
897                                }
898
899                                PartKind::FileData(FileData {
900                                    mime_type,
901                                    file_uri,
902                                })
903                            }
904                        }
905                        DocumentSourceKind::Base64(data) => {
906                            let Some(mime_type) = mime_type else {
907                                return Err(MessageError::ConversionError(
908                                    "A media type is expected for base64 encoded strings"
909                                        .to_string(),
910                                ));
911                            };
912                            PartKind::InlineData(Blob { mime_type, data })
913                        }
914                        DocumentSourceKind::String(_) => {
915                            return Err(message::MessageError::ConversionError(
916                                "Strings cannot be used as audio files!".into(),
917                            ));
918                        }
919                        DocumentSourceKind::Raw(_) => {
920                            return Err(message::MessageError::ConversionError(
921                                "Raw file data not supported, encode as base64 first".into(),
922                            ));
923                        }
924                        DocumentSourceKind::Unknown => {
925                            return Err(message::MessageError::ConversionError(
926                                "Media type for video is required for Gemini".to_string(),
927                            ));
928                        }
929                    };
930
931                    Ok(Part {
932                        thought: Some(false),
933                        thought_signature: None,
934                        part,
935                        additional_params,
936                    })
937                }
938            }
939        }
940    }
941
942    impl From<message::AssistantContent> for Part {
943        fn from(content: message::AssistantContent) -> Self {
944            match content {
945                message::AssistantContent::Text(message::Text { text }) => text.into(),
946                message::AssistantContent::ToolCall(tool_call) => tool_call.into(),
947                message::AssistantContent::Reasoning(message::Reasoning { reasoning, .. }) => {
948                    Part {
949                        thought: Some(true),
950                        thought_signature: None,
951                        part: PartKind::Text(
952                            reasoning.first().cloned().unwrap_or_else(|| "".to_string()),
953                        ),
954                        additional_params: None,
955                    }
956                }
957            }
958        }
959    }
960
961    impl From<message::ToolCall> for Part {
962        fn from(tool_call: message::ToolCall) -> Self {
963            Self {
964                thought: Some(false),
965                thought_signature: None,
966                part: PartKind::FunctionCall(FunctionCall {
967                    name: tool_call.function.name,
968                    args: tool_call.function.arguments,
969                }),
970                additional_params: None,
971            }
972        }
973    }
974
975    /// Raw media bytes.
976    /// Text should not be sent as raw bytes, use the 'text' field.
977    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
978    #[serde(rename_all = "camelCase")]
979    pub struct Blob {
980        /// The IANA standard MIME type of the source data. Examples: - image/png - image/jpeg
981        /// If an unsupported MIME type is provided, an error will be returned.
982        pub mime_type: String,
983        /// Raw bytes for media formats. A base64-encoded string.
984        pub data: String,
985    }
986
987    /// A predicted FunctionCall returned from the model that contains a string representing the
988    /// FunctionDeclaration.name with the arguments and their values.
989    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
990    pub struct FunctionCall {
991        /// Required. The name of the function to call. Must be a-z, A-Z, 0-9, or contain underscores
992        /// and dashes, with a maximum length of 63.
993        pub name: String,
994        /// Optional. The function parameters and values in JSON object format.
995        pub args: serde_json::Value,
996    }
997
998    impl From<FunctionCall> for message::ToolCall {
999        fn from(function_call: FunctionCall) -> Self {
1000            Self {
1001                id: function_call.name.clone(),
1002                call_id: None,
1003                function: message::ToolFunction {
1004                    name: function_call.name,
1005                    arguments: function_call.args,
1006                },
1007            }
1008        }
1009    }
1010
1011    impl From<message::ToolCall> for FunctionCall {
1012        fn from(tool_call: message::ToolCall) -> Self {
1013            Self {
1014                name: tool_call.function.name,
1015                args: tool_call.function.arguments,
1016            }
1017        }
1018    }
1019
1020    /// The result output from a FunctionCall that contains a string representing the FunctionDeclaration.name
1021    /// and a structured JSON object containing any output from the function is used as context to the model.
1022    /// This should contain the result of aFunctionCall made based on model prediction.
1023    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1024    pub struct FunctionResponse {
1025        /// The name of the function to call. Must be a-z, A-Z, 0-9, or contain underscores and dashes,
1026        /// with a maximum length of 63.
1027        pub name: String,
1028        /// The function response in JSON object format.
1029        pub response: Option<serde_json::Value>,
1030    }
1031
1032    /// URI based data.
1033    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1034    #[serde(rename_all = "camelCase")]
1035    pub struct FileData {
1036        /// Optional. The IANA standard MIME type of the source data.
1037        pub mime_type: Option<String>,
1038        /// Required. URI.
1039        pub file_uri: String,
1040    }
1041
1042    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1043    pub struct SafetyRating {
1044        pub category: HarmCategory,
1045        pub probability: HarmProbability,
1046    }
1047
1048    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1049    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1050    pub enum HarmProbability {
1051        HarmProbabilityUnspecified,
1052        Negligible,
1053        Low,
1054        Medium,
1055        High,
1056    }
1057
1058    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1059    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1060    pub enum HarmCategory {
1061        HarmCategoryUnspecified,
1062        HarmCategoryDerogatory,
1063        HarmCategoryToxicity,
1064        HarmCategoryViolence,
1065        HarmCategorySexually,
1066        HarmCategoryMedical,
1067        HarmCategoryDangerous,
1068        HarmCategoryHarassment,
1069        HarmCategoryHateSpeech,
1070        HarmCategorySexuallyExplicit,
1071        HarmCategoryDangerousContent,
1072        HarmCategoryCivicIntegrity,
1073    }
1074
1075    #[derive(Debug, Deserialize, Clone, Default, Serialize)]
1076    #[serde(rename_all = "camelCase")]
1077    pub struct UsageMetadata {
1078        pub prompt_token_count: i32,
1079        #[serde(skip_serializing_if = "Option::is_none")]
1080        pub cached_content_token_count: Option<i32>,
1081        #[serde(skip_serializing_if = "Option::is_none")]
1082        pub candidates_token_count: Option<i32>,
1083        pub total_token_count: i32,
1084        #[serde(skip_serializing_if = "Option::is_none")]
1085        pub thoughts_token_count: Option<i32>,
1086    }
1087
1088    impl std::fmt::Display for UsageMetadata {
1089        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1090            write!(
1091                f,
1092                "Prompt token count: {}\nCached content token count: {}\nCandidates token count: {}\nTotal token count: {}",
1093                self.prompt_token_count,
1094                match self.cached_content_token_count {
1095                    Some(count) => count.to_string(),
1096                    None => "n/a".to_string(),
1097                },
1098                match self.candidates_token_count {
1099                    Some(count) => count.to_string(),
1100                    None => "n/a".to_string(),
1101                },
1102                self.total_token_count
1103            )
1104        }
1105    }
1106
1107    impl GetTokenUsage for UsageMetadata {
1108        fn token_usage(&self) -> Option<crate::completion::Usage> {
1109            let mut usage = crate::completion::Usage::new();
1110
1111            usage.input_tokens = self.prompt_token_count as u64;
1112            usage.output_tokens = (self.cached_content_token_count.unwrap_or_default()
1113                + self.candidates_token_count.unwrap_or_default()
1114                + self.thoughts_token_count.unwrap_or_default())
1115                as u64;
1116            usage.total_tokens = usage.input_tokens + usage.output_tokens;
1117
1118            Some(usage)
1119        }
1120    }
1121
1122    /// A set of the feedback metadata the prompt specified in [GenerateContentRequest.contents](GenerateContentRequest).
1123    #[derive(Debug, Deserialize, Serialize)]
1124    #[serde(rename_all = "camelCase")]
1125    pub struct PromptFeedback {
1126        /// Optional. If set, the prompt was blocked and no candidates are returned. Rephrase the prompt.
1127        pub block_reason: Option<BlockReason>,
1128        /// Ratings for safety of the prompt. There is at most one rating per category.
1129        pub safety_ratings: Option<Vec<SafetyRating>>,
1130    }
1131
1132    /// Reason why a prompt was blocked by the model
1133    #[derive(Debug, Deserialize, Serialize)]
1134    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1135    pub enum BlockReason {
1136        /// Default value. This value is unused.
1137        BlockReasonUnspecified,
1138        /// Prompt was blocked due to safety reasons. Inspect safetyRatings to understand which safety category blocked it.
1139        Safety,
1140        /// Prompt was blocked due to unknown reasons.
1141        Other,
1142        /// Prompt was blocked due to the terms which are included from the terminology blocklist.
1143        Blocklist,
1144        /// Prompt was blocked due to prohibited content.
1145        ProhibitedContent,
1146    }
1147
1148    #[derive(Clone, Debug, Deserialize, Serialize)]
1149    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1150    pub enum FinishReason {
1151        /// Default value. This value is unused.
1152        FinishReasonUnspecified,
1153        /// Natural stop point of the model or provided stop sequence.
1154        Stop,
1155        /// The maximum number of tokens as specified in the request was reached.
1156        MaxTokens,
1157        /// The response candidate content was flagged for safety reasons.
1158        Safety,
1159        /// The response candidate content was flagged for recitation reasons.
1160        Recitation,
1161        /// The response candidate content was flagged for using an unsupported language.
1162        Language,
1163        /// Unknown reason.
1164        Other,
1165        /// Token generation stopped because the content contains forbidden terms.
1166        Blocklist,
1167        /// Token generation stopped for potentially containing prohibited content.
1168        ProhibitedContent,
1169        /// Token generation stopped because the content potentially contains Sensitive Personally Identifiable Information (SPII).
1170        Spii,
1171        /// The function call generated by the model is invalid.
1172        MalformedFunctionCall,
1173    }
1174
1175    #[derive(Clone, Debug, Deserialize, Serialize)]
1176    #[serde(rename_all = "camelCase")]
1177    pub struct CitationMetadata {
1178        pub citation_sources: Vec<CitationSource>,
1179    }
1180
1181    #[derive(Clone, Debug, Deserialize, Serialize)]
1182    #[serde(rename_all = "camelCase")]
1183    pub struct CitationSource {
1184        #[serde(skip_serializing_if = "Option::is_none")]
1185        pub uri: Option<String>,
1186        #[serde(skip_serializing_if = "Option::is_none")]
1187        pub start_index: Option<i32>,
1188        #[serde(skip_serializing_if = "Option::is_none")]
1189        pub end_index: Option<i32>,
1190        #[serde(skip_serializing_if = "Option::is_none")]
1191        pub license: Option<String>,
1192    }
1193
1194    #[derive(Clone, Debug, Deserialize, Serialize)]
1195    #[serde(rename_all = "camelCase")]
1196    pub struct LogprobsResult {
1197        pub top_candidate: Vec<TopCandidate>,
1198        pub chosen_candidate: Vec<LogProbCandidate>,
1199    }
1200
1201    #[derive(Clone, Debug, Deserialize, Serialize)]
1202    pub struct TopCandidate {
1203        pub candidates: Vec<LogProbCandidate>,
1204    }
1205
1206    #[derive(Clone, Debug, Deserialize, Serialize)]
1207    #[serde(rename_all = "camelCase")]
1208    pub struct LogProbCandidate {
1209        pub token: String,
1210        pub token_id: String,
1211        pub log_probability: f64,
1212    }
1213
1214    /// Gemini API Configuration options for model generation and outputs. Not all parameters are
1215    /// configurable for every model. From [Gemini API Reference](https://ai.google.dev/api/generate-content#generationconfig)
1216    /// ### Rig Note:
1217    /// Can be used to construct a typesafe `additional_params` in rig::[AgentBuilder](crate::agent::AgentBuilder).
1218    #[derive(Debug, Deserialize, Serialize)]
1219    #[serde(rename_all = "camelCase")]
1220    pub struct GenerationConfig {
1221        /// The set of character sequences (up to 5) that will stop output generation. If specified, the API will stop
1222        /// at the first appearance of a stop_sequence. The stop sequence will not be included as part of the response.
1223        #[serde(skip_serializing_if = "Option::is_none")]
1224        pub stop_sequences: Option<Vec<String>>,
1225        /// MIME type of the generated candidate text. Supported MIME types are:
1226        ///     - text/plain:  (default) Text output
1227        ///     - application/json: JSON response in the response candidates.
1228        ///     - text/x.enum: ENUM as a string response in the response candidates.
1229        /// Refer to the docs for a list of all supported text MIME types
1230        #[serde(skip_serializing_if = "Option::is_none")]
1231        pub response_mime_type: Option<String>,
1232        /// Output schema of the generated candidate text. Schemas must be a subset of the OpenAPI schema and can be
1233        /// objects, primitives or arrays. If set, a compatible responseMimeType must also  be set. Compatible MIME
1234        /// types: application/json: Schema for JSON response. Refer to the JSON text generation guide for more details.
1235        #[serde(skip_serializing_if = "Option::is_none")]
1236        pub response_schema: Option<Schema>,
1237        /// Number of generated responses to return. Currently, this value can only be set to 1. If
1238        /// unset, this will default to 1.
1239        #[serde(skip_serializing_if = "Option::is_none")]
1240        pub candidate_count: Option<i32>,
1241        /// The maximum number of tokens to include in a response candidate. Note: The default value varies by model, see
1242        /// the Model.output_token_limit attribute of the Model returned from the getModel function.
1243        #[serde(skip_serializing_if = "Option::is_none")]
1244        pub max_output_tokens: Option<u64>,
1245        /// Controls the randomness of the output. Note: The default value varies by model, see the Model.temperature
1246        /// attribute of the Model returned from the getModel function. Values can range from [0.0, 2.0].
1247        #[serde(skip_serializing_if = "Option::is_none")]
1248        pub temperature: Option<f64>,
1249        /// The maximum cumulative probability of tokens to consider when sampling. The model uses combined Top-k and
1250        /// Top-p (nucleus) sampling. Tokens are sorted based on their assigned probabilities so that only the most
1251        /// likely tokens are considered. Top-k sampling directly limits the maximum number of tokens to consider, while
1252        /// Nucleus sampling limits the number of tokens based on the cumulative probability. Note: The default value
1253        /// varies by Model and is specified by theModel.top_p attribute returned from the getModel function. An empty
1254        /// topK attribute indicates that the model doesn't apply top-k sampling and doesn't allow setting topK on requests.
1255        #[serde(skip_serializing_if = "Option::is_none")]
1256        pub top_p: Option<f64>,
1257        /// The maximum number of tokens to consider when sampling. Gemini models use Top-p (nucleus) sampling or a
1258        /// combination of Top-k and nucleus sampling. Top-k sampling considers the set of topK most probable tokens.
1259        /// Models running with nucleus sampling don't allow topK setting. Note: The default value varies by Model and is
1260        /// specified by theModel.top_p attribute returned from the getModel function. An empty topK attribute indicates
1261        /// that the model doesn't apply top-k sampling and doesn't allow setting topK on requests.
1262        #[serde(skip_serializing_if = "Option::is_none")]
1263        pub top_k: Option<i32>,
1264        /// Presence penalty applied to the next token's logprobs if the token has already been seen in the response.
1265        /// This penalty is binary on/off and not dependent on the number of times the token is used (after the first).
1266        /// Use frequencyPenalty for a penalty that increases with each use. A positive penalty will discourage the use
1267        /// of tokens that have already been used in the response, increasing the vocabulary. A negative penalty will
1268        /// encourage the use of tokens that have already been used in the response, decreasing the vocabulary.
1269        #[serde(skip_serializing_if = "Option::is_none")]
1270        pub presence_penalty: Option<f64>,
1271        /// Frequency penalty applied to the next token's logprobs, multiplied by the number of times each token has been
1272        /// seen in the response so far. A positive penalty will discourage the use of tokens that have already been
1273        /// used, proportional to the number of times the token has been used: The more a token is used, the more
1274        /// difficult it is for the  model to use that token again increasing the vocabulary of responses. Caution: A
1275        /// negative penalty will encourage the model to reuse tokens proportional to the number of times the token has
1276        /// been used. Small negative values will reduce the vocabulary of a response. Larger negative values will cause
1277        /// the model to  repeating a common token until it hits the maxOutputTokens limit: "...the the the the the...".
1278        #[serde(skip_serializing_if = "Option::is_none")]
1279        pub frequency_penalty: Option<f64>,
1280        /// If true, export the logprobs results in response.
1281        #[serde(skip_serializing_if = "Option::is_none")]
1282        pub response_logprobs: Option<bool>,
1283        /// Only valid if responseLogprobs=True. This sets the number of top logprobs to return at each decoding step in
1284        /// [Candidate.logprobs_result].
1285        #[serde(skip_serializing_if = "Option::is_none")]
1286        pub logprobs: Option<i32>,
1287        /// Configuration for thinking/reasoning.
1288        #[serde(skip_serializing_if = "Option::is_none")]
1289        pub thinking_config: Option<ThinkingConfig>,
1290    }
1291
1292    impl Default for GenerationConfig {
1293        fn default() -> Self {
1294            Self {
1295                temperature: Some(1.0),
1296                max_output_tokens: Some(4096),
1297                stop_sequences: None,
1298                response_mime_type: None,
1299                response_schema: None,
1300                candidate_count: None,
1301                top_p: None,
1302                top_k: None,
1303                presence_penalty: None,
1304                frequency_penalty: None,
1305                response_logprobs: None,
1306                logprobs: None,
1307                thinking_config: None,
1308            }
1309        }
1310    }
1311
1312    #[derive(Debug, Deserialize, Serialize)]
1313    #[serde(rename_all = "camelCase")]
1314    pub struct ThinkingConfig {
1315        pub thinking_budget: u32,
1316        pub include_thoughts: Option<bool>,
1317    }
1318    /// The Schema object allows the definition of input and output data types. These types can be objects, but also
1319    /// primitives and arrays. Represents a select subset of an OpenAPI 3.0 schema object.
1320    /// From [Gemini API Reference](https://ai.google.dev/api/caching#Schema)
1321    #[derive(Debug, Deserialize, Serialize, Clone)]
1322    pub struct Schema {
1323        pub r#type: String,
1324        #[serde(skip_serializing_if = "Option::is_none")]
1325        pub format: Option<String>,
1326        #[serde(skip_serializing_if = "Option::is_none")]
1327        pub description: Option<String>,
1328        #[serde(skip_serializing_if = "Option::is_none")]
1329        pub nullable: Option<bool>,
1330        #[serde(skip_serializing_if = "Option::is_none")]
1331        pub r#enum: Option<Vec<String>>,
1332        #[serde(skip_serializing_if = "Option::is_none")]
1333        pub max_items: Option<i32>,
1334        #[serde(skip_serializing_if = "Option::is_none")]
1335        pub min_items: Option<i32>,
1336        #[serde(skip_serializing_if = "Option::is_none")]
1337        pub properties: Option<HashMap<String, Schema>>,
1338        #[serde(skip_serializing_if = "Option::is_none")]
1339        pub required: Option<Vec<String>>,
1340        #[serde(skip_serializing_if = "Option::is_none")]
1341        pub items: Option<Box<Schema>>,
1342    }
1343
1344    /// Flattens a JSON schema by resolving all `$ref` references inline.
1345    /// It takes a JSON schema that may contain `$ref` references to definitions
1346    /// in `$defs` or `definitions` sections and returns a new schema with all references
1347    /// resolved and inlined. This is necessary for APIs like Gemini that don't support
1348    /// schema references.
1349    pub fn flatten_schema(mut schema: Value) -> Result<Value, CompletionError> {
1350        // extracting $defs if they exist
1351        let defs = if let Some(obj) = schema.as_object() {
1352            obj.get("$defs").or_else(|| obj.get("definitions")).cloned()
1353        } else {
1354            None
1355        };
1356
1357        let Some(defs_value) = defs else {
1358            return Ok(schema);
1359        };
1360
1361        let Some(defs_obj) = defs_value.as_object() else {
1362            return Err(CompletionError::ResponseError(
1363                "$defs must be an object".into(),
1364            ));
1365        };
1366
1367        resolve_refs(&mut schema, defs_obj)?;
1368
1369        // removing $defs from the final schema because we have inlined everything
1370        if let Some(obj) = schema.as_object_mut() {
1371            obj.remove("$defs");
1372            obj.remove("definitions");
1373        }
1374
1375        Ok(schema)
1376    }
1377
1378    /// Recursively resolves all `$ref` references in a JSON value by
1379    /// replacing them with their definitions.
1380    fn resolve_refs(
1381        value: &mut Value,
1382        defs: &serde_json::Map<String, Value>,
1383    ) -> Result<(), CompletionError> {
1384        match value {
1385            Value::Object(obj) => {
1386                if let Some(ref_value) = obj.get("$ref")
1387                    && let Some(ref_str) = ref_value.as_str()
1388                {
1389                    // "#/$defs/Person" -> "Person"
1390                    let def_name = parse_ref_path(ref_str)?;
1391
1392                    let def = defs.get(&def_name).ok_or_else(|| {
1393                        CompletionError::ResponseError(format!("Reference not found: {}", ref_str))
1394                    })?;
1395
1396                    let mut resolved = def.clone();
1397                    resolve_refs(&mut resolved, defs)?;
1398                    *value = resolved;
1399                    return Ok(());
1400                }
1401
1402                for (_, v) in obj.iter_mut() {
1403                    resolve_refs(v, defs)?;
1404                }
1405            }
1406            Value::Array(arr) => {
1407                for item in arr.iter_mut() {
1408                    resolve_refs(item, defs)?;
1409                }
1410            }
1411            _ => {}
1412        }
1413
1414        Ok(())
1415    }
1416
1417    /// Parses a JSON Schema `$ref` path to extract the definition name.
1418    ///
1419    /// JSON Schema references use URI fragment syntax to point to definitions within
1420    /// the same document. This function extracts the definition name from common
1421    /// reference patterns used in JSON Schema.
1422    fn parse_ref_path(ref_str: &str) -> Result<String, CompletionError> {
1423        if let Some(fragment) = ref_str.strip_prefix('#') {
1424            if let Some(name) = fragment.strip_prefix("/$defs/") {
1425                Ok(name.to_string())
1426            } else if let Some(name) = fragment.strip_prefix("/definitions/") {
1427                Ok(name.to_string())
1428            } else {
1429                Err(CompletionError::ResponseError(format!(
1430                    "Unsupported reference format: {}",
1431                    ref_str
1432                )))
1433            }
1434        } else {
1435            Err(CompletionError::ResponseError(format!(
1436                "Only fragment references (#/...) are supported: {}",
1437                ref_str
1438            )))
1439        }
1440    }
1441
1442    /// Helper function to extract the type string from a JSON value.
1443    /// Handles both direct string types and array types (returns the first element).
1444    fn extract_type(type_value: &Value) -> Option<String> {
1445        if type_value.is_string() {
1446            type_value.as_str().map(String::from)
1447        } else if type_value.is_array() {
1448            type_value
1449                .as_array()
1450                .and_then(|arr| arr.first())
1451                .and_then(|v| v.as_str().map(String::from))
1452        } else {
1453            None
1454        }
1455    }
1456
1457    /// Helper function to extract type from anyOf, oneOf, or allOf schemas.
1458    /// Returns the type of the first non-null schema found.
1459    fn extract_type_from_composition(composition: &Value) -> Option<String> {
1460        composition.as_array().and_then(|arr| {
1461            arr.iter().find_map(|schema| {
1462                if let Some(obj) = schema.as_object() {
1463                    // Skip null types
1464                    if let Some(type_val) = obj.get("type")
1465                        && let Some(type_str) = type_val.as_str()
1466                        && type_str == "null"
1467                    {
1468                        return None;
1469                    }
1470                    // Extract type from this schema
1471                    obj.get("type").and_then(extract_type).or_else(|| {
1472                        if obj.contains_key("properties") {
1473                            Some("object".to_string())
1474                        } else {
1475                            None
1476                        }
1477                    })
1478                } else {
1479                    None
1480                }
1481            })
1482        })
1483    }
1484
1485    /// Helper function to extract the first non-null schema from anyOf, oneOf, or allOf.
1486    /// Returns the schema object that should be used for properties, required, etc.
1487    fn extract_schema_from_composition(
1488        composition: &Value,
1489    ) -> Option<serde_json::Map<String, Value>> {
1490        composition.as_array().and_then(|arr| {
1491            arr.iter().find_map(|schema| {
1492                if let Some(obj) = schema.as_object()
1493                    && let Some(type_val) = obj.get("type")
1494                    && let Some(type_str) = type_val.as_str()
1495                {
1496                    if type_str == "null" {
1497                        return None;
1498                    }
1499                    Some(obj.clone())
1500                } else {
1501                    None
1502                }
1503            })
1504        })
1505    }
1506
1507    /// Helper function to infer the type of a schema object.
1508    /// Checks for explicit type, then anyOf/oneOf/allOf, then infers from properties.
1509    fn infer_type(obj: &serde_json::Map<String, Value>) -> String {
1510        // First, try direct type field
1511        if let Some(type_val) = obj.get("type")
1512            && let Some(type_str) = extract_type(type_val)
1513        {
1514            return type_str;
1515        }
1516
1517        // Then try anyOf, oneOf, allOf (in that order)
1518        if let Some(any_of) = obj.get("anyOf")
1519            && let Some(type_str) = extract_type_from_composition(any_of)
1520        {
1521            return type_str;
1522        }
1523
1524        if let Some(one_of) = obj.get("oneOf")
1525            && let Some(type_str) = extract_type_from_composition(one_of)
1526        {
1527            return type_str;
1528        }
1529
1530        if let Some(all_of) = obj.get("allOf")
1531            && let Some(type_str) = extract_type_from_composition(all_of)
1532        {
1533            return type_str;
1534        }
1535
1536        // Finally, infer object type if properties are present
1537        if obj.contains_key("properties") {
1538            "object".to_string()
1539        } else {
1540            String::new()
1541        }
1542    }
1543
1544    impl TryFrom<Value> for Schema {
1545        type Error = CompletionError;
1546
1547        fn try_from(value: Value) -> Result<Self, Self::Error> {
1548            let flattened_val = flatten_schema(value)?;
1549            if let Some(obj) = flattened_val.as_object() {
1550                // Determine which object to use for extracting properties and required fields.
1551                // If this object has anyOf/oneOf/allOf, we need to extract properties from the composition.
1552                let props_source = if obj.get("properties").is_none() {
1553                    if let Some(any_of) = obj.get("anyOf") {
1554                        extract_schema_from_composition(any_of)
1555                    } else if let Some(one_of) = obj.get("oneOf") {
1556                        extract_schema_from_composition(one_of)
1557                    } else if let Some(all_of) = obj.get("allOf") {
1558                        extract_schema_from_composition(all_of)
1559                    } else {
1560                        None
1561                    }
1562                    .unwrap_or(obj.clone())
1563                } else {
1564                    obj.clone()
1565                };
1566
1567                Ok(Schema {
1568                    r#type: infer_type(obj),
1569                    format: obj.get("format").and_then(|v| v.as_str()).map(String::from),
1570                    description: obj
1571                        .get("description")
1572                        .and_then(|v| v.as_str())
1573                        .map(String::from),
1574                    nullable: obj.get("nullable").and_then(|v| v.as_bool()),
1575                    r#enum: obj.get("enum").and_then(|v| v.as_array()).map(|arr| {
1576                        arr.iter()
1577                            .filter_map(|v| v.as_str().map(String::from))
1578                            .collect()
1579                    }),
1580                    max_items: obj
1581                        .get("maxItems")
1582                        .and_then(|v| v.as_i64())
1583                        .map(|v| v as i32),
1584                    min_items: obj
1585                        .get("minItems")
1586                        .and_then(|v| v.as_i64())
1587                        .map(|v| v as i32),
1588                    properties: props_source
1589                        .get("properties")
1590                        .and_then(|v| v.as_object())
1591                        .map(|map| {
1592                            map.iter()
1593                                .filter_map(|(k, v)| {
1594                                    v.clone().try_into().ok().map(|schema| (k.clone(), schema))
1595                                })
1596                                .collect()
1597                        }),
1598                    required: props_source
1599                        .get("required")
1600                        .and_then(|v| v.as_array())
1601                        .map(|arr| {
1602                            arr.iter()
1603                                .filter_map(|v| v.as_str().map(String::from))
1604                                .collect()
1605                        }),
1606                    items: obj
1607                        .get("items")
1608                        .and_then(|v| v.clone().try_into().ok())
1609                        .map(Box::new),
1610                })
1611            } else {
1612                Err(CompletionError::ResponseError(
1613                    "Expected a JSON object for Schema".into(),
1614                ))
1615            }
1616        }
1617    }
1618
1619    #[derive(Debug, Serialize)]
1620    #[serde(rename_all = "camelCase")]
1621    pub struct GenerateContentRequest {
1622        pub contents: Vec<Content>,
1623        #[serde(skip_serializing_if = "Option::is_none")]
1624        pub tools: Option<Tool>,
1625        pub tool_config: Option<ToolConfig>,
1626        /// Optional. Configuration options for model generation and outputs.
1627        pub generation_config: Option<GenerationConfig>,
1628        /// Optional. A list of unique SafetySetting instances for blocking unsafe content. This will be enforced on the
1629        /// [GenerateContentRequest.contents] and [GenerateContentResponse.candidates]. There should not be more than one
1630        /// setting for each SafetyCategory type. The API will block any contents and responses that fail to meet the
1631        /// thresholds set by these settings. This list overrides the default settings for each SafetyCategory specified
1632        /// in the safetySettings. If there is no SafetySetting for a given SafetyCategory provided in the list, the API
1633        /// will use the default safety setting for that category. Harm categories:
1634        ///     - HARM_CATEGORY_HATE_SPEECH,
1635        ///     - HARM_CATEGORY_SEXUALLY_EXPLICIT
1636        ///     - HARM_CATEGORY_DANGEROUS_CONTENT
1637        ///     - HARM_CATEGORY_HARASSMENT
1638        /// are supported.
1639        /// Refer to the guide for detailed information on available safety settings. Also refer to the Safety guidance
1640        /// to learn how to incorporate safety considerations in your AI applications.
1641        pub safety_settings: Option<Vec<SafetySetting>>,
1642        /// Optional. Developer set system instruction(s). Currently, text only.
1643        /// From [Gemini API Reference](https://ai.google.dev/gemini-api/docs/system-instructions?lang=rest)
1644        pub system_instruction: Option<Content>,
1645        // cachedContent: Optional<String>
1646        /// Additional parameters.
1647        #[serde(flatten, skip_serializing_if = "Option::is_none")]
1648        pub additional_params: Option<serde_json::Value>,
1649    }
1650
1651    #[derive(Debug, Serialize)]
1652    #[serde(rename_all = "camelCase")]
1653    pub struct Tool {
1654        pub function_declarations: Vec<FunctionDeclaration>,
1655        pub code_execution: Option<CodeExecution>,
1656    }
1657
1658    #[derive(Debug, Serialize, Clone)]
1659    #[serde(rename_all = "camelCase")]
1660    pub struct FunctionDeclaration {
1661        pub name: String,
1662        pub description: String,
1663        #[serde(skip_serializing_if = "Option::is_none")]
1664        pub parameters: Option<Schema>,
1665    }
1666
1667    #[derive(Debug, Serialize, Deserialize)]
1668    #[serde(rename_all = "camelCase")]
1669    pub struct ToolConfig {
1670        pub function_calling_config: Option<FunctionCallingMode>,
1671    }
1672
1673    #[derive(Debug, Serialize, Deserialize, Default)]
1674    #[serde(tag = "mode", rename_all = "UPPERCASE")]
1675    pub enum FunctionCallingMode {
1676        #[default]
1677        Auto,
1678        None,
1679        Any {
1680            #[serde(skip_serializing_if = "Option::is_none")]
1681            allowed_function_names: Option<Vec<String>>,
1682        },
1683    }
1684
1685    impl TryFrom<message::ToolChoice> for FunctionCallingMode {
1686        type Error = CompletionError;
1687        fn try_from(value: message::ToolChoice) -> Result<Self, Self::Error> {
1688            let res = match value {
1689                message::ToolChoice::Auto => Self::Auto,
1690                message::ToolChoice::None => Self::None,
1691                message::ToolChoice::Required => Self::Any {
1692                    allowed_function_names: None,
1693                },
1694                message::ToolChoice::Specific { function_names } => Self::Any {
1695                    allowed_function_names: Some(function_names),
1696                },
1697            };
1698
1699            Ok(res)
1700        }
1701    }
1702
1703    #[derive(Debug, Serialize)]
1704    pub struct CodeExecution {}
1705
1706    #[derive(Debug, Serialize)]
1707    #[serde(rename_all = "camelCase")]
1708    pub struct SafetySetting {
1709        pub category: HarmCategory,
1710        pub threshold: HarmBlockThreshold,
1711    }
1712
1713    #[derive(Debug, Serialize)]
1714    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1715    pub enum HarmBlockThreshold {
1716        HarmBlockThresholdUnspecified,
1717        BlockLowAndAbove,
1718        BlockMediumAndAbove,
1719        BlockOnlyHigh,
1720        BlockNone,
1721        Off,
1722    }
1723}
1724
1725#[cfg(test)]
1726mod tests {
1727    use crate::{message, providers::gemini::completion::gemini_api_types::flatten_schema};
1728
1729    use super::*;
1730    use serde_json::json;
1731
1732    #[test]
1733    fn test_deserialize_message_user() {
1734        let raw_message = r#"{
1735            "parts": [
1736                {"text": "Hello, world!"},
1737                {"inlineData": {"mimeType": "image/png", "data": "base64encodeddata"}},
1738                {"functionCall": {"name": "test_function", "args": {"arg1": "value1"}}},
1739                {"functionResponse": {"name": "test_function", "response": {"result": "success"}}},
1740                {"fileData": {"mimeType": "application/pdf", "fileUri": "http://example.com/file.pdf"}},
1741                {"executableCode": {"code": "print('Hello, world!')", "language": "PYTHON"}},
1742                {"codeExecutionResult": {"output": "Hello, world!", "outcome": "OUTCOME_OK"}}
1743            ],
1744            "role": "user"
1745        }"#;
1746
1747        let content: Content = {
1748            let jd = &mut serde_json::Deserializer::from_str(raw_message);
1749            serde_path_to_error::deserialize(jd).unwrap_or_else(|err| {
1750                panic!("Deserialization error at {}: {}", err.path(), err);
1751            })
1752        };
1753        assert_eq!(content.role, Some(Role::User));
1754        assert_eq!(content.parts.len(), 7);
1755
1756        let parts: Vec<Part> = content.parts.into_iter().collect();
1757
1758        if let Part {
1759            part: PartKind::Text(text),
1760            ..
1761        } = &parts[0]
1762        {
1763            assert_eq!(text, "Hello, world!");
1764        } else {
1765            panic!("Expected text part");
1766        }
1767
1768        if let Part {
1769            part: PartKind::InlineData(inline_data),
1770            ..
1771        } = &parts[1]
1772        {
1773            assert_eq!(inline_data.mime_type, "image/png");
1774            assert_eq!(inline_data.data, "base64encodeddata");
1775        } else {
1776            panic!("Expected inline data part");
1777        }
1778
1779        if let Part {
1780            part: PartKind::FunctionCall(function_call),
1781            ..
1782        } = &parts[2]
1783        {
1784            assert_eq!(function_call.name, "test_function");
1785            assert_eq!(
1786                function_call.args.as_object().unwrap().get("arg1").unwrap(),
1787                "value1"
1788            );
1789        } else {
1790            panic!("Expected function call part");
1791        }
1792
1793        if let Part {
1794            part: PartKind::FunctionResponse(function_response),
1795            ..
1796        } = &parts[3]
1797        {
1798            assert_eq!(function_response.name, "test_function");
1799            assert_eq!(
1800                function_response
1801                    .response
1802                    .as_ref()
1803                    .unwrap()
1804                    .get("result")
1805                    .unwrap(),
1806                "success"
1807            );
1808        } else {
1809            panic!("Expected function response part");
1810        }
1811
1812        if let Part {
1813            part: PartKind::FileData(file_data),
1814            ..
1815        } = &parts[4]
1816        {
1817            assert_eq!(file_data.mime_type.as_ref().unwrap(), "application/pdf");
1818            assert_eq!(file_data.file_uri, "http://example.com/file.pdf");
1819        } else {
1820            panic!("Expected file data part");
1821        }
1822
1823        if let Part {
1824            part: PartKind::ExecutableCode(executable_code),
1825            ..
1826        } = &parts[5]
1827        {
1828            assert_eq!(executable_code.code, "print('Hello, world!')");
1829        } else {
1830            panic!("Expected executable code part");
1831        }
1832
1833        if let Part {
1834            part: PartKind::CodeExecutionResult(code_execution_result),
1835            ..
1836        } = &parts[6]
1837        {
1838            assert_eq!(
1839                code_execution_result.clone().output.unwrap(),
1840                "Hello, world!"
1841            );
1842        } else {
1843            panic!("Expected code execution result part");
1844        }
1845    }
1846
1847    #[test]
1848    fn test_deserialize_message_model() {
1849        let json_data = json!({
1850            "parts": [{"text": "Hello, user!"}],
1851            "role": "model"
1852        });
1853
1854        let content: Content = serde_json::from_value(json_data).unwrap();
1855        assert_eq!(content.role, Some(Role::Model));
1856        assert_eq!(content.parts.len(), 1);
1857        if let Some(Part {
1858            part: PartKind::Text(text),
1859            ..
1860        }) = content.parts.first()
1861        {
1862            assert_eq!(text, "Hello, user!");
1863        } else {
1864            panic!("Expected text part");
1865        }
1866    }
1867
1868    #[test]
1869    fn test_message_conversion_user() {
1870        let msg = message::Message::user("Hello, world!");
1871        let content: Content = msg.try_into().unwrap();
1872        assert_eq!(content.role, Some(Role::User));
1873        assert_eq!(content.parts.len(), 1);
1874        if let Some(Part {
1875            part: PartKind::Text(text),
1876            ..
1877        }) = &content.parts.first()
1878        {
1879            assert_eq!(text, "Hello, world!");
1880        } else {
1881            panic!("Expected text part");
1882        }
1883    }
1884
1885    #[test]
1886    fn test_message_conversion_model() {
1887        let msg = message::Message::assistant("Hello, user!");
1888
1889        let content: Content = msg.try_into().unwrap();
1890        assert_eq!(content.role, Some(Role::Model));
1891        assert_eq!(content.parts.len(), 1);
1892        if let Some(Part {
1893            part: PartKind::Text(text),
1894            ..
1895        }) = &content.parts.first()
1896        {
1897            assert_eq!(text, "Hello, user!");
1898        } else {
1899            panic!("Expected text part");
1900        }
1901    }
1902
1903    #[test]
1904    fn test_message_conversion_tool_call() {
1905        let tool_call = message::ToolCall {
1906            id: "test_tool".to_string(),
1907            call_id: None,
1908            function: message::ToolFunction {
1909                name: "test_function".to_string(),
1910                arguments: json!({"arg1": "value1"}),
1911            },
1912        };
1913
1914        let msg = message::Message::Assistant {
1915            id: None,
1916            content: OneOrMany::one(message::AssistantContent::ToolCall(tool_call)),
1917        };
1918
1919        let content: Content = msg.try_into().unwrap();
1920        assert_eq!(content.role, Some(Role::Model));
1921        assert_eq!(content.parts.len(), 1);
1922        if let Some(Part {
1923            part: PartKind::FunctionCall(function_call),
1924            ..
1925        }) = content.parts.first()
1926        {
1927            assert_eq!(function_call.name, "test_function");
1928            assert_eq!(
1929                function_call.args.as_object().unwrap().get("arg1").unwrap(),
1930                "value1"
1931            );
1932        } else {
1933            panic!("Expected function call part");
1934        }
1935    }
1936
1937    #[test]
1938    fn test_vec_schema_conversion() {
1939        let schema_with_ref = json!({
1940            "type": "array",
1941            "items": {
1942                "$ref": "#/$defs/Person"
1943            },
1944            "$defs": {
1945                "Person": {
1946                    "type": "object",
1947                    "properties": {
1948                        "first_name": {
1949                            "type": ["string", "null"],
1950                            "description": "The person's first name, if provided (null otherwise)"
1951                        },
1952                        "last_name": {
1953                            "type": ["string", "null"],
1954                            "description": "The person's last name, if provided (null otherwise)"
1955                        },
1956                        "job": {
1957                            "type": ["string", "null"],
1958                            "description": "The person's job, if provided (null otherwise)"
1959                        }
1960                    },
1961                    "required": []
1962                }
1963            }
1964        });
1965
1966        let result: Result<Schema, _> = schema_with_ref.try_into();
1967
1968        match result {
1969            Ok(schema) => {
1970                assert_eq!(schema.r#type, "array");
1971
1972                if let Some(items) = schema.items {
1973                    println!("item types: {}", items.r#type);
1974
1975                    assert_ne!(items.r#type, "", "Items type should not be empty string!");
1976                    assert_eq!(items.r#type, "object", "Items should be object type");
1977                } else {
1978                    panic!("Schema should have items field for array type");
1979                }
1980            }
1981            Err(e) => println!("Schema conversion failed: {:?}", e),
1982        }
1983    }
1984
1985    #[test]
1986    fn test_object_schema() {
1987        let simple_schema = json!({
1988            "type": "object",
1989            "properties": {
1990                "name": {
1991                    "type": "string"
1992                }
1993            }
1994        });
1995
1996        let schema: Schema = simple_schema.try_into().unwrap();
1997        assert_eq!(schema.r#type, "object");
1998        assert!(schema.properties.is_some());
1999    }
2000
2001    #[test]
2002    fn test_array_with_inline_items() {
2003        let inline_schema = json!({
2004            "type": "array",
2005            "items": {
2006                "type": "object",
2007                "properties": {
2008                    "name": {
2009                        "type": "string"
2010                    }
2011                }
2012            }
2013        });
2014
2015        let schema: Schema = inline_schema.try_into().unwrap();
2016        assert_eq!(schema.r#type, "array");
2017
2018        if let Some(items) = schema.items {
2019            assert_eq!(items.r#type, "object");
2020            assert!(items.properties.is_some());
2021        } else {
2022            panic!("Schema should have items field");
2023        }
2024    }
2025    #[test]
2026    fn test_flattened_schema() {
2027        let ref_schema = json!({
2028            "type": "array",
2029            "items": {
2030                "$ref": "#/$defs/Person"
2031            },
2032            "$defs": {
2033                "Person": {
2034                    "type": "object",
2035                    "properties": {
2036                        "name": { "type": "string" }
2037                    }
2038                }
2039            }
2040        });
2041
2042        let flattened = flatten_schema(ref_schema).unwrap();
2043        let schema: Schema = flattened.try_into().unwrap();
2044
2045        assert_eq!(schema.r#type, "array");
2046
2047        if let Some(items) = schema.items {
2048            println!("Flattened items type: '{}'", items.r#type);
2049
2050            assert_eq!(items.r#type, "object");
2051            assert!(items.properties.is_some());
2052        }
2053    }
2054}