1use anyhow::{Context, Result};
7use futures::StreamExt;
8use genai::adapter::AdapterKind;
9use genai::chat::{ChatMessage, ChatRequest, ChatStreamEvent};
10use genai::resolver::{Endpoint, ProviderConfig, ServiceTargetResolver};
11use genai::Client;
12use genai::ServiceTarget;
13use std::sync::Arc;
14use std::time::Instant;
15use tokio::sync::{mpsc, RwLock};
16
17use crate::config::Config;
18
19pub const EOT_SIGNAL: &str = "<|EOT|>";
21
22#[derive(Debug, Clone)]
24pub struct ResolvedProvider {
25 pub provider: String,
27 pub model: String,
29}
30
31pub fn detect_provider_from_env() -> (&'static str, &'static str) {
36 if std::env::var("GEMINI_API_KEY").is_ok() {
37 ("gemini", "gemini-3.1-flash-lite-preview")
38 } else if std::env::var("OPENAI_API_KEY").is_ok() {
39 ("openai", "gpt-4o-mini")
40 } else if std::env::var("ANTHROPIC_API_KEY").is_ok() {
41 ("anthropic", "claude-3-5-sonnet-20241022")
42 } else if std::env::var("GROQ_API_KEY").is_ok() {
43 ("groq", "llama-3.1-8b-instant")
44 } else if std::env::var("COHERE_API_KEY").is_ok() {
45 ("cohere", "command-r-plus")
46 } else if std::env::var("XAI_API_KEY").is_ok() {
47 ("xai", "grok-beta")
48 } else if std::env::var("DEEPSEEK_API_KEY").is_ok() {
49 ("deepseek", "deepseek-chat")
50 } else {
51 ("ollama", "llama3.2")
53 }
54}
55
56#[derive(Debug, Clone)]
58pub struct LlmResponse {
59 pub text: String,
60 pub tokens_in: Option<i32>,
61 pub tokens_out: Option<i32>,
62}
63
64#[derive(Default)]
66struct SharedState {
67 total_tokens_used: usize,
68 request_count: usize,
69}
70
71#[derive(Clone)]
76pub struct GenAIProvider {
77 client: Arc<Client>,
79 shared: Arc<RwLock<SharedState>>,
81}
82
83impl GenAIProvider {
84 pub fn new() -> Result<Self> {
86 let client = Client::default();
87 Ok(Self::from_client(client))
88 }
89
90 pub fn new_with_config(provider_type: Option<&str>, api_key: Option<&str>) -> Result<Self> {
92 let adapter_kind = provider_type.and_then(|provider| match str_to_adapter_kind(provider) {
93 Ok(adapter_kind) => Some(adapter_kind),
94 Err(_) => {
95 log::warn!("Unknown provider type for genai client: {provider}");
96 None
97 }
98 });
99
100 if let (Some(provider), Some(key)) = (provider_type, api_key) {
102 if let Some(env_var) = provider_api_key_env_var(provider) {
103 log::info!("Setting {env_var} environment variable for genai client");
104 std::env::set_var(env_var, key);
105 } else if provider.eq_ignore_ascii_case("ollama") {
106 log::info!("Ollama provider detected - no API key required for local setup");
107 } else {
108 log::warn!("Unknown provider type for API key: {provider}");
109 }
110 }
111
112 let client = match adapter_kind {
113 Some(adapter_kind) => build_bound_client(adapter_kind, provider_type),
114 None => Client::default(),
115 };
116
117 Ok(Self::from_client(client))
118 }
119
120 pub fn from_config(
134 config: &Config,
135 cli_model: Option<&str>,
136 ) -> Result<(Self, ResolvedProvider)> {
137 let (env_provider, env_model) = detect_provider_from_env();
138
139 let provider = config
140 .provider
141 .clone()
142 .unwrap_or_else(|| env_provider.to_string());
143
144 let env_model_override = std::env::var("OPENAI_MODEL")
145 .or_else(|_| std::env::var("MODEL"))
146 .ok();
147
148 let model = cli_model
149 .map(str::to_string)
150 .or_else(|| config.model.clone())
151 .or(env_model_override)
152 .unwrap_or_else(|| env_model.to_string());
153
154 if let Some(base_url) = config.base_url.as_deref() {
157 if let Some(env_var) = provider_base_url_env_var(&provider) {
158 if std::env::var(env_var).is_err() {
159 std::env::set_var(env_var, base_url);
160 }
161 }
162 }
163
164 let provider_obj = Self::new_with_config(Some(&provider), config.api_key.as_deref())?;
165 Ok((provider_obj, ResolvedProvider { provider, model }))
166 }
167
168 fn from_client(client: Client) -> Self {
169 Self {
170 client: Arc::new(client),
171 shared: Arc::new(RwLock::new(SharedState::default())),
172 }
173 }
174
175 pub async fn get_total_tokens_used(&self) -> usize {
177 self.shared.read().await.total_tokens_used
178 }
179
180 pub async fn get_request_count(&self) -> usize {
182 self.shared.read().await.request_count
183 }
184
185 async fn increment_request(&self) {
187 let mut state = self.shared.write().await;
188 state.request_count += 1;
189 }
190
191 pub async fn add_tokens(&self, count: usize) {
193 let mut state = self.shared.write().await;
194 state.total_tokens_used += count;
195 }
196
197 pub async fn get_available_models(&self, provider: &str) -> Result<Vec<String>> {
199 let adapter_kind = str_to_adapter_kind(provider)?;
200 let provider_config = provider_base_url_from_env(provider)
201 .map(|base_url| {
202 ProviderConfig::from_endpoint(Endpoint::from_owned(normalize_base_url(&base_url)))
203 })
204 .unwrap_or_default();
205
206 let models = self
207 .client
208 .all_model_names(adapter_kind, provider_config)
209 .await
210 .context(format!("Failed to get models for provider: {provider}"))?;
211
212 Ok(models)
213 }
214
215 pub async fn generate_response_simple(&self, model: &str, prompt: &str) -> Result<LlmResponse> {
218 self.generate_response_with_retry(model, prompt, 3).await
219 }
220
221 pub async fn generate_response_with_retry(
223 &self,
224 model: &str,
225 prompt: &str,
226 max_retries: usize,
227 ) -> Result<LlmResponse> {
228 self.increment_request().await;
229
230 let chat_req = ChatRequest::default().append_message(ChatMessage::user(prompt));
231
232 log::debug!(
233 "Sending chat request to model: {model} with prompt length: {} chars",
234 prompt.len()
235 );
236
237 let start_time = Instant::now();
238 let mut last_error: Option<anyhow::Error> = None;
239 let mut retry_count = 0;
240
241 while retry_count <= max_retries {
242 if retry_count > 0 {
243 let delay_secs = std::cmp::min(1u64 << (retry_count - 1), 16);
245 log::warn!(
246 "Retry {}/{} for model {} after {}s delay (previous error: {:?})",
247 retry_count,
248 max_retries,
249 model,
250 delay_secs,
251 last_error.as_ref().map(|e| e.to_string())
252 );
253 println!(
254 " ⏳ Rate limited, retrying in {}s (attempt {}/{})",
255 delay_secs, retry_count, max_retries
256 );
257 tokio::time::sleep(tokio::time::Duration::from_secs(delay_secs)).await;
258 }
259
260 match self.client.exec_chat(model, chat_req.clone(), None).await {
261 Ok(chat_res) => {
262 let tokens_in = chat_res.usage.prompt_tokens;
263 let tokens_out = chat_res.usage.completion_tokens;
264 let content = chat_res
265 .first_text()
266 .context("No text content in response")?;
267 log::debug!(
268 "Received response with {} characters in {}ms (tokens: in={:?}, out={:?})",
269 content.len(),
270 start_time.elapsed().as_millis(),
271 tokens_in,
272 tokens_out,
273 );
274
275 let total = tokens_in.unwrap_or(0) + tokens_out.unwrap_or(0);
277 if total > 0 {
278 self.add_tokens(total as usize).await;
279 }
280
281 return Ok(LlmResponse {
282 text: content.to_string(),
283 tokens_in,
284 tokens_out,
285 });
286 }
287 Err(e) => {
288 let err_str = e.to_string();
289
290 let is_retryable = err_str.contains("429")
292 || err_str.contains("rate limit")
293 || err_str.contains("Rate limit")
294 || err_str.contains("RESOURCE_EXHAUSTED")
295 || err_str.contains("500")
296 || err_str.contains("502")
297 || err_str.contains("503")
298 || err_str.contains("504")
299 || err_str.contains("timeout")
300 || err_str.contains("connection");
301
302 if is_retryable && retry_count < max_retries {
303 log::warn!("Retryable error for model {}: {}", model, err_str);
304 last_error = Some(anyhow::anyhow!("{}", err_str));
305 retry_count += 1;
306 continue;
307 } else {
308 return Err(anyhow::anyhow!(
309 "Failed to execute chat request for model {}: {}",
310 model,
311 err_str
312 ));
313 }
314 }
315 }
316 }
317
318 Err(last_error
320 .unwrap_or_else(|| anyhow::anyhow!("Unknown error after {} retries", max_retries)))
321 }
322
323 pub async fn generate_response_stream_to_channel(
325 &self,
326 model: &str,
327 prompt: &str,
328 tx: mpsc::UnboundedSender<String>,
329 ) -> Result<()> {
330 self.increment_request().await;
331
332 let chat_req = ChatRequest::default().append_message(ChatMessage::user(prompt));
333
334 log::debug!("Sending streaming chat request to model: {model} with prompt: {prompt}");
335
336 let chat_res_stream = self
337 .client
338 .exec_chat_stream(model, chat_req, None)
339 .await
340 .context(format!(
341 "Failed to execute streaming chat request for model: {model}"
342 ))?;
343
344 let mut stream = chat_res_stream.stream;
345 let mut chunk_count = 0;
346 let mut total_content_length = 0;
347 let mut stream_ended_explicitly = false;
348 let start_time = Instant::now();
349
350 log::info!(
351 "=== STREAM START === Model: {}, Prompt length: {} chars",
352 model,
353 prompt.len()
354 );
355
356 while let Some(chunk_result) = stream.next().await {
357 let elapsed = start_time.elapsed();
358
359 match chunk_result {
360 Ok(ChatStreamEvent::Start) => {
361 log::info!(">>> STREAM STARTED for model: {model} at {elapsed:?}");
362 }
363 Ok(ChatStreamEvent::Chunk(chunk)) => {
364 chunk_count += 1;
365 total_content_length += chunk.content.len();
366
367 if chunk_count % 10 == 0 || chunk.content.len() > 100 {
368 log::info!(
369 "CHUNK #{}: {} chars, total: {} chars, elapsed: {:?}",
370 chunk_count,
371 chunk.content.len(),
372 total_content_length,
373 elapsed
374 );
375 }
376
377 if !chunk.content.is_empty() && tx.send(chunk.content.clone()).is_err() {
378 log::error!(
379 "!!! CHANNEL SEND FAILED for chunk #{chunk_count} - STOPPING STREAM !!!"
380 );
381 break;
382 }
383 }
384 Ok(ChatStreamEvent::ReasoningChunk(chunk)) => {
385 log::info!(
386 "REASONING CHUNK: {} chars at {:?}",
387 chunk.content.len(),
388 elapsed
389 );
390 if !chunk.content.is_empty() {
391 let _ = tx.send(format!("__PERSPT_REASONING__:{}", chunk.content));
392 }
393 }
394 Ok(ChatStreamEvent::End(_)) => {
395 log::info!(">>> STREAM ENDED EXPLICITLY for model: {model} after {chunk_count} chunks, {total_content_length} chars, {elapsed:?} elapsed");
396 stream_ended_explicitly = true;
397 break;
398 }
399 Ok(ChatStreamEvent::ToolCallChunk(_)) => {
400 log::debug!("Tool call chunk received (ignored)");
401 }
402 Ok(ChatStreamEvent::ThoughtSignatureChunk(_)) => {
403 log::debug!("Thought signature chunk received (ignored)");
404 }
405 Err(e) => {
406 log::error!(
407 "!!! STREAM ERROR after {chunk_count} chunks at {elapsed:?}: {e} !!!"
408 );
409 let error_msg = format!("Stream error: {e}");
410 let _ = tx.send(error_msg);
411 return Err(e.into());
412 }
413 }
414 }
415
416 let final_elapsed = start_time.elapsed();
417 if !stream_ended_explicitly {
418 log::warn!("!!! STREAM ENDED IMPLICITLY (exhausted) for model: {model} after {chunk_count} chunks, {total_content_length} chars, {final_elapsed:?} elapsed !!!");
419 }
420
421 log::info!(
422 "=== STREAM COMPLETE === Model: {model}, Final: {chunk_count} chunks, {total_content_length} chars, {final_elapsed:?} elapsed"
423 );
424
425 self.add_tokens(total_content_length / 4).await; if tx.send(EOT_SIGNAL.to_string()).is_err() {
429 log::error!("!!! FAILED TO SEND EOT SIGNAL - channel may be closed !!!");
430 return Err(anyhow::anyhow!("Channel closed during EOT signal send"));
431 }
432
433 log::info!(">>> EOT SIGNAL SENT for model: {model} <<<");
434 Ok(())
435 }
436
437 pub fn get_supported_providers() -> Vec<&'static str> {
439 vec![
440 "openai",
441 "anthropic",
442 "gemini",
443 "groq",
444 "cohere",
445 "ollama",
446 "xai",
447 "deepseek",
448 ]
449 }
450
451 pub async fn get_available_providers(&self) -> Result<Vec<String>> {
453 Ok(Self::get_supported_providers()
454 .iter()
455 .map(|s| s.to_string())
456 .collect())
457 }
458
459 pub async fn test_model(&self, model: &str) -> Result<bool> {
461 match self.generate_response_simple(model, "Hello").await {
462 Ok(_) => {
463 log::info!("Model {model} is available and working");
464 Ok(true)
465 }
466 Err(e) => {
467 log::warn!("Model {model} test failed: {e}");
468 Ok(false)
469 }
470 }
471 }
472
473 pub async fn validate_model(&self, model: &str, provider_type: Option<&str>) -> Result<String> {
475 if self.test_model(model).await? {
476 return Ok(model.to_string());
477 }
478
479 if let Some(provider) = provider_type {
480 if let Ok(models) = self.get_available_models(provider).await {
481 if !models.is_empty() {
482 log::info!("Model {} not available, using {} instead", model, models[0]);
483 return Ok(models[0].clone());
484 }
485 }
486 }
487
488 log::warn!("Could not validate model {model}, proceeding anyway");
489 Ok(model.to_string())
490 }
491}
492
493fn build_bound_client(adapter_kind: AdapterKind, provider_type: Option<&str>) -> Client {
494 let mut builder = Client::builder().with_adapter_kind(adapter_kind);
495
496 if let Some(base_url) = provider_type.and_then(provider_base_url_from_env) {
497 let endpoint = normalize_base_url(&base_url);
498 let target_resolver = ServiceTargetResolver::from_resolver_fn(
499 move |mut service_target: ServiceTarget| -> genai::resolver::Result<ServiceTarget> {
500 if service_target.model.adapter_kind == adapter_kind {
501 service_target.endpoint = Endpoint::from_owned(endpoint.clone());
502 }
503 Ok(service_target)
504 },
505 );
506 builder = builder.with_service_target_resolver(target_resolver);
507 }
508
509 builder.build()
510}
511
512fn provider_base_url_env_var(provider: &str) -> Option<&'static str> {
513 match provider.to_lowercase().as_str() {
514 "openai" => Some("OPENAI_BASE_URL"),
515 "anthropic" => Some("ANTHROPIC_BASE_URL"),
516 "gemini" | "google" => Some("GEMINI_BASE_URL"),
517 "groq" => Some("GROQ_BASE_URL"),
518 "cohere" => Some("COHERE_BASE_URL"),
519 "ollama" => Some("OLLAMA_BASE_URL"),
520 "xai" => Some("XAI_BASE_URL"),
521 "deepseek" => Some("DEEPSEEK_BASE_URL"),
522 _ => None,
523 }
524}
525
526fn provider_base_url_from_env(provider: &str) -> Option<String> {
527 let env_var = provider_base_url_env_var(provider)?;
528
529 std::env::var(env_var)
530 .ok()
531 .map(|value| value.trim().to_string())
532 .filter(|value| !value.is_empty())
533}
534
535fn provider_api_key_env_var(provider: &str) -> Option<&'static str> {
536 match provider.to_lowercase().as_str() {
537 "openai" => Some("OPENAI_API_KEY"),
538 "anthropic" => Some("ANTHROPIC_API_KEY"),
539 "gemini" | "google" => Some("GEMINI_API_KEY"),
540 "groq" => Some("GROQ_API_KEY"),
541 "cohere" => Some("COHERE_API_KEY"),
542 "xai" => Some("XAI_API_KEY"),
543 "deepseek" => Some("DEEPSEEK_API_KEY"),
544 _ => None,
545 }
546}
547
548fn normalize_base_url(base_url: &str) -> String {
549 if base_url.ends_with('/') {
550 base_url.to_string()
551 } else {
552 format!("{base_url}/")
553 }
554}
555
556fn str_to_adapter_kind(provider: &str) -> Result<AdapterKind> {
558 match provider.to_lowercase().as_str() {
559 "openai" => Ok(AdapterKind::OpenAI),
560 "anthropic" => Ok(AdapterKind::Anthropic),
561 "gemini" | "google" => Ok(AdapterKind::Gemini),
562 "groq" => Ok(AdapterKind::Groq),
563 "cohere" => Ok(AdapterKind::Cohere),
564 "ollama" => Ok(AdapterKind::Ollama),
565 "xai" => Ok(AdapterKind::Xai),
566 "deepseek" => Ok(AdapterKind::DeepSeek),
567 _ => Err(anyhow::anyhow!("Unsupported provider: {}", provider)),
568 }
569}
570
571#[cfg(test)]
572mod tests {
573 use super::*;
574
575 #[test]
576 fn test_str_to_adapter_kind() {
577 assert!(str_to_adapter_kind("openai").is_ok());
578 assert!(str_to_adapter_kind("anthropic").is_ok());
579 assert!(str_to_adapter_kind("gemini").is_ok());
580 assert!(str_to_adapter_kind("google").is_ok());
581 assert!(str_to_adapter_kind("groq").is_ok());
582 assert!(str_to_adapter_kind("cohere").is_ok());
583 assert!(str_to_adapter_kind("ollama").is_ok());
584 assert!(str_to_adapter_kind("xai").is_ok());
585 assert!(str_to_adapter_kind("deepseek").is_ok());
586 assert!(str_to_adapter_kind("invalid").is_err());
587 }
588
589 #[tokio::test]
590 async fn test_provider_creation() {
591 let provider = GenAIProvider::new();
592 assert!(provider.is_ok());
593 }
594
595 #[tokio::test]
596 async fn test_configured_provider_binds_adapter_for_custom_model_names() {
597 let provider = GenAIProvider::new_with_config(Some("openai"), None).unwrap();
598 let target = provider
599 .client
600 .resolve_service_target("gemma4-32b-it")
601 .await
602 .unwrap();
603
604 assert_eq!(target.model.adapter_kind, AdapterKind::OpenAI);
605 }
606
607 #[tokio::test]
608 async fn test_namespaced_model_resolves_on_unbound_client() {
609 let provider = GenAIProvider::new().unwrap();
611 let target = provider
612 .client
613 .resolve_service_target("openai::phi-4-npu-ov")
614 .await
615 .unwrap();
616
617 assert_eq!(target.model.adapter_kind, AdapterKind::OpenAI);
618 }
619
620 #[tokio::test]
621 async fn test_from_config_binds_adapter_for_custom_model() {
622 let config = Config {
623 provider: Some("openai".to_string()),
624 model: Some("phi-4-npu-ov".to_string()),
625 ..Default::default()
626 };
627 let (provider, resolved) = GenAIProvider::from_config(&config, None).unwrap();
628 assert_eq!(resolved.provider, "openai");
629 assert_eq!(resolved.model, "phi-4-npu-ov");
630
631 let target = provider
632 .client
633 .resolve_service_target(&resolved.model)
634 .await
635 .unwrap();
636 assert_eq!(target.model.adapter_kind, AdapterKind::OpenAI);
637 }
638
639 #[test]
640 fn test_from_config_model_precedence() {
641 let config = Config {
642 provider: Some("openai".to_string()),
643 model: Some("config-model".to_string()),
644 ..Default::default()
645 };
646 let (_p, resolved) = GenAIProvider::from_config(&config, Some("cli-model")).unwrap();
648 assert_eq!(resolved.model, "cli-model");
649 }
650
651 #[tokio::test]
652 async fn test_openai_base_url_overrides_bound_provider_endpoint() {
653 let previous = std::env::var("OPENAI_BASE_URL").ok();
654 std::env::set_var("OPENAI_BASE_URL", "https://custom.example/v1");
655
656 let provider = GenAIProvider::new_with_config(Some("openai"), None).unwrap();
657 let target = provider
658 .client
659 .resolve_service_target("gemma4-32b-it")
660 .await
661 .unwrap();
662
663 assert_eq!(target.endpoint.base_url(), "https://custom.example/v1/");
664
665 match previous {
666 Some(value) => std::env::set_var("OPENAI_BASE_URL", value),
667 None => std::env::remove_var("OPENAI_BASE_URL"),
668 }
669 }
670
671 #[test]
672 fn test_normalize_base_url() {
673 assert_eq!(
674 normalize_base_url("https://custom.example/v1"),
675 "https://custom.example/v1/"
676 );
677 assert_eq!(
678 normalize_base_url("https://custom.example/v1/"),
679 "https://custom.example/v1/"
680 );
681 }
682
683 #[tokio::test]
684 async fn test_provider_is_clonable() {
685 let provider = GenAIProvider::new().unwrap();
686 let _clone1 = provider.clone();
687 let _clone2 = provider.clone();
688 }
690}