Skip to main content

phi_core/provider/
google_vertex.rs

1//! Google Vertex AI provider.
2//!
3//! Similar to Google Generative AI but uses OAuth2 authentication
4//! and a different base URL pattern with project/location.
5//!
6//! The API key in StreamConfig is expected to be an OAuth2 access token.
7//! Callers are responsible for obtaining the token (e.g., via service account JWT).
8/*
9ARCHITECTURE: GoogleVertexProvider — enterprise Gemini via Vertex AI
10
11Vertex AI is Google's enterprise AI platform. It hosts the same Gemini models
12as Generative AI (`generativelanguage.googleapis.com`) but with:
13
14  Different URL structure:
15    GenAI:  https://generativelanguage.googleapis.com/v1beta/models/{model}:streamGenerateContent
16    Vertex: https://{region}-aiplatform.googleapis.com/v1/projects/{project}/locations/{region}/
17              publishers/google/models/{model}:streamGenerateContent
18
19  Different authentication:
20    GenAI:  `?key={api_key}` query parameter (simple API key)
21    Vertex: `Authorization: Bearer {oauth2_access_token}` header
22
23  Same API format:
24    Both use identical request/response JSON shapes (Gemini content format).
25    We re-use `build_vertex_request_body()` which is structurally the same as
26    Google GenAI's `build_request_body()`.
27
28ARCHITECTURE: Delegation pattern
29
30`GoogleVertexProvider` doesn't re-implement the SSE event loop. Instead, it:
31  1. Constructs the Vertex-specific URL (`vertex_url()` static method)
32  2. Adds the OAuth2 Bearer token as a header
33  3. Delegates to `super::google::stream_google_content()` (shared SSE logic)
34
35This avoids duplicating the Google event parsing code. The only Vertex-specific
36logic is URL construction and auth — everything else is identical to GenAI.
37
38RUST QUIRK: `fn vertex_url(model_config: &ModelConfig, model: &str) -> String`
39  An associated function on `GoogleVertexProvider` (no `self` parameter).
40  Called as `Self::vertex_url(model_config, model)` or `GoogleVertexProvider::vertex_url(...)`.
41  Python analogy: a `@staticmethod` on the class.
42*/
43
44use super::model::ModelConfig;
45use super::traits::*;
46use crate::types::*;
47use async_trait::async_trait;
48use tokio::sync::mpsc;
49
50/// Unit struct — no state. All logic in the `StreamProvider` impl.
51pub struct GoogleVertexProvider;
52
53impl GoogleVertexProvider {
54    /// Build the Vertex AI URL from model config.
55    /// Expects base_url in format: `https://{region}-aiplatform.googleapis.com/v1/projects/{project}/locations/{region}/publishers/google/models`
56    fn vertex_url(
57        model_config: &ModelConfig, // CONFIG — carries base_url (Vertex endpoint) to construct full URL
58        model: &str, // MODEL NAME — appended to base_url to get the per-model endpoint
59    ) -> String {
60        format!(
61            "{}/{}:streamGenerateContent?alt=sse",
62            model_config.base_url, model
63        )
64    }
65}
66
67#[async_trait]
68impl StreamProvider for GoogleVertexProvider {
69    fn provider_id(&self) -> &str {
70        "vertex"
71    }
72
73    async fn stream(
74        &self,
75        config: StreamConfig, // REQUEST — api_key is OAuth2 Bearer token (not API key); base_url is Vertex endpoint
76        tx: mpsc::UnboundedSender<StreamEvent>, // OBSERVER — delegates to GoogleProvider's stream logic
77        cancel: tokio_util::sync::CancellationToken, // ABORT — forwarded to delegate
78    ) -> Result<Message, ProviderError> {
79        let model_config = &config.model_config;
80        // Resolve via CredentialProvider when set, else use the static `api_key`.
81        let api_key = model_config.resolve_api_key().await?;
82
83        // Override the base_url to use Vertex format.
84        // The GoogleProvider's stream will use model_config.base_url, but we need
85        // a different URL pattern. We delegate to GoogleProvider with a modified config.
86        let vertex_url = Self::vertex_url(model_config, &config.model_config.id);
87
88        // Create a modified model config with the Vertex URL pattern
89        let mut vertex_model = model_config.clone();
90        // For Vertex, auth is via Bearer token (OAuth2), not API key in query param.
91        // We need to add the Authorization header.
92        vertex_model
93            .headers
94            .insert("authorization".to_string(), format!("Bearer {}", api_key));
95
96        // Build request body same as Google (same content format)
97        let body = build_vertex_request_body(&config);
98
99        let client = reqwest::Client::new();
100        let mut request = client
101            .post(&vertex_url)
102            .header("content-type", "application/json");
103
104        for (k, v) in &vertex_model.headers {
105            request = request.header(k, v);
106        }
107
108        let response = request
109            .json(&body)
110            .send()
111            .await
112            .map_err(|e| ProviderError::Network(e.to_string()))?;
113
114        if !response.status().is_success() {
115            let status = response.status();
116            let body = response.text().await.unwrap_or_default();
117            return Err(ProviderError::classify(
118                status.as_u16(),
119                &format!("Vertex AI error {}: {}", status, body),
120            ));
121        }
122
123        // Delegate SSE parsing to the Google provider's streaming logic.
124        // Since the response format is identical, we reuse GoogleProvider.
125        // However, we already have the response, so we'll parse it inline.
126        // Actually, let's just delegate to GoogleProvider. The key difference
127        // is auth (Bearer vs API key in URL). We handle that by using a modified
128        // model config. But GoogleProvider builds its own URL... so let's just
129        // use GoogleProvider with a trick: empty api_key and auth in headers.
130        // We can't easily reuse GoogleProvider because it constructs its own URL.
131        // Instead, parse the SSE response directly (same format as Google GenAI).
132        parse_google_sse_response(response, &config, &model_config.provider, tx, cancel).await
133    }
134}
135
136/// Parse a Google-format SSE response stream. Shared between Google and Vertex.
137async fn parse_google_sse_response(
138    response: reqwest::Response,
139    config: &StreamConfig,
140    provider_name: &str,
141    tx: mpsc::UnboundedSender<StreamEvent>,
142    cancel: tokio_util::sync::CancellationToken,
143) -> Result<Message, ProviderError> {
144    use futures::StreamExt;
145    use serde::Deserialize;
146    use tracing::{debug, warn};
147
148    let mut content: Vec<Content> = Vec::new();
149    let mut usage = Usage::default();
150    let mut stop_reason = StopReason::Stop;
151
152    let _ = tx.send(StreamEvent::Start);
153
154    let mut stream = response.bytes_stream();
155    let mut buffer = String::new();
156
157    loop {
158        tokio::select! {
159            _ = cancel.cancelled() => {
160                return Err(ProviderError::Cancelled);
161            }
162            chunk = stream.next() => {
163                match chunk {
164                    None => break,
165                    Some(Err(e)) => {
166                        warn!("Vertex stream error: {}", e);
167                        break;
168                    }
169                    Some(Ok(bytes)) => {
170                        buffer.push_str(&String::from_utf8_lossy(&bytes));
171
172                        while let Some(pos) = buffer.find("\n\n") {
173                            let event_str = buffer[..pos].to_string();
174                            buffer = buffer[pos + 2..].to_string();
175
176                            let data = event_str
177                                .lines()
178                                .find(|l| l.starts_with("data: "))
179                                .map(|l| &l[6..])
180                                .unwrap_or("");
181
182                            if data.is_empty() {
183                                continue;
184                            }
185
186                            #[derive(Deserialize)]
187                            struct Chunk {
188                                #[serde(default)]
189                                candidates: Option<Vec<Candidate>>,
190                                #[serde(default, rename = "usageMetadata")]
191                                usage_metadata: Option<UsageMeta>,
192                            }
193                            #[derive(Deserialize)]
194                            struct Candidate {
195                                #[serde(default)]
196                                content: Option<CContent>,
197                                #[serde(default, rename = "finishReason")]
198                                finish_reason: Option<String>,
199                            }
200                            #[derive(Deserialize)]
201                            struct CContent {
202                                #[serde(default)]
203                                parts: Vec<Part>,
204                            }
205                            #[derive(Deserialize)]
206                            struct Part {
207                                #[serde(default)]
208                                text: Option<String>,
209                                #[serde(default, rename = "functionCall")]
210                                function_call: Option<FCall>,
211                            }
212                            #[derive(Deserialize)]
213                            struct FCall {
214                                name: String,
215                                #[serde(default)]
216                                args: Option<serde_json::Value>,
217                            }
218                            #[derive(Deserialize)]
219                            struct UsageMeta {
220                                #[serde(default, rename = "promptTokenCount")]
221                                prompt_token_count: Option<u64>,
222                                #[serde(default, rename = "candidatesTokenCount")]
223                                candidates_token_count: Option<u64>,
224                                #[serde(default, rename = "totalTokenCount")]
225                                total_token_count: Option<u64>,
226                            }
227
228                            let parsed: Chunk = match serde_json::from_str(data) {
229                                Ok(c) => c,
230                                Err(e) => {
231                                    debug!("Failed to parse Vertex chunk: {}", e);
232                                    continue;
233                                }
234                            };
235
236                            for candidate in parsed.candidates.unwrap_or_default() {
237                                if let Some(c) = candidate.content {
238                                    for part in c.parts {
239                                        if let Some(text) = part.text {
240                                            let idx = content.iter().position(|c| matches!(c, Content::Text { .. }));
241                                            let idx = match idx {
242                                                Some(i) => i,
243                                                None => {
244                                                    content.push(Content::Text { text: String::new() });
245                                                    content.len() - 1
246                                                }
247                                            };
248                                            if let Some(Content::Text { text: t }) = content.get_mut(idx) {
249                                                t.push_str(&text);
250                                            }
251                                            let _ = tx.send(StreamEvent::TextDelta {
252                                                content_index: idx,
253                                                delta: text,
254                                            });
255                                        }
256                                        if let Some(fc) = part.function_call {
257                                            let id = format!("vertex-fc-{}", content.len());
258                                            let args = fc.args.unwrap_or(serde_json::Value::Object(Default::default()));
259                                            let idx = content.len();
260                                            content.push(Content::ToolCall {
261                                                id: id.clone(),
262                                                name: fc.name.clone(),
263                                                arguments: args,
264                                            });
265                                            let _ = tx.send(StreamEvent::ToolCallStart {
266                                                content_index: idx,
267                                                id,
268                                                name: fc.name,
269                                            });
270                                            let _ = tx.send(StreamEvent::ToolCallEnd { content_index: idx });
271                                            stop_reason = StopReason::ToolUse;
272                                        }
273                                    }
274                                }
275                                if let Some(reason) = candidate.finish_reason {
276                                    stop_reason = match reason.as_str() {
277                                        "STOP" => StopReason::Stop,
278                                        "MAX_TOKENS" => StopReason::Length,
279                                        _ => StopReason::Stop,
280                                    };
281                                }
282                            }
283
284                            if let Some(u) = parsed.usage_metadata {
285                                usage.input = u.prompt_token_count.unwrap_or(0);
286                                usage.output = u.candidates_token_count.unwrap_or(0);
287                                usage.total_tokens = u.total_token_count.unwrap_or(0);
288                            }
289                        }
290                    }
291                }
292            }
293        }
294    }
295
296    let message = Message::Assistant {
297        content,
298        stop_reason,
299        model: config.model_config.id.clone(),
300        provider: provider_name.to_string(),
301        usage,
302        timestamp: now_ms(),
303        error_message: None,
304    };
305
306    let _ = tx.send(StreamEvent::Done {
307        message: message.clone(),
308    });
309    Ok(message)
310}
311
312/// Build the request body for Vertex AI (same format as Google GenAI).
313fn build_vertex_request_body(config: &StreamConfig) -> serde_json::Value {
314    // Same format as Google GenAI
315    let mut contents: Vec<serde_json::Value> = Vec::new();
316
317    for msg in &config.messages {
318        match msg {
319            Message::User { content, .. } => {
320                let parts: Vec<serde_json::Value> = content
321                    .iter()
322                    .filter_map(|c| match c {
323                        Content::Text { text } => Some(serde_json::json!({"text": text})),
324                        Content::Image { data, mime_type } => Some(serde_json::json!({
325                            "inlineData": {"mimeType": mime_type, "data": data},
326                        })),
327                        _ => None,
328                    })
329                    .collect();
330                contents.push(serde_json::json!({"role": "user", "parts": parts}));
331            }
332            Message::Assistant { content, .. } => {
333                let parts: Vec<serde_json::Value> = content
334                    .iter()
335                    .filter_map(|c| match c {
336                        Content::Text { text } => Some(serde_json::json!({"text": text})),
337                        Content::ToolCall {
338                            name, arguments, ..
339                        } => Some(serde_json::json!({
340                            "functionCall": {"name": name, "args": arguments},
341                        })),
342                        _ => None,
343                    })
344                    .collect();
345                contents.push(serde_json::json!({"role": "model", "parts": parts}));
346            }
347            Message::ToolResult {
348                tool_name, content, ..
349            } => {
350                let text = content
351                    .iter()
352                    .find_map(|c| match c {
353                        Content::Text { text } => Some(text.clone()),
354                        _ => None,
355                    })
356                    .unwrap_or_default();
357
358                let mut parts = vec![serde_json::json!({
359                    "functionResponse": {"name": tool_name, "response": {"result": text}}
360                })];
361
362                for c in content {
363                    if let Content::Image { data, mime_type } = c {
364                        parts.push(serde_json::json!({
365                            "inlineData": {"mimeType": mime_type, "data": data},
366                        }));
367                    }
368                }
369
370                contents.push(serde_json::json!({
371                    "role": "user",
372                    "parts": parts,
373                }));
374            }
375        }
376    }
377
378    let mut body = serde_json::json!({"contents": contents});
379
380    if !config.system_prompt.is_empty() {
381        body["systemInstruction"] = serde_json::json!({"parts": [{"text": config.system_prompt}]});
382    }
383
384    let mut gen_config = serde_json::json!({});
385    if let Some(max) = config.max_tokens {
386        gen_config["maxOutputTokens"] = serde_json::json!(max);
387    }
388    if let Some(temp) = config.temperature {
389        gen_config["temperature"] = serde_json::json!(temp);
390    }
391    // Vertex AI shares Gemini's structured-output shape (responseMimeType +
392    // optional responseSchema inside generationConfig).
393    match &config.response_format {
394        ResponseFormat::Text => {}
395        ResponseFormat::JsonObject => {
396            gen_config["responseMimeType"] = serde_json::json!("application/json");
397        }
398        ResponseFormat::JsonSchema { schema, .. } => {
399            gen_config["responseMimeType"] = serde_json::json!("application/json");
400            gen_config["responseSchema"] = schema.clone();
401        }
402    }
403    if gen_config != serde_json::json!({}) {
404        body["generationConfig"] = gen_config;
405    }
406
407    if !config.tools.is_empty() {
408        let declarations: Vec<serde_json::Value> = config
409            .tools
410            .iter()
411            .map(|t| {
412                serde_json::json!({
413                    "name": t.name,
414                    "description": t.description,
415                    "parameters": t.parameters,
416                })
417            })
418            .collect();
419        body["tools"] = serde_json::json!([{"functionDeclarations": declarations}]);
420    }
421
422    body
423}