1pub const GEMINI_3_1_FLASH_LITE_PREVIEW: &str = "gemini-3.1-flash-lite-preview";
7pub const GEMINI_3_FLASH_PREVIEW: &str = "gemini-3-flash-preview";
9pub const GEMINI_2_5_PRO_PREVIEW_06_05: &str = "gemini-2.5-pro-preview-06-05";
11pub const GEMINI_2_5_PRO_PREVIEW_05_06: &str = "gemini-2.5-pro-preview-05-06";
13pub const GEMINI_2_5_PRO_PREVIEW_03_25: &str = "gemini-2.5-pro-preview-03-25";
15pub const GEMINI_2_5_FLASH_PREVIEW_04_17: &str = "gemini-2.5-flash-preview-04-17";
17pub const GEMINI_2_5_PRO_EXP_03_25: &str = "gemini-2.5-pro-exp-03-25";
19pub const GEMINI_2_5_FLASH: &str = "gemini-2.5-flash";
21#[cfg(feature = "image")]
23#[cfg_attr(docsrs, doc(cfg(feature = "image")))]
24pub const GEMINI_2_5_FLASH_IMAGE: &str = "gemini-2.5-flash-image";
25pub const GEMINI_2_0_FLASH_LITE: &str = "gemini-2.0-flash-lite";
27pub const GEMINI_2_0_FLASH: &str = "gemini-2.0-flash";
29
30use self::gemini_api_types::tool_parameters_to_schema;
31use crate::http_client::HttpClientExt;
32use crate::message::{self, MimeType, Reasoning};
33
34use crate::providers::gemini::completion::gemini_api_types::{
35 AdditionalParameters, FunctionCallingMode, ToolConfig,
36};
37use crate::providers::gemini::streaming::StreamingCompletionResponse;
38use crate::telemetry::SpanCombinator;
39use crate::{
40 OneOrMany,
41 completion::{self, CompletionError, CompletionRequest, GetTokenUsage},
42};
43use gemini_api_types::{
44 Content, FinishReason, FunctionDeclaration, GenerateContentRequest, GenerateContentResponse,
45 GenerationConfig, Part, PartKind, Role, Tool,
46};
47use serde_json::{Map, Value};
48use std::convert::TryFrom;
49use tracing::{Level, enabled, info_span};
50use tracing_futures::Instrument;
51
52use super::Client;
53
54#[derive(Clone, Debug)]
59pub struct CompletionModel<T = reqwest::Client> {
60 pub(crate) client: Client<T>,
61 pub model: String,
62}
63
64impl<T> CompletionModel<T> {
65 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
66 Self {
67 client,
68 model: model.into(),
69 }
70 }
71
72 pub fn with_model(client: Client<T>, model: &str) -> Self {
73 Self {
74 client,
75 model: model.into(),
76 }
77 }
78}
79
80impl<T> completion::CompletionModel for CompletionModel<T>
81where
82 T: HttpClientExt + Clone + 'static,
83{
84 type Response = GenerateContentResponse;
85 type StreamingResponse = StreamingCompletionResponse;
86 type Client = super::Client<T>;
87
88 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
89 Self::new(client.clone(), model)
90 }
91
92 async fn completion(
93 &self,
94 completion_request: CompletionRequest,
95 ) -> Result<completion::CompletionResponse<GenerateContentResponse>, CompletionError> {
96 let request_model = resolve_request_model(&self.model, &completion_request);
97 let span = if tracing::Span::current().is_disabled() {
98 info_span!(
99 target: "rig::completions",
100 "generate_content",
101 gen_ai.operation.name = "generate_content",
102 gen_ai.provider.name = "gcp.gemini",
103 gen_ai.request.model = &request_model,
104 gen_ai.system_instructions = &completion_request.preamble,
105 gen_ai.response.id = tracing::field::Empty,
106 gen_ai.response.model = tracing::field::Empty,
107 gen_ai.usage.output_tokens = tracing::field::Empty,
108 gen_ai.usage.input_tokens = tracing::field::Empty,
109 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
110 gen_ai.usage.cache_creation.input_tokens = tracing::field::Empty,
111 gen_ai.usage.tool_use_prompt_tokens = tracing::field::Empty,
112 gen_ai.usage.reasoning_tokens = tracing::field::Empty,
113 )
114 } else {
115 tracing::Span::current()
116 };
117
118 let request = create_request_body(completion_request)?;
119
120 if enabled!(Level::TRACE) {
121 tracing::trace!(
122 target: "rig::completions",
123 "Gemini completion request: {}",
124 serde_json::to_string_pretty(&request)?
125 );
126 }
127
128 let body = serde_json::to_vec(&request)?;
129
130 let path = completion_endpoint(&request_model);
131
132 let request = self
133 .client
134 .post(path.as_str())?
135 .body(body)
136 .map_err(|e| CompletionError::HttpError(e.into()))?;
137
138 async move {
139 let response = self.client.send::<_, Vec<u8>>(request).await?;
140
141 if response.status().is_success() {
142 let response_body = response
143 .into_body()
144 .await
145 .map_err(CompletionError::HttpError)?;
146
147 let response_text = String::from_utf8_lossy(&response_body).to_string();
148
149 let response: GenerateContentResponse = serde_json::from_slice(&response_body)
150 .map_err(|err| {
151 tracing::error!(
152 error = %err,
153 body = %response_text,
154 "Failed to deserialize Gemini completion response"
155 );
156 CompletionError::JsonError(err)
157 })?;
158
159 let span = tracing::Span::current();
160 span.record_response_metadata(&response);
161 span.record_token_usage(&response.usage_metadata);
162
163 if enabled!(Level::TRACE) {
164 tracing::trace!(
165 target: "rig::completions",
166 "Gemini completion response: {}",
167 serde_json::to_string_pretty(&response)?
168 );
169 }
170
171 response.try_into()
172 } else {
173 let text = String::from_utf8_lossy(
174 &response
175 .into_body()
176 .await
177 .map_err(CompletionError::HttpError)?,
178 )
179 .into();
180
181 Err(CompletionError::ProviderError(text))
182 }
183 }
184 .instrument(span)
185 .await
186 }
187
188 async fn stream(
189 &self,
190 request: CompletionRequest,
191 ) -> Result<
192 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
193 CompletionError,
194 > {
195 CompletionModel::stream(self, request).await
196 }
197}
198
199pub(crate) fn create_request_body(
200 completion_request: CompletionRequest,
201) -> Result<GenerateContentRequest, CompletionError> {
202 let chat_history = completion_request.chat_history_with_documents();
203
204 let CompletionRequest {
205 model: _,
206 preamble,
207 chat_history: _,
208 documents: _,
209 tools: function_tools,
210 temperature,
211 max_tokens,
212 tool_choice,
213 mut additional_params,
214 output_schema,
215 } = completion_request;
216
217 let mut full_history = Vec::new();
218 full_history.extend(chat_history);
219 let (history_system, full_history) = split_system_messages_from_history(full_history);
220
221 let mut additional_params_payload = additional_params
222 .take()
223 .unwrap_or_else(|| Value::Object(Map::new()));
224 let mut additional_tools =
225 extract_tools_from_additional_params(&mut additional_params_payload)?;
226
227 let AdditionalParameters {
228 mut generation_config,
229 additional_params,
230 } = serde_json::from_value::<AdditionalParameters>(additional_params_payload)?;
231
232 if let Some(schema) = output_schema {
234 let cfg = generation_config.get_or_insert_with(GenerationConfig::default);
235 cfg.response_mime_type = Some("application/json".to_string());
236 cfg.response_json_schema = Some(schema.to_value());
237 }
238
239 generation_config = generation_config.map(|mut cfg| {
240 if let Some(temp) = temperature {
241 cfg.temperature = Some(temp);
242 };
243
244 if let Some(max_tokens) = max_tokens {
245 cfg.max_output_tokens = Some(max_tokens);
246 };
247
248 cfg
249 });
250
251 let mut system_parts: Vec<Part> = Vec::new();
252 if let Some(preamble) = preamble.filter(|preamble| !preamble.is_empty()) {
253 system_parts.push(preamble.into());
254 }
255 for content in history_system {
256 if !content.is_empty() {
257 system_parts.push(content.into());
258 }
259 }
260 let system_instruction = if system_parts.is_empty() {
261 None
262 } else {
263 Some(Content {
264 parts: system_parts,
265 role: Some(Role::Model),
266 })
267 };
268
269 let mut tools = if function_tools.is_empty() {
270 Vec::new()
271 } else {
272 vec![serde_json::to_value(Tool::try_from(function_tools)?)?]
273 };
274 tools.append(&mut additional_tools);
275 let tools = if tools.is_empty() { None } else { Some(tools) };
276
277 let tool_config = if let Some(cfg) = tool_choice {
278 Some(ToolConfig {
279 function_calling_config: Some(FunctionCallingMode::try_from(cfg)?),
280 })
281 } else {
282 None
283 };
284
285 let request = GenerateContentRequest {
286 contents: full_history
287 .into_iter()
288 .map(|msg| {
289 msg.try_into()
290 .map_err(|e| CompletionError::RequestError(Box::new(e)))
291 })
292 .collect::<Result<Vec<_>, _>>()?,
293 generation_config,
294 safety_settings: None,
295 tools,
296 tool_config,
297 system_instruction,
298 additional_params,
299 };
300
301 Ok(request)
302}
303
304fn split_system_messages_from_history(
305 history: Vec<completion::Message>,
306) -> (Vec<String>, Vec<completion::Message>) {
307 let mut system = Vec::new();
308 let mut remaining = Vec::new();
309
310 for message in history {
311 match message {
312 completion::Message::System { content } => system.push(content),
313 other => remaining.push(other),
314 }
315 }
316
317 (system, remaining)
318}
319
320fn extract_tools_from_additional_params(
321 additional_params: &mut Value,
322) -> Result<Vec<Value>, CompletionError> {
323 if let Some(map) = additional_params.as_object_mut()
324 && let Some(raw_tools) = map.remove("tools")
325 {
326 return serde_json::from_value::<Vec<Value>>(raw_tools).map_err(|err| {
327 CompletionError::RequestError(
328 format!("Invalid Gemini `additional_params.tools` payload: {err}").into(),
329 )
330 });
331 }
332
333 Ok(Vec::new())
334}
335
336pub(crate) fn resolve_request_model(
337 default_model: &str,
338 completion_request: &CompletionRequest,
339) -> String {
340 completion_request
341 .model
342 .clone()
343 .unwrap_or_else(|| default_model.to_string())
344}
345
346pub(crate) fn completion_endpoint(model: &str) -> String {
347 format!("/v1beta/models/{model}:generateContent")
348}
349
350pub(crate) fn streaming_endpoint(model: &str) -> String {
351 format!("/v1beta/models/{model}:streamGenerateContent")
352}
353
354impl TryFrom<completion::ToolDefinition> for Tool {
355 type Error = CompletionError;
356
357 fn try_from(tool: completion::ToolDefinition) -> Result<Self, Self::Error> {
358 let parameters = tool_parameters_to_schema(tool.parameters)?;
359
360 Ok(Self {
361 function_declarations: vec![FunctionDeclaration {
362 name: tool.name,
363 description: tool.description,
364 parameters,
365 }],
366 code_execution: None,
367 })
368 }
369}
370
371impl TryFrom<Vec<completion::ToolDefinition>> for Tool {
372 type Error = CompletionError;
373
374 fn try_from(tools: Vec<completion::ToolDefinition>) -> Result<Self, Self::Error> {
375 let mut function_declarations = Vec::new();
376
377 for tool in tools {
378 let parameters = tool_parameters_to_schema(tool.parameters).map_err(|e| {
379 CompletionError::ProviderError(format!(
380 "Tool '{}' could not be converted to a schema: {:?}",
381 tool.name, e,
382 ))
383 })?;
384
385 function_declarations.push(FunctionDeclaration {
386 name: tool.name,
387 description: tool.description,
388 parameters,
389 });
390 }
391
392 Ok(Self {
393 function_declarations,
394 code_execution: None,
395 })
396 }
397}
398
399pub(crate) fn function_call_finish_reason_error(
400 reason: &FinishReason,
401 finish_message: Option<&str>,
402) -> Option<CompletionError> {
403 match reason {
404 FinishReason::MalformedFunctionCall
405 | FinishReason::UnexpectedToolCall
406 | FinishReason::MissingThoughtSignature
407 | FinishReason::TooManyToolCalls
408 | FinishReason::MalformedResponse => {
409 let message = finish_message.unwrap_or("no finish message provided");
410 Some(CompletionError::ResponseError(format!(
411 "Gemini stopped with finish_reason={reason:?}: {message}"
412 )))
413 }
414 _ => None,
415 }
416}
417
418impl TryFrom<GenerateContentResponse> for completion::CompletionResponse<GenerateContentResponse> {
419 type Error = CompletionError;
420
421 fn try_from(response: GenerateContentResponse) -> Result<Self, Self::Error> {
422 let candidate = response.candidates.first().ok_or_else(|| {
423 CompletionError::ResponseError("No response candidates in response".into())
424 })?;
425
426 if let Some(reason) = candidate.finish_reason.as_ref()
427 && let Some(err) =
428 function_call_finish_reason_error(reason, candidate.finish_message.as_deref())
429 {
430 return Err(err);
431 }
432
433 let content = candidate
434 .content
435 .as_ref()
436 .ok_or_else(|| {
437 let reason = candidate
438 .finish_reason
439 .as_ref()
440 .map(|r| format!("finish_reason={r:?}"))
441 .unwrap_or_else(|| "finish_reason=<unknown>".to_string());
442 let message = candidate
443 .finish_message
444 .as_deref()
445 .unwrap_or("no finish message provided");
446 CompletionError::ResponseError(format!(
447 "Gemini candidate missing content ({reason}, finish_message={message})"
448 ))
449 })?
450 .parts
451 .iter()
452 .map(
453 |Part {
454 thought,
455 thought_signature,
456 part,
457 ..
458 }| {
459 Ok(match part {
460 PartKind::Text(text) => {
461 if let Some(thought) = thought
462 && *thought
463 {
464 completion::AssistantContent::Reasoning(
465 Reasoning::new_with_signature(text, thought_signature.clone()),
466 )
467 } else {
468 completion::AssistantContent::text(text)
469 }
470 }
471 PartKind::InlineData(inline_data) => {
472 let mime_type =
473 message::MediaType::from_mime_type(&inline_data.mime_type);
474
475 match mime_type {
476 Some(message::MediaType::Image(media_type)) => {
477 message::AssistantContent::image_base64(
478 &inline_data.data,
479 Some(media_type),
480 Some(message::ImageDetail::default()),
481 )
482 }
483 _ => {
484 return Err(CompletionError::ResponseError(format!(
485 "Unsupported media type {mime_type:?}"
486 )));
487 }
488 }
489 }
490 PartKind::FunctionCall(function_call) => {
491 completion::AssistantContent::ToolCall(
492 message::ToolCall::new(
493 function_call.name.clone(),
494 message::ToolFunction::new(
495 function_call.name.clone(),
496 function_call.args.clone(),
497 ),
498 )
499 .with_signature(thought_signature.clone()),
500 )
501 }
502 _ => {
503 return Err(CompletionError::ResponseError(
504 "Response did not contain a message or tool call".into(),
505 ));
506 }
507 })
508 },
509 )
510 .collect::<Result<Vec<_>, _>>()?;
511
512 let choice = OneOrMany::many(content).map_err(|_| {
513 CompletionError::ResponseError(
514 "Response contained no message or tool call (empty)".to_owned(),
515 )
516 })?;
517
518 let usage = response
519 .usage_metadata
520 .as_ref()
521 .map(GetTokenUsage::token_usage)
522 .unwrap_or_default();
523
524 Ok(completion::CompletionResponse {
525 choice,
526 usage,
527 raw_response: response,
528 message_id: None,
529 })
530 }
531}
532
533pub mod gemini_api_types {
534 use crate::telemetry::ProviderResponseExt;
535 use std::{collections::HashMap, convert::Infallible, str::FromStr};
536
537 use serde::{Deserialize, Serialize};
541 use serde_json::{Value, json};
542
543 use crate::completion::GetTokenUsage;
544 use crate::message::{DocumentSourceKind, ImageMediaType, MessageError, MimeType};
545 use crate::{
546 completion::CompletionError,
547 message::{self},
548 providers::gemini::gemini_api_types::{CodeExecutionResult, ExecutableCode},
549 };
550
551 #[derive(Debug, Deserialize, Serialize, Default)]
552 #[serde(rename_all = "camelCase")]
553 pub struct AdditionalParameters {
554 pub generation_config: Option<GenerationConfig>,
556 #[serde(flatten, skip_serializing_if = "Option::is_none")]
558 pub additional_params: Option<serde_json::Value>,
559 }
560
561 impl AdditionalParameters {
562 pub fn with_config(mut self, cfg: GenerationConfig) -> Self {
563 self.generation_config = Some(cfg);
564 self
565 }
566
567 pub fn with_params(mut self, params: serde_json::Value) -> Self {
568 self.additional_params = Some(params);
569 self
570 }
571 }
572
573 #[derive(Debug, Deserialize, Serialize)]
581 #[serde(rename_all = "camelCase")]
582 pub struct GenerateContentResponse {
583 pub response_id: String,
584 pub candidates: Vec<ContentCandidate>,
586 pub prompt_feedback: Option<PromptFeedback>,
588 pub usage_metadata: Option<UsageMetadata>,
590 pub model_version: Option<String>,
591 }
592
593 impl ProviderResponseExt for GenerateContentResponse {
594 type OutputMessage = ContentCandidate;
595 type Usage = UsageMetadata;
596
597 fn get_response_id(&self) -> Option<String> {
598 Some(self.response_id.clone())
599 }
600
601 fn get_response_model_name(&self) -> Option<String> {
602 self.model_version.clone()
603 }
604
605 fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
606 self.candidates.clone()
607 }
608
609 fn get_text_response(&self) -> Option<String> {
610 let str = self
611 .candidates
612 .iter()
613 .filter_map(|x| {
614 let content = x.content.as_ref()?;
615 if content.role.as_ref().is_none_or(|y| y != &Role::Model) {
616 return None;
617 }
618
619 let res = content
620 .parts
621 .iter()
622 .filter_map(|part| {
623 if let PartKind::Text(ref str) = part.part {
624 Some(str.to_owned())
625 } else {
626 None
627 }
628 })
629 .collect::<Vec<String>>()
630 .join("\n");
631
632 Some(res)
633 })
634 .collect::<Vec<String>>()
635 .join("\n");
636
637 if str.is_empty() { None } else { Some(str) }
638 }
639
640 fn get_usage(&self) -> Option<Self::Usage> {
641 self.usage_metadata.clone()
642 }
643 }
644
645 #[derive(Clone, Debug, Deserialize, Serialize)]
647 #[serde(rename_all = "camelCase")]
648 pub struct ContentCandidate {
649 #[serde(skip_serializing_if = "Option::is_none")]
651 pub content: Option<Content>,
652 pub finish_reason: Option<FinishReason>,
655 pub safety_ratings: Option<Vec<SafetyRating>>,
658 pub citation_metadata: Option<CitationMetadata>,
662 pub token_count: Option<i32>,
664 pub avg_logprobs: Option<f64>,
666 pub logprobs_result: Option<LogprobsResult>,
668 pub index: Option<i32>,
670 pub finish_message: Option<String>,
672 }
673
674 #[derive(Clone, Debug, Deserialize, Serialize)]
675 pub struct Content {
676 #[serde(default)]
678 pub parts: Vec<Part>,
679 pub role: Option<Role>,
682 }
683
684 impl TryFrom<message::Message> for Content {
685 type Error = message::MessageError;
686
687 fn try_from(msg: message::Message) -> Result<Self, Self::Error> {
688 Ok(match msg {
689 message::Message::System { content } => Content {
690 parts: vec![content.into()],
691 role: Some(Role::User),
692 },
693 message::Message::User { content } => Content {
694 parts: content
695 .into_iter()
696 .map(|c| c.try_into())
697 .collect::<Result<Vec<_>, _>>()?,
698 role: Some(Role::User),
699 },
700 message::Message::Assistant { content, .. } => Content {
701 role: Some(Role::Model),
702 parts: content
703 .into_iter()
704 .map(|content| content.try_into())
705 .collect::<Result<Vec<_>, _>>()?,
706 },
707 })
708 }
709 }
710
711 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
712 #[serde(rename_all = "lowercase")]
713 pub enum Role {
714 User,
715 Model,
716 }
717
718 #[derive(Debug, Default, Deserialize, Serialize, Clone, PartialEq)]
719 #[serde(rename_all = "camelCase")]
720 pub struct Part {
721 #[serde(skip_serializing_if = "Option::is_none")]
723 pub thought: Option<bool>,
724 #[serde(skip_serializing_if = "Option::is_none")]
726 pub thought_signature: Option<String>,
727 #[serde(flatten)]
728 pub part: PartKind,
729 #[serde(flatten, skip_serializing_if = "Option::is_none")]
730 pub additional_params: Option<Value>,
731 }
732
733 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
737 #[serde(rename_all = "camelCase")]
738 pub enum PartKind {
739 Text(String),
740 InlineData(Blob),
741 FunctionCall(FunctionCall),
742 FunctionResponse(FunctionResponse),
743 FileData(FileData),
744 ExecutableCode(ExecutableCode),
745 CodeExecutionResult(CodeExecutionResult),
746 }
747
748 impl Default for PartKind {
751 fn default() -> Self {
752 Self::Text(String::new())
753 }
754 }
755
756 impl From<String> for Part {
757 fn from(text: String) -> Self {
758 Self {
759 thought: Some(false),
760 thought_signature: None,
761 part: PartKind::Text(text),
762 additional_params: None,
763 }
764 }
765 }
766
767 impl From<&str> for Part {
768 fn from(text: &str) -> Self {
769 Self::from(text.to_string())
770 }
771 }
772
773 impl FromStr for Part {
774 type Err = Infallible;
775
776 fn from_str(s: &str) -> Result<Self, Self::Err> {
777 Ok(s.into())
778 }
779 }
780
781 impl TryFrom<(ImageMediaType, DocumentSourceKind)> for PartKind {
782 type Error = message::MessageError;
783 fn try_from(
784 (mime_type, doc_src): (ImageMediaType, DocumentSourceKind),
785 ) -> Result<Self, Self::Error> {
786 let mime_type = mime_type.to_mime_type().to_string();
787 let part = match doc_src {
788 DocumentSourceKind::Url(url) => PartKind::FileData(FileData {
789 mime_type: Some(mime_type),
790 file_uri: url,
791 }),
792 DocumentSourceKind::Base64(data) | DocumentSourceKind::String(data) => {
793 PartKind::InlineData(Blob { mime_type, data })
794 }
795 DocumentSourceKind::Raw(_) => {
796 return Err(message::MessageError::ConversionError(
797 "Raw files not supported, encode as base64 first".into(),
798 ));
799 }
800 DocumentSourceKind::FileId(_) => {
801 return Err(message::MessageError::ConversionError(
802 "Provider file IDs are not supported for Gemini image inputs".into(),
803 ));
804 }
805 DocumentSourceKind::Unknown => {
806 return Err(message::MessageError::ConversionError(
807 "Can't convert an unknown document source".to_string(),
808 ));
809 }
810 };
811
812 Ok(part)
813 }
814 }
815
816 impl TryFrom<message::UserContent> for Part {
817 type Error = message::MessageError;
818
819 fn try_from(content: message::UserContent) -> Result<Self, Self::Error> {
820 match content {
821 message::UserContent::Text(message::Text { text, .. }) => Ok(Part {
822 thought: Some(false),
823 thought_signature: None,
824 part: PartKind::Text(text),
825 additional_params: None,
826 }),
827 message::UserContent::ToolResult(message::ToolResult { id, content, .. }) => {
828 let mut response_json: Option<serde_json::Value> = None;
829 let mut parts: Vec<FunctionResponsePart> = Vec::new();
830
831 for item in content.iter() {
832 match item {
833 message::ToolResultContent::Text(text) => {
834 let result: serde_json::Value =
835 serde_json::from_str(&text.text).unwrap_or_else(|error| {
836 tracing::trace!(
837 ?error,
838 "Tool result is not a valid JSON, treat it as normal string"
839 );
840 json!(&text.text)
841 });
842
843 response_json = Some(match response_json {
844 Some(mut existing) => {
845 if let serde_json::Value::Object(ref mut map) = existing {
846 map.insert("text".to_string(), result);
847 }
848 existing
849 }
850 None => json!({ "result": result }),
851 });
852 }
853 message::ToolResultContent::Image(image) => {
854 let part = match &image.data {
855 DocumentSourceKind::Base64(b64) => {
856 let mime_type = image
857 .media_type
858 .as_ref()
859 .ok_or(message::MessageError::ConversionError(
860 "Image media type is required for Gemini tool results".to_string(),
861 ))?
862 .to_mime_type();
863
864 FunctionResponsePart {
865 inline_data: Some(FunctionResponseInlineData {
866 mime_type: mime_type.to_string(),
867 data: b64.clone(),
868 display_name: None,
869 }),
870 file_data: None,
871 }
872 }
873 DocumentSourceKind::Url(url) => {
874 let mime_type = image
875 .media_type
876 .as_ref()
877 .map(|mt| mt.to_mime_type().to_string());
878
879 FunctionResponsePart {
880 inline_data: None,
881 file_data: Some(FileData {
882 mime_type,
883 file_uri: url.clone(),
884 }),
885 }
886 }
887 _ => {
888 return Err(message::MessageError::ConversionError(
889 "Unsupported image source kind for tool results"
890 .to_string(),
891 ));
892 }
893 };
894 parts.push(part);
895 }
896 }
897 }
898
899 Ok(Part {
900 thought: Some(false),
901 thought_signature: None,
902 part: PartKind::FunctionResponse(FunctionResponse {
903 name: id,
904 response: response_json,
905 parts: if parts.is_empty() { None } else { Some(parts) },
906 }),
907 additional_params: None,
908 })
909 }
910 message::UserContent::Image(message::Image {
911 data, media_type, ..
912 }) => match media_type {
913 Some(media_type) => match media_type {
914 message::ImageMediaType::JPEG
915 | message::ImageMediaType::PNG
916 | message::ImageMediaType::WEBP
917 | message::ImageMediaType::HEIC
918 | message::ImageMediaType::HEIF => {
919 let part = PartKind::try_from((media_type, data))?;
920 Ok(Part {
921 thought: Some(false),
922 thought_signature: None,
923 part,
924 additional_params: None,
925 })
926 }
927 _ => Err(message::MessageError::ConversionError(format!(
928 "Unsupported image media type {media_type:?}"
929 ))),
930 },
931 None => Err(message::MessageError::ConversionError(
932 "Media type for image is required for Gemini".to_string(),
933 )),
934 },
935 message::UserContent::Document(message::Document {
936 data, media_type, ..
937 }) => {
938 let Some(media_type) = media_type else {
939 return Err(MessageError::ConversionError(
940 "A mime type is required for document inputs to Gemini".to_string(),
941 ));
942 };
943
944 if matches!(
947 media_type,
948 message::DocumentMediaType::TXT
949 | message::DocumentMediaType::RTF
950 | message::DocumentMediaType::HTML
951 | message::DocumentMediaType::CSS
952 | message::DocumentMediaType::MARKDOWN
953 | message::DocumentMediaType::CSV
954 | message::DocumentMediaType::XML
955 | message::DocumentMediaType::Javascript
956 | message::DocumentMediaType::Python
957 ) {
958 use base64::Engine;
959 let part = match data {
960 DocumentSourceKind::String(text) => PartKind::Text(text),
961 DocumentSourceKind::Base64(data) => {
962 let text = String::from_utf8(
964 base64::engine::general_purpose::STANDARD
965 .decode(&data)
966 .map_err(|e| {
967 MessageError::ConversionError(format!(
968 "Failed to decode base64: {e}"
969 ))
970 })?,
971 )
972 .map_err(|e| {
973 MessageError::ConversionError(format!(
974 "Invalid UTF-8 in document: {e}"
975 ))
976 })?;
977 PartKind::Text(text)
978 }
979 DocumentSourceKind::Url(file_uri) => PartKind::FileData(FileData {
980 mime_type: Some(media_type.to_mime_type().to_string()),
981 file_uri,
982 }),
983 DocumentSourceKind::Raw(_) => {
984 return Err(MessageError::ConversionError(
985 "Raw files not supported, encode as base64 first".to_string(),
986 ));
987 }
988 DocumentSourceKind::FileId(_) => {
989 return Err(MessageError::ConversionError(
990 "Provider file IDs are not supported for Gemini documents"
991 .to_string(),
992 ));
993 }
994 DocumentSourceKind::Unknown => {
995 return Err(MessageError::ConversionError(
996 "Document has no body".to_string(),
997 ));
998 }
999 };
1000
1001 Ok(Part {
1002 thought: Some(false),
1003 part,
1004 ..Default::default()
1005 })
1006 } else if !media_type.is_code() {
1007 let mime_type = media_type.to_mime_type().to_string();
1008
1009 let part = match data {
1010 DocumentSourceKind::Url(file_uri) => PartKind::FileData(FileData {
1011 mime_type: Some(mime_type),
1012 file_uri,
1013 }),
1014 DocumentSourceKind::Base64(data) | DocumentSourceKind::String(data) => {
1015 PartKind::InlineData(Blob { mime_type, data })
1016 }
1017 DocumentSourceKind::Raw(_) => {
1018 return Err(message::MessageError::ConversionError(
1019 "Raw files not supported, encode as base64 first".into(),
1020 ));
1021 }
1022 _ => {
1023 return Err(message::MessageError::ConversionError(
1024 "Document has no body".to_string(),
1025 ));
1026 }
1027 };
1028
1029 Ok(Part {
1030 thought: Some(false),
1031 part,
1032 ..Default::default()
1033 })
1034 } else {
1035 Err(message::MessageError::ConversionError(format!(
1036 "Unsupported document media type {media_type:?}"
1037 )))
1038 }
1039 }
1040
1041 message::UserContent::Audio(message::Audio {
1042 data, media_type, ..
1043 }) => {
1044 let Some(media_type) = media_type else {
1045 return Err(MessageError::ConversionError(
1046 "A mime type is required for audio inputs to Gemini".to_string(),
1047 ));
1048 };
1049
1050 let mime_type = media_type.to_mime_type().to_string();
1051
1052 let part = match data {
1053 DocumentSourceKind::Base64(data) => {
1054 PartKind::InlineData(Blob { data, mime_type })
1055 }
1056
1057 DocumentSourceKind::Url(file_uri) => PartKind::FileData(FileData {
1058 mime_type: Some(mime_type),
1059 file_uri,
1060 }),
1061 DocumentSourceKind::String(_) => {
1062 return Err(message::MessageError::ConversionError(
1063 "Strings cannot be used as audio files!".into(),
1064 ));
1065 }
1066 DocumentSourceKind::Raw(_) => {
1067 return Err(message::MessageError::ConversionError(
1068 "Raw files not supported, encode as base64 first".into(),
1069 ));
1070 }
1071 DocumentSourceKind::FileId(_) => {
1072 return Err(message::MessageError::ConversionError(
1073 "Provider file IDs are not supported for Gemini audio inputs"
1074 .into(),
1075 ));
1076 }
1077 DocumentSourceKind::Unknown => {
1078 return Err(message::MessageError::ConversionError(
1079 "Content has no body".to_string(),
1080 ));
1081 }
1082 };
1083
1084 Ok(Part {
1085 thought: Some(false),
1086 part,
1087 ..Default::default()
1088 })
1089 }
1090 message::UserContent::Video(message::Video {
1091 data,
1092 media_type,
1093 additional_params,
1094 ..
1095 }) => {
1096 let mime_type = media_type.map(|media_ty| media_ty.to_mime_type().to_string());
1097
1098 let part = match data {
1099 DocumentSourceKind::Url(file_uri) => {
1100 if file_uri.starts_with("https://www.youtube.com") {
1101 PartKind::FileData(FileData {
1102 mime_type,
1103 file_uri,
1104 })
1105 } else {
1106 if mime_type.is_none() {
1107 return Err(MessageError::ConversionError(
1108 "A mime type is required for non-Youtube video file inputs to Gemini"
1109 .to_string(),
1110 ));
1111 }
1112
1113 PartKind::FileData(FileData {
1114 mime_type,
1115 file_uri,
1116 })
1117 }
1118 }
1119 DocumentSourceKind::Base64(data) => {
1120 let Some(mime_type) = mime_type else {
1121 return Err(MessageError::ConversionError(
1122 "A media type is expected for base64 encoded strings"
1123 .to_string(),
1124 ));
1125 };
1126 PartKind::InlineData(Blob { mime_type, data })
1127 }
1128 DocumentSourceKind::String(_) => {
1129 return Err(message::MessageError::ConversionError(
1130 "Strings cannot be used as audio files!".into(),
1131 ));
1132 }
1133 DocumentSourceKind::Raw(_) => {
1134 return Err(message::MessageError::ConversionError(
1135 "Raw file data not supported, encode as base64 first".into(),
1136 ));
1137 }
1138 DocumentSourceKind::FileId(_) => {
1139 return Err(message::MessageError::ConversionError(
1140 "Provider file IDs are not supported for Gemini video inputs"
1141 .into(),
1142 ));
1143 }
1144 DocumentSourceKind::Unknown => {
1145 return Err(message::MessageError::ConversionError(
1146 "Media type for video is required for Gemini".to_string(),
1147 ));
1148 }
1149 };
1150
1151 Ok(Part {
1152 thought: Some(false),
1153 thought_signature: None,
1154 part,
1155 additional_params,
1156 })
1157 }
1158 }
1159 }
1160 }
1161
1162 impl TryFrom<message::AssistantContent> for Part {
1163 type Error = message::MessageError;
1164
1165 fn try_from(content: message::AssistantContent) -> Result<Self, Self::Error> {
1166 match content {
1167 message::AssistantContent::Text(message::Text { text, .. }) => Ok(text.into()),
1168 message::AssistantContent::Image(message::Image {
1169 data, media_type, ..
1170 }) => match media_type {
1171 Some(media_type) => match media_type {
1172 message::ImageMediaType::JPEG
1173 | message::ImageMediaType::PNG
1174 | message::ImageMediaType::WEBP
1175 | message::ImageMediaType::HEIC
1176 | message::ImageMediaType::HEIF => {
1177 let part = PartKind::try_from((media_type, data))?;
1178 Ok(Part {
1179 thought: Some(false),
1180 thought_signature: None,
1181 part,
1182 additional_params: None,
1183 })
1184 }
1185 _ => Err(message::MessageError::ConversionError(format!(
1186 "Unsupported image media type {media_type:?}"
1187 ))),
1188 },
1189 None => Err(message::MessageError::ConversionError(
1190 "Media type for image is required for Gemini".to_string(),
1191 )),
1192 },
1193 message::AssistantContent::ToolCall(tool_call) => Ok(tool_call.into()),
1194 message::AssistantContent::Reasoning(reasoning) => Ok(Part {
1195 thought: Some(true),
1196 thought_signature: reasoning.first_signature().map(str::to_owned),
1197 part: PartKind::Text(reasoning.display_text()),
1198 additional_params: None,
1199 }),
1200 }
1201 }
1202 }
1203
1204 impl From<message::ToolCall> for Part {
1205 fn from(tool_call: message::ToolCall) -> Self {
1206 Self {
1207 thought: Some(false),
1208 thought_signature: tool_call.signature,
1209 part: PartKind::FunctionCall(FunctionCall {
1210 name: tool_call.function.name,
1211 args: tool_call.function.arguments,
1212 }),
1213 additional_params: None,
1214 }
1215 }
1216 }
1217
1218 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1221 #[serde(rename_all = "camelCase")]
1222 pub struct Blob {
1223 pub mime_type: String,
1226 pub data: String,
1228 }
1229
1230 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1233 pub struct FunctionCall {
1234 pub name: String,
1237 pub args: serde_json::Value,
1239 }
1240
1241 impl From<message::ToolCall> for FunctionCall {
1242 fn from(tool_call: message::ToolCall) -> Self {
1243 Self {
1244 name: tool_call.function.name,
1245 args: tool_call.function.arguments,
1246 }
1247 }
1248 }
1249
1250 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1254 pub struct FunctionResponse {
1255 pub name: String,
1258 #[serde(skip_serializing_if = "Option::is_none")]
1260 pub response: Option<serde_json::Value>,
1261 #[serde(skip_serializing_if = "Option::is_none")]
1263 pub parts: Option<Vec<FunctionResponsePart>>,
1264 }
1265
1266 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1268 #[serde(rename_all = "camelCase")]
1269 pub struct FunctionResponsePart {
1270 #[serde(skip_serializing_if = "Option::is_none")]
1272 pub inline_data: Option<FunctionResponseInlineData>,
1273 #[serde(skip_serializing_if = "Option::is_none")]
1275 pub file_data: Option<FileData>,
1276 }
1277
1278 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1280 #[serde(rename_all = "camelCase")]
1281 pub struct FunctionResponseInlineData {
1282 pub mime_type: String,
1284 pub data: String,
1286 #[serde(skip_serializing_if = "Option::is_none")]
1288 pub display_name: Option<String>,
1289 }
1290
1291 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1293 #[serde(rename_all = "camelCase")]
1294 pub struct FileData {
1295 pub mime_type: Option<String>,
1297 pub file_uri: String,
1299 }
1300
1301 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1302 pub struct SafetyRating {
1303 pub category: HarmCategory,
1304 pub probability: HarmProbability,
1305 }
1306
1307 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1308 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1309 pub enum HarmProbability {
1310 HarmProbabilityUnspecified,
1311 Negligible,
1312 Low,
1313 Medium,
1314 High,
1315 }
1316
1317 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1318 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1319 pub enum HarmCategory {
1320 HarmCategoryUnspecified,
1321 HarmCategoryDerogatory,
1322 HarmCategoryToxicity,
1323 HarmCategoryViolence,
1324 HarmCategorySexually,
1325 HarmCategoryMedical,
1326 HarmCategoryDangerous,
1327 HarmCategoryHarassment,
1328 HarmCategoryHateSpeech,
1329 HarmCategorySexuallyExplicit,
1330 HarmCategoryDangerousContent,
1331 HarmCategoryCivicIntegrity,
1332 }
1333
1334 #[derive(Debug, Deserialize, Clone, Default, Serialize)]
1335 #[serde(rename_all = "camelCase")]
1336 pub struct UsageMetadata {
1337 #[serde(default)]
1338 pub prompt_token_count: i32,
1339 #[serde(skip_serializing_if = "Option::is_none")]
1340 pub cached_content_token_count: Option<i32>,
1341 #[serde(skip_serializing_if = "Option::is_none")]
1342 pub candidates_token_count: Option<i32>,
1343 pub total_token_count: i32,
1344 #[serde(skip_serializing_if = "Option::is_none")]
1345 pub thoughts_token_count: Option<i32>,
1346 #[serde(default, skip_serializing_if = "Option::is_none")]
1347 pub prompt_tokens_details: Option<Vec<ModalityTokenCount>>,
1348 #[serde(default, skip_serializing_if = "Option::is_none")]
1349 pub cache_tokens_details: Option<Vec<ModalityTokenCount>>,
1350 #[serde(default, skip_serializing_if = "Option::is_none")]
1351 pub candidates_tokens_details: Option<Vec<ModalityTokenCount>>,
1352 #[serde(default, skip_serializing_if = "Option::is_none")]
1353 pub tool_use_prompt_token_count: Option<i32>,
1354 #[serde(default, skip_serializing_if = "Option::is_none")]
1355 pub tool_use_prompt_tokens_details: Option<Vec<ModalityTokenCount>>,
1356 #[serde(default, skip_serializing_if = "Option::is_none")]
1357 pub traffic_type: Option<TrafficType>,
1358 }
1359
1360 #[derive(Clone, Debug, Deserialize, Serialize)]
1361 #[serde(rename_all = "camelCase")]
1362 pub struct ModalityTokenCount {
1363 pub modality: Modality,
1364 pub token_count: i32,
1365 }
1366
1367 #[derive(Clone, Debug, Deserialize, Serialize)]
1368 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1369 pub enum Modality {
1370 ModalityUnspecified,
1371 Text,
1372 Image,
1373 Video,
1374 Audio,
1375 Document,
1376 }
1377
1378 #[derive(Clone, Debug, Deserialize, Serialize)]
1379 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1380 pub enum TrafficType {
1381 TrafficTypeUnspecified,
1382 OnDemand,
1383 ProvisionedThroughput,
1384 }
1385
1386 impl std::fmt::Display for UsageMetadata {
1387 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1388 write!(
1389 f,
1390 "Prompt token count: {}\nCached content token count: {}\nCandidates token count: {}\nTotal token count: {}",
1391 self.prompt_token_count,
1392 match self.cached_content_token_count {
1393 Some(count) => count.to_string(),
1394 None => "n/a".to_string(),
1395 },
1396 match self.candidates_token_count {
1397 Some(count) => count.to_string(),
1398 None => "n/a".to_string(),
1399 },
1400 self.total_token_count
1401 )
1402 }
1403 }
1404
1405 impl GetTokenUsage for UsageMetadata {
1406 fn token_usage(&self) -> crate::completion::Usage {
1407 let mut usage = crate::completion::Usage::new();
1408
1409 usage.input_tokens = self.prompt_token_count as u64;
1410 usage.output_tokens = self.candidates_token_count.unwrap_or_default() as u64;
1411 usage.cached_input_tokens = self.cached_content_token_count.unwrap_or_default() as u64;
1412 usage.reasoning_tokens = self.thoughts_token_count.unwrap_or_default() as u64;
1413 usage.tool_use_prompt_tokens =
1414 self.tool_use_prompt_token_count.unwrap_or_default() as u64;
1415 usage.total_tokens = self.total_token_count as u64;
1416
1417 usage
1418 }
1419 }
1420
1421 #[derive(Debug, Deserialize, Serialize)]
1423 #[serde(rename_all = "camelCase")]
1424 pub struct PromptFeedback {
1425 pub block_reason: Option<BlockReason>,
1427 pub safety_ratings: Option<Vec<SafetyRating>>,
1429 }
1430
1431 #[derive(Debug, Deserialize, Serialize)]
1433 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1434 pub enum BlockReason {
1435 BlockReasonUnspecified,
1437 Safety,
1439 Other,
1441 Blocklist,
1443 ProhibitedContent,
1445 }
1446
1447 #[derive(Clone, Debug, Deserialize, Serialize)]
1448 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1449 pub enum FinishReason {
1450 FinishReasonUnspecified,
1452 Stop,
1454 MaxTokens,
1456 Safety,
1458 Recitation,
1460 Language,
1462 Other,
1464 Blocklist,
1466 ProhibitedContent,
1468 Spii,
1470 MalformedFunctionCall,
1472 UnexpectedToolCall,
1474 MissingThoughtSignature,
1476 TooManyToolCalls,
1478 MalformedResponse,
1480 }
1481
1482 #[derive(Clone, Debug, Deserialize, Serialize)]
1483 #[serde(rename_all = "camelCase")]
1484 pub struct CitationMetadata {
1485 pub citation_sources: Vec<CitationSource>,
1486 }
1487
1488 #[derive(Clone, Debug, Deserialize, Serialize)]
1489 #[serde(rename_all = "camelCase")]
1490 pub struct CitationSource {
1491 #[serde(skip_serializing_if = "Option::is_none")]
1492 pub uri: Option<String>,
1493 #[serde(skip_serializing_if = "Option::is_none")]
1494 pub start_index: Option<i32>,
1495 #[serde(skip_serializing_if = "Option::is_none")]
1496 pub end_index: Option<i32>,
1497 #[serde(skip_serializing_if = "Option::is_none")]
1498 pub license: Option<String>,
1499 }
1500
1501 #[derive(Clone, Debug, Deserialize, Serialize)]
1502 #[serde(rename_all = "camelCase")]
1503 pub struct LogprobsResult {
1504 pub top_candidate: Vec<TopCandidate>,
1505 pub chosen_candidate: Vec<LogProbCandidate>,
1506 }
1507
1508 #[derive(Clone, Debug, Deserialize, Serialize)]
1509 pub struct TopCandidate {
1510 pub candidates: Vec<LogProbCandidate>,
1511 }
1512
1513 #[derive(Clone, Debug, Deserialize, Serialize)]
1514 #[serde(rename_all = "camelCase")]
1515 pub struct LogProbCandidate {
1516 pub token: String,
1517 pub token_id: String,
1518 pub log_probability: f64,
1519 }
1520
1521 #[derive(Debug, Deserialize, Serialize)]
1526 #[serde(rename_all = "camelCase")]
1527 pub struct GenerationConfig {
1528 #[serde(skip_serializing_if = "Option::is_none")]
1531 pub stop_sequences: Option<Vec<String>>,
1532 #[serde(skip_serializing_if = "Option::is_none")]
1538 pub response_mime_type: Option<String>,
1539 #[serde(skip_serializing_if = "Option::is_none")]
1543 pub response_schema: Option<Schema>,
1544 #[serde(
1550 skip_serializing_if = "Option::is_none",
1551 rename = "_responseJsonSchema"
1552 )]
1553 pub _response_json_schema: Option<Value>,
1554 #[serde(skip_serializing_if = "Option::is_none")]
1556 pub response_json_schema: Option<Value>,
1557 #[serde(skip_serializing_if = "Option::is_none")]
1560 pub candidate_count: Option<i32>,
1561 #[serde(skip_serializing_if = "Option::is_none")]
1564 pub max_output_tokens: Option<u64>,
1565 #[serde(skip_serializing_if = "Option::is_none")]
1568 pub temperature: Option<f64>,
1569 #[serde(skip_serializing_if = "Option::is_none")]
1576 pub top_p: Option<f64>,
1577 #[serde(skip_serializing_if = "Option::is_none")]
1583 pub top_k: Option<i32>,
1584 #[serde(skip_serializing_if = "Option::is_none")]
1590 pub presence_penalty: Option<f64>,
1591 #[serde(skip_serializing_if = "Option::is_none")]
1599 pub frequency_penalty: Option<f64>,
1600 #[serde(skip_serializing_if = "Option::is_none")]
1602 pub response_logprobs: Option<bool>,
1603 #[serde(skip_serializing_if = "Option::is_none")]
1606 pub logprobs: Option<i32>,
1607 #[serde(skip_serializing_if = "Option::is_none")]
1609 pub thinking_config: Option<ThinkingConfig>,
1610 #[serde(skip_serializing_if = "Option::is_none")]
1612 pub response_modalities: Option<Vec<ResponseModality>>,
1613 #[serde(skip_serializing_if = "Option::is_none")]
1614 pub image_config: Option<ImageConfig>,
1615 }
1616
1617 impl Default for GenerationConfig {
1618 fn default() -> Self {
1619 Self {
1620 temperature: Some(1.0),
1621 max_output_tokens: Some(4096),
1622 stop_sequences: None,
1623 response_mime_type: None,
1624 response_schema: None,
1625 _response_json_schema: None,
1626 response_json_schema: None,
1627 candidate_count: None,
1628 top_p: None,
1629 top_k: None,
1630 presence_penalty: None,
1631 frequency_penalty: None,
1632 response_logprobs: None,
1633 logprobs: None,
1634 thinking_config: None,
1635 response_modalities: None,
1636 image_config: None,
1637 }
1638 }
1639 }
1640
1641 #[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
1643 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1644 pub enum ResponseModality {
1645 Text,
1646 Image,
1647 Audio,
1648 }
1649
1650 #[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
1652 #[serde(rename_all = "snake_case")]
1653 pub enum ThinkingLevel {
1654 Minimal,
1655 Low,
1656 Medium,
1657 High,
1658 }
1659
1660 #[derive(Debug, Deserialize, Serialize)]
1664 #[serde(rename_all = "camelCase")]
1665 pub struct ThinkingConfig {
1666 #[serde(skip_serializing_if = "Option::is_none")]
1668 pub thinking_budget: Option<u32>,
1669 #[serde(skip_serializing_if = "Option::is_none")]
1671 pub thinking_level: Option<ThinkingLevel>,
1672 #[serde(skip_serializing_if = "Option::is_none")]
1674 pub include_thoughts: Option<bool>,
1675 }
1676
1677 #[derive(Debug, Deserialize, Serialize)]
1678 #[serde(rename_all = "camelCase")]
1679 pub struct ImageConfig {
1680 #[serde(skip_serializing_if = "Option::is_none")]
1681 pub aspect_ratio: Option<String>,
1682 #[serde(skip_serializing_if = "Option::is_none")]
1683 pub image_size: Option<String>,
1684 }
1685
1686 #[derive(Debug, Deserialize, Serialize, Clone)]
1690 pub struct Schema {
1691 pub r#type: String,
1692 #[serde(skip_serializing_if = "Option::is_none")]
1693 pub format: Option<String>,
1694 #[serde(skip_serializing_if = "Option::is_none")]
1695 pub description: Option<String>,
1696 #[serde(skip_serializing_if = "Option::is_none")]
1697 pub nullable: Option<bool>,
1698 #[serde(skip_serializing_if = "Option::is_none")]
1699 pub r#enum: Option<Vec<String>>,
1700 #[serde(skip_serializing_if = "Option::is_none")]
1701 pub max_items: Option<i32>,
1702 #[serde(skip_serializing_if = "Option::is_none")]
1703 pub min_items: Option<i32>,
1704 #[serde(skip_serializing_if = "Option::is_none")]
1705 pub properties: Option<HashMap<String, Schema>>,
1706 #[serde(skip_serializing_if = "Option::is_none")]
1707 pub required: Option<Vec<String>>,
1708 #[serde(skip_serializing_if = "Option::is_none")]
1709 pub items: Option<Box<Schema>>,
1710 }
1711
1712 pub fn tool_parameters_to_schema(parameters: Value) -> Result<Option<Schema>, CompletionError> {
1718 if parameters.is_null() || parameters == json!({"type": "object", "properties": {}}) {
1719 Ok(None)
1720 } else {
1721 parameters.try_into().map(Some)
1722 }
1723 }
1724
1725 pub fn flatten_schema(mut schema: Value) -> Result<Value, CompletionError> {
1731 let defs = if let Some(obj) = schema.as_object() {
1733 obj.get("$defs").or_else(|| obj.get("definitions")).cloned()
1734 } else {
1735 None
1736 };
1737
1738 let Some(defs_value) = defs else {
1739 return Ok(schema);
1740 };
1741
1742 let Some(defs_obj) = defs_value.as_object() else {
1743 return Err(CompletionError::ResponseError(
1744 "$defs must be an object".into(),
1745 ));
1746 };
1747
1748 resolve_refs(&mut schema, defs_obj)?;
1749
1750 if let Some(obj) = schema.as_object_mut() {
1752 obj.remove("$defs");
1753 obj.remove("definitions");
1754 }
1755
1756 Ok(schema)
1757 }
1758
1759 fn resolve_refs(
1762 value: &mut Value,
1763 defs: &serde_json::Map<String, Value>,
1764 ) -> Result<(), CompletionError> {
1765 match value {
1766 Value::Object(obj) => {
1767 if let Some(ref_value) = obj.get("$ref")
1768 && let Some(ref_str) = ref_value.as_str()
1769 {
1770 let def_name = parse_ref_path(ref_str)?;
1772
1773 let def = defs.get(&def_name).ok_or_else(|| {
1774 CompletionError::ResponseError(format!("Reference not found: {}", ref_str))
1775 })?;
1776
1777 let mut resolved = def.clone();
1778 resolve_refs(&mut resolved, defs)?;
1779 *value = resolved;
1780 return Ok(());
1781 }
1782
1783 for (_, v) in obj.iter_mut() {
1784 resolve_refs(v, defs)?;
1785 }
1786 }
1787 Value::Array(arr) => {
1788 for item in arr.iter_mut() {
1789 resolve_refs(item, defs)?;
1790 }
1791 }
1792 _ => {}
1793 }
1794
1795 Ok(())
1796 }
1797
1798 fn parse_ref_path(ref_str: &str) -> Result<String, CompletionError> {
1804 if let Some(fragment) = ref_str.strip_prefix('#') {
1805 if let Some(name) = fragment.strip_prefix("/$defs/") {
1806 Ok(name.to_string())
1807 } else if let Some(name) = fragment.strip_prefix("/definitions/") {
1808 Ok(name.to_string())
1809 } else {
1810 Err(CompletionError::ResponseError(format!(
1811 "Unsupported reference format: {}",
1812 ref_str
1813 )))
1814 }
1815 } else {
1816 Err(CompletionError::ResponseError(format!(
1817 "Only fragment references (#/...) are supported: {}",
1818 ref_str
1819 )))
1820 }
1821 }
1822
1823 fn extract_type(type_value: &Value) -> Option<String> {
1826 if let Some(t) = type_value.as_str() {
1827 return Some(t.to_string());
1828 }
1829
1830 type_value.as_array().and_then(|arr| {
1831 arr.iter()
1832 .filter_map(|v| v.as_str())
1833 .find(|t| *t != "null")
1834 .or_else(|| arr.iter().find_map(|v| v.as_str()))
1835 .map(str::to_owned)
1836 })
1837 }
1838
1839 fn schema_is_null(obj: &serde_json::Map<String, Value>) -> bool {
1840 obj.get("type")
1841 .and_then(extract_type)
1842 .as_deref()
1843 .is_some_and(|t| t == "null")
1844 }
1845
1846 fn schema_is_nullable(obj: &serde_json::Map<String, Value>) -> bool {
1847 obj.get("nullable")
1848 .and_then(|v| v.as_bool())
1849 .unwrap_or(false)
1850 || obj
1851 .get("type")
1852 .and_then(|v| v.as_array())
1853 .is_some_and(|arr| arr.iter().any(|v| v.as_str() == Some("null")))
1854 || ["anyOf", "oneOf", "allOf"].iter().any(|key| {
1855 obj.get(*key).and_then(|v| v.as_array()).is_some_and(|arr| {
1856 arr.iter()
1857 .filter_map(|schema| schema.as_object())
1858 .any(schema_is_null)
1859 })
1860 })
1861 }
1862
1863 fn extract_type_from_composition(composition: &Value) -> Option<String> {
1866 composition.as_array().and_then(|arr| {
1867 arr.iter().find_map(|schema| {
1868 let obj = schema.as_object()?;
1869 if schema_is_null(obj) {
1870 return None;
1871 }
1872
1873 obj.get("type").and_then(extract_type).or_else(|| {
1874 if obj.contains_key("properties") {
1875 Some("object".to_string())
1876 } else if obj.contains_key("enum") {
1877 Some("string".to_string())
1879 } else {
1880 None
1881 }
1882 })
1883 })
1884 })
1885 }
1886
1887 fn extract_schema_from_composition(
1890 composition: &Value,
1891 ) -> Option<serde_json::Map<String, Value>> {
1892 composition.as_array().and_then(|arr| {
1893 arr.iter().find_map(|schema| {
1894 let obj = schema.as_object()?;
1895 if schema_is_null(obj) {
1896 None
1897 } else {
1898 Some(obj.clone())
1899 }
1900 })
1901 })
1902 }
1903
1904 fn extract_schema_from_composition_obj(
1905 obj: &serde_json::Map<String, Value>,
1906 ) -> Option<serde_json::Map<String, Value>> {
1907 obj.get("anyOf")
1908 .and_then(extract_schema_from_composition)
1909 .or_else(|| obj.get("oneOf").and_then(extract_schema_from_composition))
1910 .or_else(|| obj.get("allOf").and_then(extract_schema_from_composition))
1911 }
1912
1913 fn infer_type(obj: &serde_json::Map<String, Value>) -> String {
1916 if let Some(type_val) = obj.get("type")
1918 && let Some(type_str) = extract_type(type_val)
1919 {
1920 return type_str;
1921 }
1922
1923 if let Some(any_of) = obj.get("anyOf")
1925 && let Some(type_str) = extract_type_from_composition(any_of)
1926 {
1927 return type_str;
1928 }
1929
1930 if let Some(one_of) = obj.get("oneOf")
1931 && let Some(type_str) = extract_type_from_composition(one_of)
1932 {
1933 return type_str;
1934 }
1935
1936 if let Some(all_of) = obj.get("allOf")
1937 && let Some(type_str) = extract_type_from_composition(all_of)
1938 {
1939 return type_str;
1940 }
1941
1942 if obj.contains_key("properties") {
1944 "object".to_string()
1945 } else if obj.contains_key("enum") {
1946 "string".to_string()
1947 } else {
1948 String::new()
1949 }
1950 }
1951
1952 impl TryFrom<Value> for Schema {
1953 type Error = CompletionError;
1954
1955 fn try_from(value: Value) -> Result<Self, Self::Error> {
1956 let flattened_val = flatten_schema(value)?;
1957 if let Some(obj) = flattened_val.as_object() {
1958 let composition_source = extract_schema_from_composition_obj(obj);
1961 let props_source = if obj.get("properties").is_none() {
1962 composition_source.clone().unwrap_or(obj.clone())
1963 } else {
1964 obj.clone()
1965 };
1966
1967 let schema_type = infer_type(obj);
1968 let items = obj
1969 .get("items")
1970 .or_else(|| props_source.get("items"))
1971 .and_then(|v| v.clone().try_into().ok())
1972 .map(Box::new);
1973
1974 let items = if schema_type == "array" && items.is_none() {
1977 Some(Box::new(Schema {
1978 r#type: "string".to_string(),
1979 format: None,
1980 description: None,
1981 nullable: None,
1982 r#enum: None,
1983 max_items: None,
1984 min_items: None,
1985 properties: None,
1986 required: None,
1987 items: None,
1988 }))
1989 } else {
1990 items
1991 };
1992
1993 Ok(Schema {
1994 r#type: schema_type,
1995 format: obj
1996 .get("format")
1997 .or_else(|| props_source.get("format"))
1998 .and_then(|v| v.as_str())
1999 .map(String::from),
2000 description: obj
2001 .get("description")
2002 .or_else(|| props_source.get("description"))
2003 .and_then(|v| v.as_str())
2004 .map(String::from),
2005 nullable: if schema_is_nullable(obj)
2006 || composition_source.as_ref().is_some_and(schema_is_nullable)
2007 {
2008 Some(true)
2009 } else {
2010 None
2011 },
2012 r#enum: obj
2013 .get("enum")
2014 .or_else(|| props_source.get("enum"))
2015 .and_then(|v| v.as_array())
2016 .map(|arr| {
2017 arr.iter()
2018 .filter_map(|v| v.as_str().map(String::from))
2019 .collect()
2020 }),
2021 max_items: obj
2022 .get("maxItems")
2023 .and_then(|v| v.as_i64())
2024 .map(|v| v as i32),
2025 min_items: obj
2026 .get("minItems")
2027 .and_then(|v| v.as_i64())
2028 .map(|v| v as i32),
2029 properties: props_source
2030 .get("properties")
2031 .and_then(|v| v.as_object())
2032 .map(|map| {
2033 map.iter()
2034 .filter_map(|(k, v)| {
2035 v.clone().try_into().ok().map(|schema| (k.clone(), schema))
2036 })
2037 .collect()
2038 }),
2039 required: props_source
2040 .get("required")
2041 .and_then(|v| v.as_array())
2042 .map(|arr| {
2043 arr.iter()
2044 .filter_map(|v| v.as_str().map(String::from))
2045 .collect()
2046 }),
2047 items,
2048 })
2049 } else {
2050 Err(CompletionError::ResponseError(
2051 "Expected a JSON object for Schema".into(),
2052 ))
2053 }
2054 }
2055 }
2056
2057 #[derive(Debug, Serialize)]
2058 #[serde(rename_all = "camelCase")]
2059 pub struct GenerateContentRequest {
2060 pub contents: Vec<Content>,
2061 #[serde(skip_serializing_if = "Option::is_none")]
2062 pub tools: Option<Vec<Value>>,
2063 pub tool_config: Option<ToolConfig>,
2064 pub generation_config: Option<GenerationConfig>,
2066 pub safety_settings: Option<Vec<SafetySetting>>,
2080 pub system_instruction: Option<Content>,
2083 #[serde(flatten, skip_serializing_if = "Option::is_none")]
2086 pub additional_params: Option<serde_json::Value>,
2087 }
2088
2089 #[derive(Debug, Serialize)]
2090 #[serde(rename_all = "camelCase")]
2091 pub struct Tool {
2092 pub function_declarations: Vec<FunctionDeclaration>,
2093 pub code_execution: Option<CodeExecution>,
2094 }
2095
2096 #[derive(Debug, Serialize, Clone)]
2097 #[serde(rename_all = "camelCase")]
2098 pub struct FunctionDeclaration {
2099 pub name: String,
2100 pub description: String,
2101 #[serde(skip_serializing_if = "Option::is_none")]
2102 pub parameters: Option<Schema>,
2103 }
2104
2105 #[derive(Debug, Serialize, Deserialize)]
2106 #[serde(rename_all = "camelCase")]
2107 pub struct ToolConfig {
2108 pub function_calling_config: Option<FunctionCallingMode>,
2109 }
2110
2111 #[derive(Debug, Serialize, Deserialize, Default)]
2112 #[serde(tag = "mode", rename_all = "UPPERCASE")]
2113 pub enum FunctionCallingMode {
2114 #[default]
2115 Auto,
2116 None,
2117 Any {
2118 #[serde(skip_serializing_if = "Option::is_none")]
2119 allowed_function_names: Option<Vec<String>>,
2120 },
2121 }
2122
2123 impl TryFrom<message::ToolChoice> for FunctionCallingMode {
2124 type Error = CompletionError;
2125 fn try_from(value: message::ToolChoice) -> Result<Self, Self::Error> {
2126 let res = match value {
2127 message::ToolChoice::Auto => Self::Auto,
2128 message::ToolChoice::None => Self::None,
2129 message::ToolChoice::Required => Self::Any {
2130 allowed_function_names: None,
2131 },
2132 message::ToolChoice::Specific { function_names } => Self::Any {
2133 allowed_function_names: Some(function_names),
2134 },
2135 };
2136
2137 Ok(res)
2138 }
2139 }
2140
2141 #[derive(Debug, Serialize)]
2142 pub struct CodeExecution {}
2143
2144 #[derive(Debug, Serialize)]
2145 #[serde(rename_all = "camelCase")]
2146 pub struct SafetySetting {
2147 pub category: HarmCategory,
2148 pub threshold: HarmBlockThreshold,
2149 }
2150
2151 #[derive(Debug, Serialize)]
2152 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
2153 pub enum HarmBlockThreshold {
2154 HarmBlockThresholdUnspecified,
2155 BlockLowAndAbove,
2156 BlockMediumAndAbove,
2157 BlockOnlyHigh,
2158 BlockNone,
2159 Off,
2160 }
2161}
2162
2163#[cfg(test)]
2164mod tests {
2165 use crate::{
2166 message,
2167 providers::gemini::completion::gemini_api_types::{
2168 ContentCandidate, FinishReason, FunctionCall, Schema, UsageMetadata, flatten_schema,
2169 tool_parameters_to_schema,
2170 },
2171 };
2172
2173 use super::*;
2174 use serde_json::json;
2175
2176 #[test]
2177 fn test_resolve_request_model_uses_override() {
2178 let request = CompletionRequest {
2179 model: Some("gemini-2.5-flash".to_string()),
2180 preamble: None,
2181 chat_history: crate::OneOrMany::one("Hello".into()),
2182 documents: vec![],
2183 tools: vec![],
2184 temperature: None,
2185 max_tokens: None,
2186 tool_choice: None,
2187 additional_params: None,
2188 output_schema: None,
2189 };
2190
2191 let request_model = resolve_request_model("gemini-2.0-flash", &request);
2192 assert_eq!(request_model, "gemini-2.5-flash");
2193 assert_eq!(
2194 completion_endpoint(&request_model),
2195 "/v1beta/models/gemini-2.5-flash:generateContent"
2196 );
2197 assert_eq!(
2198 streaming_endpoint(&request_model),
2199 "/v1beta/models/gemini-2.5-flash:streamGenerateContent"
2200 );
2201 }
2202
2203 #[test]
2204 fn test_resolve_request_model_uses_default_when_unset() {
2205 let request = CompletionRequest {
2206 model: None,
2207 preamble: None,
2208 chat_history: crate::OneOrMany::one("Hello".into()),
2209 documents: vec![],
2210 tools: vec![],
2211 temperature: None,
2212 max_tokens: None,
2213 tool_choice: None,
2214 additional_params: None,
2215 output_schema: None,
2216 };
2217
2218 assert_eq!(
2219 resolve_request_model("gemini-2.0-flash", &request),
2220 "gemini-2.0-flash"
2221 );
2222 }
2223
2224 #[test]
2225 fn test_deserialize_message_user() {
2226 let raw_message = r#"{
2227 "parts": [
2228 {"text": "Hello, world!"},
2229 {"inlineData": {"mimeType": "image/png", "data": "base64encodeddata"}},
2230 {"functionCall": {"name": "test_function", "args": {"arg1": "value1"}}},
2231 {"functionResponse": {"name": "test_function", "response": {"result": "success"}}},
2232 {"fileData": {"mimeType": "application/pdf", "fileUri": "http://example.com/file.pdf"}},
2233 {"executableCode": {"code": "print('Hello, world!')", "language": "PYTHON"}},
2234 {"codeExecutionResult": {"output": "Hello, world!", "outcome": "OUTCOME_OK"}}
2235 ],
2236 "role": "user"
2237 }"#;
2238
2239 let content: Content = {
2240 let jd = &mut serde_json::Deserializer::from_str(raw_message);
2241 serde_path_to_error::deserialize(jd).unwrap_or_else(|err| {
2242 panic!("Deserialization error at {}: {}", err.path(), err);
2243 })
2244 };
2245 assert_eq!(content.role, Some(Role::User));
2246 assert_eq!(content.parts.len(), 7);
2247
2248 let parts: Vec<Part> = content.parts.into_iter().collect();
2249
2250 if let Part {
2251 part: PartKind::Text(text),
2252 ..
2253 } = &parts[0]
2254 {
2255 assert_eq!(text, "Hello, world!");
2256 } else {
2257 panic!("Expected text part");
2258 }
2259
2260 if let Part {
2261 part: PartKind::InlineData(inline_data),
2262 ..
2263 } = &parts[1]
2264 {
2265 assert_eq!(inline_data.mime_type, "image/png");
2266 assert_eq!(inline_data.data, "base64encodeddata");
2267 } else {
2268 panic!("Expected inline data part");
2269 }
2270
2271 if let Part {
2272 part: PartKind::FunctionCall(function_call),
2273 ..
2274 } = &parts[2]
2275 {
2276 assert_eq!(function_call.name, "test_function");
2277 assert_eq!(
2278 function_call.args.as_object().unwrap().get("arg1").unwrap(),
2279 "value1"
2280 );
2281 } else {
2282 panic!("Expected function call part");
2283 }
2284
2285 if let Part {
2286 part: PartKind::FunctionResponse(function_response),
2287 ..
2288 } = &parts[3]
2289 {
2290 assert_eq!(function_response.name, "test_function");
2291 assert_eq!(
2292 function_response
2293 .response
2294 .as_ref()
2295 .unwrap()
2296 .get("result")
2297 .unwrap(),
2298 "success"
2299 );
2300 } else {
2301 panic!("Expected function response part");
2302 }
2303
2304 if let Part {
2305 part: PartKind::FileData(file_data),
2306 ..
2307 } = &parts[4]
2308 {
2309 assert_eq!(file_data.mime_type.as_ref().unwrap(), "application/pdf");
2310 assert_eq!(file_data.file_uri, "http://example.com/file.pdf");
2311 } else {
2312 panic!("Expected file data part");
2313 }
2314
2315 if let Part {
2316 part: PartKind::ExecutableCode(executable_code),
2317 ..
2318 } = &parts[5]
2319 {
2320 assert_eq!(executable_code.code, "print('Hello, world!')");
2321 } else {
2322 panic!("Expected executable code part");
2323 }
2324
2325 if let Part {
2326 part: PartKind::CodeExecutionResult(code_execution_result),
2327 ..
2328 } = &parts[6]
2329 {
2330 assert_eq!(
2331 code_execution_result.clone().output.unwrap(),
2332 "Hello, world!"
2333 );
2334 } else {
2335 panic!("Expected code execution result part");
2336 }
2337 }
2338
2339 #[test]
2340 fn test_deserialize_message_model() {
2341 let json_data = json!({
2342 "parts": [{"text": "Hello, user!"}],
2343 "role": "model"
2344 });
2345
2346 let content: Content = serde_json::from_value(json_data).unwrap();
2347 assert_eq!(content.role, Some(Role::Model));
2348 assert_eq!(content.parts.len(), 1);
2349 if let Some(Part {
2350 part: PartKind::Text(text),
2351 ..
2352 }) = content.parts.first()
2353 {
2354 assert_eq!(text, "Hello, user!");
2355 } else {
2356 panic!("Expected text part");
2357 }
2358 }
2359
2360 #[test]
2361 fn test_message_conversion_user() {
2362 let msg = message::Message::user("Hello, world!");
2363 let content: Content = msg.try_into().unwrap();
2364 assert_eq!(content.role, Some(Role::User));
2365 assert_eq!(content.parts.len(), 1);
2366 if let Some(Part {
2367 part: PartKind::Text(text),
2368 ..
2369 }) = &content.parts.first()
2370 {
2371 assert_eq!(text, "Hello, world!");
2372 } else {
2373 panic!("Expected text part");
2374 }
2375 }
2376
2377 #[test]
2378 fn test_message_conversion_model() {
2379 let msg = message::Message::assistant("Hello, user!");
2380
2381 let content: Content = msg.try_into().unwrap();
2382 assert_eq!(content.role, Some(Role::Model));
2383 assert_eq!(content.parts.len(), 1);
2384 if let Some(Part {
2385 part: PartKind::Text(text),
2386 ..
2387 }) = &content.parts.first()
2388 {
2389 assert_eq!(text, "Hello, user!");
2390 } else {
2391 panic!("Expected text part");
2392 }
2393 }
2394
2395 #[test]
2396 fn test_thought_signature_is_preserved_from_response_reasoning_part() {
2397 let response = GenerateContentResponse {
2398 response_id: "resp_1".to_string(),
2399 candidates: vec![ContentCandidate {
2400 content: Some(Content {
2401 parts: vec![Part {
2402 thought: Some(true),
2403 thought_signature: Some("thought_sig_123".to_string()),
2404 part: PartKind::Text("thinking text".to_string()),
2405 additional_params: None,
2406 }],
2407 role: Some(Role::Model),
2408 }),
2409 finish_reason: Some(FinishReason::Stop),
2410 safety_ratings: None,
2411 citation_metadata: None,
2412 token_count: None,
2413 avg_logprobs: None,
2414 logprobs_result: None,
2415 index: Some(0),
2416 finish_message: None,
2417 }],
2418 prompt_feedback: None,
2419 usage_metadata: None,
2420 model_version: None,
2421 };
2422
2423 let converted: crate::completion::CompletionResponse<GenerateContentResponse> =
2424 response.try_into().expect("convert response");
2425 let first = converted.choice.first();
2426 assert!(matches!(
2427 first,
2428 message::AssistantContent::Reasoning(message::Reasoning { content, .. })
2429 if matches!(
2430 content.first(),
2431 Some(message::ReasoningContent::Text {
2432 text,
2433 signature: Some(signature)
2434 }) if text == "thinking text" && signature == "thought_sig_123"
2435 )
2436 ));
2437 }
2438
2439 #[test]
2440 fn test_tool_protocol_finish_reason_returns_response_error() {
2441 for (reason, finish_message) in [
2442 (
2443 FinishReason::MalformedFunctionCall,
2444 "malformed function call: default_api",
2445 ),
2446 (
2447 FinishReason::UnexpectedToolCall,
2448 "unexpected tool call: default_api",
2449 ),
2450 (
2451 FinishReason::MissingThoughtSignature,
2452 "missing thought signature for tool call",
2453 ),
2454 (
2455 FinishReason::TooManyToolCalls,
2456 "too many tool calls in response",
2457 ),
2458 (
2459 FinishReason::MalformedResponse,
2460 "malformed response from provider",
2461 ),
2462 ] {
2463 let reason_name = format!("{reason:?}");
2464 let response = GenerateContentResponse {
2465 response_id: "resp_tool_protocol_error".to_string(),
2466 candidates: vec![ContentCandidate {
2467 content: Some(Content {
2468 parts: vec![Part {
2469 thought: None,
2470 thought_signature: None,
2471 part: PartKind::FunctionCall(FunctionCall {
2472 name: "default_api".to_string(),
2473 args: json!({"x": 1}),
2474 }),
2475 additional_params: None,
2476 }],
2477 role: Some(Role::Model),
2478 }),
2479 finish_reason: Some(reason),
2480 safety_ratings: None,
2481 citation_metadata: None,
2482 token_count: None,
2483 avg_logprobs: None,
2484 logprobs_result: None,
2485 index: Some(0),
2486 finish_message: Some(finish_message.to_string()),
2487 }],
2488 prompt_feedback: None,
2489 usage_metadata: None,
2490 model_version: None,
2491 };
2492
2493 let err = crate::completion::CompletionResponse::<GenerateContentResponse>::try_from(
2494 response,
2495 )
2496 .expect_err("tool protocol finish reason should fail");
2497
2498 assert!(matches!(
2499 err,
2500 CompletionError::ResponseError(message)
2501 if message.contains(&reason_name)
2502 && message.contains(finish_message)
2503 ));
2504 }
2505 }
2506
2507 #[test]
2508 fn test_completion_response_usage_preserves_cached_and_reasoning_tokens() {
2509 let response = GenerateContentResponse {
2510 response_id: "resp_1".to_string(),
2511 candidates: vec![ContentCandidate {
2512 content: Some(Content {
2513 parts: vec![Part {
2514 thought: None,
2515 thought_signature: None,
2516 part: PartKind::Text("answer".to_string()),
2517 additional_params: None,
2518 }],
2519 role: Some(Role::Model),
2520 }),
2521 finish_reason: Some(FinishReason::Stop),
2522 safety_ratings: None,
2523 citation_metadata: None,
2524 token_count: None,
2525 avg_logprobs: None,
2526 logprobs_result: None,
2527 index: Some(0),
2528 finish_message: None,
2529 }],
2530 prompt_feedback: None,
2531 usage_metadata: Some(UsageMetadata {
2532 prompt_token_count: 40,
2533 cached_content_token_count: Some(20),
2534 candidates_token_count: Some(30),
2535 total_token_count: 100,
2536 thoughts_token_count: Some(10),
2537 prompt_tokens_details: None,
2538 cache_tokens_details: None,
2539 candidates_tokens_details: None,
2540 tool_use_prompt_token_count: Some(12),
2541 tool_use_prompt_tokens_details: None,
2542 traffic_type: None,
2543 }),
2544 model_version: Some("gemini-2.0-flash-001".to_string()),
2545 };
2546
2547 let converted: crate::completion::CompletionResponse<GenerateContentResponse> =
2548 response.try_into().expect("convert response");
2549
2550 assert_eq!(converted.usage.input_tokens, 40);
2551 assert_eq!(converted.usage.cached_input_tokens, 20);
2552 assert_eq!(converted.usage.output_tokens, 30);
2553 assert_eq!(converted.usage.reasoning_tokens, 10);
2554 assert_eq!(converted.usage.tool_use_prompt_tokens, 12);
2555 assert_eq!(converted.usage.total_tokens, 100);
2556 }
2557
2558 #[test]
2559 fn test_reasoning_signature_is_emitted_in_gemini_part() {
2560 let msg = message::Message::Assistant {
2561 id: None,
2562 content: OneOrMany::one(message::AssistantContent::Reasoning(
2563 message::Reasoning::new_with_signature(
2564 "structured thought",
2565 Some("reuse_sig_456".to_string()),
2566 ),
2567 )),
2568 };
2569
2570 let converted: Content = msg.try_into().expect("convert message");
2571 let first = converted.parts.first().expect("reasoning part");
2572 assert_eq!(first.thought, Some(true));
2573 assert_eq!(first.thought_signature.as_deref(), Some("reuse_sig_456"));
2574 assert!(matches!(
2575 &first.part,
2576 PartKind::Text(text) if text == "structured thought"
2577 ));
2578 }
2579
2580 #[test]
2581 fn test_message_conversion_tool_call() {
2582 let tool_call = message::ToolCall {
2583 id: "test_tool".to_string(),
2584 call_id: None,
2585 function: message::ToolFunction {
2586 name: "test_function".to_string(),
2587 arguments: json!({"arg1": "value1"}),
2588 },
2589 signature: None,
2590 additional_params: None,
2591 };
2592
2593 let msg = message::Message::Assistant {
2594 id: None,
2595 content: OneOrMany::one(message::AssistantContent::ToolCall(tool_call)),
2596 };
2597
2598 let content: Content = msg.try_into().unwrap();
2599 assert_eq!(content.role, Some(Role::Model));
2600 assert_eq!(content.parts.len(), 1);
2601 if let Some(Part {
2602 part: PartKind::FunctionCall(function_call),
2603 ..
2604 }) = content.parts.first()
2605 {
2606 assert_eq!(function_call.name, "test_function");
2607 assert_eq!(
2608 function_call.args.as_object().unwrap().get("arg1").unwrap(),
2609 "value1"
2610 );
2611 } else {
2612 panic!("Expected function call part");
2613 }
2614 }
2615
2616 #[test]
2617 fn test_vec_schema_conversion() {
2618 let schema_with_ref = json!({
2619 "type": "array",
2620 "items": {
2621 "$ref": "#/$defs/Person"
2622 },
2623 "$defs": {
2624 "Person": {
2625 "type": "object",
2626 "properties": {
2627 "first_name": {
2628 "type": ["string", "null"],
2629 "description": "The person's first name, if provided (null otherwise)"
2630 },
2631 "last_name": {
2632 "type": ["string", "null"],
2633 "description": "The person's last name, if provided (null otherwise)"
2634 },
2635 "job": {
2636 "type": ["string", "null"],
2637 "description": "The person's job, if provided (null otherwise)"
2638 }
2639 },
2640 "required": []
2641 }
2642 }
2643 });
2644
2645 let result: Result<Schema, _> = schema_with_ref.try_into();
2646
2647 match result {
2648 Ok(schema) => {
2649 assert_eq!(schema.r#type, "array");
2650
2651 if let Some(items) = schema.items {
2652 println!("item types: {}", items.r#type);
2653
2654 assert_ne!(items.r#type, "", "Items type should not be empty string!");
2655 assert_eq!(items.r#type, "object", "Items should be object type");
2656 } else {
2657 panic!("Schema should have items field for array type");
2658 }
2659 }
2660 Err(e) => println!("Schema conversion failed: {:?}", e),
2661 }
2662 }
2663
2664 #[test]
2665 fn test_object_schema() {
2666 let simple_schema = json!({
2667 "type": "object",
2668 "properties": {
2669 "name": {
2670 "type": "string"
2671 }
2672 }
2673 });
2674
2675 let schema: Schema = simple_schema.try_into().unwrap();
2676 assert_eq!(schema.r#type, "object");
2677 assert!(schema.properties.is_some());
2678 }
2679
2680 #[test]
2681 fn test_array_with_inline_items() {
2682 let inline_schema = json!({
2683 "type": "array",
2684 "items": {
2685 "type": "object",
2686 "properties": {
2687 "name": {
2688 "type": "string"
2689 }
2690 }
2691 }
2692 });
2693
2694 let schema: Schema = inline_schema.try_into().unwrap();
2695 assert_eq!(schema.r#type, "array");
2696
2697 if let Some(items) = schema.items {
2698 assert_eq!(items.r#type, "object");
2699 assert!(items.properties.is_some());
2700 } else {
2701 panic!("Schema should have items field");
2702 }
2703 }
2704 #[test]
2705 fn test_flattened_schema() {
2706 let ref_schema = json!({
2707 "type": "array",
2708 "items": {
2709 "$ref": "#/$defs/Person"
2710 },
2711 "$defs": {
2712 "Person": {
2713 "type": "object",
2714 "properties": {
2715 "name": { "type": "string" }
2716 }
2717 }
2718 }
2719 });
2720
2721 let flattened = flatten_schema(ref_schema).unwrap();
2722 let schema: Schema = flattened.try_into().unwrap();
2723
2724 assert_eq!(schema.r#type, "array");
2725
2726 if let Some(items) = schema.items {
2727 println!("Flattened items type: '{}'", items.r#type);
2728
2729 assert_eq!(items.r#type, "object");
2730 assert!(items.properties.is_some());
2731 }
2732 }
2733
2734 #[test]
2735 fn test_array_without_items_gets_default() {
2736 let schema_json = json!({
2737 "type": "object",
2738 "properties": {
2739 "service_ids": {
2740 "type": "array",
2741 "description": "A list of service IDs"
2742 }
2743 }
2744 });
2745
2746 let schema: Schema = schema_json.try_into().unwrap();
2747 let props = schema.properties.unwrap();
2748 let service_ids = props.get("service_ids").unwrap();
2749 assert_eq!(service_ids.r#type, "array");
2750 let items = service_ids
2751 .items
2752 .as_ref()
2753 .expect("array schema missing items should get a default");
2754 assert_eq!(items.r#type, "string");
2755 }
2756
2757 #[test]
2758 fn test_tool_parameters_to_schema_maps_no_arg_tool_to_none() {
2759 let schema = tool_parameters_to_schema(json!({"type": "object", "properties": {}}))
2760 .expect("schema conversion");
2761
2762 assert!(schema.is_none());
2763 }
2764
2765 #[test]
2766 fn test_tool_parameters_to_schema_resolves_defs_ref() {
2767 let schema_json = json!({
2768 "type": "object",
2769 "properties": {
2770 "destination": { "$ref": "#/$defs/Destination" }
2771 },
2772 "required": ["destination"],
2773 "$defs": {
2774 "Destination": {
2775 "type": "object",
2776 "properties": {
2777 "city": { "type": "string" }
2778 },
2779 "required": ["city"]
2780 }
2781 }
2782 });
2783
2784 let schema = tool_parameters_to_schema(schema_json)
2785 .expect("schema conversion")
2786 .expect("schema");
2787 let props = schema.properties.expect("properties");
2788 let destination = props.get("destination").expect("destination prop");
2789
2790 assert_eq!(destination.r#type, "object");
2791 assert_eq!(destination.required, Some(vec!["city".to_string()]));
2792 }
2793
2794 #[test]
2795 fn test_tool_parameters_to_schema_handles_nullable_type_arrays() {
2796 let schema_json = json!({
2797 "type": "object",
2798 "properties": {
2799 "nickname": { "type": ["null", "string"] }
2800 }
2801 });
2802
2803 let schema = tool_parameters_to_schema(schema_json)
2804 .expect("schema conversion")
2805 .expect("schema");
2806 let props = schema.properties.expect("properties");
2807 let nickname = props.get("nickname").expect("nickname prop");
2808
2809 assert_eq!(nickname.r#type, "string");
2810 assert_eq!(nickname.nullable, Some(true));
2811 }
2812
2813 #[test]
2814 fn test_txt_document_conversion_to_text_part() {
2815 use crate::message::{DocumentMediaType, UserContent};
2817
2818 let doc = UserContent::document(
2819 "Note: test.md\nPath: /test.md\nContent: Hello World!",
2820 Some(DocumentMediaType::TXT),
2821 );
2822
2823 let content: Content = message::Message::User {
2824 content: crate::OneOrMany::one(doc),
2825 }
2826 .try_into()
2827 .unwrap();
2828
2829 if let Part {
2830 part: PartKind::Text(text),
2831 ..
2832 } = &content.parts[0]
2833 {
2834 assert!(text.contains("Note: test.md"));
2835 assert!(text.contains("Hello World!"));
2836 } else {
2837 panic!(
2838 "Expected text part for TXT document, got: {:?}",
2839 content.parts[0]
2840 );
2841 }
2842 }
2843
2844 #[test]
2845 fn test_tool_result_with_image_content() {
2846 use crate::OneOrMany;
2848 use crate::message::{
2849 DocumentSourceKind, Image, ImageMediaType, ToolResult, ToolResultContent,
2850 };
2851
2852 let tool_result = ToolResult {
2854 id: "test_tool".to_string(),
2855 call_id: None,
2856 content: OneOrMany::many(vec![
2857 ToolResultContent::Text(message::Text::new(r#"{"status": "success"}"#.to_string())),
2858 ToolResultContent::Image(Image {
2859 data: DocumentSourceKind::Base64("iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==".to_string()),
2860 media_type: Some(ImageMediaType::PNG),
2861 detail: None,
2862 additional_params: None,
2863 }),
2864 ]).expect("Should create OneOrMany with multiple items"),
2865 };
2866
2867 let user_content = message::UserContent::ToolResult(tool_result);
2868 let msg = message::Message::User {
2869 content: OneOrMany::one(user_content),
2870 };
2871
2872 let content: Content = msg.try_into().expect("Should convert to Gemini Content");
2874 assert_eq!(content.role, Some(Role::User));
2875 assert_eq!(content.parts.len(), 1);
2876
2877 if let Some(Part {
2879 part: PartKind::FunctionResponse(function_response),
2880 ..
2881 }) = content.parts.first()
2882 {
2883 assert_eq!(function_response.name, "test_tool");
2884
2885 assert!(function_response.response.is_some());
2887 let response = function_response.response.as_ref().unwrap();
2888 assert!(response.get("result").is_some());
2889
2890 assert!(function_response.parts.is_some());
2892 let parts = function_response.parts.as_ref().unwrap();
2893 assert_eq!(parts.len(), 1);
2894
2895 let image_part = &parts[0];
2896 assert!(image_part.inline_data.is_some());
2897 let inline_data = image_part.inline_data.as_ref().unwrap();
2898 assert_eq!(inline_data.mime_type, "image/png");
2899 assert!(!inline_data.data.is_empty());
2900 } else {
2901 panic!("Expected FunctionResponse part");
2902 }
2903 }
2904
2905 #[test]
2906 fn test_markdown_document_conversion_to_text_part() {
2907 use crate::message::{DocumentMediaType, UserContent};
2909
2910 let doc = UserContent::document(
2911 "# Heading\n\n* List item",
2912 Some(DocumentMediaType::MARKDOWN),
2913 );
2914
2915 let content: Content = message::Message::User {
2916 content: crate::OneOrMany::one(doc),
2917 }
2918 .try_into()
2919 .unwrap();
2920
2921 if let Part {
2922 part: PartKind::Text(text),
2923 ..
2924 } = &content.parts[0]
2925 {
2926 assert_eq!(text, "# Heading\n\n* List item");
2927 } else {
2928 panic!(
2929 "Expected text part for MARKDOWN document, got: {:?}",
2930 content.parts[0]
2931 );
2932 }
2933 }
2934
2935 #[test]
2936 fn test_markdown_url_document_conversion_to_file_data_part() {
2937 use crate::message::{DocumentMediaType, DocumentSourceKind, UserContent};
2939
2940 let doc = UserContent::Document(message::Document {
2941 data: DocumentSourceKind::Url(
2942 "https://generativelanguage.googleapis.com/v1beta/files/test-markdown".to_string(),
2943 ),
2944 media_type: Some(DocumentMediaType::MARKDOWN),
2945 additional_params: None,
2946 });
2947
2948 let content: Content = message::Message::User {
2949 content: crate::OneOrMany::one(doc),
2950 }
2951 .try_into()
2952 .unwrap();
2953
2954 if let Part {
2955 part: PartKind::FileData(file_data),
2956 ..
2957 } = &content.parts[0]
2958 {
2959 assert_eq!(
2960 file_data.file_uri,
2961 "https://generativelanguage.googleapis.com/v1beta/files/test-markdown"
2962 );
2963 assert_eq!(file_data.mime_type.as_deref(), Some("text/markdown"));
2964 } else {
2965 panic!(
2966 "Expected file_data part for URL MARKDOWN document, got: {:?}",
2967 content.parts[0]
2968 );
2969 }
2970 }
2971
2972 #[test]
2973 fn test_tool_result_with_url_image() {
2974 use crate::OneOrMany;
2976 use crate::message::{
2977 DocumentSourceKind, Image, ImageMediaType, ToolResult, ToolResultContent,
2978 };
2979
2980 let tool_result = ToolResult {
2981 id: "screenshot_tool".to_string(),
2982 call_id: None,
2983 content: OneOrMany::one(ToolResultContent::Image(Image {
2984 data: DocumentSourceKind::Url("https://example.com/image.png".to_string()),
2985 media_type: Some(ImageMediaType::PNG),
2986 detail: None,
2987 additional_params: None,
2988 })),
2989 };
2990
2991 let user_content = message::UserContent::ToolResult(tool_result);
2992 let msg = message::Message::User {
2993 content: OneOrMany::one(user_content),
2994 };
2995
2996 let content: Content = msg.try_into().expect("Should convert to Gemini Content");
2997 assert_eq!(content.role, Some(Role::User));
2998 assert_eq!(content.parts.len(), 1);
2999
3000 if let Some(Part {
3001 part: PartKind::FunctionResponse(function_response),
3002 ..
3003 }) = content.parts.first()
3004 {
3005 assert_eq!(function_response.name, "screenshot_tool");
3006
3007 assert!(function_response.parts.is_some());
3009 let parts = function_response.parts.as_ref().unwrap();
3010 assert_eq!(parts.len(), 1);
3011
3012 let image_part = &parts[0];
3013 assert!(image_part.file_data.is_some());
3014 let file_data = image_part.file_data.as_ref().unwrap();
3015 assert_eq!(file_data.file_uri, "https://example.com/image.png");
3016 assert_eq!(file_data.mime_type.as_ref().unwrap(), "image/png");
3017 } else {
3018 panic!("Expected FunctionResponse part");
3019 }
3020 }
3021
3022 #[test]
3023 fn test_create_request_body_with_documents() {
3024 use crate::OneOrMany;
3026 use crate::completion::request::{CompletionRequest, Document};
3027 use crate::message::Message;
3028
3029 let documents = vec![
3030 Document {
3031 id: "doc1".to_string(),
3032 text: "Note: first.md\nContent: First note".to_string(),
3033 additional_props: std::collections::HashMap::new(),
3034 },
3035 Document {
3036 id: "doc2".to_string(),
3037 text: "Note: second.md\nContent: Second note".to_string(),
3038 additional_props: std::collections::HashMap::new(),
3039 },
3040 ];
3041
3042 let documents_message = CompletionRequest {
3043 preamble: None,
3044 chat_history: OneOrMany::one(Message::user("placeholder")),
3045 documents,
3046 tools: vec![],
3047 temperature: None,
3048 model: None,
3049 output_schema: None,
3050 max_tokens: None,
3051 tool_choice: None,
3052 additional_params: None,
3053 }
3054 .normalized_documents()
3055 .unwrap();
3056
3057 let completion_request = CompletionRequest {
3058 preamble: Some("You are a helpful assistant".to_string()),
3059 chat_history: OneOrMany::many(vec![
3060 documents_message,
3061 Message::user("What are my notes about?"),
3062 ])
3063 .unwrap(),
3064 documents: vec![],
3065 tools: vec![],
3066 temperature: None,
3067 model: None,
3068 output_schema: None,
3069 max_tokens: None,
3070 tool_choice: None,
3071 additional_params: None,
3072 };
3073
3074 let request = create_request_body(completion_request).unwrap();
3075
3076 assert_eq!(
3078 request.contents.len(),
3079 2,
3080 "Expected 2 contents (documents + user message)"
3081 );
3082
3083 assert_eq!(request.contents[0].role, Some(Role::User));
3085 assert_eq!(
3086 request.contents[0].parts.len(),
3087 2,
3088 "Expected 2 document parts"
3089 );
3090
3091 for part in &request.contents[0].parts {
3093 if let Part {
3094 part: PartKind::Text(text),
3095 ..
3096 } = part
3097 {
3098 assert!(
3099 text.contains("Note:") && text.contains("Content:"),
3100 "Document should contain note metadata"
3101 );
3102 } else {
3103 panic!("Document parts should be text, not {:?}", part);
3104 }
3105 }
3106
3107 assert_eq!(request.contents[1].role, Some(Role::User));
3109 if let Part {
3110 part: PartKind::Text(text),
3111 ..
3112 } = &request.contents[1].parts[0]
3113 {
3114 assert_eq!(text, "What are my notes about?");
3115 } else {
3116 panic!("Expected user message to be text");
3117 }
3118 }
3119
3120 #[test]
3121 fn test_create_request_body_without_documents() {
3122 use crate::OneOrMany;
3124 use crate::completion::request::CompletionRequest;
3125 use crate::message::Message;
3126
3127 let completion_request = CompletionRequest {
3128 preamble: Some("You are a helpful assistant".to_string()),
3129 chat_history: OneOrMany::one(Message::user("Hello")),
3130 documents: vec![], tools: vec![],
3132 temperature: None,
3133 max_tokens: None,
3134 tool_choice: None,
3135 model: None,
3136 output_schema: None,
3137 additional_params: None,
3138 };
3139
3140 let request = create_request_body(completion_request).unwrap();
3141
3142 assert_eq!(request.contents.len(), 1, "Expected only user message");
3144 assert_eq!(request.contents[0].role, Some(Role::User));
3145
3146 if let Part {
3147 part: PartKind::Text(text),
3148 ..
3149 } = &request.contents[0].parts[0]
3150 {
3151 assert_eq!(text, "Hello");
3152 } else {
3153 panic!("Expected user message to be text");
3154 }
3155 }
3156
3157 #[test]
3158 fn test_from_tool_output_parses_image_json() {
3159 use crate::message::{DocumentSourceKind, ToolResultContent};
3161
3162 let image_json = r#"{"type": "image", "data": "base64data==", "mimeType": "image/jpeg"}"#;
3164 let result = ToolResultContent::from_tool_output(image_json);
3165
3166 assert_eq!(result.len(), 1);
3167 if let ToolResultContent::Image(img) = result.first() {
3168 assert!(matches!(img.data, DocumentSourceKind::Base64(_)));
3169 if let DocumentSourceKind::Base64(data) = &img.data {
3170 assert_eq!(data, "base64data==");
3171 }
3172 assert_eq!(img.media_type, Some(crate::message::ImageMediaType::JPEG));
3173 } else {
3174 panic!("Expected Image content");
3175 }
3176 }
3177
3178 #[test]
3179 fn test_from_tool_output_parses_hybrid_json() {
3180 use crate::message::{DocumentSourceKind, ToolResultContent};
3182
3183 let hybrid_json = r#"{
3184 "response": {"status": "ok", "count": 42},
3185 "parts": [
3186 {"type": "image", "data": "imgdata1==", "mimeType": "image/png"},
3187 {"type": "image", "data": "https://example.com/img.jpg", "mimeType": "image/jpeg"}
3188 ]
3189 }"#;
3190
3191 let result = ToolResultContent::from_tool_output(hybrid_json);
3192
3193 assert_eq!(result.len(), 3);
3195
3196 let items: Vec<_> = result.iter().collect();
3197
3198 if let ToolResultContent::Text(text) = &items[0] {
3200 assert!(text.text.contains("status"));
3201 assert!(text.text.contains("ok"));
3202 } else {
3203 panic!("Expected Text content first");
3204 }
3205
3206 if let ToolResultContent::Image(img) = &items[1] {
3208 assert!(matches!(img.data, DocumentSourceKind::Base64(_)));
3209 } else {
3210 panic!("Expected Image content second");
3211 }
3212
3213 if let ToolResultContent::Image(img) = &items[2] {
3215 assert!(matches!(img.data, DocumentSourceKind::Url(_)));
3216 } else {
3217 panic!("Expected Image content third");
3218 }
3219 }
3220
3221 #[tokio::test]
3225 #[ignore = "requires GEMINI_API_KEY environment variable"]
3226 async fn test_gemini_agent_with_image_tool_result_e2e() -> anyhow::Result<()> {
3227 use crate::completion::Prompt;
3228 use crate::prelude::*;
3229 use crate::providers::gemini;
3230 use crate::test_utils::MockImageGeneratorTool;
3231
3232 let client = gemini::Client::from_env()?;
3233
3234 let agent = client
3235 .agent("gemini-3-flash-preview")
3236 .preamble("You are a helpful assistant. When asked about images, use the generate_test_image tool to create one, then describe what you see in the image.")
3237 .tool(MockImageGeneratorTool)
3238 .build();
3239
3240 let response_text = agent
3242 .prompt("Please generate a test image and tell me what color the pixel is.")
3243 .await?;
3244 println!("Response: {response_text}");
3245 anyhow::ensure!(!response_text.is_empty(), "Response should not be empty");
3247
3248 Ok(())
3249 }
3250}