1pub const GEMINI_2_5_PRO_PREVIEW_06_05: &str = "gemini-2.5-pro-preview-06-05";
7pub const GEMINI_2_5_PRO_PREVIEW_05_06: &str = "gemini-2.5-pro-preview-05-06";
9pub const GEMINI_2_5_PRO_PREVIEW_03_25: &str = "gemini-2.5-pro-preview-03-25";
11pub const GEMINI_2_5_FLASH_PREVIEW_05_20: &str = "gemini-2.5-flash-preview-05-20";
13pub const GEMINI_2_5_FLASH_PREVIEW_04_17: &str = "gemini-2.5-flash-preview-04-17";
15pub const GEMINI_2_5_PRO_EXP_03_25: &str = "gemini-2.5-pro-exp-03-25";
17pub const GEMINI_2_0_FLASH_LITE: &str = "gemini-2.0-flash-lite";
19pub const GEMINI_2_0_FLASH: &str = "gemini-2.0-flash";
21pub const GEMINI_1_5_FLASH: &str = "gemini-1.5-flash";
23pub const GEMINI_1_5_PRO: &str = "gemini-1.5-pro";
25pub const GEMINI_1_5_PRO_8B: &str = "gemini-1.5-pro-8b";
27pub const GEMINI_1_0_PRO: &str = "gemini-1.0-pro";
29
30use self::gemini_api_types::Schema;
31use crate::http_client::HttpClientExt;
32use crate::message::Reasoning;
33use crate::providers::gemini::completion::gemini_api_types::{
34 AdditionalParameters, FunctionCallingMode, ToolConfig,
35};
36use crate::providers::gemini::streaming::StreamingCompletionResponse;
37use crate::telemetry::SpanCombinator;
38use crate::{
39 OneOrMany,
40 completion::{self, CompletionError, CompletionRequest},
41};
42use gemini_api_types::{
43 Content, FunctionDeclaration, GenerateContentRequest, GenerateContentResponse, Part, PartKind,
44 Role, Tool,
45};
46use serde_json::{Map, Value};
47use std::convert::TryFrom;
48use tracing::info_span;
49
50use super::Client;
51
52#[derive(Clone)]
57pub struct CompletionModel<T = reqwest::Client> {
58 pub(crate) client: Client<T>,
59 pub model: String,
60}
61
62impl<T> CompletionModel<T> {
63 pub fn new(client: Client<T>, model: &str) -> Self {
64 Self {
65 client,
66 model: model.to_string(),
67 }
68 }
69}
70
71impl<T> completion::CompletionModel for CompletionModel<T>
72where
73 T: HttpClientExt + Clone + 'static,
74{
75 type Response = GenerateContentResponse;
76 type StreamingResponse = StreamingCompletionResponse;
77
78 #[cfg_attr(feature = "worker", worker::send)]
79 async fn completion(
80 &self,
81 completion_request: CompletionRequest,
82 ) -> Result<completion::CompletionResponse<GenerateContentResponse>, CompletionError> {
83 let span = if tracing::Span::current().is_disabled() {
84 info_span!(
85 target: "rig::completions",
86 "generate_content",
87 gen_ai.operation.name = "generate_content",
88 gen_ai.provider.name = "gcp.gemini",
89 gen_ai.request.model = self.model,
90 gen_ai.system_instructions = &completion_request.preamble,
91 gen_ai.response.id = tracing::field::Empty,
92 gen_ai.response.model = tracing::field::Empty,
93 gen_ai.usage.output_tokens = tracing::field::Empty,
94 gen_ai.usage.input_tokens = tracing::field::Empty,
95 gen_ai.input.messages = tracing::field::Empty,
96 gen_ai.output.messages = tracing::field::Empty,
97 )
98 } else {
99 tracing::Span::current()
100 };
101
102 let request = create_request_body(completion_request)?;
103 span.record_model_input(&request.contents);
104
105 span.record_model_input(&request.contents);
106
107 tracing::debug!(
108 "Sending completion request to Gemini API {}",
109 serde_json::to_string_pretty(&request)?
110 );
111
112 let body = serde_json::to_vec(&request)?;
113
114 let request = self
115 .client
116 .post(&format!("/v1beta/models/{}:generateContent", self.model))
117 .header("Content-Type", "application/json")
118 .body(body)
119 .map_err(|e| CompletionError::HttpError(e.into()))?;
120
121 let response = self.client.send::<_, Vec<u8>>(request).await?;
122
123 if response.status().is_success() {
124 let response_body = response
125 .into_body()
126 .await
127 .map_err(CompletionError::HttpError)?;
128
129 let response_text = String::from_utf8_lossy(&response_body).to_string();
130 tracing::debug!("Received raw response from Gemini API: {}", response_text);
131
132 let response: GenerateContentResponse = serde_json::from_slice(&response_body)
133 .map_err(|err| {
134 tracing::error!(
135 error = %err,
136 body = %response_text,
137 "Failed to deserialize Gemini completion response"
138 );
139 CompletionError::JsonError(err)
140 })?;
141
142 match response.usage_metadata {
143 Some(ref usage) => tracing::info!(target: "rig",
144 "Gemini completion token usage: {}",
145 usage
146 ),
147 None => tracing::info!(target: "rig",
148 "Gemini completion token usage: n/a",
149 ),
150 }
151
152 let span = tracing::Span::current();
153 span.record_model_output(&response.candidates);
154 span.record_response_metadata(&response);
155 span.record_token_usage(&response.usage_metadata);
156
157 tracing::debug!(
158 "Received response from Gemini API: {}",
159 serde_json::to_string_pretty(&response)?
160 );
161
162 response.try_into()
163 } else {
164 let text = String::from_utf8_lossy(
165 &response
166 .into_body()
167 .await
168 .map_err(CompletionError::HttpError)?,
169 )
170 .into();
171
172 Err(CompletionError::ProviderError(text))
173 }
174 }
175
176 #[cfg_attr(feature = "worker", worker::send)]
177 async fn stream(
178 &self,
179 request: CompletionRequest,
180 ) -> Result<
181 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
182 CompletionError,
183 > {
184 CompletionModel::stream(self, request).await
185 }
186}
187
188pub(crate) fn create_request_body(
189 completion_request: CompletionRequest,
190) -> Result<GenerateContentRequest, CompletionError> {
191 let mut full_history = Vec::new();
192 full_history.extend(completion_request.chat_history);
193
194 let additional_params = completion_request
195 .additional_params
196 .unwrap_or_else(|| Value::Object(Map::new()));
197
198 let AdditionalParameters {
199 mut generation_config,
200 additional_params,
201 } = serde_json::from_value::<AdditionalParameters>(additional_params)?;
202
203 if let Some(temp) = completion_request.temperature {
204 generation_config.temperature = Some(temp);
205 }
206
207 if let Some(max_tokens) = completion_request.max_tokens {
208 generation_config.max_output_tokens = Some(max_tokens);
209 }
210
211 let system_instruction = completion_request.preamble.clone().map(|preamble| Content {
212 parts: vec![preamble.into()],
213 role: Some(Role::Model),
214 });
215
216 let tools = if completion_request.tools.is_empty() {
217 None
218 } else {
219 Some(Tool::try_from(completion_request.tools)?)
220 };
221
222 let tool_config = if let Some(cfg) = completion_request.tool_choice {
223 Some(ToolConfig {
224 function_calling_config: Some(FunctionCallingMode::try_from(cfg)?),
225 })
226 } else {
227 None
228 };
229
230 let request = GenerateContentRequest {
231 contents: full_history
232 .into_iter()
233 .map(|msg| {
234 msg.try_into()
235 .map_err(|e| CompletionError::RequestError(Box::new(e)))
236 })
237 .collect::<Result<Vec<_>, _>>()?,
238 generation_config: Some(generation_config),
239 safety_settings: None,
240 tools,
241 tool_config,
242 system_instruction,
243 additional_params,
244 };
245
246 Ok(request)
247}
248
249impl TryFrom<completion::ToolDefinition> for Tool {
250 type Error = CompletionError;
251
252 fn try_from(tool: completion::ToolDefinition) -> Result<Self, Self::Error> {
253 let parameters: Option<Schema> =
254 if tool.parameters == serde_json::json!({"type": "object", "properties": {}}) {
255 None
256 } else {
257 Some(tool.parameters.try_into()?)
258 };
259
260 Ok(Self {
261 function_declarations: vec![FunctionDeclaration {
262 name: tool.name,
263 description: tool.description,
264 parameters,
265 }],
266 code_execution: None,
267 })
268 }
269}
270
271impl TryFrom<Vec<completion::ToolDefinition>> for Tool {
272 type Error = CompletionError;
273
274 fn try_from(tools: Vec<completion::ToolDefinition>) -> Result<Self, Self::Error> {
275 let mut function_declarations = Vec::new();
276
277 for tool in tools {
278 let parameters =
279 if tool.parameters == serde_json::json!({"type": "object", "properties": {}}) {
280 None
281 } else {
282 match tool.parameters.try_into() {
283 Ok(schema) => Some(schema),
284 Err(e) => {
285 let emsg = format!(
286 "Tool '{}' could not be converted to a schema: {:?}",
287 tool.name, e,
288 );
289 return Err(CompletionError::ProviderError(emsg));
290 }
291 }
292 };
293
294 function_declarations.push(FunctionDeclaration {
295 name: tool.name,
296 description: tool.description,
297 parameters,
298 });
299 }
300
301 Ok(Self {
302 function_declarations,
303 code_execution: None,
304 })
305 }
306}
307
308impl TryFrom<GenerateContentResponse> for completion::CompletionResponse<GenerateContentResponse> {
309 type Error = CompletionError;
310
311 fn try_from(response: GenerateContentResponse) -> Result<Self, Self::Error> {
312 let candidate = response.candidates.first().ok_or_else(|| {
313 CompletionError::ResponseError("No response candidates in response".into())
314 })?;
315
316 let content = candidate
317 .content
318 .as_ref()
319 .ok_or_else(|| {
320 let reason = candidate
321 .finish_reason
322 .as_ref()
323 .map(|r| format!("finish_reason={r:?}"))
324 .unwrap_or_else(|| "finish_reason=<unknown>".to_string());
325 let message = candidate
326 .finish_message
327 .as_deref()
328 .unwrap_or("no finish message provided");
329 CompletionError::ResponseError(format!(
330 "Gemini candidate missing content ({reason}, finish_message={message})"
331 ))
332 })?
333 .parts
334 .iter()
335 .map(|Part { thought, part, .. }| {
336 Ok(match part {
337 PartKind::Text(text) => {
338 if let Some(thought) = thought
339 && *thought
340 {
341 completion::AssistantContent::Reasoning(Reasoning::new(text))
342 } else {
343 completion::AssistantContent::text(text)
344 }
345 }
346 PartKind::FunctionCall(function_call) => {
347 completion::AssistantContent::tool_call(
348 &function_call.name,
349 &function_call.name,
350 function_call.args.clone(),
351 )
352 }
353 _ => {
354 return Err(CompletionError::ResponseError(
355 "Response did not contain a message or tool call".into(),
356 ));
357 }
358 })
359 })
360 .collect::<Result<Vec<_>, _>>()?;
361
362 let choice = OneOrMany::many(content).map_err(|_| {
363 CompletionError::ResponseError(
364 "Response contained no message or tool call (empty)".to_owned(),
365 )
366 })?;
367
368 let usage = response
369 .usage_metadata
370 .as_ref()
371 .map(|usage| completion::Usage {
372 input_tokens: usage.prompt_token_count as u64,
373 output_tokens: usage.candidates_token_count.unwrap_or(0) as u64,
374 total_tokens: usage.total_token_count as u64,
375 })
376 .unwrap_or_default();
377
378 Ok(completion::CompletionResponse {
379 choice,
380 usage,
381 raw_response: response,
382 })
383 }
384}
385
386pub mod gemini_api_types {
387 use crate::telemetry::ProviderResponseExt;
388 use std::{collections::HashMap, convert::Infallible, str::FromStr};
389
390 use serde::{Deserialize, Serialize};
394 use serde_json::{Value, json};
395
396 use crate::completion::GetTokenUsage;
397 use crate::message::{DocumentSourceKind, ImageMediaType, MessageError, MimeType};
398 use crate::{
399 OneOrMany,
400 completion::CompletionError,
401 message::{self, Reasoning, Text},
402 providers::gemini::gemini_api_types::{CodeExecutionResult, ExecutableCode},
403 };
404
405 #[derive(Debug, Deserialize, Serialize, Default)]
406 #[serde(rename_all = "camelCase")]
407 pub struct AdditionalParameters {
408 pub generation_config: GenerationConfig,
410 #[serde(flatten, skip_serializing_if = "Option::is_none")]
412 pub additional_params: Option<serde_json::Value>,
413 }
414
415 impl AdditionalParameters {
416 pub fn with_config(mut self, cfg: GenerationConfig) -> Self {
417 self.generation_config = cfg;
418 self
419 }
420
421 pub fn with_params(mut self, params: serde_json::Value) -> Self {
422 self.additional_params = Some(params);
423 self
424 }
425 }
426
427 #[derive(Debug, Deserialize, Serialize)]
435 #[serde(rename_all = "camelCase")]
436 pub struct GenerateContentResponse {
437 pub response_id: String,
438 pub candidates: Vec<ContentCandidate>,
440 pub prompt_feedback: Option<PromptFeedback>,
442 pub usage_metadata: Option<UsageMetadata>,
444 pub model_version: Option<String>,
445 }
446
447 impl ProviderResponseExt for GenerateContentResponse {
448 type OutputMessage = ContentCandidate;
449 type Usage = UsageMetadata;
450
451 fn get_response_id(&self) -> Option<String> {
452 Some(self.response_id.clone())
453 }
454
455 fn get_response_model_name(&self) -> Option<String> {
456 None
457 }
458
459 fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
460 self.candidates.clone()
461 }
462
463 fn get_text_response(&self) -> Option<String> {
464 let str = self
465 .candidates
466 .iter()
467 .filter_map(|x| {
468 let content = x.content.as_ref()?;
469 if content.role.as_ref().is_none_or(|y| y != &Role::Model) {
470 return None;
471 }
472
473 let res = content
474 .parts
475 .iter()
476 .filter_map(|part| {
477 if let PartKind::Text(ref str) = part.part {
478 Some(str.to_owned())
479 } else {
480 None
481 }
482 })
483 .collect::<Vec<String>>()
484 .join("\n");
485
486 Some(res)
487 })
488 .collect::<Vec<String>>()
489 .join("\n");
490
491 if str.is_empty() { None } else { Some(str) }
492 }
493
494 fn get_usage(&self) -> Option<Self::Usage> {
495 self.usage_metadata.clone()
496 }
497 }
498
499 #[derive(Clone, Debug, Deserialize, Serialize)]
501 #[serde(rename_all = "camelCase")]
502 pub struct ContentCandidate {
503 #[serde(skip_serializing_if = "Option::is_none")]
505 pub content: Option<Content>,
506 pub finish_reason: Option<FinishReason>,
509 pub safety_ratings: Option<Vec<SafetyRating>>,
512 pub citation_metadata: Option<CitationMetadata>,
516 pub token_count: Option<i32>,
518 pub avg_logprobs: Option<f64>,
520 pub logprobs_result: Option<LogprobsResult>,
522 pub index: Option<i32>,
524 pub finish_message: Option<String>,
526 }
527
528 #[derive(Clone, Debug, Deserialize, Serialize)]
529 pub struct Content {
530 #[serde(default)]
532 pub parts: Vec<Part>,
533 pub role: Option<Role>,
536 }
537
538 impl TryFrom<message::Message> for Content {
539 type Error = message::MessageError;
540
541 fn try_from(msg: message::Message) -> Result<Self, Self::Error> {
542 Ok(match msg {
543 message::Message::User { content } => Content {
544 parts: content
545 .into_iter()
546 .map(|c| c.try_into())
547 .collect::<Result<Vec<_>, _>>()?,
548 role: Some(Role::User),
549 },
550 message::Message::Assistant { content, .. } => Content {
551 role: Some(Role::Model),
552 parts: content.into_iter().map(|content| content.into()).collect(),
553 },
554 })
555 }
556 }
557
558 impl TryFrom<Content> for message::Message {
559 type Error = message::MessageError;
560
561 fn try_from(content: Content) -> Result<Self, Self::Error> {
562 match content.role {
563 Some(Role::User) | None => {
564 Ok(message::Message::User {
565 content: {
566 let user_content: Result<Vec<_>, _> = content.parts.into_iter()
567 .map(|Part { part, .. }| {
568 Ok(match part {
569 PartKind::Text(text) => message::UserContent::text(text),
570 PartKind::InlineData(inline_data) => {
571 let mime_type =
572 message::MediaType::from_mime_type(&inline_data.mime_type);
573
574 match mime_type {
575 Some(message::MediaType::Image(media_type)) => {
576 message::UserContent::image_base64(
577 inline_data.data,
578 Some(media_type),
579 Some(message::ImageDetail::default()),
580 )
581 }
582 Some(message::MediaType::Document(media_type)) => {
583 message::UserContent::document(
584 inline_data.data,
585 Some(media_type),
586 )
587 }
588 Some(message::MediaType::Audio(media_type)) => {
589 message::UserContent::audio(
590 inline_data.data,
591 Some(media_type),
592 )
593 }
594 _ => {
595 return Err(message::MessageError::ConversionError(
596 format!("Unsupported media type {mime_type:?}"),
597 ));
598 }
599 }
600 }
601 _ => {
602 return Err(message::MessageError::ConversionError(format!(
603 "Unsupported gemini content part type: {part:?}"
604 )));
605 }
606 })
607 })
608 .collect();
609 OneOrMany::many(user_content?).map_err(|_| {
610 message::MessageError::ConversionError(
611 "Failed to create OneOrMany from user content".to_string(),
612 )
613 })?
614 },
615 })
616 }
617 Some(Role::Model) => Ok(message::Message::Assistant {
618 id: None,
619 content: {
620 let assistant_content: Result<Vec<_>, _> = content
621 .parts
622 .into_iter()
623 .map(|Part { thought, part, .. }| {
624 Ok(match part {
625 PartKind::Text(text) => match thought {
626 Some(true) => message::AssistantContent::Reasoning(
627 Reasoning::new(&text),
628 ),
629 _ => message::AssistantContent::Text(Text { text }),
630 },
631
632 PartKind::FunctionCall(function_call) => {
633 message::AssistantContent::ToolCall(function_call.into())
634 }
635 _ => {
636 return Err(message::MessageError::ConversionError(
637 format!("Unsupported part type: {part:?}"),
638 ));
639 }
640 })
641 })
642 .collect();
643 OneOrMany::many(assistant_content?).map_err(|_| {
644 message::MessageError::ConversionError(
645 "Failed to create OneOrMany from assistant content".to_string(),
646 )
647 })?
648 },
649 }),
650 }
651 }
652 }
653
654 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
655 #[serde(rename_all = "lowercase")]
656 pub enum Role {
657 User,
658 Model,
659 }
660
661 #[derive(Debug, Default, Deserialize, Serialize, Clone, PartialEq)]
662 #[serde(rename_all = "camelCase")]
663 pub struct Part {
664 #[serde(skip_serializing_if = "Option::is_none")]
666 pub thought: Option<bool>,
667 #[serde(skip_serializing_if = "Option::is_none")]
669 pub thought_signature: Option<String>,
670 #[serde(flatten)]
671 pub part: PartKind,
672 #[serde(flatten, skip_serializing_if = "Option::is_none")]
673 pub additional_params: Option<Value>,
674 }
675
676 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
680 #[serde(rename_all = "camelCase")]
681 pub enum PartKind {
682 Text(String),
683 InlineData(Blob),
684 FunctionCall(FunctionCall),
685 FunctionResponse(FunctionResponse),
686 FileData(FileData),
687 ExecutableCode(ExecutableCode),
688 CodeExecutionResult(CodeExecutionResult),
689 }
690
691 impl Default for PartKind {
694 fn default() -> Self {
695 Self::Text(String::new())
696 }
697 }
698
699 impl From<String> for Part {
700 fn from(text: String) -> Self {
701 Self {
702 thought: Some(false),
703 thought_signature: None,
704 part: PartKind::Text(text),
705 additional_params: None,
706 }
707 }
708 }
709
710 impl From<&str> for Part {
711 fn from(text: &str) -> Self {
712 Self::from(text.to_string())
713 }
714 }
715
716 impl FromStr for Part {
717 type Err = Infallible;
718
719 fn from_str(s: &str) -> Result<Self, Self::Err> {
720 Ok(s.into())
721 }
722 }
723
724 impl TryFrom<(ImageMediaType, DocumentSourceKind)> for PartKind {
725 type Error = message::MessageError;
726 fn try_from(
727 (mime_type, doc_src): (ImageMediaType, DocumentSourceKind),
728 ) -> Result<Self, Self::Error> {
729 let mime_type = mime_type.to_mime_type().to_string();
730 let part = match doc_src {
731 DocumentSourceKind::Url(url) => PartKind::FileData(FileData {
732 mime_type: Some(mime_type),
733 file_uri: url,
734 }),
735 DocumentSourceKind::Base64(data) | DocumentSourceKind::String(data) => {
736 PartKind::InlineData(Blob { mime_type, data })
737 }
738 DocumentSourceKind::Raw(_) => {
739 return Err(message::MessageError::ConversionError(
740 "Raw files not supported, encode as base64 first".into(),
741 ));
742 }
743 DocumentSourceKind::Unknown => {
744 return Err(message::MessageError::ConversionError(
745 "Can't convert an unknown document source".to_string(),
746 ));
747 }
748 };
749
750 Ok(part)
751 }
752 }
753
754 impl TryFrom<message::UserContent> for Part {
755 type Error = message::MessageError;
756
757 fn try_from(content: message::UserContent) -> Result<Self, Self::Error> {
758 match content {
759 message::UserContent::Text(message::Text { text }) => Ok(Part {
760 thought: Some(false),
761 thought_signature: None,
762 part: PartKind::Text(text),
763 additional_params: None,
764 }),
765 message::UserContent::ToolResult(message::ToolResult { id, content, .. }) => {
766 let content = match content.first() {
767 message::ToolResultContent::Text(text) => text.text,
768 message::ToolResultContent::Image(_) => {
769 return Err(message::MessageError::ConversionError(
770 "Tool result content must be text".to_string(),
771 ));
772 }
773 };
774 let result: serde_json::Value =
776 serde_json::from_str(&content).unwrap_or_else(|error| {
777 tracing::trace!(
778 ?error,
779 "Tool result is not a valid JSON, treat it as normal string"
780 );
781 json!(content)
782 });
783 Ok(Part {
784 thought: Some(false),
785 thought_signature: None,
786 part: PartKind::FunctionResponse(FunctionResponse {
787 name: id,
788 response: Some(json!({ "result": result })),
789 }),
790 additional_params: None,
791 })
792 }
793 message::UserContent::Image(message::Image {
794 data, media_type, ..
795 }) => match media_type {
796 Some(media_type) => match media_type {
797 message::ImageMediaType::JPEG
798 | message::ImageMediaType::PNG
799 | message::ImageMediaType::WEBP
800 | message::ImageMediaType::HEIC
801 | message::ImageMediaType::HEIF => {
802 let part = PartKind::try_from((media_type, data))?;
803 Ok(Part {
804 thought: Some(false),
805 thought_signature: None,
806 part,
807 additional_params: None,
808 })
809 }
810 _ => Err(message::MessageError::ConversionError(format!(
811 "Unsupported image media type {media_type:?}"
812 ))),
813 },
814 None => Err(message::MessageError::ConversionError(
815 "Media type for image is required for Gemini".to_string(),
816 )),
817 },
818 message::UserContent::Document(message::Document {
819 data, media_type, ..
820 }) => {
821 let Some(media_type) = media_type else {
822 return Err(MessageError::ConversionError(
823 "A mime type is required for document inputs to Gemini".to_string(),
824 ));
825 };
826
827 if !media_type.is_code() {
828 let mime_type = media_type.to_mime_type().to_string();
829
830 let part = match data {
831 DocumentSourceKind::Url(file_uri) => PartKind::FileData(FileData {
832 mime_type: Some(mime_type),
833 file_uri,
834 }),
835 DocumentSourceKind::Base64(data) | DocumentSourceKind::String(data) => {
836 PartKind::InlineData(Blob { mime_type, data })
837 }
838 DocumentSourceKind::Raw(_) => {
839 return Err(message::MessageError::ConversionError(
840 "Raw files not supported, encode as base64 first".into(),
841 ));
842 }
843 _ => {
844 return Err(message::MessageError::ConversionError(
845 "Document has no body".to_string(),
846 ));
847 }
848 };
849
850 Ok(Part {
851 thought: Some(false),
852 part,
853 ..Default::default()
854 })
855 } else {
856 Err(message::MessageError::ConversionError(format!(
857 "Unsupported document media type {media_type:?}"
858 )))
859 }
860 }
861
862 message::UserContent::Audio(message::Audio {
863 data, media_type, ..
864 }) => {
865 let Some(media_type) = media_type else {
866 return Err(MessageError::ConversionError(
867 "A mime type is required for audio inputs to Gemini".to_string(),
868 ));
869 };
870
871 let mime_type = media_type.to_mime_type().to_string();
872
873 let part = match data {
874 DocumentSourceKind::Base64(data) => {
875 PartKind::InlineData(Blob { data, mime_type })
876 }
877
878 DocumentSourceKind::Url(file_uri) => PartKind::FileData(FileData {
879 mime_type: Some(mime_type),
880 file_uri,
881 }),
882 DocumentSourceKind::String(_) => {
883 return Err(message::MessageError::ConversionError(
884 "Strings cannot be used as audio files!".into(),
885 ));
886 }
887 DocumentSourceKind::Raw(_) => {
888 return Err(message::MessageError::ConversionError(
889 "Raw files not supported, encode as base64 first".into(),
890 ));
891 }
892 DocumentSourceKind::Unknown => {
893 return Err(message::MessageError::ConversionError(
894 "Content has no body".to_string(),
895 ));
896 }
897 };
898
899 Ok(Part {
900 thought: Some(false),
901 part,
902 ..Default::default()
903 })
904 }
905 message::UserContent::Video(message::Video {
906 data,
907 media_type,
908 additional_params,
909 ..
910 }) => {
911 let mime_type = media_type.map(|media_ty| media_ty.to_mime_type().to_string());
912
913 let part = match data {
914 DocumentSourceKind::Url(file_uri) => {
915 if file_uri.starts_with("https://www.youtube.com") {
916 PartKind::FileData(FileData {
917 mime_type,
918 file_uri,
919 })
920 } else {
921 if mime_type.is_none() {
922 return Err(MessageError::ConversionError(
923 "A mime type is required for non-Youtube video file inputs to Gemini"
924 .to_string(),
925 ));
926 }
927
928 PartKind::FileData(FileData {
929 mime_type,
930 file_uri,
931 })
932 }
933 }
934 DocumentSourceKind::Base64(data) => {
935 let Some(mime_type) = mime_type else {
936 return Err(MessageError::ConversionError(
937 "A media type is expected for base64 encoded strings"
938 .to_string(),
939 ));
940 };
941 PartKind::InlineData(Blob { mime_type, data })
942 }
943 DocumentSourceKind::String(_) => {
944 return Err(message::MessageError::ConversionError(
945 "Strings cannot be used as audio files!".into(),
946 ));
947 }
948 DocumentSourceKind::Raw(_) => {
949 return Err(message::MessageError::ConversionError(
950 "Raw file data not supported, encode as base64 first".into(),
951 ));
952 }
953 DocumentSourceKind::Unknown => {
954 return Err(message::MessageError::ConversionError(
955 "Media type for video is required for Gemini".to_string(),
956 ));
957 }
958 };
959
960 Ok(Part {
961 thought: Some(false),
962 thought_signature: None,
963 part,
964 additional_params,
965 })
966 }
967 }
968 }
969 }
970
971 impl From<message::AssistantContent> for Part {
972 fn from(content: message::AssistantContent) -> Self {
973 match content {
974 message::AssistantContent::Text(message::Text { text }) => text.into(),
975 message::AssistantContent::ToolCall(tool_call) => tool_call.into(),
976 message::AssistantContent::Reasoning(message::Reasoning { reasoning, .. }) => {
977 Part {
978 thought: Some(true),
979 thought_signature: None,
980 part: PartKind::Text(
981 reasoning.first().cloned().unwrap_or_else(|| "".to_string()),
982 ),
983 additional_params: None,
984 }
985 }
986 }
987 }
988 }
989
990 impl From<message::ToolCall> for Part {
991 fn from(tool_call: message::ToolCall) -> Self {
992 Self {
993 thought: Some(false),
994 thought_signature: None,
995 part: PartKind::FunctionCall(FunctionCall {
996 name: tool_call.function.name,
997 args: tool_call.function.arguments,
998 }),
999 additional_params: None,
1000 }
1001 }
1002 }
1003
1004 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1007 #[serde(rename_all = "camelCase")]
1008 pub struct Blob {
1009 pub mime_type: String,
1012 pub data: String,
1014 }
1015
1016 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1019 pub struct FunctionCall {
1020 pub name: String,
1023 pub args: serde_json::Value,
1025 }
1026
1027 impl From<FunctionCall> for message::ToolCall {
1028 fn from(function_call: FunctionCall) -> Self {
1029 Self {
1030 id: function_call.name.clone(),
1031 call_id: None,
1032 function: message::ToolFunction {
1033 name: function_call.name,
1034 arguments: function_call.args,
1035 },
1036 }
1037 }
1038 }
1039
1040 impl From<message::ToolCall> for FunctionCall {
1041 fn from(tool_call: message::ToolCall) -> Self {
1042 Self {
1043 name: tool_call.function.name,
1044 args: tool_call.function.arguments,
1045 }
1046 }
1047 }
1048
1049 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1053 pub struct FunctionResponse {
1054 pub name: String,
1057 pub response: Option<serde_json::Value>,
1059 }
1060
1061 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1063 #[serde(rename_all = "camelCase")]
1064 pub struct FileData {
1065 pub mime_type: Option<String>,
1067 pub file_uri: String,
1069 }
1070
1071 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1072 pub struct SafetyRating {
1073 pub category: HarmCategory,
1074 pub probability: HarmProbability,
1075 }
1076
1077 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1078 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1079 pub enum HarmProbability {
1080 HarmProbabilityUnspecified,
1081 Negligible,
1082 Low,
1083 Medium,
1084 High,
1085 }
1086
1087 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1088 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1089 pub enum HarmCategory {
1090 HarmCategoryUnspecified,
1091 HarmCategoryDerogatory,
1092 HarmCategoryToxicity,
1093 HarmCategoryViolence,
1094 HarmCategorySexually,
1095 HarmCategoryMedical,
1096 HarmCategoryDangerous,
1097 HarmCategoryHarassment,
1098 HarmCategoryHateSpeech,
1099 HarmCategorySexuallyExplicit,
1100 HarmCategoryDangerousContent,
1101 HarmCategoryCivicIntegrity,
1102 }
1103
1104 #[derive(Debug, Deserialize, Clone, Default, Serialize)]
1105 #[serde(rename_all = "camelCase")]
1106 pub struct UsageMetadata {
1107 pub prompt_token_count: i32,
1108 #[serde(skip_serializing_if = "Option::is_none")]
1109 pub cached_content_token_count: Option<i32>,
1110 #[serde(skip_serializing_if = "Option::is_none")]
1111 pub candidates_token_count: Option<i32>,
1112 pub total_token_count: i32,
1113 #[serde(skip_serializing_if = "Option::is_none")]
1114 pub thoughts_token_count: Option<i32>,
1115 }
1116
1117 impl std::fmt::Display for UsageMetadata {
1118 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1119 write!(
1120 f,
1121 "Prompt token count: {}\nCached content token count: {}\nCandidates token count: {}\nTotal token count: {}",
1122 self.prompt_token_count,
1123 match self.cached_content_token_count {
1124 Some(count) => count.to_string(),
1125 None => "n/a".to_string(),
1126 },
1127 match self.candidates_token_count {
1128 Some(count) => count.to_string(),
1129 None => "n/a".to_string(),
1130 },
1131 self.total_token_count
1132 )
1133 }
1134 }
1135
1136 impl GetTokenUsage for UsageMetadata {
1137 fn token_usage(&self) -> Option<crate::completion::Usage> {
1138 let mut usage = crate::completion::Usage::new();
1139
1140 usage.input_tokens = self.prompt_token_count as u64;
1141 usage.output_tokens = (self.cached_content_token_count.unwrap_or_default()
1142 + self.candidates_token_count.unwrap_or_default()
1143 + self.thoughts_token_count.unwrap_or_default())
1144 as u64;
1145 usage.total_tokens = usage.input_tokens + usage.output_tokens;
1146
1147 Some(usage)
1148 }
1149 }
1150
1151 #[derive(Debug, Deserialize, Serialize)]
1153 #[serde(rename_all = "camelCase")]
1154 pub struct PromptFeedback {
1155 pub block_reason: Option<BlockReason>,
1157 pub safety_ratings: Option<Vec<SafetyRating>>,
1159 }
1160
1161 #[derive(Debug, Deserialize, Serialize)]
1163 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1164 pub enum BlockReason {
1165 BlockReasonUnspecified,
1167 Safety,
1169 Other,
1171 Blocklist,
1173 ProhibitedContent,
1175 }
1176
1177 #[derive(Clone, Debug, Deserialize, Serialize)]
1178 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1179 pub enum FinishReason {
1180 FinishReasonUnspecified,
1182 Stop,
1184 MaxTokens,
1186 Safety,
1188 Recitation,
1190 Language,
1192 Other,
1194 Blocklist,
1196 ProhibitedContent,
1198 Spii,
1200 MalformedFunctionCall,
1202 }
1203
1204 #[derive(Clone, Debug, Deserialize, Serialize)]
1205 #[serde(rename_all = "camelCase")]
1206 pub struct CitationMetadata {
1207 pub citation_sources: Vec<CitationSource>,
1208 }
1209
1210 #[derive(Clone, Debug, Deserialize, Serialize)]
1211 #[serde(rename_all = "camelCase")]
1212 pub struct CitationSource {
1213 #[serde(skip_serializing_if = "Option::is_none")]
1214 pub uri: Option<String>,
1215 #[serde(skip_serializing_if = "Option::is_none")]
1216 pub start_index: Option<i32>,
1217 #[serde(skip_serializing_if = "Option::is_none")]
1218 pub end_index: Option<i32>,
1219 #[serde(skip_serializing_if = "Option::is_none")]
1220 pub license: Option<String>,
1221 }
1222
1223 #[derive(Clone, Debug, Deserialize, Serialize)]
1224 #[serde(rename_all = "camelCase")]
1225 pub struct LogprobsResult {
1226 pub top_candidate: Vec<TopCandidate>,
1227 pub chosen_candidate: Vec<LogProbCandidate>,
1228 }
1229
1230 #[derive(Clone, Debug, Deserialize, Serialize)]
1231 pub struct TopCandidate {
1232 pub candidates: Vec<LogProbCandidate>,
1233 }
1234
1235 #[derive(Clone, Debug, Deserialize, Serialize)]
1236 #[serde(rename_all = "camelCase")]
1237 pub struct LogProbCandidate {
1238 pub token: String,
1239 pub token_id: String,
1240 pub log_probability: f64,
1241 }
1242
1243 #[derive(Debug, Deserialize, Serialize)]
1248 #[serde(rename_all = "camelCase")]
1249 pub struct GenerationConfig {
1250 #[serde(skip_serializing_if = "Option::is_none")]
1253 pub stop_sequences: Option<Vec<String>>,
1254 #[serde(skip_serializing_if = "Option::is_none")]
1260 pub response_mime_type: Option<String>,
1261 #[serde(skip_serializing_if = "Option::is_none")]
1265 pub response_schema: Option<Schema>,
1266 #[serde(skip_serializing_if = "Option::is_none")]
1269 pub candidate_count: Option<i32>,
1270 #[serde(skip_serializing_if = "Option::is_none")]
1273 pub max_output_tokens: Option<u64>,
1274 #[serde(skip_serializing_if = "Option::is_none")]
1277 pub temperature: Option<f64>,
1278 #[serde(skip_serializing_if = "Option::is_none")]
1285 pub top_p: Option<f64>,
1286 #[serde(skip_serializing_if = "Option::is_none")]
1292 pub top_k: Option<i32>,
1293 #[serde(skip_serializing_if = "Option::is_none")]
1299 pub presence_penalty: Option<f64>,
1300 #[serde(skip_serializing_if = "Option::is_none")]
1308 pub frequency_penalty: Option<f64>,
1309 #[serde(skip_serializing_if = "Option::is_none")]
1311 pub response_logprobs: Option<bool>,
1312 #[serde(skip_serializing_if = "Option::is_none")]
1315 pub logprobs: Option<i32>,
1316 #[serde(skip_serializing_if = "Option::is_none")]
1318 pub thinking_config: Option<ThinkingConfig>,
1319 }
1320
1321 impl Default for GenerationConfig {
1322 fn default() -> Self {
1323 Self {
1324 temperature: Some(1.0),
1325 max_output_tokens: Some(4096),
1326 stop_sequences: None,
1327 response_mime_type: None,
1328 response_schema: None,
1329 candidate_count: None,
1330 top_p: None,
1331 top_k: None,
1332 presence_penalty: None,
1333 frequency_penalty: None,
1334 response_logprobs: None,
1335 logprobs: None,
1336 thinking_config: None,
1337 }
1338 }
1339 }
1340
1341 #[derive(Debug, Deserialize, Serialize)]
1342 #[serde(rename_all = "camelCase")]
1343 pub struct ThinkingConfig {
1344 pub thinking_budget: u32,
1345 pub include_thoughts: Option<bool>,
1346 }
1347 #[derive(Debug, Deserialize, Serialize, Clone)]
1351 pub struct Schema {
1352 pub r#type: String,
1353 #[serde(skip_serializing_if = "Option::is_none")]
1354 pub format: Option<String>,
1355 #[serde(skip_serializing_if = "Option::is_none")]
1356 pub description: Option<String>,
1357 #[serde(skip_serializing_if = "Option::is_none")]
1358 pub nullable: Option<bool>,
1359 #[serde(skip_serializing_if = "Option::is_none")]
1360 pub r#enum: Option<Vec<String>>,
1361 #[serde(skip_serializing_if = "Option::is_none")]
1362 pub max_items: Option<i32>,
1363 #[serde(skip_serializing_if = "Option::is_none")]
1364 pub min_items: Option<i32>,
1365 #[serde(skip_serializing_if = "Option::is_none")]
1366 pub properties: Option<HashMap<String, Schema>>,
1367 #[serde(skip_serializing_if = "Option::is_none")]
1368 pub required: Option<Vec<String>>,
1369 #[serde(skip_serializing_if = "Option::is_none")]
1370 pub items: Option<Box<Schema>>,
1371 }
1372
1373 pub fn flatten_schema(mut schema: Value) -> Result<Value, CompletionError> {
1379 let defs = if let Some(obj) = schema.as_object() {
1381 obj.get("$defs").or_else(|| obj.get("definitions")).cloned()
1382 } else {
1383 None
1384 };
1385
1386 let Some(defs_value) = defs else {
1387 return Ok(schema);
1388 };
1389
1390 let Some(defs_obj) = defs_value.as_object() else {
1391 return Err(CompletionError::ResponseError(
1392 "$defs must be an object".into(),
1393 ));
1394 };
1395
1396 resolve_refs(&mut schema, defs_obj)?;
1397
1398 if let Some(obj) = schema.as_object_mut() {
1400 obj.remove("$defs");
1401 obj.remove("definitions");
1402 }
1403
1404 Ok(schema)
1405 }
1406
1407 fn resolve_refs(
1410 value: &mut Value,
1411 defs: &serde_json::Map<String, Value>,
1412 ) -> Result<(), CompletionError> {
1413 match value {
1414 Value::Object(obj) => {
1415 if let Some(ref_value) = obj.get("$ref")
1416 && let Some(ref_str) = ref_value.as_str()
1417 {
1418 let def_name = parse_ref_path(ref_str)?;
1420
1421 let def = defs.get(&def_name).ok_or_else(|| {
1422 CompletionError::ResponseError(format!("Reference not found: {}", ref_str))
1423 })?;
1424
1425 let mut resolved = def.clone();
1426 resolve_refs(&mut resolved, defs)?;
1427 *value = resolved;
1428 return Ok(());
1429 }
1430
1431 for (_, v) in obj.iter_mut() {
1432 resolve_refs(v, defs)?;
1433 }
1434 }
1435 Value::Array(arr) => {
1436 for item in arr.iter_mut() {
1437 resolve_refs(item, defs)?;
1438 }
1439 }
1440 _ => {}
1441 }
1442
1443 Ok(())
1444 }
1445
1446 fn parse_ref_path(ref_str: &str) -> Result<String, CompletionError> {
1452 if let Some(fragment) = ref_str.strip_prefix('#') {
1453 if let Some(name) = fragment.strip_prefix("/$defs/") {
1454 Ok(name.to_string())
1455 } else if let Some(name) = fragment.strip_prefix("/definitions/") {
1456 Ok(name.to_string())
1457 } else {
1458 Err(CompletionError::ResponseError(format!(
1459 "Unsupported reference format: {}",
1460 ref_str
1461 )))
1462 }
1463 } else {
1464 Err(CompletionError::ResponseError(format!(
1465 "Only fragment references (#/...) are supported: {}",
1466 ref_str
1467 )))
1468 }
1469 }
1470
1471 fn extract_type(type_value: &Value) -> Option<String> {
1474 if type_value.is_string() {
1475 type_value.as_str().map(String::from)
1476 } else if type_value.is_array() {
1477 type_value
1478 .as_array()
1479 .and_then(|arr| arr.first())
1480 .and_then(|v| v.as_str().map(String::from))
1481 } else {
1482 None
1483 }
1484 }
1485
1486 fn extract_type_from_composition(composition: &Value) -> Option<String> {
1489 composition.as_array().and_then(|arr| {
1490 arr.iter().find_map(|schema| {
1491 if let Some(obj) = schema.as_object() {
1492 if let Some(type_val) = obj.get("type")
1494 && let Some(type_str) = type_val.as_str()
1495 && type_str == "null"
1496 {
1497 return None;
1498 }
1499 obj.get("type").and_then(extract_type).or_else(|| {
1501 if obj.contains_key("properties") {
1502 Some("object".to_string())
1503 } else {
1504 None
1505 }
1506 })
1507 } else {
1508 None
1509 }
1510 })
1511 })
1512 }
1513
1514 fn extract_schema_from_composition(
1517 composition: &Value,
1518 ) -> Option<serde_json::Map<String, Value>> {
1519 composition.as_array().and_then(|arr| {
1520 arr.iter().find_map(|schema| {
1521 if let Some(obj) = schema.as_object()
1522 && let Some(type_val) = obj.get("type")
1523 && let Some(type_str) = type_val.as_str()
1524 {
1525 if type_str == "null" {
1526 return None;
1527 }
1528 Some(obj.clone())
1529 } else {
1530 None
1531 }
1532 })
1533 })
1534 }
1535
1536 fn infer_type(obj: &serde_json::Map<String, Value>) -> String {
1539 if let Some(type_val) = obj.get("type")
1541 && let Some(type_str) = extract_type(type_val)
1542 {
1543 return type_str;
1544 }
1545
1546 if let Some(any_of) = obj.get("anyOf")
1548 && let Some(type_str) = extract_type_from_composition(any_of)
1549 {
1550 return type_str;
1551 }
1552
1553 if let Some(one_of) = obj.get("oneOf")
1554 && let Some(type_str) = extract_type_from_composition(one_of)
1555 {
1556 return type_str;
1557 }
1558
1559 if let Some(all_of) = obj.get("allOf")
1560 && let Some(type_str) = extract_type_from_composition(all_of)
1561 {
1562 return type_str;
1563 }
1564
1565 if obj.contains_key("properties") {
1567 "object".to_string()
1568 } else {
1569 String::new()
1570 }
1571 }
1572
1573 impl TryFrom<Value> for Schema {
1574 type Error = CompletionError;
1575
1576 fn try_from(value: Value) -> Result<Self, Self::Error> {
1577 let flattened_val = flatten_schema(value)?;
1578 if let Some(obj) = flattened_val.as_object() {
1579 let props_source = if obj.get("properties").is_none() {
1582 if let Some(any_of) = obj.get("anyOf") {
1583 extract_schema_from_composition(any_of)
1584 } else if let Some(one_of) = obj.get("oneOf") {
1585 extract_schema_from_composition(one_of)
1586 } else if let Some(all_of) = obj.get("allOf") {
1587 extract_schema_from_composition(all_of)
1588 } else {
1589 None
1590 }
1591 .unwrap_or(obj.clone())
1592 } else {
1593 obj.clone()
1594 };
1595
1596 Ok(Schema {
1597 r#type: infer_type(obj),
1598 format: obj.get("format").and_then(|v| v.as_str()).map(String::from),
1599 description: obj
1600 .get("description")
1601 .and_then(|v| v.as_str())
1602 .map(String::from),
1603 nullable: obj.get("nullable").and_then(|v| v.as_bool()),
1604 r#enum: obj.get("enum").and_then(|v| v.as_array()).map(|arr| {
1605 arr.iter()
1606 .filter_map(|v| v.as_str().map(String::from))
1607 .collect()
1608 }),
1609 max_items: obj
1610 .get("maxItems")
1611 .and_then(|v| v.as_i64())
1612 .map(|v| v as i32),
1613 min_items: obj
1614 .get("minItems")
1615 .and_then(|v| v.as_i64())
1616 .map(|v| v as i32),
1617 properties: props_source
1618 .get("properties")
1619 .and_then(|v| v.as_object())
1620 .map(|map| {
1621 map.iter()
1622 .filter_map(|(k, v)| {
1623 v.clone().try_into().ok().map(|schema| (k.clone(), schema))
1624 })
1625 .collect()
1626 }),
1627 required: props_source
1628 .get("required")
1629 .and_then(|v| v.as_array())
1630 .map(|arr| {
1631 arr.iter()
1632 .filter_map(|v| v.as_str().map(String::from))
1633 .collect()
1634 }),
1635 items: obj
1636 .get("items")
1637 .and_then(|v| v.clone().try_into().ok())
1638 .map(Box::new),
1639 })
1640 } else {
1641 Err(CompletionError::ResponseError(
1642 "Expected a JSON object for Schema".into(),
1643 ))
1644 }
1645 }
1646 }
1647
1648 #[derive(Debug, Serialize)]
1649 #[serde(rename_all = "camelCase")]
1650 pub struct GenerateContentRequest {
1651 pub contents: Vec<Content>,
1652 #[serde(skip_serializing_if = "Option::is_none")]
1653 pub tools: Option<Tool>,
1654 pub tool_config: Option<ToolConfig>,
1655 pub generation_config: Option<GenerationConfig>,
1657 pub safety_settings: Option<Vec<SafetySetting>>,
1671 pub system_instruction: Option<Content>,
1674 #[serde(flatten, skip_serializing_if = "Option::is_none")]
1677 pub additional_params: Option<serde_json::Value>,
1678 }
1679
1680 #[derive(Debug, Serialize)]
1681 #[serde(rename_all = "camelCase")]
1682 pub struct Tool {
1683 pub function_declarations: Vec<FunctionDeclaration>,
1684 pub code_execution: Option<CodeExecution>,
1685 }
1686
1687 #[derive(Debug, Serialize, Clone)]
1688 #[serde(rename_all = "camelCase")]
1689 pub struct FunctionDeclaration {
1690 pub name: String,
1691 pub description: String,
1692 #[serde(skip_serializing_if = "Option::is_none")]
1693 pub parameters: Option<Schema>,
1694 }
1695
1696 #[derive(Debug, Serialize, Deserialize)]
1697 #[serde(rename_all = "camelCase")]
1698 pub struct ToolConfig {
1699 pub function_calling_config: Option<FunctionCallingMode>,
1700 }
1701
1702 #[derive(Debug, Serialize, Deserialize, Default)]
1703 #[serde(tag = "mode", rename_all = "UPPERCASE")]
1704 pub enum FunctionCallingMode {
1705 #[default]
1706 Auto,
1707 None,
1708 Any {
1709 #[serde(skip_serializing_if = "Option::is_none")]
1710 allowed_function_names: Option<Vec<String>>,
1711 },
1712 }
1713
1714 impl TryFrom<message::ToolChoice> for FunctionCallingMode {
1715 type Error = CompletionError;
1716 fn try_from(value: message::ToolChoice) -> Result<Self, Self::Error> {
1717 let res = match value {
1718 message::ToolChoice::Auto => Self::Auto,
1719 message::ToolChoice::None => Self::None,
1720 message::ToolChoice::Required => Self::Any {
1721 allowed_function_names: None,
1722 },
1723 message::ToolChoice::Specific { function_names } => Self::Any {
1724 allowed_function_names: Some(function_names),
1725 },
1726 };
1727
1728 Ok(res)
1729 }
1730 }
1731
1732 #[derive(Debug, Serialize)]
1733 pub struct CodeExecution {}
1734
1735 #[derive(Debug, Serialize)]
1736 #[serde(rename_all = "camelCase")]
1737 pub struct SafetySetting {
1738 pub category: HarmCategory,
1739 pub threshold: HarmBlockThreshold,
1740 }
1741
1742 #[derive(Debug, Serialize)]
1743 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1744 pub enum HarmBlockThreshold {
1745 HarmBlockThresholdUnspecified,
1746 BlockLowAndAbove,
1747 BlockMediumAndAbove,
1748 BlockOnlyHigh,
1749 BlockNone,
1750 Off,
1751 }
1752}
1753
1754#[cfg(test)]
1755mod tests {
1756 use crate::{message, providers::gemini::completion::gemini_api_types::flatten_schema};
1757
1758 use super::*;
1759 use serde_json::json;
1760
1761 #[test]
1762 fn test_deserialize_message_user() {
1763 let raw_message = r#"{
1764 "parts": [
1765 {"text": "Hello, world!"},
1766 {"inlineData": {"mimeType": "image/png", "data": "base64encodeddata"}},
1767 {"functionCall": {"name": "test_function", "args": {"arg1": "value1"}}},
1768 {"functionResponse": {"name": "test_function", "response": {"result": "success"}}},
1769 {"fileData": {"mimeType": "application/pdf", "fileUri": "http://example.com/file.pdf"}},
1770 {"executableCode": {"code": "print('Hello, world!')", "language": "PYTHON"}},
1771 {"codeExecutionResult": {"output": "Hello, world!", "outcome": "OUTCOME_OK"}}
1772 ],
1773 "role": "user"
1774 }"#;
1775
1776 let content: Content = {
1777 let jd = &mut serde_json::Deserializer::from_str(raw_message);
1778 serde_path_to_error::deserialize(jd).unwrap_or_else(|err| {
1779 panic!("Deserialization error at {}: {}", err.path(), err);
1780 })
1781 };
1782 assert_eq!(content.role, Some(Role::User));
1783 assert_eq!(content.parts.len(), 7);
1784
1785 let parts: Vec<Part> = content.parts.into_iter().collect();
1786
1787 if let Part {
1788 part: PartKind::Text(text),
1789 ..
1790 } = &parts[0]
1791 {
1792 assert_eq!(text, "Hello, world!");
1793 } else {
1794 panic!("Expected text part");
1795 }
1796
1797 if let Part {
1798 part: PartKind::InlineData(inline_data),
1799 ..
1800 } = &parts[1]
1801 {
1802 assert_eq!(inline_data.mime_type, "image/png");
1803 assert_eq!(inline_data.data, "base64encodeddata");
1804 } else {
1805 panic!("Expected inline data part");
1806 }
1807
1808 if let Part {
1809 part: PartKind::FunctionCall(function_call),
1810 ..
1811 } = &parts[2]
1812 {
1813 assert_eq!(function_call.name, "test_function");
1814 assert_eq!(
1815 function_call.args.as_object().unwrap().get("arg1").unwrap(),
1816 "value1"
1817 );
1818 } else {
1819 panic!("Expected function call part");
1820 }
1821
1822 if let Part {
1823 part: PartKind::FunctionResponse(function_response),
1824 ..
1825 } = &parts[3]
1826 {
1827 assert_eq!(function_response.name, "test_function");
1828 assert_eq!(
1829 function_response
1830 .response
1831 .as_ref()
1832 .unwrap()
1833 .get("result")
1834 .unwrap(),
1835 "success"
1836 );
1837 } else {
1838 panic!("Expected function response part");
1839 }
1840
1841 if let Part {
1842 part: PartKind::FileData(file_data),
1843 ..
1844 } = &parts[4]
1845 {
1846 assert_eq!(file_data.mime_type.as_ref().unwrap(), "application/pdf");
1847 assert_eq!(file_data.file_uri, "http://example.com/file.pdf");
1848 } else {
1849 panic!("Expected file data part");
1850 }
1851
1852 if let Part {
1853 part: PartKind::ExecutableCode(executable_code),
1854 ..
1855 } = &parts[5]
1856 {
1857 assert_eq!(executable_code.code, "print('Hello, world!')");
1858 } else {
1859 panic!("Expected executable code part");
1860 }
1861
1862 if let Part {
1863 part: PartKind::CodeExecutionResult(code_execution_result),
1864 ..
1865 } = &parts[6]
1866 {
1867 assert_eq!(
1868 code_execution_result.clone().output.unwrap(),
1869 "Hello, world!"
1870 );
1871 } else {
1872 panic!("Expected code execution result part");
1873 }
1874 }
1875
1876 #[test]
1877 fn test_deserialize_message_model() {
1878 let json_data = json!({
1879 "parts": [{"text": "Hello, user!"}],
1880 "role": "model"
1881 });
1882
1883 let content: Content = serde_json::from_value(json_data).unwrap();
1884 assert_eq!(content.role, Some(Role::Model));
1885 assert_eq!(content.parts.len(), 1);
1886 if let Some(Part {
1887 part: PartKind::Text(text),
1888 ..
1889 }) = content.parts.first()
1890 {
1891 assert_eq!(text, "Hello, user!");
1892 } else {
1893 panic!("Expected text part");
1894 }
1895 }
1896
1897 #[test]
1898 fn test_message_conversion_user() {
1899 let msg = message::Message::user("Hello, world!");
1900 let content: Content = msg.try_into().unwrap();
1901 assert_eq!(content.role, Some(Role::User));
1902 assert_eq!(content.parts.len(), 1);
1903 if let Some(Part {
1904 part: PartKind::Text(text),
1905 ..
1906 }) = &content.parts.first()
1907 {
1908 assert_eq!(text, "Hello, world!");
1909 } else {
1910 panic!("Expected text part");
1911 }
1912 }
1913
1914 #[test]
1915 fn test_message_conversion_model() {
1916 let msg = message::Message::assistant("Hello, user!");
1917
1918 let content: Content = msg.try_into().unwrap();
1919 assert_eq!(content.role, Some(Role::Model));
1920 assert_eq!(content.parts.len(), 1);
1921 if let Some(Part {
1922 part: PartKind::Text(text),
1923 ..
1924 }) = &content.parts.first()
1925 {
1926 assert_eq!(text, "Hello, user!");
1927 } else {
1928 panic!("Expected text part");
1929 }
1930 }
1931
1932 #[test]
1933 fn test_message_conversion_tool_call() {
1934 let tool_call = message::ToolCall {
1935 id: "test_tool".to_string(),
1936 call_id: None,
1937 function: message::ToolFunction {
1938 name: "test_function".to_string(),
1939 arguments: json!({"arg1": "value1"}),
1940 },
1941 };
1942
1943 let msg = message::Message::Assistant {
1944 id: None,
1945 content: OneOrMany::one(message::AssistantContent::ToolCall(tool_call)),
1946 };
1947
1948 let content: Content = msg.try_into().unwrap();
1949 assert_eq!(content.role, Some(Role::Model));
1950 assert_eq!(content.parts.len(), 1);
1951 if let Some(Part {
1952 part: PartKind::FunctionCall(function_call),
1953 ..
1954 }) = content.parts.first()
1955 {
1956 assert_eq!(function_call.name, "test_function");
1957 assert_eq!(
1958 function_call.args.as_object().unwrap().get("arg1").unwrap(),
1959 "value1"
1960 );
1961 } else {
1962 panic!("Expected function call part");
1963 }
1964 }
1965
1966 #[test]
1967 fn test_vec_schema_conversion() {
1968 let schema_with_ref = json!({
1969 "type": "array",
1970 "items": {
1971 "$ref": "#/$defs/Person"
1972 },
1973 "$defs": {
1974 "Person": {
1975 "type": "object",
1976 "properties": {
1977 "first_name": {
1978 "type": ["string", "null"],
1979 "description": "The person's first name, if provided (null otherwise)"
1980 },
1981 "last_name": {
1982 "type": ["string", "null"],
1983 "description": "The person's last name, if provided (null otherwise)"
1984 },
1985 "job": {
1986 "type": ["string", "null"],
1987 "description": "The person's job, if provided (null otherwise)"
1988 }
1989 },
1990 "required": []
1991 }
1992 }
1993 });
1994
1995 let result: Result<Schema, _> = schema_with_ref.try_into();
1996
1997 match result {
1998 Ok(schema) => {
1999 assert_eq!(schema.r#type, "array");
2000
2001 if let Some(items) = schema.items {
2002 println!("item types: {}", items.r#type);
2003
2004 assert_ne!(items.r#type, "", "Items type should not be empty string!");
2005 assert_eq!(items.r#type, "object", "Items should be object type");
2006 } else {
2007 panic!("Schema should have items field for array type");
2008 }
2009 }
2010 Err(e) => println!("Schema conversion failed: {:?}", e),
2011 }
2012 }
2013
2014 #[test]
2015 fn test_object_schema() {
2016 let simple_schema = json!({
2017 "type": "object",
2018 "properties": {
2019 "name": {
2020 "type": "string"
2021 }
2022 }
2023 });
2024
2025 let schema: Schema = simple_schema.try_into().unwrap();
2026 assert_eq!(schema.r#type, "object");
2027 assert!(schema.properties.is_some());
2028 }
2029
2030 #[test]
2031 fn test_array_with_inline_items() {
2032 let inline_schema = json!({
2033 "type": "array",
2034 "items": {
2035 "type": "object",
2036 "properties": {
2037 "name": {
2038 "type": "string"
2039 }
2040 }
2041 }
2042 });
2043
2044 let schema: Schema = inline_schema.try_into().unwrap();
2045 assert_eq!(schema.r#type, "array");
2046
2047 if let Some(items) = schema.items {
2048 assert_eq!(items.r#type, "object");
2049 assert!(items.properties.is_some());
2050 } else {
2051 panic!("Schema should have items field");
2052 }
2053 }
2054 #[test]
2055 fn test_flattened_schema() {
2056 let ref_schema = json!({
2057 "type": "array",
2058 "items": {
2059 "$ref": "#/$defs/Person"
2060 },
2061 "$defs": {
2062 "Person": {
2063 "type": "object",
2064 "properties": {
2065 "name": { "type": "string" }
2066 }
2067 }
2068 }
2069 });
2070
2071 let flattened = flatten_schema(ref_schema).unwrap();
2072 let schema: Schema = flattened.try_into().unwrap();
2073
2074 assert_eq!(schema.r#type, "array");
2075
2076 if let Some(items) = schema.items {
2077 println!("Flattened items type: '{}'", items.r#type);
2078
2079 assert_eq!(items.r#type, "object");
2080 assert!(items.properties.is_some());
2081 }
2082 }
2083}