rig/providers/gemini/
completion.rs

1// ================================================================
2//! Google Gemini Completion Integration
3//! From [Gemini API Reference](https://ai.google.dev/api/generate-content)
4// ================================================================
5/// `gemini-2.5-pro-preview-06-05` completion model
6pub const GEMINI_2_5_PRO_PREVIEW_06_05: &str = "gemini-2.5-pro-preview-06-05";
7/// `gemini-2.5-pro-preview-05-06` completion model
8pub const GEMINI_2_5_PRO_PREVIEW_05_06: &str = "gemini-2.5-pro-preview-05-06";
9/// `gemini-2.5-pro-preview-03-25` completion model
10pub const GEMINI_2_5_PRO_PREVIEW_03_25: &str = "gemini-2.5-pro-preview-03-25";
11/// `gemini-2.5-flash-preview-05-20` completion model
12pub const GEMINI_2_5_FLASH_PREVIEW_05_20: &str = "gemini-2.5-flash-preview-05-20";
13/// `gemini-2.5-flash-preview-04-17` completion model
14pub const GEMINI_2_5_FLASH_PREVIEW_04_17: &str = "gemini-2.5-flash-preview-04-17";
15/// `gemini-2.5-pro-exp-03-25` experimental completion model
16pub const GEMINI_2_5_PRO_EXP_03_25: &str = "gemini-2.5-pro-exp-03-25";
17/// `gemini-2.0-flash-lite` completion model
18pub const GEMINI_2_0_FLASH_LITE: &str = "gemini-2.0-flash-lite";
19/// `gemini-2.0-flash` completion model
20pub const GEMINI_2_0_FLASH: &str = "gemini-2.0-flash";
21/// `gemini-1.5-flash` completion model
22pub const GEMINI_1_5_FLASH: &str = "gemini-1.5-flash";
23/// `gemini-1.5-pro` completion model
24pub const GEMINI_1_5_PRO: &str = "gemini-1.5-pro";
25/// `gemini-1.5-pro-8b` completion model
26pub const GEMINI_1_5_PRO_8B: &str = "gemini-1.5-pro-8b";
27/// `gemini-1.0-pro` completion model
28pub const GEMINI_1_0_PRO: &str = "gemini-1.0-pro";
29
30use self::gemini_api_types::Schema;
31use crate::message::Reasoning;
32use crate::providers::gemini::completion::gemini_api_types::AdditionalParameters;
33use crate::providers::gemini::streaming::StreamingCompletionResponse;
34use crate::{
35    OneOrMany,
36    completion::{self, CompletionError, CompletionRequest},
37};
38use gemini_api_types::{
39    Content, FunctionDeclaration, GenerateContentRequest, GenerateContentResponse, Part, PartKind,
40    Role, Tool,
41};
42use serde_json::{Map, Value};
43use std::convert::TryFrom;
44
45use super::Client;
46
47// =================================================================
48// Rig Implementation Types
49// =================================================================
50
51#[derive(Clone)]
52pub struct CompletionModel {
53    pub(crate) client: Client,
54    pub model: String,
55}
56
57impl CompletionModel {
58    pub fn new(client: Client, model: &str) -> Self {
59        Self {
60            client,
61            model: model.to_string(),
62        }
63    }
64}
65
66impl completion::CompletionModel for CompletionModel {
67    type Response = GenerateContentResponse;
68    type StreamingResponse = StreamingCompletionResponse;
69
70    #[cfg_attr(feature = "worker", worker::send)]
71    async fn completion(
72        &self,
73        completion_request: CompletionRequest,
74    ) -> Result<completion::CompletionResponse<GenerateContentResponse>, CompletionError> {
75        let request = create_request_body(completion_request)?;
76
77        tracing::debug!(
78            "Sending completion request to Gemini API {}",
79            serde_json::to_string_pretty(&request)?
80        );
81
82        let response = self
83            .client
84            .post(&format!("/v1beta/models/{}:generateContent", self.model))
85            .json(&request)
86            .send()
87            .await?;
88
89        if response.status().is_success() {
90            let response = response.json::<GenerateContentResponse>().await?;
91            match response.usage_metadata {
92                Some(ref usage) => tracing::info!(target: "rig",
93                "Gemini completion token usage: {}",
94                usage
95                ),
96                None => tracing::info!(target: "rig",
97                    "Gemini completion token usage: n/a",
98                ),
99            }
100
101            tracing::debug!("Received response");
102
103            Ok(completion::CompletionResponse::try_from(response))
104        } else {
105            Err(CompletionError::ProviderError(response.text().await?))
106        }?
107    }
108
109    #[cfg_attr(feature = "worker", worker::send)]
110    async fn stream(
111        &self,
112        request: CompletionRequest,
113    ) -> Result<
114        crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
115        CompletionError,
116    > {
117        CompletionModel::stream(self, request).await
118    }
119}
120
121pub(crate) fn create_request_body(
122    completion_request: CompletionRequest,
123) -> Result<GenerateContentRequest, CompletionError> {
124    let mut full_history = Vec::new();
125    full_history.extend(completion_request.chat_history);
126
127    let additional_params = completion_request
128        .additional_params
129        .unwrap_or_else(|| Value::Object(Map::new()));
130
131    let AdditionalParameters {
132        mut generation_config,
133        additional_params,
134    } = serde_json::from_value::<AdditionalParameters>(additional_params)?;
135
136    if let Some(temp) = completion_request.temperature {
137        generation_config.temperature = Some(temp);
138    }
139
140    if let Some(max_tokens) = completion_request.max_tokens {
141        generation_config.max_output_tokens = Some(max_tokens);
142    }
143
144    let system_instruction = completion_request.preamble.clone().map(|preamble| Content {
145        parts: vec![preamble.into()],
146        role: Some(Role::Model),
147    });
148
149    let tools = if completion_request.tools.is_empty() {
150        None
151    } else {
152        Some(Tool::try_from(completion_request.tools)?)
153    };
154
155    let request = GenerateContentRequest {
156        contents: full_history
157            .into_iter()
158            .map(|msg| {
159                msg.try_into()
160                    .map_err(|e| CompletionError::RequestError(Box::new(e)))
161            })
162            .collect::<Result<Vec<_>, _>>()?,
163        generation_config: Some(generation_config),
164        safety_settings: None,
165        tools,
166        tool_config: None,
167        system_instruction,
168        additional_params,
169    };
170
171    Ok(request)
172}
173
174impl TryFrom<completion::ToolDefinition> for Tool {
175    type Error = CompletionError;
176
177    fn try_from(tool: completion::ToolDefinition) -> Result<Self, Self::Error> {
178        let parameters: Option<Schema> =
179            if tool.parameters == serde_json::json!({"type": "object", "properties": {}}) {
180                None
181            } else {
182                Some(tool.parameters.try_into()?)
183            };
184
185        Ok(Self {
186            function_declarations: vec![FunctionDeclaration {
187                name: tool.name,
188                description: tool.description,
189                parameters,
190            }],
191            code_execution: None,
192        })
193    }
194}
195
196impl TryFrom<Vec<completion::ToolDefinition>> for Tool {
197    type Error = CompletionError;
198
199    fn try_from(tools: Vec<completion::ToolDefinition>) -> Result<Self, Self::Error> {
200        let mut function_declarations = Vec::new();
201
202        for tool in tools {
203            let parameters =
204                if tool.parameters == serde_json::json!({"type": "object", "properties": {}}) {
205                    None
206                } else {
207                    match tool.parameters.try_into() {
208                        Ok(schema) => Some(schema),
209                        Err(e) => {
210                            let emsg = format!(
211                                "Tool '{}' could not be converted to a schema: {:?}",
212                                tool.name, e,
213                            );
214                            return Err(CompletionError::ProviderError(emsg));
215                        }
216                    }
217                };
218
219            function_declarations.push(FunctionDeclaration {
220                name: tool.name,
221                description: tool.description,
222                parameters,
223            });
224        }
225
226        Ok(Self {
227            function_declarations,
228            code_execution: None,
229        })
230    }
231}
232
233impl TryFrom<GenerateContentResponse> for completion::CompletionResponse<GenerateContentResponse> {
234    type Error = CompletionError;
235
236    fn try_from(response: GenerateContentResponse) -> Result<Self, Self::Error> {
237        let candidate = response.candidates.first().ok_or_else(|| {
238            CompletionError::ResponseError("No response candidates in response".into())
239        })?;
240
241        let content = candidate
242            .content
243            .parts
244            .iter()
245            .map(|Part { thought, part, .. }| {
246                Ok(match part {
247                    PartKind::Text(text) => {
248                        if let Some(thought) = thought
249                            && *thought
250                        {
251                            completion::AssistantContent::Reasoning(Reasoning::new(text))
252                        } else {
253                            completion::AssistantContent::text(text)
254                        }
255                    }
256                    PartKind::FunctionCall(function_call) => {
257                        completion::AssistantContent::tool_call(
258                            &function_call.name,
259                            &function_call.name,
260                            function_call.args.clone(),
261                        )
262                    }
263                    _ => {
264                        return Err(CompletionError::ResponseError(
265                            "Response did not contain a message or tool call".into(),
266                        ));
267                    }
268                })
269            })
270            .collect::<Result<Vec<_>, _>>()?;
271
272        let choice = OneOrMany::many(content).map_err(|_| {
273            CompletionError::ResponseError(
274                "Response contained no message or tool call (empty)".to_owned(),
275            )
276        })?;
277
278        let usage = response
279            .usage_metadata
280            .as_ref()
281            .map(|usage| completion::Usage {
282                input_tokens: usage.prompt_token_count as u64,
283                output_tokens: usage.candidates_token_count as u64,
284                total_tokens: usage.total_token_count as u64,
285            })
286            .unwrap_or_default();
287
288        Ok(completion::CompletionResponse {
289            choice,
290            usage,
291            raw_response: response,
292        })
293    }
294}
295
296pub mod gemini_api_types {
297    use std::{collections::HashMap, convert::Infallible, str::FromStr};
298
299    // =================================================================
300    // Gemini API Types
301    // =================================================================
302    use serde::{Deserialize, Serialize};
303    use serde_json::{Value, json};
304
305    use crate::message::ContentFormat;
306    use crate::{
307        OneOrMany,
308        completion::CompletionError,
309        message::{self, MimeType as _, Reasoning, Text},
310        providers::gemini::gemini_api_types::{CodeExecutionResult, ExecutableCode},
311    };
312
313    #[derive(Debug, Deserialize, Serialize, Default)]
314    #[serde(rename_all = "camelCase")]
315    pub struct AdditionalParameters {
316        /// Change your Gemini request configuration.
317        pub generation_config: GenerationConfig,
318        /// Any additional parameters that you want.
319        #[serde(flatten, skip_serializing_if = "Option::is_none")]
320        pub additional_params: Option<serde_json::Value>,
321    }
322
323    impl AdditionalParameters {
324        pub fn with_config(mut self, cfg: GenerationConfig) -> Self {
325            self.generation_config = cfg;
326            self
327        }
328
329        pub fn with_params(mut self, params: serde_json::Value) -> Self {
330            self.additional_params = Some(params);
331            self
332        }
333    }
334
335    /// Response from the model supporting multiple candidate responses.
336    /// Safety ratings and content filtering are reported for both prompt in GenerateContentResponse.prompt_feedback
337    /// and for each candidate in finishReason and in safetyRatings.
338    /// The API:
339    ///     - Returns either all requested candidates or none of them
340    ///     - Returns no candidates at all only if there was something wrong with the prompt (check promptFeedback)
341    ///     - Reports feedback on each candidate in finishReason and safetyRatings.
342    #[derive(Debug, Deserialize, Serialize)]
343    #[serde(rename_all = "camelCase")]
344    pub struct GenerateContentResponse {
345        /// Candidate responses from the model.
346        pub candidates: Vec<ContentCandidate>,
347        /// Returns the prompt's feedback related to the content filters.
348        pub prompt_feedback: Option<PromptFeedback>,
349        /// Output only. Metadata on the generation requests' token usage.
350        pub usage_metadata: Option<UsageMetadata>,
351        pub model_version: Option<String>,
352    }
353
354    /// A response candidate generated from the model.
355    #[derive(Debug, Deserialize, Serialize)]
356    #[serde(rename_all = "camelCase")]
357    pub struct ContentCandidate {
358        /// Output only. Generated content returned from the model.
359        pub content: Content,
360        /// Optional. Output only. The reason why the model stopped generating tokens.
361        /// If empty, the model has not stopped generating tokens.
362        pub finish_reason: Option<FinishReason>,
363        /// List of ratings for the safety of a response candidate.
364        /// There is at most one rating per category.
365        pub safety_ratings: Option<Vec<SafetyRating>>,
366        /// Output only. Citation information for model-generated candidate.
367        /// This field may be populated with recitation information for any text included in the content.
368        /// These are passages that are "recited" from copyrighted material in the foundational LLM's training data.
369        pub citation_metadata: Option<CitationMetadata>,
370        /// Output only. Token count for this candidate.
371        pub token_count: Option<i32>,
372        /// Output only.
373        pub avg_logprobs: Option<f64>,
374        /// Output only. Log-likelihood scores for the response tokens and top tokens
375        pub logprobs_result: Option<LogprobsResult>,
376        /// Output only. Index of the candidate in the list of response candidates.
377        pub index: Option<i32>,
378    }
379
380    #[derive(Debug, Deserialize, Serialize)]
381    pub struct Content {
382        /// Ordered Parts that constitute a single message. Parts may have different MIME types.
383        #[serde(default)]
384        pub parts: Vec<Part>,
385        /// The producer of the content. Must be either 'user' or 'model'.
386        /// Useful to set for multi-turn conversations, otherwise can be left blank or unset.
387        pub role: Option<Role>,
388    }
389
390    impl TryFrom<message::Message> for Content {
391        type Error = message::MessageError;
392
393        fn try_from(msg: message::Message) -> Result<Self, Self::Error> {
394            Ok(match msg {
395                message::Message::User { content } => Content {
396                    parts: content
397                        .into_iter()
398                        .map(|c| c.try_into())
399                        .collect::<Result<Vec<_>, _>>()?,
400                    role: Some(Role::User),
401                },
402                message::Message::Assistant { content, .. } => Content {
403                    role: Some(Role::Model),
404                    parts: content.into_iter().map(|content| content.into()).collect(),
405                },
406            })
407        }
408    }
409
410    impl TryFrom<Content> for message::Message {
411        type Error = message::MessageError;
412
413        fn try_from(content: Content) -> Result<Self, Self::Error> {
414            match content.role {
415                Some(Role::User) | None => {
416                    Ok(message::Message::User {
417                        content: {
418                            let user_content: Result<Vec<_>, _> = content.parts.into_iter()
419                            .map(|Part { part, .. }| {
420                                Ok(match part {
421                                    PartKind::Text(text) => message::UserContent::text(text),
422                                    PartKind::InlineData(inline_data) => {
423                                        let mime_type =
424                                            message::MediaType::from_mime_type(&inline_data.mime_type);
425
426                                        match mime_type {
427                                            Some(message::MediaType::Image(media_type)) => {
428                                                message::UserContent::image(
429                                                    inline_data.data,
430                                                    Some(message::ContentFormat::default()),
431                                                    Some(media_type),
432                                                    Some(message::ImageDetail::default()),
433                                                )
434                                            }
435                                            Some(message::MediaType::Document(media_type)) => {
436                                                message::UserContent::document(
437                                                    inline_data.data,
438                                                    Some(message::ContentFormat::default()),
439                                                    Some(media_type),
440                                                )
441                                            }
442                                            Some(message::MediaType::Audio(media_type)) => {
443                                                message::UserContent::audio(
444                                                    inline_data.data,
445                                                    Some(message::ContentFormat::default()),
446                                                    Some(media_type),
447                                                )
448                                            }
449                                            _ => {
450                                                return Err(message::MessageError::ConversionError(
451                                                    format!("Unsupported media type {mime_type:?}"),
452                                                ));
453                                            }
454                                        }
455                                    }
456                                    _ => {
457                                        return Err(message::MessageError::ConversionError(format!(
458                                            "Unsupported gemini content part type: {part:?}"
459                                        )));
460                                    }
461                                })
462                            })
463                            .collect();
464                            OneOrMany::many(user_content?).map_err(|_| {
465                                message::MessageError::ConversionError(
466                                    "Failed to create OneOrMany from user content".to_string(),
467                                )
468                            })?
469                        },
470                    })
471                }
472                Some(Role::Model) => Ok(message::Message::Assistant {
473                    id: None,
474                    content: {
475                        let assistant_content: Result<Vec<_>, _> = content
476                            .parts
477                            .into_iter()
478                            .map(|Part { thought, part, .. }| {
479                                Ok(match part {
480                                    PartKind::Text(text) => match thought {
481                                        Some(true) => message::AssistantContent::Reasoning(
482                                            Reasoning::new(&text),
483                                        ),
484                                        _ => message::AssistantContent::Text(Text { text }),
485                                    },
486
487                                    PartKind::FunctionCall(function_call) => {
488                                        message::AssistantContent::ToolCall(function_call.into())
489                                    }
490                                    _ => {
491                                        return Err(message::MessageError::ConversionError(
492                                            format!("Unsupported part type: {part:?}"),
493                                        ));
494                                    }
495                                })
496                            })
497                            .collect();
498                        OneOrMany::many(assistant_content?).map_err(|_| {
499                            message::MessageError::ConversionError(
500                                "Failed to create OneOrMany from assistant content".to_string(),
501                            )
502                        })?
503                    },
504                }),
505            }
506        }
507    }
508
509    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
510    #[serde(rename_all = "lowercase")]
511    pub enum Role {
512        User,
513        Model,
514    }
515
516    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
517    #[serde(rename_all = "camelCase")]
518    pub struct Part {
519        /// whether or not the part is a reasoning/thinking text or not
520        #[serde(skip_serializing_if = "Option::is_none")]
521        pub thought: Option<bool>,
522        /// an opaque sig for the thought so it can be reused - is a base64 string
523        #[serde(skip_serializing_if = "Option::is_none")]
524        pub thought_signature: Option<String>,
525        #[serde(flatten)]
526        pub part: PartKind,
527        #[serde(flatten, skip_serializing_if = "Option::is_none")]
528        pub additional_params: Option<Value>,
529    }
530
531    /// A datatype containing media that is part of a multi-part [Content] message.
532    /// A Part consists of data which has an associated datatype. A Part can only contain one of the accepted types in Part.data.
533    /// A Part must have a fixed IANA MIME type identifying the type and subtype of the media if the inlineData field is filled with raw bytes.
534    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
535    #[serde(rename_all = "camelCase")]
536    pub enum PartKind {
537        Text(String),
538        InlineData(Blob),
539        FunctionCall(FunctionCall),
540        FunctionResponse(FunctionResponse),
541        FileData(FileData),
542        ExecutableCode(ExecutableCode),
543        CodeExecutionResult(CodeExecutionResult),
544    }
545
546    impl From<String> for Part {
547        fn from(text: String) -> Self {
548            Self {
549                thought: Some(false),
550                thought_signature: None,
551                part: PartKind::Text(text),
552                additional_params: None,
553            }
554        }
555    }
556
557    impl From<&str> for Part {
558        fn from(text: &str) -> Self {
559            Self::from(text.to_string())
560        }
561    }
562
563    impl FromStr for Part {
564        type Err = Infallible;
565
566        fn from_str(s: &str) -> Result<Self, Self::Err> {
567            Ok(s.into())
568        }
569    }
570
571    impl TryFrom<message::UserContent> for Part {
572        type Error = message::MessageError;
573
574        fn try_from(content: message::UserContent) -> Result<Self, Self::Error> {
575            match content {
576                message::UserContent::Text(message::Text { text }) => Ok(Part {
577                    thought: Some(false),
578                    thought_signature: None,
579                    part: PartKind::Text(text),
580                    additional_params: None,
581                }),
582                message::UserContent::ToolResult(message::ToolResult { id, content, .. }) => {
583                    let content = match content.first() {
584                        message::ToolResultContent::Text(text) => text.text,
585                        message::ToolResultContent::Image(_) => {
586                            return Err(message::MessageError::ConversionError(
587                                "Tool result content must be text".to_string(),
588                            ));
589                        }
590                    };
591                    // Convert to JSON since this value may be a valid JSON value
592                    let result: serde_json::Value =
593                        serde_json::from_str(&content).unwrap_or_else(|error| {
594                            tracing::trace!(
595                                ?error,
596                                "Tool result is not a valid JSON, treat it as normal string"
597                            );
598                            json!(content)
599                        });
600                    Ok(Part {
601                        thought: Some(false),
602                        thought_signature: None,
603                        part: PartKind::FunctionResponse(FunctionResponse {
604                            name: id,
605                            response: Some(json!({ "result": result })),
606                        }),
607                        additional_params: None,
608                    })
609                }
610                message::UserContent::Image(message::Image {
611                    data, media_type, ..
612                }) => match media_type {
613                    Some(media_type) => match media_type {
614                        message::ImageMediaType::JPEG
615                        | message::ImageMediaType::PNG
616                        | message::ImageMediaType::WEBP
617                        | message::ImageMediaType::HEIC
618                        | message::ImageMediaType::HEIF => Ok(Part {
619                            thought: Some(false),
620                            thought_signature: None,
621                            part: PartKind::InlineData(Blob {
622                                mime_type: media_type.to_mime_type().to_owned(),
623                                data,
624                            }),
625                            additional_params: None,
626                        }),
627                        _ => Err(message::MessageError::ConversionError(format!(
628                            "Unsupported image media type {media_type:?}"
629                        ))),
630                    },
631                    None => Err(message::MessageError::ConversionError(
632                        "Media type for image is required for Gemini".to_string(), // Fixed error message
633                    )),
634                },
635                message::UserContent::Document(message::Document {
636                    data, media_type, ..
637                }) => match media_type {
638                    Some(media_type) => match media_type {
639                        message::DocumentMediaType::PDF
640                        | message::DocumentMediaType::TXT
641                        | message::DocumentMediaType::RTF
642                        | message::DocumentMediaType::HTML
643                        | message::DocumentMediaType::CSS
644                        | message::DocumentMediaType::MARKDOWN
645                        | message::DocumentMediaType::CSV
646                        | message::DocumentMediaType::XML => Ok(Part {
647                            thought: Some(false),
648                            thought_signature: None,
649                            part: PartKind::InlineData(Blob {
650                                mime_type: media_type.to_mime_type().to_owned(),
651                                data,
652                            }),
653                            additional_params: None,
654                        }),
655                        _ => Err(message::MessageError::ConversionError(format!(
656                            "Unsupported document media type {media_type:?}"
657                        ))),
658                    },
659                    None => Err(message::MessageError::ConversionError(
660                        "Media type for document is required for Gemini".to_string(), // Fixed error message
661                    )),
662                },
663                message::UserContent::Audio(message::Audio {
664                    data, media_type, ..
665                }) => match media_type {
666                    Some(media_type) => Ok(Part {
667                        thought: Some(false),
668                        thought_signature: None,
669                        part: PartKind::InlineData(Blob {
670                            mime_type: media_type.to_mime_type().to_owned(),
671                            data,
672                        }),
673                        additional_params: None,
674                    }),
675                    None => Err(message::MessageError::ConversionError(
676                        "Media type for audio is required for Gemini".to_string(),
677                    )),
678                },
679                message::UserContent::Video(message::Video {
680                    data,
681                    media_type,
682                    format,
683                    additional_params,
684                }) => {
685                    let mime_type = media_type.map(|m| m.to_mime_type().to_owned());
686
687                    let data = match format {
688                        Some(ContentFormat::String) => PartKind::FileData(FileData {
689                            mime_type,
690                            file_uri: data,
691                        }),
692                        _ => match mime_type {
693                            Some(mime_type) => PartKind::InlineData(Blob { mime_type, data }),
694                            None => {
695                                return Err(message::MessageError::ConversionError(
696                                    "Media type for video is required for Gemini".to_string(),
697                                ));
698                            }
699                        },
700                    };
701
702                    Ok(Part {
703                        thought: Some(false),
704                        thought_signature: None,
705                        part: data,
706                        additional_params,
707                    })
708                }
709            }
710        }
711    }
712
713    impl From<message::AssistantContent> for Part {
714        fn from(content: message::AssistantContent) -> Self {
715            match content {
716                message::AssistantContent::Text(message::Text { text }) => text.into(),
717                message::AssistantContent::ToolCall(tool_call) => tool_call.into(),
718                message::AssistantContent::Reasoning(message::Reasoning { reasoning, .. }) => {
719                    Part {
720                        thought: Some(true),
721                        thought_signature: None,
722                        part: PartKind::Text(
723                            reasoning.first().cloned().unwrap_or_else(|| "".to_string()),
724                        ),
725                        additional_params: None,
726                    }
727                }
728            }
729        }
730    }
731
732    impl From<message::ToolCall> for Part {
733        fn from(tool_call: message::ToolCall) -> Self {
734            Self {
735                thought: Some(false),
736                thought_signature: None,
737                part: PartKind::FunctionCall(FunctionCall {
738                    name: tool_call.function.name,
739                    args: tool_call.function.arguments,
740                }),
741                additional_params: None,
742            }
743        }
744    }
745
746    /// Raw media bytes.
747    /// Text should not be sent as raw bytes, use the 'text' field.
748    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
749    #[serde(rename_all = "camelCase")]
750    pub struct Blob {
751        /// The IANA standard MIME type of the source data. Examples: - image/png - image/jpeg
752        /// If an unsupported MIME type is provided, an error will be returned.
753        pub mime_type: String,
754        /// Raw bytes for media formats. A base64-encoded string.
755        pub data: String,
756    }
757
758    /// A predicted FunctionCall returned from the model that contains a string representing the
759    /// FunctionDeclaration.name with the arguments and their values.
760    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
761    pub struct FunctionCall {
762        /// Required. The name of the function to call. Must be a-z, A-Z, 0-9, or contain underscores
763        /// and dashes, with a maximum length of 63.
764        pub name: String,
765        /// Optional. The function parameters and values in JSON object format.
766        pub args: serde_json::Value,
767    }
768
769    impl From<FunctionCall> for message::ToolCall {
770        fn from(function_call: FunctionCall) -> Self {
771            Self {
772                id: function_call.name.clone(),
773                call_id: None,
774                function: message::ToolFunction {
775                    name: function_call.name,
776                    arguments: function_call.args,
777                },
778            }
779        }
780    }
781
782    impl From<message::ToolCall> for FunctionCall {
783        fn from(tool_call: message::ToolCall) -> Self {
784            Self {
785                name: tool_call.function.name,
786                args: tool_call.function.arguments,
787            }
788        }
789    }
790
791    /// The result output from a FunctionCall that contains a string representing the FunctionDeclaration.name
792    /// and a structured JSON object containing any output from the function is used as context to the model.
793    /// This should contain the result of aFunctionCall made based on model prediction.
794    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
795    pub struct FunctionResponse {
796        /// The name of the function to call. Must be a-z, A-Z, 0-9, or contain underscores and dashes,
797        /// with a maximum length of 63.
798        pub name: String,
799        /// The function response in JSON object format.
800        pub response: Option<serde_json::Value>,
801    }
802
803    /// URI based data.
804    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
805    #[serde(rename_all = "camelCase")]
806    pub struct FileData {
807        /// Optional. The IANA standard MIME type of the source data.
808        pub mime_type: Option<String>,
809        /// Required. URI.
810        pub file_uri: String,
811    }
812
813    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
814    pub struct SafetyRating {
815        pub category: HarmCategory,
816        pub probability: HarmProbability,
817    }
818
819    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
820    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
821    pub enum HarmProbability {
822        HarmProbabilityUnspecified,
823        Negligible,
824        Low,
825        Medium,
826        High,
827    }
828
829    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
830    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
831    pub enum HarmCategory {
832        HarmCategoryUnspecified,
833        HarmCategoryDerogatory,
834        HarmCategoryToxicity,
835        HarmCategoryViolence,
836        HarmCategorySexually,
837        HarmCategoryMedical,
838        HarmCategoryDangerous,
839        HarmCategoryHarassment,
840        HarmCategoryHateSpeech,
841        HarmCategorySexuallyExplicit,
842        HarmCategoryDangerousContent,
843        HarmCategoryCivicIntegrity,
844    }
845
846    #[derive(Debug, Deserialize, Clone, Default, Serialize)]
847    #[serde(rename_all = "camelCase")]
848    pub struct UsageMetadata {
849        pub prompt_token_count: i32,
850        #[serde(skip_serializing_if = "Option::is_none")]
851        pub cached_content_token_count: Option<i32>,
852        pub candidates_token_count: i32,
853        pub total_token_count: i32,
854        #[serde(skip_serializing_if = "Option::is_none")]
855        pub thoughts_token_count: Option<i32>,
856    }
857
858    impl std::fmt::Display for UsageMetadata {
859        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
860            write!(
861                f,
862                "Prompt token count: {}\nCached content token count: {}\nCandidates token count: {}\nTotal token count: {}",
863                self.prompt_token_count,
864                match self.cached_content_token_count {
865                    Some(count) => count.to_string(),
866                    None => "n/a".to_string(),
867                },
868                self.candidates_token_count,
869                self.total_token_count
870            )
871        }
872    }
873
874    /// A set of the feedback metadata the prompt specified in [GenerateContentRequest.contents](GenerateContentRequest).
875    #[derive(Debug, Deserialize, Serialize)]
876    #[serde(rename_all = "camelCase")]
877    pub struct PromptFeedback {
878        /// Optional. If set, the prompt was blocked and no candidates are returned. Rephrase the prompt.
879        pub block_reason: Option<BlockReason>,
880        /// Ratings for safety of the prompt. There is at most one rating per category.
881        pub safety_ratings: Option<Vec<SafetyRating>>,
882    }
883
884    /// Reason why a prompt was blocked by the model
885    #[derive(Debug, Deserialize, Serialize)]
886    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
887    pub enum BlockReason {
888        /// Default value. This value is unused.
889        BlockReasonUnspecified,
890        /// Prompt was blocked due to safety reasons. Inspect safetyRatings to understand which safety category blocked it.
891        Safety,
892        /// Prompt was blocked due to unknown reasons.
893        Other,
894        /// Prompt was blocked due to the terms which are included from the terminology blocklist.
895        Blocklist,
896        /// Prompt was blocked due to prohibited content.
897        ProhibitedContent,
898    }
899
900    #[derive(Debug, Deserialize, Serialize)]
901    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
902    pub enum FinishReason {
903        /// Default value. This value is unused.
904        FinishReasonUnspecified,
905        /// Natural stop point of the model or provided stop sequence.
906        Stop,
907        /// The maximum number of tokens as specified in the request was reached.
908        MaxTokens,
909        /// The response candidate content was flagged for safety reasons.
910        Safety,
911        /// The response candidate content was flagged for recitation reasons.
912        Recitation,
913        /// The response candidate content was flagged for using an unsupported language.
914        Language,
915        /// Unknown reason.
916        Other,
917        /// Token generation stopped because the content contains forbidden terms.
918        Blocklist,
919        /// Token generation stopped for potentially containing prohibited content.
920        ProhibitedContent,
921        /// Token generation stopped because the content potentially contains Sensitive Personally Identifiable Information (SPII).
922        Spii,
923        /// The function call generated by the model is invalid.
924        MalformedFunctionCall,
925    }
926
927    #[derive(Debug, Deserialize, Serialize)]
928    #[serde(rename_all = "camelCase")]
929    pub struct CitationMetadata {
930        pub citation_sources: Vec<CitationSource>,
931    }
932
933    #[derive(Debug, Deserialize, Serialize)]
934    #[serde(rename_all = "camelCase")]
935    pub struct CitationSource {
936        #[serde(skip_serializing_if = "Option::is_none")]
937        pub uri: Option<String>,
938        #[serde(skip_serializing_if = "Option::is_none")]
939        pub start_index: Option<i32>,
940        #[serde(skip_serializing_if = "Option::is_none")]
941        pub end_index: Option<i32>,
942        #[serde(skip_serializing_if = "Option::is_none")]
943        pub license: Option<String>,
944    }
945
946    #[derive(Debug, Deserialize, Serialize)]
947    #[serde(rename_all = "camelCase")]
948    pub struct LogprobsResult {
949        pub top_candidate: Vec<TopCandidate>,
950        pub chosen_candidate: Vec<LogProbCandidate>,
951    }
952
953    #[derive(Debug, Deserialize, Serialize)]
954    pub struct TopCandidate {
955        pub candidates: Vec<LogProbCandidate>,
956    }
957
958    #[derive(Debug, Deserialize, Serialize)]
959    #[serde(rename_all = "camelCase")]
960    pub struct LogProbCandidate {
961        pub token: String,
962        pub token_id: String,
963        pub log_probability: f64,
964    }
965
966    /// Gemini API Configuration options for model generation and outputs. Not all parameters are
967    /// configurable for every model. From [Gemini API Reference](https://ai.google.dev/api/generate-content#generationconfig)
968    /// ### Rig Note:
969    /// Can be used to construct a typesafe `additional_params` in rig::[AgentBuilder](crate::agent::AgentBuilder).
970    #[derive(Debug, Deserialize, Serialize)]
971    #[serde(rename_all = "camelCase")]
972    pub struct GenerationConfig {
973        /// The set of character sequences (up to 5) that will stop output generation. If specified, the API will stop
974        /// at the first appearance of a stop_sequence. The stop sequence will not be included as part of the response.
975        #[serde(skip_serializing_if = "Option::is_none")]
976        pub stop_sequences: Option<Vec<String>>,
977        /// MIME type of the generated candidate text. Supported MIME types are:
978        ///     - text/plain:  (default) Text output
979        ///     - application/json: JSON response in the response candidates.
980        ///     - text/x.enum: ENUM as a string response in the response candidates.
981        /// Refer to the docs for a list of all supported text MIME types
982        #[serde(skip_serializing_if = "Option::is_none")]
983        pub response_mime_type: Option<String>,
984        /// Output schema of the generated candidate text. Schemas must be a subset of the OpenAPI schema and can be
985        /// objects, primitives or arrays. If set, a compatible responseMimeType must also  be set. Compatible MIME
986        /// types: application/json: Schema for JSON response. Refer to the JSON text generation guide for more details.
987        #[serde(skip_serializing_if = "Option::is_none")]
988        pub response_schema: Option<Schema>,
989        /// Number of generated responses to return. Currently, this value can only be set to 1. If
990        /// unset, this will default to 1.
991        #[serde(skip_serializing_if = "Option::is_none")]
992        pub candidate_count: Option<i32>,
993        /// The maximum number of tokens to include in a response candidate. Note: The default value varies by model, see
994        /// the Model.output_token_limit attribute of the Model returned from the getModel function.
995        #[serde(skip_serializing_if = "Option::is_none")]
996        pub max_output_tokens: Option<u64>,
997        /// Controls the randomness of the output. Note: The default value varies by model, see the Model.temperature
998        /// attribute of the Model returned from the getModel function. Values can range from [0.0, 2.0].
999        #[serde(skip_serializing_if = "Option::is_none")]
1000        pub temperature: Option<f64>,
1001        /// The maximum cumulative probability of tokens to consider when sampling. The model uses combined Top-k and
1002        /// Top-p (nucleus) sampling. Tokens are sorted based on their assigned probabilities so that only the most
1003        /// likely tokens are considered. Top-k sampling directly limits the maximum number of tokens to consider, while
1004        /// Nucleus sampling limits the number of tokens based on the cumulative probability. Note: The default value
1005        /// varies by Model and is specified by theModel.top_p attribute returned from the getModel function. An empty
1006        /// topK attribute indicates that the model doesn't apply top-k sampling and doesn't allow setting topK on requests.
1007        #[serde(skip_serializing_if = "Option::is_none")]
1008        pub top_p: Option<f64>,
1009        /// The maximum number of tokens to consider when sampling. Gemini models use Top-p (nucleus) sampling or a
1010        /// combination of Top-k and nucleus sampling. Top-k sampling considers the set of topK most probable tokens.
1011        /// Models running with nucleus sampling don't allow topK setting. Note: The default value varies by Model and is
1012        /// specified by theModel.top_p attribute returned from the getModel function. An empty topK attribute indicates
1013        /// that the model doesn't apply top-k sampling and doesn't allow setting topK on requests.
1014        #[serde(skip_serializing_if = "Option::is_none")]
1015        pub top_k: Option<i32>,
1016        /// Presence penalty applied to the next token's logprobs if the token has already been seen in the response.
1017        /// This penalty is binary on/off and not dependent on the number of times the token is used (after the first).
1018        /// Use frequencyPenalty for a penalty that increases with each use. A positive penalty will discourage the use
1019        /// of tokens that have already been used in the response, increasing the vocabulary. A negative penalty will
1020        /// encourage the use of tokens that have already been used in the response, decreasing the vocabulary.
1021        #[serde(skip_serializing_if = "Option::is_none")]
1022        pub presence_penalty: Option<f64>,
1023        /// Frequency penalty applied to the next token's logprobs, multiplied by the number of times each token has been
1024        /// seen in the response so far. A positive penalty will discourage the use of tokens that have already been
1025        /// used, proportional to the number of times the token has been used: The more a token is used, the more
1026        /// difficult it is for the  model to use that token again increasing the vocabulary of responses. Caution: A
1027        /// negative penalty will encourage the model to reuse tokens proportional to the number of times the token has
1028        /// been used. Small negative values will reduce the vocabulary of a response. Larger negative values will cause
1029        /// the model to  repeating a common token until it hits the maxOutputTokens limit: "...the the the the the...".
1030        #[serde(skip_serializing_if = "Option::is_none")]
1031        pub frequency_penalty: Option<f64>,
1032        /// If true, export the logprobs results in response.
1033        #[serde(skip_serializing_if = "Option::is_none")]
1034        pub response_logprobs: Option<bool>,
1035        /// Only valid if responseLogprobs=True. This sets the number of top logprobs to return at each decoding step in
1036        /// [Candidate.logprobs_result].
1037        #[serde(skip_serializing_if = "Option::is_none")]
1038        pub logprobs: Option<i32>,
1039        /// Configuration for thinking/reasoning.
1040        #[serde(skip_serializing_if = "Option::is_none")]
1041        pub thinking_config: Option<ThinkingConfig>,
1042    }
1043
1044    impl Default for GenerationConfig {
1045        fn default() -> Self {
1046            Self {
1047                temperature: Some(1.0),
1048                max_output_tokens: Some(4096),
1049                stop_sequences: None,
1050                response_mime_type: None,
1051                response_schema: None,
1052                candidate_count: None,
1053                top_p: None,
1054                top_k: None,
1055                presence_penalty: None,
1056                frequency_penalty: None,
1057                response_logprobs: None,
1058                logprobs: None,
1059                thinking_config: None,
1060            }
1061        }
1062    }
1063
1064    #[derive(Debug, Deserialize, Serialize)]
1065    #[serde(rename_all = "camelCase")]
1066    pub struct ThinkingConfig {
1067        pub thinking_budget: u32,
1068        pub include_thoughts: Option<bool>,
1069    }
1070    /// The Schema object allows the definition of input and output data types. These types can be objects, but also
1071    /// primitives and arrays. Represents a select subset of an OpenAPI 3.0 schema object.
1072    /// From [Gemini API Reference](https://ai.google.dev/api/caching#Schema)
1073    #[derive(Debug, Deserialize, Serialize, Clone)]
1074    pub struct Schema {
1075        pub r#type: String,
1076        #[serde(skip_serializing_if = "Option::is_none")]
1077        pub format: Option<String>,
1078        #[serde(skip_serializing_if = "Option::is_none")]
1079        pub description: Option<String>,
1080        #[serde(skip_serializing_if = "Option::is_none")]
1081        pub nullable: Option<bool>,
1082        #[serde(skip_serializing_if = "Option::is_none")]
1083        pub r#enum: Option<Vec<String>>,
1084        #[serde(skip_serializing_if = "Option::is_none")]
1085        pub max_items: Option<i32>,
1086        #[serde(skip_serializing_if = "Option::is_none")]
1087        pub min_items: Option<i32>,
1088        #[serde(skip_serializing_if = "Option::is_none")]
1089        pub properties: Option<HashMap<String, Schema>>,
1090        #[serde(skip_serializing_if = "Option::is_none")]
1091        pub required: Option<Vec<String>>,
1092        #[serde(skip_serializing_if = "Option::is_none")]
1093        pub items: Option<Box<Schema>>,
1094    }
1095
1096    /// Flattens a JSON schema by resolving all `$ref` references inline.
1097    /// It takes a JSON schema that may contain `$ref` references to definitions
1098    /// in `$defs` or `definitions` sections and returns a new schema with all references
1099    /// resolved and inlined. This is necessary for APIs like Gemini that don't support
1100    /// schema references.
1101    pub fn flatten_schema(mut schema: Value) -> Result<Value, CompletionError> {
1102        // extracting $defs if they exist
1103        let defs = if let Some(obj) = schema.as_object() {
1104            obj.get("$defs").or_else(|| obj.get("definitions")).cloned()
1105        } else {
1106            None
1107        };
1108
1109        let Some(defs_value) = defs else {
1110            return Ok(schema);
1111        };
1112
1113        let Some(defs_obj) = defs_value.as_object() else {
1114            return Err(CompletionError::ResponseError(
1115                "$defs must be an object".into(),
1116            ));
1117        };
1118
1119        resolve_refs(&mut schema, defs_obj)?;
1120
1121        // removing $defs from the final schema because we have inlined everything
1122        if let Some(obj) = schema.as_object_mut() {
1123            obj.remove("$defs");
1124            obj.remove("definitions");
1125        }
1126
1127        Ok(schema)
1128    }
1129
1130    /// Recursively resolves all `$ref` references in a JSON value by
1131    /// replacing them with their definitions.
1132    fn resolve_refs(
1133        value: &mut Value,
1134        defs: &serde_json::Map<String, Value>,
1135    ) -> Result<(), CompletionError> {
1136        match value {
1137            Value::Object(obj) => {
1138                if let Some(ref_value) = obj.get("$ref")
1139                    && let Some(ref_str) = ref_value.as_str()
1140                {
1141                    // "#/$defs/Person" -> "Person"
1142                    let def_name = parse_ref_path(ref_str)?;
1143
1144                    let def = defs.get(&def_name).ok_or_else(|| {
1145                        CompletionError::ResponseError(format!("Reference not found: {}", ref_str))
1146                    })?;
1147
1148                    let mut resolved = def.clone();
1149                    resolve_refs(&mut resolved, defs)?;
1150                    *value = resolved;
1151                    return Ok(());
1152                }
1153
1154                for (_, v) in obj.iter_mut() {
1155                    resolve_refs(v, defs)?;
1156                }
1157            }
1158            Value::Array(arr) => {
1159                for item in arr.iter_mut() {
1160                    resolve_refs(item, defs)?;
1161                }
1162            }
1163            _ => {}
1164        }
1165
1166        Ok(())
1167    }
1168
1169    /// Parses a JSON Schema `$ref` path to extract the definition name.
1170    ///
1171    /// JSON Schema references use URI fragment syntax to point to definitions within
1172    /// the same document. This function extracts the definition name from common
1173    /// reference patterns used in JSON Schema.
1174    fn parse_ref_path(ref_str: &str) -> Result<String, CompletionError> {
1175        if let Some(fragment) = ref_str.strip_prefix('#') {
1176            if let Some(name) = fragment.strip_prefix("/$defs/") {
1177                Ok(name.to_string())
1178            } else if let Some(name) = fragment.strip_prefix("/definitions/") {
1179                Ok(name.to_string())
1180            } else {
1181                Err(CompletionError::ResponseError(format!(
1182                    "Unsupported reference format: {}",
1183                    ref_str
1184                )))
1185            }
1186        } else {
1187            Err(CompletionError::ResponseError(format!(
1188                "Only fragment references (#/...) are supported: {}",
1189                ref_str
1190            )))
1191        }
1192    }
1193
1194    impl TryFrom<Value> for Schema {
1195        type Error = CompletionError;
1196
1197        fn try_from(value: Value) -> Result<Self, Self::Error> {
1198            let flattened_val = flatten_schema(value)?;
1199            if let Some(obj) = flattened_val.as_object() {
1200                Ok(Schema {
1201                    r#type: obj
1202                        .get("type")
1203                        .and_then(|v| {
1204                            if v.is_string() {
1205                                v.as_str().map(String::from)
1206                            } else if v.is_array() {
1207                                v.as_array()
1208                                    .and_then(|arr| arr.first())
1209                                    .and_then(|v| v.as_str().map(String::from))
1210                            } else {
1211                                None
1212                            }
1213                        })
1214                        .unwrap_or_default(),
1215                    format: obj.get("format").and_then(|v| v.as_str()).map(String::from),
1216                    description: obj
1217                        .get("description")
1218                        .and_then(|v| v.as_str())
1219                        .map(String::from),
1220                    nullable: obj.get("nullable").and_then(|v| v.as_bool()),
1221                    r#enum: obj.get("enum").and_then(|v| v.as_array()).map(|arr| {
1222                        arr.iter()
1223                            .filter_map(|v| v.as_str().map(String::from))
1224                            .collect()
1225                    }),
1226                    max_items: obj
1227                        .get("maxItems")
1228                        .and_then(|v| v.as_i64())
1229                        .map(|v| v as i32),
1230                    min_items: obj
1231                        .get("minItems")
1232                        .and_then(|v| v.as_i64())
1233                        .map(|v| v as i32),
1234                    properties: obj
1235                        .get("properties")
1236                        .and_then(|v| v.as_object())
1237                        .map(|map| {
1238                            map.iter()
1239                                .filter_map(|(k, v)| {
1240                                    v.clone().try_into().ok().map(|schema| (k.clone(), schema))
1241                                })
1242                                .collect()
1243                        }),
1244                    required: obj.get("required").and_then(|v| v.as_array()).map(|arr| {
1245                        arr.iter()
1246                            .filter_map(|v| v.as_str().map(String::from))
1247                            .collect()
1248                    }),
1249                    items: obj
1250                        .get("items")
1251                        .map(|v| Box::new(v.clone().try_into().unwrap())),
1252                })
1253            } else {
1254                Err(CompletionError::ResponseError(
1255                    "Expected a JSON object for Schema".into(),
1256                ))
1257            }
1258        }
1259    }
1260
1261    #[derive(Debug, Serialize)]
1262    #[serde(rename_all = "camelCase")]
1263    pub struct GenerateContentRequest {
1264        pub contents: Vec<Content>,
1265        #[serde(skip_serializing_if = "Option::is_none")]
1266        pub tools: Option<Tool>,
1267        pub tool_config: Option<ToolConfig>,
1268        /// Optional. Configuration options for model generation and outputs.
1269        pub generation_config: Option<GenerationConfig>,
1270        /// Optional. A list of unique SafetySetting instances for blocking unsafe content. This will be enforced on the
1271        /// [GenerateContentRequest.contents] and [GenerateContentResponse.candidates]. There should not be more than one
1272        /// setting for each SafetyCategory type. The API will block any contents and responses that fail to meet the
1273        /// thresholds set by these settings. This list overrides the default settings for each SafetyCategory specified
1274        /// in the safetySettings. If there is no SafetySetting for a given SafetyCategory provided in the list, the API
1275        /// will use the default safety setting for that category. Harm categories:
1276        ///     - HARM_CATEGORY_HATE_SPEECH,
1277        ///     - HARM_CATEGORY_SEXUALLY_EXPLICIT
1278        ///     - HARM_CATEGORY_DANGEROUS_CONTENT
1279        ///     - HARM_CATEGORY_HARASSMENT
1280        /// are supported.
1281        /// Refer to the guide for detailed information on available safety settings. Also refer to the Safety guidance
1282        /// to learn how to incorporate safety considerations in your AI applications.
1283        pub safety_settings: Option<Vec<SafetySetting>>,
1284        /// Optional. Developer set system instruction(s). Currently, text only.
1285        /// From [Gemini API Reference](https://ai.google.dev/gemini-api/docs/system-instructions?lang=rest)
1286        pub system_instruction: Option<Content>,
1287        // cachedContent: Optional<String>
1288        /// Additional parameters.
1289        #[serde(flatten, skip_serializing_if = "Option::is_none")]
1290        pub additional_params: Option<serde_json::Value>,
1291    }
1292
1293    #[derive(Debug, Serialize)]
1294    #[serde(rename_all = "camelCase")]
1295    pub struct Tool {
1296        pub function_declarations: Vec<FunctionDeclaration>,
1297        pub code_execution: Option<CodeExecution>,
1298    }
1299
1300    #[derive(Debug, Serialize, Clone)]
1301    #[serde(rename_all = "camelCase")]
1302    pub struct FunctionDeclaration {
1303        pub name: String,
1304        pub description: String,
1305        #[serde(skip_serializing_if = "Option::is_none")]
1306        pub parameters: Option<Schema>,
1307    }
1308
1309    #[derive(Debug, Serialize)]
1310    #[serde(rename_all = "camelCase")]
1311    pub struct ToolConfig {
1312        pub schema: Option<Schema>,
1313    }
1314
1315    #[derive(Debug, Serialize)]
1316    #[serde(rename_all = "camelCase")]
1317    pub struct CodeExecution {}
1318
1319    #[derive(Debug, Serialize)]
1320    #[serde(rename_all = "camelCase")]
1321    pub struct SafetySetting {
1322        pub category: HarmCategory,
1323        pub threshold: HarmBlockThreshold,
1324    }
1325
1326    #[derive(Debug, Serialize)]
1327    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1328    pub enum HarmBlockThreshold {
1329        HarmBlockThresholdUnspecified,
1330        BlockLowAndAbove,
1331        BlockMediumAndAbove,
1332        BlockOnlyHigh,
1333        BlockNone,
1334        Off,
1335    }
1336}
1337
1338#[cfg(test)]
1339mod tests {
1340    use crate::{message, providers::gemini::completion::gemini_api_types::flatten_schema};
1341
1342    use super::*;
1343    use serde_json::json;
1344
1345    #[test]
1346    fn test_deserialize_message_user() {
1347        let raw_message = r#"{
1348            "parts": [
1349                {"text": "Hello, world!"},
1350                {"inlineData": {"mimeType": "image/png", "data": "base64encodeddata"}},
1351                {"functionCall": {"name": "test_function", "args": {"arg1": "value1"}}},
1352                {"functionResponse": {"name": "test_function", "response": {"result": "success"}}},
1353                {"fileData": {"mimeType": "application/pdf", "fileUri": "http://example.com/file.pdf"}},
1354                {"executableCode": {"code": "print('Hello, world!')", "language": "PYTHON"}},
1355                {"codeExecutionResult": {"output": "Hello, world!", "outcome": "OUTCOME_OK"}}
1356            ],
1357            "role": "user"
1358        }"#;
1359
1360        let content: Content = {
1361            let jd = &mut serde_json::Deserializer::from_str(raw_message);
1362            serde_path_to_error::deserialize(jd).unwrap_or_else(|err| {
1363                panic!("Deserialization error at {}: {}", err.path(), err);
1364            })
1365        };
1366        assert_eq!(content.role, Some(Role::User));
1367        assert_eq!(content.parts.len(), 7);
1368
1369        let parts: Vec<Part> = content.parts.into_iter().collect();
1370
1371        if let Part {
1372            part: PartKind::Text(text),
1373            ..
1374        } = &parts[0]
1375        {
1376            assert_eq!(text, "Hello, world!");
1377        } else {
1378            panic!("Expected text part");
1379        }
1380
1381        if let Part {
1382            part: PartKind::InlineData(inline_data),
1383            ..
1384        } = &parts[1]
1385        {
1386            assert_eq!(inline_data.mime_type, "image/png");
1387            assert_eq!(inline_data.data, "base64encodeddata");
1388        } else {
1389            panic!("Expected inline data part");
1390        }
1391
1392        if let Part {
1393            part: PartKind::FunctionCall(function_call),
1394            ..
1395        } = &parts[2]
1396        {
1397            assert_eq!(function_call.name, "test_function");
1398            assert_eq!(
1399                function_call.args.as_object().unwrap().get("arg1").unwrap(),
1400                "value1"
1401            );
1402        } else {
1403            panic!("Expected function call part");
1404        }
1405
1406        if let Part {
1407            part: PartKind::FunctionResponse(function_response),
1408            ..
1409        } = &parts[3]
1410        {
1411            assert_eq!(function_response.name, "test_function");
1412            assert_eq!(
1413                function_response
1414                    .response
1415                    .as_ref()
1416                    .unwrap()
1417                    .get("result")
1418                    .unwrap(),
1419                "success"
1420            );
1421        } else {
1422            panic!("Expected function response part");
1423        }
1424
1425        if let Part {
1426            part: PartKind::FileData(file_data),
1427            ..
1428        } = &parts[4]
1429        {
1430            assert_eq!(file_data.mime_type.as_ref().unwrap(), "application/pdf");
1431            assert_eq!(file_data.file_uri, "http://example.com/file.pdf");
1432        } else {
1433            panic!("Expected file data part");
1434        }
1435
1436        if let Part {
1437            part: PartKind::ExecutableCode(executable_code),
1438            ..
1439        } = &parts[5]
1440        {
1441            assert_eq!(executable_code.code, "print('Hello, world!')");
1442        } else {
1443            panic!("Expected executable code part");
1444        }
1445
1446        if let Part {
1447            part: PartKind::CodeExecutionResult(code_execution_result),
1448            ..
1449        } = &parts[6]
1450        {
1451            assert_eq!(
1452                code_execution_result.clone().output.unwrap(),
1453                "Hello, world!"
1454            );
1455        } else {
1456            panic!("Expected code execution result part");
1457        }
1458    }
1459
1460    #[test]
1461    fn test_deserialize_message_model() {
1462        let json_data = json!({
1463            "parts": [{"text": "Hello, user!"}],
1464            "role": "model"
1465        });
1466
1467        let content: Content = serde_json::from_value(json_data).unwrap();
1468        assert_eq!(content.role, Some(Role::Model));
1469        assert_eq!(content.parts.len(), 1);
1470        if let Some(Part {
1471            part: PartKind::Text(text),
1472            ..
1473        }) = content.parts.first()
1474        {
1475            assert_eq!(text, "Hello, user!");
1476        } else {
1477            panic!("Expected text part");
1478        }
1479    }
1480
1481    #[test]
1482    fn test_message_conversion_user() {
1483        let msg = message::Message::user("Hello, world!");
1484        let content: Content = msg.try_into().unwrap();
1485        assert_eq!(content.role, Some(Role::User));
1486        assert_eq!(content.parts.len(), 1);
1487        if let Some(Part {
1488            part: PartKind::Text(text),
1489            ..
1490        }) = &content.parts.first()
1491        {
1492            assert_eq!(text, "Hello, world!");
1493        } else {
1494            panic!("Expected text part");
1495        }
1496    }
1497
1498    #[test]
1499    fn test_message_conversion_model() {
1500        let msg = message::Message::assistant("Hello, user!");
1501
1502        let content: Content = msg.try_into().unwrap();
1503        assert_eq!(content.role, Some(Role::Model));
1504        assert_eq!(content.parts.len(), 1);
1505        if let Some(Part {
1506            part: PartKind::Text(text),
1507            ..
1508        }) = &content.parts.first()
1509        {
1510            assert_eq!(text, "Hello, user!");
1511        } else {
1512            panic!("Expected text part");
1513        }
1514    }
1515
1516    #[test]
1517    fn test_message_conversion_tool_call() {
1518        let tool_call = message::ToolCall {
1519            id: "test_tool".to_string(),
1520            call_id: None,
1521            function: message::ToolFunction {
1522                name: "test_function".to_string(),
1523                arguments: json!({"arg1": "value1"}),
1524            },
1525        };
1526
1527        let msg = message::Message::Assistant {
1528            id: None,
1529            content: OneOrMany::one(message::AssistantContent::ToolCall(tool_call)),
1530        };
1531
1532        let content: Content = msg.try_into().unwrap();
1533        assert_eq!(content.role, Some(Role::Model));
1534        assert_eq!(content.parts.len(), 1);
1535        if let Some(Part {
1536            part: PartKind::FunctionCall(function_call),
1537            ..
1538        }) = content.parts.first()
1539        {
1540            assert_eq!(function_call.name, "test_function");
1541            assert_eq!(
1542                function_call.args.as_object().unwrap().get("arg1").unwrap(),
1543                "value1"
1544            );
1545        } else {
1546            panic!("Expected function call part");
1547        }
1548    }
1549
1550    #[test]
1551    fn test_vec_schema_conversion() {
1552        let schema_with_ref = json!({
1553            "type": "array",
1554            "items": {
1555                "$ref": "#/$defs/Person"
1556            },
1557            "$defs": {
1558                "Person": {
1559                    "type": "object",
1560                    "properties": {
1561                        "first_name": {
1562                            "type": ["string", "null"],
1563                            "description": "The person's first name, if provided (null otherwise)"
1564                        },
1565                        "last_name": {
1566                            "type": ["string", "null"],
1567                            "description": "The person's last name, if provided (null otherwise)"
1568                        },
1569                        "job": {
1570                            "type": ["string", "null"],
1571                            "description": "The person's job, if provided (null otherwise)"
1572                        }
1573                    },
1574                    "required": []
1575                }
1576            }
1577        });
1578
1579        let result: Result<Schema, _> = schema_with_ref.try_into();
1580
1581        match result {
1582            Ok(schema) => {
1583                assert_eq!(schema.r#type, "array");
1584
1585                if let Some(items) = schema.items {
1586                    println!("item types: {}", items.r#type);
1587
1588                    assert_ne!(items.r#type, "", "Items type should not be empty string!");
1589                    assert_eq!(items.r#type, "object", "Items should be object type");
1590                } else {
1591                    panic!("Schema should have items field for array type");
1592                }
1593            }
1594            Err(e) => println!("Schema conversion failed: {:?}", e),
1595        }
1596    }
1597
1598    #[test]
1599    fn test_object_schema() {
1600        let simple_schema = json!({
1601            "type": "object",
1602            "properties": {
1603                "name": {
1604                    "type": "string"
1605                }
1606            }
1607        });
1608
1609        let schema: Schema = simple_schema.try_into().unwrap();
1610        assert_eq!(schema.r#type, "object");
1611        assert!(schema.properties.is_some());
1612    }
1613
1614    #[test]
1615    fn test_array_with_inline_items() {
1616        let inline_schema = json!({
1617            "type": "array",
1618            "items": {
1619                "type": "object",
1620                "properties": {
1621                    "name": {
1622                        "type": "string"
1623                    }
1624                }
1625            }
1626        });
1627
1628        let schema: Schema = inline_schema.try_into().unwrap();
1629        assert_eq!(schema.r#type, "array");
1630
1631        if let Some(items) = schema.items {
1632            assert_eq!(items.r#type, "object");
1633            assert!(items.properties.is_some());
1634        } else {
1635            panic!("Schema should have items field");
1636        }
1637    }
1638    #[test]
1639    fn test_flattened_schema() {
1640        let ref_schema = json!({
1641            "type": "array",
1642            "items": {
1643                "$ref": "#/$defs/Person"
1644            },
1645            "$defs": {
1646                "Person": {
1647                    "type": "object",
1648                    "properties": {
1649                        "name": { "type": "string" }
1650                    }
1651                }
1652            }
1653        });
1654
1655        let flattened = flatten_schema(ref_schema).unwrap();
1656        let schema: Schema = flattened.try_into().unwrap();
1657
1658        assert_eq!(schema.r#type, "array");
1659
1660        if let Some(items) = schema.items {
1661            println!("Flattened items type: '{}'", items.r#type);
1662
1663            assert_eq!(items.r#type, "object");
1664            assert!(items.properties.is_some());
1665        }
1666    }
1667}