rusty_commit/providers/
vertex.rs1use anyhow::{Context, Result};
2use async_trait::async_trait;
3use reqwest::{header, Client};
4use serde::Serialize;
5use serde_json::Value;
6
7use super::prompt::split_prompt;
8use super::AIProvider;
9use crate::config::Config;
10
11#[derive(Clone)]
12pub struct VertexProvider {
13 client: Client,
14 model: String,
15 project_id: String,
16 location: String,
17 access_token: String,
18}
19
20#[derive(Serialize)]
21struct VertexRequest {
22 model: String,
23 contents: Vec<VertexContent>,
24 system_instruction: Option<VertexSystemInstruction>,
25 generation_config: VertexGenerationConfig,
26}
27
28#[derive(Serialize)]
29struct VertexContent {
30 role: String,
31 parts: Vec<VertexPart>,
32}
33
34#[derive(Serialize)]
35struct VertexPart {
36 text: String,
37}
38
39#[derive(Serialize)]
40struct VertexSystemInstruction {
41 role: String,
42 parts: Vec<VertexPart>,
43}
44
45#[derive(Serialize)]
46struct VertexGenerationConfig {
47 max_output_tokens: u32,
48 temperature: f32,
49}
50
51impl VertexProvider {
52 pub fn new(config: &Config) -> Result<Self> {
53 let rt = tokio::runtime::Runtime::new().context("Failed to create runtime")?;
54 rt.block_on(async { Self::new_async(config).await })
55 }
56
57 async fn new_async(config: &Config) -> Result<Self> {
58 let client = Client::new();
59
60 let project_id = std::env::var("GOOGLE_CLOUD_PROJECT")
61 .or_else(|_| std::env::var("CLOUDSDK_CORE_PROJECT"))
62 .context(
63 "Google Cloud project ID not set. Set GOOGLE_CLOUD_PROJECT or run 'gcloud init'",
64 )?;
65
66 let location = config
67 .api_url
68 .as_ref()
69 .and_then(|url| {
70 url.split(".googleapis.com")
71 .next()
72 .and_then(|s| s.split('-').next())
73 .map(|s| s.to_string())
74 })
75 .unwrap_or_else(|| "us-central1".to_string());
76
77 let model = config
78 .model
79 .as_deref()
80 .unwrap_or("gemini-1.5-pro")
81 .to_string();
82
83 let access_token = Self::get_gcloud_token().await?;
84
85 Ok(Self {
86 client,
87 model,
88 project_id,
89 location,
90 access_token,
91 })
92 }
93
94 #[allow(dead_code)]
95 pub fn from_account(
96 account: &crate::config::accounts::AccountConfig,
97 _api_key: &str,
98 config: &Config,
99 ) -> Result<Self> {
100 let rt = tokio::runtime::Runtime::new().context("Failed to create runtime")?;
101 rt.block_on(async { Self::from_account_async(account, config).await })
102 }
103
104 async fn from_account_async(
105 account: &crate::config::accounts::AccountConfig,
106 config: &Config,
107 ) -> Result<Self> {
108 let client = Client::new();
109
110 let project_id = std::env::var("GOOGLE_CLOUD_PROJECT")
111 .or_else(|_| std::env::var("CLOUDSDK_CORE_PROJECT"))
112 .context("Google Cloud project ID not set")?;
113
114 let location = account
115 .api_url
116 .as_ref()
117 .and_then(|url| {
118 url.split(".googleapis.com")
119 .next()
120 .and_then(|s| s.split('-').next())
121 .map(|s| s.to_string())
122 })
123 .unwrap_or_else(|| "us-central1".to_string());
124
125 let model = account
126 .model
127 .as_deref()
128 .or(config.model.as_deref())
129 .unwrap_or("gemini-1.5-pro")
130 .to_string();
131
132 let access_token = Self::get_gcloud_token().await?;
133
134 Ok(Self {
135 client,
136 model,
137 project_id,
138 location,
139 access_token,
140 })
141 }
142
143 async fn get_gcloud_token() -> Result<String> {
144 let output = tokio::process::Command::new("gcloud")
145 .args(["auth", "print-access-token"])
146 .output()
147 .await
148 .context("Failed to execute gcloud command")?;
149
150 if !output.status.success() {
151 anyhow::bail!(
152 "Failed to get Google Cloud access token. Run 'gcloud auth login' first."
153 );
154 }
155
156 let token = String::from_utf8_lossy(&output.stdout).trim().to_string();
157
158 Ok(token)
159 }
160}
161
162#[async_trait]
163impl AIProvider for VertexProvider {
164 async fn generate_commit_message(
165 &self,
166 diff: &str,
167 context: Option<&str>,
168 full_gitmoji: bool,
169 config: &Config,
170 ) -> Result<String> {
171 let (system_prompt, user_prompt) = split_prompt(diff, context, config, full_gitmoji);
172
173 let request = VertexRequest {
174 model: format!(
175 "projects/{}/locations/{}/publishers/google/models/{}",
176 self.project_id, self.location, self.model
177 ),
178 contents: vec![VertexContent {
179 role: "user".to_string(),
180 parts: vec![VertexPart { text: user_prompt }],
181 }],
182 system_instruction: Some(VertexSystemInstruction {
183 role: "system".to_string(),
184 parts: vec![VertexPart {
185 text: system_prompt,
186 }],
187 }),
188 generation_config: VertexGenerationConfig {
189 max_output_tokens: config.tokens_max_output.unwrap_or(500),
190 temperature: 0.7,
191 },
192 };
193
194 let url = format!(
195 "https://{}-aiplatform.googleapis.com/v1/projects/{}/locations/{}/publishers/google/{}:streamGenerateContent",
196 self.location, self.project_id, self.location, self.model
197 );
198
199 let response = self
200 .client
201 .post(&url)
202 .header(
203 header::AUTHORIZATION,
204 format!("Bearer {}", self.access_token),
205 )
206 .header(header::CONTENT_TYPE, "application/json")
207 .json(&request)
208 .send()
209 .await
210 .context("Failed to connect to Vertex AI")?;
211
212 if !response.status().is_success() {
213 let error_text = response.text().await?;
214 anyhow::bail!("Vertex AI API error: {}", error_text);
215 }
216
217 let json: Value = response
218 .json()
219 .await
220 .context("Failed to parse Vertex AI response")?;
221
222 let message = json
224 .get("candidates")
225 .and_then(|candidates| candidates.as_array())
226 .and_then(|candidates| candidates.first())
227 .and_then(|cand| cand.get("content"))
228 .and_then(|content| content.get("parts"))
229 .and_then(|parts| parts.as_array())
230 .and_then(|parts| parts.first())
231 .and_then(|part| part.get("text"))
232 .and_then(|text| text.as_str())
233 .map(|s| s.to_string())
234 .context("No response from Vertex AI")?;
235
236 Ok(message.trim().to_string())
237 }
238}
239
240pub struct VertexProviderBuilder;
242
243impl super::registry::ProviderBuilder for VertexProviderBuilder {
244 fn name(&self) -> &'static str {
245 "vertex"
246 }
247
248 fn aliases(&self) -> Vec<&'static str> {
249 vec!["vertex-ai", "google-vertex", "gcp-vertex"]
250 }
251
252 fn category(&self) -> super::registry::ProviderCategory {
253 super::registry::ProviderCategory::Cloud
254 }
255
256 fn create(&self, config: &Config) -> Result<Box<dyn super::AIProvider>> {
257 Ok(Box::new(VertexProvider::new(config)?))
258 }
259
260 fn requires_api_key(&self) -> bool {
261 false
262 }
263
264 fn default_model(&self) -> Option<&'static str> {
265 Some("gemini-1.5-pro")
266 }
267}