rig/providers/gemini/
completion.rs

1// ================================================================
2//! Google Gemini Completion Integration
3//! From [Gemini API Reference](https://ai.google.dev/api/generate-content)
4// ================================================================
5/// `gemini-2.5-pro-preview-06-05` completion model
6pub const GEMINI_2_5_PRO_PREVIEW_06_05: &str = "gemini-2.5-pro-preview-06-05";
7/// `gemini-2.5-pro-preview-05-06` completion model
8pub const GEMINI_2_5_PRO_PREVIEW_05_06: &str = "gemini-2.5-pro-preview-05-06";
9/// `gemini-2.5-pro-preview-03-25` completion model
10pub const GEMINI_2_5_PRO_PREVIEW_03_25: &str = "gemini-2.5-pro-preview-03-25";
11/// `gemini-2.5-flash-preview-05-20` completion model
12pub const GEMINI_2_5_FLASH_PREVIEW_05_20: &str = "gemini-2.5-flash-preview-05-20";
13/// `gemini-2.5-flash-preview-04-17` completion model
14pub const GEMINI_2_5_FLASH_PREVIEW_04_17: &str = "gemini-2.5-flash-preview-04-17";
15/// `gemini-2.5-pro-exp-03-25` experimental completion model
16pub const GEMINI_2_5_PRO_EXP_03_25: &str = "gemini-2.5-pro-exp-03-25";
17/// `gemini-2.0-flash-lite` completion model
18pub const GEMINI_2_0_FLASH_LITE: &str = "gemini-2.0-flash-lite";
19/// `gemini-2.0-flash` completion model
20pub const GEMINI_2_0_FLASH: &str = "gemini-2.0-flash";
21/// `gemini-1.5-flash` completion model
22pub const GEMINI_1_5_FLASH: &str = "gemini-1.5-flash";
23/// `gemini-1.5-pro` completion model
24pub const GEMINI_1_5_PRO: &str = "gemini-1.5-pro";
25/// `gemini-1.5-pro-8b` completion model
26pub const GEMINI_1_5_PRO_8B: &str = "gemini-1.5-pro-8b";
27/// `gemini-1.0-pro` completion model
28pub const GEMINI_1_0_PRO: &str = "gemini-1.0-pro";
29
30use self::gemini_api_types::Schema;
31use crate::message::Reasoning;
32use crate::providers::gemini::completion::gemini_api_types::AdditionalParameters;
33use crate::providers::gemini::streaming::StreamingCompletionResponse;
34use crate::{
35    OneOrMany,
36    completion::{self, CompletionError, CompletionRequest},
37};
38use gemini_api_types::{
39    Content, FunctionDeclaration, GenerateContentRequest, GenerateContentResponse, Part, PartKind,
40    Role, Tool,
41};
42use serde_json::{Map, Value};
43use std::convert::TryFrom;
44
45use super::Client;
46
47// =================================================================
48// Rig Implementation Types
49// =================================================================
50
51#[derive(Clone)]
52pub struct CompletionModel {
53    pub(crate) client: Client,
54    pub model: String,
55}
56
57impl CompletionModel {
58    pub fn new(client: Client, model: &str) -> Self {
59        Self {
60            client,
61            model: model.to_string(),
62        }
63    }
64}
65
66impl completion::CompletionModel for CompletionModel {
67    type Response = GenerateContentResponse;
68    type StreamingResponse = StreamingCompletionResponse;
69
70    #[cfg_attr(feature = "worker", worker::send)]
71    async fn completion(
72        &self,
73        completion_request: CompletionRequest,
74    ) -> Result<completion::CompletionResponse<GenerateContentResponse>, CompletionError> {
75        let request = create_request_body(completion_request)?;
76
77        tracing::debug!(
78            "Sending completion request to Gemini API {}",
79            serde_json::to_string_pretty(&request)?
80        );
81
82        let response = self
83            .client
84            .post(&format!("/v1beta/models/{}:generateContent", self.model))
85            .json(&request)
86            .send()
87            .await?;
88
89        if response.status().is_success() {
90            let response = response.json::<GenerateContentResponse>().await?;
91            match response.usage_metadata {
92                Some(ref usage) => tracing::info!(target: "rig",
93                "Gemini completion token usage: {}",
94                usage
95                ),
96                None => tracing::info!(target: "rig",
97                    "Gemini completion token usage: n/a",
98                ),
99            }
100
101            tracing::debug!("Received response");
102
103            Ok(completion::CompletionResponse::try_from(response))
104        } else {
105            Err(CompletionError::ProviderError(response.text().await?))
106        }?
107    }
108
109    #[cfg_attr(feature = "worker", worker::send)]
110    async fn stream(
111        &self,
112        request: CompletionRequest,
113    ) -> Result<
114        crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
115        CompletionError,
116    > {
117        CompletionModel::stream(self, request).await
118    }
119}
120
121pub(crate) fn create_request_body(
122    completion_request: CompletionRequest,
123) -> Result<GenerateContentRequest, CompletionError> {
124    let mut full_history = Vec::new();
125    full_history.extend(completion_request.chat_history);
126
127    let additional_params = completion_request
128        .additional_params
129        .unwrap_or_else(|| Value::Object(Map::new()));
130
131    let AdditionalParameters {
132        mut generation_config,
133        additional_params,
134    } = serde_json::from_value::<AdditionalParameters>(additional_params)?;
135
136    if let Some(temp) = completion_request.temperature {
137        generation_config.temperature = Some(temp);
138    }
139
140    if let Some(max_tokens) = completion_request.max_tokens {
141        generation_config.max_output_tokens = Some(max_tokens);
142    }
143
144    let system_instruction = completion_request.preamble.clone().map(|preamble| Content {
145        parts: vec![preamble.into()],
146        role: Some(Role::Model),
147    });
148
149    let tools = if completion_request.tools.is_empty() {
150        None
151    } else {
152        Some(Tool::try_from(completion_request.tools)?)
153    };
154
155    let request = GenerateContentRequest {
156        contents: full_history
157            .into_iter()
158            .map(|msg| {
159                msg.try_into()
160                    .map_err(|e| CompletionError::RequestError(Box::new(e)))
161            })
162            .collect::<Result<Vec<_>, _>>()?,
163        generation_config: Some(generation_config),
164        safety_settings: None,
165        tools,
166        tool_config: None,
167        system_instruction,
168        additional_params,
169    };
170
171    Ok(request)
172}
173
174impl TryFrom<completion::ToolDefinition> for Tool {
175    type Error = CompletionError;
176
177    fn try_from(tool: completion::ToolDefinition) -> Result<Self, Self::Error> {
178        let parameters: Option<Schema> =
179            if tool.parameters == serde_json::json!({"type": "object", "properties": {}}) {
180                None
181            } else {
182                Some(tool.parameters.try_into()?)
183            };
184
185        Ok(Self {
186            function_declarations: vec![FunctionDeclaration {
187                name: tool.name,
188                description: tool.description,
189                parameters,
190            }],
191            code_execution: None,
192        })
193    }
194}
195
196impl TryFrom<Vec<completion::ToolDefinition>> for Tool {
197    type Error = CompletionError;
198
199    fn try_from(tools: Vec<completion::ToolDefinition>) -> Result<Self, Self::Error> {
200        let mut function_declarations = Vec::new();
201
202        for tool in tools {
203            let parameters =
204                if tool.parameters == serde_json::json!({"type": "object", "properties": {}}) {
205                    None
206                } else {
207                    match tool.parameters.try_into() {
208                        Ok(schema) => Some(schema),
209                        Err(e) => {
210                            let emsg = format!(
211                                "Tool '{}' could not be converted to a schema: {:?}",
212                                tool.name, e,
213                            );
214                            return Err(CompletionError::ProviderError(emsg));
215                        }
216                    }
217                };
218
219            function_declarations.push(FunctionDeclaration {
220                name: tool.name,
221                description: tool.description,
222                parameters,
223            });
224        }
225
226        Ok(Self {
227            function_declarations,
228            code_execution: None,
229        })
230    }
231}
232
233impl TryFrom<GenerateContentResponse> for completion::CompletionResponse<GenerateContentResponse> {
234    type Error = CompletionError;
235
236    fn try_from(response: GenerateContentResponse) -> Result<Self, Self::Error> {
237        let candidate = response.candidates.first().ok_or_else(|| {
238            CompletionError::ResponseError("No response candidates in response".into())
239        })?;
240
241        let content = candidate
242            .content
243            .parts
244            .iter()
245            .map(|Part { thought, part, .. }| {
246                Ok(match part {
247                    PartKind::Text(text) => {
248                        if let Some(thought) = thought
249                            && *thought
250                        {
251                            completion::AssistantContent::Reasoning(Reasoning::new(text))
252                        } else {
253                            completion::AssistantContent::text(text)
254                        }
255                    }
256                    PartKind::FunctionCall(function_call) => {
257                        completion::AssistantContent::tool_call(
258                            &function_call.name,
259                            &function_call.name,
260                            function_call.args.clone(),
261                        )
262                    }
263                    _ => {
264                        return Err(CompletionError::ResponseError(
265                            "Response did not contain a message or tool call".into(),
266                        ));
267                    }
268                })
269            })
270            .collect::<Result<Vec<_>, _>>()?;
271
272        let choice = OneOrMany::many(content).map_err(|_| {
273            CompletionError::ResponseError(
274                "Response contained no message or tool call (empty)".to_owned(),
275            )
276        })?;
277
278        let usage = response
279            .usage_metadata
280            .as_ref()
281            .map(|usage| completion::Usage {
282                input_tokens: usage.prompt_token_count as u64,
283                output_tokens: usage.candidates_token_count.unwrap_or(0) as u64,
284                total_tokens: usage.total_token_count as u64,
285            })
286            .unwrap_or_default();
287
288        Ok(completion::CompletionResponse {
289            choice,
290            usage,
291            raw_response: response,
292        })
293    }
294}
295
296pub mod gemini_api_types {
297    use std::{collections::HashMap, convert::Infallible, str::FromStr};
298
299    // =================================================================
300    // Gemini API Types
301    // =================================================================
302    use serde::{Deserialize, Serialize};
303    use serde_json::{Value, json};
304
305    use crate::message::{ContentFormat, DocumentSourceKind, ImageMediaType};
306    use crate::{
307        OneOrMany,
308        completion::CompletionError,
309        message::{self, MimeType as _, Reasoning, Text},
310        providers::gemini::gemini_api_types::{CodeExecutionResult, ExecutableCode},
311    };
312
313    #[derive(Debug, Deserialize, Serialize, Default)]
314    #[serde(rename_all = "camelCase")]
315    pub struct AdditionalParameters {
316        /// Change your Gemini request configuration.
317        pub generation_config: GenerationConfig,
318        /// Any additional parameters that you want.
319        #[serde(flatten, skip_serializing_if = "Option::is_none")]
320        pub additional_params: Option<serde_json::Value>,
321    }
322
323    impl AdditionalParameters {
324        pub fn with_config(mut self, cfg: GenerationConfig) -> Self {
325            self.generation_config = cfg;
326            self
327        }
328
329        pub fn with_params(mut self, params: serde_json::Value) -> Self {
330            self.additional_params = Some(params);
331            self
332        }
333    }
334
335    /// Response from the model supporting multiple candidate responses.
336    /// Safety ratings and content filtering are reported for both prompt in GenerateContentResponse.prompt_feedback
337    /// and for each candidate in finishReason and in safetyRatings.
338    /// The API:
339    ///     - Returns either all requested candidates or none of them
340    ///     - Returns no candidates at all only if there was something wrong with the prompt (check promptFeedback)
341    ///     - Reports feedback on each candidate in finishReason and safetyRatings.
342    #[derive(Debug, Deserialize, Serialize)]
343    #[serde(rename_all = "camelCase")]
344    pub struct GenerateContentResponse {
345        /// Candidate responses from the model.
346        pub candidates: Vec<ContentCandidate>,
347        /// Returns the prompt's feedback related to the content filters.
348        pub prompt_feedback: Option<PromptFeedback>,
349        /// Output only. Metadata on the generation requests' token usage.
350        pub usage_metadata: Option<UsageMetadata>,
351        pub model_version: Option<String>,
352    }
353
354    /// A response candidate generated from the model.
355    #[derive(Debug, Deserialize, Serialize)]
356    #[serde(rename_all = "camelCase")]
357    pub struct ContentCandidate {
358        /// Output only. Generated content returned from the model.
359        pub content: Content,
360        /// Optional. Output only. The reason why the model stopped generating tokens.
361        /// If empty, the model has not stopped generating tokens.
362        pub finish_reason: Option<FinishReason>,
363        /// List of ratings for the safety of a response candidate.
364        /// There is at most one rating per category.
365        pub safety_ratings: Option<Vec<SafetyRating>>,
366        /// Output only. Citation information for model-generated candidate.
367        /// This field may be populated with recitation information for any text included in the content.
368        /// These are passages that are "recited" from copyrighted material in the foundational LLM's training data.
369        pub citation_metadata: Option<CitationMetadata>,
370        /// Output only. Token count for this candidate.
371        pub token_count: Option<i32>,
372        /// Output only.
373        pub avg_logprobs: Option<f64>,
374        /// Output only. Log-likelihood scores for the response tokens and top tokens
375        pub logprobs_result: Option<LogprobsResult>,
376        /// Output only. Index of the candidate in the list of response candidates.
377        pub index: Option<i32>,
378    }
379
380    #[derive(Debug, Deserialize, Serialize)]
381    pub struct Content {
382        /// Ordered Parts that constitute a single message. Parts may have different MIME types.
383        #[serde(default)]
384        pub parts: Vec<Part>,
385        /// The producer of the content. Must be either 'user' or 'model'.
386        /// Useful to set for multi-turn conversations, otherwise can be left blank or unset.
387        pub role: Option<Role>,
388    }
389
390    impl TryFrom<message::Message> for Content {
391        type Error = message::MessageError;
392
393        fn try_from(msg: message::Message) -> Result<Self, Self::Error> {
394            Ok(match msg {
395                message::Message::User { content } => Content {
396                    parts: content
397                        .into_iter()
398                        .map(|c| c.try_into())
399                        .collect::<Result<Vec<_>, _>>()?,
400                    role: Some(Role::User),
401                },
402                message::Message::Assistant { content, .. } => Content {
403                    role: Some(Role::Model),
404                    parts: content.into_iter().map(|content| content.into()).collect(),
405                },
406            })
407        }
408    }
409
410    impl TryFrom<Content> for message::Message {
411        type Error = message::MessageError;
412
413        fn try_from(content: Content) -> Result<Self, Self::Error> {
414            match content.role {
415                Some(Role::User) | None => {
416                    Ok(message::Message::User {
417                        content: {
418                            let user_content: Result<Vec<_>, _> = content.parts.into_iter()
419                            .map(|Part { part, .. }| {
420                                Ok(match part {
421                                    PartKind::Text(text) => message::UserContent::text(text),
422                                    PartKind::InlineData(inline_data) => {
423                                        let mime_type =
424                                            message::MediaType::from_mime_type(&inline_data.mime_type);
425
426                                        match mime_type {
427                                            Some(message::MediaType::Image(media_type)) => {
428                                                message::UserContent::image_base64(
429                                                    inline_data.data,
430                                                    Some(media_type),
431                                                    Some(message::ImageDetail::default()),
432                                                )
433                                            }
434                                            Some(message::MediaType::Document(media_type)) => {
435                                                message::UserContent::document(
436                                                    inline_data.data,
437                                                    Some(message::ContentFormat::default()),
438                                                    Some(media_type),
439                                                )
440                                            }
441                                            Some(message::MediaType::Audio(media_type)) => {
442                                                message::UserContent::audio(
443                                                    inline_data.data,
444                                                    Some(message::ContentFormat::default()),
445                                                    Some(media_type),
446                                                )
447                                            }
448                                            _ => {
449                                                return Err(message::MessageError::ConversionError(
450                                                    format!("Unsupported media type {mime_type:?}"),
451                                                ));
452                                            }
453                                        }
454                                    }
455                                    _ => {
456                                        return Err(message::MessageError::ConversionError(format!(
457                                            "Unsupported gemini content part type: {part:?}"
458                                        )));
459                                    }
460                                })
461                            })
462                            .collect();
463                            OneOrMany::many(user_content?).map_err(|_| {
464                                message::MessageError::ConversionError(
465                                    "Failed to create OneOrMany from user content".to_string(),
466                                )
467                            })?
468                        },
469                    })
470                }
471                Some(Role::Model) => Ok(message::Message::Assistant {
472                    id: None,
473                    content: {
474                        let assistant_content: Result<Vec<_>, _> = content
475                            .parts
476                            .into_iter()
477                            .map(|Part { thought, part, .. }| {
478                                Ok(match part {
479                                    PartKind::Text(text) => match thought {
480                                        Some(true) => message::AssistantContent::Reasoning(
481                                            Reasoning::new(&text),
482                                        ),
483                                        _ => message::AssistantContent::Text(Text { text }),
484                                    },
485
486                                    PartKind::FunctionCall(function_call) => {
487                                        message::AssistantContent::ToolCall(function_call.into())
488                                    }
489                                    _ => {
490                                        return Err(message::MessageError::ConversionError(
491                                            format!("Unsupported part type: {part:?}"),
492                                        ));
493                                    }
494                                })
495                            })
496                            .collect();
497                        OneOrMany::many(assistant_content?).map_err(|_| {
498                            message::MessageError::ConversionError(
499                                "Failed to create OneOrMany from assistant content".to_string(),
500                            )
501                        })?
502                    },
503                }),
504            }
505        }
506    }
507
508    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
509    #[serde(rename_all = "lowercase")]
510    pub enum Role {
511        User,
512        Model,
513    }
514
515    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
516    #[serde(rename_all = "camelCase")]
517    pub struct Part {
518        /// whether or not the part is a reasoning/thinking text or not
519        #[serde(skip_serializing_if = "Option::is_none")]
520        pub thought: Option<bool>,
521        /// an opaque sig for the thought so it can be reused - is a base64 string
522        #[serde(skip_serializing_if = "Option::is_none")]
523        pub thought_signature: Option<String>,
524        #[serde(flatten)]
525        pub part: PartKind,
526        #[serde(flatten, skip_serializing_if = "Option::is_none")]
527        pub additional_params: Option<Value>,
528    }
529
530    /// A datatype containing media that is part of a multi-part [Content] message.
531    /// A Part consists of data which has an associated datatype. A Part can only contain one of the accepted types in Part.data.
532    /// A Part must have a fixed IANA MIME type identifying the type and subtype of the media if the inlineData field is filled with raw bytes.
533    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
534    #[serde(rename_all = "camelCase")]
535    pub enum PartKind {
536        Text(String),
537        InlineData(Blob),
538        FunctionCall(FunctionCall),
539        FunctionResponse(FunctionResponse),
540        FileData(FileData),
541        ExecutableCode(ExecutableCode),
542        CodeExecutionResult(CodeExecutionResult),
543    }
544
545    impl From<String> for Part {
546        fn from(text: String) -> Self {
547            Self {
548                thought: Some(false),
549                thought_signature: None,
550                part: PartKind::Text(text),
551                additional_params: None,
552            }
553        }
554    }
555
556    impl From<&str> for Part {
557        fn from(text: &str) -> Self {
558            Self::from(text.to_string())
559        }
560    }
561
562    impl FromStr for Part {
563        type Err = Infallible;
564
565        fn from_str(s: &str) -> Result<Self, Self::Err> {
566            Ok(s.into())
567        }
568    }
569
570    impl TryFrom<(ImageMediaType, DocumentSourceKind)> for PartKind {
571        type Error = message::MessageError;
572        fn try_from(
573            (mime_type, doc_src): (ImageMediaType, DocumentSourceKind),
574        ) -> Result<Self, Self::Error> {
575            let mime_type = mime_type.to_mime_type().to_string();
576            let part = match doc_src {
577                DocumentSourceKind::Url(url) => PartKind::FileData(FileData {
578                    mime_type: Some(mime_type),
579                    file_uri: url,
580                }),
581                DocumentSourceKind::Base64(data) => PartKind::InlineData(Blob { mime_type, data }),
582                DocumentSourceKind::Unknown => {
583                    return Err(message::MessageError::ConversionError(
584                        "Can't convert an unknown document source".to_string(),
585                    ));
586                }
587            };
588
589            Ok(part)
590        }
591    }
592
593    impl TryFrom<message::UserContent> for Part {
594        type Error = message::MessageError;
595
596        fn try_from(content: message::UserContent) -> Result<Self, Self::Error> {
597            match content {
598                message::UserContent::Text(message::Text { text }) => Ok(Part {
599                    thought: Some(false),
600                    thought_signature: None,
601                    part: PartKind::Text(text),
602                    additional_params: None,
603                }),
604                message::UserContent::ToolResult(message::ToolResult { id, content, .. }) => {
605                    let content = match content.first() {
606                        message::ToolResultContent::Text(text) => text.text,
607                        message::ToolResultContent::Image(_) => {
608                            return Err(message::MessageError::ConversionError(
609                                "Tool result content must be text".to_string(),
610                            ));
611                        }
612                    };
613                    // Convert to JSON since this value may be a valid JSON value
614                    let result: serde_json::Value =
615                        serde_json::from_str(&content).unwrap_or_else(|error| {
616                            tracing::trace!(
617                                ?error,
618                                "Tool result is not a valid JSON, treat it as normal string"
619                            );
620                            json!(content)
621                        });
622                    Ok(Part {
623                        thought: Some(false),
624                        thought_signature: None,
625                        part: PartKind::FunctionResponse(FunctionResponse {
626                            name: id,
627                            response: Some(json!({ "result": result })),
628                        }),
629                        additional_params: None,
630                    })
631                }
632                message::UserContent::Image(message::Image {
633                    data, media_type, ..
634                }) => match media_type {
635                    Some(media_type) => match media_type {
636                        message::ImageMediaType::JPEG
637                        | message::ImageMediaType::PNG
638                        | message::ImageMediaType::WEBP
639                        | message::ImageMediaType::HEIC
640                        | message::ImageMediaType::HEIF => {
641                            let part = PartKind::try_from((media_type, data))?;
642                            Ok(Part {
643                                thought: Some(false),
644                                thought_signature: None,
645                                part,
646                                additional_params: None,
647                            })
648                        }
649                        _ => Err(message::MessageError::ConversionError(format!(
650                            "Unsupported image media type {media_type:?}"
651                        ))),
652                    },
653                    None => Err(message::MessageError::ConversionError(
654                        "Media type for image is required for Gemini".to_string(), // Fixed error message
655                    )),
656                },
657                message::UserContent::Document(message::Document {
658                    data, media_type, ..
659                }) => match media_type {
660                    Some(media_type) => match media_type {
661                        message::DocumentMediaType::PDF
662                        | message::DocumentMediaType::TXT
663                        | message::DocumentMediaType::RTF
664                        | message::DocumentMediaType::HTML
665                        | message::DocumentMediaType::CSS
666                        | message::DocumentMediaType::MARKDOWN
667                        | message::DocumentMediaType::CSV
668                        | message::DocumentMediaType::XML => Ok(Part {
669                            thought: Some(false),
670                            thought_signature: None,
671                            part: PartKind::InlineData(Blob {
672                                mime_type: media_type.to_mime_type().to_owned(),
673                                data,
674                            }),
675                            additional_params: None,
676                        }),
677                        _ => Err(message::MessageError::ConversionError(format!(
678                            "Unsupported document media type {media_type:?}"
679                        ))),
680                    },
681                    None => Err(message::MessageError::ConversionError(
682                        "Media type for document is required for Gemini".to_string(), // Fixed error message
683                    )),
684                },
685                message::UserContent::Audio(message::Audio {
686                    data, media_type, ..
687                }) => match media_type {
688                    Some(media_type) => Ok(Part {
689                        thought: Some(false),
690                        thought_signature: None,
691                        part: PartKind::InlineData(Blob {
692                            mime_type: media_type.to_mime_type().to_owned(),
693                            data,
694                        }),
695                        additional_params: None,
696                    }),
697                    None => Err(message::MessageError::ConversionError(
698                        "Media type for audio is required for Gemini".to_string(),
699                    )),
700                },
701                message::UserContent::Video(message::Video {
702                    data,
703                    media_type,
704                    format,
705                    additional_params,
706                }) => {
707                    let mime_type = media_type.map(|m| m.to_mime_type().to_owned());
708
709                    let data = match format {
710                        Some(ContentFormat::String) => PartKind::FileData(FileData {
711                            mime_type,
712                            file_uri: data,
713                        }),
714                        _ => match mime_type {
715                            Some(mime_type) => PartKind::InlineData(Blob { mime_type, data }),
716                            None => {
717                                return Err(message::MessageError::ConversionError(
718                                    "Media type for video is required for Gemini".to_string(),
719                                ));
720                            }
721                        },
722                    };
723
724                    Ok(Part {
725                        thought: Some(false),
726                        thought_signature: None,
727                        part: data,
728                        additional_params,
729                    })
730                }
731            }
732        }
733    }
734
735    impl From<message::AssistantContent> for Part {
736        fn from(content: message::AssistantContent) -> Self {
737            match content {
738                message::AssistantContent::Text(message::Text { text }) => text.into(),
739                message::AssistantContent::ToolCall(tool_call) => tool_call.into(),
740                message::AssistantContent::Reasoning(message::Reasoning { reasoning, .. }) => {
741                    Part {
742                        thought: Some(true),
743                        thought_signature: None,
744                        part: PartKind::Text(
745                            reasoning.first().cloned().unwrap_or_else(|| "".to_string()),
746                        ),
747                        additional_params: None,
748                    }
749                }
750            }
751        }
752    }
753
754    impl From<message::ToolCall> for Part {
755        fn from(tool_call: message::ToolCall) -> Self {
756            Self {
757                thought: Some(false),
758                thought_signature: None,
759                part: PartKind::FunctionCall(FunctionCall {
760                    name: tool_call.function.name,
761                    args: tool_call.function.arguments,
762                }),
763                additional_params: None,
764            }
765        }
766    }
767
768    /// Raw media bytes.
769    /// Text should not be sent as raw bytes, use the 'text' field.
770    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
771    #[serde(rename_all = "camelCase")]
772    pub struct Blob {
773        /// The IANA standard MIME type of the source data. Examples: - image/png - image/jpeg
774        /// If an unsupported MIME type is provided, an error will be returned.
775        pub mime_type: String,
776        /// Raw bytes for media formats. A base64-encoded string.
777        pub data: String,
778    }
779
780    /// A predicted FunctionCall returned from the model that contains a string representing the
781    /// FunctionDeclaration.name with the arguments and their values.
782    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
783    pub struct FunctionCall {
784        /// Required. The name of the function to call. Must be a-z, A-Z, 0-9, or contain underscores
785        /// and dashes, with a maximum length of 63.
786        pub name: String,
787        /// Optional. The function parameters and values in JSON object format.
788        pub args: serde_json::Value,
789    }
790
791    impl From<FunctionCall> for message::ToolCall {
792        fn from(function_call: FunctionCall) -> Self {
793            Self {
794                id: function_call.name.clone(),
795                call_id: None,
796                function: message::ToolFunction {
797                    name: function_call.name,
798                    arguments: function_call.args,
799                },
800            }
801        }
802    }
803
804    impl From<message::ToolCall> for FunctionCall {
805        fn from(tool_call: message::ToolCall) -> Self {
806            Self {
807                name: tool_call.function.name,
808                args: tool_call.function.arguments,
809            }
810        }
811    }
812
813    /// The result output from a FunctionCall that contains a string representing the FunctionDeclaration.name
814    /// and a structured JSON object containing any output from the function is used as context to the model.
815    /// This should contain the result of aFunctionCall made based on model prediction.
816    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
817    pub struct FunctionResponse {
818        /// The name of the function to call. Must be a-z, A-Z, 0-9, or contain underscores and dashes,
819        /// with a maximum length of 63.
820        pub name: String,
821        /// The function response in JSON object format.
822        pub response: Option<serde_json::Value>,
823    }
824
825    /// URI based data.
826    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
827    #[serde(rename_all = "camelCase")]
828    pub struct FileData {
829        /// Optional. The IANA standard MIME type of the source data.
830        pub mime_type: Option<String>,
831        /// Required. URI.
832        pub file_uri: String,
833    }
834
835    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
836    pub struct SafetyRating {
837        pub category: HarmCategory,
838        pub probability: HarmProbability,
839    }
840
841    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
842    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
843    pub enum HarmProbability {
844        HarmProbabilityUnspecified,
845        Negligible,
846        Low,
847        Medium,
848        High,
849    }
850
851    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
852    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
853    pub enum HarmCategory {
854        HarmCategoryUnspecified,
855        HarmCategoryDerogatory,
856        HarmCategoryToxicity,
857        HarmCategoryViolence,
858        HarmCategorySexually,
859        HarmCategoryMedical,
860        HarmCategoryDangerous,
861        HarmCategoryHarassment,
862        HarmCategoryHateSpeech,
863        HarmCategorySexuallyExplicit,
864        HarmCategoryDangerousContent,
865        HarmCategoryCivicIntegrity,
866    }
867
868    #[derive(Debug, Deserialize, Clone, Default, Serialize)]
869    #[serde(rename_all = "camelCase")]
870    pub struct UsageMetadata {
871        pub prompt_token_count: i32,
872        #[serde(skip_serializing_if = "Option::is_none")]
873        pub cached_content_token_count: Option<i32>,
874        #[serde(skip_serializing_if = "Option::is_none")]
875        pub candidates_token_count: Option<i32>,
876        pub total_token_count: i32,
877        #[serde(skip_serializing_if = "Option::is_none")]
878        pub thoughts_token_count: Option<i32>,
879    }
880
881    impl std::fmt::Display for UsageMetadata {
882        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
883            write!(
884                f,
885                "Prompt token count: {}\nCached content token count: {}\nCandidates token count: {}\nTotal token count: {}",
886                self.prompt_token_count,
887                match self.cached_content_token_count {
888                    Some(count) => count.to_string(),
889                    None => "n/a".to_string(),
890                },
891                match self.candidates_token_count {
892                    Some(count) => count.to_string(),
893                    None => "n/a".to_string(),
894                },
895                self.total_token_count
896            )
897        }
898    }
899
900    /// A set of the feedback metadata the prompt specified in [GenerateContentRequest.contents](GenerateContentRequest).
901    #[derive(Debug, Deserialize, Serialize)]
902    #[serde(rename_all = "camelCase")]
903    pub struct PromptFeedback {
904        /// Optional. If set, the prompt was blocked and no candidates are returned. Rephrase the prompt.
905        pub block_reason: Option<BlockReason>,
906        /// Ratings for safety of the prompt. There is at most one rating per category.
907        pub safety_ratings: Option<Vec<SafetyRating>>,
908    }
909
910    /// Reason why a prompt was blocked by the model
911    #[derive(Debug, Deserialize, Serialize)]
912    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
913    pub enum BlockReason {
914        /// Default value. This value is unused.
915        BlockReasonUnspecified,
916        /// Prompt was blocked due to safety reasons. Inspect safetyRatings to understand which safety category blocked it.
917        Safety,
918        /// Prompt was blocked due to unknown reasons.
919        Other,
920        /// Prompt was blocked due to the terms which are included from the terminology blocklist.
921        Blocklist,
922        /// Prompt was blocked due to prohibited content.
923        ProhibitedContent,
924    }
925
926    #[derive(Debug, Deserialize, Serialize)]
927    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
928    pub enum FinishReason {
929        /// Default value. This value is unused.
930        FinishReasonUnspecified,
931        /// Natural stop point of the model or provided stop sequence.
932        Stop,
933        /// The maximum number of tokens as specified in the request was reached.
934        MaxTokens,
935        /// The response candidate content was flagged for safety reasons.
936        Safety,
937        /// The response candidate content was flagged for recitation reasons.
938        Recitation,
939        /// The response candidate content was flagged for using an unsupported language.
940        Language,
941        /// Unknown reason.
942        Other,
943        /// Token generation stopped because the content contains forbidden terms.
944        Blocklist,
945        /// Token generation stopped for potentially containing prohibited content.
946        ProhibitedContent,
947        /// Token generation stopped because the content potentially contains Sensitive Personally Identifiable Information (SPII).
948        Spii,
949        /// The function call generated by the model is invalid.
950        MalformedFunctionCall,
951    }
952
953    #[derive(Debug, Deserialize, Serialize)]
954    #[serde(rename_all = "camelCase")]
955    pub struct CitationMetadata {
956        pub citation_sources: Vec<CitationSource>,
957    }
958
959    #[derive(Debug, Deserialize, Serialize)]
960    #[serde(rename_all = "camelCase")]
961    pub struct CitationSource {
962        #[serde(skip_serializing_if = "Option::is_none")]
963        pub uri: Option<String>,
964        #[serde(skip_serializing_if = "Option::is_none")]
965        pub start_index: Option<i32>,
966        #[serde(skip_serializing_if = "Option::is_none")]
967        pub end_index: Option<i32>,
968        #[serde(skip_serializing_if = "Option::is_none")]
969        pub license: Option<String>,
970    }
971
972    #[derive(Debug, Deserialize, Serialize)]
973    #[serde(rename_all = "camelCase")]
974    pub struct LogprobsResult {
975        pub top_candidate: Vec<TopCandidate>,
976        pub chosen_candidate: Vec<LogProbCandidate>,
977    }
978
979    #[derive(Debug, Deserialize, Serialize)]
980    pub struct TopCandidate {
981        pub candidates: Vec<LogProbCandidate>,
982    }
983
984    #[derive(Debug, Deserialize, Serialize)]
985    #[serde(rename_all = "camelCase")]
986    pub struct LogProbCandidate {
987        pub token: String,
988        pub token_id: String,
989        pub log_probability: f64,
990    }
991
992    /// Gemini API Configuration options for model generation and outputs. Not all parameters are
993    /// configurable for every model. From [Gemini API Reference](https://ai.google.dev/api/generate-content#generationconfig)
994    /// ### Rig Note:
995    /// Can be used to construct a typesafe `additional_params` in rig::[AgentBuilder](crate::agent::AgentBuilder).
996    #[derive(Debug, Deserialize, Serialize)]
997    #[serde(rename_all = "camelCase")]
998    pub struct GenerationConfig {
999        /// The set of character sequences (up to 5) that will stop output generation. If specified, the API will stop
1000        /// at the first appearance of a stop_sequence. The stop sequence will not be included as part of the response.
1001        #[serde(skip_serializing_if = "Option::is_none")]
1002        pub stop_sequences: Option<Vec<String>>,
1003        /// MIME type of the generated candidate text. Supported MIME types are:
1004        ///     - text/plain:  (default) Text output
1005        ///     - application/json: JSON response in the response candidates.
1006        ///     - text/x.enum: ENUM as a string response in the response candidates.
1007        /// Refer to the docs for a list of all supported text MIME types
1008        #[serde(skip_serializing_if = "Option::is_none")]
1009        pub response_mime_type: Option<String>,
1010        /// Output schema of the generated candidate text. Schemas must be a subset of the OpenAPI schema and can be
1011        /// objects, primitives or arrays. If set, a compatible responseMimeType must also  be set. Compatible MIME
1012        /// types: application/json: Schema for JSON response. Refer to the JSON text generation guide for more details.
1013        #[serde(skip_serializing_if = "Option::is_none")]
1014        pub response_schema: Option<Schema>,
1015        /// Number of generated responses to return. Currently, this value can only be set to 1. If
1016        /// unset, this will default to 1.
1017        #[serde(skip_serializing_if = "Option::is_none")]
1018        pub candidate_count: Option<i32>,
1019        /// The maximum number of tokens to include in a response candidate. Note: The default value varies by model, see
1020        /// the Model.output_token_limit attribute of the Model returned from the getModel function.
1021        #[serde(skip_serializing_if = "Option::is_none")]
1022        pub max_output_tokens: Option<u64>,
1023        /// Controls the randomness of the output. Note: The default value varies by model, see the Model.temperature
1024        /// attribute of the Model returned from the getModel function. Values can range from [0.0, 2.0].
1025        #[serde(skip_serializing_if = "Option::is_none")]
1026        pub temperature: Option<f64>,
1027        /// The maximum cumulative probability of tokens to consider when sampling. The model uses combined Top-k and
1028        /// Top-p (nucleus) sampling. Tokens are sorted based on their assigned probabilities so that only the most
1029        /// likely tokens are considered. Top-k sampling directly limits the maximum number of tokens to consider, while
1030        /// Nucleus sampling limits the number of tokens based on the cumulative probability. Note: The default value
1031        /// varies by Model and is specified by theModel.top_p attribute returned from the getModel function. An empty
1032        /// topK attribute indicates that the model doesn't apply top-k sampling and doesn't allow setting topK on requests.
1033        #[serde(skip_serializing_if = "Option::is_none")]
1034        pub top_p: Option<f64>,
1035        /// The maximum number of tokens to consider when sampling. Gemini models use Top-p (nucleus) sampling or a
1036        /// combination of Top-k and nucleus sampling. Top-k sampling considers the set of topK most probable tokens.
1037        /// Models running with nucleus sampling don't allow topK setting. Note: The default value varies by Model and is
1038        /// specified by theModel.top_p attribute returned from the getModel function. An empty topK attribute indicates
1039        /// that the model doesn't apply top-k sampling and doesn't allow setting topK on requests.
1040        #[serde(skip_serializing_if = "Option::is_none")]
1041        pub top_k: Option<i32>,
1042        /// Presence penalty applied to the next token's logprobs if the token has already been seen in the response.
1043        /// This penalty is binary on/off and not dependent on the number of times the token is used (after the first).
1044        /// Use frequencyPenalty for a penalty that increases with each use. A positive penalty will discourage the use
1045        /// of tokens that have already been used in the response, increasing the vocabulary. A negative penalty will
1046        /// encourage the use of tokens that have already been used in the response, decreasing the vocabulary.
1047        #[serde(skip_serializing_if = "Option::is_none")]
1048        pub presence_penalty: Option<f64>,
1049        /// Frequency penalty applied to the next token's logprobs, multiplied by the number of times each token has been
1050        /// seen in the response so far. A positive penalty will discourage the use of tokens that have already been
1051        /// used, proportional to the number of times the token has been used: The more a token is used, the more
1052        /// difficult it is for the  model to use that token again increasing the vocabulary of responses. Caution: A
1053        /// negative penalty will encourage the model to reuse tokens proportional to the number of times the token has
1054        /// been used. Small negative values will reduce the vocabulary of a response. Larger negative values will cause
1055        /// the model to  repeating a common token until it hits the maxOutputTokens limit: "...the the the the the...".
1056        #[serde(skip_serializing_if = "Option::is_none")]
1057        pub frequency_penalty: Option<f64>,
1058        /// If true, export the logprobs results in response.
1059        #[serde(skip_serializing_if = "Option::is_none")]
1060        pub response_logprobs: Option<bool>,
1061        /// Only valid if responseLogprobs=True. This sets the number of top logprobs to return at each decoding step in
1062        /// [Candidate.logprobs_result].
1063        #[serde(skip_serializing_if = "Option::is_none")]
1064        pub logprobs: Option<i32>,
1065        /// Configuration for thinking/reasoning.
1066        #[serde(skip_serializing_if = "Option::is_none")]
1067        pub thinking_config: Option<ThinkingConfig>,
1068    }
1069
1070    impl Default for GenerationConfig {
1071        fn default() -> Self {
1072            Self {
1073                temperature: Some(1.0),
1074                max_output_tokens: Some(4096),
1075                stop_sequences: None,
1076                response_mime_type: None,
1077                response_schema: None,
1078                candidate_count: None,
1079                top_p: None,
1080                top_k: None,
1081                presence_penalty: None,
1082                frequency_penalty: None,
1083                response_logprobs: None,
1084                logprobs: None,
1085                thinking_config: None,
1086            }
1087        }
1088    }
1089
1090    #[derive(Debug, Deserialize, Serialize)]
1091    #[serde(rename_all = "camelCase")]
1092    pub struct ThinkingConfig {
1093        pub thinking_budget: u32,
1094        pub include_thoughts: Option<bool>,
1095    }
1096    /// The Schema object allows the definition of input and output data types. These types can be objects, but also
1097    /// primitives and arrays. Represents a select subset of an OpenAPI 3.0 schema object.
1098    /// From [Gemini API Reference](https://ai.google.dev/api/caching#Schema)
1099    #[derive(Debug, Deserialize, Serialize, Clone)]
1100    pub struct Schema {
1101        pub r#type: String,
1102        #[serde(skip_serializing_if = "Option::is_none")]
1103        pub format: Option<String>,
1104        #[serde(skip_serializing_if = "Option::is_none")]
1105        pub description: Option<String>,
1106        #[serde(skip_serializing_if = "Option::is_none")]
1107        pub nullable: Option<bool>,
1108        #[serde(skip_serializing_if = "Option::is_none")]
1109        pub r#enum: Option<Vec<String>>,
1110        #[serde(skip_serializing_if = "Option::is_none")]
1111        pub max_items: Option<i32>,
1112        #[serde(skip_serializing_if = "Option::is_none")]
1113        pub min_items: Option<i32>,
1114        #[serde(skip_serializing_if = "Option::is_none")]
1115        pub properties: Option<HashMap<String, Schema>>,
1116        #[serde(skip_serializing_if = "Option::is_none")]
1117        pub required: Option<Vec<String>>,
1118        #[serde(skip_serializing_if = "Option::is_none")]
1119        pub items: Option<Box<Schema>>,
1120    }
1121
1122    /// Flattens a JSON schema by resolving all `$ref` references inline.
1123    /// It takes a JSON schema that may contain `$ref` references to definitions
1124    /// in `$defs` or `definitions` sections and returns a new schema with all references
1125    /// resolved and inlined. This is necessary for APIs like Gemini that don't support
1126    /// schema references.
1127    pub fn flatten_schema(mut schema: Value) -> Result<Value, CompletionError> {
1128        // extracting $defs if they exist
1129        let defs = if let Some(obj) = schema.as_object() {
1130            obj.get("$defs").or_else(|| obj.get("definitions")).cloned()
1131        } else {
1132            None
1133        };
1134
1135        let Some(defs_value) = defs else {
1136            return Ok(schema);
1137        };
1138
1139        let Some(defs_obj) = defs_value.as_object() else {
1140            return Err(CompletionError::ResponseError(
1141                "$defs must be an object".into(),
1142            ));
1143        };
1144
1145        resolve_refs(&mut schema, defs_obj)?;
1146
1147        // removing $defs from the final schema because we have inlined everything
1148        if let Some(obj) = schema.as_object_mut() {
1149            obj.remove("$defs");
1150            obj.remove("definitions");
1151        }
1152
1153        Ok(schema)
1154    }
1155
1156    /// Recursively resolves all `$ref` references in a JSON value by
1157    /// replacing them with their definitions.
1158    fn resolve_refs(
1159        value: &mut Value,
1160        defs: &serde_json::Map<String, Value>,
1161    ) -> Result<(), CompletionError> {
1162        match value {
1163            Value::Object(obj) => {
1164                if let Some(ref_value) = obj.get("$ref")
1165                    && let Some(ref_str) = ref_value.as_str()
1166                {
1167                    // "#/$defs/Person" -> "Person"
1168                    let def_name = parse_ref_path(ref_str)?;
1169
1170                    let def = defs.get(&def_name).ok_or_else(|| {
1171                        CompletionError::ResponseError(format!("Reference not found: {}", ref_str))
1172                    })?;
1173
1174                    let mut resolved = def.clone();
1175                    resolve_refs(&mut resolved, defs)?;
1176                    *value = resolved;
1177                    return Ok(());
1178                }
1179
1180                for (_, v) in obj.iter_mut() {
1181                    resolve_refs(v, defs)?;
1182                }
1183            }
1184            Value::Array(arr) => {
1185                for item in arr.iter_mut() {
1186                    resolve_refs(item, defs)?;
1187                }
1188            }
1189            _ => {}
1190        }
1191
1192        Ok(())
1193    }
1194
1195    /// Parses a JSON Schema `$ref` path to extract the definition name.
1196    ///
1197    /// JSON Schema references use URI fragment syntax to point to definitions within
1198    /// the same document. This function extracts the definition name from common
1199    /// reference patterns used in JSON Schema.
1200    fn parse_ref_path(ref_str: &str) -> Result<String, CompletionError> {
1201        if let Some(fragment) = ref_str.strip_prefix('#') {
1202            if let Some(name) = fragment.strip_prefix("/$defs/") {
1203                Ok(name.to_string())
1204            } else if let Some(name) = fragment.strip_prefix("/definitions/") {
1205                Ok(name.to_string())
1206            } else {
1207                Err(CompletionError::ResponseError(format!(
1208                    "Unsupported reference format: {}",
1209                    ref_str
1210                )))
1211            }
1212        } else {
1213            Err(CompletionError::ResponseError(format!(
1214                "Only fragment references (#/...) are supported: {}",
1215                ref_str
1216            )))
1217        }
1218    }
1219
1220    impl TryFrom<Value> for Schema {
1221        type Error = CompletionError;
1222
1223        fn try_from(value: Value) -> Result<Self, Self::Error> {
1224            let flattened_val = flatten_schema(value)?;
1225            if let Some(obj) = flattened_val.as_object() {
1226                Ok(Schema {
1227                    r#type: obj
1228                        .get("type")
1229                        .and_then(|v| {
1230                            if v.is_string() {
1231                                v.as_str().map(String::from)
1232                            } else if v.is_array() {
1233                                v.as_array()
1234                                    .and_then(|arr| arr.first())
1235                                    .and_then(|v| v.as_str().map(String::from))
1236                            } else {
1237                                None
1238                            }
1239                        })
1240                        .unwrap_or_default(),
1241                    format: obj.get("format").and_then(|v| v.as_str()).map(String::from),
1242                    description: obj
1243                        .get("description")
1244                        .and_then(|v| v.as_str())
1245                        .map(String::from),
1246                    nullable: obj.get("nullable").and_then(|v| v.as_bool()),
1247                    r#enum: obj.get("enum").and_then(|v| v.as_array()).map(|arr| {
1248                        arr.iter()
1249                            .filter_map(|v| v.as_str().map(String::from))
1250                            .collect()
1251                    }),
1252                    max_items: obj
1253                        .get("maxItems")
1254                        .and_then(|v| v.as_i64())
1255                        .map(|v| v as i32),
1256                    min_items: obj
1257                        .get("minItems")
1258                        .and_then(|v| v.as_i64())
1259                        .map(|v| v as i32),
1260                    properties: obj
1261                        .get("properties")
1262                        .and_then(|v| v.as_object())
1263                        .map(|map| {
1264                            map.iter()
1265                                .filter_map(|(k, v)| {
1266                                    v.clone().try_into().ok().map(|schema| (k.clone(), schema))
1267                                })
1268                                .collect()
1269                        }),
1270                    required: obj.get("required").and_then(|v| v.as_array()).map(|arr| {
1271                        arr.iter()
1272                            .filter_map(|v| v.as_str().map(String::from))
1273                            .collect()
1274                    }),
1275                    items: obj
1276                        .get("items")
1277                        .map(|v| Box::new(v.clone().try_into().unwrap())),
1278                })
1279            } else {
1280                Err(CompletionError::ResponseError(
1281                    "Expected a JSON object for Schema".into(),
1282                ))
1283            }
1284        }
1285    }
1286
1287    #[derive(Debug, Serialize)]
1288    #[serde(rename_all = "camelCase")]
1289    pub struct GenerateContentRequest {
1290        pub contents: Vec<Content>,
1291        #[serde(skip_serializing_if = "Option::is_none")]
1292        pub tools: Option<Tool>,
1293        pub tool_config: Option<ToolConfig>,
1294        /// Optional. Configuration options for model generation and outputs.
1295        pub generation_config: Option<GenerationConfig>,
1296        /// Optional. A list of unique SafetySetting instances for blocking unsafe content. This will be enforced on the
1297        /// [GenerateContentRequest.contents] and [GenerateContentResponse.candidates]. There should not be more than one
1298        /// setting for each SafetyCategory type. The API will block any contents and responses that fail to meet the
1299        /// thresholds set by these settings. This list overrides the default settings for each SafetyCategory specified
1300        /// in the safetySettings. If there is no SafetySetting for a given SafetyCategory provided in the list, the API
1301        /// will use the default safety setting for that category. Harm categories:
1302        ///     - HARM_CATEGORY_HATE_SPEECH,
1303        ///     - HARM_CATEGORY_SEXUALLY_EXPLICIT
1304        ///     - HARM_CATEGORY_DANGEROUS_CONTENT
1305        ///     - HARM_CATEGORY_HARASSMENT
1306        /// are supported.
1307        /// Refer to the guide for detailed information on available safety settings. Also refer to the Safety guidance
1308        /// to learn how to incorporate safety considerations in your AI applications.
1309        pub safety_settings: Option<Vec<SafetySetting>>,
1310        /// Optional. Developer set system instruction(s). Currently, text only.
1311        /// From [Gemini API Reference](https://ai.google.dev/gemini-api/docs/system-instructions?lang=rest)
1312        pub system_instruction: Option<Content>,
1313        // cachedContent: Optional<String>
1314        /// Additional parameters.
1315        #[serde(flatten, skip_serializing_if = "Option::is_none")]
1316        pub additional_params: Option<serde_json::Value>,
1317    }
1318
1319    #[derive(Debug, Serialize)]
1320    #[serde(rename_all = "camelCase")]
1321    pub struct Tool {
1322        pub function_declarations: Vec<FunctionDeclaration>,
1323        pub code_execution: Option<CodeExecution>,
1324    }
1325
1326    #[derive(Debug, Serialize, Clone)]
1327    #[serde(rename_all = "camelCase")]
1328    pub struct FunctionDeclaration {
1329        pub name: String,
1330        pub description: String,
1331        #[serde(skip_serializing_if = "Option::is_none")]
1332        pub parameters: Option<Schema>,
1333    }
1334
1335    #[derive(Debug, Serialize)]
1336    #[serde(rename_all = "camelCase")]
1337    pub struct ToolConfig {
1338        pub schema: Option<Schema>,
1339    }
1340
1341    #[derive(Debug, Serialize)]
1342    #[serde(rename_all = "camelCase")]
1343    pub struct CodeExecution {}
1344
1345    #[derive(Debug, Serialize)]
1346    #[serde(rename_all = "camelCase")]
1347    pub struct SafetySetting {
1348        pub category: HarmCategory,
1349        pub threshold: HarmBlockThreshold,
1350    }
1351
1352    #[derive(Debug, Serialize)]
1353    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1354    pub enum HarmBlockThreshold {
1355        HarmBlockThresholdUnspecified,
1356        BlockLowAndAbove,
1357        BlockMediumAndAbove,
1358        BlockOnlyHigh,
1359        BlockNone,
1360        Off,
1361    }
1362}
1363
1364#[cfg(test)]
1365mod tests {
1366    use crate::{message, providers::gemini::completion::gemini_api_types::flatten_schema};
1367
1368    use super::*;
1369    use serde_json::json;
1370
1371    #[test]
1372    fn test_deserialize_message_user() {
1373        let raw_message = r#"{
1374            "parts": [
1375                {"text": "Hello, world!"},
1376                {"inlineData": {"mimeType": "image/png", "data": "base64encodeddata"}},
1377                {"functionCall": {"name": "test_function", "args": {"arg1": "value1"}}},
1378                {"functionResponse": {"name": "test_function", "response": {"result": "success"}}},
1379                {"fileData": {"mimeType": "application/pdf", "fileUri": "http://example.com/file.pdf"}},
1380                {"executableCode": {"code": "print('Hello, world!')", "language": "PYTHON"}},
1381                {"codeExecutionResult": {"output": "Hello, world!", "outcome": "OUTCOME_OK"}}
1382            ],
1383            "role": "user"
1384        }"#;
1385
1386        let content: Content = {
1387            let jd = &mut serde_json::Deserializer::from_str(raw_message);
1388            serde_path_to_error::deserialize(jd).unwrap_or_else(|err| {
1389                panic!("Deserialization error at {}: {}", err.path(), err);
1390            })
1391        };
1392        assert_eq!(content.role, Some(Role::User));
1393        assert_eq!(content.parts.len(), 7);
1394
1395        let parts: Vec<Part> = content.parts.into_iter().collect();
1396
1397        if let Part {
1398            part: PartKind::Text(text),
1399            ..
1400        } = &parts[0]
1401        {
1402            assert_eq!(text, "Hello, world!");
1403        } else {
1404            panic!("Expected text part");
1405        }
1406
1407        if let Part {
1408            part: PartKind::InlineData(inline_data),
1409            ..
1410        } = &parts[1]
1411        {
1412            assert_eq!(inline_data.mime_type, "image/png");
1413            assert_eq!(inline_data.data, "base64encodeddata");
1414        } else {
1415            panic!("Expected inline data part");
1416        }
1417
1418        if let Part {
1419            part: PartKind::FunctionCall(function_call),
1420            ..
1421        } = &parts[2]
1422        {
1423            assert_eq!(function_call.name, "test_function");
1424            assert_eq!(
1425                function_call.args.as_object().unwrap().get("arg1").unwrap(),
1426                "value1"
1427            );
1428        } else {
1429            panic!("Expected function call part");
1430        }
1431
1432        if let Part {
1433            part: PartKind::FunctionResponse(function_response),
1434            ..
1435        } = &parts[3]
1436        {
1437            assert_eq!(function_response.name, "test_function");
1438            assert_eq!(
1439                function_response
1440                    .response
1441                    .as_ref()
1442                    .unwrap()
1443                    .get("result")
1444                    .unwrap(),
1445                "success"
1446            );
1447        } else {
1448            panic!("Expected function response part");
1449        }
1450
1451        if let Part {
1452            part: PartKind::FileData(file_data),
1453            ..
1454        } = &parts[4]
1455        {
1456            assert_eq!(file_data.mime_type.as_ref().unwrap(), "application/pdf");
1457            assert_eq!(file_data.file_uri, "http://example.com/file.pdf");
1458        } else {
1459            panic!("Expected file data part");
1460        }
1461
1462        if let Part {
1463            part: PartKind::ExecutableCode(executable_code),
1464            ..
1465        } = &parts[5]
1466        {
1467            assert_eq!(executable_code.code, "print('Hello, world!')");
1468        } else {
1469            panic!("Expected executable code part");
1470        }
1471
1472        if let Part {
1473            part: PartKind::CodeExecutionResult(code_execution_result),
1474            ..
1475        } = &parts[6]
1476        {
1477            assert_eq!(
1478                code_execution_result.clone().output.unwrap(),
1479                "Hello, world!"
1480            );
1481        } else {
1482            panic!("Expected code execution result part");
1483        }
1484    }
1485
1486    #[test]
1487    fn test_deserialize_message_model() {
1488        let json_data = json!({
1489            "parts": [{"text": "Hello, user!"}],
1490            "role": "model"
1491        });
1492
1493        let content: Content = serde_json::from_value(json_data).unwrap();
1494        assert_eq!(content.role, Some(Role::Model));
1495        assert_eq!(content.parts.len(), 1);
1496        if let Some(Part {
1497            part: PartKind::Text(text),
1498            ..
1499        }) = content.parts.first()
1500        {
1501            assert_eq!(text, "Hello, user!");
1502        } else {
1503            panic!("Expected text part");
1504        }
1505    }
1506
1507    #[test]
1508    fn test_message_conversion_user() {
1509        let msg = message::Message::user("Hello, world!");
1510        let content: Content = msg.try_into().unwrap();
1511        assert_eq!(content.role, Some(Role::User));
1512        assert_eq!(content.parts.len(), 1);
1513        if let Some(Part {
1514            part: PartKind::Text(text),
1515            ..
1516        }) = &content.parts.first()
1517        {
1518            assert_eq!(text, "Hello, world!");
1519        } else {
1520            panic!("Expected text part");
1521        }
1522    }
1523
1524    #[test]
1525    fn test_message_conversion_model() {
1526        let msg = message::Message::assistant("Hello, user!");
1527
1528        let content: Content = msg.try_into().unwrap();
1529        assert_eq!(content.role, Some(Role::Model));
1530        assert_eq!(content.parts.len(), 1);
1531        if let Some(Part {
1532            part: PartKind::Text(text),
1533            ..
1534        }) = &content.parts.first()
1535        {
1536            assert_eq!(text, "Hello, user!");
1537        } else {
1538            panic!("Expected text part");
1539        }
1540    }
1541
1542    #[test]
1543    fn test_message_conversion_tool_call() {
1544        let tool_call = message::ToolCall {
1545            id: "test_tool".to_string(),
1546            call_id: None,
1547            function: message::ToolFunction {
1548                name: "test_function".to_string(),
1549                arguments: json!({"arg1": "value1"}),
1550            },
1551        };
1552
1553        let msg = message::Message::Assistant {
1554            id: None,
1555            content: OneOrMany::one(message::AssistantContent::ToolCall(tool_call)),
1556        };
1557
1558        let content: Content = msg.try_into().unwrap();
1559        assert_eq!(content.role, Some(Role::Model));
1560        assert_eq!(content.parts.len(), 1);
1561        if let Some(Part {
1562            part: PartKind::FunctionCall(function_call),
1563            ..
1564        }) = content.parts.first()
1565        {
1566            assert_eq!(function_call.name, "test_function");
1567            assert_eq!(
1568                function_call.args.as_object().unwrap().get("arg1").unwrap(),
1569                "value1"
1570            );
1571        } else {
1572            panic!("Expected function call part");
1573        }
1574    }
1575
1576    #[test]
1577    fn test_vec_schema_conversion() {
1578        let schema_with_ref = json!({
1579            "type": "array",
1580            "items": {
1581                "$ref": "#/$defs/Person"
1582            },
1583            "$defs": {
1584                "Person": {
1585                    "type": "object",
1586                    "properties": {
1587                        "first_name": {
1588                            "type": ["string", "null"],
1589                            "description": "The person's first name, if provided (null otherwise)"
1590                        },
1591                        "last_name": {
1592                            "type": ["string", "null"],
1593                            "description": "The person's last name, if provided (null otherwise)"
1594                        },
1595                        "job": {
1596                            "type": ["string", "null"],
1597                            "description": "The person's job, if provided (null otherwise)"
1598                        }
1599                    },
1600                    "required": []
1601                }
1602            }
1603        });
1604
1605        let result: Result<Schema, _> = schema_with_ref.try_into();
1606
1607        match result {
1608            Ok(schema) => {
1609                assert_eq!(schema.r#type, "array");
1610
1611                if let Some(items) = schema.items {
1612                    println!("item types: {}", items.r#type);
1613
1614                    assert_ne!(items.r#type, "", "Items type should not be empty string!");
1615                    assert_eq!(items.r#type, "object", "Items should be object type");
1616                } else {
1617                    panic!("Schema should have items field for array type");
1618                }
1619            }
1620            Err(e) => println!("Schema conversion failed: {:?}", e),
1621        }
1622    }
1623
1624    #[test]
1625    fn test_object_schema() {
1626        let simple_schema = json!({
1627            "type": "object",
1628            "properties": {
1629                "name": {
1630                    "type": "string"
1631                }
1632            }
1633        });
1634
1635        let schema: Schema = simple_schema.try_into().unwrap();
1636        assert_eq!(schema.r#type, "object");
1637        assert!(schema.properties.is_some());
1638    }
1639
1640    #[test]
1641    fn test_array_with_inline_items() {
1642        let inline_schema = json!({
1643            "type": "array",
1644            "items": {
1645                "type": "object",
1646                "properties": {
1647                    "name": {
1648                        "type": "string"
1649                    }
1650                }
1651            }
1652        });
1653
1654        let schema: Schema = inline_schema.try_into().unwrap();
1655        assert_eq!(schema.r#type, "array");
1656
1657        if let Some(items) = schema.items {
1658            assert_eq!(items.r#type, "object");
1659            assert!(items.properties.is_some());
1660        } else {
1661            panic!("Schema should have items field");
1662        }
1663    }
1664    #[test]
1665    fn test_flattened_schema() {
1666        let ref_schema = json!({
1667            "type": "array",
1668            "items": {
1669                "$ref": "#/$defs/Person"
1670            },
1671            "$defs": {
1672                "Person": {
1673                    "type": "object",
1674                    "properties": {
1675                        "name": { "type": "string" }
1676                    }
1677                }
1678            }
1679        });
1680
1681        let flattened = flatten_schema(ref_schema).unwrap();
1682        let schema: Schema = flattened.try_into().unwrap();
1683
1684        assert_eq!(schema.r#type, "array");
1685
1686        if let Some(items) = schema.items {
1687            println!("Flattened items type: '{}'", items.r#type);
1688
1689            assert_eq!(items.r#type, "object");
1690            assert!(items.properties.is_some());
1691        }
1692    }
1693}