1use anyhow::{Context, Result};
7use futures::StreamExt;
8use genai::adapter::AdapterKind;
9use genai::chat::{ChatMessage, ChatRequest, ChatStreamEvent};
10use genai::resolver::{AuthData, AuthResolver, Endpoint, ProviderConfig, ServiceTargetResolver};
11use genai::Client;
12use genai::ModelIden;
13use genai::ServiceTarget;
14use std::future::Future;
15use std::path::PathBuf;
16use std::pin::Pin;
17use std::sync::Arc;
18use std::time::Instant;
19use tokio::sync::{mpsc, RwLock};
20
21use crate::config::Config;
22
23pub const EOT_SIGNAL: &str = "<|EOT|>";
25
26#[derive(Debug, Clone)]
28pub struct ResolvedProvider {
29 pub provider: String,
31 pub model: String,
33}
34
35pub fn detect_provider_from_env() -> (&'static str, &'static str) {
40 if vertex_project_from_env().is_some() {
41 ("vertex", "vertex::gemini-2.5-flash")
45 } else if std::env::var("GEMINI_API_KEY").is_ok() {
46 ("gemini", "gemini-3.1-flash-lite-preview")
47 } else if std::env::var("OPENAI_API_KEY").is_ok() {
48 ("openai", "gpt-4o-mini")
49 } else if std::env::var("ANTHROPIC_API_KEY").is_ok() {
50 ("anthropic", "claude-3-5-sonnet-20241022")
51 } else if std::env::var("GROQ_API_KEY").is_ok() {
52 ("groq", "llama-3.1-8b-instant")
53 } else if std::env::var("COHERE_API_KEY").is_ok() {
54 ("cohere", "command-r-plus")
55 } else if std::env::var("XAI_API_KEY").is_ok() {
56 ("xai", "grok-beta")
57 } else if std::env::var("DEEPSEEK_API_KEY").is_ok() {
58 ("deepseek", "deepseek-chat")
59 } else {
60 ("ollama", "llama3.2")
62 }
63}
64
65#[derive(Debug, Clone)]
67pub struct LlmResponse {
68 pub text: String,
69 pub tokens_in: Option<i32>,
70 pub tokens_out: Option<i32>,
71}
72
73#[derive(Default)]
75struct SharedState {
76 total_tokens_used: usize,
77 request_count: usize,
78}
79
80#[derive(Clone)]
85pub struct GenAIProvider {
86 client: Arc<Client>,
88 shared: Arc<RwLock<SharedState>>,
90}
91
92impl GenAIProvider {
93 pub fn new() -> Result<Self> {
95 let client = Client::default();
96 Ok(Self::from_client(client))
97 }
98
99 pub fn new_with_config(provider_type: Option<&str>, api_key: Option<&str>) -> Result<Self> {
101 let adapter_kind = provider_type.and_then(|provider| match str_to_adapter_kind(provider) {
102 Ok(adapter_kind) => Some(adapter_kind),
103 Err(_) => {
104 log::warn!("Unknown provider type for genai client: {provider}");
105 None
106 }
107 });
108
109 if let (Some(provider), Some(key)) = (provider_type, api_key) {
111 if let Some(env_var) = provider_api_key_env_var(provider) {
112 log::info!("Setting {env_var} environment variable for genai client");
113 std::env::set_var(env_var, key);
114 } else if provider.eq_ignore_ascii_case("ollama") {
115 log::info!("Ollama provider detected - no API key required for local setup");
116 } else {
117 log::warn!("Unknown provider type for API key: {provider}");
118 }
119 }
120
121 let is_vertex = provider_type
122 .map(|p| p.eq_ignore_ascii_case("vertex"))
123 .unwrap_or(false);
124
125 let client = if is_vertex {
126 build_vertex_client()
129 } else {
130 match adapter_kind {
131 Some(adapter_kind) => build_bound_client(adapter_kind, provider_type),
132 None => Client::default(),
133 }
134 };
135
136 Ok(Self::from_client(client))
137 }
138
139 pub fn from_config(
153 config: &Config,
154 cli_model: Option<&str>,
155 ) -> Result<(Self, ResolvedProvider)> {
156 let (env_provider, env_model) = detect_provider_from_env();
157
158 let env_model_override = std::env::var("OPENAI_MODEL")
159 .or_else(|_| std::env::var("MODEL"))
160 .ok();
161
162 let model = cli_model
163 .map(str::to_string)
164 .or_else(|| config.model.clone())
165 .or(env_model_override)
166 .unwrap_or_else(|| env_model.to_string());
167
168 let provider = config
169 .provider
170 .clone()
171 .or_else(|| provider_from_model_namespace(&model).map(str::to_string))
172 .unwrap_or_else(|| env_provider.to_string());
173
174 if let Some(base_url) = config.base_url.as_deref() {
177 if let Some(env_var) = provider_base_url_env_var(&provider) {
178 if std::env::var(env_var).is_err() {
179 std::env::set_var(env_var, base_url);
180 }
181 }
182 }
183
184 if provider.eq_ignore_ascii_case("vertex") {
185 configure_vertex_environment(config);
186 }
187
188 let provider_obj = Self::new_with_config(Some(&provider), config.api_key.as_deref())?;
189 Ok((provider_obj, ResolvedProvider { provider, model }))
190 }
191
192 fn from_client(client: Client) -> Self {
193 Self {
194 client: Arc::new(client),
195 shared: Arc::new(RwLock::new(SharedState::default())),
196 }
197 }
198
199 pub async fn get_total_tokens_used(&self) -> usize {
201 self.shared.read().await.total_tokens_used
202 }
203
204 pub async fn get_request_count(&self) -> usize {
206 self.shared.read().await.request_count
207 }
208
209 async fn increment_request(&self) {
211 let mut state = self.shared.write().await;
212 state.request_count += 1;
213 }
214
215 pub async fn add_tokens(&self, count: usize) {
217 let mut state = self.shared.write().await;
218 state.total_tokens_used += count;
219 }
220
221 pub async fn get_available_models(&self, provider: &str) -> Result<Vec<String>> {
223 let adapter_kind = str_to_adapter_kind(provider)?;
224 let provider_config = provider_base_url_from_env(provider)
225 .map(|base_url| {
226 ProviderConfig::from_endpoint(Endpoint::from_owned(normalize_base_url(&base_url)))
227 })
228 .unwrap_or_default();
229
230 let models = self
231 .client
232 .all_model_names(adapter_kind, provider_config)
233 .await
234 .context(format!("Failed to get models for provider: {provider}"))?;
235
236 Ok(models)
237 }
238
239 pub async fn generate_response_simple(&self, model: &str, prompt: &str) -> Result<LlmResponse> {
242 self.generate_response_with_retry(model, prompt, 3).await
243 }
244
245 pub async fn generate_response_with_retry(
247 &self,
248 model: &str,
249 prompt: &str,
250 max_retries: usize,
251 ) -> Result<LlmResponse> {
252 self.increment_request().await;
253
254 let chat_req = ChatRequest::default().append_message(ChatMessage::user(prompt));
255
256 log::debug!(
257 "Sending chat request to model: {model} with prompt length: {} chars",
258 prompt.len()
259 );
260
261 let start_time = Instant::now();
262 let mut last_error: Option<anyhow::Error> = None;
263 let mut retry_count = 0;
264
265 while retry_count <= max_retries {
266 if retry_count > 0 {
267 let delay_secs = std::cmp::min(1u64 << (retry_count - 1), 16);
269 log::warn!(
270 "Retry {}/{} for model {} after {}s delay (previous error: {:?})",
271 retry_count,
272 max_retries,
273 model,
274 delay_secs,
275 last_error.as_ref().map(|e| e.to_string())
276 );
277 println!(
278 " ⏳ Rate limited, retrying in {}s (attempt {}/{})",
279 delay_secs, retry_count, max_retries
280 );
281 tokio::time::sleep(tokio::time::Duration::from_secs(delay_secs)).await;
282 }
283
284 match self.client.exec_chat(model, chat_req.clone(), None).await {
285 Ok(chat_res) => {
286 let tokens_in = chat_res.usage.prompt_tokens;
287 let tokens_out = chat_res.usage.completion_tokens;
288 let content = chat_res
289 .first_text()
290 .context("No text content in response")?;
291 log::debug!(
292 "Received response with {} characters in {}ms (tokens: in={:?}, out={:?})",
293 content.len(),
294 start_time.elapsed().as_millis(),
295 tokens_in,
296 tokens_out,
297 );
298
299 let total = tokens_in.unwrap_or(0) + tokens_out.unwrap_or(0);
301 if total > 0 {
302 self.add_tokens(total as usize).await;
303 }
304
305 return Ok(LlmResponse {
306 text: content.to_string(),
307 tokens_in,
308 tokens_out,
309 });
310 }
311 Err(e) => {
312 let err_str = e.to_string();
313
314 let is_retryable = err_str.contains("429")
316 || err_str.contains("rate limit")
317 || err_str.contains("Rate limit")
318 || err_str.contains("RESOURCE_EXHAUSTED")
319 || err_str.contains("500")
320 || err_str.contains("502")
321 || err_str.contains("503")
322 || err_str.contains("504")
323 || err_str.contains("timeout")
324 || err_str.contains("connection");
325
326 if is_retryable && retry_count < max_retries {
327 log::warn!("Retryable error for model {}: {}", model, err_str);
328 last_error = Some(anyhow::anyhow!("{}", err_str));
329 retry_count += 1;
330 continue;
331 } else {
332 return Err(anyhow::anyhow!(
333 "Failed to execute chat request for model {}: {}",
334 model,
335 err_str
336 ));
337 }
338 }
339 }
340 }
341
342 Err(last_error
344 .unwrap_or_else(|| anyhow::anyhow!("Unknown error after {} retries", max_retries)))
345 }
346
347 pub async fn generate_response_stream_to_channel(
349 &self,
350 model: &str,
351 prompt: &str,
352 tx: mpsc::UnboundedSender<String>,
353 ) -> Result<()> {
354 self.increment_request().await;
355
356 let chat_req = ChatRequest::default().append_message(ChatMessage::user(prompt));
357
358 log::debug!("Sending streaming chat request to model: {model} with prompt: {prompt}");
359
360 let chat_res_stream = self
361 .client
362 .exec_chat_stream(model, chat_req, None)
363 .await
364 .context(format!(
365 "Failed to execute streaming chat request for model: {model}"
366 ))?;
367
368 let mut stream = chat_res_stream.stream;
369 let mut chunk_count = 0;
370 let mut total_content_length = 0;
371 let mut stream_ended_explicitly = false;
372 let start_time = Instant::now();
373
374 log::info!(
375 "=== STREAM START === Model: {}, Prompt length: {} chars",
376 model,
377 prompt.len()
378 );
379
380 while let Some(chunk_result) = stream.next().await {
381 let elapsed = start_time.elapsed();
382
383 match chunk_result {
384 Ok(ChatStreamEvent::Start) => {
385 log::info!(">>> STREAM STARTED for model: {model} at {elapsed:?}");
386 }
387 Ok(ChatStreamEvent::Chunk(chunk)) => {
388 chunk_count += 1;
389 total_content_length += chunk.content.len();
390
391 if chunk_count % 10 == 0 || chunk.content.len() > 100 {
392 log::info!(
393 "CHUNK #{}: {} chars, total: {} chars, elapsed: {:?}",
394 chunk_count,
395 chunk.content.len(),
396 total_content_length,
397 elapsed
398 );
399 }
400
401 if !chunk.content.is_empty() && tx.send(chunk.content.clone()).is_err() {
402 log::error!(
403 "!!! CHANNEL SEND FAILED for chunk #{chunk_count} - STOPPING STREAM !!!"
404 );
405 break;
406 }
407 }
408 Ok(ChatStreamEvent::ReasoningChunk(chunk)) => {
409 log::info!(
410 "REASONING CHUNK: {} chars at {:?}",
411 chunk.content.len(),
412 elapsed
413 );
414 if !chunk.content.is_empty() {
415 let _ = tx.send(format!("__PERSPT_REASONING__:{}", chunk.content));
416 }
417 }
418 Ok(ChatStreamEvent::End(_)) => {
419 log::info!(">>> STREAM ENDED EXPLICITLY for model: {model} after {chunk_count} chunks, {total_content_length} chars, {elapsed:?} elapsed");
420 stream_ended_explicitly = true;
421 break;
422 }
423 Ok(ChatStreamEvent::ToolCallChunk(_)) => {
424 log::debug!("Tool call chunk received (ignored)");
425 }
426 Ok(ChatStreamEvent::ThoughtSignatureChunk(_)) => {
427 log::debug!("Thought signature chunk received (ignored)");
428 }
429 Err(e) => {
430 log::error!(
431 "!!! STREAM ERROR after {chunk_count} chunks at {elapsed:?}: {e} !!!"
432 );
433 let error_msg = format!("Stream error: {e}");
434 let _ = tx.send(error_msg);
435 return Err(e.into());
436 }
437 }
438 }
439
440 let final_elapsed = start_time.elapsed();
441 if !stream_ended_explicitly {
442 log::warn!("!!! STREAM ENDED IMPLICITLY (exhausted) for model: {model} after {chunk_count} chunks, {total_content_length} chars, {final_elapsed:?} elapsed !!!");
443 }
444
445 log::info!(
446 "=== STREAM COMPLETE === Model: {model}, Final: {chunk_count} chunks, {total_content_length} chars, {final_elapsed:?} elapsed"
447 );
448
449 self.add_tokens(total_content_length / 4).await; if tx.send(EOT_SIGNAL.to_string()).is_err() {
453 log::error!("!!! FAILED TO SEND EOT SIGNAL - channel may be closed !!!");
454 return Err(anyhow::anyhow!("Channel closed during EOT signal send"));
455 }
456
457 log::info!(">>> EOT SIGNAL SENT for model: {model} <<<");
458 Ok(())
459 }
460
461 pub fn get_supported_providers() -> Vec<&'static str> {
463 vec![
464 "openai",
465 "anthropic",
466 "gemini",
467 "groq",
468 "cohere",
469 "ollama",
470 "vertex",
471 "xai",
472 "deepseek",
473 ]
474 }
475
476 pub async fn get_available_providers(&self) -> Result<Vec<String>> {
478 Ok(Self::get_supported_providers()
479 .iter()
480 .map(|s| s.to_string())
481 .collect())
482 }
483
484 pub async fn test_model(&self, model: &str) -> Result<bool> {
486 match self.generate_response_simple(model, "Hello").await {
487 Ok(_) => {
488 log::info!("Model {model} is available and working");
489 Ok(true)
490 }
491 Err(e) => {
492 log::warn!("Model {model} test failed: {e}");
493 Ok(false)
494 }
495 }
496 }
497
498 pub async fn validate_model(&self, model: &str, provider_type: Option<&str>) -> Result<String> {
500 if self.test_model(model).await? {
501 return Ok(model.to_string());
502 }
503
504 if let Some(provider) = provider_type {
505 if let Ok(models) = self.get_available_models(provider).await {
506 if !models.is_empty() {
507 log::info!("Model {} not available, using {} instead", model, models[0]);
508 return Ok(models[0].clone());
509 }
510 }
511 }
512
513 log::warn!("Could not validate model {model}, proceeding anyway");
514 Ok(model.to_string())
515 }
516}
517
518fn build_vertex_client() -> Client {
528 let resolver = AuthResolver::from_resolver_async_fn(
529 |_model: ModelIden| -> Pin<
530 Box<dyn Future<Output = genai::resolver::Result<Option<AuthData>>> + Send>,
531 > {
532 Box::pin(async move {
533 if let Ok(token) = std::env::var("VERTEX_API_KEY") {
535 if !token.trim().is_empty() {
536 return Ok(Some(AuthData::from_single(token)));
537 }
538 }
539 let provider = gcp_auth::provider().await.map_err(|e| {
541 genai::resolver::Error::Custom(format!(
542 "Vertex ADC provider init failed (run `gcloud auth application-default login`): {e}"
543 ))
544 })?;
545 let scopes = ["https://www.googleapis.com/auth/cloud-platform"];
546 let token = provider.token(&scopes).await.map_err(|e| {
547 genai::resolver::Error::Custom(format!("Vertex ADC token fetch failed: {e}"))
548 })?;
549 Ok(Some(AuthData::from_single(token.as_str())))
550 })
551 },
552 );
553
554 let mut builder = Client::builder()
555 .with_adapter_kind(AdapterKind::Vertex)
556 .with_auth_resolver(resolver);
557
558 if let Some(endpoint) = resolved_vertex_endpoint() {
559 builder = builder.with_service_target_resolver_fn(move |mut target: ServiceTarget| {
560 target.endpoint = Endpoint::from_owned(endpoint.clone());
561 Ok(target)
562 });
563 }
564
565 builder.build()
566}
567
568fn build_bound_client(adapter_kind: AdapterKind, provider_type: Option<&str>) -> Client {
569 let mut builder = Client::builder().with_adapter_kind(adapter_kind);
570
571 if let Some(base_url) = provider_type.and_then(provider_base_url_from_env) {
572 let endpoint = normalize_base_url(&base_url);
573 let target_resolver = ServiceTargetResolver::from_resolver_fn(
574 move |mut service_target: ServiceTarget| -> genai::resolver::Result<ServiceTarget> {
575 if service_target.model.adapter_kind == adapter_kind {
576 service_target.endpoint = Endpoint::from_owned(endpoint.clone());
577 }
578 Ok(service_target)
579 },
580 );
581 builder = builder.with_service_target_resolver(target_resolver);
582 }
583
584 builder.build()
585}
586
587fn provider_from_model_namespace(model: &str) -> Option<&'static str> {
588 let lower = model.to_ascii_lowercase();
589 lower.split_once("::").and_then(|(prefix, _)| match prefix {
590 "openai" => Some("openai"),
591 "anthropic" => Some("anthropic"),
592 "gemini" | "google" => Some("gemini"),
593 "vertex" => Some("vertex"),
594 "groq" => Some("groq"),
595 "cohere" => Some("cohere"),
596 "ollama" => Some("ollama"),
597 "xai" => Some("xai"),
598 "deepseek" => Some("deepseek"),
599 _ => None,
600 })
601}
602
603fn configure_vertex_environment(config: &Config) {
604 if std::env::var("VERTEX_PROJECT_ID").is_err() {
605 if let Some(project) = config
606 .vertex_project_id
607 .as_deref()
608 .map(str::trim)
609 .filter(|v| !v.is_empty())
610 .map(str::to_string)
611 .or_else(vertex_project_from_env)
612 .or_else(read_gcloud_project)
613 {
614 match valid_vertex_segment(&project) {
618 Some(valid) => std::env::set_var("VERTEX_PROJECT_ID", valid),
619 None => log::warn!(
620 "Ignoring discovered Vertex project ID (must contain only ASCII letters, \
621 digits, and hyphens)"
622 ),
623 }
624 }
625 }
626
627 if std::env::var("VERTEX_LOCATION").is_err() {
628 if let Some(location) = config
629 .vertex_location
630 .as_deref()
631 .map(str::trim)
632 .filter(|v| !v.is_empty())
633 {
634 match valid_vertex_segment(location) {
637 Some(valid) => std::env::set_var("VERTEX_LOCATION", valid),
638 None => log::warn!(
639 "Ignoring invalid vertex_location from config (must contain only ASCII \
640 letters, digits, and hyphens)"
641 ),
642 }
643 }
644 }
645}
646
647fn vertex_project_from_env() -> Option<String> {
648 [
649 "VERTEX_PROJECT_ID",
650 "GOOGLE_CLOUD_PROJECT",
651 "GCLOUD_PROJECT",
652 "CLOUDSDK_CORE_PROJECT",
653 ]
654 .into_iter()
655 .filter_map(|key| std::env::var(key).ok())
656 .map(|value| value.trim().to_string())
657 .find(|value| !value.is_empty())
658}
659
660fn gcloud_config_dir() -> Option<PathBuf> {
661 if let Ok(dir) = std::env::var("CLOUDSDK_CONFIG") {
662 let trimmed = dir.trim();
663 if !trimmed.is_empty() {
664 return Some(PathBuf::from(trimmed));
665 }
666 }
667 dirs::home_dir().map(|home| home.join(".config").join("gcloud"))
668}
669
670fn read_gcloud_project() -> Option<String> {
671 let config_dir = gcloud_config_dir()?;
672 let active_config = std::fs::read_to_string(config_dir.join("active_config"))
673 .ok()
674 .map(|s| s.trim().to_string())
675 .filter(|s| !s.is_empty())
676 .unwrap_or_else(|| "default".to_string());
677 let config_path = config_dir
678 .join("configurations")
679 .join(format!("config_{active_config}"));
680 let content = std::fs::read_to_string(config_path).ok()?;
681 parse_gcloud_project(&content)
682}
683
684fn parse_gcloud_project(content: &str) -> Option<String> {
685 let mut in_core = false;
686 for raw in content.lines() {
687 let line = raw.trim();
688 if line.is_empty() || line.starts_with('#') {
689 continue;
690 }
691 if line.starts_with('[') && line.ends_with(']') {
692 in_core = line.eq_ignore_ascii_case("[core]");
693 continue;
694 }
695 if !in_core {
696 continue;
697 }
698 let Some((key, value)) = line.split_once('=') else {
699 continue;
700 };
701 if key.trim() == "project" {
702 let project = value.trim();
703 if !project.is_empty() {
704 return Some(project.to_string());
705 }
706 }
707 }
708 None
709}
710
711fn valid_vertex_segment(value: &str) -> Option<&str> {
721 let trimmed = value.trim();
722 if trimmed.is_empty() {
723 return None;
724 }
725 trimmed
726 .chars()
727 .all(|c| c.is_ascii_alphanumeric() || c == '-')
728 .then_some(trimmed)
729}
730
731fn resolved_vertex_endpoint() -> Option<String> {
732 let project_raw = std::env::var("VERTEX_PROJECT_ID").ok()?;
733 let project = match valid_vertex_segment(&project_raw) {
734 Some(p) => p.to_string(),
735 None => {
736 log::warn!(
737 "Ignoring VERTEX_PROJECT_ID for endpoint construction: must be non-empty and \
738 contain only ASCII letters, digits, and hyphens"
739 );
740 return None;
741 }
742 };
743 let location = match std::env::var("VERTEX_LOCATION") {
744 Ok(raw) if !raw.trim().is_empty() => match valid_vertex_segment(&raw) {
745 Some(l) => l.to_string(),
746 None => {
747 log::warn!(
748 "Ignoring invalid VERTEX_LOCATION (must contain only ASCII letters, digits, \
749 and hyphens); falling back to 'global'"
750 );
751 "global".to_string()
752 }
753 },
754 _ => "global".to_string(),
755 };
756 Some(vertex_endpoint_base(&project, &location))
757}
758
759fn vertex_endpoint_base(project: &str, location: &str) -> String {
760 let project = project.trim();
761 let location = location.trim();
762 if location.eq_ignore_ascii_case("global") {
763 format!("https://aiplatform.googleapis.com/v1/projects/{project}/locations/global/")
764 } else {
765 format!(
766 "https://{location}-aiplatform.googleapis.com/v1/projects/{project}/locations/{location}/"
767 )
768 }
769}
770
771fn provider_base_url_env_var(provider: &str) -> Option<&'static str> {
772 match provider.to_lowercase().as_str() {
773 "openai" => Some("OPENAI_BASE_URL"),
774 "anthropic" => Some("ANTHROPIC_BASE_URL"),
775 "gemini" | "google" => Some("GEMINI_BASE_URL"),
776 "groq" => Some("GROQ_BASE_URL"),
777 "cohere" => Some("COHERE_BASE_URL"),
778 "ollama" => Some("OLLAMA_BASE_URL"),
779 "xai" => Some("XAI_BASE_URL"),
780 "deepseek" => Some("DEEPSEEK_BASE_URL"),
781 _ => None,
782 }
783}
784
785fn provider_base_url_from_env(provider: &str) -> Option<String> {
786 let env_var = provider_base_url_env_var(provider)?;
787
788 std::env::var(env_var)
789 .ok()
790 .map(|value| value.trim().to_string())
791 .filter(|value| !value.is_empty())
792}
793
794fn provider_api_key_env_var(provider: &str) -> Option<&'static str> {
795 match provider.to_lowercase().as_str() {
796 "openai" => Some("OPENAI_API_KEY"),
797 "anthropic" => Some("ANTHROPIC_API_KEY"),
798 "gemini" | "google" => Some("GEMINI_API_KEY"),
799 "vertex" => Some("VERTEX_API_KEY"),
800 "groq" => Some("GROQ_API_KEY"),
801 "cohere" => Some("COHERE_API_KEY"),
802 "xai" => Some("XAI_API_KEY"),
803 "deepseek" => Some("DEEPSEEK_API_KEY"),
804 _ => None,
805 }
806}
807
808fn normalize_base_url(base_url: &str) -> String {
809 if base_url.ends_with('/') {
810 base_url.to_string()
811 } else {
812 format!("{base_url}/")
813 }
814}
815
816fn str_to_adapter_kind(provider: &str) -> Result<AdapterKind> {
818 match provider.to_lowercase().as_str() {
819 "openai" => Ok(AdapterKind::OpenAI),
820 "anthropic" => Ok(AdapterKind::Anthropic),
821 "gemini" | "google" => Ok(AdapterKind::Gemini),
822 "vertex" => Ok(AdapterKind::Vertex),
823 "groq" => Ok(AdapterKind::Groq),
824 "cohere" => Ok(AdapterKind::Cohere),
825 "ollama" => Ok(AdapterKind::Ollama),
826 "xai" => Ok(AdapterKind::Xai),
827 "deepseek" => Ok(AdapterKind::DeepSeek),
828 _ => Err(anyhow::anyhow!("Unsupported provider: {}", provider)),
829 }
830}
831
832#[cfg(test)]
833mod tests {
834 use super::*;
835
836 #[test]
837 fn test_str_to_adapter_kind() {
838 assert!(str_to_adapter_kind("openai").is_ok());
839 assert!(str_to_adapter_kind("anthropic").is_ok());
840 assert!(str_to_adapter_kind("gemini").is_ok());
841 assert!(str_to_adapter_kind("google").is_ok());
842 assert!(str_to_adapter_kind("groq").is_ok());
843 assert!(str_to_adapter_kind("cohere").is_ok());
844 assert!(str_to_adapter_kind("ollama").is_ok());
845 assert!(str_to_adapter_kind("vertex").is_ok());
846 assert!(str_to_adapter_kind("xai").is_ok());
847 assert!(str_to_adapter_kind("deepseek").is_ok());
848 assert!(str_to_adapter_kind("invalid").is_err());
849 }
850
851 #[tokio::test]
852 async fn test_provider_creation() {
853 let provider = GenAIProvider::new();
854 assert!(provider.is_ok());
855 }
856
857 #[tokio::test]
858 async fn test_configured_provider_binds_adapter_for_custom_model_names() {
859 let provider = GenAIProvider::new_with_config(Some("openai"), None).unwrap();
860 let target = provider
861 .client
862 .resolve_service_target("gemma4-32b-it")
863 .await
864 .unwrap();
865
866 assert_eq!(target.model.adapter_kind, AdapterKind::OpenAI);
867 }
868
869 #[tokio::test]
870 async fn test_namespaced_model_resolves_on_unbound_client() {
871 let provider = GenAIProvider::new().unwrap();
873 let target = provider
874 .client
875 .resolve_service_target("openai::phi-4-npu-ov")
876 .await
877 .unwrap();
878
879 assert_eq!(target.model.adapter_kind, AdapterKind::OpenAI);
880 }
881
882 #[tokio::test]
883 async fn test_from_config_binds_adapter_for_custom_model() {
884 let config = Config {
885 provider: Some("openai".to_string()),
886 model: Some("phi-4-npu-ov".to_string()),
887 ..Default::default()
888 };
889 let (provider, resolved) = GenAIProvider::from_config(&config, None).unwrap();
890 assert_eq!(resolved.provider, "openai");
891 assert_eq!(resolved.model, "phi-4-npu-ov");
892
893 let target = provider
894 .client
895 .resolve_service_target(&resolved.model)
896 .await
897 .unwrap();
898 assert_eq!(target.model.adapter_kind, AdapterKind::OpenAI);
899 }
900
901 #[test]
902 fn test_from_config_model_precedence() {
903 let config = Config {
904 provider: Some("openai".to_string()),
905 model: Some("config-model".to_string()),
906 ..Default::default()
907 };
908 let (_p, resolved) = GenAIProvider::from_config(&config, Some("cli-model")).unwrap();
910 assert_eq!(resolved.model, "cli-model");
911 }
912
913 #[test]
914 fn test_provider_from_model_namespace_detects_vertex() {
915 assert_eq!(
916 provider_from_model_namespace("vertex::gemini-2.5-flash"),
917 Some("vertex")
918 );
919 assert_eq!(provider_from_model_namespace("gemini-2.5-flash"), None);
920 }
921
922 #[tokio::test]
923 async fn test_from_config_uses_namespaced_vertex_model_when_provider_absent() {
924 let previous_project = std::env::var("VERTEX_PROJECT_ID").ok();
925 let previous_location = std::env::var("VERTEX_LOCATION").ok();
926 std::env::set_var("VERTEX_PROJECT_ID", "unit-test-project");
927 std::env::remove_var("VERTEX_LOCATION");
928
929 let config = Config::default();
930 let (_provider, resolved) =
931 GenAIProvider::from_config(&config, Some("vertex::gemini-2.5-flash")).unwrap();
932 assert_eq!(resolved.provider, "vertex");
933 assert_eq!(resolved.model, "vertex::gemini-2.5-flash");
934 assert!(std::env::var("VERTEX_LOCATION").is_err());
935 assert_eq!(
936 resolved_vertex_endpoint().as_deref(),
937 Some(
938 "https://aiplatform.googleapis.com/v1/projects/unit-test-project/locations/global/"
939 )
940 );
941
942 match previous_project {
943 Some(value) => std::env::set_var("VERTEX_PROJECT_ID", value),
944 None => std::env::remove_var("VERTEX_PROJECT_ID"),
945 }
946 match previous_location {
947 Some(value) => std::env::set_var("VERTEX_LOCATION", value),
948 None => std::env::remove_var("VERTEX_LOCATION"),
949 }
950 }
951
952 #[test]
953 fn test_valid_vertex_segment_accepts_real_values() {
954 assert_eq!(valid_vertex_segment("perspt"), Some("perspt"));
955 assert_eq!(valid_vertex_segment("us-central1"), Some("us-central1"));
956 assert_eq!(valid_vertex_segment("global"), Some("global"));
957 assert_eq!(valid_vertex_segment("europe-west4"), Some("europe-west4"));
958 assert_eq!(valid_vertex_segment(" perspt "), Some("perspt")); }
960
961 #[test]
962 fn test_valid_vertex_segment_rejects_host_redirection() {
963 assert_eq!(valid_vertex_segment("evil.com/"), None);
965 assert_eq!(valid_vertex_segment("evil.com"), None); assert_eq!(valid_vertex_segment("a/b"), None);
967 assert_eq!(valid_vertex_segment("a:b"), None);
968 assert_eq!(valid_vertex_segment("a@b"), None);
969 assert_eq!(valid_vertex_segment("a b"), None);
970 assert_eq!(valid_vertex_segment(""), None);
971 assert_eq!(valid_vertex_segment(" "), None);
972 }
973
974 #[test]
975 fn test_resolved_vertex_endpoint_rejects_malicious_location() {
976 let prev_project = std::env::var("VERTEX_PROJECT_ID").ok();
977 let prev_location = std::env::var("VERTEX_LOCATION").ok();
978
979 std::env::set_var("VERTEX_PROJECT_ID", "perspt");
982 std::env::set_var("VERTEX_LOCATION", "evil.com/");
983 assert_eq!(
984 resolved_vertex_endpoint().as_deref(),
985 Some("https://aiplatform.googleapis.com/v1/projects/perspt/locations/global/"),
986 "malicious location must fall back to global, never redirect the host"
987 );
988
989 std::env::set_var("VERTEX_PROJECT_ID", "bad/project");
991 std::env::set_var("VERTEX_LOCATION", "us-central1");
992 assert_eq!(resolved_vertex_endpoint(), None);
993
994 match prev_project {
995 Some(v) => std::env::set_var("VERTEX_PROJECT_ID", v),
996 None => std::env::remove_var("VERTEX_PROJECT_ID"),
997 }
998 match prev_location {
999 Some(v) => std::env::set_var("VERTEX_LOCATION", v),
1000 None => std::env::remove_var("VERTEX_LOCATION"),
1001 }
1002 }
1003
1004 #[test]
1005 fn test_vertex_endpoint_base_matches_genai_vertex_shape() {
1006 assert_eq!(
1007 vertex_endpoint_base("test-project", "global"),
1008 "https://aiplatform.googleapis.com/v1/projects/test-project/locations/global/"
1009 );
1010 assert_eq!(
1011 vertex_endpoint_base("test-project", "test-location"),
1012 "https://test-location-aiplatform.googleapis.com/v1/projects/test-project/locations/test-location/"
1013 );
1014 }
1015
1016 #[test]
1017 fn test_parse_gcloud_project_reads_core_project() {
1018 let content = r#"
1019 [compute]
1020 region = ignored-location
1021
1022 [core]
1023 account = user@example.com
1024 project = test-project
1025 "#;
1026 assert_eq!(
1027 parse_gcloud_project(content).as_deref(),
1028 Some("test-project")
1029 );
1030 }
1031
1032 #[tokio::test]
1033 async fn test_openai_base_url_overrides_bound_provider_endpoint() {
1034 let previous = std::env::var("OPENAI_BASE_URL").ok();
1035 std::env::set_var("OPENAI_BASE_URL", "https://custom.example/v1");
1036
1037 let provider = GenAIProvider::new_with_config(Some("openai"), None).unwrap();
1038 let target = provider
1039 .client
1040 .resolve_service_target("gemma4-32b-it")
1041 .await
1042 .unwrap();
1043
1044 assert_eq!(target.endpoint.base_url(), "https://custom.example/v1/");
1045
1046 match previous {
1047 Some(value) => std::env::set_var("OPENAI_BASE_URL", value),
1048 None => std::env::remove_var("OPENAI_BASE_URL"),
1049 }
1050 }
1051
1052 #[test]
1053 fn test_normalize_base_url() {
1054 assert_eq!(
1055 normalize_base_url("https://custom.example/v1"),
1056 "https://custom.example/v1/"
1057 );
1058 assert_eq!(
1059 normalize_base_url("https://custom.example/v1/"),
1060 "https://custom.example/v1/"
1061 );
1062 }
1063
1064 #[tokio::test]
1065 async fn test_provider_is_clonable() {
1066 let provider = GenAIProvider::new().unwrap();
1067 let _clone1 = provider.clone();
1068 let _clone2 = provider.clone();
1069 }
1071}