rig/providers/gemini/
completion.rs

1// ================================================================
2//! Google Gemini Completion Integration
3//! From [Gemini API Reference](https://ai.google.dev/api/generate-content)
4// ================================================================
5/// `gemini-2.5-pro-preview-06-05` completion model
6pub const GEMINI_2_5_PRO_PREVIEW_06_05: &str = "gemini-2.5-pro-preview-06-05";
7/// `gemini-2.5-pro-preview-05-06` completion model
8pub const GEMINI_2_5_PRO_PREVIEW_05_06: &str = "gemini-2.5-pro-preview-05-06";
9/// `gemini-2.5-pro-preview-03-25` completion model
10pub const GEMINI_2_5_PRO_PREVIEW_03_25: &str = "gemini-2.5-pro-preview-03-25";
11/// `gemini-2.5-flash-preview-04-17` completion model
12pub const GEMINI_2_5_FLASH_PREVIEW_04_17: &str = "gemini-2.5-flash-preview-04-17";
13/// `gemini-2.5-pro-exp-03-25` experimental completion model
14pub const GEMINI_2_5_PRO_EXP_03_25: &str = "gemini-2.5-pro-exp-03-25";
15/// `gemini-2.5-flash` completion model
16pub const GEMINI_2_5_FLASH: &str = "gemini-2.5-flash";
17/// `gemini-2.0-flash-lite` completion model
18pub const GEMINI_2_0_FLASH_LITE: &str = "gemini-2.0-flash-lite";
19/// `gemini-2.0-flash` completion model
20pub const GEMINI_2_0_FLASH: &str = "gemini-2.0-flash";
21
22use self::gemini_api_types::Schema;
23use crate::http_client::HttpClientExt;
24use crate::message::{self, MimeType, Reasoning};
25
26use crate::providers::gemini::completion::gemini_api_types::{
27    AdditionalParameters, FunctionCallingMode, ToolConfig,
28};
29use crate::providers::gemini::streaming::StreamingCompletionResponse;
30use crate::telemetry::SpanCombinator;
31use crate::{
32    OneOrMany,
33    completion::{self, CompletionError, CompletionRequest},
34};
35use gemini_api_types::{
36    Content, FunctionDeclaration, GenerateContentRequest, GenerateContentResponse, Part, PartKind,
37    Role, Tool,
38};
39use serde_json::{Map, Value};
40use std::convert::TryFrom;
41use tracing::info_span;
42
43use super::Client;
44
45// =================================================================
46// Rig Implementation Types
47// =================================================================
48
49#[derive(Clone, Debug)]
50pub struct CompletionModel<T = reqwest::Client> {
51    pub(crate) client: Client<T>,
52    pub model: String,
53}
54
55impl<T> CompletionModel<T> {
56    pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
57        Self {
58            client,
59            model: model.into(),
60        }
61    }
62
63    pub fn with_model(client: Client<T>, model: &str) -> Self {
64        Self {
65            client,
66            model: model.into(),
67        }
68    }
69}
70
71impl<T> completion::CompletionModel for CompletionModel<T>
72where
73    T: HttpClientExt + Clone + 'static,
74{
75    type Response = GenerateContentResponse;
76    type StreamingResponse = StreamingCompletionResponse;
77    type Client = super::Client<T>;
78
79    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
80        Self::new(client.clone(), model)
81    }
82
83    #[cfg_attr(feature = "worker", worker::send)]
84    async fn completion(
85        &self,
86        completion_request: CompletionRequest,
87    ) -> Result<completion::CompletionResponse<GenerateContentResponse>, CompletionError> {
88        let span = if tracing::Span::current().is_disabled() {
89            info_span!(
90                target: "rig::completions",
91                "generate_content",
92                gen_ai.operation.name = "generate_content",
93                gen_ai.provider.name = "gcp.gemini",
94                gen_ai.request.model = self.model,
95                gen_ai.system_instructions = &completion_request.preamble,
96                gen_ai.response.id = tracing::field::Empty,
97                gen_ai.response.model = tracing::field::Empty,
98                gen_ai.usage.output_tokens = tracing::field::Empty,
99                gen_ai.usage.input_tokens = tracing::field::Empty,
100                gen_ai.input.messages = tracing::field::Empty,
101                gen_ai.output.messages = tracing::field::Empty,
102            )
103        } else {
104            tracing::Span::current()
105        };
106
107        let request = create_request_body(completion_request)?;
108        span.record_model_input(&request.contents);
109
110        span.record_model_input(&request.contents);
111
112        tracing::trace!(
113            target: "rig::completions",
114            "Sending completion request to Gemini API {}",
115            serde_json::to_string_pretty(&request)?
116        );
117
118        let body = serde_json::to_vec(&request)?;
119
120        let path = format!("/v1beta/models/{}:generateContent", self.model);
121
122        let request = self
123            .client
124            .post(path.as_str())?
125            .body(body)
126            .map_err(|e| CompletionError::HttpError(e.into()))?;
127
128        let response = self.client.send::<_, Vec<u8>>(request).await?;
129
130        if response.status().is_success() {
131            let response_body = response
132                .into_body()
133                .await
134                .map_err(CompletionError::HttpError)?;
135
136            let response_text = String::from_utf8_lossy(&response_body).to_string();
137            tracing::debug!("Received raw response from Gemini API: {}", response_text);
138
139            let response: GenerateContentResponse = serde_json::from_slice(&response_body)
140                .map_err(|err| {
141                    tracing::error!(
142                        error = %err,
143                        body = %response_text,
144                        "Failed to deserialize Gemini completion response"
145                    );
146                    CompletionError::JsonError(err)
147                })?;
148
149            match response.usage_metadata {
150                Some(ref usage) => tracing::info!(target: "rig",
151                "Gemini completion token usage: {}",
152                usage
153                ),
154                None => tracing::info!(target: "rig",
155                    "Gemini completion token usage: n/a",
156                ),
157            }
158
159            let span = tracing::Span::current();
160            span.record_model_output(&response.candidates);
161            span.record_response_metadata(&response);
162            span.record_token_usage(&response.usage_metadata);
163
164            tracing::trace!(
165                "Received response from Gemini API: {}",
166                serde_json::to_string_pretty(&response)?
167            );
168
169            response.try_into()
170        } else {
171            let text = String::from_utf8_lossy(
172                &response
173                    .into_body()
174                    .await
175                    .map_err(CompletionError::HttpError)?,
176            )
177            .into();
178
179            Err(CompletionError::ProviderError(text))
180        }
181    }
182
183    #[cfg_attr(feature = "worker", worker::send)]
184    async fn stream(
185        &self,
186        request: CompletionRequest,
187    ) -> Result<
188        crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
189        CompletionError,
190    > {
191        CompletionModel::stream(self, request).await
192    }
193}
194
195pub(crate) fn create_request_body(
196    completion_request: CompletionRequest,
197) -> Result<GenerateContentRequest, CompletionError> {
198    let mut full_history = Vec::new();
199    full_history.extend(completion_request.chat_history);
200
201    let additional_params = completion_request
202        .additional_params
203        .unwrap_or_else(|| Value::Object(Map::new()));
204
205    let AdditionalParameters {
206        mut generation_config,
207        additional_params,
208    } = serde_json::from_value::<AdditionalParameters>(additional_params)?;
209
210    generation_config = generation_config.map(|mut cfg| {
211        if let Some(temp) = completion_request.temperature {
212            cfg.temperature = Some(temp);
213        };
214
215        if let Some(max_tokens) = completion_request.max_tokens {
216            cfg.max_output_tokens = Some(max_tokens);
217        };
218
219        cfg
220    });
221
222    let system_instruction = completion_request.preamble.clone().map(|preamble| Content {
223        parts: vec![preamble.into()],
224        role: Some(Role::Model),
225    });
226
227    let tools = if completion_request.tools.is_empty() {
228        None
229    } else {
230        Some(Tool::try_from(completion_request.tools)?)
231    };
232
233    let tool_config = if let Some(cfg) = completion_request.tool_choice {
234        Some(ToolConfig {
235            function_calling_config: Some(FunctionCallingMode::try_from(cfg)?),
236        })
237    } else {
238        None
239    };
240
241    let request = GenerateContentRequest {
242        contents: full_history
243            .into_iter()
244            .map(|msg| {
245                msg.try_into()
246                    .map_err(|e| CompletionError::RequestError(Box::new(e)))
247            })
248            .collect::<Result<Vec<_>, _>>()?,
249        generation_config,
250        safety_settings: None,
251        tools,
252        tool_config,
253        system_instruction,
254        additional_params,
255    };
256
257    Ok(request)
258}
259
260impl TryFrom<completion::ToolDefinition> for Tool {
261    type Error = CompletionError;
262
263    fn try_from(tool: completion::ToolDefinition) -> Result<Self, Self::Error> {
264        let parameters: Option<Schema> =
265            if tool.parameters == serde_json::json!({"type": "object", "properties": {}}) {
266                None
267            } else {
268                Some(tool.parameters.try_into()?)
269            };
270
271        Ok(Self {
272            function_declarations: vec![FunctionDeclaration {
273                name: tool.name,
274                description: tool.description,
275                parameters,
276            }],
277            code_execution: None,
278        })
279    }
280}
281
282impl TryFrom<Vec<completion::ToolDefinition>> for Tool {
283    type Error = CompletionError;
284
285    fn try_from(tools: Vec<completion::ToolDefinition>) -> Result<Self, Self::Error> {
286        let mut function_declarations = Vec::new();
287
288        for tool in tools {
289            let parameters =
290                if tool.parameters == serde_json::json!({"type": "object", "properties": {}}) {
291                    None
292                } else {
293                    match tool.parameters.try_into() {
294                        Ok(schema) => Some(schema),
295                        Err(e) => {
296                            let emsg = format!(
297                                "Tool '{}' could not be converted to a schema: {:?}",
298                                tool.name, e,
299                            );
300                            return Err(CompletionError::ProviderError(emsg));
301                        }
302                    }
303                };
304
305            function_declarations.push(FunctionDeclaration {
306                name: tool.name,
307                description: tool.description,
308                parameters,
309            });
310        }
311
312        Ok(Self {
313            function_declarations,
314            code_execution: None,
315        })
316    }
317}
318
319impl TryFrom<GenerateContentResponse> for completion::CompletionResponse<GenerateContentResponse> {
320    type Error = CompletionError;
321
322    fn try_from(response: GenerateContentResponse) -> Result<Self, Self::Error> {
323        let candidate = response.candidates.first().ok_or_else(|| {
324            CompletionError::ResponseError("No response candidates in response".into())
325        })?;
326
327        let content = candidate
328            .content
329            .as_ref()
330            .ok_or_else(|| {
331                let reason = candidate
332                    .finish_reason
333                    .as_ref()
334                    .map(|r| format!("finish_reason={r:?}"))
335                    .unwrap_or_else(|| "finish_reason=<unknown>".to_string());
336                let message = candidate
337                    .finish_message
338                    .as_deref()
339                    .unwrap_or("no finish message provided");
340                CompletionError::ResponseError(format!(
341                    "Gemini candidate missing content ({reason}, finish_message={message})"
342                ))
343            })?
344            .parts
345            .iter()
346            .map(|Part { thought, part, .. }| {
347                Ok(match part {
348                    PartKind::Text(text) => {
349                        if let Some(thought) = thought
350                            && *thought
351                        {
352                            completion::AssistantContent::Reasoning(Reasoning::new(text))
353                        } else {
354                            completion::AssistantContent::text(text)
355                        }
356                    }
357                    PartKind::InlineData(inline_data) => {
358                        let mime_type = message::MediaType::from_mime_type(&inline_data.mime_type);
359
360                        match mime_type {
361                            Some(message::MediaType::Image(media_type)) => {
362                                message::AssistantContent::image_base64(
363                                    &inline_data.data,
364                                    Some(media_type),
365                                    Some(message::ImageDetail::default()),
366                                )
367                            }
368                            _ => {
369                                return Err(CompletionError::ResponseError(format!(
370                                    "Unsupported media type {mime_type:?}"
371                                )));
372                            }
373                        }
374                    }
375                    PartKind::FunctionCall(function_call) => {
376                        completion::AssistantContent::tool_call(
377                            &function_call.name,
378                            &function_call.name,
379                            function_call.args.clone(),
380                        )
381                    }
382                    _ => {
383                        return Err(CompletionError::ResponseError(
384                            "Response did not contain a message or tool call".into(),
385                        ));
386                    }
387                })
388            })
389            .collect::<Result<Vec<_>, _>>()?;
390
391        let choice = OneOrMany::many(content).map_err(|_| {
392            CompletionError::ResponseError(
393                "Response contained no message or tool call (empty)".to_owned(),
394            )
395        })?;
396
397        let usage = response
398            .usage_metadata
399            .as_ref()
400            .map(|usage| completion::Usage {
401                input_tokens: usage.prompt_token_count as u64,
402                output_tokens: usage.candidates_token_count.unwrap_or(0) as u64,
403                total_tokens: usage.total_token_count as u64,
404            })
405            .unwrap_or_default();
406
407        Ok(completion::CompletionResponse {
408            choice,
409            usage,
410            raw_response: response,
411        })
412    }
413}
414
415pub mod gemini_api_types {
416    use crate::telemetry::ProviderResponseExt;
417    use std::{collections::HashMap, convert::Infallible, str::FromStr};
418
419    // =================================================================
420    // Gemini API Types
421    // =================================================================
422    use serde::{Deserialize, Serialize};
423    use serde_json::{Value, json};
424
425    use crate::completion::GetTokenUsage;
426    use crate::message::{DocumentSourceKind, ImageMediaType, MessageError, MimeType};
427    use crate::{
428        OneOrMany,
429        completion::CompletionError,
430        message::{self, Reasoning, Text},
431        providers::gemini::gemini_api_types::{CodeExecutionResult, ExecutableCode},
432    };
433
434    #[derive(Debug, Deserialize, Serialize, Default)]
435    #[serde(rename_all = "camelCase")]
436    pub struct AdditionalParameters {
437        /// Change your Gemini request configuration.
438        pub generation_config: Option<GenerationConfig>,
439        /// Any additional parameters that you want.
440        #[serde(flatten, skip_serializing_if = "Option::is_none")]
441        pub additional_params: Option<serde_json::Value>,
442    }
443
444    impl AdditionalParameters {
445        pub fn with_config(mut self, cfg: GenerationConfig) -> Self {
446            self.generation_config = Some(cfg);
447            self
448        }
449
450        pub fn with_params(mut self, params: serde_json::Value) -> Self {
451            self.additional_params = Some(params);
452            self
453        }
454    }
455
456    /// Response from the model supporting multiple candidate responses.
457    /// Safety ratings and content filtering are reported for both prompt in GenerateContentResponse.prompt_feedback
458    /// and for each candidate in finishReason and in safetyRatings.
459    /// The API:
460    ///     - Returns either all requested candidates or none of them
461    ///     - Returns no candidates at all only if there was something wrong with the prompt (check promptFeedback)
462    ///     - Reports feedback on each candidate in finishReason and safetyRatings.
463    #[derive(Debug, Deserialize, Serialize)]
464    #[serde(rename_all = "camelCase")]
465    pub struct GenerateContentResponse {
466        pub response_id: String,
467        /// Candidate responses from the model.
468        pub candidates: Vec<ContentCandidate>,
469        /// Returns the prompt's feedback related to the content filters.
470        pub prompt_feedback: Option<PromptFeedback>,
471        /// Output only. Metadata on the generation requests' token usage.
472        pub usage_metadata: Option<UsageMetadata>,
473        pub model_version: Option<String>,
474    }
475
476    impl ProviderResponseExt for GenerateContentResponse {
477        type OutputMessage = ContentCandidate;
478        type Usage = UsageMetadata;
479
480        fn get_response_id(&self) -> Option<String> {
481            Some(self.response_id.clone())
482        }
483
484        fn get_response_model_name(&self) -> Option<String> {
485            None
486        }
487
488        fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
489            self.candidates.clone()
490        }
491
492        fn get_text_response(&self) -> Option<String> {
493            let str = self
494                .candidates
495                .iter()
496                .filter_map(|x| {
497                    let content = x.content.as_ref()?;
498                    if content.role.as_ref().is_none_or(|y| y != &Role::Model) {
499                        return None;
500                    }
501
502                    let res = content
503                        .parts
504                        .iter()
505                        .filter_map(|part| {
506                            if let PartKind::Text(ref str) = part.part {
507                                Some(str.to_owned())
508                            } else {
509                                None
510                            }
511                        })
512                        .collect::<Vec<String>>()
513                        .join("\n");
514
515                    Some(res)
516                })
517                .collect::<Vec<String>>()
518                .join("\n");
519
520            if str.is_empty() { None } else { Some(str) }
521        }
522
523        fn get_usage(&self) -> Option<Self::Usage> {
524            self.usage_metadata.clone()
525        }
526    }
527
528    /// A response candidate generated from the model.
529    #[derive(Clone, Debug, Deserialize, Serialize)]
530    #[serde(rename_all = "camelCase")]
531    pub struct ContentCandidate {
532        /// Output only. Generated content returned from the model.
533        #[serde(skip_serializing_if = "Option::is_none")]
534        pub content: Option<Content>,
535        /// Optional. Output only. The reason why the model stopped generating tokens.
536        /// If empty, the model has not stopped generating tokens.
537        pub finish_reason: Option<FinishReason>,
538        /// List of ratings for the safety of a response candidate.
539        /// There is at most one rating per category.
540        pub safety_ratings: Option<Vec<SafetyRating>>,
541        /// Output only. Citation information for model-generated candidate.
542        /// This field may be populated with recitation information for any text included in the content.
543        /// These are passages that are "recited" from copyrighted material in the foundational LLM's training data.
544        pub citation_metadata: Option<CitationMetadata>,
545        /// Output only. Token count for this candidate.
546        pub token_count: Option<i32>,
547        /// Output only.
548        pub avg_logprobs: Option<f64>,
549        /// Output only. Log-likelihood scores for the response tokens and top tokens
550        pub logprobs_result: Option<LogprobsResult>,
551        /// Output only. Index of the candidate in the list of response candidates.
552        pub index: Option<i32>,
553        /// Output only. Additional information about why the model stopped generating tokens.
554        pub finish_message: Option<String>,
555    }
556
557    #[derive(Clone, Debug, Deserialize, Serialize)]
558    pub struct Content {
559        /// Ordered Parts that constitute a single message. Parts may have different MIME types.
560        #[serde(default)]
561        pub parts: Vec<Part>,
562        /// The producer of the content. Must be either 'user' or 'model'.
563        /// Useful to set for multi-turn conversations, otherwise can be left blank or unset.
564        pub role: Option<Role>,
565    }
566
567    impl TryFrom<message::Message> for Content {
568        type Error = message::MessageError;
569
570        fn try_from(msg: message::Message) -> Result<Self, Self::Error> {
571            Ok(match msg {
572                message::Message::User { content } => Content {
573                    parts: content
574                        .into_iter()
575                        .map(|c| c.try_into())
576                        .collect::<Result<Vec<_>, _>>()?,
577                    role: Some(Role::User),
578                },
579                message::Message::Assistant { content, .. } => Content {
580                    role: Some(Role::Model),
581                    parts: content
582                        .into_iter()
583                        .map(|content| content.try_into())
584                        .collect::<Result<Vec<_>, _>>()?,
585                },
586            })
587        }
588    }
589
590    impl TryFrom<Content> for message::Message {
591        type Error = message::MessageError;
592
593        fn try_from(content: Content) -> Result<Self, Self::Error> {
594            match content.role {
595                Some(Role::User) | None => {
596                    Ok(message::Message::User {
597                        content: {
598                            let user_content: Result<Vec<_>, _> = content.parts.into_iter()
599                            .map(|Part { part, .. }| {
600                                Ok(match part {
601                                    PartKind::Text(text) => message::UserContent::text(text),
602                                    PartKind::InlineData(inline_data) => {
603                                        let mime_type =
604                                            message::MediaType::from_mime_type(&inline_data.mime_type);
605
606                                        match mime_type {
607                                            Some(message::MediaType::Image(media_type)) => {
608                                                message::UserContent::image_base64(
609                                                    inline_data.data,
610                                                    Some(media_type),
611                                                    Some(message::ImageDetail::default()),
612                                                )
613                                            }
614                                            Some(message::MediaType::Document(media_type)) => {
615                                                message::UserContent::document(
616                                                    inline_data.data,
617                                                    Some(media_type),
618                                                )
619                                            }
620                                            Some(message::MediaType::Audio(media_type)) => {
621                                                message::UserContent::audio(
622                                                    inline_data.data,
623                                                    Some(media_type),
624                                                )
625                                            }
626                                            _ => {
627                                                return Err(message::MessageError::ConversionError(
628                                                    format!("Unsupported media type {mime_type:?}"),
629                                                ));
630                                            }
631                                        }
632                                    }
633                                    _ => {
634                                        return Err(message::MessageError::ConversionError(format!(
635                                            "Unsupported gemini content part type: {part:?}"
636                                        )));
637                                    }
638                                })
639                            })
640                            .collect();
641                            OneOrMany::many(user_content?).map_err(|_| {
642                                message::MessageError::ConversionError(
643                                    "Failed to create OneOrMany from user content".to_string(),
644                                )
645                            })?
646                        },
647                    })
648                }
649                Some(Role::Model) => Ok(message::Message::Assistant {
650                    id: None,
651                    content: {
652                        let assistant_content: Result<Vec<_>, _> = content
653                            .parts
654                            .into_iter()
655                            .map(|Part { thought, part, .. }| {
656                                Ok(match part {
657                                    PartKind::Text(text) => match thought {
658                                        Some(true) => message::AssistantContent::Reasoning(
659                                            Reasoning::new(&text),
660                                        ),
661                                        _ => message::AssistantContent::Text(Text { text }),
662                                    },
663
664                                    PartKind::FunctionCall(function_call) => {
665                                        message::AssistantContent::ToolCall(function_call.into())
666                                    }
667                                    _ => {
668                                        return Err(message::MessageError::ConversionError(
669                                            format!("Unsupported part type: {part:?}"),
670                                        ));
671                                    }
672                                })
673                            })
674                            .collect();
675                        OneOrMany::many(assistant_content?).map_err(|_| {
676                            message::MessageError::ConversionError(
677                                "Failed to create OneOrMany from assistant content".to_string(),
678                            )
679                        })?
680                    },
681                }),
682            }
683        }
684    }
685
686    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
687    #[serde(rename_all = "lowercase")]
688    pub enum Role {
689        User,
690        Model,
691    }
692
693    #[derive(Debug, Default, Deserialize, Serialize, Clone, PartialEq)]
694    #[serde(rename_all = "camelCase")]
695    pub struct Part {
696        /// whether or not the part is a reasoning/thinking text or not
697        #[serde(skip_serializing_if = "Option::is_none")]
698        pub thought: Option<bool>,
699        /// an opaque sig for the thought so it can be reused - is a base64 string
700        #[serde(skip_serializing_if = "Option::is_none")]
701        pub thought_signature: Option<String>,
702        #[serde(flatten)]
703        pub part: PartKind,
704        #[serde(flatten, skip_serializing_if = "Option::is_none")]
705        pub additional_params: Option<Value>,
706    }
707
708    /// A datatype containing media that is part of a multi-part [Content] message.
709    /// A Part consists of data which has an associated datatype. A Part can only contain one of the accepted types in Part.data.
710    /// A Part must have a fixed IANA MIME type identifying the type and subtype of the media if the inlineData field is filled with raw bytes.
711    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
712    #[serde(rename_all = "camelCase")]
713    pub enum PartKind {
714        Text(String),
715        InlineData(Blob),
716        FunctionCall(FunctionCall),
717        FunctionResponse(FunctionResponse),
718        FileData(FileData),
719        ExecutableCode(ExecutableCode),
720        CodeExecutionResult(CodeExecutionResult),
721    }
722
723    // This default instance is primarily so we can easily fill in the optional fields of `Part`
724    // So this instance for `PartKind` (and the allocation it would cause) should be optimized away
725    impl Default for PartKind {
726        fn default() -> Self {
727            Self::Text(String::new())
728        }
729    }
730
731    impl From<String> for Part {
732        fn from(text: String) -> Self {
733            Self {
734                thought: Some(false),
735                thought_signature: None,
736                part: PartKind::Text(text),
737                additional_params: None,
738            }
739        }
740    }
741
742    impl From<&str> for Part {
743        fn from(text: &str) -> Self {
744            Self::from(text.to_string())
745        }
746    }
747
748    impl FromStr for Part {
749        type Err = Infallible;
750
751        fn from_str(s: &str) -> Result<Self, Self::Err> {
752            Ok(s.into())
753        }
754    }
755
756    impl TryFrom<(ImageMediaType, DocumentSourceKind)> for PartKind {
757        type Error = message::MessageError;
758        fn try_from(
759            (mime_type, doc_src): (ImageMediaType, DocumentSourceKind),
760        ) -> Result<Self, Self::Error> {
761            let mime_type = mime_type.to_mime_type().to_string();
762            let part = match doc_src {
763                DocumentSourceKind::Url(url) => PartKind::FileData(FileData {
764                    mime_type: Some(mime_type),
765                    file_uri: url,
766                }),
767                DocumentSourceKind::Base64(data) | DocumentSourceKind::String(data) => {
768                    PartKind::InlineData(Blob { mime_type, data })
769                }
770                DocumentSourceKind::Raw(_) => {
771                    return Err(message::MessageError::ConversionError(
772                        "Raw files not supported, encode as base64 first".into(),
773                    ));
774                }
775                DocumentSourceKind::Unknown => {
776                    return Err(message::MessageError::ConversionError(
777                        "Can't convert an unknown document source".to_string(),
778                    ));
779                }
780            };
781
782            Ok(part)
783        }
784    }
785
786    impl TryFrom<message::UserContent> for Part {
787        type Error = message::MessageError;
788
789        fn try_from(content: message::UserContent) -> Result<Self, Self::Error> {
790            match content {
791                message::UserContent::Text(message::Text { text }) => Ok(Part {
792                    thought: Some(false),
793                    thought_signature: None,
794                    part: PartKind::Text(text),
795                    additional_params: None,
796                }),
797                message::UserContent::ToolResult(message::ToolResult { id, content, .. }) => {
798                    let content = match content.first() {
799                        message::ToolResultContent::Text(text) => text.text,
800                        message::ToolResultContent::Image(_) => {
801                            return Err(message::MessageError::ConversionError(
802                                "Tool result content must be text".to_string(),
803                            ));
804                        }
805                    };
806                    // Convert to JSON since this value may be a valid JSON value
807                    let result: serde_json::Value =
808                        serde_json::from_str(&content).unwrap_or_else(|error| {
809                            tracing::trace!(
810                                ?error,
811                                "Tool result is not a valid JSON, treat it as normal string"
812                            );
813                            json!(content)
814                        });
815                    Ok(Part {
816                        thought: Some(false),
817                        thought_signature: None,
818                        part: PartKind::FunctionResponse(FunctionResponse {
819                            name: id,
820                            response: Some(json!({ "result": result })),
821                        }),
822                        additional_params: None,
823                    })
824                }
825                message::UserContent::Image(message::Image {
826                    data, media_type, ..
827                }) => match media_type {
828                    Some(media_type) => match media_type {
829                        message::ImageMediaType::JPEG
830                        | message::ImageMediaType::PNG
831                        | message::ImageMediaType::WEBP
832                        | message::ImageMediaType::HEIC
833                        | message::ImageMediaType::HEIF => {
834                            let part = PartKind::try_from((media_type, data))?;
835                            Ok(Part {
836                                thought: Some(false),
837                                thought_signature: None,
838                                part,
839                                additional_params: None,
840                            })
841                        }
842                        _ => Err(message::MessageError::ConversionError(format!(
843                            "Unsupported image media type {media_type:?}"
844                        ))),
845                    },
846                    None => Err(message::MessageError::ConversionError(
847                        "Media type for image is required for Gemini".to_string(),
848                    )),
849                },
850                message::UserContent::Document(message::Document {
851                    data, media_type, ..
852                }) => {
853                    let Some(media_type) = media_type else {
854                        return Err(MessageError::ConversionError(
855                            "A mime type is required for document inputs to Gemini".to_string(),
856                        ));
857                    };
858
859                    if !media_type.is_code() {
860                        let mime_type = media_type.to_mime_type().to_string();
861
862                        let part = match data {
863                            DocumentSourceKind::Url(file_uri) => PartKind::FileData(FileData {
864                                mime_type: Some(mime_type),
865                                file_uri,
866                            }),
867                            DocumentSourceKind::Base64(data) | DocumentSourceKind::String(data) => {
868                                PartKind::InlineData(Blob { mime_type, data })
869                            }
870                            DocumentSourceKind::Raw(_) => {
871                                return Err(message::MessageError::ConversionError(
872                                    "Raw files not supported, encode as base64 first".into(),
873                                ));
874                            }
875                            _ => {
876                                return Err(message::MessageError::ConversionError(
877                                    "Document has no body".to_string(),
878                                ));
879                            }
880                        };
881
882                        Ok(Part {
883                            thought: Some(false),
884                            part,
885                            ..Default::default()
886                        })
887                    } else {
888                        Err(message::MessageError::ConversionError(format!(
889                            "Unsupported document media type {media_type:?}"
890                        )))
891                    }
892                }
893
894                message::UserContent::Audio(message::Audio {
895                    data, media_type, ..
896                }) => {
897                    let Some(media_type) = media_type else {
898                        return Err(MessageError::ConversionError(
899                            "A mime type is required for audio inputs to Gemini".to_string(),
900                        ));
901                    };
902
903                    let mime_type = media_type.to_mime_type().to_string();
904
905                    let part = match data {
906                        DocumentSourceKind::Base64(data) => {
907                            PartKind::InlineData(Blob { data, mime_type })
908                        }
909
910                        DocumentSourceKind::Url(file_uri) => PartKind::FileData(FileData {
911                            mime_type: Some(mime_type),
912                            file_uri,
913                        }),
914                        DocumentSourceKind::String(_) => {
915                            return Err(message::MessageError::ConversionError(
916                                "Strings cannot be used as audio files!".into(),
917                            ));
918                        }
919                        DocumentSourceKind::Raw(_) => {
920                            return Err(message::MessageError::ConversionError(
921                                "Raw files not supported, encode as base64 first".into(),
922                            ));
923                        }
924                        DocumentSourceKind::Unknown => {
925                            return Err(message::MessageError::ConversionError(
926                                "Content has no body".to_string(),
927                            ));
928                        }
929                    };
930
931                    Ok(Part {
932                        thought: Some(false),
933                        part,
934                        ..Default::default()
935                    })
936                }
937                message::UserContent::Video(message::Video {
938                    data,
939                    media_type,
940                    additional_params,
941                    ..
942                }) => {
943                    let mime_type = media_type.map(|media_ty| media_ty.to_mime_type().to_string());
944
945                    let part = match data {
946                        DocumentSourceKind::Url(file_uri) => {
947                            if file_uri.starts_with("https://www.youtube.com") {
948                                PartKind::FileData(FileData {
949                                    mime_type,
950                                    file_uri,
951                                })
952                            } else {
953                                if mime_type.is_none() {
954                                    return Err(MessageError::ConversionError(
955                                        "A mime type is required for non-Youtube video file inputs to Gemini"
956                                            .to_string(),
957                                    ));
958                                }
959
960                                PartKind::FileData(FileData {
961                                    mime_type,
962                                    file_uri,
963                                })
964                            }
965                        }
966                        DocumentSourceKind::Base64(data) => {
967                            let Some(mime_type) = mime_type else {
968                                return Err(MessageError::ConversionError(
969                                    "A media type is expected for base64 encoded strings"
970                                        .to_string(),
971                                ));
972                            };
973                            PartKind::InlineData(Blob { mime_type, data })
974                        }
975                        DocumentSourceKind::String(_) => {
976                            return Err(message::MessageError::ConversionError(
977                                "Strings cannot be used as audio files!".into(),
978                            ));
979                        }
980                        DocumentSourceKind::Raw(_) => {
981                            return Err(message::MessageError::ConversionError(
982                                "Raw file data not supported, encode as base64 first".into(),
983                            ));
984                        }
985                        DocumentSourceKind::Unknown => {
986                            return Err(message::MessageError::ConversionError(
987                                "Media type for video is required for Gemini".to_string(),
988                            ));
989                        }
990                    };
991
992                    Ok(Part {
993                        thought: Some(false),
994                        thought_signature: None,
995                        part,
996                        additional_params,
997                    })
998                }
999            }
1000        }
1001    }
1002
1003    impl TryFrom<message::AssistantContent> for Part {
1004        type Error = message::MessageError;
1005
1006        fn try_from(content: message::AssistantContent) -> Result<Self, Self::Error> {
1007            match content {
1008                message::AssistantContent::Text(message::Text { text }) => Ok(text.into()),
1009                message::AssistantContent::Image(message::Image {
1010                    data, media_type, ..
1011                }) => match media_type {
1012                    Some(media_type) => match media_type {
1013                        message::ImageMediaType::JPEG
1014                        | message::ImageMediaType::PNG
1015                        | message::ImageMediaType::WEBP
1016                        | message::ImageMediaType::HEIC
1017                        | message::ImageMediaType::HEIF => {
1018                            let part = PartKind::try_from((media_type, data))?;
1019                            Ok(Part {
1020                                thought: Some(false),
1021                                thought_signature: None,
1022                                part,
1023                                additional_params: None,
1024                            })
1025                        }
1026                        _ => Err(message::MessageError::ConversionError(format!(
1027                            "Unsupported image media type {media_type:?}"
1028                        ))),
1029                    },
1030                    None => Err(message::MessageError::ConversionError(
1031                        "Media type for image is required for Gemini".to_string(),
1032                    )),
1033                },
1034                message::AssistantContent::ToolCall(tool_call) => Ok(tool_call.into()),
1035                message::AssistantContent::Reasoning(message::Reasoning { reasoning, .. }) => {
1036                    Ok(Part {
1037                        thought: Some(true),
1038                        thought_signature: None,
1039                        part: PartKind::Text(
1040                            reasoning.first().cloned().unwrap_or_else(|| "".to_string()),
1041                        ),
1042                        additional_params: None,
1043                    })
1044                }
1045            }
1046        }
1047    }
1048
1049    impl From<message::ToolCall> for Part {
1050        fn from(tool_call: message::ToolCall) -> Self {
1051            Self {
1052                thought: Some(false),
1053                thought_signature: None,
1054                part: PartKind::FunctionCall(FunctionCall {
1055                    name: tool_call.function.name,
1056                    args: tool_call.function.arguments,
1057                }),
1058                additional_params: None,
1059            }
1060        }
1061    }
1062
1063    /// Raw media bytes.
1064    /// Text should not be sent as raw bytes, use the 'text' field.
1065    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1066    #[serde(rename_all = "camelCase")]
1067    pub struct Blob {
1068        /// The IANA standard MIME type of the source data. Examples: - image/png - image/jpeg
1069        /// If an unsupported MIME type is provided, an error will be returned.
1070        pub mime_type: String,
1071        /// Raw bytes for media formats. A base64-encoded string.
1072        pub data: String,
1073    }
1074
1075    /// A predicted FunctionCall returned from the model that contains a string representing the
1076    /// FunctionDeclaration.name with the arguments and their values.
1077    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1078    pub struct FunctionCall {
1079        /// Required. The name of the function to call. Must be a-z, A-Z, 0-9, or contain underscores
1080        /// and dashes, with a maximum length of 63.
1081        pub name: String,
1082        /// Optional. The function parameters and values in JSON object format.
1083        pub args: serde_json::Value,
1084    }
1085
1086    impl From<FunctionCall> for message::ToolCall {
1087        fn from(function_call: FunctionCall) -> Self {
1088            Self {
1089                id: function_call.name.clone(),
1090                call_id: None,
1091                function: message::ToolFunction {
1092                    name: function_call.name,
1093                    arguments: function_call.args,
1094                },
1095            }
1096        }
1097    }
1098
1099    impl From<message::ToolCall> for FunctionCall {
1100        fn from(tool_call: message::ToolCall) -> Self {
1101            Self {
1102                name: tool_call.function.name,
1103                args: tool_call.function.arguments,
1104            }
1105        }
1106    }
1107
1108    /// The result output from a FunctionCall that contains a string representing the FunctionDeclaration.name
1109    /// and a structured JSON object containing any output from the function is used as context to the model.
1110    /// This should contain the result of aFunctionCall made based on model prediction.
1111    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1112    pub struct FunctionResponse {
1113        /// The name of the function to call. Must be a-z, A-Z, 0-9, or contain underscores and dashes,
1114        /// with a maximum length of 63.
1115        pub name: String,
1116        /// The function response in JSON object format.
1117        pub response: Option<serde_json::Value>,
1118    }
1119
1120    /// URI based data.
1121    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1122    #[serde(rename_all = "camelCase")]
1123    pub struct FileData {
1124        /// Optional. The IANA standard MIME type of the source data.
1125        pub mime_type: Option<String>,
1126        /// Required. URI.
1127        pub file_uri: String,
1128    }
1129
1130    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1131    pub struct SafetyRating {
1132        pub category: HarmCategory,
1133        pub probability: HarmProbability,
1134    }
1135
1136    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1137    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1138    pub enum HarmProbability {
1139        HarmProbabilityUnspecified,
1140        Negligible,
1141        Low,
1142        Medium,
1143        High,
1144    }
1145
1146    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1147    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1148    pub enum HarmCategory {
1149        HarmCategoryUnspecified,
1150        HarmCategoryDerogatory,
1151        HarmCategoryToxicity,
1152        HarmCategoryViolence,
1153        HarmCategorySexually,
1154        HarmCategoryMedical,
1155        HarmCategoryDangerous,
1156        HarmCategoryHarassment,
1157        HarmCategoryHateSpeech,
1158        HarmCategorySexuallyExplicit,
1159        HarmCategoryDangerousContent,
1160        HarmCategoryCivicIntegrity,
1161    }
1162
1163    #[derive(Debug, Deserialize, Clone, Default, Serialize)]
1164    #[serde(rename_all = "camelCase")]
1165    pub struct UsageMetadata {
1166        pub prompt_token_count: i32,
1167        #[serde(skip_serializing_if = "Option::is_none")]
1168        pub cached_content_token_count: Option<i32>,
1169        #[serde(skip_serializing_if = "Option::is_none")]
1170        pub candidates_token_count: Option<i32>,
1171        pub total_token_count: i32,
1172        #[serde(skip_serializing_if = "Option::is_none")]
1173        pub thoughts_token_count: Option<i32>,
1174    }
1175
1176    impl std::fmt::Display for UsageMetadata {
1177        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1178            write!(
1179                f,
1180                "Prompt token count: {}\nCached content token count: {}\nCandidates token count: {}\nTotal token count: {}",
1181                self.prompt_token_count,
1182                match self.cached_content_token_count {
1183                    Some(count) => count.to_string(),
1184                    None => "n/a".to_string(),
1185                },
1186                match self.candidates_token_count {
1187                    Some(count) => count.to_string(),
1188                    None => "n/a".to_string(),
1189                },
1190                self.total_token_count
1191            )
1192        }
1193    }
1194
1195    impl GetTokenUsage for UsageMetadata {
1196        fn token_usage(&self) -> Option<crate::completion::Usage> {
1197            let mut usage = crate::completion::Usage::new();
1198
1199            usage.input_tokens = self.prompt_token_count as u64;
1200            usage.output_tokens = (self.cached_content_token_count.unwrap_or_default()
1201                + self.candidates_token_count.unwrap_or_default()
1202                + self.thoughts_token_count.unwrap_or_default())
1203                as u64;
1204            usage.total_tokens = usage.input_tokens + usage.output_tokens;
1205
1206            Some(usage)
1207        }
1208    }
1209
1210    /// A set of the feedback metadata the prompt specified in [GenerateContentRequest.contents](GenerateContentRequest).
1211    #[derive(Debug, Deserialize, Serialize)]
1212    #[serde(rename_all = "camelCase")]
1213    pub struct PromptFeedback {
1214        /// Optional. If set, the prompt was blocked and no candidates are returned. Rephrase the prompt.
1215        pub block_reason: Option<BlockReason>,
1216        /// Ratings for safety of the prompt. There is at most one rating per category.
1217        pub safety_ratings: Option<Vec<SafetyRating>>,
1218    }
1219
1220    /// Reason why a prompt was blocked by the model
1221    #[derive(Debug, Deserialize, Serialize)]
1222    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1223    pub enum BlockReason {
1224        /// Default value. This value is unused.
1225        BlockReasonUnspecified,
1226        /// Prompt was blocked due to safety reasons. Inspect safetyRatings to understand which safety category blocked it.
1227        Safety,
1228        /// Prompt was blocked due to unknown reasons.
1229        Other,
1230        /// Prompt was blocked due to the terms which are included from the terminology blocklist.
1231        Blocklist,
1232        /// Prompt was blocked due to prohibited content.
1233        ProhibitedContent,
1234    }
1235
1236    #[derive(Clone, Debug, Deserialize, Serialize)]
1237    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1238    pub enum FinishReason {
1239        /// Default value. This value is unused.
1240        FinishReasonUnspecified,
1241        /// Natural stop point of the model or provided stop sequence.
1242        Stop,
1243        /// The maximum number of tokens as specified in the request was reached.
1244        MaxTokens,
1245        /// The response candidate content was flagged for safety reasons.
1246        Safety,
1247        /// The response candidate content was flagged for recitation reasons.
1248        Recitation,
1249        /// The response candidate content was flagged for using an unsupported language.
1250        Language,
1251        /// Unknown reason.
1252        Other,
1253        /// Token generation stopped because the content contains forbidden terms.
1254        Blocklist,
1255        /// Token generation stopped for potentially containing prohibited content.
1256        ProhibitedContent,
1257        /// Token generation stopped because the content potentially contains Sensitive Personally Identifiable Information (SPII).
1258        Spii,
1259        /// The function call generated by the model is invalid.
1260        MalformedFunctionCall,
1261    }
1262
1263    #[derive(Clone, Debug, Deserialize, Serialize)]
1264    #[serde(rename_all = "camelCase")]
1265    pub struct CitationMetadata {
1266        pub citation_sources: Vec<CitationSource>,
1267    }
1268
1269    #[derive(Clone, Debug, Deserialize, Serialize)]
1270    #[serde(rename_all = "camelCase")]
1271    pub struct CitationSource {
1272        #[serde(skip_serializing_if = "Option::is_none")]
1273        pub uri: Option<String>,
1274        #[serde(skip_serializing_if = "Option::is_none")]
1275        pub start_index: Option<i32>,
1276        #[serde(skip_serializing_if = "Option::is_none")]
1277        pub end_index: Option<i32>,
1278        #[serde(skip_serializing_if = "Option::is_none")]
1279        pub license: Option<String>,
1280    }
1281
1282    #[derive(Clone, Debug, Deserialize, Serialize)]
1283    #[serde(rename_all = "camelCase")]
1284    pub struct LogprobsResult {
1285        pub top_candidate: Vec<TopCandidate>,
1286        pub chosen_candidate: Vec<LogProbCandidate>,
1287    }
1288
1289    #[derive(Clone, Debug, Deserialize, Serialize)]
1290    pub struct TopCandidate {
1291        pub candidates: Vec<LogProbCandidate>,
1292    }
1293
1294    #[derive(Clone, Debug, Deserialize, Serialize)]
1295    #[serde(rename_all = "camelCase")]
1296    pub struct LogProbCandidate {
1297        pub token: String,
1298        pub token_id: String,
1299        pub log_probability: f64,
1300    }
1301
1302    /// Gemini API Configuration options for model generation and outputs. Not all parameters are
1303    /// configurable for every model. From [Gemini API Reference](https://ai.google.dev/api/generate-content#generationconfig)
1304    /// ### Rig Note:
1305    /// Can be used to construct a typesafe `additional_params` in rig::[AgentBuilder](crate::agent::AgentBuilder).
1306    #[derive(Debug, Deserialize, Serialize)]
1307    #[serde(rename_all = "camelCase")]
1308    pub struct GenerationConfig {
1309        /// The set of character sequences (up to 5) that will stop output generation. If specified, the API will stop
1310        /// at the first appearance of a stop_sequence. The stop sequence will not be included as part of the response.
1311        #[serde(skip_serializing_if = "Option::is_none")]
1312        pub stop_sequences: Option<Vec<String>>,
1313        /// MIME type of the generated candidate text. Supported MIME types are:
1314        ///     - text/plain:  (default) Text output
1315        ///     - application/json: JSON response in the response candidates.
1316        ///     - text/x.enum: ENUM as a string response in the response candidates.
1317        /// Refer to the docs for a list of all supported text MIME types
1318        #[serde(skip_serializing_if = "Option::is_none")]
1319        pub response_mime_type: Option<String>,
1320        /// Output schema of the generated candidate text. Schemas must be a subset of the OpenAPI schema and can be
1321        /// objects, primitives or arrays. If set, a compatible responseMimeType must also  be set. Compatible MIME
1322        /// types: application/json: Schema for JSON response. Refer to the JSON text generation guide for more details.
1323        #[serde(skip_serializing_if = "Option::is_none")]
1324        pub response_schema: Option<Schema>,
1325        /// Optional. The output schema of the generated response.
1326        /// This is an alternative to responseSchema that accepts a standard JSON Schema.
1327        /// If this is set, responseSchema must be omitted.
1328        /// Compatible MIME type: application/json.
1329        /// Supported properties: $id, $defs, $ref, type, properties, etc.
1330        #[serde(
1331            skip_serializing_if = "Option::is_none",
1332            rename = "_responseJsonSchema"
1333        )]
1334        pub _response_json_schema: Option<Value>,
1335        /// Internal or alternative representation for `response_json_schema`.
1336        #[serde(skip_serializing_if = "Option::is_none")]
1337        pub response_json_schema: Option<Value>,
1338        /// Number of generated responses to return. Currently, this value can only be set to 1. If
1339        /// unset, this will default to 1.
1340        #[serde(skip_serializing_if = "Option::is_none")]
1341        pub candidate_count: Option<i32>,
1342        /// The maximum number of tokens to include in a response candidate. Note: The default value varies by model, see
1343        /// the Model.output_token_limit attribute of the Model returned from the getModel function.
1344        #[serde(skip_serializing_if = "Option::is_none")]
1345        pub max_output_tokens: Option<u64>,
1346        /// Controls the randomness of the output. Note: The default value varies by model, see the Model.temperature
1347        /// attribute of the Model returned from the getModel function. Values can range from [0.0, 2.0].
1348        #[serde(skip_serializing_if = "Option::is_none")]
1349        pub temperature: Option<f64>,
1350        /// The maximum cumulative probability of tokens to consider when sampling. The model uses combined Top-k and
1351        /// Top-p (nucleus) sampling. Tokens are sorted based on their assigned probabilities so that only the most
1352        /// likely tokens are considered. Top-k sampling directly limits the maximum number of tokens to consider, while
1353        /// Nucleus sampling limits the number of tokens based on the cumulative probability. Note: The default value
1354        /// varies by Model and is specified by theModel.top_p attribute returned from the getModel function. An empty
1355        /// topK attribute indicates that the model doesn't apply top-k sampling and doesn't allow setting topK on requests.
1356        #[serde(skip_serializing_if = "Option::is_none")]
1357        pub top_p: Option<f64>,
1358        /// The maximum number of tokens to consider when sampling. Gemini models use Top-p (nucleus) sampling or a
1359        /// combination of Top-k and nucleus sampling. Top-k sampling considers the set of topK most probable tokens.
1360        /// Models running with nucleus sampling don't allow topK setting. Note: The default value varies by Model and is
1361        /// specified by theModel.top_p attribute returned from the getModel function. An empty topK attribute indicates
1362        /// that the model doesn't apply top-k sampling and doesn't allow setting topK on requests.
1363        #[serde(skip_serializing_if = "Option::is_none")]
1364        pub top_k: Option<i32>,
1365        /// Presence penalty applied to the next token's logprobs if the token has already been seen in the response.
1366        /// This penalty is binary on/off and not dependent on the number of times the token is used (after the first).
1367        /// Use frequencyPenalty for a penalty that increases with each use. A positive penalty will discourage the use
1368        /// of tokens that have already been used in the response, increasing the vocabulary. A negative penalty will
1369        /// encourage the use of tokens that have already been used in the response, decreasing the vocabulary.
1370        #[serde(skip_serializing_if = "Option::is_none")]
1371        pub presence_penalty: Option<f64>,
1372        /// Frequency penalty applied to the next token's logprobs, multiplied by the number of times each token has been
1373        /// seen in the response so far. A positive penalty will discourage the use of tokens that have already been
1374        /// used, proportional to the number of times the token has been used: The more a token is used, the more
1375        /// difficult it is for the  model to use that token again increasing the vocabulary of responses. Caution: A
1376        /// negative penalty will encourage the model to reuse tokens proportional to the number of times the token has
1377        /// been used. Small negative values will reduce the vocabulary of a response. Larger negative values will cause
1378        /// the model to  repeating a common token until it hits the maxOutputTokens limit: "...the the the the the...".
1379        #[serde(skip_serializing_if = "Option::is_none")]
1380        pub frequency_penalty: Option<f64>,
1381        /// If true, export the logprobs results in response.
1382        #[serde(skip_serializing_if = "Option::is_none")]
1383        pub response_logprobs: Option<bool>,
1384        /// Only valid if responseLogprobs=True. This sets the number of top logprobs to return at each decoding step in
1385        /// [Candidate.logprobs_result].
1386        #[serde(skip_serializing_if = "Option::is_none")]
1387        pub logprobs: Option<i32>,
1388        /// Configuration for thinking/reasoning.
1389        #[serde(skip_serializing_if = "Option::is_none")]
1390        pub thinking_config: Option<ThinkingConfig>,
1391        #[serde(skip_serializing_if = "Option::is_none")]
1392        pub image_config: Option<ImageConfig>,
1393    }
1394
1395    impl Default for GenerationConfig {
1396        fn default() -> Self {
1397            Self {
1398                temperature: Some(1.0),
1399                max_output_tokens: Some(4096),
1400                stop_sequences: None,
1401                response_mime_type: None,
1402                response_schema: None,
1403                _response_json_schema: None,
1404                response_json_schema: None,
1405                candidate_count: None,
1406                top_p: None,
1407                top_k: None,
1408                presence_penalty: None,
1409                frequency_penalty: None,
1410                response_logprobs: None,
1411                logprobs: None,
1412                thinking_config: None,
1413                image_config: None,
1414            }
1415        }
1416    }
1417
1418    #[derive(Debug, Deserialize, Serialize)]
1419    #[serde(rename_all = "camelCase")]
1420    pub struct ThinkingConfig {
1421        pub thinking_budget: u32,
1422        pub include_thoughts: Option<bool>,
1423    }
1424
1425    #[derive(Debug, Deserialize, Serialize)]
1426    #[serde(rename_all = "camelCase")]
1427    pub struct ImageConfig {
1428        #[serde(skip_serializing_if = "Option::is_none")]
1429        pub aspect_ratio: Option<String>,
1430        #[serde(skip_serializing_if = "Option::is_none")]
1431        pub image_size: Option<String>,
1432    }
1433
1434    /// The Schema object allows the definition of input and output data types. These types can be objects, but also
1435    /// primitives and arrays. Represents a select subset of an OpenAPI 3.0 schema object.
1436    /// From [Gemini API Reference](https://ai.google.dev/api/caching#Schema)
1437    #[derive(Debug, Deserialize, Serialize, Clone)]
1438    pub struct Schema {
1439        pub r#type: String,
1440        #[serde(skip_serializing_if = "Option::is_none")]
1441        pub format: Option<String>,
1442        #[serde(skip_serializing_if = "Option::is_none")]
1443        pub description: Option<String>,
1444        #[serde(skip_serializing_if = "Option::is_none")]
1445        pub nullable: Option<bool>,
1446        #[serde(skip_serializing_if = "Option::is_none")]
1447        pub r#enum: Option<Vec<String>>,
1448        #[serde(skip_serializing_if = "Option::is_none")]
1449        pub max_items: Option<i32>,
1450        #[serde(skip_serializing_if = "Option::is_none")]
1451        pub min_items: Option<i32>,
1452        #[serde(skip_serializing_if = "Option::is_none")]
1453        pub properties: Option<HashMap<String, Schema>>,
1454        #[serde(skip_serializing_if = "Option::is_none")]
1455        pub required: Option<Vec<String>>,
1456        #[serde(skip_serializing_if = "Option::is_none")]
1457        pub items: Option<Box<Schema>>,
1458    }
1459
1460    /// Flattens a JSON schema by resolving all `$ref` references inline.
1461    /// It takes a JSON schema that may contain `$ref` references to definitions
1462    /// in `$defs` or `definitions` sections and returns a new schema with all references
1463    /// resolved and inlined. This is necessary for APIs like Gemini that don't support
1464    /// schema references.
1465    pub fn flatten_schema(mut schema: Value) -> Result<Value, CompletionError> {
1466        // extracting $defs if they exist
1467        let defs = if let Some(obj) = schema.as_object() {
1468            obj.get("$defs").or_else(|| obj.get("definitions")).cloned()
1469        } else {
1470            None
1471        };
1472
1473        let Some(defs_value) = defs else {
1474            return Ok(schema);
1475        };
1476
1477        let Some(defs_obj) = defs_value.as_object() else {
1478            return Err(CompletionError::ResponseError(
1479                "$defs must be an object".into(),
1480            ));
1481        };
1482
1483        resolve_refs(&mut schema, defs_obj)?;
1484
1485        // removing $defs from the final schema because we have inlined everything
1486        if let Some(obj) = schema.as_object_mut() {
1487            obj.remove("$defs");
1488            obj.remove("definitions");
1489        }
1490
1491        Ok(schema)
1492    }
1493
1494    /// Recursively resolves all `$ref` references in a JSON value by
1495    /// replacing them with their definitions.
1496    fn resolve_refs(
1497        value: &mut Value,
1498        defs: &serde_json::Map<String, Value>,
1499    ) -> Result<(), CompletionError> {
1500        match value {
1501            Value::Object(obj) => {
1502                if let Some(ref_value) = obj.get("$ref")
1503                    && let Some(ref_str) = ref_value.as_str()
1504                {
1505                    // "#/$defs/Person" -> "Person"
1506                    let def_name = parse_ref_path(ref_str)?;
1507
1508                    let def = defs.get(&def_name).ok_or_else(|| {
1509                        CompletionError::ResponseError(format!("Reference not found: {}", ref_str))
1510                    })?;
1511
1512                    let mut resolved = def.clone();
1513                    resolve_refs(&mut resolved, defs)?;
1514                    *value = resolved;
1515                    return Ok(());
1516                }
1517
1518                for (_, v) in obj.iter_mut() {
1519                    resolve_refs(v, defs)?;
1520                }
1521            }
1522            Value::Array(arr) => {
1523                for item in arr.iter_mut() {
1524                    resolve_refs(item, defs)?;
1525                }
1526            }
1527            _ => {}
1528        }
1529
1530        Ok(())
1531    }
1532
1533    /// Parses a JSON Schema `$ref` path to extract the definition name.
1534    ///
1535    /// JSON Schema references use URI fragment syntax to point to definitions within
1536    /// the same document. This function extracts the definition name from common
1537    /// reference patterns used in JSON Schema.
1538    fn parse_ref_path(ref_str: &str) -> Result<String, CompletionError> {
1539        if let Some(fragment) = ref_str.strip_prefix('#') {
1540            if let Some(name) = fragment.strip_prefix("/$defs/") {
1541                Ok(name.to_string())
1542            } else if let Some(name) = fragment.strip_prefix("/definitions/") {
1543                Ok(name.to_string())
1544            } else {
1545                Err(CompletionError::ResponseError(format!(
1546                    "Unsupported reference format: {}",
1547                    ref_str
1548                )))
1549            }
1550        } else {
1551            Err(CompletionError::ResponseError(format!(
1552                "Only fragment references (#/...) are supported: {}",
1553                ref_str
1554            )))
1555        }
1556    }
1557
1558    /// Helper function to extract the type string from a JSON value.
1559    /// Handles both direct string types and array types (returns the first element).
1560    fn extract_type(type_value: &Value) -> Option<String> {
1561        if type_value.is_string() {
1562            type_value.as_str().map(String::from)
1563        } else if type_value.is_array() {
1564            type_value
1565                .as_array()
1566                .and_then(|arr| arr.first())
1567                .and_then(|v| v.as_str().map(String::from))
1568        } else {
1569            None
1570        }
1571    }
1572
1573    /// Helper function to extract type from anyOf, oneOf, or allOf schemas.
1574    /// Returns the type of the first non-null schema found.
1575    fn extract_type_from_composition(composition: &Value) -> Option<String> {
1576        composition.as_array().and_then(|arr| {
1577            arr.iter().find_map(|schema| {
1578                if let Some(obj) = schema.as_object() {
1579                    // Skip null types
1580                    if let Some(type_val) = obj.get("type")
1581                        && let Some(type_str) = type_val.as_str()
1582                        && type_str == "null"
1583                    {
1584                        return None;
1585                    }
1586                    // Extract type from this schema
1587                    obj.get("type").and_then(extract_type).or_else(|| {
1588                        if obj.contains_key("properties") {
1589                            Some("object".to_string())
1590                        } else {
1591                            None
1592                        }
1593                    })
1594                } else {
1595                    None
1596                }
1597            })
1598        })
1599    }
1600
1601    /// Helper function to extract the first non-null schema from anyOf, oneOf, or allOf.
1602    /// Returns the schema object that should be used for properties, required, etc.
1603    fn extract_schema_from_composition(
1604        composition: &Value,
1605    ) -> Option<serde_json::Map<String, Value>> {
1606        composition.as_array().and_then(|arr| {
1607            arr.iter().find_map(|schema| {
1608                if let Some(obj) = schema.as_object()
1609                    && let Some(type_val) = obj.get("type")
1610                    && let Some(type_str) = type_val.as_str()
1611                {
1612                    if type_str == "null" {
1613                        return None;
1614                    }
1615                    Some(obj.clone())
1616                } else {
1617                    None
1618                }
1619            })
1620        })
1621    }
1622
1623    /// Helper function to infer the type of a schema object.
1624    /// Checks for explicit type, then anyOf/oneOf/allOf, then infers from properties.
1625    fn infer_type(obj: &serde_json::Map<String, Value>) -> String {
1626        // First, try direct type field
1627        if let Some(type_val) = obj.get("type")
1628            && let Some(type_str) = extract_type(type_val)
1629        {
1630            return type_str;
1631        }
1632
1633        // Then try anyOf, oneOf, allOf (in that order)
1634        if let Some(any_of) = obj.get("anyOf")
1635            && let Some(type_str) = extract_type_from_composition(any_of)
1636        {
1637            return type_str;
1638        }
1639
1640        if let Some(one_of) = obj.get("oneOf")
1641            && let Some(type_str) = extract_type_from_composition(one_of)
1642        {
1643            return type_str;
1644        }
1645
1646        if let Some(all_of) = obj.get("allOf")
1647            && let Some(type_str) = extract_type_from_composition(all_of)
1648        {
1649            return type_str;
1650        }
1651
1652        // Finally, infer object type if properties are present
1653        if obj.contains_key("properties") {
1654            "object".to_string()
1655        } else {
1656            String::new()
1657        }
1658    }
1659
1660    impl TryFrom<Value> for Schema {
1661        type Error = CompletionError;
1662
1663        fn try_from(value: Value) -> Result<Self, Self::Error> {
1664            let flattened_val = flatten_schema(value)?;
1665            if let Some(obj) = flattened_val.as_object() {
1666                // Determine which object to use for extracting properties and required fields.
1667                // If this object has anyOf/oneOf/allOf, we need to extract properties from the composition.
1668                let props_source = if obj.get("properties").is_none() {
1669                    if let Some(any_of) = obj.get("anyOf") {
1670                        extract_schema_from_composition(any_of)
1671                    } else if let Some(one_of) = obj.get("oneOf") {
1672                        extract_schema_from_composition(one_of)
1673                    } else if let Some(all_of) = obj.get("allOf") {
1674                        extract_schema_from_composition(all_of)
1675                    } else {
1676                        None
1677                    }
1678                    .unwrap_or(obj.clone())
1679                } else {
1680                    obj.clone()
1681                };
1682
1683                Ok(Schema {
1684                    r#type: infer_type(obj),
1685                    format: obj.get("format").and_then(|v| v.as_str()).map(String::from),
1686                    description: obj
1687                        .get("description")
1688                        .and_then(|v| v.as_str())
1689                        .map(String::from),
1690                    nullable: obj.get("nullable").and_then(|v| v.as_bool()),
1691                    r#enum: obj.get("enum").and_then(|v| v.as_array()).map(|arr| {
1692                        arr.iter()
1693                            .filter_map(|v| v.as_str().map(String::from))
1694                            .collect()
1695                    }),
1696                    max_items: obj
1697                        .get("maxItems")
1698                        .and_then(|v| v.as_i64())
1699                        .map(|v| v as i32),
1700                    min_items: obj
1701                        .get("minItems")
1702                        .and_then(|v| v.as_i64())
1703                        .map(|v| v as i32),
1704                    properties: props_source
1705                        .get("properties")
1706                        .and_then(|v| v.as_object())
1707                        .map(|map| {
1708                            map.iter()
1709                                .filter_map(|(k, v)| {
1710                                    v.clone().try_into().ok().map(|schema| (k.clone(), schema))
1711                                })
1712                                .collect()
1713                        }),
1714                    required: props_source
1715                        .get("required")
1716                        .and_then(|v| v.as_array())
1717                        .map(|arr| {
1718                            arr.iter()
1719                                .filter_map(|v| v.as_str().map(String::from))
1720                                .collect()
1721                        }),
1722                    items: obj
1723                        .get("items")
1724                        .and_then(|v| v.clone().try_into().ok())
1725                        .map(Box::new),
1726                })
1727            } else {
1728                Err(CompletionError::ResponseError(
1729                    "Expected a JSON object for Schema".into(),
1730                ))
1731            }
1732        }
1733    }
1734
1735    #[derive(Debug, Serialize)]
1736    #[serde(rename_all = "camelCase")]
1737    pub struct GenerateContentRequest {
1738        pub contents: Vec<Content>,
1739        #[serde(skip_serializing_if = "Option::is_none")]
1740        pub tools: Option<Tool>,
1741        pub tool_config: Option<ToolConfig>,
1742        /// Optional. Configuration options for model generation and outputs.
1743        pub generation_config: Option<GenerationConfig>,
1744        /// Optional. A list of unique SafetySetting instances for blocking unsafe content. This will be enforced on the
1745        /// [GenerateContentRequest.contents] and [GenerateContentResponse.candidates]. There should not be more than one
1746        /// setting for each SafetyCategory type. The API will block any contents and responses that fail to meet the
1747        /// thresholds set by these settings. This list overrides the default settings for each SafetyCategory specified
1748        /// in the safetySettings. If there is no SafetySetting for a given SafetyCategory provided in the list, the API
1749        /// will use the default safety setting for that category. Harm categories:
1750        ///     - HARM_CATEGORY_HATE_SPEECH,
1751        ///     - HARM_CATEGORY_SEXUALLY_EXPLICIT
1752        ///     - HARM_CATEGORY_DANGEROUS_CONTENT
1753        ///     - HARM_CATEGORY_HARASSMENT
1754        /// are supported.
1755        /// Refer to the guide for detailed information on available safety settings. Also refer to the Safety guidance
1756        /// to learn how to incorporate safety considerations in your AI applications.
1757        pub safety_settings: Option<Vec<SafetySetting>>,
1758        /// Optional. Developer set system instruction(s). Currently, text only.
1759        /// From [Gemini API Reference](https://ai.google.dev/gemini-api/docs/system-instructions?lang=rest)
1760        pub system_instruction: Option<Content>,
1761        // cachedContent: Optional<String>
1762        /// Additional parameters.
1763        #[serde(flatten, skip_serializing_if = "Option::is_none")]
1764        pub additional_params: Option<serde_json::Value>,
1765    }
1766
1767    #[derive(Debug, Serialize)]
1768    #[serde(rename_all = "camelCase")]
1769    pub struct Tool {
1770        pub function_declarations: Vec<FunctionDeclaration>,
1771        pub code_execution: Option<CodeExecution>,
1772    }
1773
1774    #[derive(Debug, Serialize, Clone)]
1775    #[serde(rename_all = "camelCase")]
1776    pub struct FunctionDeclaration {
1777        pub name: String,
1778        pub description: String,
1779        #[serde(skip_serializing_if = "Option::is_none")]
1780        pub parameters: Option<Schema>,
1781    }
1782
1783    #[derive(Debug, Serialize, Deserialize)]
1784    #[serde(rename_all = "camelCase")]
1785    pub struct ToolConfig {
1786        pub function_calling_config: Option<FunctionCallingMode>,
1787    }
1788
1789    #[derive(Debug, Serialize, Deserialize, Default)]
1790    #[serde(tag = "mode", rename_all = "UPPERCASE")]
1791    pub enum FunctionCallingMode {
1792        #[default]
1793        Auto,
1794        None,
1795        Any {
1796            #[serde(skip_serializing_if = "Option::is_none")]
1797            allowed_function_names: Option<Vec<String>>,
1798        },
1799    }
1800
1801    impl TryFrom<message::ToolChoice> for FunctionCallingMode {
1802        type Error = CompletionError;
1803        fn try_from(value: message::ToolChoice) -> Result<Self, Self::Error> {
1804            let res = match value {
1805                message::ToolChoice::Auto => Self::Auto,
1806                message::ToolChoice::None => Self::None,
1807                message::ToolChoice::Required => Self::Any {
1808                    allowed_function_names: None,
1809                },
1810                message::ToolChoice::Specific { function_names } => Self::Any {
1811                    allowed_function_names: Some(function_names),
1812                },
1813            };
1814
1815            Ok(res)
1816        }
1817    }
1818
1819    #[derive(Debug, Serialize)]
1820    pub struct CodeExecution {}
1821
1822    #[derive(Debug, Serialize)]
1823    #[serde(rename_all = "camelCase")]
1824    pub struct SafetySetting {
1825        pub category: HarmCategory,
1826        pub threshold: HarmBlockThreshold,
1827    }
1828
1829    #[derive(Debug, Serialize)]
1830    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1831    pub enum HarmBlockThreshold {
1832        HarmBlockThresholdUnspecified,
1833        BlockLowAndAbove,
1834        BlockMediumAndAbove,
1835        BlockOnlyHigh,
1836        BlockNone,
1837        Off,
1838    }
1839}
1840
1841#[cfg(test)]
1842mod tests {
1843    use crate::{message, providers::gemini::completion::gemini_api_types::flatten_schema};
1844
1845    use super::*;
1846    use serde_json::json;
1847
1848    #[test]
1849    fn test_deserialize_message_user() {
1850        let raw_message = r#"{
1851            "parts": [
1852                {"text": "Hello, world!"},
1853                {"inlineData": {"mimeType": "image/png", "data": "base64encodeddata"}},
1854                {"functionCall": {"name": "test_function", "args": {"arg1": "value1"}}},
1855                {"functionResponse": {"name": "test_function", "response": {"result": "success"}}},
1856                {"fileData": {"mimeType": "application/pdf", "fileUri": "http://example.com/file.pdf"}},
1857                {"executableCode": {"code": "print('Hello, world!')", "language": "PYTHON"}},
1858                {"codeExecutionResult": {"output": "Hello, world!", "outcome": "OUTCOME_OK"}}
1859            ],
1860            "role": "user"
1861        }"#;
1862
1863        let content: Content = {
1864            let jd = &mut serde_json::Deserializer::from_str(raw_message);
1865            serde_path_to_error::deserialize(jd).unwrap_or_else(|err| {
1866                panic!("Deserialization error at {}: {}", err.path(), err);
1867            })
1868        };
1869        assert_eq!(content.role, Some(Role::User));
1870        assert_eq!(content.parts.len(), 7);
1871
1872        let parts: Vec<Part> = content.parts.into_iter().collect();
1873
1874        if let Part {
1875            part: PartKind::Text(text),
1876            ..
1877        } = &parts[0]
1878        {
1879            assert_eq!(text, "Hello, world!");
1880        } else {
1881            panic!("Expected text part");
1882        }
1883
1884        if let Part {
1885            part: PartKind::InlineData(inline_data),
1886            ..
1887        } = &parts[1]
1888        {
1889            assert_eq!(inline_data.mime_type, "image/png");
1890            assert_eq!(inline_data.data, "base64encodeddata");
1891        } else {
1892            panic!("Expected inline data part");
1893        }
1894
1895        if let Part {
1896            part: PartKind::FunctionCall(function_call),
1897            ..
1898        } = &parts[2]
1899        {
1900            assert_eq!(function_call.name, "test_function");
1901            assert_eq!(
1902                function_call.args.as_object().unwrap().get("arg1").unwrap(),
1903                "value1"
1904            );
1905        } else {
1906            panic!("Expected function call part");
1907        }
1908
1909        if let Part {
1910            part: PartKind::FunctionResponse(function_response),
1911            ..
1912        } = &parts[3]
1913        {
1914            assert_eq!(function_response.name, "test_function");
1915            assert_eq!(
1916                function_response
1917                    .response
1918                    .as_ref()
1919                    .unwrap()
1920                    .get("result")
1921                    .unwrap(),
1922                "success"
1923            );
1924        } else {
1925            panic!("Expected function response part");
1926        }
1927
1928        if let Part {
1929            part: PartKind::FileData(file_data),
1930            ..
1931        } = &parts[4]
1932        {
1933            assert_eq!(file_data.mime_type.as_ref().unwrap(), "application/pdf");
1934            assert_eq!(file_data.file_uri, "http://example.com/file.pdf");
1935        } else {
1936            panic!("Expected file data part");
1937        }
1938
1939        if let Part {
1940            part: PartKind::ExecutableCode(executable_code),
1941            ..
1942        } = &parts[5]
1943        {
1944            assert_eq!(executable_code.code, "print('Hello, world!')");
1945        } else {
1946            panic!("Expected executable code part");
1947        }
1948
1949        if let Part {
1950            part: PartKind::CodeExecutionResult(code_execution_result),
1951            ..
1952        } = &parts[6]
1953        {
1954            assert_eq!(
1955                code_execution_result.clone().output.unwrap(),
1956                "Hello, world!"
1957            );
1958        } else {
1959            panic!("Expected code execution result part");
1960        }
1961    }
1962
1963    #[test]
1964    fn test_deserialize_message_model() {
1965        let json_data = json!({
1966            "parts": [{"text": "Hello, user!"}],
1967            "role": "model"
1968        });
1969
1970        let content: Content = serde_json::from_value(json_data).unwrap();
1971        assert_eq!(content.role, Some(Role::Model));
1972        assert_eq!(content.parts.len(), 1);
1973        if let Some(Part {
1974            part: PartKind::Text(text),
1975            ..
1976        }) = content.parts.first()
1977        {
1978            assert_eq!(text, "Hello, user!");
1979        } else {
1980            panic!("Expected text part");
1981        }
1982    }
1983
1984    #[test]
1985    fn test_message_conversion_user() {
1986        let msg = message::Message::user("Hello, world!");
1987        let content: Content = msg.try_into().unwrap();
1988        assert_eq!(content.role, Some(Role::User));
1989        assert_eq!(content.parts.len(), 1);
1990        if let Some(Part {
1991            part: PartKind::Text(text),
1992            ..
1993        }) = &content.parts.first()
1994        {
1995            assert_eq!(text, "Hello, world!");
1996        } else {
1997            panic!("Expected text part");
1998        }
1999    }
2000
2001    #[test]
2002    fn test_message_conversion_model() {
2003        let msg = message::Message::assistant("Hello, user!");
2004
2005        let content: Content = msg.try_into().unwrap();
2006        assert_eq!(content.role, Some(Role::Model));
2007        assert_eq!(content.parts.len(), 1);
2008        if let Some(Part {
2009            part: PartKind::Text(text),
2010            ..
2011        }) = &content.parts.first()
2012        {
2013            assert_eq!(text, "Hello, user!");
2014        } else {
2015            panic!("Expected text part");
2016        }
2017    }
2018
2019    #[test]
2020    fn test_message_conversion_tool_call() {
2021        let tool_call = message::ToolCall {
2022            id: "test_tool".to_string(),
2023            call_id: None,
2024            function: message::ToolFunction {
2025                name: "test_function".to_string(),
2026                arguments: json!({"arg1": "value1"}),
2027            },
2028        };
2029
2030        let msg = message::Message::Assistant {
2031            id: None,
2032            content: OneOrMany::one(message::AssistantContent::ToolCall(tool_call)),
2033        };
2034
2035        let content: Content = msg.try_into().unwrap();
2036        assert_eq!(content.role, Some(Role::Model));
2037        assert_eq!(content.parts.len(), 1);
2038        if let Some(Part {
2039            part: PartKind::FunctionCall(function_call),
2040            ..
2041        }) = content.parts.first()
2042        {
2043            assert_eq!(function_call.name, "test_function");
2044            assert_eq!(
2045                function_call.args.as_object().unwrap().get("arg1").unwrap(),
2046                "value1"
2047            );
2048        } else {
2049            panic!("Expected function call part");
2050        }
2051    }
2052
2053    #[test]
2054    fn test_vec_schema_conversion() {
2055        let schema_with_ref = json!({
2056            "type": "array",
2057            "items": {
2058                "$ref": "#/$defs/Person"
2059            },
2060            "$defs": {
2061                "Person": {
2062                    "type": "object",
2063                    "properties": {
2064                        "first_name": {
2065                            "type": ["string", "null"],
2066                            "description": "The person's first name, if provided (null otherwise)"
2067                        },
2068                        "last_name": {
2069                            "type": ["string", "null"],
2070                            "description": "The person's last name, if provided (null otherwise)"
2071                        },
2072                        "job": {
2073                            "type": ["string", "null"],
2074                            "description": "The person's job, if provided (null otherwise)"
2075                        }
2076                    },
2077                    "required": []
2078                }
2079            }
2080        });
2081
2082        let result: Result<Schema, _> = schema_with_ref.try_into();
2083
2084        match result {
2085            Ok(schema) => {
2086                assert_eq!(schema.r#type, "array");
2087
2088                if let Some(items) = schema.items {
2089                    println!("item types: {}", items.r#type);
2090
2091                    assert_ne!(items.r#type, "", "Items type should not be empty string!");
2092                    assert_eq!(items.r#type, "object", "Items should be object type");
2093                } else {
2094                    panic!("Schema should have items field for array type");
2095                }
2096            }
2097            Err(e) => println!("Schema conversion failed: {:?}", e),
2098        }
2099    }
2100
2101    #[test]
2102    fn test_object_schema() {
2103        let simple_schema = json!({
2104            "type": "object",
2105            "properties": {
2106                "name": {
2107                    "type": "string"
2108                }
2109            }
2110        });
2111
2112        let schema: Schema = simple_schema.try_into().unwrap();
2113        assert_eq!(schema.r#type, "object");
2114        assert!(schema.properties.is_some());
2115    }
2116
2117    #[test]
2118    fn test_array_with_inline_items() {
2119        let inline_schema = json!({
2120            "type": "array",
2121            "items": {
2122                "type": "object",
2123                "properties": {
2124                    "name": {
2125                        "type": "string"
2126                    }
2127                }
2128            }
2129        });
2130
2131        let schema: Schema = inline_schema.try_into().unwrap();
2132        assert_eq!(schema.r#type, "array");
2133
2134        if let Some(items) = schema.items {
2135            assert_eq!(items.r#type, "object");
2136            assert!(items.properties.is_some());
2137        } else {
2138            panic!("Schema should have items field");
2139        }
2140    }
2141    #[test]
2142    fn test_flattened_schema() {
2143        let ref_schema = json!({
2144            "type": "array",
2145            "items": {
2146                "$ref": "#/$defs/Person"
2147            },
2148            "$defs": {
2149                "Person": {
2150                    "type": "object",
2151                    "properties": {
2152                        "name": { "type": "string" }
2153                    }
2154                }
2155            }
2156        });
2157
2158        let flattened = flatten_schema(ref_schema).unwrap();
2159        let schema: Schema = flattened.try_into().unwrap();
2160
2161        assert_eq!(schema.r#type, "array");
2162
2163        if let Some(items) = schema.items {
2164            println!("Flattened items type: '{}'", items.r#type);
2165
2166            assert_eq!(items.r#type, "object");
2167            assert!(items.properties.is_some());
2168        }
2169    }
2170}