auto_commit/api/
gemini.rs1use anyhow::{Context, Result};
2use async_trait::async_trait;
3use serde::{Deserialize, Serialize};
4
5use super::client::{build_prompt, parse_commit_message, LlmClient};
6use super::provider::Provider;
7
8#[derive(Debug, Serialize)]
10struct GeminiRequest {
11 contents: Vec<Content>,
12 #[serde(rename = "generationConfig")]
13 generation_config: GenerationConfig,
14}
15
16#[derive(Debug, Serialize)]
17struct Content {
18 parts: Vec<Part>,
19}
20
21#[derive(Debug, Serialize)]
22struct Part {
23 text: String,
24}
25
26#[derive(Debug, Serialize)]
27struct GenerationConfig {
28 temperature: f32,
29}
30
31#[derive(Debug, Deserialize)]
33struct GeminiResponse {
34 candidates: Vec<Candidate>,
35}
36
37#[derive(Debug, Deserialize)]
38struct Candidate {
39 content: ResponseContent,
40}
41
42#[derive(Debug, Deserialize)]
43struct ResponseContent {
44 parts: Vec<ResponsePart>,
45}
46
47#[derive(Debug, Deserialize)]
48struct ResponsePart {
49 text: String,
50}
51
52#[derive(Debug, Clone)]
54pub struct GeminiClient {
55 api_key: String,
56 base_url: String,
57 model: String,
58 client: reqwest::Client,
59}
60
61impl GeminiClient {
62 pub fn new(api_key: String) -> Self {
64 let provider = Provider::Gemini;
65 Self {
66 api_key,
67 base_url: provider.base_url().to_string(),
68 model: provider.default_model().to_string(),
69 client: reqwest::Client::new(),
70 }
71 }
72
73 pub fn with_base_url(api_key: String, base_url: String) -> Self {
75 Self {
76 api_key,
77 base_url,
78 model: Provider::Gemini.default_model().to_string(),
79 client: reqwest::Client::new(),
80 }
81 }
82
83 pub fn with_model(mut self, model: impl Into<String>) -> Self {
85 self.model = model.into();
86 self
87 }
88
89 fn build_system_prompt(&self, template: Option<&str>) -> String {
90 format!(
91 "あなたは経験豊富なソフトウェアエンジニアです。Git diffから適切なコミットメッセージを生成してください。\n\n{}",
92 template.unwrap_or("")
93 )
94 }
95}
96
97#[async_trait]
98impl LlmClient for GeminiClient {
99 async fn generate_commit_message(
100 &self,
101 diff: &str,
102 template: Option<&str>,
103 ) -> Result<(String, String)> {
104 let prompt = build_prompt(diff, template);
105 let system_prompt = self.build_system_prompt(template);
106
107 let request = GeminiRequest {
108 contents: vec![Content {
109 parts: vec![
110 Part {
111 text: system_prompt,
112 },
113 Part { text: prompt },
114 ],
115 }],
116 generation_config: GenerationConfig { temperature: 0.7 },
117 };
118
119 let url = format!(
121 "{}/v1beta/models/{}:generateContent?key={}",
122 self.base_url, self.model, self.api_key
123 );
124
125 let response = self
126 .client
127 .post(&url)
128 .header("Content-Type", "application/json")
129 .json(&request)
130 .send()
131 .await
132 .context("Failed to send request to Gemini API")?;
133
134 if !response.status().is_success() {
135 let status = response.status();
136 let error_text = response.text().await.unwrap_or_default();
137 return Err(anyhow::anyhow!(
138 "Gemini API request failed ({}): {}",
139 status,
140 error_text
141 ));
142 }
143
144 let api_response: GeminiResponse = response
145 .json()
146 .await
147 .context("Failed to parse Gemini API response")?;
148
149 let message = api_response
150 .candidates
151 .first()
152 .context("No candidates in API response")?
153 .content
154 .parts
155 .first()
156 .context("No parts in response content")?
157 .text
158 .trim();
159
160 Ok(parse_commit_message(message))
161 }
162
163 fn provider_name(&self) -> &str {
164 "Gemini"
165 }
166}
167
168#[cfg(test)]
169mod tests {
170 use super::*;
171 use mockito::Server;
172
173 #[tokio::test]
174 async fn test_gemini_generate_commit_message() {
175 let mut server = Server::new_async().await;
176 let _m = server
177 .mock("POST", mockito::Matcher::Regex(r"/v1beta/models/.*:generateContent.*".to_string()))
178 .with_status(200)
179 .with_header("content-type", "application/json")
180 .with_body(
181 r#"{
182 "candidates": [{
183 "content": {
184 "parts": [{
185 "text": "feat: Add new feature\n\nImplemented the feature"
186 }]
187 }
188 }]
189 }"#,
190 )
191 .create_async()
192 .await;
193
194 let client = GeminiClient::with_base_url("test-key".into(), server.url());
195
196 let (title, desc) = client
197 .generate_commit_message("diff --git", None)
198 .await
199 .unwrap();
200
201 assert_eq!(title, "feat: Add new feature");
202 assert_eq!(desc, "Implemented the feature");
203 }
204
205 #[tokio::test]
206 async fn test_gemini_api_error() {
207 let mut server = Server::new_async().await;
208 let _m = server
209 .mock("POST", mockito::Matcher::Regex(r"/v1beta/models/.*:generateContent.*".to_string()))
210 .with_status(400)
211 .with_body(r#"{"error": {"message": "Invalid API key"}}"#)
212 .create_async()
213 .await;
214
215 let client = GeminiClient::with_base_url("bad-key".into(), server.url());
216
217 let result = client.generate_commit_message("diff", None).await;
218 assert!(result.is_err());
219 }
220
221 #[test]
222 fn test_provider_name() {
223 let client = GeminiClient::new("key".into());
224 assert_eq!(client.provider_name(), "Gemini");
225 }
226
227 #[test]
228 fn test_with_model() {
229 let client = GeminiClient::new("key".into()).with_model("gemini-1.5-pro");
230 assert_eq!(client.model, "gemini-1.5-pro");
231 }
232}