gitai/core/
llm.rs

1use crate::config::Config;
2use crate::debug;
3use anyhow::{Result, anyhow};
4use llm::{
5    LLMProvider,
6    builder::{LLMBackend, LLMBuilder},
7    chat::ChatMessage,
8};
9use schemars::JsonSchema;
10use serde::de::DeserializeOwned;
11use std::collections::HashMap;
12use std::str::FromStr;
13use std::time::Duration;
14use tokio_retry::Retry;
15use tokio_retry::strategy::ExponentialBackoff;
16
17#[derive(Clone, Debug)]
18struct ProviderDefault {
19    model: &'static str,
20    token_limit: usize,
21}
22
23static PROVIDER_DEFAULTS: std::sync::LazyLock<HashMap<&'static str, ProviderDefault>> =
24    std::sync::LazyLock::new(|| {
25        let mut map = HashMap::new();
26        map.insert(
27            "openai",
28            ProviderDefault {
29                model: "gpt-4.1",
30                token_limit: 128_000,
31            },
32        );
33        map.insert(
34            "anthropic",
35            ProviderDefault {
36                model: "claude-sonnet-4-20250514",
37                token_limit: 200_000,
38            },
39        );
40        map.insert(
41            "ollama",
42            ProviderDefault {
43                model: "llama3",
44                token_limit: 128_000,
45            },
46        );
47        map.insert(
48            "google",
49            ProviderDefault {
50                model: "gemini-2.5-pro-preview-06-05",
51                token_limit: 1_000_000,
52            },
53        );
54        map.insert(
55            "groq",
56            ProviderDefault {
57                model: "llama-3.1-70b-versatile",
58                token_limit: 128_000,
59            },
60        );
61        map.insert(
62            "xai",
63            ProviderDefault {
64                model: "grok-2-beta",
65                token_limit: 128_000,
66            },
67        );
68        map.insert(
69            "deepseek",
70            ProviderDefault {
71                model: "deepseek-chat",
72                token_limit: 64_000,
73            },
74        );
75        map.insert(
76            "phind",
77            ProviderDefault {
78                model: "phind-v2",
79                token_limit: 32_000,
80            },
81        );
82        map.insert(
83            "openrouter",
84            ProviderDefault {
85                model: "openrouter/sonoma-dusk-alpha",
86                token_limit: 2_000_000,
87            },
88        );
89        map
90    });
91
92/// Generates a message using the given configuration
93pub async fn get_message<T>(
94    config: &Config,
95    provider_name: &str,
96    system_prompt: &str,
97    user_prompt: &str,
98) -> Result<T>
99where
100    T: DeserializeOwned + JsonSchema,
101{
102    debug!("Generating message using provider: {}", provider_name);
103    debug!("System prompt: {}", system_prompt);
104    debug!("User prompt: {}", user_prompt);
105
106    // Parse the provider type
107    let backend = if provider_name.to_lowercase() == "openrouter" {
108        LLMBackend::OpenRouter
109    } else {
110        LLMBackend::from_str(provider_name).map_err(|e| anyhow!("Invalid provider: {e}"))?
111    };
112
113    // Get provider configuration
114    let provider_config = config
115        .get_provider_config(provider_name)
116        .ok_or_else(|| anyhow!("Provider '{provider_name}' not found in configuration"))?;
117
118    // Build the provider
119    let mut builder = LLMBuilder::new().backend(backend.clone());
120
121    // Set model
122    if !provider_config.model_name.is_empty() {
123        builder = builder.model(provider_config.model_name.clone());
124    }
125
126    // Set system prompt
127    builder = builder.system(system_prompt.to_string());
128
129    // Set API key if needed
130    if requires_api_key(&backend) && !provider_config.api_key.is_empty() {
131        builder = builder.api_key(provider_config.api_key.clone());
132    }
133
134    // Set temperature if specified in additional params
135    if let Some(temp) = provider_config.additional_params.get("temperature")
136        && let Ok(temp_val) = temp.parse::<f32>()
137    {
138        builder = builder.temperature(temp_val);
139    }
140
141    // Set max tokens if specified in additional params, otherwise use 4096 as default
142    // For OpenAI thinking models, don't set max_tokens via builder since they use max_completion_tokens
143    if is_openai_thinking_model(&provider_config.model_name)
144        && provider_name.to_lowercase() == "openai"
145    {
146        // For thinking models, max_completion_tokens should be handled via additional_params
147        // Don't set max_tokens via the builder for these models
148    } else if let Some(max_tokens) = provider_config.additional_params.get("max_tokens") {
149        if let Ok(mt_val) = max_tokens.parse::<u32>() {
150            builder = builder.max_tokens(mt_val);
151        }
152    } else {
153        let default_max = get_default_token_limit_for_provider(provider_name)
154            .try_into()
155            .map_err(|e| anyhow!("Token limit too large for u32: {e}"))?;
156        builder = builder.max_tokens(default_max);
157    }
158
159    // Set top_p if specified in additional params
160    if let Some(top_p) = provider_config.additional_params.get("top_p")
161        && let Ok(tp_val) = top_p.parse::<f32>()
162    {
163        builder = builder.top_p(tp_val);
164    }
165
166    // Build the provider
167    let provider = builder
168        .build()
169        .map_err(|e| anyhow!("Failed to build provider: {e}"))?;
170
171    // Generate the message
172    get_message_with_provider(provider, user_prompt, provider_name).await
173}
174
175/// Generates a message using the given provider (mainly for testing purposes)
176pub async fn get_message_with_provider<T>(
177    provider: Box<dyn LLMProvider + Send + Sync>,
178    user_prompt: &str,
179    provider_type: &str,
180) -> Result<T>
181where
182    T: DeserializeOwned + JsonSchema,
183{
184    debug!("Entering get_message_with_provider");
185
186    let retry_strategy = ExponentialBackoff::from_millis(10).factor(2).take(2); // 2 attempts total: initial + 1 retry
187
188    let result = Retry::spawn(retry_strategy, || async {
189        debug!("Attempting to generate message");
190
191        // Enhanced prompt that requests specifically formatted JSON output
192        let enhanced_prompt = if std::any::type_name::<T>() == std::any::type_name::<String>() {
193            user_prompt.to_string()
194        } else {
195            format!("{user_prompt}\n\nPlease respond with a valid JSON object and nothing else. No explanations or text outside the JSON.")
196        };
197
198        // Create chat message with user prompt
199        let mut messages = vec![ChatMessage::user().content(enhanced_prompt).build()];
200
201        // Special handling for Anthropic - use the "prefill" technique with "{"
202        if provider_type.to_lowercase() == "anthropic" && std::any::type_name::<T>() != std::any::type_name::<String>() {
203            messages.push(ChatMessage::assistant().content("Here is the JSON:\n{").build());
204        }
205
206        match tokio::time::timeout(Duration::from_secs(30), provider.chat(&messages)).await {
207            Ok(Ok(response)) => {
208                debug!("Received response from provider");
209                let response_text = response.text().unwrap_or_default();
210
211                // Provider-specific response parsing
212                let result = match provider_type.to_lowercase().as_str() {
213                    // For Anthropic with brace prefixing
214                    "anthropic" => {
215                        if std::any::type_name::<T>() == std::any::type_name::<String>() {
216                            // For String type, we need to handle differently
217                            #[allow(clippy::unnecessary_to_owned)]
218                            let string_result: T = serde_json::from_value(serde_json::Value::String(response_text.clone()))
219                                .map_err(|e| anyhow!("String conversion error: {e}"))?;
220                            Ok(string_result)
221                        } else {
222                            parse_json_response_with_brace_prefix::<T>(&response_text)
223                        }
224                    },
225
226                    // For all other providers - use appropriate parsing
227                    _ => {
228                        if std::any::type_name::<T>() == std::any::type_name::<String>() {
229                            // For String type, we need to handle differently
230                            #[allow(clippy::unnecessary_to_owned)]
231                            let string_result: T = serde_json::from_value(serde_json::Value::String(response_text.clone()))
232                                .map_err(|e| anyhow!("String conversion error: {e}"))?;
233                            Ok(string_result)
234                        } else {
235                            // First try direct parsing, then fall back to extraction
236                            parse_json_response::<T>(&response_text)
237                        }
238                    }
239                };
240
241                match result {
242                    Ok(message) => Ok(message),
243                    Err(e) => {
244                        debug!("JSON parse error: {} text: {}", e, response_text);
245                        Err(anyhow!("JSON parse error: {e}"))
246                    }
247                }
248            }
249            Ok(Err(e)) => {
250                debug!("Provider error: {}", e);
251                Err(anyhow!("Provider error: {e}"))
252            }
253            Err(_) => {
254                debug!("Provider timed out");
255                Err(anyhow!("Provider timed out"))
256            }
257        }
258    })
259    .await;
260
261    match result {
262        Ok(message) => {
263            debug!("Generated message successfully");
264            Ok(message)
265        }
266        Err(e) => {
267            debug!("Failed to generate message after retries: {}", e);
268            Err(anyhow!("Failed to generate message: {e}"))
269        }
270    }
271}
272
273/// Parse a provider's response that should be pure JSON
274fn parse_json_response<T: DeserializeOwned>(text: &str) -> Result<T> {
275    match serde_json::from_str::<T>(text) {
276        Ok(message) => Ok(message),
277        Err(e) => {
278            // Fallback to a more robust extraction if direct parsing fails
279            debug!(
280                "Direct JSON parse failed: {}. Attempting fallback extraction.",
281                e
282            );
283            extract_and_parse_json(text)
284        }
285    }
286}
287
288/// Parse a response from Anthropic that needs the prefixed "{"
289fn parse_json_response_with_brace_prefix<T: DeserializeOwned>(text: &str) -> Result<T> {
290    // Add the opening brace that we prefilled in the prompt
291    let json_text = format!("{{{text}");
292    match serde_json::from_str::<T>(&json_text) {
293        Ok(message) => Ok(message),
294        Err(e) => {
295            debug!(
296                "Brace-prefixed JSON parse failed: {}. Attempting fallback extraction.",
297                e
298            );
299            extract_and_parse_json(text)
300        }
301    }
302}
303
304/// Extracts and parses JSON from a potentially non-JSON response
305fn extract_and_parse_json<T: DeserializeOwned>(text: &str) -> Result<T> {
306    let cleaned_json = clean_json_from_llm(text);
307    serde_json::from_str(&cleaned_json).map_err(|e| anyhow!("JSON parse error: {e}"))
308}
309
310/// Returns a list of available LLM providers as strings
311pub fn get_available_provider_names() -> Vec<String> {
312    vec![
313        "openai".to_string(),
314        "anthropic".to_string(),
315        "ollama".to_string(),
316        "google".to_string(),
317        "groq".to_string(),
318        "xai".to_string(),
319        "deepseek".to_string(),
320        "phind".to_string(),
321        "openrouter".to_string(),
322    ]
323}
324
325/// Returns the default model for a given provider
326pub fn get_default_model_for_provider(provider_type: &str) -> &'static str {
327    PROVIDER_DEFAULTS
328        .get(provider_type.to_lowercase().as_str())
329        .map_or("gpt-4.1", |def| def.model)
330}
331
332/// Returns the default token limit for a given provider
333pub fn get_default_token_limit_for_provider(provider_type: &str) -> usize {
334    PROVIDER_DEFAULTS
335        .get(provider_type.to_lowercase().as_str())
336        .map_or(8_192, |def| def.token_limit)
337}
338
339/// Checks if a provider requires an API key
340pub fn provider_requires_api_key(provider_type: &str) -> bool {
341    if let Ok(backend) = LLMBackend::from_str(provider_type) {
342        requires_api_key(&backend)
343    } else {
344        true // Default to requiring API key for unknown providers
345    }
346}
347
348/// Helper function: check if `LLMBackend` requires API key
349fn requires_api_key(backend: &LLMBackend) -> bool {
350    !matches!(backend, LLMBackend::Ollama | LLMBackend::Phind)
351}
352
353/// Helper function: check if the model is an `OpenAI` thinking model
354fn is_openai_thinking_model(model: &str) -> bool {
355    let model_lower = model.to_lowercase();
356    model_lower.starts_with('o')
357}
358
359/// Validates the provider configuration
360pub fn validate_provider_config(config: &Config, provider_name: &str) -> Result<()> {
361    if provider_requires_api_key(provider_name) {
362        let provider_config = config
363            .get_provider_config(provider_name)
364            .ok_or_else(|| anyhow!("Provider '{provider_name}' not found in configuration"))?;
365
366        if provider_config.api_key.is_empty() {
367            return Err(anyhow!("API key required for provider: {provider_name}"));
368        }
369    }
370
371    Ok(())
372}
373
374/// Combines default, saved, and command-line configurations
375pub fn get_combined_config<S: ::std::hash::BuildHasher>(
376    config: &Config,
377    provider_name: &str,
378    command_line_args: &HashMap<String, String, S>,
379) -> HashMap<String, String> {
380    let mut combined_params = HashMap::default();
381
382    // Add default values
383    combined_params.insert(
384        "model".to_string(),
385        get_default_model_for_provider(provider_name).to_string(),
386    );
387
388    // Add saved config values if available
389    if let Some(provider_config) = config.get_provider_config(provider_name) {
390        if !provider_config.api_key.is_empty() {
391            combined_params.insert("api_key".to_string(), provider_config.api_key.clone());
392        }
393        if !provider_config.model_name.is_empty() {
394            combined_params.insert("model".to_string(), provider_config.model_name.clone());
395        }
396        for (key, value) in &provider_config.additional_params {
397            combined_params.insert(key.clone(), value.clone());
398        }
399    }
400
401    // Add command line args (these take precedence)
402    for (key, value) in command_line_args {
403        if !value.is_empty() {
404            combined_params.insert(key.clone(), value.clone());
405        }
406    }
407
408    // Handle OpenAI thinking models: convert max_tokens to max_completion_tokens
409    if provider_name.to_lowercase() == "openai"
410        && let Some(model) = combined_params.get("model")
411        && is_openai_thinking_model(model)
412        && let Some(max_tokens) = combined_params.remove("max_tokens")
413    {
414        combined_params.insert("max_completion_tokens".to_string(), max_tokens);
415    }
416
417    combined_params
418}
419
420fn clean_json_from_llm(json_str: &str) -> String {
421    // Remove potential leading/trailing whitespace and invisible characters
422    let trimmed = json_str
423        .trim_start_matches(|c: char| c.is_whitespace() || !c.is_ascii())
424        .trim_end_matches(|c: char| c.is_whitespace() || !c.is_ascii());
425
426    // If wrapped in code block, remove the markers
427    let without_codeblock = if trimmed.starts_with("```") && trimmed.ends_with("```") {
428        let start = trimmed.find('{').unwrap_or(0);
429        let end = trimmed.rfind('}').map_or(trimmed.len(), |i| i + 1);
430        &trimmed[start..end]
431    } else {
432        trimmed
433    };
434
435    // Find the first '{' and last '}' to extract the JSON object
436    let start = without_codeblock.find('{').unwrap_or(0);
437    let end = without_codeblock
438        .rfind('}')
439        .map_or(without_codeblock.len(), |i| i + 1);
440
441    without_codeblock[start..end].trim().to_string()
442}