use super::model::ModelConfig;
use super::traits::*;
use crate::types::*;
use async_trait::async_trait;
use tokio::sync::mpsc;
pub struct GoogleVertexProvider;
impl GoogleVertexProvider {
fn vertex_url(
model_config: &ModelConfig, model: &str, ) -> String {
format!(
"{}/{}:streamGenerateContent?alt=sse",
model_config.base_url, model
)
}
}
#[async_trait]
impl StreamProvider for GoogleVertexProvider {
fn provider_id(&self) -> &str {
"vertex"
}
async fn stream(
&self,
config: StreamConfig, tx: mpsc::UnboundedSender<StreamEvent>, cancel: tokio_util::sync::CancellationToken, ) -> Result<Message, ProviderError> {
let model_config = &config.model_config;
let api_key = model_config.resolve_api_key().await?;
let vertex_url = Self::vertex_url(model_config, &config.model_config.id);
let mut vertex_model = model_config.clone();
vertex_model
.headers
.insert("authorization".to_string(), format!("Bearer {}", api_key));
let body = build_vertex_request_body(&config);
let client = reqwest::Client::new();
let mut request = client
.post(&vertex_url)
.header("content-type", "application/json");
for (k, v) in &vertex_model.headers {
request = request.header(k, v);
}
let response = request
.json(&body)
.send()
.await
.map_err(|e| ProviderError::Network(e.to_string()))?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
return Err(ProviderError::classify(
status.as_u16(),
&format!("Vertex AI error {}: {}", status, body),
));
}
parse_google_sse_response(response, &config, &model_config.provider, tx, cancel).await
}
}
async fn parse_google_sse_response(
response: reqwest::Response,
config: &StreamConfig,
provider_name: &str,
tx: mpsc::UnboundedSender<StreamEvent>,
cancel: tokio_util::sync::CancellationToken,
) -> Result<Message, ProviderError> {
use futures::StreamExt;
use serde::Deserialize;
use tracing::{debug, warn};
let mut content: Vec<Content> = Vec::new();
let mut usage = Usage::default();
let mut stop_reason = StopReason::Stop;
let _ = tx.send(StreamEvent::Start);
let mut stream = response.bytes_stream();
let mut buffer = String::new();
loop {
tokio::select! {
_ = cancel.cancelled() => {
return Err(ProviderError::Cancelled);
}
chunk = stream.next() => {
match chunk {
None => break,
Some(Err(e)) => {
warn!("Vertex stream error: {}", e);
break;
}
Some(Ok(bytes)) => {
buffer.push_str(&String::from_utf8_lossy(&bytes));
while let Some(pos) = buffer.find("\n\n") {
let event_str = buffer[..pos].to_string();
buffer = buffer[pos + 2..].to_string();
let data = event_str
.lines()
.find(|l| l.starts_with("data: "))
.map(|l| &l[6..])
.unwrap_or("");
if data.is_empty() {
continue;
}
#[derive(Deserialize)]
struct Chunk {
#[serde(default)]
candidates: Option<Vec<Candidate>>,
#[serde(default, rename = "usageMetadata")]
usage_metadata: Option<UsageMeta>,
}
#[derive(Deserialize)]
struct Candidate {
#[serde(default)]
content: Option<CContent>,
#[serde(default, rename = "finishReason")]
finish_reason: Option<String>,
}
#[derive(Deserialize)]
struct CContent {
#[serde(default)]
parts: Vec<Part>,
}
#[derive(Deserialize)]
struct Part {
#[serde(default)]
text: Option<String>,
#[serde(default, rename = "functionCall")]
function_call: Option<FCall>,
}
#[derive(Deserialize)]
struct FCall {
name: String,
#[serde(default)]
args: Option<serde_json::Value>,
}
#[derive(Deserialize)]
struct UsageMeta {
#[serde(default, rename = "promptTokenCount")]
prompt_token_count: Option<u64>,
#[serde(default, rename = "candidatesTokenCount")]
candidates_token_count: Option<u64>,
#[serde(default, rename = "totalTokenCount")]
total_token_count: Option<u64>,
}
let parsed: Chunk = match serde_json::from_str(data) {
Ok(c) => c,
Err(e) => {
debug!("Failed to parse Vertex chunk: {}", e);
continue;
}
};
for candidate in parsed.candidates.unwrap_or_default() {
if let Some(c) = candidate.content {
for part in c.parts {
if let Some(text) = part.text {
let idx = content.iter().position(|c| matches!(c, Content::Text { .. }));
let idx = match idx {
Some(i) => i,
None => {
content.push(Content::Text { text: String::new() });
content.len() - 1
}
};
if let Some(Content::Text { text: t }) = content.get_mut(idx) {
t.push_str(&text);
}
let _ = tx.send(StreamEvent::TextDelta {
content_index: idx,
delta: text,
});
}
if let Some(fc) = part.function_call {
let id = format!("vertex-fc-{}", content.len());
let args = fc.args.unwrap_or(serde_json::Value::Object(Default::default()));
let idx = content.len();
content.push(Content::ToolCall {
id: id.clone(),
name: fc.name.clone(),
arguments: args,
});
let _ = tx.send(StreamEvent::ToolCallStart {
content_index: idx,
id,
name: fc.name,
});
let _ = tx.send(StreamEvent::ToolCallEnd { content_index: idx });
stop_reason = StopReason::ToolUse;
}
}
}
if let Some(reason) = candidate.finish_reason {
stop_reason = match reason.as_str() {
"STOP" => StopReason::Stop,
"MAX_TOKENS" => StopReason::Length,
_ => StopReason::Stop,
};
}
}
if let Some(u) = parsed.usage_metadata {
usage.input = u.prompt_token_count.unwrap_or(0);
usage.output = u.candidates_token_count.unwrap_or(0);
usage.total_tokens = u.total_token_count.unwrap_or(0);
}
}
}
}
}
}
}
let message = Message::Assistant {
content,
stop_reason,
model: config.model_config.id.clone(),
provider: provider_name.to_string(),
usage,
timestamp: now_ms(),
error_message: None,
};
let _ = tx.send(StreamEvent::Done {
message: message.clone(),
});
Ok(message)
}
fn build_vertex_request_body(config: &StreamConfig) -> serde_json::Value {
let mut contents: Vec<serde_json::Value> = Vec::new();
for msg in &config.messages {
match msg {
Message::User { content, .. } => {
let parts: Vec<serde_json::Value> = content
.iter()
.filter_map(|c| match c {
Content::Text { text } => Some(serde_json::json!({"text": text})),
Content::Image { data, mime_type } => Some(serde_json::json!({
"inlineData": {"mimeType": mime_type, "data": data},
})),
_ => None,
})
.collect();
contents.push(serde_json::json!({"role": "user", "parts": parts}));
}
Message::Assistant { content, .. } => {
let parts: Vec<serde_json::Value> = content
.iter()
.filter_map(|c| match c {
Content::Text { text } => Some(serde_json::json!({"text": text})),
Content::ToolCall {
name, arguments, ..
} => Some(serde_json::json!({
"functionCall": {"name": name, "args": arguments},
})),
_ => None,
})
.collect();
contents.push(serde_json::json!({"role": "model", "parts": parts}));
}
Message::ToolResult {
tool_name, content, ..
} => {
let text = content
.iter()
.find_map(|c| match c {
Content::Text { text } => Some(text.clone()),
_ => None,
})
.unwrap_or_default();
let mut parts = vec![serde_json::json!({
"functionResponse": {"name": tool_name, "response": {"result": text}}
})];
for c in content {
if let Content::Image { data, mime_type } = c {
parts.push(serde_json::json!({
"inlineData": {"mimeType": mime_type, "data": data},
}));
}
}
contents.push(serde_json::json!({
"role": "user",
"parts": parts,
}));
}
}
}
let mut body = serde_json::json!({"contents": contents});
if !config.system_prompt.is_empty() {
body["systemInstruction"] = serde_json::json!({"parts": [{"text": config.system_prompt}]});
}
let mut gen_config = serde_json::json!({});
if let Some(max) = config.max_tokens {
gen_config["maxOutputTokens"] = serde_json::json!(max);
}
if let Some(temp) = config.temperature {
gen_config["temperature"] = serde_json::json!(temp);
}
match &config.response_format {
ResponseFormat::Text => {}
ResponseFormat::JsonObject => {
gen_config["responseMimeType"] = serde_json::json!("application/json");
}
ResponseFormat::JsonSchema { schema, .. } => {
gen_config["responseMimeType"] = serde_json::json!("application/json");
gen_config["responseSchema"] = schema.clone();
}
}
if gen_config != serde_json::json!({}) {
body["generationConfig"] = gen_config;
}
if !config.tools.is_empty() {
let declarations: Vec<serde_json::Value> = config
.tools
.iter()
.map(|t| {
serde_json::json!({
"name": t.name,
"description": t.description,
"parameters": t.parameters,
})
})
.collect();
body["tools"] = serde_json::json!([{"functionDeclarations": declarations}]);
}
body
}