Skip to main content

polyc_llm_vertex/
lib.rs

1//! A concrete [`LlmProvider`] backed by the GCP AI platform REST API,
2//! authenticated via Workload Identity Federation / Application Default
3//! Credentials (handled by `gcp_auth`: in-cluster metadata, WIF, or a local
4//! `gcloud` login, transparently).
5//!
6//! Uses the SSE-streaming `streamGenerateContent` endpoint
7//! (`?alt=sse`) and adapts each partial response into the streaming
8//! [`Chunk`] vocabulary the trait expects (text → [`Chunk::TextDelta`] per
9//! token batch, function calls → tool-call chunks, usage + finish reason
10//! → [`Chunk::Usage`]/[`Chunk::Stop`]). Chunks are yielded as bytes arrive
11//! — the harness pushes them down to the client without buffering the
12//! whole response, so user-visible latency starts at first-token time.
13//! The trait boundary keeps it a
14//! single-file change.
15
16use std::sync::Arc;
17
18use async_trait::async_trait;
19use futures::stream::{BoxStream, StreamExt};
20use gcp_auth::TokenProvider;
21use polyc_llm::{
22    Chunk, CompletionRequest, Content, LlmProvider, Message, Role, StopReason, Usage,
23    sse::next_event_boundary,
24};
25use serde::Deserialize;
26
27const SCOPE: &str = "https://www.googleapis.com/auth/cloud-platform";
28
29/// How long to wait for the TCP/TLS connection to establish.
30const CONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10);
31/// Per-read idle timeout — bounds a *stalled* SSE stream without capping a long
32/// healthy generation (resets on each successful read; NOT `Client::timeout`).
33const READ_TIMEOUT: std::time::Duration = std::time::Duration::from_mins(2);
34
35/// Which model/project/region to call.
36#[derive(Debug, Clone)]
37pub struct VertexConfig {
38    /// GCP project id.
39    pub project: String,
40    /// Region, e.g. `"us-central1"`.
41    pub location: String,
42    /// Publisher model id, e.g. `"<family>-<size>-<version>"`.
43    pub model: String,
44}
45
46/// Errors from the provider.
47#[derive(Debug, thiserror::Error)]
48pub enum VertexError {
49    /// Credential acquisition failed.
50    #[error("auth: {0}")]
51    Auth(#[from] gcp_auth::Error),
52    /// Transport/HTTP failure.
53    #[error("http: {0}")]
54    Http(#[from] reqwest::Error),
55    /// The API returned a non-2xx status.
56    #[error("provider returned status {status}: {body}")]
57    Provider {
58        /// HTTP status code.
59        status: u16,
60        /// Response body.
61        body: String,
62    },
63}
64
65impl polyc_llm::LlmError for VertexError {
66    fn kind(&self) -> polyc_llm::LlmErrorKind {
67        use polyc_llm::LlmErrorKind;
68        match self {
69            // Credential acquisition failure → terminal auth error.
70            Self::Auth(_) => LlmErrorKind::Auth,
71            Self::Http(e) if e.is_timeout() => LlmErrorKind::Timeout,
72            Self::Http(_) => LlmErrorKind::Unavailable,
73            Self::Provider { status, .. } => polyc_llm::kind_from_http_status(*status),
74        }
75    }
76}
77
78/// `LlmProvider` over the GCP AI platform REST API.
79pub struct VertexProvider {
80    http: reqwest::Client,
81    tokens: Arc<dyn TokenProvider>,
82    config: VertexConfig,
83}
84
85impl VertexProvider {
86    /// Build a provider, resolving ambient credentials (WIF/ADC/metadata).
87    ///
88    /// # Errors
89    ///
90    /// Returns [`VertexError::Auth`] if no credentials can be resolved.
91    pub async fn new(config: VertexConfig) -> Result<Self, VertexError> {
92        let tokens = gcp_auth::provider().await?;
93        // Bound a stalled stream so a detached harness turn task can't hang
94        // forever waiting on a read that never returns (the control plane's
95        // TURN_DEADLINE can't reclaim it).
96        let http = reqwest::Client::builder()
97            .connect_timeout(CONNECT_TIMEOUT)
98            .read_timeout(READ_TIMEOUT)
99            .build()
100            .unwrap_or_else(|_| reqwest::Client::new());
101        Ok(Self {
102            http,
103            tokens,
104            config,
105        })
106    }
107
108    /// Build the `streamGenerateContent` URL for `model`.
109    ///
110    /// `model` is the per-request model id (`CompletionRequest::model`); the
111    /// caller falls back to [`VertexConfig::model`] when the request leaves it
112    /// empty, so the configured model is only a default and the active model can
113    /// be switched per turn without rebuilding the provider.
114    fn endpoint(&self, model: &str) -> String {
115        let VertexConfig {
116            project, location, ..
117        } = &self.config;
118        // The `global` location is served from the un-prefixed host
119        // (`aiplatform.googleapis.com`), NOT `global-aiplatform.googleapis.com`
120        // — the path still carries `locations/global`. The newest models are
121        // global-only, so this branch is required to reach them; regional
122        // locations keep the `{location}-` host prefix.
123        let host = if location == "global" {
124            "aiplatform.googleapis.com".to_owned()
125        } else {
126            format!("{location}-aiplatform.googleapis.com")
127        };
128        format!(
129            "https://{host}/v1/projects/{project}/locations/{location}/publishers/google/models/{model}:streamGenerateContent?alt=sse"
130        )
131    }
132}
133
134#[async_trait]
135impl LlmProvider for VertexProvider {
136    type Error = VertexError;
137
138    async fn complete(
139        &self,
140        req: CompletionRequest,
141    ) -> Result<BoxStream<'static, Result<Chunk, Self::Error>>, Self::Error> {
142        // Per-request model wins; the configured model is the fallback default,
143        // so an empty request model (or a switch to a different one) needs no
144        // provider rebuild.
145        let model = if req.model.is_empty() {
146            self.config.model.as_str()
147        } else {
148            req.model.as_str()
149        };
150        let body = build_request(&req);
151        tracing::debug!(
152            model = %model,
153            messages = req.messages.len(),
154            tools = req.tools.len(),
155            max_tokens = ?req.max_tokens,
156            temperature = ?req.temperature,
157            body = %body,
158            "vertex request"
159        );
160        let token = self.tokens.token(&[SCOPE]).await?;
161        let resp = self
162            .http
163            .post(self.endpoint(model))
164            .bearer_auth(token.as_str())
165            .json(&body)
166            .send()
167            .await?;
168
169        let status = resp.status();
170        if !status.is_success() {
171            let body = resp.text().await.unwrap_or_default();
172            return Err(VertexError::Provider {
173                status: status.as_u16(),
174                body,
175            });
176        }
177
178        // Yield Chunks as SSE events arrive. Each event is `data: <json>\n\n`
179        // where `<json>` is a partial `GenerateContentResponse` (a slice of
180        // the full reply); `map_response` produces one or more Chunks per
181        // event. Buffering only the smallest amount needed to find the next
182        // `\n\n` boundary keeps end-to-end latency at first-token time.
183        let byte_stream = resp.bytes_stream();
184        let chunks = async_stream::stream! {
185            use futures::StreamExt as _;
186            let mut byte_stream = byte_stream;
187            let mut buf: Vec<u8> = Vec::new();
188            // Stream-scoped counter so synthesized tool-call ids stay unique
189            // across SSE events (see `map_response`).
190            let mut tool_seq = 0usize;
191            while let Some(item) = byte_stream.next().await {
192                let bytes = match item {
193                    Ok(b) => b,
194                    Err(e) => { yield Err(VertexError::from(e)); return; }
195                };
196                buf.extend_from_slice(&bytes);
197                while let Some((pos, sep_len)) = next_event_boundary(&buf) {
198                    let event_bytes: Vec<u8> = buf.drain(..pos + sep_len).collect();
199                    // Trim the trailing separator.
200                    let event = std::str::from_utf8(&event_bytes[..event_bytes.len() - sep_len])
201                        .unwrap_or("");
202                    for line in event.lines() {
203                        let Some(json) = line.strip_prefix("data: ").or_else(|| line.strip_prefix("data:")) else {
204                            continue;
205                        };
206                        tracing::debug!(event = %json, "vertex sse event");
207                        match serde_json::from_str::<GenerateContentResponse>(json) {
208                            Ok(resp) => {
209                                for chunk in map_response(resp, &mut tool_seq) {
210                                    yield chunk;
211                                }
212                            }
213                            Err(err) => {
214                                // Don't kill the stream over one malformed
215                                // event — surface it and continue.
216                                yield Err(VertexError::Provider {
217                                    status: 0,
218                                    body: format!("malformed SSE JSON: {err}; line: {json}"),
219                                });
220                            }
221                        }
222                    }
223                }
224            }
225        };
226        Ok(chunks.boxed())
227    }
228}
229
230/// Map a [`CompletionRequest`] to the GCP `generateContent` request body.
231fn build_request(req: &CompletionRequest) -> serde_json::Value {
232    let mut contents = Vec::new();
233    let mut system_parts = Vec::new();
234
235    for msg in &req.messages {
236        if msg.role == Role::System {
237            for c in &msg.content {
238                if let Content::Text(t) = c {
239                    system_parts.push(serde_json::json!({ "text": t }));
240                }
241            }
242        } else {
243            let role = if msg.role == Role::Assistant {
244                "model"
245            } else {
246                "user"
247            };
248            let parts = message_parts(msg);
249            if !parts.is_empty() {
250                contents.push(serde_json::json!({ "role": role, "parts": parts }));
251            }
252        }
253    }
254
255    let mut body = serde_json::json!({ "contents": contents });
256    if !system_parts.is_empty() {
257        body["systemInstruction"] = serde_json::json!({ "parts": system_parts });
258    }
259    let mut gen_config = serde_json::Map::new();
260    if let Some(max) = req.max_tokens {
261        gen_config.insert("maxOutputTokens".into(), max.into());
262    }
263    if let Some(temp) = req.temperature {
264        gen_config.insert("temperature".into(), temp.into());
265    }
266    if !req.stop.is_empty() {
267        gen_config.insert("stopSequences".into(), serde_json::json!(req.stop));
268    }
269    if !gen_config.is_empty() {
270        body["generationConfig"] = serde_json::Value::Object(gen_config);
271    }
272    // Gemini accepts a single `tools` array carrying multiple entries: one
273    // `functionDeclarations` block for our MCP/native tools, plus built-in tools
274    // like `googleSearch`. Gemini 3 supports combining them in one request, so
275    // the model can pivot between searching the public web and calling our tools
276    // within a turn. `web_search` is what `run_turn` sets for answering turns;
277    // auxiliary calls (summarize/classify) leave it false and get no search.
278    let mut tool_entries: Vec<serde_json::Value> = Vec::new();
279    if !req.tools.is_empty() {
280        let decls: Vec<_> = req
281            .tools
282            .iter()
283            .map(|t| {
284                serde_json::json!({
285                    "name": t.name,
286                    "description": t.description,
287                    "parameters": sanitize_schema_for_gemini(&t.schema_json),
288                })
289            })
290            .collect();
291        tool_entries.push(serde_json::json!({ "functionDeclarations": decls }));
292    }
293    if req.web_search {
294        tool_entries.push(serde_json::json!({ "googleSearch": {} }));
295    }
296    if !tool_entries.is_empty() {
297        body["tools"] = serde_json::Value::Array(tool_entries);
298    }
299    body
300}
301
302/// JSON-Schema keywords Gemini's function-declaration parser rejects with a 400.
303///
304/// MCP tool schemas are frequently JSON-Schema draft-07 (e.g. Zod output) and
305/// carry these; Vertex's function-calling `Schema` is a restricted `OpenAPI` 3.0
306/// subset. We drop them recursively before sending. The kept keywords (`type`,
307/// `properties`, `required`, `enum`, `items`, `description`, `minimum`,
308/// `maximum`, `default`, `format`, …) cover what the model needs to call the
309/// tool; the dropped constraints are re-validated by the tool server anyway.
310///
311/// Verified still required on `gemini-3.1-pro-preview` (2026-06-05): sending a
312/// raw draft-07 `parameters` 400s with `Unknown name "$schema"` /
313/// `"exclusiveMinimum"` — the function-declaration parser is OpenAPI-subset, not
314/// model-version-gated, so this is NOT a 2.5-only workaround.
315const GEMINI_UNSUPPORTED_SCHEMA_KEYS: &[&str] = &[
316    "$schema",
317    "$id",
318    "$ref",
319    "$defs",
320    "$comment",
321    "definitions",
322    "additionalProperties",
323    "unevaluatedProperties",
324    "patternProperties",
325    "exclusiveMinimum",
326    "exclusiveMaximum",
327];
328
329/// Recursively strip [`GEMINI_UNSUPPORTED_SCHEMA_KEYS`] from a JSON-Schema value
330/// so a rich MCP tool schema becomes a valid Vertex function-declaration
331/// `parameters` object. Non-object/array values pass through unchanged.
332fn sanitize_schema_for_gemini(value: &serde_json::Value) -> serde_json::Value {
333    match value {
334        serde_json::Value::Object(map) => serde_json::Value::Object(
335            map.iter()
336                .filter(|(k, _)| !GEMINI_UNSUPPORTED_SCHEMA_KEYS.contains(&k.as_str()))
337                .map(|(k, v)| (k.clone(), sanitize_schema_for_gemini(v)))
338                .collect(),
339        ),
340        serde_json::Value::Array(arr) => {
341            serde_json::Value::Array(arr.iter().map(sanitize_schema_for_gemini).collect())
342        }
343        other => other.clone(),
344    }
345}
346
347/// Build the `parts` array for one message (text, tool calls, tool results).
348fn message_parts(msg: &Message) -> Vec<serde_json::Value> {
349    let mut parts = Vec::new();
350    for c in &msg.content {
351        match c {
352            Content::Text(t) => parts.push(serde_json::json!({ "text": t })),
353            Content::ToolUse(tc) => {
354                let args: serde_json::Value =
355                    serde_json::from_str(&tc.args_json).unwrap_or(serde_json::Value::Null);
356                let mut part = serde_json::json!({
357                    "functionCall": { "name": tc.name, "args": args }
358                });
359                // Thinking models require the thought signature they emitted
360                // alongside a function call to be echoed back on the part when
361                // it reappears in the request history, or they reject the
362                // request (400 INVALID_ARGUMENT).
363                if let Some(sig) = &tc.signature {
364                    part["thoughtSignature"] = serde_json::json!(sig);
365                }
366                parts.push(part);
367            }
368            Content::ToolResult(tr) => {
369                let result: serde_json::Value =
370                    serde_json::from_str(&tr.result_json).unwrap_or(serde_json::Value::Null);
371                parts.push(serde_json::json!({
372                    "functionResponse": { "name": tr.tool_call_id, "response": { "result": result } }
373                }));
374            }
375            // Images and any future content variants are not yet mapped.
376            _ => {}
377        }
378    }
379    parts
380}
381
382/// Fold a `generateContent` response into ordered [`Chunk`]s.
383///
384/// `tool_seq` is a stream-scoped counter for synthesizing unique tool-call
385/// ids: the provider emits no ids of its own, and a turn's function calls can
386/// arrive across multiple SSE events, so a per-event index would collide
387/// (every event would restart at `call-0`). Threading one counter across all
388/// events keeps ids unique and stable.
389fn map_response(
390    resp: GenerateContentResponse,
391    tool_seq: &mut usize,
392) -> Vec<Result<Chunk, VertexError>> {
393    let mut chunks = Vec::new();
394    let candidate = resp.candidates.into_iter().next();
395
396    let mut text = String::new();
397    let mut tool_calls = Vec::new();
398    let mut finish = None;
399    if let Some(c) = candidate {
400        finish = c.finish_reason;
401        if let Some(content) = c.content {
402            for part in content.parts {
403                if let Some(t) = part.text {
404                    text.push_str(&t);
405                }
406                if let Some(fc) = part.function_call {
407                    // Carry the per-part thought signature with its call so it
408                    // can be echoed back on the follow-up request.
409                    tool_calls.push((fc, part.thought_signature));
410                }
411            }
412        }
413    }
414
415    if !text.is_empty() {
416        chunks.push(Ok(Chunk::text_delta(text)));
417    }
418    for (fc, signature) in &tool_calls {
419        let id = format!("call-{}", *tool_seq);
420        *tool_seq += 1;
421        chunks.push(Ok(Chunk::tool_call_start_signed(
422            id.clone(),
423            fc.name.clone(),
424            signature.clone(),
425        )));
426        chunks.push(Ok(Chunk::tool_call_args_delta(
427            id.clone(),
428            fc.args.to_string(),
429        )));
430        chunks.push(Ok(Chunk::tool_call_end(id)));
431    }
432    if let Some(u) = resp.usage_metadata {
433        chunks.push(Ok(Chunk::Usage(Usage {
434            input_tokens: u.prompt_token_count,
435            output_tokens: u.candidates_token_count,
436        })));
437    }
438    // Only terminate the stream when this event actually signals end of turn:
439    // either an explicit `finish_reason` (the canonical end marker) or a
440    // tool-use payload (the model handed control back). Partial SSE events
441    // have neither, so they don't push a Stop — keeping the stream open for
442    // subsequent token batches.
443    if finish.is_some() || !tool_calls.is_empty() {
444        let mapped = map_finish_reason(finish.as_deref());
445        // Tool calls present → the model yielded for tool use, INCLUDING the
446        // common streamed case where the call arrives with no finish reason
447        // (mapped → EndTurn). But a *hard* finish (MaxTokens / Refusal /
448        // StopSequence) means the output was truncated or refused, so it wins:
449        // the call may be incomplete and must not be executed downstream.
450        let stop = if !tool_calls.is_empty() && matches!(mapped, StopReason::EndTurn) {
451            StopReason::ToolUse
452        } else {
453            mapped
454        };
455        chunks.push(Ok(Chunk::Stop(stop)));
456    }
457    chunks
458}
459
460fn map_finish_reason(reason: Option<&str>) -> StopReason {
461    match reason {
462        Some("MAX_TOKENS") => StopReason::MaxTokens,
463        Some("STOP_SEQUENCE") => StopReason::StopSequence,
464        Some("SAFETY" | "RECITATION" | "BLOCKLIST" | "PROHIBITED_CONTENT" | "SPII") => {
465            StopReason::Refusal
466        }
467        _ => StopReason::EndTurn,
468    }
469}
470
471// ── response wire types ─────────────────────────────────────────────────────
472
473#[derive(Deserialize)]
474struct GenerateContentResponse {
475    #[serde(default)]
476    candidates: Vec<Candidate>,
477    #[serde(default, rename = "usageMetadata")]
478    usage_metadata: Option<UsageMetadata>,
479}
480
481#[derive(Deserialize)]
482struct Candidate {
483    #[serde(default)]
484    content: Option<RespContent>,
485    #[serde(default, rename = "finishReason")]
486    finish_reason: Option<String>,
487}
488
489#[derive(Deserialize)]
490struct RespContent {
491    #[serde(default)]
492    parts: Vec<Part>,
493}
494
495#[derive(Deserialize)]
496struct Part {
497    #[serde(default)]
498    text: Option<String>,
499    #[serde(default, rename = "functionCall")]
500    function_call: Option<FunctionCall>,
501    #[serde(default, rename = "thoughtSignature")]
502    thought_signature: Option<String>,
503}
504
505#[derive(Deserialize)]
506struct FunctionCall {
507    name: String,
508    #[serde(default)]
509    args: serde_json::Value,
510}
511
512#[derive(Deserialize)]
513struct UsageMetadata {
514    #[serde(default, rename = "promptTokenCount")]
515    prompt_token_count: u64,
516    #[serde(default, rename = "candidatesTokenCount")]
517    candidates_token_count: u64,
518}
519
520#[cfg(test)]
521mod tests {
522    #![allow(clippy::pedantic, clippy::nursery, missing_docs)]
523
524    use super::*;
525
526    #[test]
527    fn maps_text_and_usage_and_stop() {
528        let resp: GenerateContentResponse = serde_json::from_value(serde_json::json!({
529            "candidates": [{
530                "content": { "role": "model", "parts": [{ "text": "parity" }] },
531                "finishReason": "STOP"
532            }],
533            "usageMetadata": { "promptTokenCount": 5, "candidatesTokenCount": 2 }
534        }))
535        .unwrap();
536        let chunks: Vec<_> = map_response(resp, &mut 0)
537            .into_iter()
538            .map(Result::unwrap)
539            .collect();
540        assert_eq!(chunks[0], Chunk::text_delta("parity"));
541        assert!(matches!(
542            chunks[chunks.len() - 1],
543            Chunk::Stop(StopReason::EndTurn)
544        ));
545    }
546
547    #[test]
548    fn maps_function_call_to_tool_chunks() {
549        let resp: GenerateContentResponse = serde_json::from_value(serde_json::json!({
550            "candidates": [{
551                "content": { "parts": [{ "functionCall": { "name": "search", "args": { "q": "rust" } } }] },
552                "finishReason": "STOP"
553            }]
554        }))
555        .unwrap();
556        let chunks: Vec<_> = map_response(resp, &mut 0)
557            .into_iter()
558            .map(Result::unwrap)
559            .collect();
560        assert!(
561            chunks
562                .iter()
563                .any(|c| matches!(c, Chunk::ToolCallStart { name, .. } if name == "search"))
564        );
565        assert!(matches!(
566            chunks[chunks.len() - 1],
567            Chunk::Stop(StopReason::ToolUse)
568        ));
569    }
570
571    #[test]
572    fn build_request_maps_roles_and_system() {
573        let mut req = CompletionRequest::new("m");
574        req.system = None;
575        req.messages = vec![Message::system("be terse"), Message::user("hi")];
576        let body = build_request(&req);
577        assert_eq!(body["systemInstruction"]["parts"][0]["text"], "be terse");
578        assert_eq!(body["contents"][0]["role"], "user");
579        assert_eq!(body["contents"][0]["parts"][0]["text"], "hi");
580    }
581
582    #[test]
583    fn build_request_strips_gemini_incompatible_tool_schema_keys() {
584        use polyc_llm::ToolSpec;
585        let mut req = CompletionRequest::new("m");
586        // Gemini's function-declaration parser rejects draft-07 keywords with a
587        // 400 (verified on gemini-3.1-pro-preview 2026-06-05), so build_request
588        // must strip them recursively before sending.
589        req.tools = vec![ToolSpec {
590            name: "list_recent".to_owned(),
591            description: "recent".to_owned(),
592            schema_json: serde_json::json!({
593                "$schema": "http://json-schema.org/draft-07/schema#",
594                "type": "object",
595                "additionalProperties": false,
596                "properties": {
597                    "limit": { "type": "integer", "exclusiveMinimum": 0, "maximum": 100 }
598                }
599            }),
600            title: None,
601            needs_approval: false,
602        }];
603        let params = &build_request(&req)["tools"][0]["functionDeclarations"][0]["parameters"];
604        assert!(params.get("$schema").is_none());
605        assert!(params.get("additionalProperties").is_none());
606        assert!(
607            params["properties"]["limit"]
608                .get("exclusiveMinimum")
609                .is_none()
610        );
611        // Valid keywords survive.
612        assert_eq!(params["type"], "object");
613        assert_eq!(params["properties"]["limit"]["type"], "integer");
614        assert_eq!(params["properties"]["limit"]["maximum"], 100);
615    }
616
617    #[test]
618    fn web_search_adds_google_search_grounding_tool() {
619        use polyc_llm::ToolSpec;
620        // web_search off: no tools array at all when there are no function tools.
621        let off = CompletionRequest::new("m");
622        assert!(build_request(&off).get("tools").is_none());
623
624        // web_search on, no function tools: a lone googleSearch entry.
625        let mut grounded = CompletionRequest::new("m");
626        grounded.web_search = true;
627        let tools = build_request(&grounded)["tools"].clone();
628        assert_eq!(tools, serde_json::json!([{ "googleSearch": {} }]));
629
630        // web_search on WITH a function tool: both entries coexist in one array
631        // (Gemini 3 combines built-in + custom tools in a single request).
632        grounded.tools = vec![ToolSpec {
633            name: "list_recent".to_owned(),
634            description: "recent".to_owned(),
635            schema_json: serde_json::json!({ "type": "object" }),
636            title: None,
637            needs_approval: false,
638        }];
639        let tools = build_request(&grounded)["tools"].clone();
640        assert_eq!(tools.as_array().map(Vec::len), Some(2));
641        assert!(tools[0].get("functionDeclarations").is_some());
642        assert_eq!(tools[1], serde_json::json!({ "googleSearch": {} }));
643    }
644}