1use anyhow::{Context, Result};
2use async_trait::async_trait;
3use serde::{Deserialize, Serialize};
4
5use super::client::{build_prompt, parse_commit_message, LlmClient};
6use super::provider::Provider;
7
8#[derive(Debug, Serialize)]
9struct ChatRequest {
10 model: String,
11 messages: Vec<ChatMessage>,
12 temperature: f32,
13}
14
15#[derive(Debug, Serialize, Deserialize)]
16struct ChatMessage {
17 role: String,
18 content: String,
19}
20
21#[derive(Debug, Deserialize)]
22struct ChatResponse {
23 choices: Vec<Choice>,
24}
25
26#[derive(Debug, Deserialize)]
27struct Choice {
28 message: ChatMessage,
29}
30
31#[derive(Debug, Clone)]
33pub struct OpenAiCompatibleClient {
34 api_key: String,
35 base_url: String,
36 model: String,
37 provider: Provider,
38 client: reqwest::Client,
39}
40
41impl OpenAiCompatibleClient {
42 pub fn new(provider: Provider, api_key: String) -> Self {
44 Self {
45 api_key,
46 base_url: provider.base_url().to_string(),
47 model: provider.default_model().to_string(),
48 provider,
49 client: reqwest::Client::new(),
50 }
51 }
52
53 pub fn with_base_url(provider: Provider, api_key: String, base_url: String) -> Self {
55 Self {
56 api_key,
57 base_url,
58 model: provider.default_model().to_string(),
59 provider,
60 client: reqwest::Client::new(),
61 }
62 }
63
64 pub fn with_model(mut self, model: impl Into<String>) -> Self {
66 self.model = model.into();
67 self
68 }
69}
70
71#[async_trait]
72impl LlmClient for OpenAiCompatibleClient {
73 async fn generate_commit_message(
74 &self,
75 diff: &str,
76 template: Option<&str>,
77 ) -> Result<(String, String)> {
78 let prompt = build_prompt(diff, template);
79
80 let request = ChatRequest {
81 model: self.model.clone(),
82 messages: vec![
83 ChatMessage {
84 role: "system".to_string(),
85 content: "あなたは経験豊富なソフトウェアエンジニアです。Git diffから適切なコミットメッセージを生成してください。".to_string(),
86 },
87 ChatMessage {
88 role: "user".to_string(),
89 content: prompt,
90 },
91 ],
92 temperature: 0.7,
93 };
94
95 let response = self
96 .client
97 .post(format!("{}/v1/chat/completions", self.base_url))
98 .header("Authorization", format!("Bearer {}", self.api_key))
99 .header("Content-Type", "application/json")
100 .json(&request)
101 .send()
102 .await
103 .with_context(|| format!("Failed to send request to {} API", self.provider))?;
104
105 if !response.status().is_success() {
106 let status = response.status();
107 let error_text = response.text().await.unwrap_or_default();
108 return Err(anyhow::anyhow!(
109 "{} API request failed ({}): {}",
110 self.provider,
111 status,
112 error_text
113 ));
114 }
115
116 let api_response: ChatResponse = response
117 .json()
118 .await
119 .with_context(|| format!("Failed to parse {} API response", self.provider))?;
120
121 let message = api_response
122 .choices
123 .first()
124 .context("No choices in API response")?
125 .message
126 .content
127 .trim();
128
129 Ok(parse_commit_message(message))
130 }
131
132 fn provider_name(&self) -> &str {
133 match self.provider {
134 Provider::OpenAi => "OpenAI",
135 Provider::DeepSeek => "DeepSeek",
136 _ => "Unknown",
137 }
138 }
139}
140
141#[cfg(test)]
142mod tests {
143 use super::*;
144 use mockito::Server;
145
146 #[tokio::test]
147 async fn test_openai_generate_commit_message() {
148 let mut server = Server::new_async().await;
149 let _m = server
150 .mock("POST", "/v1/chat/completions")
151 .with_status(200)
152 .with_header("content-type", "application/json")
153 .with_body(
154 r#"{
155 "choices": [{
156 "message": {
157 "role": "assistant",
158 "content": "feat: Add user authentication\n\nImplemented JWT-based auth"
159 }
160 }]
161 }"#,
162 )
163 .create_async()
164 .await;
165
166 let client =
167 OpenAiCompatibleClient::with_base_url(Provider::OpenAi, "test-key".into(), server.url());
168
169 let (title, desc) = client
170 .generate_commit_message("diff --git", None)
171 .await
172 .unwrap();
173
174 assert_eq!(title, "feat: Add user authentication");
175 assert_eq!(desc, "Implemented JWT-based auth");
176 }
177
178 #[tokio::test]
179 async fn test_deepseek_generate_commit_message() {
180 let mut server = Server::new_async().await;
181 let _m = server
182 .mock("POST", "/v1/chat/completions")
183 .with_status(200)
184 .with_header("content-type", "application/json")
185 .with_body(
186 r#"{
187 "choices": [{
188 "message": {
189 "role": "assistant",
190 "content": "fix: Resolve memory leak"
191 }
192 }]
193 }"#,
194 )
195 .create_async()
196 .await;
197
198 let client = OpenAiCompatibleClient::with_base_url(
199 Provider::DeepSeek,
200 "test-key".into(),
201 server.url(),
202 );
203
204 let (title, desc) = client
205 .generate_commit_message("diff --git", None)
206 .await
207 .unwrap();
208
209 assert_eq!(title, "fix: Resolve memory leak");
210 assert_eq!(desc, "");
211 }
212
213 #[tokio::test]
214 async fn test_api_error_handling() {
215 let mut server = Server::new_async().await;
216 let _m = server
217 .mock("POST", "/v1/chat/completions")
218 .with_status(401)
219 .with_body(r#"{"error": "Invalid API key"}"#)
220 .create_async()
221 .await;
222
223 let client =
224 OpenAiCompatibleClient::with_base_url(Provider::OpenAi, "bad-key".into(), server.url());
225
226 let result = client.generate_commit_message("diff", None).await;
227 assert!(result.is_err());
228 let err = result.unwrap_err().to_string();
229 assert!(err.contains("401") || err.contains("API"));
230 }
231
232 #[test]
233 fn test_provider_name() {
234 let openai = OpenAiCompatibleClient::new(Provider::OpenAi, "key".into());
235 assert_eq!(openai.provider_name(), "OpenAI");
236
237 let deepseek = OpenAiCompatibleClient::new(Provider::DeepSeek, "key".into());
238 assert_eq!(deepseek.provider_name(), "DeepSeek");
239 }
240
241 #[test]
242 fn test_with_model() {
243 let client = OpenAiCompatibleClient::new(Provider::OpenAi, "key".into())
244 .with_model("gpt-4-turbo");
245 assert_eq!(client.model, "gpt-4-turbo");
246 }
247}