auto_commit/api/
gemini.rs

1use anyhow::{Context, Result};
2use async_trait::async_trait;
3use serde::{Deserialize, Serialize};
4
5use super::client::{build_prompt, parse_commit_message, LlmClient};
6use super::provider::Provider;
7
8/// Gemini API request format
9#[derive(Debug, Serialize)]
10struct GeminiRequest {
11    contents: Vec<Content>,
12    #[serde(rename = "generationConfig")]
13    generation_config: GenerationConfig,
14}
15
16#[derive(Debug, Serialize)]
17struct Content {
18    parts: Vec<Part>,
19}
20
21#[derive(Debug, Serialize)]
22struct Part {
23    text: String,
24}
25
26#[derive(Debug, Serialize)]
27struct GenerationConfig {
28    temperature: f32,
29}
30
31/// Gemini API response format
32#[derive(Debug, Deserialize)]
33struct GeminiResponse {
34    candidates: Vec<Candidate>,
35}
36
37#[derive(Debug, Deserialize)]
38struct Candidate {
39    content: ResponseContent,
40}
41
42#[derive(Debug, Deserialize)]
43struct ResponseContent {
44    parts: Vec<ResponsePart>,
45}
46
47#[derive(Debug, Deserialize)]
48struct ResponsePart {
49    text: String,
50}
51
52/// Client for Google Gemini API
53#[derive(Debug, Clone)]
54pub struct GeminiClient {
55    api_key: String,
56    base_url: String,
57    model: String,
58    client: reqwest::Client,
59}
60
61impl GeminiClient {
62    /// Create a new Gemini client
63    pub fn new(api_key: String) -> Self {
64        let provider = Provider::Gemini;
65        Self {
66            api_key,
67            base_url: provider.base_url().to_string(),
68            model: provider.default_model().to_string(),
69            client: reqwest::Client::new(),
70        }
71    }
72
73    /// Create client with custom base URL (useful for testing or custom endpoints)
74    pub fn with_base_url(api_key: String, base_url: String) -> Self {
75        Self {
76            api_key,
77            base_url,
78            model: Provider::Gemini.default_model().to_string(),
79            client: reqwest::Client::new(),
80        }
81    }
82
83    /// Set custom model
84    pub fn with_model(mut self, model: impl Into<String>) -> Self {
85        self.model = model.into();
86        self
87    }
88
89    fn build_system_prompt(&self, template: Option<&str>) -> String {
90        format!(
91            "あなたは経験豊富なソフトウェアエンジニアです。Git diffから適切なコミットメッセージを生成してください。\n\n{}",
92            template.unwrap_or("")
93        )
94    }
95}
96
97#[async_trait]
98impl LlmClient for GeminiClient {
99    async fn generate_commit_message(
100        &self,
101        diff: &str,
102        template: Option<&str>,
103    ) -> Result<(String, String)> {
104        let prompt = build_prompt(diff, template);
105        let system_prompt = self.build_system_prompt(template);
106
107        let request = GeminiRequest {
108            contents: vec![Content {
109                parts: vec![
110                    Part {
111                        text: system_prompt,
112                    },
113                    Part { text: prompt },
114                ],
115            }],
116            generation_config: GenerationConfig { temperature: 0.7 },
117        };
118
119        // Gemini uses query parameter for API key
120        let url = format!(
121            "{}/v1beta/models/{}:generateContent?key={}",
122            self.base_url, self.model, self.api_key
123        );
124
125        let response = self
126            .client
127            .post(&url)
128            .header("Content-Type", "application/json")
129            .json(&request)
130            .send()
131            .await
132            .context("Failed to send request to Gemini API")?;
133
134        if !response.status().is_success() {
135            let status = response.status();
136            let error_text = response.text().await.unwrap_or_default();
137            return Err(anyhow::anyhow!(
138                "Gemini API request failed ({}): {}",
139                status,
140                error_text
141            ));
142        }
143
144        let api_response: GeminiResponse = response
145            .json()
146            .await
147            .context("Failed to parse Gemini API response")?;
148
149        let message = api_response
150            .candidates
151            .first()
152            .context("No candidates in API response")?
153            .content
154            .parts
155            .first()
156            .context("No parts in response content")?
157            .text
158            .trim();
159
160        Ok(parse_commit_message(message))
161    }
162
163    fn provider_name(&self) -> &str {
164        "Gemini"
165    }
166}
167
168#[cfg(test)]
169mod tests {
170    use super::*;
171    use mockito::Server;
172
173    #[tokio::test]
174    async fn test_gemini_generate_commit_message() {
175        let mut server = Server::new_async().await;
176        let _m = server
177            .mock("POST", mockito::Matcher::Regex(r"/v1beta/models/.*:generateContent.*".to_string()))
178            .with_status(200)
179            .with_header("content-type", "application/json")
180            .with_body(
181                r#"{
182                "candidates": [{
183                    "content": {
184                        "parts": [{
185                            "text": "feat: Add new feature\n\nImplemented the feature"
186                        }]
187                    }
188                }]
189            }"#,
190            )
191            .create_async()
192            .await;
193
194        let client = GeminiClient::with_base_url("test-key".into(), server.url());
195
196        let (title, desc) = client
197            .generate_commit_message("diff --git", None)
198            .await
199            .unwrap();
200
201        assert_eq!(title, "feat: Add new feature");
202        assert_eq!(desc, "Implemented the feature");
203    }
204
205    #[tokio::test]
206    async fn test_gemini_api_error() {
207        let mut server = Server::new_async().await;
208        let _m = server
209            .mock("POST", mockito::Matcher::Regex(r"/v1beta/models/.*:generateContent.*".to_string()))
210            .with_status(400)
211            .with_body(r#"{"error": {"message": "Invalid API key"}}"#)
212            .create_async()
213            .await;
214
215        let client = GeminiClient::with_base_url("bad-key".into(), server.url());
216
217        let result = client.generate_commit_message("diff", None).await;
218        assert!(result.is_err());
219    }
220
221    #[test]
222    fn test_provider_name() {
223        let client = GeminiClient::new("key".into());
224        assert_eq!(client.provider_name(), "Gemini");
225    }
226
227    #[test]
228    fn test_with_model() {
229        let client = GeminiClient::new("key".into()).with_model("gemini-1.5-pro");
230        assert_eq!(client.model, "gemini-1.5-pro");
231    }
232}