dynamic_grounding_for_github_copilot 0.1.0

MCP server providing Google Gemini AI integration for enhanced codebase search and analysis
Documentation
use crate::api_key::ApiKeyProvider;
use crate::error::{Error, Result};
use crate::quota::QuotaTracker;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use std::time::Duration;
use tracing::{debug, info, warn};

const GEMINI_API_BASE: &str = "https://generativelanguage.googleapis.com/v1beta";
const DEFAULT_MODEL: &str = "gemini-2.0-flash";

/// Gemini API client with quota tracking and error handling
pub struct GeminiClient {
    client: Client,
    api_key_provider: Arc<dyn ApiKeyProvider>,
    quota_tracker: Arc<QuotaTracker>,
    model: String,
}

impl GeminiClient {
    pub fn new(
        api_key_provider: Arc<dyn ApiKeyProvider>,
        quota_tracker: Arc<QuotaTracker>,
    ) -> Self {
        let client = Client::builder()
            .timeout(Duration::from_secs(120))
            .build()
            .expect("Failed to build HTTP client");

        Self {
            client,
            api_key_provider,
            quota_tracker,
            model: DEFAULT_MODEL.to_string(),
        }
    }

    pub fn with_model(mut self, model: String) -> Self {
        self.model = model;
        self
    }

    /// Generate content with the Gemini API
    pub async fn generate_content(
        &self,
        prompt: &str,
        _file_content: Option<String>,
    ) -> Result<GenerateResponse> {
        // Check quota before making request
        let quota_status = self.quota_tracker.check_quota()?;

        if quota_status.warning {
            warn!("{}", quota_status.format_message());
        }

        let api_key = self.api_key_provider.get_key().await?;
        let api_key_str = api_key.as_str();

        debug!(
            "API key retrieved, length: {}, starts with: {}",
            api_key_str.len(),
            if api_key_str.len() >= 4 {
                &api_key_str[0..4]
            } else {
                api_key_str
            }
        );

        let url = format!("{}/models/{}:generateContent", GEMINI_API_BASE, self.model);

        let request_body = GenerateRequest {
            contents: vec![Content {
                parts: vec![Part {
                    text: prompt.to_string(),
                }],
            }],
        };

        debug!("Sending request to Gemini API with model {}", self.model);
        debug!("Request URL: {}", url);

        let response = self
            .client
            .post(&url)
            .header("x-goog-api-key", api_key_str)
            .json(&request_body)
            .send()
            .await?;

        if !response.status().is_success() {
            let status = response.status();
            let error_text = response.text().await.unwrap_or_default();
            return Err(Error::GeminiApiError(format!(
                "API request failed with status {}: {}",
                status, error_text
            )));
        }

        let gemini_response: GeminiResponse = response.json().await?;

        // Extract token counts for quota tracking
        let prompt_tokens = gemini_response
            .usage_metadata
            .as_ref()
            .map(|u| u.prompt_token_count)
            .unwrap_or(0);
        let candidates_tokens = gemini_response
            .usage_metadata
            .as_ref()
            .map(|u| u.candidates_token_count)
            .unwrap_or(0);
        let total_tokens = gemini_response
            .usage_metadata
            .as_ref()
            .map(|u| u.total_token_count)
            .unwrap_or(0);

        // Record the request
        self.quota_tracker.record_request(total_tokens);

        info!(
            "Gemini API request successful. Tokens: prompt={}, candidates={}, total={}",
            prompt_tokens, candidates_tokens, total_tokens
        );

        // Extract text from response
        let text = gemini_response
            .candidates
            .first()
            .and_then(|c| c.content.parts.first())
            .map(|p| p.text.clone())
            .unwrap_or_default();

        Ok(GenerateResponse {
            text,
            quota_status: self.quota_tracker.get_status(),
            prompt_tokens,
            total_tokens,
        })
    }

    /// Search through codebase files with natural language
    pub async fn search_codebase(
        &self,
        query: &str,
        files_content: &[FileContent],
    ) -> Result<GenerateResponse> {
        let mut prompt = format!(
            "You are a code search assistant. Search through the following codebase files and find relevant sections matching this query: {}\n\n",
            query
        );

        prompt.push_str("Codebase files:\n\n");
        for file in files_content {
            prompt.push_str(&format!(
                "File: {}\n```\n{}\n```\n\n",
                file.path, file.content
            ));
        }

        prompt.push_str("\nProvide a structured response with:\n");
        prompt.push_str("1. Relevant file paths and line numbers\n");
        prompt.push_str("2. Brief explanation of why each result matches\n");
        prompt.push_str("3. Code snippets showing the relevant sections\n");

        self.generate_content(&prompt, None).await
    }

