rig/providers/gemini/
completion.rs

1// ================================================================
2//! Google Gemini Completion Integration
3//! From [Gemini API Reference](https://ai.google.dev/api/generate-content)
4// ================================================================
5/// `gemini-2.5-pro-preview-06-05` completion model
6pub const GEMINI_2_5_PRO_PREVIEW_06_05: &str = "gemini-2.5-pro-preview-06-05";
7/// `gemini-2.5-pro-preview-05-06` completion model
8pub const GEMINI_2_5_PRO_PREVIEW_05_06: &str = "gemini-2.5-pro-preview-05-06";
9/// `gemini-2.5-pro-preview-03-25` completion model
10pub const GEMINI_2_5_PRO_PREVIEW_03_25: &str = "gemini-2.5-pro-preview-03-25";
11/// `gemini-2.5-flash-preview-05-20` completion model
12pub const GEMINI_2_5_FLASH_PREVIEW_05_20: &str = "gemini-2.5-flash-preview-05-20";
13/// `gemini-2.5-flash-preview-04-17` completion model
14pub const GEMINI_2_5_FLASH_PREVIEW_04_17: &str = "gemini-2.5-flash-preview-04-17";
15/// `gemini-2.5-pro-exp-03-25` experimental completion model
16pub const GEMINI_2_5_PRO_EXP_03_25: &str = "gemini-2.5-pro-exp-03-25";
17/// `gemini-2.0-flash-lite` completion model
18pub const GEMINI_2_0_FLASH_LITE: &str = "gemini-2.0-flash-lite";
19/// `gemini-2.0-flash` completion model
20pub const GEMINI_2_0_FLASH: &str = "gemini-2.0-flash";
21/// `gemini-1.5-flash` completion model
22pub const GEMINI_1_5_FLASH: &str = "gemini-1.5-flash";
23/// `gemini-1.5-pro` completion model
24pub const GEMINI_1_5_PRO: &str = "gemini-1.5-pro";
25/// `gemini-1.5-pro-8b` completion model
26pub const GEMINI_1_5_PRO_8B: &str = "gemini-1.5-pro-8b";
27/// `gemini-1.0-pro` completion model
28pub const GEMINI_1_0_PRO: &str = "gemini-1.0-pro";
29
30use self::gemini_api_types::Schema;
31use crate::providers::gemini::streaming::StreamingCompletionResponse;
32use crate::{
33    OneOrMany,
34    completion::{self, CompletionError, CompletionRequest},
35};
36use gemini_api_types::{
37    Content, FunctionDeclaration, GenerateContentRequest, GenerateContentResponse,
38    GenerationConfig, Part, Role, Tool,
39};
40use serde_json::{Map, Value};
41use std::convert::TryFrom;
42
43use super::Client;
44
45// =================================================================
46// Rig Implementation Types
47// =================================================================
48
49#[derive(Clone)]
50pub struct CompletionModel {
51    pub(crate) client: Client,
52    pub model: String,
53}
54
55impl CompletionModel {
56    pub fn new(client: Client, model: &str) -> Self {
57        Self {
58            client,
59            model: model.to_string(),
60        }
61    }
62}
63
64impl completion::CompletionModel for CompletionModel {
65    type Response = GenerateContentResponse;
66    type StreamingResponse = StreamingCompletionResponse;
67
68    #[cfg_attr(feature = "worker", worker::send)]
69    async fn completion(
70        &self,
71        completion_request: CompletionRequest,
72    ) -> Result<completion::CompletionResponse<GenerateContentResponse>, CompletionError> {
73        let request = create_request_body(completion_request)?;
74
75        tracing::debug!(
76            "Sending completion request to Gemini API {}",
77            serde_json::to_string_pretty(&request)?
78        );
79
80        let response = self
81            .client
82            .post(&format!("/v1beta/models/{}:generateContent", self.model))
83            .json(&request)
84            .send()
85            .await?;
86
87        if response.status().is_success() {
88            let response = response.json::<GenerateContentResponse>().await?;
89            match response.usage_metadata {
90                Some(ref usage) => tracing::info!(target: "rig",
91                "Gemini completion token usage: {}",
92                usage
93                ),
94                None => tracing::info!(target: "rig",
95                    "Gemini completion token usage: n/a",
96                ),
97            }
98
99            tracing::debug!("Received response");
100
101            Ok(completion::CompletionResponse::try_from(response))
102        } else {
103            Err(CompletionError::ProviderError(response.text().await?))
104        }?
105    }
106
107    #[cfg_attr(feature = "worker", worker::send)]
108    async fn stream(
109        &self,
110        request: CompletionRequest,
111    ) -> Result<
112        crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
113        CompletionError,
114    > {
115        CompletionModel::stream(self, request).await
116    }
117}
118
119pub(crate) fn create_request_body(
120    completion_request: CompletionRequest,
121) -> Result<GenerateContentRequest, CompletionError> {
122    let mut full_history = Vec::new();
123    full_history.extend(completion_request.chat_history);
124
125    let additional_params = completion_request
126        .additional_params
127        .unwrap_or_else(|| Value::Object(Map::new()));
128
129    let mut generation_config = serde_json::from_value::<GenerationConfig>(additional_params)?;
130
131    if let Some(temp) = completion_request.temperature {
132        generation_config.temperature = Some(temp);
133    }
134
135    if let Some(max_tokens) = completion_request.max_tokens {
136        generation_config.max_output_tokens = Some(max_tokens);
137    }
138
139    let system_instruction = completion_request.preamble.clone().map(|preamble| Content {
140        parts: OneOrMany::one(preamble.into()),
141        role: Some(Role::Model),
142    });
143
144    let request = GenerateContentRequest {
145        contents: full_history
146            .into_iter()
147            .map(|msg| {
148                msg.try_into()
149                    .map_err(|e| CompletionError::RequestError(Box::new(e)))
150            })
151            .collect::<Result<Vec<_>, _>>()?,
152        generation_config: Some(generation_config),
153        safety_settings: None,
154        tools: Some(Tool::try_from(completion_request.tools)?),
155        tool_config: None,
156        system_instruction,
157    };
158
159    Ok(request)
160}
161
162impl TryFrom<completion::ToolDefinition> for Tool {
163    type Error = CompletionError;
164
165    fn try_from(tool: completion::ToolDefinition) -> Result<Self, Self::Error> {
166        let parameters: Option<Schema> =
167            if tool.parameters == serde_json::json!({"type": "object", "properties": {}}) {
168                None
169            } else {
170                Some(tool.parameters.try_into()?)
171            };
172        Ok(Self {
173            function_declarations: OneOrMany::one(FunctionDeclaration {
174                name: tool.name,
175                description: tool.description,
176                parameters,
177            }),
178            code_execution: None,
179        })
180    }
181}
182
183impl TryFrom<Vec<completion::ToolDefinition>> for Tool {
184    type Error = CompletionError;
185
186    fn try_from(tools: Vec<completion::ToolDefinition>) -> Result<Self, Self::Error> {
187        let mut functions = Vec::new();
188
189        for tool in tools {
190            let parameters =
191                if tool.parameters == serde_json::json!({"type": "object", "properties": {}}) {
192                    None
193                } else {
194                    match tool.parameters.try_into() {
195                        Ok(schema) => Some(schema),
196                        Err(e) => {
197                            let emsg = format!(
198                                "Tool '{}' could not be converted to a schema: {:?}",
199                                tool.name, e,
200                            );
201                            return Err(CompletionError::ProviderError(emsg));
202                        }
203                    }
204                };
205
206            functions.push(FunctionDeclaration {
207                name: tool.name,
208                description: tool.description,
209                parameters,
210            });
211        }
212
213        let function_declarations: OneOrMany<FunctionDeclaration> = OneOrMany::many(functions)
214            .map_err(|x| CompletionError::ProviderError(x.to_string()))?;
215
216        Ok(Self {
217            function_declarations,
218            code_execution: None,
219        })
220    }
221}
222
223impl TryFrom<GenerateContentResponse> for completion::CompletionResponse<GenerateContentResponse> {
224    type Error = CompletionError;
225
226    fn try_from(response: GenerateContentResponse) -> Result<Self, Self::Error> {
227        let candidate = response.candidates.first().ok_or_else(|| {
228            CompletionError::ResponseError("No response candidates in response".into())
229        })?;
230
231        let content = candidate
232            .content
233            .parts
234            .iter()
235            .map(|part| {
236                Ok(match part {
237                    Part::Text(text) => completion::AssistantContent::text(text),
238                    Part::FunctionCall(function_call) => completion::AssistantContent::tool_call(
239                        &function_call.name,
240                        &function_call.name,
241                        function_call.args.clone(),
242                    ),
243                    _ => {
244                        return Err(CompletionError::ResponseError(
245                            "Response did not contain a message or tool call".into(),
246                        ));
247                    }
248                })
249            })
250            .collect::<Result<Vec<_>, _>>()?;
251
252        let choice = OneOrMany::many(content).map_err(|_| {
253            CompletionError::ResponseError(
254                "Response contained no message or tool call (empty)".to_owned(),
255            )
256        })?;
257
258        Ok(completion::CompletionResponse {
259            choice,
260            raw_response: response,
261        })
262    }
263}
264
265pub mod gemini_api_types {
266    use std::{collections::HashMap, convert::Infallible, str::FromStr};
267
268    // =================================================================
269    // Gemini API Types
270    // =================================================================
271    use serde::{Deserialize, Serialize};
272    use serde_json::{Value, json};
273
274    use crate::{
275        OneOrMany,
276        completion::CompletionError,
277        message::{self, MessageError, MimeType as _},
278        one_or_many::string_or_one_or_many,
279        providers::gemini::gemini_api_types::{CodeExecutionResult, ExecutableCode},
280    };
281
282    /// Response from the model supporting multiple candidate responses.
283    /// Safety ratings and content filtering are reported for both prompt in GenerateContentResponse.prompt_feedback
284    /// and for each candidate in finishReason and in safetyRatings.
285    /// The API:
286    ///     - Returns either all requested candidates or none of them
287    ///     - Returns no candidates at all only if there was something wrong with the prompt (check promptFeedback)
288    ///     - Reports feedback on each candidate in finishReason and safetyRatings.
289    #[derive(Debug, Deserialize)]
290    #[serde(rename_all = "camelCase")]
291    pub struct GenerateContentResponse {
292        /// Candidate responses from the model.
293        pub candidates: Vec<ContentCandidate>,
294        /// Returns the prompt's feedback related to the content filters.
295        pub prompt_feedback: Option<PromptFeedback>,
296        /// Output only. Metadata on the generation requests' token usage.
297        pub usage_metadata: Option<UsageMetadata>,
298        pub model_version: Option<String>,
299    }
300
301    /// A response candidate generated from the model.
302    #[derive(Debug, Deserialize)]
303    #[serde(rename_all = "camelCase")]
304    pub struct ContentCandidate {
305        /// Output only. Generated content returned from the model.
306        pub content: Content,
307        /// Optional. Output only. The reason why the model stopped generating tokens.
308        /// If empty, the model has not stopped generating tokens.
309        pub finish_reason: Option<FinishReason>,
310        /// List of ratings for the safety of a response candidate.
311        /// There is at most one rating per category.
312        pub safety_ratings: Option<Vec<SafetyRating>>,
313        /// Output only. Citation information for model-generated candidate.
314        /// This field may be populated with recitation information for any text included in the content.
315        /// These are passages that are "recited" from copyrighted material in the foundational LLM's training data.
316        pub citation_metadata: Option<CitationMetadata>,
317        /// Output only. Token count for this candidate.
318        pub token_count: Option<i32>,
319        /// Output only.
320        pub avg_logprobs: Option<f64>,
321        /// Output only. Log-likelihood scores for the response tokens and top tokens
322        pub logprobs_result: Option<LogprobsResult>,
323        /// Output only. Index of the candidate in the list of response candidates.
324        pub index: Option<i32>,
325    }
326    #[derive(Debug, Deserialize, Serialize)]
327    pub struct Content {
328        /// Ordered Parts that constitute a single message. Parts may have different MIME types.
329        #[serde(deserialize_with = "string_or_one_or_many")]
330        pub parts: OneOrMany<Part>,
331        /// The producer of the content. Must be either 'user' or 'model'.
332        /// Useful to set for multi-turn conversations, otherwise can be left blank or unset.
333        pub role: Option<Role>,
334    }
335
336    impl TryFrom<message::Message> for Content {
337        type Error = message::MessageError;
338
339        fn try_from(msg: message::Message) -> Result<Self, Self::Error> {
340            Ok(match msg {
341                message::Message::User { content } => Content {
342                    parts: content.try_map(|c| c.try_into())?,
343                    role: Some(Role::User),
344                },
345                message::Message::Assistant { content, .. } => Content {
346                    role: Some(Role::Model),
347                    parts: content.map(|content| content.into()),
348                },
349            })
350        }
351    }
352
353    impl TryFrom<Content> for message::Message {
354        type Error = message::MessageError;
355
356        fn try_from(content: Content) -> Result<Self, Self::Error> {
357            match content.role {
358                Some(Role::User) | None => Ok(message::Message::User {
359                    content: content.parts.try_map(|part| {
360                        Ok(match part {
361                            Part::Text(text) => message::UserContent::text(text),
362                            Part::InlineData(inline_data) => {
363                                let mime_type =
364                                    message::MediaType::from_mime_type(&inline_data.mime_type);
365
366                                match mime_type {
367                                    Some(message::MediaType::Image(media_type)) => {
368                                        message::UserContent::image(
369                                            inline_data.data,
370                                            Some(message::ContentFormat::default()),
371                                            Some(media_type),
372                                            Some(message::ImageDetail::default()),
373                                        )
374                                    }
375                                    Some(message::MediaType::Document(media_type)) => {
376                                        message::UserContent::document(
377                                            inline_data.data,
378                                            Some(message::ContentFormat::default()),
379                                            Some(media_type),
380                                        )
381                                    }
382                                    Some(message::MediaType::Audio(media_type)) => {
383                                        message::UserContent::audio(
384                                            inline_data.data,
385                                            Some(message::ContentFormat::default()),
386                                            Some(media_type),
387                                        )
388                                    }
389                                    _ => {
390                                        return Err(message::MessageError::ConversionError(
391                                            format!("Unsupported media type {mime_type:?}"),
392                                        ));
393                                    }
394                                }
395                            }
396                            _ => {
397                                return Err(message::MessageError::ConversionError(format!(
398                                    "Unsupported gemini content part type: {part:?}"
399                                )));
400                            }
401                        })
402                    })?,
403                }),
404                Some(Role::Model) => Ok(message::Message::Assistant {
405                    id: None,
406                    content: content.parts.try_map(|part| {
407                        Ok(match part {
408                            Part::Text(text) => message::AssistantContent::text(text),
409                            Part::FunctionCall(function_call) => {
410                                message::AssistantContent::ToolCall(function_call.into())
411                            }
412                            _ => {
413                                return Err(message::MessageError::ConversionError(format!(
414                                    "Unsupported part type: {part:?}"
415                                )));
416                            }
417                        })
418                    })?,
419                }),
420            }
421        }
422    }
423
424    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
425    #[serde(rename_all = "lowercase")]
426    pub enum Role {
427        User,
428        Model,
429    }
430
431    /// A datatype containing media that is part of a multi-part [Content] message.
432    /// A Part consists of data which has an associated datatype. A Part can only contain one of the accepted types in Part.data.
433    /// A Part must have a fixed IANA MIME type identifying the type and subtype of the media if the inlineData field is filled with raw bytes.
434    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
435    #[serde(rename_all = "camelCase")]
436    pub enum Part {
437        Text(String),
438        InlineData(Blob),
439        FunctionCall(FunctionCall),
440        FunctionResponse(FunctionResponse),
441        FileData(FileData),
442        ExecutableCode(ExecutableCode),
443        CodeExecutionResult(CodeExecutionResult),
444    }
445
446    impl From<String> for Part {
447        fn from(text: String) -> Self {
448            Self::Text(text)
449        }
450    }
451
452    impl From<&str> for Part {
453        fn from(text: &str) -> Self {
454            Self::Text(text.to_string())
455        }
456    }
457
458    impl FromStr for Part {
459        type Err = Infallible;
460
461        fn from_str(s: &str) -> Result<Self, Self::Err> {
462            Ok(s.into())
463        }
464    }
465
466    impl TryFrom<message::UserContent> for Part {
467        type Error = message::MessageError;
468
469        fn try_from(content: message::UserContent) -> Result<Self, Self::Error> {
470            match content {
471                message::UserContent::Text(message::Text { text }) => Ok(Self::Text(text)),
472                message::UserContent::ToolResult(message::ToolResult { id, content, .. }) => {
473                    let content = match content.first() {
474                        message::ToolResultContent::Text(text) => text.text,
475                        message::ToolResultContent::Image(_) => {
476                            return Err(message::MessageError::ConversionError(
477                                "Tool result content must be text".to_string(),
478                            ));
479                        }
480                    };
481                    // Convert to JSON since this value may be a valid JSON value
482                    let result: serde_json::Value = serde_json::from_str(&content)
483                        .map_err(|x| MessageError::ConversionError(x.to_string()))?;
484                    Ok(Part::FunctionResponse(FunctionResponse {
485                        name: id,
486                        response: Some(json!({ "result": result })),
487                    }))
488                }
489                message::UserContent::Image(message::Image {
490                    data, media_type, ..
491                }) => match media_type {
492                    Some(media_type) => match media_type {
493                        message::ImageMediaType::JPEG
494                        | message::ImageMediaType::PNG
495                        | message::ImageMediaType::WEBP
496                        | message::ImageMediaType::HEIC
497                        | message::ImageMediaType::HEIF => Ok(Self::InlineData(Blob {
498                            mime_type: media_type.to_mime_type().to_owned(),
499                            data,
500                        })),
501                        _ => Err(message::MessageError::ConversionError(format!(
502                            "Unsupported image media type {media_type:?}"
503                        ))),
504                    },
505                    None => Err(message::MessageError::ConversionError(
506                        "Media type for image is required for Anthropic".to_string(),
507                    )),
508                },
509                message::UserContent::Document(message::Document {
510                    data, media_type, ..
511                }) => match media_type {
512                    Some(media_type) => match media_type {
513                        message::DocumentMediaType::PDF
514                        | message::DocumentMediaType::TXT
515                        | message::DocumentMediaType::RTF
516                        | message::DocumentMediaType::HTML
517                        | message::DocumentMediaType::CSS
518                        | message::DocumentMediaType::MARKDOWN
519                        | message::DocumentMediaType::CSV
520                        | message::DocumentMediaType::XML => Ok(Self::InlineData(Blob {
521                            mime_type: media_type.to_mime_type().to_owned(),
522                            data,
523                        })),
524                        _ => Err(message::MessageError::ConversionError(format!(
525                            "Unsupported document media type {media_type:?}"
526                        ))),
527                    },
528                    None => Err(message::MessageError::ConversionError(
529                        "Media type for document is required for Anthropic".to_string(),
530                    )),
531                },
532                message::UserContent::Audio(message::Audio {
533                    data, media_type, ..
534                }) => match media_type {
535                    Some(media_type) => Ok(Self::InlineData(Blob {
536                        mime_type: media_type.to_mime_type().to_owned(),
537                        data,
538                    })),
539                    None => Err(message::MessageError::ConversionError(
540                        "Media type for audio is required for Anthropic".to_string(),
541                    )),
542                },
543            }
544        }
545    }
546
547    impl From<message::AssistantContent> for Part {
548        fn from(content: message::AssistantContent) -> Self {
549            match content {
550                message::AssistantContent::Text(message::Text { text }) => text.into(),
551                message::AssistantContent::ToolCall(tool_call) => tool_call.into(),
552            }
553        }
554    }
555
556    impl From<message::ToolCall> for Part {
557        fn from(tool_call: message::ToolCall) -> Self {
558            Self::FunctionCall(FunctionCall {
559                name: tool_call.function.name,
560                args: tool_call.function.arguments,
561            })
562        }
563    }
564
565    /// Raw media bytes.
566    /// Text should not be sent as raw bytes, use the 'text' field.
567    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
568    #[serde(rename_all = "camelCase")]
569    pub struct Blob {
570        /// The IANA standard MIME type of the source data. Examples: - image/png - image/jpeg
571        /// If an unsupported MIME type is provided, an error will be returned.
572        pub mime_type: String,
573        /// Raw bytes for media formats. A base64-encoded string.
574        pub data: String,
575    }
576
577    /// A predicted FunctionCall returned from the model that contains a string representing the
578    /// FunctionDeclaration.name with the arguments and their values.
579    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
580    pub struct FunctionCall {
581        /// Required. The name of the function to call. Must be a-z, A-Z, 0-9, or contain underscores
582        /// and dashes, with a maximum length of 63.
583        pub name: String,
584        /// Optional. The function parameters and values in JSON object format.
585        pub args: serde_json::Value,
586    }
587
588    impl From<FunctionCall> for message::ToolCall {
589        fn from(function_call: FunctionCall) -> Self {
590            Self {
591                id: function_call.name.clone(),
592                call_id: None,
593                function: message::ToolFunction {
594                    name: function_call.name,
595                    arguments: function_call.args,
596                },
597            }
598        }
599    }
600
601    impl From<message::ToolCall> for FunctionCall {
602        fn from(tool_call: message::ToolCall) -> Self {
603            Self {
604                name: tool_call.function.name,
605                args: tool_call.function.arguments,
606            }
607        }
608    }
609
610    /// The result output from a FunctionCall that contains a string representing the FunctionDeclaration.name
611    /// and a structured JSON object containing any output from the function is used as context to the model.
612    /// This should contain the result of aFunctionCall made based on model prediction.
613    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
614    pub struct FunctionResponse {
615        /// The name of the function to call. Must be a-z, A-Z, 0-9, or contain underscores and dashes,
616        /// with a maximum length of 63.
617        pub name: String,
618        /// The function response in JSON object format.
619        pub response: Option<serde_json::Value>,
620    }
621
622    /// URI based data.
623    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
624    #[serde(rename_all = "camelCase")]
625    pub struct FileData {
626        /// Optional. The IANA standard MIME type of the source data.
627        pub mime_type: Option<String>,
628        /// Required. URI.
629        pub file_uri: String,
630    }
631
632    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
633    pub struct SafetyRating {
634        pub category: HarmCategory,
635        pub probability: HarmProbability,
636    }
637
638    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
639    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
640    pub enum HarmProbability {
641        HarmProbabilityUnspecified,
642        Negligible,
643        Low,
644        Medium,
645        High,
646    }
647
648    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
649    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
650    pub enum HarmCategory {
651        HarmCategoryUnspecified,
652        HarmCategoryDerogatory,
653        HarmCategoryToxicity,
654        HarmCategoryViolence,
655        HarmCategorySexually,
656        HarmCategoryMedical,
657        HarmCategoryDangerous,
658        HarmCategoryHarassment,
659        HarmCategoryHateSpeech,
660        HarmCategorySexuallyExplicit,
661        HarmCategoryDangerousContent,
662        HarmCategoryCivicIntegrity,
663    }
664
665    #[derive(Debug, Deserialize, Clone, Default)]
666    #[serde(rename_all = "camelCase")]
667    pub struct UsageMetadata {
668        pub prompt_token_count: i32,
669        #[serde(skip_serializing_if = "Option::is_none")]
670        pub cached_content_token_count: Option<i32>,
671        pub candidates_token_count: i32,
672        pub total_token_count: i32,
673    }
674
675    impl std::fmt::Display for UsageMetadata {
676        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
677            write!(
678                f,
679                "Prompt token count: {}\nCached content token count: {}\nCandidates token count: {}\nTotal token count: {}",
680                self.prompt_token_count,
681                match self.cached_content_token_count {
682                    Some(count) => count.to_string(),
683                    None => "n/a".to_string(),
684                },
685                self.candidates_token_count,
686                self.total_token_count
687            )
688        }
689    }
690
691    /// A set of the feedback metadata the prompt specified in [GenerateContentRequest.contents](GenerateContentRequest).
692    #[derive(Debug, Deserialize)]
693    #[serde(rename_all = "camelCase")]
694    pub struct PromptFeedback {
695        /// Optional. If set, the prompt was blocked and no candidates are returned. Rephrase the prompt.
696        pub block_reason: Option<BlockReason>,
697        /// Ratings for safety of the prompt. There is at most one rating per category.
698        pub safety_ratings: Option<Vec<SafetyRating>>,
699    }
700
701    /// Reason why a prompt was blocked by the model
702    #[derive(Debug, Deserialize)]
703    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
704    pub enum BlockReason {
705        /// Default value. This value is unused.
706        BlockReasonUnspecified,
707        /// Prompt was blocked due to safety reasons. Inspect safetyRatings to understand which safety category blocked it.
708        Safety,
709        /// Prompt was blocked due to unknown reasons.
710        Other,
711        /// Prompt was blocked due to the terms which are included from the terminology blocklist.
712        Blocklist,
713        /// Prompt was blocked due to prohibited content.
714        ProhibitedContent,
715    }
716
717    #[derive(Debug, Deserialize)]
718    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
719    pub enum FinishReason {
720        /// Default value. This value is unused.
721        FinishReasonUnspecified,
722        /// Natural stop point of the model or provided stop sequence.
723        Stop,
724        /// The maximum number of tokens as specified in the request was reached.
725        MaxTokens,
726        /// The response candidate content was flagged for safety reasons.
727        Safety,
728        /// The response candidate content was flagged for recitation reasons.
729        Recitation,
730        /// The response candidate content was flagged for using an unsupported language.
731        Language,
732        /// Unknown reason.
733        Other,
734        /// Token generation stopped because the content contains forbidden terms.
735        Blocklist,
736        /// Token generation stopped for potentially containing prohibited content.
737        ProhibitedContent,
738        /// Token generation stopped because the content potentially contains Sensitive Personally Identifiable Information (SPII).
739        Spii,
740        /// The function call generated by the model is invalid.
741        MalformedFunctionCall,
742    }
743
744    #[derive(Debug, Deserialize)]
745    #[serde(rename_all = "camelCase")]
746    pub struct CitationMetadata {
747        pub citation_sources: Vec<CitationSource>,
748    }
749
750    #[derive(Debug, Deserialize)]
751    #[serde(rename_all = "camelCase")]
752    pub struct CitationSource {
753        #[serde(skip_serializing_if = "Option::is_none")]
754        pub uri: Option<String>,
755        #[serde(skip_serializing_if = "Option::is_none")]
756        pub start_index: Option<i32>,
757        #[serde(skip_serializing_if = "Option::is_none")]
758        pub end_index: Option<i32>,
759        #[serde(skip_serializing_if = "Option::is_none")]
760        pub license: Option<String>,
761    }
762
763    #[derive(Debug, Deserialize)]
764    #[serde(rename_all = "camelCase")]
765    pub struct LogprobsResult {
766        pub top_candidate: Vec<TopCandidate>,
767        pub chosen_candidate: Vec<LogProbCandidate>,
768    }
769
770    #[derive(Debug, Deserialize)]
771    pub struct TopCandidate {
772        pub candidates: Vec<LogProbCandidate>,
773    }
774
775    #[derive(Debug, Deserialize)]
776    #[serde(rename_all = "camelCase")]
777    pub struct LogProbCandidate {
778        pub token: String,
779        pub token_id: String,
780        pub log_probability: f64,
781    }
782
783    /// Gemini API Configuration options for model generation and outputs. Not all parameters are
784    /// configurable for every model. From [Gemini API Reference](https://ai.google.dev/api/generate-content#generationconfig)
785    /// ### Rig Note:
786    /// Can be used to construct a typesafe `additional_params` in rig::[AgentBuilder](crate::agent::AgentBuilder).
787    #[derive(Debug, Deserialize, Serialize)]
788    #[serde(rename_all = "camelCase")]
789    pub struct GenerationConfig {
790        /// The set of character sequences (up to 5) that will stop output generation. If specified, the API will stop
791        /// at the first appearance of a stop_sequence. The stop sequence will not be included as part of the response.
792        #[serde(skip_serializing_if = "Option::is_none")]
793        pub stop_sequences: Option<Vec<String>>,
794        /// MIME type of the generated candidate text. Supported MIME types are:
795        ///     - text/plain:  (default) Text output
796        ///     - application/json: JSON response in the response candidates.
797        ///     - text/x.enum: ENUM as a string response in the response candidates.
798        /// Refer to the docs for a list of all supported text MIME types
799        #[serde(skip_serializing_if = "Option::is_none")]
800        pub response_mime_type: Option<String>,
801        /// Output schema of the generated candidate text. Schemas must be a subset of the OpenAPI schema and can be
802        /// objects, primitives or arrays. If set, a compatible responseMimeType must also  be set. Compatible MIME
803        /// types: application/json: Schema for JSON response. Refer to the JSON text generation guide for more details.
804        #[serde(skip_serializing_if = "Option::is_none")]
805        pub response_schema: Option<Schema>,
806        /// Number of generated responses to return. Currently, this value can only be set to 1. If
807        /// unset, this will default to 1.
808        #[serde(skip_serializing_if = "Option::is_none")]
809        pub candidate_count: Option<i32>,
810        /// The maximum number of tokens to include in a response candidate. Note: The default value varies by model, see
811        /// the Model.output_token_limit attribute of the Model returned from the getModel function.
812        #[serde(skip_serializing_if = "Option::is_none")]
813        pub max_output_tokens: Option<u64>,
814        /// Controls the randomness of the output. Note: The default value varies by model, see the Model.temperature
815        /// attribute of the Model returned from the getModel function. Values can range from [0.0, 2.0].
816        #[serde(skip_serializing_if = "Option::is_none")]
817        pub temperature: Option<f64>,
818        /// The maximum cumulative probability of tokens to consider when sampling. The model uses combined Top-k and
819        /// Top-p (nucleus) sampling. Tokens are sorted based on their assigned probabilities so that only the most
820        /// likely tokens are considered. Top-k sampling directly limits the maximum number of tokens to consider, while
821        /// Nucleus sampling limits the number of tokens based on the cumulative probability. Note: The default value
822        /// varies by Model and is specified by theModel.top_p attribute returned from the getModel function. An empty
823        /// topK attribute indicates that the model doesn't apply top-k sampling and doesn't allow setting topK on requests.
824        #[serde(skip_serializing_if = "Option::is_none")]
825        pub top_p: Option<f64>,
826        /// The maximum number of tokens to consider when sampling. Gemini models use Top-p (nucleus) sampling or a
827        /// combination of Top-k and nucleus sampling. Top-k sampling considers the set of topK most probable tokens.
828        /// Models running with nucleus sampling don't allow topK setting. Note: The default value varies by Model and is
829        /// specified by theModel.top_p attribute returned from the getModel function. An empty topK attribute indicates
830        /// that the model doesn't apply top-k sampling and doesn't allow setting topK on requests.
831        #[serde(skip_serializing_if = "Option::is_none")]
832        pub top_k: Option<i32>,
833        /// Presence penalty applied to the next token's logprobs if the token has already been seen in the response.
834        /// This penalty is binary on/off and not dependent on the number of times the token is used (after the first).
835        /// Use frequencyPenalty for a penalty that increases with each use. A positive penalty will discourage the use
836        /// of tokens that have already been used in the response, increasing the vocabulary. A negative penalty will
837        /// encourage the use of tokens that have already been used in the response, decreasing the vocabulary.
838        #[serde(skip_serializing_if = "Option::is_none")]
839        pub presence_penalty: Option<f64>,
840        /// Frequency penalty applied to the next token's logprobs, multiplied by the number of times each token has been
841        /// seen in the response so far. A positive penalty will discourage the use of tokens that have already been
842        /// used, proportional to the number of times the token has been used: The more a token is used, the more
843        /// difficult it is for the  model to use that token again increasing the vocabulary of responses. Caution: A
844        /// negative penalty will encourage the model to reuse tokens proportional to the number of times the token has
845        /// been used. Small negative values will reduce the vocabulary of a response. Larger negative values will cause
846        /// the model to  repeating a common token until it hits the maxOutputTokens limit: "...the the the the the...".
847        #[serde(skip_serializing_if = "Option::is_none")]
848        pub frequency_penalty: Option<f64>,
849        /// If true, export the logprobs results in response.
850        #[serde(skip_serializing_if = "Option::is_none")]
851        pub response_logprobs: Option<bool>,
852        /// Only valid if responseLogprobs=True. This sets the number of top logprobs to return at each decoding step in
853        /// [Candidate.logprobs_result].
854        #[serde(skip_serializing_if = "Option::is_none")]
855        pub logprobs: Option<i32>,
856    }
857
858    impl Default for GenerationConfig {
859        fn default() -> Self {
860            Self {
861                temperature: Some(1.0),
862                max_output_tokens: Some(4096),
863                stop_sequences: None,
864                response_mime_type: None,
865                response_schema: None,
866                candidate_count: None,
867                top_p: None,
868                top_k: None,
869                presence_penalty: None,
870                frequency_penalty: None,
871                response_logprobs: None,
872                logprobs: None,
873            }
874        }
875    }
876    /// The Schema object allows the definition of input and output data types. These types can be objects, but also
877    /// primitives and arrays. Represents a select subset of an OpenAPI 3.0 schema object.
878    /// From [Gemini API Reference](https://ai.google.dev/api/caching#Schema)
879    #[derive(Debug, Deserialize, Serialize, Clone)]
880    pub struct Schema {
881        pub r#type: String,
882        #[serde(skip_serializing_if = "Option::is_none")]
883        pub format: Option<String>,
884        #[serde(skip_serializing_if = "Option::is_none")]
885        pub description: Option<String>,
886        #[serde(skip_serializing_if = "Option::is_none")]
887        pub nullable: Option<bool>,
888        #[serde(skip_serializing_if = "Option::is_none")]
889        pub r#enum: Option<Vec<String>>,
890        #[serde(skip_serializing_if = "Option::is_none")]
891        pub max_items: Option<i32>,
892        #[serde(skip_serializing_if = "Option::is_none")]
893        pub min_items: Option<i32>,
894        #[serde(skip_serializing_if = "Option::is_none")]
895        pub properties: Option<HashMap<String, Schema>>,
896        #[serde(skip_serializing_if = "Option::is_none")]
897        pub required: Option<Vec<String>>,
898        #[serde(skip_serializing_if = "Option::is_none")]
899        pub items: Option<Box<Schema>>,
900    }
901
902    impl TryFrom<Value> for Schema {
903        type Error = CompletionError;
904
905        fn try_from(value: Value) -> Result<Self, Self::Error> {
906            if let Some(obj) = value.as_object() {
907                Ok(Schema {
908                    r#type: obj
909                        .get("type")
910                        .and_then(|v| {
911                            if v.is_string() {
912                                v.as_str().map(String::from)
913                            } else if v.is_array() {
914                                v.as_array()
915                                    .and_then(|arr| arr.first())
916                                    .and_then(|v| v.as_str().map(String::from))
917                            } else {
918                                None
919                            }
920                        })
921                        .unwrap_or_default(),
922                    format: obj.get("format").and_then(|v| v.as_str()).map(String::from),
923                    description: obj
924                        .get("description")
925                        .and_then(|v| v.as_str())
926                        .map(String::from),
927                    nullable: obj.get("nullable").and_then(|v| v.as_bool()),
928                    r#enum: obj.get("enum").and_then(|v| v.as_array()).map(|arr| {
929                        arr.iter()
930                            .filter_map(|v| v.as_str().map(String::from))
931                            .collect()
932                    }),
933                    max_items: obj
934                        .get("maxItems")
935                        .and_then(|v| v.as_i64())
936                        .map(|v| v as i32),
937                    min_items: obj
938                        .get("minItems")
939                        .and_then(|v| v.as_i64())
940                        .map(|v| v as i32),
941                    properties: obj
942                        .get("properties")
943                        .and_then(|v| v.as_object())
944                        .map(|map| {
945                            map.iter()
946                                .filter_map(|(k, v)| {
947                                    v.clone().try_into().ok().map(|schema| (k.clone(), schema))
948                                })
949                                .collect()
950                        }),
951                    required: obj.get("required").and_then(|v| v.as_array()).map(|arr| {
952                        arr.iter()
953                            .filter_map(|v| v.as_str().map(String::from))
954                            .collect()
955                    }),
956                    items: obj
957                        .get("items")
958                        .map(|v| Box::new(v.clone().try_into().unwrap())),
959                })
960            } else {
961                Err(CompletionError::ResponseError(
962                    "Expected a JSON object for Schema".into(),
963                ))
964            }
965        }
966    }
967
968    #[derive(Debug, Serialize)]
969    #[serde(rename_all = "camelCase")]
970    pub struct GenerateContentRequest {
971        pub contents: Vec<Content>,
972        pub tools: Option<Tool>,
973        pub tool_config: Option<ToolConfig>,
974        /// Optional. Configuration options for model generation and outputs.
975        pub generation_config: Option<GenerationConfig>,
976        /// Optional. A list of unique SafetySetting instances for blocking unsafe content. This will be enforced on the
977        /// [GenerateContentRequest.contents] and [GenerateContentResponse.candidates]. There should not be more than one
978        /// setting for each SafetyCategory type. The API will block any contents and responses that fail to meet the
979        /// thresholds set by these settings. This list overrides the default settings for each SafetyCategory specified
980        /// in the safetySettings. If there is no SafetySetting for a given SafetyCategory provided in the list, the API
981        /// will use the default safety setting for that category. Harm categories:
982        ///     - HARM_CATEGORY_HATE_SPEECH,
983        ///     - HARM_CATEGORY_SEXUALLY_EXPLICIT
984        ///     - HARM_CATEGORY_DANGEROUS_CONTENT
985        ///     - HARM_CATEGORY_HARASSMENT
986        /// are supported.
987        /// Refer to the guide for detailed information on available safety settings. Also refer to the Safety guidance
988        /// to learn how to incorporate safety considerations in your AI applications.
989        pub safety_settings: Option<Vec<SafetySetting>>,
990        /// Optional. Developer set system instruction(s). Currently, text only.
991        /// From [Gemini API Reference](https://ai.google.dev/gemini-api/docs/system-instructions?lang=rest)
992        pub system_instruction: Option<Content>,
993        // cachedContent: Optional<String>
994    }
995
996    #[derive(Debug, Serialize)]
997    #[serde(rename_all = "camelCase")]
998    pub struct Tool {
999        pub function_declarations: OneOrMany<FunctionDeclaration>,
1000        pub code_execution: Option<CodeExecution>,
1001    }
1002
1003    #[derive(Debug, Serialize, Clone)]
1004    #[serde(rename_all = "camelCase")]
1005    pub struct FunctionDeclaration {
1006        pub name: String,
1007        pub description: String,
1008        #[serde(skip_serializing_if = "Option::is_none")]
1009        pub parameters: Option<Schema>,
1010    }
1011
1012    #[derive(Debug, Serialize)]
1013    #[serde(rename_all = "camelCase")]
1014    pub struct ToolConfig {
1015        pub schema: Option<Schema>,
1016    }
1017
1018    #[derive(Debug, Serialize)]
1019    #[serde(rename_all = "camelCase")]
1020    pub struct CodeExecution {}
1021
1022    #[derive(Debug, Serialize)]
1023    #[serde(rename_all = "camelCase")]
1024    pub struct SafetySetting {
1025        pub category: HarmCategory,
1026        pub threshold: HarmBlockThreshold,
1027    }
1028
1029    #[derive(Debug, Serialize)]
1030    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1031    pub enum HarmBlockThreshold {
1032        HarmBlockThresholdUnspecified,
1033        BlockLowAndAbove,
1034        BlockMediumAndAbove,
1035        BlockOnlyHigh,
1036        BlockNone,
1037        Off,
1038    }
1039}
1040
1041#[cfg(test)]
1042mod tests {
1043    use crate::message;
1044
1045    use super::*;
1046    use serde_json::json;
1047
1048    #[test]
1049    fn test_deserialize_message_user() {
1050        let raw_message = r#"{
1051            "parts": [
1052                {"text": "Hello, world!"},
1053                {"inlineData": {"mimeType": "image/png", "data": "base64encodeddata"}},
1054                {"functionCall": {"name": "test_function", "args": {"arg1": "value1"}}},
1055                {"functionResponse": {"name": "test_function", "response": {"result": "success"}}},
1056                {"fileData": {"mimeType": "application/pdf", "fileUri": "http://example.com/file.pdf"}},
1057                {"executableCode": {"code": "print('Hello, world!')", "language": "PYTHON"}},
1058                {"codeExecutionResult": {"output": "Hello, world!", "outcome": "OUTCOME_OK"}}
1059            ],
1060            "role": "user"
1061        }"#;
1062
1063        let content: Content = {
1064            let jd = &mut serde_json::Deserializer::from_str(raw_message);
1065            serde_path_to_error::deserialize(jd).unwrap_or_else(|err| {
1066                panic!("Deserialization error at {}: {}", err.path(), err);
1067            })
1068        };
1069        assert_eq!(content.role, Some(Role::User));
1070        assert_eq!(content.parts.len(), 7);
1071
1072        let parts: Vec<Part> = content.parts.into_iter().collect();
1073
1074        if let Part::Text(text) = &parts[0] {
1075            assert_eq!(text, "Hello, world!");
1076        } else {
1077            panic!("Expected text part");
1078        }
1079
1080        if let Part::InlineData(inline_data) = &parts[1] {
1081            assert_eq!(inline_data.mime_type, "image/png");
1082            assert_eq!(inline_data.data, "base64encodeddata");
1083        } else {
1084            panic!("Expected inline data part");
1085        }
1086
1087        if let Part::FunctionCall(function_call) = &parts[2] {
1088            assert_eq!(function_call.name, "test_function");
1089            assert_eq!(
1090                function_call.args.as_object().unwrap().get("arg1").unwrap(),
1091                "value1"
1092            );
1093        } else {
1094            panic!("Expected function call part");
1095        }
1096
1097        if let Part::FunctionResponse(function_response) = &parts[3] {
1098            assert_eq!(function_response.name, "test_function");
1099            assert_eq!(
1100                function_response
1101                    .response
1102                    .as_ref()
1103                    .unwrap()
1104                    .get("result")
1105                    .unwrap(),
1106                "success"
1107            );
1108        } else {
1109            panic!("Expected function response part");
1110        }
1111
1112        if let Part::FileData(file_data) = &parts[4] {
1113            assert_eq!(file_data.mime_type.as_ref().unwrap(), "application/pdf");
1114            assert_eq!(file_data.file_uri, "http://example.com/file.pdf");
1115        } else {
1116            panic!("Expected file data part");
1117        }
1118
1119        if let Part::ExecutableCode(executable_code) = &parts[5] {
1120            assert_eq!(executable_code.code, "print('Hello, world!')");
1121        } else {
1122            panic!("Expected executable code part");
1123        }
1124
1125        if let Part::CodeExecutionResult(code_execution_result) = &parts[6] {
1126            assert_eq!(
1127                code_execution_result.clone().output.unwrap(),
1128                "Hello, world!"
1129            );
1130        } else {
1131            panic!("Expected code execution result part");
1132        }
1133    }
1134
1135    #[test]
1136    fn test_deserialize_message_model() {
1137        let json_data = json!({
1138            "parts": [{"text": "Hello, user!"}],
1139            "role": "model"
1140        });
1141
1142        let content: Content = serde_json::from_value(json_data).unwrap();
1143        assert_eq!(content.role, Some(Role::Model));
1144        assert_eq!(content.parts.len(), 1);
1145        if let Part::Text(text) = &content.parts.first() {
1146            assert_eq!(text, "Hello, user!");
1147        } else {
1148            panic!("Expected text part");
1149        }
1150    }
1151
1152    #[test]
1153    fn test_message_conversion_user() {
1154        let msg = message::Message::user("Hello, world!");
1155        let content: Content = msg.try_into().unwrap();
1156        assert_eq!(content.role, Some(Role::User));
1157        assert_eq!(content.parts.len(), 1);
1158        if let Part::Text(text) = &content.parts.first() {
1159            assert_eq!(text, "Hello, world!");
1160        } else {
1161            panic!("Expected text part");
1162        }
1163    }
1164
1165    #[test]
1166    fn test_message_conversion_model() {
1167        let msg = message::Message::assistant("Hello, user!");
1168
1169        let content: Content = msg.try_into().unwrap();
1170        assert_eq!(content.role, Some(Role::Model));
1171        assert_eq!(content.parts.len(), 1);
1172        if let Part::Text(text) = &content.parts.first() {
1173            assert_eq!(text, "Hello, user!");
1174        } else {
1175            panic!("Expected text part");
1176        }
1177    }
1178
1179    #[test]
1180    fn test_message_conversion_tool_call() {
1181        let tool_call = message::ToolCall {
1182            id: "test_tool".to_string(),
1183            call_id: None,
1184            function: message::ToolFunction {
1185                name: "test_function".to_string(),
1186                arguments: json!({"arg1": "value1"}),
1187            },
1188        };
1189
1190        let msg = message::Message::Assistant {
1191            id: None,
1192            content: OneOrMany::one(message::AssistantContent::ToolCall(tool_call)),
1193        };
1194
1195        let content: Content = msg.try_into().unwrap();
1196        assert_eq!(content.role, Some(Role::Model));
1197        assert_eq!(content.parts.len(), 1);
1198        if let Part::FunctionCall(function_call) = &content.parts.first() {
1199            assert_eq!(function_call.name, "test_function");
1200            assert_eq!(
1201                function_call.args.as_object().unwrap().get("arg1").unwrap(),
1202                "value1"
1203            );
1204        } else {
1205            panic!("Expected function call part");
1206        }
1207    }
1208}