1use crate::config::{Config, ProviderConfig};
2use crate::llm_providers::{
3 create_provider, get_available_providers, get_provider_metadata, LLMProviderConfig,
4 LLMProviderType,
5};
6use crate::{log_debug, LLMProvider};
7use anyhow::{anyhow, Result};
8use serde::de::DeserializeOwned;
9use serde::Serialize;
10use std::collections::HashMap;
11use std::time::Duration;
12use tokio_retry::strategy::ExponentialBackoff;
13use tokio_retry::Retry;
14
15pub async fn get_refined_message<T>(
17 config: &Config,
18 provider_type: &LLMProviderType,
19 system_prompt: &str,
20 user_prompt: &str,
21) -> Result<T>
22where
23 T: Serialize + DeserializeOwned + std::fmt::Debug,
24 String: Into<T>,
25{
26 let provider_metadata = get_provider_metadata(provider_type);
28 let provider_config = if provider_metadata.requires_api_key {
29 config
30 .get_provider_config(provider_type.as_ref())
31 .ok_or_else(|| anyhow!("Provider '{}' not found in configuration", provider_type))?
32 .clone()
33 } else {
34 ProviderConfig::default_for(provider_type.as_ref())
35 };
36
37 let llm_provider = create_provider(*provider_type, provider_config.to_llm_provider_config())?;
39
40 log_debug!(
41 "Generating refined message using provider: {}",
42 provider_type
43 );
44 log_debug!("System prompt: {}", system_prompt);
45 log_debug!("User prompt: {}", user_prompt);
46
47 let result =
49 get_refined_message_with_provider::<T>(llm_provider, system_prompt, user_prompt).await?;
50
51 Ok(result)
52}
53
54pub async fn get_refined_message_with_provider<T>(
56 llm_provider: Box<dyn LLMProvider + Send + Sync>,
57 system_prompt: &str,
58 user_prompt: &str,
59) -> Result<T>
60where
61 T: Serialize + DeserializeOwned + std::fmt::Debug,
62 String: Into<T>,
63{
64 log_debug!("Entering get_refined_message_with_provider");
65
66 let retry_strategy = ExponentialBackoff::from_millis(10).factor(2).take(2); let result = Retry::spawn(retry_strategy, || async {
69 log_debug!("Attempting to generate message");
70 match tokio::time::timeout(
71 Duration::from_secs(30),
72 llm_provider.generate_message(system_prompt, user_prompt),
73 )
74 .await
75 {
76 Ok(Ok(refined_message)) => {
77 log_debug!("Received response from provider");
78 let cleaned_message = clean_json_from_llm(&refined_message);
79 if std::any::type_name::<T>() == std::any::type_name::<String>() {
80 Ok(cleaned_message.into())
82 } else {
83 match serde_json::from_str::<T>(&cleaned_message) {
85 Ok(message) => Ok(message),
86 Err(e) => {
87 log_debug!("Deserialization error: {} message: {}", e, cleaned_message);
88 Err(anyhow!("Deserialization error: {}", e))
89 }
90 }
91 }
92 }
93 Ok(Err(e)) => {
94 log_debug!("Provider error: {}", e);
95 Err(e)
96 }
97 Err(_) => {
98 log_debug!("Provider timed out");
99 Err(anyhow!("Provider timed out"))
100 }
101 }
102 })
103 .await;
104
105 match result {
106 Ok(message) => {
107 log_debug!("Deserialized message: {:?}", message);
108 Ok(message)
109 }
110 Err(e) => {
111 log_debug!("Failed to generate message after retries: {}", e);
112 Err(anyhow!("Failed to generate message: {}", e))
113 }
114 }
115}
116
117pub fn get_available_provider_names() -> Vec<String> {
119 get_available_providers()
120 .into_iter()
121 .filter(|p| *p != LLMProviderType::Test)
122 .map(|p| p.to_string())
123 .collect()
124}
125
126pub fn get_default_model_for_provider(provider_type: &LLMProviderType) -> &'static str {
128 get_provider_metadata(provider_type).default_model
129}
130
131pub fn get_default_token_limit_for_provider(provider_type: &LLMProviderType) -> Result<usize> {
133 Ok(get_provider_metadata(provider_type).default_token_limit)
134}
135
136pub fn provider_requires_api_key(provider_type: &LLMProviderType) -> bool {
138 get_provider_metadata(provider_type).requires_api_key
139}
140
141pub fn validate_provider_config(config: &Config, provider_type: &LLMProviderType) -> Result<()> {
143 let metadata = get_provider_metadata(provider_type);
144
145 if metadata.requires_api_key {
146 let provider_config = config
147 .get_provider_config(provider_type.as_ref())
148 .ok_or_else(|| anyhow!("Provider '{}' not found in configuration", provider_type))?;
149
150 if provider_config.api_key.is_empty() {
151 return Err(anyhow!("API key required for provider: {}", provider_type));
152 }
153 }
154
155 Ok(())
156}
157
158pub fn get_combined_config(
160 config: &Config,
161 provider_type: &LLMProviderType,
162 command_line_args: &LLMProviderConfig,
163) -> LLMProviderConfig {
164 let default_config = LLMProviderConfig {
165 api_key: String::new(),
166 model: get_default_model_for_provider(provider_type).to_string(),
167 additional_params: HashMap::default(),
168 };
169
170 let saved_config = config
171 .get_provider_config(provider_type.as_ref())
172 .cloned()
173 .unwrap_or_default();
174
175 LLMProviderConfig {
176 api_key: if !command_line_args.api_key.is_empty() {
177 command_line_args.api_key.clone()
178 } else if !saved_config.api_key.is_empty() {
179 saved_config.api_key
180 } else {
181 default_config.api_key
182 },
183 model: if !command_line_args.model.is_empty() {
184 command_line_args.model.clone()
185 } else if !saved_config.model.is_empty() {
186 saved_config.model
187 } else {
188 default_config.model
189 },
190 additional_params: if !command_line_args.additional_params.is_empty() {
191 command_line_args.additional_params.clone()
192 } else if !saved_config.additional_params.is_empty() {
193 saved_config.additional_params
194 } else {
195 default_config.additional_params
196 },
197 }
198}
199
200fn clean_json_from_llm(json_str: &str) -> String {
201 let trimmed = json_str
203 .trim_start_matches(|c: char| c.is_whitespace() || !c.is_ascii())
204 .trim_end_matches(|c: char| c.is_whitespace() || !c.is_ascii());
205
206 let without_codeblock = if trimmed.starts_with("```") && trimmed.ends_with("```") {
208 let start = trimmed.find('{').unwrap_or(0);
209 let end = trimmed.rfind('}').map_or(trimmed.len(), |i| i + 1);
210 &trimmed[start..end]
211 } else {
212 trimmed
213 };
214
215 let start = without_codeblock.find('{').unwrap_or(0);
217 let end = without_codeblock
218 .rfind('}')
219 .map_or(without_codeblock.len(), |i| i + 1);
220
221 without_codeblock[start..end].trim().to_string()
222}