    /// Analyze files and provide insights
    pub async fn analyze_files(
        &self,
        files_content: &[FileContent],
        question: &str,
    ) -> Result<GenerateResponse> {
        let mut prompt = format!(
            "Analyze the following code files and answer this question: {}\n\n",
            question
        );

        prompt.push_str("Files to analyze:\n\n");
        for file in files_content {
            prompt.push_str(&format!(
                "File: {}\n```\n{}\n```\n\n",
                file.path, file.content
            ));
        }

        self.generate_content(&prompt, None).await
    }

    /// Ask a question about the codebase
    pub async fn ask_about_code(&self, context: &str, question: &str) -> Result<GenerateResponse> {
        let prompt = format!(
            "Context from codebase:\n{}\n\nQuestion: {}\n\nProvide a clear, concise answer based on the context provided.",
            context, question
        );

        self.generate_content(&prompt, None).await
    }

    /// Summarize a directory structure and its contents
    pub async fn summarize_directory(
        &self,
        directory_structure: &str,
        files_content: &[FileContent],
    ) -> Result<GenerateResponse> {
        let mut prompt = format!(
            "Summarize this directory and its code:\n\nDirectory structure:\n{}\n\n",
            directory_structure
        );

        prompt.push_str("Key files:\n\n");
        for file in files_content.iter().take(10) {
            // Limit to first 10 files
            prompt.push_str(&format!(
                "File: {}\n```\n{}\n```\n\n",
                file.path, file.content
            ));
        }

        prompt.push_str("\nProvide:\n");
        prompt.push_str("1. Overall purpose and architecture\n");
        prompt.push_str("2. Key components and their relationships\n");
        prompt.push_str("3. Notable patterns or technologies used\n");
        prompt.push_str("4. Entry points and main functionality\n");

        self.generate_content(&prompt, None).await
    }
}

#[derive(Debug, Serialize)]
struct GenerateRequest {
    contents: Vec<Content>,
}

#[derive(Debug, Serialize)]
struct Content {
    parts: Vec<Part>,
}

#[derive(Debug, Serialize, Deserialize)]
struct Part {
    text: String,
}

#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
struct GeminiResponse {
    candidates: Vec<Candidate>,
    usage_metadata: Option<UsageMetadata>,
}

#[derive(Debug, Deserialize)]
struct Candidate {
    content: ContentResponse,
}

#[derive(Debug, Deserialize)]
struct ContentResponse {
    parts: Vec<Part>,
}

#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
struct UsageMetadata {
    prompt_token_count: usize,
    candidates_token_count: usize,
    total_token_count: usize,
}

/// Response from Gemini API with quota information
#[derive(Debug)]
pub struct GenerateResponse {
    pub text: String,
    pub quota_status: crate::quota::QuotaStatus,
    pub prompt_tokens: usize,
    pub total_tokens: usize,
}

/// File content for codebase operations
#[derive(Debug, Clone)]
pub struct FileContent {
    pub path: String,
    pub content: String,
}

impl FileContent {
    pub fn new(path: String, content: String) -> Self {
        Self { path, content }
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::api_key::{ApiKeyProvider, SecureString};

    struct MockApiKeyProvider;

    #[async_trait::async_trait]
    impl ApiKeyProvider for MockApiKeyProvider {
        async fn get_key(&self) -> Result<SecureString> {
            Ok(SecureString::new(
                "AIzaSyDMockKey1234567890123456789".to_string(),
            ))
        }
    }

    #[test]
    fn test_file_content_creation() {
        let file = FileContent::new("src/main.rs".to_string(), "fn main() {}".to_string());
        assert_eq!(file.path, "src/main.rs");
        assert_eq!(file.content, "fn main() {}");
    }

    #[tokio::test]
    async fn test_gemini_client_creation() {
        let api_key_provider = Arc::new(MockApiKeyProvider) as Arc<dyn ApiKeyProvider>;
        let quota_tracker = Arc::new(QuotaTracker::new());

        let _client = GeminiClient::new(api_key_provider, quota_tracker);
        // Client creation always succeeds now
    }

    #[test]
    fn test_default_model() {
        let api_key_provider = Arc::new(MockApiKeyProvider) as Arc<dyn ApiKeyProvider>;
        let quota_tracker = Arc::new(QuotaTracker::new());

        let client = GeminiClient::new(api_key_provider, quota_tracker);
        assert_eq!(client.model, DEFAULT_MODEL);
    }

    #[test]
    fn test_with_model() {
        let api_key_provider = Arc::new(MockApiKeyProvider) as Arc<dyn ApiKeyProvider>;
        let quota_tracker = Arc::new(QuotaTracker::new());

        let client = GeminiClient::new(api_key_provider, quota_tracker)
            .with_model("gemini-2.5-pro".to_string());

        assert_eq!(client.model, "gemini-2.5-pro");
    }
}