1use super::util;
11use super::{
12 CompletionRequest, CompletionResponse, ContentPart, FinishReason, Message, ModelInfo, Provider,
13 Role, StreamChunk, ToolDefinition, Usage,
14};
15use anyhow::{Context, Result};
16use async_trait::async_trait;
17use futures::StreamExt;
18use jsonwebtoken::{Algorithm, EncodingKey, Header};
19use reqwest::Client;
20use serde::{Deserialize, Serialize};
21use serde_json::{Value, json};
22use std::sync::Arc;
23use std::time::Duration;
24use tokio::sync::RwLock;
25
26const REQUEST_TIMEOUT: Duration = Duration::from_secs(120);
27const CONNECT_TIMEOUT: Duration = Duration::from_secs(10);
28const MAX_RETRIES: u32 = 3;
29
30const VERTEX_ENDPOINT: &str = "aiplatform.googleapis.com";
31const VERTEX_REGION: &str = "global";
32const GOOGLE_TOKEN_URL: &str = "https://oauth2.googleapis.com/token";
33const VERTEX_SCOPE: &str = "https://www.googleapis.com/auth/cloud-platform";
34
35struct CachedToken {
37 token: String,
38 expires_at: std::time::Instant,
39}
40
41#[derive(Debug, Clone, Deserialize)]
43struct ServiceAccountKey {
44 client_email: String,
45 private_key: String,
46 token_uri: Option<String>,
47 project_id: Option<String>,
48}
49
50#[derive(Serialize)]
52struct JwtClaims {
53 iss: String,
54 scope: String,
55 aud: String,
56 iat: u64,
57 exp: u64,
58}
59
60pub struct VertexGlmProvider {
61 client: Client,
62 project_id: String,
63 base_url: String,
64 sa_key: ServiceAccountKey,
65 encoding_key: EncodingKey,
66 cached_token: Arc<RwLock<Option<CachedToken>>>,
68}
69
70impl std::fmt::Debug for VertexGlmProvider {
71 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
72 f.debug_struct("VertexGlmProvider")
73 .field("project_id", &self.project_id)
74 .field("base_url", &self.base_url)
75 .field("client_email", &self.sa_key.client_email)
76 .finish()
77 }
78}
79
80impl VertexGlmProvider {
81 pub fn new(sa_json: &str, project_id: Option<String>) -> Result<Self> {
83 let sa_key: ServiceAccountKey =
84 serde_json::from_str(sa_json).context("Failed to parse service account JSON key")?;
85
86 let project_id = project_id
87 .or_else(|| sa_key.project_id.clone())
88 .ok_or_else(|| anyhow::anyhow!("No project_id found in SA key or Vault config"))?;
89
90 let encoding_key = EncodingKey::from_rsa_pem(sa_key.private_key.as_bytes())
91 .context("Failed to parse RSA private key from service account")?;
92
93 let base_url = format!(
94 "https://{}/v1/projects/{}/locations/{}/endpoints/openapi",
95 VERTEX_ENDPOINT, project_id, VERTEX_REGION
96 );
97
98 tracing::debug!(
99 provider = "vertex-glm",
100 project_id = %project_id,
101 client_email = %sa_key.client_email,
102 base_url = %base_url,
103 "Creating Vertex GLM provider with service account"
104 );
105
106 let client = Client::builder()
107 .connect_timeout(CONNECT_TIMEOUT)
108 .timeout(REQUEST_TIMEOUT)
109 .build()
110 .context("Failed to build HTTP client")?;
111
112 Ok(Self {
113 client,
114 project_id,
115 base_url,
116 sa_key,
117 encoding_key,
118 cached_token: Arc::new(RwLock::new(None)),
119 })
120 }
121
122 async fn get_access_token(&self) -> Result<String> {
124 {
126 let cache = self.cached_token.read().await;
127 if let Some(ref cached) = *cache
128 && cached.expires_at
129 > std::time::Instant::now() + std::time::Duration::from_secs(300)
130 {
131 return Ok(cached.token.clone());
132 }
133 }
134
135 let now = std::time::SystemTime::now()
137 .duration_since(std::time::UNIX_EPOCH)
138 .context("System time error")?
139 .as_secs();
140
141 let token_uri = self.sa_key.token_uri.as_deref().unwrap_or(GOOGLE_TOKEN_URL);
142
143 let claims = JwtClaims {
144 iss: self.sa_key.client_email.clone(),
145 scope: VERTEX_SCOPE.to_string(),
146 aud: token_uri.to_string(),
147 iat: now,
148 exp: now + 3600,
149 };
150
151 let header = Header::new(Algorithm::RS256);
152 let assertion = jsonwebtoken::encode(&header, &claims, &self.encoding_key)
153 .context("Failed to sign JWT assertion")?;
154
155 let form_body = format!(
157 "grant_type={}&assertion={}",
158 urlencoding::encode("urn:ietf:params:oauth:grant-type:jwt-bearer"),
159 urlencoding::encode(&assertion),
160 );
161 let response = self
162 .client
163 .post(token_uri)
164 .header("Content-Type", "application/x-www-form-urlencoded")
165 .body(form_body)
166 .send()
167 .await
168 .context("Failed to exchange JWT for access token")?;
169
170 let status = response.status();
171 let body = response
172 .text()
173 .await
174 .context("Failed to read token response")?;
175
176 if !status.is_success() {
177 anyhow::bail!("GCP token exchange failed: {status} {body}");
178 }
179
180 #[derive(Deserialize)]
181 struct TokenResponse {
182 access_token: String,
183 #[serde(default)]
184 expires_in: Option<u64>,
185 }
186
187 let token_resp: TokenResponse =
188 serde_json::from_str(&body).context("Failed to parse GCP token response")?;
189
190 let expires_in = token_resp.expires_in.unwrap_or(3600);
191
192 {
194 let mut cache = self.cached_token.write().await;
195 *cache = Some(CachedToken {
196 token: token_resp.access_token.clone(),
197 expires_at: std::time::Instant::now() + std::time::Duration::from_secs(expires_in),
198 });
199 }
200
201 tracing::debug!(
202 client_email = %self.sa_key.client_email,
203 expires_in_secs = expires_in,
204 "Refreshed GCP access token via service account JWT"
205 );
206
207 Ok(token_resp.access_token)
208 }
209
210 fn convert_messages(messages: &[Message]) -> Vec<Value> {
211 messages
212 .iter()
213 .map(|msg| {
214 let role = match msg.role {
215 Role::System => "system",
216 Role::User => "user",
217 Role::Assistant => "assistant",
218 Role::Tool => "tool",
219 };
220
221 match msg.role {
222 Role::Tool => {
223 if let Some(ContentPart::ToolResult {
224 tool_call_id,
225 content,
226 }) = msg.content.first()
227 {
228 json!({
229 "role": "tool",
230 "tool_call_id": tool_call_id,
231 "content": content
232 })
233 } else {
234 json!({"role": role, "content": ""})
235 }
236 }
237 Role::Assistant => {
238 let text: String = msg
239 .content
240 .iter()
241 .filter_map(|p| match p {
242 ContentPart::Text { text } => Some(text.clone()),
243 _ => None,
244 })
245 .collect::<Vec<_>>()
246 .join("");
247
248 let tool_calls: Vec<Value> = msg
249 .content
250 .iter()
251 .filter_map(|p| match p {
252 ContentPart::ToolCall {
253 id,
254 name,
255 arguments,
256 ..
257 } => Some(json!({
258 "id": id,
259 "type": "function",
260 "function": {
261 "name": name,
262 "arguments": arguments
263 }
264 })),
265 _ => None,
266 })
267 .collect();
268
269 let mut msg_json = json!({
270 "role": "assistant",
271 "content": if text.is_empty() { Value::Null } else { json!(text) },
272 });
273
274 if !tool_calls.is_empty() {
275 msg_json["tool_calls"] = json!(tool_calls);
276 }
277 msg_json
278 }
279 _ => {
280 let text: String = msg
281 .content
282 .iter()
283 .filter_map(|p| match p {
284 ContentPart::Text { text } => Some(text.clone()),
285 _ => None,
286 })
287 .collect::<Vec<_>>()
288 .join("\n");
289
290 json!({"role": role, "content": text})
291 }
292 }
293 })
294 .collect()
295 }
296
297 fn convert_tools(tools: &[ToolDefinition]) -> Vec<Value> {
298 tools
299 .iter()
300 .map(|t| {
301 json!({
302 "type": "function",
303 "function": {
304 "name": t.name,
305 "description": t.description,
306 "parameters": t.parameters
307 }
308 })
309 })
310 .collect()
311 }
312}
313
314#[derive(Debug, Deserialize)]
316struct ChatCompletion {
317 choices: Vec<Choice>,
318 #[serde(default)]
319 usage: Option<ApiUsage>,
320}
321
322#[derive(Debug, Deserialize)]
323struct Choice {
324 message: ChoiceMessage,
325 #[serde(default)]
326 finish_reason: Option<String>,
327}
328
329#[derive(Debug, Deserialize)]
330struct ChoiceMessage {
331 #[serde(default)]
332 content: Option<String>,
333 #[serde(default)]
334 tool_calls: Option<Vec<ToolCall>>,
335}
336
337#[derive(Debug, Deserialize)]
338struct ToolCall {
339 id: String,
340 function: FunctionCall,
341}
342
343#[derive(Debug, Deserialize)]
344struct FunctionCall {
345 name: String,
346 arguments: String,
347}
348
349#[derive(Debug, Deserialize)]
350struct ApiUsage {
351 #[serde(default)]
352 prompt_tokens: usize,
353 #[serde(default)]
354 completion_tokens: usize,
355 #[serde(default)]
356 total_tokens: usize,
357 #[serde(default)]
359 prompt_tokens_details: Option<VertexGlmPromptTokenDetails>,
360}
361
362#[derive(Debug, Deserialize, Default)]
363struct VertexGlmPromptTokenDetails {
364 #[serde(default)]
365 cached_tokens: usize,
366}
367
368impl ApiUsage {
369 fn cached(&self) -> usize {
370 self.prompt_tokens_details
371 .as_ref()
372 .map(|d| d.cached_tokens)
373 .unwrap_or(0)
374 }
375}
376
377#[derive(Debug, Deserialize)]
378struct ApiError {
379 error: ApiErrorDetail,
380}
381
382#[derive(Debug, Deserialize)]
383struct ApiErrorDetail {
384 message: String,
385 #[serde(default, rename = "type")]
386 error_type: Option<String>,
387}
388
389#[derive(Debug, Deserialize)]
391struct StreamResponse {
392 choices: Vec<StreamChoice>,
393}
394
395#[derive(Debug, Deserialize)]
396struct StreamChoice {
397 delta: StreamDelta,
398 #[serde(default)]
399 finish_reason: Option<String>,
400}
401
402#[derive(Debug, Deserialize)]
403struct StreamDelta {
404 #[serde(default)]
405 content: Option<String>,
406 #[serde(default)]
407 tool_calls: Option<Vec<StreamToolCall>>,
408}
409
410#[derive(Debug, Deserialize)]
411struct StreamToolCall {
412 #[serde(default)]
413 id: Option<String>,
414 function: Option<StreamFunction>,
415}
416
417#[derive(Debug, Deserialize)]
418struct StreamFunction {
419 #[serde(default)]
420 name: Option<String>,
421 #[serde(default)]
422 arguments: Option<String>,
423}
424
425#[async_trait]
426impl Provider for VertexGlmProvider {
427 fn name(&self) -> &str {
428 "vertex-glm"
429 }
430
431 async fn list_models(&self) -> Result<Vec<ModelInfo>> {
432 Ok(vec![
433 ModelInfo {
434 id: "zai-org/glm-5-maas".to_string(),
435 name: "GLM-5 (Vertex AI MaaS)".to_string(),
436 provider: "vertex-glm".to_string(),
437 context_window: 200_000,
438 max_output_tokens: Some(128_000),
439 supports_vision: false,
440 supports_tools: true,
441 supports_streaming: true,
442 input_cost_per_million: Some(1.0),
443 output_cost_per_million: Some(3.2),
444 },
445 ModelInfo {
446 id: "glm-5".to_string(),
447 name: "GLM-5 (Vertex AI)".to_string(),
448 provider: "vertex-glm".to_string(),
449 context_window: 200_000,
450 max_output_tokens: Some(128_000),
451 supports_vision: false,
452 supports_tools: true,
453 supports_streaming: true,
454 input_cost_per_million: Some(1.0),
455 output_cost_per_million: Some(3.2),
456 },
457 ])
458 }
459
460 async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
461 let mut access_token = self.get_access_token().await?;
462
463 let messages = Self::convert_messages(&request.messages);
464 let tools = Self::convert_tools(&request.tools);
465
466 let model = if request.model.starts_with("zai-org/") {
468 request.model.clone()
469 } else {
470 format!(
471 "zai-org/{}-maas",
472 request.model.trim_start_matches("zai-org/")
473 )
474 };
475
476 let temperature = request.temperature.unwrap_or(1.0);
478
479 let mut body = json!({
480 "model": model,
481 "messages": messages,
482 "temperature": temperature,
483 "stream": false,
484 });
485
486 if !tools.is_empty() {
487 body["tools"] = json!(tools);
488 }
489 if let Some(max) = request.max_tokens {
490 body["max_tokens"] = json!(max);
491 }
492
493 tracing::debug!(model = %request.model, "Vertex GLM request");
494
495 let url = format!("{}/chat/completions", self.base_url);
496 let mut last_err = None;
497
498 for attempt in 0..MAX_RETRIES {
499 if attempt > 0 {
500 let backoff = Duration::from_millis(1000 * 2u64.pow(attempt - 1));
501 tracing::warn!(
502 attempt,
503 backoff_ms = backoff.as_millis() as u64,
504 "Vertex GLM retrying after transient error"
505 );
506 tokio::time::sleep(backoff).await;
507 access_token = self.get_access_token().await?;
509 }
510
511 let send_result = self
512 .client
513 .post(&url)
514 .bearer_auth(&access_token)
515 .header("Content-Type", "application/json")
516 .json(&body)
517 .send()
518 .await;
519
520 let response = match send_result {
521 Ok(r) => r,
522 Err(e) if e.is_timeout() && attempt + 1 < MAX_RETRIES => {
523 tracing::warn!(error = %e, "Vertex GLM request timed out");
524 last_err = Some(format!("Request timed out: {e}"));
525 continue;
526 }
527 Err(e) => anyhow::bail!("Failed to send request to Vertex AI GLM: {e}"),
528 };
529
530 let status = response.status();
531 let text = response
532 .text()
533 .await
534 .context("Failed to read Vertex AI GLM response")?;
535
536 if status == reqwest::StatusCode::SERVICE_UNAVAILABLE && attempt + 1 < MAX_RETRIES {
537 tracing::warn!(status = %status, body = %text, "Vertex GLM service unavailable, retrying");
538 last_err = Some(format!("503 Service Unavailable: {text}"));
539 continue;
540 }
541
542 if !status.is_success() {
543 if let Ok(err) = serde_json::from_str::<ApiError>(&text) {
544 anyhow::bail!(
545 "Vertex AI GLM API error: {} ({:?})",
546 err.error.message,
547 err.error.error_type
548 );
549 }
550 anyhow::bail!("Vertex AI GLM API error: {} {}", status, text);
551 }
552
553 let completion: ChatCompletion = serde_json::from_str(&text).context(format!(
554 "Failed to parse Vertex AI GLM response: {}",
555 util::truncate_bytes_safe(&text, 200)
556 ))?;
557
558 let choice = completion
559 .choices
560 .first()
561 .ok_or_else(|| anyhow::anyhow!("No choices in Vertex AI GLM response"))?;
562
563 let mut content = Vec::new();
564 let mut has_tool_calls = false;
565
566 if let Some(text) = &choice.message.content
567 && !text.is_empty()
568 {
569 content.push(ContentPart::Text { text: text.clone() });
570 }
571
572 if let Some(tool_calls) = &choice.message.tool_calls {
573 has_tool_calls = !tool_calls.is_empty();
574 for tc in tool_calls {
575 content.push(ContentPart::ToolCall {
576 id: tc.id.clone(),
577 name: tc.function.name.clone(),
578 arguments: tc.function.arguments.clone(),
579 thought_signature: None,
580 });
581 }
582 }
583
584 let finish_reason = if has_tool_calls {
585 FinishReason::ToolCalls
586 } else {
587 match choice.finish_reason.as_deref() {
588 Some("stop") => FinishReason::Stop,
589 Some("length") => FinishReason::Length,
590 Some("tool_calls") => FinishReason::ToolCalls,
591 Some("content_filter") => FinishReason::ContentFilter,
592 _ => FinishReason::Stop,
593 }
594 };
595
596 return Ok(CompletionResponse {
597 message: Message {
598 role: Role::Assistant,
599 content,
600 },
601 usage: Usage {
602 prompt_tokens: completion
603 .usage
604 .as_ref()
605 .map(|u| u.prompt_tokens.saturating_sub(u.cached()))
606 .unwrap_or(0),
607 completion_tokens: completion
608 .usage
609 .as_ref()
610 .map(|u| u.completion_tokens)
611 .unwrap_or(0),
612 total_tokens: completion
613 .usage
614 .as_ref()
615 .map(|u| u.total_tokens)
616 .unwrap_or(0),
617 cache_read_tokens: completion
618 .usage
619 .as_ref()
620 .map(ApiUsage::cached)
621 .filter(|&n| n > 0),
622 cache_write_tokens: None,
623 },
624 finish_reason,
625 });
626 }
627
628 anyhow::bail!(
629 "Vertex AI GLM request failed after {MAX_RETRIES} attempts: {}",
630 last_err.unwrap_or_default()
631 )
632 }
633
634 async fn complete_stream(
635 &self,
636 request: CompletionRequest,
637 ) -> Result<futures::stream::BoxStream<'static, StreamChunk>> {
638 let mut access_token = self.get_access_token().await?;
639
640 let messages = Self::convert_messages(&request.messages);
641 let tools = Self::convert_tools(&request.tools);
642
643 let model = if request.model.starts_with("zai-org") {
645 request.model.clone()
646 } else {
647 format!(
648 "zai-org/{}-maas",
649 request.model.trim_start_matches("zai-org/")
650 )
651 };
652
653 let temperature = request.temperature.unwrap_or(1.0);
654
655 let mut body = json!({
656 "model": model,
657 "messages": messages,
658 "temperature": temperature,
659 "stream": true,
660 });
661
662 if !tools.is_empty() {
663 body["tools"] = json!(tools);
664 }
665 if let Some(max) = request.max_tokens {
666 body["max_tokens"] = json!(max);
667 }
668
669 tracing::debug!(model = %request.model, "Vertex GLM streaming request");
670
671 let url = format!("{}/chat/completions", self.base_url);
672 let mut last_err = String::new();
673
674 for attempt in 0..MAX_RETRIES {
675 if attempt > 0 {
676 let backoff = Duration::from_millis(1000 * 2u64.pow(attempt - 1));
677 tracing::warn!(
678 attempt,
679 backoff_ms = backoff.as_millis() as u64,
680 "Vertex GLM streaming retrying after transient error"
681 );
682 tokio::time::sleep(backoff).await;
683 access_token = self.get_access_token().await?;
684 }
685
686 let send_result = self
687 .client
688 .post(&url)
689 .bearer_auth(&access_token)
690 .header("Content-Type", "application/json")
691 .json(&body)
692 .send()
693 .await;
694
695 let response = match send_result {
696 Ok(r) => r,
697 Err(e) if e.is_timeout() && attempt + 1 < MAX_RETRIES => {
698 tracing::warn!(error = %e, "Vertex GLM streaming request timed out");
699 last_err = format!("Request timed out: {e}");
700 continue;
701 }
702 Err(e) => anyhow::bail!("Failed to send streaming request to Vertex AI GLM: {e}"),
703 };
704
705 if response.status() == reqwest::StatusCode::SERVICE_UNAVAILABLE
706 && attempt + 1 < MAX_RETRIES
707 {
708 let text = response.text().await.unwrap_or_default();
709 tracing::warn!(body = %text, "Vertex GLM streaming service unavailable, retrying");
710 last_err = format!("503 Service Unavailable: {text}");
711 continue;
712 }
713
714 if !response.status().is_success() {
715 let status = response.status();
716 let text = response.text().await.unwrap_or_default();
717 if let Ok(err) = serde_json::from_str::<ApiError>(&text) {
718 anyhow::bail!(
719 "Vertex AI GLM API error: {} ({:?})",
720 err.error.message,
721 err.error.error_type
722 );
723 }
724 anyhow::bail!("Vertex AI GLM streaming error: {} {}", status, text);
725 }
726
727 let stream = response.bytes_stream();
728 let mut buffer = String::new();
729
730 return Ok(stream
731 .flat_map(move |chunk_result| {
732 let mut chunks: Vec<StreamChunk> = Vec::new();
733 match chunk_result {
734 Ok(bytes) => {
735 let text = String::from_utf8_lossy(&bytes);
736 buffer.push_str(&text);
737
738 let mut text_buf = String::new();
739
740 while let Some(line_end) = buffer.find('\n') {
741 let line = buffer[..line_end].trim().to_string();
742 buffer = buffer[line_end + 1..].to_string();
743
744 if line == "data: [DONE]" {
745 if !text_buf.is_empty() {
746 chunks
747 .push(StreamChunk::Text(std::mem::take(&mut text_buf)));
748 }
749 chunks.push(StreamChunk::Done { usage: None });
750 continue;
751 }
752 if let Some(data) = line.strip_prefix("data: ")
753 && let Ok(parsed) = serde_json::from_str::<StreamResponse>(data)
754 && let Some(choice) = parsed.choices.first()
755 {
756 if let Some(ref content) = choice.delta.content {
757 text_buf.push_str(content);
758 }
759 if let Some(ref tool_calls) = choice.delta.tool_calls {
760 if !text_buf.is_empty() {
761 chunks.push(StreamChunk::Text(std::mem::take(
762 &mut text_buf,
763 )));
764 }
765 for tc in tool_calls {
766 if let Some(ref func) = tc.function {
767 let id = tc.id.clone().unwrap_or_default();
768 if let Some(ref name) = func.name {
769 chunks.push(StreamChunk::ToolCallStart {
770 id: id.clone(),
771 name: name.clone(),
772 });
773 }
774 if let Some(ref args) = func.arguments {
775 chunks.push(StreamChunk::ToolCallDelta {
776 id: id.clone(),
777 arguments_delta: args.clone(),
778 });
779 }
780 }
781 }
782 }
783 if let Some(ref reason) = choice.finish_reason {
784 if !text_buf.is_empty() {
785 chunks.push(StreamChunk::Text(std::mem::take(
786 &mut text_buf,
787 )));
788 }
789 if reason == "tool_calls"
790 && let Some(tc) = choice
791 .delta
792 .tool_calls
793 .as_ref()
794 .and_then(|t| t.last())
795 && let Some(id) = &tc.id
796 {
797 chunks
798 .push(StreamChunk::ToolCallEnd { id: id.clone() });
799 }
800 }
801 }
802 }
803 if !text_buf.is_empty() {
804 chunks.push(StreamChunk::Text(text_buf));
805 }
806 }
807 Err(e) => chunks.push(StreamChunk::Error(e.to_string())),
808 }
809 futures::stream::iter(chunks)
810 })
811 .boxed());
812 }
813
814 anyhow::bail!("Vertex AI GLM streaming failed after {MAX_RETRIES} attempts: {last_err}")
815 }
816}
817
818#[cfg(test)]
819mod tests {
820 use super::*;
821
822 #[test]
823 fn test_rejects_invalid_sa_json() {
824 let result = VertexGlmProvider::new("{}", None);
825 assert!(result.is_err());
826 }
827
828 #[test]
829 fn test_rejects_missing_project_id() {
830 let sa_json = json!({
831 "type": "service_account",
832 "client_email": "test@test.iam.gserviceaccount.com",
833 "private_key": "-----BEGIN RSA PRIVATE KEY-----\nMIIBogIBAAJBALRiMLAHudeSA/x3hB2f+2NRkJlS\n-----END RSA PRIVATE KEY-----\n",
834 "token_uri": "https://oauth2.googleapis.com/token"
835 });
836 let result = VertexGlmProvider::new(&sa_json.to_string(), None);
838 assert!(result.is_err());
839 }
840}