1use std::collections::HashMap;
8use std::path::PathBuf;
9
10use crate::analyzer::CodeIssue;
11use crate::i18n::I18n;
12
13use super::client::{LlmClient, LlmConfig};
14use super::prompt::build_roast_prompt;
15
16pub type RoastMap = HashMap<String, String>;
20
21pub trait RoastProvider {
25 fn generate_roasts(&self, issues: &[CodeIssue], lang: &str) -> RoastMap;
29}
30
31pub struct LocalRoastProvider;
35
36impl RoastProvider for LocalRoastProvider {
37 fn generate_roasts(&self, issues: &[CodeIssue], lang: &str) -> RoastMap {
38 let i18n = I18n::new(lang);
39 let mut map = RoastMap::new();
40
41 for issue in issues {
42 let key = format!(
43 "{}:{}:{}",
44 issue.file_path.display(),
45 issue.line,
46 issue.rule_name
47 );
48 let messages = i18n.get_roast_messages(&issue.rule_name);
49 let roast = if !messages.is_empty() {
50 messages[issue.line % messages.len()].clone()
51 } else {
52 issue.message.clone()
53 };
54 map.insert(key, roast);
55 }
56
57 map
58 }
59}
60
61pub struct LlmRoastProvider {
65 client: LlmClient,
66 fallback: LocalRoastProvider,
67}
68
69impl LlmRoastProvider {
70 pub fn new(config: LlmConfig) -> Self {
72 Self {
73 client: LlmClient::new(config),
74 fallback: LocalRoastProvider,
75 }
76 }
77}
78
79impl RoastProvider for LlmRoastProvider {
80 fn generate_roasts(&self, issues: &[CodeIssue], lang: &str) -> RoastMap {
81 let contexts = extract_code_contexts(issues);
82 let prompt = build_roast_prompt(issues, &contexts, lang);
83
84 tracing::debug!("Calling LLM with {} issues...", issues.len());
85 tracing::debug!(
86 "Prompt (first 500 chars): {}",
87 &prompt[..prompt.len().min(500)]
88 );
89
90 match self.client.call_blocking(&prompt) {
91 Ok(response) => {
92 tracing::debug!("LLM response received ({} chars)", response.len());
93 match parse_llm_response(&response, issues) {
94 Ok(roasts) => {
95 tracing::debug!("Parsed {} roasts from LLM", roasts.len());
96 roasts
97 }
98 Err(e) => {
99 tracing::warn!(
100 "Failed to parse LLM response: {:#}. Falling back to local roasts.",
101 e
102 );
103 self.fallback.generate_roasts(issues, lang)
104 }
105 }
106 }
107 Err(e) => {
108 tracing::warn!("LLM call failed: {:#}. Falling back to local roasts.", e);
109 self.fallback.generate_roasts(issues, lang)
110 }
111 }
112 }
113}
114
115fn extract_code_contexts(issues: &[CodeIssue]) -> HashMap<String, String> {
119 let file_paths: Vec<PathBuf> = issues
121 .iter()
122 .map(|i| i.file_path.clone())
123 .collect::<std::collections::HashSet<_>>()
124 .into_iter()
125 .collect();
126
127 let file_contents: HashMap<PathBuf, Vec<String>> = file_paths
129 .into_iter()
130 .filter_map(|path| {
131 let content = std::fs::read_to_string(&path).ok()?;
132 let lines: Vec<String> = content.lines().map(String::from).collect();
133 Some((path, lines))
134 })
135 .collect();
136
137 let mut contexts = HashMap::new();
139 for issue in issues {
140 let key = format!(
141 "{}:{}:{}",
142 issue.file_path.display(),
143 issue.line,
144 issue.rule_name
145 );
146
147 if let Some(lines) = file_contents.get(&issue.file_path) {
148 let start = issue.line.saturating_sub(6);
149 let end = (issue.line + 5).min(lines.len());
150 let context: String = lines[start..end]
151 .iter()
152 .enumerate()
153 .map(|(i, l)| format!("{:>4} | {}", start + i + 1, l))
154 .collect::<Vec<_>>()
155 .join("\n");
156 contexts.insert(key, context);
157 }
158 }
159
160 contexts
161}
162
163fn parse_llm_response(response: &str, issues: &[CodeIssue]) -> Result<RoastMap, anyhow::Error> {
168 let json_str = extract_json_from_response(response);
169 let cleaned = fix_trailing_commas(json_str);
171 let parsed: HashMap<String, String> = serde_json::from_str(&cleaned)?;
172
173 let mut roasts = RoastMap::new();
174 for (idx_str, roast) in parsed {
175 let Ok(idx) = idx_str.parse::<usize>() else {
176 continue;
177 };
178 if idx >= issues.len() {
179 continue;
180 }
181 let issue = &issues[idx];
182 let key = format!(
183 "{}:{}:{}",
184 issue.file_path.display(),
185 issue.line,
186 issue.rule_name
187 );
188 roasts.insert(key, roast);
189 }
190
191 Ok(roasts)
192}
193
194fn extract_json_from_response(response: &str) -> &str {
196 if let Some(start) = response.find("```json") {
198 let json_start = start + 7;
199 if let Some(end) = response[json_start..].find("```") {
200 return response[json_start..json_start + end].trim();
201 }
202 }
203
204 if let Some(start) = response.find("```") {
206 let fence_start = start + 3;
207 let content_start = response[fence_start..]
209 .find('\n')
210 .map(|n| fence_start + n + 1)
211 .unwrap_or(fence_start);
212 if let Some(end) = response[content_start..].find("```") {
213 return response[content_start..content_start + end].trim();
214 }
215 }
216
217 if let Some(start) = response.find('{') {
219 if let Some(end) = response.rfind('}') {
220 return &response[start..=end];
221 }
222 }
223
224 response
225}
226
227fn fix_trailing_commas(json: &str) -> String {
232 let mut result = String::with_capacity(json.len());
233 let bytes = json.as_bytes();
234 let len = bytes.len();
235
236 for i in 0..len {
237 if bytes[i] == b',' {
238 let rest = &json[i + 1..];
240 let trimmed = rest.trim_start();
241 if trimmed.starts_with('}') || trimmed.starts_with(']') {
242 continue;
244 }
245 }
246 result.push(bytes[i] as char);
247 }
248
249 result
250}
251
252#[cfg(test)]
253mod tests {
254 use super::*;
255 use crate::analyzer::Severity;
256
257 fn make_issue(rule: &str, line: usize) -> CodeIssue {
259 CodeIssue {
260 file_path: PathBuf::from("test.rs"),
261 line,
262 column: 1,
263 rule_name: rule.to_string(),
264 message: "test message".to_string(),
265 severity: Severity::Spicy,
266 }
267 }
268
269 #[test]
270 fn test_extract_json_from_plain_object() {
271 let response = r#"{"0": "roast one", "1": "roast two"}"#;
274 let result = extract_json_from_response(response);
275 assert_eq!(result, response, "Plain JSON should be returned as-is");
276 }
277
278 #[test]
279 fn test_extract_json_from_markdown_fence() {
280 let response = "Here is the JSON:\n```json\n{\"0\": \"roast\"}\n```\nDone.";
283 let result = extract_json_from_response(response);
284 assert_eq!(
285 result, "{\"0\": \"roast\"}",
286 "JSON inside markdown fences should be extracted"
287 );
288 }
289
290 #[test]
291 fn test_parse_response_maps_indices_to_issue_keys() {
292 let issues = vec![
295 make_issue("unwrap-abuse", 10),
296 make_issue("deep-nesting", 25),
297 ];
298 let response = r#"{"0": "nice unwrap", "1": "so deep"}"#;
299 let roasts = parse_llm_response(response, &issues).unwrap();
300
301 assert_eq!(roasts.len(), 2, "Should have roasts for both issues");
302 assert!(
303 roasts.contains_key("test.rs:10:unwrap-abuse"),
304 "First issue key must be test.rs:10:unwrap-abuse"
305 );
306 assert!(
307 roasts.contains_key("test.rs:25:deep-nesting"),
308 "Second issue key must be test.rs:25:deep-nesting"
309 );
310 }
311
312 #[test]
313 fn test_parse_response_skips_out_of_range_indices() {
314 let issues = vec![make_issue("unwrap-abuse", 10)];
317 let response = r#"{"0": "valid", "5": "out of range", "abc": "not a number"}"#;
318 let roasts = parse_llm_response(response, &issues).unwrap();
319
320 assert_eq!(
321 roasts.len(),
322 1,
323 "Only the valid index should produce a roast"
324 );
325 assert!(
326 roasts.contains_key("test.rs:10:unwrap-abuse"),
327 "Valid index 0 should map to the first issue"
328 );
329 }
330
331 #[test]
332 fn test_local_provider_returns_roasts_for_known_rules() {
333 let issues = vec![make_issue("unwrap-abuse", 1)];
336 let provider = LocalRoastProvider;
337 let roasts = provider.generate_roasts(&issues, "en-US");
338
339 assert!(
340 !roasts.is_empty(),
341 "LocalRoastProvider must return at least one roast for known rules"
342 );
343 assert!(
344 roasts.contains_key("test.rs:1:unwrap-abuse"),
345 "Roast key must match the issue key format"
346 );
347 }
348
349 #[test]
350 fn test_local_provider_returns_something_for_unknown_rules() {
351 let issues = vec![make_issue("unknown-rule-xyz", 42)];
354 let provider = LocalRoastProvider;
355 let roasts = provider.generate_roasts(&issues, "en-US");
356
357 assert_eq!(
358 roasts.len(),
359 1,
360 "Should have exactly one roast for one issue"
361 );
362 let roast = roasts.get("test.rs:42:unknown-rule-xyz").unwrap();
363 assert!(
364 !roast.is_empty(),
365 "Unknown rules must still produce a non-empty roast message"
366 );
367 }
368
369 #[test]
370 fn test_parse_response_with_markdown_wrapped_json() {
371 let issues = vec![make_issue("deep-nesting", 5)];
374 let response =
375 "Sure, here are the roasts:\n```json\n{\"0\": \"nested deeper than inception\"}\n```";
376 let roasts = parse_llm_response(response, &issues).unwrap();
377
378 assert_eq!(roasts.len(), 1, "Should parse one roast from fenced JSON");
379 let roast = roasts.get("test.rs:5:deep-nesting").unwrap();
380 assert_eq!(
381 roast, "nested deeper than inception",
382 "Roast content must match the JSON value"
383 );
384 }
385
386 #[test]
387 fn test_fix_trailing_commas_before_brace() {
388 let input = r#"{"0": "a", "1": "b",}"#;
389 let result = fix_trailing_commas(input);
390 assert_eq!(result, r#"{"0": "a", "1": "b"}"#);
391 }
392
393 #[test]
394 fn test_fix_trailing_commas_before_bracket() {
395 let input = r#"["a", "b",]"#;
396 let result = fix_trailing_commas(input);
397 assert_eq!(result, r#"["a", "b"]"#);
398 }
399
400 #[test]
401 fn test_fix_trailing_commas_preserves_valid_json() {
402 let input = r#"{"0": "a", "1": "b"}"#;
403 let result = fix_trailing_commas(input);
404 assert_eq!(result, input, "Valid JSON should be unchanged");
405 }
406
407 #[test]
408 fn test_fix_trailing_commas_handles_whitespace() {
409 let input = "{\"0\": \"a\" , \n}";
411 let result = fix_trailing_commas(input);
412 assert!(!result.contains(",}"), "Trailing comma must be removed");
413 assert!(result.contains("\"a\""), "Content must be preserved");
414 }
415
416 #[test]
417 fn test_parse_response_with_trailing_comma() {
418 let issues = vec![
421 make_issue("unwrap-abuse", 10),
422 make_issue("deep-nesting", 25),
423 ];
424 let response = "```json\n{\"0\": \"nice unwrap\", \"1\": \"so deep\",}\n```";
425 let roasts = parse_llm_response(response, &issues).unwrap();
426
427 assert_eq!(
428 roasts.len(),
429 2,
430 "Should parse both roasts despite trailing comma"
431 );
432 }
433
434 #[test]
435 fn test_extract_json_from_generic_code_fence() {
436 let response = "Here:\n```\n{\"0\": \"roast\"}\n```";
438 let result = extract_json_from_response(response);
439 assert_eq!(result, "{\"0\": \"roast\"}");
440 }
441}