1use crate::config::Config;
2use crate::debug;
3use anyhow::{Result, anyhow};
4use llm::{
5 LLMProvider,
6 builder::{LLMBackend, LLMBuilder},
7 chat::ChatMessage,
8};
9use schemars::JsonSchema;
10use serde::de::DeserializeOwned;
11use std::collections::HashMap;
12use std::str::FromStr;
13use std::time::Duration;
14use tokio_retry::Retry;
15use tokio_retry::strategy::ExponentialBackoff;
16
17#[derive(Clone, Debug)]
18struct ProviderDefault {
19 model: &'static str,
20 token_limit: usize,
21}
22
23static PROVIDER_DEFAULTS: std::sync::LazyLock<HashMap<&'static str, ProviderDefault>> =
24 std::sync::LazyLock::new(|| {
25 let mut map = HashMap::new();
26 map.insert(
27 "openai",
28 ProviderDefault {
29 model: "gpt-4.1",
30 token_limit: 128_000,
31 },
32 );
33 map.insert(
34 "anthropic",
35 ProviderDefault {
36 model: "claude-sonnet-4-20250514",
37 token_limit: 200_000,
38 },
39 );
40 map.insert(
41 "ollama",
42 ProviderDefault {
43 model: "llama3",
44 token_limit: 128_000,
45 },
46 );
47 map.insert(
48 "google",
49 ProviderDefault {
50 model: "gemini-2.5-pro-preview-06-05",
51 token_limit: 1_000_000,
52 },
53 );
54 map.insert(
55 "groq",
56 ProviderDefault {
57 model: "llama-3.1-70b-versatile",
58 token_limit: 128_000,
59 },
60 );
61 map.insert(
62 "xai",
63 ProviderDefault {
64 model: "grok-2-beta",
65 token_limit: 128_000,
66 },
67 );
68 map.insert(
69 "deepseek",
70 ProviderDefault {
71 model: "deepseek-chat",
72 token_limit: 64_000,
73 },
74 );
75 map.insert(
76 "phind",
77 ProviderDefault {
78 model: "phind-v2",
79 token_limit: 32_000,
80 },
81 );
82 map.insert(
83 "openrouter",
84 ProviderDefault {
85 model: "openrouter/sonoma-dusk-alpha",
86 token_limit: 2_000_000,
87 },
88 );
89 map
90 });
91
92pub async fn get_message<T>(
94 config: &Config,
95 provider_name: &str,
96 system_prompt: &str,
97 user_prompt: &str,
98) -> Result<T>
99where
100 T: DeserializeOwned + JsonSchema,
101{
102 debug!("Generating message using provider: {}", provider_name);
103 debug!("System prompt: {}", system_prompt);
104 debug!("User prompt: {}", user_prompt);
105
106 let backend = if provider_name.to_lowercase() == "openrouter" {
108 LLMBackend::OpenRouter
109 } else {
110 LLMBackend::from_str(provider_name).map_err(|e| anyhow!("Invalid provider: {e}"))?
111 };
112
113 let provider_config = config
115 .get_provider_config(provider_name)
116 .ok_or_else(|| anyhow!("Provider '{provider_name}' not found in configuration"))?;
117
118 let mut builder = LLMBuilder::new().backend(backend.clone());
120
121 if !provider_config.model_name.is_empty() {
123 builder = builder.model(provider_config.model_name.clone());
124 }
125
126 builder = builder.system(system_prompt.to_string());
128
129 if requires_api_key(&backend) && !provider_config.api_key.is_empty() {
131 builder = builder.api_key(provider_config.api_key.clone());
132 }
133
134 if let Some(temp) = provider_config.additional_params.get("temperature")
136 && let Ok(temp_val) = temp.parse::<f32>()
137 {
138 builder = builder.temperature(temp_val);
139 }
140
141 if is_openai_thinking_model(&provider_config.model_name)
144 && provider_name.to_lowercase() == "openai"
145 {
146 } else if let Some(max_tokens) = provider_config.additional_params.get("max_tokens") {
149 if let Ok(mt_val) = max_tokens.parse::<u32>() {
150 builder = builder.max_tokens(mt_val);
151 }
152 } else {
153 let default_max = get_default_token_limit_for_provider(provider_name)
154 .try_into()
155 .map_err(|e| anyhow!("Token limit too large for u32: {e}"))?;
156 builder = builder.max_tokens(default_max);
157 }
158
159 if let Some(top_p) = provider_config.additional_params.get("top_p")
161 && let Ok(tp_val) = top_p.parse::<f32>()
162 {
163 builder = builder.top_p(tp_val);
164 }
165
166 let provider = builder
168 .build()
169 .map_err(|e| anyhow!("Failed to build provider: {e}"))?;
170
171 get_message_with_provider(provider, user_prompt, provider_name).await
173}
174
175pub async fn get_message_with_provider<T>(
177 provider: Box<dyn LLMProvider + Send + Sync>,
178 user_prompt: &str,
179 provider_type: &str,
180) -> Result<T>
181where
182 T: DeserializeOwned + JsonSchema,
183{
184 debug!("Entering get_message_with_provider");
185
186 let retry_strategy = ExponentialBackoff::from_millis(10).factor(2).take(2); let result = Retry::spawn(retry_strategy, || async {
189 debug!("Attempting to generate message");
190
191 let enhanced_prompt = if std::any::type_name::<T>() == std::any::type_name::<String>() {
193 user_prompt.to_string()
194 } else {
195 format!("{user_prompt}\n\nPlease respond with a valid JSON object and nothing else. No explanations or text outside the JSON.")
196 };
197
198 let mut messages = vec![ChatMessage::user().content(enhanced_prompt).build()];
200
201 if provider_type.to_lowercase() == "anthropic" && std::any::type_name::<T>() != std::any::type_name::<String>() {
203 messages.push(ChatMessage::assistant().content("Here is the JSON:\n{").build());
204 }
205
206 match tokio::time::timeout(Duration::from_secs(30), provider.chat(&messages)).await {
207 Ok(Ok(response)) => {
208 debug!("Received response from provider");
209 let response_text = response.text().unwrap_or_default();
210
211 let result = match provider_type.to_lowercase().as_str() {
213 "anthropic" => {
215 if std::any::type_name::<T>() == std::any::type_name::<String>() {
216 #[allow(clippy::unnecessary_to_owned)]
218 let string_result: T = serde_json::from_value(serde_json::Value::String(response_text.clone()))
219 .map_err(|e| anyhow!("String conversion error: {e}"))?;
220 Ok(string_result)
221 } else {
222 parse_json_response_with_brace_prefix::<T>(&response_text)
223 }
224 },
225
226 _ => {
228 if std::any::type_name::<T>() == std::any::type_name::<String>() {
229 #[allow(clippy::unnecessary_to_owned)]
231 let string_result: T = serde_json::from_value(serde_json::Value::String(response_text.clone()))
232 .map_err(|e| anyhow!("String conversion error: {e}"))?;
233 Ok(string_result)
234 } else {
235 parse_json_response::<T>(&response_text)
237 }
238 }
239 };
240
241 match result {
242 Ok(message) => Ok(message),
243 Err(e) => {
244 debug!("JSON parse error: {} text: {}", e, response_text);
245 Err(anyhow!("JSON parse error: {e}"))
246 }
247 }
248 }
249 Ok(Err(e)) => {
250 debug!("Provider error: {}", e);
251 Err(anyhow!("Provider error: {e}"))
252 }
253 Err(_) => {
254 debug!("Provider timed out");
255 Err(anyhow!("Provider timed out"))
256 }
257 }
258 })
259 .await;
260
261 match result {
262 Ok(message) => {
263 debug!("Generated message successfully");
264 Ok(message)
265 }
266 Err(e) => {
267 debug!("Failed to generate message after retries: {}", e);
268 Err(anyhow!("Failed to generate message: {e}"))
269 }
270 }
271}
272
273fn parse_json_response<T: DeserializeOwned>(text: &str) -> Result<T> {
275 match serde_json::from_str::<T>(text) {
276 Ok(message) => Ok(message),
277 Err(e) => {
278 debug!(
280 "Direct JSON parse failed: {}. Attempting fallback extraction.",
281 e
282 );
283 extract_and_parse_json(text)
284 }
285 }
286}
287
288fn parse_json_response_with_brace_prefix<T: DeserializeOwned>(text: &str) -> Result<T> {
290 let json_text = format!("{{{text}");
292 match serde_json::from_str::<T>(&json_text) {
293 Ok(message) => Ok(message),
294 Err(e) => {
295 debug!(
296 "Brace-prefixed JSON parse failed: {}. Attempting fallback extraction.",
297 e
298 );
299 extract_and_parse_json(text)
300 }
301 }
302}
303
304fn extract_and_parse_json<T: DeserializeOwned>(text: &str) -> Result<T> {
306 let cleaned_json = clean_json_from_llm(text);
307 serde_json::from_str(&cleaned_json).map_err(|e| anyhow!("JSON parse error: {e}"))
308}
309
310pub fn get_available_provider_names() -> Vec<String> {
312 vec![
313 "openai".to_string(),
314 "anthropic".to_string(),
315 "ollama".to_string(),
316 "google".to_string(),
317 "groq".to_string(),
318 "xai".to_string(),
319 "deepseek".to_string(),
320 "phind".to_string(),
321 "openrouter".to_string(),
322 ]
323}
324
325pub fn get_default_model_for_provider(provider_type: &str) -> &'static str {
327 PROVIDER_DEFAULTS
328 .get(provider_type.to_lowercase().as_str())
329 .map_or("gpt-4.1", |def| def.model)
330}
331
332pub fn get_default_token_limit_for_provider(provider_type: &str) -> usize {
334 PROVIDER_DEFAULTS
335 .get(provider_type.to_lowercase().as_str())
336 .map_or(8_192, |def| def.token_limit)
337}
338
339pub fn provider_requires_api_key(provider_type: &str) -> bool {
341 if let Ok(backend) = LLMBackend::from_str(provider_type) {
342 requires_api_key(&backend)
343 } else {
344 true }
346}
347
348fn requires_api_key(backend: &LLMBackend) -> bool {
350 !matches!(backend, LLMBackend::Ollama | LLMBackend::Phind)
351}
352
353fn is_openai_thinking_model(model: &str) -> bool {
355 let model_lower = model.to_lowercase();
356 model_lower.starts_with('o')
357}
358
359pub fn validate_provider_config(config: &Config, provider_name: &str) -> Result<()> {
361 if provider_requires_api_key(provider_name) {
362 let provider_config = config
363 .get_provider_config(provider_name)
364 .ok_or_else(|| anyhow!("Provider '{provider_name}' not found in configuration"))?;
365
366 if provider_config.api_key.is_empty() {
367 return Err(anyhow!("API key required for provider: {provider_name}"));
368 }
369 }
370
371 Ok(())
372}
373
374pub fn get_combined_config<S: ::std::hash::BuildHasher>(
376 config: &Config,
377 provider_name: &str,
378 command_line_args: &HashMap<String, String, S>,
379) -> HashMap<String, String> {
380 let mut combined_params = HashMap::default();
381
382 combined_params.insert(
384 "model".to_string(),
385 get_default_model_for_provider(provider_name).to_string(),
386 );
387
388 if let Some(provider_config) = config.get_provider_config(provider_name) {
390 if !provider_config.api_key.is_empty() {
391 combined_params.insert("api_key".to_string(), provider_config.api_key.clone());
392 }
393 if !provider_config.model_name.is_empty() {
394 combined_params.insert("model".to_string(), provider_config.model_name.clone());
395 }
396 for (key, value) in &provider_config.additional_params {
397 combined_params.insert(key.clone(), value.clone());
398 }
399 }
400
401 for (key, value) in command_line_args {
403 if !value.is_empty() {
404 combined_params.insert(key.clone(), value.clone());
405 }
406 }
407
408 if provider_name.to_lowercase() == "openai"
410 && let Some(model) = combined_params.get("model")
411 && is_openai_thinking_model(model)
412 && let Some(max_tokens) = combined_params.remove("max_tokens")
413 {
414 combined_params.insert("max_completion_tokens".to_string(), max_tokens);
415 }
416
417 combined_params
418}
419
420fn clean_json_from_llm(json_str: &str) -> String {
421 let trimmed = json_str
423 .trim_start_matches(|c: char| c.is_whitespace() || !c.is_ascii())
424 .trim_end_matches(|c: char| c.is_whitespace() || !c.is_ascii());
425
426 let without_codeblock = if trimmed.starts_with("```") && trimmed.ends_with("```") {
428 let start = trimmed.find('{').unwrap_or(0);
429 let end = trimmed.rfind('}').map_or(trimmed.len(), |i| i + 1);
430 &trimmed[start..end]
431 } else {
432 trimmed
433 };
434
435 let start = without_codeblock.find('{').unwrap_or(0);
437 let end = without_codeblock
438 .rfind('}')
439 .map_or(without_codeblock.len(), |i| i + 1);
440
441 without_codeblock[start..end].trim().to_string()
442}