async_gemini/models/generate_content/request.rs
1use crate::{
2 models::generate_content::HarmCategory,
3 util::{deserialize_obj_or_vec, deserialize_option_obj_or_vec},
4};
5
6use serde::{Deserialize, Serialize};
7
8use super::{Content, Part, Role};
9
10/// when deserilization:
11/// - google api support both camelCase and snake_case key, but we only support camel case.
12/// - google api allow trailling comma, but not here
13#[derive(Debug, Serialize, Deserialize, Default, Clone, PartialEq)]
14#[serde(rename_all = "camelCase")]
15pub struct GenerateContentRequest {
16 // contents must start with user and alternate between user and model, and end with user or function response
17 #[serde(deserialize_with = "deserialize_obj_or_vec")]
18 pub contents: Vec<Content>,
19 /// A piece of code that enables the system to interact with external systems to perform an action, or set of actions, outside of knowledge and scope of the model.
20 #[serde(skip_serializing_if = "Option::is_none")]
21 pub tools: Option<Vec<Tool>>,
22 #[serde(skip_serializing_if = "Option::is_none")]
23 #[serde(deserialize_with = "deserialize_option_obj_or_vec", default)]
24 pub safety_settings: Option<Vec<SafetySetting>>,
25 #[serde(skip_serializing_if = "Option::is_none")]
26 pub generation_config: Option<GenerateionConfig>,
27}
28
29#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
30#[serde(rename_all = "camelCase")]
31pub struct Tool {
32 /// One or more function declarations. Each function declaration contains information about one function that includes the following:
33 /// name The name of the function to call. Must start with a letter or an underscore. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 64.
34 /// description (optional). The description and purpose of the function. The model uses this to decide how and whether to call the function. For the best results, we recommend that you include a description.
35 /// parameters The parameters of this function in a format that's compatible with the OpenAPI schema format.
36 /// For more information, see Function calling.
37 function_declarations: Vec<FunctionTool>,
38}
39
40#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
41pub struct FunctionTool {
42 name: String,
43 description: Option<String>,
44 parameters: serde_json::Value,
45}
46
47#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
48pub struct SafetySetting {
49 category: HarmCategory,
50 threshold: SafetySettingThreshold,
51}
52
53/// The threshold for blocking responses that could belong to the specified safety category based on probability.
54#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
55#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
56pub enum SafetySettingThreshold {
57 BlockNone,
58 BlockLowAndAbove,
59 BlockMedAndAbove,
60 BlockOnlyHigh,
61}
62
63#[derive(Debug, Serialize, Deserialize, Default, Clone, PartialEq)]
64#[serde(rename_all = "camelCase")]
65pub struct GenerateionConfig {
66 /// The temperature is used for sampling during the response generation, which occurs when topP and topK are applied. Temperature controls the degree of randomness in token selection. Lower temperatures are good for prompts that require a more deterministic and less open-ended or creative response, while higher temperatures can lead to more diverse or creative results. A temperature of 0 is deterministic: the highest probability response is always selected.
67 /// Range: 0.0 - 1.0
68 /// Default for gemini-1.0-pro: 0.9
69 /// Default for gemini-1.0-pro-vision: 0.4
70 pub temperature: Option<f32>,
71 /// Top-P changes how the model selects tokens for output. Tokens are selected from the most (see top-K) to least probable until the sum of their probabilities equals the top-P value. For example, if tokens A, B, and C have a probability of 0.3, 0.2, and 0.1 and the top-P value is 0.5, then the model will select either A or B as the next token by using temperature and excludes C as a candidate.
72 /// Specify a lower value for less random responses and a higher value for more random responses.
73 /// Range: 0.0 - 1.0
74 /// Default: 1.0
75 pub top_p: Option<f32>,
76 /// Top-K changes how the model selects tokens for output. A top-K of 1 means the next selected token is the most probable among all tokens in the model's vocabulary (also called greedy decoding), while a top-K of 3 means that the next token is selected from among the three most probable tokens by using temperature.
77 /// For each token selection step, the top-K tokens with the highest probabilities are sampled. Then tokens are further filtered based on top-P with the final token selected using temperature sampling.
78 /// Specify a lower value for less random responses and a higher value for more random responses.
79 /// Range: 1-40
80 /// Default for gemini-1.0-pro-vision: 32
81 /// Default for gemini-1.0-pro: none
82 pub top_k: Option<u32>,
83 /// The number of response variations to return.
84 /// This value must be 1.
85 pub candidate_count: Option<u32>,
86 /// Maximum number of tokens that can be generated in the response. A token is approximately four characters. 100 tokens correspond to roughly 60-80 words.
87 /// Specify a lower value for shorter responses and a higher value for potentially longer responses.
88 /// Range for gemini-1.0-pro: 1-8192 (default: 8192)
89 /// Range for gemini-1.0-pro-vision: 1-2048 (default: 2048)
90 pub max_output_tokens: Option<u32>,
91 /// Specifies a list of strings that tells the model to stop generating text if one of the strings is encountered in the response. If a string appears multiple times in the response, then the response truncates where it's first encountered. The strings are case-sensitive.
92 /// For example, if the following is the returned response when stopSequences isn't specified:
93 /// public static string reverse(string myString)
94 /// Then the returned response with stopSequences set to ["Str","reverse"] is:
95 /// public static string
96 /// Maximum 5 items in the list.
97 pub stop_sequences: Option<Vec<String>>,
98}
99
100/// Gemini require contents:
101/// 1. start with "user" role
102/// 2. alternate between "user" and "model" role
103/// 3. end with "user" role or function response
104pub fn process_contents(contents: &[Content]) -> Vec<Content> {
105 let mut filtered = Vec::with_capacity(contents.len());
106 if contents.is_empty() {
107 return filtered;
108 }
109 let mut prev_role: Option<Role> = None;
110 for content in contents {
111 if let Some(pr) = prev_role {
112 if pr == content.role {
113 if let Some(last) = filtered.last_mut() {
114 last.parts.extend(content.parts.clone());
115 };
116 prev_role = Some(content.role);
117 continue;
118 }
119 }
120 filtered.push(content.clone());
121 prev_role = Some(content.role);
122 }
123
124 if let Some(first) = filtered.first() {
125 if first.role == Role::Model {
126 filtered.insert(
127 0,
128 Content {
129 role: Role::User,
130 parts: vec![Part::Text("Starting the conversation...".to_string())],
131 },
132 )
133 }
134 }
135
136 if let Some(last) = filtered.last() {
137 if last.role == Role::Model {
138 filtered.push(Content {
139 role: Role::User,
140 parts: vec![Part::Text("continue".to_string())],
141 });
142 }
143 }
144
145 filtered
146}
147
148#[cfg(test)]
149mod tests {
150 use serde_json::json;
151
152 use crate::models::generate_content::{Content, FileData, Part, Role};
153
154 use super::*;
155 #[test]
156 fn serde() {
157 let tests = vec![
158 (
159 "simple",
160 r#"{"contents": {"role": "user","parts": {"text": "Give me a recipe for banana bread."}}}"#,
161 GenerateContentRequest {
162 contents: vec![Content {
163 role: Role::User,
164 parts: vec![Part::Text(
165 "Give me a recipe for banana bread.".to_string(),
166 )],
167 }],
168 ..Default::default()
169 },
170 ),
171 (
172 "text",
173 r#"{
174 "contents":
175 {
176 "role": "user",
177 "parts":
178 {
179 "text": "Give me a recipe for banana bread."
180 }
181 },
182 "safetySettings":
183 {
184 "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
185 "threshold": "BLOCK_LOW_AND_ABOVE"
186 },
187 "generationConfig":
188 {
189 "temperature": 0.2,
190 "topP": 0.8,
191 "topK": 40
192 }
193 }"#,
194 GenerateContentRequest {
195 contents: vec![Content {
196 role: Role::User,
197 parts: vec![Part::Text(
198 "Give me a recipe for banana bread.".to_string(),
199 )],
200 }],
201 safety_settings: Some(vec![SafetySetting {
202 category: HarmCategory::SexuallyExplicit,
203 threshold: SafetySettingThreshold::BlockLowAndAbove,
204 }]),
205 generation_config: Some(GenerateionConfig {
206 temperature: Some(0.2),
207 top_p: Some(0.8),
208 top_k: Some(40),
209 ..Default::default()
210 }),
211 ..Default::default()
212 },
213 ),
214 (
215 "chat",
216 r#"{
217 "contents": [
218 {
219 "role": "USER",
220 "parts": { "text": "Hello!" }
221 },
222 {
223 "role": "MODEL",
224 "parts": { "text": "Argh! What brings ye to my ship?" }
225 },
226 {
227 "role": "USER",
228 "parts": { "text": "Wow! You are a real-life priate!" }
229 }
230 ],
231 "safetySettings": {
232 "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
233 "threshold": "BLOCK_LOW_AND_ABOVE"
234 },
235 "generationConfig": {
236 "temperature": 0.2,
237 "topP": 0.8,
238 "topK": 40,
239 "maxOutputTokens": 200
240 }
241 }"#,
242 GenerateContentRequest {
243 contents: vec![
244 Content {
245 role: Role::User,
246 parts: vec![Part::Text(
247 "Hello!".to_string(),
248 )],
249 },
250 Content {
251 role: Role::Model,
252 parts: vec![Part::Text(
253 "Argh! What brings ye to my ship?".to_string(),
254 )],
255 },
256 Content {
257 role: Role::User,
258 parts: vec![Part::Text(
259 "Wow! You are a real-life priate!".to_string(),
260 )],
261 },
262 ],
263 safety_settings: Some(vec![SafetySetting {
264 category: HarmCategory::SexuallyExplicit,
265 threshold: SafetySettingThreshold::BlockLowAndAbove,
266 }]),
267 generation_config: Some(GenerateionConfig {
268 temperature: Some(0.2),
269 top_p: Some(0.8),
270 top_k: Some(40),
271 max_output_tokens: Some(200),
272 ..Default::default()
273 }),
274 ..Default::default()
275 },
276 ),
277 (
278 "multimodal",
279 r#"{
280 "contents": {
281 "role": "user",
282 "parts": [
283 {
284 "fileData": {
285 "mimeType": "image/jpeg",
286 "fileUri": "gs://cloud-samples-data/ai-platform/flowers/daisy/10559679065_50d2b16f6d.jpg"
287 }
288 },
289 {
290 "text": "Describe this picture."
291 }
292 ]
293 },
294 "safetySettings": {
295 "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
296 "threshold": "BLOCK_LOW_AND_ABOVE"
297 },
298 "generationConfig": {
299 "temperature": 0.4,
300 "topP": 1.0,
301 "topK": 32,
302 "maxOutputTokens": 2048
303 }
304 }"#,
305 GenerateContentRequest {
306 contents: vec![Content {
307 role: Role::User,
308 parts: vec![
309 Part::File(FileData {
310 mime_type: "image/jpeg".to_string(),
311 file_uri: "gs://cloud-samples-data/ai-platform/flowers/daisy/10559679065_50d2b16f6d.jpg".to_string(),
312 video_metadata: None,
313 }),
314 Part::Text(
315 "Describe this picture.".to_string(),
316 ),
317 ],
318 }],
319 safety_settings: Some(vec![SafetySetting {
320 category: HarmCategory::SexuallyExplicit,
321 threshold: SafetySettingThreshold::BlockLowAndAbove,
322 }]),
323 generation_config: Some(GenerateionConfig {
324 temperature: Some(0.4),
325 top_p: Some(1.0),
326 top_k: Some(32),
327 max_output_tokens: Some(2048),
328 ..Default::default()
329 }),
330 ..Default::default()
331 },
332 ),
333 (
334 "function",
335 r#"{
336 "contents": {
337 "role": "user",
338 "parts": {
339 "text": "Which theaters in Mountain View show Barbie movie?"
340 }
341 },
342 "tools": [
343 {
344 "functionDeclarations": [
345 {
346 "name": "find_movies",
347 "description": "find movie titles currently playing in theaters based on any description, genre, title words, etc.",
348 "parameters": {
349 "type": "object",
350 "properties": {
351 "location": {
352 "type": "string",
353 "description": "The city and state, e.g. San Francisco, CA or a zip code e.g. 95616"
354 },
355 "description": {
356 "type": "string",
357 "description": "Any kind of description including category or genre, title words, attributes, etc."
358 }
359 },
360 "required": [
361 "description"
362 ]
363 }
364 },
365 {
366 "name": "find_theaters",
367 "description": "find theaters based on location and optionally movie title which are is currently playing in theaters",
368 "parameters": {
369 "type": "object",
370 "properties": {
371 "location": {
372 "type": "string",
373 "description": "The city and state, e.g. San Francisco, CA or a zip code e.g. 95616"
374 },
375 "movie": {
376 "type": "string",
377 "description": "Any movie title"
378 }
379 },
380 "required": [
381 "location"
382 ]
383 }
384 },
385 {
386 "name": "get_showtimes",
387 "description": "Find the start times for movies playing in a specific theater",
388 "parameters": {
389 "type": "object",
390 "properties": {
391 "location": {
392 "type": "string",
393 "description": "The city and state, e.g. San Francisco, CA or a zip code e.g. 95616"
394 },
395 "movie": {
396 "type": "string",
397 "description": "Any movie title"
398 },
399 "theater": {
400 "type": "string",
401 "description": "Name of the theater"
402 },
403 "date": {
404 "type": "string",
405 "description": "Date for requested showtime"
406 }
407 },
408 "required": [
409 "location",
410 "movie",
411 "theater",
412 "date"
413 ]
414 }
415 }
416 ]
417 }
418 ]
419 }"#,
420 GenerateContentRequest {
421 contents: vec![Content {
422 role: Role::User,
423 parts: vec![
424 Part::Text(
425 "Which theaters in Mountain View show Barbie movie?".to_string(),
426 ),
427 ],
428 }],
429 tools:Some(vec![Tool{
430 function_declarations:vec![
431 FunctionTool{
432 name:"find_movies".to_string(),
433 description:Some("find movie titles currently playing in theaters based on any description, genre, title words, etc.".to_string()),
434 parameters:json!({
435 "type": "object",
436 "properties": {
437 "location": {
438 "type": "string",
439 "description": "The city and state, e.g. San Francisco, CA or a zip code e.g. 95616"
440 },
441 "description": {
442 "type": "string",
443 "description": "Any kind of description including category or genre, title words, attributes, etc."
444 }
445 },
446 "required": [
447 "description"
448 ]
449 })
450 },
451 FunctionTool{
452 name:"find_theaters".to_string(),
453 description:Some("find theaters based on location and optionally movie title which are is currently playing in theaters".to_string()),
454 parameters:json!({
455 "type": "object",
456 "properties": {
457 "location": {
458 "type": "string",
459 "description": "The city and state, e.g. San Francisco, CA or a zip code e.g. 95616"
460 },
461 "movie": {
462 "type": "string",
463 "description": "Any movie title"
464 }
465 },
466 "required": [
467 "location"
468 ]
469 })
470 },
471 FunctionTool{
472 name:"get_showtimes".to_string(),
473 description:Some("Find the start times for movies playing in a specific theater".to_string()),
474 parameters:json!({
475 "type": "object",
476 "properties": {
477 "location": {
478 "type": "string",
479 "description": "The city and state, e.g. San Francisco, CA or a zip code e.g. 95616"
480 },
481 "movie": {
482 "type": "string",
483 "description": "Any movie title"
484 },
485 "theater": {
486 "type": "string",
487 "description": "Name of the theater"
488 },
489 "date": {
490 "type": "string",
491 "description": "Date for requested showtime"
492 }
493 },
494 "required": [
495 "location",
496 "movie",
497 "theater",
498 "date"
499 ]
500 })
501 }
502 ]
503 }]),
504 ..Default::default()
505 },
506 ),
507 ];
508 for (name, json, expected) in tests {
509 //test deserialize
510 let actual: GenerateContentRequest = serde_json::from_str(json).unwrap();
511 assert_eq!(actual, expected, "deserialize test failed: {}", name);
512 //test serialize
513 let serialized = serde_json::to_string(&expected).unwrap();
514 let actual: GenerateContentRequest = serde_json::from_str(&serialized).unwrap();
515 assert_eq!(actual, expected, "serialize test failed: {}", name);
516 }
517 }
518
519 #[test]
520 fn process() {
521 let tests = vec![
522 (
523 "[(model, text)]",
524 vec![Content {
525 role: Role::Model,
526 parts: vec![Part::Text("hi".to_string())],
527 }],
528 vec![
529 Content {
530 role: Role::User,
531 parts: vec![Part::Text("Starting the conversation...".to_string())],
532 },
533 Content {
534 role: Role::Model,
535 parts: vec![Part::Text("hi".to_string())],
536 },
537 Content {
538 role: Role::User,
539 parts: vec![Part::Text("continue".to_string())],
540 },
541 ],
542 ),
543 (
544 "[(user, text)]",
545 vec![Content {
546 role: Role::User,
547 parts: vec![Part::Text("hi".to_string())],
548 }],
549 vec![Content {
550 role: Role::User,
551 parts: vec![Part::Text("hi".to_string())],
552 }],
553 ),
554 ];
555 for (name, contents, want) in tests {
556 let got = process_contents(&contents);
557 assert_eq!(got, want, "test failed: {}", name)
558 }
559 }
560}