#[cfg(not(feature = "std"))]
use alloc::string::{String, ToString};
#[cfg(not(feature = "std"))]
use alloc::vec::Vec;
use crate::collections::{HashSet, new_set};
use crate::context::{Context, Value};
use crate::language::Language;
const POLARITY_TOKENS: &[&str] = &[
"not", "never", "no", "none", "cannot", "won't", "neither", "nor",
];
const STOPWORDS: &[&str] = &[
"a", "an", "the", "and", "or", "but", "in", "on", "at", "to", "for", "of", "with", "by",
"from", "is", "was", "are", "were", "be", "been", "being", "have", "has", "had", "do", "does",
"did", "will", "would", "could", "should", "may", "might", "shall", "can", "it", "its", "this",
"that", "these", "those", "which", "who", "what", "where", "when", "how", "if", "then", "than",
"so", "as", "up", "out", "into", "also", "just", "more", "most",
];
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct FaithfulnessScore {
pub precision: f32,
pub polarity_match: bool,
pub unentailed: Vec<String>,
pub polarity_drift: Vec<PolarityDrift>,
}
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct PolarityDrift {
pub token: String,
pub in_source: usize,
pub in_hypothesis: usize,
}
impl FaithfulnessScore {
pub fn is_faithful(&self) -> bool {
self.precision >= 1.0 && self.polarity_match
}
pub fn passes(&self, threshold: f32) -> bool {
self.precision >= threshold && self.polarity_match
}
}
pub fn score_faithfulness(
output: &str,
context: &Context,
template_literals: &[&str],
language: &dyn Language,
) -> FaithfulnessScore {
let hyp_tokens = tokenize(output);
let src_tokens: Vec<String> = {
let mut v = tokens_from_context(context);
v.extend(tokens_from_literals(template_literals));
v
};
let mut polarity_drift = Vec::new();
let mut polarity_match = true;
for &tok in POLARITY_TOKENS {
let s = src_tokens
.iter()
.filter(|t| t.as_ref() as &str == tok)
.count();
let h = hyp_tokens
.iter()
.filter(|t| t.as_ref() as &str == tok)
.count();
if s != h {
polarity_match = false;
polarity_drift.push(PolarityDrift {
token: tok.to_string(),
in_source: s,
in_hypothesis: h,
});
}
}
let hyp_content: Vec<&String> = hyp_tokens.iter().filter(|t| is_content_token(t)).collect();
if hyp_content.is_empty() {
return FaithfulnessScore {
precision: 1.0,
polarity_match,
unentailed: Vec::new(),
polarity_drift,
};
}
let mut src_set: HashSet<String> = new_set();
for t in &src_tokens {
src_set.insert(t.clone());
src_set.insert(language.singularize(t));
}
let mut entailed: usize = 0;
let mut unentailed: Vec<String> = Vec::new();
for t in &hyp_content {
let t_sing = language.singularize(t);
if src_set.contains(t.as_ref() as &str) || src_set.contains(&t_sing) {
entailed += 1;
} else {
unentailed.push((*t).clone());
}
}
let precision = entailed as f32 / hyp_content.len() as f32;
FaithfulnessScore {
precision,
polarity_match,
unentailed,
polarity_drift,
}
}
fn tokenize(text: &str) -> Vec<String> {
text.split_whitespace()
.map(|raw| {
raw.trim_matches(|c: char| {
matches!(
c,
',' | '.' | ':' | ';' | '!' | '?' | '"' | '\''
| '(' | ')' | '[' | ']' | '\u{2014}' | '\u{2013}' | '-'
)
})
.to_lowercase()
})
.filter(|s| !s.is_empty())
.collect()
}
fn is_stopword(tok: &str) -> bool {
STOPWORDS.contains(&tok)
}
fn is_polarity(tok: &str) -> bool {
POLARITY_TOKENS.contains(&tok)
}
fn is_numeric(tok: &str) -> bool {
!tok.is_empty() && tok.chars().all(|c| c.is_ascii_digit())
}
fn is_content_token(tok: &str) -> bool {
tok.len() >= 3 && !is_stopword(tok) && !is_polarity(tok) && !is_numeric(tok)
}
fn tokens_from_context(ctx: &Context) -> Vec<String> {
let mut out = Vec::new();
for (_key, value) in ctx.iter() {
match value {
Value::String(s) => out.extend(tokenize(s)),
Value::Number(n) => out.push(n.to_string()),
Value::List(items) => {
for item in items {
out.extend(tokenize(item));
}
}
Value::Entity { name, .. } => out.extend(tokenize(name)),
}
}
out
}
fn tokens_from_literals(literals: &[&str]) -> Vec<String> {
let mut out = Vec::new();
for lit in literals {
out.extend(tokenize(lit));
}
out
}
#[macro_export]
macro_rules! assert_faithful {
($output:expr, $context:expr, $template_literals:expr, $language:expr $(,)?) => {{
let score = $crate::score_faithfulness(
&$output,
&$context,
&$template_literals,
$language,
);
if !score.is_faithful() {
panic!(
"faithfulness violation:\n precision: {:.3}\n polarity_match: {}\n unentailed: {:?}\n polarity_drift: {:?}\n output: {}\n",
score.precision,
score.polarity_match,
score.unentailed,
score.polarity_drift,
$output,
);
}
}};
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn tokenize_strips_edge_punctuation() {
let toks = tokenize("hello, world! (foo)");
assert_eq!(toks, vec!["hello", "world", "foo"]);
}
#[test]
fn tokenize_lowercases() {
let toks = tokenize("UserService AccountService");
assert_eq!(toks, vec!["userservice", "accountservice"]);
}
#[test]
fn tokenize_preserves_inner_hyphens_and_apostrophes() {
let toks = tokenize("user-facing won't");
assert_eq!(toks, vec!["user-facing", "won't"]);
}
#[test]
fn tokenize_empty_string_produces_no_tokens() {
assert!(tokenize("").is_empty());
}
#[test]
fn tokenize_punctuation_only_produces_no_tokens() {
assert!(tokenize("... ,,, ???").is_empty());
}
#[test]
fn content_token_respects_length_threshold() {
assert!(!is_content_token("it"));
assert!(is_content_token("foo")); }
#[test]
fn content_token_excludes_stopwords() {
assert!(!is_content_token("the"));
assert!(!is_content_token("and"));
assert!(!is_content_token("was"));
}
#[test]
fn content_token_excludes_polarity() {
assert!(!is_content_token("not"));
assert!(!is_content_token("never"));
assert!(!is_content_token("nor"));
}
#[test]
fn content_token_excludes_pure_digits() {
assert!(!is_content_token("123"));
assert!(!is_content_token("42"));
}
#[test]
fn content_token_admits_alphanumeric_mixed() {
assert!(is_content_token("a123")); }
#[test]
fn is_faithful_requires_both_precision_and_polarity() {
let score_bad_polarity = FaithfulnessScore {
precision: 1.0,
polarity_match: false,
unentailed: vec![],
polarity_drift: vec![PolarityDrift {
token: "not".into(),
in_source: 0,
in_hypothesis: 1,
}],
};
assert!(!score_bad_polarity.is_faithful());
let score_bad_precision = FaithfulnessScore {
precision: 0.8,
polarity_match: true,
unentailed: vec!["extra".into()],
polarity_drift: vec![],
};
assert!(!score_bad_precision.is_faithful());
let score_good = FaithfulnessScore {
precision: 1.0,
polarity_match: true,
unentailed: vec![],
polarity_drift: vec![],
};
assert!(score_good.is_faithful());
}
#[test]
fn passes_threshold_semantics() {
let score = FaithfulnessScore {
precision: 0.75,
polarity_match: true,
unentailed: vec!["extra".into()],
polarity_drift: vec![],
};
assert!(score.passes(0.5), "0.75 >= 0.5 should pass");
assert!(score.passes(0.75), "0.75 >= 0.75 should pass (boundary)");
assert!(!score.passes(0.76), "0.75 < 0.76 should fail");
let score_polarity_bad = FaithfulnessScore {
precision: 1.0,
polarity_match: false,
unentailed: vec![],
polarity_drift: vec![PolarityDrift {
token: "not".into(),
in_source: 0,
in_hypothesis: 1,
}],
};
assert!(
!score_polarity_bad.passes(0.0),
"polarity mismatch always fails"
);
}
}