Skip to main content

codetether_agent/provider/
vertex_glm.rs

1//! Vertex AI GLM provider implementation (MaaS endpoint)
2//!
3//! GLM-5 via Google Cloud Vertex AI Managed API Service.
4//! Uses service account JWT auth to obtain OAuth2 access tokens.
5//! The service account JSON key is stored in Vault and used to sign JWTs
6//! that are exchanged for short-lived access tokens (cached ~55 min).
7//!
8//! Reference: https://console.cloud.google.com/vertex-ai/publishers/zai/model-garden/glm-5
9
10use super::util;
11use super::{
12    CompletionRequest, CompletionResponse, ContentPart, FinishReason, Message, ModelInfo, Provider,
13    Role, StreamChunk, ToolDefinition, Usage,
14};
15use anyhow::{Context, Result};
16use async_trait::async_trait;
17use futures::StreamExt;
18use jsonwebtoken::{Algorithm, EncodingKey, Header};
19use reqwest::Client;
20use serde::{Deserialize, Serialize};
21use serde_json::{Value, json};
22use std::sync::Arc;
23use std::time::Duration;
24use tokio::sync::RwLock;
25
26const REQUEST_TIMEOUT: Duration = Duration::from_secs(120);
27const CONNECT_TIMEOUT: Duration = Duration::from_secs(10);
28const MAX_RETRIES: u32 = 3;
29
30const VERTEX_ENDPOINT: &str = "aiplatform.googleapis.com";
31const VERTEX_REGION: &str = "global";
32const GOOGLE_TOKEN_URL: &str = "https://oauth2.googleapis.com/token";
33const VERTEX_SCOPE: &str = "https://www.googleapis.com/auth/cloud-platform";
34
35/// Cached OAuth2 access token with expiration tracking
36struct CachedToken {
37    token: String,
38    expires_at: std::time::Instant,
39}
40
41/// GCP service account key (parsed from JSON)
42#[derive(Debug, Clone, Deserialize)]
43struct ServiceAccountKey {
44    client_email: String,
45    private_key: String,
46    token_uri: Option<String>,
47    project_id: Option<String>,
48}
49
50/// JWT claims for GCP service account auth
51#[derive(Serialize)]
52struct JwtClaims {
53    iss: String,
54    scope: String,
55    aud: String,
56    iat: u64,
57    exp: u64,
58}
59
60pub struct VertexGlmProvider {
61    client: Client,
62    project_id: String,
63    base_url: String,
64    sa_key: ServiceAccountKey,
65    encoding_key: EncodingKey,
66    /// Cached OAuth2 access token (refreshes ~5 min before expiry)
67    cached_token: Arc<RwLock<Option<CachedToken>>>,
68}
69
70impl std::fmt::Debug for VertexGlmProvider {
71    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
72        f.debug_struct("VertexGlmProvider")
73            .field("project_id", &self.project_id)
74            .field("base_url", &self.base_url)
75            .field("client_email", &self.sa_key.client_email)
76            .finish()
77    }
78}
79
80impl VertexGlmProvider {
81    /// Create from a service account JSON key string
82    pub fn new(sa_json: &str, project_id: Option<String>) -> Result<Self> {
83        let sa_key: ServiceAccountKey =
84            serde_json::from_str(sa_json).context("Failed to parse service account JSON key")?;
85
86        let project_id = project_id
87            .or_else(|| sa_key.project_id.clone())
88            .ok_or_else(|| anyhow::anyhow!("No project_id found in SA key or Vault config"))?;
89
90        let encoding_key = EncodingKey::from_rsa_pem(sa_key.private_key.as_bytes())
91            .context("Failed to parse RSA private key from service account")?;
92
93        let base_url = format!(
94            "https://{}/v1/projects/{}/locations/{}/endpoints/openapi",
95            VERTEX_ENDPOINT, project_id, VERTEX_REGION
96        );
97
98        tracing::debug!(
99            provider = "vertex-glm",
100            project_id = %project_id,
101            client_email = %sa_key.client_email,
102            base_url = %base_url,
103            "Creating Vertex GLM provider with service account"
104        );
105
106        let client = Client::builder()
107            .connect_timeout(CONNECT_TIMEOUT)
108            .timeout(REQUEST_TIMEOUT)
109            .build()
110            .context("Failed to build HTTP client")?;
111
112        Ok(Self {
113            client,
114            project_id,
115            base_url,
116            sa_key,
117            encoding_key,
118            cached_token: Arc::new(RwLock::new(None)),
119        })
120    }
121
122    /// Get a valid OAuth2 access token, refreshing if needed
123    async fn get_access_token(&self) -> Result<String> {
124        // Check cache — refresh 5 minutes before expiration
125        {
126            let cache = self.cached_token.read().await;
127            if let Some(ref cached) = *cache
128                && cached.expires_at
129                    > std::time::Instant::now() + std::time::Duration::from_secs(300)
130            {
131                return Ok(cached.token.clone());
132            }
133        }
134
135        // Sign a JWT assertion
136        let now = std::time::SystemTime::now()
137            .duration_since(std::time::UNIX_EPOCH)
138            .context("System time error")?
139            .as_secs();
140
141        let token_uri = self.sa_key.token_uri.as_deref().unwrap_or(GOOGLE_TOKEN_URL);
142
143        let claims = JwtClaims {
144            iss: self.sa_key.client_email.clone(),
145            scope: VERTEX_SCOPE.to_string(),
146            aud: token_uri.to_string(),
147            iat: now,
148            exp: now + 3600,
149        };
150
151        let header = Header::new(Algorithm::RS256);
152        let assertion = jsonwebtoken::encode(&header, &claims, &self.encoding_key)
153            .context("Failed to sign JWT assertion")?;
154
155        // Exchange JWT for access token
156        let form_body = format!(
157            "grant_type={}&assertion={}",
158            urlencoding::encode("urn:ietf:params:oauth:grant-type:jwt-bearer"),
159            urlencoding::encode(&assertion),
160        );
161        let response = self
162            .client
163            .post(token_uri)
164            .header("Content-Type", "application/x-www-form-urlencoded")
165            .body(form_body)
166            .send()
167            .await
168            .context("Failed to exchange JWT for access token")?;
169
170        let status = response.status();
171        let body = response
172            .text()
173            .await
174            .context("Failed to read token response")?;
175
176        if !status.is_success() {
177            anyhow::bail!("GCP token exchange failed: {status} {body}");
178        }
179
180        #[derive(Deserialize)]
181        struct TokenResponse {
182            access_token: String,
183            #[serde(default)]
184            expires_in: Option<u64>,
185        }
186
187        let token_resp: TokenResponse =
188            serde_json::from_str(&body).context("Failed to parse GCP token response")?;
189
190        let expires_in = token_resp.expires_in.unwrap_or(3600);
191
192        // Cache it
193        {
194            let mut cache = self.cached_token.write().await;
195            *cache = Some(CachedToken {
196                token: token_resp.access_token.clone(),
197                expires_at: std::time::Instant::now() + std::time::Duration::from_secs(expires_in),
198            });
199        }
200
201        tracing::debug!(
202            client_email = %self.sa_key.client_email,
203            expires_in_secs = expires_in,
204            "Refreshed GCP access token via service account JWT"
205        );
206
207        Ok(token_resp.access_token)
208    }
209
210    fn convert_messages(messages: &[Message]) -> Vec<Value> {
211        messages
212            .iter()
213            .map(|msg| {
214                let role = match msg.role {
215                    Role::System => "system",
216                    Role::User => "user",
217                    Role::Assistant => "assistant",
218                    Role::Tool => "tool",
219                };
220
221                match msg.role {
222                    Role::Tool => {
223                        if let Some(ContentPart::ToolResult {
224                            tool_call_id,
225                            content,
226                        }) = msg.content.first()
227                        {
228                            json!({
229                                "role": "tool",
230                                "tool_call_id": tool_call_id,
231                                "content": content
232                            })
233                        } else {
234                            json!({"role": role, "content": ""})
235                        }
236                    }
237                    Role::Assistant => {
238                        let text: String = msg
239                            .content
240                            .iter()
241                            .filter_map(|p| match p {
242                                ContentPart::Text { text } => Some(text.clone()),
243                                _ => None,
244                            })
245                            .collect::<Vec<_>>()
246                            .join("");
247
248                        let tool_calls: Vec<Value> = msg
249                            .content
250                            .iter()
251                            .filter_map(|p| match p {
252                                ContentPart::ToolCall {
253                                    id,
254                                    name,
255                                    arguments,
256                                    ..
257                                } => Some(json!({
258                                    "id": id,
259                                    "type": "function",
260                                    "function": {
261                                        "name": name,
262                                        "arguments": arguments
263                                    }
264                                })),
265                                _ => None,
266                            })
267                            .collect();
268
269                        let mut msg_json = json!({
270                            "role": "assistant",
271                            "content": if text.is_empty() { Value::Null } else { json!(text) },
272                        });
273
274                        if !tool_calls.is_empty() {
275                            msg_json["tool_calls"] = json!(tool_calls);
276                        }
277                        msg_json
278                    }
279                    _ => {
280                        let text: String = msg
281                            .content
282                            .iter()
283                            .filter_map(|p| match p {
284                                ContentPart::Text { text } => Some(text.clone()),
285                                _ => None,
286                            })
287                            .collect::<Vec<_>>()
288                            .join("\n");
289
290                        json!({"role": role, "content": text})
291                    }
292                }
293            })
294            .collect()
295    }
296
297    fn convert_tools(tools: &[ToolDefinition]) -> Vec<Value> {
298        tools
299            .iter()
300            .map(|t| {
301                json!({
302                    "type": "function",
303                    "function": {
304                        "name": t.name,
305                        "description": t.description,
306                        "parameters": t.parameters
307                    }
308                })
309            })
310            .collect()
311    }
312}
313
314// Response types
315#[derive(Debug, Deserialize)]
316struct ChatCompletion {
317    choices: Vec<Choice>,
318    #[serde(default)]
319    usage: Option<ApiUsage>,
320}
321
322#[derive(Debug, Deserialize)]
323struct Choice {
324    message: ChoiceMessage,
325    #[serde(default)]
326    finish_reason: Option<String>,
327}
328
329#[derive(Debug, Deserialize)]
330struct ChoiceMessage {
331    #[serde(default)]
332    content: Option<String>,
333    #[serde(default)]
334    tool_calls: Option<Vec<ToolCall>>,
335}
336
337#[derive(Debug, Deserialize)]
338struct ToolCall {
339    id: String,
340    function: FunctionCall,
341}
342
343#[derive(Debug, Deserialize)]
344struct FunctionCall {
345    name: String,
346    arguments: String,
347}
348
349#[derive(Debug, Deserialize)]
350struct ApiUsage {
351    #[serde(default)]
352    prompt_tokens: usize,
353    #[serde(default)]
354    completion_tokens: usize,
355    #[serde(default)]
356    total_tokens: usize,
357    /// GLM-on-Vertex KV-cache hit count, subset of `prompt_tokens`.
358    #[serde(default)]
359    prompt_tokens_details: Option<VertexGlmPromptTokenDetails>,
360}
361
362#[derive(Debug, Deserialize, Default)]
363struct VertexGlmPromptTokenDetails {
364    #[serde(default)]
365    cached_tokens: usize,
366}
367
368impl ApiUsage {
369    fn cached(&self) -> usize {
370        self.prompt_tokens_details
371            .as_ref()
372            .map(|d| d.cached_tokens)
373            .unwrap_or(0)
374    }
375}
376
377#[derive(Debug, Deserialize)]
378struct ApiError {
379    error: ApiErrorDetail,
380}
381
382#[derive(Debug, Deserialize)]
383struct ApiErrorDetail {
384    message: String,
385    #[serde(default, rename = "type")]
386    error_type: Option<String>,
387}
388
389// SSE streaming types
390#[derive(Debug, Deserialize)]
391struct StreamResponse {
392    choices: Vec<StreamChoice>,
393}
394
395#[derive(Debug, Deserialize)]
396struct StreamChoice {
397    delta: StreamDelta,
398    #[serde(default)]
399    finish_reason: Option<String>,
400}
401
402#[derive(Debug, Deserialize)]
403struct StreamDelta {
404    #[serde(default)]
405    content: Option<String>,
406    #[serde(default)]
407    tool_calls: Option<Vec<StreamToolCall>>,
408}
409
410#[derive(Debug, Deserialize)]
411struct StreamToolCall {
412    #[serde(default)]
413    id: Option<String>,
414    function: Option<StreamFunction>,
415}
416
417#[derive(Debug, Deserialize)]
418struct StreamFunction {
419    #[serde(default)]
420    name: Option<String>,
421    #[serde(default)]
422    arguments: Option<String>,
423}
424
425#[async_trait]
426impl Provider for VertexGlmProvider {
427    fn name(&self) -> &str {
428        "vertex-glm"
429    }
430
431    async fn list_models(&self) -> Result<Vec<ModelInfo>> {
432        Ok(vec![
433            ModelInfo {
434                id: "zai-org/glm-5-maas".to_string(),
435                name: "GLM-5 (Vertex AI MaaS)".to_string(),
436                provider: "vertex-glm".to_string(),
437                context_window: 200_000,
438                max_output_tokens: Some(128_000),
439                supports_vision: false,
440                supports_tools: true,
441                supports_streaming: true,
442                input_cost_per_million: Some(1.0),
443                output_cost_per_million: Some(3.2),
444            },
445            ModelInfo {
446                id: "glm-5".to_string(),
447                name: "GLM-5 (Vertex AI)".to_string(),
448                provider: "vertex-glm".to_string(),
449                context_window: 200_000,
450                max_output_tokens: Some(128_000),
451                supports_vision: false,
452                supports_tools: true,
453                supports_streaming: true,
454                input_cost_per_million: Some(1.0),
455                output_cost_per_million: Some(3.2),
456            },
457        ])
458    }
459
460    async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
461        let mut access_token = self.get_access_token().await?;
462
463        let messages = Self::convert_messages(&request.messages);
464        let tools = Self::convert_tools(&request.tools);
465
466        // Resolve model ID to Vertex format
467        let model = if request.model.starts_with("zai-org/") {
468            request.model.clone()
469        } else {
470            format!(
471                "zai-org/{}-maas",
472                request.model.trim_start_matches("zai-org/")
473            )
474        };
475
476        // GLM-5 defaults to temperature 1.0 for best results
477        let temperature = request.temperature.unwrap_or(1.0);
478
479        let mut body = json!({
480            "model": model,
481            "messages": messages,
482            "temperature": temperature,
483            "stream": false,
484        });
485
486        if !tools.is_empty() {
487            body["tools"] = json!(tools);
488        }
489        if let Some(max) = request.max_tokens {
490            body["max_tokens"] = json!(max);
491        }
492
493        tracing::debug!(model = %request.model, "Vertex GLM request");
494
495        let url = format!("{}/chat/completions", self.base_url);
496        let mut last_err = None;
497
498        for attempt in 0..MAX_RETRIES {
499            if attempt > 0 {
500                let backoff = Duration::from_millis(1000 * 2u64.pow(attempt - 1));
501                tracing::warn!(
502                    attempt,
503                    backoff_ms = backoff.as_millis() as u64,
504                    "Vertex GLM retrying after transient error"
505                );
506                tokio::time::sleep(backoff).await;
507                // Re-acquire token in case it expired during backoff
508                access_token = self.get_access_token().await?;
509            }
510
511            let send_result = self
512                .client
513                .post(&url)
514                .bearer_auth(&access_token)
515                .header("Content-Type", "application/json")
516                .json(&body)
517                .send()
518                .await;
519
520            let response = match send_result {
521                Ok(r) => r,
522                Err(e) if e.is_timeout() && attempt + 1 < MAX_RETRIES => {
523                    tracing::warn!(error = %e, "Vertex GLM request timed out");
524                    last_err = Some(format!("Request timed out: {e}"));
525                    continue;
526                }
527                Err(e) => anyhow::bail!("Failed to send request to Vertex AI GLM: {e}"),
528            };
529
530            let status = response.status();
531            let text = response
532                .text()
533                .await
534                .context("Failed to read Vertex AI GLM response")?;
535
536            if status == reqwest::StatusCode::SERVICE_UNAVAILABLE && attempt + 1 < MAX_RETRIES {
537                tracing::warn!(status = %status, body = %text, "Vertex GLM service unavailable, retrying");
538                last_err = Some(format!("503 Service Unavailable: {text}"));
539                continue;
540            }
541
542            if !status.is_success() {
543                if let Ok(err) = serde_json::from_str::<ApiError>(&text) {
544                    anyhow::bail!(
545                        "Vertex AI GLM API error: {} ({:?})",
546                        err.error.message,
547                        err.error.error_type
548                    );
549                }
550                anyhow::bail!("Vertex AI GLM API error: {} {}", status, text);
551            }
552
553            let completion: ChatCompletion = serde_json::from_str(&text).context(format!(
554                "Failed to parse Vertex AI GLM response: {}",
555                util::truncate_bytes_safe(&text, 200)
556            ))?;
557
558            let choice = completion
559                .choices
560                .first()
561                .ok_or_else(|| anyhow::anyhow!("No choices in Vertex AI GLM response"))?;
562
563            let mut content = Vec::new();
564            let mut has_tool_calls = false;
565
566            if let Some(text) = &choice.message.content
567                && !text.is_empty()
568            {
569                content.push(ContentPart::Text { text: text.clone() });
570            }
571
572            if let Some(tool_calls) = &choice.message.tool_calls {
573                has_tool_calls = !tool_calls.is_empty();
574                for tc in tool_calls {
575                    content.push(ContentPart::ToolCall {
576                        id: tc.id.clone(),
577                        name: tc.function.name.clone(),
578                        arguments: tc.function.arguments.clone(),
579                        thought_signature: None,
580                    });
581                }
582            }
583
584            let finish_reason = if has_tool_calls {
585                FinishReason::ToolCalls
586            } else {
587                match choice.finish_reason.as_deref() {
588                    Some("stop") => FinishReason::Stop,
589                    Some("length") => FinishReason::Length,
590                    Some("tool_calls") => FinishReason::ToolCalls,
591                    Some("content_filter") => FinishReason::ContentFilter,
592                    _ => FinishReason::Stop,
593                }
594            };
595
596            return Ok(CompletionResponse {
597                message: Message {
598                    role: Role::Assistant,
599                    content,
600                },
601                usage: Usage {
602                    prompt_tokens: completion
603                        .usage
604                        .as_ref()
605                        .map(|u| u.prompt_tokens.saturating_sub(u.cached()))
606                        .unwrap_or(0),
607                    completion_tokens: completion
608                        .usage
609                        .as_ref()
610                        .map(|u| u.completion_tokens)
611                        .unwrap_or(0),
612                    total_tokens: completion
613                        .usage
614                        .as_ref()
615                        .map(|u| u.total_tokens)
616                        .unwrap_or(0),
617                    cache_read_tokens: completion
618                        .usage
619                        .as_ref()
620                        .map(ApiUsage::cached)
621                        .filter(|&n| n > 0),
622                    cache_write_tokens: None,
623                },
624                finish_reason,
625            });
626        }
627
628        anyhow::bail!(
629            "Vertex AI GLM request failed after {MAX_RETRIES} attempts: {}",
630            last_err.unwrap_or_default()
631        )
632    }
633
634    async fn complete_stream(
635        &self,
636        request: CompletionRequest,
637    ) -> Result<futures::stream::BoxStream<'static, StreamChunk>> {
638        let mut access_token = self.get_access_token().await?;
639
640        let messages = Self::convert_messages(&request.messages);
641        let tools = Self::convert_tools(&request.tools);
642
643        // Resolve model ID to Vertex format
644        let model = if request.model.starts_with("zai-org") {
645            request.model.clone()
646        } else {
647            format!(
648                "zai-org/{}-maas",
649                request.model.trim_start_matches("zai-org/")
650            )
651        };
652
653        let temperature = request.temperature.unwrap_or(1.0);
654
655        let mut body = json!({
656            "model": model,
657            "messages": messages,
658            "temperature": temperature,
659            "stream": true,
660        });
661
662        if !tools.is_empty() {
663            body["tools"] = json!(tools);
664        }
665        if let Some(max) = request.max_tokens {
666            body["max_tokens"] = json!(max);
667        }
668
669        tracing::debug!(model = %request.model, "Vertex GLM streaming request");
670
671        let url = format!("{}/chat/completions", self.base_url);
672        let mut last_err = String::new();
673
674        for attempt in 0..MAX_RETRIES {
675            if attempt > 0 {
676                let backoff = Duration::from_millis(1000 * 2u64.pow(attempt - 1));
677                tracing::warn!(
678                    attempt,
679                    backoff_ms = backoff.as_millis() as u64,
680                    "Vertex GLM streaming retrying after transient error"
681                );
682                tokio::time::sleep(backoff).await;
683                access_token = self.get_access_token().await?;
684            }
685
686            let send_result = self
687                .client
688                .post(&url)
689                .bearer_auth(&access_token)
690                .header("Content-Type", "application/json")
691                .json(&body)
692                .send()
693                .await;
694
695            let response = match send_result {
696                Ok(r) => r,
697                Err(e) if e.is_timeout() && attempt + 1 < MAX_RETRIES => {
698                    tracing::warn!(error = %e, "Vertex GLM streaming request timed out");
699                    last_err = format!("Request timed out: {e}");
700                    continue;
701                }
702                Err(e) => anyhow::bail!("Failed to send streaming request to Vertex AI GLM: {e}"),
703            };
704
705            if response.status() == reqwest::StatusCode::SERVICE_UNAVAILABLE
706                && attempt + 1 < MAX_RETRIES
707            {
708                let text = response.text().await.unwrap_or_default();
709                tracing::warn!(body = %text, "Vertex GLM streaming service unavailable, retrying");
710                last_err = format!("503 Service Unavailable: {text}");
711                continue;
712            }
713
714            if !response.status().is_success() {
715                let status = response.status();
716                let text = response.text().await.unwrap_or_default();
717                if let Ok(err) = serde_json::from_str::<ApiError>(&text) {
718                    anyhow::bail!(
719                        "Vertex AI GLM API error: {} ({:?})",
720                        err.error.message,
721                        err.error.error_type
722                    );
723                }
724                anyhow::bail!("Vertex AI GLM streaming error: {} {}", status, text);
725            }
726
727            let stream = response.bytes_stream();
728            let mut buffer = String::new();
729
730            return Ok(stream
731                .flat_map(move |chunk_result| {
732                    let mut chunks: Vec<StreamChunk> = Vec::new();
733                    match chunk_result {
734                        Ok(bytes) => {
735                            let text = String::from_utf8_lossy(&bytes);
736                            buffer.push_str(&text);
737
738                            let mut text_buf = String::new();
739
740                            while let Some(line_end) = buffer.find('\n') {
741                                let line = buffer[..line_end].trim().to_string();
742                                buffer = buffer[line_end + 1..].to_string();
743
744                                if line == "data: [DONE]" {
745                                    if !text_buf.is_empty() {
746                                        chunks
747                                            .push(StreamChunk::Text(std::mem::take(&mut text_buf)));
748                                    }
749                                    chunks.push(StreamChunk::Done { usage: None });
750                                    continue;
751                                }
752                                if let Some(data) = line.strip_prefix("data: ")
753                                    && let Ok(parsed) = serde_json::from_str::<StreamResponse>(data)
754                                    && let Some(choice) = parsed.choices.first()
755                                {
756                                    if let Some(ref content) = choice.delta.content {
757                                        text_buf.push_str(content);
758                                    }
759                                    if let Some(ref tool_calls) = choice.delta.tool_calls {
760                                        if !text_buf.is_empty() {
761                                            chunks.push(StreamChunk::Text(std::mem::take(
762                                                &mut text_buf,
763                                            )));
764                                        }
765                                        for tc in tool_calls {
766                                            if let Some(ref func) = tc.function {
767                                                let id = tc.id.clone().unwrap_or_default();
768                                                if let Some(ref name) = func.name {
769                                                    chunks.push(StreamChunk::ToolCallStart {
770                                                        id: id.clone(),
771                                                        name: name.clone(),
772                                                    });
773                                                }
774                                                if let Some(ref args) = func.arguments {
775                                                    chunks.push(StreamChunk::ToolCallDelta {
776                                                        id: id.clone(),
777                                                        arguments_delta: args.clone(),
778                                                    });
779                                                }
780                                            }
781                                        }
782                                    }
783                                    if let Some(ref reason) = choice.finish_reason {
784                                        if !text_buf.is_empty() {
785                                            chunks.push(StreamChunk::Text(std::mem::take(
786                                                &mut text_buf,
787                                            )));
788                                        }
789                                        if reason == "tool_calls"
790                                            && let Some(tc) = choice
791                                                .delta
792                                                .tool_calls
793                                                .as_ref()
794                                                .and_then(|t| t.last())
795                                            && let Some(id) = &tc.id
796                                        {
797                                            chunks
798                                                .push(StreamChunk::ToolCallEnd { id: id.clone() });
799                                        }
800                                    }
801                                }
802                            }
803                            if !text_buf.is_empty() {
804                                chunks.push(StreamChunk::Text(text_buf));
805                            }
806                        }
807                        Err(e) => chunks.push(StreamChunk::Error(e.to_string())),
808                    }
809                    futures::stream::iter(chunks)
810                })
811                .boxed());
812        }
813
814        anyhow::bail!("Vertex AI GLM streaming failed after {MAX_RETRIES} attempts: {last_err}")
815    }
816}
817
818#[cfg(test)]
819mod tests {
820    use super::*;
821
822    #[test]
823    fn test_rejects_invalid_sa_json() {
824        let result = VertexGlmProvider::new("{}", None);
825        assert!(result.is_err());
826    }
827
828    #[test]
829    fn test_rejects_missing_project_id() {
830        let sa_json = json!({
831            "type": "service_account",
832            "client_email": "test@test.iam.gserviceaccount.com",
833            "private_key": "-----BEGIN RSA PRIVATE KEY-----\nMIIBogIBAAJBALRiMLAHudeSA/x3hB2f+2NRkJlS\n-----END RSA PRIVATE KEY-----\n",
834            "token_uri": "https://oauth2.googleapis.com/token"
835        });
836        // Invalid RSA key but the error should be about key parsing, not project
837        let result = VertexGlmProvider::new(&sa_json.to_string(), None);
838        assert!(result.is_err());
839    }
840}