#![allow(dead_code)]
use super::{AIClient, AIConfig, AIResponse, AIContext, ModelConfig};
use anyhow::Result;
use futures_util::stream::StreamExt;
use serde::{Deserialize, Serialize};
use tracing::debug;
pub struct GeminiClient {
client: reqwest::Client,
api_key: String,
base_url: String,
model: String,
max_tokens: u32,
temperature: f32,
}
#[derive(Debug, Serialize)]
struct GeminiRequest {
contents: Vec<Content>,
generation_config: GenerationConfig,
}
#[derive(Debug, Serialize)]
struct Content {
role: String,
parts: Vec<Part>,
}
#[derive(Debug, Serialize)]
struct Part {
text: String,
}
#[derive(Debug, Serialize)]
struct GenerationConfig {
max_output_tokens: u32,
temperature: f32,
}
#[derive(Debug, Deserialize)]
struct GeminiResponse {
candidates: Vec<Candidate>,
usage_metadata: Option<UsageMetadata>,
}
#[derive(Debug, Deserialize)]
struct Candidate {
content: ContentResponse,
finish_reason: Option<String>,
}
#[derive(Debug, Deserialize)]
struct ContentResponse {
parts: Vec<PartResponse>,
role: String,
}
#[derive(Debug, Deserialize)]
struct PartResponse {
text: String,
}
#[derive(Debug, Deserialize)]
struct UsageMetadata {
prompt_token_count: u32,
candidates_token_count: u32,
total_token_count: u32,
}
impl GeminiClient {
pub fn new(config: &AIConfig) -> Result<Self> {
let client = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(60))
.build()?;
let base_url = std::env::var("GEMINI_BASE_URL")
.ok()
.filter(|s| !s.is_empty())
.unwrap_or_else(|| "https://generativelanguage.googleapis.com".to_string())
.trim_end_matches('/')
.to_string();
Ok(Self {
client,
api_key: config.api_key.clone(),
base_url,
model: config.model.clone(),
max_tokens: config.max_tokens,
temperature: config.temperature,
})
}
fn build_url(&self) -> String {
format!(
"{}/v1beta/models/{}:generateContent?key={}",
self.base_url, self.model, self.api_key
)
}
fn build_stream_url(&self) -> String {
format!(
"{}/v1beta/models/{}:streamGenerateContent?alt=sse&key={}",
self.base_url, self.model, self.api_key
)
}
}
#[async_trait::async_trait]
impl AIClient for GeminiClient {
async fn complete(
&self,
prompt: &str,
_context: Option<&AIContext>,
) -> Result<AIResponse> {
debug!("Sending completion request to Gemini");
let request = GeminiRequest {
contents: vec![Content {
role: "user".to_string(),
parts: vec![Part {
text: prompt.to_string(),
}],
}],
generation_config: GenerationConfig {
max_output_tokens: self.max_tokens,
temperature: self.temperature,
},
};
let response = self
.client
.post(&self.build_url())
.json(&request)
.send()
.await?;
if !response.status().is_success() {
let error_text = response.text().await?;
anyhow::bail!("Gemini API error: {}", error_text);
}
let completion: GeminiResponse = response.json().await?;
let content = completion
.candidates
.get(0)
.and_then(|c| c.content.parts.get(0))
.map(|p| p.text.clone())
.unwrap_or_default();
let tokens_used = completion.usage_metadata.as_ref().map(|u| u.total_token_count);
Ok(AIResponse {
content,
tokens_used,
model: self.model.clone(),
finish_reason: completion.candidates.get(0).and_then(|c| c.finish_reason.clone()),
cost_estimate: None, })
}
async fn complete_stream(
&self,
prompt: &str,
_context: Option<&AIContext>,
) -> Result<tokio::sync::mpsc::Receiver<Result<String>>> {
debug!("Sending streaming completion request to Gemini");
let (tx, rx) = tokio::sync::mpsc::channel(100);
let url = self.build_stream_url();
let client = self.client.clone();
let max_tokens = self.max_tokens;
let temperature = self.temperature;
let prompt = prompt.to_string();
tokio::spawn(async move {
let request = GeminiRequest {
contents: vec![Content {
role: "user".to_string(),
parts: vec![Part { text: prompt }],
}],
generation_config: GenerationConfig {
max_output_tokens: max_tokens,
temperature,
},
};
let response = match client.post(&url).json(&request).send().await {
Ok(r) => r,
Err(e) => {
let _ = tx.send(Err(e.into())).await;
return;
}
};
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
let _ = tx
.send(Err(anyhow::anyhow!("Gemini API error {}: {}", status, body)))
.await;
return;
}
let mut stream = response.bytes_stream();
let mut buf = String::new();
while let Some(chunk) = stream.next().await {
let bytes = match chunk {
Ok(b) => b,
Err(e) => {
let _ = tx.send(Err(anyhow::anyhow!("stream error: {}", e))).await;
return;
}
};
buf.push_str(&String::from_utf8_lossy(&bytes));
while let Some(idx) = buf.find("\n\n") {
let raw = buf[..idx].to_string();
buf.drain(..idx + 2);
for line in raw.lines() {
if let Some(data) = line.strip_prefix("data: ") {
if data.is_empty() || data == "[DONE]" {
continue;
}
match serde_json::from_str::<GeminiResponse>(data) {
Ok(resp) => {
if let Some(text) = resp
.candidates
.get(0)
.and_then(|c| c.content.parts.get(0))
.map(|p| p.text.clone())
{
if !text.is_empty() {
if tx.send(Ok(text)).await.is_err() {
return;
}
}
}
}
Err(_) => continue, }
}
}
}
}
});
Ok(rx)
}
fn name(&self) -> &str {
"Gemini"
}
fn model_info(&self) -> ModelConfig {
ModelConfig {
name: self.model.clone(),
context_window: 32768, supports_functions: true,
supports_vision: self.model.contains("vision"),
cost_per_1k_input: 0.0005, cost_per_1k_output: 0.0015, }
}
}