google_generative_ai_rs/v1/
gemini.rs

1//! Handles the text interaction with the API
2use core::fmt;
3use serde::{Deserialize, Serialize};
4
5use self::request::{FileData, InlineData, VideoMetadata};
6/// Defines the type of response expected from the API.
7/// Used at the end of the API URL for the Gemini API.
8#[derive(Debug, Clone, Default, PartialEq)]
9pub enum ResponseType {
10    #[default]
11    GenerateContent,
12    StreamGenerateContent,
13    GetModel,
14    GetModelList,
15    CountTokens,
16    EmbedContent,
17    BatchEmbedContents,
18}
19impl fmt::Display for ResponseType {
20    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
21        match self {
22            ResponseType::GenerateContent => f.write_str("generateContent"),
23            ResponseType::StreamGenerateContent => f.write_str("streamGenerateContent"),
24            ResponseType::GetModel => f.write_str(""), // No display as its already in the URL
25            ResponseType::GetModelList => f.write_str(""), // No display as its already in the URL
26            ResponseType::CountTokens => f.write_str("countTokens"),
27            ResponseType::EmbedContent => f.write_str("embedContent"),
28            ResponseType::BatchEmbedContents => f.write_str("batchEmbedContents"),
29        }
30    }
31}
32/// Captures the information for a specific Google generative AI model.
33///
34/// ```json
35/// {
36///    "name": "models/gemini-pro",
37///    "version": "001",
38///    "displayName": "Gemini Pro",
39///    "description": "The best model for scaling across a wide range of tasks",
40///    "inputTokenLimit": 30720,
41///    "outputTokenLimit": 2048,
42///    "supportedGenerationMethods": [
43///        "generateContent",
44///        "countTokens"
45///    ],
46///    "temperature": 0.9,
47///    "topP": 1,
48///    "topK": 100,
49/// }
50/// ```
51#[derive(Debug, Default, Deserialize)]
52#[serde(rename_all = "camelCase")]
53#[serde(rename = "model")]
54pub struct ModelInformation {
55    pub name: String,
56    pub version: String,
57    pub display_name: String,
58    pub description: String,
59    pub input_token_limit: i32,
60    pub output_token_limit: i32,
61    pub supported_generation_methods: Vec<String>,
62    pub temperature: Option<f32>,
63    #[serde(skip_serializing_if = "Option::is_none")]
64    pub top_p: Option<f32>,
65    #[serde(skip_serializing_if = "Option::is_none")]
66    pub top_k: Option<i32>,
67}
68/// Lists the available models for the Gemini API.
69#[derive(Debug, Default, Deserialize)]
70#[serde(rename = "models")]
71pub struct ModelInformationList {
72    pub models: Vec<ModelInformation>,
73}
74
75#[derive(Debug, Clone, Default, PartialEq, Serialize)]
76#[serde(rename_all = "kebab-case")]
77pub enum Model {
78    #[default]
79    Gemini1_0Pro,
80    #[cfg(feature = "beta")]
81    #[cfg_attr(docsrs, doc(cfg(feature = "beta")))]
82    Gemini1_5Pro,
83    #[cfg(feature = "beta")]
84    #[cfg_attr(docsrs, doc(cfg(feature = "beta")))]
85    Gemini1_5Flash,
86    #[cfg(feature = "beta")]
87    #[cfg_attr(docsrs, doc(cfg(feature = "beta")))]
88    Gemini1_5Flash8B,
89    #[cfg(feature = "beta")]
90    #[cfg_attr(docsrs, doc(cfg(feature = "beta")))]
91    Gemini2_0Flash,
92    #[cfg(feature = "beta")]
93    #[cfg_attr(docsrs, doc(cfg(feature = "beta")))]
94    Custom(String),
95    // TODO: Embedding004
96}
97impl fmt::Display for Model {
98    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
99        match self {
100            Model::Gemini1_0Pro => write!(f, "gemini-1.0-pro"),
101
102            #[cfg(feature = "beta")]
103            #[cfg_attr(docsrs, doc(cfg(feature = "beta")))]
104            Model::Gemini1_5Pro => write!(f, "gemini-1.5-pro-latest"),
105            #[cfg(feature = "beta")]
106            #[cfg_attr(docsrs, doc(cfg(feature = "beta")))]
107            Model::Gemini1_5Flash => write!(f, "gemini-1.5-flash"),
108            #[cfg(feature = "beta")]
109            #[cfg_attr(docsrs, doc(cfg(feature = "beta")))]
110            Model::Gemini1_5Flash8B => write!(f, "gemini-1.5-flash-8b"),
111
112            #[cfg(feature = "beta")]
113            #[cfg_attr(docsrs, doc(cfg(feature = "beta")))]
114            Model::Gemini2_0Flash => write!(f, "gemini-2.0-flash-exp"),
115
116            #[cfg(feature = "beta")]
117            #[cfg_attr(docsrs, doc(cfg(feature = "beta")))]
118            Model::Custom(name) => write!(f, "{}", name),
119            // TODO: Model::Embedding004 => write!(f, "text-embedding-004"),
120        }
121    }
122}
123
124#[derive(Debug, Clone, Deserialize, Serialize)]
125pub struct Content {
126    pub role: Role,
127    #[serde(default)]
128    pub parts: Vec<Part>,
129}
130
131#[derive(Debug, Clone, Deserialize, Serialize)]
132#[serde(rename_all = "camelCase")]
133pub struct Part {
134    #[serde(skip_serializing_if = "Option::is_none")]
135    pub text: Option<String>,
136    #[serde(skip_serializing_if = "Option::is_none")]
137    pub inline_data: Option<InlineData>,
138    #[serde(skip_serializing_if = "Option::is_none")]
139    pub file_data: Option<FileData>,
140    #[serde(skip_serializing_if = "Option::is_none")]
141    pub video_metadata: Option<VideoMetadata>,
142}
143
144#[derive(Debug, Clone, Deserialize, Serialize)]
145#[serde(rename_all = "lowercase")]
146pub enum Role {
147    User,
148    Model,
149}
150
151/// The request format follows the following structure:
152/// ```json
153/// {
154///   "contents": [
155///     {
156///       "role": string,
157///       "parts": [
158///         {
159///           /// Union field data can be only one of the following:
160///           "text": string,
161///           "inlineData": {
162///             "mimeType": string,
163///             "data": string
164///           },
165///           "fileData": {
166///             "mimeType": string,
167///             "fileUri": string
168///           },
169///           /// End of list of possible types for union field data.
170///           "videoMetadata": {
171///             "startOffset": {
172///               "seconds": integer,
173///               "nanos": integer
174///             },
175///             "endOffset": {
176///               "seconds": integer,
177///               "nanos": integer
178///             }
179///           }
180///         }
181///       ]
182///     }
183///   ],
184///   "tools": [
185///     {
186///       "functionDeclarations": [
187///         {
188///           "name": string,
189///           "description": string,
190///           "parameters": {
191///             object (OpenAPI Object Schema)
192///           }
193///         }
194///       ]
195///     }
196///   ],
197///   "safetySettings": [
198///     {
199///       "category": enum (HarmCategory),
200///       "threshold": enum (HarmBlockThreshold)
201///     }
202///   ],
203///   "generationConfig": {
204///     "temperature": number,
205///     "topP": number,
206///     "topK": number,
207///     "candidateCount": integer,
208///     "maxOutputTokens": integer,
209///     "stopSequences": [
210///       string
211///     ]
212///   }
213/// }
214/// ```
215/// See https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/gemini
216pub mod request {
217    use serde::{Deserialize, Serialize};
218
219    use super::{
220        safety::{HarmBlockThreshold, HarmCategory},
221        Content,
222    };
223
224    /// Holds the data to be used for a specific text request
225    #[derive(Debug, Clone, Deserialize, Serialize)]
226    pub struct Request {
227        pub contents: Vec<Content>,
228        #[serde(skip_serializing_if = "Vec::is_empty")]
229        pub tools: Vec<Tools>,
230        #[serde(skip_serializing_if = "Vec::is_empty")]
231        #[serde(default, rename = "safetySettings")]
232        pub safety_settings: Vec<SafetySettings>,
233        #[serde(skip_serializing_if = "Option::is_none")]
234        #[serde(default, rename = "generationConfig")]
235        pub generation_config: Option<GenerationConfig>,
236
237        #[cfg(feature = "beta")]
238        #[cfg_attr(docsrs, doc(cfg(feature = "beta")))]
239        #[serde(skip_serializing_if = "Option::is_none")]
240        #[serde(default, rename = "system_instruction")]
241        pub system_instruction: Option<SystemInstructionContent>,
242    }
243    impl Request {
244        pub fn new(
245            contents: Vec<Content>,
246            tools: Vec<Tools>,
247            safety_settings: Vec<SafetySettings>,
248            generation_config: Option<GenerationConfig>,
249        ) -> Self {
250            Request {
251                contents,
252                tools,
253                safety_settings,
254                generation_config,
255                #[cfg(feature = "beta")]
256                system_instruction: None,
257            }
258        }
259
260        #[cfg(feature = "beta")]
261        #[cfg_attr(docsrs, doc(cfg(feature = "beta")))]
262        pub fn set_system_instruction(&mut self, instruction: SystemInstructionContent) {
263            self.system_instruction = Some(instruction);
264        }
265
266        /// Gets the total character count of the prompt.
267        /// As per the Gemini API, "Text input is charged by every 1,000 characters of input (prompt).
268        ///     Characters are counted by UTF-8 code points and white space is excluded from the count."
269        /// See: https://cloud.google.com/vertex-ai/pricing
270        ///
271        /// Returns the total character count of the prompt as per the Gemini API.
272        pub fn get_prompt_character_count(&self) -> usize {
273            let mut text_count = 0;
274            for content in &self.contents {
275                for part in &content.parts {
276                    if let Some(text) = &part.text {
277                        // Exclude white space from the count
278                        let num_chars = bytecount::num_chars(text.as_bytes());
279                        let num_spaces = bytecount::count(text.as_bytes(), b' ');
280                        text_count += num_chars - num_spaces;
281                    }
282                }
283            }
284            text_count
285        }
286    }
287    #[derive(Debug, Clone, Deserialize, Serialize)]
288    #[serde(rename_all = "camelCase")]
289    pub struct InlineData {
290        pub mime_type: String,
291        pub data: String,
292    }
293    #[derive(Debug, Clone, Deserialize, Serialize)]
294    #[serde(rename_all = "camelCase")]
295    pub struct FileData {
296        pub mime_type: String,
297        pub file_uri: String,
298    }
299    #[derive(Debug, Clone, Deserialize, Serialize)]
300    #[serde(rename_all = "camelCase")]
301    pub struct VideoMetadata {
302        pub start_offset: StartOffset,
303        pub end_offset: EndOffset,
304    }
305    #[derive(Debug, Clone, Deserialize, Serialize)]
306    pub struct StartOffset {
307        pub seconds: i32,
308        pub nanos: i32,
309    }
310    #[derive(Debug, Clone, Deserialize, Serialize)]
311    pub struct EndOffset {
312        pub seconds: i32,
313        pub nanos: i32,
314    }
315    #[derive(Debug, Clone, Deserialize, Serialize)]
316    pub struct Tools {
317        #[serde(rename = "functionDeclarations")]
318        pub function_declarations: Vec<FunctionDeclaration>,
319    }
320
321    #[derive(Debug, Clone, Deserialize, Serialize)]
322    pub struct FunctionDeclaration {
323        pub name: String,
324        pub description: String,
325        pub parameters: serde_json::Value,
326    }
327
328    #[derive(Debug, Clone, Deserialize, Serialize)]
329    pub struct SafetySettings {
330        pub category: HarmCategory,
331        pub threshold: HarmBlockThreshold,
332    }
333    #[derive(Debug, Clone, Deserialize, Serialize)]
334    #[serde(rename_all = "camelCase")]
335    pub struct GenerationConfig {
336        pub temperature: Option<f32>,
337        pub top_p: Option<f32>,
338        pub top_k: Option<i32>,
339        pub candidate_count: Option<i32>,
340        pub max_output_tokens: Option<i32>,
341        pub stop_sequences: Option<Vec<String>>,
342
343        #[cfg(feature = "beta")]
344        #[cfg_attr(docsrs, doc(cfg(feature = "beta")))]
345        pub response_mime_type: Option<String>,
346
347        #[cfg(feature = "beta")]
348        #[cfg_attr(docsrs, doc(cfg(feature = "beta")))]
349        pub response_schema: Option<serde_json::Value>,
350    }
351
352    #[cfg(feature = "beta")]
353    #[cfg_attr(docsrs, doc(cfg(feature = "beta")))]
354    #[derive(Debug, Clone, Deserialize, Serialize)]
355    pub struct SystemInstructionContent {
356        #[serde(default)]
357        pub parts: Vec<SystemInstructionPart>,
358    }
359
360    #[cfg(feature = "beta")]
361    #[cfg_attr(docsrs, doc(cfg(feature = "beta")))]
362    #[derive(Debug, Clone, Deserialize, Serialize)]
363    #[serde(rename_all = "camelCase")]
364    pub struct SystemInstructionPart {
365        #[serde(skip_serializing_if = "Option::is_none")]
366        pub text: Option<String>,
367    }
368}
369
370/// The response format follows the following structure:
371/// ```json
372/// {
373///   "candidates": [
374///     {
375///       "content": {
376///         "parts": [
377///           {
378///             "text": string
379///           }
380///         ]
381///       },
382///       "finishReason": enum (FinishReason),
383///       "safetyRatings": [
384///         {
385///           "category": enum (HarmCategory),
386///           "probability": enum (HarmProbability),
387///           "blocked": boolean
388///         }
389///       ],
390///       "citationMetadata": {
391///         "citations": [
392///           {
393///             "startIndex": integer,
394///             "endIndex": integer,
395///             "uri": string,
396///             "title": string,
397///             "license": string,
398///             "publicationDate": {
399///               "year": integer,
400///               "month": integer,
401///               "day": integer
402///             }
403///           }
404///         ]
405///       }
406///     }
407///   ],
408///   "usageMetadata": {
409///     "promptTokenCount": integer,
410///     "candidatesTokenCount": integer,
411///     "totalTokenCount": integer
412///   }
413/// }
414/// ```
415pub mod response {
416    use core::fmt;
417    use futures::Stream;
418    use reqwest_streams::error::StreamBodyError;
419    use serde::Deserialize;
420    use std::pin::Pin;
421
422    use super::{
423        safety::{HarmCategory, HarmProbability},
424        Content,
425    };
426
427    impl fmt::Debug for StreamedGeminiResponse {
428        fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
429            write!(f, "StreamedGeminiResponse {{ /* stream values */ }}")
430        }
431    }
432
433    type ResponseJsonStream =
434        Pin<Box<dyn Stream<Item = Result<serde_json::Value, StreamBodyError>> + Send>>;
435
436    /// The token count for a given prompt.
437    #[derive(Debug, Default, Deserialize)]
438    #[serde(rename_all = "camelCase")]
439    pub struct TokenCount {
440        pub total_tokens: u64,
441    }
442
443    // The streamGenerateContent response
444    #[derive(Default)]
445    pub struct StreamedGeminiResponse {
446        pub response_stream: Option<ResponseJsonStream>,
447    }
448
449    #[derive(Debug, Clone, Deserialize)]
450    #[serde(rename_all = "camelCase")]
451    pub struct GeminiResponse {
452        pub candidates: Vec<Candidate>,
453        pub prompt_feedback: Option<PromptFeedback>,
454        pub usage_metadata: Option<UsageMetadata>,
455    }
456    #[derive(Debug, Clone, Deserialize)]
457    #[serde(rename_all = "camelCase")]
458    pub enum GeminiErrorResponse {
459        Error {
460            code: u16,
461            message: String,
462            status: String,
463        },
464    }
465
466    impl GeminiResponse {
467        /// Returns the total character count of the response as per the Gemini API.
468        pub fn get_response_character_count(&self) -> usize {
469            let mut text_count = 0;
470            for candidate in &self.candidates {
471                for content in &candidate.content.parts {
472                    if let Some(text) = &content.text {
473                        // Exclude white space from the count
474                        let num_chars = bytecount::num_chars(text.as_bytes());
475                        let num_spaces = bytecount::count(text.as_bytes(), b' ');
476                        text_count += num_chars - num_spaces;
477                    }
478                }
479            }
480            text_count
481        }
482    }
483    #[derive(Debug, Clone, Deserialize)]
484    #[serde(rename_all = "camelCase")]
485    pub struct Candidate {
486        pub content: Content,
487        pub finish_reason: Option<String>,
488        pub index: Option<i32>,
489        #[serde(default)]
490        pub safety_ratings: Vec<SafetyRating>,
491    }
492    #[derive(Debug, Clone, Deserialize)]
493    #[serde(rename_all = "camelCase")]
494    pub struct UsageMetadata {
495        pub prompt_token_count: u64,
496        pub candidates_token_count: u64,
497    }
498    #[derive(Debug, Clone, Deserialize)]
499    pub struct PromptFeedback {
500        #[serde(rename = "safetyRatings")]
501        pub safety_ratings: Vec<SafetyRating>,
502    }
503
504    #[derive(Debug, Clone, Deserialize)]
505    pub struct SafetyRating {
506        pub category: HarmCategory,
507        pub probability: HarmProbability,
508        #[serde(default)]
509        pub blocked: bool,
510    }
511
512    /// The reason why the model stopped generating tokens. If empty, the model has not stopped generating the tokens.
513    #[derive(Debug, Clone, Deserialize)]
514    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
515    pub enum FinishReason {
516        FinishReasonUnspecified, // The finish reason is unspecified.
517        FinishReasonStop,        // Natural stop point of the model or provided stop sequence.
518        FinishReasonMaxTokens, // The maximum number of tokens as specified in the request was reached.
519        FinishReasonSafety, // The token generation was stopped as the response was flagged for safety reasons. Note that [`Candidate`].content is empty if content filters block the output.
520        FinishReasonRecitation, // The token generation was stopped as the response was flagged for unauthorized citations.
521        FinishReasonOther,      // All other reasons that stopped the token
522    }
523    #[cfg(test)]
524    mod tests {}
525}
526
527/// The safety data for HarmCategory, HarmBlockThreshold and HarmProbability
528pub mod safety {
529    use serde::{Deserialize, Serialize};
530
531    /// The safety category to configure a threshold for.
532    #[derive(Debug, Clone, Deserialize, Serialize)]
533    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
534    pub enum HarmCategory {
535        HarmCategorySexuallyExplicit,
536        HarmCategoryHateSpeech,
537        HarmCategoryHarassment,
538        HarmCategoryDangerousContent,
539    }
540    /// For a request: the safety category to configure a threshold for. For a response: the harm probability levels in the content.
541    #[derive(Debug, Clone, Deserialize, Serialize)]
542    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
543    pub enum HarmProbability {
544        HarmProbabilityUnspecified,
545        Negligible,
546        Low,
547        Medium,
548        High,
549    }
550    /// The threshold for blocking responses that could belong to the specified safety category based on probability.
551    #[derive(Debug, Clone, Deserialize, Serialize)]
552    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
553    pub enum HarmBlockThreshold {
554        BlockNone,
555        BlockLowAndAbove,
556        BlockMedAndAbove,
557        BlockHighAndAbove,
558    }
559}