Skip to main content

garbage_code_hunter/llm/
provider.rs

1//! Roast provider abstraction for generating code review messages.
2//!
3//! This module defines the `RoastProvider` trait and two implementations:
4//! - `LocalRoastProvider`: Uses hardcoded roast messages from the i18n module.
5//! - `LlmRoastProvider`: Calls an LLM endpoint to generate dynamic, context-aware roasts.
6
7use std::collections::HashMap;
8use std::path::PathBuf;
9
10use crate::analyzer::CodeIssue;
11use crate::i18n::I18n;
12
13use super::client::{LlmClient, LlmConfig};
14use super::prompt::build_roast_prompt;
15
16/// A map from issue key to roast message.
17///
18/// Issue key format: `"{file_path}:{line}:{rule_name}"`
19pub type RoastMap = HashMap<String, String>;
20
21/// Trait for generating roast messages for code issues.
22///
23/// Implementors can use local hardcoded messages or call external LLM services.
24pub trait RoastProvider {
25    /// Generate roast messages for the given issues.
26    ///
27    /// Returns a `RoastMap` mapping issue keys to roast messages.
28    fn generate_roasts(&self, issues: &[CodeIssue], lang: &str) -> RoastMap;
29}
30
31/// Local roast provider using hardcoded messages from the i18n module.
32///
33/// This is the default provider and serves as the fallback when LLM calls fail.
34pub struct LocalRoastProvider;
35
36impl RoastProvider for LocalRoastProvider {
37    fn generate_roasts(&self, issues: &[CodeIssue], lang: &str) -> RoastMap {
38        let i18n = I18n::new(lang);
39        let mut map = RoastMap::new();
40
41        for issue in issues {
42            let key = format!(
43                "{}:{}:{}",
44                issue.file_path.display(),
45                issue.line,
46                issue.rule_name
47            );
48            let messages = i18n.get_roast_messages(&issue.rule_name);
49            let roast = if !messages.is_empty() {
50                messages[issue.line % messages.len()].clone()
51            } else {
52                issue.message.clone()
53            };
54            map.insert(key, roast);
55        }
56
57        map
58    }
59}
60
61/// LLM-powered roast provider that generates dynamic, context-aware roasts.
62///
63/// Falls back to `LocalRoastProvider` if the LLM call fails or returns invalid data.
64pub struct LlmRoastProvider {
65    client: LlmClient,
66    fallback: LocalRoastProvider,
67}
68
69impl LlmRoastProvider {
70    /// Create a new LLM roast provider with the given configuration.
71    pub fn new(config: LlmConfig) -> Self {
72        Self {
73            client: LlmClient::new(config),
74            fallback: LocalRoastProvider,
75        }
76    }
77}
78
79impl RoastProvider for LlmRoastProvider {
80    fn generate_roasts(&self, issues: &[CodeIssue], lang: &str) -> RoastMap {
81        let contexts = extract_code_contexts(issues);
82        let prompt = build_roast_prompt(issues, &contexts, lang);
83
84        tracing::debug!("Calling LLM with {} issues...", issues.len());
85        tracing::debug!(
86            "Prompt (first 500 chars): {}",
87            &prompt[..prompt.len().min(500)]
88        );
89
90        match self.client.call_blocking(&prompt) {
91            Ok(response) => {
92                tracing::debug!("LLM response received ({} chars)", response.len());
93                match parse_llm_response(&response, issues) {
94                    Ok(roasts) => {
95                        tracing::debug!("Parsed {} roasts from LLM", roasts.len());
96                        roasts
97                    }
98                    Err(e) => {
99                        tracing::warn!(
100                            "Failed to parse LLM response: {:#}. Falling back to local roasts.",
101                            e
102                        );
103                        self.fallback.generate_roasts(issues, lang)
104                    }
105                }
106            }
107            Err(e) => {
108                tracing::warn!("LLM call failed: {:#}. Falling back to local roasts.", e);
109                self.fallback.generate_roasts(issues, lang)
110            }
111        }
112    }
113}
114
115/// Extract code context (±5 lines) around each issue for the LLM prompt.
116///
117/// Groups issues by file to avoid reading the same file multiple times.
118fn extract_code_contexts(issues: &[CodeIssue]) -> HashMap<String, String> {
119    // Collect unique file paths
120    let file_paths: Vec<PathBuf> = issues
121        .iter()
122        .map(|i| i.file_path.clone())
123        .collect::<std::collections::HashSet<_>>()
124        .into_iter()
125        .collect();
126
127    // Read all file contents upfront
128    let file_contents: HashMap<PathBuf, Vec<String>> = file_paths
129        .into_iter()
130        .filter_map(|path| {
131            let content = std::fs::read_to_string(&path).ok()?;
132            let lines: Vec<String> = content.lines().map(String::from).collect();
133            Some((path, lines))
134        })
135        .collect();
136
137    // Extract context window for each issue
138    let mut contexts = HashMap::new();
139    for issue in issues {
140        let key = format!(
141            "{}:{}:{}",
142            issue.file_path.display(),
143            issue.line,
144            issue.rule_name
145        );
146
147        if let Some(lines) = file_contents.get(&issue.file_path) {
148            let start = issue.line.saturating_sub(6);
149            let end = (issue.line + 5).min(lines.len());
150            let context: String = lines[start..end]
151                .iter()
152                .enumerate()
153                .map(|(i, l)| format!("{:>4} | {}", start + i + 1, l))
154                .collect::<Vec<_>>()
155                .join("\n");
156            contexts.insert(key, context);
157        }
158    }
159
160    contexts
161}
162
163/// Parse the LLM response JSON into a RoastMap.
164///
165/// Expected format: `{"0": "roast message", "1": "roast message", ...}`
166/// where keys are issue indices (0-based) matching the order in the prompt.
167fn parse_llm_response(response: &str, issues: &[CodeIssue]) -> Result<RoastMap, anyhow::Error> {
168    let json_str = extract_json_from_response(response);
169    // LLMs often produce trailing commas in JSON — strip them for robustness
170    let cleaned = fix_trailing_commas(json_str);
171    let parsed: HashMap<String, String> = serde_json::from_str(&cleaned)?;
172
173    let mut roasts = RoastMap::new();
174    for (idx_str, roast) in parsed {
175        let Ok(idx) = idx_str.parse::<usize>() else {
176            continue;
177        };
178        if idx >= issues.len() {
179            continue;
180        }
181        let issue = &issues[idx];
182        let key = format!(
183            "{}:{}:{}",
184            issue.file_path.display(),
185            issue.line,
186            issue.rule_name
187        );
188        roasts.insert(key, roast);
189    }
190
191    Ok(roasts)
192}
193
194/// Extract JSON from LLM response, handling markdown code fences and plain JSON.
195fn extract_json_from_response(response: &str) -> &str {
196    // Handle ```json ... ``` wrapper
197    if let Some(start) = response.find("```json") {
198        let json_start = start + 7;
199        if let Some(end) = response[json_start..].find("```") {
200            return response[json_start..json_start + end].trim();
201        }
202    }
203
204    // Handle ``` ... ``` wrapper (without json tag)
205    if let Some(start) = response.find("```") {
206        let fence_start = start + 3;
207        // Skip the optional language tag on the same line
208        let content_start = response[fence_start..]
209            .find('\n')
210            .map(|n| fence_start + n + 1)
211            .unwrap_or(fence_start);
212        if let Some(end) = response[content_start..].find("```") {
213            return response[content_start..content_start + end].trim();
214        }
215    }
216
217    // Handle plain JSON object
218    if let Some(start) = response.find('{') {
219        if let Some(end) = response.rfind('}') {
220            return &response[start..=end];
221        }
222    }
223
224    response
225}
226
227/// Remove trailing commas from JSON before closing braces/brackets.
228///
229/// LLMs frequently produce invalid JSON like `{"a": 1, "b": 2,}` —
230/// this function strips the trailing comma to make it valid.
231fn fix_trailing_commas(json: &str) -> String {
232    let mut result = String::with_capacity(json.len());
233    let bytes = json.as_bytes();
234    let len = bytes.len();
235
236    for i in 0..len {
237        if bytes[i] == b',' {
238            // Look ahead past whitespace for `}` or `]`
239            let rest = &json[i + 1..];
240            let trimmed = rest.trim_start();
241            if trimmed.starts_with('}') || trimmed.starts_with(']') {
242                // Skip this comma
243                continue;
244            }
245        }
246        result.push(bytes[i] as char);
247    }
248
249    result
250}
251
252#[cfg(test)]
253mod tests {
254    use super::*;
255    use crate::analyzer::Severity;
256
257    /// Helper to create a test CodeIssue with minimal fields.
258    fn make_issue(rule: &str, line: usize) -> CodeIssue {
259        CodeIssue {
260            file_path: PathBuf::from("test.rs"),
261            line,
262            column: 1,
263            rule_name: rule.to_string(),
264            message: "test message".to_string(),
265            severity: Severity::Spicy,
266        }
267    }
268
269    #[test]
270    fn test_extract_json_from_plain_object() {
271        // Objective: Verify plain JSON objects are extracted correctly.
272        // Invariants: Output must match the input when it is a valid JSON object.
273        let response = r#"{"0": "roast one", "1": "roast two"}"#;
274        let result = extract_json_from_response(response);
275        assert_eq!(result, response, "Plain JSON should be returned as-is");
276    }
277
278    #[test]
279    fn test_extract_json_from_markdown_fence() {
280        // Objective: Verify JSON wrapped in ```json fences is extracted.
281        // Invariants: Only the JSON content between fences is returned.
282        let response = "Here is the JSON:\n```json\n{\"0\": \"roast\"}\n```\nDone.";
283        let result = extract_json_from_response(response);
284        assert_eq!(
285            result, "{\"0\": \"roast\"}",
286            "JSON inside markdown fences should be extracted"
287        );
288    }
289
290    #[test]
291    fn test_parse_response_maps_indices_to_issue_keys() {
292        // Objective: Verify LLM response indices map to correct issue keys.
293        // Invariants: Each index maps to the corresponding issue's key format.
294        let issues = vec![
295            make_issue("unwrap-abuse", 10),
296            make_issue("deep-nesting", 25),
297        ];
298        let response = r#"{"0": "nice unwrap", "1": "so deep"}"#;
299        let roasts = parse_llm_response(response, &issues).unwrap();
300
301        assert_eq!(roasts.len(), 2, "Should have roasts for both issues");
302        assert!(
303            roasts.contains_key("test.rs:10:unwrap-abuse"),
304            "First issue key must be test.rs:10:unwrap-abuse"
305        );
306        assert!(
307            roasts.contains_key("test.rs:25:deep-nesting"),
308            "Second issue key must be test.rs:25:deep-nesting"
309        );
310    }
311
312    #[test]
313    fn test_parse_response_skips_out_of_range_indices() {
314        // Objective: Verify out-of-range indices are silently ignored.
315        // Invariants: Only valid indices produce roasts; invalid ones are skipped.
316        let issues = vec![make_issue("unwrap-abuse", 10)];
317        let response = r#"{"0": "valid", "5": "out of range", "abc": "not a number"}"#;
318        let roasts = parse_llm_response(response, &issues).unwrap();
319
320        assert_eq!(
321            roasts.len(),
322            1,
323            "Only the valid index should produce a roast"
324        );
325        assert!(
326            roasts.contains_key("test.rs:10:unwrap-abuse"),
327            "Valid index 0 should map to the first issue"
328        );
329    }
330
331    #[test]
332    fn test_local_provider_returns_roasts_for_known_rules() {
333        // Objective: Verify LocalRoastProvider produces roasts for rules with i18n messages.
334        // Invariants: At least one roast must be returned for a known rule name.
335        let issues = vec![make_issue("unwrap-abuse", 1)];
336        let provider = LocalRoastProvider;
337        let roasts = provider.generate_roasts(&issues, "en-US");
338
339        assert!(
340            !roasts.is_empty(),
341            "LocalRoastProvider must return at least one roast for known rules"
342        );
343        assert!(
344            roasts.contains_key("test.rs:1:unwrap-abuse"),
345            "Roast key must match the issue key format"
346        );
347    }
348
349    #[test]
350    fn test_local_provider_returns_something_for_unknown_rules() {
351        // Objective: Verify unknown rules still produce a roast message.
352        // Invariants: The i18n module returns a catch-all message for unknown rules.
353        let issues = vec![make_issue("unknown-rule-xyz", 42)];
354        let provider = LocalRoastProvider;
355        let roasts = provider.generate_roasts(&issues, "en-US");
356
357        assert_eq!(
358            roasts.len(),
359            1,
360            "Should have exactly one roast for one issue"
361        );
362        let roast = roasts.get("test.rs:42:unknown-rule-xyz").unwrap();
363        assert!(
364            !roast.is_empty(),
365            "Unknown rules must still produce a non-empty roast message"
366        );
367    }
368
369    #[test]
370    fn test_parse_response_with_markdown_wrapped_json() {
371        // Objective: Verify end-to-end parsing with markdown-wrapped LLM output.
372        // Invariants: JSON inside code fences must parse correctly.
373        let issues = vec![make_issue("deep-nesting", 5)];
374        let response =
375            "Sure, here are the roasts:\n```json\n{\"0\": \"nested deeper than inception\"}\n```";
376        let roasts = parse_llm_response(response, &issues).unwrap();
377
378        assert_eq!(roasts.len(), 1, "Should parse one roast from fenced JSON");
379        let roast = roasts.get("test.rs:5:deep-nesting").unwrap();
380        assert_eq!(
381            roast, "nested deeper than inception",
382            "Roast content must match the JSON value"
383        );
384    }
385
386    #[test]
387    fn test_fix_trailing_commas_before_brace() {
388        let input = r#"{"0": "a", "1": "b",}"#;
389        let result = fix_trailing_commas(input);
390        assert_eq!(result, r#"{"0": "a", "1": "b"}"#);
391    }
392
393    #[test]
394    fn test_fix_trailing_commas_before_bracket() {
395        let input = r#"["a", "b",]"#;
396        let result = fix_trailing_commas(input);
397        assert_eq!(result, r#"["a", "b"]"#);
398    }
399
400    #[test]
401    fn test_fix_trailing_commas_preserves_valid_json() {
402        let input = r#"{"0": "a", "1": "b"}"#;
403        let result = fix_trailing_commas(input);
404        assert_eq!(result, input, "Valid JSON should be unchanged");
405    }
406
407    #[test]
408    fn test_fix_trailing_commas_handles_whitespace() {
409        // " ,  \n}" -> comma removed -> "   \n}" (space before comma + spaces after)
410        let input = "{\"0\": \"a\" ,  \n}";
411        let result = fix_trailing_commas(input);
412        assert!(!result.contains(",}"), "Trailing comma must be removed");
413        assert!(result.contains("\"a\""), "Content must be preserved");
414    }
415
416    #[test]
417    fn test_parse_response_with_trailing_comma() {
418        // Objective: Verify LLM output with trailing commas is handled.
419        // Invariants: Trailing commas must be stripped before parsing.
420        let issues = vec![
421            make_issue("unwrap-abuse", 10),
422            make_issue("deep-nesting", 25),
423        ];
424        let response = "```json\n{\"0\": \"nice unwrap\", \"1\": \"so deep\",}\n```";
425        let roasts = parse_llm_response(response, &issues).unwrap();
426
427        assert_eq!(
428            roasts.len(),
429            2,
430            "Should parse both roasts despite trailing comma"
431        );
432    }
433
434    #[test]
435    fn test_extract_json_from_generic_code_fence() {
436        // Objective: Verify JSON in ``` fences (without json tag) is extracted.
437        let response = "Here:\n```\n{\"0\": \"roast\"}\n```";
438        let result = extract_json_from_response(response);
439        assert_eq!(result, "{\"0\": \"roast\"}");
440    }
441}