rig/providers/gemini/
completion.rs

1// ================================================================
2//! Google Gemini Completion Integration
3//! From [Gemini API Reference](https://ai.google.dev/api/generate-content)
4// ================================================================
5/// `gemini-2.5-pro-preview-06-05` completion model
6pub const GEMINI_2_5_PRO_PREVIEW_06_05: &str = "gemini-2.5-pro-preview-06-05";
7/// `gemini-2.5-pro-preview-05-06` completion model
8pub const GEMINI_2_5_PRO_PREVIEW_05_06: &str = "gemini-2.5-pro-preview-05-06";
9/// `gemini-2.5-pro-preview-03-25` completion model
10pub const GEMINI_2_5_PRO_PREVIEW_03_25: &str = "gemini-2.5-pro-preview-03-25";
11/// `gemini-2.5-flash-preview-05-20` completion model
12pub const GEMINI_2_5_FLASH_PREVIEW_05_20: &str = "gemini-2.5-flash-preview-05-20";
13/// `gemini-2.5-flash-preview-04-17` completion model
14pub const GEMINI_2_5_FLASH_PREVIEW_04_17: &str = "gemini-2.5-flash-preview-04-17";
15/// `gemini-2.5-pro-exp-03-25` experimental completion model
16pub const GEMINI_2_5_PRO_EXP_03_25: &str = "gemini-2.5-pro-exp-03-25";
17/// `gemini-2.0-flash-lite` completion model
18pub const GEMINI_2_0_FLASH_LITE: &str = "gemini-2.0-flash-lite";
19/// `gemini-2.0-flash` completion model
20pub const GEMINI_2_0_FLASH: &str = "gemini-2.0-flash";
21/// `gemini-1.5-flash` completion model
22pub const GEMINI_1_5_FLASH: &str = "gemini-1.5-flash";
23/// `gemini-1.5-pro` completion model
24pub const GEMINI_1_5_PRO: &str = "gemini-1.5-pro";
25/// `gemini-1.5-pro-8b` completion model
26pub const GEMINI_1_5_PRO_8B: &str = "gemini-1.5-pro-8b";
27/// `gemini-1.0-pro` completion model
28pub const GEMINI_1_0_PRO: &str = "gemini-1.0-pro";
29
30use self::gemini_api_types::Schema;
31use crate::providers::gemini::streaming::StreamingCompletionResponse;
32use crate::{
33    OneOrMany,
34    completion::{self, CompletionError, CompletionRequest},
35};
36use gemini_api_types::{
37    Content, FunctionDeclaration, GenerateContentRequest, GenerateContentResponse,
38    GenerationConfig, Part, Role, Tool,
39};
40use serde_json::{Map, Value};
41use std::convert::TryFrom;
42
43use super::Client;
44
45// =================================================================
46// Rig Implementation Types
47// =================================================================
48
49#[derive(Clone)]
50pub struct CompletionModel {
51    pub(crate) client: Client,
52    pub model: String,
53}
54
55impl CompletionModel {
56    pub fn new(client: Client, model: &str) -> Self {
57        Self {
58            client,
59            model: model.to_string(),
60        }
61    }
62}
63
64impl completion::CompletionModel for CompletionModel {
65    type Response = GenerateContentResponse;
66    type StreamingResponse = StreamingCompletionResponse;
67
68    #[cfg_attr(feature = "worker", worker::send)]
69    async fn completion(
70        &self,
71        completion_request: CompletionRequest,
72    ) -> Result<completion::CompletionResponse<GenerateContentResponse>, CompletionError> {
73        let request = create_request_body(completion_request)?;
74
75        tracing::debug!(
76            "Sending completion request to Gemini API {}",
77            serde_json::to_string_pretty(&request)?
78        );
79
80        let response = self
81            .client
82            .post(&format!("/v1beta/models/{}:generateContent", self.model))
83            .json(&request)
84            .send()
85            .await?;
86
87        if response.status().is_success() {
88            let response = response.json::<GenerateContentResponse>().await?;
89            match response.usage_metadata {
90                Some(ref usage) => tracing::info!(target: "rig",
91                "Gemini completion token usage: {}",
92                usage
93                ),
94                None => tracing::info!(target: "rig",
95                    "Gemini completion token usage: n/a",
96                ),
97            }
98
99            tracing::debug!("Received response");
100
101            Ok(completion::CompletionResponse::try_from(response))
102        } else {
103            Err(CompletionError::ProviderError(response.text().await?))
104        }?
105    }
106
107    #[cfg_attr(feature = "worker", worker::send)]
108    async fn stream(
109        &self,
110        request: CompletionRequest,
111    ) -> Result<
112        crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
113        CompletionError,
114    > {
115        CompletionModel::stream(self, request).await
116    }
117}
118
119pub(crate) fn create_request_body(
120    completion_request: CompletionRequest,
121) -> Result<GenerateContentRequest, CompletionError> {
122    let mut full_history = Vec::new();
123    full_history.extend(completion_request.chat_history);
124
125    let additional_params = completion_request
126        .additional_params
127        .unwrap_or_else(|| Value::Object(Map::new()));
128
129    let mut generation_config = serde_json::from_value::<GenerationConfig>(additional_params)?;
130
131    if let Some(temp) = completion_request.temperature {
132        generation_config.temperature = Some(temp);
133    }
134
135    if let Some(max_tokens) = completion_request.max_tokens {
136        generation_config.max_output_tokens = Some(max_tokens);
137    }
138
139    let system_instruction = completion_request.preamble.clone().map(|preamble| Content {
140        parts: OneOrMany::one(preamble.into()),
141        role: Some(Role::Model),
142    });
143
144    let request = GenerateContentRequest {
145        contents: full_history
146            .into_iter()
147            .map(|msg| {
148                msg.try_into()
149                    .map_err(|e| CompletionError::RequestError(Box::new(e)))
150            })
151            .collect::<Result<Vec<_>, _>>()?,
152        generation_config: Some(generation_config),
153        safety_settings: None,
154        tools: Some(Tool::try_from(completion_request.tools)?),
155        tool_config: None,
156        system_instruction,
157    };
158
159    Ok(request)
160}
161
162impl TryFrom<completion::ToolDefinition> for Tool {
163    type Error = CompletionError;
164
165    fn try_from(tool: completion::ToolDefinition) -> Result<Self, Self::Error> {
166        let parameters: Option<Schema> =
167            if tool.parameters == serde_json::json!({"type": "object", "properties": {}}) {
168                None
169            } else {
170                Some(tool.parameters.try_into()?)
171            };
172
173        Ok(Self {
174            function_declarations: vec![FunctionDeclaration {
175                name: tool.name,
176                description: tool.description,
177                parameters,
178            }],
179            code_execution: None,
180        })
181    }
182}
183
184impl TryFrom<Vec<completion::ToolDefinition>> for Tool {
185    type Error = CompletionError;
186
187    fn try_from(tools: Vec<completion::ToolDefinition>) -> Result<Self, Self::Error> {
188        let mut function_declarations = Vec::new();
189
190        for tool in tools {
191            let parameters =
192                if tool.parameters == serde_json::json!({"type": "object", "properties": {}}) {
193                    None
194                } else {
195                    match tool.parameters.try_into() {
196                        Ok(schema) => Some(schema),
197                        Err(e) => {
198                            let emsg = format!(
199                                "Tool '{}' could not be converted to a schema: {:?}",
200                                tool.name, e,
201                            );
202                            return Err(CompletionError::ProviderError(emsg));
203                        }
204                    }
205                };
206
207            function_declarations.push(FunctionDeclaration {
208                name: tool.name,
209                description: tool.description,
210                parameters,
211            });
212        }
213
214        Ok(Self {
215            function_declarations,
216            code_execution: None,
217        })
218    }
219}
220
221impl TryFrom<GenerateContentResponse> for completion::CompletionResponse<GenerateContentResponse> {
222    type Error = CompletionError;
223
224    fn try_from(response: GenerateContentResponse) -> Result<Self, Self::Error> {
225        let candidate = response.candidates.first().ok_or_else(|| {
226            CompletionError::ResponseError("No response candidates in response".into())
227        })?;
228
229        let content = candidate
230            .content
231            .parts
232            .iter()
233            .map(|part| {
234                Ok(match part {
235                    Part::Text(text) => completion::AssistantContent::text(text),
236                    Part::FunctionCall(function_call) => completion::AssistantContent::tool_call(
237                        &function_call.name,
238                        &function_call.name,
239                        function_call.args.clone(),
240                    ),
241                    _ => {
242                        return Err(CompletionError::ResponseError(
243                            "Response did not contain a message or tool call".into(),
244                        ));
245                    }
246                })
247            })
248            .collect::<Result<Vec<_>, _>>()?;
249
250        let choice = OneOrMany::many(content).map_err(|_| {
251            CompletionError::ResponseError(
252                "Response contained no message or tool call (empty)".to_owned(),
253            )
254        })?;
255
256        let usage = response
257            .usage_metadata
258            .as_ref()
259            .map(|usage| completion::Usage {
260                input_tokens: usage.prompt_token_count as u64,
261                output_tokens: usage.candidates_token_count as u64,
262                total_tokens: usage.total_token_count as u64,
263            })
264            .unwrap_or_default();
265
266        Ok(completion::CompletionResponse {
267            choice,
268            usage,
269            raw_response: response,
270        })
271    }
272}
273
274pub mod gemini_api_types {
275    use std::{collections::HashMap, convert::Infallible, str::FromStr};
276
277    // =================================================================
278    // Gemini API Types
279    // =================================================================
280    use serde::{Deserialize, Serialize};
281    use serde_json::{Value, json};
282
283    use crate::{
284        OneOrMany,
285        completion::CompletionError,
286        message::{self, MessageError, MimeType as _},
287        one_or_many::string_or_one_or_many,
288        providers::gemini::gemini_api_types::{CodeExecutionResult, ExecutableCode},
289    };
290
291    /// Response from the model supporting multiple candidate responses.
292    /// Safety ratings and content filtering are reported for both prompt in GenerateContentResponse.prompt_feedback
293    /// and for each candidate in finishReason and in safetyRatings.
294    /// The API:
295    ///     - Returns either all requested candidates or none of them
296    ///     - Returns no candidates at all only if there was something wrong with the prompt (check promptFeedback)
297    ///     - Reports feedback on each candidate in finishReason and safetyRatings.
298    #[derive(Debug, Deserialize)]
299    #[serde(rename_all = "camelCase")]
300    pub struct GenerateContentResponse {
301        /// Candidate responses from the model.
302        pub candidates: Vec<ContentCandidate>,
303        /// Returns the prompt's feedback related to the content filters.
304        pub prompt_feedback: Option<PromptFeedback>,
305        /// Output only. Metadata on the generation requests' token usage.
306        pub usage_metadata: Option<UsageMetadata>,
307        pub model_version: Option<String>,
308    }
309
310    /// A response candidate generated from the model.
311    #[derive(Debug, Deserialize)]
312    #[serde(rename_all = "camelCase")]
313    pub struct ContentCandidate {
314        /// Output only. Generated content returned from the model.
315        pub content: Content,
316        /// Optional. Output only. The reason why the model stopped generating tokens.
317        /// If empty, the model has not stopped generating tokens.
318        pub finish_reason: Option<FinishReason>,
319        /// List of ratings for the safety of a response candidate.
320        /// There is at most one rating per category.
321        pub safety_ratings: Option<Vec<SafetyRating>>,
322        /// Output only. Citation information for model-generated candidate.
323        /// This field may be populated with recitation information for any text included in the content.
324        /// These are passages that are "recited" from copyrighted material in the foundational LLM's training data.
325        pub citation_metadata: Option<CitationMetadata>,
326        /// Output only. Token count for this candidate.
327        pub token_count: Option<i32>,
328        /// Output only.
329        pub avg_logprobs: Option<f64>,
330        /// Output only. Log-likelihood scores for the response tokens and top tokens
331        pub logprobs_result: Option<LogprobsResult>,
332        /// Output only. Index of the candidate in the list of response candidates.
333        pub index: Option<i32>,
334    }
335    #[derive(Debug, Deserialize, Serialize)]
336    pub struct Content {
337        /// Ordered Parts that constitute a single message. Parts may have different MIME types.
338        #[serde(deserialize_with = "string_or_one_or_many")]
339        pub parts: OneOrMany<Part>,
340        /// The producer of the content. Must be either 'user' or 'model'.
341        /// Useful to set for multi-turn conversations, otherwise can be left blank or unset.
342        pub role: Option<Role>,
343    }
344
345    impl TryFrom<message::Message> for Content {
346        type Error = message::MessageError;
347
348        fn try_from(msg: message::Message) -> Result<Self, Self::Error> {
349            Ok(match msg {
350                message::Message::User { content } => Content {
351                    parts: content.try_map(|c| c.try_into())?,
352                    role: Some(Role::User),
353                },
354                message::Message::Assistant { content, .. } => Content {
355                    role: Some(Role::Model),
356                    parts: content.map(|content| content.into()),
357                },
358            })
359        }
360    }
361
362    impl TryFrom<Content> for message::Message {
363        type Error = message::MessageError;
364
365        fn try_from(content: Content) -> Result<Self, Self::Error> {
366            match content.role {
367                Some(Role::User) | None => Ok(message::Message::User {
368                    content: content.parts.try_map(|part| {
369                        Ok(match part {
370                            Part::Text(text) => message::UserContent::text(text),
371                            Part::InlineData(inline_data) => {
372                                let mime_type =
373                                    message::MediaType::from_mime_type(&inline_data.mime_type);
374
375                                match mime_type {
376                                    Some(message::MediaType::Image(media_type)) => {
377                                        message::UserContent::image(
378                                            inline_data.data,
379                                            Some(message::ContentFormat::default()),
380                                            Some(media_type),
381                                            Some(message::ImageDetail::default()),
382                                        )
383                                    }
384                                    Some(message::MediaType::Document(media_type)) => {
385                                        message::UserContent::document(
386                                            inline_data.data,
387                                            Some(message::ContentFormat::default()),
388                                            Some(media_type),
389                                        )
390                                    }
391                                    Some(message::MediaType::Audio(media_type)) => {
392                                        message::UserContent::audio(
393                                            inline_data.data,
394                                            Some(message::ContentFormat::default()),
395                                            Some(media_type),
396                                        )
397                                    }
398                                    _ => {
399                                        return Err(message::MessageError::ConversionError(
400                                            format!("Unsupported media type {mime_type:?}"),
401                                        ));
402                                    }
403                                }
404                            }
405                            _ => {
406                                return Err(message::MessageError::ConversionError(format!(
407                                    "Unsupported gemini content part type: {part:?}"
408                                )));
409                            }
410                        })
411                    })?,
412                }),
413                Some(Role::Model) => Ok(message::Message::Assistant {
414                    id: None,
415                    content: content.parts.try_map(|part| {
416                        Ok(match part {
417                            Part::Text(text) => message::AssistantContent::text(text),
418                            Part::FunctionCall(function_call) => {
419                                message::AssistantContent::ToolCall(function_call.into())
420                            }
421                            _ => {
422                                return Err(message::MessageError::ConversionError(format!(
423                                    "Unsupported part type: {part:?}"
424                                )));
425                            }
426                        })
427                    })?,
428                }),
429            }
430        }
431    }
432
433    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
434    #[serde(rename_all = "lowercase")]
435    pub enum Role {
436        User,
437        Model,
438    }
439
440    /// A datatype containing media that is part of a multi-part [Content] message.
441    /// A Part consists of data which has an associated datatype. A Part can only contain one of the accepted types in Part.data.
442    /// A Part must have a fixed IANA MIME type identifying the type and subtype of the media if the inlineData field is filled with raw bytes.
443    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
444    #[serde(rename_all = "camelCase")]
445    pub enum Part {
446        Text(String),
447        InlineData(Blob),
448        FunctionCall(FunctionCall),
449        FunctionResponse(FunctionResponse),
450        FileData(FileData),
451        ExecutableCode(ExecutableCode),
452        CodeExecutionResult(CodeExecutionResult),
453    }
454
455    impl From<String> for Part {
456        fn from(text: String) -> Self {
457            Self::Text(text)
458        }
459    }
460
461    impl From<&str> for Part {
462        fn from(text: &str) -> Self {
463            Self::Text(text.to_string())
464        }
465    }
466
467    impl FromStr for Part {
468        type Err = Infallible;
469
470        fn from_str(s: &str) -> Result<Self, Self::Err> {
471            Ok(s.into())
472        }
473    }
474
475    impl TryFrom<message::UserContent> for Part {
476        type Error = message::MessageError;
477
478        fn try_from(content: message::UserContent) -> Result<Self, Self::Error> {
479            match content {
480                message::UserContent::Text(message::Text { text }) => Ok(Self::Text(text)),
481                message::UserContent::ToolResult(message::ToolResult { id, content, .. }) => {
482                    let content = match content.first() {
483                        message::ToolResultContent::Text(text) => text.text,
484                        message::ToolResultContent::Image(_) => {
485                            return Err(message::MessageError::ConversionError(
486                                "Tool result content must be text".to_string(),
487                            ));
488                        }
489                    };
490                    // Convert to JSON since this value may be a valid JSON value
491                    let result: serde_json::Value = serde_json::from_str(&content)
492                        .map_err(|x| MessageError::ConversionError(x.to_string()))?;
493                    Ok(Part::FunctionResponse(FunctionResponse {
494                        name: id,
495                        response: Some(json!({ "result": result })),
496                    }))
497                }
498                message::UserContent::Image(message::Image {
499                    data, media_type, ..
500                }) => match media_type {
501                    Some(media_type) => match media_type {
502                        message::ImageMediaType::JPEG
503                        | message::ImageMediaType::PNG
504                        | message::ImageMediaType::WEBP
505                        | message::ImageMediaType::HEIC
506                        | message::ImageMediaType::HEIF => Ok(Self::InlineData(Blob {
507                            mime_type: media_type.to_mime_type().to_owned(),
508                            data,
509                        })),
510                        _ => Err(message::MessageError::ConversionError(format!(
511                            "Unsupported image media type {media_type:?}"
512                        ))),
513                    },
514                    None => Err(message::MessageError::ConversionError(
515                        "Media type for image is required for Anthropic".to_string(),
516                    )),
517                },
518                message::UserContent::Document(message::Document {
519                    data, media_type, ..
520                }) => match media_type {
521                    Some(media_type) => match media_type {
522                        message::DocumentMediaType::PDF
523                        | message::DocumentMediaType::TXT
524                        | message::DocumentMediaType::RTF
525                        | message::DocumentMediaType::HTML
526                        | message::DocumentMediaType::CSS
527                        | message::DocumentMediaType::MARKDOWN
528                        | message::DocumentMediaType::CSV
529                        | message::DocumentMediaType::XML => Ok(Self::InlineData(Blob {
530                            mime_type: media_type.to_mime_type().to_owned(),
531                            data,
532                        })),
533                        _ => Err(message::MessageError::ConversionError(format!(
534                            "Unsupported document media type {media_type:?}"
535                        ))),
536                    },
537                    None => Err(message::MessageError::ConversionError(
538                        "Media type for document is required for Anthropic".to_string(),
539                    )),
540                },
541                message::UserContent::Audio(message::Audio {
542                    data, media_type, ..
543                }) => match media_type {
544                    Some(media_type) => Ok(Self::InlineData(Blob {
545                        mime_type: media_type.to_mime_type().to_owned(),
546                        data,
547                    })),
548                    None => Err(message::MessageError::ConversionError(
549                        "Media type for audio is required for Anthropic".to_string(),
550                    )),
551                },
552            }
553        }
554    }
555
556    impl From<message::AssistantContent> for Part {
557        fn from(content: message::AssistantContent) -> Self {
558            match content {
559                message::AssistantContent::Text(message::Text { text }) => text.into(),
560                message::AssistantContent::ToolCall(tool_call) => tool_call.into(),
561            }
562        }
563    }
564
565    impl From<message::ToolCall> for Part {
566        fn from(tool_call: message::ToolCall) -> Self {
567            Self::FunctionCall(FunctionCall {
568                name: tool_call.function.name,
569                args: tool_call.function.arguments,
570            })
571        }
572    }
573
574    /// Raw media bytes.
575    /// Text should not be sent as raw bytes, use the 'text' field.
576    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
577    #[serde(rename_all = "camelCase")]
578    pub struct Blob {
579        /// The IANA standard MIME type of the source data. Examples: - image/png - image/jpeg
580        /// If an unsupported MIME type is provided, an error will be returned.
581        pub mime_type: String,
582        /// Raw bytes for media formats. A base64-encoded string.
583        pub data: String,
584    }
585
586    /// A predicted FunctionCall returned from the model that contains a string representing the
587    /// FunctionDeclaration.name with the arguments and their values.
588    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
589    pub struct FunctionCall {
590        /// Required. The name of the function to call. Must be a-z, A-Z, 0-9, or contain underscores
591        /// and dashes, with a maximum length of 63.
592        pub name: String,
593        /// Optional. The function parameters and values in JSON object format.
594        pub args: serde_json::Value,
595    }
596
597    impl From<FunctionCall> for message::ToolCall {
598        fn from(function_call: FunctionCall) -> Self {
599            Self {
600                id: function_call.name.clone(),
601                call_id: None,
602                function: message::ToolFunction {
603                    name: function_call.name,
604                    arguments: function_call.args,
605                },
606            }
607        }
608    }
609
610    impl From<message::ToolCall> for FunctionCall {
611        fn from(tool_call: message::ToolCall) -> Self {
612            Self {
613                name: tool_call.function.name,
614                args: tool_call.function.arguments,
615            }
616        }
617    }
618
619    /// The result output from a FunctionCall that contains a string representing the FunctionDeclaration.name
620    /// and a structured JSON object containing any output from the function is used as context to the model.
621    /// This should contain the result of aFunctionCall made based on model prediction.
622    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
623    pub struct FunctionResponse {
624        /// The name of the function to call. Must be a-z, A-Z, 0-9, or contain underscores and dashes,
625        /// with a maximum length of 63.
626        pub name: String,
627        /// The function response in JSON object format.
628        pub response: Option<serde_json::Value>,
629    }
630
631    /// URI based data.
632    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
633    #[serde(rename_all = "camelCase")]
634    pub struct FileData {
635        /// Optional. The IANA standard MIME type of the source data.
636        pub mime_type: Option<String>,
637        /// Required. URI.
638        pub file_uri: String,
639    }
640
641    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
642    pub struct SafetyRating {
643        pub category: HarmCategory,
644        pub probability: HarmProbability,
645    }
646
647    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
648    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
649    pub enum HarmProbability {
650        HarmProbabilityUnspecified,
651        Negligible,
652        Low,
653        Medium,
654        High,
655    }
656
657    #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
658    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
659    pub enum HarmCategory {
660        HarmCategoryUnspecified,
661        HarmCategoryDerogatory,
662        HarmCategoryToxicity,
663        HarmCategoryViolence,
664        HarmCategorySexually,
665        HarmCategoryMedical,
666        HarmCategoryDangerous,
667        HarmCategoryHarassment,
668        HarmCategoryHateSpeech,
669        HarmCategorySexuallyExplicit,
670        HarmCategoryDangerousContent,
671        HarmCategoryCivicIntegrity,
672    }
673
674    #[derive(Debug, Deserialize, Clone, Default)]
675    #[serde(rename_all = "camelCase")]
676    pub struct UsageMetadata {
677        pub prompt_token_count: i32,
678        #[serde(skip_serializing_if = "Option::is_none")]
679        pub cached_content_token_count: Option<i32>,
680        pub candidates_token_count: i32,
681        pub total_token_count: i32,
682    }
683
684    impl std::fmt::Display for UsageMetadata {
685        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
686            write!(
687                f,
688                "Prompt token count: {}\nCached content token count: {}\nCandidates token count: {}\nTotal token count: {}",
689                self.prompt_token_count,
690                match self.cached_content_token_count {
691                    Some(count) => count.to_string(),
692                    None => "n/a".to_string(),
693                },
694                self.candidates_token_count,
695                self.total_token_count
696            )
697        }
698    }
699
700    /// A set of the feedback metadata the prompt specified in [GenerateContentRequest.contents](GenerateContentRequest).
701    #[derive(Debug, Deserialize)]
702    #[serde(rename_all = "camelCase")]
703    pub struct PromptFeedback {
704        /// Optional. If set, the prompt was blocked and no candidates are returned. Rephrase the prompt.
705        pub block_reason: Option<BlockReason>,
706        /// Ratings for safety of the prompt. There is at most one rating per category.
707        pub safety_ratings: Option<Vec<SafetyRating>>,
708    }
709
710    /// Reason why a prompt was blocked by the model
711    #[derive(Debug, Deserialize)]
712    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
713    pub enum BlockReason {
714        /// Default value. This value is unused.
715        BlockReasonUnspecified,
716        /// Prompt was blocked due to safety reasons. Inspect safetyRatings to understand which safety category blocked it.
717        Safety,
718        /// Prompt was blocked due to unknown reasons.
719        Other,
720        /// Prompt was blocked due to the terms which are included from the terminology blocklist.
721        Blocklist,
722        /// Prompt was blocked due to prohibited content.
723        ProhibitedContent,
724    }
725
726    #[derive(Debug, Deserialize)]
727    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
728    pub enum FinishReason {
729        /// Default value. This value is unused.
730        FinishReasonUnspecified,
731        /// Natural stop point of the model or provided stop sequence.
732        Stop,
733        /// The maximum number of tokens as specified in the request was reached.
734        MaxTokens,
735        /// The response candidate content was flagged for safety reasons.
736        Safety,
737        /// The response candidate content was flagged for recitation reasons.
738        Recitation,
739        /// The response candidate content was flagged for using an unsupported language.
740        Language,
741        /// Unknown reason.
742        Other,
743        /// Token generation stopped because the content contains forbidden terms.
744        Blocklist,
745        /// Token generation stopped for potentially containing prohibited content.
746        ProhibitedContent,
747        /// Token generation stopped because the content potentially contains Sensitive Personally Identifiable Information (SPII).
748        Spii,
749        /// The function call generated by the model is invalid.
750        MalformedFunctionCall,
751    }
752
753    #[derive(Debug, Deserialize)]
754    #[serde(rename_all = "camelCase")]
755    pub struct CitationMetadata {
756        pub citation_sources: Vec<CitationSource>,
757    }
758
759    #[derive(Debug, Deserialize)]
760    #[serde(rename_all = "camelCase")]
761    pub struct CitationSource {
762        #[serde(skip_serializing_if = "Option::is_none")]
763        pub uri: Option<String>,
764        #[serde(skip_serializing_if = "Option::is_none")]
765        pub start_index: Option<i32>,
766        #[serde(skip_serializing_if = "Option::is_none")]
767        pub end_index: Option<i32>,
768        #[serde(skip_serializing_if = "Option::is_none")]
769        pub license: Option<String>,
770    }
771
772    #[derive(Debug, Deserialize)]
773    #[serde(rename_all = "camelCase")]
774    pub struct LogprobsResult {
775        pub top_candidate: Vec<TopCandidate>,
776        pub chosen_candidate: Vec<LogProbCandidate>,
777    }
778
779    #[derive(Debug, Deserialize)]
780    pub struct TopCandidate {
781        pub candidates: Vec<LogProbCandidate>,
782    }
783
784    #[derive(Debug, Deserialize)]
785    #[serde(rename_all = "camelCase")]
786    pub struct LogProbCandidate {
787        pub token: String,
788        pub token_id: String,
789        pub log_probability: f64,
790    }
791
792    /// Gemini API Configuration options for model generation and outputs. Not all parameters are
793    /// configurable for every model. From [Gemini API Reference](https://ai.google.dev/api/generate-content#generationconfig)
794    /// ### Rig Note:
795    /// Can be used to construct a typesafe `additional_params` in rig::[AgentBuilder](crate::agent::AgentBuilder).
796    #[derive(Debug, Deserialize, Serialize)]
797    #[serde(rename_all = "camelCase")]
798    pub struct GenerationConfig {
799        /// The set of character sequences (up to 5) that will stop output generation. If specified, the API will stop
800        /// at the first appearance of a stop_sequence. The stop sequence will not be included as part of the response.
801        #[serde(skip_serializing_if = "Option::is_none")]
802        pub stop_sequences: Option<Vec<String>>,
803        /// MIME type of the generated candidate text. Supported MIME types are:
804        ///     - text/plain:  (default) Text output
805        ///     - application/json: JSON response in the response candidates.
806        ///     - text/x.enum: ENUM as a string response in the response candidates.
807        /// Refer to the docs for a list of all supported text MIME types
808        #[serde(skip_serializing_if = "Option::is_none")]
809        pub response_mime_type: Option<String>,
810        /// Output schema of the generated candidate text. Schemas must be a subset of the OpenAPI schema and can be
811        /// objects, primitives or arrays. If set, a compatible responseMimeType must also  be set. Compatible MIME
812        /// types: application/json: Schema for JSON response. Refer to the JSON text generation guide for more details.
813        #[serde(skip_serializing_if = "Option::is_none")]
814        pub response_schema: Option<Schema>,
815        /// Number of generated responses to return. Currently, this value can only be set to 1. If
816        /// unset, this will default to 1.
817        #[serde(skip_serializing_if = "Option::is_none")]
818        pub candidate_count: Option<i32>,
819        /// The maximum number of tokens to include in a response candidate. Note: The default value varies by model, see
820        /// the Model.output_token_limit attribute of the Model returned from the getModel function.
821        #[serde(skip_serializing_if = "Option::is_none")]
822        pub max_output_tokens: Option<u64>,
823        /// Controls the randomness of the output. Note: The default value varies by model, see the Model.temperature
824        /// attribute of the Model returned from the getModel function. Values can range from [0.0, 2.0].
825        #[serde(skip_serializing_if = "Option::is_none")]
826        pub temperature: Option<f64>,
827        /// The maximum cumulative probability of tokens to consider when sampling. The model uses combined Top-k and
828        /// Top-p (nucleus) sampling. Tokens are sorted based on their assigned probabilities so that only the most
829        /// likely tokens are considered. Top-k sampling directly limits the maximum number of tokens to consider, while
830        /// Nucleus sampling limits the number of tokens based on the cumulative probability. Note: The default value
831        /// varies by Model and is specified by theModel.top_p attribute returned from the getModel function. An empty
832        /// topK attribute indicates that the model doesn't apply top-k sampling and doesn't allow setting topK on requests.
833        #[serde(skip_serializing_if = "Option::is_none")]
834        pub top_p: Option<f64>,
835        /// The maximum number of tokens to consider when sampling. Gemini models use Top-p (nucleus) sampling or a
836        /// combination of Top-k and nucleus sampling. Top-k sampling considers the set of topK most probable tokens.
837        /// Models running with nucleus sampling don't allow topK setting. Note: The default value varies by Model and is
838        /// specified by theModel.top_p attribute returned from the getModel function. An empty topK attribute indicates
839        /// that the model doesn't apply top-k sampling and doesn't allow setting topK on requests.
840        #[serde(skip_serializing_if = "Option::is_none")]
841        pub top_k: Option<i32>,
842        /// Presence penalty applied to the next token's logprobs if the token has already been seen in the response.
843        /// This penalty is binary on/off and not dependent on the number of times the token is used (after the first).
844        /// Use frequencyPenalty for a penalty that increases with each use. A positive penalty will discourage the use
845        /// of tokens that have already been used in the response, increasing the vocabulary. A negative penalty will
846        /// encourage the use of tokens that have already been used in the response, decreasing the vocabulary.
847        #[serde(skip_serializing_if = "Option::is_none")]
848        pub presence_penalty: Option<f64>,
849        /// Frequency penalty applied to the next token's logprobs, multiplied by the number of times each token has been
850        /// seen in the response so far. A positive penalty will discourage the use of tokens that have already been
851        /// used, proportional to the number of times the token has been used: The more a token is used, the more
852        /// difficult it is for the  model to use that token again increasing the vocabulary of responses. Caution: A
853        /// negative penalty will encourage the model to reuse tokens proportional to the number of times the token has
854        /// been used. Small negative values will reduce the vocabulary of a response. Larger negative values will cause
855        /// the model to  repeating a common token until it hits the maxOutputTokens limit: "...the the the the the...".
856        #[serde(skip_serializing_if = "Option::is_none")]
857        pub frequency_penalty: Option<f64>,
858        /// If true, export the logprobs results in response.
859        #[serde(skip_serializing_if = "Option::is_none")]
860        pub response_logprobs: Option<bool>,
861        /// Only valid if responseLogprobs=True. This sets the number of top logprobs to return at each decoding step in
862        /// [Candidate.logprobs_result].
863        #[serde(skip_serializing_if = "Option::is_none")]
864        pub logprobs: Option<i32>,
865    }
866
867    impl Default for GenerationConfig {
868        fn default() -> Self {
869            Self {
870                temperature: Some(1.0),
871                max_output_tokens: Some(4096),
872                stop_sequences: None,
873                response_mime_type: None,
874                response_schema: None,
875                candidate_count: None,
876                top_p: None,
877                top_k: None,
878                presence_penalty: None,
879                frequency_penalty: None,
880                response_logprobs: None,
881                logprobs: None,
882            }
883        }
884    }
885    /// The Schema object allows the definition of input and output data types. These types can be objects, but also
886    /// primitives and arrays. Represents a select subset of an OpenAPI 3.0 schema object.
887    /// From [Gemini API Reference](https://ai.google.dev/api/caching#Schema)
888    #[derive(Debug, Deserialize, Serialize, Clone)]
889    pub struct Schema {
890        pub r#type: String,
891        #[serde(skip_serializing_if = "Option::is_none")]
892        pub format: Option<String>,
893        #[serde(skip_serializing_if = "Option::is_none")]
894        pub description: Option<String>,
895        #[serde(skip_serializing_if = "Option::is_none")]
896        pub nullable: Option<bool>,
897        #[serde(skip_serializing_if = "Option::is_none")]
898        pub r#enum: Option<Vec<String>>,
899        #[serde(skip_serializing_if = "Option::is_none")]
900        pub max_items: Option<i32>,
901        #[serde(skip_serializing_if = "Option::is_none")]
902        pub min_items: Option<i32>,
903        #[serde(skip_serializing_if = "Option::is_none")]
904        pub properties: Option<HashMap<String, Schema>>,
905        #[serde(skip_serializing_if = "Option::is_none")]
906        pub required: Option<Vec<String>>,
907        #[serde(skip_serializing_if = "Option::is_none")]
908        pub items: Option<Box<Schema>>,
909    }
910
911    impl TryFrom<Value> for Schema {
912        type Error = CompletionError;
913
914        fn try_from(value: Value) -> Result<Self, Self::Error> {
915            if let Some(obj) = value.as_object() {
916                Ok(Schema {
917                    r#type: obj
918                        .get("type")
919                        .and_then(|v| {
920                            if v.is_string() {
921                                v.as_str().map(String::from)
922                            } else if v.is_array() {
923                                v.as_array()
924                                    .and_then(|arr| arr.first())
925                                    .and_then(|v| v.as_str().map(String::from))
926                            } else {
927                                None
928                            }
929                        })
930                        .unwrap_or_default(),
931                    format: obj.get("format").and_then(|v| v.as_str()).map(String::from),
932                    description: obj
933                        .get("description")
934                        .and_then(|v| v.as_str())
935                        .map(String::from),
936                    nullable: obj.get("nullable").and_then(|v| v.as_bool()),
937                    r#enum: obj.get("enum").and_then(|v| v.as_array()).map(|arr| {
938                        arr.iter()
939                            .filter_map(|v| v.as_str().map(String::from))
940                            .collect()
941                    }),
942                    max_items: obj
943                        .get("maxItems")
944                        .and_then(|v| v.as_i64())
945                        .map(|v| v as i32),
946                    min_items: obj
947                        .get("minItems")
948                        .and_then(|v| v.as_i64())
949                        .map(|v| v as i32),
950                    properties: obj
951                        .get("properties")
952                        .and_then(|v| v.as_object())
953                        .map(|map| {
954                            map.iter()
955                                .filter_map(|(k, v)| {
956                                    v.clone().try_into().ok().map(|schema| (k.clone(), schema))
957                                })
958                                .collect()
959                        }),
960                    required: obj.get("required").and_then(|v| v.as_array()).map(|arr| {
961                        arr.iter()
962                            .filter_map(|v| v.as_str().map(String::from))
963                            .collect()
964                    }),
965                    items: obj
966                        .get("items")
967                        .map(|v| Box::new(v.clone().try_into().unwrap())),
968                })
969            } else {
970                Err(CompletionError::ResponseError(
971                    "Expected a JSON object for Schema".into(),
972                ))
973            }
974        }
975    }
976
977    #[derive(Debug, Serialize)]
978    #[serde(rename_all = "camelCase")]
979    pub struct GenerateContentRequest {
980        pub contents: Vec<Content>,
981        pub tools: Option<Tool>,
982        pub tool_config: Option<ToolConfig>,
983        /// Optional. Configuration options for model generation and outputs.
984        pub generation_config: Option<GenerationConfig>,
985        /// Optional. A list of unique SafetySetting instances for blocking unsafe content. This will be enforced on the
986        /// [GenerateContentRequest.contents] and [GenerateContentResponse.candidates]. There should not be more than one
987        /// setting for each SafetyCategory type. The API will block any contents and responses that fail to meet the
988        /// thresholds set by these settings. This list overrides the default settings for each SafetyCategory specified
989        /// in the safetySettings. If there is no SafetySetting for a given SafetyCategory provided in the list, the API
990        /// will use the default safety setting for that category. Harm categories:
991        ///     - HARM_CATEGORY_HATE_SPEECH,
992        ///     - HARM_CATEGORY_SEXUALLY_EXPLICIT
993        ///     - HARM_CATEGORY_DANGEROUS_CONTENT
994        ///     - HARM_CATEGORY_HARASSMENT
995        /// are supported.
996        /// Refer to the guide for detailed information on available safety settings. Also refer to the Safety guidance
997        /// to learn how to incorporate safety considerations in your AI applications.
998        pub safety_settings: Option<Vec<SafetySetting>>,
999        /// Optional. Developer set system instruction(s). Currently, text only.
1000        /// From [Gemini API Reference](https://ai.google.dev/gemini-api/docs/system-instructions?lang=rest)
1001        pub system_instruction: Option<Content>,
1002        // cachedContent: Optional<String>
1003    }
1004
1005    #[derive(Debug, Serialize)]
1006    #[serde(rename_all = "camelCase")]
1007    pub struct Tool {
1008        pub function_declarations: Vec<FunctionDeclaration>,
1009        pub code_execution: Option<CodeExecution>,
1010    }
1011
1012    #[derive(Debug, Serialize, Clone)]
1013    #[serde(rename_all = "camelCase")]
1014    pub struct FunctionDeclaration {
1015        pub name: String,
1016        pub description: String,
1017        #[serde(skip_serializing_if = "Option::is_none")]
1018        pub parameters: Option<Schema>,
1019    }
1020
1021    #[derive(Debug, Serialize)]
1022    #[serde(rename_all = "camelCase")]
1023    pub struct ToolConfig {
1024        pub schema: Option<Schema>,
1025    }
1026
1027    #[derive(Debug, Serialize)]
1028    #[serde(rename_all = "camelCase")]
1029    pub struct CodeExecution {}
1030
1031    #[derive(Debug, Serialize)]
1032    #[serde(rename_all = "camelCase")]
1033    pub struct SafetySetting {
1034        pub category: HarmCategory,
1035        pub threshold: HarmBlockThreshold,
1036    }
1037
1038    #[derive(Debug, Serialize)]
1039    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1040    pub enum HarmBlockThreshold {
1041        HarmBlockThresholdUnspecified,
1042        BlockLowAndAbove,
1043        BlockMediumAndAbove,
1044        BlockOnlyHigh,
1045        BlockNone,
1046        Off,
1047    }
1048}
1049
1050#[cfg(test)]
1051mod tests {
1052    use crate::message;
1053
1054    use super::*;
1055    use serde_json::json;
1056
1057    #[test]
1058    fn test_deserialize_message_user() {
1059        let raw_message = r#"{
1060            "parts": [
1061                {"text": "Hello, world!"},
1062                {"inlineData": {"mimeType": "image/png", "data": "base64encodeddata"}},
1063                {"functionCall": {"name": "test_function", "args": {"arg1": "value1"}}},
1064                {"functionResponse": {"name": "test_function", "response": {"result": "success"}}},
1065                {"fileData": {"mimeType": "application/pdf", "fileUri": "http://example.com/file.pdf"}},
1066                {"executableCode": {"code": "print('Hello, world!')", "language": "PYTHON"}},
1067                {"codeExecutionResult": {"output": "Hello, world!", "outcome": "OUTCOME_OK"}}
1068            ],
1069            "role": "user"
1070        }"#;
1071
1072        let content: Content = {
1073            let jd = &mut serde_json::Deserializer::from_str(raw_message);
1074            serde_path_to_error::deserialize(jd).unwrap_or_else(|err| {
1075                panic!("Deserialization error at {}: {}", err.path(), err);
1076            })
1077        };
1078        assert_eq!(content.role, Some(Role::User));
1079        assert_eq!(content.parts.len(), 7);
1080
1081        let parts: Vec<Part> = content.parts.into_iter().collect();
1082
1083        if let Part::Text(text) = &parts[0] {
1084            assert_eq!(text, "Hello, world!");
1085        } else {
1086            panic!("Expected text part");
1087        }
1088
1089        if let Part::InlineData(inline_data) = &parts[1] {
1090            assert_eq!(inline_data.mime_type, "image/png");
1091            assert_eq!(inline_data.data, "base64encodeddata");
1092        } else {
1093            panic!("Expected inline data part");
1094        }
1095
1096        if let Part::FunctionCall(function_call) = &parts[2] {
1097            assert_eq!(function_call.name, "test_function");
1098            assert_eq!(
1099                function_call.args.as_object().unwrap().get("arg1").unwrap(),
1100                "value1"
1101            );
1102        } else {
1103            panic!("Expected function call part");
1104        }
1105
1106        if let Part::FunctionResponse(function_response) = &parts[3] {
1107            assert_eq!(function_response.name, "test_function");
1108            assert_eq!(
1109                function_response
1110                    .response
1111                    .as_ref()
1112                    .unwrap()
1113                    .get("result")
1114                    .unwrap(),
1115                "success"
1116            );
1117        } else {
1118            panic!("Expected function response part");
1119        }
1120
1121        if let Part::FileData(file_data) = &parts[4] {
1122            assert_eq!(file_data.mime_type.as_ref().unwrap(), "application/pdf");
1123            assert_eq!(file_data.file_uri, "http://example.com/file.pdf");
1124        } else {
1125            panic!("Expected file data part");
1126        }
1127
1128        if let Part::ExecutableCode(executable_code) = &parts[5] {
1129            assert_eq!(executable_code.code, "print('Hello, world!')");
1130        } else {
1131            panic!("Expected executable code part");
1132        }
1133
1134        if let Part::CodeExecutionResult(code_execution_result) = &parts[6] {
1135            assert_eq!(
1136                code_execution_result.clone().output.unwrap(),
1137                "Hello, world!"
1138            );
1139        } else {
1140            panic!("Expected code execution result part");
1141        }
1142    }
1143
1144    #[test]
1145    fn test_deserialize_message_model() {
1146        let json_data = json!({
1147            "parts": [{"text": "Hello, user!"}],
1148            "role": "model"
1149        });
1150
1151        let content: Content = serde_json::from_value(json_data).unwrap();
1152        assert_eq!(content.role, Some(Role::Model));
1153        assert_eq!(content.parts.len(), 1);
1154        if let Part::Text(text) = &content.parts.first() {
1155            assert_eq!(text, "Hello, user!");
1156        } else {
1157            panic!("Expected text part");
1158        }
1159    }
1160
1161    #[test]
1162    fn test_message_conversion_user() {
1163        let msg = message::Message::user("Hello, world!");
1164        let content: Content = msg.try_into().unwrap();
1165        assert_eq!(content.role, Some(Role::User));
1166        assert_eq!(content.parts.len(), 1);
1167        if let Part::Text(text) = &content.parts.first() {
1168            assert_eq!(text, "Hello, world!");
1169        } else {
1170            panic!("Expected text part");
1171        }
1172    }
1173
1174    #[test]
1175    fn test_message_conversion_model() {
1176        let msg = message::Message::assistant("Hello, user!");
1177
1178        let content: Content = msg.try_into().unwrap();
1179        assert_eq!(content.role, Some(Role::Model));
1180        assert_eq!(content.parts.len(), 1);
1181        if let Part::Text(text) = &content.parts.first() {
1182            assert_eq!(text, "Hello, user!");
1183        } else {
1184            panic!("Expected text part");
1185        }
1186    }
1187
1188    #[test]
1189    fn test_message_conversion_tool_call() {
1190        let tool_call = message::ToolCall {
1191            id: "test_tool".to_string(),
1192            call_id: None,
1193            function: message::ToolFunction {
1194                name: "test_function".to_string(),
1195                arguments: json!({"arg1": "value1"}),
1196            },
1197        };
1198
1199        let msg = message::Message::Assistant {
1200            id: None,
1201            content: OneOrMany::one(message::AssistantContent::ToolCall(tool_call)),
1202        };
1203
1204        let content: Content = msg.try_into().unwrap();
1205        assert_eq!(content.role, Some(Role::Model));
1206        assert_eq!(content.parts.len(), 1);
1207        if let Part::FunctionCall(function_call) = &content.parts.first() {
1208            assert_eq!(function_call.name, "test_function");
1209            assert_eq!(
1210                function_call.args.as_object().unwrap().get("arg1").unwrap(),
1211                "value1"
1212            );
1213        } else {
1214            panic!("Expected function call part");
1215        }
1216    }
1217}