1use std::sync::Arc;
17
18use async_trait::async_trait;
19use futures::stream::{BoxStream, StreamExt};
20use gcp_auth::TokenProvider;
21use polyc_llm::{
22 Chunk, CompletionRequest, Content, LlmProvider, Message, Role, StopReason, Usage,
23 sse::next_event_boundary,
24};
25use serde::Deserialize;
26
27const SCOPE: &str = "https://www.googleapis.com/auth/cloud-platform";
28
29const CONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10);
31const READ_TIMEOUT: std::time::Duration = std::time::Duration::from_mins(2);
34
35#[derive(Debug, Clone)]
37pub struct VertexConfig {
38 pub project: String,
40 pub location: String,
42 pub model: String,
44}
45
46#[derive(Debug, thiserror::Error)]
48pub enum VertexError {
49 #[error("auth: {0}")]
51 Auth(#[from] gcp_auth::Error),
52 #[error("http: {0}")]
54 Http(#[from] reqwest::Error),
55 #[error("provider returned status {status}: {body}")]
57 Provider {
58 status: u16,
60 body: String,
62 },
63}
64
65impl polyc_llm::LlmError for VertexError {
66 fn kind(&self) -> polyc_llm::LlmErrorKind {
67 use polyc_llm::LlmErrorKind;
68 match self {
69 Self::Auth(_) => LlmErrorKind::Auth,
71 Self::Http(e) if e.is_timeout() => LlmErrorKind::Timeout,
72 Self::Http(_) => LlmErrorKind::Unavailable,
73 Self::Provider { status, .. } => polyc_llm::kind_from_http_status(*status),
74 }
75 }
76}
77
78pub struct VertexProvider {
80 http: reqwest::Client,
81 tokens: Arc<dyn TokenProvider>,
82 config: VertexConfig,
83}
84
85impl VertexProvider {
86 pub async fn new(config: VertexConfig) -> Result<Self, VertexError> {
92 let tokens = gcp_auth::provider().await?;
93 let http = reqwest::Client::builder()
97 .connect_timeout(CONNECT_TIMEOUT)
98 .read_timeout(READ_TIMEOUT)
99 .build()
100 .unwrap_or_else(|_| reqwest::Client::new());
101 Ok(Self {
102 http,
103 tokens,
104 config,
105 })
106 }
107
108 fn endpoint(&self, model: &str) -> String {
115 let VertexConfig {
116 project, location, ..
117 } = &self.config;
118 let host = if location == "global" {
124 "aiplatform.googleapis.com".to_owned()
125 } else {
126 format!("{location}-aiplatform.googleapis.com")
127 };
128 format!(
129 "https://{host}/v1/projects/{project}/locations/{location}/publishers/google/models/{model}:streamGenerateContent?alt=sse"
130 )
131 }
132}
133
134#[async_trait]
135impl LlmProvider for VertexProvider {
136 type Error = VertexError;
137
138 async fn complete(
139 &self,
140 req: CompletionRequest,
141 ) -> Result<BoxStream<'static, Result<Chunk, Self::Error>>, Self::Error> {
142 let model = if req.model.is_empty() {
146 self.config.model.as_str()
147 } else {
148 req.model.as_str()
149 };
150 let body = build_request(&req);
151 tracing::debug!(
152 model = %model,
153 messages = req.messages.len(),
154 tools = req.tools.len(),
155 max_tokens = ?req.max_tokens,
156 temperature = ?req.temperature,
157 body = %body,
158 "vertex request"
159 );
160 let token = self.tokens.token(&[SCOPE]).await?;
161 let resp = self
162 .http
163 .post(self.endpoint(model))
164 .bearer_auth(token.as_str())
165 .json(&body)
166 .send()
167 .await?;
168
169 let status = resp.status();
170 if !status.is_success() {
171 let body = resp.text().await.unwrap_or_default();
172 return Err(VertexError::Provider {
173 status: status.as_u16(),
174 body,
175 });
176 }
177
178 let byte_stream = resp.bytes_stream();
184 let chunks = async_stream::stream! {
185 use futures::StreamExt as _;
186 let mut byte_stream = byte_stream;
187 let mut buf: Vec<u8> = Vec::new();
188 let mut tool_seq = 0usize;
191 while let Some(item) = byte_stream.next().await {
192 let bytes = match item {
193 Ok(b) => b,
194 Err(e) => { yield Err(VertexError::from(e)); return; }
195 };
196 buf.extend_from_slice(&bytes);
197 while let Some((pos, sep_len)) = next_event_boundary(&buf) {
198 let event_bytes: Vec<u8> = buf.drain(..pos + sep_len).collect();
199 let event = std::str::from_utf8(&event_bytes[..event_bytes.len() - sep_len])
201 .unwrap_or("");
202 for line in event.lines() {
203 let Some(json) = line.strip_prefix("data: ").or_else(|| line.strip_prefix("data:")) else {
204 continue;
205 };
206 tracing::debug!(event = %json, "vertex sse event");
207 match serde_json::from_str::<GenerateContentResponse>(json) {
208 Ok(resp) => {
209 for chunk in map_response(resp, &mut tool_seq) {
210 yield chunk;
211 }
212 }
213 Err(err) => {
214 yield Err(VertexError::Provider {
217 status: 0,
218 body: format!("malformed SSE JSON: {err}; line: {json}"),
219 });
220 }
221 }
222 }
223 }
224 }
225 };
226 Ok(chunks.boxed())
227 }
228}
229
230fn build_request(req: &CompletionRequest) -> serde_json::Value {
232 let mut contents = Vec::new();
233 let mut system_parts = Vec::new();
234
235 for msg in &req.messages {
236 if msg.role == Role::System {
237 for c in &msg.content {
238 if let Content::Text(t) = c {
239 system_parts.push(serde_json::json!({ "text": t }));
240 }
241 }
242 } else {
243 let role = if msg.role == Role::Assistant {
244 "model"
245 } else {
246 "user"
247 };
248 let parts = message_parts(msg);
249 if !parts.is_empty() {
250 contents.push(serde_json::json!({ "role": role, "parts": parts }));
251 }
252 }
253 }
254
255 let mut body = serde_json::json!({ "contents": contents });
256 if !system_parts.is_empty() {
257 body["systemInstruction"] = serde_json::json!({ "parts": system_parts });
258 }
259 let mut gen_config = serde_json::Map::new();
260 if let Some(max) = req.max_tokens {
261 gen_config.insert("maxOutputTokens".into(), max.into());
262 }
263 if let Some(temp) = req.temperature {
264 gen_config.insert("temperature".into(), temp.into());
265 }
266 if !req.stop.is_empty() {
267 gen_config.insert("stopSequences".into(), serde_json::json!(req.stop));
268 }
269 if !gen_config.is_empty() {
270 body["generationConfig"] = serde_json::Value::Object(gen_config);
271 }
272 let mut tool_entries: Vec<serde_json::Value> = Vec::new();
279 if !req.tools.is_empty() {
280 let decls: Vec<_> = req
281 .tools
282 .iter()
283 .map(|t| {
284 serde_json::json!({
285 "name": t.name,
286 "description": t.description,
287 "parameters": sanitize_schema_for_gemini(&t.schema_json),
288 })
289 })
290 .collect();
291 tool_entries.push(serde_json::json!({ "functionDeclarations": decls }));
292 }
293 if req.web_search {
294 tool_entries.push(serde_json::json!({ "googleSearch": {} }));
295 }
296 if !tool_entries.is_empty() {
297 body["tools"] = serde_json::Value::Array(tool_entries);
298 }
299 body
300}
301
302const GEMINI_UNSUPPORTED_SCHEMA_KEYS: &[&str] = &[
316 "$schema",
317 "$id",
318 "$ref",
319 "$defs",
320 "$comment",
321 "definitions",
322 "additionalProperties",
323 "unevaluatedProperties",
324 "patternProperties",
325 "exclusiveMinimum",
326 "exclusiveMaximum",
327];
328
329fn sanitize_schema_for_gemini(value: &serde_json::Value) -> serde_json::Value {
333 match value {
334 serde_json::Value::Object(map) => serde_json::Value::Object(
335 map.iter()
336 .filter(|(k, _)| !GEMINI_UNSUPPORTED_SCHEMA_KEYS.contains(&k.as_str()))
337 .map(|(k, v)| (k.clone(), sanitize_schema_for_gemini(v)))
338 .collect(),
339 ),
340 serde_json::Value::Array(arr) => {
341 serde_json::Value::Array(arr.iter().map(sanitize_schema_for_gemini).collect())
342 }
343 other => other.clone(),
344 }
345}
346
347fn message_parts(msg: &Message) -> Vec<serde_json::Value> {
349 let mut parts = Vec::new();
350 for c in &msg.content {
351 match c {
352 Content::Text(t) => parts.push(serde_json::json!({ "text": t })),
353 Content::ToolUse(tc) => {
354 let args: serde_json::Value =
355 serde_json::from_str(&tc.args_json).unwrap_or(serde_json::Value::Null);
356 let mut part = serde_json::json!({
357 "functionCall": { "name": tc.name, "args": args }
358 });
359 if let Some(sig) = &tc.signature {
364 part["thoughtSignature"] = serde_json::json!(sig);
365 }
366 parts.push(part);
367 }
368 Content::ToolResult(tr) => {
369 let result: serde_json::Value =
370 serde_json::from_str(&tr.result_json).unwrap_or(serde_json::Value::Null);
371 parts.push(serde_json::json!({
372 "functionResponse": { "name": tr.tool_call_id, "response": { "result": result } }
373 }));
374 }
375 _ => {}
377 }
378 }
379 parts
380}
381
382fn map_response(
390 resp: GenerateContentResponse,
391 tool_seq: &mut usize,
392) -> Vec<Result<Chunk, VertexError>> {
393 let mut chunks = Vec::new();
394 let candidate = resp.candidates.into_iter().next();
395
396 let mut text = String::new();
397 let mut tool_calls = Vec::new();
398 let mut finish = None;
399 if let Some(c) = candidate {
400 finish = c.finish_reason;
401 if let Some(content) = c.content {
402 for part in content.parts {
403 if let Some(t) = part.text {
404 text.push_str(&t);
405 }
406 if let Some(fc) = part.function_call {
407 tool_calls.push((fc, part.thought_signature));
410 }
411 }
412 }
413 }
414
415 if !text.is_empty() {
416 chunks.push(Ok(Chunk::text_delta(text)));
417 }
418 for (fc, signature) in &tool_calls {
419 let id = format!("call-{}", *tool_seq);
420 *tool_seq += 1;
421 chunks.push(Ok(Chunk::tool_call_start_signed(
422 id.clone(),
423 fc.name.clone(),
424 signature.clone(),
425 )));
426 chunks.push(Ok(Chunk::tool_call_args_delta(
427 id.clone(),
428 fc.args.to_string(),
429 )));
430 chunks.push(Ok(Chunk::tool_call_end(id)));
431 }
432 if let Some(u) = resp.usage_metadata {
433 chunks.push(Ok(Chunk::Usage(Usage {
434 input_tokens: u.prompt_token_count,
435 output_tokens: u.candidates_token_count,
436 })));
437 }
438 if finish.is_some() || !tool_calls.is_empty() {
444 let mapped = map_finish_reason(finish.as_deref());
445 let stop = if !tool_calls.is_empty() && matches!(mapped, StopReason::EndTurn) {
451 StopReason::ToolUse
452 } else {
453 mapped
454 };
455 chunks.push(Ok(Chunk::Stop(stop)));
456 }
457 chunks
458}
459
460fn map_finish_reason(reason: Option<&str>) -> StopReason {
461 match reason {
462 Some("MAX_TOKENS") => StopReason::MaxTokens,
463 Some("STOP_SEQUENCE") => StopReason::StopSequence,
464 Some("SAFETY" | "RECITATION" | "BLOCKLIST" | "PROHIBITED_CONTENT" | "SPII") => {
465 StopReason::Refusal
466 }
467 _ => StopReason::EndTurn,
468 }
469}
470
471#[derive(Deserialize)]
474struct GenerateContentResponse {
475 #[serde(default)]
476 candidates: Vec<Candidate>,
477 #[serde(default, rename = "usageMetadata")]
478 usage_metadata: Option<UsageMetadata>,
479}
480
481#[derive(Deserialize)]
482struct Candidate {
483 #[serde(default)]
484 content: Option<RespContent>,
485 #[serde(default, rename = "finishReason")]
486 finish_reason: Option<String>,
487}
488
489#[derive(Deserialize)]
490struct RespContent {
491 #[serde(default)]
492 parts: Vec<Part>,
493}
494
495#[derive(Deserialize)]
496struct Part {
497 #[serde(default)]
498 text: Option<String>,
499 #[serde(default, rename = "functionCall")]
500 function_call: Option<FunctionCall>,
501 #[serde(default, rename = "thoughtSignature")]
502 thought_signature: Option<String>,
503}
504
505#[derive(Deserialize)]
506struct FunctionCall {
507 name: String,
508 #[serde(default)]
509 args: serde_json::Value,
510}
511
512#[derive(Deserialize)]
513struct UsageMetadata {
514 #[serde(default, rename = "promptTokenCount")]
515 prompt_token_count: u64,
516 #[serde(default, rename = "candidatesTokenCount")]
517 candidates_token_count: u64,
518}
519
520#[cfg(test)]
521mod tests {
522 #![allow(clippy::pedantic, clippy::nursery, missing_docs)]
523
524 use super::*;
525
526 #[test]
527 fn maps_text_and_usage_and_stop() {
528 let resp: GenerateContentResponse = serde_json::from_value(serde_json::json!({
529 "candidates": [{
530 "content": { "role": "model", "parts": [{ "text": "parity" }] },
531 "finishReason": "STOP"
532 }],
533 "usageMetadata": { "promptTokenCount": 5, "candidatesTokenCount": 2 }
534 }))
535 .unwrap();
536 let chunks: Vec<_> = map_response(resp, &mut 0)
537 .into_iter()
538 .map(Result::unwrap)
539 .collect();
540 assert_eq!(chunks[0], Chunk::text_delta("parity"));
541 assert!(matches!(
542 chunks[chunks.len() - 1],
543 Chunk::Stop(StopReason::EndTurn)
544 ));
545 }
546
547 #[test]
548 fn maps_function_call_to_tool_chunks() {
549 let resp: GenerateContentResponse = serde_json::from_value(serde_json::json!({
550 "candidates": [{
551 "content": { "parts": [{ "functionCall": { "name": "search", "args": { "q": "rust" } } }] },
552 "finishReason": "STOP"
553 }]
554 }))
555 .unwrap();
556 let chunks: Vec<_> = map_response(resp, &mut 0)
557 .into_iter()
558 .map(Result::unwrap)
559 .collect();
560 assert!(
561 chunks
562 .iter()
563 .any(|c| matches!(c, Chunk::ToolCallStart { name, .. } if name == "search"))
564 );
565 assert!(matches!(
566 chunks[chunks.len() - 1],
567 Chunk::Stop(StopReason::ToolUse)
568 ));
569 }
570
571 #[test]
572 fn build_request_maps_roles_and_system() {
573 let mut req = CompletionRequest::new("m");
574 req.system = None;
575 req.messages = vec![Message::system("be terse"), Message::user("hi")];
576 let body = build_request(&req);
577 assert_eq!(body["systemInstruction"]["parts"][0]["text"], "be terse");
578 assert_eq!(body["contents"][0]["role"], "user");
579 assert_eq!(body["contents"][0]["parts"][0]["text"], "hi");
580 }
581
582 #[test]
583 fn build_request_strips_gemini_incompatible_tool_schema_keys() {
584 use polyc_llm::ToolSpec;
585 let mut req = CompletionRequest::new("m");
586 req.tools = vec![ToolSpec {
590 name: "list_recent".to_owned(),
591 description: "recent".to_owned(),
592 schema_json: serde_json::json!({
593 "$schema": "http://json-schema.org/draft-07/schema#",
594 "type": "object",
595 "additionalProperties": false,
596 "properties": {
597 "limit": { "type": "integer", "exclusiveMinimum": 0, "maximum": 100 }
598 }
599 }),
600 title: None,
601 needs_approval: false,
602 }];
603 let params = &build_request(&req)["tools"][0]["functionDeclarations"][0]["parameters"];
604 assert!(params.get("$schema").is_none());
605 assert!(params.get("additionalProperties").is_none());
606 assert!(
607 params["properties"]["limit"]
608 .get("exclusiveMinimum")
609 .is_none()
610 );
611 assert_eq!(params["type"], "object");
613 assert_eq!(params["properties"]["limit"]["type"], "integer");
614 assert_eq!(params["properties"]["limit"]["maximum"], 100);
615 }
616
617 #[test]
618 fn web_search_adds_google_search_grounding_tool() {
619 use polyc_llm::ToolSpec;
620 let off = CompletionRequest::new("m");
622 assert!(build_request(&off).get("tools").is_none());
623
624 let mut grounded = CompletionRequest::new("m");
626 grounded.web_search = true;
627 let tools = build_request(&grounded)["tools"].clone();
628 assert_eq!(tools, serde_json::json!([{ "googleSearch": {} }]));
629
630 grounded.tools = vec![ToolSpec {
633 name: "list_recent".to_owned(),
634 description: "recent".to_owned(),
635 schema_json: serde_json::json!({ "type": "object" }),
636 title: None,
637 needs_approval: false,
638 }];
639 let tools = build_request(&grounded)["tools"].clone();
640 assert_eq!(tools.as_array().map(Vec::len), Some(2));
641 assert!(tools[0].get("functionDeclarations").is_some());
642 assert_eq!(tools[1], serde_json::json!({ "googleSearch": {} }));
643 }
644}