1use crate::config::Config;
2use crate::log_debug;
3use anyhow::{Result, anyhow};
4use llm::{
5 LLMProvider,
6 builder::{LLMBackend, LLMBuilder},
7 chat::ChatMessage,
8};
9use schemars::JsonSchema;
10use serde::de::DeserializeOwned;
11use std::collections::HashMap;
12use std::str::FromStr;
13use std::time::Duration;
14use tokio_retry::Retry;
15use tokio_retry::strategy::ExponentialBackoff;
16
17pub async fn get_message<T>(
19 config: &Config,
20 provider_name: &str,
21 system_prompt: &str,
22 user_prompt: &str,
23) -> Result<T>
24where
25 T: DeserializeOwned + JsonSchema,
26{
27 log_debug!("Generating message using provider: {}", provider_name);
28 log_debug!("System prompt: {}", system_prompt);
29 log_debug!("User prompt: {}", user_prompt);
30
31 let backend =
33 LLMBackend::from_str(provider_name).map_err(|e| anyhow!("Invalid provider: {}", e))?;
34
35 let provider_config = config
37 .get_provider_config(provider_name)
38 .ok_or_else(|| anyhow!("Provider '{}' not found in configuration", provider_name))?;
39
40 let mut builder = LLMBuilder::new().backend(backend.clone());
42
43 if !provider_config.model.is_empty() {
45 builder = builder.model(provider_config.model.clone());
46 }
47
48 builder = builder.system(system_prompt.to_string());
50
51 if requires_api_key(&backend) && !provider_config.api_key.is_empty() {
53 builder = builder.api_key(provider_config.api_key.clone());
54 }
55
56 if let Some(temp) = provider_config.additional_params.get("temperature") {
58 if let Ok(temp_val) = temp.parse::<f32>() {
59 builder = builder.temperature(temp_val);
60 }
61 }
62
63 if let Some(max_tokens) = provider_config.additional_params.get("max_tokens") {
65 if let Ok(mt_val) = max_tokens.parse::<u32>() {
66 builder = builder.max_tokens(mt_val);
67 }
68 } else {
69 builder = builder.max_tokens(4096);
70 }
71
72 if let Some(top_p) = provider_config.additional_params.get("top_p") {
74 if let Ok(tp_val) = top_p.parse::<f32>() {
75 builder = builder.top_p(tp_val);
76 }
77 }
78
79 let provider = builder
81 .build()
82 .map_err(|e| anyhow!("Failed to build provider: {}", e))?;
83
84 get_message_with_provider(provider, user_prompt, provider_name).await
86}
87
88pub async fn get_message_with_provider<T>(
90 provider: Box<dyn LLMProvider + Send + Sync>,
91 user_prompt: &str,
92 provider_type: &str,
93) -> Result<T>
94where
95 T: DeserializeOwned + JsonSchema,
96{
97 log_debug!("Entering get_message_with_provider");
98
99 let retry_strategy = ExponentialBackoff::from_millis(10).factor(2).take(2); let result = Retry::spawn(retry_strategy, || async {
102 log_debug!("Attempting to generate message");
103
104 let enhanced_prompt = if std::any::type_name::<T>() == std::any::type_name::<String>() {
106 user_prompt.to_string()
107 } else {
108 format!("{user_prompt}\n\nPlease respond with a valid JSON object and nothing else. No explanations or text outside the JSON.")
109 };
110
111 let mut messages = vec![ChatMessage::user().content(enhanced_prompt).build()];
113
114 if provider_type.to_lowercase() == "anthropic" && std::any::type_name::<T>() != std::any::type_name::<String>() {
116 messages.push(ChatMessage::assistant().content("Here is the JSON:\n{").build());
117 }
118
119 match tokio::time::timeout(Duration::from_secs(30), provider.chat(&messages)).await {
120 Ok(Ok(response)) => {
121 log_debug!("Received response from provider");
122 let response_text = response.text().unwrap_or_default();
123
124 let result = match provider_type.to_lowercase().as_str() {
126 "anthropic" => {
128 if std::any::type_name::<T>() == std::any::type_name::<String>() {
129 #[allow(clippy::unnecessary_to_owned)]
131 let string_result: T = serde_json::from_value(serde_json::Value::String(response_text.clone()))
132 .map_err(|e| anyhow!("String conversion error: {}", e))?;
133 Ok(string_result)
134 } else {
135 parse_json_response_with_brace_prefix::<T>(&response_text)
136 }
137 },
138
139 _ => {
141 if std::any::type_name::<T>() == std::any::type_name::<String>() {
142 #[allow(clippy::unnecessary_to_owned)]
144 let string_result: T = serde_json::from_value(serde_json::Value::String(response_text.clone()))
145 .map_err(|e| anyhow!("String conversion error: {}", e))?;
146 Ok(string_result)
147 } else {
148 parse_json_response::<T>(&response_text)
150 }
151 }
152 };
153
154 match result {
155 Ok(message) => Ok(message),
156 Err(e) => {
157 log_debug!("JSON parse error: {} text: {}", e, response_text);
158 Err(anyhow!("JSON parse error: {}", e))
159 }
160 }
161 }
162 Ok(Err(e)) => {
163 log_debug!("Provider error: {}", e);
164 Err(anyhow!("Provider error: {}", e))
165 }
166 Err(_) => {
167 log_debug!("Provider timed out");
168 Err(anyhow!("Provider timed out"))
169 }
170 }
171 })
172 .await;
173
174 match result {
175 Ok(message) => {
176 log_debug!("Generated message successfully");
177 Ok(message)
178 }
179 Err(e) => {
180 log_debug!("Failed to generate message after retries: {}", e);
181 Err(anyhow!("Failed to generate message: {}", e))
182 }
183 }
184}
185
186fn parse_json_response<T: DeserializeOwned>(text: &str) -> Result<T> {
188 match serde_json::from_str::<T>(text) {
189 Ok(message) => Ok(message),
190 Err(e) => {
191 log_debug!(
193 "Direct JSON parse failed: {}. Attempting fallback extraction.",
194 e
195 );
196 extract_and_parse_json(text)
197 }
198 }
199}
200
201fn parse_json_response_with_brace_prefix<T: DeserializeOwned>(text: &str) -> Result<T> {
203 let json_text = format!("{{{text}");
205 match serde_json::from_str::<T>(&json_text) {
206 Ok(message) => Ok(message),
207 Err(e) => {
208 log_debug!(
209 "Brace-prefixed JSON parse failed: {}. Attempting fallback extraction.",
210 e
211 );
212 extract_and_parse_json(text)
213 }
214 }
215}
216
217fn extract_and_parse_json<T: DeserializeOwned>(text: &str) -> Result<T> {
219 let cleaned_json = clean_json_from_llm(text);
220 serde_json::from_str(&cleaned_json).map_err(|e| anyhow!("JSON parse error: {}", e))
221}
222
223pub fn get_available_provider_names() -> Vec<String> {
225 vec![
226 "openai".to_string(),
227 "anthropic".to_string(),
228 "ollama".to_string(),
229 "google".to_string(),
230 "groq".to_string(),
231 "xai".to_string(),
232 "deepseek".to_string(),
233 "phind".to_string(),
234 ]
235}
236
237pub fn get_default_model_for_provider(provider_type: &str) -> &'static str {
239 match provider_type.to_lowercase().as_str() {
240 "anthropic" => "claude-3-7-sonnet-latest",
241 "ollama" => "llama3",
242 "google" => "gemini-2.0-flash",
243 "groq" => "llama-3.1-70b-versatile",
244 "xai" => "grok-2-beta",
245 "deepseek" => "deepseek-chat",
246 "phind" => "phind-v2",
247 _ => "gpt-4.1", }
249}
250
251pub fn get_default_token_limit_for_provider(provider_type: &str) -> Result<usize> {
253 let limit = match provider_type.to_lowercase().as_str() {
254 "anthropic" => 200_000,
255 "ollama" | "openai" | "groq" | "xai" => 128_000,
256 "google" => 1_000_000,
257 "deepseek" => 64_000,
258 "phind" => 32_000,
259 _ => 8_192, };
261 Ok(limit)
262}
263
264pub fn provider_requires_api_key(provider_type: &str) -> bool {
266 if let Ok(backend) = LLMBackend::from_str(provider_type) {
267 requires_api_key(&backend)
268 } else {
269 true }
271}
272
273fn requires_api_key(backend: &LLMBackend) -> bool {
275 !matches!(backend, LLMBackend::Ollama | LLMBackend::Phind)
276}
277
278pub fn validate_provider_config(config: &Config, provider_name: &str) -> Result<()> {
280 if provider_requires_api_key(provider_name) {
281 let provider_config = config
282 .get_provider_config(provider_name)
283 .ok_or_else(|| anyhow!("Provider '{}' not found in configuration", provider_name))?;
284
285 if provider_config.api_key.is_empty() {
286 return Err(anyhow!("API key required for provider: {}", provider_name));
287 }
288 }
289
290 Ok(())
291}
292
293pub fn get_combined_config<S: ::std::hash::BuildHasher>(
295 config: &Config,
296 provider_name: &str,
297 command_line_args: &HashMap<String, String, S>,
298) -> HashMap<String, String> {
299 let mut combined_params = HashMap::default();
300
301 combined_params.insert(
303 "model".to_string(),
304 get_default_model_for_provider(provider_name).to_string(),
305 );
306
307 if let Some(provider_config) = config.get_provider_config(provider_name) {
309 if !provider_config.api_key.is_empty() {
310 combined_params.insert("api_key".to_string(), provider_config.api_key.clone());
311 }
312 if !provider_config.model.is_empty() {
313 combined_params.insert("model".to_string(), provider_config.model.clone());
314 }
315 for (key, value) in &provider_config.additional_params {
316 combined_params.insert(key.clone(), value.clone());
317 }
318 }
319
320 for (key, value) in command_line_args {
322 if !value.is_empty() {
323 combined_params.insert(key.clone(), value.clone());
324 }
325 }
326
327 combined_params
328}
329
330fn clean_json_from_llm(json_str: &str) -> String {
331 let trimmed = json_str
333 .trim_start_matches(|c: char| c.is_whitespace() || !c.is_ascii())
334 .trim_end_matches(|c: char| c.is_whitespace() || !c.is_ascii());
335
336 let without_codeblock = if trimmed.starts_with("```") && trimmed.ends_with("```") {
338 let start = trimmed.find('{').unwrap_or(0);
339 let end = trimmed.rfind('}').map_or(trimmed.len(), |i| i + 1);
340 &trimmed[start..end]
341 } else {
342 trimmed
343 };
344
345 let start = without_codeblock.find('{').unwrap_or(0);
347 let end = without_codeblock
348 .rfind('}')
349 .map_or(without_codeblock.len(), |i| i + 1);
350
351 without_codeblock[start..end].trim().to_string()
352}