rig/providers/gemini/
completion.rs

1// ================================================================
2//! Google Gemini Completion Integration
3//! From [Gemini API Reference](https://ai.google.dev/api/generate-content)
4// ================================================================
5/// `gemini-2.5-pro-preview-06-05` completion model
6pub const GEMINI_2_5_PRO_PREVIEW_06_05: &str = "gemini-2.5-pro-preview-06-05";
7/// `gemini-2.5-pro-preview-05-06` completion model
8pub const GEMINI_2_5_PRO_PREVIEW_05_06: &str = "gemini-2.5-pro-preview-05-06";
9/// `gemini-2.5-pro-preview-03-25` completion model
10pub const GEMINI_2_5_PRO_PREVIEW_03_25: &str = "gemini-2.5-pro-preview-03-25";
11/// `gemini-2.5-flash-preview-04-17` completion model
12pub const GEMINI_2_5_FLASH_PREVIEW_04_17: &str = "gemini-2.5-flash-preview-04-17";
13/// `gemini-2.5-pro-exp-03-25` experimental completion model
14pub const GEMINI_2_5_PRO_EXP_03_25: &str = "gemini-2.5-pro-exp-03-25";
15/// `gemini-2.5-flash` completion model
16pub const GEMINI_2_5_FLASH: &str = "gemini-2.5-flash";
17/// `gemini-2.0-flash-lite` completion model
18pub const GEMINI_2_0_FLASH_LITE: &str = "gemini-2.0-flash-lite";
19/// `gemini-2.0-flash` completion model
20pub const GEMINI_2_0_FLASH: &str = "gemini-2.0-flash";
21
22use self::gemini_api_types::Schema;
23use crate::http_client::HttpClientExt;
24use crate::message::{self, MimeType, Reasoning};
25
26use crate::providers::gemini::completion::gemini_api_types::{
27    AdditionalParameters, FunctionCallingMode, ToolConfig,
28};
29use crate::providers::gemini::streaming::StreamingCompletionResponse;
30use crate::telemetry::SpanCombinator;
31use crate::{
32    OneOrMany,
33    completion::{self, CompletionError, CompletionRequest},
34};
35use gemini_api_types::{
36    Content, FunctionDeclaration, GenerateContentRequest, GenerateContentResponse, Part, PartKind,
37    Role, Tool,
38};
39use serde_json::{Map, Value};
40use std::convert::TryFrom;
41use tracing::{Level, enabled, info_span};
42use tracing_futures::Instrument;
43
44use super::Client;
45
46// =================================================================
47// Rig Implementation Types
48// =================================================================
49
50#[derive(Clone, Debug)]
51pub struct CompletionModel<T = reqwest::Client> {
52    pub(crate) client: Client<T>,
53    pub model: String,
54}
55
56impl<T> CompletionModel<T> {
57    pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
58        Self {
59            client,
60            model: model.into(),
61        }
62    }
63
64    pub fn with_model(client: Client<T>, model: &str) -> Self {
65        Self {
66            client,
67            model: model.into(),
68        }
69    }
70}
71
72impl<T> completion::CompletionModel for CompletionModel<T>
73where
74    T: HttpClientExt + Clone + 'static,
75{
76    type Response = GenerateContentResponse;
77    type StreamingResponse = StreamingCompletionResponse;
78    type Client = super::Client<T>;
79
80    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
81        Self::new(client.clone(), model)
82    }
83
84    async fn completion(
85        &self,
86        completion_request: CompletionRequest,
87    ) -> Result<completion::CompletionResponse<GenerateContentResponse>, CompletionError> {
88        let span = if tracing::Span::current().is_disabled() {
89            info_span!(
90                target: "rig::completions",
91                "generate_content",
92                gen_ai.operation.name = "generate_content",
93                gen_ai.provider.name = "gcp.gemini",
94                gen_ai.request.model = self.model,
95                gen_ai.system_instructions = &completion_request.preamble,
96                gen_ai.response.id = tracing::field::Empty,
97                gen_ai.response.model = tracing::field::Empty,
98                gen_ai.usage.output_tokens = tracing::field::Empty,
99                gen_ai.usage.input_tokens = tracing::field::Empty,
100            )
101        } else {
102            tracing::Span::current()
103        };
104
105        let request = create_request_body(completion_request)?;
106
107        if enabled!(Level::TRACE) {
108            tracing::trace!(
109                target: "rig::completions",
110                "Gemini completion request: {}",
111                serde_json::to_string_pretty(&request)?
112            );
113        }
114
115        let body = serde_json::to_vec(&request)?;
116
117        let path = format!("/v1beta/models/{}:generateContent", self.model);
118
119        let request = self
120            .client
121            .post(path.as_str())?
122            .body(body)
123            .map_err(|e| CompletionError::HttpError(e.into()))?;
124
125        async move {
126            let response = self.client.send::<_, Vec<u8>>(request).await?;
127
128            if response.status().is_success() {
129                let response_body = response
130                    .into_body()
131                    .await
132                    .map_err(CompletionError::HttpError)?;
133
134                let response_text = String::from_utf8_lossy(&response_body).to_string();
135
136                let response: GenerateContentResponse = serde_json::from_slice(&response_body)
137                    .map_err(|err| {
138                        tracing::error!(
139                            error = %err,
140                            body = %response_text,
141                            "Failed to deserialize Gemini completion response"
142                        );
143                        CompletionError::JsonError(err)
144                    })?;
145
146                let span = tracing::Span::current();
147                span.record_response_metadata(&response);
148                span.record_token_usage(&response.usage_metadata);
149
150                if enabled!(Level::TRACE) {
151                    tracing::trace!(
152                        target: "rig::completions",
153                        "Gemini completion response: {}",
154                        serde_json::to_string_pretty(&response)?
155                    );
156                }
157
158                response.try_into()
159            } else {
160                let text = String::from_utf8_lossy(
161                    &response
162                        .into_body()
163                        .await
164                        .map_err(CompletionError::HttpError)?,
165                )
166                .into();
167
168                Err(CompletionError::ProviderError(text))
169            }
170        }
171        .instrument(span)
172        .await
173    }
174
175    async fn stream(
176        &self,
177        request: CompletionRequest,
178    ) -> Result<
179        crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
180        CompletionError,
181    > {
182        CompletionModel::stream(self, request).await
183    }
184}
185
186pub(crate) fn create_request_body(
187    completion_request: CompletionRequest,
188) -> Result<GenerateContentRequest, CompletionError> {
189    let mut full_history = Vec::new();
190    full_history.extend(completion_request.chat_history);
191
192    let additional_params = completion_request
193        .additional_params
194        .unwrap_or_else(|| Value::Object(Map::new()));
195
196    let AdditionalParameters {
197        mut generation_config,
198        additional_params,
199    } = serde_json::from_value::<AdditionalParameters>(additional_params)?;
200
201    generation_config = generation_config.map(|mut cfg| {
202        if let Some(temp) = completion_request.temperature {
203            cfg.temperature = Some(temp);
204        };
205
206        if let Some(max_tokens) = completion_request.max_tokens {
207            cfg.max_output_tokens = Some(max_tokens);
208        };
209
210        cfg
211    });
212
213    let system_instruction = completion_request.preamble.clone().map(|preamble| Content {
214        parts: vec![preamble.into()],
215        role: Some(Role::Model),
216    });
217
218    let tools = if completion_request.tools.is_empty() {
219        None
220    } else {
221        Some(Tool::try_from(completion_request.tools)?)
222    };
223
224    let tool_config = if let Some(cfg) = completion_request.tool_choice {
225        Some(ToolConfig {
226            function_calling_config: Some(FunctionCallingMode::try_from(cfg)?),
227        })
228    } else {
229        None
230    };
231
232    let request = GenerateContentRequest {
233        contents: full_history
234            .into_iter()
235            .map(|msg| {
236                msg.try_into()
237                    .map_err(|e| CompletionError::RequestError(Box::new(e)))
238            })
239            .collect::<Result<Vec<_>, _>>()?,
240        generation_config,
241        safety_settings: None,
242        tools,
243        tool_config,
244        system_instruction,
245        additional_params,
246    };
247
248    Ok(request)
249}
250
251impl TryFrom<completion::ToolDefinition> for Tool {
252    type Error = CompletionError;
253
254    fn try_from(tool: completion::ToolDefinition) -> Result<Self, Self::Error> {
255        let parameters: Option<Schema> =
256            if tool.parameters == serde_json::json!({"type": "object", "properties": {}}) {
257                None
258            } else {
259                Some(tool.parameters.try_into()?)
260            };
261
262        Ok(Self {
263            function_declarations: vec![FunctionDeclaration {
264                name: tool.name,
265                description: tool.description,
266                parameters,
267            }],
268            code_execution: None,
269        })
270    }
271}
272
273impl TryFrom<Vec<completion::ToolDefinition>> for Tool {
274    type Error = CompletionError;
275
276    fn try_from(tools: Vec<completion::ToolDefinition>) -> Result<Self, Self::Error> {
277        let mut function_declarations = Vec::new();
278
279        for tool in tools {
280            let parameters =
281                if tool.parameters == serde_json::json!({"type": "object", "properties": {}}) {
282                    None
283                } else {
284                    match tool.parameters.try_into() {
285                        Ok(schema) => Some(schema),
286                        Err(e) => {
287                            let emsg = format!(
288                                "Tool '{}' could not be converted to a schema: {:?}",
289                                tool.name, e,
290                            );
291                            return Err(CompletionError::ProviderError(emsg));
292                        }
293                    }
294                };
295
296            function_declarations.push(FunctionDeclaration {
297                name: tool.name,
298                description: tool.description,
299                parameters,
300            });
301        }
302
303        Ok(Self {
304            function_declarations,
305            code_execution: None,
306        })
307    }
308}
309
310impl TryFrom<GenerateContentResponse> for completion::CompletionResponse<GenerateContentResponse> {
311    type Error = CompletionError;
312
313    fn try_from(response: GenerateContentResponse) -> Result<Self, Self::Error> {
314        let candidate = response.candidates.first().ok_or_else(|| {
315            CompletionError::ResponseError("No response candidates in response".into())
316        })?;
317
318        let content = candidate
319            .content
320            .as_ref()
321            .ok_or_else(|| {
322                let reason = candidate
323                    .finish_reason
324                    .as_ref()
325                    .map(|r| format!("finish_reason={r:?}"))
326                    .unwrap_or_else(|| "finish_reason=<unknown>".to_string());
327                let message = candidate
328                    .finish_message
329                    .as_deref()
330                    .unwrap_or("no finish message provided");
331                CompletionError::ResponseError(format!(
332                    "Gemini candidate missing content ({reason}, finish_message={message})"
333                ))
334            })?
335            .parts
336            .iter()
337            .map(
338                |Part {
339                     thought,
340                     thought_signature,
341                     part,
342                     ..
343                 }| {
344                    Ok(match part {
345                        PartKind::Text(text) => {
346                            if let Some(thought) = thought
347                                && *thought
348                            {
349                                completion::AssistantContent::Reasoning(Reasoning::new(text))
350                            } else {
351                                completion::AssistantContent::text(text)
352                            }
353                        }
354                        PartKind::InlineData(inline_data) => {
355                            let mime_type =
356                                message::MediaType::from_mime_type(&inline_data.mime_type);
357
358                            match mime_type {
359                                Some(message::MediaType::Image(media_type)) => {
360                                    message::AssistantContent::image_base64(
361                                        &inline_data.data,
362                                        Some(media_type),
363                                        Some(message::ImageDetail::default()),
364                                    )
365                                }
366                                _ => {
367                                    return Err(CompletionError::ResponseError(format!(
368                                        "Unsupported media type {mime_type:?}"
369                                    )));
370                                }
371                            }
372                        }
373                        PartKind::FunctionCall(function_call) => {
374                            completion::AssistantContent::ToolCall(
375                                message::ToolCall::new(
376                                    function_call.name.clone(),
377                                    message::ToolFunction::new(
378                                        function_call.name.clone(),
379                                        function_call.args.clone(),
380                                    ),
381                                )
382                                .with_signature(thought_signature.clone()),
383                            )
384                        }
385                        _ => {
386                            return Err(CompletionError::ResponseError(
387                                "Response did not contain a message or tool call".into(),
388                            ));
389                        }
390                    })
391                },
392            )
393            .collect::<Result<Vec<_>, _>>()?;
394
395        let choice = OneOrMany::many(content).map_err(|_| {
396            CompletionError::ResponseError(
397                "Response contained no message or tool call (empty)".to_owned(),
398            )
399        })?;
400
401        let usage = response
402            .usage_metadata
403            .as_ref()
404            .map(|usage| completion::Usage {
405                input_tokens: usage.prompt_token_count as u64,
406                output_tokens: usage.candidates_token_count.unwrap_or(0) as u64,
407                total_tokens: usage.total_token_count as u64,
408            })
409            .unwrap_or_default();
410
411        Ok(completion::CompletionResponse {
412            choice,
413            usage,
414            raw_response: response,
415        })
416    }
417}
418
419pub mod gemini_api_types {
420    use crate::telemetry::ProviderResponseExt;
421    use std::{collections::HashMap, convert::Infallible, str::FromStr};
422
423    // =================================================================
424    // Gemini API Types
425    // =================================================================
426    use serde::{Deserialize, Serialize};
427    use serde_json::{Value, json};
428
429    use crate::completion::GetTokenUsage;
430    use crate::message::{DocumentSourceKind, ImageMediaType, MessageError, MimeType};
431    use crate::{
432        completion::CompletionError,
433        message::{self},
434        providers::gemini::gemini_api_types::{CodeExecutionResult, ExecutableCode},
435    };
436
437    #[derive(Debug, Deserialize, Serialize, Default)]
438    #[serde(rename_all = "camelCase")]
439    pub struct AdditionalParameters {
440        /// Change your Gemini request configuration.
441        pub generation_config: Option<GenerationConfig>,
442        /// Any additional parameters that you want.
443        #[serde(flatten, skip_serializing_if = "Option::is_none")]
444        pub additional_params: Option<serde_json::Value>,
445    }
446
447    impl AdditionalParameters {
448        pub fn with_config(mut self, cfg: GenerationConfig) -> Self {
449            self.generation_config = Some(cfg);
450            self
451        }
452
453        pub fn with_params(mut self, params: serde_json::Value) -> Self {
454            self.additional_params = Some(params);
455            self
456        }
457    }
458
459    /// Response from the model supporting multiple candidate responses.
460    /// Safety ratings and content filtering are reported for both prompt in GenerateContentResponse.prompt_feedback
461    /// and for each candidate in finishReason and in safetyRatings.
462    /// The API:
463    ///     - Returns either all requested candidates or none of them
464    ///     - Returns no candidates at all only if there was something wrong with the prompt (check promptFeedback)
465    ///     - Reports feedback on each candidate in finishReason and safetyRatings.
466    #[derive(Debug, Deserialize, Serialize)]
467    #[serde(rename_all = "camelCase")]
468    pub struct GenerateContentResponse {
469        pub response_id: String,
470        /// Candidate responses from the model.
471        pub candidates: Vec<ContentCandidate>,
472        /// Returns the prompt's feedback related to the content filters.
473        pub prompt_feedback: Option<PromptFeedback>,
474        /// Output only. Metadata on the generation requests' token usage.
475        pub usage_metadata: Option<UsageMetadata>,
476        pub model_version: Option<String>,
477    }
478
479    impl ProviderResponseExt for GenerateContentResponse {
480        type OutputMessage = ContentCandidate;
481        type Usage = UsageMetadata;
482
483        fn get_response_id(&self) -> Option<String> {
484            Some(self.response_id.clone())
485        }
486
487        fn get_response_model_name(&self) -> Option<String> {
488            None
489        }
490
491        fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
492            self.candidates.clone()
493        }
494
495        fn get_text_response(&self) -> Option<String> {
496            let str = self
497                .candidates
498                .iter()
499                .filter_map(|x| {
500                    let content = x.content.as_ref()?;
501                    if content.role.as_ref().is_none_or(|y| y != &Role::Model) {
502                        return None;
503                    }
504
505                    let res = content
506                        .parts
507                        .iter()
508                        .filter_map(|part| {
509                            if let PartKind::Text(ref str) = part.part {
510                                Some(str.to_owned())
511                            } else {
512                                None
513                            }
514                        })
515                        .collect::<Vec<String>>()
516                        .join("\n");
517
518                    Some(res)
519                })
520                .collect::<Vec<String>>()
521                .join("\n");
522
523            if str.is_empty() { None } else { Some(str) }
524        }
525
526        fn get_usage(&self) -> Option<Self::Usage> {
527            self.usage_metadata.clone()
528        }
529    }
530
531    /// A response candidate generated from the model.
532    #[derive(Clone, Debug, Deserialize, Serialize)]
533    #[serde(rename_all = "camelCase")]
534    pub struct ContentCandidate {
535        /// Output only. Generated content returned from the model.
536        #[serde(skip_serializing_if = "Option::is_none")]
537        pub content: Option<Content>,
538        /// Optional. Output only. The reason why the model stopped generating tokens.
539        /// If empty, the model has not stopped generating tokens.
540        pub finish_reason: Option<FinishReason>,
541        /// List of ratings for the safety of a response candidate.
542        /// There is at most one rating per category.
543        pub safety_ratings: Option<Vec<SafetyRating>>,
544        /// Output only. Citation information for model-generated candidate.
545        /// This field may be populated with recitation information for any text included in the content.
546        /// These are passages that are "recited" from copyrighted material in the foundational LLM's training data.
547        pub citation_metadata: Option<CitationMetadata>,
548        /// Output only. Token count for this candidate.
549        pub token_count: Option<i32>,
550        /// Output only.
551        pub avg_logprobs: Option<f64>,
552        /// Output only. Log-likelihood scores for the response tokens and top tokens
553        pub logprobs_result: Option<LogprobsResult>,
554        /// Output only. Index of the candidate in the list of response candidates.
555        pub index: Option<i32>,
556        /// Output only. Additional information about why the model stopped generating tokens.
557        pub finish_message: Option<String>,
558    }
559
560    #[derive(Clone, Debug, Deserialize, Serialize)]
561    pub struct Content {
562        /// Ordered Parts that constitute a single message. Parts may have different MIME types.
563        #[serde(default)]
564        pub parts: Vec<Part>,
565        /// The producer of the content. Must be either 'user' or 'model'.
566        /// Useful to set for multi-turn conversations, otherwise can be left blank or unset.
567        pub role: Option<Role>,
568    }
569
570    impl TryFrom<message::Message> for Content {
571        type Error = message::MessageError;
572
573        fn try_from(msg: message::Message) -> Result<Self, Self::Error> {
574            Ok(match msg {
575                message::Message::User { content } => Content {
576                    parts: content
577                        .into_iter()
578                        .map(|c| c.try_into())
579                        .collect::<Result<Vec<_>, _>>()?,
580                    role: Some(Role::User),
581                },
582                message::Message::Assistant { content, .. } => Content {
583                    role: Some(Role::Model),
584                    parts: content
585                        .into_iter()
586                        .map(|content| content.try_into())
587                        .collect::<Result<Vec<_>, _>>()?,
588                },
589            })
590        }
591    }
592
593    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
594    #[serde(rename_all = "lowercase")]
595    pub enum Role {
596        User,
597        Model,
598    }
599
600    #[derive(Debug, Default, Deserialize, Serialize, Clone, PartialEq)]
601    #[serde(rename_all = "camelCase")]
602    pub struct Part {
603        /// whether or not the part is a reasoning/thinking text or not
604        #[serde(skip_serializing_if = "Option::is_none")]
605        pub thought: Option<bool>,
606        /// an opaque sig for the thought so it can be reused - is a base64 string
607        #[serde(skip_serializing_if = "Option::is_none")]
608        pub thought_signature: Option<String>,
609        #[serde(flatten)]
610        pub part: PartKind,
611        #[serde(flatten, skip_serializing_if = "Option::is_none")]
612        pub additional_params: Option<Value>,
613    }
614
615    /// A datatype containing media that is part of a multi-part [Content] message.
616    /// A Part consists of data which has an associated datatype. A Part can only contain one of the accepted types in Part.data.
617    /// A Part must have a fixed IANA MIME type identifying the type and subtype of the media if the inlineData field is filled with raw bytes.
618    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
619    #[serde(rename_all = "camelCase")]
620    pub enum PartKind {
621        Text(String),
622        InlineData(Blob),
623        FunctionCall(FunctionCall),
624        FunctionResponse(FunctionResponse),
625        FileData(FileData),
626        ExecutableCode(ExecutableCode),
627        CodeExecutionResult(CodeExecutionResult),
628    }
629
630    // This default instance is primarily so we can easily fill in the optional fields of `Part`
631    // So this instance for `PartKind` (and the allocation it would cause) should be optimized away
632    impl Default for PartKind {
633        fn default() -> Self {
634            Self::Text(String::new())
635        }
636    }
637
638    impl From<String> for Part {
639        fn from(text: String) -> Self {
640            Self {
641                thought: Some(false),
642                thought_signature: None,
643                part: PartKind::Text(text),
644                additional_params: None,
645            }
646        }
647    }
648
649    impl From<&str> for Part {
650        fn from(text: &str) -> Self {
651            Self::from(text.to_string())
652        }
653    }
654
655    impl FromStr for Part {
656        type Err = Infallible;
657
658        fn from_str(s: &str) -> Result<Self, Self::Err> {
659            Ok(s.into())
660        }
661    }
662
663    impl TryFrom<(ImageMediaType, DocumentSourceKind)> for PartKind {
664        type Error = message::MessageError;
665        fn try_from(
666            (mime_type, doc_src): (ImageMediaType, DocumentSourceKind),
667        ) -> Result<Self, Self::Error> {
668            let mime_type = mime_type.to_mime_type().to_string();
669            let part = match doc_src {
670                DocumentSourceKind::Url(url) => PartKind::FileData(FileData {
671                    mime_type: Some(mime_type),
672                    file_uri: url,
673                }),
674                DocumentSourceKind::Base64(data) | DocumentSourceKind::String(data) => {
675                    PartKind::InlineData(Blob { mime_type, data })
676                }
677                DocumentSourceKind::Raw(_) => {
678                    return Err(message::MessageError::ConversionError(
679                        "Raw files not supported, encode as base64 first".into(),
680                    ));
681                }
682                DocumentSourceKind::Unknown => {
683                    return Err(message::MessageError::ConversionError(
684                        "Can't convert an unknown document source".to_string(),
685                    ));
686                }
687            };
688
689            Ok(part)
690        }
691    }
692
693    impl TryFrom<message::UserContent> for Part {
694        type Error = message::MessageError;
695
696        fn try_from(content: message::UserContent) -> Result<Self, Self::Error> {
697            match content {
698                message::UserContent::Text(message::Text { text }) => Ok(Part {
699                    thought: Some(false),
700                    thought_signature: None,
701                    part: PartKind::Text(text),
702                    additional_params: None,
703                }),
704                message::UserContent::ToolResult(message::ToolResult { id, content, .. }) => {
705                    let content = match content.first() {
706                        message::ToolResultContent::Text(text) => text.text,
707                        message::ToolResultContent::Image(_) => {
708                            return Err(message::MessageError::ConversionError(
709                                "Tool result content must be text".to_string(),
710                            ));
711                        }
712                    };
713                    // Convert to JSON since this value may be a valid JSON value
714                    let result: serde_json::Value =
715                        serde_json::from_str(&content).unwrap_or_else(|error| {
716                            tracing::trace!(
717                                ?error,
718                                "Tool result is not a valid JSON, treat it as normal string"
719                            );
720                            json!(content)
721                        });
722                    Ok(Part {
723                        thought: Some(false),
724                        thought_signature: None,
725                        part: PartKind::FunctionResponse(FunctionResponse {
726                            name: id,
727                            response: Some(json!({ "result": result })),
728                        }),
729                        additional_params: None,
730                    })
731                }
732                message::UserContent::Image(message::Image {
733                    data, media_type, ..
734                }) => match media_type {
735                    Some(media_type) => match media_type {
736                        message::ImageMediaType::JPEG
737                        | message::ImageMediaType::PNG
738                        | message::ImageMediaType::WEBP
739                        | message::ImageMediaType::HEIC
740                        | message::ImageMediaType::HEIF => {
741                            let part = PartKind::try_from((media_type, data))?;
742                            Ok(Part {
743                                thought: Some(false),
744                                thought_signature: None,
745                                part,
746                                additional_params: None,
747                            })
748                        }
749                        _ => Err(message::MessageError::ConversionError(format!(
750                            "Unsupported image media type {media_type:?}"
751                        ))),
752                    },
753                    None => Err(message::MessageError::ConversionError(
754                        "Media type for image is required for Gemini".to_string(),
755                    )),
756                },
757                message::UserContent::Document(message::Document {
758                    data, media_type, ..
759                }) => {
760                    let Some(media_type) = media_type else {
761                        return Err(MessageError::ConversionError(
762                            "A mime type is required for document inputs to Gemini".to_string(),
763                        ));
764                    };
765
766                    if !media_type.is_code() {
767                        let mime_type = media_type.to_mime_type().to_string();
768
769                        let part = match data {
770                            DocumentSourceKind::Url(file_uri) => PartKind::FileData(FileData {
771                                mime_type: Some(mime_type),
772                                file_uri,
773                            }),
774                            DocumentSourceKind::Base64(data) | DocumentSourceKind::String(data) => {
775                                PartKind::InlineData(Blob { mime_type, data })
776                            }
777                            DocumentSourceKind::Raw(_) => {
778                                return Err(message::MessageError::ConversionError(
779                                    "Raw files not supported, encode as base64 first".into(),
780                                ));
781                            }
782                            _ => {
783                                return Err(message::MessageError::ConversionError(
784                                    "Document has no body".to_string(),
785                                ));
786                            }
787                        };
788
789                        Ok(Part {
790                            thought: Some(false),
791                            part,
792                            ..Default::default()
793                        })
794                    } else {
795                        Err(message::MessageError::ConversionError(format!(
796                            "Unsupported document media type {media_type:?}"
797                        )))
798                    }
799                }
800
801                message::UserContent::Audio(message::Audio {
802                    data, media_type, ..
803                }) => {
804                    let Some(media_type) = media_type else {
805                        return Err(MessageError::ConversionError(
806                            "A mime type is required for audio inputs to Gemini".to_string(),
807                        ));
808                    };
809
810                    let mime_type = media_type.to_mime_type().to_string();
811
812                    let part = match data {
813                        DocumentSourceKind::Base64(data) => {
814                            PartKind::InlineData(Blob { data, mime_type })
815                        }
816
817                        DocumentSourceKind::Url(file_uri) => PartKind::FileData(FileData {
818                            mime_type: Some(mime_type),
819                            file_uri,
820                        }),
821                        DocumentSourceKind::String(_) => {
822                            return Err(message::MessageError::ConversionError(
823                                "Strings cannot be used as audio files!".into(),
824                            ));
825                        }
826                        DocumentSourceKind::Raw(_) => {
827                            return Err(message::MessageError::ConversionError(
828                                "Raw files not supported, encode as base64 first".into(),
829                            ));
830                        }
831                        DocumentSourceKind::Unknown => {
832                            return Err(message::MessageError::ConversionError(
833                                "Content has no body".to_string(),
834                            ));
835                        }
836                    };
837
838                    Ok(Part {
839                        thought: Some(false),
840                        part,
841                        ..Default::default()
842                    })
843                }
844                message::UserContent::Video(message::Video {
845                    data,
846                    media_type,
847                    additional_params,
848                    ..
849                }) => {
850                    let mime_type = media_type.map(|media_ty| media_ty.to_mime_type().to_string());
851
852                    let part = match data {
853                        DocumentSourceKind::Url(file_uri) => {
854                            if file_uri.starts_with("https://www.youtube.com") {
855                                PartKind::FileData(FileData {
856                                    mime_type,
857                                    file_uri,
858                                })
859                            } else {
860                                if mime_type.is_none() {
861                                    return Err(MessageError::ConversionError(
862                                        "A mime type is required for non-Youtube video file inputs to Gemini"
863                                            .to_string(),
864                                    ));
865                                }
866
867                                PartKind::FileData(FileData {
868                                    mime_type,
869                                    file_uri,
870                                })
871                            }
872                        }
873                        DocumentSourceKind::Base64(data) => {
874                            let Some(mime_type) = mime_type else {
875                                return Err(MessageError::ConversionError(
876                                    "A media type is expected for base64 encoded strings"
877                                        .to_string(),
878                                ));
879                            };
880                            PartKind::InlineData(Blob { mime_type, data })
881                        }
882                        DocumentSourceKind::String(_) => {
883                            return Err(message::MessageError::ConversionError(
884                                "Strings cannot be used as audio files!".into(),
885                            ));
886                        }
887                        DocumentSourceKind::Raw(_) => {
888                            return Err(message::MessageError::ConversionError(
889                                "Raw file data not supported, encode as base64 first".into(),
890                            ));
891                        }
892                        DocumentSourceKind::Unknown => {
893                            return Err(message::MessageError::ConversionError(
894                                "Media type for video is required for Gemini".to_string(),
895                            ));
896                        }
897                    };
898
899                    Ok(Part {
900                        thought: Some(false),
901                        thought_signature: None,
902                        part,
903                        additional_params,
904                    })
905                }
906            }
907        }
908    }
909
910    impl TryFrom<message::AssistantContent> for Part {
911        type Error = message::MessageError;
912
913        fn try_from(content: message::AssistantContent) -> Result<Self, Self::Error> {
914            match content {
915                message::AssistantContent::Text(message::Text { text }) => Ok(text.into()),
916                message::AssistantContent::Image(message::Image {
917                    data, media_type, ..
918                }) => match media_type {
919                    Some(media_type) => match media_type {
920                        message::ImageMediaType::JPEG
921                        | message::ImageMediaType::PNG
922                        | message::ImageMediaType::WEBP
923                        | message::ImageMediaType::HEIC
924                        | message::ImageMediaType::HEIF => {
925                            let part = PartKind::try_from((media_type, data))?;
926                            Ok(Part {
927                                thought: Some(false),
928                                thought_signature: None,
929                                part,
930                                additional_params: None,
931                            })
932                        }
933                        _ => Err(message::MessageError::ConversionError(format!(
934                            "Unsupported image media type {media_type:?}"
935                        ))),
936                    },
937                    None => Err(message::MessageError::ConversionError(
938                        "Media type for image is required for Gemini".to_string(),
939                    )),
940                },
941                message::AssistantContent::ToolCall(tool_call) => Ok(tool_call.into()),
942                message::AssistantContent::Reasoning(message::Reasoning { reasoning, .. }) => {
943                    Ok(Part {
944                        thought: Some(true),
945                        thought_signature: None,
946                        part: PartKind::Text(
947                            reasoning.first().cloned().unwrap_or_else(|| "".to_string()),
948                        ),
949                        additional_params: None,
950                    })
951                }
952            }
953        }
954    }
955
956    impl From<message::ToolCall> for Part {
957        fn from(tool_call: message::ToolCall) -> Self {
958            Self {
959                thought: Some(false),
960                thought_signature: tool_call.signature,
961                part: PartKind::FunctionCall(FunctionCall {
962                    name: tool_call.function.name,
963                    args: tool_call.function.arguments,
964                }),
965                additional_params: None,
966            }
967        }
968    }
969
970    /// Raw media bytes.
971    /// Text should not be sent as raw bytes, use the 'text' field.
972    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
973    #[serde(rename_all = "camelCase")]
974    pub struct Blob {
975        /// The IANA standard MIME type of the source data. Examples: - image/png - image/jpeg
976        /// If an unsupported MIME type is provided, an error will be returned.
977        pub mime_type: String,
978        /// Raw bytes for media formats. A base64-encoded string.
979        pub data: String,
980    }
981
982    /// A predicted FunctionCall returned from the model that contains a string representing the
983    /// FunctionDeclaration.name with the arguments and their values.
984    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
985    pub struct FunctionCall {
986        /// Required. The name of the function to call. Must be a-z, A-Z, 0-9, or contain underscores
987        /// and dashes, with a maximum length of 63.
988        pub name: String,
989        /// Optional. The function parameters and values in JSON object format.
990        pub args: serde_json::Value,
991    }
992
993    impl From<message::ToolCall> for FunctionCall {
994        fn from(tool_call: message::ToolCall) -> Self {
995            Self {
996                name: tool_call.function.name,
997                args: tool_call.function.arguments,
998            }
999        }
1000    }
1001
1002    /// The result output from a FunctionCall that contains a string representing the FunctionDeclaration.name
1003    /// and a structured JSON object containing any output from the function is used as context to the model.
1004    /// This should contain the result of aFunctionCall made based on model prediction.
1005    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1006    pub struct FunctionResponse {
1007        /// The name of the function to call. Must be a-z, A-Z, 0-9, or contain underscores and dashes,
1008        /// with a maximum length of 63.
1009        pub name: String,
1010        /// The function response in JSON object format.
1011        pub response: Option<serde_json::Value>,
1012    }
1013
1014    /// URI based data.
1015    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1016    #[serde(rename_all = "camelCase")]
1017    pub struct FileData {
1018        /// Optional. The IANA standard MIME type of the source data.
1019        pub mime_type: Option<String>,
1020        /// Required. URI.
1021        pub file_uri: String,
1022    }
1023
1024    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1025    pub struct SafetyRating {
1026        pub category: HarmCategory,
1027        pub probability: HarmProbability,
1028    }
1029
1030    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1031    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1032    pub enum HarmProbability {
1033        HarmProbabilityUnspecified,
1034        Negligible,
1035        Low,
1036        Medium,
1037        High,
1038    }
1039
1040    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1041    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1042    pub enum HarmCategory {
1043        HarmCategoryUnspecified,
1044        HarmCategoryDerogatory,
1045        HarmCategoryToxicity,
1046        HarmCategoryViolence,
1047        HarmCategorySexually,
1048        HarmCategoryMedical,
1049        HarmCategoryDangerous,
1050        HarmCategoryHarassment,
1051        HarmCategoryHateSpeech,
1052        HarmCategorySexuallyExplicit,
1053        HarmCategoryDangerousContent,
1054        HarmCategoryCivicIntegrity,
1055    }
1056
1057    #[derive(Debug, Deserialize, Clone, Default, Serialize)]
1058    #[serde(rename_all = "camelCase")]
1059    pub struct UsageMetadata {
1060        pub prompt_token_count: i32,
1061        #[serde(skip_serializing_if = "Option::is_none")]
1062        pub cached_content_token_count: Option<i32>,
1063        #[serde(skip_serializing_if = "Option::is_none")]
1064        pub candidates_token_count: Option<i32>,
1065        pub total_token_count: i32,
1066        #[serde(skip_serializing_if = "Option::is_none")]
1067        pub thoughts_token_count: Option<i32>,
1068    }
1069
1070    impl std::fmt::Display for UsageMetadata {
1071        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1072            write!(
1073                f,
1074                "Prompt token count: {}\nCached content token count: {}\nCandidates token count: {}\nTotal token count: {}",
1075                self.prompt_token_count,
1076                match self.cached_content_token_count {
1077                    Some(count) => count.to_string(),
1078                    None => "n/a".to_string(),
1079                },
1080                match self.candidates_token_count {
1081                    Some(count) => count.to_string(),
1082                    None => "n/a".to_string(),
1083                },
1084                self.total_token_count
1085            )
1086        }
1087    }
1088
1089    impl GetTokenUsage for UsageMetadata {
1090        fn token_usage(&self) -> Option<crate::completion::Usage> {
1091            let mut usage = crate::completion::Usage::new();
1092
1093            usage.input_tokens = self.prompt_token_count as u64;
1094            usage.output_tokens = (self.cached_content_token_count.unwrap_or_default()
1095                + self.candidates_token_count.unwrap_or_default()
1096                + self.thoughts_token_count.unwrap_or_default())
1097                as u64;
1098            usage.total_tokens = usage.input_tokens + usage.output_tokens;
1099
1100            Some(usage)
1101        }
1102    }
1103
1104    /// A set of the feedback metadata the prompt specified in [GenerateContentRequest.contents](GenerateContentRequest).
1105    #[derive(Debug, Deserialize, Serialize)]
1106    #[serde(rename_all = "camelCase")]
1107    pub struct PromptFeedback {
1108        /// Optional. If set, the prompt was blocked and no candidates are returned. Rephrase the prompt.
1109        pub block_reason: Option<BlockReason>,
1110        /// Ratings for safety of the prompt. There is at most one rating per category.
1111        pub safety_ratings: Option<Vec<SafetyRating>>,
1112    }
1113
1114    /// Reason why a prompt was blocked by the model
1115    #[derive(Debug, Deserialize, Serialize)]
1116    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1117    pub enum BlockReason {
1118        /// Default value. This value is unused.
1119        BlockReasonUnspecified,
1120        /// Prompt was blocked due to safety reasons. Inspect safetyRatings to understand which safety category blocked it.
1121        Safety,
1122        /// Prompt was blocked due to unknown reasons.
1123        Other,
1124        /// Prompt was blocked due to the terms which are included from the terminology blocklist.
1125        Blocklist,
1126        /// Prompt was blocked due to prohibited content.
1127        ProhibitedContent,
1128    }
1129
1130    #[derive(Clone, Debug, Deserialize, Serialize)]
1131    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1132    pub enum FinishReason {
1133        /// Default value. This value is unused.
1134        FinishReasonUnspecified,
1135        /// Natural stop point of the model or provided stop sequence.
1136        Stop,
1137        /// The maximum number of tokens as specified in the request was reached.
1138        MaxTokens,
1139        /// The response candidate content was flagged for safety reasons.
1140        Safety,
1141        /// The response candidate content was flagged for recitation reasons.
1142        Recitation,
1143        /// The response candidate content was flagged for using an unsupported language.
1144        Language,
1145        /// Unknown reason.
1146        Other,
1147        /// Token generation stopped because the content contains forbidden terms.
1148        Blocklist,
1149        /// Token generation stopped for potentially containing prohibited content.
1150        ProhibitedContent,
1151        /// Token generation stopped because the content potentially contains Sensitive Personally Identifiable Information (SPII).
1152        Spii,
1153        /// The function call generated by the model is invalid.
1154        MalformedFunctionCall,
1155    }
1156
1157    #[derive(Clone, Debug, Deserialize, Serialize)]
1158    #[serde(rename_all = "camelCase")]
1159    pub struct CitationMetadata {
1160        pub citation_sources: Vec<CitationSource>,
1161    }
1162
1163    #[derive(Clone, Debug, Deserialize, Serialize)]
1164    #[serde(rename_all = "camelCase")]
1165    pub struct CitationSource {
1166        #[serde(skip_serializing_if = "Option::is_none")]
1167        pub uri: Option<String>,
1168        #[serde(skip_serializing_if = "Option::is_none")]
1169        pub start_index: Option<i32>,
1170        #[serde(skip_serializing_if = "Option::is_none")]
1171        pub end_index: Option<i32>,
1172        #[serde(skip_serializing_if = "Option::is_none")]
1173        pub license: Option<String>,
1174    }
1175
1176    #[derive(Clone, Debug, Deserialize, Serialize)]
1177    #[serde(rename_all = "camelCase")]
1178    pub struct LogprobsResult {
1179        pub top_candidate: Vec<TopCandidate>,
1180        pub chosen_candidate: Vec<LogProbCandidate>,
1181    }
1182
1183    #[derive(Clone, Debug, Deserialize, Serialize)]
1184    pub struct TopCandidate {
1185        pub candidates: Vec<LogProbCandidate>,
1186    }
1187
1188    #[derive(Clone, Debug, Deserialize, Serialize)]
1189    #[serde(rename_all = "camelCase")]
1190    pub struct LogProbCandidate {
1191        pub token: String,
1192        pub token_id: String,
1193        pub log_probability: f64,
1194    }
1195
1196    /// Gemini API Configuration options for model generation and outputs. Not all parameters are
1197    /// configurable for every model. From [Gemini API Reference](https://ai.google.dev/api/generate-content#generationconfig)
1198    /// ### Rig Note:
1199    /// Can be used to construct a typesafe `additional_params` in rig::[AgentBuilder](crate::agent::AgentBuilder).
1200    #[derive(Debug, Deserialize, Serialize)]
1201    #[serde(rename_all = "camelCase")]
1202    pub struct GenerationConfig {
1203        /// The set of character sequences (up to 5) that will stop output generation. If specified, the API will stop
1204        /// at the first appearance of a stop_sequence. The stop sequence will not be included as part of the response.
1205        #[serde(skip_serializing_if = "Option::is_none")]
1206        pub stop_sequences: Option<Vec<String>>,
1207        /// MIME type of the generated candidate text. Supported MIME types are:
1208        ///     - text/plain:  (default) Text output
1209        ///     - application/json: JSON response in the response candidates.
1210        ///     - text/x.enum: ENUM as a string response in the response candidates.
1211        /// Refer to the docs for a list of all supported text MIME types
1212        #[serde(skip_serializing_if = "Option::is_none")]
1213        pub response_mime_type: Option<String>,
1214        /// Output schema of the generated candidate text. Schemas must be a subset of the OpenAPI schema and can be
1215        /// objects, primitives or arrays. If set, a compatible responseMimeType must also  be set. Compatible MIME
1216        /// types: application/json: Schema for JSON response. Refer to the JSON text generation guide for more details.
1217        #[serde(skip_serializing_if = "Option::is_none")]
1218        pub response_schema: Option<Schema>,
1219        /// Optional. The output schema of the generated response.
1220        /// This is an alternative to responseSchema that accepts a standard JSON Schema.
1221        /// If this is set, responseSchema must be omitted.
1222        /// Compatible MIME type: application/json.
1223        /// Supported properties: $id, $defs, $ref, type, properties, etc.
1224        #[serde(
1225            skip_serializing_if = "Option::is_none",
1226            rename = "_responseJsonSchema"
1227        )]
1228        pub _response_json_schema: Option<Value>,
1229        /// Internal or alternative representation for `response_json_schema`.
1230        #[serde(skip_serializing_if = "Option::is_none")]
1231        pub response_json_schema: Option<Value>,
1232        /// Number of generated responses to return. Currently, this value can only be set to 1. If
1233        /// unset, this will default to 1.
1234        #[serde(skip_serializing_if = "Option::is_none")]
1235        pub candidate_count: Option<i32>,
1236        /// The maximum number of tokens to include in a response candidate. Note: The default value varies by model, see
1237        /// the Model.output_token_limit attribute of the Model returned from the getModel function.
1238        #[serde(skip_serializing_if = "Option::is_none")]
1239        pub max_output_tokens: Option<u64>,
1240        /// Controls the randomness of the output. Note: The default value varies by model, see the Model.temperature
1241        /// attribute of the Model returned from the getModel function. Values can range from [0.0, 2.0].
1242        #[serde(skip_serializing_if = "Option::is_none")]
1243        pub temperature: Option<f64>,
1244        /// The maximum cumulative probability of tokens to consider when sampling. The model uses combined Top-k and
1245        /// Top-p (nucleus) sampling. Tokens are sorted based on their assigned probabilities so that only the most
1246        /// likely tokens are considered. Top-k sampling directly limits the maximum number of tokens to consider, while
1247        /// Nucleus sampling limits the number of tokens based on the cumulative probability. Note: The default value
1248        /// varies by Model and is specified by theModel.top_p attribute returned from the getModel function. An empty
1249        /// topK attribute indicates that the model doesn't apply top-k sampling and doesn't allow setting topK on requests.
1250        #[serde(skip_serializing_if = "Option::is_none")]
1251        pub top_p: Option<f64>,
1252        /// The maximum number of tokens to consider when sampling. Gemini models use Top-p (nucleus) sampling or a
1253        /// combination of Top-k and nucleus sampling. Top-k sampling considers the set of topK most probable tokens.
1254        /// Models running with nucleus sampling don't allow topK setting. Note: The default value varies by Model and is
1255        /// specified by theModel.top_p attribute returned from the getModel function. An empty topK attribute indicates
1256        /// that the model doesn't apply top-k sampling and doesn't allow setting topK on requests.
1257        #[serde(skip_serializing_if = "Option::is_none")]
1258        pub top_k: Option<i32>,
1259        /// Presence penalty applied to the next token's logprobs if the token has already been seen in the response.
1260        /// This penalty is binary on/off and not dependent on the number of times the token is used (after the first).
1261        /// Use frequencyPenalty for a penalty that increases with each use. A positive penalty will discourage the use
1262        /// of tokens that have already been used in the response, increasing the vocabulary. A negative penalty will
1263        /// encourage the use of tokens that have already been used in the response, decreasing the vocabulary.
1264        #[serde(skip_serializing_if = "Option::is_none")]
1265        pub presence_penalty: Option<f64>,
1266        /// Frequency penalty applied to the next token's logprobs, multiplied by the number of times each token has been
1267        /// seen in the response so far. A positive penalty will discourage the use of tokens that have already been
1268        /// used, proportional to the number of times the token has been used: The more a token is used, the more
1269        /// difficult it is for the  model to use that token again increasing the vocabulary of responses. Caution: A
1270        /// negative penalty will encourage the model to reuse tokens proportional to the number of times the token has
1271        /// been used. Small negative values will reduce the vocabulary of a response. Larger negative values will cause
1272        /// the model to  repeating a common token until it hits the maxOutputTokens limit: "...the the the the the...".
1273        #[serde(skip_serializing_if = "Option::is_none")]
1274        pub frequency_penalty: Option<f64>,
1275        /// If true, export the logprobs results in response.
1276        #[serde(skip_serializing_if = "Option::is_none")]
1277        pub response_logprobs: Option<bool>,
1278        /// Only valid if responseLogprobs=True. This sets the number of top logprobs to return at each decoding step in
1279        /// [Candidate.logprobs_result].
1280        #[serde(skip_serializing_if = "Option::is_none")]
1281        pub logprobs: Option<i32>,
1282        /// Configuration for thinking/reasoning.
1283        #[serde(skip_serializing_if = "Option::is_none")]
1284        pub thinking_config: Option<ThinkingConfig>,
1285        #[serde(skip_serializing_if = "Option::is_none")]
1286        pub image_config: Option<ImageConfig>,
1287    }
1288
1289    impl Default for GenerationConfig {
1290        fn default() -> Self {
1291            Self {
1292                temperature: Some(1.0),
1293                max_output_tokens: Some(4096),
1294                stop_sequences: None,
1295                response_mime_type: None,
1296                response_schema: None,
1297                _response_json_schema: None,
1298                response_json_schema: None,
1299                candidate_count: None,
1300                top_p: None,
1301                top_k: None,
1302                presence_penalty: None,
1303                frequency_penalty: None,
1304                response_logprobs: None,
1305                logprobs: None,
1306                thinking_config: None,
1307                image_config: None,
1308            }
1309        }
1310    }
1311
1312    #[derive(Debug, Deserialize, Serialize)]
1313    #[serde(rename_all = "camelCase")]
1314    pub struct ThinkingConfig {
1315        pub thinking_budget: u32,
1316        pub include_thoughts: Option<bool>,
1317    }
1318
1319    #[derive(Debug, Deserialize, Serialize)]
1320    #[serde(rename_all = "camelCase")]
1321    pub struct ImageConfig {
1322        #[serde(skip_serializing_if = "Option::is_none")]
1323        pub aspect_ratio: Option<String>,
1324        #[serde(skip_serializing_if = "Option::is_none")]
1325        pub image_size: Option<String>,
1326    }
1327
1328    /// The Schema object allows the definition of input and output data types. These types can be objects, but also
1329    /// primitives and arrays. Represents a select subset of an OpenAPI 3.0 schema object.
1330    /// From [Gemini API Reference](https://ai.google.dev/api/caching#Schema)
1331    #[derive(Debug, Deserialize, Serialize, Clone)]
1332    pub struct Schema {
1333        pub r#type: String,
1334        #[serde(skip_serializing_if = "Option::is_none")]
1335        pub format: Option<String>,
1336        #[serde(skip_serializing_if = "Option::is_none")]
1337        pub description: Option<String>,
1338        #[serde(skip_serializing_if = "Option::is_none")]
1339        pub nullable: Option<bool>,
1340        #[serde(skip_serializing_if = "Option::is_none")]
1341        pub r#enum: Option<Vec<String>>,
1342        #[serde(skip_serializing_if = "Option::is_none")]
1343        pub max_items: Option<i32>,
1344        #[serde(skip_serializing_if = "Option::is_none")]
1345        pub min_items: Option<i32>,
1346        #[serde(skip_serializing_if = "Option::is_none")]
1347        pub properties: Option<HashMap<String, Schema>>,
1348        #[serde(skip_serializing_if = "Option::is_none")]
1349        pub required: Option<Vec<String>>,
1350        #[serde(skip_serializing_if = "Option::is_none")]
1351        pub items: Option<Box<Schema>>,
1352    }
1353
1354    /// Flattens a JSON schema by resolving all `$ref` references inline.
1355    /// It takes a JSON schema that may contain `$ref` references to definitions
1356    /// in `$defs` or `definitions` sections and returns a new schema with all references
1357    /// resolved and inlined. This is necessary for APIs like Gemini that don't support
1358    /// schema references.
1359    pub fn flatten_schema(mut schema: Value) -> Result<Value, CompletionError> {
1360        // extracting $defs if they exist
1361        let defs = if let Some(obj) = schema.as_object() {
1362            obj.get("$defs").or_else(|| obj.get("definitions")).cloned()
1363        } else {
1364            None
1365        };
1366
1367        let Some(defs_value) = defs else {
1368            return Ok(schema);
1369        };
1370
1371        let Some(defs_obj) = defs_value.as_object() else {
1372            return Err(CompletionError::ResponseError(
1373                "$defs must be an object".into(),
1374            ));
1375        };
1376
1377        resolve_refs(&mut schema, defs_obj)?;
1378
1379        // removing $defs from the final schema because we have inlined everything
1380        if let Some(obj) = schema.as_object_mut() {
1381            obj.remove("$defs");
1382            obj.remove("definitions");
1383        }
1384
1385        Ok(schema)
1386    }
1387
1388    /// Recursively resolves all `$ref` references in a JSON value by
1389    /// replacing them with their definitions.
1390    fn resolve_refs(
1391        value: &mut Value,
1392        defs: &serde_json::Map<String, Value>,
1393    ) -> Result<(), CompletionError> {
1394        match value {
1395            Value::Object(obj) => {
1396                if let Some(ref_value) = obj.get("$ref")
1397                    && let Some(ref_str) = ref_value.as_str()
1398                {
1399                    // "#/$defs/Person" -> "Person"
1400                    let def_name = parse_ref_path(ref_str)?;
1401
1402                    let def = defs.get(&def_name).ok_or_else(|| {
1403                        CompletionError::ResponseError(format!("Reference not found: {}", ref_str))
1404                    })?;
1405
1406                    let mut resolved = def.clone();
1407                    resolve_refs(&mut resolved, defs)?;
1408                    *value = resolved;
1409                    return Ok(());
1410                }
1411
1412                for (_, v) in obj.iter_mut() {
1413                    resolve_refs(v, defs)?;
1414                }
1415            }
1416            Value::Array(arr) => {
1417                for item in arr.iter_mut() {
1418                    resolve_refs(item, defs)?;
1419                }
1420            }
1421            _ => {}
1422        }
1423
1424        Ok(())
1425    }
1426
1427    /// Parses a JSON Schema `$ref` path to extract the definition name.
1428    ///
1429    /// JSON Schema references use URI fragment syntax to point to definitions within
1430    /// the same document. This function extracts the definition name from common
1431    /// reference patterns used in JSON Schema.
1432    fn parse_ref_path(ref_str: &str) -> Result<String, CompletionError> {
1433        if let Some(fragment) = ref_str.strip_prefix('#') {
1434            if let Some(name) = fragment.strip_prefix("/$defs/") {
1435                Ok(name.to_string())
1436            } else if let Some(name) = fragment.strip_prefix("/definitions/") {
1437                Ok(name.to_string())
1438            } else {
1439                Err(CompletionError::ResponseError(format!(
1440                    "Unsupported reference format: {}",
1441                    ref_str
1442                )))
1443            }
1444        } else {
1445            Err(CompletionError::ResponseError(format!(
1446                "Only fragment references (#/...) are supported: {}",
1447                ref_str
1448            )))
1449        }
1450    }
1451
1452    /// Helper function to extract the type string from a JSON value.
1453    /// Handles both direct string types and array types (returns the first element).
1454    fn extract_type(type_value: &Value) -> Option<String> {
1455        if type_value.is_string() {
1456            type_value.as_str().map(String::from)
1457        } else if type_value.is_array() {
1458            type_value
1459                .as_array()
1460                .and_then(|arr| arr.first())
1461                .and_then(|v| v.as_str().map(String::from))
1462        } else {
1463            None
1464        }
1465    }
1466
1467    /// Helper function to extract type from anyOf, oneOf, or allOf schemas.
1468    /// Returns the type of the first non-null schema found.
1469    fn extract_type_from_composition(composition: &Value) -> Option<String> {
1470        composition.as_array().and_then(|arr| {
1471            arr.iter().find_map(|schema| {
1472                if let Some(obj) = schema.as_object() {
1473                    // Skip null types
1474                    if let Some(type_val) = obj.get("type")
1475                        && let Some(type_str) = type_val.as_str()
1476                        && type_str == "null"
1477                    {
1478                        return None;
1479                    }
1480                    // Extract type from this schema
1481                    obj.get("type").and_then(extract_type).or_else(|| {
1482                        if obj.contains_key("properties") {
1483                            Some("object".to_string())
1484                        } else {
1485                            None
1486                        }
1487                    })
1488                } else {
1489                    None
1490                }
1491            })
1492        })
1493    }
1494
1495    /// Helper function to extract the first non-null schema from anyOf, oneOf, or allOf.
1496    /// Returns the schema object that should be used for properties, required, etc.
1497    fn extract_schema_from_composition(
1498        composition: &Value,
1499    ) -> Option<serde_json::Map<String, Value>> {
1500        composition.as_array().and_then(|arr| {
1501            arr.iter().find_map(|schema| {
1502                if let Some(obj) = schema.as_object()
1503                    && let Some(type_val) = obj.get("type")
1504                    && let Some(type_str) = type_val.as_str()
1505                {
1506                    if type_str == "null" {
1507                        return None;
1508                    }
1509                    Some(obj.clone())
1510                } else {
1511                    None
1512                }
1513            })
1514        })
1515    }
1516
1517    /// Helper function to infer the type of a schema object.
1518    /// Checks for explicit type, then anyOf/oneOf/allOf, then infers from properties.
1519    fn infer_type(obj: &serde_json::Map<String, Value>) -> String {
1520        // First, try direct type field
1521        if let Some(type_val) = obj.get("type")
1522            && let Some(type_str) = extract_type(type_val)
1523        {
1524            return type_str;
1525        }
1526
1527        // Then try anyOf, oneOf, allOf (in that order)
1528        if let Some(any_of) = obj.get("anyOf")
1529            && let Some(type_str) = extract_type_from_composition(any_of)
1530        {
1531            return type_str;
1532        }
1533
1534        if let Some(one_of) = obj.get("oneOf")
1535            && let Some(type_str) = extract_type_from_composition(one_of)
1536        {
1537            return type_str;
1538        }
1539
1540        if let Some(all_of) = obj.get("allOf")
1541            && let Some(type_str) = extract_type_from_composition(all_of)
1542        {
1543            return type_str;
1544        }
1545
1546        // Finally, infer object type if properties are present
1547        if obj.contains_key("properties") {
1548            "object".to_string()
1549        } else {
1550            String::new()
1551        }
1552    }
1553
1554    impl TryFrom<Value> for Schema {
1555        type Error = CompletionError;
1556
1557        fn try_from(value: Value) -> Result<Self, Self::Error> {
1558            let flattened_val = flatten_schema(value)?;
1559            if let Some(obj) = flattened_val.as_object() {
1560                // Determine which object to use for extracting properties and required fields.
1561                // If this object has anyOf/oneOf/allOf, we need to extract properties from the composition.
1562                let props_source = if obj.get("properties").is_none() {
1563                    if let Some(any_of) = obj.get("anyOf") {
1564                        extract_schema_from_composition(any_of)
1565                    } else if let Some(one_of) = obj.get("oneOf") {
1566                        extract_schema_from_composition(one_of)
1567                    } else if let Some(all_of) = obj.get("allOf") {
1568                        extract_schema_from_composition(all_of)
1569                    } else {
1570                        None
1571                    }
1572                    .unwrap_or(obj.clone())
1573                } else {
1574                    obj.clone()
1575                };
1576
1577                Ok(Schema {
1578                    r#type: infer_type(obj),
1579                    format: obj.get("format").and_then(|v| v.as_str()).map(String::from),
1580                    description: obj
1581                        .get("description")
1582                        .and_then(|v| v.as_str())
1583                        .map(String::from),
1584                    nullable: obj.get("nullable").and_then(|v| v.as_bool()),
1585                    r#enum: obj.get("enum").and_then(|v| v.as_array()).map(|arr| {
1586                        arr.iter()
1587                            .filter_map(|v| v.as_str().map(String::from))
1588                            .collect()
1589                    }),
1590                    max_items: obj
1591                        .get("maxItems")
1592                        .and_then(|v| v.as_i64())
1593                        .map(|v| v as i32),
1594                    min_items: obj
1595                        .get("minItems")
1596                        .and_then(|v| v.as_i64())
1597                        .map(|v| v as i32),
1598                    properties: props_source
1599                        .get("properties")
1600                        .and_then(|v| v.as_object())
1601                        .map(|map| {
1602                            map.iter()
1603                                .filter_map(|(k, v)| {
1604                                    v.clone().try_into().ok().map(|schema| (k.clone(), schema))
1605                                })
1606                                .collect()
1607                        }),
1608                    required: props_source
1609                        .get("required")
1610                        .and_then(|v| v.as_array())
1611                        .map(|arr| {
1612                            arr.iter()
1613                                .filter_map(|v| v.as_str().map(String::from))
1614                                .collect()
1615                        }),
1616                    items: obj
1617                        .get("items")
1618                        .and_then(|v| v.clone().try_into().ok())
1619                        .map(Box::new),
1620                })
1621            } else {
1622                Err(CompletionError::ResponseError(
1623                    "Expected a JSON object for Schema".into(),
1624                ))
1625            }
1626        }
1627    }
1628
1629    #[derive(Debug, Serialize)]
1630    #[serde(rename_all = "camelCase")]
1631    pub struct GenerateContentRequest {
1632        pub contents: Vec<Content>,
1633        #[serde(skip_serializing_if = "Option::is_none")]
1634        pub tools: Option<Tool>,
1635        pub tool_config: Option<ToolConfig>,
1636        /// Optional. Configuration options for model generation and outputs.
1637        pub generation_config: Option<GenerationConfig>,
1638        /// Optional. A list of unique SafetySetting instances for blocking unsafe content. This will be enforced on the
1639        /// [GenerateContentRequest.contents] and [GenerateContentResponse.candidates]. There should not be more than one
1640        /// setting for each SafetyCategory type. The API will block any contents and responses that fail to meet the
1641        /// thresholds set by these settings. This list overrides the default settings for each SafetyCategory specified
1642        /// in the safetySettings. If there is no SafetySetting for a given SafetyCategory provided in the list, the API
1643        /// will use the default safety setting for that category. Harm categories:
1644        ///     - HARM_CATEGORY_HATE_SPEECH,
1645        ///     - HARM_CATEGORY_SEXUALLY_EXPLICIT
1646        ///     - HARM_CATEGORY_DANGEROUS_CONTENT
1647        ///     - HARM_CATEGORY_HARASSMENT
1648        /// are supported.
1649        /// Refer to the guide for detailed information on available safety settings. Also refer to the Safety guidance
1650        /// to learn how to incorporate safety considerations in your AI applications.
1651        pub safety_settings: Option<Vec<SafetySetting>>,
1652        /// Optional. Developer set system instruction(s). Currently, text only.
1653        /// From [Gemini API Reference](https://ai.google.dev/gemini-api/docs/system-instructions?lang=rest)
1654        pub system_instruction: Option<Content>,
1655        // cachedContent: Optional<String>
1656        /// Additional parameters.
1657        #[serde(flatten, skip_serializing_if = "Option::is_none")]
1658        pub additional_params: Option<serde_json::Value>,
1659    }
1660
1661    #[derive(Debug, Serialize)]
1662    #[serde(rename_all = "camelCase")]
1663    pub struct Tool {
1664        pub function_declarations: Vec<FunctionDeclaration>,
1665        pub code_execution: Option<CodeExecution>,
1666    }
1667
1668    #[derive(Debug, Serialize, Clone)]
1669    #[serde(rename_all = "camelCase")]
1670    pub struct FunctionDeclaration {
1671        pub name: String,
1672        pub description: String,
1673        #[serde(skip_serializing_if = "Option::is_none")]
1674        pub parameters: Option<Schema>,
1675    }
1676
1677    #[derive(Debug, Serialize, Deserialize)]
1678    #[serde(rename_all = "camelCase")]
1679    pub struct ToolConfig {
1680        pub function_calling_config: Option<FunctionCallingMode>,
1681    }
1682
1683    #[derive(Debug, Serialize, Deserialize, Default)]
1684    #[serde(tag = "mode", rename_all = "UPPERCASE")]
1685    pub enum FunctionCallingMode {
1686        #[default]
1687        Auto,
1688        None,
1689        Any {
1690            #[serde(skip_serializing_if = "Option::is_none")]
1691            allowed_function_names: Option<Vec<String>>,
1692        },
1693    }
1694
1695    impl TryFrom<message::ToolChoice> for FunctionCallingMode {
1696        type Error = CompletionError;
1697        fn try_from(value: message::ToolChoice) -> Result<Self, Self::Error> {
1698            let res = match value {
1699                message::ToolChoice::Auto => Self::Auto,
1700                message::ToolChoice::None => Self::None,
1701                message::ToolChoice::Required => Self::Any {
1702                    allowed_function_names: None,
1703                },
1704                message::ToolChoice::Specific { function_names } => Self::Any {
1705                    allowed_function_names: Some(function_names),
1706                },
1707            };
1708
1709            Ok(res)
1710        }
1711    }
1712
1713    #[derive(Debug, Serialize)]
1714    pub struct CodeExecution {}
1715
1716    #[derive(Debug, Serialize)]
1717    #[serde(rename_all = "camelCase")]
1718    pub struct SafetySetting {
1719        pub category: HarmCategory,
1720        pub threshold: HarmBlockThreshold,
1721    }
1722
1723    #[derive(Debug, Serialize)]
1724    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1725    pub enum HarmBlockThreshold {
1726        HarmBlockThresholdUnspecified,
1727        BlockLowAndAbove,
1728        BlockMediumAndAbove,
1729        BlockOnlyHigh,
1730        BlockNone,
1731        Off,
1732    }
1733}
1734
1735#[cfg(test)]
1736mod tests {
1737    use crate::{message, providers::gemini::completion::gemini_api_types::flatten_schema};
1738
1739    use super::*;
1740    use serde_json::json;
1741
1742    #[test]
1743    fn test_deserialize_message_user() {
1744        let raw_message = r#"{
1745            "parts": [
1746                {"text": "Hello, world!"},
1747                {"inlineData": {"mimeType": "image/png", "data": "base64encodeddata"}},
1748                {"functionCall": {"name": "test_function", "args": {"arg1": "value1"}}},
1749                {"functionResponse": {"name": "test_function", "response": {"result": "success"}}},
1750                {"fileData": {"mimeType": "application/pdf", "fileUri": "http://example.com/file.pdf"}},
1751                {"executableCode": {"code": "print('Hello, world!')", "language": "PYTHON"}},
1752                {"codeExecutionResult": {"output": "Hello, world!", "outcome": "OUTCOME_OK"}}
1753            ],
1754            "role": "user"
1755        }"#;
1756
1757        let content: Content = {
1758            let jd = &mut serde_json::Deserializer::from_str(raw_message);
1759            serde_path_to_error::deserialize(jd).unwrap_or_else(|err| {
1760                panic!("Deserialization error at {}: {}", err.path(), err);
1761            })
1762        };
1763        assert_eq!(content.role, Some(Role::User));
1764        assert_eq!(content.parts.len(), 7);
1765
1766        let parts: Vec<Part> = content.parts.into_iter().collect();
1767
1768        if let Part {
1769            part: PartKind::Text(text),
1770            ..
1771        } = &parts[0]
1772        {
1773            assert_eq!(text, "Hello, world!");
1774        } else {
1775            panic!("Expected text part");
1776        }
1777
1778        if let Part {
1779            part: PartKind::InlineData(inline_data),
1780            ..
1781        } = &parts[1]
1782        {
1783            assert_eq!(inline_data.mime_type, "image/png");
1784            assert_eq!(inline_data.data, "base64encodeddata");
1785        } else {
1786            panic!("Expected inline data part");
1787        }
1788
1789        if let Part {
1790            part: PartKind::FunctionCall(function_call),
1791            ..
1792        } = &parts[2]
1793        {
1794            assert_eq!(function_call.name, "test_function");
1795            assert_eq!(
1796                function_call.args.as_object().unwrap().get("arg1").unwrap(),
1797                "value1"
1798            );
1799        } else {
1800            panic!("Expected function call part");
1801        }
1802
1803        if let Part {
1804            part: PartKind::FunctionResponse(function_response),
1805            ..
1806        } = &parts[3]
1807        {
1808            assert_eq!(function_response.name, "test_function");
1809            assert_eq!(
1810                function_response
1811                    .response
1812                    .as_ref()
1813                    .unwrap()
1814                    .get("result")
1815                    .unwrap(),
1816                "success"
1817            );
1818        } else {
1819            panic!("Expected function response part");
1820        }
1821
1822        if let Part {
1823            part: PartKind::FileData(file_data),
1824            ..
1825        } = &parts[4]
1826        {
1827            assert_eq!(file_data.mime_type.as_ref().unwrap(), "application/pdf");
1828            assert_eq!(file_data.file_uri, "http://example.com/file.pdf");
1829        } else {
1830            panic!("Expected file data part");
1831        }
1832
1833        if let Part {
1834            part: PartKind::ExecutableCode(executable_code),
1835            ..
1836        } = &parts[5]
1837        {
1838            assert_eq!(executable_code.code, "print('Hello, world!')");
1839        } else {
1840            panic!("Expected executable code part");
1841        }
1842
1843        if let Part {
1844            part: PartKind::CodeExecutionResult(code_execution_result),
1845            ..
1846        } = &parts[6]
1847        {
1848            assert_eq!(
1849                code_execution_result.clone().output.unwrap(),
1850                "Hello, world!"
1851            );
1852        } else {
1853            panic!("Expected code execution result part");
1854        }
1855    }
1856
1857    #[test]
1858    fn test_deserialize_message_model() {
1859        let json_data = json!({
1860            "parts": [{"text": "Hello, user!"}],
1861            "role": "model"
1862        });
1863
1864        let content: Content = serde_json::from_value(json_data).unwrap();
1865        assert_eq!(content.role, Some(Role::Model));
1866        assert_eq!(content.parts.len(), 1);
1867        if let Some(Part {
1868            part: PartKind::Text(text),
1869            ..
1870        }) = content.parts.first()
1871        {
1872            assert_eq!(text, "Hello, user!");
1873        } else {
1874            panic!("Expected text part");
1875        }
1876    }
1877
1878    #[test]
1879    fn test_message_conversion_user() {
1880        let msg = message::Message::user("Hello, world!");
1881        let content: Content = msg.try_into().unwrap();
1882        assert_eq!(content.role, Some(Role::User));
1883        assert_eq!(content.parts.len(), 1);
1884        if let Some(Part {
1885            part: PartKind::Text(text),
1886            ..
1887        }) = &content.parts.first()
1888        {
1889            assert_eq!(text, "Hello, world!");
1890        } else {
1891            panic!("Expected text part");
1892        }
1893    }
1894
1895    #[test]
1896    fn test_message_conversion_model() {
1897        let msg = message::Message::assistant("Hello, user!");
1898
1899        let content: Content = msg.try_into().unwrap();
1900        assert_eq!(content.role, Some(Role::Model));
1901        assert_eq!(content.parts.len(), 1);
1902        if let Some(Part {
1903            part: PartKind::Text(text),
1904            ..
1905        }) = &content.parts.first()
1906        {
1907            assert_eq!(text, "Hello, user!");
1908        } else {
1909            panic!("Expected text part");
1910        }
1911    }
1912
1913    #[test]
1914    fn test_message_conversion_tool_call() {
1915        let tool_call = message::ToolCall {
1916            id: "test_tool".to_string(),
1917            call_id: None,
1918            function: message::ToolFunction {
1919                name: "test_function".to_string(),
1920                arguments: json!({"arg1": "value1"}),
1921            },
1922            signature: None,
1923            additional_params: None,
1924        };
1925
1926        let msg = message::Message::Assistant {
1927            id: None,
1928            content: OneOrMany::one(message::AssistantContent::ToolCall(tool_call)),
1929        };
1930
1931        let content: Content = msg.try_into().unwrap();
1932        assert_eq!(content.role, Some(Role::Model));
1933        assert_eq!(content.parts.len(), 1);
1934        if let Some(Part {
1935            part: PartKind::FunctionCall(function_call),
1936            ..
1937        }) = content.parts.first()
1938        {
1939            assert_eq!(function_call.name, "test_function");
1940            assert_eq!(
1941                function_call.args.as_object().unwrap().get("arg1").unwrap(),
1942                "value1"
1943            );
1944        } else {
1945            panic!("Expected function call part");
1946        }
1947    }
1948
1949    #[test]
1950    fn test_vec_schema_conversion() {
1951        let schema_with_ref = json!({
1952            "type": "array",
1953            "items": {
1954                "$ref": "#/$defs/Person"
1955            },
1956            "$defs": {
1957                "Person": {
1958                    "type": "object",
1959                    "properties": {
1960                        "first_name": {
1961                            "type": ["string", "null"],
1962                            "description": "The person's first name, if provided (null otherwise)"
1963                        },
1964                        "last_name": {
1965                            "type": ["string", "null"],
1966                            "description": "The person's last name, if provided (null otherwise)"
1967                        },
1968                        "job": {
1969                            "type": ["string", "null"],
1970                            "description": "The person's job, if provided (null otherwise)"
1971                        }
1972                    },
1973                    "required": []
1974                }
1975            }
1976        });
1977
1978        let result: Result<Schema, _> = schema_with_ref.try_into();
1979
1980        match result {
1981            Ok(schema) => {
1982                assert_eq!(schema.r#type, "array");
1983
1984                if let Some(items) = schema.items {
1985                    println!("item types: {}", items.r#type);
1986
1987                    assert_ne!(items.r#type, "", "Items type should not be empty string!");
1988                    assert_eq!(items.r#type, "object", "Items should be object type");
1989                } else {
1990                    panic!("Schema should have items field for array type");
1991                }
1992            }
1993            Err(e) => println!("Schema conversion failed: {:?}", e),
1994        }
1995    }
1996
1997    #[test]
1998    fn test_object_schema() {
1999        let simple_schema = json!({
2000            "type": "object",
2001            "properties": {
2002                "name": {
2003                    "type": "string"
2004                }
2005            }
2006        });
2007
2008        let schema: Schema = simple_schema.try_into().unwrap();
2009        assert_eq!(schema.r#type, "object");
2010        assert!(schema.properties.is_some());
2011    }
2012
2013    #[test]
2014    fn test_array_with_inline_items() {
2015        let inline_schema = json!({
2016            "type": "array",
2017            "items": {
2018                "type": "object",
2019                "properties": {
2020                    "name": {
2021                        "type": "string"
2022                    }
2023                }
2024            }
2025        });
2026
2027        let schema: Schema = inline_schema.try_into().unwrap();
2028        assert_eq!(schema.r#type, "array");
2029
2030        if let Some(items) = schema.items {
2031            assert_eq!(items.r#type, "object");
2032            assert!(items.properties.is_some());
2033        } else {
2034            panic!("Schema should have items field");
2035        }
2036    }
2037    #[test]
2038    fn test_flattened_schema() {
2039        let ref_schema = json!({
2040            "type": "array",
2041            "items": {
2042                "$ref": "#/$defs/Person"
2043            },
2044            "$defs": {
2045                "Person": {
2046                    "type": "object",
2047                    "properties": {
2048                        "name": { "type": "string" }
2049                    }
2050                }
2051            }
2052        });
2053
2054        let flattened = flatten_schema(ref_schema).unwrap();
2055        let schema: Schema = flattened.try_into().unwrap();
2056
2057        assert_eq!(schema.r#type, "array");
2058
2059        if let Some(items) = schema.items {
2060            println!("Flattened items type: '{}'", items.r#type);
2061
2062            assert_eq!(items.r#type, "object");
2063            assert!(items.properties.is_some());
2064        }
2065    }
2066}