Skip to main content

codetether_agent/provider/
vertex_glm.rs

1//! Vertex AI GLM provider implementation (MaaS endpoint)
2//!
3//! GLM-5 via Google Cloud Vertex AI Managed API Service.
4//! Uses service account JWT auth to obtain OAuth2 access tokens.
5//! The service account JSON key is stored in Vault and used to sign JWTs
6//! that are exchanged for short-lived access tokens (cached ~55 min).
7//!
8//! Reference: https://console.cloud.google.com/vertex-ai/publishers/zai/model-garden/glm-5
9
10use super::{
11    CompletionRequest, CompletionResponse, ContentPart, FinishReason, Message, ModelInfo, Provider,
12    Role, StreamChunk, ToolDefinition, Usage,
13};
14use anyhow::{Context, Result};
15use async_trait::async_trait;
16use futures::StreamExt;
17use jsonwebtoken::{Algorithm, EncodingKey, Header};
18use reqwest::Client;
19use serde::{Deserialize, Serialize};
20use serde_json::{Value, json};
21use std::sync::Arc;
22use std::time::Duration;
23use tokio::sync::RwLock;
24
25const REQUEST_TIMEOUT: Duration = Duration::from_secs(120);
26const CONNECT_TIMEOUT: Duration = Duration::from_secs(10);
27const MAX_RETRIES: u32 = 3;
28
29const VERTEX_ENDPOINT: &str = "aiplatform.googleapis.com";
30const VERTEX_REGION: &str = "global";
31const GOOGLE_TOKEN_URL: &str = "https://oauth2.googleapis.com/token";
32const VERTEX_SCOPE: &str = "https://www.googleapis.com/auth/cloud-platform";
33
34/// Cached OAuth2 access token with expiration tracking
35struct CachedToken {
36    token: String,
37    expires_at: std::time::Instant,
38}
39
40/// GCP service account key (parsed from JSON)
41#[derive(Debug, Clone, Deserialize)]
42struct ServiceAccountKey {
43    client_email: String,
44    private_key: String,
45    token_uri: Option<String>,
46    project_id: Option<String>,
47}
48
49/// JWT claims for GCP service account auth
50#[derive(Serialize)]
51struct JwtClaims {
52    iss: String,
53    scope: String,
54    aud: String,
55    iat: u64,
56    exp: u64,
57}
58
59pub struct VertexGlmProvider {
60    client: Client,
61    project_id: String,
62    base_url: String,
63    sa_key: ServiceAccountKey,
64    encoding_key: EncodingKey,
65    /// Cached OAuth2 access token (refreshes ~5 min before expiry)
66    cached_token: Arc<RwLock<Option<CachedToken>>>,
67}
68
69impl std::fmt::Debug for VertexGlmProvider {
70    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71        f.debug_struct("VertexGlmProvider")
72            .field("project_id", &self.project_id)
73            .field("base_url", &self.base_url)
74            .field("client_email", &self.sa_key.client_email)
75            .finish()
76    }
77}
78
79impl VertexGlmProvider {
80    /// Create from a service account JSON key string
81    pub fn new(sa_json: &str, project_id: Option<String>) -> Result<Self> {
82        let sa_key: ServiceAccountKey =
83            serde_json::from_str(sa_json).context("Failed to parse service account JSON key")?;
84
85        let project_id = project_id
86            .or_else(|| sa_key.project_id.clone())
87            .ok_or_else(|| anyhow::anyhow!("No project_id found in SA key or Vault config"))?;
88
89        let encoding_key = EncodingKey::from_rsa_pem(sa_key.private_key.as_bytes())
90            .context("Failed to parse RSA private key from service account")?;
91
92        let base_url = format!(
93            "https://{}/v1/projects/{}/locations/{}/endpoints/openapi",
94            VERTEX_ENDPOINT, project_id, VERTEX_REGION
95        );
96
97        tracing::debug!(
98            provider = "vertex-glm",
99            project_id = %project_id,
100            client_email = %sa_key.client_email,
101            base_url = %base_url,
102            "Creating Vertex GLM provider with service account"
103        );
104
105        let client = Client::builder()
106            .connect_timeout(CONNECT_TIMEOUT)
107            .timeout(REQUEST_TIMEOUT)
108            .build()
109            .context("Failed to build HTTP client")?;
110
111        Ok(Self {
112            client,
113            project_id,
114            base_url,
115            sa_key,
116            encoding_key,
117            cached_token: Arc::new(RwLock::new(None)),
118        })
119    }
120
121    /// Get a valid OAuth2 access token, refreshing if needed
122    async fn get_access_token(&self) -> Result<String> {
123        // Check cache — refresh 5 minutes before expiration
124        {
125            let cache = self.cached_token.read().await;
126            if let Some(ref cached) = *cache {
127                if cached.expires_at
128                    > std::time::Instant::now() + std::time::Duration::from_secs(300)
129                {
130                    return Ok(cached.token.clone());
131                }
132            }
133        }
134
135        // Sign a JWT assertion
136        let now = std::time::SystemTime::now()
137            .duration_since(std::time::UNIX_EPOCH)
138            .context("System time error")?
139            .as_secs();
140
141        let token_uri = self.sa_key.token_uri.as_deref().unwrap_or(GOOGLE_TOKEN_URL);
142
143        let claims = JwtClaims {
144            iss: self.sa_key.client_email.clone(),
145            scope: VERTEX_SCOPE.to_string(),
146            aud: token_uri.to_string(),
147            iat: now,
148            exp: now + 3600,
149        };
150
151        let header = Header::new(Algorithm::RS256);
152        let assertion = jsonwebtoken::encode(&header, &claims, &self.encoding_key)
153            .context("Failed to sign JWT assertion")?;
154
155        // Exchange JWT for access token
156        let form_body = format!(
157            "grant_type={}&assertion={}",
158            urlencoding::encode("urn:ietf:params:oauth:grant-type:jwt-bearer"),
159            urlencoding::encode(&assertion),
160        );
161        let response = self
162            .client
163            .post(token_uri)
164            .header("Content-Type", "application/x-www-form-urlencoded")
165            .body(form_body)
166            .send()
167            .await
168            .context("Failed to exchange JWT for access token")?;
169
170        let status = response.status();
171        let body = response
172            .text()
173            .await
174            .context("Failed to read token response")?;
175
176        if !status.is_success() {
177            anyhow::bail!("GCP token exchange failed: {status} {body}");
178        }
179
180        #[derive(Deserialize)]
181        struct TokenResponse {
182            access_token: String,
183            #[serde(default)]
184            expires_in: Option<u64>,
185        }
186
187        let token_resp: TokenResponse =
188            serde_json::from_str(&body).context("Failed to parse GCP token response")?;
189
190        let expires_in = token_resp.expires_in.unwrap_or(3600);
191
192        // Cache it
193        {
194            let mut cache = self.cached_token.write().await;
195            *cache = Some(CachedToken {
196                token: token_resp.access_token.clone(),
197                expires_at: std::time::Instant::now() + std::time::Duration::from_secs(expires_in),
198            });
199        }
200
201        tracing::debug!(
202            client_email = %self.sa_key.client_email,
203            expires_in_secs = expires_in,
204            "Refreshed GCP access token via service account JWT"
205        );
206
207        Ok(token_resp.access_token)
208    }
209
210    fn convert_messages(messages: &[Message]) -> Vec<Value> {
211        messages
212            .iter()
213            .map(|msg| {
214                let role = match msg.role {
215                    Role::System => "system",
216                    Role::User => "user",
217                    Role::Assistant => "assistant",
218                    Role::Tool => "tool",
219                };
220
221                match msg.role {
222                    Role::Tool => {
223                        if let Some(ContentPart::ToolResult {
224                            tool_call_id,
225                            content,
226                        }) = msg.content.first()
227                        {
228                            json!({
229                                "role": "tool",
230                                "tool_call_id": tool_call_id,
231                                "content": content
232                            })
233                        } else {
234                            json!({"role": role, "content": ""})
235                        }
236                    }
237                    Role::Assistant => {
238                        let text: String = msg
239                            .content
240                            .iter()
241                            .filter_map(|p| match p {
242                                ContentPart::Text { text } => Some(text.clone()),
243                                _ => None,
244                            })
245                            .collect::<Vec<_>>()
246                            .join("");
247
248                        let tool_calls: Vec<Value> = msg
249                            .content
250                            .iter()
251                            .filter_map(|p| match p {
252                                ContentPart::ToolCall {
253                                    id,
254                                    name,
255                                    arguments,
256                                    ..
257                                } => Some(json!({
258                                    "id": id,
259                                    "type": "function",
260                                    "function": {
261                                        "name": name,
262                                        "arguments": arguments
263                                    }
264                                })),
265                                _ => None,
266                            })
267                            .collect();
268
269                        let mut msg_json = json!({
270                            "role": "assistant",
271                            "content": if text.is_empty() { Value::Null } else { json!(text) },
272                        });
273
274                        if !tool_calls.is_empty() {
275                            msg_json["tool_calls"] = json!(tool_calls);
276                        }
277                        msg_json
278                    }
279                    _ => {
280                        let text: String = msg
281                            .content
282                            .iter()
283                            .filter_map(|p| match p {
284                                ContentPart::Text { text } => Some(text.clone()),
285                                _ => None,
286                            })
287                            .collect::<Vec<_>>()
288                            .join("\n");
289
290                        json!({"role": role, "content": text})
291                    }
292                }
293            })
294            .collect()
295    }
296
297    fn convert_tools(tools: &[ToolDefinition]) -> Vec<Value> {
298        tools
299            .iter()
300            .map(|t| {
301                json!({
302                    "type": "function",
303                    "function": {
304                        "name": t.name,
305                        "description": t.description,
306                        "parameters": t.parameters
307                    }
308                })
309            })
310            .collect()
311    }
312}
313
314// Response types
315#[derive(Debug, Deserialize)]
316struct ChatCompletion {
317    choices: Vec<Choice>,
318    #[serde(default)]
319    usage: Option<ApiUsage>,
320}
321
322#[derive(Debug, Deserialize)]
323struct Choice {
324    message: ChoiceMessage,
325    #[serde(default)]
326    finish_reason: Option<String>,
327}
328
329#[derive(Debug, Deserialize)]
330struct ChoiceMessage {
331    #[serde(default)]
332    content: Option<String>,
333    #[serde(default)]
334    tool_calls: Option<Vec<ToolCall>>,
335}
336
337#[derive(Debug, Deserialize)]
338struct ToolCall {
339    id: String,
340    function: FunctionCall,
341}
342
343#[derive(Debug, Deserialize)]
344struct FunctionCall {
345    name: String,
346    arguments: String,
347}
348
349#[derive(Debug, Deserialize)]
350struct ApiUsage {
351    #[serde(default)]
352    prompt_tokens: usize,
353    #[serde(default)]
354    completion_tokens: usize,
355    #[serde(default)]
356    total_tokens: usize,
357}
358
359#[derive(Debug, Deserialize)]
360struct ApiError {
361    error: ApiErrorDetail,
362}
363
364#[derive(Debug, Deserialize)]
365struct ApiErrorDetail {
366    message: String,
367    #[serde(default, rename = "type")]
368    error_type: Option<String>,
369}
370
371// SSE streaming types
372#[derive(Debug, Deserialize)]
373struct StreamResponse {
374    choices: Vec<StreamChoice>,
375}
376
377#[derive(Debug, Deserialize)]
378struct StreamChoice {
379    delta: StreamDelta,
380    #[serde(default)]
381    finish_reason: Option<String>,
382}
383
384#[derive(Debug, Deserialize)]
385struct StreamDelta {
386    #[serde(default)]
387    content: Option<String>,
388    #[serde(default)]
389    tool_calls: Option<Vec<StreamToolCall>>,
390}
391
392#[derive(Debug, Deserialize)]
393struct StreamToolCall {
394    #[serde(default)]
395    id: Option<String>,
396    function: Option<StreamFunction>,
397}
398
399#[derive(Debug, Deserialize)]
400struct StreamFunction {
401    #[serde(default)]
402    name: Option<String>,
403    #[serde(default)]
404    arguments: Option<String>,
405}
406
407#[async_trait]
408impl Provider for VertexGlmProvider {
409    fn name(&self) -> &str {
410        "vertex-glm"
411    }
412
413    async fn list_models(&self) -> Result<Vec<ModelInfo>> {
414        Ok(vec![
415            ModelInfo {
416                id: "zai-org/glm-5-maas".to_string(),
417                name: "GLM-5 (Vertex AI MaaS)".to_string(),
418                provider: "vertex-glm".to_string(),
419                context_window: 200_000,
420                max_output_tokens: Some(128_000),
421                supports_vision: false,
422                supports_tools: true,
423                supports_streaming: true,
424                input_cost_per_million: Some(1.0),
425                output_cost_per_million: Some(3.2),
426            },
427            ModelInfo {
428                id: "glm-5".to_string(),
429                name: "GLM-5 (Vertex AI)".to_string(),
430                provider: "vertex-glm".to_string(),
431                context_window: 200_000,
432                max_output_tokens: Some(128_000),
433                supports_vision: false,
434                supports_tools: true,
435                supports_streaming: true,
436                input_cost_per_million: Some(1.0),
437                output_cost_per_million: Some(3.2),
438            },
439        ])
440    }
441
442    async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
443        let mut access_token = self.get_access_token().await?;
444
445        let messages = Self::convert_messages(&request.messages);
446        let tools = Self::convert_tools(&request.tools);
447
448        // Resolve model ID to Vertex format
449        let model = if request.model.starts_with("zai-org/") {
450            request.model.clone()
451        } else {
452            format!(
453                "zai-org/{}-maas",
454                request.model.trim_start_matches("zai-org/")
455            )
456        };
457
458        // GLM-5 defaults to temperature 1.0 for best results
459        let temperature = request.temperature.unwrap_or(1.0);
460
461        let mut body = json!({
462            "model": model,
463            "messages": messages,
464            "temperature": temperature,
465            "stream": false,
466        });
467
468        if !tools.is_empty() {
469            body["tools"] = json!(tools);
470        }
471        if let Some(max) = request.max_tokens {
472            body["max_tokens"] = json!(max);
473        }
474
475        tracing::debug!(model = %request.model, "Vertex GLM request");
476
477        let url = format!("{}/chat/completions", self.base_url);
478        let mut last_err = None;
479
480        for attempt in 0..MAX_RETRIES {
481            if attempt > 0 {
482                let backoff = Duration::from_millis(1000 * 2u64.pow(attempt - 1));
483                tracing::warn!(
484                    attempt,
485                    backoff_ms = backoff.as_millis() as u64,
486                    "Vertex GLM retrying after transient error"
487                );
488                tokio::time::sleep(backoff).await;
489                // Re-acquire token in case it expired during backoff
490                access_token = self.get_access_token().await?;
491            }
492
493            let send_result = self
494                .client
495                .post(&url)
496                .bearer_auth(&access_token)
497                .header("Content-Type", "application/json")
498                .json(&body)
499                .send()
500                .await;
501
502            let response = match send_result {
503                Ok(r) => r,
504                Err(e) if e.is_timeout() && attempt + 1 < MAX_RETRIES => {
505                    tracing::warn!(error = %e, "Vertex GLM request timed out");
506                    last_err = Some(format!("Request timed out: {e}"));
507                    continue;
508                }
509                Err(e) => anyhow::bail!("Failed to send request to Vertex AI GLM: {e}"),
510            };
511
512            let status = response.status();
513            let text = response
514                .text()
515                .await
516                .context("Failed to read Vertex AI GLM response")?;
517
518            if status == reqwest::StatusCode::SERVICE_UNAVAILABLE && attempt + 1 < MAX_RETRIES {
519                tracing::warn!(status = %status, body = %text, "Vertex GLM service unavailable, retrying");
520                last_err = Some(format!("503 Service Unavailable: {text}"));
521                continue;
522            }
523
524            if !status.is_success() {
525                if let Ok(err) = serde_json::from_str::<ApiError>(&text) {
526                    anyhow::bail!(
527                        "Vertex AI GLM API error: {} ({:?})",
528                        err.error.message,
529                        err.error.error_type
530                    );
531                }
532                anyhow::bail!("Vertex AI GLM API error: {} {}", status, text);
533            }
534
535            let completion: ChatCompletion = serde_json::from_str(&text).context(format!(
536                "Failed to parse Vertex AI GLM response: {}",
537                &text[..text.len().min(200)]
538            ))?;
539
540            let choice = completion
541                .choices
542                .first()
543                .ok_or_else(|| anyhow::anyhow!("No choices in Vertex AI GLM response"))?;
544
545            let mut content = Vec::new();
546            let mut has_tool_calls = false;
547
548            if let Some(text) = &choice.message.content {
549                if !text.is_empty() {
550                    content.push(ContentPart::Text { text: text.clone() });
551                }
552            }
553
554            if let Some(tool_calls) = &choice.message.tool_calls {
555                has_tool_calls = !tool_calls.is_empty();
556                for tc in tool_calls {
557                    content.push(ContentPart::ToolCall {
558                        id: tc.id.clone(),
559                        name: tc.function.name.clone(),
560                        arguments: tc.function.arguments.clone(),
561                        thought_signature: None,
562                    });
563                }
564            }
565
566            let finish_reason = if has_tool_calls {
567                FinishReason::ToolCalls
568            } else {
569                match choice.finish_reason.as_deref() {
570                    Some("stop") => FinishReason::Stop,
571                    Some("length") => FinishReason::Length,
572                    Some("tool_calls") => FinishReason::ToolCalls,
573                    Some("content_filter") => FinishReason::ContentFilter,
574                    _ => FinishReason::Stop,
575                }
576            };
577
578            return Ok(CompletionResponse {
579                message: Message {
580                    role: Role::Assistant,
581                    content,
582                },
583                usage: Usage {
584                    prompt_tokens: completion
585                        .usage
586                        .as_ref()
587                        .map(|u| u.prompt_tokens)
588                        .unwrap_or(0),
589                    completion_tokens: completion
590                        .usage
591                        .as_ref()
592                        .map(|u| u.completion_tokens)
593                        .unwrap_or(0),
594                    total_tokens: completion
595                        .usage
596                        .as_ref()
597                        .map(|u| u.total_tokens)
598                        .unwrap_or(0),
599                    cache_read_tokens: None,
600                    cache_write_tokens: None,
601                },
602                finish_reason,
603            });
604        }
605
606        anyhow::bail!(
607            "Vertex AI GLM request failed after {MAX_RETRIES} attempts: {}",
608            last_err.unwrap_or_default()
609        )
610    }
611
612    async fn complete_stream(
613        &self,
614        request: CompletionRequest,
615    ) -> Result<futures::stream::BoxStream<'static, StreamChunk>> {
616        let mut access_token = self.get_access_token().await?;
617
618        let messages = Self::convert_messages(&request.messages);
619        let tools = Self::convert_tools(&request.tools);
620
621        // Resolve model ID to Vertex format
622        let model = if request.model.starts_with("zai-org") {
623            request.model.clone()
624        } else {
625            format!(
626                "zai-org/{}-maas",
627                request.model.trim_start_matches("zai-org/")
628            )
629        };
630
631        let temperature = request.temperature.unwrap_or(1.0);
632
633        let mut body = json!({
634            "model": model,
635            "messages": messages,
636            "temperature": temperature,
637            "stream": true,
638        });
639
640        if !tools.is_empty() {
641            body["tools"] = json!(tools);
642        }
643        if let Some(max) = request.max_tokens {
644            body["max_tokens"] = json!(max);
645        }
646
647        tracing::debug!(model = %request.model, "Vertex GLM streaming request");
648
649        let url = format!("{}/chat/completions", self.base_url);
650        let mut last_err = String::new();
651
652        for attempt in 0..MAX_RETRIES {
653            if attempt > 0 {
654                let backoff = Duration::from_millis(1000 * 2u64.pow(attempt - 1));
655                tracing::warn!(
656                    attempt,
657                    backoff_ms = backoff.as_millis() as u64,
658                    "Vertex GLM streaming retrying after transient error"
659                );
660                tokio::time::sleep(backoff).await;
661                access_token = self.get_access_token().await?;
662            }
663
664            let send_result = self
665                .client
666                .post(&url)
667                .bearer_auth(&access_token)
668                .header("Content-Type", "application/json")
669                .json(&body)
670                .send()
671                .await;
672
673            let response = match send_result {
674                Ok(r) => r,
675                Err(e) if e.is_timeout() && attempt + 1 < MAX_RETRIES => {
676                    tracing::warn!(error = %e, "Vertex GLM streaming request timed out");
677                    last_err = format!("Request timed out: {e}");
678                    continue;
679                }
680                Err(e) => anyhow::bail!("Failed to send streaming request to Vertex AI GLM: {e}"),
681            };
682
683            if response.status() == reqwest::StatusCode::SERVICE_UNAVAILABLE
684                && attempt + 1 < MAX_RETRIES
685            {
686                let text = response.text().await.unwrap_or_default();
687                tracing::warn!(body = %text, "Vertex GLM streaming service unavailable, retrying");
688                last_err = format!("503 Service Unavailable: {text}");
689                continue;
690            }
691
692            if !response.status().is_success() {
693                let status = response.status();
694                let text = response.text().await.unwrap_or_default();
695                if let Ok(err) = serde_json::from_str::<ApiError>(&text) {
696                    anyhow::bail!(
697                        "Vertex AI GLM API error: {} ({:?})",
698                        err.error.message,
699                        err.error.error_type
700                    );
701                }
702                anyhow::bail!("Vertex AI GLM streaming error: {} {}", status, text);
703            }
704
705            let stream = response.bytes_stream();
706            let mut buffer = String::new();
707
708            return Ok(stream
709                .flat_map(move |chunk_result| {
710                    let mut chunks: Vec<StreamChunk> = Vec::new();
711                    match chunk_result {
712                        Ok(bytes) => {
713                            let text = String::from_utf8_lossy(&bytes);
714                            buffer.push_str(&text);
715
716                            let mut text_buf = String::new();
717
718                            while let Some(line_end) = buffer.find('\n') {
719                                let line = buffer[..line_end].trim().to_string();
720                                buffer = buffer[line_end + 1..].to_string();
721
722                                if line == "data: [DONE]" {
723                                    if !text_buf.is_empty() {
724                                        chunks
725                                            .push(StreamChunk::Text(std::mem::take(&mut text_buf)));
726                                    }
727                                    chunks.push(StreamChunk::Done { usage: None });
728                                    continue;
729                                }
730                                if let Some(data) = line.strip_prefix("data: ") {
731                                    if let Ok(parsed) = serde_json::from_str::<StreamResponse>(data)
732                                    {
733                                        if let Some(choice) = parsed.choices.first() {
734                                            if let Some(ref content) = choice.delta.content {
735                                                text_buf.push_str(content);
736                                            }
737                                            if let Some(ref tool_calls) = choice.delta.tool_calls {
738                                                if !text_buf.is_empty() {
739                                                    chunks.push(StreamChunk::Text(std::mem::take(
740                                                        &mut text_buf,
741                                                    )));
742                                                }
743                                                for tc in tool_calls {
744                                                    if let Some(ref func) = tc.function {
745                                                        let id = tc.id.clone().unwrap_or_default();
746                                                        if let Some(ref name) = func.name {
747                                                            chunks.push(
748                                                                StreamChunk::ToolCallStart {
749                                                                    id: id.clone(),
750                                                                    name: name.clone(),
751                                                                },
752                                                            );
753                                                        }
754                                                        if let Some(ref args) = func.arguments {
755                                                            chunks.push(
756                                                                StreamChunk::ToolCallDelta {
757                                                                    id: id.clone(),
758                                                                    arguments_delta: args.clone(),
759                                                                },
760                                                            );
761                                                        }
762                                                    }
763                                                }
764                                            }
765                                            if let Some(ref reason) = choice.finish_reason {
766                                                if !text_buf.is_empty() {
767                                                    chunks.push(StreamChunk::Text(std::mem::take(
768                                                        &mut text_buf,
769                                                    )));
770                                                }
771                                                if reason == "tool_calls" {
772                                                    if let Some(tc) = choice
773                                                        .delta
774                                                        .tool_calls
775                                                        .as_ref()
776                                                        .and_then(|t| t.last())
777                                                    {
778                                                        if let Some(id) = &tc.id {
779                                                            chunks.push(StreamChunk::ToolCallEnd {
780                                                                id: id.clone(),
781                                                            });
782                                                        }
783                                                    }
784                                                }
785                                            }
786                                        }
787                                    }
788                                }
789                            }
790                            if !text_buf.is_empty() {
791                                chunks.push(StreamChunk::Text(text_buf));
792                            }
793                        }
794                        Err(e) => chunks.push(StreamChunk::Error(e.to_string())),
795                    }
796                    futures::stream::iter(chunks)
797                })
798                .boxed());
799        }
800
801        anyhow::bail!("Vertex AI GLM streaming failed after {MAX_RETRIES} attempts: {last_err}")
802    }
803}
804
805#[cfg(test)]
806mod tests {
807    use super::*;
808
809    #[test]
810    fn test_rejects_invalid_sa_json() {
811        let result = VertexGlmProvider::new("{}", None);
812        assert!(result.is_err());
813    }
814
815    #[test]
816    fn test_rejects_missing_project_id() {
817        let sa_json = json!({
818            "type": "service_account",
819            "client_email": "test@test.iam.gserviceaccount.com",
820            "private_key": "-----BEGIN RSA PRIVATE KEY-----\nMIIBogIBAAJBALRiMLAHudeSA/x3hB2f+2NRkJlS\n-----END RSA PRIVATE KEY-----\n",
821            "token_uri": "https://oauth2.googleapis.com/token"
822        });
823        // Invalid RSA key but the error should be about key parsing, not project
824        let result = VertexGlmProvider::new(&sa_json.to_string(), None);
825        assert!(result.is_err());
826    }
827}