Skip to main content

rig/providers/gemini/
completion.rs

1// ================================================================
2//! Google Gemini Completion Integration
3//! From [Gemini API Reference](https://ai.google.dev/api/generate-content)
4// ================================================================
5/// `gemini-2.5-pro-preview-06-05` completion model
6pub const GEMINI_2_5_PRO_PREVIEW_06_05: &str = "gemini-2.5-pro-preview-06-05";
7/// `gemini-2.5-pro-preview-05-06` completion model
8pub const GEMINI_2_5_PRO_PREVIEW_05_06: &str = "gemini-2.5-pro-preview-05-06";
9/// `gemini-2.5-pro-preview-03-25` completion model
10pub const GEMINI_2_5_PRO_PREVIEW_03_25: &str = "gemini-2.5-pro-preview-03-25";
11/// `gemini-2.5-flash-preview-04-17` completion model
12pub const GEMINI_2_5_FLASH_PREVIEW_04_17: &str = "gemini-2.5-flash-preview-04-17";
13/// `gemini-2.5-pro-exp-03-25` experimental completion model
14pub const GEMINI_2_5_PRO_EXP_03_25: &str = "gemini-2.5-pro-exp-03-25";
15/// `gemini-2.5-flash` completion model
16pub const GEMINI_2_5_FLASH: &str = "gemini-2.5-flash";
17/// `gemini-2.0-flash-lite` completion model
18pub const GEMINI_2_0_FLASH_LITE: &str = "gemini-2.0-flash-lite";
19/// `gemini-2.0-flash` completion model
20pub const GEMINI_2_0_FLASH: &str = "gemini-2.0-flash";
21
22use self::gemini_api_types::Schema;
23use crate::http_client::HttpClientExt;
24use crate::message::{self, MimeType, Reasoning};
25
26use crate::providers::gemini::completion::gemini_api_types::{
27    AdditionalParameters, FunctionCallingMode, ToolConfig,
28};
29use crate::providers::gemini::streaming::StreamingCompletionResponse;
30use crate::telemetry::SpanCombinator;
31use crate::{
32    OneOrMany,
33    completion::{self, CompletionError, CompletionRequest},
34};
35use gemini_api_types::{
36    Content, FunctionDeclaration, GenerateContentRequest, GenerateContentResponse,
37    GenerationConfig, Part, PartKind, Role, Tool,
38};
39use serde_json::{Map, Value};
40use std::convert::TryFrom;
41use tracing::{Level, enabled, info_span};
42use tracing_futures::Instrument;
43
44use super::Client;
45
46// =================================================================
47// Rig Implementation Types
48// =================================================================
49
50#[derive(Clone, Debug)]
51pub struct CompletionModel<T = reqwest::Client> {
52    pub(crate) client: Client<T>,
53    pub model: String,
54}
55
56impl<T> CompletionModel<T> {
57    pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
58        Self {
59            client,
60            model: model.into(),
61        }
62    }
63
64    pub fn with_model(client: Client<T>, model: &str) -> Self {
65        Self {
66            client,
67            model: model.into(),
68        }
69    }
70}
71
72impl<T> completion::CompletionModel for CompletionModel<T>
73where
74    T: HttpClientExt + Clone + 'static,
75{
76    type Response = GenerateContentResponse;
77    type StreamingResponse = StreamingCompletionResponse;
78    type Client = super::Client<T>;
79
80    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
81        Self::new(client.clone(), model)
82    }
83
84    async fn completion(
85        &self,
86        completion_request: CompletionRequest,
87    ) -> Result<completion::CompletionResponse<GenerateContentResponse>, CompletionError> {
88        let request_model = resolve_request_model(&self.model, &completion_request);
89        let span = if tracing::Span::current().is_disabled() {
90            info_span!(
91                target: "rig::completions",
92                "generate_content",
93                gen_ai.operation.name = "generate_content",
94                gen_ai.provider.name = "gcp.gemini",
95                gen_ai.request.model = &request_model,
96                gen_ai.system_instructions = &completion_request.preamble,
97                gen_ai.response.id = tracing::field::Empty,
98                gen_ai.response.model = tracing::field::Empty,
99                gen_ai.usage.output_tokens = tracing::field::Empty,
100                gen_ai.usage.input_tokens = tracing::field::Empty,
101            )
102        } else {
103            tracing::Span::current()
104        };
105
106        let request = create_request_body(completion_request)?;
107
108        if enabled!(Level::TRACE) {
109            tracing::trace!(
110                target: "rig::completions",
111                "Gemini completion request: {}",
112                serde_json::to_string_pretty(&request)?
113            );
114        }
115
116        let body = serde_json::to_vec(&request)?;
117
118        let path = completion_endpoint(&request_model);
119
120        let request = self
121            .client
122            .post(path.as_str())?
123            .body(body)
124            .map_err(|e| CompletionError::HttpError(e.into()))?;
125
126        async move {
127            let response = self.client.send::<_, Vec<u8>>(request).await?;
128
129            if response.status().is_success() {
130                let response_body = response
131                    .into_body()
132                    .await
133                    .map_err(CompletionError::HttpError)?;
134
135                let response_text = String::from_utf8_lossy(&response_body).to_string();
136
137                let response: GenerateContentResponse = serde_json::from_slice(&response_body)
138                    .map_err(|err| {
139                        tracing::error!(
140                            error = %err,
141                            body = %response_text,
142                            "Failed to deserialize Gemini completion response"
143                        );
144                        CompletionError::JsonError(err)
145                    })?;
146
147                let span = tracing::Span::current();
148                span.record_response_metadata(&response);
149                span.record_token_usage(&response.usage_metadata);
150
151                if enabled!(Level::TRACE) {
152                    tracing::trace!(
153                        target: "rig::completions",
154                        "Gemini completion response: {}",
155                        serde_json::to_string_pretty(&response)?
156                    );
157                }
158
159                response.try_into()
160            } else {
161                let text = String::from_utf8_lossy(
162                    &response
163                        .into_body()
164                        .await
165                        .map_err(CompletionError::HttpError)?,
166                )
167                .into();
168
169                Err(CompletionError::ProviderError(text))
170            }
171        }
172        .instrument(span)
173        .await
174    }
175
176    async fn stream(
177        &self,
178        request: CompletionRequest,
179    ) -> Result<
180        crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
181        CompletionError,
182    > {
183        CompletionModel::stream(self, request).await
184    }
185}
186
187pub(crate) fn create_request_body(
188    completion_request: CompletionRequest,
189) -> Result<GenerateContentRequest, CompletionError> {
190    let mut full_history = Vec::new();
191    full_history.extend(completion_request.chat_history);
192
193    let additional_params = completion_request
194        .additional_params
195        .unwrap_or_else(|| Value::Object(Map::new()));
196
197    let AdditionalParameters {
198        mut generation_config,
199        additional_params,
200    } = serde_json::from_value::<AdditionalParameters>(additional_params)?;
201
202    // Apply output_schema to generation_config, creating one if needed
203    if let Some(schema) = completion_request.output_schema {
204        let cfg = generation_config.get_or_insert_with(GenerationConfig::default);
205        cfg.response_mime_type = Some("application/json".to_string());
206        cfg.response_json_schema = Some(schema.to_value());
207    }
208
209    generation_config = generation_config.map(|mut cfg| {
210        if let Some(temp) = completion_request.temperature {
211            cfg.temperature = Some(temp);
212        };
213
214        if let Some(max_tokens) = completion_request.max_tokens {
215            cfg.max_output_tokens = Some(max_tokens);
216        };
217
218        cfg
219    });
220
221    let system_instruction = completion_request.preamble.clone().map(|preamble| Content {
222        parts: vec![preamble.into()],
223        role: Some(Role::Model),
224    });
225
226    let tools = if completion_request.tools.is_empty() {
227        None
228    } else {
229        Some(vec![Tool::try_from(completion_request.tools)?])
230    };
231
232    let tool_config = if let Some(cfg) = completion_request.tool_choice {
233        Some(ToolConfig {
234            function_calling_config: Some(FunctionCallingMode::try_from(cfg)?),
235        })
236    } else {
237        None
238    };
239
240    let request = GenerateContentRequest {
241        contents: full_history
242            .into_iter()
243            .map(|msg| {
244                msg.try_into()
245                    .map_err(|e| CompletionError::RequestError(Box::new(e)))
246            })
247            .collect::<Result<Vec<_>, _>>()?,
248        generation_config,
249        safety_settings: None,
250        tools,
251        tool_config,
252        system_instruction,
253        additional_params,
254    };
255
256    Ok(request)
257}
258
259pub(crate) fn resolve_request_model(
260    default_model: &str,
261    completion_request: &CompletionRequest,
262) -> String {
263    completion_request
264        .model
265        .clone()
266        .unwrap_or_else(|| default_model.to_string())
267}
268
269pub(crate) fn completion_endpoint(model: &str) -> String {
270    format!("/v1beta/models/{model}:generateContent")
271}
272
273pub(crate) fn streaming_endpoint(model: &str) -> String {
274    format!("/v1beta/models/{model}:streamGenerateContent")
275}
276
277impl TryFrom<completion::ToolDefinition> for Tool {
278    type Error = CompletionError;
279
280    fn try_from(tool: completion::ToolDefinition) -> Result<Self, Self::Error> {
281        let parameters: Option<Schema> =
282            if tool.parameters == serde_json::json!({"type": "object", "properties": {}}) {
283                None
284            } else {
285                Some(tool.parameters.try_into()?)
286            };
287
288        Ok(Self {
289            function_declarations: vec![FunctionDeclaration {
290                name: tool.name,
291                description: tool.description,
292                parameters,
293            }],
294            code_execution: None,
295        })
296    }
297}
298
299impl TryFrom<Vec<completion::ToolDefinition>> for Tool {
300    type Error = CompletionError;
301
302    fn try_from(tools: Vec<completion::ToolDefinition>) -> Result<Self, Self::Error> {
303        let mut function_declarations = Vec::new();
304
305        for tool in tools {
306            let parameters =
307                if tool.parameters == serde_json::json!({"type": "object", "properties": {}}) {
308                    None
309                } else {
310                    match tool.parameters.try_into() {
311                        Ok(schema) => Some(schema),
312                        Err(e) => {
313                            let emsg = format!(
314                                "Tool '{}' could not be converted to a schema: {:?}",
315                                tool.name, e,
316                            );
317                            return Err(CompletionError::ProviderError(emsg));
318                        }
319                    }
320                };
321
322            function_declarations.push(FunctionDeclaration {
323                name: tool.name,
324                description: tool.description,
325                parameters,
326            });
327        }
328
329        Ok(Self {
330            function_declarations,
331            code_execution: None,
332        })
333    }
334}
335
336impl TryFrom<GenerateContentResponse> for completion::CompletionResponse<GenerateContentResponse> {
337    type Error = CompletionError;
338
339    fn try_from(response: GenerateContentResponse) -> Result<Self, Self::Error> {
340        let candidate = response.candidates.first().ok_or_else(|| {
341            CompletionError::ResponseError("No response candidates in response".into())
342        })?;
343
344        let content = candidate
345            .content
346            .as_ref()
347            .ok_or_else(|| {
348                let reason = candidate
349                    .finish_reason
350                    .as_ref()
351                    .map(|r| format!("finish_reason={r:?}"))
352                    .unwrap_or_else(|| "finish_reason=<unknown>".to_string());
353                let message = candidate
354                    .finish_message
355                    .as_deref()
356                    .unwrap_or("no finish message provided");
357                CompletionError::ResponseError(format!(
358                    "Gemini candidate missing content ({reason}, finish_message={message})"
359                ))
360            })?
361            .parts
362            .iter()
363            .map(
364                |Part {
365                     thought,
366                     thought_signature,
367                     part,
368                     ..
369                 }| {
370                    Ok(match part {
371                        PartKind::Text(text) => {
372                            if let Some(thought) = thought
373                                && *thought
374                            {
375                                completion::AssistantContent::Reasoning(
376                                    Reasoning::new_with_signature(text, thought_signature.clone()),
377                                )
378                            } else {
379                                completion::AssistantContent::text(text)
380                            }
381                        }
382                        PartKind::InlineData(inline_data) => {
383                            let mime_type =
384                                message::MediaType::from_mime_type(&inline_data.mime_type);
385
386                            match mime_type {
387                                Some(message::MediaType::Image(media_type)) => {
388                                    message::AssistantContent::image_base64(
389                                        &inline_data.data,
390                                        Some(media_type),
391                                        Some(message::ImageDetail::default()),
392                                    )
393                                }
394                                _ => {
395                                    return Err(CompletionError::ResponseError(format!(
396                                        "Unsupported media type {mime_type:?}"
397                                    )));
398                                }
399                            }
400                        }
401                        PartKind::FunctionCall(function_call) => {
402                            completion::AssistantContent::ToolCall(
403                                message::ToolCall::new(
404                                    function_call.name.clone(),
405                                    message::ToolFunction::new(
406                                        function_call.name.clone(),
407                                        function_call.args.clone(),
408                                    ),
409                                )
410                                .with_signature(thought_signature.clone()),
411                            )
412                        }
413                        _ => {
414                            return Err(CompletionError::ResponseError(
415                                "Response did not contain a message or tool call".into(),
416                            ));
417                        }
418                    })
419                },
420            )
421            .collect::<Result<Vec<_>, _>>()?;
422
423        let choice = OneOrMany::many(content).map_err(|_| {
424            CompletionError::ResponseError(
425                "Response contained no message or tool call (empty)".to_owned(),
426            )
427        })?;
428
429        let usage = response
430            .usage_metadata
431            .as_ref()
432            .map(|usage| completion::Usage {
433                input_tokens: usage.prompt_token_count as u64,
434                output_tokens: usage.candidates_token_count.unwrap_or(0) as u64,
435                total_tokens: usage.total_token_count as u64,
436                cached_input_tokens: 0,
437            })
438            .unwrap_or_default();
439
440        Ok(completion::CompletionResponse {
441            choice,
442            usage,
443            raw_response: response,
444            message_id: None,
445        })
446    }
447}
448
449pub mod gemini_api_types {
450    use crate::telemetry::ProviderResponseExt;
451    use std::{collections::HashMap, convert::Infallible, str::FromStr};
452
453    // =================================================================
454    // Gemini API Types
455    // =================================================================
456    use serde::{Deserialize, Serialize};
457    use serde_json::{Value, json};
458
459    use crate::completion::GetTokenUsage;
460    use crate::message::{DocumentSourceKind, ImageMediaType, MessageError, MimeType};
461    use crate::{
462        completion::CompletionError,
463        message::{self},
464        providers::gemini::gemini_api_types::{CodeExecutionResult, ExecutableCode},
465    };
466
467    #[derive(Debug, Deserialize, Serialize, Default)]
468    #[serde(rename_all = "camelCase")]
469    pub struct AdditionalParameters {
470        /// Change your Gemini request configuration.
471        pub generation_config: Option<GenerationConfig>,
472        /// Any additional parameters that you want.
473        #[serde(flatten, skip_serializing_if = "Option::is_none")]
474        pub additional_params: Option<serde_json::Value>,
475    }
476
477    impl AdditionalParameters {
478        pub fn with_config(mut self, cfg: GenerationConfig) -> Self {
479            self.generation_config = Some(cfg);
480            self
481        }
482
483        pub fn with_params(mut self, params: serde_json::Value) -> Self {
484            self.additional_params = Some(params);
485            self
486        }
487    }
488
489    /// Response from the model supporting multiple candidate responses.
490    /// Safety ratings and content filtering are reported for both prompt in GenerateContentResponse.prompt_feedback
491    /// and for each candidate in finishReason and in safetyRatings.
492    /// The API:
493    ///     - Returns either all requested candidates or none of them
494    ///     - Returns no candidates at all only if there was something wrong with the prompt (check promptFeedback)
495    ///     - Reports feedback on each candidate in finishReason and safetyRatings.
496    #[derive(Debug, Deserialize, Serialize)]
497    #[serde(rename_all = "camelCase")]
498    pub struct GenerateContentResponse {
499        pub response_id: String,
500        /// Candidate responses from the model.
501        pub candidates: Vec<ContentCandidate>,
502        /// Returns the prompt's feedback related to the content filters.
503        pub prompt_feedback: Option<PromptFeedback>,
504        /// Output only. Metadata on the generation requests' token usage.
505        pub usage_metadata: Option<UsageMetadata>,
506        pub model_version: Option<String>,
507    }
508
509    impl ProviderResponseExt for GenerateContentResponse {
510        type OutputMessage = ContentCandidate;
511        type Usage = UsageMetadata;
512
513        fn get_response_id(&self) -> Option<String> {
514            Some(self.response_id.clone())
515        }
516
517        fn get_response_model_name(&self) -> Option<String> {
518            None
519        }
520
521        fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
522            self.candidates.clone()
523        }
524
525        fn get_text_response(&self) -> Option<String> {
526            let str = self
527                .candidates
528                .iter()
529                .filter_map(|x| {
530                    let content = x.content.as_ref()?;
531                    if content.role.as_ref().is_none_or(|y| y != &Role::Model) {
532                        return None;
533                    }
534
535                    let res = content
536                        .parts
537                        .iter()
538                        .filter_map(|part| {
539                            if let PartKind::Text(ref str) = part.part {
540                                Some(str.to_owned())
541                            } else {
542                                None
543                            }
544                        })
545                        .collect::<Vec<String>>()
546                        .join("\n");
547
548                    Some(res)
549                })
550                .collect::<Vec<String>>()
551                .join("\n");
552
553            if str.is_empty() { None } else { Some(str) }
554        }
555
556        fn get_usage(&self) -> Option<Self::Usage> {
557            self.usage_metadata.clone()
558        }
559    }
560
561    /// A response candidate generated from the model.
562    #[derive(Clone, Debug, Deserialize, Serialize)]
563    #[serde(rename_all = "camelCase")]
564    pub struct ContentCandidate {
565        /// Output only. Generated content returned from the model.
566        #[serde(skip_serializing_if = "Option::is_none")]
567        pub content: Option<Content>,
568        /// Optional. Output only. The reason why the model stopped generating tokens.
569        /// If empty, the model has not stopped generating tokens.
570        pub finish_reason: Option<FinishReason>,
571        /// List of ratings for the safety of a response candidate.
572        /// There is at most one rating per category.
573        pub safety_ratings: Option<Vec<SafetyRating>>,
574        /// Output only. Citation information for model-generated candidate.
575        /// This field may be populated with recitation information for any text included in the content.
576        /// These are passages that are "recited" from copyrighted material in the foundational LLM's training data.
577        pub citation_metadata: Option<CitationMetadata>,
578        /// Output only. Token count for this candidate.
579        pub token_count: Option<i32>,
580        /// Output only.
581        pub avg_logprobs: Option<f64>,
582        /// Output only. Log-likelihood scores for the response tokens and top tokens
583        pub logprobs_result: Option<LogprobsResult>,
584        /// Output only. Index of the candidate in the list of response candidates.
585        pub index: Option<i32>,
586        /// Output only. Additional information about why the model stopped generating tokens.
587        pub finish_message: Option<String>,
588    }
589
590    #[derive(Clone, Debug, Deserialize, Serialize)]
591    pub struct Content {
592        /// Ordered Parts that constitute a single message. Parts may have different MIME types.
593        #[serde(default)]
594        pub parts: Vec<Part>,
595        /// The producer of the content. Must be either 'user' or 'model'.
596        /// Useful to set for multi-turn conversations, otherwise can be left blank or unset.
597        pub role: Option<Role>,
598    }
599
600    impl TryFrom<message::Message> for Content {
601        type Error = message::MessageError;
602
603        fn try_from(msg: message::Message) -> Result<Self, Self::Error> {
604            Ok(match msg {
605                message::Message::User { content } => Content {
606                    parts: content
607                        .into_iter()
608                        .map(|c| c.try_into())
609                        .collect::<Result<Vec<_>, _>>()?,
610                    role: Some(Role::User),
611                },
612                message::Message::Assistant { content, .. } => Content {
613                    role: Some(Role::Model),
614                    parts: content
615                        .into_iter()
616                        .map(|content| content.try_into())
617                        .collect::<Result<Vec<_>, _>>()?,
618                },
619            })
620        }
621    }
622
623    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
624    #[serde(rename_all = "lowercase")]
625    pub enum Role {
626        User,
627        Model,
628    }
629
630    #[derive(Debug, Default, Deserialize, Serialize, Clone, PartialEq)]
631    #[serde(rename_all = "camelCase")]
632    pub struct Part {
633        /// whether or not the part is a reasoning/thinking text or not
634        #[serde(skip_serializing_if = "Option::is_none")]
635        pub thought: Option<bool>,
636        /// an opaque sig for the thought so it can be reused - is a base64 string
637        #[serde(skip_serializing_if = "Option::is_none")]
638        pub thought_signature: Option<String>,
639        #[serde(flatten)]
640        pub part: PartKind,
641        #[serde(flatten, skip_serializing_if = "Option::is_none")]
642        pub additional_params: Option<Value>,
643    }
644
645    /// A datatype containing media that is part of a multi-part [Content] message.
646    /// A Part consists of data which has an associated datatype. A Part can only contain one of the accepted types in Part.data.
647    /// A Part must have a fixed IANA MIME type identifying the type and subtype of the media if the inlineData field is filled with raw bytes.
648    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
649    #[serde(rename_all = "camelCase")]
650    pub enum PartKind {
651        Text(String),
652        InlineData(Blob),
653        FunctionCall(FunctionCall),
654        FunctionResponse(FunctionResponse),
655        FileData(FileData),
656        ExecutableCode(ExecutableCode),
657        CodeExecutionResult(CodeExecutionResult),
658    }
659
660    // This default instance is primarily so we can easily fill in the optional fields of `Part`
661    // So this instance for `PartKind` (and the allocation it would cause) should be optimized away
662    impl Default for PartKind {
663        fn default() -> Self {
664            Self::Text(String::new())
665        }
666    }
667
668    impl From<String> for Part {
669        fn from(text: String) -> Self {
670            Self {
671                thought: Some(false),
672                thought_signature: None,
673                part: PartKind::Text(text),
674                additional_params: None,
675            }
676        }
677    }
678
679    impl From<&str> for Part {
680        fn from(text: &str) -> Self {
681            Self::from(text.to_string())
682        }
683    }
684
685    impl FromStr for Part {
686        type Err = Infallible;
687
688        fn from_str(s: &str) -> Result<Self, Self::Err> {
689            Ok(s.into())
690        }
691    }
692
693    impl TryFrom<(ImageMediaType, DocumentSourceKind)> for PartKind {
694        type Error = message::MessageError;
695        fn try_from(
696            (mime_type, doc_src): (ImageMediaType, DocumentSourceKind),
697        ) -> Result<Self, Self::Error> {
698            let mime_type = mime_type.to_mime_type().to_string();
699            let part = match doc_src {
700                DocumentSourceKind::Url(url) => PartKind::FileData(FileData {
701                    mime_type: Some(mime_type),
702                    file_uri: url,
703                }),
704                DocumentSourceKind::Base64(data) | DocumentSourceKind::String(data) => {
705                    PartKind::InlineData(Blob { mime_type, data })
706                }
707                DocumentSourceKind::Raw(_) => {
708                    return Err(message::MessageError::ConversionError(
709                        "Raw files not supported, encode as base64 first".into(),
710                    ));
711                }
712                DocumentSourceKind::Unknown => {
713                    return Err(message::MessageError::ConversionError(
714                        "Can't convert an unknown document source".to_string(),
715                    ));
716                }
717            };
718
719            Ok(part)
720        }
721    }
722
723    impl TryFrom<message::UserContent> for Part {
724        type Error = message::MessageError;
725
726        fn try_from(content: message::UserContent) -> Result<Self, Self::Error> {
727            match content {
728                message::UserContent::Text(message::Text { text }) => Ok(Part {
729                    thought: Some(false),
730                    thought_signature: None,
731                    part: PartKind::Text(text),
732                    additional_params: None,
733                }),
734                message::UserContent::ToolResult(message::ToolResult { id, content, .. }) => {
735                    let mut response_json: Option<serde_json::Value> = None;
736                    let mut parts: Vec<FunctionResponsePart> = Vec::new();
737
738                    for item in content.iter() {
739                        match item {
740                            message::ToolResultContent::Text(text) => {
741                                let result: serde_json::Value =
742                                    serde_json::from_str(&text.text).unwrap_or_else(|error| {
743                                        tracing::trace!(
744                                            ?error,
745                                            "Tool result is not a valid JSON, treat it as normal string"
746                                        );
747                                        json!(&text.text)
748                                    });
749
750                                response_json = Some(match response_json {
751                                    Some(mut existing) => {
752                                        if let serde_json::Value::Object(ref mut map) = existing {
753                                            map.insert("text".to_string(), result);
754                                        }
755                                        existing
756                                    }
757                                    None => json!({ "result": result }),
758                                });
759                            }
760                            message::ToolResultContent::Image(image) => {
761                                let part = match &image.data {
762                                    DocumentSourceKind::Base64(b64) => {
763                                        let mime_type = image
764                                            .media_type
765                                            .as_ref()
766                                            .ok_or(message::MessageError::ConversionError(
767                                                "Image media type is required for Gemini tool results".to_string(),
768                                            ))?
769                                            .to_mime_type();
770
771                                        FunctionResponsePart {
772                                            inline_data: Some(FunctionResponseInlineData {
773                                                mime_type: mime_type.to_string(),
774                                                data: b64.clone(),
775                                                display_name: None,
776                                            }),
777                                            file_data: None,
778                                        }
779                                    }
780                                    DocumentSourceKind::Url(url) => {
781                                        let mime_type = image
782                                            .media_type
783                                            .as_ref()
784                                            .map(|mt| mt.to_mime_type().to_string());
785
786                                        FunctionResponsePart {
787                                            inline_data: None,
788                                            file_data: Some(FileData {
789                                                mime_type,
790                                                file_uri: url.clone(),
791                                            }),
792                                        }
793                                    }
794                                    _ => {
795                                        return Err(message::MessageError::ConversionError(
796                                            "Unsupported image source kind for tool results"
797                                                .to_string(),
798                                        ));
799                                    }
800                                };
801                                parts.push(part);
802                            }
803                        }
804                    }
805
806                    Ok(Part {
807                        thought: Some(false),
808                        thought_signature: None,
809                        part: PartKind::FunctionResponse(FunctionResponse {
810                            name: id,
811                            response: response_json,
812                            parts: if parts.is_empty() { None } else { Some(parts) },
813                        }),
814                        additional_params: None,
815                    })
816                }
817                message::UserContent::Image(message::Image {
818                    data, media_type, ..
819                }) => match media_type {
820                    Some(media_type) => match media_type {
821                        message::ImageMediaType::JPEG
822                        | message::ImageMediaType::PNG
823                        | message::ImageMediaType::WEBP
824                        | message::ImageMediaType::HEIC
825                        | message::ImageMediaType::HEIF => {
826                            let part = PartKind::try_from((media_type, data))?;
827                            Ok(Part {
828                                thought: Some(false),
829                                thought_signature: None,
830                                part,
831                                additional_params: None,
832                            })
833                        }
834                        _ => Err(message::MessageError::ConversionError(format!(
835                            "Unsupported image media type {media_type:?}"
836                        ))),
837                    },
838                    None => Err(message::MessageError::ConversionError(
839                        "Media type for image is required for Gemini".to_string(),
840                    )),
841                },
842                message::UserContent::Document(message::Document {
843                    data, media_type, ..
844                }) => {
845                    let Some(media_type) = media_type else {
846                        return Err(MessageError::ConversionError(
847                            "A mime type is required for document inputs to Gemini".to_string(),
848                        ));
849                    };
850
851                    if !media_type.is_code() {
852                        let mime_type = media_type.to_mime_type().to_string();
853
854                        let part = match data {
855                            DocumentSourceKind::Url(file_uri) => PartKind::FileData(FileData {
856                                mime_type: Some(mime_type),
857                                file_uri,
858                            }),
859                            DocumentSourceKind::Base64(data) | DocumentSourceKind::String(data) => {
860                                PartKind::InlineData(Blob { mime_type, data })
861                            }
862                            DocumentSourceKind::Raw(_) => {
863                                return Err(message::MessageError::ConversionError(
864                                    "Raw files not supported, encode as base64 first".into(),
865                                ));
866                            }
867                            _ => {
868                                return Err(message::MessageError::ConversionError(
869                                    "Document has no body".to_string(),
870                                ));
871                            }
872                        };
873
874                        Ok(Part {
875                            thought: Some(false),
876                            part,
877                            ..Default::default()
878                        })
879                    } else {
880                        Err(message::MessageError::ConversionError(format!(
881                            "Unsupported document media type {media_type:?}"
882                        )))
883                    }
884                }
885
886                message::UserContent::Audio(message::Audio {
887                    data, media_type, ..
888                }) => {
889                    let Some(media_type) = media_type else {
890                        return Err(MessageError::ConversionError(
891                            "A mime type is required for audio inputs to Gemini".to_string(),
892                        ));
893                    };
894
895                    let mime_type = media_type.to_mime_type().to_string();
896
897                    let part = match data {
898                        DocumentSourceKind::Base64(data) => {
899                            PartKind::InlineData(Blob { data, mime_type })
900                        }
901
902                        DocumentSourceKind::Url(file_uri) => PartKind::FileData(FileData {
903                            mime_type: Some(mime_type),
904                            file_uri,
905                        }),
906                        DocumentSourceKind::String(_) => {
907                            return Err(message::MessageError::ConversionError(
908                                "Strings cannot be used as audio files!".into(),
909                            ));
910                        }
911                        DocumentSourceKind::Raw(_) => {
912                            return Err(message::MessageError::ConversionError(
913                                "Raw files not supported, encode as base64 first".into(),
914                            ));
915                        }
916                        DocumentSourceKind::Unknown => {
917                            return Err(message::MessageError::ConversionError(
918                                "Content has no body".to_string(),
919                            ));
920                        }
921                    };
922
923                    Ok(Part {
924                        thought: Some(false),
925                        part,
926                        ..Default::default()
927                    })
928                }
929                message::UserContent::Video(message::Video {
930                    data,
931                    media_type,
932                    additional_params,
933                    ..
934                }) => {
935                    let mime_type = media_type.map(|media_ty| media_ty.to_mime_type().to_string());
936
937                    let part = match data {
938                        DocumentSourceKind::Url(file_uri) => {
939                            if file_uri.starts_with("https://www.youtube.com") {
940                                PartKind::FileData(FileData {
941                                    mime_type,
942                                    file_uri,
943                                })
944                            } else {
945                                if mime_type.is_none() {
946                                    return Err(MessageError::ConversionError(
947                                        "A mime type is required for non-Youtube video file inputs to Gemini"
948                                            .to_string(),
949                                    ));
950                                }
951
952                                PartKind::FileData(FileData {
953                                    mime_type,
954                                    file_uri,
955                                })
956                            }
957                        }
958                        DocumentSourceKind::Base64(data) => {
959                            let Some(mime_type) = mime_type else {
960                                return Err(MessageError::ConversionError(
961                                    "A media type is expected for base64 encoded strings"
962                                        .to_string(),
963                                ));
964                            };
965                            PartKind::InlineData(Blob { mime_type, data })
966                        }
967                        DocumentSourceKind::String(_) => {
968                            return Err(message::MessageError::ConversionError(
969                                "Strings cannot be used as audio files!".into(),
970                            ));
971                        }
972                        DocumentSourceKind::Raw(_) => {
973                            return Err(message::MessageError::ConversionError(
974                                "Raw file data not supported, encode as base64 first".into(),
975                            ));
976                        }
977                        DocumentSourceKind::Unknown => {
978                            return Err(message::MessageError::ConversionError(
979                                "Media type for video is required for Gemini".to_string(),
980                            ));
981                        }
982                    };
983
984                    Ok(Part {
985                        thought: Some(false),
986                        thought_signature: None,
987                        part,
988                        additional_params,
989                    })
990                }
991            }
992        }
993    }
994
995    impl TryFrom<message::AssistantContent> for Part {
996        type Error = message::MessageError;
997
998        fn try_from(content: message::AssistantContent) -> Result<Self, Self::Error> {
999            match content {
1000                message::AssistantContent::Text(message::Text { text }) => Ok(text.into()),
1001                message::AssistantContent::Image(message::Image {
1002                    data, media_type, ..
1003                }) => match media_type {
1004                    Some(media_type) => match media_type {
1005                        message::ImageMediaType::JPEG
1006                        | message::ImageMediaType::PNG
1007                        | message::ImageMediaType::WEBP
1008                        | message::ImageMediaType::HEIC
1009                        | message::ImageMediaType::HEIF => {
1010                            let part = PartKind::try_from((media_type, data))?;
1011                            Ok(Part {
1012                                thought: Some(false),
1013                                thought_signature: None,
1014                                part,
1015                                additional_params: None,
1016                            })
1017                        }
1018                        _ => Err(message::MessageError::ConversionError(format!(
1019                            "Unsupported image media type {media_type:?}"
1020                        ))),
1021                    },
1022                    None => Err(message::MessageError::ConversionError(
1023                        "Media type for image is required for Gemini".to_string(),
1024                    )),
1025                },
1026                message::AssistantContent::ToolCall(tool_call) => Ok(tool_call.into()),
1027                message::AssistantContent::Reasoning(reasoning) => Ok(Part {
1028                    thought: Some(true),
1029                    thought_signature: reasoning.first_signature().map(str::to_owned),
1030                    part: PartKind::Text(reasoning.display_text()),
1031                    additional_params: None,
1032                }),
1033            }
1034        }
1035    }
1036
1037    impl From<message::ToolCall> for Part {
1038        fn from(tool_call: message::ToolCall) -> Self {
1039            Self {
1040                thought: Some(false),
1041                thought_signature: tool_call.signature,
1042                part: PartKind::FunctionCall(FunctionCall {
1043                    name: tool_call.function.name,
1044                    args: tool_call.function.arguments,
1045                }),
1046                additional_params: None,
1047            }
1048        }
1049    }
1050
1051    /// Raw media bytes.
1052    /// Text should not be sent as raw bytes, use the 'text' field.
1053    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1054    #[serde(rename_all = "camelCase")]
1055    pub struct Blob {
1056        /// The IANA standard MIME type of the source data. Examples: - image/png - image/jpeg
1057        /// If an unsupported MIME type is provided, an error will be returned.
1058        pub mime_type: String,
1059        /// Raw bytes for media formats. A base64-encoded string.
1060        pub data: String,
1061    }
1062
1063    /// A predicted FunctionCall returned from the model that contains a string representing the
1064    /// FunctionDeclaration.name with the arguments and their values.
1065    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1066    pub struct FunctionCall {
1067        /// Required. The name of the function to call. Must be a-z, A-Z, 0-9, or contain underscores
1068        /// and dashes, with a maximum length of 63.
1069        pub name: String,
1070        /// Optional. The function parameters and values in JSON object format.
1071        pub args: serde_json::Value,
1072    }
1073
1074    impl From<message::ToolCall> for FunctionCall {
1075        fn from(tool_call: message::ToolCall) -> Self {
1076            Self {
1077                name: tool_call.function.name,
1078                args: tool_call.function.arguments,
1079            }
1080        }
1081    }
1082
1083    /// The result output from a FunctionCall that contains a string representing the FunctionDeclaration.name
1084    /// and a structured JSON object containing any output from the function is used as context to the model.
1085    /// This should contain the result of aFunctionCall made based on model prediction.
1086    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1087    pub struct FunctionResponse {
1088        /// The name of the function to call. Must be a-z, A-Z, 0-9, or contain underscores and dashes,
1089        /// with a maximum length of 63.
1090        pub name: String,
1091        /// The function response in JSON object format.
1092        #[serde(skip_serializing_if = "Option::is_none")]
1093        pub response: Option<serde_json::Value>,
1094        /// Multimodal parts for the function response (e.g., images).
1095        #[serde(skip_serializing_if = "Option::is_none")]
1096        pub parts: Option<Vec<FunctionResponsePart>>,
1097    }
1098
1099    /// A part of a multimodal function response.
1100    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1101    #[serde(rename_all = "camelCase")]
1102    pub struct FunctionResponsePart {
1103        /// Inline data containing base64-encoded media content.
1104        #[serde(skip_serializing_if = "Option::is_none")]
1105        pub inline_data: Option<FunctionResponseInlineData>,
1106        /// File data containing a URI reference.
1107        #[serde(skip_serializing_if = "Option::is_none")]
1108        pub file_data: Option<FileData>,
1109    }
1110
1111    /// Inline data for function response parts.
1112    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1113    #[serde(rename_all = "camelCase")]
1114    pub struct FunctionResponseInlineData {
1115        /// The IANA standard MIME type of the source data.
1116        pub mime_type: String,
1117        /// Raw bytes for media formats. A base64-encoded string.
1118        pub data: String,
1119        /// Optional display name for the content.
1120        #[serde(skip_serializing_if = "Option::is_none")]
1121        pub display_name: Option<String>,
1122    }
1123
1124    /// URI based data.
1125    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1126    #[serde(rename_all = "camelCase")]
1127    pub struct FileData {
1128        /// Optional. The IANA standard MIME type of the source data.
1129        pub mime_type: Option<String>,
1130        /// Required. URI.
1131        pub file_uri: String,
1132    }
1133
1134    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1135    pub struct SafetyRating {
1136        pub category: HarmCategory,
1137        pub probability: HarmProbability,
1138    }
1139
1140    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1141    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1142    pub enum HarmProbability {
1143        HarmProbabilityUnspecified,
1144        Negligible,
1145        Low,
1146        Medium,
1147        High,
1148    }
1149
1150    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1151    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1152    pub enum HarmCategory {
1153        HarmCategoryUnspecified,
1154        HarmCategoryDerogatory,
1155        HarmCategoryToxicity,
1156        HarmCategoryViolence,
1157        HarmCategorySexually,
1158        HarmCategoryMedical,
1159        HarmCategoryDangerous,
1160        HarmCategoryHarassment,
1161        HarmCategoryHateSpeech,
1162        HarmCategorySexuallyExplicit,
1163        HarmCategoryDangerousContent,
1164        HarmCategoryCivicIntegrity,
1165    }
1166
1167    #[derive(Debug, Deserialize, Clone, Default, Serialize)]
1168    #[serde(rename_all = "camelCase")]
1169    pub struct UsageMetadata {
1170        pub prompt_token_count: i32,
1171        #[serde(skip_serializing_if = "Option::is_none")]
1172        pub cached_content_token_count: Option<i32>,
1173        #[serde(skip_serializing_if = "Option::is_none")]
1174        pub candidates_token_count: Option<i32>,
1175        pub total_token_count: i32,
1176        #[serde(skip_serializing_if = "Option::is_none")]
1177        pub thoughts_token_count: Option<i32>,
1178    }
1179
1180    impl std::fmt::Display for UsageMetadata {
1181        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1182            write!(
1183                f,
1184                "Prompt token count: {}\nCached content token count: {}\nCandidates token count: {}\nTotal token count: {}",
1185                self.prompt_token_count,
1186                match self.cached_content_token_count {
1187                    Some(count) => count.to_string(),
1188                    None => "n/a".to_string(),
1189                },
1190                match self.candidates_token_count {
1191                    Some(count) => count.to_string(),
1192                    None => "n/a".to_string(),
1193                },
1194                self.total_token_count
1195            )
1196        }
1197    }
1198
1199    impl GetTokenUsage for UsageMetadata {
1200        fn token_usage(&self) -> Option<crate::completion::Usage> {
1201            let mut usage = crate::completion::Usage::new();
1202
1203            usage.input_tokens = self.prompt_token_count as u64;
1204            usage.output_tokens = (self.cached_content_token_count.unwrap_or_default()
1205                + self.candidates_token_count.unwrap_or_default()
1206                + self.thoughts_token_count.unwrap_or_default())
1207                as u64;
1208            usage.total_tokens = usage.input_tokens + usage.output_tokens;
1209
1210            Some(usage)
1211        }
1212    }
1213
1214    /// A set of the feedback metadata the prompt specified in [GenerateContentRequest.contents](GenerateContentRequest).
1215    #[derive(Debug, Deserialize, Serialize)]
1216    #[serde(rename_all = "camelCase")]
1217    pub struct PromptFeedback {
1218        /// Optional. If set, the prompt was blocked and no candidates are returned. Rephrase the prompt.
1219        pub block_reason: Option<BlockReason>,
1220        /// Ratings for safety of the prompt. There is at most one rating per category.
1221        pub safety_ratings: Option<Vec<SafetyRating>>,
1222    }
1223
1224    /// Reason why a prompt was blocked by the model
1225    #[derive(Debug, Deserialize, Serialize)]
1226    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1227    pub enum BlockReason {
1228        /// Default value. This value is unused.
1229        BlockReasonUnspecified,
1230        /// Prompt was blocked due to safety reasons. Inspect safetyRatings to understand which safety category blocked it.
1231        Safety,
1232        /// Prompt was blocked due to unknown reasons.
1233        Other,
1234        /// Prompt was blocked due to the terms which are included from the terminology blocklist.
1235        Blocklist,
1236        /// Prompt was blocked due to prohibited content.
1237        ProhibitedContent,
1238    }
1239
1240    #[derive(Clone, Debug, Deserialize, Serialize)]
1241    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1242    pub enum FinishReason {
1243        /// Default value. This value is unused.
1244        FinishReasonUnspecified,
1245        /// Natural stop point of the model or provided stop sequence.
1246        Stop,
1247        /// The maximum number of tokens as specified in the request was reached.
1248        MaxTokens,
1249        /// The response candidate content was flagged for safety reasons.
1250        Safety,
1251        /// The response candidate content was flagged for recitation reasons.
1252        Recitation,
1253        /// The response candidate content was flagged for using an unsupported language.
1254        Language,
1255        /// Unknown reason.
1256        Other,
1257        /// Token generation stopped because the content contains forbidden terms.
1258        Blocklist,
1259        /// Token generation stopped for potentially containing prohibited content.
1260        ProhibitedContent,
1261        /// Token generation stopped because the content potentially contains Sensitive Personally Identifiable Information (SPII).
1262        Spii,
1263        /// The function call generated by the model is invalid.
1264        MalformedFunctionCall,
1265    }
1266
1267    #[derive(Clone, Debug, Deserialize, Serialize)]
1268    #[serde(rename_all = "camelCase")]
1269    pub struct CitationMetadata {
1270        pub citation_sources: Vec<CitationSource>,
1271    }
1272
1273    #[derive(Clone, Debug, Deserialize, Serialize)]
1274    #[serde(rename_all = "camelCase")]
1275    pub struct CitationSource {
1276        #[serde(skip_serializing_if = "Option::is_none")]
1277        pub uri: Option<String>,
1278        #[serde(skip_serializing_if = "Option::is_none")]
1279        pub start_index: Option<i32>,
1280        #[serde(skip_serializing_if = "Option::is_none")]
1281        pub end_index: Option<i32>,
1282        #[serde(skip_serializing_if = "Option::is_none")]
1283        pub license: Option<String>,
1284    }
1285
1286    #[derive(Clone, Debug, Deserialize, Serialize)]
1287    #[serde(rename_all = "camelCase")]
1288    pub struct LogprobsResult {
1289        pub top_candidate: Vec<TopCandidate>,
1290        pub chosen_candidate: Vec<LogProbCandidate>,
1291    }
1292
1293    #[derive(Clone, Debug, Deserialize, Serialize)]
1294    pub struct TopCandidate {
1295        pub candidates: Vec<LogProbCandidate>,
1296    }
1297
1298    #[derive(Clone, Debug, Deserialize, Serialize)]
1299    #[serde(rename_all = "camelCase")]
1300    pub struct LogProbCandidate {
1301        pub token: String,
1302        pub token_id: String,
1303        pub log_probability: f64,
1304    }
1305
1306    /// Gemini API Configuration options for model generation and outputs. Not all parameters are
1307    /// configurable for every model. From [Gemini API Reference](https://ai.google.dev/api/generate-content#generationconfig)
1308    /// ### Rig Note:
1309    /// Can be used to construct a typesafe `additional_params` in rig::[AgentBuilder](crate::agent::AgentBuilder).
1310    #[derive(Debug, Deserialize, Serialize)]
1311    #[serde(rename_all = "camelCase")]
1312    pub struct GenerationConfig {
1313        /// The set of character sequences (up to 5) that will stop output generation. If specified, the API will stop
1314        /// at the first appearance of a stop_sequence. The stop sequence will not be included as part of the response.
1315        #[serde(skip_serializing_if = "Option::is_none")]
1316        pub stop_sequences: Option<Vec<String>>,
1317        /// MIME type of the generated candidate text. Supported MIME types are:
1318        ///     - text/plain:  (default) Text output
1319        ///     - application/json: JSON response in the response candidates.
1320        ///     - text/x.enum: ENUM as a string response in the response candidates.
1321        /// Refer to the docs for a list of all supported text MIME types
1322        #[serde(skip_serializing_if = "Option::is_none")]
1323        pub response_mime_type: Option<String>,
1324        /// Output schema of the generated candidate text. Schemas must be a subset of the OpenAPI schema and can be
1325        /// objects, primitives or arrays. If set, a compatible responseMimeType must also  be set. Compatible MIME
1326        /// types: application/json: Schema for JSON response. Refer to the JSON text generation guide for more details.
1327        #[serde(skip_serializing_if = "Option::is_none")]
1328        pub response_schema: Option<Schema>,
1329        /// Optional. The output schema of the generated response.
1330        /// This is an alternative to responseSchema that accepts a standard JSON Schema.
1331        /// If this is set, responseSchema must be omitted.
1332        /// Compatible MIME type: application/json.
1333        /// Supported properties: $id, $defs, $ref, type, properties, etc.
1334        #[serde(
1335            skip_serializing_if = "Option::is_none",
1336            rename = "_responseJsonSchema"
1337        )]
1338        pub _response_json_schema: Option<Value>,
1339        /// Internal or alternative representation for `response_json_schema`.
1340        #[serde(skip_serializing_if = "Option::is_none")]
1341        pub response_json_schema: Option<Value>,
1342        /// Number of generated responses to return. Currently, this value can only be set to 1. If
1343        /// unset, this will default to 1.
1344        #[serde(skip_serializing_if = "Option::is_none")]
1345        pub candidate_count: Option<i32>,
1346        /// The maximum number of tokens to include in a response candidate. Note: The default value varies by model, see
1347        /// the Model.output_token_limit attribute of the Model returned from the getModel function.
1348        #[serde(skip_serializing_if = "Option::is_none")]
1349        pub max_output_tokens: Option<u64>,
1350        /// Controls the randomness of the output. Note: The default value varies by model, see the Model.temperature
1351        /// attribute of the Model returned from the getModel function. Values can range from [0.0, 2.0].
1352        #[serde(skip_serializing_if = "Option::is_none")]
1353        pub temperature: Option<f64>,
1354        /// The maximum cumulative probability of tokens to consider when sampling. The model uses combined Top-k and
1355        /// Top-p (nucleus) sampling. Tokens are sorted based on their assigned probabilities so that only the most
1356        /// likely tokens are considered. Top-k sampling directly limits the maximum number of tokens to consider, while
1357        /// Nucleus sampling limits the number of tokens based on the cumulative probability. Note: The default value
1358        /// varies by Model and is specified by theModel.top_p attribute returned from the getModel function. An empty
1359        /// topK attribute indicates that the model doesn't apply top-k sampling and doesn't allow setting topK on requests.
1360        #[serde(skip_serializing_if = "Option::is_none")]
1361        pub top_p: Option<f64>,
1362        /// The maximum number of tokens to consider when sampling. Gemini models use Top-p (nucleus) sampling or a
1363        /// combination of Top-k and nucleus sampling. Top-k sampling considers the set of topK most probable tokens.
1364        /// Models running with nucleus sampling don't allow topK setting. Note: The default value varies by Model and is
1365        /// specified by theModel.top_p attribute returned from the getModel function. An empty topK attribute indicates
1366        /// that the model doesn't apply top-k sampling and doesn't allow setting topK on requests.
1367        #[serde(skip_serializing_if = "Option::is_none")]
1368        pub top_k: Option<i32>,
1369        /// Presence penalty applied to the next token's logprobs if the token has already been seen in the response.
1370        /// This penalty is binary on/off and not dependent on the number of times the token is used (after the first).
1371        /// Use frequencyPenalty for a penalty that increases with each use. A positive penalty will discourage the use
1372        /// of tokens that have already been used in the response, increasing the vocabulary. A negative penalty will
1373        /// encourage the use of tokens that have already been used in the response, decreasing the vocabulary.
1374        #[serde(skip_serializing_if = "Option::is_none")]
1375        pub presence_penalty: Option<f64>,
1376        /// Frequency penalty applied to the next token's logprobs, multiplied by the number of times each token has been
1377        /// seen in the response so far. A positive penalty will discourage the use of tokens that have already been
1378        /// used, proportional to the number of times the token has been used: The more a token is used, the more
1379        /// difficult it is for the  model to use that token again increasing the vocabulary of responses. Caution: A
1380        /// negative penalty will encourage the model to reuse tokens proportional to the number of times the token has
1381        /// been used. Small negative values will reduce the vocabulary of a response. Larger negative values will cause
1382        /// the model to  repeating a common token until it hits the maxOutputTokens limit: "...the the the the the...".
1383        #[serde(skip_serializing_if = "Option::is_none")]
1384        pub frequency_penalty: Option<f64>,
1385        /// If true, export the logprobs results in response.
1386        #[serde(skip_serializing_if = "Option::is_none")]
1387        pub response_logprobs: Option<bool>,
1388        /// Only valid if responseLogprobs=True. This sets the number of top logprobs to return at each decoding step in
1389        /// [Candidate.logprobs_result].
1390        #[serde(skip_serializing_if = "Option::is_none")]
1391        pub logprobs: Option<i32>,
1392        /// Configuration for thinking/reasoning.
1393        #[serde(skip_serializing_if = "Option::is_none")]
1394        pub thinking_config: Option<ThinkingConfig>,
1395        #[serde(skip_serializing_if = "Option::is_none")]
1396        pub image_config: Option<ImageConfig>,
1397    }
1398
1399    impl Default for GenerationConfig {
1400        fn default() -> Self {
1401            Self {
1402                temperature: Some(1.0),
1403                max_output_tokens: Some(4096),
1404                stop_sequences: None,
1405                response_mime_type: None,
1406                response_schema: None,
1407                _response_json_schema: None,
1408                response_json_schema: None,
1409                candidate_count: None,
1410                top_p: None,
1411                top_k: None,
1412                presence_penalty: None,
1413                frequency_penalty: None,
1414                response_logprobs: None,
1415                logprobs: None,
1416                thinking_config: None,
1417                image_config: None,
1418            }
1419        }
1420    }
1421
1422    #[derive(Debug, Deserialize, Serialize)]
1423    #[serde(rename_all = "camelCase")]
1424    pub struct ThinkingConfig {
1425        pub thinking_budget: u32,
1426        pub include_thoughts: Option<bool>,
1427    }
1428
1429    #[derive(Debug, Deserialize, Serialize)]
1430    #[serde(rename_all = "camelCase")]
1431    pub struct ImageConfig {
1432        #[serde(skip_serializing_if = "Option::is_none")]
1433        pub aspect_ratio: Option<String>,
1434        #[serde(skip_serializing_if = "Option::is_none")]
1435        pub image_size: Option<String>,
1436    }
1437
1438    /// The Schema object allows the definition of input and output data types. These types can be objects, but also
1439    /// primitives and arrays. Represents a select subset of an OpenAPI 3.0 schema object.
1440    /// From [Gemini API Reference](https://ai.google.dev/api/caching#Schema)
1441    #[derive(Debug, Deserialize, Serialize, Clone)]
1442    pub struct Schema {
1443        pub r#type: String,
1444        #[serde(skip_serializing_if = "Option::is_none")]
1445        pub format: Option<String>,
1446        #[serde(skip_serializing_if = "Option::is_none")]
1447        pub description: Option<String>,
1448        #[serde(skip_serializing_if = "Option::is_none")]
1449        pub nullable: Option<bool>,
1450        #[serde(skip_serializing_if = "Option::is_none")]
1451        pub r#enum: Option<Vec<String>>,
1452        #[serde(skip_serializing_if = "Option::is_none")]
1453        pub max_items: Option<i32>,
1454        #[serde(skip_serializing_if = "Option::is_none")]
1455        pub min_items: Option<i32>,
1456        #[serde(skip_serializing_if = "Option::is_none")]
1457        pub properties: Option<HashMap<String, Schema>>,
1458        #[serde(skip_serializing_if = "Option::is_none")]
1459        pub required: Option<Vec<String>>,
1460        #[serde(skip_serializing_if = "Option::is_none")]
1461        pub items: Option<Box<Schema>>,
1462    }
1463
1464    /// Flattens a JSON schema by resolving all `$ref` references inline.
1465    /// It takes a JSON schema that may contain `$ref` references to definitions
1466    /// in `$defs` or `definitions` sections and returns a new schema with all references
1467    /// resolved and inlined. This is necessary for APIs like Gemini that don't support
1468    /// schema references.
1469    pub fn flatten_schema(mut schema: Value) -> Result<Value, CompletionError> {
1470        // extracting $defs if they exist
1471        let defs = if let Some(obj) = schema.as_object() {
1472            obj.get("$defs").or_else(|| obj.get("definitions")).cloned()
1473        } else {
1474            None
1475        };
1476
1477        let Some(defs_value) = defs else {
1478            return Ok(schema);
1479        };
1480
1481        let Some(defs_obj) = defs_value.as_object() else {
1482            return Err(CompletionError::ResponseError(
1483                "$defs must be an object".into(),
1484            ));
1485        };
1486
1487        resolve_refs(&mut schema, defs_obj)?;
1488
1489        // removing $defs from the final schema because we have inlined everything
1490        if let Some(obj) = schema.as_object_mut() {
1491            obj.remove("$defs");
1492            obj.remove("definitions");
1493        }
1494
1495        Ok(schema)
1496    }
1497
1498    /// Recursively resolves all `$ref` references in a JSON value by
1499    /// replacing them with their definitions.
1500    fn resolve_refs(
1501        value: &mut Value,
1502        defs: &serde_json::Map<String, Value>,
1503    ) -> Result<(), CompletionError> {
1504        match value {
1505            Value::Object(obj) => {
1506                if let Some(ref_value) = obj.get("$ref")
1507                    && let Some(ref_str) = ref_value.as_str()
1508                {
1509                    // "#/$defs/Person" -> "Person"
1510                    let def_name = parse_ref_path(ref_str)?;
1511
1512                    let def = defs.get(&def_name).ok_or_else(|| {
1513                        CompletionError::ResponseError(format!("Reference not found: {}", ref_str))
1514                    })?;
1515
1516                    let mut resolved = def.clone();
1517                    resolve_refs(&mut resolved, defs)?;
1518                    *value = resolved;
1519                    return Ok(());
1520                }
1521
1522                for (_, v) in obj.iter_mut() {
1523                    resolve_refs(v, defs)?;
1524                }
1525            }
1526            Value::Array(arr) => {
1527                for item in arr.iter_mut() {
1528                    resolve_refs(item, defs)?;
1529                }
1530            }
1531            _ => {}
1532        }
1533
1534        Ok(())
1535    }
1536
1537    /// Parses a JSON Schema `$ref` path to extract the definition name.
1538    ///
1539    /// JSON Schema references use URI fragment syntax to point to definitions within
1540    /// the same document. This function extracts the definition name from common
1541    /// reference patterns used in JSON Schema.
1542    fn parse_ref_path(ref_str: &str) -> Result<String, CompletionError> {
1543        if let Some(fragment) = ref_str.strip_prefix('#') {
1544            if let Some(name) = fragment.strip_prefix("/$defs/") {
1545                Ok(name.to_string())
1546            } else if let Some(name) = fragment.strip_prefix("/definitions/") {
1547                Ok(name.to_string())
1548            } else {
1549                Err(CompletionError::ResponseError(format!(
1550                    "Unsupported reference format: {}",
1551                    ref_str
1552                )))
1553            }
1554        } else {
1555            Err(CompletionError::ResponseError(format!(
1556                "Only fragment references (#/...) are supported: {}",
1557                ref_str
1558            )))
1559        }
1560    }
1561
1562    /// Helper function to extract the type string from a JSON value.
1563    /// Handles both direct string types and array types (returns the first element).
1564    fn extract_type(type_value: &Value) -> Option<String> {
1565        if type_value.is_string() {
1566            type_value.as_str().map(String::from)
1567        } else if type_value.is_array() {
1568            type_value
1569                .as_array()
1570                .and_then(|arr| arr.first())
1571                .and_then(|v| v.as_str().map(String::from))
1572        } else {
1573            None
1574        }
1575    }
1576
1577    /// Helper function to extract type from anyOf, oneOf, or allOf schemas.
1578    /// Returns the type of the first non-null schema found.
1579    fn extract_type_from_composition(composition: &Value) -> Option<String> {
1580        composition.as_array().and_then(|arr| {
1581            arr.iter().find_map(|schema| {
1582                if let Some(obj) = schema.as_object() {
1583                    // Skip null types
1584                    if let Some(type_val) = obj.get("type")
1585                        && let Some(type_str) = type_val.as_str()
1586                        && type_str == "null"
1587                    {
1588                        return None;
1589                    }
1590                    // Extract type from this schema
1591                    obj.get("type").and_then(extract_type).or_else(|| {
1592                        if obj.contains_key("properties") {
1593                            Some("object".to_string())
1594                        } else {
1595                            None
1596                        }
1597                    })
1598                } else {
1599                    None
1600                }
1601            })
1602        })
1603    }
1604
1605    /// Helper function to extract the first non-null schema from anyOf, oneOf, or allOf.
1606    /// Returns the schema object that should be used for properties, required, etc.
1607    fn extract_schema_from_composition(
1608        composition: &Value,
1609    ) -> Option<serde_json::Map<String, Value>> {
1610        composition.as_array().and_then(|arr| {
1611            arr.iter().find_map(|schema| {
1612                if let Some(obj) = schema.as_object()
1613                    && let Some(type_val) = obj.get("type")
1614                    && let Some(type_str) = type_val.as_str()
1615                {
1616                    if type_str == "null" {
1617                        return None;
1618                    }
1619                    Some(obj.clone())
1620                } else {
1621                    None
1622                }
1623            })
1624        })
1625    }
1626
1627    /// Helper function to infer the type of a schema object.
1628    /// Checks for explicit type, then anyOf/oneOf/allOf, then infers from properties.
1629    fn infer_type(obj: &serde_json::Map<String, Value>) -> String {
1630        // First, try direct type field
1631        if let Some(type_val) = obj.get("type")
1632            && let Some(type_str) = extract_type(type_val)
1633        {
1634            return type_str;
1635        }
1636
1637        // Then try anyOf, oneOf, allOf (in that order)
1638        if let Some(any_of) = obj.get("anyOf")
1639            && let Some(type_str) = extract_type_from_composition(any_of)
1640        {
1641            return type_str;
1642        }
1643
1644        if let Some(one_of) = obj.get("oneOf")
1645            && let Some(type_str) = extract_type_from_composition(one_of)
1646        {
1647            return type_str;
1648        }
1649
1650        if let Some(all_of) = obj.get("allOf")
1651            && let Some(type_str) = extract_type_from_composition(all_of)
1652        {
1653            return type_str;
1654        }
1655
1656        // Finally, infer object type if properties are present
1657        if obj.contains_key("properties") {
1658            "object".to_string()
1659        } else {
1660            String::new()
1661        }
1662    }
1663
1664    impl TryFrom<Value> for Schema {
1665        type Error = CompletionError;
1666
1667        fn try_from(value: Value) -> Result<Self, Self::Error> {
1668            let flattened_val = flatten_schema(value)?;
1669            if let Some(obj) = flattened_val.as_object() {
1670                // Determine which object to use for extracting properties and required fields.
1671                // If this object has anyOf/oneOf/allOf, we need to extract properties from the composition.
1672                let props_source = if obj.get("properties").is_none() {
1673                    if let Some(any_of) = obj.get("anyOf") {
1674                        extract_schema_from_composition(any_of)
1675                    } else if let Some(one_of) = obj.get("oneOf") {
1676                        extract_schema_from_composition(one_of)
1677                    } else if let Some(all_of) = obj.get("allOf") {
1678                        extract_schema_from_composition(all_of)
1679                    } else {
1680                        None
1681                    }
1682                    .unwrap_or(obj.clone())
1683                } else {
1684                    obj.clone()
1685                };
1686
1687                Ok(Schema {
1688                    r#type: infer_type(obj),
1689                    format: obj.get("format").and_then(|v| v.as_str()).map(String::from),
1690                    description: obj
1691                        .get("description")
1692                        .and_then(|v| v.as_str())
1693                        .map(String::from),
1694                    nullable: obj.get("nullable").and_then(|v| v.as_bool()),
1695                    r#enum: obj.get("enum").and_then(|v| v.as_array()).map(|arr| {
1696                        arr.iter()
1697                            .filter_map(|v| v.as_str().map(String::from))
1698                            .collect()
1699                    }),
1700                    max_items: obj
1701                        .get("maxItems")
1702                        .and_then(|v| v.as_i64())
1703                        .map(|v| v as i32),
1704                    min_items: obj
1705                        .get("minItems")
1706                        .and_then(|v| v.as_i64())
1707                        .map(|v| v as i32),
1708                    properties: props_source
1709                        .get("properties")
1710                        .and_then(|v| v.as_object())
1711                        .map(|map| {
1712                            map.iter()
1713                                .filter_map(|(k, v)| {
1714                                    v.clone().try_into().ok().map(|schema| (k.clone(), schema))
1715                                })
1716                                .collect()
1717                        }),
1718                    required: props_source
1719                        .get("required")
1720                        .and_then(|v| v.as_array())
1721                        .map(|arr| {
1722                            arr.iter()
1723                                .filter_map(|v| v.as_str().map(String::from))
1724                                .collect()
1725                        }),
1726                    items: obj
1727                        .get("items")
1728                        .and_then(|v| v.clone().try_into().ok())
1729                        .map(Box::new),
1730                })
1731            } else {
1732                Err(CompletionError::ResponseError(
1733                    "Expected a JSON object for Schema".into(),
1734                ))
1735            }
1736        }
1737    }
1738
1739    #[derive(Debug, Serialize)]
1740    #[serde(rename_all = "camelCase")]
1741    pub struct GenerateContentRequest {
1742        pub contents: Vec<Content>,
1743        #[serde(skip_serializing_if = "Option::is_none")]
1744        pub tools: Option<Vec<Tool>>,
1745        pub tool_config: Option<ToolConfig>,
1746        /// Optional. Configuration options for model generation and outputs.
1747        pub generation_config: Option<GenerationConfig>,
1748        /// Optional. A list of unique SafetySetting instances for blocking unsafe content. This will be enforced on the
1749        /// [GenerateContentRequest.contents] and [GenerateContentResponse.candidates]. There should not be more than one
1750        /// setting for each SafetyCategory type. The API will block any contents and responses that fail to meet the
1751        /// thresholds set by these settings. This list overrides the default settings for each SafetyCategory specified
1752        /// in the safetySettings. If there is no SafetySetting for a given SafetyCategory provided in the list, the API
1753        /// will use the default safety setting for that category. Harm categories:
1754        ///     - HARM_CATEGORY_HATE_SPEECH,
1755        ///     - HARM_CATEGORY_SEXUALLY_EXPLICIT
1756        ///     - HARM_CATEGORY_DANGEROUS_CONTENT
1757        ///     - HARM_CATEGORY_HARASSMENT
1758        /// are supported.
1759        /// Refer to the guide for detailed information on available safety settings. Also refer to the Safety guidance
1760        /// to learn how to incorporate safety considerations in your AI applications.
1761        pub safety_settings: Option<Vec<SafetySetting>>,
1762        /// Optional. Developer set system instruction(s). Currently, text only.
1763        /// From [Gemini API Reference](https://ai.google.dev/gemini-api/docs/system-instructions?lang=rest)
1764        pub system_instruction: Option<Content>,
1765        // cachedContent: Optional<String>
1766        /// Additional parameters.
1767        #[serde(flatten, skip_serializing_if = "Option::is_none")]
1768        pub additional_params: Option<serde_json::Value>,
1769    }
1770
1771    #[derive(Debug, Serialize)]
1772    #[serde(rename_all = "camelCase")]
1773    pub struct Tool {
1774        pub function_declarations: Vec<FunctionDeclaration>,
1775        pub code_execution: Option<CodeExecution>,
1776    }
1777
1778    #[derive(Debug, Serialize, Clone)]
1779    #[serde(rename_all = "camelCase")]
1780    pub struct FunctionDeclaration {
1781        pub name: String,
1782        pub description: String,
1783        #[serde(skip_serializing_if = "Option::is_none")]
1784        pub parameters: Option<Schema>,
1785    }
1786
1787    #[derive(Debug, Serialize, Deserialize)]
1788    #[serde(rename_all = "camelCase")]
1789    pub struct ToolConfig {
1790        pub function_calling_config: Option<FunctionCallingMode>,
1791    }
1792
1793    #[derive(Debug, Serialize, Deserialize, Default)]
1794    #[serde(tag = "mode", rename_all = "UPPERCASE")]
1795    pub enum FunctionCallingMode {
1796        #[default]
1797        Auto,
1798        None,
1799        Any {
1800            #[serde(skip_serializing_if = "Option::is_none")]
1801            allowed_function_names: Option<Vec<String>>,
1802        },
1803    }
1804
1805    impl TryFrom<message::ToolChoice> for FunctionCallingMode {
1806        type Error = CompletionError;
1807        fn try_from(value: message::ToolChoice) -> Result<Self, Self::Error> {
1808            let res = match value {
1809                message::ToolChoice::Auto => Self::Auto,
1810                message::ToolChoice::None => Self::None,
1811                message::ToolChoice::Required => Self::Any {
1812                    allowed_function_names: None,
1813                },
1814                message::ToolChoice::Specific { function_names } => Self::Any {
1815                    allowed_function_names: Some(function_names),
1816                },
1817            };
1818
1819            Ok(res)
1820        }
1821    }
1822
1823    #[derive(Debug, Serialize)]
1824    pub struct CodeExecution {}
1825
1826    #[derive(Debug, Serialize)]
1827    #[serde(rename_all = "camelCase")]
1828    pub struct SafetySetting {
1829        pub category: HarmCategory,
1830        pub threshold: HarmBlockThreshold,
1831    }
1832
1833    #[derive(Debug, Serialize)]
1834    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1835    pub enum HarmBlockThreshold {
1836        HarmBlockThresholdUnspecified,
1837        BlockLowAndAbove,
1838        BlockMediumAndAbove,
1839        BlockOnlyHigh,
1840        BlockNone,
1841        Off,
1842    }
1843}
1844
1845#[cfg(test)]
1846mod tests {
1847    use crate::{
1848        message,
1849        providers::gemini::completion::gemini_api_types::{
1850            ContentCandidate, FinishReason, flatten_schema,
1851        },
1852    };
1853
1854    use super::*;
1855    use serde_json::json;
1856
1857    #[test]
1858    fn test_resolve_request_model_uses_override() {
1859        let request = CompletionRequest {
1860            model: Some("gemini-2.5-flash".to_string()),
1861            preamble: None,
1862            chat_history: crate::OneOrMany::one("Hello".into()),
1863            documents: vec![],
1864            tools: vec![],
1865            temperature: None,
1866            max_tokens: None,
1867            tool_choice: None,
1868            additional_params: None,
1869            output_schema: None,
1870        };
1871
1872        let request_model = resolve_request_model("gemini-2.0-flash", &request);
1873        assert_eq!(request_model, "gemini-2.5-flash");
1874        assert_eq!(
1875            completion_endpoint(&request_model),
1876            "/v1beta/models/gemini-2.5-flash:generateContent"
1877        );
1878        assert_eq!(
1879            streaming_endpoint(&request_model),
1880            "/v1beta/models/gemini-2.5-flash:streamGenerateContent"
1881        );
1882    }
1883
1884    #[test]
1885    fn test_resolve_request_model_uses_default_when_unset() {
1886        let request = CompletionRequest {
1887            model: None,
1888            preamble: None,
1889            chat_history: crate::OneOrMany::one("Hello".into()),
1890            documents: vec![],
1891            tools: vec![],
1892            temperature: None,
1893            max_tokens: None,
1894            tool_choice: None,
1895            additional_params: None,
1896            output_schema: None,
1897        };
1898
1899        assert_eq!(
1900            resolve_request_model("gemini-2.0-flash", &request),
1901            "gemini-2.0-flash"
1902        );
1903    }
1904
1905    #[test]
1906    fn test_deserialize_message_user() {
1907        let raw_message = r#"{
1908            "parts": [
1909                {"text": "Hello, world!"},
1910                {"inlineData": {"mimeType": "image/png", "data": "base64encodeddata"}},
1911                {"functionCall": {"name": "test_function", "args": {"arg1": "value1"}}},
1912                {"functionResponse": {"name": "test_function", "response": {"result": "success"}}},
1913                {"fileData": {"mimeType": "application/pdf", "fileUri": "http://example.com/file.pdf"}},
1914                {"executableCode": {"code": "print('Hello, world!')", "language": "PYTHON"}},
1915                {"codeExecutionResult": {"output": "Hello, world!", "outcome": "OUTCOME_OK"}}
1916            ],
1917            "role": "user"
1918        }"#;
1919
1920        let content: Content = {
1921            let jd = &mut serde_json::Deserializer::from_str(raw_message);
1922            serde_path_to_error::deserialize(jd).unwrap_or_else(|err| {
1923                panic!("Deserialization error at {}: {}", err.path(), err);
1924            })
1925        };
1926        assert_eq!(content.role, Some(Role::User));
1927        assert_eq!(content.parts.len(), 7);
1928
1929        let parts: Vec<Part> = content.parts.into_iter().collect();
1930
1931        if let Part {
1932            part: PartKind::Text(text),
1933            ..
1934        } = &parts[0]
1935        {
1936            assert_eq!(text, "Hello, world!");
1937        } else {
1938            panic!("Expected text part");
1939        }
1940
1941        if let Part {
1942            part: PartKind::InlineData(inline_data),
1943            ..
1944        } = &parts[1]
1945        {
1946            assert_eq!(inline_data.mime_type, "image/png");
1947            assert_eq!(inline_data.data, "base64encodeddata");
1948        } else {
1949            panic!("Expected inline data part");
1950        }
1951
1952        if let Part {
1953            part: PartKind::FunctionCall(function_call),
1954            ..
1955        } = &parts[2]
1956        {
1957            assert_eq!(function_call.name, "test_function");
1958            assert_eq!(
1959                function_call.args.as_object().unwrap().get("arg1").unwrap(),
1960                "value1"
1961            );
1962        } else {
1963            panic!("Expected function call part");
1964        }
1965
1966        if let Part {
1967            part: PartKind::FunctionResponse(function_response),
1968            ..
1969        } = &parts[3]
1970        {
1971            assert_eq!(function_response.name, "test_function");
1972            assert_eq!(
1973                function_response
1974                    .response
1975                    .as_ref()
1976                    .unwrap()
1977                    .get("result")
1978                    .unwrap(),
1979                "success"
1980            );
1981        } else {
1982            panic!("Expected function response part");
1983        }
1984
1985        if let Part {
1986            part: PartKind::FileData(file_data),
1987            ..
1988        } = &parts[4]
1989        {
1990            assert_eq!(file_data.mime_type.as_ref().unwrap(), "application/pdf");
1991            assert_eq!(file_data.file_uri, "http://example.com/file.pdf");
1992        } else {
1993            panic!("Expected file data part");
1994        }
1995
1996        if let Part {
1997            part: PartKind::ExecutableCode(executable_code),
1998            ..
1999        } = &parts[5]
2000        {
2001            assert_eq!(executable_code.code, "print('Hello, world!')");
2002        } else {
2003            panic!("Expected executable code part");
2004        }
2005
2006        if let Part {
2007            part: PartKind::CodeExecutionResult(code_execution_result),
2008            ..
2009        } = &parts[6]
2010        {
2011            assert_eq!(
2012                code_execution_result.clone().output.unwrap(),
2013                "Hello, world!"
2014            );
2015        } else {
2016            panic!("Expected code execution result part");
2017        }
2018    }
2019
2020    #[test]
2021    fn test_deserialize_message_model() {
2022        let json_data = json!({
2023            "parts": [{"text": "Hello, user!"}],
2024            "role": "model"
2025        });
2026
2027        let content: Content = serde_json::from_value(json_data).unwrap();
2028        assert_eq!(content.role, Some(Role::Model));
2029        assert_eq!(content.parts.len(), 1);
2030        if let Some(Part {
2031            part: PartKind::Text(text),
2032            ..
2033        }) = content.parts.first()
2034        {
2035            assert_eq!(text, "Hello, user!");
2036        } else {
2037            panic!("Expected text part");
2038        }
2039    }
2040
2041    #[test]
2042    fn test_message_conversion_user() {
2043        let msg = message::Message::user("Hello, world!");
2044        let content: Content = msg.try_into().unwrap();
2045        assert_eq!(content.role, Some(Role::User));
2046        assert_eq!(content.parts.len(), 1);
2047        if let Some(Part {
2048            part: PartKind::Text(text),
2049            ..
2050        }) = &content.parts.first()
2051        {
2052            assert_eq!(text, "Hello, world!");
2053        } else {
2054            panic!("Expected text part");
2055        }
2056    }
2057
2058    #[test]
2059    fn test_message_conversion_model() {
2060        let msg = message::Message::assistant("Hello, user!");
2061
2062        let content: Content = msg.try_into().unwrap();
2063        assert_eq!(content.role, Some(Role::Model));
2064        assert_eq!(content.parts.len(), 1);
2065        if let Some(Part {
2066            part: PartKind::Text(text),
2067            ..
2068        }) = &content.parts.first()
2069        {
2070            assert_eq!(text, "Hello, user!");
2071        } else {
2072            panic!("Expected text part");
2073        }
2074    }
2075
2076    #[test]
2077    fn test_thought_signature_is_preserved_from_response_reasoning_part() {
2078        let response = GenerateContentResponse {
2079            response_id: "resp_1".to_string(),
2080            candidates: vec![ContentCandidate {
2081                content: Some(Content {
2082                    parts: vec![Part {
2083                        thought: Some(true),
2084                        thought_signature: Some("thought_sig_123".to_string()),
2085                        part: PartKind::Text("thinking text".to_string()),
2086                        additional_params: None,
2087                    }],
2088                    role: Some(Role::Model),
2089                }),
2090                finish_reason: Some(FinishReason::Stop),
2091                safety_ratings: None,
2092                citation_metadata: None,
2093                token_count: None,
2094                avg_logprobs: None,
2095                logprobs_result: None,
2096                index: Some(0),
2097                finish_message: None,
2098            }],
2099            prompt_feedback: None,
2100            usage_metadata: None,
2101            model_version: None,
2102        };
2103
2104        let converted: crate::completion::CompletionResponse<GenerateContentResponse> =
2105            response.try_into().expect("convert response");
2106        let first = converted.choice.first();
2107        assert!(matches!(
2108            first,
2109            message::AssistantContent::Reasoning(message::Reasoning { content, .. })
2110                if matches!(
2111                    content.first(),
2112                    Some(message::ReasoningContent::Text {
2113                        text,
2114                        signature: Some(signature)
2115                    }) if text == "thinking text" && signature == "thought_sig_123"
2116                )
2117        ));
2118    }
2119
2120    #[test]
2121    fn test_reasoning_signature_is_emitted_in_gemini_part() {
2122        let msg = message::Message::Assistant {
2123            id: None,
2124            content: OneOrMany::one(message::AssistantContent::Reasoning(
2125                message::Reasoning::new_with_signature(
2126                    "structured thought",
2127                    Some("reuse_sig_456".to_string()),
2128                ),
2129            )),
2130        };
2131
2132        let converted: Content = msg.try_into().expect("convert message");
2133        let first = converted.parts.first().expect("reasoning part");
2134        assert_eq!(first.thought, Some(true));
2135        assert_eq!(first.thought_signature.as_deref(), Some("reuse_sig_456"));
2136        assert!(matches!(
2137            &first.part,
2138            PartKind::Text(text) if text == "structured thought"
2139        ));
2140    }
2141
2142    #[test]
2143    fn test_message_conversion_tool_call() {
2144        let tool_call = message::ToolCall {
2145            id: "test_tool".to_string(),
2146            call_id: None,
2147            function: message::ToolFunction {
2148                name: "test_function".to_string(),
2149                arguments: json!({"arg1": "value1"}),
2150            },
2151            signature: None,
2152            additional_params: None,
2153        };
2154
2155        let msg = message::Message::Assistant {
2156            id: None,
2157            content: OneOrMany::one(message::AssistantContent::ToolCall(tool_call)),
2158        };
2159
2160        let content: Content = msg.try_into().unwrap();
2161        assert_eq!(content.role, Some(Role::Model));
2162        assert_eq!(content.parts.len(), 1);
2163        if let Some(Part {
2164            part: PartKind::FunctionCall(function_call),
2165            ..
2166        }) = content.parts.first()
2167        {
2168            assert_eq!(function_call.name, "test_function");
2169            assert_eq!(
2170                function_call.args.as_object().unwrap().get("arg1").unwrap(),
2171                "value1"
2172            );
2173        } else {
2174            panic!("Expected function call part");
2175        }
2176    }
2177
2178    #[test]
2179    fn test_vec_schema_conversion() {
2180        let schema_with_ref = json!({
2181            "type": "array",
2182            "items": {
2183                "$ref": "#/$defs/Person"
2184            },
2185            "$defs": {
2186                "Person": {
2187                    "type": "object",
2188                    "properties": {
2189                        "first_name": {
2190                            "type": ["string", "null"],
2191                            "description": "The person's first name, if provided (null otherwise)"
2192                        },
2193                        "last_name": {
2194                            "type": ["string", "null"],
2195                            "description": "The person's last name, if provided (null otherwise)"
2196                        },
2197                        "job": {
2198                            "type": ["string", "null"],
2199                            "description": "The person's job, if provided (null otherwise)"
2200                        }
2201                    },
2202                    "required": []
2203                }
2204            }
2205        });
2206
2207        let result: Result<Schema, _> = schema_with_ref.try_into();
2208
2209        match result {
2210            Ok(schema) => {
2211                assert_eq!(schema.r#type, "array");
2212
2213                if let Some(items) = schema.items {
2214                    println!("item types: {}", items.r#type);
2215
2216                    assert_ne!(items.r#type, "", "Items type should not be empty string!");
2217                    assert_eq!(items.r#type, "object", "Items should be object type");
2218                } else {
2219                    panic!("Schema should have items field for array type");
2220                }
2221            }
2222            Err(e) => println!("Schema conversion failed: {:?}", e),
2223        }
2224    }
2225
2226    #[test]
2227    fn test_object_schema() {
2228        let simple_schema = json!({
2229            "type": "object",
2230            "properties": {
2231                "name": {
2232                    "type": "string"
2233                }
2234            }
2235        });
2236
2237        let schema: Schema = simple_schema.try_into().unwrap();
2238        assert_eq!(schema.r#type, "object");
2239        assert!(schema.properties.is_some());
2240    }
2241
2242    #[test]
2243    fn test_array_with_inline_items() {
2244        let inline_schema = json!({
2245            "type": "array",
2246            "items": {
2247                "type": "object",
2248                "properties": {
2249                    "name": {
2250                        "type": "string"
2251                    }
2252                }
2253            }
2254        });
2255
2256        let schema: Schema = inline_schema.try_into().unwrap();
2257        assert_eq!(schema.r#type, "array");
2258
2259        if let Some(items) = schema.items {
2260            assert_eq!(items.r#type, "object");
2261            assert!(items.properties.is_some());
2262        } else {
2263            panic!("Schema should have items field");
2264        }
2265    }
2266    #[test]
2267    fn test_flattened_schema() {
2268        let ref_schema = json!({
2269            "type": "array",
2270            "items": {
2271                "$ref": "#/$defs/Person"
2272            },
2273            "$defs": {
2274                "Person": {
2275                    "type": "object",
2276                    "properties": {
2277                        "name": { "type": "string" }
2278                    }
2279                }
2280            }
2281        });
2282
2283        let flattened = flatten_schema(ref_schema).unwrap();
2284        let schema: Schema = flattened.try_into().unwrap();
2285
2286        assert_eq!(schema.r#type, "array");
2287
2288        if let Some(items) = schema.items {
2289            println!("Flattened items type: '{}'", items.r#type);
2290
2291            assert_eq!(items.r#type, "object");
2292            assert!(items.properties.is_some());
2293        }
2294    }
2295
2296    #[test]
2297    fn test_tool_result_with_image_content() {
2298        // Test that a ToolResult with image content converts correctly to Gemini's Part format
2299        use crate::OneOrMany;
2300        use crate::message::{
2301            DocumentSourceKind, Image, ImageMediaType, ToolResult, ToolResultContent,
2302        };
2303
2304        // Create a tool result with both text and image content
2305        let tool_result = ToolResult {
2306            id: "test_tool".to_string(),
2307            call_id: None,
2308            content: OneOrMany::many(vec![
2309                ToolResultContent::Text(message::Text {
2310                    text: r#"{"status": "success"}"#.to_string(),
2311                }),
2312                ToolResultContent::Image(Image {
2313                    data: DocumentSourceKind::Base64("iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==".to_string()),
2314                    media_type: Some(ImageMediaType::PNG),
2315                    detail: None,
2316                    additional_params: None,
2317                }),
2318            ]).expect("Should create OneOrMany with multiple items"),
2319        };
2320
2321        let user_content = message::UserContent::ToolResult(tool_result);
2322        let msg = message::Message::User {
2323            content: OneOrMany::one(user_content),
2324        };
2325
2326        // Convert to Gemini Content
2327        let content: Content = msg.try_into().expect("Should convert to Gemini Content");
2328
2329        assert_eq!(content.role, Some(Role::User));
2330        assert_eq!(content.parts.len(), 1);
2331
2332        // Verify the part is a FunctionResponse with both response and parts
2333        if let Some(Part {
2334            part: PartKind::FunctionResponse(function_response),
2335            ..
2336        }) = content.parts.first()
2337        {
2338            assert_eq!(function_response.name, "test_tool");
2339
2340            // Check that response JSON is present
2341            assert!(function_response.response.is_some());
2342            let response = function_response.response.as_ref().unwrap();
2343            assert!(response.get("result").is_some());
2344
2345            // Check that parts with image data are present
2346            assert!(function_response.parts.is_some());
2347            let parts = function_response.parts.as_ref().unwrap();
2348            assert_eq!(parts.len(), 1);
2349
2350            let image_part = &parts[0];
2351            assert!(image_part.inline_data.is_some());
2352            let inline_data = image_part.inline_data.as_ref().unwrap();
2353            assert_eq!(inline_data.mime_type, "image/png");
2354            assert!(!inline_data.data.is_empty());
2355        } else {
2356            panic!("Expected FunctionResponse part");
2357        }
2358    }
2359
2360    #[test]
2361    fn test_tool_result_with_url_image() {
2362        // Test that a ToolResult with a URL-based image converts to file_data
2363        use crate::OneOrMany;
2364        use crate::message::{
2365            DocumentSourceKind, Image, ImageMediaType, ToolResult, ToolResultContent,
2366        };
2367
2368        let tool_result = ToolResult {
2369            id: "screenshot_tool".to_string(),
2370            call_id: None,
2371            content: OneOrMany::one(ToolResultContent::Image(Image {
2372                data: DocumentSourceKind::Url("https://example.com/image.png".to_string()),
2373                media_type: Some(ImageMediaType::PNG),
2374                detail: None,
2375                additional_params: None,
2376            })),
2377        };
2378
2379        let user_content = message::UserContent::ToolResult(tool_result);
2380        let msg = message::Message::User {
2381            content: OneOrMany::one(user_content),
2382        };
2383
2384        let content: Content = msg.try_into().expect("Should convert to Gemini Content");
2385
2386        assert_eq!(content.role, Some(Role::User));
2387        assert_eq!(content.parts.len(), 1);
2388
2389        if let Some(Part {
2390            part: PartKind::FunctionResponse(function_response),
2391            ..
2392        }) = content.parts.first()
2393        {
2394            assert_eq!(function_response.name, "screenshot_tool");
2395
2396            // URL images should have parts with file_data
2397            assert!(function_response.parts.is_some());
2398            let parts = function_response.parts.as_ref().unwrap();
2399            assert_eq!(parts.len(), 1);
2400
2401            let image_part = &parts[0];
2402            assert!(image_part.file_data.is_some());
2403            let file_data = image_part.file_data.as_ref().unwrap();
2404            assert_eq!(file_data.file_uri, "https://example.com/image.png");
2405            assert_eq!(file_data.mime_type.as_ref().unwrap(), "image/png");
2406        } else {
2407            panic!("Expected FunctionResponse part");
2408        }
2409    }
2410
2411    #[test]
2412    fn test_from_tool_output_parses_image_json() {
2413        // Test the ToolResultContent::from_tool_output helper with image JSON
2414        use crate::message::{DocumentSourceKind, ToolResultContent};
2415
2416        // Test simple image JSON format
2417        let image_json = r#"{"type": "image", "data": "base64data==", "mimeType": "image/jpeg"}"#;
2418        let result = ToolResultContent::from_tool_output(image_json);
2419
2420        assert_eq!(result.len(), 1);
2421        if let ToolResultContent::Image(img) = result.first() {
2422            assert!(matches!(img.data, DocumentSourceKind::Base64(_)));
2423            if let DocumentSourceKind::Base64(data) = &img.data {
2424                assert_eq!(data, "base64data==");
2425            }
2426            assert_eq!(img.media_type, Some(crate::message::ImageMediaType::JPEG));
2427        } else {
2428            panic!("Expected Image content");
2429        }
2430    }
2431
2432    #[test]
2433    fn test_from_tool_output_parses_hybrid_json() {
2434        // Test the ToolResultContent::from_tool_output helper with hybrid response/parts format
2435        use crate::message::{DocumentSourceKind, ToolResultContent};
2436
2437        let hybrid_json = r#"{
2438            "response": {"status": "ok", "count": 42},
2439            "parts": [
2440                {"type": "image", "data": "imgdata1==", "mimeType": "image/png"},
2441                {"type": "image", "data": "https://example.com/img.jpg", "mimeType": "image/jpeg"}
2442            ]
2443        }"#;
2444
2445        let result = ToolResultContent::from_tool_output(hybrid_json);
2446
2447        // Should have 3 items: 1 text (response) + 2 images (parts)
2448        assert_eq!(result.len(), 3);
2449
2450        let items: Vec<_> = result.iter().collect();
2451
2452        // First should be text with the response JSON
2453        if let ToolResultContent::Text(text) = &items[0] {
2454            assert!(text.text.contains("status"));
2455            assert!(text.text.contains("ok"));
2456        } else {
2457            panic!("Expected Text content first");
2458        }
2459
2460        // Second should be base64 image
2461        if let ToolResultContent::Image(img) = &items[1] {
2462            assert!(matches!(img.data, DocumentSourceKind::Base64(_)));
2463        } else {
2464            panic!("Expected Image content second");
2465        }
2466
2467        // Third should be URL image
2468        if let ToolResultContent::Image(img) = &items[2] {
2469            assert!(matches!(img.data, DocumentSourceKind::Url(_)));
2470        } else {
2471            panic!("Expected Image content third");
2472        }
2473    }
2474
2475    /// E2E test that verifies Gemini can process tool results containing images.
2476    /// This test creates an agent with a tool that returns an image, invokes it,
2477    /// and verifies that Gemini can interpret the image in the tool result.
2478    #[tokio::test]
2479    #[ignore = "requires GEMINI_API_KEY environment variable"]
2480    async fn test_gemini_agent_with_image_tool_result_e2e() {
2481        use crate::completion::{Prompt, ToolDefinition};
2482        use crate::prelude::*;
2483        use crate::providers::gemini;
2484        use crate::tool::Tool;
2485        use serde::{Deserialize, Serialize};
2486
2487        /// A tool that returns a small red 1x1 pixel PNG image
2488        #[derive(Debug, Serialize, Deserialize)]
2489        struct ImageGeneratorTool;
2490
2491        #[derive(Debug, thiserror::Error)]
2492        #[error("Image generation error")]
2493        struct ImageToolError;
2494
2495        impl Tool for ImageGeneratorTool {
2496            const NAME: &'static str = "generate_test_image";
2497            type Error = ImageToolError;
2498            type Args = serde_json::Value;
2499            // Return the image in the format that from_tool_output expects
2500            type Output = String;
2501
2502            async fn definition(&self, _prompt: String) -> ToolDefinition {
2503                ToolDefinition {
2504                    name: "generate_test_image".to_string(),
2505                    description: "Generates a small test image (a 1x1 red pixel). Call this tool when asked to generate or show an image.".to_string(),
2506                    parameters: json!({
2507                        "type": "object",
2508                        "properties": {},
2509                        "required": []
2510                    }),
2511                }
2512            }
2513
2514            async fn call(&self, _args: Self::Args) -> Result<Self::Output, Self::Error> {
2515                // Return a JSON object that from_tool_output will parse as an image
2516                // This is a 1x1 red PNG pixel
2517                Ok(json!({
2518                    "type": "image",
2519                    "data": "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8DwHwAFBQIAX8jx0gAAAABJRU5ErkJggg==",
2520                    "mimeType": "image/png"
2521                }).to_string())
2522            }
2523        }
2524
2525        let client = gemini::Client::from_env();
2526
2527        let agent = client
2528            .agent("gemini-3-flash-preview")
2529            .preamble("You are a helpful assistant. When asked about images, use the generate_test_image tool to create one, then describe what you see in the image.")
2530            .tool(ImageGeneratorTool)
2531            .build();
2532
2533        // This prompt should trigger the tool, which returns an image that Gemini should process
2534        let response = agent
2535            .prompt("Please generate a test image and tell me what color the pixel is.")
2536            .await;
2537
2538        // The test passes if Gemini successfully processes the request without errors.
2539        // The image is a 1x1 red pixel, so Gemini should be able to describe it.
2540        assert!(
2541            response.is_ok(),
2542            "Gemini should successfully process tool result with image: {:?}",
2543            response.err()
2544        );
2545
2546        let response_text = response.unwrap();
2547        println!("Response: {response_text}");
2548        // Gemini should have been able to see the image and potentially describe its color
2549        assert!(!response_text.is_empty(), "Response should not be empty");
2550    }
2551}