rig/providers/gemini/
completion.rs

1// ================================================================
2//! Google Gemini Completion Integration
3//! From [Gemini API Reference](https://ai.google.dev/api/generate-content)
4// ================================================================
5/// `gemini-2.5-pro-preview-06-05` completion model
6pub const GEMINI_2_5_PRO_PREVIEW_06_05: &str = "gemini-2.5-pro-preview-06-05";
7/// `gemini-2.5-pro-preview-05-06` completion model
8pub const GEMINI_2_5_PRO_PREVIEW_05_06: &str = "gemini-2.5-pro-preview-05-06";
9/// `gemini-2.5-pro-preview-03-25` completion model
10pub const GEMINI_2_5_PRO_PREVIEW_03_25: &str = "gemini-2.5-pro-preview-03-25";
11/// `gemini-2.5-flash-preview-05-20` completion model
12pub const GEMINI_2_5_FLASH_PREVIEW_05_20: &str = "gemini-2.5-flash-preview-05-20";
13/// `gemini-2.5-flash-preview-04-17` completion model
14pub const GEMINI_2_5_FLASH_PREVIEW_04_17: &str = "gemini-2.5-flash-preview-04-17";
15/// `gemini-2.5-pro-exp-03-25` experimental completion model
16pub const GEMINI_2_5_PRO_EXP_03_25: &str = "gemini-2.5-pro-exp-03-25";
17/// `gemini-2.0-flash-lite` completion model
18pub const GEMINI_2_0_FLASH_LITE: &str = "gemini-2.0-flash-lite";
19/// `gemini-2.0-flash` completion model
20pub const GEMINI_2_0_FLASH: &str = "gemini-2.0-flash";
21/// `gemini-1.5-flash` completion model
22pub const GEMINI_1_5_FLASH: &str = "gemini-1.5-flash";
23/// `gemini-1.5-pro` completion model
24pub const GEMINI_1_5_PRO: &str = "gemini-1.5-pro";
25/// `gemini-1.5-pro-8b` completion model
26pub const GEMINI_1_5_PRO_8B: &str = "gemini-1.5-pro-8b";
27/// `gemini-1.0-pro` completion model
28pub const GEMINI_1_0_PRO: &str = "gemini-1.0-pro";
29
30use self::gemini_api_types::Schema;
31use crate::http_client::HttpClientExt;
32use crate::message::Reasoning;
33use crate::providers::gemini::completion::gemini_api_types::{
34    AdditionalParameters, FunctionCallingMode, ToolConfig,
35};
36use crate::providers::gemini::streaming::StreamingCompletionResponse;
37use crate::telemetry::SpanCombinator;
38use crate::{
39    OneOrMany,
40    completion::{self, CompletionError, CompletionRequest},
41};
42use gemini_api_types::{
43    Content, FunctionDeclaration, GenerateContentRequest, GenerateContentResponse, Part, PartKind,
44    Role, Tool,
45};
46use serde_json::{Map, Value};
47use std::convert::TryFrom;
48use tracing::info_span;
49
50use super::Client;
51
52// =================================================================
53// Rig Implementation Types
54// =================================================================
55
56#[derive(Clone)]
57pub struct CompletionModel<T = reqwest::Client> {
58    pub(crate) client: Client<T>,
59    pub model: String,
60}
61
62impl<T> CompletionModel<T> {
63    pub fn new(client: Client<T>, model: &str) -> Self {
64        Self {
65            client,
66            model: model.to_string(),
67        }
68    }
69}
70
71impl<T> completion::CompletionModel for CompletionModel<T>
72where
73    T: HttpClientExt + Clone + 'static,
74{
75    type Response = GenerateContentResponse;
76    type StreamingResponse = StreamingCompletionResponse;
77
78    #[cfg_attr(feature = "worker", worker::send)]
79    async fn completion(
80        &self,
81        completion_request: CompletionRequest,
82    ) -> Result<completion::CompletionResponse<GenerateContentResponse>, CompletionError> {
83        let span = if tracing::Span::current().is_disabled() {
84            info_span!(
85                target: "rig::completions",
86                "generate_content",
87                gen_ai.operation.name = "generate_content",
88                gen_ai.provider.name = "gcp.gemini",
89                gen_ai.request.model = self.model,
90                gen_ai.system_instructions = &completion_request.preamble,
91                gen_ai.response.id = tracing::field::Empty,
92                gen_ai.response.model = tracing::field::Empty,
93                gen_ai.usage.output_tokens = tracing::field::Empty,
94                gen_ai.usage.input_tokens = tracing::field::Empty,
95                gen_ai.input.messages = tracing::field::Empty,
96                gen_ai.output.messages = tracing::field::Empty,
97            )
98        } else {
99            tracing::Span::current()
100        };
101
102        let request = create_request_body(completion_request)?;
103        span.record_model_input(&request.contents);
104
105        span.record_model_input(&request.contents);
106
107        tracing::debug!(
108            "Sending completion request to Gemini API {}",
109            serde_json::to_string_pretty(&request)?
110        );
111
112        let body = serde_json::to_vec(&request)?;
113
114        let request = self
115            .client
116            .post(&format!("/v1beta/models/{}:generateContent", self.model))
117            .header("Content-Type", "application/json")
118            .body(body)
119            .map_err(|e| CompletionError::HttpError(e.into()))?;
120
121        let response = self.client.send::<_, Vec<u8>>(request).await?;
122
123        if response.status().is_success() {
124            let response_body = response
125                .into_body()
126                .await
127                .map_err(CompletionError::HttpError)?;
128
129            let response_text = String::from_utf8_lossy(&response_body).to_string();
130            tracing::debug!("Received raw response from Gemini API: {}", response_text);
131
132            let response: GenerateContentResponse = serde_json::from_slice(&response_body)
133                .map_err(|err| {
134                    tracing::error!(
135                        error = %err,
136                        body = %response_text,
137                        "Failed to deserialize Gemini completion response"
138                    );
139                    CompletionError::JsonError(err)
140                })?;
141
142            match response.usage_metadata {
143                Some(ref usage) => tracing::info!(target: "rig",
144                "Gemini completion token usage: {}",
145                usage
146                ),
147                None => tracing::info!(target: "rig",
148                    "Gemini completion token usage: n/a",
149                ),
150            }
151
152            let span = tracing::Span::current();
153            span.record_model_output(&response.candidates);
154            span.record_response_metadata(&response);
155            span.record_token_usage(&response.usage_metadata);
156
157            tracing::debug!(
158                "Received response from Gemini API: {}",
159                serde_json::to_string_pretty(&response)?
160            );
161
162            response.try_into()
163        } else {
164            let text = String::from_utf8_lossy(
165                &response
166                    .into_body()
167                    .await
168                    .map_err(CompletionError::HttpError)?,
169            )
170            .into();
171
172            Err(CompletionError::ProviderError(text))
173        }
174    }
175
176    #[cfg_attr(feature = "worker", worker::send)]
177    async fn stream(
178        &self,
179        request: CompletionRequest,
180    ) -> Result<
181        crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
182        CompletionError,
183    > {
184        CompletionModel::stream(self, request).await
185    }
186}
187
188pub(crate) fn create_request_body(
189    completion_request: CompletionRequest,
190) -> Result<GenerateContentRequest, CompletionError> {
191    let mut full_history = Vec::new();
192    full_history.extend(completion_request.chat_history);
193
194    let additional_params = completion_request
195        .additional_params
196        .unwrap_or_else(|| Value::Object(Map::new()));
197
198    let AdditionalParameters {
199        mut generation_config,
200        additional_params,
201    } = serde_json::from_value::<AdditionalParameters>(additional_params)?;
202
203    if let Some(temp) = completion_request.temperature {
204        generation_config.temperature = Some(temp);
205    }
206
207    if let Some(max_tokens) = completion_request.max_tokens {
208        generation_config.max_output_tokens = Some(max_tokens);
209    }
210
211    let system_instruction = completion_request.preamble.clone().map(|preamble| Content {
212        parts: vec![preamble.into()],
213        role: Some(Role::Model),
214    });
215
216    let tools = if completion_request.tools.is_empty() {
217        None
218    } else {
219        Some(Tool::try_from(completion_request.tools)?)
220    };
221
222    let tool_config = if let Some(cfg) = completion_request.tool_choice {
223        Some(ToolConfig {
224            function_calling_config: Some(FunctionCallingMode::try_from(cfg)?),
225        })
226    } else {
227        None
228    };
229
230    let request = GenerateContentRequest {
231        contents: full_history
232            .into_iter()
233            .map(|msg| {
234                msg.try_into()
235                    .map_err(|e| CompletionError::RequestError(Box::new(e)))
236            })
237            .collect::<Result<Vec<_>, _>>()?,
238        generation_config: Some(generation_config),
239        safety_settings: None,
240        tools,
241        tool_config,
242        system_instruction,
243        additional_params,
244    };
245
246    Ok(request)
247}
248
249impl TryFrom<completion::ToolDefinition> for Tool {
250    type Error = CompletionError;
251
252    fn try_from(tool: completion::ToolDefinition) -> Result<Self, Self::Error> {
253        let parameters: Option<Schema> =
254            if tool.parameters == serde_json::json!({"type": "object", "properties": {}}) {
255                None
256            } else {
257                Some(tool.parameters.try_into()?)
258            };
259
260        Ok(Self {
261            function_declarations: vec![FunctionDeclaration {
262                name: tool.name,
263                description: tool.description,
264                parameters,
265            }],
266            code_execution: None,
267        })
268    }
269}
270
271impl TryFrom<Vec<completion::ToolDefinition>> for Tool {
272    type Error = CompletionError;
273
274    fn try_from(tools: Vec<completion::ToolDefinition>) -> Result<Self, Self::Error> {
275        let mut function_declarations = Vec::new();
276
277        for tool in tools {
278            let parameters =
279                if tool.parameters == serde_json::json!({"type": "object", "properties": {}}) {
280                    None
281                } else {
282                    match tool.parameters.try_into() {
283                        Ok(schema) => Some(schema),
284                        Err(e) => {
285                            let emsg = format!(
286                                "Tool '{}' could not be converted to a schema: {:?}",
287                                tool.name, e,
288                            );
289                            return Err(CompletionError::ProviderError(emsg));
290                        }
291                    }
292                };
293
294            function_declarations.push(FunctionDeclaration {
295                name: tool.name,
296                description: tool.description,
297                parameters,
298            });
299        }
300
301        Ok(Self {
302            function_declarations,
303            code_execution: None,
304        })
305    }
306}
307
308impl TryFrom<GenerateContentResponse> for completion::CompletionResponse<GenerateContentResponse> {
309    type Error = CompletionError;
310
311    fn try_from(response: GenerateContentResponse) -> Result<Self, Self::Error> {
312        let candidate = response.candidates.first().ok_or_else(|| {
313            CompletionError::ResponseError("No response candidates in response".into())
314        })?;
315
316        let content = candidate
317            .content
318            .as_ref()
319            .ok_or_else(|| {
320                let reason = candidate
321                    .finish_reason
322                    .as_ref()
323                    .map(|r| format!("finish_reason={r:?}"))
324                    .unwrap_or_else(|| "finish_reason=<unknown>".to_string());
325                let message = candidate
326                    .finish_message
327                    .as_deref()
328                    .unwrap_or("no finish message provided");
329                CompletionError::ResponseError(format!(
330                    "Gemini candidate missing content ({reason}, finish_message={message})"
331                ))
332            })?
333            .parts
334            .iter()
335            .map(|Part { thought, part, .. }| {
336                Ok(match part {
337                    PartKind::Text(text) => {
338                        if let Some(thought) = thought
339                            && *thought
340                        {
341                            completion::AssistantContent::Reasoning(Reasoning::new(text))
342                        } else {
343                            completion::AssistantContent::text(text)
344                        }
345                    }
346                    PartKind::FunctionCall(function_call) => {
347                        completion::AssistantContent::tool_call(
348                            &function_call.name,
349                            &function_call.name,
350                            function_call.args.clone(),
351                        )
352                    }
353                    _ => {
354                        return Err(CompletionError::ResponseError(
355                            "Response did not contain a message or tool call".into(),
356                        ));
357                    }
358                })
359            })
360            .collect::<Result<Vec<_>, _>>()?;
361
362        let choice = OneOrMany::many(content).map_err(|_| {
363            CompletionError::ResponseError(
364                "Response contained no message or tool call (empty)".to_owned(),
365            )
366        })?;
367
368        let usage = response
369            .usage_metadata
370            .as_ref()
371            .map(|usage| completion::Usage {
372                input_tokens: usage.prompt_token_count as u64,
373                output_tokens: usage.candidates_token_count.unwrap_or(0) as u64,
374                total_tokens: usage.total_token_count as u64,
375            })
376            .unwrap_or_default();
377
378        Ok(completion::CompletionResponse {
379            choice,
380            usage,
381            raw_response: response,
382        })
383    }
384}
385
386pub mod gemini_api_types {
387    use crate::telemetry::ProviderResponseExt;
388    use std::{collections::HashMap, convert::Infallible, str::FromStr};
389
390    // =================================================================
391    // Gemini API Types
392    // =================================================================
393    use serde::{Deserialize, Serialize};
394    use serde_json::{Value, json};
395
396    use crate::completion::GetTokenUsage;
397    use crate::message::{DocumentSourceKind, ImageMediaType, MessageError, MimeType};
398    use crate::{
399        OneOrMany,
400        completion::CompletionError,
401        message::{self, Reasoning, Text},
402        providers::gemini::gemini_api_types::{CodeExecutionResult, ExecutableCode},
403    };
404
405    #[derive(Debug, Deserialize, Serialize, Default)]
406    #[serde(rename_all = "camelCase")]
407    pub struct AdditionalParameters {
408        /// Change your Gemini request configuration.
409        pub generation_config: GenerationConfig,
410        /// Any additional parameters that you want.
411        #[serde(flatten, skip_serializing_if = "Option::is_none")]
412        pub additional_params: Option<serde_json::Value>,
413    }
414
415    impl AdditionalParameters {
416        pub fn with_config(mut self, cfg: GenerationConfig) -> Self {
417            self.generation_config = cfg;
418            self
419        }
420
421        pub fn with_params(mut self, params: serde_json::Value) -> Self {
422            self.additional_params = Some(params);
423            self
424        }
425    }
426
427    /// Response from the model supporting multiple candidate responses.
428    /// Safety ratings and content filtering are reported for both prompt in GenerateContentResponse.prompt_feedback
429    /// and for each candidate in finishReason and in safetyRatings.
430    /// The API:
431    ///     - Returns either all requested candidates or none of them
432    ///     - Returns no candidates at all only if there was something wrong with the prompt (check promptFeedback)
433    ///     - Reports feedback on each candidate in finishReason and safetyRatings.
434    #[derive(Debug, Deserialize, Serialize)]
435    #[serde(rename_all = "camelCase")]
436    pub struct GenerateContentResponse {
437        pub response_id: String,
438        /// Candidate responses from the model.
439        pub candidates: Vec<ContentCandidate>,
440        /// Returns the prompt's feedback related to the content filters.
441        pub prompt_feedback: Option<PromptFeedback>,
442        /// Output only. Metadata on the generation requests' token usage.
443        pub usage_metadata: Option<UsageMetadata>,
444        pub model_version: Option<String>,
445    }
446
447    impl ProviderResponseExt for GenerateContentResponse {
448        type OutputMessage = ContentCandidate;
449        type Usage = UsageMetadata;
450
451        fn get_response_id(&self) -> Option<String> {
452            Some(self.response_id.clone())
453        }
454
455        fn get_response_model_name(&self) -> Option<String> {
456            None
457        }
458
459        fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
460            self.candidates.clone()
461        }
462
463        fn get_text_response(&self) -> Option<String> {
464            let str = self
465                .candidates
466                .iter()
467                .filter_map(|x| {
468                    let content = x.content.as_ref()?;
469                    if content.role.as_ref().is_none_or(|y| y != &Role::Model) {
470                        return None;
471                    }
472
473                    let res = content
474                        .parts
475                        .iter()
476                        .filter_map(|part| {
477                            if let PartKind::Text(ref str) = part.part {
478                                Some(str.to_owned())
479                            } else {
480                                None
481                            }
482                        })
483                        .collect::<Vec<String>>()
484                        .join("\n");
485
486                    Some(res)
487                })
488                .collect::<Vec<String>>()
489                .join("\n");
490
491            if str.is_empty() { None } else { Some(str) }
492        }
493
494        fn get_usage(&self) -> Option<Self::Usage> {
495            self.usage_metadata.clone()
496        }
497    }
498
499    /// A response candidate generated from the model.
500    #[derive(Clone, Debug, Deserialize, Serialize)]
501    #[serde(rename_all = "camelCase")]
502    pub struct ContentCandidate {
503        /// Output only. Generated content returned from the model.
504        #[serde(skip_serializing_if = "Option::is_none")]
505        pub content: Option<Content>,
506        /// Optional. Output only. The reason why the model stopped generating tokens.
507        /// If empty, the model has not stopped generating tokens.
508        pub finish_reason: Option<FinishReason>,
509        /// List of ratings for the safety of a response candidate.
510        /// There is at most one rating per category.
511        pub safety_ratings: Option<Vec<SafetyRating>>,
512        /// Output only. Citation information for model-generated candidate.
513        /// This field may be populated with recitation information for any text included in the content.
514        /// These are passages that are "recited" from copyrighted material in the foundational LLM's training data.
515        pub citation_metadata: Option<CitationMetadata>,
516        /// Output only. Token count for this candidate.
517        pub token_count: Option<i32>,
518        /// Output only.
519        pub avg_logprobs: Option<f64>,
520        /// Output only. Log-likelihood scores for the response tokens and top tokens
521        pub logprobs_result: Option<LogprobsResult>,
522        /// Output only. Index of the candidate in the list of response candidates.
523        pub index: Option<i32>,
524        /// Output only. Additional information about why the model stopped generating tokens.
525        pub finish_message: Option<String>,
526    }
527
528    #[derive(Clone, Debug, Deserialize, Serialize)]
529    pub struct Content {
530        /// Ordered Parts that constitute a single message. Parts may have different MIME types.
531        #[serde(default)]
532        pub parts: Vec<Part>,
533        /// The producer of the content. Must be either 'user' or 'model'.
534        /// Useful to set for multi-turn conversations, otherwise can be left blank or unset.
535        pub role: Option<Role>,
536    }
537
538    impl TryFrom<message::Message> for Content {
539        type Error = message::MessageError;
540
541        fn try_from(msg: message::Message) -> Result<Self, Self::Error> {
542            Ok(match msg {
543                message::Message::User { content } => Content {
544                    parts: content
545                        .into_iter()
546                        .map(|c| c.try_into())
547                        .collect::<Result<Vec<_>, _>>()?,
548                    role: Some(Role::User),
549                },
550                message::Message::Assistant { content, .. } => Content {
551                    role: Some(Role::Model),
552                    parts: content.into_iter().map(|content| content.into()).collect(),
553                },
554            })
555        }
556    }
557
558    impl TryFrom<Content> for message::Message {
559        type Error = message::MessageError;
560
561        fn try_from(content: Content) -> Result<Self, Self::Error> {
562            match content.role {
563                Some(Role::User) | None => {
564                    Ok(message::Message::User {
565                        content: {
566                            let user_content: Result<Vec<_>, _> = content.parts.into_iter()
567                            .map(|Part { part, .. }| {
568                                Ok(match part {
569                                    PartKind::Text(text) => message::UserContent::text(text),
570                                    PartKind::InlineData(inline_data) => {
571                                        let mime_type =
572                                            message::MediaType::from_mime_type(&inline_data.mime_type);
573
574                                        match mime_type {
575                                            Some(message::MediaType::Image(media_type)) => {
576                                                message::UserContent::image_base64(
577                                                    inline_data.data,
578                                                    Some(media_type),
579                                                    Some(message::ImageDetail::default()),
580                                                )
581                                            }
582                                            Some(message::MediaType::Document(media_type)) => {
583                                                message::UserContent::document(
584                                                    inline_data.data,
585                                                    Some(media_type),
586                                                )
587                                            }
588                                            Some(message::MediaType::Audio(media_type)) => {
589                                                message::UserContent::audio(
590                                                    inline_data.data,
591                                                    Some(media_type),
592                                                )
593                                            }
594                                            _ => {
595                                                return Err(message::MessageError::ConversionError(
596                                                    format!("Unsupported media type {mime_type:?}"),
597                                                ));
598                                            }
599                                        }
600                                    }
601                                    _ => {
602                                        return Err(message::MessageError::ConversionError(format!(
603                                            "Unsupported gemini content part type: {part:?}"
604                                        )));
605                                    }
606                                })
607                            })
608                            .collect();
609                            OneOrMany::many(user_content?).map_err(|_| {
610                                message::MessageError::ConversionError(
611                                    "Failed to create OneOrMany from user content".to_string(),
612                                )
613                            })?
614                        },
615                    })
616                }
617                Some(Role::Model) => Ok(message::Message::Assistant {
618                    id: None,
619                    content: {
620                        let assistant_content: Result<Vec<_>, _> = content
621                            .parts
622                            .into_iter()
623                            .map(|Part { thought, part, .. }| {
624                                Ok(match part {
625                                    PartKind::Text(text) => match thought {
626                                        Some(true) => message::AssistantContent::Reasoning(
627                                            Reasoning::new(&text),
628                                        ),
629                                        _ => message::AssistantContent::Text(Text { text }),
630                                    },
631
632                                    PartKind::FunctionCall(function_call) => {
633                                        message::AssistantContent::ToolCall(function_call.into())
634                                    }
635                                    _ => {
636                                        return Err(message::MessageError::ConversionError(
637                                            format!("Unsupported part type: {part:?}"),
638                                        ));
639                                    }
640                                })
641                            })
642                            .collect();
643                        OneOrMany::many(assistant_content?).map_err(|_| {
644                            message::MessageError::ConversionError(
645                                "Failed to create OneOrMany from assistant content".to_string(),
646                            )
647                        })?
648                    },
649                }),
650            }
651        }
652    }
653
654    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
655    #[serde(rename_all = "lowercase")]
656    pub enum Role {
657        User,
658        Model,
659    }
660
661    #[derive(Debug, Default, Deserialize, Serialize, Clone, PartialEq)]
662    #[serde(rename_all = "camelCase")]
663    pub struct Part {
664        /// whether or not the part is a reasoning/thinking text or not
665        #[serde(skip_serializing_if = "Option::is_none")]
666        pub thought: Option<bool>,
667        /// an opaque sig for the thought so it can be reused - is a base64 string
668        #[serde(skip_serializing_if = "Option::is_none")]
669        pub thought_signature: Option<String>,
670        #[serde(flatten)]
671        pub part: PartKind,
672        #[serde(flatten, skip_serializing_if = "Option::is_none")]
673        pub additional_params: Option<Value>,
674    }
675
676    /// A datatype containing media that is part of a multi-part [Content] message.
677    /// A Part consists of data which has an associated datatype. A Part can only contain one of the accepted types in Part.data.
678    /// A Part must have a fixed IANA MIME type identifying the type and subtype of the media if the inlineData field is filled with raw bytes.
679    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
680    #[serde(rename_all = "camelCase")]
681    pub enum PartKind {
682        Text(String),
683        InlineData(Blob),
684        FunctionCall(FunctionCall),
685        FunctionResponse(FunctionResponse),
686        FileData(FileData),
687        ExecutableCode(ExecutableCode),
688        CodeExecutionResult(CodeExecutionResult),
689    }
690
691    // This default instance is primarily so we can easily fill in the optional fields of `Part`
692    // So this instance for `PartKind` (and the allocation it would cause) should be optimized away
693    impl Default for PartKind {
694        fn default() -> Self {
695            Self::Text(String::new())
696        }
697    }
698
699    impl From<String> for Part {
700        fn from(text: String) -> Self {
701            Self {
702                thought: Some(false),
703                thought_signature: None,
704                part: PartKind::Text(text),
705                additional_params: None,
706            }
707        }
708    }
709
710    impl From<&str> for Part {
711        fn from(text: &str) -> Self {
712            Self::from(text.to_string())
713        }
714    }
715
716    impl FromStr for Part {
717        type Err = Infallible;
718
719        fn from_str(s: &str) -> Result<Self, Self::Err> {
720            Ok(s.into())
721        }
722    }
723
724    impl TryFrom<(ImageMediaType, DocumentSourceKind)> for PartKind {
725        type Error = message::MessageError;
726        fn try_from(
727            (mime_type, doc_src): (ImageMediaType, DocumentSourceKind),
728        ) -> Result<Self, Self::Error> {
729            let mime_type = mime_type.to_mime_type().to_string();
730            let part = match doc_src {
731                DocumentSourceKind::Url(url) => PartKind::FileData(FileData {
732                    mime_type: Some(mime_type),
733                    file_uri: url,
734                }),
735                DocumentSourceKind::Base64(data) | DocumentSourceKind::String(data) => {
736                    PartKind::InlineData(Blob { mime_type, data })
737                }
738                DocumentSourceKind::Raw(_) => {
739                    return Err(message::MessageError::ConversionError(
740                        "Raw files not supported, encode as base64 first".into(),
741                    ));
742                }
743                DocumentSourceKind::Unknown => {
744                    return Err(message::MessageError::ConversionError(
745                        "Can't convert an unknown document source".to_string(),
746                    ));
747                }
748            };
749
750            Ok(part)
751        }
752    }
753
754    impl TryFrom<message::UserContent> for Part {
755        type Error = message::MessageError;
756
757        fn try_from(content: message::UserContent) -> Result<Self, Self::Error> {
758            match content {
759                message::UserContent::Text(message::Text { text }) => Ok(Part {
760                    thought: Some(false),
761                    thought_signature: None,
762                    part: PartKind::Text(text),
763                    additional_params: None,
764                }),
765                message::UserContent::ToolResult(message::ToolResult { id, content, .. }) => {
766                    let content = match content.first() {
767                        message::ToolResultContent::Text(text) => text.text,
768                        message::ToolResultContent::Image(_) => {
769                            return Err(message::MessageError::ConversionError(
770                                "Tool result content must be text".to_string(),
771                            ));
772                        }
773                    };
774                    // Convert to JSON since this value may be a valid JSON value
775                    let result: serde_json::Value =
776                        serde_json::from_str(&content).unwrap_or_else(|error| {
777                            tracing::trace!(
778                                ?error,
779                                "Tool result is not a valid JSON, treat it as normal string"
780                            );
781                            json!(content)
782                        });
783                    Ok(Part {
784                        thought: Some(false),
785                        thought_signature: None,
786                        part: PartKind::FunctionResponse(FunctionResponse {
787                            name: id,
788                            response: Some(json!({ "result": result })),
789                        }),
790                        additional_params: None,
791                    })
792                }
793                message::UserContent::Image(message::Image {
794                    data, media_type, ..
795                }) => match media_type {
796                    Some(media_type) => match media_type {
797                        message::ImageMediaType::JPEG
798                        | message::ImageMediaType::PNG
799                        | message::ImageMediaType::WEBP
800                        | message::ImageMediaType::HEIC
801                        | message::ImageMediaType::HEIF => {
802                            let part = PartKind::try_from((media_type, data))?;
803                            Ok(Part {
804                                thought: Some(false),
805                                thought_signature: None,
806                                part,
807                                additional_params: None,
808                            })
809                        }
810                        _ => Err(message::MessageError::ConversionError(format!(
811                            "Unsupported image media type {media_type:?}"
812                        ))),
813                    },
814                    None => Err(message::MessageError::ConversionError(
815                        "Media type for image is required for Gemini".to_string(),
816                    )),
817                },
818                message::UserContent::Document(message::Document {
819                    data, media_type, ..
820                }) => {
821                    let Some(media_type) = media_type else {
822                        return Err(MessageError::ConversionError(
823                            "A mime type is required for document inputs to Gemini".to_string(),
824                        ));
825                    };
826
827                    if !media_type.is_code() {
828                        let mime_type = media_type.to_mime_type().to_string();
829
830                        let part = match data {
831                            DocumentSourceKind::Url(file_uri) => PartKind::FileData(FileData {
832                                mime_type: Some(mime_type),
833                                file_uri,
834                            }),
835                            DocumentSourceKind::Base64(data) | DocumentSourceKind::String(data) => {
836                                PartKind::InlineData(Blob { mime_type, data })
837                            }
838                            DocumentSourceKind::Raw(_) => {
839                                return Err(message::MessageError::ConversionError(
840                                    "Raw files not supported, encode as base64 first".into(),
841                                ));
842                            }
843                            _ => {
844                                return Err(message::MessageError::ConversionError(
845                                    "Document has no body".to_string(),
846                                ));
847                            }
848                        };
849
850                        Ok(Part {
851                            thought: Some(false),
852                            part,
853                            ..Default::default()
854                        })
855                    } else {
856                        Err(message::MessageError::ConversionError(format!(
857                            "Unsupported document media type {media_type:?}"
858                        )))
859                    }
860                }
861
862                message::UserContent::Audio(message::Audio {
863                    data, media_type, ..
864                }) => {
865                    let Some(media_type) = media_type else {
866                        return Err(MessageError::ConversionError(
867                            "A mime type is required for audio inputs to Gemini".to_string(),
868                        ));
869                    };
870
871                    let mime_type = media_type.to_mime_type().to_string();
872
873                    let part = match data {
874                        DocumentSourceKind::Base64(data) => {
875                            PartKind::InlineData(Blob { data, mime_type })
876                        }
877
878                        DocumentSourceKind::Url(file_uri) => PartKind::FileData(FileData {
879                            mime_type: Some(mime_type),
880                            file_uri,
881                        }),
882                        DocumentSourceKind::String(_) => {
883                            return Err(message::MessageError::ConversionError(
884                                "Strings cannot be used as audio files!".into(),
885                            ));
886                        }
887                        DocumentSourceKind::Raw(_) => {
888                            return Err(message::MessageError::ConversionError(
889                                "Raw files not supported, encode as base64 first".into(),
890                            ));
891                        }
892                        DocumentSourceKind::Unknown => {
893                            return Err(message::MessageError::ConversionError(
894                                "Content has no body".to_string(),
895                            ));
896                        }
897                    };
898
899                    Ok(Part {
900                        thought: Some(false),
901                        part,
902                        ..Default::default()
903                    })
904                }
905                message::UserContent::Video(message::Video {
906                    data,
907                    media_type,
908                    additional_params,
909                    ..
910                }) => {
911                    let mime_type = media_type.map(|media_ty| media_ty.to_mime_type().to_string());
912
913                    let part = match data {
914                        DocumentSourceKind::Url(file_uri) => {
915                            if file_uri.starts_with("https://www.youtube.com") {
916                                PartKind::FileData(FileData {
917                                    mime_type,
918                                    file_uri,
919                                })
920                            } else {
921                                if mime_type.is_none() {
922                                    return Err(MessageError::ConversionError(
923                                        "A mime type is required for non-Youtube video file inputs to Gemini"
924                                            .to_string(),
925                                    ));
926                                }
927
928                                PartKind::FileData(FileData {
929                                    mime_type,
930                                    file_uri,
931                                })
932                            }
933                        }
934                        DocumentSourceKind::Base64(data) => {
935                            let Some(mime_type) = mime_type else {
936                                return Err(MessageError::ConversionError(
937                                    "A media type is expected for base64 encoded strings"
938                                        .to_string(),
939                                ));
940                            };
941                            PartKind::InlineData(Blob { mime_type, data })
942                        }
943                        DocumentSourceKind::String(_) => {
944                            return Err(message::MessageError::ConversionError(
945                                "Strings cannot be used as audio files!".into(),
946                            ));
947                        }
948                        DocumentSourceKind::Raw(_) => {
949                            return Err(message::MessageError::ConversionError(
950                                "Raw file data not supported, encode as base64 first".into(),
951                            ));
952                        }
953                        DocumentSourceKind::Unknown => {
954                            return Err(message::MessageError::ConversionError(
955                                "Media type for video is required for Gemini".to_string(),
956                            ));
957                        }
958                    };
959
960                    Ok(Part {
961                        thought: Some(false),
962                        thought_signature: None,
963                        part,
964                        additional_params,
965                    })
966                }
967            }
968        }
969    }
970
971    impl From<message::AssistantContent> for Part {
972        fn from(content: message::AssistantContent) -> Self {
973            match content {
974                message::AssistantContent::Text(message::Text { text }) => text.into(),
975                message::AssistantContent::ToolCall(tool_call) => tool_call.into(),
976                message::AssistantContent::Reasoning(message::Reasoning { reasoning, .. }) => {
977                    Part {
978                        thought: Some(true),
979                        thought_signature: None,
980                        part: PartKind::Text(
981                            reasoning.first().cloned().unwrap_or_else(|| "".to_string()),
982                        ),
983                        additional_params: None,
984                    }
985                }
986            }
987        }
988    }
989
990    impl From<message::ToolCall> for Part {
991        fn from(tool_call: message::ToolCall) -> Self {
992            Self {
993                thought: Some(false),
994                thought_signature: None,
995                part: PartKind::FunctionCall(FunctionCall {
996                    name: tool_call.function.name,
997                    args: tool_call.function.arguments,
998                }),
999                additional_params: None,
1000            }
1001        }
1002    }
1003
1004    /// Raw media bytes.
1005    /// Text should not be sent as raw bytes, use the 'text' field.
1006    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1007    #[serde(rename_all = "camelCase")]
1008    pub struct Blob {
1009        /// The IANA standard MIME type of the source data. Examples: - image/png - image/jpeg
1010        /// If an unsupported MIME type is provided, an error will be returned.
1011        pub mime_type: String,
1012        /// Raw bytes for media formats. A base64-encoded string.
1013        pub data: String,
1014    }
1015
1016    /// A predicted FunctionCall returned from the model that contains a string representing the
1017    /// FunctionDeclaration.name with the arguments and their values.
1018    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1019    pub struct FunctionCall {
1020        /// Required. The name of the function to call. Must be a-z, A-Z, 0-9, or contain underscores
1021        /// and dashes, with a maximum length of 63.
1022        pub name: String,
1023        /// Optional. The function parameters and values in JSON object format.
1024        pub args: serde_json::Value,
1025    }
1026
1027    impl From<FunctionCall> for message::ToolCall {
1028        fn from(function_call: FunctionCall) -> Self {
1029            Self {
1030                id: function_call.name.clone(),
1031                call_id: None,
1032                function: message::ToolFunction {
1033                    name: function_call.name,
1034                    arguments: function_call.args,
1035                },
1036            }
1037        }
1038    }
1039
1040    impl From<message::ToolCall> for FunctionCall {
1041        fn from(tool_call: message::ToolCall) -> Self {
1042            Self {
1043                name: tool_call.function.name,
1044                args: tool_call.function.arguments,
1045            }
1046        }
1047    }
1048
1049    /// The result output from a FunctionCall that contains a string representing the FunctionDeclaration.name
1050    /// and a structured JSON object containing any output from the function is used as context to the model.
1051    /// This should contain the result of aFunctionCall made based on model prediction.
1052    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1053    pub struct FunctionResponse {
1054        /// The name of the function to call. Must be a-z, A-Z, 0-9, or contain underscores and dashes,
1055        /// with a maximum length of 63.
1056        pub name: String,
1057        /// The function response in JSON object format.
1058        pub response: Option<serde_json::Value>,
1059    }
1060
1061    /// URI based data.
1062    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1063    #[serde(rename_all = "camelCase")]
1064    pub struct FileData {
1065        /// Optional. The IANA standard MIME type of the source data.
1066        pub mime_type: Option<String>,
1067        /// Required. URI.
1068        pub file_uri: String,
1069    }
1070
1071    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1072    pub struct SafetyRating {
1073        pub category: HarmCategory,
1074        pub probability: HarmProbability,
1075    }
1076
1077    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1078    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1079    pub enum HarmProbability {
1080        HarmProbabilityUnspecified,
1081        Negligible,
1082        Low,
1083        Medium,
1084        High,
1085    }
1086
1087    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1088    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1089    pub enum HarmCategory {
1090        HarmCategoryUnspecified,
1091        HarmCategoryDerogatory,
1092        HarmCategoryToxicity,
1093        HarmCategoryViolence,
1094        HarmCategorySexually,
1095        HarmCategoryMedical,
1096        HarmCategoryDangerous,
1097        HarmCategoryHarassment,
1098        HarmCategoryHateSpeech,
1099        HarmCategorySexuallyExplicit,
1100        HarmCategoryDangerousContent,
1101        HarmCategoryCivicIntegrity,
1102    }
1103
1104    #[derive(Debug, Deserialize, Clone, Default, Serialize)]
1105    #[serde(rename_all = "camelCase")]
1106    pub struct UsageMetadata {
1107        pub prompt_token_count: i32,
1108        #[serde(skip_serializing_if = "Option::is_none")]
1109        pub cached_content_token_count: Option<i32>,
1110        #[serde(skip_serializing_if = "Option::is_none")]
1111        pub candidates_token_count: Option<i32>,
1112        pub total_token_count: i32,
1113        #[serde(skip_serializing_if = "Option::is_none")]
1114        pub thoughts_token_count: Option<i32>,
1115    }
1116
1117    impl std::fmt::Display for UsageMetadata {
1118        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1119            write!(
1120                f,
1121                "Prompt token count: {}\nCached content token count: {}\nCandidates token count: {}\nTotal token count: {}",
1122                self.prompt_token_count,
1123                match self.cached_content_token_count {
1124                    Some(count) => count.to_string(),
1125                    None => "n/a".to_string(),
1126                },
1127                match self.candidates_token_count {
1128                    Some(count) => count.to_string(),
1129                    None => "n/a".to_string(),
1130                },
1131                self.total_token_count
1132            )
1133        }
1134    }
1135
1136    impl GetTokenUsage for UsageMetadata {
1137        fn token_usage(&self) -> Option<crate::completion::Usage> {
1138            let mut usage = crate::completion::Usage::new();
1139
1140            usage.input_tokens = self.prompt_token_count as u64;
1141            usage.output_tokens = (self.cached_content_token_count.unwrap_or_default()
1142                + self.candidates_token_count.unwrap_or_default()
1143                + self.thoughts_token_count.unwrap_or_default())
1144                as u64;
1145            usage.total_tokens = usage.input_tokens + usage.output_tokens;
1146
1147            Some(usage)
1148        }
1149    }
1150
1151    /// A set of the feedback metadata the prompt specified in [GenerateContentRequest.contents](GenerateContentRequest).
1152    #[derive(Debug, Deserialize, Serialize)]
1153    #[serde(rename_all = "camelCase")]
1154    pub struct PromptFeedback {
1155        /// Optional. If set, the prompt was blocked and no candidates are returned. Rephrase the prompt.
1156        pub block_reason: Option<BlockReason>,
1157        /// Ratings for safety of the prompt. There is at most one rating per category.
1158        pub safety_ratings: Option<Vec<SafetyRating>>,
1159    }
1160
1161    /// Reason why a prompt was blocked by the model
1162    #[derive(Debug, Deserialize, Serialize)]
1163    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1164    pub enum BlockReason {
1165        /// Default value. This value is unused.
1166        BlockReasonUnspecified,
1167        /// Prompt was blocked due to safety reasons. Inspect safetyRatings to understand which safety category blocked it.
1168        Safety,
1169        /// Prompt was blocked due to unknown reasons.
1170        Other,
1171        /// Prompt was blocked due to the terms which are included from the terminology blocklist.
1172        Blocklist,
1173        /// Prompt was blocked due to prohibited content.
1174        ProhibitedContent,
1175    }
1176
1177    #[derive(Clone, Debug, Deserialize, Serialize)]
1178    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1179    pub enum FinishReason {
1180        /// Default value. This value is unused.
1181        FinishReasonUnspecified,
1182        /// Natural stop point of the model or provided stop sequence.
1183        Stop,
1184        /// The maximum number of tokens as specified in the request was reached.
1185        MaxTokens,
1186        /// The response candidate content was flagged for safety reasons.
1187        Safety,
1188        /// The response candidate content was flagged for recitation reasons.
1189        Recitation,
1190        /// The response candidate content was flagged for using an unsupported language.
1191        Language,
1192        /// Unknown reason.
1193        Other,
1194        /// Token generation stopped because the content contains forbidden terms.
1195        Blocklist,
1196        /// Token generation stopped for potentially containing prohibited content.
1197        ProhibitedContent,
1198        /// Token generation stopped because the content potentially contains Sensitive Personally Identifiable Information (SPII).
1199        Spii,
1200        /// The function call generated by the model is invalid.
1201        MalformedFunctionCall,
1202    }
1203
1204    #[derive(Clone, Debug, Deserialize, Serialize)]
1205    #[serde(rename_all = "camelCase")]
1206    pub struct CitationMetadata {
1207        pub citation_sources: Vec<CitationSource>,
1208    }
1209
1210    #[derive(Clone, Debug, Deserialize, Serialize)]
1211    #[serde(rename_all = "camelCase")]
1212    pub struct CitationSource {
1213        #[serde(skip_serializing_if = "Option::is_none")]
1214        pub uri: Option<String>,
1215        #[serde(skip_serializing_if = "Option::is_none")]
1216        pub start_index: Option<i32>,
1217        #[serde(skip_serializing_if = "Option::is_none")]
1218        pub end_index: Option<i32>,
1219        #[serde(skip_serializing_if = "Option::is_none")]
1220        pub license: Option<String>,
1221    }
1222
1223    #[derive(Clone, Debug, Deserialize, Serialize)]
1224    #[serde(rename_all = "camelCase")]
1225    pub struct LogprobsResult {
1226        pub top_candidate: Vec<TopCandidate>,
1227        pub chosen_candidate: Vec<LogProbCandidate>,
1228    }
1229
1230    #[derive(Clone, Debug, Deserialize, Serialize)]
1231    pub struct TopCandidate {
1232        pub candidates: Vec<LogProbCandidate>,
1233    }
1234
1235    #[derive(Clone, Debug, Deserialize, Serialize)]
1236    #[serde(rename_all = "camelCase")]
1237    pub struct LogProbCandidate {
1238        pub token: String,
1239        pub token_id: String,
1240        pub log_probability: f64,
1241    }
1242
1243    /// Gemini API Configuration options for model generation and outputs. Not all parameters are
1244    /// configurable for every model. From [Gemini API Reference](https://ai.google.dev/api/generate-content#generationconfig)
1245    /// ### Rig Note:
1246    /// Can be used to construct a typesafe `additional_params` in rig::[AgentBuilder](crate::agent::AgentBuilder).
1247    #[derive(Debug, Deserialize, Serialize)]
1248    #[serde(rename_all = "camelCase")]
1249    pub struct GenerationConfig {
1250        /// The set of character sequences (up to 5) that will stop output generation. If specified, the API will stop
1251        /// at the first appearance of a stop_sequence. The stop sequence will not be included as part of the response.
1252        #[serde(skip_serializing_if = "Option::is_none")]
1253        pub stop_sequences: Option<Vec<String>>,
1254        /// MIME type of the generated candidate text. Supported MIME types are:
1255        ///     - text/plain:  (default) Text output
1256        ///     - application/json: JSON response in the response candidates.
1257        ///     - text/x.enum: ENUM as a string response in the response candidates.
1258        /// Refer to the docs for a list of all supported text MIME types
1259        #[serde(skip_serializing_if = "Option::is_none")]
1260        pub response_mime_type: Option<String>,
1261        /// Output schema of the generated candidate text. Schemas must be a subset of the OpenAPI schema and can be
1262        /// objects, primitives or arrays. If set, a compatible responseMimeType must also  be set. Compatible MIME
1263        /// types: application/json: Schema for JSON response. Refer to the JSON text generation guide for more details.
1264        #[serde(skip_serializing_if = "Option::is_none")]
1265        pub response_schema: Option<Schema>,
1266        /// Number of generated responses to return. Currently, this value can only be set to 1. If
1267        /// unset, this will default to 1.
1268        #[serde(skip_serializing_if = "Option::is_none")]
1269        pub candidate_count: Option<i32>,
1270        /// The maximum number of tokens to include in a response candidate. Note: The default value varies by model, see
1271        /// the Model.output_token_limit attribute of the Model returned from the getModel function.
1272        #[serde(skip_serializing_if = "Option::is_none")]
1273        pub max_output_tokens: Option<u64>,
1274        /// Controls the randomness of the output. Note: The default value varies by model, see the Model.temperature
1275        /// attribute of the Model returned from the getModel function. Values can range from [0.0, 2.0].
1276        #[serde(skip_serializing_if = "Option::is_none")]
1277        pub temperature: Option<f64>,
1278        /// The maximum cumulative probability of tokens to consider when sampling. The model uses combined Top-k and
1279        /// Top-p (nucleus) sampling. Tokens are sorted based on their assigned probabilities so that only the most
1280        /// likely tokens are considered. Top-k sampling directly limits the maximum number of tokens to consider, while
1281        /// Nucleus sampling limits the number of tokens based on the cumulative probability. Note: The default value
1282        /// varies by Model and is specified by theModel.top_p attribute returned from the getModel function. An empty
1283        /// topK attribute indicates that the model doesn't apply top-k sampling and doesn't allow setting topK on requests.
1284        #[serde(skip_serializing_if = "Option::is_none")]
1285        pub top_p: Option<f64>,
1286        /// The maximum number of tokens to consider when sampling. Gemini models use Top-p (nucleus) sampling or a
1287        /// combination of Top-k and nucleus sampling. Top-k sampling considers the set of topK most probable tokens.
1288        /// Models running with nucleus sampling don't allow topK setting. Note: The default value varies by Model and is
1289        /// specified by theModel.top_p attribute returned from the getModel function. An empty topK attribute indicates
1290        /// that the model doesn't apply top-k sampling and doesn't allow setting topK on requests.
1291        #[serde(skip_serializing_if = "Option::is_none")]
1292        pub top_k: Option<i32>,
1293        /// Presence penalty applied to the next token's logprobs if the token has already been seen in the response.
1294        /// This penalty is binary on/off and not dependent on the number of times the token is used (after the first).
1295        /// Use frequencyPenalty for a penalty that increases with each use. A positive penalty will discourage the use
1296        /// of tokens that have already been used in the response, increasing the vocabulary. A negative penalty will
1297        /// encourage the use of tokens that have already been used in the response, decreasing the vocabulary.
1298        #[serde(skip_serializing_if = "Option::is_none")]
1299        pub presence_penalty: Option<f64>,
1300        /// Frequency penalty applied to the next token's logprobs, multiplied by the number of times each token has been
1301        /// seen in the response so far. A positive penalty will discourage the use of tokens that have already been
1302        /// used, proportional to the number of times the token has been used: The more a token is used, the more
1303        /// difficult it is for the  model to use that token again increasing the vocabulary of responses. Caution: A
1304        /// negative penalty will encourage the model to reuse tokens proportional to the number of times the token has
1305        /// been used. Small negative values will reduce the vocabulary of a response. Larger negative values will cause
1306        /// the model to  repeating a common token until it hits the maxOutputTokens limit: "...the the the the the...".
1307        #[serde(skip_serializing_if = "Option::is_none")]
1308        pub frequency_penalty: Option<f64>,
1309        /// If true, export the logprobs results in response.
1310        #[serde(skip_serializing_if = "Option::is_none")]
1311        pub response_logprobs: Option<bool>,
1312        /// Only valid if responseLogprobs=True. This sets the number of top logprobs to return at each decoding step in
1313        /// [Candidate.logprobs_result].
1314        #[serde(skip_serializing_if = "Option::is_none")]
1315        pub logprobs: Option<i32>,
1316        /// Configuration for thinking/reasoning.
1317        #[serde(skip_serializing_if = "Option::is_none")]
1318        pub thinking_config: Option<ThinkingConfig>,
1319    }
1320
1321    impl Default for GenerationConfig {
1322        fn default() -> Self {
1323            Self {
1324                temperature: Some(1.0),
1325                max_output_tokens: Some(4096),
1326                stop_sequences: None,
1327                response_mime_type: None,
1328                response_schema: None,
1329                candidate_count: None,
1330                top_p: None,
1331                top_k: None,
1332                presence_penalty: None,
1333                frequency_penalty: None,
1334                response_logprobs: None,
1335                logprobs: None,
1336                thinking_config: None,
1337            }
1338        }
1339    }
1340
1341    #[derive(Debug, Deserialize, Serialize)]
1342    #[serde(rename_all = "camelCase")]
1343    pub struct ThinkingConfig {
1344        pub thinking_budget: u32,
1345        pub include_thoughts: Option<bool>,
1346    }
1347    /// The Schema object allows the definition of input and output data types. These types can be objects, but also
1348    /// primitives and arrays. Represents a select subset of an OpenAPI 3.0 schema object.
1349    /// From [Gemini API Reference](https://ai.google.dev/api/caching#Schema)
1350    #[derive(Debug, Deserialize, Serialize, Clone)]
1351    pub struct Schema {
1352        pub r#type: String,
1353        #[serde(skip_serializing_if = "Option::is_none")]
1354        pub format: Option<String>,
1355        #[serde(skip_serializing_if = "Option::is_none")]
1356        pub description: Option<String>,
1357        #[serde(skip_serializing_if = "Option::is_none")]
1358        pub nullable: Option<bool>,
1359        #[serde(skip_serializing_if = "Option::is_none")]
1360        pub r#enum: Option<Vec<String>>,
1361        #[serde(skip_serializing_if = "Option::is_none")]
1362        pub max_items: Option<i32>,
1363        #[serde(skip_serializing_if = "Option::is_none")]
1364        pub min_items: Option<i32>,
1365        #[serde(skip_serializing_if = "Option::is_none")]
1366        pub properties: Option<HashMap<String, Schema>>,
1367        #[serde(skip_serializing_if = "Option::is_none")]
1368        pub required: Option<Vec<String>>,
1369        #[serde(skip_serializing_if = "Option::is_none")]
1370        pub items: Option<Box<Schema>>,
1371    }
1372
1373    /// Flattens a JSON schema by resolving all `$ref` references inline.
1374    /// It takes a JSON schema that may contain `$ref` references to definitions
1375    /// in `$defs` or `definitions` sections and returns a new schema with all references
1376    /// resolved and inlined. This is necessary for APIs like Gemini that don't support
1377    /// schema references.
1378    pub fn flatten_schema(mut schema: Value) -> Result<Value, CompletionError> {
1379        // extracting $defs if they exist
1380        let defs = if let Some(obj) = schema.as_object() {
1381            obj.get("$defs").or_else(|| obj.get("definitions")).cloned()
1382        } else {
1383            None
1384        };
1385
1386        let Some(defs_value) = defs else {
1387            return Ok(schema);
1388        };
1389
1390        let Some(defs_obj) = defs_value.as_object() else {
1391            return Err(CompletionError::ResponseError(
1392                "$defs must be an object".into(),
1393            ));
1394        };
1395
1396        resolve_refs(&mut schema, defs_obj)?;
1397
1398        // removing $defs from the final schema because we have inlined everything
1399        if let Some(obj) = schema.as_object_mut() {
1400            obj.remove("$defs");
1401            obj.remove("definitions");
1402        }
1403
1404        Ok(schema)
1405    }
1406
1407    /// Recursively resolves all `$ref` references in a JSON value by
1408    /// replacing them with their definitions.
1409    fn resolve_refs(
1410        value: &mut Value,
1411        defs: &serde_json::Map<String, Value>,
1412    ) -> Result<(), CompletionError> {
1413        match value {
1414            Value::Object(obj) => {
1415                if let Some(ref_value) = obj.get("$ref")
1416                    && let Some(ref_str) = ref_value.as_str()
1417                {
1418                    // "#/$defs/Person" -> "Person"
1419                    let def_name = parse_ref_path(ref_str)?;
1420
1421                    let def = defs.get(&def_name).ok_or_else(|| {
1422                        CompletionError::ResponseError(format!("Reference not found: {}", ref_str))
1423                    })?;
1424
1425                    let mut resolved = def.clone();
1426                    resolve_refs(&mut resolved, defs)?;
1427                    *value = resolved;
1428                    return Ok(());
1429                }
1430
1431                for (_, v) in obj.iter_mut() {
1432                    resolve_refs(v, defs)?;
1433                }
1434            }
1435            Value::Array(arr) => {
1436                for item in arr.iter_mut() {
1437                    resolve_refs(item, defs)?;
1438                }
1439            }
1440            _ => {}
1441        }
1442
1443        Ok(())
1444    }
1445
1446    /// Parses a JSON Schema `$ref` path to extract the definition name.
1447    ///
1448    /// JSON Schema references use URI fragment syntax to point to definitions within
1449    /// the same document. This function extracts the definition name from common
1450    /// reference patterns used in JSON Schema.
1451    fn parse_ref_path(ref_str: &str) -> Result<String, CompletionError> {
1452        if let Some(fragment) = ref_str.strip_prefix('#') {
1453            if let Some(name) = fragment.strip_prefix("/$defs/") {
1454                Ok(name.to_string())
1455            } else if let Some(name) = fragment.strip_prefix("/definitions/") {
1456                Ok(name.to_string())
1457            } else {
1458                Err(CompletionError::ResponseError(format!(
1459                    "Unsupported reference format: {}",
1460                    ref_str
1461                )))
1462            }
1463        } else {
1464            Err(CompletionError::ResponseError(format!(
1465                "Only fragment references (#/...) are supported: {}",
1466                ref_str
1467            )))
1468        }
1469    }
1470
1471    /// Helper function to extract the type string from a JSON value.
1472    /// Handles both direct string types and array types (returns the first element).
1473    fn extract_type(type_value: &Value) -> Option<String> {
1474        if type_value.is_string() {
1475            type_value.as_str().map(String::from)
1476        } else if type_value.is_array() {
1477            type_value
1478                .as_array()
1479                .and_then(|arr| arr.first())
1480                .and_then(|v| v.as_str().map(String::from))
1481        } else {
1482            None
1483        }
1484    }
1485
1486    /// Helper function to extract type from anyOf, oneOf, or allOf schemas.
1487    /// Returns the type of the first non-null schema found.
1488    fn extract_type_from_composition(composition: &Value) -> Option<String> {
1489        composition.as_array().and_then(|arr| {
1490            arr.iter().find_map(|schema| {
1491                if let Some(obj) = schema.as_object() {
1492                    // Skip null types
1493                    if let Some(type_val) = obj.get("type")
1494                        && let Some(type_str) = type_val.as_str()
1495                        && type_str == "null"
1496                    {
1497                        return None;
1498                    }
1499                    // Extract type from this schema
1500                    obj.get("type").and_then(extract_type).or_else(|| {
1501                        if obj.contains_key("properties") {
1502                            Some("object".to_string())
1503                        } else {
1504                            None
1505                        }
1506                    })
1507                } else {
1508                    None
1509                }
1510            })
1511        })
1512    }
1513
1514    /// Helper function to extract the first non-null schema from anyOf, oneOf, or allOf.
1515    /// Returns the schema object that should be used for properties, required, etc.
1516    fn extract_schema_from_composition(
1517        composition: &Value,
1518    ) -> Option<serde_json::Map<String, Value>> {
1519        composition.as_array().and_then(|arr| {
1520            arr.iter().find_map(|schema| {
1521                if let Some(obj) = schema.as_object()
1522                    && let Some(type_val) = obj.get("type")
1523                    && let Some(type_str) = type_val.as_str()
1524                {
1525                    if type_str == "null" {
1526                        return None;
1527                    }
1528                    Some(obj.clone())
1529                } else {
1530                    None
1531                }
1532            })
1533        })
1534    }
1535
1536    /// Helper function to infer the type of a schema object.
1537    /// Checks for explicit type, then anyOf/oneOf/allOf, then infers from properties.
1538    fn infer_type(obj: &serde_json::Map<String, Value>) -> String {
1539        // First, try direct type field
1540        if let Some(type_val) = obj.get("type")
1541            && let Some(type_str) = extract_type(type_val)
1542        {
1543            return type_str;
1544        }
1545
1546        // Then try anyOf, oneOf, allOf (in that order)
1547        if let Some(any_of) = obj.get("anyOf")
1548            && let Some(type_str) = extract_type_from_composition(any_of)
1549        {
1550            return type_str;
1551        }
1552
1553        if let Some(one_of) = obj.get("oneOf")
1554            && let Some(type_str) = extract_type_from_composition(one_of)
1555        {
1556            return type_str;
1557        }
1558
1559        if let Some(all_of) = obj.get("allOf")
1560            && let Some(type_str) = extract_type_from_composition(all_of)
1561        {
1562            return type_str;
1563        }
1564
1565        // Finally, infer object type if properties are present
1566        if obj.contains_key("properties") {
1567            "object".to_string()
1568        } else {
1569            String::new()
1570        }
1571    }
1572
1573    impl TryFrom<Value> for Schema {
1574        type Error = CompletionError;
1575
1576        fn try_from(value: Value) -> Result<Self, Self::Error> {
1577            let flattened_val = flatten_schema(value)?;
1578            if let Some(obj) = flattened_val.as_object() {
1579                // Determine which object to use for extracting properties and required fields.
1580                // If this object has anyOf/oneOf/allOf, we need to extract properties from the composition.
1581                let props_source = if obj.get("properties").is_none() {
1582                    if let Some(any_of) = obj.get("anyOf") {
1583                        extract_schema_from_composition(any_of)
1584                    } else if let Some(one_of) = obj.get("oneOf") {
1585                        extract_schema_from_composition(one_of)
1586                    } else if let Some(all_of) = obj.get("allOf") {
1587                        extract_schema_from_composition(all_of)
1588                    } else {
1589                        None
1590                    }
1591                    .unwrap_or(obj.clone())
1592                } else {
1593                    obj.clone()
1594                };
1595
1596                Ok(Schema {
1597                    r#type: infer_type(obj),
1598                    format: obj.get("format").and_then(|v| v.as_str()).map(String::from),
1599                    description: obj
1600                        .get("description")
1601                        .and_then(|v| v.as_str())
1602                        .map(String::from),
1603                    nullable: obj.get("nullable").and_then(|v| v.as_bool()),
1604                    r#enum: obj.get("enum").and_then(|v| v.as_array()).map(|arr| {
1605                        arr.iter()
1606                            .filter_map(|v| v.as_str().map(String::from))
1607                            .collect()
1608                    }),
1609                    max_items: obj
1610                        .get("maxItems")
1611                        .and_then(|v| v.as_i64())
1612                        .map(|v| v as i32),
1613                    min_items: obj
1614                        .get("minItems")
1615                        .and_then(|v| v.as_i64())
1616                        .map(|v| v as i32),
1617                    properties: props_source
1618                        .get("properties")
1619                        .and_then(|v| v.as_object())
1620                        .map(|map| {
1621                            map.iter()
1622                                .filter_map(|(k, v)| {
1623                                    v.clone().try_into().ok().map(|schema| (k.clone(), schema))
1624                                })
1625                                .collect()
1626                        }),
1627                    required: props_source
1628                        .get("required")
1629                        .and_then(|v| v.as_array())
1630                        .map(|arr| {
1631                            arr.iter()
1632                                .filter_map(|v| v.as_str().map(String::from))
1633                                .collect()
1634                        }),
1635                    items: obj
1636                        .get("items")
1637                        .and_then(|v| v.clone().try_into().ok())
1638                        .map(Box::new),
1639                })
1640            } else {
1641                Err(CompletionError::ResponseError(
1642                    "Expected a JSON object for Schema".into(),
1643                ))
1644            }
1645        }
1646    }
1647
1648    #[derive(Debug, Serialize)]
1649    #[serde(rename_all = "camelCase")]
1650    pub struct GenerateContentRequest {
1651        pub contents: Vec<Content>,
1652        #[serde(skip_serializing_if = "Option::is_none")]
1653        pub tools: Option<Tool>,
1654        pub tool_config: Option<ToolConfig>,
1655        /// Optional. Configuration options for model generation and outputs.
1656        pub generation_config: Option<GenerationConfig>,
1657        /// Optional. A list of unique SafetySetting instances for blocking unsafe content. This will be enforced on the
1658        /// [GenerateContentRequest.contents] and [GenerateContentResponse.candidates]. There should not be more than one
1659        /// setting for each SafetyCategory type. The API will block any contents and responses that fail to meet the
1660        /// thresholds set by these settings. This list overrides the default settings for each SafetyCategory specified
1661        /// in the safetySettings. If there is no SafetySetting for a given SafetyCategory provided in the list, the API
1662        /// will use the default safety setting for that category. Harm categories:
1663        ///     - HARM_CATEGORY_HATE_SPEECH,
1664        ///     - HARM_CATEGORY_SEXUALLY_EXPLICIT
1665        ///     - HARM_CATEGORY_DANGEROUS_CONTENT
1666        ///     - HARM_CATEGORY_HARASSMENT
1667        /// are supported.
1668        /// Refer to the guide for detailed information on available safety settings. Also refer to the Safety guidance
1669        /// to learn how to incorporate safety considerations in your AI applications.
1670        pub safety_settings: Option<Vec<SafetySetting>>,
1671        /// Optional. Developer set system instruction(s). Currently, text only.
1672        /// From [Gemini API Reference](https://ai.google.dev/gemini-api/docs/system-instructions?lang=rest)
1673        pub system_instruction: Option<Content>,
1674        // cachedContent: Optional<String>
1675        /// Additional parameters.
1676        #[serde(flatten, skip_serializing_if = "Option::is_none")]
1677        pub additional_params: Option<serde_json::Value>,
1678    }
1679
1680    #[derive(Debug, Serialize)]
1681    #[serde(rename_all = "camelCase")]
1682    pub struct Tool {
1683        pub function_declarations: Vec<FunctionDeclaration>,
1684        pub code_execution: Option<CodeExecution>,
1685    }
1686
1687    #[derive(Debug, Serialize, Clone)]
1688    #[serde(rename_all = "camelCase")]
1689    pub struct FunctionDeclaration {
1690        pub name: String,
1691        pub description: String,
1692        #[serde(skip_serializing_if = "Option::is_none")]
1693        pub parameters: Option<Schema>,
1694    }
1695
1696    #[derive(Debug, Serialize, Deserialize)]
1697    #[serde(rename_all = "camelCase")]
1698    pub struct ToolConfig {
1699        pub function_calling_config: Option<FunctionCallingMode>,
1700    }
1701
1702    #[derive(Debug, Serialize, Deserialize, Default)]
1703    #[serde(tag = "mode", rename_all = "UPPERCASE")]
1704    pub enum FunctionCallingMode {
1705        #[default]
1706        Auto,
1707        None,
1708        Any {
1709            #[serde(skip_serializing_if = "Option::is_none")]
1710            allowed_function_names: Option<Vec<String>>,
1711        },
1712    }
1713
1714    impl TryFrom<message::ToolChoice> for FunctionCallingMode {
1715        type Error = CompletionError;
1716        fn try_from(value: message::ToolChoice) -> Result<Self, Self::Error> {
1717            let res = match value {
1718                message::ToolChoice::Auto => Self::Auto,
1719                message::ToolChoice::None => Self::None,
1720                message::ToolChoice::Required => Self::Any {
1721                    allowed_function_names: None,
1722                },
1723                message::ToolChoice::Specific { function_names } => Self::Any {
1724                    allowed_function_names: Some(function_names),
1725                },
1726            };
1727
1728            Ok(res)
1729        }
1730    }
1731
1732    #[derive(Debug, Serialize)]
1733    pub struct CodeExecution {}
1734
1735    #[derive(Debug, Serialize)]
1736    #[serde(rename_all = "camelCase")]
1737    pub struct SafetySetting {
1738        pub category: HarmCategory,
1739        pub threshold: HarmBlockThreshold,
1740    }
1741
1742    #[derive(Debug, Serialize)]
1743    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1744    pub enum HarmBlockThreshold {
1745        HarmBlockThresholdUnspecified,
1746        BlockLowAndAbove,
1747        BlockMediumAndAbove,
1748        BlockOnlyHigh,
1749        BlockNone,
1750        Off,
1751    }
1752}
1753
1754#[cfg(test)]
1755mod tests {
1756    use crate::{message, providers::gemini::completion::gemini_api_types::flatten_schema};
1757
1758    use super::*;
1759    use serde_json::json;
1760
1761    #[test]
1762    fn test_deserialize_message_user() {
1763        let raw_message = r#"{
1764            "parts": [
1765                {"text": "Hello, world!"},
1766                {"inlineData": {"mimeType": "image/png", "data": "base64encodeddata"}},
1767                {"functionCall": {"name": "test_function", "args": {"arg1": "value1"}}},
1768                {"functionResponse": {"name": "test_function", "response": {"result": "success"}}},
1769                {"fileData": {"mimeType": "application/pdf", "fileUri": "http://example.com/file.pdf"}},
1770                {"executableCode": {"code": "print('Hello, world!')", "language": "PYTHON"}},
1771                {"codeExecutionResult": {"output": "Hello, world!", "outcome": "OUTCOME_OK"}}
1772            ],
1773            "role": "user"
1774        }"#;
1775
1776        let content: Content = {
1777            let jd = &mut serde_json::Deserializer::from_str(raw_message);
1778            serde_path_to_error::deserialize(jd).unwrap_or_else(|err| {
1779                panic!("Deserialization error at {}: {}", err.path(), err);
1780            })
1781        };
1782        assert_eq!(content.role, Some(Role::User));
1783        assert_eq!(content.parts.len(), 7);
1784
1785        let parts: Vec<Part> = content.parts.into_iter().collect();
1786
1787        if let Part {
1788            part: PartKind::Text(text),
1789            ..
1790        } = &parts[0]
1791        {
1792            assert_eq!(text, "Hello, world!");
1793        } else {
1794            panic!("Expected text part");
1795        }
1796
1797        if let Part {
1798            part: PartKind::InlineData(inline_data),
1799            ..
1800        } = &parts[1]
1801        {
1802            assert_eq!(inline_data.mime_type, "image/png");
1803            assert_eq!(inline_data.data, "base64encodeddata");
1804        } else {
1805            panic!("Expected inline data part");
1806        }
1807
1808        if let Part {
1809            part: PartKind::FunctionCall(function_call),
1810            ..
1811        } = &parts[2]
1812        {
1813            assert_eq!(function_call.name, "test_function");
1814            assert_eq!(
1815                function_call.args.as_object().unwrap().get("arg1").unwrap(),
1816                "value1"
1817            );
1818        } else {
1819            panic!("Expected function call part");
1820        }
1821
1822        if let Part {
1823            part: PartKind::FunctionResponse(function_response),
1824            ..
1825        } = &parts[3]
1826        {
1827            assert_eq!(function_response.name, "test_function");
1828            assert_eq!(
1829                function_response
1830                    .response
1831                    .as_ref()
1832                    .unwrap()
1833                    .get("result")
1834                    .unwrap(),
1835                "success"
1836            );
1837        } else {
1838            panic!("Expected function response part");
1839        }
1840
1841        if let Part {
1842            part: PartKind::FileData(file_data),
1843            ..
1844        } = &parts[4]
1845        {
1846            assert_eq!(file_data.mime_type.as_ref().unwrap(), "application/pdf");
1847            assert_eq!(file_data.file_uri, "http://example.com/file.pdf");
1848        } else {
1849            panic!("Expected file data part");
1850        }
1851
1852        if let Part {
1853            part: PartKind::ExecutableCode(executable_code),
1854            ..
1855        } = &parts[5]
1856        {
1857            assert_eq!(executable_code.code, "print('Hello, world!')");
1858        } else {
1859            panic!("Expected executable code part");
1860        }
1861
1862        if let Part {
1863            part: PartKind::CodeExecutionResult(code_execution_result),
1864            ..
1865        } = &parts[6]
1866        {
1867            assert_eq!(
1868                code_execution_result.clone().output.unwrap(),
1869                "Hello, world!"
1870            );
1871        } else {
1872            panic!("Expected code execution result part");
1873        }
1874    }
1875
1876    #[test]
1877    fn test_deserialize_message_model() {
1878        let json_data = json!({
1879            "parts": [{"text": "Hello, user!"}],
1880            "role": "model"
1881        });
1882
1883        let content: Content = serde_json::from_value(json_data).unwrap();
1884        assert_eq!(content.role, Some(Role::Model));
1885        assert_eq!(content.parts.len(), 1);
1886        if let Some(Part {
1887            part: PartKind::Text(text),
1888            ..
1889        }) = content.parts.first()
1890        {
1891            assert_eq!(text, "Hello, user!");
1892        } else {
1893            panic!("Expected text part");
1894        }
1895    }
1896
1897    #[test]
1898    fn test_message_conversion_user() {
1899        let msg = message::Message::user("Hello, world!");
1900        let content: Content = msg.try_into().unwrap();
1901        assert_eq!(content.role, Some(Role::User));
1902        assert_eq!(content.parts.len(), 1);
1903        if let Some(Part {
1904            part: PartKind::Text(text),
1905            ..
1906        }) = &content.parts.first()
1907        {
1908            assert_eq!(text, "Hello, world!");
1909        } else {
1910            panic!("Expected text part");
1911        }
1912    }
1913
1914    #[test]
1915    fn test_message_conversion_model() {
1916        let msg = message::Message::assistant("Hello, user!");
1917
1918        let content: Content = msg.try_into().unwrap();
1919        assert_eq!(content.role, Some(Role::Model));
1920        assert_eq!(content.parts.len(), 1);
1921        if let Some(Part {
1922            part: PartKind::Text(text),
1923            ..
1924        }) = &content.parts.first()
1925        {
1926            assert_eq!(text, "Hello, user!");
1927        } else {
1928            panic!("Expected text part");
1929        }
1930    }
1931
1932    #[test]
1933    fn test_message_conversion_tool_call() {
1934        let tool_call = message::ToolCall {
1935            id: "test_tool".to_string(),
1936            call_id: None,
1937            function: message::ToolFunction {
1938                name: "test_function".to_string(),
1939                arguments: json!({"arg1": "value1"}),
1940            },
1941        };
1942
1943        let msg = message::Message::Assistant {
1944            id: None,
1945            content: OneOrMany::one(message::AssistantContent::ToolCall(tool_call)),
1946        };
1947
1948        let content: Content = msg.try_into().unwrap();
1949        assert_eq!(content.role, Some(Role::Model));
1950        assert_eq!(content.parts.len(), 1);
1951        if let Some(Part {
1952            part: PartKind::FunctionCall(function_call),
1953            ..
1954        }) = content.parts.first()
1955        {
1956            assert_eq!(function_call.name, "test_function");
1957            assert_eq!(
1958                function_call.args.as_object().unwrap().get("arg1").unwrap(),
1959                "value1"
1960            );
1961        } else {
1962            panic!("Expected function call part");
1963        }
1964    }
1965
1966    #[test]
1967    fn test_vec_schema_conversion() {
1968        let schema_with_ref = json!({
1969            "type": "array",
1970            "items": {
1971                "$ref": "#/$defs/Person"
1972            },
1973            "$defs": {
1974                "Person": {
1975                    "type": "object",
1976                    "properties": {
1977                        "first_name": {
1978                            "type": ["string", "null"],
1979                            "description": "The person's first name, if provided (null otherwise)"
1980                        },
1981                        "last_name": {
1982                            "type": ["string", "null"],
1983                            "description": "The person's last name, if provided (null otherwise)"
1984                        },
1985                        "job": {
1986                            "type": ["string", "null"],
1987                            "description": "The person's job, if provided (null otherwise)"
1988                        }
1989                    },
1990                    "required": []
1991                }
1992            }
1993        });
1994
1995        let result: Result<Schema, _> = schema_with_ref.try_into();
1996
1997        match result {
1998            Ok(schema) => {
1999                assert_eq!(schema.r#type, "array");
2000
2001                if let Some(items) = schema.items {
2002                    println!("item types: {}", items.r#type);
2003
2004                    assert_ne!(items.r#type, "", "Items type should not be empty string!");
2005                    assert_eq!(items.r#type, "object", "Items should be object type");
2006                } else {
2007                    panic!("Schema should have items field for array type");
2008                }
2009            }
2010            Err(e) => println!("Schema conversion failed: {:?}", e),
2011        }
2012    }
2013
2014    #[test]
2015    fn test_object_schema() {
2016        let simple_schema = json!({
2017            "type": "object",
2018            "properties": {
2019                "name": {
2020                    "type": "string"
2021                }
2022            }
2023        });
2024
2025        let schema: Schema = simple_schema.try_into().unwrap();
2026        assert_eq!(schema.r#type, "object");
2027        assert!(schema.properties.is_some());
2028    }
2029
2030    #[test]
2031    fn test_array_with_inline_items() {
2032        let inline_schema = json!({
2033            "type": "array",
2034            "items": {
2035                "type": "object",
2036                "properties": {
2037                    "name": {
2038                        "type": "string"
2039                    }
2040                }
2041            }
2042        });
2043
2044        let schema: Schema = inline_schema.try_into().unwrap();
2045        assert_eq!(schema.r#type, "array");
2046
2047        if let Some(items) = schema.items {
2048            assert_eq!(items.r#type, "object");
2049            assert!(items.properties.is_some());
2050        } else {
2051            panic!("Schema should have items field");
2052        }
2053    }
2054    #[test]
2055    fn test_flattened_schema() {
2056        let ref_schema = json!({
2057            "type": "array",
2058            "items": {
2059                "$ref": "#/$defs/Person"
2060            },
2061            "$defs": {
2062                "Person": {
2063                    "type": "object",
2064                    "properties": {
2065                        "name": { "type": "string" }
2066                    }
2067                }
2068            }
2069        });
2070
2071        let flattened = flatten_schema(ref_schema).unwrap();
2072        let schema: Schema = flattened.try_into().unwrap();
2073
2074        assert_eq!(schema.r#type, "array");
2075
2076        if let Some(items) = schema.items {
2077            println!("Flattened items type: '{}'", items.r#type);
2078
2079            assert_eq!(items.r#type, "object");
2080            assert!(items.properties.is_some());
2081        }
2082    }
2083}