1pub const GEMINI_2_5_PRO_PREVIEW_06_05: &str = "gemini-2.5-pro-preview-06-05";
7pub const GEMINI_2_5_PRO_PREVIEW_05_06: &str = "gemini-2.5-pro-preview-05-06";
9pub const GEMINI_2_5_PRO_PREVIEW_03_25: &str = "gemini-2.5-pro-preview-03-25";
11pub const GEMINI_2_5_FLASH_PREVIEW_05_20: &str = "gemini-2.5-flash-preview-05-20";
13pub const GEMINI_2_5_FLASH_PREVIEW_04_17: &str = "gemini-2.5-flash-preview-04-17";
15pub const GEMINI_2_5_PRO_EXP_03_25: &str = "gemini-2.5-pro-exp-03-25";
17pub const GEMINI_2_0_FLASH_LITE: &str = "gemini-2.0-flash-lite";
19pub const GEMINI_2_0_FLASH: &str = "gemini-2.0-flash";
21pub const GEMINI_1_5_FLASH: &str = "gemini-1.5-flash";
23pub const GEMINI_1_5_PRO: &str = "gemini-1.5-pro";
25pub const GEMINI_1_5_PRO_8B: &str = "gemini-1.5-pro-8b";
27pub const GEMINI_1_0_PRO: &str = "gemini-1.0-pro";
29
30use self::gemini_api_types::Schema;
31use crate::message::Reasoning;
32use crate::providers::gemini::completion::gemini_api_types::AdditionalParameters;
33use crate::providers::gemini::streaming::StreamingCompletionResponse;
34use crate::{
35 OneOrMany,
36 completion::{self, CompletionError, CompletionRequest},
37};
38use gemini_api_types::{
39 Content, FunctionDeclaration, GenerateContentRequest, GenerateContentResponse, Part, PartKind,
40 Role, Tool,
41};
42use serde_json::{Map, Value};
43use std::convert::TryFrom;
44
45use super::Client;
46
47#[derive(Clone)]
52pub struct CompletionModel {
53 pub(crate) client: Client,
54 pub model: String,
55}
56
57impl CompletionModel {
58 pub fn new(client: Client, model: &str) -> Self {
59 Self {
60 client,
61 model: model.to_string(),
62 }
63 }
64}
65
66impl completion::CompletionModel for CompletionModel {
67 type Response = GenerateContentResponse;
68 type StreamingResponse = StreamingCompletionResponse;
69
70 #[cfg_attr(feature = "worker", worker::send)]
71 async fn completion(
72 &self,
73 completion_request: CompletionRequest,
74 ) -> Result<completion::CompletionResponse<GenerateContentResponse>, CompletionError> {
75 let request = create_request_body(completion_request)?;
76
77 tracing::debug!(
78 "Sending completion request to Gemini API {}",
79 serde_json::to_string_pretty(&request)?
80 );
81
82 let response = self
83 .client
84 .post(&format!("/v1beta/models/{}:generateContent", self.model))
85 .json(&request)
86 .send()
87 .await?;
88
89 if response.status().is_success() {
90 let response = response.json::<GenerateContentResponse>().await?;
91 match response.usage_metadata {
92 Some(ref usage) => tracing::info!(target: "rig",
93 "Gemini completion token usage: {}",
94 usage
95 ),
96 None => tracing::info!(target: "rig",
97 "Gemini completion token usage: n/a",
98 ),
99 }
100
101 tracing::debug!("Received response");
102
103 Ok(completion::CompletionResponse::try_from(response))
104 } else {
105 Err(CompletionError::ProviderError(response.text().await?))
106 }?
107 }
108
109 #[cfg_attr(feature = "worker", worker::send)]
110 async fn stream(
111 &self,
112 request: CompletionRequest,
113 ) -> Result<
114 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
115 CompletionError,
116 > {
117 CompletionModel::stream(self, request).await
118 }
119}
120
121pub(crate) fn create_request_body(
122 completion_request: CompletionRequest,
123) -> Result<GenerateContentRequest, CompletionError> {
124 let mut full_history = Vec::new();
125 full_history.extend(completion_request.chat_history);
126
127 let additional_params = completion_request
128 .additional_params
129 .unwrap_or_else(|| Value::Object(Map::new()));
130
131 let AdditionalParameters {
132 mut generation_config,
133 additional_params,
134 } = serde_json::from_value::<AdditionalParameters>(additional_params)?;
135
136 if let Some(temp) = completion_request.temperature {
137 generation_config.temperature = Some(temp);
138 }
139
140 if let Some(max_tokens) = completion_request.max_tokens {
141 generation_config.max_output_tokens = Some(max_tokens);
142 }
143
144 let system_instruction = completion_request.preamble.clone().map(|preamble| Content {
145 parts: vec![preamble.into()],
146 role: Some(Role::Model),
147 });
148
149 let tools = if completion_request.tools.is_empty() {
150 None
151 } else {
152 Some(Tool::try_from(completion_request.tools)?)
153 };
154
155 let request = GenerateContentRequest {
156 contents: full_history
157 .into_iter()
158 .map(|msg| {
159 msg.try_into()
160 .map_err(|e| CompletionError::RequestError(Box::new(e)))
161 })
162 .collect::<Result<Vec<_>, _>>()?,
163 generation_config: Some(generation_config),
164 safety_settings: None,
165 tools,
166 tool_config: None,
167 system_instruction,
168 additional_params,
169 };
170
171 Ok(request)
172}
173
174impl TryFrom<completion::ToolDefinition> for Tool {
175 type Error = CompletionError;
176
177 fn try_from(tool: completion::ToolDefinition) -> Result<Self, Self::Error> {
178 let parameters: Option<Schema> =
179 if tool.parameters == serde_json::json!({"type": "object", "properties": {}}) {
180 None
181 } else {
182 Some(tool.parameters.try_into()?)
183 };
184
185 Ok(Self {
186 function_declarations: vec![FunctionDeclaration {
187 name: tool.name,
188 description: tool.description,
189 parameters,
190 }],
191 code_execution: None,
192 })
193 }
194}
195
196impl TryFrom<Vec<completion::ToolDefinition>> for Tool {
197 type Error = CompletionError;
198
199 fn try_from(tools: Vec<completion::ToolDefinition>) -> Result<Self, Self::Error> {
200 let mut function_declarations = Vec::new();
201
202 for tool in tools {
203 let parameters =
204 if tool.parameters == serde_json::json!({"type": "object", "properties": {}}) {
205 None
206 } else {
207 match tool.parameters.try_into() {
208 Ok(schema) => Some(schema),
209 Err(e) => {
210 let emsg = format!(
211 "Tool '{}' could not be converted to a schema: {:?}",
212 tool.name, e,
213 );
214 return Err(CompletionError::ProviderError(emsg));
215 }
216 }
217 };
218
219 function_declarations.push(FunctionDeclaration {
220 name: tool.name,
221 description: tool.description,
222 parameters,
223 });
224 }
225
226 Ok(Self {
227 function_declarations,
228 code_execution: None,
229 })
230 }
231}
232
233impl TryFrom<GenerateContentResponse> for completion::CompletionResponse<GenerateContentResponse> {
234 type Error = CompletionError;
235
236 fn try_from(response: GenerateContentResponse) -> Result<Self, Self::Error> {
237 let candidate = response.candidates.first().ok_or_else(|| {
238 CompletionError::ResponseError("No response candidates in response".into())
239 })?;
240
241 let content = candidate
242 .content
243 .parts
244 .iter()
245 .map(|Part { thought, part, .. }| {
246 Ok(match part {
247 PartKind::Text(text) => {
248 if let Some(thought) = thought
249 && *thought
250 {
251 completion::AssistantContent::Reasoning(Reasoning::new(text))
252 } else {
253 completion::AssistantContent::text(text)
254 }
255 }
256 PartKind::FunctionCall(function_call) => {
257 completion::AssistantContent::tool_call(
258 &function_call.name,
259 &function_call.name,
260 function_call.args.clone(),
261 )
262 }
263 _ => {
264 return Err(CompletionError::ResponseError(
265 "Response did not contain a message or tool call".into(),
266 ));
267 }
268 })
269 })
270 .collect::<Result<Vec<_>, _>>()?;
271
272 let choice = OneOrMany::many(content).map_err(|_| {
273 CompletionError::ResponseError(
274 "Response contained no message or tool call (empty)".to_owned(),
275 )
276 })?;
277
278 let usage = response
279 .usage_metadata
280 .as_ref()
281 .map(|usage| completion::Usage {
282 input_tokens: usage.prompt_token_count as u64,
283 output_tokens: usage.candidates_token_count as u64,
284 total_tokens: usage.total_token_count as u64,
285 })
286 .unwrap_or_default();
287
288 Ok(completion::CompletionResponse {
289 choice,
290 usage,
291 raw_response: response,
292 })
293 }
294}
295
296pub mod gemini_api_types {
297 use std::{collections::HashMap, convert::Infallible, str::FromStr};
298
299 use serde::{Deserialize, Serialize};
303 use serde_json::{Value, json};
304
305 use crate::message::ContentFormat;
306 use crate::{
307 OneOrMany,
308 completion::CompletionError,
309 message::{self, MimeType as _, Reasoning, Text},
310 providers::gemini::gemini_api_types::{CodeExecutionResult, ExecutableCode},
311 };
312
313 #[derive(Debug, Deserialize, Serialize, Default)]
314 #[serde(rename_all = "camelCase")]
315 pub struct AdditionalParameters {
316 pub generation_config: GenerationConfig,
318 #[serde(flatten, skip_serializing_if = "Option::is_none")]
320 pub additional_params: Option<serde_json::Value>,
321 }
322
323 impl AdditionalParameters {
324 pub fn with_config(mut self, cfg: GenerationConfig) -> Self {
325 self.generation_config = cfg;
326 self
327 }
328
329 pub fn with_params(mut self, params: serde_json::Value) -> Self {
330 self.additional_params = Some(params);
331 self
332 }
333 }
334
335 #[derive(Debug, Deserialize, Serialize)]
343 #[serde(rename_all = "camelCase")]
344 pub struct GenerateContentResponse {
345 pub candidates: Vec<ContentCandidate>,
347 pub prompt_feedback: Option<PromptFeedback>,
349 pub usage_metadata: Option<UsageMetadata>,
351 pub model_version: Option<String>,
352 }
353
354 #[derive(Debug, Deserialize, Serialize)]
356 #[serde(rename_all = "camelCase")]
357 pub struct ContentCandidate {
358 pub content: Content,
360 pub finish_reason: Option<FinishReason>,
363 pub safety_ratings: Option<Vec<SafetyRating>>,
366 pub citation_metadata: Option<CitationMetadata>,
370 pub token_count: Option<i32>,
372 pub avg_logprobs: Option<f64>,
374 pub logprobs_result: Option<LogprobsResult>,
376 pub index: Option<i32>,
378 }
379
380 #[derive(Debug, Deserialize, Serialize)]
381 pub struct Content {
382 #[serde(default)]
384 pub parts: Vec<Part>,
385 pub role: Option<Role>,
388 }
389
390 impl TryFrom<message::Message> for Content {
391 type Error = message::MessageError;
392
393 fn try_from(msg: message::Message) -> Result<Self, Self::Error> {
394 Ok(match msg {
395 message::Message::User { content } => Content {
396 parts: content
397 .into_iter()
398 .map(|c| c.try_into())
399 .collect::<Result<Vec<_>, _>>()?,
400 role: Some(Role::User),
401 },
402 message::Message::Assistant { content, .. } => Content {
403 role: Some(Role::Model),
404 parts: content.into_iter().map(|content| content.into()).collect(),
405 },
406 })
407 }
408 }
409
410 impl TryFrom<Content> for message::Message {
411 type Error = message::MessageError;
412
413 fn try_from(content: Content) -> Result<Self, Self::Error> {
414 match content.role {
415 Some(Role::User) | None => {
416 Ok(message::Message::User {
417 content: {
418 let user_content: Result<Vec<_>, _> = content.parts.into_iter()
419 .map(|Part { part, .. }| {
420 Ok(match part {
421 PartKind::Text(text) => message::UserContent::text(text),
422 PartKind::InlineData(inline_data) => {
423 let mime_type =
424 message::MediaType::from_mime_type(&inline_data.mime_type);
425
426 match mime_type {
427 Some(message::MediaType::Image(media_type)) => {
428 message::UserContent::image(
429 inline_data.data,
430 Some(message::ContentFormat::default()),
431 Some(media_type),
432 Some(message::ImageDetail::default()),
433 )
434 }
435 Some(message::MediaType::Document(media_type)) => {
436 message::UserContent::document(
437 inline_data.data,
438 Some(message::ContentFormat::default()),
439 Some(media_type),
440 )
441 }
442 Some(message::MediaType::Audio(media_type)) => {
443 message::UserContent::audio(
444 inline_data.data,
445 Some(message::ContentFormat::default()),
446 Some(media_type),
447 )
448 }
449 _ => {
450 return Err(message::MessageError::ConversionError(
451 format!("Unsupported media type {mime_type:?}"),
452 ));
453 }
454 }
455 }
456 _ => {
457 return Err(message::MessageError::ConversionError(format!(
458 "Unsupported gemini content part type: {part:?}"
459 )));
460 }
461 })
462 })
463 .collect();
464 OneOrMany::many(user_content?).map_err(|_| {
465 message::MessageError::ConversionError(
466 "Failed to create OneOrMany from user content".to_string(),
467 )
468 })?
469 },
470 })
471 }
472 Some(Role::Model) => Ok(message::Message::Assistant {
473 id: None,
474 content: {
475 let assistant_content: Result<Vec<_>, _> = content
476 .parts
477 .into_iter()
478 .map(|Part { thought, part, .. }| {
479 Ok(match part {
480 PartKind::Text(text) => match thought {
481 Some(true) => message::AssistantContent::Reasoning(
482 Reasoning::new(&text),
483 ),
484 _ => message::AssistantContent::Text(Text { text }),
485 },
486
487 PartKind::FunctionCall(function_call) => {
488 message::AssistantContent::ToolCall(function_call.into())
489 }
490 _ => {
491 return Err(message::MessageError::ConversionError(
492 format!("Unsupported part type: {part:?}"),
493 ));
494 }
495 })
496 })
497 .collect();
498 OneOrMany::many(assistant_content?).map_err(|_| {
499 message::MessageError::ConversionError(
500 "Failed to create OneOrMany from assistant content".to_string(),
501 )
502 })?
503 },
504 }),
505 }
506 }
507 }
508
509 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
510 #[serde(rename_all = "lowercase")]
511 pub enum Role {
512 User,
513 Model,
514 }
515
516 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
517 #[serde(rename_all = "camelCase")]
518 pub struct Part {
519 #[serde(skip_serializing_if = "Option::is_none")]
521 pub thought: Option<bool>,
522 #[serde(skip_serializing_if = "Option::is_none")]
524 pub thought_signature: Option<String>,
525 #[serde(flatten)]
526 pub part: PartKind,
527 #[serde(flatten, skip_serializing_if = "Option::is_none")]
528 pub additional_params: Option<Value>,
529 }
530
531 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
535 #[serde(rename_all = "camelCase")]
536 pub enum PartKind {
537 Text(String),
538 InlineData(Blob),
539 FunctionCall(FunctionCall),
540 FunctionResponse(FunctionResponse),
541 FileData(FileData),
542 ExecutableCode(ExecutableCode),
543 CodeExecutionResult(CodeExecutionResult),
544 }
545
546 impl From<String> for Part {
547 fn from(text: String) -> Self {
548 Self {
549 thought: Some(false),
550 thought_signature: None,
551 part: PartKind::Text(text),
552 additional_params: None,
553 }
554 }
555 }
556
557 impl From<&str> for Part {
558 fn from(text: &str) -> Self {
559 Self::from(text.to_string())
560 }
561 }
562
563 impl FromStr for Part {
564 type Err = Infallible;
565
566 fn from_str(s: &str) -> Result<Self, Self::Err> {
567 Ok(s.into())
568 }
569 }
570
571 impl TryFrom<message::UserContent> for Part {
572 type Error = message::MessageError;
573
574 fn try_from(content: message::UserContent) -> Result<Self, Self::Error> {
575 match content {
576 message::UserContent::Text(message::Text { text }) => Ok(Part {
577 thought: Some(false),
578 thought_signature: None,
579 part: PartKind::Text(text),
580 additional_params: None,
581 }),
582 message::UserContent::ToolResult(message::ToolResult { id, content, .. }) => {
583 let content = match content.first() {
584 message::ToolResultContent::Text(text) => text.text,
585 message::ToolResultContent::Image(_) => {
586 return Err(message::MessageError::ConversionError(
587 "Tool result content must be text".to_string(),
588 ));
589 }
590 };
591 let result: serde_json::Value =
593 serde_json::from_str(&content).unwrap_or_else(|error| {
594 tracing::trace!(
595 ?error,
596 "Tool result is not a valid JSON, treat it as normal string"
597 );
598 json!(content)
599 });
600 Ok(Part {
601 thought: Some(false),
602 thought_signature: None,
603 part: PartKind::FunctionResponse(FunctionResponse {
604 name: id,
605 response: Some(json!({ "result": result })),
606 }),
607 additional_params: None,
608 })
609 }
610 message::UserContent::Image(message::Image {
611 data, media_type, ..
612 }) => match media_type {
613 Some(media_type) => match media_type {
614 message::ImageMediaType::JPEG
615 | message::ImageMediaType::PNG
616 | message::ImageMediaType::WEBP
617 | message::ImageMediaType::HEIC
618 | message::ImageMediaType::HEIF => Ok(Part {
619 thought: Some(false),
620 thought_signature: None,
621 part: PartKind::InlineData(Blob {
622 mime_type: media_type.to_mime_type().to_owned(),
623 data,
624 }),
625 additional_params: None,
626 }),
627 _ => Err(message::MessageError::ConversionError(format!(
628 "Unsupported image media type {media_type:?}"
629 ))),
630 },
631 None => Err(message::MessageError::ConversionError(
632 "Media type for image is required for Gemini".to_string(), )),
634 },
635 message::UserContent::Document(message::Document {
636 data, media_type, ..
637 }) => match media_type {
638 Some(media_type) => match media_type {
639 message::DocumentMediaType::PDF
640 | message::DocumentMediaType::TXT
641 | message::DocumentMediaType::RTF
642 | message::DocumentMediaType::HTML
643 | message::DocumentMediaType::CSS
644 | message::DocumentMediaType::MARKDOWN
645 | message::DocumentMediaType::CSV
646 | message::DocumentMediaType::XML => Ok(Part {
647 thought: Some(false),
648 thought_signature: None,
649 part: PartKind::InlineData(Blob {
650 mime_type: media_type.to_mime_type().to_owned(),
651 data,
652 }),
653 additional_params: None,
654 }),
655 _ => Err(message::MessageError::ConversionError(format!(
656 "Unsupported document media type {media_type:?}"
657 ))),
658 },
659 None => Err(message::MessageError::ConversionError(
660 "Media type for document is required for Gemini".to_string(), )),
662 },
663 message::UserContent::Audio(message::Audio {
664 data, media_type, ..
665 }) => match media_type {
666 Some(media_type) => Ok(Part {
667 thought: Some(false),
668 thought_signature: None,
669 part: PartKind::InlineData(Blob {
670 mime_type: media_type.to_mime_type().to_owned(),
671 data,
672 }),
673 additional_params: None,
674 }),
675 None => Err(message::MessageError::ConversionError(
676 "Media type for audio is required for Gemini".to_string(),
677 )),
678 },
679 message::UserContent::Video(message::Video {
680 data,
681 media_type,
682 format,
683 additional_params,
684 }) => {
685 let mime_type = media_type.map(|m| m.to_mime_type().to_owned());
686
687 let data = match format {
688 Some(ContentFormat::String) => PartKind::FileData(FileData {
689 mime_type,
690 file_uri: data,
691 }),
692 _ => match mime_type {
693 Some(mime_type) => PartKind::InlineData(Blob { mime_type, data }),
694 None => {
695 return Err(message::MessageError::ConversionError(
696 "Media type for video is required for Gemini".to_string(),
697 ));
698 }
699 },
700 };
701
702 Ok(Part {
703 thought: Some(false),
704 thought_signature: None,
705 part: data,
706 additional_params,
707 })
708 }
709 }
710 }
711 }
712
713 impl From<message::AssistantContent> for Part {
714 fn from(content: message::AssistantContent) -> Self {
715 match content {
716 message::AssistantContent::Text(message::Text { text }) => text.into(),
717 message::AssistantContent::ToolCall(tool_call) => tool_call.into(),
718 message::AssistantContent::Reasoning(message::Reasoning { reasoning, .. }) => {
719 Part {
720 thought: Some(true),
721 thought_signature: None,
722 part: PartKind::Text(
723 reasoning.first().cloned().unwrap_or_else(|| "".to_string()),
724 ),
725 additional_params: None,
726 }
727 }
728 }
729 }
730 }
731
732 impl From<message::ToolCall> for Part {
733 fn from(tool_call: message::ToolCall) -> Self {
734 Self {
735 thought: Some(false),
736 thought_signature: None,
737 part: PartKind::FunctionCall(FunctionCall {
738 name: tool_call.function.name,
739 args: tool_call.function.arguments,
740 }),
741 additional_params: None,
742 }
743 }
744 }
745
746 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
749 #[serde(rename_all = "camelCase")]
750 pub struct Blob {
751 pub mime_type: String,
754 pub data: String,
756 }
757
758 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
761 pub struct FunctionCall {
762 pub name: String,
765 pub args: serde_json::Value,
767 }
768
769 impl From<FunctionCall> for message::ToolCall {
770 fn from(function_call: FunctionCall) -> Self {
771 Self {
772 id: function_call.name.clone(),
773 call_id: None,
774 function: message::ToolFunction {
775 name: function_call.name,
776 arguments: function_call.args,
777 },
778 }
779 }
780 }
781
782 impl From<message::ToolCall> for FunctionCall {
783 fn from(tool_call: message::ToolCall) -> Self {
784 Self {
785 name: tool_call.function.name,
786 args: tool_call.function.arguments,
787 }
788 }
789 }
790
791 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
795 pub struct FunctionResponse {
796 pub name: String,
799 pub response: Option<serde_json::Value>,
801 }
802
803 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
805 #[serde(rename_all = "camelCase")]
806 pub struct FileData {
807 pub mime_type: Option<String>,
809 pub file_uri: String,
811 }
812
813 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
814 pub struct SafetyRating {
815 pub category: HarmCategory,
816 pub probability: HarmProbability,
817 }
818
819 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
820 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
821 pub enum HarmProbability {
822 HarmProbabilityUnspecified,
823 Negligible,
824 Low,
825 Medium,
826 High,
827 }
828
829 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
830 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
831 pub enum HarmCategory {
832 HarmCategoryUnspecified,
833 HarmCategoryDerogatory,
834 HarmCategoryToxicity,
835 HarmCategoryViolence,
836 HarmCategorySexually,
837 HarmCategoryMedical,
838 HarmCategoryDangerous,
839 HarmCategoryHarassment,
840 HarmCategoryHateSpeech,
841 HarmCategorySexuallyExplicit,
842 HarmCategoryDangerousContent,
843 HarmCategoryCivicIntegrity,
844 }
845
846 #[derive(Debug, Deserialize, Clone, Default, Serialize)]
847 #[serde(rename_all = "camelCase")]
848 pub struct UsageMetadata {
849 pub prompt_token_count: i32,
850 #[serde(skip_serializing_if = "Option::is_none")]
851 pub cached_content_token_count: Option<i32>,
852 pub candidates_token_count: i32,
853 pub total_token_count: i32,
854 }
855
856 impl std::fmt::Display for UsageMetadata {
857 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
858 write!(
859 f,
860 "Prompt token count: {}\nCached content token count: {}\nCandidates token count: {}\nTotal token count: {}",
861 self.prompt_token_count,
862 match self.cached_content_token_count {
863 Some(count) => count.to_string(),
864 None => "n/a".to_string(),
865 },
866 self.candidates_token_count,
867 self.total_token_count
868 )
869 }
870 }
871
872 #[derive(Debug, Deserialize, Serialize)]
874 #[serde(rename_all = "camelCase")]
875 pub struct PromptFeedback {
876 pub block_reason: Option<BlockReason>,
878 pub safety_ratings: Option<Vec<SafetyRating>>,
880 }
881
882 #[derive(Debug, Deserialize, Serialize)]
884 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
885 pub enum BlockReason {
886 BlockReasonUnspecified,
888 Safety,
890 Other,
892 Blocklist,
894 ProhibitedContent,
896 }
897
898 #[derive(Debug, Deserialize, Serialize)]
899 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
900 pub enum FinishReason {
901 FinishReasonUnspecified,
903 Stop,
905 MaxTokens,
907 Safety,
909 Recitation,
911 Language,
913 Other,
915 Blocklist,
917 ProhibitedContent,
919 Spii,
921 MalformedFunctionCall,
923 }
924
925 #[derive(Debug, Deserialize, Serialize)]
926 #[serde(rename_all = "camelCase")]
927 pub struct CitationMetadata {
928 pub citation_sources: Vec<CitationSource>,
929 }
930
931 #[derive(Debug, Deserialize, Serialize)]
932 #[serde(rename_all = "camelCase")]
933 pub struct CitationSource {
934 #[serde(skip_serializing_if = "Option::is_none")]
935 pub uri: Option<String>,
936 #[serde(skip_serializing_if = "Option::is_none")]
937 pub start_index: Option<i32>,
938 #[serde(skip_serializing_if = "Option::is_none")]
939 pub end_index: Option<i32>,
940 #[serde(skip_serializing_if = "Option::is_none")]
941 pub license: Option<String>,
942 }
943
944 #[derive(Debug, Deserialize, Serialize)]
945 #[serde(rename_all = "camelCase")]
946 pub struct LogprobsResult {
947 pub top_candidate: Vec<TopCandidate>,
948 pub chosen_candidate: Vec<LogProbCandidate>,
949 }
950
951 #[derive(Debug, Deserialize, Serialize)]
952 pub struct TopCandidate {
953 pub candidates: Vec<LogProbCandidate>,
954 }
955
956 #[derive(Debug, Deserialize, Serialize)]
957 #[serde(rename_all = "camelCase")]
958 pub struct LogProbCandidate {
959 pub token: String,
960 pub token_id: String,
961 pub log_probability: f64,
962 }
963
964 #[derive(Debug, Deserialize, Serialize)]
969 #[serde(rename_all = "camelCase")]
970 pub struct GenerationConfig {
971 #[serde(skip_serializing_if = "Option::is_none")]
974 pub stop_sequences: Option<Vec<String>>,
975 #[serde(skip_serializing_if = "Option::is_none")]
981 pub response_mime_type: Option<String>,
982 #[serde(skip_serializing_if = "Option::is_none")]
986 pub response_schema: Option<Schema>,
987 #[serde(skip_serializing_if = "Option::is_none")]
990 pub candidate_count: Option<i32>,
991 #[serde(skip_serializing_if = "Option::is_none")]
994 pub max_output_tokens: Option<u64>,
995 #[serde(skip_serializing_if = "Option::is_none")]
998 pub temperature: Option<f64>,
999 #[serde(skip_serializing_if = "Option::is_none")]
1006 pub top_p: Option<f64>,
1007 #[serde(skip_serializing_if = "Option::is_none")]
1013 pub top_k: Option<i32>,
1014 #[serde(skip_serializing_if = "Option::is_none")]
1020 pub presence_penalty: Option<f64>,
1021 #[serde(skip_serializing_if = "Option::is_none")]
1029 pub frequency_penalty: Option<f64>,
1030 #[serde(skip_serializing_if = "Option::is_none")]
1032 pub response_logprobs: Option<bool>,
1033 #[serde(skip_serializing_if = "Option::is_none")]
1036 pub logprobs: Option<i32>,
1037 #[serde(skip_serializing_if = "Option::is_none")]
1039 pub thinking_config: Option<ThinkingConfig>,
1040 }
1041
1042 impl Default for GenerationConfig {
1043 fn default() -> Self {
1044 Self {
1045 temperature: Some(1.0),
1046 max_output_tokens: Some(4096),
1047 stop_sequences: None,
1048 response_mime_type: None,
1049 response_schema: None,
1050 candidate_count: None,
1051 top_p: None,
1052 top_k: None,
1053 presence_penalty: None,
1054 frequency_penalty: None,
1055 response_logprobs: None,
1056 logprobs: None,
1057 thinking_config: None,
1058 }
1059 }
1060 }
1061
1062 #[derive(Debug, Deserialize, Serialize)]
1063 #[serde(rename_all = "camelCase")]
1064 pub struct ThinkingConfig {
1065 pub thinking_budget: u32,
1066 pub include_thoughts: Option<bool>,
1067 }
1068 #[derive(Debug, Deserialize, Serialize, Clone)]
1072 pub struct Schema {
1073 pub r#type: String,
1074 #[serde(skip_serializing_if = "Option::is_none")]
1075 pub format: Option<String>,
1076 #[serde(skip_serializing_if = "Option::is_none")]
1077 pub description: Option<String>,
1078 #[serde(skip_serializing_if = "Option::is_none")]
1079 pub nullable: Option<bool>,
1080 #[serde(skip_serializing_if = "Option::is_none")]
1081 pub r#enum: Option<Vec<String>>,
1082 #[serde(skip_serializing_if = "Option::is_none")]
1083 pub max_items: Option<i32>,
1084 #[serde(skip_serializing_if = "Option::is_none")]
1085 pub min_items: Option<i32>,
1086 #[serde(skip_serializing_if = "Option::is_none")]
1087 pub properties: Option<HashMap<String, Schema>>,
1088 #[serde(skip_serializing_if = "Option::is_none")]
1089 pub required: Option<Vec<String>>,
1090 #[serde(skip_serializing_if = "Option::is_none")]
1091 pub items: Option<Box<Schema>>,
1092 }
1093
1094 impl TryFrom<Value> for Schema {
1095 type Error = CompletionError;
1096
1097 fn try_from(value: Value) -> Result<Self, Self::Error> {
1098 if let Some(obj) = value.as_object() {
1099 Ok(Schema {
1100 r#type: obj
1101 .get("type")
1102 .and_then(|v| {
1103 if v.is_string() {
1104 v.as_str().map(String::from)
1105 } else if v.is_array() {
1106 v.as_array()
1107 .and_then(|arr| arr.first())
1108 .and_then(|v| v.as_str().map(String::from))
1109 } else {
1110 None
1111 }
1112 })
1113 .unwrap_or_default(),
1114 format: obj.get("format").and_then(|v| v.as_str()).map(String::from),
1115 description: obj
1116 .get("description")
1117 .and_then(|v| v.as_str())
1118 .map(String::from),
1119 nullable: obj.get("nullable").and_then(|v| v.as_bool()),
1120 r#enum: obj.get("enum").and_then(|v| v.as_array()).map(|arr| {
1121 arr.iter()
1122 .filter_map(|v| v.as_str().map(String::from))
1123 .collect()
1124 }),
1125 max_items: obj
1126 .get("maxItems")
1127 .and_then(|v| v.as_i64())
1128 .map(|v| v as i32),
1129 min_items: obj
1130 .get("minItems")
1131 .and_then(|v| v.as_i64())
1132 .map(|v| v as i32),
1133 properties: obj
1134 .get("properties")
1135 .and_then(|v| v.as_object())
1136 .map(|map| {
1137 map.iter()
1138 .filter_map(|(k, v)| {
1139 v.clone().try_into().ok().map(|schema| (k.clone(), schema))
1140 })
1141 .collect()
1142 }),
1143 required: obj.get("required").and_then(|v| v.as_array()).map(|arr| {
1144 arr.iter()
1145 .filter_map(|v| v.as_str().map(String::from))
1146 .collect()
1147 }),
1148 items: obj
1149 .get("items")
1150 .map(|v| Box::new(v.clone().try_into().unwrap())),
1151 })
1152 } else {
1153 Err(CompletionError::ResponseError(
1154 "Expected a JSON object for Schema".into(),
1155 ))
1156 }
1157 }
1158 }
1159
1160 #[derive(Debug, Serialize)]
1161 #[serde(rename_all = "camelCase")]
1162 pub struct GenerateContentRequest {
1163 pub contents: Vec<Content>,
1164 #[serde(skip_serializing_if = "Option::is_none")]
1165 pub tools: Option<Tool>,
1166 pub tool_config: Option<ToolConfig>,
1167 pub generation_config: Option<GenerationConfig>,
1169 pub safety_settings: Option<Vec<SafetySetting>>,
1183 pub system_instruction: Option<Content>,
1186 #[serde(flatten, skip_serializing_if = "Option::is_none")]
1189 pub additional_params: Option<serde_json::Value>,
1190 }
1191
1192 #[derive(Debug, Serialize)]
1193 #[serde(rename_all = "camelCase")]
1194 pub struct Tool {
1195 pub function_declarations: Vec<FunctionDeclaration>,
1196 pub code_execution: Option<CodeExecution>,
1197 }
1198
1199 #[derive(Debug, Serialize, Clone)]
1200 #[serde(rename_all = "camelCase")]
1201 pub struct FunctionDeclaration {
1202 pub name: String,
1203 pub description: String,
1204 #[serde(skip_serializing_if = "Option::is_none")]
1205 pub parameters: Option<Schema>,
1206 }
1207
1208 #[derive(Debug, Serialize)]
1209 #[serde(rename_all = "camelCase")]
1210 pub struct ToolConfig {
1211 pub schema: Option<Schema>,
1212 }
1213
1214 #[derive(Debug, Serialize)]
1215 #[serde(rename_all = "camelCase")]
1216 pub struct CodeExecution {}
1217
1218 #[derive(Debug, Serialize)]
1219 #[serde(rename_all = "camelCase")]
1220 pub struct SafetySetting {
1221 pub category: HarmCategory,
1222 pub threshold: HarmBlockThreshold,
1223 }
1224
1225 #[derive(Debug, Serialize)]
1226 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1227 pub enum HarmBlockThreshold {
1228 HarmBlockThresholdUnspecified,
1229 BlockLowAndAbove,
1230 BlockMediumAndAbove,
1231 BlockOnlyHigh,
1232 BlockNone,
1233 Off,
1234 }
1235}
1236
1237#[cfg(test)]
1238mod tests {
1239 use crate::message;
1240
1241 use super::*;
1242 use serde_json::json;
1243
1244 #[test]
1245 fn test_deserialize_message_user() {
1246 let raw_message = r#"{
1247 "parts": [
1248 {"text": "Hello, world!"},
1249 {"inlineData": {"mimeType": "image/png", "data": "base64encodeddata"}},
1250 {"functionCall": {"name": "test_function", "args": {"arg1": "value1"}}},
1251 {"functionResponse": {"name": "test_function", "response": {"result": "success"}}},
1252 {"fileData": {"mimeType": "application/pdf", "fileUri": "http://example.com/file.pdf"}},
1253 {"executableCode": {"code": "print('Hello, world!')", "language": "PYTHON"}},
1254 {"codeExecutionResult": {"output": "Hello, world!", "outcome": "OUTCOME_OK"}}
1255 ],
1256 "role": "user"
1257 }"#;
1258
1259 let content: Content = {
1260 let jd = &mut serde_json::Deserializer::from_str(raw_message);
1261 serde_path_to_error::deserialize(jd).unwrap_or_else(|err| {
1262 panic!("Deserialization error at {}: {}", err.path(), err);
1263 })
1264 };
1265 assert_eq!(content.role, Some(Role::User));
1266 assert_eq!(content.parts.len(), 7);
1267
1268 let parts: Vec<Part> = content.parts.into_iter().collect();
1269
1270 if let Part {
1271 part: PartKind::Text(text),
1272 ..
1273 } = &parts[0]
1274 {
1275 assert_eq!(text, "Hello, world!");
1276 } else {
1277 panic!("Expected text part");
1278 }
1279
1280 if let Part {
1281 part: PartKind::InlineData(inline_data),
1282 ..
1283 } = &parts[1]
1284 {
1285 assert_eq!(inline_data.mime_type, "image/png");
1286 assert_eq!(inline_data.data, "base64encodeddata");
1287 } else {
1288 panic!("Expected inline data part");
1289 }
1290
1291 if let Part {
1292 part: PartKind::FunctionCall(function_call),
1293 ..
1294 } = &parts[2]
1295 {
1296 assert_eq!(function_call.name, "test_function");
1297 assert_eq!(
1298 function_call.args.as_object().unwrap().get("arg1").unwrap(),
1299 "value1"
1300 );
1301 } else {
1302 panic!("Expected function call part");
1303 }
1304
1305 if let Part {
1306 part: PartKind::FunctionResponse(function_response),
1307 ..
1308 } = &parts[3]
1309 {
1310 assert_eq!(function_response.name, "test_function");
1311 assert_eq!(
1312 function_response
1313 .response
1314 .as_ref()
1315 .unwrap()
1316 .get("result")
1317 .unwrap(),
1318 "success"
1319 );
1320 } else {
1321 panic!("Expected function response part");
1322 }
1323
1324 if let Part {
1325 part: PartKind::FileData(file_data),
1326 ..
1327 } = &parts[4]
1328 {
1329 assert_eq!(file_data.mime_type.as_ref().unwrap(), "application/pdf");
1330 assert_eq!(file_data.file_uri, "http://example.com/file.pdf");
1331 } else {
1332 panic!("Expected file data part");
1333 }
1334
1335 if let Part {
1336 part: PartKind::ExecutableCode(executable_code),
1337 ..
1338 } = &parts[5]
1339 {
1340 assert_eq!(executable_code.code, "print('Hello, world!')");
1341 } else {
1342 panic!("Expected executable code part");
1343 }
1344
1345 if let Part {
1346 part: PartKind::CodeExecutionResult(code_execution_result),
1347 ..
1348 } = &parts[6]
1349 {
1350 assert_eq!(
1351 code_execution_result.clone().output.unwrap(),
1352 "Hello, world!"
1353 );
1354 } else {
1355 panic!("Expected code execution result part");
1356 }
1357 }
1358
1359 #[test]
1360 fn test_deserialize_message_model() {
1361 let json_data = json!({
1362 "parts": [{"text": "Hello, user!"}],
1363 "role": "model"
1364 });
1365
1366 let content: Content = serde_json::from_value(json_data).unwrap();
1367 assert_eq!(content.role, Some(Role::Model));
1368 assert_eq!(content.parts.len(), 1);
1369 if let Some(Part {
1370 part: PartKind::Text(text),
1371 ..
1372 }) = content.parts.first()
1373 {
1374 assert_eq!(text, "Hello, user!");
1375 } else {
1376 panic!("Expected text part");
1377 }
1378 }
1379
1380 #[test]
1381 fn test_message_conversion_user() {
1382 let msg = message::Message::user("Hello, world!");
1383 let content: Content = msg.try_into().unwrap();
1384 assert_eq!(content.role, Some(Role::User));
1385 assert_eq!(content.parts.len(), 1);
1386 if let Some(Part {
1387 part: PartKind::Text(text),
1388 ..
1389 }) = &content.parts.first()
1390 {
1391 assert_eq!(text, "Hello, world!");
1392 } else {
1393 panic!("Expected text part");
1394 }
1395 }
1396
1397 #[test]
1398 fn test_message_conversion_model() {
1399 let msg = message::Message::assistant("Hello, user!");
1400
1401 let content: Content = msg.try_into().unwrap();
1402 assert_eq!(content.role, Some(Role::Model));
1403 assert_eq!(content.parts.len(), 1);
1404 if let Some(Part {
1405 part: PartKind::Text(text),
1406 ..
1407 }) = &content.parts.first()
1408 {
1409 assert_eq!(text, "Hello, user!");
1410 } else {
1411 panic!("Expected text part");
1412 }
1413 }
1414
1415 #[test]
1416 fn test_message_conversion_tool_call() {
1417 let tool_call = message::ToolCall {
1418 id: "test_tool".to_string(),
1419 call_id: None,
1420 function: message::ToolFunction {
1421 name: "test_function".to_string(),
1422 arguments: json!({"arg1": "value1"}),
1423 },
1424 };
1425
1426 let msg = message::Message::Assistant {
1427 id: None,
1428 content: OneOrMany::one(message::AssistantContent::ToolCall(tool_call)),
1429 };
1430
1431 let content: Content = msg.try_into().unwrap();
1432 assert_eq!(content.role, Some(Role::Model));
1433 assert_eq!(content.parts.len(), 1);
1434 if let Some(Part {
1435 part: PartKind::FunctionCall(function_call),
1436 ..
1437 }) = content.parts.first()
1438 {
1439 assert_eq!(function_call.name, "test_function");
1440 assert_eq!(
1441 function_call.args.as_object().unwrap().get("arg1").unwrap(),
1442 "value1"
1443 );
1444 } else {
1445 panic!("Expected function call part");
1446 }
1447 }
1448}