use super::traits::{
ChatRequest, ChatResponse, FinishReason, LlmProvider, MessageRole, TokenUsage,
};
use anyhow::{Context, Result};
use async_trait::async_trait;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
pub struct GeminiProvider {
name: Arc<str>,
api_key: Arc<str>,
base_url: Arc<str>,
client: Client,
}
impl GeminiProvider {
pub fn new(name: Arc<str>, api_key: Arc<str>, base_url: Option<Arc<str>>) -> Self {
Self {
name,
api_key,
base_url: base_url.unwrap_or_else(|| {
Arc::from("https://generativelanguage.googleapis.com/v1beta")
}),
client: Client::new(),
}
}
}
#[async_trait]
impl LlmProvider for GeminiProvider {
fn name(&self) -> &str {
&self.name
}
fn provider_type(&self) -> &str {
"gemini"
}
async fn chat(&self, request: ChatRequest) -> Result<ChatResponse> {
let url = format!(
"{}/models/{}:generateContent?key={}",
self.base_url, request.model, self.api_key
);
let contents: Vec<GeminiContent> = request
.messages
.iter()
.filter(|m| m.role != MessageRole::System)
.map(|m| GeminiContent {
role: match m.role {
MessageRole::User | MessageRole::Tool => "user".to_string(),
MessageRole::Assistant => "model".to_string(),
MessageRole::System => "user".to_string(),
},
parts: vec![GeminiPart {
text: m.content.clone(),
}],
})
.collect();
let system_instruction = request
.messages
.iter()
.find(|m| m.role == MessageRole::System)
.map(|m| GeminiSystemInstruction {
parts: vec![GeminiPart {
text: m.content.clone(),
}],
});
let body = GeminiRequest {
contents,
generation_config: GeminiGenerationConfig {
max_output_tokens: request.max_tokens,
temperature: request.temperature,
},
system_instruction,
};
let response = self
.client
.post(&url)
.header("Content-Type", "application/json")
.json(&body)
.send()
.await
.context("Failed to send request to Gemini")?;
if !response.status().is_success() {
let error_text = response.text().await.unwrap_or_default();
anyhow::bail!("Gemini API error: {}", error_text);
}
let api_response: GeminiResponse = response
.json()
.await
.context("Failed to parse Gemini response")?;
let candidate = api_response
.candidates
.into_iter()
.next()
.context("No candidates in Gemini response")?;
let content = candidate
.content
.parts
.into_iter()
.map(|p| p.text)
.collect::<Vec<String>>()
.join("");
let finish_reason = match candidate.finish_reason.as_deref() {
Some("STOP") => FinishReason::Stop,
Some("MAX_TOKENS") => FinishReason::Length,
_ => FinishReason::Stop,
};
Ok(ChatResponse {
id: uuid::Uuid::new_v4().to_string(),
model: request.model,
content,
finish_reason,
usage: TokenUsage {
prompt_tokens: api_response.usage_metadata.prompt_token_count,
completion_tokens: api_response.usage_metadata.candidates_token_count,
total_tokens: api_response.usage_metadata.total_token_count,
},
tool_calls: None,
})
}
async fn is_available(&self) -> bool {
if self.api_key.is_empty() {
return false;
}
let url = format!("{}/models?key={}", self.base_url, self.api_key);
let response = self.client.get(&url).send().await;
match response {
Ok(resp) => resp.status().is_success(),
Err(_) => false,
}
}
}
#[derive(Debug, Serialize)]
struct GeminiRequest {
contents: Vec<GeminiContent>,
generation_config: GeminiGenerationConfig,
#[serde(skip_serializing_if = "Option::is_none")]
system_instruction: Option<GeminiSystemInstruction>,
}
#[derive(Debug, Serialize, Deserialize)]
struct GeminiContent {
role: String,
parts: Vec<GeminiPart>,
}
#[derive(Debug, Serialize, Deserialize)]
struct GeminiPart {
text: String,
}
#[derive(Debug, Serialize)]
struct GeminiGenerationConfig {
#[serde(skip_serializing_if = "Option::is_none")]
max_output_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
temperature: Option<f32>,
}
#[derive(Debug, Serialize)]
struct GeminiSystemInstruction {
parts: Vec<GeminiPart>,
}
#[derive(Debug, Deserialize)]
struct GeminiResponse {
candidates: Vec<GeminiCandidate>,
usage_metadata: GeminiUsageMetadata,
}
#[derive(Debug, Deserialize)]
struct GeminiCandidate {
content: GeminiContent,
finish_reason: Option<String>,
}
#[derive(Debug, Deserialize)]
struct GeminiUsageMetadata {
prompt_token_count: u32,
candidates_token_count: u32,
total_token_count: u32,
}