1pub const GEMINI_2_5_PRO_PREVIEW_06_05: &str = "gemini-2.5-pro-preview-06-05";
7pub const GEMINI_2_5_PRO_PREVIEW_05_06: &str = "gemini-2.5-pro-preview-05-06";
9pub const GEMINI_2_5_PRO_PREVIEW_03_25: &str = "gemini-2.5-pro-preview-03-25";
11pub const GEMINI_2_5_FLASH_PREVIEW_04_17: &str = "gemini-2.5-flash-preview-04-17";
13pub const GEMINI_2_5_PRO_EXP_03_25: &str = "gemini-2.5-pro-exp-03-25";
15pub const GEMINI_2_5_FLASH: &str = "gemini-2.5-flash";
17pub const GEMINI_2_0_FLASH_LITE: &str = "gemini-2.0-flash-lite";
19pub const GEMINI_2_0_FLASH: &str = "gemini-2.0-flash";
21
22use self::gemini_api_types::Schema;
23use crate::http_client::HttpClientExt;
24use crate::message::{self, MimeType, Reasoning};
25
26use crate::providers::gemini::completion::gemini_api_types::{
27 AdditionalParameters, FunctionCallingMode, ToolConfig,
28};
29use crate::providers::gemini::streaming::StreamingCompletionResponse;
30use crate::telemetry::SpanCombinator;
31use crate::{
32 OneOrMany,
33 completion::{self, CompletionError, CompletionRequest},
34};
35use gemini_api_types::{
36 Content, FunctionDeclaration, GenerateContentRequest, GenerateContentResponse, Part, PartKind,
37 Role, Tool,
38};
39use serde_json::{Map, Value};
40use std::convert::TryFrom;
41use tracing::{Level, enabled, info_span};
42use tracing_futures::Instrument;
43
44use super::Client;
45
46#[derive(Clone, Debug)]
51pub struct CompletionModel<T = reqwest::Client> {
52 pub(crate) client: Client<T>,
53 pub model: String,
54}
55
56impl<T> CompletionModel<T> {
57 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
58 Self {
59 client,
60 model: model.into(),
61 }
62 }
63
64 pub fn with_model(client: Client<T>, model: &str) -> Self {
65 Self {
66 client,
67 model: model.into(),
68 }
69 }
70}
71
72impl<T> completion::CompletionModel for CompletionModel<T>
73where
74 T: HttpClientExt + Clone + 'static,
75{
76 type Response = GenerateContentResponse;
77 type StreamingResponse = StreamingCompletionResponse;
78 type Client = super::Client<T>;
79
80 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
81 Self::new(client.clone(), model)
82 }
83
84 async fn completion(
85 &self,
86 completion_request: CompletionRequest,
87 ) -> Result<completion::CompletionResponse<GenerateContentResponse>, CompletionError> {
88 let span = if tracing::Span::current().is_disabled() {
89 info_span!(
90 target: "rig::completions",
91 "generate_content",
92 gen_ai.operation.name = "generate_content",
93 gen_ai.provider.name = "gcp.gemini",
94 gen_ai.request.model = self.model,
95 gen_ai.system_instructions = &completion_request.preamble,
96 gen_ai.response.id = tracing::field::Empty,
97 gen_ai.response.model = tracing::field::Empty,
98 gen_ai.usage.output_tokens = tracing::field::Empty,
99 gen_ai.usage.input_tokens = tracing::field::Empty,
100 )
101 } else {
102 tracing::Span::current()
103 };
104
105 let request = create_request_body(completion_request)?;
106
107 if enabled!(Level::TRACE) {
108 tracing::trace!(
109 target: "rig::completions",
110 "Gemini completion request: {}",
111 serde_json::to_string_pretty(&request)?
112 );
113 }
114
115 let body = serde_json::to_vec(&request)?;
116
117 let path = format!("/v1beta/models/{}:generateContent", self.model);
118
119 let request = self
120 .client
121 .post(path.as_str())?
122 .body(body)
123 .map_err(|e| CompletionError::HttpError(e.into()))?;
124
125 async move {
126 let response = self.client.send::<_, Vec<u8>>(request).await?;
127
128 if response.status().is_success() {
129 let response_body = response
130 .into_body()
131 .await
132 .map_err(CompletionError::HttpError)?;
133
134 let response_text = String::from_utf8_lossy(&response_body).to_string();
135
136 let response: GenerateContentResponse = serde_json::from_slice(&response_body)
137 .map_err(|err| {
138 tracing::error!(
139 error = %err,
140 body = %response_text,
141 "Failed to deserialize Gemini completion response"
142 );
143 CompletionError::JsonError(err)
144 })?;
145
146 let span = tracing::Span::current();
147 span.record_response_metadata(&response);
148 span.record_token_usage(&response.usage_metadata);
149
150 if enabled!(Level::TRACE) {
151 tracing::trace!(
152 target: "rig::completions",
153 "Gemini completion response: {}",
154 serde_json::to_string_pretty(&response)?
155 );
156 }
157
158 response.try_into()
159 } else {
160 let text = String::from_utf8_lossy(
161 &response
162 .into_body()
163 .await
164 .map_err(CompletionError::HttpError)?,
165 )
166 .into();
167
168 Err(CompletionError::ProviderError(text))
169 }
170 }
171 .instrument(span)
172 .await
173 }
174
175 async fn stream(
176 &self,
177 request: CompletionRequest,
178 ) -> Result<
179 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
180 CompletionError,
181 > {
182 CompletionModel::stream(self, request).await
183 }
184}
185
186pub(crate) fn create_request_body(
187 completion_request: CompletionRequest,
188) -> Result<GenerateContentRequest, CompletionError> {
189 let mut full_history = Vec::new();
190 full_history.extend(completion_request.chat_history);
191
192 let additional_params = completion_request
193 .additional_params
194 .unwrap_or_else(|| Value::Object(Map::new()));
195
196 let AdditionalParameters {
197 mut generation_config,
198 additional_params,
199 } = serde_json::from_value::<AdditionalParameters>(additional_params)?;
200
201 generation_config = generation_config.map(|mut cfg| {
202 if let Some(temp) = completion_request.temperature {
203 cfg.temperature = Some(temp);
204 };
205
206 if let Some(max_tokens) = completion_request.max_tokens {
207 cfg.max_output_tokens = Some(max_tokens);
208 };
209
210 cfg
211 });
212
213 let system_instruction = completion_request.preamble.clone().map(|preamble| Content {
214 parts: vec![preamble.into()],
215 role: Some(Role::Model),
216 });
217
218 let tools = if completion_request.tools.is_empty() {
219 None
220 } else {
221 Some(vec![Tool::try_from(completion_request.tools)?])
222 };
223
224 let tool_config = if let Some(cfg) = completion_request.tool_choice {
225 Some(ToolConfig {
226 function_calling_config: Some(FunctionCallingMode::try_from(cfg)?),
227 })
228 } else {
229 None
230 };
231
232 let request = GenerateContentRequest {
233 contents: full_history
234 .into_iter()
235 .map(|msg| {
236 msg.try_into()
237 .map_err(|e| CompletionError::RequestError(Box::new(e)))
238 })
239 .collect::<Result<Vec<_>, _>>()?,
240 generation_config,
241 safety_settings: None,
242 tools,
243 tool_config,
244 system_instruction,
245 additional_params,
246 };
247
248 Ok(request)
249}
250
251impl TryFrom<completion::ToolDefinition> for Tool {
252 type Error = CompletionError;
253
254 fn try_from(tool: completion::ToolDefinition) -> Result<Self, Self::Error> {
255 let parameters: Option<Schema> =
256 if tool.parameters == serde_json::json!({"type": "object", "properties": {}}) {
257 None
258 } else {
259 Some(tool.parameters.try_into()?)
260 };
261
262 Ok(Self {
263 function_declarations: vec![FunctionDeclaration {
264 name: tool.name,
265 description: tool.description,
266 parameters,
267 }],
268 code_execution: None,
269 })
270 }
271}
272
273impl TryFrom<Vec<completion::ToolDefinition>> for Tool {
274 type Error = CompletionError;
275
276 fn try_from(tools: Vec<completion::ToolDefinition>) -> Result<Self, Self::Error> {
277 let mut function_declarations = Vec::new();
278
279 for tool in tools {
280 let parameters =
281 if tool.parameters == serde_json::json!({"type": "object", "properties": {}}) {
282 None
283 } else {
284 match tool.parameters.try_into() {
285 Ok(schema) => Some(schema),
286 Err(e) => {
287 let emsg = format!(
288 "Tool '{}' could not be converted to a schema: {:?}",
289 tool.name, e,
290 );
291 return Err(CompletionError::ProviderError(emsg));
292 }
293 }
294 };
295
296 function_declarations.push(FunctionDeclaration {
297 name: tool.name,
298 description: tool.description,
299 parameters,
300 });
301 }
302
303 Ok(Self {
304 function_declarations,
305 code_execution: None,
306 })
307 }
308}
309
310impl TryFrom<GenerateContentResponse> for completion::CompletionResponse<GenerateContentResponse> {
311 type Error = CompletionError;
312
313 fn try_from(response: GenerateContentResponse) -> Result<Self, Self::Error> {
314 let candidate = response.candidates.first().ok_or_else(|| {
315 CompletionError::ResponseError("No response candidates in response".into())
316 })?;
317
318 let content = candidate
319 .content
320 .as_ref()
321 .ok_or_else(|| {
322 let reason = candidate
323 .finish_reason
324 .as_ref()
325 .map(|r| format!("finish_reason={r:?}"))
326 .unwrap_or_else(|| "finish_reason=<unknown>".to_string());
327 let message = candidate
328 .finish_message
329 .as_deref()
330 .unwrap_or("no finish message provided");
331 CompletionError::ResponseError(format!(
332 "Gemini candidate missing content ({reason}, finish_message={message})"
333 ))
334 })?
335 .parts
336 .iter()
337 .map(
338 |Part {
339 thought,
340 thought_signature,
341 part,
342 ..
343 }| {
344 Ok(match part {
345 PartKind::Text(text) => {
346 if let Some(thought) = thought
347 && *thought
348 {
349 completion::AssistantContent::Reasoning(Reasoning::new(text))
350 } else {
351 completion::AssistantContent::text(text)
352 }
353 }
354 PartKind::InlineData(inline_data) => {
355 let mime_type =
356 message::MediaType::from_mime_type(&inline_data.mime_type);
357
358 match mime_type {
359 Some(message::MediaType::Image(media_type)) => {
360 message::AssistantContent::image_base64(
361 &inline_data.data,
362 Some(media_type),
363 Some(message::ImageDetail::default()),
364 )
365 }
366 _ => {
367 return Err(CompletionError::ResponseError(format!(
368 "Unsupported media type {mime_type:?}"
369 )));
370 }
371 }
372 }
373 PartKind::FunctionCall(function_call) => {
374 completion::AssistantContent::ToolCall(
375 message::ToolCall::new(
376 function_call.name.clone(),
377 message::ToolFunction::new(
378 function_call.name.clone(),
379 function_call.args.clone(),
380 ),
381 )
382 .with_signature(thought_signature.clone()),
383 )
384 }
385 _ => {
386 return Err(CompletionError::ResponseError(
387 "Response did not contain a message or tool call".into(),
388 ));
389 }
390 })
391 },
392 )
393 .collect::<Result<Vec<_>, _>>()?;
394
395 let choice = OneOrMany::many(content).map_err(|_| {
396 CompletionError::ResponseError(
397 "Response contained no message or tool call (empty)".to_owned(),
398 )
399 })?;
400
401 let usage = response
402 .usage_metadata
403 .as_ref()
404 .map(|usage| completion::Usage {
405 input_tokens: usage.prompt_token_count as u64,
406 output_tokens: usage.candidates_token_count.unwrap_or(0) as u64,
407 total_tokens: usage.total_token_count as u64,
408 cached_input_tokens: 0,
409 })
410 .unwrap_or_default();
411
412 Ok(completion::CompletionResponse {
413 choice,
414 usage,
415 raw_response: response,
416 })
417 }
418}
419
420pub mod gemini_api_types {
421 use crate::telemetry::ProviderResponseExt;
422 use std::{collections::HashMap, convert::Infallible, str::FromStr};
423
424 use serde::{Deserialize, Serialize};
428 use serde_json::{Value, json};
429
430 use crate::completion::GetTokenUsage;
431 use crate::message::{DocumentSourceKind, ImageMediaType, MessageError, MimeType};
432 use crate::{
433 completion::CompletionError,
434 message::{self},
435 providers::gemini::gemini_api_types::{CodeExecutionResult, ExecutableCode},
436 };
437
438 #[derive(Debug, Deserialize, Serialize, Default)]
439 #[serde(rename_all = "camelCase")]
440 pub struct AdditionalParameters {
441 pub generation_config: Option<GenerationConfig>,
443 #[serde(flatten, skip_serializing_if = "Option::is_none")]
445 pub additional_params: Option<serde_json::Value>,
446 }
447
448 impl AdditionalParameters {
449 pub fn with_config(mut self, cfg: GenerationConfig) -> Self {
450 self.generation_config = Some(cfg);
451 self
452 }
453
454 pub fn with_params(mut self, params: serde_json::Value) -> Self {
455 self.additional_params = Some(params);
456 self
457 }
458 }
459
460 #[derive(Debug, Deserialize, Serialize)]
468 #[serde(rename_all = "camelCase")]
469 pub struct GenerateContentResponse {
470 pub response_id: String,
471 pub candidates: Vec<ContentCandidate>,
473 pub prompt_feedback: Option<PromptFeedback>,
475 pub usage_metadata: Option<UsageMetadata>,
477 pub model_version: Option<String>,
478 }
479
480 impl ProviderResponseExt for GenerateContentResponse {
481 type OutputMessage = ContentCandidate;
482 type Usage = UsageMetadata;
483
484 fn get_response_id(&self) -> Option<String> {
485 Some(self.response_id.clone())
486 }
487
488 fn get_response_model_name(&self) -> Option<String> {
489 None
490 }
491
492 fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
493 self.candidates.clone()
494 }
495
496 fn get_text_response(&self) -> Option<String> {
497 let str = self
498 .candidates
499 .iter()
500 .filter_map(|x| {
501 let content = x.content.as_ref()?;
502 if content.role.as_ref().is_none_or(|y| y != &Role::Model) {
503 return None;
504 }
505
506 let res = content
507 .parts
508 .iter()
509 .filter_map(|part| {
510 if let PartKind::Text(ref str) = part.part {
511 Some(str.to_owned())
512 } else {
513 None
514 }
515 })
516 .collect::<Vec<String>>()
517 .join("\n");
518
519 Some(res)
520 })
521 .collect::<Vec<String>>()
522 .join("\n");
523
524 if str.is_empty() { None } else { Some(str) }
525 }
526
527 fn get_usage(&self) -> Option<Self::Usage> {
528 self.usage_metadata.clone()
529 }
530 }
531
532 #[derive(Clone, Debug, Deserialize, Serialize)]
534 #[serde(rename_all = "camelCase")]
535 pub struct ContentCandidate {
536 #[serde(skip_serializing_if = "Option::is_none")]
538 pub content: Option<Content>,
539 pub finish_reason: Option<FinishReason>,
542 pub safety_ratings: Option<Vec<SafetyRating>>,
545 pub citation_metadata: Option<CitationMetadata>,
549 pub token_count: Option<i32>,
551 pub avg_logprobs: Option<f64>,
553 pub logprobs_result: Option<LogprobsResult>,
555 pub index: Option<i32>,
557 pub finish_message: Option<String>,
559 }
560
561 #[derive(Clone, Debug, Deserialize, Serialize)]
562 pub struct Content {
563 #[serde(default)]
565 pub parts: Vec<Part>,
566 pub role: Option<Role>,
569 }
570
571 impl TryFrom<message::Message> for Content {
572 type Error = message::MessageError;
573
574 fn try_from(msg: message::Message) -> Result<Self, Self::Error> {
575 Ok(match msg {
576 message::Message::User { content } => Content {
577 parts: content
578 .into_iter()
579 .map(|c| c.try_into())
580 .collect::<Result<Vec<_>, _>>()?,
581 role: Some(Role::User),
582 },
583 message::Message::Assistant { content, .. } => Content {
584 role: Some(Role::Model),
585 parts: content
586 .into_iter()
587 .map(|content| content.try_into())
588 .collect::<Result<Vec<_>, _>>()?,
589 },
590 })
591 }
592 }
593
594 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
595 #[serde(rename_all = "lowercase")]
596 pub enum Role {
597 User,
598 Model,
599 }
600
601 #[derive(Debug, Default, Deserialize, Serialize, Clone, PartialEq)]
602 #[serde(rename_all = "camelCase")]
603 pub struct Part {
604 #[serde(skip_serializing_if = "Option::is_none")]
606 pub thought: Option<bool>,
607 #[serde(skip_serializing_if = "Option::is_none")]
609 pub thought_signature: Option<String>,
610 #[serde(flatten)]
611 pub part: PartKind,
612 #[serde(flatten, skip_serializing_if = "Option::is_none")]
613 pub additional_params: Option<Value>,
614 }
615
616 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
620 #[serde(rename_all = "camelCase")]
621 pub enum PartKind {
622 Text(String),
623 InlineData(Blob),
624 FunctionCall(FunctionCall),
625 FunctionResponse(FunctionResponse),
626 FileData(FileData),
627 ExecutableCode(ExecutableCode),
628 CodeExecutionResult(CodeExecutionResult),
629 }
630
631 impl Default for PartKind {
634 fn default() -> Self {
635 Self::Text(String::new())
636 }
637 }
638
639 impl From<String> for Part {
640 fn from(text: String) -> Self {
641 Self {
642 thought: Some(false),
643 thought_signature: None,
644 part: PartKind::Text(text),
645 additional_params: None,
646 }
647 }
648 }
649
650 impl From<&str> for Part {
651 fn from(text: &str) -> Self {
652 Self::from(text.to_string())
653 }
654 }
655
656 impl FromStr for Part {
657 type Err = Infallible;
658
659 fn from_str(s: &str) -> Result<Self, Self::Err> {
660 Ok(s.into())
661 }
662 }
663
664 impl TryFrom<(ImageMediaType, DocumentSourceKind)> for PartKind {
665 type Error = message::MessageError;
666 fn try_from(
667 (mime_type, doc_src): (ImageMediaType, DocumentSourceKind),
668 ) -> Result<Self, Self::Error> {
669 let mime_type = mime_type.to_mime_type().to_string();
670 let part = match doc_src {
671 DocumentSourceKind::Url(url) => PartKind::FileData(FileData {
672 mime_type: Some(mime_type),
673 file_uri: url,
674 }),
675 DocumentSourceKind::Base64(data) | DocumentSourceKind::String(data) => {
676 PartKind::InlineData(Blob { mime_type, data })
677 }
678 DocumentSourceKind::Raw(_) => {
679 return Err(message::MessageError::ConversionError(
680 "Raw files not supported, encode as base64 first".into(),
681 ));
682 }
683 DocumentSourceKind::Unknown => {
684 return Err(message::MessageError::ConversionError(
685 "Can't convert an unknown document source".to_string(),
686 ));
687 }
688 };
689
690 Ok(part)
691 }
692 }
693
694 impl TryFrom<message::UserContent> for Part {
695 type Error = message::MessageError;
696
697 fn try_from(content: message::UserContent) -> Result<Self, Self::Error> {
698 match content {
699 message::UserContent::Text(message::Text { text }) => Ok(Part {
700 thought: Some(false),
701 thought_signature: None,
702 part: PartKind::Text(text),
703 additional_params: None,
704 }),
705 message::UserContent::ToolResult(message::ToolResult { id, content, .. }) => {
706 let mut response_json: Option<serde_json::Value> = None;
707 let mut parts: Vec<FunctionResponsePart> = Vec::new();
708
709 for item in content.iter() {
710 match item {
711 message::ToolResultContent::Text(text) => {
712 let result: serde_json::Value =
713 serde_json::from_str(&text.text).unwrap_or_else(|error| {
714 tracing::trace!(
715 ?error,
716 "Tool result is not a valid JSON, treat it as normal string"
717 );
718 json!(&text.text)
719 });
720
721 response_json = Some(match response_json {
722 Some(mut existing) => {
723 if let serde_json::Value::Object(ref mut map) = existing {
724 map.insert("text".to_string(), result);
725 }
726 existing
727 }
728 None => json!({ "result": result }),
729 });
730 }
731 message::ToolResultContent::Image(image) => {
732 let part = match &image.data {
733 DocumentSourceKind::Base64(b64) => {
734 let mime_type = image
735 .media_type
736 .as_ref()
737 .ok_or(message::MessageError::ConversionError(
738 "Image media type is required for Gemini tool results".to_string(),
739 ))?
740 .to_mime_type();
741
742 FunctionResponsePart {
743 inline_data: Some(FunctionResponseInlineData {
744 mime_type: mime_type.to_string(),
745 data: b64.clone(),
746 display_name: None,
747 }),
748 file_data: None,
749 }
750 }
751 DocumentSourceKind::Url(url) => {
752 let mime_type = image
753 .media_type
754 .as_ref()
755 .map(|mt| mt.to_mime_type().to_string());
756
757 FunctionResponsePart {
758 inline_data: None,
759 file_data: Some(FileData {
760 mime_type,
761 file_uri: url.clone(),
762 }),
763 }
764 }
765 _ => {
766 return Err(message::MessageError::ConversionError(
767 "Unsupported image source kind for tool results"
768 .to_string(),
769 ));
770 }
771 };
772 parts.push(part);
773 }
774 }
775 }
776
777 Ok(Part {
778 thought: Some(false),
779 thought_signature: None,
780 part: PartKind::FunctionResponse(FunctionResponse {
781 name: id,
782 response: response_json,
783 parts: if parts.is_empty() { None } else { Some(parts) },
784 }),
785 additional_params: None,
786 })
787 }
788 message::UserContent::Image(message::Image {
789 data, media_type, ..
790 }) => match media_type {
791 Some(media_type) => match media_type {
792 message::ImageMediaType::JPEG
793 | message::ImageMediaType::PNG
794 | message::ImageMediaType::WEBP
795 | message::ImageMediaType::HEIC
796 | message::ImageMediaType::HEIF => {
797 let part = PartKind::try_from((media_type, data))?;
798 Ok(Part {
799 thought: Some(false),
800 thought_signature: None,
801 part,
802 additional_params: None,
803 })
804 }
805 _ => Err(message::MessageError::ConversionError(format!(
806 "Unsupported image media type {media_type:?}"
807 ))),
808 },
809 None => Err(message::MessageError::ConversionError(
810 "Media type for image is required for Gemini".to_string(),
811 )),
812 },
813 message::UserContent::Document(message::Document {
814 data, media_type, ..
815 }) => {
816 let Some(media_type) = media_type else {
817 return Err(MessageError::ConversionError(
818 "A mime type is required for document inputs to Gemini".to_string(),
819 ));
820 };
821
822 if !media_type.is_code() {
823 let mime_type = media_type.to_mime_type().to_string();
824
825 let part = match data {
826 DocumentSourceKind::Url(file_uri) => PartKind::FileData(FileData {
827 mime_type: Some(mime_type),
828 file_uri,
829 }),
830 DocumentSourceKind::Base64(data) | DocumentSourceKind::String(data) => {
831 PartKind::InlineData(Blob { mime_type, data })
832 }
833 DocumentSourceKind::Raw(_) => {
834 return Err(message::MessageError::ConversionError(
835 "Raw files not supported, encode as base64 first".into(),
836 ));
837 }
838 _ => {
839 return Err(message::MessageError::ConversionError(
840 "Document has no body".to_string(),
841 ));
842 }
843 };
844
845 Ok(Part {
846 thought: Some(false),
847 part,
848 ..Default::default()
849 })
850 } else {
851 Err(message::MessageError::ConversionError(format!(
852 "Unsupported document media type {media_type:?}"
853 )))
854 }
855 }
856
857 message::UserContent::Audio(message::Audio {
858 data, media_type, ..
859 }) => {
860 let Some(media_type) = media_type else {
861 return Err(MessageError::ConversionError(
862 "A mime type is required for audio inputs to Gemini".to_string(),
863 ));
864 };
865
866 let mime_type = media_type.to_mime_type().to_string();
867
868 let part = match data {
869 DocumentSourceKind::Base64(data) => {
870 PartKind::InlineData(Blob { data, mime_type })
871 }
872
873 DocumentSourceKind::Url(file_uri) => PartKind::FileData(FileData {
874 mime_type: Some(mime_type),
875 file_uri,
876 }),
877 DocumentSourceKind::String(_) => {
878 return Err(message::MessageError::ConversionError(
879 "Strings cannot be used as audio files!".into(),
880 ));
881 }
882 DocumentSourceKind::Raw(_) => {
883 return Err(message::MessageError::ConversionError(
884 "Raw files not supported, encode as base64 first".into(),
885 ));
886 }
887 DocumentSourceKind::Unknown => {
888 return Err(message::MessageError::ConversionError(
889 "Content has no body".to_string(),
890 ));
891 }
892 };
893
894 Ok(Part {
895 thought: Some(false),
896 part,
897 ..Default::default()
898 })
899 }
900 message::UserContent::Video(message::Video {
901 data,
902 media_type,
903 additional_params,
904 ..
905 }) => {
906 let mime_type = media_type.map(|media_ty| media_ty.to_mime_type().to_string());
907
908 let part = match data {
909 DocumentSourceKind::Url(file_uri) => {
910 if file_uri.starts_with("https://www.youtube.com") {
911 PartKind::FileData(FileData {
912 mime_type,
913 file_uri,
914 })
915 } else {
916 if mime_type.is_none() {
917 return Err(MessageError::ConversionError(
918 "A mime type is required for non-Youtube video file inputs to Gemini"
919 .to_string(),
920 ));
921 }
922
923 PartKind::FileData(FileData {
924 mime_type,
925 file_uri,
926 })
927 }
928 }
929 DocumentSourceKind::Base64(data) => {
930 let Some(mime_type) = mime_type else {
931 return Err(MessageError::ConversionError(
932 "A media type is expected for base64 encoded strings"
933 .to_string(),
934 ));
935 };
936 PartKind::InlineData(Blob { mime_type, data })
937 }
938 DocumentSourceKind::String(_) => {
939 return Err(message::MessageError::ConversionError(
940 "Strings cannot be used as audio files!".into(),
941 ));
942 }
943 DocumentSourceKind::Raw(_) => {
944 return Err(message::MessageError::ConversionError(
945 "Raw file data not supported, encode as base64 first".into(),
946 ));
947 }
948 DocumentSourceKind::Unknown => {
949 return Err(message::MessageError::ConversionError(
950 "Media type for video is required for Gemini".to_string(),
951 ));
952 }
953 };
954
955 Ok(Part {
956 thought: Some(false),
957 thought_signature: None,
958 part,
959 additional_params,
960 })
961 }
962 }
963 }
964 }
965
966 impl TryFrom<message::AssistantContent> for Part {
967 type Error = message::MessageError;
968
969 fn try_from(content: message::AssistantContent) -> Result<Self, Self::Error> {
970 match content {
971 message::AssistantContent::Text(message::Text { text }) => Ok(text.into()),
972 message::AssistantContent::Image(message::Image {
973 data, media_type, ..
974 }) => match media_type {
975 Some(media_type) => match media_type {
976 message::ImageMediaType::JPEG
977 | message::ImageMediaType::PNG
978 | message::ImageMediaType::WEBP
979 | message::ImageMediaType::HEIC
980 | message::ImageMediaType::HEIF => {
981 let part = PartKind::try_from((media_type, data))?;
982 Ok(Part {
983 thought: Some(false),
984 thought_signature: None,
985 part,
986 additional_params: None,
987 })
988 }
989 _ => Err(message::MessageError::ConversionError(format!(
990 "Unsupported image media type {media_type:?}"
991 ))),
992 },
993 None => Err(message::MessageError::ConversionError(
994 "Media type for image is required for Gemini".to_string(),
995 )),
996 },
997 message::AssistantContent::ToolCall(tool_call) => Ok(tool_call.into()),
998 message::AssistantContent::Reasoning(message::Reasoning { reasoning, .. }) => {
999 Ok(Part {
1000 thought: Some(true),
1001 thought_signature: None,
1002 part: PartKind::Text(
1003 reasoning.first().cloned().unwrap_or_else(|| "".to_string()),
1004 ),
1005 additional_params: None,
1006 })
1007 }
1008 }
1009 }
1010 }
1011
1012 impl From<message::ToolCall> for Part {
1013 fn from(tool_call: message::ToolCall) -> Self {
1014 Self {
1015 thought: Some(false),
1016 thought_signature: tool_call.signature,
1017 part: PartKind::FunctionCall(FunctionCall {
1018 name: tool_call.function.name,
1019 args: tool_call.function.arguments,
1020 }),
1021 additional_params: None,
1022 }
1023 }
1024 }
1025
1026 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1029 #[serde(rename_all = "camelCase")]
1030 pub struct Blob {
1031 pub mime_type: String,
1034 pub data: String,
1036 }
1037
1038 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1041 pub struct FunctionCall {
1042 pub name: String,
1045 pub args: serde_json::Value,
1047 }
1048
1049 impl From<message::ToolCall> for FunctionCall {
1050 fn from(tool_call: message::ToolCall) -> Self {
1051 Self {
1052 name: tool_call.function.name,
1053 args: tool_call.function.arguments,
1054 }
1055 }
1056 }
1057
1058 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1062 pub struct FunctionResponse {
1063 pub name: String,
1066 #[serde(skip_serializing_if = "Option::is_none")]
1068 pub response: Option<serde_json::Value>,
1069 #[serde(skip_serializing_if = "Option::is_none")]
1071 pub parts: Option<Vec<FunctionResponsePart>>,
1072 }
1073
1074 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1076 #[serde(rename_all = "camelCase")]
1077 pub struct FunctionResponsePart {
1078 #[serde(skip_serializing_if = "Option::is_none")]
1080 pub inline_data: Option<FunctionResponseInlineData>,
1081 #[serde(skip_serializing_if = "Option::is_none")]
1083 pub file_data: Option<FileData>,
1084 }
1085
1086 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1088 #[serde(rename_all = "camelCase")]
1089 pub struct FunctionResponseInlineData {
1090 pub mime_type: String,
1092 pub data: String,
1094 #[serde(skip_serializing_if = "Option::is_none")]
1096 pub display_name: Option<String>,
1097 }
1098
1099 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1101 #[serde(rename_all = "camelCase")]
1102 pub struct FileData {
1103 pub mime_type: Option<String>,
1105 pub file_uri: String,
1107 }
1108
1109 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1110 pub struct SafetyRating {
1111 pub category: HarmCategory,
1112 pub probability: HarmProbability,
1113 }
1114
1115 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1116 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1117 pub enum HarmProbability {
1118 HarmProbabilityUnspecified,
1119 Negligible,
1120 Low,
1121 Medium,
1122 High,
1123 }
1124
1125 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1126 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1127 pub enum HarmCategory {
1128 HarmCategoryUnspecified,
1129 HarmCategoryDerogatory,
1130 HarmCategoryToxicity,
1131 HarmCategoryViolence,
1132 HarmCategorySexually,
1133 HarmCategoryMedical,
1134 HarmCategoryDangerous,
1135 HarmCategoryHarassment,
1136 HarmCategoryHateSpeech,
1137 HarmCategorySexuallyExplicit,
1138 HarmCategoryDangerousContent,
1139 HarmCategoryCivicIntegrity,
1140 }
1141
1142 #[derive(Debug, Deserialize, Clone, Default, Serialize)]
1143 #[serde(rename_all = "camelCase")]
1144 pub struct UsageMetadata {
1145 pub prompt_token_count: i32,
1146 #[serde(skip_serializing_if = "Option::is_none")]
1147 pub cached_content_token_count: Option<i32>,
1148 #[serde(skip_serializing_if = "Option::is_none")]
1149 pub candidates_token_count: Option<i32>,
1150 pub total_token_count: i32,
1151 #[serde(skip_serializing_if = "Option::is_none")]
1152 pub thoughts_token_count: Option<i32>,
1153 }
1154
1155 impl std::fmt::Display for UsageMetadata {
1156 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1157 write!(
1158 f,
1159 "Prompt token count: {}\nCached content token count: {}\nCandidates token count: {}\nTotal token count: {}",
1160 self.prompt_token_count,
1161 match self.cached_content_token_count {
1162 Some(count) => count.to_string(),
1163 None => "n/a".to_string(),
1164 },
1165 match self.candidates_token_count {
1166 Some(count) => count.to_string(),
1167 None => "n/a".to_string(),
1168 },
1169 self.total_token_count
1170 )
1171 }
1172 }
1173
1174 impl GetTokenUsage for UsageMetadata {
1175 fn token_usage(&self) -> Option<crate::completion::Usage> {
1176 let mut usage = crate::completion::Usage::new();
1177
1178 usage.input_tokens = self.prompt_token_count as u64;
1179 usage.output_tokens = (self.cached_content_token_count.unwrap_or_default()
1180 + self.candidates_token_count.unwrap_or_default()
1181 + self.thoughts_token_count.unwrap_or_default())
1182 as u64;
1183 usage.total_tokens = usage.input_tokens + usage.output_tokens;
1184
1185 Some(usage)
1186 }
1187 }
1188
1189 #[derive(Debug, Deserialize, Serialize)]
1191 #[serde(rename_all = "camelCase")]
1192 pub struct PromptFeedback {
1193 pub block_reason: Option<BlockReason>,
1195 pub safety_ratings: Option<Vec<SafetyRating>>,
1197 }
1198
1199 #[derive(Debug, Deserialize, Serialize)]
1201 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1202 pub enum BlockReason {
1203 BlockReasonUnspecified,
1205 Safety,
1207 Other,
1209 Blocklist,
1211 ProhibitedContent,
1213 }
1214
1215 #[derive(Clone, Debug, Deserialize, Serialize)]
1216 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1217 pub enum FinishReason {
1218 FinishReasonUnspecified,
1220 Stop,
1222 MaxTokens,
1224 Safety,
1226 Recitation,
1228 Language,
1230 Other,
1232 Blocklist,
1234 ProhibitedContent,
1236 Spii,
1238 MalformedFunctionCall,
1240 }
1241
1242 #[derive(Clone, Debug, Deserialize, Serialize)]
1243 #[serde(rename_all = "camelCase")]
1244 pub struct CitationMetadata {
1245 pub citation_sources: Vec<CitationSource>,
1246 }
1247
1248 #[derive(Clone, Debug, Deserialize, Serialize)]
1249 #[serde(rename_all = "camelCase")]
1250 pub struct CitationSource {
1251 #[serde(skip_serializing_if = "Option::is_none")]
1252 pub uri: Option<String>,
1253 #[serde(skip_serializing_if = "Option::is_none")]
1254 pub start_index: Option<i32>,
1255 #[serde(skip_serializing_if = "Option::is_none")]
1256 pub end_index: Option<i32>,
1257 #[serde(skip_serializing_if = "Option::is_none")]
1258 pub license: Option<String>,
1259 }
1260
1261 #[derive(Clone, Debug, Deserialize, Serialize)]
1262 #[serde(rename_all = "camelCase")]
1263 pub struct LogprobsResult {
1264 pub top_candidate: Vec<TopCandidate>,
1265 pub chosen_candidate: Vec<LogProbCandidate>,
1266 }
1267
1268 #[derive(Clone, Debug, Deserialize, Serialize)]
1269 pub struct TopCandidate {
1270 pub candidates: Vec<LogProbCandidate>,
1271 }
1272
1273 #[derive(Clone, Debug, Deserialize, Serialize)]
1274 #[serde(rename_all = "camelCase")]
1275 pub struct LogProbCandidate {
1276 pub token: String,
1277 pub token_id: String,
1278 pub log_probability: f64,
1279 }
1280
1281 #[derive(Debug, Deserialize, Serialize)]
1286 #[serde(rename_all = "camelCase")]
1287 pub struct GenerationConfig {
1288 #[serde(skip_serializing_if = "Option::is_none")]
1291 pub stop_sequences: Option<Vec<String>>,
1292 #[serde(skip_serializing_if = "Option::is_none")]
1298 pub response_mime_type: Option<String>,
1299 #[serde(skip_serializing_if = "Option::is_none")]
1303 pub response_schema: Option<Schema>,
1304 #[serde(
1310 skip_serializing_if = "Option::is_none",
1311 rename = "_responseJsonSchema"
1312 )]
1313 pub _response_json_schema: Option<Value>,
1314 #[serde(skip_serializing_if = "Option::is_none")]
1316 pub response_json_schema: Option<Value>,
1317 #[serde(skip_serializing_if = "Option::is_none")]
1320 pub candidate_count: Option<i32>,
1321 #[serde(skip_serializing_if = "Option::is_none")]
1324 pub max_output_tokens: Option<u64>,
1325 #[serde(skip_serializing_if = "Option::is_none")]
1328 pub temperature: Option<f64>,
1329 #[serde(skip_serializing_if = "Option::is_none")]
1336 pub top_p: Option<f64>,
1337 #[serde(skip_serializing_if = "Option::is_none")]
1343 pub top_k: Option<i32>,
1344 #[serde(skip_serializing_if = "Option::is_none")]
1350 pub presence_penalty: Option<f64>,
1351 #[serde(skip_serializing_if = "Option::is_none")]
1359 pub frequency_penalty: Option<f64>,
1360 #[serde(skip_serializing_if = "Option::is_none")]
1362 pub response_logprobs: Option<bool>,
1363 #[serde(skip_serializing_if = "Option::is_none")]
1366 pub logprobs: Option<i32>,
1367 #[serde(skip_serializing_if = "Option::is_none")]
1369 pub thinking_config: Option<ThinkingConfig>,
1370 #[serde(skip_serializing_if = "Option::is_none")]
1371 pub image_config: Option<ImageConfig>,
1372 }
1373
1374 impl Default for GenerationConfig {
1375 fn default() -> Self {
1376 Self {
1377 temperature: Some(1.0),
1378 max_output_tokens: Some(4096),
1379 stop_sequences: None,
1380 response_mime_type: None,
1381 response_schema: None,
1382 _response_json_schema: None,
1383 response_json_schema: None,
1384 candidate_count: None,
1385 top_p: None,
1386 top_k: None,
1387 presence_penalty: None,
1388 frequency_penalty: None,
1389 response_logprobs: None,
1390 logprobs: None,
1391 thinking_config: None,
1392 image_config: None,
1393 }
1394 }
1395 }
1396
1397 #[derive(Debug, Deserialize, Serialize)]
1398 #[serde(rename_all = "camelCase")]
1399 pub struct ThinkingConfig {
1400 pub thinking_budget: u32,
1401 pub include_thoughts: Option<bool>,
1402 }
1403
1404 #[derive(Debug, Deserialize, Serialize)]
1405 #[serde(rename_all = "camelCase")]
1406 pub struct ImageConfig {
1407 #[serde(skip_serializing_if = "Option::is_none")]
1408 pub aspect_ratio: Option<String>,
1409 #[serde(skip_serializing_if = "Option::is_none")]
1410 pub image_size: Option<String>,
1411 }
1412
1413 #[derive(Debug, Deserialize, Serialize, Clone)]
1417 pub struct Schema {
1418 pub r#type: String,
1419 #[serde(skip_serializing_if = "Option::is_none")]
1420 pub format: Option<String>,
1421 #[serde(skip_serializing_if = "Option::is_none")]
1422 pub description: Option<String>,
1423 #[serde(skip_serializing_if = "Option::is_none")]
1424 pub nullable: Option<bool>,
1425 #[serde(skip_serializing_if = "Option::is_none")]
1426 pub r#enum: Option<Vec<String>>,
1427 #[serde(skip_serializing_if = "Option::is_none")]
1428 pub max_items: Option<i32>,
1429 #[serde(skip_serializing_if = "Option::is_none")]
1430 pub min_items: Option<i32>,
1431 #[serde(skip_serializing_if = "Option::is_none")]
1432 pub properties: Option<HashMap<String, Schema>>,
1433 #[serde(skip_serializing_if = "Option::is_none")]
1434 pub required: Option<Vec<String>>,
1435 #[serde(skip_serializing_if = "Option::is_none")]
1436 pub items: Option<Box<Schema>>,
1437 }
1438
1439 pub fn flatten_schema(mut schema: Value) -> Result<Value, CompletionError> {
1445 let defs = if let Some(obj) = schema.as_object() {
1447 obj.get("$defs").or_else(|| obj.get("definitions")).cloned()
1448 } else {
1449 None
1450 };
1451
1452 let Some(defs_value) = defs else {
1453 return Ok(schema);
1454 };
1455
1456 let Some(defs_obj) = defs_value.as_object() else {
1457 return Err(CompletionError::ResponseError(
1458 "$defs must be an object".into(),
1459 ));
1460 };
1461
1462 resolve_refs(&mut schema, defs_obj)?;
1463
1464 if let Some(obj) = schema.as_object_mut() {
1466 obj.remove("$defs");
1467 obj.remove("definitions");
1468 }
1469
1470 Ok(schema)
1471 }
1472
1473 fn resolve_refs(
1476 value: &mut Value,
1477 defs: &serde_json::Map<String, Value>,
1478 ) -> Result<(), CompletionError> {
1479 match value {
1480 Value::Object(obj) => {
1481 if let Some(ref_value) = obj.get("$ref")
1482 && let Some(ref_str) = ref_value.as_str()
1483 {
1484 let def_name = parse_ref_path(ref_str)?;
1486
1487 let def = defs.get(&def_name).ok_or_else(|| {
1488 CompletionError::ResponseError(format!("Reference not found: {}", ref_str))
1489 })?;
1490
1491 let mut resolved = def.clone();
1492 resolve_refs(&mut resolved, defs)?;
1493 *value = resolved;
1494 return Ok(());
1495 }
1496
1497 for (_, v) in obj.iter_mut() {
1498 resolve_refs(v, defs)?;
1499 }
1500 }
1501 Value::Array(arr) => {
1502 for item in arr.iter_mut() {
1503 resolve_refs(item, defs)?;
1504 }
1505 }
1506 _ => {}
1507 }
1508
1509 Ok(())
1510 }
1511
1512 fn parse_ref_path(ref_str: &str) -> Result<String, CompletionError> {
1518 if let Some(fragment) = ref_str.strip_prefix('#') {
1519 if let Some(name) = fragment.strip_prefix("/$defs/") {
1520 Ok(name.to_string())
1521 } else if let Some(name) = fragment.strip_prefix("/definitions/") {
1522 Ok(name.to_string())
1523 } else {
1524 Err(CompletionError::ResponseError(format!(
1525 "Unsupported reference format: {}",
1526 ref_str
1527 )))
1528 }
1529 } else {
1530 Err(CompletionError::ResponseError(format!(
1531 "Only fragment references (#/...) are supported: {}",
1532 ref_str
1533 )))
1534 }
1535 }
1536
1537 fn extract_type(type_value: &Value) -> Option<String> {
1540 if type_value.is_string() {
1541 type_value.as_str().map(String::from)
1542 } else if type_value.is_array() {
1543 type_value
1544 .as_array()
1545 .and_then(|arr| arr.first())
1546 .and_then(|v| v.as_str().map(String::from))
1547 } else {
1548 None
1549 }
1550 }
1551
1552 fn extract_type_from_composition(composition: &Value) -> Option<String> {
1555 composition.as_array().and_then(|arr| {
1556 arr.iter().find_map(|schema| {
1557 if let Some(obj) = schema.as_object() {
1558 if let Some(type_val) = obj.get("type")
1560 && let Some(type_str) = type_val.as_str()
1561 && type_str == "null"
1562 {
1563 return None;
1564 }
1565 obj.get("type").and_then(extract_type).or_else(|| {
1567 if obj.contains_key("properties") {
1568 Some("object".to_string())
1569 } else {
1570 None
1571 }
1572 })
1573 } else {
1574 None
1575 }
1576 })
1577 })
1578 }
1579
1580 fn extract_schema_from_composition(
1583 composition: &Value,
1584 ) -> Option<serde_json::Map<String, Value>> {
1585 composition.as_array().and_then(|arr| {
1586 arr.iter().find_map(|schema| {
1587 if let Some(obj) = schema.as_object()
1588 && let Some(type_val) = obj.get("type")
1589 && let Some(type_str) = type_val.as_str()
1590 {
1591 if type_str == "null" {
1592 return None;
1593 }
1594 Some(obj.clone())
1595 } else {
1596 None
1597 }
1598 })
1599 })
1600 }
1601
1602 fn infer_type(obj: &serde_json::Map<String, Value>) -> String {
1605 if let Some(type_val) = obj.get("type")
1607 && let Some(type_str) = extract_type(type_val)
1608 {
1609 return type_str;
1610 }
1611
1612 if let Some(any_of) = obj.get("anyOf")
1614 && let Some(type_str) = extract_type_from_composition(any_of)
1615 {
1616 return type_str;
1617 }
1618
1619 if let Some(one_of) = obj.get("oneOf")
1620 && let Some(type_str) = extract_type_from_composition(one_of)
1621 {
1622 return type_str;
1623 }
1624
1625 if let Some(all_of) = obj.get("allOf")
1626 && let Some(type_str) = extract_type_from_composition(all_of)
1627 {
1628 return type_str;
1629 }
1630
1631 if obj.contains_key("properties") {
1633 "object".to_string()
1634 } else {
1635 String::new()
1636 }
1637 }
1638
1639 impl TryFrom<Value> for Schema {
1640 type Error = CompletionError;
1641
1642 fn try_from(value: Value) -> Result<Self, Self::Error> {
1643 let flattened_val = flatten_schema(value)?;
1644 if let Some(obj) = flattened_val.as_object() {
1645 let props_source = if obj.get("properties").is_none() {
1648 if let Some(any_of) = obj.get("anyOf") {
1649 extract_schema_from_composition(any_of)
1650 } else if let Some(one_of) = obj.get("oneOf") {
1651 extract_schema_from_composition(one_of)
1652 } else if let Some(all_of) = obj.get("allOf") {
1653 extract_schema_from_composition(all_of)
1654 } else {
1655 None
1656 }
1657 .unwrap_or(obj.clone())
1658 } else {
1659 obj.clone()
1660 };
1661
1662 Ok(Schema {
1663 r#type: infer_type(obj),
1664 format: obj.get("format").and_then(|v| v.as_str()).map(String::from),
1665 description: obj
1666 .get("description")
1667 .and_then(|v| v.as_str())
1668 .map(String::from),
1669 nullable: obj.get("nullable").and_then(|v| v.as_bool()),
1670 r#enum: obj.get("enum").and_then(|v| v.as_array()).map(|arr| {
1671 arr.iter()
1672 .filter_map(|v| v.as_str().map(String::from))
1673 .collect()
1674 }),
1675 max_items: obj
1676 .get("maxItems")
1677 .and_then(|v| v.as_i64())
1678 .map(|v| v as i32),
1679 min_items: obj
1680 .get("minItems")
1681 .and_then(|v| v.as_i64())
1682 .map(|v| v as i32),
1683 properties: props_source
1684 .get("properties")
1685 .and_then(|v| v.as_object())
1686 .map(|map| {
1687 map.iter()
1688 .filter_map(|(k, v)| {
1689 v.clone().try_into().ok().map(|schema| (k.clone(), schema))
1690 })
1691 .collect()
1692 }),
1693 required: props_source
1694 .get("required")
1695 .and_then(|v| v.as_array())
1696 .map(|arr| {
1697 arr.iter()
1698 .filter_map(|v| v.as_str().map(String::from))
1699 .collect()
1700 }),
1701 items: obj
1702 .get("items")
1703 .and_then(|v| v.clone().try_into().ok())
1704 .map(Box::new),
1705 })
1706 } else {
1707 Err(CompletionError::ResponseError(
1708 "Expected a JSON object for Schema".into(),
1709 ))
1710 }
1711 }
1712 }
1713
1714 #[derive(Debug, Serialize)]
1715 #[serde(rename_all = "camelCase")]
1716 pub struct GenerateContentRequest {
1717 pub contents: Vec<Content>,
1718 #[serde(skip_serializing_if = "Option::is_none")]
1719 pub tools: Option<Vec<Tool>>,
1720 pub tool_config: Option<ToolConfig>,
1721 pub generation_config: Option<GenerationConfig>,
1723 pub safety_settings: Option<Vec<SafetySetting>>,
1737 pub system_instruction: Option<Content>,
1740 #[serde(flatten, skip_serializing_if = "Option::is_none")]
1743 pub additional_params: Option<serde_json::Value>,
1744 }
1745
1746 #[derive(Debug, Serialize)]
1747 #[serde(rename_all = "camelCase")]
1748 pub struct Tool {
1749 pub function_declarations: Vec<FunctionDeclaration>,
1750 pub code_execution: Option<CodeExecution>,
1751 }
1752
1753 #[derive(Debug, Serialize, Clone)]
1754 #[serde(rename_all = "camelCase")]
1755 pub struct FunctionDeclaration {
1756 pub name: String,
1757 pub description: String,
1758 #[serde(skip_serializing_if = "Option::is_none")]
1759 pub parameters: Option<Schema>,
1760 }
1761
1762 #[derive(Debug, Serialize, Deserialize)]
1763 #[serde(rename_all = "camelCase")]
1764 pub struct ToolConfig {
1765 pub function_calling_config: Option<FunctionCallingMode>,
1766 }
1767
1768 #[derive(Debug, Serialize, Deserialize, Default)]
1769 #[serde(tag = "mode", rename_all = "UPPERCASE")]
1770 pub enum FunctionCallingMode {
1771 #[default]
1772 Auto,
1773 None,
1774 Any {
1775 #[serde(skip_serializing_if = "Option::is_none")]
1776 allowed_function_names: Option<Vec<String>>,
1777 },
1778 }
1779
1780 impl TryFrom<message::ToolChoice> for FunctionCallingMode {
1781 type Error = CompletionError;
1782 fn try_from(value: message::ToolChoice) -> Result<Self, Self::Error> {
1783 let res = match value {
1784 message::ToolChoice::Auto => Self::Auto,
1785 message::ToolChoice::None => Self::None,
1786 message::ToolChoice::Required => Self::Any {
1787 allowed_function_names: None,
1788 },
1789 message::ToolChoice::Specific { function_names } => Self::Any {
1790 allowed_function_names: Some(function_names),
1791 },
1792 };
1793
1794 Ok(res)
1795 }
1796 }
1797
1798 #[derive(Debug, Serialize)]
1799 pub struct CodeExecution {}
1800
1801 #[derive(Debug, Serialize)]
1802 #[serde(rename_all = "camelCase")]
1803 pub struct SafetySetting {
1804 pub category: HarmCategory,
1805 pub threshold: HarmBlockThreshold,
1806 }
1807
1808 #[derive(Debug, Serialize)]
1809 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1810 pub enum HarmBlockThreshold {
1811 HarmBlockThresholdUnspecified,
1812 BlockLowAndAbove,
1813 BlockMediumAndAbove,
1814 BlockOnlyHigh,
1815 BlockNone,
1816 Off,
1817 }
1818}
1819
1820#[cfg(test)]
1821mod tests {
1822 use crate::{message, providers::gemini::completion::gemini_api_types::flatten_schema};
1823
1824 use super::*;
1825 use serde_json::json;
1826
1827 #[test]
1828 fn test_deserialize_message_user() {
1829 let raw_message = r#"{
1830 "parts": [
1831 {"text": "Hello, world!"},
1832 {"inlineData": {"mimeType": "image/png", "data": "base64encodeddata"}},
1833 {"functionCall": {"name": "test_function", "args": {"arg1": "value1"}}},
1834 {"functionResponse": {"name": "test_function", "response": {"result": "success"}}},
1835 {"fileData": {"mimeType": "application/pdf", "fileUri": "http://example.com/file.pdf"}},
1836 {"executableCode": {"code": "print('Hello, world!')", "language": "PYTHON"}},
1837 {"codeExecutionResult": {"output": "Hello, world!", "outcome": "OUTCOME_OK"}}
1838 ],
1839 "role": "user"
1840 }"#;
1841
1842 let content: Content = {
1843 let jd = &mut serde_json::Deserializer::from_str(raw_message);
1844 serde_path_to_error::deserialize(jd).unwrap_or_else(|err| {
1845 panic!("Deserialization error at {}: {}", err.path(), err);
1846 })
1847 };
1848 assert_eq!(content.role, Some(Role::User));
1849 assert_eq!(content.parts.len(), 7);
1850
1851 let parts: Vec<Part> = content.parts.into_iter().collect();
1852
1853 if let Part {
1854 part: PartKind::Text(text),
1855 ..
1856 } = &parts[0]
1857 {
1858 assert_eq!(text, "Hello, world!");
1859 } else {
1860 panic!("Expected text part");
1861 }
1862
1863 if let Part {
1864 part: PartKind::InlineData(inline_data),
1865 ..
1866 } = &parts[1]
1867 {
1868 assert_eq!(inline_data.mime_type, "image/png");
1869 assert_eq!(inline_data.data, "base64encodeddata");
1870 } else {
1871 panic!("Expected inline data part");
1872 }
1873
1874 if let Part {
1875 part: PartKind::FunctionCall(function_call),
1876 ..
1877 } = &parts[2]
1878 {
1879 assert_eq!(function_call.name, "test_function");
1880 assert_eq!(
1881 function_call.args.as_object().unwrap().get("arg1").unwrap(),
1882 "value1"
1883 );
1884 } else {
1885 panic!("Expected function call part");
1886 }
1887
1888 if let Part {
1889 part: PartKind::FunctionResponse(function_response),
1890 ..
1891 } = &parts[3]
1892 {
1893 assert_eq!(function_response.name, "test_function");
1894 assert_eq!(
1895 function_response
1896 .response
1897 .as_ref()
1898 .unwrap()
1899 .get("result")
1900 .unwrap(),
1901 "success"
1902 );
1903 } else {
1904 panic!("Expected function response part");
1905 }
1906
1907 if let Part {
1908 part: PartKind::FileData(file_data),
1909 ..
1910 } = &parts[4]
1911 {
1912 assert_eq!(file_data.mime_type.as_ref().unwrap(), "application/pdf");
1913 assert_eq!(file_data.file_uri, "http://example.com/file.pdf");
1914 } else {
1915 panic!("Expected file data part");
1916 }
1917
1918 if let Part {
1919 part: PartKind::ExecutableCode(executable_code),
1920 ..
1921 } = &parts[5]
1922 {
1923 assert_eq!(executable_code.code, "print('Hello, world!')");
1924 } else {
1925 panic!("Expected executable code part");
1926 }
1927
1928 if let Part {
1929 part: PartKind::CodeExecutionResult(code_execution_result),
1930 ..
1931 } = &parts[6]
1932 {
1933 assert_eq!(
1934 code_execution_result.clone().output.unwrap(),
1935 "Hello, world!"
1936 );
1937 } else {
1938 panic!("Expected code execution result part");
1939 }
1940 }
1941
1942 #[test]
1943 fn test_deserialize_message_model() {
1944 let json_data = json!({
1945 "parts": [{"text": "Hello, user!"}],
1946 "role": "model"
1947 });
1948
1949 let content: Content = serde_json::from_value(json_data).unwrap();
1950 assert_eq!(content.role, Some(Role::Model));
1951 assert_eq!(content.parts.len(), 1);
1952 if let Some(Part {
1953 part: PartKind::Text(text),
1954 ..
1955 }) = content.parts.first()
1956 {
1957 assert_eq!(text, "Hello, user!");
1958 } else {
1959 panic!("Expected text part");
1960 }
1961 }
1962
1963 #[test]
1964 fn test_message_conversion_user() {
1965 let msg = message::Message::user("Hello, world!");
1966 let content: Content = msg.try_into().unwrap();
1967 assert_eq!(content.role, Some(Role::User));
1968 assert_eq!(content.parts.len(), 1);
1969 if let Some(Part {
1970 part: PartKind::Text(text),
1971 ..
1972 }) = &content.parts.first()
1973 {
1974 assert_eq!(text, "Hello, world!");
1975 } else {
1976 panic!("Expected text part");
1977 }
1978 }
1979
1980 #[test]
1981 fn test_message_conversion_model() {
1982 let msg = message::Message::assistant("Hello, user!");
1983
1984 let content: Content = msg.try_into().unwrap();
1985 assert_eq!(content.role, Some(Role::Model));
1986 assert_eq!(content.parts.len(), 1);
1987 if let Some(Part {
1988 part: PartKind::Text(text),
1989 ..
1990 }) = &content.parts.first()
1991 {
1992 assert_eq!(text, "Hello, user!");
1993 } else {
1994 panic!("Expected text part");
1995 }
1996 }
1997
1998 #[test]
1999 fn test_message_conversion_tool_call() {
2000 let tool_call = message::ToolCall {
2001 id: "test_tool".to_string(),
2002 call_id: None,
2003 function: message::ToolFunction {
2004 name: "test_function".to_string(),
2005 arguments: json!({"arg1": "value1"}),
2006 },
2007 signature: None,
2008 additional_params: None,
2009 };
2010
2011 let msg = message::Message::Assistant {
2012 id: None,
2013 content: OneOrMany::one(message::AssistantContent::ToolCall(tool_call)),
2014 };
2015
2016 let content: Content = msg.try_into().unwrap();
2017 assert_eq!(content.role, Some(Role::Model));
2018 assert_eq!(content.parts.len(), 1);
2019 if let Some(Part {
2020 part: PartKind::FunctionCall(function_call),
2021 ..
2022 }) = content.parts.first()
2023 {
2024 assert_eq!(function_call.name, "test_function");
2025 assert_eq!(
2026 function_call.args.as_object().unwrap().get("arg1").unwrap(),
2027 "value1"
2028 );
2029 } else {
2030 panic!("Expected function call part");
2031 }
2032 }
2033
2034 #[test]
2035 fn test_vec_schema_conversion() {
2036 let schema_with_ref = json!({
2037 "type": "array",
2038 "items": {
2039 "$ref": "#/$defs/Person"
2040 },
2041 "$defs": {
2042 "Person": {
2043 "type": "object",
2044 "properties": {
2045 "first_name": {
2046 "type": ["string", "null"],
2047 "description": "The person's first name, if provided (null otherwise)"
2048 },
2049 "last_name": {
2050 "type": ["string", "null"],
2051 "description": "The person's last name, if provided (null otherwise)"
2052 },
2053 "job": {
2054 "type": ["string", "null"],
2055 "description": "The person's job, if provided (null otherwise)"
2056 }
2057 },
2058 "required": []
2059 }
2060 }
2061 });
2062
2063 let result: Result<Schema, _> = schema_with_ref.try_into();
2064
2065 match result {
2066 Ok(schema) => {
2067 assert_eq!(schema.r#type, "array");
2068
2069 if let Some(items) = schema.items {
2070 println!("item types: {}", items.r#type);
2071
2072 assert_ne!(items.r#type, "", "Items type should not be empty string!");
2073 assert_eq!(items.r#type, "object", "Items should be object type");
2074 } else {
2075 panic!("Schema should have items field for array type");
2076 }
2077 }
2078 Err(e) => println!("Schema conversion failed: {:?}", e),
2079 }
2080 }
2081
2082 #[test]
2083 fn test_object_schema() {
2084 let simple_schema = json!({
2085 "type": "object",
2086 "properties": {
2087 "name": {
2088 "type": "string"
2089 }
2090 }
2091 });
2092
2093 let schema: Schema = simple_schema.try_into().unwrap();
2094 assert_eq!(schema.r#type, "object");
2095 assert!(schema.properties.is_some());
2096 }
2097
2098 #[test]
2099 fn test_array_with_inline_items() {
2100 let inline_schema = json!({
2101 "type": "array",
2102 "items": {
2103 "type": "object",
2104 "properties": {
2105 "name": {
2106 "type": "string"
2107 }
2108 }
2109 }
2110 });
2111
2112 let schema: Schema = inline_schema.try_into().unwrap();
2113 assert_eq!(schema.r#type, "array");
2114
2115 if let Some(items) = schema.items {
2116 assert_eq!(items.r#type, "object");
2117 assert!(items.properties.is_some());
2118 } else {
2119 panic!("Schema should have items field");
2120 }
2121 }
2122 #[test]
2123 fn test_flattened_schema() {
2124 let ref_schema = json!({
2125 "type": "array",
2126 "items": {
2127 "$ref": "#/$defs/Person"
2128 },
2129 "$defs": {
2130 "Person": {
2131 "type": "object",
2132 "properties": {
2133 "name": { "type": "string" }
2134 }
2135 }
2136 }
2137 });
2138
2139 let flattened = flatten_schema(ref_schema).unwrap();
2140 let schema: Schema = flattened.try_into().unwrap();
2141
2142 assert_eq!(schema.r#type, "array");
2143
2144 if let Some(items) = schema.items {
2145 println!("Flattened items type: '{}'", items.r#type);
2146
2147 assert_eq!(items.r#type, "object");
2148 assert!(items.properties.is_some());
2149 }
2150 }
2151
2152 #[test]
2153 fn test_tool_result_with_image_content() {
2154 use crate::OneOrMany;
2156 use crate::message::{
2157 DocumentSourceKind, Image, ImageMediaType, ToolResult, ToolResultContent,
2158 };
2159
2160 let tool_result = ToolResult {
2162 id: "test_tool".to_string(),
2163 call_id: None,
2164 content: OneOrMany::many(vec![
2165 ToolResultContent::Text(message::Text {
2166 text: r#"{"status": "success"}"#.to_string(),
2167 }),
2168 ToolResultContent::Image(Image {
2169 data: DocumentSourceKind::Base64("iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==".to_string()),
2170 media_type: Some(ImageMediaType::PNG),
2171 detail: None,
2172 additional_params: None,
2173 }),
2174 ]).expect("Should create OneOrMany with multiple items"),
2175 };
2176
2177 let user_content = message::UserContent::ToolResult(tool_result);
2178 let msg = message::Message::User {
2179 content: OneOrMany::one(user_content),
2180 };
2181
2182 let content: Content = msg.try_into().expect("Should convert to Gemini Content");
2184
2185 assert_eq!(content.role, Some(Role::User));
2186 assert_eq!(content.parts.len(), 1);
2187
2188 if let Some(Part {
2190 part: PartKind::FunctionResponse(function_response),
2191 ..
2192 }) = content.parts.first()
2193 {
2194 assert_eq!(function_response.name, "test_tool");
2195
2196 assert!(function_response.response.is_some());
2198 let response = function_response.response.as_ref().unwrap();
2199 assert!(response.get("result").is_some());
2200
2201 assert!(function_response.parts.is_some());
2203 let parts = function_response.parts.as_ref().unwrap();
2204 assert_eq!(parts.len(), 1);
2205
2206 let image_part = &parts[0];
2207 assert!(image_part.inline_data.is_some());
2208 let inline_data = image_part.inline_data.as_ref().unwrap();
2209 assert_eq!(inline_data.mime_type, "image/png");
2210 assert!(!inline_data.data.is_empty());
2211 } else {
2212 panic!("Expected FunctionResponse part");
2213 }
2214 }
2215
2216 #[test]
2217 fn test_tool_result_with_url_image() {
2218 use crate::OneOrMany;
2220 use crate::message::{
2221 DocumentSourceKind, Image, ImageMediaType, ToolResult, ToolResultContent,
2222 };
2223
2224 let tool_result = ToolResult {
2225 id: "screenshot_tool".to_string(),
2226 call_id: None,
2227 content: OneOrMany::one(ToolResultContent::Image(Image {
2228 data: DocumentSourceKind::Url("https://example.com/image.png".to_string()),
2229 media_type: Some(ImageMediaType::PNG),
2230 detail: None,
2231 additional_params: None,
2232 })),
2233 };
2234
2235 let user_content = message::UserContent::ToolResult(tool_result);
2236 let msg = message::Message::User {
2237 content: OneOrMany::one(user_content),
2238 };
2239
2240 let content: Content = msg.try_into().expect("Should convert to Gemini Content");
2241
2242 assert_eq!(content.role, Some(Role::User));
2243 assert_eq!(content.parts.len(), 1);
2244
2245 if let Some(Part {
2246 part: PartKind::FunctionResponse(function_response),
2247 ..
2248 }) = content.parts.first()
2249 {
2250 assert_eq!(function_response.name, "screenshot_tool");
2251
2252 assert!(function_response.parts.is_some());
2254 let parts = function_response.parts.as_ref().unwrap();
2255 assert_eq!(parts.len(), 1);
2256
2257 let image_part = &parts[0];
2258 assert!(image_part.file_data.is_some());
2259 let file_data = image_part.file_data.as_ref().unwrap();
2260 assert_eq!(file_data.file_uri, "https://example.com/image.png");
2261 assert_eq!(file_data.mime_type.as_ref().unwrap(), "image/png");
2262 } else {
2263 panic!("Expected FunctionResponse part");
2264 }
2265 }
2266
2267 #[test]
2268 fn test_from_tool_output_parses_image_json() {
2269 use crate::message::{DocumentSourceKind, ToolResultContent};
2271
2272 let image_json = r#"{"type": "image", "data": "base64data==", "mimeType": "image/jpeg"}"#;
2274 let result = ToolResultContent::from_tool_output(image_json);
2275
2276 assert_eq!(result.len(), 1);
2277 if let ToolResultContent::Image(img) = result.first() {
2278 assert!(matches!(img.data, DocumentSourceKind::Base64(_)));
2279 if let DocumentSourceKind::Base64(data) = &img.data {
2280 assert_eq!(data, "base64data==");
2281 }
2282 assert_eq!(img.media_type, Some(crate::message::ImageMediaType::JPEG));
2283 } else {
2284 panic!("Expected Image content");
2285 }
2286 }
2287
2288 #[test]
2289 fn test_from_tool_output_parses_hybrid_json() {
2290 use crate::message::{DocumentSourceKind, ToolResultContent};
2292
2293 let hybrid_json = r#"{
2294 "response": {"status": "ok", "count": 42},
2295 "parts": [
2296 {"type": "image", "data": "imgdata1==", "mimeType": "image/png"},
2297 {"type": "image", "data": "https://example.com/img.jpg", "mimeType": "image/jpeg"}
2298 ]
2299 }"#;
2300
2301 let result = ToolResultContent::from_tool_output(hybrid_json);
2302
2303 assert_eq!(result.len(), 3);
2305
2306 let items: Vec<_> = result.iter().collect();
2307
2308 if let ToolResultContent::Text(text) = &items[0] {
2310 assert!(text.text.contains("status"));
2311 assert!(text.text.contains("ok"));
2312 } else {
2313 panic!("Expected Text content first");
2314 }
2315
2316 if let ToolResultContent::Image(img) = &items[1] {
2318 assert!(matches!(img.data, DocumentSourceKind::Base64(_)));
2319 } else {
2320 panic!("Expected Image content second");
2321 }
2322
2323 if let ToolResultContent::Image(img) = &items[2] {
2325 assert!(matches!(img.data, DocumentSourceKind::Url(_)));
2326 } else {
2327 panic!("Expected Image content third");
2328 }
2329 }
2330
2331 #[tokio::test]
2335 #[ignore = "requires GEMINI_API_KEY environment variable"]
2336 async fn test_gemini_agent_with_image_tool_result_e2e() {
2337 use crate::completion::{Prompt, ToolDefinition};
2338 use crate::prelude::*;
2339 use crate::providers::gemini;
2340 use crate::tool::Tool;
2341 use serde::{Deserialize, Serialize};
2342
2343 #[derive(Debug, Serialize, Deserialize)]
2345 struct ImageGeneratorTool;
2346
2347 #[derive(Debug, thiserror::Error)]
2348 #[error("Image generation error")]
2349 struct ImageToolError;
2350
2351 impl Tool for ImageGeneratorTool {
2352 const NAME: &'static str = "generate_test_image";
2353 type Error = ImageToolError;
2354 type Args = serde_json::Value;
2355 type Output = String;
2357
2358 async fn definition(&self, _prompt: String) -> ToolDefinition {
2359 ToolDefinition {
2360 name: "generate_test_image".to_string(),
2361 description: "Generates a small test image (a 1x1 red pixel). Call this tool when asked to generate or show an image.".to_string(),
2362 parameters: json!({
2363 "type": "object",
2364 "properties": {},
2365 "required": []
2366 }),
2367 }
2368 }
2369
2370 async fn call(&self, _args: Self::Args) -> Result<Self::Output, Self::Error> {
2371 Ok(json!({
2374 "type": "image",
2375 "data": "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8DwHwAFBQIAX8jx0gAAAABJRU5ErkJggg==",
2376 "mimeType": "image/png"
2377 }).to_string())
2378 }
2379 }
2380
2381 let client = gemini::Client::from_env();
2382
2383 let agent = client
2384 .agent("gemini-3-flash-preview")
2385 .preamble("You are a helpful assistant. When asked about images, use the generate_test_image tool to create one, then describe what you see in the image.")
2386 .tool(ImageGeneratorTool)
2387 .build();
2388
2389 let response = agent
2391 .prompt("Please generate a test image and tell me what color the pixel is.")
2392 .await;
2393
2394 assert!(
2397 response.is_ok(),
2398 "Gemini should successfully process tool result with image: {:?}",
2399 response.err()
2400 );
2401
2402 let response_text = response.unwrap();
2403 println!("Response: {response_text}");
2404 assert!(!response_text.is_empty(), "Response should not be empty");
2406 }
2407}