1pub const GEMINI_2_5_PRO_PREVIEW_06_05: &str = "gemini-2.5-pro-preview-06-05";
7pub const GEMINI_2_5_PRO_PREVIEW_05_06: &str = "gemini-2.5-pro-preview-05-06";
9pub const GEMINI_2_5_PRO_PREVIEW_03_25: &str = "gemini-2.5-pro-preview-03-25";
11pub const GEMINI_2_5_FLASH_PREVIEW_05_20: &str = "gemini-2.5-flash-preview-05-20";
13pub const GEMINI_2_5_FLASH_PREVIEW_04_17: &str = "gemini-2.5-flash-preview-04-17";
15pub const GEMINI_2_5_PRO_EXP_03_25: &str = "gemini-2.5-pro-exp-03-25";
17pub const GEMINI_2_0_FLASH_LITE: &str = "gemini-2.0-flash-lite";
19pub const GEMINI_2_0_FLASH: &str = "gemini-2.0-flash";
21pub const GEMINI_1_5_FLASH: &str = "gemini-1.5-flash";
23pub const GEMINI_1_5_PRO: &str = "gemini-1.5-pro";
25pub const GEMINI_1_5_PRO_8B: &str = "gemini-1.5-pro-8b";
27pub const GEMINI_1_0_PRO: &str = "gemini-1.0-pro";
29
30use self::gemini_api_types::Schema;
31use crate::providers::gemini::streaming::StreamingCompletionResponse;
32use crate::{
33 OneOrMany,
34 completion::{self, CompletionError, CompletionRequest},
35};
36use gemini_api_types::{
37 Content, FunctionDeclaration, GenerateContentRequest, GenerateContentResponse,
38 GenerationConfig, Part, Role, Tool,
39};
40use serde_json::{Map, Value};
41use std::convert::TryFrom;
42
43use super::Client;
44
45#[derive(Clone)]
50pub struct CompletionModel {
51 pub(crate) client: Client,
52 pub model: String,
53}
54
55impl CompletionModel {
56 pub fn new(client: Client, model: &str) -> Self {
57 Self {
58 client,
59 model: model.to_string(),
60 }
61 }
62}
63
64impl completion::CompletionModel for CompletionModel {
65 type Response = GenerateContentResponse;
66 type StreamingResponse = StreamingCompletionResponse;
67
68 #[cfg_attr(feature = "worker", worker::send)]
69 async fn completion(
70 &self,
71 completion_request: CompletionRequest,
72 ) -> Result<completion::CompletionResponse<GenerateContentResponse>, CompletionError> {
73 let request = create_request_body(completion_request)?;
74
75 tracing::debug!(
76 "Sending completion request to Gemini API {}",
77 serde_json::to_string_pretty(&request)?
78 );
79
80 let response = self
81 .client
82 .post(&format!("/v1beta/models/{}:generateContent", self.model))
83 .json(&request)
84 .send()
85 .await?;
86
87 if response.status().is_success() {
88 let response = response.json::<GenerateContentResponse>().await?;
89 match response.usage_metadata {
90 Some(ref usage) => tracing::info!(target: "rig",
91 "Gemini completion token usage: {}",
92 usage
93 ),
94 None => tracing::info!(target: "rig",
95 "Gemini completion token usage: n/a",
96 ),
97 }
98
99 tracing::debug!("Received response");
100
101 Ok(completion::CompletionResponse::try_from(response))
102 } else {
103 Err(CompletionError::ProviderError(response.text().await?))
104 }?
105 }
106
107 #[cfg_attr(feature = "worker", worker::send)]
108 async fn stream(
109 &self,
110 request: CompletionRequest,
111 ) -> Result<
112 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
113 CompletionError,
114 > {
115 CompletionModel::stream(self, request).await
116 }
117}
118
119pub(crate) fn create_request_body(
120 completion_request: CompletionRequest,
121) -> Result<GenerateContentRequest, CompletionError> {
122 let mut full_history = Vec::new();
123 full_history.extend(completion_request.chat_history);
124
125 let additional_params = completion_request
126 .additional_params
127 .unwrap_or_else(|| Value::Object(Map::new()));
128
129 let mut generation_config = serde_json::from_value::<GenerationConfig>(additional_params)?;
130
131 if let Some(temp) = completion_request.temperature {
132 generation_config.temperature = Some(temp);
133 }
134
135 if let Some(max_tokens) = completion_request.max_tokens {
136 generation_config.max_output_tokens = Some(max_tokens);
137 }
138
139 let system_instruction = completion_request.preamble.clone().map(|preamble| Content {
140 parts: OneOrMany::one(preamble.into()),
141 role: Some(Role::Model),
142 });
143
144 let request = GenerateContentRequest {
145 contents: full_history
146 .into_iter()
147 .map(|msg| {
148 msg.try_into()
149 .map_err(|e| CompletionError::RequestError(Box::new(e)))
150 })
151 .collect::<Result<Vec<_>, _>>()?,
152 generation_config: Some(generation_config),
153 safety_settings: None,
154 tools: Some(Tool::try_from(completion_request.tools)?),
155 tool_config: None,
156 system_instruction,
157 };
158
159 Ok(request)
160}
161
162impl TryFrom<completion::ToolDefinition> for Tool {
163 type Error = CompletionError;
164
165 fn try_from(tool: completion::ToolDefinition) -> Result<Self, Self::Error> {
166 let parameters: Option<Schema> =
167 if tool.parameters == serde_json::json!({"type": "object", "properties": {}}) {
168 None
169 } else {
170 Some(tool.parameters.try_into()?)
171 };
172
173 Ok(Self {
174 function_declarations: vec![FunctionDeclaration {
175 name: tool.name,
176 description: tool.description,
177 parameters,
178 }],
179 code_execution: None,
180 })
181 }
182}
183
184impl TryFrom<Vec<completion::ToolDefinition>> for Tool {
185 type Error = CompletionError;
186
187 fn try_from(tools: Vec<completion::ToolDefinition>) -> Result<Self, Self::Error> {
188 let mut function_declarations = Vec::new();
189
190 for tool in tools {
191 let parameters =
192 if tool.parameters == serde_json::json!({"type": "object", "properties": {}}) {
193 None
194 } else {
195 match tool.parameters.try_into() {
196 Ok(schema) => Some(schema),
197 Err(e) => {
198 let emsg = format!(
199 "Tool '{}' could not be converted to a schema: {:?}",
200 tool.name, e,
201 );
202 return Err(CompletionError::ProviderError(emsg));
203 }
204 }
205 };
206
207 function_declarations.push(FunctionDeclaration {
208 name: tool.name,
209 description: tool.description,
210 parameters,
211 });
212 }
213
214 Ok(Self {
215 function_declarations,
216 code_execution: None,
217 })
218 }
219}
220
221impl TryFrom<GenerateContentResponse> for completion::CompletionResponse<GenerateContentResponse> {
222 type Error = CompletionError;
223
224 fn try_from(response: GenerateContentResponse) -> Result<Self, Self::Error> {
225 let candidate = response.candidates.first().ok_or_else(|| {
226 CompletionError::ResponseError("No response candidates in response".into())
227 })?;
228
229 let content = candidate
230 .content
231 .parts
232 .iter()
233 .map(|part| {
234 Ok(match part {
235 Part::Text(text) => completion::AssistantContent::text(text),
236 Part::FunctionCall(function_call) => completion::AssistantContent::tool_call(
237 &function_call.name,
238 &function_call.name,
239 function_call.args.clone(),
240 ),
241 _ => {
242 return Err(CompletionError::ResponseError(
243 "Response did not contain a message or tool call".into(),
244 ));
245 }
246 })
247 })
248 .collect::<Result<Vec<_>, _>>()?;
249
250 let choice = OneOrMany::many(content).map_err(|_| {
251 CompletionError::ResponseError(
252 "Response contained no message or tool call (empty)".to_owned(),
253 )
254 })?;
255
256 let usage = response
257 .usage_metadata
258 .as_ref()
259 .map(|usage| completion::Usage {
260 input_tokens: usage.prompt_token_count as u64,
261 output_tokens: usage.candidates_token_count as u64,
262 total_tokens: usage.total_token_count as u64,
263 })
264 .unwrap_or_default();
265
266 Ok(completion::CompletionResponse {
267 choice,
268 usage,
269 raw_response: response,
270 })
271 }
272}
273
274pub mod gemini_api_types {
275 use std::{collections::HashMap, convert::Infallible, str::FromStr};
276
277 use serde::{Deserialize, Serialize};
281 use serde_json::{Value, json};
282
283 use crate::{
284 OneOrMany,
285 completion::CompletionError,
286 message::{self, MessageError, MimeType as _},
287 one_or_many::string_or_one_or_many,
288 providers::gemini::gemini_api_types::{CodeExecutionResult, ExecutableCode},
289 };
290
291 #[derive(Debug, Deserialize)]
299 #[serde(rename_all = "camelCase")]
300 pub struct GenerateContentResponse {
301 pub candidates: Vec<ContentCandidate>,
303 pub prompt_feedback: Option<PromptFeedback>,
305 pub usage_metadata: Option<UsageMetadata>,
307 pub model_version: Option<String>,
308 }
309
310 #[derive(Debug, Deserialize)]
312 #[serde(rename_all = "camelCase")]
313 pub struct ContentCandidate {
314 pub content: Content,
316 pub finish_reason: Option<FinishReason>,
319 pub safety_ratings: Option<Vec<SafetyRating>>,
322 pub citation_metadata: Option<CitationMetadata>,
326 pub token_count: Option<i32>,
328 pub avg_logprobs: Option<f64>,
330 pub logprobs_result: Option<LogprobsResult>,
332 pub index: Option<i32>,
334 }
335 #[derive(Debug, Deserialize, Serialize)]
336 pub struct Content {
337 #[serde(deserialize_with = "string_or_one_or_many")]
339 pub parts: OneOrMany<Part>,
340 pub role: Option<Role>,
343 }
344
345 impl TryFrom<message::Message> for Content {
346 type Error = message::MessageError;
347
348 fn try_from(msg: message::Message) -> Result<Self, Self::Error> {
349 Ok(match msg {
350 message::Message::User { content } => Content {
351 parts: content.try_map(|c| c.try_into())?,
352 role: Some(Role::User),
353 },
354 message::Message::Assistant { content, .. } => Content {
355 role: Some(Role::Model),
356 parts: content.map(|content| content.into()),
357 },
358 })
359 }
360 }
361
362 impl TryFrom<Content> for message::Message {
363 type Error = message::MessageError;
364
365 fn try_from(content: Content) -> Result<Self, Self::Error> {
366 match content.role {
367 Some(Role::User) | None => Ok(message::Message::User {
368 content: content.parts.try_map(|part| {
369 Ok(match part {
370 Part::Text(text) => message::UserContent::text(text),
371 Part::InlineData(inline_data) => {
372 let mime_type =
373 message::MediaType::from_mime_type(&inline_data.mime_type);
374
375 match mime_type {
376 Some(message::MediaType::Image(media_type)) => {
377 message::UserContent::image(
378 inline_data.data,
379 Some(message::ContentFormat::default()),
380 Some(media_type),
381 Some(message::ImageDetail::default()),
382 )
383 }
384 Some(message::MediaType::Document(media_type)) => {
385 message::UserContent::document(
386 inline_data.data,
387 Some(message::ContentFormat::default()),
388 Some(media_type),
389 )
390 }
391 Some(message::MediaType::Audio(media_type)) => {
392 message::UserContent::audio(
393 inline_data.data,
394 Some(message::ContentFormat::default()),
395 Some(media_type),
396 )
397 }
398 _ => {
399 return Err(message::MessageError::ConversionError(
400 format!("Unsupported media type {mime_type:?}"),
401 ));
402 }
403 }
404 }
405 _ => {
406 return Err(message::MessageError::ConversionError(format!(
407 "Unsupported gemini content part type: {part:?}"
408 )));
409 }
410 })
411 })?,
412 }),
413 Some(Role::Model) => Ok(message::Message::Assistant {
414 id: None,
415 content: content.parts.try_map(|part| {
416 Ok(match part {
417 Part::Text(text) => message::AssistantContent::text(text),
418 Part::FunctionCall(function_call) => {
419 message::AssistantContent::ToolCall(function_call.into())
420 }
421 _ => {
422 return Err(message::MessageError::ConversionError(format!(
423 "Unsupported part type: {part:?}"
424 )));
425 }
426 })
427 })?,
428 }),
429 }
430 }
431 }
432
433 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
434 #[serde(rename_all = "lowercase")]
435 pub enum Role {
436 User,
437 Model,
438 }
439
440 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
444 #[serde(rename_all = "camelCase")]
445 pub enum Part {
446 Text(String),
447 InlineData(Blob),
448 FunctionCall(FunctionCall),
449 FunctionResponse(FunctionResponse),
450 FileData(FileData),
451 ExecutableCode(ExecutableCode),
452 CodeExecutionResult(CodeExecutionResult),
453 }
454
455 impl From<String> for Part {
456 fn from(text: String) -> Self {
457 Self::Text(text)
458 }
459 }
460
461 impl From<&str> for Part {
462 fn from(text: &str) -> Self {
463 Self::Text(text.to_string())
464 }
465 }
466
467 impl FromStr for Part {
468 type Err = Infallible;
469
470 fn from_str(s: &str) -> Result<Self, Self::Err> {
471 Ok(s.into())
472 }
473 }
474
475 impl TryFrom<message::UserContent> for Part {
476 type Error = message::MessageError;
477
478 fn try_from(content: message::UserContent) -> Result<Self, Self::Error> {
479 match content {
480 message::UserContent::Text(message::Text { text }) => Ok(Self::Text(text)),
481 message::UserContent::ToolResult(message::ToolResult { id, content, .. }) => {
482 let content = match content.first() {
483 message::ToolResultContent::Text(text) => text.text,
484 message::ToolResultContent::Image(_) => {
485 return Err(message::MessageError::ConversionError(
486 "Tool result content must be text".to_string(),
487 ));
488 }
489 };
490 let result: serde_json::Value = serde_json::from_str(&content)
492 .map_err(|x| MessageError::ConversionError(x.to_string()))?;
493 Ok(Part::FunctionResponse(FunctionResponse {
494 name: id,
495 response: Some(json!({ "result": result })),
496 }))
497 }
498 message::UserContent::Image(message::Image {
499 data, media_type, ..
500 }) => match media_type {
501 Some(media_type) => match media_type {
502 message::ImageMediaType::JPEG
503 | message::ImageMediaType::PNG
504 | message::ImageMediaType::WEBP
505 | message::ImageMediaType::HEIC
506 | message::ImageMediaType::HEIF => Ok(Self::InlineData(Blob {
507 mime_type: media_type.to_mime_type().to_owned(),
508 data,
509 })),
510 _ => Err(message::MessageError::ConversionError(format!(
511 "Unsupported image media type {media_type:?}"
512 ))),
513 },
514 None => Err(message::MessageError::ConversionError(
515 "Media type for image is required for Anthropic".to_string(),
516 )),
517 },
518 message::UserContent::Document(message::Document {
519 data, media_type, ..
520 }) => match media_type {
521 Some(media_type) => match media_type {
522 message::DocumentMediaType::PDF
523 | message::DocumentMediaType::TXT
524 | message::DocumentMediaType::RTF
525 | message::DocumentMediaType::HTML
526 | message::DocumentMediaType::CSS
527 | message::DocumentMediaType::MARKDOWN
528 | message::DocumentMediaType::CSV
529 | message::DocumentMediaType::XML => Ok(Self::InlineData(Blob {
530 mime_type: media_type.to_mime_type().to_owned(),
531 data,
532 })),
533 _ => Err(message::MessageError::ConversionError(format!(
534 "Unsupported document media type {media_type:?}"
535 ))),
536 },
537 None => Err(message::MessageError::ConversionError(
538 "Media type for document is required for Anthropic".to_string(),
539 )),
540 },
541 message::UserContent::Audio(message::Audio {
542 data, media_type, ..
543 }) => match media_type {
544 Some(media_type) => Ok(Self::InlineData(Blob {
545 mime_type: media_type.to_mime_type().to_owned(),
546 data,
547 })),
548 None => Err(message::MessageError::ConversionError(
549 "Media type for audio is required for Anthropic".to_string(),
550 )),
551 },
552 }
553 }
554 }
555
556 impl From<message::AssistantContent> for Part {
557 fn from(content: message::AssistantContent) -> Self {
558 match content {
559 message::AssistantContent::Text(message::Text { text }) => text.into(),
560 message::AssistantContent::ToolCall(tool_call) => tool_call.into(),
561 }
562 }
563 }
564
565 impl From<message::ToolCall> for Part {
566 fn from(tool_call: message::ToolCall) -> Self {
567 Self::FunctionCall(FunctionCall {
568 name: tool_call.function.name,
569 args: tool_call.function.arguments,
570 })
571 }
572 }
573
574 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
577 #[serde(rename_all = "camelCase")]
578 pub struct Blob {
579 pub mime_type: String,
582 pub data: String,
584 }
585
586 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
589 pub struct FunctionCall {
590 pub name: String,
593 pub args: serde_json::Value,
595 }
596
597 impl From<FunctionCall> for message::ToolCall {
598 fn from(function_call: FunctionCall) -> Self {
599 Self {
600 id: function_call.name.clone(),
601 call_id: None,
602 function: message::ToolFunction {
603 name: function_call.name,
604 arguments: function_call.args,
605 },
606 }
607 }
608 }
609
610 impl From<message::ToolCall> for FunctionCall {
611 fn from(tool_call: message::ToolCall) -> Self {
612 Self {
613 name: tool_call.function.name,
614 args: tool_call.function.arguments,
615 }
616 }
617 }
618
619 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
623 pub struct FunctionResponse {
624 pub name: String,
627 pub response: Option<serde_json::Value>,
629 }
630
631 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
633 #[serde(rename_all = "camelCase")]
634 pub struct FileData {
635 pub mime_type: Option<String>,
637 pub file_uri: String,
639 }
640
641 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
642 pub struct SafetyRating {
643 pub category: HarmCategory,
644 pub probability: HarmProbability,
645 }
646
647 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
648 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
649 pub enum HarmProbability {
650 HarmProbabilityUnspecified,
651 Negligible,
652 Low,
653 Medium,
654 High,
655 }
656
657 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
658 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
659 pub enum HarmCategory {
660 HarmCategoryUnspecified,
661 HarmCategoryDerogatory,
662 HarmCategoryToxicity,
663 HarmCategoryViolence,
664 HarmCategorySexually,
665 HarmCategoryMedical,
666 HarmCategoryDangerous,
667 HarmCategoryHarassment,
668 HarmCategoryHateSpeech,
669 HarmCategorySexuallyExplicit,
670 HarmCategoryDangerousContent,
671 HarmCategoryCivicIntegrity,
672 }
673
674 #[derive(Debug, Deserialize, Clone, Default)]
675 #[serde(rename_all = "camelCase")]
676 pub struct UsageMetadata {
677 pub prompt_token_count: i32,
678 #[serde(skip_serializing_if = "Option::is_none")]
679 pub cached_content_token_count: Option<i32>,
680 pub candidates_token_count: i32,
681 pub total_token_count: i32,
682 }
683
684 impl std::fmt::Display for UsageMetadata {
685 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
686 write!(
687 f,
688 "Prompt token count: {}\nCached content token count: {}\nCandidates token count: {}\nTotal token count: {}",
689 self.prompt_token_count,
690 match self.cached_content_token_count {
691 Some(count) => count.to_string(),
692 None => "n/a".to_string(),
693 },
694 self.candidates_token_count,
695 self.total_token_count
696 )
697 }
698 }
699
700 #[derive(Debug, Deserialize)]
702 #[serde(rename_all = "camelCase")]
703 pub struct PromptFeedback {
704 pub block_reason: Option<BlockReason>,
706 pub safety_ratings: Option<Vec<SafetyRating>>,
708 }
709
710 #[derive(Debug, Deserialize)]
712 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
713 pub enum BlockReason {
714 BlockReasonUnspecified,
716 Safety,
718 Other,
720 Blocklist,
722 ProhibitedContent,
724 }
725
726 #[derive(Debug, Deserialize)]
727 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
728 pub enum FinishReason {
729 FinishReasonUnspecified,
731 Stop,
733 MaxTokens,
735 Safety,
737 Recitation,
739 Language,
741 Other,
743 Blocklist,
745 ProhibitedContent,
747 Spii,
749 MalformedFunctionCall,
751 }
752
753 #[derive(Debug, Deserialize)]
754 #[serde(rename_all = "camelCase")]
755 pub struct CitationMetadata {
756 pub citation_sources: Vec<CitationSource>,
757 }
758
759 #[derive(Debug, Deserialize)]
760 #[serde(rename_all = "camelCase")]
761 pub struct CitationSource {
762 #[serde(skip_serializing_if = "Option::is_none")]
763 pub uri: Option<String>,
764 #[serde(skip_serializing_if = "Option::is_none")]
765 pub start_index: Option<i32>,
766 #[serde(skip_serializing_if = "Option::is_none")]
767 pub end_index: Option<i32>,
768 #[serde(skip_serializing_if = "Option::is_none")]
769 pub license: Option<String>,
770 }
771
772 #[derive(Debug, Deserialize)]
773 #[serde(rename_all = "camelCase")]
774 pub struct LogprobsResult {
775 pub top_candidate: Vec<TopCandidate>,
776 pub chosen_candidate: Vec<LogProbCandidate>,
777 }
778
779 #[derive(Debug, Deserialize)]
780 pub struct TopCandidate {
781 pub candidates: Vec<LogProbCandidate>,
782 }
783
784 #[derive(Debug, Deserialize)]
785 #[serde(rename_all = "camelCase")]
786 pub struct LogProbCandidate {
787 pub token: String,
788 pub token_id: String,
789 pub log_probability: f64,
790 }
791
792 #[derive(Debug, Deserialize, Serialize)]
797 #[serde(rename_all = "camelCase")]
798 pub struct GenerationConfig {
799 #[serde(skip_serializing_if = "Option::is_none")]
802 pub stop_sequences: Option<Vec<String>>,
803 #[serde(skip_serializing_if = "Option::is_none")]
809 pub response_mime_type: Option<String>,
810 #[serde(skip_serializing_if = "Option::is_none")]
814 pub response_schema: Option<Schema>,
815 #[serde(skip_serializing_if = "Option::is_none")]
818 pub candidate_count: Option<i32>,
819 #[serde(skip_serializing_if = "Option::is_none")]
822 pub max_output_tokens: Option<u64>,
823 #[serde(skip_serializing_if = "Option::is_none")]
826 pub temperature: Option<f64>,
827 #[serde(skip_serializing_if = "Option::is_none")]
834 pub top_p: Option<f64>,
835 #[serde(skip_serializing_if = "Option::is_none")]
841 pub top_k: Option<i32>,
842 #[serde(skip_serializing_if = "Option::is_none")]
848 pub presence_penalty: Option<f64>,
849 #[serde(skip_serializing_if = "Option::is_none")]
857 pub frequency_penalty: Option<f64>,
858 #[serde(skip_serializing_if = "Option::is_none")]
860 pub response_logprobs: Option<bool>,
861 #[serde(skip_serializing_if = "Option::is_none")]
864 pub logprobs: Option<i32>,
865 }
866
867 impl Default for GenerationConfig {
868 fn default() -> Self {
869 Self {
870 temperature: Some(1.0),
871 max_output_tokens: Some(4096),
872 stop_sequences: None,
873 response_mime_type: None,
874 response_schema: None,
875 candidate_count: None,
876 top_p: None,
877 top_k: None,
878 presence_penalty: None,
879 frequency_penalty: None,
880 response_logprobs: None,
881 logprobs: None,
882 }
883 }
884 }
885 #[derive(Debug, Deserialize, Serialize, Clone)]
889 pub struct Schema {
890 pub r#type: String,
891 #[serde(skip_serializing_if = "Option::is_none")]
892 pub format: Option<String>,
893 #[serde(skip_serializing_if = "Option::is_none")]
894 pub description: Option<String>,
895 #[serde(skip_serializing_if = "Option::is_none")]
896 pub nullable: Option<bool>,
897 #[serde(skip_serializing_if = "Option::is_none")]
898 pub r#enum: Option<Vec<String>>,
899 #[serde(skip_serializing_if = "Option::is_none")]
900 pub max_items: Option<i32>,
901 #[serde(skip_serializing_if = "Option::is_none")]
902 pub min_items: Option<i32>,
903 #[serde(skip_serializing_if = "Option::is_none")]
904 pub properties: Option<HashMap<String, Schema>>,
905 #[serde(skip_serializing_if = "Option::is_none")]
906 pub required: Option<Vec<String>>,
907 #[serde(skip_serializing_if = "Option::is_none")]
908 pub items: Option<Box<Schema>>,
909 }
910
911 impl TryFrom<Value> for Schema {
912 type Error = CompletionError;
913
914 fn try_from(value: Value) -> Result<Self, Self::Error> {
915 if let Some(obj) = value.as_object() {
916 Ok(Schema {
917 r#type: obj
918 .get("type")
919 .and_then(|v| {
920 if v.is_string() {
921 v.as_str().map(String::from)
922 } else if v.is_array() {
923 v.as_array()
924 .and_then(|arr| arr.first())
925 .and_then(|v| v.as_str().map(String::from))
926 } else {
927 None
928 }
929 })
930 .unwrap_or_default(),
931 format: obj.get("format").and_then(|v| v.as_str()).map(String::from),
932 description: obj
933 .get("description")
934 .and_then(|v| v.as_str())
935 .map(String::from),
936 nullable: obj.get("nullable").and_then(|v| v.as_bool()),
937 r#enum: obj.get("enum").and_then(|v| v.as_array()).map(|arr| {
938 arr.iter()
939 .filter_map(|v| v.as_str().map(String::from))
940 .collect()
941 }),
942 max_items: obj
943 .get("maxItems")
944 .and_then(|v| v.as_i64())
945 .map(|v| v as i32),
946 min_items: obj
947 .get("minItems")
948 .and_then(|v| v.as_i64())
949 .map(|v| v as i32),
950 properties: obj
951 .get("properties")
952 .and_then(|v| v.as_object())
953 .map(|map| {
954 map.iter()
955 .filter_map(|(k, v)| {
956 v.clone().try_into().ok().map(|schema| (k.clone(), schema))
957 })
958 .collect()
959 }),
960 required: obj.get("required").and_then(|v| v.as_array()).map(|arr| {
961 arr.iter()
962 .filter_map(|v| v.as_str().map(String::from))
963 .collect()
964 }),
965 items: obj
966 .get("items")
967 .map(|v| Box::new(v.clone().try_into().unwrap())),
968 })
969 } else {
970 Err(CompletionError::ResponseError(
971 "Expected a JSON object for Schema".into(),
972 ))
973 }
974 }
975 }
976
977 #[derive(Debug, Serialize)]
978 #[serde(rename_all = "camelCase")]
979 pub struct GenerateContentRequest {
980 pub contents: Vec<Content>,
981 pub tools: Option<Tool>,
982 pub tool_config: Option<ToolConfig>,
983 pub generation_config: Option<GenerationConfig>,
985 pub safety_settings: Option<Vec<SafetySetting>>,
999 pub system_instruction: Option<Content>,
1002 }
1004
1005 #[derive(Debug, Serialize)]
1006 #[serde(rename_all = "camelCase")]
1007 pub struct Tool {
1008 pub function_declarations: Vec<FunctionDeclaration>,
1009 pub code_execution: Option<CodeExecution>,
1010 }
1011
1012 #[derive(Debug, Serialize, Clone)]
1013 #[serde(rename_all = "camelCase")]
1014 pub struct FunctionDeclaration {
1015 pub name: String,
1016 pub description: String,
1017 #[serde(skip_serializing_if = "Option::is_none")]
1018 pub parameters: Option<Schema>,
1019 }
1020
1021 #[derive(Debug, Serialize)]
1022 #[serde(rename_all = "camelCase")]
1023 pub struct ToolConfig {
1024 pub schema: Option<Schema>,
1025 }
1026
1027 #[derive(Debug, Serialize)]
1028 #[serde(rename_all = "camelCase")]
1029 pub struct CodeExecution {}
1030
1031 #[derive(Debug, Serialize)]
1032 #[serde(rename_all = "camelCase")]
1033 pub struct SafetySetting {
1034 pub category: HarmCategory,
1035 pub threshold: HarmBlockThreshold,
1036 }
1037
1038 #[derive(Debug, Serialize)]
1039 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1040 pub enum HarmBlockThreshold {
1041 HarmBlockThresholdUnspecified,
1042 BlockLowAndAbove,
1043 BlockMediumAndAbove,
1044 BlockOnlyHigh,
1045 BlockNone,
1046 Off,
1047 }
1048}
1049
1050#[cfg(test)]
1051mod tests {
1052 use crate::message;
1053
1054 use super::*;
1055 use serde_json::json;
1056
1057 #[test]
1058 fn test_deserialize_message_user() {
1059 let raw_message = r#"{
1060 "parts": [
1061 {"text": "Hello, world!"},
1062 {"inlineData": {"mimeType": "image/png", "data": "base64encodeddata"}},
1063 {"functionCall": {"name": "test_function", "args": {"arg1": "value1"}}},
1064 {"functionResponse": {"name": "test_function", "response": {"result": "success"}}},
1065 {"fileData": {"mimeType": "application/pdf", "fileUri": "http://example.com/file.pdf"}},
1066 {"executableCode": {"code": "print('Hello, world!')", "language": "PYTHON"}},
1067 {"codeExecutionResult": {"output": "Hello, world!", "outcome": "OUTCOME_OK"}}
1068 ],
1069 "role": "user"
1070 }"#;
1071
1072 let content: Content = {
1073 let jd = &mut serde_json::Deserializer::from_str(raw_message);
1074 serde_path_to_error::deserialize(jd).unwrap_or_else(|err| {
1075 panic!("Deserialization error at {}: {}", err.path(), err);
1076 })
1077 };
1078 assert_eq!(content.role, Some(Role::User));
1079 assert_eq!(content.parts.len(), 7);
1080
1081 let parts: Vec<Part> = content.parts.into_iter().collect();
1082
1083 if let Part::Text(text) = &parts[0] {
1084 assert_eq!(text, "Hello, world!");
1085 } else {
1086 panic!("Expected text part");
1087 }
1088
1089 if let Part::InlineData(inline_data) = &parts[1] {
1090 assert_eq!(inline_data.mime_type, "image/png");
1091 assert_eq!(inline_data.data, "base64encodeddata");
1092 } else {
1093 panic!("Expected inline data part");
1094 }
1095
1096 if let Part::FunctionCall(function_call) = &parts[2] {
1097 assert_eq!(function_call.name, "test_function");
1098 assert_eq!(
1099 function_call.args.as_object().unwrap().get("arg1").unwrap(),
1100 "value1"
1101 );
1102 } else {
1103 panic!("Expected function call part");
1104 }
1105
1106 if let Part::FunctionResponse(function_response) = &parts[3] {
1107 assert_eq!(function_response.name, "test_function");
1108 assert_eq!(
1109 function_response
1110 .response
1111 .as_ref()
1112 .unwrap()
1113 .get("result")
1114 .unwrap(),
1115 "success"
1116 );
1117 } else {
1118 panic!("Expected function response part");
1119 }
1120
1121 if let Part::FileData(file_data) = &parts[4] {
1122 assert_eq!(file_data.mime_type.as_ref().unwrap(), "application/pdf");
1123 assert_eq!(file_data.file_uri, "http://example.com/file.pdf");
1124 } else {
1125 panic!("Expected file data part");
1126 }
1127
1128 if let Part::ExecutableCode(executable_code) = &parts[5] {
1129 assert_eq!(executable_code.code, "print('Hello, world!')");
1130 } else {
1131 panic!("Expected executable code part");
1132 }
1133
1134 if let Part::CodeExecutionResult(code_execution_result) = &parts[6] {
1135 assert_eq!(
1136 code_execution_result.clone().output.unwrap(),
1137 "Hello, world!"
1138 );
1139 } else {
1140 panic!("Expected code execution result part");
1141 }
1142 }
1143
1144 #[test]
1145 fn test_deserialize_message_model() {
1146 let json_data = json!({
1147 "parts": [{"text": "Hello, user!"}],
1148 "role": "model"
1149 });
1150
1151 let content: Content = serde_json::from_value(json_data).unwrap();
1152 assert_eq!(content.role, Some(Role::Model));
1153 assert_eq!(content.parts.len(), 1);
1154 if let Part::Text(text) = &content.parts.first() {
1155 assert_eq!(text, "Hello, user!");
1156 } else {
1157 panic!("Expected text part");
1158 }
1159 }
1160
1161 #[test]
1162 fn test_message_conversion_user() {
1163 let msg = message::Message::user("Hello, world!");
1164 let content: Content = msg.try_into().unwrap();
1165 assert_eq!(content.role, Some(Role::User));
1166 assert_eq!(content.parts.len(), 1);
1167 if let Part::Text(text) = &content.parts.first() {
1168 assert_eq!(text, "Hello, world!");
1169 } else {
1170 panic!("Expected text part");
1171 }
1172 }
1173
1174 #[test]
1175 fn test_message_conversion_model() {
1176 let msg = message::Message::assistant("Hello, user!");
1177
1178 let content: Content = msg.try_into().unwrap();
1179 assert_eq!(content.role, Some(Role::Model));
1180 assert_eq!(content.parts.len(), 1);
1181 if let Part::Text(text) = &content.parts.first() {
1182 assert_eq!(text, "Hello, user!");
1183 } else {
1184 panic!("Expected text part");
1185 }
1186 }
1187
1188 #[test]
1189 fn test_message_conversion_tool_call() {
1190 let tool_call = message::ToolCall {
1191 id: "test_tool".to_string(),
1192 call_id: None,
1193 function: message::ToolFunction {
1194 name: "test_function".to_string(),
1195 arguments: json!({"arg1": "value1"}),
1196 },
1197 };
1198
1199 let msg = message::Message::Assistant {
1200 id: None,
1201 content: OneOrMany::one(message::AssistantContent::ToolCall(tool_call)),
1202 };
1203
1204 let content: Content = msg.try_into().unwrap();
1205 assert_eq!(content.role, Some(Role::Model));
1206 assert_eq!(content.parts.len(), 1);
1207 if let Part::FunctionCall(function_call) = &content.parts.first() {
1208 assert_eq!(function_call.name, "test_function");
1209 assert_eq!(
1210 function_call.args.as_object().unwrap().get("arg1").unwrap(),
1211 "value1"
1212 );
1213 } else {
1214 panic!("Expected function call part");
1215 }
1216 }
1217}