1pub const GEMINI_2_5_PRO_PREVIEW_06_05: &str = "gemini-2.5-pro-preview-06-05";
7pub const GEMINI_2_5_PRO_PREVIEW_05_06: &str = "gemini-2.5-pro-preview-05-06";
9pub const GEMINI_2_5_PRO_PREVIEW_03_25: &str = "gemini-2.5-pro-preview-03-25";
11pub const GEMINI_2_5_FLASH_PREVIEW_04_17: &str = "gemini-2.5-flash-preview-04-17";
13pub const GEMINI_2_5_PRO_EXP_03_25: &str = "gemini-2.5-pro-exp-03-25";
15pub const GEMINI_2_5_FLASH: &str = "gemini-2.5-flash";
17pub const GEMINI_2_0_FLASH_LITE: &str = "gemini-2.0-flash-lite";
19pub const GEMINI_2_0_FLASH: &str = "gemini-2.0-flash";
21
22use self::gemini_api_types::Schema;
23use crate::http_client::HttpClientExt;
24use crate::message::{self, MimeType, Reasoning};
25
26use crate::providers::gemini::completion::gemini_api_types::{
27 AdditionalParameters, FunctionCallingMode, ToolConfig,
28};
29use crate::providers::gemini::streaming::StreamingCompletionResponse;
30use crate::telemetry::SpanCombinator;
31use crate::{
32 OneOrMany,
33 completion::{self, CompletionError, CompletionRequest},
34};
35use gemini_api_types::{
36 Content, FunctionDeclaration, GenerateContentRequest, GenerateContentResponse, Part, PartKind,
37 Role, Tool,
38};
39use serde_json::{Map, Value};
40use std::convert::TryFrom;
41use tracing::info_span;
42
43use super::Client;
44
45#[derive(Clone, Debug)]
50pub struct CompletionModel<T = reqwest::Client> {
51 pub(crate) client: Client<T>,
52 pub model: String,
53}
54
55impl<T> CompletionModel<T> {
56 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
57 Self {
58 client,
59 model: model.into(),
60 }
61 }
62
63 pub fn with_model(client: Client<T>, model: &str) -> Self {
64 Self {
65 client,
66 model: model.into(),
67 }
68 }
69}
70
71impl<T> completion::CompletionModel for CompletionModel<T>
72where
73 T: HttpClientExt + Clone + 'static,
74{
75 type Response = GenerateContentResponse;
76 type StreamingResponse = StreamingCompletionResponse;
77 type Client = super::Client<T>;
78
79 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
80 Self::new(client.clone(), model)
81 }
82
83 #[cfg_attr(feature = "worker", worker::send)]
84 async fn completion(
85 &self,
86 completion_request: CompletionRequest,
87 ) -> Result<completion::CompletionResponse<GenerateContentResponse>, CompletionError> {
88 let span = if tracing::Span::current().is_disabled() {
89 info_span!(
90 target: "rig::completions",
91 "generate_content",
92 gen_ai.operation.name = "generate_content",
93 gen_ai.provider.name = "gcp.gemini",
94 gen_ai.request.model = self.model,
95 gen_ai.system_instructions = &completion_request.preamble,
96 gen_ai.response.id = tracing::field::Empty,
97 gen_ai.response.model = tracing::field::Empty,
98 gen_ai.usage.output_tokens = tracing::field::Empty,
99 gen_ai.usage.input_tokens = tracing::field::Empty,
100 gen_ai.input.messages = tracing::field::Empty,
101 gen_ai.output.messages = tracing::field::Empty,
102 )
103 } else {
104 tracing::Span::current()
105 };
106
107 let request = create_request_body(completion_request)?;
108 span.record_model_input(&request.contents);
109
110 span.record_model_input(&request.contents);
111
112 tracing::trace!(
113 target: "rig::completions",
114 "Sending completion request to Gemini API {}",
115 serde_json::to_string_pretty(&request)?
116 );
117
118 let body = serde_json::to_vec(&request)?;
119
120 let path = format!("/v1beta/models/{}:generateContent", self.model);
121
122 let request = self
123 .client
124 .post(path.as_str())?
125 .body(body)
126 .map_err(|e| CompletionError::HttpError(e.into()))?;
127
128 let response = self.client.send::<_, Vec<u8>>(request).await?;
129
130 if response.status().is_success() {
131 let response_body = response
132 .into_body()
133 .await
134 .map_err(CompletionError::HttpError)?;
135
136 let response_text = String::from_utf8_lossy(&response_body).to_string();
137 tracing::debug!("Received raw response from Gemini API: {}", response_text);
138
139 let response: GenerateContentResponse = serde_json::from_slice(&response_body)
140 .map_err(|err| {
141 tracing::error!(
142 error = %err,
143 body = %response_text,
144 "Failed to deserialize Gemini completion response"
145 );
146 CompletionError::JsonError(err)
147 })?;
148
149 match response.usage_metadata {
150 Some(ref usage) => tracing::info!(target: "rig",
151 "Gemini completion token usage: {}",
152 usage
153 ),
154 None => tracing::info!(target: "rig",
155 "Gemini completion token usage: n/a",
156 ),
157 }
158
159 let span = tracing::Span::current();
160 span.record_model_output(&response.candidates);
161 span.record_response_metadata(&response);
162 span.record_token_usage(&response.usage_metadata);
163
164 tracing::trace!(
165 "Received response from Gemini API: {}",
166 serde_json::to_string_pretty(&response)?
167 );
168
169 response.try_into()
170 } else {
171 let text = String::from_utf8_lossy(
172 &response
173 .into_body()
174 .await
175 .map_err(CompletionError::HttpError)?,
176 )
177 .into();
178
179 Err(CompletionError::ProviderError(text))
180 }
181 }
182
183 #[cfg_attr(feature = "worker", worker::send)]
184 async fn stream(
185 &self,
186 request: CompletionRequest,
187 ) -> Result<
188 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
189 CompletionError,
190 > {
191 CompletionModel::stream(self, request).await
192 }
193}
194
195pub(crate) fn create_request_body(
196 completion_request: CompletionRequest,
197) -> Result<GenerateContentRequest, CompletionError> {
198 let mut full_history = Vec::new();
199 full_history.extend(completion_request.chat_history);
200
201 let additional_params = completion_request
202 .additional_params
203 .unwrap_or_else(|| Value::Object(Map::new()));
204
205 let AdditionalParameters {
206 mut generation_config,
207 additional_params,
208 } = serde_json::from_value::<AdditionalParameters>(additional_params)?;
209
210 generation_config = generation_config.map(|mut cfg| {
211 if let Some(temp) = completion_request.temperature {
212 cfg.temperature = Some(temp);
213 };
214
215 if let Some(max_tokens) = completion_request.max_tokens {
216 cfg.max_output_tokens = Some(max_tokens);
217 };
218
219 cfg
220 });
221
222 let system_instruction = completion_request.preamble.clone().map(|preamble| Content {
223 parts: vec![preamble.into()],
224 role: Some(Role::Model),
225 });
226
227 let tools = if completion_request.tools.is_empty() {
228 None
229 } else {
230 Some(Tool::try_from(completion_request.tools)?)
231 };
232
233 let tool_config = if let Some(cfg) = completion_request.tool_choice {
234 Some(ToolConfig {
235 function_calling_config: Some(FunctionCallingMode::try_from(cfg)?),
236 })
237 } else {
238 None
239 };
240
241 let request = GenerateContentRequest {
242 contents: full_history
243 .into_iter()
244 .map(|msg| {
245 msg.try_into()
246 .map_err(|e| CompletionError::RequestError(Box::new(e)))
247 })
248 .collect::<Result<Vec<_>, _>>()?,
249 generation_config,
250 safety_settings: None,
251 tools,
252 tool_config,
253 system_instruction,
254 additional_params,
255 };
256
257 Ok(request)
258}
259
260impl TryFrom<completion::ToolDefinition> for Tool {
261 type Error = CompletionError;
262
263 fn try_from(tool: completion::ToolDefinition) -> Result<Self, Self::Error> {
264 let parameters: Option<Schema> =
265 if tool.parameters == serde_json::json!({"type": "object", "properties": {}}) {
266 None
267 } else {
268 Some(tool.parameters.try_into()?)
269 };
270
271 Ok(Self {
272 function_declarations: vec![FunctionDeclaration {
273 name: tool.name,
274 description: tool.description,
275 parameters,
276 }],
277 code_execution: None,
278 })
279 }
280}
281
282impl TryFrom<Vec<completion::ToolDefinition>> for Tool {
283 type Error = CompletionError;
284
285 fn try_from(tools: Vec<completion::ToolDefinition>) -> Result<Self, Self::Error> {
286 let mut function_declarations = Vec::new();
287
288 for tool in tools {
289 let parameters =
290 if tool.parameters == serde_json::json!({"type": "object", "properties": {}}) {
291 None
292 } else {
293 match tool.parameters.try_into() {
294 Ok(schema) => Some(schema),
295 Err(e) => {
296 let emsg = format!(
297 "Tool '{}' could not be converted to a schema: {:?}",
298 tool.name, e,
299 );
300 return Err(CompletionError::ProviderError(emsg));
301 }
302 }
303 };
304
305 function_declarations.push(FunctionDeclaration {
306 name: tool.name,
307 description: tool.description,
308 parameters,
309 });
310 }
311
312 Ok(Self {
313 function_declarations,
314 code_execution: None,
315 })
316 }
317}
318
319impl TryFrom<GenerateContentResponse> for completion::CompletionResponse<GenerateContentResponse> {
320 type Error = CompletionError;
321
322 fn try_from(response: GenerateContentResponse) -> Result<Self, Self::Error> {
323 let candidate = response.candidates.first().ok_or_else(|| {
324 CompletionError::ResponseError("No response candidates in response".into())
325 })?;
326
327 let content = candidate
328 .content
329 .as_ref()
330 .ok_or_else(|| {
331 let reason = candidate
332 .finish_reason
333 .as_ref()
334 .map(|r| format!("finish_reason={r:?}"))
335 .unwrap_or_else(|| "finish_reason=<unknown>".to_string());
336 let message = candidate
337 .finish_message
338 .as_deref()
339 .unwrap_or("no finish message provided");
340 CompletionError::ResponseError(format!(
341 "Gemini candidate missing content ({reason}, finish_message={message})"
342 ))
343 })?
344 .parts
345 .iter()
346 .map(|Part { thought, part, .. }| {
347 Ok(match part {
348 PartKind::Text(text) => {
349 if let Some(thought) = thought
350 && *thought
351 {
352 completion::AssistantContent::Reasoning(Reasoning::new(text))
353 } else {
354 completion::AssistantContent::text(text)
355 }
356 }
357 PartKind::InlineData(inline_data) => {
358 let mime_type = message::MediaType::from_mime_type(&inline_data.mime_type);
359
360 match mime_type {
361 Some(message::MediaType::Image(media_type)) => {
362 message::AssistantContent::image_base64(
363 &inline_data.data,
364 Some(media_type),
365 Some(message::ImageDetail::default()),
366 )
367 }
368 _ => {
369 return Err(CompletionError::ResponseError(format!(
370 "Unsupported media type {mime_type:?}"
371 )));
372 }
373 }
374 }
375 PartKind::FunctionCall(function_call) => {
376 completion::AssistantContent::tool_call(
377 &function_call.name,
378 &function_call.name,
379 function_call.args.clone(),
380 )
381 }
382 _ => {
383 return Err(CompletionError::ResponseError(
384 "Response did not contain a message or tool call".into(),
385 ));
386 }
387 })
388 })
389 .collect::<Result<Vec<_>, _>>()?;
390
391 let choice = OneOrMany::many(content).map_err(|_| {
392 CompletionError::ResponseError(
393 "Response contained no message or tool call (empty)".to_owned(),
394 )
395 })?;
396
397 let usage = response
398 .usage_metadata
399 .as_ref()
400 .map(|usage| completion::Usage {
401 input_tokens: usage.prompt_token_count as u64,
402 output_tokens: usage.candidates_token_count.unwrap_or(0) as u64,
403 total_tokens: usage.total_token_count as u64,
404 })
405 .unwrap_or_default();
406
407 Ok(completion::CompletionResponse {
408 choice,
409 usage,
410 raw_response: response,
411 })
412 }
413}
414
415pub mod gemini_api_types {
416 use crate::telemetry::ProviderResponseExt;
417 use std::{collections::HashMap, convert::Infallible, str::FromStr};
418
419 use serde::{Deserialize, Serialize};
423 use serde_json::{Value, json};
424
425 use crate::completion::GetTokenUsage;
426 use crate::message::{DocumentSourceKind, ImageMediaType, MessageError, MimeType};
427 use crate::{
428 OneOrMany,
429 completion::CompletionError,
430 message::{self, Reasoning, Text},
431 providers::gemini::gemini_api_types::{CodeExecutionResult, ExecutableCode},
432 };
433
434 #[derive(Debug, Deserialize, Serialize, Default)]
435 #[serde(rename_all = "camelCase")]
436 pub struct AdditionalParameters {
437 pub generation_config: Option<GenerationConfig>,
439 #[serde(flatten, skip_serializing_if = "Option::is_none")]
441 pub additional_params: Option<serde_json::Value>,
442 }
443
444 impl AdditionalParameters {
445 pub fn with_config(mut self, cfg: GenerationConfig) -> Self {
446 self.generation_config = Some(cfg);
447 self
448 }
449
450 pub fn with_params(mut self, params: serde_json::Value) -> Self {
451 self.additional_params = Some(params);
452 self
453 }
454 }
455
456 #[derive(Debug, Deserialize, Serialize)]
464 #[serde(rename_all = "camelCase")]
465 pub struct GenerateContentResponse {
466 pub response_id: String,
467 pub candidates: Vec<ContentCandidate>,
469 pub prompt_feedback: Option<PromptFeedback>,
471 pub usage_metadata: Option<UsageMetadata>,
473 pub model_version: Option<String>,
474 }
475
476 impl ProviderResponseExt for GenerateContentResponse {
477 type OutputMessage = ContentCandidate;
478 type Usage = UsageMetadata;
479
480 fn get_response_id(&self) -> Option<String> {
481 Some(self.response_id.clone())
482 }
483
484 fn get_response_model_name(&self) -> Option<String> {
485 None
486 }
487
488 fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
489 self.candidates.clone()
490 }
491
492 fn get_text_response(&self) -> Option<String> {
493 let str = self
494 .candidates
495 .iter()
496 .filter_map(|x| {
497 let content = x.content.as_ref()?;
498 if content.role.as_ref().is_none_or(|y| y != &Role::Model) {
499 return None;
500 }
501
502 let res = content
503 .parts
504 .iter()
505 .filter_map(|part| {
506 if let PartKind::Text(ref str) = part.part {
507 Some(str.to_owned())
508 } else {
509 None
510 }
511 })
512 .collect::<Vec<String>>()
513 .join("\n");
514
515 Some(res)
516 })
517 .collect::<Vec<String>>()
518 .join("\n");
519
520 if str.is_empty() { None } else { Some(str) }
521 }
522
523 fn get_usage(&self) -> Option<Self::Usage> {
524 self.usage_metadata.clone()
525 }
526 }
527
528 #[derive(Clone, Debug, Deserialize, Serialize)]
530 #[serde(rename_all = "camelCase")]
531 pub struct ContentCandidate {
532 #[serde(skip_serializing_if = "Option::is_none")]
534 pub content: Option<Content>,
535 pub finish_reason: Option<FinishReason>,
538 pub safety_ratings: Option<Vec<SafetyRating>>,
541 pub citation_metadata: Option<CitationMetadata>,
545 pub token_count: Option<i32>,
547 pub avg_logprobs: Option<f64>,
549 pub logprobs_result: Option<LogprobsResult>,
551 pub index: Option<i32>,
553 pub finish_message: Option<String>,
555 }
556
557 #[derive(Clone, Debug, Deserialize, Serialize)]
558 pub struct Content {
559 #[serde(default)]
561 pub parts: Vec<Part>,
562 pub role: Option<Role>,
565 }
566
567 impl TryFrom<message::Message> for Content {
568 type Error = message::MessageError;
569
570 fn try_from(msg: message::Message) -> Result<Self, Self::Error> {
571 Ok(match msg {
572 message::Message::User { content } => Content {
573 parts: content
574 .into_iter()
575 .map(|c| c.try_into())
576 .collect::<Result<Vec<_>, _>>()?,
577 role: Some(Role::User),
578 },
579 message::Message::Assistant { content, .. } => Content {
580 role: Some(Role::Model),
581 parts: content
582 .into_iter()
583 .map(|content| content.try_into())
584 .collect::<Result<Vec<_>, _>>()?,
585 },
586 })
587 }
588 }
589
590 impl TryFrom<Content> for message::Message {
591 type Error = message::MessageError;
592
593 fn try_from(content: Content) -> Result<Self, Self::Error> {
594 match content.role {
595 Some(Role::User) | None => {
596 Ok(message::Message::User {
597 content: {
598 let user_content: Result<Vec<_>, _> = content.parts.into_iter()
599 .map(|Part { part, .. }| {
600 Ok(match part {
601 PartKind::Text(text) => message::UserContent::text(text),
602 PartKind::InlineData(inline_data) => {
603 let mime_type =
604 message::MediaType::from_mime_type(&inline_data.mime_type);
605
606 match mime_type {
607 Some(message::MediaType::Image(media_type)) => {
608 message::UserContent::image_base64(
609 inline_data.data,
610 Some(media_type),
611 Some(message::ImageDetail::default()),
612 )
613 }
614 Some(message::MediaType::Document(media_type)) => {
615 message::UserContent::document(
616 inline_data.data,
617 Some(media_type),
618 )
619 }
620 Some(message::MediaType::Audio(media_type)) => {
621 message::UserContent::audio(
622 inline_data.data,
623 Some(media_type),
624 )
625 }
626 _ => {
627 return Err(message::MessageError::ConversionError(
628 format!("Unsupported media type {mime_type:?}"),
629 ));
630 }
631 }
632 }
633 _ => {
634 return Err(message::MessageError::ConversionError(format!(
635 "Unsupported gemini content part type: {part:?}"
636 )));
637 }
638 })
639 })
640 .collect();
641 OneOrMany::many(user_content?).map_err(|_| {
642 message::MessageError::ConversionError(
643 "Failed to create OneOrMany from user content".to_string(),
644 )
645 })?
646 },
647 })
648 }
649 Some(Role::Model) => Ok(message::Message::Assistant {
650 id: None,
651 content: {
652 let assistant_content: Result<Vec<_>, _> = content
653 .parts
654 .into_iter()
655 .map(|Part { thought, part, .. }| {
656 Ok(match part {
657 PartKind::Text(text) => match thought {
658 Some(true) => message::AssistantContent::Reasoning(
659 Reasoning::new(&text),
660 ),
661 _ => message::AssistantContent::Text(Text { text }),
662 },
663
664 PartKind::FunctionCall(function_call) => {
665 message::AssistantContent::ToolCall(function_call.into())
666 }
667 _ => {
668 return Err(message::MessageError::ConversionError(
669 format!("Unsupported part type: {part:?}"),
670 ));
671 }
672 })
673 })
674 .collect();
675 OneOrMany::many(assistant_content?).map_err(|_| {
676 message::MessageError::ConversionError(
677 "Failed to create OneOrMany from assistant content".to_string(),
678 )
679 })?
680 },
681 }),
682 }
683 }
684 }
685
686 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
687 #[serde(rename_all = "lowercase")]
688 pub enum Role {
689 User,
690 Model,
691 }
692
693 #[derive(Debug, Default, Deserialize, Serialize, Clone, PartialEq)]
694 #[serde(rename_all = "camelCase")]
695 pub struct Part {
696 #[serde(skip_serializing_if = "Option::is_none")]
698 pub thought: Option<bool>,
699 #[serde(skip_serializing_if = "Option::is_none")]
701 pub thought_signature: Option<String>,
702 #[serde(flatten)]
703 pub part: PartKind,
704 #[serde(flatten, skip_serializing_if = "Option::is_none")]
705 pub additional_params: Option<Value>,
706 }
707
708 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
712 #[serde(rename_all = "camelCase")]
713 pub enum PartKind {
714 Text(String),
715 InlineData(Blob),
716 FunctionCall(FunctionCall),
717 FunctionResponse(FunctionResponse),
718 FileData(FileData),
719 ExecutableCode(ExecutableCode),
720 CodeExecutionResult(CodeExecutionResult),
721 }
722
723 impl Default for PartKind {
726 fn default() -> Self {
727 Self::Text(String::new())
728 }
729 }
730
731 impl From<String> for Part {
732 fn from(text: String) -> Self {
733 Self {
734 thought: Some(false),
735 thought_signature: None,
736 part: PartKind::Text(text),
737 additional_params: None,
738 }
739 }
740 }
741
742 impl From<&str> for Part {
743 fn from(text: &str) -> Self {
744 Self::from(text.to_string())
745 }
746 }
747
748 impl FromStr for Part {
749 type Err = Infallible;
750
751 fn from_str(s: &str) -> Result<Self, Self::Err> {
752 Ok(s.into())
753 }
754 }
755
756 impl TryFrom<(ImageMediaType, DocumentSourceKind)> for PartKind {
757 type Error = message::MessageError;
758 fn try_from(
759 (mime_type, doc_src): (ImageMediaType, DocumentSourceKind),
760 ) -> Result<Self, Self::Error> {
761 let mime_type = mime_type.to_mime_type().to_string();
762 let part = match doc_src {
763 DocumentSourceKind::Url(url) => PartKind::FileData(FileData {
764 mime_type: Some(mime_type),
765 file_uri: url,
766 }),
767 DocumentSourceKind::Base64(data) | DocumentSourceKind::String(data) => {
768 PartKind::InlineData(Blob { mime_type, data })
769 }
770 DocumentSourceKind::Raw(_) => {
771 return Err(message::MessageError::ConversionError(
772 "Raw files not supported, encode as base64 first".into(),
773 ));
774 }
775 DocumentSourceKind::Unknown => {
776 return Err(message::MessageError::ConversionError(
777 "Can't convert an unknown document source".to_string(),
778 ));
779 }
780 };
781
782 Ok(part)
783 }
784 }
785
786 impl TryFrom<message::UserContent> for Part {
787 type Error = message::MessageError;
788
789 fn try_from(content: message::UserContent) -> Result<Self, Self::Error> {
790 match content {
791 message::UserContent::Text(message::Text { text }) => Ok(Part {
792 thought: Some(false),
793 thought_signature: None,
794 part: PartKind::Text(text),
795 additional_params: None,
796 }),
797 message::UserContent::ToolResult(message::ToolResult { id, content, .. }) => {
798 let content = match content.first() {
799 message::ToolResultContent::Text(text) => text.text,
800 message::ToolResultContent::Image(_) => {
801 return Err(message::MessageError::ConversionError(
802 "Tool result content must be text".to_string(),
803 ));
804 }
805 };
806 let result: serde_json::Value =
808 serde_json::from_str(&content).unwrap_or_else(|error| {
809 tracing::trace!(
810 ?error,
811 "Tool result is not a valid JSON, treat it as normal string"
812 );
813 json!(content)
814 });
815 Ok(Part {
816 thought: Some(false),
817 thought_signature: None,
818 part: PartKind::FunctionResponse(FunctionResponse {
819 name: id,
820 response: Some(json!({ "result": result })),
821 }),
822 additional_params: None,
823 })
824 }
825 message::UserContent::Image(message::Image {
826 data, media_type, ..
827 }) => match media_type {
828 Some(media_type) => match media_type {
829 message::ImageMediaType::JPEG
830 | message::ImageMediaType::PNG
831 | message::ImageMediaType::WEBP
832 | message::ImageMediaType::HEIC
833 | message::ImageMediaType::HEIF => {
834 let part = PartKind::try_from((media_type, data))?;
835 Ok(Part {
836 thought: Some(false),
837 thought_signature: None,
838 part,
839 additional_params: None,
840 })
841 }
842 _ => Err(message::MessageError::ConversionError(format!(
843 "Unsupported image media type {media_type:?}"
844 ))),
845 },
846 None => Err(message::MessageError::ConversionError(
847 "Media type for image is required for Gemini".to_string(),
848 )),
849 },
850 message::UserContent::Document(message::Document {
851 data, media_type, ..
852 }) => {
853 let Some(media_type) = media_type else {
854 return Err(MessageError::ConversionError(
855 "A mime type is required for document inputs to Gemini".to_string(),
856 ));
857 };
858
859 if !media_type.is_code() {
860 let mime_type = media_type.to_mime_type().to_string();
861
862 let part = match data {
863 DocumentSourceKind::Url(file_uri) => PartKind::FileData(FileData {
864 mime_type: Some(mime_type),
865 file_uri,
866 }),
867 DocumentSourceKind::Base64(data) | DocumentSourceKind::String(data) => {
868 PartKind::InlineData(Blob { mime_type, data })
869 }
870 DocumentSourceKind::Raw(_) => {
871 return Err(message::MessageError::ConversionError(
872 "Raw files not supported, encode as base64 first".into(),
873 ));
874 }
875 _ => {
876 return Err(message::MessageError::ConversionError(
877 "Document has no body".to_string(),
878 ));
879 }
880 };
881
882 Ok(Part {
883 thought: Some(false),
884 part,
885 ..Default::default()
886 })
887 } else {
888 Err(message::MessageError::ConversionError(format!(
889 "Unsupported document media type {media_type:?}"
890 )))
891 }
892 }
893
894 message::UserContent::Audio(message::Audio {
895 data, media_type, ..
896 }) => {
897 let Some(media_type) = media_type else {
898 return Err(MessageError::ConversionError(
899 "A mime type is required for audio inputs to Gemini".to_string(),
900 ));
901 };
902
903 let mime_type = media_type.to_mime_type().to_string();
904
905 let part = match data {
906 DocumentSourceKind::Base64(data) => {
907 PartKind::InlineData(Blob { data, mime_type })
908 }
909
910 DocumentSourceKind::Url(file_uri) => PartKind::FileData(FileData {
911 mime_type: Some(mime_type),
912 file_uri,
913 }),
914 DocumentSourceKind::String(_) => {
915 return Err(message::MessageError::ConversionError(
916 "Strings cannot be used as audio files!".into(),
917 ));
918 }
919 DocumentSourceKind::Raw(_) => {
920 return Err(message::MessageError::ConversionError(
921 "Raw files not supported, encode as base64 first".into(),
922 ));
923 }
924 DocumentSourceKind::Unknown => {
925 return Err(message::MessageError::ConversionError(
926 "Content has no body".to_string(),
927 ));
928 }
929 };
930
931 Ok(Part {
932 thought: Some(false),
933 part,
934 ..Default::default()
935 })
936 }
937 message::UserContent::Video(message::Video {
938 data,
939 media_type,
940 additional_params,
941 ..
942 }) => {
943 let mime_type = media_type.map(|media_ty| media_ty.to_mime_type().to_string());
944
945 let part = match data {
946 DocumentSourceKind::Url(file_uri) => {
947 if file_uri.starts_with("https://www.youtube.com") {
948 PartKind::FileData(FileData {
949 mime_type,
950 file_uri,
951 })
952 } else {
953 if mime_type.is_none() {
954 return Err(MessageError::ConversionError(
955 "A mime type is required for non-Youtube video file inputs to Gemini"
956 .to_string(),
957 ));
958 }
959
960 PartKind::FileData(FileData {
961 mime_type,
962 file_uri,
963 })
964 }
965 }
966 DocumentSourceKind::Base64(data) => {
967 let Some(mime_type) = mime_type else {
968 return Err(MessageError::ConversionError(
969 "A media type is expected for base64 encoded strings"
970 .to_string(),
971 ));
972 };
973 PartKind::InlineData(Blob { mime_type, data })
974 }
975 DocumentSourceKind::String(_) => {
976 return Err(message::MessageError::ConversionError(
977 "Strings cannot be used as audio files!".into(),
978 ));
979 }
980 DocumentSourceKind::Raw(_) => {
981 return Err(message::MessageError::ConversionError(
982 "Raw file data not supported, encode as base64 first".into(),
983 ));
984 }
985 DocumentSourceKind::Unknown => {
986 return Err(message::MessageError::ConversionError(
987 "Media type for video is required for Gemini".to_string(),
988 ));
989 }
990 };
991
992 Ok(Part {
993 thought: Some(false),
994 thought_signature: None,
995 part,
996 additional_params,
997 })
998 }
999 }
1000 }
1001 }
1002
1003 impl TryFrom<message::AssistantContent> for Part {
1004 type Error = message::MessageError;
1005
1006 fn try_from(content: message::AssistantContent) -> Result<Self, Self::Error> {
1007 match content {
1008 message::AssistantContent::Text(message::Text { text }) => Ok(text.into()),
1009 message::AssistantContent::Image(message::Image {
1010 data, media_type, ..
1011 }) => match media_type {
1012 Some(media_type) => match media_type {
1013 message::ImageMediaType::JPEG
1014 | message::ImageMediaType::PNG
1015 | message::ImageMediaType::WEBP
1016 | message::ImageMediaType::HEIC
1017 | message::ImageMediaType::HEIF => {
1018 let part = PartKind::try_from((media_type, data))?;
1019 Ok(Part {
1020 thought: Some(false),
1021 thought_signature: None,
1022 part,
1023 additional_params: None,
1024 })
1025 }
1026 _ => Err(message::MessageError::ConversionError(format!(
1027 "Unsupported image media type {media_type:?}"
1028 ))),
1029 },
1030 None => Err(message::MessageError::ConversionError(
1031 "Media type for image is required for Gemini".to_string(),
1032 )),
1033 },
1034 message::AssistantContent::ToolCall(tool_call) => Ok(tool_call.into()),
1035 message::AssistantContent::Reasoning(message::Reasoning { reasoning, .. }) => {
1036 Ok(Part {
1037 thought: Some(true),
1038 thought_signature: None,
1039 part: PartKind::Text(
1040 reasoning.first().cloned().unwrap_or_else(|| "".to_string()),
1041 ),
1042 additional_params: None,
1043 })
1044 }
1045 }
1046 }
1047 }
1048
1049 impl From<message::ToolCall> for Part {
1050 fn from(tool_call: message::ToolCall) -> Self {
1051 Self {
1052 thought: Some(false),
1053 thought_signature: None,
1054 part: PartKind::FunctionCall(FunctionCall {
1055 name: tool_call.function.name,
1056 args: tool_call.function.arguments,
1057 }),
1058 additional_params: None,
1059 }
1060 }
1061 }
1062
1063 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1066 #[serde(rename_all = "camelCase")]
1067 pub struct Blob {
1068 pub mime_type: String,
1071 pub data: String,
1073 }
1074
1075 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1078 pub struct FunctionCall {
1079 pub name: String,
1082 pub args: serde_json::Value,
1084 }
1085
1086 impl From<FunctionCall> for message::ToolCall {
1087 fn from(function_call: FunctionCall) -> Self {
1088 Self {
1089 id: function_call.name.clone(),
1090 call_id: None,
1091 function: message::ToolFunction {
1092 name: function_call.name,
1093 arguments: function_call.args,
1094 },
1095 }
1096 }
1097 }
1098
1099 impl From<message::ToolCall> for FunctionCall {
1100 fn from(tool_call: message::ToolCall) -> Self {
1101 Self {
1102 name: tool_call.function.name,
1103 args: tool_call.function.arguments,
1104 }
1105 }
1106 }
1107
1108 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1112 pub struct FunctionResponse {
1113 pub name: String,
1116 pub response: Option<serde_json::Value>,
1118 }
1119
1120 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1122 #[serde(rename_all = "camelCase")]
1123 pub struct FileData {
1124 pub mime_type: Option<String>,
1126 pub file_uri: String,
1128 }
1129
1130 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1131 pub struct SafetyRating {
1132 pub category: HarmCategory,
1133 pub probability: HarmProbability,
1134 }
1135
1136 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1137 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1138 pub enum HarmProbability {
1139 HarmProbabilityUnspecified,
1140 Negligible,
1141 Low,
1142 Medium,
1143 High,
1144 }
1145
1146 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1147 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1148 pub enum HarmCategory {
1149 HarmCategoryUnspecified,
1150 HarmCategoryDerogatory,
1151 HarmCategoryToxicity,
1152 HarmCategoryViolence,
1153 HarmCategorySexually,
1154 HarmCategoryMedical,
1155 HarmCategoryDangerous,
1156 HarmCategoryHarassment,
1157 HarmCategoryHateSpeech,
1158 HarmCategorySexuallyExplicit,
1159 HarmCategoryDangerousContent,
1160 HarmCategoryCivicIntegrity,
1161 }
1162
1163 #[derive(Debug, Deserialize, Clone, Default, Serialize)]
1164 #[serde(rename_all = "camelCase")]
1165 pub struct UsageMetadata {
1166 pub prompt_token_count: i32,
1167 #[serde(skip_serializing_if = "Option::is_none")]
1168 pub cached_content_token_count: Option<i32>,
1169 #[serde(skip_serializing_if = "Option::is_none")]
1170 pub candidates_token_count: Option<i32>,
1171 pub total_token_count: i32,
1172 #[serde(skip_serializing_if = "Option::is_none")]
1173 pub thoughts_token_count: Option<i32>,
1174 }
1175
1176 impl std::fmt::Display for UsageMetadata {
1177 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1178 write!(
1179 f,
1180 "Prompt token count: {}\nCached content token count: {}\nCandidates token count: {}\nTotal token count: {}",
1181 self.prompt_token_count,
1182 match self.cached_content_token_count {
1183 Some(count) => count.to_string(),
1184 None => "n/a".to_string(),
1185 },
1186 match self.candidates_token_count {
1187 Some(count) => count.to_string(),
1188 None => "n/a".to_string(),
1189 },
1190 self.total_token_count
1191 )
1192 }
1193 }
1194
1195 impl GetTokenUsage for UsageMetadata {
1196 fn token_usage(&self) -> Option<crate::completion::Usage> {
1197 let mut usage = crate::completion::Usage::new();
1198
1199 usage.input_tokens = self.prompt_token_count as u64;
1200 usage.output_tokens = (self.cached_content_token_count.unwrap_or_default()
1201 + self.candidates_token_count.unwrap_or_default()
1202 + self.thoughts_token_count.unwrap_or_default())
1203 as u64;
1204 usage.total_tokens = usage.input_tokens + usage.output_tokens;
1205
1206 Some(usage)
1207 }
1208 }
1209
1210 #[derive(Debug, Deserialize, Serialize)]
1212 #[serde(rename_all = "camelCase")]
1213 pub struct PromptFeedback {
1214 pub block_reason: Option<BlockReason>,
1216 pub safety_ratings: Option<Vec<SafetyRating>>,
1218 }
1219
1220 #[derive(Debug, Deserialize, Serialize)]
1222 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1223 pub enum BlockReason {
1224 BlockReasonUnspecified,
1226 Safety,
1228 Other,
1230 Blocklist,
1232 ProhibitedContent,
1234 }
1235
1236 #[derive(Clone, Debug, Deserialize, Serialize)]
1237 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1238 pub enum FinishReason {
1239 FinishReasonUnspecified,
1241 Stop,
1243 MaxTokens,
1245 Safety,
1247 Recitation,
1249 Language,
1251 Other,
1253 Blocklist,
1255 ProhibitedContent,
1257 Spii,
1259 MalformedFunctionCall,
1261 }
1262
1263 #[derive(Clone, Debug, Deserialize, Serialize)]
1264 #[serde(rename_all = "camelCase")]
1265 pub struct CitationMetadata {
1266 pub citation_sources: Vec<CitationSource>,
1267 }
1268
1269 #[derive(Clone, Debug, Deserialize, Serialize)]
1270 #[serde(rename_all = "camelCase")]
1271 pub struct CitationSource {
1272 #[serde(skip_serializing_if = "Option::is_none")]
1273 pub uri: Option<String>,
1274 #[serde(skip_serializing_if = "Option::is_none")]
1275 pub start_index: Option<i32>,
1276 #[serde(skip_serializing_if = "Option::is_none")]
1277 pub end_index: Option<i32>,
1278 #[serde(skip_serializing_if = "Option::is_none")]
1279 pub license: Option<String>,
1280 }
1281
1282 #[derive(Clone, Debug, Deserialize, Serialize)]
1283 #[serde(rename_all = "camelCase")]
1284 pub struct LogprobsResult {
1285 pub top_candidate: Vec<TopCandidate>,
1286 pub chosen_candidate: Vec<LogProbCandidate>,
1287 }
1288
1289 #[derive(Clone, Debug, Deserialize, Serialize)]
1290 pub struct TopCandidate {
1291 pub candidates: Vec<LogProbCandidate>,
1292 }
1293
1294 #[derive(Clone, Debug, Deserialize, Serialize)]
1295 #[serde(rename_all = "camelCase")]
1296 pub struct LogProbCandidate {
1297 pub token: String,
1298 pub token_id: String,
1299 pub log_probability: f64,
1300 }
1301
1302 #[derive(Debug, Deserialize, Serialize)]
1307 #[serde(rename_all = "camelCase")]
1308 pub struct GenerationConfig {
1309 #[serde(skip_serializing_if = "Option::is_none")]
1312 pub stop_sequences: Option<Vec<String>>,
1313 #[serde(skip_serializing_if = "Option::is_none")]
1319 pub response_mime_type: Option<String>,
1320 #[serde(skip_serializing_if = "Option::is_none")]
1324 pub response_schema: Option<Schema>,
1325 #[serde(
1331 skip_serializing_if = "Option::is_none",
1332 rename = "_responseJsonSchema"
1333 )]
1334 pub _response_json_schema: Option<Value>,
1335 #[serde(skip_serializing_if = "Option::is_none")]
1337 pub response_json_schema: Option<Value>,
1338 #[serde(skip_serializing_if = "Option::is_none")]
1341 pub candidate_count: Option<i32>,
1342 #[serde(skip_serializing_if = "Option::is_none")]
1345 pub max_output_tokens: Option<u64>,
1346 #[serde(skip_serializing_if = "Option::is_none")]
1349 pub temperature: Option<f64>,
1350 #[serde(skip_serializing_if = "Option::is_none")]
1357 pub top_p: Option<f64>,
1358 #[serde(skip_serializing_if = "Option::is_none")]
1364 pub top_k: Option<i32>,
1365 #[serde(skip_serializing_if = "Option::is_none")]
1371 pub presence_penalty: Option<f64>,
1372 #[serde(skip_serializing_if = "Option::is_none")]
1380 pub frequency_penalty: Option<f64>,
1381 #[serde(skip_serializing_if = "Option::is_none")]
1383 pub response_logprobs: Option<bool>,
1384 #[serde(skip_serializing_if = "Option::is_none")]
1387 pub logprobs: Option<i32>,
1388 #[serde(skip_serializing_if = "Option::is_none")]
1390 pub thinking_config: Option<ThinkingConfig>,
1391 #[serde(skip_serializing_if = "Option::is_none")]
1392 pub image_config: Option<ImageConfig>,
1393 }
1394
1395 impl Default for GenerationConfig {
1396 fn default() -> Self {
1397 Self {
1398 temperature: Some(1.0),
1399 max_output_tokens: Some(4096),
1400 stop_sequences: None,
1401 response_mime_type: None,
1402 response_schema: None,
1403 _response_json_schema: None,
1404 response_json_schema: None,
1405 candidate_count: None,
1406 top_p: None,
1407 top_k: None,
1408 presence_penalty: None,
1409 frequency_penalty: None,
1410 response_logprobs: None,
1411 logprobs: None,
1412 thinking_config: None,
1413 image_config: None,
1414 }
1415 }
1416 }
1417
1418 #[derive(Debug, Deserialize, Serialize)]
1419 #[serde(rename_all = "camelCase")]
1420 pub struct ThinkingConfig {
1421 pub thinking_budget: u32,
1422 pub include_thoughts: Option<bool>,
1423 }
1424
1425 #[derive(Debug, Deserialize, Serialize)]
1426 #[serde(rename_all = "camelCase")]
1427 pub struct ImageConfig {
1428 #[serde(skip_serializing_if = "Option::is_none")]
1429 pub aspect_ratio: Option<String>,
1430 #[serde(skip_serializing_if = "Option::is_none")]
1431 pub image_size: Option<String>,
1432 }
1433
1434 #[derive(Debug, Deserialize, Serialize, Clone)]
1438 pub struct Schema {
1439 pub r#type: String,
1440 #[serde(skip_serializing_if = "Option::is_none")]
1441 pub format: Option<String>,
1442 #[serde(skip_serializing_if = "Option::is_none")]
1443 pub description: Option<String>,
1444 #[serde(skip_serializing_if = "Option::is_none")]
1445 pub nullable: Option<bool>,
1446 #[serde(skip_serializing_if = "Option::is_none")]
1447 pub r#enum: Option<Vec<String>>,
1448 #[serde(skip_serializing_if = "Option::is_none")]
1449 pub max_items: Option<i32>,
1450 #[serde(skip_serializing_if = "Option::is_none")]
1451 pub min_items: Option<i32>,
1452 #[serde(skip_serializing_if = "Option::is_none")]
1453 pub properties: Option<HashMap<String, Schema>>,
1454 #[serde(skip_serializing_if = "Option::is_none")]
1455 pub required: Option<Vec<String>>,
1456 #[serde(skip_serializing_if = "Option::is_none")]
1457 pub items: Option<Box<Schema>>,
1458 }
1459
1460 pub fn flatten_schema(mut schema: Value) -> Result<Value, CompletionError> {
1466 let defs = if let Some(obj) = schema.as_object() {
1468 obj.get("$defs").or_else(|| obj.get("definitions")).cloned()
1469 } else {
1470 None
1471 };
1472
1473 let Some(defs_value) = defs else {
1474 return Ok(schema);
1475 };
1476
1477 let Some(defs_obj) = defs_value.as_object() else {
1478 return Err(CompletionError::ResponseError(
1479 "$defs must be an object".into(),
1480 ));
1481 };
1482
1483 resolve_refs(&mut schema, defs_obj)?;
1484
1485 if let Some(obj) = schema.as_object_mut() {
1487 obj.remove("$defs");
1488 obj.remove("definitions");
1489 }
1490
1491 Ok(schema)
1492 }
1493
1494 fn resolve_refs(
1497 value: &mut Value,
1498 defs: &serde_json::Map<String, Value>,
1499 ) -> Result<(), CompletionError> {
1500 match value {
1501 Value::Object(obj) => {
1502 if let Some(ref_value) = obj.get("$ref")
1503 && let Some(ref_str) = ref_value.as_str()
1504 {
1505 let def_name = parse_ref_path(ref_str)?;
1507
1508 let def = defs.get(&def_name).ok_or_else(|| {
1509 CompletionError::ResponseError(format!("Reference not found: {}", ref_str))
1510 })?;
1511
1512 let mut resolved = def.clone();
1513 resolve_refs(&mut resolved, defs)?;
1514 *value = resolved;
1515 return Ok(());
1516 }
1517
1518 for (_, v) in obj.iter_mut() {
1519 resolve_refs(v, defs)?;
1520 }
1521 }
1522 Value::Array(arr) => {
1523 for item in arr.iter_mut() {
1524 resolve_refs(item, defs)?;
1525 }
1526 }
1527 _ => {}
1528 }
1529
1530 Ok(())
1531 }
1532
1533 fn parse_ref_path(ref_str: &str) -> Result<String, CompletionError> {
1539 if let Some(fragment) = ref_str.strip_prefix('#') {
1540 if let Some(name) = fragment.strip_prefix("/$defs/") {
1541 Ok(name.to_string())
1542 } else if let Some(name) = fragment.strip_prefix("/definitions/") {
1543 Ok(name.to_string())
1544 } else {
1545 Err(CompletionError::ResponseError(format!(
1546 "Unsupported reference format: {}",
1547 ref_str
1548 )))
1549 }
1550 } else {
1551 Err(CompletionError::ResponseError(format!(
1552 "Only fragment references (#/...) are supported: {}",
1553 ref_str
1554 )))
1555 }
1556 }
1557
1558 fn extract_type(type_value: &Value) -> Option<String> {
1561 if type_value.is_string() {
1562 type_value.as_str().map(String::from)
1563 } else if type_value.is_array() {
1564 type_value
1565 .as_array()
1566 .and_then(|arr| arr.first())
1567 .and_then(|v| v.as_str().map(String::from))
1568 } else {
1569 None
1570 }
1571 }
1572
1573 fn extract_type_from_composition(composition: &Value) -> Option<String> {
1576 composition.as_array().and_then(|arr| {
1577 arr.iter().find_map(|schema| {
1578 if let Some(obj) = schema.as_object() {
1579 if let Some(type_val) = obj.get("type")
1581 && let Some(type_str) = type_val.as_str()
1582 && type_str == "null"
1583 {
1584 return None;
1585 }
1586 obj.get("type").and_then(extract_type).or_else(|| {
1588 if obj.contains_key("properties") {
1589 Some("object".to_string())
1590 } else {
1591 None
1592 }
1593 })
1594 } else {
1595 None
1596 }
1597 })
1598 })
1599 }
1600
1601 fn extract_schema_from_composition(
1604 composition: &Value,
1605 ) -> Option<serde_json::Map<String, Value>> {
1606 composition.as_array().and_then(|arr| {
1607 arr.iter().find_map(|schema| {
1608 if let Some(obj) = schema.as_object()
1609 && let Some(type_val) = obj.get("type")
1610 && let Some(type_str) = type_val.as_str()
1611 {
1612 if type_str == "null" {
1613 return None;
1614 }
1615 Some(obj.clone())
1616 } else {
1617 None
1618 }
1619 })
1620 })
1621 }
1622
1623 fn infer_type(obj: &serde_json::Map<String, Value>) -> String {
1626 if let Some(type_val) = obj.get("type")
1628 && let Some(type_str) = extract_type(type_val)
1629 {
1630 return type_str;
1631 }
1632
1633 if let Some(any_of) = obj.get("anyOf")
1635 && let Some(type_str) = extract_type_from_composition(any_of)
1636 {
1637 return type_str;
1638 }
1639
1640 if let Some(one_of) = obj.get("oneOf")
1641 && let Some(type_str) = extract_type_from_composition(one_of)
1642 {
1643 return type_str;
1644 }
1645
1646 if let Some(all_of) = obj.get("allOf")
1647 && let Some(type_str) = extract_type_from_composition(all_of)
1648 {
1649 return type_str;
1650 }
1651
1652 if obj.contains_key("properties") {
1654 "object".to_string()
1655 } else {
1656 String::new()
1657 }
1658 }
1659
1660 impl TryFrom<Value> for Schema {
1661 type Error = CompletionError;
1662
1663 fn try_from(value: Value) -> Result<Self, Self::Error> {
1664 let flattened_val = flatten_schema(value)?;
1665 if let Some(obj) = flattened_val.as_object() {
1666 let props_source = if obj.get("properties").is_none() {
1669 if let Some(any_of) = obj.get("anyOf") {
1670 extract_schema_from_composition(any_of)
1671 } else if let Some(one_of) = obj.get("oneOf") {
1672 extract_schema_from_composition(one_of)
1673 } else if let Some(all_of) = obj.get("allOf") {
1674 extract_schema_from_composition(all_of)
1675 } else {
1676 None
1677 }
1678 .unwrap_or(obj.clone())
1679 } else {
1680 obj.clone()
1681 };
1682
1683 Ok(Schema {
1684 r#type: infer_type(obj),
1685 format: obj.get("format").and_then(|v| v.as_str()).map(String::from),
1686 description: obj
1687 .get("description")
1688 .and_then(|v| v.as_str())
1689 .map(String::from),
1690 nullable: obj.get("nullable").and_then(|v| v.as_bool()),
1691 r#enum: obj.get("enum").and_then(|v| v.as_array()).map(|arr| {
1692 arr.iter()
1693 .filter_map(|v| v.as_str().map(String::from))
1694 .collect()
1695 }),
1696 max_items: obj
1697 .get("maxItems")
1698 .and_then(|v| v.as_i64())
1699 .map(|v| v as i32),
1700 min_items: obj
1701 .get("minItems")
1702 .and_then(|v| v.as_i64())
1703 .map(|v| v as i32),
1704 properties: props_source
1705 .get("properties")
1706 .and_then(|v| v.as_object())
1707 .map(|map| {
1708 map.iter()
1709 .filter_map(|(k, v)| {
1710 v.clone().try_into().ok().map(|schema| (k.clone(), schema))
1711 })
1712 .collect()
1713 }),
1714 required: props_source
1715 .get("required")
1716 .and_then(|v| v.as_array())
1717 .map(|arr| {
1718 arr.iter()
1719 .filter_map(|v| v.as_str().map(String::from))
1720 .collect()
1721 }),
1722 items: obj
1723 .get("items")
1724 .and_then(|v| v.clone().try_into().ok())
1725 .map(Box::new),
1726 })
1727 } else {
1728 Err(CompletionError::ResponseError(
1729 "Expected a JSON object for Schema".into(),
1730 ))
1731 }
1732 }
1733 }
1734
1735 #[derive(Debug, Serialize)]
1736 #[serde(rename_all = "camelCase")]
1737 pub struct GenerateContentRequest {
1738 pub contents: Vec<Content>,
1739 #[serde(skip_serializing_if = "Option::is_none")]
1740 pub tools: Option<Tool>,
1741 pub tool_config: Option<ToolConfig>,
1742 pub generation_config: Option<GenerationConfig>,
1744 pub safety_settings: Option<Vec<SafetySetting>>,
1758 pub system_instruction: Option<Content>,
1761 #[serde(flatten, skip_serializing_if = "Option::is_none")]
1764 pub additional_params: Option<serde_json::Value>,
1765 }
1766
1767 #[derive(Debug, Serialize)]
1768 #[serde(rename_all = "camelCase")]
1769 pub struct Tool {
1770 pub function_declarations: Vec<FunctionDeclaration>,
1771 pub code_execution: Option<CodeExecution>,
1772 }
1773
1774 #[derive(Debug, Serialize, Clone)]
1775 #[serde(rename_all = "camelCase")]
1776 pub struct FunctionDeclaration {
1777 pub name: String,
1778 pub description: String,
1779 #[serde(skip_serializing_if = "Option::is_none")]
1780 pub parameters: Option<Schema>,
1781 }
1782
1783 #[derive(Debug, Serialize, Deserialize)]
1784 #[serde(rename_all = "camelCase")]
1785 pub struct ToolConfig {
1786 pub function_calling_config: Option<FunctionCallingMode>,
1787 }
1788
1789 #[derive(Debug, Serialize, Deserialize, Default)]
1790 #[serde(tag = "mode", rename_all = "UPPERCASE")]
1791 pub enum FunctionCallingMode {
1792 #[default]
1793 Auto,
1794 None,
1795 Any {
1796 #[serde(skip_serializing_if = "Option::is_none")]
1797 allowed_function_names: Option<Vec<String>>,
1798 },
1799 }
1800
1801 impl TryFrom<message::ToolChoice> for FunctionCallingMode {
1802 type Error = CompletionError;
1803 fn try_from(value: message::ToolChoice) -> Result<Self, Self::Error> {
1804 let res = match value {
1805 message::ToolChoice::Auto => Self::Auto,
1806 message::ToolChoice::None => Self::None,
1807 message::ToolChoice::Required => Self::Any {
1808 allowed_function_names: None,
1809 },
1810 message::ToolChoice::Specific { function_names } => Self::Any {
1811 allowed_function_names: Some(function_names),
1812 },
1813 };
1814
1815 Ok(res)
1816 }
1817 }
1818
1819 #[derive(Debug, Serialize)]
1820 pub struct CodeExecution {}
1821
1822 #[derive(Debug, Serialize)]
1823 #[serde(rename_all = "camelCase")]
1824 pub struct SafetySetting {
1825 pub category: HarmCategory,
1826 pub threshold: HarmBlockThreshold,
1827 }
1828
1829 #[derive(Debug, Serialize)]
1830 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1831 pub enum HarmBlockThreshold {
1832 HarmBlockThresholdUnspecified,
1833 BlockLowAndAbove,
1834 BlockMediumAndAbove,
1835 BlockOnlyHigh,
1836 BlockNone,
1837 Off,
1838 }
1839}
1840
1841#[cfg(test)]
1842mod tests {
1843 use crate::{message, providers::gemini::completion::gemini_api_types::flatten_schema};
1844
1845 use super::*;
1846 use serde_json::json;
1847
1848 #[test]
1849 fn test_deserialize_message_user() {
1850 let raw_message = r#"{
1851 "parts": [
1852 {"text": "Hello, world!"},
1853 {"inlineData": {"mimeType": "image/png", "data": "base64encodeddata"}},
1854 {"functionCall": {"name": "test_function", "args": {"arg1": "value1"}}},
1855 {"functionResponse": {"name": "test_function", "response": {"result": "success"}}},
1856 {"fileData": {"mimeType": "application/pdf", "fileUri": "http://example.com/file.pdf"}},
1857 {"executableCode": {"code": "print('Hello, world!')", "language": "PYTHON"}},
1858 {"codeExecutionResult": {"output": "Hello, world!", "outcome": "OUTCOME_OK"}}
1859 ],
1860 "role": "user"
1861 }"#;
1862
1863 let content: Content = {
1864 let jd = &mut serde_json::Deserializer::from_str(raw_message);
1865 serde_path_to_error::deserialize(jd).unwrap_or_else(|err| {
1866 panic!("Deserialization error at {}: {}", err.path(), err);
1867 })
1868 };
1869 assert_eq!(content.role, Some(Role::User));
1870 assert_eq!(content.parts.len(), 7);
1871
1872 let parts: Vec<Part> = content.parts.into_iter().collect();
1873
1874 if let Part {
1875 part: PartKind::Text(text),
1876 ..
1877 } = &parts[0]
1878 {
1879 assert_eq!(text, "Hello, world!");
1880 } else {
1881 panic!("Expected text part");
1882 }
1883
1884 if let Part {
1885 part: PartKind::InlineData(inline_data),
1886 ..
1887 } = &parts[1]
1888 {
1889 assert_eq!(inline_data.mime_type, "image/png");
1890 assert_eq!(inline_data.data, "base64encodeddata");
1891 } else {
1892 panic!("Expected inline data part");
1893 }
1894
1895 if let Part {
1896 part: PartKind::FunctionCall(function_call),
1897 ..
1898 } = &parts[2]
1899 {
1900 assert_eq!(function_call.name, "test_function");
1901 assert_eq!(
1902 function_call.args.as_object().unwrap().get("arg1").unwrap(),
1903 "value1"
1904 );
1905 } else {
1906 panic!("Expected function call part");
1907 }
1908
1909 if let Part {
1910 part: PartKind::FunctionResponse(function_response),
1911 ..
1912 } = &parts[3]
1913 {
1914 assert_eq!(function_response.name, "test_function");
1915 assert_eq!(
1916 function_response
1917 .response
1918 .as_ref()
1919 .unwrap()
1920 .get("result")
1921 .unwrap(),
1922 "success"
1923 );
1924 } else {
1925 panic!("Expected function response part");
1926 }
1927
1928 if let Part {
1929 part: PartKind::FileData(file_data),
1930 ..
1931 } = &parts[4]
1932 {
1933 assert_eq!(file_data.mime_type.as_ref().unwrap(), "application/pdf");
1934 assert_eq!(file_data.file_uri, "http://example.com/file.pdf");
1935 } else {
1936 panic!("Expected file data part");
1937 }
1938
1939 if let Part {
1940 part: PartKind::ExecutableCode(executable_code),
1941 ..
1942 } = &parts[5]
1943 {
1944 assert_eq!(executable_code.code, "print('Hello, world!')");
1945 } else {
1946 panic!("Expected executable code part");
1947 }
1948
1949 if let Part {
1950 part: PartKind::CodeExecutionResult(code_execution_result),
1951 ..
1952 } = &parts[6]
1953 {
1954 assert_eq!(
1955 code_execution_result.clone().output.unwrap(),
1956 "Hello, world!"
1957 );
1958 } else {
1959 panic!("Expected code execution result part");
1960 }
1961 }
1962
1963 #[test]
1964 fn test_deserialize_message_model() {
1965 let json_data = json!({
1966 "parts": [{"text": "Hello, user!"}],
1967 "role": "model"
1968 });
1969
1970 let content: Content = serde_json::from_value(json_data).unwrap();
1971 assert_eq!(content.role, Some(Role::Model));
1972 assert_eq!(content.parts.len(), 1);
1973 if let Some(Part {
1974 part: PartKind::Text(text),
1975 ..
1976 }) = content.parts.first()
1977 {
1978 assert_eq!(text, "Hello, user!");
1979 } else {
1980 panic!("Expected text part");
1981 }
1982 }
1983
1984 #[test]
1985 fn test_message_conversion_user() {
1986 let msg = message::Message::user("Hello, world!");
1987 let content: Content = msg.try_into().unwrap();
1988 assert_eq!(content.role, Some(Role::User));
1989 assert_eq!(content.parts.len(), 1);
1990 if let Some(Part {
1991 part: PartKind::Text(text),
1992 ..
1993 }) = &content.parts.first()
1994 {
1995 assert_eq!(text, "Hello, world!");
1996 } else {
1997 panic!("Expected text part");
1998 }
1999 }
2000
2001 #[test]
2002 fn test_message_conversion_model() {
2003 let msg = message::Message::assistant("Hello, user!");
2004
2005 let content: Content = msg.try_into().unwrap();
2006 assert_eq!(content.role, Some(Role::Model));
2007 assert_eq!(content.parts.len(), 1);
2008 if let Some(Part {
2009 part: PartKind::Text(text),
2010 ..
2011 }) = &content.parts.first()
2012 {
2013 assert_eq!(text, "Hello, user!");
2014 } else {
2015 panic!("Expected text part");
2016 }
2017 }
2018
2019 #[test]
2020 fn test_message_conversion_tool_call() {
2021 let tool_call = message::ToolCall {
2022 id: "test_tool".to_string(),
2023 call_id: None,
2024 function: message::ToolFunction {
2025 name: "test_function".to_string(),
2026 arguments: json!({"arg1": "value1"}),
2027 },
2028 };
2029
2030 let msg = message::Message::Assistant {
2031 id: None,
2032 content: OneOrMany::one(message::AssistantContent::ToolCall(tool_call)),
2033 };
2034
2035 let content: Content = msg.try_into().unwrap();
2036 assert_eq!(content.role, Some(Role::Model));
2037 assert_eq!(content.parts.len(), 1);
2038 if let Some(Part {
2039 part: PartKind::FunctionCall(function_call),
2040 ..
2041 }) = content.parts.first()
2042 {
2043 assert_eq!(function_call.name, "test_function");
2044 assert_eq!(
2045 function_call.args.as_object().unwrap().get("arg1").unwrap(),
2046 "value1"
2047 );
2048 } else {
2049 panic!("Expected function call part");
2050 }
2051 }
2052
2053 #[test]
2054 fn test_vec_schema_conversion() {
2055 let schema_with_ref = json!({
2056 "type": "array",
2057 "items": {
2058 "$ref": "#/$defs/Person"
2059 },
2060 "$defs": {
2061 "Person": {
2062 "type": "object",
2063 "properties": {
2064 "first_name": {
2065 "type": ["string", "null"],
2066 "description": "The person's first name, if provided (null otherwise)"
2067 },
2068 "last_name": {
2069 "type": ["string", "null"],
2070 "description": "The person's last name, if provided (null otherwise)"
2071 },
2072 "job": {
2073 "type": ["string", "null"],
2074 "description": "The person's job, if provided (null otherwise)"
2075 }
2076 },
2077 "required": []
2078 }
2079 }
2080 });
2081
2082 let result: Result<Schema, _> = schema_with_ref.try_into();
2083
2084 match result {
2085 Ok(schema) => {
2086 assert_eq!(schema.r#type, "array");
2087
2088 if let Some(items) = schema.items {
2089 println!("item types: {}", items.r#type);
2090
2091 assert_ne!(items.r#type, "", "Items type should not be empty string!");
2092 assert_eq!(items.r#type, "object", "Items should be object type");
2093 } else {
2094 panic!("Schema should have items field for array type");
2095 }
2096 }
2097 Err(e) => println!("Schema conversion failed: {:?}", e),
2098 }
2099 }
2100
2101 #[test]
2102 fn test_object_schema() {
2103 let simple_schema = json!({
2104 "type": "object",
2105 "properties": {
2106 "name": {
2107 "type": "string"
2108 }
2109 }
2110 });
2111
2112 let schema: Schema = simple_schema.try_into().unwrap();
2113 assert_eq!(schema.r#type, "object");
2114 assert!(schema.properties.is_some());
2115 }
2116
2117 #[test]
2118 fn test_array_with_inline_items() {
2119 let inline_schema = json!({
2120 "type": "array",
2121 "items": {
2122 "type": "object",
2123 "properties": {
2124 "name": {
2125 "type": "string"
2126 }
2127 }
2128 }
2129 });
2130
2131 let schema: Schema = inline_schema.try_into().unwrap();
2132 assert_eq!(schema.r#type, "array");
2133
2134 if let Some(items) = schema.items {
2135 assert_eq!(items.r#type, "object");
2136 assert!(items.properties.is_some());
2137 } else {
2138 panic!("Schema should have items field");
2139 }
2140 }
2141 #[test]
2142 fn test_flattened_schema() {
2143 let ref_schema = json!({
2144 "type": "array",
2145 "items": {
2146 "$ref": "#/$defs/Person"
2147 },
2148 "$defs": {
2149 "Person": {
2150 "type": "object",
2151 "properties": {
2152 "name": { "type": "string" }
2153 }
2154 }
2155 }
2156 });
2157
2158 let flattened = flatten_schema(ref_schema).unwrap();
2159 let schema: Schema = flattened.try_into().unwrap();
2160
2161 assert_eq!(schema.r#type, "array");
2162
2163 if let Some(items) = schema.items {
2164 println!("Flattened items type: '{}'", items.r#type);
2165
2166 assert_eq!(items.r#type, "object");
2167 assert!(items.properties.is_some());
2168 }
2169 }
2170}