rig/providers/gemini/
completion.rs

1// ================================================================
2//! Google Gemini Completion Integration
3//! From [Gemini API Reference](https://ai.google.dev/api/generate-content)
4// ================================================================
5/// `gemini-2.5-pro-preview-06-05` completion model
6pub const GEMINI_2_5_PRO_PREVIEW_06_05: &str = "gemini-2.5-pro-preview-06-05";
7/// `gemini-2.5-pro-preview-05-06` completion model
8pub const GEMINI_2_5_PRO_PREVIEW_05_06: &str = "gemini-2.5-pro-preview-05-06";
9/// `gemini-2.5-pro-preview-03-25` completion model
10pub const GEMINI_2_5_PRO_PREVIEW_03_25: &str = "gemini-2.5-pro-preview-03-25";
11/// `gemini-2.5-flash-preview-05-20` completion model
12pub const GEMINI_2_5_FLASH_PREVIEW_05_20: &str = "gemini-2.5-flash-preview-05-20";
13/// `gemini-2.5-flash-preview-04-17` completion model
14pub const GEMINI_2_5_FLASH_PREVIEW_04_17: &str = "gemini-2.5-flash-preview-04-17";
15/// `gemini-2.5-pro-exp-03-25` experimental completion model
16pub const GEMINI_2_5_PRO_EXP_03_25: &str = "gemini-2.5-pro-exp-03-25";
17/// `gemini-2.0-flash-lite` completion model
18pub const GEMINI_2_0_FLASH_LITE: &str = "gemini-2.0-flash-lite";
19/// `gemini-2.0-flash` completion model
20pub const GEMINI_2_0_FLASH: &str = "gemini-2.0-flash";
21/// `gemini-1.5-flash` completion model
22pub const GEMINI_1_5_FLASH: &str = "gemini-1.5-flash";
23/// `gemini-1.5-pro` completion model
24pub const GEMINI_1_5_PRO: &str = "gemini-1.5-pro";
25/// `gemini-1.5-pro-8b` completion model
26pub const GEMINI_1_5_PRO_8B: &str = "gemini-1.5-pro-8b";
27/// `gemini-1.0-pro` completion model
28pub const GEMINI_1_0_PRO: &str = "gemini-1.0-pro";
29
30use self::gemini_api_types::Schema;
31use crate::message::Reasoning;
32use crate::providers::gemini::completion::gemini_api_types::AdditionalParameters;
33use crate::providers::gemini::streaming::StreamingCompletionResponse;
34use crate::{
35    OneOrMany,
36    completion::{self, CompletionError, CompletionRequest},
37};
38use gemini_api_types::{
39    Content, FunctionDeclaration, GenerateContentRequest, GenerateContentResponse, Part, PartKind,
40    Role, Tool,
41};
42use serde_json::{Map, Value};
43use std::convert::TryFrom;
44
45use super::Client;
46
47// =================================================================
48// Rig Implementation Types
49// =================================================================
50
51#[derive(Clone)]
52pub struct CompletionModel {
53    pub(crate) client: Client,
54    pub model: String,
55}
56
57impl CompletionModel {
58    pub fn new(client: Client, model: &str) -> Self {
59        Self {
60            client,
61            model: model.to_string(),
62        }
63    }
64}
65
66impl completion::CompletionModel for CompletionModel {
67    type Response = GenerateContentResponse;
68    type StreamingResponse = StreamingCompletionResponse;
69
70    #[cfg_attr(feature = "worker", worker::send)]
71    async fn completion(
72        &self,
73        completion_request: CompletionRequest,
74    ) -> Result<completion::CompletionResponse<GenerateContentResponse>, CompletionError> {
75        let request = create_request_body(completion_request)?;
76
77        tracing::debug!(
78            "Sending completion request to Gemini API {}",
79            serde_json::to_string_pretty(&request)?
80        );
81
82        let response = self
83            .client
84            .post(&format!("/v1beta/models/{}:generateContent", self.model))
85            .json(&request)
86            .send()
87            .await?;
88
89        if response.status().is_success() {
90            let response = response.json::<GenerateContentResponse>().await?;
91            match response.usage_metadata {
92                Some(ref usage) => tracing::info!(target: "rig",
93                "Gemini completion token usage: {}",
94                usage
95                ),
96                None => tracing::info!(target: "rig",
97                    "Gemini completion token usage: n/a",
98                ),
99            }
100
101            tracing::debug!("Received response");
102
103            Ok(completion::CompletionResponse::try_from(response))
104        } else {
105            Err(CompletionError::ProviderError(response.text().await?))
106        }?
107    }
108
109    #[cfg_attr(feature = "worker", worker::send)]
110    async fn stream(
111        &self,
112        request: CompletionRequest,
113    ) -> Result<
114        crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
115        CompletionError,
116    > {
117        CompletionModel::stream(self, request).await
118    }
119}
120
121pub(crate) fn create_request_body(
122    completion_request: CompletionRequest,
123) -> Result<GenerateContentRequest, CompletionError> {
124    let mut full_history = Vec::new();
125    full_history.extend(completion_request.chat_history);
126
127    let additional_params = completion_request
128        .additional_params
129        .unwrap_or_else(|| Value::Object(Map::new()));
130
131    let AdditionalParameters {
132        mut generation_config,
133        additional_params,
134    } = serde_json::from_value::<AdditionalParameters>(additional_params)?;
135
136    if let Some(temp) = completion_request.temperature {
137        generation_config.temperature = Some(temp);
138    }
139
140    if let Some(max_tokens) = completion_request.max_tokens {
141        generation_config.max_output_tokens = Some(max_tokens);
142    }
143
144    let system_instruction = completion_request.preamble.clone().map(|preamble| Content {
145        parts: vec![preamble.into()],
146        role: Some(Role::Model),
147    });
148
149    let tools = if completion_request.tools.is_empty() {
150        None
151    } else {
152        Some(Tool::try_from(completion_request.tools)?)
153    };
154
155    let request = GenerateContentRequest {
156        contents: full_history
157            .into_iter()
158            .map(|msg| {
159                msg.try_into()
160                    .map_err(|e| CompletionError::RequestError(Box::new(e)))
161            })
162            .collect::<Result<Vec<_>, _>>()?,
163        generation_config: Some(generation_config),
164        safety_settings: None,
165        tools,
166        tool_config: None,
167        system_instruction,
168        additional_params,
169    };
170
171    Ok(request)
172}
173
174impl TryFrom<completion::ToolDefinition> for Tool {
175    type Error = CompletionError;
176
177    fn try_from(tool: completion::ToolDefinition) -> Result<Self, Self::Error> {
178        let parameters: Option<Schema> =
179            if tool.parameters == serde_json::json!({"type": "object", "properties": {}}) {
180                None
181            } else {
182                Some(tool.parameters.try_into()?)
183            };
184
185        Ok(Self {
186            function_declarations: vec![FunctionDeclaration {
187                name: tool.name,
188                description: tool.description,
189                parameters,
190            }],
191            code_execution: None,
192        })
193    }
194}
195
196impl TryFrom<Vec<completion::ToolDefinition>> for Tool {
197    type Error = CompletionError;
198
199    fn try_from(tools: Vec<completion::ToolDefinition>) -> Result<Self, Self::Error> {
200        let mut function_declarations = Vec::new();
201
202        for tool in tools {
203            let parameters =
204                if tool.parameters == serde_json::json!({"type": "object", "properties": {}}) {
205                    None
206                } else {
207                    match tool.parameters.try_into() {
208                        Ok(schema) => Some(schema),
209                        Err(e) => {
210                            let emsg = format!(
211                                "Tool '{}' could not be converted to a schema: {:?}",
212                                tool.name, e,
213                            );
214                            return Err(CompletionError::ProviderError(emsg));
215                        }
216                    }
217                };
218
219            function_declarations.push(FunctionDeclaration {
220                name: tool.name,
221                description: tool.description,
222                parameters,
223            });
224        }
225
226        Ok(Self {
227            function_declarations,
228            code_execution: None,
229        })
230    }
231}
232
233impl TryFrom<GenerateContentResponse> for completion::CompletionResponse<GenerateContentResponse> {
234    type Error = CompletionError;
235
236    fn try_from(response: GenerateContentResponse) -> Result<Self, Self::Error> {
237        let candidate = response.candidates.first().ok_or_else(|| {
238            CompletionError::ResponseError("No response candidates in response".into())
239        })?;
240
241        let content = candidate
242            .content
243            .parts
244            .iter()
245            .map(|Part { thought, part, .. }| {
246                Ok(match part {
247                    PartKind::Text(text) => {
248                        if let Some(thought) = thought
249                            && *thought
250                        {
251                            completion::AssistantContent::Reasoning(Reasoning::new(text))
252                        } else {
253                            completion::AssistantContent::text(text)
254                        }
255                    }
256                    PartKind::FunctionCall(function_call) => {
257                        completion::AssistantContent::tool_call(
258                            &function_call.name,
259                            &function_call.name,
260                            function_call.args.clone(),
261                        )
262                    }
263                    _ => {
264                        return Err(CompletionError::ResponseError(
265                            "Response did not contain a message or tool call".into(),
266                        ));
267                    }
268                })
269            })
270            .collect::<Result<Vec<_>, _>>()?;
271
272        let choice = OneOrMany::many(content).map_err(|_| {
273            CompletionError::ResponseError(
274                "Response contained no message or tool call (empty)".to_owned(),
275            )
276        })?;
277
278        let usage = response
279            .usage_metadata
280            .as_ref()
281            .map(|usage| completion::Usage {
282                input_tokens: usage.prompt_token_count as u64,
283                output_tokens: usage.candidates_token_count as u64,
284                total_tokens: usage.total_token_count as u64,
285            })
286            .unwrap_or_default();
287
288        Ok(completion::CompletionResponse {
289            choice,
290            usage,
291            raw_response: response,
292        })
293    }
294}
295
296pub mod gemini_api_types {
297    use std::{collections::HashMap, convert::Infallible, str::FromStr};
298
299    // =================================================================
300    // Gemini API Types
301    // =================================================================
302    use serde::{Deserialize, Serialize};
303    use serde_json::{Value, json};
304
305    use crate::message::ContentFormat;
306    use crate::{
307        OneOrMany,
308        completion::CompletionError,
309        message::{self, MimeType as _, Reasoning, Text},
310        providers::gemini::gemini_api_types::{CodeExecutionResult, ExecutableCode},
311    };
312
313    #[derive(Debug, Deserialize, Serialize, Default)]
314    #[serde(rename_all = "camelCase")]
315    pub struct AdditionalParameters {
316        /// Change your Gemini request configuration.
317        pub generation_config: GenerationConfig,
318        /// Any additional parameters that you want.
319        #[serde(flatten, skip_serializing_if = "Option::is_none")]
320        pub additional_params: Option<serde_json::Value>,
321    }
322
323    impl AdditionalParameters {
324        pub fn with_config(mut self, cfg: GenerationConfig) -> Self {
325            self.generation_config = cfg;
326            self
327        }
328
329        pub fn with_params(mut self, params: serde_json::Value) -> Self {
330            self.additional_params = Some(params);
331            self
332        }
333    }
334
335    /// Response from the model supporting multiple candidate responses.
336    /// Safety ratings and content filtering are reported for both prompt in GenerateContentResponse.prompt_feedback
337    /// and for each candidate in finishReason and in safetyRatings.
338    /// The API:
339    ///     - Returns either all requested candidates or none of them
340    ///     - Returns no candidates at all only if there was something wrong with the prompt (check promptFeedback)
341    ///     - Reports feedback on each candidate in finishReason and safetyRatings.
342    #[derive(Debug, Deserialize, Serialize)]
343    #[serde(rename_all = "camelCase")]
344    pub struct GenerateContentResponse {
345        /// Candidate responses from the model.
346        pub candidates: Vec<ContentCandidate>,
347        /// Returns the prompt's feedback related to the content filters.
348        pub prompt_feedback: Option<PromptFeedback>,
349        /// Output only. Metadata on the generation requests' token usage.
350        pub usage_metadata: Option<UsageMetadata>,
351        pub model_version: Option<String>,
352    }
353
354    /// A response candidate generated from the model.
355    #[derive(Debug, Deserialize, Serialize)]
356    #[serde(rename_all = "camelCase")]
357    pub struct ContentCandidate {
358        /// Output only. Generated content returned from the model.
359        pub content: Content,
360        /// Optional. Output only. The reason why the model stopped generating tokens.
361        /// If empty, the model has not stopped generating tokens.
362        pub finish_reason: Option<FinishReason>,
363        /// List of ratings for the safety of a response candidate.
364        /// There is at most one rating per category.
365        pub safety_ratings: Option<Vec<SafetyRating>>,
366        /// Output only. Citation information for model-generated candidate.
367        /// This field may be populated with recitation information for any text included in the content.
368        /// These are passages that are "recited" from copyrighted material in the foundational LLM's training data.
369        pub citation_metadata: Option<CitationMetadata>,
370        /// Output only. Token count for this candidate.
371        pub token_count: Option<i32>,
372        /// Output only.
373        pub avg_logprobs: Option<f64>,
374        /// Output only. Log-likelihood scores for the response tokens and top tokens
375        pub logprobs_result: Option<LogprobsResult>,
376        /// Output only. Index of the candidate in the list of response candidates.
377        pub index: Option<i32>,
378    }
379
380    #[derive(Debug, Deserialize, Serialize)]
381    pub struct Content {
382        /// Ordered Parts that constitute a single message. Parts may have different MIME types.
383        #[serde(default)]
384        pub parts: Vec<Part>,
385        /// The producer of the content. Must be either 'user' or 'model'.
386        /// Useful to set for multi-turn conversations, otherwise can be left blank or unset.
387        pub role: Option<Role>,
388    }
389
390    impl TryFrom<message::Message> for Content {
391        type Error = message::MessageError;
392
393        fn try_from(msg: message::Message) -> Result<Self, Self::Error> {
394            Ok(match msg {
395                message::Message::User { content } => Content {
396                    parts: content
397                        .into_iter()
398                        .map(|c| c.try_into())
399                        .collect::<Result<Vec<_>, _>>()?,
400                    role: Some(Role::User),
401                },
402                message::Message::Assistant { content, .. } => Content {
403                    role: Some(Role::Model),
404                    parts: content.into_iter().map(|content| content.into()).collect(),
405                },
406            })
407        }
408    }
409
410    impl TryFrom<Content> for message::Message {
411        type Error = message::MessageError;
412
413        fn try_from(content: Content) -> Result<Self, Self::Error> {
414            match content.role {
415                Some(Role::User) | None => {
416                    Ok(message::Message::User {
417                        content: {
418                            let user_content: Result<Vec<_>, _> = content.parts.into_iter()
419                            .map(|Part { part, .. }| {
420                                Ok(match part {
421                                    PartKind::Text(text) => message::UserContent::text(text),
422                                    PartKind::InlineData(inline_data) => {
423                                        let mime_type =
424                                            message::MediaType::from_mime_type(&inline_data.mime_type);
425
426                                        match mime_type {
427                                            Some(message::MediaType::Image(media_type)) => {
428                                                message::UserContent::image(
429                                                    inline_data.data,
430                                                    Some(message::ContentFormat::default()),
431                                                    Some(media_type),
432                                                    Some(message::ImageDetail::default()),
433                                                )
434                                            }
435                                            Some(message::MediaType::Document(media_type)) => {
436                                                message::UserContent::document(
437                                                    inline_data.data,
438                                                    Some(message::ContentFormat::default()),
439                                                    Some(media_type),
440                                                )
441                                            }
442                                            Some(message::MediaType::Audio(media_type)) => {
443                                                message::UserContent::audio(
444                                                    inline_data.data,
445                                                    Some(message::ContentFormat::default()),
446                                                    Some(media_type),
447                                                )
448                                            }
449                                            _ => {
450                                                return Err(message::MessageError::ConversionError(
451                                                    format!("Unsupported media type {mime_type:?}"),
452                                                ));
453                                            }
454                                        }
455                                    }
456                                    _ => {
457                                        return Err(message::MessageError::ConversionError(format!(
458                                            "Unsupported gemini content part type: {part:?}"
459                                        )));
460                                    }
461                                })
462                            })
463                            .collect();
464                            OneOrMany::many(user_content?).map_err(|_| {
465                                message::MessageError::ConversionError(
466                                    "Failed to create OneOrMany from user content".to_string(),
467                                )
468                            })?
469                        },
470                    })
471                }
472                Some(Role::Model) => Ok(message::Message::Assistant {
473                    id: None,
474                    content: {
475                        let assistant_content: Result<Vec<_>, _> = content
476                            .parts
477                            .into_iter()
478                            .map(|Part { thought, part, .. }| {
479                                Ok(match part {
480                                    PartKind::Text(text) => match thought {
481                                        Some(true) => message::AssistantContent::Reasoning(
482                                            Reasoning::new(&text),
483                                        ),
484                                        _ => message::AssistantContent::Text(Text { text }),
485                                    },
486
487                                    PartKind::FunctionCall(function_call) => {
488                                        message::AssistantContent::ToolCall(function_call.into())
489                                    }
490                                    _ => {
491                                        return Err(message::MessageError::ConversionError(
492                                            format!("Unsupported part type: {part:?}"),
493                                        ));
494                                    }
495                                })
496                            })
497                            .collect();
498                        OneOrMany::many(assistant_content?).map_err(|_| {
499                            message::MessageError::ConversionError(
500                                "Failed to create OneOrMany from assistant content".to_string(),
501                            )
502                        })?
503                    },
504                }),
505            }
506        }
507    }
508
509    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
510    #[serde(rename_all = "lowercase")]
511    pub enum Role {
512        User,
513        Model,
514    }
515
516    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
517    #[serde(rename_all = "camelCase")]
518    pub struct Part {
519        /// whether or not the part is a reasoning/thinking text or not
520        #[serde(skip_serializing_if = "Option::is_none")]
521        pub thought: Option<bool>,
522        /// an opaque sig for the thought so it can be reused - is a base64 string
523        #[serde(skip_serializing_if = "Option::is_none")]
524        pub thought_signature: Option<String>,
525        #[serde(flatten)]
526        pub part: PartKind,
527        #[serde(flatten, skip_serializing_if = "Option::is_none")]
528        pub additional_params: Option<Value>,
529    }
530
531    /// A datatype containing media that is part of a multi-part [Content] message.
532    /// A Part consists of data which has an associated datatype. A Part can only contain one of the accepted types in Part.data.
533    /// A Part must have a fixed IANA MIME type identifying the type and subtype of the media if the inlineData field is filled with raw bytes.
534    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
535    #[serde(rename_all = "camelCase")]
536    pub enum PartKind {
537        Text(String),
538        InlineData(Blob),
539        FunctionCall(FunctionCall),
540        FunctionResponse(FunctionResponse),
541        FileData(FileData),
542        ExecutableCode(ExecutableCode),
543        CodeExecutionResult(CodeExecutionResult),
544    }
545
546    impl From<String> for Part {
547        fn from(text: String) -> Self {
548            Self {
549                thought: Some(false),
550                thought_signature: None,
551                part: PartKind::Text(text),
552                additional_params: None,
553            }
554        }
555    }
556
557    impl From<&str> for Part {
558        fn from(text: &str) -> Self {
559            Self::from(text.to_string())
560        }
561    }
562
563    impl FromStr for Part {
564        type Err = Infallible;
565
566        fn from_str(s: &str) -> Result<Self, Self::Err> {
567            Ok(s.into())
568        }
569    }
570
571    impl TryFrom<message::UserContent> for Part {
572        type Error = message::MessageError;
573
574        fn try_from(content: message::UserContent) -> Result<Self, Self::Error> {
575            match content {
576                message::UserContent::Text(message::Text { text }) => Ok(Part {
577                    thought: Some(false),
578                    thought_signature: None,
579                    part: PartKind::Text(text),
580                    additional_params: None,
581                }),
582                message::UserContent::ToolResult(message::ToolResult { id, content, .. }) => {
583                    let content = match content.first() {
584                        message::ToolResultContent::Text(text) => text.text,
585                        message::ToolResultContent::Image(_) => {
586                            return Err(message::MessageError::ConversionError(
587                                "Tool result content must be text".to_string(),
588                            ));
589                        }
590                    };
591                    // Convert to JSON since this value may be a valid JSON value
592                    let result: serde_json::Value =
593                        serde_json::from_str(&content).unwrap_or_else(|error| {
594                            tracing::trace!(
595                                ?error,
596                                "Tool result is not a valid JSON, treat it as normal string"
597                            );
598                            json!(content)
599                        });
600                    Ok(Part {
601                        thought: Some(false),
602                        thought_signature: None,
603                        part: PartKind::FunctionResponse(FunctionResponse {
604                            name: id,
605                            response: Some(json!({ "result": result })),
606                        }),
607                        additional_params: None,
608                    })
609                }
610                message::UserContent::Image(message::Image {
611                    data, media_type, ..
612                }) => match media_type {
613                    Some(media_type) => match media_type {
614                        message::ImageMediaType::JPEG
615                        | message::ImageMediaType::PNG
616                        | message::ImageMediaType::WEBP
617                        | message::ImageMediaType::HEIC
618                        | message::ImageMediaType::HEIF => Ok(Part {
619                            thought: Some(false),
620                            thought_signature: None,
621                            part: PartKind::InlineData(Blob {
622                                mime_type: media_type.to_mime_type().to_owned(),
623                                data,
624                            }),
625                            additional_params: None,
626                        }),
627                        _ => Err(message::MessageError::ConversionError(format!(
628                            "Unsupported image media type {media_type:?}"
629                        ))),
630                    },
631                    None => Err(message::MessageError::ConversionError(
632                        "Media type for image is required for Gemini".to_string(), // Fixed error message
633                    )),
634                },
635                message::UserContent::Document(message::Document {
636                    data, media_type, ..
637                }) => match media_type {
638                    Some(media_type) => match media_type {
639                        message::DocumentMediaType::PDF
640                        | message::DocumentMediaType::TXT
641                        | message::DocumentMediaType::RTF
642                        | message::DocumentMediaType::HTML
643                        | message::DocumentMediaType::CSS
644                        | message::DocumentMediaType::MARKDOWN
645                        | message::DocumentMediaType::CSV
646                        | message::DocumentMediaType::XML => Ok(Part {
647                            thought: Some(false),
648                            thought_signature: None,
649                            part: PartKind::InlineData(Blob {
650                                mime_type: media_type.to_mime_type().to_owned(),
651                                data,
652                            }),
653                            additional_params: None,
654                        }),
655                        _ => Err(message::MessageError::ConversionError(format!(
656                            "Unsupported document media type {media_type:?}"
657                        ))),
658                    },
659                    None => Err(message::MessageError::ConversionError(
660                        "Media type for document is required for Gemini".to_string(), // Fixed error message
661                    )),
662                },
663                message::UserContent::Audio(message::Audio {
664                    data, media_type, ..
665                }) => match media_type {
666                    Some(media_type) => Ok(Part {
667                        thought: Some(false),
668                        thought_signature: None,
669                        part: PartKind::InlineData(Blob {
670                            mime_type: media_type.to_mime_type().to_owned(),
671                            data,
672                        }),
673                        additional_params: None,
674                    }),
675                    None => Err(message::MessageError::ConversionError(
676                        "Media type for audio is required for Gemini".to_string(),
677                    )),
678                },
679                message::UserContent::Video(message::Video {
680                    data,
681                    media_type,
682                    format,
683                    additional_params,
684                }) => {
685                    let mime_type = media_type.map(|m| m.to_mime_type().to_owned());
686
687                    let data = match format {
688                        Some(ContentFormat::String) => PartKind::FileData(FileData {
689                            mime_type,
690                            file_uri: data,
691                        }),
692                        _ => match mime_type {
693                            Some(mime_type) => PartKind::InlineData(Blob { mime_type, data }),
694                            None => {
695                                return Err(message::MessageError::ConversionError(
696                                    "Media type for video is required for Gemini".to_string(),
697                                ));
698                            }
699                        },
700                    };
701
702                    Ok(Part {
703                        thought: Some(false),
704                        thought_signature: None,
705                        part: data,
706                        additional_params,
707                    })
708                }
709            }
710        }
711    }
712
713    impl From<message::AssistantContent> for Part {
714        fn from(content: message::AssistantContent) -> Self {
715            match content {
716                message::AssistantContent::Text(message::Text { text }) => text.into(),
717                message::AssistantContent::ToolCall(tool_call) => tool_call.into(),
718                message::AssistantContent::Reasoning(message::Reasoning { reasoning, .. }) => {
719                    Part {
720                        thought: Some(true),
721                        thought_signature: None,
722                        part: PartKind::Text(
723                            reasoning.first().cloned().unwrap_or_else(|| "".to_string()),
724                        ),
725                        additional_params: None,
726                    }
727                }
728            }
729        }
730    }
731
732    impl From<message::ToolCall> for Part {
733        fn from(tool_call: message::ToolCall) -> Self {
734            Self {
735                thought: Some(false),
736                thought_signature: None,
737                part: PartKind::FunctionCall(FunctionCall {
738                    name: tool_call.function.name,
739                    args: tool_call.function.arguments,
740                }),
741                additional_params: None,
742            }
743        }
744    }
745
746    /// Raw media bytes.
747    /// Text should not be sent as raw bytes, use the 'text' field.
748    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
749    #[serde(rename_all = "camelCase")]
750    pub struct Blob {
751        /// The IANA standard MIME type of the source data. Examples: - image/png - image/jpeg
752        /// If an unsupported MIME type is provided, an error will be returned.
753        pub mime_type: String,
754        /// Raw bytes for media formats. A base64-encoded string.
755        pub data: String,
756    }
757
758    /// A predicted FunctionCall returned from the model that contains a string representing the
759    /// FunctionDeclaration.name with the arguments and their values.
760    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
761    pub struct FunctionCall {
762        /// Required. The name of the function to call. Must be a-z, A-Z, 0-9, or contain underscores
763        /// and dashes, with a maximum length of 63.
764        pub name: String,
765        /// Optional. The function parameters and values in JSON object format.
766        pub args: serde_json::Value,
767    }
768
769    impl From<FunctionCall> for message::ToolCall {
770        fn from(function_call: FunctionCall) -> Self {
771            Self {
772                id: function_call.name.clone(),
773                call_id: None,
774                function: message::ToolFunction {
775                    name: function_call.name,
776                    arguments: function_call.args,
777                },
778            }
779        }
780    }
781
782    impl From<message::ToolCall> for FunctionCall {
783        fn from(tool_call: message::ToolCall) -> Self {
784            Self {
785                name: tool_call.function.name,
786                args: tool_call.function.arguments,
787            }
788        }
789    }
790
791    /// The result output from a FunctionCall that contains a string representing the FunctionDeclaration.name
792    /// and a structured JSON object containing any output from the function is used as context to the model.
793    /// This should contain the result of aFunctionCall made based on model prediction.
794    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
795    pub struct FunctionResponse {
796        /// The name of the function to call. Must be a-z, A-Z, 0-9, or contain underscores and dashes,
797        /// with a maximum length of 63.
798        pub name: String,
799        /// The function response in JSON object format.
800        pub response: Option<serde_json::Value>,
801    }
802
803    /// URI based data.
804    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
805    #[serde(rename_all = "camelCase")]
806    pub struct FileData {
807        /// Optional. The IANA standard MIME type of the source data.
808        pub mime_type: Option<String>,
809        /// Required. URI.
810        pub file_uri: String,
811    }
812
813    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
814    pub struct SafetyRating {
815        pub category: HarmCategory,
816        pub probability: HarmProbability,
817    }
818
819    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
820    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
821    pub enum HarmProbability {
822        HarmProbabilityUnspecified,
823        Negligible,
824        Low,
825        Medium,
826        High,
827    }
828
829    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
830    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
831    pub enum HarmCategory {
832        HarmCategoryUnspecified,
833        HarmCategoryDerogatory,
834        HarmCategoryToxicity,
835        HarmCategoryViolence,
836        HarmCategorySexually,
837        HarmCategoryMedical,
838        HarmCategoryDangerous,
839        HarmCategoryHarassment,
840        HarmCategoryHateSpeech,
841        HarmCategorySexuallyExplicit,
842        HarmCategoryDangerousContent,
843        HarmCategoryCivicIntegrity,
844    }
845
846    #[derive(Debug, Deserialize, Clone, Default, Serialize)]
847    #[serde(rename_all = "camelCase")]
848    pub struct UsageMetadata {
849        pub prompt_token_count: i32,
850        #[serde(skip_serializing_if = "Option::is_none")]
851        pub cached_content_token_count: Option<i32>,
852        pub candidates_token_count: i32,
853        pub total_token_count: i32,
854    }
855
856    impl std::fmt::Display for UsageMetadata {
857        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
858            write!(
859                f,
860                "Prompt token count: {}\nCached content token count: {}\nCandidates token count: {}\nTotal token count: {}",
861                self.prompt_token_count,
862                match self.cached_content_token_count {
863                    Some(count) => count.to_string(),
864                    None => "n/a".to_string(),
865                },
866                self.candidates_token_count,
867                self.total_token_count
868            )
869        }
870    }
871
872    /// A set of the feedback metadata the prompt specified in [GenerateContentRequest.contents](GenerateContentRequest).
873    #[derive(Debug, Deserialize, Serialize)]
874    #[serde(rename_all = "camelCase")]
875    pub struct PromptFeedback {
876        /// Optional. If set, the prompt was blocked and no candidates are returned. Rephrase the prompt.
877        pub block_reason: Option<BlockReason>,
878        /// Ratings for safety of the prompt. There is at most one rating per category.
879        pub safety_ratings: Option<Vec<SafetyRating>>,
880    }
881
882    /// Reason why a prompt was blocked by the model
883    #[derive(Debug, Deserialize, Serialize)]
884    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
885    pub enum BlockReason {
886        /// Default value. This value is unused.
887        BlockReasonUnspecified,
888        /// Prompt was blocked due to safety reasons. Inspect safetyRatings to understand which safety category blocked it.
889        Safety,
890        /// Prompt was blocked due to unknown reasons.
891        Other,
892        /// Prompt was blocked due to the terms which are included from the terminology blocklist.
893        Blocklist,
894        /// Prompt was blocked due to prohibited content.
895        ProhibitedContent,
896    }
897
898    #[derive(Debug, Deserialize, Serialize)]
899    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
900    pub enum FinishReason {
901        /// Default value. This value is unused.
902        FinishReasonUnspecified,
903        /// Natural stop point of the model or provided stop sequence.
904        Stop,
905        /// The maximum number of tokens as specified in the request was reached.
906        MaxTokens,
907        /// The response candidate content was flagged for safety reasons.
908        Safety,
909        /// The response candidate content was flagged for recitation reasons.
910        Recitation,
911        /// The response candidate content was flagged for using an unsupported language.
912        Language,
913        /// Unknown reason.
914        Other,
915        /// Token generation stopped because the content contains forbidden terms.
916        Blocklist,
917        /// Token generation stopped for potentially containing prohibited content.
918        ProhibitedContent,
919        /// Token generation stopped because the content potentially contains Sensitive Personally Identifiable Information (SPII).
920        Spii,
921        /// The function call generated by the model is invalid.
922        MalformedFunctionCall,
923    }
924
925    #[derive(Debug, Deserialize, Serialize)]
926    #[serde(rename_all = "camelCase")]
927    pub struct CitationMetadata {
928        pub citation_sources: Vec<CitationSource>,
929    }
930
931    #[derive(Debug, Deserialize, Serialize)]
932    #[serde(rename_all = "camelCase")]
933    pub struct CitationSource {
934        #[serde(skip_serializing_if = "Option::is_none")]
935        pub uri: Option<String>,
936        #[serde(skip_serializing_if = "Option::is_none")]
937        pub start_index: Option<i32>,
938        #[serde(skip_serializing_if = "Option::is_none")]
939        pub end_index: Option<i32>,
940        #[serde(skip_serializing_if = "Option::is_none")]
941        pub license: Option<String>,
942    }
943
944    #[derive(Debug, Deserialize, Serialize)]
945    #[serde(rename_all = "camelCase")]
946    pub struct LogprobsResult {
947        pub top_candidate: Vec<TopCandidate>,
948        pub chosen_candidate: Vec<LogProbCandidate>,
949    }
950
951    #[derive(Debug, Deserialize, Serialize)]
952    pub struct TopCandidate {
953        pub candidates: Vec<LogProbCandidate>,
954    }
955
956    #[derive(Debug, Deserialize, Serialize)]
957    #[serde(rename_all = "camelCase")]
958    pub struct LogProbCandidate {
959        pub token: String,
960        pub token_id: String,
961        pub log_probability: f64,
962    }
963
964    /// Gemini API Configuration options for model generation and outputs. Not all parameters are
965    /// configurable for every model. From [Gemini API Reference](https://ai.google.dev/api/generate-content#generationconfig)
966    /// ### Rig Note:
967    /// Can be used to construct a typesafe `additional_params` in rig::[AgentBuilder](crate::agent::AgentBuilder).
968    #[derive(Debug, Deserialize, Serialize)]
969    #[serde(rename_all = "camelCase")]
970    pub struct GenerationConfig {
971        /// The set of character sequences (up to 5) that will stop output generation. If specified, the API will stop
972        /// at the first appearance of a stop_sequence. The stop sequence will not be included as part of the response.
973        #[serde(skip_serializing_if = "Option::is_none")]
974        pub stop_sequences: Option<Vec<String>>,
975        /// MIME type of the generated candidate text. Supported MIME types are:
976        ///     - text/plain:  (default) Text output
977        ///     - application/json: JSON response in the response candidates.
978        ///     - text/x.enum: ENUM as a string response in the response candidates.
979        /// Refer to the docs for a list of all supported text MIME types
980        #[serde(skip_serializing_if = "Option::is_none")]
981        pub response_mime_type: Option<String>,
982        /// Output schema of the generated candidate text. Schemas must be a subset of the OpenAPI schema and can be
983        /// objects, primitives or arrays. If set, a compatible responseMimeType must also  be set. Compatible MIME
984        /// types: application/json: Schema for JSON response. Refer to the JSON text generation guide for more details.
985        #[serde(skip_serializing_if = "Option::is_none")]
986        pub response_schema: Option<Schema>,
987        /// Number of generated responses to return. Currently, this value can only be set to 1. If
988        /// unset, this will default to 1.
989        #[serde(skip_serializing_if = "Option::is_none")]
990        pub candidate_count: Option<i32>,
991        /// The maximum number of tokens to include in a response candidate. Note: The default value varies by model, see
992        /// the Model.output_token_limit attribute of the Model returned from the getModel function.
993        #[serde(skip_serializing_if = "Option::is_none")]
994        pub max_output_tokens: Option<u64>,
995        /// Controls the randomness of the output. Note: The default value varies by model, see the Model.temperature
996        /// attribute of the Model returned from the getModel function. Values can range from [0.0, 2.0].
997        #[serde(skip_serializing_if = "Option::is_none")]
998        pub temperature: Option<f64>,
999        /// The maximum cumulative probability of tokens to consider when sampling. The model uses combined Top-k and
1000        /// Top-p (nucleus) sampling. Tokens are sorted based on their assigned probabilities so that only the most
1001        /// likely tokens are considered. Top-k sampling directly limits the maximum number of tokens to consider, while
1002        /// Nucleus sampling limits the number of tokens based on the cumulative probability. Note: The default value
1003        /// varies by Model and is specified by theModel.top_p attribute returned from the getModel function. An empty
1004        /// topK attribute indicates that the model doesn't apply top-k sampling and doesn't allow setting topK on requests.
1005        #[serde(skip_serializing_if = "Option::is_none")]
1006        pub top_p: Option<f64>,
1007        /// The maximum number of tokens to consider when sampling. Gemini models use Top-p (nucleus) sampling or a
1008        /// combination of Top-k and nucleus sampling. Top-k sampling considers the set of topK most probable tokens.
1009        /// Models running with nucleus sampling don't allow topK setting. Note: The default value varies by Model and is
1010        /// specified by theModel.top_p attribute returned from the getModel function. An empty topK attribute indicates
1011        /// that the model doesn't apply top-k sampling and doesn't allow setting topK on requests.
1012        #[serde(skip_serializing_if = "Option::is_none")]
1013        pub top_k: Option<i32>,
1014        /// Presence penalty applied to the next token's logprobs if the token has already been seen in the response.
1015        /// This penalty is binary on/off and not dependent on the number of times the token is used (after the first).
1016        /// Use frequencyPenalty for a penalty that increases with each use. A positive penalty will discourage the use
1017        /// of tokens that have already been used in the response, increasing the vocabulary. A negative penalty will
1018        /// encourage the use of tokens that have already been used in the response, decreasing the vocabulary.
1019        #[serde(skip_serializing_if = "Option::is_none")]
1020        pub presence_penalty: Option<f64>,
1021        /// Frequency penalty applied to the next token's logprobs, multiplied by the number of times each token has been
1022        /// seen in the response so far. A positive penalty will discourage the use of tokens that have already been
1023        /// used, proportional to the number of times the token has been used: The more a token is used, the more
1024        /// difficult it is for the  model to use that token again increasing the vocabulary of responses. Caution: A
1025        /// negative penalty will encourage the model to reuse tokens proportional to the number of times the token has
1026        /// been used. Small negative values will reduce the vocabulary of a response. Larger negative values will cause
1027        /// the model to  repeating a common token until it hits the maxOutputTokens limit: "...the the the the the...".
1028        #[serde(skip_serializing_if = "Option::is_none")]
1029        pub frequency_penalty: Option<f64>,
1030        /// If true, export the logprobs results in response.
1031        #[serde(skip_serializing_if = "Option::is_none")]
1032        pub response_logprobs: Option<bool>,
1033        /// Only valid if responseLogprobs=True. This sets the number of top logprobs to return at each decoding step in
1034        /// [Candidate.logprobs_result].
1035        #[serde(skip_serializing_if = "Option::is_none")]
1036        pub logprobs: Option<i32>,
1037        /// Configuration for thinking/reasoning.
1038        #[serde(skip_serializing_if = "Option::is_none")]
1039        pub thinking_config: Option<ThinkingConfig>,
1040    }
1041
1042    impl Default for GenerationConfig {
1043        fn default() -> Self {
1044            Self {
1045                temperature: Some(1.0),
1046                max_output_tokens: Some(4096),
1047                stop_sequences: None,
1048                response_mime_type: None,
1049                response_schema: None,
1050                candidate_count: None,
1051                top_p: None,
1052                top_k: None,
1053                presence_penalty: None,
1054                frequency_penalty: None,
1055                response_logprobs: None,
1056                logprobs: None,
1057                thinking_config: None,
1058            }
1059        }
1060    }
1061
1062    #[derive(Debug, Deserialize, Serialize)]
1063    #[serde(rename_all = "camelCase")]
1064    pub struct ThinkingConfig {
1065        pub thinking_budget: u32,
1066        pub include_thoughts: Option<bool>,
1067    }
1068    /// The Schema object allows the definition of input and output data types. These types can be objects, but also
1069    /// primitives and arrays. Represents a select subset of an OpenAPI 3.0 schema object.
1070    /// From [Gemini API Reference](https://ai.google.dev/api/caching#Schema)
1071    #[derive(Debug, Deserialize, Serialize, Clone)]
1072    pub struct Schema {
1073        pub r#type: String,
1074        #[serde(skip_serializing_if = "Option::is_none")]
1075        pub format: Option<String>,
1076        #[serde(skip_serializing_if = "Option::is_none")]
1077        pub description: Option<String>,
1078        #[serde(skip_serializing_if = "Option::is_none")]
1079        pub nullable: Option<bool>,
1080        #[serde(skip_serializing_if = "Option::is_none")]
1081        pub r#enum: Option<Vec<String>>,
1082        #[serde(skip_serializing_if = "Option::is_none")]
1083        pub max_items: Option<i32>,
1084        #[serde(skip_serializing_if = "Option::is_none")]
1085        pub min_items: Option<i32>,
1086        #[serde(skip_serializing_if = "Option::is_none")]
1087        pub properties: Option<HashMap<String, Schema>>,
1088        #[serde(skip_serializing_if = "Option::is_none")]
1089        pub required: Option<Vec<String>>,
1090        #[serde(skip_serializing_if = "Option::is_none")]
1091        pub items: Option<Box<Schema>>,
1092    }
1093
1094    impl TryFrom<Value> for Schema {
1095        type Error = CompletionError;
1096
1097        fn try_from(value: Value) -> Result<Self, Self::Error> {
1098            if let Some(obj) = value.as_object() {
1099                Ok(Schema {
1100                    r#type: obj
1101                        .get("type")
1102                        .and_then(|v| {
1103                            if v.is_string() {
1104                                v.as_str().map(String::from)
1105                            } else if v.is_array() {
1106                                v.as_array()
1107                                    .and_then(|arr| arr.first())
1108                                    .and_then(|v| v.as_str().map(String::from))
1109                            } else {
1110                                None
1111                            }
1112                        })
1113                        .unwrap_or_default(),
1114                    format: obj.get("format").and_then(|v| v.as_str()).map(String::from),
1115                    description: obj
1116                        .get("description")
1117                        .and_then(|v| v.as_str())
1118                        .map(String::from),
1119                    nullable: obj.get("nullable").and_then(|v| v.as_bool()),
1120                    r#enum: obj.get("enum").and_then(|v| v.as_array()).map(|arr| {
1121                        arr.iter()
1122                            .filter_map(|v| v.as_str().map(String::from))
1123                            .collect()
1124                    }),
1125                    max_items: obj
1126                        .get("maxItems")
1127                        .and_then(|v| v.as_i64())
1128                        .map(|v| v as i32),
1129                    min_items: obj
1130                        .get("minItems")
1131                        .and_then(|v| v.as_i64())
1132                        .map(|v| v as i32),
1133                    properties: obj
1134                        .get("properties")
1135                        .and_then(|v| v.as_object())
1136                        .map(|map| {
1137                            map.iter()
1138                                .filter_map(|(k, v)| {
1139                                    v.clone().try_into().ok().map(|schema| (k.clone(), schema))
1140                                })
1141                                .collect()
1142                        }),
1143                    required: obj.get("required").and_then(|v| v.as_array()).map(|arr| {
1144                        arr.iter()
1145                            .filter_map(|v| v.as_str().map(String::from))
1146                            .collect()
1147                    }),
1148                    items: obj
1149                        .get("items")
1150                        .map(|v| Box::new(v.clone().try_into().unwrap())),
1151                })
1152            } else {
1153                Err(CompletionError::ResponseError(
1154                    "Expected a JSON object for Schema".into(),
1155                ))
1156            }
1157        }
1158    }
1159
1160    #[derive(Debug, Serialize)]
1161    #[serde(rename_all = "camelCase")]
1162    pub struct GenerateContentRequest {
1163        pub contents: Vec<Content>,
1164        #[serde(skip_serializing_if = "Option::is_none")]
1165        pub tools: Option<Tool>,
1166        pub tool_config: Option<ToolConfig>,
1167        /// Optional. Configuration options for model generation and outputs.
1168        pub generation_config: Option<GenerationConfig>,
1169        /// Optional. A list of unique SafetySetting instances for blocking unsafe content. This will be enforced on the
1170        /// [GenerateContentRequest.contents] and [GenerateContentResponse.candidates]. There should not be more than one
1171        /// setting for each SafetyCategory type. The API will block any contents and responses that fail to meet the
1172        /// thresholds set by these settings. This list overrides the default settings for each SafetyCategory specified
1173        /// in the safetySettings. If there is no SafetySetting for a given SafetyCategory provided in the list, the API
1174        /// will use the default safety setting for that category. Harm categories:
1175        ///     - HARM_CATEGORY_HATE_SPEECH,
1176        ///     - HARM_CATEGORY_SEXUALLY_EXPLICIT
1177        ///     - HARM_CATEGORY_DANGEROUS_CONTENT
1178        ///     - HARM_CATEGORY_HARASSMENT
1179        /// are supported.
1180        /// Refer to the guide for detailed information on available safety settings. Also refer to the Safety guidance
1181        /// to learn how to incorporate safety considerations in your AI applications.
1182        pub safety_settings: Option<Vec<SafetySetting>>,
1183        /// Optional. Developer set system instruction(s). Currently, text only.
1184        /// From [Gemini API Reference](https://ai.google.dev/gemini-api/docs/system-instructions?lang=rest)
1185        pub system_instruction: Option<Content>,
1186        // cachedContent: Optional<String>
1187        /// Additional parameters.
1188        #[serde(flatten, skip_serializing_if = "Option::is_none")]
1189        pub additional_params: Option<serde_json::Value>,
1190    }
1191
1192    #[derive(Debug, Serialize)]
1193    #[serde(rename_all = "camelCase")]
1194    pub struct Tool {
1195        pub function_declarations: Vec<FunctionDeclaration>,
1196        pub code_execution: Option<CodeExecution>,
1197    }
1198
1199    #[derive(Debug, Serialize, Clone)]
1200    #[serde(rename_all = "camelCase")]
1201    pub struct FunctionDeclaration {
1202        pub name: String,
1203        pub description: String,
1204        #[serde(skip_serializing_if = "Option::is_none")]
1205        pub parameters: Option<Schema>,
1206    }
1207
1208    #[derive(Debug, Serialize)]
1209    #[serde(rename_all = "camelCase")]
1210    pub struct ToolConfig {
1211        pub schema: Option<Schema>,
1212    }
1213
1214    #[derive(Debug, Serialize)]
1215    #[serde(rename_all = "camelCase")]
1216    pub struct CodeExecution {}
1217
1218    #[derive(Debug, Serialize)]
1219    #[serde(rename_all = "camelCase")]
1220    pub struct SafetySetting {
1221        pub category: HarmCategory,
1222        pub threshold: HarmBlockThreshold,
1223    }
1224
1225    #[derive(Debug, Serialize)]
1226    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1227    pub enum HarmBlockThreshold {
1228        HarmBlockThresholdUnspecified,
1229        BlockLowAndAbove,
1230        BlockMediumAndAbove,
1231        BlockOnlyHigh,
1232        BlockNone,
1233        Off,
1234    }
1235}
1236
1237#[cfg(test)]
1238mod tests {
1239    use crate::message;
1240
1241    use super::*;
1242    use serde_json::json;
1243
1244    #[test]
1245    fn test_deserialize_message_user() {
1246        let raw_message = r#"{
1247            "parts": [
1248                {"text": "Hello, world!"},
1249                {"inlineData": {"mimeType": "image/png", "data": "base64encodeddata"}},
1250                {"functionCall": {"name": "test_function", "args": {"arg1": "value1"}}},
1251                {"functionResponse": {"name": "test_function", "response": {"result": "success"}}},
1252                {"fileData": {"mimeType": "application/pdf", "fileUri": "http://example.com/file.pdf"}},
1253                {"executableCode": {"code": "print('Hello, world!')", "language": "PYTHON"}},
1254                {"codeExecutionResult": {"output": "Hello, world!", "outcome": "OUTCOME_OK"}}
1255            ],
1256            "role": "user"
1257        }"#;
1258
1259        let content: Content = {
1260            let jd = &mut serde_json::Deserializer::from_str(raw_message);
1261            serde_path_to_error::deserialize(jd).unwrap_or_else(|err| {
1262                panic!("Deserialization error at {}: {}", err.path(), err);
1263            })
1264        };
1265        assert_eq!(content.role, Some(Role::User));
1266        assert_eq!(content.parts.len(), 7);
1267
1268        let parts: Vec<Part> = content.parts.into_iter().collect();
1269
1270        if let Part {
1271            part: PartKind::Text(text),
1272            ..
1273        } = &parts[0]
1274        {
1275            assert_eq!(text, "Hello, world!");
1276        } else {
1277            panic!("Expected text part");
1278        }
1279
1280        if let Part {
1281            part: PartKind::InlineData(inline_data),
1282            ..
1283        } = &parts[1]
1284        {
1285            assert_eq!(inline_data.mime_type, "image/png");
1286            assert_eq!(inline_data.data, "base64encodeddata");
1287        } else {
1288            panic!("Expected inline data part");
1289        }
1290
1291        if let Part {
1292            part: PartKind::FunctionCall(function_call),
1293            ..
1294        } = &parts[2]
1295        {
1296            assert_eq!(function_call.name, "test_function");
1297            assert_eq!(
1298                function_call.args.as_object().unwrap().get("arg1").unwrap(),
1299                "value1"
1300            );
1301        } else {
1302            panic!("Expected function call part");
1303        }
1304
1305        if let Part {
1306            part: PartKind::FunctionResponse(function_response),
1307            ..
1308        } = &parts[3]
1309        {
1310            assert_eq!(function_response.name, "test_function");
1311            assert_eq!(
1312                function_response
1313                    .response
1314                    .as_ref()
1315                    .unwrap()
1316                    .get("result")
1317                    .unwrap(),
1318                "success"
1319            );
1320        } else {
1321            panic!("Expected function response part");
1322        }
1323
1324        if let Part {
1325            part: PartKind::FileData(file_data),
1326            ..
1327        } = &parts[4]
1328        {
1329            assert_eq!(file_data.mime_type.as_ref().unwrap(), "application/pdf");
1330            assert_eq!(file_data.file_uri, "http://example.com/file.pdf");
1331        } else {
1332            panic!("Expected file data part");
1333        }
1334
1335        if let Part {
1336            part: PartKind::ExecutableCode(executable_code),
1337            ..
1338        } = &parts[5]
1339        {
1340            assert_eq!(executable_code.code, "print('Hello, world!')");
1341        } else {
1342            panic!("Expected executable code part");
1343        }
1344
1345        if let Part {
1346            part: PartKind::CodeExecutionResult(code_execution_result),
1347            ..
1348        } = &parts[6]
1349        {
1350            assert_eq!(
1351                code_execution_result.clone().output.unwrap(),
1352                "Hello, world!"
1353            );
1354        } else {
1355            panic!("Expected code execution result part");
1356        }
1357    }
1358
1359    #[test]
1360    fn test_deserialize_message_model() {
1361        let json_data = json!({
1362            "parts": [{"text": "Hello, user!"}],
1363            "role": "model"
1364        });
1365
1366        let content: Content = serde_json::from_value(json_data).unwrap();
1367        assert_eq!(content.role, Some(Role::Model));
1368        assert_eq!(content.parts.len(), 1);
1369        if let Some(Part {
1370            part: PartKind::Text(text),
1371            ..
1372        }) = content.parts.first()
1373        {
1374            assert_eq!(text, "Hello, user!");
1375        } else {
1376            panic!("Expected text part");
1377        }
1378    }
1379
1380    #[test]
1381    fn test_message_conversion_user() {
1382        let msg = message::Message::user("Hello, world!");
1383        let content: Content = msg.try_into().unwrap();
1384        assert_eq!(content.role, Some(Role::User));
1385        assert_eq!(content.parts.len(), 1);
1386        if let Some(Part {
1387            part: PartKind::Text(text),
1388            ..
1389        }) = &content.parts.first()
1390        {
1391            assert_eq!(text, "Hello, world!");
1392        } else {
1393            panic!("Expected text part");
1394        }
1395    }
1396
1397    #[test]
1398    fn test_message_conversion_model() {
1399        let msg = message::Message::assistant("Hello, user!");
1400
1401        let content: Content = msg.try_into().unwrap();
1402        assert_eq!(content.role, Some(Role::Model));
1403        assert_eq!(content.parts.len(), 1);
1404        if let Some(Part {
1405            part: PartKind::Text(text),
1406            ..
1407        }) = &content.parts.first()
1408        {
1409            assert_eq!(text, "Hello, user!");
1410        } else {
1411            panic!("Expected text part");
1412        }
1413    }
1414
1415    #[test]
1416    fn test_message_conversion_tool_call() {
1417        let tool_call = message::ToolCall {
1418            id: "test_tool".to_string(),
1419            call_id: None,
1420            function: message::ToolFunction {
1421                name: "test_function".to_string(),
1422                arguments: json!({"arg1": "value1"}),
1423            },
1424        };
1425
1426        let msg = message::Message::Assistant {
1427            id: None,
1428            content: OneOrMany::one(message::AssistantContent::ToolCall(tool_call)),
1429        };
1430
1431        let content: Content = msg.try_into().unwrap();
1432        assert_eq!(content.role, Some(Role::Model));
1433        assert_eq!(content.parts.len(), 1);
1434        if let Some(Part {
1435            part: PartKind::FunctionCall(function_call),
1436            ..
1437        }) = content.parts.first()
1438        {
1439            assert_eq!(function_call.name, "test_function");
1440            assert_eq!(
1441                function_call.args.as_object().unwrap().get("arg1").unwrap(),
1442                "value1"
1443            );
1444        } else {
1445            panic!("Expected function call part");
1446        }
1447    }
1448}