git_iris/
llm.rs

1use crate::config::Config;
2use crate::log_debug;
3use anyhow::{Result, anyhow};
4use llm::{
5    LLMProvider,
6    builder::{LLMBackend, LLMBuilder},
7    chat::ChatMessage,
8};
9use serde::Serialize;
10use serde::de::DeserializeOwned;
11use std::collections::HashMap;
12use std::str::FromStr;
13use std::time::Duration;
14use tokio_retry::Retry;
15use tokio_retry::strategy::ExponentialBackoff;
16
17/// Generates a message using the given configuration
18pub async fn get_message<T>(
19    config: &Config,
20    provider_name: &str,
21    system_prompt: &str,
22    user_prompt: &str,
23) -> Result<T>
24where
25    T: Serialize + DeserializeOwned + std::fmt::Debug,
26    String: Into<T>,
27{
28    log_debug!("Generating message using provider: {}", provider_name);
29    log_debug!("System prompt: {}", system_prompt);
30    log_debug!("User prompt: {}", user_prompt);
31
32    // Parse the provider type
33    let backend =
34        LLMBackend::from_str(provider_name).map_err(|e| anyhow!("Invalid provider: {}", e))?;
35
36    // Get provider configuration
37    let provider_config = config
38        .get_provider_config(provider_name)
39        .ok_or_else(|| anyhow!("Provider '{}' not found in configuration", provider_name))?;
40
41    // Build the provider
42    let mut builder = LLMBuilder::new().backend(backend.clone());
43
44    // Set model
45    if !provider_config.model.is_empty() {
46        builder = builder.model(provider_config.model.clone());
47    }
48
49    // Set system prompt
50    builder = builder.system(system_prompt.to_string());
51
52    // Set API key if needed
53    if requires_api_key(&backend) && !provider_config.api_key.is_empty() {
54        builder = builder.api_key(provider_config.api_key.clone());
55    }
56
57    // Set temperature if specified in additional params
58    if let Some(temp) = provider_config.additional_params.get("temperature") {
59        if let Ok(temp_val) = temp.parse::<f32>() {
60            builder = builder.temperature(temp_val);
61        }
62    }
63
64    // Set max tokens if specified in additional params, otherwise use 4096 as default
65    if let Some(max_tokens) = provider_config.additional_params.get("max_tokens") {
66        if let Ok(mt_val) = max_tokens.parse::<u32>() {
67            builder = builder.max_tokens(mt_val);
68        }
69    } else {
70        builder = builder.max_tokens(4096);
71    }
72
73    // Set top_p if specified in additional params
74    if let Some(top_p) = provider_config.additional_params.get("top_p") {
75        if let Ok(tp_val) = top_p.parse::<f32>() {
76            builder = builder.top_p(tp_val);
77        }
78    }
79
80    // Build the provider
81    let provider = builder
82        .build()
83        .map_err(|e| anyhow!("Failed to build provider: {}", e))?;
84
85    // Generate the message
86    let result = get_message_with_provider::<T>(provider, user_prompt).await?;
87
88    Ok(result)
89}
90
91/// Generates a message using the given provider (mainly for testing purposes)
92pub async fn get_message_with_provider<T>(
93    provider: Box<dyn LLMProvider + Send + Sync>,
94    user_prompt: &str,
95) -> Result<T>
96where
97    T: Serialize + DeserializeOwned + std::fmt::Debug,
98    String: Into<T>,
99{
100    log_debug!("Entering get_message_with_provider");
101
102    let retry_strategy = ExponentialBackoff::from_millis(10).factor(2).take(2); // 2 attempts total: initial + 1 retry
103
104    let result = Retry::spawn(retry_strategy, || async {
105        log_debug!("Attempting to generate message");
106
107        // Create chat message with user prompt
108        let messages = vec![ChatMessage::user().content(user_prompt.to_string()).build()];
109
110        match tokio::time::timeout(Duration::from_secs(30), provider.chat(&messages)).await {
111            Ok(Ok(response)) => {
112                log_debug!("Received response from provider");
113                let response_text = response.text().unwrap_or_default();
114                let cleaned_message = clean_json_from_llm(&response_text);
115
116                if std::any::type_name::<T>() == std::any::type_name::<String>() {
117                    // If T is String, return the raw string response
118                    Ok(cleaned_message.into())
119                } else {
120                    // Attempt to deserialize the response
121                    match serde_json::from_str::<T>(&cleaned_message) {
122                        Ok(message) => Ok(message),
123                        Err(e) => {
124                            log_debug!("Deserialization error: {} message: {}", e, cleaned_message);
125                            Err(anyhow!("Deserialization error: {}", e))
126                        }
127                    }
128                }
129            }
130            Ok(Err(e)) => {
131                log_debug!("Provider error: {}", e);
132                Err(anyhow!("Provider error: {}", e))
133            }
134            Err(_) => {
135                log_debug!("Provider timed out");
136                Err(anyhow!("Provider timed out"))
137            }
138        }
139    })
140    .await;
141
142    match result {
143        Ok(message) => {
144            log_debug!("Deserialized message: {:?}", message);
145            Ok(message)
146        }
147        Err(e) => {
148            log_debug!("Failed to generate message after retries: {}", e);
149            Err(anyhow!("Failed to generate message: {}", e))
150        }
151    }
152}
153
154/// Returns a list of available LLM providers as strings
155pub fn get_available_provider_names() -> Vec<String> {
156    vec![
157        "openai".to_string(),
158        "anthropic".to_string(),
159        "ollama".to_string(),
160        "google".to_string(),
161        "groq".to_string(),
162        "xai".to_string(),
163        "deepseek".to_string(),
164        "phind".to_string(),
165    ]
166}
167
168/// Returns the default model for a given provider
169pub fn get_default_model_for_provider(provider_type: &str) -> &'static str {
170    match provider_type.to_lowercase().as_str() {
171        "anthropic" => "claude-3-7-sonnet-20250219",
172        "ollama" => "llama3",
173        "google" => "gemini-2.0-flash",
174        "groq" => "llama-3.1-70b-versatile",
175        "xai" => "grok-2-beta",
176        "deepseek" => "deepseek-chat",
177        "phind" => "phind-v2",
178        _ => "gpt-4o", // Default to OpenAI's model
179    }
180}
181
182/// Returns the default token limit for a given provider
183pub fn get_default_token_limit_for_provider(provider_type: &str) -> Result<usize> {
184    let limit = match provider_type.to_lowercase().as_str() {
185        "anthropic" => 200_000,
186        "ollama" | "openai" | "groq" | "xai" => 128_000,
187        "google" => 1_000_000,
188        "deepseek" => 64_000,
189        "phind" => 32_000,
190        _ => 8_192, // Default token limit
191    };
192    Ok(limit)
193}
194
195/// Checks if a provider requires an API key
196pub fn provider_requires_api_key(provider_type: &str) -> bool {
197    if let Ok(backend) = LLMBackend::from_str(provider_type) {
198        requires_api_key(&backend)
199    } else {
200        true // Default to requiring API key for unknown providers
201    }
202}
203
204/// Helper function: check if `LLMBackend` requires API key
205fn requires_api_key(backend: &LLMBackend) -> bool {
206    !matches!(backend, LLMBackend::Ollama | LLMBackend::Phind)
207}
208
209/// Validates the provider configuration
210pub fn validate_provider_config(config: &Config, provider_name: &str) -> Result<()> {
211    if provider_requires_api_key(provider_name) {
212        let provider_config = config
213            .get_provider_config(provider_name)
214            .ok_or_else(|| anyhow!("Provider '{}' not found in configuration", provider_name))?;
215
216        if provider_config.api_key.is_empty() {
217            return Err(anyhow!("API key required for provider: {}", provider_name));
218        }
219    }
220
221    Ok(())
222}
223
224/// Combines default, saved, and command-line configurations
225pub fn get_combined_config<S: ::std::hash::BuildHasher>(
226    config: &Config,
227    provider_name: &str,
228    command_line_args: &HashMap<String, String, S>,
229) -> HashMap<String, String> {
230    let mut combined_params = HashMap::default();
231
232    // Add default values
233    combined_params.insert(
234        "model".to_string(),
235        get_default_model_for_provider(provider_name).to_string(),
236    );
237
238    // Add saved config values if available
239    if let Some(provider_config) = config.get_provider_config(provider_name) {
240        if !provider_config.api_key.is_empty() {
241            combined_params.insert("api_key".to_string(), provider_config.api_key.clone());
242        }
243        if !provider_config.model.is_empty() {
244            combined_params.insert("model".to_string(), provider_config.model.clone());
245        }
246        for (key, value) in &provider_config.additional_params {
247            combined_params.insert(key.clone(), value.clone());
248        }
249    }
250
251    // Add command line args (these take precedence)
252    for (key, value) in command_line_args {
253        if !value.is_empty() {
254            combined_params.insert(key.clone(), value.clone());
255        }
256    }
257
258    combined_params
259}
260
261fn clean_json_from_llm(json_str: &str) -> String {
262    // Remove potential leading/trailing whitespace and invisible characters
263    let trimmed = json_str
264        .trim_start_matches(|c: char| c.is_whitespace() || !c.is_ascii())
265        .trim_end_matches(|c: char| c.is_whitespace() || !c.is_ascii());
266
267    // If wrapped in code block, remove the markers
268    let without_codeblock = if trimmed.starts_with("```") && trimmed.ends_with("```") {
269        let start = trimmed.find('{').unwrap_or(0);
270        let end = trimmed.rfind('}').map_or(trimmed.len(), |i| i + 1);
271        &trimmed[start..end]
272    } else {
273        trimmed
274    };
275
276    // Find the first '{' and last '}' to extract the JSON object
277    let start = without_codeblock.find('{').unwrap_or(0);
278    let end = without_codeblock
279        .rfind('}')
280        .map_or(without_codeblock.len(), |i| i + 1);
281
282    without_codeblock[start..end].trim().to_string()
283}