Skip to main content

rig/providers/gemini/
completion.rs

1// ================================================================
2//! Google Gemini Completion Integration
3//! From [Gemini API Reference](https://ai.google.dev/api/generate-content)
4// ================================================================
5/// `gemini-2.5-pro-preview-06-05` completion model
6pub const GEMINI_2_5_PRO_PREVIEW_06_05: &str = "gemini-2.5-pro-preview-06-05";
7/// `gemini-2.5-pro-preview-05-06` completion model
8pub const GEMINI_2_5_PRO_PREVIEW_05_06: &str = "gemini-2.5-pro-preview-05-06";
9/// `gemini-2.5-pro-preview-03-25` completion model
10pub const GEMINI_2_5_PRO_PREVIEW_03_25: &str = "gemini-2.5-pro-preview-03-25";
11/// `gemini-2.5-flash-preview-04-17` completion model
12pub const GEMINI_2_5_FLASH_PREVIEW_04_17: &str = "gemini-2.5-flash-preview-04-17";
13/// `gemini-2.5-pro-exp-03-25` experimental completion model
14pub const GEMINI_2_5_PRO_EXP_03_25: &str = "gemini-2.5-pro-exp-03-25";
15/// `gemini-2.5-flash` completion model
16pub const GEMINI_2_5_FLASH: &str = "gemini-2.5-flash";
17/// `gemini-2.0-flash-lite` completion model
18pub const GEMINI_2_0_FLASH_LITE: &str = "gemini-2.0-flash-lite";
19/// `gemini-2.0-flash` completion model
20pub const GEMINI_2_0_FLASH: &str = "gemini-2.0-flash";
21
22use self::gemini_api_types::Schema;
23use crate::http_client::HttpClientExt;
24use crate::message::{self, MimeType, Reasoning};
25
26use crate::providers::gemini::completion::gemini_api_types::{
27    AdditionalParameters, FunctionCallingMode, ToolConfig,
28};
29use crate::providers::gemini::streaming::StreamingCompletionResponse;
30use crate::telemetry::SpanCombinator;
31use crate::{
32    OneOrMany,
33    completion::{self, CompletionError, CompletionRequest},
34};
35use gemini_api_types::{
36    Content, FunctionDeclaration, GenerateContentRequest, GenerateContentResponse, Part, PartKind,
37    Role, Tool,
38};
39use serde_json::{Map, Value};
40use std::convert::TryFrom;
41use tracing::{Level, enabled, info_span};
42use tracing_futures::Instrument;
43
44use super::Client;
45
46// =================================================================
47// Rig Implementation Types
48// =================================================================
49
50#[derive(Clone, Debug)]
51pub struct CompletionModel<T = reqwest::Client> {
52    pub(crate) client: Client<T>,
53    pub model: String,
54}
55
56impl<T> CompletionModel<T> {
57    pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
58        Self {
59            client,
60            model: model.into(),
61        }
62    }
63
64    pub fn with_model(client: Client<T>, model: &str) -> Self {
65        Self {
66            client,
67            model: model.into(),
68        }
69    }
70}
71
72impl<T> completion::CompletionModel for CompletionModel<T>
73where
74    T: HttpClientExt + Clone + 'static,
75{
76    type Response = GenerateContentResponse;
77    type StreamingResponse = StreamingCompletionResponse;
78    type Client = super::Client<T>;
79
80    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
81        Self::new(client.clone(), model)
82    }
83
84    async fn completion(
85        &self,
86        completion_request: CompletionRequest,
87    ) -> Result<completion::CompletionResponse<GenerateContentResponse>, CompletionError> {
88        let span = if tracing::Span::current().is_disabled() {
89            info_span!(
90                target: "rig::completions",
91                "generate_content",
92                gen_ai.operation.name = "generate_content",
93                gen_ai.provider.name = "gcp.gemini",
94                gen_ai.request.model = self.model,
95                gen_ai.system_instructions = &completion_request.preamble,
96                gen_ai.response.id = tracing::field::Empty,
97                gen_ai.response.model = tracing::field::Empty,
98                gen_ai.usage.output_tokens = tracing::field::Empty,
99                gen_ai.usage.input_tokens = tracing::field::Empty,
100            )
101        } else {
102            tracing::Span::current()
103        };
104
105        let request = create_request_body(completion_request)?;
106
107        if enabled!(Level::TRACE) {
108            tracing::trace!(
109                target: "rig::completions",
110                "Gemini completion request: {}",
111                serde_json::to_string_pretty(&request)?
112            );
113        }
114
115        let body = serde_json::to_vec(&request)?;
116
117        let path = format!("/v1beta/models/{}:generateContent", self.model);
118
119        let request = self
120            .client
121            .post(path.as_str())?
122            .body(body)
123            .map_err(|e| CompletionError::HttpError(e.into()))?;
124
125        async move {
126            let response = self.client.send::<_, Vec<u8>>(request).await?;
127
128            if response.status().is_success() {
129                let response_body = response
130                    .into_body()
131                    .await
132                    .map_err(CompletionError::HttpError)?;
133
134                let response_text = String::from_utf8_lossy(&response_body).to_string();
135
136                let response: GenerateContentResponse = serde_json::from_slice(&response_body)
137                    .map_err(|err| {
138                        tracing::error!(
139                            error = %err,
140                            body = %response_text,
141                            "Failed to deserialize Gemini completion response"
142                        );
143                        CompletionError::JsonError(err)
144                    })?;
145
146                let span = tracing::Span::current();
147                span.record_response_metadata(&response);
148                span.record_token_usage(&response.usage_metadata);
149
150                if enabled!(Level::TRACE) {
151                    tracing::trace!(
152                        target: "rig::completions",
153                        "Gemini completion response: {}",
154                        serde_json::to_string_pretty(&response)?
155                    );
156                }
157
158                response.try_into()
159            } else {
160                let text = String::from_utf8_lossy(
161                    &response
162                        .into_body()
163                        .await
164                        .map_err(CompletionError::HttpError)?,
165                )
166                .into();
167
168                Err(CompletionError::ProviderError(text))
169            }
170        }
171        .instrument(span)
172        .await
173    }
174
175    async fn stream(
176        &self,
177        request: CompletionRequest,
178    ) -> Result<
179        crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
180        CompletionError,
181    > {
182        CompletionModel::stream(self, request).await
183    }
184}
185
186pub(crate) fn create_request_body(
187    completion_request: CompletionRequest,
188) -> Result<GenerateContentRequest, CompletionError> {
189    let mut full_history = Vec::new();
190    full_history.extend(completion_request.chat_history);
191
192    let additional_params = completion_request
193        .additional_params
194        .unwrap_or_else(|| Value::Object(Map::new()));
195
196    let AdditionalParameters {
197        mut generation_config,
198        additional_params,
199    } = serde_json::from_value::<AdditionalParameters>(additional_params)?;
200
201    generation_config = generation_config.map(|mut cfg| {
202        if let Some(temp) = completion_request.temperature {
203            cfg.temperature = Some(temp);
204        };
205
206        if let Some(max_tokens) = completion_request.max_tokens {
207            cfg.max_output_tokens = Some(max_tokens);
208        };
209
210        cfg
211    });
212
213    let system_instruction = completion_request.preamble.clone().map(|preamble| Content {
214        parts: vec![preamble.into()],
215        role: Some(Role::Model),
216    });
217
218    let tools = if completion_request.tools.is_empty() {
219        None
220    } else {
221        Some(vec![Tool::try_from(completion_request.tools)?])
222    };
223
224    let tool_config = if let Some(cfg) = completion_request.tool_choice {
225        Some(ToolConfig {
226            function_calling_config: Some(FunctionCallingMode::try_from(cfg)?),
227        })
228    } else {
229        None
230    };
231
232    let request = GenerateContentRequest {
233        contents: full_history
234            .into_iter()
235            .map(|msg| {
236                msg.try_into()
237                    .map_err(|e| CompletionError::RequestError(Box::new(e)))
238            })
239            .collect::<Result<Vec<_>, _>>()?,
240        generation_config,
241        safety_settings: None,
242        tools,
243        tool_config,
244        system_instruction,
245        additional_params,
246    };
247
248    Ok(request)
249}
250
251impl TryFrom<completion::ToolDefinition> for Tool {
252    type Error = CompletionError;
253
254    fn try_from(tool: completion::ToolDefinition) -> Result<Self, Self::Error> {
255        let parameters: Option<Schema> =
256            if tool.parameters == serde_json::json!({"type": "object", "properties": {}}) {
257                None
258            } else {
259                Some(tool.parameters.try_into()?)
260            };
261
262        Ok(Self {
263            function_declarations: vec![FunctionDeclaration {
264                name: tool.name,
265                description: tool.description,
266                parameters,
267            }],
268            code_execution: None,
269        })
270    }
271}
272
273impl TryFrom<Vec<completion::ToolDefinition>> for Tool {
274    type Error = CompletionError;
275
276    fn try_from(tools: Vec<completion::ToolDefinition>) -> Result<Self, Self::Error> {
277        let mut function_declarations = Vec::new();
278
279        for tool in tools {
280            let parameters =
281                if tool.parameters == serde_json::json!({"type": "object", "properties": {}}) {
282                    None
283                } else {
284                    match tool.parameters.try_into() {
285                        Ok(schema) => Some(schema),
286                        Err(e) => {
287                            let emsg = format!(
288                                "Tool '{}' could not be converted to a schema: {:?}",
289                                tool.name, e,
290                            );
291                            return Err(CompletionError::ProviderError(emsg));
292                        }
293                    }
294                };
295
296            function_declarations.push(FunctionDeclaration {
297                name: tool.name,
298                description: tool.description,
299                parameters,
300            });
301        }
302
303        Ok(Self {
304            function_declarations,
305            code_execution: None,
306        })
307    }
308}
309
310impl TryFrom<GenerateContentResponse> for completion::CompletionResponse<GenerateContentResponse> {
311    type Error = CompletionError;
312
313    fn try_from(response: GenerateContentResponse) -> Result<Self, Self::Error> {
314        let candidate = response.candidates.first().ok_or_else(|| {
315            CompletionError::ResponseError("No response candidates in response".into())
316        })?;
317
318        let content = candidate
319            .content
320            .as_ref()
321            .ok_or_else(|| {
322                let reason = candidate
323                    .finish_reason
324                    .as_ref()
325                    .map(|r| format!("finish_reason={r:?}"))
326                    .unwrap_or_else(|| "finish_reason=<unknown>".to_string());
327                let message = candidate
328                    .finish_message
329                    .as_deref()
330                    .unwrap_or("no finish message provided");
331                CompletionError::ResponseError(format!(
332                    "Gemini candidate missing content ({reason}, finish_message={message})"
333                ))
334            })?
335            .parts
336            .iter()
337            .map(
338                |Part {
339                     thought,
340                     thought_signature,
341                     part,
342                     ..
343                 }| {
344                    Ok(match part {
345                        PartKind::Text(text) => {
346                            if let Some(thought) = thought
347                                && *thought
348                            {
349                                completion::AssistantContent::Reasoning(Reasoning::new(text))
350                            } else {
351                                completion::AssistantContent::text(text)
352                            }
353                        }
354                        PartKind::InlineData(inline_data) => {
355                            let mime_type =
356                                message::MediaType::from_mime_type(&inline_data.mime_type);
357
358                            match mime_type {
359                                Some(message::MediaType::Image(media_type)) => {
360                                    message::AssistantContent::image_base64(
361                                        &inline_data.data,
362                                        Some(media_type),
363                                        Some(message::ImageDetail::default()),
364                                    )
365                                }
366                                _ => {
367                                    return Err(CompletionError::ResponseError(format!(
368                                        "Unsupported media type {mime_type:?}"
369                                    )));
370                                }
371                            }
372                        }
373                        PartKind::FunctionCall(function_call) => {
374                            completion::AssistantContent::ToolCall(
375                                message::ToolCall::new(
376                                    function_call.name.clone(),
377                                    message::ToolFunction::new(
378                                        function_call.name.clone(),
379                                        function_call.args.clone(),
380                                    ),
381                                )
382                                .with_signature(thought_signature.clone()),
383                            )
384                        }
385                        _ => {
386                            return Err(CompletionError::ResponseError(
387                                "Response did not contain a message or tool call".into(),
388                            ));
389                        }
390                    })
391                },
392            )
393            .collect::<Result<Vec<_>, _>>()?;
394
395        let choice = OneOrMany::many(content).map_err(|_| {
396            CompletionError::ResponseError(
397                "Response contained no message or tool call (empty)".to_owned(),
398            )
399        })?;
400
401        let usage = response
402            .usage_metadata
403            .as_ref()
404            .map(|usage| completion::Usage {
405                input_tokens: usage.prompt_token_count as u64,
406                output_tokens: usage.candidates_token_count.unwrap_or(0) as u64,
407                total_tokens: usage.total_token_count as u64,
408                cached_input_tokens: 0,
409            })
410            .unwrap_or_default();
411
412        Ok(completion::CompletionResponse {
413            choice,
414            usage,
415            raw_response: response,
416        })
417    }
418}
419
420pub mod gemini_api_types {
421    use crate::telemetry::ProviderResponseExt;
422    use std::{collections::HashMap, convert::Infallible, str::FromStr};
423
424    // =================================================================
425    // Gemini API Types
426    // =================================================================
427    use serde::{Deserialize, Serialize};
428    use serde_json::{Value, json};
429
430    use crate::completion::GetTokenUsage;
431    use crate::message::{DocumentSourceKind, ImageMediaType, MessageError, MimeType};
432    use crate::{
433        completion::CompletionError,
434        message::{self},
435        providers::gemini::gemini_api_types::{CodeExecutionResult, ExecutableCode},
436    };
437
438    #[derive(Debug, Deserialize, Serialize, Default)]
439    #[serde(rename_all = "camelCase")]
440    pub struct AdditionalParameters {
441        /// Change your Gemini request configuration.
442        pub generation_config: Option<GenerationConfig>,
443        /// Any additional parameters that you want.
444        #[serde(flatten, skip_serializing_if = "Option::is_none")]
445        pub additional_params: Option<serde_json::Value>,
446    }
447
448    impl AdditionalParameters {
449        pub fn with_config(mut self, cfg: GenerationConfig) -> Self {
450            self.generation_config = Some(cfg);
451            self
452        }
453
454        pub fn with_params(mut self, params: serde_json::Value) -> Self {
455            self.additional_params = Some(params);
456            self
457        }
458    }
459
460    /// Response from the model supporting multiple candidate responses.
461    /// Safety ratings and content filtering are reported for both prompt in GenerateContentResponse.prompt_feedback
462    /// and for each candidate in finishReason and in safetyRatings.
463    /// The API:
464    ///     - Returns either all requested candidates or none of them
465    ///     - Returns no candidates at all only if there was something wrong with the prompt (check promptFeedback)
466    ///     - Reports feedback on each candidate in finishReason and safetyRatings.
467    #[derive(Debug, Deserialize, Serialize)]
468    #[serde(rename_all = "camelCase")]
469    pub struct GenerateContentResponse {
470        pub response_id: String,
471        /// Candidate responses from the model.
472        pub candidates: Vec<ContentCandidate>,
473        /// Returns the prompt's feedback related to the content filters.
474        pub prompt_feedback: Option<PromptFeedback>,
475        /// Output only. Metadata on the generation requests' token usage.
476        pub usage_metadata: Option<UsageMetadata>,
477        pub model_version: Option<String>,
478    }
479
480    impl ProviderResponseExt for GenerateContentResponse {
481        type OutputMessage = ContentCandidate;
482        type Usage = UsageMetadata;
483
484        fn get_response_id(&self) -> Option<String> {
485            Some(self.response_id.clone())
486        }
487
488        fn get_response_model_name(&self) -> Option<String> {
489            None
490        }
491
492        fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
493            self.candidates.clone()
494        }
495
496        fn get_text_response(&self) -> Option<String> {
497            let str = self
498                .candidates
499                .iter()
500                .filter_map(|x| {
501                    let content = x.content.as_ref()?;
502                    if content.role.as_ref().is_none_or(|y| y != &Role::Model) {
503                        return None;
504                    }
505
506                    let res = content
507                        .parts
508                        .iter()
509                        .filter_map(|part| {
510                            if let PartKind::Text(ref str) = part.part {
511                                Some(str.to_owned())
512                            } else {
513                                None
514                            }
515                        })
516                        .collect::<Vec<String>>()
517                        .join("\n");
518
519                    Some(res)
520                })
521                .collect::<Vec<String>>()
522                .join("\n");
523
524            if str.is_empty() { None } else { Some(str) }
525        }
526
527        fn get_usage(&self) -> Option<Self::Usage> {
528            self.usage_metadata.clone()
529        }
530    }
531
532    /// A response candidate generated from the model.
533    #[derive(Clone, Debug, Deserialize, Serialize)]
534    #[serde(rename_all = "camelCase")]
535    pub struct ContentCandidate {
536        /// Output only. Generated content returned from the model.
537        #[serde(skip_serializing_if = "Option::is_none")]
538        pub content: Option<Content>,
539        /// Optional. Output only. The reason why the model stopped generating tokens.
540        /// If empty, the model has not stopped generating tokens.
541        pub finish_reason: Option<FinishReason>,
542        /// List of ratings for the safety of a response candidate.
543        /// There is at most one rating per category.
544        pub safety_ratings: Option<Vec<SafetyRating>>,
545        /// Output only. Citation information for model-generated candidate.
546        /// This field may be populated with recitation information for any text included in the content.
547        /// These are passages that are "recited" from copyrighted material in the foundational LLM's training data.
548        pub citation_metadata: Option<CitationMetadata>,
549        /// Output only. Token count for this candidate.
550        pub token_count: Option<i32>,
551        /// Output only.
552        pub avg_logprobs: Option<f64>,
553        /// Output only. Log-likelihood scores for the response tokens and top tokens
554        pub logprobs_result: Option<LogprobsResult>,
555        /// Output only. Index of the candidate in the list of response candidates.
556        pub index: Option<i32>,
557        /// Output only. Additional information about why the model stopped generating tokens.
558        pub finish_message: Option<String>,
559    }
560
561    #[derive(Clone, Debug, Deserialize, Serialize)]
562    pub struct Content {
563        /// Ordered Parts that constitute a single message. Parts may have different MIME types.
564        #[serde(default)]
565        pub parts: Vec<Part>,
566        /// The producer of the content. Must be either 'user' or 'model'.
567        /// Useful to set for multi-turn conversations, otherwise can be left blank or unset.
568        pub role: Option<Role>,
569    }
570
571    impl TryFrom<message::Message> for Content {
572        type Error = message::MessageError;
573
574        fn try_from(msg: message::Message) -> Result<Self, Self::Error> {
575            Ok(match msg {
576                message::Message::User { content } => Content {
577                    parts: content
578                        .into_iter()
579                        .map(|c| c.try_into())
580                        .collect::<Result<Vec<_>, _>>()?,
581                    role: Some(Role::User),
582                },
583                message::Message::Assistant { content, .. } => Content {
584                    role: Some(Role::Model),
585                    parts: content
586                        .into_iter()
587                        .map(|content| content.try_into())
588                        .collect::<Result<Vec<_>, _>>()?,
589                },
590            })
591        }
592    }
593
594    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
595    #[serde(rename_all = "lowercase")]
596    pub enum Role {
597        User,
598        Model,
599    }
600
601    #[derive(Debug, Default, Deserialize, Serialize, Clone, PartialEq)]
602    #[serde(rename_all = "camelCase")]
603    pub struct Part {
604        /// whether or not the part is a reasoning/thinking text or not
605        #[serde(skip_serializing_if = "Option::is_none")]
606        pub thought: Option<bool>,
607        /// an opaque sig for the thought so it can be reused - is a base64 string
608        #[serde(skip_serializing_if = "Option::is_none")]
609        pub thought_signature: Option<String>,
610        #[serde(flatten)]
611        pub part: PartKind,
612        #[serde(flatten, skip_serializing_if = "Option::is_none")]
613        pub additional_params: Option<Value>,
614    }
615
616    /// A datatype containing media that is part of a multi-part [Content] message.
617    /// A Part consists of data which has an associated datatype. A Part can only contain one of the accepted types in Part.data.
618    /// A Part must have a fixed IANA MIME type identifying the type and subtype of the media if the inlineData field is filled with raw bytes.
619    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
620    #[serde(rename_all = "camelCase")]
621    pub enum PartKind {
622        Text(String),
623        InlineData(Blob),
624        FunctionCall(FunctionCall),
625        FunctionResponse(FunctionResponse),
626        FileData(FileData),
627        ExecutableCode(ExecutableCode),
628        CodeExecutionResult(CodeExecutionResult),
629    }
630
631    // This default instance is primarily so we can easily fill in the optional fields of `Part`
632    // So this instance for `PartKind` (and the allocation it would cause) should be optimized away
633    impl Default for PartKind {
634        fn default() -> Self {
635            Self::Text(String::new())
636        }
637    }
638
639    impl From<String> for Part {
640        fn from(text: String) -> Self {
641            Self {
642                thought: Some(false),
643                thought_signature: None,
644                part: PartKind::Text(text),
645                additional_params: None,
646            }
647        }
648    }
649
650    impl From<&str> for Part {
651        fn from(text: &str) -> Self {
652            Self::from(text.to_string())
653        }
654    }
655
656    impl FromStr for Part {
657        type Err = Infallible;
658
659        fn from_str(s: &str) -> Result<Self, Self::Err> {
660            Ok(s.into())
661        }
662    }
663
664    impl TryFrom<(ImageMediaType, DocumentSourceKind)> for PartKind {
665        type Error = message::MessageError;
666        fn try_from(
667            (mime_type, doc_src): (ImageMediaType, DocumentSourceKind),
668        ) -> Result<Self, Self::Error> {
669            let mime_type = mime_type.to_mime_type().to_string();
670            let part = match doc_src {
671                DocumentSourceKind::Url(url) => PartKind::FileData(FileData {
672                    mime_type: Some(mime_type),
673                    file_uri: url,
674                }),
675                DocumentSourceKind::Base64(data) | DocumentSourceKind::String(data) => {
676                    PartKind::InlineData(Blob { mime_type, data })
677                }
678                DocumentSourceKind::Raw(_) => {
679                    return Err(message::MessageError::ConversionError(
680                        "Raw files not supported, encode as base64 first".into(),
681                    ));
682                }
683                DocumentSourceKind::Unknown => {
684                    return Err(message::MessageError::ConversionError(
685                        "Can't convert an unknown document source".to_string(),
686                    ));
687                }
688            };
689
690            Ok(part)
691        }
692    }
693
694    impl TryFrom<message::UserContent> for Part {
695        type Error = message::MessageError;
696
697        fn try_from(content: message::UserContent) -> Result<Self, Self::Error> {
698            match content {
699                message::UserContent::Text(message::Text { text }) => Ok(Part {
700                    thought: Some(false),
701                    thought_signature: None,
702                    part: PartKind::Text(text),
703                    additional_params: None,
704                }),
705                message::UserContent::ToolResult(message::ToolResult { id, content, .. }) => {
706                    let mut response_json: Option<serde_json::Value> = None;
707                    let mut parts: Vec<FunctionResponsePart> = Vec::new();
708
709                    for item in content.iter() {
710                        match item {
711                            message::ToolResultContent::Text(text) => {
712                                let result: serde_json::Value =
713                                    serde_json::from_str(&text.text).unwrap_or_else(|error| {
714                                        tracing::trace!(
715                                            ?error,
716                                            "Tool result is not a valid JSON, treat it as normal string"
717                                        );
718                                        json!(&text.text)
719                                    });
720
721                                response_json = Some(match response_json {
722                                    Some(mut existing) => {
723                                        if let serde_json::Value::Object(ref mut map) = existing {
724                                            map.insert("text".to_string(), result);
725                                        }
726                                        existing
727                                    }
728                                    None => json!({ "result": result }),
729                                });
730                            }
731                            message::ToolResultContent::Image(image) => {
732                                let part = match &image.data {
733                                    DocumentSourceKind::Base64(b64) => {
734                                        let mime_type = image
735                                            .media_type
736                                            .as_ref()
737                                            .ok_or(message::MessageError::ConversionError(
738                                                "Image media type is required for Gemini tool results".to_string(),
739                                            ))?
740                                            .to_mime_type();
741
742                                        FunctionResponsePart {
743                                            inline_data: Some(FunctionResponseInlineData {
744                                                mime_type: mime_type.to_string(),
745                                                data: b64.clone(),
746                                                display_name: None,
747                                            }),
748                                            file_data: None,
749                                        }
750                                    }
751                                    DocumentSourceKind::Url(url) => {
752                                        let mime_type = image
753                                            .media_type
754                                            .as_ref()
755                                            .map(|mt| mt.to_mime_type().to_string());
756
757                                        FunctionResponsePart {
758                                            inline_data: None,
759                                            file_data: Some(FileData {
760                                                mime_type,
761                                                file_uri: url.clone(),
762                                            }),
763                                        }
764                                    }
765                                    _ => {
766                                        return Err(message::MessageError::ConversionError(
767                                            "Unsupported image source kind for tool results"
768                                                .to_string(),
769                                        ));
770                                    }
771                                };
772                                parts.push(part);
773                            }
774                        }
775                    }
776
777                    Ok(Part {
778                        thought: Some(false),
779                        thought_signature: None,
780                        part: PartKind::FunctionResponse(FunctionResponse {
781                            name: id,
782                            response: response_json,
783                            parts: if parts.is_empty() { None } else { Some(parts) },
784                        }),
785                        additional_params: None,
786                    })
787                }
788                message::UserContent::Image(message::Image {
789                    data, media_type, ..
790                }) => match media_type {
791                    Some(media_type) => match media_type {
792                        message::ImageMediaType::JPEG
793                        | message::ImageMediaType::PNG
794                        | message::ImageMediaType::WEBP
795                        | message::ImageMediaType::HEIC
796                        | message::ImageMediaType::HEIF => {
797                            let part = PartKind::try_from((media_type, data))?;
798                            Ok(Part {
799                                thought: Some(false),
800                                thought_signature: None,
801                                part,
802                                additional_params: None,
803                            })
804                        }
805                        _ => Err(message::MessageError::ConversionError(format!(
806                            "Unsupported image media type {media_type:?}"
807                        ))),
808                    },
809                    None => Err(message::MessageError::ConversionError(
810                        "Media type for image is required for Gemini".to_string(),
811                    )),
812                },
813                message::UserContent::Document(message::Document {
814                    data, media_type, ..
815                }) => {
816                    let Some(media_type) = media_type else {
817                        return Err(MessageError::ConversionError(
818                            "A mime type is required for document inputs to Gemini".to_string(),
819                        ));
820                    };
821
822                    if !media_type.is_code() {
823                        let mime_type = media_type.to_mime_type().to_string();
824
825                        let part = match data {
826                            DocumentSourceKind::Url(file_uri) => PartKind::FileData(FileData {
827                                mime_type: Some(mime_type),
828                                file_uri,
829                            }),
830                            DocumentSourceKind::Base64(data) | DocumentSourceKind::String(data) => {
831                                PartKind::InlineData(Blob { mime_type, data })
832                            }
833                            DocumentSourceKind::Raw(_) => {
834                                return Err(message::MessageError::ConversionError(
835                                    "Raw files not supported, encode as base64 first".into(),
836                                ));
837                            }
838                            _ => {
839                                return Err(message::MessageError::ConversionError(
840                                    "Document has no body".to_string(),
841                                ));
842                            }
843                        };
844
845                        Ok(Part {
846                            thought: Some(false),
847                            part,
848                            ..Default::default()
849                        })
850                    } else {
851                        Err(message::MessageError::ConversionError(format!(
852                            "Unsupported document media type {media_type:?}"
853                        )))
854                    }
855                }
856
857                message::UserContent::Audio(message::Audio {
858                    data, media_type, ..
859                }) => {
860                    let Some(media_type) = media_type else {
861                        return Err(MessageError::ConversionError(
862                            "A mime type is required for audio inputs to Gemini".to_string(),
863                        ));
864                    };
865
866                    let mime_type = media_type.to_mime_type().to_string();
867
868                    let part = match data {
869                        DocumentSourceKind::Base64(data) => {
870                            PartKind::InlineData(Blob { data, mime_type })
871                        }
872
873                        DocumentSourceKind::Url(file_uri) => PartKind::FileData(FileData {
874                            mime_type: Some(mime_type),
875                            file_uri,
876                        }),
877                        DocumentSourceKind::String(_) => {
878                            return Err(message::MessageError::ConversionError(
879                                "Strings cannot be used as audio files!".into(),
880                            ));
881                        }
882                        DocumentSourceKind::Raw(_) => {
883                            return Err(message::MessageError::ConversionError(
884                                "Raw files not supported, encode as base64 first".into(),
885                            ));
886                        }
887                        DocumentSourceKind::Unknown => {
888                            return Err(message::MessageError::ConversionError(
889                                "Content has no body".to_string(),
890                            ));
891                        }
892                    };
893
894                    Ok(Part {
895                        thought: Some(false),
896                        part,
897                        ..Default::default()
898                    })
899                }
900                message::UserContent::Video(message::Video {
901                    data,
902                    media_type,
903                    additional_params,
904                    ..
905                }) => {
906                    let mime_type = media_type.map(|media_ty| media_ty.to_mime_type().to_string());
907
908                    let part = match data {
909                        DocumentSourceKind::Url(file_uri) => {
910                            if file_uri.starts_with("https://www.youtube.com") {
911                                PartKind::FileData(FileData {
912                                    mime_type,
913                                    file_uri,
914                                })
915                            } else {
916                                if mime_type.is_none() {
917                                    return Err(MessageError::ConversionError(
918                                        "A mime type is required for non-Youtube video file inputs to Gemini"
919                                            .to_string(),
920                                    ));
921                                }
922
923                                PartKind::FileData(FileData {
924                                    mime_type,
925                                    file_uri,
926                                })
927                            }
928                        }
929                        DocumentSourceKind::Base64(data) => {
930                            let Some(mime_type) = mime_type else {
931                                return Err(MessageError::ConversionError(
932                                    "A media type is expected for base64 encoded strings"
933                                        .to_string(),
934                                ));
935                            };
936                            PartKind::InlineData(Blob { mime_type, data })
937                        }
938                        DocumentSourceKind::String(_) => {
939                            return Err(message::MessageError::ConversionError(
940                                "Strings cannot be used as audio files!".into(),
941                            ));
942                        }
943                        DocumentSourceKind::Raw(_) => {
944                            return Err(message::MessageError::ConversionError(
945                                "Raw file data not supported, encode as base64 first".into(),
946                            ));
947                        }
948                        DocumentSourceKind::Unknown => {
949                            return Err(message::MessageError::ConversionError(
950                                "Media type for video is required for Gemini".to_string(),
951                            ));
952                        }
953                    };
954
955                    Ok(Part {
956                        thought: Some(false),
957                        thought_signature: None,
958                        part,
959                        additional_params,
960                    })
961                }
962            }
963        }
964    }
965
966    impl TryFrom<message::AssistantContent> for Part {
967        type Error = message::MessageError;
968
969        fn try_from(content: message::AssistantContent) -> Result<Self, Self::Error> {
970            match content {
971                message::AssistantContent::Text(message::Text { text }) => Ok(text.into()),
972                message::AssistantContent::Image(message::Image {
973                    data, media_type, ..
974                }) => match media_type {
975                    Some(media_type) => match media_type {
976                        message::ImageMediaType::JPEG
977                        | message::ImageMediaType::PNG
978                        | message::ImageMediaType::WEBP
979                        | message::ImageMediaType::HEIC
980                        | message::ImageMediaType::HEIF => {
981                            let part = PartKind::try_from((media_type, data))?;
982                            Ok(Part {
983                                thought: Some(false),
984                                thought_signature: None,
985                                part,
986                                additional_params: None,
987                            })
988                        }
989                        _ => Err(message::MessageError::ConversionError(format!(
990                            "Unsupported image media type {media_type:?}"
991                        ))),
992                    },
993                    None => Err(message::MessageError::ConversionError(
994                        "Media type for image is required for Gemini".to_string(),
995                    )),
996                },
997                message::AssistantContent::ToolCall(tool_call) => Ok(tool_call.into()),
998                message::AssistantContent::Reasoning(message::Reasoning { reasoning, .. }) => {
999                    Ok(Part {
1000                        thought: Some(true),
1001                        thought_signature: None,
1002                        part: PartKind::Text(
1003                            reasoning.first().cloned().unwrap_or_else(|| "".to_string()),
1004                        ),
1005                        additional_params: None,
1006                    })
1007                }
1008            }
1009        }
1010    }
1011
1012    impl From<message::ToolCall> for Part {
1013        fn from(tool_call: message::ToolCall) -> Self {
1014            Self {
1015                thought: Some(false),
1016                thought_signature: tool_call.signature,
1017                part: PartKind::FunctionCall(FunctionCall {
1018                    name: tool_call.function.name,
1019                    args: tool_call.function.arguments,
1020                }),
1021                additional_params: None,
1022            }
1023        }
1024    }
1025
1026    /// Raw media bytes.
1027    /// Text should not be sent as raw bytes, use the 'text' field.
1028    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1029    #[serde(rename_all = "camelCase")]
1030    pub struct Blob {
1031        /// The IANA standard MIME type of the source data. Examples: - image/png - image/jpeg
1032        /// If an unsupported MIME type is provided, an error will be returned.
1033        pub mime_type: String,
1034        /// Raw bytes for media formats. A base64-encoded string.
1035        pub data: String,
1036    }
1037
1038    /// A predicted FunctionCall returned from the model that contains a string representing the
1039    /// FunctionDeclaration.name with the arguments and their values.
1040    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1041    pub struct FunctionCall {
1042        /// Required. The name of the function to call. Must be a-z, A-Z, 0-9, or contain underscores
1043        /// and dashes, with a maximum length of 63.
1044        pub name: String,
1045        /// Optional. The function parameters and values in JSON object format.
1046        pub args: serde_json::Value,
1047    }
1048
1049    impl From<message::ToolCall> for FunctionCall {
1050        fn from(tool_call: message::ToolCall) -> Self {
1051            Self {
1052                name: tool_call.function.name,
1053                args: tool_call.function.arguments,
1054            }
1055        }
1056    }
1057
1058    /// The result output from a FunctionCall that contains a string representing the FunctionDeclaration.name
1059    /// and a structured JSON object containing any output from the function is used as context to the model.
1060    /// This should contain the result of aFunctionCall made based on model prediction.
1061    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1062    pub struct FunctionResponse {
1063        /// The name of the function to call. Must be a-z, A-Z, 0-9, or contain underscores and dashes,
1064        /// with a maximum length of 63.
1065        pub name: String,
1066        /// The function response in JSON object format.
1067        #[serde(skip_serializing_if = "Option::is_none")]
1068        pub response: Option<serde_json::Value>,
1069        /// Multimodal parts for the function response (e.g., images).
1070        #[serde(skip_serializing_if = "Option::is_none")]
1071        pub parts: Option<Vec<FunctionResponsePart>>,
1072    }
1073
1074    /// A part of a multimodal function response.
1075    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1076    #[serde(rename_all = "camelCase")]
1077    pub struct FunctionResponsePart {
1078        /// Inline data containing base64-encoded media content.
1079        #[serde(skip_serializing_if = "Option::is_none")]
1080        pub inline_data: Option<FunctionResponseInlineData>,
1081        /// File data containing a URI reference.
1082        #[serde(skip_serializing_if = "Option::is_none")]
1083        pub file_data: Option<FileData>,
1084    }
1085
1086    /// Inline data for function response parts.
1087    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1088    #[serde(rename_all = "camelCase")]
1089    pub struct FunctionResponseInlineData {
1090        /// The IANA standard MIME type of the source data.
1091        pub mime_type: String,
1092        /// Raw bytes for media formats. A base64-encoded string.
1093        pub data: String,
1094        /// Optional display name for the content.
1095        #[serde(skip_serializing_if = "Option::is_none")]
1096        pub display_name: Option<String>,
1097    }
1098
1099    /// URI based data.
1100    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1101    #[serde(rename_all = "camelCase")]
1102    pub struct FileData {
1103        /// Optional. The IANA standard MIME type of the source data.
1104        pub mime_type: Option<String>,
1105        /// Required. URI.
1106        pub file_uri: String,
1107    }
1108
1109    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1110    pub struct SafetyRating {
1111        pub category: HarmCategory,
1112        pub probability: HarmProbability,
1113    }
1114
1115    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1116    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1117    pub enum HarmProbability {
1118        HarmProbabilityUnspecified,
1119        Negligible,
1120        Low,
1121        Medium,
1122        High,
1123    }
1124
1125    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1126    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1127    pub enum HarmCategory {
1128        HarmCategoryUnspecified,
1129        HarmCategoryDerogatory,
1130        HarmCategoryToxicity,
1131        HarmCategoryViolence,
1132        HarmCategorySexually,
1133        HarmCategoryMedical,
1134        HarmCategoryDangerous,
1135        HarmCategoryHarassment,
1136        HarmCategoryHateSpeech,
1137        HarmCategorySexuallyExplicit,
1138        HarmCategoryDangerousContent,
1139        HarmCategoryCivicIntegrity,
1140    }
1141
1142    #[derive(Debug, Deserialize, Clone, Default, Serialize)]
1143    #[serde(rename_all = "camelCase")]
1144    pub struct UsageMetadata {
1145        pub prompt_token_count: i32,
1146        #[serde(skip_serializing_if = "Option::is_none")]
1147        pub cached_content_token_count: Option<i32>,
1148        #[serde(skip_serializing_if = "Option::is_none")]
1149        pub candidates_token_count: Option<i32>,
1150        pub total_token_count: i32,
1151        #[serde(skip_serializing_if = "Option::is_none")]
1152        pub thoughts_token_count: Option<i32>,
1153    }
1154
1155    impl std::fmt::Display for UsageMetadata {
1156        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1157            write!(
1158                f,
1159                "Prompt token count: {}\nCached content token count: {}\nCandidates token count: {}\nTotal token count: {}",
1160                self.prompt_token_count,
1161                match self.cached_content_token_count {
1162                    Some(count) => count.to_string(),
1163                    None => "n/a".to_string(),
1164                },
1165                match self.candidates_token_count {
1166                    Some(count) => count.to_string(),
1167                    None => "n/a".to_string(),
1168                },
1169                self.total_token_count
1170            )
1171        }
1172    }
1173
1174    impl GetTokenUsage for UsageMetadata {
1175        fn token_usage(&self) -> Option<crate::completion::Usage> {
1176            let mut usage = crate::completion::Usage::new();
1177
1178            usage.input_tokens = self.prompt_token_count as u64;
1179            usage.output_tokens = (self.cached_content_token_count.unwrap_or_default()
1180                + self.candidates_token_count.unwrap_or_default()
1181                + self.thoughts_token_count.unwrap_or_default())
1182                as u64;
1183            usage.total_tokens = usage.input_tokens + usage.output_tokens;
1184
1185            Some(usage)
1186        }
1187    }
1188
1189    /// A set of the feedback metadata the prompt specified in [GenerateContentRequest.contents](GenerateContentRequest).
1190    #[derive(Debug, Deserialize, Serialize)]
1191    #[serde(rename_all = "camelCase")]
1192    pub struct PromptFeedback {
1193        /// Optional. If set, the prompt was blocked and no candidates are returned. Rephrase the prompt.
1194        pub block_reason: Option<BlockReason>,
1195        /// Ratings for safety of the prompt. There is at most one rating per category.
1196        pub safety_ratings: Option<Vec<SafetyRating>>,
1197    }
1198
1199    /// Reason why a prompt was blocked by the model
1200    #[derive(Debug, Deserialize, Serialize)]
1201    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1202    pub enum BlockReason {
1203        /// Default value. This value is unused.
1204        BlockReasonUnspecified,
1205        /// Prompt was blocked due to safety reasons. Inspect safetyRatings to understand which safety category blocked it.
1206        Safety,
1207        /// Prompt was blocked due to unknown reasons.
1208        Other,
1209        /// Prompt was blocked due to the terms which are included from the terminology blocklist.
1210        Blocklist,
1211        /// Prompt was blocked due to prohibited content.
1212        ProhibitedContent,
1213    }
1214
1215    #[derive(Clone, Debug, Deserialize, Serialize)]
1216    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1217    pub enum FinishReason {
1218        /// Default value. This value is unused.
1219        FinishReasonUnspecified,
1220        /// Natural stop point of the model or provided stop sequence.
1221        Stop,
1222        /// The maximum number of tokens as specified in the request was reached.
1223        MaxTokens,
1224        /// The response candidate content was flagged for safety reasons.
1225        Safety,
1226        /// The response candidate content was flagged for recitation reasons.
1227        Recitation,
1228        /// The response candidate content was flagged for using an unsupported language.
1229        Language,
1230        /// Unknown reason.
1231        Other,
1232        /// Token generation stopped because the content contains forbidden terms.
1233        Blocklist,
1234        /// Token generation stopped for potentially containing prohibited content.
1235        ProhibitedContent,
1236        /// Token generation stopped because the content potentially contains Sensitive Personally Identifiable Information (SPII).
1237        Spii,
1238        /// The function call generated by the model is invalid.
1239        MalformedFunctionCall,
1240    }
1241
1242    #[derive(Clone, Debug, Deserialize, Serialize)]
1243    #[serde(rename_all = "camelCase")]
1244    pub struct CitationMetadata {
1245        pub citation_sources: Vec<CitationSource>,
1246    }
1247
1248    #[derive(Clone, Debug, Deserialize, Serialize)]
1249    #[serde(rename_all = "camelCase")]
1250    pub struct CitationSource {
1251        #[serde(skip_serializing_if = "Option::is_none")]
1252        pub uri: Option<String>,
1253        #[serde(skip_serializing_if = "Option::is_none")]
1254        pub start_index: Option<i32>,
1255        #[serde(skip_serializing_if = "Option::is_none")]
1256        pub end_index: Option<i32>,
1257        #[serde(skip_serializing_if = "Option::is_none")]
1258        pub license: Option<String>,
1259    }
1260
1261    #[derive(Clone, Debug, Deserialize, Serialize)]
1262    #[serde(rename_all = "camelCase")]
1263    pub struct LogprobsResult {
1264        pub top_candidate: Vec<TopCandidate>,
1265        pub chosen_candidate: Vec<LogProbCandidate>,
1266    }
1267
1268    #[derive(Clone, Debug, Deserialize, Serialize)]
1269    pub struct TopCandidate {
1270        pub candidates: Vec<LogProbCandidate>,
1271    }
1272
1273    #[derive(Clone, Debug, Deserialize, Serialize)]
1274    #[serde(rename_all = "camelCase")]
1275    pub struct LogProbCandidate {
1276        pub token: String,
1277        pub token_id: String,
1278        pub log_probability: f64,
1279    }
1280
1281    /// Gemini API Configuration options for model generation and outputs. Not all parameters are
1282    /// configurable for every model. From [Gemini API Reference](https://ai.google.dev/api/generate-content#generationconfig)
1283    /// ### Rig Note:
1284    /// Can be used to construct a typesafe `additional_params` in rig::[AgentBuilder](crate::agent::AgentBuilder).
1285    #[derive(Debug, Deserialize, Serialize)]
1286    #[serde(rename_all = "camelCase")]
1287    pub struct GenerationConfig {
1288        /// The set of character sequences (up to 5) that will stop output generation. If specified, the API will stop
1289        /// at the first appearance of a stop_sequence. The stop sequence will not be included as part of the response.
1290        #[serde(skip_serializing_if = "Option::is_none")]
1291        pub stop_sequences: Option<Vec<String>>,
1292        /// MIME type of the generated candidate text. Supported MIME types are:
1293        ///     - text/plain:  (default) Text output
1294        ///     - application/json: JSON response in the response candidates.
1295        ///     - text/x.enum: ENUM as a string response in the response candidates.
1296        /// Refer to the docs for a list of all supported text MIME types
1297        #[serde(skip_serializing_if = "Option::is_none")]
1298        pub response_mime_type: Option<String>,
1299        /// Output schema of the generated candidate text. Schemas must be a subset of the OpenAPI schema and can be
1300        /// objects, primitives or arrays. If set, a compatible responseMimeType must also  be set. Compatible MIME
1301        /// types: application/json: Schema for JSON response. Refer to the JSON text generation guide for more details.
1302        #[serde(skip_serializing_if = "Option::is_none")]
1303        pub response_schema: Option<Schema>,
1304        /// Optional. The output schema of the generated response.
1305        /// This is an alternative to responseSchema that accepts a standard JSON Schema.
1306        /// If this is set, responseSchema must be omitted.
1307        /// Compatible MIME type: application/json.
1308        /// Supported properties: $id, $defs, $ref, type, properties, etc.
1309        #[serde(
1310            skip_serializing_if = "Option::is_none",
1311            rename = "_responseJsonSchema"
1312        )]
1313        pub _response_json_schema: Option<Value>,
1314        /// Internal or alternative representation for `response_json_schema`.
1315        #[serde(skip_serializing_if = "Option::is_none")]
1316        pub response_json_schema: Option<Value>,
1317        /// Number of generated responses to return. Currently, this value can only be set to 1. If
1318        /// unset, this will default to 1.
1319        #[serde(skip_serializing_if = "Option::is_none")]
1320        pub candidate_count: Option<i32>,
1321        /// The maximum number of tokens to include in a response candidate. Note: The default value varies by model, see
1322        /// the Model.output_token_limit attribute of the Model returned from the getModel function.
1323        #[serde(skip_serializing_if = "Option::is_none")]
1324        pub max_output_tokens: Option<u64>,
1325        /// Controls the randomness of the output. Note: The default value varies by model, see the Model.temperature
1326        /// attribute of the Model returned from the getModel function. Values can range from [0.0, 2.0].
1327        #[serde(skip_serializing_if = "Option::is_none")]
1328        pub temperature: Option<f64>,
1329        /// The maximum cumulative probability of tokens to consider when sampling. The model uses combined Top-k and
1330        /// Top-p (nucleus) sampling. Tokens are sorted based on their assigned probabilities so that only the most
1331        /// likely tokens are considered. Top-k sampling directly limits the maximum number of tokens to consider, while
1332        /// Nucleus sampling limits the number of tokens based on the cumulative probability. Note: The default value
1333        /// varies by Model and is specified by theModel.top_p attribute returned from the getModel function. An empty
1334        /// topK attribute indicates that the model doesn't apply top-k sampling and doesn't allow setting topK on requests.
1335        #[serde(skip_serializing_if = "Option::is_none")]
1336        pub top_p: Option<f64>,
1337        /// The maximum number of tokens to consider when sampling. Gemini models use Top-p (nucleus) sampling or a
1338        /// combination of Top-k and nucleus sampling. Top-k sampling considers the set of topK most probable tokens.
1339        /// Models running with nucleus sampling don't allow topK setting. Note: The default value varies by Model and is
1340        /// specified by theModel.top_p attribute returned from the getModel function. An empty topK attribute indicates
1341        /// that the model doesn't apply top-k sampling and doesn't allow setting topK on requests.
1342        #[serde(skip_serializing_if = "Option::is_none")]
1343        pub top_k: Option<i32>,
1344        /// Presence penalty applied to the next token's logprobs if the token has already been seen in the response.
1345        /// This penalty is binary on/off and not dependent on the number of times the token is used (after the first).
1346        /// Use frequencyPenalty for a penalty that increases with each use. A positive penalty will discourage the use
1347        /// of tokens that have already been used in the response, increasing the vocabulary. A negative penalty will
1348        /// encourage the use of tokens that have already been used in the response, decreasing the vocabulary.
1349        #[serde(skip_serializing_if = "Option::is_none")]
1350        pub presence_penalty: Option<f64>,
1351        /// Frequency penalty applied to the next token's logprobs, multiplied by the number of times each token has been
1352        /// seen in the response so far. A positive penalty will discourage the use of tokens that have already been
1353        /// used, proportional to the number of times the token has been used: The more a token is used, the more
1354        /// difficult it is for the  model to use that token again increasing the vocabulary of responses. Caution: A
1355        /// negative penalty will encourage the model to reuse tokens proportional to the number of times the token has
1356        /// been used. Small negative values will reduce the vocabulary of a response. Larger negative values will cause
1357        /// the model to  repeating a common token until it hits the maxOutputTokens limit: "...the the the the the...".
1358        #[serde(skip_serializing_if = "Option::is_none")]
1359        pub frequency_penalty: Option<f64>,
1360        /// If true, export the logprobs results in response.
1361        #[serde(skip_serializing_if = "Option::is_none")]
1362        pub response_logprobs: Option<bool>,
1363        /// Only valid if responseLogprobs=True. This sets the number of top logprobs to return at each decoding step in
1364        /// [Candidate.logprobs_result].
1365        #[serde(skip_serializing_if = "Option::is_none")]
1366        pub logprobs: Option<i32>,
1367        /// Configuration for thinking/reasoning.
1368        #[serde(skip_serializing_if = "Option::is_none")]
1369        pub thinking_config: Option<ThinkingConfig>,
1370        #[serde(skip_serializing_if = "Option::is_none")]
1371        pub image_config: Option<ImageConfig>,
1372    }
1373
1374    impl Default for GenerationConfig {
1375        fn default() -> Self {
1376            Self {
1377                temperature: Some(1.0),
1378                max_output_tokens: Some(4096),
1379                stop_sequences: None,
1380                response_mime_type: None,
1381                response_schema: None,
1382                _response_json_schema: None,
1383                response_json_schema: None,
1384                candidate_count: None,
1385                top_p: None,
1386                top_k: None,
1387                presence_penalty: None,
1388                frequency_penalty: None,
1389                response_logprobs: None,
1390                logprobs: None,
1391                thinking_config: None,
1392                image_config: None,
1393            }
1394        }
1395    }
1396
1397    #[derive(Debug, Deserialize, Serialize)]
1398    #[serde(rename_all = "camelCase")]
1399    pub struct ThinkingConfig {
1400        pub thinking_budget: u32,
1401        pub include_thoughts: Option<bool>,
1402    }
1403
1404    #[derive(Debug, Deserialize, Serialize)]
1405    #[serde(rename_all = "camelCase")]
1406    pub struct ImageConfig {
1407        #[serde(skip_serializing_if = "Option::is_none")]
1408        pub aspect_ratio: Option<String>,
1409        #[serde(skip_serializing_if = "Option::is_none")]
1410        pub image_size: Option<String>,
1411    }
1412
1413    /// The Schema object allows the definition of input and output data types. These types can be objects, but also
1414    /// primitives and arrays. Represents a select subset of an OpenAPI 3.0 schema object.
1415    /// From [Gemini API Reference](https://ai.google.dev/api/caching#Schema)
1416    #[derive(Debug, Deserialize, Serialize, Clone)]
1417    pub struct Schema {
1418        pub r#type: String,
1419        #[serde(skip_serializing_if = "Option::is_none")]
1420        pub format: Option<String>,
1421        #[serde(skip_serializing_if = "Option::is_none")]
1422        pub description: Option<String>,
1423        #[serde(skip_serializing_if = "Option::is_none")]
1424        pub nullable: Option<bool>,
1425        #[serde(skip_serializing_if = "Option::is_none")]
1426        pub r#enum: Option<Vec<String>>,
1427        #[serde(skip_serializing_if = "Option::is_none")]
1428        pub max_items: Option<i32>,
1429        #[serde(skip_serializing_if = "Option::is_none")]
1430        pub min_items: Option<i32>,
1431        #[serde(skip_serializing_if = "Option::is_none")]
1432        pub properties: Option<HashMap<String, Schema>>,
1433        #[serde(skip_serializing_if = "Option::is_none")]
1434        pub required: Option<Vec<String>>,
1435        #[serde(skip_serializing_if = "Option::is_none")]
1436        pub items: Option<Box<Schema>>,
1437    }
1438
1439    /// Flattens a JSON schema by resolving all `$ref` references inline.
1440    /// It takes a JSON schema that may contain `$ref` references to definitions
1441    /// in `$defs` or `definitions` sections and returns a new schema with all references
1442    /// resolved and inlined. This is necessary for APIs like Gemini that don't support
1443    /// schema references.
1444    pub fn flatten_schema(mut schema: Value) -> Result<Value, CompletionError> {
1445        // extracting $defs if they exist
1446        let defs = if let Some(obj) = schema.as_object() {
1447            obj.get("$defs").or_else(|| obj.get("definitions")).cloned()
1448        } else {
1449            None
1450        };
1451
1452        let Some(defs_value) = defs else {
1453            return Ok(schema);
1454        };
1455
1456        let Some(defs_obj) = defs_value.as_object() else {
1457            return Err(CompletionError::ResponseError(
1458                "$defs must be an object".into(),
1459            ));
1460        };
1461
1462        resolve_refs(&mut schema, defs_obj)?;
1463
1464        // removing $defs from the final schema because we have inlined everything
1465        if let Some(obj) = schema.as_object_mut() {
1466            obj.remove("$defs");
1467            obj.remove("definitions");
1468        }
1469
1470        Ok(schema)
1471    }
1472
1473    /// Recursively resolves all `$ref` references in a JSON value by
1474    /// replacing them with their definitions.
1475    fn resolve_refs(
1476        value: &mut Value,
1477        defs: &serde_json::Map<String, Value>,
1478    ) -> Result<(), CompletionError> {
1479        match value {
1480            Value::Object(obj) => {
1481                if let Some(ref_value) = obj.get("$ref")
1482                    && let Some(ref_str) = ref_value.as_str()
1483                {
1484                    // "#/$defs/Person" -> "Person"
1485                    let def_name = parse_ref_path(ref_str)?;
1486
1487                    let def = defs.get(&def_name).ok_or_else(|| {
1488                        CompletionError::ResponseError(format!("Reference not found: {}", ref_str))
1489                    })?;
1490
1491                    let mut resolved = def.clone();
1492                    resolve_refs(&mut resolved, defs)?;
1493                    *value = resolved;
1494                    return Ok(());
1495                }
1496
1497                for (_, v) in obj.iter_mut() {
1498                    resolve_refs(v, defs)?;
1499                }
1500            }
1501            Value::Array(arr) => {
1502                for item in arr.iter_mut() {
1503                    resolve_refs(item, defs)?;
1504                }
1505            }
1506            _ => {}
1507        }
1508
1509        Ok(())
1510    }
1511
1512    /// Parses a JSON Schema `$ref` path to extract the definition name.
1513    ///
1514    /// JSON Schema references use URI fragment syntax to point to definitions within
1515    /// the same document. This function extracts the definition name from common
1516    /// reference patterns used in JSON Schema.
1517    fn parse_ref_path(ref_str: &str) -> Result<String, CompletionError> {
1518        if let Some(fragment) = ref_str.strip_prefix('#') {
1519            if let Some(name) = fragment.strip_prefix("/$defs/") {
1520                Ok(name.to_string())
1521            } else if let Some(name) = fragment.strip_prefix("/definitions/") {
1522                Ok(name.to_string())
1523            } else {
1524                Err(CompletionError::ResponseError(format!(
1525                    "Unsupported reference format: {}",
1526                    ref_str
1527                )))
1528            }
1529        } else {
1530            Err(CompletionError::ResponseError(format!(
1531                "Only fragment references (#/...) are supported: {}",
1532                ref_str
1533            )))
1534        }
1535    }
1536
1537    /// Helper function to extract the type string from a JSON value.
1538    /// Handles both direct string types and array types (returns the first element).
1539    fn extract_type(type_value: &Value) -> Option<String> {
1540        if type_value.is_string() {
1541            type_value.as_str().map(String::from)
1542        } else if type_value.is_array() {
1543            type_value
1544                .as_array()
1545                .and_then(|arr| arr.first())
1546                .and_then(|v| v.as_str().map(String::from))
1547        } else {
1548            None
1549        }
1550    }
1551
1552    /// Helper function to extract type from anyOf, oneOf, or allOf schemas.
1553    /// Returns the type of the first non-null schema found.
1554    fn extract_type_from_composition(composition: &Value) -> Option<String> {
1555        composition.as_array().and_then(|arr| {
1556            arr.iter().find_map(|schema| {
1557                if let Some(obj) = schema.as_object() {
1558                    // Skip null types
1559                    if let Some(type_val) = obj.get("type")
1560                        && let Some(type_str) = type_val.as_str()
1561                        && type_str == "null"
1562                    {
1563                        return None;
1564                    }
1565                    // Extract type from this schema
1566                    obj.get("type").and_then(extract_type).or_else(|| {
1567                        if obj.contains_key("properties") {
1568                            Some("object".to_string())
1569                        } else {
1570                            None
1571                        }
1572                    })
1573                } else {
1574                    None
1575                }
1576            })
1577        })
1578    }
1579
1580    /// Helper function to extract the first non-null schema from anyOf, oneOf, or allOf.
1581    /// Returns the schema object that should be used for properties, required, etc.
1582    fn extract_schema_from_composition(
1583        composition: &Value,
1584    ) -> Option<serde_json::Map<String, Value>> {
1585        composition.as_array().and_then(|arr| {
1586            arr.iter().find_map(|schema| {
1587                if let Some(obj) = schema.as_object()
1588                    && let Some(type_val) = obj.get("type")
1589                    && let Some(type_str) = type_val.as_str()
1590                {
1591                    if type_str == "null" {
1592                        return None;
1593                    }
1594                    Some(obj.clone())
1595                } else {
1596                    None
1597                }
1598            })
1599        })
1600    }
1601
1602    /// Helper function to infer the type of a schema object.
1603    /// Checks for explicit type, then anyOf/oneOf/allOf, then infers from properties.
1604    fn infer_type(obj: &serde_json::Map<String, Value>) -> String {
1605        // First, try direct type field
1606        if let Some(type_val) = obj.get("type")
1607            && let Some(type_str) = extract_type(type_val)
1608        {
1609            return type_str;
1610        }
1611
1612        // Then try anyOf, oneOf, allOf (in that order)
1613        if let Some(any_of) = obj.get("anyOf")
1614            && let Some(type_str) = extract_type_from_composition(any_of)
1615        {
1616            return type_str;
1617        }
1618
1619        if let Some(one_of) = obj.get("oneOf")
1620            && let Some(type_str) = extract_type_from_composition(one_of)
1621        {
1622            return type_str;
1623        }
1624
1625        if let Some(all_of) = obj.get("allOf")
1626            && let Some(type_str) = extract_type_from_composition(all_of)
1627        {
1628            return type_str;
1629        }
1630
1631        // Finally, infer object type if properties are present
1632        if obj.contains_key("properties") {
1633            "object".to_string()
1634        } else {
1635            String::new()
1636        }
1637    }
1638
1639    impl TryFrom<Value> for Schema {
1640        type Error = CompletionError;
1641
1642        fn try_from(value: Value) -> Result<Self, Self::Error> {
1643            let flattened_val = flatten_schema(value)?;
1644            if let Some(obj) = flattened_val.as_object() {
1645                // Determine which object to use for extracting properties and required fields.
1646                // If this object has anyOf/oneOf/allOf, we need to extract properties from the composition.
1647                let props_source = if obj.get("properties").is_none() {
1648                    if let Some(any_of) = obj.get("anyOf") {
1649                        extract_schema_from_composition(any_of)
1650                    } else if let Some(one_of) = obj.get("oneOf") {
1651                        extract_schema_from_composition(one_of)
1652                    } else if let Some(all_of) = obj.get("allOf") {
1653                        extract_schema_from_composition(all_of)
1654                    } else {
1655                        None
1656                    }
1657                    .unwrap_or(obj.clone())
1658                } else {
1659                    obj.clone()
1660                };
1661
1662                Ok(Schema {
1663                    r#type: infer_type(obj),
1664                    format: obj.get("format").and_then(|v| v.as_str()).map(String::from),
1665                    description: obj
1666                        .get("description")
1667                        .and_then(|v| v.as_str())
1668                        .map(String::from),
1669                    nullable: obj.get("nullable").and_then(|v| v.as_bool()),
1670                    r#enum: obj.get("enum").and_then(|v| v.as_array()).map(|arr| {
1671                        arr.iter()
1672                            .filter_map(|v| v.as_str().map(String::from))
1673                            .collect()
1674                    }),
1675                    max_items: obj
1676                        .get("maxItems")
1677                        .and_then(|v| v.as_i64())
1678                        .map(|v| v as i32),
1679                    min_items: obj
1680                        .get("minItems")
1681                        .and_then(|v| v.as_i64())
1682                        .map(|v| v as i32),
1683                    properties: props_source
1684                        .get("properties")
1685                        .and_then(|v| v.as_object())
1686                        .map(|map| {
1687                            map.iter()
1688                                .filter_map(|(k, v)| {
1689                                    v.clone().try_into().ok().map(|schema| (k.clone(), schema))
1690                                })
1691                                .collect()
1692                        }),
1693                    required: props_source
1694                        .get("required")
1695                        .and_then(|v| v.as_array())
1696                        .map(|arr| {
1697                            arr.iter()
1698                                .filter_map(|v| v.as_str().map(String::from))
1699                                .collect()
1700                        }),
1701                    items: obj
1702                        .get("items")
1703                        .and_then(|v| v.clone().try_into().ok())
1704                        .map(Box::new),
1705                })
1706            } else {
1707                Err(CompletionError::ResponseError(
1708                    "Expected a JSON object for Schema".into(),
1709                ))
1710            }
1711        }
1712    }
1713
1714    #[derive(Debug, Serialize)]
1715    #[serde(rename_all = "camelCase")]
1716    pub struct GenerateContentRequest {
1717        pub contents: Vec<Content>,
1718        #[serde(skip_serializing_if = "Option::is_none")]
1719        pub tools: Option<Vec<Tool>>,
1720        pub tool_config: Option<ToolConfig>,
1721        /// Optional. Configuration options for model generation and outputs.
1722        pub generation_config: Option<GenerationConfig>,
1723        /// Optional. A list of unique SafetySetting instances for blocking unsafe content. This will be enforced on the
1724        /// [GenerateContentRequest.contents] and [GenerateContentResponse.candidates]. There should not be more than one
1725        /// setting for each SafetyCategory type. The API will block any contents and responses that fail to meet the
1726        /// thresholds set by these settings. This list overrides the default settings for each SafetyCategory specified
1727        /// in the safetySettings. If there is no SafetySetting for a given SafetyCategory provided in the list, the API
1728        /// will use the default safety setting for that category. Harm categories:
1729        ///     - HARM_CATEGORY_HATE_SPEECH,
1730        ///     - HARM_CATEGORY_SEXUALLY_EXPLICIT
1731        ///     - HARM_CATEGORY_DANGEROUS_CONTENT
1732        ///     - HARM_CATEGORY_HARASSMENT
1733        /// are supported.
1734        /// Refer to the guide for detailed information on available safety settings. Also refer to the Safety guidance
1735        /// to learn how to incorporate safety considerations in your AI applications.
1736        pub safety_settings: Option<Vec<SafetySetting>>,
1737        /// Optional. Developer set system instruction(s). Currently, text only.
1738        /// From [Gemini API Reference](https://ai.google.dev/gemini-api/docs/system-instructions?lang=rest)
1739        pub system_instruction: Option<Content>,
1740        // cachedContent: Optional<String>
1741        /// Additional parameters.
1742        #[serde(flatten, skip_serializing_if = "Option::is_none")]
1743        pub additional_params: Option<serde_json::Value>,
1744    }
1745
1746    #[derive(Debug, Serialize)]
1747    #[serde(rename_all = "camelCase")]
1748    pub struct Tool {
1749        pub function_declarations: Vec<FunctionDeclaration>,
1750        pub code_execution: Option<CodeExecution>,
1751    }
1752
1753    #[derive(Debug, Serialize, Clone)]
1754    #[serde(rename_all = "camelCase")]
1755    pub struct FunctionDeclaration {
1756        pub name: String,
1757        pub description: String,
1758        #[serde(skip_serializing_if = "Option::is_none")]
1759        pub parameters: Option<Schema>,
1760    }
1761
1762    #[derive(Debug, Serialize, Deserialize)]
1763    #[serde(rename_all = "camelCase")]
1764    pub struct ToolConfig {
1765        pub function_calling_config: Option<FunctionCallingMode>,
1766    }
1767
1768    #[derive(Debug, Serialize, Deserialize, Default)]
1769    #[serde(tag = "mode", rename_all = "UPPERCASE")]
1770    pub enum FunctionCallingMode {
1771        #[default]
1772        Auto,
1773        None,
1774        Any {
1775            #[serde(skip_serializing_if = "Option::is_none")]
1776            allowed_function_names: Option<Vec<String>>,
1777        },
1778    }
1779
1780    impl TryFrom<message::ToolChoice> for FunctionCallingMode {
1781        type Error = CompletionError;
1782        fn try_from(value: message::ToolChoice) -> Result<Self, Self::Error> {
1783            let res = match value {
1784                message::ToolChoice::Auto => Self::Auto,
1785                message::ToolChoice::None => Self::None,
1786                message::ToolChoice::Required => Self::Any {
1787                    allowed_function_names: None,
1788                },
1789                message::ToolChoice::Specific { function_names } => Self::Any {
1790                    allowed_function_names: Some(function_names),
1791                },
1792            };
1793
1794            Ok(res)
1795        }
1796    }
1797
1798    #[derive(Debug, Serialize)]
1799    pub struct CodeExecution {}
1800
1801    #[derive(Debug, Serialize)]
1802    #[serde(rename_all = "camelCase")]
1803    pub struct SafetySetting {
1804        pub category: HarmCategory,
1805        pub threshold: HarmBlockThreshold,
1806    }
1807
1808    #[derive(Debug, Serialize)]
1809    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1810    pub enum HarmBlockThreshold {
1811        HarmBlockThresholdUnspecified,
1812        BlockLowAndAbove,
1813        BlockMediumAndAbove,
1814        BlockOnlyHigh,
1815        BlockNone,
1816        Off,
1817    }
1818}
1819
1820#[cfg(test)]
1821mod tests {
1822    use crate::{message, providers::gemini::completion::gemini_api_types::flatten_schema};
1823
1824    use super::*;
1825    use serde_json::json;
1826
1827    #[test]
1828    fn test_deserialize_message_user() {
1829        let raw_message = r#"{
1830            "parts": [
1831                {"text": "Hello, world!"},
1832                {"inlineData": {"mimeType": "image/png", "data": "base64encodeddata"}},
1833                {"functionCall": {"name": "test_function", "args": {"arg1": "value1"}}},
1834                {"functionResponse": {"name": "test_function", "response": {"result": "success"}}},
1835                {"fileData": {"mimeType": "application/pdf", "fileUri": "http://example.com/file.pdf"}},
1836                {"executableCode": {"code": "print('Hello, world!')", "language": "PYTHON"}},
1837                {"codeExecutionResult": {"output": "Hello, world!", "outcome": "OUTCOME_OK"}}
1838            ],
1839            "role": "user"
1840        }"#;
1841
1842        let content: Content = {
1843            let jd = &mut serde_json::Deserializer::from_str(raw_message);
1844            serde_path_to_error::deserialize(jd).unwrap_or_else(|err| {
1845                panic!("Deserialization error at {}: {}", err.path(), err);
1846            })
1847        };
1848        assert_eq!(content.role, Some(Role::User));
1849        assert_eq!(content.parts.len(), 7);
1850
1851        let parts: Vec<Part> = content.parts.into_iter().collect();
1852
1853        if let Part {
1854            part: PartKind::Text(text),
1855            ..
1856        } = &parts[0]
1857        {
1858            assert_eq!(text, "Hello, world!");
1859        } else {
1860            panic!("Expected text part");
1861        }
1862
1863        if let Part {
1864            part: PartKind::InlineData(inline_data),
1865            ..
1866        } = &parts[1]
1867        {
1868            assert_eq!(inline_data.mime_type, "image/png");
1869            assert_eq!(inline_data.data, "base64encodeddata");
1870        } else {
1871            panic!("Expected inline data part");
1872        }
1873
1874        if let Part {
1875            part: PartKind::FunctionCall(function_call),
1876            ..
1877        } = &parts[2]
1878        {
1879            assert_eq!(function_call.name, "test_function");
1880            assert_eq!(
1881                function_call.args.as_object().unwrap().get("arg1").unwrap(),
1882                "value1"
1883            );
1884        } else {
1885            panic!("Expected function call part");
1886        }
1887
1888        if let Part {
1889            part: PartKind::FunctionResponse(function_response),
1890            ..
1891        } = &parts[3]
1892        {
1893            assert_eq!(function_response.name, "test_function");
1894            assert_eq!(
1895                function_response
1896                    .response
1897                    .as_ref()
1898                    .unwrap()
1899                    .get("result")
1900                    .unwrap(),
1901                "success"
1902            );
1903        } else {
1904            panic!("Expected function response part");
1905        }
1906
1907        if let Part {
1908            part: PartKind::FileData(file_data),
1909            ..
1910        } = &parts[4]
1911        {
1912            assert_eq!(file_data.mime_type.as_ref().unwrap(), "application/pdf");
1913            assert_eq!(file_data.file_uri, "http://example.com/file.pdf");
1914        } else {
1915            panic!("Expected file data part");
1916        }
1917
1918        if let Part {
1919            part: PartKind::ExecutableCode(executable_code),
1920            ..
1921        } = &parts[5]
1922        {
1923            assert_eq!(executable_code.code, "print('Hello, world!')");
1924        } else {
1925            panic!("Expected executable code part");
1926        }
1927
1928        if let Part {
1929            part: PartKind::CodeExecutionResult(code_execution_result),
1930            ..
1931        } = &parts[6]
1932        {
1933            assert_eq!(
1934                code_execution_result.clone().output.unwrap(),
1935                "Hello, world!"
1936            );
1937        } else {
1938            panic!("Expected code execution result part");
1939        }
1940    }
1941
1942    #[test]
1943    fn test_deserialize_message_model() {
1944        let json_data = json!({
1945            "parts": [{"text": "Hello, user!"}],
1946            "role": "model"
1947        });
1948
1949        let content: Content = serde_json::from_value(json_data).unwrap();
1950        assert_eq!(content.role, Some(Role::Model));
1951        assert_eq!(content.parts.len(), 1);
1952        if let Some(Part {
1953            part: PartKind::Text(text),
1954            ..
1955        }) = content.parts.first()
1956        {
1957            assert_eq!(text, "Hello, user!");
1958        } else {
1959            panic!("Expected text part");
1960        }
1961    }
1962
1963    #[test]
1964    fn test_message_conversion_user() {
1965        let msg = message::Message::user("Hello, world!");
1966        let content: Content = msg.try_into().unwrap();
1967        assert_eq!(content.role, Some(Role::User));
1968        assert_eq!(content.parts.len(), 1);
1969        if let Some(Part {
1970            part: PartKind::Text(text),
1971            ..
1972        }) = &content.parts.first()
1973        {
1974            assert_eq!(text, "Hello, world!");
1975        } else {
1976            panic!("Expected text part");
1977        }
1978    }
1979
1980    #[test]
1981    fn test_message_conversion_model() {
1982        let msg = message::Message::assistant("Hello, user!");
1983
1984        let content: Content = msg.try_into().unwrap();
1985        assert_eq!(content.role, Some(Role::Model));
1986        assert_eq!(content.parts.len(), 1);
1987        if let Some(Part {
1988            part: PartKind::Text(text),
1989            ..
1990        }) = &content.parts.first()
1991        {
1992            assert_eq!(text, "Hello, user!");
1993        } else {
1994            panic!("Expected text part");
1995        }
1996    }
1997
1998    #[test]
1999    fn test_message_conversion_tool_call() {
2000        let tool_call = message::ToolCall {
2001            id: "test_tool".to_string(),
2002            call_id: None,
2003            function: message::ToolFunction {
2004                name: "test_function".to_string(),
2005                arguments: json!({"arg1": "value1"}),
2006            },
2007            signature: None,
2008            additional_params: None,
2009        };
2010
2011        let msg = message::Message::Assistant {
2012            id: None,
2013            content: OneOrMany::one(message::AssistantContent::ToolCall(tool_call)),
2014        };
2015
2016        let content: Content = msg.try_into().unwrap();
2017        assert_eq!(content.role, Some(Role::Model));
2018        assert_eq!(content.parts.len(), 1);
2019        if let Some(Part {
2020            part: PartKind::FunctionCall(function_call),
2021            ..
2022        }) = content.parts.first()
2023        {
2024            assert_eq!(function_call.name, "test_function");
2025            assert_eq!(
2026                function_call.args.as_object().unwrap().get("arg1").unwrap(),
2027                "value1"
2028            );
2029        } else {
2030            panic!("Expected function call part");
2031        }
2032    }
2033
2034    #[test]
2035    fn test_vec_schema_conversion() {
2036        let schema_with_ref = json!({
2037            "type": "array",
2038            "items": {
2039                "$ref": "#/$defs/Person"
2040            },
2041            "$defs": {
2042                "Person": {
2043                    "type": "object",
2044                    "properties": {
2045                        "first_name": {
2046                            "type": ["string", "null"],
2047                            "description": "The person's first name, if provided (null otherwise)"
2048                        },
2049                        "last_name": {
2050                            "type": ["string", "null"],
2051                            "description": "The person's last name, if provided (null otherwise)"
2052                        },
2053                        "job": {
2054                            "type": ["string", "null"],
2055                            "description": "The person's job, if provided (null otherwise)"
2056                        }
2057                    },
2058                    "required": []
2059                }
2060            }
2061        });
2062
2063        let result: Result<Schema, _> = schema_with_ref.try_into();
2064
2065        match result {
2066            Ok(schema) => {
2067                assert_eq!(schema.r#type, "array");
2068
2069                if let Some(items) = schema.items {
2070                    println!("item types: {}", items.r#type);
2071
2072                    assert_ne!(items.r#type, "", "Items type should not be empty string!");
2073                    assert_eq!(items.r#type, "object", "Items should be object type");
2074                } else {
2075                    panic!("Schema should have items field for array type");
2076                }
2077            }
2078            Err(e) => println!("Schema conversion failed: {:?}", e),
2079        }
2080    }
2081
2082    #[test]
2083    fn test_object_schema() {
2084        let simple_schema = json!({
2085            "type": "object",
2086            "properties": {
2087                "name": {
2088                    "type": "string"
2089                }
2090            }
2091        });
2092
2093        let schema: Schema = simple_schema.try_into().unwrap();
2094        assert_eq!(schema.r#type, "object");
2095        assert!(schema.properties.is_some());
2096    }
2097
2098    #[test]
2099    fn test_array_with_inline_items() {
2100        let inline_schema = json!({
2101            "type": "array",
2102            "items": {
2103                "type": "object",
2104                "properties": {
2105                    "name": {
2106                        "type": "string"
2107                    }
2108                }
2109            }
2110        });
2111
2112        let schema: Schema = inline_schema.try_into().unwrap();
2113        assert_eq!(schema.r#type, "array");
2114
2115        if let Some(items) = schema.items {
2116            assert_eq!(items.r#type, "object");
2117            assert!(items.properties.is_some());
2118        } else {
2119            panic!("Schema should have items field");
2120        }
2121    }
2122    #[test]
2123    fn test_flattened_schema() {
2124        let ref_schema = json!({
2125            "type": "array",
2126            "items": {
2127                "$ref": "#/$defs/Person"
2128            },
2129            "$defs": {
2130                "Person": {
2131                    "type": "object",
2132                    "properties": {
2133                        "name": { "type": "string" }
2134                    }
2135                }
2136            }
2137        });
2138
2139        let flattened = flatten_schema(ref_schema).unwrap();
2140        let schema: Schema = flattened.try_into().unwrap();
2141
2142        assert_eq!(schema.r#type, "array");
2143
2144        if let Some(items) = schema.items {
2145            println!("Flattened items type: '{}'", items.r#type);
2146
2147            assert_eq!(items.r#type, "object");
2148            assert!(items.properties.is_some());
2149        }
2150    }
2151
2152    #[test]
2153    fn test_tool_result_with_image_content() {
2154        // Test that a ToolResult with image content converts correctly to Gemini's Part format
2155        use crate::OneOrMany;
2156        use crate::message::{
2157            DocumentSourceKind, Image, ImageMediaType, ToolResult, ToolResultContent,
2158        };
2159
2160        // Create a tool result with both text and image content
2161        let tool_result = ToolResult {
2162            id: "test_tool".to_string(),
2163            call_id: None,
2164            content: OneOrMany::many(vec![
2165                ToolResultContent::Text(message::Text {
2166                    text: r#"{"status": "success"}"#.to_string(),
2167                }),
2168                ToolResultContent::Image(Image {
2169                    data: DocumentSourceKind::Base64("iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==".to_string()),
2170                    media_type: Some(ImageMediaType::PNG),
2171                    detail: None,
2172                    additional_params: None,
2173                }),
2174            ]).expect("Should create OneOrMany with multiple items"),
2175        };
2176
2177        let user_content = message::UserContent::ToolResult(tool_result);
2178        let msg = message::Message::User {
2179            content: OneOrMany::one(user_content),
2180        };
2181
2182        // Convert to Gemini Content
2183        let content: Content = msg.try_into().expect("Should convert to Gemini Content");
2184
2185        assert_eq!(content.role, Some(Role::User));
2186        assert_eq!(content.parts.len(), 1);
2187
2188        // Verify the part is a FunctionResponse with both response and parts
2189        if let Some(Part {
2190            part: PartKind::FunctionResponse(function_response),
2191            ..
2192        }) = content.parts.first()
2193        {
2194            assert_eq!(function_response.name, "test_tool");
2195
2196            // Check that response JSON is present
2197            assert!(function_response.response.is_some());
2198            let response = function_response.response.as_ref().unwrap();
2199            assert!(response.get("result").is_some());
2200
2201            // Check that parts with image data are present
2202            assert!(function_response.parts.is_some());
2203            let parts = function_response.parts.as_ref().unwrap();
2204            assert_eq!(parts.len(), 1);
2205
2206            let image_part = &parts[0];
2207            assert!(image_part.inline_data.is_some());
2208            let inline_data = image_part.inline_data.as_ref().unwrap();
2209            assert_eq!(inline_data.mime_type, "image/png");
2210            assert!(!inline_data.data.is_empty());
2211        } else {
2212            panic!("Expected FunctionResponse part");
2213        }
2214    }
2215
2216    #[test]
2217    fn test_tool_result_with_url_image() {
2218        // Test that a ToolResult with a URL-based image converts to file_data
2219        use crate::OneOrMany;
2220        use crate::message::{
2221            DocumentSourceKind, Image, ImageMediaType, ToolResult, ToolResultContent,
2222        };
2223
2224        let tool_result = ToolResult {
2225            id: "screenshot_tool".to_string(),
2226            call_id: None,
2227            content: OneOrMany::one(ToolResultContent::Image(Image {
2228                data: DocumentSourceKind::Url("https://example.com/image.png".to_string()),
2229                media_type: Some(ImageMediaType::PNG),
2230                detail: None,
2231                additional_params: None,
2232            })),
2233        };
2234
2235        let user_content = message::UserContent::ToolResult(tool_result);
2236        let msg = message::Message::User {
2237            content: OneOrMany::one(user_content),
2238        };
2239
2240        let content: Content = msg.try_into().expect("Should convert to Gemini Content");
2241
2242        assert_eq!(content.role, Some(Role::User));
2243        assert_eq!(content.parts.len(), 1);
2244
2245        if let Some(Part {
2246            part: PartKind::FunctionResponse(function_response),
2247            ..
2248        }) = content.parts.first()
2249        {
2250            assert_eq!(function_response.name, "screenshot_tool");
2251
2252            // URL images should have parts with file_data
2253            assert!(function_response.parts.is_some());
2254            let parts = function_response.parts.as_ref().unwrap();
2255            assert_eq!(parts.len(), 1);
2256
2257            let image_part = &parts[0];
2258            assert!(image_part.file_data.is_some());
2259            let file_data = image_part.file_data.as_ref().unwrap();
2260            assert_eq!(file_data.file_uri, "https://example.com/image.png");
2261            assert_eq!(file_data.mime_type.as_ref().unwrap(), "image/png");
2262        } else {
2263            panic!("Expected FunctionResponse part");
2264        }
2265    }
2266
2267    #[test]
2268    fn test_from_tool_output_parses_image_json() {
2269        // Test the ToolResultContent::from_tool_output helper with image JSON
2270        use crate::message::{DocumentSourceKind, ToolResultContent};
2271
2272        // Test simple image JSON format
2273        let image_json = r#"{"type": "image", "data": "base64data==", "mimeType": "image/jpeg"}"#;
2274        let result = ToolResultContent::from_tool_output(image_json);
2275
2276        assert_eq!(result.len(), 1);
2277        if let ToolResultContent::Image(img) = result.first() {
2278            assert!(matches!(img.data, DocumentSourceKind::Base64(_)));
2279            if let DocumentSourceKind::Base64(data) = &img.data {
2280                assert_eq!(data, "base64data==");
2281            }
2282            assert_eq!(img.media_type, Some(crate::message::ImageMediaType::JPEG));
2283        } else {
2284            panic!("Expected Image content");
2285        }
2286    }
2287
2288    #[test]
2289    fn test_from_tool_output_parses_hybrid_json() {
2290        // Test the ToolResultContent::from_tool_output helper with hybrid response/parts format
2291        use crate::message::{DocumentSourceKind, ToolResultContent};
2292
2293        let hybrid_json = r#"{
2294            "response": {"status": "ok", "count": 42},
2295            "parts": [
2296                {"type": "image", "data": "imgdata1==", "mimeType": "image/png"},
2297                {"type": "image", "data": "https://example.com/img.jpg", "mimeType": "image/jpeg"}
2298            ]
2299        }"#;
2300
2301        let result = ToolResultContent::from_tool_output(hybrid_json);
2302
2303        // Should have 3 items: 1 text (response) + 2 images (parts)
2304        assert_eq!(result.len(), 3);
2305
2306        let items: Vec<_> = result.iter().collect();
2307
2308        // First should be text with the response JSON
2309        if let ToolResultContent::Text(text) = &items[0] {
2310            assert!(text.text.contains("status"));
2311            assert!(text.text.contains("ok"));
2312        } else {
2313            panic!("Expected Text content first");
2314        }
2315
2316        // Second should be base64 image
2317        if let ToolResultContent::Image(img) = &items[1] {
2318            assert!(matches!(img.data, DocumentSourceKind::Base64(_)));
2319        } else {
2320            panic!("Expected Image content second");
2321        }
2322
2323        // Third should be URL image
2324        if let ToolResultContent::Image(img) = &items[2] {
2325            assert!(matches!(img.data, DocumentSourceKind::Url(_)));
2326        } else {
2327            panic!("Expected Image content third");
2328        }
2329    }
2330
2331    /// E2E test that verifies Gemini can process tool results containing images.
2332    /// This test creates an agent with a tool that returns an image, invokes it,
2333    /// and verifies that Gemini can interpret the image in the tool result.
2334    #[tokio::test]
2335    #[ignore = "requires GEMINI_API_KEY environment variable"]
2336    async fn test_gemini_agent_with_image_tool_result_e2e() {
2337        use crate::completion::{Prompt, ToolDefinition};
2338        use crate::prelude::*;
2339        use crate::providers::gemini;
2340        use crate::tool::Tool;
2341        use serde::{Deserialize, Serialize};
2342
2343        /// A tool that returns a small red 1x1 pixel PNG image
2344        #[derive(Debug, Serialize, Deserialize)]
2345        struct ImageGeneratorTool;
2346
2347        #[derive(Debug, thiserror::Error)]
2348        #[error("Image generation error")]
2349        struct ImageToolError;
2350
2351        impl Tool for ImageGeneratorTool {
2352            const NAME: &'static str = "generate_test_image";
2353            type Error = ImageToolError;
2354            type Args = serde_json::Value;
2355            // Return the image in the format that from_tool_output expects
2356            type Output = String;
2357
2358            async fn definition(&self, _prompt: String) -> ToolDefinition {
2359                ToolDefinition {
2360                    name: "generate_test_image".to_string(),
2361                    description: "Generates a small test image (a 1x1 red pixel). Call this tool when asked to generate or show an image.".to_string(),
2362                    parameters: json!({
2363                        "type": "object",
2364                        "properties": {},
2365                        "required": []
2366                    }),
2367                }
2368            }
2369
2370            async fn call(&self, _args: Self::Args) -> Result<Self::Output, Self::Error> {
2371                // Return a JSON object that from_tool_output will parse as an image
2372                // This is a 1x1 red PNG pixel
2373                Ok(json!({
2374                    "type": "image",
2375                    "data": "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8DwHwAFBQIAX8jx0gAAAABJRU5ErkJggg==",
2376                    "mimeType": "image/png"
2377                }).to_string())
2378            }
2379        }
2380
2381        let client = gemini::Client::from_env();
2382
2383        let agent = client
2384            .agent("gemini-3-flash-preview")
2385            .preamble("You are a helpful assistant. When asked about images, use the generate_test_image tool to create one, then describe what you see in the image.")
2386            .tool(ImageGeneratorTool)
2387            .build();
2388
2389        // This prompt should trigger the tool, which returns an image that Gemini should process
2390        let response = agent
2391            .prompt("Please generate a test image and tell me what color the pixel is.")
2392            .await;
2393
2394        // The test passes if Gemini successfully processes the request without errors.
2395        // The image is a 1x1 red pixel, so Gemini should be able to describe it.
2396        assert!(
2397            response.is_ok(),
2398            "Gemini should successfully process tool result with image: {:?}",
2399            response.err()
2400        );
2401
2402        let response_text = response.unwrap();
2403        println!("Response: {response_text}");
2404        // Gemini should have been able to see the image and potentially describe its color
2405        assert!(!response_text.is_empty(), "Response should not be empty");
2406    }
2407}