rusty_commit/providers/
vertex.rs1use anyhow::{Context, Result};
2use async_trait::async_trait;
3use reqwest::{header, Client};
4use serde::Serialize;
5use serde_json::Value;
6
7use super::{split_prompt, AIProvider};
8use crate::config::Config;
9
10#[derive(Clone)]
11pub struct VertexProvider {
12 client: Client,
13 model: String,
14 project_id: String,
15 location: String,
16 access_token: String,
17}
18
19#[derive(Serialize)]
20struct VertexRequest {
21 model: String,
22 contents: Vec<VertexContent>,
23 system_instruction: Option<VertexSystemInstruction>,
24 generation_config: VertexGenerationConfig,
25}
26
27#[derive(Serialize)]
28struct VertexContent {
29 role: String,
30 parts: Vec<VertexPart>,
31}
32
33#[derive(Serialize)]
34struct VertexPart {
35 text: String,
36}
37
38#[derive(Serialize)]
39struct VertexSystemInstruction {
40 role: String,
41 parts: Vec<VertexPart>,
42}
43
44#[derive(Serialize)]
45struct VertexGenerationConfig {
46 max_output_tokens: u32,
47 temperature: f32,
48}
49
50impl VertexProvider {
51 pub fn new(config: &Config) -> Result<Self> {
52 let rt = tokio::runtime::Runtime::new().context("Failed to create runtime")?;
53 rt.block_on(async { Self::new_async(config).await })
54 }
55
56 async fn new_async(config: &Config) -> Result<Self> {
57 let client = Client::new();
58
59 let project_id = std::env::var("GOOGLE_CLOUD_PROJECT")
60 .or_else(|_| std::env::var("CLOUDSDK_CORE_PROJECT"))
61 .context(
62 "Google Cloud project ID not set. Set GOOGLE_CLOUD_PROJECT or run 'gcloud init'",
63 )?;
64
65 let location = config
66 .api_url
67 .as_ref()
68 .and_then(|url| {
69 url.split(".googleapis.com")
70 .next()
71 .and_then(|s| s.split('-').next())
72 .map(|s| s.to_string())
73 })
74 .unwrap_or_else(|| "us-central1".to_string());
75
76 let model = config
77 .model
78 .as_deref()
79 .unwrap_or("gemini-1.5-pro")
80 .to_string();
81
82 let access_token = Self::get_gcloud_token().await?;
83
84 Ok(Self {
85 client,
86 model,
87 project_id,
88 location,
89 access_token,
90 })
91 }
92
93 #[allow(dead_code)]
94 pub fn from_account(
95 account: &crate::config::accounts::AccountConfig,
96 _api_key: &str,
97 config: &Config,
98 ) -> Result<Self> {
99 let rt = tokio::runtime::Runtime::new().context("Failed to create runtime")?;
100 rt.block_on(async { Self::from_account_async(account, config).await })
101 }
102
103 async fn from_account_async(
104 account: &crate::config::accounts::AccountConfig,
105 config: &Config,
106 ) -> Result<Self> {
107 let client = Client::new();
108
109 let project_id = std::env::var("GOOGLE_CLOUD_PROJECT")
110 .or_else(|_| std::env::var("CLOUDSDK_CORE_PROJECT"))
111 .context("Google Cloud project ID not set")?;
112
113 let location = account
114 .api_url
115 .as_ref()
116 .and_then(|url| {
117 url.split(".googleapis.com")
118 .next()
119 .and_then(|s| s.split('-').next())
120 .map(|s| s.to_string())
121 })
122 .unwrap_or_else(|| "us-central1".to_string());
123
124 let model = account
125 .model
126 .as_deref()
127 .or(config.model.as_deref())
128 .unwrap_or("gemini-1.5-pro")
129 .to_string();
130
131 let access_token = Self::get_gcloud_token().await?;
132
133 Ok(Self {
134 client,
135 model,
136 project_id,
137 location,
138 access_token,
139 })
140 }
141
142 async fn get_gcloud_token() -> Result<String> {
143 let output = tokio::process::Command::new("gcloud")
144 .args(["auth", "print-access-token"])
145 .output()
146 .await
147 .context("Failed to execute gcloud command")?;
148
149 if !output.status.success() {
150 anyhow::bail!(
151 "Failed to get Google Cloud access token. Run 'gcloud auth login' first."
152 );
153 }
154
155 let token = String::from_utf8_lossy(&output.stdout).trim().to_string();
156
157 Ok(token)
158 }
159}
160
161#[async_trait]
162impl AIProvider for VertexProvider {
163 async fn generate_commit_message(
164 &self,
165 diff: &str,
166 context: Option<&str>,
167 full_gitmoji: bool,
168 config: &Config,
169 ) -> Result<String> {
170 let (system_prompt, user_prompt) = split_prompt(diff, context, config, full_gitmoji);
171
172 let request = VertexRequest {
173 model: format!(
174 "projects/{}/locations/{}/publishers/google/models/{}",
175 self.project_id, self.location, self.model
176 ),
177 contents: vec![VertexContent {
178 role: "user".to_string(),
179 parts: vec![VertexPart { text: user_prompt }],
180 }],
181 system_instruction: Some(VertexSystemInstruction {
182 role: "system".to_string(),
183 parts: vec![VertexPart {
184 text: system_prompt,
185 }],
186 }),
187 generation_config: VertexGenerationConfig {
188 max_output_tokens: config.tokens_max_output.unwrap_or(500),
189 temperature: 0.7,
190 },
191 };
192
193 let url = format!(
194 "https://{}-aiplatform.googleapis.com/v1/projects/{}/locations/{}/publishers/google/{}:streamGenerateContent",
195 self.location, self.project_id, self.location, self.model
196 );
197
198 let response = self
199 .client
200 .post(&url)
201 .header(
202 header::AUTHORIZATION,
203 format!("Bearer {}", self.access_token),
204 )
205 .header(header::CONTENT_TYPE, "application/json")
206 .json(&request)
207 .send()
208 .await
209 .context("Failed to connect to Vertex AI")?;
210
211 if !response.status().is_success() {
212 let error_text = response.text().await?;
213 anyhow::bail!("Vertex AI API error: {}", error_text);
214 }
215
216 let json: Value = response
217 .json()
218 .await
219 .context("Failed to parse Vertex AI response")?;
220
221 let message = json
223 .get("candidates")
224 .and_then(|candidates| candidates.as_array())
225 .and_then(|candidates| candidates.first())
226 .and_then(|cand| cand.get("content"))
227 .and_then(|content| content.get("parts"))
228 .and_then(|parts| parts.as_array())
229 .and_then(|parts| parts.first())
230 .and_then(|part| part.get("text"))
231 .and_then(|text| text.as_str())
232 .map(|s| s.to_string())
233 .context("No response from Vertex AI")?;
234
235 Ok(message.trim().to_string())
236 }
237}
238
239pub struct VertexProviderBuilder;
241
242impl super::registry::ProviderBuilder for VertexProviderBuilder {
243 fn name(&self) -> &'static str {
244 "vertex"
245 }
246
247 fn aliases(&self) -> Vec<&'static str> {
248 vec!["vertex-ai", "google-vertex", "gcp-vertex"]
249 }
250
251 fn category(&self) -> super::registry::ProviderCategory {
252 super::registry::ProviderCategory::Cloud
253 }
254
255 fn create(&self, config: &Config) -> Result<Box<dyn super::AIProvider>> {
256 Ok(Box::new(VertexProvider::new(config)?))
257 }
258
259 fn requires_api_key(&self) -> bool {
260 false
261 }
262
263 fn default_model(&self) -> Option<&'static str> {
264 Some("gemini-1.5-pro")
265 }
266}