1pub const GEMINI_2_0_FLASH: &str = "gemini-2.0-flash";
8pub const GEMINI_1_5_FLASH: &str = "gemini-1.5-flash";
10pub const GEMINI_1_5_PRO: &str = "gemini-1.5-pro";
12pub const GEMINI_1_5_PRO_8B: &str = "gemini-1.5-pro-8b";
14pub const GEMINI_1_0_PRO: &str = "gemini-1.0-pro";
16
17use gemini_api_types::{
18 Content, FunctionDeclaration, GenerateContentRequest, GenerateContentResponse,
19 GenerationConfig, Part, Role, Tool,
20};
21use serde_json::{Map, Value};
22use std::convert::TryFrom;
23
24use crate::{
25 completion::{self, CompletionError, CompletionRequest},
26 OneOrMany,
27};
28
29use self::gemini_api_types::Schema;
30
31use super::Client;
32
33#[derive(Clone)]
38pub struct CompletionModel {
39 pub(crate) client: Client,
40 pub model: String,
41}
42
43impl CompletionModel {
44 pub fn new(client: Client, model: &str) -> Self {
45 Self {
46 client,
47 model: model.to_string(),
48 }
49 }
50}
51
52impl completion::CompletionModel for CompletionModel {
53 type Response = GenerateContentResponse;
54
55 #[cfg_attr(feature = "worker", worker::send)]
56 async fn completion(
57 &self,
58 completion_request: CompletionRequest,
59 ) -> Result<completion::CompletionResponse<GenerateContentResponse>, CompletionError> {
60 let request = create_request_body(completion_request)?;
61
62 tracing::debug!(
63 "Sending completion request to Gemini API {}",
64 serde_json::to_string_pretty(&request)?
65 );
66
67 let response = self
68 .client
69 .post(&format!("/v1beta/models/{}:generateContent", self.model))
70 .json(&request)
71 .send()
72 .await?;
73
74 if response.status().is_success() {
75 let response = response.json::<GenerateContentResponse>().await?;
76 match response.usage_metadata {
77 Some(ref usage) => tracing::info!(target: "rig",
78 "Gemini completion token usage: {}",
79 usage
80 ),
81 None => tracing::info!(target: "rig",
82 "Gemini completion token usage: n/a",
83 ),
84 }
85
86 tracing::debug!("Received response");
87
88 Ok(completion::CompletionResponse::try_from(response))
89 } else {
90 Err(CompletionError::ProviderError(response.text().await?))
91 }?
92 }
93}
94
95pub(crate) fn create_request_body(
96 completion_request: CompletionRequest,
97) -> Result<GenerateContentRequest, CompletionError> {
98 let mut full_history = Vec::new();
99 full_history.extend(completion_request.chat_history);
100
101 let additional_params = completion_request
102 .additional_params
103 .unwrap_or_else(|| Value::Object(Map::new()));
104
105 let mut generation_config = serde_json::from_value::<GenerationConfig>(additional_params)?;
106
107 if let Some(temp) = completion_request.temperature {
108 generation_config.temperature = Some(temp);
109 }
110
111 if let Some(max_tokens) = completion_request.max_tokens {
112 generation_config.max_output_tokens = Some(max_tokens);
113 }
114
115 let system_instruction = completion_request.preamble.clone().map(|preamble| Content {
116 parts: OneOrMany::one(preamble.into()),
117 role: Some(Role::Model),
118 });
119
120 let request = GenerateContentRequest {
121 contents: full_history
122 .into_iter()
123 .map(|msg| {
124 msg.try_into()
125 .map_err(|e| CompletionError::RequestError(Box::new(e)))
126 })
127 .collect::<Result<Vec<_>, _>>()?,
128 generation_config: Some(generation_config),
129 safety_settings: None,
130 tools: Some(
131 completion_request
132 .tools
133 .into_iter()
134 .map(Tool::try_from)
135 .collect::<Result<Vec<_>, _>>()?,
136 ),
137 tool_config: None,
138 system_instruction,
139 };
140
141 Ok(request)
142}
143
144impl TryFrom<completion::ToolDefinition> for Tool {
145 type Error = CompletionError;
146
147 fn try_from(tool: completion::ToolDefinition) -> Result<Self, Self::Error> {
148 let parameters: Option<Schema> =
149 if tool.parameters == serde_json::json!({"type": "object", "properties": {}}) {
150 None
151 } else {
152 Some(tool.parameters.try_into()?)
153 };
154 Ok(Self {
155 function_declarations: FunctionDeclaration {
156 name: tool.name,
157 description: tool.description,
158 parameters,
159 },
160 code_execution: None,
161 })
162 }
163}
164
165impl TryFrom<GenerateContentResponse> for completion::CompletionResponse<GenerateContentResponse> {
166 type Error = CompletionError;
167
168 fn try_from(response: GenerateContentResponse) -> Result<Self, Self::Error> {
169 let candidate = response.candidates.first().ok_or_else(|| {
170 CompletionError::ResponseError("No response candidates in response".into())
171 })?;
172
173 let content = candidate
174 .content
175 .parts
176 .iter()
177 .map(|part| {
178 Ok(match part {
179 Part::Text(text) => completion::AssistantContent::text(text),
180 Part::FunctionCall(function_call) => completion::AssistantContent::tool_call(
181 &function_call.name,
182 &function_call.name,
183 function_call.args.clone(),
184 ),
185 _ => {
186 return Err(CompletionError::ResponseError(
187 "Response did not contain a message or tool call".into(),
188 ))
189 }
190 })
191 })
192 .collect::<Result<Vec<_>, _>>()?;
193
194 let choice = OneOrMany::many(content).map_err(|_| {
195 CompletionError::ResponseError(
196 "Response contained no message or tool call (empty)".to_owned(),
197 )
198 })?;
199
200 Ok(completion::CompletionResponse {
201 choice,
202 raw_response: response,
203 })
204 }
205}
206
207pub mod gemini_api_types {
208 use std::{collections::HashMap, convert::Infallible, str::FromStr};
209
210 use serde::{Deserialize, Serialize};
214 use serde_json::Value;
215
216 use crate::{
217 completion::CompletionError,
218 message::{self, MimeType as _},
219 one_or_many::string_or_one_or_many,
220 providers::gemini::gemini_api_types::{CodeExecutionResult, ExecutableCode},
221 OneOrMany,
222 };
223
224 #[derive(Debug, Deserialize)]
232 #[serde(rename_all = "camelCase")]
233 pub struct GenerateContentResponse {
234 pub candidates: Vec<ContentCandidate>,
236 pub prompt_feedback: Option<PromptFeedback>,
238 pub usage_metadata: Option<UsageMetadata>,
240 pub model_version: Option<String>,
241 }
242
243 #[derive(Debug, Deserialize)]
245 #[serde(rename_all = "camelCase")]
246 pub struct ContentCandidate {
247 pub content: Content,
249 pub finish_reason: Option<FinishReason>,
252 pub safety_ratings: Option<Vec<SafetyRating>>,
255 pub citation_metadata: Option<CitationMetadata>,
259 pub token_count: Option<i32>,
261 pub avg_logprobs: Option<f64>,
263 pub logprobs_result: Option<LogprobsResult>,
265 pub index: Option<i32>,
267 }
268 #[derive(Debug, Deserialize, Serialize)]
269 pub struct Content {
270 #[serde(deserialize_with = "string_or_one_or_many")]
272 pub parts: OneOrMany<Part>,
273 pub role: Option<Role>,
276 }
277
278 impl TryFrom<message::Message> for Content {
279 type Error = message::MessageError;
280
281 fn try_from(msg: message::Message) -> Result<Self, Self::Error> {
282 Ok(match msg {
283 message::Message::User { content } => Content {
284 parts: content.try_map(|c| c.try_into())?,
285 role: Some(Role::User),
286 },
287 message::Message::Assistant { content } => Content {
288 role: Some(Role::Model),
289 parts: content.map(|content| content.into()),
290 },
291 })
292 }
293 }
294
295 impl TryFrom<Content> for message::Message {
296 type Error = message::MessageError;
297
298 fn try_from(content: Content) -> Result<Self, Self::Error> {
299 match content.role {
300 Some(Role::User) | None => Ok(message::Message::User {
301 content: content.parts.try_map(|part| {
302 Ok(match part {
303 Part::Text(text) => message::UserContent::text(text),
304 Part::InlineData(inline_data) => {
305 let mime_type =
306 message::MediaType::from_mime_type(&inline_data.mime_type);
307
308 match mime_type {
309 Some(message::MediaType::Image(media_type)) => {
310 message::UserContent::image(
311 inline_data.data,
312 Some(message::ContentFormat::default()),
313 Some(media_type),
314 Some(message::ImageDetail::default()),
315 )
316 }
317 Some(message::MediaType::Document(media_type)) => {
318 message::UserContent::document(
319 inline_data.data,
320 Some(message::ContentFormat::default()),
321 Some(media_type),
322 )
323 }
324 Some(message::MediaType::Audio(media_type)) => {
325 message::UserContent::audio(
326 inline_data.data,
327 Some(message::ContentFormat::default()),
328 Some(media_type),
329 )
330 }
331 _ => {
332 return Err(message::MessageError::ConversionError(
333 format!("Unsupported media type {:?}", mime_type),
334 ))
335 }
336 }
337 }
338 _ => {
339 return Err(message::MessageError::ConversionError(format!(
340 "Unsupported gemini content part type: {:?}",
341 part
342 )))
343 }
344 })
345 })?,
346 }),
347 Some(Role::Model) => Ok(message::Message::Assistant {
348 content: content.parts.try_map(|part| {
349 Ok(match part {
350 Part::Text(text) => message::AssistantContent::text(text),
351 Part::FunctionCall(function_call) => {
352 message::AssistantContent::ToolCall(function_call.into())
353 }
354 _ => {
355 return Err(message::MessageError::ConversionError(format!(
356 "Unsupported part type: {:?}",
357 part
358 )))
359 }
360 })
361 })?,
362 }),
363 }
364 }
365 }
366
367 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
368 #[serde(rename_all = "lowercase")]
369 pub enum Role {
370 User,
371 Model,
372 }
373
374 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
378 #[serde(rename_all = "camelCase")]
379 pub enum Part {
380 Text(String),
381 InlineData(Blob),
382 FunctionCall(FunctionCall),
383 FunctionResponse(FunctionResponse),
384 FileData(FileData),
385 ExecutableCode(ExecutableCode),
386 CodeExecutionResult(CodeExecutionResult),
387 }
388
389 impl From<String> for Part {
390 fn from(text: String) -> Self {
391 Self::Text(text)
392 }
393 }
394
395 impl From<&str> for Part {
396 fn from(text: &str) -> Self {
397 Self::Text(text.to_string())
398 }
399 }
400
401 impl FromStr for Part {
402 type Err = Infallible;
403
404 fn from_str(s: &str) -> Result<Self, Self::Err> {
405 Ok(s.into())
406 }
407 }
408
409 impl TryFrom<message::UserContent> for Part {
410 type Error = message::MessageError;
411
412 fn try_from(content: message::UserContent) -> Result<Self, Self::Error> {
413 match content {
414 message::UserContent::Text(message::Text { text }) => Ok(Self::Text(text)),
415 message::UserContent::ToolResult(message::ToolResult { id, content }) => {
416 let content = match content.first() {
417 message::ToolResultContent::Text(text) => text.text,
418 message::ToolResultContent::Image(_) => {
419 return Err(message::MessageError::ConversionError(
420 "Tool result content must be text".to_string(),
421 ))
422 }
423 };
424 Ok(Part::FunctionResponse(FunctionResponse {
425 name: id,
426 response: Some(serde_json::from_str(&content).map_err(|e| {
427 message::MessageError::ConversionError(format!(
428 "Failed to parse tool response: {}",
429 e
430 ))
431 })?),
432 }))
433 }
434 message::UserContent::Image(message::Image {
435 data, media_type, ..
436 }) => match media_type {
437 Some(media_type) => match media_type {
438 message::ImageMediaType::JPEG
439 | message::ImageMediaType::PNG
440 | message::ImageMediaType::WEBP
441 | message::ImageMediaType::HEIC
442 | message::ImageMediaType::HEIF => Ok(Self::InlineData(Blob {
443 mime_type: media_type.to_mime_type().to_owned(),
444 data,
445 })),
446 _ => Err(message::MessageError::ConversionError(format!(
447 "Unsupported image media type {:?}",
448 media_type
449 ))),
450 },
451 None => Err(message::MessageError::ConversionError(
452 "Media type for image is required for Anthropic".to_string(),
453 )),
454 },
455 message::UserContent::Document(message::Document {
456 data, media_type, ..
457 }) => match media_type {
458 Some(media_type) => match media_type {
459 message::DocumentMediaType::PDF
460 | message::DocumentMediaType::TXT
461 | message::DocumentMediaType::RTF
462 | message::DocumentMediaType::HTML
463 | message::DocumentMediaType::CSS
464 | message::DocumentMediaType::MARKDOWN
465 | message::DocumentMediaType::CSV
466 | message::DocumentMediaType::XML => Ok(Self::InlineData(Blob {
467 mime_type: media_type.to_mime_type().to_owned(),
468 data,
469 })),
470 _ => Err(message::MessageError::ConversionError(format!(
471 "Unsupported document media type {:?}",
472 media_type
473 ))),
474 },
475 None => Err(message::MessageError::ConversionError(
476 "Media type for document is required for Anthropic".to_string(),
477 )),
478 },
479 message::UserContent::Audio(message::Audio {
480 data, media_type, ..
481 }) => match media_type {
482 Some(media_type) => Ok(Self::InlineData(Blob {
483 mime_type: media_type.to_mime_type().to_owned(),
484 data,
485 })),
486 None => Err(message::MessageError::ConversionError(
487 "Media type for audio is required for Anthropic".to_string(),
488 )),
489 },
490 }
491 }
492 }
493
494 impl From<message::AssistantContent> for Part {
495 fn from(content: message::AssistantContent) -> Self {
496 match content {
497 message::AssistantContent::Text(message::Text { text }) => text.into(),
498 message::AssistantContent::ToolCall(tool_call) => tool_call.into(),
499 }
500 }
501 }
502
503 impl From<message::ToolCall> for Part {
504 fn from(tool_call: message::ToolCall) -> Self {
505 Self::FunctionCall(FunctionCall {
506 name: tool_call.function.name,
507 args: tool_call.function.arguments,
508 })
509 }
510 }
511
512 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
515 #[serde(rename_all = "camelCase")]
516 pub struct Blob {
517 pub mime_type: String,
520 pub data: String,
522 }
523
524 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
527 pub struct FunctionCall {
528 pub name: String,
531 pub args: serde_json::Value,
533 }
534
535 impl From<FunctionCall> for message::ToolCall {
536 fn from(function_call: FunctionCall) -> Self {
537 Self {
538 id: function_call.name.clone(),
539 function: message::ToolFunction {
540 name: function_call.name,
541 arguments: function_call.args,
542 },
543 }
544 }
545 }
546
547 impl From<message::ToolCall> for FunctionCall {
548 fn from(tool_call: message::ToolCall) -> Self {
549 Self {
550 name: tool_call.function.name,
551 args: tool_call.function.arguments,
552 }
553 }
554 }
555
556 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
560 pub struct FunctionResponse {
561 pub name: String,
564 pub response: Option<HashMap<String, serde_json::Value>>,
566 }
567
568 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
570 #[serde(rename_all = "camelCase")]
571 pub struct FileData {
572 pub mime_type: Option<String>,
574 pub file_uri: String,
576 }
577
578 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
579 pub struct SafetyRating {
580 pub category: HarmCategory,
581 pub probability: HarmProbability,
582 }
583
584 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
585 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
586 pub enum HarmProbability {
587 HarmProbabilityUnspecified,
588 Negligible,
589 Low,
590 Medium,
591 High,
592 }
593
594 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
595 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
596 pub enum HarmCategory {
597 HarmCategoryUnspecified,
598 HarmCategoryDerogatory,
599 HarmCategoryToxicity,
600 HarmCategoryViolence,
601 HarmCategorySexually,
602 HarmCategoryMedical,
603 HarmCategoryDangerous,
604 HarmCategoryHarassment,
605 HarmCategoryHateSpeech,
606 HarmCategorySexuallyExplicit,
607 HarmCategoryDangerousContent,
608 HarmCategoryCivicIntegrity,
609 }
610
611 #[derive(Debug, Deserialize)]
612 #[serde(rename_all = "camelCase")]
613 pub struct UsageMetadata {
614 pub prompt_token_count: i32,
615 #[serde(skip_serializing_if = "Option::is_none")]
616 pub cached_content_token_count: Option<i32>,
617 pub candidates_token_count: i32,
618 pub total_token_count: i32,
619 }
620
621 impl std::fmt::Display for UsageMetadata {
622 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
623 write!(
624 f,
625 "Prompt token count: {}\nCached content token count: {}\nCandidates token count: {}\nTotal token count: {}",
626 self.prompt_token_count,
627 match self.cached_content_token_count {
628 Some(count) => count.to_string(),
629 None => "n/a".to_string(),
630 },
631 self.candidates_token_count,
632 self.total_token_count
633 )
634 }
635 }
636
637 #[derive(Debug, Deserialize)]
639 #[serde(rename_all = "camelCase")]
640 pub struct PromptFeedback {
641 pub block_reason: Option<BlockReason>,
643 pub safety_ratings: Option<Vec<SafetyRating>>,
645 }
646
647 #[derive(Debug, Deserialize)]
649 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
650 pub enum BlockReason {
651 BlockReasonUnspecified,
653 Safety,
655 Other,
657 Blocklist,
659 ProhibitedContent,
661 }
662
663 #[derive(Debug, Deserialize)]
664 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
665 pub enum FinishReason {
666 FinishReasonUnspecified,
668 Stop,
670 MaxTokens,
672 Safety,
674 Recitation,
676 Language,
678 Other,
680 Blocklist,
682 ProhibitedContent,
684 Spii,
686 MalformedFunctionCall,
688 }
689
690 #[derive(Debug, Deserialize)]
691 #[serde(rename_all = "camelCase")]
692 pub struct CitationMetadata {
693 pub citation_sources: Vec<CitationSource>,
694 }
695
696 #[derive(Debug, Deserialize)]
697 #[serde(rename_all = "camelCase")]
698 pub struct CitationSource {
699 #[serde(skip_serializing_if = "Option::is_none")]
700 pub uri: Option<String>,
701 #[serde(skip_serializing_if = "Option::is_none")]
702 pub start_index: Option<i32>,
703 #[serde(skip_serializing_if = "Option::is_none")]
704 pub end_index: Option<i32>,
705 #[serde(skip_serializing_if = "Option::is_none")]
706 pub license: Option<String>,
707 }
708
709 #[derive(Debug, Deserialize)]
710 #[serde(rename_all = "camelCase")]
711 pub struct LogprobsResult {
712 pub top_candidate: Vec<TopCandidate>,
713 pub chosen_candidate: Vec<LogProbCandidate>,
714 }
715
716 #[derive(Debug, Deserialize)]
717 pub struct TopCandidate {
718 pub candidates: Vec<LogProbCandidate>,
719 }
720
721 #[derive(Debug, Deserialize)]
722 #[serde(rename_all = "camelCase")]
723 pub struct LogProbCandidate {
724 pub token: String,
725 pub token_id: String,
726 pub log_probability: f64,
727 }
728
729 #[derive(Debug, Deserialize, Serialize)]
734 #[serde(rename_all = "camelCase")]
735 pub struct GenerationConfig {
736 #[serde(skip_serializing_if = "Option::is_none")]
739 pub stop_sequences: Option<Vec<String>>,
740 #[serde(skip_serializing_if = "Option::is_none")]
746 pub response_mime_type: Option<String>,
747 #[serde(skip_serializing_if = "Option::is_none")]
751 pub response_schema: Option<Schema>,
752 #[serde(skip_serializing_if = "Option::is_none")]
755 pub candidate_count: Option<i32>,
756 #[serde(skip_serializing_if = "Option::is_none")]
759 pub max_output_tokens: Option<u64>,
760 #[serde(skip_serializing_if = "Option::is_none")]
763 pub temperature: Option<f64>,
764 #[serde(skip_serializing_if = "Option::is_none")]
771 pub top_p: Option<f64>,
772 #[serde(skip_serializing_if = "Option::is_none")]
778 pub top_k: Option<i32>,
779 #[serde(skip_serializing_if = "Option::is_none")]
785 pub presence_penalty: Option<f64>,
786 #[serde(skip_serializing_if = "Option::is_none")]
794 pub frequency_penalty: Option<f64>,
795 #[serde(skip_serializing_if = "Option::is_none")]
797 pub response_logprobs: Option<bool>,
798 #[serde(skip_serializing_if = "Option::is_none")]
801 pub logprobs: Option<i32>,
802 }
803
804 impl Default for GenerationConfig {
805 fn default() -> Self {
806 Self {
807 temperature: Some(1.0),
808 max_output_tokens: Some(4096),
809 stop_sequences: None,
810 response_mime_type: None,
811 response_schema: None,
812 candidate_count: None,
813 top_p: None,
814 top_k: None,
815 presence_penalty: None,
816 frequency_penalty: None,
817 response_logprobs: None,
818 logprobs: None,
819 }
820 }
821 }
822 #[derive(Debug, Deserialize, Serialize)]
826 pub struct Schema {
827 pub r#type: String,
828 #[serde(skip_serializing_if = "Option::is_none")]
829 pub format: Option<String>,
830 #[serde(skip_serializing_if = "Option::is_none")]
831 pub description: Option<String>,
832 #[serde(skip_serializing_if = "Option::is_none")]
833 pub nullable: Option<bool>,
834 #[serde(skip_serializing_if = "Option::is_none")]
835 pub r#enum: Option<Vec<String>>,
836 #[serde(skip_serializing_if = "Option::is_none")]
837 pub max_items: Option<i32>,
838 #[serde(skip_serializing_if = "Option::is_none")]
839 pub min_items: Option<i32>,
840 #[serde(skip_serializing_if = "Option::is_none")]
841 pub properties: Option<HashMap<String, Schema>>,
842 #[serde(skip_serializing_if = "Option::is_none")]
843 pub required: Option<Vec<String>>,
844 #[serde(skip_serializing_if = "Option::is_none")]
845 pub items: Option<Box<Schema>>,
846 }
847
848 impl TryFrom<Value> for Schema {
849 type Error = CompletionError;
850
851 fn try_from(value: Value) -> Result<Self, Self::Error> {
852 if let Some(obj) = value.as_object() {
853 Ok(Schema {
854 r#type: obj
855 .get("type")
856 .and_then(|v| {
857 if v.is_string() {
858 v.as_str().map(String::from)
859 } else if v.is_array() {
860 v.as_array()
861 .and_then(|arr| arr.first())
862 .and_then(|v| v.as_str().map(String::from))
863 } else {
864 None
865 }
866 })
867 .unwrap_or_default(),
868 format: obj.get("format").and_then(|v| v.as_str()).map(String::from),
869 description: obj
870 .get("description")
871 .and_then(|v| v.as_str())
872 .map(String::from),
873 nullable: obj.get("nullable").and_then(|v| v.as_bool()),
874 r#enum: obj.get("enum").and_then(|v| v.as_array()).map(|arr| {
875 arr.iter()
876 .filter_map(|v| v.as_str().map(String::from))
877 .collect()
878 }),
879 max_items: obj
880 .get("maxItems")
881 .and_then(|v| v.as_i64())
882 .map(|v| v as i32),
883 min_items: obj
884 .get("minItems")
885 .and_then(|v| v.as_i64())
886 .map(|v| v as i32),
887 properties: obj
888 .get("properties")
889 .and_then(|v| v.as_object())
890 .map(|map| {
891 map.iter()
892 .filter_map(|(k, v)| {
893 v.clone().try_into().ok().map(|schema| (k.clone(), schema))
894 })
895 .collect()
896 }),
897 required: obj.get("required").and_then(|v| v.as_array()).map(|arr| {
898 arr.iter()
899 .filter_map(|v| v.as_str().map(String::from))
900 .collect()
901 }),
902 items: obj
903 .get("items")
904 .map(|v| Box::new(v.clone().try_into().unwrap())),
905 })
906 } else {
907 Err(CompletionError::ResponseError(
908 "Expected a JSON object for Schema".into(),
909 ))
910 }
911 }
912 }
913
914 #[derive(Debug, Serialize)]
915 #[serde(rename_all = "camelCase")]
916 pub struct GenerateContentRequest {
917 pub contents: Vec<Content>,
918 pub tools: Option<Vec<Tool>>,
919 pub tool_config: Option<ToolConfig>,
920 pub generation_config: Option<GenerationConfig>,
922 pub safety_settings: Option<Vec<SafetySetting>>,
936 pub system_instruction: Option<Content>,
939 }
941
942 #[derive(Debug, Serialize)]
943 #[serde(rename_all = "camelCase")]
944 pub struct Tool {
945 pub function_declarations: FunctionDeclaration,
946 pub code_execution: Option<CodeExecution>,
947 }
948
949 #[derive(Debug, Serialize)]
950 #[serde(rename_all = "camelCase")]
951 pub struct FunctionDeclaration {
952 pub name: String,
953 pub description: String,
954 #[serde(skip_serializing_if = "Option::is_none")]
955 pub parameters: Option<Schema>,
956 }
957
958 #[derive(Debug, Serialize)]
959 #[serde(rename_all = "camelCase")]
960 pub struct ToolConfig {
961 pub schema: Option<Schema>,
962 }
963
964 #[derive(Debug, Serialize)]
965 #[serde(rename_all = "camelCase")]
966 pub struct CodeExecution {}
967
968 #[derive(Debug, Serialize)]
969 #[serde(rename_all = "camelCase")]
970 pub struct SafetySetting {
971 pub category: HarmCategory,
972 pub threshold: HarmBlockThreshold,
973 }
974
975 #[derive(Debug, Serialize)]
976 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
977 pub enum HarmBlockThreshold {
978 HarmBlockThresholdUnspecified,
979 BlockLowAndAbove,
980 BlockMediumAndAbove,
981 BlockOnlyHigh,
982 BlockNone,
983 Off,
984 }
985}
986
987#[cfg(test)]
988mod tests {
989 use crate::message;
990
991 use super::*;
992 use serde_json::json;
993
994 #[test]
995 fn test_deserialize_message_user() {
996 let raw_message = r#"{
997 "parts": [
998 {"text": "Hello, world!"},
999 {"inlineData": {"mimeType": "image/png", "data": "base64encodeddata"}},
1000 {"functionCall": {"name": "test_function", "args": {"arg1": "value1"}}},
1001 {"functionResponse": {"name": "test_function", "response": {"result": "success"}}},
1002 {"fileData": {"mimeType": "application/pdf", "fileUri": "http://example.com/file.pdf"}},
1003 {"executableCode": {"code": "print('Hello, world!')", "language": "PYTHON"}},
1004 {"codeExecutionResult": {"output": "Hello, world!", "outcome": "OUTCOME_OK"}}
1005 ],
1006 "role": "user"
1007 }"#;
1008
1009 let content: Content = {
1010 let jd = &mut serde_json::Deserializer::from_str(raw_message);
1011 serde_path_to_error::deserialize(jd).unwrap_or_else(|err| {
1012 panic!("Deserialization error at {}: {}", err.path(), err);
1013 })
1014 };
1015 assert_eq!(content.role, Some(Role::User));
1016 assert_eq!(content.parts.len(), 7);
1017
1018 let parts: Vec<Part> = content.parts.into_iter().collect();
1019
1020 if let Part::Text(text) = &parts[0] {
1021 assert_eq!(text, "Hello, world!");
1022 } else {
1023 panic!("Expected text part");
1024 }
1025
1026 if let Part::InlineData(inline_data) = &parts[1] {
1027 assert_eq!(inline_data.mime_type, "image/png");
1028 assert_eq!(inline_data.data, "base64encodeddata");
1029 } else {
1030 panic!("Expected inline data part");
1031 }
1032
1033 if let Part::FunctionCall(function_call) = &parts[2] {
1034 assert_eq!(function_call.name, "test_function");
1035 assert_eq!(
1036 function_call.args.as_object().unwrap().get("arg1").unwrap(),
1037 "value1"
1038 );
1039 } else {
1040 panic!("Expected function call part");
1041 }
1042
1043 if let Part::FunctionResponse(function_response) = &parts[3] {
1044 assert_eq!(function_response.name, "test_function");
1045 assert_eq!(
1046 function_response
1047 .response
1048 .as_ref()
1049 .unwrap()
1050 .get("result")
1051 .unwrap(),
1052 "success"
1053 );
1054 } else {
1055 panic!("Expected function response part");
1056 }
1057
1058 if let Part::FileData(file_data) = &parts[4] {
1059 assert_eq!(file_data.mime_type.as_ref().unwrap(), "application/pdf");
1060 assert_eq!(file_data.file_uri, "http://example.com/file.pdf");
1061 } else {
1062 panic!("Expected file data part");
1063 }
1064
1065 if let Part::ExecutableCode(executable_code) = &parts[5] {
1066 assert_eq!(executable_code.code, "print('Hello, world!')");
1067 } else {
1068 panic!("Expected executable code part");
1069 }
1070
1071 if let Part::CodeExecutionResult(code_execution_result) = &parts[6] {
1072 assert_eq!(
1073 code_execution_result.clone().output.unwrap(),
1074 "Hello, world!"
1075 );
1076 } else {
1077 panic!("Expected code execution result part");
1078 }
1079 }
1080
1081 #[test]
1082 fn test_deserialize_message_model() {
1083 let json_data = json!({
1084 "parts": [{"text": "Hello, user!"}],
1085 "role": "model"
1086 });
1087
1088 let content: Content = serde_json::from_value(json_data).unwrap();
1089 assert_eq!(content.role, Some(Role::Model));
1090 assert_eq!(content.parts.len(), 1);
1091 if let Part::Text(text) = &content.parts.first() {
1092 assert_eq!(text, "Hello, user!");
1093 } else {
1094 panic!("Expected text part");
1095 }
1096 }
1097
1098 #[test]
1099 fn test_message_conversion_user() {
1100 let msg = message::Message::user("Hello, world!");
1101 let content: Content = msg.try_into().unwrap();
1102 assert_eq!(content.role, Some(Role::User));
1103 assert_eq!(content.parts.len(), 1);
1104 if let Part::Text(text) = &content.parts.first() {
1105 assert_eq!(text, "Hello, world!");
1106 } else {
1107 panic!("Expected text part");
1108 }
1109 }
1110
1111 #[test]
1112 fn test_message_conversion_model() {
1113 let msg = message::Message::assistant("Hello, user!");
1114
1115 let content: Content = msg.try_into().unwrap();
1116 assert_eq!(content.role, Some(Role::Model));
1117 assert_eq!(content.parts.len(), 1);
1118 if let Part::Text(text) = &content.parts.first() {
1119 assert_eq!(text, "Hello, user!");
1120 } else {
1121 panic!("Expected text part");
1122 }
1123 }
1124
1125 #[test]
1126 fn test_message_conversion_tool_call() {
1127 let tool_call = message::ToolCall {
1128 id: "test_tool".to_string(),
1129 function: message::ToolFunction {
1130 name: "test_function".to_string(),
1131 arguments: json!({"arg1": "value1"}),
1132 },
1133 };
1134
1135 let msg = message::Message::Assistant {
1136 content: OneOrMany::one(message::AssistantContent::ToolCall(tool_call)),
1137 };
1138
1139 let content: Content = msg.try_into().unwrap();
1140 assert_eq!(content.role, Some(Role::Model));
1141 assert_eq!(content.parts.len(), 1);
1142 if let Part::FunctionCall(function_call) = &content.parts.first() {
1143 assert_eq!(function_call.name, "test_function");
1144 assert_eq!(
1145 function_call.args.as_object().unwrap().get("arg1").unwrap(),
1146 "value1"
1147 );
1148 } else {
1149 panic!("Expected function call part");
1150 }
1151 }
1152}