1pub const GEMINI_2_5_PRO_PREVIEW_06_05: &str = "gemini-2.5-pro-preview-06-05";
7pub const GEMINI_2_5_PRO_PREVIEW_05_06: &str = "gemini-2.5-pro-preview-05-06";
9pub const GEMINI_2_5_PRO_PREVIEW_03_25: &str = "gemini-2.5-pro-preview-03-25";
11pub const GEMINI_2_5_FLASH_PREVIEW_05_20: &str = "gemini-2.5-flash-preview-05-20";
13pub const GEMINI_2_5_FLASH_PREVIEW_04_17: &str = "gemini-2.5-flash-preview-04-17";
15pub const GEMINI_2_5_PRO_EXP_03_25: &str = "gemini-2.5-pro-exp-03-25";
17pub const GEMINI_2_0_FLASH_LITE: &str = "gemini-2.0-flash-lite";
19pub const GEMINI_2_0_FLASH: &str = "gemini-2.0-flash";
21pub const GEMINI_1_5_FLASH: &str = "gemini-1.5-flash";
23pub const GEMINI_1_5_PRO: &str = "gemini-1.5-pro";
25pub const GEMINI_1_5_PRO_8B: &str = "gemini-1.5-pro-8b";
27pub const GEMINI_1_0_PRO: &str = "gemini-1.0-pro";
29
30use self::gemini_api_types::Schema;
31use crate::providers::gemini::streaming::StreamingCompletionResponse;
32use crate::{
33 OneOrMany,
34 completion::{self, CompletionError, CompletionRequest},
35};
36use gemini_api_types::{
37 Content, FunctionDeclaration, GenerateContentRequest, GenerateContentResponse,
38 GenerationConfig, Part, Role, Tool,
39};
40use serde_json::{Map, Value};
41use std::convert::TryFrom;
42
43use super::Client;
44
45#[derive(Clone)]
50pub struct CompletionModel {
51 pub(crate) client: Client,
52 pub model: String,
53}
54
55impl CompletionModel {
56 pub fn new(client: Client, model: &str) -> Self {
57 Self {
58 client,
59 model: model.to_string(),
60 }
61 }
62}
63
64impl completion::CompletionModel for CompletionModel {
65 type Response = GenerateContentResponse;
66 type StreamingResponse = StreamingCompletionResponse;
67
68 #[cfg_attr(feature = "worker", worker::send)]
69 async fn completion(
70 &self,
71 completion_request: CompletionRequest,
72 ) -> Result<completion::CompletionResponse<GenerateContentResponse>, CompletionError> {
73 let request = create_request_body(completion_request)?;
74
75 tracing::debug!(
76 "Sending completion request to Gemini API {}",
77 serde_json::to_string_pretty(&request)?
78 );
79
80 let response = self
81 .client
82 .post(&format!("/v1beta/models/{}:generateContent", self.model))
83 .json(&request)
84 .send()
85 .await?;
86
87 if response.status().is_success() {
88 let response = response.json::<GenerateContentResponse>().await?;
89 match response.usage_metadata {
90 Some(ref usage) => tracing::info!(target: "rig",
91 "Gemini completion token usage: {}",
92 usage
93 ),
94 None => tracing::info!(target: "rig",
95 "Gemini completion token usage: n/a",
96 ),
97 }
98
99 tracing::debug!("Received response");
100
101 Ok(completion::CompletionResponse::try_from(response))
102 } else {
103 Err(CompletionError::ProviderError(response.text().await?))
104 }?
105 }
106
107 #[cfg_attr(feature = "worker", worker::send)]
108 async fn stream(
109 &self,
110 request: CompletionRequest,
111 ) -> Result<
112 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
113 CompletionError,
114 > {
115 CompletionModel::stream(self, request).await
116 }
117}
118
119pub(crate) fn create_request_body(
120 completion_request: CompletionRequest,
121) -> Result<GenerateContentRequest, CompletionError> {
122 let mut full_history = Vec::new();
123 full_history.extend(completion_request.chat_history);
124
125 let additional_params = completion_request
126 .additional_params
127 .unwrap_or_else(|| Value::Object(Map::new()));
128
129 let mut generation_config = serde_json::from_value::<GenerationConfig>(additional_params)?;
130
131 if let Some(temp) = completion_request.temperature {
132 generation_config.temperature = Some(temp);
133 }
134
135 if let Some(max_tokens) = completion_request.max_tokens {
136 generation_config.max_output_tokens = Some(max_tokens);
137 }
138
139 let system_instruction = completion_request.preamble.clone().map(|preamble| Content {
140 parts: OneOrMany::one(preamble.into()),
141 role: Some(Role::Model),
142 });
143
144 let tools = if completion_request.tools.is_empty() {
145 None
146 } else {
147 Some(Tool::try_from(completion_request.tools)?)
148 };
149
150 let request = GenerateContentRequest {
151 contents: full_history
152 .into_iter()
153 .map(|msg| {
154 msg.try_into()
155 .map_err(|e| CompletionError::RequestError(Box::new(e)))
156 })
157 .collect::<Result<Vec<_>, _>>()?,
158 generation_config: Some(generation_config),
159 safety_settings: None,
160 tools,
161 tool_config: None,
162 system_instruction,
163 };
164
165 Ok(request)
166}
167
168impl TryFrom<completion::ToolDefinition> for Tool {
169 type Error = CompletionError;
170
171 fn try_from(tool: completion::ToolDefinition) -> Result<Self, Self::Error> {
172 let parameters: Option<Schema> =
173 if tool.parameters == serde_json::json!({"type": "object", "properties": {}}) {
174 None
175 } else {
176 Some(tool.parameters.try_into()?)
177 };
178
179 Ok(Self {
180 function_declarations: vec![FunctionDeclaration {
181 name: tool.name,
182 description: tool.description,
183 parameters,
184 }],
185 code_execution: None,
186 })
187 }
188}
189
190impl TryFrom<Vec<completion::ToolDefinition>> for Tool {
191 type Error = CompletionError;
192
193 fn try_from(tools: Vec<completion::ToolDefinition>) -> Result<Self, Self::Error> {
194 let mut function_declarations = Vec::new();
195
196 for tool in tools {
197 let parameters =
198 if tool.parameters == serde_json::json!({"type": "object", "properties": {}}) {
199 None
200 } else {
201 match tool.parameters.try_into() {
202 Ok(schema) => Some(schema),
203 Err(e) => {
204 let emsg = format!(
205 "Tool '{}' could not be converted to a schema: {:?}",
206 tool.name, e,
207 );
208 return Err(CompletionError::ProviderError(emsg));
209 }
210 }
211 };
212
213 function_declarations.push(FunctionDeclaration {
214 name: tool.name,
215 description: tool.description,
216 parameters,
217 });
218 }
219
220 Ok(Self {
221 function_declarations,
222 code_execution: None,
223 })
224 }
225}
226
227impl TryFrom<GenerateContentResponse> for completion::CompletionResponse<GenerateContentResponse> {
228 type Error = CompletionError;
229
230 fn try_from(response: GenerateContentResponse) -> Result<Self, Self::Error> {
231 let candidate = response.candidates.first().ok_or_else(|| {
232 CompletionError::ResponseError("No response candidates in response".into())
233 })?;
234
235 let content = candidate
236 .content
237 .parts
238 .iter()
239 .map(|part| {
240 Ok(match part {
241 Part::Text(text) => completion::AssistantContent::text(text),
242 Part::FunctionCall(function_call) => completion::AssistantContent::tool_call(
243 &function_call.name,
244 &function_call.name,
245 function_call.args.clone(),
246 ),
247 _ => {
248 return Err(CompletionError::ResponseError(
249 "Response did not contain a message or tool call".into(),
250 ));
251 }
252 })
253 })
254 .collect::<Result<Vec<_>, _>>()?;
255
256 let choice = OneOrMany::many(content).map_err(|_| {
257 CompletionError::ResponseError(
258 "Response contained no message or tool call (empty)".to_owned(),
259 )
260 })?;
261
262 let usage = response
263 .usage_metadata
264 .as_ref()
265 .map(|usage| completion::Usage {
266 input_tokens: usage.prompt_token_count as u64,
267 output_tokens: usage.candidates_token_count as u64,
268 total_tokens: usage.total_token_count as u64,
269 })
270 .unwrap_or_default();
271
272 Ok(completion::CompletionResponse {
273 choice,
274 usage,
275 raw_response: response,
276 })
277 }
278}
279
280pub mod gemini_api_types {
281 use std::{collections::HashMap, convert::Infallible, str::FromStr};
282
283 use serde::{Deserialize, Serialize};
287 use serde_json::{Value, json};
288
289 use crate::{
290 OneOrMany,
291 completion::CompletionError,
292 message::{self, MessageError, MimeType as _},
293 one_or_many::string_or_one_or_many,
294 providers::gemini::gemini_api_types::{CodeExecutionResult, ExecutableCode},
295 };
296
297 #[derive(Debug, Deserialize, Serialize)]
305 #[serde(rename_all = "camelCase")]
306 pub struct GenerateContentResponse {
307 pub candidates: Vec<ContentCandidate>,
309 pub prompt_feedback: Option<PromptFeedback>,
311 pub usage_metadata: Option<UsageMetadata>,
313 pub model_version: Option<String>,
314 }
315
316 #[derive(Debug, Deserialize, Serialize)]
318 #[serde(rename_all = "camelCase")]
319 pub struct ContentCandidate {
320 pub content: Content,
322 pub finish_reason: Option<FinishReason>,
325 pub safety_ratings: Option<Vec<SafetyRating>>,
328 pub citation_metadata: Option<CitationMetadata>,
332 pub token_count: Option<i32>,
334 pub avg_logprobs: Option<f64>,
336 pub logprobs_result: Option<LogprobsResult>,
338 pub index: Option<i32>,
340 }
341 #[derive(Debug, Deserialize, Serialize)]
342 pub struct Content {
343 #[serde(deserialize_with = "string_or_one_or_many")]
345 pub parts: OneOrMany<Part>,
346 pub role: Option<Role>,
349 }
350
351 impl TryFrom<message::Message> for Content {
352 type Error = message::MessageError;
353
354 fn try_from(msg: message::Message) -> Result<Self, Self::Error> {
355 Ok(match msg {
356 message::Message::User { content } => Content {
357 parts: content.try_map(|c| c.try_into())?,
358 role: Some(Role::User),
359 },
360 message::Message::Assistant { content, .. } => Content {
361 role: Some(Role::Model),
362 parts: content.map(|content| content.into()),
363 },
364 })
365 }
366 }
367
368 impl TryFrom<Content> for message::Message {
369 type Error = message::MessageError;
370
371 fn try_from(content: Content) -> Result<Self, Self::Error> {
372 match content.role {
373 Some(Role::User) | None => Ok(message::Message::User {
374 content: content.parts.try_map(|part| {
375 Ok(match part {
376 Part::Text(text) => message::UserContent::text(text),
377 Part::InlineData(inline_data) => {
378 let mime_type =
379 message::MediaType::from_mime_type(&inline_data.mime_type);
380
381 match mime_type {
382 Some(message::MediaType::Image(media_type)) => {
383 message::UserContent::image(
384 inline_data.data,
385 Some(message::ContentFormat::default()),
386 Some(media_type),
387 Some(message::ImageDetail::default()),
388 )
389 }
390 Some(message::MediaType::Document(media_type)) => {
391 message::UserContent::document(
392 inline_data.data,
393 Some(message::ContentFormat::default()),
394 Some(media_type),
395 )
396 }
397 Some(message::MediaType::Audio(media_type)) => {
398 message::UserContent::audio(
399 inline_data.data,
400 Some(message::ContentFormat::default()),
401 Some(media_type),
402 )
403 }
404 _ => {
405 return Err(message::MessageError::ConversionError(
406 format!("Unsupported media type {mime_type:?}"),
407 ));
408 }
409 }
410 }
411 _ => {
412 return Err(message::MessageError::ConversionError(format!(
413 "Unsupported gemini content part type: {part:?}"
414 )));
415 }
416 })
417 })?,
418 }),
419 Some(Role::Model) => Ok(message::Message::Assistant {
420 id: None,
421 content: content.parts.try_map(|part| {
422 Ok(match part {
423 Part::Text(text) => message::AssistantContent::text(text),
424 Part::FunctionCall(function_call) => {
425 message::AssistantContent::ToolCall(function_call.into())
426 }
427 _ => {
428 return Err(message::MessageError::ConversionError(format!(
429 "Unsupported part type: {part:?}"
430 )));
431 }
432 })
433 })?,
434 }),
435 }
436 }
437 }
438
439 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
440 #[serde(rename_all = "lowercase")]
441 pub enum Role {
442 User,
443 Model,
444 }
445
446 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
450 #[serde(rename_all = "camelCase")]
451 pub enum Part {
452 Text(String),
453 InlineData(Blob),
454 FunctionCall(FunctionCall),
455 FunctionResponse(FunctionResponse),
456 FileData(FileData),
457 ExecutableCode(ExecutableCode),
458 CodeExecutionResult(CodeExecutionResult),
459 Thought { thoughts: Vec<String> },
460 }
461
462 impl From<String> for Part {
463 fn from(text: String) -> Self {
464 Self::Text(text)
465 }
466 }
467
468 impl From<&str> for Part {
469 fn from(text: &str) -> Self {
470 Self::Text(text.to_string())
471 }
472 }
473
474 impl FromStr for Part {
475 type Err = Infallible;
476
477 fn from_str(s: &str) -> Result<Self, Self::Err> {
478 Ok(s.into())
479 }
480 }
481
482 impl TryFrom<message::UserContent> for Part {
483 type Error = message::MessageError;
484
485 fn try_from(content: message::UserContent) -> Result<Self, Self::Error> {
486 match content {
487 message::UserContent::Text(message::Text { text }) => Ok(Self::Text(text)),
488 message::UserContent::ToolResult(message::ToolResult { id, content, .. }) => {
489 let content = match content.first() {
490 message::ToolResultContent::Text(text) => text.text,
491 message::ToolResultContent::Image(_) => {
492 return Err(message::MessageError::ConversionError(
493 "Tool result content must be text".to_string(),
494 ));
495 }
496 };
497 let result: serde_json::Value = serde_json::from_str(&content)
499 .map_err(|x| MessageError::ConversionError(x.to_string()))?;
500 Ok(Part::FunctionResponse(FunctionResponse {
501 name: id,
502 response: Some(json!({ "result": result })),
503 }))
504 }
505 message::UserContent::Image(message::Image {
506 data, media_type, ..
507 }) => match media_type {
508 Some(media_type) => match media_type {
509 message::ImageMediaType::JPEG
510 | message::ImageMediaType::PNG
511 | message::ImageMediaType::WEBP
512 | message::ImageMediaType::HEIC
513 | message::ImageMediaType::HEIF => Ok(Self::InlineData(Blob {
514 mime_type: media_type.to_mime_type().to_owned(),
515 data,
516 })),
517 _ => Err(message::MessageError::ConversionError(format!(
518 "Unsupported image media type {media_type:?}"
519 ))),
520 },
521 None => Err(message::MessageError::ConversionError(
522 "Media type for image is required for Anthropic".to_string(),
523 )),
524 },
525 message::UserContent::Document(message::Document {
526 data, media_type, ..
527 }) => match media_type {
528 Some(media_type) => match media_type {
529 message::DocumentMediaType::PDF
530 | message::DocumentMediaType::TXT
531 | message::DocumentMediaType::RTF
532 | message::DocumentMediaType::HTML
533 | message::DocumentMediaType::CSS
534 | message::DocumentMediaType::MARKDOWN
535 | message::DocumentMediaType::CSV
536 | message::DocumentMediaType::XML => Ok(Self::InlineData(Blob {
537 mime_type: media_type.to_mime_type().to_owned(),
538 data,
539 })),
540 _ => Err(message::MessageError::ConversionError(format!(
541 "Unsupported document media type {media_type:?}"
542 ))),
543 },
544 None => Err(message::MessageError::ConversionError(
545 "Media type for document is required for Anthropic".to_string(),
546 )),
547 },
548 message::UserContent::Audio(message::Audio {
549 data, media_type, ..
550 }) => match media_type {
551 Some(media_type) => Ok(Self::InlineData(Blob {
552 mime_type: media_type.to_mime_type().to_owned(),
553 data,
554 })),
555 None => Err(message::MessageError::ConversionError(
556 "Media type for audio is required for Anthropic".to_string(),
557 )),
558 },
559 }
560 }
561 }
562
563 impl From<message::AssistantContent> for Part {
564 fn from(content: message::AssistantContent) -> Self {
565 match content {
566 message::AssistantContent::Text(message::Text { text }) => text.into(),
567 message::AssistantContent::ToolCall(tool_call) => tool_call.into(),
568 message::AssistantContent::Reasoning(message::Reasoning { reasoning }) => {
569 Part::Thought {
570 thoughts: vec![reasoning],
571 }
572 }
573 }
574 }
575 }
576
577 impl From<message::ToolCall> for Part {
578 fn from(tool_call: message::ToolCall) -> Self {
579 Self::FunctionCall(FunctionCall {
580 name: tool_call.function.name,
581 args: tool_call.function.arguments,
582 })
583 }
584 }
585
586 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
589 #[serde(rename_all = "camelCase")]
590 pub struct Blob {
591 pub mime_type: String,
594 pub data: String,
596 }
597
598 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
601 pub struct FunctionCall {
602 pub name: String,
605 pub args: serde_json::Value,
607 }
608
609 impl From<FunctionCall> for message::ToolCall {
610 fn from(function_call: FunctionCall) -> Self {
611 Self {
612 id: function_call.name.clone(),
613 call_id: None,
614 function: message::ToolFunction {
615 name: function_call.name,
616 arguments: function_call.args,
617 },
618 }
619 }
620 }
621
622 impl From<message::ToolCall> for FunctionCall {
623 fn from(tool_call: message::ToolCall) -> Self {
624 Self {
625 name: tool_call.function.name,
626 args: tool_call.function.arguments,
627 }
628 }
629 }
630
631 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
635 pub struct FunctionResponse {
636 pub name: String,
639 pub response: Option<serde_json::Value>,
641 }
642
643 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
645 #[serde(rename_all = "camelCase")]
646 pub struct FileData {
647 pub mime_type: Option<String>,
649 pub file_uri: String,
651 }
652
653 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
654 pub struct SafetyRating {
655 pub category: HarmCategory,
656 pub probability: HarmProbability,
657 }
658
659 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
660 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
661 pub enum HarmProbability {
662 HarmProbabilityUnspecified,
663 Negligible,
664 Low,
665 Medium,
666 High,
667 }
668
669 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
670 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
671 pub enum HarmCategory {
672 HarmCategoryUnspecified,
673 HarmCategoryDerogatory,
674 HarmCategoryToxicity,
675 HarmCategoryViolence,
676 HarmCategorySexually,
677 HarmCategoryMedical,
678 HarmCategoryDangerous,
679 HarmCategoryHarassment,
680 HarmCategoryHateSpeech,
681 HarmCategorySexuallyExplicit,
682 HarmCategoryDangerousContent,
683 HarmCategoryCivicIntegrity,
684 }
685
686 #[derive(Debug, Deserialize, Clone, Default, Serialize)]
687 #[serde(rename_all = "camelCase")]
688 pub struct UsageMetadata {
689 pub prompt_token_count: i32,
690 #[serde(skip_serializing_if = "Option::is_none")]
691 pub cached_content_token_count: Option<i32>,
692 pub candidates_token_count: i32,
693 pub total_token_count: i32,
694 }
695
696 impl std::fmt::Display for UsageMetadata {
697 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
698 write!(
699 f,
700 "Prompt token count: {}\nCached content token count: {}\nCandidates token count: {}\nTotal token count: {}",
701 self.prompt_token_count,
702 match self.cached_content_token_count {
703 Some(count) => count.to_string(),
704 None => "n/a".to_string(),
705 },
706 self.candidates_token_count,
707 self.total_token_count
708 )
709 }
710 }
711
712 #[derive(Debug, Deserialize, Serialize)]
714 #[serde(rename_all = "camelCase")]
715 pub struct PromptFeedback {
716 pub block_reason: Option<BlockReason>,
718 pub safety_ratings: Option<Vec<SafetyRating>>,
720 }
721
722 #[derive(Debug, Deserialize, Serialize)]
724 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
725 pub enum BlockReason {
726 BlockReasonUnspecified,
728 Safety,
730 Other,
732 Blocklist,
734 ProhibitedContent,
736 }
737
738 #[derive(Debug, Deserialize, Serialize)]
739 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
740 pub enum FinishReason {
741 FinishReasonUnspecified,
743 Stop,
745 MaxTokens,
747 Safety,
749 Recitation,
751 Language,
753 Other,
755 Blocklist,
757 ProhibitedContent,
759 Spii,
761 MalformedFunctionCall,
763 }
764
765 #[derive(Debug, Deserialize, Serialize)]
766 #[serde(rename_all = "camelCase")]
767 pub struct CitationMetadata {
768 pub citation_sources: Vec<CitationSource>,
769 }
770
771 #[derive(Debug, Deserialize, Serialize)]
772 #[serde(rename_all = "camelCase")]
773 pub struct CitationSource {
774 #[serde(skip_serializing_if = "Option::is_none")]
775 pub uri: Option<String>,
776 #[serde(skip_serializing_if = "Option::is_none")]
777 pub start_index: Option<i32>,
778 #[serde(skip_serializing_if = "Option::is_none")]
779 pub end_index: Option<i32>,
780 #[serde(skip_serializing_if = "Option::is_none")]
781 pub license: Option<String>,
782 }
783
784 #[derive(Debug, Deserialize, Serialize)]
785 #[serde(rename_all = "camelCase")]
786 pub struct LogprobsResult {
787 pub top_candidate: Vec<TopCandidate>,
788 pub chosen_candidate: Vec<LogProbCandidate>,
789 }
790
791 #[derive(Debug, Deserialize, Serialize)]
792 pub struct TopCandidate {
793 pub candidates: Vec<LogProbCandidate>,
794 }
795
796 #[derive(Debug, Deserialize, Serialize)]
797 #[serde(rename_all = "camelCase")]
798 pub struct LogProbCandidate {
799 pub token: String,
800 pub token_id: String,
801 pub log_probability: f64,
802 }
803
804 #[derive(Debug, Deserialize, Serialize)]
809 #[serde(rename_all = "camelCase")]
810 pub struct GenerationConfig {
811 #[serde(skip_serializing_if = "Option::is_none")]
814 pub stop_sequences: Option<Vec<String>>,
815 #[serde(skip_serializing_if = "Option::is_none")]
821 pub response_mime_type: Option<String>,
822 #[serde(skip_serializing_if = "Option::is_none")]
826 pub response_schema: Option<Schema>,
827 #[serde(skip_serializing_if = "Option::is_none")]
830 pub candidate_count: Option<i32>,
831 #[serde(skip_serializing_if = "Option::is_none")]
834 pub max_output_tokens: Option<u64>,
835 #[serde(skip_serializing_if = "Option::is_none")]
838 pub temperature: Option<f64>,
839 #[serde(skip_serializing_if = "Option::is_none")]
846 pub top_p: Option<f64>,
847 #[serde(skip_serializing_if = "Option::is_none")]
853 pub top_k: Option<i32>,
854 #[serde(skip_serializing_if = "Option::is_none")]
860 pub presence_penalty: Option<f64>,
861 #[serde(skip_serializing_if = "Option::is_none")]
869 pub frequency_penalty: Option<f64>,
870 #[serde(skip_serializing_if = "Option::is_none")]
872 pub response_logprobs: Option<bool>,
873 #[serde(skip_serializing_if = "Option::is_none")]
876 pub logprobs: Option<i32>,
877 }
878
879 impl Default for GenerationConfig {
880 fn default() -> Self {
881 Self {
882 temperature: Some(1.0),
883 max_output_tokens: Some(4096),
884 stop_sequences: None,
885 response_mime_type: None,
886 response_schema: None,
887 candidate_count: None,
888 top_p: None,
889 top_k: None,
890 presence_penalty: None,
891 frequency_penalty: None,
892 response_logprobs: None,
893 logprobs: None,
894 }
895 }
896 }
897 #[derive(Debug, Deserialize, Serialize, Clone)]
901 pub struct Schema {
902 pub r#type: String,
903 #[serde(skip_serializing_if = "Option::is_none")]
904 pub format: Option<String>,
905 #[serde(skip_serializing_if = "Option::is_none")]
906 pub description: Option<String>,
907 #[serde(skip_serializing_if = "Option::is_none")]
908 pub nullable: Option<bool>,
909 #[serde(skip_serializing_if = "Option::is_none")]
910 pub r#enum: Option<Vec<String>>,
911 #[serde(skip_serializing_if = "Option::is_none")]
912 pub max_items: Option<i32>,
913 #[serde(skip_serializing_if = "Option::is_none")]
914 pub min_items: Option<i32>,
915 #[serde(skip_serializing_if = "Option::is_none")]
916 pub properties: Option<HashMap<String, Schema>>,
917 #[serde(skip_serializing_if = "Option::is_none")]
918 pub required: Option<Vec<String>>,
919 #[serde(skip_serializing_if = "Option::is_none")]
920 pub items: Option<Box<Schema>>,
921 }
922
923 impl TryFrom<Value> for Schema {
924 type Error = CompletionError;
925
926 fn try_from(value: Value) -> Result<Self, Self::Error> {
927 if let Some(obj) = value.as_object() {
928 Ok(Schema {
929 r#type: obj
930 .get("type")
931 .and_then(|v| {
932 if v.is_string() {
933 v.as_str().map(String::from)
934 } else if v.is_array() {
935 v.as_array()
936 .and_then(|arr| arr.first())
937 .and_then(|v| v.as_str().map(String::from))
938 } else {
939 None
940 }
941 })
942 .unwrap_or_default(),
943 format: obj.get("format").and_then(|v| v.as_str()).map(String::from),
944 description: obj
945 .get("description")
946 .and_then(|v| v.as_str())
947 .map(String::from),
948 nullable: obj.get("nullable").and_then(|v| v.as_bool()),
949 r#enum: obj.get("enum").and_then(|v| v.as_array()).map(|arr| {
950 arr.iter()
951 .filter_map(|v| v.as_str().map(String::from))
952 .collect()
953 }),
954 max_items: obj
955 .get("maxItems")
956 .and_then(|v| v.as_i64())
957 .map(|v| v as i32),
958 min_items: obj
959 .get("minItems")
960 .and_then(|v| v.as_i64())
961 .map(|v| v as i32),
962 properties: obj
963 .get("properties")
964 .and_then(|v| v.as_object())
965 .map(|map| {
966 map.iter()
967 .filter_map(|(k, v)| {
968 v.clone().try_into().ok().map(|schema| (k.clone(), schema))
969 })
970 .collect()
971 }),
972 required: obj.get("required").and_then(|v| v.as_array()).map(|arr| {
973 arr.iter()
974 .filter_map(|v| v.as_str().map(String::from))
975 .collect()
976 }),
977 items: obj
978 .get("items")
979 .map(|v| Box::new(v.clone().try_into().unwrap())),
980 })
981 } else {
982 Err(CompletionError::ResponseError(
983 "Expected a JSON object for Schema".into(),
984 ))
985 }
986 }
987 }
988
989 #[derive(Debug, Serialize)]
990 #[serde(rename_all = "camelCase")]
991 pub struct GenerateContentRequest {
992 pub contents: Vec<Content>,
993 #[serde(skip_serializing_if = "Option::is_none")]
994 pub tools: Option<Tool>,
995 pub tool_config: Option<ToolConfig>,
996 pub generation_config: Option<GenerationConfig>,
998 pub safety_settings: Option<Vec<SafetySetting>>,
1012 pub system_instruction: Option<Content>,
1015 }
1017
1018 #[derive(Debug, Serialize)]
1019 #[serde(rename_all = "camelCase")]
1020 pub struct Tool {
1021 pub function_declarations: Vec<FunctionDeclaration>,
1022 pub code_execution: Option<CodeExecution>,
1023 }
1024
1025 #[derive(Debug, Serialize, Clone)]
1026 #[serde(rename_all = "camelCase")]
1027 pub struct FunctionDeclaration {
1028 pub name: String,
1029 pub description: String,
1030 #[serde(skip_serializing_if = "Option::is_none")]
1031 pub parameters: Option<Schema>,
1032 }
1033
1034 #[derive(Debug, Serialize)]
1035 #[serde(rename_all = "camelCase")]
1036 pub struct ToolConfig {
1037 pub schema: Option<Schema>,
1038 }
1039
1040 #[derive(Debug, Serialize)]
1041 #[serde(rename_all = "camelCase")]
1042 pub struct CodeExecution {}
1043
1044 #[derive(Debug, Serialize)]
1045 #[serde(rename_all = "camelCase")]
1046 pub struct SafetySetting {
1047 pub category: HarmCategory,
1048 pub threshold: HarmBlockThreshold,
1049 }
1050
1051 #[derive(Debug, Serialize)]
1052 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1053 pub enum HarmBlockThreshold {
1054 HarmBlockThresholdUnspecified,
1055 BlockLowAndAbove,
1056 BlockMediumAndAbove,
1057 BlockOnlyHigh,
1058 BlockNone,
1059 Off,
1060 }
1061}
1062
1063#[cfg(test)]
1064mod tests {
1065 use crate::message;
1066
1067 use super::*;
1068 use serde_json::json;
1069
1070 #[test]
1071 fn test_deserialize_message_user() {
1072 let raw_message = r#"{
1073 "parts": [
1074 {"text": "Hello, world!"},
1075 {"inlineData": {"mimeType": "image/png", "data": "base64encodeddata"}},
1076 {"functionCall": {"name": "test_function", "args": {"arg1": "value1"}}},
1077 {"functionResponse": {"name": "test_function", "response": {"result": "success"}}},
1078 {"fileData": {"mimeType": "application/pdf", "fileUri": "http://example.com/file.pdf"}},
1079 {"executableCode": {"code": "print('Hello, world!')", "language": "PYTHON"}},
1080 {"codeExecutionResult": {"output": "Hello, world!", "outcome": "OUTCOME_OK"}}
1081 ],
1082 "role": "user"
1083 }"#;
1084
1085 let content: Content = {
1086 let jd = &mut serde_json::Deserializer::from_str(raw_message);
1087 serde_path_to_error::deserialize(jd).unwrap_or_else(|err| {
1088 panic!("Deserialization error at {}: {}", err.path(), err);
1089 })
1090 };
1091 assert_eq!(content.role, Some(Role::User));
1092 assert_eq!(content.parts.len(), 7);
1093
1094 let parts: Vec<Part> = content.parts.into_iter().collect();
1095
1096 if let Part::Text(text) = &parts[0] {
1097 assert_eq!(text, "Hello, world!");
1098 } else {
1099 panic!("Expected text part");
1100 }
1101
1102 if let Part::InlineData(inline_data) = &parts[1] {
1103 assert_eq!(inline_data.mime_type, "image/png");
1104 assert_eq!(inline_data.data, "base64encodeddata");
1105 } else {
1106 panic!("Expected inline data part");
1107 }
1108
1109 if let Part::FunctionCall(function_call) = &parts[2] {
1110 assert_eq!(function_call.name, "test_function");
1111 assert_eq!(
1112 function_call.args.as_object().unwrap().get("arg1").unwrap(),
1113 "value1"
1114 );
1115 } else {
1116 panic!("Expected function call part");
1117 }
1118
1119 if let Part::FunctionResponse(function_response) = &parts[3] {
1120 assert_eq!(function_response.name, "test_function");
1121 assert_eq!(
1122 function_response
1123 .response
1124 .as_ref()
1125 .unwrap()
1126 .get("result")
1127 .unwrap(),
1128 "success"
1129 );
1130 } else {
1131 panic!("Expected function response part");
1132 }
1133
1134 if let Part::FileData(file_data) = &parts[4] {
1135 assert_eq!(file_data.mime_type.as_ref().unwrap(), "application/pdf");
1136 assert_eq!(file_data.file_uri, "http://example.com/file.pdf");
1137 } else {
1138 panic!("Expected file data part");
1139 }
1140
1141 if let Part::ExecutableCode(executable_code) = &parts[5] {
1142 assert_eq!(executable_code.code, "print('Hello, world!')");
1143 } else {
1144 panic!("Expected executable code part");
1145 }
1146
1147 if let Part::CodeExecutionResult(code_execution_result) = &parts[6] {
1148 assert_eq!(
1149 code_execution_result.clone().output.unwrap(),
1150 "Hello, world!"
1151 );
1152 } else {
1153 panic!("Expected code execution result part");
1154 }
1155 }
1156
1157 #[test]
1158 fn test_deserialize_message_model() {
1159 let json_data = json!({
1160 "parts": [{"text": "Hello, user!"}],
1161 "role": "model"
1162 });
1163
1164 let content: Content = serde_json::from_value(json_data).unwrap();
1165 assert_eq!(content.role, Some(Role::Model));
1166 assert_eq!(content.parts.len(), 1);
1167 if let Part::Text(text) = &content.parts.first() {
1168 assert_eq!(text, "Hello, user!");
1169 } else {
1170 panic!("Expected text part");
1171 }
1172 }
1173
1174 #[test]
1175 fn test_message_conversion_user() {
1176 let msg = message::Message::user("Hello, world!");
1177 let content: Content = msg.try_into().unwrap();
1178 assert_eq!(content.role, Some(Role::User));
1179 assert_eq!(content.parts.len(), 1);
1180 if let Part::Text(text) = &content.parts.first() {
1181 assert_eq!(text, "Hello, world!");
1182 } else {
1183 panic!("Expected text part");
1184 }
1185 }
1186
1187 #[test]
1188 fn test_message_conversion_model() {
1189 let msg = message::Message::assistant("Hello, user!");
1190
1191 let content: Content = msg.try_into().unwrap();
1192 assert_eq!(content.role, Some(Role::Model));
1193 assert_eq!(content.parts.len(), 1);
1194 if let Part::Text(text) = &content.parts.first() {
1195 assert_eq!(text, "Hello, user!");
1196 } else {
1197 panic!("Expected text part");
1198 }
1199 }
1200
1201 #[test]
1202 fn test_message_conversion_tool_call() {
1203 let tool_call = message::ToolCall {
1204 id: "test_tool".to_string(),
1205 call_id: None,
1206 function: message::ToolFunction {
1207 name: "test_function".to_string(),
1208 arguments: json!({"arg1": "value1"}),
1209 },
1210 };
1211
1212 let msg = message::Message::Assistant {
1213 id: None,
1214 content: OneOrMany::one(message::AssistantContent::ToolCall(tool_call)),
1215 };
1216
1217 let content: Content = msg.try_into().unwrap();
1218 assert_eq!(content.role, Some(Role::Model));
1219 assert_eq!(content.parts.len(), 1);
1220 if let Part::FunctionCall(function_call) = &content.parts.first() {
1221 assert_eq!(function_call.name, "test_function");
1222 assert_eq!(
1223 function_call.args.as_object().unwrap().get("arg1").unwrap(),
1224 "value1"
1225 );
1226 } else {
1227 panic!("Expected function call part");
1228 }
1229 }
1230}