use std::sync::Arc;
use async_trait::async_trait;
use futures::stream::{BoxStream, StreamExt};
use gcp_auth::TokenProvider;
use polyc_llm::{
Chunk, CompletionRequest, Content, LlmProvider, Message, Role, StopReason, Usage,
sse::next_event_boundary,
};
use serde::Deserialize;
const SCOPE: &str = "https://www.googleapis.com/auth/cloud-platform";
const CONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10);
const READ_TIMEOUT: std::time::Duration = std::time::Duration::from_mins(2);
#[derive(Debug, Clone)]
pub struct VertexConfig {
pub project: String,
pub location: String,
pub model: String,
}
#[derive(Debug, thiserror::Error)]
pub enum VertexError {
#[error("auth: {0}")]
Auth(#[from] gcp_auth::Error),
#[error("http: {0}")]
Http(#[from] reqwest::Error),
#[error("provider returned status {status}: {body}")]
Provider {
status: u16,
body: String,
},
}
impl polyc_llm::LlmError for VertexError {
fn kind(&self) -> polyc_llm::LlmErrorKind {
use polyc_llm::LlmErrorKind;
match self {
Self::Auth(_) => LlmErrorKind::Auth,
Self::Http(e) if e.is_timeout() => LlmErrorKind::Timeout,
Self::Http(_) => LlmErrorKind::Unavailable,
Self::Provider { status, .. } => polyc_llm::kind_from_http_status(*status),
}
}
}
pub struct VertexProvider {
http: reqwest::Client,
tokens: Arc<dyn TokenProvider>,
config: VertexConfig,
}
impl VertexProvider {
pub async fn new(config: VertexConfig) -> Result<Self, VertexError> {
let tokens = gcp_auth::provider().await?;
let http = reqwest::Client::builder()
.connect_timeout(CONNECT_TIMEOUT)
.read_timeout(READ_TIMEOUT)
.build()
.unwrap_or_else(|_| reqwest::Client::new());
Ok(Self {
http,
tokens,
config,
})
}
fn endpoint(&self, model: &str) -> String {
let VertexConfig {
project, location, ..
} = &self.config;
let host = if location == "global" {
"aiplatform.googleapis.com".to_owned()
} else {
format!("{location}-aiplatform.googleapis.com")
};
format!(
"https://{host}/v1/projects/{project}/locations/{location}/publishers/google/models/{model}:streamGenerateContent?alt=sse"
)
}
}
#[async_trait]
impl LlmProvider for VertexProvider {
type Error = VertexError;
async fn complete(
&self,
req: CompletionRequest,
) -> Result<BoxStream<'static, Result<Chunk, Self::Error>>, Self::Error> {
let model = if req.model.is_empty() {
self.config.model.as_str()
} else {
req.model.as_str()
};
let body = build_request(&req);
tracing::debug!(
model = %model,
messages = req.messages.len(),
tools = req.tools.len(),
max_tokens = ?req.max_tokens,
temperature = ?req.temperature,
body = %body,
"vertex request"
);
let token = self.tokens.token(&[SCOPE]).await?;
let resp = self
.http
.post(self.endpoint(model))
.bearer_auth(token.as_str())
.json(&body)
.send()
.await?;
let status = resp.status();
if !status.is_success() {
let body = resp.text().await.unwrap_or_default();
return Err(VertexError::Provider {
status: status.as_u16(),
body,
});
}
let byte_stream = resp.bytes_stream();
let chunks = async_stream::stream! {
use futures::StreamExt as _;
let mut byte_stream = byte_stream;
let mut buf: Vec<u8> = Vec::new();
let mut tool_seq = 0usize;
while let Some(item) = byte_stream.next().await {
let bytes = match item {
Ok(b) => b,
Err(e) => { yield Err(VertexError::from(e)); return; }
};
buf.extend_from_slice(&bytes);
while let Some((pos, sep_len)) = next_event_boundary(&buf) {
let event_bytes: Vec<u8> = buf.drain(..pos + sep_len).collect();
let event = std::str::from_utf8(&event_bytes[..event_bytes.len() - sep_len])
.unwrap_or("");
for line in event.lines() {
let Some(json) = line.strip_prefix("data: ").or_else(|| line.strip_prefix("data:")) else {
continue;
};
tracing::debug!(event = %json, "vertex sse event");
match serde_json::from_str::<GenerateContentResponse>(json) {
Ok(resp) => {
for chunk in map_response(resp, &mut tool_seq) {
yield chunk;
}
}
Err(err) => {
yield Err(VertexError::Provider {
status: 0,
body: format!("malformed SSE JSON: {err}; line: {json}"),
});
}
}
}
}
}
};
Ok(chunks.boxed())
}
}
fn build_request(req: &CompletionRequest) -> serde_json::Value {
let mut contents = Vec::new();
let mut system_parts = Vec::new();
for msg in &req.messages {
if msg.role == Role::System {
for c in &msg.content {
if let Content::Text(t) = c {
system_parts.push(serde_json::json!({ "text": t }));
}
}
} else {
let role = if msg.role == Role::Assistant {
"model"
} else {
"user"
};
let parts = message_parts(msg);
if !parts.is_empty() {
contents.push(serde_json::json!({ "role": role, "parts": parts }));
}
}
}
let mut body = serde_json::json!({ "contents": contents });
if !system_parts.is_empty() {
body["systemInstruction"] = serde_json::json!({ "parts": system_parts });
}
let mut gen_config = serde_json::Map::new();
if let Some(max) = req.max_tokens {
gen_config.insert("maxOutputTokens".into(), max.into());
}
if let Some(temp) = req.temperature {
gen_config.insert("temperature".into(), temp.into());
}
if !req.stop.is_empty() {
gen_config.insert("stopSequences".into(), serde_json::json!(req.stop));
}
if !gen_config.is_empty() {
body["generationConfig"] = serde_json::Value::Object(gen_config);
}
let mut tool_entries: Vec<serde_json::Value> = Vec::new();
if !req.tools.is_empty() {
let decls: Vec<_> = req
.tools
.iter()
.map(|t| {
serde_json::json!({
"name": t.name,
"description": t.description,
"parameters": sanitize_schema_for_gemini(&t.schema_json),
})
})
.collect();
tool_entries.push(serde_json::json!({ "functionDeclarations": decls }));
}
if req.web_search {
tool_entries.push(serde_json::json!({ "googleSearch": {} }));
}
if !tool_entries.is_empty() {
body["tools"] = serde_json::Value::Array(tool_entries);
}
body
}
const GEMINI_UNSUPPORTED_SCHEMA_KEYS: &[&str] = &[
"$schema",
"$id",
"$ref",
"$defs",
"$comment",
"definitions",
"additionalProperties",
"unevaluatedProperties",
"patternProperties",
"exclusiveMinimum",
"exclusiveMaximum",
];
fn sanitize_schema_for_gemini(value: &serde_json::Value) -> serde_json::Value {
match value {
serde_json::Value::Object(map) => serde_json::Value::Object(
map.iter()
.filter(|(k, _)| !GEMINI_UNSUPPORTED_SCHEMA_KEYS.contains(&k.as_str()))
.map(|(k, v)| (k.clone(), sanitize_schema_for_gemini(v)))
.collect(),
),
serde_json::Value::Array(arr) => {
serde_json::Value::Array(arr.iter().map(sanitize_schema_for_gemini).collect())
}
other => other.clone(),
}
}
fn message_parts(msg: &Message) -> Vec<serde_json::Value> {
let mut parts = Vec::new();
for c in &msg.content {
match c {
Content::Text(t) => parts.push(serde_json::json!({ "text": t })),
Content::ToolUse(tc) => {
let args: serde_json::Value =
serde_json::from_str(&tc.args_json).unwrap_or(serde_json::Value::Null);
let mut part = serde_json::json!({
"functionCall": { "name": tc.name, "args": args }
});
if let Some(sig) = &tc.signature {
part["thoughtSignature"] = serde_json::json!(sig);
}
parts.push(part);
}
Content::ToolResult(tr) => {
let result: serde_json::Value =
serde_json::from_str(&tr.result_json).unwrap_or(serde_json::Value::Null);
parts.push(serde_json::json!({
"functionResponse": { "name": tr.tool_call_id, "response": { "result": result } }
}));
}
_ => {}
}
}
parts
}
fn map_response(
resp: GenerateContentResponse,
tool_seq: &mut usize,
) -> Vec<Result<Chunk, VertexError>> {
let mut chunks = Vec::new();
let candidate = resp.candidates.into_iter().next();
let mut text = String::new();
let mut tool_calls = Vec::new();
let mut finish = None;
if let Some(c) = candidate {
finish = c.finish_reason;
if let Some(content) = c.content {
for part in content.parts {
if let Some(t) = part.text {
text.push_str(&t);
}
if let Some(fc) = part.function_call {
tool_calls.push((fc, part.thought_signature));
}
}
}
}
if !text.is_empty() {
chunks.push(Ok(Chunk::text_delta(text)));
}
for (fc, signature) in &tool_calls {
let id = format!("call-{}", *tool_seq);
*tool_seq += 1;
chunks.push(Ok(Chunk::tool_call_start_signed(
id.clone(),
fc.name.clone(),
signature.clone(),
)));
chunks.push(Ok(Chunk::tool_call_args_delta(
id.clone(),
fc.args.to_string(),
)));
chunks.push(Ok(Chunk::tool_call_end(id)));
}
if let Some(u) = resp.usage_metadata {
chunks.push(Ok(Chunk::Usage(Usage {
input_tokens: u.prompt_token_count,
output_tokens: u.candidates_token_count,
})));
}
if finish.is_some() || !tool_calls.is_empty() {
let mapped = map_finish_reason(finish.as_deref());
let stop = if !tool_calls.is_empty() && matches!(mapped, StopReason::EndTurn) {
StopReason::ToolUse
} else {
mapped
};
chunks.push(Ok(Chunk::Stop(stop)));
}
chunks
}
fn map_finish_reason(reason: Option<&str>) -> StopReason {
match reason {
Some("MAX_TOKENS") => StopReason::MaxTokens,
Some("STOP_SEQUENCE") => StopReason::StopSequence,
Some("SAFETY" | "RECITATION" | "BLOCKLIST" | "PROHIBITED_CONTENT" | "SPII") => {
StopReason::Refusal
}
_ => StopReason::EndTurn,
}
}
#[derive(Deserialize)]
struct GenerateContentResponse {
#[serde(default)]
candidates: Vec<Candidate>,
#[serde(default, rename = "usageMetadata")]
usage_metadata: Option<UsageMetadata>,
}
#[derive(Deserialize)]
struct Candidate {
#[serde(default)]
content: Option<RespContent>,
#[serde(default, rename = "finishReason")]
finish_reason: Option<String>,
}
#[derive(Deserialize)]
struct RespContent {
#[serde(default)]
parts: Vec<Part>,
}
#[derive(Deserialize)]
struct Part {
#[serde(default)]
text: Option<String>,
#[serde(default, rename = "functionCall")]
function_call: Option<FunctionCall>,
#[serde(default, rename = "thoughtSignature")]
thought_signature: Option<String>,
}
#[derive(Deserialize)]
struct FunctionCall {
name: String,
#[serde(default)]
args: serde_json::Value,
}
#[derive(Deserialize)]
struct UsageMetadata {
#[serde(default, rename = "promptTokenCount")]
prompt_token_count: u64,
#[serde(default, rename = "candidatesTokenCount")]
candidates_token_count: u64,
}
#[cfg(test)]
mod tests {
#![allow(clippy::pedantic, clippy::nursery, missing_docs)]
use super::*;
#[test]
fn maps_text_and_usage_and_stop() {
let resp: GenerateContentResponse = serde_json::from_value(serde_json::json!({
"candidates": [{
"content": { "role": "model", "parts": [{ "text": "parity" }] },
"finishReason": "STOP"
}],
"usageMetadata": { "promptTokenCount": 5, "candidatesTokenCount": 2 }
}))
.unwrap();
let chunks: Vec<_> = map_response(resp, &mut 0)
.into_iter()
.map(Result::unwrap)
.collect();
assert_eq!(chunks[0], Chunk::text_delta("parity"));
assert!(matches!(
chunks[chunks.len() - 1],
Chunk::Stop(StopReason::EndTurn)
));
}
#[test]
fn maps_function_call_to_tool_chunks() {
let resp: GenerateContentResponse = serde_json::from_value(serde_json::json!({
"candidates": [{
"content": { "parts": [{ "functionCall": { "name": "search", "args": { "q": "rust" } } }] },
"finishReason": "STOP"
}]
}))
.unwrap();
let chunks: Vec<_> = map_response(resp, &mut 0)
.into_iter()
.map(Result::unwrap)
.collect();
assert!(
chunks
.iter()
.any(|c| matches!(c, Chunk::ToolCallStart { name, .. } if name == "search"))
);
assert!(matches!(
chunks[chunks.len() - 1],
Chunk::Stop(StopReason::ToolUse)
));
}
#[test]
fn build_request_maps_roles_and_system() {
let mut req = CompletionRequest::new("m");
req.system = None;
req.messages = vec![Message::system("be terse"), Message::user("hi")];
let body = build_request(&req);
assert_eq!(body["systemInstruction"]["parts"][0]["text"], "be terse");
assert_eq!(body["contents"][0]["role"], "user");
assert_eq!(body["contents"][0]["parts"][0]["text"], "hi");
}
#[test]
fn build_request_strips_gemini_incompatible_tool_schema_keys() {
use polyc_llm::ToolSpec;
let mut req = CompletionRequest::new("m");
req.tools = vec![ToolSpec {
name: "list_recent".to_owned(),
description: "recent".to_owned(),
schema_json: serde_json::json!({
"$schema": "http://json-schema.org/draft-07/schema#",
"type": "object",
"additionalProperties": false,
"properties": {
"limit": { "type": "integer", "exclusiveMinimum": 0, "maximum": 100 }
}
}),
title: None,
needs_approval: false,
}];
let params = &build_request(&req)["tools"][0]["functionDeclarations"][0]["parameters"];
assert!(params.get("$schema").is_none());
assert!(params.get("additionalProperties").is_none());
assert!(
params["properties"]["limit"]
.get("exclusiveMinimum")
.is_none()
);
assert_eq!(params["type"], "object");
assert_eq!(params["properties"]["limit"]["type"], "integer");
assert_eq!(params["properties"]["limit"]["maximum"], 100);
}
#[test]
fn web_search_adds_google_search_grounding_tool() {
use polyc_llm::ToolSpec;
let off = CompletionRequest::new("m");
assert!(build_request(&off).get("tools").is_none());
let mut grounded = CompletionRequest::new("m");
grounded.web_search = true;
let tools = build_request(&grounded)["tools"].clone();
assert_eq!(tools, serde_json::json!([{ "googleSearch": {} }]));
grounded.tools = vec![ToolSpec {
name: "list_recent".to_owned(),
description: "recent".to_owned(),
schema_json: serde_json::json!({ "type": "object" }),
title: None,
needs_approval: false,
}];
let tools = build_request(&grounded)["tools"].clone();
assert_eq!(tools.as_array().map(Vec::len), Some(2));
assert!(tools[0].get("functionDeclarations").is_some());
assert_eq!(tools[1], serde_json::json!({ "googleSearch": {} }));
}
}