1use crate::config::Config;
2use crate::log_debug;
3use anyhow::{Result, anyhow};
4use llm::{
5 LLMProvider,
6 builder::{LLMBackend, LLMBuilder},
7 chat::ChatMessage,
8};
9use serde::Serialize;
10use serde::de::DeserializeOwned;
11use std::collections::HashMap;
12use std::str::FromStr;
13use std::time::Duration;
14use tokio_retry::Retry;
15use tokio_retry::strategy::ExponentialBackoff;
16
17pub async fn get_message<T>(
19 config: &Config,
20 provider_name: &str,
21 system_prompt: &str,
22 user_prompt: &str,
23) -> Result<T>
24where
25 T: Serialize + DeserializeOwned + std::fmt::Debug,
26 String: Into<T>,
27{
28 log_debug!("Generating message using provider: {}", provider_name);
29 log_debug!("System prompt: {}", system_prompt);
30 log_debug!("User prompt: {}", user_prompt);
31
32 let backend =
34 LLMBackend::from_str(provider_name).map_err(|e| anyhow!("Invalid provider: {}", e))?;
35
36 let provider_config = config
38 .get_provider_config(provider_name)
39 .ok_or_else(|| anyhow!("Provider '{}' not found in configuration", provider_name))?;
40
41 let mut builder = LLMBuilder::new().backend(backend.clone());
43
44 if !provider_config.model.is_empty() {
46 builder = builder.model(provider_config.model.clone());
47 }
48
49 builder = builder.system(system_prompt.to_string());
51
52 if requires_api_key(&backend) && !provider_config.api_key.is_empty() {
54 builder = builder.api_key(provider_config.api_key.clone());
55 }
56
57 if let Some(temp) = provider_config.additional_params.get("temperature") {
59 if let Ok(temp_val) = temp.parse::<f32>() {
60 builder = builder.temperature(temp_val);
61 }
62 }
63
64 if let Some(max_tokens) = provider_config.additional_params.get("max_tokens") {
66 if let Ok(mt_val) = max_tokens.parse::<u32>() {
67 builder = builder.max_tokens(mt_val);
68 }
69 } else {
70 builder = builder.max_tokens(4096);
71 }
72
73 if let Some(top_p) = provider_config.additional_params.get("top_p") {
75 if let Ok(tp_val) = top_p.parse::<f32>() {
76 builder = builder.top_p(tp_val);
77 }
78 }
79
80 let provider = builder
82 .build()
83 .map_err(|e| anyhow!("Failed to build provider: {}", e))?;
84
85 let result = get_message_with_provider::<T>(provider, user_prompt).await?;
87
88 Ok(result)
89}
90
91pub async fn get_message_with_provider<T>(
93 provider: Box<dyn LLMProvider + Send + Sync>,
94 user_prompt: &str,
95) -> Result<T>
96where
97 T: Serialize + DeserializeOwned + std::fmt::Debug,
98 String: Into<T>,
99{
100 log_debug!("Entering get_message_with_provider");
101
102 let retry_strategy = ExponentialBackoff::from_millis(10).factor(2).take(2); let result = Retry::spawn(retry_strategy, || async {
105 log_debug!("Attempting to generate message");
106
107 let messages = vec![ChatMessage::user().content(user_prompt.to_string()).build()];
109
110 match tokio::time::timeout(Duration::from_secs(30), provider.chat(&messages)).await {
111 Ok(Ok(response)) => {
112 log_debug!("Received response from provider");
113 let response_text = response.text().unwrap_or_default();
114 let cleaned_message = clean_json_from_llm(&response_text);
115
116 if std::any::type_name::<T>() == std::any::type_name::<String>() {
117 Ok(cleaned_message.into())
119 } else {
120 match serde_json::from_str::<T>(&cleaned_message) {
122 Ok(message) => Ok(message),
123 Err(e) => {
124 log_debug!("Deserialization error: {} message: {}", e, cleaned_message);
125 Err(anyhow!("Deserialization error: {}", e))
126 }
127 }
128 }
129 }
130 Ok(Err(e)) => {
131 log_debug!("Provider error: {}", e);
132 Err(anyhow!("Provider error: {}", e))
133 }
134 Err(_) => {
135 log_debug!("Provider timed out");
136 Err(anyhow!("Provider timed out"))
137 }
138 }
139 })
140 .await;
141
142 match result {
143 Ok(message) => {
144 log_debug!("Deserialized message: {:?}", message);
145 Ok(message)
146 }
147 Err(e) => {
148 log_debug!("Failed to generate message after retries: {}", e);
149 Err(anyhow!("Failed to generate message: {}", e))
150 }
151 }
152}
153
154pub fn get_available_provider_names() -> Vec<String> {
156 vec![
157 "openai".to_string(),
158 "anthropic".to_string(),
159 "ollama".to_string(),
160 "google".to_string(),
161 "groq".to_string(),
162 "xai".to_string(),
163 "deepseek".to_string(),
164 "phind".to_string(),
165 ]
166}
167
168pub fn get_default_model_for_provider(provider_type: &str) -> &'static str {
170 match provider_type.to_lowercase().as_str() {
171 "anthropic" => "claude-3-7-sonnet-20250219",
172 "ollama" => "llama3",
173 "google" => "gemini-2.0-flash",
174 "groq" => "llama-3.1-70b-versatile",
175 "xai" => "grok-2-beta",
176 "deepseek" => "deepseek-chat",
177 "phind" => "phind-v2",
178 _ => "gpt-4o", }
180}
181
182pub fn get_default_token_limit_for_provider(provider_type: &str) -> Result<usize> {
184 let limit = match provider_type.to_lowercase().as_str() {
185 "anthropic" => 200_000,
186 "ollama" | "openai" | "groq" | "xai" => 128_000,
187 "google" => 1_000_000,
188 "deepseek" => 64_000,
189 "phind" => 32_000,
190 _ => 8_192, };
192 Ok(limit)
193}
194
195pub fn provider_requires_api_key(provider_type: &str) -> bool {
197 if let Ok(backend) = LLMBackend::from_str(provider_type) {
198 requires_api_key(&backend)
199 } else {
200 true }
202}
203
204fn requires_api_key(backend: &LLMBackend) -> bool {
206 !matches!(backend, LLMBackend::Ollama | LLMBackend::Phind)
207}
208
209pub fn validate_provider_config(config: &Config, provider_name: &str) -> Result<()> {
211 if provider_requires_api_key(provider_name) {
212 let provider_config = config
213 .get_provider_config(provider_name)
214 .ok_or_else(|| anyhow!("Provider '{}' not found in configuration", provider_name))?;
215
216 if provider_config.api_key.is_empty() {
217 return Err(anyhow!("API key required for provider: {}", provider_name));
218 }
219 }
220
221 Ok(())
222}
223
224pub fn get_combined_config<S: ::std::hash::BuildHasher>(
226 config: &Config,
227 provider_name: &str,
228 command_line_args: &HashMap<String, String, S>,
229) -> HashMap<String, String> {
230 let mut combined_params = HashMap::default();
231
232 combined_params.insert(
234 "model".to_string(),
235 get_default_model_for_provider(provider_name).to_string(),
236 );
237
238 if let Some(provider_config) = config.get_provider_config(provider_name) {
240 if !provider_config.api_key.is_empty() {
241 combined_params.insert("api_key".to_string(), provider_config.api_key.clone());
242 }
243 if !provider_config.model.is_empty() {
244 combined_params.insert("model".to_string(), provider_config.model.clone());
245 }
246 for (key, value) in &provider_config.additional_params {
247 combined_params.insert(key.clone(), value.clone());
248 }
249 }
250
251 for (key, value) in command_line_args {
253 if !value.is_empty() {
254 combined_params.insert(key.clone(), value.clone());
255 }
256 }
257
258 combined_params
259}
260
261fn clean_json_from_llm(json_str: &str) -> String {
262 let trimmed = json_str
264 .trim_start_matches(|c: char| c.is_whitespace() || !c.is_ascii())
265 .trim_end_matches(|c: char| c.is_whitespace() || !c.is_ascii());
266
267 let without_codeblock = if trimmed.starts_with("```") && trimmed.ends_with("```") {
269 let start = trimmed.find('{').unwrap_or(0);
270 let end = trimmed.rfind('}').map_or(trimmed.len(), |i| i + 1);
271 &trimmed[start..end]
272 } else {
273 trimmed
274 };
275
276 let start = without_codeblock.find('{').unwrap_or(0);
278 let end = without_codeblock
279 .rfind('}')
280 .map_or(without_codeblock.len(), |i| i + 1);
281
282 without_codeblock[start..end].trim().to_string()
283}