1use crate::config::Config;
2use crate::debug;
3use anyhow::{Result, anyhow};
4use llm::{
5 LLMProvider,
6 builder::{LLMBackend, LLMBuilder},
7 chat::ChatMessage,
8};
9use schemars::JsonSchema;
10use serde::de::DeserializeOwned;
11use std::collections::HashMap;
12use std::str::FromStr;
13use std::time::Duration;
14use tokio_retry::Retry;
15use tokio_retry::strategy::ExponentialBackoff;
16
17#[derive(Clone, Debug)]
18struct ProviderDefault {
19 model: &'static str,
20 token_limit: usize,
21}
22
23static PROVIDER_DEFAULTS: std::sync::LazyLock<HashMap<&'static str, ProviderDefault>> =
24 std::sync::LazyLock::new(|| {
25 let mut map = HashMap::new();
26 map.insert(
27 "openai",
28 ProviderDefault {
29 model: "gpt-4.1",
30 token_limit: 128_000,
31 },
32 );
33 map.insert(
34 "anthropic",
35 ProviderDefault {
36 model: "claude-sonnet-4-20250514",
37 token_limit: 200_000,
38 },
39 );
40 map.insert(
41 "ollama",
42 ProviderDefault {
43 model: "llama3",
44 token_limit: 128_000,
45 },
46 );
47 map.insert(
48 "google",
49 ProviderDefault {
50 model: "gemini-2.5-pro-preview-06-05",
51 token_limit: 1_000_000,
52 },
53 );
54 map.insert(
55 "groq",
56 ProviderDefault {
57 model: "llama-3.1-70b-versatile",
58 token_limit: 128_000,
59 },
60 );
61 map.insert(
62 "xai",
63 ProviderDefault {
64 model: "grok-2-beta",
65 token_limit: 128_000,
66 },
67 );
68 map.insert(
69 "deepseek",
70 ProviderDefault {
71 model: "deepseek-chat",
72 token_limit: 64_000,
73 },
74 );
75 map.insert(
76 "phind",
77 ProviderDefault {
78 model: "phind-v2",
79 token_limit: 32_000,
80 },
81 );
82 map.insert(
83 "openrouter",
84 ProviderDefault {
85 model: "openrouter/sonoma-dusk-alpha",
86 token_limit: 2_000_000,
87 },
88 );
89 map
90 });
91
92pub async fn get_message<T>(
94 config: &Config,
95 provider_name: &str,
96 system_prompt: &str,
97 user_prompt: &str,
98) -> Result<T>
99where
100 T: DeserializeOwned + JsonSchema,
101{
102 debug!("Generating message using provider: {}", provider_name);
103 debug!("System prompt: {}", system_prompt);
104 debug!("User prompt: {}", user_prompt);
105
106 let backend = if provider_name.to_lowercase() == "openrouter" {
108 LLMBackend::OpenRouter
109 } else {
110 LLMBackend::from_str(provider_name).map_err(|e| anyhow!("Invalid provider: {e}"))?
111 };
112
113 let provider_config = config
115 .get_provider_config(provider_name)
116 .ok_or_else(|| anyhow!("Provider '{provider_name}' not found in configuration"))?;
117
118 let mut builder = LLMBuilder::new().backend(backend.clone());
120
121 if !provider_config.model.is_empty() {
123 builder = builder.model(provider_config.model.clone());
124 }
125
126 builder = builder.system(system_prompt.to_string());
128
129 if requires_api_key(&backend) && !provider_config.api_key.is_empty() {
131 builder = builder.api_key(provider_config.api_key.clone());
132 }
133
134 if let Some(temp) = provider_config.additional_params.get("temperature")
136 && let Ok(temp_val) = temp.parse::<f32>()
137 {
138 builder = builder.temperature(temp_val);
139 }
140
141 if is_openai_thinking_model(&provider_config.model) && provider_name.to_lowercase() == "openai"
144 {
145 } else if let Some(max_tokens) = provider_config.additional_params.get("max_tokens") {
148 if let Ok(mt_val) = max_tokens.parse::<u32>() {
149 builder = builder.max_tokens(mt_val);
150 }
151 } else {
152 let default_max = get_default_token_limit_for_provider(provider_name)
153 .try_into()
154 .map_err(|e| anyhow!("Token limit too large for u32: {e}"))?;
155 builder = builder.max_tokens(default_max);
156 }
157
158 if let Some(top_p) = provider_config.additional_params.get("top_p")
160 && let Ok(tp_val) = top_p.parse::<f32>()
161 {
162 builder = builder.top_p(tp_val);
163 }
164
165 let provider = builder
167 .build()
168 .map_err(|e| anyhow!("Failed to build provider: {e}"))?;
169
170 get_message_with_provider(provider, user_prompt, provider_name).await
172}
173
174pub async fn get_message_with_provider<T>(
176 provider: Box<dyn LLMProvider + Send + Sync>,
177 user_prompt: &str,
178 provider_type: &str,
179) -> Result<T>
180where
181 T: DeserializeOwned + JsonSchema,
182{
183 debug!("Entering get_message_with_provider");
184
185 let retry_strategy = ExponentialBackoff::from_millis(10).factor(2).take(2); let result = Retry::spawn(retry_strategy, || async {
188 debug!("Attempting to generate message");
189
190 let enhanced_prompt = if std::any::type_name::<T>() == std::any::type_name::<String>() {
192 user_prompt.to_string()
193 } else {
194 format!("{user_prompt}\n\nPlease respond with a valid JSON object and nothing else. No explanations or text outside the JSON.")
195 };
196
197 let mut messages = vec![ChatMessage::user().content(enhanced_prompt).build()];
199
200 if provider_type.to_lowercase() == "anthropic" && std::any::type_name::<T>() != std::any::type_name::<String>() {
202 messages.push(ChatMessage::assistant().content("Here is the JSON:\n{").build());
203 }
204
205 match tokio::time::timeout(Duration::from_secs(30), provider.chat(&messages)).await {
206 Ok(Ok(response)) => {
207 debug!("Received response from provider");
208 let response_text = response.text().unwrap_or_default();
209
210 let result = match provider_type.to_lowercase().as_str() {
212 "anthropic" => {
214 if std::any::type_name::<T>() == std::any::type_name::<String>() {
215 #[allow(clippy::unnecessary_to_owned)]
217 let string_result: T = serde_json::from_value(serde_json::Value::String(response_text.clone()))
218 .map_err(|e| anyhow!("String conversion error: {e}"))?;
219 Ok(string_result)
220 } else {
221 parse_json_response_with_brace_prefix::<T>(&response_text)
222 }
223 },
224
225 _ => {
227 if std::any::type_name::<T>() == std::any::type_name::<String>() {
228 #[allow(clippy::unnecessary_to_owned)]
230 let string_result: T = serde_json::from_value(serde_json::Value::String(response_text.clone()))
231 .map_err(|e| anyhow!("String conversion error: {e}"))?;
232 Ok(string_result)
233 } else {
234 parse_json_response::<T>(&response_text)
236 }
237 }
238 };
239
240 match result {
241 Ok(message) => Ok(message),
242 Err(e) => {
243 debug!("JSON parse error: {} text: {}", e, response_text);
244 Err(anyhow!("JSON parse error: {e}"))
245 }
246 }
247 }
248 Ok(Err(e)) => {
249 debug!("Provider error: {}", e);
250 Err(anyhow!("Provider error: {e}"))
251 }
252 Err(_) => {
253 debug!("Provider timed out");
254 Err(anyhow!("Provider timed out"))
255 }
256 }
257 })
258 .await;
259
260 match result {
261 Ok(message) => {
262 debug!("Generated message successfully");
263 Ok(message)
264 }
265 Err(e) => {
266 debug!("Failed to generate message after retries: {}", e);
267 Err(anyhow!("Failed to generate message: {e}"))
268 }
269 }
270}
271
272fn parse_json_response<T: DeserializeOwned>(text: &str) -> Result<T> {
274 match serde_json::from_str::<T>(text) {
275 Ok(message) => Ok(message),
276 Err(e) => {
277 debug!(
279 "Direct JSON parse failed: {}. Attempting fallback extraction.",
280 e
281 );
282 extract_and_parse_json(text)
283 }
284 }
285}
286
287fn parse_json_response_with_brace_prefix<T: DeserializeOwned>(text: &str) -> Result<T> {
289 let json_text = format!("{{{text}");
291 match serde_json::from_str::<T>(&json_text) {
292 Ok(message) => Ok(message),
293 Err(e) => {
294 debug!(
295 "Brace-prefixed JSON parse failed: {}. Attempting fallback extraction.",
296 e
297 );
298 extract_and_parse_json(text)
299 }
300 }
301}
302
303fn extract_and_parse_json<T: DeserializeOwned>(text: &str) -> Result<T> {
305 let cleaned_json = clean_json_from_llm(text);
306 serde_json::from_str(&cleaned_json).map_err(|e| anyhow!("JSON parse error: {e}"))
307}
308
309pub fn get_available_provider_names() -> Vec<String> {
311 vec![
312 "openai".to_string(),
313 "anthropic".to_string(),
314 "ollama".to_string(),
315 "google".to_string(),
316 "groq".to_string(),
317 "xai".to_string(),
318 "deepseek".to_string(),
319 "phind".to_string(),
320 "openrouter".to_string(),
321 ]
322}
323
324pub fn get_default_model_for_provider(provider_type: &str) -> &'static str {
326 PROVIDER_DEFAULTS
327 .get(provider_type.to_lowercase().as_str())
328 .map_or("gpt-4.1", |def| def.model)
329}
330
331pub fn get_default_token_limit_for_provider(provider_type: &str) -> usize {
333 PROVIDER_DEFAULTS
334 .get(provider_type.to_lowercase().as_str())
335 .map_or(8_192, |def| def.token_limit)
336}
337
338pub fn provider_requires_api_key(provider_type: &str) -> bool {
340 if let Ok(backend) = LLMBackend::from_str(provider_type) {
341 requires_api_key(&backend)
342 } else {
343 true }
345}
346
347fn requires_api_key(backend: &LLMBackend) -> bool {
349 !matches!(backend, LLMBackend::Ollama | LLMBackend::Phind)
350}
351
352fn is_openai_thinking_model(model: &str) -> bool {
354 let model_lower = model.to_lowercase();
355 model_lower.starts_with('o')
356}
357
358pub fn validate_provider_config(config: &Config, provider_name: &str) -> Result<()> {
360 if provider_requires_api_key(provider_name) {
361 let provider_config = config
362 .get_provider_config(provider_name)
363 .ok_or_else(|| anyhow!("Provider '{provider_name}' not found in configuration"))?;
364
365 if provider_config.api_key.is_empty() {
366 return Err(anyhow!("API key required for provider: {provider_name}"));
367 }
368 }
369
370 Ok(())
371}
372
373pub fn get_combined_config<S: ::std::hash::BuildHasher>(
375 config: &Config,
376 provider_name: &str,
377 command_line_args: &HashMap<String, String, S>,
378) -> HashMap<String, String> {
379 let mut combined_params = HashMap::default();
380
381 combined_params.insert(
383 "model".to_string(),
384 get_default_model_for_provider(provider_name).to_string(),
385 );
386
387 if let Some(provider_config) = config.get_provider_config(provider_name) {
389 if !provider_config.api_key.is_empty() {
390 combined_params.insert("api_key".to_string(), provider_config.api_key.clone());
391 }
392 if !provider_config.model.is_empty() {
393 combined_params.insert("model".to_string(), provider_config.model.clone());
394 }
395 for (key, value) in &provider_config.additional_params {
396 combined_params.insert(key.clone(), value.clone());
397 }
398 }
399
400 for (key, value) in command_line_args {
402 if !value.is_empty() {
403 combined_params.insert(key.clone(), value.clone());
404 }
405 }
406
407 if provider_name.to_lowercase() == "openai"
409 && let Some(model) = combined_params.get("model")
410 && is_openai_thinking_model(model)
411 && let Some(max_tokens) = combined_params.remove("max_tokens")
412 {
413 combined_params.insert("max_completion_tokens".to_string(), max_tokens);
414 }
415
416 combined_params
417}
418
419fn clean_json_from_llm(json_str: &str) -> String {
420 let trimmed = json_str
422 .trim_start_matches(|c: char| c.is_whitespace() || !c.is_ascii())
423 .trim_end_matches(|c: char| c.is_whitespace() || !c.is_ascii());
424
425 let without_codeblock = if trimmed.starts_with("```") && trimmed.ends_with("```") {
427 let start = trimmed.find('{').unwrap_or(0);
428 let end = trimmed.rfind('}').map_or(trimmed.len(), |i| i + 1);
429 &trimmed[start..end]
430 } else {
431 trimmed
432 };
433
434 let start = without_codeblock.find('{').unwrap_or(0);
436 let end = without_codeblock
437 .rfind('}')
438 .map_or(without_codeblock.len(), |i| i + 1);
439
440 without_codeblock[start..end].trim().to_string()
441}