1pub const GEMINI_2_5_PRO_PREVIEW_06_05: &str = "gemini-2.5-pro-preview-06-05";
7pub const GEMINI_2_5_PRO_PREVIEW_05_06: &str = "gemini-2.5-pro-preview-05-06";
9pub const GEMINI_2_5_PRO_PREVIEW_03_25: &str = "gemini-2.5-pro-preview-03-25";
11pub const GEMINI_2_5_FLASH_PREVIEW_05_20: &str = "gemini-2.5-flash-preview-05-20";
13pub const GEMINI_2_5_FLASH_PREVIEW_04_17: &str = "gemini-2.5-flash-preview-04-17";
15pub const GEMINI_2_5_PRO_EXP_03_25: &str = "gemini-2.5-pro-exp-03-25";
17pub const GEMINI_2_0_FLASH_LITE: &str = "gemini-2.0-flash-lite";
19pub const GEMINI_2_0_FLASH: &str = "gemini-2.0-flash";
21pub const GEMINI_1_5_FLASH: &str = "gemini-1.5-flash";
23pub const GEMINI_1_5_PRO: &str = "gemini-1.5-pro";
25pub const GEMINI_1_5_PRO_8B: &str = "gemini-1.5-pro-8b";
27pub const GEMINI_1_0_PRO: &str = "gemini-1.0-pro";
29
30use self::gemini_api_types::Schema;
31use crate::providers::gemini::streaming::StreamingCompletionResponse;
32use crate::{
33 completion::{self, CompletionError, CompletionRequest},
34 OneOrMany,
35};
36use gemini_api_types::{
37 Content, FunctionDeclaration, GenerateContentRequest, GenerateContentResponse,
38 GenerationConfig, Part, Role, Tool,
39};
40use serde_json::{Map, Value};
41use std::convert::TryFrom;
42
43use super::Client;
44
45#[derive(Clone)]
50pub struct CompletionModel {
51 pub(crate) client: Client,
52 pub model: String,
53}
54
55impl CompletionModel {
56 pub fn new(client: Client, model: &str) -> Self {
57 Self {
58 client,
59 model: model.to_string(),
60 }
61 }
62}
63
64impl completion::CompletionModel for CompletionModel {
65 type Response = GenerateContentResponse;
66 type StreamingResponse = StreamingCompletionResponse;
67
68 #[cfg_attr(feature = "worker", worker::send)]
69 async fn completion(
70 &self,
71 completion_request: CompletionRequest,
72 ) -> Result<completion::CompletionResponse<GenerateContentResponse>, CompletionError> {
73 let request = create_request_body(completion_request)?;
74
75 tracing::debug!(
76 "Sending completion request to Gemini API {}",
77 serde_json::to_string_pretty(&request)?
78 );
79
80 let response = self
81 .client
82 .post(&format!("/v1beta/models/{}:generateContent", self.model))
83 .json(&request)
84 .send()
85 .await?;
86
87 if response.status().is_success() {
88 let response = response.json::<GenerateContentResponse>().await?;
89 match response.usage_metadata {
90 Some(ref usage) => tracing::info!(target: "rig",
91 "Gemini completion token usage: {}",
92 usage
93 ),
94 None => tracing::info!(target: "rig",
95 "Gemini completion token usage: n/a",
96 ),
97 }
98
99 tracing::debug!("Received response");
100
101 Ok(completion::CompletionResponse::try_from(response))
102 } else {
103 Err(CompletionError::ProviderError(response.text().await?))
104 }?
105 }
106
107 #[cfg_attr(feature = "worker", worker::send)]
108 async fn stream(
109 &self,
110 request: CompletionRequest,
111 ) -> Result<
112 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
113 CompletionError,
114 > {
115 CompletionModel::stream(self, request).await
116 }
117}
118
119pub(crate) fn create_request_body(
120 completion_request: CompletionRequest,
121) -> Result<GenerateContentRequest, CompletionError> {
122 let mut full_history = Vec::new();
123 full_history.extend(completion_request.chat_history);
124
125 let additional_params = completion_request
126 .additional_params
127 .unwrap_or_else(|| Value::Object(Map::new()));
128
129 let mut generation_config = serde_json::from_value::<GenerationConfig>(additional_params)?;
130
131 if let Some(temp) = completion_request.temperature {
132 generation_config.temperature = Some(temp);
133 }
134
135 if let Some(max_tokens) = completion_request.max_tokens {
136 generation_config.max_output_tokens = Some(max_tokens);
137 }
138
139 let system_instruction = completion_request.preamble.clone().map(|preamble| Content {
140 parts: OneOrMany::one(preamble.into()),
141 role: Some(Role::Model),
142 });
143
144 let request = GenerateContentRequest {
145 contents: full_history
146 .into_iter()
147 .map(|msg| {
148 msg.try_into()
149 .map_err(|e| CompletionError::RequestError(Box::new(e)))
150 })
151 .collect::<Result<Vec<_>, _>>()?,
152 generation_config: Some(generation_config),
153 safety_settings: None,
154 tools: Some(
155 completion_request
156 .tools
157 .into_iter()
158 .map(Tool::try_from)
159 .collect::<Result<Vec<_>, _>>()?,
160 ),
161 tool_config: None,
162 system_instruction,
163 };
164
165 Ok(request)
166}
167
168impl TryFrom<completion::ToolDefinition> for Tool {
169 type Error = CompletionError;
170
171 fn try_from(tool: completion::ToolDefinition) -> Result<Self, Self::Error> {
172 let parameters: Option<Schema> =
173 if tool.parameters == serde_json::json!({"type": "object", "properties": {}}) {
174 None
175 } else {
176 Some(tool.parameters.try_into()?)
177 };
178 Ok(Self {
179 function_declarations: FunctionDeclaration {
180 name: tool.name,
181 description: tool.description,
182 parameters,
183 },
184 code_execution: None,
185 })
186 }
187}
188
189impl TryFrom<GenerateContentResponse> for completion::CompletionResponse<GenerateContentResponse> {
190 type Error = CompletionError;
191
192 fn try_from(response: GenerateContentResponse) -> Result<Self, Self::Error> {
193 let candidate = response.candidates.first().ok_or_else(|| {
194 CompletionError::ResponseError("No response candidates in response".into())
195 })?;
196
197 let content = candidate
198 .content
199 .parts
200 .iter()
201 .map(|part| {
202 Ok(match part {
203 Part::Text(text) => completion::AssistantContent::text(text),
204 Part::FunctionCall(function_call) => completion::AssistantContent::tool_call(
205 &function_call.name,
206 &function_call.name,
207 function_call.args.clone(),
208 ),
209 _ => {
210 return Err(CompletionError::ResponseError(
211 "Response did not contain a message or tool call".into(),
212 ))
213 }
214 })
215 })
216 .collect::<Result<Vec<_>, _>>()?;
217
218 let choice = OneOrMany::many(content).map_err(|_| {
219 CompletionError::ResponseError(
220 "Response contained no message or tool call (empty)".to_owned(),
221 )
222 })?;
223
224 Ok(completion::CompletionResponse {
225 choice,
226 raw_response: response,
227 })
228 }
229}
230
231pub mod gemini_api_types {
232 use std::{collections::HashMap, convert::Infallible, str::FromStr};
233
234 use serde::{Deserialize, Serialize};
238 use serde_json::Value;
239
240 use crate::{
241 completion::CompletionError,
242 message::{self, MimeType as _},
243 one_or_many::string_or_one_or_many,
244 providers::gemini::gemini_api_types::{CodeExecutionResult, ExecutableCode},
245 OneOrMany,
246 };
247
248 #[derive(Debug, Deserialize)]
256 #[serde(rename_all = "camelCase")]
257 pub struct GenerateContentResponse {
258 pub candidates: Vec<ContentCandidate>,
260 pub prompt_feedback: Option<PromptFeedback>,
262 pub usage_metadata: Option<UsageMetadata>,
264 pub model_version: Option<String>,
265 }
266
267 #[derive(Debug, Deserialize)]
269 #[serde(rename_all = "camelCase")]
270 pub struct ContentCandidate {
271 pub content: Content,
273 pub finish_reason: Option<FinishReason>,
276 pub safety_ratings: Option<Vec<SafetyRating>>,
279 pub citation_metadata: Option<CitationMetadata>,
283 pub token_count: Option<i32>,
285 pub avg_logprobs: Option<f64>,
287 pub logprobs_result: Option<LogprobsResult>,
289 pub index: Option<i32>,
291 }
292 #[derive(Debug, Deserialize, Serialize)]
293 pub struct Content {
294 #[serde(deserialize_with = "string_or_one_or_many")]
296 pub parts: OneOrMany<Part>,
297 pub role: Option<Role>,
300 }
301
302 impl TryFrom<message::Message> for Content {
303 type Error = message::MessageError;
304
305 fn try_from(msg: message::Message) -> Result<Self, Self::Error> {
306 Ok(match msg {
307 message::Message::User { content } => Content {
308 parts: content.try_map(|c| c.try_into())?,
309 role: Some(Role::User),
310 },
311 message::Message::Assistant { content } => Content {
312 role: Some(Role::Model),
313 parts: content.map(|content| content.into()),
314 },
315 })
316 }
317 }
318
319 impl TryFrom<Content> for message::Message {
320 type Error = message::MessageError;
321
322 fn try_from(content: Content) -> Result<Self, Self::Error> {
323 match content.role {
324 Some(Role::User) | None => Ok(message::Message::User {
325 content: content.parts.try_map(|part| {
326 Ok(match part {
327 Part::Text(text) => message::UserContent::text(text),
328 Part::InlineData(inline_data) => {
329 let mime_type =
330 message::MediaType::from_mime_type(&inline_data.mime_type);
331
332 match mime_type {
333 Some(message::MediaType::Image(media_type)) => {
334 message::UserContent::image(
335 inline_data.data,
336 Some(message::ContentFormat::default()),
337 Some(media_type),
338 Some(message::ImageDetail::default()),
339 )
340 }
341 Some(message::MediaType::Document(media_type)) => {
342 message::UserContent::document(
343 inline_data.data,
344 Some(message::ContentFormat::default()),
345 Some(media_type),
346 )
347 }
348 Some(message::MediaType::Audio(media_type)) => {
349 message::UserContent::audio(
350 inline_data.data,
351 Some(message::ContentFormat::default()),
352 Some(media_type),
353 )
354 }
355 _ => {
356 return Err(message::MessageError::ConversionError(
357 format!("Unsupported media type {mime_type:?}"),
358 ))
359 }
360 }
361 }
362 _ => {
363 return Err(message::MessageError::ConversionError(format!(
364 "Unsupported gemini content part type: {part:?}"
365 )))
366 }
367 })
368 })?,
369 }),
370 Some(Role::Model) => Ok(message::Message::Assistant {
371 content: content.parts.try_map(|part| {
372 Ok(match part {
373 Part::Text(text) => message::AssistantContent::text(text),
374 Part::FunctionCall(function_call) => {
375 message::AssistantContent::ToolCall(function_call.into())
376 }
377 _ => {
378 return Err(message::MessageError::ConversionError(format!(
379 "Unsupported part type: {part:?}"
380 )))
381 }
382 })
383 })?,
384 }),
385 }
386 }
387 }
388
389 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
390 #[serde(rename_all = "lowercase")]
391 pub enum Role {
392 User,
393 Model,
394 }
395
396 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
400 #[serde(rename_all = "camelCase")]
401 pub enum Part {
402 Text(String),
403 InlineData(Blob),
404 FunctionCall(FunctionCall),
405 FunctionResponse(FunctionResponse),
406 FileData(FileData),
407 ExecutableCode(ExecutableCode),
408 CodeExecutionResult(CodeExecutionResult),
409 }
410
411 impl From<String> for Part {
412 fn from(text: String) -> Self {
413 Self::Text(text)
414 }
415 }
416
417 impl From<&str> for Part {
418 fn from(text: &str) -> Self {
419 Self::Text(text.to_string())
420 }
421 }
422
423 impl FromStr for Part {
424 type Err = Infallible;
425
426 fn from_str(s: &str) -> Result<Self, Self::Err> {
427 Ok(s.into())
428 }
429 }
430
431 impl TryFrom<message::UserContent> for Part {
432 type Error = message::MessageError;
433
434 fn try_from(content: message::UserContent) -> Result<Self, Self::Error> {
435 match content {
436 message::UserContent::Text(message::Text { text }) => Ok(Self::Text(text)),
437 message::UserContent::ToolResult(message::ToolResult { id, content }) => {
438 let content = match content.first() {
439 message::ToolResultContent::Text(text) => text.text,
440 message::ToolResultContent::Image(_) => {
441 return Err(message::MessageError::ConversionError(
442 "Tool result content must be text".to_string(),
443 ))
444 }
445 };
446 Ok(Part::FunctionResponse(FunctionResponse {
447 name: id,
448 response: Some(serde_json::from_str(&content).map_err(|e| {
449 message::MessageError::ConversionError(format!(
450 "Failed to parse tool response: {e}"
451 ))
452 })?),
453 }))
454 }
455 message::UserContent::Image(message::Image {
456 data, media_type, ..
457 }) => match media_type {
458 Some(media_type) => match media_type {
459 message::ImageMediaType::JPEG
460 | message::ImageMediaType::PNG
461 | message::ImageMediaType::WEBP
462 | message::ImageMediaType::HEIC
463 | message::ImageMediaType::HEIF => Ok(Self::InlineData(Blob {
464 mime_type: media_type.to_mime_type().to_owned(),
465 data,
466 })),
467 _ => Err(message::MessageError::ConversionError(format!(
468 "Unsupported image media type {media_type:?}"
469 ))),
470 },
471 None => Err(message::MessageError::ConversionError(
472 "Media type for image is required for Anthropic".to_string(),
473 )),
474 },
475 message::UserContent::Document(message::Document {
476 data, media_type, ..
477 }) => match media_type {
478 Some(media_type) => match media_type {
479 message::DocumentMediaType::PDF
480 | message::DocumentMediaType::TXT
481 | message::DocumentMediaType::RTF
482 | message::DocumentMediaType::HTML
483 | message::DocumentMediaType::CSS
484 | message::DocumentMediaType::MARKDOWN
485 | message::DocumentMediaType::CSV
486 | message::DocumentMediaType::XML => Ok(Self::InlineData(Blob {
487 mime_type: media_type.to_mime_type().to_owned(),
488 data,
489 })),
490 _ => Err(message::MessageError::ConversionError(format!(
491 "Unsupported document media type {media_type:?}"
492 ))),
493 },
494 None => Err(message::MessageError::ConversionError(
495 "Media type for document is required for Anthropic".to_string(),
496 )),
497 },
498 message::UserContent::Audio(message::Audio {
499 data, media_type, ..
500 }) => match media_type {
501 Some(media_type) => Ok(Self::InlineData(Blob {
502 mime_type: media_type.to_mime_type().to_owned(),
503 data,
504 })),
505 None => Err(message::MessageError::ConversionError(
506 "Media type for audio is required for Anthropic".to_string(),
507 )),
508 },
509 }
510 }
511 }
512
513 impl From<message::AssistantContent> for Part {
514 fn from(content: message::AssistantContent) -> Self {
515 match content {
516 message::AssistantContent::Text(message::Text { text }) => text.into(),
517 message::AssistantContent::ToolCall(tool_call) => tool_call.into(),
518 }
519 }
520 }
521
522 impl From<message::ToolCall> for Part {
523 fn from(tool_call: message::ToolCall) -> Self {
524 Self::FunctionCall(FunctionCall {
525 name: tool_call.function.name,
526 args: tool_call.function.arguments,
527 })
528 }
529 }
530
531 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
534 #[serde(rename_all = "camelCase")]
535 pub struct Blob {
536 pub mime_type: String,
539 pub data: String,
541 }
542
543 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
546 pub struct FunctionCall {
547 pub name: String,
550 pub args: serde_json::Value,
552 }
553
554 impl From<FunctionCall> for message::ToolCall {
555 fn from(function_call: FunctionCall) -> Self {
556 Self {
557 id: function_call.name.clone(),
558 function: message::ToolFunction {
559 name: function_call.name,
560 arguments: function_call.args,
561 },
562 }
563 }
564 }
565
566 impl From<message::ToolCall> for FunctionCall {
567 fn from(tool_call: message::ToolCall) -> Self {
568 Self {
569 name: tool_call.function.name,
570 args: tool_call.function.arguments,
571 }
572 }
573 }
574
575 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
579 pub struct FunctionResponse {
580 pub name: String,
583 pub response: Option<HashMap<String, serde_json::Value>>,
585 }
586
587 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
589 #[serde(rename_all = "camelCase")]
590 pub struct FileData {
591 pub mime_type: Option<String>,
593 pub file_uri: String,
595 }
596
597 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
598 pub struct SafetyRating {
599 pub category: HarmCategory,
600 pub probability: HarmProbability,
601 }
602
603 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
604 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
605 pub enum HarmProbability {
606 HarmProbabilityUnspecified,
607 Negligible,
608 Low,
609 Medium,
610 High,
611 }
612
613 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
614 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
615 pub enum HarmCategory {
616 HarmCategoryUnspecified,
617 HarmCategoryDerogatory,
618 HarmCategoryToxicity,
619 HarmCategoryViolence,
620 HarmCategorySexually,
621 HarmCategoryMedical,
622 HarmCategoryDangerous,
623 HarmCategoryHarassment,
624 HarmCategoryHateSpeech,
625 HarmCategorySexuallyExplicit,
626 HarmCategoryDangerousContent,
627 HarmCategoryCivicIntegrity,
628 }
629
630 #[derive(Debug, Deserialize, Clone, Default)]
631 #[serde(rename_all = "camelCase")]
632 pub struct UsageMetadata {
633 pub prompt_token_count: i32,
634 #[serde(skip_serializing_if = "Option::is_none")]
635 pub cached_content_token_count: Option<i32>,
636 pub candidates_token_count: i32,
637 pub total_token_count: i32,
638 }
639
640 impl std::fmt::Display for UsageMetadata {
641 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
642 write!(
643 f,
644 "Prompt token count: {}\nCached content token count: {}\nCandidates token count: {}\nTotal token count: {}",
645 self.prompt_token_count,
646 match self.cached_content_token_count {
647 Some(count) => count.to_string(),
648 None => "n/a".to_string(),
649 },
650 self.candidates_token_count,
651 self.total_token_count
652 )
653 }
654 }
655
656 #[derive(Debug, Deserialize)]
658 #[serde(rename_all = "camelCase")]
659 pub struct PromptFeedback {
660 pub block_reason: Option<BlockReason>,
662 pub safety_ratings: Option<Vec<SafetyRating>>,
664 }
665
666 #[derive(Debug, Deserialize)]
668 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
669 pub enum BlockReason {
670 BlockReasonUnspecified,
672 Safety,
674 Other,
676 Blocklist,
678 ProhibitedContent,
680 }
681
682 #[derive(Debug, Deserialize)]
683 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
684 pub enum FinishReason {
685 FinishReasonUnspecified,
687 Stop,
689 MaxTokens,
691 Safety,
693 Recitation,
695 Language,
697 Other,
699 Blocklist,
701 ProhibitedContent,
703 Spii,
705 MalformedFunctionCall,
707 }
708
709 #[derive(Debug, Deserialize)]
710 #[serde(rename_all = "camelCase")]
711 pub struct CitationMetadata {
712 pub citation_sources: Vec<CitationSource>,
713 }
714
715 #[derive(Debug, Deserialize)]
716 #[serde(rename_all = "camelCase")]
717 pub struct CitationSource {
718 #[serde(skip_serializing_if = "Option::is_none")]
719 pub uri: Option<String>,
720 #[serde(skip_serializing_if = "Option::is_none")]
721 pub start_index: Option<i32>,
722 #[serde(skip_serializing_if = "Option::is_none")]
723 pub end_index: Option<i32>,
724 #[serde(skip_serializing_if = "Option::is_none")]
725 pub license: Option<String>,
726 }
727
728 #[derive(Debug, Deserialize)]
729 #[serde(rename_all = "camelCase")]
730 pub struct LogprobsResult {
731 pub top_candidate: Vec<TopCandidate>,
732 pub chosen_candidate: Vec<LogProbCandidate>,
733 }
734
735 #[derive(Debug, Deserialize)]
736 pub struct TopCandidate {
737 pub candidates: Vec<LogProbCandidate>,
738 }
739
740 #[derive(Debug, Deserialize)]
741 #[serde(rename_all = "camelCase")]
742 pub struct LogProbCandidate {
743 pub token: String,
744 pub token_id: String,
745 pub log_probability: f64,
746 }
747
748 #[derive(Debug, Deserialize, Serialize)]
753 #[serde(rename_all = "camelCase")]
754 pub struct GenerationConfig {
755 #[serde(skip_serializing_if = "Option::is_none")]
758 pub stop_sequences: Option<Vec<String>>,
759 #[serde(skip_serializing_if = "Option::is_none")]
765 pub response_mime_type: Option<String>,
766 #[serde(skip_serializing_if = "Option::is_none")]
770 pub response_schema: Option<Schema>,
771 #[serde(skip_serializing_if = "Option::is_none")]
774 pub candidate_count: Option<i32>,
775 #[serde(skip_serializing_if = "Option::is_none")]
778 pub max_output_tokens: Option<u64>,
779 #[serde(skip_serializing_if = "Option::is_none")]
782 pub temperature: Option<f64>,
783 #[serde(skip_serializing_if = "Option::is_none")]
790 pub top_p: Option<f64>,
791 #[serde(skip_serializing_if = "Option::is_none")]
797 pub top_k: Option<i32>,
798 #[serde(skip_serializing_if = "Option::is_none")]
804 pub presence_penalty: Option<f64>,
805 #[serde(skip_serializing_if = "Option::is_none")]
813 pub frequency_penalty: Option<f64>,
814 #[serde(skip_serializing_if = "Option::is_none")]
816 pub response_logprobs: Option<bool>,
817 #[serde(skip_serializing_if = "Option::is_none")]
820 pub logprobs: Option<i32>,
821 }
822
823 impl Default for GenerationConfig {
824 fn default() -> Self {
825 Self {
826 temperature: Some(1.0),
827 max_output_tokens: Some(4096),
828 stop_sequences: None,
829 response_mime_type: None,
830 response_schema: None,
831 candidate_count: None,
832 top_p: None,
833 top_k: None,
834 presence_penalty: None,
835 frequency_penalty: None,
836 response_logprobs: None,
837 logprobs: None,
838 }
839 }
840 }
841 #[derive(Debug, Deserialize, Serialize)]
845 pub struct Schema {
846 pub r#type: String,
847 #[serde(skip_serializing_if = "Option::is_none")]
848 pub format: Option<String>,
849 #[serde(skip_serializing_if = "Option::is_none")]
850 pub description: Option<String>,
851 #[serde(skip_serializing_if = "Option::is_none")]
852 pub nullable: Option<bool>,
853 #[serde(skip_serializing_if = "Option::is_none")]
854 pub r#enum: Option<Vec<String>>,
855 #[serde(skip_serializing_if = "Option::is_none")]
856 pub max_items: Option<i32>,
857 #[serde(skip_serializing_if = "Option::is_none")]
858 pub min_items: Option<i32>,
859 #[serde(skip_serializing_if = "Option::is_none")]
860 pub properties: Option<HashMap<String, Schema>>,
861 #[serde(skip_serializing_if = "Option::is_none")]
862 pub required: Option<Vec<String>>,
863 #[serde(skip_serializing_if = "Option::is_none")]
864 pub items: Option<Box<Schema>>,
865 }
866
867 impl TryFrom<Value> for Schema {
868 type Error = CompletionError;
869
870 fn try_from(value: Value) -> Result<Self, Self::Error> {
871 if let Some(obj) = value.as_object() {
872 Ok(Schema {
873 r#type: obj
874 .get("type")
875 .and_then(|v| {
876 if v.is_string() {
877 v.as_str().map(String::from)
878 } else if v.is_array() {
879 v.as_array()
880 .and_then(|arr| arr.first())
881 .and_then(|v| v.as_str().map(String::from))
882 } else {
883 None
884 }
885 })
886 .unwrap_or_default(),
887 format: obj.get("format").and_then(|v| v.as_str()).map(String::from),
888 description: obj
889 .get("description")
890 .and_then(|v| v.as_str())
891 .map(String::from),
892 nullable: obj.get("nullable").and_then(|v| v.as_bool()),
893 r#enum: obj.get("enum").and_then(|v| v.as_array()).map(|arr| {
894 arr.iter()
895 .filter_map(|v| v.as_str().map(String::from))
896 .collect()
897 }),
898 max_items: obj
899 .get("maxItems")
900 .and_then(|v| v.as_i64())
901 .map(|v| v as i32),
902 min_items: obj
903 .get("minItems")
904 .and_then(|v| v.as_i64())
905 .map(|v| v as i32),
906 properties: obj
907 .get("properties")
908 .and_then(|v| v.as_object())
909 .map(|map| {
910 map.iter()
911 .filter_map(|(k, v)| {
912 v.clone().try_into().ok().map(|schema| (k.clone(), schema))
913 })
914 .collect()
915 }),
916 required: obj.get("required").and_then(|v| v.as_array()).map(|arr| {
917 arr.iter()
918 .filter_map(|v| v.as_str().map(String::from))
919 .collect()
920 }),
921 items: obj
922 .get("items")
923 .map(|v| Box::new(v.clone().try_into().unwrap())),
924 })
925 } else {
926 Err(CompletionError::ResponseError(
927 "Expected a JSON object for Schema".into(),
928 ))
929 }
930 }
931 }
932
933 #[derive(Debug, Serialize)]
934 #[serde(rename_all = "camelCase")]
935 pub struct GenerateContentRequest {
936 pub contents: Vec<Content>,
937 pub tools: Option<Vec<Tool>>,
938 pub tool_config: Option<ToolConfig>,
939 pub generation_config: Option<GenerationConfig>,
941 pub safety_settings: Option<Vec<SafetySetting>>,
955 pub system_instruction: Option<Content>,
958 }
960
961 #[derive(Debug, Serialize)]
962 #[serde(rename_all = "camelCase")]
963 pub struct Tool {
964 pub function_declarations: FunctionDeclaration,
965 pub code_execution: Option<CodeExecution>,
966 }
967
968 #[derive(Debug, Serialize)]
969 #[serde(rename_all = "camelCase")]
970 pub struct FunctionDeclaration {
971 pub name: String,
972 pub description: String,
973 #[serde(skip_serializing_if = "Option::is_none")]
974 pub parameters: Option<Schema>,
975 }
976
977 #[derive(Debug, Serialize)]
978 #[serde(rename_all = "camelCase")]
979 pub struct ToolConfig {
980 pub schema: Option<Schema>,
981 }
982
983 #[derive(Debug, Serialize)]
984 #[serde(rename_all = "camelCase")]
985 pub struct CodeExecution {}
986
987 #[derive(Debug, Serialize)]
988 #[serde(rename_all = "camelCase")]
989 pub struct SafetySetting {
990 pub category: HarmCategory,
991 pub threshold: HarmBlockThreshold,
992 }
993
994 #[derive(Debug, Serialize)]
995 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
996 pub enum HarmBlockThreshold {
997 HarmBlockThresholdUnspecified,
998 BlockLowAndAbove,
999 BlockMediumAndAbove,
1000 BlockOnlyHigh,
1001 BlockNone,
1002 Off,
1003 }
1004}
1005
1006#[cfg(test)]
1007mod tests {
1008 use crate::message;
1009
1010 use super::*;
1011 use serde_json::json;
1012
1013 #[test]
1014 fn test_deserialize_message_user() {
1015 let raw_message = r#"{
1016 "parts": [
1017 {"text": "Hello, world!"},
1018 {"inlineData": {"mimeType": "image/png", "data": "base64encodeddata"}},
1019 {"functionCall": {"name": "test_function", "args": {"arg1": "value1"}}},
1020 {"functionResponse": {"name": "test_function", "response": {"result": "success"}}},
1021 {"fileData": {"mimeType": "application/pdf", "fileUri": "http://example.com/file.pdf"}},
1022 {"executableCode": {"code": "print('Hello, world!')", "language": "PYTHON"}},
1023 {"codeExecutionResult": {"output": "Hello, world!", "outcome": "OUTCOME_OK"}}
1024 ],
1025 "role": "user"
1026 }"#;
1027
1028 let content: Content = {
1029 let jd = &mut serde_json::Deserializer::from_str(raw_message);
1030 serde_path_to_error::deserialize(jd).unwrap_or_else(|err| {
1031 panic!("Deserialization error at {}: {}", err.path(), err);
1032 })
1033 };
1034 assert_eq!(content.role, Some(Role::User));
1035 assert_eq!(content.parts.len(), 7);
1036
1037 let parts: Vec<Part> = content.parts.into_iter().collect();
1038
1039 if let Part::Text(text) = &parts[0] {
1040 assert_eq!(text, "Hello, world!");
1041 } else {
1042 panic!("Expected text part");
1043 }
1044
1045 if let Part::InlineData(inline_data) = &parts[1] {
1046 assert_eq!(inline_data.mime_type, "image/png");
1047 assert_eq!(inline_data.data, "base64encodeddata");
1048 } else {
1049 panic!("Expected inline data part");
1050 }
1051
1052 if let Part::FunctionCall(function_call) = &parts[2] {
1053 assert_eq!(function_call.name, "test_function");
1054 assert_eq!(
1055 function_call.args.as_object().unwrap().get("arg1").unwrap(),
1056 "value1"
1057 );
1058 } else {
1059 panic!("Expected function call part");
1060 }
1061
1062 if let Part::FunctionResponse(function_response) = &parts[3] {
1063 assert_eq!(function_response.name, "test_function");
1064 assert_eq!(
1065 function_response
1066 .response
1067 .as_ref()
1068 .unwrap()
1069 .get("result")
1070 .unwrap(),
1071 "success"
1072 );
1073 } else {
1074 panic!("Expected function response part");
1075 }
1076
1077 if let Part::FileData(file_data) = &parts[4] {
1078 assert_eq!(file_data.mime_type.as_ref().unwrap(), "application/pdf");
1079 assert_eq!(file_data.file_uri, "http://example.com/file.pdf");
1080 } else {
1081 panic!("Expected file data part");
1082 }
1083
1084 if let Part::ExecutableCode(executable_code) = &parts[5] {
1085 assert_eq!(executable_code.code, "print('Hello, world!')");
1086 } else {
1087 panic!("Expected executable code part");
1088 }
1089
1090 if let Part::CodeExecutionResult(code_execution_result) = &parts[6] {
1091 assert_eq!(
1092 code_execution_result.clone().output.unwrap(),
1093 "Hello, world!"
1094 );
1095 } else {
1096 panic!("Expected code execution result part");
1097 }
1098 }
1099
1100 #[test]
1101 fn test_deserialize_message_model() {
1102 let json_data = json!({
1103 "parts": [{"text": "Hello, user!"}],
1104 "role": "model"
1105 });
1106
1107 let content: Content = serde_json::from_value(json_data).unwrap();
1108 assert_eq!(content.role, Some(Role::Model));
1109 assert_eq!(content.parts.len(), 1);
1110 if let Part::Text(text) = &content.parts.first() {
1111 assert_eq!(text, "Hello, user!");
1112 } else {
1113 panic!("Expected text part");
1114 }
1115 }
1116
1117 #[test]
1118 fn test_message_conversion_user() {
1119 let msg = message::Message::user("Hello, world!");
1120 let content: Content = msg.try_into().unwrap();
1121 assert_eq!(content.role, Some(Role::User));
1122 assert_eq!(content.parts.len(), 1);
1123 if let Part::Text(text) = &content.parts.first() {
1124 assert_eq!(text, "Hello, world!");
1125 } else {
1126 panic!("Expected text part");
1127 }
1128 }
1129
1130 #[test]
1131 fn test_message_conversion_model() {
1132 let msg = message::Message::assistant("Hello, user!");
1133
1134 let content: Content = msg.try_into().unwrap();
1135 assert_eq!(content.role, Some(Role::Model));
1136 assert_eq!(content.parts.len(), 1);
1137 if let Part::Text(text) = &content.parts.first() {
1138 assert_eq!(text, "Hello, user!");
1139 } else {
1140 panic!("Expected text part");
1141 }
1142 }
1143
1144 #[test]
1145 fn test_message_conversion_tool_call() {
1146 let tool_call = message::ToolCall {
1147 id: "test_tool".to_string(),
1148 function: message::ToolFunction {
1149 name: "test_function".to_string(),
1150 arguments: json!({"arg1": "value1"}),
1151 },
1152 };
1153
1154 let msg = message::Message::Assistant {
1155 content: OneOrMany::one(message::AssistantContent::ToolCall(tool_call)),
1156 };
1157
1158 let content: Content = msg.try_into().unwrap();
1159 assert_eq!(content.role, Some(Role::Model));
1160 assert_eq!(content.parts.len(), 1);
1161 if let Part::FunctionCall(function_call) = &content.parts.first() {
1162 assert_eq!(function_call.name, "test_function");
1163 assert_eq!(
1164 function_call.args.as_object().unwrap().get("arg1").unwrap(),
1165 "value1"
1166 );
1167 } else {
1168 panic!("Expected function call part");
1169 }
1170 }
1171}