1pub const GEMINI_3_1_FLASH_LITE_PREVIEW: &str = "gemini-3.1-flash-lite-preview";
7pub const GEMINI_3_FLASH_PREVIEW: &str = "gemini-3-flash-preview";
9pub const GEMINI_2_5_PRO_PREVIEW_06_05: &str = "gemini-2.5-pro-preview-06-05";
11pub const GEMINI_2_5_PRO_PREVIEW_05_06: &str = "gemini-2.5-pro-preview-05-06";
13pub const GEMINI_2_5_PRO_PREVIEW_03_25: &str = "gemini-2.5-pro-preview-03-25";
15pub const GEMINI_2_5_FLASH_PREVIEW_04_17: &str = "gemini-2.5-flash-preview-04-17";
17pub const GEMINI_2_5_PRO_EXP_03_25: &str = "gemini-2.5-pro-exp-03-25";
19pub const GEMINI_2_5_FLASH: &str = "gemini-2.5-flash";
21pub const GEMINI_2_0_FLASH_LITE: &str = "gemini-2.0-flash-lite";
23pub const GEMINI_2_0_FLASH: &str = "gemini-2.0-flash";
25
26use self::gemini_api_types::Schema;
27use crate::http_client::HttpClientExt;
28use crate::message::{self, MimeType, Reasoning};
29
30use crate::providers::gemini::completion::gemini_api_types::{
31 AdditionalParameters, FunctionCallingMode, ToolConfig,
32};
33use crate::providers::gemini::streaming::StreamingCompletionResponse;
34use crate::telemetry::SpanCombinator;
35use crate::{
36 OneOrMany,
37 completion::{self, CompletionError, CompletionRequest},
38};
39use gemini_api_types::{
40 Content, FunctionDeclaration, GenerateContentRequest, GenerateContentResponse,
41 GenerationConfig, Part, PartKind, Role, Tool,
42};
43use serde_json::{Map, Value};
44use std::convert::TryFrom;
45use tracing::{Level, enabled, info_span};
46use tracing_futures::Instrument;
47
48use super::Client;
49
50#[derive(Clone, Debug)]
55pub struct CompletionModel<T = reqwest::Client> {
56 pub(crate) client: Client<T>,
57 pub model: String,
58}
59
60impl<T> CompletionModel<T> {
61 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
62 Self {
63 client,
64 model: model.into(),
65 }
66 }
67
68 pub fn with_model(client: Client<T>, model: &str) -> Self {
69 Self {
70 client,
71 model: model.into(),
72 }
73 }
74}
75
76impl<T> completion::CompletionModel for CompletionModel<T>
77where
78 T: HttpClientExt + Clone + 'static,
79{
80 type Response = GenerateContentResponse;
81 type StreamingResponse = StreamingCompletionResponse;
82 type Client = super::Client<T>;
83
84 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
85 Self::new(client.clone(), model)
86 }
87
88 async fn completion(
89 &self,
90 completion_request: CompletionRequest,
91 ) -> Result<completion::CompletionResponse<GenerateContentResponse>, CompletionError> {
92 let request_model = resolve_request_model(&self.model, &completion_request);
93 let span = if tracing::Span::current().is_disabled() {
94 info_span!(
95 target: "rig::completions",
96 "generate_content",
97 gen_ai.operation.name = "generate_content",
98 gen_ai.provider.name = "gcp.gemini",
99 gen_ai.request.model = &request_model,
100 gen_ai.system_instructions = &completion_request.preamble,
101 gen_ai.response.id = tracing::field::Empty,
102 gen_ai.response.model = tracing::field::Empty,
103 gen_ai.usage.output_tokens = tracing::field::Empty,
104 gen_ai.usage.input_tokens = tracing::field::Empty,
105 gen_ai.usage.cached_tokens = tracing::field::Empty,
106 )
107 } else {
108 tracing::Span::current()
109 };
110
111 let request = create_request_body(completion_request)?;
112
113 if enabled!(Level::TRACE) {
114 tracing::trace!(
115 target: "rig::completions",
116 "Gemini completion request: {}",
117 serde_json::to_string_pretty(&request)?
118 );
119 }
120
121 let body = serde_json::to_vec(&request)?;
122
123 let path = completion_endpoint(&request_model);
124
125 let request = self
126 .client
127 .post(path.as_str())?
128 .body(body)
129 .map_err(|e| CompletionError::HttpError(e.into()))?;
130
131 async move {
132 let response = self.client.send::<_, Vec<u8>>(request).await?;
133
134 if response.status().is_success() {
135 let response_body = response
136 .into_body()
137 .await
138 .map_err(CompletionError::HttpError)?;
139
140 let response_text = String::from_utf8_lossy(&response_body).to_string();
141
142 let response: GenerateContentResponse = serde_json::from_slice(&response_body)
143 .map_err(|err| {
144 tracing::error!(
145 error = %err,
146 body = %response_text,
147 "Failed to deserialize Gemini completion response"
148 );
149 CompletionError::JsonError(err)
150 })?;
151
152 let span = tracing::Span::current();
153 span.record_response_metadata(&response);
154 span.record_token_usage(&response.usage_metadata);
155
156 if enabled!(Level::TRACE) {
157 tracing::trace!(
158 target: "rig::completions",
159 "Gemini completion response: {}",
160 serde_json::to_string_pretty(&response)?
161 );
162 }
163
164 response.try_into()
165 } else {
166 let text = String::from_utf8_lossy(
167 &response
168 .into_body()
169 .await
170 .map_err(CompletionError::HttpError)?,
171 )
172 .into();
173
174 Err(CompletionError::ProviderError(text))
175 }
176 }
177 .instrument(span)
178 .await
179 }
180
181 async fn stream(
182 &self,
183 request: CompletionRequest,
184 ) -> Result<
185 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
186 CompletionError,
187 > {
188 CompletionModel::stream(self, request).await
189 }
190}
191
192pub(crate) fn create_request_body(
193 completion_request: CompletionRequest,
194) -> Result<GenerateContentRequest, CompletionError> {
195 let documents_message = completion_request.normalized_documents();
196
197 let CompletionRequest {
198 model: _,
199 preamble,
200 chat_history,
201 documents: _,
202 tools: function_tools,
203 temperature,
204 max_tokens,
205 tool_choice,
206 mut additional_params,
207 output_schema,
208 } = completion_request;
209
210 let mut full_history = Vec::new();
211 if let Some(msg) = documents_message {
212 full_history.push(msg);
213 }
214 full_history.extend(chat_history);
215 let (history_system, full_history) = split_system_messages_from_history(full_history);
216
217 let mut additional_params_payload = additional_params
218 .take()
219 .unwrap_or_else(|| Value::Object(Map::new()));
220 let mut additional_tools =
221 extract_tools_from_additional_params(&mut additional_params_payload)?;
222
223 let AdditionalParameters {
224 mut generation_config,
225 additional_params,
226 } = serde_json::from_value::<AdditionalParameters>(additional_params_payload)?;
227
228 if let Some(schema) = output_schema {
230 let cfg = generation_config.get_or_insert_with(GenerationConfig::default);
231 cfg.response_mime_type = Some("application/json".to_string());
232 cfg.response_json_schema = Some(schema.to_value());
233 }
234
235 generation_config = generation_config.map(|mut cfg| {
236 if let Some(temp) = temperature {
237 cfg.temperature = Some(temp);
238 };
239
240 if let Some(max_tokens) = max_tokens {
241 cfg.max_output_tokens = Some(max_tokens);
242 };
243
244 cfg
245 });
246
247 let mut system_parts: Vec<Part> = Vec::new();
248 if let Some(preamble) = preamble.filter(|preamble| !preamble.is_empty()) {
249 system_parts.push(preamble.into());
250 }
251 for content in history_system {
252 if !content.is_empty() {
253 system_parts.push(content.into());
254 }
255 }
256 let system_instruction = if system_parts.is_empty() {
257 None
258 } else {
259 Some(Content {
260 parts: system_parts,
261 role: Some(Role::Model),
262 })
263 };
264
265 let mut tools = if function_tools.is_empty() {
266 Vec::new()
267 } else {
268 vec![serde_json::to_value(Tool::try_from(function_tools)?)?]
269 };
270 tools.append(&mut additional_tools);
271 let tools = if tools.is_empty() { None } else { Some(tools) };
272
273 let tool_config = if let Some(cfg) = tool_choice {
274 Some(ToolConfig {
275 function_calling_config: Some(FunctionCallingMode::try_from(cfg)?),
276 })
277 } else {
278 None
279 };
280
281 let request = GenerateContentRequest {
282 contents: full_history
283 .into_iter()
284 .map(|msg| {
285 msg.try_into()
286 .map_err(|e| CompletionError::RequestError(Box::new(e)))
287 })
288 .collect::<Result<Vec<_>, _>>()?,
289 generation_config,
290 safety_settings: None,
291 tools,
292 tool_config,
293 system_instruction,
294 additional_params,
295 };
296
297 Ok(request)
298}
299
300fn split_system_messages_from_history(
301 history: Vec<completion::Message>,
302) -> (Vec<String>, Vec<completion::Message>) {
303 let mut system = Vec::new();
304 let mut remaining = Vec::new();
305
306 for message in history {
307 match message {
308 completion::Message::System { content } => system.push(content),
309 other => remaining.push(other),
310 }
311 }
312
313 (system, remaining)
314}
315
316fn extract_tools_from_additional_params(
317 additional_params: &mut Value,
318) -> Result<Vec<Value>, CompletionError> {
319 if let Some(map) = additional_params.as_object_mut()
320 && let Some(raw_tools) = map.remove("tools")
321 {
322 return serde_json::from_value::<Vec<Value>>(raw_tools).map_err(|err| {
323 CompletionError::RequestError(
324 format!("Invalid Gemini `additional_params.tools` payload: {err}").into(),
325 )
326 });
327 }
328
329 Ok(Vec::new())
330}
331
332pub(crate) fn resolve_request_model(
333 default_model: &str,
334 completion_request: &CompletionRequest,
335) -> String {
336 completion_request
337 .model
338 .clone()
339 .unwrap_or_else(|| default_model.to_string())
340}
341
342pub(crate) fn completion_endpoint(model: &str) -> String {
343 format!("/v1beta/models/{model}:generateContent")
344}
345
346pub(crate) fn streaming_endpoint(model: &str) -> String {
347 format!("/v1beta/models/{model}:streamGenerateContent")
348}
349
350impl TryFrom<completion::ToolDefinition> for Tool {
351 type Error = CompletionError;
352
353 fn try_from(tool: completion::ToolDefinition) -> Result<Self, Self::Error> {
354 let parameters: Option<Schema> =
355 if tool.parameters == serde_json::json!({"type": "object", "properties": {}}) {
356 None
357 } else {
358 Some(tool.parameters.try_into()?)
359 };
360
361 Ok(Self {
362 function_declarations: vec![FunctionDeclaration {
363 name: tool.name,
364 description: tool.description,
365 parameters,
366 }],
367 code_execution: None,
368 })
369 }
370}
371
372impl TryFrom<Vec<completion::ToolDefinition>> for Tool {
373 type Error = CompletionError;
374
375 fn try_from(tools: Vec<completion::ToolDefinition>) -> Result<Self, Self::Error> {
376 let mut function_declarations = Vec::new();
377
378 for tool in tools {
379 let parameters =
380 if tool.parameters == serde_json::json!({"type": "object", "properties": {}}) {
381 None
382 } else {
383 match tool.parameters.try_into() {
384 Ok(schema) => Some(schema),
385 Err(e) => {
386 let emsg = format!(
387 "Tool '{}' could not be converted to a schema: {:?}",
388 tool.name, e,
389 );
390 return Err(CompletionError::ProviderError(emsg));
391 }
392 }
393 };
394
395 function_declarations.push(FunctionDeclaration {
396 name: tool.name,
397 description: tool.description,
398 parameters,
399 });
400 }
401
402 Ok(Self {
403 function_declarations,
404 code_execution: None,
405 })
406 }
407}
408
409impl TryFrom<GenerateContentResponse> for completion::CompletionResponse<GenerateContentResponse> {
410 type Error = CompletionError;
411
412 fn try_from(response: GenerateContentResponse) -> Result<Self, Self::Error> {
413 let candidate = response.candidates.first().ok_or_else(|| {
414 CompletionError::ResponseError("No response candidates in response".into())
415 })?;
416
417 let content = candidate
418 .content
419 .as_ref()
420 .ok_or_else(|| {
421 let reason = candidate
422 .finish_reason
423 .as_ref()
424 .map(|r| format!("finish_reason={r:?}"))
425 .unwrap_or_else(|| "finish_reason=<unknown>".to_string());
426 let message = candidate
427 .finish_message
428 .as_deref()
429 .unwrap_or("no finish message provided");
430 CompletionError::ResponseError(format!(
431 "Gemini candidate missing content ({reason}, finish_message={message})"
432 ))
433 })?
434 .parts
435 .iter()
436 .map(
437 |Part {
438 thought,
439 thought_signature,
440 part,
441 ..
442 }| {
443 Ok(match part {
444 PartKind::Text(text) => {
445 if let Some(thought) = thought
446 && *thought
447 {
448 completion::AssistantContent::Reasoning(
449 Reasoning::new_with_signature(text, thought_signature.clone()),
450 )
451 } else {
452 completion::AssistantContent::text(text)
453 }
454 }
455 PartKind::InlineData(inline_data) => {
456 let mime_type =
457 message::MediaType::from_mime_type(&inline_data.mime_type);
458
459 match mime_type {
460 Some(message::MediaType::Image(media_type)) => {
461 message::AssistantContent::image_base64(
462 &inline_data.data,
463 Some(media_type),
464 Some(message::ImageDetail::default()),
465 )
466 }
467 _ => {
468 return Err(CompletionError::ResponseError(format!(
469 "Unsupported media type {mime_type:?}"
470 )));
471 }
472 }
473 }
474 PartKind::FunctionCall(function_call) => {
475 completion::AssistantContent::ToolCall(
476 message::ToolCall::new(
477 function_call.name.clone(),
478 message::ToolFunction::new(
479 function_call.name.clone(),
480 function_call.args.clone(),
481 ),
482 )
483 .with_signature(thought_signature.clone()),
484 )
485 }
486 _ => {
487 return Err(CompletionError::ResponseError(
488 "Response did not contain a message or tool call".into(),
489 ));
490 }
491 })
492 },
493 )
494 .collect::<Result<Vec<_>, _>>()?;
495
496 let choice = OneOrMany::many(content).map_err(|_| {
497 CompletionError::ResponseError(
498 "Response contained no message or tool call (empty)".to_owned(),
499 )
500 })?;
501
502 let usage = response
503 .usage_metadata
504 .as_ref()
505 .map(|usage| completion::Usage {
506 input_tokens: usage.prompt_token_count as u64,
507 output_tokens: usage.candidates_token_count.unwrap_or(0) as u64,
508 total_tokens: usage.total_token_count as u64,
509 cached_input_tokens: 0,
510 cache_creation_input_tokens: 0,
511 })
512 .unwrap_or_default();
513
514 Ok(completion::CompletionResponse {
515 choice,
516 usage,
517 raw_response: response,
518 message_id: None,
519 })
520 }
521}
522
523pub mod gemini_api_types {
524 use crate::telemetry::ProviderResponseExt;
525 use std::{collections::HashMap, convert::Infallible, str::FromStr};
526
527 use serde::{Deserialize, Serialize};
531 use serde_json::{Value, json};
532
533 use crate::completion::GetTokenUsage;
534 use crate::message::{DocumentSourceKind, ImageMediaType, MessageError, MimeType};
535 use crate::{
536 completion::CompletionError,
537 message::{self},
538 providers::gemini::gemini_api_types::{CodeExecutionResult, ExecutableCode},
539 };
540
541 #[derive(Debug, Deserialize, Serialize, Default)]
542 #[serde(rename_all = "camelCase")]
543 pub struct AdditionalParameters {
544 pub generation_config: Option<GenerationConfig>,
546 #[serde(flatten, skip_serializing_if = "Option::is_none")]
548 pub additional_params: Option<serde_json::Value>,
549 }
550
551 impl AdditionalParameters {
552 pub fn with_config(mut self, cfg: GenerationConfig) -> Self {
553 self.generation_config = Some(cfg);
554 self
555 }
556
557 pub fn with_params(mut self, params: serde_json::Value) -> Self {
558 self.additional_params = Some(params);
559 self
560 }
561 }
562
563 #[derive(Debug, Deserialize, Serialize)]
571 #[serde(rename_all = "camelCase")]
572 pub struct GenerateContentResponse {
573 pub response_id: String,
574 pub candidates: Vec<ContentCandidate>,
576 pub prompt_feedback: Option<PromptFeedback>,
578 pub usage_metadata: Option<UsageMetadata>,
580 pub model_version: Option<String>,
581 }
582
583 impl ProviderResponseExt for GenerateContentResponse {
584 type OutputMessage = ContentCandidate;
585 type Usage = UsageMetadata;
586
587 fn get_response_id(&self) -> Option<String> {
588 Some(self.response_id.clone())
589 }
590
591 fn get_response_model_name(&self) -> Option<String> {
592 None
593 }
594
595 fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
596 self.candidates.clone()
597 }
598
599 fn get_text_response(&self) -> Option<String> {
600 let str = self
601 .candidates
602 .iter()
603 .filter_map(|x| {
604 let content = x.content.as_ref()?;
605 if content.role.as_ref().is_none_or(|y| y != &Role::Model) {
606 return None;
607 }
608
609 let res = content
610 .parts
611 .iter()
612 .filter_map(|part| {
613 if let PartKind::Text(ref str) = part.part {
614 Some(str.to_owned())
615 } else {
616 None
617 }
618 })
619 .collect::<Vec<String>>()
620 .join("\n");
621
622 Some(res)
623 })
624 .collect::<Vec<String>>()
625 .join("\n");
626
627 if str.is_empty() { None } else { Some(str) }
628 }
629
630 fn get_usage(&self) -> Option<Self::Usage> {
631 self.usage_metadata.clone()
632 }
633 }
634
635 #[derive(Clone, Debug, Deserialize, Serialize)]
637 #[serde(rename_all = "camelCase")]
638 pub struct ContentCandidate {
639 #[serde(skip_serializing_if = "Option::is_none")]
641 pub content: Option<Content>,
642 pub finish_reason: Option<FinishReason>,
645 pub safety_ratings: Option<Vec<SafetyRating>>,
648 pub citation_metadata: Option<CitationMetadata>,
652 pub token_count: Option<i32>,
654 pub avg_logprobs: Option<f64>,
656 pub logprobs_result: Option<LogprobsResult>,
658 pub index: Option<i32>,
660 pub finish_message: Option<String>,
662 }
663
664 #[derive(Clone, Debug, Deserialize, Serialize)]
665 pub struct Content {
666 #[serde(default)]
668 pub parts: Vec<Part>,
669 pub role: Option<Role>,
672 }
673
674 impl TryFrom<message::Message> for Content {
675 type Error = message::MessageError;
676
677 fn try_from(msg: message::Message) -> Result<Self, Self::Error> {
678 Ok(match msg {
679 message::Message::System { content } => Content {
680 parts: vec![content.into()],
681 role: Some(Role::User),
682 },
683 message::Message::User { content } => Content {
684 parts: content
685 .into_iter()
686 .map(|c| c.try_into())
687 .collect::<Result<Vec<_>, _>>()?,
688 role: Some(Role::User),
689 },
690 message::Message::Assistant { content, .. } => Content {
691 role: Some(Role::Model),
692 parts: content
693 .into_iter()
694 .map(|content| content.try_into())
695 .collect::<Result<Vec<_>, _>>()?,
696 },
697 })
698 }
699 }
700
701 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
702 #[serde(rename_all = "lowercase")]
703 pub enum Role {
704 User,
705 Model,
706 }
707
708 #[derive(Debug, Default, Deserialize, Serialize, Clone, PartialEq)]
709 #[serde(rename_all = "camelCase")]
710 pub struct Part {
711 #[serde(skip_serializing_if = "Option::is_none")]
713 pub thought: Option<bool>,
714 #[serde(skip_serializing_if = "Option::is_none")]
716 pub thought_signature: Option<String>,
717 #[serde(flatten)]
718 pub part: PartKind,
719 #[serde(flatten, skip_serializing_if = "Option::is_none")]
720 pub additional_params: Option<Value>,
721 }
722
723 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
727 #[serde(rename_all = "camelCase")]
728 pub enum PartKind {
729 Text(String),
730 InlineData(Blob),
731 FunctionCall(FunctionCall),
732 FunctionResponse(FunctionResponse),
733 FileData(FileData),
734 ExecutableCode(ExecutableCode),
735 CodeExecutionResult(CodeExecutionResult),
736 }
737
738 impl Default for PartKind {
741 fn default() -> Self {
742 Self::Text(String::new())
743 }
744 }
745
746 impl From<String> for Part {
747 fn from(text: String) -> Self {
748 Self {
749 thought: Some(false),
750 thought_signature: None,
751 part: PartKind::Text(text),
752 additional_params: None,
753 }
754 }
755 }
756
757 impl From<&str> for Part {
758 fn from(text: &str) -> Self {
759 Self::from(text.to_string())
760 }
761 }
762
763 impl FromStr for Part {
764 type Err = Infallible;
765
766 fn from_str(s: &str) -> Result<Self, Self::Err> {
767 Ok(s.into())
768 }
769 }
770
771 impl TryFrom<(ImageMediaType, DocumentSourceKind)> for PartKind {
772 type Error = message::MessageError;
773 fn try_from(
774 (mime_type, doc_src): (ImageMediaType, DocumentSourceKind),
775 ) -> Result<Self, Self::Error> {
776 let mime_type = mime_type.to_mime_type().to_string();
777 let part = match doc_src {
778 DocumentSourceKind::Url(url) => PartKind::FileData(FileData {
779 mime_type: Some(mime_type),
780 file_uri: url,
781 }),
782 DocumentSourceKind::Base64(data) | DocumentSourceKind::String(data) => {
783 PartKind::InlineData(Blob { mime_type, data })
784 }
785 DocumentSourceKind::Raw(_) => {
786 return Err(message::MessageError::ConversionError(
787 "Raw files not supported, encode as base64 first".into(),
788 ));
789 }
790 DocumentSourceKind::Unknown => {
791 return Err(message::MessageError::ConversionError(
792 "Can't convert an unknown document source".to_string(),
793 ));
794 }
795 };
796
797 Ok(part)
798 }
799 }
800
801 impl TryFrom<message::UserContent> for Part {
802 type Error = message::MessageError;
803
804 fn try_from(content: message::UserContent) -> Result<Self, Self::Error> {
805 match content {
806 message::UserContent::Text(message::Text { text }) => Ok(Part {
807 thought: Some(false),
808 thought_signature: None,
809 part: PartKind::Text(text),
810 additional_params: None,
811 }),
812 message::UserContent::ToolResult(message::ToolResult { id, content, .. }) => {
813 let mut response_json: Option<serde_json::Value> = None;
814 let mut parts: Vec<FunctionResponsePart> = Vec::new();
815
816 for item in content.iter() {
817 match item {
818 message::ToolResultContent::Text(text) => {
819 let result: serde_json::Value =
820 serde_json::from_str(&text.text).unwrap_or_else(|error| {
821 tracing::trace!(
822 ?error,
823 "Tool result is not a valid JSON, treat it as normal string"
824 );
825 json!(&text.text)
826 });
827
828 response_json = Some(match response_json {
829 Some(mut existing) => {
830 if let serde_json::Value::Object(ref mut map) = existing {
831 map.insert("text".to_string(), result);
832 }
833 existing
834 }
835 None => json!({ "result": result }),
836 });
837 }
838 message::ToolResultContent::Image(image) => {
839 let part = match &image.data {
840 DocumentSourceKind::Base64(b64) => {
841 let mime_type = image
842 .media_type
843 .as_ref()
844 .ok_or(message::MessageError::ConversionError(
845 "Image media type is required for Gemini tool results".to_string(),
846 ))?
847 .to_mime_type();
848
849 FunctionResponsePart {
850 inline_data: Some(FunctionResponseInlineData {
851 mime_type: mime_type.to_string(),
852 data: b64.clone(),
853 display_name: None,
854 }),
855 file_data: None,
856 }
857 }
858 DocumentSourceKind::Url(url) => {
859 let mime_type = image
860 .media_type
861 .as_ref()
862 .map(|mt| mt.to_mime_type().to_string());
863
864 FunctionResponsePart {
865 inline_data: None,
866 file_data: Some(FileData {
867 mime_type,
868 file_uri: url.clone(),
869 }),
870 }
871 }
872 _ => {
873 return Err(message::MessageError::ConversionError(
874 "Unsupported image source kind for tool results"
875 .to_string(),
876 ));
877 }
878 };
879 parts.push(part);
880 }
881 }
882 }
883
884 Ok(Part {
885 thought: Some(false),
886 thought_signature: None,
887 part: PartKind::FunctionResponse(FunctionResponse {
888 name: id,
889 response: response_json,
890 parts: if parts.is_empty() { None } else { Some(parts) },
891 }),
892 additional_params: None,
893 })
894 }
895 message::UserContent::Image(message::Image {
896 data, media_type, ..
897 }) => match media_type {
898 Some(media_type) => match media_type {
899 message::ImageMediaType::JPEG
900 | message::ImageMediaType::PNG
901 | message::ImageMediaType::WEBP
902 | message::ImageMediaType::HEIC
903 | message::ImageMediaType::HEIF => {
904 let part = PartKind::try_from((media_type, data))?;
905 Ok(Part {
906 thought: Some(false),
907 thought_signature: None,
908 part,
909 additional_params: None,
910 })
911 }
912 _ => Err(message::MessageError::ConversionError(format!(
913 "Unsupported image media type {media_type:?}"
914 ))),
915 },
916 None => Err(message::MessageError::ConversionError(
917 "Media type for image is required for Gemini".to_string(),
918 )),
919 },
920 message::UserContent::Document(message::Document {
921 data, media_type, ..
922 }) => {
923 let Some(media_type) = media_type else {
924 return Err(MessageError::ConversionError(
925 "A mime type is required for document inputs to Gemini".to_string(),
926 ));
927 };
928
929 if matches!(
932 media_type,
933 message::DocumentMediaType::TXT
934 | message::DocumentMediaType::RTF
935 | message::DocumentMediaType::HTML
936 | message::DocumentMediaType::CSS
937 | message::DocumentMediaType::MARKDOWN
938 | message::DocumentMediaType::CSV
939 | message::DocumentMediaType::XML
940 | message::DocumentMediaType::Javascript
941 | message::DocumentMediaType::Python
942 ) {
943 use base64::Engine;
944 let part = match data {
945 DocumentSourceKind::String(text) => PartKind::Text(text),
946 DocumentSourceKind::Base64(data) => {
947 let text = String::from_utf8(
949 base64::engine::general_purpose::STANDARD
950 .decode(&data)
951 .map_err(|e| {
952 MessageError::ConversionError(format!(
953 "Failed to decode base64: {e}"
954 ))
955 })?,
956 )
957 .map_err(|e| {
958 MessageError::ConversionError(format!(
959 "Invalid UTF-8 in document: {e}"
960 ))
961 })?;
962 PartKind::Text(text)
963 }
964 DocumentSourceKind::Url(file_uri) => PartKind::FileData(FileData {
965 mime_type: Some(media_type.to_mime_type().to_string()),
966 file_uri,
967 }),
968 DocumentSourceKind::Raw(_) => {
969 return Err(MessageError::ConversionError(
970 "Raw files not supported, encode as base64 first".to_string(),
971 ));
972 }
973 DocumentSourceKind::Unknown => {
974 return Err(MessageError::ConversionError(
975 "Document has no body".to_string(),
976 ));
977 }
978 };
979
980 Ok(Part {
981 thought: Some(false),
982 part,
983 ..Default::default()
984 })
985 } else if !media_type.is_code() {
986 let mime_type = media_type.to_mime_type().to_string();
987
988 let part = match data {
989 DocumentSourceKind::Url(file_uri) => PartKind::FileData(FileData {
990 mime_type: Some(mime_type),
991 file_uri,
992 }),
993 DocumentSourceKind::Base64(data) | DocumentSourceKind::String(data) => {
994 PartKind::InlineData(Blob { mime_type, data })
995 }
996 DocumentSourceKind::Raw(_) => {
997 return Err(message::MessageError::ConversionError(
998 "Raw files not supported, encode as base64 first".into(),
999 ));
1000 }
1001 _ => {
1002 return Err(message::MessageError::ConversionError(
1003 "Document has no body".to_string(),
1004 ));
1005 }
1006 };
1007
1008 Ok(Part {
1009 thought: Some(false),
1010 part,
1011 ..Default::default()
1012 })
1013 } else {
1014 Err(message::MessageError::ConversionError(format!(
1015 "Unsupported document media type {media_type:?}"
1016 )))
1017 }
1018 }
1019
1020 message::UserContent::Audio(message::Audio {
1021 data, media_type, ..
1022 }) => {
1023 let Some(media_type) = media_type else {
1024 return Err(MessageError::ConversionError(
1025 "A mime type is required for audio inputs to Gemini".to_string(),
1026 ));
1027 };
1028
1029 let mime_type = media_type.to_mime_type().to_string();
1030
1031 let part = match data {
1032 DocumentSourceKind::Base64(data) => {
1033 PartKind::InlineData(Blob { data, mime_type })
1034 }
1035
1036 DocumentSourceKind::Url(file_uri) => PartKind::FileData(FileData {
1037 mime_type: Some(mime_type),
1038 file_uri,
1039 }),
1040 DocumentSourceKind::String(_) => {
1041 return Err(message::MessageError::ConversionError(
1042 "Strings cannot be used as audio files!".into(),
1043 ));
1044 }
1045 DocumentSourceKind::Raw(_) => {
1046 return Err(message::MessageError::ConversionError(
1047 "Raw files not supported, encode as base64 first".into(),
1048 ));
1049 }
1050 DocumentSourceKind::Unknown => {
1051 return Err(message::MessageError::ConversionError(
1052 "Content has no body".to_string(),
1053 ));
1054 }
1055 };
1056
1057 Ok(Part {
1058 thought: Some(false),
1059 part,
1060 ..Default::default()
1061 })
1062 }
1063 message::UserContent::Video(message::Video {
1064 data,
1065 media_type,
1066 additional_params,
1067 ..
1068 }) => {
1069 let mime_type = media_type.map(|media_ty| media_ty.to_mime_type().to_string());
1070
1071 let part = match data {
1072 DocumentSourceKind::Url(file_uri) => {
1073 if file_uri.starts_with("https://www.youtube.com") {
1074 PartKind::FileData(FileData {
1075 mime_type,
1076 file_uri,
1077 })
1078 } else {
1079 if mime_type.is_none() {
1080 return Err(MessageError::ConversionError(
1081 "A mime type is required for non-Youtube video file inputs to Gemini"
1082 .to_string(),
1083 ));
1084 }
1085
1086 PartKind::FileData(FileData {
1087 mime_type,
1088 file_uri,
1089 })
1090 }
1091 }
1092 DocumentSourceKind::Base64(data) => {
1093 let Some(mime_type) = mime_type else {
1094 return Err(MessageError::ConversionError(
1095 "A media type is expected for base64 encoded strings"
1096 .to_string(),
1097 ));
1098 };
1099 PartKind::InlineData(Blob { mime_type, data })
1100 }
1101 DocumentSourceKind::String(_) => {
1102 return Err(message::MessageError::ConversionError(
1103 "Strings cannot be used as audio files!".into(),
1104 ));
1105 }
1106 DocumentSourceKind::Raw(_) => {
1107 return Err(message::MessageError::ConversionError(
1108 "Raw file data not supported, encode as base64 first".into(),
1109 ));
1110 }
1111 DocumentSourceKind::Unknown => {
1112 return Err(message::MessageError::ConversionError(
1113 "Media type for video is required for Gemini".to_string(),
1114 ));
1115 }
1116 };
1117
1118 Ok(Part {
1119 thought: Some(false),
1120 thought_signature: None,
1121 part,
1122 additional_params,
1123 })
1124 }
1125 }
1126 }
1127 }
1128
1129 impl TryFrom<message::AssistantContent> for Part {
1130 type Error = message::MessageError;
1131
1132 fn try_from(content: message::AssistantContent) -> Result<Self, Self::Error> {
1133 match content {
1134 message::AssistantContent::Text(message::Text { text }) => Ok(text.into()),
1135 message::AssistantContent::Image(message::Image {
1136 data, media_type, ..
1137 }) => match media_type {
1138 Some(media_type) => match media_type {
1139 message::ImageMediaType::JPEG
1140 | message::ImageMediaType::PNG
1141 | message::ImageMediaType::WEBP
1142 | message::ImageMediaType::HEIC
1143 | message::ImageMediaType::HEIF => {
1144 let part = PartKind::try_from((media_type, data))?;
1145 Ok(Part {
1146 thought: Some(false),
1147 thought_signature: None,
1148 part,
1149 additional_params: None,
1150 })
1151 }
1152 _ => Err(message::MessageError::ConversionError(format!(
1153 "Unsupported image media type {media_type:?}"
1154 ))),
1155 },
1156 None => Err(message::MessageError::ConversionError(
1157 "Media type for image is required for Gemini".to_string(),
1158 )),
1159 },
1160 message::AssistantContent::ToolCall(tool_call) => Ok(tool_call.into()),
1161 message::AssistantContent::Reasoning(reasoning) => Ok(Part {
1162 thought: Some(true),
1163 thought_signature: reasoning.first_signature().map(str::to_owned),
1164 part: PartKind::Text(reasoning.display_text()),
1165 additional_params: None,
1166 }),
1167 }
1168 }
1169 }
1170
1171 impl From<message::ToolCall> for Part {
1172 fn from(tool_call: message::ToolCall) -> Self {
1173 Self {
1174 thought: Some(false),
1175 thought_signature: tool_call.signature,
1176 part: PartKind::FunctionCall(FunctionCall {
1177 name: tool_call.function.name,
1178 args: tool_call.function.arguments,
1179 }),
1180 additional_params: None,
1181 }
1182 }
1183 }
1184
1185 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1188 #[serde(rename_all = "camelCase")]
1189 pub struct Blob {
1190 pub mime_type: String,
1193 pub data: String,
1195 }
1196
1197 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1200 pub struct FunctionCall {
1201 pub name: String,
1204 pub args: serde_json::Value,
1206 }
1207
1208 impl From<message::ToolCall> for FunctionCall {
1209 fn from(tool_call: message::ToolCall) -> Self {
1210 Self {
1211 name: tool_call.function.name,
1212 args: tool_call.function.arguments,
1213 }
1214 }
1215 }
1216
1217 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1221 pub struct FunctionResponse {
1222 pub name: String,
1225 #[serde(skip_serializing_if = "Option::is_none")]
1227 pub response: Option<serde_json::Value>,
1228 #[serde(skip_serializing_if = "Option::is_none")]
1230 pub parts: Option<Vec<FunctionResponsePart>>,
1231 }
1232
1233 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1235 #[serde(rename_all = "camelCase")]
1236 pub struct FunctionResponsePart {
1237 #[serde(skip_serializing_if = "Option::is_none")]
1239 pub inline_data: Option<FunctionResponseInlineData>,
1240 #[serde(skip_serializing_if = "Option::is_none")]
1242 pub file_data: Option<FileData>,
1243 }
1244
1245 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1247 #[serde(rename_all = "camelCase")]
1248 pub struct FunctionResponseInlineData {
1249 pub mime_type: String,
1251 pub data: String,
1253 #[serde(skip_serializing_if = "Option::is_none")]
1255 pub display_name: Option<String>,
1256 }
1257
1258 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1260 #[serde(rename_all = "camelCase")]
1261 pub struct FileData {
1262 pub mime_type: Option<String>,
1264 pub file_uri: String,
1266 }
1267
1268 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1269 pub struct SafetyRating {
1270 pub category: HarmCategory,
1271 pub probability: HarmProbability,
1272 }
1273
1274 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1275 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1276 pub enum HarmProbability {
1277 HarmProbabilityUnspecified,
1278 Negligible,
1279 Low,
1280 Medium,
1281 High,
1282 }
1283
1284 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
1285 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1286 pub enum HarmCategory {
1287 HarmCategoryUnspecified,
1288 HarmCategoryDerogatory,
1289 HarmCategoryToxicity,
1290 HarmCategoryViolence,
1291 HarmCategorySexually,
1292 HarmCategoryMedical,
1293 HarmCategoryDangerous,
1294 HarmCategoryHarassment,
1295 HarmCategoryHateSpeech,
1296 HarmCategorySexuallyExplicit,
1297 HarmCategoryDangerousContent,
1298 HarmCategoryCivicIntegrity,
1299 }
1300
1301 #[derive(Debug, Deserialize, Clone, Default, Serialize)]
1302 #[serde(rename_all = "camelCase")]
1303 pub struct UsageMetadata {
1304 #[serde(default)]
1305 pub prompt_token_count: i32,
1306 #[serde(skip_serializing_if = "Option::is_none")]
1307 pub cached_content_token_count: Option<i32>,
1308 #[serde(skip_serializing_if = "Option::is_none")]
1309 pub candidates_token_count: Option<i32>,
1310 pub total_token_count: i32,
1311 #[serde(skip_serializing_if = "Option::is_none")]
1312 pub thoughts_token_count: Option<i32>,
1313 }
1314
1315 impl std::fmt::Display for UsageMetadata {
1316 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1317 write!(
1318 f,
1319 "Prompt token count: {}\nCached content token count: {}\nCandidates token count: {}\nTotal token count: {}",
1320 self.prompt_token_count,
1321 match self.cached_content_token_count {
1322 Some(count) => count.to_string(),
1323 None => "n/a".to_string(),
1324 },
1325 match self.candidates_token_count {
1326 Some(count) => count.to_string(),
1327 None => "n/a".to_string(),
1328 },
1329 self.total_token_count
1330 )
1331 }
1332 }
1333
1334 impl GetTokenUsage for UsageMetadata {
1335 fn token_usage(&self) -> Option<crate::completion::Usage> {
1336 let mut usage = crate::completion::Usage::new();
1337
1338 usage.input_tokens = self.prompt_token_count as u64;
1339 usage.output_tokens = (self.cached_content_token_count.unwrap_or_default()
1340 + self.candidates_token_count.unwrap_or_default()
1341 + self.thoughts_token_count.unwrap_or_default())
1342 as u64;
1343 usage.total_tokens = usage.input_tokens + usage.output_tokens;
1344
1345 Some(usage)
1346 }
1347 }
1348
1349 #[derive(Debug, Deserialize, Serialize)]
1351 #[serde(rename_all = "camelCase")]
1352 pub struct PromptFeedback {
1353 pub block_reason: Option<BlockReason>,
1355 pub safety_ratings: Option<Vec<SafetyRating>>,
1357 }
1358
1359 #[derive(Debug, Deserialize, Serialize)]
1361 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1362 pub enum BlockReason {
1363 BlockReasonUnspecified,
1365 Safety,
1367 Other,
1369 Blocklist,
1371 ProhibitedContent,
1373 }
1374
1375 #[derive(Clone, Debug, Deserialize, Serialize)]
1376 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
1377 pub enum FinishReason {
1378 FinishReasonUnspecified,
1380 Stop,
1382 MaxTokens,
1384 Safety,
1386 Recitation,
1388 Language,
1390 Other,
1392 Blocklist,
1394 ProhibitedContent,
1396 Spii,
1398 MalformedFunctionCall,
1400 }
1401
1402 #[derive(Clone, Debug, Deserialize, Serialize)]
1403 #[serde(rename_all = "camelCase")]
1404 pub struct CitationMetadata {
1405 pub citation_sources: Vec<CitationSource>,
1406 }
1407
1408 #[derive(Clone, Debug, Deserialize, Serialize)]
1409 #[serde(rename_all = "camelCase")]
1410 pub struct CitationSource {
1411 #[serde(skip_serializing_if = "Option::is_none")]
1412 pub uri: Option<String>,
1413 #[serde(skip_serializing_if = "Option::is_none")]
1414 pub start_index: Option<i32>,
1415 #[serde(skip_serializing_if = "Option::is_none")]
1416 pub end_index: Option<i32>,
1417 #[serde(skip_serializing_if = "Option::is_none")]
1418 pub license: Option<String>,
1419 }
1420
1421 #[derive(Clone, Debug, Deserialize, Serialize)]
1422 #[serde(rename_all = "camelCase")]
1423 pub struct LogprobsResult {
1424 pub top_candidate: Vec<TopCandidate>,
1425 pub chosen_candidate: Vec<LogProbCandidate>,
1426 }
1427
1428 #[derive(Clone, Debug, Deserialize, Serialize)]
1429 pub struct TopCandidate {
1430 pub candidates: Vec<LogProbCandidate>,
1431 }
1432
1433 #[derive(Clone, Debug, Deserialize, Serialize)]
1434 #[serde(rename_all = "camelCase")]
1435 pub struct LogProbCandidate {
1436 pub token: String,
1437 pub token_id: String,
1438 pub log_probability: f64,
1439 }
1440
1441 #[derive(Debug, Deserialize, Serialize)]
1446 #[serde(rename_all = "camelCase")]
1447 pub struct GenerationConfig {
1448 #[serde(skip_serializing_if = "Option::is_none")]
1451 pub stop_sequences: Option<Vec<String>>,
1452 #[serde(skip_serializing_if = "Option::is_none")]
1458 pub response_mime_type: Option<String>,
1459 #[serde(skip_serializing_if = "Option::is_none")]
1463 pub response_schema: Option<Schema>,
1464 #[serde(
1470 skip_serializing_if = "Option::is_none",
1471 rename = "_responseJsonSchema"
1472 )]
1473 pub _response_json_schema: Option<Value>,
1474 #[serde(skip_serializing_if = "Option::is_none")]
1476 pub response_json_schema: Option<Value>,
1477 #[serde(skip_serializing_if = "Option::is_none")]
1480 pub candidate_count: Option<i32>,
1481 #[serde(skip_serializing_if = "Option::is_none")]
1484 pub max_output_tokens: Option<u64>,
1485 #[serde(skip_serializing_if = "Option::is_none")]
1488 pub temperature: Option<f64>,
1489 #[serde(skip_serializing_if = "Option::is_none")]
1496 pub top_p: Option<f64>,
1497 #[serde(skip_serializing_if = "Option::is_none")]
1503 pub top_k: Option<i32>,
1504 #[serde(skip_serializing_if = "Option::is_none")]
1510 pub presence_penalty: Option<f64>,
1511 #[serde(skip_serializing_if = "Option::is_none")]
1519 pub frequency_penalty: Option<f64>,
1520 #[serde(skip_serializing_if = "Option::is_none")]
1522 pub response_logprobs: Option<bool>,
1523 #[serde(skip_serializing_if = "Option::is_none")]
1526 pub logprobs: Option<i32>,
1527 #[serde(skip_serializing_if = "Option::is_none")]
1529 pub thinking_config: Option<ThinkingConfig>,
1530 #[serde(skip_serializing_if = "Option::is_none")]
1531 pub image_config: Option<ImageConfig>,
1532 }
1533
1534 impl Default for GenerationConfig {
1535 fn default() -> Self {
1536 Self {
1537 temperature: Some(1.0),
1538 max_output_tokens: Some(4096),
1539 stop_sequences: None,
1540 response_mime_type: None,
1541 response_schema: None,
1542 _response_json_schema: None,
1543 response_json_schema: None,
1544 candidate_count: None,
1545 top_p: None,
1546 top_k: None,
1547 presence_penalty: None,
1548 frequency_penalty: None,
1549 response_logprobs: None,
1550 logprobs: None,
1551 thinking_config: None,
1552 image_config: None,
1553 }
1554 }
1555 }
1556
1557 #[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
1559 #[serde(rename_all = "snake_case")]
1560 pub enum ThinkingLevel {
1561 Minimal,
1562 Low,
1563 Medium,
1564 High,
1565 }
1566
1567 #[derive(Debug, Deserialize, Serialize)]
1571 #[serde(rename_all = "camelCase")]
1572 pub struct ThinkingConfig {
1573 #[serde(skip_serializing_if = "Option::is_none")]
1575 pub thinking_budget: Option<u32>,
1576 #[serde(skip_serializing_if = "Option::is_none")]
1578 pub thinking_level: Option<ThinkingLevel>,
1579 #[serde(skip_serializing_if = "Option::is_none")]
1581 pub include_thoughts: Option<bool>,
1582 }
1583
1584 #[derive(Debug, Deserialize, Serialize)]
1585 #[serde(rename_all = "camelCase")]
1586 pub struct ImageConfig {
1587 #[serde(skip_serializing_if = "Option::is_none")]
1588 pub aspect_ratio: Option<String>,
1589 #[serde(skip_serializing_if = "Option::is_none")]
1590 pub image_size: Option<String>,
1591 }
1592
1593 #[derive(Debug, Deserialize, Serialize, Clone)]
1597 pub struct Schema {
1598 pub r#type: String,
1599 #[serde(skip_serializing_if = "Option::is_none")]
1600 pub format: Option<String>,
1601 #[serde(skip_serializing_if = "Option::is_none")]
1602 pub description: Option<String>,
1603 #[serde(skip_serializing_if = "Option::is_none")]
1604 pub nullable: Option<bool>,
1605 #[serde(skip_serializing_if = "Option::is_none")]
1606 pub r#enum: Option<Vec<String>>,
1607 #[serde(skip_serializing_if = "Option::is_none")]
1608 pub max_items: Option<i32>,
1609 #[serde(skip_serializing_if = "Option::is_none")]
1610 pub min_items: Option<i32>,
1611 #[serde(skip_serializing_if = "Option::is_none")]
1612 pub properties: Option<HashMap<String, Schema>>,
1613 #[serde(skip_serializing_if = "Option::is_none")]
1614 pub required: Option<Vec<String>>,
1615 #[serde(skip_serializing_if = "Option::is_none")]
1616 pub items: Option<Box<Schema>>,
1617 }
1618
1619 pub fn flatten_schema(mut schema: Value) -> Result<Value, CompletionError> {
1625 let defs = if let Some(obj) = schema.as_object() {
1627 obj.get("$defs").or_else(|| obj.get("definitions")).cloned()
1628 } else {
1629 None
1630 };
1631
1632 let Some(defs_value) = defs else {
1633 return Ok(schema);
1634 };
1635
1636 let Some(defs_obj) = defs_value.as_object() else {
1637 return Err(CompletionError::ResponseError(
1638 "$defs must be an object".into(),
1639 ));
1640 };
1641
1642 resolve_refs(&mut schema, defs_obj)?;
1643
1644 if let Some(obj) = schema.as_object_mut() {
1646 obj.remove("$defs");
1647 obj.remove("definitions");
1648 }
1649
1650 Ok(schema)
1651 }
1652
1653 fn resolve_refs(
1656 value: &mut Value,
1657 defs: &serde_json::Map<String, Value>,
1658 ) -> Result<(), CompletionError> {
1659 match value {
1660 Value::Object(obj) => {
1661 if let Some(ref_value) = obj.get("$ref")
1662 && let Some(ref_str) = ref_value.as_str()
1663 {
1664 let def_name = parse_ref_path(ref_str)?;
1666
1667 let def = defs.get(&def_name).ok_or_else(|| {
1668 CompletionError::ResponseError(format!("Reference not found: {}", ref_str))
1669 })?;
1670
1671 let mut resolved = def.clone();
1672 resolve_refs(&mut resolved, defs)?;
1673 *value = resolved;
1674 return Ok(());
1675 }
1676
1677 for (_, v) in obj.iter_mut() {
1678 resolve_refs(v, defs)?;
1679 }
1680 }
1681 Value::Array(arr) => {
1682 for item in arr.iter_mut() {
1683 resolve_refs(item, defs)?;
1684 }
1685 }
1686 _ => {}
1687 }
1688
1689 Ok(())
1690 }
1691
1692 fn parse_ref_path(ref_str: &str) -> Result<String, CompletionError> {
1698 if let Some(fragment) = ref_str.strip_prefix('#') {
1699 if let Some(name) = fragment.strip_prefix("/$defs/") {
1700 Ok(name.to_string())
1701 } else if let Some(name) = fragment.strip_prefix("/definitions/") {
1702 Ok(name.to_string())
1703 } else {
1704 Err(CompletionError::ResponseError(format!(
1705 "Unsupported reference format: {}",
1706 ref_str
1707 )))
1708 }
1709 } else {
1710 Err(CompletionError::ResponseError(format!(
1711 "Only fragment references (#/...) are supported: {}",
1712 ref_str
1713 )))
1714 }
1715 }
1716
1717 fn extract_type(type_value: &Value) -> Option<String> {
1720 if type_value.is_string() {
1721 type_value.as_str().map(String::from)
1722 } else if type_value.is_array() {
1723 type_value
1724 .as_array()
1725 .and_then(|arr| arr.first())
1726 .and_then(|v| v.as_str().map(String::from))
1727 } else {
1728 None
1729 }
1730 }
1731
1732 fn extract_type_from_composition(composition: &Value) -> Option<String> {
1735 composition.as_array().and_then(|arr| {
1736 arr.iter().find_map(|schema| {
1737 if let Some(obj) = schema.as_object() {
1738 if let Some(type_val) = obj.get("type")
1740 && let Some(type_str) = type_val.as_str()
1741 && type_str == "null"
1742 {
1743 return None;
1744 }
1745 obj.get("type").and_then(extract_type).or_else(|| {
1747 if obj.contains_key("properties") {
1748 Some("object".to_string())
1749 } else if obj.contains_key("enum") {
1750 Some("string".to_string())
1752 } else {
1753 None
1754 }
1755 })
1756 } else {
1757 None
1758 }
1759 })
1760 })
1761 }
1762
1763 fn extract_schema_from_composition(
1766 composition: &Value,
1767 ) -> Option<serde_json::Map<String, Value>> {
1768 composition.as_array().and_then(|arr| {
1769 arr.iter().find_map(|schema| {
1770 if let Some(obj) = schema.as_object()
1771 && let Some(type_val) = obj.get("type")
1772 && let Some(type_str) = type_val.as_str()
1773 {
1774 if type_str == "null" {
1775 return None;
1776 }
1777 Some(obj.clone())
1778 } else {
1779 None
1780 }
1781 })
1782 })
1783 }
1784
1785 fn infer_type(obj: &serde_json::Map<String, Value>) -> String {
1788 if let Some(type_val) = obj.get("type")
1790 && let Some(type_str) = extract_type(type_val)
1791 {
1792 return type_str;
1793 }
1794
1795 if let Some(any_of) = obj.get("anyOf")
1797 && let Some(type_str) = extract_type_from_composition(any_of)
1798 {
1799 return type_str;
1800 }
1801
1802 if let Some(one_of) = obj.get("oneOf")
1803 && let Some(type_str) = extract_type_from_composition(one_of)
1804 {
1805 return type_str;
1806 }
1807
1808 if let Some(all_of) = obj.get("allOf")
1809 && let Some(type_str) = extract_type_from_composition(all_of)
1810 {
1811 return type_str;
1812 }
1813
1814 if obj.contains_key("properties") {
1816 "object".to_string()
1817 } else {
1818 String::new()
1819 }
1820 }
1821
1822 impl TryFrom<Value> for Schema {
1823 type Error = CompletionError;
1824
1825 fn try_from(value: Value) -> Result<Self, Self::Error> {
1826 let flattened_val = flatten_schema(value)?;
1827 if let Some(obj) = flattened_val.as_object() {
1828 let props_source = if obj.get("properties").is_none() {
1831 if let Some(any_of) = obj.get("anyOf") {
1832 extract_schema_from_composition(any_of)
1833 } else if let Some(one_of) = obj.get("oneOf") {
1834 extract_schema_from_composition(one_of)
1835 } else if let Some(all_of) = obj.get("allOf") {
1836 extract_schema_from_composition(all_of)
1837 } else {
1838 None
1839 }
1840 .unwrap_or(obj.clone())
1841 } else {
1842 obj.clone()
1843 };
1844
1845 let schema_type = infer_type(obj);
1846 let items = obj
1847 .get("items")
1848 .and_then(|v| v.clone().try_into().ok())
1849 .map(Box::new);
1850
1851 let items = if schema_type == "array" && items.is_none() {
1854 Some(Box::new(Schema {
1855 r#type: "string".to_string(),
1856 format: None,
1857 description: None,
1858 nullable: None,
1859 r#enum: None,
1860 max_items: None,
1861 min_items: None,
1862 properties: None,
1863 required: None,
1864 items: None,
1865 }))
1866 } else {
1867 items
1868 };
1869
1870 Ok(Schema {
1871 r#type: schema_type,
1872 format: obj.get("format").and_then(|v| v.as_str()).map(String::from),
1873 description: obj
1874 .get("description")
1875 .and_then(|v| v.as_str())
1876 .map(String::from),
1877 nullable: obj.get("nullable").and_then(|v| v.as_bool()),
1878 r#enum: obj.get("enum").and_then(|v| v.as_array()).map(|arr| {
1879 arr.iter()
1880 .filter_map(|v| v.as_str().map(String::from))
1881 .collect()
1882 }),
1883 max_items: obj
1884 .get("maxItems")
1885 .and_then(|v| v.as_i64())
1886 .map(|v| v as i32),
1887 min_items: obj
1888 .get("minItems")
1889 .and_then(|v| v.as_i64())
1890 .map(|v| v as i32),
1891 properties: props_source
1892 .get("properties")
1893 .and_then(|v| v.as_object())
1894 .map(|map| {
1895 map.iter()
1896 .filter_map(|(k, v)| {
1897 v.clone().try_into().ok().map(|schema| (k.clone(), schema))
1898 })
1899 .collect()
1900 }),
1901 required: props_source
1902 .get("required")
1903 .and_then(|v| v.as_array())
1904 .map(|arr| {
1905 arr.iter()
1906 .filter_map(|v| v.as_str().map(String::from))
1907 .collect()
1908 }),
1909 items,
1910 })
1911 } else {
1912 Err(CompletionError::ResponseError(
1913 "Expected a JSON object for Schema".into(),
1914 ))
1915 }
1916 }
1917 }
1918
1919 #[derive(Debug, Serialize)]
1920 #[serde(rename_all = "camelCase")]
1921 pub struct GenerateContentRequest {
1922 pub contents: Vec<Content>,
1923 #[serde(skip_serializing_if = "Option::is_none")]
1924 pub tools: Option<Vec<Value>>,
1925 pub tool_config: Option<ToolConfig>,
1926 pub generation_config: Option<GenerationConfig>,
1928 pub safety_settings: Option<Vec<SafetySetting>>,
1942 pub system_instruction: Option<Content>,
1945 #[serde(flatten, skip_serializing_if = "Option::is_none")]
1948 pub additional_params: Option<serde_json::Value>,
1949 }
1950
1951 #[derive(Debug, Serialize)]
1952 #[serde(rename_all = "camelCase")]
1953 pub struct Tool {
1954 pub function_declarations: Vec<FunctionDeclaration>,
1955 pub code_execution: Option<CodeExecution>,
1956 }
1957
1958 #[derive(Debug, Serialize, Clone)]
1959 #[serde(rename_all = "camelCase")]
1960 pub struct FunctionDeclaration {
1961 pub name: String,
1962 pub description: String,
1963 #[serde(skip_serializing_if = "Option::is_none")]
1964 pub parameters: Option<Schema>,
1965 }
1966
1967 #[derive(Debug, Serialize, Deserialize)]
1968 #[serde(rename_all = "camelCase")]
1969 pub struct ToolConfig {
1970 pub function_calling_config: Option<FunctionCallingMode>,
1971 }
1972
1973 #[derive(Debug, Serialize, Deserialize, Default)]
1974 #[serde(tag = "mode", rename_all = "UPPERCASE")]
1975 pub enum FunctionCallingMode {
1976 #[default]
1977 Auto,
1978 None,
1979 Any {
1980 #[serde(skip_serializing_if = "Option::is_none")]
1981 allowed_function_names: Option<Vec<String>>,
1982 },
1983 }
1984
1985 impl TryFrom<message::ToolChoice> for FunctionCallingMode {
1986 type Error = CompletionError;
1987 fn try_from(value: message::ToolChoice) -> Result<Self, Self::Error> {
1988 let res = match value {
1989 message::ToolChoice::Auto => Self::Auto,
1990 message::ToolChoice::None => Self::None,
1991 message::ToolChoice::Required => Self::Any {
1992 allowed_function_names: None,
1993 },
1994 message::ToolChoice::Specific { function_names } => Self::Any {
1995 allowed_function_names: Some(function_names),
1996 },
1997 };
1998
1999 Ok(res)
2000 }
2001 }
2002
2003 #[derive(Debug, Serialize)]
2004 pub struct CodeExecution {}
2005
2006 #[derive(Debug, Serialize)]
2007 #[serde(rename_all = "camelCase")]
2008 pub struct SafetySetting {
2009 pub category: HarmCategory,
2010 pub threshold: HarmBlockThreshold,
2011 }
2012
2013 #[derive(Debug, Serialize)]
2014 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
2015 pub enum HarmBlockThreshold {
2016 HarmBlockThresholdUnspecified,
2017 BlockLowAndAbove,
2018 BlockMediumAndAbove,
2019 BlockOnlyHigh,
2020 BlockNone,
2021 Off,
2022 }
2023}
2024
2025#[cfg(test)]
2026mod tests {
2027 use crate::{
2028 message,
2029 providers::gemini::completion::gemini_api_types::{
2030 ContentCandidate, FinishReason, flatten_schema,
2031 },
2032 };
2033
2034 use super::*;
2035 use serde_json::json;
2036
2037 #[test]
2038 fn test_resolve_request_model_uses_override() {
2039 let request = CompletionRequest {
2040 model: Some("gemini-2.5-flash".to_string()),
2041 preamble: None,
2042 chat_history: crate::OneOrMany::one("Hello".into()),
2043 documents: vec![],
2044 tools: vec![],
2045 temperature: None,
2046 max_tokens: None,
2047 tool_choice: None,
2048 additional_params: None,
2049 output_schema: None,
2050 };
2051
2052 let request_model = resolve_request_model("gemini-2.0-flash", &request);
2053 assert_eq!(request_model, "gemini-2.5-flash");
2054 assert_eq!(
2055 completion_endpoint(&request_model),
2056 "/v1beta/models/gemini-2.5-flash:generateContent"
2057 );
2058 assert_eq!(
2059 streaming_endpoint(&request_model),
2060 "/v1beta/models/gemini-2.5-flash:streamGenerateContent"
2061 );
2062 }
2063
2064 #[test]
2065 fn test_resolve_request_model_uses_default_when_unset() {
2066 let request = CompletionRequest {
2067 model: None,
2068 preamble: None,
2069 chat_history: crate::OneOrMany::one("Hello".into()),
2070 documents: vec![],
2071 tools: vec![],
2072 temperature: None,
2073 max_tokens: None,
2074 tool_choice: None,
2075 additional_params: None,
2076 output_schema: None,
2077 };
2078
2079 assert_eq!(
2080 resolve_request_model("gemini-2.0-flash", &request),
2081 "gemini-2.0-flash"
2082 );
2083 }
2084
2085 #[test]
2086 fn test_deserialize_message_user() {
2087 let raw_message = r#"{
2088 "parts": [
2089 {"text": "Hello, world!"},
2090 {"inlineData": {"mimeType": "image/png", "data": "base64encodeddata"}},
2091 {"functionCall": {"name": "test_function", "args": {"arg1": "value1"}}},
2092 {"functionResponse": {"name": "test_function", "response": {"result": "success"}}},
2093 {"fileData": {"mimeType": "application/pdf", "fileUri": "http://example.com/file.pdf"}},
2094 {"executableCode": {"code": "print('Hello, world!')", "language": "PYTHON"}},
2095 {"codeExecutionResult": {"output": "Hello, world!", "outcome": "OUTCOME_OK"}}
2096 ],
2097 "role": "user"
2098 }"#;
2099
2100 let content: Content = {
2101 let jd = &mut serde_json::Deserializer::from_str(raw_message);
2102 serde_path_to_error::deserialize(jd).unwrap_or_else(|err| {
2103 panic!("Deserialization error at {}: {}", err.path(), err);
2104 })
2105 };
2106 assert_eq!(content.role, Some(Role::User));
2107 assert_eq!(content.parts.len(), 7);
2108
2109 let parts: Vec<Part> = content.parts.into_iter().collect();
2110
2111 if let Part {
2112 part: PartKind::Text(text),
2113 ..
2114 } = &parts[0]
2115 {
2116 assert_eq!(text, "Hello, world!");
2117 } else {
2118 panic!("Expected text part");
2119 }
2120
2121 if let Part {
2122 part: PartKind::InlineData(inline_data),
2123 ..
2124 } = &parts[1]
2125 {
2126 assert_eq!(inline_data.mime_type, "image/png");
2127 assert_eq!(inline_data.data, "base64encodeddata");
2128 } else {
2129 panic!("Expected inline data part");
2130 }
2131
2132 if let Part {
2133 part: PartKind::FunctionCall(function_call),
2134 ..
2135 } = &parts[2]
2136 {
2137 assert_eq!(function_call.name, "test_function");
2138 assert_eq!(
2139 function_call.args.as_object().unwrap().get("arg1").unwrap(),
2140 "value1"
2141 );
2142 } else {
2143 panic!("Expected function call part");
2144 }
2145
2146 if let Part {
2147 part: PartKind::FunctionResponse(function_response),
2148 ..
2149 } = &parts[3]
2150 {
2151 assert_eq!(function_response.name, "test_function");
2152 assert_eq!(
2153 function_response
2154 .response
2155 .as_ref()
2156 .unwrap()
2157 .get("result")
2158 .unwrap(),
2159 "success"
2160 );
2161 } else {
2162 panic!("Expected function response part");
2163 }
2164
2165 if let Part {
2166 part: PartKind::FileData(file_data),
2167 ..
2168 } = &parts[4]
2169 {
2170 assert_eq!(file_data.mime_type.as_ref().unwrap(), "application/pdf");
2171 assert_eq!(file_data.file_uri, "http://example.com/file.pdf");
2172 } else {
2173 panic!("Expected file data part");
2174 }
2175
2176 if let Part {
2177 part: PartKind::ExecutableCode(executable_code),
2178 ..
2179 } = &parts[5]
2180 {
2181 assert_eq!(executable_code.code, "print('Hello, world!')");
2182 } else {
2183 panic!("Expected executable code part");
2184 }
2185
2186 if let Part {
2187 part: PartKind::CodeExecutionResult(code_execution_result),
2188 ..
2189 } = &parts[6]
2190 {
2191 assert_eq!(
2192 code_execution_result.clone().output.unwrap(),
2193 "Hello, world!"
2194 );
2195 } else {
2196 panic!("Expected code execution result part");
2197 }
2198 }
2199
2200 #[test]
2201 fn test_deserialize_message_model() {
2202 let json_data = json!({
2203 "parts": [{"text": "Hello, user!"}],
2204 "role": "model"
2205 });
2206
2207 let content: Content = serde_json::from_value(json_data).unwrap();
2208 assert_eq!(content.role, Some(Role::Model));
2209 assert_eq!(content.parts.len(), 1);
2210 if let Some(Part {
2211 part: PartKind::Text(text),
2212 ..
2213 }) = content.parts.first()
2214 {
2215 assert_eq!(text, "Hello, user!");
2216 } else {
2217 panic!("Expected text part");
2218 }
2219 }
2220
2221 #[test]
2222 fn test_message_conversion_user() {
2223 let msg = message::Message::user("Hello, world!");
2224 let content: Content = msg.try_into().unwrap();
2225 assert_eq!(content.role, Some(Role::User));
2226 assert_eq!(content.parts.len(), 1);
2227 if let Some(Part {
2228 part: PartKind::Text(text),
2229 ..
2230 }) = &content.parts.first()
2231 {
2232 assert_eq!(text, "Hello, world!");
2233 } else {
2234 panic!("Expected text part");
2235 }
2236 }
2237
2238 #[test]
2239 fn test_message_conversion_model() {
2240 let msg = message::Message::assistant("Hello, user!");
2241
2242 let content: Content = msg.try_into().unwrap();
2243 assert_eq!(content.role, Some(Role::Model));
2244 assert_eq!(content.parts.len(), 1);
2245 if let Some(Part {
2246 part: PartKind::Text(text),
2247 ..
2248 }) = &content.parts.first()
2249 {
2250 assert_eq!(text, "Hello, user!");
2251 } else {
2252 panic!("Expected text part");
2253 }
2254 }
2255
2256 #[test]
2257 fn test_thought_signature_is_preserved_from_response_reasoning_part() {
2258 let response = GenerateContentResponse {
2259 response_id: "resp_1".to_string(),
2260 candidates: vec![ContentCandidate {
2261 content: Some(Content {
2262 parts: vec![Part {
2263 thought: Some(true),
2264 thought_signature: Some("thought_sig_123".to_string()),
2265 part: PartKind::Text("thinking text".to_string()),
2266 additional_params: None,
2267 }],
2268 role: Some(Role::Model),
2269 }),
2270 finish_reason: Some(FinishReason::Stop),
2271 safety_ratings: None,
2272 citation_metadata: None,
2273 token_count: None,
2274 avg_logprobs: None,
2275 logprobs_result: None,
2276 index: Some(0),
2277 finish_message: None,
2278 }],
2279 prompt_feedback: None,
2280 usage_metadata: None,
2281 model_version: None,
2282 };
2283
2284 let converted: crate::completion::CompletionResponse<GenerateContentResponse> =
2285 response.try_into().expect("convert response");
2286 let first = converted.choice.first();
2287 assert!(matches!(
2288 first,
2289 message::AssistantContent::Reasoning(message::Reasoning { content, .. })
2290 if matches!(
2291 content.first(),
2292 Some(message::ReasoningContent::Text {
2293 text,
2294 signature: Some(signature)
2295 }) if text == "thinking text" && signature == "thought_sig_123"
2296 )
2297 ));
2298 }
2299
2300 #[test]
2301 fn test_reasoning_signature_is_emitted_in_gemini_part() {
2302 let msg = message::Message::Assistant {
2303 id: None,
2304 content: OneOrMany::one(message::AssistantContent::Reasoning(
2305 message::Reasoning::new_with_signature(
2306 "structured thought",
2307 Some("reuse_sig_456".to_string()),
2308 ),
2309 )),
2310 };
2311
2312 let converted: Content = msg.try_into().expect("convert message");
2313 let first = converted.parts.first().expect("reasoning part");
2314 assert_eq!(first.thought, Some(true));
2315 assert_eq!(first.thought_signature.as_deref(), Some("reuse_sig_456"));
2316 assert!(matches!(
2317 &first.part,
2318 PartKind::Text(text) if text == "structured thought"
2319 ));
2320 }
2321
2322 #[test]
2323 fn test_message_conversion_tool_call() {
2324 let tool_call = message::ToolCall {
2325 id: "test_tool".to_string(),
2326 call_id: None,
2327 function: message::ToolFunction {
2328 name: "test_function".to_string(),
2329 arguments: json!({"arg1": "value1"}),
2330 },
2331 signature: None,
2332 additional_params: None,
2333 };
2334
2335 let msg = message::Message::Assistant {
2336 id: None,
2337 content: OneOrMany::one(message::AssistantContent::ToolCall(tool_call)),
2338 };
2339
2340 let content: Content = msg.try_into().unwrap();
2341 assert_eq!(content.role, Some(Role::Model));
2342 assert_eq!(content.parts.len(), 1);
2343 if let Some(Part {
2344 part: PartKind::FunctionCall(function_call),
2345 ..
2346 }) = content.parts.first()
2347 {
2348 assert_eq!(function_call.name, "test_function");
2349 assert_eq!(
2350 function_call.args.as_object().unwrap().get("arg1").unwrap(),
2351 "value1"
2352 );
2353 } else {
2354 panic!("Expected function call part");
2355 }
2356 }
2357
2358 #[test]
2359 fn test_vec_schema_conversion() {
2360 let schema_with_ref = json!({
2361 "type": "array",
2362 "items": {
2363 "$ref": "#/$defs/Person"
2364 },
2365 "$defs": {
2366 "Person": {
2367 "type": "object",
2368 "properties": {
2369 "first_name": {
2370 "type": ["string", "null"],
2371 "description": "The person's first name, if provided (null otherwise)"
2372 },
2373 "last_name": {
2374 "type": ["string", "null"],
2375 "description": "The person's last name, if provided (null otherwise)"
2376 },
2377 "job": {
2378 "type": ["string", "null"],
2379 "description": "The person's job, if provided (null otherwise)"
2380 }
2381 },
2382 "required": []
2383 }
2384 }
2385 });
2386
2387 let result: Result<Schema, _> = schema_with_ref.try_into();
2388
2389 match result {
2390 Ok(schema) => {
2391 assert_eq!(schema.r#type, "array");
2392
2393 if let Some(items) = schema.items {
2394 println!("item types: {}", items.r#type);
2395
2396 assert_ne!(items.r#type, "", "Items type should not be empty string!");
2397 assert_eq!(items.r#type, "object", "Items should be object type");
2398 } else {
2399 panic!("Schema should have items field for array type");
2400 }
2401 }
2402 Err(e) => println!("Schema conversion failed: {:?}", e),
2403 }
2404 }
2405
2406 #[test]
2407 fn test_object_schema() {
2408 let simple_schema = json!({
2409 "type": "object",
2410 "properties": {
2411 "name": {
2412 "type": "string"
2413 }
2414 }
2415 });
2416
2417 let schema: Schema = simple_schema.try_into().unwrap();
2418 assert_eq!(schema.r#type, "object");
2419 assert!(schema.properties.is_some());
2420 }
2421
2422 #[test]
2423 fn test_array_with_inline_items() {
2424 let inline_schema = json!({
2425 "type": "array",
2426 "items": {
2427 "type": "object",
2428 "properties": {
2429 "name": {
2430 "type": "string"
2431 }
2432 }
2433 }
2434 });
2435
2436 let schema: Schema = inline_schema.try_into().unwrap();
2437 assert_eq!(schema.r#type, "array");
2438
2439 if let Some(items) = schema.items {
2440 assert_eq!(items.r#type, "object");
2441 assert!(items.properties.is_some());
2442 } else {
2443 panic!("Schema should have items field");
2444 }
2445 }
2446 #[test]
2447 fn test_flattened_schema() {
2448 let ref_schema = json!({
2449 "type": "array",
2450 "items": {
2451 "$ref": "#/$defs/Person"
2452 },
2453 "$defs": {
2454 "Person": {
2455 "type": "object",
2456 "properties": {
2457 "name": { "type": "string" }
2458 }
2459 }
2460 }
2461 });
2462
2463 let flattened = flatten_schema(ref_schema).unwrap();
2464 let schema: Schema = flattened.try_into().unwrap();
2465
2466 assert_eq!(schema.r#type, "array");
2467
2468 if let Some(items) = schema.items {
2469 println!("Flattened items type: '{}'", items.r#type);
2470
2471 assert_eq!(items.r#type, "object");
2472 assert!(items.properties.is_some());
2473 }
2474 }
2475
2476 #[test]
2477 fn test_array_without_items_gets_default() {
2478 let schema_json = json!({
2479 "type": "object",
2480 "properties": {
2481 "service_ids": {
2482 "type": "array",
2483 "description": "A list of service IDs"
2484 }
2485 }
2486 });
2487
2488 let schema: Schema = schema_json.try_into().unwrap();
2489 let props = schema.properties.unwrap();
2490 let service_ids = props.get("service_ids").unwrap();
2491 assert_eq!(service_ids.r#type, "array");
2492 let items = service_ids
2493 .items
2494 .as_ref()
2495 .expect("array schema missing items should get a default");
2496 assert_eq!(items.r#type, "string");
2497 }
2498
2499 #[test]
2500 fn test_txt_document_conversion_to_text_part() {
2501 use crate::message::{DocumentMediaType, UserContent};
2503
2504 let doc = UserContent::document(
2505 "Note: test.md\nPath: /test.md\nContent: Hello World!",
2506 Some(DocumentMediaType::TXT),
2507 );
2508
2509 let content: Content = message::Message::User {
2510 content: crate::OneOrMany::one(doc),
2511 }
2512 .try_into()
2513 .unwrap();
2514
2515 if let Part {
2516 part: PartKind::Text(text),
2517 ..
2518 } = &content.parts[0]
2519 {
2520 assert!(text.contains("Note: test.md"));
2521 assert!(text.contains("Hello World!"));
2522 } else {
2523 panic!(
2524 "Expected text part for TXT document, got: {:?}",
2525 content.parts[0]
2526 );
2527 }
2528 }
2529
2530 #[test]
2531 fn test_tool_result_with_image_content() {
2532 use crate::OneOrMany;
2534 use crate::message::{
2535 DocumentSourceKind, Image, ImageMediaType, ToolResult, ToolResultContent,
2536 };
2537
2538 let tool_result = ToolResult {
2540 id: "test_tool".to_string(),
2541 call_id: None,
2542 content: OneOrMany::many(vec![
2543 ToolResultContent::Text(message::Text {
2544 text: r#"{"status": "success"}"#.to_string(),
2545 }),
2546 ToolResultContent::Image(Image {
2547 data: DocumentSourceKind::Base64("iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==".to_string()),
2548 media_type: Some(ImageMediaType::PNG),
2549 detail: None,
2550 additional_params: None,
2551 }),
2552 ]).expect("Should create OneOrMany with multiple items"),
2553 };
2554
2555 let user_content = message::UserContent::ToolResult(tool_result);
2556 let msg = message::Message::User {
2557 content: OneOrMany::one(user_content),
2558 };
2559
2560 let content: Content = msg.try_into().expect("Should convert to Gemini Content");
2562 assert_eq!(content.role, Some(Role::User));
2563 assert_eq!(content.parts.len(), 1);
2564
2565 if let Some(Part {
2567 part: PartKind::FunctionResponse(function_response),
2568 ..
2569 }) = content.parts.first()
2570 {
2571 assert_eq!(function_response.name, "test_tool");
2572
2573 assert!(function_response.response.is_some());
2575 let response = function_response.response.as_ref().unwrap();
2576 assert!(response.get("result").is_some());
2577
2578 assert!(function_response.parts.is_some());
2580 let parts = function_response.parts.as_ref().unwrap();
2581 assert_eq!(parts.len(), 1);
2582
2583 let image_part = &parts[0];
2584 assert!(image_part.inline_data.is_some());
2585 let inline_data = image_part.inline_data.as_ref().unwrap();
2586 assert_eq!(inline_data.mime_type, "image/png");
2587 assert!(!inline_data.data.is_empty());
2588 } else {
2589 panic!("Expected FunctionResponse part");
2590 }
2591 }
2592
2593 #[test]
2594 fn test_markdown_document_conversion_to_text_part() {
2595 use crate::message::{DocumentMediaType, UserContent};
2597
2598 let doc = UserContent::document(
2599 "# Heading\n\n* List item",
2600 Some(DocumentMediaType::MARKDOWN),
2601 );
2602
2603 let content: Content = message::Message::User {
2604 content: crate::OneOrMany::one(doc),
2605 }
2606 .try_into()
2607 .unwrap();
2608
2609 if let Part {
2610 part: PartKind::Text(text),
2611 ..
2612 } = &content.parts[0]
2613 {
2614 assert_eq!(text, "# Heading\n\n* List item");
2615 } else {
2616 panic!(
2617 "Expected text part for MARKDOWN document, got: {:?}",
2618 content.parts[0]
2619 );
2620 }
2621 }
2622
2623 #[test]
2624 fn test_markdown_url_document_conversion_to_file_data_part() {
2625 use crate::message::{DocumentMediaType, DocumentSourceKind, UserContent};
2627
2628 let doc = UserContent::Document(message::Document {
2629 data: DocumentSourceKind::Url(
2630 "https://generativelanguage.googleapis.com/v1beta/files/test-markdown".to_string(),
2631 ),
2632 media_type: Some(DocumentMediaType::MARKDOWN),
2633 additional_params: None,
2634 });
2635
2636 let content: Content = message::Message::User {
2637 content: crate::OneOrMany::one(doc),
2638 }
2639 .try_into()
2640 .unwrap();
2641
2642 if let Part {
2643 part: PartKind::FileData(file_data),
2644 ..
2645 } = &content.parts[0]
2646 {
2647 assert_eq!(
2648 file_data.file_uri,
2649 "https://generativelanguage.googleapis.com/v1beta/files/test-markdown"
2650 );
2651 assert_eq!(file_data.mime_type.as_deref(), Some("text/markdown"));
2652 } else {
2653 panic!(
2654 "Expected file_data part for URL MARKDOWN document, got: {:?}",
2655 content.parts[0]
2656 );
2657 }
2658 }
2659
2660 #[test]
2661 fn test_tool_result_with_url_image() {
2662 use crate::OneOrMany;
2664 use crate::message::{
2665 DocumentSourceKind, Image, ImageMediaType, ToolResult, ToolResultContent,
2666 };
2667
2668 let tool_result = ToolResult {
2669 id: "screenshot_tool".to_string(),
2670 call_id: None,
2671 content: OneOrMany::one(ToolResultContent::Image(Image {
2672 data: DocumentSourceKind::Url("https://example.com/image.png".to_string()),
2673 media_type: Some(ImageMediaType::PNG),
2674 detail: None,
2675 additional_params: None,
2676 })),
2677 };
2678
2679 let user_content = message::UserContent::ToolResult(tool_result);
2680 let msg = message::Message::User {
2681 content: OneOrMany::one(user_content),
2682 };
2683
2684 let content: Content = msg.try_into().expect("Should convert to Gemini Content");
2685 assert_eq!(content.role, Some(Role::User));
2686 assert_eq!(content.parts.len(), 1);
2687
2688 if let Some(Part {
2689 part: PartKind::FunctionResponse(function_response),
2690 ..
2691 }) = content.parts.first()
2692 {
2693 assert_eq!(function_response.name, "screenshot_tool");
2694
2695 assert!(function_response.parts.is_some());
2697 let parts = function_response.parts.as_ref().unwrap();
2698 assert_eq!(parts.len(), 1);
2699
2700 let image_part = &parts[0];
2701 assert!(image_part.file_data.is_some());
2702 let file_data = image_part.file_data.as_ref().unwrap();
2703 assert_eq!(file_data.file_uri, "https://example.com/image.png");
2704 assert_eq!(file_data.mime_type.as_ref().unwrap(), "image/png");
2705 } else {
2706 panic!("Expected FunctionResponse part");
2707 }
2708 }
2709
2710 #[test]
2711 fn test_create_request_body_with_documents() {
2712 use crate::OneOrMany;
2714 use crate::completion::request::{CompletionRequest, Document};
2715 use crate::message::Message;
2716
2717 let documents = vec![
2718 Document {
2719 id: "doc1".to_string(),
2720 text: "Note: first.md\nContent: First note".to_string(),
2721 additional_props: std::collections::HashMap::new(),
2722 },
2723 Document {
2724 id: "doc2".to_string(),
2725 text: "Note: second.md\nContent: Second note".to_string(),
2726 additional_props: std::collections::HashMap::new(),
2727 },
2728 ];
2729
2730 let completion_request = CompletionRequest {
2731 preamble: Some("You are a helpful assistant".to_string()),
2732 chat_history: OneOrMany::one(Message::user("What are my notes about?")),
2733 documents: documents.clone(),
2734 tools: vec![],
2735 temperature: None,
2736 model: None,
2737 output_schema: None,
2738 max_tokens: None,
2739 tool_choice: None,
2740 additional_params: None,
2741 };
2742
2743 let request = create_request_body(completion_request).unwrap();
2744
2745 assert_eq!(
2747 request.contents.len(),
2748 2,
2749 "Expected 2 contents (documents + user message)"
2750 );
2751
2752 assert_eq!(request.contents[0].role, Some(Role::User));
2754 assert_eq!(
2755 request.contents[0].parts.len(),
2756 2,
2757 "Expected 2 document parts"
2758 );
2759
2760 for part in &request.contents[0].parts {
2762 if let Part {
2763 part: PartKind::Text(text),
2764 ..
2765 } = part
2766 {
2767 assert!(
2768 text.contains("Note:") && text.contains("Content:"),
2769 "Document should contain note metadata"
2770 );
2771 } else {
2772 panic!("Document parts should be text, not {:?}", part);
2773 }
2774 }
2775
2776 assert_eq!(request.contents[1].role, Some(Role::User));
2778 if let Part {
2779 part: PartKind::Text(text),
2780 ..
2781 } = &request.contents[1].parts[0]
2782 {
2783 assert_eq!(text, "What are my notes about?");
2784 } else {
2785 panic!("Expected user message to be text");
2786 }
2787 }
2788
2789 #[test]
2790 fn test_create_request_body_without_documents() {
2791 use crate::OneOrMany;
2793 use crate::completion::request::CompletionRequest;
2794 use crate::message::Message;
2795
2796 let completion_request = CompletionRequest {
2797 preamble: Some("You are a helpful assistant".to_string()),
2798 chat_history: OneOrMany::one(Message::user("Hello")),
2799 documents: vec![], tools: vec![],
2801 temperature: None,
2802 max_tokens: None,
2803 tool_choice: None,
2804 model: None,
2805 output_schema: None,
2806 additional_params: None,
2807 };
2808
2809 let request = create_request_body(completion_request).unwrap();
2810
2811 assert_eq!(request.contents.len(), 1, "Expected only user message");
2813 assert_eq!(request.contents[0].role, Some(Role::User));
2814
2815 if let Part {
2816 part: PartKind::Text(text),
2817 ..
2818 } = &request.contents[0].parts[0]
2819 {
2820 assert_eq!(text, "Hello");
2821 } else {
2822 panic!("Expected user message to be text");
2823 }
2824 }
2825
2826 #[test]
2827 fn test_from_tool_output_parses_image_json() {
2828 use crate::message::{DocumentSourceKind, ToolResultContent};
2830
2831 let image_json = r#"{"type": "image", "data": "base64data==", "mimeType": "image/jpeg"}"#;
2833 let result = ToolResultContent::from_tool_output(image_json);
2834
2835 assert_eq!(result.len(), 1);
2836 if let ToolResultContent::Image(img) = result.first() {
2837 assert!(matches!(img.data, DocumentSourceKind::Base64(_)));
2838 if let DocumentSourceKind::Base64(data) = &img.data {
2839 assert_eq!(data, "base64data==");
2840 }
2841 assert_eq!(img.media_type, Some(crate::message::ImageMediaType::JPEG));
2842 } else {
2843 panic!("Expected Image content");
2844 }
2845 }
2846
2847 #[test]
2848 fn test_from_tool_output_parses_hybrid_json() {
2849 use crate::message::{DocumentSourceKind, ToolResultContent};
2851
2852 let hybrid_json = r#"{
2853 "response": {"status": "ok", "count": 42},
2854 "parts": [
2855 {"type": "image", "data": "imgdata1==", "mimeType": "image/png"},
2856 {"type": "image", "data": "https://example.com/img.jpg", "mimeType": "image/jpeg"}
2857 ]
2858 }"#;
2859
2860 let result = ToolResultContent::from_tool_output(hybrid_json);
2861
2862 assert_eq!(result.len(), 3);
2864
2865 let items: Vec<_> = result.iter().collect();
2866
2867 if let ToolResultContent::Text(text) = &items[0] {
2869 assert!(text.text.contains("status"));
2870 assert!(text.text.contains("ok"));
2871 } else {
2872 panic!("Expected Text content first");
2873 }
2874
2875 if let ToolResultContent::Image(img) = &items[1] {
2877 assert!(matches!(img.data, DocumentSourceKind::Base64(_)));
2878 } else {
2879 panic!("Expected Image content second");
2880 }
2881
2882 if let ToolResultContent::Image(img) = &items[2] {
2884 assert!(matches!(img.data, DocumentSourceKind::Url(_)));
2885 } else {
2886 panic!("Expected Image content third");
2887 }
2888 }
2889
2890 #[tokio::test]
2894 #[ignore = "requires GEMINI_API_KEY environment variable"]
2895 async fn test_gemini_agent_with_image_tool_result_e2e() {
2896 use crate::completion::{Prompt, ToolDefinition};
2897 use crate::prelude::*;
2898 use crate::providers::gemini;
2899 use crate::tool::Tool;
2900 use serde::{Deserialize, Serialize};
2901
2902 #[derive(Debug, Serialize, Deserialize)]
2904 struct ImageGeneratorTool;
2905
2906 #[derive(Debug, thiserror::Error)]
2907 #[error("Image generation error")]
2908 struct ImageToolError;
2909
2910 impl Tool for ImageGeneratorTool {
2911 const NAME: &'static str = "generate_test_image";
2912 type Error = ImageToolError;
2913 type Args = serde_json::Value;
2914 type Output = String;
2916
2917 async fn definition(&self, _prompt: String) -> ToolDefinition {
2918 ToolDefinition {
2919 name: "generate_test_image".to_string(),
2920 description: "Generates a small test image (a 1x1 red pixel). Call this tool when asked to generate or show an image.".to_string(),
2921 parameters: json!({
2922 "type": "object",
2923 "properties": {},
2924 "required": []
2925 }),
2926 }
2927 }
2928
2929 async fn call(&self, _args: Self::Args) -> Result<Self::Output, Self::Error> {
2930 Ok(json!({
2933 "type": "image",
2934 "data": "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8DwHwAFBQIAX8jx0gAAAABJRU5ErkJggg==",
2935 "mimeType": "image/png"
2936 }).to_string())
2937 }
2938 }
2939
2940 let client = gemini::Client::from_env();
2941
2942 let agent = client
2943 .agent("gemini-3-flash-preview")
2944 .preamble("You are a helpful assistant. When asked about images, use the generate_test_image tool to create one, then describe what you see in the image.")
2945 .tool(ImageGeneratorTool)
2946 .build();
2947
2948 let response = agent
2950 .prompt("Please generate a test image and tell me what color the pixel is.")
2951 .await;
2952
2953 assert!(
2956 response.is_ok(),
2957 "Gemini should successfully process tool result with image: {:?}",
2958 response.err()
2959 );
2960
2961 let response_text = response.unwrap();
2962 println!("Response: {response_text}");
2963 assert!(!response_text.is_empty(), "Response should not be empty");
2965 }
2966}