auto_commit/api/
openai_compatible.rs

1use anyhow::{Context, Result};
2use async_trait::async_trait;
3use serde::{Deserialize, Serialize};
4
5use super::client::{build_prompt, parse_commit_message, LlmClient};
6use super::provider::Provider;
7
8#[derive(Debug, Serialize)]
9struct ChatRequest {
10    model: String,
11    messages: Vec<ChatMessage>,
12    temperature: f32,
13}
14
15#[derive(Debug, Serialize, Deserialize)]
16struct ChatMessage {
17    role: String,
18    content: String,
19}
20
21#[derive(Debug, Deserialize)]
22struct ChatResponse {
23    choices: Vec<Choice>,
24}
25
26#[derive(Debug, Deserialize)]
27struct Choice {
28    message: ChatMessage,
29}
30
31/// Client for OpenAI-compatible APIs (OpenAI, DeepSeek)
32#[derive(Debug, Clone)]
33pub struct OpenAiCompatibleClient {
34    api_key: String,
35    base_url: String,
36    model: String,
37    provider: Provider,
38    client: reqwest::Client,
39}
40
41impl OpenAiCompatibleClient {
42    /// Create a new client for the specified provider
43    pub fn new(provider: Provider, api_key: String) -> Self {
44        Self {
45            api_key,
46            base_url: provider.base_url().to_string(),
47            model: provider.default_model().to_string(),
48            provider,
49            client: reqwest::Client::new(),
50        }
51    }
52
53    /// Create client with custom base URL (useful for testing or custom endpoints)
54    pub fn with_base_url(provider: Provider, api_key: String, base_url: String) -> Self {
55        Self {
56            api_key,
57            base_url,
58            model: provider.default_model().to_string(),
59            provider,
60            client: reqwest::Client::new(),
61        }
62    }
63
64    /// Set custom model
65    pub fn with_model(mut self, model: impl Into<String>) -> Self {
66        self.model = model.into();
67        self
68    }
69}
70
71#[async_trait]
72impl LlmClient for OpenAiCompatibleClient {
73    async fn generate_commit_message(
74        &self,
75        diff: &str,
76        template: Option<&str>,
77    ) -> Result<(String, String)> {
78        let prompt = build_prompt(diff, template);
79
80        let request = ChatRequest {
81            model: self.model.clone(),
82            messages: vec![
83                ChatMessage {
84                    role: "system".to_string(),
85                    content: "あなたは経験豊富なソフトウェアエンジニアです。Git diffから適切なコミットメッセージを生成してください。".to_string(),
86                },
87                ChatMessage {
88                    role: "user".to_string(),
89                    content: prompt,
90                },
91            ],
92            temperature: 0.7,
93        };
94
95        let response = self
96            .client
97            .post(format!("{}/v1/chat/completions", self.base_url))
98            .header("Authorization", format!("Bearer {}", self.api_key))
99            .header("Content-Type", "application/json")
100            .json(&request)
101            .send()
102            .await
103            .with_context(|| format!("Failed to send request to {} API", self.provider))?;
104
105        if !response.status().is_success() {
106            let status = response.status();
107            let error_text = response.text().await.unwrap_or_default();
108            return Err(anyhow::anyhow!(
109                "{} API request failed ({}): {}",
110                self.provider,
111                status,
112                error_text
113            ));
114        }
115
116        let api_response: ChatResponse = response
117            .json()
118            .await
119            .with_context(|| format!("Failed to parse {} API response", self.provider))?;
120
121        let message = api_response
122            .choices
123            .first()
124            .context("No choices in API response")?
125            .message
126            .content
127            .trim();
128
129        Ok(parse_commit_message(message))
130    }
131
132    fn provider_name(&self) -> &str {
133        match self.provider {
134            Provider::OpenAi => "OpenAI",
135            Provider::DeepSeek => "DeepSeek",
136            _ => "Unknown",
137        }
138    }
139}
140
141#[cfg(test)]
142mod tests {
143    use super::*;
144    use mockito::Server;
145
146    #[tokio::test]
147    async fn test_openai_generate_commit_message() {
148        let mut server = Server::new_async().await;
149        let _m = server
150            .mock("POST", "/v1/chat/completions")
151            .with_status(200)
152            .with_header("content-type", "application/json")
153            .with_body(
154                r#"{
155                "choices": [{
156                    "message": {
157                        "role": "assistant",
158                        "content": "feat: Add user authentication\n\nImplemented JWT-based auth"
159                    }
160                }]
161            }"#,
162            )
163            .create_async()
164            .await;
165
166        let client =
167            OpenAiCompatibleClient::with_base_url(Provider::OpenAi, "test-key".into(), server.url());
168
169        let (title, desc) = client
170            .generate_commit_message("diff --git", None)
171            .await
172            .unwrap();
173
174        assert_eq!(title, "feat: Add user authentication");
175        assert_eq!(desc, "Implemented JWT-based auth");
176    }
177
178    #[tokio::test]
179    async fn test_deepseek_generate_commit_message() {
180        let mut server = Server::new_async().await;
181        let _m = server
182            .mock("POST", "/v1/chat/completions")
183            .with_status(200)
184            .with_header("content-type", "application/json")
185            .with_body(
186                r#"{
187                "choices": [{
188                    "message": {
189                        "role": "assistant",
190                        "content": "fix: Resolve memory leak"
191                    }
192                }]
193            }"#,
194            )
195            .create_async()
196            .await;
197
198        let client = OpenAiCompatibleClient::with_base_url(
199            Provider::DeepSeek,
200            "test-key".into(),
201            server.url(),
202        );
203
204        let (title, desc) = client
205            .generate_commit_message("diff --git", None)
206            .await
207            .unwrap();
208
209        assert_eq!(title, "fix: Resolve memory leak");
210        assert_eq!(desc, "");
211    }
212
213    #[tokio::test]
214    async fn test_api_error_handling() {
215        let mut server = Server::new_async().await;
216        let _m = server
217            .mock("POST", "/v1/chat/completions")
218            .with_status(401)
219            .with_body(r#"{"error": "Invalid API key"}"#)
220            .create_async()
221            .await;
222
223        let client =
224            OpenAiCompatibleClient::with_base_url(Provider::OpenAi, "bad-key".into(), server.url());
225
226        let result = client.generate_commit_message("diff", None).await;
227        assert!(result.is_err());
228        let err = result.unwrap_err().to_string();
229        assert!(err.contains("401") || err.contains("API"));
230    }
231
232    #[test]
233    fn test_provider_name() {
234        let openai = OpenAiCompatibleClient::new(Provider::OpenAi, "key".into());
235        assert_eq!(openai.provider_name(), "OpenAI");
236
237        let deepseek = OpenAiCompatibleClient::new(Provider::DeepSeek, "key".into());
238        assert_eq!(deepseek.provider_name(), "DeepSeek");
239    }
240
241    #[test]
242    fn test_with_model() {
243        let client = OpenAiCompatibleClient::new(Provider::OpenAi, "key".into())
244            .with_model("gpt-4-turbo");
245        assert_eq!(client.model, "gpt-4-turbo");
246    }
247}