1pub const GEMINI_3_1_FLASH_LITE_PREVIEW: &str = "gemini-3.1-flash-lite-preview";
7pub const GEMINI_3_FLASH_PREVIEW: &str = "gemini-3-flash-preview";
9pub const GEMINI_2_5_PRO_PREVIEW_06_05: &str = "gemini-2.5-pro-preview-06-05";
11pub const GEMINI_2_5_PRO_PREVIEW_05_06: &str = "gemini-2.5-pro-preview-05-06";
13pub const GEMINI_2_5_PRO_PREVIEW_03_25: &str = "gemini-2.5-pro-preview-03-25";
15pub const GEMINI_2_5_FLASH_PREVIEW_04_17: &str = "gemini-2.5-flash-preview-04-17";
17pub const GEMINI_2_5_PRO_EXP_03_25: &str = "gemini-2.5-pro-exp-03-25";
19pub const GEMINI_2_5_FLASH: &str = "gemini-2.5-flash";
21pub const GEMINI_2_0_FLASH_LITE: &str = "gemini-2.0-flash-lite";
23pub const GEMINI_2_0_FLASH: &str = "gemini-2.0-flash";
25
26use self::gemini_api_types::Schema;
27use crate::http_client::HttpClientExt;
28use crate::message::{self, MimeType, Reasoning};
29
30use crate::providers::gemini::completion::gemini_api_types::{
31 AdditionalParameters, FunctionCallingMode, ToolConfig,
32};
33use crate::providers::gemini::streaming::StreamingCompletionResponse;
34use crate::telemetry::SpanCombinator;
35use crate::{
36 OneOrMany,
37 completion::{self, CompletionError, CompletionRequest},
38};
39use gemini_api_types::{
40 Content, FunctionDeclaration, GenerateContentRequest, GenerateContentResponse,
41 GenerationConfig, Part, PartKind, Role, Tool,
42};
43use serde_json::{Map, Value};
44use std::convert::TryFrom;
45use tracing::{Level, enabled, info_span};
46use tracing_futures::Instrument;
47
48use super::Client;
49
50#[derive(Clone, Debug)]
55pub struct CompletionModel<T = reqwest::Client> {
56 pub(crate) client: Client<T>,
57 pub model: String,
58}
59
60impl<T> CompletionModel<T> {
61 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
62 Self {
63 client,
64 model: model.into(),
65 }
66 }
67
68 pub fn with_model(client: Client<T>, model: &str) -> Self {
69 Self {
70 client,
71 model: model.into(),
72 }
73 }
74}
75
76impl<T> completion::CompletionModel for CompletionModel<T>
77where
78 T: HttpClientExt + Clone + 'static,
79{
80 type Response = GenerateContentResponse;
81 type StreamingResponse = StreamingCompletionResponse;
82 type Client = super::Client<T>;
83
84 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
85 Self::new(client.clone(), model)
86 }
87
88 async fn completion(
89 &self,
90 completion_request: CompletionRequest,
91 ) -> Result<completion::CompletionResponse<GenerateContentResponse>, CompletionError> {
92 let request_model = resolve_request_model(&self.model, &completion_request);
93 let span = if tracing::Span::current().is_disabled() {
94 info_span!(
95 target: "rig::completions",
96 "generate_content",
97 gen_ai.operation.name = "generate_content",
98 gen_ai.provider.name = "gcp.gemini",
99 gen_ai.request.model = &request_model,
100 gen_ai.system_instructions = &completion_request.preamble,
101 gen_ai.response.id = tracing::field::Empty,
102 gen_ai.response.model = tracing::field::Empty,
103 gen_ai.usage.output_tokens = tracing::field::Empty,
104 gen_ai.usage.input_tokens = tracing::field::Empty,
105 gen_ai.usage.cached_tokens = tracing::field::Empty,
106 )
107 } else {
108 tracing::Span::current()
109 };
110
111 let request = create_request_body(completion_request)?;
112
113 if enabled!(Level::TRACE) {
114 tracing::trace!(
115 target: "rig::completions",
116 "Gemini completion request: {}",
117 serde_json::to_string_pretty(&request)?
118 );
119 }
120
121 let body = serde_json::to_vec(&request)?;
122
123 let path = completion_endpoint(&request_model);
124
125 let request = self
126 .client
127 .post(path.as_str())?
128 .body(body)
129 .map_err(|e| CompletionError::HttpError(e.into()))?;
130
131 async move {
132 let response = self.client.send::<_, Vec<u8>>(request).await?;
133
134 if response.status().is_success() {
135 let response_body = response
136 .into_body()
137 .await
138 .map_err(CompletionError::HttpError)?;
139
140 let response_text = String::from_utf8_lossy(&response_body).to_string();
141
142 let response: GenerateContentResponse = serde_json::from_slice(&response_body)
143 .map_err(|err| {
144 tracing::error!(
145 error = %err,
146 body = %response_text,
147 "Failed to deserialize Gemini completion response"
148 );
149 CompletionError::JsonError(err)
150 })?;
151
152 let span = tracing::Span::current();
153 span.record_response_metadata(&response);
154 span.record_token_usage(&response.usage_metadata);
155
156 if enabled!(Level::TRACE) {
157 tracing::trace!(
158 target: "rig::completions",
159 "Gemini completion response: {}",
160 serde_json::to_string_pretty(&response)?
161 );
162 }
163
164 response.try_into()
165 } else {
166 let text = String::from_utf8_lossy(
167 &response
168 .into_body()
169 .await
170 .map_err(CompletionError::HttpError)?,
171 )
172 .into();
173
174 Err(CompletionError::ProviderError(text))
175 }
176 }
177 .instrument(span)
178 .await
179 }
180
181 async fn stream(
182 &self,
183 request: CompletionRequest,
184 ) -> Result<
185 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
186 CompletionError,
187 > {
188 CompletionModel::stream(self, request).await
189 }
190}
191
192pub(crate) fn create_request_body(
193 completion_request: CompletionRequest,
194) -> Result<GenerateContentRequest, CompletionError> {
195 let documents_message = completion_request.normalized_documents();
196
197 let CompletionRequest {
198 model: _,
199 preamble,
200 chat_history,
201 documents: _,
202 tools: function_tools,
203 temperature,
204 max_tokens,
205 tool_choice,
206 mut additional_params,
207 output_schema,
208 } = completion_request;
209
210 let mut full_history = Vec::new();
211 if let Some(msg) = documents_message {
212 full_history.push(msg);
213 }
214 full_history.extend(chat_history);
215 let (history_system, full_history) = split_system_messages_from_history(full_history);
216
217 let mut additional_params_payload = additional_params
218 .take()
219 .unwrap_or_else(|| Value::Object(Map::new()));
220 let mut additional_tools =
221 extract_tools_from_additional_params(&mut additional_params_payload)?;
222
223 let AdditionalParameters {
224 mut generation_config,
225 additional_params,
226 } = serde_json::from_value::<AdditionalParameters>(additional_params_payload)?;
227
228 if let Some(schema) = output_schema {
230 let cfg = generation_config.get_or_insert_with(GenerationConfig::default);
231 cfg.response_mime_type = Some("application/json".to_string());
232 cfg.response_json_schema = Some(schema.to_value());
233 }
234
235 generation_config = generation_config.map(|mut cfg| {
236 if let Some(temp) = temperature {
237 cfg.temperature = Some(temp);
238 };
239
240 if let Some(max_tokens) = max_tokens {
241 cfg.max_output_tokens = Some(max_tokens);
242 };
243
244 cfg
245 });
246
247 let mut system_parts: Vec<Part> = Vec::new();
248 if let Some(preamble) = preamble.filter(|preamble| !preamble.is_empty()) {
249 system_parts.push(preamble.into());
250 }
251 for content in history_system {
252 if !content.is_empty() {
253 system_parts.push(content.into());
254 }
255 }
256 let system_instruction = if system_parts.is_empty() {
257 None
258 } else {
259 Some(Content {
260 parts: system_parts,
261 role: Some(Role::Model),
262 })
263 };
264
265 let mut tools = if function_tools.is_empty() {
266 Vec::new()
267 } else {
268 vec![serde_json::to_value(Tool::try_from(function_tools)?)?]
269 };
270 tools.append(&mut additional_tools);
271 let tools = if tools.is_empty() { None } else { Some(tools) };
272
273 let tool_config = if let Some(cfg) = tool_choice {
274 Some(ToolConfig {
275 function_calling_config: Some(FunctionCallingMode::try_from(cfg)?),
276 })
277 } else {
278 None
279 };
280
281 let request = GenerateContentRequest {
282 contents: full_history
283 .into_iter()
284 .map(|msg| {
285 msg.try_into()
286 .map_err(|e| CompletionError::RequestError(Box::new(e)))
287 })
288 .collect::<Result<Vec<_>, _>>()?,
289 generation_config,
290 safety_settings: None,
291 tools,
292 tool_config,
293 system_instruction,
294 additional_params,
295 };
296
297 Ok(request)
298}
299
300fn split_system_messages_from_history(
301 history: Vec<completion::Message>,
302) -> (Vec<String>, Vec<completion::Message>) {
303 let mut system = Vec::new();
304 let mut remaining = Vec::new();
305
306 for message in history {
307 match message {
308 completion::Message::System { content } => system.push(content),
309 other => remaining.push(other),
310 }
311 }
312
313 (system, remaining)
314}
315
316fn extract_tools_from_additional_params(
317 additional_params: &mut Value,
318) -> Result<Vec<Value>, CompletionError> {
319 if let Some(map) = additional_params.as_object_mut()
320 && let Some(raw_tools) = map.remove("tools")
321 {
322 return serde_json::from_value::<Vec<Value>>(raw_tools).map_err(|err| {
323 CompletionError::RequestError(
324 format!("Invalid Gemini `additional_params.tools` payload: {err}").into(),
325 )
326 });
327 }
328
329 Ok(Vec::new())
330}
331
332pub(crate) fn resolve_request_model(
333 default_model: &str,
334 completion_request: &CompletionRequest,
335) -> String {
336 completion_request
337 .model
338 .clone()
339 .unwrap_or_else(|| default_model.to_string())
340}
341
342pub(crate) fn completion_endpoint(model: &str) -> String {
343 format!("/v1beta/models/{model}:generateContent")
344}
345
346pub(crate) fn streaming_endpoint(model: &str) -> String {
347 format!("/v1beta/models/{model}:streamGenerateContent")
348}
349
350impl TryFrom<completion::ToolDefinition> for Tool {
351 type Error = CompletionError;
352
353 fn try_from(tool: completion::ToolDefinition) -> Result<Self, Self::Error> {
354 let parameters: Option<Schema> =
355 if tool.parameters == serde_json::json!({"type": "object", "properties": {}}) {
356 None
357 } else {
358 Some(tool.parameters.try_into()?)
359 };
360
361 Ok(Self {
362 function_declarations: vec![FunctionDeclaration {
363 name: tool.name,
364 description: tool.description,
365 parameters,
366 }],
367 code_execution: None,
368 })
369 }
370}
371
372impl TryFrom<Vec<completion::ToolDefinition>> for Tool {
373 type Error = CompletionError;
374
375 fn try_from(tools: Vec<completion::ToolDefinition>) -> Result<Self, Self::Error> {
376 let mut function_declarations = Vec::new();
377
378 for tool in tools {
379 let parameters =
380 if tool.parameters == serde_json::json!({"type": "object", "properties": {}}) {
381 None
382 } else {
383 match tool.parameters.try_into() {
384 Ok(schema) => Some(schema),
385 Err(e) => {
386 let emsg = format!(
387 "Tool '{}' could not be converted to a schema: {:?}",
388 tool.name, e,
389 );
390 return Err(CompletionError::ProviderError(emsg));
391 }
392 }
393 };
394
395 function_declarations.push(FunctionDeclaration {
396 name: tool.name,
397 description: tool.description,
398 parameters,
399 });
400 }
401
402 Ok(Self {
403 function_declarations,
404 code_execution: None,
405 })
406 }
407}
408
409impl TryFrom<GenerateContentResponse> for completion::CompletionResponse<GenerateContentResponse> {
410 type Error = CompletionError;
411
412 fn try_from(response: GenerateContentResponse) -> Result<Self, Self::Error> {
413 let candidate = response.candidates.first().ok_or_else(|| {
414 CompletionError::ResponseError("No response candidates in response".into())
415 })?;
416
417 let content = candidate
418 .content
419 .as_ref()
420 .ok_or_else(|| {
421 let reason = candidate
422 .finish_reason
423 .as_ref()
424 .map(|r| format!("finish_reason={r:?}"))
425 .unwrap_or_else(|| "finish_reason=<unknown>".to_string());
426 let message = candidate
427 .finish_message
428 .as_deref()
429 .unwrap_or("no finish message provided");
430 CompletionError::ResponseError(format!(
431 "Gemini candidate missing content ({reason}, finish_message={message})"
432 ))
433 })?
434 .parts
435 .iter()
436 .map(
437 |Part {
438 thought,
439 thought_signature,
440 part,
441 ..
442 }| {
443 Ok(match part {
444 PartKind::Text(text) => {
445 if let Some(thought) = thought
446 && *thought
447 {
448 completion::AssistantContent::Reasoning(
449 Reasoning::new_with_signature(text, thought_signature.clone()),
450 )
451 } else {
452 completion::AssistantContent::text(text)
453 }
454 }
455 PartKind::InlineData(inline_data) => {
456 let mime_type =
457 message::MediaType::from_mime_type(&inline_data.mime_type);
458
459 match mime_type {
460 Some(message::MediaType::Image(media_type)) => {
461 message::AssistantContent::image_base64(
462 &inline_data.data,
463 Some(media_type),
464 Some(message::ImageDetail::default()),
465 )
466 }
467 _ => {
468 return Err(CompletionError::ResponseError(format!(
469 "Unsupported media type {mime_type:?}"
470 )));
471 }
472 }
473 }
474 PartKind::FunctionCall(function_call) => {
475 completion::AssistantContent::ToolCall(
476 message::ToolCall::new(
477 function_call.name.clone(),
478 message::ToolFunction::new(
479 function_call.name.clone(),
480 function_call.args.clone(),
481 ),
482 )
483 .with_signature(thought_signature.clone()),
484 )
485 }
486 _ => {
487 return Err(CompletionError::ResponseError(
488 "Response did not contain a message or tool call".into(),
489 ));
490 }
491 })
492 },
493 )
494 .collect::<Result<Vec<_>, _>>()?;
495
496 let choice = OneOrMany::many(content).map_err(|_| {
497 CompletionError::ResponseError(
498 "Response contained no message or tool call (empty)".to_owned(),
499 )
500 })?;
501
502 let usage = response
503 .usage_metadata
504 .as_ref()
505 .map(|usage| completion::Usage {
506 input_tokens: usage.prompt_token_count as u64,
507 output_tokens: usage.candidates_token_count.unwrap_or(0) as u64,
508 total_tokens: usage.total_token_count as u64,
509 cached_input_tokens: 0,
510 })
511 .unwrap_or_default();
512
513 Ok(completion::CompletionResponse {
514 choice,
515 usage,
516 raw_response: response,
517 message_id: None,
518 })
519 }
520}
521
522pub mod gemini_api_types {
523 use crate::telemetry::ProviderResponseExt;
524 use std::{collections::HashMap, convert::Infallible, str::FromStr};
525
526 use serde::{Deserialize, Serialize};
530 use serde_json::{Value, json};
531
532 use crate::completion::GetTokenUsage;
533 use crate::message::{DocumentSourceKind, ImageMediaType, MessageError, MimeType};
534 use crate::{
535 completion::CompletionError,
536 message::{self},
537 providers::gemini::gemini_api_types::{CodeExecutionResult, ExecutableCode},
538 };
539
540 #[derive(Debug, Deserialize, Serialize, Default)]
541 #[serde(rename_all = "camelCase")]
542 pub struct AdditionalParameters {
543 pub generation_config: Option<GenerationConfig>,
545 #[serde(flatten, skip_serializing_if = "Option::is_none")]
547 pub additional_params: Option<serde_json::Value>,
548 }
549
550 impl AdditionalParameters {
551 pub fn with_config(mut self, cfg: GenerationConfig) -> Self {
552 self.generation_config = Some(cfg);
553 self
554 }
555
556 pub fn with_params(mut self, params: serde_json::Value) -> Self {
557 self.additional_params = Some(params);
558 self
559 }
560 }
561
562 #[derive(Debug, Deserialize, Serialize)]
570 #[serde(rename_all = "camelCase")]
571 pub struct GenerateContentResponse {
572 pub response_id: String,
573 pub candidates: Vec<ContentCandidate>,
575 pub prompt_feedback: Option<PromptFeedback>,
577 pub usage_metadata: Option<UsageMetadata>,
579 pub model_version: Option<String>,
580 }
581
582 impl ProviderResponseExt for GenerateContentResponse {
583 type OutputMessage = ContentCandidate;
584 type Usage = UsageMetadata;
585
586 fn get_response_id(&self) -> Option<String> {
587 Some(self.response_id.clone())
588 }
589
590 fn get_response_model_name(&self) -> Option<String> {
591 None
592 }
593
594 fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
595 self.candidates.clone()
596 }
597
598 fn get_text_response(&self) -> Option<String> {
599 let str = self
600 .candidates
601 .iter()
602 .filter_map(|x| {
603 let content = x.content.as_ref()?;
604 if content.role.as_ref().is_none_or(|y| y != &Role::Model) {
605 return None;
606 }
607
608 let res = content
609 .parts
610 .iter()
611 .filter_map(|part| {
612 if let PartKind::Text(ref str) = part.part {
613 Some(str.to_owned())
614 } else {
615 None
616 }
617 })
618 .collect::<Vec<String>>()
619 .join("\n");
620
621 Some(res)
622 })
623 .collect::<Vec<String>>()
624 .join("\n");
625
626 if str.is_empty() { None } else { Some(str) }
627 }
628
629 fn get_usage(&self) -> Option<Self::Usage> {
630 self.usage_metadata.clone()
631 }
632 }
633
634 #[derive(Clone, Debug, Deserialize, Serialize)]
636 #[serde(rename_all = "camelCase")]
637 pub struct ContentCandidate {
638 #[serde(skip_serializing_if = "Option::is_none")]
640 pub content: Option<Content>,
641 pub finish_reason: Option<FinishReason>,
644 pub safety_ratings: Option<Vec<SafetyRating>>,
647 pub citation_metadata: Option<CitationMetadata>,
651 pub token_count: Option<i32>,
653 pub avg_logprobs: Option<f64>,
655 pub logprobs_result: Option<LogprobsResult>,
657 pub index: Option<i32>,
659 pub finish_message: Option<String>,
661 }
662
663 #[derive(Clone, Debug, Deserialize, Serialize)]
664 pub struct Content {
665 #[serde(default)]
667 pub parts: Vec<Part>,
668 pub role: Option<Role>,
671 }
672
673 impl TryFrom<message::Message> for Content {
674 type Error = message::MessageError;
675
676 fn try_from(msg: message::Message) -> Result<Self, Self::Error> {
677 Ok(match msg {
678 message::Message::System { content } => Content {
679 parts: vec![content.into()],
680 role: Some(Role::User),
681 },
682 message::Message::User { content } => Content {
683 parts: content
684 .into_iter()
685 .map(|c| c.try_into())
686 .collect::<Result<Vec<_>, _>>()?,
687 role: Some(Role::User),
688 },
689 message::Message::Assistant { content, .. } => Content {
690 role: Some(Role::Model),
691 parts: content
692 .into_iter()
693 .map(|content| content.try_into())
694 .collect::<Result<Vec<_>, _>>()?,
695 },
696 })
697 }
698 }
699
700 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
701 #[serde(rename_all = "lowercase")]
702 pub enum Role {
703 User,
704 Model,
705 }
706
707 #[derive(Debug, Default, Deserialize, Serialize, Clone, PartialEq)]
708 #[serde(rename_all = "camelCase")]
709 pub struct Part {
710 #[serde(skip_serializing_if = "Option::is_none")]
712 pub thought: Option<bool>,
713 #[serde(skip_serializing_if = "Option::is_none")]
715 pub thought_signature: Option<String>,
716 #[serde(flatten)]
717 pub part: PartKind,
718 #[serde(flatten, skip_serializing_if = "Option::is_none")]
719 pub additional_params: Option<Value>,
720 }
721
722 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
726 #[serde(rename_all = "camelCase")]
727 pub enum PartKind {
728 Text(String),
729 InlineData(Blob),
730 FunctionCall(FunctionCall),
731 FunctionResponse(FunctionResponse),
732 FileData(FileData),
733 ExecutableCode(ExecutableCode),
734 CodeExecutionResult(CodeExecutionResult),
735 }
736
737 impl Default for PartKind {
740 fn default() -> Self {
741 Self::Text(String::new())
742 }
743 }
744
745 impl From<String> for Part {
746 fn from(text: String) -> Self {
747 Self {
748 thought: Some(false),
749 thought_signature: None,
750 part: PartKind::Text(text),
751 additional_params: None,
752 }
753 }
754 }
755
756 impl From<&str> for Part {
757 fn from(text: &str) -> Self {
758 Self::from(text.to_string())
759 }
760 }
761
762 impl FromStr for Part {
763 type Err = Infallible;
764
765 fn from_str(s: &str) -> Result<Self, Self::Err> {
766 Ok(s.into())
767 }
768 }
769
770 impl TryFrom<(ImageMediaType, DocumentSourceKind)> for PartKind {
771 type Error = message::MessageError;
772 fn try_from(
773 (mime_type, doc_src): (ImageMediaType, DocumentSourceKind),
774 ) -> Result<Self, Self::Error> {
775 let mime_type = mime_type.to_mime_type().to_string();
776 let part = match doc_src {
777 DocumentSourceKind::Url(url) => PartKind::FileData(FileData {
778 mime_type: Some(mime_type),
779 file_uri: url,
780 }),
781 DocumentSourceKind::Base64(data) | DocumentSourceKind::String(data) => {
782 PartKind::InlineData(Blob { mime_type, data })
783 }
784 DocumentSourceKind::Raw(_) => {
785 return Err(message::MessageError::ConversionError(
786 "Raw files not supported, encode as base64 first".into(),
787 ));
788 }
789 DocumentSourceKind::Unknown => {
790 return Err(message::MessageError::ConversionError(
791 "Can't convert an unknown document source".to_string(),
792 ));
793 }
794 };
795
796 Ok(part)
797 }
798 }
799
800 impl TryFrom<message::UserContent> for Part {
801 type Error = message::MessageError;
802
803 fn try_from(content: message::UserContent) -> Result<Self, Self::Error> {
804 match content {
805 message::UserContent::Text(message::Text { text }) => Ok(Part {
806 thought: Some(false),
807 thought_signature: None,
808 part: PartKind::Text(text),
809 additional_params: None,
810 }),
811 message::UserContent::ToolResult(message::ToolResult { id, content, .. }) => {
812 let mut response_json: Option<serde_json::Value> = None;
813 let mut parts: Vec<FunctionResponsePart> = Vec::new();
814
815 for item in content.iter() {
816 match item {
817 message::ToolResultContent::Text(text) => {
818 let result: serde_json::Value =
819 serde_json::from_str(&text.text).unwrap_or_else(|error| {
820 tracing::trace!(
821 ?error,
822 "Tool result is not a valid JSON, treat it as normal string"
823 );
824 json!(&text.text)
825 });
826
827 response_json = Some(match response_json {
828 Some(mut existing) => {
829 if let serde_json::Value::Object(ref mut map) = existing {
830 map.insert("text".to_string(), result);
831 }
832 existing
833 }
834 None => json!({ "result": result }),
835 });
836 }
837 message::ToolResultContent::Image(image) => {
838 let part = match &image.data {
839 DocumentSourceKind::Base64(b64) => {
840 let mime_type = image
841 .media_type
842 .as_ref()
843 .ok_or(message::MessageError::ConversionError(
844 "Image media type is required for Gemini tool results".to_string(),
845 ))?
846 .to_mime_type();
847
848 FunctionResponsePart {
849 inline_data: Some(FunctionResponseInlineData {
850 mime_type: mime_type.to_string(),
851 data: b64.clone(),
852 display_name: None,
853 }),
854 file_data: None,
855 }
856 }
857 DocumentSourceKind::Url(url) => {
858 let mime_type = image
859 .media_type
860 .as_ref()
861 .map(|mt| mt.to_mime_type().to_string());
862
863 FunctionResponsePart {
864 inline_data: None,
865 file_data: Some(FileData {
866 mime_type,
867 file_uri: url.clone(),
868 }),
869 }
870 }
871 _ => {
872 return Err(message::MessageError::ConversionError(
873 "Unsupported image source kind for tool results"
874 .to_string(),
875 ));
876 }
877 };
878 parts.push(part);
879 }
880 }
881 }
882
883 Ok(Part {
884 thought: Some(false),
885 thought_signature: None,
886 part: PartKind::FunctionResponse(FunctionResponse {
887 name: id,
888 response: response_json,
889 parts: if parts.is_empty() { None } else { Some(parts) },
890 }),
891 additional_params: None,
892 })
893 }
894 message::UserContent::Image(message::Image {
895 data, media_type, ..
896 }) => match media_type {
897 Some(media_type) => match media_type {
898 message::ImageMediaType::JPEG
899 | message::ImageMediaType::PNG
900 | message::ImageMediaType::WEBP
901 | message::ImageMediaType::HEIC
902 | message::ImageMediaType::HEIF => {
903 let part = PartKind::try_from((media_type, data))?;
904 Ok(Part {
905 thought: Some(false),
906 thought_signature: None,
907 part,
908 additional_params: None,
909 })
910 }
911 _ => Err(message::MessageError::ConversionError(format!(
912 "Unsupported image media type {media_type:?}"
913 ))),
914 },
915 None => Err(message::MessageError::ConversionError(
916 "Media type for image is required for Gemini".to_string(),
917 )),
918 },
919 message::UserContent::Document(message::Document {
920 data, media_type, ..
921 }) => {
922 let Some(media_type) = media_type else {
923 return Err(MessageError::ConversionError(
924 "A mime type is required for document inputs to Gemini".to_string(),
925 ));
926 };
927
928 if matches!(
931 media_type,
932 message::DocumentMediaType::TXT
933 | message::DocumentMediaType::RTF
934 | message::DocumentMediaType::HTML
935 | message::DocumentMediaType::CSS
936 | message::DocumentMediaType::MARKDOWN
937 | message::DocumentMediaType::CSV
938 | message::DocumentMediaType::XML
939 | message::DocumentMediaType::Javascript
940 | message::DocumentMediaType::Python
941 ) {
942 use base64::Engine;
943 let part = match data {
944 DocumentSourceKind::String(text) => PartKind::Text(text),
945 DocumentSourceKind::Base64(data) => {
946 let text = String::from_utf8(
948 base64::engine::general_purpose::STANDARD
949 .decode(&data)
950 .map_err(|e| {
951 MessageError::ConversionError(format!(
952 "Failed to decode base64: {e}"
953 ))
954 })?,
955 )
956 .map_err(|e| {
957 MessageError::ConversionError(format!(
958 "Invalid UTF-8 in document: {e}"
959 ))
960 })?;
961 PartKind::Text(text)
962 }
963 DocumentSourceKind::Url(file_uri) => PartKind::FileData(FileData {
964 mime_type: Some(media_type.to_mime_type().to_string()),
965 file_uri,
966 }),
967 DocumentSourceKind::Raw(_) => {
968 return Err(MessageError::ConversionError(
969 "Raw files not supported, encode as base64 first".to_string(),
970 ));
971 }
972 DocumentSourceKind::Unknown => {
973 return Err(MessageError::ConversionError(
974 "Document has no body".to_string(),
975 ));
976 }
977 };
978
979 Ok(Part {
980 thought: Some(false),
981 part,
982 ..Default::default()
983 })
984 } else if !media_type.is_code() {
985 let mime_type = media_type.to_mime_type().to_string();
986
987 let part = match data {
988 DocumentSourceKind::Url(file_uri) => PartKind::FileData(FileData {
989 mime_type: Some(mime_type),
990 file_uri,
991 }),
992 DocumentSourceKind::Base64(data) | DocumentSourceKind::String(data) => {
993 PartKind::InlineData(Blob { mime_type, data })
994 }
995 DocumentSourceKind::Raw(_) => {
996 return Err(message::MessageError::ConversionError(
997 "Raw files not supported, encode as base64 first".into(),
998 ));
999 }
1000 _ => {
1001 return Err(message::MessageError::ConversionError(
1002 "Document has no body".to_string(),
1003 ));
1004 }
1005 };
1006
1007 Ok(Part {
1008 thought: Some(false),
1009 part,
1010 ..Default::default()
1011 })
1012 } else {
1013 Err(message::MessageError::ConversionError(format!(
1014 "Unsupported document media type {media_type:?}"
1015 )))
1016 }
1017 }
1018
1019 message::UserContent::Audio(message::Audio {
1020 data, media_type, ..
1021 }) => {
1022 let Some(media_type) = media_type else {
1023 return Err(MessageError::ConversionError(
1024 "A mime type is required for audio inputs to Gemini".to_string(),
1025 ));
1026 };
1027
1028 let mime_type = media_type.to_mime_type().to_string();
1029
1030 let part = match data {
1031 DocumentSourceKind::Base64(data) => {
1032 PartKind::InlineData(Blob { data, mime_type })
1033 }
1034
1035 DocumentSourceKind::Url(file_uri) => PartKind::FileData(FileData {
1036 mime_type: Some(mime_type),
1037 file_uri,
1038 }),
1039 DocumentSourceKind::String(_) => {
1040 return Err(message::MessageError::ConversionError(
1041 "Strings cannot be used as audio files!".into(),
1042 ));
1043 }
1044 DocumentSourceKind::Raw(_) => {
1045 return Err(message::MessageError::ConversionError(
1046 "Raw files not supported, encode as base64 first".into(),
1047 ));
1048 }
1049 DocumentSourceKind::Unknown => {
1050 return Err(message::MessageError::ConversionError(
1051 "Content has no body".to_string(),
1052 ));
1053 }
1054 };
1055
1056 Ok(Part {
1057 thought: Some(false),
1058 part,
1059 ..Default::default()
1060 })
1061 }
1062 message::UserContent::Video(message::Video {
1063 data,
1064 media_type,
1065 additional_params,
1066 ..
1067 }) => {
1068 let mime_type = media_type.map(|media_ty| media_ty.to_mime_type().to_string());
1069
1070 let part = match data {
1071 DocumentSourceKind::Url(file_uri) => {
1072 if file_uri.starts_with("https://www.youtube.com") {
1073 PartKind::FileData(FileData {
1074 mime_type,
1075 file_uri,
1076 })
1077 } else {
1078 if mime_type.is_none() {
1079 return Err(MessageError::ConversionError(
1080 "A mime type is required for non-Youtube video file inputs to Gemini"
1081 .to_string(),
1082 ));
1083 }
1084
1085 PartKind::FileData(FileData {
1086 mime_type,
1087 file_uri,
1088 })
1089 }
1090 }
1091 DocumentSourceKind::Base64(data) => {
1092 let Some(mime_type) = mime_type else {
1093 return Err(MessageError::ConversionError(
1094 "A media type is expected for base64 encoded strings"
1095 .to_string(),
1096 ));
1097 };
1098 PartKind::InlineData(Blob { mime_type, data })
1099 }
1100 DocumentSourceKind::String(_) => {
1101 return Err(message::MessageError::ConversionError(
1102 "Strings cannot be used as audio files!".into(),
1103 ));
1104 }
1105 DocumentSourceKind::Raw(_) => {
1106 return Err(message::MessageError::ConversionError(
1107 "Raw file data not supported, encode as base64 first".into(),
1108 ));
1109 }
1110 DocumentSourceKind::Unknown => {
1111 return Err(message::MessageError::ConversionError(
1112 "Media type for video is required for Gemini".to_string(),
1113 ));
1114 }
1115 };
1116
1117 Ok(Part {
1118 thought: Some(false),
1119 thought_signature: None,
1120 part,
1121 additional_params,
1122 })
1123 }
1124 }
1125 }
1126 }
1127
1128 impl TryFrom<message::AssistantContent> for Part {
1129 type Error = message::MessageError;
1130
1131 fn try_from(content: message::AssistantContent) -> Result<Self, Self::Error> {
1132 match content {
1133 message::AssistantContent::Text(message::Text { text }) => Ok(text.into()),
1134 message::AssistantContent::Image(message::Image {
1135 data, media_type, ..
1136 }) => match media_type {
1137 Some(media_type) => match media_type {
1138 message::ImageMediaType::JPEG
1139 | message::ImageMediaType::PNG
1140 | message::ImageMediaType::WEBP
1141 | message::ImageMediaType::HEIC
1142 | message::ImageMediaType::HEIF => {
1143 let part = PartKind::try_from((media_type, data))?;
1144 Ok(Part {
1145 thought: Some(false),
1146 thought_signature: None,
1147 part,
1148 additional_params: None,
1149 })
1150 }
1151 _ => Err(message::MessageError::ConversionError(format!(
1152 "Unsupported image media type {media_type:?}"
1153 ))),
1154 },
1155 None => Err(message::MessageError::ConversionError(
1156 "Media type for image is required for Gemini".to_string(),
1157 )),
1158 },
1159 message::AssistantContent::ToolCall(tool_call) => Ok(tool_call.into()),
1160 message::AssistantContent::Reasoning(reasoning) => Ok(Part {
1161 thought: Some(true),
1162 thought_signature: reasoning.first_signature().map(str::to_owned),
1163 part: PartKind::Text(reasoning.display_text()),
1164 additional_params: None,
1165 }),
1166 }
1167 }
1168 }
1169
1170 impl From<message::ToolCall> for Part {
1171 fn from(tool_call: message::ToolCall) -> Self {
1172 Self {
1173 thought: Some(false),
1174 thought_signature: tool_call.signature,
1175 part: PartKind::FunctionCall(FunctionCall {
1176 name: tool_call.function.name,
1177 args: tool_call.function.arguments,
1178 }),
1179 additional_params: None,
1180 }
1181 }
1182 }
1183
1184 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1187 #[serde(rename_all = "camelCase")]
1188 pub struct Blob {
1189 pub mime_type: String,
1192 pub data: String,
1194 }
1195
1196 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1199 pub struct FunctionCall {
1200 pub name: String,
1203 pub args: serde_json::Value,
1205 }
1206
1207 impl From<message::ToolCall> for FunctionCall {
1208 fn from(tool_call: message::ToolCall) -> Self {
1209 Self {
1210 name: tool_call.function.name,
1211 args: tool_call.function.arguments,
1212 }
1213 }
1214 }
1215
1216 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1220 pub struct FunctionResponse {
1221 pub name: String,
1224 #[serde(skip_serializing_if = "Option::is_none")]
1226 pub response: Option<serde_json::Value>,
1227 #[serde(skip_serializing_if = "Option::is_none")]
1229 pub parts: Option<Vec<FunctionResponsePart>>,
1230 }
1231
1232 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1234 #[serde(rename_all = "camelCase")]
1235 pub struct FunctionResponsePart {
1236 #[serde(skip_serializing_if = "Option::is_none")]
1238 pub inline_data: Option<FunctionResponseInlineData>,
1239 #[serde(skip_serializing_if = "Option::is_none")]
1241 pub file_data: Option<FileData>,
1242 }
1243
1244 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1246 #[serde(rename_all = "camelCase")]
1247 pub struct FunctionResponseInlineData {
1248 pub mime_type: String,
1250 pub data: String,
1252 #[serde(skip_serializing_if = "Option::is_none")]
1254 pub display_name: Option<String>,
1255 }
1256
1257 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1259 #[serde(rename_all = "camelCase")]
1260 pub struct FileData {
1261 pub mime_type: Option<String>,
1263 pub file_uri: String,
1265 }
1266
1267 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1268 pub struct SafetyRating {
1269 pub category: HarmCategory,
1270 pub probability: HarmProbability,
1271 }
1272
1273 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1274 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1275 pub enum HarmProbability {
1276 HarmProbabilityUnspecified,
1277 Negligible,
1278 Low,
1279 Medium,
1280 High,
1281 }
1282
1283 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1284 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1285 pub enum HarmCategory {
1286 HarmCategoryUnspecified,
1287 HarmCategoryDerogatory,
1288 HarmCategoryToxicity,
1289 HarmCategoryViolence,
1290 HarmCategorySexually,
1291 HarmCategoryMedical,
1292 HarmCategoryDangerous,
1293 HarmCategoryHarassment,
1294 HarmCategoryHateSpeech,
1295 HarmCategorySexuallyExplicit,
1296 HarmCategoryDangerousContent,
1297 HarmCategoryCivicIntegrity,
1298 }
1299
1300 #[derive(Debug, Deserialize, Clone, Default, Serialize)]
1301 #[serde(rename_all = "camelCase")]
1302 pub struct UsageMetadata {
1303 pub prompt_token_count: i32,
1304 #[serde(skip_serializing_if = "Option::is_none")]
1305 pub cached_content_token_count: Option<i32>,
1306 #[serde(skip_serializing_if = "Option::is_none")]
1307 pub candidates_token_count: Option<i32>,
1308 pub total_token_count: i32,
1309 #[serde(skip_serializing_if = "Option::is_none")]
1310 pub thoughts_token_count: Option<i32>,
1311 }
1312
1313 impl std::fmt::Display for UsageMetadata {
1314 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1315 write!(
1316 f,
1317 "Prompt token count: {}\nCached content token count: {}\nCandidates token count: {}\nTotal token count: {}",
1318 self.prompt_token_count,
1319 match self.cached_content_token_count {
1320 Some(count) => count.to_string(),
1321 None => "n/a".to_string(),
1322 },
1323 match self.candidates_token_count {
1324 Some(count) => count.to_string(),
1325 None => "n/a".to_string(),
1326 },
1327 self.total_token_count
1328 )
1329 }
1330 }
1331
1332 impl GetTokenUsage for UsageMetadata {
1333 fn token_usage(&self) -> Option<crate::completion::Usage> {
1334 let mut usage = crate::completion::Usage::new();
1335
1336 usage.input_tokens = self.prompt_token_count as u64;
1337 usage.output_tokens = (self.cached_content_token_count.unwrap_or_default()
1338 + self.candidates_token_count.unwrap_or_default()
1339 + self.thoughts_token_count.unwrap_or_default())
1340 as u64;
1341 usage.total_tokens = usage.input_tokens + usage.output_tokens;
1342
1343 Some(usage)
1344 }
1345 }
1346
1347 #[derive(Debug, Deserialize, Serialize)]
1349 #[serde(rename_all = "camelCase")]
1350 pub struct PromptFeedback {
1351 pub block_reason: Option<BlockReason>,
1353 pub safety_ratings: Option<Vec<SafetyRating>>,
1355 }
1356
1357 #[derive(Debug, Deserialize, Serialize)]
1359 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1360 pub enum BlockReason {
1361 BlockReasonUnspecified,
1363 Safety,
1365 Other,
1367 Blocklist,
1369 ProhibitedContent,
1371 }
1372
1373 #[derive(Clone, Debug, Deserialize, Serialize)]
1374 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1375 pub enum FinishReason {
1376 FinishReasonUnspecified,
1378 Stop,
1380 MaxTokens,
1382 Safety,
1384 Recitation,
1386 Language,
1388 Other,
1390 Blocklist,
1392 ProhibitedContent,
1394 Spii,
1396 MalformedFunctionCall,
1398 }
1399
1400 #[derive(Clone, Debug, Deserialize, Serialize)]
1401 #[serde(rename_all = "camelCase")]
1402 pub struct CitationMetadata {
1403 pub citation_sources: Vec<CitationSource>,
1404 }
1405
1406 #[derive(Clone, Debug, Deserialize, Serialize)]
1407 #[serde(rename_all = "camelCase")]
1408 pub struct CitationSource {
1409 #[serde(skip_serializing_if = "Option::is_none")]
1410 pub uri: Option<String>,
1411 #[serde(skip_serializing_if = "Option::is_none")]
1412 pub start_index: Option<i32>,
1413 #[serde(skip_serializing_if = "Option::is_none")]
1414 pub end_index: Option<i32>,
1415 #[serde(skip_serializing_if = "Option::is_none")]
1416 pub license: Option<String>,
1417 }
1418
1419 #[derive(Clone, Debug, Deserialize, Serialize)]
1420 #[serde(rename_all = "camelCase")]
1421 pub struct LogprobsResult {
1422 pub top_candidate: Vec<TopCandidate>,
1423 pub chosen_candidate: Vec<LogProbCandidate>,
1424 }
1425
1426 #[derive(Clone, Debug, Deserialize, Serialize)]
1427 pub struct TopCandidate {
1428 pub candidates: Vec<LogProbCandidate>,
1429 }
1430
1431 #[derive(Clone, Debug, Deserialize, Serialize)]
1432 #[serde(rename_all = "camelCase")]
1433 pub struct LogProbCandidate {
1434 pub token: String,
1435 pub token_id: String,
1436 pub log_probability: f64,
1437 }
1438
1439 #[derive(Debug, Deserialize, Serialize)]
1444 #[serde(rename_all = "camelCase")]
1445 pub struct GenerationConfig {
1446 #[serde(skip_serializing_if = "Option::is_none")]
1449 pub stop_sequences: Option<Vec<String>>,
1450 #[serde(skip_serializing_if = "Option::is_none")]
1456 pub response_mime_type: Option<String>,
1457 #[serde(skip_serializing_if = "Option::is_none")]
1461 pub response_schema: Option<Schema>,
1462 #[serde(
1468 skip_serializing_if = "Option::is_none",
1469 rename = "_responseJsonSchema"
1470 )]
1471 pub _response_json_schema: Option<Value>,
1472 #[serde(skip_serializing_if = "Option::is_none")]
1474 pub response_json_schema: Option<Value>,
1475 #[serde(skip_serializing_if = "Option::is_none")]
1478 pub candidate_count: Option<i32>,
1479 #[serde(skip_serializing_if = "Option::is_none")]
1482 pub max_output_tokens: Option<u64>,
1483 #[serde(skip_serializing_if = "Option::is_none")]
1486 pub temperature: Option<f64>,
1487 #[serde(skip_serializing_if = "Option::is_none")]
1494 pub top_p: Option<f64>,
1495 #[serde(skip_serializing_if = "Option::is_none")]
1501 pub top_k: Option<i32>,
1502 #[serde(skip_serializing_if = "Option::is_none")]
1508 pub presence_penalty: Option<f64>,
1509 #[serde(skip_serializing_if = "Option::is_none")]
1517 pub frequency_penalty: Option<f64>,
1518 #[serde(skip_serializing_if = "Option::is_none")]
1520 pub response_logprobs: Option<bool>,
1521 #[serde(skip_serializing_if = "Option::is_none")]
1524 pub logprobs: Option<i32>,
1525 #[serde(skip_serializing_if = "Option::is_none")]
1527 pub thinking_config: Option<ThinkingConfig>,
1528 #[serde(skip_serializing_if = "Option::is_none")]
1529 pub image_config: Option<ImageConfig>,
1530 }
1531
1532 impl Default for GenerationConfig {
1533 fn default() -> Self {
1534 Self {
1535 temperature: Some(1.0),
1536 max_output_tokens: Some(4096),
1537 stop_sequences: None,
1538 response_mime_type: None,
1539 response_schema: None,
1540 _response_json_schema: None,
1541 response_json_schema: None,
1542 candidate_count: None,
1543 top_p: None,
1544 top_k: None,
1545 presence_penalty: None,
1546 frequency_penalty: None,
1547 response_logprobs: None,
1548 logprobs: None,
1549 thinking_config: None,
1550 image_config: None,
1551 }
1552 }
1553 }
1554
1555 #[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
1557 #[serde(rename_all = "snake_case")]
1558 pub enum ThinkingLevel {
1559 Minimal,
1560 Low,
1561 Medium,
1562 High,
1563 }
1564
1565 #[derive(Debug, Deserialize, Serialize)]
1569 #[serde(rename_all = "camelCase")]
1570 pub struct ThinkingConfig {
1571 #[serde(skip_serializing_if = "Option::is_none")]
1573 pub thinking_budget: Option<u32>,
1574 #[serde(skip_serializing_if = "Option::is_none")]
1576 pub thinking_level: Option<ThinkingLevel>,
1577 #[serde(skip_serializing_if = "Option::is_none")]
1579 pub include_thoughts: Option<bool>,
1580 }
1581
1582 #[derive(Debug, Deserialize, Serialize)]
1583 #[serde(rename_all = "camelCase")]
1584 pub struct ImageConfig {
1585 #[serde(skip_serializing_if = "Option::is_none")]
1586 pub aspect_ratio: Option<String>,
1587 #[serde(skip_serializing_if = "Option::is_none")]
1588 pub image_size: Option<String>,
1589 }
1590
1591 #[derive(Debug, Deserialize, Serialize, Clone)]
1595 pub struct Schema {
1596 pub r#type: String,
1597 #[serde(skip_serializing_if = "Option::is_none")]
1598 pub format: Option<String>,
1599 #[serde(skip_serializing_if = "Option::is_none")]
1600 pub description: Option<String>,
1601 #[serde(skip_serializing_if = "Option::is_none")]
1602 pub nullable: Option<bool>,
1603 #[serde(skip_serializing_if = "Option::is_none")]
1604 pub r#enum: Option<Vec<String>>,
1605 #[serde(skip_serializing_if = "Option::is_none")]
1606 pub max_items: Option<i32>,
1607 #[serde(skip_serializing_if = "Option::is_none")]
1608 pub min_items: Option<i32>,
1609 #[serde(skip_serializing_if = "Option::is_none")]
1610 pub properties: Option<HashMap<String, Schema>>,
1611 #[serde(skip_serializing_if = "Option::is_none")]
1612 pub required: Option<Vec<String>>,
1613 #[serde(skip_serializing_if = "Option::is_none")]
1614 pub items: Option<Box<Schema>>,
1615 }
1616
1617 pub fn flatten_schema(mut schema: Value) -> Result<Value, CompletionError> {
1623 let defs = if let Some(obj) = schema.as_object() {
1625 obj.get("$defs").or_else(|| obj.get("definitions")).cloned()
1626 } else {
1627 None
1628 };
1629
1630 let Some(defs_value) = defs else {
1631 return Ok(schema);
1632 };
1633
1634 let Some(defs_obj) = defs_value.as_object() else {
1635 return Err(CompletionError::ResponseError(
1636 "$defs must be an object".into(),
1637 ));
1638 };
1639
1640 resolve_refs(&mut schema, defs_obj)?;
1641
1642 if let Some(obj) = schema.as_object_mut() {
1644 obj.remove("$defs");
1645 obj.remove("definitions");
1646 }
1647
1648 Ok(schema)
1649 }
1650
1651 fn resolve_refs(
1654 value: &mut Value,
1655 defs: &serde_json::Map<String, Value>,
1656 ) -> Result<(), CompletionError> {
1657 match value {
1658 Value::Object(obj) => {
1659 if let Some(ref_value) = obj.get("$ref")
1660 && let Some(ref_str) = ref_value.as_str()
1661 {
1662 let def_name = parse_ref_path(ref_str)?;
1664
1665 let def = defs.get(&def_name).ok_or_else(|| {
1666 CompletionError::ResponseError(format!("Reference not found: {}", ref_str))
1667 })?;
1668
1669 let mut resolved = def.clone();
1670 resolve_refs(&mut resolved, defs)?;
1671 *value = resolved;
1672 return Ok(());
1673 }
1674
1675 for (_, v) in obj.iter_mut() {
1676 resolve_refs(v, defs)?;
1677 }
1678 }
1679 Value::Array(arr) => {
1680 for item in arr.iter_mut() {
1681 resolve_refs(item, defs)?;
1682 }
1683 }
1684 _ => {}
1685 }
1686
1687 Ok(())
1688 }
1689
1690 fn parse_ref_path(ref_str: &str) -> Result<String, CompletionError> {
1696 if let Some(fragment) = ref_str.strip_prefix('#') {
1697 if let Some(name) = fragment.strip_prefix("/$defs/") {
1698 Ok(name.to_string())
1699 } else if let Some(name) = fragment.strip_prefix("/definitions/") {
1700 Ok(name.to_string())
1701 } else {
1702 Err(CompletionError::ResponseError(format!(
1703 "Unsupported reference format: {}",
1704 ref_str
1705 )))
1706 }
1707 } else {
1708 Err(CompletionError::ResponseError(format!(
1709 "Only fragment references (#/...) are supported: {}",
1710 ref_str
1711 )))
1712 }
1713 }
1714
1715 fn extract_type(type_value: &Value) -> Option<String> {
1718 if type_value.is_string() {
1719 type_value.as_str().map(String::from)
1720 } else if type_value.is_array() {
1721 type_value
1722 .as_array()
1723 .and_then(|arr| arr.first())
1724 .and_then(|v| v.as_str().map(String::from))
1725 } else {
1726 None
1727 }
1728 }
1729
1730 fn extract_type_from_composition(composition: &Value) -> Option<String> {
1733 composition.as_array().and_then(|arr| {
1734 arr.iter().find_map(|schema| {
1735 if let Some(obj) = schema.as_object() {
1736 if let Some(type_val) = obj.get("type")
1738 && let Some(type_str) = type_val.as_str()
1739 && type_str == "null"
1740 {
1741 return None;
1742 }
1743 obj.get("type").and_then(extract_type).or_else(|| {
1745 if obj.contains_key("properties") {
1746 Some("object".to_string())
1747 } else {
1748 None
1749 }
1750 })
1751 } else {
1752 None
1753 }
1754 })
1755 })
1756 }
1757
1758 fn extract_schema_from_composition(
1761 composition: &Value,
1762 ) -> Option<serde_json::Map<String, Value>> {
1763 composition.as_array().and_then(|arr| {
1764 arr.iter().find_map(|schema| {
1765 if let Some(obj) = schema.as_object()
1766 && let Some(type_val) = obj.get("type")
1767 && let Some(type_str) = type_val.as_str()
1768 {
1769 if type_str == "null" {
1770 return None;
1771 }
1772 Some(obj.clone())
1773 } else {
1774 None
1775 }
1776 })
1777 })
1778 }
1779
1780 fn infer_type(obj: &serde_json::Map<String, Value>) -> String {
1783 if let Some(type_val) = obj.get("type")
1785 && let Some(type_str) = extract_type(type_val)
1786 {
1787 return type_str;
1788 }
1789
1790 if let Some(any_of) = obj.get("anyOf")
1792 && let Some(type_str) = extract_type_from_composition(any_of)
1793 {
1794 return type_str;
1795 }
1796
1797 if let Some(one_of) = obj.get("oneOf")
1798 && let Some(type_str) = extract_type_from_composition(one_of)
1799 {
1800 return type_str;
1801 }
1802
1803 if let Some(all_of) = obj.get("allOf")
1804 && let Some(type_str) = extract_type_from_composition(all_of)
1805 {
1806 return type_str;
1807 }
1808
1809 if obj.contains_key("properties") {
1811 "object".to_string()
1812 } else {
1813 String::new()
1814 }
1815 }
1816
1817 impl TryFrom<Value> for Schema {
1818 type Error = CompletionError;
1819
1820 fn try_from(value: Value) -> Result<Self, Self::Error> {
1821 let flattened_val = flatten_schema(value)?;
1822 if let Some(obj) = flattened_val.as_object() {
1823 let props_source = if obj.get("properties").is_none() {
1826 if let Some(any_of) = obj.get("anyOf") {
1827 extract_schema_from_composition(any_of)
1828 } else if let Some(one_of) = obj.get("oneOf") {
1829 extract_schema_from_composition(one_of)
1830 } else if let Some(all_of) = obj.get("allOf") {
1831 extract_schema_from_composition(all_of)
1832 } else {
1833 None
1834 }
1835 .unwrap_or(obj.clone())
1836 } else {
1837 obj.clone()
1838 };
1839
1840 let schema_type = infer_type(obj);
1841 let items = obj
1842 .get("items")
1843 .and_then(|v| v.clone().try_into().ok())
1844 .map(Box::new);
1845
1846 let items = if schema_type == "array" && items.is_none() {
1849 Some(Box::new(Schema {
1850 r#type: "string".to_string(),
1851 format: None,
1852 description: None,
1853 nullable: None,
1854 r#enum: None,
1855 max_items: None,
1856 min_items: None,
1857 properties: None,
1858 required: None,
1859 items: None,
1860 }))
1861 } else {
1862 items
1863 };
1864
1865 Ok(Schema {
1866 r#type: schema_type,
1867 format: obj.get("format").and_then(|v| v.as_str()).map(String::from),
1868 description: obj
1869 .get("description")
1870 .and_then(|v| v.as_str())
1871 .map(String::from),
1872 nullable: obj.get("nullable").and_then(|v| v.as_bool()),
1873 r#enum: obj.get("enum").and_then(|v| v.as_array()).map(|arr| {
1874 arr.iter()
1875 .filter_map(|v| v.as_str().map(String::from))
1876 .collect()
1877 }),
1878 max_items: obj
1879 .get("maxItems")
1880 .and_then(|v| v.as_i64())
1881 .map(|v| v as i32),
1882 min_items: obj
1883 .get("minItems")
1884 .and_then(|v| v.as_i64())
1885 .map(|v| v as i32),
1886 properties: props_source
1887 .get("properties")
1888 .and_then(|v| v.as_object())
1889 .map(|map| {
1890 map.iter()
1891 .filter_map(|(k, v)| {
1892 v.clone().try_into().ok().map(|schema| (k.clone(), schema))
1893 })
1894 .collect()
1895 }),
1896 required: props_source
1897 .get("required")
1898 .and_then(|v| v.as_array())
1899 .map(|arr| {
1900 arr.iter()
1901 .filter_map(|v| v.as_str().map(String::from))
1902 .collect()
1903 }),
1904 items,
1905 })
1906 } else {
1907 Err(CompletionError::ResponseError(
1908 "Expected a JSON object for Schema".into(),
1909 ))
1910 }
1911 }
1912 }
1913
1914 #[derive(Debug, Serialize)]
1915 #[serde(rename_all = "camelCase")]
1916 pub struct GenerateContentRequest {
1917 pub contents: Vec<Content>,
1918 #[serde(skip_serializing_if = "Option::is_none")]
1919 pub tools: Option<Vec<Value>>,
1920 pub tool_config: Option<ToolConfig>,
1921 pub generation_config: Option<GenerationConfig>,
1923 pub safety_settings: Option<Vec<SafetySetting>>,
1937 pub system_instruction: Option<Content>,
1940 #[serde(flatten, skip_serializing_if = "Option::is_none")]
1943 pub additional_params: Option<serde_json::Value>,
1944 }
1945
1946 #[derive(Debug, Serialize)]
1947 #[serde(rename_all = "camelCase")]
1948 pub struct Tool {
1949 pub function_declarations: Vec<FunctionDeclaration>,
1950 pub code_execution: Option<CodeExecution>,
1951 }
1952
1953 #[derive(Debug, Serialize, Clone)]
1954 #[serde(rename_all = "camelCase")]
1955 pub struct FunctionDeclaration {
1956 pub name: String,
1957 pub description: String,
1958 #[serde(skip_serializing_if = "Option::is_none")]
1959 pub parameters: Option<Schema>,
1960 }
1961
1962 #[derive(Debug, Serialize, Deserialize)]
1963 #[serde(rename_all = "camelCase")]
1964 pub struct ToolConfig {
1965 pub function_calling_config: Option<FunctionCallingMode>,
1966 }
1967
1968 #[derive(Debug, Serialize, Deserialize, Default)]
1969 #[serde(tag = "mode", rename_all = "UPPERCASE")]
1970 pub enum FunctionCallingMode {
1971 #[default]
1972 Auto,
1973 None,
1974 Any {
1975 #[serde(skip_serializing_if = "Option::is_none")]
1976 allowed_function_names: Option<Vec<String>>,
1977 },
1978 }
1979
1980 impl TryFrom<message::ToolChoice> for FunctionCallingMode {
1981 type Error = CompletionError;
1982 fn try_from(value: message::ToolChoice) -> Result<Self, Self::Error> {
1983 let res = match value {
1984 message::ToolChoice::Auto => Self::Auto,
1985 message::ToolChoice::None => Self::None,
1986 message::ToolChoice::Required => Self::Any {
1987 allowed_function_names: None,
1988 },
1989 message::ToolChoice::Specific { function_names } => Self::Any {
1990 allowed_function_names: Some(function_names),
1991 },
1992 };
1993
1994 Ok(res)
1995 }
1996 }
1997
1998 #[derive(Debug, Serialize)]
1999 pub struct CodeExecution {}
2000
2001 #[derive(Debug, Serialize)]
2002 #[serde(rename_all = "camelCase")]
2003 pub struct SafetySetting {
2004 pub category: HarmCategory,
2005 pub threshold: HarmBlockThreshold,
2006 }
2007
2008 #[derive(Debug, Serialize)]
2009 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
2010 pub enum HarmBlockThreshold {
2011 HarmBlockThresholdUnspecified,
2012 BlockLowAndAbove,
2013 BlockMediumAndAbove,
2014 BlockOnlyHigh,
2015 BlockNone,
2016 Off,
2017 }
2018}
2019
2020#[cfg(test)]
2021mod tests {
2022 use crate::{
2023 message,
2024 providers::gemini::completion::gemini_api_types::{
2025 ContentCandidate, FinishReason, flatten_schema,
2026 },
2027 };
2028
2029 use super::*;
2030 use serde_json::json;
2031
2032 #[test]
2033 fn test_resolve_request_model_uses_override() {
2034 let request = CompletionRequest {
2035 model: Some("gemini-2.5-flash".to_string()),
2036 preamble: None,
2037 chat_history: crate::OneOrMany::one("Hello".into()),
2038 documents: vec![],
2039 tools: vec![],
2040 temperature: None,
2041 max_tokens: None,
2042 tool_choice: None,
2043 additional_params: None,
2044 output_schema: None,
2045 };
2046
2047 let request_model = resolve_request_model("gemini-2.0-flash", &request);
2048 assert_eq!(request_model, "gemini-2.5-flash");
2049 assert_eq!(
2050 completion_endpoint(&request_model),
2051 "/v1beta/models/gemini-2.5-flash:generateContent"
2052 );
2053 assert_eq!(
2054 streaming_endpoint(&request_model),
2055 "/v1beta/models/gemini-2.5-flash:streamGenerateContent"
2056 );
2057 }
2058
2059 #[test]
2060 fn test_resolve_request_model_uses_default_when_unset() {
2061 let request = CompletionRequest {
2062 model: None,
2063 preamble: None,
2064 chat_history: crate::OneOrMany::one("Hello".into()),
2065 documents: vec![],
2066 tools: vec![],
2067 temperature: None,
2068 max_tokens: None,
2069 tool_choice: None,
2070 additional_params: None,
2071 output_schema: None,
2072 };
2073
2074 assert_eq!(
2075 resolve_request_model("gemini-2.0-flash", &request),
2076 "gemini-2.0-flash"
2077 );
2078 }
2079
2080 #[test]
2081 fn test_deserialize_message_user() {
2082 let raw_message = r#"{
2083 "parts": [
2084 {"text": "Hello, world!"},
2085 {"inlineData": {"mimeType": "image/png", "data": "base64encodeddata"}},
2086 {"functionCall": {"name": "test_function", "args": {"arg1": "value1"}}},
2087 {"functionResponse": {"name": "test_function", "response": {"result": "success"}}},
2088 {"fileData": {"mimeType": "application/pdf", "fileUri": "http://example.com/file.pdf"}},
2089 {"executableCode": {"code": "print('Hello, world!')", "language": "PYTHON"}},
2090 {"codeExecutionResult": {"output": "Hello, world!", "outcome": "OUTCOME_OK"}}
2091 ],
2092 "role": "user"
2093 }"#;
2094
2095 let content: Content = {
2096 let jd = &mut serde_json::Deserializer::from_str(raw_message);
2097 serde_path_to_error::deserialize(jd).unwrap_or_else(|err| {
2098 panic!("Deserialization error at {}: {}", err.path(), err);
2099 })
2100 };
2101 assert_eq!(content.role, Some(Role::User));
2102 assert_eq!(content.parts.len(), 7);
2103
2104 let parts: Vec<Part> = content.parts.into_iter().collect();
2105
2106 if let Part {
2107 part: PartKind::Text(text),
2108 ..
2109 } = &parts[0]
2110 {
2111 assert_eq!(text, "Hello, world!");
2112 } else {
2113 panic!("Expected text part");
2114 }
2115
2116 if let Part {
2117 part: PartKind::InlineData(inline_data),
2118 ..
2119 } = &parts[1]
2120 {
2121 assert_eq!(inline_data.mime_type, "image/png");
2122 assert_eq!(inline_data.data, "base64encodeddata");
2123 } else {
2124 panic!("Expected inline data part");
2125 }
2126
2127 if let Part {
2128 part: PartKind::FunctionCall(function_call),
2129 ..
2130 } = &parts[2]
2131 {
2132 assert_eq!(function_call.name, "test_function");
2133 assert_eq!(
2134 function_call.args.as_object().unwrap().get("arg1").unwrap(),
2135 "value1"
2136 );
2137 } else {
2138 panic!("Expected function call part");
2139 }
2140
2141 if let Part {
2142 part: PartKind::FunctionResponse(function_response),
2143 ..
2144 } = &parts[3]
2145 {
2146 assert_eq!(function_response.name, "test_function");
2147 assert_eq!(
2148 function_response
2149 .response
2150 .as_ref()
2151 .unwrap()
2152 .get("result")
2153 .unwrap(),
2154 "success"
2155 );
2156 } else {
2157 panic!("Expected function response part");
2158 }
2159
2160 if let Part {
2161 part: PartKind::FileData(file_data),
2162 ..
2163 } = &parts[4]
2164 {
2165 assert_eq!(file_data.mime_type.as_ref().unwrap(), "application/pdf");
2166 assert_eq!(file_data.file_uri, "http://example.com/file.pdf");
2167 } else {
2168 panic!("Expected file data part");
2169 }
2170
2171 if let Part {
2172 part: PartKind::ExecutableCode(executable_code),
2173 ..
2174 } = &parts[5]
2175 {
2176 assert_eq!(executable_code.code, "print('Hello, world!')");
2177 } else {
2178 panic!("Expected executable code part");
2179 }
2180
2181 if let Part {
2182 part: PartKind::CodeExecutionResult(code_execution_result),
2183 ..
2184 } = &parts[6]
2185 {
2186 assert_eq!(
2187 code_execution_result.clone().output.unwrap(),
2188 "Hello, world!"
2189 );
2190 } else {
2191 panic!("Expected code execution result part");
2192 }
2193 }
2194
2195 #[test]
2196 fn test_deserialize_message_model() {
2197 let json_data = json!({
2198 "parts": [{"text": "Hello, user!"}],
2199 "role": "model"
2200 });
2201
2202 let content: Content = serde_json::from_value(json_data).unwrap();
2203 assert_eq!(content.role, Some(Role::Model));
2204 assert_eq!(content.parts.len(), 1);
2205 if let Some(Part {
2206 part: PartKind::Text(text),
2207 ..
2208 }) = content.parts.first()
2209 {
2210 assert_eq!(text, "Hello, user!");
2211 } else {
2212 panic!("Expected text part");
2213 }
2214 }
2215
2216 #[test]
2217 fn test_message_conversion_user() {
2218 let msg = message::Message::user("Hello, world!");
2219 let content: Content = msg.try_into().unwrap();
2220 assert_eq!(content.role, Some(Role::User));
2221 assert_eq!(content.parts.len(), 1);
2222 if let Some(Part {
2223 part: PartKind::Text(text),
2224 ..
2225 }) = &content.parts.first()
2226 {
2227 assert_eq!(text, "Hello, world!");
2228 } else {
2229 panic!("Expected text part");
2230 }
2231 }
2232
2233 #[test]
2234 fn test_message_conversion_model() {
2235 let msg = message::Message::assistant("Hello, user!");
2236
2237 let content: Content = msg.try_into().unwrap();
2238 assert_eq!(content.role, Some(Role::Model));
2239 assert_eq!(content.parts.len(), 1);
2240 if let Some(Part {
2241 part: PartKind::Text(text),
2242 ..
2243 }) = &content.parts.first()
2244 {
2245 assert_eq!(text, "Hello, user!");
2246 } else {
2247 panic!("Expected text part");
2248 }
2249 }
2250
2251 #[test]
2252 fn test_thought_signature_is_preserved_from_response_reasoning_part() {
2253 let response = GenerateContentResponse {
2254 response_id: "resp_1".to_string(),
2255 candidates: vec![ContentCandidate {
2256 content: Some(Content {
2257 parts: vec![Part {
2258 thought: Some(true),
2259 thought_signature: Some("thought_sig_123".to_string()),
2260 part: PartKind::Text("thinking text".to_string()),
2261 additional_params: None,
2262 }],
2263 role: Some(Role::Model),
2264 }),
2265 finish_reason: Some(FinishReason::Stop),
2266 safety_ratings: None,
2267 citation_metadata: None,
2268 token_count: None,
2269 avg_logprobs: None,
2270 logprobs_result: None,
2271 index: Some(0),
2272 finish_message: None,
2273 }],
2274 prompt_feedback: None,
2275 usage_metadata: None,
2276 model_version: None,
2277 };
2278
2279 let converted: crate::completion::CompletionResponse<GenerateContentResponse> =
2280 response.try_into().expect("convert response");
2281 let first = converted.choice.first();
2282 assert!(matches!(
2283 first,
2284 message::AssistantContent::Reasoning(message::Reasoning { content, .. })
2285 if matches!(
2286 content.first(),
2287 Some(message::ReasoningContent::Text {
2288 text,
2289 signature: Some(signature)
2290 }) if text == "thinking text" && signature == "thought_sig_123"
2291 )
2292 ));
2293 }
2294
2295 #[test]
2296 fn test_reasoning_signature_is_emitted_in_gemini_part() {
2297 let msg = message::Message::Assistant {
2298 id: None,
2299 content: OneOrMany::one(message::AssistantContent::Reasoning(
2300 message::Reasoning::new_with_signature(
2301 "structured thought",
2302 Some("reuse_sig_456".to_string()),
2303 ),
2304 )),
2305 };
2306
2307 let converted: Content = msg.try_into().expect("convert message");
2308 let first = converted.parts.first().expect("reasoning part");
2309 assert_eq!(first.thought, Some(true));
2310 assert_eq!(first.thought_signature.as_deref(), Some("reuse_sig_456"));
2311 assert!(matches!(
2312 &first.part,
2313 PartKind::Text(text) if text == "structured thought"
2314 ));
2315 }
2316
2317 #[test]
2318 fn test_message_conversion_tool_call() {
2319 let tool_call = message::ToolCall {
2320 id: "test_tool".to_string(),
2321 call_id: None,
2322 function: message::ToolFunction {
2323 name: "test_function".to_string(),
2324 arguments: json!({"arg1": "value1"}),
2325 },
2326 signature: None,
2327 additional_params: None,
2328 };
2329
2330 let msg = message::Message::Assistant {
2331 id: None,
2332 content: OneOrMany::one(message::AssistantContent::ToolCall(tool_call)),
2333 };
2334
2335 let content: Content = msg.try_into().unwrap();
2336 assert_eq!(content.role, Some(Role::Model));
2337 assert_eq!(content.parts.len(), 1);
2338 if let Some(Part {
2339 part: PartKind::FunctionCall(function_call),
2340 ..
2341 }) = content.parts.first()
2342 {
2343 assert_eq!(function_call.name, "test_function");
2344 assert_eq!(
2345 function_call.args.as_object().unwrap().get("arg1").unwrap(),
2346 "value1"
2347 );
2348 } else {
2349 panic!("Expected function call part");
2350 }
2351 }
2352
2353 #[test]
2354 fn test_vec_schema_conversion() {
2355 let schema_with_ref = json!({
2356 "type": "array",
2357 "items": {
2358 "$ref": "#/$defs/Person"
2359 },
2360 "$defs": {
2361 "Person": {
2362 "type": "object",
2363 "properties": {
2364 "first_name": {
2365 "type": ["string", "null"],
2366 "description": "The person's first name, if provided (null otherwise)"
2367 },
2368 "last_name": {
2369 "type": ["string", "null"],
2370 "description": "The person's last name, if provided (null otherwise)"
2371 },
2372 "job": {
2373 "type": ["string", "null"],
2374 "description": "The person's job, if provided (null otherwise)"
2375 }
2376 },
2377 "required": []
2378 }
2379 }
2380 });
2381
2382 let result: Result<Schema, _> = schema_with_ref.try_into();
2383
2384 match result {
2385 Ok(schema) => {
2386 assert_eq!(schema.r#type, "array");
2387
2388 if let Some(items) = schema.items {
2389 println!("item types: {}", items.r#type);
2390
2391 assert_ne!(items.r#type, "", "Items type should not be empty string!");
2392 assert_eq!(items.r#type, "object", "Items should be object type");
2393 } else {
2394 panic!("Schema should have items field for array type");
2395 }
2396 }
2397 Err(e) => println!("Schema conversion failed: {:?}", e),
2398 }
2399 }
2400
2401 #[test]
2402 fn test_object_schema() {
2403 let simple_schema = json!({
2404 "type": "object",
2405 "properties": {
2406 "name": {
2407 "type": "string"
2408 }
2409 }
2410 });
2411
2412 let schema: Schema = simple_schema.try_into().unwrap();
2413 assert_eq!(schema.r#type, "object");
2414 assert!(schema.properties.is_some());
2415 }
2416
2417 #[test]
2418 fn test_array_with_inline_items() {
2419 let inline_schema = json!({
2420 "type": "array",
2421 "items": {
2422 "type": "object",
2423 "properties": {
2424 "name": {
2425 "type": "string"
2426 }
2427 }
2428 }
2429 });
2430
2431 let schema: Schema = inline_schema.try_into().unwrap();
2432 assert_eq!(schema.r#type, "array");
2433
2434 if let Some(items) = schema.items {
2435 assert_eq!(items.r#type, "object");
2436 assert!(items.properties.is_some());
2437 } else {
2438 panic!("Schema should have items field");
2439 }
2440 }
2441 #[test]
2442 fn test_flattened_schema() {
2443 let ref_schema = json!({
2444 "type": "array",
2445 "items": {
2446 "$ref": "#/$defs/Person"
2447 },
2448 "$defs": {
2449 "Person": {
2450 "type": "object",
2451 "properties": {
2452 "name": { "type": "string" }
2453 }
2454 }
2455 }
2456 });
2457
2458 let flattened = flatten_schema(ref_schema).unwrap();
2459 let schema: Schema = flattened.try_into().unwrap();
2460
2461 assert_eq!(schema.r#type, "array");
2462
2463 if let Some(items) = schema.items {
2464 println!("Flattened items type: '{}'", items.r#type);
2465
2466 assert_eq!(items.r#type, "object");
2467 assert!(items.properties.is_some());
2468 }
2469 }
2470
2471 #[test]
2472 fn test_array_without_items_gets_default() {
2473 let schema_json = json!({
2474 "type": "object",
2475 "properties": {
2476 "service_ids": {
2477 "type": "array",
2478 "description": "A list of service IDs"
2479 }
2480 }
2481 });
2482
2483 let schema: Schema = schema_json.try_into().unwrap();
2484 let props = schema.properties.unwrap();
2485 let service_ids = props.get("service_ids").unwrap();
2486 assert_eq!(service_ids.r#type, "array");
2487 let items = service_ids
2488 .items
2489 .as_ref()
2490 .expect("array schema missing items should get a default");
2491 assert_eq!(items.r#type, "string");
2492 }
2493
2494 #[test]
2495 fn test_txt_document_conversion_to_text_part() {
2496 use crate::message::{DocumentMediaType, UserContent};
2498
2499 let doc = UserContent::document(
2500 "Note: test.md\nPath: /test.md\nContent: Hello World!",
2501 Some(DocumentMediaType::TXT),
2502 );
2503
2504 let content: Content = message::Message::User {
2505 content: crate::OneOrMany::one(doc),
2506 }
2507 .try_into()
2508 .unwrap();
2509
2510 if let Part {
2511 part: PartKind::Text(text),
2512 ..
2513 } = &content.parts[0]
2514 {
2515 assert!(text.contains("Note: test.md"));
2516 assert!(text.contains("Hello World!"));
2517 } else {
2518 panic!(
2519 "Expected text part for TXT document, got: {:?}",
2520 content.parts[0]
2521 );
2522 }
2523 }
2524
2525 #[test]
2526 fn test_tool_result_with_image_content() {
2527 use crate::OneOrMany;
2529 use crate::message::{
2530 DocumentSourceKind, Image, ImageMediaType, ToolResult, ToolResultContent,
2531 };
2532
2533 let tool_result = ToolResult {
2535 id: "test_tool".to_string(),
2536 call_id: None,
2537 content: OneOrMany::many(vec![
2538 ToolResultContent::Text(message::Text {
2539 text: r#"{"status": "success"}"#.to_string(),
2540 }),
2541 ToolResultContent::Image(Image {
2542 data: DocumentSourceKind::Base64("iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==".to_string()),
2543 media_type: Some(ImageMediaType::PNG),
2544 detail: None,
2545 additional_params: None,
2546 }),
2547 ]).expect("Should create OneOrMany with multiple items"),
2548 };
2549
2550 let user_content = message::UserContent::ToolResult(tool_result);
2551 let msg = message::Message::User {
2552 content: OneOrMany::one(user_content),
2553 };
2554
2555 let content: Content = msg.try_into().expect("Should convert to Gemini Content");
2557 assert_eq!(content.role, Some(Role::User));
2558 assert_eq!(content.parts.len(), 1);
2559
2560 if let Some(Part {
2562 part: PartKind::FunctionResponse(function_response),
2563 ..
2564 }) = content.parts.first()
2565 {
2566 assert_eq!(function_response.name, "test_tool");
2567
2568 assert!(function_response.response.is_some());
2570 let response = function_response.response.as_ref().unwrap();
2571 assert!(response.get("result").is_some());
2572
2573 assert!(function_response.parts.is_some());
2575 let parts = function_response.parts.as_ref().unwrap();
2576 assert_eq!(parts.len(), 1);
2577
2578 let image_part = &parts[0];
2579 assert!(image_part.inline_data.is_some());
2580 let inline_data = image_part.inline_data.as_ref().unwrap();
2581 assert_eq!(inline_data.mime_type, "image/png");
2582 assert!(!inline_data.data.is_empty());
2583 } else {
2584 panic!("Expected FunctionResponse part");
2585 }
2586 }
2587
2588 #[test]
2589 fn test_markdown_document_conversion_to_text_part() {
2590 use crate::message::{DocumentMediaType, UserContent};
2592
2593 let doc = UserContent::document(
2594 "# Heading\n\n* List item",
2595 Some(DocumentMediaType::MARKDOWN),
2596 );
2597
2598 let content: Content = message::Message::User {
2599 content: crate::OneOrMany::one(doc),
2600 }
2601 .try_into()
2602 .unwrap();
2603
2604 if let Part {
2605 part: PartKind::Text(text),
2606 ..
2607 } = &content.parts[0]
2608 {
2609 assert_eq!(text, "# Heading\n\n* List item");
2610 } else {
2611 panic!(
2612 "Expected text part for MARKDOWN document, got: {:?}",
2613 content.parts[0]
2614 );
2615 }
2616 }
2617
2618 #[test]
2619 fn test_markdown_url_document_conversion_to_file_data_part() {
2620 use crate::message::{DocumentMediaType, DocumentSourceKind, UserContent};
2622
2623 let doc = UserContent::Document(message::Document {
2624 data: DocumentSourceKind::Url(
2625 "https://generativelanguage.googleapis.com/v1beta/files/test-markdown".to_string(),
2626 ),
2627 media_type: Some(DocumentMediaType::MARKDOWN),
2628 additional_params: None,
2629 });
2630
2631 let content: Content = message::Message::User {
2632 content: crate::OneOrMany::one(doc),
2633 }
2634 .try_into()
2635 .unwrap();
2636
2637 if let Part {
2638 part: PartKind::FileData(file_data),
2639 ..
2640 } = &content.parts[0]
2641 {
2642 assert_eq!(
2643 file_data.file_uri,
2644 "https://generativelanguage.googleapis.com/v1beta/files/test-markdown"
2645 );
2646 assert_eq!(file_data.mime_type.as_deref(), Some("text/markdown"));
2647 } else {
2648 panic!(
2649 "Expected file_data part for URL MARKDOWN document, got: {:?}",
2650 content.parts[0]
2651 );
2652 }
2653 }
2654
2655 #[test]
2656 fn test_tool_result_with_url_image() {
2657 use crate::OneOrMany;
2659 use crate::message::{
2660 DocumentSourceKind, Image, ImageMediaType, ToolResult, ToolResultContent,
2661 };
2662
2663 let tool_result = ToolResult {
2664 id: "screenshot_tool".to_string(),
2665 call_id: None,
2666 content: OneOrMany::one(ToolResultContent::Image(Image {
2667 data: DocumentSourceKind::Url("https://example.com/image.png".to_string()),
2668 media_type: Some(ImageMediaType::PNG),
2669 detail: None,
2670 additional_params: None,
2671 })),
2672 };
2673
2674 let user_content = message::UserContent::ToolResult(tool_result);
2675 let msg = message::Message::User {
2676 content: OneOrMany::one(user_content),
2677 };
2678
2679 let content: Content = msg.try_into().expect("Should convert to Gemini Content");
2680 assert_eq!(content.role, Some(Role::User));
2681 assert_eq!(content.parts.len(), 1);
2682
2683 if let Some(Part {
2684 part: PartKind::FunctionResponse(function_response),
2685 ..
2686 }) = content.parts.first()
2687 {
2688 assert_eq!(function_response.name, "screenshot_tool");
2689
2690 assert!(function_response.parts.is_some());
2692 let parts = function_response.parts.as_ref().unwrap();
2693 assert_eq!(parts.len(), 1);
2694
2695 let image_part = &parts[0];
2696 assert!(image_part.file_data.is_some());
2697 let file_data = image_part.file_data.as_ref().unwrap();
2698 assert_eq!(file_data.file_uri, "https://example.com/image.png");
2699 assert_eq!(file_data.mime_type.as_ref().unwrap(), "image/png");
2700 } else {
2701 panic!("Expected FunctionResponse part");
2702 }
2703 }
2704
2705 #[test]
2706 fn test_create_request_body_with_documents() {
2707 use crate::OneOrMany;
2709 use crate::completion::request::{CompletionRequest, Document};
2710 use crate::message::Message;
2711
2712 let documents = vec![
2713 Document {
2714 id: "doc1".to_string(),
2715 text: "Note: first.md\nContent: First note".to_string(),
2716 additional_props: std::collections::HashMap::new(),
2717 },
2718 Document {
2719 id: "doc2".to_string(),
2720 text: "Note: second.md\nContent: Second note".to_string(),
2721 additional_props: std::collections::HashMap::new(),
2722 },
2723 ];
2724
2725 let completion_request = CompletionRequest {
2726 preamble: Some("You are a helpful assistant".to_string()),
2727 chat_history: OneOrMany::one(Message::user("What are my notes about?")),
2728 documents: documents.clone(),
2729 tools: vec![],
2730 temperature: None,
2731 model: None,
2732 output_schema: None,
2733 max_tokens: None,
2734 tool_choice: None,
2735 additional_params: None,
2736 };
2737
2738 let request = create_request_body(completion_request).unwrap();
2739
2740 assert_eq!(
2742 request.contents.len(),
2743 2,
2744 "Expected 2 contents (documents + user message)"
2745 );
2746
2747 assert_eq!(request.contents[0].role, Some(Role::User));
2749 assert_eq!(
2750 request.contents[0].parts.len(),
2751 2,
2752 "Expected 2 document parts"
2753 );
2754
2755 for part in &request.contents[0].parts {
2757 if let Part {
2758 part: PartKind::Text(text),
2759 ..
2760 } = part
2761 {
2762 assert!(
2763 text.contains("Note:") && text.contains("Content:"),
2764 "Document should contain note metadata"
2765 );
2766 } else {
2767 panic!("Document parts should be text, not {:?}", part);
2768 }
2769 }
2770
2771 assert_eq!(request.contents[1].role, Some(Role::User));
2773 if let Part {
2774 part: PartKind::Text(text),
2775 ..
2776 } = &request.contents[1].parts[0]
2777 {
2778 assert_eq!(text, "What are my notes about?");
2779 } else {
2780 panic!("Expected user message to be text");
2781 }
2782 }
2783
2784 #[test]
2785 fn test_create_request_body_without_documents() {
2786 use crate::OneOrMany;
2788 use crate::completion::request::CompletionRequest;
2789 use crate::message::Message;
2790
2791 let completion_request = CompletionRequest {
2792 preamble: Some("You are a helpful assistant".to_string()),
2793 chat_history: OneOrMany::one(Message::user("Hello")),
2794 documents: vec![], tools: vec![],
2796 temperature: None,
2797 max_tokens: None,
2798 tool_choice: None,
2799 model: None,
2800 output_schema: None,
2801 additional_params: None,
2802 };
2803
2804 let request = create_request_body(completion_request).unwrap();
2805
2806 assert_eq!(request.contents.len(), 1, "Expected only user message");
2808 assert_eq!(request.contents[0].role, Some(Role::User));
2809
2810 if let Part {
2811 part: PartKind::Text(text),
2812 ..
2813 } = &request.contents[0].parts[0]
2814 {
2815 assert_eq!(text, "Hello");
2816 } else {
2817 panic!("Expected user message to be text");
2818 }
2819 }
2820
2821 #[test]
2822 fn test_from_tool_output_parses_image_json() {
2823 use crate::message::{DocumentSourceKind, ToolResultContent};
2825
2826 let image_json = r#"{"type": "image", "data": "base64data==", "mimeType": "image/jpeg"}"#;
2828 let result = ToolResultContent::from_tool_output(image_json);
2829
2830 assert_eq!(result.len(), 1);
2831 if let ToolResultContent::Image(img) = result.first() {
2832 assert!(matches!(img.data, DocumentSourceKind::Base64(_)));
2833 if let DocumentSourceKind::Base64(data) = &img.data {
2834 assert_eq!(data, "base64data==");
2835 }
2836 assert_eq!(img.media_type, Some(crate::message::ImageMediaType::JPEG));
2837 } else {
2838 panic!("Expected Image content");
2839 }
2840 }
2841
2842 #[test]
2843 fn test_from_tool_output_parses_hybrid_json() {
2844 use crate::message::{DocumentSourceKind, ToolResultContent};
2846
2847 let hybrid_json = r#"{
2848 "response": {"status": "ok", "count": 42},
2849 "parts": [
2850 {"type": "image", "data": "imgdata1==", "mimeType": "image/png"},
2851 {"type": "image", "data": "https://example.com/img.jpg", "mimeType": "image/jpeg"}
2852 ]
2853 }"#;
2854
2855 let result = ToolResultContent::from_tool_output(hybrid_json);
2856
2857 assert_eq!(result.len(), 3);
2859
2860 let items: Vec<_> = result.iter().collect();
2861
2862 if let ToolResultContent::Text(text) = &items[0] {
2864 assert!(text.text.contains("status"));
2865 assert!(text.text.contains("ok"));
2866 } else {
2867 panic!("Expected Text content first");
2868 }
2869
2870 if let ToolResultContent::Image(img) = &items[1] {
2872 assert!(matches!(img.data, DocumentSourceKind::Base64(_)));
2873 } else {
2874 panic!("Expected Image content second");
2875 }
2876
2877 if let ToolResultContent::Image(img) = &items[2] {
2879 assert!(matches!(img.data, DocumentSourceKind::Url(_)));
2880 } else {
2881 panic!("Expected Image content third");
2882 }
2883 }
2884
2885 #[tokio::test]
2889 #[ignore = "requires GEMINI_API_KEY environment variable"]
2890 async fn test_gemini_agent_with_image_tool_result_e2e() {
2891 use crate::completion::{Prompt, ToolDefinition};
2892 use crate::prelude::*;
2893 use crate::providers::gemini;
2894 use crate::tool::Tool;
2895 use serde::{Deserialize, Serialize};
2896
2897 #[derive(Debug, Serialize, Deserialize)]
2899 struct ImageGeneratorTool;
2900
2901 #[derive(Debug, thiserror::Error)]
2902 #[error("Image generation error")]
2903 struct ImageToolError;
2904
2905 impl Tool for ImageGeneratorTool {
2906 const NAME: &'static str = "generate_test_image";
2907 type Error = ImageToolError;
2908 type Args = serde_json::Value;
2909 type Output = String;
2911
2912 async fn definition(&self, _prompt: String) -> ToolDefinition {
2913 ToolDefinition {
2914 name: "generate_test_image".to_string(),
2915 description: "Generates a small test image (a 1x1 red pixel). Call this tool when asked to generate or show an image.".to_string(),
2916 parameters: json!({
2917 "type": "object",
2918 "properties": {},
2919 "required": []
2920 }),
2921 }
2922 }
2923
2924 async fn call(&self, _args: Self::Args) -> Result<Self::Output, Self::Error> {
2925 Ok(json!({
2928 "type": "image",
2929 "data": "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8DwHwAFBQIAX8jx0gAAAABJRU5ErkJggg==",
2930 "mimeType": "image/png"
2931 }).to_string())
2932 }
2933 }
2934
2935 let client = gemini::Client::from_env();
2936
2937 let agent = client
2938 .agent("gemini-3-flash-preview")
2939 .preamble("You are a helpful assistant. When asked about images, use the generate_test_image tool to create one, then describe what you see in the image.")
2940 .tool(ImageGeneratorTool)
2941 .build();
2942
2943 let response = agent
2945 .prompt("Please generate a test image and tell me what color the pixel is.")
2946 .await;
2947
2948 assert!(
2951 response.is_ok(),
2952 "Gemini should successfully process tool result with image: {:?}",
2953 response.err()
2954 );
2955
2956 let response_text = response.unwrap();
2957 println!("Response: {response_text}");
2958 assert!(!response_text.is_empty(), "Response should not be empty");
2960 }
2961}