Skip to main content

rusty_commit/providers/
vertex.rs

1use anyhow::{Context, Result};
2use async_trait::async_trait;
3use reqwest::{header, Client};
4use serde::Serialize;
5use serde_json::Value;
6
7use super::{split_prompt, AIProvider};
8use crate::config::Config;
9
10#[derive(Clone)]
11pub struct VertexProvider {
12    client: Client,
13    model: String,
14    project_id: String,
15    location: String,
16    access_token: String,
17}
18
19#[derive(Serialize)]
20struct VertexRequest {
21    model: String,
22    contents: Vec<VertexContent>,
23    system_instruction: Option<VertexSystemInstruction>,
24    generation_config: VertexGenerationConfig,
25}
26
27#[derive(Serialize)]
28struct VertexContent {
29    role: String,
30    parts: Vec<VertexPart>,
31}
32
33#[derive(Serialize)]
34struct VertexPart {
35    text: String,
36}
37
38#[derive(Serialize)]
39struct VertexSystemInstruction {
40    role: String,
41    parts: Vec<VertexPart>,
42}
43
44#[derive(Serialize)]
45struct VertexGenerationConfig {
46    max_output_tokens: u32,
47    temperature: f32,
48}
49
50impl VertexProvider {
51    pub fn new(config: &Config) -> Result<Self> {
52        let rt = tokio::runtime::Runtime::new().context("Failed to create runtime")?;
53        rt.block_on(async { Self::new_async(config).await })
54    }
55
56    async fn new_async(config: &Config) -> Result<Self> {
57        let client = Client::new();
58
59        let project_id = std::env::var("GOOGLE_CLOUD_PROJECT")
60            .or_else(|_| std::env::var("CLOUDSDK_CORE_PROJECT"))
61            .context(
62                "Google Cloud project ID not set. Set GOOGLE_CLOUD_PROJECT or run 'gcloud init'",
63            )?;
64
65        let location = config
66            .api_url
67            .as_ref()
68            .and_then(|url| {
69                url.split(".googleapis.com")
70                    .next()
71                    .and_then(|s| s.split('-').next())
72                    .map(|s| s.to_string())
73            })
74            .unwrap_or_else(|| "us-central1".to_string());
75
76        let model = config
77            .model
78            .as_deref()
79            .unwrap_or("gemini-1.5-pro")
80            .to_string();
81
82        let access_token = Self::get_gcloud_token().await?;
83
84        Ok(Self {
85            client,
86            model,
87            project_id,
88            location,
89            access_token,
90        })
91    }
92
93    #[allow(dead_code)]
94    pub fn from_account(
95        account: &crate::config::accounts::AccountConfig,
96        _api_key: &str,
97        config: &Config,
98    ) -> Result<Self> {
99        let rt = tokio::runtime::Runtime::new().context("Failed to create runtime")?;
100        rt.block_on(async { Self::from_account_async(account, config).await })
101    }
102
103    async fn from_account_async(
104        account: &crate::config::accounts::AccountConfig,
105        config: &Config,
106    ) -> Result<Self> {
107        let client = Client::new();
108
109        let project_id = std::env::var("GOOGLE_CLOUD_PROJECT")
110            .or_else(|_| std::env::var("CLOUDSDK_CORE_PROJECT"))
111            .context("Google Cloud project ID not set")?;
112
113        let location = account
114            .api_url
115            .as_ref()
116            .and_then(|url| {
117                url.split(".googleapis.com")
118                    .next()
119                    .and_then(|s| s.split('-').next())
120                    .map(|s| s.to_string())
121            })
122            .unwrap_or_else(|| "us-central1".to_string());
123
124        let model = account
125            .model
126            .as_deref()
127            .or(config.model.as_deref())
128            .unwrap_or("gemini-1.5-pro")
129            .to_string();
130
131        let access_token = Self::get_gcloud_token().await?;
132
133        Ok(Self {
134            client,
135            model,
136            project_id,
137            location,
138            access_token,
139        })
140    }
141
142    async fn get_gcloud_token() -> Result<String> {
143        let output = tokio::process::Command::new("gcloud")
144            .args(["auth", "print-access-token"])
145            .output()
146            .await
147            .context("Failed to execute gcloud command")?;
148
149        if !output.status.success() {
150            anyhow::bail!(
151                "Failed to get Google Cloud access token. Run 'gcloud auth login' first."
152            );
153        }
154
155        let token = String::from_utf8_lossy(&output.stdout).trim().to_string();
156
157        Ok(token)
158    }
159}
160
161#[async_trait]
162impl AIProvider for VertexProvider {
163    async fn generate_commit_message(
164        &self,
165        diff: &str,
166        context: Option<&str>,
167        full_gitmoji: bool,
168        config: &Config,
169    ) -> Result<String> {
170        let (system_prompt, user_prompt) = split_prompt(diff, context, config, full_gitmoji);
171
172        let request = VertexRequest {
173            model: format!(
174                "projects/{}/locations/{}/publishers/google/models/{}",
175                self.project_id, self.location, self.model
176            ),
177            contents: vec![VertexContent {
178                role: "user".to_string(),
179                parts: vec![VertexPart { text: user_prompt }],
180            }],
181            system_instruction: Some(VertexSystemInstruction {
182                role: "system".to_string(),
183                parts: vec![VertexPart {
184                    text: system_prompt,
185                }],
186            }),
187            generation_config: VertexGenerationConfig {
188                max_output_tokens: config.tokens_max_output.unwrap_or(500),
189                temperature: 0.7,
190            },
191        };
192
193        let url = format!(
194            "https://{}-aiplatform.googleapis.com/v1/projects/{}/locations/{}/publishers/google/{}:streamGenerateContent",
195            self.location, self.project_id, self.location, self.model
196        );
197
198        let response = self
199            .client
200            .post(&url)
201            .header(
202                header::AUTHORIZATION,
203                format!("Bearer {}", self.access_token),
204            )
205            .header(header::CONTENT_TYPE, "application/json")
206            .json(&request)
207            .send()
208            .await
209            .context("Failed to connect to Vertex AI")?;
210
211        if !response.status().is_success() {
212            let error_text = response.text().await?;
213            anyhow::bail!("Vertex AI API error: {}", error_text);
214        }
215
216        let json: Value = response
217            .json()
218            .await
219            .context("Failed to parse Vertex AI response")?;
220
221        // Extract text from response
222        let message = json
223            .get("candidates")
224            .and_then(|candidates| candidates.as_array())
225            .and_then(|candidates| candidates.first())
226            .and_then(|cand| cand.get("content"))
227            .and_then(|content| content.get("parts"))
228            .and_then(|parts| parts.as_array())
229            .and_then(|parts| parts.first())
230            .and_then(|part| part.get("text"))
231            .and_then(|text| text.as_str())
232            .map(|s| s.to_string())
233            .context("No response from Vertex AI")?;
234
235        Ok(message.trim().to_string())
236    }
237}
238
239/// ProviderBuilder for Vertex AI
240pub struct VertexProviderBuilder;
241
242impl super::registry::ProviderBuilder for VertexProviderBuilder {
243    fn name(&self) -> &'static str {
244        "vertex"
245    }
246
247    fn aliases(&self) -> Vec<&'static str> {
248        vec!["vertex-ai", "google-vertex", "gcp-vertex"]
249    }
250
251    fn category(&self) -> super::registry::ProviderCategory {
252        super::registry::ProviderCategory::Cloud
253    }
254
255    fn create(&self, config: &Config) -> Result<Box<dyn super::AIProvider>> {
256        Ok(Box::new(VertexProvider::new(config)?))
257    }
258
259    fn requires_api_key(&self) -> bool {
260        false
261    }
262
263    fn default_model(&self) -> Option<&'static str> {
264        Some("gemini-1.5-pro")
265    }
266}