use crate::claim::Claim;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TrustScore {
pub score: f64,
pub signals: ScoreSignals,
pub risk_level: RiskLevel,
pub explanation: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ScoreSignals {
pub confidence: f64,
pub specificity: f64,
pub hedging: f64,
pub verifiability: f64,
pub consistency: Option<f64>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum RiskLevel {
Low,
Medium,
High,
Critical,
}
impl std::fmt::Display for RiskLevel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
RiskLevel::Low => write!(f, "LOW"),
RiskLevel::Medium => write!(f, "MEDIUM"),
RiskLevel::High => write!(f, "HIGH"),
RiskLevel::Critical => write!(f, "CRITICAL"),
}
}
}
pub fn score_claim(claim: &Claim) -> TrustScore {
let signals = compute_signals(claim);
let score = aggregate_score(&signals);
let risk_level = classify_risk(score);
let explanation = explain_score(claim, &signals, &risk_level);
TrustScore {
score,
signals,
risk_level,
explanation,
}
}
pub fn score_passage(claims: &[Claim]) -> TrustScore {
if claims.is_empty() {
return TrustScore {
score: 0.5,
signals: ScoreSignals {
confidence: 0.5,
specificity: 0.5,
hedging: 0.5,
verifiability: 0.5,
consistency: None,
},
risk_level: RiskLevel::Medium,
explanation: "No claims to analyze.".to_string(),
};
}
let claim_scores: Vec<TrustScore> = claims.iter().map(score_claim).collect();
let n = claim_scores.len() as f64;
let avg_score = claim_scores.iter().map(|s| s.score).sum::<f64>() / n;
let min_score = claim_scores
.iter()
.map(|s| s.score)
.fold(f64::INFINITY, f64::min);
let passage_score = 0.7 * avg_score + 0.3 * min_score;
let passage_score = passage_score.clamp(0.0, 1.0);
let avg_signals = ScoreSignals {
confidence: claim_scores
.iter()
.map(|s| s.signals.confidence)
.sum::<f64>()
/ n,
specificity: claim_scores
.iter()
.map(|s| s.signals.specificity)
.sum::<f64>()
/ n,
hedging: claim_scores.iter().map(|s| s.signals.hedging).sum::<f64>() / n,
verifiability: claim_scores
.iter()
.map(|s| s.signals.verifiability)
.sum::<f64>()
/ n,
consistency: None,
};
let risk_level = classify_risk(passage_score);
let n_high_risk = claim_scores
.iter()
.filter(|s| s.risk_level == RiskLevel::High || s.risk_level == RiskLevel::Critical)
.count();
let explanation = format!(
"{} claims analyzed. {} high-risk claims detected. Average trust: {:.0}%.",
claims.len(),
n_high_risk,
passage_score * 100.0
);
TrustScore {
score: passage_score,
signals: avg_signals,
risk_level,
explanation,
}
}
fn compute_signals(claim: &Claim) -> ScoreSignals {
let confidence = if claim.is_hedged {
0.8 } else if claim.specificity > 0.5 && claim.is_verifiable {
0.4 } else if claim.specificity > 0.3 {
0.6 } else {
0.5 };
let specificity = claim.specificity;
let hedging = if claim.is_hedged { 0.85 } else { 0.5 };
let verifiability = if claim.is_verifiable { 0.7 } else { 0.5 };
ScoreSignals {
confidence,
specificity,
hedging,
verifiability,
consistency: None,
}
}
const W_CONFIDENCE: f64 = 0.35;
const W_HEDGING: f64 = 0.25;
const W_SPECIFICITY: f64 = 0.20;
const W_VERIFIABILITY: f64 = 0.15;
const W_CONSISTENCY: f64 = 0.05;
fn aggregate_score(signals: &ScoreSignals) -> f64 {
let base = W_CONFIDENCE * signals.confidence
+ W_HEDGING * signals.hedging
+ W_SPECIFICITY * signals.specificity
+ W_VERIFIABILITY * signals.verifiability;
let consistency_bonus = signals.consistency.unwrap_or(0.5) * W_CONSISTENCY;
(base + consistency_bonus).clamp(0.0, 1.0)
}
pub fn classify_risk_pub(score: f64) -> RiskLevel {
classify_risk(score)
}
fn classify_risk(score: f64) -> RiskLevel {
if score >= 0.75 {
RiskLevel::Low
} else if score >= 0.55 {
RiskLevel::Medium
} else if score >= 0.35 {
RiskLevel::High
} else {
RiskLevel::Critical
}
}
fn explain_score(claim: &Claim, signals: &ScoreSignals, risk: &RiskLevel) -> String {
let mut reasons = Vec::new();
if signals.confidence < 0.5 {
reasons.push("overconfident language without hedging");
}
if claim.is_hedged {
reasons.push("appropriately hedged");
}
if claim.is_verifiable && !claim.is_hedged {
reasons.push("specific verifiable claim — verify independently");
}
if claim.specificity < 0.3 {
reasons.push("vague claim with low specificity");
}
if claim.specificity > 0.6 {
reasons.push("highly specific claim");
}
if reasons.is_empty() {
reasons.push("no strong signals detected");
}
format!(
"[{risk}] Trust: {:.0}% — {}",
signals.confidence * 100.0,
reasons.join("; ")
)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::claim::extract_claims;
#[test]
fn score_is_bounded() {
let claims = extract_claims("Einstein was born in 1879. The sky might be purple.");
for claim in &claims {
let score = score_claim(claim);
assert!(
score.score >= 0.0 && score.score <= 1.0,
"Score out of bounds: {}",
score.score
);
}
}
#[test]
fn hedged_claims_score_higher() {
let hedged = Claim {
text: "This might be related to the discovery.".to_string(),
sentence_idx: 0,
is_verifiable: false,
specificity: 0.3,
is_hedged: true,
};
let confident = Claim {
text: "Einstein discovered exactly 47 particles in 1903.".to_string(),
sentence_idx: 0,
is_verifiable: true,
specificity: 0.8,
is_hedged: false,
};
let hedged_score = score_claim(&hedged);
let confident_score = score_claim(&confident);
assert!(
hedged_score.score > confident_score.score,
"Hedged {:.3} should score higher than overconfident {:.3}",
hedged_score.score,
confident_score.score
);
}
#[test]
fn passage_scoring() {
let claims = extract_claims(
"Albert Einstein was born in 1879. He might have visited Paris. \
The theory of relativity was published in exactly 1905.",
);
let passage = score_passage(&claims);
assert!(passage.score >= 0.0 && passage.score <= 1.0);
assert!(!passage.explanation.is_empty());
}
#[test]
fn empty_passage() {
let passage = score_passage(&[]);
assert_eq!(passage.risk_level, RiskLevel::Medium);
}
#[test]
fn risk_classification() {
assert_eq!(classify_risk(0.8), RiskLevel::Low);
assert_eq!(classify_risk(0.6), RiskLevel::Medium);
assert_eq!(classify_risk(0.4), RiskLevel::High);
assert_eq!(classify_risk(0.2), RiskLevel::Critical);
}
}