1pub const GEMINI_2_5_PRO_PREVIEW_06_05: &str = "gemini-2.5-pro-preview-06-05";
7pub const GEMINI_2_5_PRO_PREVIEW_05_06: &str = "gemini-2.5-pro-preview-05-06";
9pub const GEMINI_2_5_PRO_PREVIEW_03_25: &str = "gemini-2.5-pro-preview-03-25";
11pub const GEMINI_2_5_FLASH_PREVIEW_04_17: &str = "gemini-2.5-flash-preview-04-17";
13pub const GEMINI_2_5_PRO_EXP_03_25: &str = "gemini-2.5-pro-exp-03-25";
15pub const GEMINI_2_5_FLASH: &str = "gemini-2.5-flash";
17pub const GEMINI_2_0_FLASH_LITE: &str = "gemini-2.0-flash-lite";
19pub const GEMINI_2_0_FLASH: &str = "gemini-2.0-flash";
21
22use self::gemini_api_types::Schema;
23use crate::http_client::HttpClientExt;
24use crate::message::{self, MimeType, Reasoning};
25
26use crate::providers::gemini::completion::gemini_api_types::{
27 AdditionalParameters, FunctionCallingMode, ToolConfig,
28};
29use crate::providers::gemini::streaming::StreamingCompletionResponse;
30use crate::telemetry::SpanCombinator;
31use crate::{
32 OneOrMany,
33 completion::{self, CompletionError, CompletionRequest},
34};
35use gemini_api_types::{
36 Content, FunctionDeclaration, GenerateContentRequest, GenerateContentResponse,
37 GenerationConfig, Part, PartKind, Role, Tool,
38};
39use serde_json::{Map, Value};
40use std::convert::TryFrom;
41use tracing::{Level, enabled, info_span};
42use tracing_futures::Instrument;
43
44use super::Client;
45
46#[derive(Clone, Debug)]
51pub struct CompletionModel<T = reqwest::Client> {
52 pub(crate) client: Client<T>,
53 pub model: String,
54}
55
56impl<T> CompletionModel<T> {
57 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
58 Self {
59 client,
60 model: model.into(),
61 }
62 }
63
64 pub fn with_model(client: Client<T>, model: &str) -> Self {
65 Self {
66 client,
67 model: model.into(),
68 }
69 }
70}
71
72impl<T> completion::CompletionModel for CompletionModel<T>
73where
74 T: HttpClientExt + Clone + 'static,
75{
76 type Response = GenerateContentResponse;
77 type StreamingResponse = StreamingCompletionResponse;
78 type Client = super::Client<T>;
79
80 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
81 Self::new(client.clone(), model)
82 }
83
84 async fn completion(
85 &self,
86 completion_request: CompletionRequest,
87 ) -> Result<completion::CompletionResponse<GenerateContentResponse>, CompletionError> {
88 let request_model = resolve_request_model(&self.model, &completion_request);
89 let span = if tracing::Span::current().is_disabled() {
90 info_span!(
91 target: "rig::completions",
92 "generate_content",
93 gen_ai.operation.name = "generate_content",
94 gen_ai.provider.name = "gcp.gemini",
95 gen_ai.request.model = &request_model,
96 gen_ai.system_instructions = &completion_request.preamble,
97 gen_ai.response.id = tracing::field::Empty,
98 gen_ai.response.model = tracing::field::Empty,
99 gen_ai.usage.output_tokens = tracing::field::Empty,
100 gen_ai.usage.input_tokens = tracing::field::Empty,
101 )
102 } else {
103 tracing::Span::current()
104 };
105
106 let request = create_request_body(completion_request)?;
107
108 if enabled!(Level::TRACE) {
109 tracing::trace!(
110 target: "rig::completions",
111 "Gemini completion request: {}",
112 serde_json::to_string_pretty(&request)?
113 );
114 }
115
116 let body = serde_json::to_vec(&request)?;
117
118 let path = completion_endpoint(&request_model);
119
120 let request = self
121 .client
122 .post(path.as_str())?
123 .body(body)
124 .map_err(|e| CompletionError::HttpError(e.into()))?;
125
126 async move {
127 let response = self.client.send::<_, Vec<u8>>(request).await?;
128
129 if response.status().is_success() {
130 let response_body = response
131 .into_body()
132 .await
133 .map_err(CompletionError::HttpError)?;
134
135 let response_text = String::from_utf8_lossy(&response_body).to_string();
136
137 let response: GenerateContentResponse = serde_json::from_slice(&response_body)
138 .map_err(|err| {
139 tracing::error!(
140 error = %err,
141 body = %response_text,
142 "Failed to deserialize Gemini completion response"
143 );
144 CompletionError::JsonError(err)
145 })?;
146
147 let span = tracing::Span::current();
148 span.record_response_metadata(&response);
149 span.record_token_usage(&response.usage_metadata);
150
151 if enabled!(Level::TRACE) {
152 tracing::trace!(
153 target: "rig::completions",
154 "Gemini completion response: {}",
155 serde_json::to_string_pretty(&response)?
156 );
157 }
158
159 response.try_into()
160 } else {
161 let text = String::from_utf8_lossy(
162 &response
163 .into_body()
164 .await
165 .map_err(CompletionError::HttpError)?,
166 )
167 .into();
168
169 Err(CompletionError::ProviderError(text))
170 }
171 }
172 .instrument(span)
173 .await
174 }
175
176 async fn stream(
177 &self,
178 request: CompletionRequest,
179 ) -> Result<
180 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
181 CompletionError,
182 > {
183 CompletionModel::stream(self, request).await
184 }
185}
186
187pub(crate) fn create_request_body(
188 completion_request: CompletionRequest,
189) -> Result<GenerateContentRequest, CompletionError> {
190 let mut full_history = Vec::new();
191
192 if let Some(documents_message) = completion_request.normalized_documents() {
194 full_history.push(documents_message);
195 }
196
197 full_history.extend(completion_request.chat_history);
198
199 let additional_params = completion_request
200 .additional_params
201 .unwrap_or_else(|| Value::Object(Map::new()));
202
203 let AdditionalParameters {
204 mut generation_config,
205 additional_params,
206 } = serde_json::from_value::<AdditionalParameters>(additional_params)?;
207
208 if let Some(schema) = completion_request.output_schema {
210 let cfg = generation_config.get_or_insert_with(GenerationConfig::default);
211 cfg.response_mime_type = Some("application/json".to_string());
212 cfg.response_json_schema = Some(schema.to_value());
213 }
214
215 generation_config = generation_config.map(|mut cfg| {
216 if let Some(temp) = completion_request.temperature {
217 cfg.temperature = Some(temp);
218 };
219
220 if let Some(max_tokens) = completion_request.max_tokens {
221 cfg.max_output_tokens = Some(max_tokens);
222 };
223
224 cfg
225 });
226
227 let system_instruction = completion_request.preamble.clone().map(|preamble| Content {
228 parts: vec![preamble.into()],
229 role: Some(Role::Model),
230 });
231
232 let tools = if completion_request.tools.is_empty() {
233 None
234 } else {
235 Some(vec![Tool::try_from(completion_request.tools)?])
236 };
237
238 let tool_config = if let Some(cfg) = completion_request.tool_choice {
239 Some(ToolConfig {
240 function_calling_config: Some(FunctionCallingMode::try_from(cfg)?),
241 })
242 } else {
243 None
244 };
245
246 let request = GenerateContentRequest {
247 contents: full_history
248 .into_iter()
249 .map(|msg| {
250 msg.try_into()
251 .map_err(|e| CompletionError::RequestError(Box::new(e)))
252 })
253 .collect::<Result<Vec<_>, _>>()?,
254 generation_config,
255 safety_settings: None,
256 tools,
257 tool_config,
258 system_instruction,
259 additional_params,
260 };
261
262 Ok(request)
263}
264
265pub(crate) fn resolve_request_model(
266 default_model: &str,
267 completion_request: &CompletionRequest,
268) -> String {
269 completion_request
270 .model
271 .clone()
272 .unwrap_or_else(|| default_model.to_string())
273}
274
275pub(crate) fn completion_endpoint(model: &str) -> String {
276 format!("/v1beta/models/{model}:generateContent")
277}
278
279pub(crate) fn streaming_endpoint(model: &str) -> String {
280 format!("/v1beta/models/{model}:streamGenerateContent")
281}
282
283impl TryFrom<completion::ToolDefinition> for Tool {
284 type Error = CompletionError;
285
286 fn try_from(tool: completion::ToolDefinition) -> Result<Self, Self::Error> {
287 let parameters: Option<Schema> =
288 if tool.parameters == serde_json::json!({"type": "object", "properties": {}}) {
289 None
290 } else {
291 Some(tool.parameters.try_into()?)
292 };
293
294 Ok(Self {
295 function_declarations: vec![FunctionDeclaration {
296 name: tool.name,
297 description: tool.description,
298 parameters,
299 }],
300 code_execution: None,
301 })
302 }
303}
304
305impl TryFrom<Vec<completion::ToolDefinition>> for Tool {
306 type Error = CompletionError;
307
308 fn try_from(tools: Vec<completion::ToolDefinition>) -> Result<Self, Self::Error> {
309 let mut function_declarations = Vec::new();
310
311 for tool in tools {
312 let parameters =
313 if tool.parameters == serde_json::json!({"type": "object", "properties": {}}) {
314 None
315 } else {
316 match tool.parameters.try_into() {
317 Ok(schema) => Some(schema),
318 Err(e) => {
319 let emsg = format!(
320 "Tool '{}' could not be converted to a schema: {:?}",
321 tool.name, e,
322 );
323 return Err(CompletionError::ProviderError(emsg));
324 }
325 }
326 };
327
328 function_declarations.push(FunctionDeclaration {
329 name: tool.name,
330 description: tool.description,
331 parameters,
332 });
333 }
334
335 Ok(Self {
336 function_declarations,
337 code_execution: None,
338 })
339 }
340}
341
342impl TryFrom<GenerateContentResponse> for completion::CompletionResponse<GenerateContentResponse> {
343 type Error = CompletionError;
344
345 fn try_from(response: GenerateContentResponse) -> Result<Self, Self::Error> {
346 let candidate = response.candidates.first().ok_or_else(|| {
347 CompletionError::ResponseError("No response candidates in response".into())
348 })?;
349
350 let content = candidate
351 .content
352 .as_ref()
353 .ok_or_else(|| {
354 let reason = candidate
355 .finish_reason
356 .as_ref()
357 .map(|r| format!("finish_reason={r:?}"))
358 .unwrap_or_else(|| "finish_reason=<unknown>".to_string());
359 let message = candidate
360 .finish_message
361 .as_deref()
362 .unwrap_or("no finish message provided");
363 CompletionError::ResponseError(format!(
364 "Gemini candidate missing content ({reason}, finish_message={message})"
365 ))
366 })?
367 .parts
368 .iter()
369 .map(
370 |Part {
371 thought,
372 thought_signature,
373 part,
374 ..
375 }| {
376 Ok(match part {
377 PartKind::Text(text) => {
378 if let Some(thought) = thought
379 && *thought
380 {
381 completion::AssistantContent::Reasoning(
382 Reasoning::new_with_signature(text, thought_signature.clone()),
383 )
384 } else {
385 completion::AssistantContent::text(text)
386 }
387 }
388 PartKind::InlineData(inline_data) => {
389 let mime_type =
390 message::MediaType::from_mime_type(&inline_data.mime_type);
391
392 match mime_type {
393 Some(message::MediaType::Image(media_type)) => {
394 message::AssistantContent::image_base64(
395 &inline_data.data,
396 Some(media_type),
397 Some(message::ImageDetail::default()),
398 )
399 }
400 _ => {
401 return Err(CompletionError::ResponseError(format!(
402 "Unsupported media type {mime_type:?}"
403 )));
404 }
405 }
406 }
407 PartKind::FunctionCall(function_call) => {
408 completion::AssistantContent::ToolCall(
409 message::ToolCall::new(
410 function_call.name.clone(),
411 message::ToolFunction::new(
412 function_call.name.clone(),
413 function_call.args.clone(),
414 ),
415 )
416 .with_signature(thought_signature.clone()),
417 )
418 }
419 _ => {
420 return Err(CompletionError::ResponseError(
421 "Response did not contain a message or tool call".into(),
422 ));
423 }
424 })
425 },
426 )
427 .collect::<Result<Vec<_>, _>>()?;
428
429 let choice = OneOrMany::many(content).map_err(|_| {
430 CompletionError::ResponseError(
431 "Response contained no message or tool call (empty)".to_owned(),
432 )
433 })?;
434
435 let usage = response
436 .usage_metadata
437 .as_ref()
438 .map(|usage| completion::Usage {
439 input_tokens: usage.prompt_token_count as u64,
440 output_tokens: usage.candidates_token_count.unwrap_or(0) as u64,
441 total_tokens: usage.total_token_count as u64,
442 cached_input_tokens: 0,
443 })
444 .unwrap_or_default();
445
446 Ok(completion::CompletionResponse {
447 choice,
448 usage,
449 raw_response: response,
450 message_id: None,
451 })
452 }
453}
454
455pub mod gemini_api_types {
456 use crate::telemetry::ProviderResponseExt;
457 use std::{collections::HashMap, convert::Infallible, str::FromStr};
458
459 use serde::{Deserialize, Serialize};
463 use serde_json::{Value, json};
464
465 use crate::completion::GetTokenUsage;
466 use crate::message::{DocumentSourceKind, ImageMediaType, MessageError, MimeType};
467 use crate::{
468 completion::CompletionError,
469 message::{self},
470 providers::gemini::gemini_api_types::{CodeExecutionResult, ExecutableCode},
471 };
472
473 #[derive(Debug, Deserialize, Serialize, Default)]
474 #[serde(rename_all = "camelCase")]
475 pub struct AdditionalParameters {
476 pub generation_config: Option<GenerationConfig>,
478 #[serde(flatten, skip_serializing_if = "Option::is_none")]
480 pub additional_params: Option<serde_json::Value>,
481 }
482
483 impl AdditionalParameters {
484 pub fn with_config(mut self, cfg: GenerationConfig) -> Self {
485 self.generation_config = Some(cfg);
486 self
487 }
488
489 pub fn with_params(mut self, params: serde_json::Value) -> Self {
490 self.additional_params = Some(params);
491 self
492 }
493 }
494
495 #[derive(Debug, Deserialize, Serialize)]
503 #[serde(rename_all = "camelCase")]
504 pub struct GenerateContentResponse {
505 pub response_id: String,
506 pub candidates: Vec<ContentCandidate>,
508 pub prompt_feedback: Option<PromptFeedback>,
510 pub usage_metadata: Option<UsageMetadata>,
512 pub model_version: Option<String>,
513 }
514
515 impl ProviderResponseExt for GenerateContentResponse {
516 type OutputMessage = ContentCandidate;
517 type Usage = UsageMetadata;
518
519 fn get_response_id(&self) -> Option<String> {
520 Some(self.response_id.clone())
521 }
522
523 fn get_response_model_name(&self) -> Option<String> {
524 None
525 }
526
527 fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
528 self.candidates.clone()
529 }
530
531 fn get_text_response(&self) -> Option<String> {
532 let str = self
533 .candidates
534 .iter()
535 .filter_map(|x| {
536 let content = x.content.as_ref()?;
537 if content.role.as_ref().is_none_or(|y| y != &Role::Model) {
538 return None;
539 }
540
541 let res = content
542 .parts
543 .iter()
544 .filter_map(|part| {
545 if let PartKind::Text(ref str) = part.part {
546 Some(str.to_owned())
547 } else {
548 None
549 }
550 })
551 .collect::<Vec<String>>()
552 .join("\n");
553
554 Some(res)
555 })
556 .collect::<Vec<String>>()
557 .join("\n");
558
559 if str.is_empty() { None } else { Some(str) }
560 }
561
562 fn get_usage(&self) -> Option<Self::Usage> {
563 self.usage_metadata.clone()
564 }
565 }
566
567 #[derive(Clone, Debug, Deserialize, Serialize)]
569 #[serde(rename_all = "camelCase")]
570 pub struct ContentCandidate {
571 #[serde(skip_serializing_if = "Option::is_none")]
573 pub content: Option<Content>,
574 pub finish_reason: Option<FinishReason>,
577 pub safety_ratings: Option<Vec<SafetyRating>>,
580 pub citation_metadata: Option<CitationMetadata>,
584 pub token_count: Option<i32>,
586 pub avg_logprobs: Option<f64>,
588 pub logprobs_result: Option<LogprobsResult>,
590 pub index: Option<i32>,
592 pub finish_message: Option<String>,
594 }
595
596 #[derive(Clone, Debug, Deserialize, Serialize)]
597 pub struct Content {
598 #[serde(default)]
600 pub parts: Vec<Part>,
601 pub role: Option<Role>,
604 }
605
606 impl TryFrom<message::Message> for Content {
607 type Error = message::MessageError;
608
609 fn try_from(msg: message::Message) -> Result<Self, Self::Error> {
610 Ok(match msg {
611 message::Message::User { content } => Content {
612 parts: content
613 .into_iter()
614 .map(|c| c.try_into())
615 .collect::<Result<Vec<_>, _>>()?,
616 role: Some(Role::User),
617 },
618 message::Message::Assistant { content, .. } => Content {
619 role: Some(Role::Model),
620 parts: content
621 .into_iter()
622 .map(|content| content.try_into())
623 .collect::<Result<Vec<_>, _>>()?,
624 },
625 })
626 }
627 }
628
629 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
630 #[serde(rename_all = "lowercase")]
631 pub enum Role {
632 User,
633 Model,
634 }
635
636 #[derive(Debug, Default, Deserialize, Serialize, Clone, PartialEq)]
637 #[serde(rename_all = "camelCase")]
638 pub struct Part {
639 #[serde(skip_serializing_if = "Option::is_none")]
641 pub thought: Option<bool>,
642 #[serde(skip_serializing_if = "Option::is_none")]
644 pub thought_signature: Option<String>,
645 #[serde(flatten)]
646 pub part: PartKind,
647 #[serde(flatten, skip_serializing_if = "Option::is_none")]
648 pub additional_params: Option<Value>,
649 }
650
651 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
655 #[serde(rename_all = "camelCase")]
656 pub enum PartKind {
657 Text(String),
658 InlineData(Blob),
659 FunctionCall(FunctionCall),
660 FunctionResponse(FunctionResponse),
661 FileData(FileData),
662 ExecutableCode(ExecutableCode),
663 CodeExecutionResult(CodeExecutionResult),
664 }
665
666 impl Default for PartKind {
669 fn default() -> Self {
670 Self::Text(String::new())
671 }
672 }
673
674 impl From<String> for Part {
675 fn from(text: String) -> Self {
676 Self {
677 thought: Some(false),
678 thought_signature: None,
679 part: PartKind::Text(text),
680 additional_params: None,
681 }
682 }
683 }
684
685 impl From<&str> for Part {
686 fn from(text: &str) -> Self {
687 Self::from(text.to_string())
688 }
689 }
690
691 impl FromStr for Part {
692 type Err = Infallible;
693
694 fn from_str(s: &str) -> Result<Self, Self::Err> {
695 Ok(s.into())
696 }
697 }
698
699 impl TryFrom<(ImageMediaType, DocumentSourceKind)> for PartKind {
700 type Error = message::MessageError;
701 fn try_from(
702 (mime_type, doc_src): (ImageMediaType, DocumentSourceKind),
703 ) -> Result<Self, Self::Error> {
704 let mime_type = mime_type.to_mime_type().to_string();
705 let part = match doc_src {
706 DocumentSourceKind::Url(url) => PartKind::FileData(FileData {
707 mime_type: Some(mime_type),
708 file_uri: url,
709 }),
710 DocumentSourceKind::Base64(data) | DocumentSourceKind::String(data) => {
711 PartKind::InlineData(Blob { mime_type, data })
712 }
713 DocumentSourceKind::Raw(_) => {
714 return Err(message::MessageError::ConversionError(
715 "Raw files not supported, encode as base64 first".into(),
716 ));
717 }
718 DocumentSourceKind::Unknown => {
719 return Err(message::MessageError::ConversionError(
720 "Can't convert an unknown document source".to_string(),
721 ));
722 }
723 };
724
725 Ok(part)
726 }
727 }
728
729 impl TryFrom<message::UserContent> for Part {
730 type Error = message::MessageError;
731
732 fn try_from(content: message::UserContent) -> Result<Self, Self::Error> {
733 match content {
734 message::UserContent::Text(message::Text { text }) => Ok(Part {
735 thought: Some(false),
736 thought_signature: None,
737 part: PartKind::Text(text),
738 additional_params: None,
739 }),
740 message::UserContent::ToolResult(message::ToolResult { id, content, .. }) => {
741 let mut response_json: Option<serde_json::Value> = None;
742 let mut parts: Vec<FunctionResponsePart> = Vec::new();
743
744 for item in content.iter() {
745 match item {
746 message::ToolResultContent::Text(text) => {
747 let result: serde_json::Value =
748 serde_json::from_str(&text.text).unwrap_or_else(|error| {
749 tracing::trace!(
750 ?error,
751 "Tool result is not a valid JSON, treat it as normal string"
752 );
753 json!(&text.text)
754 });
755
756 response_json = Some(match response_json {
757 Some(mut existing) => {
758 if let serde_json::Value::Object(ref mut map) = existing {
759 map.insert("text".to_string(), result);
760 }
761 existing
762 }
763 None => json!({ "result": result }),
764 });
765 }
766 message::ToolResultContent::Image(image) => {
767 let part = match &image.data {
768 DocumentSourceKind::Base64(b64) => {
769 let mime_type = image
770 .media_type
771 .as_ref()
772 .ok_or(message::MessageError::ConversionError(
773 "Image media type is required for Gemini tool results".to_string(),
774 ))?
775 .to_mime_type();
776
777 FunctionResponsePart {
778 inline_data: Some(FunctionResponseInlineData {
779 mime_type: mime_type.to_string(),
780 data: b64.clone(),
781 display_name: None,
782 }),
783 file_data: None,
784 }
785 }
786 DocumentSourceKind::Url(url) => {
787 let mime_type = image
788 .media_type
789 .as_ref()
790 .map(|mt| mt.to_mime_type().to_string());
791
792 FunctionResponsePart {
793 inline_data: None,
794 file_data: Some(FileData {
795 mime_type,
796 file_uri: url.clone(),
797 }),
798 }
799 }
800 _ => {
801 return Err(message::MessageError::ConversionError(
802 "Unsupported image source kind for tool results"
803 .to_string(),
804 ));
805 }
806 };
807 parts.push(part);
808 }
809 }
810 }
811
812 Ok(Part {
813 thought: Some(false),
814 thought_signature: None,
815 part: PartKind::FunctionResponse(FunctionResponse {
816 name: id,
817 response: response_json,
818 parts: if parts.is_empty() { None } else { Some(parts) },
819 }),
820 additional_params: None,
821 })
822 }
823 message::UserContent::Image(message::Image {
824 data, media_type, ..
825 }) => match media_type {
826 Some(media_type) => match media_type {
827 message::ImageMediaType::JPEG
828 | message::ImageMediaType::PNG
829 | message::ImageMediaType::WEBP
830 | message::ImageMediaType::HEIC
831 | message::ImageMediaType::HEIF => {
832 let part = PartKind::try_from((media_type, data))?;
833 Ok(Part {
834 thought: Some(false),
835 thought_signature: None,
836 part,
837 additional_params: None,
838 })
839 }
840 _ => Err(message::MessageError::ConversionError(format!(
841 "Unsupported image media type {media_type:?}"
842 ))),
843 },
844 None => Err(message::MessageError::ConversionError(
845 "Media type for image is required for Gemini".to_string(),
846 )),
847 },
848 message::UserContent::Document(message::Document {
849 data, media_type, ..
850 }) => {
851 let Some(media_type) = media_type else {
852 return Err(MessageError::ConversionError(
853 "A mime type is required for document inputs to Gemini".to_string(),
854 ));
855 };
856
857 if matches!(
859 media_type,
860 message::DocumentMediaType::TXT
861 | message::DocumentMediaType::RTF
862 | message::DocumentMediaType::HTML
863 | message::DocumentMediaType::CSS
864 | message::DocumentMediaType::MARKDOWN
865 | message::DocumentMediaType::CSV
866 | message::DocumentMediaType::XML
867 | message::DocumentMediaType::Javascript
868 | message::DocumentMediaType::Python
869 ) {
870 use base64::Engine;
871 let text = match data {
872 DocumentSourceKind::String(text) => text.clone(),
873 DocumentSourceKind::Base64(data) => {
874 String::from_utf8(
876 base64::engine::general_purpose::STANDARD
877 .decode(&data)
878 .map_err(|e| {
879 MessageError::ConversionError(format!(
880 "Failed to decode base64: {e}"
881 ))
882 })?,
883 )
884 .map_err(|e| {
885 MessageError::ConversionError(format!(
886 "Invalid UTF-8 in document: {e}"
887 ))
888 })?
889 }
890 _ => {
891 return Err(MessageError::ConversionError(
892 "Text-based documents must be String or Base64 encoded"
893 .to_string(),
894 ));
895 }
896 };
897
898 Ok(Part {
899 thought: Some(false),
900 part: PartKind::Text(text),
901 ..Default::default()
902 })
903 } else if !media_type.is_code() {
904 let mime_type = media_type.to_mime_type().to_string();
905
906 let part = match data {
907 DocumentSourceKind::Url(file_uri) => PartKind::FileData(FileData {
908 mime_type: Some(mime_type),
909 file_uri,
910 }),
911 DocumentSourceKind::Base64(data) | DocumentSourceKind::String(data) => {
912 PartKind::InlineData(Blob { mime_type, data })
913 }
914 DocumentSourceKind::Raw(_) => {
915 return Err(message::MessageError::ConversionError(
916 "Raw files not supported, encode as base64 first".into(),
917 ));
918 }
919 _ => {
920 return Err(message::MessageError::ConversionError(
921 "Document has no body".to_string(),
922 ));
923 }
924 };
925
926 Ok(Part {
927 thought: Some(false),
928 part,
929 ..Default::default()
930 })
931 } else {
932 Err(message::MessageError::ConversionError(format!(
933 "Unsupported document media type {media_type:?}"
934 )))
935 }
936 }
937
938 message::UserContent::Audio(message::Audio {
939 data, media_type, ..
940 }) => {
941 let Some(media_type) = media_type else {
942 return Err(MessageError::ConversionError(
943 "A mime type is required for audio inputs to Gemini".to_string(),
944 ));
945 };
946
947 let mime_type = media_type.to_mime_type().to_string();
948
949 let part = match data {
950 DocumentSourceKind::Base64(data) => {
951 PartKind::InlineData(Blob { data, mime_type })
952 }
953
954 DocumentSourceKind::Url(file_uri) => PartKind::FileData(FileData {
955 mime_type: Some(mime_type),
956 file_uri,
957 }),
958 DocumentSourceKind::String(_) => {
959 return Err(message::MessageError::ConversionError(
960 "Strings cannot be used as audio files!".into(),
961 ));
962 }
963 DocumentSourceKind::Raw(_) => {
964 return Err(message::MessageError::ConversionError(
965 "Raw files not supported, encode as base64 first".into(),
966 ));
967 }
968 DocumentSourceKind::Unknown => {
969 return Err(message::MessageError::ConversionError(
970 "Content has no body".to_string(),
971 ));
972 }
973 };
974
975 Ok(Part {
976 thought: Some(false),
977 part,
978 ..Default::default()
979 })
980 }
981 message::UserContent::Video(message::Video {
982 data,
983 media_type,
984 additional_params,
985 ..
986 }) => {
987 let mime_type = media_type.map(|media_ty| media_ty.to_mime_type().to_string());
988
989 let part = match data {
990 DocumentSourceKind::Url(file_uri) => {
991 if file_uri.starts_with("https://www.youtube.com") {
992 PartKind::FileData(FileData {
993 mime_type,
994 file_uri,
995 })
996 } else {
997 if mime_type.is_none() {
998 return Err(MessageError::ConversionError(
999 "A mime type is required for non-Youtube video file inputs to Gemini"
1000 .to_string(),
1001 ));
1002 }
1003
1004 PartKind::FileData(FileData {
1005 mime_type,
1006 file_uri,
1007 })
1008 }
1009 }
1010 DocumentSourceKind::Base64(data) => {
1011 let Some(mime_type) = mime_type else {
1012 return Err(MessageError::ConversionError(
1013 "A media type is expected for base64 encoded strings"
1014 .to_string(),
1015 ));
1016 };
1017 PartKind::InlineData(Blob { mime_type, data })
1018 }
1019 DocumentSourceKind::String(_) => {
1020 return Err(message::MessageError::ConversionError(
1021 "Strings cannot be used as audio files!".into(),
1022 ));
1023 }
1024 DocumentSourceKind::Raw(_) => {
1025 return Err(message::MessageError::ConversionError(
1026 "Raw file data not supported, encode as base64 first".into(),
1027 ));
1028 }
1029 DocumentSourceKind::Unknown => {
1030 return Err(message::MessageError::ConversionError(
1031 "Media type for video is required for Gemini".to_string(),
1032 ));
1033 }
1034 };
1035
1036 Ok(Part {
1037 thought: Some(false),
1038 thought_signature: None,
1039 part,
1040 additional_params,
1041 })
1042 }
1043 }
1044 }
1045 }
1046
1047 impl TryFrom<message::AssistantContent> for Part {
1048 type Error = message::MessageError;
1049
1050 fn try_from(content: message::AssistantContent) -> Result<Self, Self::Error> {
1051 match content {
1052 message::AssistantContent::Text(message::Text { text }) => Ok(text.into()),
1053 message::AssistantContent::Image(message::Image {
1054 data, media_type, ..
1055 }) => match media_type {
1056 Some(media_type) => match media_type {
1057 message::ImageMediaType::JPEG
1058 | message::ImageMediaType::PNG
1059 | message::ImageMediaType::WEBP
1060 | message::ImageMediaType::HEIC
1061 | message::ImageMediaType::HEIF => {
1062 let part = PartKind::try_from((media_type, data))?;
1063 Ok(Part {
1064 thought: Some(false),
1065 thought_signature: None,
1066 part,
1067 additional_params: None,
1068 })
1069 }
1070 _ => Err(message::MessageError::ConversionError(format!(
1071 "Unsupported image media type {media_type:?}"
1072 ))),
1073 },
1074 None => Err(message::MessageError::ConversionError(
1075 "Media type for image is required for Gemini".to_string(),
1076 )),
1077 },
1078 message::AssistantContent::ToolCall(tool_call) => Ok(tool_call.into()),
1079 message::AssistantContent::Reasoning(reasoning) => Ok(Part {
1080 thought: Some(true),
1081 thought_signature: reasoning.first_signature().map(str::to_owned),
1082 part: PartKind::Text(reasoning.display_text()),
1083 additional_params: None,
1084 }),
1085 }
1086 }
1087 }
1088
1089 impl From<message::ToolCall> for Part {
1090 fn from(tool_call: message::ToolCall) -> Self {
1091 Self {
1092 thought: Some(false),
1093 thought_signature: tool_call.signature,
1094 part: PartKind::FunctionCall(FunctionCall {
1095 name: tool_call.function.name,
1096 args: tool_call.function.arguments,
1097 }),
1098 additional_params: None,
1099 }
1100 }
1101 }
1102
1103 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1106 #[serde(rename_all = "camelCase")]
1107 pub struct Blob {
1108 pub mime_type: String,
1111 pub data: String,
1113 }
1114
1115 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1118 pub struct FunctionCall {
1119 pub name: String,
1122 pub args: serde_json::Value,
1124 }
1125
1126 impl From<message::ToolCall> for FunctionCall {
1127 fn from(tool_call: message::ToolCall) -> Self {
1128 Self {
1129 name: tool_call.function.name,
1130 args: tool_call.function.arguments,
1131 }
1132 }
1133 }
1134
1135 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1139 pub struct FunctionResponse {
1140 pub name: String,
1143 #[serde(skip_serializing_if = "Option::is_none")]
1145 pub response: Option<serde_json::Value>,
1146 #[serde(skip_serializing_if = "Option::is_none")]
1148 pub parts: Option<Vec<FunctionResponsePart>>,
1149 }
1150
1151 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1153 #[serde(rename_all = "camelCase")]
1154 pub struct FunctionResponsePart {
1155 #[serde(skip_serializing_if = "Option::is_none")]
1157 pub inline_data: Option<FunctionResponseInlineData>,
1158 #[serde(skip_serializing_if = "Option::is_none")]
1160 pub file_data: Option<FileData>,
1161 }
1162
1163 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1165 #[serde(rename_all = "camelCase")]
1166 pub struct FunctionResponseInlineData {
1167 pub mime_type: String,
1169 pub data: String,
1171 #[serde(skip_serializing_if = "Option::is_none")]
1173 pub display_name: Option<String>,
1174 }
1175
1176 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1178 #[serde(rename_all = "camelCase")]
1179 pub struct FileData {
1180 pub mime_type: Option<String>,
1182 pub file_uri: String,
1184 }
1185
1186 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1187 pub struct SafetyRating {
1188 pub category: HarmCategory,
1189 pub probability: HarmProbability,
1190 }
1191
1192 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1193 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1194 pub enum HarmProbability {
1195 HarmProbabilityUnspecified,
1196 Negligible,
1197 Low,
1198 Medium,
1199 High,
1200 }
1201
1202 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1203 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1204 pub enum HarmCategory {
1205 HarmCategoryUnspecified,
1206 HarmCategoryDerogatory,
1207 HarmCategoryToxicity,
1208 HarmCategoryViolence,
1209 HarmCategorySexually,
1210 HarmCategoryMedical,
1211 HarmCategoryDangerous,
1212 HarmCategoryHarassment,
1213 HarmCategoryHateSpeech,
1214 HarmCategorySexuallyExplicit,
1215 HarmCategoryDangerousContent,
1216 HarmCategoryCivicIntegrity,
1217 }
1218
1219 #[derive(Debug, Deserialize, Clone, Default, Serialize)]
1220 #[serde(rename_all = "camelCase")]
1221 pub struct UsageMetadata {
1222 pub prompt_token_count: i32,
1223 #[serde(skip_serializing_if = "Option::is_none")]
1224 pub cached_content_token_count: Option<i32>,
1225 #[serde(skip_serializing_if = "Option::is_none")]
1226 pub candidates_token_count: Option<i32>,
1227 pub total_token_count: i32,
1228 #[serde(skip_serializing_if = "Option::is_none")]
1229 pub thoughts_token_count: Option<i32>,
1230 }
1231
1232 impl std::fmt::Display for UsageMetadata {
1233 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1234 write!(
1235 f,
1236 "Prompt token count: {}\nCached content token count: {}\nCandidates token count: {}\nTotal token count: {}",
1237 self.prompt_token_count,
1238 match self.cached_content_token_count {
1239 Some(count) => count.to_string(),
1240 None => "n/a".to_string(),
1241 },
1242 match self.candidates_token_count {
1243 Some(count) => count.to_string(),
1244 None => "n/a".to_string(),
1245 },
1246 self.total_token_count
1247 )
1248 }
1249 }
1250
1251 impl GetTokenUsage for UsageMetadata {
1252 fn token_usage(&self) -> Option<crate::completion::Usage> {
1253 let mut usage = crate::completion::Usage::new();
1254
1255 usage.input_tokens = self.prompt_token_count as u64;
1256 usage.output_tokens = (self.cached_content_token_count.unwrap_or_default()
1257 + self.candidates_token_count.unwrap_or_default()
1258 + self.thoughts_token_count.unwrap_or_default())
1259 as u64;
1260 usage.total_tokens = usage.input_tokens + usage.output_tokens;
1261
1262 Some(usage)
1263 }
1264 }
1265
1266 #[derive(Debug, Deserialize, Serialize)]
1268 #[serde(rename_all = "camelCase")]
1269 pub struct PromptFeedback {
1270 pub block_reason: Option<BlockReason>,
1272 pub safety_ratings: Option<Vec<SafetyRating>>,
1274 }
1275
1276 #[derive(Debug, Deserialize, Serialize)]
1278 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1279 pub enum BlockReason {
1280 BlockReasonUnspecified,
1282 Safety,
1284 Other,
1286 Blocklist,
1288 ProhibitedContent,
1290 }
1291
1292 #[derive(Clone, Debug, Deserialize, Serialize)]
1293 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1294 pub enum FinishReason {
1295 FinishReasonUnspecified,
1297 Stop,
1299 MaxTokens,
1301 Safety,
1303 Recitation,
1305 Language,
1307 Other,
1309 Blocklist,
1311 ProhibitedContent,
1313 Spii,
1315 MalformedFunctionCall,
1317 }
1318
1319 #[derive(Clone, Debug, Deserialize, Serialize)]
1320 #[serde(rename_all = "camelCase")]
1321 pub struct CitationMetadata {
1322 pub citation_sources: Vec<CitationSource>,
1323 }
1324
1325 #[derive(Clone, Debug, Deserialize, Serialize)]
1326 #[serde(rename_all = "camelCase")]
1327 pub struct CitationSource {
1328 #[serde(skip_serializing_if = "Option::is_none")]
1329 pub uri: Option<String>,
1330 #[serde(skip_serializing_if = "Option::is_none")]
1331 pub start_index: Option<i32>,
1332 #[serde(skip_serializing_if = "Option::is_none")]
1333 pub end_index: Option<i32>,
1334 #[serde(skip_serializing_if = "Option::is_none")]
1335 pub license: Option<String>,
1336 }
1337
1338 #[derive(Clone, Debug, Deserialize, Serialize)]
1339 #[serde(rename_all = "camelCase")]
1340 pub struct LogprobsResult {
1341 pub top_candidate: Vec<TopCandidate>,
1342 pub chosen_candidate: Vec<LogProbCandidate>,
1343 }
1344
1345 #[derive(Clone, Debug, Deserialize, Serialize)]
1346 pub struct TopCandidate {
1347 pub candidates: Vec<LogProbCandidate>,
1348 }
1349
1350 #[derive(Clone, Debug, Deserialize, Serialize)]
1351 #[serde(rename_all = "camelCase")]
1352 pub struct LogProbCandidate {
1353 pub token: String,
1354 pub token_id: String,
1355 pub log_probability: f64,
1356 }
1357
1358 #[derive(Debug, Deserialize, Serialize)]
1363 #[serde(rename_all = "camelCase")]
1364 pub struct GenerationConfig {
1365 #[serde(skip_serializing_if = "Option::is_none")]
1368 pub stop_sequences: Option<Vec<String>>,
1369 #[serde(skip_serializing_if = "Option::is_none")]
1375 pub response_mime_type: Option<String>,
1376 #[serde(skip_serializing_if = "Option::is_none")]
1380 pub response_schema: Option<Schema>,
1381 #[serde(
1387 skip_serializing_if = "Option::is_none",
1388 rename = "_responseJsonSchema"
1389 )]
1390 pub _response_json_schema: Option<Value>,
1391 #[serde(skip_serializing_if = "Option::is_none")]
1393 pub response_json_schema: Option<Value>,
1394 #[serde(skip_serializing_if = "Option::is_none")]
1397 pub candidate_count: Option<i32>,
1398 #[serde(skip_serializing_if = "Option::is_none")]
1401 pub max_output_tokens: Option<u64>,
1402 #[serde(skip_serializing_if = "Option::is_none")]
1405 pub temperature: Option<f64>,
1406 #[serde(skip_serializing_if = "Option::is_none")]
1413 pub top_p: Option<f64>,
1414 #[serde(skip_serializing_if = "Option::is_none")]
1420 pub top_k: Option<i32>,
1421 #[serde(skip_serializing_if = "Option::is_none")]
1427 pub presence_penalty: Option<f64>,
1428 #[serde(skip_serializing_if = "Option::is_none")]
1436 pub frequency_penalty: Option<f64>,
1437 #[serde(skip_serializing_if = "Option::is_none")]
1439 pub response_logprobs: Option<bool>,
1440 #[serde(skip_serializing_if = "Option::is_none")]
1443 pub logprobs: Option<i32>,
1444 #[serde(skip_serializing_if = "Option::is_none")]
1446 pub thinking_config: Option<ThinkingConfig>,
1447 #[serde(skip_serializing_if = "Option::is_none")]
1448 pub image_config: Option<ImageConfig>,
1449 }
1450
1451 impl Default for GenerationConfig {
1452 fn default() -> Self {
1453 Self {
1454 temperature: Some(1.0),
1455 max_output_tokens: Some(4096),
1456 stop_sequences: None,
1457 response_mime_type: None,
1458 response_schema: None,
1459 _response_json_schema: None,
1460 response_json_schema: None,
1461 candidate_count: None,
1462 top_p: None,
1463 top_k: None,
1464 presence_penalty: None,
1465 frequency_penalty: None,
1466 response_logprobs: None,
1467 logprobs: None,
1468 thinking_config: None,
1469 image_config: None,
1470 }
1471 }
1472 }
1473
1474 #[derive(Debug, Deserialize, Serialize)]
1475 #[serde(rename_all = "camelCase")]
1476 pub struct ThinkingConfig {
1477 pub thinking_budget: u32,
1478 pub include_thoughts: Option<bool>,
1479 }
1480
1481 #[derive(Debug, Deserialize, Serialize)]
1482 #[serde(rename_all = "camelCase")]
1483 pub struct ImageConfig {
1484 #[serde(skip_serializing_if = "Option::is_none")]
1485 pub aspect_ratio: Option<String>,
1486 #[serde(skip_serializing_if = "Option::is_none")]
1487 pub image_size: Option<String>,
1488 }
1489
1490 #[derive(Debug, Deserialize, Serialize, Clone)]
1494 pub struct Schema {
1495 pub r#type: String,
1496 #[serde(skip_serializing_if = "Option::is_none")]
1497 pub format: Option<String>,
1498 #[serde(skip_serializing_if = "Option::is_none")]
1499 pub description: Option<String>,
1500 #[serde(skip_serializing_if = "Option::is_none")]
1501 pub nullable: Option<bool>,
1502 #[serde(skip_serializing_if = "Option::is_none")]
1503 pub r#enum: Option<Vec<String>>,
1504 #[serde(skip_serializing_if = "Option::is_none")]
1505 pub max_items: Option<i32>,
1506 #[serde(skip_serializing_if = "Option::is_none")]
1507 pub min_items: Option<i32>,
1508 #[serde(skip_serializing_if = "Option::is_none")]
1509 pub properties: Option<HashMap<String, Schema>>,
1510 #[serde(skip_serializing_if = "Option::is_none")]
1511 pub required: Option<Vec<String>>,
1512 #[serde(skip_serializing_if = "Option::is_none")]
1513 pub items: Option<Box<Schema>>,
1514 }
1515
1516 pub fn flatten_schema(mut schema: Value) -> Result<Value, CompletionError> {
1522 let defs = if let Some(obj) = schema.as_object() {
1524 obj.get("$defs").or_else(|| obj.get("definitions")).cloned()
1525 } else {
1526 None
1527 };
1528
1529 let Some(defs_value) = defs else {
1530 return Ok(schema);
1531 };
1532
1533 let Some(defs_obj) = defs_value.as_object() else {
1534 return Err(CompletionError::ResponseError(
1535 "$defs must be an object".into(),
1536 ));
1537 };
1538
1539 resolve_refs(&mut schema, defs_obj)?;
1540
1541 if let Some(obj) = schema.as_object_mut() {
1543 obj.remove("$defs");
1544 obj.remove("definitions");
1545 }
1546
1547 Ok(schema)
1548 }
1549
1550 fn resolve_refs(
1553 value: &mut Value,
1554 defs: &serde_json::Map<String, Value>,
1555 ) -> Result<(), CompletionError> {
1556 match value {
1557 Value::Object(obj) => {
1558 if let Some(ref_value) = obj.get("$ref")
1559 && let Some(ref_str) = ref_value.as_str()
1560 {
1561 let def_name = parse_ref_path(ref_str)?;
1563
1564 let def = defs.get(&def_name).ok_or_else(|| {
1565 CompletionError::ResponseError(format!("Reference not found: {}", ref_str))
1566 })?;
1567
1568 let mut resolved = def.clone();
1569 resolve_refs(&mut resolved, defs)?;
1570 *value = resolved;
1571 return Ok(());
1572 }
1573
1574 for (_, v) in obj.iter_mut() {
1575 resolve_refs(v, defs)?;
1576 }
1577 }
1578 Value::Array(arr) => {
1579 for item in arr.iter_mut() {
1580 resolve_refs(item, defs)?;
1581 }
1582 }
1583 _ => {}
1584 }
1585
1586 Ok(())
1587 }
1588
1589 fn parse_ref_path(ref_str: &str) -> Result<String, CompletionError> {
1595 if let Some(fragment) = ref_str.strip_prefix('#') {
1596 if let Some(name) = fragment.strip_prefix("/$defs/") {
1597 Ok(name.to_string())
1598 } else if let Some(name) = fragment.strip_prefix("/definitions/") {
1599 Ok(name.to_string())
1600 } else {
1601 Err(CompletionError::ResponseError(format!(
1602 "Unsupported reference format: {}",
1603 ref_str
1604 )))
1605 }
1606 } else {
1607 Err(CompletionError::ResponseError(format!(
1608 "Only fragment references (#/...) are supported: {}",
1609 ref_str
1610 )))
1611 }
1612 }
1613
1614 fn extract_type(type_value: &Value) -> Option<String> {
1617 if type_value.is_string() {
1618 type_value.as_str().map(String::from)
1619 } else if type_value.is_array() {
1620 type_value
1621 .as_array()
1622 .and_then(|arr| arr.first())
1623 .and_then(|v| v.as_str().map(String::from))
1624 } else {
1625 None
1626 }
1627 }
1628
1629 fn extract_type_from_composition(composition: &Value) -> Option<String> {
1632 composition.as_array().and_then(|arr| {
1633 arr.iter().find_map(|schema| {
1634 if let Some(obj) = schema.as_object() {
1635 if let Some(type_val) = obj.get("type")
1637 && let Some(type_str) = type_val.as_str()
1638 && type_str == "null"
1639 {
1640 return None;
1641 }
1642 obj.get("type").and_then(extract_type).or_else(|| {
1644 if obj.contains_key("properties") {
1645 Some("object".to_string())
1646 } else {
1647 None
1648 }
1649 })
1650 } else {
1651 None
1652 }
1653 })
1654 })
1655 }
1656
1657 fn extract_schema_from_composition(
1660 composition: &Value,
1661 ) -> Option<serde_json::Map<String, Value>> {
1662 composition.as_array().and_then(|arr| {
1663 arr.iter().find_map(|schema| {
1664 if let Some(obj) = schema.as_object()
1665 && let Some(type_val) = obj.get("type")
1666 && let Some(type_str) = type_val.as_str()
1667 {
1668 if type_str == "null" {
1669 return None;
1670 }
1671 Some(obj.clone())
1672 } else {
1673 None
1674 }
1675 })
1676 })
1677 }
1678
1679 fn infer_type(obj: &serde_json::Map<String, Value>) -> String {
1682 if let Some(type_val) = obj.get("type")
1684 && let Some(type_str) = extract_type(type_val)
1685 {
1686 return type_str;
1687 }
1688
1689 if let Some(any_of) = obj.get("anyOf")
1691 && let Some(type_str) = extract_type_from_composition(any_of)
1692 {
1693 return type_str;
1694 }
1695
1696 if let Some(one_of) = obj.get("oneOf")
1697 && let Some(type_str) = extract_type_from_composition(one_of)
1698 {
1699 return type_str;
1700 }
1701
1702 if let Some(all_of) = obj.get("allOf")
1703 && let Some(type_str) = extract_type_from_composition(all_of)
1704 {
1705 return type_str;
1706 }
1707
1708 if obj.contains_key("properties") {
1710 "object".to_string()
1711 } else {
1712 String::new()
1713 }
1714 }
1715
1716 impl TryFrom<Value> for Schema {
1717 type Error = CompletionError;
1718
1719 fn try_from(value: Value) -> Result<Self, Self::Error> {
1720 let flattened_val = flatten_schema(value)?;
1721 if let Some(obj) = flattened_val.as_object() {
1722 let props_source = if obj.get("properties").is_none() {
1725 if let Some(any_of) = obj.get("anyOf") {
1726 extract_schema_from_composition(any_of)
1727 } else if let Some(one_of) = obj.get("oneOf") {
1728 extract_schema_from_composition(one_of)
1729 } else if let Some(all_of) = obj.get("allOf") {
1730 extract_schema_from_composition(all_of)
1731 } else {
1732 None
1733 }
1734 .unwrap_or(obj.clone())
1735 } else {
1736 obj.clone()
1737 };
1738
1739 Ok(Schema {
1740 r#type: infer_type(obj),
1741 format: obj.get("format").and_then(|v| v.as_str()).map(String::from),
1742 description: obj
1743 .get("description")
1744 .and_then(|v| v.as_str())
1745 .map(String::from),
1746 nullable: obj.get("nullable").and_then(|v| v.as_bool()),
1747 r#enum: obj.get("enum").and_then(|v| v.as_array()).map(|arr| {
1748 arr.iter()
1749 .filter_map(|v| v.as_str().map(String::from))
1750 .collect()
1751 }),
1752 max_items: obj
1753 .get("maxItems")
1754 .and_then(|v| v.as_i64())
1755 .map(|v| v as i32),
1756 min_items: obj
1757 .get("minItems")
1758 .and_then(|v| v.as_i64())
1759 .map(|v| v as i32),
1760 properties: props_source
1761 .get("properties")
1762 .and_then(|v| v.as_object())
1763 .map(|map| {
1764 map.iter()
1765 .filter_map(|(k, v)| {
1766 v.clone().try_into().ok().map(|schema| (k.clone(), schema))
1767 })
1768 .collect()
1769 }),
1770 required: props_source
1771 .get("required")
1772 .and_then(|v| v.as_array())
1773 .map(|arr| {
1774 arr.iter()
1775 .filter_map(|v| v.as_str().map(String::from))
1776 .collect()
1777 }),
1778 items: obj
1779 .get("items")
1780 .and_then(|v| v.clone().try_into().ok())
1781 .map(Box::new),
1782 })
1783 } else {
1784 Err(CompletionError::ResponseError(
1785 "Expected a JSON object for Schema".into(),
1786 ))
1787 }
1788 }
1789 }
1790
1791 #[derive(Debug, Serialize)]
1792 #[serde(rename_all = "camelCase")]
1793 pub struct GenerateContentRequest {
1794 pub contents: Vec<Content>,
1795 #[serde(skip_serializing_if = "Option::is_none")]
1796 pub tools: Option<Vec<Tool>>,
1797 pub tool_config: Option<ToolConfig>,
1798 pub generation_config: Option<GenerationConfig>,
1800 pub safety_settings: Option<Vec<SafetySetting>>,
1814 pub system_instruction: Option<Content>,
1817 #[serde(flatten, skip_serializing_if = "Option::is_none")]
1820 pub additional_params: Option<serde_json::Value>,
1821 }
1822
1823 #[derive(Debug, Serialize)]
1824 #[serde(rename_all = "camelCase")]
1825 pub struct Tool {
1826 pub function_declarations: Vec<FunctionDeclaration>,
1827 pub code_execution: Option<CodeExecution>,
1828 }
1829
1830 #[derive(Debug, Serialize, Clone)]
1831 #[serde(rename_all = "camelCase")]
1832 pub struct FunctionDeclaration {
1833 pub name: String,
1834 pub description: String,
1835 #[serde(skip_serializing_if = "Option::is_none")]
1836 pub parameters: Option<Schema>,
1837 }
1838
1839 #[derive(Debug, Serialize, Deserialize)]
1840 #[serde(rename_all = "camelCase")]
1841 pub struct ToolConfig {
1842 pub function_calling_config: Option<FunctionCallingMode>,
1843 }
1844
1845 #[derive(Debug, Serialize, Deserialize, Default)]
1846 #[serde(tag = "mode", rename_all = "UPPERCASE")]
1847 pub enum FunctionCallingMode {
1848 #[default]
1849 Auto,
1850 None,
1851 Any {
1852 #[serde(skip_serializing_if = "Option::is_none")]
1853 allowed_function_names: Option<Vec<String>>,
1854 },
1855 }
1856
1857 impl TryFrom<message::ToolChoice> for FunctionCallingMode {
1858 type Error = CompletionError;
1859 fn try_from(value: message::ToolChoice) -> Result<Self, Self::Error> {
1860 let res = match value {
1861 message::ToolChoice::Auto => Self::Auto,
1862 message::ToolChoice::None => Self::None,
1863 message::ToolChoice::Required => Self::Any {
1864 allowed_function_names: None,
1865 },
1866 message::ToolChoice::Specific { function_names } => Self::Any {
1867 allowed_function_names: Some(function_names),
1868 },
1869 };
1870
1871 Ok(res)
1872 }
1873 }
1874
1875 #[derive(Debug, Serialize)]
1876 pub struct CodeExecution {}
1877
1878 #[derive(Debug, Serialize)]
1879 #[serde(rename_all = "camelCase")]
1880 pub struct SafetySetting {
1881 pub category: HarmCategory,
1882 pub threshold: HarmBlockThreshold,
1883 }
1884
1885 #[derive(Debug, Serialize)]
1886 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1887 pub enum HarmBlockThreshold {
1888 HarmBlockThresholdUnspecified,
1889 BlockLowAndAbove,
1890 BlockMediumAndAbove,
1891 BlockOnlyHigh,
1892 BlockNone,
1893 Off,
1894 }
1895}
1896
1897#[cfg(test)]
1898mod tests {
1899 use crate::{
1900 message,
1901 providers::gemini::completion::gemini_api_types::{
1902 ContentCandidate, FinishReason, flatten_schema,
1903 },
1904 };
1905
1906 use super::*;
1907 use serde_json::json;
1908
1909 #[test]
1910 fn test_resolve_request_model_uses_override() {
1911 let request = CompletionRequest {
1912 model: Some("gemini-2.5-flash".to_string()),
1913 preamble: None,
1914 chat_history: crate::OneOrMany::one("Hello".into()),
1915 documents: vec![],
1916 tools: vec![],
1917 temperature: None,
1918 max_tokens: None,
1919 tool_choice: None,
1920 additional_params: None,
1921 output_schema: None,
1922 };
1923
1924 let request_model = resolve_request_model("gemini-2.0-flash", &request);
1925 assert_eq!(request_model, "gemini-2.5-flash");
1926 assert_eq!(
1927 completion_endpoint(&request_model),
1928 "/v1beta/models/gemini-2.5-flash:generateContent"
1929 );
1930 assert_eq!(
1931 streaming_endpoint(&request_model),
1932 "/v1beta/models/gemini-2.5-flash:streamGenerateContent"
1933 );
1934 }
1935
1936 #[test]
1937 fn test_resolve_request_model_uses_default_when_unset() {
1938 let request = CompletionRequest {
1939 model: None,
1940 preamble: None,
1941 chat_history: crate::OneOrMany::one("Hello".into()),
1942 documents: vec![],
1943 tools: vec![],
1944 temperature: None,
1945 max_tokens: None,
1946 tool_choice: None,
1947 additional_params: None,
1948 output_schema: None,
1949 };
1950
1951 assert_eq!(
1952 resolve_request_model("gemini-2.0-flash", &request),
1953 "gemini-2.0-flash"
1954 );
1955 }
1956
1957 #[test]
1958 fn test_deserialize_message_user() {
1959 let raw_message = r#"{
1960 "parts": [
1961 {"text": "Hello, world!"},
1962 {"inlineData": {"mimeType": "image/png", "data": "base64encodeddata"}},
1963 {"functionCall": {"name": "test_function", "args": {"arg1": "value1"}}},
1964 {"functionResponse": {"name": "test_function", "response": {"result": "success"}}},
1965 {"fileData": {"mimeType": "application/pdf", "fileUri": "http://example.com/file.pdf"}},
1966 {"executableCode": {"code": "print('Hello, world!')", "language": "PYTHON"}},
1967 {"codeExecutionResult": {"output": "Hello, world!", "outcome": "OUTCOME_OK"}}
1968 ],
1969 "role": "user"
1970 }"#;
1971
1972 let content: Content = {
1973 let jd = &mut serde_json::Deserializer::from_str(raw_message);
1974 serde_path_to_error::deserialize(jd).unwrap_or_else(|err| {
1975 panic!("Deserialization error at {}: {}", err.path(), err);
1976 })
1977 };
1978 assert_eq!(content.role, Some(Role::User));
1979 assert_eq!(content.parts.len(), 7);
1980
1981 let parts: Vec<Part> = content.parts.into_iter().collect();
1982
1983 if let Part {
1984 part: PartKind::Text(text),
1985 ..
1986 } = &parts[0]
1987 {
1988 assert_eq!(text, "Hello, world!");
1989 } else {
1990 panic!("Expected text part");
1991 }
1992
1993 if let Part {
1994 part: PartKind::InlineData(inline_data),
1995 ..
1996 } = &parts[1]
1997 {
1998 assert_eq!(inline_data.mime_type, "image/png");
1999 assert_eq!(inline_data.data, "base64encodeddata");
2000 } else {
2001 panic!("Expected inline data part");
2002 }
2003
2004 if let Part {
2005 part: PartKind::FunctionCall(function_call),
2006 ..
2007 } = &parts[2]
2008 {
2009 assert_eq!(function_call.name, "test_function");
2010 assert_eq!(
2011 function_call.args.as_object().unwrap().get("arg1").unwrap(),
2012 "value1"
2013 );
2014 } else {
2015 panic!("Expected function call part");
2016 }
2017
2018 if let Part {
2019 part: PartKind::FunctionResponse(function_response),
2020 ..
2021 } = &parts[3]
2022 {
2023 assert_eq!(function_response.name, "test_function");
2024 assert_eq!(
2025 function_response
2026 .response
2027 .as_ref()
2028 .unwrap()
2029 .get("result")
2030 .unwrap(),
2031 "success"
2032 );
2033 } else {
2034 panic!("Expected function response part");
2035 }
2036
2037 if let Part {
2038 part: PartKind::FileData(file_data),
2039 ..
2040 } = &parts[4]
2041 {
2042 assert_eq!(file_data.mime_type.as_ref().unwrap(), "application/pdf");
2043 assert_eq!(file_data.file_uri, "http://example.com/file.pdf");
2044 } else {
2045 panic!("Expected file data part");
2046 }
2047
2048 if let Part {
2049 part: PartKind::ExecutableCode(executable_code),
2050 ..
2051 } = &parts[5]
2052 {
2053 assert_eq!(executable_code.code, "print('Hello, world!')");
2054 } else {
2055 panic!("Expected executable code part");
2056 }
2057
2058 if let Part {
2059 part: PartKind::CodeExecutionResult(code_execution_result),
2060 ..
2061 } = &parts[6]
2062 {
2063 assert_eq!(
2064 code_execution_result.clone().output.unwrap(),
2065 "Hello, world!"
2066 );
2067 } else {
2068 panic!("Expected code execution result part");
2069 }
2070 }
2071
2072 #[test]
2073 fn test_deserialize_message_model() {
2074 let json_data = json!({
2075 "parts": [{"text": "Hello, user!"}],
2076 "role": "model"
2077 });
2078
2079 let content: Content = serde_json::from_value(json_data).unwrap();
2080 assert_eq!(content.role, Some(Role::Model));
2081 assert_eq!(content.parts.len(), 1);
2082 if let Some(Part {
2083 part: PartKind::Text(text),
2084 ..
2085 }) = content.parts.first()
2086 {
2087 assert_eq!(text, "Hello, user!");
2088 } else {
2089 panic!("Expected text part");
2090 }
2091 }
2092
2093 #[test]
2094 fn test_message_conversion_user() {
2095 let msg = message::Message::user("Hello, world!");
2096 let content: Content = msg.try_into().unwrap();
2097 assert_eq!(content.role, Some(Role::User));
2098 assert_eq!(content.parts.len(), 1);
2099 if let Some(Part {
2100 part: PartKind::Text(text),
2101 ..
2102 }) = &content.parts.first()
2103 {
2104 assert_eq!(text, "Hello, world!");
2105 } else {
2106 panic!("Expected text part");
2107 }
2108 }
2109
2110 #[test]
2111 fn test_message_conversion_model() {
2112 let msg = message::Message::assistant("Hello, user!");
2113
2114 let content: Content = msg.try_into().unwrap();
2115 assert_eq!(content.role, Some(Role::Model));
2116 assert_eq!(content.parts.len(), 1);
2117 if let Some(Part {
2118 part: PartKind::Text(text),
2119 ..
2120 }) = &content.parts.first()
2121 {
2122 assert_eq!(text, "Hello, user!");
2123 } else {
2124 panic!("Expected text part");
2125 }
2126 }
2127
2128 #[test]
2129 fn test_thought_signature_is_preserved_from_response_reasoning_part() {
2130 let response = GenerateContentResponse {
2131 response_id: "resp_1".to_string(),
2132 candidates: vec![ContentCandidate {
2133 content: Some(Content {
2134 parts: vec![Part {
2135 thought: Some(true),
2136 thought_signature: Some("thought_sig_123".to_string()),
2137 part: PartKind::Text("thinking text".to_string()),
2138 additional_params: None,
2139 }],
2140 role: Some(Role::Model),
2141 }),
2142 finish_reason: Some(FinishReason::Stop),
2143 safety_ratings: None,
2144 citation_metadata: None,
2145 token_count: None,
2146 avg_logprobs: None,
2147 logprobs_result: None,
2148 index: Some(0),
2149 finish_message: None,
2150 }],
2151 prompt_feedback: None,
2152 usage_metadata: None,
2153 model_version: None,
2154 };
2155
2156 let converted: crate::completion::CompletionResponse<GenerateContentResponse> =
2157 response.try_into().expect("convert response");
2158 let first = converted.choice.first();
2159 assert!(matches!(
2160 first,
2161 message::AssistantContent::Reasoning(message::Reasoning { content, .. })
2162 if matches!(
2163 content.first(),
2164 Some(message::ReasoningContent::Text {
2165 text,
2166 signature: Some(signature)
2167 }) if text == "thinking text" && signature == "thought_sig_123"
2168 )
2169 ));
2170 }
2171
2172 #[test]
2173 fn test_reasoning_signature_is_emitted_in_gemini_part() {
2174 let msg = message::Message::Assistant {
2175 id: None,
2176 content: OneOrMany::one(message::AssistantContent::Reasoning(
2177 message::Reasoning::new_with_signature(
2178 "structured thought",
2179 Some("reuse_sig_456".to_string()),
2180 ),
2181 )),
2182 };
2183
2184 let converted: Content = msg.try_into().expect("convert message");
2185 let first = converted.parts.first().expect("reasoning part");
2186 assert_eq!(first.thought, Some(true));
2187 assert_eq!(first.thought_signature.as_deref(), Some("reuse_sig_456"));
2188 assert!(matches!(
2189 &first.part,
2190 PartKind::Text(text) if text == "structured thought"
2191 ));
2192 }
2193
2194 #[test]
2195 fn test_message_conversion_tool_call() {
2196 let tool_call = message::ToolCall {
2197 id: "test_tool".to_string(),
2198 call_id: None,
2199 function: message::ToolFunction {
2200 name: "test_function".to_string(),
2201 arguments: json!({"arg1": "value1"}),
2202 },
2203 signature: None,
2204 additional_params: None,
2205 };
2206
2207 let msg = message::Message::Assistant {
2208 id: None,
2209 content: OneOrMany::one(message::AssistantContent::ToolCall(tool_call)),
2210 };
2211
2212 let content: Content = msg.try_into().unwrap();
2213 assert_eq!(content.role, Some(Role::Model));
2214 assert_eq!(content.parts.len(), 1);
2215 if let Some(Part {
2216 part: PartKind::FunctionCall(function_call),
2217 ..
2218 }) = content.parts.first()
2219 {
2220 assert_eq!(function_call.name, "test_function");
2221 assert_eq!(
2222 function_call.args.as_object().unwrap().get("arg1").unwrap(),
2223 "value1"
2224 );
2225 } else {
2226 panic!("Expected function call part");
2227 }
2228 }
2229
2230 #[test]
2231 fn test_vec_schema_conversion() {
2232 let schema_with_ref = json!({
2233 "type": "array",
2234 "items": {
2235 "$ref": "#/$defs/Person"
2236 },
2237 "$defs": {
2238 "Person": {
2239 "type": "object",
2240 "properties": {
2241 "first_name": {
2242 "type": ["string", "null"],
2243 "description": "The person's first name, if provided (null otherwise)"
2244 },
2245 "last_name": {
2246 "type": ["string", "null"],
2247 "description": "The person's last name, if provided (null otherwise)"
2248 },
2249 "job": {
2250 "type": ["string", "null"],
2251 "description": "The person's job, if provided (null otherwise)"
2252 }
2253 },
2254 "required": []
2255 }
2256 }
2257 });
2258
2259 let result: Result<Schema, _> = schema_with_ref.try_into();
2260
2261 match result {
2262 Ok(schema) => {
2263 assert_eq!(schema.r#type, "array");
2264
2265 if let Some(items) = schema.items {
2266 println!("item types: {}", items.r#type);
2267
2268 assert_ne!(items.r#type, "", "Items type should not be empty string!");
2269 assert_eq!(items.r#type, "object", "Items should be object type");
2270 } else {
2271 panic!("Schema should have items field for array type");
2272 }
2273 }
2274 Err(e) => println!("Schema conversion failed: {:?}", e),
2275 }
2276 }
2277
2278 #[test]
2279 fn test_object_schema() {
2280 let simple_schema = json!({
2281 "type": "object",
2282 "properties": {
2283 "name": {
2284 "type": "string"
2285 }
2286 }
2287 });
2288
2289 let schema: Schema = simple_schema.try_into().unwrap();
2290 assert_eq!(schema.r#type, "object");
2291 assert!(schema.properties.is_some());
2292 }
2293
2294 #[test]
2295 fn test_array_with_inline_items() {
2296 let inline_schema = json!({
2297 "type": "array",
2298 "items": {
2299 "type": "object",
2300 "properties": {
2301 "name": {
2302 "type": "string"
2303 }
2304 }
2305 }
2306 });
2307
2308 let schema: Schema = inline_schema.try_into().unwrap();
2309 assert_eq!(schema.r#type, "array");
2310
2311 if let Some(items) = schema.items {
2312 assert_eq!(items.r#type, "object");
2313 assert!(items.properties.is_some());
2314 } else {
2315 panic!("Schema should have items field");
2316 }
2317 }
2318 #[test]
2319 fn test_flattened_schema() {
2320 let ref_schema = json!({
2321 "type": "array",
2322 "items": {
2323 "$ref": "#/$defs/Person"
2324 },
2325 "$defs": {
2326 "Person": {
2327 "type": "object",
2328 "properties": {
2329 "name": { "type": "string" }
2330 }
2331 }
2332 }
2333 });
2334
2335 let flattened = flatten_schema(ref_schema).unwrap();
2336 let schema: Schema = flattened.try_into().unwrap();
2337
2338 assert_eq!(schema.r#type, "array");
2339
2340 if let Some(items) = schema.items {
2341 println!("Flattened items type: '{}'", items.r#type);
2342
2343 assert_eq!(items.r#type, "object");
2344 assert!(items.properties.is_some());
2345 }
2346 }
2347
2348 #[test]
2349 fn test_txt_document_conversion_to_text_part() {
2350 use crate::message::{DocumentMediaType, UserContent};
2352
2353 let doc = UserContent::document(
2354 "Note: test.md\nPath: /test.md\nContent: Hello World!",
2355 Some(DocumentMediaType::TXT),
2356 );
2357
2358 let content: Content = message::Message::User {
2359 content: crate::OneOrMany::one(doc),
2360 }
2361 .try_into()
2362 .unwrap();
2363
2364 if let Part {
2365 part: PartKind::Text(text),
2366 ..
2367 } = &content.parts[0]
2368 {
2369 assert!(text.contains("Note: test.md"));
2370 assert!(text.contains("Hello World!"));
2371 } else {
2372 panic!(
2373 "Expected text part for TXT document, got: {:?}",
2374 content.parts[0]
2375 );
2376 }
2377 }
2378
2379 #[test]
2380 fn test_tool_result_with_image_content() {
2381 use crate::OneOrMany;
2383 use crate::message::{
2384 DocumentSourceKind, Image, ImageMediaType, ToolResult, ToolResultContent,
2385 };
2386
2387 let tool_result = ToolResult {
2389 id: "test_tool".to_string(),
2390 call_id: None,
2391 content: OneOrMany::many(vec![
2392 ToolResultContent::Text(message::Text {
2393 text: r#"{"status": "success"}"#.to_string(),
2394 }),
2395 ToolResultContent::Image(Image {
2396 data: DocumentSourceKind::Base64("iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==".to_string()),
2397 media_type: Some(ImageMediaType::PNG),
2398 detail: None,
2399 additional_params: None,
2400 }),
2401 ]).expect("Should create OneOrMany with multiple items"),
2402 };
2403
2404 let user_content = message::UserContent::ToolResult(tool_result);
2405 let msg = message::Message::User {
2406 content: OneOrMany::one(user_content),
2407 };
2408
2409 let content: Content = msg.try_into().expect("Should convert to Gemini Content");
2411 assert_eq!(content.role, Some(Role::User));
2412 assert_eq!(content.parts.len(), 1);
2413
2414 if let Some(Part {
2416 part: PartKind::FunctionResponse(function_response),
2417 ..
2418 }) = content.parts.first()
2419 {
2420 assert_eq!(function_response.name, "test_tool");
2421
2422 assert!(function_response.response.is_some());
2424 let response = function_response.response.as_ref().unwrap();
2425 assert!(response.get("result").is_some());
2426
2427 assert!(function_response.parts.is_some());
2429 let parts = function_response.parts.as_ref().unwrap();
2430 assert_eq!(parts.len(), 1);
2431
2432 let image_part = &parts[0];
2433 assert!(image_part.inline_data.is_some());
2434 let inline_data = image_part.inline_data.as_ref().unwrap();
2435 assert_eq!(inline_data.mime_type, "image/png");
2436 assert!(!inline_data.data.is_empty());
2437 } else {
2438 panic!("Expected FunctionResponse part");
2439 }
2440 }
2441
2442 #[test]
2443 fn test_markdown_document_conversion_to_text_part() {
2444 use crate::message::{DocumentMediaType, UserContent};
2446
2447 let doc = UserContent::document(
2448 "# Heading\n\n* List item",
2449 Some(DocumentMediaType::MARKDOWN),
2450 );
2451
2452 let content: Content = message::Message::User {
2453 content: crate::OneOrMany::one(doc),
2454 }
2455 .try_into()
2456 .unwrap();
2457
2458 if let Part {
2459 part: PartKind::Text(text),
2460 ..
2461 } = &content.parts[0]
2462 {
2463 assert_eq!(text, "# Heading\n\n* List item");
2464 } else {
2465 panic!(
2466 "Expected text part for MARKDOWN document, got: {:?}",
2467 content.parts[0]
2468 );
2469 }
2470 }
2471
2472 #[test]
2473 fn test_tool_result_with_url_image() {
2474 use crate::OneOrMany;
2476 use crate::message::{
2477 DocumentSourceKind, Image, ImageMediaType, ToolResult, ToolResultContent,
2478 };
2479
2480 let tool_result = ToolResult {
2481 id: "screenshot_tool".to_string(),
2482 call_id: None,
2483 content: OneOrMany::one(ToolResultContent::Image(Image {
2484 data: DocumentSourceKind::Url("https://example.com/image.png".to_string()),
2485 media_type: Some(ImageMediaType::PNG),
2486 detail: None,
2487 additional_params: None,
2488 })),
2489 };
2490
2491 let user_content = message::UserContent::ToolResult(tool_result);
2492 let msg = message::Message::User {
2493 content: OneOrMany::one(user_content),
2494 };
2495
2496 let content: Content = msg.try_into().expect("Should convert to Gemini Content");
2497 assert_eq!(content.role, Some(Role::User));
2498 assert_eq!(content.parts.len(), 1);
2499
2500 if let Some(Part {
2501 part: PartKind::FunctionResponse(function_response),
2502 ..
2503 }) = content.parts.first()
2504 {
2505 assert_eq!(function_response.name, "screenshot_tool");
2506
2507 assert!(function_response.parts.is_some());
2509 let parts = function_response.parts.as_ref().unwrap();
2510 assert_eq!(parts.len(), 1);
2511
2512 let image_part = &parts[0];
2513 assert!(image_part.file_data.is_some());
2514 let file_data = image_part.file_data.as_ref().unwrap();
2515 assert_eq!(file_data.file_uri, "https://example.com/image.png");
2516 assert_eq!(file_data.mime_type.as_ref().unwrap(), "image/png");
2517 } else {
2518 panic!("Expected FunctionResponse part");
2519 }
2520 }
2521
2522 #[test]
2523 fn test_create_request_body_with_documents() {
2524 use crate::OneOrMany;
2526 use crate::completion::request::{CompletionRequest, Document};
2527 use crate::message::Message;
2528
2529 let documents = vec![
2530 Document {
2531 id: "doc1".to_string(),
2532 text: "Note: first.md\nContent: First note".to_string(),
2533 additional_props: std::collections::HashMap::new(),
2534 },
2535 Document {
2536 id: "doc2".to_string(),
2537 text: "Note: second.md\nContent: Second note".to_string(),
2538 additional_props: std::collections::HashMap::new(),
2539 },
2540 ];
2541
2542 let completion_request = CompletionRequest {
2543 preamble: Some("You are a helpful assistant".to_string()),
2544 chat_history: OneOrMany::one(Message::user("What are my notes about?")),
2545 documents: documents.clone(),
2546 tools: vec![],
2547 temperature: None,
2548 model: None,
2549 output_schema: None,
2550 max_tokens: None,
2551 tool_choice: None,
2552 additional_params: None,
2553 };
2554
2555 let request = create_request_body(completion_request).unwrap();
2556
2557 assert_eq!(
2559 request.contents.len(),
2560 2,
2561 "Expected 2 contents (documents + user message)"
2562 );
2563
2564 assert_eq!(request.contents[0].role, Some(Role::User));
2566 assert_eq!(
2567 request.contents[0].parts.len(),
2568 2,
2569 "Expected 2 document parts"
2570 );
2571
2572 for part in &request.contents[0].parts {
2574 if let Part {
2575 part: PartKind::Text(text),
2576 ..
2577 } = part
2578 {
2579 assert!(
2580 text.contains("Note:") && text.contains("Content:"),
2581 "Document should contain note metadata"
2582 );
2583 } else {
2584 panic!("Document parts should be text, not {:?}", part);
2585 }
2586 }
2587
2588 assert_eq!(request.contents[1].role, Some(Role::User));
2590 if let Part {
2591 part: PartKind::Text(text),
2592 ..
2593 } = &request.contents[1].parts[0]
2594 {
2595 assert_eq!(text, "What are my notes about?");
2596 } else {
2597 panic!("Expected user message to be text");
2598 }
2599 }
2600
2601 #[test]
2602 fn test_create_request_body_without_documents() {
2603 use crate::OneOrMany;
2605 use crate::completion::request::CompletionRequest;
2606 use crate::message::Message;
2607
2608 let completion_request = CompletionRequest {
2609 preamble: Some("You are a helpful assistant".to_string()),
2610 chat_history: OneOrMany::one(Message::user("Hello")),
2611 documents: vec![], tools: vec![],
2613 temperature: None,
2614 max_tokens: None,
2615 tool_choice: None,
2616 model: None,
2617 output_schema: None,
2618 additional_params: None,
2619 };
2620
2621 let request = create_request_body(completion_request).unwrap();
2622
2623 assert_eq!(request.contents.len(), 1, "Expected only user message");
2625 assert_eq!(request.contents[0].role, Some(Role::User));
2626
2627 if let Part {
2628 part: PartKind::Text(text),
2629 ..
2630 } = &request.contents[0].parts[0]
2631 {
2632 assert_eq!(text, "Hello");
2633 } else {
2634 panic!("Expected user message to be text");
2635 }
2636 }
2637
2638 #[test]
2639 fn test_from_tool_output_parses_image_json() {
2640 use crate::message::{DocumentSourceKind, ToolResultContent};
2642
2643 let image_json = r#"{"type": "image", "data": "base64data==", "mimeType": "image/jpeg"}"#;
2645 let result = ToolResultContent::from_tool_output(image_json);
2646
2647 assert_eq!(result.len(), 1);
2648 if let ToolResultContent::Image(img) = result.first() {
2649 assert!(matches!(img.data, DocumentSourceKind::Base64(_)));
2650 if let DocumentSourceKind::Base64(data) = &img.data {
2651 assert_eq!(data, "base64data==");
2652 }
2653 assert_eq!(img.media_type, Some(crate::message::ImageMediaType::JPEG));
2654 } else {
2655 panic!("Expected Image content");
2656 }
2657 }
2658
2659 #[test]
2660 fn test_from_tool_output_parses_hybrid_json() {
2661 use crate::message::{DocumentSourceKind, ToolResultContent};
2663
2664 let hybrid_json = r#"{
2665 "response": {"status": "ok", "count": 42},
2666 "parts": [
2667 {"type": "image", "data": "imgdata1==", "mimeType": "image/png"},
2668 {"type": "image", "data": "https://example.com/img.jpg", "mimeType": "image/jpeg"}
2669 ]
2670 }"#;
2671
2672 let result = ToolResultContent::from_tool_output(hybrid_json);
2673
2674 assert_eq!(result.len(), 3);
2676
2677 let items: Vec<_> = result.iter().collect();
2678
2679 if let ToolResultContent::Text(text) = &items[0] {
2681 assert!(text.text.contains("status"));
2682 assert!(text.text.contains("ok"));
2683 } else {
2684 panic!("Expected Text content first");
2685 }
2686
2687 if let ToolResultContent::Image(img) = &items[1] {
2689 assert!(matches!(img.data, DocumentSourceKind::Base64(_)));
2690 } else {
2691 panic!("Expected Image content second");
2692 }
2693
2694 if let ToolResultContent::Image(img) = &items[2] {
2696 assert!(matches!(img.data, DocumentSourceKind::Url(_)));
2697 } else {
2698 panic!("Expected Image content third");
2699 }
2700 }
2701
2702 #[tokio::test]
2706 #[ignore = "requires GEMINI_API_KEY environment variable"]
2707 async fn test_gemini_agent_with_image_tool_result_e2e() {
2708 use crate::completion::{Prompt, ToolDefinition};
2709 use crate::prelude::*;
2710 use crate::providers::gemini;
2711 use crate::tool::Tool;
2712 use serde::{Deserialize, Serialize};
2713
2714 #[derive(Debug, Serialize, Deserialize)]
2716 struct ImageGeneratorTool;
2717
2718 #[derive(Debug, thiserror::Error)]
2719 #[error("Image generation error")]
2720 struct ImageToolError;
2721
2722 impl Tool for ImageGeneratorTool {
2723 const NAME: &'static str = "generate_test_image";
2724 type Error = ImageToolError;
2725 type Args = serde_json::Value;
2726 type Output = String;
2728
2729 async fn definition(&self, _prompt: String) -> ToolDefinition {
2730 ToolDefinition {
2731 name: "generate_test_image".to_string(),
2732 description: "Generates a small test image (a 1x1 red pixel). Call this tool when asked to generate or show an image.".to_string(),
2733 parameters: json!({
2734 "type": "object",
2735 "properties": {},
2736 "required": []
2737 }),
2738 }
2739 }
2740
2741 async fn call(&self, _args: Self::Args) -> Result<Self::Output, Self::Error> {
2742 Ok(json!({
2745 "type": "image",
2746 "data": "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8DwHwAFBQIAX8jx0gAAAABJRU5ErkJggg==",
2747 "mimeType": "image/png"
2748 }).to_string())
2749 }
2750 }
2751
2752 let client = gemini::Client::from_env();
2753
2754 let agent = client
2755 .agent("gemini-3-flash-preview")
2756 .preamble("You are a helpful assistant. When asked about images, use the generate_test_image tool to create one, then describe what you see in the image.")
2757 .tool(ImageGeneratorTool)
2758 .build();
2759
2760 let response = agent
2762 .prompt("Please generate a test image and tell me what color the pixel is.")
2763 .await;
2764
2765 assert!(
2768 response.is_ok(),
2769 "Gemini should successfully process tool result with image: {:?}",
2770 response.err()
2771 );
2772
2773 let response_text = response.unwrap();
2774 println!("Response: {response_text}");
2775 assert!(!response_text.is_empty(), "Response should not be empty");
2777 }
2778}