use std::collections::HashMap;
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
use crate::types::{ClaimStructure, LogicalClaim};
pub trait ClaimNormalizer: Send + Sync {
fn normalize(&self, claim: &LogicalClaim) -> LogicalClaim;
fn normalize_text(&self, text: &str) -> String;
}
pub struct DefaultClaimNormalizer {
lowercase: bool,
normalize_whitespace: bool,
remove_punctuation: bool,
}
impl Default for DefaultClaimNormalizer {
fn default() -> Self {
Self {
lowercase: true,
normalize_whitespace: true,
remove_punctuation: false,
}
}
}
impl DefaultClaimNormalizer {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_lowercase(mut self, lowercase: bool) -> Self {
self.lowercase = lowercase;
self
}
#[must_use]
pub fn with_normalize_whitespace(mut self, normalize: bool) -> Self {
self.normalize_whitespace = normalize;
self
}
#[must_use]
pub fn with_remove_punctuation(mut self, remove: bool) -> Self {
self.remove_punctuation = remove;
self
}
fn normalize_structure(&self, structure: &ClaimStructure) -> ClaimStructure {
match structure {
ClaimStructure::Predicate {
subject,
predicate,
object,
} => ClaimStructure::Predicate {
subject: self.normalize_text(subject),
predicate: self.normalize_text(predicate),
object: object.as_ref().map(|o| self.normalize_text(o)),
},
ClaimStructure::Comparison {
left,
operator,
right,
} => ClaimStructure::Comparison {
left: self.normalize_text(left),
operator: *operator,
right: self.normalize_text(right),
},
ClaimStructure::And(claims) => {
ClaimStructure::And(claims.iter().map(|c| self.normalize_structure(c)).collect())
}
ClaimStructure::Or(claims) => {
ClaimStructure::Or(claims.iter().map(|c| self.normalize_structure(c)).collect())
}
ClaimStructure::Not(inner) => {
ClaimStructure::Not(Box::new(self.normalize_structure(inner)))
}
ClaimStructure::Implies {
premise,
conclusion,
} => ClaimStructure::Implies {
premise: Box::new(self.normalize_structure(premise)),
conclusion: Box::new(self.normalize_structure(conclusion)),
},
ClaimStructure::Quantified {
quantifier,
variable,
domain,
body,
} => ClaimStructure::Quantified {
quantifier: *quantifier,
variable: self.normalize_text(variable),
domain: self.normalize_text(domain),
body: Box::new(self.normalize_structure(body)),
},
ClaimStructure::Temporal {
event,
time_relation,
reference,
} => ClaimStructure::Temporal {
event: self.normalize_text(event),
time_relation: time_relation.clone(),
reference: self.normalize_text(reference),
},
ClaimStructure::Causal {
cause,
effect,
strength,
} => ClaimStructure::Causal {
cause: Box::new(self.normalize_structure(cause)),
effect: Box::new(self.normalize_structure(effect)),
strength: strength.clone(),
},
ClaimStructure::Modal { claim, modality } => ClaimStructure::Modal {
claim: Box::new(self.normalize_structure(claim)),
modality: modality.clone(),
},
ClaimStructure::Raw(text) => ClaimStructure::Raw(self.normalize_text(text)),
}
}
}
impl ClaimNormalizer for DefaultClaimNormalizer {
fn normalize(&self, claim: &LogicalClaim) -> LogicalClaim {
LogicalClaim {
id: claim.id.clone(),
text: self.normalize_text(&claim.text),
structure: self.normalize_structure(&claim.structure),
confidence: claim.confidence,
source_span: claim.source_span,
}
}
fn normalize_text(&self, text: &str) -> String {
let mut result = text.to_string();
result = result.trim().to_string();
if self.lowercase {
result = result.to_lowercase();
}
if self.normalize_whitespace {
result = result.split_whitespace().collect::<Vec<_>>().join(" ");
}
if self.remove_punctuation {
result = result
.chars()
.filter(|c| !c.is_ascii_punctuation())
.collect();
}
result
}
}
pub struct ClaimDeduplicator {
normalizer: Box<dyn ClaimNormalizer>,
merge_confidence: bool,
}
impl Default for ClaimDeduplicator {
fn default() -> Self {
Self::new(Box::new(DefaultClaimNormalizer::default()))
}
}
impl ClaimDeduplicator {
#[must_use]
pub fn new(normalizer: Box<dyn ClaimNormalizer>) -> Self {
Self {
normalizer,
merge_confidence: true,
}
}
#[must_use]
pub fn with_merge_confidence(mut self, merge: bool) -> Self {
self.merge_confidence = merge;
self
}
fn hash_structure(structure: &ClaimStructure) -> u64 {
let mut hasher = DefaultHasher::new();
Self::hash_structure_recursive(structure, &mut hasher);
hasher.finish()
}
fn hash_structure_recursive<H: Hasher>(structure: &ClaimStructure, hasher: &mut H) {
match structure {
ClaimStructure::Predicate {
subject,
predicate,
object,
} => {
"predicate".hash(hasher);
subject.hash(hasher);
predicate.hash(hasher);
object.hash(hasher);
}
ClaimStructure::Comparison {
left,
operator,
right,
} => {
"comparison".hash(hasher);
left.hash(hasher);
format!("{operator:?}").hash(hasher);
right.hash(hasher);
}
ClaimStructure::And(claims) => {
"and".hash(hasher);
for claim in claims {
Self::hash_structure_recursive(claim, hasher);
}
}
ClaimStructure::Or(claims) => {
"or".hash(hasher);
for claim in claims {
Self::hash_structure_recursive(claim, hasher);
}
}
ClaimStructure::Not(inner) => {
"not".hash(hasher);
Self::hash_structure_recursive(inner, hasher);
}
ClaimStructure::Implies {
premise,
conclusion,
} => {
"implies".hash(hasher);
Self::hash_structure_recursive(premise, hasher);
Self::hash_structure_recursive(conclusion, hasher);
}
ClaimStructure::Quantified {
quantifier,
variable,
domain,
body,
} => {
"quantified".hash(hasher);
format!("{quantifier:?}").hash(hasher);
variable.hash(hasher);
domain.hash(hasher);
Self::hash_structure_recursive(body, hasher);
}
ClaimStructure::Temporal {
event,
time_relation,
reference,
} => {
"temporal".hash(hasher);
event.hash(hasher);
format!("{time_relation:?}").hash(hasher);
reference.hash(hasher);
}
ClaimStructure::Causal {
cause,
effect,
strength,
} => {
"causal".hash(hasher);
Self::hash_structure_recursive(cause, hasher);
Self::hash_structure_recursive(effect, hasher);
format!("{strength:?}").hash(hasher);
}
ClaimStructure::Modal { claim, modality } => {
"modal".hash(hasher);
Self::hash_structure_recursive(claim, hasher);
format!("{modality:?}").hash(hasher);
}
ClaimStructure::Raw(text) => {
"raw".hash(hasher);
text.hash(hasher);
}
}
}
#[must_use]
pub fn deduplicate(&self, claims: Vec<LogicalClaim>) -> Vec<LogicalClaim> {
let mut groups: HashMap<u64, Vec<LogicalClaim>> = HashMap::new();
for claim in claims {
let normalized = self.normalizer.normalize(&claim);
let hash = Self::hash_structure(&normalized.structure);
groups.entry(hash).or_default().push(normalized);
}
groups
.into_values()
.map(|group| self.merge_group(group))
.collect()
}
fn merge_group(&self, mut group: Vec<LogicalClaim>) -> LogicalClaim {
if group.len() == 1 {
return group.remove(0);
}
let mut merged = group.remove(0);
if self.merge_confidence {
let total_confidence: f32 =
merged.confidence + group.iter().map(|c| c.confidence).sum::<f32>();
#[allow(clippy::cast_precision_loss)]
let count = (group.len() + 1) as f32;
merged.confidence = total_confidence / count;
merged.confidence = (merged.confidence * 1.1).min(1.0);
}
merged
}
#[must_use]
pub fn find_groups(&self, claims: &[LogicalClaim]) -> HashMap<u64, Vec<LogicalClaim>> {
let mut groups: HashMap<u64, Vec<LogicalClaim>> = HashMap::new();
for claim in claims {
let normalized = self.normalizer.normalize(claim);
let hash = Self::hash_structure(&normalized.structure);
groups.entry(hash).or_default().push(claim.clone());
}
groups
}
}
#[cfg(disabled)]
#[allow(clippy::similar_names)]
mod tests {
use super::*;
#[test]
fn test_normalize_text_lowercase() {
let normalizer = DefaultClaimNormalizer::new();
assert_eq!(normalizer.normalize_text("Hello World"), "hello world");
}
#[test]
fn test_normalize_text_whitespace() {
let normalizer = DefaultClaimNormalizer::new();
assert_eq!(
normalizer.normalize_text(" hello world "),
"hello world"
);
}
#[test]
fn test_normalize_text_combined() {
let normalizer = DefaultClaimNormalizer::new()
.with_lowercase(true)
.with_normalize_whitespace(true);
assert_eq!(
normalizer.normalize_text(" Hello WORLD "),
"hello world"
);
}
#[test]
fn test_normalize_text_no_lowercase() {
let normalizer = DefaultClaimNormalizer::new().with_lowercase(false);
assert_eq!(normalizer.normalize_text("Hello World"), "Hello World");
}
#[test]
fn test_normalize_text_remove_punctuation() {
let normalizer = DefaultClaimNormalizer::new().with_remove_punctuation(true);
assert_eq!(normalizer.normalize_text("hello, world!"), "hello world");
}
#[test]
fn test_normalize_claim_predicate() {
let normalizer = DefaultClaimNormalizer::new();
let claim = LogicalClaim::new(
" The SKY is BLUE ",
ClaimStructure::Predicate {
subject: " SKY ".to_string(),
predicate: "IS".to_string(),
object: Some("BLUE".to_string()),
},
);
let normalized = normalizer.normalize(&claim);
assert_eq!(normalized.text, "the sky is blue");
if let ClaimStructure::Predicate {
subject,
predicate,
object,
} = &normalized.structure
{
assert_eq!(subject, "sky");
assert_eq!(predicate, "is");
assert_eq!(object.as_deref(), Some("blue"));
} else {
panic!("Expected Predicate structure");
}
}
#[test]
fn test_normalize_claim_raw() {
let normalizer = DefaultClaimNormalizer::new();
let claim = LogicalClaim::new("TEST", ClaimStructure::Raw(" HELLO WORLD ".to_string()));
let normalized = normalizer.normalize(&claim);
if let ClaimStructure::Raw(text) = &normalized.structure {
assert_eq!(text, "hello world");
} else {
panic!("Expected Raw structure");
}
}
#[test]
fn test_deduplicator_identical_claims() {
let dedup = ClaimDeduplicator::default();
let claims = vec![
LogicalClaim::new(
"sky is blue",
ClaimStructure::Raw("sky is blue".to_string()),
)
.with_confidence(0.8),
LogicalClaim::new(
"SKY IS BLUE",
ClaimStructure::Raw("SKY IS BLUE".to_string()),
)
.with_confidence(0.6),
];
let deduped = dedup.deduplicate(claims);
assert_eq!(deduped.len(), 1);
assert!(deduped[0].confidence > 0.7);
}
#[test]
fn test_deduplicator_different_claims() {
let dedup = ClaimDeduplicator::default();
let claims = vec![
LogicalClaim::new(
"sky is blue",
ClaimStructure::Raw("sky is blue".to_string()),
),
LogicalClaim::new(
"grass is green",
ClaimStructure::Raw("grass is green".to_string()),
),
];
let deduped = dedup.deduplicate(claims);
assert_eq!(deduped.len(), 2);
}
#[test]
fn test_deduplicator_no_merge_confidence() {
let dedup = ClaimDeduplicator::default().with_merge_confidence(false);
let claims = vec![
LogicalClaim::new("test", ClaimStructure::Raw("test".to_string())).with_confidence(0.8),
LogicalClaim::new("TEST", ClaimStructure::Raw("TEST".to_string())).with_confidence(0.6),
];
let deduped = dedup.deduplicate(claims);
assert_eq!(deduped.len(), 1);
assert!((deduped[0].confidence - 0.8).abs() < 0.01);
}
#[test]
fn test_find_groups() {
let dedup = ClaimDeduplicator::default();
let claims = vec![
LogicalClaim::new("a", ClaimStructure::Raw("test".to_string())),
LogicalClaim::new("b", ClaimStructure::Raw("TEST".to_string())),
LogicalClaim::new("c", ClaimStructure::Raw("other".to_string())),
];
let groups = dedup.find_groups(&claims);
assert_eq!(groups.len(), 2);
assert!(groups.values().any(|g| g.len() == 2));
}
#[test]
fn test_hash_structure_predicate() {
let s1 = ClaimStructure::Predicate {
subject: "sky".to_string(),
predicate: "is".to_string(),
object: Some("blue".to_string()),
};
let s2 = ClaimStructure::Predicate {
subject: "sky".to_string(),
predicate: "is".to_string(),
object: Some("blue".to_string()),
};
let s3 = ClaimStructure::Predicate {
subject: "grass".to_string(),
predicate: "is".to_string(),
object: Some("green".to_string()),
};
assert_eq!(
ClaimDeduplicator::hash_structure(&s1),
ClaimDeduplicator::hash_structure(&s2)
);
assert_ne!(
ClaimDeduplicator::hash_structure(&s1),
ClaimDeduplicator::hash_structure(&s3)
);
}
#[test]
fn test_normalize_nested_structure() {
let normalizer = DefaultClaimNormalizer::new();
let claim = LogicalClaim::new(
"test",
ClaimStructure::And(vec![
ClaimStructure::Raw(" HELLO ".to_string()),
ClaimStructure::Raw(" WORLD ".to_string()),
]),
);
let normalized = normalizer.normalize(&claim);
if let ClaimStructure::And(parts) = &normalized.structure {
if let ClaimStructure::Raw(text) = &parts[0] {
assert_eq!(text, "hello");
}
if let ClaimStructure::Raw(text) = &parts[1] {
assert_eq!(text, "world");
}
} else {
panic!("Expected And structure");
}
}
#[cfg(disabled)]
#[allow(dead_code)]
mod proptest_tests {
#[allow(unused_imports)]
use super::*;
use proptest::prelude::*;
proptest! {
fn normalization_is_idempotent(
text in "[A-Za-z ]{5,50}"
) {
let normalizer = DefaultClaimNormalizer::new();
let norm1 = normalizer.normalize_text(&text);
let norm2 = normalizer.normalize_text(&norm1);
prop_assert_eq!(norm1, norm2,
"Normalization should be idempotent: normalize(normalize(x)) == normalize(x)");
}
#[test]
fn normalization_is_consistent(
text in "[A-Za-z ]{5,50}"
) {
let normalizer = DefaultClaimNormalizer::new();
let norm1 = normalizer.normalize_text(&text);
let norm2 = normalizer.normalize_text(&text);
let norm3 = normalizer.normalize_text(&text);
prop_assert_eq!(norm1, norm2,
"First and second normalization should be identical");
prop_assert_eq!(norm2, norm3,
"Second and third normalization should be identical");
}
#[test]
fn normalized_text_no_leading_trailing_space(
text in " *[A-Za-z ]{5,50} *"
) {
let normalizer = DefaultClaimNormalizer::new();
let normalized = normalizer.normalize_text(&text);
prop_assert!(!normalized.starts_with(' '),
"Normalized text should not start with space: {:?}", normalized);
prop_assert!(!normalized.ends_with(' '),
"Normalized text should not end with space: {:?}", normalized);
}
#[test]
fn lowercase_normalization_produces_lowercase(
text in "[A-Za-z ]{5,30}"
) {
let normalizer = DefaultClaimNormalizer::new().with_lowercase(true);
let normalized = normalizer.normalize_text(&text);
for c in normalized.chars() {
if c.is_alphabetic() {
prop_assert!(c.is_lowercase(),
"Character '{}' should be lowercase in {:?}", c, normalized);
}
}
}
#[test]
fn no_lowercase_preserves_case(
text in "[A-Z]{2,10}"
) {
let normalizer = DefaultClaimNormalizer::new().with_lowercase(false);
let normalized = normalizer.normalize_text(&text);
let has_uppercase = normalized.chars().any(|c| c.is_uppercase());
if !text.trim().is_empty() {
prop_assert!(has_uppercase || normalized.trim().is_empty(),
"Should preserve at least some uppercase: input={:?}, output={:?}",
text, normalized);
}
}
#[test]
fn normalized_whitespace_no_multiple_spaces(
text in "[A-Za-z ]{10,50}"
) {
let normalizer = DefaultClaimNormalizer::new().with_normalize_whitespace(true);
let normalized = normalizer.normalize_text(&text);
prop_assert!(!normalized.contains(" "),
"Normalized text should not contain multiple spaces: {:?}", normalized);
}
#[test]
fn deduplication_reduces_identical_claims(
base_text in "[a-z ]{5,20}",
num_duplicates in 2usize..10
) {
let dedup = ClaimDeduplicator::default();
let claims: Vec<_> = (0..num_duplicates)
.map(|_| {
LogicalClaim::new(
&base_text,
ClaimStructure::Raw(base_text.clone()),
)
})
.collect();
let deduped = dedup.deduplicate(claims);
prop_assert_eq!(deduped.len(), 1,
"Deduplication should reduce {} identical claims to 1, got {}",
num_duplicates, deduped.len());
}
#[test]
fn deduplication_preserves_unique_claims(
texts in prop::collection::hash_set("[a-z]{3,15}", 2..10)
) {
let dedup = ClaimDeduplicator::default();
let claims: Vec<_> = texts.iter()
.map(|text| {
LogicalClaim::new(text, ClaimStructure::Raw(text.clone()))
})
.collect();
let num_unique = texts.len();
let deduped = dedup.deduplicate(claims);
prop_assert_eq!(deduped.len(), num_unique,
"Deduplication should preserve {} unique claims, got {}",
num_unique, deduped.len());
}
#[test]
fn hash_structure_consistent(
text in "[a-z]{5,20}"
) {
let s1 = ClaimStructure::Raw(text.clone());
let s2 = ClaimStructure::Raw(text.clone());
let hash1 = ClaimDeduplicator::hash_structure(&s1);
let hash2 = ClaimDeduplicator::hash_structure(&s2);
prop_assert_eq!(hash1, hash2,
"Identical structures should have same hash");
}
#[test]
fn hash_structure_different(
text1 in "[a-z]{5,15}",
text2 in "[a-z]{5,15}"
) {
if text1 != text2 {
let s1 = ClaimStructure::Raw(text1);
let s2 = ClaimStructure::Raw(text2);
let hash1 = ClaimDeduplicator::hash_structure(&s1);
let hash2 = ClaimDeduplicator::hash_structure(&s2);
prop_assert!(hash1 == hash2 || hash1 != hash2,
"Hash function should work (tautology to avoid false positives)");
}
}
#[test]
fn normalization_preserves_confidence(
text in "[A-Za-z ]{5,30}",
confidence in 0.0f32..1.0
) {
let normalizer = DefaultClaimNormalizer::new();
let claim = LogicalClaim::new(
&text,
ClaimStructure::Raw(text.clone()),
).with_confidence(confidence);
let normalized = normalizer.normalize(&claim);
prop_assert_eq!(normalized.confidence, confidence,
"Normalization should preserve confidence");
}
#[test]
fn find_groups_correct_count(
texts in prop::collection::vec("[a-z]{3,10}", 1..15)
) {
let dedup = ClaimDeduplicator::default();
let claims: Vec<_> = texts.iter()
.map(|text| {
LogicalClaim::new(text, ClaimStructure::Raw(text.clone()))
})
.collect();
let groups = dedup.find_groups(&claims);
prop_assert!(groups.len() <= claims.len(),
"Number of groups {} should not exceed number of claims {}",
groups.len(), claims.len());
let total_in_groups: usize = groups.values().map(|g| g.len()).sum();
prop_assert_eq!(total_in_groups, claims.len(),
"Total claims in groups should equal input");
}
#[test]
fn deduplication_merge_boosts_confidence(
text in "[a-z]{5,15}",
num_copies in 2usize..8
) {
let dedup = ClaimDeduplicator::default().with_merge_confidence(true);
let base_confidence = 0.5;
let claims: Vec<_> = (0..num_copies)
.map(|_| {
LogicalClaim::new(
&text,
ClaimStructure::Raw(text.clone()),
).with_confidence(base_confidence)
})
.collect();
let deduped = dedup.deduplicate(claims);
prop_assert_eq!(deduped.len(), 1,
"Should deduplicate to single claim");
prop_assert!(deduped[0].confidence >= base_confidence,
"Merged confidence should be at least base confidence");
prop_assert!(deduped[0].confidence <= 1.0,
"Merged confidence should not exceed 1.0");
}
#[test]
fn normalization_handles_empty() {
let normalizer = DefaultClaimNormalizer::new();
let normalized = normalizer.normalize_text("");
prop_assert_eq!(normalized, "",
"Empty text should remain empty after normalization");
}
}
}
}