Skip to main content

perspt_core/
llm_provider.rs

1//! # LLM Provider Module
2//!
3//! Thread-safe LLM provider abstraction for multi-agent use.
4//! Wraps genai::Client with Arc<RwLock<>> for shared state.
5
6use anyhow::{Context, Result};
7use futures::StreamExt;
8use genai::adapter::AdapterKind;
9use genai::chat::{ChatMessage, ChatRequest, ChatStreamEvent};
10use genai::resolver::{AuthData, AuthResolver, Endpoint, ProviderConfig, ServiceTargetResolver};
11use genai::Client;
12use genai::ModelIden;
13use genai::ServiceTarget;
14use std::future::Future;
15use std::path::PathBuf;
16use std::pin::Pin;
17use std::sync::Arc;
18use std::time::Instant;
19use tokio::sync::{mpsc, RwLock};
20
21use crate::config::Config;
22
23/// End of transmission signal
24pub const EOT_SIGNAL: &str = "<|EOT|>";
25
26/// Effective provider id and model after merging config, CLI flags, and env.
27#[derive(Debug, Clone)]
28pub struct ResolvedProvider {
29    /// Provider id, e.g. `openai`, `ollama`.
30    pub provider: String,
31    /// Model name to use (passed to genai verbatim so namespacing works).
32    pub model: String,
33}
34
35/// Detect the provider id and a sensible default model from environment keys.
36///
37/// Used as the fallback when no provider is configured. Falls back to a local
38/// Ollama setup when no API keys are present.
39pub fn detect_provider_from_env() -> (&'static str, &'static str) {
40    if vertex_project_from_env().is_some() {
41        // Google Vertex AI (Agent/AI Platform). Models are namespace-routed as
42        // `vertex::<model>`; auth is an OAuth2 Bearer token from ADC or
43        // VERTEX_API_KEY.
44        ("vertex", "vertex::gemini-2.5-flash")
45    } else if std::env::var("GEMINI_API_KEY").is_ok() {
46        ("gemini", "gemini-3.1-flash-lite-preview")
47    } else if std::env::var("OPENAI_API_KEY").is_ok() {
48        ("openai", "gpt-4o-mini")
49    } else if std::env::var("ANTHROPIC_API_KEY").is_ok() {
50        ("anthropic", "claude-3-5-sonnet-20241022")
51    } else if std::env::var("GROQ_API_KEY").is_ok() {
52        ("groq", "llama-3.1-8b-instant")
53    } else if std::env::var("COHERE_API_KEY").is_ok() {
54        ("cohere", "command-r-plus")
55    } else if std::env::var("XAI_API_KEY").is_ok() {
56        ("xai", "grok-beta")
57    } else if std::env::var("DEEPSEEK_API_KEY").is_ok() {
58        ("deepseek", "deepseek-chat")
59    } else {
60        // Default to Ollama for local usage
61        ("ollama", "llama3.2")
62    }
63}
64
65/// Response from a non-streaming LLM call, carrying text and token usage.
66#[derive(Debug, Clone)]
67pub struct LlmResponse {
68    pub text: String,
69    pub tokens_in: Option<i32>,
70    pub tokens_out: Option<i32>,
71}
72
73/// Shared state for rate limiting and token counting
74#[derive(Default)]
75struct SharedState {
76    total_tokens_used: usize,
77    request_count: usize,
78}
79
80/// Thread-safe LLM provider implementation using Arc<RwLock<>>.
81///
82/// This provider can be cheaply cloned and shared across multiple agents.
83/// Each clone shares the same underlying client and rate limiting state.
84#[derive(Clone)]
85pub struct GenAIProvider {
86    /// The underlying genai client
87    client: Arc<Client>,
88    /// Shared state for rate limiting and metrics
89    shared: Arc<RwLock<SharedState>>,
90}
91
92impl GenAIProvider {
93    /// Creates a new GenAI provider with automatic configuration.
94    pub fn new() -> Result<Self> {
95        let client = Client::default();
96        Ok(Self::from_client(client))
97    }
98
99    /// Creates a new GenAI provider with explicit configuration.
100    pub fn new_with_config(provider_type: Option<&str>, api_key: Option<&str>) -> Result<Self> {
101        let adapter_kind = provider_type.and_then(|provider| match str_to_adapter_kind(provider) {
102            Ok(adapter_kind) => Some(adapter_kind),
103            Err(_) => {
104                log::warn!("Unknown provider type for genai client: {provider}");
105                None
106            }
107        });
108
109        // Set environment variable if API key is provided
110        if let (Some(provider), Some(key)) = (provider_type, api_key) {
111            if let Some(env_var) = provider_api_key_env_var(provider) {
112                log::info!("Setting {env_var} environment variable for genai client");
113                std::env::set_var(env_var, key);
114            } else if provider.eq_ignore_ascii_case("ollama") {
115                log::info!("Ollama provider detected - no API key required for local setup");
116            } else {
117                log::warn!("Unknown provider type for API key: {provider}");
118            }
119        }
120
121        let is_vertex = provider_type
122            .map(|p| p.eq_ignore_ascii_case("vertex"))
123            .unwrap_or(false);
124
125        let client = if is_vertex {
126            // Vertex AI authenticates with an OAuth2 Bearer token from ADC; no
127            // static API key is required when ADC is configured.
128            build_vertex_client()
129        } else {
130            match adapter_kind {
131                Some(adapter_kind) => build_bound_client(adapter_kind, provider_type),
132                None => Client::default(),
133            }
134        };
135
136        Ok(Self::from_client(client))
137    }
138
139    /// Build a provider from a `Config`, merging in environment detection and an
140    /// optional CLI model override, and return the effective provider/model.
141    ///
142    /// Precedence:
143    ///   - provider: `config.provider` > environment detection
144    ///   - model:    `cli_model` > `config.model` > provider default
145    ///   - api_key:  `config.api_key` > ambient environment
146    ///   - base_url: `config.base_url` > ambient environment
147    ///
148    /// The returned client is bound to the resolved adapter, so custom/local
149    /// OpenAI-compatible model names (e.g. `phi-4-npu-ov`) route correctly while
150    /// recognized names still resolve by prefix. Model names are passed through
151    /// verbatim so genai namespacing (`openai::model`) keeps working.
152    pub fn from_config(
153        config: &Config,
154        cli_model: Option<&str>,
155    ) -> Result<(Self, ResolvedProvider)> {
156        let (env_provider, env_model) = detect_provider_from_env();
157
158        let env_model_override = std::env::var("OPENAI_MODEL")
159            .or_else(|_| std::env::var("MODEL"))
160            .ok();
161
162        let model = cli_model
163            .map(str::to_string)
164            .or_else(|| config.model.clone())
165            .or(env_model_override)
166            .unwrap_or_else(|| env_model.to_string());
167
168        let provider = config
169            .provider
170            .clone()
171            .or_else(|| provider_from_model_namespace(&model).map(str::to_string))
172            .unwrap_or_else(|| env_provider.to_string());
173
174        // Propagate a configured base URL into the env var that build_bound_client
175        // reads, without clobbering an explicit ambient override.
176        if let Some(base_url) = config.base_url.as_deref() {
177            if let Some(env_var) = provider_base_url_env_var(&provider) {
178                if std::env::var(env_var).is_err() {
179                    std::env::set_var(env_var, base_url);
180                }
181            }
182        }
183
184        if provider.eq_ignore_ascii_case("vertex") {
185            configure_vertex_environment(config);
186        }
187
188        let provider_obj = Self::new_with_config(Some(&provider), config.api_key.as_deref())?;
189        Ok((provider_obj, ResolvedProvider { provider, model }))
190    }
191
192    fn from_client(client: Client) -> Self {
193        Self {
194            client: Arc::new(client),
195            shared: Arc::new(RwLock::new(SharedState::default())),
196        }
197    }
198
199    /// Get total tokens used across all requests
200    pub async fn get_total_tokens_used(&self) -> usize {
201        self.shared.read().await.total_tokens_used
202    }
203
204    /// Get total request count
205    pub async fn get_request_count(&self) -> usize {
206        self.shared.read().await.request_count
207    }
208
209    /// Increment request counter (for metrics)
210    async fn increment_request(&self) {
211        let mut state = self.shared.write().await;
212        state.request_count += 1;
213    }
214
215    /// Add tokens to the total count
216    pub async fn add_tokens(&self, count: usize) {
217        let mut state = self.shared.write().await;
218        state.total_tokens_used += count;
219    }
220
221    /// Retrieves all available models for a specific provider.
222    pub async fn get_available_models(&self, provider: &str) -> Result<Vec<String>> {
223        let adapter_kind = str_to_adapter_kind(provider)?;
224        let provider_config = provider_base_url_from_env(provider)
225            .map(|base_url| {
226                ProviderConfig::from_endpoint(Endpoint::from_owned(normalize_base_url(&base_url)))
227            })
228            .unwrap_or_default();
229
230        let models = self
231            .client
232            .all_model_names(adapter_kind, provider_config)
233            .await
234            .context(format!("Failed to get models for provider: {provider}"))?;
235
236        Ok(models)
237    }
238
239    /// Generates a simple text response without streaming.
240    /// Includes exponential backoff retry for rate limits and transient errors.
241    pub async fn generate_response_simple(&self, model: &str, prompt: &str) -> Result<LlmResponse> {
242        self.generate_response_with_retry(model, prompt, 3).await
243    }
244
245    /// Generates a response with configurable retry count and exponential backoff.
246    pub async fn generate_response_with_retry(
247        &self,
248        model: &str,
249        prompt: &str,
250        max_retries: usize,
251    ) -> Result<LlmResponse> {
252        self.increment_request().await;
253
254        let chat_req = ChatRequest::default().append_message(ChatMessage::user(prompt));
255
256        log::debug!(
257            "Sending chat request to model: {model} with prompt length: {} chars",
258            prompt.len()
259        );
260
261        let start_time = Instant::now();
262        let mut last_error: Option<anyhow::Error> = None;
263        let mut retry_count = 0;
264
265        while retry_count <= max_retries {
266            if retry_count > 0 {
267                // Exponential backoff: 1s, 2s, 4s, 8s, ... (capped at 16s)
268                let delay_secs = std::cmp::min(1u64 << (retry_count - 1), 16);
269                log::warn!(
270                    "Retry {}/{} for model {} after {}s delay (previous error: {:?})",
271                    retry_count,
272                    max_retries,
273                    model,
274                    delay_secs,
275                    last_error.as_ref().map(|e| e.to_string())
276                );
277                println!(
278                    "   ⏳ Rate limited, retrying in {}s (attempt {}/{})",
279                    delay_secs, retry_count, max_retries
280                );
281                tokio::time::sleep(tokio::time::Duration::from_secs(delay_secs)).await;
282            }
283
284            match self.client.exec_chat(model, chat_req.clone(), None).await {
285                Ok(chat_res) => {
286                    let tokens_in = chat_res.usage.prompt_tokens;
287                    let tokens_out = chat_res.usage.completion_tokens;
288                    let content = chat_res
289                        .first_text()
290                        .context("No text content in response")?;
291                    log::debug!(
292                        "Received response with {} characters in {}ms (tokens: in={:?}, out={:?})",
293                        content.len(),
294                        start_time.elapsed().as_millis(),
295                        tokens_in,
296                        tokens_out,
297                    );
298
299                    // Update shared token counter with real values when available
300                    let total = tokens_in.unwrap_or(0) + tokens_out.unwrap_or(0);
301                    if total > 0 {
302                        self.add_tokens(total as usize).await;
303                    }
304
305                    return Ok(LlmResponse {
306                        text: content.to_string(),
307                        tokens_in,
308                        tokens_out,
309                    });
310                }
311                Err(e) => {
312                    let err_str = e.to_string();
313
314                    // Check if it's a retryable error (rate limit, server error, network)
315                    let is_retryable = err_str.contains("429")
316                        || err_str.contains("rate limit")
317                        || err_str.contains("Rate limit")
318                        || err_str.contains("RESOURCE_EXHAUSTED")
319                        || err_str.contains("500")
320                        || err_str.contains("502")
321                        || err_str.contains("503")
322                        || err_str.contains("504")
323                        || err_str.contains("timeout")
324                        || err_str.contains("connection");
325
326                    if is_retryable && retry_count < max_retries {
327                        log::warn!("Retryable error for model {}: {}", model, err_str);
328                        last_error = Some(anyhow::anyhow!("{}", err_str));
329                        retry_count += 1;
330                        continue;
331                    } else {
332                        return Err(anyhow::anyhow!(
333                            "Failed to execute chat request for model {}: {}",
334                            model,
335                            err_str
336                        ));
337                    }
338                }
339            }
340        }
341
342        // Should not reach here, but handle gracefully
343        Err(last_error
344            .unwrap_or_else(|| anyhow::anyhow!("Unknown error after {} retries", max_retries)))
345    }
346
347    /// Generates a streaming response and sends chunks via mpsc channel.
348    pub async fn generate_response_stream_to_channel(
349        &self,
350        model: &str,
351        prompt: &str,
352        tx: mpsc::UnboundedSender<String>,
353    ) -> Result<()> {
354        self.increment_request().await;
355
356        let chat_req = ChatRequest::default().append_message(ChatMessage::user(prompt));
357
358        log::debug!("Sending streaming chat request to model: {model} with prompt: {prompt}");
359
360        let chat_res_stream = self
361            .client
362            .exec_chat_stream(model, chat_req, None)
363            .await
364            .context(format!(
365                "Failed to execute streaming chat request for model: {model}"
366            ))?;
367
368        let mut stream = chat_res_stream.stream;
369        let mut chunk_count = 0;
370        let mut total_content_length = 0;
371        let mut stream_ended_explicitly = false;
372        let start_time = Instant::now();
373
374        log::info!(
375            "=== STREAM START === Model: {}, Prompt length: {} chars",
376            model,
377            prompt.len()
378        );
379
380        while let Some(chunk_result) = stream.next().await {
381            let elapsed = start_time.elapsed();
382
383            match chunk_result {
384                Ok(ChatStreamEvent::Start) => {
385                    log::info!(">>> STREAM STARTED for model: {model} at {elapsed:?}");
386                }
387                Ok(ChatStreamEvent::Chunk(chunk)) => {
388                    chunk_count += 1;
389                    total_content_length += chunk.content.len();
390
391                    if chunk_count % 10 == 0 || chunk.content.len() > 100 {
392                        log::info!(
393                            "CHUNK #{}: {} chars, total: {} chars, elapsed: {:?}",
394                            chunk_count,
395                            chunk.content.len(),
396                            total_content_length,
397                            elapsed
398                        );
399                    }
400
401                    if !chunk.content.is_empty() && tx.send(chunk.content.clone()).is_err() {
402                        log::error!(
403                            "!!! CHANNEL SEND FAILED for chunk #{chunk_count} - STOPPING STREAM !!!"
404                        );
405                        break;
406                    }
407                }
408                Ok(ChatStreamEvent::ReasoningChunk(chunk)) => {
409                    log::info!(
410                        "REASONING CHUNK: {} chars at {:?}",
411                        chunk.content.len(),
412                        elapsed
413                    );
414                    if !chunk.content.is_empty() {
415                        let _ = tx.send(format!("__PERSPT_REASONING__:{}", chunk.content));
416                    }
417                }
418                Ok(ChatStreamEvent::End(_)) => {
419                    log::info!(">>> STREAM ENDED EXPLICITLY for model: {model} after {chunk_count} chunks, {total_content_length} chars, {elapsed:?} elapsed");
420                    stream_ended_explicitly = true;
421                    break;
422                }
423                Ok(ChatStreamEvent::ToolCallChunk(_)) => {
424                    log::debug!("Tool call chunk received (ignored)");
425                }
426                Ok(ChatStreamEvent::ThoughtSignatureChunk(_)) => {
427                    log::debug!("Thought signature chunk received (ignored)");
428                }
429                Err(e) => {
430                    log::error!(
431                        "!!! STREAM ERROR after {chunk_count} chunks at {elapsed:?}: {e} !!!"
432                    );
433                    let error_msg = format!("Stream error: {e}");
434                    let _ = tx.send(error_msg);
435                    return Err(e.into());
436                }
437            }
438        }
439
440        let final_elapsed = start_time.elapsed();
441        if !stream_ended_explicitly {
442            log::warn!("!!! STREAM ENDED IMPLICITLY (exhausted) for model: {model} after {chunk_count} chunks, {total_content_length} chars, {final_elapsed:?} elapsed !!!");
443        }
444
445        log::info!(
446            "=== STREAM COMPLETE === Model: {model}, Final: {chunk_count} chunks, {total_content_length} chars, {final_elapsed:?} elapsed"
447        );
448
449        // Add approximate token count
450        self.add_tokens(total_content_length / 4).await; // Rough estimate
451
452        if tx.send(EOT_SIGNAL.to_string()).is_err() {
453            log::error!("!!! FAILED TO SEND EOT SIGNAL - channel may be closed !!!");
454            return Err(anyhow::anyhow!("Channel closed during EOT signal send"));
455        }
456
457        log::info!(">>> EOT SIGNAL SENT for model: {model} <<<");
458        Ok(())
459    }
460
461    /// Get a list of supported providers
462    pub fn get_supported_providers() -> Vec<&'static str> {
463        vec![
464            "openai",
465            "anthropic",
466            "gemini",
467            "groq",
468            "cohere",
469            "ollama",
470            "vertex",
471            "xai",
472            "deepseek",
473        ]
474    }
475
476    /// Get all available providers
477    pub async fn get_available_providers(&self) -> Result<Vec<String>> {
478        Ok(Self::get_supported_providers()
479            .iter()
480            .map(|s| s.to_string())
481            .collect())
482    }
483
484    /// Test if a model is available and working
485    pub async fn test_model(&self, model: &str) -> Result<bool> {
486        match self.generate_response_simple(model, "Hello").await {
487            Ok(_) => {
488                log::info!("Model {model} is available and working");
489                Ok(true)
490            }
491            Err(e) => {
492                log::warn!("Model {model} test failed: {e}");
493                Ok(false)
494            }
495        }
496    }
497
498    /// Validate and get the best available model for a provider
499    pub async fn validate_model(&self, model: &str, provider_type: Option<&str>) -> Result<String> {
500        if self.test_model(model).await? {
501            return Ok(model.to_string());
502        }
503
504        if let Some(provider) = provider_type {
505            if let Ok(models) = self.get_available_models(provider).await {
506                if !models.is_empty() {
507                    log::info!("Model {} not available, using {} instead", model, models[0]);
508                    return Ok(models[0].clone());
509                }
510            }
511        }
512
513        log::warn!("Could not validate model {model}, proceeding anyway");
514        Ok(model.to_string())
515    }
516}
517
518/// Build a genai client for Google Vertex AI authenticated via Application
519/// Default Credentials (ADC).
520///
521/// genai's Vertex adapter reads `VERTEX_PROJECT_ID` (required) and
522/// `VERTEX_LOCATION` to construct the request
523/// URL, and expects an OAuth2 Bearer token from an [`AuthResolver`]. This
524/// resolver fetches that token from ADC (gcloud login, a service account, or the
525/// metadata server) on each request, so no static API key is needed. If a
526/// `VERTEX_API_KEY` bearer token is explicitly set, it is used as an override.
527fn build_vertex_client() -> Client {
528    let resolver = AuthResolver::from_resolver_async_fn(
529        |_model: ModelIden| -> Pin<
530            Box<dyn Future<Output = genai::resolver::Result<Option<AuthData>>> + Send>,
531        > {
532            Box::pin(async move {
533                // Explicit bearer-token override wins when present.
534                if let Ok(token) = std::env::var("VERTEX_API_KEY") {
535                    if !token.trim().is_empty() {
536                        return Ok(Some(AuthData::from_single(token)));
537                    }
538                }
539                // Otherwise resolve an access token from ADC.
540                let provider = gcp_auth::provider().await.map_err(|e| {
541                    genai::resolver::Error::Custom(format!(
542                        "Vertex ADC provider init failed (run `gcloud auth application-default login`): {e}"
543                    ))
544                })?;
545                let scopes = ["https://www.googleapis.com/auth/cloud-platform"];
546                let token = provider.token(&scopes).await.map_err(|e| {
547                    genai::resolver::Error::Custom(format!("Vertex ADC token fetch failed: {e}"))
548                })?;
549                Ok(Some(AuthData::from_single(token.as_str())))
550            })
551        },
552    );
553
554    let mut builder = Client::builder()
555        .with_adapter_kind(AdapterKind::Vertex)
556        .with_auth_resolver(resolver);
557
558    if let Some(endpoint) = resolved_vertex_endpoint() {
559        builder = builder.with_service_target_resolver_fn(move |mut target: ServiceTarget| {
560            target.endpoint = Endpoint::from_owned(endpoint.clone());
561            Ok(target)
562        });
563    }
564
565    builder.build()
566}
567
568fn build_bound_client(adapter_kind: AdapterKind, provider_type: Option<&str>) -> Client {
569    let mut builder = Client::builder().with_adapter_kind(adapter_kind);
570
571    if let Some(base_url) = provider_type.and_then(provider_base_url_from_env) {
572        let endpoint = normalize_base_url(&base_url);
573        let target_resolver = ServiceTargetResolver::from_resolver_fn(
574            move |mut service_target: ServiceTarget| -> genai::resolver::Result<ServiceTarget> {
575                if service_target.model.adapter_kind == adapter_kind {
576                    service_target.endpoint = Endpoint::from_owned(endpoint.clone());
577                }
578                Ok(service_target)
579            },
580        );
581        builder = builder.with_service_target_resolver(target_resolver);
582    }
583
584    builder.build()
585}
586
587fn provider_from_model_namespace(model: &str) -> Option<&'static str> {
588    let lower = model.to_ascii_lowercase();
589    lower.split_once("::").and_then(|(prefix, _)| match prefix {
590        "openai" => Some("openai"),
591        "anthropic" => Some("anthropic"),
592        "gemini" | "google" => Some("gemini"),
593        "vertex" => Some("vertex"),
594        "groq" => Some("groq"),
595        "cohere" => Some("cohere"),
596        "ollama" => Some("ollama"),
597        "xai" => Some("xai"),
598        "deepseek" => Some("deepseek"),
599        _ => None,
600    })
601}
602
603fn configure_vertex_environment(config: &Config) {
604    if std::env::var("VERTEX_PROJECT_ID").is_err() {
605        if let Some(project) = config
606            .vertex_project_id
607            .as_deref()
608            .map(str::trim)
609            .filter(|v| !v.is_empty())
610            .map(str::to_string)
611            .or_else(vertex_project_from_env)
612            .or_else(read_gcloud_project)
613        {
614            // Only propagate a syntactically valid project into the process env;
615            // genai reads VERTEX_PROJECT_ID to build the request URL, so an
616            // invalid value should not be planted there from config discovery.
617            match valid_vertex_segment(&project) {
618                Some(valid) => std::env::set_var("VERTEX_PROJECT_ID", valid),
619                None => log::warn!(
620                    "Ignoring discovered Vertex project ID (must contain only ASCII letters, \
621                     digits, and hyphens)"
622                ),
623            }
624        }
625    }
626
627    if std::env::var("VERTEX_LOCATION").is_err() {
628        if let Some(location) = config
629            .vertex_location
630            .as_deref()
631            .map(str::trim)
632            .filter(|v| !v.is_empty())
633        {
634            // The location is interpolated into the endpoint host, so never
635            // plant an invalid value into the env from a config file.
636            match valid_vertex_segment(location) {
637                Some(valid) => std::env::set_var("VERTEX_LOCATION", valid),
638                None => log::warn!(
639                    "Ignoring invalid vertex_location from config (must contain only ASCII \
640                     letters, digits, and hyphens)"
641                ),
642            }
643        }
644    }
645}
646
647fn vertex_project_from_env() -> Option<String> {
648    [
649        "VERTEX_PROJECT_ID",
650        "GOOGLE_CLOUD_PROJECT",
651        "GCLOUD_PROJECT",
652        "CLOUDSDK_CORE_PROJECT",
653    ]
654    .into_iter()
655    .filter_map(|key| std::env::var(key).ok())
656    .map(|value| value.trim().to_string())
657    .find(|value| !value.is_empty())
658}
659
660fn gcloud_config_dir() -> Option<PathBuf> {
661    if let Ok(dir) = std::env::var("CLOUDSDK_CONFIG") {
662        let trimmed = dir.trim();
663        if !trimmed.is_empty() {
664            return Some(PathBuf::from(trimmed));
665        }
666    }
667    dirs::home_dir().map(|home| home.join(".config").join("gcloud"))
668}
669
670fn read_gcloud_project() -> Option<String> {
671    let config_dir = gcloud_config_dir()?;
672    let active_config = std::fs::read_to_string(config_dir.join("active_config"))
673        .ok()
674        .map(|s| s.trim().to_string())
675        .filter(|s| !s.is_empty())
676        .unwrap_or_else(|| "default".to_string());
677    let config_path = config_dir
678        .join("configurations")
679        .join(format!("config_{active_config}"));
680    let content = std::fs::read_to_string(config_path).ok()?;
681    parse_gcloud_project(&content)
682}
683
684fn parse_gcloud_project(content: &str) -> Option<String> {
685    let mut in_core = false;
686    for raw in content.lines() {
687        let line = raw.trim();
688        if line.is_empty() || line.starts_with('#') {
689            continue;
690        }
691        if line.starts_with('[') && line.ends_with(']') {
692            in_core = line.eq_ignore_ascii_case("[core]");
693            continue;
694        }
695        if !in_core {
696            continue;
697        }
698        let Some((key, value)) = line.split_once('=') else {
699            continue;
700        };
701        if key.trim() == "project" {
702            let project = value.trim();
703            if !project.is_empty() {
704                return Some(project.to_string());
705            }
706        }
707    }
708    None
709}
710
711/// Validate a Vertex project/location segment before it is interpolated into the
712/// request endpoint URL.
713///
714/// GCP project IDs and Vertex locations are limited to ASCII letters, digits,
715/// and hyphens. Rejecting anything else prevents a crafted value (e.g. one
716/// containing `/`, `.`, `:`, or `@`) from altering the endpoint *host* — the
717/// location is interpolated into `{location}-aiplatform.googleapis.com`, so an
718/// unvalidated value could otherwise redirect the OAuth bearer token to an
719/// attacker-controlled host. Returns the trimmed value when valid.
720fn valid_vertex_segment(value: &str) -> Option<&str> {
721    let trimmed = value.trim();
722    if trimmed.is_empty() {
723        return None;
724    }
725    trimmed
726        .chars()
727        .all(|c| c.is_ascii_alphanumeric() || c == '-')
728        .then_some(trimmed)
729}
730
731fn resolved_vertex_endpoint() -> Option<String> {
732    let project_raw = std::env::var("VERTEX_PROJECT_ID").ok()?;
733    let project = match valid_vertex_segment(&project_raw) {
734        Some(p) => p.to_string(),
735        None => {
736            log::warn!(
737                "Ignoring VERTEX_PROJECT_ID for endpoint construction: must be non-empty and \
738                 contain only ASCII letters, digits, and hyphens"
739            );
740            return None;
741        }
742    };
743    let location = match std::env::var("VERTEX_LOCATION") {
744        Ok(raw) if !raw.trim().is_empty() => match valid_vertex_segment(&raw) {
745            Some(l) => l.to_string(),
746            None => {
747                log::warn!(
748                    "Ignoring invalid VERTEX_LOCATION (must contain only ASCII letters, digits, \
749                     and hyphens); falling back to 'global'"
750                );
751                "global".to_string()
752            }
753        },
754        _ => "global".to_string(),
755    };
756    Some(vertex_endpoint_base(&project, &location))
757}
758
759fn vertex_endpoint_base(project: &str, location: &str) -> String {
760    let project = project.trim();
761    let location = location.trim();
762    if location.eq_ignore_ascii_case("global") {
763        format!("https://aiplatform.googleapis.com/v1/projects/{project}/locations/global/")
764    } else {
765        format!(
766            "https://{location}-aiplatform.googleapis.com/v1/projects/{project}/locations/{location}/"
767        )
768    }
769}
770
771fn provider_base_url_env_var(provider: &str) -> Option<&'static str> {
772    match provider.to_lowercase().as_str() {
773        "openai" => Some("OPENAI_BASE_URL"),
774        "anthropic" => Some("ANTHROPIC_BASE_URL"),
775        "gemini" | "google" => Some("GEMINI_BASE_URL"),
776        "groq" => Some("GROQ_BASE_URL"),
777        "cohere" => Some("COHERE_BASE_URL"),
778        "ollama" => Some("OLLAMA_BASE_URL"),
779        "xai" => Some("XAI_BASE_URL"),
780        "deepseek" => Some("DEEPSEEK_BASE_URL"),
781        _ => None,
782    }
783}
784
785fn provider_base_url_from_env(provider: &str) -> Option<String> {
786    let env_var = provider_base_url_env_var(provider)?;
787
788    std::env::var(env_var)
789        .ok()
790        .map(|value| value.trim().to_string())
791        .filter(|value| !value.is_empty())
792}
793
794fn provider_api_key_env_var(provider: &str) -> Option<&'static str> {
795    match provider.to_lowercase().as_str() {
796        "openai" => Some("OPENAI_API_KEY"),
797        "anthropic" => Some("ANTHROPIC_API_KEY"),
798        "gemini" | "google" => Some("GEMINI_API_KEY"),
799        "vertex" => Some("VERTEX_API_KEY"),
800        "groq" => Some("GROQ_API_KEY"),
801        "cohere" => Some("COHERE_API_KEY"),
802        "xai" => Some("XAI_API_KEY"),
803        "deepseek" => Some("DEEPSEEK_API_KEY"),
804        _ => None,
805    }
806}
807
808fn normalize_base_url(base_url: &str) -> String {
809    if base_url.ends_with('/') {
810        base_url.to_string()
811    } else {
812        format!("{base_url}/")
813    }
814}
815
816/// Convert a provider string to genai AdapterKind
817fn str_to_adapter_kind(provider: &str) -> Result<AdapterKind> {
818    match provider.to_lowercase().as_str() {
819        "openai" => Ok(AdapterKind::OpenAI),
820        "anthropic" => Ok(AdapterKind::Anthropic),
821        "gemini" | "google" => Ok(AdapterKind::Gemini),
822        "vertex" => Ok(AdapterKind::Vertex),
823        "groq" => Ok(AdapterKind::Groq),
824        "cohere" => Ok(AdapterKind::Cohere),
825        "ollama" => Ok(AdapterKind::Ollama),
826        "xai" => Ok(AdapterKind::Xai),
827        "deepseek" => Ok(AdapterKind::DeepSeek),
828        _ => Err(anyhow::anyhow!("Unsupported provider: {}", provider)),
829    }
830}
831
832#[cfg(test)]
833mod tests {
834    use super::*;
835
836    #[test]
837    fn test_str_to_adapter_kind() {
838        assert!(str_to_adapter_kind("openai").is_ok());
839        assert!(str_to_adapter_kind("anthropic").is_ok());
840        assert!(str_to_adapter_kind("gemini").is_ok());
841        assert!(str_to_adapter_kind("google").is_ok());
842        assert!(str_to_adapter_kind("groq").is_ok());
843        assert!(str_to_adapter_kind("cohere").is_ok());
844        assert!(str_to_adapter_kind("ollama").is_ok());
845        assert!(str_to_adapter_kind("vertex").is_ok());
846        assert!(str_to_adapter_kind("xai").is_ok());
847        assert!(str_to_adapter_kind("deepseek").is_ok());
848        assert!(str_to_adapter_kind("invalid").is_err());
849    }
850
851    #[tokio::test]
852    async fn test_provider_creation() {
853        let provider = GenAIProvider::new();
854        assert!(provider.is_ok());
855    }
856
857    #[tokio::test]
858    async fn test_configured_provider_binds_adapter_for_custom_model_names() {
859        let provider = GenAIProvider::new_with_config(Some("openai"), None).unwrap();
860        let target = provider
861            .client
862            .resolve_service_target("gemma4-32b-it")
863            .await
864            .unwrap();
865
866        assert_eq!(target.model.adapter_kind, AdapterKind::OpenAI);
867    }
868
869    #[tokio::test]
870    async fn test_namespaced_model_resolves_on_unbound_client() {
871        // genai-native namespacing must work without a bound client.
872        let provider = GenAIProvider::new().unwrap();
873        let target = provider
874            .client
875            .resolve_service_target("openai::phi-4-npu-ov")
876            .await
877            .unwrap();
878
879        assert_eq!(target.model.adapter_kind, AdapterKind::OpenAI);
880    }
881
882    #[tokio::test]
883    async fn test_from_config_binds_adapter_for_custom_model() {
884        let config = Config {
885            provider: Some("openai".to_string()),
886            model: Some("phi-4-npu-ov".to_string()),
887            ..Default::default()
888        };
889        let (provider, resolved) = GenAIProvider::from_config(&config, None).unwrap();
890        assert_eq!(resolved.provider, "openai");
891        assert_eq!(resolved.model, "phi-4-npu-ov");
892
893        let target = provider
894            .client
895            .resolve_service_target(&resolved.model)
896            .await
897            .unwrap();
898        assert_eq!(target.model.adapter_kind, AdapterKind::OpenAI);
899    }
900
901    #[test]
902    fn test_from_config_model_precedence() {
903        let config = Config {
904            provider: Some("openai".to_string()),
905            model: Some("config-model".to_string()),
906            ..Default::default()
907        };
908        // CLI override wins over config model.
909        let (_p, resolved) = GenAIProvider::from_config(&config, Some("cli-model")).unwrap();
910        assert_eq!(resolved.model, "cli-model");
911    }
912
913    #[test]
914    fn test_provider_from_model_namespace_detects_vertex() {
915        assert_eq!(
916            provider_from_model_namespace("vertex::gemini-2.5-flash"),
917            Some("vertex")
918        );
919        assert_eq!(provider_from_model_namespace("gemini-2.5-flash"), None);
920    }
921
922    #[tokio::test]
923    async fn test_from_config_uses_namespaced_vertex_model_when_provider_absent() {
924        let previous_project = std::env::var("VERTEX_PROJECT_ID").ok();
925        let previous_location = std::env::var("VERTEX_LOCATION").ok();
926        std::env::set_var("VERTEX_PROJECT_ID", "unit-test-project");
927        std::env::remove_var("VERTEX_LOCATION");
928
929        let config = Config::default();
930        let (_provider, resolved) =
931            GenAIProvider::from_config(&config, Some("vertex::gemini-2.5-flash")).unwrap();
932        assert_eq!(resolved.provider, "vertex");
933        assert_eq!(resolved.model, "vertex::gemini-2.5-flash");
934        assert!(std::env::var("VERTEX_LOCATION").is_err());
935        assert_eq!(
936            resolved_vertex_endpoint().as_deref(),
937            Some(
938                "https://aiplatform.googleapis.com/v1/projects/unit-test-project/locations/global/"
939            )
940        );
941
942        match previous_project {
943            Some(value) => std::env::set_var("VERTEX_PROJECT_ID", value),
944            None => std::env::remove_var("VERTEX_PROJECT_ID"),
945        }
946        match previous_location {
947            Some(value) => std::env::set_var("VERTEX_LOCATION", value),
948            None => std::env::remove_var("VERTEX_LOCATION"),
949        }
950    }
951
952    #[test]
953    fn test_valid_vertex_segment_accepts_real_values() {
954        assert_eq!(valid_vertex_segment("perspt"), Some("perspt"));
955        assert_eq!(valid_vertex_segment("us-central1"), Some("us-central1"));
956        assert_eq!(valid_vertex_segment("global"), Some("global"));
957        assert_eq!(valid_vertex_segment("europe-west4"), Some("europe-west4"));
958        assert_eq!(valid_vertex_segment("  perspt  "), Some("perspt")); // trimmed
959    }
960
961    #[test]
962    fn test_valid_vertex_segment_rejects_host_redirection() {
963        // Values that could alter the endpoint host or path must be rejected.
964        assert_eq!(valid_vertex_segment("evil.com/"), None);
965        assert_eq!(valid_vertex_segment("evil.com"), None); // '.'
966        assert_eq!(valid_vertex_segment("a/b"), None);
967        assert_eq!(valid_vertex_segment("a:b"), None);
968        assert_eq!(valid_vertex_segment("a@b"), None);
969        assert_eq!(valid_vertex_segment("a b"), None);
970        assert_eq!(valid_vertex_segment(""), None);
971        assert_eq!(valid_vertex_segment("   "), None);
972    }
973
974    #[test]
975    fn test_resolved_vertex_endpoint_rejects_malicious_location() {
976        let prev_project = std::env::var("VERTEX_PROJECT_ID").ok();
977        let prev_location = std::env::var("VERTEX_LOCATION").ok();
978
979        // A crafted location must not be interpolated into the host; it falls
980        // back to the safe global endpoint instead.
981        std::env::set_var("VERTEX_PROJECT_ID", "perspt");
982        std::env::set_var("VERTEX_LOCATION", "evil.com/");
983        assert_eq!(
984            resolved_vertex_endpoint().as_deref(),
985            Some("https://aiplatform.googleapis.com/v1/projects/perspt/locations/global/"),
986            "malicious location must fall back to global, never redirect the host"
987        );
988
989        // An invalid project yields no endpoint override at all.
990        std::env::set_var("VERTEX_PROJECT_ID", "bad/project");
991        std::env::set_var("VERTEX_LOCATION", "us-central1");
992        assert_eq!(resolved_vertex_endpoint(), None);
993
994        match prev_project {
995            Some(v) => std::env::set_var("VERTEX_PROJECT_ID", v),
996            None => std::env::remove_var("VERTEX_PROJECT_ID"),
997        }
998        match prev_location {
999            Some(v) => std::env::set_var("VERTEX_LOCATION", v),
1000            None => std::env::remove_var("VERTEX_LOCATION"),
1001        }
1002    }
1003
1004    #[test]
1005    fn test_vertex_endpoint_base_matches_genai_vertex_shape() {
1006        assert_eq!(
1007            vertex_endpoint_base("test-project", "global"),
1008            "https://aiplatform.googleapis.com/v1/projects/test-project/locations/global/"
1009        );
1010        assert_eq!(
1011            vertex_endpoint_base("test-project", "test-location"),
1012            "https://test-location-aiplatform.googleapis.com/v1/projects/test-project/locations/test-location/"
1013        );
1014    }
1015
1016    #[test]
1017    fn test_parse_gcloud_project_reads_core_project() {
1018        let content = r#"
1019        [compute]
1020        region = ignored-location
1021
1022        [core]
1023        account = user@example.com
1024        project = test-project
1025        "#;
1026        assert_eq!(
1027            parse_gcloud_project(content).as_deref(),
1028            Some("test-project")
1029        );
1030    }
1031
1032    #[tokio::test]
1033    async fn test_openai_base_url_overrides_bound_provider_endpoint() {
1034        let previous = std::env::var("OPENAI_BASE_URL").ok();
1035        std::env::set_var("OPENAI_BASE_URL", "https://custom.example/v1");
1036
1037        let provider = GenAIProvider::new_with_config(Some("openai"), None).unwrap();
1038        let target = provider
1039            .client
1040            .resolve_service_target("gemma4-32b-it")
1041            .await
1042            .unwrap();
1043
1044        assert_eq!(target.endpoint.base_url(), "https://custom.example/v1/");
1045
1046        match previous {
1047            Some(value) => std::env::set_var("OPENAI_BASE_URL", value),
1048            None => std::env::remove_var("OPENAI_BASE_URL"),
1049        }
1050    }
1051
1052    #[test]
1053    fn test_normalize_base_url() {
1054        assert_eq!(
1055            normalize_base_url("https://custom.example/v1"),
1056            "https://custom.example/v1/"
1057        );
1058        assert_eq!(
1059            normalize_base_url("https://custom.example/v1/"),
1060            "https://custom.example/v1/"
1061        );
1062    }
1063
1064    #[tokio::test]
1065    async fn test_provider_is_clonable() {
1066        let provider = GenAIProvider::new().unwrap();
1067        let _clone1 = provider.clone();
1068        let _clone2 = provider.clone();
1069        // All clones share the same underlying state
1070    }
1071}