1use std::collections::HashMap;
8use std::path::PathBuf;
9
10use crate::analyzer::CodeIssue;
11use crate::i18n::I18n;
12
13use super::client::{LlmClient, LlmConfig};
14use super::prompt::build_roast_prompt;
15
16pub type RoastMap = HashMap<String, String>;
20
21pub trait RoastProvider {
25 fn generate_roasts(&self, issues: &[CodeIssue], lang: &str) -> RoastMap;
29}
30
31pub struct LocalRoastProvider;
35
36impl RoastProvider for LocalRoastProvider {
37 fn generate_roasts(&self, issues: &[CodeIssue], lang: &str) -> RoastMap {
38 let i18n = I18n::new(lang);
39 let mut map = RoastMap::new();
40
41 for issue in issues {
42 let key = format!(
43 "{}:{}:{}",
44 issue.file_path.display(),
45 issue.line,
46 issue.rule_name
47 );
48 let messages = i18n.get_roast_messages(&issue.rule_name);
49 let roast = if !messages.is_empty() {
50 messages[issue.line % messages.len()].clone()
51 } else {
52 issue.message.clone()
53 };
54 map.insert(key, roast);
55 }
56
57 map
58 }
59}
60
61pub struct LlmRoastProvider {
65 client: LlmClient,
66 fallback: LocalRoastProvider,
67}
68
69impl LlmRoastProvider {
70 pub fn new(config: LlmConfig) -> Self {
72 Self {
73 client: LlmClient::new(config),
74 fallback: LocalRoastProvider,
75 }
76 }
77}
78
79impl RoastProvider for LlmRoastProvider {
80 fn generate_roasts(&self, issues: &[CodeIssue], lang: &str) -> RoastMap {
81 let contexts = extract_code_contexts(issues);
82 let prompt = build_roast_prompt(issues, &contexts, lang);
83
84 tracing::debug!("Calling LLM with {} issues...", issues.len());
85 tracing::debug!(
86 "Prompt (first 500 chars): {}",
87 &prompt[..prompt.len().min(500)]
88 );
89
90 match self.client.call_blocking(&prompt) {
91 Ok(response) => {
92 tracing::debug!("LLM response received ({} chars)", response.len());
93 match parse_llm_response(&response, issues) {
94 Ok(roasts) => {
95 tracing::debug!("Parsed {} roasts from LLM", roasts.len());
96 roasts
97 }
98 Err(e) => {
99 tracing::warn!(
100 "Failed to parse LLM response: {:#}. Falling back to local roasts.",
101 e
102 );
103 self.fallback.generate_roasts(issues, lang)
104 }
105 }
106 }
107 Err(e) => {
108 tracing::warn!("LLM call failed: {:#}. Falling back to local roasts.", e);
109 self.fallback.generate_roasts(issues, lang)
110 }
111 }
112 }
113}
114
115fn extract_code_contexts(issues: &[CodeIssue]) -> HashMap<String, String> {
119 let file_paths: Vec<PathBuf> = issues
121 .iter()
122 .map(|i| i.file_path.clone())
123 .collect::<std::collections::HashSet<_>>()
124 .into_iter()
125 .collect();
126
127 let file_contents: HashMap<PathBuf, Vec<String>> = file_paths
129 .into_iter()
130 .filter_map(|path| match std::fs::read_to_string(&path) {
131 Ok(content) => {
132 let lines: Vec<String> = content.lines().map(String::from).collect();
133 Some((path, lines))
134 }
135 Err(e) => {
136 tracing::warn!("Failed to read source file {}: {}", path.display(), e);
137 None
138 }
139 })
140 .collect();
141
142 let mut contexts = HashMap::new();
144 for issue in issues {
145 let key = format!(
146 "{}:{}:{}",
147 issue.file_path.display(),
148 issue.line,
149 issue.rule_name
150 );
151
152 if let Some(lines) = file_contents.get(&issue.file_path) {
153 let start = issue.line.saturating_sub(6);
154 let end = (issue.line + 5).min(lines.len());
155 let context: String = lines[start..end]
156 .iter()
157 .enumerate()
158 .map(|(i, l)| format!("{:>4} | {}", start + i + 1, l))
159 .collect::<Vec<_>>()
160 .join("\n");
161 contexts.insert(key, context);
162 }
163 }
164
165 contexts
166}
167
168fn parse_llm_response(response: &str, issues: &[CodeIssue]) -> Result<RoastMap, anyhow::Error> {
173 let json_str = extract_json_from_response(response);
174 let cleaned = fix_trailing_commas(json_str);
176 let parsed: HashMap<String, String> = serde_json::from_str(&cleaned)?;
177
178 let mut roasts = RoastMap::new();
179 for (idx_str, roast) in parsed {
180 let Ok(idx) = idx_str.parse::<usize>() else {
181 continue;
182 };
183 if idx >= issues.len() {
184 continue;
185 }
186 let issue = &issues[idx];
187 let key = format!(
188 "{}:{}:{}",
189 issue.file_path.display(),
190 issue.line,
191 issue.rule_name
192 );
193 roasts.insert(key, roast);
194 }
195
196 Ok(roasts)
197}
198
199fn extract_json_from_response(response: &str) -> &str {
201 if let Some(start) = response.find("```json") {
203 let json_start = start + 7;
204 if let Some(end) = response[json_start..].find("```") {
205 return response[json_start..json_start + end].trim();
206 }
207 }
208
209 if let Some(start) = response.find("```") {
211 let fence_start = start + 3;
212 let content_start = response[fence_start..]
214 .find('\n')
215 .map(|n| fence_start + n + 1)
216 .unwrap_or(fence_start);
217 if let Some(end) = response[content_start..].find("```") {
218 return response[content_start..content_start + end].trim();
219 }
220 }
221
222 if let Some(start) = response.find('{') {
224 if let Some(end) = response.rfind('}') {
225 return &response[start..=end];
226 }
227 }
228
229 response
230}
231
232fn fix_trailing_commas(json: &str) -> String {
237 let mut result = String::with_capacity(json.len());
238 let bytes = json.as_bytes();
239 let len = bytes.len();
240
241 for i in 0..len {
242 if bytes[i] == b',' {
243 let rest = &json[i + 1..];
245 let trimmed = rest.trim_start();
246 if trimmed.starts_with('}') || trimmed.starts_with(']') {
247 continue;
249 }
250 }
251 result.push(bytes[i] as char);
252 }
253
254 result
255}
256
257#[cfg(test)]
258mod tests {
259 use super::*;
260 use crate::analyzer::Severity;
261
262 fn make_issue(rule: &str, line: usize) -> CodeIssue {
264 CodeIssue {
265 file_path: PathBuf::from("test.rs"),
266 line,
267 column: 1,
268 rule_name: rule.to_string(),
269 message: "test message".to_string(),
270 severity: Severity::Spicy,
271 }
272 }
273
274 #[test]
275 fn test_extract_json_from_plain_object() {
276 let response = r#"{"0": "roast one", "1": "roast two"}"#;
279 let result = extract_json_from_response(response);
280 assert_eq!(result, response, "Plain JSON should be returned as-is");
281 }
282
283 #[test]
284 fn test_extract_json_from_markdown_fence() {
285 let response = "Here is the JSON:\n```json\n{\"0\": \"roast\"}\n```\nDone.";
288 let result = extract_json_from_response(response);
289 assert_eq!(
290 result, "{\"0\": \"roast\"}",
291 "JSON inside markdown fences should be extracted"
292 );
293 }
294
295 #[test]
296 fn test_parse_response_maps_indices_to_issue_keys() {
297 let issues = vec![
300 make_issue("unwrap-abuse", 10),
301 make_issue("deep-nesting", 25),
302 ];
303 let response = r#"{"0": "nice unwrap", "1": "so deep"}"#;
304 let roasts = parse_llm_response(response, &issues).unwrap();
305
306 assert_eq!(roasts.len(), 2, "Should have roasts for both issues");
307 assert!(
308 roasts.contains_key("test.rs:10:unwrap-abuse"),
309 "First issue key must be test.rs:10:unwrap-abuse"
310 );
311 assert!(
312 roasts.contains_key("test.rs:25:deep-nesting"),
313 "Second issue key must be test.rs:25:deep-nesting"
314 );
315 }
316
317 #[test]
318 fn test_parse_response_skips_out_of_range_indices() {
319 let issues = vec![make_issue("unwrap-abuse", 10)];
322 let response = r#"{"0": "valid", "5": "out of range", "abc": "not a number"}"#;
323 let roasts = parse_llm_response(response, &issues).unwrap();
324
325 assert_eq!(
326 roasts.len(),
327 1,
328 "Only the valid index should produce a roast"
329 );
330 assert!(
331 roasts.contains_key("test.rs:10:unwrap-abuse"),
332 "Valid index 0 should map to the first issue"
333 );
334 }
335
336 #[test]
337 fn test_local_provider_returns_roasts_for_known_rules() {
338 let issues = vec![make_issue("unwrap-abuse", 1)];
341 let provider = LocalRoastProvider;
342 let roasts = provider.generate_roasts(&issues, "en-US");
343
344 assert!(
345 !roasts.is_empty(),
346 "LocalRoastProvider must return at least one roast for known rules"
347 );
348 assert!(
349 roasts.contains_key("test.rs:1:unwrap-abuse"),
350 "Roast key must match the issue key format"
351 );
352 }
353
354 #[test]
355 fn test_local_provider_returns_something_for_unknown_rules() {
356 let issues = vec![make_issue("unknown-rule-xyz", 42)];
359 let provider = LocalRoastProvider;
360 let roasts = provider.generate_roasts(&issues, "en-US");
361
362 assert_eq!(
363 roasts.len(),
364 1,
365 "Should have exactly one roast for one issue"
366 );
367 let roast = roasts.get("test.rs:42:unknown-rule-xyz").unwrap();
368 assert!(
369 !roast.is_empty(),
370 "Unknown rules must still produce a non-empty roast message"
371 );
372 }
373
374 #[test]
375 fn test_parse_response_with_markdown_wrapped_json() {
376 let issues = vec![make_issue("deep-nesting", 5)];
379 let response =
380 "Sure, here are the roasts:\n```json\n{\"0\": \"nested deeper than inception\"}\n```";
381 let roasts = parse_llm_response(response, &issues).unwrap();
382
383 assert_eq!(roasts.len(), 1, "Should parse one roast from fenced JSON");
384 let roast = roasts.get("test.rs:5:deep-nesting").unwrap();
385 assert_eq!(
386 roast, "nested deeper than inception",
387 "Roast content must match the JSON value"
388 );
389 }
390
391 #[test]
392 fn test_fix_trailing_commas_before_brace() {
393 let input = r#"{"0": "a", "1": "b",}"#;
394 let result = fix_trailing_commas(input);
395 assert_eq!(result, r#"{"0": "a", "1": "b"}"#);
396 }
397
398 #[test]
399 fn test_fix_trailing_commas_before_bracket() {
400 let input = r#"["a", "b",]"#;
401 let result = fix_trailing_commas(input);
402 assert_eq!(result, r#"["a", "b"]"#);
403 }
404
405 #[test]
406 fn test_fix_trailing_commas_preserves_valid_json() {
407 let input = r#"{"0": "a", "1": "b"}"#;
408 let result = fix_trailing_commas(input);
409 assert_eq!(result, input, "Valid JSON should be unchanged");
410 }
411
412 #[test]
413 fn test_fix_trailing_commas_handles_whitespace() {
414 let input = "{\"0\": \"a\" , \n}";
416 let result = fix_trailing_commas(input);
417 assert!(!result.contains(",}"), "Trailing comma must be removed");
418 assert!(result.contains("\"a\""), "Content must be preserved");
419 }
420
421 #[test]
422 fn test_parse_response_with_trailing_comma() {
423 let issues = vec![
426 make_issue("unwrap-abuse", 10),
427 make_issue("deep-nesting", 25),
428 ];
429 let response = "```json\n{\"0\": \"nice unwrap\", \"1\": \"so deep\",}\n```";
430 let roasts = parse_llm_response(response, &issues).unwrap();
431
432 assert_eq!(
433 roasts.len(),
434 2,
435 "Should parse both roasts despite trailing comma"
436 );
437 }
438
439 #[test]
440 fn test_extract_json_from_generic_code_fence() {
441 let response = "Here:\n```\n{\"0\": \"roast\"}\n```";
443 let result = extract_json_from_response(response);
444 assert_eq!(result, "{\"0\": \"roast\"}");
445 }
446}