Skip to main content

rusty_commit/providers/
vertex.rs

1use anyhow::{Context, Result};
2use async_trait::async_trait;
3use reqwest::{header, Client};
4use serde::Serialize;
5use serde_json::Value;
6
7use super::prompt::split_prompt;
8use super::AIProvider;
9use crate::config::Config;
10
11#[derive(Clone)]
12pub struct VertexProvider {
13    client: Client,
14    model: String,
15    project_id: String,
16    location: String,
17    access_token: String,
18}
19
20#[derive(Serialize)]
21struct VertexRequest {
22    model: String,
23    contents: Vec<VertexContent>,
24    system_instruction: Option<VertexSystemInstruction>,
25    generation_config: VertexGenerationConfig,
26}
27
28#[derive(Serialize)]
29struct VertexContent {
30    role: String,
31    parts: Vec<VertexPart>,
32}
33
34#[derive(Serialize)]
35struct VertexPart {
36    text: String,
37}
38
39#[derive(Serialize)]
40struct VertexSystemInstruction {
41    role: String,
42    parts: Vec<VertexPart>,
43}
44
45#[derive(Serialize)]
46struct VertexGenerationConfig {
47    max_output_tokens: u32,
48    temperature: f32,
49}
50
51impl VertexProvider {
52    pub fn new(config: &Config) -> Result<Self> {
53        let rt = tokio::runtime::Runtime::new().context("Failed to create runtime")?;
54        rt.block_on(async { Self::new_async(config).await })
55    }
56
57    async fn new_async(config: &Config) -> Result<Self> {
58        let client = Client::new();
59
60        let project_id = std::env::var("GOOGLE_CLOUD_PROJECT")
61            .or_else(|_| std::env::var("CLOUDSDK_CORE_PROJECT"))
62            .context(
63                "Google Cloud project ID not set. Set GOOGLE_CLOUD_PROJECT or run 'gcloud init'",
64            )?;
65
66        let location = config
67            .api_url
68            .as_ref()
69            .and_then(|url| {
70                url.split(".googleapis.com")
71                    .next()
72                    .and_then(|s| s.split('-').next())
73                    .map(|s| s.to_string())
74            })
75            .unwrap_or_else(|| "us-central1".to_string());
76
77        let model = config
78            .model
79            .as_deref()
80            .unwrap_or("gemini-1.5-pro")
81            .to_string();
82
83        let access_token = Self::get_gcloud_token().await?;
84
85        Ok(Self {
86            client,
87            model,
88            project_id,
89            location,
90            access_token,
91        })
92    }
93
94    #[allow(dead_code)]
95    pub fn from_account(
96        account: &crate::config::accounts::AccountConfig,
97        _api_key: &str,
98        config: &Config,
99    ) -> Result<Self> {
100        let rt = tokio::runtime::Runtime::new().context("Failed to create runtime")?;
101        rt.block_on(async { Self::from_account_async(account, config).await })
102    }
103
104    async fn from_account_async(
105        account: &crate::config::accounts::AccountConfig,
106        config: &Config,
107    ) -> Result<Self> {
108        let client = Client::new();
109
110        let project_id = std::env::var("GOOGLE_CLOUD_PROJECT")
111            .or_else(|_| std::env::var("CLOUDSDK_CORE_PROJECT"))
112            .context("Google Cloud project ID not set")?;
113
114        let location = account
115            .api_url
116            .as_ref()
117            .and_then(|url| {
118                url.split(".googleapis.com")
119                    .next()
120                    .and_then(|s| s.split('-').next())
121                    .map(|s| s.to_string())
122            })
123            .unwrap_or_else(|| "us-central1".to_string());
124
125        let model = account
126            .model
127            .as_deref()
128            .or(config.model.as_deref())
129            .unwrap_or("gemini-1.5-pro")
130            .to_string();
131
132        let access_token = Self::get_gcloud_token().await?;
133
134        Ok(Self {
135            client,
136            model,
137            project_id,
138            location,
139            access_token,
140        })
141    }
142
143    async fn get_gcloud_token() -> Result<String> {
144        let output = tokio::process::Command::new("gcloud")
145            .args(["auth", "print-access-token"])
146            .output()
147            .await
148            .context("Failed to execute gcloud command")?;
149
150        if !output.status.success() {
151            anyhow::bail!(
152                "Failed to get Google Cloud access token. Run 'gcloud auth login' first."
153            );
154        }
155
156        let token = String::from_utf8_lossy(&output.stdout).trim().to_string();
157
158        Ok(token)
159    }
160}
161
162#[async_trait]
163impl AIProvider for VertexProvider {
164    async fn generate_commit_message(
165        &self,
166        diff: &str,
167        context: Option<&str>,
168        full_gitmoji: bool,
169        config: &Config,
170    ) -> Result<String> {
171        let (system_prompt, user_prompt) = split_prompt(diff, context, config, full_gitmoji);
172
173        let request = VertexRequest {
174            model: format!(
175                "projects/{}/locations/{}/publishers/google/models/{}",
176                self.project_id, self.location, self.model
177            ),
178            contents: vec![VertexContent {
179                role: "user".to_string(),
180                parts: vec![VertexPart { text: user_prompt }],
181            }],
182            system_instruction: Some(VertexSystemInstruction {
183                role: "system".to_string(),
184                parts: vec![VertexPart {
185                    text: system_prompt,
186                }],
187            }),
188            generation_config: VertexGenerationConfig {
189                max_output_tokens: config.tokens_max_output.unwrap_or(500),
190                temperature: 0.7,
191            },
192        };
193
194        let url = format!(
195            "https://{}-aiplatform.googleapis.com/v1/projects/{}/locations/{}/publishers/google/{}:streamGenerateContent",
196            self.location, self.project_id, self.location, self.model
197        );
198
199        let response = self
200            .client
201            .post(&url)
202            .header(
203                header::AUTHORIZATION,
204                format!("Bearer {}", self.access_token),
205            )
206            .header(header::CONTENT_TYPE, "application/json")
207            .json(&request)
208            .send()
209            .await
210            .context("Failed to connect to Vertex AI")?;
211
212        if !response.status().is_success() {
213            let error_text = response.text().await?;
214            anyhow::bail!("Vertex AI API error: {}", error_text);
215        }
216
217        let json: Value = response
218            .json()
219            .await
220            .context("Failed to parse Vertex AI response")?;
221
222        // Extract text from response
223        let message = json
224            .get("candidates")
225            .and_then(|candidates| candidates.as_array())
226            .and_then(|candidates| candidates.first())
227            .and_then(|cand| cand.get("content"))
228            .and_then(|content| content.get("parts"))
229            .and_then(|parts| parts.as_array())
230            .and_then(|parts| parts.first())
231            .and_then(|part| part.get("text"))
232            .and_then(|text| text.as_str())
233            .map(|s| s.to_string())
234            .context("No response from Vertex AI")?;
235
236        Ok(message.trim().to_string())
237    }
238}
239
240/// ProviderBuilder for Vertex AI
241pub struct VertexProviderBuilder;
242
243impl super::registry::ProviderBuilder for VertexProviderBuilder {
244    fn name(&self) -> &'static str {
245        "vertex"
246    }
247
248    fn aliases(&self) -> Vec<&'static str> {
249        vec!["vertex-ai", "google-vertex", "gcp-vertex"]
250    }
251
252    fn category(&self) -> super::registry::ProviderCategory {
253        super::registry::ProviderCategory::Cloud
254    }
255
256    fn create(&self, config: &Config) -> Result<Box<dyn super::AIProvider>> {
257        Ok(Box::new(VertexProvider::new(config)?))
258    }
259
260    fn requires_api_key(&self) -> bool {
261        false
262    }
263
264    fn default_model(&self) -> Option<&'static str> {
265        Some("gemini-1.5-pro")
266    }
267}