1pub const GEMINI_2_5_PRO_PREVIEW_06_05: &str = "gemini-2.5-pro-preview-06-05";
7pub const GEMINI_2_5_PRO_PREVIEW_05_06: &str = "gemini-2.5-pro-preview-05-06";
9pub const GEMINI_2_5_PRO_PREVIEW_03_25: &str = "gemini-2.5-pro-preview-03-25";
11pub const GEMINI_2_5_FLASH_PREVIEW_05_20: &str = "gemini-2.5-flash-preview-05-20";
13pub const GEMINI_2_5_FLASH_PREVIEW_04_17: &str = "gemini-2.5-flash-preview-04-17";
15pub const GEMINI_2_5_PRO_EXP_03_25: &str = "gemini-2.5-pro-exp-03-25";
17pub const GEMINI_2_0_FLASH_LITE: &str = "gemini-2.0-flash-lite";
19pub const GEMINI_2_0_FLASH: &str = "gemini-2.0-flash";
21pub const GEMINI_1_5_FLASH: &str = "gemini-1.5-flash";
23pub const GEMINI_1_5_PRO: &str = "gemini-1.5-pro";
25pub const GEMINI_1_5_PRO_8B: &str = "gemini-1.5-pro-8b";
27pub const GEMINI_1_0_PRO: &str = "gemini-1.0-pro";
29
30use self::gemini_api_types::Schema;
31use crate::message::Reasoning;
32use crate::providers::gemini::completion::gemini_api_types::AdditionalParameters;
33use crate::providers::gemini::streaming::StreamingCompletionResponse;
34use crate::{
35 OneOrMany,
36 completion::{self, CompletionError, CompletionRequest},
37};
38use gemini_api_types::{
39 Content, FunctionDeclaration, GenerateContentRequest, GenerateContentResponse, Part, PartKind,
40 Role, Tool,
41};
42use serde_json::{Map, Value};
43use std::convert::TryFrom;
44
45use super::Client;
46
47#[derive(Clone)]
52pub struct CompletionModel {
53 pub(crate) client: Client,
54 pub model: String,
55}
56
57impl CompletionModel {
58 pub fn new(client: Client, model: &str) -> Self {
59 Self {
60 client,
61 model: model.to_string(),
62 }
63 }
64}
65
66impl completion::CompletionModel for CompletionModel {
67 type Response = GenerateContentResponse;
68 type StreamingResponse = StreamingCompletionResponse;
69
70 #[cfg_attr(feature = "worker", worker::send)]
71 async fn completion(
72 &self,
73 completion_request: CompletionRequest,
74 ) -> Result<completion::CompletionResponse<GenerateContentResponse>, CompletionError> {
75 let request = create_request_body(completion_request)?;
76
77 tracing::debug!(
78 "Sending completion request to Gemini API {}",
79 serde_json::to_string_pretty(&request)?
80 );
81
82 let response = self
83 .client
84 .post(&format!("/v1beta/models/{}:generateContent", self.model))
85 .json(&request)
86 .send()
87 .await?;
88
89 if response.status().is_success() {
90 let response = response.json::<GenerateContentResponse>().await?;
91 match response.usage_metadata {
92 Some(ref usage) => tracing::info!(target: "rig",
93 "Gemini completion token usage: {}",
94 usage
95 ),
96 None => tracing::info!(target: "rig",
97 "Gemini completion token usage: n/a",
98 ),
99 }
100
101 tracing::debug!("Received response");
102
103 Ok(completion::CompletionResponse::try_from(response))
104 } else {
105 Err(CompletionError::ProviderError(response.text().await?))
106 }?
107 }
108
109 #[cfg_attr(feature = "worker", worker::send)]
110 async fn stream(
111 &self,
112 request: CompletionRequest,
113 ) -> Result<
114 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
115 CompletionError,
116 > {
117 CompletionModel::stream(self, request).await
118 }
119}
120
121pub(crate) fn create_request_body(
122 completion_request: CompletionRequest,
123) -> Result<GenerateContentRequest, CompletionError> {
124 let mut full_history = Vec::new();
125 full_history.extend(completion_request.chat_history);
126
127 let additional_params = completion_request
128 .additional_params
129 .unwrap_or_else(|| Value::Object(Map::new()));
130
131 let AdditionalParameters {
132 mut generation_config,
133 additional_params,
134 } = serde_json::from_value::<AdditionalParameters>(additional_params)?;
135
136 if let Some(temp) = completion_request.temperature {
137 generation_config.temperature = Some(temp);
138 }
139
140 if let Some(max_tokens) = completion_request.max_tokens {
141 generation_config.max_output_tokens = Some(max_tokens);
142 }
143
144 let system_instruction = completion_request.preamble.clone().map(|preamble| Content {
145 parts: vec![preamble.into()],
146 role: Some(Role::Model),
147 });
148
149 let tools = if completion_request.tools.is_empty() {
150 None
151 } else {
152 Some(Tool::try_from(completion_request.tools)?)
153 };
154
155 let request = GenerateContentRequest {
156 contents: full_history
157 .into_iter()
158 .map(|msg| {
159 msg.try_into()
160 .map_err(|e| CompletionError::RequestError(Box::new(e)))
161 })
162 .collect::<Result<Vec<_>, _>>()?,
163 generation_config: Some(generation_config),
164 safety_settings: None,
165 tools,
166 tool_config: None,
167 system_instruction,
168 additional_params,
169 };
170
171 Ok(request)
172}
173
174impl TryFrom<completion::ToolDefinition> for Tool {
175 type Error = CompletionError;
176
177 fn try_from(tool: completion::ToolDefinition) -> Result<Self, Self::Error> {
178 let parameters: Option<Schema> =
179 if tool.parameters == serde_json::json!({"type": "object", "properties": {}}) {
180 None
181 } else {
182 Some(tool.parameters.try_into()?)
183 };
184
185 Ok(Self {
186 function_declarations: vec![FunctionDeclaration {
187 name: tool.name,
188 description: tool.description,
189 parameters,
190 }],
191 code_execution: None,
192 })
193 }
194}
195
196impl TryFrom<Vec<completion::ToolDefinition>> for Tool {
197 type Error = CompletionError;
198
199 fn try_from(tools: Vec<completion::ToolDefinition>) -> Result<Self, Self::Error> {
200 let mut function_declarations = Vec::new();
201
202 for tool in tools {
203 let parameters =
204 if tool.parameters == serde_json::json!({"type": "object", "properties": {}}) {
205 None
206 } else {
207 match tool.parameters.try_into() {
208 Ok(schema) => Some(schema),
209 Err(e) => {
210 let emsg = format!(
211 "Tool '{}' could not be converted to a schema: {:?}",
212 tool.name, e,
213 );
214 return Err(CompletionError::ProviderError(emsg));
215 }
216 }
217 };
218
219 function_declarations.push(FunctionDeclaration {
220 name: tool.name,
221 description: tool.description,
222 parameters,
223 });
224 }
225
226 Ok(Self {
227 function_declarations,
228 code_execution: None,
229 })
230 }
231}
232
233impl TryFrom<GenerateContentResponse> for completion::CompletionResponse<GenerateContentResponse> {
234 type Error = CompletionError;
235
236 fn try_from(response: GenerateContentResponse) -> Result<Self, Self::Error> {
237 let candidate = response.candidates.first().ok_or_else(|| {
238 CompletionError::ResponseError("No response candidates in response".into())
239 })?;
240
241 let content = candidate
242 .content
243 .parts
244 .iter()
245 .map(|Part { thought, part, .. }| {
246 Ok(match part {
247 PartKind::Text(text) => {
248 if let Some(thought) = thought
249 && *thought
250 {
251 completion::AssistantContent::Reasoning(Reasoning::new(text))
252 } else {
253 completion::AssistantContent::text(text)
254 }
255 }
256 PartKind::FunctionCall(function_call) => {
257 completion::AssistantContent::tool_call(
258 &function_call.name,
259 &function_call.name,
260 function_call.args.clone(),
261 )
262 }
263 _ => {
264 return Err(CompletionError::ResponseError(
265 "Response did not contain a message or tool call".into(),
266 ));
267 }
268 })
269 })
270 .collect::<Result<Vec<_>, _>>()?;
271
272 let choice = OneOrMany::many(content).map_err(|_| {
273 CompletionError::ResponseError(
274 "Response contained no message or tool call (empty)".to_owned(),
275 )
276 })?;
277
278 let usage = response
279 .usage_metadata
280 .as_ref()
281 .map(|usage| completion::Usage {
282 input_tokens: usage.prompt_token_count as u64,
283 output_tokens: usage.candidates_token_count as u64,
284 total_tokens: usage.total_token_count as u64,
285 })
286 .unwrap_or_default();
287
288 Ok(completion::CompletionResponse {
289 choice,
290 usage,
291 raw_response: response,
292 })
293 }
294}
295
296pub mod gemini_api_types {
297 use std::{collections::HashMap, convert::Infallible, str::FromStr};
298
299 use serde::{Deserialize, Serialize};
303 use serde_json::{Value, json};
304
305 use crate::message::ContentFormat;
306 use crate::{
307 OneOrMany,
308 completion::CompletionError,
309 message::{self, MimeType as _, Reasoning, Text},
310 providers::gemini::gemini_api_types::{CodeExecutionResult, ExecutableCode},
311 };
312
313 #[derive(Debug, Deserialize, Serialize, Default)]
314 #[serde(rename_all = "camelCase")]
315 pub struct AdditionalParameters {
316 pub generation_config: GenerationConfig,
318 #[serde(flatten, skip_serializing_if = "Option::is_none")]
320 pub additional_params: Option<serde_json::Value>,
321 }
322
323 impl AdditionalParameters {
324 pub fn with_config(mut self, cfg: GenerationConfig) -> Self {
325 self.generation_config = cfg;
326 self
327 }
328
329 pub fn with_params(mut self, params: serde_json::Value) -> Self {
330 self.additional_params = Some(params);
331 self
332 }
333 }
334
335 #[derive(Debug, Deserialize, Serialize)]
343 #[serde(rename_all = "camelCase")]
344 pub struct GenerateContentResponse {
345 pub candidates: Vec<ContentCandidate>,
347 pub prompt_feedback: Option<PromptFeedback>,
349 pub usage_metadata: Option<UsageMetadata>,
351 pub model_version: Option<String>,
352 }
353
354 #[derive(Debug, Deserialize, Serialize)]
356 #[serde(rename_all = "camelCase")]
357 pub struct ContentCandidate {
358 pub content: Content,
360 pub finish_reason: Option<FinishReason>,
363 pub safety_ratings: Option<Vec<SafetyRating>>,
366 pub citation_metadata: Option<CitationMetadata>,
370 pub token_count: Option<i32>,
372 pub avg_logprobs: Option<f64>,
374 pub logprobs_result: Option<LogprobsResult>,
376 pub index: Option<i32>,
378 }
379
380 #[derive(Debug, Deserialize, Serialize)]
381 pub struct Content {
382 #[serde(default)]
384 pub parts: Vec<Part>,
385 pub role: Option<Role>,
388 }
389
390 impl TryFrom<message::Message> for Content {
391 type Error = message::MessageError;
392
393 fn try_from(msg: message::Message) -> Result<Self, Self::Error> {
394 Ok(match msg {
395 message::Message::User { content } => Content {
396 parts: content
397 .into_iter()
398 .map(|c| c.try_into())
399 .collect::<Result<Vec<_>, _>>()?,
400 role: Some(Role::User),
401 },
402 message::Message::Assistant { content, .. } => Content {
403 role: Some(Role::Model),
404 parts: content.into_iter().map(|content| content.into()).collect(),
405 },
406 })
407 }
408 }
409
410 impl TryFrom<Content> for message::Message {
411 type Error = message::MessageError;
412
413 fn try_from(content: Content) -> Result<Self, Self::Error> {
414 match content.role {
415 Some(Role::User) | None => {
416 Ok(message::Message::User {
417 content: {
418 let user_content: Result<Vec<_>, _> = content.parts.into_iter()
419 .map(|Part { part, .. }| {
420 Ok(match part {
421 PartKind::Text(text) => message::UserContent::text(text),
422 PartKind::InlineData(inline_data) => {
423 let mime_type =
424 message::MediaType::from_mime_type(&inline_data.mime_type);
425
426 match mime_type {
427 Some(message::MediaType::Image(media_type)) => {
428 message::UserContent::image(
429 inline_data.data,
430 Some(message::ContentFormat::default()),
431 Some(media_type),
432 Some(message::ImageDetail::default()),
433 )
434 }
435 Some(message::MediaType::Document(media_type)) => {
436 message::UserContent::document(
437 inline_data.data,
438 Some(message::ContentFormat::default()),
439 Some(media_type),
440 )
441 }
442 Some(message::MediaType::Audio(media_type)) => {
443 message::UserContent::audio(
444 inline_data.data,
445 Some(message::ContentFormat::default()),
446 Some(media_type),
447 )
448 }
449 _ => {
450 return Err(message::MessageError::ConversionError(
451 format!("Unsupported media type {mime_type:?}"),
452 ));
453 }
454 }
455 }
456 _ => {
457 return Err(message::MessageError::ConversionError(format!(
458 "Unsupported gemini content part type: {part:?}"
459 )));
460 }
461 })
462 })
463 .collect();
464 OneOrMany::many(user_content?).map_err(|_| {
465 message::MessageError::ConversionError(
466 "Failed to create OneOrMany from user content".to_string(),
467 )
468 })?
469 },
470 })
471 }
472 Some(Role::Model) => Ok(message::Message::Assistant {
473 id: None,
474 content: {
475 let assistant_content: Result<Vec<_>, _> = content
476 .parts
477 .into_iter()
478 .map(|Part { thought, part, .. }| {
479 Ok(match part {
480 PartKind::Text(text) => match thought {
481 Some(true) => message::AssistantContent::Reasoning(
482 Reasoning::new(&text),
483 ),
484 _ => message::AssistantContent::Text(Text { text }),
485 },
486
487 PartKind::FunctionCall(function_call) => {
488 message::AssistantContent::ToolCall(function_call.into())
489 }
490 _ => {
491 return Err(message::MessageError::ConversionError(
492 format!("Unsupported part type: {part:?}"),
493 ));
494 }
495 })
496 })
497 .collect();
498 OneOrMany::many(assistant_content?).map_err(|_| {
499 message::MessageError::ConversionError(
500 "Failed to create OneOrMany from assistant content".to_string(),
501 )
502 })?
503 },
504 }),
505 }
506 }
507 }
508
509 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
510 #[serde(rename_all = "lowercase")]
511 pub enum Role {
512 User,
513 Model,
514 }
515
516 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
517 #[serde(rename_all = "camelCase")]
518 pub struct Part {
519 #[serde(skip_serializing_if = "Option::is_none")]
521 pub thought: Option<bool>,
522 #[serde(skip_serializing_if = "Option::is_none")]
524 pub thought_signature: Option<String>,
525 #[serde(flatten)]
526 pub part: PartKind,
527 #[serde(flatten, skip_serializing_if = "Option::is_none")]
528 pub additional_params: Option<Value>,
529 }
530
531 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
535 #[serde(rename_all = "camelCase")]
536 pub enum PartKind {
537 Text(String),
538 InlineData(Blob),
539 FunctionCall(FunctionCall),
540 FunctionResponse(FunctionResponse),
541 FileData(FileData),
542 ExecutableCode(ExecutableCode),
543 CodeExecutionResult(CodeExecutionResult),
544 }
545
546 impl From<String> for Part {
547 fn from(text: String) -> Self {
548 Self {
549 thought: Some(false),
550 thought_signature: None,
551 part: PartKind::Text(text),
552 additional_params: None,
553 }
554 }
555 }
556
557 impl From<&str> for Part {
558 fn from(text: &str) -> Self {
559 Self::from(text.to_string())
560 }
561 }
562
563 impl FromStr for Part {
564 type Err = Infallible;
565
566 fn from_str(s: &str) -> Result<Self, Self::Err> {
567 Ok(s.into())
568 }
569 }
570
571 impl TryFrom<message::UserContent> for Part {
572 type Error = message::MessageError;
573
574 fn try_from(content: message::UserContent) -> Result<Self, Self::Error> {
575 match content {
576 message::UserContent::Text(message::Text { text }) => Ok(Part {
577 thought: Some(false),
578 thought_signature: None,
579 part: PartKind::Text(text),
580 additional_params: None,
581 }),
582 message::UserContent::ToolResult(message::ToolResult { id, content, .. }) => {
583 let content = match content.first() {
584 message::ToolResultContent::Text(text) => text.text,
585 message::ToolResultContent::Image(_) => {
586 return Err(message::MessageError::ConversionError(
587 "Tool result content must be text".to_string(),
588 ));
589 }
590 };
591 let result: serde_json::Value =
593 serde_json::from_str(&content).unwrap_or_else(|error| {
594 tracing::trace!(
595 ?error,
596 "Tool result is not a valid JSON, treat it as normal string"
597 );
598 json!(content)
599 });
600 Ok(Part {
601 thought: Some(false),
602 thought_signature: None,
603 part: PartKind::FunctionResponse(FunctionResponse {
604 name: id,
605 response: Some(json!({ "result": result })),
606 }),
607 additional_params: None,
608 })
609 }
610 message::UserContent::Image(message::Image {
611 data, media_type, ..
612 }) => match media_type {
613 Some(media_type) => match media_type {
614 message::ImageMediaType::JPEG
615 | message::ImageMediaType::PNG
616 | message::ImageMediaType::WEBP
617 | message::ImageMediaType::HEIC
618 | message::ImageMediaType::HEIF => Ok(Part {
619 thought: Some(false),
620 thought_signature: None,
621 part: PartKind::InlineData(Blob {
622 mime_type: media_type.to_mime_type().to_owned(),
623 data,
624 }),
625 additional_params: None,
626 }),
627 _ => Err(message::MessageError::ConversionError(format!(
628 "Unsupported image media type {media_type:?}"
629 ))),
630 },
631 None => Err(message::MessageError::ConversionError(
632 "Media type for image is required for Gemini".to_string(), )),
634 },
635 message::UserContent::Document(message::Document {
636 data, media_type, ..
637 }) => match media_type {
638 Some(media_type) => match media_type {
639 message::DocumentMediaType::PDF
640 | message::DocumentMediaType::TXT
641 | message::DocumentMediaType::RTF
642 | message::DocumentMediaType::HTML
643 | message::DocumentMediaType::CSS
644 | message::DocumentMediaType::MARKDOWN
645 | message::DocumentMediaType::CSV
646 | message::DocumentMediaType::XML => Ok(Part {
647 thought: Some(false),
648 thought_signature: None,
649 part: PartKind::InlineData(Blob {
650 mime_type: media_type.to_mime_type().to_owned(),
651 data,
652 }),
653 additional_params: None,
654 }),
655 _ => Err(message::MessageError::ConversionError(format!(
656 "Unsupported document media type {media_type:?}"
657 ))),
658 },
659 None => Err(message::MessageError::ConversionError(
660 "Media type for document is required for Gemini".to_string(), )),
662 },
663 message::UserContent::Audio(message::Audio {
664 data, media_type, ..
665 }) => match media_type {
666 Some(media_type) => Ok(Part {
667 thought: Some(false),
668 thought_signature: None,
669 part: PartKind::InlineData(Blob {
670 mime_type: media_type.to_mime_type().to_owned(),
671 data,
672 }),
673 additional_params: None,
674 }),
675 None => Err(message::MessageError::ConversionError(
676 "Media type for audio is required for Gemini".to_string(),
677 )),
678 },
679 message::UserContent::Video(message::Video {
680 data,
681 media_type,
682 format,
683 additional_params,
684 }) => {
685 let mime_type = media_type.map(|m| m.to_mime_type().to_owned());
686
687 let data = match format {
688 Some(ContentFormat::String) => PartKind::FileData(FileData {
689 mime_type,
690 file_uri: data,
691 }),
692 _ => match mime_type {
693 Some(mime_type) => PartKind::InlineData(Blob { mime_type, data }),
694 None => {
695 return Err(message::MessageError::ConversionError(
696 "Media type for video is required for Gemini".to_string(),
697 ));
698 }
699 },
700 };
701
702 Ok(Part {
703 thought: Some(false),
704 thought_signature: None,
705 part: data,
706 additional_params,
707 })
708 }
709 }
710 }
711 }
712
713 impl From<message::AssistantContent> for Part {
714 fn from(content: message::AssistantContent) -> Self {
715 match content {
716 message::AssistantContent::Text(message::Text { text }) => text.into(),
717 message::AssistantContent::ToolCall(tool_call) => tool_call.into(),
718 message::AssistantContent::Reasoning(message::Reasoning { reasoning, .. }) => {
719 Part {
720 thought: Some(true),
721 thought_signature: None,
722 part: PartKind::Text(
723 reasoning.first().cloned().unwrap_or_else(|| "".to_string()),
724 ),
725 additional_params: None,
726 }
727 }
728 }
729 }
730 }
731
732 impl From<message::ToolCall> for Part {
733 fn from(tool_call: message::ToolCall) -> Self {
734 Self {
735 thought: Some(false),
736 thought_signature: None,
737 part: PartKind::FunctionCall(FunctionCall {
738 name: tool_call.function.name,
739 args: tool_call.function.arguments,
740 }),
741 additional_params: None,
742 }
743 }
744 }
745
746 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
749 #[serde(rename_all = "camelCase")]
750 pub struct Blob {
751 pub mime_type: String,
754 pub data: String,
756 }
757
758 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
761 pub struct FunctionCall {
762 pub name: String,
765 pub args: serde_json::Value,
767 }
768
769 impl From<FunctionCall> for message::ToolCall {
770 fn from(function_call: FunctionCall) -> Self {
771 Self {
772 id: function_call.name.clone(),
773 call_id: None,
774 function: message::ToolFunction {
775 name: function_call.name,
776 arguments: function_call.args,
777 },
778 }
779 }
780 }
781
782 impl From<message::ToolCall> for FunctionCall {
783 fn from(tool_call: message::ToolCall) -> Self {
784 Self {
785 name: tool_call.function.name,
786 args: tool_call.function.arguments,
787 }
788 }
789 }
790
791 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
795 pub struct FunctionResponse {
796 pub name: String,
799 pub response: Option<serde_json::Value>,
801 }
802
803 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
805 #[serde(rename_all = "camelCase")]
806 pub struct FileData {
807 pub mime_type: Option<String>,
809 pub file_uri: String,
811 }
812
813 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
814 pub struct SafetyRating {
815 pub category: HarmCategory,
816 pub probability: HarmProbability,
817 }
818
819 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
820 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
821 pub enum HarmProbability {
822 HarmProbabilityUnspecified,
823 Negligible,
824 Low,
825 Medium,
826 High,
827 }
828
829 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
830 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
831 pub enum HarmCategory {
832 HarmCategoryUnspecified,
833 HarmCategoryDerogatory,
834 HarmCategoryToxicity,
835 HarmCategoryViolence,
836 HarmCategorySexually,
837 HarmCategoryMedical,
838 HarmCategoryDangerous,
839 HarmCategoryHarassment,
840 HarmCategoryHateSpeech,
841 HarmCategorySexuallyExplicit,
842 HarmCategoryDangerousContent,
843 HarmCategoryCivicIntegrity,
844 }
845
846 #[derive(Debug, Deserialize, Clone, Default, Serialize)]
847 #[serde(rename_all = "camelCase")]
848 pub struct UsageMetadata {
849 pub prompt_token_count: i32,
850 #[serde(skip_serializing_if = "Option::is_none")]
851 pub cached_content_token_count: Option<i32>,
852 pub candidates_token_count: i32,
853 pub total_token_count: i32,
854 #[serde(skip_serializing_if = "Option::is_none")]
855 pub thoughts_token_count: Option<i32>,
856 }
857
858 impl std::fmt::Display for UsageMetadata {
859 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
860 write!(
861 f,
862 "Prompt token count: {}\nCached content token count: {}\nCandidates token count: {}\nTotal token count: {}",
863 self.prompt_token_count,
864 match self.cached_content_token_count {
865 Some(count) => count.to_string(),
866 None => "n/a".to_string(),
867 },
868 self.candidates_token_count,
869 self.total_token_count
870 )
871 }
872 }
873
874 #[derive(Debug, Deserialize, Serialize)]
876 #[serde(rename_all = "camelCase")]
877 pub struct PromptFeedback {
878 pub block_reason: Option<BlockReason>,
880 pub safety_ratings: Option<Vec<SafetyRating>>,
882 }
883
884 #[derive(Debug, Deserialize, Serialize)]
886 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
887 pub enum BlockReason {
888 BlockReasonUnspecified,
890 Safety,
892 Other,
894 Blocklist,
896 ProhibitedContent,
898 }
899
900 #[derive(Debug, Deserialize, Serialize)]
901 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
902 pub enum FinishReason {
903 FinishReasonUnspecified,
905 Stop,
907 MaxTokens,
909 Safety,
911 Recitation,
913 Language,
915 Other,
917 Blocklist,
919 ProhibitedContent,
921 Spii,
923 MalformedFunctionCall,
925 }
926
927 #[derive(Debug, Deserialize, Serialize)]
928 #[serde(rename_all = "camelCase")]
929 pub struct CitationMetadata {
930 pub citation_sources: Vec<CitationSource>,
931 }
932
933 #[derive(Debug, Deserialize, Serialize)]
934 #[serde(rename_all = "camelCase")]
935 pub struct CitationSource {
936 #[serde(skip_serializing_if = "Option::is_none")]
937 pub uri: Option<String>,
938 #[serde(skip_serializing_if = "Option::is_none")]
939 pub start_index: Option<i32>,
940 #[serde(skip_serializing_if = "Option::is_none")]
941 pub end_index: Option<i32>,
942 #[serde(skip_serializing_if = "Option::is_none")]
943 pub license: Option<String>,
944 }
945
946 #[derive(Debug, Deserialize, Serialize)]
947 #[serde(rename_all = "camelCase")]
948 pub struct LogprobsResult {
949 pub top_candidate: Vec<TopCandidate>,
950 pub chosen_candidate: Vec<LogProbCandidate>,
951 }
952
953 #[derive(Debug, Deserialize, Serialize)]
954 pub struct TopCandidate {
955 pub candidates: Vec<LogProbCandidate>,
956 }
957
958 #[derive(Debug, Deserialize, Serialize)]
959 #[serde(rename_all = "camelCase")]
960 pub struct LogProbCandidate {
961 pub token: String,
962 pub token_id: String,
963 pub log_probability: f64,
964 }
965
966 #[derive(Debug, Deserialize, Serialize)]
971 #[serde(rename_all = "camelCase")]
972 pub struct GenerationConfig {
973 #[serde(skip_serializing_if = "Option::is_none")]
976 pub stop_sequences: Option<Vec<String>>,
977 #[serde(skip_serializing_if = "Option::is_none")]
983 pub response_mime_type: Option<String>,
984 #[serde(skip_serializing_if = "Option::is_none")]
988 pub response_schema: Option<Schema>,
989 #[serde(skip_serializing_if = "Option::is_none")]
992 pub candidate_count: Option<i32>,
993 #[serde(skip_serializing_if = "Option::is_none")]
996 pub max_output_tokens: Option<u64>,
997 #[serde(skip_serializing_if = "Option::is_none")]
1000 pub temperature: Option<f64>,
1001 #[serde(skip_serializing_if = "Option::is_none")]
1008 pub top_p: Option<f64>,
1009 #[serde(skip_serializing_if = "Option::is_none")]
1015 pub top_k: Option<i32>,
1016 #[serde(skip_serializing_if = "Option::is_none")]
1022 pub presence_penalty: Option<f64>,
1023 #[serde(skip_serializing_if = "Option::is_none")]
1031 pub frequency_penalty: Option<f64>,
1032 #[serde(skip_serializing_if = "Option::is_none")]
1034 pub response_logprobs: Option<bool>,
1035 #[serde(skip_serializing_if = "Option::is_none")]
1038 pub logprobs: Option<i32>,
1039 #[serde(skip_serializing_if = "Option::is_none")]
1041 pub thinking_config: Option<ThinkingConfig>,
1042 }
1043
1044 impl Default for GenerationConfig {
1045 fn default() -> Self {
1046 Self {
1047 temperature: Some(1.0),
1048 max_output_tokens: Some(4096),
1049 stop_sequences: None,
1050 response_mime_type: None,
1051 response_schema: None,
1052 candidate_count: None,
1053 top_p: None,
1054 top_k: None,
1055 presence_penalty: None,
1056 frequency_penalty: None,
1057 response_logprobs: None,
1058 logprobs: None,
1059 thinking_config: None,
1060 }
1061 }
1062 }
1063
1064 #[derive(Debug, Deserialize, Serialize)]
1065 #[serde(rename_all = "camelCase")]
1066 pub struct ThinkingConfig {
1067 pub thinking_budget: u32,
1068 pub include_thoughts: Option<bool>,
1069 }
1070 #[derive(Debug, Deserialize, Serialize, Clone)]
1074 pub struct Schema {
1075 pub r#type: String,
1076 #[serde(skip_serializing_if = "Option::is_none")]
1077 pub format: Option<String>,
1078 #[serde(skip_serializing_if = "Option::is_none")]
1079 pub description: Option<String>,
1080 #[serde(skip_serializing_if = "Option::is_none")]
1081 pub nullable: Option<bool>,
1082 #[serde(skip_serializing_if = "Option::is_none")]
1083 pub r#enum: Option<Vec<String>>,
1084 #[serde(skip_serializing_if = "Option::is_none")]
1085 pub max_items: Option<i32>,
1086 #[serde(skip_serializing_if = "Option::is_none")]
1087 pub min_items: Option<i32>,
1088 #[serde(skip_serializing_if = "Option::is_none")]
1089 pub properties: Option<HashMap<String, Schema>>,
1090 #[serde(skip_serializing_if = "Option::is_none")]
1091 pub required: Option<Vec<String>>,
1092 #[serde(skip_serializing_if = "Option::is_none")]
1093 pub items: Option<Box<Schema>>,
1094 }
1095
1096 pub fn flatten_schema(mut schema: Value) -> Result<Value, CompletionError> {
1102 let defs = if let Some(obj) = schema.as_object() {
1104 obj.get("$defs").or_else(|| obj.get("definitions")).cloned()
1105 } else {
1106 None
1107 };
1108
1109 let Some(defs_value) = defs else {
1110 return Ok(schema);
1111 };
1112
1113 let Some(defs_obj) = defs_value.as_object() else {
1114 return Err(CompletionError::ResponseError(
1115 "$defs must be an object".into(),
1116 ));
1117 };
1118
1119 resolve_refs(&mut schema, defs_obj)?;
1120
1121 if let Some(obj) = schema.as_object_mut() {
1123 obj.remove("$defs");
1124 obj.remove("definitions");
1125 }
1126
1127 Ok(schema)
1128 }
1129
1130 fn resolve_refs(
1133 value: &mut Value,
1134 defs: &serde_json::Map<String, Value>,
1135 ) -> Result<(), CompletionError> {
1136 match value {
1137 Value::Object(obj) => {
1138 if let Some(ref_value) = obj.get("$ref")
1139 && let Some(ref_str) = ref_value.as_str()
1140 {
1141 let def_name = parse_ref_path(ref_str)?;
1143
1144 let def = defs.get(&def_name).ok_or_else(|| {
1145 CompletionError::ResponseError(format!("Reference not found: {}", ref_str))
1146 })?;
1147
1148 let mut resolved = def.clone();
1149 resolve_refs(&mut resolved, defs)?;
1150 *value = resolved;
1151 return Ok(());
1152 }
1153
1154 for (_, v) in obj.iter_mut() {
1155 resolve_refs(v, defs)?;
1156 }
1157 }
1158 Value::Array(arr) => {
1159 for item in arr.iter_mut() {
1160 resolve_refs(item, defs)?;
1161 }
1162 }
1163 _ => {}
1164 }
1165
1166 Ok(())
1167 }
1168
1169 fn parse_ref_path(ref_str: &str) -> Result<String, CompletionError> {
1175 if let Some(fragment) = ref_str.strip_prefix('#') {
1176 if let Some(name) = fragment.strip_prefix("/$defs/") {
1177 Ok(name.to_string())
1178 } else if let Some(name) = fragment.strip_prefix("/definitions/") {
1179 Ok(name.to_string())
1180 } else {
1181 Err(CompletionError::ResponseError(format!(
1182 "Unsupported reference format: {}",
1183 ref_str
1184 )))
1185 }
1186 } else {
1187 Err(CompletionError::ResponseError(format!(
1188 "Only fragment references (#/...) are supported: {}",
1189 ref_str
1190 )))
1191 }
1192 }
1193
1194 impl TryFrom<Value> for Schema {
1195 type Error = CompletionError;
1196
1197 fn try_from(value: Value) -> Result<Self, Self::Error> {
1198 let flattened_val = flatten_schema(value)?;
1199 if let Some(obj) = flattened_val.as_object() {
1200 Ok(Schema {
1201 r#type: obj
1202 .get("type")
1203 .and_then(|v| {
1204 if v.is_string() {
1205 v.as_str().map(String::from)
1206 } else if v.is_array() {
1207 v.as_array()
1208 .and_then(|arr| arr.first())
1209 .and_then(|v| v.as_str().map(String::from))
1210 } else {
1211 None
1212 }
1213 })
1214 .unwrap_or_default(),
1215 format: obj.get("format").and_then(|v| v.as_str()).map(String::from),
1216 description: obj
1217 .get("description")
1218 .and_then(|v| v.as_str())
1219 .map(String::from),
1220 nullable: obj.get("nullable").and_then(|v| v.as_bool()),
1221 r#enum: obj.get("enum").and_then(|v| v.as_array()).map(|arr| {
1222 arr.iter()
1223 .filter_map(|v| v.as_str().map(String::from))
1224 .collect()
1225 }),
1226 max_items: obj
1227 .get("maxItems")
1228 .and_then(|v| v.as_i64())
1229 .map(|v| v as i32),
1230 min_items: obj
1231 .get("minItems")
1232 .and_then(|v| v.as_i64())
1233 .map(|v| v as i32),
1234 properties: obj
1235 .get("properties")
1236 .and_then(|v| v.as_object())
1237 .map(|map| {
1238 map.iter()
1239 .filter_map(|(k, v)| {
1240 v.clone().try_into().ok().map(|schema| (k.clone(), schema))
1241 })
1242 .collect()
1243 }),
1244 required: obj.get("required").and_then(|v| v.as_array()).map(|arr| {
1245 arr.iter()
1246 .filter_map(|v| v.as_str().map(String::from))
1247 .collect()
1248 }),
1249 items: obj
1250 .get("items")
1251 .map(|v| Box::new(v.clone().try_into().unwrap())),
1252 })
1253 } else {
1254 Err(CompletionError::ResponseError(
1255 "Expected a JSON object for Schema".into(),
1256 ))
1257 }
1258 }
1259 }
1260
1261 #[derive(Debug, Serialize)]
1262 #[serde(rename_all = "camelCase")]
1263 pub struct GenerateContentRequest {
1264 pub contents: Vec<Content>,
1265 #[serde(skip_serializing_if = "Option::is_none")]
1266 pub tools: Option<Tool>,
1267 pub tool_config: Option<ToolConfig>,
1268 pub generation_config: Option<GenerationConfig>,
1270 pub safety_settings: Option<Vec<SafetySetting>>,
1284 pub system_instruction: Option<Content>,
1287 #[serde(flatten, skip_serializing_if = "Option::is_none")]
1290 pub additional_params: Option<serde_json::Value>,
1291 }
1292
1293 #[derive(Debug, Serialize)]
1294 #[serde(rename_all = "camelCase")]
1295 pub struct Tool {
1296 pub function_declarations: Vec<FunctionDeclaration>,
1297 pub code_execution: Option<CodeExecution>,
1298 }
1299
1300 #[derive(Debug, Serialize, Clone)]
1301 #[serde(rename_all = "camelCase")]
1302 pub struct FunctionDeclaration {
1303 pub name: String,
1304 pub description: String,
1305 #[serde(skip_serializing_if = "Option::is_none")]
1306 pub parameters: Option<Schema>,
1307 }
1308
1309 #[derive(Debug, Serialize)]
1310 #[serde(rename_all = "camelCase")]
1311 pub struct ToolConfig {
1312 pub schema: Option<Schema>,
1313 }
1314
1315 #[derive(Debug, Serialize)]
1316 #[serde(rename_all = "camelCase")]
1317 pub struct CodeExecution {}
1318
1319 #[derive(Debug, Serialize)]
1320 #[serde(rename_all = "camelCase")]
1321 pub struct SafetySetting {
1322 pub category: HarmCategory,
1323 pub threshold: HarmBlockThreshold,
1324 }
1325
1326 #[derive(Debug, Serialize)]
1327 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1328 pub enum HarmBlockThreshold {
1329 HarmBlockThresholdUnspecified,
1330 BlockLowAndAbove,
1331 BlockMediumAndAbove,
1332 BlockOnlyHigh,
1333 BlockNone,
1334 Off,
1335 }
1336}
1337
1338#[cfg(test)]
1339mod tests {
1340 use crate::{message, providers::gemini::completion::gemini_api_types::flatten_schema};
1341
1342 use super::*;
1343 use serde_json::json;
1344
1345 #[test]
1346 fn test_deserialize_message_user() {
1347 let raw_message = r#"{
1348 "parts": [
1349 {"text": "Hello, world!"},
1350 {"inlineData": {"mimeType": "image/png", "data": "base64encodeddata"}},
1351 {"functionCall": {"name": "test_function", "args": {"arg1": "value1"}}},
1352 {"functionResponse": {"name": "test_function", "response": {"result": "success"}}},
1353 {"fileData": {"mimeType": "application/pdf", "fileUri": "http://example.com/file.pdf"}},
1354 {"executableCode": {"code": "print('Hello, world!')", "language": "PYTHON"}},
1355 {"codeExecutionResult": {"output": "Hello, world!", "outcome": "OUTCOME_OK"}}
1356 ],
1357 "role": "user"
1358 }"#;
1359
1360 let content: Content = {
1361 let jd = &mut serde_json::Deserializer::from_str(raw_message);
1362 serde_path_to_error::deserialize(jd).unwrap_or_else(|err| {
1363 panic!("Deserialization error at {}: {}", err.path(), err);
1364 })
1365 };
1366 assert_eq!(content.role, Some(Role::User));
1367 assert_eq!(content.parts.len(), 7);
1368
1369 let parts: Vec<Part> = content.parts.into_iter().collect();
1370
1371 if let Part {
1372 part: PartKind::Text(text),
1373 ..
1374 } = &parts[0]
1375 {
1376 assert_eq!(text, "Hello, world!");
1377 } else {
1378 panic!("Expected text part");
1379 }
1380
1381 if let Part {
1382 part: PartKind::InlineData(inline_data),
1383 ..
1384 } = &parts[1]
1385 {
1386 assert_eq!(inline_data.mime_type, "image/png");
1387 assert_eq!(inline_data.data, "base64encodeddata");
1388 } else {
1389 panic!("Expected inline data part");
1390 }
1391
1392 if let Part {
1393 part: PartKind::FunctionCall(function_call),
1394 ..
1395 } = &parts[2]
1396 {
1397 assert_eq!(function_call.name, "test_function");
1398 assert_eq!(
1399 function_call.args.as_object().unwrap().get("arg1").unwrap(),
1400 "value1"
1401 );
1402 } else {
1403 panic!("Expected function call part");
1404 }
1405
1406 if let Part {
1407 part: PartKind::FunctionResponse(function_response),
1408 ..
1409 } = &parts[3]
1410 {
1411 assert_eq!(function_response.name, "test_function");
1412 assert_eq!(
1413 function_response
1414 .response
1415 .as_ref()
1416 .unwrap()
1417 .get("result")
1418 .unwrap(),
1419 "success"
1420 );
1421 } else {
1422 panic!("Expected function response part");
1423 }
1424
1425 if let Part {
1426 part: PartKind::FileData(file_data),
1427 ..
1428 } = &parts[4]
1429 {
1430 assert_eq!(file_data.mime_type.as_ref().unwrap(), "application/pdf");
1431 assert_eq!(file_data.file_uri, "http://example.com/file.pdf");
1432 } else {
1433 panic!("Expected file data part");
1434 }
1435
1436 if let Part {
1437 part: PartKind::ExecutableCode(executable_code),
1438 ..
1439 } = &parts[5]
1440 {
1441 assert_eq!(executable_code.code, "print('Hello, world!')");
1442 } else {
1443 panic!("Expected executable code part");
1444 }
1445
1446 if let Part {
1447 part: PartKind::CodeExecutionResult(code_execution_result),
1448 ..
1449 } = &parts[6]
1450 {
1451 assert_eq!(
1452 code_execution_result.clone().output.unwrap(),
1453 "Hello, world!"
1454 );
1455 } else {
1456 panic!("Expected code execution result part");
1457 }
1458 }
1459
1460 #[test]
1461 fn test_deserialize_message_model() {
1462 let json_data = json!({
1463 "parts": [{"text": "Hello, user!"}],
1464 "role": "model"
1465 });
1466
1467 let content: Content = serde_json::from_value(json_data).unwrap();
1468 assert_eq!(content.role, Some(Role::Model));
1469 assert_eq!(content.parts.len(), 1);
1470 if let Some(Part {
1471 part: PartKind::Text(text),
1472 ..
1473 }) = content.parts.first()
1474 {
1475 assert_eq!(text, "Hello, user!");
1476 } else {
1477 panic!("Expected text part");
1478 }
1479 }
1480
1481 #[test]
1482 fn test_message_conversion_user() {
1483 let msg = message::Message::user("Hello, world!");
1484 let content: Content = msg.try_into().unwrap();
1485 assert_eq!(content.role, Some(Role::User));
1486 assert_eq!(content.parts.len(), 1);
1487 if let Some(Part {
1488 part: PartKind::Text(text),
1489 ..
1490 }) = &content.parts.first()
1491 {
1492 assert_eq!(text, "Hello, world!");
1493 } else {
1494 panic!("Expected text part");
1495 }
1496 }
1497
1498 #[test]
1499 fn test_message_conversion_model() {
1500 let msg = message::Message::assistant("Hello, user!");
1501
1502 let content: Content = msg.try_into().unwrap();
1503 assert_eq!(content.role, Some(Role::Model));
1504 assert_eq!(content.parts.len(), 1);
1505 if let Some(Part {
1506 part: PartKind::Text(text),
1507 ..
1508 }) = &content.parts.first()
1509 {
1510 assert_eq!(text, "Hello, user!");
1511 } else {
1512 panic!("Expected text part");
1513 }
1514 }
1515
1516 #[test]
1517 fn test_message_conversion_tool_call() {
1518 let tool_call = message::ToolCall {
1519 id: "test_tool".to_string(),
1520 call_id: None,
1521 function: message::ToolFunction {
1522 name: "test_function".to_string(),
1523 arguments: json!({"arg1": "value1"}),
1524 },
1525 };
1526
1527 let msg = message::Message::Assistant {
1528 id: None,
1529 content: OneOrMany::one(message::AssistantContent::ToolCall(tool_call)),
1530 };
1531
1532 let content: Content = msg.try_into().unwrap();
1533 assert_eq!(content.role, Some(Role::Model));
1534 assert_eq!(content.parts.len(), 1);
1535 if let Some(Part {
1536 part: PartKind::FunctionCall(function_call),
1537 ..
1538 }) = content.parts.first()
1539 {
1540 assert_eq!(function_call.name, "test_function");
1541 assert_eq!(
1542 function_call.args.as_object().unwrap().get("arg1").unwrap(),
1543 "value1"
1544 );
1545 } else {
1546 panic!("Expected function call part");
1547 }
1548 }
1549
1550 #[test]
1551 fn test_vec_schema_conversion() {
1552 let schema_with_ref = json!({
1553 "type": "array",
1554 "items": {
1555 "$ref": "#/$defs/Person"
1556 },
1557 "$defs": {
1558 "Person": {
1559 "type": "object",
1560 "properties": {
1561 "first_name": {
1562 "type": ["string", "null"],
1563 "description": "The person's first name, if provided (null otherwise)"
1564 },
1565 "last_name": {
1566 "type": ["string", "null"],
1567 "description": "The person's last name, if provided (null otherwise)"
1568 },
1569 "job": {
1570 "type": ["string", "null"],
1571 "description": "The person's job, if provided (null otherwise)"
1572 }
1573 },
1574 "required": []
1575 }
1576 }
1577 });
1578
1579 let result: Result<Schema, _> = schema_with_ref.try_into();
1580
1581 match result {
1582 Ok(schema) => {
1583 assert_eq!(schema.r#type, "array");
1584
1585 if let Some(items) = schema.items {
1586 println!("item types: {}", items.r#type);
1587
1588 assert_ne!(items.r#type, "", "Items type should not be empty string!");
1589 assert_eq!(items.r#type, "object", "Items should be object type");
1590 } else {
1591 panic!("Schema should have items field for array type");
1592 }
1593 }
1594 Err(e) => println!("Schema conversion failed: {:?}", e),
1595 }
1596 }
1597
1598 #[test]
1599 fn test_object_schema() {
1600 let simple_schema = json!({
1601 "type": "object",
1602 "properties": {
1603 "name": {
1604 "type": "string"
1605 }
1606 }
1607 });
1608
1609 let schema: Schema = simple_schema.try_into().unwrap();
1610 assert_eq!(schema.r#type, "object");
1611 assert!(schema.properties.is_some());
1612 }
1613
1614 #[test]
1615 fn test_array_with_inline_items() {
1616 let inline_schema = json!({
1617 "type": "array",
1618 "items": {
1619 "type": "object",
1620 "properties": {
1621 "name": {
1622 "type": "string"
1623 }
1624 }
1625 }
1626 });
1627
1628 let schema: Schema = inline_schema.try_into().unwrap();
1629 assert_eq!(schema.r#type, "array");
1630
1631 if let Some(items) = schema.items {
1632 assert_eq!(items.r#type, "object");
1633 assert!(items.properties.is_some());
1634 } else {
1635 panic!("Schema should have items field");
1636 }
1637 }
1638 #[test]
1639 fn test_flattened_schema() {
1640 let ref_schema = json!({
1641 "type": "array",
1642 "items": {
1643 "$ref": "#/$defs/Person"
1644 },
1645 "$defs": {
1646 "Person": {
1647 "type": "object",
1648 "properties": {
1649 "name": { "type": "string" }
1650 }
1651 }
1652 }
1653 });
1654
1655 let flattened = flatten_schema(ref_schema).unwrap();
1656 let schema: Schema = flattened.try_into().unwrap();
1657
1658 assert_eq!(schema.r#type, "array");
1659
1660 if let Some(items) = schema.items {
1661 println!("Flattened items type: '{}'", items.r#type);
1662
1663 assert_eq!(items.r#type, "object");
1664 assert!(items.properties.is_some());
1665 }
1666 }
1667}