1pub const GEMINI_2_0_FLASH: &str = "gemini-2.0-flash";
8pub const GEMINI_1_5_FLASH: &str = "gemini-1.5-flash";
10pub const GEMINI_1_5_PRO: &str = "gemini-1.5-pro";
12pub const GEMINI_1_5_PRO_8B: &str = "gemini-1.5-pro-8b";
14pub const GEMINI_1_0_PRO: &str = "gemini-1.0-pro";
16
17use gemini_api_types::{
18 Content, FunctionDeclaration, GenerateContentRequest, GenerateContentResponse,
19 GenerationConfig, Part, Role, Tool,
20};
21use serde_json::{Map, Value};
22use std::convert::TryFrom;
23
24use crate::{
25 completion::{self, CompletionError, CompletionRequest},
26 OneOrMany,
27};
28
29use super::Client;
30
31#[derive(Clone)]
36pub struct CompletionModel {
37 client: Client,
38 pub model: String,
39}
40
41impl CompletionModel {
42 pub fn new(client: Client, model: &str) -> Self {
43 Self {
44 client,
45 model: model.to_string(),
46 }
47 }
48}
49
50impl completion::CompletionModel for CompletionModel {
51 type Response = GenerateContentResponse;
52
53 #[cfg_attr(feature = "worker", worker::send)]
54 async fn completion(
55 &self,
56 mut completion_request: CompletionRequest,
57 ) -> Result<completion::CompletionResponse<GenerateContentResponse>, CompletionError> {
58 let mut full_history = Vec::new();
59 full_history.append(&mut completion_request.chat_history);
60
61 full_history.push(completion_request.prompt_with_context());
62
63 let additional_params = completion_request
65 .additional_params
66 .unwrap_or_else(|| Value::Object(Map::new()));
67 let mut generation_config = serde_json::from_value::<GenerationConfig>(additional_params)?;
68
69 if let Some(temp) = completion_request.temperature {
71 generation_config.temperature = Some(temp);
72 }
73
74 if let Some(max_tokens) = completion_request.max_tokens {
76 generation_config.max_output_tokens = Some(max_tokens);
77 }
78
79 let system_instruction = completion_request.preamble.clone().map(|preamble| Content {
80 parts: OneOrMany::one(preamble.into()),
81 role: Some(Role::Model),
82 });
83
84 let request = GenerateContentRequest {
85 contents: full_history
86 .into_iter()
87 .map(|msg| {
88 msg.try_into()
89 .map_err(|e| CompletionError::RequestError(Box::new(e)))
90 })
91 .collect::<Result<Vec<_>, _>>()?,
92 generation_config: Some(generation_config),
93 safety_settings: None,
94 tools: Some(
95 completion_request
96 .tools
97 .into_iter()
98 .map(Tool::try_from)
99 .collect::<Result<Vec<_>, _>>()?,
100 ),
101 tool_config: None,
102 system_instruction,
103 };
104
105 tracing::debug!(
106 "Sending completion request to Gemini API {}",
107 serde_json::to_string_pretty(&request)?
108 );
109
110 let response = self
111 .client
112 .post(&format!("/v1beta/models/{}:generateContent", self.model))
113 .json(&request)
114 .send()
115 .await?;
116
117 if response.status().is_success() {
118 let response = response.json::<GenerateContentResponse>().await?;
119 match response.usage_metadata {
120 Some(ref usage) => tracing::info!(target: "rig",
121 "Gemini completion token usage: {}",
122 usage
123 ),
124 None => tracing::info!(target: "rig",
125 "Gemini completion token usage: n/a",
126 ),
127 }
128
129 tracing::debug!("Received response");
130
131 Ok(completion::CompletionResponse::try_from(response))
132 } else {
133 Err(CompletionError::ProviderError(response.text().await?))
134 }?
135 }
136}
137
138impl TryFrom<completion::ToolDefinition> for Tool {
139 type Error = CompletionError;
140
141 fn try_from(tool: completion::ToolDefinition) -> Result<Self, Self::Error> {
142 Ok(Self {
143 function_declarations: FunctionDeclaration {
144 name: tool.name,
145 description: tool.description,
146 parameters: Some(tool.parameters.try_into()?),
147 },
148 code_execution: None,
149 })
150 }
151}
152
153impl TryFrom<GenerateContentResponse> for completion::CompletionResponse<GenerateContentResponse> {
154 type Error = CompletionError;
155
156 fn try_from(response: GenerateContentResponse) -> Result<Self, Self::Error> {
157 let candidate = response.candidates.first().ok_or_else(|| {
158 CompletionError::ResponseError("No response candidates in response".into())
159 })?;
160
161 let content = candidate
162 .content
163 .parts
164 .iter()
165 .map(|part| {
166 Ok(match part {
167 Part::Text(text) => completion::AssistantContent::text(text),
168 Part::FunctionCall(function_call) => completion::AssistantContent::tool_call(
169 &function_call.name,
170 &function_call.name,
171 function_call.args.clone(),
172 ),
173 _ => {
174 return Err(CompletionError::ResponseError(
175 "Response did not contain a message or tool call".into(),
176 ))
177 }
178 })
179 })
180 .collect::<Result<Vec<_>, _>>()?;
181
182 let choice = OneOrMany::many(content).map_err(|_| {
183 CompletionError::ResponseError(
184 "Response contained no message or tool call (empty)".to_owned(),
185 )
186 })?;
187
188 Ok(completion::CompletionResponse {
189 choice,
190 raw_response: response,
191 })
192 }
193}
194
195pub mod gemini_api_types {
196 use std::{collections::HashMap, convert::Infallible, str::FromStr};
197
198 use serde::{Deserialize, Serialize};
202 use serde_json::Value;
203
204 use crate::{
205 completion::CompletionError,
206 message::{self, MimeType as _},
207 one_or_many::string_or_one_or_many,
208 providers::gemini::gemini_api_types::{CodeExecutionResult, ExecutableCode},
209 OneOrMany,
210 };
211
212 #[derive(Debug, Deserialize)]
220 #[serde(rename_all = "camelCase")]
221 pub struct GenerateContentResponse {
222 pub candidates: Vec<ContentCandidate>,
224 pub prompt_feedback: Option<PromptFeedback>,
226 pub usage_metadata: Option<UsageMetadata>,
228 pub model_version: Option<String>,
229 }
230
231 #[derive(Debug, Deserialize)]
233 #[serde(rename_all = "camelCase")]
234 pub struct ContentCandidate {
235 pub content: Content,
237 pub finish_reason: Option<FinishReason>,
240 pub safety_ratings: Option<Vec<SafetyRating>>,
243 pub citation_metadata: Option<CitationMetadata>,
247 pub token_count: Option<i32>,
249 pub avg_logprobs: Option<f64>,
251 pub logprobs_result: Option<LogprobsResult>,
253 pub index: Option<i32>,
255 }
256 #[derive(Debug, Deserialize, Serialize)]
257 pub struct Content {
258 #[serde(deserialize_with = "string_or_one_or_many")]
260 pub parts: OneOrMany<Part>,
261 pub role: Option<Role>,
264 }
265
266 impl TryFrom<message::Message> for Content {
267 type Error = message::MessageError;
268
269 fn try_from(msg: message::Message) -> Result<Self, Self::Error> {
270 Ok(match msg {
271 message::Message::User { content } => Content {
272 parts: content.try_map(|c| c.try_into())?,
273 role: Some(Role::User),
274 },
275 message::Message::Assistant { content } => Content {
276 role: Some(Role::Model),
277 parts: content.map(|content| content.into()),
278 },
279 })
280 }
281 }
282
283 impl TryFrom<Content> for message::Message {
284 type Error = message::MessageError;
285
286 fn try_from(content: Content) -> Result<Self, Self::Error> {
287 match content.role {
288 Some(Role::User) | None => Ok(message::Message::User {
289 content: content.parts.try_map(|part| {
290 Ok(match part {
291 Part::Text(text) => message::UserContent::text(text),
292 Part::InlineData(inline_data) => {
293 let mime_type =
294 message::MediaType::from_mime_type(&inline_data.mime_type);
295
296 match mime_type {
297 Some(message::MediaType::Image(media_type)) => {
298 message::UserContent::image(
299 inline_data.data,
300 Some(message::ContentFormat::default()),
301 Some(media_type),
302 Some(message::ImageDetail::default()),
303 )
304 }
305 Some(message::MediaType::Document(media_type)) => {
306 message::UserContent::document(
307 inline_data.data,
308 Some(message::ContentFormat::default()),
309 Some(media_type),
310 )
311 }
312 Some(message::MediaType::Audio(media_type)) => {
313 message::UserContent::audio(
314 inline_data.data,
315 Some(message::ContentFormat::default()),
316 Some(media_type),
317 )
318 }
319 _ => {
320 return Err(message::MessageError::ConversionError(
321 format!("Unsupported media type {:?}", mime_type),
322 ))
323 }
324 }
325 }
326 _ => {
327 return Err(message::MessageError::ConversionError(format!(
328 "Unsupported gemini content part type: {:?}",
329 part
330 )))
331 }
332 })
333 })?,
334 }),
335 Some(Role::Model) => Ok(message::Message::Assistant {
336 content: content.parts.try_map(|part| {
337 Ok(match part {
338 Part::Text(text) => message::AssistantContent::text(text),
339 Part::FunctionCall(function_call) => {
340 message::AssistantContent::ToolCall(function_call.into())
341 }
342 _ => {
343 return Err(message::MessageError::ConversionError(format!(
344 "Unsupported part type: {:?}",
345 part
346 )))
347 }
348 })
349 })?,
350 }),
351 }
352 }
353 }
354
355 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
356 #[serde(rename_all = "lowercase")]
357 pub enum Role {
358 User,
359 Model,
360 }
361
362 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
366 #[serde(rename_all = "camelCase")]
367 pub enum Part {
368 Text(String),
369 InlineData(Blob),
370 FunctionCall(FunctionCall),
371 FunctionResponse(FunctionResponse),
372 FileData(FileData),
373 ExecutableCode(ExecutableCode),
374 CodeExecutionResult(CodeExecutionResult),
375 }
376
377 impl From<String> for Part {
378 fn from(text: String) -> Self {
379 Self::Text(text)
380 }
381 }
382
383 impl From<&str> for Part {
384 fn from(text: &str) -> Self {
385 Self::Text(text.to_string())
386 }
387 }
388
389 impl FromStr for Part {
390 type Err = Infallible;
391
392 fn from_str(s: &str) -> Result<Self, Self::Err> {
393 Ok(s.into())
394 }
395 }
396
397 impl TryFrom<message::UserContent> for Part {
398 type Error = message::MessageError;
399
400 fn try_from(content: message::UserContent) -> Result<Self, Self::Error> {
401 match content {
402 message::UserContent::Text(message::Text { text }) => Ok(Self::Text(text)),
403 message::UserContent::ToolResult(message::ToolResult { id, content }) => {
404 let content = match content.first() {
405 message::ToolResultContent::Text(text) => text.text,
406 message::ToolResultContent::Image(_) => {
407 return Err(message::MessageError::ConversionError(
408 "Tool result content must be text".to_string(),
409 ))
410 }
411 };
412 Ok(Part::FunctionResponse(FunctionResponse {
413 name: id,
414 response: Some(serde_json::from_str(&content).map_err(|e| {
415 message::MessageError::ConversionError(format!(
416 "Failed to parse tool response: {}",
417 e
418 ))
419 })?),
420 }))
421 }
422 message::UserContent::Image(message::Image {
423 data, media_type, ..
424 }) => match media_type {
425 Some(media_type) => match media_type {
426 message::ImageMediaType::JPEG
427 | message::ImageMediaType::PNG
428 | message::ImageMediaType::WEBP
429 | message::ImageMediaType::HEIC
430 | message::ImageMediaType::HEIF => Ok(Self::InlineData(Blob {
431 mime_type: media_type.to_mime_type().to_owned(),
432 data,
433 })),
434 _ => Err(message::MessageError::ConversionError(format!(
435 "Unsupported image media type {:?}",
436 media_type
437 ))),
438 },
439 None => Err(message::MessageError::ConversionError(
440 "Media type for image is required for Anthropic".to_string(),
441 )),
442 },
443 message::UserContent::Document(message::Document {
444 data, media_type, ..
445 }) => match media_type {
446 Some(media_type) => match media_type {
447 message::DocumentMediaType::PDF
448 | message::DocumentMediaType::TXT
449 | message::DocumentMediaType::RTF
450 | message::DocumentMediaType::HTML
451 | message::DocumentMediaType::CSS
452 | message::DocumentMediaType::MARKDOWN
453 | message::DocumentMediaType::CSV
454 | message::DocumentMediaType::XML => Ok(Self::InlineData(Blob {
455 mime_type: media_type.to_mime_type().to_owned(),
456 data,
457 })),
458 _ => Err(message::MessageError::ConversionError(format!(
459 "Unsupported document media type {:?}",
460 media_type
461 ))),
462 },
463 None => Err(message::MessageError::ConversionError(
464 "Media type for document is required for Anthropic".to_string(),
465 )),
466 },
467 message::UserContent::Audio(message::Audio {
468 data, media_type, ..
469 }) => match media_type {
470 Some(media_type) => Ok(Self::InlineData(Blob {
471 mime_type: media_type.to_mime_type().to_owned(),
472 data,
473 })),
474 None => Err(message::MessageError::ConversionError(
475 "Media type for audio is required for Anthropic".to_string(),
476 )),
477 },
478 }
479 }
480 }
481
482 impl From<message::AssistantContent> for Part {
483 fn from(content: message::AssistantContent) -> Self {
484 match content {
485 message::AssistantContent::Text(message::Text { text }) => text.into(),
486 message::AssistantContent::ToolCall(tool_call) => tool_call.into(),
487 }
488 }
489 }
490
491 impl From<message::ToolCall> for Part {
492 fn from(tool_call: message::ToolCall) -> Self {
493 Self::FunctionCall(FunctionCall {
494 name: tool_call.function.name,
495 args: tool_call.function.arguments,
496 })
497 }
498 }
499
500 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
503 #[serde(rename_all = "camelCase")]
504 pub struct Blob {
505 pub mime_type: String,
508 pub data: String,
510 }
511
512 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
515 pub struct FunctionCall {
516 pub name: String,
519 pub args: serde_json::Value,
521 }
522
523 impl From<FunctionCall> for message::ToolCall {
524 fn from(function_call: FunctionCall) -> Self {
525 Self {
526 id: function_call.name.clone(),
527 function: message::ToolFunction {
528 name: function_call.name,
529 arguments: function_call.args,
530 },
531 }
532 }
533 }
534
535 impl From<message::ToolCall> for FunctionCall {
536 fn from(tool_call: message::ToolCall) -> Self {
537 Self {
538 name: tool_call.function.name,
539 args: tool_call.function.arguments,
540 }
541 }
542 }
543
544 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
548 pub struct FunctionResponse {
549 pub name: String,
552 pub response: Option<HashMap<String, serde_json::Value>>,
554 }
555
556 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
558 #[serde(rename_all = "camelCase")]
559 pub struct FileData {
560 pub mime_type: Option<String>,
562 pub file_uri: String,
564 }
565
566 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
567 pub struct SafetyRating {
568 pub category: HarmCategory,
569 pub probability: HarmProbability,
570 }
571
572 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
573 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
574 pub enum HarmProbability {
575 HarmProbabilityUnspecified,
576 Negligible,
577 Low,
578 Medium,
579 High,
580 }
581
582 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
583 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
584 pub enum HarmCategory {
585 HarmCategoryUnspecified,
586 HarmCategoryDerogatory,
587 HarmCategoryToxicity,
588 HarmCategoryViolence,
589 HarmCategorySexually,
590 HarmCategoryMedical,
591 HarmCategoryDangerous,
592 HarmCategoryHarassment,
593 HarmCategoryHateSpeech,
594 HarmCategorySexuallyExplicit,
595 HarmCategoryDangerousContent,
596 HarmCategoryCivicIntegrity,
597 }
598
599 #[derive(Debug, Deserialize)]
600 #[serde(rename_all = "camelCase")]
601 pub struct UsageMetadata {
602 pub prompt_token_count: i32,
603 #[serde(skip_serializing_if = "Option::is_none")]
604 pub cached_content_token_count: Option<i32>,
605 pub candidates_token_count: i32,
606 pub total_token_count: i32,
607 }
608
609 impl std::fmt::Display for UsageMetadata {
610 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
611 write!(
612 f,
613 "Prompt token count: {}\nCached content token count: {}\nCandidates token count: {}\nTotal token count: {}",
614 self.prompt_token_count,
615 match self.cached_content_token_count {
616 Some(count) => count.to_string(),
617 None => "n/a".to_string(),
618 },
619 self.candidates_token_count,
620 self.total_token_count
621 )
622 }
623 }
624
625 #[derive(Debug, Deserialize)]
627 #[serde(rename_all = "camelCase")]
628 pub struct PromptFeedback {
629 pub block_reason: Option<BlockReason>,
631 pub safety_ratings: Option<Vec<SafetyRating>>,
633 }
634
635 #[derive(Debug, Deserialize)]
637 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
638 pub enum BlockReason {
639 BlockReasonUnspecified,
641 Safety,
643 Other,
645 Blocklist,
647 ProhibitedContent,
649 }
650
651 #[derive(Debug, Deserialize)]
652 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
653 pub enum FinishReason {
654 FinishReasonUnspecified,
656 Stop,
658 MaxTokens,
660 Safety,
662 Recitation,
664 Language,
666 Other,
668 Blocklist,
670 ProhibitedContent,
672 Spii,
674 MalformedFunctionCall,
676 }
677
678 #[derive(Debug, Deserialize)]
679 #[serde(rename_all = "camelCase")]
680 pub struct CitationMetadata {
681 pub citation_sources: Vec<CitationSource>,
682 }
683
684 #[derive(Debug, Deserialize)]
685 #[serde(rename_all = "camelCase")]
686 pub struct CitationSource {
687 #[serde(skip_serializing_if = "Option::is_none")]
688 pub uri: Option<String>,
689 #[serde(skip_serializing_if = "Option::is_none")]
690 pub start_index: Option<i32>,
691 #[serde(skip_serializing_if = "Option::is_none")]
692 pub end_index: Option<i32>,
693 #[serde(skip_serializing_if = "Option::is_none")]
694 pub license: Option<String>,
695 }
696
697 #[derive(Debug, Deserialize)]
698 #[serde(rename_all = "camelCase")]
699 pub struct LogprobsResult {
700 pub top_candidate: Vec<TopCandidate>,
701 pub chosen_candidate: Vec<LogProbCandidate>,
702 }
703
704 #[derive(Debug, Deserialize)]
705 pub struct TopCandidate {
706 pub candidates: Vec<LogProbCandidate>,
707 }
708
709 #[derive(Debug, Deserialize)]
710 #[serde(rename_all = "camelCase")]
711 pub struct LogProbCandidate {
712 pub token: String,
713 pub token_id: String,
714 pub log_probability: f64,
715 }
716
717 #[derive(Debug, Deserialize, Serialize)]
722 #[serde(rename_all = "camelCase")]
723 pub struct GenerationConfig {
724 #[serde(skip_serializing_if = "Option::is_none")]
727 pub stop_sequences: Option<Vec<String>>,
728 #[serde(skip_serializing_if = "Option::is_none")]
734 pub response_mime_type: Option<String>,
735 #[serde(skip_serializing_if = "Option::is_none")]
739 pub response_schema: Option<Schema>,
740 #[serde(skip_serializing_if = "Option::is_none")]
743 pub candidate_count: Option<i32>,
744 #[serde(skip_serializing_if = "Option::is_none")]
747 pub max_output_tokens: Option<u64>,
748 #[serde(skip_serializing_if = "Option::is_none")]
751 pub temperature: Option<f64>,
752 #[serde(skip_serializing_if = "Option::is_none")]
759 pub top_p: Option<f64>,
760 #[serde(skip_serializing_if = "Option::is_none")]
766 pub top_k: Option<i32>,
767 #[serde(skip_serializing_if = "Option::is_none")]
773 pub presence_penalty: Option<f64>,
774 #[serde(skip_serializing_if = "Option::is_none")]
782 pub frequency_penalty: Option<f64>,
783 #[serde(skip_serializing_if = "Option::is_none")]
785 pub response_logprobs: Option<bool>,
786 #[serde(skip_serializing_if = "Option::is_none")]
789 pub logprobs: Option<i32>,
790 }
791
792 impl Default for GenerationConfig {
793 fn default() -> Self {
794 Self {
795 temperature: Some(1.0),
796 max_output_tokens: Some(4096),
797 stop_sequences: None,
798 response_mime_type: None,
799 response_schema: None,
800 candidate_count: None,
801 top_p: None,
802 top_k: None,
803 presence_penalty: None,
804 frequency_penalty: None,
805 response_logprobs: None,
806 logprobs: None,
807 }
808 }
809 }
810 #[derive(Debug, Deserialize, Serialize)]
814 pub struct Schema {
815 pub r#type: String,
816 #[serde(skip_serializing_if = "Option::is_none")]
817 pub format: Option<String>,
818 #[serde(skip_serializing_if = "Option::is_none")]
819 pub description: Option<String>,
820 #[serde(skip_serializing_if = "Option::is_none")]
821 pub nullable: Option<bool>,
822 #[serde(skip_serializing_if = "Option::is_none")]
823 pub r#enum: Option<Vec<String>>,
824 #[serde(skip_serializing_if = "Option::is_none")]
825 pub max_items: Option<i32>,
826 #[serde(skip_serializing_if = "Option::is_none")]
827 pub min_items: Option<i32>,
828 #[serde(skip_serializing_if = "Option::is_none")]
829 pub properties: Option<HashMap<String, Schema>>,
830 #[serde(skip_serializing_if = "Option::is_none")]
831 pub required: Option<Vec<String>>,
832 #[serde(skip_serializing_if = "Option::is_none")]
833 pub items: Option<Box<Schema>>,
834 }
835
836 impl TryFrom<Value> for Schema {
837 type Error = CompletionError;
838
839 fn try_from(value: Value) -> Result<Self, Self::Error> {
840 if let Some(obj) = value.as_object() {
841 Ok(Schema {
842 r#type: obj
843 .get("type")
844 .and_then(|v| {
845 if v.is_string() {
846 v.as_str().map(String::from)
847 } else if v.is_array() {
848 v.as_array()
849 .and_then(|arr| arr.first())
850 .and_then(|v| v.as_str().map(String::from))
851 } else {
852 None
853 }
854 })
855 .unwrap_or_default(),
856 format: obj.get("format").and_then(|v| v.as_str()).map(String::from),
857 description: obj
858 .get("description")
859 .and_then(|v| v.as_str())
860 .map(String::from),
861 nullable: obj.get("nullable").and_then(|v| v.as_bool()),
862 r#enum: obj.get("enum").and_then(|v| v.as_array()).map(|arr| {
863 arr.iter()
864 .filter_map(|v| v.as_str().map(String::from))
865 .collect()
866 }),
867 max_items: obj
868 .get("maxItems")
869 .and_then(|v| v.as_i64())
870 .map(|v| v as i32),
871 min_items: obj
872 .get("minItems")
873 .and_then(|v| v.as_i64())
874 .map(|v| v as i32),
875 properties: obj
876 .get("properties")
877 .and_then(|v| v.as_object())
878 .map(|map| {
879 map.iter()
880 .filter_map(|(k, v)| {
881 v.clone().try_into().ok().map(|schema| (k.clone(), schema))
882 })
883 .collect()
884 }),
885 required: obj.get("required").and_then(|v| v.as_array()).map(|arr| {
886 arr.iter()
887 .filter_map(|v| v.as_str().map(String::from))
888 .collect()
889 }),
890 items: obj
891 .get("items")
892 .map(|v| Box::new(v.clone().try_into().unwrap())),
893 })
894 } else {
895 Err(CompletionError::ResponseError(
896 "Expected a JSON object for Schema".into(),
897 ))
898 }
899 }
900 }
901
902 #[derive(Debug, Serialize)]
903 #[serde(rename_all = "camelCase")]
904 pub struct GenerateContentRequest {
905 pub contents: Vec<Content>,
906 pub tools: Option<Vec<Tool>>,
907 pub tool_config: Option<ToolConfig>,
908 pub generation_config: Option<GenerationConfig>,
910 pub safety_settings: Option<Vec<SafetySetting>>,
924 pub system_instruction: Option<Content>,
927 }
929
930 #[derive(Debug, Serialize)]
931 #[serde(rename_all = "camelCase")]
932 pub struct Tool {
933 pub function_declarations: FunctionDeclaration,
934 pub code_execution: Option<CodeExecution>,
935 }
936
937 #[derive(Debug, Serialize)]
938 #[serde(rename_all = "camelCase")]
939 pub struct FunctionDeclaration {
940 pub name: String,
941 pub description: String,
942 #[serde(skip_serializing_if = "Option::is_none")]
943 pub parameters: Option<Schema>,
944 }
945
946 #[derive(Debug, Serialize)]
947 #[serde(rename_all = "camelCase")]
948 pub struct ToolConfig {
949 pub schema: Option<Schema>,
950 }
951
952 #[derive(Debug, Serialize)]
953 #[serde(rename_all = "camelCase")]
954 pub struct CodeExecution {}
955
956 #[derive(Debug, Serialize)]
957 #[serde(rename_all = "camelCase")]
958 pub struct SafetySetting {
959 pub category: HarmCategory,
960 pub threshold: HarmBlockThreshold,
961 }
962
963 #[derive(Debug, Serialize)]
964 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
965 pub enum HarmBlockThreshold {
966 HarmBlockThresholdUnspecified,
967 BlockLowAndAbove,
968 BlockMediumAndAbove,
969 BlockOnlyHigh,
970 BlockNone,
971 Off,
972 }
973}
974
975#[cfg(test)]
976mod tests {
977 use crate::message;
978
979 use super::*;
980 use serde_json::json;
981
982 #[test]
983 fn test_deserialize_message_user() {
984 let raw_message = r#"{
985 "parts": [
986 {"text": "Hello, world!"},
987 {"inlineData": {"mimeType": "image/png", "data": "base64encodeddata"}},
988 {"functionCall": {"name": "test_function", "args": {"arg1": "value1"}}},
989 {"functionResponse": {"name": "test_function", "response": {"result": "success"}}},
990 {"fileData": {"mimeType": "application/pdf", "fileUri": "http://example.com/file.pdf"}},
991 {"executableCode": {"code": "print('Hello, world!')", "language": "PYTHON"}},
992 {"codeExecutionResult": {"output": "Hello, world!", "outcome": "OUTCOME_OK"}}
993 ],
994 "role": "user"
995 }"#;
996
997 let content: Content = {
998 let jd = &mut serde_json::Deserializer::from_str(raw_message);
999 serde_path_to_error::deserialize(jd).unwrap_or_else(|err| {
1000 panic!("Deserialization error at {}: {}", err.path(), err);
1001 })
1002 };
1003 assert_eq!(content.role, Some(Role::User));
1004 assert_eq!(content.parts.len(), 7);
1005
1006 let parts: Vec<Part> = content.parts.into_iter().collect();
1007
1008 if let Part::Text(text) = &parts[0] {
1009 assert_eq!(text, "Hello, world!");
1010 } else {
1011 panic!("Expected text part");
1012 }
1013
1014 if let Part::InlineData(inline_data) = &parts[1] {
1015 assert_eq!(inline_data.mime_type, "image/png");
1016 assert_eq!(inline_data.data, "base64encodeddata");
1017 } else {
1018 panic!("Expected inline data part");
1019 }
1020
1021 if let Part::FunctionCall(function_call) = &parts[2] {
1022 assert_eq!(function_call.name, "test_function");
1023 assert_eq!(
1024 function_call.args.as_object().unwrap().get("arg1").unwrap(),
1025 "value1"
1026 );
1027 } else {
1028 panic!("Expected function call part");
1029 }
1030
1031 if let Part::FunctionResponse(function_response) = &parts[3] {
1032 assert_eq!(function_response.name, "test_function");
1033 assert_eq!(
1034 function_response
1035 .response
1036 .as_ref()
1037 .unwrap()
1038 .get("result")
1039 .unwrap(),
1040 "success"
1041 );
1042 } else {
1043 panic!("Expected function response part");
1044 }
1045
1046 if let Part::FileData(file_data) = &parts[4] {
1047 assert_eq!(file_data.mime_type.as_ref().unwrap(), "application/pdf");
1048 assert_eq!(file_data.file_uri, "http://example.com/file.pdf");
1049 } else {
1050 panic!("Expected file data part");
1051 }
1052
1053 if let Part::ExecutableCode(executable_code) = &parts[5] {
1054 assert_eq!(executable_code.code, "print('Hello, world!')");
1055 } else {
1056 panic!("Expected executable code part");
1057 }
1058
1059 if let Part::CodeExecutionResult(code_execution_result) = &parts[6] {
1060 assert_eq!(
1061 code_execution_result.clone().output.unwrap(),
1062 "Hello, world!"
1063 );
1064 } else {
1065 panic!("Expected code execution result part");
1066 }
1067 }
1068
1069 #[test]
1070 fn test_deserialize_message_model() {
1071 let json_data = json!({
1072 "parts": [{"text": "Hello, user!"}],
1073 "role": "model"
1074 });
1075
1076 let content: Content = serde_json::from_value(json_data).unwrap();
1077 assert_eq!(content.role, Some(Role::Model));
1078 assert_eq!(content.parts.len(), 1);
1079 if let Part::Text(text) = &content.parts.first() {
1080 assert_eq!(text, "Hello, user!");
1081 } else {
1082 panic!("Expected text part");
1083 }
1084 }
1085
1086 #[test]
1087 fn test_message_conversion_user() {
1088 let msg = message::Message::user("Hello, world!");
1089 let content: Content = msg.try_into().unwrap();
1090 assert_eq!(content.role, Some(Role::User));
1091 assert_eq!(content.parts.len(), 1);
1092 if let Part::Text(text) = &content.parts.first() {
1093 assert_eq!(text, "Hello, world!");
1094 } else {
1095 panic!("Expected text part");
1096 }
1097 }
1098
1099 #[test]
1100 fn test_message_conversion_model() {
1101 let msg = message::Message::assistant("Hello, user!");
1102
1103 let content: Content = msg.try_into().unwrap();
1104 assert_eq!(content.role, Some(Role::Model));
1105 assert_eq!(content.parts.len(), 1);
1106 if let Part::Text(text) = &content.parts.first() {
1107 assert_eq!(text, "Hello, user!");
1108 } else {
1109 panic!("Expected text part");
1110 }
1111 }
1112
1113 #[test]
1114 fn test_message_conversion_tool_call() {
1115 let tool_call = message::ToolCall {
1116 id: "test_tool".to_string(),
1117 function: message::ToolFunction {
1118 name: "test_function".to_string(),
1119 arguments: json!({"arg1": "value1"}),
1120 },
1121 };
1122
1123 let msg = message::Message::Assistant {
1124 content: OneOrMany::one(message::AssistantContent::ToolCall(tool_call)),
1125 };
1126
1127 let content: Content = msg.try_into().unwrap();
1128 assert_eq!(content.role, Some(Role::Model));
1129 assert_eq!(content.parts.len(), 1);
1130 if let Part::FunctionCall(function_call) = &content.parts.first() {
1131 assert_eq!(function_call.name, "test_function");
1132 assert_eq!(
1133 function_call.args.as_object().unwrap().get("arg1").unwrap(),
1134 "value1"
1135 );
1136 } else {
1137 panic!("Expected function call part");
1138 }
1139 }
1140}