1pub const GEMINI_2_5_PRO_PREVIEW_06_05: &str = "gemini-2.5-pro-preview-06-05";
7pub const GEMINI_2_5_PRO_PREVIEW_05_06: &str = "gemini-2.5-pro-preview-05-06";
9pub const GEMINI_2_5_PRO_PREVIEW_03_25: &str = "gemini-2.5-pro-preview-03-25";
11pub const GEMINI_2_5_FLASH_PREVIEW_05_20: &str = "gemini-2.5-flash-preview-05-20";
13pub const GEMINI_2_5_FLASH_PREVIEW_04_17: &str = "gemini-2.5-flash-preview-04-17";
15pub const GEMINI_2_5_PRO_EXP_03_25: &str = "gemini-2.5-pro-exp-03-25";
17pub const GEMINI_2_0_FLASH_LITE: &str = "gemini-2.0-flash-lite";
19pub const GEMINI_2_0_FLASH: &str = "gemini-2.0-flash";
21pub const GEMINI_1_5_FLASH: &str = "gemini-1.5-flash";
23pub const GEMINI_1_5_PRO: &str = "gemini-1.5-pro";
25pub const GEMINI_1_5_PRO_8B: &str = "gemini-1.5-pro-8b";
27pub const GEMINI_1_0_PRO: &str = "gemini-1.0-pro";
29
30use self::gemini_api_types::Schema;
31use crate::message::Reasoning;
32use crate::providers::gemini::completion::gemini_api_types::{
33 AdditionalParameters, FunctionCallingMode, ToolConfig,
34};
35use crate::providers::gemini::streaming::StreamingCompletionResponse;
36use crate::telemetry::SpanCombinator;
37use crate::{
38 OneOrMany,
39 completion::{self, CompletionError, CompletionRequest},
40};
41use gemini_api_types::{
42 Content, FunctionDeclaration, GenerateContentRequest, GenerateContentResponse, Part, PartKind,
43 Role, Tool,
44};
45use serde_json::{Map, Value};
46use std::convert::TryFrom;
47use tracing::info_span;
48
49use super::Client;
50
51#[derive(Clone)]
56pub struct CompletionModel<T = reqwest::Client> {
57 pub(crate) client: Client<T>,
58 pub model: String,
59}
60
61impl<T> CompletionModel<T> {
62 pub fn new(client: Client<T>, model: &str) -> Self {
63 Self {
64 client,
65 model: model.to_string(),
66 }
67 }
68}
69
70impl completion::CompletionModel for CompletionModel<reqwest::Client> {
71 type Response = GenerateContentResponse;
72 type StreamingResponse = StreamingCompletionResponse;
73
74 #[cfg_attr(feature = "worker", worker::send)]
75 async fn completion(
76 &self,
77 completion_request: CompletionRequest,
78 ) -> Result<completion::CompletionResponse<GenerateContentResponse>, CompletionError> {
79 let span = if tracing::Span::current().is_disabled() {
80 info_span!(
81 target: "rig::completions",
82 "generate_content",
83 gen_ai.operation.name = "generate_content",
84 gen_ai.provider.name = "gcp.gemini",
85 gen_ai.request.model = self.model,
86 gen_ai.system_instructions = &completion_request.preamble,
87 gen_ai.response.id = tracing::field::Empty,
88 gen_ai.response.model = tracing::field::Empty,
89 gen_ai.usage.output_tokens = tracing::field::Empty,
90 gen_ai.usage.input_tokens = tracing::field::Empty,
91 gen_ai.input.messages = tracing::field::Empty,
92 gen_ai.output.messages = tracing::field::Empty,
93 )
94 } else {
95 tracing::Span::current()
96 };
97
98 let request = create_request_body(completion_request)?;
99 span.record_model_input(&request.contents);
100
101 span.record_model_input(&request.contents);
102
103 tracing::debug!(
104 "Sending completion request to Gemini API {}",
105 serde_json::to_string_pretty(&request)?
106 );
107
108 let body = serde_json::to_vec(&request)?;
109
110 let request = self
111 .client
112 .post(&format!("/v1beta/models/{}:generateContent", self.model))
113 .header("Content-Type", "application/json")
114 .body(body)
115 .map_err(|e| CompletionError::HttpError(e.into()))?;
116
117 let response = self.client.send::<_, Vec<u8>>(request).await?;
118
119 if response.status().is_success() {
120 let response_body = response
121 .into_body()
122 .await
123 .map_err(CompletionError::HttpError)?;
124
125 let response: GenerateContentResponse = serde_json::from_slice(&response_body)?;
126
127 match response.usage_metadata {
128 Some(ref usage) => tracing::info!(target: "rig",
129 "Gemini completion token usage: {}",
130 usage
131 ),
132 None => tracing::info!(target: "rig",
133 "Gemini completion token usage: n/a",
134 ),
135 }
136
137 let span = tracing::Span::current();
138 span.record_model_output(&response.candidates);
139 span.record_response_metadata(&response);
140 span.record_token_usage(&response.usage_metadata);
141
142 tracing::debug!(
143 "Received response from Gemini API: {}",
144 serde_json::to_string_pretty(&response)?
145 );
146
147 response.try_into()
148 } else {
149 let text = String::from_utf8_lossy(
150 &response
151 .into_body()
152 .await
153 .map_err(CompletionError::HttpError)?,
154 )
155 .into();
156
157 Err(CompletionError::ProviderError(text))
158 }
159 }
160
161 #[cfg_attr(feature = "worker", worker::send)]
162 async fn stream(
163 &self,
164 request: CompletionRequest,
165 ) -> Result<
166 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
167 CompletionError,
168 > {
169 CompletionModel::stream(self, request).await
170 }
171}
172
173pub(crate) fn create_request_body(
174 completion_request: CompletionRequest,
175) -> Result<GenerateContentRequest, CompletionError> {
176 let mut full_history = Vec::new();
177 full_history.extend(completion_request.chat_history);
178
179 let additional_params = completion_request
180 .additional_params
181 .unwrap_or_else(|| Value::Object(Map::new()));
182
183 let AdditionalParameters {
184 mut generation_config,
185 additional_params,
186 } = serde_json::from_value::<AdditionalParameters>(additional_params)?;
187
188 if let Some(temp) = completion_request.temperature {
189 generation_config.temperature = Some(temp);
190 }
191
192 if let Some(max_tokens) = completion_request.max_tokens {
193 generation_config.max_output_tokens = Some(max_tokens);
194 }
195
196 let system_instruction = completion_request.preamble.clone().map(|preamble| Content {
197 parts: vec![preamble.into()],
198 role: Some(Role::Model),
199 });
200
201 let tools = if completion_request.tools.is_empty() {
202 None
203 } else {
204 Some(Tool::try_from(completion_request.tools)?)
205 };
206
207 let tool_config = if let Some(cfg) = completion_request.tool_choice {
208 Some(ToolConfig {
209 function_calling_config: Some(FunctionCallingMode::try_from(cfg)?),
210 })
211 } else {
212 None
213 };
214
215 let request = GenerateContentRequest {
216 contents: full_history
217 .into_iter()
218 .map(|msg| {
219 msg.try_into()
220 .map_err(|e| CompletionError::RequestError(Box::new(e)))
221 })
222 .collect::<Result<Vec<_>, _>>()?,
223 generation_config: Some(generation_config),
224 safety_settings: None,
225 tools,
226 tool_config,
227 system_instruction,
228 additional_params,
229 };
230
231 Ok(request)
232}
233
234impl TryFrom<completion::ToolDefinition> for Tool {
235 type Error = CompletionError;
236
237 fn try_from(tool: completion::ToolDefinition) -> Result<Self, Self::Error> {
238 let parameters: Option<Schema> =
239 if tool.parameters == serde_json::json!({"type": "object", "properties": {}}) {
240 None
241 } else {
242 Some(tool.parameters.try_into()?)
243 };
244
245 Ok(Self {
246 function_declarations: vec![FunctionDeclaration {
247 name: tool.name,
248 description: tool.description,
249 parameters,
250 }],
251 code_execution: None,
252 })
253 }
254}
255
256impl TryFrom<Vec<completion::ToolDefinition>> for Tool {
257 type Error = CompletionError;
258
259 fn try_from(tools: Vec<completion::ToolDefinition>) -> Result<Self, Self::Error> {
260 let mut function_declarations = Vec::new();
261
262 for tool in tools {
263 let parameters =
264 if tool.parameters == serde_json::json!({"type": "object", "properties": {}}) {
265 None
266 } else {
267 match tool.parameters.try_into() {
268 Ok(schema) => Some(schema),
269 Err(e) => {
270 let emsg = format!(
271 "Tool '{}' could not be converted to a schema: {:?}",
272 tool.name, e,
273 );
274 return Err(CompletionError::ProviderError(emsg));
275 }
276 }
277 };
278
279 function_declarations.push(FunctionDeclaration {
280 name: tool.name,
281 description: tool.description,
282 parameters,
283 });
284 }
285
286 Ok(Self {
287 function_declarations,
288 code_execution: None,
289 })
290 }
291}
292
293impl TryFrom<GenerateContentResponse> for completion::CompletionResponse<GenerateContentResponse> {
294 type Error = CompletionError;
295
296 fn try_from(response: GenerateContentResponse) -> Result<Self, Self::Error> {
297 let candidate = response.candidates.first().ok_or_else(|| {
298 CompletionError::ResponseError("No response candidates in response".into())
299 })?;
300
301 let content = candidate
302 .content
303 .parts
304 .iter()
305 .map(|Part { thought, part, .. }| {
306 Ok(match part {
307 PartKind::Text(text) => {
308 if let Some(thought) = thought
309 && *thought
310 {
311 completion::AssistantContent::Reasoning(Reasoning::new(text))
312 } else {
313 completion::AssistantContent::text(text)
314 }
315 }
316 PartKind::FunctionCall(function_call) => {
317 completion::AssistantContent::tool_call(
318 &function_call.name,
319 &function_call.name,
320 function_call.args.clone(),
321 )
322 }
323 _ => {
324 return Err(CompletionError::ResponseError(
325 "Response did not contain a message or tool call".into(),
326 ));
327 }
328 })
329 })
330 .collect::<Result<Vec<_>, _>>()?;
331
332 let choice = OneOrMany::many(content).map_err(|_| {
333 CompletionError::ResponseError(
334 "Response contained no message or tool call (empty)".to_owned(),
335 )
336 })?;
337
338 let usage = response
339 .usage_metadata
340 .as_ref()
341 .map(|usage| completion::Usage {
342 input_tokens: usage.prompt_token_count as u64,
343 output_tokens: usage.candidates_token_count.unwrap_or(0) as u64,
344 total_tokens: usage.total_token_count as u64,
345 })
346 .unwrap_or_default();
347
348 Ok(completion::CompletionResponse {
349 choice,
350 usage,
351 raw_response: response,
352 })
353 }
354}
355
356pub mod gemini_api_types {
357 use crate::telemetry::ProviderResponseExt;
358 use std::{collections::HashMap, convert::Infallible, str::FromStr};
359
360 use serde::{Deserialize, Serialize};
364 use serde_json::{Value, json};
365
366 use crate::completion::GetTokenUsage;
367 use crate::message::{DocumentSourceKind, ImageMediaType, MessageError, MimeType};
368 use crate::{
369 OneOrMany,
370 completion::CompletionError,
371 message::{self, Reasoning, Text},
372 providers::gemini::gemini_api_types::{CodeExecutionResult, ExecutableCode},
373 };
374
375 #[derive(Debug, Deserialize, Serialize, Default)]
376 #[serde(rename_all = "camelCase")]
377 pub struct AdditionalParameters {
378 pub generation_config: GenerationConfig,
380 #[serde(flatten, skip_serializing_if = "Option::is_none")]
382 pub additional_params: Option<serde_json::Value>,
383 }
384
385 impl AdditionalParameters {
386 pub fn with_config(mut self, cfg: GenerationConfig) -> Self {
387 self.generation_config = cfg;
388 self
389 }
390
391 pub fn with_params(mut self, params: serde_json::Value) -> Self {
392 self.additional_params = Some(params);
393 self
394 }
395 }
396
397 #[derive(Debug, Deserialize, Serialize)]
405 #[serde(rename_all = "camelCase")]
406 pub struct GenerateContentResponse {
407 pub response_id: String,
408 pub candidates: Vec<ContentCandidate>,
410 pub prompt_feedback: Option<PromptFeedback>,
412 pub usage_metadata: Option<UsageMetadata>,
414 pub model_version: Option<String>,
415 }
416
417 impl ProviderResponseExt for GenerateContentResponse {
418 type OutputMessage = ContentCandidate;
419 type Usage = UsageMetadata;
420
421 fn get_response_id(&self) -> Option<String> {
422 Some(self.response_id.clone())
423 }
424
425 fn get_response_model_name(&self) -> Option<String> {
426 None
427 }
428
429 fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
430 self.candidates.clone()
431 }
432
433 fn get_text_response(&self) -> Option<String> {
434 let str = self
435 .candidates
436 .iter()
437 .filter_map(|x| {
438 if x.content.role.as_ref().is_none_or(|y| y != &Role::Model) {
439 return None;
440 }
441
442 let res = x
443 .content
444 .parts
445 .iter()
446 .filter_map(|part| {
447 if let PartKind::Text(ref str) = part.part {
448 Some(str.to_owned())
449 } else {
450 None
451 }
452 })
453 .collect::<Vec<String>>()
454 .join("\n");
455
456 Some(res)
457 })
458 .collect::<Vec<String>>()
459 .join("\n");
460
461 if str.is_empty() { None } else { Some(str) }
462 }
463
464 fn get_usage(&self) -> Option<Self::Usage> {
465 self.usage_metadata.clone()
466 }
467 }
468
469 #[derive(Clone, Debug, Deserialize, Serialize)]
471 #[serde(rename_all = "camelCase")]
472 pub struct ContentCandidate {
473 pub content: Content,
475 pub finish_reason: Option<FinishReason>,
478 pub safety_ratings: Option<Vec<SafetyRating>>,
481 pub citation_metadata: Option<CitationMetadata>,
485 pub token_count: Option<i32>,
487 pub avg_logprobs: Option<f64>,
489 pub logprobs_result: Option<LogprobsResult>,
491 pub index: Option<i32>,
493 }
494
495 #[derive(Clone, Debug, Deserialize, Serialize)]
496 pub struct Content {
497 #[serde(default)]
499 pub parts: Vec<Part>,
500 pub role: Option<Role>,
503 }
504
505 impl TryFrom<message::Message> for Content {
506 type Error = message::MessageError;
507
508 fn try_from(msg: message::Message) -> Result<Self, Self::Error> {
509 Ok(match msg {
510 message::Message::User { content } => Content {
511 parts: content
512 .into_iter()
513 .map(|c| c.try_into())
514 .collect::<Result<Vec<_>, _>>()?,
515 role: Some(Role::User),
516 },
517 message::Message::Assistant { content, .. } => Content {
518 role: Some(Role::Model),
519 parts: content.into_iter().map(|content| content.into()).collect(),
520 },
521 })
522 }
523 }
524
525 impl TryFrom<Content> for message::Message {
526 type Error = message::MessageError;
527
528 fn try_from(content: Content) -> Result<Self, Self::Error> {
529 match content.role {
530 Some(Role::User) | None => {
531 Ok(message::Message::User {
532 content: {
533 let user_content: Result<Vec<_>, _> = content.parts.into_iter()
534 .map(|Part { part, .. }| {
535 Ok(match part {
536 PartKind::Text(text) => message::UserContent::text(text),
537 PartKind::InlineData(inline_data) => {
538 let mime_type =
539 message::MediaType::from_mime_type(&inline_data.mime_type);
540
541 match mime_type {
542 Some(message::MediaType::Image(media_type)) => {
543 message::UserContent::image_base64(
544 inline_data.data,
545 Some(media_type),
546 Some(message::ImageDetail::default()),
547 )
548 }
549 Some(message::MediaType::Document(media_type)) => {
550 message::UserContent::document(
551 inline_data.data,
552 Some(media_type),
553 )
554 }
555 Some(message::MediaType::Audio(media_type)) => {
556 message::UserContent::audio(
557 inline_data.data,
558 Some(media_type),
559 )
560 }
561 _ => {
562 return Err(message::MessageError::ConversionError(
563 format!("Unsupported media type {mime_type:?}"),
564 ));
565 }
566 }
567 }
568 _ => {
569 return Err(message::MessageError::ConversionError(format!(
570 "Unsupported gemini content part type: {part:?}"
571 )));
572 }
573 })
574 })
575 .collect();
576 OneOrMany::many(user_content?).map_err(|_| {
577 message::MessageError::ConversionError(
578 "Failed to create OneOrMany from user content".to_string(),
579 )
580 })?
581 },
582 })
583 }
584 Some(Role::Model) => Ok(message::Message::Assistant {
585 id: None,
586 content: {
587 let assistant_content: Result<Vec<_>, _> = content
588 .parts
589 .into_iter()
590 .map(|Part { thought, part, .. }| {
591 Ok(match part {
592 PartKind::Text(text) => match thought {
593 Some(true) => message::AssistantContent::Reasoning(
594 Reasoning::new(&text),
595 ),
596 _ => message::AssistantContent::Text(Text { text }),
597 },
598
599 PartKind::FunctionCall(function_call) => {
600 message::AssistantContent::ToolCall(function_call.into())
601 }
602 _ => {
603 return Err(message::MessageError::ConversionError(
604 format!("Unsupported part type: {part:?}"),
605 ));
606 }
607 })
608 })
609 .collect();
610 OneOrMany::many(assistant_content?).map_err(|_| {
611 message::MessageError::ConversionError(
612 "Failed to create OneOrMany from assistant content".to_string(),
613 )
614 })?
615 },
616 }),
617 }
618 }
619 }
620
621 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
622 #[serde(rename_all = "lowercase")]
623 pub enum Role {
624 User,
625 Model,
626 }
627
628 #[derive(Debug, Default, Deserialize, Serialize, Clone, PartialEq)]
629 #[serde(rename_all = "camelCase")]
630 pub struct Part {
631 #[serde(skip_serializing_if = "Option::is_none")]
633 pub thought: Option<bool>,
634 #[serde(skip_serializing_if = "Option::is_none")]
636 pub thought_signature: Option<String>,
637 #[serde(flatten)]
638 pub part: PartKind,
639 #[serde(flatten, skip_serializing_if = "Option::is_none")]
640 pub additional_params: Option<Value>,
641 }
642
643 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
647 #[serde(rename_all = "camelCase")]
648 pub enum PartKind {
649 Text(String),
650 InlineData(Blob),
651 FunctionCall(FunctionCall),
652 FunctionResponse(FunctionResponse),
653 FileData(FileData),
654 ExecutableCode(ExecutableCode),
655 CodeExecutionResult(CodeExecutionResult),
656 }
657
658 impl Default for PartKind {
661 fn default() -> Self {
662 Self::Text(String::new())
663 }
664 }
665
666 impl From<String> for Part {
667 fn from(text: String) -> Self {
668 Self {
669 thought: Some(false),
670 thought_signature: None,
671 part: PartKind::Text(text),
672 additional_params: None,
673 }
674 }
675 }
676
677 impl From<&str> for Part {
678 fn from(text: &str) -> Self {
679 Self::from(text.to_string())
680 }
681 }
682
683 impl FromStr for Part {
684 type Err = Infallible;
685
686 fn from_str(s: &str) -> Result<Self, Self::Err> {
687 Ok(s.into())
688 }
689 }
690
691 impl TryFrom<(ImageMediaType, DocumentSourceKind)> for PartKind {
692 type Error = message::MessageError;
693 fn try_from(
694 (mime_type, doc_src): (ImageMediaType, DocumentSourceKind),
695 ) -> Result<Self, Self::Error> {
696 let mime_type = mime_type.to_mime_type().to_string();
697 let part = match doc_src {
698 DocumentSourceKind::Url(url) => PartKind::FileData(FileData {
699 mime_type: Some(mime_type),
700 file_uri: url,
701 }),
702 DocumentSourceKind::Base64(data) | DocumentSourceKind::String(data) => {
703 PartKind::InlineData(Blob { mime_type, data })
704 }
705 DocumentSourceKind::Raw(_) => {
706 return Err(message::MessageError::ConversionError(
707 "Raw files not supported, encode as base64 first".into(),
708 ));
709 }
710 DocumentSourceKind::Unknown => {
711 return Err(message::MessageError::ConversionError(
712 "Can't convert an unknown document source".to_string(),
713 ));
714 }
715 };
716
717 Ok(part)
718 }
719 }
720
721 impl TryFrom<message::UserContent> for Part {
722 type Error = message::MessageError;
723
724 fn try_from(content: message::UserContent) -> Result<Self, Self::Error> {
725 match content {
726 message::UserContent::Text(message::Text { text }) => Ok(Part {
727 thought: Some(false),
728 thought_signature: None,
729 part: PartKind::Text(text),
730 additional_params: None,
731 }),
732 message::UserContent::ToolResult(message::ToolResult { id, content, .. }) => {
733 let content = match content.first() {
734 message::ToolResultContent::Text(text) => text.text,
735 message::ToolResultContent::Image(_) => {
736 return Err(message::MessageError::ConversionError(
737 "Tool result content must be text".to_string(),
738 ));
739 }
740 };
741 let result: serde_json::Value =
743 serde_json::from_str(&content).unwrap_or_else(|error| {
744 tracing::trace!(
745 ?error,
746 "Tool result is not a valid JSON, treat it as normal string"
747 );
748 json!(content)
749 });
750 Ok(Part {
751 thought: Some(false),
752 thought_signature: None,
753 part: PartKind::FunctionResponse(FunctionResponse {
754 name: id,
755 response: Some(json!({ "result": result })),
756 }),
757 additional_params: None,
758 })
759 }
760 message::UserContent::Image(message::Image {
761 data, media_type, ..
762 }) => match media_type {
763 Some(media_type) => match media_type {
764 message::ImageMediaType::JPEG
765 | message::ImageMediaType::PNG
766 | message::ImageMediaType::WEBP
767 | message::ImageMediaType::HEIC
768 | message::ImageMediaType::HEIF => {
769 let part = PartKind::try_from((media_type, data))?;
770 Ok(Part {
771 thought: Some(false),
772 thought_signature: None,
773 part,
774 additional_params: None,
775 })
776 }
777 _ => Err(message::MessageError::ConversionError(format!(
778 "Unsupported image media type {media_type:?}"
779 ))),
780 },
781 None => Err(message::MessageError::ConversionError(
782 "Media type for image is required for Gemini".to_string(),
783 )),
784 },
785 message::UserContent::Document(message::Document {
786 data, media_type, ..
787 }) => {
788 let Some(media_type) = media_type else {
789 return Err(MessageError::ConversionError(
790 "A mime type is required for document inputs to Gemini".to_string(),
791 ));
792 };
793
794 if !media_type.is_code() {
795 let mime_type = media_type.to_mime_type().to_string();
796
797 let part = match data {
798 DocumentSourceKind::Url(file_uri) => PartKind::FileData(FileData {
799 mime_type: Some(mime_type),
800 file_uri,
801 }),
802 DocumentSourceKind::Base64(data) | DocumentSourceKind::String(data) => {
803 PartKind::InlineData(Blob { mime_type, data })
804 }
805 DocumentSourceKind::Raw(_) => {
806 return Err(message::MessageError::ConversionError(
807 "Raw files not supported, encode as base64 first".into(),
808 ));
809 }
810 _ => {
811 return Err(message::MessageError::ConversionError(
812 "Document has no body".to_string(),
813 ));
814 }
815 };
816
817 Ok(Part {
818 thought: Some(false),
819 part,
820 ..Default::default()
821 })
822 } else {
823 Err(message::MessageError::ConversionError(format!(
824 "Unsupported document media type {media_type:?}"
825 )))
826 }
827 }
828
829 message::UserContent::Audio(message::Audio {
830 data, media_type, ..
831 }) => {
832 let Some(media_type) = media_type else {
833 return Err(MessageError::ConversionError(
834 "A mime type is required for audio inputs to Gemini".to_string(),
835 ));
836 };
837
838 let mime_type = media_type.to_mime_type().to_string();
839
840 let part = match data {
841 DocumentSourceKind::Base64(data) => {
842 PartKind::InlineData(Blob { data, mime_type })
843 }
844
845 DocumentSourceKind::Url(file_uri) => PartKind::FileData(FileData {
846 mime_type: Some(mime_type),
847 file_uri,
848 }),
849 DocumentSourceKind::String(_) => {
850 return Err(message::MessageError::ConversionError(
851 "Strings cannot be used as audio files!".into(),
852 ));
853 }
854 DocumentSourceKind::Raw(_) => {
855 return Err(message::MessageError::ConversionError(
856 "Raw files not supported, encode as base64 first".into(),
857 ));
858 }
859 DocumentSourceKind::Unknown => {
860 return Err(message::MessageError::ConversionError(
861 "Content has no body".to_string(),
862 ));
863 }
864 };
865
866 Ok(Part {
867 thought: Some(false),
868 part,
869 ..Default::default()
870 })
871 }
872 message::UserContent::Video(message::Video {
873 data,
874 media_type,
875 additional_params,
876 ..
877 }) => {
878 let mime_type = media_type.map(|media_ty| media_ty.to_mime_type().to_string());
879
880 let part = match data {
881 DocumentSourceKind::Url(file_uri) => {
882 if file_uri.starts_with("https://www.youtube.com") {
883 PartKind::FileData(FileData {
884 mime_type,
885 file_uri,
886 })
887 } else {
888 if mime_type.is_none() {
889 return Err(MessageError::ConversionError(
890 "A mime type is required for non-Youtube video file inputs to Gemini"
891 .to_string(),
892 ));
893 }
894
895 PartKind::FileData(FileData {
896 mime_type,
897 file_uri,
898 })
899 }
900 }
901 DocumentSourceKind::Base64(data) => {
902 let Some(mime_type) = mime_type else {
903 return Err(MessageError::ConversionError(
904 "A media type is expected for base64 encoded strings"
905 .to_string(),
906 ));
907 };
908 PartKind::InlineData(Blob { mime_type, data })
909 }
910 DocumentSourceKind::String(_) => {
911 return Err(message::MessageError::ConversionError(
912 "Strings cannot be used as audio files!".into(),
913 ));
914 }
915 DocumentSourceKind::Raw(_) => {
916 return Err(message::MessageError::ConversionError(
917 "Raw file data not supported, encode as base64 first".into(),
918 ));
919 }
920 DocumentSourceKind::Unknown => {
921 return Err(message::MessageError::ConversionError(
922 "Media type for video is required for Gemini".to_string(),
923 ));
924 }
925 };
926
927 Ok(Part {
928 thought: Some(false),
929 thought_signature: None,
930 part,
931 additional_params,
932 })
933 }
934 }
935 }
936 }
937
938 impl From<message::AssistantContent> for Part {
939 fn from(content: message::AssistantContent) -> Self {
940 match content {
941 message::AssistantContent::Text(message::Text { text }) => text.into(),
942 message::AssistantContent::ToolCall(tool_call) => tool_call.into(),
943 message::AssistantContent::Reasoning(message::Reasoning { reasoning, .. }) => {
944 Part {
945 thought: Some(true),
946 thought_signature: None,
947 part: PartKind::Text(
948 reasoning.first().cloned().unwrap_or_else(|| "".to_string()),
949 ),
950 additional_params: None,
951 }
952 }
953 }
954 }
955 }
956
957 impl From<message::ToolCall> for Part {
958 fn from(tool_call: message::ToolCall) -> Self {
959 Self {
960 thought: Some(false),
961 thought_signature: None,
962 part: PartKind::FunctionCall(FunctionCall {
963 name: tool_call.function.name,
964 args: tool_call.function.arguments,
965 }),
966 additional_params: None,
967 }
968 }
969 }
970
971 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
974 #[serde(rename_all = "camelCase")]
975 pub struct Blob {
976 pub mime_type: String,
979 pub data: String,
981 }
982
983 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
986 pub struct FunctionCall {
987 pub name: String,
990 pub args: serde_json::Value,
992 }
993
994 impl From<FunctionCall> for message::ToolCall {
995 fn from(function_call: FunctionCall) -> Self {
996 Self {
997 id: function_call.name.clone(),
998 call_id: None,
999 function: message::ToolFunction {
1000 name: function_call.name,
1001 arguments: function_call.args,
1002 },
1003 }
1004 }
1005 }
1006
1007 impl From<message::ToolCall> for FunctionCall {
1008 fn from(tool_call: message::ToolCall) -> Self {
1009 Self {
1010 name: tool_call.function.name,
1011 args: tool_call.function.arguments,
1012 }
1013 }
1014 }
1015
1016 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1020 pub struct FunctionResponse {
1021 pub name: String,
1024 pub response: Option<serde_json::Value>,
1026 }
1027
1028 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1030 #[serde(rename_all = "camelCase")]
1031 pub struct FileData {
1032 pub mime_type: Option<String>,
1034 pub file_uri: String,
1036 }
1037
1038 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1039 pub struct SafetyRating {
1040 pub category: HarmCategory,
1041 pub probability: HarmProbability,
1042 }
1043
1044 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1045 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1046 pub enum HarmProbability {
1047 HarmProbabilityUnspecified,
1048 Negligible,
1049 Low,
1050 Medium,
1051 High,
1052 }
1053
1054 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1055 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1056 pub enum HarmCategory {
1057 HarmCategoryUnspecified,
1058 HarmCategoryDerogatory,
1059 HarmCategoryToxicity,
1060 HarmCategoryViolence,
1061 HarmCategorySexually,
1062 HarmCategoryMedical,
1063 HarmCategoryDangerous,
1064 HarmCategoryHarassment,
1065 HarmCategoryHateSpeech,
1066 HarmCategorySexuallyExplicit,
1067 HarmCategoryDangerousContent,
1068 HarmCategoryCivicIntegrity,
1069 }
1070
1071 #[derive(Debug, Deserialize, Clone, Default, Serialize)]
1072 #[serde(rename_all = "camelCase")]
1073 pub struct UsageMetadata {
1074 pub prompt_token_count: i32,
1075 #[serde(skip_serializing_if = "Option::is_none")]
1076 pub cached_content_token_count: Option<i32>,
1077 #[serde(skip_serializing_if = "Option::is_none")]
1078 pub candidates_token_count: Option<i32>,
1079 pub total_token_count: i32,
1080 #[serde(skip_serializing_if = "Option::is_none")]
1081 pub thoughts_token_count: Option<i32>,
1082 }
1083
1084 impl std::fmt::Display for UsageMetadata {
1085 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1086 write!(
1087 f,
1088 "Prompt token count: {}\nCached content token count: {}\nCandidates token count: {}\nTotal token count: {}",
1089 self.prompt_token_count,
1090 match self.cached_content_token_count {
1091 Some(count) => count.to_string(),
1092 None => "n/a".to_string(),
1093 },
1094 match self.candidates_token_count {
1095 Some(count) => count.to_string(),
1096 None => "n/a".to_string(),
1097 },
1098 self.total_token_count
1099 )
1100 }
1101 }
1102
1103 impl GetTokenUsage for UsageMetadata {
1104 fn token_usage(&self) -> Option<crate::completion::Usage> {
1105 let mut usage = crate::completion::Usage::new();
1106
1107 usage.input_tokens = self.prompt_token_count as u64;
1108 usage.output_tokens = (self.cached_content_token_count.unwrap_or_default()
1109 + self.candidates_token_count.unwrap_or_default()
1110 + self.thoughts_token_count.unwrap_or_default())
1111 as u64;
1112 usage.total_tokens = usage.input_tokens + usage.output_tokens;
1113
1114 Some(usage)
1115 }
1116 }
1117
1118 #[derive(Debug, Deserialize, Serialize)]
1120 #[serde(rename_all = "camelCase")]
1121 pub struct PromptFeedback {
1122 pub block_reason: Option<BlockReason>,
1124 pub safety_ratings: Option<Vec<SafetyRating>>,
1126 }
1127
1128 #[derive(Debug, Deserialize, Serialize)]
1130 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1131 pub enum BlockReason {
1132 BlockReasonUnspecified,
1134 Safety,
1136 Other,
1138 Blocklist,
1140 ProhibitedContent,
1142 }
1143
1144 #[derive(Clone, Debug, Deserialize, Serialize)]
1145 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1146 pub enum FinishReason {
1147 FinishReasonUnspecified,
1149 Stop,
1151 MaxTokens,
1153 Safety,
1155 Recitation,
1157 Language,
1159 Other,
1161 Blocklist,
1163 ProhibitedContent,
1165 Spii,
1167 MalformedFunctionCall,
1169 }
1170
1171 #[derive(Clone, Debug, Deserialize, Serialize)]
1172 #[serde(rename_all = "camelCase")]
1173 pub struct CitationMetadata {
1174 pub citation_sources: Vec<CitationSource>,
1175 }
1176
1177 #[derive(Clone, Debug, Deserialize, Serialize)]
1178 #[serde(rename_all = "camelCase")]
1179 pub struct CitationSource {
1180 #[serde(skip_serializing_if = "Option::is_none")]
1181 pub uri: Option<String>,
1182 #[serde(skip_serializing_if = "Option::is_none")]
1183 pub start_index: Option<i32>,
1184 #[serde(skip_serializing_if = "Option::is_none")]
1185 pub end_index: Option<i32>,
1186 #[serde(skip_serializing_if = "Option::is_none")]
1187 pub license: Option<String>,
1188 }
1189
1190 #[derive(Clone, Debug, Deserialize, Serialize)]
1191 #[serde(rename_all = "camelCase")]
1192 pub struct LogprobsResult {
1193 pub top_candidate: Vec<TopCandidate>,
1194 pub chosen_candidate: Vec<LogProbCandidate>,
1195 }
1196
1197 #[derive(Clone, Debug, Deserialize, Serialize)]
1198 pub struct TopCandidate {
1199 pub candidates: Vec<LogProbCandidate>,
1200 }
1201
1202 #[derive(Clone, Debug, Deserialize, Serialize)]
1203 #[serde(rename_all = "camelCase")]
1204 pub struct LogProbCandidate {
1205 pub token: String,
1206 pub token_id: String,
1207 pub log_probability: f64,
1208 }
1209
1210 #[derive(Debug, Deserialize, Serialize)]
1215 #[serde(rename_all = "camelCase")]
1216 pub struct GenerationConfig {
1217 #[serde(skip_serializing_if = "Option::is_none")]
1220 pub stop_sequences: Option<Vec<String>>,
1221 #[serde(skip_serializing_if = "Option::is_none")]
1227 pub response_mime_type: Option<String>,
1228 #[serde(skip_serializing_if = "Option::is_none")]
1232 pub response_schema: Option<Schema>,
1233 #[serde(skip_serializing_if = "Option::is_none")]
1236 pub candidate_count: Option<i32>,
1237 #[serde(skip_serializing_if = "Option::is_none")]
1240 pub max_output_tokens: Option<u64>,
1241 #[serde(skip_serializing_if = "Option::is_none")]
1244 pub temperature: Option<f64>,
1245 #[serde(skip_serializing_if = "Option::is_none")]
1252 pub top_p: Option<f64>,
1253 #[serde(skip_serializing_if = "Option::is_none")]
1259 pub top_k: Option<i32>,
1260 #[serde(skip_serializing_if = "Option::is_none")]
1266 pub presence_penalty: Option<f64>,
1267 #[serde(skip_serializing_if = "Option::is_none")]
1275 pub frequency_penalty: Option<f64>,
1276 #[serde(skip_serializing_if = "Option::is_none")]
1278 pub response_logprobs: Option<bool>,
1279 #[serde(skip_serializing_if = "Option::is_none")]
1282 pub logprobs: Option<i32>,
1283 #[serde(skip_serializing_if = "Option::is_none")]
1285 pub thinking_config: Option<ThinkingConfig>,
1286 }
1287
1288 impl Default for GenerationConfig {
1289 fn default() -> Self {
1290 Self {
1291 temperature: Some(1.0),
1292 max_output_tokens: Some(4096),
1293 stop_sequences: None,
1294 response_mime_type: None,
1295 response_schema: None,
1296 candidate_count: None,
1297 top_p: None,
1298 top_k: None,
1299 presence_penalty: None,
1300 frequency_penalty: None,
1301 response_logprobs: None,
1302 logprobs: None,
1303 thinking_config: None,
1304 }
1305 }
1306 }
1307
1308 #[derive(Debug, Deserialize, Serialize)]
1309 #[serde(rename_all = "camelCase")]
1310 pub struct ThinkingConfig {
1311 pub thinking_budget: u32,
1312 pub include_thoughts: Option<bool>,
1313 }
1314 #[derive(Debug, Deserialize, Serialize, Clone)]
1318 pub struct Schema {
1319 pub r#type: String,
1320 #[serde(skip_serializing_if = "Option::is_none")]
1321 pub format: Option<String>,
1322 #[serde(skip_serializing_if = "Option::is_none")]
1323 pub description: Option<String>,
1324 #[serde(skip_serializing_if = "Option::is_none")]
1325 pub nullable: Option<bool>,
1326 #[serde(skip_serializing_if = "Option::is_none")]
1327 pub r#enum: Option<Vec<String>>,
1328 #[serde(skip_serializing_if = "Option::is_none")]
1329 pub max_items: Option<i32>,
1330 #[serde(skip_serializing_if = "Option::is_none")]
1331 pub min_items: Option<i32>,
1332 #[serde(skip_serializing_if = "Option::is_none")]
1333 pub properties: Option<HashMap<String, Schema>>,
1334 #[serde(skip_serializing_if = "Option::is_none")]
1335 pub required: Option<Vec<String>>,
1336 #[serde(skip_serializing_if = "Option::is_none")]
1337 pub items: Option<Box<Schema>>,
1338 }
1339
1340 pub fn flatten_schema(mut schema: Value) -> Result<Value, CompletionError> {
1346 let defs = if let Some(obj) = schema.as_object() {
1348 obj.get("$defs").or_else(|| obj.get("definitions")).cloned()
1349 } else {
1350 None
1351 };
1352
1353 let Some(defs_value) = defs else {
1354 return Ok(schema);
1355 };
1356
1357 let Some(defs_obj) = defs_value.as_object() else {
1358 return Err(CompletionError::ResponseError(
1359 "$defs must be an object".into(),
1360 ));
1361 };
1362
1363 resolve_refs(&mut schema, defs_obj)?;
1364
1365 if let Some(obj) = schema.as_object_mut() {
1367 obj.remove("$defs");
1368 obj.remove("definitions");
1369 }
1370
1371 Ok(schema)
1372 }
1373
1374 fn resolve_refs(
1377 value: &mut Value,
1378 defs: &serde_json::Map<String, Value>,
1379 ) -> Result<(), CompletionError> {
1380 match value {
1381 Value::Object(obj) => {
1382 if let Some(ref_value) = obj.get("$ref")
1383 && let Some(ref_str) = ref_value.as_str()
1384 {
1385 let def_name = parse_ref_path(ref_str)?;
1387
1388 let def = defs.get(&def_name).ok_or_else(|| {
1389 CompletionError::ResponseError(format!("Reference not found: {}", ref_str))
1390 })?;
1391
1392 let mut resolved = def.clone();
1393 resolve_refs(&mut resolved, defs)?;
1394 *value = resolved;
1395 return Ok(());
1396 }
1397
1398 for (_, v) in obj.iter_mut() {
1399 resolve_refs(v, defs)?;
1400 }
1401 }
1402 Value::Array(arr) => {
1403 for item in arr.iter_mut() {
1404 resolve_refs(item, defs)?;
1405 }
1406 }
1407 _ => {}
1408 }
1409
1410 Ok(())
1411 }
1412
1413 fn parse_ref_path(ref_str: &str) -> Result<String, CompletionError> {
1419 if let Some(fragment) = ref_str.strip_prefix('#') {
1420 if let Some(name) = fragment.strip_prefix("/$defs/") {
1421 Ok(name.to_string())
1422 } else if let Some(name) = fragment.strip_prefix("/definitions/") {
1423 Ok(name.to_string())
1424 } else {
1425 Err(CompletionError::ResponseError(format!(
1426 "Unsupported reference format: {}",
1427 ref_str
1428 )))
1429 }
1430 } else {
1431 Err(CompletionError::ResponseError(format!(
1432 "Only fragment references (#/...) are supported: {}",
1433 ref_str
1434 )))
1435 }
1436 }
1437
1438 fn extract_type(type_value: &Value) -> Option<String> {
1441 if type_value.is_string() {
1442 type_value.as_str().map(String::from)
1443 } else if type_value.is_array() {
1444 type_value
1445 .as_array()
1446 .and_then(|arr| arr.first())
1447 .and_then(|v| v.as_str().map(String::from))
1448 } else {
1449 None
1450 }
1451 }
1452
1453 fn extract_type_from_composition(composition: &Value) -> Option<String> {
1456 composition.as_array().and_then(|arr| {
1457 arr.iter().find_map(|schema| {
1458 if let Some(obj) = schema.as_object() {
1459 if let Some(type_val) = obj.get("type")
1461 && let Some(type_str) = type_val.as_str()
1462 && type_str == "null"
1463 {
1464 return None;
1465 }
1466 obj.get("type").and_then(extract_type).or_else(|| {
1468 if obj.contains_key("properties") {
1469 Some("object".to_string())
1470 } else {
1471 None
1472 }
1473 })
1474 } else {
1475 None
1476 }
1477 })
1478 })
1479 }
1480
1481 fn extract_schema_from_composition(
1484 composition: &Value,
1485 ) -> Option<serde_json::Map<String, Value>> {
1486 composition.as_array().and_then(|arr| {
1487 arr.iter().find_map(|schema| {
1488 if let Some(obj) = schema.as_object()
1489 && let Some(type_val) = obj.get("type")
1490 && let Some(type_str) = type_val.as_str()
1491 {
1492 if type_str == "null" {
1493 return None;
1494 }
1495 Some(obj.clone())
1496 } else {
1497 None
1498 }
1499 })
1500 })
1501 }
1502
1503 fn infer_type(obj: &serde_json::Map<String, Value>) -> String {
1506 if let Some(type_val) = obj.get("type")
1508 && let Some(type_str) = extract_type(type_val)
1509 {
1510 return type_str;
1511 }
1512
1513 if let Some(any_of) = obj.get("anyOf")
1515 && let Some(type_str) = extract_type_from_composition(any_of)
1516 {
1517 return type_str;
1518 }
1519
1520 if let Some(one_of) = obj.get("oneOf")
1521 && let Some(type_str) = extract_type_from_composition(one_of)
1522 {
1523 return type_str;
1524 }
1525
1526 if let Some(all_of) = obj.get("allOf")
1527 && let Some(type_str) = extract_type_from_composition(all_of)
1528 {
1529 return type_str;
1530 }
1531
1532 if obj.contains_key("properties") {
1534 "object".to_string()
1535 } else {
1536 String::new()
1537 }
1538 }
1539
1540 impl TryFrom<Value> for Schema {
1541 type Error = CompletionError;
1542
1543 fn try_from(value: Value) -> Result<Self, Self::Error> {
1544 let flattened_val = flatten_schema(value)?;
1545 if let Some(obj) = flattened_val.as_object() {
1546 let props_source = if obj.get("properties").is_none() {
1549 if let Some(any_of) = obj.get("anyOf") {
1550 extract_schema_from_composition(any_of)
1551 } else if let Some(one_of) = obj.get("oneOf") {
1552 extract_schema_from_composition(one_of)
1553 } else if let Some(all_of) = obj.get("allOf") {
1554 extract_schema_from_composition(all_of)
1555 } else {
1556 None
1557 }
1558 .unwrap_or(obj.clone())
1559 } else {
1560 obj.clone()
1561 };
1562
1563 Ok(Schema {
1564 r#type: infer_type(obj),
1565 format: obj.get("format").and_then(|v| v.as_str()).map(String::from),
1566 description: obj
1567 .get("description")
1568 .and_then(|v| v.as_str())
1569 .map(String::from),
1570 nullable: obj.get("nullable").and_then(|v| v.as_bool()),
1571 r#enum: obj.get("enum").and_then(|v| v.as_array()).map(|arr| {
1572 arr.iter()
1573 .filter_map(|v| v.as_str().map(String::from))
1574 .collect()
1575 }),
1576 max_items: obj
1577 .get("maxItems")
1578 .and_then(|v| v.as_i64())
1579 .map(|v| v as i32),
1580 min_items: obj
1581 .get("minItems")
1582 .and_then(|v| v.as_i64())
1583 .map(|v| v as i32),
1584 properties: props_source
1585 .get("properties")
1586 .and_then(|v| v.as_object())
1587 .map(|map| {
1588 map.iter()
1589 .filter_map(|(k, v)| {
1590 v.clone().try_into().ok().map(|schema| (k.clone(), schema))
1591 })
1592 .collect()
1593 }),
1594 required: props_source
1595 .get("required")
1596 .and_then(|v| v.as_array())
1597 .map(|arr| {
1598 arr.iter()
1599 .filter_map(|v| v.as_str().map(String::from))
1600 .collect()
1601 }),
1602 items: obj
1603 .get("items")
1604 .and_then(|v| v.clone().try_into().ok())
1605 .map(Box::new),
1606 })
1607 } else {
1608 Err(CompletionError::ResponseError(
1609 "Expected a JSON object for Schema".into(),
1610 ))
1611 }
1612 }
1613 }
1614
1615 #[derive(Debug, Serialize)]
1616 #[serde(rename_all = "camelCase")]
1617 pub struct GenerateContentRequest {
1618 pub contents: Vec<Content>,
1619 #[serde(skip_serializing_if = "Option::is_none")]
1620 pub tools: Option<Tool>,
1621 pub tool_config: Option<ToolConfig>,
1622 pub generation_config: Option<GenerationConfig>,
1624 pub safety_settings: Option<Vec<SafetySetting>>,
1638 pub system_instruction: Option<Content>,
1641 #[serde(flatten, skip_serializing_if = "Option::is_none")]
1644 pub additional_params: Option<serde_json::Value>,
1645 }
1646
1647 #[derive(Debug, Serialize)]
1648 #[serde(rename_all = "camelCase")]
1649 pub struct Tool {
1650 pub function_declarations: Vec<FunctionDeclaration>,
1651 pub code_execution: Option<CodeExecution>,
1652 }
1653
1654 #[derive(Debug, Serialize, Clone)]
1655 #[serde(rename_all = "camelCase")]
1656 pub struct FunctionDeclaration {
1657 pub name: String,
1658 pub description: String,
1659 #[serde(skip_serializing_if = "Option::is_none")]
1660 pub parameters: Option<Schema>,
1661 }
1662
1663 #[derive(Debug, Serialize, Deserialize)]
1664 #[serde(rename_all = "camelCase")]
1665 pub struct ToolConfig {
1666 pub function_calling_config: Option<FunctionCallingMode>,
1667 }
1668
1669 #[derive(Debug, Serialize, Deserialize, Default)]
1670 #[serde(tag = "mode", rename_all = "UPPERCASE")]
1671 pub enum FunctionCallingMode {
1672 #[default]
1673 Auto,
1674 None,
1675 Any {
1676 #[serde(skip_serializing_if = "Option::is_none")]
1677 allowed_function_names: Option<Vec<String>>,
1678 },
1679 }
1680
1681 impl TryFrom<message::ToolChoice> for FunctionCallingMode {
1682 type Error = CompletionError;
1683 fn try_from(value: message::ToolChoice) -> Result<Self, Self::Error> {
1684 let res = match value {
1685 message::ToolChoice::Auto => Self::Auto,
1686 message::ToolChoice::None => Self::None,
1687 message::ToolChoice::Required => Self::Any {
1688 allowed_function_names: None,
1689 },
1690 message::ToolChoice::Specific { function_names } => Self::Any {
1691 allowed_function_names: Some(function_names),
1692 },
1693 };
1694
1695 Ok(res)
1696 }
1697 }
1698
1699 #[derive(Debug, Serialize)]
1700 pub struct CodeExecution {}
1701
1702 #[derive(Debug, Serialize)]
1703 #[serde(rename_all = "camelCase")]
1704 pub struct SafetySetting {
1705 pub category: HarmCategory,
1706 pub threshold: HarmBlockThreshold,
1707 }
1708
1709 #[derive(Debug, Serialize)]
1710 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1711 pub enum HarmBlockThreshold {
1712 HarmBlockThresholdUnspecified,
1713 BlockLowAndAbove,
1714 BlockMediumAndAbove,
1715 BlockOnlyHigh,
1716 BlockNone,
1717 Off,
1718 }
1719}
1720
1721#[cfg(test)]
1722mod tests {
1723 use crate::{message, providers::gemini::completion::gemini_api_types::flatten_schema};
1724
1725 use super::*;
1726 use serde_json::json;
1727
1728 #[test]
1729 fn test_deserialize_message_user() {
1730 let raw_message = r#"{
1731 "parts": [
1732 {"text": "Hello, world!"},
1733 {"inlineData": {"mimeType": "image/png", "data": "base64encodeddata"}},
1734 {"functionCall": {"name": "test_function", "args": {"arg1": "value1"}}},
1735 {"functionResponse": {"name": "test_function", "response": {"result": "success"}}},
1736 {"fileData": {"mimeType": "application/pdf", "fileUri": "http://example.com/file.pdf"}},
1737 {"executableCode": {"code": "print('Hello, world!')", "language": "PYTHON"}},
1738 {"codeExecutionResult": {"output": "Hello, world!", "outcome": "OUTCOME_OK"}}
1739 ],
1740 "role": "user"
1741 }"#;
1742
1743 let content: Content = {
1744 let jd = &mut serde_json::Deserializer::from_str(raw_message);
1745 serde_path_to_error::deserialize(jd).unwrap_or_else(|err| {
1746 panic!("Deserialization error at {}: {}", err.path(), err);
1747 })
1748 };
1749 assert_eq!(content.role, Some(Role::User));
1750 assert_eq!(content.parts.len(), 7);
1751
1752 let parts: Vec<Part> = content.parts.into_iter().collect();
1753
1754 if let Part {
1755 part: PartKind::Text(text),
1756 ..
1757 } = &parts[0]
1758 {
1759 assert_eq!(text, "Hello, world!");
1760 } else {
1761 panic!("Expected text part");
1762 }
1763
1764 if let Part {
1765 part: PartKind::InlineData(inline_data),
1766 ..
1767 } = &parts[1]
1768 {
1769 assert_eq!(inline_data.mime_type, "image/png");
1770 assert_eq!(inline_data.data, "base64encodeddata");
1771 } else {
1772 panic!("Expected inline data part");
1773 }
1774
1775 if let Part {
1776 part: PartKind::FunctionCall(function_call),
1777 ..
1778 } = &parts[2]
1779 {
1780 assert_eq!(function_call.name, "test_function");
1781 assert_eq!(
1782 function_call.args.as_object().unwrap().get("arg1").unwrap(),
1783 "value1"
1784 );
1785 } else {
1786 panic!("Expected function call part");
1787 }
1788
1789 if let Part {
1790 part: PartKind::FunctionResponse(function_response),
1791 ..
1792 } = &parts[3]
1793 {
1794 assert_eq!(function_response.name, "test_function");
1795 assert_eq!(
1796 function_response
1797 .response
1798 .as_ref()
1799 .unwrap()
1800 .get("result")
1801 .unwrap(),
1802 "success"
1803 );
1804 } else {
1805 panic!("Expected function response part");
1806 }
1807
1808 if let Part {
1809 part: PartKind::FileData(file_data),
1810 ..
1811 } = &parts[4]
1812 {
1813 assert_eq!(file_data.mime_type.as_ref().unwrap(), "application/pdf");
1814 assert_eq!(file_data.file_uri, "http://example.com/file.pdf");
1815 } else {
1816 panic!("Expected file data part");
1817 }
1818
1819 if let Part {
1820 part: PartKind::ExecutableCode(executable_code),
1821 ..
1822 } = &parts[5]
1823 {
1824 assert_eq!(executable_code.code, "print('Hello, world!')");
1825 } else {
1826 panic!("Expected executable code part");
1827 }
1828
1829 if let Part {
1830 part: PartKind::CodeExecutionResult(code_execution_result),
1831 ..
1832 } = &parts[6]
1833 {
1834 assert_eq!(
1835 code_execution_result.clone().output.unwrap(),
1836 "Hello, world!"
1837 );
1838 } else {
1839 panic!("Expected code execution result part");
1840 }
1841 }
1842
1843 #[test]
1844 fn test_deserialize_message_model() {
1845 let json_data = json!({
1846 "parts": [{"text": "Hello, user!"}],
1847 "role": "model"
1848 });
1849
1850 let content: Content = serde_json::from_value(json_data).unwrap();
1851 assert_eq!(content.role, Some(Role::Model));
1852 assert_eq!(content.parts.len(), 1);
1853 if let Some(Part {
1854 part: PartKind::Text(text),
1855 ..
1856 }) = content.parts.first()
1857 {
1858 assert_eq!(text, "Hello, user!");
1859 } else {
1860 panic!("Expected text part");
1861 }
1862 }
1863
1864 #[test]
1865 fn test_message_conversion_user() {
1866 let msg = message::Message::user("Hello, world!");
1867 let content: Content = msg.try_into().unwrap();
1868 assert_eq!(content.role, Some(Role::User));
1869 assert_eq!(content.parts.len(), 1);
1870 if let Some(Part {
1871 part: PartKind::Text(text),
1872 ..
1873 }) = &content.parts.first()
1874 {
1875 assert_eq!(text, "Hello, world!");
1876 } else {
1877 panic!("Expected text part");
1878 }
1879 }
1880
1881 #[test]
1882 fn test_message_conversion_model() {
1883 let msg = message::Message::assistant("Hello, user!");
1884
1885 let content: Content = msg.try_into().unwrap();
1886 assert_eq!(content.role, Some(Role::Model));
1887 assert_eq!(content.parts.len(), 1);
1888 if let Some(Part {
1889 part: PartKind::Text(text),
1890 ..
1891 }) = &content.parts.first()
1892 {
1893 assert_eq!(text, "Hello, user!");
1894 } else {
1895 panic!("Expected text part");
1896 }
1897 }
1898
1899 #[test]
1900 fn test_message_conversion_tool_call() {
1901 let tool_call = message::ToolCall {
1902 id: "test_tool".to_string(),
1903 call_id: None,
1904 function: message::ToolFunction {
1905 name: "test_function".to_string(),
1906 arguments: json!({"arg1": "value1"}),
1907 },
1908 };
1909
1910 let msg = message::Message::Assistant {
1911 id: None,
1912 content: OneOrMany::one(message::AssistantContent::ToolCall(tool_call)),
1913 };
1914
1915 let content: Content = msg.try_into().unwrap();
1916 assert_eq!(content.role, Some(Role::Model));
1917 assert_eq!(content.parts.len(), 1);
1918 if let Some(Part {
1919 part: PartKind::FunctionCall(function_call),
1920 ..
1921 }) = content.parts.first()
1922 {
1923 assert_eq!(function_call.name, "test_function");
1924 assert_eq!(
1925 function_call.args.as_object().unwrap().get("arg1").unwrap(),
1926 "value1"
1927 );
1928 } else {
1929 panic!("Expected function call part");
1930 }
1931 }
1932
1933 #[test]
1934 fn test_vec_schema_conversion() {
1935 let schema_with_ref = json!({
1936 "type": "array",
1937 "items": {
1938 "$ref": "#/$defs/Person"
1939 },
1940 "$defs": {
1941 "Person": {
1942 "type": "object",
1943 "properties": {
1944 "first_name": {
1945 "type": ["string", "null"],
1946 "description": "The person's first name, if provided (null otherwise)"
1947 },
1948 "last_name": {
1949 "type": ["string", "null"],
1950 "description": "The person's last name, if provided (null otherwise)"
1951 },
1952 "job": {
1953 "type": ["string", "null"],
1954 "description": "The person's job, if provided (null otherwise)"
1955 }
1956 },
1957 "required": []
1958 }
1959 }
1960 });
1961
1962 let result: Result<Schema, _> = schema_with_ref.try_into();
1963
1964 match result {
1965 Ok(schema) => {
1966 assert_eq!(schema.r#type, "array");
1967
1968 if let Some(items) = schema.items {
1969 println!("item types: {}", items.r#type);
1970
1971 assert_ne!(items.r#type, "", "Items type should not be empty string!");
1972 assert_eq!(items.r#type, "object", "Items should be object type");
1973 } else {
1974 panic!("Schema should have items field for array type");
1975 }
1976 }
1977 Err(e) => println!("Schema conversion failed: {:?}", e),
1978 }
1979 }
1980
1981 #[test]
1982 fn test_object_schema() {
1983 let simple_schema = json!({
1984 "type": "object",
1985 "properties": {
1986 "name": {
1987 "type": "string"
1988 }
1989 }
1990 });
1991
1992 let schema: Schema = simple_schema.try_into().unwrap();
1993 assert_eq!(schema.r#type, "object");
1994 assert!(schema.properties.is_some());
1995 }
1996
1997 #[test]
1998 fn test_array_with_inline_items() {
1999 let inline_schema = json!({
2000 "type": "array",
2001 "items": {
2002 "type": "object",
2003 "properties": {
2004 "name": {
2005 "type": "string"
2006 }
2007 }
2008 }
2009 });
2010
2011 let schema: Schema = inline_schema.try_into().unwrap();
2012 assert_eq!(schema.r#type, "array");
2013
2014 if let Some(items) = schema.items {
2015 assert_eq!(items.r#type, "object");
2016 assert!(items.properties.is_some());
2017 } else {
2018 panic!("Schema should have items field");
2019 }
2020 }
2021 #[test]
2022 fn test_flattened_schema() {
2023 let ref_schema = json!({
2024 "type": "array",
2025 "items": {
2026 "$ref": "#/$defs/Person"
2027 },
2028 "$defs": {
2029 "Person": {
2030 "type": "object",
2031 "properties": {
2032 "name": { "type": "string" }
2033 }
2034 }
2035 }
2036 });
2037
2038 let flattened = flatten_schema(ref_schema).unwrap();
2039 let schema: Schema = flattened.try_into().unwrap();
2040
2041 assert_eq!(schema.r#type, "array");
2042
2043 if let Some(items) = schema.items {
2044 println!("Flattened items type: '{}'", items.r#type);
2045
2046 assert_eq!(items.r#type, "object");
2047 assert!(items.properties.is_some());
2048 }
2049 }
2050}