use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ForwardCenter {
pub entity_id: u64,
pub realization: String,
pub salience: f64,
pub grammatical_role: Option<GrammaticalRole>,
pub info_status: InformationStatus,
pub offset: usize,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
pub enum GrammaticalRole {
Subject,
DirectObject,
IndirectObject,
Oblique,
#[default]
Other,
}
impl GrammaticalRole {
#[must_use]
pub const fn salience_weight(&self) -> f64 {
match self {
GrammaticalRole::Subject => 1.0,
GrammaticalRole::DirectObject => 0.8,
GrammaticalRole::IndirectObject => 0.6,
GrammaticalRole::Oblique => 0.4,
GrammaticalRole::Other => 0.3,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
pub enum InformationStatus {
New,
Inferrable,
#[default]
Evoked,
Unused,
}
impl InformationStatus {
#[must_use]
pub const fn is_hearer_old(&self) -> bool {
matches!(
self,
InformationStatus::Evoked | InformationStatus::Unused | InformationStatus::Inferrable
)
}
#[must_use]
pub const fn salience_boost(&self) -> f64 {
match self {
InformationStatus::Evoked => 0.3,
InformationStatus::Unused => 0.2,
InformationStatus::Inferrable => 0.1,
InformationStatus::New => 0.0,
}
}
}
impl ForwardCenter {
#[must_use]
pub fn new(entity_id: u64, realization: impl Into<String>, salience: f64) -> Self {
Self {
entity_id,
realization: realization.into(),
salience,
grammatical_role: None,
info_status: InformationStatus::default(),
offset: 0,
}
}
#[must_use]
pub fn with_role(mut self, role: GrammaticalRole) -> Self {
self.grammatical_role = Some(role);
self
}
#[must_use]
pub fn with_info_status(mut self, status: InformationStatus) -> Self {
self.info_status = status;
self
}
#[must_use]
pub fn at_offset(mut self, offset: usize) -> Self {
self.offset = offset;
self
}
#[must_use]
pub fn effective_salience(&self) -> f64 {
let role_weight = self.grammatical_role.map_or(0.5, |r| r.salience_weight());
self.salience * role_weight + self.info_status.salience_boost()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default, Serialize, Deserialize)]
pub enum CenteringTransition {
Continue,
Retain,
SmoothShift,
RoughShift,
#[default]
Null,
}
impl CenteringTransition {
#[must_use]
pub const fn coherence_score(&self) -> f64 {
match self {
CenteringTransition::Continue => 1.0,
CenteringTransition::Retain => 0.75,
CenteringTransition::SmoothShift => 0.5,
CenteringTransition::RoughShift => 0.25,
CenteringTransition::Null => 0.0,
}
}
#[must_use]
pub const fn is_continuing(&self) -> bool {
matches!(
self,
CenteringTransition::Continue | CenteringTransition::Retain
)
}
#[must_use]
pub const fn is_shifting(&self) -> bool {
matches!(
self,
CenteringTransition::SmoothShift | CenteringTransition::RoughShift
)
}
#[must_use]
pub const fn as_str(&self) -> &'static str {
match self {
CenteringTransition::Continue => "CONTINUE",
CenteringTransition::Retain => "RETAIN",
CenteringTransition::SmoothShift => "SMOOTH-SHIFT",
CenteringTransition::RoughShift => "ROUGH-SHIFT",
CenteringTransition::Null => "NULL",
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct CenteringState {
pub utterance_idx: usize,
pub cf: Vec<ForwardCenter>,
pub cb: Option<u64>,
pub cp: Option<u64>,
pub transition: CenteringTransition,
pub recency_scores: HashMap<u64, f64>,
}
impl CenteringState {
#[must_use]
pub fn new(utterance_idx: usize) -> Self {
Self {
utterance_idx,
cf: Vec::new(),
cb: None,
cp: None,
transition: CenteringTransition::Null,
recency_scores: HashMap::new(),
}
}
#[must_use]
pub fn with_cf(mut self, cf: Vec<ForwardCenter>) -> Self {
self.cf = cf;
self.cf.sort_by(|a, b| {
b.effective_salience()
.partial_cmp(&a.effective_salience())
.unwrap_or(std::cmp::Ordering::Equal)
});
self.cp = self.cf.first().map(|fc| fc.entity_id);
self
}
#[must_use]
pub fn mentioned_entities(&self) -> Vec<u64> {
self.cf.iter().map(|fc| fc.entity_id).collect()
}
#[must_use]
pub fn mentions(&self, entity_id: u64) -> bool {
self.cf.iter().any(|fc| fc.entity_id == entity_id)
}
#[must_use]
pub fn get_fc(&self, entity_id: u64) -> Option<&ForwardCenter> {
self.cf.iter().find(|fc| fc.entity_id == entity_id)
}
#[must_use]
pub fn coherence_score(&self) -> f64 {
self.transition.coherence_score()
}
}
#[derive(Debug, Clone)]
pub struct CenteringConfig {
pub recency_decay: f64,
pub role_weight: f64,
pub info_status_weight: f64,
pub use_recency: bool,
pub recency_window: usize,
}
impl Default for CenteringConfig {
fn default() -> Self {
Self {
recency_decay: 0.5,
role_weight: 1.0,
info_status_weight: 1.0,
use_recency: true, recency_window: 5,
}
}
}
pub fn track_centers(
utterances: &[Vec<ForwardCenter>],
config: &CenteringConfig,
) -> Vec<CenteringState> {
let mut states: Vec<CenteringState> = Vec::with_capacity(utterances.len());
for (i, cf_list) in utterances.iter().enumerate() {
let mut state = CenteringState::new(i).with_cf(cf_list.clone());
if i == 0 {
state.cb = None;
state.transition = CenteringTransition::Null;
} else {
let prev_state = &states[i - 1];
state.cb = compute_cb(prev_state, &state);
state.transition = compute_transition(prev_state, &state);
}
if config.use_recency {
state.recency_scores = compute_recency_scores(&states, &state, config);
}
states.push(state);
}
states
}
fn compute_cb(prev_state: &CenteringState, current_state: &CenteringState) -> Option<u64> {
for fc in &prev_state.cf {
if current_state.mentions(fc.entity_id) {
return Some(fc.entity_id);
}
}
None
}
pub fn compute_transition(
prev_state: &CenteringState,
current_state: &CenteringState,
) -> CenteringTransition {
let prev_cb = prev_state.cb;
let curr_cb = current_state.cb;
let curr_cp = current_state.cp;
if curr_cb.is_none() {
return CenteringTransition::Null;
}
let cb = curr_cb.expect("curr_cb.is_none() checked above");
match (prev_cb, curr_cp) {
(Some(prev), Some(cp)) if prev == cb && cb == cp => CenteringTransition::Continue,
(Some(prev), Some(cp)) if prev == cb && cb != cp => CenteringTransition::Retain,
(Some(prev), Some(cp)) if prev != cb && cb == cp => CenteringTransition::SmoothShift,
(Some(prev), Some(cp)) if prev != cb && cb != cp => CenteringTransition::RoughShift,
(None, Some(cp)) if cb == cp => CenteringTransition::Continue,
(None, Some(_)) => CenteringTransition::Retain,
_ => CenteringTransition::Null,
}
}
fn compute_recency_scores(
prev_states: &[CenteringState],
current: &CenteringState,
config: &CenteringConfig,
) -> HashMap<u64, f64> {
let mut scores: HashMap<u64, f64> = HashMap::new();
for fc in ¤t.cf {
scores.insert(fc.entity_id, fc.effective_salience());
}
let start = prev_states.len().saturating_sub(config.recency_window);
for (i, state) in prev_states[start..].iter().enumerate() {
let age = prev_states.len() - start - i; let decay = config.recency_decay.powi(age as i32);
for fc in &state.cf {
let recency_score = fc.effective_salience() * decay;
scores
.entry(fc.entity_id)
.and_modify(|s| *s += recency_score)
.or_insert(recency_score);
}
}
scores
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct CoherenceAnalysis {
pub transition_counts: HashMap<String, usize>,
pub avg_coherence: f64,
pub continuity_ratio: f64,
pub shift_count: usize,
pub max_continuity_run: usize,
}
pub fn analyze_coherence(states: &[CenteringState]) -> CoherenceAnalysis {
if states.is_empty() {
return CoherenceAnalysis::default();
}
let mut counts: HashMap<String, usize> = HashMap::new();
let mut total_coherence = 0.0;
let mut continuing = 0;
let mut shifts = 0;
let mut current_run = 0;
let mut max_run = 0;
for state in states {
let key = state.transition.as_str().to_string();
*counts.entry(key).or_default() += 1;
total_coherence += state.transition.coherence_score();
if state.transition.is_continuing() {
continuing += 1;
current_run += 1;
max_run = max_run.max(current_run);
} else {
current_run = 0;
if state.transition.is_shifting() {
shifts += 1;
}
}
}
CoherenceAnalysis {
transition_counts: counts,
avg_coherence: total_coherence / states.len() as f64,
continuity_ratio: continuing as f64 / states.len() as f64,
shift_count: shifts,
max_continuity_run: max_run,
}
}
pub fn score_antecedents(
anaphor_utterance: usize,
states: &[CenteringState],
config: &CenteringConfig,
) -> HashMap<u64, f64> {
let mut scores: HashMap<u64, f64> = HashMap::new();
if anaphor_utterance == 0 || states.is_empty() {
return scores;
}
let current_state = states.get(anaphor_utterance);
if let Some(state) = current_state {
if config.use_recency && !state.recency_scores.is_empty() {
return state.recency_scores.clone();
}
}
for (i, state) in states[..anaphor_utterance].iter().enumerate().rev() {
let age = anaphor_utterance - i;
let decay = config.recency_decay.powi(age as i32);
for fc in &state.cf {
let score = fc.effective_salience() * decay;
let cb_bonus = if Some(fc.entity_id) == state.cb {
0.2
} else {
0.0
};
scores
.entry(fc.entity_id)
.and_modify(|s| *s = s.max(score + cb_bonus))
.or_insert(score + cb_bonus);
}
}
scores
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_forward_center_salience() {
let fc = ForwardCenter::new(1, "John", 0.8)
.with_role(GrammaticalRole::Subject)
.with_info_status(InformationStatus::Evoked);
let expected = 0.8 * 1.0 + 0.3;
assert!((fc.effective_salience() - expected).abs() < 0.001);
}
#[test]
fn test_centering_continue() {
let utterances = vec![
vec![
ForwardCenter::new(1, "John", 1.0).with_role(GrammaticalRole::Subject),
ForwardCenter::new(2, "Mary", 0.8).with_role(GrammaticalRole::DirectObject),
],
vec![
ForwardCenter::new(1, "He", 0.9).with_role(GrammaticalRole::Subject),
ForwardCenter::new(3, "the book", 0.7).with_role(GrammaticalRole::DirectObject),
],
];
let config = CenteringConfig::default();
let states = track_centers(&utterances, &config);
assert_eq!(states[0].cb, None);
assert_eq!(states[0].cp, Some(1));
assert_eq!(states[1].cb, Some(1));
assert_eq!(states[1].cp, Some(1));
assert_eq!(states[1].transition, CenteringTransition::Continue);
}
#[test]
fn test_centering_retain() {
let utterances = vec![
vec![ForwardCenter::new(1, "John", 1.0).with_role(GrammaticalRole::Subject)],
vec![
ForwardCenter::new(2, "Mary", 1.0).with_role(GrammaticalRole::Subject),
ForwardCenter::new(1, "him", 0.7).with_role(GrammaticalRole::DirectObject),
],
];
let config = CenteringConfig::default();
let states = track_centers(&utterances, &config);
assert_eq!(states[1].cb, Some(1));
assert_eq!(states[1].cp, Some(2));
assert_eq!(states[1].transition, CenteringTransition::Retain);
}
#[test]
fn test_centering_smooth_shift() {
let utterances = vec![
vec![ForwardCenter::new(1, "John", 1.0)],
vec![
ForwardCenter::new(2, "Mary", 1.0).with_role(GrammaticalRole::Subject),
],
];
let config = CenteringConfig::default();
let states = track_centers(&utterances, &config);
assert_eq!(states[1].cb, None);
assert_eq!(states[1].transition, CenteringTransition::Null);
}
#[test]
fn test_coherence_analysis() {
let utterances = vec![
vec![ForwardCenter::new(1, "John", 1.0)],
vec![ForwardCenter::new(1, "he", 0.9)],
vec![ForwardCenter::new(1, "him", 0.8)],
vec![ForwardCenter::new(2, "Mary", 1.0)], ];
let config = CenteringConfig::default();
let states = track_centers(&utterances, &config);
let analysis = analyze_coherence(&states);
assert!(analysis.avg_coherence > 0.0);
assert!(analysis.max_continuity_run >= 2);
}
#[test]
fn test_recency_scores() {
let config = CenteringConfig {
use_recency: true,
recency_decay: 0.5,
recency_window: 3,
..Default::default()
};
let utterances = vec![
vec![ForwardCenter::new(1, "John", 1.0)],
vec![ForwardCenter::new(2, "Mary", 1.0)],
vec![
ForwardCenter::new(1, "he", 0.9),
ForwardCenter::new(2, "her", 0.8),
],
];
let states = track_centers(&utterances, &config);
let scores = &states[2].recency_scores;
assert!(scores.contains_key(&1));
assert!(scores.contains_key(&2));
}
#[test]
fn test_transition_coherence_ordering() {
assert!(
CenteringTransition::Continue.coherence_score()
> CenteringTransition::Retain.coherence_score()
);
assert!(
CenteringTransition::Retain.coherence_score()
> CenteringTransition::SmoothShift.coherence_score()
);
assert!(
CenteringTransition::SmoothShift.coherence_score()
> CenteringTransition::RoughShift.coherence_score()
);
}
#[test]
fn test_grammatical_role_ordering() {
assert!(
GrammaticalRole::Subject.salience_weight()
> GrammaticalRole::DirectObject.salience_weight()
);
assert!(
GrammaticalRole::DirectObject.salience_weight()
> GrammaticalRole::IndirectObject.salience_weight()
);
assert!(
GrammaticalRole::IndirectObject.salience_weight()
> GrammaticalRole::Oblique.salience_weight()
);
}
#[test]
fn test_information_status() {
assert!(InformationStatus::Evoked.is_hearer_old());
assert!(InformationStatus::Unused.is_hearer_old());
assert!(InformationStatus::Inferrable.is_hearer_old());
assert!(!InformationStatus::New.is_hearer_old());
assert!(
InformationStatus::Evoked.salience_boost() > InformationStatus::New.salience_boost()
);
}
#[test]
fn test_score_antecedents() {
let utterances = vec![
vec![
ForwardCenter::new(1, "John", 1.0).with_role(GrammaticalRole::Subject),
ForwardCenter::new(2, "Mary", 0.8),
],
vec![ForwardCenter::new(3, "the book", 0.7)],
];
let config = CenteringConfig::default();
let states = track_centers(&utterances, &config);
let scores = score_antecedents(1, &states, &config);
assert!(scores.get(&1).unwrap_or(&0.0) >= scores.get(&2).unwrap_or(&0.0));
}
}