use std::collections::HashMap;
use std::path::PathBuf;
use crate::analyzer::CodeIssue;
use crate::i18n::I18n;
use super::client::{LlmClient, LlmConfig};
use super::prompt::build_roast_prompt;
pub type RoastMap = HashMap<String, String>;
pub trait RoastProvider {
fn generate_roasts(&self, issues: &[CodeIssue], lang: &str) -> RoastMap;
}
pub struct LocalRoastProvider;
impl RoastProvider for LocalRoastProvider {
fn generate_roasts(&self, issues: &[CodeIssue], lang: &str) -> RoastMap {
let i18n = I18n::new(lang);
let mut map = RoastMap::new();
for issue in issues {
let key = format!(
"{}:{}:{}",
issue.file_path.display(),
issue.line,
issue.rule_name
);
let messages = i18n.get_roast_messages(&issue.rule_name);
let roast = if !messages.is_empty() {
messages[issue.line % messages.len()].clone()
} else {
issue.message.clone()
};
map.insert(key, roast);
}
map
}
}
pub struct LlmRoastProvider {
client: LlmClient,
fallback: LocalRoastProvider,
}
impl LlmRoastProvider {
pub fn new(config: LlmConfig) -> Self {
Self {
client: LlmClient::new(config),
fallback: LocalRoastProvider,
}
}
}
impl RoastProvider for LlmRoastProvider {
fn generate_roasts(&self, issues: &[CodeIssue], lang: &str) -> RoastMap {
let contexts = extract_code_contexts(issues);
let prompt = build_roast_prompt(issues, &contexts, lang);
tracing::debug!("Calling LLM with {} issues...", issues.len());
tracing::debug!(
"Prompt (first 500 chars): {}",
&prompt[..prompt.len().min(500)]
);
match self.client.call_blocking(&prompt) {
Ok(response) => {
tracing::debug!("LLM response received ({} chars)", response.len());
match parse_llm_response(&response, issues) {
Ok(roasts) => {
tracing::debug!("Parsed {} roasts from LLM", roasts.len());
roasts
}
Err(e) => {
tracing::warn!(
"Failed to parse LLM response: {:#}. Falling back to local roasts.",
e
);
self.fallback.generate_roasts(issues, lang)
}
}
}
Err(e) => {
tracing::warn!("LLM call failed: {:#}. Falling back to local roasts.", e);
self.fallback.generate_roasts(issues, lang)
}
}
}
}
fn extract_code_contexts(issues: &[CodeIssue]) -> HashMap<String, String> {
let file_paths: Vec<PathBuf> = issues
.iter()
.map(|i| i.file_path.clone())
.collect::<std::collections::HashSet<_>>()
.into_iter()
.collect();
let file_contents: HashMap<PathBuf, Vec<String>> = file_paths
.into_iter()
.filter_map(|path| {
let content = std::fs::read_to_string(&path).ok()?;
let lines: Vec<String> = content.lines().map(String::from).collect();
Some((path, lines))
})
.collect();
let mut contexts = HashMap::new();
for issue in issues {
let key = format!(
"{}:{}:{}",
issue.file_path.display(),
issue.line,
issue.rule_name
);
if let Some(lines) = file_contents.get(&issue.file_path) {
let start = issue.line.saturating_sub(6);
let end = (issue.line + 5).min(lines.len());
let context: String = lines[start..end]
.iter()
.enumerate()
.map(|(i, l)| format!("{:>4} | {}", start + i + 1, l))
.collect::<Vec<_>>()
.join("\n");
contexts.insert(key, context);
}
}
contexts
}
fn parse_llm_response(response: &str, issues: &[CodeIssue]) -> Result<RoastMap, anyhow::Error> {
let json_str = extract_json_from_response(response);
let cleaned = fix_trailing_commas(json_str);
let parsed: HashMap<String, String> = serde_json::from_str(&cleaned)?;
let mut roasts = RoastMap::new();
for (idx_str, roast) in parsed {
let Ok(idx) = idx_str.parse::<usize>() else {
continue;
};
if idx >= issues.len() {
continue;
}
let issue = &issues[idx];
let key = format!(
"{}:{}:{}",
issue.file_path.display(),
issue.line,
issue.rule_name
);
roasts.insert(key, roast);
}
Ok(roasts)
}
fn extract_json_from_response(response: &str) -> &str {
if let Some(start) = response.find("```json") {
let json_start = start + 7;
if let Some(end) = response[json_start..].find("```") {
return response[json_start..json_start + end].trim();
}
}
if let Some(start) = response.find("```") {
let fence_start = start + 3;
let content_start = response[fence_start..]
.find('\n')
.map(|n| fence_start + n + 1)
.unwrap_or(fence_start);
if let Some(end) = response[content_start..].find("```") {
return response[content_start..content_start + end].trim();
}
}
if let Some(start) = response.find('{') {
if let Some(end) = response.rfind('}') {
return &response[start..=end];
}
}
response
}
fn fix_trailing_commas(json: &str) -> String {
let mut result = String::with_capacity(json.len());
let bytes = json.as_bytes();
let len = bytes.len();
for i in 0..len {
if bytes[i] == b',' {
let rest = &json[i + 1..];
let trimmed = rest.trim_start();
if trimmed.starts_with('}') || trimmed.starts_with(']') {
continue;
}
}
result.push(bytes[i] as char);
}
result
}
#[cfg(test)]
mod tests {
use super::*;
use crate::analyzer::Severity;
fn make_issue(rule: &str, line: usize) -> CodeIssue {
CodeIssue {
file_path: PathBuf::from("test.rs"),
line,
column: 1,
rule_name: rule.to_string(),
message: "test message".to_string(),
severity: Severity::Spicy,
}
}
#[test]
fn test_extract_json_from_plain_object() {
let response = r#"{"0": "roast one", "1": "roast two"}"#;
let result = extract_json_from_response(response);
assert_eq!(result, response, "Plain JSON should be returned as-is");
}
#[test]
fn test_extract_json_from_markdown_fence() {
let response = "Here is the JSON:\n```json\n{\"0\": \"roast\"}\n```\nDone.";
let result = extract_json_from_response(response);
assert_eq!(
result, "{\"0\": \"roast\"}",
"JSON inside markdown fences should be extracted"
);
}
#[test]
fn test_parse_response_maps_indices_to_issue_keys() {
let issues = vec![
make_issue("unwrap-abuse", 10),
make_issue("deep-nesting", 25),
];
let response = r#"{"0": "nice unwrap", "1": "so deep"}"#;
let roasts = parse_llm_response(response, &issues).unwrap();
assert_eq!(roasts.len(), 2, "Should have roasts for both issues");
assert!(
roasts.contains_key("test.rs:10:unwrap-abuse"),
"First issue key must be test.rs:10:unwrap-abuse"
);
assert!(
roasts.contains_key("test.rs:25:deep-nesting"),
"Second issue key must be test.rs:25:deep-nesting"
);
}
#[test]
fn test_parse_response_skips_out_of_range_indices() {
let issues = vec![make_issue("unwrap-abuse", 10)];
let response = r#"{"0": "valid", "5": "out of range", "abc": "not a number"}"#;
let roasts = parse_llm_response(response, &issues).unwrap();
assert_eq!(
roasts.len(),
1,
"Only the valid index should produce a roast"
);
assert!(
roasts.contains_key("test.rs:10:unwrap-abuse"),
"Valid index 0 should map to the first issue"
);
}
#[test]
fn test_local_provider_returns_roasts_for_known_rules() {
let issues = vec![make_issue("unwrap-abuse", 1)];
let provider = LocalRoastProvider;
let roasts = provider.generate_roasts(&issues, "en-US");
assert!(
!roasts.is_empty(),
"LocalRoastProvider must return at least one roast for known rules"
);
assert!(
roasts.contains_key("test.rs:1:unwrap-abuse"),
"Roast key must match the issue key format"
);
}
#[test]
fn test_local_provider_returns_something_for_unknown_rules() {
let issues = vec![make_issue("unknown-rule-xyz", 42)];
let provider = LocalRoastProvider;
let roasts = provider.generate_roasts(&issues, "en-US");
assert_eq!(
roasts.len(),
1,
"Should have exactly one roast for one issue"
);
let roast = roasts.get("test.rs:42:unknown-rule-xyz").unwrap();
assert!(
!roast.is_empty(),
"Unknown rules must still produce a non-empty roast message"
);
}
#[test]
fn test_parse_response_with_markdown_wrapped_json() {
let issues = vec![make_issue("deep-nesting", 5)];
let response =
"Sure, here are the roasts:\n```json\n{\"0\": \"nested deeper than inception\"}\n```";
let roasts = parse_llm_response(response, &issues).unwrap();
assert_eq!(roasts.len(), 1, "Should parse one roast from fenced JSON");
let roast = roasts.get("test.rs:5:deep-nesting").unwrap();
assert_eq!(
roast, "nested deeper than inception",
"Roast content must match the JSON value"
);
}
#[test]
fn test_fix_trailing_commas_before_brace() {
let input = r#"{"0": "a", "1": "b",}"#;
let result = fix_trailing_commas(input);
assert_eq!(result, r#"{"0": "a", "1": "b"}"#);
}
#[test]
fn test_fix_trailing_commas_before_bracket() {
let input = r#"["a", "b",]"#;
let result = fix_trailing_commas(input);
assert_eq!(result, r#"["a", "b"]"#);
}
#[test]
fn test_fix_trailing_commas_preserves_valid_json() {
let input = r#"{"0": "a", "1": "b"}"#;
let result = fix_trailing_commas(input);
assert_eq!(result, input, "Valid JSON should be unchanged");
}
#[test]
fn test_fix_trailing_commas_handles_whitespace() {
let input = "{\"0\": \"a\" , \n}";
let result = fix_trailing_commas(input);
assert!(!result.contains(",}"), "Trailing comma must be removed");
assert!(result.contains("\"a\""), "Content must be preserved");
}
#[test]
fn test_parse_response_with_trailing_comma() {
let issues = vec![
make_issue("unwrap-abuse", 10),
make_issue("deep-nesting", 25),
];
let response = "```json\n{\"0\": \"nice unwrap\", \"1\": \"so deep\",}\n```";
let roasts = parse_llm_response(response, &issues).unwrap();
assert_eq!(
roasts.len(),
2,
"Should parse both roasts despite trailing comma"
);
}
#[test]
fn test_extract_json_from_generic_code_fence() {
let response = "Here:\n```\n{\"0\": \"roast\"}\n```";
let result = extract_json_from_response(response);
assert_eq!(result, "{\"0\": \"roast\"}");
}
}