1pub const GEMINI_2_5_PRO_PREVIEW_06_05: &str = "gemini-2.5-pro-preview-06-05";
7pub const GEMINI_2_5_PRO_PREVIEW_05_06: &str = "gemini-2.5-pro-preview-05-06";
9pub const GEMINI_2_5_PRO_PREVIEW_03_25: &str = "gemini-2.5-pro-preview-03-25";
11pub const GEMINI_2_5_FLASH_PREVIEW_05_20: &str = "gemini-2.5-flash-preview-05-20";
13pub const GEMINI_2_5_FLASH_PREVIEW_04_17: &str = "gemini-2.5-flash-preview-04-17";
15pub const GEMINI_2_5_PRO_EXP_03_25: &str = "gemini-2.5-pro-exp-03-25";
17pub const GEMINI_2_0_FLASH_LITE: &str = "gemini-2.0-flash-lite";
19pub const GEMINI_2_0_FLASH: &str = "gemini-2.0-flash";
21pub const GEMINI_1_5_FLASH: &str = "gemini-1.5-flash";
23pub const GEMINI_1_5_PRO: &str = "gemini-1.5-pro";
25pub const GEMINI_1_5_PRO_8B: &str = "gemini-1.5-pro-8b";
27pub const GEMINI_1_0_PRO: &str = "gemini-1.0-pro";
29
30use self::gemini_api_types::Schema;
31use crate::providers::gemini::streaming::StreamingCompletionResponse;
32use crate::{
33 OneOrMany,
34 completion::{self, CompletionError, CompletionRequest},
35};
36use gemini_api_types::{
37 Content, FunctionDeclaration, GenerateContentRequest, GenerateContentResponse,
38 GenerationConfig, Part, Role, Tool,
39};
40use serde_json::{Map, Value};
41use std::convert::TryFrom;
42
43use super::Client;
44
45#[derive(Clone)]
50pub struct CompletionModel {
51 pub(crate) client: Client,
52 pub model: String,
53}
54
55impl CompletionModel {
56 pub fn new(client: Client, model: &str) -> Self {
57 Self {
58 client,
59 model: model.to_string(),
60 }
61 }
62}
63
64impl completion::CompletionModel for CompletionModel {
65 type Response = GenerateContentResponse;
66 type StreamingResponse = StreamingCompletionResponse;
67
68 #[cfg_attr(feature = "worker", worker::send)]
69 async fn completion(
70 &self,
71 completion_request: CompletionRequest,
72 ) -> Result<completion::CompletionResponse<GenerateContentResponse>, CompletionError> {
73 let request = create_request_body(completion_request)?;
74
75 tracing::debug!(
76 "Sending completion request to Gemini API {}",
77 serde_json::to_string_pretty(&request)?
78 );
79
80 let response = self
81 .client
82 .post(&format!("/v1beta/models/{}:generateContent", self.model))
83 .json(&request)
84 .send()
85 .await?;
86
87 if response.status().is_success() {
88 let response = response.json::<GenerateContentResponse>().await?;
89 match response.usage_metadata {
90 Some(ref usage) => tracing::info!(target: "rig",
91 "Gemini completion token usage: {}",
92 usage
93 ),
94 None => tracing::info!(target: "rig",
95 "Gemini completion token usage: n/a",
96 ),
97 }
98
99 tracing::debug!("Received response");
100
101 Ok(completion::CompletionResponse::try_from(response))
102 } else {
103 Err(CompletionError::ProviderError(response.text().await?))
104 }?
105 }
106
107 #[cfg_attr(feature = "worker", worker::send)]
108 async fn stream(
109 &self,
110 request: CompletionRequest,
111 ) -> Result<
112 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
113 CompletionError,
114 > {
115 CompletionModel::stream(self, request).await
116 }
117}
118
119pub(crate) fn create_request_body(
120 completion_request: CompletionRequest,
121) -> Result<GenerateContentRequest, CompletionError> {
122 let mut full_history = Vec::new();
123 full_history.extend(completion_request.chat_history);
124
125 let additional_params = completion_request
126 .additional_params
127 .unwrap_or_else(|| Value::Object(Map::new()));
128
129 let mut generation_config = serde_json::from_value::<GenerationConfig>(additional_params)?;
130
131 if let Some(temp) = completion_request.temperature {
132 generation_config.temperature = Some(temp);
133 }
134
135 if let Some(max_tokens) = completion_request.max_tokens {
136 generation_config.max_output_tokens = Some(max_tokens);
137 }
138
139 let system_instruction = completion_request.preamble.clone().map(|preamble| Content {
140 parts: OneOrMany::one(preamble.into()),
141 role: Some(Role::Model),
142 });
143
144 let request = GenerateContentRequest {
145 contents: full_history
146 .into_iter()
147 .map(|msg| {
148 msg.try_into()
149 .map_err(|e| CompletionError::RequestError(Box::new(e)))
150 })
151 .collect::<Result<Vec<_>, _>>()?,
152 generation_config: Some(generation_config),
153 safety_settings: None,
154 tools: Some(Tool::try_from(completion_request.tools)?),
155 tool_config: None,
156 system_instruction,
157 };
158
159 Ok(request)
160}
161
162impl TryFrom<completion::ToolDefinition> for Tool {
163 type Error = CompletionError;
164
165 fn try_from(tool: completion::ToolDefinition) -> Result<Self, Self::Error> {
166 let parameters: Option<Schema> =
167 if tool.parameters == serde_json::json!({"type": "object", "properties": {}}) {
168 None
169 } else {
170 Some(tool.parameters.try_into()?)
171 };
172 Ok(Self {
173 function_declarations: OneOrMany::one(FunctionDeclaration {
174 name: tool.name,
175 description: tool.description,
176 parameters,
177 }),
178 code_execution: None,
179 })
180 }
181}
182
183impl TryFrom<Vec<completion::ToolDefinition>> for Tool {
184 type Error = CompletionError;
185
186 fn try_from(tools: Vec<completion::ToolDefinition>) -> Result<Self, Self::Error> {
187 let mut functions = Vec::new();
188
189 for tool in tools {
190 let parameters =
191 if tool.parameters == serde_json::json!({"type": "object", "properties": {}}) {
192 None
193 } else {
194 match tool.parameters.try_into() {
195 Ok(schema) => Some(schema),
196 Err(e) => {
197 let emsg = format!(
198 "Tool '{}' could not be converted to a schema: {:?}",
199 tool.name, e,
200 );
201 return Err(CompletionError::ProviderError(emsg));
202 }
203 }
204 };
205
206 functions.push(FunctionDeclaration {
207 name: tool.name,
208 description: tool.description,
209 parameters,
210 });
211 }
212
213 let function_declarations: OneOrMany<FunctionDeclaration> = OneOrMany::many(functions)
214 .map_err(|x| CompletionError::ProviderError(x.to_string()))?;
215
216 Ok(Self {
217 function_declarations,
218 code_execution: None,
219 })
220 }
221}
222
223impl TryFrom<GenerateContentResponse> for completion::CompletionResponse<GenerateContentResponse> {
224 type Error = CompletionError;
225
226 fn try_from(response: GenerateContentResponse) -> Result<Self, Self::Error> {
227 let candidate = response.candidates.first().ok_or_else(|| {
228 CompletionError::ResponseError("No response candidates in response".into())
229 })?;
230
231 let content = candidate
232 .content
233 .parts
234 .iter()
235 .map(|part| {
236 Ok(match part {
237 Part::Text(text) => completion::AssistantContent::text(text),
238 Part::FunctionCall(function_call) => completion::AssistantContent::tool_call(
239 &function_call.name,
240 &function_call.name,
241 function_call.args.clone(),
242 ),
243 _ => {
244 return Err(CompletionError::ResponseError(
245 "Response did not contain a message or tool call".into(),
246 ));
247 }
248 })
249 })
250 .collect::<Result<Vec<_>, _>>()?;
251
252 let choice = OneOrMany::many(content).map_err(|_| {
253 CompletionError::ResponseError(
254 "Response contained no message or tool call (empty)".to_owned(),
255 )
256 })?;
257
258 Ok(completion::CompletionResponse {
259 choice,
260 raw_response: response,
261 })
262 }
263}
264
265pub mod gemini_api_types {
266 use std::{collections::HashMap, convert::Infallible, str::FromStr};
267
268 use serde::{Deserialize, Serialize};
272 use serde_json::{Value, json};
273
274 use crate::{
275 OneOrMany,
276 completion::CompletionError,
277 message::{self, MessageError, MimeType as _},
278 one_or_many::string_or_one_or_many,
279 providers::gemini::gemini_api_types::{CodeExecutionResult, ExecutableCode},
280 };
281
282 #[derive(Debug, Deserialize)]
290 #[serde(rename_all = "camelCase")]
291 pub struct GenerateContentResponse {
292 pub candidates: Vec<ContentCandidate>,
294 pub prompt_feedback: Option<PromptFeedback>,
296 pub usage_metadata: Option<UsageMetadata>,
298 pub model_version: Option<String>,
299 }
300
301 #[derive(Debug, Deserialize)]
303 #[serde(rename_all = "camelCase")]
304 pub struct ContentCandidate {
305 pub content: Content,
307 pub finish_reason: Option<FinishReason>,
310 pub safety_ratings: Option<Vec<SafetyRating>>,
313 pub citation_metadata: Option<CitationMetadata>,
317 pub token_count: Option<i32>,
319 pub avg_logprobs: Option<f64>,
321 pub logprobs_result: Option<LogprobsResult>,
323 pub index: Option<i32>,
325 }
326 #[derive(Debug, Deserialize, Serialize)]
327 pub struct Content {
328 #[serde(deserialize_with = "string_or_one_or_many")]
330 pub parts: OneOrMany<Part>,
331 pub role: Option<Role>,
334 }
335
336 impl TryFrom<message::Message> for Content {
337 type Error = message::MessageError;
338
339 fn try_from(msg: message::Message) -> Result<Self, Self::Error> {
340 Ok(match msg {
341 message::Message::User { content } => Content {
342 parts: content.try_map(|c| c.try_into())?,
343 role: Some(Role::User),
344 },
345 message::Message::Assistant { content, .. } => Content {
346 role: Some(Role::Model),
347 parts: content.map(|content| content.into()),
348 },
349 })
350 }
351 }
352
353 impl TryFrom<Content> for message::Message {
354 type Error = message::MessageError;
355
356 fn try_from(content: Content) -> Result<Self, Self::Error> {
357 match content.role {
358 Some(Role::User) | None => Ok(message::Message::User {
359 content: content.parts.try_map(|part| {
360 Ok(match part {
361 Part::Text(text) => message::UserContent::text(text),
362 Part::InlineData(inline_data) => {
363 let mime_type =
364 message::MediaType::from_mime_type(&inline_data.mime_type);
365
366 match mime_type {
367 Some(message::MediaType::Image(media_type)) => {
368 message::UserContent::image(
369 inline_data.data,
370 Some(message::ContentFormat::default()),
371 Some(media_type),
372 Some(message::ImageDetail::default()),
373 )
374 }
375 Some(message::MediaType::Document(media_type)) => {
376 message::UserContent::document(
377 inline_data.data,
378 Some(message::ContentFormat::default()),
379 Some(media_type),
380 )
381 }
382 Some(message::MediaType::Audio(media_type)) => {
383 message::UserContent::audio(
384 inline_data.data,
385 Some(message::ContentFormat::default()),
386 Some(media_type),
387 )
388 }
389 _ => {
390 return Err(message::MessageError::ConversionError(
391 format!("Unsupported media type {mime_type:?}"),
392 ));
393 }
394 }
395 }
396 _ => {
397 return Err(message::MessageError::ConversionError(format!(
398 "Unsupported gemini content part type: {part:?}"
399 )));
400 }
401 })
402 })?,
403 }),
404 Some(Role::Model) => Ok(message::Message::Assistant {
405 id: None,
406 content: content.parts.try_map(|part| {
407 Ok(match part {
408 Part::Text(text) => message::AssistantContent::text(text),
409 Part::FunctionCall(function_call) => {
410 message::AssistantContent::ToolCall(function_call.into())
411 }
412 _ => {
413 return Err(message::MessageError::ConversionError(format!(
414 "Unsupported part type: {part:?}"
415 )));
416 }
417 })
418 })?,
419 }),
420 }
421 }
422 }
423
424 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
425 #[serde(rename_all = "lowercase")]
426 pub enum Role {
427 User,
428 Model,
429 }
430
431 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
435 #[serde(rename_all = "camelCase")]
436 pub enum Part {
437 Text(String),
438 InlineData(Blob),
439 FunctionCall(FunctionCall),
440 FunctionResponse(FunctionResponse),
441 FileData(FileData),
442 ExecutableCode(ExecutableCode),
443 CodeExecutionResult(CodeExecutionResult),
444 }
445
446 impl From<String> for Part {
447 fn from(text: String) -> Self {
448 Self::Text(text)
449 }
450 }
451
452 impl From<&str> for Part {
453 fn from(text: &str) -> Self {
454 Self::Text(text.to_string())
455 }
456 }
457
458 impl FromStr for Part {
459 type Err = Infallible;
460
461 fn from_str(s: &str) -> Result<Self, Self::Err> {
462 Ok(s.into())
463 }
464 }
465
466 impl TryFrom<message::UserContent> for Part {
467 type Error = message::MessageError;
468
469 fn try_from(content: message::UserContent) -> Result<Self, Self::Error> {
470 match content {
471 message::UserContent::Text(message::Text { text }) => Ok(Self::Text(text)),
472 message::UserContent::ToolResult(message::ToolResult { id, content, .. }) => {
473 let content = match content.first() {
474 message::ToolResultContent::Text(text) => text.text,
475 message::ToolResultContent::Image(_) => {
476 return Err(message::MessageError::ConversionError(
477 "Tool result content must be text".to_string(),
478 ));
479 }
480 };
481 let result: serde_json::Value = serde_json::from_str(&content)
483 .map_err(|x| MessageError::ConversionError(x.to_string()))?;
484 Ok(Part::FunctionResponse(FunctionResponse {
485 name: id,
486 response: Some(json!({ "result": result })),
487 }))
488 }
489 message::UserContent::Image(message::Image {
490 data, media_type, ..
491 }) => match media_type {
492 Some(media_type) => match media_type {
493 message::ImageMediaType::JPEG
494 | message::ImageMediaType::PNG
495 | message::ImageMediaType::WEBP
496 | message::ImageMediaType::HEIC
497 | message::ImageMediaType::HEIF => Ok(Self::InlineData(Blob {
498 mime_type: media_type.to_mime_type().to_owned(),
499 data,
500 })),
501 _ => Err(message::MessageError::ConversionError(format!(
502 "Unsupported image media type {media_type:?}"
503 ))),
504 },
505 None => Err(message::MessageError::ConversionError(
506 "Media type for image is required for Anthropic".to_string(),
507 )),
508 },
509 message::UserContent::Document(message::Document {
510 data, media_type, ..
511 }) => match media_type {
512 Some(media_type) => match media_type {
513 message::DocumentMediaType::PDF
514 | message::DocumentMediaType::TXT
515 | message::DocumentMediaType::RTF
516 | message::DocumentMediaType::HTML
517 | message::DocumentMediaType::CSS
518 | message::DocumentMediaType::MARKDOWN
519 | message::DocumentMediaType::CSV
520 | message::DocumentMediaType::XML => Ok(Self::InlineData(Blob {
521 mime_type: media_type.to_mime_type().to_owned(),
522 data,
523 })),
524 _ => Err(message::MessageError::ConversionError(format!(
525 "Unsupported document media type {media_type:?}"
526 ))),
527 },
528 None => Err(message::MessageError::ConversionError(
529 "Media type for document is required for Anthropic".to_string(),
530 )),
531 },
532 message::UserContent::Audio(message::Audio {
533 data, media_type, ..
534 }) => match media_type {
535 Some(media_type) => Ok(Self::InlineData(Blob {
536 mime_type: media_type.to_mime_type().to_owned(),
537 data,
538 })),
539 None => Err(message::MessageError::ConversionError(
540 "Media type for audio is required for Anthropic".to_string(),
541 )),
542 },
543 }
544 }
545 }
546
547 impl From<message::AssistantContent> for Part {
548 fn from(content: message::AssistantContent) -> Self {
549 match content {
550 message::AssistantContent::Text(message::Text { text }) => text.into(),
551 message::AssistantContent::ToolCall(tool_call) => tool_call.into(),
552 }
553 }
554 }
555
556 impl From<message::ToolCall> for Part {
557 fn from(tool_call: message::ToolCall) -> Self {
558 Self::FunctionCall(FunctionCall {
559 name: tool_call.function.name,
560 args: tool_call.function.arguments,
561 })
562 }
563 }
564
565 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
568 #[serde(rename_all = "camelCase")]
569 pub struct Blob {
570 pub mime_type: String,
573 pub data: String,
575 }
576
577 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
580 pub struct FunctionCall {
581 pub name: String,
584 pub args: serde_json::Value,
586 }
587
588 impl From<FunctionCall> for message::ToolCall {
589 fn from(function_call: FunctionCall) -> Self {
590 Self {
591 id: function_call.name.clone(),
592 call_id: None,
593 function: message::ToolFunction {
594 name: function_call.name,
595 arguments: function_call.args,
596 },
597 }
598 }
599 }
600
601 impl From<message::ToolCall> for FunctionCall {
602 fn from(tool_call: message::ToolCall) -> Self {
603 Self {
604 name: tool_call.function.name,
605 args: tool_call.function.arguments,
606 }
607 }
608 }
609
610 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
614 pub struct FunctionResponse {
615 pub name: String,
618 pub response: Option<serde_json::Value>,
620 }
621
622 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
624 #[serde(rename_all = "camelCase")]
625 pub struct FileData {
626 pub mime_type: Option<String>,
628 pub file_uri: String,
630 }
631
632 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
633 pub struct SafetyRating {
634 pub category: HarmCategory,
635 pub probability: HarmProbability,
636 }
637
638 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
639 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
640 pub enum HarmProbability {
641 HarmProbabilityUnspecified,
642 Negligible,
643 Low,
644 Medium,
645 High,
646 }
647
648 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
649 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
650 pub enum HarmCategory {
651 HarmCategoryUnspecified,
652 HarmCategoryDerogatory,
653 HarmCategoryToxicity,
654 HarmCategoryViolence,
655 HarmCategorySexually,
656 HarmCategoryMedical,
657 HarmCategoryDangerous,
658 HarmCategoryHarassment,
659 HarmCategoryHateSpeech,
660 HarmCategorySexuallyExplicit,
661 HarmCategoryDangerousContent,
662 HarmCategoryCivicIntegrity,
663 }
664
665 #[derive(Debug, Deserialize, Clone, Default)]
666 #[serde(rename_all = "camelCase")]
667 pub struct UsageMetadata {
668 pub prompt_token_count: i32,
669 #[serde(skip_serializing_if = "Option::is_none")]
670 pub cached_content_token_count: Option<i32>,
671 pub candidates_token_count: i32,
672 pub total_token_count: i32,
673 }
674
675 impl std::fmt::Display for UsageMetadata {
676 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
677 write!(
678 f,
679 "Prompt token count: {}\nCached content token count: {}\nCandidates token count: {}\nTotal token count: {}",
680 self.prompt_token_count,
681 match self.cached_content_token_count {
682 Some(count) => count.to_string(),
683 None => "n/a".to_string(),
684 },
685 self.candidates_token_count,
686 self.total_token_count
687 )
688 }
689 }
690
691 #[derive(Debug, Deserialize)]
693 #[serde(rename_all = "camelCase")]
694 pub struct PromptFeedback {
695 pub block_reason: Option<BlockReason>,
697 pub safety_ratings: Option<Vec<SafetyRating>>,
699 }
700
701 #[derive(Debug, Deserialize)]
703 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
704 pub enum BlockReason {
705 BlockReasonUnspecified,
707 Safety,
709 Other,
711 Blocklist,
713 ProhibitedContent,
715 }
716
717 #[derive(Debug, Deserialize)]
718 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
719 pub enum FinishReason {
720 FinishReasonUnspecified,
722 Stop,
724 MaxTokens,
726 Safety,
728 Recitation,
730 Language,
732 Other,
734 Blocklist,
736 ProhibitedContent,
738 Spii,
740 MalformedFunctionCall,
742 }
743
744 #[derive(Debug, Deserialize)]
745 #[serde(rename_all = "camelCase")]
746 pub struct CitationMetadata {
747 pub citation_sources: Vec<CitationSource>,
748 }
749
750 #[derive(Debug, Deserialize)]
751 #[serde(rename_all = "camelCase")]
752 pub struct CitationSource {
753 #[serde(skip_serializing_if = "Option::is_none")]
754 pub uri: Option<String>,
755 #[serde(skip_serializing_if = "Option::is_none")]
756 pub start_index: Option<i32>,
757 #[serde(skip_serializing_if = "Option::is_none")]
758 pub end_index: Option<i32>,
759 #[serde(skip_serializing_if = "Option::is_none")]
760 pub license: Option<String>,
761 }
762
763 #[derive(Debug, Deserialize)]
764 #[serde(rename_all = "camelCase")]
765 pub struct LogprobsResult {
766 pub top_candidate: Vec<TopCandidate>,
767 pub chosen_candidate: Vec<LogProbCandidate>,
768 }
769
770 #[derive(Debug, Deserialize)]
771 pub struct TopCandidate {
772 pub candidates: Vec<LogProbCandidate>,
773 }
774
775 #[derive(Debug, Deserialize)]
776 #[serde(rename_all = "camelCase")]
777 pub struct LogProbCandidate {
778 pub token: String,
779 pub token_id: String,
780 pub log_probability: f64,
781 }
782
783 #[derive(Debug, Deserialize, Serialize)]
788 #[serde(rename_all = "camelCase")]
789 pub struct GenerationConfig {
790 #[serde(skip_serializing_if = "Option::is_none")]
793 pub stop_sequences: Option<Vec<String>>,
794 #[serde(skip_serializing_if = "Option::is_none")]
800 pub response_mime_type: Option<String>,
801 #[serde(skip_serializing_if = "Option::is_none")]
805 pub response_schema: Option<Schema>,
806 #[serde(skip_serializing_if = "Option::is_none")]
809 pub candidate_count: Option<i32>,
810 #[serde(skip_serializing_if = "Option::is_none")]
813 pub max_output_tokens: Option<u64>,
814 #[serde(skip_serializing_if = "Option::is_none")]
817 pub temperature: Option<f64>,
818 #[serde(skip_serializing_if = "Option::is_none")]
825 pub top_p: Option<f64>,
826 #[serde(skip_serializing_if = "Option::is_none")]
832 pub top_k: Option<i32>,
833 #[serde(skip_serializing_if = "Option::is_none")]
839 pub presence_penalty: Option<f64>,
840 #[serde(skip_serializing_if = "Option::is_none")]
848 pub frequency_penalty: Option<f64>,
849 #[serde(skip_serializing_if = "Option::is_none")]
851 pub response_logprobs: Option<bool>,
852 #[serde(skip_serializing_if = "Option::is_none")]
855 pub logprobs: Option<i32>,
856 }
857
858 impl Default for GenerationConfig {
859 fn default() -> Self {
860 Self {
861 temperature: Some(1.0),
862 max_output_tokens: Some(4096),
863 stop_sequences: None,
864 response_mime_type: None,
865 response_schema: None,
866 candidate_count: None,
867 top_p: None,
868 top_k: None,
869 presence_penalty: None,
870 frequency_penalty: None,
871 response_logprobs: None,
872 logprobs: None,
873 }
874 }
875 }
876 #[derive(Debug, Deserialize, Serialize, Clone)]
880 pub struct Schema {
881 pub r#type: String,
882 #[serde(skip_serializing_if = "Option::is_none")]
883 pub format: Option<String>,
884 #[serde(skip_serializing_if = "Option::is_none")]
885 pub description: Option<String>,
886 #[serde(skip_serializing_if = "Option::is_none")]
887 pub nullable: Option<bool>,
888 #[serde(skip_serializing_if = "Option::is_none")]
889 pub r#enum: Option<Vec<String>>,
890 #[serde(skip_serializing_if = "Option::is_none")]
891 pub max_items: Option<i32>,
892 #[serde(skip_serializing_if = "Option::is_none")]
893 pub min_items: Option<i32>,
894 #[serde(skip_serializing_if = "Option::is_none")]
895 pub properties: Option<HashMap<String, Schema>>,
896 #[serde(skip_serializing_if = "Option::is_none")]
897 pub required: Option<Vec<String>>,
898 #[serde(skip_serializing_if = "Option::is_none")]
899 pub items: Option<Box<Schema>>,
900 }
901
902 impl TryFrom<Value> for Schema {
903 type Error = CompletionError;
904
905 fn try_from(value: Value) -> Result<Self, Self::Error> {
906 if let Some(obj) = value.as_object() {
907 Ok(Schema {
908 r#type: obj
909 .get("type")
910 .and_then(|v| {
911 if v.is_string() {
912 v.as_str().map(String::from)
913 } else if v.is_array() {
914 v.as_array()
915 .and_then(|arr| arr.first())
916 .and_then(|v| v.as_str().map(String::from))
917 } else {
918 None
919 }
920 })
921 .unwrap_or_default(),
922 format: obj.get("format").and_then(|v| v.as_str()).map(String::from),
923 description: obj
924 .get("description")
925 .and_then(|v| v.as_str())
926 .map(String::from),
927 nullable: obj.get("nullable").and_then(|v| v.as_bool()),
928 r#enum: obj.get("enum").and_then(|v| v.as_array()).map(|arr| {
929 arr.iter()
930 .filter_map(|v| v.as_str().map(String::from))
931 .collect()
932 }),
933 max_items: obj
934 .get("maxItems")
935 .and_then(|v| v.as_i64())
936 .map(|v| v as i32),
937 min_items: obj
938 .get("minItems")
939 .and_then(|v| v.as_i64())
940 .map(|v| v as i32),
941 properties: obj
942 .get("properties")
943 .and_then(|v| v.as_object())
944 .map(|map| {
945 map.iter()
946 .filter_map(|(k, v)| {
947 v.clone().try_into().ok().map(|schema| (k.clone(), schema))
948 })
949 .collect()
950 }),
951 required: obj.get("required").and_then(|v| v.as_array()).map(|arr| {
952 arr.iter()
953 .filter_map(|v| v.as_str().map(String::from))
954 .collect()
955 }),
956 items: obj
957 .get("items")
958 .map(|v| Box::new(v.clone().try_into().unwrap())),
959 })
960 } else {
961 Err(CompletionError::ResponseError(
962 "Expected a JSON object for Schema".into(),
963 ))
964 }
965 }
966 }
967
968 #[derive(Debug, Serialize)]
969 #[serde(rename_all = "camelCase")]
970 pub struct GenerateContentRequest {
971 pub contents: Vec<Content>,
972 pub tools: Option<Tool>,
973 pub tool_config: Option<ToolConfig>,
974 pub generation_config: Option<GenerationConfig>,
976 pub safety_settings: Option<Vec<SafetySetting>>,
990 pub system_instruction: Option<Content>,
993 }
995
996 #[derive(Debug, Serialize)]
997 #[serde(rename_all = "camelCase")]
998 pub struct Tool {
999 pub function_declarations: OneOrMany<FunctionDeclaration>,
1000 pub code_execution: Option<CodeExecution>,
1001 }
1002
1003 #[derive(Debug, Serialize, Clone)]
1004 #[serde(rename_all = "camelCase")]
1005 pub struct FunctionDeclaration {
1006 pub name: String,
1007 pub description: String,
1008 #[serde(skip_serializing_if = "Option::is_none")]
1009 pub parameters: Option<Schema>,
1010 }
1011
1012 #[derive(Debug, Serialize)]
1013 #[serde(rename_all = "camelCase")]
1014 pub struct ToolConfig {
1015 pub schema: Option<Schema>,
1016 }
1017
1018 #[derive(Debug, Serialize)]
1019 #[serde(rename_all = "camelCase")]
1020 pub struct CodeExecution {}
1021
1022 #[derive(Debug, Serialize)]
1023 #[serde(rename_all = "camelCase")]
1024 pub struct SafetySetting {
1025 pub category: HarmCategory,
1026 pub threshold: HarmBlockThreshold,
1027 }
1028
1029 #[derive(Debug, Serialize)]
1030 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1031 pub enum HarmBlockThreshold {
1032 HarmBlockThresholdUnspecified,
1033 BlockLowAndAbove,
1034 BlockMediumAndAbove,
1035 BlockOnlyHigh,
1036 BlockNone,
1037 Off,
1038 }
1039}
1040
1041#[cfg(test)]
1042mod tests {
1043 use crate::message;
1044
1045 use super::*;
1046 use serde_json::json;
1047
1048 #[test]
1049 fn test_deserialize_message_user() {
1050 let raw_message = r#"{
1051 "parts": [
1052 {"text": "Hello, world!"},
1053 {"inlineData": {"mimeType": "image/png", "data": "base64encodeddata"}},
1054 {"functionCall": {"name": "test_function", "args": {"arg1": "value1"}}},
1055 {"functionResponse": {"name": "test_function", "response": {"result": "success"}}},
1056 {"fileData": {"mimeType": "application/pdf", "fileUri": "http://example.com/file.pdf"}},
1057 {"executableCode": {"code": "print('Hello, world!')", "language": "PYTHON"}},
1058 {"codeExecutionResult": {"output": "Hello, world!", "outcome": "OUTCOME_OK"}}
1059 ],
1060 "role": "user"
1061 }"#;
1062
1063 let content: Content = {
1064 let jd = &mut serde_json::Deserializer::from_str(raw_message);
1065 serde_path_to_error::deserialize(jd).unwrap_or_else(|err| {
1066 panic!("Deserialization error at {}: {}", err.path(), err);
1067 })
1068 };
1069 assert_eq!(content.role, Some(Role::User));
1070 assert_eq!(content.parts.len(), 7);
1071
1072 let parts: Vec<Part> = content.parts.into_iter().collect();
1073
1074 if let Part::Text(text) = &parts[0] {
1075 assert_eq!(text, "Hello, world!");
1076 } else {
1077 panic!("Expected text part");
1078 }
1079
1080 if let Part::InlineData(inline_data) = &parts[1] {
1081 assert_eq!(inline_data.mime_type, "image/png");
1082 assert_eq!(inline_data.data, "base64encodeddata");
1083 } else {
1084 panic!("Expected inline data part");
1085 }
1086
1087 if let Part::FunctionCall(function_call) = &parts[2] {
1088 assert_eq!(function_call.name, "test_function");
1089 assert_eq!(
1090 function_call.args.as_object().unwrap().get("arg1").unwrap(),
1091 "value1"
1092 );
1093 } else {
1094 panic!("Expected function call part");
1095 }
1096
1097 if let Part::FunctionResponse(function_response) = &parts[3] {
1098 assert_eq!(function_response.name, "test_function");
1099 assert_eq!(
1100 function_response
1101 .response
1102 .as_ref()
1103 .unwrap()
1104 .get("result")
1105 .unwrap(),
1106 "success"
1107 );
1108 } else {
1109 panic!("Expected function response part");
1110 }
1111
1112 if let Part::FileData(file_data) = &parts[4] {
1113 assert_eq!(file_data.mime_type.as_ref().unwrap(), "application/pdf");
1114 assert_eq!(file_data.file_uri, "http://example.com/file.pdf");
1115 } else {
1116 panic!("Expected file data part");
1117 }
1118
1119 if let Part::ExecutableCode(executable_code) = &parts[5] {
1120 assert_eq!(executable_code.code, "print('Hello, world!')");
1121 } else {
1122 panic!("Expected executable code part");
1123 }
1124
1125 if let Part::CodeExecutionResult(code_execution_result) = &parts[6] {
1126 assert_eq!(
1127 code_execution_result.clone().output.unwrap(),
1128 "Hello, world!"
1129 );
1130 } else {
1131 panic!("Expected code execution result part");
1132 }
1133 }
1134
1135 #[test]
1136 fn test_deserialize_message_model() {
1137 let json_data = json!({
1138 "parts": [{"text": "Hello, user!"}],
1139 "role": "model"
1140 });
1141
1142 let content: Content = serde_json::from_value(json_data).unwrap();
1143 assert_eq!(content.role, Some(Role::Model));
1144 assert_eq!(content.parts.len(), 1);
1145 if let Part::Text(text) = &content.parts.first() {
1146 assert_eq!(text, "Hello, user!");
1147 } else {
1148 panic!("Expected text part");
1149 }
1150 }
1151
1152 #[test]
1153 fn test_message_conversion_user() {
1154 let msg = message::Message::user("Hello, world!");
1155 let content: Content = msg.try_into().unwrap();
1156 assert_eq!(content.role, Some(Role::User));
1157 assert_eq!(content.parts.len(), 1);
1158 if let Part::Text(text) = &content.parts.first() {
1159 assert_eq!(text, "Hello, world!");
1160 } else {
1161 panic!("Expected text part");
1162 }
1163 }
1164
1165 #[test]
1166 fn test_message_conversion_model() {
1167 let msg = message::Message::assistant("Hello, user!");
1168
1169 let content: Content = msg.try_into().unwrap();
1170 assert_eq!(content.role, Some(Role::Model));
1171 assert_eq!(content.parts.len(), 1);
1172 if let Part::Text(text) = &content.parts.first() {
1173 assert_eq!(text, "Hello, user!");
1174 } else {
1175 panic!("Expected text part");
1176 }
1177 }
1178
1179 #[test]
1180 fn test_message_conversion_tool_call() {
1181 let tool_call = message::ToolCall {
1182 id: "test_tool".to_string(),
1183 call_id: None,
1184 function: message::ToolFunction {
1185 name: "test_function".to_string(),
1186 arguments: json!({"arg1": "value1"}),
1187 },
1188 };
1189
1190 let msg = message::Message::Assistant {
1191 id: None,
1192 content: OneOrMany::one(message::AssistantContent::ToolCall(tool_call)),
1193 };
1194
1195 let content: Content = msg.try_into().unwrap();
1196 assert_eq!(content.role, Some(Role::Model));
1197 assert_eq!(content.parts.len(), 1);
1198 if let Part::FunctionCall(function_call) = &content.parts.first() {
1199 assert_eq!(function_call.name, "test_function");
1200 assert_eq!(
1201 function_call.args.as_object().unwrap().get("arg1").unwrap(),
1202 "value1"
1203 );
1204 } else {
1205 panic!("Expected function call part");
1206 }
1207 }
1208}