Skip to main content

perspt_core/
llm_provider.rs

1//! # LLM Provider Module
2//!
3//! Thread-safe LLM provider abstraction for multi-agent use.
4//! Wraps genai::Client with Arc<RwLock<>> for shared state.
5
6use anyhow::{Context, Result};
7use futures::StreamExt;
8use genai::adapter::AdapterKind;
9use genai::chat::{ChatMessage, ChatOptions, ChatRequest, ChatStreamEvent};
10use genai::Client;
11use std::sync::Arc;
12use std::time::Instant;
13use tokio::sync::{mpsc, RwLock};
14
15/// End of transmission signal
16pub const EOT_SIGNAL: &str = "<|EOT|>";
17
18/// Response from a non-streaming LLM call, carrying text and token usage.
19#[derive(Debug, Clone)]
20pub struct LlmResponse {
21    pub text: String,
22    pub tokens_in: Option<i32>,
23    pub tokens_out: Option<i32>,
24}
25
26/// Shared state for rate limiting and token counting
27#[derive(Default)]
28struct SharedState {
29    total_tokens_used: usize,
30    request_count: usize,
31}
32
33/// Thread-safe LLM provider implementation using Arc<RwLock<>>.
34///
35/// This provider can be cheaply cloned and shared across multiple agents.
36/// Each clone shares the same underlying client and rate limiting state.
37#[derive(Clone)]
38pub struct GenAIProvider {
39    /// The underlying genai client
40    client: Arc<Client>,
41    /// Shared state for rate limiting and metrics
42    shared: Arc<RwLock<SharedState>>,
43}
44
45impl GenAIProvider {
46    /// Creates a new GenAI provider with automatic configuration.
47    pub fn new() -> Result<Self> {
48        let client = Client::default();
49        Ok(Self {
50            client: Arc::new(client),
51            shared: Arc::new(RwLock::new(SharedState::default())),
52        })
53    }
54
55    /// Creates a new GenAI provider with explicit configuration.
56    pub fn new_with_config(provider_type: Option<&str>, api_key: Option<&str>) -> Result<Self> {
57        // Set environment variable if API key is provided
58        if let (Some(provider), Some(key)) = (provider_type, api_key) {
59            let env_var = match provider {
60                "openai" => "OPENAI_API_KEY",
61                "anthropic" => "ANTHROPIC_API_KEY",
62                "gemini" => "GEMINI_API_KEY",
63                "groq" => "GROQ_API_KEY",
64                "cohere" => "COHERE_API_KEY",
65                "xai" => "XAI_API_KEY",
66                "deepseek" => "DEEPSEEK_API_KEY",
67                "ollama" => {
68                    log::info!("Ollama provider detected - no API key required for local setup");
69                    return Self::new();
70                }
71                _ => {
72                    log::warn!("Unknown provider type for API key: {provider}");
73                    return Self::new();
74                }
75            };
76
77            log::info!("Setting {env_var} environment variable for genai client");
78            std::env::set_var(env_var, key);
79        }
80
81        Self::new()
82    }
83
84    /// Get total tokens used across all requests
85    pub async fn get_total_tokens_used(&self) -> usize {
86        self.shared.read().await.total_tokens_used
87    }
88
89    /// Get total request count
90    pub async fn get_request_count(&self) -> usize {
91        self.shared.read().await.request_count
92    }
93
94    /// Increment request counter (for metrics)
95    async fn increment_request(&self) {
96        let mut state = self.shared.write().await;
97        state.request_count += 1;
98    }
99
100    /// Add tokens to the total count
101    pub async fn add_tokens(&self, count: usize) {
102        let mut state = self.shared.write().await;
103        state.total_tokens_used += count;
104    }
105
106    /// Retrieves all available models for a specific provider.
107    pub async fn get_available_models(&self, provider: &str) -> Result<Vec<String>> {
108        let adapter_kind = str_to_adapter_kind(provider)?;
109
110        let models = self
111            .client
112            .all_model_names(adapter_kind)
113            .await
114            .context(format!("Failed to get models for provider: {provider}"))?;
115
116        Ok(models)
117    }
118
119    /// Generates a simple text response without streaming.
120    /// Includes exponential backoff retry for rate limits and transient errors.
121    pub async fn generate_response_simple(&self, model: &str, prompt: &str) -> Result<LlmResponse> {
122        self.generate_response_with_retry(model, prompt, 3).await
123    }
124
125    /// Generates a response with configurable retry count and exponential backoff.
126    pub async fn generate_response_with_retry(
127        &self,
128        model: &str,
129        prompt: &str,
130        max_retries: usize,
131    ) -> Result<LlmResponse> {
132        self.increment_request().await;
133
134        let chat_req = ChatRequest::default().append_message(ChatMessage::user(prompt));
135
136        log::debug!(
137            "Sending chat request to model: {model} with prompt length: {} chars",
138            prompt.len()
139        );
140
141        let start_time = Instant::now();
142        let mut last_error: Option<anyhow::Error> = None;
143        let mut retry_count = 0;
144
145        while retry_count <= max_retries {
146            if retry_count > 0 {
147                // Exponential backoff: 1s, 2s, 4s, 8s, ... (capped at 16s)
148                let delay_secs = std::cmp::min(1u64 << (retry_count - 1), 16);
149                log::warn!(
150                    "Retry {}/{} for model {} after {}s delay (previous error: {:?})",
151                    retry_count,
152                    max_retries,
153                    model,
154                    delay_secs,
155                    last_error.as_ref().map(|e| e.to_string())
156                );
157                println!(
158                    "   ⏳ Rate limited, retrying in {}s (attempt {}/{})",
159                    delay_secs, retry_count, max_retries
160                );
161                tokio::time::sleep(tokio::time::Duration::from_secs(delay_secs)).await;
162            }
163
164            match self.client.exec_chat(model, chat_req.clone(), None).await {
165                Ok(chat_res) => {
166                    let tokens_in = chat_res.usage.prompt_tokens;
167                    let tokens_out = chat_res.usage.completion_tokens;
168                    let content = chat_res
169                        .first_text()
170                        .context("No text content in response")?;
171                    log::debug!(
172                        "Received response with {} characters in {}ms (tokens: in={:?}, out={:?})",
173                        content.len(),
174                        start_time.elapsed().as_millis(),
175                        tokens_in,
176                        tokens_out,
177                    );
178
179                    // Update shared token counter with real values when available
180                    let total = tokens_in.unwrap_or(0) + tokens_out.unwrap_or(0);
181                    if total > 0 {
182                        self.add_tokens(total as usize).await;
183                    }
184
185                    return Ok(LlmResponse {
186                        text: content.to_string(),
187                        tokens_in,
188                        tokens_out,
189                    });
190                }
191                Err(e) => {
192                    let err_str = e.to_string();
193
194                    // Check if it's a retryable error (rate limit, server error, network)
195                    let is_retryable = err_str.contains("429")
196                        || err_str.contains("rate limit")
197                        || err_str.contains("Rate limit")
198                        || err_str.contains("RESOURCE_EXHAUSTED")
199                        || err_str.contains("500")
200                        || err_str.contains("502")
201                        || err_str.contains("503")
202                        || err_str.contains("504")
203                        || err_str.contains("timeout")
204                        || err_str.contains("connection");
205
206                    if is_retryable && retry_count < max_retries {
207                        log::warn!("Retryable error for model {}: {}", model, err_str);
208                        last_error = Some(anyhow::anyhow!("{}", err_str));
209                        retry_count += 1;
210                        continue;
211                    } else {
212                        return Err(anyhow::anyhow!(
213                            "Failed to execute chat request for model {}: {}",
214                            model,
215                            err_str
216                        ));
217                    }
218                }
219            }
220        }
221
222        // Should not reach here, but handle gracefully
223        Err(last_error
224            .unwrap_or_else(|| anyhow::anyhow!("Unknown error after {} retries", max_retries)))
225    }
226
227    /// Generates a streaming response and sends chunks via mpsc channel.
228    pub async fn generate_response_stream_to_channel(
229        &self,
230        model: &str,
231        prompt: &str,
232        tx: mpsc::UnboundedSender<String>,
233    ) -> Result<()> {
234        self.increment_request().await;
235
236        let chat_req = ChatRequest::default().append_message(ChatMessage::user(prompt));
237
238        log::debug!("Sending streaming chat request to model: {model} with prompt: {prompt}");
239
240        let chat_res_stream = self
241            .client
242            .exec_chat_stream(model, chat_req, None)
243            .await
244            .context(format!(
245                "Failed to execute streaming chat request for model: {model}"
246            ))?;
247
248        let mut stream = chat_res_stream.stream;
249        let mut chunk_count = 0;
250        let mut total_content_length = 0;
251        let mut stream_ended_explicitly = false;
252        let start_time = Instant::now();
253
254        log::info!(
255            "=== STREAM START === Model: {}, Prompt length: {} chars",
256            model,
257            prompt.len()
258        );
259
260        while let Some(chunk_result) = stream.next().await {
261            let elapsed = start_time.elapsed();
262
263            match chunk_result {
264                Ok(ChatStreamEvent::Start) => {
265                    log::info!(">>> STREAM STARTED for model: {model} at {elapsed:?}");
266                }
267                Ok(ChatStreamEvent::Chunk(chunk)) => {
268                    chunk_count += 1;
269                    total_content_length += chunk.content.len();
270
271                    if chunk_count % 10 == 0 || chunk.content.len() > 100 {
272                        log::info!(
273                            "CHUNK #{}: {} chars, total: {} chars, elapsed: {:?}",
274                            chunk_count,
275                            chunk.content.len(),
276                            total_content_length,
277                            elapsed
278                        );
279                    }
280
281                    if !chunk.content.is_empty() && tx.send(chunk.content.clone()).is_err() {
282                        log::error!(
283                            "!!! CHANNEL SEND FAILED for chunk #{chunk_count} - STOPPING STREAM !!!"
284                        );
285                        break;
286                    }
287                }
288                Ok(ChatStreamEvent::ReasoningChunk(chunk)) => {
289                    log::info!(
290                        "REASONING CHUNK: {} chars at {:?}",
291                        chunk.content.len(),
292                        elapsed
293                    );
294                }
295                Ok(ChatStreamEvent::End(_)) => {
296                    log::info!(">>> STREAM ENDED EXPLICITLY for model: {model} after {chunk_count} chunks, {total_content_length} chars, {elapsed:?} elapsed");
297                    stream_ended_explicitly = true;
298                    break;
299                }
300                Ok(ChatStreamEvent::ToolCallChunk(_)) => {
301                    log::debug!("Tool call chunk received (ignored)");
302                }
303                Ok(ChatStreamEvent::ThoughtSignatureChunk(_)) => {
304                    log::debug!("Thought signature chunk received (ignored)");
305                }
306                Err(e) => {
307                    log::error!(
308                        "!!! STREAM ERROR after {chunk_count} chunks at {elapsed:?}: {e} !!!"
309                    );
310                    let error_msg = format!("Stream error: {e}");
311                    let _ = tx.send(error_msg);
312                    return Err(e.into());
313                }
314            }
315        }
316
317        let final_elapsed = start_time.elapsed();
318        if !stream_ended_explicitly {
319            log::warn!("!!! STREAM ENDED IMPLICITLY (exhausted) for model: {model} after {chunk_count} chunks, {total_content_length} chars, {final_elapsed:?} elapsed !!!");
320        }
321
322        log::info!(
323            "=== STREAM COMPLETE === Model: {model}, Final: {chunk_count} chunks, {total_content_length} chars, {final_elapsed:?} elapsed"
324        );
325
326        // Add approximate token count
327        self.add_tokens(total_content_length / 4).await; // Rough estimate
328
329        if tx.send(EOT_SIGNAL.to_string()).is_err() {
330            log::error!("!!! FAILED TO SEND EOT SIGNAL - channel may be closed !!!");
331            return Err(anyhow::anyhow!("Channel closed during EOT signal send"));
332        }
333
334        log::info!(">>> EOT SIGNAL SENT for model: {model} <<<");
335        Ok(())
336    }
337
338    /// Generate response with conversation history
339    pub async fn generate_response_with_history(
340        &self,
341        model: &str,
342        messages: Vec<ChatMessage>,
343    ) -> Result<String> {
344        self.increment_request().await;
345
346        let chat_req = ChatRequest::new(messages);
347
348        log::debug!("Sending chat request to model: {model} with conversation history");
349
350        let chat_res = self
351            .client
352            .exec_chat(model, chat_req, None)
353            .await
354            .context(format!("Failed to execute chat request for model: {model}"))?;
355
356        let content = chat_res
357            .first_text()
358            .context("No text content in response")?;
359
360        log::debug!("Received response with {} characters", content.len());
361        Ok(content.to_string())
362    }
363
364    /// Generate response with custom chat options
365    pub async fn generate_response_with_options(
366        &self,
367        model: &str,
368        prompt: &str,
369        options: ChatOptions,
370    ) -> Result<String> {
371        self.increment_request().await;
372
373        let chat_req = ChatRequest::default().append_message(ChatMessage::user(prompt));
374
375        log::debug!("Sending chat request to model: {model} with custom options");
376
377        let chat_res = self
378            .client
379            .exec_chat(model, chat_req, Some(&options))
380            .await
381            .context(format!("Failed to execute chat request for model: {model}"))?;
382
383        let content = chat_res
384            .first_text()
385            .context("No text content in response")?;
386
387        log::debug!("Received response with {} characters", content.len());
388        Ok(content.to_string())
389    }
390
391    /// Get a list of supported providers
392    pub fn get_supported_providers() -> Vec<&'static str> {
393        vec![
394            "openai",
395            "anthropic",
396            "gemini",
397            "groq",
398            "cohere",
399            "ollama",
400            "xai",
401            "deepseek",
402        ]
403    }
404
405    /// Get all available providers
406    pub async fn get_available_providers(&self) -> Result<Vec<String>> {
407        Ok(Self::get_supported_providers()
408            .iter()
409            .map(|s| s.to_string())
410            .collect())
411    }
412
413    /// Test if a model is available and working
414    pub async fn test_model(&self, model: &str) -> Result<bool> {
415        match self.generate_response_simple(model, "Hello").await {
416            Ok(_) => {
417                log::info!("Model {model} is available and working");
418                Ok(true)
419            }
420            Err(e) => {
421                log::warn!("Model {model} test failed: {e}");
422                Ok(false)
423            }
424        }
425    }
426
427    /// Validate and get the best available model for a provider
428    pub async fn validate_model(&self, model: &str, provider_type: Option<&str>) -> Result<String> {
429        if self.test_model(model).await? {
430            return Ok(model.to_string());
431        }
432
433        if let Some(provider) = provider_type {
434            if let Ok(models) = self.get_available_models(provider).await {
435                if !models.is_empty() {
436                    log::info!("Model {} not available, using {} instead", model, models[0]);
437                    return Ok(models[0].clone());
438                }
439            }
440        }
441
442        log::warn!("Could not validate model {model}, proceeding anyway");
443        Ok(model.to_string())
444    }
445}
446
447/// Convert a provider string to genai AdapterKind
448fn str_to_adapter_kind(provider: &str) -> Result<AdapterKind> {
449    match provider.to_lowercase().as_str() {
450        "openai" => Ok(AdapterKind::OpenAI),
451        "anthropic" => Ok(AdapterKind::Anthropic),
452        "gemini" | "google" => Ok(AdapterKind::Gemini),
453        "groq" => Ok(AdapterKind::Groq),
454        "cohere" => Ok(AdapterKind::Cohere),
455        "ollama" => Ok(AdapterKind::Ollama),
456        "xai" => Ok(AdapterKind::Xai),
457        "deepseek" => Ok(AdapterKind::DeepSeek),
458        _ => Err(anyhow::anyhow!("Unsupported provider: {}", provider)),
459    }
460}
461
462#[cfg(test)]
463mod tests {
464    use super::*;
465
466    #[test]
467    fn test_str_to_adapter_kind() {
468        assert!(str_to_adapter_kind("openai").is_ok());
469        assert!(str_to_adapter_kind("anthropic").is_ok());
470        assert!(str_to_adapter_kind("gemini").is_ok());
471        assert!(str_to_adapter_kind("google").is_ok());
472        assert!(str_to_adapter_kind("groq").is_ok());
473        assert!(str_to_adapter_kind("cohere").is_ok());
474        assert!(str_to_adapter_kind("ollama").is_ok());
475        assert!(str_to_adapter_kind("xai").is_ok());
476        assert!(str_to_adapter_kind("deepseek").is_ok());
477        assert!(str_to_adapter_kind("invalid").is_err());
478    }
479
480    #[tokio::test]
481    async fn test_provider_creation() {
482        let provider = GenAIProvider::new();
483        assert!(provider.is_ok());
484    }
485
486    #[tokio::test]
487    async fn test_provider_is_clonable() {
488        let provider = GenAIProvider::new().unwrap();
489        let _clone1 = provider.clone();
490        let _clone2 = provider.clone();
491        // All clones share the same underlying state
492    }
493}