git_iris/
llm.rs

1use crate::config::Config;
2use crate::log_debug;
3use anyhow::{Result, anyhow};
4use llm::{
5    LLMProvider,
6    builder::{LLMBackend, LLMBuilder},
7    chat::ChatMessage,
8};
9use schemars::JsonSchema;
10use serde::de::DeserializeOwned;
11use std::collections::HashMap;
12use std::str::FromStr;
13use std::time::Duration;
14use tokio_retry::Retry;
15use tokio_retry::strategy::ExponentialBackoff;
16
17/// Generates a message using the given configuration
18pub async fn get_message<T>(
19    config: &Config,
20    provider_name: &str,
21    system_prompt: &str,
22    user_prompt: &str,
23) -> Result<T>
24where
25    T: DeserializeOwned + JsonSchema,
26{
27    log_debug!("Generating message using provider: {}", provider_name);
28    log_debug!("System prompt: {}", system_prompt);
29    log_debug!("User prompt: {}", user_prompt);
30
31    // Parse the provider type
32    let backend =
33        LLMBackend::from_str(provider_name).map_err(|e| anyhow!("Invalid provider: {}", e))?;
34
35    // Get provider configuration
36    let provider_config = config
37        .get_provider_config(provider_name)
38        .ok_or_else(|| anyhow!("Provider '{}' not found in configuration", provider_name))?;
39
40    // Build the provider
41    let mut builder = LLMBuilder::new().backend(backend.clone());
42
43    // Set model
44    if !provider_config.model.is_empty() {
45        builder = builder.model(provider_config.model.clone());
46    }
47
48    // Set system prompt
49    builder = builder.system(system_prompt.to_string());
50
51    // Set API key if needed
52    if requires_api_key(&backend) && !provider_config.api_key.is_empty() {
53        builder = builder.api_key(provider_config.api_key.clone());
54    }
55
56    // Set temperature if specified in additional params
57    if let Some(temp) = provider_config.additional_params.get("temperature") {
58        if let Ok(temp_val) = temp.parse::<f32>() {
59            builder = builder.temperature(temp_val);
60        }
61    }
62
63    // Set max tokens if specified in additional params, otherwise use 4096 as default
64    if let Some(max_tokens) = provider_config.additional_params.get("max_tokens") {
65        if let Ok(mt_val) = max_tokens.parse::<u32>() {
66            builder = builder.max_tokens(mt_val);
67        }
68    } else {
69        builder = builder.max_tokens(4096);
70    }
71
72    // Set top_p if specified in additional params
73    if let Some(top_p) = provider_config.additional_params.get("top_p") {
74        if let Ok(tp_val) = top_p.parse::<f32>() {
75            builder = builder.top_p(tp_val);
76        }
77    }
78
79    // Build the provider
80    let provider = builder
81        .build()
82        .map_err(|e| anyhow!("Failed to build provider: {}", e))?;
83
84    // Generate the message
85    get_message_with_provider(provider, user_prompt, provider_name).await
86}
87
88/// Generates a message using the given provider (mainly for testing purposes)
89pub async fn get_message_with_provider<T>(
90    provider: Box<dyn LLMProvider + Send + Sync>,
91    user_prompt: &str,
92    provider_type: &str,
93) -> Result<T>
94where
95    T: DeserializeOwned + JsonSchema,
96{
97    log_debug!("Entering get_message_with_provider");
98
99    let retry_strategy = ExponentialBackoff::from_millis(10).factor(2).take(2); // 2 attempts total: initial + 1 retry
100
101    let result = Retry::spawn(retry_strategy, || async {
102        log_debug!("Attempting to generate message");
103
104        // Enhanced prompt that requests specifically formatted JSON output
105        let enhanced_prompt = if std::any::type_name::<T>() == std::any::type_name::<String>() {
106            user_prompt.to_string()
107        } else {
108            format!("{user_prompt}\n\nPlease respond with a valid JSON object and nothing else. No explanations or text outside the JSON.")
109        };
110
111        // Create chat message with user prompt
112        let mut messages = vec![ChatMessage::user().content(enhanced_prompt).build()];
113
114        // Special handling for Anthropic - use the "prefill" technique with "{"
115        if provider_type.to_lowercase() == "anthropic" && std::any::type_name::<T>() != std::any::type_name::<String>() {
116            messages.push(ChatMessage::assistant().content("Here is the JSON:\n{").build());
117        }
118
119        match tokio::time::timeout(Duration::from_secs(30), provider.chat(&messages)).await {
120            Ok(Ok(response)) => {
121                log_debug!("Received response from provider");
122                let response_text = response.text().unwrap_or_default();
123
124                // Provider-specific response parsing
125                let result = match provider_type.to_lowercase().as_str() {
126                    // For Anthropic with brace prefixing
127                    "anthropic" => {
128                        if std::any::type_name::<T>() == std::any::type_name::<String>() {
129                            // For String type, we need to handle differently
130                            #[allow(clippy::unnecessary_to_owned)]
131                            let string_result: T = serde_json::from_value(serde_json::Value::String(response_text.clone()))
132                                .map_err(|e| anyhow!("String conversion error: {}", e))?;
133                            Ok(string_result)
134                        } else {
135                            parse_json_response_with_brace_prefix::<T>(&response_text)
136                        }
137                    },
138
139                    // For all other providers - use appropriate parsing
140                    _ => {
141                        if std::any::type_name::<T>() == std::any::type_name::<String>() {
142                            // For String type, we need to handle differently
143                            #[allow(clippy::unnecessary_to_owned)]
144                            let string_result: T = serde_json::from_value(serde_json::Value::String(response_text.clone()))
145                                .map_err(|e| anyhow!("String conversion error: {}", e))?;
146                            Ok(string_result)
147                        } else {
148                            // First try direct parsing, then fall back to extraction
149                            parse_json_response::<T>(&response_text)
150                        }
151                    }
152                };
153
154                match result {
155                    Ok(message) => Ok(message),
156                    Err(e) => {
157                        log_debug!("JSON parse error: {} text: {}", e, response_text);
158                        Err(anyhow!("JSON parse error: {}", e))
159                    }
160                }
161            }
162            Ok(Err(e)) => {
163                log_debug!("Provider error: {}", e);
164                Err(anyhow!("Provider error: {}", e))
165            }
166            Err(_) => {
167                log_debug!("Provider timed out");
168                Err(anyhow!("Provider timed out"))
169            }
170        }
171    })
172    .await;
173
174    match result {
175        Ok(message) => {
176            log_debug!("Generated message successfully");
177            Ok(message)
178        }
179        Err(e) => {
180            log_debug!("Failed to generate message after retries: {}", e);
181            Err(anyhow!("Failed to generate message: {}", e))
182        }
183    }
184}
185
186/// Parse a provider's response that should be pure JSON
187fn parse_json_response<T: DeserializeOwned>(text: &str) -> Result<T> {
188    match serde_json::from_str::<T>(text) {
189        Ok(message) => Ok(message),
190        Err(e) => {
191            // Fallback to a more robust extraction if direct parsing fails
192            log_debug!(
193                "Direct JSON parse failed: {}. Attempting fallback extraction.",
194                e
195            );
196            extract_and_parse_json(text)
197        }
198    }
199}
200
201/// Parse a response from Anthropic that needs the prefixed "{"
202fn parse_json_response_with_brace_prefix<T: DeserializeOwned>(text: &str) -> Result<T> {
203    // Add the opening brace that we prefilled in the prompt
204    let json_text = format!("{{{text}");
205    match serde_json::from_str::<T>(&json_text) {
206        Ok(message) => Ok(message),
207        Err(e) => {
208            log_debug!(
209                "Brace-prefixed JSON parse failed: {}. Attempting fallback extraction.",
210                e
211            );
212            extract_and_parse_json(text)
213        }
214    }
215}
216
217/// Extracts and parses JSON from a potentially non-JSON response
218fn extract_and_parse_json<T: DeserializeOwned>(text: &str) -> Result<T> {
219    let cleaned_json = clean_json_from_llm(text);
220    serde_json::from_str(&cleaned_json).map_err(|e| anyhow!("JSON parse error: {}", e))
221}
222
223/// Returns a list of available LLM providers as strings
224pub fn get_available_provider_names() -> Vec<String> {
225    vec![
226        "openai".to_string(),
227        "anthropic".to_string(),
228        "ollama".to_string(),
229        "google".to_string(),
230        "groq".to_string(),
231        "xai".to_string(),
232        "deepseek".to_string(),
233        "phind".to_string(),
234    ]
235}
236
237/// Returns the default model for a given provider
238pub fn get_default_model_for_provider(provider_type: &str) -> &'static str {
239    match provider_type.to_lowercase().as_str() {
240        "anthropic" => "claude-3-7-sonnet-latest",
241        "ollama" => "llama3",
242        "google" => "gemini-2.0-flash",
243        "groq" => "llama-3.1-70b-versatile",
244        "xai" => "grok-2-beta",
245        "deepseek" => "deepseek-chat",
246        "phind" => "phind-v2",
247        _ => "gpt-4.1", // Default to OpenAI's model
248    }
249}
250
251/// Returns the default token limit for a given provider
252pub fn get_default_token_limit_for_provider(provider_type: &str) -> Result<usize> {
253    let limit = match provider_type.to_lowercase().as_str() {
254        "anthropic" => 200_000,
255        "ollama" | "openai" | "groq" | "xai" => 128_000,
256        "google" => 1_000_000,
257        "deepseek" => 64_000,
258        "phind" => 32_000,
259        _ => 8_192, // Default token limit
260    };
261    Ok(limit)
262}
263
264/// Checks if a provider requires an API key
265pub fn provider_requires_api_key(provider_type: &str) -> bool {
266    if let Ok(backend) = LLMBackend::from_str(provider_type) {
267        requires_api_key(&backend)
268    } else {
269        true // Default to requiring API key for unknown providers
270    }
271}
272
273/// Helper function: check if `LLMBackend` requires API key
274fn requires_api_key(backend: &LLMBackend) -> bool {
275    !matches!(backend, LLMBackend::Ollama | LLMBackend::Phind)
276}
277
278/// Validates the provider configuration
279pub fn validate_provider_config(config: &Config, provider_name: &str) -> Result<()> {
280    if provider_requires_api_key(provider_name) {
281        let provider_config = config
282            .get_provider_config(provider_name)
283            .ok_or_else(|| anyhow!("Provider '{}' not found in configuration", provider_name))?;
284
285        if provider_config.api_key.is_empty() {
286            return Err(anyhow!("API key required for provider: {}", provider_name));
287        }
288    }
289
290    Ok(())
291}
292
293/// Combines default, saved, and command-line configurations
294pub fn get_combined_config<S: ::std::hash::BuildHasher>(
295    config: &Config,
296    provider_name: &str,
297    command_line_args: &HashMap<String, String, S>,
298) -> HashMap<String, String> {
299    let mut combined_params = HashMap::default();
300
301    // Add default values
302    combined_params.insert(
303        "model".to_string(),
304        get_default_model_for_provider(provider_name).to_string(),
305    );
306
307    // Add saved config values if available
308    if let Some(provider_config) = config.get_provider_config(provider_name) {
309        if !provider_config.api_key.is_empty() {
310            combined_params.insert("api_key".to_string(), provider_config.api_key.clone());
311        }
312        if !provider_config.model.is_empty() {
313            combined_params.insert("model".to_string(), provider_config.model.clone());
314        }
315        for (key, value) in &provider_config.additional_params {
316            combined_params.insert(key.clone(), value.clone());
317        }
318    }
319
320    // Add command line args (these take precedence)
321    for (key, value) in command_line_args {
322        if !value.is_empty() {
323            combined_params.insert(key.clone(), value.clone());
324        }
325    }
326
327    combined_params
328}
329
330fn clean_json_from_llm(json_str: &str) -> String {
331    // Remove potential leading/trailing whitespace and invisible characters
332    let trimmed = json_str
333        .trim_start_matches(|c: char| c.is_whitespace() || !c.is_ascii())
334        .trim_end_matches(|c: char| c.is_whitespace() || !c.is_ascii());
335
336    // If wrapped in code block, remove the markers
337    let without_codeblock = if trimmed.starts_with("```") && trimmed.ends_with("```") {
338        let start = trimmed.find('{').unwrap_or(0);
339        let end = trimmed.rfind('}').map_or(trimmed.len(), |i| i + 1);
340        &trimmed[start..end]
341    } else {
342        trimmed
343    };
344
345    // Find the first '{' and last '}' to extract the JSON object
346    let start = without_codeblock.find('{').unwrap_or(0);
347    let end = without_codeblock
348        .rfind('}')
349        .map_or(without_codeblock.len(), |i| i + 1);
350
351    without_codeblock[start..end].trim().to_string()
352}