use crate::api_key::ApiKeyProvider;
use crate::error::{Error, Result};
use crate::quota::QuotaTracker;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use std::time::Duration;
use tracing::{debug, info, warn};
const GEMINI_API_BASE: &str = "https://generativelanguage.googleapis.com/v1beta";
const DEFAULT_MODEL: &str = "gemini-2.0-flash";
pub struct GeminiClient {
client: Client,
api_key_provider: Arc<dyn ApiKeyProvider>,
quota_tracker: Arc<QuotaTracker>,
model: String,
}
impl GeminiClient {
pub fn new(
api_key_provider: Arc<dyn ApiKeyProvider>,
quota_tracker: Arc<QuotaTracker>,
) -> Self {
let client = Client::builder()
.timeout(Duration::from_secs(120))
.build()
.expect("Failed to build HTTP client");
Self {
client,
api_key_provider,
quota_tracker,
model: DEFAULT_MODEL.to_string(),
}
}
pub fn with_model(mut self, model: String) -> Self {
self.model = model;
self
}
pub async fn generate_content(
&self,
prompt: &str,
_file_content: Option<String>,
) -> Result<GenerateResponse> {
let quota_status = self.quota_tracker.check_quota()?;
if quota_status.warning {
warn!("{}", quota_status.format_message());
}
let api_key = self.api_key_provider.get_key().await?;
let api_key_str = api_key.as_str();
debug!(
"API key retrieved, length: {}, starts with: {}",
api_key_str.len(),
if api_key_str.len() >= 4 {
&api_key_str[0..4]
} else {
api_key_str
}
);
let url = format!("{}/models/{}:generateContent", GEMINI_API_BASE, self.model);
let request_body = GenerateRequest {
contents: vec![Content {
parts: vec![Part {
text: prompt.to_string(),
}],
}],
};
debug!("Sending request to Gemini API with model {}", self.model);
debug!("Request URL: {}", url);
let response = self
.client
.post(&url)
.header("x-goog-api-key", api_key_str)
.json(&request_body)
.send()
.await?;
if !response.status().is_success() {
let status = response.status();
let error_text = response.text().await.unwrap_or_default();
return Err(Error::GeminiApiError(format!(
"API request failed with status {}: {}",
status, error_text
)));
}
let gemini_response: GeminiResponse = response.json().await?;
let prompt_tokens = gemini_response
.usage_metadata
.as_ref()
.map(|u| u.prompt_token_count)
.unwrap_or(0);
let candidates_tokens = gemini_response
.usage_metadata
.as_ref()
.map(|u| u.candidates_token_count)
.unwrap_or(0);
let total_tokens = gemini_response
.usage_metadata
.as_ref()
.map(|u| u.total_token_count)
.unwrap_or(0);
self.quota_tracker.record_request(total_tokens);
info!(
"Gemini API request successful. Tokens: prompt={}, candidates={}, total={}",
prompt_tokens, candidates_tokens, total_tokens
);
let text = gemini_response
.candidates
.first()
.and_then(|c| c.content.parts.first())
.map(|p| p.text.clone())
.unwrap_or_default();
Ok(GenerateResponse {
text,
quota_status: self.quota_tracker.get_status(),
prompt_tokens,
total_tokens,
})
}
pub async fn search_codebase(
&self,
query: &str,
files_content: &[FileContent],
) -> Result<GenerateResponse> {
let mut prompt = format!(
"You are a code search assistant. Search through the following codebase files and find relevant sections matching this query: {}\n\n",
query
);
prompt.push_str("Codebase files:\n\n");
for file in files_content {
prompt.push_str(&format!(
"File: {}\n```\n{}\n```\n\n",
file.path, file.content
));
}
prompt.push_str("\nProvide a structured response with:\n");
prompt.push_str("1. Relevant file paths and line numbers\n");
prompt.push_str("2. Brief explanation of why each result matches\n");
prompt.push_str("3. Code snippets showing the relevant sections\n");
self.generate_content(&prompt, None).await
}
pub async fn analyze_files(
&self,
files_content: &[FileContent],
question: &str,
) -> Result<GenerateResponse> {
let mut prompt = format!(
"Analyze the following code files and answer this question: {}\n\n",
question
);
prompt.push_str("Files to analyze:\n\n");
for file in files_content {
prompt.push_str(&format!(
"File: {}\n```\n{}\n```\n\n",
file.path, file.content
));
}
self.generate_content(&prompt, None).await
}
pub async fn ask_about_code(&self, context: &str, question: &str) -> Result<GenerateResponse> {
let prompt = format!(
"Context from codebase:\n{}\n\nQuestion: {}\n\nProvide a clear, concise answer based on the context provided.",
context, question
);
self.generate_content(&prompt, None).await
}
pub async fn summarize_directory(
&self,
directory_structure: &str,
files_content: &[FileContent],
) -> Result<GenerateResponse> {
let mut prompt = format!(
"Summarize this directory and its code:\n\nDirectory structure:\n{}\n\n",
directory_structure
);
prompt.push_str("Key files:\n\n");
for file in files_content.iter().take(10) {
prompt.push_str(&format!(
"File: {}\n```\n{}\n```\n\n",
file.path, file.content
));
}
prompt.push_str("\nProvide:\n");
prompt.push_str("1. Overall purpose and architecture\n");
prompt.push_str("2. Key components and their relationships\n");
prompt.push_str("3. Notable patterns or technologies used\n");
prompt.push_str("4. Entry points and main functionality\n");
self.generate_content(&prompt, None).await
}
}
#[derive(Debug, Serialize)]
struct GenerateRequest {
contents: Vec<Content>,
}
#[derive(Debug, Serialize)]
struct Content {
parts: Vec<Part>,
}
#[derive(Debug, Serialize, Deserialize)]
struct Part {
text: String,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
struct GeminiResponse {
candidates: Vec<Candidate>,
usage_metadata: Option<UsageMetadata>,
}
#[derive(Debug, Deserialize)]
struct Candidate {
content: ContentResponse,
}
#[derive(Debug, Deserialize)]
struct ContentResponse {
parts: Vec<Part>,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
struct UsageMetadata {
prompt_token_count: usize,
candidates_token_count: usize,
total_token_count: usize,
}
#[derive(Debug)]
pub struct GenerateResponse {
pub text: String,
pub quota_status: crate::quota::QuotaStatus,
pub prompt_tokens: usize,
pub total_tokens: usize,
}
#[derive(Debug, Clone)]
pub struct FileContent {
pub path: String,
pub content: String,
}
impl FileContent {
pub fn new(path: String, content: String) -> Self {
Self { path, content }
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::api_key::{ApiKeyProvider, SecureString};
struct MockApiKeyProvider;
#[async_trait::async_trait]
impl ApiKeyProvider for MockApiKeyProvider {
async fn get_key(&self) -> Result<SecureString> {
Ok(SecureString::new(
"AIzaSyDMockKey1234567890123456789".to_string(),
))
}
}
#[test]
fn test_file_content_creation() {
let file = FileContent::new("src/main.rs".to_string(), "fn main() {}".to_string());
assert_eq!(file.path, "src/main.rs");
assert_eq!(file.content, "fn main() {}");
}
#[tokio::test]
async fn test_gemini_client_creation() {
let api_key_provider = Arc::new(MockApiKeyProvider) as Arc<dyn ApiKeyProvider>;
let quota_tracker = Arc::new(QuotaTracker::new());
let _client = GeminiClient::new(api_key_provider, quota_tracker);
}
#[test]
fn test_default_model() {
let api_key_provider = Arc::new(MockApiKeyProvider) as Arc<dyn ApiKeyProvider>;
let quota_tracker = Arc::new(QuotaTracker::new());
let client = GeminiClient::new(api_key_provider, quota_tracker);
assert_eq!(client.model, DEFAULT_MODEL);
}
#[test]
fn test_with_model() {
let api_key_provider = Arc::new(MockApiKeyProvider) as Arc<dyn ApiKeyProvider>;
let quota_tracker = Arc::new(QuotaTracker::new());
let client = GeminiClient::new(api_key_provider, quota_tracker)
.with_model("gemini-2.5-pro".to_string());
assert_eq!(client.model, "gemini-2.5-pro");
}
}