1pub const GEMINI_3_1_FLASH_LITE_PREVIEW: &str = "gemini-3.1-flash-lite-preview";
7pub const GEMINI_3_FLASH_PREVIEW: &str = "gemini-3-flash-preview";
9pub const GEMINI_2_5_PRO_PREVIEW_06_05: &str = "gemini-2.5-pro-preview-06-05";
11pub const GEMINI_2_5_PRO_PREVIEW_05_06: &str = "gemini-2.5-pro-preview-05-06";
13pub const GEMINI_2_5_PRO_PREVIEW_03_25: &str = "gemini-2.5-pro-preview-03-25";
15pub const GEMINI_2_5_FLASH_PREVIEW_04_17: &str = "gemini-2.5-flash-preview-04-17";
17pub const GEMINI_2_5_PRO_EXP_03_25: &str = "gemini-2.5-pro-exp-03-25";
19pub const GEMINI_2_5_FLASH: &str = "gemini-2.5-flash";
21pub const GEMINI_2_0_FLASH_LITE: &str = "gemini-2.0-flash-lite";
23pub const GEMINI_2_0_FLASH: &str = "gemini-2.0-flash";
25
26use self::gemini_api_types::Schema;
27use crate::http_client::HttpClientExt;
28use crate::message::{self, MimeType, Reasoning};
29
30use crate::providers::gemini::completion::gemini_api_types::{
31 AdditionalParameters, FunctionCallingMode, ToolConfig,
32};
33use crate::providers::gemini::streaming::StreamingCompletionResponse;
34use crate::telemetry::SpanCombinator;
35use crate::{
36 OneOrMany,
37 completion::{self, CompletionError, CompletionRequest, GetTokenUsage},
38};
39use gemini_api_types::{
40 Content, FunctionDeclaration, GenerateContentRequest, GenerateContentResponse,
41 GenerationConfig, Part, PartKind, Role, Tool,
42};
43use serde_json::{Map, Value};
44use std::convert::TryFrom;
45use tracing::{Level, enabled, info_span};
46use tracing_futures::Instrument;
47
48use super::Client;
49
50#[derive(Clone, Debug)]
55pub struct CompletionModel<T = reqwest::Client> {
56 pub(crate) client: Client<T>,
57 pub model: String,
58}
59
60impl<T> CompletionModel<T> {
61 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
62 Self {
63 client,
64 model: model.into(),
65 }
66 }
67
68 pub fn with_model(client: Client<T>, model: &str) -> Self {
69 Self {
70 client,
71 model: model.into(),
72 }
73 }
74}
75
76impl<T> completion::CompletionModel for CompletionModel<T>
77where
78 T: HttpClientExt + Clone + 'static,
79{
80 type Response = GenerateContentResponse;
81 type StreamingResponse = StreamingCompletionResponse;
82 type Client = super::Client<T>;
83
84 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
85 Self::new(client.clone(), model)
86 }
87
88 async fn completion(
89 &self,
90 completion_request: CompletionRequest,
91 ) -> Result<completion::CompletionResponse<GenerateContentResponse>, CompletionError> {
92 let request_model = resolve_request_model(&self.model, &completion_request);
93 let span = if tracing::Span::current().is_disabled() {
94 info_span!(
95 target: "rig::completions",
96 "generate_content",
97 gen_ai.operation.name = "generate_content",
98 gen_ai.provider.name = "gcp.gemini",
99 gen_ai.request.model = &request_model,
100 gen_ai.system_instructions = &completion_request.preamble,
101 gen_ai.response.id = tracing::field::Empty,
102 gen_ai.response.model = tracing::field::Empty,
103 gen_ai.usage.output_tokens = tracing::field::Empty,
104 gen_ai.usage.input_tokens = tracing::field::Empty,
105 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
106 gen_ai.usage.cache_creation.input_tokens = tracing::field::Empty,
107 gen_ai.usage.reasoning_tokens = tracing::field::Empty,
108 )
109 } else {
110 tracing::Span::current()
111 };
112
113 let request = create_request_body(completion_request)?;
114
115 if enabled!(Level::TRACE) {
116 tracing::trace!(
117 target: "rig::completions",
118 "Gemini completion request: {}",
119 serde_json::to_string_pretty(&request)?
120 );
121 }
122
123 let body = serde_json::to_vec(&request)?;
124
125 let path = completion_endpoint(&request_model);
126
127 let request = self
128 .client
129 .post(path.as_str())?
130 .body(body)
131 .map_err(|e| CompletionError::HttpError(e.into()))?;
132
133 async move {
134 let response = self.client.send::<_, Vec<u8>>(request).await?;
135
136 if response.status().is_success() {
137 let response_body = response
138 .into_body()
139 .await
140 .map_err(CompletionError::HttpError)?;
141
142 let response_text = String::from_utf8_lossy(&response_body).to_string();
143
144 let response: GenerateContentResponse = serde_json::from_slice(&response_body)
145 .map_err(|err| {
146 tracing::error!(
147 error = %err,
148 body = %response_text,
149 "Failed to deserialize Gemini completion response"
150 );
151 CompletionError::JsonError(err)
152 })?;
153
154 let span = tracing::Span::current();
155 span.record_response_metadata(&response);
156 span.record_token_usage(&response.usage_metadata);
157
158 if enabled!(Level::TRACE) {
159 tracing::trace!(
160 target: "rig::completions",
161 "Gemini completion response: {}",
162 serde_json::to_string_pretty(&response)?
163 );
164 }
165
166 response.try_into()
167 } else {
168 let text = String::from_utf8_lossy(
169 &response
170 .into_body()
171 .await
172 .map_err(CompletionError::HttpError)?,
173 )
174 .into();
175
176 Err(CompletionError::ProviderError(text))
177 }
178 }
179 .instrument(span)
180 .await
181 }
182
183 async fn stream(
184 &self,
185 request: CompletionRequest,
186 ) -> Result<
187 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
188 CompletionError,
189 > {
190 CompletionModel::stream(self, request).await
191 }
192}
193
194pub(crate) fn create_request_body(
195 completion_request: CompletionRequest,
196) -> Result<GenerateContentRequest, CompletionError> {
197 let documents_message = completion_request.normalized_documents();
198
199 let CompletionRequest {
200 model: _,
201 preamble,
202 chat_history,
203 documents: _,
204 tools: function_tools,
205 temperature,
206 max_tokens,
207 tool_choice,
208 mut additional_params,
209 output_schema,
210 } = completion_request;
211
212 let mut full_history = Vec::new();
213 if let Some(msg) = documents_message {
214 full_history.push(msg);
215 }
216 full_history.extend(chat_history);
217 let (history_system, full_history) = split_system_messages_from_history(full_history);
218
219 let mut additional_params_payload = additional_params
220 .take()
221 .unwrap_or_else(|| Value::Object(Map::new()));
222 let mut additional_tools =
223 extract_tools_from_additional_params(&mut additional_params_payload)?;
224
225 let AdditionalParameters {
226 mut generation_config,
227 additional_params,
228 } = serde_json::from_value::<AdditionalParameters>(additional_params_payload)?;
229
230 if let Some(schema) = output_schema {
232 let cfg = generation_config.get_or_insert_with(GenerationConfig::default);
233 cfg.response_mime_type = Some("application/json".to_string());
234 cfg.response_json_schema = Some(schema.to_value());
235 }
236
237 generation_config = generation_config.map(|mut cfg| {
238 if let Some(temp) = temperature {
239 cfg.temperature = Some(temp);
240 };
241
242 if let Some(max_tokens) = max_tokens {
243 cfg.max_output_tokens = Some(max_tokens);
244 };
245
246 cfg
247 });
248
249 let mut system_parts: Vec<Part> = Vec::new();
250 if let Some(preamble) = preamble.filter(|preamble| !preamble.is_empty()) {
251 system_parts.push(preamble.into());
252 }
253 for content in history_system {
254 if !content.is_empty() {
255 system_parts.push(content.into());
256 }
257 }
258 let system_instruction = if system_parts.is_empty() {
259 None
260 } else {
261 Some(Content {
262 parts: system_parts,
263 role: Some(Role::Model),
264 })
265 };
266
267 let mut tools = if function_tools.is_empty() {
268 Vec::new()
269 } else {
270 vec![serde_json::to_value(Tool::try_from(function_tools)?)?]
271 };
272 tools.append(&mut additional_tools);
273 let tools = if tools.is_empty() { None } else { Some(tools) };
274
275 let tool_config = if let Some(cfg) = tool_choice {
276 Some(ToolConfig {
277 function_calling_config: Some(FunctionCallingMode::try_from(cfg)?),
278 })
279 } else {
280 None
281 };
282
283 let request = GenerateContentRequest {
284 contents: full_history
285 .into_iter()
286 .map(|msg| {
287 msg.try_into()
288 .map_err(|e| CompletionError::RequestError(Box::new(e)))
289 })
290 .collect::<Result<Vec<_>, _>>()?,
291 generation_config,
292 safety_settings: None,
293 tools,
294 tool_config,
295 system_instruction,
296 additional_params,
297 };
298
299 Ok(request)
300}
301
302fn split_system_messages_from_history(
303 history: Vec<completion::Message>,
304) -> (Vec<String>, Vec<completion::Message>) {
305 let mut system = Vec::new();
306 let mut remaining = Vec::new();
307
308 for message in history {
309 match message {
310 completion::Message::System { content } => system.push(content),
311 other => remaining.push(other),
312 }
313 }
314
315 (system, remaining)
316}
317
318fn extract_tools_from_additional_params(
319 additional_params: &mut Value,
320) -> Result<Vec<Value>, CompletionError> {
321 if let Some(map) = additional_params.as_object_mut()
322 && let Some(raw_tools) = map.remove("tools")
323 {
324 return serde_json::from_value::<Vec<Value>>(raw_tools).map_err(|err| {
325 CompletionError::RequestError(
326 format!("Invalid Gemini `additional_params.tools` payload: {err}").into(),
327 )
328 });
329 }
330
331 Ok(Vec::new())
332}
333
334pub(crate) fn resolve_request_model(
335 default_model: &str,
336 completion_request: &CompletionRequest,
337) -> String {
338 completion_request
339 .model
340 .clone()
341 .unwrap_or_else(|| default_model.to_string())
342}
343
344pub(crate) fn completion_endpoint(model: &str) -> String {
345 format!("/v1beta/models/{model}:generateContent")
346}
347
348pub(crate) fn streaming_endpoint(model: &str) -> String {
349 format!("/v1beta/models/{model}:streamGenerateContent")
350}
351
352impl TryFrom<completion::ToolDefinition> for Tool {
353 type Error = CompletionError;
354
355 fn try_from(tool: completion::ToolDefinition) -> Result<Self, Self::Error> {
356 let parameters: Option<Schema> =
357 if tool.parameters == serde_json::json!({"type": "object", "properties": {}}) {
358 None
359 } else {
360 Some(tool.parameters.try_into()?)
361 };
362
363 Ok(Self {
364 function_declarations: vec![FunctionDeclaration {
365 name: tool.name,
366 description: tool.description,
367 parameters,
368 }],
369 code_execution: None,
370 })
371 }
372}
373
374impl TryFrom<Vec<completion::ToolDefinition>> for Tool {
375 type Error = CompletionError;
376
377 fn try_from(tools: Vec<completion::ToolDefinition>) -> Result<Self, Self::Error> {
378 let mut function_declarations = Vec::new();
379
380 for tool in tools {
381 let parameters =
382 if tool.parameters == serde_json::json!({"type": "object", "properties": {}}) {
383 None
384 } else {
385 match tool.parameters.try_into() {
386 Ok(schema) => Some(schema),
387 Err(e) => {
388 let emsg = format!(
389 "Tool '{}' could not be converted to a schema: {:?}",
390 tool.name, e,
391 );
392 return Err(CompletionError::ProviderError(emsg));
393 }
394 }
395 };
396
397 function_declarations.push(FunctionDeclaration {
398 name: tool.name,
399 description: tool.description,
400 parameters,
401 });
402 }
403
404 Ok(Self {
405 function_declarations,
406 code_execution: None,
407 })
408 }
409}
410
411impl TryFrom<GenerateContentResponse> for completion::CompletionResponse<GenerateContentResponse> {
412 type Error = CompletionError;
413
414 fn try_from(response: GenerateContentResponse) -> Result<Self, Self::Error> {
415 let candidate = response.candidates.first().ok_or_else(|| {
416 CompletionError::ResponseError("No response candidates in response".into())
417 })?;
418
419 let content = candidate
420 .content
421 .as_ref()
422 .ok_or_else(|| {
423 let reason = candidate
424 .finish_reason
425 .as_ref()
426 .map(|r| format!("finish_reason={r:?}"))
427 .unwrap_or_else(|| "finish_reason=<unknown>".to_string());
428 let message = candidate
429 .finish_message
430 .as_deref()
431 .unwrap_or("no finish message provided");
432 CompletionError::ResponseError(format!(
433 "Gemini candidate missing content ({reason}, finish_message={message})"
434 ))
435 })?
436 .parts
437 .iter()
438 .map(
439 |Part {
440 thought,
441 thought_signature,
442 part,
443 ..
444 }| {
445 Ok(match part {
446 PartKind::Text(text) => {
447 if let Some(thought) = thought
448 && *thought
449 {
450 completion::AssistantContent::Reasoning(
451 Reasoning::new_with_signature(text, thought_signature.clone()),
452 )
453 } else {
454 completion::AssistantContent::text(text)
455 }
456 }
457 PartKind::InlineData(inline_data) => {
458 let mime_type =
459 message::MediaType::from_mime_type(&inline_data.mime_type);
460
461 match mime_type {
462 Some(message::MediaType::Image(media_type)) => {
463 message::AssistantContent::image_base64(
464 &inline_data.data,
465 Some(media_type),
466 Some(message::ImageDetail::default()),
467 )
468 }
469 _ => {
470 return Err(CompletionError::ResponseError(format!(
471 "Unsupported media type {mime_type:?}"
472 )));
473 }
474 }
475 }
476 PartKind::FunctionCall(function_call) => {
477 completion::AssistantContent::ToolCall(
478 message::ToolCall::new(
479 function_call.name.clone(),
480 message::ToolFunction::new(
481 function_call.name.clone(),
482 function_call.args.clone(),
483 ),
484 )
485 .with_signature(thought_signature.clone()),
486 )
487 }
488 _ => {
489 return Err(CompletionError::ResponseError(
490 "Response did not contain a message or tool call".into(),
491 ));
492 }
493 })
494 },
495 )
496 .collect::<Result<Vec<_>, _>>()?;
497
498 let choice = OneOrMany::many(content).map_err(|_| {
499 CompletionError::ResponseError(
500 "Response contained no message or tool call (empty)".to_owned(),
501 )
502 })?;
503
504 let usage = response
505 .usage_metadata
506 .as_ref()
507 .and_then(GetTokenUsage::token_usage)
508 .unwrap_or_default();
509
510 Ok(completion::CompletionResponse {
511 choice,
512 usage,
513 raw_response: response,
514 message_id: None,
515 })
516 }
517}
518
519pub mod gemini_api_types {
520 use crate::telemetry::ProviderResponseExt;
521 use std::{collections::HashMap, convert::Infallible, str::FromStr};
522
523 use serde::{Deserialize, Serialize};
527 use serde_json::{Value, json};
528
529 use crate::completion::GetTokenUsage;
530 use crate::message::{DocumentSourceKind, ImageMediaType, MessageError, MimeType};
531 use crate::{
532 completion::CompletionError,
533 message::{self},
534 providers::gemini::gemini_api_types::{CodeExecutionResult, ExecutableCode},
535 };
536
537 #[derive(Debug, Deserialize, Serialize, Default)]
538 #[serde(rename_all = "camelCase")]
539 pub struct AdditionalParameters {
540 pub generation_config: Option<GenerationConfig>,
542 #[serde(flatten, skip_serializing_if = "Option::is_none")]
544 pub additional_params: Option<serde_json::Value>,
545 }
546
547 impl AdditionalParameters {
548 pub fn with_config(mut self, cfg: GenerationConfig) -> Self {
549 self.generation_config = Some(cfg);
550 self
551 }
552
553 pub fn with_params(mut self, params: serde_json::Value) -> Self {
554 self.additional_params = Some(params);
555 self
556 }
557 }
558
559 #[derive(Debug, Deserialize, Serialize)]
567 #[serde(rename_all = "camelCase")]
568 pub struct GenerateContentResponse {
569 pub response_id: String,
570 pub candidates: Vec<ContentCandidate>,
572 pub prompt_feedback: Option<PromptFeedback>,
574 pub usage_metadata: Option<UsageMetadata>,
576 pub model_version: Option<String>,
577 }
578
579 impl ProviderResponseExt for GenerateContentResponse {
580 type OutputMessage = ContentCandidate;
581 type Usage = UsageMetadata;
582
583 fn get_response_id(&self) -> Option<String> {
584 Some(self.response_id.clone())
585 }
586
587 fn get_response_model_name(&self) -> Option<String> {
588 self.model_version.clone()
589 }
590
591 fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
592 self.candidates.clone()
593 }
594
595 fn get_text_response(&self) -> Option<String> {
596 let str = self
597 .candidates
598 .iter()
599 .filter_map(|x| {
600 let content = x.content.as_ref()?;
601 if content.role.as_ref().is_none_or(|y| y != &Role::Model) {
602 return None;
603 }
604
605 let res = content
606 .parts
607 .iter()
608 .filter_map(|part| {
609 if let PartKind::Text(ref str) = part.part {
610 Some(str.to_owned())
611 } else {
612 None
613 }
614 })
615 .collect::<Vec<String>>()
616 .join("\n");
617
618 Some(res)
619 })
620 .collect::<Vec<String>>()
621 .join("\n");
622
623 if str.is_empty() { None } else { Some(str) }
624 }
625
626 fn get_usage(&self) -> Option<Self::Usage> {
627 self.usage_metadata.clone()
628 }
629 }
630
631 #[derive(Clone, Debug, Deserialize, Serialize)]
633 #[serde(rename_all = "camelCase")]
634 pub struct ContentCandidate {
635 #[serde(skip_serializing_if = "Option::is_none")]
637 pub content: Option<Content>,
638 pub finish_reason: Option<FinishReason>,
641 pub safety_ratings: Option<Vec<SafetyRating>>,
644 pub citation_metadata: Option<CitationMetadata>,
648 pub token_count: Option<i32>,
650 pub avg_logprobs: Option<f64>,
652 pub logprobs_result: Option<LogprobsResult>,
654 pub index: Option<i32>,
656 pub finish_message: Option<String>,
658 }
659
660 #[derive(Clone, Debug, Deserialize, Serialize)]
661 pub struct Content {
662 #[serde(default)]
664 pub parts: Vec<Part>,
665 pub role: Option<Role>,
668 }
669
670 impl TryFrom<message::Message> for Content {
671 type Error = message::MessageError;
672
673 fn try_from(msg: message::Message) -> Result<Self, Self::Error> {
674 Ok(match msg {
675 message::Message::System { content } => Content {
676 parts: vec![content.into()],
677 role: Some(Role::User),
678 },
679 message::Message::User { content } => Content {
680 parts: content
681 .into_iter()
682 .map(|c| c.try_into())
683 .collect::<Result<Vec<_>, _>>()?,
684 role: Some(Role::User),
685 },
686 message::Message::Assistant { content, .. } => Content {
687 role: Some(Role::Model),
688 parts: content
689 .into_iter()
690 .map(|content| content.try_into())
691 .collect::<Result<Vec<_>, _>>()?,
692 },
693 })
694 }
695 }
696
697 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
698 #[serde(rename_all = "lowercase")]
699 pub enum Role {
700 User,
701 Model,
702 }
703
704 #[derive(Debug, Default, Deserialize, Serialize, Clone, PartialEq)]
705 #[serde(rename_all = "camelCase")]
706 pub struct Part {
707 #[serde(skip_serializing_if = "Option::is_none")]
709 pub thought: Option<bool>,
710 #[serde(skip_serializing_if = "Option::is_none")]
712 pub thought_signature: Option<String>,
713 #[serde(flatten)]
714 pub part: PartKind,
715 #[serde(flatten, skip_serializing_if = "Option::is_none")]
716 pub additional_params: Option<Value>,
717 }
718
719 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
723 #[serde(rename_all = "camelCase")]
724 pub enum PartKind {
725 Text(String),
726 InlineData(Blob),
727 FunctionCall(FunctionCall),
728 FunctionResponse(FunctionResponse),
729 FileData(FileData),
730 ExecutableCode(ExecutableCode),
731 CodeExecutionResult(CodeExecutionResult),
732 }
733
734 impl Default for PartKind {
737 fn default() -> Self {
738 Self::Text(String::new())
739 }
740 }
741
742 impl From<String> for Part {
743 fn from(text: String) -> Self {
744 Self {
745 thought: Some(false),
746 thought_signature: None,
747 part: PartKind::Text(text),
748 additional_params: None,
749 }
750 }
751 }
752
753 impl From<&str> for Part {
754 fn from(text: &str) -> Self {
755 Self::from(text.to_string())
756 }
757 }
758
759 impl FromStr for Part {
760 type Err = Infallible;
761
762 fn from_str(s: &str) -> Result<Self, Self::Err> {
763 Ok(s.into())
764 }
765 }
766
767 impl TryFrom<(ImageMediaType, DocumentSourceKind)> for PartKind {
768 type Error = message::MessageError;
769 fn try_from(
770 (mime_type, doc_src): (ImageMediaType, DocumentSourceKind),
771 ) -> Result<Self, Self::Error> {
772 let mime_type = mime_type.to_mime_type().to_string();
773 let part = match doc_src {
774 DocumentSourceKind::Url(url) => PartKind::FileData(FileData {
775 mime_type: Some(mime_type),
776 file_uri: url,
777 }),
778 DocumentSourceKind::Base64(data) | DocumentSourceKind::String(data) => {
779 PartKind::InlineData(Blob { mime_type, data })
780 }
781 DocumentSourceKind::Raw(_) => {
782 return Err(message::MessageError::ConversionError(
783 "Raw files not supported, encode as base64 first".into(),
784 ));
785 }
786 DocumentSourceKind::FileId(_) => {
787 return Err(message::MessageError::ConversionError(
788 "Provider file IDs are not supported for Gemini image inputs".into(),
789 ));
790 }
791 DocumentSourceKind::Unknown => {
792 return Err(message::MessageError::ConversionError(
793 "Can't convert an unknown document source".to_string(),
794 ));
795 }
796 };
797
798 Ok(part)
799 }
800 }
801
802 impl TryFrom<message::UserContent> for Part {
803 type Error = message::MessageError;
804
805 fn try_from(content: message::UserContent) -> Result<Self, Self::Error> {
806 match content {
807 message::UserContent::Text(message::Text { text }) => Ok(Part {
808 thought: Some(false),
809 thought_signature: None,
810 part: PartKind::Text(text),
811 additional_params: None,
812 }),
813 message::UserContent::ToolResult(message::ToolResult { id, content, .. }) => {
814 let mut response_json: Option<serde_json::Value> = None;
815 let mut parts: Vec<FunctionResponsePart> = Vec::new();
816
817 for item in content.iter() {
818 match item {
819 message::ToolResultContent::Text(text) => {
820 let result: serde_json::Value =
821 serde_json::from_str(&text.text).unwrap_or_else(|error| {
822 tracing::trace!(
823 ?error,
824 "Tool result is not a valid JSON, treat it as normal string"
825 );
826 json!(&text.text)
827 });
828
829 response_json = Some(match response_json {
830 Some(mut existing) => {
831 if let serde_json::Value::Object(ref mut map) = existing {
832 map.insert("text".to_string(), result);
833 }
834 existing
835 }
836 None => json!({ "result": result }),
837 });
838 }
839 message::ToolResultContent::Image(image) => {
840 let part = match &image.data {
841 DocumentSourceKind::Base64(b64) => {
842 let mime_type = image
843 .media_type
844 .as_ref()
845 .ok_or(message::MessageError::ConversionError(
846 "Image media type is required for Gemini tool results".to_string(),
847 ))?
848 .to_mime_type();
849
850 FunctionResponsePart {
851 inline_data: Some(FunctionResponseInlineData {
852 mime_type: mime_type.to_string(),
853 data: b64.clone(),
854 display_name: None,
855 }),
856 file_data: None,
857 }
858 }
859 DocumentSourceKind::Url(url) => {
860 let mime_type = image
861 .media_type
862 .as_ref()
863 .map(|mt| mt.to_mime_type().to_string());
864
865 FunctionResponsePart {
866 inline_data: None,
867 file_data: Some(FileData {
868 mime_type,
869 file_uri: url.clone(),
870 }),
871 }
872 }
873 _ => {
874 return Err(message::MessageError::ConversionError(
875 "Unsupported image source kind for tool results"
876 .to_string(),
877 ));
878 }
879 };
880 parts.push(part);
881 }
882 }
883 }
884
885 Ok(Part {
886 thought: Some(false),
887 thought_signature: None,
888 part: PartKind::FunctionResponse(FunctionResponse {
889 name: id,
890 response: response_json,
891 parts: if parts.is_empty() { None } else { Some(parts) },
892 }),
893 additional_params: None,
894 })
895 }
896 message::UserContent::Image(message::Image {
897 data, media_type, ..
898 }) => match media_type {
899 Some(media_type) => match media_type {
900 message::ImageMediaType::JPEG
901 | message::ImageMediaType::PNG
902 | message::ImageMediaType::WEBP
903 | message::ImageMediaType::HEIC
904 | message::ImageMediaType::HEIF => {
905 let part = PartKind::try_from((media_type, data))?;
906 Ok(Part {
907 thought: Some(false),
908 thought_signature: None,
909 part,
910 additional_params: None,
911 })
912 }
913 _ => Err(message::MessageError::ConversionError(format!(
914 "Unsupported image media type {media_type:?}"
915 ))),
916 },
917 None => Err(message::MessageError::ConversionError(
918 "Media type for image is required for Gemini".to_string(),
919 )),
920 },
921 message::UserContent::Document(message::Document {
922 data, media_type, ..
923 }) => {
924 let Some(media_type) = media_type else {
925 return Err(MessageError::ConversionError(
926 "A mime type is required for document inputs to Gemini".to_string(),
927 ));
928 };
929
930 if matches!(
933 media_type,
934 message::DocumentMediaType::TXT
935 | message::DocumentMediaType::RTF
936 | message::DocumentMediaType::HTML
937 | message::DocumentMediaType::CSS
938 | message::DocumentMediaType::MARKDOWN
939 | message::DocumentMediaType::CSV
940 | message::DocumentMediaType::XML
941 | message::DocumentMediaType::Javascript
942 | message::DocumentMediaType::Python
943 ) {
944 use base64::Engine;
945 let part = match data {
946 DocumentSourceKind::String(text) => PartKind::Text(text),
947 DocumentSourceKind::Base64(data) => {
948 let text = String::from_utf8(
950 base64::engine::general_purpose::STANDARD
951 .decode(&data)
952 .map_err(|e| {
953 MessageError::ConversionError(format!(
954 "Failed to decode base64: {e}"
955 ))
956 })?,
957 )
958 .map_err(|e| {
959 MessageError::ConversionError(format!(
960 "Invalid UTF-8 in document: {e}"
961 ))
962 })?;
963 PartKind::Text(text)
964 }
965 DocumentSourceKind::Url(file_uri) => PartKind::FileData(FileData {
966 mime_type: Some(media_type.to_mime_type().to_string()),
967 file_uri,
968 }),
969 DocumentSourceKind::Raw(_) => {
970 return Err(MessageError::ConversionError(
971 "Raw files not supported, encode as base64 first".to_string(),
972 ));
973 }
974 DocumentSourceKind::FileId(_) => {
975 return Err(MessageError::ConversionError(
976 "Provider file IDs are not supported for Gemini documents"
977 .to_string(),
978 ));
979 }
980 DocumentSourceKind::Unknown => {
981 return Err(MessageError::ConversionError(
982 "Document has no body".to_string(),
983 ));
984 }
985 };
986
987 Ok(Part {
988 thought: Some(false),
989 part,
990 ..Default::default()
991 })
992 } else if !media_type.is_code() {
993 let mime_type = media_type.to_mime_type().to_string();
994
995 let part = match data {
996 DocumentSourceKind::Url(file_uri) => PartKind::FileData(FileData {
997 mime_type: Some(mime_type),
998 file_uri,
999 }),
1000 DocumentSourceKind::Base64(data) | DocumentSourceKind::String(data) => {
1001 PartKind::InlineData(Blob { mime_type, data })
1002 }
1003 DocumentSourceKind::Raw(_) => {
1004 return Err(message::MessageError::ConversionError(
1005 "Raw files not supported, encode as base64 first".into(),
1006 ));
1007 }
1008 _ => {
1009 return Err(message::MessageError::ConversionError(
1010 "Document has no body".to_string(),
1011 ));
1012 }
1013 };
1014
1015 Ok(Part {
1016 thought: Some(false),
1017 part,
1018 ..Default::default()
1019 })
1020 } else {
1021 Err(message::MessageError::ConversionError(format!(
1022 "Unsupported document media type {media_type:?}"
1023 )))
1024 }
1025 }
1026
1027 message::UserContent::Audio(message::Audio {
1028 data, media_type, ..
1029 }) => {
1030 let Some(media_type) = media_type else {
1031 return Err(MessageError::ConversionError(
1032 "A mime type is required for audio inputs to Gemini".to_string(),
1033 ));
1034 };
1035
1036 let mime_type = media_type.to_mime_type().to_string();
1037
1038 let part = match data {
1039 DocumentSourceKind::Base64(data) => {
1040 PartKind::InlineData(Blob { data, mime_type })
1041 }
1042
1043 DocumentSourceKind::Url(file_uri) => PartKind::FileData(FileData {
1044 mime_type: Some(mime_type),
1045 file_uri,
1046 }),
1047 DocumentSourceKind::String(_) => {
1048 return Err(message::MessageError::ConversionError(
1049 "Strings cannot be used as audio files!".into(),
1050 ));
1051 }
1052 DocumentSourceKind::Raw(_) => {
1053 return Err(message::MessageError::ConversionError(
1054 "Raw files not supported, encode as base64 first".into(),
1055 ));
1056 }
1057 DocumentSourceKind::FileId(_) => {
1058 return Err(message::MessageError::ConversionError(
1059 "Provider file IDs are not supported for Gemini audio inputs"
1060 .into(),
1061 ));
1062 }
1063 DocumentSourceKind::Unknown => {
1064 return Err(message::MessageError::ConversionError(
1065 "Content has no body".to_string(),
1066 ));
1067 }
1068 };
1069
1070 Ok(Part {
1071 thought: Some(false),
1072 part,
1073 ..Default::default()
1074 })
1075 }
1076 message::UserContent::Video(message::Video {
1077 data,
1078 media_type,
1079 additional_params,
1080 ..
1081 }) => {
1082 let mime_type = media_type.map(|media_ty| media_ty.to_mime_type().to_string());
1083
1084 let part = match data {
1085 DocumentSourceKind::Url(file_uri) => {
1086 if file_uri.starts_with("https://www.youtube.com") {
1087 PartKind::FileData(FileData {
1088 mime_type,
1089 file_uri,
1090 })
1091 } else {
1092 if mime_type.is_none() {
1093 return Err(MessageError::ConversionError(
1094 "A mime type is required for non-Youtube video file inputs to Gemini"
1095 .to_string(),
1096 ));
1097 }
1098
1099 PartKind::FileData(FileData {
1100 mime_type,
1101 file_uri,
1102 })
1103 }
1104 }
1105 DocumentSourceKind::Base64(data) => {
1106 let Some(mime_type) = mime_type else {
1107 return Err(MessageError::ConversionError(
1108 "A media type is expected for base64 encoded strings"
1109 .to_string(),
1110 ));
1111 };
1112 PartKind::InlineData(Blob { mime_type, data })
1113 }
1114 DocumentSourceKind::String(_) => {
1115 return Err(message::MessageError::ConversionError(
1116 "Strings cannot be used as audio files!".into(),
1117 ));
1118 }
1119 DocumentSourceKind::Raw(_) => {
1120 return Err(message::MessageError::ConversionError(
1121 "Raw file data not supported, encode as base64 first".into(),
1122 ));
1123 }
1124 DocumentSourceKind::FileId(_) => {
1125 return Err(message::MessageError::ConversionError(
1126 "Provider file IDs are not supported for Gemini video inputs"
1127 .into(),
1128 ));
1129 }
1130 DocumentSourceKind::Unknown => {
1131 return Err(message::MessageError::ConversionError(
1132 "Media type for video is required for Gemini".to_string(),
1133 ));
1134 }
1135 };
1136
1137 Ok(Part {
1138 thought: Some(false),
1139 thought_signature: None,
1140 part,
1141 additional_params,
1142 })
1143 }
1144 }
1145 }
1146 }
1147
1148 impl TryFrom<message::AssistantContent> for Part {
1149 type Error = message::MessageError;
1150
1151 fn try_from(content: message::AssistantContent) -> Result<Self, Self::Error> {
1152 match content {
1153 message::AssistantContent::Text(message::Text { text }) => Ok(text.into()),
1154 message::AssistantContent::Image(message::Image {
1155 data, media_type, ..
1156 }) => match media_type {
1157 Some(media_type) => match media_type {
1158 message::ImageMediaType::JPEG
1159 | message::ImageMediaType::PNG
1160 | message::ImageMediaType::WEBP
1161 | message::ImageMediaType::HEIC
1162 | message::ImageMediaType::HEIF => {
1163 let part = PartKind::try_from((media_type, data))?;
1164 Ok(Part {
1165 thought: Some(false),
1166 thought_signature: None,
1167 part,
1168 additional_params: None,
1169 })
1170 }
1171 _ => Err(message::MessageError::ConversionError(format!(
1172 "Unsupported image media type {media_type:?}"
1173 ))),
1174 },
1175 None => Err(message::MessageError::ConversionError(
1176 "Media type for image is required for Gemini".to_string(),
1177 )),
1178 },
1179 message::AssistantContent::ToolCall(tool_call) => Ok(tool_call.into()),
1180 message::AssistantContent::Reasoning(reasoning) => Ok(Part {
1181 thought: Some(true),
1182 thought_signature: reasoning.first_signature().map(str::to_owned),
1183 part: PartKind::Text(reasoning.display_text()),
1184 additional_params: None,
1185 }),
1186 }
1187 }
1188 }
1189
1190 impl From<message::ToolCall> for Part {
1191 fn from(tool_call: message::ToolCall) -> Self {
1192 Self {
1193 thought: Some(false),
1194 thought_signature: tool_call.signature,
1195 part: PartKind::FunctionCall(FunctionCall {
1196 name: tool_call.function.name,
1197 args: tool_call.function.arguments,
1198 }),
1199 additional_params: None,
1200 }
1201 }
1202 }
1203
1204 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1207 #[serde(rename_all = "camelCase")]
1208 pub struct Blob {
1209 pub mime_type: String,
1212 pub data: String,
1214 }
1215
1216 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1219 pub struct FunctionCall {
1220 pub name: String,
1223 pub args: serde_json::Value,
1225 }
1226
1227 impl From<message::ToolCall> for FunctionCall {
1228 fn from(tool_call: message::ToolCall) -> Self {
1229 Self {
1230 name: tool_call.function.name,
1231 args: tool_call.function.arguments,
1232 }
1233 }
1234 }
1235
1236 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1240 pub struct FunctionResponse {
1241 pub name: String,
1244 #[serde(skip_serializing_if = "Option::is_none")]
1246 pub response: Option<serde_json::Value>,
1247 #[serde(skip_serializing_if = "Option::is_none")]
1249 pub parts: Option<Vec<FunctionResponsePart>>,
1250 }
1251
1252 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1254 #[serde(rename_all = "camelCase")]
1255 pub struct FunctionResponsePart {
1256 #[serde(skip_serializing_if = "Option::is_none")]
1258 pub inline_data: Option<FunctionResponseInlineData>,
1259 #[serde(skip_serializing_if = "Option::is_none")]
1261 pub file_data: Option<FileData>,
1262 }
1263
1264 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1266 #[serde(rename_all = "camelCase")]
1267 pub struct FunctionResponseInlineData {
1268 pub mime_type: String,
1270 pub data: String,
1272 #[serde(skip_serializing_if = "Option::is_none")]
1274 pub display_name: Option<String>,
1275 }
1276
1277 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1279 #[serde(rename_all = "camelCase")]
1280 pub struct FileData {
1281 pub mime_type: Option<String>,
1283 pub file_uri: String,
1285 }
1286
1287 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1288 pub struct SafetyRating {
1289 pub category: HarmCategory,
1290 pub probability: HarmProbability,
1291 }
1292
1293 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1294 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1295 pub enum HarmProbability {
1296 HarmProbabilityUnspecified,
1297 Negligible,
1298 Low,
1299 Medium,
1300 High,
1301 }
1302
1303 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1304 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1305 pub enum HarmCategory {
1306 HarmCategoryUnspecified,
1307 HarmCategoryDerogatory,
1308 HarmCategoryToxicity,
1309 HarmCategoryViolence,
1310 HarmCategorySexually,
1311 HarmCategoryMedical,
1312 HarmCategoryDangerous,
1313 HarmCategoryHarassment,
1314 HarmCategoryHateSpeech,
1315 HarmCategorySexuallyExplicit,
1316 HarmCategoryDangerousContent,
1317 HarmCategoryCivicIntegrity,
1318 }
1319
1320 #[derive(Debug, Deserialize, Clone, Default, Serialize)]
1321 #[serde(rename_all = "camelCase")]
1322 pub struct UsageMetadata {
1323 #[serde(default)]
1324 pub prompt_token_count: i32,
1325 #[serde(skip_serializing_if = "Option::is_none")]
1326 pub cached_content_token_count: Option<i32>,
1327 #[serde(skip_serializing_if = "Option::is_none")]
1328 pub candidates_token_count: Option<i32>,
1329 pub total_token_count: i32,
1330 #[serde(skip_serializing_if = "Option::is_none")]
1331 pub thoughts_token_count: Option<i32>,
1332 #[serde(default, skip_serializing_if = "Option::is_none")]
1333 pub prompt_tokens_details: Option<Vec<ModalityTokenCount>>,
1334 #[serde(default, skip_serializing_if = "Option::is_none")]
1335 pub cache_tokens_details: Option<Vec<ModalityTokenCount>>,
1336 #[serde(default, skip_serializing_if = "Option::is_none")]
1337 pub candidates_tokens_details: Option<Vec<ModalityTokenCount>>,
1338 #[serde(default, skip_serializing_if = "Option::is_none")]
1339 pub tool_use_prompt_token_count: Option<i32>,
1340 #[serde(default, skip_serializing_if = "Option::is_none")]
1341 pub tool_use_prompt_tokens_details: Option<Vec<ModalityTokenCount>>,
1342 #[serde(default, skip_serializing_if = "Option::is_none")]
1343 pub traffic_type: Option<TrafficType>,
1344 }
1345
1346 #[derive(Clone, Debug, Deserialize, Serialize)]
1347 #[serde(rename_all = "camelCase")]
1348 pub struct ModalityTokenCount {
1349 pub modality: Modality,
1350 pub token_count: i32,
1351 }
1352
1353 #[derive(Clone, Debug, Deserialize, Serialize)]
1354 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1355 pub enum Modality {
1356 ModalityUnspecified,
1357 Text,
1358 Image,
1359 Video,
1360 Audio,
1361 Document,
1362 }
1363
1364 #[derive(Clone, Debug, Deserialize, Serialize)]
1365 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1366 pub enum TrafficType {
1367 TrafficTypeUnspecified,
1368 OnDemand,
1369 ProvisionedThroughput,
1370 }
1371
1372 impl std::fmt::Display for UsageMetadata {
1373 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1374 write!(
1375 f,
1376 "Prompt token count: {}\nCached content token count: {}\nCandidates token count: {}\nTotal token count: {}",
1377 self.prompt_token_count,
1378 match self.cached_content_token_count {
1379 Some(count) => count.to_string(),
1380 None => "n/a".to_string(),
1381 },
1382 match self.candidates_token_count {
1383 Some(count) => count.to_string(),
1384 None => "n/a".to_string(),
1385 },
1386 self.total_token_count
1387 )
1388 }
1389 }
1390
1391 impl GetTokenUsage for UsageMetadata {
1392 fn token_usage(&self) -> Option<crate::completion::Usage> {
1393 let mut usage = crate::completion::Usage::new();
1394
1395 usage.input_tokens = self.prompt_token_count as u64;
1396 usage.output_tokens = self.candidates_token_count.unwrap_or_default() as u64;
1397 usage.cached_input_tokens = self.cached_content_token_count.unwrap_or_default() as u64;
1398 usage.reasoning_tokens = self.thoughts_token_count.unwrap_or_default() as u64;
1399 usage.total_tokens = self.total_token_count as u64;
1400
1401 Some(usage)
1402 }
1403 }
1404
1405 #[derive(Debug, Deserialize, Serialize)]
1407 #[serde(rename_all = "camelCase")]
1408 pub struct PromptFeedback {
1409 pub block_reason: Option<BlockReason>,
1411 pub safety_ratings: Option<Vec<SafetyRating>>,
1413 }
1414
1415 #[derive(Debug, Deserialize, Serialize)]
1417 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1418 pub enum BlockReason {
1419 BlockReasonUnspecified,
1421 Safety,
1423 Other,
1425 Blocklist,
1427 ProhibitedContent,
1429 }
1430
1431 #[derive(Clone, Debug, Deserialize, Serialize)]
1432 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1433 pub enum FinishReason {
1434 FinishReasonUnspecified,
1436 Stop,
1438 MaxTokens,
1440 Safety,
1442 Recitation,
1444 Language,
1446 Other,
1448 Blocklist,
1450 ProhibitedContent,
1452 Spii,
1454 MalformedFunctionCall,
1456 }
1457
1458 #[derive(Clone, Debug, Deserialize, Serialize)]
1459 #[serde(rename_all = "camelCase")]
1460 pub struct CitationMetadata {
1461 pub citation_sources: Vec<CitationSource>,
1462 }
1463
1464 #[derive(Clone, Debug, Deserialize, Serialize)]
1465 #[serde(rename_all = "camelCase")]
1466 pub struct CitationSource {
1467 #[serde(skip_serializing_if = "Option::is_none")]
1468 pub uri: Option<String>,
1469 #[serde(skip_serializing_if = "Option::is_none")]
1470 pub start_index: Option<i32>,
1471 #[serde(skip_serializing_if = "Option::is_none")]
1472 pub end_index: Option<i32>,
1473 #[serde(skip_serializing_if = "Option::is_none")]
1474 pub license: Option<String>,
1475 }
1476
1477 #[derive(Clone, Debug, Deserialize, Serialize)]
1478 #[serde(rename_all = "camelCase")]
1479 pub struct LogprobsResult {
1480 pub top_candidate: Vec<TopCandidate>,
1481 pub chosen_candidate: Vec<LogProbCandidate>,
1482 }
1483
1484 #[derive(Clone, Debug, Deserialize, Serialize)]
1485 pub struct TopCandidate {
1486 pub candidates: Vec<LogProbCandidate>,
1487 }
1488
1489 #[derive(Clone, Debug, Deserialize, Serialize)]
1490 #[serde(rename_all = "camelCase")]
1491 pub struct LogProbCandidate {
1492 pub token: String,
1493 pub token_id: String,
1494 pub log_probability: f64,
1495 }
1496
1497 #[derive(Debug, Deserialize, Serialize)]
1502 #[serde(rename_all = "camelCase")]
1503 pub struct GenerationConfig {
1504 #[serde(skip_serializing_if = "Option::is_none")]
1507 pub stop_sequences: Option<Vec<String>>,
1508 #[serde(skip_serializing_if = "Option::is_none")]
1514 pub response_mime_type: Option<String>,
1515 #[serde(skip_serializing_if = "Option::is_none")]
1519 pub response_schema: Option<Schema>,
1520 #[serde(
1526 skip_serializing_if = "Option::is_none",
1527 rename = "_responseJsonSchema"
1528 )]
1529 pub _response_json_schema: Option<Value>,
1530 #[serde(skip_serializing_if = "Option::is_none")]
1532 pub response_json_schema: Option<Value>,
1533 #[serde(skip_serializing_if = "Option::is_none")]
1536 pub candidate_count: Option<i32>,
1537 #[serde(skip_serializing_if = "Option::is_none")]
1540 pub max_output_tokens: Option<u64>,
1541 #[serde(skip_serializing_if = "Option::is_none")]
1544 pub temperature: Option<f64>,
1545 #[serde(skip_serializing_if = "Option::is_none")]
1552 pub top_p: Option<f64>,
1553 #[serde(skip_serializing_if = "Option::is_none")]
1559 pub top_k: Option<i32>,
1560 #[serde(skip_serializing_if = "Option::is_none")]
1566 pub presence_penalty: Option<f64>,
1567 #[serde(skip_serializing_if = "Option::is_none")]
1575 pub frequency_penalty: Option<f64>,
1576 #[serde(skip_serializing_if = "Option::is_none")]
1578 pub response_logprobs: Option<bool>,
1579 #[serde(skip_serializing_if = "Option::is_none")]
1582 pub logprobs: Option<i32>,
1583 #[serde(skip_serializing_if = "Option::is_none")]
1585 pub thinking_config: Option<ThinkingConfig>,
1586 #[serde(skip_serializing_if = "Option::is_none")]
1587 pub image_config: Option<ImageConfig>,
1588 }
1589
1590 impl Default for GenerationConfig {
1591 fn default() -> Self {
1592 Self {
1593 temperature: Some(1.0),
1594 max_output_tokens: Some(4096),
1595 stop_sequences: None,
1596 response_mime_type: None,
1597 response_schema: None,
1598 _response_json_schema: None,
1599 response_json_schema: None,
1600 candidate_count: None,
1601 top_p: None,
1602 top_k: None,
1603 presence_penalty: None,
1604 frequency_penalty: None,
1605 response_logprobs: None,
1606 logprobs: None,
1607 thinking_config: None,
1608 image_config: None,
1609 }
1610 }
1611 }
1612
1613 #[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
1615 #[serde(rename_all = "snake_case")]
1616 pub enum ThinkingLevel {
1617 Minimal,
1618 Low,
1619 Medium,
1620 High,
1621 }
1622
1623 #[derive(Debug, Deserialize, Serialize)]
1627 #[serde(rename_all = "camelCase")]
1628 pub struct ThinkingConfig {
1629 #[serde(skip_serializing_if = "Option::is_none")]
1631 pub thinking_budget: Option<u32>,
1632 #[serde(skip_serializing_if = "Option::is_none")]
1634 pub thinking_level: Option<ThinkingLevel>,
1635 #[serde(skip_serializing_if = "Option::is_none")]
1637 pub include_thoughts: Option<bool>,
1638 }
1639
1640 #[derive(Debug, Deserialize, Serialize)]
1641 #[serde(rename_all = "camelCase")]
1642 pub struct ImageConfig {
1643 #[serde(skip_serializing_if = "Option::is_none")]
1644 pub aspect_ratio: Option<String>,
1645 #[serde(skip_serializing_if = "Option::is_none")]
1646 pub image_size: Option<String>,
1647 }
1648
1649 #[derive(Debug, Deserialize, Serialize, Clone)]
1653 pub struct Schema {
1654 pub r#type: String,
1655 #[serde(skip_serializing_if = "Option::is_none")]
1656 pub format: Option<String>,
1657 #[serde(skip_serializing_if = "Option::is_none")]
1658 pub description: Option<String>,
1659 #[serde(skip_serializing_if = "Option::is_none")]
1660 pub nullable: Option<bool>,
1661 #[serde(skip_serializing_if = "Option::is_none")]
1662 pub r#enum: Option<Vec<String>>,
1663 #[serde(skip_serializing_if = "Option::is_none")]
1664 pub max_items: Option<i32>,
1665 #[serde(skip_serializing_if = "Option::is_none")]
1666 pub min_items: Option<i32>,
1667 #[serde(skip_serializing_if = "Option::is_none")]
1668 pub properties: Option<HashMap<String, Schema>>,
1669 #[serde(skip_serializing_if = "Option::is_none")]
1670 pub required: Option<Vec<String>>,
1671 #[serde(skip_serializing_if = "Option::is_none")]
1672 pub items: Option<Box<Schema>>,
1673 }
1674
1675 pub fn flatten_schema(mut schema: Value) -> Result<Value, CompletionError> {
1681 let defs = if let Some(obj) = schema.as_object() {
1683 obj.get("$defs").or_else(|| obj.get("definitions")).cloned()
1684 } else {
1685 None
1686 };
1687
1688 let Some(defs_value) = defs else {
1689 return Ok(schema);
1690 };
1691
1692 let Some(defs_obj) = defs_value.as_object() else {
1693 return Err(CompletionError::ResponseError(
1694 "$defs must be an object".into(),
1695 ));
1696 };
1697
1698 resolve_refs(&mut schema, defs_obj)?;
1699
1700 if let Some(obj) = schema.as_object_mut() {
1702 obj.remove("$defs");
1703 obj.remove("definitions");
1704 }
1705
1706 Ok(schema)
1707 }
1708
1709 fn resolve_refs(
1712 value: &mut Value,
1713 defs: &serde_json::Map<String, Value>,
1714 ) -> Result<(), CompletionError> {
1715 match value {
1716 Value::Object(obj) => {
1717 if let Some(ref_value) = obj.get("$ref")
1718 && let Some(ref_str) = ref_value.as_str()
1719 {
1720 let def_name = parse_ref_path(ref_str)?;
1722
1723 let def = defs.get(&def_name).ok_or_else(|| {
1724 CompletionError::ResponseError(format!("Reference not found: {}", ref_str))
1725 })?;
1726
1727 let mut resolved = def.clone();
1728 resolve_refs(&mut resolved, defs)?;
1729 *value = resolved;
1730 return Ok(());
1731 }
1732
1733 for (_, v) in obj.iter_mut() {
1734 resolve_refs(v, defs)?;
1735 }
1736 }
1737 Value::Array(arr) => {
1738 for item in arr.iter_mut() {
1739 resolve_refs(item, defs)?;
1740 }
1741 }
1742 _ => {}
1743 }
1744
1745 Ok(())
1746 }
1747
1748 fn parse_ref_path(ref_str: &str) -> Result<String, CompletionError> {
1754 if let Some(fragment) = ref_str.strip_prefix('#') {
1755 if let Some(name) = fragment.strip_prefix("/$defs/") {
1756 Ok(name.to_string())
1757 } else if let Some(name) = fragment.strip_prefix("/definitions/") {
1758 Ok(name.to_string())
1759 } else {
1760 Err(CompletionError::ResponseError(format!(
1761 "Unsupported reference format: {}",
1762 ref_str
1763 )))
1764 }
1765 } else {
1766 Err(CompletionError::ResponseError(format!(
1767 "Only fragment references (#/...) are supported: {}",
1768 ref_str
1769 )))
1770 }
1771 }
1772
1773 fn extract_type(type_value: &Value) -> Option<String> {
1776 if type_value.is_string() {
1777 type_value.as_str().map(String::from)
1778 } else if type_value.is_array() {
1779 type_value
1780 .as_array()
1781 .and_then(|arr| arr.first())
1782 .and_then(|v| v.as_str().map(String::from))
1783 } else {
1784 None
1785 }
1786 }
1787
1788 fn extract_type_from_composition(composition: &Value) -> Option<String> {
1791 composition.as_array().and_then(|arr| {
1792 arr.iter().find_map(|schema| {
1793 if let Some(obj) = schema.as_object() {
1794 if let Some(type_val) = obj.get("type")
1796 && let Some(type_str) = type_val.as_str()
1797 && type_str == "null"
1798 {
1799 return None;
1800 }
1801 obj.get("type").and_then(extract_type).or_else(|| {
1803 if obj.contains_key("properties") {
1804 Some("object".to_string())
1805 } else if obj.contains_key("enum") {
1806 Some("string".to_string())
1808 } else {
1809 None
1810 }
1811 })
1812 } else {
1813 None
1814 }
1815 })
1816 })
1817 }
1818
1819 fn extract_schema_from_composition(
1822 composition: &Value,
1823 ) -> Option<serde_json::Map<String, Value>> {
1824 composition.as_array().and_then(|arr| {
1825 arr.iter().find_map(|schema| {
1826 if let Some(obj) = schema.as_object()
1827 && let Some(type_val) = obj.get("type")
1828 && let Some(type_str) = type_val.as_str()
1829 {
1830 if type_str == "null" {
1831 return None;
1832 }
1833 Some(obj.clone())
1834 } else {
1835 None
1836 }
1837 })
1838 })
1839 }
1840
1841 fn infer_type(obj: &serde_json::Map<String, Value>) -> String {
1844 if let Some(type_val) = obj.get("type")
1846 && let Some(type_str) = extract_type(type_val)
1847 {
1848 return type_str;
1849 }
1850
1851 if let Some(any_of) = obj.get("anyOf")
1853 && let Some(type_str) = extract_type_from_composition(any_of)
1854 {
1855 return type_str;
1856 }
1857
1858 if let Some(one_of) = obj.get("oneOf")
1859 && let Some(type_str) = extract_type_from_composition(one_of)
1860 {
1861 return type_str;
1862 }
1863
1864 if let Some(all_of) = obj.get("allOf")
1865 && let Some(type_str) = extract_type_from_composition(all_of)
1866 {
1867 return type_str;
1868 }
1869
1870 if obj.contains_key("properties") {
1872 "object".to_string()
1873 } else {
1874 String::new()
1875 }
1876 }
1877
1878 impl TryFrom<Value> for Schema {
1879 type Error = CompletionError;
1880
1881 fn try_from(value: Value) -> Result<Self, Self::Error> {
1882 let flattened_val = flatten_schema(value)?;
1883 if let Some(obj) = flattened_val.as_object() {
1884 let props_source = if obj.get("properties").is_none() {
1887 if let Some(any_of) = obj.get("anyOf") {
1888 extract_schema_from_composition(any_of)
1889 } else if let Some(one_of) = obj.get("oneOf") {
1890 extract_schema_from_composition(one_of)
1891 } else if let Some(all_of) = obj.get("allOf") {
1892 extract_schema_from_composition(all_of)
1893 } else {
1894 None
1895 }
1896 .unwrap_or(obj.clone())
1897 } else {
1898 obj.clone()
1899 };
1900
1901 let schema_type = infer_type(obj);
1902 let items = obj
1903 .get("items")
1904 .and_then(|v| v.clone().try_into().ok())
1905 .map(Box::new);
1906
1907 let items = if schema_type == "array" && items.is_none() {
1910 Some(Box::new(Schema {
1911 r#type: "string".to_string(),
1912 format: None,
1913 description: None,
1914 nullable: None,
1915 r#enum: None,
1916 max_items: None,
1917 min_items: None,
1918 properties: None,
1919 required: None,
1920 items: None,
1921 }))
1922 } else {
1923 items
1924 };
1925
1926 Ok(Schema {
1927 r#type: schema_type,
1928 format: obj.get("format").and_then(|v| v.as_str()).map(String::from),
1929 description: obj
1930 .get("description")
1931 .and_then(|v| v.as_str())
1932 .map(String::from),
1933 nullable: obj.get("nullable").and_then(|v| v.as_bool()),
1934 r#enum: obj.get("enum").and_then(|v| v.as_array()).map(|arr| {
1935 arr.iter()
1936 .filter_map(|v| v.as_str().map(String::from))
1937 .collect()
1938 }),
1939 max_items: obj
1940 .get("maxItems")
1941 .and_then(|v| v.as_i64())
1942 .map(|v| v as i32),
1943 min_items: obj
1944 .get("minItems")
1945 .and_then(|v| v.as_i64())
1946 .map(|v| v as i32),
1947 properties: props_source
1948 .get("properties")
1949 .and_then(|v| v.as_object())
1950 .map(|map| {
1951 map.iter()
1952 .filter_map(|(k, v)| {
1953 v.clone().try_into().ok().map(|schema| (k.clone(), schema))
1954 })
1955 .collect()
1956 }),
1957 required: props_source
1958 .get("required")
1959 .and_then(|v| v.as_array())
1960 .map(|arr| {
1961 arr.iter()
1962 .filter_map(|v| v.as_str().map(String::from))
1963 .collect()
1964 }),
1965 items,
1966 })
1967 } else {
1968 Err(CompletionError::ResponseError(
1969 "Expected a JSON object for Schema".into(),
1970 ))
1971 }
1972 }
1973 }
1974
1975 #[derive(Debug, Serialize)]
1976 #[serde(rename_all = "camelCase")]
1977 pub struct GenerateContentRequest {
1978 pub contents: Vec<Content>,
1979 #[serde(skip_serializing_if = "Option::is_none")]
1980 pub tools: Option<Vec<Value>>,
1981 pub tool_config: Option<ToolConfig>,
1982 pub generation_config: Option<GenerationConfig>,
1984 pub safety_settings: Option<Vec<SafetySetting>>,
1998 pub system_instruction: Option<Content>,
2001 #[serde(flatten, skip_serializing_if = "Option::is_none")]
2004 pub additional_params: Option<serde_json::Value>,
2005 }
2006
2007 #[derive(Debug, Serialize)]
2008 #[serde(rename_all = "camelCase")]
2009 pub struct Tool {
2010 pub function_declarations: Vec<FunctionDeclaration>,
2011 pub code_execution: Option<CodeExecution>,
2012 }
2013
2014 #[derive(Debug, Serialize, Clone)]
2015 #[serde(rename_all = "camelCase")]
2016 pub struct FunctionDeclaration {
2017 pub name: String,
2018 pub description: String,
2019 #[serde(skip_serializing_if = "Option::is_none")]
2020 pub parameters: Option<Schema>,
2021 }
2022
2023 #[derive(Debug, Serialize, Deserialize)]
2024 #[serde(rename_all = "camelCase")]
2025 pub struct ToolConfig {
2026 pub function_calling_config: Option<FunctionCallingMode>,
2027 }
2028
2029 #[derive(Debug, Serialize, Deserialize, Default)]
2030 #[serde(tag = "mode", rename_all = "UPPERCASE")]
2031 pub enum FunctionCallingMode {
2032 #[default]
2033 Auto,
2034 None,
2035 Any {
2036 #[serde(skip_serializing_if = "Option::is_none")]
2037 allowed_function_names: Option<Vec<String>>,
2038 },
2039 }
2040
2041 impl TryFrom<message::ToolChoice> for FunctionCallingMode {
2042 type Error = CompletionError;
2043 fn try_from(value: message::ToolChoice) -> Result<Self, Self::Error> {
2044 let res = match value {
2045 message::ToolChoice::Auto => Self::Auto,
2046 message::ToolChoice::None => Self::None,
2047 message::ToolChoice::Required => Self::Any {
2048 allowed_function_names: None,
2049 },
2050 message::ToolChoice::Specific { function_names } => Self::Any {
2051 allowed_function_names: Some(function_names),
2052 },
2053 };
2054
2055 Ok(res)
2056 }
2057 }
2058
2059 #[derive(Debug, Serialize)]
2060 pub struct CodeExecution {}
2061
2062 #[derive(Debug, Serialize)]
2063 #[serde(rename_all = "camelCase")]
2064 pub struct SafetySetting {
2065 pub category: HarmCategory,
2066 pub threshold: HarmBlockThreshold,
2067 }
2068
2069 #[derive(Debug, Serialize)]
2070 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
2071 pub enum HarmBlockThreshold {
2072 HarmBlockThresholdUnspecified,
2073 BlockLowAndAbove,
2074 BlockMediumAndAbove,
2075 BlockOnlyHigh,
2076 BlockNone,
2077 Off,
2078 }
2079}
2080
2081#[cfg(test)]
2082mod tests {
2083 use crate::{
2084 message,
2085 providers::gemini::completion::gemini_api_types::{
2086 ContentCandidate, FinishReason, UsageMetadata, flatten_schema,
2087 },
2088 };
2089
2090 use super::*;
2091 use serde_json::json;
2092
2093 #[test]
2094 fn test_resolve_request_model_uses_override() {
2095 let request = CompletionRequest {
2096 model: Some("gemini-2.5-flash".to_string()),
2097 preamble: None,
2098 chat_history: crate::OneOrMany::one("Hello".into()),
2099 documents: vec![],
2100 tools: vec![],
2101 temperature: None,
2102 max_tokens: None,
2103 tool_choice: None,
2104 additional_params: None,
2105 output_schema: None,
2106 };
2107
2108 let request_model = resolve_request_model("gemini-2.0-flash", &request);
2109 assert_eq!(request_model, "gemini-2.5-flash");
2110 assert_eq!(
2111 completion_endpoint(&request_model),
2112 "/v1beta/models/gemini-2.5-flash:generateContent"
2113 );
2114 assert_eq!(
2115 streaming_endpoint(&request_model),
2116 "/v1beta/models/gemini-2.5-flash:streamGenerateContent"
2117 );
2118 }
2119
2120 #[test]
2121 fn test_resolve_request_model_uses_default_when_unset() {
2122 let request = CompletionRequest {
2123 model: None,
2124 preamble: None,
2125 chat_history: crate::OneOrMany::one("Hello".into()),
2126 documents: vec![],
2127 tools: vec![],
2128 temperature: None,
2129 max_tokens: None,
2130 tool_choice: None,
2131 additional_params: None,
2132 output_schema: None,
2133 };
2134
2135 assert_eq!(
2136 resolve_request_model("gemini-2.0-flash", &request),
2137 "gemini-2.0-flash"
2138 );
2139 }
2140
2141 #[test]
2142 fn test_deserialize_message_user() {
2143 let raw_message = r#"{
2144 "parts": [
2145 {"text": "Hello, world!"},
2146 {"inlineData": {"mimeType": "image/png", "data": "base64encodeddata"}},
2147 {"functionCall": {"name": "test_function", "args": {"arg1": "value1"}}},
2148 {"functionResponse": {"name": "test_function", "response": {"result": "success"}}},
2149 {"fileData": {"mimeType": "application/pdf", "fileUri": "http://example.com/file.pdf"}},
2150 {"executableCode": {"code": "print('Hello, world!')", "language": "PYTHON"}},
2151 {"codeExecutionResult": {"output": "Hello, world!", "outcome": "OUTCOME_OK"}}
2152 ],
2153 "role": "user"
2154 }"#;
2155
2156 let content: Content = {
2157 let jd = &mut serde_json::Deserializer::from_str(raw_message);
2158 serde_path_to_error::deserialize(jd).unwrap_or_else(|err| {
2159 panic!("Deserialization error at {}: {}", err.path(), err);
2160 })
2161 };
2162 assert_eq!(content.role, Some(Role::User));
2163 assert_eq!(content.parts.len(), 7);
2164
2165 let parts: Vec<Part> = content.parts.into_iter().collect();
2166
2167 if let Part {
2168 part: PartKind::Text(text),
2169 ..
2170 } = &parts[0]
2171 {
2172 assert_eq!(text, "Hello, world!");
2173 } else {
2174 panic!("Expected text part");
2175 }
2176
2177 if let Part {
2178 part: PartKind::InlineData(inline_data),
2179 ..
2180 } = &parts[1]
2181 {
2182 assert_eq!(inline_data.mime_type, "image/png");
2183 assert_eq!(inline_data.data, "base64encodeddata");
2184 } else {
2185 panic!("Expected inline data part");
2186 }
2187
2188 if let Part {
2189 part: PartKind::FunctionCall(function_call),
2190 ..
2191 } = &parts[2]
2192 {
2193 assert_eq!(function_call.name, "test_function");
2194 assert_eq!(
2195 function_call.args.as_object().unwrap().get("arg1").unwrap(),
2196 "value1"
2197 );
2198 } else {
2199 panic!("Expected function call part");
2200 }
2201
2202 if let Part {
2203 part: PartKind::FunctionResponse(function_response),
2204 ..
2205 } = &parts[3]
2206 {
2207 assert_eq!(function_response.name, "test_function");
2208 assert_eq!(
2209 function_response
2210 .response
2211 .as_ref()
2212 .unwrap()
2213 .get("result")
2214 .unwrap(),
2215 "success"
2216 );
2217 } else {
2218 panic!("Expected function response part");
2219 }
2220
2221 if let Part {
2222 part: PartKind::FileData(file_data),
2223 ..
2224 } = &parts[4]
2225 {
2226 assert_eq!(file_data.mime_type.as_ref().unwrap(), "application/pdf");
2227 assert_eq!(file_data.file_uri, "http://example.com/file.pdf");
2228 } else {
2229 panic!("Expected file data part");
2230 }
2231
2232 if let Part {
2233 part: PartKind::ExecutableCode(executable_code),
2234 ..
2235 } = &parts[5]
2236 {
2237 assert_eq!(executable_code.code, "print('Hello, world!')");
2238 } else {
2239 panic!("Expected executable code part");
2240 }
2241
2242 if let Part {
2243 part: PartKind::CodeExecutionResult(code_execution_result),
2244 ..
2245 } = &parts[6]
2246 {
2247 assert_eq!(
2248 code_execution_result.clone().output.unwrap(),
2249 "Hello, world!"
2250 );
2251 } else {
2252 panic!("Expected code execution result part");
2253 }
2254 }
2255
2256 #[test]
2257 fn test_deserialize_message_model() {
2258 let json_data = json!({
2259 "parts": [{"text": "Hello, user!"}],
2260 "role": "model"
2261 });
2262
2263 let content: Content = serde_json::from_value(json_data).unwrap();
2264 assert_eq!(content.role, Some(Role::Model));
2265 assert_eq!(content.parts.len(), 1);
2266 if let Some(Part {
2267 part: PartKind::Text(text),
2268 ..
2269 }) = content.parts.first()
2270 {
2271 assert_eq!(text, "Hello, user!");
2272 } else {
2273 panic!("Expected text part");
2274 }
2275 }
2276
2277 #[test]
2278 fn test_message_conversion_user() {
2279 let msg = message::Message::user("Hello, world!");
2280 let content: Content = msg.try_into().unwrap();
2281 assert_eq!(content.role, Some(Role::User));
2282 assert_eq!(content.parts.len(), 1);
2283 if let Some(Part {
2284 part: PartKind::Text(text),
2285 ..
2286 }) = &content.parts.first()
2287 {
2288 assert_eq!(text, "Hello, world!");
2289 } else {
2290 panic!("Expected text part");
2291 }
2292 }
2293
2294 #[test]
2295 fn test_message_conversion_model() {
2296 let msg = message::Message::assistant("Hello, user!");
2297
2298 let content: Content = msg.try_into().unwrap();
2299 assert_eq!(content.role, Some(Role::Model));
2300 assert_eq!(content.parts.len(), 1);
2301 if let Some(Part {
2302 part: PartKind::Text(text),
2303 ..
2304 }) = &content.parts.first()
2305 {
2306 assert_eq!(text, "Hello, user!");
2307 } else {
2308 panic!("Expected text part");
2309 }
2310 }
2311
2312 #[test]
2313 fn test_thought_signature_is_preserved_from_response_reasoning_part() {
2314 let response = GenerateContentResponse {
2315 response_id: "resp_1".to_string(),
2316 candidates: vec![ContentCandidate {
2317 content: Some(Content {
2318 parts: vec![Part {
2319 thought: Some(true),
2320 thought_signature: Some("thought_sig_123".to_string()),
2321 part: PartKind::Text("thinking text".to_string()),
2322 additional_params: None,
2323 }],
2324 role: Some(Role::Model),
2325 }),
2326 finish_reason: Some(FinishReason::Stop),
2327 safety_ratings: None,
2328 citation_metadata: None,
2329 token_count: None,
2330 avg_logprobs: None,
2331 logprobs_result: None,
2332 index: Some(0),
2333 finish_message: None,
2334 }],
2335 prompt_feedback: None,
2336 usage_metadata: None,
2337 model_version: None,
2338 };
2339
2340 let converted: crate::completion::CompletionResponse<GenerateContentResponse> =
2341 response.try_into().expect("convert response");
2342 let first = converted.choice.first();
2343 assert!(matches!(
2344 first,
2345 message::AssistantContent::Reasoning(message::Reasoning { content, .. })
2346 if matches!(
2347 content.first(),
2348 Some(message::ReasoningContent::Text {
2349 text,
2350 signature: Some(signature)
2351 }) if text == "thinking text" && signature == "thought_sig_123"
2352 )
2353 ));
2354 }
2355
2356 #[test]
2357 fn test_completion_response_usage_preserves_cached_and_reasoning_tokens() {
2358 let response = GenerateContentResponse {
2359 response_id: "resp_1".to_string(),
2360 candidates: vec![ContentCandidate {
2361 content: Some(Content {
2362 parts: vec![Part {
2363 thought: None,
2364 thought_signature: None,
2365 part: PartKind::Text("answer".to_string()),
2366 additional_params: None,
2367 }],
2368 role: Some(Role::Model),
2369 }),
2370 finish_reason: Some(FinishReason::Stop),
2371 safety_ratings: None,
2372 citation_metadata: None,
2373 token_count: None,
2374 avg_logprobs: None,
2375 logprobs_result: None,
2376 index: Some(0),
2377 finish_message: None,
2378 }],
2379 prompt_feedback: None,
2380 usage_metadata: Some(UsageMetadata {
2381 prompt_token_count: 40,
2382 cached_content_token_count: Some(20),
2383 candidates_token_count: Some(30),
2384 total_token_count: 100,
2385 thoughts_token_count: Some(10),
2386 prompt_tokens_details: None,
2387 cache_tokens_details: None,
2388 candidates_tokens_details: None,
2389 tool_use_prompt_token_count: None,
2390 tool_use_prompt_tokens_details: None,
2391 traffic_type: None,
2392 }),
2393 model_version: Some("gemini-2.0-flash-001".to_string()),
2394 };
2395
2396 let converted: crate::completion::CompletionResponse<GenerateContentResponse> =
2397 response.try_into().expect("convert response");
2398
2399 assert_eq!(converted.usage.input_tokens, 40);
2400 assert_eq!(converted.usage.cached_input_tokens, 20);
2401 assert_eq!(converted.usage.output_tokens, 30);
2402 assert_eq!(converted.usage.reasoning_tokens, 10);
2403 assert_eq!(converted.usage.total_tokens, 100);
2404 }
2405
2406 #[test]
2407 fn test_reasoning_signature_is_emitted_in_gemini_part() {
2408 let msg = message::Message::Assistant {
2409 id: None,
2410 content: OneOrMany::one(message::AssistantContent::Reasoning(
2411 message::Reasoning::new_with_signature(
2412 "structured thought",
2413 Some("reuse_sig_456".to_string()),
2414 ),
2415 )),
2416 };
2417
2418 let converted: Content = msg.try_into().expect("convert message");
2419 let first = converted.parts.first().expect("reasoning part");
2420 assert_eq!(first.thought, Some(true));
2421 assert_eq!(first.thought_signature.as_deref(), Some("reuse_sig_456"));
2422 assert!(matches!(
2423 &first.part,
2424 PartKind::Text(text) if text == "structured thought"
2425 ));
2426 }
2427
2428 #[test]
2429 fn test_message_conversion_tool_call() {
2430 let tool_call = message::ToolCall {
2431 id: "test_tool".to_string(),
2432 call_id: None,
2433 function: message::ToolFunction {
2434 name: "test_function".to_string(),
2435 arguments: json!({"arg1": "value1"}),
2436 },
2437 signature: None,
2438 additional_params: None,
2439 };
2440
2441 let msg = message::Message::Assistant {
2442 id: None,
2443 content: OneOrMany::one(message::AssistantContent::ToolCall(tool_call)),
2444 };
2445
2446 let content: Content = msg.try_into().unwrap();
2447 assert_eq!(content.role, Some(Role::Model));
2448 assert_eq!(content.parts.len(), 1);
2449 if let Some(Part {
2450 part: PartKind::FunctionCall(function_call),
2451 ..
2452 }) = content.parts.first()
2453 {
2454 assert_eq!(function_call.name, "test_function");
2455 assert_eq!(
2456 function_call.args.as_object().unwrap().get("arg1").unwrap(),
2457 "value1"
2458 );
2459 } else {
2460 panic!("Expected function call part");
2461 }
2462 }
2463
2464 #[test]
2465 fn test_vec_schema_conversion() {
2466 let schema_with_ref = json!({
2467 "type": "array",
2468 "items": {
2469 "$ref": "#/$defs/Person"
2470 },
2471 "$defs": {
2472 "Person": {
2473 "type": "object",
2474 "properties": {
2475 "first_name": {
2476 "type": ["string", "null"],
2477 "description": "The person's first name, if provided (null otherwise)"
2478 },
2479 "last_name": {
2480 "type": ["string", "null"],
2481 "description": "The person's last name, if provided (null otherwise)"
2482 },
2483 "job": {
2484 "type": ["string", "null"],
2485 "description": "The person's job, if provided (null otherwise)"
2486 }
2487 },
2488 "required": []
2489 }
2490 }
2491 });
2492
2493 let result: Result<Schema, _> = schema_with_ref.try_into();
2494
2495 match result {
2496 Ok(schema) => {
2497 assert_eq!(schema.r#type, "array");
2498
2499 if let Some(items) = schema.items {
2500 println!("item types: {}", items.r#type);
2501
2502 assert_ne!(items.r#type, "", "Items type should not be empty string!");
2503 assert_eq!(items.r#type, "object", "Items should be object type");
2504 } else {
2505 panic!("Schema should have items field for array type");
2506 }
2507 }
2508 Err(e) => println!("Schema conversion failed: {:?}", e),
2509 }
2510 }
2511
2512 #[test]
2513 fn test_object_schema() {
2514 let simple_schema = json!({
2515 "type": "object",
2516 "properties": {
2517 "name": {
2518 "type": "string"
2519 }
2520 }
2521 });
2522
2523 let schema: Schema = simple_schema.try_into().unwrap();
2524 assert_eq!(schema.r#type, "object");
2525 assert!(schema.properties.is_some());
2526 }
2527
2528 #[test]
2529 fn test_array_with_inline_items() {
2530 let inline_schema = json!({
2531 "type": "array",
2532 "items": {
2533 "type": "object",
2534 "properties": {
2535 "name": {
2536 "type": "string"
2537 }
2538 }
2539 }
2540 });
2541
2542 let schema: Schema = inline_schema.try_into().unwrap();
2543 assert_eq!(schema.r#type, "array");
2544
2545 if let Some(items) = schema.items {
2546 assert_eq!(items.r#type, "object");
2547 assert!(items.properties.is_some());
2548 } else {
2549 panic!("Schema should have items field");
2550 }
2551 }
2552 #[test]
2553 fn test_flattened_schema() {
2554 let ref_schema = json!({
2555 "type": "array",
2556 "items": {
2557 "$ref": "#/$defs/Person"
2558 },
2559 "$defs": {
2560 "Person": {
2561 "type": "object",
2562 "properties": {
2563 "name": { "type": "string" }
2564 }
2565 }
2566 }
2567 });
2568
2569 let flattened = flatten_schema(ref_schema).unwrap();
2570 let schema: Schema = flattened.try_into().unwrap();
2571
2572 assert_eq!(schema.r#type, "array");
2573
2574 if let Some(items) = schema.items {
2575 println!("Flattened items type: '{}'", items.r#type);
2576
2577 assert_eq!(items.r#type, "object");
2578 assert!(items.properties.is_some());
2579 }
2580 }
2581
2582 #[test]
2583 fn test_array_without_items_gets_default() {
2584 let schema_json = json!({
2585 "type": "object",
2586 "properties": {
2587 "service_ids": {
2588 "type": "array",
2589 "description": "A list of service IDs"
2590 }
2591 }
2592 });
2593
2594 let schema: Schema = schema_json.try_into().unwrap();
2595 let props = schema.properties.unwrap();
2596 let service_ids = props.get("service_ids").unwrap();
2597 assert_eq!(service_ids.r#type, "array");
2598 let items = service_ids
2599 .items
2600 .as_ref()
2601 .expect("array schema missing items should get a default");
2602 assert_eq!(items.r#type, "string");
2603 }
2604
2605 #[test]
2606 fn test_txt_document_conversion_to_text_part() {
2607 use crate::message::{DocumentMediaType, UserContent};
2609
2610 let doc = UserContent::document(
2611 "Note: test.md\nPath: /test.md\nContent: Hello World!",
2612 Some(DocumentMediaType::TXT),
2613 );
2614
2615 let content: Content = message::Message::User {
2616 content: crate::OneOrMany::one(doc),
2617 }
2618 .try_into()
2619 .unwrap();
2620
2621 if let Part {
2622 part: PartKind::Text(text),
2623 ..
2624 } = &content.parts[0]
2625 {
2626 assert!(text.contains("Note: test.md"));
2627 assert!(text.contains("Hello World!"));
2628 } else {
2629 panic!(
2630 "Expected text part for TXT document, got: {:?}",
2631 content.parts[0]
2632 );
2633 }
2634 }
2635
2636 #[test]
2637 fn test_tool_result_with_image_content() {
2638 use crate::OneOrMany;
2640 use crate::message::{
2641 DocumentSourceKind, Image, ImageMediaType, ToolResult, ToolResultContent,
2642 };
2643
2644 let tool_result = ToolResult {
2646 id: "test_tool".to_string(),
2647 call_id: None,
2648 content: OneOrMany::many(vec![
2649 ToolResultContent::Text(message::Text {
2650 text: r#"{"status": "success"}"#.to_string(),
2651 }),
2652 ToolResultContent::Image(Image {
2653 data: DocumentSourceKind::Base64("iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==".to_string()),
2654 media_type: Some(ImageMediaType::PNG),
2655 detail: None,
2656 additional_params: None,
2657 }),
2658 ]).expect("Should create OneOrMany with multiple items"),
2659 };
2660
2661 let user_content = message::UserContent::ToolResult(tool_result);
2662 let msg = message::Message::User {
2663 content: OneOrMany::one(user_content),
2664 };
2665
2666 let content: Content = msg.try_into().expect("Should convert to Gemini Content");
2668 assert_eq!(content.role, Some(Role::User));
2669 assert_eq!(content.parts.len(), 1);
2670
2671 if let Some(Part {
2673 part: PartKind::FunctionResponse(function_response),
2674 ..
2675 }) = content.parts.first()
2676 {
2677 assert_eq!(function_response.name, "test_tool");
2678
2679 assert!(function_response.response.is_some());
2681 let response = function_response.response.as_ref().unwrap();
2682 assert!(response.get("result").is_some());
2683
2684 assert!(function_response.parts.is_some());
2686 let parts = function_response.parts.as_ref().unwrap();
2687 assert_eq!(parts.len(), 1);
2688
2689 let image_part = &parts[0];
2690 assert!(image_part.inline_data.is_some());
2691 let inline_data = image_part.inline_data.as_ref().unwrap();
2692 assert_eq!(inline_data.mime_type, "image/png");
2693 assert!(!inline_data.data.is_empty());
2694 } else {
2695 panic!("Expected FunctionResponse part");
2696 }
2697 }
2698
2699 #[test]
2700 fn test_markdown_document_conversion_to_text_part() {
2701 use crate::message::{DocumentMediaType, UserContent};
2703
2704 let doc = UserContent::document(
2705 "# Heading\n\n* List item",
2706 Some(DocumentMediaType::MARKDOWN),
2707 );
2708
2709 let content: Content = message::Message::User {
2710 content: crate::OneOrMany::one(doc),
2711 }
2712 .try_into()
2713 .unwrap();
2714
2715 if let Part {
2716 part: PartKind::Text(text),
2717 ..
2718 } = &content.parts[0]
2719 {
2720 assert_eq!(text, "# Heading\n\n* List item");
2721 } else {
2722 panic!(
2723 "Expected text part for MARKDOWN document, got: {:?}",
2724 content.parts[0]
2725 );
2726 }
2727 }
2728
2729 #[test]
2730 fn test_markdown_url_document_conversion_to_file_data_part() {
2731 use crate::message::{DocumentMediaType, DocumentSourceKind, UserContent};
2733
2734 let doc = UserContent::Document(message::Document {
2735 data: DocumentSourceKind::Url(
2736 "https://generativelanguage.googleapis.com/v1beta/files/test-markdown".to_string(),
2737 ),
2738 media_type: Some(DocumentMediaType::MARKDOWN),
2739 additional_params: None,
2740 });
2741
2742 let content: Content = message::Message::User {
2743 content: crate::OneOrMany::one(doc),
2744 }
2745 .try_into()
2746 .unwrap();
2747
2748 if let Part {
2749 part: PartKind::FileData(file_data),
2750 ..
2751 } = &content.parts[0]
2752 {
2753 assert_eq!(
2754 file_data.file_uri,
2755 "https://generativelanguage.googleapis.com/v1beta/files/test-markdown"
2756 );
2757 assert_eq!(file_data.mime_type.as_deref(), Some("text/markdown"));
2758 } else {
2759 panic!(
2760 "Expected file_data part for URL MARKDOWN document, got: {:?}",
2761 content.parts[0]
2762 );
2763 }
2764 }
2765
2766 #[test]
2767 fn test_tool_result_with_url_image() {
2768 use crate::OneOrMany;
2770 use crate::message::{
2771 DocumentSourceKind, Image, ImageMediaType, ToolResult, ToolResultContent,
2772 };
2773
2774 let tool_result = ToolResult {
2775 id: "screenshot_tool".to_string(),
2776 call_id: None,
2777 content: OneOrMany::one(ToolResultContent::Image(Image {
2778 data: DocumentSourceKind::Url("https://example.com/image.png".to_string()),
2779 media_type: Some(ImageMediaType::PNG),
2780 detail: None,
2781 additional_params: None,
2782 })),
2783 };
2784
2785 let user_content = message::UserContent::ToolResult(tool_result);
2786 let msg = message::Message::User {
2787 content: OneOrMany::one(user_content),
2788 };
2789
2790 let content: Content = msg.try_into().expect("Should convert to Gemini Content");
2791 assert_eq!(content.role, Some(Role::User));
2792 assert_eq!(content.parts.len(), 1);
2793
2794 if let Some(Part {
2795 part: PartKind::FunctionResponse(function_response),
2796 ..
2797 }) = content.parts.first()
2798 {
2799 assert_eq!(function_response.name, "screenshot_tool");
2800
2801 assert!(function_response.parts.is_some());
2803 let parts = function_response.parts.as_ref().unwrap();
2804 assert_eq!(parts.len(), 1);
2805
2806 let image_part = &parts[0];
2807 assert!(image_part.file_data.is_some());
2808 let file_data = image_part.file_data.as_ref().unwrap();
2809 assert_eq!(file_data.file_uri, "https://example.com/image.png");
2810 assert_eq!(file_data.mime_type.as_ref().unwrap(), "image/png");
2811 } else {
2812 panic!("Expected FunctionResponse part");
2813 }
2814 }
2815
2816 #[test]
2817 fn test_create_request_body_with_documents() {
2818 use crate::OneOrMany;
2820 use crate::completion::request::{CompletionRequest, Document};
2821 use crate::message::Message;
2822
2823 let documents = vec![
2824 Document {
2825 id: "doc1".to_string(),
2826 text: "Note: first.md\nContent: First note".to_string(),
2827 additional_props: std::collections::HashMap::new(),
2828 },
2829 Document {
2830 id: "doc2".to_string(),
2831 text: "Note: second.md\nContent: Second note".to_string(),
2832 additional_props: std::collections::HashMap::new(),
2833 },
2834 ];
2835
2836 let completion_request = CompletionRequest {
2837 preamble: Some("You are a helpful assistant".to_string()),
2838 chat_history: OneOrMany::one(Message::user("What are my notes about?")),
2839 documents: documents.clone(),
2840 tools: vec![],
2841 temperature: None,
2842 model: None,
2843 output_schema: None,
2844 max_tokens: None,
2845 tool_choice: None,
2846 additional_params: None,
2847 };
2848
2849 let request = create_request_body(completion_request).unwrap();
2850
2851 assert_eq!(
2853 request.contents.len(),
2854 2,
2855 "Expected 2 contents (documents + user message)"
2856 );
2857
2858 assert_eq!(request.contents[0].role, Some(Role::User));
2860 assert_eq!(
2861 request.contents[0].parts.len(),
2862 2,
2863 "Expected 2 document parts"
2864 );
2865
2866 for part in &request.contents[0].parts {
2868 if let Part {
2869 part: PartKind::Text(text),
2870 ..
2871 } = part
2872 {
2873 assert!(
2874 text.contains("Note:") && text.contains("Content:"),
2875 "Document should contain note metadata"
2876 );
2877 } else {
2878 panic!("Document parts should be text, not {:?}", part);
2879 }
2880 }
2881
2882 assert_eq!(request.contents[1].role, Some(Role::User));
2884 if let Part {
2885 part: PartKind::Text(text),
2886 ..
2887 } = &request.contents[1].parts[0]
2888 {
2889 assert_eq!(text, "What are my notes about?");
2890 } else {
2891 panic!("Expected user message to be text");
2892 }
2893 }
2894
2895 #[test]
2896 fn test_create_request_body_without_documents() {
2897 use crate::OneOrMany;
2899 use crate::completion::request::CompletionRequest;
2900 use crate::message::Message;
2901
2902 let completion_request = CompletionRequest {
2903 preamble: Some("You are a helpful assistant".to_string()),
2904 chat_history: OneOrMany::one(Message::user("Hello")),
2905 documents: vec![], tools: vec![],
2907 temperature: None,
2908 max_tokens: None,
2909 tool_choice: None,
2910 model: None,
2911 output_schema: None,
2912 additional_params: None,
2913 };
2914
2915 let request = create_request_body(completion_request).unwrap();
2916
2917 assert_eq!(request.contents.len(), 1, "Expected only user message");
2919 assert_eq!(request.contents[0].role, Some(Role::User));
2920
2921 if let Part {
2922 part: PartKind::Text(text),
2923 ..
2924 } = &request.contents[0].parts[0]
2925 {
2926 assert_eq!(text, "Hello");
2927 } else {
2928 panic!("Expected user message to be text");
2929 }
2930 }
2931
2932 #[test]
2933 fn test_from_tool_output_parses_image_json() {
2934 use crate::message::{DocumentSourceKind, ToolResultContent};
2936
2937 let image_json = r#"{"type": "image", "data": "base64data==", "mimeType": "image/jpeg"}"#;
2939 let result = ToolResultContent::from_tool_output(image_json);
2940
2941 assert_eq!(result.len(), 1);
2942 if let ToolResultContent::Image(img) = result.first() {
2943 assert!(matches!(img.data, DocumentSourceKind::Base64(_)));
2944 if let DocumentSourceKind::Base64(data) = &img.data {
2945 assert_eq!(data, "base64data==");
2946 }
2947 assert_eq!(img.media_type, Some(crate::message::ImageMediaType::JPEG));
2948 } else {
2949 panic!("Expected Image content");
2950 }
2951 }
2952
2953 #[test]
2954 fn test_from_tool_output_parses_hybrid_json() {
2955 use crate::message::{DocumentSourceKind, ToolResultContent};
2957
2958 let hybrid_json = r#"{
2959 "response": {"status": "ok", "count": 42},
2960 "parts": [
2961 {"type": "image", "data": "imgdata1==", "mimeType": "image/png"},
2962 {"type": "image", "data": "https://example.com/img.jpg", "mimeType": "image/jpeg"}
2963 ]
2964 }"#;
2965
2966 let result = ToolResultContent::from_tool_output(hybrid_json);
2967
2968 assert_eq!(result.len(), 3);
2970
2971 let items: Vec<_> = result.iter().collect();
2972
2973 if let ToolResultContent::Text(text) = &items[0] {
2975 assert!(text.text.contains("status"));
2976 assert!(text.text.contains("ok"));
2977 } else {
2978 panic!("Expected Text content first");
2979 }
2980
2981 if let ToolResultContent::Image(img) = &items[1] {
2983 assert!(matches!(img.data, DocumentSourceKind::Base64(_)));
2984 } else {
2985 panic!("Expected Image content second");
2986 }
2987
2988 if let ToolResultContent::Image(img) = &items[2] {
2990 assert!(matches!(img.data, DocumentSourceKind::Url(_)));
2991 } else {
2992 panic!("Expected Image content third");
2993 }
2994 }
2995
2996 #[tokio::test]
3000 #[ignore = "requires GEMINI_API_KEY environment variable"]
3001 async fn test_gemini_agent_with_image_tool_result_e2e() -> anyhow::Result<()> {
3002 use crate::completion::Prompt;
3003 use crate::prelude::*;
3004 use crate::providers::gemini;
3005 use crate::test_utils::MockImageGeneratorTool;
3006
3007 let client = gemini::Client::from_env()?;
3008
3009 let agent = client
3010 .agent("gemini-3-flash-preview")
3011 .preamble("You are a helpful assistant. When asked about images, use the generate_test_image tool to create one, then describe what you see in the image.")
3012 .tool(MockImageGeneratorTool)
3013 .build();
3014
3015 let response_text = agent
3017 .prompt("Please generate a test image and tell me what color the pixel is.")
3018 .await?;
3019 println!("Response: {response_text}");
3020 anyhow::ensure!(!response_text.is_empty(), "Response should not be empty");
3022
3023 Ok(())
3024 }
3025}