phi_core/provider/google_vertex.rs
1//! Google Vertex AI provider.
2//!
3//! Similar to Google Generative AI but uses OAuth2 authentication
4//! and a different base URL pattern with project/location.
5//!
6//! The API key in StreamConfig is expected to be an OAuth2 access token.
7//! Callers are responsible for obtaining the token (e.g., via service account JWT).
8/*
9ARCHITECTURE: GoogleVertexProvider — enterprise Gemini via Vertex AI
10
11Vertex AI is Google's enterprise AI platform. It hosts the same Gemini models
12as Generative AI (`generativelanguage.googleapis.com`) but with:
13
14 Different URL structure:
15 GenAI: https://generativelanguage.googleapis.com/v1beta/models/{model}:streamGenerateContent
16 Vertex: https://{region}-aiplatform.googleapis.com/v1/projects/{project}/locations/{region}/
17 publishers/google/models/{model}:streamGenerateContent
18
19 Different authentication:
20 GenAI: `?key={api_key}` query parameter (simple API key)
21 Vertex: `Authorization: Bearer {oauth2_access_token}` header
22
23 Same API format:
24 Both use identical request/response JSON shapes (Gemini content format).
25 We re-use `build_vertex_request_body()` which is structurally the same as
26 Google GenAI's `build_request_body()`.
27
28ARCHITECTURE: Delegation pattern
29
30`GoogleVertexProvider` doesn't re-implement the SSE event loop. Instead, it:
31 1. Constructs the Vertex-specific URL (`vertex_url()` static method)
32 2. Adds the OAuth2 Bearer token as a header
33 3. Delegates to `super::google::stream_google_content()` (shared SSE logic)
34
35This avoids duplicating the Google event parsing code. The only Vertex-specific
36logic is URL construction and auth — everything else is identical to GenAI.
37
38RUST QUIRK: `fn vertex_url(model_config: &ModelConfig, model: &str) -> String`
39 An associated function on `GoogleVertexProvider` (no `self` parameter).
40 Called as `Self::vertex_url(model_config, model)` or `GoogleVertexProvider::vertex_url(...)`.
41 Python analogy: a `@staticmethod` on the class.
42*/
43
44use super::model::ModelConfig;
45use super::traits::*;
46use crate::types::*;
47use async_trait::async_trait;
48use tokio::sync::mpsc;
49
50/// Unit struct — no state. All logic in the `StreamProvider` impl.
51pub struct GoogleVertexProvider;
52
53impl GoogleVertexProvider {
54 /// Build the Vertex AI URL from model config.
55 /// Expects base_url in format: `https://{region}-aiplatform.googleapis.com/v1/projects/{project}/locations/{region}/publishers/google/models`
56 fn vertex_url(
57 model_config: &ModelConfig, // CONFIG — carries base_url (Vertex endpoint) to construct full URL
58 model: &str, // MODEL NAME — appended to base_url to get the per-model endpoint
59 ) -> String {
60 format!(
61 "{}/{}:streamGenerateContent?alt=sse",
62 model_config.base_url, model
63 )
64 }
65}
66
67#[async_trait]
68impl StreamProvider for GoogleVertexProvider {
69 fn provider_id(&self) -> &str {
70 "vertex"
71 }
72
73 async fn stream(
74 &self,
75 config: StreamConfig, // REQUEST — api_key is OAuth2 Bearer token (not API key); base_url is Vertex endpoint
76 tx: mpsc::UnboundedSender<StreamEvent>, // OBSERVER — delegates to GoogleProvider's stream logic
77 cancel: tokio_util::sync::CancellationToken, // ABORT — forwarded to delegate
78 ) -> Result<Message, ProviderError> {
79 let model_config = &config.model_config;
80 // Resolve via CredentialProvider when set, else use the static `api_key`.
81 let api_key = model_config.resolve_api_key().await?;
82
83 // Override the base_url to use Vertex format.
84 // The GoogleProvider's stream will use model_config.base_url, but we need
85 // a different URL pattern. We delegate to GoogleProvider with a modified config.
86 let vertex_url = Self::vertex_url(model_config, &config.model_config.id);
87
88 // Create a modified model config with the Vertex URL pattern
89 let mut vertex_model = model_config.clone();
90 // For Vertex, auth is via Bearer token (OAuth2), not API key in query param.
91 // We need to add the Authorization header.
92 vertex_model
93 .headers
94 .insert("authorization".to_string(), format!("Bearer {}", api_key));
95
96 // Build request body same as Google (same content format)
97 let body = build_vertex_request_body(&config);
98
99 let client = reqwest::Client::new();
100 let mut request = client
101 .post(&vertex_url)
102 .header("content-type", "application/json");
103
104 for (k, v) in &vertex_model.headers {
105 request = request.header(k, v);
106 }
107
108 let response = request
109 .json(&body)
110 .send()
111 .await
112 .map_err(|e| ProviderError::Network(e.to_string()))?;
113
114 if !response.status().is_success() {
115 let status = response.status();
116 let body = response.text().await.unwrap_or_default();
117 return Err(ProviderError::classify(
118 status.as_u16(),
119 &format!("Vertex AI error {}: {}", status, body),
120 ));
121 }
122
123 // Delegate SSE parsing to the Google provider's streaming logic.
124 // Since the response format is identical, we reuse GoogleProvider.
125 // However, we already have the response, so we'll parse it inline.
126 // Actually, let's just delegate to GoogleProvider. The key difference
127 // is auth (Bearer vs API key in URL). We handle that by using a modified
128 // model config. But GoogleProvider builds its own URL... so let's just
129 // use GoogleProvider with a trick: empty api_key and auth in headers.
130 // We can't easily reuse GoogleProvider because it constructs its own URL.
131 // Instead, parse the SSE response directly (same format as Google GenAI).
132 parse_google_sse_response(response, &config, &model_config.provider, tx, cancel).await
133 }
134}
135
136/// Parse a Google-format SSE response stream. Shared between Google and Vertex.
137async fn parse_google_sse_response(
138 response: reqwest::Response,
139 config: &StreamConfig,
140 provider_name: &str,
141 tx: mpsc::UnboundedSender<StreamEvent>,
142 cancel: tokio_util::sync::CancellationToken,
143) -> Result<Message, ProviderError> {
144 use futures::StreamExt;
145 use serde::Deserialize;
146 use tracing::{debug, warn};
147
148 let mut content: Vec<Content> = Vec::new();
149 let mut usage = Usage::default();
150 let mut stop_reason = StopReason::Stop;
151
152 let _ = tx.send(StreamEvent::Start);
153
154 let mut stream = response.bytes_stream();
155 let mut buffer = String::new();
156
157 loop {
158 tokio::select! {
159 _ = cancel.cancelled() => {
160 return Err(ProviderError::Cancelled);
161 }
162 chunk = stream.next() => {
163 match chunk {
164 None => break,
165 Some(Err(e)) => {
166 warn!("Vertex stream error: {}", e);
167 break;
168 }
169 Some(Ok(bytes)) => {
170 buffer.push_str(&String::from_utf8_lossy(&bytes));
171
172 while let Some(pos) = buffer.find("\n\n") {
173 let event_str = buffer[..pos].to_string();
174 buffer = buffer[pos + 2..].to_string();
175
176 let data = event_str
177 .lines()
178 .find(|l| l.starts_with("data: "))
179 .map(|l| &l[6..])
180 .unwrap_or("");
181
182 if data.is_empty() {
183 continue;
184 }
185
186 #[derive(Deserialize)]
187 struct Chunk {
188 #[serde(default)]
189 candidates: Option<Vec<Candidate>>,
190 #[serde(default, rename = "usageMetadata")]
191 usage_metadata: Option<UsageMeta>,
192 }
193 #[derive(Deserialize)]
194 struct Candidate {
195 #[serde(default)]
196 content: Option<CContent>,
197 #[serde(default, rename = "finishReason")]
198 finish_reason: Option<String>,
199 }
200 #[derive(Deserialize)]
201 struct CContent {
202 #[serde(default)]
203 parts: Vec<Part>,
204 }
205 #[derive(Deserialize)]
206 struct Part {
207 #[serde(default)]
208 text: Option<String>,
209 #[serde(default, rename = "functionCall")]
210 function_call: Option<FCall>,
211 }
212 #[derive(Deserialize)]
213 struct FCall {
214 name: String,
215 #[serde(default)]
216 args: Option<serde_json::Value>,
217 }
218 #[derive(Deserialize)]
219 struct UsageMeta {
220 #[serde(default, rename = "promptTokenCount")]
221 prompt_token_count: Option<u64>,
222 #[serde(default, rename = "candidatesTokenCount")]
223 candidates_token_count: Option<u64>,
224 #[serde(default, rename = "totalTokenCount")]
225 total_token_count: Option<u64>,
226 }
227
228 let parsed: Chunk = match serde_json::from_str(data) {
229 Ok(c) => c,
230 Err(e) => {
231 debug!("Failed to parse Vertex chunk: {}", e);
232 continue;
233 }
234 };
235
236 for candidate in parsed.candidates.unwrap_or_default() {
237 if let Some(c) = candidate.content {
238 for part in c.parts {
239 if let Some(text) = part.text {
240 let idx = content.iter().position(|c| matches!(c, Content::Text { .. }));
241 let idx = match idx {
242 Some(i) => i,
243 None => {
244 content.push(Content::Text { text: String::new() });
245 content.len() - 1
246 }
247 };
248 if let Some(Content::Text { text: t }) = content.get_mut(idx) {
249 t.push_str(&text);
250 }
251 let _ = tx.send(StreamEvent::TextDelta {
252 content_index: idx,
253 delta: text,
254 });
255 }
256 if let Some(fc) = part.function_call {
257 let id = format!("vertex-fc-{}", content.len());
258 let args = fc.args.unwrap_or(serde_json::Value::Object(Default::default()));
259 let idx = content.len();
260 content.push(Content::ToolCall {
261 id: id.clone(),
262 name: fc.name.clone(),
263 arguments: args,
264 });
265 let _ = tx.send(StreamEvent::ToolCallStart {
266 content_index: idx,
267 id,
268 name: fc.name,
269 });
270 let _ = tx.send(StreamEvent::ToolCallEnd { content_index: idx });
271 stop_reason = StopReason::ToolUse;
272 }
273 }
274 }
275 if let Some(reason) = candidate.finish_reason {
276 stop_reason = match reason.as_str() {
277 "STOP" => StopReason::Stop,
278 "MAX_TOKENS" => StopReason::Length,
279 _ => StopReason::Stop,
280 };
281 }
282 }
283
284 if let Some(u) = parsed.usage_metadata {
285 usage.input = u.prompt_token_count.unwrap_or(0);
286 usage.output = u.candidates_token_count.unwrap_or(0);
287 usage.total_tokens = u.total_token_count.unwrap_or(0);
288 }
289 }
290 }
291 }
292 }
293 }
294 }
295
296 let message = Message::Assistant {
297 content,
298 stop_reason,
299 model: config.model_config.id.clone(),
300 provider: provider_name.to_string(),
301 usage,
302 timestamp: now_ms(),
303 error_message: None,
304 };
305
306 let _ = tx.send(StreamEvent::Done {
307 message: message.clone(),
308 });
309 Ok(message)
310}
311
312/// Build the request body for Vertex AI (same format as Google GenAI).
313fn build_vertex_request_body(config: &StreamConfig) -> serde_json::Value {
314 // Same format as Google GenAI
315 let mut contents: Vec<serde_json::Value> = Vec::new();
316
317 for msg in &config.messages {
318 match msg {
319 Message::User { content, .. } => {
320 let parts: Vec<serde_json::Value> = content
321 .iter()
322 .filter_map(|c| match c {
323 Content::Text { text } => Some(serde_json::json!({"text": text})),
324 Content::Image { data, mime_type } => Some(serde_json::json!({
325 "inlineData": {"mimeType": mime_type, "data": data},
326 })),
327 _ => None,
328 })
329 .collect();
330 contents.push(serde_json::json!({"role": "user", "parts": parts}));
331 }
332 Message::Assistant { content, .. } => {
333 let parts: Vec<serde_json::Value> = content
334 .iter()
335 .filter_map(|c| match c {
336 Content::Text { text } => Some(serde_json::json!({"text": text})),
337 Content::ToolCall {
338 name, arguments, ..
339 } => Some(serde_json::json!({
340 "functionCall": {"name": name, "args": arguments},
341 })),
342 _ => None,
343 })
344 .collect();
345 contents.push(serde_json::json!({"role": "model", "parts": parts}));
346 }
347 Message::ToolResult {
348 tool_name, content, ..
349 } => {
350 let text = content
351 .iter()
352 .find_map(|c| match c {
353 Content::Text { text } => Some(text.clone()),
354 _ => None,
355 })
356 .unwrap_or_default();
357
358 let mut parts = vec![serde_json::json!({
359 "functionResponse": {"name": tool_name, "response": {"result": text}}
360 })];
361
362 for c in content {
363 if let Content::Image { data, mime_type } = c {
364 parts.push(serde_json::json!({
365 "inlineData": {"mimeType": mime_type, "data": data},
366 }));
367 }
368 }
369
370 contents.push(serde_json::json!({
371 "role": "user",
372 "parts": parts,
373 }));
374 }
375 }
376 }
377
378 let mut body = serde_json::json!({"contents": contents});
379
380 if !config.system_prompt.is_empty() {
381 body["systemInstruction"] = serde_json::json!({"parts": [{"text": config.system_prompt}]});
382 }
383
384 let mut gen_config = serde_json::json!({});
385 if let Some(max) = config.max_tokens {
386 gen_config["maxOutputTokens"] = serde_json::json!(max);
387 }
388 if let Some(temp) = config.temperature {
389 gen_config["temperature"] = serde_json::json!(temp);
390 }
391 // Vertex AI shares Gemini's structured-output shape (responseMimeType +
392 // optional responseSchema inside generationConfig).
393 match &config.response_format {
394 ResponseFormat::Text => {}
395 ResponseFormat::JsonObject => {
396 gen_config["responseMimeType"] = serde_json::json!("application/json");
397 }
398 ResponseFormat::JsonSchema { schema, .. } => {
399 gen_config["responseMimeType"] = serde_json::json!("application/json");
400 gen_config["responseSchema"] = schema.clone();
401 }
402 }
403 if gen_config != serde_json::json!({}) {
404 body["generationConfig"] = gen_config;
405 }
406
407 if !config.tools.is_empty() {
408 let declarations: Vec<serde_json::Value> = config
409 .tools
410 .iter()
411 .map(|t| {
412 serde_json::json!({
413 "name": t.name,
414 "description": t.description,
415 "parameters": t.parameters,
416 })
417 })
418 .collect();
419 body["tools"] = serde_json::json!([{"functionDeclarations": declarations}]);
420 }
421
422 body
423}