use crate::corpus::evaluation::{evaluate, EvaluationReport};
use crate::linter;
use serde::Serialize;
pub const DETECTION_F1_WEIGHT: f64 = 0.25;
pub const RULE_CITATION_WEIGHT: f64 = 0.20;
pub const CWE_MAPPING_WEIGHT: f64 = 0.10;
pub const FIX_VALIDITY_WEIGHT: f64 = 0.15;
pub const EXPLANATION_WEIGHT: f64 = 0.15;
pub const OOD_WEIGHT: f64 = 0.15;
#[derive(Debug, Clone)]
pub struct Prediction {
pub id: String,
pub classification: String,
pub cited_rules: Vec<String>,
pub cited_cwes: Vec<String>,
pub proposed_fix: Option<String>,
pub explanation: String,
}
#[derive(Debug, Clone)]
pub struct GroundTruth {
pub id: String,
pub label: u8,
pub rules: Vec<String>,
pub cwes: Vec<String>,
pub script: String,
}
#[derive(Debug, Clone, Serialize)]
pub struct EvalResult {
pub detection: EvaluationReport,
pub detection_f1: f64,
pub rule_citation: f64,
pub cwe_mapping: f64,
pub fix_validity: f64,
pub explanation_quality: f64,
pub ood_generalization: f64,
pub composite_score: f64,
pub weighted_breakdown: WeightedBreakdown,
pub total: usize,
pub static_dynamic_gap: Option<f64>,
pub model_mcc_vs_keyword: Option<f64>,
}
#[derive(Debug, Clone, Serialize)]
pub struct WeightedBreakdown {
pub detection_f1: f64,
pub rule_citation: f64,
pub cwe_mapping: f64,
pub fix_validity: f64,
pub explanation: f64,
pub ood: f64,
}
pub fn run_eval(predictions: &[Prediction], ground_truth: &[GroundTruth]) -> EvalResult {
let gt_map: std::collections::HashMap<&str, &GroundTruth> =
ground_truth.iter().map(|gt| (gt.id.as_str(), gt)).collect();
let detection_pairs: Vec<(u8, u8)> = predictions
.iter()
.filter_map(|p| {
gt_map.get(p.id.as_str()).map(|gt| {
let pred_label = if p.classification == "unsafe" {
1u8
} else {
0u8
};
(pred_label, gt.label)
})
})
.collect();
let detection = evaluate(&detection_pairs, "model");
let detection_f1 = detection.f1;
let rule_citation = compute_rule_citation(predictions, >_map);
let cwe_mapping = compute_cwe_mapping(predictions, >_map);
let fix_validity = compute_fix_validity(predictions, >_map);
let explanation_quality = compute_explanation_quality(predictions);
let ood_generalization = 0.0;
let weighted = WeightedBreakdown {
detection_f1: detection_f1 * DETECTION_F1_WEIGHT,
rule_citation: rule_citation * RULE_CITATION_WEIGHT,
cwe_mapping: cwe_mapping * CWE_MAPPING_WEIGHT,
fix_validity: fix_validity * FIX_VALIDITY_WEIGHT,
explanation: explanation_quality * EXPLANATION_WEIGHT,
ood: ood_generalization * OOD_WEIGHT,
};
let composite = weighted.detection_f1
+ weighted.rule_citation
+ weighted.cwe_mapping
+ weighted.fix_validity
+ weighted.explanation
+ weighted.ood;
EvalResult {
detection,
detection_f1,
rule_citation,
cwe_mapping,
fix_validity,
explanation_quality,
ood_generalization,
composite_score: composite,
weighted_breakdown: weighted,
total: predictions.len(),
static_dynamic_gap: None,
model_mcc_vs_keyword: None,
}
}
fn compute_rule_citation(
predictions: &[Prediction],
gt_map: &std::collections::HashMap<&str, &GroundTruth>,
) -> f64 {
let mut correct = 0usize;
let mut total = 0usize;
for pred in predictions {
if let Some(gt) = gt_map.get(pred.id.as_str()) {
if gt.label == 1 && !gt.rules.is_empty() {
total += 1;
if pred.cited_rules.iter().any(|r| gt.rules.contains(r)) {
correct += 1;
}
}
}
}
if total > 0 {
correct as f64 / total as f64
} else {
0.0
}
}
fn compute_cwe_mapping(
predictions: &[Prediction],
gt_map: &std::collections::HashMap<&str, &GroundTruth>,
) -> f64 {
let mut correct = 0usize;
let mut total = 0usize;
for pred in predictions {
if let Some(gt) = gt_map.get(pred.id.as_str()) {
if gt.label == 1 && !gt.cwes.is_empty() {
total += 1;
if pred.cited_cwes.iter().any(|c| gt.cwes.contains(c)) {
correct += 1;
}
}
}
}
if total > 0 {
correct as f64 / total as f64
} else {
0.0
}
}
fn compute_fix_validity(
predictions: &[Prediction],
gt_map: &std::collections::HashMap<&str, &GroundTruth>,
) -> f64 {
let mut valid = 0usize;
let mut total = 0usize;
for pred in predictions {
if let Some(fix) = &pred.proposed_fix {
if let Some(gt) = gt_map.get(pred.id.as_str()) {
if gt.label == 1 && !fix.is_empty() {
total += 1;
let result = linter::lint_shell(fix);
let remaining_rules: Vec<&str> =
result.diagnostics.iter().map(|d| d.code.as_str()).collect();
let original_fixed = gt
.rules
.iter()
.all(|r| !remaining_rules.contains(&r.as_str()));
if original_fixed {
valid += 1;
}
}
}
}
}
if total > 0 {
valid as f64 / total as f64
} else {
0.5
}
}
#[allow(clippy::if_same_then_else)] fn compute_explanation_quality(predictions: &[Prediction]) -> f64 {
if predictions.is_empty() {
return 0.0;
}
let mut total_score = 0.0;
for pred in predictions {
let mut score = 0.0;
if pred.explanation.contains("safe") || pred.explanation.contains("unsafe") {
score += 0.25;
}
if !pred.cited_rules.is_empty() {
score += 0.25;
}
if pred.classification == "unsafe" && pred.explanation.len() > 50 {
score += 0.25;
} else if pred.classification == "safe" {
score += 0.25; }
let actionable_keywords = ["use", "instead", "replace", "remove", "avoid", "fix"];
if pred.classification == "unsafe"
&& actionable_keywords
.iter()
.any(|kw| pred.explanation.to_lowercase().contains(kw))
{
score += 0.25;
} else if pred.classification == "safe" {
score += 0.25;
}
total_score += score;
}
total_score / predictions.len() as f64
}
pub fn format_eval_report(result: &EvalResult) -> String {
use std::fmt::Write;
let mut out = String::new();
let _ = writeln!(out, "ShellSafetyBench Evaluation Report");
let _ = writeln!(out, "==================================");
let _ = writeln!(out, "Total entries: {}", result.total);
let _ = writeln!(out);
let _ = writeln!(out, "Metrics (weighted):");
let _ = writeln!(
out,
" Detection F1: {:.3} (x{:.0}% = {:.3})",
result.detection_f1,
DETECTION_F1_WEIGHT * 100.0,
result.weighted_breakdown.detection_f1
);
let _ = writeln!(
out,
" Rule Citation: {:.3} (x{:.0}% = {:.3})",
result.rule_citation,
RULE_CITATION_WEIGHT * 100.0,
result.weighted_breakdown.rule_citation
);
let _ = writeln!(
out,
" CWE Mapping: {:.3} (x{:.0}% = {:.3})",
result.cwe_mapping,
CWE_MAPPING_WEIGHT * 100.0,
result.weighted_breakdown.cwe_mapping
);
let _ = writeln!(
out,
" Fix Validity: {:.3} (x{:.0}% = {:.3})",
result.fix_validity,
FIX_VALIDITY_WEIGHT * 100.0,
result.weighted_breakdown.fix_validity
);
let _ = writeln!(
out,
" Explanation: {:.3} (x{:.0}% = {:.3})",
result.explanation_quality,
EXPLANATION_WEIGHT * 100.0,
result.weighted_breakdown.explanation
);
let _ = writeln!(
out,
" OOD Generalize: {:.3} (x{:.0}% = {:.3})",
result.ood_generalization,
OOD_WEIGHT * 100.0,
result.weighted_breakdown.ood
);
let _ = writeln!(out);
let _ = writeln!(
out,
" COMPOSITE SCORE: {:.3} / 1.000",
result.composite_score
);
if let Some(gap) = result.static_dynamic_gap {
let _ = writeln!(
out,
" Static-Dynamic Gap: {:.1}% (target: <15%)",
gap * 100.0
);
}
if let Some(mcc_diff) = result.model_mcc_vs_keyword {
let _ = writeln!(out, " MCC vs Keyword: {:.3} (target: >0)", mcc_diff);
}
out
}
#[derive(Debug, Clone, serde::Deserialize)]
pub struct EvalPrediction {
#[serde(default)]
pub id: String,
pub classification: String,
#[serde(default)]
pub label: u8,
#[serde(default)]
pub cited_rules: Vec<String>,
#[serde(default)]
pub cited_cwes: Vec<String>,
#[serde(default)]
pub proposed_fix: Option<String>,
#[serde(default)]
pub explanation: String,
#[serde(default)]
pub script: String,
#[serde(default)]
pub ground_truth_rules: Vec<String>,
#[serde(default)]
pub ground_truth_cwes: Vec<String>,
}
#[derive(Debug, Clone, Serialize)]
pub struct SimpleEvalResult {
pub detection_f1: f64,
pub rule_citation: f64,
pub cwe_mapping: f64,
pub fix_validity: f64,
pub explanation_quality: f64,
pub ood_generalization: f64,
pub weighted_score: f64,
pub total: usize,
}
pub fn evaluate_predictions(preds: &[EvalPrediction]) -> SimpleEvalResult {
let predictions: Vec<Prediction> = preds
.iter()
.enumerate()
.map(|(i, p)| Prediction {
id: if p.id.is_empty() {
format!("SSB-{:05}", i)
} else {
p.id.clone()
},
classification: p.classification.clone(),
cited_rules: p.cited_rules.clone(),
cited_cwes: p.cited_cwes.clone(),
proposed_fix: p.proposed_fix.clone(),
explanation: p.explanation.clone(),
})
.collect();
let ground_truth: Vec<GroundTruth> = preds
.iter()
.enumerate()
.map(|(i, p)| GroundTruth {
id: if p.id.is_empty() {
format!("SSB-{:05}", i)
} else {
p.id.clone()
},
label: p.label,
rules: p.ground_truth_rules.clone(),
cwes: p.ground_truth_cwes.clone(),
script: p.script.clone(),
})
.collect();
let result = run_eval(&predictions, &ground_truth);
SimpleEvalResult {
detection_f1: result.detection_f1,
rule_citation: result.rule_citation,
cwe_mapping: result.cwe_mapping,
fix_validity: result.fix_validity,
explanation_quality: result.explanation_quality,
ood_generalization: result.ood_generalization,
weighted_score: result.composite_score,
total: result.total,
}
}
#[cfg(test)]
#[path = "eval_harness_tests_make_predict.rs"]
mod tests_extracted;