1pub const GEMINI_2_5_PRO_PREVIEW_06_05: &str = "gemini-2.5-pro-preview-06-05";
7pub const GEMINI_2_5_PRO_PREVIEW_05_06: &str = "gemini-2.5-pro-preview-05-06";
9pub const GEMINI_2_5_PRO_PREVIEW_03_25: &str = "gemini-2.5-pro-preview-03-25";
11pub const GEMINI_2_5_FLASH_PREVIEW_04_17: &str = "gemini-2.5-flash-preview-04-17";
13pub const GEMINI_2_5_PRO_EXP_03_25: &str = "gemini-2.5-pro-exp-03-25";
15pub const GEMINI_2_5_FLASH: &str = "gemini-2.5-flash";
17pub const GEMINI_2_0_FLASH_LITE: &str = "gemini-2.0-flash-lite";
19pub const GEMINI_2_0_FLASH: &str = "gemini-2.0-flash";
21
22use self::gemini_api_types::Schema;
23use crate::http_client::HttpClientExt;
24use crate::message::{self, MimeType, Reasoning};
25
26use crate::providers::gemini::completion::gemini_api_types::{
27 AdditionalParameters, FunctionCallingMode, ToolConfig,
28};
29use crate::providers::gemini::streaming::StreamingCompletionResponse;
30use crate::telemetry::SpanCombinator;
31use crate::{
32 OneOrMany,
33 completion::{self, CompletionError, CompletionRequest},
34};
35use gemini_api_types::{
36 Content, FunctionDeclaration, GenerateContentRequest, GenerateContentResponse, Part, PartKind,
37 Role, Tool,
38};
39use serde_json::{Map, Value};
40use std::convert::TryFrom;
41use tracing::{Level, enabled, info_span};
42use tracing_futures::Instrument;
43
44use super::Client;
45
46#[derive(Clone, Debug)]
51pub struct CompletionModel<T = reqwest::Client> {
52 pub(crate) client: Client<T>,
53 pub model: String,
54}
55
56impl<T> CompletionModel<T> {
57 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
58 Self {
59 client,
60 model: model.into(),
61 }
62 }
63
64 pub fn with_model(client: Client<T>, model: &str) -> Self {
65 Self {
66 client,
67 model: model.into(),
68 }
69 }
70}
71
72impl<T> completion::CompletionModel for CompletionModel<T>
73where
74 T: HttpClientExt + Clone + 'static,
75{
76 type Response = GenerateContentResponse;
77 type StreamingResponse = StreamingCompletionResponse;
78 type Client = super::Client<T>;
79
80 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
81 Self::new(client.clone(), model)
82 }
83
84 #[cfg_attr(feature = "worker", worker::send)]
85 async fn completion(
86 &self,
87 completion_request: CompletionRequest,
88 ) -> Result<completion::CompletionResponse<GenerateContentResponse>, CompletionError> {
89 let span = if tracing::Span::current().is_disabled() {
90 info_span!(
91 target: "rig::completions",
92 "generate_content",
93 gen_ai.operation.name = "generate_content",
94 gen_ai.provider.name = "gcp.gemini",
95 gen_ai.request.model = self.model,
96 gen_ai.system_instructions = &completion_request.preamble,
97 gen_ai.response.id = tracing::field::Empty,
98 gen_ai.response.model = tracing::field::Empty,
99 gen_ai.usage.output_tokens = tracing::field::Empty,
100 gen_ai.usage.input_tokens = tracing::field::Empty,
101 )
102 } else {
103 tracing::Span::current()
104 };
105
106 let request = create_request_body(completion_request)?;
107
108 if enabled!(Level::TRACE) {
109 tracing::trace!(
110 target: "rig::completions",
111 "Gemini completion request: {}",
112 serde_json::to_string_pretty(&request)?
113 );
114 }
115
116 let body = serde_json::to_vec(&request)?;
117
118 let path = format!("/v1beta/models/{}:generateContent", self.model);
119
120 let request = self
121 .client
122 .post(path.as_str())?
123 .body(body)
124 .map_err(|e| CompletionError::HttpError(e.into()))?;
125
126 async move {
127 let response = self.client.send::<_, Vec<u8>>(request).await?;
128
129 if response.status().is_success() {
130 let response_body = response
131 .into_body()
132 .await
133 .map_err(CompletionError::HttpError)?;
134
135 let response_text = String::from_utf8_lossy(&response_body).to_string();
136
137 let response: GenerateContentResponse = serde_json::from_slice(&response_body)
138 .map_err(|err| {
139 tracing::error!(
140 error = %err,
141 body = %response_text,
142 "Failed to deserialize Gemini completion response"
143 );
144 CompletionError::JsonError(err)
145 })?;
146
147 let span = tracing::Span::current();
148 span.record_response_metadata(&response);
149 span.record_token_usage(&response.usage_metadata);
150
151 if enabled!(Level::TRACE) {
152 tracing::trace!(
153 target: "rig::completions",
154 "Gemini completion response: {}",
155 serde_json::to_string_pretty(&response)?
156 );
157 }
158
159 response.try_into()
160 } else {
161 let text = String::from_utf8_lossy(
162 &response
163 .into_body()
164 .await
165 .map_err(CompletionError::HttpError)?,
166 )
167 .into();
168
169 Err(CompletionError::ProviderError(text))
170 }
171 }
172 .instrument(span)
173 .await
174 }
175
176 #[cfg_attr(feature = "worker", worker::send)]
177 async fn stream(
178 &self,
179 request: CompletionRequest,
180 ) -> Result<
181 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
182 CompletionError,
183 > {
184 CompletionModel::stream(self, request).await
185 }
186}
187
188pub(crate) fn create_request_body(
189 completion_request: CompletionRequest,
190) -> Result<GenerateContentRequest, CompletionError> {
191 let mut full_history = Vec::new();
192 full_history.extend(completion_request.chat_history);
193
194 let additional_params = completion_request
195 .additional_params
196 .unwrap_or_else(|| Value::Object(Map::new()));
197
198 let AdditionalParameters {
199 mut generation_config,
200 additional_params,
201 } = serde_json::from_value::<AdditionalParameters>(additional_params)?;
202
203 generation_config = generation_config.map(|mut cfg| {
204 if let Some(temp) = completion_request.temperature {
205 cfg.temperature = Some(temp);
206 };
207
208 if let Some(max_tokens) = completion_request.max_tokens {
209 cfg.max_output_tokens = Some(max_tokens);
210 };
211
212 cfg
213 });
214
215 let system_instruction = completion_request.preamble.clone().map(|preamble| Content {
216 parts: vec![preamble.into()],
217 role: Some(Role::Model),
218 });
219
220 let tools = if completion_request.tools.is_empty() {
221 None
222 } else {
223 Some(Tool::try_from(completion_request.tools)?)
224 };
225
226 let tool_config = if let Some(cfg) = completion_request.tool_choice {
227 Some(ToolConfig {
228 function_calling_config: Some(FunctionCallingMode::try_from(cfg)?),
229 })
230 } else {
231 None
232 };
233
234 let request = GenerateContentRequest {
235 contents: full_history
236 .into_iter()
237 .map(|msg| {
238 msg.try_into()
239 .map_err(|e| CompletionError::RequestError(Box::new(e)))
240 })
241 .collect::<Result<Vec<_>, _>>()?,
242 generation_config,
243 safety_settings: None,
244 tools,
245 tool_config,
246 system_instruction,
247 additional_params,
248 };
249
250 Ok(request)
251}
252
253impl TryFrom<completion::ToolDefinition> for Tool {
254 type Error = CompletionError;
255
256 fn try_from(tool: completion::ToolDefinition) -> Result<Self, Self::Error> {
257 let parameters: Option<Schema> =
258 if tool.parameters == serde_json::json!({"type": "object", "properties": {}}) {
259 None
260 } else {
261 Some(tool.parameters.try_into()?)
262 };
263
264 Ok(Self {
265 function_declarations: vec![FunctionDeclaration {
266 name: tool.name,
267 description: tool.description,
268 parameters,
269 }],
270 code_execution: None,
271 })
272 }
273}
274
275impl TryFrom<Vec<completion::ToolDefinition>> for Tool {
276 type Error = CompletionError;
277
278 fn try_from(tools: Vec<completion::ToolDefinition>) -> Result<Self, Self::Error> {
279 let mut function_declarations = Vec::new();
280
281 for tool in tools {
282 let parameters =
283 if tool.parameters == serde_json::json!({"type": "object", "properties": {}}) {
284 None
285 } else {
286 match tool.parameters.try_into() {
287 Ok(schema) => Some(schema),
288 Err(e) => {
289 let emsg = format!(
290 "Tool '{}' could not be converted to a schema: {:?}",
291 tool.name, e,
292 );
293 return Err(CompletionError::ProviderError(emsg));
294 }
295 }
296 };
297
298 function_declarations.push(FunctionDeclaration {
299 name: tool.name,
300 description: tool.description,
301 parameters,
302 });
303 }
304
305 Ok(Self {
306 function_declarations,
307 code_execution: None,
308 })
309 }
310}
311
312impl TryFrom<GenerateContentResponse> for completion::CompletionResponse<GenerateContentResponse> {
313 type Error = CompletionError;
314
315 fn try_from(response: GenerateContentResponse) -> Result<Self, Self::Error> {
316 let candidate = response.candidates.first().ok_or_else(|| {
317 CompletionError::ResponseError("No response candidates in response".into())
318 })?;
319
320 let content = candidate
321 .content
322 .as_ref()
323 .ok_or_else(|| {
324 let reason = candidate
325 .finish_reason
326 .as_ref()
327 .map(|r| format!("finish_reason={r:?}"))
328 .unwrap_or_else(|| "finish_reason=<unknown>".to_string());
329 let message = candidate
330 .finish_message
331 .as_deref()
332 .unwrap_or("no finish message provided");
333 CompletionError::ResponseError(format!(
334 "Gemini candidate missing content ({reason}, finish_message={message})"
335 ))
336 })?
337 .parts
338 .iter()
339 .map(|Part { thought, part, .. }| {
340 Ok(match part {
341 PartKind::Text(text) => {
342 if let Some(thought) = thought
343 && *thought
344 {
345 completion::AssistantContent::Reasoning(Reasoning::new(text))
346 } else {
347 completion::AssistantContent::text(text)
348 }
349 }
350 PartKind::InlineData(inline_data) => {
351 let mime_type = message::MediaType::from_mime_type(&inline_data.mime_type);
352
353 match mime_type {
354 Some(message::MediaType::Image(media_type)) => {
355 message::AssistantContent::image_base64(
356 &inline_data.data,
357 Some(media_type),
358 Some(message::ImageDetail::default()),
359 )
360 }
361 _ => {
362 return Err(CompletionError::ResponseError(format!(
363 "Unsupported media type {mime_type:?}"
364 )));
365 }
366 }
367 }
368 PartKind::FunctionCall(function_call) => {
369 completion::AssistantContent::tool_call(
370 &function_call.name,
371 &function_call.name,
372 function_call.args.clone(),
373 )
374 }
375 _ => {
376 return Err(CompletionError::ResponseError(
377 "Response did not contain a message or tool call".into(),
378 ));
379 }
380 })
381 })
382 .collect::<Result<Vec<_>, _>>()?;
383
384 let choice = OneOrMany::many(content).map_err(|_| {
385 CompletionError::ResponseError(
386 "Response contained no message or tool call (empty)".to_owned(),
387 )
388 })?;
389
390 let usage = response
391 .usage_metadata
392 .as_ref()
393 .map(|usage| completion::Usage {
394 input_tokens: usage.prompt_token_count as u64,
395 output_tokens: usage.candidates_token_count.unwrap_or(0) as u64,
396 total_tokens: usage.total_token_count as u64,
397 })
398 .unwrap_or_default();
399
400 Ok(completion::CompletionResponse {
401 choice,
402 usage,
403 raw_response: response,
404 })
405 }
406}
407
408pub mod gemini_api_types {
409 use crate::telemetry::ProviderResponseExt;
410 use std::{collections::HashMap, convert::Infallible, str::FromStr};
411
412 use serde::{Deserialize, Serialize};
416 use serde_json::{Value, json};
417
418 use crate::completion::GetTokenUsage;
419 use crate::message::{DocumentSourceKind, ImageMediaType, MessageError, MimeType};
420 use crate::{
421 OneOrMany,
422 completion::CompletionError,
423 message::{self, Reasoning, Text},
424 providers::gemini::gemini_api_types::{CodeExecutionResult, ExecutableCode},
425 };
426
427 #[derive(Debug, Deserialize, Serialize, Default)]
428 #[serde(rename_all = "camelCase")]
429 pub struct AdditionalParameters {
430 pub generation_config: Option<GenerationConfig>,
432 #[serde(flatten, skip_serializing_if = "Option::is_none")]
434 pub additional_params: Option<serde_json::Value>,
435 }
436
437 impl AdditionalParameters {
438 pub fn with_config(mut self, cfg: GenerationConfig) -> Self {
439 self.generation_config = Some(cfg);
440 self
441 }
442
443 pub fn with_params(mut self, params: serde_json::Value) -> Self {
444 self.additional_params = Some(params);
445 self
446 }
447 }
448
449 #[derive(Debug, Deserialize, Serialize)]
457 #[serde(rename_all = "camelCase")]
458 pub struct GenerateContentResponse {
459 pub response_id: String,
460 pub candidates: Vec<ContentCandidate>,
462 pub prompt_feedback: Option<PromptFeedback>,
464 pub usage_metadata: Option<UsageMetadata>,
466 pub model_version: Option<String>,
467 }
468
469 impl ProviderResponseExt for GenerateContentResponse {
470 type OutputMessage = ContentCandidate;
471 type Usage = UsageMetadata;
472
473 fn get_response_id(&self) -> Option<String> {
474 Some(self.response_id.clone())
475 }
476
477 fn get_response_model_name(&self) -> Option<String> {
478 None
479 }
480
481 fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
482 self.candidates.clone()
483 }
484
485 fn get_text_response(&self) -> Option<String> {
486 let str = self
487 .candidates
488 .iter()
489 .filter_map(|x| {
490 let content = x.content.as_ref()?;
491 if content.role.as_ref().is_none_or(|y| y != &Role::Model) {
492 return None;
493 }
494
495 let res = content
496 .parts
497 .iter()
498 .filter_map(|part| {
499 if let PartKind::Text(ref str) = part.part {
500 Some(str.to_owned())
501 } else {
502 None
503 }
504 })
505 .collect::<Vec<String>>()
506 .join("\n");
507
508 Some(res)
509 })
510 .collect::<Vec<String>>()
511 .join("\n");
512
513 if str.is_empty() { None } else { Some(str) }
514 }
515
516 fn get_usage(&self) -> Option<Self::Usage> {
517 self.usage_metadata.clone()
518 }
519 }
520
521 #[derive(Clone, Debug, Deserialize, Serialize)]
523 #[serde(rename_all = "camelCase")]
524 pub struct ContentCandidate {
525 #[serde(skip_serializing_if = "Option::is_none")]
527 pub content: Option<Content>,
528 pub finish_reason: Option<FinishReason>,
531 pub safety_ratings: Option<Vec<SafetyRating>>,
534 pub citation_metadata: Option<CitationMetadata>,
538 pub token_count: Option<i32>,
540 pub avg_logprobs: Option<f64>,
542 pub logprobs_result: Option<LogprobsResult>,
544 pub index: Option<i32>,
546 pub finish_message: Option<String>,
548 }
549
550 #[derive(Clone, Debug, Deserialize, Serialize)]
551 pub struct Content {
552 #[serde(default)]
554 pub parts: Vec<Part>,
555 pub role: Option<Role>,
558 }
559
560 impl TryFrom<message::Message> for Content {
561 type Error = message::MessageError;
562
563 fn try_from(msg: message::Message) -> Result<Self, Self::Error> {
564 Ok(match msg {
565 message::Message::User { content } => Content {
566 parts: content
567 .into_iter()
568 .map(|c| c.try_into())
569 .collect::<Result<Vec<_>, _>>()?,
570 role: Some(Role::User),
571 },
572 message::Message::Assistant { content, .. } => Content {
573 role: Some(Role::Model),
574 parts: content
575 .into_iter()
576 .map(|content| content.try_into())
577 .collect::<Result<Vec<_>, _>>()?,
578 },
579 })
580 }
581 }
582
583 impl TryFrom<Content> for message::Message {
584 type Error = message::MessageError;
585
586 fn try_from(content: Content) -> Result<Self, Self::Error> {
587 match content.role {
588 Some(Role::User) | None => {
589 Ok(message::Message::User {
590 content: {
591 let user_content: Result<Vec<_>, _> = content.parts.into_iter()
592 .map(|Part { part, .. }| {
593 Ok(match part {
594 PartKind::Text(text) => message::UserContent::text(text),
595 PartKind::InlineData(inline_data) => {
596 let mime_type =
597 message::MediaType::from_mime_type(&inline_data.mime_type);
598
599 match mime_type {
600 Some(message::MediaType::Image(media_type)) => {
601 message::UserContent::image_base64(
602 inline_data.data,
603 Some(media_type),
604 Some(message::ImageDetail::default()),
605 )
606 }
607 Some(message::MediaType::Document(media_type)) => {
608 message::UserContent::document(
609 inline_data.data,
610 Some(media_type),
611 )
612 }
613 Some(message::MediaType::Audio(media_type)) => {
614 message::UserContent::audio(
615 inline_data.data,
616 Some(media_type),
617 )
618 }
619 _ => {
620 return Err(message::MessageError::ConversionError(
621 format!("Unsupported media type {mime_type:?}"),
622 ));
623 }
624 }
625 }
626 _ => {
627 return Err(message::MessageError::ConversionError(format!(
628 "Unsupported gemini content part type: {part:?}"
629 )));
630 }
631 })
632 })
633 .collect();
634 OneOrMany::many(user_content?).map_err(|_| {
635 message::MessageError::ConversionError(
636 "Failed to create OneOrMany from user content".to_string(),
637 )
638 })?
639 },
640 })
641 }
642 Some(Role::Model) => Ok(message::Message::Assistant {
643 id: None,
644 content: {
645 let assistant_content: Result<Vec<_>, _> = content
646 .parts
647 .into_iter()
648 .map(|Part { thought, part, .. }| {
649 Ok(match part {
650 PartKind::Text(text) => match thought {
651 Some(true) => message::AssistantContent::Reasoning(
652 Reasoning::new(&text),
653 ),
654 _ => message::AssistantContent::Text(Text { text }),
655 },
656
657 PartKind::FunctionCall(function_call) => {
658 message::AssistantContent::ToolCall(function_call.into())
659 }
660 _ => {
661 return Err(message::MessageError::ConversionError(
662 format!("Unsupported part type: {part:?}"),
663 ));
664 }
665 })
666 })
667 .collect();
668 OneOrMany::many(assistant_content?).map_err(|_| {
669 message::MessageError::ConversionError(
670 "Failed to create OneOrMany from assistant content".to_string(),
671 )
672 })?
673 },
674 }),
675 }
676 }
677 }
678
679 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
680 #[serde(rename_all = "lowercase")]
681 pub enum Role {
682 User,
683 Model,
684 }
685
686 #[derive(Debug, Default, Deserialize, Serialize, Clone, PartialEq)]
687 #[serde(rename_all = "camelCase")]
688 pub struct Part {
689 #[serde(skip_serializing_if = "Option::is_none")]
691 pub thought: Option<bool>,
692 #[serde(skip_serializing_if = "Option::is_none")]
694 pub thought_signature: Option<String>,
695 #[serde(flatten)]
696 pub part: PartKind,
697 #[serde(flatten, skip_serializing_if = "Option::is_none")]
698 pub additional_params: Option<Value>,
699 }
700
701 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
705 #[serde(rename_all = "camelCase")]
706 pub enum PartKind {
707 Text(String),
708 InlineData(Blob),
709 FunctionCall(FunctionCall),
710 FunctionResponse(FunctionResponse),
711 FileData(FileData),
712 ExecutableCode(ExecutableCode),
713 CodeExecutionResult(CodeExecutionResult),
714 }
715
716 impl Default for PartKind {
719 fn default() -> Self {
720 Self::Text(String::new())
721 }
722 }
723
724 impl From<String> for Part {
725 fn from(text: String) -> Self {
726 Self {
727 thought: Some(false),
728 thought_signature: None,
729 part: PartKind::Text(text),
730 additional_params: None,
731 }
732 }
733 }
734
735 impl From<&str> for Part {
736 fn from(text: &str) -> Self {
737 Self::from(text.to_string())
738 }
739 }
740
741 impl FromStr for Part {
742 type Err = Infallible;
743
744 fn from_str(s: &str) -> Result<Self, Self::Err> {
745 Ok(s.into())
746 }
747 }
748
749 impl TryFrom<(ImageMediaType, DocumentSourceKind)> for PartKind {
750 type Error = message::MessageError;
751 fn try_from(
752 (mime_type, doc_src): (ImageMediaType, DocumentSourceKind),
753 ) -> Result<Self, Self::Error> {
754 let mime_type = mime_type.to_mime_type().to_string();
755 let part = match doc_src {
756 DocumentSourceKind::Url(url) => PartKind::FileData(FileData {
757 mime_type: Some(mime_type),
758 file_uri: url,
759 }),
760 DocumentSourceKind::Base64(data) | DocumentSourceKind::String(data) => {
761 PartKind::InlineData(Blob { mime_type, data })
762 }
763 DocumentSourceKind::Raw(_) => {
764 return Err(message::MessageError::ConversionError(
765 "Raw files not supported, encode as base64 first".into(),
766 ));
767 }
768 DocumentSourceKind::Unknown => {
769 return Err(message::MessageError::ConversionError(
770 "Can't convert an unknown document source".to_string(),
771 ));
772 }
773 };
774
775 Ok(part)
776 }
777 }
778
779 impl TryFrom<message::UserContent> for Part {
780 type Error = message::MessageError;
781
782 fn try_from(content: message::UserContent) -> Result<Self, Self::Error> {
783 match content {
784 message::UserContent::Text(message::Text { text }) => Ok(Part {
785 thought: Some(false),
786 thought_signature: None,
787 part: PartKind::Text(text),
788 additional_params: None,
789 }),
790 message::UserContent::ToolResult(message::ToolResult { id, content, .. }) => {
791 let content = match content.first() {
792 message::ToolResultContent::Text(text) => text.text,
793 message::ToolResultContent::Image(_) => {
794 return Err(message::MessageError::ConversionError(
795 "Tool result content must be text".to_string(),
796 ));
797 }
798 };
799 let result: serde_json::Value =
801 serde_json::from_str(&content).unwrap_or_else(|error| {
802 tracing::trace!(
803 ?error,
804 "Tool result is not a valid JSON, treat it as normal string"
805 );
806 json!(content)
807 });
808 Ok(Part {
809 thought: Some(false),
810 thought_signature: None,
811 part: PartKind::FunctionResponse(FunctionResponse {
812 name: id,
813 response: Some(json!({ "result": result })),
814 }),
815 additional_params: None,
816 })
817 }
818 message::UserContent::Image(message::Image {
819 data, media_type, ..
820 }) => match media_type {
821 Some(media_type) => match media_type {
822 message::ImageMediaType::JPEG
823 | message::ImageMediaType::PNG
824 | message::ImageMediaType::WEBP
825 | message::ImageMediaType::HEIC
826 | message::ImageMediaType::HEIF => {
827 let part = PartKind::try_from((media_type, data))?;
828 Ok(Part {
829 thought: Some(false),
830 thought_signature: None,
831 part,
832 additional_params: None,
833 })
834 }
835 _ => Err(message::MessageError::ConversionError(format!(
836 "Unsupported image media type {media_type:?}"
837 ))),
838 },
839 None => Err(message::MessageError::ConversionError(
840 "Media type for image is required for Gemini".to_string(),
841 )),
842 },
843 message::UserContent::Document(message::Document {
844 data, media_type, ..
845 }) => {
846 let Some(media_type) = media_type else {
847 return Err(MessageError::ConversionError(
848 "A mime type is required for document inputs to Gemini".to_string(),
849 ));
850 };
851
852 if !media_type.is_code() {
853 let mime_type = media_type.to_mime_type().to_string();
854
855 let part = match data {
856 DocumentSourceKind::Url(file_uri) => PartKind::FileData(FileData {
857 mime_type: Some(mime_type),
858 file_uri,
859 }),
860 DocumentSourceKind::Base64(data) | DocumentSourceKind::String(data) => {
861 PartKind::InlineData(Blob { mime_type, data })
862 }
863 DocumentSourceKind::Raw(_) => {
864 return Err(message::MessageError::ConversionError(
865 "Raw files not supported, encode as base64 first".into(),
866 ));
867 }
868 _ => {
869 return Err(message::MessageError::ConversionError(
870 "Document has no body".to_string(),
871 ));
872 }
873 };
874
875 Ok(Part {
876 thought: Some(false),
877 part,
878 ..Default::default()
879 })
880 } else {
881 Err(message::MessageError::ConversionError(format!(
882 "Unsupported document media type {media_type:?}"
883 )))
884 }
885 }
886
887 message::UserContent::Audio(message::Audio {
888 data, media_type, ..
889 }) => {
890 let Some(media_type) = media_type else {
891 return Err(MessageError::ConversionError(
892 "A mime type is required for audio inputs to Gemini".to_string(),
893 ));
894 };
895
896 let mime_type = media_type.to_mime_type().to_string();
897
898 let part = match data {
899 DocumentSourceKind::Base64(data) => {
900 PartKind::InlineData(Blob { data, mime_type })
901 }
902
903 DocumentSourceKind::Url(file_uri) => PartKind::FileData(FileData {
904 mime_type: Some(mime_type),
905 file_uri,
906 }),
907 DocumentSourceKind::String(_) => {
908 return Err(message::MessageError::ConversionError(
909 "Strings cannot be used as audio files!".into(),
910 ));
911 }
912 DocumentSourceKind::Raw(_) => {
913 return Err(message::MessageError::ConversionError(
914 "Raw files not supported, encode as base64 first".into(),
915 ));
916 }
917 DocumentSourceKind::Unknown => {
918 return Err(message::MessageError::ConversionError(
919 "Content has no body".to_string(),
920 ));
921 }
922 };
923
924 Ok(Part {
925 thought: Some(false),
926 part,
927 ..Default::default()
928 })
929 }
930 message::UserContent::Video(message::Video {
931 data,
932 media_type,
933 additional_params,
934 ..
935 }) => {
936 let mime_type = media_type.map(|media_ty| media_ty.to_mime_type().to_string());
937
938 let part = match data {
939 DocumentSourceKind::Url(file_uri) => {
940 if file_uri.starts_with("https://www.youtube.com") {
941 PartKind::FileData(FileData {
942 mime_type,
943 file_uri,
944 })
945 } else {
946 if mime_type.is_none() {
947 return Err(MessageError::ConversionError(
948 "A mime type is required for non-Youtube video file inputs to Gemini"
949 .to_string(),
950 ));
951 }
952
953 PartKind::FileData(FileData {
954 mime_type,
955 file_uri,
956 })
957 }
958 }
959 DocumentSourceKind::Base64(data) => {
960 let Some(mime_type) = mime_type else {
961 return Err(MessageError::ConversionError(
962 "A media type is expected for base64 encoded strings"
963 .to_string(),
964 ));
965 };
966 PartKind::InlineData(Blob { mime_type, data })
967 }
968 DocumentSourceKind::String(_) => {
969 return Err(message::MessageError::ConversionError(
970 "Strings cannot be used as audio files!".into(),
971 ));
972 }
973 DocumentSourceKind::Raw(_) => {
974 return Err(message::MessageError::ConversionError(
975 "Raw file data not supported, encode as base64 first".into(),
976 ));
977 }
978 DocumentSourceKind::Unknown => {
979 return Err(message::MessageError::ConversionError(
980 "Media type for video is required for Gemini".to_string(),
981 ));
982 }
983 };
984
985 Ok(Part {
986 thought: Some(false),
987 thought_signature: None,
988 part,
989 additional_params,
990 })
991 }
992 }
993 }
994 }
995
996 impl TryFrom<message::AssistantContent> for Part {
997 type Error = message::MessageError;
998
999 fn try_from(content: message::AssistantContent) -> Result<Self, Self::Error> {
1000 match content {
1001 message::AssistantContent::Text(message::Text { text }) => Ok(text.into()),
1002 message::AssistantContent::Image(message::Image {
1003 data, media_type, ..
1004 }) => match media_type {
1005 Some(media_type) => match media_type {
1006 message::ImageMediaType::JPEG
1007 | message::ImageMediaType::PNG
1008 | message::ImageMediaType::WEBP
1009 | message::ImageMediaType::HEIC
1010 | message::ImageMediaType::HEIF => {
1011 let part = PartKind::try_from((media_type, data))?;
1012 Ok(Part {
1013 thought: Some(false),
1014 thought_signature: None,
1015 part,
1016 additional_params: None,
1017 })
1018 }
1019 _ => Err(message::MessageError::ConversionError(format!(
1020 "Unsupported image media type {media_type:?}"
1021 ))),
1022 },
1023 None => Err(message::MessageError::ConversionError(
1024 "Media type for image is required for Gemini".to_string(),
1025 )),
1026 },
1027 message::AssistantContent::ToolCall(tool_call) => Ok(tool_call.into()),
1028 message::AssistantContent::Reasoning(message::Reasoning { reasoning, .. }) => {
1029 Ok(Part {
1030 thought: Some(true),
1031 thought_signature: None,
1032 part: PartKind::Text(
1033 reasoning.first().cloned().unwrap_or_else(|| "".to_string()),
1034 ),
1035 additional_params: None,
1036 })
1037 }
1038 }
1039 }
1040 }
1041
1042 impl From<message::ToolCall> for Part {
1043 fn from(tool_call: message::ToolCall) -> Self {
1044 Self {
1045 thought: Some(false),
1046 thought_signature: None,
1047 part: PartKind::FunctionCall(FunctionCall {
1048 name: tool_call.function.name,
1049 args: tool_call.function.arguments,
1050 }),
1051 additional_params: None,
1052 }
1053 }
1054 }
1055
1056 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1059 #[serde(rename_all = "camelCase")]
1060 pub struct Blob {
1061 pub mime_type: String,
1064 pub data: String,
1066 }
1067
1068 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1071 pub struct FunctionCall {
1072 pub name: String,
1075 pub args: serde_json::Value,
1077 }
1078
1079 impl From<FunctionCall> for message::ToolCall {
1080 fn from(function_call: FunctionCall) -> Self {
1081 Self {
1082 id: function_call.name.clone(),
1083 call_id: None,
1084 function: message::ToolFunction {
1085 name: function_call.name,
1086 arguments: function_call.args,
1087 },
1088 }
1089 }
1090 }
1091
1092 impl From<message::ToolCall> for FunctionCall {
1093 fn from(tool_call: message::ToolCall) -> Self {
1094 Self {
1095 name: tool_call.function.name,
1096 args: tool_call.function.arguments,
1097 }
1098 }
1099 }
1100
1101 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1105 pub struct FunctionResponse {
1106 pub name: String,
1109 pub response: Option<serde_json::Value>,
1111 }
1112
1113 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1115 #[serde(rename_all = "camelCase")]
1116 pub struct FileData {
1117 pub mime_type: Option<String>,
1119 pub file_uri: String,
1121 }
1122
1123 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1124 pub struct SafetyRating {
1125 pub category: HarmCategory,
1126 pub probability: HarmProbability,
1127 }
1128
1129 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1130 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1131 pub enum HarmProbability {
1132 HarmProbabilityUnspecified,
1133 Negligible,
1134 Low,
1135 Medium,
1136 High,
1137 }
1138
1139 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1140 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1141 pub enum HarmCategory {
1142 HarmCategoryUnspecified,
1143 HarmCategoryDerogatory,
1144 HarmCategoryToxicity,
1145 HarmCategoryViolence,
1146 HarmCategorySexually,
1147 HarmCategoryMedical,
1148 HarmCategoryDangerous,
1149 HarmCategoryHarassment,
1150 HarmCategoryHateSpeech,
1151 HarmCategorySexuallyExplicit,
1152 HarmCategoryDangerousContent,
1153 HarmCategoryCivicIntegrity,
1154 }
1155
1156 #[derive(Debug, Deserialize, Clone, Default, Serialize)]
1157 #[serde(rename_all = "camelCase")]
1158 pub struct UsageMetadata {
1159 pub prompt_token_count: i32,
1160 #[serde(skip_serializing_if = "Option::is_none")]
1161 pub cached_content_token_count: Option<i32>,
1162 #[serde(skip_serializing_if = "Option::is_none")]
1163 pub candidates_token_count: Option<i32>,
1164 pub total_token_count: i32,
1165 #[serde(skip_serializing_if = "Option::is_none")]
1166 pub thoughts_token_count: Option<i32>,
1167 }
1168
1169 impl std::fmt::Display for UsageMetadata {
1170 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1171 write!(
1172 f,
1173 "Prompt token count: {}\nCached content token count: {}\nCandidates token count: {}\nTotal token count: {}",
1174 self.prompt_token_count,
1175 match self.cached_content_token_count {
1176 Some(count) => count.to_string(),
1177 None => "n/a".to_string(),
1178 },
1179 match self.candidates_token_count {
1180 Some(count) => count.to_string(),
1181 None => "n/a".to_string(),
1182 },
1183 self.total_token_count
1184 )
1185 }
1186 }
1187
1188 impl GetTokenUsage for UsageMetadata {
1189 fn token_usage(&self) -> Option<crate::completion::Usage> {
1190 let mut usage = crate::completion::Usage::new();
1191
1192 usage.input_tokens = self.prompt_token_count as u64;
1193 usage.output_tokens = (self.cached_content_token_count.unwrap_or_default()
1194 + self.candidates_token_count.unwrap_or_default()
1195 + self.thoughts_token_count.unwrap_or_default())
1196 as u64;
1197 usage.total_tokens = usage.input_tokens + usage.output_tokens;
1198
1199 Some(usage)
1200 }
1201 }
1202
1203 #[derive(Debug, Deserialize, Serialize)]
1205 #[serde(rename_all = "camelCase")]
1206 pub struct PromptFeedback {
1207 pub block_reason: Option<BlockReason>,
1209 pub safety_ratings: Option<Vec<SafetyRating>>,
1211 }
1212
1213 #[derive(Debug, Deserialize, Serialize)]
1215 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1216 pub enum BlockReason {
1217 BlockReasonUnspecified,
1219 Safety,
1221 Other,
1223 Blocklist,
1225 ProhibitedContent,
1227 }
1228
1229 #[derive(Clone, Debug, Deserialize, Serialize)]
1230 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1231 pub enum FinishReason {
1232 FinishReasonUnspecified,
1234 Stop,
1236 MaxTokens,
1238 Safety,
1240 Recitation,
1242 Language,
1244 Other,
1246 Blocklist,
1248 ProhibitedContent,
1250 Spii,
1252 MalformedFunctionCall,
1254 }
1255
1256 #[derive(Clone, Debug, Deserialize, Serialize)]
1257 #[serde(rename_all = "camelCase")]
1258 pub struct CitationMetadata {
1259 pub citation_sources: Vec<CitationSource>,
1260 }
1261
1262 #[derive(Clone, Debug, Deserialize, Serialize)]
1263 #[serde(rename_all = "camelCase")]
1264 pub struct CitationSource {
1265 #[serde(skip_serializing_if = "Option::is_none")]
1266 pub uri: Option<String>,
1267 #[serde(skip_serializing_if = "Option::is_none")]
1268 pub start_index: Option<i32>,
1269 #[serde(skip_serializing_if = "Option::is_none")]
1270 pub end_index: Option<i32>,
1271 #[serde(skip_serializing_if = "Option::is_none")]
1272 pub license: Option<String>,
1273 }
1274
1275 #[derive(Clone, Debug, Deserialize, Serialize)]
1276 #[serde(rename_all = "camelCase")]
1277 pub struct LogprobsResult {
1278 pub top_candidate: Vec<TopCandidate>,
1279 pub chosen_candidate: Vec<LogProbCandidate>,
1280 }
1281
1282 #[derive(Clone, Debug, Deserialize, Serialize)]
1283 pub struct TopCandidate {
1284 pub candidates: Vec<LogProbCandidate>,
1285 }
1286
1287 #[derive(Clone, Debug, Deserialize, Serialize)]
1288 #[serde(rename_all = "camelCase")]
1289 pub struct LogProbCandidate {
1290 pub token: String,
1291 pub token_id: String,
1292 pub log_probability: f64,
1293 }
1294
1295 #[derive(Debug, Deserialize, Serialize)]
1300 #[serde(rename_all = "camelCase")]
1301 pub struct GenerationConfig {
1302 #[serde(skip_serializing_if = "Option::is_none")]
1305 pub stop_sequences: Option<Vec<String>>,
1306 #[serde(skip_serializing_if = "Option::is_none")]
1312 pub response_mime_type: Option<String>,
1313 #[serde(skip_serializing_if = "Option::is_none")]
1317 pub response_schema: Option<Schema>,
1318 #[serde(
1324 skip_serializing_if = "Option::is_none",
1325 rename = "_responseJsonSchema"
1326 )]
1327 pub _response_json_schema: Option<Value>,
1328 #[serde(skip_serializing_if = "Option::is_none")]
1330 pub response_json_schema: Option<Value>,
1331 #[serde(skip_serializing_if = "Option::is_none")]
1334 pub candidate_count: Option<i32>,
1335 #[serde(skip_serializing_if = "Option::is_none")]
1338 pub max_output_tokens: Option<u64>,
1339 #[serde(skip_serializing_if = "Option::is_none")]
1342 pub temperature: Option<f64>,
1343 #[serde(skip_serializing_if = "Option::is_none")]
1350 pub top_p: Option<f64>,
1351 #[serde(skip_serializing_if = "Option::is_none")]
1357 pub top_k: Option<i32>,
1358 #[serde(skip_serializing_if = "Option::is_none")]
1364 pub presence_penalty: Option<f64>,
1365 #[serde(skip_serializing_if = "Option::is_none")]
1373 pub frequency_penalty: Option<f64>,
1374 #[serde(skip_serializing_if = "Option::is_none")]
1376 pub response_logprobs: Option<bool>,
1377 #[serde(skip_serializing_if = "Option::is_none")]
1380 pub logprobs: Option<i32>,
1381 #[serde(skip_serializing_if = "Option::is_none")]
1383 pub thinking_config: Option<ThinkingConfig>,
1384 #[serde(skip_serializing_if = "Option::is_none")]
1385 pub image_config: Option<ImageConfig>,
1386 }
1387
1388 impl Default for GenerationConfig {
1389 fn default() -> Self {
1390 Self {
1391 temperature: Some(1.0),
1392 max_output_tokens: Some(4096),
1393 stop_sequences: None,
1394 response_mime_type: None,
1395 response_schema: None,
1396 _response_json_schema: None,
1397 response_json_schema: None,
1398 candidate_count: None,
1399 top_p: None,
1400 top_k: None,
1401 presence_penalty: None,
1402 frequency_penalty: None,
1403 response_logprobs: None,
1404 logprobs: None,
1405 thinking_config: None,
1406 image_config: None,
1407 }
1408 }
1409 }
1410
1411 #[derive(Debug, Deserialize, Serialize)]
1412 #[serde(rename_all = "camelCase")]
1413 pub struct ThinkingConfig {
1414 pub thinking_budget: u32,
1415 pub include_thoughts: Option<bool>,
1416 }
1417
1418 #[derive(Debug, Deserialize, Serialize)]
1419 #[serde(rename_all = "camelCase")]
1420 pub struct ImageConfig {
1421 #[serde(skip_serializing_if = "Option::is_none")]
1422 pub aspect_ratio: Option<String>,
1423 #[serde(skip_serializing_if = "Option::is_none")]
1424 pub image_size: Option<String>,
1425 }
1426
1427 #[derive(Debug, Deserialize, Serialize, Clone)]
1431 pub struct Schema {
1432 pub r#type: String,
1433 #[serde(skip_serializing_if = "Option::is_none")]
1434 pub format: Option<String>,
1435 #[serde(skip_serializing_if = "Option::is_none")]
1436 pub description: Option<String>,
1437 #[serde(skip_serializing_if = "Option::is_none")]
1438 pub nullable: Option<bool>,
1439 #[serde(skip_serializing_if = "Option::is_none")]
1440 pub r#enum: Option<Vec<String>>,
1441 #[serde(skip_serializing_if = "Option::is_none")]
1442 pub max_items: Option<i32>,
1443 #[serde(skip_serializing_if = "Option::is_none")]
1444 pub min_items: Option<i32>,
1445 #[serde(skip_serializing_if = "Option::is_none")]
1446 pub properties: Option<HashMap<String, Schema>>,
1447 #[serde(skip_serializing_if = "Option::is_none")]
1448 pub required: Option<Vec<String>>,
1449 #[serde(skip_serializing_if = "Option::is_none")]
1450 pub items: Option<Box<Schema>>,
1451 }
1452
1453 pub fn flatten_schema(mut schema: Value) -> Result<Value, CompletionError> {
1459 let defs = if let Some(obj) = schema.as_object() {
1461 obj.get("$defs").or_else(|| obj.get("definitions")).cloned()
1462 } else {
1463 None
1464 };
1465
1466 let Some(defs_value) = defs else {
1467 return Ok(schema);
1468 };
1469
1470 let Some(defs_obj) = defs_value.as_object() else {
1471 return Err(CompletionError::ResponseError(
1472 "$defs must be an object".into(),
1473 ));
1474 };
1475
1476 resolve_refs(&mut schema, defs_obj)?;
1477
1478 if let Some(obj) = schema.as_object_mut() {
1480 obj.remove("$defs");
1481 obj.remove("definitions");
1482 }
1483
1484 Ok(schema)
1485 }
1486
1487 fn resolve_refs(
1490 value: &mut Value,
1491 defs: &serde_json::Map<String, Value>,
1492 ) -> Result<(), CompletionError> {
1493 match value {
1494 Value::Object(obj) => {
1495 if let Some(ref_value) = obj.get("$ref")
1496 && let Some(ref_str) = ref_value.as_str()
1497 {
1498 let def_name = parse_ref_path(ref_str)?;
1500
1501 let def = defs.get(&def_name).ok_or_else(|| {
1502 CompletionError::ResponseError(format!("Reference not found: {}", ref_str))
1503 })?;
1504
1505 let mut resolved = def.clone();
1506 resolve_refs(&mut resolved, defs)?;
1507 *value = resolved;
1508 return Ok(());
1509 }
1510
1511 for (_, v) in obj.iter_mut() {
1512 resolve_refs(v, defs)?;
1513 }
1514 }
1515 Value::Array(arr) => {
1516 for item in arr.iter_mut() {
1517 resolve_refs(item, defs)?;
1518 }
1519 }
1520 _ => {}
1521 }
1522
1523 Ok(())
1524 }
1525
1526 fn parse_ref_path(ref_str: &str) -> Result<String, CompletionError> {
1532 if let Some(fragment) = ref_str.strip_prefix('#') {
1533 if let Some(name) = fragment.strip_prefix("/$defs/") {
1534 Ok(name.to_string())
1535 } else if let Some(name) = fragment.strip_prefix("/definitions/") {
1536 Ok(name.to_string())
1537 } else {
1538 Err(CompletionError::ResponseError(format!(
1539 "Unsupported reference format: {}",
1540 ref_str
1541 )))
1542 }
1543 } else {
1544 Err(CompletionError::ResponseError(format!(
1545 "Only fragment references (#/...) are supported: {}",
1546 ref_str
1547 )))
1548 }
1549 }
1550
1551 fn extract_type(type_value: &Value) -> Option<String> {
1554 if type_value.is_string() {
1555 type_value.as_str().map(String::from)
1556 } else if type_value.is_array() {
1557 type_value
1558 .as_array()
1559 .and_then(|arr| arr.first())
1560 .and_then(|v| v.as_str().map(String::from))
1561 } else {
1562 None
1563 }
1564 }
1565
1566 fn extract_type_from_composition(composition: &Value) -> Option<String> {
1569 composition.as_array().and_then(|arr| {
1570 arr.iter().find_map(|schema| {
1571 if let Some(obj) = schema.as_object() {
1572 if let Some(type_val) = obj.get("type")
1574 && let Some(type_str) = type_val.as_str()
1575 && type_str == "null"
1576 {
1577 return None;
1578 }
1579 obj.get("type").and_then(extract_type).or_else(|| {
1581 if obj.contains_key("properties") {
1582 Some("object".to_string())
1583 } else {
1584 None
1585 }
1586 })
1587 } else {
1588 None
1589 }
1590 })
1591 })
1592 }
1593
1594 fn extract_schema_from_composition(
1597 composition: &Value,
1598 ) -> Option<serde_json::Map<String, Value>> {
1599 composition.as_array().and_then(|arr| {
1600 arr.iter().find_map(|schema| {
1601 if let Some(obj) = schema.as_object()
1602 && let Some(type_val) = obj.get("type")
1603 && let Some(type_str) = type_val.as_str()
1604 {
1605 if type_str == "null" {
1606 return None;
1607 }
1608 Some(obj.clone())
1609 } else {
1610 None
1611 }
1612 })
1613 })
1614 }
1615
1616 fn infer_type(obj: &serde_json::Map<String, Value>) -> String {
1619 if let Some(type_val) = obj.get("type")
1621 && let Some(type_str) = extract_type(type_val)
1622 {
1623 return type_str;
1624 }
1625
1626 if let Some(any_of) = obj.get("anyOf")
1628 && let Some(type_str) = extract_type_from_composition(any_of)
1629 {
1630 return type_str;
1631 }
1632
1633 if let Some(one_of) = obj.get("oneOf")
1634 && let Some(type_str) = extract_type_from_composition(one_of)
1635 {
1636 return type_str;
1637 }
1638
1639 if let Some(all_of) = obj.get("allOf")
1640 && let Some(type_str) = extract_type_from_composition(all_of)
1641 {
1642 return type_str;
1643 }
1644
1645 if obj.contains_key("properties") {
1647 "object".to_string()
1648 } else {
1649 String::new()
1650 }
1651 }
1652
1653 impl TryFrom<Value> for Schema {
1654 type Error = CompletionError;
1655
1656 fn try_from(value: Value) -> Result<Self, Self::Error> {
1657 let flattened_val = flatten_schema(value)?;
1658 if let Some(obj) = flattened_val.as_object() {
1659 let props_source = if obj.get("properties").is_none() {
1662 if let Some(any_of) = obj.get("anyOf") {
1663 extract_schema_from_composition(any_of)
1664 } else if let Some(one_of) = obj.get("oneOf") {
1665 extract_schema_from_composition(one_of)
1666 } else if let Some(all_of) = obj.get("allOf") {
1667 extract_schema_from_composition(all_of)
1668 } else {
1669 None
1670 }
1671 .unwrap_or(obj.clone())
1672 } else {
1673 obj.clone()
1674 };
1675
1676 Ok(Schema {
1677 r#type: infer_type(obj),
1678 format: obj.get("format").and_then(|v| v.as_str()).map(String::from),
1679 description: obj
1680 .get("description")
1681 .and_then(|v| v.as_str())
1682 .map(String::from),
1683 nullable: obj.get("nullable").and_then(|v| v.as_bool()),
1684 r#enum: obj.get("enum").and_then(|v| v.as_array()).map(|arr| {
1685 arr.iter()
1686 .filter_map(|v| v.as_str().map(String::from))
1687 .collect()
1688 }),
1689 max_items: obj
1690 .get("maxItems")
1691 .and_then(|v| v.as_i64())
1692 .map(|v| v as i32),
1693 min_items: obj
1694 .get("minItems")
1695 .and_then(|v| v.as_i64())
1696 .map(|v| v as i32),
1697 properties: props_source
1698 .get("properties")
1699 .and_then(|v| v.as_object())
1700 .map(|map| {
1701 map.iter()
1702 .filter_map(|(k, v)| {
1703 v.clone().try_into().ok().map(|schema| (k.clone(), schema))
1704 })
1705 .collect()
1706 }),
1707 required: props_source
1708 .get("required")
1709 .and_then(|v| v.as_array())
1710 .map(|arr| {
1711 arr.iter()
1712 .filter_map(|v| v.as_str().map(String::from))
1713 .collect()
1714 }),
1715 items: obj
1716 .get("items")
1717 .and_then(|v| v.clone().try_into().ok())
1718 .map(Box::new),
1719 })
1720 } else {
1721 Err(CompletionError::ResponseError(
1722 "Expected a JSON object for Schema".into(),
1723 ))
1724 }
1725 }
1726 }
1727
1728 #[derive(Debug, Serialize)]
1729 #[serde(rename_all = "camelCase")]
1730 pub struct GenerateContentRequest {
1731 pub contents: Vec<Content>,
1732 #[serde(skip_serializing_if = "Option::is_none")]
1733 pub tools: Option<Tool>,
1734 pub tool_config: Option<ToolConfig>,
1735 pub generation_config: Option<GenerationConfig>,
1737 pub safety_settings: Option<Vec<SafetySetting>>,
1751 pub system_instruction: Option<Content>,
1754 #[serde(flatten, skip_serializing_if = "Option::is_none")]
1757 pub additional_params: Option<serde_json::Value>,
1758 }
1759
1760 #[derive(Debug, Serialize)]
1761 #[serde(rename_all = "camelCase")]
1762 pub struct Tool {
1763 pub function_declarations: Vec<FunctionDeclaration>,
1764 pub code_execution: Option<CodeExecution>,
1765 }
1766
1767 #[derive(Debug, Serialize, Clone)]
1768 #[serde(rename_all = "camelCase")]
1769 pub struct FunctionDeclaration {
1770 pub name: String,
1771 pub description: String,
1772 #[serde(skip_serializing_if = "Option::is_none")]
1773 pub parameters: Option<Schema>,
1774 }
1775
1776 #[derive(Debug, Serialize, Deserialize)]
1777 #[serde(rename_all = "camelCase")]
1778 pub struct ToolConfig {
1779 pub function_calling_config: Option<FunctionCallingMode>,
1780 }
1781
1782 #[derive(Debug, Serialize, Deserialize, Default)]
1783 #[serde(tag = "mode", rename_all = "UPPERCASE")]
1784 pub enum FunctionCallingMode {
1785 #[default]
1786 Auto,
1787 None,
1788 Any {
1789 #[serde(skip_serializing_if = "Option::is_none")]
1790 allowed_function_names: Option<Vec<String>>,
1791 },
1792 }
1793
1794 impl TryFrom<message::ToolChoice> for FunctionCallingMode {
1795 type Error = CompletionError;
1796 fn try_from(value: message::ToolChoice) -> Result<Self, Self::Error> {
1797 let res = match value {
1798 message::ToolChoice::Auto => Self::Auto,
1799 message::ToolChoice::None => Self::None,
1800 message::ToolChoice::Required => Self::Any {
1801 allowed_function_names: None,
1802 },
1803 message::ToolChoice::Specific { function_names } => Self::Any {
1804 allowed_function_names: Some(function_names),
1805 },
1806 };
1807
1808 Ok(res)
1809 }
1810 }
1811
1812 #[derive(Debug, Serialize)]
1813 pub struct CodeExecution {}
1814
1815 #[derive(Debug, Serialize)]
1816 #[serde(rename_all = "camelCase")]
1817 pub struct SafetySetting {
1818 pub category: HarmCategory,
1819 pub threshold: HarmBlockThreshold,
1820 }
1821
1822 #[derive(Debug, Serialize)]
1823 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1824 pub enum HarmBlockThreshold {
1825 HarmBlockThresholdUnspecified,
1826 BlockLowAndAbove,
1827 BlockMediumAndAbove,
1828 BlockOnlyHigh,
1829 BlockNone,
1830 Off,
1831 }
1832}
1833
1834#[cfg(test)]
1835mod tests {
1836 use crate::{message, providers::gemini::completion::gemini_api_types::flatten_schema};
1837
1838 use super::*;
1839 use serde_json::json;
1840
1841 #[test]
1842 fn test_deserialize_message_user() {
1843 let raw_message = r#"{
1844 "parts": [
1845 {"text": "Hello, world!"},
1846 {"inlineData": {"mimeType": "image/png", "data": "base64encodeddata"}},
1847 {"functionCall": {"name": "test_function", "args": {"arg1": "value1"}}},
1848 {"functionResponse": {"name": "test_function", "response": {"result": "success"}}},
1849 {"fileData": {"mimeType": "application/pdf", "fileUri": "http://example.com/file.pdf"}},
1850 {"executableCode": {"code": "print('Hello, world!')", "language": "PYTHON"}},
1851 {"codeExecutionResult": {"output": "Hello, world!", "outcome": "OUTCOME_OK"}}
1852 ],
1853 "role": "user"
1854 }"#;
1855
1856 let content: Content = {
1857 let jd = &mut serde_json::Deserializer::from_str(raw_message);
1858 serde_path_to_error::deserialize(jd).unwrap_or_else(|err| {
1859 panic!("Deserialization error at {}: {}", err.path(), err);
1860 })
1861 };
1862 assert_eq!(content.role, Some(Role::User));
1863 assert_eq!(content.parts.len(), 7);
1864
1865 let parts: Vec<Part> = content.parts.into_iter().collect();
1866
1867 if let Part {
1868 part: PartKind::Text(text),
1869 ..
1870 } = &parts[0]
1871 {
1872 assert_eq!(text, "Hello, world!");
1873 } else {
1874 panic!("Expected text part");
1875 }
1876
1877 if let Part {
1878 part: PartKind::InlineData(inline_data),
1879 ..
1880 } = &parts[1]
1881 {
1882 assert_eq!(inline_data.mime_type, "image/png");
1883 assert_eq!(inline_data.data, "base64encodeddata");
1884 } else {
1885 panic!("Expected inline data part");
1886 }
1887
1888 if let Part {
1889 part: PartKind::FunctionCall(function_call),
1890 ..
1891 } = &parts[2]
1892 {
1893 assert_eq!(function_call.name, "test_function");
1894 assert_eq!(
1895 function_call.args.as_object().unwrap().get("arg1").unwrap(),
1896 "value1"
1897 );
1898 } else {
1899 panic!("Expected function call part");
1900 }
1901
1902 if let Part {
1903 part: PartKind::FunctionResponse(function_response),
1904 ..
1905 } = &parts[3]
1906 {
1907 assert_eq!(function_response.name, "test_function");
1908 assert_eq!(
1909 function_response
1910 .response
1911 .as_ref()
1912 .unwrap()
1913 .get("result")
1914 .unwrap(),
1915 "success"
1916 );
1917 } else {
1918 panic!("Expected function response part");
1919 }
1920
1921 if let Part {
1922 part: PartKind::FileData(file_data),
1923 ..
1924 } = &parts[4]
1925 {
1926 assert_eq!(file_data.mime_type.as_ref().unwrap(), "application/pdf");
1927 assert_eq!(file_data.file_uri, "http://example.com/file.pdf");
1928 } else {
1929 panic!("Expected file data part");
1930 }
1931
1932 if let Part {
1933 part: PartKind::ExecutableCode(executable_code),
1934 ..
1935 } = &parts[5]
1936 {
1937 assert_eq!(executable_code.code, "print('Hello, world!')");
1938 } else {
1939 panic!("Expected executable code part");
1940 }
1941
1942 if let Part {
1943 part: PartKind::CodeExecutionResult(code_execution_result),
1944 ..
1945 } = &parts[6]
1946 {
1947 assert_eq!(
1948 code_execution_result.clone().output.unwrap(),
1949 "Hello, world!"
1950 );
1951 } else {
1952 panic!("Expected code execution result part");
1953 }
1954 }
1955
1956 #[test]
1957 fn test_deserialize_message_model() {
1958 let json_data = json!({
1959 "parts": [{"text": "Hello, user!"}],
1960 "role": "model"
1961 });
1962
1963 let content: Content = serde_json::from_value(json_data).unwrap();
1964 assert_eq!(content.role, Some(Role::Model));
1965 assert_eq!(content.parts.len(), 1);
1966 if let Some(Part {
1967 part: PartKind::Text(text),
1968 ..
1969 }) = content.parts.first()
1970 {
1971 assert_eq!(text, "Hello, user!");
1972 } else {
1973 panic!("Expected text part");
1974 }
1975 }
1976
1977 #[test]
1978 fn test_message_conversion_user() {
1979 let msg = message::Message::user("Hello, world!");
1980 let content: Content = msg.try_into().unwrap();
1981 assert_eq!(content.role, Some(Role::User));
1982 assert_eq!(content.parts.len(), 1);
1983 if let Some(Part {
1984 part: PartKind::Text(text),
1985 ..
1986 }) = &content.parts.first()
1987 {
1988 assert_eq!(text, "Hello, world!");
1989 } else {
1990 panic!("Expected text part");
1991 }
1992 }
1993
1994 #[test]
1995 fn test_message_conversion_model() {
1996 let msg = message::Message::assistant("Hello, user!");
1997
1998 let content: Content = msg.try_into().unwrap();
1999 assert_eq!(content.role, Some(Role::Model));
2000 assert_eq!(content.parts.len(), 1);
2001 if let Some(Part {
2002 part: PartKind::Text(text),
2003 ..
2004 }) = &content.parts.first()
2005 {
2006 assert_eq!(text, "Hello, user!");
2007 } else {
2008 panic!("Expected text part");
2009 }
2010 }
2011
2012 #[test]
2013 fn test_message_conversion_tool_call() {
2014 let tool_call = message::ToolCall {
2015 id: "test_tool".to_string(),
2016 call_id: None,
2017 function: message::ToolFunction {
2018 name: "test_function".to_string(),
2019 arguments: json!({"arg1": "value1"}),
2020 },
2021 };
2022
2023 let msg = message::Message::Assistant {
2024 id: None,
2025 content: OneOrMany::one(message::AssistantContent::ToolCall(tool_call)),
2026 };
2027
2028 let content: Content = msg.try_into().unwrap();
2029 assert_eq!(content.role, Some(Role::Model));
2030 assert_eq!(content.parts.len(), 1);
2031 if let Some(Part {
2032 part: PartKind::FunctionCall(function_call),
2033 ..
2034 }) = content.parts.first()
2035 {
2036 assert_eq!(function_call.name, "test_function");
2037 assert_eq!(
2038 function_call.args.as_object().unwrap().get("arg1").unwrap(),
2039 "value1"
2040 );
2041 } else {
2042 panic!("Expected function call part");
2043 }
2044 }
2045
2046 #[test]
2047 fn test_vec_schema_conversion() {
2048 let schema_with_ref = json!({
2049 "type": "array",
2050 "items": {
2051 "$ref": "#/$defs/Person"
2052 },
2053 "$defs": {
2054 "Person": {
2055 "type": "object",
2056 "properties": {
2057 "first_name": {
2058 "type": ["string", "null"],
2059 "description": "The person's first name, if provided (null otherwise)"
2060 },
2061 "last_name": {
2062 "type": ["string", "null"],
2063 "description": "The person's last name, if provided (null otherwise)"
2064 },
2065 "job": {
2066 "type": ["string", "null"],
2067 "description": "The person's job, if provided (null otherwise)"
2068 }
2069 },
2070 "required": []
2071 }
2072 }
2073 });
2074
2075 let result: Result<Schema, _> = schema_with_ref.try_into();
2076
2077 match result {
2078 Ok(schema) => {
2079 assert_eq!(schema.r#type, "array");
2080
2081 if let Some(items) = schema.items {
2082 println!("item types: {}", items.r#type);
2083
2084 assert_ne!(items.r#type, "", "Items type should not be empty string!");
2085 assert_eq!(items.r#type, "object", "Items should be object type");
2086 } else {
2087 panic!("Schema should have items field for array type");
2088 }
2089 }
2090 Err(e) => println!("Schema conversion failed: {:?}", e),
2091 }
2092 }
2093
2094 #[test]
2095 fn test_object_schema() {
2096 let simple_schema = json!({
2097 "type": "object",
2098 "properties": {
2099 "name": {
2100 "type": "string"
2101 }
2102 }
2103 });
2104
2105 let schema: Schema = simple_schema.try_into().unwrap();
2106 assert_eq!(schema.r#type, "object");
2107 assert!(schema.properties.is_some());
2108 }
2109
2110 #[test]
2111 fn test_array_with_inline_items() {
2112 let inline_schema = json!({
2113 "type": "array",
2114 "items": {
2115 "type": "object",
2116 "properties": {
2117 "name": {
2118 "type": "string"
2119 }
2120 }
2121 }
2122 });
2123
2124 let schema: Schema = inline_schema.try_into().unwrap();
2125 assert_eq!(schema.r#type, "array");
2126
2127 if let Some(items) = schema.items {
2128 assert_eq!(items.r#type, "object");
2129 assert!(items.properties.is_some());
2130 } else {
2131 panic!("Schema should have items field");
2132 }
2133 }
2134 #[test]
2135 fn test_flattened_schema() {
2136 let ref_schema = json!({
2137 "type": "array",
2138 "items": {
2139 "$ref": "#/$defs/Person"
2140 },
2141 "$defs": {
2142 "Person": {
2143 "type": "object",
2144 "properties": {
2145 "name": { "type": "string" }
2146 }
2147 }
2148 }
2149 });
2150
2151 let flattened = flatten_schema(ref_schema).unwrap();
2152 let schema: Schema = flattened.try_into().unwrap();
2153
2154 assert_eq!(schema.r#type, "array");
2155
2156 if let Some(items) = schema.items {
2157 println!("Flattened items type: '{}'", items.r#type);
2158
2159 assert_eq!(items.r#type, "object");
2160 assert!(items.properties.is_some());
2161 }
2162 }
2163}