use std::fmt;
use std::sync::Arc;
use futures_util::{StreamExt, stream};
use tokio_util::sync::CancellationToken;
use crate::gemini::{GeminiMapper, GeminiProvider, ParsedEvent, drain_sse_data, parse_sse_data};
use crate::http::HttpClient;
use crate::provider::{EventStream, ModelInfo, Provider, ProviderError, Request};
use crate::stream::AssistantStreamEvent;
pub struct VertexProvider {
access_token: String,
project: String,
location: String,
base_url: String,
models: Vec<ModelInfo>,
inner: GeminiProvider,
client: Arc<HttpClient>,
}
impl fmt::Debug for VertexProvider {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("VertexProvider")
.field("project", &self.project)
.field("location", &self.location)
.field("access_token", &"***")
.field("models", &self.models.len())
.finish()
}
}
impl VertexProvider {
pub fn new(
access_token: String,
project: String,
location: String,
base_url: Option<String>,
) -> Self {
let base_url =
base_url.unwrap_or_else(|| format!("https://{location}-aiplatform.googleapis.com"));
let inner = GeminiProvider::new(String::new(), None);
let models = default_vertex_models();
Self {
access_token,
project,
location,
base_url,
models,
inner,
client: Arc::new(HttpClient::new()),
}
}
pub fn from_config(
access_token: String,
project: String,
location: String,
models: Vec<String>,
base_url: Option<String>,
) -> Self {
let base_url =
base_url.unwrap_or_else(|| format!("https://{location}-aiplatform.googleapis.com"));
let inner = GeminiProvider::new(String::new(), None);
let model_list = models
.iter()
.map(|id| ModelInfo {
id: id.clone(),
display_name: id.clone(),
context_window: 1_000_000,
max_output_tokens: 65536,
supports_images: true,
supports_streaming: true,
supports_thinking: false,
})
.collect();
Self {
access_token,
project,
location,
base_url,
models: model_list,
inner,
client: Arc::new(HttpClient::new()),
}
}
pub fn with_client(self, client: Arc<HttpClient>) -> Self {
Self { client, ..self }
}
pub fn build_vertex_url(&self, model_id: &str) -> String {
format!(
"{base}/v1/projects/{project}/locations/{location}/publishers/google/models/{model}:streamGenerateContent?alt=sse",
base = self.base_url,
project = self.project,
location = self.location,
model = model_id,
)
}
pub fn build_request_body(&self, request: &Request) -> serde_json::Value {
self.inner.build_request_body(request)
}
pub fn stream_from_sse(&self, sse_body: &str, cancel: CancellationToken) -> EventStream {
let mut mapper = GeminiMapper::new("vertex");
let mut stream_events: Vec<Result<AssistantStreamEvent, ProviderError>> = Vec::new();
for data in parse_sse_data(sse_body) {
for parsed in ParsedEvent::from_data(&data) {
match parsed {
ParsedEvent::Valid(event) => {
stream_events.extend(mapper.process(event).into_iter().map(Ok));
}
ParsedEvent::Malformed { data, error } => {
stream_events.push(Err(ProviderError::StreamError(format!(
"malformed SSE data: {error} (data: {data:.80})"
))));
}
}
}
}
let _cancel = cancel;
Box::pin(stream::iter(stream_events))
}
}
struct ReceiverStream {
rx: tokio::sync::mpsc::Receiver<Result<AssistantStreamEvent, ProviderError>>,
}
impl futures_core::Stream for ReceiverStream {
type Item = Result<AssistantStreamEvent, ProviderError>;
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
self.rx.poll_recv(cx)
}
}
impl Provider for VertexProvider {
fn id(&self) -> &str {
"vertex"
}
fn models(&self) -> &[ModelInfo] {
&self.models
}
fn stream(&self, request: Request) -> EventStream {
let model_id = request
.model
.split_once(':')
.map(|(_, id)| id)
.unwrap_or(&request.model)
.to_string();
let url = self.build_vertex_url(&model_id);
let body = self.inner.build_request_body(&request);
let cancel = request.cancel;
let http_client = self.client.client().clone();
let access_token = self.access_token.clone();
let (tx, rx) = tokio::sync::mpsc::channel(64);
tokio::spawn(async move {
if let Err(e) =
stream_vertex_http(http_client, access_token, &url, &body, cancel, &tx).await
{
let _ = tx.send(Err(e)).await;
}
});
Box::pin(ReceiverStream { rx })
}
}
async fn stream_vertex_http(
http_client: reqwest::Client,
access_token: String,
url: &str,
body: &serde_json::Value,
cancel: CancellationToken,
tx: &tokio::sync::mpsc::Sender<Result<AssistantStreamEvent, ProviderError>>,
) -> Result<(), ProviderError> {
let req = http_client
.post(url)
.header("authorization", format!("Bearer {access_token}"))
.header("content-type", "application/json");
let response = req
.body(serde_json::to_string(body).unwrap_or_default())
.send()
.await
.map_err(|e| ProviderError::RequestFailed(e.to_string()))?;
let status = response.status();
if !status.is_success() {
let headers = response.headers().clone();
let error_body = response.text().await.unwrap_or_default();
return Err(map_vertex_status(status, &error_body, &headers));
}
let mut byte_stream = response.bytes_stream();
let mut buffer = String::new();
let mut mapper = GeminiMapper::new("vertex");
let mut saw_done = false;
loop {
let chunk = tokio::select! {
_ = cancel.cancelled() => {
return Ok(());
}
chunk = byte_stream.next() => {
match chunk {
Some(c) => c,
None => break,
}
}
};
let chunk = chunk.map_err(|e| ProviderError::StreamError(e.to_string()))?;
buffer.push_str(&String::from_utf8_lossy(&chunk));
for parsed in drain_sse_data(&mut buffer) {
match parsed {
ParsedEvent::Valid(event) => {
for stream_event in mapper.process(event) {
let is_terminal = matches!(
stream_event,
AssistantStreamEvent::Done { .. } | AssistantStreamEvent::Error { .. }
);
if tx.send(Ok(stream_event)).await.is_err() {
return Ok(());
}
if is_terminal {
saw_done = true;
}
}
}
ParsedEvent::Malformed { data, error } => {
let err = ProviderError::StreamError(format!(
"malformed SSE data: {error} (data: {data:.80})"
));
if tx.send(Err(err)).await.is_err() {
return Ok(());
}
}
}
}
}
if !saw_done {
let err = ProviderError::StreamError("stream ended without a terminal event".into());
let _ = tx.send(Err(err)).await;
}
Ok(())
}
fn map_vertex_status(
status: reqwest::StatusCode,
body: &str,
headers: &reqwest::header::HeaderMap,
) -> ProviderError {
match status.as_u16() {
401 | 403 => ProviderError::AuthFailed(format!("authentication failed: {body}")),
429 => ProviderError::RateLimited {
retry_after_ms: crate::retry::parse_retry_after(headers),
},
408 | 504 => ProviderError::Timeout,
_ => {
if let Ok(err_body) = serde_json::from_str::<serde_json::Value>(body)
&& let Some(code) = err_body
.get("error")
.and_then(|e| e.get("code"))
.and_then(|c| c.as_i64())
&& (code == 401 || code == 403)
{
return ProviderError::AuthFailed(format!("authentication failed: {body}"));
}
ProviderError::RequestFailed(format!("HTTP {}: {body}", status.as_u16()))
}
}
}
fn default_vertex_models() -> Vec<ModelInfo> {
vec![
ModelInfo {
id: "gemini-2.5-flash".into(),
display_name: "Gemini 2.5 Flash (Vertex)".into(),
context_window: 1_000_000,
max_output_tokens: 65536,
supports_images: true,
supports_streaming: true,
supports_thinking: false,
},
ModelInfo {
id: "gemini-2.5-pro".into(),
display_name: "Gemini 2.5 Pro (Vertex)".into(),
context_window: 1_000_000,
max_output_tokens: 65536,
supports_images: true,
supports_streaming: true,
supports_thinking: false,
},
]
}