git_iris/
llm.rs

1use crate::config::{Config, ProviderConfig};
2use crate::llm_providers::{
3    create_provider, get_available_providers, get_provider_metadata, LLMProviderConfig,
4    LLMProviderType,
5};
6use crate::{log_debug, LLMProvider};
7use anyhow::{anyhow, Result};
8use serde::de::DeserializeOwned;
9use serde::Serialize;
10use std::collections::HashMap;
11use std::time::Duration;
12use tokio_retry::strategy::ExponentialBackoff;
13use tokio_retry::Retry;
14
15/// Generates a message using the given configuration
16pub async fn get_refined_message<T>(
17    config: &Config,
18    provider_type: &LLMProviderType,
19    system_prompt: &str,
20    user_prompt: &str,
21) -> Result<T>
22where
23    T: Serialize + DeserializeOwned + std::fmt::Debug,
24    String: Into<T>,
25{
26    // Get provider metadata and configuration
27    let provider_metadata = get_provider_metadata(provider_type);
28    let provider_config = if provider_metadata.requires_api_key {
29        config
30            .get_provider_config(provider_type.as_ref())
31            .ok_or_else(|| anyhow!("Provider '{}' not found in configuration", provider_type))?
32            .clone()
33    } else {
34        ProviderConfig::default_for(provider_type.as_ref())
35    };
36
37    // Create the LLM provider instance
38    let llm_provider = create_provider(*provider_type, provider_config.to_llm_provider_config())?;
39
40    log_debug!(
41        "Generating refined message using provider: {}",
42        provider_type
43    );
44    log_debug!("System prompt: {}", system_prompt);
45    log_debug!("User prompt: {}", user_prompt);
46
47    // Call get_refined_message_with_provider
48    let result =
49        get_refined_message_with_provider::<T>(llm_provider, system_prompt, user_prompt).await?;
50
51    Ok(result)
52}
53
54/// Generates a message using the given provider (mainly for testing purposes)
55pub async fn get_refined_message_with_provider<T>(
56    llm_provider: Box<dyn LLMProvider + Send + Sync>,
57    system_prompt: &str,
58    user_prompt: &str,
59) -> Result<T>
60where
61    T: Serialize + DeserializeOwned + std::fmt::Debug,
62    String: Into<T>,
63{
64    log_debug!("Entering get_refined_message_with_provider");
65
66    let retry_strategy = ExponentialBackoff::from_millis(10).factor(2).take(2); // 2 attempts total: initial + 1 retry
67
68    let result = Retry::spawn(retry_strategy, || async {
69        log_debug!("Attempting to generate message");
70        match tokio::time::timeout(
71            Duration::from_secs(30),
72            llm_provider.generate_message(system_prompt, user_prompt),
73        )
74        .await
75        {
76            Ok(Ok(refined_message)) => {
77                log_debug!("Received response from provider");
78                let cleaned_message = clean_json_from_llm(&refined_message);
79                if std::any::type_name::<T>() == std::any::type_name::<String>() {
80                    // If T is String, return the raw string response
81                    Ok(cleaned_message.into())
82                } else {
83                    // Attempt to deserialize the response
84                    match serde_json::from_str::<T>(&cleaned_message) {
85                        Ok(message) => Ok(message),
86                        Err(e) => {
87                            log_debug!("Deserialization error: {} message: {}", e, cleaned_message);
88                            Err(anyhow!("Deserialization error: {}", e))
89                        }
90                    }
91                }
92            }
93            Ok(Err(e)) => {
94                log_debug!("Provider error: {}", e);
95                Err(e)
96            }
97            Err(_) => {
98                log_debug!("Provider timed out");
99                Err(anyhow!("Provider timed out"))
100            }
101        }
102    })
103    .await;
104
105    match result {
106        Ok(message) => {
107            log_debug!("Deserialized message: {:?}", message);
108            Ok(message)
109        }
110        Err(e) => {
111            log_debug!("Failed to generate message after retries: {}", e);
112            Err(anyhow!("Failed to generate message: {}", e))
113        }
114    }
115}
116
117/// Returns a list of available LLM providers as strings
118pub fn get_available_provider_names() -> Vec<String> {
119    get_available_providers()
120        .into_iter()
121        .filter(|p| *p != LLMProviderType::Test)
122        .map(|p| p.to_string())
123        .collect()
124}
125
126/// Returns the default model for a given provider
127pub fn get_default_model_for_provider(provider_type: &LLMProviderType) -> &'static str {
128    get_provider_metadata(provider_type).default_model
129}
130
131/// Returns the default token limit for a given provider
132pub fn get_default_token_limit_for_provider(provider_type: &LLMProviderType) -> Result<usize> {
133    Ok(get_provider_metadata(provider_type).default_token_limit)
134}
135
136/// Checks if a provider requires an API key
137pub fn provider_requires_api_key(provider_type: &LLMProviderType) -> bool {
138    get_provider_metadata(provider_type).requires_api_key
139}
140
141/// Validates the provider configuration
142pub fn validate_provider_config(config: &Config, provider_type: &LLMProviderType) -> Result<()> {
143    let metadata = get_provider_metadata(provider_type);
144
145    if metadata.requires_api_key {
146        let provider_config = config
147            .get_provider_config(provider_type.as_ref())
148            .ok_or_else(|| anyhow!("Provider '{}' not found in configuration", provider_type))?;
149
150        if provider_config.api_key.is_empty() {
151            return Err(anyhow!("API key required for provider: {}", provider_type));
152        }
153    }
154
155    Ok(())
156}
157
158/// Combines default, saved, and command-line configurations
159pub fn get_combined_config(
160    config: &Config,
161    provider_type: &LLMProviderType,
162    command_line_args: &LLMProviderConfig,
163) -> LLMProviderConfig {
164    let default_config = LLMProviderConfig {
165        api_key: String::new(),
166        model: get_default_model_for_provider(provider_type).to_string(),
167        additional_params: HashMap::default(),
168    };
169
170    let saved_config = config
171        .get_provider_config(provider_type.as_ref())
172        .cloned()
173        .unwrap_or_default();
174
175    LLMProviderConfig {
176        api_key: if !command_line_args.api_key.is_empty() {
177            command_line_args.api_key.clone()
178        } else if !saved_config.api_key.is_empty() {
179            saved_config.api_key
180        } else {
181            default_config.api_key
182        },
183        model: if !command_line_args.model.is_empty() {
184            command_line_args.model.clone()
185        } else if !saved_config.model.is_empty() {
186            saved_config.model
187        } else {
188            default_config.model
189        },
190        additional_params: if !command_line_args.additional_params.is_empty() {
191            command_line_args.additional_params.clone()
192        } else if !saved_config.additional_params.is_empty() {
193            saved_config.additional_params
194        } else {
195            default_config.additional_params
196        },
197    }
198}
199
200fn clean_json_from_llm(json_str: &str) -> String {
201    // Remove potential leading/trailing whitespace and invisible characters
202    let trimmed = json_str
203        .trim_start_matches(|c: char| c.is_whitespace() || !c.is_ascii())
204        .trim_end_matches(|c: char| c.is_whitespace() || !c.is_ascii());
205
206    // If wrapped in code block, remove the markers
207    let without_codeblock = if trimmed.starts_with("```") && trimmed.ends_with("```") {
208        let start = trimmed.find('{').unwrap_or(0);
209        let end = trimmed.rfind('}').map_or(trimmed.len(), |i| i + 1);
210        &trimmed[start..end]
211    } else {
212        trimmed
213    };
214
215    // Find the first '{' and last '}' to extract the JSON object
216    let start = without_codeblock.find('{').unwrap_or(0);
217    let end = without_codeblock
218        .rfind('}')
219        .map_or(without_codeblock.len(), |i| i + 1);
220
221    without_codeblock[start..end].trim().to_string()
222}