async_gemini/models/generate_content/
request.rs

1use crate::{
2    models::generate_content::HarmCategory,
3    util::{deserialize_obj_or_vec, deserialize_option_obj_or_vec},
4};
5
6use serde::{Deserialize, Serialize};
7
8use super::{Content, Part, Role};
9
10/// when deserilization:
11/// - google api support both camelCase and snake_case key, but we only support camel case.
12/// - google api allow trailling comma, but not here
13#[derive(Debug, Serialize, Deserialize, Default, Clone, PartialEq)]
14#[serde(rename_all = "camelCase")]
15pub struct GenerateContentRequest {
16    // contents must start with user and alternate between user and model, and end with user or function response
17    #[serde(deserialize_with = "deserialize_obj_or_vec")]
18    pub contents: Vec<Content>,
19    /// A piece of code that enables the system to interact with external systems to perform an action, or set of actions, outside of knowledge and scope of the model.
20    #[serde(skip_serializing_if = "Option::is_none")]
21    pub tools: Option<Vec<Tool>>,
22    #[serde(skip_serializing_if = "Option::is_none")]
23    #[serde(deserialize_with = "deserialize_option_obj_or_vec", default)]
24    pub safety_settings: Option<Vec<SafetySetting>>,
25    #[serde(skip_serializing_if = "Option::is_none")]
26    pub generation_config: Option<GenerateionConfig>,
27}
28
29#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
30#[serde(rename_all = "camelCase")]
31pub struct Tool {
32    /// One or more function declarations. Each function declaration contains information about one function that includes the following:
33    /// name The name of the function to call. Must start with a letter or an underscore. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 64.
34    /// description (optional). The description and purpose of the function. The model uses this to decide how and whether to call the function. For the best results, we recommend that you include a description.
35    /// parameters The parameters of this function in a format that's compatible with the OpenAPI schema format.
36    /// For more information, see Function calling.
37    function_declarations: Vec<FunctionTool>,
38}
39
40#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
41pub struct FunctionTool {
42    name: String,
43    description: Option<String>,
44    parameters: serde_json::Value,
45}
46
47#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
48pub struct SafetySetting {
49    category: HarmCategory,
50    threshold: SafetySettingThreshold,
51}
52
53/// The threshold for blocking responses that could belong to the specified safety category based on probability.
54#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
55#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
56pub enum SafetySettingThreshold {
57    BlockNone,
58    BlockLowAndAbove,
59    BlockMedAndAbove,
60    BlockOnlyHigh,
61}
62
63#[derive(Debug, Serialize, Deserialize, Default, Clone, PartialEq)]
64#[serde(rename_all = "camelCase")]
65pub struct GenerateionConfig {
66    /// The temperature is used for sampling during the response generation, which occurs when topP and topK are applied. Temperature controls the degree of randomness in token selection. Lower temperatures are good for prompts that require a more deterministic and less open-ended or creative response, while higher temperatures can lead to more diverse or creative results. A temperature of 0 is deterministic: the highest probability response is always selected.
67    /// Range: 0.0 - 1.0
68    /// Default for gemini-1.0-pro: 0.9
69    /// Default for gemini-1.0-pro-vision: 0.4
70    pub temperature: Option<f32>,
71    /// Top-P changes how the model selects tokens for output. Tokens are selected from the most (see top-K) to least probable until the sum of their probabilities equals the top-P value. For example, if tokens A, B, and C have a probability of 0.3, 0.2, and 0.1 and the top-P value is 0.5, then the model will select either A or B as the next token by using temperature and excludes C as a candidate.
72    /// Specify a lower value for less random responses and a higher value for more random responses.
73    /// Range: 0.0 - 1.0
74    /// Default: 1.0
75    pub top_p: Option<f32>,
76    /// Top-K changes how the model selects tokens for output. A top-K of 1 means the next selected token is the most probable among all tokens in the model's vocabulary (also called greedy decoding), while a top-K of 3 means that the next token is selected from among the three most probable tokens by using temperature.
77    /// For each token selection step, the top-K tokens with the highest probabilities are sampled. Then tokens are further filtered based on top-P with the final token selected using temperature sampling.
78    /// Specify a lower value for less random responses and a higher value for more random responses.
79    /// Range: 1-40
80    /// Default for gemini-1.0-pro-vision: 32
81    /// Default for gemini-1.0-pro: none
82    pub top_k: Option<u32>,
83    /// The number of response variations to return.
84    /// This value must be 1.
85    pub candidate_count: Option<u32>,
86    /// Maximum number of tokens that can be generated in the response. A token is approximately four characters. 100 tokens correspond to roughly 60-80 words.
87    /// Specify a lower value for shorter responses and a higher value for potentially longer responses.
88    /// Range for gemini-1.0-pro: 1-8192 (default: 8192)
89    /// Range for gemini-1.0-pro-vision: 1-2048 (default: 2048)
90    pub max_output_tokens: Option<u32>,
91    /// Specifies a list of strings that tells the model to stop generating text if one of the strings is encountered in the response. If a string appears multiple times in the response, then the response truncates where it's first encountered. The strings are case-sensitive.
92    /// For example, if the following is the returned response when stopSequences isn't specified:
93    /// public static string reverse(string myString)
94    /// Then the returned response with stopSequences set to ["Str","reverse"] is:
95    /// public static string
96    /// Maximum 5 items in the list.
97    pub stop_sequences: Option<Vec<String>>,
98}
99
100/// Gemini require contents:
101/// 1. start with "user" role
102/// 2. alternate between "user" and "model" role
103/// 3. end with "user" role or function response
104pub fn process_contents(contents: &[Content]) -> Vec<Content> {
105    let mut filtered = Vec::with_capacity(contents.len());
106    if contents.is_empty() {
107        return filtered;
108    }
109    let mut prev_role: Option<Role> = None;
110    for content in contents {
111        if let Some(pr) = prev_role {
112            if pr == content.role {
113                if let Some(last) = filtered.last_mut() {
114                    last.parts.extend(content.parts.clone());
115                };
116                prev_role = Some(content.role);
117                continue;
118            }
119        }
120        filtered.push(content.clone());
121        prev_role = Some(content.role);
122    }
123
124    if let Some(first) = filtered.first() {
125        if first.role == Role::Model {
126            filtered.insert(
127                0,
128                Content {
129                    role: Role::User,
130                    parts: vec![Part::Text("Starting the conversation...".to_string())],
131                },
132            )
133        }
134    }
135
136    if let Some(last) = filtered.last() {
137        if last.role == Role::Model {
138            filtered.push(Content {
139                role: Role::User,
140                parts: vec![Part::Text("continue".to_string())],
141            });
142        }
143    }
144
145    filtered
146}
147
148#[cfg(test)]
149mod tests {
150    use serde_json::json;
151
152    use crate::models::generate_content::{Content, FileData, Part, Role};
153
154    use super::*;
155    #[test]
156    fn serde() {
157        let tests = vec![
158            (
159                "simple",
160                r#"{"contents": {"role": "user","parts": {"text": "Give me a recipe for banana bread."}}}"#,
161                GenerateContentRequest {
162                    contents: vec![Content {
163                        role: Role::User,
164                        parts: vec![Part::Text(
165                             "Give me a recipe for banana bread.".to_string(),
166                        )],
167                    }],
168                    ..Default::default()
169                },
170            ),
171            (
172                "text",
173                r#"{
174                    "contents":
175                    {
176                        "role": "user",
177                        "parts":
178                        {
179                            "text": "Give me a recipe for banana bread."
180                        }
181                    },
182                    "safetySettings":
183                    {
184                        "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
185                        "threshold": "BLOCK_LOW_AND_ABOVE"
186                    },
187                    "generationConfig":
188                    {
189                        "temperature": 0.2,
190                        "topP": 0.8,
191                        "topK": 40
192                    }
193                }"#,
194                GenerateContentRequest {
195                    contents: vec![Content {
196                        role: Role::User,
197                        parts: vec![Part::Text(
198                             "Give me a recipe for banana bread.".to_string(),
199                        )],
200                    }],
201                    safety_settings: Some(vec![SafetySetting {
202                        category: HarmCategory::SexuallyExplicit,
203                        threshold: SafetySettingThreshold::BlockLowAndAbove,
204                    }]),
205                    generation_config: Some(GenerateionConfig {
206                        temperature: Some(0.2),
207                        top_p: Some(0.8),
208                        top_k: Some(40),
209                        ..Default::default()
210                    }),
211                    ..Default::default()
212                },
213            ),
214            (
215                "chat",
216                r#"{
217                    "contents": [
218                      {
219                        "role": "USER",
220                        "parts": { "text": "Hello!" }
221                      },
222                      {
223                        "role": "MODEL",
224                        "parts": { "text": "Argh! What brings ye to my ship?" }
225                      },
226                      {
227                        "role": "USER",
228                        "parts": { "text": "Wow! You are a real-life priate!" }
229                      }
230                    ],
231                    "safetySettings": {
232                      "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
233                      "threshold": "BLOCK_LOW_AND_ABOVE"
234                    },
235                    "generationConfig": {
236                      "temperature": 0.2,
237                      "topP": 0.8,
238                      "topK": 40,
239                      "maxOutputTokens": 200
240                    }
241                  }"#,
242                GenerateContentRequest {
243                    contents: vec![
244                        Content {
245                            role: Role::User,
246                            parts: vec![Part::Text(
247                                 "Hello!".to_string(),
248                            )],
249                        },
250                        Content {
251                            role: Role::Model,
252                            parts: vec![Part::Text(
253                                 "Argh! What brings ye to my ship?".to_string(),
254                            )],
255                        },
256                        Content {
257                            role: Role::User,
258                            parts: vec![Part::Text(
259                                 "Wow! You are a real-life priate!".to_string(),
260                            )],
261                        },
262                    ],
263                    safety_settings: Some(vec![SafetySetting {
264                        category: HarmCategory::SexuallyExplicit,
265                        threshold: SafetySettingThreshold::BlockLowAndAbove,
266                    }]),
267                    generation_config: Some(GenerateionConfig {
268                        temperature: Some(0.2),
269                        top_p: Some(0.8),
270                        top_k: Some(40),
271                        max_output_tokens: Some(200),
272                        ..Default::default()
273                    }),
274                    ..Default::default()
275                },
276            ),
277            (
278                "multimodal",
279                r#"{
280                    "contents": {
281                      "role": "user",
282                      "parts": [
283                        {
284                          "fileData": {
285                            "mimeType": "image/jpeg",
286                            "fileUri": "gs://cloud-samples-data/ai-platform/flowers/daisy/10559679065_50d2b16f6d.jpg"
287                          }
288                        },
289                        {
290                          "text": "Describe this picture."
291                        }
292                      ]
293                    },
294                    "safetySettings": {
295                      "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
296                      "threshold": "BLOCK_LOW_AND_ABOVE"
297                    },
298                    "generationConfig": {
299                      "temperature": 0.4,
300                      "topP": 1.0,
301                      "topK": 32,
302                      "maxOutputTokens": 2048
303                    }
304                  }"#,
305                GenerateContentRequest {
306                    contents: vec![Content {
307                        role: Role::User,
308                        parts: vec![
309                            Part::File(FileData {
310                                mime_type: "image/jpeg".to_string(),
311                                file_uri: "gs://cloud-samples-data/ai-platform/flowers/daisy/10559679065_50d2b16f6d.jpg".to_string(),
312                                video_metadata: None,
313                            }),
314                            Part::Text(
315                                 "Describe this picture.".to_string(),
316                            ),
317                        ],
318                    }],
319                    safety_settings: Some(vec![SafetySetting {
320                        category: HarmCategory::SexuallyExplicit,
321                        threshold: SafetySettingThreshold::BlockLowAndAbove,
322                    }]),
323                    generation_config: Some(GenerateionConfig {
324                        temperature: Some(0.4),
325                        top_p: Some(1.0),
326                        top_k: Some(32),
327                        max_output_tokens: Some(2048),
328                        ..Default::default()
329                    }),
330                    ..Default::default()
331                },
332            ),
333            (
334                "function",
335                r#"{
336                    "contents": {
337                      "role": "user",
338                      "parts": {
339                        "text": "Which theaters in Mountain View show Barbie movie?"
340                      }
341                    },
342                    "tools": [
343                      {
344                        "functionDeclarations": [
345                          {
346                            "name": "find_movies",
347                            "description": "find movie titles currently playing in theaters based on any description, genre, title words, etc.",
348                            "parameters": {
349                              "type": "object",
350                              "properties": {
351                                "location": {
352                                  "type": "string",
353                                  "description": "The city and state, e.g. San Francisco, CA or a zip code e.g. 95616"
354                                },
355                                "description": {
356                                  "type": "string",
357                                  "description": "Any kind of description including category or genre, title words, attributes, etc."
358                                }
359                              },
360                              "required": [
361                                "description"
362                              ]
363                            }
364                          },
365                          {
366                            "name": "find_theaters",
367                            "description": "find theaters based on location and optionally movie title which are is currently playing in theaters",
368                            "parameters": {
369                              "type": "object",
370                              "properties": {
371                                "location": {
372                                  "type": "string",
373                                  "description": "The city and state, e.g. San Francisco, CA or a zip code e.g. 95616"
374                                },
375                                "movie": {
376                                  "type": "string",
377                                  "description": "Any movie title"
378                                }
379                              },
380                              "required": [
381                                "location"
382                              ]
383                            }
384                          },
385                          {
386                            "name": "get_showtimes",
387                            "description": "Find the start times for movies playing in a specific theater",
388                            "parameters": {
389                              "type": "object",
390                              "properties": {
391                                "location": {
392                                  "type": "string",
393                                  "description": "The city and state, e.g. San Francisco, CA or a zip code e.g. 95616"
394                                },
395                                "movie": {
396                                  "type": "string",
397                                  "description": "Any movie title"
398                                },
399                                "theater": {
400                                  "type": "string",
401                                  "description": "Name of the theater"
402                                },
403                                "date": {
404                                  "type": "string",
405                                  "description": "Date for requested showtime"
406                                }
407                              },
408                              "required": [
409                                "location",
410                                "movie",
411                                "theater",
412                                "date"
413                              ]
414                            }
415                          }
416                        ]
417                      }
418                    ]
419                  }"#,
420                GenerateContentRequest {
421                    contents: vec![Content {
422                        role: Role::User,
423                        parts: vec![
424                            Part::Text(
425                                 "Which theaters in Mountain View show Barbie movie?".to_string(),
426                            ),
427                        ],
428                    }],
429                    tools:Some(vec![Tool{
430                        function_declarations:vec![
431                            FunctionTool{
432                                name:"find_movies".to_string(),
433                                description:Some("find movie titles currently playing in theaters based on any description, genre, title words, etc.".to_string()),
434                                parameters:json!({
435                              "type": "object",
436                              "properties": {
437                                "location": {
438                                  "type": "string",
439                                  "description": "The city and state, e.g. San Francisco, CA or a zip code e.g. 95616"
440                                },
441                                "description": {
442                                  "type": "string",
443                                  "description": "Any kind of description including category or genre, title words, attributes, etc."
444                                }
445                              },
446                              "required": [
447                                "description"
448                              ]
449                            })
450                            },
451                            FunctionTool{
452                                name:"find_theaters".to_string(),
453                                description:Some("find theaters based on location and optionally movie title which are is currently playing in theaters".to_string()),
454                                parameters:json!({
455                              "type": "object",
456                              "properties": {
457                                "location": {
458                                  "type": "string",
459                                  "description": "The city and state, e.g. San Francisco, CA or a zip code e.g. 95616"
460                                },
461                                "movie": {
462                                  "type": "string",
463                                  "description": "Any movie title"
464                                }
465                              },
466                              "required": [
467                                "location"
468                              ]
469                            })
470                            },
471                            FunctionTool{
472                                name:"get_showtimes".to_string(),
473                                description:Some("Find the start times for movies playing in a specific theater".to_string()),
474                                parameters:json!({
475                              "type": "object",
476                              "properties": {
477                                "location": {
478                                  "type": "string",
479                                  "description": "The city and state, e.g. San Francisco, CA or a zip code e.g. 95616"
480                                },
481                                "movie": {
482                                  "type": "string",
483                                  "description": "Any movie title"
484                                },
485                                "theater": {
486                                  "type": "string",
487                                  "description": "Name of the theater"
488                                },
489                                "date": {
490                                  "type": "string",
491                                  "description": "Date for requested showtime"
492                                }
493                              },
494                              "required": [
495                                "location",
496                                "movie",
497                                "theater",
498                                "date"
499                              ]
500                            })
501                            }
502                        ]
503                    }]),
504                    ..Default::default()
505                },
506            ),
507        ];
508        for (name, json, expected) in tests {
509            //test deserialize
510            let actual: GenerateContentRequest = serde_json::from_str(json).unwrap();
511            assert_eq!(actual, expected, "deserialize test failed: {}", name);
512            //test serialize
513            let serialized = serde_json::to_string(&expected).unwrap();
514            let actual: GenerateContentRequest = serde_json::from_str(&serialized).unwrap();
515            assert_eq!(actual, expected, "serialize test failed: {}", name);
516        }
517    }
518
519    #[test]
520    fn process() {
521        let tests = vec![
522            (
523                "[(model, text)]",
524                vec![Content {
525                    role: Role::Model,
526                    parts: vec![Part::Text("hi".to_string())],
527                }],
528                vec![
529                    Content {
530                        role: Role::User,
531                        parts: vec![Part::Text("Starting the conversation...".to_string())],
532                    },
533                    Content {
534                        role: Role::Model,
535                        parts: vec![Part::Text("hi".to_string())],
536                    },
537                    Content {
538                        role: Role::User,
539                        parts: vec![Part::Text("continue".to_string())],
540                    },
541                ],
542            ),
543            (
544                "[(user, text)]",
545                vec![Content {
546                    role: Role::User,
547                    parts: vec![Part::Text("hi".to_string())],
548                }],
549                vec![Content {
550                    role: Role::User,
551                    parts: vec![Part::Text("hi".to_string())],
552                }],
553            ),
554        ];
555        for (name, contents, want) in tests {
556            let got = process_contents(&contents);
557            assert_eq!(got, want, "test failed: {}", name)
558        }
559    }
560}