Skip to main content

rusty_commit/providers/
openai.rs

1use anyhow::{Context, Result};
2use async_openai::{
3    config::OpenAIConfig,
4    types::chat::{
5        ChatCompletionRequestSystemMessage, ChatCompletionRequestUserMessage,
6        CreateChatCompletionRequestArgs,
7    },
8    Client,
9};
10use async_trait::async_trait;
11
12use super::{split_prompt, AIProvider};
13use crate::config::accounts::AccountConfig;
14use crate::config::Config;
15use crate::utils::retry::retry_async;
16
17pub struct OpenAIProvider {
18    client: Client<OpenAIConfig>,
19    model: String,
20}
21
22impl OpenAIProvider {
23    pub fn new(config: &Config) -> Result<Self> {
24        let api_key = config
25            .api_key
26            .as_ref()
27            .context("OpenAI API key not configured.\nRun: rco config set RCO_API_KEY=<your_key>\nGet your API key from: https://platform.openai.com/api-keys")?;
28
29        let openai_config = OpenAIConfig::new().with_api_key(api_key).with_api_base(
30            config
31                .api_url
32                .as_deref()
33                .unwrap_or("https://api.openai.com/v1"),
34        );
35
36        let client = Client::with_config(openai_config);
37        let model = config.model.as_deref().unwrap_or("gpt-4o-mini").to_string();
38
39        Ok(Self { client, model })
40    }
41
42    /// Create provider from account configuration
43    #[allow(dead_code)]
44    pub fn from_account(account: &AccountConfig, api_key: &str, config: &Config) -> Result<Self> {
45        let openai_config = OpenAIConfig::new().with_api_key(api_key).with_api_base(
46            account
47                .api_url
48                .as_deref()
49                .or(config.api_url.as_deref())
50                .unwrap_or("https://api.openai.com/v1"),
51        );
52
53        let client = Client::with_config(openai_config);
54        let model = account
55            .model
56            .as_deref()
57            .or(config.model.as_deref())
58            .unwrap_or("gpt-4o-mini")
59            .to_string();
60
61        Ok(Self { client, model })
62    }
63}
64
65#[async_trait]
66impl AIProvider for OpenAIProvider {
67    async fn generate_commit_message(
68        &self,
69        diff: &str,
70        context: Option<&str>,
71        full_gitmoji: bool,
72        config: &Config,
73    ) -> Result<String> {
74        let (system_prompt, user_prompt) = split_prompt(diff, context, config, full_gitmoji);
75
76        let messages = vec![
77            ChatCompletionRequestSystemMessage::from(system_prompt).into(),
78            ChatCompletionRequestUserMessage::from(user_prompt).into(),
79        ];
80
81        // Handle model-specific parameters
82        let request = if self.model.contains("gpt-5-nano") {
83            // GPT-5-nano doesn't support temperature=0, use 1.0 (default)
84            CreateChatCompletionRequestArgs::default()
85                .model(&self.model)
86                .messages(messages)
87                .temperature(1.0)
88                .max_tokens(config.tokens_max_output.unwrap_or(500) as u16)
89                .build()?
90        } else {
91            // Standard models support temperature=0.7 and max_tokens
92            CreateChatCompletionRequestArgs::default()
93                .model(&self.model)
94                .messages(messages)
95                .temperature(0.7)
96                .max_tokens(config.tokens_max_output.unwrap_or(500) as u16)
97                .build()?
98        };
99
100        let response = retry_async(|| async {
101            match self.client.chat().create(request.clone()).await {
102                Ok(resp) => Ok(resp),
103                Err(e) => {
104                    let error_msg = e.to_string();
105                    if error_msg.contains("401") || error_msg.contains("invalid_api_key") {
106                        Err(anyhow::anyhow!("Invalid OpenAI API key. Please check your API key configuration."))
107                    } else if error_msg.contains("insufficient_quota") {
108                        Err(anyhow::anyhow!("OpenAI API quota exceeded. Please check your billing status."))
109                    } else {
110                        Err(anyhow::anyhow!(e).context("Failed to generate commit message from OpenAI"))
111                    }
112                }
113            }
114        }).await.context("Failed to generate commit message from OpenAI after retries. Please check your internet connection and API configuration.")?;
115
116        let message = response
117            .choices
118            .first()
119            .and_then(|choice| choice.message.content.as_ref())
120            .context("OpenAI returned an empty response. The model may be overloaded - please try again.")?
121            .trim()
122            .to_string();
123
124        Ok(message)
125    }
126}
127
128/// OpenAICompatibleProvider - A wrapper that handles OpenAI-compatible providers
129/// This struct registers all OpenAI-compatible API providers in the registry
130#[allow(dead_code)]
131pub struct OpenAICompatibleProvider {
132    pub name: &'static str,
133    pub aliases: Vec<&'static str>,
134    pub default_api_url: &'static str,
135    pub default_model: Option<&'static str>,
136    pub compatible_providers: std::collections::HashMap<&'static str, &'static str>,
137}
138
139impl OpenAICompatibleProvider {
140    pub fn new() -> Self {
141        let mut compat = std::collections::HashMap::new();
142
143        // Core OpenAI-compatible providers
144        compat.insert("deepseek", "https://api.deepseek.com/v1");
145        compat.insert("groq", "https://api.groq.com/openai/v1");
146        compat.insert("openrouter", "https://openrouter.ai/api/v1");
147        compat.insert("together", "https://api.together.ai/v1");
148        compat.insert("deepinfra", "https://api.deepinfra.com/v1/openai");
149        compat.insert("mistral", "https://api.mistral.ai/v1");
150        compat.insert("github-models", "https://models.inference.ai.azure.com");
151        compat.insert("fireworks", "https://api.fireworks.ai/v1");
152        compat.insert("fireworks-ai", "https://api.fireworks.ai/v1");
153        compat.insert("moonshot", "https://api.moonshot.cn/v1");
154        compat.insert("moonshot-ai", "https://api.moonshot.cn/v1");
155        compat.insert("dashscope", "https://dashscope.console.aliyuncs.com/api/v1");
156        compat.insert("alibaba", "https://dashscope.console.aliyuncs.com/api/v1");
157        compat.insert("qwen", "https://dashscope.console.aliyuncs.com/api/v1");
158        compat.insert(
159            "qwen-coder",
160            "https://dashscope.console.aliyuncs.com/api/v1",
161        );
162        compat.insert("codex", "https://api.openai.com/v1");
163
164        // ═══════════════════════════════════════════════════════════════════════
165        // OpenAI-Compatible Providers - Comprehensive list from OpenCode
166        // ═══════════════════════════════════════════════════════════════════════
167
168        // ───────────────────────────────────────────────────────────────────────
169        // 1. Major Cloud Providers (Enterprise)
170        // ───────────────────────────────────────────────────────────────────────
171        compat.insert("cohere", "https://api.cohere.com/v1");
172        compat.insert("cohere-ai", "https://api.cohere.com/v1");
173        compat.insert("ai21", "https://api.ai21.com/studio/v1");
174        compat.insert("ai21-labs", "https://api.ai21.com/studio/v1");
175        compat.insert("upstage", "https://api.upstage.ai/v1/solar");
176        compat.insert("upstage-ai", "https://api.upstage.ai/v1/solar");
177        compat.insert("solar", "https://api.upstage.ai/v1/solar");
178        compat.insert("solar-pro", "https://api.upstage.ai/v1/solar");
179
180        // ───────────────────────────────────────────────────────────────────────
181        // 2. GPU Cloud & Inference Providers
182        // ───────────────────────────────────────────────────────────────────────
183        compat.insert("nebius", "https://api.studio.nebius.ai/v1");
184        compat.insert("nebius-ai", "https://api.studio.nebius.ai/v1");
185        compat.insert("nebius-studio", "https://api.studio.nebius.ai/v1");
186        compat.insert("ovh", "https://api.ovh.com/v1");
187        compat.insert("ovhcloud", "https://api.ovh.com/v1");
188        compat.insert("ovh-ai", "https://api.ovh.com/v1");
189        compat.insert("scaleway", "https://api.scaleway.ai/v1");
190        compat.insert("scaleway-ai", "https://api.scaleway.ai/v1");
191        compat.insert("friendli", "https://api.friendli.ai/v1");
192        compat.insert("friendli-ai", "https://api.friendli.ai/v1");
193        compat.insert("baseten", "https://api.baseten.co/v1");
194        compat.insert("baseten-ai", "https://api.baseten.co/v1");
195        compat.insert("chutes", "https://api.chutes.ai/v1");
196        compat.insert("chutes-ai", "https://api.chutes.ai/v1");
197        compat.insert("ionet", "https://api.io.net/v1");
198        compat.insert("io-net", "https://api.io.net/v1");
199        compat.insert("modelscope", "https://api.modelscope.cn/v1");
200        compat.insert("requesty", "https://api.requesty.ai/v1");
201        compat.insert("morph", "https://api.morph.so/v1");
202        compat.insert("morph-labs", "https://api.morph.so/v1");
203        compat.insert("synthetic", "https://api.syntheticai.com/v1");
204        compat.insert("nano-gpt", "https://api.nano-gpt.com/v1");
205        compat.insert("nanogpt", "https://api.nano-gpt.com/v1");
206        compat.insert("zenmux", "https://api.zenmux.com/v1");
207        compat.insert("v0", "https://api.v0.dev/v1");
208        compat.insert("v0-vercel", "https://api.v0.dev/v1");
209        compat.insert("iflowcn", "https://api.iflow.cn/v1");
210        compat.insert("venice", "https://api.venice.ai/v1");
211        compat.insert("venice-ai", "https://api.venice.ai/v1");
212        compat.insert("cortecs", "https://api.cortecs.ai/v1");
213        compat.insert("cortecs-ai", "https://api.cortecs.ai/v1");
214        compat.insert("kimi-coding", "https://api.moonshot.cn/v1");
215        compat.insert("abacus", "https://api.abacus.ai/v1");
216        compat.insert("abacus-ai", "https://api.abacus.ai/v1");
217        compat.insert("bailing", "https://api.bailing.ai/v1");
218        compat.insert("fastrouter", "https://api.fastrouter.ai/v1");
219        compat.insert("inference", "https://api.inference.net/v1");
220        compat.insert("inference-net", "https://api.inference.net/v1");
221        compat.insert("submodel", "https://api.submodel.ai/v1");
222        compat.insert("zai", "https://api.z.ai/v1");
223        compat.insert("zai-coding", "https://api.z.ai/v1");
224        compat.insert("zhipu-coding", "https://open.bigmodel.cn/api/paas/v4");
225        compat.insert("poe", "https://api.poe.com/v1");
226        compat.insert("poe-ai", "https://api.poe.com/v1");
227        compat.insert("cerebras", "https://api.cerebras.ai/v1");
228        compat.insert("cerebras-ai", "https://api.cerebras.ai/v1");
229        compat.insert("sambanova", "https://api.sambanova.ai/v1");
230        compat.insert("sambanova-ai", "https://api.sambanova.ai/v1");
231        compat.insert("novita", "https://api.novita.ai/v3/openai");
232        compat.insert("novita-ai", "https://api.novita.ai/v3/openai");
233        compat.insert("predibase", "https://api.predibase.com/v1");
234        compat.insert("tensorops", "https://api.tensorops.ai/v1");
235        compat.insert("hyperbolic", "https://api.hyperbolic.ai/v1");
236        compat.insert("hyperbolic-ai", "https://api.hyperbolic.ai/v1");
237        compat.insert("kluster", "https://api.kluster.ai/v1");
238        compat.insert("kluster-ai", "https://api.kluster.ai/v1");
239        compat.insert("lambda", "https://api.lambda.ai/v1");
240        compat.insert("lambda-labs", "https://api.lambda.ai/v1");
241        compat.insert("replicate", "https://api.replicate.com/v1");
242        compat.insert("targon", "https://api.targon.com/v1");
243        compat.insert("corcel", "https://api.corcel.io/v1");
244        compat.insert("cybernative", "https://api.cybernative.ai/v1");
245        compat.insert("cybernative-ai", "https://api.cybernative.ai/v1");
246        compat.insert("edgen", "https://api.edgen.co/v1");
247        compat.insert("gigachat", "https://api.gigachat.ru/v1");
248        compat.insert("gigachat-ai", "https://api.gigachat.ru/v1");
249        compat.insert("hydra", "https://api.hydraai.com/v1");
250        compat.insert("hydra-ai", "https://api.hydraai.com/v1");
251        compat.insert("jina", "https://api.jina.ai/v1");
252        compat.insert("jina-ai", "https://api.jina.ai/v1");
253        compat.insert("lingyi", "https://api.lingyiwanwu.com/v1");
254        compat.insert("lingyiwanwu", "https://api.lingyiwanwu.com/v1");
255        compat.insert("monica", "https://api.monica.ai/v1");
256        compat.insert("monica-ai", "https://api.monica.ai/v1");
257        compat.insert("pollinations", "https://api.pollinations.ai/v1");
258        compat.insert("pollinations-ai", "https://api.pollinations.ai/v1");
259        compat.insert("rawechat", "https://api.rawe.chat/v1");
260        compat.insert("shuttleai", "https://api.shuttleai.com/v1");
261        compat.insert("shuttle-ai", "https://api.shuttleai.com/v1");
262        compat.insert("teknium", "https://api.teknium.ai/v1");
263        compat.insert("theb", "https://api.theb.ai/v1");
264        compat.insert("theb-ai", "https://api.theb.ai/v1");
265        compat.insert("tryleap", "https://api.tryleap.ai/v1");
266        compat.insert("leap-ai", "https://api.tryleap.ai/v1");
267
268        // ───────────────────────────────────────────────────────────────────────
269        // 3. Local/Self-hosted Providers
270        // ───────────────────────────────────────────────────────────────────────
271        compat.insert("lmstudio", "http://localhost:1234/v1");
272        compat.insert("lm-studio", "http://localhost:1234/v1");
273        compat.insert("llamacpp", "http://localhost:8080/v1");
274        compat.insert("llama-cpp", "http://localhost:8080/v1");
275        compat.insert("kobold", "http://localhost:5000/v1");
276        compat.insert("koboldcpp", "http://localhost:5000/v1");
277        compat.insert("textgen", "http://localhost:5000/v1");
278        compat.insert("text-generation", "http://localhost:5000/v1");
279        compat.insert("tabby", "http://localhost:8080/v1");
280
281        // ───────────────────────────────────────────────────────────────────────
282        // 4. China-based Providers
283        // ───────────────────────────────────────────────────────────────────────
284        compat.insert("siliconflow", "https://api.siliconflow.cn/v1");
285        compat.insert("silicon-flow", "https://api.siliconflow.cn/v1");
286        compat.insert("zhipu", "https://open.bigmodel.cn/api/paas/v4");
287        compat.insert("zhipu-ai", "https://open.bigmodel.cn/api/paas/v4");
288        compat.insert("bigmodel", "https://open.bigmodel.cn/api/paas/v4");
289        compat.insert("minimax", "https://api.minimax.chat/v1");
290        compat.insert("minimax-ai", "https://api.minimax.chat/v1");
291        compat.insert("glm", "https://open.bigmodel.cn/api/paas/v4");
292        compat.insert("chatglm", "https://open.bigmodel.cn/api/paas/v4");
293        compat.insert("baichuan", "https://api.baichuan-ai.com/v1");
294        compat.insert("01-ai", "https://api.01.ai/v1");
295        compat.insert("yi", "https://api.01.ai/v1");
296
297        // ───────────────────────────────────────────────────────────────────────
298        // 5. Additional Providers from OpenCommit & Others
299        // ───────────────────────────────────────────────────────────────────────
300        compat.insert("aimlapi", "https://api.aimlapi.com/v1");
301        compat.insert("ai-ml-api", "https://api.aimlapi.com/v1");
302
303        // ───────────────────────────────────────────────────────────────────────
304        // 6. AI Gateway & Proxy Services
305        // ───────────────────────────────────────────────────────────────────────
306        compat.insert("helicone", "https://gateway.helicone.ai/v1");
307        compat.insert("helicone-ai", "https://gateway.helicone.ai/v1");
308        compat.insert(
309            "workers-ai",
310            "https://api.cloudflare.com/client/v4/accounts/{account_id}/ai/v1",
311        );
312        compat.insert("cloudflare-ai", "https://gateway.ai.cloudflare.com/v1");
313        compat.insert("cloudflare-gateway", "https://gateway.ai.cloudflare.com/v1");
314        compat.insert("vercel-ai", "https://api.vercel.ai/v1");
315        compat.insert("vercel-gateway", "https://api.vercel.ai/v1");
316
317        // ───────────────────────────────────────────────────────────────────────
318        // 7. Specialized Providers
319        // ───────────────────────────────────────────────────────────────────────
320        compat.insert("302ai", "https://api.302.ai/v1");
321        compat.insert("302-ai", "https://api.302.ai/v1");
322        compat.insert("sap-ai", "https://api.ai.sap.com/v1");
323        compat.insert("sap-ai-core", "https://api.ai.sap.com/v1");
324
325        Self {
326            name: "openai",
327            aliases: vec!["openai"],
328            default_api_url: "https://api.openai.com/v1",
329            default_model: Some("gpt-4o-mini"),
330            compatible_providers: compat,
331        }
332    }
333}
334
335impl Default for OpenAICompatibleProvider {
336    fn default() -> Self {
337        Self::new()
338    }
339}
340
341impl super::registry::ProviderBuilder for OpenAICompatibleProvider {
342    fn name(&self) -> &'static str {
343        self.name
344    }
345
346    fn aliases(&self) -> Vec<&'static str> {
347        self.aliases.clone()
348    }
349
350    fn category(&self) -> super::registry::ProviderCategory {
351        super::registry::ProviderCategory::OpenAICompatible
352    }
353
354    fn create(&self, config: &Config) -> Result<Box<dyn super::AIProvider>> {
355        Ok(Box::new(OpenAIProvider::new(config)?))
356    }
357
358    fn requires_api_key(&self) -> bool {
359        true
360    }
361
362    fn default_model(&self) -> Option<&'static str> {
363        self.default_model
364    }
365}