perspt_core/
llm_provider.rs

1//! # LLM Provider Module
2//!
3//! Thread-safe LLM provider abstraction for multi-agent use.
4//! Wraps genai::Client with Arc<RwLock<>> for shared state.
5
6use anyhow::{Context, Result};
7use futures::StreamExt;
8use genai::adapter::AdapterKind;
9use genai::chat::{ChatMessage, ChatOptions, ChatRequest, ChatStreamEvent};
10use genai::Client;
11use std::sync::Arc;
12use std::time::Instant;
13use tokio::sync::{mpsc, RwLock};
14
15/// End of transmission signal
16pub const EOT_SIGNAL: &str = "<|EOT|>";
17
18/// Shared state for rate limiting and token counting
19#[derive(Default)]
20struct SharedState {
21    total_tokens_used: usize,
22    request_count: usize,
23}
24
25/// Thread-safe LLM provider implementation using Arc<RwLock<>>.
26///
27/// This provider can be cheaply cloned and shared across multiple agents.
28/// Each clone shares the same underlying client and rate limiting state.
29#[derive(Clone)]
30pub struct GenAIProvider {
31    /// The underlying genai client
32    client: Arc<Client>,
33    /// Shared state for rate limiting and metrics
34    shared: Arc<RwLock<SharedState>>,
35}
36
37impl GenAIProvider {
38    /// Creates a new GenAI provider with automatic configuration.
39    pub fn new() -> Result<Self> {
40        let client = Client::default();
41        Ok(Self {
42            client: Arc::new(client),
43            shared: Arc::new(RwLock::new(SharedState::default())),
44        })
45    }
46
47    /// Creates a new GenAI provider with explicit configuration.
48    pub fn new_with_config(provider_type: Option<&str>, api_key: Option<&str>) -> Result<Self> {
49        // Set environment variable if API key is provided
50        if let (Some(provider), Some(key)) = (provider_type, api_key) {
51            let env_var = match provider {
52                "openai" => "OPENAI_API_KEY",
53                "anthropic" => "ANTHROPIC_API_KEY",
54                "gemini" => "GEMINI_API_KEY",
55                "groq" => "GROQ_API_KEY",
56                "cohere" => "COHERE_API_KEY",
57                "xai" => "XAI_API_KEY",
58                "deepseek" => "DEEPSEEK_API_KEY",
59                "ollama" => {
60                    log::info!("Ollama provider detected - no API key required for local setup");
61                    return Self::new();
62                }
63                _ => {
64                    log::warn!("Unknown provider type for API key: {provider}");
65                    return Self::new();
66                }
67            };
68
69            log::info!("Setting {env_var} environment variable for genai client");
70            std::env::set_var(env_var, key);
71        }
72
73        Self::new()
74    }
75
76    /// Get total tokens used across all requests
77    pub async fn get_total_tokens_used(&self) -> usize {
78        self.shared.read().await.total_tokens_used
79    }
80
81    /// Get total request count
82    pub async fn get_request_count(&self) -> usize {
83        self.shared.read().await.request_count
84    }
85
86    /// Increment request counter (for metrics)
87    async fn increment_request(&self) {
88        let mut state = self.shared.write().await;
89        state.request_count += 1;
90    }
91
92    /// Add tokens to the total count
93    pub async fn add_tokens(&self, count: usize) {
94        let mut state = self.shared.write().await;
95        state.total_tokens_used += count;
96    }
97
98    /// Retrieves all available models for a specific provider.
99    pub async fn get_available_models(&self, provider: &str) -> Result<Vec<String>> {
100        let adapter_kind = str_to_adapter_kind(provider)?;
101
102        let models = self
103            .client
104            .all_model_names(adapter_kind)
105            .await
106            .context(format!("Failed to get models for provider: {provider}"))?;
107
108        Ok(models)
109    }
110
111    /// Generates a simple text response without streaming.
112    /// Includes exponential backoff retry for rate limits and transient errors.
113    pub async fn generate_response_simple(&self, model: &str, prompt: &str) -> Result<String> {
114        self.generate_response_with_retry(model, prompt, 3).await
115    }
116
117    /// Generates a response with configurable retry count and exponential backoff.
118    pub async fn generate_response_with_retry(
119        &self,
120        model: &str,
121        prompt: &str,
122        max_retries: usize,
123    ) -> Result<String> {
124        self.increment_request().await;
125
126        let chat_req = ChatRequest::default().append_message(ChatMessage::user(prompt));
127
128        log::debug!(
129            "Sending chat request to model: {model} with prompt length: {} chars",
130            prompt.len()
131        );
132
133        let start_time = Instant::now();
134        let mut last_error: Option<anyhow::Error> = None;
135        let mut retry_count = 0;
136
137        while retry_count <= max_retries {
138            if retry_count > 0 {
139                // Exponential backoff: 1s, 2s, 4s, 8s, ... (capped at 16s)
140                let delay_secs = std::cmp::min(1u64 << (retry_count - 1), 16);
141                log::warn!(
142                    "Retry {}/{} for model {} after {}s delay (previous error: {:?})",
143                    retry_count,
144                    max_retries,
145                    model,
146                    delay_secs,
147                    last_error.as_ref().map(|e| e.to_string())
148                );
149                println!(
150                    "   ⏳ Rate limited, retrying in {}s (attempt {}/{})",
151                    delay_secs, retry_count, max_retries
152                );
153                tokio::time::sleep(tokio::time::Duration::from_secs(delay_secs)).await;
154            }
155
156            match self.client.exec_chat(model, chat_req.clone(), None).await {
157                Ok(chat_res) => {
158                    let content = chat_res
159                        .first_text()
160                        .context("No text content in response")?;
161                    log::debug!(
162                        "Received response with {} characters in {}ms",
163                        content.len(),
164                        start_time.elapsed().as_millis()
165                    );
166
167                    return Ok(content.to_string());
168                }
169                Err(e) => {
170                    let err_str = e.to_string();
171
172                    // Check if it's a retryable error (rate limit, server error, network)
173                    let is_retryable = err_str.contains("429")
174                        || err_str.contains("rate limit")
175                        || err_str.contains("Rate limit")
176                        || err_str.contains("RESOURCE_EXHAUSTED")
177                        || err_str.contains("500")
178                        || err_str.contains("502")
179                        || err_str.contains("503")
180                        || err_str.contains("504")
181                        || err_str.contains("timeout")
182                        || err_str.contains("connection");
183
184                    if is_retryable && retry_count < max_retries {
185                        log::warn!("Retryable error for model {}: {}", model, err_str);
186                        last_error = Some(anyhow::anyhow!("{}", err_str));
187                        retry_count += 1;
188                        continue;
189                    } else {
190                        return Err(anyhow::anyhow!(
191                            "Failed to execute chat request for model {}: {}",
192                            model,
193                            err_str
194                        ));
195                    }
196                }
197            }
198        }
199
200        // Should not reach here, but handle gracefully
201        Err(last_error
202            .unwrap_or_else(|| anyhow::anyhow!("Unknown error after {} retries", max_retries)))
203    }
204
205    /// Generates a streaming response and sends chunks via mpsc channel.
206    pub async fn generate_response_stream_to_channel(
207        &self,
208        model: &str,
209        prompt: &str,
210        tx: mpsc::UnboundedSender<String>,
211    ) -> Result<()> {
212        self.increment_request().await;
213
214        let chat_req = ChatRequest::default().append_message(ChatMessage::user(prompt));
215
216        log::debug!("Sending streaming chat request to model: {model} with prompt: {prompt}");
217
218        let chat_res_stream = self
219            .client
220            .exec_chat_stream(model, chat_req, None)
221            .await
222            .context(format!(
223                "Failed to execute streaming chat request for model: {model}"
224            ))?;
225
226        let mut stream = chat_res_stream.stream;
227        let mut chunk_count = 0;
228        let mut total_content_length = 0;
229        let mut stream_ended_explicitly = false;
230        let start_time = Instant::now();
231
232        log::info!(
233            "=== STREAM START === Model: {}, Prompt length: {} chars",
234            model,
235            prompt.len()
236        );
237
238        while let Some(chunk_result) = stream.next().await {
239            let elapsed = start_time.elapsed();
240
241            match chunk_result {
242                Ok(ChatStreamEvent::Start) => {
243                    log::info!(">>> STREAM STARTED for model: {model} at {elapsed:?}");
244                }
245                Ok(ChatStreamEvent::Chunk(chunk)) => {
246                    chunk_count += 1;
247                    total_content_length += chunk.content.len();
248
249                    if chunk_count % 10 == 0 || chunk.content.len() > 100 {
250                        log::info!(
251                            "CHUNK #{}: {} chars, total: {} chars, elapsed: {:?}",
252                            chunk_count,
253                            chunk.content.len(),
254                            total_content_length,
255                            elapsed
256                        );
257                    }
258
259                    if !chunk.content.is_empty() && tx.send(chunk.content.clone()).is_err() {
260                        log::error!(
261                            "!!! CHANNEL SEND FAILED for chunk #{chunk_count} - STOPPING STREAM !!!"
262                        );
263                        break;
264                    }
265                }
266                Ok(ChatStreamEvent::ReasoningChunk(chunk)) => {
267                    log::info!(
268                        "REASONING CHUNK: {} chars at {:?}",
269                        chunk.content.len(),
270                        elapsed
271                    );
272                }
273                Ok(ChatStreamEvent::End(_)) => {
274                    log::info!(">>> STREAM ENDED EXPLICITLY for model: {model} after {chunk_count} chunks, {total_content_length} chars, {elapsed:?} elapsed");
275                    stream_ended_explicitly = true;
276                    break;
277                }
278                Ok(ChatStreamEvent::ToolCallChunk(_)) => {
279                    log::debug!("Tool call chunk received (ignored)");
280                }
281                Ok(ChatStreamEvent::ThoughtSignatureChunk(_)) => {
282                    log::debug!("Thought signature chunk received (ignored)");
283                }
284                Err(e) => {
285                    log::error!(
286                        "!!! STREAM ERROR after {chunk_count} chunks at {elapsed:?}: {e} !!!"
287                    );
288                    let error_msg = format!("Stream error: {e}");
289                    let _ = tx.send(error_msg);
290                    return Err(e.into());
291                }
292            }
293        }
294
295        let final_elapsed = start_time.elapsed();
296        if !stream_ended_explicitly {
297            log::warn!("!!! STREAM ENDED IMPLICITLY (exhausted) for model: {model} after {chunk_count} chunks, {total_content_length} chars, {final_elapsed:?} elapsed !!!");
298        }
299
300        log::info!(
301            "=== STREAM COMPLETE === Model: {model}, Final: {chunk_count} chunks, {total_content_length} chars, {final_elapsed:?} elapsed"
302        );
303
304        // Add approximate token count
305        self.add_tokens(total_content_length / 4).await; // Rough estimate
306
307        if tx.send(EOT_SIGNAL.to_string()).is_err() {
308            log::error!("!!! FAILED TO SEND EOT SIGNAL - channel may be closed !!!");
309            return Err(anyhow::anyhow!("Channel closed during EOT signal send"));
310        }
311
312        log::info!(">>> EOT SIGNAL SENT for model: {model} <<<");
313        Ok(())
314    }
315
316    /// Generate response with conversation history
317    pub async fn generate_response_with_history(
318        &self,
319        model: &str,
320        messages: Vec<ChatMessage>,
321    ) -> Result<String> {
322        self.increment_request().await;
323
324        let chat_req = ChatRequest::new(messages);
325
326        log::debug!("Sending chat request to model: {model} with conversation history");
327
328        let chat_res = self
329            .client
330            .exec_chat(model, chat_req, None)
331            .await
332            .context(format!("Failed to execute chat request for model: {model}"))?;
333
334        let content = chat_res
335            .first_text()
336            .context("No text content in response")?;
337
338        log::debug!("Received response with {} characters", content.len());
339        Ok(content.to_string())
340    }
341
342    /// Generate response with custom chat options
343    pub async fn generate_response_with_options(
344        &self,
345        model: &str,
346        prompt: &str,
347        options: ChatOptions,
348    ) -> Result<String> {
349        self.increment_request().await;
350
351        let chat_req = ChatRequest::default().append_message(ChatMessage::user(prompt));
352
353        log::debug!("Sending chat request to model: {model} with custom options");
354
355        let chat_res = self
356            .client
357            .exec_chat(model, chat_req, Some(&options))
358            .await
359            .context(format!("Failed to execute chat request for model: {model}"))?;
360
361        let content = chat_res
362            .first_text()
363            .context("No text content in response")?;
364
365        log::debug!("Received response with {} characters", content.len());
366        Ok(content.to_string())
367    }
368
369    /// Get a list of supported providers
370    pub fn get_supported_providers() -> Vec<&'static str> {
371        vec![
372            "openai",
373            "anthropic",
374            "gemini",
375            "groq",
376            "cohere",
377            "ollama",
378            "xai",
379            "deepseek",
380        ]
381    }
382
383    /// Get all available providers
384    pub async fn get_available_providers(&self) -> Result<Vec<String>> {
385        Ok(Self::get_supported_providers()
386            .iter()
387            .map(|s| s.to_string())
388            .collect())
389    }
390
391    /// Test if a model is available and working
392    pub async fn test_model(&self, model: &str) -> Result<bool> {
393        match self.generate_response_simple(model, "Hello").await {
394            Ok(_) => {
395                log::info!("Model {model} is available and working");
396                Ok(true)
397            }
398            Err(e) => {
399                log::warn!("Model {model} test failed: {e}");
400                Ok(false)
401            }
402        }
403    }
404
405    /// Validate and get the best available model for a provider
406    pub async fn validate_model(&self, model: &str, provider_type: Option<&str>) -> Result<String> {
407        if self.test_model(model).await? {
408            return Ok(model.to_string());
409        }
410
411        if let Some(provider) = provider_type {
412            if let Ok(models) = self.get_available_models(provider).await {
413                if !models.is_empty() {
414                    log::info!("Model {} not available, using {} instead", model, models[0]);
415                    return Ok(models[0].clone());
416                }
417            }
418        }
419
420        log::warn!("Could not validate model {model}, proceeding anyway");
421        Ok(model.to_string())
422    }
423}
424
425/// Convert a provider string to genai AdapterKind
426fn str_to_adapter_kind(provider: &str) -> Result<AdapterKind> {
427    match provider.to_lowercase().as_str() {
428        "openai" => Ok(AdapterKind::OpenAI),
429        "anthropic" => Ok(AdapterKind::Anthropic),
430        "gemini" | "google" => Ok(AdapterKind::Gemini),
431        "groq" => Ok(AdapterKind::Groq),
432        "cohere" => Ok(AdapterKind::Cohere),
433        "ollama" => Ok(AdapterKind::Ollama),
434        "xai" => Ok(AdapterKind::Xai),
435        "deepseek" => Ok(AdapterKind::DeepSeek),
436        _ => Err(anyhow::anyhow!("Unsupported provider: {}", provider)),
437    }
438}
439
440#[cfg(test)]
441mod tests {
442    use super::*;
443
444    #[test]
445    fn test_str_to_adapter_kind() {
446        assert!(str_to_adapter_kind("openai").is_ok());
447        assert!(str_to_adapter_kind("anthropic").is_ok());
448        assert!(str_to_adapter_kind("gemini").is_ok());
449        assert!(str_to_adapter_kind("google").is_ok());
450        assert!(str_to_adapter_kind("groq").is_ok());
451        assert!(str_to_adapter_kind("cohere").is_ok());
452        assert!(str_to_adapter_kind("ollama").is_ok());
453        assert!(str_to_adapter_kind("xai").is_ok());
454        assert!(str_to_adapter_kind("deepseek").is_ok());
455        assert!(str_to_adapter_kind("invalid").is_err());
456    }
457
458    #[tokio::test]
459    async fn test_provider_creation() {
460        let provider = GenAIProvider::new();
461        assert!(provider.is_ok());
462    }
463
464    #[tokio::test]
465    async fn test_provider_is_clonable() {
466        let provider = GenAIProvider::new().unwrap();
467        let _clone1 = provider.clone();
468        let _clone2 = provider.clone();
469        // All clones share the same underlying state
470    }
471}