use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ReferenceCandidate {
pub entity_id: u64,
pub description: String,
pub weight: f64,
pub source: CandidateSource,
pub satisfied_constraints: Vec<String>,
pub violated_constraints: Vec<String>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
pub enum CandidateSource {
#[default]
Discourse,
WorldKnowledge,
Bridging,
Accommodation,
Cataphoric,
}
impl ReferenceCandidate {
#[must_use]
pub fn new(entity_id: u64, description: impl Into<String>, weight: f64) -> Self {
Self {
entity_id,
description: description.into(),
weight,
source: CandidateSource::default(),
satisfied_constraints: Vec::new(),
violated_constraints: Vec::new(),
}
}
#[must_use]
pub fn with_source(mut self, source: CandidateSource) -> Self {
self.source = source;
self
}
#[must_use]
pub fn satisfies(mut self, constraint: impl Into<String>) -> Self {
self.satisfied_constraints.push(constraint.into());
self
}
#[must_use]
pub fn violates(mut self, constraint: impl Into<String>) -> Self {
self.violated_constraints.push(constraint.into());
self
}
#[must_use]
pub fn has_violations(&self) -> bool {
!self.violated_constraints.is_empty()
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct UncertainReference {
pub description: String,
pub candidates: Vec<ReferenceCandidate>,
pub resolved: bool,
pub resolved_entity: Option<u64>,
pub constraints: Vec<ReferenceConstraint>,
pub discourse_position: Option<usize>,
pub is_cataphoric: bool,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ReferenceConstraint {
pub kind: ConstraintKind,
pub value: String,
pub is_hard: bool,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum ConstraintKind {
Gender,
Number,
Person,
Animacy,
SemanticType,
Binding,
Salience,
Recency,
}
impl UncertainReference {
#[must_use]
pub fn new(description: impl Into<String>) -> Self {
Self {
description: description.into(),
candidates: Vec::new(),
resolved: false,
resolved_entity: None,
constraints: Vec::new(),
discourse_position: None,
is_cataphoric: false,
}
}
#[must_use]
pub fn cataphoric(mut self) -> Self {
self.is_cataphoric = true;
self
}
#[must_use]
pub fn at_position(mut self, position: usize) -> Self {
self.discourse_position = Some(position);
self
}
pub fn add_constraint(
&mut self,
kind: ConstraintKind,
value: impl Into<String>,
is_hard: bool,
) {
self.constraints.push(ReferenceConstraint {
kind,
value: value.into(),
is_hard,
});
}
pub fn add_candidate(&mut self, candidate: ReferenceCandidate) {
if let Some(existing) = self
.candidates
.iter_mut()
.find(|c| c.entity_id == candidate.entity_id)
{
existing.weight = log_sum_exp(existing.weight, candidate.weight);
} else {
self.candidates.push(candidate);
}
}
pub fn update_evidence(&mut self, entity_id: u64, evidence: f64) {
if let Some(candidate) = self
.candidates
.iter_mut()
.find(|c| c.entity_id == entity_id)
{
candidate.weight += evidence;
}
}
pub fn prune(&mut self, threshold: f64) {
self.candidates.retain(|c| c.weight >= threshold);
}
pub fn prune_violations(&mut self) {
self.candidates.retain(|c| !c.has_violations());
}
#[must_use]
pub fn ranked_candidates(&self) -> Vec<&ReferenceCandidate> {
let mut sorted: Vec<_> = self.candidates.iter().collect();
sorted.sort_by(|a, b| {
b.weight
.partial_cmp(&a.weight)
.unwrap_or(std::cmp::Ordering::Equal)
});
sorted
}
#[must_use]
pub fn best_candidate(&self) -> Option<&ReferenceCandidate> {
self.ranked_candidates().first().copied()
}
#[must_use]
pub fn entropy(&self) -> f64 {
if self.candidates.is_empty() {
return 0.0;
}
let probs = self.probabilities();
let mut h = 0.0;
for p in probs.values() {
if *p > 0.0 {
h -= p * p.log2();
}
}
h
}
#[must_use]
pub fn probabilities(&self) -> HashMap<u64, f64> {
if self.candidates.is_empty() {
return HashMap::new();
}
let max_weight = self
.candidates
.iter()
.map(|c| c.weight)
.fold(f64::NEG_INFINITY, f64::max);
let exp_sum: f64 = self
.candidates
.iter()
.map(|c| (c.weight - max_weight).exp())
.sum();
self.candidates
.iter()
.map(|c| (c.entity_id, (c.weight - max_weight).exp() / exp_sum))
.collect()
}
#[must_use]
pub fn is_ambiguous(&self, threshold: f64) -> bool {
let probs = self.probabilities();
let high_prob_count = probs.values().filter(|&&p| p >= threshold).count();
high_prob_count > 1
}
#[must_use]
pub fn resolve(&mut self) -> Option<ReferenceCandidate> {
if self.resolved {
return self
.candidates
.iter()
.find(|c| Some(c.entity_id) == self.resolved_entity)
.cloned();
}
let best = self.best_candidate()?.clone();
self.resolved = true;
self.resolved_entity = Some(best.entity_id);
Some(best)
}
pub fn resolve_to(&mut self, entity_id: u64) {
self.resolved = true;
self.resolved_entity = Some(entity_id);
}
#[must_use]
pub fn is_resolved(&self) -> bool {
self.resolved
}
#[must_use]
pub fn candidate_count(&self) -> usize {
self.candidates.len()
}
#[must_use]
pub fn is_unresolvable(&self) -> bool {
self.candidates.is_empty()
}
}
#[derive(Debug, Clone, Default)]
pub struct DeferredResolutionContext {
pub pending: Vec<UncertainReference>,
pub resolved: Vec<UncertainReference>,
pub entity_mentions: HashMap<u64, Vec<usize>>,
pub position: usize,
}
impl DeferredResolutionContext {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn add_uncertain(&mut self, reference: UncertainReference) {
self.pending.push(reference);
}
pub fn record_mention(&mut self, entity_id: u64) {
self.entity_mentions
.entry(entity_id)
.or_default()
.push(self.position);
}
pub fn advance(&mut self) {
self.position += 1;
}
pub fn try_resolve_cataphoric(&mut self, new_entities: &[(u64, String, f64)]) {
for reference in &mut self.pending {
if reference.is_cataphoric && !reference.is_resolved() {
for (entity_id, description, weight) in new_entities {
reference.add_candidate(
ReferenceCandidate::new(*entity_id, description.clone(), *weight)
.with_source(CandidateSource::Cataphoric),
);
}
}
}
}
pub fn resolve_all(&mut self) {
for reference in &mut self.pending {
let _ = reference.resolve();
}
self.resolved.append(&mut self.pending);
}
#[must_use]
pub fn ambiguous_references(&self, threshold: f64) -> Vec<&UncertainReference> {
self.pending
.iter()
.filter(|r| r.is_ambiguous(threshold))
.collect()
}
#[must_use]
pub fn statistics(&self) -> ResolutionStatistics {
let total = self.pending.len() + self.resolved.len();
let resolved_count = self.resolved.len();
let ambiguous_count = self.pending.iter().filter(|r| r.is_ambiguous(0.3)).count();
let unresolvable_count = self.pending.iter().filter(|r| r.is_unresolvable()).count();
let cataphoric_count = self.pending.iter().filter(|r| r.is_cataphoric).count();
let avg_entropy = if self.pending.is_empty() {
0.0
} else {
self.pending.iter().map(|r| r.entropy()).sum::<f64>() / self.pending.len() as f64
};
ResolutionStatistics {
total,
resolved: resolved_count,
pending: self.pending.len(),
ambiguous: ambiguous_count,
unresolvable: unresolvable_count,
cataphoric: cataphoric_count,
avg_entropy,
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ResolutionStatistics {
pub total: usize,
pub resolved: usize,
pub pending: usize,
pub ambiguous: usize,
pub unresolvable: usize,
pub cataphoric: usize,
pub avg_entropy: f64,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum ResolutionStrategy {
#[default]
Greedy,
Deferred,
Probabilistic,
Confident(u8),
}
impl ResolutionStrategy {
#[must_use]
pub fn should_resolve(&self, reference: &UncertainReference) -> bool {
match self {
ResolutionStrategy::Greedy => true,
ResolutionStrategy::Deferred => false,
ResolutionStrategy::Probabilistic => false,
ResolutionStrategy::Confident(threshold) => {
let probs = reference.probabilities();
probs.values().any(|&p| p * 100.0 >= *threshold as f64)
}
}
}
}
pub fn resolve_uncertain(
reference: &mut UncertainReference,
strategy: ResolutionStrategy,
) -> Option<ReferenceCandidate> {
if reference.is_resolved() {
return reference
.candidates
.iter()
.find(|c| Some(c.entity_id) == reference.resolved_entity)
.cloned();
}
if strategy.should_resolve(reference) {
reference.resolve()
} else {
None
}
}
fn log_sum_exp(a: f64, b: f64) -> f64 {
let max = a.max(b);
if max == f64::NEG_INFINITY {
f64::NEG_INFINITY
} else {
max + ((a - max).exp() + (b - max).exp()).ln()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_uncertain_reference_basic() {
let mut reference = UncertainReference::new("the person");
reference.add_candidate(ReferenceCandidate::new(1, "John", 0.8));
reference.add_candidate(ReferenceCandidate::new(2, "Mary", 0.6));
assert_eq!(reference.candidate_count(), 2);
assert!(!reference.is_resolved());
let best = reference.best_candidate().unwrap();
assert_eq!(best.entity_id, 1);
}
#[test]
fn test_evidence_update() {
let mut reference = UncertainReference::new("the person");
reference.add_candidate(ReferenceCandidate::new(1, "John", 0.5));
reference.add_candidate(ReferenceCandidate::new(2, "Mary", 0.5));
assert_eq!(
reference.candidates[0].weight,
reference.candidates[1].weight
);
reference.update_evidence(1, 0.3);
reference.update_evidence(2, -0.2);
let best = reference.best_candidate().unwrap();
assert_eq!(best.entity_id, 1);
}
#[test]
fn test_probabilities() {
let mut reference = UncertainReference::new("test");
reference.add_candidate(ReferenceCandidate::new(1, "A", 1.0));
reference.add_candidate(ReferenceCandidate::new(2, "B", 0.0));
let probs = reference.probabilities();
assert!(probs[&1] > probs[&2]);
let sum: f64 = probs.values().sum();
assert!((sum - 1.0).abs() < 0.001);
}
#[test]
fn test_entropy() {
let mut equal_ref = UncertainReference::new("test");
equal_ref.add_candidate(ReferenceCandidate::new(1, "A", 0.0));
equal_ref.add_candidate(ReferenceCandidate::new(2, "B", 0.0));
let mut unequal_ref = UncertainReference::new("test");
unequal_ref.add_candidate(ReferenceCandidate::new(1, "A", 10.0));
unequal_ref.add_candidate(ReferenceCandidate::new(2, "B", 0.0));
assert!(equal_ref.entropy() > unequal_ref.entropy());
}
#[test]
fn test_ambiguity_detection() {
let mut reference = UncertainReference::new("test");
reference.add_candidate(ReferenceCandidate::new(1, "A", 10.0));
reference.add_candidate(ReferenceCandidate::new(2, "B", 0.0));
assert!(!reference.is_ambiguous(0.3));
let mut ambiguous_ref = UncertainReference::new("test");
ambiguous_ref.add_candidate(ReferenceCandidate::new(1, "A", 0.0));
ambiguous_ref.add_candidate(ReferenceCandidate::new(2, "B", 0.0));
assert!(ambiguous_ref.is_ambiguous(0.3));
}
#[test]
fn test_resolution() {
let mut reference = UncertainReference::new("test");
reference.add_candidate(ReferenceCandidate::new(1, "John", 0.8));
reference.add_candidate(ReferenceCandidate::new(2, "Mary", 0.6));
assert!(!reference.is_resolved());
let resolved = reference.resolve().unwrap();
assert_eq!(resolved.entity_id, 1);
assert!(reference.is_resolved());
assert_eq!(reference.resolved_entity, Some(1));
}
#[test]
fn test_constraint_violations() {
let candidate = ReferenceCandidate::new(1, "John", 0.8).violates("gender:feminine");
assert!(candidate.has_violations());
let mut reference = UncertainReference::new("she");
reference.add_candidate(candidate);
reference.add_candidate(ReferenceCandidate::new(2, "Mary", 0.6));
reference.prune_violations();
assert_eq!(reference.candidate_count(), 1);
assert_eq!(reference.candidates[0].entity_id, 2);
}
#[test]
fn test_cataphoric_resolution() {
let mut context = DeferredResolutionContext::new();
let cataphoric = UncertainReference::new("she").cataphoric().at_position(0);
context.add_uncertain(cataphoric);
context.advance();
context.try_resolve_cataphoric(&[(1, "Mary".to_string(), 0.9)]);
assert_eq!(context.pending[0].candidate_count(), 1);
}
#[test]
fn test_resolution_strategy() {
let mut reference = UncertainReference::new("test");
reference.add_candidate(ReferenceCandidate::new(1, "A", 5.0));
reference.add_candidate(ReferenceCandidate::new(2, "B", 0.0));
assert!(ResolutionStrategy::Greedy.should_resolve(&reference));
assert!(!ResolutionStrategy::Deferred.should_resolve(&reference));
assert!(ResolutionStrategy::Confident(90).should_resolve(&reference));
}
#[test]
fn test_context_statistics() {
let mut context = DeferredResolutionContext::new();
let mut resolved = UncertainReference::new("resolved");
resolved.add_candidate(ReferenceCandidate::new(1, "A", 0.9));
let _ = resolved.resolve();
context.resolved.push(resolved);
let mut ambiguous = UncertainReference::new("ambiguous");
ambiguous.add_candidate(ReferenceCandidate::new(2, "B", 0.0));
ambiguous.add_candidate(ReferenceCandidate::new(3, "C", 0.0));
context.pending.push(ambiguous);
let cataphoric = UncertainReference::new("cataphoric").cataphoric();
context.pending.push(cataphoric);
let stats = context.statistics();
assert_eq!(stats.total, 3);
assert_eq!(stats.resolved, 1);
assert_eq!(stats.pending, 2);
assert_eq!(stats.ambiguous, 1);
assert_eq!(stats.cataphoric, 1);
}
}