rig/providers/gemini/
completion.rs

1// ================================================================
2//! Google Gemini Completion Integration
3//! From [Gemini API Reference](https://ai.google.dev/api/generate-content)
4// ================================================================
5/// `gemini-2.5-pro-preview-06-05` completion model
6pub const GEMINI_2_5_PRO_PREVIEW_06_05: &str = "gemini-2.5-pro-preview-06-05";
7/// `gemini-2.5-pro-preview-05-06` completion model
8pub const GEMINI_2_5_PRO_PREVIEW_05_06: &str = "gemini-2.5-pro-preview-05-06";
9/// `gemini-2.5-pro-preview-03-25` completion model
10pub const GEMINI_2_5_PRO_PREVIEW_03_25: &str = "gemini-2.5-pro-preview-03-25";
11/// `gemini-2.5-flash-preview-04-17` completion model
12pub const GEMINI_2_5_FLASH_PREVIEW_04_17: &str = "gemini-2.5-flash-preview-04-17";
13/// `gemini-2.5-pro-exp-03-25` experimental completion model
14pub const GEMINI_2_5_PRO_EXP_03_25: &str = "gemini-2.5-pro-exp-03-25";
15/// `gemini-2.5-flash` completion model
16pub const GEMINI_2_5_FLASH: &str = "gemini-2.5-flash";
17/// `gemini-2.0-flash-lite` completion model
18pub const GEMINI_2_0_FLASH_LITE: &str = "gemini-2.0-flash-lite";
19/// `gemini-2.0-flash` completion model
20pub const GEMINI_2_0_FLASH: &str = "gemini-2.0-flash";
21
22use self::gemini_api_types::Schema;
23use crate::http_client::HttpClientExt;
24use crate::message::{self, MimeType, Reasoning};
25
26use crate::providers::gemini::completion::gemini_api_types::{
27    AdditionalParameters, FunctionCallingMode, ToolConfig,
28};
29use crate::providers::gemini::streaming::StreamingCompletionResponse;
30use crate::telemetry::SpanCombinator;
31use crate::{
32    OneOrMany,
33    completion::{self, CompletionError, CompletionRequest},
34};
35use gemini_api_types::{
36    Content, FunctionDeclaration, GenerateContentRequest, GenerateContentResponse, Part, PartKind,
37    Role, Tool,
38};
39use serde_json::{Map, Value};
40use std::convert::TryFrom;
41use tracing::{Level, enabled, info_span};
42use tracing_futures::Instrument;
43
44use super::Client;
45
46// =================================================================
47// Rig Implementation Types
48// =================================================================
49
50#[derive(Clone, Debug)]
51pub struct CompletionModel<T = reqwest::Client> {
52    pub(crate) client: Client<T>,
53    pub model: String,
54}
55
56impl<T> CompletionModel<T> {
57    pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
58        Self {
59            client,
60            model: model.into(),
61        }
62    }
63
64    pub fn with_model(client: Client<T>, model: &str) -> Self {
65        Self {
66            client,
67            model: model.into(),
68        }
69    }
70}
71
72impl<T> completion::CompletionModel for CompletionModel<T>
73where
74    T: HttpClientExt + Clone + 'static,
75{
76    type Response = GenerateContentResponse;
77    type StreamingResponse = StreamingCompletionResponse;
78    type Client = super::Client<T>;
79
80    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
81        Self::new(client.clone(), model)
82    }
83
84    #[cfg_attr(feature = "worker", worker::send)]
85    async fn completion(
86        &self,
87        completion_request: CompletionRequest,
88    ) -> Result<completion::CompletionResponse<GenerateContentResponse>, CompletionError> {
89        let span = if tracing::Span::current().is_disabled() {
90            info_span!(
91                target: "rig::completions",
92                "generate_content",
93                gen_ai.operation.name = "generate_content",
94                gen_ai.provider.name = "gcp.gemini",
95                gen_ai.request.model = self.model,
96                gen_ai.system_instructions = &completion_request.preamble,
97                gen_ai.response.id = tracing::field::Empty,
98                gen_ai.response.model = tracing::field::Empty,
99                gen_ai.usage.output_tokens = tracing::field::Empty,
100                gen_ai.usage.input_tokens = tracing::field::Empty,
101            )
102        } else {
103            tracing::Span::current()
104        };
105
106        let request = create_request_body(completion_request)?;
107
108        if enabled!(Level::TRACE) {
109            tracing::trace!(
110                target: "rig::completions",
111                "Gemini completion request: {}",
112                serde_json::to_string_pretty(&request)?
113            );
114        }
115
116        let body = serde_json::to_vec(&request)?;
117
118        let path = format!("/v1beta/models/{}:generateContent", self.model);
119
120        let request = self
121            .client
122            .post(path.as_str())?
123            .body(body)
124            .map_err(|e| CompletionError::HttpError(e.into()))?;
125
126        async move {
127            let response = self.client.send::<_, Vec<u8>>(request).await?;
128
129            if response.status().is_success() {
130                let response_body = response
131                    .into_body()
132                    .await
133                    .map_err(CompletionError::HttpError)?;
134
135                let response_text = String::from_utf8_lossy(&response_body).to_string();
136
137                let response: GenerateContentResponse = serde_json::from_slice(&response_body)
138                    .map_err(|err| {
139                        tracing::error!(
140                            error = %err,
141                            body = %response_text,
142                            "Failed to deserialize Gemini completion response"
143                        );
144                        CompletionError::JsonError(err)
145                    })?;
146
147                let span = tracing::Span::current();
148                span.record_response_metadata(&response);
149                span.record_token_usage(&response.usage_metadata);
150
151                if enabled!(Level::TRACE) {
152                    tracing::trace!(
153                        target: "rig::completions",
154                        "Gemini completion response: {}",
155                        serde_json::to_string_pretty(&response)?
156                    );
157                }
158
159                response.try_into()
160            } else {
161                let text = String::from_utf8_lossy(
162                    &response
163                        .into_body()
164                        .await
165                        .map_err(CompletionError::HttpError)?,
166                )
167                .into();
168
169                Err(CompletionError::ProviderError(text))
170            }
171        }
172        .instrument(span)
173        .await
174    }
175
176    #[cfg_attr(feature = "worker", worker::send)]
177    async fn stream(
178        &self,
179        request: CompletionRequest,
180    ) -> Result<
181        crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
182        CompletionError,
183    > {
184        CompletionModel::stream(self, request).await
185    }
186}
187
188pub(crate) fn create_request_body(
189    completion_request: CompletionRequest,
190) -> Result<GenerateContentRequest, CompletionError> {
191    let mut full_history = Vec::new();
192    full_history.extend(completion_request.chat_history);
193
194    let additional_params = completion_request
195        .additional_params
196        .unwrap_or_else(|| Value::Object(Map::new()));
197
198    let AdditionalParameters {
199        mut generation_config,
200        additional_params,
201    } = serde_json::from_value::<AdditionalParameters>(additional_params)?;
202
203    generation_config = generation_config.map(|mut cfg| {
204        if let Some(temp) = completion_request.temperature {
205            cfg.temperature = Some(temp);
206        };
207
208        if let Some(max_tokens) = completion_request.max_tokens {
209            cfg.max_output_tokens = Some(max_tokens);
210        };
211
212        cfg
213    });
214
215    let system_instruction = completion_request.preamble.clone().map(|preamble| Content {
216        parts: vec![preamble.into()],
217        role: Some(Role::Model),
218    });
219
220    let tools = if completion_request.tools.is_empty() {
221        None
222    } else {
223        Some(Tool::try_from(completion_request.tools)?)
224    };
225
226    let tool_config = if let Some(cfg) = completion_request.tool_choice {
227        Some(ToolConfig {
228            function_calling_config: Some(FunctionCallingMode::try_from(cfg)?),
229        })
230    } else {
231        None
232    };
233
234    let request = GenerateContentRequest {
235        contents: full_history
236            .into_iter()
237            .map(|msg| {
238                msg.try_into()
239                    .map_err(|e| CompletionError::RequestError(Box::new(e)))
240            })
241            .collect::<Result<Vec<_>, _>>()?,
242        generation_config,
243        safety_settings: None,
244        tools,
245        tool_config,
246        system_instruction,
247        additional_params,
248    };
249
250    Ok(request)
251}
252
253impl TryFrom<completion::ToolDefinition> for Tool {
254    type Error = CompletionError;
255
256    fn try_from(tool: completion::ToolDefinition) -> Result<Self, Self::Error> {
257        let parameters: Option<Schema> =
258            if tool.parameters == serde_json::json!({"type": "object", "properties": {}}) {
259                None
260            } else {
261                Some(tool.parameters.try_into()?)
262            };
263
264        Ok(Self {
265            function_declarations: vec![FunctionDeclaration {
266                name: tool.name,
267                description: tool.description,
268                parameters,
269            }],
270            code_execution: None,
271        })
272    }
273}
274
275impl TryFrom<Vec<completion::ToolDefinition>> for Tool {
276    type Error = CompletionError;
277
278    fn try_from(tools: Vec<completion::ToolDefinition>) -> Result<Self, Self::Error> {
279        let mut function_declarations = Vec::new();
280
281        for tool in tools {
282            let parameters =
283                if tool.parameters == serde_json::json!({"type": "object", "properties": {}}) {
284                    None
285                } else {
286                    match tool.parameters.try_into() {
287                        Ok(schema) => Some(schema),
288                        Err(e) => {
289                            let emsg = format!(
290                                "Tool '{}' could not be converted to a schema: {:?}",
291                                tool.name, e,
292                            );
293                            return Err(CompletionError::ProviderError(emsg));
294                        }
295                    }
296                };
297
298            function_declarations.push(FunctionDeclaration {
299                name: tool.name,
300                description: tool.description,
301                parameters,
302            });
303        }
304
305        Ok(Self {
306            function_declarations,
307            code_execution: None,
308        })
309    }
310}
311
312impl TryFrom<GenerateContentResponse> for completion::CompletionResponse<GenerateContentResponse> {
313    type Error = CompletionError;
314
315    fn try_from(response: GenerateContentResponse) -> Result<Self, Self::Error> {
316        let candidate = response.candidates.first().ok_or_else(|| {
317            CompletionError::ResponseError("No response candidates in response".into())
318        })?;
319
320        let content = candidate
321            .content
322            .as_ref()
323            .ok_or_else(|| {
324                let reason = candidate
325                    .finish_reason
326                    .as_ref()
327                    .map(|r| format!("finish_reason={r:?}"))
328                    .unwrap_or_else(|| "finish_reason=<unknown>".to_string());
329                let message = candidate
330                    .finish_message
331                    .as_deref()
332                    .unwrap_or("no finish message provided");
333                CompletionError::ResponseError(format!(
334                    "Gemini candidate missing content ({reason}, finish_message={message})"
335                ))
336            })?
337            .parts
338            .iter()
339            .map(|Part { thought, part, .. }| {
340                Ok(match part {
341                    PartKind::Text(text) => {
342                        if let Some(thought) = thought
343                            && *thought
344                        {
345                            completion::AssistantContent::Reasoning(Reasoning::new(text))
346                        } else {
347                            completion::AssistantContent::text(text)
348                        }
349                    }
350                    PartKind::InlineData(inline_data) => {
351                        let mime_type = message::MediaType::from_mime_type(&inline_data.mime_type);
352
353                        match mime_type {
354                            Some(message::MediaType::Image(media_type)) => {
355                                message::AssistantContent::image_base64(
356                                    &inline_data.data,
357                                    Some(media_type),
358                                    Some(message::ImageDetail::default()),
359                                )
360                            }
361                            _ => {
362                                return Err(CompletionError::ResponseError(format!(
363                                    "Unsupported media type {mime_type:?}"
364                                )));
365                            }
366                        }
367                    }
368                    PartKind::FunctionCall(function_call) => {
369                        completion::AssistantContent::tool_call(
370                            &function_call.name,
371                            &function_call.name,
372                            function_call.args.clone(),
373                        )
374                    }
375                    _ => {
376                        return Err(CompletionError::ResponseError(
377                            "Response did not contain a message or tool call".into(),
378                        ));
379                    }
380                })
381            })
382            .collect::<Result<Vec<_>, _>>()?;
383
384        let choice = OneOrMany::many(content).map_err(|_| {
385            CompletionError::ResponseError(
386                "Response contained no message or tool call (empty)".to_owned(),
387            )
388        })?;
389
390        let usage = response
391            .usage_metadata
392            .as_ref()
393            .map(|usage| completion::Usage {
394                input_tokens: usage.prompt_token_count as u64,
395                output_tokens: usage.candidates_token_count.unwrap_or(0) as u64,
396                total_tokens: usage.total_token_count as u64,
397            })
398            .unwrap_or_default();
399
400        Ok(completion::CompletionResponse {
401            choice,
402            usage,
403            raw_response: response,
404        })
405    }
406}
407
408pub mod gemini_api_types {
409    use crate::telemetry::ProviderResponseExt;
410    use std::{collections::HashMap, convert::Infallible, str::FromStr};
411
412    // =================================================================
413    // Gemini API Types
414    // =================================================================
415    use serde::{Deserialize, Serialize};
416    use serde_json::{Value, json};
417
418    use crate::completion::GetTokenUsage;
419    use crate::message::{DocumentSourceKind, ImageMediaType, MessageError, MimeType};
420    use crate::{
421        OneOrMany,
422        completion::CompletionError,
423        message::{self, Reasoning, Text},
424        providers::gemini::gemini_api_types::{CodeExecutionResult, ExecutableCode},
425    };
426
427    #[derive(Debug, Deserialize, Serialize, Default)]
428    #[serde(rename_all = "camelCase")]
429    pub struct AdditionalParameters {
430        /// Change your Gemini request configuration.
431        pub generation_config: Option<GenerationConfig>,
432        /// Any additional parameters that you want.
433        #[serde(flatten, skip_serializing_if = "Option::is_none")]
434        pub additional_params: Option<serde_json::Value>,
435    }
436
437    impl AdditionalParameters {
438        pub fn with_config(mut self, cfg: GenerationConfig) -> Self {
439            self.generation_config = Some(cfg);
440            self
441        }
442
443        pub fn with_params(mut self, params: serde_json::Value) -> Self {
444            self.additional_params = Some(params);
445            self
446        }
447    }
448
449    /// Response from the model supporting multiple candidate responses.
450    /// Safety ratings and content filtering are reported for both prompt in GenerateContentResponse.prompt_feedback
451    /// and for each candidate in finishReason and in safetyRatings.
452    /// The API:
453    ///     - Returns either all requested candidates or none of them
454    ///     - Returns no candidates at all only if there was something wrong with the prompt (check promptFeedback)
455    ///     - Reports feedback on each candidate in finishReason and safetyRatings.
456    #[derive(Debug, Deserialize, Serialize)]
457    #[serde(rename_all = "camelCase")]
458    pub struct GenerateContentResponse {
459        pub response_id: String,
460        /// Candidate responses from the model.
461        pub candidates: Vec<ContentCandidate>,
462        /// Returns the prompt's feedback related to the content filters.
463        pub prompt_feedback: Option<PromptFeedback>,
464        /// Output only. Metadata on the generation requests' token usage.
465        pub usage_metadata: Option<UsageMetadata>,
466        pub model_version: Option<String>,
467    }
468
469    impl ProviderResponseExt for GenerateContentResponse {
470        type OutputMessage = ContentCandidate;
471        type Usage = UsageMetadata;
472
473        fn get_response_id(&self) -> Option<String> {
474            Some(self.response_id.clone())
475        }
476
477        fn get_response_model_name(&self) -> Option<String> {
478            None
479        }
480
481        fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
482            self.candidates.clone()
483        }
484
485        fn get_text_response(&self) -> Option<String> {
486            let str = self
487                .candidates
488                .iter()
489                .filter_map(|x| {
490                    let content = x.content.as_ref()?;
491                    if content.role.as_ref().is_none_or(|y| y != &Role::Model) {
492                        return None;
493                    }
494
495                    let res = content
496                        .parts
497                        .iter()
498                        .filter_map(|part| {
499                            if let PartKind::Text(ref str) = part.part {
500                                Some(str.to_owned())
501                            } else {
502                                None
503                            }
504                        })
505                        .collect::<Vec<String>>()
506                        .join("\n");
507
508                    Some(res)
509                })
510                .collect::<Vec<String>>()
511                .join("\n");
512
513            if str.is_empty() { None } else { Some(str) }
514        }
515
516        fn get_usage(&self) -> Option<Self::Usage> {
517            self.usage_metadata.clone()
518        }
519    }
520
521    /// A response candidate generated from the model.
522    #[derive(Clone, Debug, Deserialize, Serialize)]
523    #[serde(rename_all = "camelCase")]
524    pub struct ContentCandidate {
525        /// Output only. Generated content returned from the model.
526        #[serde(skip_serializing_if = "Option::is_none")]
527        pub content: Option<Content>,
528        /// Optional. Output only. The reason why the model stopped generating tokens.
529        /// If empty, the model has not stopped generating tokens.
530        pub finish_reason: Option<FinishReason>,
531        /// List of ratings for the safety of a response candidate.
532        /// There is at most one rating per category.
533        pub safety_ratings: Option<Vec<SafetyRating>>,
534        /// Output only. Citation information for model-generated candidate.
535        /// This field may be populated with recitation information for any text included in the content.
536        /// These are passages that are "recited" from copyrighted material in the foundational LLM's training data.
537        pub citation_metadata: Option<CitationMetadata>,
538        /// Output only. Token count for this candidate.
539        pub token_count: Option<i32>,
540        /// Output only.
541        pub avg_logprobs: Option<f64>,
542        /// Output only. Log-likelihood scores for the response tokens and top tokens
543        pub logprobs_result: Option<LogprobsResult>,
544        /// Output only. Index of the candidate in the list of response candidates.
545        pub index: Option<i32>,
546        /// Output only. Additional information about why the model stopped generating tokens.
547        pub finish_message: Option<String>,
548    }
549
550    #[derive(Clone, Debug, Deserialize, Serialize)]
551    pub struct Content {
552        /// Ordered Parts that constitute a single message. Parts may have different MIME types.
553        #[serde(default)]
554        pub parts: Vec<Part>,
555        /// The producer of the content. Must be either 'user' or 'model'.
556        /// Useful to set for multi-turn conversations, otherwise can be left blank or unset.
557        pub role: Option<Role>,
558    }
559
560    impl TryFrom<message::Message> for Content {
561        type Error = message::MessageError;
562
563        fn try_from(msg: message::Message) -> Result<Self, Self::Error> {
564            Ok(match msg {
565                message::Message::User { content } => Content {
566                    parts: content
567                        .into_iter()
568                        .map(|c| c.try_into())
569                        .collect::<Result<Vec<_>, _>>()?,
570                    role: Some(Role::User),
571                },
572                message::Message::Assistant { content, .. } => Content {
573                    role: Some(Role::Model),
574                    parts: content
575                        .into_iter()
576                        .map(|content| content.try_into())
577                        .collect::<Result<Vec<_>, _>>()?,
578                },
579            })
580        }
581    }
582
583    impl TryFrom<Content> for message::Message {
584        type Error = message::MessageError;
585
586        fn try_from(content: Content) -> Result<Self, Self::Error> {
587            match content.role {
588                Some(Role::User) | None => {
589                    Ok(message::Message::User {
590                        content: {
591                            let user_content: Result<Vec<_>, _> = content.parts.into_iter()
592                            .map(|Part { part, .. }| {
593                                Ok(match part {
594                                    PartKind::Text(text) => message::UserContent::text(text),
595                                    PartKind::InlineData(inline_data) => {
596                                        let mime_type =
597                                            message::MediaType::from_mime_type(&inline_data.mime_type);
598
599                                        match mime_type {
600                                            Some(message::MediaType::Image(media_type)) => {
601                                                message::UserContent::image_base64(
602                                                    inline_data.data,
603                                                    Some(media_type),
604                                                    Some(message::ImageDetail::default()),
605                                                )
606                                            }
607                                            Some(message::MediaType::Document(media_type)) => {
608                                                message::UserContent::document(
609                                                    inline_data.data,
610                                                    Some(media_type),
611                                                )
612                                            }
613                                            Some(message::MediaType::Audio(media_type)) => {
614                                                message::UserContent::audio(
615                                                    inline_data.data,
616                                                    Some(media_type),
617                                                )
618                                            }
619                                            _ => {
620                                                return Err(message::MessageError::ConversionError(
621                                                    format!("Unsupported media type {mime_type:?}"),
622                                                ));
623                                            }
624                                        }
625                                    }
626                                    _ => {
627                                        return Err(message::MessageError::ConversionError(format!(
628                                            "Unsupported gemini content part type: {part:?}"
629                                        )));
630                                    }
631                                })
632                            })
633                            .collect();
634                            OneOrMany::many(user_content?).map_err(|_| {
635                                message::MessageError::ConversionError(
636                                    "Failed to create OneOrMany from user content".to_string(),
637                                )
638                            })?
639                        },
640                    })
641                }
642                Some(Role::Model) => Ok(message::Message::Assistant {
643                    id: None,
644                    content: {
645                        let assistant_content: Result<Vec<_>, _> = content
646                            .parts
647                            .into_iter()
648                            .map(|Part { thought, part, .. }| {
649                                Ok(match part {
650                                    PartKind::Text(text) => match thought {
651                                        Some(true) => message::AssistantContent::Reasoning(
652                                            Reasoning::new(&text),
653                                        ),
654                                        _ => message::AssistantContent::Text(Text { text }),
655                                    },
656
657                                    PartKind::FunctionCall(function_call) => {
658                                        message::AssistantContent::ToolCall(function_call.into())
659                                    }
660                                    _ => {
661                                        return Err(message::MessageError::ConversionError(
662                                            format!("Unsupported part type: {part:?}"),
663                                        ));
664                                    }
665                                })
666                            })
667                            .collect();
668                        OneOrMany::many(assistant_content?).map_err(|_| {
669                            message::MessageError::ConversionError(
670                                "Failed to create OneOrMany from assistant content".to_string(),
671                            )
672                        })?
673                    },
674                }),
675            }
676        }
677    }
678
679    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
680    #[serde(rename_all = "lowercase")]
681    pub enum Role {
682        User,
683        Model,
684    }
685
686    #[derive(Debug, Default, Deserialize, Serialize, Clone, PartialEq)]
687    #[serde(rename_all = "camelCase")]
688    pub struct Part {
689        /// whether or not the part is a reasoning/thinking text or not
690        #[serde(skip_serializing_if = "Option::is_none")]
691        pub thought: Option<bool>,
692        /// an opaque sig for the thought so it can be reused - is a base64 string
693        #[serde(skip_serializing_if = "Option::is_none")]
694        pub thought_signature: Option<String>,
695        #[serde(flatten)]
696        pub part: PartKind,
697        #[serde(flatten, skip_serializing_if = "Option::is_none")]
698        pub additional_params: Option<Value>,
699    }
700
701    /// A datatype containing media that is part of a multi-part [Content] message.
702    /// A Part consists of data which has an associated datatype. A Part can only contain one of the accepted types in Part.data.
703    /// A Part must have a fixed IANA MIME type identifying the type and subtype of the media if the inlineData field is filled with raw bytes.
704    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
705    #[serde(rename_all = "camelCase")]
706    pub enum PartKind {
707        Text(String),
708        InlineData(Blob),
709        FunctionCall(FunctionCall),
710        FunctionResponse(FunctionResponse),
711        FileData(FileData),
712        ExecutableCode(ExecutableCode),
713        CodeExecutionResult(CodeExecutionResult),
714    }
715
716    // This default instance is primarily so we can easily fill in the optional fields of `Part`
717    // So this instance for `PartKind` (and the allocation it would cause) should be optimized away
718    impl Default for PartKind {
719        fn default() -> Self {
720            Self::Text(String::new())
721        }
722    }
723
724    impl From<String> for Part {
725        fn from(text: String) -> Self {
726            Self {
727                thought: Some(false),
728                thought_signature: None,
729                part: PartKind::Text(text),
730                additional_params: None,
731            }
732        }
733    }
734
735    impl From<&str> for Part {
736        fn from(text: &str) -> Self {
737            Self::from(text.to_string())
738        }
739    }
740
741    impl FromStr for Part {
742        type Err = Infallible;
743
744        fn from_str(s: &str) -> Result<Self, Self::Err> {
745            Ok(s.into())
746        }
747    }
748
749    impl TryFrom<(ImageMediaType, DocumentSourceKind)> for PartKind {
750        type Error = message::MessageError;
751        fn try_from(
752            (mime_type, doc_src): (ImageMediaType, DocumentSourceKind),
753        ) -> Result<Self, Self::Error> {
754            let mime_type = mime_type.to_mime_type().to_string();
755            let part = match doc_src {
756                DocumentSourceKind::Url(url) => PartKind::FileData(FileData {
757                    mime_type: Some(mime_type),
758                    file_uri: url,
759                }),
760                DocumentSourceKind::Base64(data) | DocumentSourceKind::String(data) => {
761                    PartKind::InlineData(Blob { mime_type, data })
762                }
763                DocumentSourceKind::Raw(_) => {
764                    return Err(message::MessageError::ConversionError(
765                        "Raw files not supported, encode as base64 first".into(),
766                    ));
767                }
768                DocumentSourceKind::Unknown => {
769                    return Err(message::MessageError::ConversionError(
770                        "Can't convert an unknown document source".to_string(),
771                    ));
772                }
773            };
774
775            Ok(part)
776        }
777    }
778
779    impl TryFrom<message::UserContent> for Part {
780        type Error = message::MessageError;
781
782        fn try_from(content: message::UserContent) -> Result<Self, Self::Error> {
783            match content {
784                message::UserContent::Text(message::Text { text }) => Ok(Part {
785                    thought: Some(false),
786                    thought_signature: None,
787                    part: PartKind::Text(text),
788                    additional_params: None,
789                }),
790                message::UserContent::ToolResult(message::ToolResult { id, content, .. }) => {
791                    let content = match content.first() {
792                        message::ToolResultContent::Text(text) => text.text,
793                        message::ToolResultContent::Image(_) => {
794                            return Err(message::MessageError::ConversionError(
795                                "Tool result content must be text".to_string(),
796                            ));
797                        }
798                    };
799                    // Convert to JSON since this value may be a valid JSON value
800                    let result: serde_json::Value =
801                        serde_json::from_str(&content).unwrap_or_else(|error| {
802                            tracing::trace!(
803                                ?error,
804                                "Tool result is not a valid JSON, treat it as normal string"
805                            );
806                            json!(content)
807                        });
808                    Ok(Part {
809                        thought: Some(false),
810                        thought_signature: None,
811                        part: PartKind::FunctionResponse(FunctionResponse {
812                            name: id,
813                            response: Some(json!({ "result": result })),
814                        }),
815                        additional_params: None,
816                    })
817                }
818                message::UserContent::Image(message::Image {
819                    data, media_type, ..
820                }) => match media_type {
821                    Some(media_type) => match media_type {
822                        message::ImageMediaType::JPEG
823                        | message::ImageMediaType::PNG
824                        | message::ImageMediaType::WEBP
825                        | message::ImageMediaType::HEIC
826                        | message::ImageMediaType::HEIF => {
827                            let part = PartKind::try_from((media_type, data))?;
828                            Ok(Part {
829                                thought: Some(false),
830                                thought_signature: None,
831                                part,
832                                additional_params: None,
833                            })
834                        }
835                        _ => Err(message::MessageError::ConversionError(format!(
836                            "Unsupported image media type {media_type:?}"
837                        ))),
838                    },
839                    None => Err(message::MessageError::ConversionError(
840                        "Media type for image is required for Gemini".to_string(),
841                    )),
842                },
843                message::UserContent::Document(message::Document {
844                    data, media_type, ..
845                }) => {
846                    let Some(media_type) = media_type else {
847                        return Err(MessageError::ConversionError(
848                            "A mime type is required for document inputs to Gemini".to_string(),
849                        ));
850                    };
851
852                    if !media_type.is_code() {
853                        let mime_type = media_type.to_mime_type().to_string();
854
855                        let part = match data {
856                            DocumentSourceKind::Url(file_uri) => PartKind::FileData(FileData {
857                                mime_type: Some(mime_type),
858                                file_uri,
859                            }),
860                            DocumentSourceKind::Base64(data) | DocumentSourceKind::String(data) => {
861                                PartKind::InlineData(Blob { mime_type, data })
862                            }
863                            DocumentSourceKind::Raw(_) => {
864                                return Err(message::MessageError::ConversionError(
865                                    "Raw files not supported, encode as base64 first".into(),
866                                ));
867                            }
868                            _ => {
869                                return Err(message::MessageError::ConversionError(
870                                    "Document has no body".to_string(),
871                                ));
872                            }
873                        };
874
875                        Ok(Part {
876                            thought: Some(false),
877                            part,
878                            ..Default::default()
879                        })
880                    } else {
881                        Err(message::MessageError::ConversionError(format!(
882                            "Unsupported document media type {media_type:?}"
883                        )))
884                    }
885                }
886
887                message::UserContent::Audio(message::Audio {
888                    data, media_type, ..
889                }) => {
890                    let Some(media_type) = media_type else {
891                        return Err(MessageError::ConversionError(
892                            "A mime type is required for audio inputs to Gemini".to_string(),
893                        ));
894                    };
895
896                    let mime_type = media_type.to_mime_type().to_string();
897
898                    let part = match data {
899                        DocumentSourceKind::Base64(data) => {
900                            PartKind::InlineData(Blob { data, mime_type })
901                        }
902
903                        DocumentSourceKind::Url(file_uri) => PartKind::FileData(FileData {
904                            mime_type: Some(mime_type),
905                            file_uri,
906                        }),
907                        DocumentSourceKind::String(_) => {
908                            return Err(message::MessageError::ConversionError(
909                                "Strings cannot be used as audio files!".into(),
910                            ));
911                        }
912                        DocumentSourceKind::Raw(_) => {
913                            return Err(message::MessageError::ConversionError(
914                                "Raw files not supported, encode as base64 first".into(),
915                            ));
916                        }
917                        DocumentSourceKind::Unknown => {
918                            return Err(message::MessageError::ConversionError(
919                                "Content has no body".to_string(),
920                            ));
921                        }
922                    };
923
924                    Ok(Part {
925                        thought: Some(false),
926                        part,
927                        ..Default::default()
928                    })
929                }
930                message::UserContent::Video(message::Video {
931                    data,
932                    media_type,
933                    additional_params,
934                    ..
935                }) => {
936                    let mime_type = media_type.map(|media_ty| media_ty.to_mime_type().to_string());
937
938                    let part = match data {
939                        DocumentSourceKind::Url(file_uri) => {
940                            if file_uri.starts_with("https://www.youtube.com") {
941                                PartKind::FileData(FileData {
942                                    mime_type,
943                                    file_uri,
944                                })
945                            } else {
946                                if mime_type.is_none() {
947                                    return Err(MessageError::ConversionError(
948                                        "A mime type is required for non-Youtube video file inputs to Gemini"
949                                            .to_string(),
950                                    ));
951                                }
952
953                                PartKind::FileData(FileData {
954                                    mime_type,
955                                    file_uri,
956                                })
957                            }
958                        }
959                        DocumentSourceKind::Base64(data) => {
960                            let Some(mime_type) = mime_type else {
961                                return Err(MessageError::ConversionError(
962                                    "A media type is expected for base64 encoded strings"
963                                        .to_string(),
964                                ));
965                            };
966                            PartKind::InlineData(Blob { mime_type, data })
967                        }
968                        DocumentSourceKind::String(_) => {
969                            return Err(message::MessageError::ConversionError(
970                                "Strings cannot be used as audio files!".into(),
971                            ));
972                        }
973                        DocumentSourceKind::Raw(_) => {
974                            return Err(message::MessageError::ConversionError(
975                                "Raw file data not supported, encode as base64 first".into(),
976                            ));
977                        }
978                        DocumentSourceKind::Unknown => {
979                            return Err(message::MessageError::ConversionError(
980                                "Media type for video is required for Gemini".to_string(),
981                            ));
982                        }
983                    };
984
985                    Ok(Part {
986                        thought: Some(false),
987                        thought_signature: None,
988                        part,
989                        additional_params,
990                    })
991                }
992            }
993        }
994    }
995
996    impl TryFrom<message::AssistantContent> for Part {
997        type Error = message::MessageError;
998
999        fn try_from(content: message::AssistantContent) -> Result<Self, Self::Error> {
1000            match content {
1001                message::AssistantContent::Text(message::Text { text }) => Ok(text.into()),
1002                message::AssistantContent::Image(message::Image {
1003                    data, media_type, ..
1004                }) => match media_type {
1005                    Some(media_type) => match media_type {
1006                        message::ImageMediaType::JPEG
1007                        | message::ImageMediaType::PNG
1008                        | message::ImageMediaType::WEBP
1009                        | message::ImageMediaType::HEIC
1010                        | message::ImageMediaType::HEIF => {
1011                            let part = PartKind::try_from((media_type, data))?;
1012                            Ok(Part {
1013                                thought: Some(false),
1014                                thought_signature: None,
1015                                part,
1016                                additional_params: None,
1017                            })
1018                        }
1019                        _ => Err(message::MessageError::ConversionError(format!(
1020                            "Unsupported image media type {media_type:?}"
1021                        ))),
1022                    },
1023                    None => Err(message::MessageError::ConversionError(
1024                        "Media type for image is required for Gemini".to_string(),
1025                    )),
1026                },
1027                message::AssistantContent::ToolCall(tool_call) => Ok(tool_call.into()),
1028                message::AssistantContent::Reasoning(message::Reasoning { reasoning, .. }) => {
1029                    Ok(Part {
1030                        thought: Some(true),
1031                        thought_signature: None,
1032                        part: PartKind::Text(
1033                            reasoning.first().cloned().unwrap_or_else(|| "".to_string()),
1034                        ),
1035                        additional_params: None,
1036                    })
1037                }
1038            }
1039        }
1040    }
1041
1042    impl From<message::ToolCall> for Part {
1043        fn from(tool_call: message::ToolCall) -> Self {
1044            Self {
1045                thought: Some(false),
1046                thought_signature: None,
1047                part: PartKind::FunctionCall(FunctionCall {
1048                    name: tool_call.function.name,
1049                    args: tool_call.function.arguments,
1050                }),
1051                additional_params: None,
1052            }
1053        }
1054    }
1055
1056    /// Raw media bytes.
1057    /// Text should not be sent as raw bytes, use the 'text' field.
1058    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1059    #[serde(rename_all = "camelCase")]
1060    pub struct Blob {
1061        /// The IANA standard MIME type of the source data. Examples: - image/png - image/jpeg
1062        /// If an unsupported MIME type is provided, an error will be returned.
1063        pub mime_type: String,
1064        /// Raw bytes for media formats. A base64-encoded string.
1065        pub data: String,
1066    }
1067
1068    /// A predicted FunctionCall returned from the model that contains a string representing the
1069    /// FunctionDeclaration.name with the arguments and their values.
1070    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1071    pub struct FunctionCall {
1072        /// Required. The name of the function to call. Must be a-z, A-Z, 0-9, or contain underscores
1073        /// and dashes, with a maximum length of 63.
1074        pub name: String,
1075        /// Optional. The function parameters and values in JSON object format.
1076        pub args: serde_json::Value,
1077    }
1078
1079    impl From<FunctionCall> for message::ToolCall {
1080        fn from(function_call: FunctionCall) -> Self {
1081            Self {
1082                id: function_call.name.clone(),
1083                call_id: None,
1084                function: message::ToolFunction {
1085                    name: function_call.name,
1086                    arguments: function_call.args,
1087                },
1088            }
1089        }
1090    }
1091
1092    impl From<message::ToolCall> for FunctionCall {
1093        fn from(tool_call: message::ToolCall) -> Self {
1094            Self {
1095                name: tool_call.function.name,
1096                args: tool_call.function.arguments,
1097            }
1098        }
1099    }
1100
1101    /// The result output from a FunctionCall that contains a string representing the FunctionDeclaration.name
1102    /// and a structured JSON object containing any output from the function is used as context to the model.
1103    /// This should contain the result of aFunctionCall made based on model prediction.
1104    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1105    pub struct FunctionResponse {
1106        /// The name of the function to call. Must be a-z, A-Z, 0-9, or contain underscores and dashes,
1107        /// with a maximum length of 63.
1108        pub name: String,
1109        /// The function response in JSON object format.
1110        pub response: Option<serde_json::Value>,
1111    }
1112
1113    /// URI based data.
1114    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1115    #[serde(rename_all = "camelCase")]
1116    pub struct FileData {
1117        /// Optional. The IANA standard MIME type of the source data.
1118        pub mime_type: Option<String>,
1119        /// Required. URI.
1120        pub file_uri: String,
1121    }
1122
1123    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1124    pub struct SafetyRating {
1125        pub category: HarmCategory,
1126        pub probability: HarmProbability,
1127    }
1128
1129    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1130    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1131    pub enum HarmProbability {
1132        HarmProbabilityUnspecified,
1133        Negligible,
1134        Low,
1135        Medium,
1136        High,
1137    }
1138
1139    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1140    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1141    pub enum HarmCategory {
1142        HarmCategoryUnspecified,
1143        HarmCategoryDerogatory,
1144        HarmCategoryToxicity,
1145        HarmCategoryViolence,
1146        HarmCategorySexually,
1147        HarmCategoryMedical,
1148        HarmCategoryDangerous,
1149        HarmCategoryHarassment,
1150        HarmCategoryHateSpeech,
1151        HarmCategorySexuallyExplicit,
1152        HarmCategoryDangerousContent,
1153        HarmCategoryCivicIntegrity,
1154    }
1155
1156    #[derive(Debug, Deserialize, Clone, Default, Serialize)]
1157    #[serde(rename_all = "camelCase")]
1158    pub struct UsageMetadata {
1159        pub prompt_token_count: i32,
1160        #[serde(skip_serializing_if = "Option::is_none")]
1161        pub cached_content_token_count: Option<i32>,
1162        #[serde(skip_serializing_if = "Option::is_none")]
1163        pub candidates_token_count: Option<i32>,
1164        pub total_token_count: i32,
1165        #[serde(skip_serializing_if = "Option::is_none")]
1166        pub thoughts_token_count: Option<i32>,
1167    }
1168
1169    impl std::fmt::Display for UsageMetadata {
1170        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1171            write!(
1172                f,
1173                "Prompt token count: {}\nCached content token count: {}\nCandidates token count: {}\nTotal token count: {}",
1174                self.prompt_token_count,
1175                match self.cached_content_token_count {
1176                    Some(count) => count.to_string(),
1177                    None => "n/a".to_string(),
1178                },
1179                match self.candidates_token_count {
1180                    Some(count) => count.to_string(),
1181                    None => "n/a".to_string(),
1182                },
1183                self.total_token_count
1184            )
1185        }
1186    }
1187
1188    impl GetTokenUsage for UsageMetadata {
1189        fn token_usage(&self) -> Option<crate::completion::Usage> {
1190            let mut usage = crate::completion::Usage::new();
1191
1192            usage.input_tokens = self.prompt_token_count as u64;
1193            usage.output_tokens = (self.cached_content_token_count.unwrap_or_default()
1194                + self.candidates_token_count.unwrap_or_default()
1195                + self.thoughts_token_count.unwrap_or_default())
1196                as u64;
1197            usage.total_tokens = usage.input_tokens + usage.output_tokens;
1198
1199            Some(usage)
1200        }
1201    }
1202
1203    /// A set of the feedback metadata the prompt specified in [GenerateContentRequest.contents](GenerateContentRequest).
1204    #[derive(Debug, Deserialize, Serialize)]
1205    #[serde(rename_all = "camelCase")]
1206    pub struct PromptFeedback {
1207        /// Optional. If set, the prompt was blocked and no candidates are returned. Rephrase the prompt.
1208        pub block_reason: Option<BlockReason>,
1209        /// Ratings for safety of the prompt. There is at most one rating per category.
1210        pub safety_ratings: Option<Vec<SafetyRating>>,
1211    }
1212
1213    /// Reason why a prompt was blocked by the model
1214    #[derive(Debug, Deserialize, Serialize)]
1215    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1216    pub enum BlockReason {
1217        /// Default value. This value is unused.
1218        BlockReasonUnspecified,
1219        /// Prompt was blocked due to safety reasons. Inspect safetyRatings to understand which safety category blocked it.
1220        Safety,
1221        /// Prompt was blocked due to unknown reasons.
1222        Other,
1223        /// Prompt was blocked due to the terms which are included from the terminology blocklist.
1224        Blocklist,
1225        /// Prompt was blocked due to prohibited content.
1226        ProhibitedContent,
1227    }
1228
1229    #[derive(Clone, Debug, Deserialize, Serialize)]
1230    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1231    pub enum FinishReason {
1232        /// Default value. This value is unused.
1233        FinishReasonUnspecified,
1234        /// Natural stop point of the model or provided stop sequence.
1235        Stop,
1236        /// The maximum number of tokens as specified in the request was reached.
1237        MaxTokens,
1238        /// The response candidate content was flagged for safety reasons.
1239        Safety,
1240        /// The response candidate content was flagged for recitation reasons.
1241        Recitation,
1242        /// The response candidate content was flagged for using an unsupported language.
1243        Language,
1244        /// Unknown reason.
1245        Other,
1246        /// Token generation stopped because the content contains forbidden terms.
1247        Blocklist,
1248        /// Token generation stopped for potentially containing prohibited content.
1249        ProhibitedContent,
1250        /// Token generation stopped because the content potentially contains Sensitive Personally Identifiable Information (SPII).
1251        Spii,
1252        /// The function call generated by the model is invalid.
1253        MalformedFunctionCall,
1254    }
1255
1256    #[derive(Clone, Debug, Deserialize, Serialize)]
1257    #[serde(rename_all = "camelCase")]
1258    pub struct CitationMetadata {
1259        pub citation_sources: Vec<CitationSource>,
1260    }
1261
1262    #[derive(Clone, Debug, Deserialize, Serialize)]
1263    #[serde(rename_all = "camelCase")]
1264    pub struct CitationSource {
1265        #[serde(skip_serializing_if = "Option::is_none")]
1266        pub uri: Option<String>,
1267        #[serde(skip_serializing_if = "Option::is_none")]
1268        pub start_index: Option<i32>,
1269        #[serde(skip_serializing_if = "Option::is_none")]
1270        pub end_index: Option<i32>,
1271        #[serde(skip_serializing_if = "Option::is_none")]
1272        pub license: Option<String>,
1273    }
1274
1275    #[derive(Clone, Debug, Deserialize, Serialize)]
1276    #[serde(rename_all = "camelCase")]
1277    pub struct LogprobsResult {
1278        pub top_candidate: Vec<TopCandidate>,
1279        pub chosen_candidate: Vec<LogProbCandidate>,
1280    }
1281
1282    #[derive(Clone, Debug, Deserialize, Serialize)]
1283    pub struct TopCandidate {
1284        pub candidates: Vec<LogProbCandidate>,
1285    }
1286
1287    #[derive(Clone, Debug, Deserialize, Serialize)]
1288    #[serde(rename_all = "camelCase")]
1289    pub struct LogProbCandidate {
1290        pub token: String,
1291        pub token_id: String,
1292        pub log_probability: f64,
1293    }
1294
1295    /// Gemini API Configuration options for model generation and outputs. Not all parameters are
1296    /// configurable for every model. From [Gemini API Reference](https://ai.google.dev/api/generate-content#generationconfig)
1297    /// ### Rig Note:
1298    /// Can be used to construct a typesafe `additional_params` in rig::[AgentBuilder](crate::agent::AgentBuilder).
1299    #[derive(Debug, Deserialize, Serialize)]
1300    #[serde(rename_all = "camelCase")]
1301    pub struct GenerationConfig {
1302        /// The set of character sequences (up to 5) that will stop output generation. If specified, the API will stop
1303        /// at the first appearance of a stop_sequence. The stop sequence will not be included as part of the response.
1304        #[serde(skip_serializing_if = "Option::is_none")]
1305        pub stop_sequences: Option<Vec<String>>,
1306        /// MIME type of the generated candidate text. Supported MIME types are:
1307        ///     - text/plain:  (default) Text output
1308        ///     - application/json: JSON response in the response candidates.
1309        ///     - text/x.enum: ENUM as a string response in the response candidates.
1310        /// Refer to the docs for a list of all supported text MIME types
1311        #[serde(skip_serializing_if = "Option::is_none")]
1312        pub response_mime_type: Option<String>,
1313        /// Output schema of the generated candidate text. Schemas must be a subset of the OpenAPI schema and can be
1314        /// objects, primitives or arrays. If set, a compatible responseMimeType must also  be set. Compatible MIME
1315        /// types: application/json: Schema for JSON response. Refer to the JSON text generation guide for more details.
1316        #[serde(skip_serializing_if = "Option::is_none")]
1317        pub response_schema: Option<Schema>,
1318        /// Optional. The output schema of the generated response.
1319        /// This is an alternative to responseSchema that accepts a standard JSON Schema.
1320        /// If this is set, responseSchema must be omitted.
1321        /// Compatible MIME type: application/json.
1322        /// Supported properties: $id, $defs, $ref, type, properties, etc.
1323        #[serde(
1324            skip_serializing_if = "Option::is_none",
1325            rename = "_responseJsonSchema"
1326        )]
1327        pub _response_json_schema: Option<Value>,
1328        /// Internal or alternative representation for `response_json_schema`.
1329        #[serde(skip_serializing_if = "Option::is_none")]
1330        pub response_json_schema: Option<Value>,
1331        /// Number of generated responses to return. Currently, this value can only be set to 1. If
1332        /// unset, this will default to 1.
1333        #[serde(skip_serializing_if = "Option::is_none")]
1334        pub candidate_count: Option<i32>,
1335        /// The maximum number of tokens to include in a response candidate. Note: The default value varies by model, see
1336        /// the Model.output_token_limit attribute of the Model returned from the getModel function.
1337        #[serde(skip_serializing_if = "Option::is_none")]
1338        pub max_output_tokens: Option<u64>,
1339        /// Controls the randomness of the output. Note: The default value varies by model, see the Model.temperature
1340        /// attribute of the Model returned from the getModel function. Values can range from [0.0, 2.0].
1341        #[serde(skip_serializing_if = "Option::is_none")]
1342        pub temperature: Option<f64>,
1343        /// The maximum cumulative probability of tokens to consider when sampling. The model uses combined Top-k and
1344        /// Top-p (nucleus) sampling. Tokens are sorted based on their assigned probabilities so that only the most
1345        /// likely tokens are considered. Top-k sampling directly limits the maximum number of tokens to consider, while
1346        /// Nucleus sampling limits the number of tokens based on the cumulative probability. Note: The default value
1347        /// varies by Model and is specified by theModel.top_p attribute returned from the getModel function. An empty
1348        /// topK attribute indicates that the model doesn't apply top-k sampling and doesn't allow setting topK on requests.
1349        #[serde(skip_serializing_if = "Option::is_none")]
1350        pub top_p: Option<f64>,
1351        /// The maximum number of tokens to consider when sampling. Gemini models use Top-p (nucleus) sampling or a
1352        /// combination of Top-k and nucleus sampling. Top-k sampling considers the set of topK most probable tokens.
1353        /// Models running with nucleus sampling don't allow topK setting. Note: The default value varies by Model and is
1354        /// specified by theModel.top_p attribute returned from the getModel function. An empty topK attribute indicates
1355        /// that the model doesn't apply top-k sampling and doesn't allow setting topK on requests.
1356        #[serde(skip_serializing_if = "Option::is_none")]
1357        pub top_k: Option<i32>,
1358        /// Presence penalty applied to the next token's logprobs if the token has already been seen in the response.
1359        /// This penalty is binary on/off and not dependent on the number of times the token is used (after the first).
1360        /// Use frequencyPenalty for a penalty that increases with each use. A positive penalty will discourage the use
1361        /// of tokens that have already been used in the response, increasing the vocabulary. A negative penalty will
1362        /// encourage the use of tokens that have already been used in the response, decreasing the vocabulary.
1363        #[serde(skip_serializing_if = "Option::is_none")]
1364        pub presence_penalty: Option<f64>,
1365        /// Frequency penalty applied to the next token's logprobs, multiplied by the number of times each token has been
1366        /// seen in the response so far. A positive penalty will discourage the use of tokens that have already been
1367        /// used, proportional to the number of times the token has been used: The more a token is used, the more
1368        /// difficult it is for the  model to use that token again increasing the vocabulary of responses. Caution: A
1369        /// negative penalty will encourage the model to reuse tokens proportional to the number of times the token has
1370        /// been used. Small negative values will reduce the vocabulary of a response. Larger negative values will cause
1371        /// the model to  repeating a common token until it hits the maxOutputTokens limit: "...the the the the the...".
1372        #[serde(skip_serializing_if = "Option::is_none")]
1373        pub frequency_penalty: Option<f64>,
1374        /// If true, export the logprobs results in response.
1375        #[serde(skip_serializing_if = "Option::is_none")]
1376        pub response_logprobs: Option<bool>,
1377        /// Only valid if responseLogprobs=True. This sets the number of top logprobs to return at each decoding step in
1378        /// [Candidate.logprobs_result].
1379        #[serde(skip_serializing_if = "Option::is_none")]
1380        pub logprobs: Option<i32>,
1381        /// Configuration for thinking/reasoning.
1382        #[serde(skip_serializing_if = "Option::is_none")]
1383        pub thinking_config: Option<ThinkingConfig>,
1384        #[serde(skip_serializing_if = "Option::is_none")]
1385        pub image_config: Option<ImageConfig>,
1386    }
1387
1388    impl Default for GenerationConfig {
1389        fn default() -> Self {
1390            Self {
1391                temperature: Some(1.0),
1392                max_output_tokens: Some(4096),
1393                stop_sequences: None,
1394                response_mime_type: None,
1395                response_schema: None,
1396                _response_json_schema: None,
1397                response_json_schema: None,
1398                candidate_count: None,
1399                top_p: None,
1400                top_k: None,
1401                presence_penalty: None,
1402                frequency_penalty: None,
1403                response_logprobs: None,
1404                logprobs: None,
1405                thinking_config: None,
1406                image_config: None,
1407            }
1408        }
1409    }
1410
1411    #[derive(Debug, Deserialize, Serialize)]
1412    #[serde(rename_all = "camelCase")]
1413    pub struct ThinkingConfig {
1414        pub thinking_budget: u32,
1415        pub include_thoughts: Option<bool>,
1416    }
1417
1418    #[derive(Debug, Deserialize, Serialize)]
1419    #[serde(rename_all = "camelCase")]
1420    pub struct ImageConfig {
1421        #[serde(skip_serializing_if = "Option::is_none")]
1422        pub aspect_ratio: Option<String>,
1423        #[serde(skip_serializing_if = "Option::is_none")]
1424        pub image_size: Option<String>,
1425    }
1426
1427    /// The Schema object allows the definition of input and output data types. These types can be objects, but also
1428    /// primitives and arrays. Represents a select subset of an OpenAPI 3.0 schema object.
1429    /// From [Gemini API Reference](https://ai.google.dev/api/caching#Schema)
1430    #[derive(Debug, Deserialize, Serialize, Clone)]
1431    pub struct Schema {
1432        pub r#type: String,
1433        #[serde(skip_serializing_if = "Option::is_none")]
1434        pub format: Option<String>,
1435        #[serde(skip_serializing_if = "Option::is_none")]
1436        pub description: Option<String>,
1437        #[serde(skip_serializing_if = "Option::is_none")]
1438        pub nullable: Option<bool>,
1439        #[serde(skip_serializing_if = "Option::is_none")]
1440        pub r#enum: Option<Vec<String>>,
1441        #[serde(skip_serializing_if = "Option::is_none")]
1442        pub max_items: Option<i32>,
1443        #[serde(skip_serializing_if = "Option::is_none")]
1444        pub min_items: Option<i32>,
1445        #[serde(skip_serializing_if = "Option::is_none")]
1446        pub properties: Option<HashMap<String, Schema>>,
1447        #[serde(skip_serializing_if = "Option::is_none")]
1448        pub required: Option<Vec<String>>,
1449        #[serde(skip_serializing_if = "Option::is_none")]
1450        pub items: Option<Box<Schema>>,
1451    }
1452
1453    /// Flattens a JSON schema by resolving all `$ref` references inline.
1454    /// It takes a JSON schema that may contain `$ref` references to definitions
1455    /// in `$defs` or `definitions` sections and returns a new schema with all references
1456    /// resolved and inlined. This is necessary for APIs like Gemini that don't support
1457    /// schema references.
1458    pub fn flatten_schema(mut schema: Value) -> Result<Value, CompletionError> {
1459        // extracting $defs if they exist
1460        let defs = if let Some(obj) = schema.as_object() {
1461            obj.get("$defs").or_else(|| obj.get("definitions")).cloned()
1462        } else {
1463            None
1464        };
1465
1466        let Some(defs_value) = defs else {
1467            return Ok(schema);
1468        };
1469
1470        let Some(defs_obj) = defs_value.as_object() else {
1471            return Err(CompletionError::ResponseError(
1472                "$defs must be an object".into(),
1473            ));
1474        };
1475
1476        resolve_refs(&mut schema, defs_obj)?;
1477
1478        // removing $defs from the final schema because we have inlined everything
1479        if let Some(obj) = schema.as_object_mut() {
1480            obj.remove("$defs");
1481            obj.remove("definitions");
1482        }
1483
1484        Ok(schema)
1485    }
1486
1487    /// Recursively resolves all `$ref` references in a JSON value by
1488    /// replacing them with their definitions.
1489    fn resolve_refs(
1490        value: &mut Value,
1491        defs: &serde_json::Map<String, Value>,
1492    ) -> Result<(), CompletionError> {
1493        match value {
1494            Value::Object(obj) => {
1495                if let Some(ref_value) = obj.get("$ref")
1496                    && let Some(ref_str) = ref_value.as_str()
1497                {
1498                    // "#/$defs/Person" -> "Person"
1499                    let def_name = parse_ref_path(ref_str)?;
1500
1501                    let def = defs.get(&def_name).ok_or_else(|| {
1502                        CompletionError::ResponseError(format!("Reference not found: {}", ref_str))
1503                    })?;
1504
1505                    let mut resolved = def.clone();
1506                    resolve_refs(&mut resolved, defs)?;
1507                    *value = resolved;
1508                    return Ok(());
1509                }
1510
1511                for (_, v) in obj.iter_mut() {
1512                    resolve_refs(v, defs)?;
1513                }
1514            }
1515            Value::Array(arr) => {
1516                for item in arr.iter_mut() {
1517                    resolve_refs(item, defs)?;
1518                }
1519            }
1520            _ => {}
1521        }
1522
1523        Ok(())
1524    }
1525
1526    /// Parses a JSON Schema `$ref` path to extract the definition name.
1527    ///
1528    /// JSON Schema references use URI fragment syntax to point to definitions within
1529    /// the same document. This function extracts the definition name from common
1530    /// reference patterns used in JSON Schema.
1531    fn parse_ref_path(ref_str: &str) -> Result<String, CompletionError> {
1532        if let Some(fragment) = ref_str.strip_prefix('#') {
1533            if let Some(name) = fragment.strip_prefix("/$defs/") {
1534                Ok(name.to_string())
1535            } else if let Some(name) = fragment.strip_prefix("/definitions/") {
1536                Ok(name.to_string())
1537            } else {
1538                Err(CompletionError::ResponseError(format!(
1539                    "Unsupported reference format: {}",
1540                    ref_str
1541                )))
1542            }
1543        } else {
1544            Err(CompletionError::ResponseError(format!(
1545                "Only fragment references (#/...) are supported: {}",
1546                ref_str
1547            )))
1548        }
1549    }
1550
1551    /// Helper function to extract the type string from a JSON value.
1552    /// Handles both direct string types and array types (returns the first element).
1553    fn extract_type(type_value: &Value) -> Option<String> {
1554        if type_value.is_string() {
1555            type_value.as_str().map(String::from)
1556        } else if type_value.is_array() {
1557            type_value
1558                .as_array()
1559                .and_then(|arr| arr.first())
1560                .and_then(|v| v.as_str().map(String::from))
1561        } else {
1562            None
1563        }
1564    }
1565
1566    /// Helper function to extract type from anyOf, oneOf, or allOf schemas.
1567    /// Returns the type of the first non-null schema found.
1568    fn extract_type_from_composition(composition: &Value) -> Option<String> {
1569        composition.as_array().and_then(|arr| {
1570            arr.iter().find_map(|schema| {
1571                if let Some(obj) = schema.as_object() {
1572                    // Skip null types
1573                    if let Some(type_val) = obj.get("type")
1574                        && let Some(type_str) = type_val.as_str()
1575                        && type_str == "null"
1576                    {
1577                        return None;
1578                    }
1579                    // Extract type from this schema
1580                    obj.get("type").and_then(extract_type).or_else(|| {
1581                        if obj.contains_key("properties") {
1582                            Some("object".to_string())
1583                        } else {
1584                            None
1585                        }
1586                    })
1587                } else {
1588                    None
1589                }
1590            })
1591        })
1592    }
1593
1594    /// Helper function to extract the first non-null schema from anyOf, oneOf, or allOf.
1595    /// Returns the schema object that should be used for properties, required, etc.
1596    fn extract_schema_from_composition(
1597        composition: &Value,
1598    ) -> Option<serde_json::Map<String, Value>> {
1599        composition.as_array().and_then(|arr| {
1600            arr.iter().find_map(|schema| {
1601                if let Some(obj) = schema.as_object()
1602                    && let Some(type_val) = obj.get("type")
1603                    && let Some(type_str) = type_val.as_str()
1604                {
1605                    if type_str == "null" {
1606                        return None;
1607                    }
1608                    Some(obj.clone())
1609                } else {
1610                    None
1611                }
1612            })
1613        })
1614    }
1615
1616    /// Helper function to infer the type of a schema object.
1617    /// Checks for explicit type, then anyOf/oneOf/allOf, then infers from properties.
1618    fn infer_type(obj: &serde_json::Map<String, Value>) -> String {
1619        // First, try direct type field
1620        if let Some(type_val) = obj.get("type")
1621            && let Some(type_str) = extract_type(type_val)
1622        {
1623            return type_str;
1624        }
1625
1626        // Then try anyOf, oneOf, allOf (in that order)
1627        if let Some(any_of) = obj.get("anyOf")
1628            && let Some(type_str) = extract_type_from_composition(any_of)
1629        {
1630            return type_str;
1631        }
1632
1633        if let Some(one_of) = obj.get("oneOf")
1634            && let Some(type_str) = extract_type_from_composition(one_of)
1635        {
1636            return type_str;
1637        }
1638
1639        if let Some(all_of) = obj.get("allOf")
1640            && let Some(type_str) = extract_type_from_composition(all_of)
1641        {
1642            return type_str;
1643        }
1644
1645        // Finally, infer object type if properties are present
1646        if obj.contains_key("properties") {
1647            "object".to_string()
1648        } else {
1649            String::new()
1650        }
1651    }
1652
1653    impl TryFrom<Value> for Schema {
1654        type Error = CompletionError;
1655
1656        fn try_from(value: Value) -> Result<Self, Self::Error> {
1657            let flattened_val = flatten_schema(value)?;
1658            if let Some(obj) = flattened_val.as_object() {
1659                // Determine which object to use for extracting properties and required fields.
1660                // If this object has anyOf/oneOf/allOf, we need to extract properties from the composition.
1661                let props_source = if obj.get("properties").is_none() {
1662                    if let Some(any_of) = obj.get("anyOf") {
1663                        extract_schema_from_composition(any_of)
1664                    } else if let Some(one_of) = obj.get("oneOf") {
1665                        extract_schema_from_composition(one_of)
1666                    } else if let Some(all_of) = obj.get("allOf") {
1667                        extract_schema_from_composition(all_of)
1668                    } else {
1669                        None
1670                    }
1671                    .unwrap_or(obj.clone())
1672                } else {
1673                    obj.clone()
1674                };
1675
1676                Ok(Schema {
1677                    r#type: infer_type(obj),
1678                    format: obj.get("format").and_then(|v| v.as_str()).map(String::from),
1679                    description: obj
1680                        .get("description")
1681                        .and_then(|v| v.as_str())
1682                        .map(String::from),
1683                    nullable: obj.get("nullable").and_then(|v| v.as_bool()),
1684                    r#enum: obj.get("enum").and_then(|v| v.as_array()).map(|arr| {
1685                        arr.iter()
1686                            .filter_map(|v| v.as_str().map(String::from))
1687                            .collect()
1688                    }),
1689                    max_items: obj
1690                        .get("maxItems")
1691                        .and_then(|v| v.as_i64())
1692                        .map(|v| v as i32),
1693                    min_items: obj
1694                        .get("minItems")
1695                        .and_then(|v| v.as_i64())
1696                        .map(|v| v as i32),
1697                    properties: props_source
1698                        .get("properties")
1699                        .and_then(|v| v.as_object())
1700                        .map(|map| {
1701                            map.iter()
1702                                .filter_map(|(k, v)| {
1703                                    v.clone().try_into().ok().map(|schema| (k.clone(), schema))
1704                                })
1705                                .collect()
1706                        }),
1707                    required: props_source
1708                        .get("required")
1709                        .and_then(|v| v.as_array())
1710                        .map(|arr| {
1711                            arr.iter()
1712                                .filter_map(|v| v.as_str().map(String::from))
1713                                .collect()
1714                        }),
1715                    items: obj
1716                        .get("items")
1717                        .and_then(|v| v.clone().try_into().ok())
1718                        .map(Box::new),
1719                })
1720            } else {
1721                Err(CompletionError::ResponseError(
1722                    "Expected a JSON object for Schema".into(),
1723                ))
1724            }
1725        }
1726    }
1727
1728    #[derive(Debug, Serialize)]
1729    #[serde(rename_all = "camelCase")]
1730    pub struct GenerateContentRequest {
1731        pub contents: Vec<Content>,
1732        #[serde(skip_serializing_if = "Option::is_none")]
1733        pub tools: Option<Tool>,
1734        pub tool_config: Option<ToolConfig>,
1735        /// Optional. Configuration options for model generation and outputs.
1736        pub generation_config: Option<GenerationConfig>,
1737        /// Optional. A list of unique SafetySetting instances for blocking unsafe content. This will be enforced on the
1738        /// [GenerateContentRequest.contents] and [GenerateContentResponse.candidates]. There should not be more than one
1739        /// setting for each SafetyCategory type. The API will block any contents and responses that fail to meet the
1740        /// thresholds set by these settings. This list overrides the default settings for each SafetyCategory specified
1741        /// in the safetySettings. If there is no SafetySetting for a given SafetyCategory provided in the list, the API
1742        /// will use the default safety setting for that category. Harm categories:
1743        ///     - HARM_CATEGORY_HATE_SPEECH,
1744        ///     - HARM_CATEGORY_SEXUALLY_EXPLICIT
1745        ///     - HARM_CATEGORY_DANGEROUS_CONTENT
1746        ///     - HARM_CATEGORY_HARASSMENT
1747        /// are supported.
1748        /// Refer to the guide for detailed information on available safety settings. Also refer to the Safety guidance
1749        /// to learn how to incorporate safety considerations in your AI applications.
1750        pub safety_settings: Option<Vec<SafetySetting>>,
1751        /// Optional. Developer set system instruction(s). Currently, text only.
1752        /// From [Gemini API Reference](https://ai.google.dev/gemini-api/docs/system-instructions?lang=rest)
1753        pub system_instruction: Option<Content>,
1754        // cachedContent: Optional<String>
1755        /// Additional parameters.
1756        #[serde(flatten, skip_serializing_if = "Option::is_none")]
1757        pub additional_params: Option<serde_json::Value>,
1758    }
1759
1760    #[derive(Debug, Serialize)]
1761    #[serde(rename_all = "camelCase")]
1762    pub struct Tool {
1763        pub function_declarations: Vec<FunctionDeclaration>,
1764        pub code_execution: Option<CodeExecution>,
1765    }
1766
1767    #[derive(Debug, Serialize, Clone)]
1768    #[serde(rename_all = "camelCase")]
1769    pub struct FunctionDeclaration {
1770        pub name: String,
1771        pub description: String,
1772        #[serde(skip_serializing_if = "Option::is_none")]
1773        pub parameters: Option<Schema>,
1774    }
1775
1776    #[derive(Debug, Serialize, Deserialize)]
1777    #[serde(rename_all = "camelCase")]
1778    pub struct ToolConfig {
1779        pub function_calling_config: Option<FunctionCallingMode>,
1780    }
1781
1782    #[derive(Debug, Serialize, Deserialize, Default)]
1783    #[serde(tag = "mode", rename_all = "UPPERCASE")]
1784    pub enum FunctionCallingMode {
1785        #[default]
1786        Auto,
1787        None,
1788        Any {
1789            #[serde(skip_serializing_if = "Option::is_none")]
1790            allowed_function_names: Option<Vec<String>>,
1791        },
1792    }
1793
1794    impl TryFrom<message::ToolChoice> for FunctionCallingMode {
1795        type Error = CompletionError;
1796        fn try_from(value: message::ToolChoice) -> Result<Self, Self::Error> {
1797            let res = match value {
1798                message::ToolChoice::Auto => Self::Auto,
1799                message::ToolChoice::None => Self::None,
1800                message::ToolChoice::Required => Self::Any {
1801                    allowed_function_names: None,
1802                },
1803                message::ToolChoice::Specific { function_names } => Self::Any {
1804                    allowed_function_names: Some(function_names),
1805                },
1806            };
1807
1808            Ok(res)
1809        }
1810    }
1811
1812    #[derive(Debug, Serialize)]
1813    pub struct CodeExecution {}
1814
1815    #[derive(Debug, Serialize)]
1816    #[serde(rename_all = "camelCase")]
1817    pub struct SafetySetting {
1818        pub category: HarmCategory,
1819        pub threshold: HarmBlockThreshold,
1820    }
1821
1822    #[derive(Debug, Serialize)]
1823    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1824    pub enum HarmBlockThreshold {
1825        HarmBlockThresholdUnspecified,
1826        BlockLowAndAbove,
1827        BlockMediumAndAbove,
1828        BlockOnlyHigh,
1829        BlockNone,
1830        Off,
1831    }
1832}
1833
1834#[cfg(test)]
1835mod tests {
1836    use crate::{message, providers::gemini::completion::gemini_api_types::flatten_schema};
1837
1838    use super::*;
1839    use serde_json::json;
1840
1841    #[test]
1842    fn test_deserialize_message_user() {
1843        let raw_message = r#"{
1844            "parts": [
1845                {"text": "Hello, world!"},
1846                {"inlineData": {"mimeType": "image/png", "data": "base64encodeddata"}},
1847                {"functionCall": {"name": "test_function", "args": {"arg1": "value1"}}},
1848                {"functionResponse": {"name": "test_function", "response": {"result": "success"}}},
1849                {"fileData": {"mimeType": "application/pdf", "fileUri": "http://example.com/file.pdf"}},
1850                {"executableCode": {"code": "print('Hello, world!')", "language": "PYTHON"}},
1851                {"codeExecutionResult": {"output": "Hello, world!", "outcome": "OUTCOME_OK"}}
1852            ],
1853            "role": "user"
1854        }"#;
1855
1856        let content: Content = {
1857            let jd = &mut serde_json::Deserializer::from_str(raw_message);
1858            serde_path_to_error::deserialize(jd).unwrap_or_else(|err| {
1859                panic!("Deserialization error at {}: {}", err.path(), err);
1860            })
1861        };
1862        assert_eq!(content.role, Some(Role::User));
1863        assert_eq!(content.parts.len(), 7);
1864
1865        let parts: Vec<Part> = content.parts.into_iter().collect();
1866
1867        if let Part {
1868            part: PartKind::Text(text),
1869            ..
1870        } = &parts[0]
1871        {
1872            assert_eq!(text, "Hello, world!");
1873        } else {
1874            panic!("Expected text part");
1875        }
1876
1877        if let Part {
1878            part: PartKind::InlineData(inline_data),
1879            ..
1880        } = &parts[1]
1881        {
1882            assert_eq!(inline_data.mime_type, "image/png");
1883            assert_eq!(inline_data.data, "base64encodeddata");
1884        } else {
1885            panic!("Expected inline data part");
1886        }
1887
1888        if let Part {
1889            part: PartKind::FunctionCall(function_call),
1890            ..
1891        } = &parts[2]
1892        {
1893            assert_eq!(function_call.name, "test_function");
1894            assert_eq!(
1895                function_call.args.as_object().unwrap().get("arg1").unwrap(),
1896                "value1"
1897            );
1898        } else {
1899            panic!("Expected function call part");
1900        }
1901
1902        if let Part {
1903            part: PartKind::FunctionResponse(function_response),
1904            ..
1905        } = &parts[3]
1906        {
1907            assert_eq!(function_response.name, "test_function");
1908            assert_eq!(
1909                function_response
1910                    .response
1911                    .as_ref()
1912                    .unwrap()
1913                    .get("result")
1914                    .unwrap(),
1915                "success"
1916            );
1917        } else {
1918            panic!("Expected function response part");
1919        }
1920
1921        if let Part {
1922            part: PartKind::FileData(file_data),
1923            ..
1924        } = &parts[4]
1925        {
1926            assert_eq!(file_data.mime_type.as_ref().unwrap(), "application/pdf");
1927            assert_eq!(file_data.file_uri, "http://example.com/file.pdf");
1928        } else {
1929            panic!("Expected file data part");
1930        }
1931
1932        if let Part {
1933            part: PartKind::ExecutableCode(executable_code),
1934            ..
1935        } = &parts[5]
1936        {
1937            assert_eq!(executable_code.code, "print('Hello, world!')");
1938        } else {
1939            panic!("Expected executable code part");
1940        }
1941
1942        if let Part {
1943            part: PartKind::CodeExecutionResult(code_execution_result),
1944            ..
1945        } = &parts[6]
1946        {
1947            assert_eq!(
1948                code_execution_result.clone().output.unwrap(),
1949                "Hello, world!"
1950            );
1951        } else {
1952            panic!("Expected code execution result part");
1953        }
1954    }
1955
1956    #[test]
1957    fn test_deserialize_message_model() {
1958        let json_data = json!({
1959            "parts": [{"text": "Hello, user!"}],
1960            "role": "model"
1961        });
1962
1963        let content: Content = serde_json::from_value(json_data).unwrap();
1964        assert_eq!(content.role, Some(Role::Model));
1965        assert_eq!(content.parts.len(), 1);
1966        if let Some(Part {
1967            part: PartKind::Text(text),
1968            ..
1969        }) = content.parts.first()
1970        {
1971            assert_eq!(text, "Hello, user!");
1972        } else {
1973            panic!("Expected text part");
1974        }
1975    }
1976
1977    #[test]
1978    fn test_message_conversion_user() {
1979        let msg = message::Message::user("Hello, world!");
1980        let content: Content = msg.try_into().unwrap();
1981        assert_eq!(content.role, Some(Role::User));
1982        assert_eq!(content.parts.len(), 1);
1983        if let Some(Part {
1984            part: PartKind::Text(text),
1985            ..
1986        }) = &content.parts.first()
1987        {
1988            assert_eq!(text, "Hello, world!");
1989        } else {
1990            panic!("Expected text part");
1991        }
1992    }
1993
1994    #[test]
1995    fn test_message_conversion_model() {
1996        let msg = message::Message::assistant("Hello, user!");
1997
1998        let content: Content = msg.try_into().unwrap();
1999        assert_eq!(content.role, Some(Role::Model));
2000        assert_eq!(content.parts.len(), 1);
2001        if let Some(Part {
2002            part: PartKind::Text(text),
2003            ..
2004        }) = &content.parts.first()
2005        {
2006            assert_eq!(text, "Hello, user!");
2007        } else {
2008            panic!("Expected text part");
2009        }
2010    }
2011
2012    #[test]
2013    fn test_message_conversion_tool_call() {
2014        let tool_call = message::ToolCall {
2015            id: "test_tool".to_string(),
2016            call_id: None,
2017            function: message::ToolFunction {
2018                name: "test_function".to_string(),
2019                arguments: json!({"arg1": "value1"}),
2020            },
2021        };
2022
2023        let msg = message::Message::Assistant {
2024            id: None,
2025            content: OneOrMany::one(message::AssistantContent::ToolCall(tool_call)),
2026        };
2027
2028        let content: Content = msg.try_into().unwrap();
2029        assert_eq!(content.role, Some(Role::Model));
2030        assert_eq!(content.parts.len(), 1);
2031        if let Some(Part {
2032            part: PartKind::FunctionCall(function_call),
2033            ..
2034        }) = content.parts.first()
2035        {
2036            assert_eq!(function_call.name, "test_function");
2037            assert_eq!(
2038                function_call.args.as_object().unwrap().get("arg1").unwrap(),
2039                "value1"
2040            );
2041        } else {
2042            panic!("Expected function call part");
2043        }
2044    }
2045
2046    #[test]
2047    fn test_vec_schema_conversion() {
2048        let schema_with_ref = json!({
2049            "type": "array",
2050            "items": {
2051                "$ref": "#/$defs/Person"
2052            },
2053            "$defs": {
2054                "Person": {
2055                    "type": "object",
2056                    "properties": {
2057                        "first_name": {
2058                            "type": ["string", "null"],
2059                            "description": "The person's first name, if provided (null otherwise)"
2060                        },
2061                        "last_name": {
2062                            "type": ["string", "null"],
2063                            "description": "The person's last name, if provided (null otherwise)"
2064                        },
2065                        "job": {
2066                            "type": ["string", "null"],
2067                            "description": "The person's job, if provided (null otherwise)"
2068                        }
2069                    },
2070                    "required": []
2071                }
2072            }
2073        });
2074
2075        let result: Result<Schema, _> = schema_with_ref.try_into();
2076
2077        match result {
2078            Ok(schema) => {
2079                assert_eq!(schema.r#type, "array");
2080
2081                if let Some(items) = schema.items {
2082                    println!("item types: {}", items.r#type);
2083
2084                    assert_ne!(items.r#type, "", "Items type should not be empty string!");
2085                    assert_eq!(items.r#type, "object", "Items should be object type");
2086                } else {
2087                    panic!("Schema should have items field for array type");
2088                }
2089            }
2090            Err(e) => println!("Schema conversion failed: {:?}", e),
2091        }
2092    }
2093
2094    #[test]
2095    fn test_object_schema() {
2096        let simple_schema = json!({
2097            "type": "object",
2098            "properties": {
2099                "name": {
2100                    "type": "string"
2101                }
2102            }
2103        });
2104
2105        let schema: Schema = simple_schema.try_into().unwrap();
2106        assert_eq!(schema.r#type, "object");
2107        assert!(schema.properties.is_some());
2108    }
2109
2110    #[test]
2111    fn test_array_with_inline_items() {
2112        let inline_schema = json!({
2113            "type": "array",
2114            "items": {
2115                "type": "object",
2116                "properties": {
2117                    "name": {
2118                        "type": "string"
2119                    }
2120                }
2121            }
2122        });
2123
2124        let schema: Schema = inline_schema.try_into().unwrap();
2125        assert_eq!(schema.r#type, "array");
2126
2127        if let Some(items) = schema.items {
2128            assert_eq!(items.r#type, "object");
2129            assert!(items.properties.is_some());
2130        } else {
2131            panic!("Schema should have items field");
2132        }
2133    }
2134    #[test]
2135    fn test_flattened_schema() {
2136        let ref_schema = json!({
2137            "type": "array",
2138            "items": {
2139                "$ref": "#/$defs/Person"
2140            },
2141            "$defs": {
2142                "Person": {
2143                    "type": "object",
2144                    "properties": {
2145                        "name": { "type": "string" }
2146                    }
2147                }
2148            }
2149        });
2150
2151        let flattened = flatten_schema(ref_schema).unwrap();
2152        let schema: Schema = flattened.try_into().unwrap();
2153
2154        assert_eq!(schema.r#type, "array");
2155
2156        if let Some(items) = schema.items {
2157            println!("Flattened items type: '{}'", items.r#type);
2158
2159            assert_eq!(items.r#type, "object");
2160            assert!(items.properties.is_some());
2161        }
2162    }
2163}