1pub const GEMINI_3_1_FLASH_LITE_PREVIEW: &str = "gemini-3.1-flash-lite-preview";
7pub const GEMINI_3_FLASH_PREVIEW: &str = "gemini-3-flash-preview";
9pub const GEMINI_2_5_PRO_PREVIEW_06_05: &str = "gemini-2.5-pro-preview-06-05";
11pub const GEMINI_2_5_PRO_PREVIEW_05_06: &str = "gemini-2.5-pro-preview-05-06";
13pub const GEMINI_2_5_PRO_PREVIEW_03_25: &str = "gemini-2.5-pro-preview-03-25";
15pub const GEMINI_2_5_FLASH_PREVIEW_04_17: &str = "gemini-2.5-flash-preview-04-17";
17pub const GEMINI_2_5_PRO_EXP_03_25: &str = "gemini-2.5-pro-exp-03-25";
19pub const GEMINI_2_5_FLASH: &str = "gemini-2.5-flash";
21pub const GEMINI_2_0_FLASH_LITE: &str = "gemini-2.0-flash-lite";
23pub const GEMINI_2_0_FLASH: &str = "gemini-2.0-flash";
25
26use self::gemini_api_types::tool_parameters_to_schema;
27use crate::http_client::HttpClientExt;
28use crate::message::{self, MimeType, Reasoning};
29
30use crate::providers::gemini::completion::gemini_api_types::{
31 AdditionalParameters, FunctionCallingMode, ToolConfig,
32};
33use crate::providers::gemini::streaming::StreamingCompletionResponse;
34use crate::telemetry::SpanCombinator;
35use crate::{
36 OneOrMany,
37 completion::{self, CompletionError, CompletionRequest, GetTokenUsage},
38};
39use gemini_api_types::{
40 Content, FinishReason, FunctionDeclaration, GenerateContentRequest, GenerateContentResponse,
41 GenerationConfig, Part, PartKind, Role, Tool,
42};
43use serde_json::{Map, Value};
44use std::convert::TryFrom;
45use tracing::{Level, enabled, info_span};
46use tracing_futures::Instrument;
47
48use super::Client;
49
50#[derive(Clone, Debug)]
55pub struct CompletionModel<T = reqwest::Client> {
56 pub(crate) client: Client<T>,
57 pub model: String,
58}
59
60impl<T> CompletionModel<T> {
61 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
62 Self {
63 client,
64 model: model.into(),
65 }
66 }
67
68 pub fn with_model(client: Client<T>, model: &str) -> Self {
69 Self {
70 client,
71 model: model.into(),
72 }
73 }
74}
75
76impl<T> completion::CompletionModel for CompletionModel<T>
77where
78 T: HttpClientExt + Clone + 'static,
79{
80 type Response = GenerateContentResponse;
81 type StreamingResponse = StreamingCompletionResponse;
82 type Client = super::Client<T>;
83
84 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
85 Self::new(client.clone(), model)
86 }
87
88 async fn completion(
89 &self,
90 completion_request: CompletionRequest,
91 ) -> Result<completion::CompletionResponse<GenerateContentResponse>, CompletionError> {
92 let request_model = resolve_request_model(&self.model, &completion_request);
93 let span = if tracing::Span::current().is_disabled() {
94 info_span!(
95 target: "rig::completions",
96 "generate_content",
97 gen_ai.operation.name = "generate_content",
98 gen_ai.provider.name = "gcp.gemini",
99 gen_ai.request.model = &request_model,
100 gen_ai.system_instructions = &completion_request.preamble,
101 gen_ai.response.id = tracing::field::Empty,
102 gen_ai.response.model = tracing::field::Empty,
103 gen_ai.usage.output_tokens = tracing::field::Empty,
104 gen_ai.usage.input_tokens = tracing::field::Empty,
105 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
106 gen_ai.usage.cache_creation.input_tokens = tracing::field::Empty,
107 gen_ai.usage.tool_use_prompt_tokens = tracing::field::Empty,
108 gen_ai.usage.reasoning_tokens = tracing::field::Empty,
109 )
110 } else {
111 tracing::Span::current()
112 };
113
114 let request = create_request_body(completion_request)?;
115
116 if enabled!(Level::TRACE) {
117 tracing::trace!(
118 target: "rig::completions",
119 "Gemini completion request: {}",
120 serde_json::to_string_pretty(&request)?
121 );
122 }
123
124 let body = serde_json::to_vec(&request)?;
125
126 let path = completion_endpoint(&request_model);
127
128 let request = self
129 .client
130 .post(path.as_str())?
131 .body(body)
132 .map_err(|e| CompletionError::HttpError(e.into()))?;
133
134 async move {
135 let response = self.client.send::<_, Vec<u8>>(request).await?;
136
137 if response.status().is_success() {
138 let response_body = response
139 .into_body()
140 .await
141 .map_err(CompletionError::HttpError)?;
142
143 let response_text = String::from_utf8_lossy(&response_body).to_string();
144
145 let response: GenerateContentResponse = serde_json::from_slice(&response_body)
146 .map_err(|err| {
147 tracing::error!(
148 error = %err,
149 body = %response_text,
150 "Failed to deserialize Gemini completion response"
151 );
152 CompletionError::JsonError(err)
153 })?;
154
155 let span = tracing::Span::current();
156 span.record_response_metadata(&response);
157 span.record_token_usage(&response.usage_metadata);
158
159 if enabled!(Level::TRACE) {
160 tracing::trace!(
161 target: "rig::completions",
162 "Gemini completion response: {}",
163 serde_json::to_string_pretty(&response)?
164 );
165 }
166
167 response.try_into()
168 } else {
169 let text = String::from_utf8_lossy(
170 &response
171 .into_body()
172 .await
173 .map_err(CompletionError::HttpError)?,
174 )
175 .into();
176
177 Err(CompletionError::ProviderError(text))
178 }
179 }
180 .instrument(span)
181 .await
182 }
183
184 async fn stream(
185 &self,
186 request: CompletionRequest,
187 ) -> Result<
188 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
189 CompletionError,
190 > {
191 CompletionModel::stream(self, request).await
192 }
193}
194
195pub(crate) fn create_request_body(
196 completion_request: CompletionRequest,
197) -> Result<GenerateContentRequest, CompletionError> {
198 let documents_message = completion_request.normalized_documents();
199
200 let CompletionRequest {
201 model: _,
202 preamble,
203 chat_history,
204 documents: _,
205 tools: function_tools,
206 temperature,
207 max_tokens,
208 tool_choice,
209 mut additional_params,
210 output_schema,
211 } = completion_request;
212
213 let mut full_history = Vec::new();
214 if let Some(msg) = documents_message {
215 full_history.push(msg);
216 }
217 full_history.extend(chat_history);
218 let (history_system, full_history) = split_system_messages_from_history(full_history);
219
220 let mut additional_params_payload = additional_params
221 .take()
222 .unwrap_or_else(|| Value::Object(Map::new()));
223 let mut additional_tools =
224 extract_tools_from_additional_params(&mut additional_params_payload)?;
225
226 let AdditionalParameters {
227 mut generation_config,
228 additional_params,
229 } = serde_json::from_value::<AdditionalParameters>(additional_params_payload)?;
230
231 if let Some(schema) = output_schema {
233 let cfg = generation_config.get_or_insert_with(GenerationConfig::default);
234 cfg.response_mime_type = Some("application/json".to_string());
235 cfg.response_json_schema = Some(schema.to_value());
236 }
237
238 generation_config = generation_config.map(|mut cfg| {
239 if let Some(temp) = temperature {
240 cfg.temperature = Some(temp);
241 };
242
243 if let Some(max_tokens) = max_tokens {
244 cfg.max_output_tokens = Some(max_tokens);
245 };
246
247 cfg
248 });
249
250 let mut system_parts: Vec<Part> = Vec::new();
251 if let Some(preamble) = preamble.filter(|preamble| !preamble.is_empty()) {
252 system_parts.push(preamble.into());
253 }
254 for content in history_system {
255 if !content.is_empty() {
256 system_parts.push(content.into());
257 }
258 }
259 let system_instruction = if system_parts.is_empty() {
260 None
261 } else {
262 Some(Content {
263 parts: system_parts,
264 role: Some(Role::Model),
265 })
266 };
267
268 let mut tools = if function_tools.is_empty() {
269 Vec::new()
270 } else {
271 vec![serde_json::to_value(Tool::try_from(function_tools)?)?]
272 };
273 tools.append(&mut additional_tools);
274 let tools = if tools.is_empty() { None } else { Some(tools) };
275
276 let tool_config = if let Some(cfg) = tool_choice {
277 Some(ToolConfig {
278 function_calling_config: Some(FunctionCallingMode::try_from(cfg)?),
279 })
280 } else {
281 None
282 };
283
284 let request = GenerateContentRequest {
285 contents: full_history
286 .into_iter()
287 .map(|msg| {
288 msg.try_into()
289 .map_err(|e| CompletionError::RequestError(Box::new(e)))
290 })
291 .collect::<Result<Vec<_>, _>>()?,
292 generation_config,
293 safety_settings: None,
294 tools,
295 tool_config,
296 system_instruction,
297 additional_params,
298 };
299
300 Ok(request)
301}
302
303fn split_system_messages_from_history(
304 history: Vec<completion::Message>,
305) -> (Vec<String>, Vec<completion::Message>) {
306 let mut system = Vec::new();
307 let mut remaining = Vec::new();
308
309 for message in history {
310 match message {
311 completion::Message::System { content } => system.push(content),
312 other => remaining.push(other),
313 }
314 }
315
316 (system, remaining)
317}
318
319fn extract_tools_from_additional_params(
320 additional_params: &mut Value,
321) -> Result<Vec<Value>, CompletionError> {
322 if let Some(map) = additional_params.as_object_mut()
323 && let Some(raw_tools) = map.remove("tools")
324 {
325 return serde_json::from_value::<Vec<Value>>(raw_tools).map_err(|err| {
326 CompletionError::RequestError(
327 format!("Invalid Gemini `additional_params.tools` payload: {err}").into(),
328 )
329 });
330 }
331
332 Ok(Vec::new())
333}
334
335pub(crate) fn resolve_request_model(
336 default_model: &str,
337 completion_request: &CompletionRequest,
338) -> String {
339 completion_request
340 .model
341 .clone()
342 .unwrap_or_else(|| default_model.to_string())
343}
344
345pub(crate) fn completion_endpoint(model: &str) -> String {
346 format!("/v1beta/models/{model}:generateContent")
347}
348
349pub(crate) fn streaming_endpoint(model: &str) -> String {
350 format!("/v1beta/models/{model}:streamGenerateContent")
351}
352
353impl TryFrom<completion::ToolDefinition> for Tool {
354 type Error = CompletionError;
355
356 fn try_from(tool: completion::ToolDefinition) -> Result<Self, Self::Error> {
357 let parameters = tool_parameters_to_schema(tool.parameters)?;
358
359 Ok(Self {
360 function_declarations: vec![FunctionDeclaration {
361 name: tool.name,
362 description: tool.description,
363 parameters,
364 }],
365 code_execution: None,
366 })
367 }
368}
369
370impl TryFrom<Vec<completion::ToolDefinition>> for Tool {
371 type Error = CompletionError;
372
373 fn try_from(tools: Vec<completion::ToolDefinition>) -> Result<Self, Self::Error> {
374 let mut function_declarations = Vec::new();
375
376 for tool in tools {
377 let parameters = tool_parameters_to_schema(tool.parameters).map_err(|e| {
378 CompletionError::ProviderError(format!(
379 "Tool '{}' could not be converted to a schema: {:?}",
380 tool.name, e,
381 ))
382 })?;
383
384 function_declarations.push(FunctionDeclaration {
385 name: tool.name,
386 description: tool.description,
387 parameters,
388 });
389 }
390
391 Ok(Self {
392 function_declarations,
393 code_execution: None,
394 })
395 }
396}
397
398pub(crate) fn function_call_finish_reason_error(
399 reason: &FinishReason,
400 finish_message: Option<&str>,
401) -> Option<CompletionError> {
402 match reason {
403 FinishReason::MalformedFunctionCall
404 | FinishReason::UnexpectedToolCall
405 | FinishReason::MissingThoughtSignature
406 | FinishReason::TooManyToolCalls
407 | FinishReason::MalformedResponse => {
408 let message = finish_message.unwrap_or("no finish message provided");
409 Some(CompletionError::ResponseError(format!(
410 "Gemini stopped with finish_reason={reason:?}: {message}"
411 )))
412 }
413 _ => None,
414 }
415}
416
417impl TryFrom<GenerateContentResponse> for completion::CompletionResponse<GenerateContentResponse> {
418 type Error = CompletionError;
419
420 fn try_from(response: GenerateContentResponse) -> Result<Self, Self::Error> {
421 let candidate = response.candidates.first().ok_or_else(|| {
422 CompletionError::ResponseError("No response candidates in response".into())
423 })?;
424
425 if let Some(reason) = candidate.finish_reason.as_ref()
426 && let Some(err) =
427 function_call_finish_reason_error(reason, candidate.finish_message.as_deref())
428 {
429 return Err(err);
430 }
431
432 let content = candidate
433 .content
434 .as_ref()
435 .ok_or_else(|| {
436 let reason = candidate
437 .finish_reason
438 .as_ref()
439 .map(|r| format!("finish_reason={r:?}"))
440 .unwrap_or_else(|| "finish_reason=<unknown>".to_string());
441 let message = candidate
442 .finish_message
443 .as_deref()
444 .unwrap_or("no finish message provided");
445 CompletionError::ResponseError(format!(
446 "Gemini candidate missing content ({reason}, finish_message={message})"
447 ))
448 })?
449 .parts
450 .iter()
451 .map(
452 |Part {
453 thought,
454 thought_signature,
455 part,
456 ..
457 }| {
458 Ok(match part {
459 PartKind::Text(text) => {
460 if let Some(thought) = thought
461 && *thought
462 {
463 completion::AssistantContent::Reasoning(
464 Reasoning::new_with_signature(text, thought_signature.clone()),
465 )
466 } else {
467 completion::AssistantContent::text(text)
468 }
469 }
470 PartKind::InlineData(inline_data) => {
471 let mime_type =
472 message::MediaType::from_mime_type(&inline_data.mime_type);
473
474 match mime_type {
475 Some(message::MediaType::Image(media_type)) => {
476 message::AssistantContent::image_base64(
477 &inline_data.data,
478 Some(media_type),
479 Some(message::ImageDetail::default()),
480 )
481 }
482 _ => {
483 return Err(CompletionError::ResponseError(format!(
484 "Unsupported media type {mime_type:?}"
485 )));
486 }
487 }
488 }
489 PartKind::FunctionCall(function_call) => {
490 completion::AssistantContent::ToolCall(
491 message::ToolCall::new(
492 function_call.name.clone(),
493 message::ToolFunction::new(
494 function_call.name.clone(),
495 function_call.args.clone(),
496 ),
497 )
498 .with_signature(thought_signature.clone()),
499 )
500 }
501 _ => {
502 return Err(CompletionError::ResponseError(
503 "Response did not contain a message or tool call".into(),
504 ));
505 }
506 })
507 },
508 )
509 .collect::<Result<Vec<_>, _>>()?;
510
511 let choice = OneOrMany::many(content).map_err(|_| {
512 CompletionError::ResponseError(
513 "Response contained no message or tool call (empty)".to_owned(),
514 )
515 })?;
516
517 let usage = response
518 .usage_metadata
519 .as_ref()
520 .and_then(GetTokenUsage::token_usage)
521 .unwrap_or_default();
522
523 Ok(completion::CompletionResponse {
524 choice,
525 usage,
526 raw_response: response,
527 message_id: None,
528 })
529 }
530}
531
532pub mod gemini_api_types {
533 use crate::telemetry::ProviderResponseExt;
534 use std::{collections::HashMap, convert::Infallible, str::FromStr};
535
536 use serde::{Deserialize, Serialize};
540 use serde_json::{Value, json};
541
542 use crate::completion::GetTokenUsage;
543 use crate::message::{DocumentSourceKind, ImageMediaType, MessageError, MimeType};
544 use crate::{
545 completion::CompletionError,
546 message::{self},
547 providers::gemini::gemini_api_types::{CodeExecutionResult, ExecutableCode},
548 };
549
550 #[derive(Debug, Deserialize, Serialize, Default)]
551 #[serde(rename_all = "camelCase")]
552 pub struct AdditionalParameters {
553 pub generation_config: Option<GenerationConfig>,
555 #[serde(flatten, skip_serializing_if = "Option::is_none")]
557 pub additional_params: Option<serde_json::Value>,
558 }
559
560 impl AdditionalParameters {
561 pub fn with_config(mut self, cfg: GenerationConfig) -> Self {
562 self.generation_config = Some(cfg);
563 self
564 }
565
566 pub fn with_params(mut self, params: serde_json::Value) -> Self {
567 self.additional_params = Some(params);
568 self
569 }
570 }
571
572 #[derive(Debug, Deserialize, Serialize)]
580 #[serde(rename_all = "camelCase")]
581 pub struct GenerateContentResponse {
582 pub response_id: String,
583 pub candidates: Vec<ContentCandidate>,
585 pub prompt_feedback: Option<PromptFeedback>,
587 pub usage_metadata: Option<UsageMetadata>,
589 pub model_version: Option<String>,
590 }
591
592 impl ProviderResponseExt for GenerateContentResponse {
593 type OutputMessage = ContentCandidate;
594 type Usage = UsageMetadata;
595
596 fn get_response_id(&self) -> Option<String> {
597 Some(self.response_id.clone())
598 }
599
600 fn get_response_model_name(&self) -> Option<String> {
601 self.model_version.clone()
602 }
603
604 fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
605 self.candidates.clone()
606 }
607
608 fn get_text_response(&self) -> Option<String> {
609 let str = self
610 .candidates
611 .iter()
612 .filter_map(|x| {
613 let content = x.content.as_ref()?;
614 if content.role.as_ref().is_none_or(|y| y != &Role::Model) {
615 return None;
616 }
617
618 let res = content
619 .parts
620 .iter()
621 .filter_map(|part| {
622 if let PartKind::Text(ref str) = part.part {
623 Some(str.to_owned())
624 } else {
625 None
626 }
627 })
628 .collect::<Vec<String>>()
629 .join("\n");
630
631 Some(res)
632 })
633 .collect::<Vec<String>>()
634 .join("\n");
635
636 if str.is_empty() { None } else { Some(str) }
637 }
638
639 fn get_usage(&self) -> Option<Self::Usage> {
640 self.usage_metadata.clone()
641 }
642 }
643
644 #[derive(Clone, Debug, Deserialize, Serialize)]
646 #[serde(rename_all = "camelCase")]
647 pub struct ContentCandidate {
648 #[serde(skip_serializing_if = "Option::is_none")]
650 pub content: Option<Content>,
651 pub finish_reason: Option<FinishReason>,
654 pub safety_ratings: Option<Vec<SafetyRating>>,
657 pub citation_metadata: Option<CitationMetadata>,
661 pub token_count: Option<i32>,
663 pub avg_logprobs: Option<f64>,
665 pub logprobs_result: Option<LogprobsResult>,
667 pub index: Option<i32>,
669 pub finish_message: Option<String>,
671 }
672
673 #[derive(Clone, Debug, Deserialize, Serialize)]
674 pub struct Content {
675 #[serde(default)]
677 pub parts: Vec<Part>,
678 pub role: Option<Role>,
681 }
682
683 impl TryFrom<message::Message> for Content {
684 type Error = message::MessageError;
685
686 fn try_from(msg: message::Message) -> Result<Self, Self::Error> {
687 Ok(match msg {
688 message::Message::System { content } => Content {
689 parts: vec![content.into()],
690 role: Some(Role::User),
691 },
692 message::Message::User { content } => Content {
693 parts: content
694 .into_iter()
695 .map(|c| c.try_into())
696 .collect::<Result<Vec<_>, _>>()?,
697 role: Some(Role::User),
698 },
699 message::Message::Assistant { content, .. } => Content {
700 role: Some(Role::Model),
701 parts: content
702 .into_iter()
703 .map(|content| content.try_into())
704 .collect::<Result<Vec<_>, _>>()?,
705 },
706 })
707 }
708 }
709
710 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
711 #[serde(rename_all = "lowercase")]
712 pub enum Role {
713 User,
714 Model,
715 }
716
717 #[derive(Debug, Default, Deserialize, Serialize, Clone, PartialEq)]
718 #[serde(rename_all = "camelCase")]
719 pub struct Part {
720 #[serde(skip_serializing_if = "Option::is_none")]
722 pub thought: Option<bool>,
723 #[serde(skip_serializing_if = "Option::is_none")]
725 pub thought_signature: Option<String>,
726 #[serde(flatten)]
727 pub part: PartKind,
728 #[serde(flatten, skip_serializing_if = "Option::is_none")]
729 pub additional_params: Option<Value>,
730 }
731
732 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
736 #[serde(rename_all = "camelCase")]
737 pub enum PartKind {
738 Text(String),
739 InlineData(Blob),
740 FunctionCall(FunctionCall),
741 FunctionResponse(FunctionResponse),
742 FileData(FileData),
743 ExecutableCode(ExecutableCode),
744 CodeExecutionResult(CodeExecutionResult),
745 }
746
747 impl Default for PartKind {
750 fn default() -> Self {
751 Self::Text(String::new())
752 }
753 }
754
755 impl From<String> for Part {
756 fn from(text: String) -> Self {
757 Self {
758 thought: Some(false),
759 thought_signature: None,
760 part: PartKind::Text(text),
761 additional_params: None,
762 }
763 }
764 }
765
766 impl From<&str> for Part {
767 fn from(text: &str) -> Self {
768 Self::from(text.to_string())
769 }
770 }
771
772 impl FromStr for Part {
773 type Err = Infallible;
774
775 fn from_str(s: &str) -> Result<Self, Self::Err> {
776 Ok(s.into())
777 }
778 }
779
780 impl TryFrom<(ImageMediaType, DocumentSourceKind)> for PartKind {
781 type Error = message::MessageError;
782 fn try_from(
783 (mime_type, doc_src): (ImageMediaType, DocumentSourceKind),
784 ) -> Result<Self, Self::Error> {
785 let mime_type = mime_type.to_mime_type().to_string();
786 let part = match doc_src {
787 DocumentSourceKind::Url(url) => PartKind::FileData(FileData {
788 mime_type: Some(mime_type),
789 file_uri: url,
790 }),
791 DocumentSourceKind::Base64(data) | DocumentSourceKind::String(data) => {
792 PartKind::InlineData(Blob { mime_type, data })
793 }
794 DocumentSourceKind::Raw(_) => {
795 return Err(message::MessageError::ConversionError(
796 "Raw files not supported, encode as base64 first".into(),
797 ));
798 }
799 DocumentSourceKind::FileId(_) => {
800 return Err(message::MessageError::ConversionError(
801 "Provider file IDs are not supported for Gemini image inputs".into(),
802 ));
803 }
804 DocumentSourceKind::Unknown => {
805 return Err(message::MessageError::ConversionError(
806 "Can't convert an unknown document source".to_string(),
807 ));
808 }
809 };
810
811 Ok(part)
812 }
813 }
814
815 impl TryFrom<message::UserContent> for Part {
816 type Error = message::MessageError;
817
818 fn try_from(content: message::UserContent) -> Result<Self, Self::Error> {
819 match content {
820 message::UserContent::Text(message::Text { text, .. }) => Ok(Part {
821 thought: Some(false),
822 thought_signature: None,
823 part: PartKind::Text(text),
824 additional_params: None,
825 }),
826 message::UserContent::ToolResult(message::ToolResult { id, content, .. }) => {
827 let mut response_json: Option<serde_json::Value> = None;
828 let mut parts: Vec<FunctionResponsePart> = Vec::new();
829
830 for item in content.iter() {
831 match item {
832 message::ToolResultContent::Text(text) => {
833 let result: serde_json::Value =
834 serde_json::from_str(&text.text).unwrap_or_else(|error| {
835 tracing::trace!(
836 ?error,
837 "Tool result is not a valid JSON, treat it as normal string"
838 );
839 json!(&text.text)
840 });
841
842 response_json = Some(match response_json {
843 Some(mut existing) => {
844 if let serde_json::Value::Object(ref mut map) = existing {
845 map.insert("text".to_string(), result);
846 }
847 existing
848 }
849 None => json!({ "result": result }),
850 });
851 }
852 message::ToolResultContent::Image(image) => {
853 let part = match &image.data {
854 DocumentSourceKind::Base64(b64) => {
855 let mime_type = image
856 .media_type
857 .as_ref()
858 .ok_or(message::MessageError::ConversionError(
859 "Image media type is required for Gemini tool results".to_string(),
860 ))?
861 .to_mime_type();
862
863 FunctionResponsePart {
864 inline_data: Some(FunctionResponseInlineData {
865 mime_type: mime_type.to_string(),
866 data: b64.clone(),
867 display_name: None,
868 }),
869 file_data: None,
870 }
871 }
872 DocumentSourceKind::Url(url) => {
873 let mime_type = image
874 .media_type
875 .as_ref()
876 .map(|mt| mt.to_mime_type().to_string());
877
878 FunctionResponsePart {
879 inline_data: None,
880 file_data: Some(FileData {
881 mime_type,
882 file_uri: url.clone(),
883 }),
884 }
885 }
886 _ => {
887 return Err(message::MessageError::ConversionError(
888 "Unsupported image source kind for tool results"
889 .to_string(),
890 ));
891 }
892 };
893 parts.push(part);
894 }
895 }
896 }
897
898 Ok(Part {
899 thought: Some(false),
900 thought_signature: None,
901 part: PartKind::FunctionResponse(FunctionResponse {
902 name: id,
903 response: response_json,
904 parts: if parts.is_empty() { None } else { Some(parts) },
905 }),
906 additional_params: None,
907 })
908 }
909 message::UserContent::Image(message::Image {
910 data, media_type, ..
911 }) => match media_type {
912 Some(media_type) => match media_type {
913 message::ImageMediaType::JPEG
914 | message::ImageMediaType::PNG
915 | message::ImageMediaType::WEBP
916 | message::ImageMediaType::HEIC
917 | message::ImageMediaType::HEIF => {
918 let part = PartKind::try_from((media_type, data))?;
919 Ok(Part {
920 thought: Some(false),
921 thought_signature: None,
922 part,
923 additional_params: None,
924 })
925 }
926 _ => Err(message::MessageError::ConversionError(format!(
927 "Unsupported image media type {media_type:?}"
928 ))),
929 },
930 None => Err(message::MessageError::ConversionError(
931 "Media type for image is required for Gemini".to_string(),
932 )),
933 },
934 message::UserContent::Document(message::Document {
935 data, media_type, ..
936 }) => {
937 let Some(media_type) = media_type else {
938 return Err(MessageError::ConversionError(
939 "A mime type is required for document inputs to Gemini".to_string(),
940 ));
941 };
942
943 if matches!(
946 media_type,
947 message::DocumentMediaType::TXT
948 | message::DocumentMediaType::RTF
949 | message::DocumentMediaType::HTML
950 | message::DocumentMediaType::CSS
951 | message::DocumentMediaType::MARKDOWN
952 | message::DocumentMediaType::CSV
953 | message::DocumentMediaType::XML
954 | message::DocumentMediaType::Javascript
955 | message::DocumentMediaType::Python
956 ) {
957 use base64::Engine;
958 let part = match data {
959 DocumentSourceKind::String(text) => PartKind::Text(text),
960 DocumentSourceKind::Base64(data) => {
961 let text = String::from_utf8(
963 base64::engine::general_purpose::STANDARD
964 .decode(&data)
965 .map_err(|e| {
966 MessageError::ConversionError(format!(
967 "Failed to decode base64: {e}"
968 ))
969 })?,
970 )
971 .map_err(|e| {
972 MessageError::ConversionError(format!(
973 "Invalid UTF-8 in document: {e}"
974 ))
975 })?;
976 PartKind::Text(text)
977 }
978 DocumentSourceKind::Url(file_uri) => PartKind::FileData(FileData {
979 mime_type: Some(media_type.to_mime_type().to_string()),
980 file_uri,
981 }),
982 DocumentSourceKind::Raw(_) => {
983 return Err(MessageError::ConversionError(
984 "Raw files not supported, encode as base64 first".to_string(),
985 ));
986 }
987 DocumentSourceKind::FileId(_) => {
988 return Err(MessageError::ConversionError(
989 "Provider file IDs are not supported for Gemini documents"
990 .to_string(),
991 ));
992 }
993 DocumentSourceKind::Unknown => {
994 return Err(MessageError::ConversionError(
995 "Document has no body".to_string(),
996 ));
997 }
998 };
999
1000 Ok(Part {
1001 thought: Some(false),
1002 part,
1003 ..Default::default()
1004 })
1005 } else if !media_type.is_code() {
1006 let mime_type = media_type.to_mime_type().to_string();
1007
1008 let part = match data {
1009 DocumentSourceKind::Url(file_uri) => PartKind::FileData(FileData {
1010 mime_type: Some(mime_type),
1011 file_uri,
1012 }),
1013 DocumentSourceKind::Base64(data) | DocumentSourceKind::String(data) => {
1014 PartKind::InlineData(Blob { mime_type, data })
1015 }
1016 DocumentSourceKind::Raw(_) => {
1017 return Err(message::MessageError::ConversionError(
1018 "Raw files not supported, encode as base64 first".into(),
1019 ));
1020 }
1021 _ => {
1022 return Err(message::MessageError::ConversionError(
1023 "Document has no body".to_string(),
1024 ));
1025 }
1026 };
1027
1028 Ok(Part {
1029 thought: Some(false),
1030 part,
1031 ..Default::default()
1032 })
1033 } else {
1034 Err(message::MessageError::ConversionError(format!(
1035 "Unsupported document media type {media_type:?}"
1036 )))
1037 }
1038 }
1039
1040 message::UserContent::Audio(message::Audio {
1041 data, media_type, ..
1042 }) => {
1043 let Some(media_type) = media_type else {
1044 return Err(MessageError::ConversionError(
1045 "A mime type is required for audio inputs to Gemini".to_string(),
1046 ));
1047 };
1048
1049 let mime_type = media_type.to_mime_type().to_string();
1050
1051 let part = match data {
1052 DocumentSourceKind::Base64(data) => {
1053 PartKind::InlineData(Blob { data, mime_type })
1054 }
1055
1056 DocumentSourceKind::Url(file_uri) => PartKind::FileData(FileData {
1057 mime_type: Some(mime_type),
1058 file_uri,
1059 }),
1060 DocumentSourceKind::String(_) => {
1061 return Err(message::MessageError::ConversionError(
1062 "Strings cannot be used as audio files!".into(),
1063 ));
1064 }
1065 DocumentSourceKind::Raw(_) => {
1066 return Err(message::MessageError::ConversionError(
1067 "Raw files not supported, encode as base64 first".into(),
1068 ));
1069 }
1070 DocumentSourceKind::FileId(_) => {
1071 return Err(message::MessageError::ConversionError(
1072 "Provider file IDs are not supported for Gemini audio inputs"
1073 .into(),
1074 ));
1075 }
1076 DocumentSourceKind::Unknown => {
1077 return Err(message::MessageError::ConversionError(
1078 "Content has no body".to_string(),
1079 ));
1080 }
1081 };
1082
1083 Ok(Part {
1084 thought: Some(false),
1085 part,
1086 ..Default::default()
1087 })
1088 }
1089 message::UserContent::Video(message::Video {
1090 data,
1091 media_type,
1092 additional_params,
1093 ..
1094 }) => {
1095 let mime_type = media_type.map(|media_ty| media_ty.to_mime_type().to_string());
1096
1097 let part = match data {
1098 DocumentSourceKind::Url(file_uri) => {
1099 if file_uri.starts_with("https://www.youtube.com") {
1100 PartKind::FileData(FileData {
1101 mime_type,
1102 file_uri,
1103 })
1104 } else {
1105 if mime_type.is_none() {
1106 return Err(MessageError::ConversionError(
1107 "A mime type is required for non-Youtube video file inputs to Gemini"
1108 .to_string(),
1109 ));
1110 }
1111
1112 PartKind::FileData(FileData {
1113 mime_type,
1114 file_uri,
1115 })
1116 }
1117 }
1118 DocumentSourceKind::Base64(data) => {
1119 let Some(mime_type) = mime_type else {
1120 return Err(MessageError::ConversionError(
1121 "A media type is expected for base64 encoded strings"
1122 .to_string(),
1123 ));
1124 };
1125 PartKind::InlineData(Blob { mime_type, data })
1126 }
1127 DocumentSourceKind::String(_) => {
1128 return Err(message::MessageError::ConversionError(
1129 "Strings cannot be used as audio files!".into(),
1130 ));
1131 }
1132 DocumentSourceKind::Raw(_) => {
1133 return Err(message::MessageError::ConversionError(
1134 "Raw file data not supported, encode as base64 first".into(),
1135 ));
1136 }
1137 DocumentSourceKind::FileId(_) => {
1138 return Err(message::MessageError::ConversionError(
1139 "Provider file IDs are not supported for Gemini video inputs"
1140 .into(),
1141 ));
1142 }
1143 DocumentSourceKind::Unknown => {
1144 return Err(message::MessageError::ConversionError(
1145 "Media type for video is required for Gemini".to_string(),
1146 ));
1147 }
1148 };
1149
1150 Ok(Part {
1151 thought: Some(false),
1152 thought_signature: None,
1153 part,
1154 additional_params,
1155 })
1156 }
1157 }
1158 }
1159 }
1160
1161 impl TryFrom<message::AssistantContent> for Part {
1162 type Error = message::MessageError;
1163
1164 fn try_from(content: message::AssistantContent) -> Result<Self, Self::Error> {
1165 match content {
1166 message::AssistantContent::Text(message::Text { text, .. }) => Ok(text.into()),
1167 message::AssistantContent::Image(message::Image {
1168 data, media_type, ..
1169 }) => match media_type {
1170 Some(media_type) => match media_type {
1171 message::ImageMediaType::JPEG
1172 | message::ImageMediaType::PNG
1173 | message::ImageMediaType::WEBP
1174 | message::ImageMediaType::HEIC
1175 | message::ImageMediaType::HEIF => {
1176 let part = PartKind::try_from((media_type, data))?;
1177 Ok(Part {
1178 thought: Some(false),
1179 thought_signature: None,
1180 part,
1181 additional_params: None,
1182 })
1183 }
1184 _ => Err(message::MessageError::ConversionError(format!(
1185 "Unsupported image media type {media_type:?}"
1186 ))),
1187 },
1188 None => Err(message::MessageError::ConversionError(
1189 "Media type for image is required for Gemini".to_string(),
1190 )),
1191 },
1192 message::AssistantContent::ToolCall(tool_call) => Ok(tool_call.into()),
1193 message::AssistantContent::Reasoning(reasoning) => Ok(Part {
1194 thought: Some(true),
1195 thought_signature: reasoning.first_signature().map(str::to_owned),
1196 part: PartKind::Text(reasoning.display_text()),
1197 additional_params: None,
1198 }),
1199 }
1200 }
1201 }
1202
1203 impl From<message::ToolCall> for Part {
1204 fn from(tool_call: message::ToolCall) -> Self {
1205 Self {
1206 thought: Some(false),
1207 thought_signature: tool_call.signature,
1208 part: PartKind::FunctionCall(FunctionCall {
1209 name: tool_call.function.name,
1210 args: tool_call.function.arguments,
1211 }),
1212 additional_params: None,
1213 }
1214 }
1215 }
1216
1217 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1220 #[serde(rename_all = "camelCase")]
1221 pub struct Blob {
1222 pub mime_type: String,
1225 pub data: String,
1227 }
1228
1229 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1232 pub struct FunctionCall {
1233 pub name: String,
1236 pub args: serde_json::Value,
1238 }
1239
1240 impl From<message::ToolCall> for FunctionCall {
1241 fn from(tool_call: message::ToolCall) -> Self {
1242 Self {
1243 name: tool_call.function.name,
1244 args: tool_call.function.arguments,
1245 }
1246 }
1247 }
1248
1249 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1253 pub struct FunctionResponse {
1254 pub name: String,
1257 #[serde(skip_serializing_if = "Option::is_none")]
1259 pub response: Option<serde_json::Value>,
1260 #[serde(skip_serializing_if = "Option::is_none")]
1262 pub parts: Option<Vec<FunctionResponsePart>>,
1263 }
1264
1265 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1267 #[serde(rename_all = "camelCase")]
1268 pub struct FunctionResponsePart {
1269 #[serde(skip_serializing_if = "Option::is_none")]
1271 pub inline_data: Option<FunctionResponseInlineData>,
1272 #[serde(skip_serializing_if = "Option::is_none")]
1274 pub file_data: Option<FileData>,
1275 }
1276
1277 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1279 #[serde(rename_all = "camelCase")]
1280 pub struct FunctionResponseInlineData {
1281 pub mime_type: String,
1283 pub data: String,
1285 #[serde(skip_serializing_if = "Option::is_none")]
1287 pub display_name: Option<String>,
1288 }
1289
1290 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1292 #[serde(rename_all = "camelCase")]
1293 pub struct FileData {
1294 pub mime_type: Option<String>,
1296 pub file_uri: String,
1298 }
1299
1300 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1301 pub struct SafetyRating {
1302 pub category: HarmCategory,
1303 pub probability: HarmProbability,
1304 }
1305
1306 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1307 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1308 pub enum HarmProbability {
1309 HarmProbabilityUnspecified,
1310 Negligible,
1311 Low,
1312 Medium,
1313 High,
1314 }
1315
1316 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1317 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1318 pub enum HarmCategory {
1319 HarmCategoryUnspecified,
1320 HarmCategoryDerogatory,
1321 HarmCategoryToxicity,
1322 HarmCategoryViolence,
1323 HarmCategorySexually,
1324 HarmCategoryMedical,
1325 HarmCategoryDangerous,
1326 HarmCategoryHarassment,
1327 HarmCategoryHateSpeech,
1328 HarmCategorySexuallyExplicit,
1329 HarmCategoryDangerousContent,
1330 HarmCategoryCivicIntegrity,
1331 }
1332
1333 #[derive(Debug, Deserialize, Clone, Default, Serialize)]
1334 #[serde(rename_all = "camelCase")]
1335 pub struct UsageMetadata {
1336 #[serde(default)]
1337 pub prompt_token_count: i32,
1338 #[serde(skip_serializing_if = "Option::is_none")]
1339 pub cached_content_token_count: Option<i32>,
1340 #[serde(skip_serializing_if = "Option::is_none")]
1341 pub candidates_token_count: Option<i32>,
1342 pub total_token_count: i32,
1343 #[serde(skip_serializing_if = "Option::is_none")]
1344 pub thoughts_token_count: Option<i32>,
1345 #[serde(default, skip_serializing_if = "Option::is_none")]
1346 pub prompt_tokens_details: Option<Vec<ModalityTokenCount>>,
1347 #[serde(default, skip_serializing_if = "Option::is_none")]
1348 pub cache_tokens_details: Option<Vec<ModalityTokenCount>>,
1349 #[serde(default, skip_serializing_if = "Option::is_none")]
1350 pub candidates_tokens_details: Option<Vec<ModalityTokenCount>>,
1351 #[serde(default, skip_serializing_if = "Option::is_none")]
1352 pub tool_use_prompt_token_count: Option<i32>,
1353 #[serde(default, skip_serializing_if = "Option::is_none")]
1354 pub tool_use_prompt_tokens_details: Option<Vec<ModalityTokenCount>>,
1355 #[serde(default, skip_serializing_if = "Option::is_none")]
1356 pub traffic_type: Option<TrafficType>,
1357 }
1358
1359 #[derive(Clone, Debug, Deserialize, Serialize)]
1360 #[serde(rename_all = "camelCase")]
1361 pub struct ModalityTokenCount {
1362 pub modality: Modality,
1363 pub token_count: i32,
1364 }
1365
1366 #[derive(Clone, Debug, Deserialize, Serialize)]
1367 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1368 pub enum Modality {
1369 ModalityUnspecified,
1370 Text,
1371 Image,
1372 Video,
1373 Audio,
1374 Document,
1375 }
1376
1377 #[derive(Clone, Debug, Deserialize, Serialize)]
1378 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1379 pub enum TrafficType {
1380 TrafficTypeUnspecified,
1381 OnDemand,
1382 ProvisionedThroughput,
1383 }
1384
1385 impl std::fmt::Display for UsageMetadata {
1386 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1387 write!(
1388 f,
1389 "Prompt token count: {}\nCached content token count: {}\nCandidates token count: {}\nTotal token count: {}",
1390 self.prompt_token_count,
1391 match self.cached_content_token_count {
1392 Some(count) => count.to_string(),
1393 None => "n/a".to_string(),
1394 },
1395 match self.candidates_token_count {
1396 Some(count) => count.to_string(),
1397 None => "n/a".to_string(),
1398 },
1399 self.total_token_count
1400 )
1401 }
1402 }
1403
1404 impl GetTokenUsage for UsageMetadata {
1405 fn token_usage(&self) -> Option<crate::completion::Usage> {
1406 let mut usage = crate::completion::Usage::new();
1407
1408 usage.input_tokens = self.prompt_token_count as u64;
1409 usage.output_tokens = self.candidates_token_count.unwrap_or_default() as u64;
1410 usage.cached_input_tokens = self.cached_content_token_count.unwrap_or_default() as u64;
1411 usage.reasoning_tokens = self.thoughts_token_count.unwrap_or_default() as u64;
1412 usage.tool_use_prompt_tokens =
1413 self.tool_use_prompt_token_count.unwrap_or_default() as u64;
1414 usage.total_tokens = self.total_token_count as u64;
1415
1416 Some(usage)
1417 }
1418 }
1419
1420 #[derive(Debug, Deserialize, Serialize)]
1422 #[serde(rename_all = "camelCase")]
1423 pub struct PromptFeedback {
1424 pub block_reason: Option<BlockReason>,
1426 pub safety_ratings: Option<Vec<SafetyRating>>,
1428 }
1429
1430 #[derive(Debug, Deserialize, Serialize)]
1432 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1433 pub enum BlockReason {
1434 BlockReasonUnspecified,
1436 Safety,
1438 Other,
1440 Blocklist,
1442 ProhibitedContent,
1444 }
1445
1446 #[derive(Clone, Debug, Deserialize, Serialize)]
1447 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1448 pub enum FinishReason {
1449 FinishReasonUnspecified,
1451 Stop,
1453 MaxTokens,
1455 Safety,
1457 Recitation,
1459 Language,
1461 Other,
1463 Blocklist,
1465 ProhibitedContent,
1467 Spii,
1469 MalformedFunctionCall,
1471 UnexpectedToolCall,
1473 MissingThoughtSignature,
1475 TooManyToolCalls,
1477 MalformedResponse,
1479 }
1480
1481 #[derive(Clone, Debug, Deserialize, Serialize)]
1482 #[serde(rename_all = "camelCase")]
1483 pub struct CitationMetadata {
1484 pub citation_sources: Vec<CitationSource>,
1485 }
1486
1487 #[derive(Clone, Debug, Deserialize, Serialize)]
1488 #[serde(rename_all = "camelCase")]
1489 pub struct CitationSource {
1490 #[serde(skip_serializing_if = "Option::is_none")]
1491 pub uri: Option<String>,
1492 #[serde(skip_serializing_if = "Option::is_none")]
1493 pub start_index: Option<i32>,
1494 #[serde(skip_serializing_if = "Option::is_none")]
1495 pub end_index: Option<i32>,
1496 #[serde(skip_serializing_if = "Option::is_none")]
1497 pub license: Option<String>,
1498 }
1499
1500 #[derive(Clone, Debug, Deserialize, Serialize)]
1501 #[serde(rename_all = "camelCase")]
1502 pub struct LogprobsResult {
1503 pub top_candidate: Vec<TopCandidate>,
1504 pub chosen_candidate: Vec<LogProbCandidate>,
1505 }
1506
1507 #[derive(Clone, Debug, Deserialize, Serialize)]
1508 pub struct TopCandidate {
1509 pub candidates: Vec<LogProbCandidate>,
1510 }
1511
1512 #[derive(Clone, Debug, Deserialize, Serialize)]
1513 #[serde(rename_all = "camelCase")]
1514 pub struct LogProbCandidate {
1515 pub token: String,
1516 pub token_id: String,
1517 pub log_probability: f64,
1518 }
1519
1520 #[derive(Debug, Deserialize, Serialize)]
1525 #[serde(rename_all = "camelCase")]
1526 pub struct GenerationConfig {
1527 #[serde(skip_serializing_if = "Option::is_none")]
1530 pub stop_sequences: Option<Vec<String>>,
1531 #[serde(skip_serializing_if = "Option::is_none")]
1537 pub response_mime_type: Option<String>,
1538 #[serde(skip_serializing_if = "Option::is_none")]
1542 pub response_schema: Option<Schema>,
1543 #[serde(
1549 skip_serializing_if = "Option::is_none",
1550 rename = "_responseJsonSchema"
1551 )]
1552 pub _response_json_schema: Option<Value>,
1553 #[serde(skip_serializing_if = "Option::is_none")]
1555 pub response_json_schema: Option<Value>,
1556 #[serde(skip_serializing_if = "Option::is_none")]
1559 pub candidate_count: Option<i32>,
1560 #[serde(skip_serializing_if = "Option::is_none")]
1563 pub max_output_tokens: Option<u64>,
1564 #[serde(skip_serializing_if = "Option::is_none")]
1567 pub temperature: Option<f64>,
1568 #[serde(skip_serializing_if = "Option::is_none")]
1575 pub top_p: Option<f64>,
1576 #[serde(skip_serializing_if = "Option::is_none")]
1582 pub top_k: Option<i32>,
1583 #[serde(skip_serializing_if = "Option::is_none")]
1589 pub presence_penalty: Option<f64>,
1590 #[serde(skip_serializing_if = "Option::is_none")]
1598 pub frequency_penalty: Option<f64>,
1599 #[serde(skip_serializing_if = "Option::is_none")]
1601 pub response_logprobs: Option<bool>,
1602 #[serde(skip_serializing_if = "Option::is_none")]
1605 pub logprobs: Option<i32>,
1606 #[serde(skip_serializing_if = "Option::is_none")]
1608 pub thinking_config: Option<ThinkingConfig>,
1609 #[serde(skip_serializing_if = "Option::is_none")]
1610 pub image_config: Option<ImageConfig>,
1611 }
1612
1613 impl Default for GenerationConfig {
1614 fn default() -> Self {
1615 Self {
1616 temperature: Some(1.0),
1617 max_output_tokens: Some(4096),
1618 stop_sequences: None,
1619 response_mime_type: None,
1620 response_schema: None,
1621 _response_json_schema: None,
1622 response_json_schema: None,
1623 candidate_count: None,
1624 top_p: None,
1625 top_k: None,
1626 presence_penalty: None,
1627 frequency_penalty: None,
1628 response_logprobs: None,
1629 logprobs: None,
1630 thinking_config: None,
1631 image_config: None,
1632 }
1633 }
1634 }
1635
1636 #[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
1638 #[serde(rename_all = "snake_case")]
1639 pub enum ThinkingLevel {
1640 Minimal,
1641 Low,
1642 Medium,
1643 High,
1644 }
1645
1646 #[derive(Debug, Deserialize, Serialize)]
1650 #[serde(rename_all = "camelCase")]
1651 pub struct ThinkingConfig {
1652 #[serde(skip_serializing_if = "Option::is_none")]
1654 pub thinking_budget: Option<u32>,
1655 #[serde(skip_serializing_if = "Option::is_none")]
1657 pub thinking_level: Option<ThinkingLevel>,
1658 #[serde(skip_serializing_if = "Option::is_none")]
1660 pub include_thoughts: Option<bool>,
1661 }
1662
1663 #[derive(Debug, Deserialize, Serialize)]
1664 #[serde(rename_all = "camelCase")]
1665 pub struct ImageConfig {
1666 #[serde(skip_serializing_if = "Option::is_none")]
1667 pub aspect_ratio: Option<String>,
1668 #[serde(skip_serializing_if = "Option::is_none")]
1669 pub image_size: Option<String>,
1670 }
1671
1672 #[derive(Debug, Deserialize, Serialize, Clone)]
1676 pub struct Schema {
1677 pub r#type: String,
1678 #[serde(skip_serializing_if = "Option::is_none")]
1679 pub format: Option<String>,
1680 #[serde(skip_serializing_if = "Option::is_none")]
1681 pub description: Option<String>,
1682 #[serde(skip_serializing_if = "Option::is_none")]
1683 pub nullable: Option<bool>,
1684 #[serde(skip_serializing_if = "Option::is_none")]
1685 pub r#enum: Option<Vec<String>>,
1686 #[serde(skip_serializing_if = "Option::is_none")]
1687 pub max_items: Option<i32>,
1688 #[serde(skip_serializing_if = "Option::is_none")]
1689 pub min_items: Option<i32>,
1690 #[serde(skip_serializing_if = "Option::is_none")]
1691 pub properties: Option<HashMap<String, Schema>>,
1692 #[serde(skip_serializing_if = "Option::is_none")]
1693 pub required: Option<Vec<String>>,
1694 #[serde(skip_serializing_if = "Option::is_none")]
1695 pub items: Option<Box<Schema>>,
1696 }
1697
1698 pub fn tool_parameters_to_schema(parameters: Value) -> Result<Option<Schema>, CompletionError> {
1704 if parameters.is_null() || parameters == json!({"type": "object", "properties": {}}) {
1705 Ok(None)
1706 } else {
1707 parameters.try_into().map(Some)
1708 }
1709 }
1710
1711 pub fn flatten_schema(mut schema: Value) -> Result<Value, CompletionError> {
1717 let defs = if let Some(obj) = schema.as_object() {
1719 obj.get("$defs").or_else(|| obj.get("definitions")).cloned()
1720 } else {
1721 None
1722 };
1723
1724 let Some(defs_value) = defs else {
1725 return Ok(schema);
1726 };
1727
1728 let Some(defs_obj) = defs_value.as_object() else {
1729 return Err(CompletionError::ResponseError(
1730 "$defs must be an object".into(),
1731 ));
1732 };
1733
1734 resolve_refs(&mut schema, defs_obj)?;
1735
1736 if let Some(obj) = schema.as_object_mut() {
1738 obj.remove("$defs");
1739 obj.remove("definitions");
1740 }
1741
1742 Ok(schema)
1743 }
1744
1745 fn resolve_refs(
1748 value: &mut Value,
1749 defs: &serde_json::Map<String, Value>,
1750 ) -> Result<(), CompletionError> {
1751 match value {
1752 Value::Object(obj) => {
1753 if let Some(ref_value) = obj.get("$ref")
1754 && let Some(ref_str) = ref_value.as_str()
1755 {
1756 let def_name = parse_ref_path(ref_str)?;
1758
1759 let def = defs.get(&def_name).ok_or_else(|| {
1760 CompletionError::ResponseError(format!("Reference not found: {}", ref_str))
1761 })?;
1762
1763 let mut resolved = def.clone();
1764 resolve_refs(&mut resolved, defs)?;
1765 *value = resolved;
1766 return Ok(());
1767 }
1768
1769 for (_, v) in obj.iter_mut() {
1770 resolve_refs(v, defs)?;
1771 }
1772 }
1773 Value::Array(arr) => {
1774 for item in arr.iter_mut() {
1775 resolve_refs(item, defs)?;
1776 }
1777 }
1778 _ => {}
1779 }
1780
1781 Ok(())
1782 }
1783
1784 fn parse_ref_path(ref_str: &str) -> Result<String, CompletionError> {
1790 if let Some(fragment) = ref_str.strip_prefix('#') {
1791 if let Some(name) = fragment.strip_prefix("/$defs/") {
1792 Ok(name.to_string())
1793 } else if let Some(name) = fragment.strip_prefix("/definitions/") {
1794 Ok(name.to_string())
1795 } else {
1796 Err(CompletionError::ResponseError(format!(
1797 "Unsupported reference format: {}",
1798 ref_str
1799 )))
1800 }
1801 } else {
1802 Err(CompletionError::ResponseError(format!(
1803 "Only fragment references (#/...) are supported: {}",
1804 ref_str
1805 )))
1806 }
1807 }
1808
1809 fn extract_type(type_value: &Value) -> Option<String> {
1812 if let Some(t) = type_value.as_str() {
1813 return Some(t.to_string());
1814 }
1815
1816 type_value.as_array().and_then(|arr| {
1817 arr.iter()
1818 .filter_map(|v| v.as_str())
1819 .find(|t| *t != "null")
1820 .or_else(|| arr.iter().find_map(|v| v.as_str()))
1821 .map(str::to_owned)
1822 })
1823 }
1824
1825 fn schema_is_null(obj: &serde_json::Map<String, Value>) -> bool {
1826 obj.get("type")
1827 .and_then(extract_type)
1828 .as_deref()
1829 .is_some_and(|t| t == "null")
1830 }
1831
1832 fn schema_is_nullable(obj: &serde_json::Map<String, Value>) -> bool {
1833 obj.get("nullable")
1834 .and_then(|v| v.as_bool())
1835 .unwrap_or(false)
1836 || obj
1837 .get("type")
1838 .and_then(|v| v.as_array())
1839 .is_some_and(|arr| arr.iter().any(|v| v.as_str() == Some("null")))
1840 || ["anyOf", "oneOf", "allOf"].iter().any(|key| {
1841 obj.get(*key).and_then(|v| v.as_array()).is_some_and(|arr| {
1842 arr.iter()
1843 .filter_map(|schema| schema.as_object())
1844 .any(schema_is_null)
1845 })
1846 })
1847 }
1848
1849 fn extract_type_from_composition(composition: &Value) -> Option<String> {
1852 composition.as_array().and_then(|arr| {
1853 arr.iter().find_map(|schema| {
1854 let obj = schema.as_object()?;
1855 if schema_is_null(obj) {
1856 return None;
1857 }
1858
1859 obj.get("type").and_then(extract_type).or_else(|| {
1860 if obj.contains_key("properties") {
1861 Some("object".to_string())
1862 } else if obj.contains_key("enum") {
1863 Some("string".to_string())
1865 } else {
1866 None
1867 }
1868 })
1869 })
1870 })
1871 }
1872
1873 fn extract_schema_from_composition(
1876 composition: &Value,
1877 ) -> Option<serde_json::Map<String, Value>> {
1878 composition.as_array().and_then(|arr| {
1879 arr.iter().find_map(|schema| {
1880 let obj = schema.as_object()?;
1881 if schema_is_null(obj) {
1882 None
1883 } else {
1884 Some(obj.clone())
1885 }
1886 })
1887 })
1888 }
1889
1890 fn extract_schema_from_composition_obj(
1891 obj: &serde_json::Map<String, Value>,
1892 ) -> Option<serde_json::Map<String, Value>> {
1893 obj.get("anyOf")
1894 .and_then(extract_schema_from_composition)
1895 .or_else(|| obj.get("oneOf").and_then(extract_schema_from_composition))
1896 .or_else(|| obj.get("allOf").and_then(extract_schema_from_composition))
1897 }
1898
1899 fn infer_type(obj: &serde_json::Map<String, Value>) -> String {
1902 if let Some(type_val) = obj.get("type")
1904 && let Some(type_str) = extract_type(type_val)
1905 {
1906 return type_str;
1907 }
1908
1909 if let Some(any_of) = obj.get("anyOf")
1911 && let Some(type_str) = extract_type_from_composition(any_of)
1912 {
1913 return type_str;
1914 }
1915
1916 if let Some(one_of) = obj.get("oneOf")
1917 && let Some(type_str) = extract_type_from_composition(one_of)
1918 {
1919 return type_str;
1920 }
1921
1922 if let Some(all_of) = obj.get("allOf")
1923 && let Some(type_str) = extract_type_from_composition(all_of)
1924 {
1925 return type_str;
1926 }
1927
1928 if obj.contains_key("properties") {
1930 "object".to_string()
1931 } else if obj.contains_key("enum") {
1932 "string".to_string()
1933 } else {
1934 String::new()
1935 }
1936 }
1937
1938 impl TryFrom<Value> for Schema {
1939 type Error = CompletionError;
1940
1941 fn try_from(value: Value) -> Result<Self, Self::Error> {
1942 let flattened_val = flatten_schema(value)?;
1943 if let Some(obj) = flattened_val.as_object() {
1944 let composition_source = extract_schema_from_composition_obj(obj);
1947 let props_source = if obj.get("properties").is_none() {
1948 composition_source.clone().unwrap_or(obj.clone())
1949 } else {
1950 obj.clone()
1951 };
1952
1953 let schema_type = infer_type(obj);
1954 let items = obj
1955 .get("items")
1956 .or_else(|| props_source.get("items"))
1957 .and_then(|v| v.clone().try_into().ok())
1958 .map(Box::new);
1959
1960 let items = if schema_type == "array" && items.is_none() {
1963 Some(Box::new(Schema {
1964 r#type: "string".to_string(),
1965 format: None,
1966 description: None,
1967 nullable: None,
1968 r#enum: None,
1969 max_items: None,
1970 min_items: None,
1971 properties: None,
1972 required: None,
1973 items: None,
1974 }))
1975 } else {
1976 items
1977 };
1978
1979 Ok(Schema {
1980 r#type: schema_type,
1981 format: obj
1982 .get("format")
1983 .or_else(|| props_source.get("format"))
1984 .and_then(|v| v.as_str())
1985 .map(String::from),
1986 description: obj
1987 .get("description")
1988 .or_else(|| props_source.get("description"))
1989 .and_then(|v| v.as_str())
1990 .map(String::from),
1991 nullable: if schema_is_nullable(obj)
1992 || composition_source.as_ref().is_some_and(schema_is_nullable)
1993 {
1994 Some(true)
1995 } else {
1996 None
1997 },
1998 r#enum: obj
1999 .get("enum")
2000 .or_else(|| props_source.get("enum"))
2001 .and_then(|v| v.as_array())
2002 .map(|arr| {
2003 arr.iter()
2004 .filter_map(|v| v.as_str().map(String::from))
2005 .collect()
2006 }),
2007 max_items: obj
2008 .get("maxItems")
2009 .and_then(|v| v.as_i64())
2010 .map(|v| v as i32),
2011 min_items: obj
2012 .get("minItems")
2013 .and_then(|v| v.as_i64())
2014 .map(|v| v as i32),
2015 properties: props_source
2016 .get("properties")
2017 .and_then(|v| v.as_object())
2018 .map(|map| {
2019 map.iter()
2020 .filter_map(|(k, v)| {
2021 v.clone().try_into().ok().map(|schema| (k.clone(), schema))
2022 })
2023 .collect()
2024 }),
2025 required: props_source
2026 .get("required")
2027 .and_then(|v| v.as_array())
2028 .map(|arr| {
2029 arr.iter()
2030 .filter_map(|v| v.as_str().map(String::from))
2031 .collect()
2032 }),
2033 items,
2034 })
2035 } else {
2036 Err(CompletionError::ResponseError(
2037 "Expected a JSON object for Schema".into(),
2038 ))
2039 }
2040 }
2041 }
2042
2043 #[derive(Debug, Serialize)]
2044 #[serde(rename_all = "camelCase")]
2045 pub struct GenerateContentRequest {
2046 pub contents: Vec<Content>,
2047 #[serde(skip_serializing_if = "Option::is_none")]
2048 pub tools: Option<Vec<Value>>,
2049 pub tool_config: Option<ToolConfig>,
2050 pub generation_config: Option<GenerationConfig>,
2052 pub safety_settings: Option<Vec<SafetySetting>>,
2066 pub system_instruction: Option<Content>,
2069 #[serde(flatten, skip_serializing_if = "Option::is_none")]
2072 pub additional_params: Option<serde_json::Value>,
2073 }
2074
2075 #[derive(Debug, Serialize)]
2076 #[serde(rename_all = "camelCase")]
2077 pub struct Tool {
2078 pub function_declarations: Vec<FunctionDeclaration>,
2079 pub code_execution: Option<CodeExecution>,
2080 }
2081
2082 #[derive(Debug, Serialize, Clone)]
2083 #[serde(rename_all = "camelCase")]
2084 pub struct FunctionDeclaration {
2085 pub name: String,
2086 pub description: String,
2087 #[serde(skip_serializing_if = "Option::is_none")]
2088 pub parameters: Option<Schema>,
2089 }
2090
2091 #[derive(Debug, Serialize, Deserialize)]
2092 #[serde(rename_all = "camelCase")]
2093 pub struct ToolConfig {
2094 pub function_calling_config: Option<FunctionCallingMode>,
2095 }
2096
2097 #[derive(Debug, Serialize, Deserialize, Default)]
2098 #[serde(tag = "mode", rename_all = "UPPERCASE")]
2099 pub enum FunctionCallingMode {
2100 #[default]
2101 Auto,
2102 None,
2103 Any {
2104 #[serde(skip_serializing_if = "Option::is_none")]
2105 allowed_function_names: Option<Vec<String>>,
2106 },
2107 }
2108
2109 impl TryFrom<message::ToolChoice> for FunctionCallingMode {
2110 type Error = CompletionError;
2111 fn try_from(value: message::ToolChoice) -> Result<Self, Self::Error> {
2112 let res = match value {
2113 message::ToolChoice::Auto => Self::Auto,
2114 message::ToolChoice::None => Self::None,
2115 message::ToolChoice::Required => Self::Any {
2116 allowed_function_names: None,
2117 },
2118 message::ToolChoice::Specific { function_names } => Self::Any {
2119 allowed_function_names: Some(function_names),
2120 },
2121 };
2122
2123 Ok(res)
2124 }
2125 }
2126
2127 #[derive(Debug, Serialize)]
2128 pub struct CodeExecution {}
2129
2130 #[derive(Debug, Serialize)]
2131 #[serde(rename_all = "camelCase")]
2132 pub struct SafetySetting {
2133 pub category: HarmCategory,
2134 pub threshold: HarmBlockThreshold,
2135 }
2136
2137 #[derive(Debug, Serialize)]
2138 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
2139 pub enum HarmBlockThreshold {
2140 HarmBlockThresholdUnspecified,
2141 BlockLowAndAbove,
2142 BlockMediumAndAbove,
2143 BlockOnlyHigh,
2144 BlockNone,
2145 Off,
2146 }
2147}
2148
2149#[cfg(test)]
2150mod tests {
2151 use crate::{
2152 message,
2153 providers::gemini::completion::gemini_api_types::{
2154 ContentCandidate, FinishReason, FunctionCall, Schema, UsageMetadata, flatten_schema,
2155 tool_parameters_to_schema,
2156 },
2157 };
2158
2159 use super::*;
2160 use serde_json::json;
2161
2162 #[test]
2163 fn test_resolve_request_model_uses_override() {
2164 let request = CompletionRequest {
2165 model: Some("gemini-2.5-flash".to_string()),
2166 preamble: None,
2167 chat_history: crate::OneOrMany::one("Hello".into()),
2168 documents: vec![],
2169 tools: vec![],
2170 temperature: None,
2171 max_tokens: None,
2172 tool_choice: None,
2173 additional_params: None,
2174 output_schema: None,
2175 };
2176
2177 let request_model = resolve_request_model("gemini-2.0-flash", &request);
2178 assert_eq!(request_model, "gemini-2.5-flash");
2179 assert_eq!(
2180 completion_endpoint(&request_model),
2181 "/v1beta/models/gemini-2.5-flash:generateContent"
2182 );
2183 assert_eq!(
2184 streaming_endpoint(&request_model),
2185 "/v1beta/models/gemini-2.5-flash:streamGenerateContent"
2186 );
2187 }
2188
2189 #[test]
2190 fn test_resolve_request_model_uses_default_when_unset() {
2191 let request = CompletionRequest {
2192 model: None,
2193 preamble: None,
2194 chat_history: crate::OneOrMany::one("Hello".into()),
2195 documents: vec![],
2196 tools: vec![],
2197 temperature: None,
2198 max_tokens: None,
2199 tool_choice: None,
2200 additional_params: None,
2201 output_schema: None,
2202 };
2203
2204 assert_eq!(
2205 resolve_request_model("gemini-2.0-flash", &request),
2206 "gemini-2.0-flash"
2207 );
2208 }
2209
2210 #[test]
2211 fn test_deserialize_message_user() {
2212 let raw_message = r#"{
2213 "parts": [
2214 {"text": "Hello, world!"},
2215 {"inlineData": {"mimeType": "image/png", "data": "base64encodeddata"}},
2216 {"functionCall": {"name": "test_function", "args": {"arg1": "value1"}}},
2217 {"functionResponse": {"name": "test_function", "response": {"result": "success"}}},
2218 {"fileData": {"mimeType": "application/pdf", "fileUri": "http://example.com/file.pdf"}},
2219 {"executableCode": {"code": "print('Hello, world!')", "language": "PYTHON"}},
2220 {"codeExecutionResult": {"output": "Hello, world!", "outcome": "OUTCOME_OK"}}
2221 ],
2222 "role": "user"
2223 }"#;
2224
2225 let content: Content = {
2226 let jd = &mut serde_json::Deserializer::from_str(raw_message);
2227 serde_path_to_error::deserialize(jd).unwrap_or_else(|err| {
2228 panic!("Deserialization error at {}: {}", err.path(), err);
2229 })
2230 };
2231 assert_eq!(content.role, Some(Role::User));
2232 assert_eq!(content.parts.len(), 7);
2233
2234 let parts: Vec<Part> = content.parts.into_iter().collect();
2235
2236 if let Part {
2237 part: PartKind::Text(text),
2238 ..
2239 } = &parts[0]
2240 {
2241 assert_eq!(text, "Hello, world!");
2242 } else {
2243 panic!("Expected text part");
2244 }
2245
2246 if let Part {
2247 part: PartKind::InlineData(inline_data),
2248 ..
2249 } = &parts[1]
2250 {
2251 assert_eq!(inline_data.mime_type, "image/png");
2252 assert_eq!(inline_data.data, "base64encodeddata");
2253 } else {
2254 panic!("Expected inline data part");
2255 }
2256
2257 if let Part {
2258 part: PartKind::FunctionCall(function_call),
2259 ..
2260 } = &parts[2]
2261 {
2262 assert_eq!(function_call.name, "test_function");
2263 assert_eq!(
2264 function_call.args.as_object().unwrap().get("arg1").unwrap(),
2265 "value1"
2266 );
2267 } else {
2268 panic!("Expected function call part");
2269 }
2270
2271 if let Part {
2272 part: PartKind::FunctionResponse(function_response),
2273 ..
2274 } = &parts[3]
2275 {
2276 assert_eq!(function_response.name, "test_function");
2277 assert_eq!(
2278 function_response
2279 .response
2280 .as_ref()
2281 .unwrap()
2282 .get("result")
2283 .unwrap(),
2284 "success"
2285 );
2286 } else {
2287 panic!("Expected function response part");
2288 }
2289
2290 if let Part {
2291 part: PartKind::FileData(file_data),
2292 ..
2293 } = &parts[4]
2294 {
2295 assert_eq!(file_data.mime_type.as_ref().unwrap(), "application/pdf");
2296 assert_eq!(file_data.file_uri, "http://example.com/file.pdf");
2297 } else {
2298 panic!("Expected file data part");
2299 }
2300
2301 if let Part {
2302 part: PartKind::ExecutableCode(executable_code),
2303 ..
2304 } = &parts[5]
2305 {
2306 assert_eq!(executable_code.code, "print('Hello, world!')");
2307 } else {
2308 panic!("Expected executable code part");
2309 }
2310
2311 if let Part {
2312 part: PartKind::CodeExecutionResult(code_execution_result),
2313 ..
2314 } = &parts[6]
2315 {
2316 assert_eq!(
2317 code_execution_result.clone().output.unwrap(),
2318 "Hello, world!"
2319 );
2320 } else {
2321 panic!("Expected code execution result part");
2322 }
2323 }
2324
2325 #[test]
2326 fn test_deserialize_message_model() {
2327 let json_data = json!({
2328 "parts": [{"text": "Hello, user!"}],
2329 "role": "model"
2330 });
2331
2332 let content: Content = serde_json::from_value(json_data).unwrap();
2333 assert_eq!(content.role, Some(Role::Model));
2334 assert_eq!(content.parts.len(), 1);
2335 if let Some(Part {
2336 part: PartKind::Text(text),
2337 ..
2338 }) = content.parts.first()
2339 {
2340 assert_eq!(text, "Hello, user!");
2341 } else {
2342 panic!("Expected text part");
2343 }
2344 }
2345
2346 #[test]
2347 fn test_message_conversion_user() {
2348 let msg = message::Message::user("Hello, world!");
2349 let content: Content = msg.try_into().unwrap();
2350 assert_eq!(content.role, Some(Role::User));
2351 assert_eq!(content.parts.len(), 1);
2352 if let Some(Part {
2353 part: PartKind::Text(text),
2354 ..
2355 }) = &content.parts.first()
2356 {
2357 assert_eq!(text, "Hello, world!");
2358 } else {
2359 panic!("Expected text part");
2360 }
2361 }
2362
2363 #[test]
2364 fn test_message_conversion_model() {
2365 let msg = message::Message::assistant("Hello, user!");
2366
2367 let content: Content = msg.try_into().unwrap();
2368 assert_eq!(content.role, Some(Role::Model));
2369 assert_eq!(content.parts.len(), 1);
2370 if let Some(Part {
2371 part: PartKind::Text(text),
2372 ..
2373 }) = &content.parts.first()
2374 {
2375 assert_eq!(text, "Hello, user!");
2376 } else {
2377 panic!("Expected text part");
2378 }
2379 }
2380
2381 #[test]
2382 fn test_thought_signature_is_preserved_from_response_reasoning_part() {
2383 let response = GenerateContentResponse {
2384 response_id: "resp_1".to_string(),
2385 candidates: vec![ContentCandidate {
2386 content: Some(Content {
2387 parts: vec![Part {
2388 thought: Some(true),
2389 thought_signature: Some("thought_sig_123".to_string()),
2390 part: PartKind::Text("thinking text".to_string()),
2391 additional_params: None,
2392 }],
2393 role: Some(Role::Model),
2394 }),
2395 finish_reason: Some(FinishReason::Stop),
2396 safety_ratings: None,
2397 citation_metadata: None,
2398 token_count: None,
2399 avg_logprobs: None,
2400 logprobs_result: None,
2401 index: Some(0),
2402 finish_message: None,
2403 }],
2404 prompt_feedback: None,
2405 usage_metadata: None,
2406 model_version: None,
2407 };
2408
2409 let converted: crate::completion::CompletionResponse<GenerateContentResponse> =
2410 response.try_into().expect("convert response");
2411 let first = converted.choice.first();
2412 assert!(matches!(
2413 first,
2414 message::AssistantContent::Reasoning(message::Reasoning { content, .. })
2415 if matches!(
2416 content.first(),
2417 Some(message::ReasoningContent::Text {
2418 text,
2419 signature: Some(signature)
2420 }) if text == "thinking text" && signature == "thought_sig_123"
2421 )
2422 ));
2423 }
2424
2425 #[test]
2426 fn test_tool_protocol_finish_reason_returns_response_error() {
2427 for (reason, finish_message) in [
2428 (
2429 FinishReason::MalformedFunctionCall,
2430 "malformed function call: default_api",
2431 ),
2432 (
2433 FinishReason::UnexpectedToolCall,
2434 "unexpected tool call: default_api",
2435 ),
2436 (
2437 FinishReason::MissingThoughtSignature,
2438 "missing thought signature for tool call",
2439 ),
2440 (
2441 FinishReason::TooManyToolCalls,
2442 "too many tool calls in response",
2443 ),
2444 (
2445 FinishReason::MalformedResponse,
2446 "malformed response from provider",
2447 ),
2448 ] {
2449 let reason_name = format!("{reason:?}");
2450 let response = GenerateContentResponse {
2451 response_id: "resp_tool_protocol_error".to_string(),
2452 candidates: vec![ContentCandidate {
2453 content: Some(Content {
2454 parts: vec![Part {
2455 thought: None,
2456 thought_signature: None,
2457 part: PartKind::FunctionCall(FunctionCall {
2458 name: "default_api".to_string(),
2459 args: json!({"x": 1}),
2460 }),
2461 additional_params: None,
2462 }],
2463 role: Some(Role::Model),
2464 }),
2465 finish_reason: Some(reason),
2466 safety_ratings: None,
2467 citation_metadata: None,
2468 token_count: None,
2469 avg_logprobs: None,
2470 logprobs_result: None,
2471 index: Some(0),
2472 finish_message: Some(finish_message.to_string()),
2473 }],
2474 prompt_feedback: None,
2475 usage_metadata: None,
2476 model_version: None,
2477 };
2478
2479 let err = crate::completion::CompletionResponse::<GenerateContentResponse>::try_from(
2480 response,
2481 )
2482 .expect_err("tool protocol finish reason should fail");
2483
2484 assert!(matches!(
2485 err,
2486 CompletionError::ResponseError(message)
2487 if message.contains(&reason_name)
2488 && message.contains(finish_message)
2489 ));
2490 }
2491 }
2492
2493 #[test]
2494 fn test_completion_response_usage_preserves_cached_and_reasoning_tokens() {
2495 let response = GenerateContentResponse {
2496 response_id: "resp_1".to_string(),
2497 candidates: vec![ContentCandidate {
2498 content: Some(Content {
2499 parts: vec![Part {
2500 thought: None,
2501 thought_signature: None,
2502 part: PartKind::Text("answer".to_string()),
2503 additional_params: None,
2504 }],
2505 role: Some(Role::Model),
2506 }),
2507 finish_reason: Some(FinishReason::Stop),
2508 safety_ratings: None,
2509 citation_metadata: None,
2510 token_count: None,
2511 avg_logprobs: None,
2512 logprobs_result: None,
2513 index: Some(0),
2514 finish_message: None,
2515 }],
2516 prompt_feedback: None,
2517 usage_metadata: Some(UsageMetadata {
2518 prompt_token_count: 40,
2519 cached_content_token_count: Some(20),
2520 candidates_token_count: Some(30),
2521 total_token_count: 100,
2522 thoughts_token_count: Some(10),
2523 prompt_tokens_details: None,
2524 cache_tokens_details: None,
2525 candidates_tokens_details: None,
2526 tool_use_prompt_token_count: Some(12),
2527 tool_use_prompt_tokens_details: None,
2528 traffic_type: None,
2529 }),
2530 model_version: Some("gemini-2.0-flash-001".to_string()),
2531 };
2532
2533 let converted: crate::completion::CompletionResponse<GenerateContentResponse> =
2534 response.try_into().expect("convert response");
2535
2536 assert_eq!(converted.usage.input_tokens, 40);
2537 assert_eq!(converted.usage.cached_input_tokens, 20);
2538 assert_eq!(converted.usage.output_tokens, 30);
2539 assert_eq!(converted.usage.reasoning_tokens, 10);
2540 assert_eq!(converted.usage.tool_use_prompt_tokens, 12);
2541 assert_eq!(converted.usage.total_tokens, 100);
2542 }
2543
2544 #[test]
2545 fn test_reasoning_signature_is_emitted_in_gemini_part() {
2546 let msg = message::Message::Assistant {
2547 id: None,
2548 content: OneOrMany::one(message::AssistantContent::Reasoning(
2549 message::Reasoning::new_with_signature(
2550 "structured thought",
2551 Some("reuse_sig_456".to_string()),
2552 ),
2553 )),
2554 };
2555
2556 let converted: Content = msg.try_into().expect("convert message");
2557 let first = converted.parts.first().expect("reasoning part");
2558 assert_eq!(first.thought, Some(true));
2559 assert_eq!(first.thought_signature.as_deref(), Some("reuse_sig_456"));
2560 assert!(matches!(
2561 &first.part,
2562 PartKind::Text(text) if text == "structured thought"
2563 ));
2564 }
2565
2566 #[test]
2567 fn test_message_conversion_tool_call() {
2568 let tool_call = message::ToolCall {
2569 id: "test_tool".to_string(),
2570 call_id: None,
2571 function: message::ToolFunction {
2572 name: "test_function".to_string(),
2573 arguments: json!({"arg1": "value1"}),
2574 },
2575 signature: None,
2576 additional_params: None,
2577 };
2578
2579 let msg = message::Message::Assistant {
2580 id: None,
2581 content: OneOrMany::one(message::AssistantContent::ToolCall(tool_call)),
2582 };
2583
2584 let content: Content = msg.try_into().unwrap();
2585 assert_eq!(content.role, Some(Role::Model));
2586 assert_eq!(content.parts.len(), 1);
2587 if let Some(Part {
2588 part: PartKind::FunctionCall(function_call),
2589 ..
2590 }) = content.parts.first()
2591 {
2592 assert_eq!(function_call.name, "test_function");
2593 assert_eq!(
2594 function_call.args.as_object().unwrap().get("arg1").unwrap(),
2595 "value1"
2596 );
2597 } else {
2598 panic!("Expected function call part");
2599 }
2600 }
2601
2602 #[test]
2603 fn test_vec_schema_conversion() {
2604 let schema_with_ref = json!({
2605 "type": "array",
2606 "items": {
2607 "$ref": "#/$defs/Person"
2608 },
2609 "$defs": {
2610 "Person": {
2611 "type": "object",
2612 "properties": {
2613 "first_name": {
2614 "type": ["string", "null"],
2615 "description": "The person's first name, if provided (null otherwise)"
2616 },
2617 "last_name": {
2618 "type": ["string", "null"],
2619 "description": "The person's last name, if provided (null otherwise)"
2620 },
2621 "job": {
2622 "type": ["string", "null"],
2623 "description": "The person's job, if provided (null otherwise)"
2624 }
2625 },
2626 "required": []
2627 }
2628 }
2629 });
2630
2631 let result: Result<Schema, _> = schema_with_ref.try_into();
2632
2633 match result {
2634 Ok(schema) => {
2635 assert_eq!(schema.r#type, "array");
2636
2637 if let Some(items) = schema.items {
2638 println!("item types: {}", items.r#type);
2639
2640 assert_ne!(items.r#type, "", "Items type should not be empty string!");
2641 assert_eq!(items.r#type, "object", "Items should be object type");
2642 } else {
2643 panic!("Schema should have items field for array type");
2644 }
2645 }
2646 Err(e) => println!("Schema conversion failed: {:?}", e),
2647 }
2648 }
2649
2650 #[test]
2651 fn test_object_schema() {
2652 let simple_schema = json!({
2653 "type": "object",
2654 "properties": {
2655 "name": {
2656 "type": "string"
2657 }
2658 }
2659 });
2660
2661 let schema: Schema = simple_schema.try_into().unwrap();
2662 assert_eq!(schema.r#type, "object");
2663 assert!(schema.properties.is_some());
2664 }
2665
2666 #[test]
2667 fn test_array_with_inline_items() {
2668 let inline_schema = json!({
2669 "type": "array",
2670 "items": {
2671 "type": "object",
2672 "properties": {
2673 "name": {
2674 "type": "string"
2675 }
2676 }
2677 }
2678 });
2679
2680 let schema: Schema = inline_schema.try_into().unwrap();
2681 assert_eq!(schema.r#type, "array");
2682
2683 if let Some(items) = schema.items {
2684 assert_eq!(items.r#type, "object");
2685 assert!(items.properties.is_some());
2686 } else {
2687 panic!("Schema should have items field");
2688 }
2689 }
2690 #[test]
2691 fn test_flattened_schema() {
2692 let ref_schema = json!({
2693 "type": "array",
2694 "items": {
2695 "$ref": "#/$defs/Person"
2696 },
2697 "$defs": {
2698 "Person": {
2699 "type": "object",
2700 "properties": {
2701 "name": { "type": "string" }
2702 }
2703 }
2704 }
2705 });
2706
2707 let flattened = flatten_schema(ref_schema).unwrap();
2708 let schema: Schema = flattened.try_into().unwrap();
2709
2710 assert_eq!(schema.r#type, "array");
2711
2712 if let Some(items) = schema.items {
2713 println!("Flattened items type: '{}'", items.r#type);
2714
2715 assert_eq!(items.r#type, "object");
2716 assert!(items.properties.is_some());
2717 }
2718 }
2719
2720 #[test]
2721 fn test_array_without_items_gets_default() {
2722 let schema_json = json!({
2723 "type": "object",
2724 "properties": {
2725 "service_ids": {
2726 "type": "array",
2727 "description": "A list of service IDs"
2728 }
2729 }
2730 });
2731
2732 let schema: Schema = schema_json.try_into().unwrap();
2733 let props = schema.properties.unwrap();
2734 let service_ids = props.get("service_ids").unwrap();
2735 assert_eq!(service_ids.r#type, "array");
2736 let items = service_ids
2737 .items
2738 .as_ref()
2739 .expect("array schema missing items should get a default");
2740 assert_eq!(items.r#type, "string");
2741 }
2742
2743 #[test]
2744 fn test_tool_parameters_to_schema_maps_no_arg_tool_to_none() {
2745 let schema = tool_parameters_to_schema(json!({"type": "object", "properties": {}}))
2746 .expect("schema conversion");
2747
2748 assert!(schema.is_none());
2749 }
2750
2751 #[test]
2752 fn test_tool_parameters_to_schema_resolves_defs_ref() {
2753 let schema_json = json!({
2754 "type": "object",
2755 "properties": {
2756 "destination": { "$ref": "#/$defs/Destination" }
2757 },
2758 "required": ["destination"],
2759 "$defs": {
2760 "Destination": {
2761 "type": "object",
2762 "properties": {
2763 "city": { "type": "string" }
2764 },
2765 "required": ["city"]
2766 }
2767 }
2768 });
2769
2770 let schema = tool_parameters_to_schema(schema_json)
2771 .expect("schema conversion")
2772 .expect("schema");
2773 let props = schema.properties.expect("properties");
2774 let destination = props.get("destination").expect("destination prop");
2775
2776 assert_eq!(destination.r#type, "object");
2777 assert_eq!(destination.required, Some(vec!["city".to_string()]));
2778 }
2779
2780 #[test]
2781 fn test_tool_parameters_to_schema_handles_nullable_type_arrays() {
2782 let schema_json = json!({
2783 "type": "object",
2784 "properties": {
2785 "nickname": { "type": ["null", "string"] }
2786 }
2787 });
2788
2789 let schema = tool_parameters_to_schema(schema_json)
2790 .expect("schema conversion")
2791 .expect("schema");
2792 let props = schema.properties.expect("properties");
2793 let nickname = props.get("nickname").expect("nickname prop");
2794
2795 assert_eq!(nickname.r#type, "string");
2796 assert_eq!(nickname.nullable, Some(true));
2797 }
2798
2799 #[test]
2800 fn test_txt_document_conversion_to_text_part() {
2801 use crate::message::{DocumentMediaType, UserContent};
2803
2804 let doc = UserContent::document(
2805 "Note: test.md\nPath: /test.md\nContent: Hello World!",
2806 Some(DocumentMediaType::TXT),
2807 );
2808
2809 let content: Content = message::Message::User {
2810 content: crate::OneOrMany::one(doc),
2811 }
2812 .try_into()
2813 .unwrap();
2814
2815 if let Part {
2816 part: PartKind::Text(text),
2817 ..
2818 } = &content.parts[0]
2819 {
2820 assert!(text.contains("Note: test.md"));
2821 assert!(text.contains("Hello World!"));
2822 } else {
2823 panic!(
2824 "Expected text part for TXT document, got: {:?}",
2825 content.parts[0]
2826 );
2827 }
2828 }
2829
2830 #[test]
2831 fn test_tool_result_with_image_content() {
2832 use crate::OneOrMany;
2834 use crate::message::{
2835 DocumentSourceKind, Image, ImageMediaType, ToolResult, ToolResultContent,
2836 };
2837
2838 let tool_result = ToolResult {
2840 id: "test_tool".to_string(),
2841 call_id: None,
2842 content: OneOrMany::many(vec![
2843 ToolResultContent::Text(message::Text::new(r#"{"status": "success"}"#.to_string())),
2844 ToolResultContent::Image(Image {
2845 data: DocumentSourceKind::Base64("iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==".to_string()),
2846 media_type: Some(ImageMediaType::PNG),
2847 detail: None,
2848 additional_params: None,
2849 }),
2850 ]).expect("Should create OneOrMany with multiple items"),
2851 };
2852
2853 let user_content = message::UserContent::ToolResult(tool_result);
2854 let msg = message::Message::User {
2855 content: OneOrMany::one(user_content),
2856 };
2857
2858 let content: Content = msg.try_into().expect("Should convert to Gemini Content");
2860 assert_eq!(content.role, Some(Role::User));
2861 assert_eq!(content.parts.len(), 1);
2862
2863 if let Some(Part {
2865 part: PartKind::FunctionResponse(function_response),
2866 ..
2867 }) = content.parts.first()
2868 {
2869 assert_eq!(function_response.name, "test_tool");
2870
2871 assert!(function_response.response.is_some());
2873 let response = function_response.response.as_ref().unwrap();
2874 assert!(response.get("result").is_some());
2875
2876 assert!(function_response.parts.is_some());
2878 let parts = function_response.parts.as_ref().unwrap();
2879 assert_eq!(parts.len(), 1);
2880
2881 let image_part = &parts[0];
2882 assert!(image_part.inline_data.is_some());
2883 let inline_data = image_part.inline_data.as_ref().unwrap();
2884 assert_eq!(inline_data.mime_type, "image/png");
2885 assert!(!inline_data.data.is_empty());
2886 } else {
2887 panic!("Expected FunctionResponse part");
2888 }
2889 }
2890
2891 #[test]
2892 fn test_markdown_document_conversion_to_text_part() {
2893 use crate::message::{DocumentMediaType, UserContent};
2895
2896 let doc = UserContent::document(
2897 "# Heading\n\n* List item",
2898 Some(DocumentMediaType::MARKDOWN),
2899 );
2900
2901 let content: Content = message::Message::User {
2902 content: crate::OneOrMany::one(doc),
2903 }
2904 .try_into()
2905 .unwrap();
2906
2907 if let Part {
2908 part: PartKind::Text(text),
2909 ..
2910 } = &content.parts[0]
2911 {
2912 assert_eq!(text, "# Heading\n\n* List item");
2913 } else {
2914 panic!(
2915 "Expected text part for MARKDOWN document, got: {:?}",
2916 content.parts[0]
2917 );
2918 }
2919 }
2920
2921 #[test]
2922 fn test_markdown_url_document_conversion_to_file_data_part() {
2923 use crate::message::{DocumentMediaType, DocumentSourceKind, UserContent};
2925
2926 let doc = UserContent::Document(message::Document {
2927 data: DocumentSourceKind::Url(
2928 "https://generativelanguage.googleapis.com/v1beta/files/test-markdown".to_string(),
2929 ),
2930 media_type: Some(DocumentMediaType::MARKDOWN),
2931 additional_params: None,
2932 });
2933
2934 let content: Content = message::Message::User {
2935 content: crate::OneOrMany::one(doc),
2936 }
2937 .try_into()
2938 .unwrap();
2939
2940 if let Part {
2941 part: PartKind::FileData(file_data),
2942 ..
2943 } = &content.parts[0]
2944 {
2945 assert_eq!(
2946 file_data.file_uri,
2947 "https://generativelanguage.googleapis.com/v1beta/files/test-markdown"
2948 );
2949 assert_eq!(file_data.mime_type.as_deref(), Some("text/markdown"));
2950 } else {
2951 panic!(
2952 "Expected file_data part for URL MARKDOWN document, got: {:?}",
2953 content.parts[0]
2954 );
2955 }
2956 }
2957
2958 #[test]
2959 fn test_tool_result_with_url_image() {
2960 use crate::OneOrMany;
2962 use crate::message::{
2963 DocumentSourceKind, Image, ImageMediaType, ToolResult, ToolResultContent,
2964 };
2965
2966 let tool_result = ToolResult {
2967 id: "screenshot_tool".to_string(),
2968 call_id: None,
2969 content: OneOrMany::one(ToolResultContent::Image(Image {
2970 data: DocumentSourceKind::Url("https://example.com/image.png".to_string()),
2971 media_type: Some(ImageMediaType::PNG),
2972 detail: None,
2973 additional_params: None,
2974 })),
2975 };
2976
2977 let user_content = message::UserContent::ToolResult(tool_result);
2978 let msg = message::Message::User {
2979 content: OneOrMany::one(user_content),
2980 };
2981
2982 let content: Content = msg.try_into().expect("Should convert to Gemini Content");
2983 assert_eq!(content.role, Some(Role::User));
2984 assert_eq!(content.parts.len(), 1);
2985
2986 if let Some(Part {
2987 part: PartKind::FunctionResponse(function_response),
2988 ..
2989 }) = content.parts.first()
2990 {
2991 assert_eq!(function_response.name, "screenshot_tool");
2992
2993 assert!(function_response.parts.is_some());
2995 let parts = function_response.parts.as_ref().unwrap();
2996 assert_eq!(parts.len(), 1);
2997
2998 let image_part = &parts[0];
2999 assert!(image_part.file_data.is_some());
3000 let file_data = image_part.file_data.as_ref().unwrap();
3001 assert_eq!(file_data.file_uri, "https://example.com/image.png");
3002 assert_eq!(file_data.mime_type.as_ref().unwrap(), "image/png");
3003 } else {
3004 panic!("Expected FunctionResponse part");
3005 }
3006 }
3007
3008 #[test]
3009 fn test_create_request_body_with_documents() {
3010 use crate::OneOrMany;
3012 use crate::completion::request::{CompletionRequest, Document};
3013 use crate::message::Message;
3014
3015 let documents = vec![
3016 Document {
3017 id: "doc1".to_string(),
3018 text: "Note: first.md\nContent: First note".to_string(),
3019 additional_props: std::collections::HashMap::new(),
3020 },
3021 Document {
3022 id: "doc2".to_string(),
3023 text: "Note: second.md\nContent: Second note".to_string(),
3024 additional_props: std::collections::HashMap::new(),
3025 },
3026 ];
3027
3028 let completion_request = CompletionRequest {
3029 preamble: Some("You are a helpful assistant".to_string()),
3030 chat_history: OneOrMany::one(Message::user("What are my notes about?")),
3031 documents: documents.clone(),
3032 tools: vec![],
3033 temperature: None,
3034 model: None,
3035 output_schema: None,
3036 max_tokens: None,
3037 tool_choice: None,
3038 additional_params: None,
3039 };
3040
3041 let request = create_request_body(completion_request).unwrap();
3042
3043 assert_eq!(
3045 request.contents.len(),
3046 2,
3047 "Expected 2 contents (documents + user message)"
3048 );
3049
3050 assert_eq!(request.contents[0].role, Some(Role::User));
3052 assert_eq!(
3053 request.contents[0].parts.len(),
3054 2,
3055 "Expected 2 document parts"
3056 );
3057
3058 for part in &request.contents[0].parts {
3060 if let Part {
3061 part: PartKind::Text(text),
3062 ..
3063 } = part
3064 {
3065 assert!(
3066 text.contains("Note:") && text.contains("Content:"),
3067 "Document should contain note metadata"
3068 );
3069 } else {
3070 panic!("Document parts should be text, not {:?}", part);
3071 }
3072 }
3073
3074 assert_eq!(request.contents[1].role, Some(Role::User));
3076 if let Part {
3077 part: PartKind::Text(text),
3078 ..
3079 } = &request.contents[1].parts[0]
3080 {
3081 assert_eq!(text, "What are my notes about?");
3082 } else {
3083 panic!("Expected user message to be text");
3084 }
3085 }
3086
3087 #[test]
3088 fn test_create_request_body_without_documents() {
3089 use crate::OneOrMany;
3091 use crate::completion::request::CompletionRequest;
3092 use crate::message::Message;
3093
3094 let completion_request = CompletionRequest {
3095 preamble: Some("You are a helpful assistant".to_string()),
3096 chat_history: OneOrMany::one(Message::user("Hello")),
3097 documents: vec![], tools: vec![],
3099 temperature: None,
3100 max_tokens: None,
3101 tool_choice: None,
3102 model: None,
3103 output_schema: None,
3104 additional_params: None,
3105 };
3106
3107 let request = create_request_body(completion_request).unwrap();
3108
3109 assert_eq!(request.contents.len(), 1, "Expected only user message");
3111 assert_eq!(request.contents[0].role, Some(Role::User));
3112
3113 if let Part {
3114 part: PartKind::Text(text),
3115 ..
3116 } = &request.contents[0].parts[0]
3117 {
3118 assert_eq!(text, "Hello");
3119 } else {
3120 panic!("Expected user message to be text");
3121 }
3122 }
3123
3124 #[test]
3125 fn test_from_tool_output_parses_image_json() {
3126 use crate::message::{DocumentSourceKind, ToolResultContent};
3128
3129 let image_json = r#"{"type": "image", "data": "base64data==", "mimeType": "image/jpeg"}"#;
3131 let result = ToolResultContent::from_tool_output(image_json);
3132
3133 assert_eq!(result.len(), 1);
3134 if let ToolResultContent::Image(img) = result.first() {
3135 assert!(matches!(img.data, DocumentSourceKind::Base64(_)));
3136 if let DocumentSourceKind::Base64(data) = &img.data {
3137 assert_eq!(data, "base64data==");
3138 }
3139 assert_eq!(img.media_type, Some(crate::message::ImageMediaType::JPEG));
3140 } else {
3141 panic!("Expected Image content");
3142 }
3143 }
3144
3145 #[test]
3146 fn test_from_tool_output_parses_hybrid_json() {
3147 use crate::message::{DocumentSourceKind, ToolResultContent};
3149
3150 let hybrid_json = r#"{
3151 "response": {"status": "ok", "count": 42},
3152 "parts": [
3153 {"type": "image", "data": "imgdata1==", "mimeType": "image/png"},
3154 {"type": "image", "data": "https://example.com/img.jpg", "mimeType": "image/jpeg"}
3155 ]
3156 }"#;
3157
3158 let result = ToolResultContent::from_tool_output(hybrid_json);
3159
3160 assert_eq!(result.len(), 3);
3162
3163 let items: Vec<_> = result.iter().collect();
3164
3165 if let ToolResultContent::Text(text) = &items[0] {
3167 assert!(text.text.contains("status"));
3168 assert!(text.text.contains("ok"));
3169 } else {
3170 panic!("Expected Text content first");
3171 }
3172
3173 if let ToolResultContent::Image(img) = &items[1] {
3175 assert!(matches!(img.data, DocumentSourceKind::Base64(_)));
3176 } else {
3177 panic!("Expected Image content second");
3178 }
3179
3180 if let ToolResultContent::Image(img) = &items[2] {
3182 assert!(matches!(img.data, DocumentSourceKind::Url(_)));
3183 } else {
3184 panic!("Expected Image content third");
3185 }
3186 }
3187
3188 #[tokio::test]
3192 #[ignore = "requires GEMINI_API_KEY environment variable"]
3193 async fn test_gemini_agent_with_image_tool_result_e2e() -> anyhow::Result<()> {
3194 use crate::completion::Prompt;
3195 use crate::prelude::*;
3196 use crate::providers::gemini;
3197 use crate::test_utils::MockImageGeneratorTool;
3198
3199 let client = gemini::Client::from_env()?;
3200
3201 let agent = client
3202 .agent("gemini-3-flash-preview")
3203 .preamble("You are a helpful assistant. When asked about images, use the generate_test_image tool to create one, then describe what you see in the image.")
3204 .tool(MockImageGeneratorTool)
3205 .build();
3206
3207 let response_text = agent
3209 .prompt("Please generate a test image and tell me what color the pixel is.")
3210 .await?;
3211 println!("Response: {response_text}");
3212 anyhow::ensure!(!response_text.is_empty(), "Response should not be empty");
3214
3215 Ok(())
3216 }
3217}