1pub const GEMINI_2_5_PRO_PREVIEW_06_05: &str = "gemini-2.5-pro-preview-06-05";
7pub const GEMINI_2_5_PRO_PREVIEW_05_06: &str = "gemini-2.5-pro-preview-05-06";
9pub const GEMINI_2_5_PRO_PREVIEW_03_25: &str = "gemini-2.5-pro-preview-03-25";
11pub const GEMINI_2_5_FLASH_PREVIEW_05_20: &str = "gemini-2.5-flash-preview-05-20";
13pub const GEMINI_2_5_FLASH_PREVIEW_04_17: &str = "gemini-2.5-flash-preview-04-17";
15pub const GEMINI_2_5_PRO_EXP_03_25: &str = "gemini-2.5-pro-exp-03-25";
17pub const GEMINI_2_0_FLASH_LITE: &str = "gemini-2.0-flash-lite";
19pub const GEMINI_2_0_FLASH: &str = "gemini-2.0-flash";
21pub const GEMINI_1_5_FLASH: &str = "gemini-1.5-flash";
23pub const GEMINI_1_5_PRO: &str = "gemini-1.5-pro";
25pub const GEMINI_1_5_PRO_8B: &str = "gemini-1.5-pro-8b";
27pub const GEMINI_1_0_PRO: &str = "gemini-1.0-pro";
29
30use self::gemini_api_types::Schema;
31use crate::message::Reasoning;
32use crate::providers::gemini::completion::gemini_api_types::AdditionalParameters;
33use crate::providers::gemini::streaming::StreamingCompletionResponse;
34use crate::{
35 OneOrMany,
36 completion::{self, CompletionError, CompletionRequest},
37};
38use gemini_api_types::{
39 Content, FunctionDeclaration, GenerateContentRequest, GenerateContentResponse, Part, PartKind,
40 Role, Tool,
41};
42use serde_json::{Map, Value};
43use std::convert::TryFrom;
44
45use super::Client;
46
47#[derive(Clone)]
52pub struct CompletionModel {
53 pub(crate) client: Client,
54 pub model: String,
55}
56
57impl CompletionModel {
58 pub fn new(client: Client, model: &str) -> Self {
59 Self {
60 client,
61 model: model.to_string(),
62 }
63 }
64}
65
66impl completion::CompletionModel for CompletionModel {
67 type Response = GenerateContentResponse;
68 type StreamingResponse = StreamingCompletionResponse;
69
70 #[cfg_attr(feature = "worker", worker::send)]
71 async fn completion(
72 &self,
73 completion_request: CompletionRequest,
74 ) -> Result<completion::CompletionResponse<GenerateContentResponse>, CompletionError> {
75 let request = create_request_body(completion_request)?;
76
77 tracing::debug!(
78 "Sending completion request to Gemini API {}",
79 serde_json::to_string_pretty(&request)?
80 );
81
82 let response = self
83 .client
84 .post(&format!("/v1beta/models/{}:generateContent", self.model))
85 .json(&request)
86 .send()
87 .await?;
88
89 if response.status().is_success() {
90 let response = response.json::<GenerateContentResponse>().await?;
91 match response.usage_metadata {
92 Some(ref usage) => tracing::info!(target: "rig",
93 "Gemini completion token usage: {}",
94 usage
95 ),
96 None => tracing::info!(target: "rig",
97 "Gemini completion token usage: n/a",
98 ),
99 }
100
101 tracing::debug!("Received response");
102
103 Ok(completion::CompletionResponse::try_from(response))
104 } else {
105 Err(CompletionError::ProviderError(response.text().await?))
106 }?
107 }
108
109 #[cfg_attr(feature = "worker", worker::send)]
110 async fn stream(
111 &self,
112 request: CompletionRequest,
113 ) -> Result<
114 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
115 CompletionError,
116 > {
117 CompletionModel::stream(self, request).await
118 }
119}
120
121pub(crate) fn create_request_body(
122 completion_request: CompletionRequest,
123) -> Result<GenerateContentRequest, CompletionError> {
124 let mut full_history = Vec::new();
125 full_history.extend(completion_request.chat_history);
126
127 let additional_params = completion_request
128 .additional_params
129 .unwrap_or_else(|| Value::Object(Map::new()));
130
131 let AdditionalParameters {
132 mut generation_config,
133 additional_params,
134 } = serde_json::from_value::<AdditionalParameters>(additional_params)?;
135
136 if let Some(temp) = completion_request.temperature {
137 generation_config.temperature = Some(temp);
138 }
139
140 if let Some(max_tokens) = completion_request.max_tokens {
141 generation_config.max_output_tokens = Some(max_tokens);
142 }
143
144 let system_instruction = completion_request.preamble.clone().map(|preamble| Content {
145 parts: vec![preamble.into()],
146 role: Some(Role::Model),
147 });
148
149 let tools = if completion_request.tools.is_empty() {
150 None
151 } else {
152 Some(Tool::try_from(completion_request.tools)?)
153 };
154
155 let request = GenerateContentRequest {
156 contents: full_history
157 .into_iter()
158 .map(|msg| {
159 msg.try_into()
160 .map_err(|e| CompletionError::RequestError(Box::new(e)))
161 })
162 .collect::<Result<Vec<_>, _>>()?,
163 generation_config: Some(generation_config),
164 safety_settings: None,
165 tools,
166 tool_config: None,
167 system_instruction,
168 additional_params,
169 };
170
171 Ok(request)
172}
173
174impl TryFrom<completion::ToolDefinition> for Tool {
175 type Error = CompletionError;
176
177 fn try_from(tool: completion::ToolDefinition) -> Result<Self, Self::Error> {
178 let parameters: Option<Schema> =
179 if tool.parameters == serde_json::json!({"type": "object", "properties": {}}) {
180 None
181 } else {
182 Some(tool.parameters.try_into()?)
183 };
184
185 Ok(Self {
186 function_declarations: vec![FunctionDeclaration {
187 name: tool.name,
188 description: tool.description,
189 parameters,
190 }],
191 code_execution: None,
192 })
193 }
194}
195
196impl TryFrom<Vec<completion::ToolDefinition>> for Tool {
197 type Error = CompletionError;
198
199 fn try_from(tools: Vec<completion::ToolDefinition>) -> Result<Self, Self::Error> {
200 let mut function_declarations = Vec::new();
201
202 for tool in tools {
203 let parameters =
204 if tool.parameters == serde_json::json!({"type": "object", "properties": {}}) {
205 None
206 } else {
207 match tool.parameters.try_into() {
208 Ok(schema) => Some(schema),
209 Err(e) => {
210 let emsg = format!(
211 "Tool '{}' could not be converted to a schema: {:?}",
212 tool.name, e,
213 );
214 return Err(CompletionError::ProviderError(emsg));
215 }
216 }
217 };
218
219 function_declarations.push(FunctionDeclaration {
220 name: tool.name,
221 description: tool.description,
222 parameters,
223 });
224 }
225
226 Ok(Self {
227 function_declarations,
228 code_execution: None,
229 })
230 }
231}
232
233impl TryFrom<GenerateContentResponse> for completion::CompletionResponse<GenerateContentResponse> {
234 type Error = CompletionError;
235
236 fn try_from(response: GenerateContentResponse) -> Result<Self, Self::Error> {
237 let candidate = response.candidates.first().ok_or_else(|| {
238 CompletionError::ResponseError("No response candidates in response".into())
239 })?;
240
241 let content = candidate
242 .content
243 .parts
244 .iter()
245 .map(|Part { thought, part, .. }| {
246 Ok(match part {
247 PartKind::Text(text) => {
248 if let Some(thought) = thought
249 && *thought
250 {
251 completion::AssistantContent::Reasoning(Reasoning::new(text))
252 } else {
253 completion::AssistantContent::text(text)
254 }
255 }
256 PartKind::FunctionCall(function_call) => {
257 completion::AssistantContent::tool_call(
258 &function_call.name,
259 &function_call.name,
260 function_call.args.clone(),
261 )
262 }
263 _ => {
264 return Err(CompletionError::ResponseError(
265 "Response did not contain a message or tool call".into(),
266 ));
267 }
268 })
269 })
270 .collect::<Result<Vec<_>, _>>()?;
271
272 let choice = OneOrMany::many(content).map_err(|_| {
273 CompletionError::ResponseError(
274 "Response contained no message or tool call (empty)".to_owned(),
275 )
276 })?;
277
278 let usage = response
279 .usage_metadata
280 .as_ref()
281 .map(|usage| completion::Usage {
282 input_tokens: usage.prompt_token_count as u64,
283 output_tokens: usage.candidates_token_count.unwrap_or(0) as u64,
284 total_tokens: usage.total_token_count as u64,
285 })
286 .unwrap_or_default();
287
288 Ok(completion::CompletionResponse {
289 choice,
290 usage,
291 raw_response: response,
292 })
293 }
294}
295
296pub mod gemini_api_types {
297 use std::{collections::HashMap, convert::Infallible, str::FromStr};
298
299 use serde::{Deserialize, Serialize};
303 use serde_json::{Value, json};
304
305 use crate::message::{ContentFormat, DocumentSourceKind, ImageMediaType};
306 use crate::{
307 OneOrMany,
308 completion::CompletionError,
309 message::{self, MimeType as _, Reasoning, Text},
310 providers::gemini::gemini_api_types::{CodeExecutionResult, ExecutableCode},
311 };
312
313 #[derive(Debug, Deserialize, Serialize, Default)]
314 #[serde(rename_all = "camelCase")]
315 pub struct AdditionalParameters {
316 pub generation_config: GenerationConfig,
318 #[serde(flatten, skip_serializing_if = "Option::is_none")]
320 pub additional_params: Option<serde_json::Value>,
321 }
322
323 impl AdditionalParameters {
324 pub fn with_config(mut self, cfg: GenerationConfig) -> Self {
325 self.generation_config = cfg;
326 self
327 }
328
329 pub fn with_params(mut self, params: serde_json::Value) -> Self {
330 self.additional_params = Some(params);
331 self
332 }
333 }
334
335 #[derive(Debug, Deserialize, Serialize)]
343 #[serde(rename_all = "camelCase")]
344 pub struct GenerateContentResponse {
345 pub candidates: Vec<ContentCandidate>,
347 pub prompt_feedback: Option<PromptFeedback>,
349 pub usage_metadata: Option<UsageMetadata>,
351 pub model_version: Option<String>,
352 }
353
354 #[derive(Debug, Deserialize, Serialize)]
356 #[serde(rename_all = "camelCase")]
357 pub struct ContentCandidate {
358 pub content: Content,
360 pub finish_reason: Option<FinishReason>,
363 pub safety_ratings: Option<Vec<SafetyRating>>,
366 pub citation_metadata: Option<CitationMetadata>,
370 pub token_count: Option<i32>,
372 pub avg_logprobs: Option<f64>,
374 pub logprobs_result: Option<LogprobsResult>,
376 pub index: Option<i32>,
378 }
379
380 #[derive(Debug, Deserialize, Serialize)]
381 pub struct Content {
382 #[serde(default)]
384 pub parts: Vec<Part>,
385 pub role: Option<Role>,
388 }
389
390 impl TryFrom<message::Message> for Content {
391 type Error = message::MessageError;
392
393 fn try_from(msg: message::Message) -> Result<Self, Self::Error> {
394 Ok(match msg {
395 message::Message::User { content } => Content {
396 parts: content
397 .into_iter()
398 .map(|c| c.try_into())
399 .collect::<Result<Vec<_>, _>>()?,
400 role: Some(Role::User),
401 },
402 message::Message::Assistant { content, .. } => Content {
403 role: Some(Role::Model),
404 parts: content.into_iter().map(|content| content.into()).collect(),
405 },
406 })
407 }
408 }
409
410 impl TryFrom<Content> for message::Message {
411 type Error = message::MessageError;
412
413 fn try_from(content: Content) -> Result<Self, Self::Error> {
414 match content.role {
415 Some(Role::User) | None => {
416 Ok(message::Message::User {
417 content: {
418 let user_content: Result<Vec<_>, _> = content.parts.into_iter()
419 .map(|Part { part, .. }| {
420 Ok(match part {
421 PartKind::Text(text) => message::UserContent::text(text),
422 PartKind::InlineData(inline_data) => {
423 let mime_type =
424 message::MediaType::from_mime_type(&inline_data.mime_type);
425
426 match mime_type {
427 Some(message::MediaType::Image(media_type)) => {
428 message::UserContent::image_base64(
429 inline_data.data,
430 Some(media_type),
431 Some(message::ImageDetail::default()),
432 )
433 }
434 Some(message::MediaType::Document(media_type)) => {
435 message::UserContent::document(
436 inline_data.data,
437 Some(message::ContentFormat::default()),
438 Some(media_type),
439 )
440 }
441 Some(message::MediaType::Audio(media_type)) => {
442 message::UserContent::audio(
443 inline_data.data,
444 Some(message::ContentFormat::default()),
445 Some(media_type),
446 )
447 }
448 _ => {
449 return Err(message::MessageError::ConversionError(
450 format!("Unsupported media type {mime_type:?}"),
451 ));
452 }
453 }
454 }
455 _ => {
456 return Err(message::MessageError::ConversionError(format!(
457 "Unsupported gemini content part type: {part:?}"
458 )));
459 }
460 })
461 })
462 .collect();
463 OneOrMany::many(user_content?).map_err(|_| {
464 message::MessageError::ConversionError(
465 "Failed to create OneOrMany from user content".to_string(),
466 )
467 })?
468 },
469 })
470 }
471 Some(Role::Model) => Ok(message::Message::Assistant {
472 id: None,
473 content: {
474 let assistant_content: Result<Vec<_>, _> = content
475 .parts
476 .into_iter()
477 .map(|Part { thought, part, .. }| {
478 Ok(match part {
479 PartKind::Text(text) => match thought {
480 Some(true) => message::AssistantContent::Reasoning(
481 Reasoning::new(&text),
482 ),
483 _ => message::AssistantContent::Text(Text { text }),
484 },
485
486 PartKind::FunctionCall(function_call) => {
487 message::AssistantContent::ToolCall(function_call.into())
488 }
489 _ => {
490 return Err(message::MessageError::ConversionError(
491 format!("Unsupported part type: {part:?}"),
492 ));
493 }
494 })
495 })
496 .collect();
497 OneOrMany::many(assistant_content?).map_err(|_| {
498 message::MessageError::ConversionError(
499 "Failed to create OneOrMany from assistant content".to_string(),
500 )
501 })?
502 },
503 }),
504 }
505 }
506 }
507
508 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
509 #[serde(rename_all = "lowercase")]
510 pub enum Role {
511 User,
512 Model,
513 }
514
515 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
516 #[serde(rename_all = "camelCase")]
517 pub struct Part {
518 #[serde(skip_serializing_if = "Option::is_none")]
520 pub thought: Option<bool>,
521 #[serde(skip_serializing_if = "Option::is_none")]
523 pub thought_signature: Option<String>,
524 #[serde(flatten)]
525 pub part: PartKind,
526 #[serde(flatten, skip_serializing_if = "Option::is_none")]
527 pub additional_params: Option<Value>,
528 }
529
530 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
534 #[serde(rename_all = "camelCase")]
535 pub enum PartKind {
536 Text(String),
537 InlineData(Blob),
538 FunctionCall(FunctionCall),
539 FunctionResponse(FunctionResponse),
540 FileData(FileData),
541 ExecutableCode(ExecutableCode),
542 CodeExecutionResult(CodeExecutionResult),
543 }
544
545 impl From<String> for Part {
546 fn from(text: String) -> Self {
547 Self {
548 thought: Some(false),
549 thought_signature: None,
550 part: PartKind::Text(text),
551 additional_params: None,
552 }
553 }
554 }
555
556 impl From<&str> for Part {
557 fn from(text: &str) -> Self {
558 Self::from(text.to_string())
559 }
560 }
561
562 impl FromStr for Part {
563 type Err = Infallible;
564
565 fn from_str(s: &str) -> Result<Self, Self::Err> {
566 Ok(s.into())
567 }
568 }
569
570 impl TryFrom<(ImageMediaType, DocumentSourceKind)> for PartKind {
571 type Error = message::MessageError;
572 fn try_from(
573 (mime_type, doc_src): (ImageMediaType, DocumentSourceKind),
574 ) -> Result<Self, Self::Error> {
575 let mime_type = mime_type.to_mime_type().to_string();
576 let part = match doc_src {
577 DocumentSourceKind::Url(url) => PartKind::FileData(FileData {
578 mime_type: Some(mime_type),
579 file_uri: url,
580 }),
581 DocumentSourceKind::Base64(data) => PartKind::InlineData(Blob { mime_type, data }),
582 DocumentSourceKind::Unknown => {
583 return Err(message::MessageError::ConversionError(
584 "Can't convert an unknown document source".to_string(),
585 ));
586 }
587 };
588
589 Ok(part)
590 }
591 }
592
593 impl TryFrom<message::UserContent> for Part {
594 type Error = message::MessageError;
595
596 fn try_from(content: message::UserContent) -> Result<Self, Self::Error> {
597 match content {
598 message::UserContent::Text(message::Text { text }) => Ok(Part {
599 thought: Some(false),
600 thought_signature: None,
601 part: PartKind::Text(text),
602 additional_params: None,
603 }),
604 message::UserContent::ToolResult(message::ToolResult { id, content, .. }) => {
605 let content = match content.first() {
606 message::ToolResultContent::Text(text) => text.text,
607 message::ToolResultContent::Image(_) => {
608 return Err(message::MessageError::ConversionError(
609 "Tool result content must be text".to_string(),
610 ));
611 }
612 };
613 let result: serde_json::Value =
615 serde_json::from_str(&content).unwrap_or_else(|error| {
616 tracing::trace!(
617 ?error,
618 "Tool result is not a valid JSON, treat it as normal string"
619 );
620 json!(content)
621 });
622 Ok(Part {
623 thought: Some(false),
624 thought_signature: None,
625 part: PartKind::FunctionResponse(FunctionResponse {
626 name: id,
627 response: Some(json!({ "result": result })),
628 }),
629 additional_params: None,
630 })
631 }
632 message::UserContent::Image(message::Image {
633 data, media_type, ..
634 }) => match media_type {
635 Some(media_type) => match media_type {
636 message::ImageMediaType::JPEG
637 | message::ImageMediaType::PNG
638 | message::ImageMediaType::WEBP
639 | message::ImageMediaType::HEIC
640 | message::ImageMediaType::HEIF => {
641 let part = PartKind::try_from((media_type, data))?;
642 Ok(Part {
643 thought: Some(false),
644 thought_signature: None,
645 part,
646 additional_params: None,
647 })
648 }
649 _ => Err(message::MessageError::ConversionError(format!(
650 "Unsupported image media type {media_type:?}"
651 ))),
652 },
653 None => Err(message::MessageError::ConversionError(
654 "Media type for image is required for Gemini".to_string(), )),
656 },
657 message::UserContent::Document(message::Document {
658 data, media_type, ..
659 }) => match media_type {
660 Some(media_type) => match media_type {
661 message::DocumentMediaType::PDF
662 | message::DocumentMediaType::TXT
663 | message::DocumentMediaType::RTF
664 | message::DocumentMediaType::HTML
665 | message::DocumentMediaType::CSS
666 | message::DocumentMediaType::MARKDOWN
667 | message::DocumentMediaType::CSV
668 | message::DocumentMediaType::XML => Ok(Part {
669 thought: Some(false),
670 thought_signature: None,
671 part: PartKind::InlineData(Blob {
672 mime_type: media_type.to_mime_type().to_owned(),
673 data,
674 }),
675 additional_params: None,
676 }),
677 _ => Err(message::MessageError::ConversionError(format!(
678 "Unsupported document media type {media_type:?}"
679 ))),
680 },
681 None => Err(message::MessageError::ConversionError(
682 "Media type for document is required for Gemini".to_string(), )),
684 },
685 message::UserContent::Audio(message::Audio {
686 data, media_type, ..
687 }) => match media_type {
688 Some(media_type) => Ok(Part {
689 thought: Some(false),
690 thought_signature: None,
691 part: PartKind::InlineData(Blob {
692 mime_type: media_type.to_mime_type().to_owned(),
693 data,
694 }),
695 additional_params: None,
696 }),
697 None => Err(message::MessageError::ConversionError(
698 "Media type for audio is required for Gemini".to_string(),
699 )),
700 },
701 message::UserContent::Video(message::Video {
702 data,
703 media_type,
704 format,
705 additional_params,
706 }) => {
707 let mime_type = media_type.map(|m| m.to_mime_type().to_owned());
708
709 let data = match format {
710 Some(ContentFormat::String) => PartKind::FileData(FileData {
711 mime_type,
712 file_uri: data,
713 }),
714 _ => match mime_type {
715 Some(mime_type) => PartKind::InlineData(Blob { mime_type, data }),
716 None => {
717 return Err(message::MessageError::ConversionError(
718 "Media type for video is required for Gemini".to_string(),
719 ));
720 }
721 },
722 };
723
724 Ok(Part {
725 thought: Some(false),
726 thought_signature: None,
727 part: data,
728 additional_params,
729 })
730 }
731 }
732 }
733 }
734
735 impl From<message::AssistantContent> for Part {
736 fn from(content: message::AssistantContent) -> Self {
737 match content {
738 message::AssistantContent::Text(message::Text { text }) => text.into(),
739 message::AssistantContent::ToolCall(tool_call) => tool_call.into(),
740 message::AssistantContent::Reasoning(message::Reasoning { reasoning, .. }) => {
741 Part {
742 thought: Some(true),
743 thought_signature: None,
744 part: PartKind::Text(
745 reasoning.first().cloned().unwrap_or_else(|| "".to_string()),
746 ),
747 additional_params: None,
748 }
749 }
750 }
751 }
752 }
753
754 impl From<message::ToolCall> for Part {
755 fn from(tool_call: message::ToolCall) -> Self {
756 Self {
757 thought: Some(false),
758 thought_signature: None,
759 part: PartKind::FunctionCall(FunctionCall {
760 name: tool_call.function.name,
761 args: tool_call.function.arguments,
762 }),
763 additional_params: None,
764 }
765 }
766 }
767
768 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
771 #[serde(rename_all = "camelCase")]
772 pub struct Blob {
773 pub mime_type: String,
776 pub data: String,
778 }
779
780 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
783 pub struct FunctionCall {
784 pub name: String,
787 pub args: serde_json::Value,
789 }
790
791 impl From<FunctionCall> for message::ToolCall {
792 fn from(function_call: FunctionCall) -> Self {
793 Self {
794 id: function_call.name.clone(),
795 call_id: None,
796 function: message::ToolFunction {
797 name: function_call.name,
798 arguments: function_call.args,
799 },
800 }
801 }
802 }
803
804 impl From<message::ToolCall> for FunctionCall {
805 fn from(tool_call: message::ToolCall) -> Self {
806 Self {
807 name: tool_call.function.name,
808 args: tool_call.function.arguments,
809 }
810 }
811 }
812
813 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
817 pub struct FunctionResponse {
818 pub name: String,
821 pub response: Option<serde_json::Value>,
823 }
824
825 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
827 #[serde(rename_all = "camelCase")]
828 pub struct FileData {
829 pub mime_type: Option<String>,
831 pub file_uri: String,
833 }
834
835 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
836 pub struct SafetyRating {
837 pub category: HarmCategory,
838 pub probability: HarmProbability,
839 }
840
841 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
842 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
843 pub enum HarmProbability {
844 HarmProbabilityUnspecified,
845 Negligible,
846 Low,
847 Medium,
848 High,
849 }
850
851 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
852 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
853 pub enum HarmCategory {
854 HarmCategoryUnspecified,
855 HarmCategoryDerogatory,
856 HarmCategoryToxicity,
857 HarmCategoryViolence,
858 HarmCategorySexually,
859 HarmCategoryMedical,
860 HarmCategoryDangerous,
861 HarmCategoryHarassment,
862 HarmCategoryHateSpeech,
863 HarmCategorySexuallyExplicit,
864 HarmCategoryDangerousContent,
865 HarmCategoryCivicIntegrity,
866 }
867
868 #[derive(Debug, Deserialize, Clone, Default, Serialize)]
869 #[serde(rename_all = "camelCase")]
870 pub struct UsageMetadata {
871 pub prompt_token_count: i32,
872 #[serde(skip_serializing_if = "Option::is_none")]
873 pub cached_content_token_count: Option<i32>,
874 #[serde(skip_serializing_if = "Option::is_none")]
875 pub candidates_token_count: Option<i32>,
876 pub total_token_count: i32,
877 #[serde(skip_serializing_if = "Option::is_none")]
878 pub thoughts_token_count: Option<i32>,
879 }
880
881 impl std::fmt::Display for UsageMetadata {
882 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
883 write!(
884 f,
885 "Prompt token count: {}\nCached content token count: {}\nCandidates token count: {}\nTotal token count: {}",
886 self.prompt_token_count,
887 match self.cached_content_token_count {
888 Some(count) => count.to_string(),
889 None => "n/a".to_string(),
890 },
891 match self.candidates_token_count {
892 Some(count) => count.to_string(),
893 None => "n/a".to_string(),
894 },
895 self.total_token_count
896 )
897 }
898 }
899
900 #[derive(Debug, Deserialize, Serialize)]
902 #[serde(rename_all = "camelCase")]
903 pub struct PromptFeedback {
904 pub block_reason: Option<BlockReason>,
906 pub safety_ratings: Option<Vec<SafetyRating>>,
908 }
909
910 #[derive(Debug, Deserialize, Serialize)]
912 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
913 pub enum BlockReason {
914 BlockReasonUnspecified,
916 Safety,
918 Other,
920 Blocklist,
922 ProhibitedContent,
924 }
925
926 #[derive(Debug, Deserialize, Serialize)]
927 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
928 pub enum FinishReason {
929 FinishReasonUnspecified,
931 Stop,
933 MaxTokens,
935 Safety,
937 Recitation,
939 Language,
941 Other,
943 Blocklist,
945 ProhibitedContent,
947 Spii,
949 MalformedFunctionCall,
951 }
952
953 #[derive(Debug, Deserialize, Serialize)]
954 #[serde(rename_all = "camelCase")]
955 pub struct CitationMetadata {
956 pub citation_sources: Vec<CitationSource>,
957 }
958
959 #[derive(Debug, Deserialize, Serialize)]
960 #[serde(rename_all = "camelCase")]
961 pub struct CitationSource {
962 #[serde(skip_serializing_if = "Option::is_none")]
963 pub uri: Option<String>,
964 #[serde(skip_serializing_if = "Option::is_none")]
965 pub start_index: Option<i32>,
966 #[serde(skip_serializing_if = "Option::is_none")]
967 pub end_index: Option<i32>,
968 #[serde(skip_serializing_if = "Option::is_none")]
969 pub license: Option<String>,
970 }
971
972 #[derive(Debug, Deserialize, Serialize)]
973 #[serde(rename_all = "camelCase")]
974 pub struct LogprobsResult {
975 pub top_candidate: Vec<TopCandidate>,
976 pub chosen_candidate: Vec<LogProbCandidate>,
977 }
978
979 #[derive(Debug, Deserialize, Serialize)]
980 pub struct TopCandidate {
981 pub candidates: Vec<LogProbCandidate>,
982 }
983
984 #[derive(Debug, Deserialize, Serialize)]
985 #[serde(rename_all = "camelCase")]
986 pub struct LogProbCandidate {
987 pub token: String,
988 pub token_id: String,
989 pub log_probability: f64,
990 }
991
992 #[derive(Debug, Deserialize, Serialize)]
997 #[serde(rename_all = "camelCase")]
998 pub struct GenerationConfig {
999 #[serde(skip_serializing_if = "Option::is_none")]
1002 pub stop_sequences: Option<Vec<String>>,
1003 #[serde(skip_serializing_if = "Option::is_none")]
1009 pub response_mime_type: Option<String>,
1010 #[serde(skip_serializing_if = "Option::is_none")]
1014 pub response_schema: Option<Schema>,
1015 #[serde(skip_serializing_if = "Option::is_none")]
1018 pub candidate_count: Option<i32>,
1019 #[serde(skip_serializing_if = "Option::is_none")]
1022 pub max_output_tokens: Option<u64>,
1023 #[serde(skip_serializing_if = "Option::is_none")]
1026 pub temperature: Option<f64>,
1027 #[serde(skip_serializing_if = "Option::is_none")]
1034 pub top_p: Option<f64>,
1035 #[serde(skip_serializing_if = "Option::is_none")]
1041 pub top_k: Option<i32>,
1042 #[serde(skip_serializing_if = "Option::is_none")]
1048 pub presence_penalty: Option<f64>,
1049 #[serde(skip_serializing_if = "Option::is_none")]
1057 pub frequency_penalty: Option<f64>,
1058 #[serde(skip_serializing_if = "Option::is_none")]
1060 pub response_logprobs: Option<bool>,
1061 #[serde(skip_serializing_if = "Option::is_none")]
1064 pub logprobs: Option<i32>,
1065 #[serde(skip_serializing_if = "Option::is_none")]
1067 pub thinking_config: Option<ThinkingConfig>,
1068 }
1069
1070 impl Default for GenerationConfig {
1071 fn default() -> Self {
1072 Self {
1073 temperature: Some(1.0),
1074 max_output_tokens: Some(4096),
1075 stop_sequences: None,
1076 response_mime_type: None,
1077 response_schema: None,
1078 candidate_count: None,
1079 top_p: None,
1080 top_k: None,
1081 presence_penalty: None,
1082 frequency_penalty: None,
1083 response_logprobs: None,
1084 logprobs: None,
1085 thinking_config: None,
1086 }
1087 }
1088 }
1089
1090 #[derive(Debug, Deserialize, Serialize)]
1091 #[serde(rename_all = "camelCase")]
1092 pub struct ThinkingConfig {
1093 pub thinking_budget: u32,
1094 pub include_thoughts: Option<bool>,
1095 }
1096 #[derive(Debug, Deserialize, Serialize, Clone)]
1100 pub struct Schema {
1101 pub r#type: String,
1102 #[serde(skip_serializing_if = "Option::is_none")]
1103 pub format: Option<String>,
1104 #[serde(skip_serializing_if = "Option::is_none")]
1105 pub description: Option<String>,
1106 #[serde(skip_serializing_if = "Option::is_none")]
1107 pub nullable: Option<bool>,
1108 #[serde(skip_serializing_if = "Option::is_none")]
1109 pub r#enum: Option<Vec<String>>,
1110 #[serde(skip_serializing_if = "Option::is_none")]
1111 pub max_items: Option<i32>,
1112 #[serde(skip_serializing_if = "Option::is_none")]
1113 pub min_items: Option<i32>,
1114 #[serde(skip_serializing_if = "Option::is_none")]
1115 pub properties: Option<HashMap<String, Schema>>,
1116 #[serde(skip_serializing_if = "Option::is_none")]
1117 pub required: Option<Vec<String>>,
1118 #[serde(skip_serializing_if = "Option::is_none")]
1119 pub items: Option<Box<Schema>>,
1120 }
1121
1122 pub fn flatten_schema(mut schema: Value) -> Result<Value, CompletionError> {
1128 let defs = if let Some(obj) = schema.as_object() {
1130 obj.get("$defs").or_else(|| obj.get("definitions")).cloned()
1131 } else {
1132 None
1133 };
1134
1135 let Some(defs_value) = defs else {
1136 return Ok(schema);
1137 };
1138
1139 let Some(defs_obj) = defs_value.as_object() else {
1140 return Err(CompletionError::ResponseError(
1141 "$defs must be an object".into(),
1142 ));
1143 };
1144
1145 resolve_refs(&mut schema, defs_obj)?;
1146
1147 if let Some(obj) = schema.as_object_mut() {
1149 obj.remove("$defs");
1150 obj.remove("definitions");
1151 }
1152
1153 Ok(schema)
1154 }
1155
1156 fn resolve_refs(
1159 value: &mut Value,
1160 defs: &serde_json::Map<String, Value>,
1161 ) -> Result<(), CompletionError> {
1162 match value {
1163 Value::Object(obj) => {
1164 if let Some(ref_value) = obj.get("$ref")
1165 && let Some(ref_str) = ref_value.as_str()
1166 {
1167 let def_name = parse_ref_path(ref_str)?;
1169
1170 let def = defs.get(&def_name).ok_or_else(|| {
1171 CompletionError::ResponseError(format!("Reference not found: {}", ref_str))
1172 })?;
1173
1174 let mut resolved = def.clone();
1175 resolve_refs(&mut resolved, defs)?;
1176 *value = resolved;
1177 return Ok(());
1178 }
1179
1180 for (_, v) in obj.iter_mut() {
1181 resolve_refs(v, defs)?;
1182 }
1183 }
1184 Value::Array(arr) => {
1185 for item in arr.iter_mut() {
1186 resolve_refs(item, defs)?;
1187 }
1188 }
1189 _ => {}
1190 }
1191
1192 Ok(())
1193 }
1194
1195 fn parse_ref_path(ref_str: &str) -> Result<String, CompletionError> {
1201 if let Some(fragment) = ref_str.strip_prefix('#') {
1202 if let Some(name) = fragment.strip_prefix("/$defs/") {
1203 Ok(name.to_string())
1204 } else if let Some(name) = fragment.strip_prefix("/definitions/") {
1205 Ok(name.to_string())
1206 } else {
1207 Err(CompletionError::ResponseError(format!(
1208 "Unsupported reference format: {}",
1209 ref_str
1210 )))
1211 }
1212 } else {
1213 Err(CompletionError::ResponseError(format!(
1214 "Only fragment references (#/...) are supported: {}",
1215 ref_str
1216 )))
1217 }
1218 }
1219
1220 impl TryFrom<Value> for Schema {
1221 type Error = CompletionError;
1222
1223 fn try_from(value: Value) -> Result<Self, Self::Error> {
1224 let flattened_val = flatten_schema(value)?;
1225 if let Some(obj) = flattened_val.as_object() {
1226 Ok(Schema {
1227 r#type: obj
1228 .get("type")
1229 .and_then(|v| {
1230 if v.is_string() {
1231 v.as_str().map(String::from)
1232 } else if v.is_array() {
1233 v.as_array()
1234 .and_then(|arr| arr.first())
1235 .and_then(|v| v.as_str().map(String::from))
1236 } else {
1237 None
1238 }
1239 })
1240 .unwrap_or_default(),
1241 format: obj.get("format").and_then(|v| v.as_str()).map(String::from),
1242 description: obj
1243 .get("description")
1244 .and_then(|v| v.as_str())
1245 .map(String::from),
1246 nullable: obj.get("nullable").and_then(|v| v.as_bool()),
1247 r#enum: obj.get("enum").and_then(|v| v.as_array()).map(|arr| {
1248 arr.iter()
1249 .filter_map(|v| v.as_str().map(String::from))
1250 .collect()
1251 }),
1252 max_items: obj
1253 .get("maxItems")
1254 .and_then(|v| v.as_i64())
1255 .map(|v| v as i32),
1256 min_items: obj
1257 .get("minItems")
1258 .and_then(|v| v.as_i64())
1259 .map(|v| v as i32),
1260 properties: obj
1261 .get("properties")
1262 .and_then(|v| v.as_object())
1263 .map(|map| {
1264 map.iter()
1265 .filter_map(|(k, v)| {
1266 v.clone().try_into().ok().map(|schema| (k.clone(), schema))
1267 })
1268 .collect()
1269 }),
1270 required: obj.get("required").and_then(|v| v.as_array()).map(|arr| {
1271 arr.iter()
1272 .filter_map(|v| v.as_str().map(String::from))
1273 .collect()
1274 }),
1275 items: obj
1276 .get("items")
1277 .map(|v| Box::new(v.clone().try_into().unwrap())),
1278 })
1279 } else {
1280 Err(CompletionError::ResponseError(
1281 "Expected a JSON object for Schema".into(),
1282 ))
1283 }
1284 }
1285 }
1286
1287 #[derive(Debug, Serialize)]
1288 #[serde(rename_all = "camelCase")]
1289 pub struct GenerateContentRequest {
1290 pub contents: Vec<Content>,
1291 #[serde(skip_serializing_if = "Option::is_none")]
1292 pub tools: Option<Tool>,
1293 pub tool_config: Option<ToolConfig>,
1294 pub generation_config: Option<GenerationConfig>,
1296 pub safety_settings: Option<Vec<SafetySetting>>,
1310 pub system_instruction: Option<Content>,
1313 #[serde(flatten, skip_serializing_if = "Option::is_none")]
1316 pub additional_params: Option<serde_json::Value>,
1317 }
1318
1319 #[derive(Debug, Serialize)]
1320 #[serde(rename_all = "camelCase")]
1321 pub struct Tool {
1322 pub function_declarations: Vec<FunctionDeclaration>,
1323 pub code_execution: Option<CodeExecution>,
1324 }
1325
1326 #[derive(Debug, Serialize, Clone)]
1327 #[serde(rename_all = "camelCase")]
1328 pub struct FunctionDeclaration {
1329 pub name: String,
1330 pub description: String,
1331 #[serde(skip_serializing_if = "Option::is_none")]
1332 pub parameters: Option<Schema>,
1333 }
1334
1335 #[derive(Debug, Serialize)]
1336 #[serde(rename_all = "camelCase")]
1337 pub struct ToolConfig {
1338 pub schema: Option<Schema>,
1339 }
1340
1341 #[derive(Debug, Serialize)]
1342 #[serde(rename_all = "camelCase")]
1343 pub struct CodeExecution {}
1344
1345 #[derive(Debug, Serialize)]
1346 #[serde(rename_all = "camelCase")]
1347 pub struct SafetySetting {
1348 pub category: HarmCategory,
1349 pub threshold: HarmBlockThreshold,
1350 }
1351
1352 #[derive(Debug, Serialize)]
1353 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1354 pub enum HarmBlockThreshold {
1355 HarmBlockThresholdUnspecified,
1356 BlockLowAndAbove,
1357 BlockMediumAndAbove,
1358 BlockOnlyHigh,
1359 BlockNone,
1360 Off,
1361 }
1362}
1363
1364#[cfg(test)]
1365mod tests {
1366 use crate::{message, providers::gemini::completion::gemini_api_types::flatten_schema};
1367
1368 use super::*;
1369 use serde_json::json;
1370
1371 #[test]
1372 fn test_deserialize_message_user() {
1373 let raw_message = r#"{
1374 "parts": [
1375 {"text": "Hello, world!"},
1376 {"inlineData": {"mimeType": "image/png", "data": "base64encodeddata"}},
1377 {"functionCall": {"name": "test_function", "args": {"arg1": "value1"}}},
1378 {"functionResponse": {"name": "test_function", "response": {"result": "success"}}},
1379 {"fileData": {"mimeType": "application/pdf", "fileUri": "http://example.com/file.pdf"}},
1380 {"executableCode": {"code": "print('Hello, world!')", "language": "PYTHON"}},
1381 {"codeExecutionResult": {"output": "Hello, world!", "outcome": "OUTCOME_OK"}}
1382 ],
1383 "role": "user"
1384 }"#;
1385
1386 let content: Content = {
1387 let jd = &mut serde_json::Deserializer::from_str(raw_message);
1388 serde_path_to_error::deserialize(jd).unwrap_or_else(|err| {
1389 panic!("Deserialization error at {}: {}", err.path(), err);
1390 })
1391 };
1392 assert_eq!(content.role, Some(Role::User));
1393 assert_eq!(content.parts.len(), 7);
1394
1395 let parts: Vec<Part> = content.parts.into_iter().collect();
1396
1397 if let Part {
1398 part: PartKind::Text(text),
1399 ..
1400 } = &parts[0]
1401 {
1402 assert_eq!(text, "Hello, world!");
1403 } else {
1404 panic!("Expected text part");
1405 }
1406
1407 if let Part {
1408 part: PartKind::InlineData(inline_data),
1409 ..
1410 } = &parts[1]
1411 {
1412 assert_eq!(inline_data.mime_type, "image/png");
1413 assert_eq!(inline_data.data, "base64encodeddata");
1414 } else {
1415 panic!("Expected inline data part");
1416 }
1417
1418 if let Part {
1419 part: PartKind::FunctionCall(function_call),
1420 ..
1421 } = &parts[2]
1422 {
1423 assert_eq!(function_call.name, "test_function");
1424 assert_eq!(
1425 function_call.args.as_object().unwrap().get("arg1").unwrap(),
1426 "value1"
1427 );
1428 } else {
1429 panic!("Expected function call part");
1430 }
1431
1432 if let Part {
1433 part: PartKind::FunctionResponse(function_response),
1434 ..
1435 } = &parts[3]
1436 {
1437 assert_eq!(function_response.name, "test_function");
1438 assert_eq!(
1439 function_response
1440 .response
1441 .as_ref()
1442 .unwrap()
1443 .get("result")
1444 .unwrap(),
1445 "success"
1446 );
1447 } else {
1448 panic!("Expected function response part");
1449 }
1450
1451 if let Part {
1452 part: PartKind::FileData(file_data),
1453 ..
1454 } = &parts[4]
1455 {
1456 assert_eq!(file_data.mime_type.as_ref().unwrap(), "application/pdf");
1457 assert_eq!(file_data.file_uri, "http://example.com/file.pdf");
1458 } else {
1459 panic!("Expected file data part");
1460 }
1461
1462 if let Part {
1463 part: PartKind::ExecutableCode(executable_code),
1464 ..
1465 } = &parts[5]
1466 {
1467 assert_eq!(executable_code.code, "print('Hello, world!')");
1468 } else {
1469 panic!("Expected executable code part");
1470 }
1471
1472 if let Part {
1473 part: PartKind::CodeExecutionResult(code_execution_result),
1474 ..
1475 } = &parts[6]
1476 {
1477 assert_eq!(
1478 code_execution_result.clone().output.unwrap(),
1479 "Hello, world!"
1480 );
1481 } else {
1482 panic!("Expected code execution result part");
1483 }
1484 }
1485
1486 #[test]
1487 fn test_deserialize_message_model() {
1488 let json_data = json!({
1489 "parts": [{"text": "Hello, user!"}],
1490 "role": "model"
1491 });
1492
1493 let content: Content = serde_json::from_value(json_data).unwrap();
1494 assert_eq!(content.role, Some(Role::Model));
1495 assert_eq!(content.parts.len(), 1);
1496 if let Some(Part {
1497 part: PartKind::Text(text),
1498 ..
1499 }) = content.parts.first()
1500 {
1501 assert_eq!(text, "Hello, user!");
1502 } else {
1503 panic!("Expected text part");
1504 }
1505 }
1506
1507 #[test]
1508 fn test_message_conversion_user() {
1509 let msg = message::Message::user("Hello, world!");
1510 let content: Content = msg.try_into().unwrap();
1511 assert_eq!(content.role, Some(Role::User));
1512 assert_eq!(content.parts.len(), 1);
1513 if let Some(Part {
1514 part: PartKind::Text(text),
1515 ..
1516 }) = &content.parts.first()
1517 {
1518 assert_eq!(text, "Hello, world!");
1519 } else {
1520 panic!("Expected text part");
1521 }
1522 }
1523
1524 #[test]
1525 fn test_message_conversion_model() {
1526 let msg = message::Message::assistant("Hello, user!");
1527
1528 let content: Content = msg.try_into().unwrap();
1529 assert_eq!(content.role, Some(Role::Model));
1530 assert_eq!(content.parts.len(), 1);
1531 if let Some(Part {
1532 part: PartKind::Text(text),
1533 ..
1534 }) = &content.parts.first()
1535 {
1536 assert_eq!(text, "Hello, user!");
1537 } else {
1538 panic!("Expected text part");
1539 }
1540 }
1541
1542 #[test]
1543 fn test_message_conversion_tool_call() {
1544 let tool_call = message::ToolCall {
1545 id: "test_tool".to_string(),
1546 call_id: None,
1547 function: message::ToolFunction {
1548 name: "test_function".to_string(),
1549 arguments: json!({"arg1": "value1"}),
1550 },
1551 };
1552
1553 let msg = message::Message::Assistant {
1554 id: None,
1555 content: OneOrMany::one(message::AssistantContent::ToolCall(tool_call)),
1556 };
1557
1558 let content: Content = msg.try_into().unwrap();
1559 assert_eq!(content.role, Some(Role::Model));
1560 assert_eq!(content.parts.len(), 1);
1561 if let Some(Part {
1562 part: PartKind::FunctionCall(function_call),
1563 ..
1564 }) = content.parts.first()
1565 {
1566 assert_eq!(function_call.name, "test_function");
1567 assert_eq!(
1568 function_call.args.as_object().unwrap().get("arg1").unwrap(),
1569 "value1"
1570 );
1571 } else {
1572 panic!("Expected function call part");
1573 }
1574 }
1575
1576 #[test]
1577 fn test_vec_schema_conversion() {
1578 let schema_with_ref = json!({
1579 "type": "array",
1580 "items": {
1581 "$ref": "#/$defs/Person"
1582 },
1583 "$defs": {
1584 "Person": {
1585 "type": "object",
1586 "properties": {
1587 "first_name": {
1588 "type": ["string", "null"],
1589 "description": "The person's first name, if provided (null otherwise)"
1590 },
1591 "last_name": {
1592 "type": ["string", "null"],
1593 "description": "The person's last name, if provided (null otherwise)"
1594 },
1595 "job": {
1596 "type": ["string", "null"],
1597 "description": "The person's job, if provided (null otherwise)"
1598 }
1599 },
1600 "required": []
1601 }
1602 }
1603 });
1604
1605 let result: Result<Schema, _> = schema_with_ref.try_into();
1606
1607 match result {
1608 Ok(schema) => {
1609 assert_eq!(schema.r#type, "array");
1610
1611 if let Some(items) = schema.items {
1612 println!("item types: {}", items.r#type);
1613
1614 assert_ne!(items.r#type, "", "Items type should not be empty string!");
1615 assert_eq!(items.r#type, "object", "Items should be object type");
1616 } else {
1617 panic!("Schema should have items field for array type");
1618 }
1619 }
1620 Err(e) => println!("Schema conversion failed: {:?}", e),
1621 }
1622 }
1623
1624 #[test]
1625 fn test_object_schema() {
1626 let simple_schema = json!({
1627 "type": "object",
1628 "properties": {
1629 "name": {
1630 "type": "string"
1631 }
1632 }
1633 });
1634
1635 let schema: Schema = simple_schema.try_into().unwrap();
1636 assert_eq!(schema.r#type, "object");
1637 assert!(schema.properties.is_some());
1638 }
1639
1640 #[test]
1641 fn test_array_with_inline_items() {
1642 let inline_schema = json!({
1643 "type": "array",
1644 "items": {
1645 "type": "object",
1646 "properties": {
1647 "name": {
1648 "type": "string"
1649 }
1650 }
1651 }
1652 });
1653
1654 let schema: Schema = inline_schema.try_into().unwrap();
1655 assert_eq!(schema.r#type, "array");
1656
1657 if let Some(items) = schema.items {
1658 assert_eq!(items.r#type, "object");
1659 assert!(items.properties.is_some());
1660 } else {
1661 panic!("Schema should have items field");
1662 }
1663 }
1664 #[test]
1665 fn test_flattened_schema() {
1666 let ref_schema = json!({
1667 "type": "array",
1668 "items": {
1669 "$ref": "#/$defs/Person"
1670 },
1671 "$defs": {
1672 "Person": {
1673 "type": "object",
1674 "properties": {
1675 "name": { "type": "string" }
1676 }
1677 }
1678 }
1679 });
1680
1681 let flattened = flatten_schema(ref_schema).unwrap();
1682 let schema: Schema = flattened.try_into().unwrap();
1683
1684 assert_eq!(schema.r#type, "array");
1685
1686 if let Some(items) = schema.items {
1687 println!("Flattened items type: '{}'", items.r#type);
1688
1689 assert_eq!(items.r#type, "object");
1690 assert!(items.properties.is_some());
1691 }
1692 }
1693}