pub const GEMINI_3_1_FLASH_LITE_PREVIEW: &str = "gemini-3.1-flash-lite-preview";
pub const GEMINI_3_FLASH_PREVIEW: &str = "gemini-3-flash-preview";
pub const GEMINI_2_5_PRO_PREVIEW_06_05: &str = "gemini-2.5-pro-preview-06-05";
pub const GEMINI_2_5_PRO_PREVIEW_05_06: &str = "gemini-2.5-pro-preview-05-06";
pub const GEMINI_2_5_PRO_PREVIEW_03_25: &str = "gemini-2.5-pro-preview-03-25";
pub const GEMINI_2_5_FLASH_PREVIEW_04_17: &str = "gemini-2.5-flash-preview-04-17";
pub const GEMINI_2_5_PRO_EXP_03_25: &str = "gemini-2.5-pro-exp-03-25";
pub const GEMINI_2_5_FLASH: &str = "gemini-2.5-flash";
pub const GEMINI_2_0_FLASH_LITE: &str = "gemini-2.0-flash-lite";
pub const GEMINI_2_0_FLASH: &str = "gemini-2.0-flash";
use self::gemini_api_types::tool_parameters_to_schema;
use crate::http_client::HttpClientExt;
use crate::message::{self, MimeType, Reasoning};
use crate::providers::gemini::completion::gemini_api_types::{
AdditionalParameters, FunctionCallingMode, ToolConfig,
};
use crate::providers::gemini::streaming::StreamingCompletionResponse;
use crate::telemetry::SpanCombinator;
use crate::{
OneOrMany,
completion::{self, CompletionError, CompletionRequest, GetTokenUsage},
};
use gemini_api_types::{
Content, FinishReason, FunctionDeclaration, GenerateContentRequest, GenerateContentResponse,
GenerationConfig, Part, PartKind, Role, Tool,
};
use serde_json::{Map, Value};
use std::convert::TryFrom;
use tracing::{Level, enabled, info_span};
use tracing_futures::Instrument;
use super::Client;
#[derive(Clone, Debug)]
pub struct CompletionModel<T = reqwest::Client> {
pub(crate) client: Client<T>,
pub model: String,
}
impl<T> CompletionModel<T> {
pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
Self {
client,
model: model.into(),
}
}
pub fn with_model(client: Client<T>, model: &str) -> Self {
Self {
client,
model: model.into(),
}
}
}
impl<T> completion::CompletionModel for CompletionModel<T>
where
T: HttpClientExt + Clone + 'static,
{
type Response = GenerateContentResponse;
type StreamingResponse = StreamingCompletionResponse;
type Client = super::Client<T>;
fn make(client: &Self::Client, model: impl Into<String>) -> Self {
Self::new(client.clone(), model)
}
async fn completion(
&self,
completion_request: CompletionRequest,
) -> Result<completion::CompletionResponse<GenerateContentResponse>, CompletionError> {
let request_model = resolve_request_model(&self.model, &completion_request);
let span = if tracing::Span::current().is_disabled() {
info_span!(
target: "rig::completions",
"generate_content",
gen_ai.operation.name = "generate_content",
gen_ai.provider.name = "gcp.gemini",
gen_ai.request.model = &request_model,
gen_ai.system_instructions = &completion_request.preamble,
gen_ai.response.id = tracing::field::Empty,
gen_ai.response.model = tracing::field::Empty,
gen_ai.usage.output_tokens = tracing::field::Empty,
gen_ai.usage.input_tokens = tracing::field::Empty,
gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
gen_ai.usage.cache_creation.input_tokens = tracing::field::Empty,
gen_ai.usage.tool_use_prompt_tokens = tracing::field::Empty,
gen_ai.usage.reasoning_tokens = tracing::field::Empty,
)
} else {
tracing::Span::current()
};
let request = create_request_body(completion_request)?;
if enabled!(Level::TRACE) {
tracing::trace!(
target: "rig::completions",
"Gemini completion request: {}",
serde_json::to_string_pretty(&request)?
);
}
let body = serde_json::to_vec(&request)?;
let path = completion_endpoint(&request_model);
let request = self
.client
.post(path.as_str())?
.body(body)
.map_err(|e| CompletionError::HttpError(e.into()))?;
async move {
let response = self.client.send::<_, Vec<u8>>(request).await?;
if response.status().is_success() {
let response_body = response
.into_body()
.await
.map_err(CompletionError::HttpError)?;
let response_text = String::from_utf8_lossy(&response_body).to_string();
let response: GenerateContentResponse = serde_json::from_slice(&response_body)
.map_err(|err| {
tracing::error!(
error = %err,
body = %response_text,
"Failed to deserialize Gemini completion response"
);
CompletionError::JsonError(err)
})?;
let span = tracing::Span::current();
span.record_response_metadata(&response);
span.record_token_usage(&response.usage_metadata);
if enabled!(Level::TRACE) {
tracing::trace!(
target: "rig::completions",
"Gemini completion response: {}",
serde_json::to_string_pretty(&response)?
);
}
response.try_into()
} else {
let text = String::from_utf8_lossy(
&response
.into_body()
.await
.map_err(CompletionError::HttpError)?,
)
.into();
Err(CompletionError::ProviderError(text))
}
}
.instrument(span)
.await
}
async fn stream(
&self,
request: CompletionRequest,
) -> Result<
crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
CompletionError,
> {
CompletionModel::stream(self, request).await
}
}
pub(crate) fn create_request_body(
completion_request: CompletionRequest,
) -> Result<GenerateContentRequest, CompletionError> {
let documents_message = completion_request.normalized_documents();
let CompletionRequest {
model: _,
preamble,
chat_history,
documents: _,
tools: function_tools,
temperature,
max_tokens,
tool_choice,
mut additional_params,
output_schema,
} = completion_request;
let mut full_history = Vec::new();
if let Some(msg) = documents_message {
full_history.push(msg);
}
full_history.extend(chat_history);
let (history_system, full_history) = split_system_messages_from_history(full_history);
let mut additional_params_payload = additional_params
.take()
.unwrap_or_else(|| Value::Object(Map::new()));
let mut additional_tools =
extract_tools_from_additional_params(&mut additional_params_payload)?;
let AdditionalParameters {
mut generation_config,
additional_params,
} = serde_json::from_value::<AdditionalParameters>(additional_params_payload)?;
if let Some(schema) = output_schema {
let cfg = generation_config.get_or_insert_with(GenerationConfig::default);
cfg.response_mime_type = Some("application/json".to_string());
cfg.response_json_schema = Some(schema.to_value());
}
generation_config = generation_config.map(|mut cfg| {
if let Some(temp) = temperature {
cfg.temperature = Some(temp);
};
if let Some(max_tokens) = max_tokens {
cfg.max_output_tokens = Some(max_tokens);
};
cfg
});
let mut system_parts: Vec<Part> = Vec::new();
if let Some(preamble) = preamble.filter(|preamble| !preamble.is_empty()) {
system_parts.push(preamble.into());
}
for content in history_system {
if !content.is_empty() {
system_parts.push(content.into());
}
}
let system_instruction = if system_parts.is_empty() {
None
} else {
Some(Content {
parts: system_parts,
role: Some(Role::Model),
})
};
let mut tools = if function_tools.is_empty() {
Vec::new()
} else {
vec![serde_json::to_value(Tool::try_from(function_tools)?)?]
};
tools.append(&mut additional_tools);
let tools = if tools.is_empty() { None } else { Some(tools) };
let tool_config = if let Some(cfg) = tool_choice {
Some(ToolConfig {
function_calling_config: Some(FunctionCallingMode::try_from(cfg)?),
})
} else {
None
};
let request = GenerateContentRequest {
contents: full_history
.into_iter()
.map(|msg| {
msg.try_into()
.map_err(|e| CompletionError::RequestError(Box::new(e)))
})
.collect::<Result<Vec<_>, _>>()?,
generation_config,
safety_settings: None,
tools,
tool_config,
system_instruction,
additional_params,
};
Ok(request)
}
fn split_system_messages_from_history(
history: Vec<completion::Message>,
) -> (Vec<String>, Vec<completion::Message>) {
let mut system = Vec::new();
let mut remaining = Vec::new();
for message in history {
match message {
completion::Message::System { content } => system.push(content),
other => remaining.push(other),
}
}
(system, remaining)
}
fn extract_tools_from_additional_params(
additional_params: &mut Value,
) -> Result<Vec<Value>, CompletionError> {
if let Some(map) = additional_params.as_object_mut()
&& let Some(raw_tools) = map.remove("tools")
{
return serde_json::from_value::<Vec<Value>>(raw_tools).map_err(|err| {
CompletionError::RequestError(
format!("Invalid Gemini `additional_params.tools` payload: {err}").into(),
)
});
}
Ok(Vec::new())
}
pub(crate) fn resolve_request_model(
default_model: &str,
completion_request: &CompletionRequest,
) -> String {
completion_request
.model
.clone()
.unwrap_or_else(|| default_model.to_string())
}
pub(crate) fn completion_endpoint(model: &str) -> String {
format!("/v1beta/models/{model}:generateContent")
}
pub(crate) fn streaming_endpoint(model: &str) -> String {
format!("/v1beta/models/{model}:streamGenerateContent")
}
impl TryFrom<completion::ToolDefinition> for Tool {
type Error = CompletionError;
fn try_from(tool: completion::ToolDefinition) -> Result<Self, Self::Error> {
let parameters = tool_parameters_to_schema(tool.parameters)?;
Ok(Self {
function_declarations: vec![FunctionDeclaration {
name: tool.name,
description: tool.description,
parameters,
}],
code_execution: None,
})
}
}
impl TryFrom<Vec<completion::ToolDefinition>> for Tool {
type Error = CompletionError;
fn try_from(tools: Vec<completion::ToolDefinition>) -> Result<Self, Self::Error> {
let mut function_declarations = Vec::new();
for tool in tools {
let parameters = tool_parameters_to_schema(tool.parameters).map_err(|e| {
CompletionError::ProviderError(format!(
"Tool '{}' could not be converted to a schema: {:?}",
tool.name, e,
))
})?;
function_declarations.push(FunctionDeclaration {
name: tool.name,
description: tool.description,
parameters,
});
}
Ok(Self {
function_declarations,
code_execution: None,
})
}
}
pub(crate) fn function_call_finish_reason_error(
reason: &FinishReason,
finish_message: Option<&str>,
) -> Option<CompletionError> {
match reason {
FinishReason::MalformedFunctionCall
| FinishReason::UnexpectedToolCall
| FinishReason::MissingThoughtSignature
| FinishReason::TooManyToolCalls
| FinishReason::MalformedResponse => {
let message = finish_message.unwrap_or("no finish message provided");
Some(CompletionError::ResponseError(format!(
"Gemini stopped with finish_reason={reason:?}: {message}"
)))
}
_ => None,
}
}
impl TryFrom<GenerateContentResponse> for completion::CompletionResponse<GenerateContentResponse> {
type Error = CompletionError;
fn try_from(response: GenerateContentResponse) -> Result<Self, Self::Error> {
let candidate = response.candidates.first().ok_or_else(|| {
CompletionError::ResponseError("No response candidates in response".into())
})?;
if let Some(reason) = candidate.finish_reason.as_ref()
&& let Some(err) =
function_call_finish_reason_error(reason, candidate.finish_message.as_deref())
{
return Err(err);
}
let content = candidate
.content
.as_ref()
.ok_or_else(|| {
let reason = candidate
.finish_reason
.as_ref()
.map(|r| format!("finish_reason={r:?}"))
.unwrap_or_else(|| "finish_reason=<unknown>".to_string());
let message = candidate
.finish_message
.as_deref()
.unwrap_or("no finish message provided");
CompletionError::ResponseError(format!(
"Gemini candidate missing content ({reason}, finish_message={message})"
))
})?
.parts
.iter()
.map(
|Part {
thought,
thought_signature,
part,
..
}| {
Ok(match part {
PartKind::Text(text) => {
if let Some(thought) = thought
&& *thought
{
completion::AssistantContent::Reasoning(
Reasoning::new_with_signature(text, thought_signature.clone()),
)
} else {
completion::AssistantContent::text(text)
}
}
PartKind::InlineData(inline_data) => {
let mime_type =
message::MediaType::from_mime_type(&inline_data.mime_type);
match mime_type {
Some(message::MediaType::Image(media_type)) => {
message::AssistantContent::image_base64(
&inline_data.data,
Some(media_type),
Some(message::ImageDetail::default()),
)
}
_ => {
return Err(CompletionError::ResponseError(format!(
"Unsupported media type {mime_type:?}"
)));
}
}
}
PartKind::FunctionCall(function_call) => {
completion::AssistantContent::ToolCall(
message::ToolCall::new(
function_call.name.clone(),
message::ToolFunction::new(
function_call.name.clone(),
function_call.args.clone(),
),
)
.with_signature(thought_signature.clone()),
)
}
_ => {
return Err(CompletionError::ResponseError(
"Response did not contain a message or tool call".into(),
));
}
})
},
)
.collect::<Result<Vec<_>, _>>()?;
let choice = OneOrMany::many(content).map_err(|_| {
CompletionError::ResponseError(
"Response contained no message or tool call (empty)".to_owned(),
)
})?;
let usage = response
.usage_metadata
.as_ref()
.and_then(GetTokenUsage::token_usage)
.unwrap_or_default();
Ok(completion::CompletionResponse {
choice,
usage,
raw_response: response,
message_id: None,
})
}
}
pub mod gemini_api_types {
use crate::telemetry::ProviderResponseExt;
use std::{collections::HashMap, convert::Infallible, str::FromStr};
use serde::{Deserialize, Serialize};
use serde_json::{Value, json};
use crate::completion::GetTokenUsage;
use crate::message::{DocumentSourceKind, ImageMediaType, MessageError, MimeType};
use crate::{
completion::CompletionError,
message::{self},
providers::gemini::gemini_api_types::{CodeExecutionResult, ExecutableCode},
};
#[derive(Debug, Deserialize, Serialize, Default)]
#[serde(rename_all = "camelCase")]
pub struct AdditionalParameters {
pub generation_config: Option<GenerationConfig>,
#[serde(flatten, skip_serializing_if = "Option::is_none")]
pub additional_params: Option<serde_json::Value>,
}
impl AdditionalParameters {
pub fn with_config(mut self, cfg: GenerationConfig) -> Self {
self.generation_config = Some(cfg);
self
}
pub fn with_params(mut self, params: serde_json::Value) -> Self {
self.additional_params = Some(params);
self
}
}
#[derive(Debug, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct GenerateContentResponse {
pub response_id: String,
pub candidates: Vec<ContentCandidate>,
pub prompt_feedback: Option<PromptFeedback>,
pub usage_metadata: Option<UsageMetadata>,
pub model_version: Option<String>,
}
impl ProviderResponseExt for GenerateContentResponse {
type OutputMessage = ContentCandidate;
type Usage = UsageMetadata;
fn get_response_id(&self) -> Option<String> {
Some(self.response_id.clone())
}
fn get_response_model_name(&self) -> Option<String> {
self.model_version.clone()
}
fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
self.candidates.clone()
}
fn get_text_response(&self) -> Option<String> {
let str = self
.candidates
.iter()
.filter_map(|x| {
let content = x.content.as_ref()?;
if content.role.as_ref().is_none_or(|y| y != &Role::Model) {
return None;
}
let res = content
.parts
.iter()
.filter_map(|part| {
if let PartKind::Text(ref str) = part.part {
Some(str.to_owned())
} else {
None
}
})
.collect::<Vec<String>>()
.join("\n");
Some(res)
})
.collect::<Vec<String>>()
.join("\n");
if str.is_empty() { None } else { Some(str) }
}
fn get_usage(&self) -> Option<Self::Usage> {
self.usage_metadata.clone()
}
}
#[derive(Clone, Debug, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct ContentCandidate {
#[serde(skip_serializing_if = "Option::is_none")]
pub content: Option<Content>,
pub finish_reason: Option<FinishReason>,
pub safety_ratings: Option<Vec<SafetyRating>>,
pub citation_metadata: Option<CitationMetadata>,
pub token_count: Option<i32>,
pub avg_logprobs: Option<f64>,
pub logprobs_result: Option<LogprobsResult>,
pub index: Option<i32>,
pub finish_message: Option<String>,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct Content {
#[serde(default)]
pub parts: Vec<Part>,
pub role: Option<Role>,
}
impl TryFrom<message::Message> for Content {
type Error = message::MessageError;
fn try_from(msg: message::Message) -> Result<Self, Self::Error> {
Ok(match msg {
message::Message::System { content } => Content {
parts: vec![content.into()],
role: Some(Role::User),
},
message::Message::User { content } => Content {
parts: content
.into_iter()
.map(|c| c.try_into())
.collect::<Result<Vec<_>, _>>()?,
role: Some(Role::User),
},
message::Message::Assistant { content, .. } => Content {
role: Some(Role::Model),
parts: content
.into_iter()
.map(|content| content.try_into())
.collect::<Result<Vec<_>, _>>()?,
},
})
}
}
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum Role {
User,
Model,
}
#[derive(Debug, Default, Deserialize, Serialize, Clone, PartialEq)]
#[serde(rename_all = "camelCase")]
pub struct Part {
#[serde(skip_serializing_if = "Option::is_none")]
pub thought: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub thought_signature: Option<String>,
#[serde(flatten)]
pub part: PartKind,
#[serde(flatten, skip_serializing_if = "Option::is_none")]
pub additional_params: Option<Value>,
}
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
#[serde(rename_all = "camelCase")]
pub enum PartKind {
Text(String),
InlineData(Blob),
FunctionCall(FunctionCall),
FunctionResponse(FunctionResponse),
FileData(FileData),
ExecutableCode(ExecutableCode),
CodeExecutionResult(CodeExecutionResult),
}
impl Default for PartKind {
fn default() -> Self {
Self::Text(String::new())
}
}
impl From<String> for Part {
fn from(text: String) -> Self {
Self {
thought: Some(false),
thought_signature: None,
part: PartKind::Text(text),
additional_params: None,
}
}
}
impl From<&str> for Part {
fn from(text: &str) -> Self {
Self::from(text.to_string())
}
}
impl FromStr for Part {
type Err = Infallible;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(s.into())
}
}
impl TryFrom<(ImageMediaType, DocumentSourceKind)> for PartKind {
type Error = message::MessageError;
fn try_from(
(mime_type, doc_src): (ImageMediaType, DocumentSourceKind),
) -> Result<Self, Self::Error> {
let mime_type = mime_type.to_mime_type().to_string();
let part = match doc_src {
DocumentSourceKind::Url(url) => PartKind::FileData(FileData {
mime_type: Some(mime_type),
file_uri: url,
}),
DocumentSourceKind::Base64(data) | DocumentSourceKind::String(data) => {
PartKind::InlineData(Blob { mime_type, data })
}
DocumentSourceKind::Raw(_) => {
return Err(message::MessageError::ConversionError(
"Raw files not supported, encode as base64 first".into(),
));
}
DocumentSourceKind::FileId(_) => {
return Err(message::MessageError::ConversionError(
"Provider file IDs are not supported for Gemini image inputs".into(),
));
}
DocumentSourceKind::Unknown => {
return Err(message::MessageError::ConversionError(
"Can't convert an unknown document source".to_string(),
));
}
};
Ok(part)
}
}
impl TryFrom<message::UserContent> for Part {
type Error = message::MessageError;
fn try_from(content: message::UserContent) -> Result<Self, Self::Error> {
match content {
message::UserContent::Text(message::Text { text, .. }) => Ok(Part {
thought: Some(false),
thought_signature: None,
part: PartKind::Text(text),
additional_params: None,
}),
message::UserContent::ToolResult(message::ToolResult { id, content, .. }) => {
let mut response_json: Option<serde_json::Value> = None;
let mut parts: Vec<FunctionResponsePart> = Vec::new();
for item in content.iter() {
match item {
message::ToolResultContent::Text(text) => {
let result: serde_json::Value =
serde_json::from_str(&text.text).unwrap_or_else(|error| {
tracing::trace!(
?error,
"Tool result is not a valid JSON, treat it as normal string"
);
json!(&text.text)
});
response_json = Some(match response_json {
Some(mut existing) => {
if let serde_json::Value::Object(ref mut map) = existing {
map.insert("text".to_string(), result);
}
existing
}
None => json!({ "result": result }),
});
}
message::ToolResultContent::Image(image) => {
let part = match &image.data {
DocumentSourceKind::Base64(b64) => {
let mime_type = image
.media_type
.as_ref()
.ok_or(message::MessageError::ConversionError(
"Image media type is required for Gemini tool results".to_string(),
))?
.to_mime_type();
FunctionResponsePart {
inline_data: Some(FunctionResponseInlineData {
mime_type: mime_type.to_string(),
data: b64.clone(),
display_name: None,
}),
file_data: None,
}
}
DocumentSourceKind::Url(url) => {
let mime_type = image
.media_type
.as_ref()
.map(|mt| mt.to_mime_type().to_string());
FunctionResponsePart {
inline_data: None,
file_data: Some(FileData {
mime_type,
file_uri: url.clone(),
}),
}
}
_ => {
return Err(message::MessageError::ConversionError(
"Unsupported image source kind for tool results"
.to_string(),
));
}
};
parts.push(part);
}
}
}
Ok(Part {
thought: Some(false),
thought_signature: None,
part: PartKind::FunctionResponse(FunctionResponse {
name: id,
response: response_json,
parts: if parts.is_empty() { None } else { Some(parts) },
}),
additional_params: None,
})
}
message::UserContent::Image(message::Image {
data, media_type, ..
}) => match media_type {
Some(media_type) => match media_type {
message::ImageMediaType::JPEG
| message::ImageMediaType::PNG
| message::ImageMediaType::WEBP
| message::ImageMediaType::HEIC
| message::ImageMediaType::HEIF => {
let part = PartKind::try_from((media_type, data))?;
Ok(Part {
thought: Some(false),
thought_signature: None,
part,
additional_params: None,
})
}
_ => Err(message::MessageError::ConversionError(format!(
"Unsupported image media type {media_type:?}"
))),
},
None => Err(message::MessageError::ConversionError(
"Media type for image is required for Gemini".to_string(),
)),
},
message::UserContent::Document(message::Document {
data, media_type, ..
}) => {
let Some(media_type) = media_type else {
return Err(MessageError::ConversionError(
"A mime type is required for document inputs to Gemini".to_string(),
));
};
if matches!(
media_type,
message::DocumentMediaType::TXT
| message::DocumentMediaType::RTF
| message::DocumentMediaType::HTML
| message::DocumentMediaType::CSS
| message::DocumentMediaType::MARKDOWN
| message::DocumentMediaType::CSV
| message::DocumentMediaType::XML
| message::DocumentMediaType::Javascript
| message::DocumentMediaType::Python
) {
use base64::Engine;
let part = match data {
DocumentSourceKind::String(text) => PartKind::Text(text),
DocumentSourceKind::Base64(data) => {
let text = String::from_utf8(
base64::engine::general_purpose::STANDARD
.decode(&data)
.map_err(|e| {
MessageError::ConversionError(format!(
"Failed to decode base64: {e}"
))
})?,
)
.map_err(|e| {
MessageError::ConversionError(format!(
"Invalid UTF-8 in document: {e}"
))
})?;
PartKind::Text(text)
}
DocumentSourceKind::Url(file_uri) => PartKind::FileData(FileData {
mime_type: Some(media_type.to_mime_type().to_string()),
file_uri,
}),
DocumentSourceKind::Raw(_) => {
return Err(MessageError::ConversionError(
"Raw files not supported, encode as base64 first".to_string(),
));
}
DocumentSourceKind::FileId(_) => {
return Err(MessageError::ConversionError(
"Provider file IDs are not supported for Gemini documents"
.to_string(),
));
}
DocumentSourceKind::Unknown => {
return Err(MessageError::ConversionError(
"Document has no body".to_string(),
));
}
};
Ok(Part {
thought: Some(false),
part,
..Default::default()
})
} else if !media_type.is_code() {
let mime_type = media_type.to_mime_type().to_string();
let part = match data {
DocumentSourceKind::Url(file_uri) => PartKind::FileData(FileData {
mime_type: Some(mime_type),
file_uri,
}),
DocumentSourceKind::Base64(data) | DocumentSourceKind::String(data) => {
PartKind::InlineData(Blob { mime_type, data })
}
DocumentSourceKind::Raw(_) => {
return Err(message::MessageError::ConversionError(
"Raw files not supported, encode as base64 first".into(),
));
}
_ => {
return Err(message::MessageError::ConversionError(
"Document has no body".to_string(),
));
}
};
Ok(Part {
thought: Some(false),
part,
..Default::default()
})
} else {
Err(message::MessageError::ConversionError(format!(
"Unsupported document media type {media_type:?}"
)))
}
}
message::UserContent::Audio(message::Audio {
data, media_type, ..
}) => {
let Some(media_type) = media_type else {
return Err(MessageError::ConversionError(
"A mime type is required for audio inputs to Gemini".to_string(),
));
};
let mime_type = media_type.to_mime_type().to_string();
let part = match data {
DocumentSourceKind::Base64(data) => {
PartKind::InlineData(Blob { data, mime_type })
}
DocumentSourceKind::Url(file_uri) => PartKind::FileData(FileData {
mime_type: Some(mime_type),
file_uri,
}),
DocumentSourceKind::String(_) => {
return Err(message::MessageError::ConversionError(
"Strings cannot be used as audio files!".into(),
));
}
DocumentSourceKind::Raw(_) => {
return Err(message::MessageError::ConversionError(
"Raw files not supported, encode as base64 first".into(),
));
}
DocumentSourceKind::FileId(_) => {
return Err(message::MessageError::ConversionError(
"Provider file IDs are not supported for Gemini audio inputs"
.into(),
));
}
DocumentSourceKind::Unknown => {
return Err(message::MessageError::ConversionError(
"Content has no body".to_string(),
));
}
};
Ok(Part {
thought: Some(false),
part,
..Default::default()
})
}
message::UserContent::Video(message::Video {
data,
media_type,
additional_params,
..
}) => {
let mime_type = media_type.map(|media_ty| media_ty.to_mime_type().to_string());
let part = match data {
DocumentSourceKind::Url(file_uri) => {
if file_uri.starts_with("https://www.youtube.com") {
PartKind::FileData(FileData {
mime_type,
file_uri,
})
} else {
if mime_type.is_none() {
return Err(MessageError::ConversionError(
"A mime type is required for non-Youtube video file inputs to Gemini"
.to_string(),
));
}
PartKind::FileData(FileData {
mime_type,
file_uri,
})
}
}
DocumentSourceKind::Base64(data) => {
let Some(mime_type) = mime_type else {
return Err(MessageError::ConversionError(
"A media type is expected for base64 encoded strings"
.to_string(),
));
};
PartKind::InlineData(Blob { mime_type, data })
}
DocumentSourceKind::String(_) => {
return Err(message::MessageError::ConversionError(
"Strings cannot be used as audio files!".into(),
));
}
DocumentSourceKind::Raw(_) => {
return Err(message::MessageError::ConversionError(
"Raw file data not supported, encode as base64 first".into(),
));
}
DocumentSourceKind::FileId(_) => {
return Err(message::MessageError::ConversionError(
"Provider file IDs are not supported for Gemini video inputs"
.into(),
));
}
DocumentSourceKind::Unknown => {
return Err(message::MessageError::ConversionError(
"Media type for video is required for Gemini".to_string(),
));
}
};
Ok(Part {
thought: Some(false),
thought_signature: None,
part,
additional_params,
})
}
}
}
}
impl TryFrom<message::AssistantContent> for Part {
type Error = message::MessageError;
fn try_from(content: message::AssistantContent) -> Result<Self, Self::Error> {
match content {
message::AssistantContent::Text(message::Text { text, .. }) => Ok(text.into()),
message::AssistantContent::Image(message::Image {
data, media_type, ..
}) => match media_type {
Some(media_type) => match media_type {
message::ImageMediaType::JPEG
| message::ImageMediaType::PNG
| message::ImageMediaType::WEBP
| message::ImageMediaType::HEIC
| message::ImageMediaType::HEIF => {
let part = PartKind::try_from((media_type, data))?;
Ok(Part {
thought: Some(false),
thought_signature: None,
part,
additional_params: None,
})
}
_ => Err(message::MessageError::ConversionError(format!(
"Unsupported image media type {media_type:?}"
))),
},
None => Err(message::MessageError::ConversionError(
"Media type for image is required for Gemini".to_string(),
)),
},
message::AssistantContent::ToolCall(tool_call) => Ok(tool_call.into()),
message::AssistantContent::Reasoning(reasoning) => Ok(Part {
thought: Some(true),
thought_signature: reasoning.first_signature().map(str::to_owned),
part: PartKind::Text(reasoning.display_text()),
additional_params: None,
}),
}
}
}
impl From<message::ToolCall> for Part {
fn from(tool_call: message::ToolCall) -> Self {
Self {
thought: Some(false),
thought_signature: tool_call.signature,
part: PartKind::FunctionCall(FunctionCall {
name: tool_call.function.name,
args: tool_call.function.arguments,
}),
additional_params: None,
}
}
}
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
#[serde(rename_all = "camelCase")]
pub struct Blob {
pub mime_type: String,
pub data: String,
}
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
pub struct FunctionCall {
pub name: String,
pub args: serde_json::Value,
}
impl From<message::ToolCall> for FunctionCall {
fn from(tool_call: message::ToolCall) -> Self {
Self {
name: tool_call.function.name,
args: tool_call.function.arguments,
}
}
}
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
pub struct FunctionResponse {
pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub response: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub parts: Option<Vec<FunctionResponsePart>>,
}
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
#[serde(rename_all = "camelCase")]
pub struct FunctionResponsePart {
#[serde(skip_serializing_if = "Option::is_none")]
pub inline_data: Option<FunctionResponseInlineData>,
#[serde(skip_serializing_if = "Option::is_none")]
pub file_data: Option<FileData>,
}
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
#[serde(rename_all = "camelCase")]
pub struct FunctionResponseInlineData {
pub mime_type: String,
pub data: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub display_name: Option<String>,
}
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
#[serde(rename_all = "camelCase")]
pub struct FileData {
pub mime_type: Option<String>,
pub file_uri: String,
}
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
pub struct SafetyRating {
pub category: HarmCategory,
pub probability: HarmProbability,
}
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum HarmProbability {
HarmProbabilityUnspecified,
Negligible,
Low,
Medium,
High,
}
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum HarmCategory {
HarmCategoryUnspecified,
HarmCategoryDerogatory,
HarmCategoryToxicity,
HarmCategoryViolence,
HarmCategorySexually,
HarmCategoryMedical,
HarmCategoryDangerous,
HarmCategoryHarassment,
HarmCategoryHateSpeech,
HarmCategorySexuallyExplicit,
HarmCategoryDangerousContent,
HarmCategoryCivicIntegrity,
}
#[derive(Debug, Deserialize, Clone, Default, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct UsageMetadata {
#[serde(default)]
pub prompt_token_count: i32,
#[serde(skip_serializing_if = "Option::is_none")]
pub cached_content_token_count: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub candidates_token_count: Option<i32>,
pub total_token_count: i32,
#[serde(skip_serializing_if = "Option::is_none")]
pub thoughts_token_count: Option<i32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub prompt_tokens_details: Option<Vec<ModalityTokenCount>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub cache_tokens_details: Option<Vec<ModalityTokenCount>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub candidates_tokens_details: Option<Vec<ModalityTokenCount>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tool_use_prompt_token_count: Option<i32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tool_use_prompt_tokens_details: Option<Vec<ModalityTokenCount>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub traffic_type: Option<TrafficType>,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct ModalityTokenCount {
pub modality: Modality,
pub token_count: i32,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum Modality {
ModalityUnspecified,
Text,
Image,
Video,
Audio,
Document,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum TrafficType {
TrafficTypeUnspecified,
OnDemand,
ProvisionedThroughput,
}
impl std::fmt::Display for UsageMetadata {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Prompt token count: {}\nCached content token count: {}\nCandidates token count: {}\nTotal token count: {}",
self.prompt_token_count,
match self.cached_content_token_count {
Some(count) => count.to_string(),
None => "n/a".to_string(),
},
match self.candidates_token_count {
Some(count) => count.to_string(),
None => "n/a".to_string(),
},
self.total_token_count
)
}
}
impl GetTokenUsage for UsageMetadata {
fn token_usage(&self) -> Option<crate::completion::Usage> {
let mut usage = crate::completion::Usage::new();
usage.input_tokens = self.prompt_token_count as u64;
usage.output_tokens = self.candidates_token_count.unwrap_or_default() as u64;
usage.cached_input_tokens = self.cached_content_token_count.unwrap_or_default() as u64;
usage.reasoning_tokens = self.thoughts_token_count.unwrap_or_default() as u64;
usage.tool_use_prompt_tokens =
self.tool_use_prompt_token_count.unwrap_or_default() as u64;
usage.total_tokens = self.total_token_count as u64;
Some(usage)
}
}
#[derive(Debug, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct PromptFeedback {
pub block_reason: Option<BlockReason>,
pub safety_ratings: Option<Vec<SafetyRating>>,
}
#[derive(Debug, Deserialize, Serialize)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum BlockReason {
BlockReasonUnspecified,
Safety,
Other,
Blocklist,
ProhibitedContent,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum FinishReason {
FinishReasonUnspecified,
Stop,
MaxTokens,
Safety,
Recitation,
Language,
Other,
Blocklist,
ProhibitedContent,
Spii,
MalformedFunctionCall,
UnexpectedToolCall,
MissingThoughtSignature,
TooManyToolCalls,
MalformedResponse,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct CitationMetadata {
pub citation_sources: Vec<CitationSource>,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct CitationSource {
#[serde(skip_serializing_if = "Option::is_none")]
pub uri: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub start_index: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub end_index: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub license: Option<String>,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct LogprobsResult {
pub top_candidate: Vec<TopCandidate>,
pub chosen_candidate: Vec<LogProbCandidate>,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct TopCandidate {
pub candidates: Vec<LogProbCandidate>,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct LogProbCandidate {
pub token: String,
pub token_id: String,
pub log_probability: f64,
}
#[derive(Debug, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct GenerationConfig {
#[serde(skip_serializing_if = "Option::is_none")]
pub stop_sequences: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub response_mime_type: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub response_schema: Option<Schema>,
#[serde(
skip_serializing_if = "Option::is_none",
rename = "_responseJsonSchema"
)]
pub _response_json_schema: Option<Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub response_json_schema: Option<Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub candidate_count: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_output_tokens: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_k: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub presence_penalty: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub frequency_penalty: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub response_logprobs: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub logprobs: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub thinking_config: Option<ThinkingConfig>,
#[serde(skip_serializing_if = "Option::is_none")]
pub image_config: Option<ImageConfig>,
}
impl Default for GenerationConfig {
fn default() -> Self {
Self {
temperature: Some(1.0),
max_output_tokens: Some(4096),
stop_sequences: None,
response_mime_type: None,
response_schema: None,
_response_json_schema: None,
response_json_schema: None,
candidate_count: None,
top_p: None,
top_k: None,
presence_penalty: None,
frequency_penalty: None,
response_logprobs: None,
logprobs: None,
thinking_config: None,
image_config: None,
}
}
}
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum ThinkingLevel {
Minimal,
Low,
Medium,
High,
}
#[derive(Debug, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct ThinkingConfig {
#[serde(skip_serializing_if = "Option::is_none")]
pub thinking_budget: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub thinking_level: Option<ThinkingLevel>,
#[serde(skip_serializing_if = "Option::is_none")]
pub include_thoughts: Option<bool>,
}
#[derive(Debug, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct ImageConfig {
#[serde(skip_serializing_if = "Option::is_none")]
pub aspect_ratio: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub image_size: Option<String>,
}
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct Schema {
pub r#type: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub format: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub nullable: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub r#enum: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_items: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub min_items: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub properties: Option<HashMap<String, Schema>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub required: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub items: Option<Box<Schema>>,
}
pub fn tool_parameters_to_schema(parameters: Value) -> Result<Option<Schema>, CompletionError> {
if parameters.is_null() || parameters == json!({"type": "object", "properties": {}}) {
Ok(None)
} else {
parameters.try_into().map(Some)
}
}
pub fn flatten_schema(mut schema: Value) -> Result<Value, CompletionError> {
let defs = if let Some(obj) = schema.as_object() {
obj.get("$defs").or_else(|| obj.get("definitions")).cloned()
} else {
None
};
let Some(defs_value) = defs else {
return Ok(schema);
};
let Some(defs_obj) = defs_value.as_object() else {
return Err(CompletionError::ResponseError(
"$defs must be an object".into(),
));
};
resolve_refs(&mut schema, defs_obj)?;
if let Some(obj) = schema.as_object_mut() {
obj.remove("$defs");
obj.remove("definitions");
}
Ok(schema)
}
fn resolve_refs(
value: &mut Value,
defs: &serde_json::Map<String, Value>,
) -> Result<(), CompletionError> {
match value {
Value::Object(obj) => {
if let Some(ref_value) = obj.get("$ref")
&& let Some(ref_str) = ref_value.as_str()
{
let def_name = parse_ref_path(ref_str)?;
let def = defs.get(&def_name).ok_or_else(|| {
CompletionError::ResponseError(format!("Reference not found: {}", ref_str))
})?;
let mut resolved = def.clone();
resolve_refs(&mut resolved, defs)?;
*value = resolved;
return Ok(());
}
for (_, v) in obj.iter_mut() {
resolve_refs(v, defs)?;
}
}
Value::Array(arr) => {
for item in arr.iter_mut() {
resolve_refs(item, defs)?;
}
}
_ => {}
}
Ok(())
}
fn parse_ref_path(ref_str: &str) -> Result<String, CompletionError> {
if let Some(fragment) = ref_str.strip_prefix('#') {
if let Some(name) = fragment.strip_prefix("/$defs/") {
Ok(name.to_string())
} else if let Some(name) = fragment.strip_prefix("/definitions/") {
Ok(name.to_string())
} else {
Err(CompletionError::ResponseError(format!(
"Unsupported reference format: {}",
ref_str
)))
}
} else {
Err(CompletionError::ResponseError(format!(
"Only fragment references (#/...) are supported: {}",
ref_str
)))
}
}
fn extract_type(type_value: &Value) -> Option<String> {
if let Some(t) = type_value.as_str() {
return Some(t.to_string());
}
type_value.as_array().and_then(|arr| {
arr.iter()
.filter_map(|v| v.as_str())
.find(|t| *t != "null")
.or_else(|| arr.iter().find_map(|v| v.as_str()))
.map(str::to_owned)
})
}
fn schema_is_null(obj: &serde_json::Map<String, Value>) -> bool {
obj.get("type")
.and_then(extract_type)
.as_deref()
.is_some_and(|t| t == "null")
}
fn schema_is_nullable(obj: &serde_json::Map<String, Value>) -> bool {
obj.get("nullable")
.and_then(|v| v.as_bool())
.unwrap_or(false)
|| obj
.get("type")
.and_then(|v| v.as_array())
.is_some_and(|arr| arr.iter().any(|v| v.as_str() == Some("null")))
|| ["anyOf", "oneOf", "allOf"].iter().any(|key| {
obj.get(*key).and_then(|v| v.as_array()).is_some_and(|arr| {
arr.iter()
.filter_map(|schema| schema.as_object())
.any(schema_is_null)
})
})
}
fn extract_type_from_composition(composition: &Value) -> Option<String> {
composition.as_array().and_then(|arr| {
arr.iter().find_map(|schema| {
let obj = schema.as_object()?;
if schema_is_null(obj) {
return None;
}
obj.get("type").and_then(extract_type).or_else(|| {
if obj.contains_key("properties") {
Some("object".to_string())
} else if obj.contains_key("enum") {
Some("string".to_string())
} else {
None
}
})
})
})
}
fn extract_schema_from_composition(
composition: &Value,
) -> Option<serde_json::Map<String, Value>> {
composition.as_array().and_then(|arr| {
arr.iter().find_map(|schema| {
let obj = schema.as_object()?;
if schema_is_null(obj) {
None
} else {
Some(obj.clone())
}
})
})
}
fn extract_schema_from_composition_obj(
obj: &serde_json::Map<String, Value>,
) -> Option<serde_json::Map<String, Value>> {
obj.get("anyOf")
.and_then(extract_schema_from_composition)
.or_else(|| obj.get("oneOf").and_then(extract_schema_from_composition))
.or_else(|| obj.get("allOf").and_then(extract_schema_from_composition))
}
fn infer_type(obj: &serde_json::Map<String, Value>) -> String {
if let Some(type_val) = obj.get("type")
&& let Some(type_str) = extract_type(type_val)
{
return type_str;
}
if let Some(any_of) = obj.get("anyOf")
&& let Some(type_str) = extract_type_from_composition(any_of)
{
return type_str;
}
if let Some(one_of) = obj.get("oneOf")
&& let Some(type_str) = extract_type_from_composition(one_of)
{
return type_str;
}
if let Some(all_of) = obj.get("allOf")
&& let Some(type_str) = extract_type_from_composition(all_of)
{
return type_str;
}
if obj.contains_key("properties") {
"object".to_string()
} else if obj.contains_key("enum") {
"string".to_string()
} else {
String::new()
}
}
impl TryFrom<Value> for Schema {
type Error = CompletionError;
fn try_from(value: Value) -> Result<Self, Self::Error> {
let flattened_val = flatten_schema(value)?;
if let Some(obj) = flattened_val.as_object() {
let composition_source = extract_schema_from_composition_obj(obj);
let props_source = if obj.get("properties").is_none() {
composition_source.clone().unwrap_or(obj.clone())
} else {
obj.clone()
};
let schema_type = infer_type(obj);
let items = obj
.get("items")
.or_else(|| props_source.get("items"))
.and_then(|v| v.clone().try_into().ok())
.map(Box::new);
let items = if schema_type == "array" && items.is_none() {
Some(Box::new(Schema {
r#type: "string".to_string(),
format: None,
description: None,
nullable: None,
r#enum: None,
max_items: None,
min_items: None,
properties: None,
required: None,
items: None,
}))
} else {
items
};
Ok(Schema {
r#type: schema_type,
format: obj
.get("format")
.or_else(|| props_source.get("format"))
.and_then(|v| v.as_str())
.map(String::from),
description: obj
.get("description")
.or_else(|| props_source.get("description"))
.and_then(|v| v.as_str())
.map(String::from),
nullable: if schema_is_nullable(obj)
|| composition_source.as_ref().is_some_and(schema_is_nullable)
{
Some(true)
} else {
None
},
r#enum: obj
.get("enum")
.or_else(|| props_source.get("enum"))
.and_then(|v| v.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| v.as_str().map(String::from))
.collect()
}),
max_items: obj
.get("maxItems")
.and_then(|v| v.as_i64())
.map(|v| v as i32),
min_items: obj
.get("minItems")
.and_then(|v| v.as_i64())
.map(|v| v as i32),
properties: props_source
.get("properties")
.and_then(|v| v.as_object())
.map(|map| {
map.iter()
.filter_map(|(k, v)| {
v.clone().try_into().ok().map(|schema| (k.clone(), schema))
})
.collect()
}),
required: props_source
.get("required")
.and_then(|v| v.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| v.as_str().map(String::from))
.collect()
}),
items,
})
} else {
Err(CompletionError::ResponseError(
"Expected a JSON object for Schema".into(),
))
}
}
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct GenerateContentRequest {
pub contents: Vec<Content>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<Value>>,
pub tool_config: Option<ToolConfig>,
pub generation_config: Option<GenerationConfig>,
pub safety_settings: Option<Vec<SafetySetting>>,
pub system_instruction: Option<Content>,
#[serde(flatten, skip_serializing_if = "Option::is_none")]
pub additional_params: Option<serde_json::Value>,
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct Tool {
pub function_declarations: Vec<FunctionDeclaration>,
pub code_execution: Option<CodeExecution>,
}
#[derive(Debug, Serialize, Clone)]
#[serde(rename_all = "camelCase")]
pub struct FunctionDeclaration {
pub name: String,
pub description: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub parameters: Option<Schema>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ToolConfig {
pub function_calling_config: Option<FunctionCallingMode>,
}
#[derive(Debug, Serialize, Deserialize, Default)]
#[serde(tag = "mode", rename_all = "UPPERCASE")]
pub enum FunctionCallingMode {
#[default]
Auto,
None,
Any {
#[serde(skip_serializing_if = "Option::is_none")]
allowed_function_names: Option<Vec<String>>,
},
}
impl TryFrom<message::ToolChoice> for FunctionCallingMode {
type Error = CompletionError;
fn try_from(value: message::ToolChoice) -> Result<Self, Self::Error> {
let res = match value {
message::ToolChoice::Auto => Self::Auto,
message::ToolChoice::None => Self::None,
message::ToolChoice::Required => Self::Any {
allowed_function_names: None,
},
message::ToolChoice::Specific { function_names } => Self::Any {
allowed_function_names: Some(function_names),
},
};
Ok(res)
}
}
#[derive(Debug, Serialize)]
pub struct CodeExecution {}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct SafetySetting {
pub category: HarmCategory,
pub threshold: HarmBlockThreshold,
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum HarmBlockThreshold {
HarmBlockThresholdUnspecified,
BlockLowAndAbove,
BlockMediumAndAbove,
BlockOnlyHigh,
BlockNone,
Off,
}
}
#[cfg(test)]
mod tests {
use crate::{
message,
providers::gemini::completion::gemini_api_types::{
ContentCandidate, FinishReason, FunctionCall, Schema, UsageMetadata, flatten_schema,
tool_parameters_to_schema,
},
};
use super::*;
use serde_json::json;
#[test]
fn test_resolve_request_model_uses_override() {
let request = CompletionRequest {
model: Some("gemini-2.5-flash".to_string()),
preamble: None,
chat_history: crate::OneOrMany::one("Hello".into()),
documents: vec![],
tools: vec![],
temperature: None,
max_tokens: None,
tool_choice: None,
additional_params: None,
output_schema: None,
};
let request_model = resolve_request_model("gemini-2.0-flash", &request);
assert_eq!(request_model, "gemini-2.5-flash");
assert_eq!(
completion_endpoint(&request_model),
"/v1beta/models/gemini-2.5-flash:generateContent"
);
assert_eq!(
streaming_endpoint(&request_model),
"/v1beta/models/gemini-2.5-flash:streamGenerateContent"
);
}
#[test]
fn test_resolve_request_model_uses_default_when_unset() {
let request = CompletionRequest {
model: None,
preamble: None,
chat_history: crate::OneOrMany::one("Hello".into()),
documents: vec![],
tools: vec![],
temperature: None,
max_tokens: None,
tool_choice: None,
additional_params: None,
output_schema: None,
};
assert_eq!(
resolve_request_model("gemini-2.0-flash", &request),
"gemini-2.0-flash"
);
}
#[test]
fn test_deserialize_message_user() {
let raw_message = r#"{
"parts": [
{"text": "Hello, world!"},
{"inlineData": {"mimeType": "image/png", "data": "base64encodeddata"}},
{"functionCall": {"name": "test_function", "args": {"arg1": "value1"}}},
{"functionResponse": {"name": "test_function", "response": {"result": "success"}}},
{"fileData": {"mimeType": "application/pdf", "fileUri": "http://example.com/file.pdf"}},
{"executableCode": {"code": "print('Hello, world!')", "language": "PYTHON"}},
{"codeExecutionResult": {"output": "Hello, world!", "outcome": "OUTCOME_OK"}}
],
"role": "user"
}"#;
let content: Content = {
let jd = &mut serde_json::Deserializer::from_str(raw_message);
serde_path_to_error::deserialize(jd).unwrap_or_else(|err| {
panic!("Deserialization error at {}: {}", err.path(), err);
})
};
assert_eq!(content.role, Some(Role::User));
assert_eq!(content.parts.len(), 7);
let parts: Vec<Part> = content.parts.into_iter().collect();
if let Part {
part: PartKind::Text(text),
..
} = &parts[0]
{
assert_eq!(text, "Hello, world!");
} else {
panic!("Expected text part");
}
if let Part {
part: PartKind::InlineData(inline_data),
..
} = &parts[1]
{
assert_eq!(inline_data.mime_type, "image/png");
assert_eq!(inline_data.data, "base64encodeddata");
} else {
panic!("Expected inline data part");
}
if let Part {
part: PartKind::FunctionCall(function_call),
..
} = &parts[2]
{
assert_eq!(function_call.name, "test_function");
assert_eq!(
function_call.args.as_object().unwrap().get("arg1").unwrap(),
"value1"
);
} else {
panic!("Expected function call part");
}
if let Part {
part: PartKind::FunctionResponse(function_response),
..
} = &parts[3]
{
assert_eq!(function_response.name, "test_function");
assert_eq!(
function_response
.response
.as_ref()
.unwrap()
.get("result")
.unwrap(),
"success"
);
} else {
panic!("Expected function response part");
}
if let Part {
part: PartKind::FileData(file_data),
..
} = &parts[4]
{
assert_eq!(file_data.mime_type.as_ref().unwrap(), "application/pdf");
assert_eq!(file_data.file_uri, "http://example.com/file.pdf");
} else {
panic!("Expected file data part");
}
if let Part {
part: PartKind::ExecutableCode(executable_code),
..
} = &parts[5]
{
assert_eq!(executable_code.code, "print('Hello, world!')");
} else {
panic!("Expected executable code part");
}
if let Part {
part: PartKind::CodeExecutionResult(code_execution_result),
..
} = &parts[6]
{
assert_eq!(
code_execution_result.clone().output.unwrap(),
"Hello, world!"
);
} else {
panic!("Expected code execution result part");
}
}
#[test]
fn test_deserialize_message_model() {
let json_data = json!({
"parts": [{"text": "Hello, user!"}],
"role": "model"
});
let content: Content = serde_json::from_value(json_data).unwrap();
assert_eq!(content.role, Some(Role::Model));
assert_eq!(content.parts.len(), 1);
if let Some(Part {
part: PartKind::Text(text),
..
}) = content.parts.first()
{
assert_eq!(text, "Hello, user!");
} else {
panic!("Expected text part");
}
}
#[test]
fn test_message_conversion_user() {
let msg = message::Message::user("Hello, world!");
let content: Content = msg.try_into().unwrap();
assert_eq!(content.role, Some(Role::User));
assert_eq!(content.parts.len(), 1);
if let Some(Part {
part: PartKind::Text(text),
..
}) = &content.parts.first()
{
assert_eq!(text, "Hello, world!");
} else {
panic!("Expected text part");
}
}
#[test]
fn test_message_conversion_model() {
let msg = message::Message::assistant("Hello, user!");
let content: Content = msg.try_into().unwrap();
assert_eq!(content.role, Some(Role::Model));
assert_eq!(content.parts.len(), 1);
if let Some(Part {
part: PartKind::Text(text),
..
}) = &content.parts.first()
{
assert_eq!(text, "Hello, user!");
} else {
panic!("Expected text part");
}
}
#[test]
fn test_thought_signature_is_preserved_from_response_reasoning_part() {
let response = GenerateContentResponse {
response_id: "resp_1".to_string(),
candidates: vec![ContentCandidate {
content: Some(Content {
parts: vec![Part {
thought: Some(true),
thought_signature: Some("thought_sig_123".to_string()),
part: PartKind::Text("thinking text".to_string()),
additional_params: None,
}],
role: Some(Role::Model),
}),
finish_reason: Some(FinishReason::Stop),
safety_ratings: None,
citation_metadata: None,
token_count: None,
avg_logprobs: None,
logprobs_result: None,
index: Some(0),
finish_message: None,
}],
prompt_feedback: None,
usage_metadata: None,
model_version: None,
};
let converted: crate::completion::CompletionResponse<GenerateContentResponse> =
response.try_into().expect("convert response");
let first = converted.choice.first();
assert!(matches!(
first,
message::AssistantContent::Reasoning(message::Reasoning { content, .. })
if matches!(
content.first(),
Some(message::ReasoningContent::Text {
text,
signature: Some(signature)
}) if text == "thinking text" && signature == "thought_sig_123"
)
));
}
#[test]
fn test_tool_protocol_finish_reason_returns_response_error() {
for (reason, finish_message) in [
(
FinishReason::MalformedFunctionCall,
"malformed function call: default_api",
),
(
FinishReason::UnexpectedToolCall,
"unexpected tool call: default_api",
),
(
FinishReason::MissingThoughtSignature,
"missing thought signature for tool call",
),
(
FinishReason::TooManyToolCalls,
"too many tool calls in response",
),
(
FinishReason::MalformedResponse,
"malformed response from provider",
),
] {
let reason_name = format!("{reason:?}");
let response = GenerateContentResponse {
response_id: "resp_tool_protocol_error".to_string(),
candidates: vec![ContentCandidate {
content: Some(Content {
parts: vec![Part {
thought: None,
thought_signature: None,
part: PartKind::FunctionCall(FunctionCall {
name: "default_api".to_string(),
args: json!({"x": 1}),
}),
additional_params: None,
}],
role: Some(Role::Model),
}),
finish_reason: Some(reason),
safety_ratings: None,
citation_metadata: None,
token_count: None,
avg_logprobs: None,
logprobs_result: None,
index: Some(0),
finish_message: Some(finish_message.to_string()),
}],
prompt_feedback: None,
usage_metadata: None,
model_version: None,
};
let err = crate::completion::CompletionResponse::<GenerateContentResponse>::try_from(
response,
)
.expect_err("tool protocol finish reason should fail");
assert!(matches!(
err,
CompletionError::ResponseError(message)
if message.contains(&reason_name)
&& message.contains(finish_message)
));
}
}
#[test]
fn test_completion_response_usage_preserves_cached_and_reasoning_tokens() {
let response = GenerateContentResponse {
response_id: "resp_1".to_string(),
candidates: vec![ContentCandidate {
content: Some(Content {
parts: vec![Part {
thought: None,
thought_signature: None,
part: PartKind::Text("answer".to_string()),
additional_params: None,
}],
role: Some(Role::Model),
}),
finish_reason: Some(FinishReason::Stop),
safety_ratings: None,
citation_metadata: None,
token_count: None,
avg_logprobs: None,
logprobs_result: None,
index: Some(0),
finish_message: None,
}],
prompt_feedback: None,
usage_metadata: Some(UsageMetadata {
prompt_token_count: 40,
cached_content_token_count: Some(20),
candidates_token_count: Some(30),
total_token_count: 100,
thoughts_token_count: Some(10),
prompt_tokens_details: None,
cache_tokens_details: None,
candidates_tokens_details: None,
tool_use_prompt_token_count: Some(12),
tool_use_prompt_tokens_details: None,
traffic_type: None,
}),
model_version: Some("gemini-2.0-flash-001".to_string()),
};
let converted: crate::completion::CompletionResponse<GenerateContentResponse> =
response.try_into().expect("convert response");
assert_eq!(converted.usage.input_tokens, 40);
assert_eq!(converted.usage.cached_input_tokens, 20);
assert_eq!(converted.usage.output_tokens, 30);
assert_eq!(converted.usage.reasoning_tokens, 10);
assert_eq!(converted.usage.tool_use_prompt_tokens, 12);
assert_eq!(converted.usage.total_tokens, 100);
}
#[test]
fn test_reasoning_signature_is_emitted_in_gemini_part() {
let msg = message::Message::Assistant {
id: None,
content: OneOrMany::one(message::AssistantContent::Reasoning(
message::Reasoning::new_with_signature(
"structured thought",
Some("reuse_sig_456".to_string()),
),
)),
};
let converted: Content = msg.try_into().expect("convert message");
let first = converted.parts.first().expect("reasoning part");
assert_eq!(first.thought, Some(true));
assert_eq!(first.thought_signature.as_deref(), Some("reuse_sig_456"));
assert!(matches!(
&first.part,
PartKind::Text(text) if text == "structured thought"
));
}
#[test]
fn test_message_conversion_tool_call() {
let tool_call = message::ToolCall {
id: "test_tool".to_string(),
call_id: None,
function: message::ToolFunction {
name: "test_function".to_string(),
arguments: json!({"arg1": "value1"}),
},
signature: None,
additional_params: None,
};
let msg = message::Message::Assistant {
id: None,
content: OneOrMany::one(message::AssistantContent::ToolCall(tool_call)),
};
let content: Content = msg.try_into().unwrap();
assert_eq!(content.role, Some(Role::Model));
assert_eq!(content.parts.len(), 1);
if let Some(Part {
part: PartKind::FunctionCall(function_call),
..
}) = content.parts.first()
{
assert_eq!(function_call.name, "test_function");
assert_eq!(
function_call.args.as_object().unwrap().get("arg1").unwrap(),
"value1"
);
} else {
panic!("Expected function call part");
}
}
#[test]
fn test_vec_schema_conversion() {
let schema_with_ref = json!({
"type": "array",
"items": {
"$ref": "#/$defs/Person"
},
"$defs": {
"Person": {
"type": "object",
"properties": {
"first_name": {
"type": ["string", "null"],
"description": "The person's first name, if provided (null otherwise)"
},
"last_name": {
"type": ["string", "null"],
"description": "The person's last name, if provided (null otherwise)"
},
"job": {
"type": ["string", "null"],
"description": "The person's job, if provided (null otherwise)"
}
},
"required": []
}
}
});
let result: Result<Schema, _> = schema_with_ref.try_into();
match result {
Ok(schema) => {
assert_eq!(schema.r#type, "array");
if let Some(items) = schema.items {
println!("item types: {}", items.r#type);
assert_ne!(items.r#type, "", "Items type should not be empty string!");
assert_eq!(items.r#type, "object", "Items should be object type");
} else {
panic!("Schema should have items field for array type");
}
}
Err(e) => println!("Schema conversion failed: {:?}", e),
}
}
#[test]
fn test_object_schema() {
let simple_schema = json!({
"type": "object",
"properties": {
"name": {
"type": "string"
}
}
});
let schema: Schema = simple_schema.try_into().unwrap();
assert_eq!(schema.r#type, "object");
assert!(schema.properties.is_some());
}
#[test]
fn test_array_with_inline_items() {
let inline_schema = json!({
"type": "array",
"items": {
"type": "object",
"properties": {
"name": {
"type": "string"
}
}
}
});
let schema: Schema = inline_schema.try_into().unwrap();
assert_eq!(schema.r#type, "array");
if let Some(items) = schema.items {
assert_eq!(items.r#type, "object");
assert!(items.properties.is_some());
} else {
panic!("Schema should have items field");
}
}
#[test]
fn test_flattened_schema() {
let ref_schema = json!({
"type": "array",
"items": {
"$ref": "#/$defs/Person"
},
"$defs": {
"Person": {
"type": "object",
"properties": {
"name": { "type": "string" }
}
}
}
});
let flattened = flatten_schema(ref_schema).unwrap();
let schema: Schema = flattened.try_into().unwrap();
assert_eq!(schema.r#type, "array");
if let Some(items) = schema.items {
println!("Flattened items type: '{}'", items.r#type);
assert_eq!(items.r#type, "object");
assert!(items.properties.is_some());
}
}
#[test]
fn test_array_without_items_gets_default() {
let schema_json = json!({
"type": "object",
"properties": {
"service_ids": {
"type": "array",
"description": "A list of service IDs"
}
}
});
let schema: Schema = schema_json.try_into().unwrap();
let props = schema.properties.unwrap();
let service_ids = props.get("service_ids").unwrap();
assert_eq!(service_ids.r#type, "array");
let items = service_ids
.items
.as_ref()
.expect("array schema missing items should get a default");
assert_eq!(items.r#type, "string");
}
#[test]
fn test_tool_parameters_to_schema_maps_no_arg_tool_to_none() {
let schema = tool_parameters_to_schema(json!({"type": "object", "properties": {}}))
.expect("schema conversion");
assert!(schema.is_none());
}
#[test]
fn test_tool_parameters_to_schema_resolves_defs_ref() {
let schema_json = json!({
"type": "object",
"properties": {
"destination": { "$ref": "#/$defs/Destination" }
},
"required": ["destination"],
"$defs": {
"Destination": {
"type": "object",
"properties": {
"city": { "type": "string" }
},
"required": ["city"]
}
}
});
let schema = tool_parameters_to_schema(schema_json)
.expect("schema conversion")
.expect("schema");
let props = schema.properties.expect("properties");
let destination = props.get("destination").expect("destination prop");
assert_eq!(destination.r#type, "object");
assert_eq!(destination.required, Some(vec!["city".to_string()]));
}
#[test]
fn test_tool_parameters_to_schema_handles_nullable_type_arrays() {
let schema_json = json!({
"type": "object",
"properties": {
"nickname": { "type": ["null", "string"] }
}
});
let schema = tool_parameters_to_schema(schema_json)
.expect("schema conversion")
.expect("schema");
let props = schema.properties.expect("properties");
let nickname = props.get("nickname").expect("nickname prop");
assert_eq!(nickname.r#type, "string");
assert_eq!(nickname.nullable, Some(true));
}
#[test]
fn test_txt_document_conversion_to_text_part() {
use crate::message::{DocumentMediaType, UserContent};
let doc = UserContent::document(
"Note: test.md\nPath: /test.md\nContent: Hello World!",
Some(DocumentMediaType::TXT),
);
let content: Content = message::Message::User {
content: crate::OneOrMany::one(doc),
}
.try_into()
.unwrap();
if let Part {
part: PartKind::Text(text),
..
} = &content.parts[0]
{
assert!(text.contains("Note: test.md"));
assert!(text.contains("Hello World!"));
} else {
panic!(
"Expected text part for TXT document, got: {:?}",
content.parts[0]
);
}
}
#[test]
fn test_tool_result_with_image_content() {
use crate::OneOrMany;
use crate::message::{
DocumentSourceKind, Image, ImageMediaType, ToolResult, ToolResultContent,
};
let tool_result = ToolResult {
id: "test_tool".to_string(),
call_id: None,
content: OneOrMany::many(vec![
ToolResultContent::Text(message::Text::new(r#"{"status": "success"}"#.to_string())),
ToolResultContent::Image(Image {
data: DocumentSourceKind::Base64("iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==".to_string()),
media_type: Some(ImageMediaType::PNG),
detail: None,
additional_params: None,
}),
]).expect("Should create OneOrMany with multiple items"),
};
let user_content = message::UserContent::ToolResult(tool_result);
let msg = message::Message::User {
content: OneOrMany::one(user_content),
};
let content: Content = msg.try_into().expect("Should convert to Gemini Content");
assert_eq!(content.role, Some(Role::User));
assert_eq!(content.parts.len(), 1);
if let Some(Part {
part: PartKind::FunctionResponse(function_response),
..
}) = content.parts.first()
{
assert_eq!(function_response.name, "test_tool");
assert!(function_response.response.is_some());
let response = function_response.response.as_ref().unwrap();
assert!(response.get("result").is_some());
assert!(function_response.parts.is_some());
let parts = function_response.parts.as_ref().unwrap();
assert_eq!(parts.len(), 1);
let image_part = &parts[0];
assert!(image_part.inline_data.is_some());
let inline_data = image_part.inline_data.as_ref().unwrap();
assert_eq!(inline_data.mime_type, "image/png");
assert!(!inline_data.data.is_empty());
} else {
panic!("Expected FunctionResponse part");
}
}
#[test]
fn test_markdown_document_conversion_to_text_part() {
use crate::message::{DocumentMediaType, UserContent};
let doc = UserContent::document(
"# Heading\n\n* List item",
Some(DocumentMediaType::MARKDOWN),
);
let content: Content = message::Message::User {
content: crate::OneOrMany::one(doc),
}
.try_into()
.unwrap();
if let Part {
part: PartKind::Text(text),
..
} = &content.parts[0]
{
assert_eq!(text, "# Heading\n\n* List item");
} else {
panic!(
"Expected text part for MARKDOWN document, got: {:?}",
content.parts[0]
);
}
}
#[test]
fn test_markdown_url_document_conversion_to_file_data_part() {
use crate::message::{DocumentMediaType, DocumentSourceKind, UserContent};
let doc = UserContent::Document(message::Document {
data: DocumentSourceKind::Url(
"https://generativelanguage.googleapis.com/v1beta/files/test-markdown".to_string(),
),
media_type: Some(DocumentMediaType::MARKDOWN),
additional_params: None,
});
let content: Content = message::Message::User {
content: crate::OneOrMany::one(doc),
}
.try_into()
.unwrap();
if let Part {
part: PartKind::FileData(file_data),
..
} = &content.parts[0]
{
assert_eq!(
file_data.file_uri,
"https://generativelanguage.googleapis.com/v1beta/files/test-markdown"
);
assert_eq!(file_data.mime_type.as_deref(), Some("text/markdown"));
} else {
panic!(
"Expected file_data part for URL MARKDOWN document, got: {:?}",
content.parts[0]
);
}
}
#[test]
fn test_tool_result_with_url_image() {
use crate::OneOrMany;
use crate::message::{
DocumentSourceKind, Image, ImageMediaType, ToolResult, ToolResultContent,
};
let tool_result = ToolResult {
id: "screenshot_tool".to_string(),
call_id: None,
content: OneOrMany::one(ToolResultContent::Image(Image {
data: DocumentSourceKind::Url("https://example.com/image.png".to_string()),
media_type: Some(ImageMediaType::PNG),
detail: None,
additional_params: None,
})),
};
let user_content = message::UserContent::ToolResult(tool_result);
let msg = message::Message::User {
content: OneOrMany::one(user_content),
};
let content: Content = msg.try_into().expect("Should convert to Gemini Content");
assert_eq!(content.role, Some(Role::User));
assert_eq!(content.parts.len(), 1);
if let Some(Part {
part: PartKind::FunctionResponse(function_response),
..
}) = content.parts.first()
{
assert_eq!(function_response.name, "screenshot_tool");
assert!(function_response.parts.is_some());
let parts = function_response.parts.as_ref().unwrap();
assert_eq!(parts.len(), 1);
let image_part = &parts[0];
assert!(image_part.file_data.is_some());
let file_data = image_part.file_data.as_ref().unwrap();
assert_eq!(file_data.file_uri, "https://example.com/image.png");
assert_eq!(file_data.mime_type.as_ref().unwrap(), "image/png");
} else {
panic!("Expected FunctionResponse part");
}
}
#[test]
fn test_create_request_body_with_documents() {
use crate::OneOrMany;
use crate::completion::request::{CompletionRequest, Document};
use crate::message::Message;
let documents = vec![
Document {
id: "doc1".to_string(),
text: "Note: first.md\nContent: First note".to_string(),
additional_props: std::collections::HashMap::new(),
},
Document {
id: "doc2".to_string(),
text: "Note: second.md\nContent: Second note".to_string(),
additional_props: std::collections::HashMap::new(),
},
];
let completion_request = CompletionRequest {
preamble: Some("You are a helpful assistant".to_string()),
chat_history: OneOrMany::one(Message::user("What are my notes about?")),
documents: documents.clone(),
tools: vec![],
temperature: None,
model: None,
output_schema: None,
max_tokens: None,
tool_choice: None,
additional_params: None,
};
let request = create_request_body(completion_request).unwrap();
assert_eq!(
request.contents.len(),
2,
"Expected 2 contents (documents + user message)"
);
assert_eq!(request.contents[0].role, Some(Role::User));
assert_eq!(
request.contents[0].parts.len(),
2,
"Expected 2 document parts"
);
for part in &request.contents[0].parts {
if let Part {
part: PartKind::Text(text),
..
} = part
{
assert!(
text.contains("Note:") && text.contains("Content:"),
"Document should contain note metadata"
);
} else {
panic!("Document parts should be text, not {:?}", part);
}
}
assert_eq!(request.contents[1].role, Some(Role::User));
if let Part {
part: PartKind::Text(text),
..
} = &request.contents[1].parts[0]
{
assert_eq!(text, "What are my notes about?");
} else {
panic!("Expected user message to be text");
}
}
#[test]
fn test_create_request_body_without_documents() {
use crate::OneOrMany;
use crate::completion::request::CompletionRequest;
use crate::message::Message;
let completion_request = CompletionRequest {
preamble: Some("You are a helpful assistant".to_string()),
chat_history: OneOrMany::one(Message::user("Hello")),
documents: vec![], tools: vec![],
temperature: None,
max_tokens: None,
tool_choice: None,
model: None,
output_schema: None,
additional_params: None,
};
let request = create_request_body(completion_request).unwrap();
assert_eq!(request.contents.len(), 1, "Expected only user message");
assert_eq!(request.contents[0].role, Some(Role::User));
if let Part {
part: PartKind::Text(text),
..
} = &request.contents[0].parts[0]
{
assert_eq!(text, "Hello");
} else {
panic!("Expected user message to be text");
}
}
#[test]
fn test_from_tool_output_parses_image_json() {
use crate::message::{DocumentSourceKind, ToolResultContent};
let image_json = r#"{"type": "image", "data": "base64data==", "mimeType": "image/jpeg"}"#;
let result = ToolResultContent::from_tool_output(image_json);
assert_eq!(result.len(), 1);
if let ToolResultContent::Image(img) = result.first() {
assert!(matches!(img.data, DocumentSourceKind::Base64(_)));
if let DocumentSourceKind::Base64(data) = &img.data {
assert_eq!(data, "base64data==");
}
assert_eq!(img.media_type, Some(crate::message::ImageMediaType::JPEG));
} else {
panic!("Expected Image content");
}
}
#[test]
fn test_from_tool_output_parses_hybrid_json() {
use crate::message::{DocumentSourceKind, ToolResultContent};
let hybrid_json = r#"{
"response": {"status": "ok", "count": 42},
"parts": [
{"type": "image", "data": "imgdata1==", "mimeType": "image/png"},
{"type": "image", "data": "https://example.com/img.jpg", "mimeType": "image/jpeg"}
]
}"#;
let result = ToolResultContent::from_tool_output(hybrid_json);
assert_eq!(result.len(), 3);
let items: Vec<_> = result.iter().collect();
if let ToolResultContent::Text(text) = &items[0] {
assert!(text.text.contains("status"));
assert!(text.text.contains("ok"));
} else {
panic!("Expected Text content first");
}
if let ToolResultContent::Image(img) = &items[1] {
assert!(matches!(img.data, DocumentSourceKind::Base64(_)));
} else {
panic!("Expected Image content second");
}
if let ToolResultContent::Image(img) = &items[2] {
assert!(matches!(img.data, DocumentSourceKind::Url(_)));
} else {
panic!("Expected Image content third");
}
}
#[tokio::test]
#[ignore = "requires GEMINI_API_KEY environment variable"]
async fn test_gemini_agent_with_image_tool_result_e2e() -> anyhow::Result<()> {
use crate::completion::Prompt;
use crate::prelude::*;
use crate::providers::gemini;
use crate::test_utils::MockImageGeneratorTool;
let client = gemini::Client::from_env()?;
let agent = client
.agent("gemini-3-flash-preview")
.preamble("You are a helpful assistant. When asked about images, use the generate_test_image tool to create one, then describe what you see in the image.")
.tool(MockImageGeneratorTool)
.build();
let response_text = agent
.prompt("Please generate a test image and tell me what color the pixel is.")
.await?;
println!("Response: {response_text}");
anyhow::ensure!(!response_text.is_empty(), "Response should not be empty");
Ok(())
}
}