1pub const GEMINI_2_5_PRO_PREVIEW_06_05: &str = "gemini-2.5-pro-preview-06-05";
7pub const GEMINI_2_5_PRO_PREVIEW_05_06: &str = "gemini-2.5-pro-preview-05-06";
9pub const GEMINI_2_5_PRO_PREVIEW_03_25: &str = "gemini-2.5-pro-preview-03-25";
11pub const GEMINI_2_5_FLASH_PREVIEW_04_17: &str = "gemini-2.5-flash-preview-04-17";
13pub const GEMINI_2_5_PRO_EXP_03_25: &str = "gemini-2.5-pro-exp-03-25";
15pub const GEMINI_2_5_FLASH: &str = "gemini-2.5-flash";
17pub const GEMINI_2_0_FLASH_LITE: &str = "gemini-2.0-flash-lite";
19pub const GEMINI_2_0_FLASH: &str = "gemini-2.0-flash";
21
22use self::gemini_api_types::Schema;
23use crate::http_client::HttpClientExt;
24use crate::message::{self, MimeType, Reasoning};
25
26use crate::providers::gemini::completion::gemini_api_types::{
27 AdditionalParameters, FunctionCallingMode, ToolConfig,
28};
29use crate::providers::gemini::streaming::StreamingCompletionResponse;
30use crate::telemetry::SpanCombinator;
31use crate::{
32 OneOrMany,
33 completion::{self, CompletionError, CompletionRequest},
34};
35use gemini_api_types::{
36 Content, FunctionDeclaration, GenerateContentRequest, GenerateContentResponse, Part, PartKind,
37 Role, Tool,
38};
39use serde_json::{Map, Value};
40use std::convert::TryFrom;
41use tracing::{Level, enabled, info_span};
42use tracing_futures::Instrument;
43
44use super::Client;
45
46#[derive(Clone, Debug)]
51pub struct CompletionModel<T = reqwest::Client> {
52 pub(crate) client: Client<T>,
53 pub model: String,
54}
55
56impl<T> CompletionModel<T> {
57 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
58 Self {
59 client,
60 model: model.into(),
61 }
62 }
63
64 pub fn with_model(client: Client<T>, model: &str) -> Self {
65 Self {
66 client,
67 model: model.into(),
68 }
69 }
70}
71
72impl<T> completion::CompletionModel for CompletionModel<T>
73where
74 T: HttpClientExt + Clone + 'static,
75{
76 type Response = GenerateContentResponse;
77 type StreamingResponse = StreamingCompletionResponse;
78 type Client = super::Client<T>;
79
80 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
81 Self::new(client.clone(), model)
82 }
83
84 async fn completion(
85 &self,
86 completion_request: CompletionRequest,
87 ) -> Result<completion::CompletionResponse<GenerateContentResponse>, CompletionError> {
88 let span = if tracing::Span::current().is_disabled() {
89 info_span!(
90 target: "rig::completions",
91 "generate_content",
92 gen_ai.operation.name = "generate_content",
93 gen_ai.provider.name = "gcp.gemini",
94 gen_ai.request.model = self.model,
95 gen_ai.system_instructions = &completion_request.preamble,
96 gen_ai.response.id = tracing::field::Empty,
97 gen_ai.response.model = tracing::field::Empty,
98 gen_ai.usage.output_tokens = tracing::field::Empty,
99 gen_ai.usage.input_tokens = tracing::field::Empty,
100 )
101 } else {
102 tracing::Span::current()
103 };
104
105 let request = create_request_body(completion_request)?;
106
107 if enabled!(Level::TRACE) {
108 tracing::trace!(
109 target: "rig::completions",
110 "Gemini completion request: {}",
111 serde_json::to_string_pretty(&request)?
112 );
113 }
114
115 let body = serde_json::to_vec(&request)?;
116
117 let path = format!("/v1beta/models/{}:generateContent", self.model);
118
119 let request = self
120 .client
121 .post(path.as_str())?
122 .body(body)
123 .map_err(|e| CompletionError::HttpError(e.into()))?;
124
125 async move {
126 let response = self.client.send::<_, Vec<u8>>(request).await?;
127
128 if response.status().is_success() {
129 let response_body = response
130 .into_body()
131 .await
132 .map_err(CompletionError::HttpError)?;
133
134 let response_text = String::from_utf8_lossy(&response_body).to_string();
135
136 let response: GenerateContentResponse = serde_json::from_slice(&response_body)
137 .map_err(|err| {
138 tracing::error!(
139 error = %err,
140 body = %response_text,
141 "Failed to deserialize Gemini completion response"
142 );
143 CompletionError::JsonError(err)
144 })?;
145
146 let span = tracing::Span::current();
147 span.record_response_metadata(&response);
148 span.record_token_usage(&response.usage_metadata);
149
150 if enabled!(Level::TRACE) {
151 tracing::trace!(
152 target: "rig::completions",
153 "Gemini completion response: {}",
154 serde_json::to_string_pretty(&response)?
155 );
156 }
157
158 response.try_into()
159 } else {
160 let text = String::from_utf8_lossy(
161 &response
162 .into_body()
163 .await
164 .map_err(CompletionError::HttpError)?,
165 )
166 .into();
167
168 Err(CompletionError::ProviderError(text))
169 }
170 }
171 .instrument(span)
172 .await
173 }
174
175 async fn stream(
176 &self,
177 request: CompletionRequest,
178 ) -> Result<
179 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
180 CompletionError,
181 > {
182 CompletionModel::stream(self, request).await
183 }
184}
185
186pub(crate) fn create_request_body(
187 completion_request: CompletionRequest,
188) -> Result<GenerateContentRequest, CompletionError> {
189 let mut full_history = Vec::new();
190 full_history.extend(completion_request.chat_history);
191
192 let additional_params = completion_request
193 .additional_params
194 .unwrap_or_else(|| Value::Object(Map::new()));
195
196 let AdditionalParameters {
197 mut generation_config,
198 additional_params,
199 } = serde_json::from_value::<AdditionalParameters>(additional_params)?;
200
201 generation_config = generation_config.map(|mut cfg| {
202 if let Some(temp) = completion_request.temperature {
203 cfg.temperature = Some(temp);
204 };
205
206 if let Some(max_tokens) = completion_request.max_tokens {
207 cfg.max_output_tokens = Some(max_tokens);
208 };
209
210 cfg
211 });
212
213 let system_instruction = completion_request.preamble.clone().map(|preamble| Content {
214 parts: vec![preamble.into()],
215 role: Some(Role::Model),
216 });
217
218 let tools = if completion_request.tools.is_empty() {
219 None
220 } else {
221 Some(Tool::try_from(completion_request.tools)?)
222 };
223
224 let tool_config = if let Some(cfg) = completion_request.tool_choice {
225 Some(ToolConfig {
226 function_calling_config: Some(FunctionCallingMode::try_from(cfg)?),
227 })
228 } else {
229 None
230 };
231
232 let request = GenerateContentRequest {
233 contents: full_history
234 .into_iter()
235 .map(|msg| {
236 msg.try_into()
237 .map_err(|e| CompletionError::RequestError(Box::new(e)))
238 })
239 .collect::<Result<Vec<_>, _>>()?,
240 generation_config,
241 safety_settings: None,
242 tools,
243 tool_config,
244 system_instruction,
245 additional_params,
246 };
247
248 Ok(request)
249}
250
251impl TryFrom<completion::ToolDefinition> for Tool {
252 type Error = CompletionError;
253
254 fn try_from(tool: completion::ToolDefinition) -> Result<Self, Self::Error> {
255 let parameters: Option<Schema> =
256 if tool.parameters == serde_json::json!({"type": "object", "properties": {}}) {
257 None
258 } else {
259 Some(tool.parameters.try_into()?)
260 };
261
262 Ok(Self {
263 function_declarations: vec![FunctionDeclaration {
264 name: tool.name,
265 description: tool.description,
266 parameters,
267 }],
268 code_execution: None,
269 })
270 }
271}
272
273impl TryFrom<Vec<completion::ToolDefinition>> for Tool {
274 type Error = CompletionError;
275
276 fn try_from(tools: Vec<completion::ToolDefinition>) -> Result<Self, Self::Error> {
277 let mut function_declarations = Vec::new();
278
279 for tool in tools {
280 let parameters =
281 if tool.parameters == serde_json::json!({"type": "object", "properties": {}}) {
282 None
283 } else {
284 match tool.parameters.try_into() {
285 Ok(schema) => Some(schema),
286 Err(e) => {
287 let emsg = format!(
288 "Tool '{}' could not be converted to a schema: {:?}",
289 tool.name, e,
290 );
291 return Err(CompletionError::ProviderError(emsg));
292 }
293 }
294 };
295
296 function_declarations.push(FunctionDeclaration {
297 name: tool.name,
298 description: tool.description,
299 parameters,
300 });
301 }
302
303 Ok(Self {
304 function_declarations,
305 code_execution: None,
306 })
307 }
308}
309
310impl TryFrom<GenerateContentResponse> for completion::CompletionResponse<GenerateContentResponse> {
311 type Error = CompletionError;
312
313 fn try_from(response: GenerateContentResponse) -> Result<Self, Self::Error> {
314 let candidate = response.candidates.first().ok_or_else(|| {
315 CompletionError::ResponseError("No response candidates in response".into())
316 })?;
317
318 let content = candidate
319 .content
320 .as_ref()
321 .ok_or_else(|| {
322 let reason = candidate
323 .finish_reason
324 .as_ref()
325 .map(|r| format!("finish_reason={r:?}"))
326 .unwrap_or_else(|| "finish_reason=<unknown>".to_string());
327 let message = candidate
328 .finish_message
329 .as_deref()
330 .unwrap_or("no finish message provided");
331 CompletionError::ResponseError(format!(
332 "Gemini candidate missing content ({reason}, finish_message={message})"
333 ))
334 })?
335 .parts
336 .iter()
337 .map(
338 |Part {
339 thought,
340 thought_signature,
341 part,
342 ..
343 }| {
344 Ok(match part {
345 PartKind::Text(text) => {
346 if let Some(thought) = thought
347 && *thought
348 {
349 completion::AssistantContent::Reasoning(Reasoning::new(text))
350 } else {
351 completion::AssistantContent::text(text)
352 }
353 }
354 PartKind::InlineData(inline_data) => {
355 let mime_type =
356 message::MediaType::from_mime_type(&inline_data.mime_type);
357
358 match mime_type {
359 Some(message::MediaType::Image(media_type)) => {
360 message::AssistantContent::image_base64(
361 &inline_data.data,
362 Some(media_type),
363 Some(message::ImageDetail::default()),
364 )
365 }
366 _ => {
367 return Err(CompletionError::ResponseError(format!(
368 "Unsupported media type {mime_type:?}"
369 )));
370 }
371 }
372 }
373 PartKind::FunctionCall(function_call) => {
374 completion::AssistantContent::ToolCall(
375 message::ToolCall::new(
376 function_call.name.clone(),
377 message::ToolFunction::new(
378 function_call.name.clone(),
379 function_call.args.clone(),
380 ),
381 )
382 .with_signature(thought_signature.clone()),
383 )
384 }
385 _ => {
386 return Err(CompletionError::ResponseError(
387 "Response did not contain a message or tool call".into(),
388 ));
389 }
390 })
391 },
392 )
393 .collect::<Result<Vec<_>, _>>()?;
394
395 let choice = OneOrMany::many(content).map_err(|_| {
396 CompletionError::ResponseError(
397 "Response contained no message or tool call (empty)".to_owned(),
398 )
399 })?;
400
401 let usage = response
402 .usage_metadata
403 .as_ref()
404 .map(|usage| completion::Usage {
405 input_tokens: usage.prompt_token_count as u64,
406 output_tokens: usage.candidates_token_count.unwrap_or(0) as u64,
407 total_tokens: usage.total_token_count as u64,
408 })
409 .unwrap_or_default();
410
411 Ok(completion::CompletionResponse {
412 choice,
413 usage,
414 raw_response: response,
415 })
416 }
417}
418
419pub mod gemini_api_types {
420 use crate::telemetry::ProviderResponseExt;
421 use std::{collections::HashMap, convert::Infallible, str::FromStr};
422
423 use serde::{Deserialize, Serialize};
427 use serde_json::{Value, json};
428
429 use crate::completion::GetTokenUsage;
430 use crate::message::{DocumentSourceKind, ImageMediaType, MessageError, MimeType};
431 use crate::{
432 completion::CompletionError,
433 message::{self},
434 providers::gemini::gemini_api_types::{CodeExecutionResult, ExecutableCode},
435 };
436
437 #[derive(Debug, Deserialize, Serialize, Default)]
438 #[serde(rename_all = "camelCase")]
439 pub struct AdditionalParameters {
440 pub generation_config: Option<GenerationConfig>,
442 #[serde(flatten, skip_serializing_if = "Option::is_none")]
444 pub additional_params: Option<serde_json::Value>,
445 }
446
447 impl AdditionalParameters {
448 pub fn with_config(mut self, cfg: GenerationConfig) -> Self {
449 self.generation_config = Some(cfg);
450 self
451 }
452
453 pub fn with_params(mut self, params: serde_json::Value) -> Self {
454 self.additional_params = Some(params);
455 self
456 }
457 }
458
459 #[derive(Debug, Deserialize, Serialize)]
467 #[serde(rename_all = "camelCase")]
468 pub struct GenerateContentResponse {
469 pub response_id: String,
470 pub candidates: Vec<ContentCandidate>,
472 pub prompt_feedback: Option<PromptFeedback>,
474 pub usage_metadata: Option<UsageMetadata>,
476 pub model_version: Option<String>,
477 }
478
479 impl ProviderResponseExt for GenerateContentResponse {
480 type OutputMessage = ContentCandidate;
481 type Usage = UsageMetadata;
482
483 fn get_response_id(&self) -> Option<String> {
484 Some(self.response_id.clone())
485 }
486
487 fn get_response_model_name(&self) -> Option<String> {
488 None
489 }
490
491 fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
492 self.candidates.clone()
493 }
494
495 fn get_text_response(&self) -> Option<String> {
496 let str = self
497 .candidates
498 .iter()
499 .filter_map(|x| {
500 let content = x.content.as_ref()?;
501 if content.role.as_ref().is_none_or(|y| y != &Role::Model) {
502 return None;
503 }
504
505 let res = content
506 .parts
507 .iter()
508 .filter_map(|part| {
509 if let PartKind::Text(ref str) = part.part {
510 Some(str.to_owned())
511 } else {
512 None
513 }
514 })
515 .collect::<Vec<String>>()
516 .join("\n");
517
518 Some(res)
519 })
520 .collect::<Vec<String>>()
521 .join("\n");
522
523 if str.is_empty() { None } else { Some(str) }
524 }
525
526 fn get_usage(&self) -> Option<Self::Usage> {
527 self.usage_metadata.clone()
528 }
529 }
530
531 #[derive(Clone, Debug, Deserialize, Serialize)]
533 #[serde(rename_all = "camelCase")]
534 pub struct ContentCandidate {
535 #[serde(skip_serializing_if = "Option::is_none")]
537 pub content: Option<Content>,
538 pub finish_reason: Option<FinishReason>,
541 pub safety_ratings: Option<Vec<SafetyRating>>,
544 pub citation_metadata: Option<CitationMetadata>,
548 pub token_count: Option<i32>,
550 pub avg_logprobs: Option<f64>,
552 pub logprobs_result: Option<LogprobsResult>,
554 pub index: Option<i32>,
556 pub finish_message: Option<String>,
558 }
559
560 #[derive(Clone, Debug, Deserialize, Serialize)]
561 pub struct Content {
562 #[serde(default)]
564 pub parts: Vec<Part>,
565 pub role: Option<Role>,
568 }
569
570 impl TryFrom<message::Message> for Content {
571 type Error = message::MessageError;
572
573 fn try_from(msg: message::Message) -> Result<Self, Self::Error> {
574 Ok(match msg {
575 message::Message::User { content } => Content {
576 parts: content
577 .into_iter()
578 .map(|c| c.try_into())
579 .collect::<Result<Vec<_>, _>>()?,
580 role: Some(Role::User),
581 },
582 message::Message::Assistant { content, .. } => Content {
583 role: Some(Role::Model),
584 parts: content
585 .into_iter()
586 .map(|content| content.try_into())
587 .collect::<Result<Vec<_>, _>>()?,
588 },
589 })
590 }
591 }
592
593 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
594 #[serde(rename_all = "lowercase")]
595 pub enum Role {
596 User,
597 Model,
598 }
599
600 #[derive(Debug, Default, Deserialize, Serialize, Clone, PartialEq)]
601 #[serde(rename_all = "camelCase")]
602 pub struct Part {
603 #[serde(skip_serializing_if = "Option::is_none")]
605 pub thought: Option<bool>,
606 #[serde(skip_serializing_if = "Option::is_none")]
608 pub thought_signature: Option<String>,
609 #[serde(flatten)]
610 pub part: PartKind,
611 #[serde(flatten, skip_serializing_if = "Option::is_none")]
612 pub additional_params: Option<Value>,
613 }
614
615 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
619 #[serde(rename_all = "camelCase")]
620 pub enum PartKind {
621 Text(String),
622 InlineData(Blob),
623 FunctionCall(FunctionCall),
624 FunctionResponse(FunctionResponse),
625 FileData(FileData),
626 ExecutableCode(ExecutableCode),
627 CodeExecutionResult(CodeExecutionResult),
628 }
629
630 impl Default for PartKind {
633 fn default() -> Self {
634 Self::Text(String::new())
635 }
636 }
637
638 impl From<String> for Part {
639 fn from(text: String) -> Self {
640 Self {
641 thought: Some(false),
642 thought_signature: None,
643 part: PartKind::Text(text),
644 additional_params: None,
645 }
646 }
647 }
648
649 impl From<&str> for Part {
650 fn from(text: &str) -> Self {
651 Self::from(text.to_string())
652 }
653 }
654
655 impl FromStr for Part {
656 type Err = Infallible;
657
658 fn from_str(s: &str) -> Result<Self, Self::Err> {
659 Ok(s.into())
660 }
661 }
662
663 impl TryFrom<(ImageMediaType, DocumentSourceKind)> for PartKind {
664 type Error = message::MessageError;
665 fn try_from(
666 (mime_type, doc_src): (ImageMediaType, DocumentSourceKind),
667 ) -> Result<Self, Self::Error> {
668 let mime_type = mime_type.to_mime_type().to_string();
669 let part = match doc_src {
670 DocumentSourceKind::Url(url) => PartKind::FileData(FileData {
671 mime_type: Some(mime_type),
672 file_uri: url,
673 }),
674 DocumentSourceKind::Base64(data) | DocumentSourceKind::String(data) => {
675 PartKind::InlineData(Blob { mime_type, data })
676 }
677 DocumentSourceKind::Raw(_) => {
678 return Err(message::MessageError::ConversionError(
679 "Raw files not supported, encode as base64 first".into(),
680 ));
681 }
682 DocumentSourceKind::Unknown => {
683 return Err(message::MessageError::ConversionError(
684 "Can't convert an unknown document source".to_string(),
685 ));
686 }
687 };
688
689 Ok(part)
690 }
691 }
692
693 impl TryFrom<message::UserContent> for Part {
694 type Error = message::MessageError;
695
696 fn try_from(content: message::UserContent) -> Result<Self, Self::Error> {
697 match content {
698 message::UserContent::Text(message::Text { text }) => Ok(Part {
699 thought: Some(false),
700 thought_signature: None,
701 part: PartKind::Text(text),
702 additional_params: None,
703 }),
704 message::UserContent::ToolResult(message::ToolResult { id, content, .. }) => {
705 let content = match content.first() {
706 message::ToolResultContent::Text(text) => text.text,
707 message::ToolResultContent::Image(_) => {
708 return Err(message::MessageError::ConversionError(
709 "Tool result content must be text".to_string(),
710 ));
711 }
712 };
713 let result: serde_json::Value =
715 serde_json::from_str(&content).unwrap_or_else(|error| {
716 tracing::trace!(
717 ?error,
718 "Tool result is not a valid JSON, treat it as normal string"
719 );
720 json!(content)
721 });
722 Ok(Part {
723 thought: Some(false),
724 thought_signature: None,
725 part: PartKind::FunctionResponse(FunctionResponse {
726 name: id,
727 response: Some(json!({ "result": result })),
728 }),
729 additional_params: None,
730 })
731 }
732 message::UserContent::Image(message::Image {
733 data, media_type, ..
734 }) => match media_type {
735 Some(media_type) => match media_type {
736 message::ImageMediaType::JPEG
737 | message::ImageMediaType::PNG
738 | message::ImageMediaType::WEBP
739 | message::ImageMediaType::HEIC
740 | message::ImageMediaType::HEIF => {
741 let part = PartKind::try_from((media_type, data))?;
742 Ok(Part {
743 thought: Some(false),
744 thought_signature: None,
745 part,
746 additional_params: None,
747 })
748 }
749 _ => Err(message::MessageError::ConversionError(format!(
750 "Unsupported image media type {media_type:?}"
751 ))),
752 },
753 None => Err(message::MessageError::ConversionError(
754 "Media type for image is required for Gemini".to_string(),
755 )),
756 },
757 message::UserContent::Document(message::Document {
758 data, media_type, ..
759 }) => {
760 let Some(media_type) = media_type else {
761 return Err(MessageError::ConversionError(
762 "A mime type is required for document inputs to Gemini".to_string(),
763 ));
764 };
765
766 if !media_type.is_code() {
767 let mime_type = media_type.to_mime_type().to_string();
768
769 let part = match data {
770 DocumentSourceKind::Url(file_uri) => PartKind::FileData(FileData {
771 mime_type: Some(mime_type),
772 file_uri,
773 }),
774 DocumentSourceKind::Base64(data) | DocumentSourceKind::String(data) => {
775 PartKind::InlineData(Blob { mime_type, data })
776 }
777 DocumentSourceKind::Raw(_) => {
778 return Err(message::MessageError::ConversionError(
779 "Raw files not supported, encode as base64 first".into(),
780 ));
781 }
782 _ => {
783 return Err(message::MessageError::ConversionError(
784 "Document has no body".to_string(),
785 ));
786 }
787 };
788
789 Ok(Part {
790 thought: Some(false),
791 part,
792 ..Default::default()
793 })
794 } else {
795 Err(message::MessageError::ConversionError(format!(
796 "Unsupported document media type {media_type:?}"
797 )))
798 }
799 }
800
801 message::UserContent::Audio(message::Audio {
802 data, media_type, ..
803 }) => {
804 let Some(media_type) = media_type else {
805 return Err(MessageError::ConversionError(
806 "A mime type is required for audio inputs to Gemini".to_string(),
807 ));
808 };
809
810 let mime_type = media_type.to_mime_type().to_string();
811
812 let part = match data {
813 DocumentSourceKind::Base64(data) => {
814 PartKind::InlineData(Blob { data, mime_type })
815 }
816
817 DocumentSourceKind::Url(file_uri) => PartKind::FileData(FileData {
818 mime_type: Some(mime_type),
819 file_uri,
820 }),
821 DocumentSourceKind::String(_) => {
822 return Err(message::MessageError::ConversionError(
823 "Strings cannot be used as audio files!".into(),
824 ));
825 }
826 DocumentSourceKind::Raw(_) => {
827 return Err(message::MessageError::ConversionError(
828 "Raw files not supported, encode as base64 first".into(),
829 ));
830 }
831 DocumentSourceKind::Unknown => {
832 return Err(message::MessageError::ConversionError(
833 "Content has no body".to_string(),
834 ));
835 }
836 };
837
838 Ok(Part {
839 thought: Some(false),
840 part,
841 ..Default::default()
842 })
843 }
844 message::UserContent::Video(message::Video {
845 data,
846 media_type,
847 additional_params,
848 ..
849 }) => {
850 let mime_type = media_type.map(|media_ty| media_ty.to_mime_type().to_string());
851
852 let part = match data {
853 DocumentSourceKind::Url(file_uri) => {
854 if file_uri.starts_with("https://www.youtube.com") {
855 PartKind::FileData(FileData {
856 mime_type,
857 file_uri,
858 })
859 } else {
860 if mime_type.is_none() {
861 return Err(MessageError::ConversionError(
862 "A mime type is required for non-Youtube video file inputs to Gemini"
863 .to_string(),
864 ));
865 }
866
867 PartKind::FileData(FileData {
868 mime_type,
869 file_uri,
870 })
871 }
872 }
873 DocumentSourceKind::Base64(data) => {
874 let Some(mime_type) = mime_type else {
875 return Err(MessageError::ConversionError(
876 "A media type is expected for base64 encoded strings"
877 .to_string(),
878 ));
879 };
880 PartKind::InlineData(Blob { mime_type, data })
881 }
882 DocumentSourceKind::String(_) => {
883 return Err(message::MessageError::ConversionError(
884 "Strings cannot be used as audio files!".into(),
885 ));
886 }
887 DocumentSourceKind::Raw(_) => {
888 return Err(message::MessageError::ConversionError(
889 "Raw file data not supported, encode as base64 first".into(),
890 ));
891 }
892 DocumentSourceKind::Unknown => {
893 return Err(message::MessageError::ConversionError(
894 "Media type for video is required for Gemini".to_string(),
895 ));
896 }
897 };
898
899 Ok(Part {
900 thought: Some(false),
901 thought_signature: None,
902 part,
903 additional_params,
904 })
905 }
906 }
907 }
908 }
909
910 impl TryFrom<message::AssistantContent> for Part {
911 type Error = message::MessageError;
912
913 fn try_from(content: message::AssistantContent) -> Result<Self, Self::Error> {
914 match content {
915 message::AssistantContent::Text(message::Text { text }) => Ok(text.into()),
916 message::AssistantContent::Image(message::Image {
917 data, media_type, ..
918 }) => match media_type {
919 Some(media_type) => match media_type {
920 message::ImageMediaType::JPEG
921 | message::ImageMediaType::PNG
922 | message::ImageMediaType::WEBP
923 | message::ImageMediaType::HEIC
924 | message::ImageMediaType::HEIF => {
925 let part = PartKind::try_from((media_type, data))?;
926 Ok(Part {
927 thought: Some(false),
928 thought_signature: None,
929 part,
930 additional_params: None,
931 })
932 }
933 _ => Err(message::MessageError::ConversionError(format!(
934 "Unsupported image media type {media_type:?}"
935 ))),
936 },
937 None => Err(message::MessageError::ConversionError(
938 "Media type for image is required for Gemini".to_string(),
939 )),
940 },
941 message::AssistantContent::ToolCall(tool_call) => Ok(tool_call.into()),
942 message::AssistantContent::Reasoning(message::Reasoning { reasoning, .. }) => {
943 Ok(Part {
944 thought: Some(true),
945 thought_signature: None,
946 part: PartKind::Text(
947 reasoning.first().cloned().unwrap_or_else(|| "".to_string()),
948 ),
949 additional_params: None,
950 })
951 }
952 }
953 }
954 }
955
956 impl From<message::ToolCall> for Part {
957 fn from(tool_call: message::ToolCall) -> Self {
958 Self {
959 thought: Some(false),
960 thought_signature: tool_call.signature,
961 part: PartKind::FunctionCall(FunctionCall {
962 name: tool_call.function.name,
963 args: tool_call.function.arguments,
964 }),
965 additional_params: None,
966 }
967 }
968 }
969
970 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
973 #[serde(rename_all = "camelCase")]
974 pub struct Blob {
975 pub mime_type: String,
978 pub data: String,
980 }
981
982 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
985 pub struct FunctionCall {
986 pub name: String,
989 pub args: serde_json::Value,
991 }
992
993 impl From<message::ToolCall> for FunctionCall {
994 fn from(tool_call: message::ToolCall) -> Self {
995 Self {
996 name: tool_call.function.name,
997 args: tool_call.function.arguments,
998 }
999 }
1000 }
1001
1002 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1006 pub struct FunctionResponse {
1007 pub name: String,
1010 pub response: Option<serde_json::Value>,
1012 }
1013
1014 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1016 #[serde(rename_all = "camelCase")]
1017 pub struct FileData {
1018 pub mime_type: Option<String>,
1020 pub file_uri: String,
1022 }
1023
1024 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1025 pub struct SafetyRating {
1026 pub category: HarmCategory,
1027 pub probability: HarmProbability,
1028 }
1029
1030 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1031 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1032 pub enum HarmProbability {
1033 HarmProbabilityUnspecified,
1034 Negligible,
1035 Low,
1036 Medium,
1037 High,
1038 }
1039
1040 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1041 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1042 pub enum HarmCategory {
1043 HarmCategoryUnspecified,
1044 HarmCategoryDerogatory,
1045 HarmCategoryToxicity,
1046 HarmCategoryViolence,
1047 HarmCategorySexually,
1048 HarmCategoryMedical,
1049 HarmCategoryDangerous,
1050 HarmCategoryHarassment,
1051 HarmCategoryHateSpeech,
1052 HarmCategorySexuallyExplicit,
1053 HarmCategoryDangerousContent,
1054 HarmCategoryCivicIntegrity,
1055 }
1056
1057 #[derive(Debug, Deserialize, Clone, Default, Serialize)]
1058 #[serde(rename_all = "camelCase")]
1059 pub struct UsageMetadata {
1060 pub prompt_token_count: i32,
1061 #[serde(skip_serializing_if = "Option::is_none")]
1062 pub cached_content_token_count: Option<i32>,
1063 #[serde(skip_serializing_if = "Option::is_none")]
1064 pub candidates_token_count: Option<i32>,
1065 pub total_token_count: i32,
1066 #[serde(skip_serializing_if = "Option::is_none")]
1067 pub thoughts_token_count: Option<i32>,
1068 }
1069
1070 impl std::fmt::Display for UsageMetadata {
1071 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1072 write!(
1073 f,
1074 "Prompt token count: {}\nCached content token count: {}\nCandidates token count: {}\nTotal token count: {}",
1075 self.prompt_token_count,
1076 match self.cached_content_token_count {
1077 Some(count) => count.to_string(),
1078 None => "n/a".to_string(),
1079 },
1080 match self.candidates_token_count {
1081 Some(count) => count.to_string(),
1082 None => "n/a".to_string(),
1083 },
1084 self.total_token_count
1085 )
1086 }
1087 }
1088
1089 impl GetTokenUsage for UsageMetadata {
1090 fn token_usage(&self) -> Option<crate::completion::Usage> {
1091 let mut usage = crate::completion::Usage::new();
1092
1093 usage.input_tokens = self.prompt_token_count as u64;
1094 usage.output_tokens = (self.cached_content_token_count.unwrap_or_default()
1095 + self.candidates_token_count.unwrap_or_default()
1096 + self.thoughts_token_count.unwrap_or_default())
1097 as u64;
1098 usage.total_tokens = usage.input_tokens + usage.output_tokens;
1099
1100 Some(usage)
1101 }
1102 }
1103
1104 #[derive(Debug, Deserialize, Serialize)]
1106 #[serde(rename_all = "camelCase")]
1107 pub struct PromptFeedback {
1108 pub block_reason: Option<BlockReason>,
1110 pub safety_ratings: Option<Vec<SafetyRating>>,
1112 }
1113
1114 #[derive(Debug, Deserialize, Serialize)]
1116 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1117 pub enum BlockReason {
1118 BlockReasonUnspecified,
1120 Safety,
1122 Other,
1124 Blocklist,
1126 ProhibitedContent,
1128 }
1129
1130 #[derive(Clone, Debug, Deserialize, Serialize)]
1131 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1132 pub enum FinishReason {
1133 FinishReasonUnspecified,
1135 Stop,
1137 MaxTokens,
1139 Safety,
1141 Recitation,
1143 Language,
1145 Other,
1147 Blocklist,
1149 ProhibitedContent,
1151 Spii,
1153 MalformedFunctionCall,
1155 }
1156
1157 #[derive(Clone, Debug, Deserialize, Serialize)]
1158 #[serde(rename_all = "camelCase")]
1159 pub struct CitationMetadata {
1160 pub citation_sources: Vec<CitationSource>,
1161 }
1162
1163 #[derive(Clone, Debug, Deserialize, Serialize)]
1164 #[serde(rename_all = "camelCase")]
1165 pub struct CitationSource {
1166 #[serde(skip_serializing_if = "Option::is_none")]
1167 pub uri: Option<String>,
1168 #[serde(skip_serializing_if = "Option::is_none")]
1169 pub start_index: Option<i32>,
1170 #[serde(skip_serializing_if = "Option::is_none")]
1171 pub end_index: Option<i32>,
1172 #[serde(skip_serializing_if = "Option::is_none")]
1173 pub license: Option<String>,
1174 }
1175
1176 #[derive(Clone, Debug, Deserialize, Serialize)]
1177 #[serde(rename_all = "camelCase")]
1178 pub struct LogprobsResult {
1179 pub top_candidate: Vec<TopCandidate>,
1180 pub chosen_candidate: Vec<LogProbCandidate>,
1181 }
1182
1183 #[derive(Clone, Debug, Deserialize, Serialize)]
1184 pub struct TopCandidate {
1185 pub candidates: Vec<LogProbCandidate>,
1186 }
1187
1188 #[derive(Clone, Debug, Deserialize, Serialize)]
1189 #[serde(rename_all = "camelCase")]
1190 pub struct LogProbCandidate {
1191 pub token: String,
1192 pub token_id: String,
1193 pub log_probability: f64,
1194 }
1195
1196 #[derive(Debug, Deserialize, Serialize)]
1201 #[serde(rename_all = "camelCase")]
1202 pub struct GenerationConfig {
1203 #[serde(skip_serializing_if = "Option::is_none")]
1206 pub stop_sequences: Option<Vec<String>>,
1207 #[serde(skip_serializing_if = "Option::is_none")]
1213 pub response_mime_type: Option<String>,
1214 #[serde(skip_serializing_if = "Option::is_none")]
1218 pub response_schema: Option<Schema>,
1219 #[serde(
1225 skip_serializing_if = "Option::is_none",
1226 rename = "_responseJsonSchema"
1227 )]
1228 pub _response_json_schema: Option<Value>,
1229 #[serde(skip_serializing_if = "Option::is_none")]
1231 pub response_json_schema: Option<Value>,
1232 #[serde(skip_serializing_if = "Option::is_none")]
1235 pub candidate_count: Option<i32>,
1236 #[serde(skip_serializing_if = "Option::is_none")]
1239 pub max_output_tokens: Option<u64>,
1240 #[serde(skip_serializing_if = "Option::is_none")]
1243 pub temperature: Option<f64>,
1244 #[serde(skip_serializing_if = "Option::is_none")]
1251 pub top_p: Option<f64>,
1252 #[serde(skip_serializing_if = "Option::is_none")]
1258 pub top_k: Option<i32>,
1259 #[serde(skip_serializing_if = "Option::is_none")]
1265 pub presence_penalty: Option<f64>,
1266 #[serde(skip_serializing_if = "Option::is_none")]
1274 pub frequency_penalty: Option<f64>,
1275 #[serde(skip_serializing_if = "Option::is_none")]
1277 pub response_logprobs: Option<bool>,
1278 #[serde(skip_serializing_if = "Option::is_none")]
1281 pub logprobs: Option<i32>,
1282 #[serde(skip_serializing_if = "Option::is_none")]
1284 pub thinking_config: Option<ThinkingConfig>,
1285 #[serde(skip_serializing_if = "Option::is_none")]
1286 pub image_config: Option<ImageConfig>,
1287 }
1288
1289 impl Default for GenerationConfig {
1290 fn default() -> Self {
1291 Self {
1292 temperature: Some(1.0),
1293 max_output_tokens: Some(4096),
1294 stop_sequences: None,
1295 response_mime_type: None,
1296 response_schema: None,
1297 _response_json_schema: None,
1298 response_json_schema: None,
1299 candidate_count: None,
1300 top_p: None,
1301 top_k: None,
1302 presence_penalty: None,
1303 frequency_penalty: None,
1304 response_logprobs: None,
1305 logprobs: None,
1306 thinking_config: None,
1307 image_config: None,
1308 }
1309 }
1310 }
1311
1312 #[derive(Debug, Deserialize, Serialize)]
1313 #[serde(rename_all = "camelCase")]
1314 pub struct ThinkingConfig {
1315 pub thinking_budget: u32,
1316 pub include_thoughts: Option<bool>,
1317 }
1318
1319 #[derive(Debug, Deserialize, Serialize)]
1320 #[serde(rename_all = "camelCase")]
1321 pub struct ImageConfig {
1322 #[serde(skip_serializing_if = "Option::is_none")]
1323 pub aspect_ratio: Option<String>,
1324 #[serde(skip_serializing_if = "Option::is_none")]
1325 pub image_size: Option<String>,
1326 }
1327
1328 #[derive(Debug, Deserialize, Serialize, Clone)]
1332 pub struct Schema {
1333 pub r#type: String,
1334 #[serde(skip_serializing_if = "Option::is_none")]
1335 pub format: Option<String>,
1336 #[serde(skip_serializing_if = "Option::is_none")]
1337 pub description: Option<String>,
1338 #[serde(skip_serializing_if = "Option::is_none")]
1339 pub nullable: Option<bool>,
1340 #[serde(skip_serializing_if = "Option::is_none")]
1341 pub r#enum: Option<Vec<String>>,
1342 #[serde(skip_serializing_if = "Option::is_none")]
1343 pub max_items: Option<i32>,
1344 #[serde(skip_serializing_if = "Option::is_none")]
1345 pub min_items: Option<i32>,
1346 #[serde(skip_serializing_if = "Option::is_none")]
1347 pub properties: Option<HashMap<String, Schema>>,
1348 #[serde(skip_serializing_if = "Option::is_none")]
1349 pub required: Option<Vec<String>>,
1350 #[serde(skip_serializing_if = "Option::is_none")]
1351 pub items: Option<Box<Schema>>,
1352 }
1353
1354 pub fn flatten_schema(mut schema: Value) -> Result<Value, CompletionError> {
1360 let defs = if let Some(obj) = schema.as_object() {
1362 obj.get("$defs").or_else(|| obj.get("definitions")).cloned()
1363 } else {
1364 None
1365 };
1366
1367 let Some(defs_value) = defs else {
1368 return Ok(schema);
1369 };
1370
1371 let Some(defs_obj) = defs_value.as_object() else {
1372 return Err(CompletionError::ResponseError(
1373 "$defs must be an object".into(),
1374 ));
1375 };
1376
1377 resolve_refs(&mut schema, defs_obj)?;
1378
1379 if let Some(obj) = schema.as_object_mut() {
1381 obj.remove("$defs");
1382 obj.remove("definitions");
1383 }
1384
1385 Ok(schema)
1386 }
1387
1388 fn resolve_refs(
1391 value: &mut Value,
1392 defs: &serde_json::Map<String, Value>,
1393 ) -> Result<(), CompletionError> {
1394 match value {
1395 Value::Object(obj) => {
1396 if let Some(ref_value) = obj.get("$ref")
1397 && let Some(ref_str) = ref_value.as_str()
1398 {
1399 let def_name = parse_ref_path(ref_str)?;
1401
1402 let def = defs.get(&def_name).ok_or_else(|| {
1403 CompletionError::ResponseError(format!("Reference not found: {}", ref_str))
1404 })?;
1405
1406 let mut resolved = def.clone();
1407 resolve_refs(&mut resolved, defs)?;
1408 *value = resolved;
1409 return Ok(());
1410 }
1411
1412 for (_, v) in obj.iter_mut() {
1413 resolve_refs(v, defs)?;
1414 }
1415 }
1416 Value::Array(arr) => {
1417 for item in arr.iter_mut() {
1418 resolve_refs(item, defs)?;
1419 }
1420 }
1421 _ => {}
1422 }
1423
1424 Ok(())
1425 }
1426
1427 fn parse_ref_path(ref_str: &str) -> Result<String, CompletionError> {
1433 if let Some(fragment) = ref_str.strip_prefix('#') {
1434 if let Some(name) = fragment.strip_prefix("/$defs/") {
1435 Ok(name.to_string())
1436 } else if let Some(name) = fragment.strip_prefix("/definitions/") {
1437 Ok(name.to_string())
1438 } else {
1439 Err(CompletionError::ResponseError(format!(
1440 "Unsupported reference format: {}",
1441 ref_str
1442 )))
1443 }
1444 } else {
1445 Err(CompletionError::ResponseError(format!(
1446 "Only fragment references (#/...) are supported: {}",
1447 ref_str
1448 )))
1449 }
1450 }
1451
1452 fn extract_type(type_value: &Value) -> Option<String> {
1455 if type_value.is_string() {
1456 type_value.as_str().map(String::from)
1457 } else if type_value.is_array() {
1458 type_value
1459 .as_array()
1460 .and_then(|arr| arr.first())
1461 .and_then(|v| v.as_str().map(String::from))
1462 } else {
1463 None
1464 }
1465 }
1466
1467 fn extract_type_from_composition(composition: &Value) -> Option<String> {
1470 composition.as_array().and_then(|arr| {
1471 arr.iter().find_map(|schema| {
1472 if let Some(obj) = schema.as_object() {
1473 if let Some(type_val) = obj.get("type")
1475 && let Some(type_str) = type_val.as_str()
1476 && type_str == "null"
1477 {
1478 return None;
1479 }
1480 obj.get("type").and_then(extract_type).or_else(|| {
1482 if obj.contains_key("properties") {
1483 Some("object".to_string())
1484 } else {
1485 None
1486 }
1487 })
1488 } else {
1489 None
1490 }
1491 })
1492 })
1493 }
1494
1495 fn extract_schema_from_composition(
1498 composition: &Value,
1499 ) -> Option<serde_json::Map<String, Value>> {
1500 composition.as_array().and_then(|arr| {
1501 arr.iter().find_map(|schema| {
1502 if let Some(obj) = schema.as_object()
1503 && let Some(type_val) = obj.get("type")
1504 && let Some(type_str) = type_val.as_str()
1505 {
1506 if type_str == "null" {
1507 return None;
1508 }
1509 Some(obj.clone())
1510 } else {
1511 None
1512 }
1513 })
1514 })
1515 }
1516
1517 fn infer_type(obj: &serde_json::Map<String, Value>) -> String {
1520 if let Some(type_val) = obj.get("type")
1522 && let Some(type_str) = extract_type(type_val)
1523 {
1524 return type_str;
1525 }
1526
1527 if let Some(any_of) = obj.get("anyOf")
1529 && let Some(type_str) = extract_type_from_composition(any_of)
1530 {
1531 return type_str;
1532 }
1533
1534 if let Some(one_of) = obj.get("oneOf")
1535 && let Some(type_str) = extract_type_from_composition(one_of)
1536 {
1537 return type_str;
1538 }
1539
1540 if let Some(all_of) = obj.get("allOf")
1541 && let Some(type_str) = extract_type_from_composition(all_of)
1542 {
1543 return type_str;
1544 }
1545
1546 if obj.contains_key("properties") {
1548 "object".to_string()
1549 } else {
1550 String::new()
1551 }
1552 }
1553
1554 impl TryFrom<Value> for Schema {
1555 type Error = CompletionError;
1556
1557 fn try_from(value: Value) -> Result<Self, Self::Error> {
1558 let flattened_val = flatten_schema(value)?;
1559 if let Some(obj) = flattened_val.as_object() {
1560 let props_source = if obj.get("properties").is_none() {
1563 if let Some(any_of) = obj.get("anyOf") {
1564 extract_schema_from_composition(any_of)
1565 } else if let Some(one_of) = obj.get("oneOf") {
1566 extract_schema_from_composition(one_of)
1567 } else if let Some(all_of) = obj.get("allOf") {
1568 extract_schema_from_composition(all_of)
1569 } else {
1570 None
1571 }
1572 .unwrap_or(obj.clone())
1573 } else {
1574 obj.clone()
1575 };
1576
1577 Ok(Schema {
1578 r#type: infer_type(obj),
1579 format: obj.get("format").and_then(|v| v.as_str()).map(String::from),
1580 description: obj
1581 .get("description")
1582 .and_then(|v| v.as_str())
1583 .map(String::from),
1584 nullable: obj.get("nullable").and_then(|v| v.as_bool()),
1585 r#enum: obj.get("enum").and_then(|v| v.as_array()).map(|arr| {
1586 arr.iter()
1587 .filter_map(|v| v.as_str().map(String::from))
1588 .collect()
1589 }),
1590 max_items: obj
1591 .get("maxItems")
1592 .and_then(|v| v.as_i64())
1593 .map(|v| v as i32),
1594 min_items: obj
1595 .get("minItems")
1596 .and_then(|v| v.as_i64())
1597 .map(|v| v as i32),
1598 properties: props_source
1599 .get("properties")
1600 .and_then(|v| v.as_object())
1601 .map(|map| {
1602 map.iter()
1603 .filter_map(|(k, v)| {
1604 v.clone().try_into().ok().map(|schema| (k.clone(), schema))
1605 })
1606 .collect()
1607 }),
1608 required: props_source
1609 .get("required")
1610 .and_then(|v| v.as_array())
1611 .map(|arr| {
1612 arr.iter()
1613 .filter_map(|v| v.as_str().map(String::from))
1614 .collect()
1615 }),
1616 items: obj
1617 .get("items")
1618 .and_then(|v| v.clone().try_into().ok())
1619 .map(Box::new),
1620 })
1621 } else {
1622 Err(CompletionError::ResponseError(
1623 "Expected a JSON object for Schema".into(),
1624 ))
1625 }
1626 }
1627 }
1628
1629 #[derive(Debug, Serialize)]
1630 #[serde(rename_all = "camelCase")]
1631 pub struct GenerateContentRequest {
1632 pub contents: Vec<Content>,
1633 #[serde(skip_serializing_if = "Option::is_none")]
1634 pub tools: Option<Tool>,
1635 pub tool_config: Option<ToolConfig>,
1636 pub generation_config: Option<GenerationConfig>,
1638 pub safety_settings: Option<Vec<SafetySetting>>,
1652 pub system_instruction: Option<Content>,
1655 #[serde(flatten, skip_serializing_if = "Option::is_none")]
1658 pub additional_params: Option<serde_json::Value>,
1659 }
1660
1661 #[derive(Debug, Serialize)]
1662 #[serde(rename_all = "camelCase")]
1663 pub struct Tool {
1664 pub function_declarations: Vec<FunctionDeclaration>,
1665 pub code_execution: Option<CodeExecution>,
1666 }
1667
1668 #[derive(Debug, Serialize, Clone)]
1669 #[serde(rename_all = "camelCase")]
1670 pub struct FunctionDeclaration {
1671 pub name: String,
1672 pub description: String,
1673 #[serde(skip_serializing_if = "Option::is_none")]
1674 pub parameters: Option<Schema>,
1675 }
1676
1677 #[derive(Debug, Serialize, Deserialize)]
1678 #[serde(rename_all = "camelCase")]
1679 pub struct ToolConfig {
1680 pub function_calling_config: Option<FunctionCallingMode>,
1681 }
1682
1683 #[derive(Debug, Serialize, Deserialize, Default)]
1684 #[serde(tag = "mode", rename_all = "UPPERCASE")]
1685 pub enum FunctionCallingMode {
1686 #[default]
1687 Auto,
1688 None,
1689 Any {
1690 #[serde(skip_serializing_if = "Option::is_none")]
1691 allowed_function_names: Option<Vec<String>>,
1692 },
1693 }
1694
1695 impl TryFrom<message::ToolChoice> for FunctionCallingMode {
1696 type Error = CompletionError;
1697 fn try_from(value: message::ToolChoice) -> Result<Self, Self::Error> {
1698 let res = match value {
1699 message::ToolChoice::Auto => Self::Auto,
1700 message::ToolChoice::None => Self::None,
1701 message::ToolChoice::Required => Self::Any {
1702 allowed_function_names: None,
1703 },
1704 message::ToolChoice::Specific { function_names } => Self::Any {
1705 allowed_function_names: Some(function_names),
1706 },
1707 };
1708
1709 Ok(res)
1710 }
1711 }
1712
1713 #[derive(Debug, Serialize)]
1714 pub struct CodeExecution {}
1715
1716 #[derive(Debug, Serialize)]
1717 #[serde(rename_all = "camelCase")]
1718 pub struct SafetySetting {
1719 pub category: HarmCategory,
1720 pub threshold: HarmBlockThreshold,
1721 }
1722
1723 #[derive(Debug, Serialize)]
1724 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1725 pub enum HarmBlockThreshold {
1726 HarmBlockThresholdUnspecified,
1727 BlockLowAndAbove,
1728 BlockMediumAndAbove,
1729 BlockOnlyHigh,
1730 BlockNone,
1731 Off,
1732 }
1733}
1734
1735#[cfg(test)]
1736mod tests {
1737 use crate::{message, providers::gemini::completion::gemini_api_types::flatten_schema};
1738
1739 use super::*;
1740 use serde_json::json;
1741
1742 #[test]
1743 fn test_deserialize_message_user() {
1744 let raw_message = r#"{
1745 "parts": [
1746 {"text": "Hello, world!"},
1747 {"inlineData": {"mimeType": "image/png", "data": "base64encodeddata"}},
1748 {"functionCall": {"name": "test_function", "args": {"arg1": "value1"}}},
1749 {"functionResponse": {"name": "test_function", "response": {"result": "success"}}},
1750 {"fileData": {"mimeType": "application/pdf", "fileUri": "http://example.com/file.pdf"}},
1751 {"executableCode": {"code": "print('Hello, world!')", "language": "PYTHON"}},
1752 {"codeExecutionResult": {"output": "Hello, world!", "outcome": "OUTCOME_OK"}}
1753 ],
1754 "role": "user"
1755 }"#;
1756
1757 let content: Content = {
1758 let jd = &mut serde_json::Deserializer::from_str(raw_message);
1759 serde_path_to_error::deserialize(jd).unwrap_or_else(|err| {
1760 panic!("Deserialization error at {}: {}", err.path(), err);
1761 })
1762 };
1763 assert_eq!(content.role, Some(Role::User));
1764 assert_eq!(content.parts.len(), 7);
1765
1766 let parts: Vec<Part> = content.parts.into_iter().collect();
1767
1768 if let Part {
1769 part: PartKind::Text(text),
1770 ..
1771 } = &parts[0]
1772 {
1773 assert_eq!(text, "Hello, world!");
1774 } else {
1775 panic!("Expected text part");
1776 }
1777
1778 if let Part {
1779 part: PartKind::InlineData(inline_data),
1780 ..
1781 } = &parts[1]
1782 {
1783 assert_eq!(inline_data.mime_type, "image/png");
1784 assert_eq!(inline_data.data, "base64encodeddata");
1785 } else {
1786 panic!("Expected inline data part");
1787 }
1788
1789 if let Part {
1790 part: PartKind::FunctionCall(function_call),
1791 ..
1792 } = &parts[2]
1793 {
1794 assert_eq!(function_call.name, "test_function");
1795 assert_eq!(
1796 function_call.args.as_object().unwrap().get("arg1").unwrap(),
1797 "value1"
1798 );
1799 } else {
1800 panic!("Expected function call part");
1801 }
1802
1803 if let Part {
1804 part: PartKind::FunctionResponse(function_response),
1805 ..
1806 } = &parts[3]
1807 {
1808 assert_eq!(function_response.name, "test_function");
1809 assert_eq!(
1810 function_response
1811 .response
1812 .as_ref()
1813 .unwrap()
1814 .get("result")
1815 .unwrap(),
1816 "success"
1817 );
1818 } else {
1819 panic!("Expected function response part");
1820 }
1821
1822 if let Part {
1823 part: PartKind::FileData(file_data),
1824 ..
1825 } = &parts[4]
1826 {
1827 assert_eq!(file_data.mime_type.as_ref().unwrap(), "application/pdf");
1828 assert_eq!(file_data.file_uri, "http://example.com/file.pdf");
1829 } else {
1830 panic!("Expected file data part");
1831 }
1832
1833 if let Part {
1834 part: PartKind::ExecutableCode(executable_code),
1835 ..
1836 } = &parts[5]
1837 {
1838 assert_eq!(executable_code.code, "print('Hello, world!')");
1839 } else {
1840 panic!("Expected executable code part");
1841 }
1842
1843 if let Part {
1844 part: PartKind::CodeExecutionResult(code_execution_result),
1845 ..
1846 } = &parts[6]
1847 {
1848 assert_eq!(
1849 code_execution_result.clone().output.unwrap(),
1850 "Hello, world!"
1851 );
1852 } else {
1853 panic!("Expected code execution result part");
1854 }
1855 }
1856
1857 #[test]
1858 fn test_deserialize_message_model() {
1859 let json_data = json!({
1860 "parts": [{"text": "Hello, user!"}],
1861 "role": "model"
1862 });
1863
1864 let content: Content = serde_json::from_value(json_data).unwrap();
1865 assert_eq!(content.role, Some(Role::Model));
1866 assert_eq!(content.parts.len(), 1);
1867 if let Some(Part {
1868 part: PartKind::Text(text),
1869 ..
1870 }) = content.parts.first()
1871 {
1872 assert_eq!(text, "Hello, user!");
1873 } else {
1874 panic!("Expected text part");
1875 }
1876 }
1877
1878 #[test]
1879 fn test_message_conversion_user() {
1880 let msg = message::Message::user("Hello, world!");
1881 let content: Content = msg.try_into().unwrap();
1882 assert_eq!(content.role, Some(Role::User));
1883 assert_eq!(content.parts.len(), 1);
1884 if let Some(Part {
1885 part: PartKind::Text(text),
1886 ..
1887 }) = &content.parts.first()
1888 {
1889 assert_eq!(text, "Hello, world!");
1890 } else {
1891 panic!("Expected text part");
1892 }
1893 }
1894
1895 #[test]
1896 fn test_message_conversion_model() {
1897 let msg = message::Message::assistant("Hello, user!");
1898
1899 let content: Content = msg.try_into().unwrap();
1900 assert_eq!(content.role, Some(Role::Model));
1901 assert_eq!(content.parts.len(), 1);
1902 if let Some(Part {
1903 part: PartKind::Text(text),
1904 ..
1905 }) = &content.parts.first()
1906 {
1907 assert_eq!(text, "Hello, user!");
1908 } else {
1909 panic!("Expected text part");
1910 }
1911 }
1912
1913 #[test]
1914 fn test_message_conversion_tool_call() {
1915 let tool_call = message::ToolCall {
1916 id: "test_tool".to_string(),
1917 call_id: None,
1918 function: message::ToolFunction {
1919 name: "test_function".to_string(),
1920 arguments: json!({"arg1": "value1"}),
1921 },
1922 signature: None,
1923 additional_params: None,
1924 };
1925
1926 let msg = message::Message::Assistant {
1927 id: None,
1928 content: OneOrMany::one(message::AssistantContent::ToolCall(tool_call)),
1929 };
1930
1931 let content: Content = msg.try_into().unwrap();
1932 assert_eq!(content.role, Some(Role::Model));
1933 assert_eq!(content.parts.len(), 1);
1934 if let Some(Part {
1935 part: PartKind::FunctionCall(function_call),
1936 ..
1937 }) = content.parts.first()
1938 {
1939 assert_eq!(function_call.name, "test_function");
1940 assert_eq!(
1941 function_call.args.as_object().unwrap().get("arg1").unwrap(),
1942 "value1"
1943 );
1944 } else {
1945 panic!("Expected function call part");
1946 }
1947 }
1948
1949 #[test]
1950 fn test_vec_schema_conversion() {
1951 let schema_with_ref = json!({
1952 "type": "array",
1953 "items": {
1954 "$ref": "#/$defs/Person"
1955 },
1956 "$defs": {
1957 "Person": {
1958 "type": "object",
1959 "properties": {
1960 "first_name": {
1961 "type": ["string", "null"],
1962 "description": "The person's first name, if provided (null otherwise)"
1963 },
1964 "last_name": {
1965 "type": ["string", "null"],
1966 "description": "The person's last name, if provided (null otherwise)"
1967 },
1968 "job": {
1969 "type": ["string", "null"],
1970 "description": "The person's job, if provided (null otherwise)"
1971 }
1972 },
1973 "required": []
1974 }
1975 }
1976 });
1977
1978 let result: Result<Schema, _> = schema_with_ref.try_into();
1979
1980 match result {
1981 Ok(schema) => {
1982 assert_eq!(schema.r#type, "array");
1983
1984 if let Some(items) = schema.items {
1985 println!("item types: {}", items.r#type);
1986
1987 assert_ne!(items.r#type, "", "Items type should not be empty string!");
1988 assert_eq!(items.r#type, "object", "Items should be object type");
1989 } else {
1990 panic!("Schema should have items field for array type");
1991 }
1992 }
1993 Err(e) => println!("Schema conversion failed: {:?}", e),
1994 }
1995 }
1996
1997 #[test]
1998 fn test_object_schema() {
1999 let simple_schema = json!({
2000 "type": "object",
2001 "properties": {
2002 "name": {
2003 "type": "string"
2004 }
2005 }
2006 });
2007
2008 let schema: Schema = simple_schema.try_into().unwrap();
2009 assert_eq!(schema.r#type, "object");
2010 assert!(schema.properties.is_some());
2011 }
2012
2013 #[test]
2014 fn test_array_with_inline_items() {
2015 let inline_schema = json!({
2016 "type": "array",
2017 "items": {
2018 "type": "object",
2019 "properties": {
2020 "name": {
2021 "type": "string"
2022 }
2023 }
2024 }
2025 });
2026
2027 let schema: Schema = inline_schema.try_into().unwrap();
2028 assert_eq!(schema.r#type, "array");
2029
2030 if let Some(items) = schema.items {
2031 assert_eq!(items.r#type, "object");
2032 assert!(items.properties.is_some());
2033 } else {
2034 panic!("Schema should have items field");
2035 }
2036 }
2037 #[test]
2038 fn test_flattened_schema() {
2039 let ref_schema = json!({
2040 "type": "array",
2041 "items": {
2042 "$ref": "#/$defs/Person"
2043 },
2044 "$defs": {
2045 "Person": {
2046 "type": "object",
2047 "properties": {
2048 "name": { "type": "string" }
2049 }
2050 }
2051 }
2052 });
2053
2054 let flattened = flatten_schema(ref_schema).unwrap();
2055 let schema: Schema = flattened.try_into().unwrap();
2056
2057 assert_eq!(schema.r#type, "array");
2058
2059 if let Some(items) = schema.items {
2060 println!("Flattened items type: '{}'", items.r#type);
2061
2062 assert_eq!(items.r#type, "object");
2063 assert!(items.properties.is_some());
2064 }
2065 }
2066}