1pub const GEMINI_2_5_PRO_PREVIEW_06_05: &str = "gemini-2.5-pro-preview-06-05";
7pub const GEMINI_2_5_PRO_PREVIEW_05_06: &str = "gemini-2.5-pro-preview-05-06";
9pub const GEMINI_2_5_PRO_PREVIEW_03_25: &str = "gemini-2.5-pro-preview-03-25";
11pub const GEMINI_2_5_FLASH_PREVIEW_05_20: &str = "gemini-2.5-flash-preview-05-20";
13pub const GEMINI_2_5_FLASH_PREVIEW_04_17: &str = "gemini-2.5-flash-preview-04-17";
15pub const GEMINI_2_5_PRO_EXP_03_25: &str = "gemini-2.5-pro-exp-03-25";
17pub const GEMINI_2_0_FLASH_LITE: &str = "gemini-2.0-flash-lite";
19pub const GEMINI_2_0_FLASH: &str = "gemini-2.0-flash";
21pub const GEMINI_1_5_FLASH: &str = "gemini-1.5-flash";
23pub const GEMINI_1_5_PRO: &str = "gemini-1.5-pro";
25pub const GEMINI_1_5_PRO_8B: &str = "gemini-1.5-pro-8b";
27pub const GEMINI_1_0_PRO: &str = "gemini-1.0-pro";
29
30use self::gemini_api_types::Schema;
31use crate::http_client::HttpClientExt;
32use crate::message::Reasoning;
33use crate::providers::gemini::completion::gemini_api_types::{
34 AdditionalParameters, FunctionCallingMode, ToolConfig,
35};
36use crate::providers::gemini::streaming::StreamingCompletionResponse;
37use crate::telemetry::SpanCombinator;
38use crate::{
39 OneOrMany,
40 completion::{self, CompletionError, CompletionRequest},
41};
42use gemini_api_types::{
43 Content, FunctionDeclaration, GenerateContentRequest, GenerateContentResponse, Part, PartKind,
44 Role, Tool,
45};
46use serde_json::{Map, Value};
47use std::convert::TryFrom;
48use tracing::info_span;
49
50use super::Client;
51
52#[derive(Clone)]
57pub struct CompletionModel<T = reqwest::Client> {
58 pub(crate) client: Client<T>,
59 pub model: String,
60}
61
62impl<T> CompletionModel<T> {
63 pub fn new(client: Client<T>, model: &str) -> Self {
64 Self {
65 client,
66 model: model.to_string(),
67 }
68 }
69}
70
71impl<T> completion::CompletionModel for CompletionModel<T>
72where
73 T: HttpClientExt + Clone + 'static,
74{
75 type Response = GenerateContentResponse;
76 type StreamingResponse = StreamingCompletionResponse;
77
78 #[cfg_attr(feature = "worker", worker::send)]
79 async fn completion(
80 &self,
81 completion_request: CompletionRequest,
82 ) -> Result<completion::CompletionResponse<GenerateContentResponse>, CompletionError> {
83 let span = if tracing::Span::current().is_disabled() {
84 info_span!(
85 target: "rig::completions",
86 "generate_content",
87 gen_ai.operation.name = "generate_content",
88 gen_ai.provider.name = "gcp.gemini",
89 gen_ai.request.model = self.model,
90 gen_ai.system_instructions = &completion_request.preamble,
91 gen_ai.response.id = tracing::field::Empty,
92 gen_ai.response.model = tracing::field::Empty,
93 gen_ai.usage.output_tokens = tracing::field::Empty,
94 gen_ai.usage.input_tokens = tracing::field::Empty,
95 gen_ai.input.messages = tracing::field::Empty,
96 gen_ai.output.messages = tracing::field::Empty,
97 )
98 } else {
99 tracing::Span::current()
100 };
101
102 let request = create_request_body(completion_request)?;
103 span.record_model_input(&request.contents);
104
105 span.record_model_input(&request.contents);
106
107 tracing::debug!(
108 "Sending completion request to Gemini API {}",
109 serde_json::to_string_pretty(&request)?
110 );
111
112 let body = serde_json::to_vec(&request)?;
113
114 let request = self
115 .client
116 .post(&format!("/v1beta/models/{}:generateContent", self.model))
117 .header("Content-Type", "application/json")
118 .body(body)
119 .map_err(|e| CompletionError::HttpError(e.into()))?;
120
121 let response = self.client.send::<_, Vec<u8>>(request).await?;
122
123 if response.status().is_success() {
124 let response_body = response
125 .into_body()
126 .await
127 .map_err(CompletionError::HttpError)?;
128
129 let response: GenerateContentResponse = serde_json::from_slice(&response_body)?;
130
131 match response.usage_metadata {
132 Some(ref usage) => tracing::info!(target: "rig",
133 "Gemini completion token usage: {}",
134 usage
135 ),
136 None => tracing::info!(target: "rig",
137 "Gemini completion token usage: n/a",
138 ),
139 }
140
141 let span = tracing::Span::current();
142 span.record_model_output(&response.candidates);
143 span.record_response_metadata(&response);
144 span.record_token_usage(&response.usage_metadata);
145
146 tracing::debug!(
147 "Received response from Gemini API: {}",
148 serde_json::to_string_pretty(&response)?
149 );
150
151 response.try_into()
152 } else {
153 let text = String::from_utf8_lossy(
154 &response
155 .into_body()
156 .await
157 .map_err(CompletionError::HttpError)?,
158 )
159 .into();
160
161 Err(CompletionError::ProviderError(text))
162 }
163 }
164
165 #[cfg_attr(feature = "worker", worker::send)]
166 async fn stream(
167 &self,
168 request: CompletionRequest,
169 ) -> Result<
170 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
171 CompletionError,
172 > {
173 CompletionModel::stream(self, request).await
174 }
175}
176
177pub(crate) fn create_request_body(
178 completion_request: CompletionRequest,
179) -> Result<GenerateContentRequest, CompletionError> {
180 let mut full_history = Vec::new();
181 full_history.extend(completion_request.chat_history);
182
183 let additional_params = completion_request
184 .additional_params
185 .unwrap_or_else(|| Value::Object(Map::new()));
186
187 let AdditionalParameters {
188 mut generation_config,
189 additional_params,
190 } = serde_json::from_value::<AdditionalParameters>(additional_params)?;
191
192 if let Some(temp) = completion_request.temperature {
193 generation_config.temperature = Some(temp);
194 }
195
196 if let Some(max_tokens) = completion_request.max_tokens {
197 generation_config.max_output_tokens = Some(max_tokens);
198 }
199
200 let system_instruction = completion_request.preamble.clone().map(|preamble| Content {
201 parts: vec![preamble.into()],
202 role: Some(Role::Model),
203 });
204
205 let tools = if completion_request.tools.is_empty() {
206 None
207 } else {
208 Some(Tool::try_from(completion_request.tools)?)
209 };
210
211 let tool_config = if let Some(cfg) = completion_request.tool_choice {
212 Some(ToolConfig {
213 function_calling_config: Some(FunctionCallingMode::try_from(cfg)?),
214 })
215 } else {
216 None
217 };
218
219 let request = GenerateContentRequest {
220 contents: full_history
221 .into_iter()
222 .map(|msg| {
223 msg.try_into()
224 .map_err(|e| CompletionError::RequestError(Box::new(e)))
225 })
226 .collect::<Result<Vec<_>, _>>()?,
227 generation_config: Some(generation_config),
228 safety_settings: None,
229 tools,
230 tool_config,
231 system_instruction,
232 additional_params,
233 };
234
235 Ok(request)
236}
237
238impl TryFrom<completion::ToolDefinition> for Tool {
239 type Error = CompletionError;
240
241 fn try_from(tool: completion::ToolDefinition) -> Result<Self, Self::Error> {
242 let parameters: Option<Schema> =
243 if tool.parameters == serde_json::json!({"type": "object", "properties": {}}) {
244 None
245 } else {
246 Some(tool.parameters.try_into()?)
247 };
248
249 Ok(Self {
250 function_declarations: vec![FunctionDeclaration {
251 name: tool.name,
252 description: tool.description,
253 parameters,
254 }],
255 code_execution: None,
256 })
257 }
258}
259
260impl TryFrom<Vec<completion::ToolDefinition>> for Tool {
261 type Error = CompletionError;
262
263 fn try_from(tools: Vec<completion::ToolDefinition>) -> Result<Self, Self::Error> {
264 let mut function_declarations = Vec::new();
265
266 for tool in tools {
267 let parameters =
268 if tool.parameters == serde_json::json!({"type": "object", "properties": {}}) {
269 None
270 } else {
271 match tool.parameters.try_into() {
272 Ok(schema) => Some(schema),
273 Err(e) => {
274 let emsg = format!(
275 "Tool '{}' could not be converted to a schema: {:?}",
276 tool.name, e,
277 );
278 return Err(CompletionError::ProviderError(emsg));
279 }
280 }
281 };
282
283 function_declarations.push(FunctionDeclaration {
284 name: tool.name,
285 description: tool.description,
286 parameters,
287 });
288 }
289
290 Ok(Self {
291 function_declarations,
292 code_execution: None,
293 })
294 }
295}
296
297impl TryFrom<GenerateContentResponse> for completion::CompletionResponse<GenerateContentResponse> {
298 type Error = CompletionError;
299
300 fn try_from(response: GenerateContentResponse) -> Result<Self, Self::Error> {
301 let candidate = response.candidates.first().ok_or_else(|| {
302 CompletionError::ResponseError("No response candidates in response".into())
303 })?;
304
305 let content = candidate
306 .content
307 .parts
308 .iter()
309 .map(|Part { thought, part, .. }| {
310 Ok(match part {
311 PartKind::Text(text) => {
312 if let Some(thought) = thought
313 && *thought
314 {
315 completion::AssistantContent::Reasoning(Reasoning::new(text))
316 } else {
317 completion::AssistantContent::text(text)
318 }
319 }
320 PartKind::FunctionCall(function_call) => {
321 completion::AssistantContent::tool_call(
322 &function_call.name,
323 &function_call.name,
324 function_call.args.clone(),
325 )
326 }
327 _ => {
328 return Err(CompletionError::ResponseError(
329 "Response did not contain a message or tool call".into(),
330 ));
331 }
332 })
333 })
334 .collect::<Result<Vec<_>, _>>()?;
335
336 let choice = OneOrMany::many(content).map_err(|_| {
337 CompletionError::ResponseError(
338 "Response contained no message or tool call (empty)".to_owned(),
339 )
340 })?;
341
342 let usage = response
343 .usage_metadata
344 .as_ref()
345 .map(|usage| completion::Usage {
346 input_tokens: usage.prompt_token_count as u64,
347 output_tokens: usage.candidates_token_count.unwrap_or(0) as u64,
348 total_tokens: usage.total_token_count as u64,
349 })
350 .unwrap_or_default();
351
352 Ok(completion::CompletionResponse {
353 choice,
354 usage,
355 raw_response: response,
356 })
357 }
358}
359
360pub mod gemini_api_types {
361 use crate::telemetry::ProviderResponseExt;
362 use std::{collections::HashMap, convert::Infallible, str::FromStr};
363
364 use serde::{Deserialize, Serialize};
368 use serde_json::{Value, json};
369
370 use crate::completion::GetTokenUsage;
371 use crate::message::{DocumentSourceKind, ImageMediaType, MessageError, MimeType};
372 use crate::{
373 OneOrMany,
374 completion::CompletionError,
375 message::{self, Reasoning, Text},
376 providers::gemini::gemini_api_types::{CodeExecutionResult, ExecutableCode},
377 };
378
379 #[derive(Debug, Deserialize, Serialize, Default)]
380 #[serde(rename_all = "camelCase")]
381 pub struct AdditionalParameters {
382 pub generation_config: GenerationConfig,
384 #[serde(flatten, skip_serializing_if = "Option::is_none")]
386 pub additional_params: Option<serde_json::Value>,
387 }
388
389 impl AdditionalParameters {
390 pub fn with_config(mut self, cfg: GenerationConfig) -> Self {
391 self.generation_config = cfg;
392 self
393 }
394
395 pub fn with_params(mut self, params: serde_json::Value) -> Self {
396 self.additional_params = Some(params);
397 self
398 }
399 }
400
401 #[derive(Debug, Deserialize, Serialize)]
409 #[serde(rename_all = "camelCase")]
410 pub struct GenerateContentResponse {
411 pub response_id: String,
412 pub candidates: Vec<ContentCandidate>,
414 pub prompt_feedback: Option<PromptFeedback>,
416 pub usage_metadata: Option<UsageMetadata>,
418 pub model_version: Option<String>,
419 }
420
421 impl ProviderResponseExt for GenerateContentResponse {
422 type OutputMessage = ContentCandidate;
423 type Usage = UsageMetadata;
424
425 fn get_response_id(&self) -> Option<String> {
426 Some(self.response_id.clone())
427 }
428
429 fn get_response_model_name(&self) -> Option<String> {
430 None
431 }
432
433 fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
434 self.candidates.clone()
435 }
436
437 fn get_text_response(&self) -> Option<String> {
438 let str = self
439 .candidates
440 .iter()
441 .filter_map(|x| {
442 if x.content.role.as_ref().is_none_or(|y| y != &Role::Model) {
443 return None;
444 }
445
446 let res = x
447 .content
448 .parts
449 .iter()
450 .filter_map(|part| {
451 if let PartKind::Text(ref str) = part.part {
452 Some(str.to_owned())
453 } else {
454 None
455 }
456 })
457 .collect::<Vec<String>>()
458 .join("\n");
459
460 Some(res)
461 })
462 .collect::<Vec<String>>()
463 .join("\n");
464
465 if str.is_empty() { None } else { Some(str) }
466 }
467
468 fn get_usage(&self) -> Option<Self::Usage> {
469 self.usage_metadata.clone()
470 }
471 }
472
473 #[derive(Clone, Debug, Deserialize, Serialize)]
475 #[serde(rename_all = "camelCase")]
476 pub struct ContentCandidate {
477 pub content: Content,
479 pub finish_reason: Option<FinishReason>,
482 pub safety_ratings: Option<Vec<SafetyRating>>,
485 pub citation_metadata: Option<CitationMetadata>,
489 pub token_count: Option<i32>,
491 pub avg_logprobs: Option<f64>,
493 pub logprobs_result: Option<LogprobsResult>,
495 pub index: Option<i32>,
497 }
498
499 #[derive(Clone, Debug, Deserialize, Serialize)]
500 pub struct Content {
501 #[serde(default)]
503 pub parts: Vec<Part>,
504 pub role: Option<Role>,
507 }
508
509 impl TryFrom<message::Message> for Content {
510 type Error = message::MessageError;
511
512 fn try_from(msg: message::Message) -> Result<Self, Self::Error> {
513 Ok(match msg {
514 message::Message::User { content } => Content {
515 parts: content
516 .into_iter()
517 .map(|c| c.try_into())
518 .collect::<Result<Vec<_>, _>>()?,
519 role: Some(Role::User),
520 },
521 message::Message::Assistant { content, .. } => Content {
522 role: Some(Role::Model),
523 parts: content.into_iter().map(|content| content.into()).collect(),
524 },
525 })
526 }
527 }
528
529 impl TryFrom<Content> for message::Message {
530 type Error = message::MessageError;
531
532 fn try_from(content: Content) -> Result<Self, Self::Error> {
533 match content.role {
534 Some(Role::User) | None => {
535 Ok(message::Message::User {
536 content: {
537 let user_content: Result<Vec<_>, _> = content.parts.into_iter()
538 .map(|Part { part, .. }| {
539 Ok(match part {
540 PartKind::Text(text) => message::UserContent::text(text),
541 PartKind::InlineData(inline_data) => {
542 let mime_type =
543 message::MediaType::from_mime_type(&inline_data.mime_type);
544
545 match mime_type {
546 Some(message::MediaType::Image(media_type)) => {
547 message::UserContent::image_base64(
548 inline_data.data,
549 Some(media_type),
550 Some(message::ImageDetail::default()),
551 )
552 }
553 Some(message::MediaType::Document(media_type)) => {
554 message::UserContent::document(
555 inline_data.data,
556 Some(media_type),
557 )
558 }
559 Some(message::MediaType::Audio(media_type)) => {
560 message::UserContent::audio(
561 inline_data.data,
562 Some(media_type),
563 )
564 }
565 _ => {
566 return Err(message::MessageError::ConversionError(
567 format!("Unsupported media type {mime_type:?}"),
568 ));
569 }
570 }
571 }
572 _ => {
573 return Err(message::MessageError::ConversionError(format!(
574 "Unsupported gemini content part type: {part:?}"
575 )));
576 }
577 })
578 })
579 .collect();
580 OneOrMany::many(user_content?).map_err(|_| {
581 message::MessageError::ConversionError(
582 "Failed to create OneOrMany from user content".to_string(),
583 )
584 })?
585 },
586 })
587 }
588 Some(Role::Model) => Ok(message::Message::Assistant {
589 id: None,
590 content: {
591 let assistant_content: Result<Vec<_>, _> = content
592 .parts
593 .into_iter()
594 .map(|Part { thought, part, .. }| {
595 Ok(match part {
596 PartKind::Text(text) => match thought {
597 Some(true) => message::AssistantContent::Reasoning(
598 Reasoning::new(&text),
599 ),
600 _ => message::AssistantContent::Text(Text { text }),
601 },
602
603 PartKind::FunctionCall(function_call) => {
604 message::AssistantContent::ToolCall(function_call.into())
605 }
606 _ => {
607 return Err(message::MessageError::ConversionError(
608 format!("Unsupported part type: {part:?}"),
609 ));
610 }
611 })
612 })
613 .collect();
614 OneOrMany::many(assistant_content?).map_err(|_| {
615 message::MessageError::ConversionError(
616 "Failed to create OneOrMany from assistant content".to_string(),
617 )
618 })?
619 },
620 }),
621 }
622 }
623 }
624
625 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
626 #[serde(rename_all = "lowercase")]
627 pub enum Role {
628 User,
629 Model,
630 }
631
632 #[derive(Debug, Default, Deserialize, Serialize, Clone, PartialEq)]
633 #[serde(rename_all = "camelCase")]
634 pub struct Part {
635 #[serde(skip_serializing_if = "Option::is_none")]
637 pub thought: Option<bool>,
638 #[serde(skip_serializing_if = "Option::is_none")]
640 pub thought_signature: Option<String>,
641 #[serde(flatten)]
642 pub part: PartKind,
643 #[serde(flatten, skip_serializing_if = "Option::is_none")]
644 pub additional_params: Option<Value>,
645 }
646
647 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
651 #[serde(rename_all = "camelCase")]
652 pub enum PartKind {
653 Text(String),
654 InlineData(Blob),
655 FunctionCall(FunctionCall),
656 FunctionResponse(FunctionResponse),
657 FileData(FileData),
658 ExecutableCode(ExecutableCode),
659 CodeExecutionResult(CodeExecutionResult),
660 }
661
662 impl Default for PartKind {
665 fn default() -> Self {
666 Self::Text(String::new())
667 }
668 }
669
670 impl From<String> for Part {
671 fn from(text: String) -> Self {
672 Self {
673 thought: Some(false),
674 thought_signature: None,
675 part: PartKind::Text(text),
676 additional_params: None,
677 }
678 }
679 }
680
681 impl From<&str> for Part {
682 fn from(text: &str) -> Self {
683 Self::from(text.to_string())
684 }
685 }
686
687 impl FromStr for Part {
688 type Err = Infallible;
689
690 fn from_str(s: &str) -> Result<Self, Self::Err> {
691 Ok(s.into())
692 }
693 }
694
695 impl TryFrom<(ImageMediaType, DocumentSourceKind)> for PartKind {
696 type Error = message::MessageError;
697 fn try_from(
698 (mime_type, doc_src): (ImageMediaType, DocumentSourceKind),
699 ) -> Result<Self, Self::Error> {
700 let mime_type = mime_type.to_mime_type().to_string();
701 let part = match doc_src {
702 DocumentSourceKind::Url(url) => PartKind::FileData(FileData {
703 mime_type: Some(mime_type),
704 file_uri: url,
705 }),
706 DocumentSourceKind::Base64(data) | DocumentSourceKind::String(data) => {
707 PartKind::InlineData(Blob { mime_type, data })
708 }
709 DocumentSourceKind::Raw(_) => {
710 return Err(message::MessageError::ConversionError(
711 "Raw files not supported, encode as base64 first".into(),
712 ));
713 }
714 DocumentSourceKind::Unknown => {
715 return Err(message::MessageError::ConversionError(
716 "Can't convert an unknown document source".to_string(),
717 ));
718 }
719 };
720
721 Ok(part)
722 }
723 }
724
725 impl TryFrom<message::UserContent> for Part {
726 type Error = message::MessageError;
727
728 fn try_from(content: message::UserContent) -> Result<Self, Self::Error> {
729 match content {
730 message::UserContent::Text(message::Text { text }) => Ok(Part {
731 thought: Some(false),
732 thought_signature: None,
733 part: PartKind::Text(text),
734 additional_params: None,
735 }),
736 message::UserContent::ToolResult(message::ToolResult { id, content, .. }) => {
737 let content = match content.first() {
738 message::ToolResultContent::Text(text) => text.text,
739 message::ToolResultContent::Image(_) => {
740 return Err(message::MessageError::ConversionError(
741 "Tool result content must be text".to_string(),
742 ));
743 }
744 };
745 let result: serde_json::Value =
747 serde_json::from_str(&content).unwrap_or_else(|error| {
748 tracing::trace!(
749 ?error,
750 "Tool result is not a valid JSON, treat it as normal string"
751 );
752 json!(content)
753 });
754 Ok(Part {
755 thought: Some(false),
756 thought_signature: None,
757 part: PartKind::FunctionResponse(FunctionResponse {
758 name: id,
759 response: Some(json!({ "result": result })),
760 }),
761 additional_params: None,
762 })
763 }
764 message::UserContent::Image(message::Image {
765 data, media_type, ..
766 }) => match media_type {
767 Some(media_type) => match media_type {
768 message::ImageMediaType::JPEG
769 | message::ImageMediaType::PNG
770 | message::ImageMediaType::WEBP
771 | message::ImageMediaType::HEIC
772 | message::ImageMediaType::HEIF => {
773 let part = PartKind::try_from((media_type, data))?;
774 Ok(Part {
775 thought: Some(false),
776 thought_signature: None,
777 part,
778 additional_params: None,
779 })
780 }
781 _ => Err(message::MessageError::ConversionError(format!(
782 "Unsupported image media type {media_type:?}"
783 ))),
784 },
785 None => Err(message::MessageError::ConversionError(
786 "Media type for image is required for Gemini".to_string(),
787 )),
788 },
789 message::UserContent::Document(message::Document {
790 data, media_type, ..
791 }) => {
792 let Some(media_type) = media_type else {
793 return Err(MessageError::ConversionError(
794 "A mime type is required for document inputs to Gemini".to_string(),
795 ));
796 };
797
798 if !media_type.is_code() {
799 let mime_type = media_type.to_mime_type().to_string();
800
801 let part = match data {
802 DocumentSourceKind::Url(file_uri) => PartKind::FileData(FileData {
803 mime_type: Some(mime_type),
804 file_uri,
805 }),
806 DocumentSourceKind::Base64(data) | DocumentSourceKind::String(data) => {
807 PartKind::InlineData(Blob { mime_type, data })
808 }
809 DocumentSourceKind::Raw(_) => {
810 return Err(message::MessageError::ConversionError(
811 "Raw files not supported, encode as base64 first".into(),
812 ));
813 }
814 _ => {
815 return Err(message::MessageError::ConversionError(
816 "Document has no body".to_string(),
817 ));
818 }
819 };
820
821 Ok(Part {
822 thought: Some(false),
823 part,
824 ..Default::default()
825 })
826 } else {
827 Err(message::MessageError::ConversionError(format!(
828 "Unsupported document media type {media_type:?}"
829 )))
830 }
831 }
832
833 message::UserContent::Audio(message::Audio {
834 data, media_type, ..
835 }) => {
836 let Some(media_type) = media_type else {
837 return Err(MessageError::ConversionError(
838 "A mime type is required for audio inputs to Gemini".to_string(),
839 ));
840 };
841
842 let mime_type = media_type.to_mime_type().to_string();
843
844 let part = match data {
845 DocumentSourceKind::Base64(data) => {
846 PartKind::InlineData(Blob { data, mime_type })
847 }
848
849 DocumentSourceKind::Url(file_uri) => PartKind::FileData(FileData {
850 mime_type: Some(mime_type),
851 file_uri,
852 }),
853 DocumentSourceKind::String(_) => {
854 return Err(message::MessageError::ConversionError(
855 "Strings cannot be used as audio files!".into(),
856 ));
857 }
858 DocumentSourceKind::Raw(_) => {
859 return Err(message::MessageError::ConversionError(
860 "Raw files not supported, encode as base64 first".into(),
861 ));
862 }
863 DocumentSourceKind::Unknown => {
864 return Err(message::MessageError::ConversionError(
865 "Content has no body".to_string(),
866 ));
867 }
868 };
869
870 Ok(Part {
871 thought: Some(false),
872 part,
873 ..Default::default()
874 })
875 }
876 message::UserContent::Video(message::Video {
877 data,
878 media_type,
879 additional_params,
880 ..
881 }) => {
882 let mime_type = media_type.map(|media_ty| media_ty.to_mime_type().to_string());
883
884 let part = match data {
885 DocumentSourceKind::Url(file_uri) => {
886 if file_uri.starts_with("https://www.youtube.com") {
887 PartKind::FileData(FileData {
888 mime_type,
889 file_uri,
890 })
891 } else {
892 if mime_type.is_none() {
893 return Err(MessageError::ConversionError(
894 "A mime type is required for non-Youtube video file inputs to Gemini"
895 .to_string(),
896 ));
897 }
898
899 PartKind::FileData(FileData {
900 mime_type,
901 file_uri,
902 })
903 }
904 }
905 DocumentSourceKind::Base64(data) => {
906 let Some(mime_type) = mime_type else {
907 return Err(MessageError::ConversionError(
908 "A media type is expected for base64 encoded strings"
909 .to_string(),
910 ));
911 };
912 PartKind::InlineData(Blob { mime_type, data })
913 }
914 DocumentSourceKind::String(_) => {
915 return Err(message::MessageError::ConversionError(
916 "Strings cannot be used as audio files!".into(),
917 ));
918 }
919 DocumentSourceKind::Raw(_) => {
920 return Err(message::MessageError::ConversionError(
921 "Raw file data not supported, encode as base64 first".into(),
922 ));
923 }
924 DocumentSourceKind::Unknown => {
925 return Err(message::MessageError::ConversionError(
926 "Media type for video is required for Gemini".to_string(),
927 ));
928 }
929 };
930
931 Ok(Part {
932 thought: Some(false),
933 thought_signature: None,
934 part,
935 additional_params,
936 })
937 }
938 }
939 }
940 }
941
942 impl From<message::AssistantContent> for Part {
943 fn from(content: message::AssistantContent) -> Self {
944 match content {
945 message::AssistantContent::Text(message::Text { text }) => text.into(),
946 message::AssistantContent::ToolCall(tool_call) => tool_call.into(),
947 message::AssistantContent::Reasoning(message::Reasoning { reasoning, .. }) => {
948 Part {
949 thought: Some(true),
950 thought_signature: None,
951 part: PartKind::Text(
952 reasoning.first().cloned().unwrap_or_else(|| "".to_string()),
953 ),
954 additional_params: None,
955 }
956 }
957 }
958 }
959 }
960
961 impl From<message::ToolCall> for Part {
962 fn from(tool_call: message::ToolCall) -> Self {
963 Self {
964 thought: Some(false),
965 thought_signature: None,
966 part: PartKind::FunctionCall(FunctionCall {
967 name: tool_call.function.name,
968 args: tool_call.function.arguments,
969 }),
970 additional_params: None,
971 }
972 }
973 }
974
975 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
978 #[serde(rename_all = "camelCase")]
979 pub struct Blob {
980 pub mime_type: String,
983 pub data: String,
985 }
986
987 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
990 pub struct FunctionCall {
991 pub name: String,
994 pub args: serde_json::Value,
996 }
997
998 impl From<FunctionCall> for message::ToolCall {
999 fn from(function_call: FunctionCall) -> Self {
1000 Self {
1001 id: function_call.name.clone(),
1002 call_id: None,
1003 function: message::ToolFunction {
1004 name: function_call.name,
1005 arguments: function_call.args,
1006 },
1007 }
1008 }
1009 }
1010
1011 impl From<message::ToolCall> for FunctionCall {
1012 fn from(tool_call: message::ToolCall) -> Self {
1013 Self {
1014 name: tool_call.function.name,
1015 args: tool_call.function.arguments,
1016 }
1017 }
1018 }
1019
1020 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1024 pub struct FunctionResponse {
1025 pub name: String,
1028 pub response: Option<serde_json::Value>,
1030 }
1031
1032 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1034 #[serde(rename_all = "camelCase")]
1035 pub struct FileData {
1036 pub mime_type: Option<String>,
1038 pub file_uri: String,
1040 }
1041
1042 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1043 pub struct SafetyRating {
1044 pub category: HarmCategory,
1045 pub probability: HarmProbability,
1046 }
1047
1048 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1049 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1050 pub enum HarmProbability {
1051 HarmProbabilityUnspecified,
1052 Negligible,
1053 Low,
1054 Medium,
1055 High,
1056 }
1057
1058 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1059 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1060 pub enum HarmCategory {
1061 HarmCategoryUnspecified,
1062 HarmCategoryDerogatory,
1063 HarmCategoryToxicity,
1064 HarmCategoryViolence,
1065 HarmCategorySexually,
1066 HarmCategoryMedical,
1067 HarmCategoryDangerous,
1068 HarmCategoryHarassment,
1069 HarmCategoryHateSpeech,
1070 HarmCategorySexuallyExplicit,
1071 HarmCategoryDangerousContent,
1072 HarmCategoryCivicIntegrity,
1073 }
1074
1075 #[derive(Debug, Deserialize, Clone, Default, Serialize)]
1076 #[serde(rename_all = "camelCase")]
1077 pub struct UsageMetadata {
1078 pub prompt_token_count: i32,
1079 #[serde(skip_serializing_if = "Option::is_none")]
1080 pub cached_content_token_count: Option<i32>,
1081 #[serde(skip_serializing_if = "Option::is_none")]
1082 pub candidates_token_count: Option<i32>,
1083 pub total_token_count: i32,
1084 #[serde(skip_serializing_if = "Option::is_none")]
1085 pub thoughts_token_count: Option<i32>,
1086 }
1087
1088 impl std::fmt::Display for UsageMetadata {
1089 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1090 write!(
1091 f,
1092 "Prompt token count: {}\nCached content token count: {}\nCandidates token count: {}\nTotal token count: {}",
1093 self.prompt_token_count,
1094 match self.cached_content_token_count {
1095 Some(count) => count.to_string(),
1096 None => "n/a".to_string(),
1097 },
1098 match self.candidates_token_count {
1099 Some(count) => count.to_string(),
1100 None => "n/a".to_string(),
1101 },
1102 self.total_token_count
1103 )
1104 }
1105 }
1106
1107 impl GetTokenUsage for UsageMetadata {
1108 fn token_usage(&self) -> Option<crate::completion::Usage> {
1109 let mut usage = crate::completion::Usage::new();
1110
1111 usage.input_tokens = self.prompt_token_count as u64;
1112 usage.output_tokens = (self.cached_content_token_count.unwrap_or_default()
1113 + self.candidates_token_count.unwrap_or_default()
1114 + self.thoughts_token_count.unwrap_or_default())
1115 as u64;
1116 usage.total_tokens = usage.input_tokens + usage.output_tokens;
1117
1118 Some(usage)
1119 }
1120 }
1121
1122 #[derive(Debug, Deserialize, Serialize)]
1124 #[serde(rename_all = "camelCase")]
1125 pub struct PromptFeedback {
1126 pub block_reason: Option<BlockReason>,
1128 pub safety_ratings: Option<Vec<SafetyRating>>,
1130 }
1131
1132 #[derive(Debug, Deserialize, Serialize)]
1134 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1135 pub enum BlockReason {
1136 BlockReasonUnspecified,
1138 Safety,
1140 Other,
1142 Blocklist,
1144 ProhibitedContent,
1146 }
1147
1148 #[derive(Clone, Debug, Deserialize, Serialize)]
1149 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1150 pub enum FinishReason {
1151 FinishReasonUnspecified,
1153 Stop,
1155 MaxTokens,
1157 Safety,
1159 Recitation,
1161 Language,
1163 Other,
1165 Blocklist,
1167 ProhibitedContent,
1169 Spii,
1171 MalformedFunctionCall,
1173 }
1174
1175 #[derive(Clone, Debug, Deserialize, Serialize)]
1176 #[serde(rename_all = "camelCase")]
1177 pub struct CitationMetadata {
1178 pub citation_sources: Vec<CitationSource>,
1179 }
1180
1181 #[derive(Clone, Debug, Deserialize, Serialize)]
1182 #[serde(rename_all = "camelCase")]
1183 pub struct CitationSource {
1184 #[serde(skip_serializing_if = "Option::is_none")]
1185 pub uri: Option<String>,
1186 #[serde(skip_serializing_if = "Option::is_none")]
1187 pub start_index: Option<i32>,
1188 #[serde(skip_serializing_if = "Option::is_none")]
1189 pub end_index: Option<i32>,
1190 #[serde(skip_serializing_if = "Option::is_none")]
1191 pub license: Option<String>,
1192 }
1193
1194 #[derive(Clone, Debug, Deserialize, Serialize)]
1195 #[serde(rename_all = "camelCase")]
1196 pub struct LogprobsResult {
1197 pub top_candidate: Vec<TopCandidate>,
1198 pub chosen_candidate: Vec<LogProbCandidate>,
1199 }
1200
1201 #[derive(Clone, Debug, Deserialize, Serialize)]
1202 pub struct TopCandidate {
1203 pub candidates: Vec<LogProbCandidate>,
1204 }
1205
1206 #[derive(Clone, Debug, Deserialize, Serialize)]
1207 #[serde(rename_all = "camelCase")]
1208 pub struct LogProbCandidate {
1209 pub token: String,
1210 pub token_id: String,
1211 pub log_probability: f64,
1212 }
1213
1214 #[derive(Debug, Deserialize, Serialize)]
1219 #[serde(rename_all = "camelCase")]
1220 pub struct GenerationConfig {
1221 #[serde(skip_serializing_if = "Option::is_none")]
1224 pub stop_sequences: Option<Vec<String>>,
1225 #[serde(skip_serializing_if = "Option::is_none")]
1231 pub response_mime_type: Option<String>,
1232 #[serde(skip_serializing_if = "Option::is_none")]
1236 pub response_schema: Option<Schema>,
1237 #[serde(skip_serializing_if = "Option::is_none")]
1240 pub candidate_count: Option<i32>,
1241 #[serde(skip_serializing_if = "Option::is_none")]
1244 pub max_output_tokens: Option<u64>,
1245 #[serde(skip_serializing_if = "Option::is_none")]
1248 pub temperature: Option<f64>,
1249 #[serde(skip_serializing_if = "Option::is_none")]
1256 pub top_p: Option<f64>,
1257 #[serde(skip_serializing_if = "Option::is_none")]
1263 pub top_k: Option<i32>,
1264 #[serde(skip_serializing_if = "Option::is_none")]
1270 pub presence_penalty: Option<f64>,
1271 #[serde(skip_serializing_if = "Option::is_none")]
1279 pub frequency_penalty: Option<f64>,
1280 #[serde(skip_serializing_if = "Option::is_none")]
1282 pub response_logprobs: Option<bool>,
1283 #[serde(skip_serializing_if = "Option::is_none")]
1286 pub logprobs: Option<i32>,
1287 #[serde(skip_serializing_if = "Option::is_none")]
1289 pub thinking_config: Option<ThinkingConfig>,
1290 }
1291
1292 impl Default for GenerationConfig {
1293 fn default() -> Self {
1294 Self {
1295 temperature: Some(1.0),
1296 max_output_tokens: Some(4096),
1297 stop_sequences: None,
1298 response_mime_type: None,
1299 response_schema: None,
1300 candidate_count: None,
1301 top_p: None,
1302 top_k: None,
1303 presence_penalty: None,
1304 frequency_penalty: None,
1305 response_logprobs: None,
1306 logprobs: None,
1307 thinking_config: None,
1308 }
1309 }
1310 }
1311
1312 #[derive(Debug, Deserialize, Serialize)]
1313 #[serde(rename_all = "camelCase")]
1314 pub struct ThinkingConfig {
1315 pub thinking_budget: u32,
1316 pub include_thoughts: Option<bool>,
1317 }
1318 #[derive(Debug, Deserialize, Serialize, Clone)]
1322 pub struct Schema {
1323 pub r#type: String,
1324 #[serde(skip_serializing_if = "Option::is_none")]
1325 pub format: Option<String>,
1326 #[serde(skip_serializing_if = "Option::is_none")]
1327 pub description: Option<String>,
1328 #[serde(skip_serializing_if = "Option::is_none")]
1329 pub nullable: Option<bool>,
1330 #[serde(skip_serializing_if = "Option::is_none")]
1331 pub r#enum: Option<Vec<String>>,
1332 #[serde(skip_serializing_if = "Option::is_none")]
1333 pub max_items: Option<i32>,
1334 #[serde(skip_serializing_if = "Option::is_none")]
1335 pub min_items: Option<i32>,
1336 #[serde(skip_serializing_if = "Option::is_none")]
1337 pub properties: Option<HashMap<String, Schema>>,
1338 #[serde(skip_serializing_if = "Option::is_none")]
1339 pub required: Option<Vec<String>>,
1340 #[serde(skip_serializing_if = "Option::is_none")]
1341 pub items: Option<Box<Schema>>,
1342 }
1343
1344 pub fn flatten_schema(mut schema: Value) -> Result<Value, CompletionError> {
1350 let defs = if let Some(obj) = schema.as_object() {
1352 obj.get("$defs").or_else(|| obj.get("definitions")).cloned()
1353 } else {
1354 None
1355 };
1356
1357 let Some(defs_value) = defs else {
1358 return Ok(schema);
1359 };
1360
1361 let Some(defs_obj) = defs_value.as_object() else {
1362 return Err(CompletionError::ResponseError(
1363 "$defs must be an object".into(),
1364 ));
1365 };
1366
1367 resolve_refs(&mut schema, defs_obj)?;
1368
1369 if let Some(obj) = schema.as_object_mut() {
1371 obj.remove("$defs");
1372 obj.remove("definitions");
1373 }
1374
1375 Ok(schema)
1376 }
1377
1378 fn resolve_refs(
1381 value: &mut Value,
1382 defs: &serde_json::Map<String, Value>,
1383 ) -> Result<(), CompletionError> {
1384 match value {
1385 Value::Object(obj) => {
1386 if let Some(ref_value) = obj.get("$ref")
1387 && let Some(ref_str) = ref_value.as_str()
1388 {
1389 let def_name = parse_ref_path(ref_str)?;
1391
1392 let def = defs.get(&def_name).ok_or_else(|| {
1393 CompletionError::ResponseError(format!("Reference not found: {}", ref_str))
1394 })?;
1395
1396 let mut resolved = def.clone();
1397 resolve_refs(&mut resolved, defs)?;
1398 *value = resolved;
1399 return Ok(());
1400 }
1401
1402 for (_, v) in obj.iter_mut() {
1403 resolve_refs(v, defs)?;
1404 }
1405 }
1406 Value::Array(arr) => {
1407 for item in arr.iter_mut() {
1408 resolve_refs(item, defs)?;
1409 }
1410 }
1411 _ => {}
1412 }
1413
1414 Ok(())
1415 }
1416
1417 fn parse_ref_path(ref_str: &str) -> Result<String, CompletionError> {
1423 if let Some(fragment) = ref_str.strip_prefix('#') {
1424 if let Some(name) = fragment.strip_prefix("/$defs/") {
1425 Ok(name.to_string())
1426 } else if let Some(name) = fragment.strip_prefix("/definitions/") {
1427 Ok(name.to_string())
1428 } else {
1429 Err(CompletionError::ResponseError(format!(
1430 "Unsupported reference format: {}",
1431 ref_str
1432 )))
1433 }
1434 } else {
1435 Err(CompletionError::ResponseError(format!(
1436 "Only fragment references (#/...) are supported: {}",
1437 ref_str
1438 )))
1439 }
1440 }
1441
1442 fn extract_type(type_value: &Value) -> Option<String> {
1445 if type_value.is_string() {
1446 type_value.as_str().map(String::from)
1447 } else if type_value.is_array() {
1448 type_value
1449 .as_array()
1450 .and_then(|arr| arr.first())
1451 .and_then(|v| v.as_str().map(String::from))
1452 } else {
1453 None
1454 }
1455 }
1456
1457 fn extract_type_from_composition(composition: &Value) -> Option<String> {
1460 composition.as_array().and_then(|arr| {
1461 arr.iter().find_map(|schema| {
1462 if let Some(obj) = schema.as_object() {
1463 if let Some(type_val) = obj.get("type")
1465 && let Some(type_str) = type_val.as_str()
1466 && type_str == "null"
1467 {
1468 return None;
1469 }
1470 obj.get("type").and_then(extract_type).or_else(|| {
1472 if obj.contains_key("properties") {
1473 Some("object".to_string())
1474 } else {
1475 None
1476 }
1477 })
1478 } else {
1479 None
1480 }
1481 })
1482 })
1483 }
1484
1485 fn extract_schema_from_composition(
1488 composition: &Value,
1489 ) -> Option<serde_json::Map<String, Value>> {
1490 composition.as_array().and_then(|arr| {
1491 arr.iter().find_map(|schema| {
1492 if let Some(obj) = schema.as_object()
1493 && let Some(type_val) = obj.get("type")
1494 && let Some(type_str) = type_val.as_str()
1495 {
1496 if type_str == "null" {
1497 return None;
1498 }
1499 Some(obj.clone())
1500 } else {
1501 None
1502 }
1503 })
1504 })
1505 }
1506
1507 fn infer_type(obj: &serde_json::Map<String, Value>) -> String {
1510 if let Some(type_val) = obj.get("type")
1512 && let Some(type_str) = extract_type(type_val)
1513 {
1514 return type_str;
1515 }
1516
1517 if let Some(any_of) = obj.get("anyOf")
1519 && let Some(type_str) = extract_type_from_composition(any_of)
1520 {
1521 return type_str;
1522 }
1523
1524 if let Some(one_of) = obj.get("oneOf")
1525 && let Some(type_str) = extract_type_from_composition(one_of)
1526 {
1527 return type_str;
1528 }
1529
1530 if let Some(all_of) = obj.get("allOf")
1531 && let Some(type_str) = extract_type_from_composition(all_of)
1532 {
1533 return type_str;
1534 }
1535
1536 if obj.contains_key("properties") {
1538 "object".to_string()
1539 } else {
1540 String::new()
1541 }
1542 }
1543
1544 impl TryFrom<Value> for Schema {
1545 type Error = CompletionError;
1546
1547 fn try_from(value: Value) -> Result<Self, Self::Error> {
1548 let flattened_val = flatten_schema(value)?;
1549 if let Some(obj) = flattened_val.as_object() {
1550 let props_source = if obj.get("properties").is_none() {
1553 if let Some(any_of) = obj.get("anyOf") {
1554 extract_schema_from_composition(any_of)
1555 } else if let Some(one_of) = obj.get("oneOf") {
1556 extract_schema_from_composition(one_of)
1557 } else if let Some(all_of) = obj.get("allOf") {
1558 extract_schema_from_composition(all_of)
1559 } else {
1560 None
1561 }
1562 .unwrap_or(obj.clone())
1563 } else {
1564 obj.clone()
1565 };
1566
1567 Ok(Schema {
1568 r#type: infer_type(obj),
1569 format: obj.get("format").and_then(|v| v.as_str()).map(String::from),
1570 description: obj
1571 .get("description")
1572 .and_then(|v| v.as_str())
1573 .map(String::from),
1574 nullable: obj.get("nullable").and_then(|v| v.as_bool()),
1575 r#enum: obj.get("enum").and_then(|v| v.as_array()).map(|arr| {
1576 arr.iter()
1577 .filter_map(|v| v.as_str().map(String::from))
1578 .collect()
1579 }),
1580 max_items: obj
1581 .get("maxItems")
1582 .and_then(|v| v.as_i64())
1583 .map(|v| v as i32),
1584 min_items: obj
1585 .get("minItems")
1586 .and_then(|v| v.as_i64())
1587 .map(|v| v as i32),
1588 properties: props_source
1589 .get("properties")
1590 .and_then(|v| v.as_object())
1591 .map(|map| {
1592 map.iter()
1593 .filter_map(|(k, v)| {
1594 v.clone().try_into().ok().map(|schema| (k.clone(), schema))
1595 })
1596 .collect()
1597 }),
1598 required: props_source
1599 .get("required")
1600 .and_then(|v| v.as_array())
1601 .map(|arr| {
1602 arr.iter()
1603 .filter_map(|v| v.as_str().map(String::from))
1604 .collect()
1605 }),
1606 items: obj
1607 .get("items")
1608 .and_then(|v| v.clone().try_into().ok())
1609 .map(Box::new),
1610 })
1611 } else {
1612 Err(CompletionError::ResponseError(
1613 "Expected a JSON object for Schema".into(),
1614 ))
1615 }
1616 }
1617 }
1618
1619 #[derive(Debug, Serialize)]
1620 #[serde(rename_all = "camelCase")]
1621 pub struct GenerateContentRequest {
1622 pub contents: Vec<Content>,
1623 #[serde(skip_serializing_if = "Option::is_none")]
1624 pub tools: Option<Tool>,
1625 pub tool_config: Option<ToolConfig>,
1626 pub generation_config: Option<GenerationConfig>,
1628 pub safety_settings: Option<Vec<SafetySetting>>,
1642 pub system_instruction: Option<Content>,
1645 #[serde(flatten, skip_serializing_if = "Option::is_none")]
1648 pub additional_params: Option<serde_json::Value>,
1649 }
1650
1651 #[derive(Debug, Serialize)]
1652 #[serde(rename_all = "camelCase")]
1653 pub struct Tool {
1654 pub function_declarations: Vec<FunctionDeclaration>,
1655 pub code_execution: Option<CodeExecution>,
1656 }
1657
1658 #[derive(Debug, Serialize, Clone)]
1659 #[serde(rename_all = "camelCase")]
1660 pub struct FunctionDeclaration {
1661 pub name: String,
1662 pub description: String,
1663 #[serde(skip_serializing_if = "Option::is_none")]
1664 pub parameters: Option<Schema>,
1665 }
1666
1667 #[derive(Debug, Serialize, Deserialize)]
1668 #[serde(rename_all = "camelCase")]
1669 pub struct ToolConfig {
1670 pub function_calling_config: Option<FunctionCallingMode>,
1671 }
1672
1673 #[derive(Debug, Serialize, Deserialize, Default)]
1674 #[serde(tag = "mode", rename_all = "UPPERCASE")]
1675 pub enum FunctionCallingMode {
1676 #[default]
1677 Auto,
1678 None,
1679 Any {
1680 #[serde(skip_serializing_if = "Option::is_none")]
1681 allowed_function_names: Option<Vec<String>>,
1682 },
1683 }
1684
1685 impl TryFrom<message::ToolChoice> for FunctionCallingMode {
1686 type Error = CompletionError;
1687 fn try_from(value: message::ToolChoice) -> Result<Self, Self::Error> {
1688 let res = match value {
1689 message::ToolChoice::Auto => Self::Auto,
1690 message::ToolChoice::None => Self::None,
1691 message::ToolChoice::Required => Self::Any {
1692 allowed_function_names: None,
1693 },
1694 message::ToolChoice::Specific { function_names } => Self::Any {
1695 allowed_function_names: Some(function_names),
1696 },
1697 };
1698
1699 Ok(res)
1700 }
1701 }
1702
1703 #[derive(Debug, Serialize)]
1704 pub struct CodeExecution {}
1705
1706 #[derive(Debug, Serialize)]
1707 #[serde(rename_all = "camelCase")]
1708 pub struct SafetySetting {
1709 pub category: HarmCategory,
1710 pub threshold: HarmBlockThreshold,
1711 }
1712
1713 #[derive(Debug, Serialize)]
1714 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1715 pub enum HarmBlockThreshold {
1716 HarmBlockThresholdUnspecified,
1717 BlockLowAndAbove,
1718 BlockMediumAndAbove,
1719 BlockOnlyHigh,
1720 BlockNone,
1721 Off,
1722 }
1723}
1724
1725#[cfg(test)]
1726mod tests {
1727 use crate::{message, providers::gemini::completion::gemini_api_types::flatten_schema};
1728
1729 use super::*;
1730 use serde_json::json;
1731
1732 #[test]
1733 fn test_deserialize_message_user() {
1734 let raw_message = r#"{
1735 "parts": [
1736 {"text": "Hello, world!"},
1737 {"inlineData": {"mimeType": "image/png", "data": "base64encodeddata"}},
1738 {"functionCall": {"name": "test_function", "args": {"arg1": "value1"}}},
1739 {"functionResponse": {"name": "test_function", "response": {"result": "success"}}},
1740 {"fileData": {"mimeType": "application/pdf", "fileUri": "http://example.com/file.pdf"}},
1741 {"executableCode": {"code": "print('Hello, world!')", "language": "PYTHON"}},
1742 {"codeExecutionResult": {"output": "Hello, world!", "outcome": "OUTCOME_OK"}}
1743 ],
1744 "role": "user"
1745 }"#;
1746
1747 let content: Content = {
1748 let jd = &mut serde_json::Deserializer::from_str(raw_message);
1749 serde_path_to_error::deserialize(jd).unwrap_or_else(|err| {
1750 panic!("Deserialization error at {}: {}", err.path(), err);
1751 })
1752 };
1753 assert_eq!(content.role, Some(Role::User));
1754 assert_eq!(content.parts.len(), 7);
1755
1756 let parts: Vec<Part> = content.parts.into_iter().collect();
1757
1758 if let Part {
1759 part: PartKind::Text(text),
1760 ..
1761 } = &parts[0]
1762 {
1763 assert_eq!(text, "Hello, world!");
1764 } else {
1765 panic!("Expected text part");
1766 }
1767
1768 if let Part {
1769 part: PartKind::InlineData(inline_data),
1770 ..
1771 } = &parts[1]
1772 {
1773 assert_eq!(inline_data.mime_type, "image/png");
1774 assert_eq!(inline_data.data, "base64encodeddata");
1775 } else {
1776 panic!("Expected inline data part");
1777 }
1778
1779 if let Part {
1780 part: PartKind::FunctionCall(function_call),
1781 ..
1782 } = &parts[2]
1783 {
1784 assert_eq!(function_call.name, "test_function");
1785 assert_eq!(
1786 function_call.args.as_object().unwrap().get("arg1").unwrap(),
1787 "value1"
1788 );
1789 } else {
1790 panic!("Expected function call part");
1791 }
1792
1793 if let Part {
1794 part: PartKind::FunctionResponse(function_response),
1795 ..
1796 } = &parts[3]
1797 {
1798 assert_eq!(function_response.name, "test_function");
1799 assert_eq!(
1800 function_response
1801 .response
1802 .as_ref()
1803 .unwrap()
1804 .get("result")
1805 .unwrap(),
1806 "success"
1807 );
1808 } else {
1809 panic!("Expected function response part");
1810 }
1811
1812 if let Part {
1813 part: PartKind::FileData(file_data),
1814 ..
1815 } = &parts[4]
1816 {
1817 assert_eq!(file_data.mime_type.as_ref().unwrap(), "application/pdf");
1818 assert_eq!(file_data.file_uri, "http://example.com/file.pdf");
1819 } else {
1820 panic!("Expected file data part");
1821 }
1822
1823 if let Part {
1824 part: PartKind::ExecutableCode(executable_code),
1825 ..
1826 } = &parts[5]
1827 {
1828 assert_eq!(executable_code.code, "print('Hello, world!')");
1829 } else {
1830 panic!("Expected executable code part");
1831 }
1832
1833 if let Part {
1834 part: PartKind::CodeExecutionResult(code_execution_result),
1835 ..
1836 } = &parts[6]
1837 {
1838 assert_eq!(
1839 code_execution_result.clone().output.unwrap(),
1840 "Hello, world!"
1841 );
1842 } else {
1843 panic!("Expected code execution result part");
1844 }
1845 }
1846
1847 #[test]
1848 fn test_deserialize_message_model() {
1849 let json_data = json!({
1850 "parts": [{"text": "Hello, user!"}],
1851 "role": "model"
1852 });
1853
1854 let content: Content = serde_json::from_value(json_data).unwrap();
1855 assert_eq!(content.role, Some(Role::Model));
1856 assert_eq!(content.parts.len(), 1);
1857 if let Some(Part {
1858 part: PartKind::Text(text),
1859 ..
1860 }) = content.parts.first()
1861 {
1862 assert_eq!(text, "Hello, user!");
1863 } else {
1864 panic!("Expected text part");
1865 }
1866 }
1867
1868 #[test]
1869 fn test_message_conversion_user() {
1870 let msg = message::Message::user("Hello, world!");
1871 let content: Content = msg.try_into().unwrap();
1872 assert_eq!(content.role, Some(Role::User));
1873 assert_eq!(content.parts.len(), 1);
1874 if let Some(Part {
1875 part: PartKind::Text(text),
1876 ..
1877 }) = &content.parts.first()
1878 {
1879 assert_eq!(text, "Hello, world!");
1880 } else {
1881 panic!("Expected text part");
1882 }
1883 }
1884
1885 #[test]
1886 fn test_message_conversion_model() {
1887 let msg = message::Message::assistant("Hello, user!");
1888
1889 let content: Content = msg.try_into().unwrap();
1890 assert_eq!(content.role, Some(Role::Model));
1891 assert_eq!(content.parts.len(), 1);
1892 if let Some(Part {
1893 part: PartKind::Text(text),
1894 ..
1895 }) = &content.parts.first()
1896 {
1897 assert_eq!(text, "Hello, user!");
1898 } else {
1899 panic!("Expected text part");
1900 }
1901 }
1902
1903 #[test]
1904 fn test_message_conversion_tool_call() {
1905 let tool_call = message::ToolCall {
1906 id: "test_tool".to_string(),
1907 call_id: None,
1908 function: message::ToolFunction {
1909 name: "test_function".to_string(),
1910 arguments: json!({"arg1": "value1"}),
1911 },
1912 };
1913
1914 let msg = message::Message::Assistant {
1915 id: None,
1916 content: OneOrMany::one(message::AssistantContent::ToolCall(tool_call)),
1917 };
1918
1919 let content: Content = msg.try_into().unwrap();
1920 assert_eq!(content.role, Some(Role::Model));
1921 assert_eq!(content.parts.len(), 1);
1922 if let Some(Part {
1923 part: PartKind::FunctionCall(function_call),
1924 ..
1925 }) = content.parts.first()
1926 {
1927 assert_eq!(function_call.name, "test_function");
1928 assert_eq!(
1929 function_call.args.as_object().unwrap().get("arg1").unwrap(),
1930 "value1"
1931 );
1932 } else {
1933 panic!("Expected function call part");
1934 }
1935 }
1936
1937 #[test]
1938 fn test_vec_schema_conversion() {
1939 let schema_with_ref = json!({
1940 "type": "array",
1941 "items": {
1942 "$ref": "#/$defs/Person"
1943 },
1944 "$defs": {
1945 "Person": {
1946 "type": "object",
1947 "properties": {
1948 "first_name": {
1949 "type": ["string", "null"],
1950 "description": "The person's first name, if provided (null otherwise)"
1951 },
1952 "last_name": {
1953 "type": ["string", "null"],
1954 "description": "The person's last name, if provided (null otherwise)"
1955 },
1956 "job": {
1957 "type": ["string", "null"],
1958 "description": "The person's job, if provided (null otherwise)"
1959 }
1960 },
1961 "required": []
1962 }
1963 }
1964 });
1965
1966 let result: Result<Schema, _> = schema_with_ref.try_into();
1967
1968 match result {
1969 Ok(schema) => {
1970 assert_eq!(schema.r#type, "array");
1971
1972 if let Some(items) = schema.items {
1973 println!("item types: {}", items.r#type);
1974
1975 assert_ne!(items.r#type, "", "Items type should not be empty string!");
1976 assert_eq!(items.r#type, "object", "Items should be object type");
1977 } else {
1978 panic!("Schema should have items field for array type");
1979 }
1980 }
1981 Err(e) => println!("Schema conversion failed: {:?}", e),
1982 }
1983 }
1984
1985 #[test]
1986 fn test_object_schema() {
1987 let simple_schema = json!({
1988 "type": "object",
1989 "properties": {
1990 "name": {
1991 "type": "string"
1992 }
1993 }
1994 });
1995
1996 let schema: Schema = simple_schema.try_into().unwrap();
1997 assert_eq!(schema.r#type, "object");
1998 assert!(schema.properties.is_some());
1999 }
2000
2001 #[test]
2002 fn test_array_with_inline_items() {
2003 let inline_schema = json!({
2004 "type": "array",
2005 "items": {
2006 "type": "object",
2007 "properties": {
2008 "name": {
2009 "type": "string"
2010 }
2011 }
2012 }
2013 });
2014
2015 let schema: Schema = inline_schema.try_into().unwrap();
2016 assert_eq!(schema.r#type, "array");
2017
2018 if let Some(items) = schema.items {
2019 assert_eq!(items.r#type, "object");
2020 assert!(items.properties.is_some());
2021 } else {
2022 panic!("Schema should have items field");
2023 }
2024 }
2025 #[test]
2026 fn test_flattened_schema() {
2027 let ref_schema = json!({
2028 "type": "array",
2029 "items": {
2030 "$ref": "#/$defs/Person"
2031 },
2032 "$defs": {
2033 "Person": {
2034 "type": "object",
2035 "properties": {
2036 "name": { "type": "string" }
2037 }
2038 }
2039 }
2040 });
2041
2042 let flattened = flatten_schema(ref_schema).unwrap();
2043 let schema: Schema = flattened.try_into().unwrap();
2044
2045 assert_eq!(schema.r#type, "array");
2046
2047 if let Some(items) = schema.items {
2048 println!("Flattened items type: '{}'", items.r#type);
2049
2050 assert_eq!(items.r#type, "object");
2051 assert!(items.properties.is_some());
2052 }
2053 }
2054}