Skip to main content

garbage_code_hunter/llm/
provider.rs

1//! Roast provider abstraction for generating code review messages.
2//!
3//! This module defines the `RoastProvider` trait and two implementations:
4//! - `LocalRoastProvider`: Uses hardcoded roast messages from the i18n module.
5//! - `LlmRoastProvider`: Calls an LLM endpoint to generate dynamic, context-aware roasts.
6
7use std::collections::HashMap;
8use std::path::PathBuf;
9
10use crate::analyzer::CodeIssue;
11use crate::i18n::I18n;
12
13use super::client::{LlmClient, LlmConfig};
14use super::prompt::build_roast_prompt;
15
16/// A map from issue key to roast message.
17///
18/// Issue key format: `"{file_path}:{line}:{rule_name}"`
19pub type RoastMap = HashMap<String, String>;
20
21/// Trait for generating roast messages for code issues.
22///
23/// Implementors can use local hardcoded messages or call external LLM services.
24pub trait RoastProvider {
25    /// Generate roast messages for the given issues.
26    ///
27    /// Returns a `RoastMap` mapping issue keys to roast messages.
28    fn generate_roasts(&self, issues: &[CodeIssue], lang: &str) -> RoastMap;
29}
30
31/// Local roast provider using hardcoded messages from the i18n module.
32///
33/// This is the default provider and serves as the fallback when LLM calls fail.
34pub struct LocalRoastProvider;
35
36impl RoastProvider for LocalRoastProvider {
37    fn generate_roasts(&self, issues: &[CodeIssue], lang: &str) -> RoastMap {
38        let i18n = I18n::new(lang);
39        let mut map = RoastMap::new();
40
41        for issue in issues {
42            let key = format!(
43                "{}:{}:{}",
44                issue.file_path.display(),
45                issue.line,
46                issue.rule_name
47            );
48            let messages = i18n.get_roast_messages(&issue.rule_name);
49            let roast = if !messages.is_empty() {
50                messages[issue.line % messages.len()].clone()
51            } else {
52                issue.message.clone()
53            };
54            map.insert(key, roast);
55        }
56
57        map
58    }
59}
60
61/// LLM-powered roast provider that generates dynamic, context-aware roasts.
62///
63/// Falls back to `LocalRoastProvider` if the LLM call fails or returns invalid data.
64pub struct LlmRoastProvider {
65    client: LlmClient,
66    fallback: LocalRoastProvider,
67}
68
69impl LlmRoastProvider {
70    /// Create a new LLM roast provider with the given configuration.
71    pub fn new(config: LlmConfig) -> Self {
72        Self {
73            client: LlmClient::new(config),
74            fallback: LocalRoastProvider,
75        }
76    }
77}
78
79impl RoastProvider for LlmRoastProvider {
80    fn generate_roasts(&self, issues: &[CodeIssue], lang: &str) -> RoastMap {
81        let contexts = extract_code_contexts(issues);
82        let prompt = build_roast_prompt(issues, &contexts, lang);
83
84        tracing::debug!("Calling LLM with {} issues...", issues.len());
85        tracing::debug!(
86            "Prompt (first 500 chars): {}",
87            &prompt[..prompt.len().min(500)]
88        );
89
90        match self.client.call_blocking(&prompt) {
91            Ok(response) => {
92                tracing::debug!("LLM response received ({} chars)", response.len());
93                match parse_llm_response(&response, issues) {
94                    Ok(roasts) => {
95                        tracing::debug!("Parsed {} roasts from LLM", roasts.len());
96                        roasts
97                    }
98                    Err(e) => {
99                        tracing::warn!(
100                            "Failed to parse LLM response: {:#}. Falling back to local roasts.",
101                            e
102                        );
103                        self.fallback.generate_roasts(issues, lang)
104                    }
105                }
106            }
107            Err(e) => {
108                tracing::warn!("LLM call failed: {:#}. Falling back to local roasts.", e);
109                self.fallback.generate_roasts(issues, lang)
110            }
111        }
112    }
113}
114
115/// Extract code context (±5 lines) around each issue for the LLM prompt.
116///
117/// Groups issues by file to avoid reading the same file multiple times.
118fn extract_code_contexts(issues: &[CodeIssue]) -> HashMap<String, String> {
119    // Collect unique file paths
120    let file_paths: Vec<PathBuf> = issues
121        .iter()
122        .map(|i| i.file_path.clone())
123        .collect::<std::collections::HashSet<_>>()
124        .into_iter()
125        .collect();
126
127    // Read all file contents upfront
128    let file_contents: HashMap<PathBuf, Vec<String>> = file_paths
129        .into_iter()
130        .filter_map(|path| match std::fs::read_to_string(&path) {
131            Ok(content) => {
132                let lines: Vec<String> = content.lines().map(String::from).collect();
133                Some((path, lines))
134            }
135            Err(e) => {
136                tracing::warn!("Failed to read source file {}: {}", path.display(), e);
137                None
138            }
139        })
140        .collect();
141
142    // Extract context window for each issue
143    let mut contexts = HashMap::new();
144    for issue in issues {
145        let key = format!(
146            "{}:{}:{}",
147            issue.file_path.display(),
148            issue.line,
149            issue.rule_name
150        );
151
152        if let Some(lines) = file_contents.get(&issue.file_path) {
153            let start = issue.line.saturating_sub(6);
154            let end = (issue.line + 5).min(lines.len());
155            let context: String = lines[start..end]
156                .iter()
157                .enumerate()
158                .map(|(i, l)| format!("{:>4} | {}", start + i + 1, l))
159                .collect::<Vec<_>>()
160                .join("\n");
161            contexts.insert(key, context);
162        }
163    }
164
165    contexts
166}
167
168/// Parse the LLM response JSON into a RoastMap.
169///
170/// Expected format: `{"0": "roast message", "1": "roast message", ...}`
171/// where keys are issue indices (0-based) matching the order in the prompt.
172fn parse_llm_response(response: &str, issues: &[CodeIssue]) -> Result<RoastMap, anyhow::Error> {
173    let json_str = extract_json_from_response(response);
174    // LLMs often produce trailing commas in JSON — strip them for robustness
175    let cleaned = fix_trailing_commas(json_str);
176    let parsed: HashMap<String, String> = serde_json::from_str(&cleaned)?;
177
178    let mut roasts = RoastMap::new();
179    for (idx_str, roast) in parsed {
180        let Ok(idx) = idx_str.parse::<usize>() else {
181            continue;
182        };
183        if idx >= issues.len() {
184            continue;
185        }
186        let issue = &issues[idx];
187        let key = format!(
188            "{}:{}:{}",
189            issue.file_path.display(),
190            issue.line,
191            issue.rule_name
192        );
193        roasts.insert(key, roast);
194    }
195
196    Ok(roasts)
197}
198
199/// Extract JSON from LLM response, handling markdown code fences and plain JSON.
200fn extract_json_from_response(response: &str) -> &str {
201    // Handle ```json ... ``` wrapper
202    if let Some(start) = response.find("```json") {
203        let json_start = start + 7;
204        if let Some(end) = response[json_start..].find("```") {
205            return response[json_start..json_start + end].trim();
206        }
207    }
208
209    // Handle ``` ... ``` wrapper (without json tag)
210    if let Some(start) = response.find("```") {
211        let fence_start = start + 3;
212        // Skip the optional language tag on the same line
213        let content_start = response[fence_start..]
214            .find('\n')
215            .map(|n| fence_start + n + 1)
216            .unwrap_or(fence_start);
217        if let Some(end) = response[content_start..].find("```") {
218            return response[content_start..content_start + end].trim();
219        }
220    }
221
222    // Handle plain JSON object
223    if let Some(start) = response.find('{') {
224        if let Some(end) = response.rfind('}') {
225            return &response[start..=end];
226        }
227    }
228
229    response
230}
231
232/// Remove trailing commas from JSON before closing braces/brackets.
233///
234/// LLMs frequently produce invalid JSON like `{"a": 1, "b": 2,}` —
235/// this function strips the trailing comma to make it valid.
236fn fix_trailing_commas(json: &str) -> String {
237    let mut result = String::with_capacity(json.len());
238    let bytes = json.as_bytes();
239    let len = bytes.len();
240
241    for i in 0..len {
242        if bytes[i] == b',' {
243            // Look ahead past whitespace for `}` or `]`
244            let rest = &json[i + 1..];
245            let trimmed = rest.trim_start();
246            if trimmed.starts_with('}') || trimmed.starts_with(']') {
247                // Skip this comma
248                continue;
249            }
250        }
251        result.push(bytes[i] as char);
252    }
253
254    result
255}
256
257#[cfg(test)]
258mod tests {
259    use super::*;
260    use crate::analyzer::Severity;
261
262    /// Helper to create a test CodeIssue with minimal fields.
263    fn make_issue(rule: &str, line: usize) -> CodeIssue {
264        CodeIssue {
265            file_path: PathBuf::from("test.rs"),
266            line,
267            column: 1,
268            rule_name: rule.to_string(),
269            message: "test message".to_string(),
270            severity: Severity::Spicy,
271        }
272    }
273
274    #[test]
275    fn test_extract_json_from_plain_object() {
276        // Objective: Verify plain JSON objects are extracted correctly.
277        // Invariants: Output must match the input when it is a valid JSON object.
278        let response = r#"{"0": "roast one", "1": "roast two"}"#;
279        let result = extract_json_from_response(response);
280        assert_eq!(result, response, "Plain JSON should be returned as-is");
281    }
282
283    #[test]
284    fn test_extract_json_from_markdown_fence() {
285        // Objective: Verify JSON wrapped in ```json fences is extracted.
286        // Invariants: Only the JSON content between fences is returned.
287        let response = "Here is the JSON:\n```json\n{\"0\": \"roast\"}\n```\nDone.";
288        let result = extract_json_from_response(response);
289        assert_eq!(
290            result, "{\"0\": \"roast\"}",
291            "JSON inside markdown fences should be extracted"
292        );
293    }
294
295    #[test]
296    fn test_parse_response_maps_indices_to_issue_keys() {
297        // Objective: Verify LLM response indices map to correct issue keys.
298        // Invariants: Each index maps to the corresponding issue's key format.
299        let issues = vec![
300            make_issue("unwrap-abuse", 10),
301            make_issue("deep-nesting", 25),
302        ];
303        let response = r#"{"0": "nice unwrap", "1": "so deep"}"#;
304        let roasts = parse_llm_response(response, &issues).unwrap();
305
306        assert_eq!(roasts.len(), 2, "Should have roasts for both issues");
307        assert!(
308            roasts.contains_key("test.rs:10:unwrap-abuse"),
309            "First issue key must be test.rs:10:unwrap-abuse"
310        );
311        assert!(
312            roasts.contains_key("test.rs:25:deep-nesting"),
313            "Second issue key must be test.rs:25:deep-nesting"
314        );
315    }
316
317    #[test]
318    fn test_parse_response_skips_out_of_range_indices() {
319        // Objective: Verify out-of-range indices are silently ignored.
320        // Invariants: Only valid indices produce roasts; invalid ones are skipped.
321        let issues = vec![make_issue("unwrap-abuse", 10)];
322        let response = r#"{"0": "valid", "5": "out of range", "abc": "not a number"}"#;
323        let roasts = parse_llm_response(response, &issues).unwrap();
324
325        assert_eq!(
326            roasts.len(),
327            1,
328            "Only the valid index should produce a roast"
329        );
330        assert!(
331            roasts.contains_key("test.rs:10:unwrap-abuse"),
332            "Valid index 0 should map to the first issue"
333        );
334    }
335
336    #[test]
337    fn test_local_provider_returns_roasts_for_known_rules() {
338        // Objective: Verify LocalRoastProvider produces roasts for rules with i18n messages.
339        // Invariants: At least one roast must be returned for a known rule name.
340        let issues = vec![make_issue("unwrap-abuse", 1)];
341        let provider = LocalRoastProvider;
342        let roasts = provider.generate_roasts(&issues, "en-US");
343
344        assert!(
345            !roasts.is_empty(),
346            "LocalRoastProvider must return at least one roast for known rules"
347        );
348        assert!(
349            roasts.contains_key("test.rs:1:unwrap-abuse"),
350            "Roast key must match the issue key format"
351        );
352    }
353
354    #[test]
355    fn test_local_provider_returns_something_for_unknown_rules() {
356        // Objective: Verify unknown rules still produce a roast message.
357        // Invariants: The i18n module returns a catch-all message for unknown rules.
358        let issues = vec![make_issue("unknown-rule-xyz", 42)];
359        let provider = LocalRoastProvider;
360        let roasts = provider.generate_roasts(&issues, "en-US");
361
362        assert_eq!(
363            roasts.len(),
364            1,
365            "Should have exactly one roast for one issue"
366        );
367        let roast = roasts.get("test.rs:42:unknown-rule-xyz").unwrap();
368        assert!(
369            !roast.is_empty(),
370            "Unknown rules must still produce a non-empty roast message"
371        );
372    }
373
374    #[test]
375    fn test_parse_response_with_markdown_wrapped_json() {
376        // Objective: Verify end-to-end parsing with markdown-wrapped LLM output.
377        // Invariants: JSON inside code fences must parse correctly.
378        let issues = vec![make_issue("deep-nesting", 5)];
379        let response =
380            "Sure, here are the roasts:\n```json\n{\"0\": \"nested deeper than inception\"}\n```";
381        let roasts = parse_llm_response(response, &issues).unwrap();
382
383        assert_eq!(roasts.len(), 1, "Should parse one roast from fenced JSON");
384        let roast = roasts.get("test.rs:5:deep-nesting").unwrap();
385        assert_eq!(
386            roast, "nested deeper than inception",
387            "Roast content must match the JSON value"
388        );
389    }
390
391    #[test]
392    fn test_fix_trailing_commas_before_brace() {
393        let input = r#"{"0": "a", "1": "b",}"#;
394        let result = fix_trailing_commas(input);
395        assert_eq!(result, r#"{"0": "a", "1": "b"}"#);
396    }
397
398    #[test]
399    fn test_fix_trailing_commas_before_bracket() {
400        let input = r#"["a", "b",]"#;
401        let result = fix_trailing_commas(input);
402        assert_eq!(result, r#"["a", "b"]"#);
403    }
404
405    #[test]
406    fn test_fix_trailing_commas_preserves_valid_json() {
407        let input = r#"{"0": "a", "1": "b"}"#;
408        let result = fix_trailing_commas(input);
409        assert_eq!(result, input, "Valid JSON should be unchanged");
410    }
411
412    #[test]
413    fn test_fix_trailing_commas_handles_whitespace() {
414        // " ,  \n}" -> comma removed -> "   \n}" (space before comma + spaces after)
415        let input = "{\"0\": \"a\" ,  \n}";
416        let result = fix_trailing_commas(input);
417        assert!(!result.contains(",}"), "Trailing comma must be removed");
418        assert!(result.contains("\"a\""), "Content must be preserved");
419    }
420
421    #[test]
422    fn test_parse_response_with_trailing_comma() {
423        // Objective: Verify LLM output with trailing commas is handled.
424        // Invariants: Trailing commas must be stripped before parsing.
425        let issues = vec![
426            make_issue("unwrap-abuse", 10),
427            make_issue("deep-nesting", 25),
428        ];
429        let response = "```json\n{\"0\": \"nice unwrap\", \"1\": \"so deep\",}\n```";
430        let roasts = parse_llm_response(response, &issues).unwrap();
431
432        assert_eq!(
433            roasts.len(),
434            2,
435            "Should parse both roasts despite trailing comma"
436        );
437    }
438
439    #[test]
440    fn test_extract_json_from_generic_code_fence() {
441        // Objective: Verify JSON in ``` fences (without json tag) is extracted.
442        let response = "Here:\n```\n{\"0\": \"roast\"}\n```";
443        let result = extract_json_from_response(response);
444        assert_eq!(result, "{\"0\": \"roast\"}");
445    }
446}