use std::sync::Arc;
use base64::Engine;
use markhor_core::chat::ChatError;
use markhor_core::chat::chat::{
ChatApi, ChatOptions, ChatResponse, ChatStream, ContentPart, FinishReason,
Message, ModelInfo, ToolCallRequest, ToolChoice, ToolParameterSchema,
UsageInfo,
};
use async_trait::async_trait;
use markhor_core::extension::Extension;
use reqwest::{Client, StatusCode};
use serde::{Deserialize, Serialize};
use serde_json::json;
use tracing::{debug, error, instrument, trace, warn};
use uuid::Uuid;
use secrecy::ExposeSecret;
use crate::gemini::error::map_response_error;
use super::error::GeminiError;
use super::shared::{GeminiConfig, SharedGeminiClient, EXTENSION_URI};
#[derive(Serialize, Debug)]
#[serde(rename_all = "camelCase")]
struct GeminiGenerateRequest {
contents: Vec<GeminiContent>,
#[serde(skip_serializing_if = "Option::is_none")]
tools: Option<Vec<GeminiTool>>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_config: Option<GeminiToolConfig>,
#[serde(skip_serializing_if = "Option::is_none")]
system_instruction: Option<GeminiContent>, #[serde(skip_serializing_if = "Option::is_none")]
generation_config: Option<GeminiGenerationConfig>,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(rename_all = "camelCase")]
struct GeminiContent {
role: String, parts: Vec<GeminiPart>,
}
impl From<Message> for GeminiContent {
fn from(message: Message) -> Self {
match message {
Message::System(parts) => {
let combined_text = parts.into_iter()
.filter_map(|p| p.into_text())
.collect::<Vec<_>>()
.join("\n");
GeminiContent {
role: "system".to_string(), parts: vec![GeminiPart::Text{ text: combined_text }]
}
}
Message::User(parts) => {
let gemini_parts = parts.into_iter().map(|p| p.into()).collect();
GeminiContent {
role: "user".to_string(),
parts: gemini_parts
}
}
Message::Assistant { content: parts, tool_calls } => {
let mut gemini_parts: Vec<_> = parts.into_iter().map(|p| p.into()).collect();
for call_request in tool_calls {
gemini_parts.push(GeminiPart::FunctionCall {
function_call: GeminiFunctionCall {
name: call_request.name,
args: call_request.arguments,
}
});
}
GeminiContent {
role: "model".to_string(), parts: gemini_parts,
}
}
Message::Tool(tool_results) => {
let function_response_parts: Vec<GeminiPart> = tool_results.into_iter()
.map(|result| {
GeminiPart::FunctionResponse {
function_response: GeminiFunctionResponse {
name: result.name,
response: result.content,
}
}
}).collect();
GeminiContent {
role: "function".to_string(), parts: function_response_parts,
}
}
}
}
}
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(untagged)] enum GeminiPart {
Text {
text: String,
},
InlineData {
inline_data: GeminiBlob,
},
FunctionCall {
#[serde(rename = "functionCall")]
function_call: GeminiFunctionCall,
},
FunctionResponse {
#[serde(rename = "functionResponse")]
function_response: GeminiFunctionResponse,
},
}
impl From<ContentPart> for GeminiPart {
fn from(part: ContentPart) -> Self {
match part {
ContentPart::Text(text) => {
GeminiPart::Text { text }
}
ContentPart::Image { mime_type, data } => {
let encoded_data = base64::engine::general_purpose::STANDARD.encode(data);
GeminiPart::InlineData {
inline_data: GeminiBlob {
mime_type: mime_type,
data: encoded_data,
}
}
}
}
}
}
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(rename_all = "camelCase")]
struct GeminiBlob {
mime_type: String,
data: String, }
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(rename_all = "camelCase")]
struct GeminiFunctionCall {
name: String,
args: serde_json::Value, }
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(rename_all = "camelCase")]
struct GeminiFunctionResponse {
name: String,
response: serde_json::Value, }
#[derive(Serialize, Debug)]
#[serde(rename_all = "camelCase")]
struct GeminiTool {
function_declarations: Vec<GeminiFunctionDeclaration>,
}
#[derive(Serialize, Debug)]
#[serde(rename_all = "camelCase")]
struct GeminiFunctionDeclaration {
name: String,
description: String,
parameters: ToolParameterSchema, }
#[derive(Serialize, Debug)]
#[serde(rename_all = "camelCase")]
struct GeminiToolConfig {
#[serde(skip_serializing_if = "Option::is_none")]
function_calling_config: Option<GeminiFunctionCallingConfig>,
}
#[derive(Serialize, Debug)]
#[serde(rename_all = "camelCase")]
struct GeminiFunctionCallingConfig {
mode: GeminiFunctionCallingMode, #[serde(skip_serializing_if = "Option::is_none")]
allowed_function_names: Option<Vec<String>>, }
#[derive(Serialize, Debug)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
enum GeminiFunctionCallingMode {
ModeUnspecified,
Auto,
Any,
None,
}
impl From<ToolChoice> for GeminiFunctionCallingMode {
fn from(choice: ToolChoice) -> Self {
match choice {
ToolChoice::Auto => GeminiFunctionCallingMode::Auto,
ToolChoice::Required => GeminiFunctionCallingMode::Any,
ToolChoice::None => GeminiFunctionCallingMode::None,
ToolChoice::Tool { name } => {
warn!("Forcing use of a specific tool is not supported for Gemini (tool: '{}'). Using ANY", name);
GeminiFunctionCallingMode::Any
},
}
}
}
#[derive(Serialize, Debug, Default)]
#[serde(rename_all = "camelCase")]
struct GeminiGenerationConfig {
#[serde(skip_serializing_if = "Option::is_none")]
temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
candidate_count: Option<u32>, #[serde(skip_serializing_if = "Option::is_none")]
max_output_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
stop_sequences: Option<Vec<String>>,
}
#[derive(Deserialize, Debug)]
#[serde(rename_all = "camelCase")]
struct GeminiGenerateResponse {
candidates: Option<Vec<GeminiCandidate>>,
#[serde(default)]
usage_metadata: Option<GeminiUsageMetadata>,
}
impl GeminiGenerateResponse {
pub fn into_chat_response(self, request_model_id: &str) -> Result<ChatResponse, GeminiError> { let first_candidate = self.candidates.and_then(|mut c| c.into_iter().next());
let usage = self.usage_metadata.map(Into::into);
if let Some(cand) = first_candidate {
let finish_reason = cand.finish_reason.map(Into::into)
.unwrap_or(FinishReason::Other("Unknown finish reason".to_string()));
let mut content_parts = Vec::new();
let mut tool_calls = Vec::new();
if let Some(content) = cand.content {
if content.role == "model" {
for part in content.parts {
match part {
GeminiPart::Text { text } => {
content_parts.push(ContentPart::Text(text)); }
GeminiPart::InlineData { inline_data } => {
let decoded_data = base64::engine::general_purpose::STANDARD.decode(inline_data.data)
.map_err(|e| {
error!("Failed to decode base64 image data from Gemini response: {}", e);
GeminiError::UnexpectedResponse(format!("Failed to decode base64 image data: {}", e))
})?;
content_parts.push(ContentPart::Image {
mime_type: inline_data.mime_type,
data: decoded_data,
});
}
GeminiPart::FunctionCall { function_call } => {
tool_calls.push(ToolCallRequest {
id: format!("gemini-{}", Uuid::new_v4()),
name: function_call.name, arguments: function_call.args, });
}
GeminiPart::FunctionResponse { .. } => {
warn!("Unexpected FunctionResponse part in model content.");
}
}
}
} else {
warn!(role = %content.role, "Unexpected role in Gemini candidate content.");
}
} else {
debug!("Gemini candidate received with no 'content' field.");
}
if content_parts.is_empty() && tool_calls.is_empty() {
debug!("Received response with no text content or tool calls (Finish Reason: {:?}).", finish_reason);
}
Ok(ChatResponse {
content: content_parts,
tool_calls,
usage,
finish_reason: Some(finish_reason),
model_id: Some(request_model_id.to_string()),
})
} else {
warn!("Gemini response contained no candidates.");
Ok(ChatResponse {
content: vec![],
tool_calls: vec![],
usage, finish_reason: Some(FinishReason::Other("No candidate received".to_string())),
model_id: Some(request_model_id.to_string()),
})
}
}
}
#[derive(Deserialize, Debug)]
#[serde(rename_all = "camelCase")]
struct GeminiCandidate {
content: Option<GeminiContent>, finish_reason: Option<GeminiFinishReason>,
}
#[derive(Deserialize, Serialize, Debug, PartialEq, Eq)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
enum GeminiFinishReason {
Stop,
MaxTokens,
Safety,
Recitation,
Blocklist,
ProhibitedContent,
Spii,
MalformedFunctionCall,
Other,
Unspecified,
}
impl Into<FinishReason> for GeminiFinishReason {
fn into(self) -> FinishReason {
match self {
GeminiFinishReason::Stop => FinishReason::Stop,
GeminiFinishReason::MaxTokens => FinishReason::Length,
GeminiFinishReason::Safety => FinishReason::ContentFilter,
GeminiFinishReason::Recitation => FinishReason::Other("Recitation".to_string()),
GeminiFinishReason::Blocklist => FinishReason::Other("Blocklist".to_string()),
GeminiFinishReason::ProhibitedContent => FinishReason::ContentFilter,
GeminiFinishReason::Spii => FinishReason::ContentFilter,
GeminiFinishReason::MalformedFunctionCall => FinishReason::Other("MalformedFunctionCall".to_string()),
GeminiFinishReason::Other => FinishReason::Other("Unknown".to_string()),
GeminiFinishReason::Unspecified => FinishReason::Unspecified,
}
}
}
#[derive(Deserialize, Debug, Default)]
#[serde(rename_all = "camelCase")]
struct GeminiUsageMetadata {
#[serde(default)]
prompt_token_count: Option<u32>,
#[serde(default)]
candidates_token_count: Option<u32>, #[serde(default)]
total_token_count: Option<u32>,
}
impl Into<UsageInfo> for GeminiUsageMetadata {
fn into(self) -> UsageInfo {
UsageInfo {
prompt_tokens: self.prompt_token_count,
completion_tokens: self.candidates_token_count, total_tokens: self.total_token_count,
}
}
}
#[derive(Deserialize, Debug)]
struct GeminiErrorResponse {
error: GeminiErrorDetail,
}
impl GeminiErrorResponse {
fn into_api_error(self, response_status: StatusCode) -> ChatError {
let msg = format!("{} (Status: {}, Code: {})", self.error.message, self.error.status, self.error.code);
match response_status.as_u16() {
400 => ChatError::InvalidRequest(msg),
401 | 403 => ChatError::Authentication(msg),
404 => ChatError::ModelNotFound(msg), 429 => ChatError::RateLimited,
500..=599 => ChatError::Api {
status: Some(response_status.as_u16()),
message: msg,
source: None,
},
_ => ChatError::Api {
status: Some(response_status.as_u16()),
message: msg,
source: None,
},
}
}
}
#[derive(Deserialize, Debug)]
struct GeminiErrorDetail {
code: u16,
message: String,
status: String, }
#[derive(Deserialize, Debug)]
struct GeminiListModelsResponse {
models: Vec<GeminiModelInfo>,
}
#[derive(Deserialize, Debug)]
#[serde(rename_all = "camelCase")]
struct GeminiModelInfo {
name: String, display_name: Option<String>,
description: Option<String>,
input_token_limit: Option<u32>,
output_token_limit: Option<u32>,
}
const DEFAULT_GEMINI_API_BASE: &str = "https://generativelanguage.googleapis.com/v1beta/models";
const DEFAULT_GEMINI_CHAT_MODEL: &str = "gemini-2.0-flash-lite";
#[derive(Debug, Clone)]
pub struct GeminiChatClient {
shared_client: Arc<SharedGeminiClient>,
default_model_id: String,
}
impl GeminiChatClient {
pub fn new(api_key: impl Into<String>) -> Result<Self, GeminiError> {
Self::new_with_options(api_key, None, None)
}
pub fn new_with_options(
api_key: impl Into<String>,
default_model_id: Option<String>,
client_override: Option<Client>,
) -> Result<Self, GeminiError> {
let config = GeminiConfig::new(api_key)?; let shared_client = SharedGeminiClient::new(config, client_override)?;
Self::new_with_shared_client(Arc::new(shared_client), default_model_id)
}
#[instrument(name = "gemini_chat_client_from_config", skip(shared_client))]
pub(crate) fn new_with_shared_client(
shared_client: Arc<SharedGeminiClient>,
default_model_id: Option<String>,
) -> Result<Self, GeminiError> {
let model_id = default_model_id.unwrap_or_else(|| DEFAULT_GEMINI_CHAT_MODEL.to_string());
debug!(default_model_id = %model_id, "GeminiChatClient created.");
Ok(Self {
shared_client,
default_model_id: model_id,
})
}
async fn map_gemini_error(err_resp: reqwest::Response) -> ChatError {
let status = err_resp.status();
let error_text_result = err_resp.text().await;
match error_text_result {
Ok(error_text) => {
trace!(status = %status, error_body = %error_text, "Gemini API error response body");
match serde_json::from_str::<GeminiErrorResponse>(&error_text) {
Ok(gemini_error) => {
gemini_error.into_api_error(status)
}
Err(parse_err) => {
warn!(parse_error = %parse_err, body = %error_text, "Failed to parse Gemini error response JSON");
ChatError::Api {
status: Some(status.as_u16()),
message: error_text,
source: Some(Box::new(parse_err)),
}
}
}
},
Err(text_err) => {
error!(status = %status, text_error = %text_err, "Failed to read Gemini error response body text");
ChatError::Api {
status: Some(status.as_u16()),
message: format!("Failed to read error response body: {}", text_err),
source: Some(Box::new(text_err)),
}
}
}
}
fn convert_messages(
messages: &[Message],
) -> Result<(Option<GeminiContent>, Vec<GeminiContent>), GeminiError> { let mut system_instruction: Option<GeminiContent> = None;
let mut gemini_contents: Vec<GeminiContent> = Vec::with_capacity(messages.len());
let mut system_message_found = false;
for message in messages {
match message {
Message::System(parts) => {
if system_message_found {
return Err(GeminiError::InvalidInput(
"Multiple System messages are not supported by Gemini; use 'system_instruction'.".to_string()
));
}
system_message_found = true;
let combined_text = parts.iter()
.filter_map(|part| part.clone().into_text())
.collect::<Vec<_>>()
.join("\n");
system_instruction = Some(GeminiContent {
role: "user".to_string(),
parts: vec![GeminiPart::Text { text: combined_text }],
});
}
_ => {
gemini_contents.push(GeminiContent::from(message.clone()));
}
}
}
Ok((system_instruction, gemini_contents))
}
fn convert_tools(options: &ChatOptions) -> (Option<Vec<GeminiTool>>, Option<GeminiToolConfig>) {
let tools = options.tools.as_ref().map(|defs| {
vec![GeminiTool {
function_declarations: defs.iter().map(|def| GeminiFunctionDeclaration {
name: def.name.clone(),
description: def.description.clone(),
parameters: def.parameters.clone(), }).collect(),
}]
});
let tool_config = match options.tool_choice.as_ref() {
None | Some(ToolChoice::Auto) => {
if tools.is_some() {
None } else {
Some(GeminiToolConfig { function_calling_config: Some(GeminiFunctionCallingConfig { mode: GeminiFunctionCallingMode::None, allowed_function_names: None })})
}
},
Some(ToolChoice::None) => Some(GeminiToolConfig { function_calling_config: Some(GeminiFunctionCallingConfig { mode: GeminiFunctionCallingMode::None, allowed_function_names: None }) }),
Some(ToolChoice::Required) => {
if tools.is_none() || tools.as_ref().map_or(true, |t| t.is_empty()) {
warn!("ToolChoice::Required specified but no tools were provided.");
None } else {
Some(GeminiToolConfig { function_calling_config: Some(GeminiFunctionCallingConfig { mode: GeminiFunctionCallingMode::Any, allowed_function_names: None }) })
}
},
Some(ToolChoice::Tool { name }) => Some(GeminiToolConfig {
function_calling_config: Some(GeminiFunctionCallingConfig {
mode: GeminiFunctionCallingMode::Any,
allowed_function_names: Some(vec![name.clone()]),
}),
}),
};
(tools, tool_config)
}
}
#[async_trait]
impl ChatApi for GeminiChatClient {
#[instrument(skip(self), fields(client = self.shared_client.config().base_url.as_str()))]
async fn list_models(&self) -> Result<Vec<ModelInfo>, ChatError> {
async {
let url = self.shared_client.build_url("models")?; debug!(%url, "Requesting Gemini models list");
let response = self.shared_client.http_client()
.get(url)
.header("x-goog-api-key", self.shared_client.config().api_key.expose_secret()) .send()
.await
.map_err(GeminiError::Network)?;
if !response.status().is_success() {
let status = response.status();
error!(%status, "Failed to list models from Gemini API");
return Err(map_response_error(response).await);
}
let status = response.status();
debug!(%status, "Received successful response for model list");
let raw_body = response.text()
.await
.map_err(|e| {
error!(error = %e, "Failed to read successful response body for model list");
GeminiError::Network(e)
})?;
trace!(body = %raw_body, "Received model list response body");
let list_response: GeminiListModelsResponse = serde_json::from_str(&raw_body)
.map_err(|e| {
error!(parse_error = %e, raw_body = %raw_body, "Failed to parse Gemini model list JSON");
GeminiError::ResponseParsing {
context: "Parsing model list".to_string(),
source: e,
}
})?;
let models = list_response.models.into_iter()
.filter_map(|m| {
let model_id = m.name.split('/').last();
match model_id {
Some(id) if !id.is_empty() => Some(ModelInfo {
id: id.to_string(),
description: m.description.clone().or(m.display_name.clone()), context_window: m.input_token_limit,
max_output_tokens: m.output_token_limit,
}),
_ => {
warn!(raw_name = %m.name, "Could not parse model ID from Gemini model name");
None }
}
})
.collect::<Vec<_>>();
debug!(count = models.len(), "Successfully parsed models list");
Ok(models)
}
.await .map_err(Into::into) }
#[instrument(skip(self, messages, options), fields(model = options.model_id.as_deref().unwrap_or(&self.default_model_id)))]
async fn generate(
&self,
messages: &[Message],
options: &ChatOptions,
) -> Result<ChatResponse, ChatError> {
async {
let model_id = options
.model_id
.as_deref()
.unwrap_or(&self.default_model_id);
let path_segment = format!("models/{}:generateContent", model_id);
let url = self.shared_client.build_url(&path_segment)?; debug!(%url, %model_id, "Sending generate request to Gemini");
let (system_instruction, gemini_contents) = Self::convert_messages(messages)?;
let (tools, tool_config) = Self::convert_tools(options);
let generation_config = GeminiGenerationConfig {
temperature: options.temperature,
top_p: options.top_p,
max_output_tokens: options.max_tokens,
stop_sequences: options.stop_sequences.clone(),
candidate_count: Some(1), ..Default::default()
};
let request_body = GeminiGenerateRequest {
contents: gemini_contents,
tools,
tool_config,
system_instruction,
generation_config: Some(generation_config).filter(|c| {
c.temperature.is_some() || c.top_p.is_some() || c.max_output_tokens.is_some() || c.stop_sequences.is_some()
}),
};
let request_json = serde_json::to_string(&request_body)
.map_err(|e| {
error!(error = %e, "Failed to serialize Gemini generate request body");
GeminiError::RequestSerialization(e)
})?;
trace!(body = %request_json, "Constructed Gemini request body JSON");
let response = self.shared_client.http_client()
.post(url)
.header("x-goog-api-key", self.shared_client.config().api_key.expose_secret()) .header("Content-Type", "application/json") .body(request_json)
.send()
.await
.map_err(GeminiError::Network)?;
if !response.status().is_success() {
let status = response.status();
error!(%status, "Gemini generate API returned error status");
return Err(map_response_error(response).await);
}
let status = response.status();
debug!(%status, "Received successful response for generate request");
let raw_body = response.text()
.await
.map_err(|e| {
error!(error = %e, "Failed to read successful response body for generate");
GeminiError::Network(e)
})?;
trace!(body = %raw_body, "Received Gemini generate response body");
let gemini_response: GeminiGenerateResponse = serde_json::from_str(&raw_body)
.map_err(|e| {
error!(parse_error = %e, raw_body = %raw_body, "Failed to parse Gemini generate response JSON");
GeminiError::ResponseParsing {
context: "Parsing generate response".to_string(),
source: e,
}
})?;
debug!("Successfully parsed Gemini generate response");
gemini_response.into_chat_response(model_id) }
.await .map_err(Into::into) }
#[instrument(skip(self, messages, options))]
async fn generate_stream(
&self,
messages: &[Message],
options: &ChatOptions,
) -> Result<ChatStream, ChatError> {
warn!("Gemini streaming is not yet implemented.");
Err(ChatError::NotSupported("Streaming is not yet implemented for the Gemini client.".to_string()))
}
}
pub fn create_default_http_client() -> Result<reqwest::Client, ChatError> {
reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(60)) .build()
.map_err(|e| ChatError::Configuration(format!("Failed to build HTTP client: {}", e)))
}