gitai/core/
llm.rs

1use crate::config::Config;
2use crate::debug;
3use anyhow::{Result, anyhow};
4use llm::{
5    LLMProvider,
6    builder::{LLMBackend, LLMBuilder},
7    chat::ChatMessage,
8};
9use schemars::JsonSchema;
10use serde::de::DeserializeOwned;
11use std::collections::HashMap;
12use std::str::FromStr;
13use std::time::Duration;
14use tokio_retry::Retry;
15use tokio_retry::strategy::ExponentialBackoff;
16
17#[derive(Clone, Debug)]
18struct ProviderDefault {
19    model: &'static str,
20    token_limit: usize,
21}
22
23static PROVIDER_DEFAULTS: std::sync::LazyLock<HashMap<&'static str, ProviderDefault>> =
24    std::sync::LazyLock::new(|| {
25        let mut map = HashMap::new();
26        map.insert(
27            "openai",
28            ProviderDefault {
29                model: "gpt-4.1",
30                token_limit: 128_000,
31            },
32        );
33        map.insert(
34            "anthropic",
35            ProviderDefault {
36                model: "claude-sonnet-4-20250514",
37                token_limit: 200_000,
38            },
39        );
40        map.insert(
41            "ollama",
42            ProviderDefault {
43                model: "llama3",
44                token_limit: 128_000,
45            },
46        );
47        map.insert(
48            "google",
49            ProviderDefault {
50                model: "gemini-2.5-pro-preview-06-05",
51                token_limit: 1_000_000,
52            },
53        );
54        map.insert(
55            "groq",
56            ProviderDefault {
57                model: "llama-3.1-70b-versatile",
58                token_limit: 128_000,
59            },
60        );
61        map.insert(
62            "xai",
63            ProviderDefault {
64                model: "grok-2-beta",
65                token_limit: 128_000,
66            },
67        );
68        map.insert(
69            "deepseek",
70            ProviderDefault {
71                model: "deepseek-chat",
72                token_limit: 64_000,
73            },
74        );
75        map.insert(
76            "phind",
77            ProviderDefault {
78                model: "phind-v2",
79                token_limit: 32_000,
80            },
81        );
82        map.insert(
83            "openrouter",
84            ProviderDefault {
85                model: "openrouter/sonoma-dusk-alpha",
86                token_limit: 2_000_000,
87            },
88        );
89        map
90    });
91
92/// Generates a message using the given configuration
93pub async fn get_message<T>(
94    config: &Config,
95    provider_name: &str,
96    system_prompt: &str,
97    user_prompt: &str,
98) -> Result<T>
99where
100    T: DeserializeOwned + JsonSchema,
101{
102    debug!("Generating message using provider: {}", provider_name);
103    debug!("System prompt: {}", system_prompt);
104    debug!("User prompt: {}", user_prompt);
105
106    // Parse the provider type
107    let backend = if provider_name.to_lowercase() == "openrouter" {
108        LLMBackend::OpenRouter
109    } else {
110        LLMBackend::from_str(provider_name).map_err(|e| anyhow!("Invalid provider: {e}"))?
111    };
112
113    // Get provider configuration
114    let provider_config = config
115        .get_provider_config(provider_name)
116        .ok_or_else(|| anyhow!("Provider '{provider_name}' not found in configuration"))?;
117
118    // Build the provider
119    let mut builder = LLMBuilder::new().backend(backend.clone());
120
121    // Set model
122    if !provider_config.model.is_empty() {
123        builder = builder.model(provider_config.model.clone());
124    }
125
126    // Set system prompt
127    builder = builder.system(system_prompt.to_string());
128
129    // Set API key if needed
130    if requires_api_key(&backend) && !provider_config.api_key.is_empty() {
131        builder = builder.api_key(provider_config.api_key.clone());
132    }
133
134    // Set temperature if specified in additional params
135    if let Some(temp) = provider_config.additional_params.get("temperature")
136        && let Ok(temp_val) = temp.parse::<f32>()
137    {
138        builder = builder.temperature(temp_val);
139    }
140
141    // Set max tokens if specified in additional params, otherwise use 4096 as default
142    // For OpenAI thinking models, don't set max_tokens via builder since they use max_completion_tokens
143    if is_openai_thinking_model(&provider_config.model) && provider_name.to_lowercase() == "openai"
144    {
145        // For thinking models, max_completion_tokens should be handled via additional_params
146        // Don't set max_tokens via the builder for these models
147    } else if let Some(max_tokens) = provider_config.additional_params.get("max_tokens") {
148        if let Ok(mt_val) = max_tokens.parse::<u32>() {
149            builder = builder.max_tokens(mt_val);
150        }
151    } else {
152        let default_max = get_default_token_limit_for_provider(provider_name)
153            .try_into()
154            .map_err(|e| anyhow!("Token limit too large for u32: {e}"))?;
155        builder = builder.max_tokens(default_max);
156    }
157
158    // Set top_p if specified in additional params
159    if let Some(top_p) = provider_config.additional_params.get("top_p")
160        && let Ok(tp_val) = top_p.parse::<f32>()
161    {
162        builder = builder.top_p(tp_val);
163    }
164
165    // Build the provider
166    let provider = builder
167        .build()
168        .map_err(|e| anyhow!("Failed to build provider: {e}"))?;
169
170    // Generate the message
171    get_message_with_provider(provider, user_prompt, provider_name).await
172}
173
174/// Generates a message using the given provider (mainly for testing purposes)
175pub async fn get_message_with_provider<T>(
176    provider: Box<dyn LLMProvider + Send + Sync>,
177    user_prompt: &str,
178    provider_type: &str,
179) -> Result<T>
180where
181    T: DeserializeOwned + JsonSchema,
182{
183    debug!("Entering get_message_with_provider");
184
185    let retry_strategy = ExponentialBackoff::from_millis(10).factor(2).take(2); // 2 attempts total: initial + 1 retry
186
187    let result = Retry::spawn(retry_strategy, || async {
188        debug!("Attempting to generate message");
189
190        // Enhanced prompt that requests specifically formatted JSON output
191        let enhanced_prompt = if std::any::type_name::<T>() == std::any::type_name::<String>() {
192            user_prompt.to_string()
193        } else {
194            format!("{user_prompt}\n\nPlease respond with a valid JSON object and nothing else. No explanations or text outside the JSON.")
195        };
196
197        // Create chat message with user prompt
198        let mut messages = vec![ChatMessage::user().content(enhanced_prompt).build()];
199
200        // Special handling for Anthropic - use the "prefill" technique with "{"
201        if provider_type.to_lowercase() == "anthropic" && std::any::type_name::<T>() != std::any::type_name::<String>() {
202            messages.push(ChatMessage::assistant().content("Here is the JSON:\n{").build());
203        }
204
205        match tokio::time::timeout(Duration::from_secs(30), provider.chat(&messages)).await {
206            Ok(Ok(response)) => {
207                debug!("Received response from provider");
208                let response_text = response.text().unwrap_or_default();
209
210                // Provider-specific response parsing
211                let result = match provider_type.to_lowercase().as_str() {
212                    // For Anthropic with brace prefixing
213                    "anthropic" => {
214                        if std::any::type_name::<T>() == std::any::type_name::<String>() {
215                            // For String type, we need to handle differently
216                            #[allow(clippy::unnecessary_to_owned)]
217                            let string_result: T = serde_json::from_value(serde_json::Value::String(response_text.clone()))
218                                .map_err(|e| anyhow!("String conversion error: {e}"))?;
219                            Ok(string_result)
220                        } else {
221                            parse_json_response_with_brace_prefix::<T>(&response_text)
222                        }
223                    },
224
225                    // For all other providers - use appropriate parsing
226                    _ => {
227                        if std::any::type_name::<T>() == std::any::type_name::<String>() {
228                            // For String type, we need to handle differently
229                            #[allow(clippy::unnecessary_to_owned)]
230                            let string_result: T = serde_json::from_value(serde_json::Value::String(response_text.clone()))
231                                .map_err(|e| anyhow!("String conversion error: {e}"))?;
232                            Ok(string_result)
233                        } else {
234                            // First try direct parsing, then fall back to extraction
235                            parse_json_response::<T>(&response_text)
236                        }
237                    }
238                };
239
240                match result {
241                    Ok(message) => Ok(message),
242                    Err(e) => {
243                        debug!("JSON parse error: {} text: {}", e, response_text);
244                        Err(anyhow!("JSON parse error: {e}"))
245                    }
246                }
247            }
248            Ok(Err(e)) => {
249                debug!("Provider error: {}", e);
250                Err(anyhow!("Provider error: {e}"))
251            }
252            Err(_) => {
253                debug!("Provider timed out");
254                Err(anyhow!("Provider timed out"))
255            }
256        }
257    })
258    .await;
259
260    match result {
261        Ok(message) => {
262            debug!("Generated message successfully");
263            Ok(message)
264        }
265        Err(e) => {
266            debug!("Failed to generate message after retries: {}", e);
267            Err(anyhow!("Failed to generate message: {e}"))
268        }
269    }
270}
271
272/// Parse a provider's response that should be pure JSON
273fn parse_json_response<T: DeserializeOwned>(text: &str) -> Result<T> {
274    match serde_json::from_str::<T>(text) {
275        Ok(message) => Ok(message),
276        Err(e) => {
277            // Fallback to a more robust extraction if direct parsing fails
278            debug!(
279                "Direct JSON parse failed: {}. Attempting fallback extraction.",
280                e
281            );
282            extract_and_parse_json(text)
283        }
284    }
285}
286
287/// Parse a response from Anthropic that needs the prefixed "{"
288fn parse_json_response_with_brace_prefix<T: DeserializeOwned>(text: &str) -> Result<T> {
289    // Add the opening brace that we prefilled in the prompt
290    let json_text = format!("{{{text}");
291    match serde_json::from_str::<T>(&json_text) {
292        Ok(message) => Ok(message),
293        Err(e) => {
294            debug!(
295                "Brace-prefixed JSON parse failed: {}. Attempting fallback extraction.",
296                e
297            );
298            extract_and_parse_json(text)
299        }
300    }
301}
302
303/// Extracts and parses JSON from a potentially non-JSON response
304fn extract_and_parse_json<T: DeserializeOwned>(text: &str) -> Result<T> {
305    let cleaned_json = clean_json_from_llm(text);
306    serde_json::from_str(&cleaned_json).map_err(|e| anyhow!("JSON parse error: {e}"))
307}
308
309/// Returns a list of available LLM providers as strings
310pub fn get_available_provider_names() -> Vec<String> {
311    vec![
312        "openai".to_string(),
313        "anthropic".to_string(),
314        "ollama".to_string(),
315        "google".to_string(),
316        "groq".to_string(),
317        "xai".to_string(),
318        "deepseek".to_string(),
319        "phind".to_string(),
320        "openrouter".to_string(),
321    ]
322}
323
324/// Returns the default model for a given provider
325pub fn get_default_model_for_provider(provider_type: &str) -> &'static str {
326    PROVIDER_DEFAULTS
327        .get(provider_type.to_lowercase().as_str())
328        .map_or("gpt-4.1", |def| def.model)
329}
330
331/// Returns the default token limit for a given provider
332pub fn get_default_token_limit_for_provider(provider_type: &str) -> usize {
333    PROVIDER_DEFAULTS
334        .get(provider_type.to_lowercase().as_str())
335        .map_or(8_192, |def| def.token_limit)
336}
337
338/// Checks if a provider requires an API key
339pub fn provider_requires_api_key(provider_type: &str) -> bool {
340    if let Ok(backend) = LLMBackend::from_str(provider_type) {
341        requires_api_key(&backend)
342    } else {
343        true // Default to requiring API key for unknown providers
344    }
345}
346
347/// Helper function: check if `LLMBackend` requires API key
348fn requires_api_key(backend: &LLMBackend) -> bool {
349    !matches!(backend, LLMBackend::Ollama | LLMBackend::Phind)
350}
351
352/// Helper function: check if the model is an `OpenAI` thinking model
353fn is_openai_thinking_model(model: &str) -> bool {
354    let model_lower = model.to_lowercase();
355    model_lower.starts_with('o')
356}
357
358/// Validates the provider configuration
359pub fn validate_provider_config(config: &Config, provider_name: &str) -> Result<()> {
360    if provider_requires_api_key(provider_name) {
361        let provider_config = config
362            .get_provider_config(provider_name)
363            .ok_or_else(|| anyhow!("Provider '{provider_name}' not found in configuration"))?;
364
365        if provider_config.api_key.is_empty() {
366            return Err(anyhow!("API key required for provider: {provider_name}"));
367        }
368    }
369
370    Ok(())
371}
372
373/// Combines default, saved, and command-line configurations
374pub fn get_combined_config<S: ::std::hash::BuildHasher>(
375    config: &Config,
376    provider_name: &str,
377    command_line_args: &HashMap<String, String, S>,
378) -> HashMap<String, String> {
379    let mut combined_params = HashMap::default();
380
381    // Add default values
382    combined_params.insert(
383        "model".to_string(),
384        get_default_model_for_provider(provider_name).to_string(),
385    );
386
387    // Add saved config values if available
388    if let Some(provider_config) = config.get_provider_config(provider_name) {
389        if !provider_config.api_key.is_empty() {
390            combined_params.insert("api_key".to_string(), provider_config.api_key.clone());
391        }
392        if !provider_config.model.is_empty() {
393            combined_params.insert("model".to_string(), provider_config.model.clone());
394        }
395        for (key, value) in &provider_config.additional_params {
396            combined_params.insert(key.clone(), value.clone());
397        }
398    }
399
400    // Add command line args (these take precedence)
401    for (key, value) in command_line_args {
402        if !value.is_empty() {
403            combined_params.insert(key.clone(), value.clone());
404        }
405    }
406
407    // Handle OpenAI thinking models: convert max_tokens to max_completion_tokens
408    if provider_name.to_lowercase() == "openai"
409        && let Some(model) = combined_params.get("model")
410        && is_openai_thinking_model(model)
411        && let Some(max_tokens) = combined_params.remove("max_tokens")
412    {
413        combined_params.insert("max_completion_tokens".to_string(), max_tokens);
414    }
415
416    combined_params
417}
418
419fn clean_json_from_llm(json_str: &str) -> String {
420    // Remove potential leading/trailing whitespace and invisible characters
421    let trimmed = json_str
422        .trim_start_matches(|c: char| c.is_whitespace() || !c.is_ascii())
423        .trim_end_matches(|c: char| c.is_whitespace() || !c.is_ascii());
424
425    // If wrapped in code block, remove the markers
426    let without_codeblock = if trimmed.starts_with("```") && trimmed.ends_with("```") {
427        let start = trimmed.find('{').unwrap_or(0);
428        let end = trimmed.rfind('}').map_or(trimmed.len(), |i| i + 1);
429        &trimmed[start..end]
430    } else {
431        trimmed
432    };
433
434    // Find the first '{' and last '}' to extract the JSON object
435    let start = without_codeblock.find('{').unwrap_or(0);
436    let end = without_codeblock
437        .rfind('}')
438        .map_or(without_codeblock.len(), |i| i + 1);
439
440    without_codeblock[start..end].trim().to_string()
441}