Skip to main content

rusty_commit/providers/
nvidia.rs

1//! NVIDIA NIM Provider - Enterprise GPU Inference
2//!
3//! NVIDIA NIM (NVIDIA Inference Microservices) provides optimized inference
4//! for LLMs on NVIDIA GPUs. Supports both self-hosted and cloud deployments.
5//!
6//! Setup:
7//! 1. Get API key from: https://build.nvidia.com
8//! 2. Configure rco:
9//!    `rco config set RCO_AI_PROVIDER=nvidia RCO_API_KEY=<key> RCO_MODEL=meta/llama-3.1-8b-instruct`
10//!
11//! Docs: https://docs.nvidia.com/nim/
12
13use anyhow::{Context, Result};
14use async_trait::async_trait;
15use reqwest::Client;
16use serde::{Deserialize, Serialize};
17
18use super::prompt::build_prompt;
19use super::AIProvider;
20use crate::config::Config;
21use crate::utils::retry::retry_async;
22
23pub struct NvidiaProvider {
24    client: Client,
25    api_url: String,
26    api_key: String,
27    model: String,
28}
29
30#[derive(Serialize)]
31struct NvidiaRequest {
32    model: String,
33    messages: Vec<NvidiaMessage>,
34    max_tokens: i32,
35    temperature: f32,
36    top_p: f32,
37    stream: bool,
38}
39
40#[derive(Serialize, Deserialize, Clone)]
41struct NvidiaMessage {
42    role: String,
43    content: String,
44}
45
46#[derive(Deserialize)]
47struct NvidiaResponse {
48    choices: Vec<NvidiaChoice>,
49}
50
51#[derive(Deserialize)]
52struct NvidiaChoice {
53    message: NvidiaMessage,
54}
55
56impl NvidiaProvider {
57    pub fn new(config: &Config) -> Result<Self> {
58        let client = Client::new();
59        let api_key = config
60            .api_key
61            .as_ref()
62            .context("NVIDIA API key not configured.\nRun: rco config set RCO_API_KEY=<your_key>\nGet your API key from: https://build.nvidia.com")?;
63
64        let api_url = config
65            .api_url
66            .as_deref()
67            .unwrap_or("https://integrate.api.nvidia.com/v1")
68            .to_string();
69
70        let model = config
71            .model
72            .as_deref()
73            .unwrap_or("meta/llama-3.1-8b-instruct")
74            .to_string();
75
76        Ok(Self {
77            client,
78            api_url,
79            api_key: api_key.clone(),
80            model,
81        })
82    }
83
84    /// Create provider from account configuration
85    #[allow(dead_code)]
86    pub fn from_account(
87        account: &crate::config::accounts::AccountConfig,
88        api_key: &str,
89        config: &Config,
90    ) -> Result<Self> {
91        let client = Client::new();
92        let api_url = account
93            .api_url
94            .as_deref()
95            .or(config.api_url.as_deref())
96            .unwrap_or("https://integrate.api.nvidia.com/v1")
97            .to_string();
98
99        let model = account
100            .model
101            .as_deref()
102            .or(config.model.as_deref())
103            .unwrap_or("meta/llama-3.1-8b-instruct")
104            .to_string();
105
106        Ok(Self {
107            client,
108            api_url,
109            api_key: api_key.to_string(),
110            model,
111        })
112    }
113}
114
115#[async_trait]
116impl AIProvider for NvidiaProvider {
117    async fn generate_commit_message(
118        &self,
119        diff: &str,
120        context: Option<&str>,
121        full_gitmoji: bool,
122        config: &Config,
123    ) -> Result<String> {
124        let prompt = build_prompt(diff, context, config, full_gitmoji);
125
126        let messages = vec![
127            NvidiaMessage {
128                role: "system".to_string(),
129                content: "You are an expert at writing clear, concise git commit messages."
130                    .to_string(),
131            },
132            NvidiaMessage {
133                role: "user".to_string(),
134                content: prompt,
135            },
136        ];
137
138        let request = NvidiaRequest {
139            model: self.model.clone(),
140            messages,
141            max_tokens: config.tokens_max_output.unwrap_or(500) as i32,
142            temperature: 0.7,
143            top_p: 0.7,
144            stream: false,
145        };
146
147        let nvidia_response: NvidiaResponse = retry_async(|| async {
148            let url = format!("{}/chat/completions", self.api_url);
149            let response = self
150                .client
151                .post(&url)
152                .header("Authorization", format!("Bearer {}", self.api_key))
153                .json(&request)
154                .send()
155                .await
156                .context("Failed to connect to NVIDIA NIM API")?;
157
158            if !response.status().is_success() {
159                let error_text = response.text().await?;
160                if error_text.contains("401") || error_text.contains("Unauthorized") {
161                    return Err(anyhow::anyhow!(
162                        "Invalid NVIDIA API key. Please check your API key configuration."
163                    ));
164                }
165                return Err(anyhow::anyhow!("NVIDIA NIM API error: {}", error_text));
166            }
167
168            let nvidia_response: NvidiaResponse = response
169                .json()
170                .await
171                .context("Failed to parse NVIDIA NIM response")?;
172
173            Ok(nvidia_response)
174        })
175        .await
176        .context("Failed to generate commit message from NVIDIA NIM after retries")?;
177
178        let message = nvidia_response
179            .choices
180            .first()
181            .map(|choice| choice.message.content.trim().to_string())
182            .context("NVIDIA NIM returned an empty response")?;
183
184        Ok(message)
185    }
186}
187
188/// ProviderBuilder for NVIDIA NIM
189pub struct NvidiaProviderBuilder;
190
191impl super::registry::ProviderBuilder for NvidiaProviderBuilder {
192    fn name(&self) -> &'static str {
193        "nvidia"
194    }
195
196    fn aliases(&self) -> Vec<&'static str> {
197        vec!["nvidia-nim", "nim", "nvidia-ai"]
198    }
199
200    fn category(&self) -> super::registry::ProviderCategory {
201        super::registry::ProviderCategory::Cloud
202    }
203
204    fn create(&self, config: &Config) -> Result<Box<dyn super::AIProvider>> {
205        Ok(Box::new(NvidiaProvider::new(config)?))
206    }
207
208    fn requires_api_key(&self) -> bool {
209        true
210    }
211
212    fn default_model(&self) -> Option<&'static str> {
213        Some("meta/llama-3.1-8b-instruct")
214    }
215}