use crate::eval::cluster_encoder::{
ClusterEmbedding, ClusterEncoder, ClusterMention, LocalCluster, MergeScorer,
};
use std::collections::HashMap;
#[cfg(feature = "candle")]
use {
anno::backends::encoder_candle::CandleTextEncoder,
candle_core::{DType, Device, Module, Tensor, D},
candle_nn::{layer_norm, linear, LayerNorm, Linear, VarBuilder},
};
#[derive(Debug, Clone)]
pub struct NeuralClusterConfig {
pub hidden_dim: usize,
pub num_heads: usize,
pub dropout: f32,
pub use_cls_pooling: bool,
pub max_mentions: usize,
pub merge_threshold: f32,
}
impl Default for NeuralClusterConfig {
fn default() -> Self {
Self {
hidden_dim: 768, num_heads: 12,
dropout: 0.1,
use_cls_pooling: false, max_mentions: 50,
merge_threshold: 0.5,
}
}
}
#[cfg(feature = "candle")]
pub struct CandleClusterEncoder<E: CandleTextEncoder> {
encoder: E,
pooling_layer: ClusterPoolingLayer,
config: NeuralClusterConfig,
device: Device,
}
#[cfg(feature = "candle")]
impl<E: CandleTextEncoder> CandleClusterEncoder<E> {
pub fn new(encoder: E, config: NeuralClusterConfig) -> crate::Result<Self> {
let device = Device::cuda_if_available(0).unwrap_or(Device::Cpu);
let pooling_layer = ClusterPoolingLayer::new(&config, &device)?;
Ok(Self {
encoder,
pooling_layer,
config,
device,
})
}
fn encode_cluster_impl(&self, cluster: &LocalCluster) -> crate::Result<Vec<f32>> {
if cluster.mentions.is_empty() {
return Ok(vec![0.0; self.config.hidden_dim]);
}
let mentions: Vec<&ClusterMention> = cluster
.mentions
.iter()
.take(self.config.max_mentions)
.collect();
let mut mention_embeddings = Vec::new();
for mention in &mentions {
let (embeddings, seq_len) = self.encoder.encode(&mention.text)?;
let hidden_dim = self.config.hidden_dim;
let mut pooled = vec![0.0f32; hidden_dim];
if seq_len > 0 {
for i in 0..seq_len {
for j in 0..hidden_dim {
pooled[j] += embeddings[i * hidden_dim + j];
}
}
for p in &mut pooled {
*p /= seq_len as f32;
}
}
mention_embeddings.push(pooled);
}
let num_mentions = mention_embeddings.len();
let flat: Vec<f32> = mention_embeddings.into_iter().flatten().collect();
let tensor = Tensor::from_vec(flat, (num_mentions, self.config.hidden_dim), &self.device)
.map_err(|e: candle_core::Error| crate::Error::Inference(e.to_string()))?;
let pooled = self.pooling_layer.forward(&tensor)?;
let result = pooled
.to_vec1::<f32>()
.map_err(|e: candle_core::Error| crate::Error::Inference(e.to_string()))?;
Ok(result)
}
}
#[cfg(feature = "candle")]
impl<E: CandleTextEncoder> ClusterEncoder for CandleClusterEncoder<E> {
fn encode_cluster(
&self,
cluster: &LocalCluster,
_hidden_states: Option<&[Vec<f32>]>,
) -> ClusterEmbedding {
let embedding = self
.encode_cluster_impl(cluster)
.unwrap_or_else(|_| vec![0.0; self.config.hidden_dim]);
ClusterEmbedding {
cluster_id: cluster.id,
context_id: cluster.context_id,
embedding,
mention_count: cluster.mentions.len(),
}
}
fn embedding_dim(&self) -> usize {
self.config.hidden_dim
}
}
#[cfg(feature = "candle")]
struct ClusterPoolingLayer {
wq: Linear,
wk: Linear,
wv: Linear,
wo: Linear,
ln: LayerNorm,
num_heads: usize,
head_dim: usize,
}
#[cfg(feature = "candle")]
impl ClusterPoolingLayer {
fn new(config: &NeuralClusterConfig, device: &Device) -> crate::Result<Self> {
let varmap = candle_nn::VarMap::new();
let vb = VarBuilder::from_varmap(&varmap, DType::F32, device);
let hidden_dim = config.hidden_dim;
let num_heads = config.num_heads;
let head_dim = hidden_dim / num_heads;
let wq = linear(hidden_dim, hidden_dim, vb.pp("wq"))
.map_err(|e| crate::Error::Inference(format!("Linear wq: {}", e)))?;
let wk = linear(hidden_dim, hidden_dim, vb.pp("wk"))
.map_err(|e| crate::Error::Inference(format!("Linear wk: {}", e)))?;
let wv = linear(hidden_dim, hidden_dim, vb.pp("wv"))
.map_err(|e| crate::Error::Inference(format!("Linear wv: {}", e)))?;
let wo = linear(hidden_dim, hidden_dim, vb.pp("wo"))
.map_err(|e| crate::Error::Inference(format!("Linear wo: {}", e)))?;
let ln = layer_norm(hidden_dim, 1e-5, vb.pp("ln"))
.map_err(|e| crate::Error::Inference(format!("LayerNorm: {}", e)))?;
Ok(Self {
wq,
wk,
wv,
wo,
ln,
num_heads,
head_dim,
})
}
fn forward(&self, x: &Tensor) -> crate::Result<Tensor> {
let (seq_len, hidden_dim) = x
.dims2()
.map_err(|e| crate::Error::Inference(format!("Dims: {}", e)))?;
let q = self
.wq
.forward(x)
.map_err(|e| crate::Error::Inference(format!("Q: {}", e)))?;
let k = self
.wk
.forward(x)
.map_err(|e| crate::Error::Inference(format!("K: {}", e)))?;
let v = self
.wv
.forward(x)
.map_err(|e| crate::Error::Inference(format!("V: {}", e)))?;
let q = q
.reshape((seq_len, self.num_heads, self.head_dim))
.map_err(|e| crate::Error::Inference(format!("Q reshape: {}", e)))?
.transpose(0, 1)
.map_err(|e| crate::Error::Inference(format!("Q transpose: {}", e)))?;
let k = k
.reshape((seq_len, self.num_heads, self.head_dim))
.map_err(|e| crate::Error::Inference(format!("K reshape: {}", e)))?
.transpose(0, 1)
.map_err(|e| crate::Error::Inference(format!("K transpose: {}", e)))?;
let v = v
.reshape((seq_len, self.num_heads, self.head_dim))
.map_err(|e| crate::Error::Inference(format!("V reshape: {}", e)))?
.transpose(0, 1)
.map_err(|e| crate::Error::Inference(format!("V transpose: {}", e)))?;
let scale = (self.head_dim as f64).sqrt();
let scores = q
.matmul(
&k.transpose(1, 2)
.map_err(|e| crate::Error::Inference(format!("K^T: {}", e)))?,
)
.map_err(|e| crate::Error::Inference(format!("QK^T: {}", e)))?
.affine(1.0 / scale, 0.0)
.map_err(|e| crate::Error::Inference(format!("Scale: {}", e)))?;
let attn = candle_nn::ops::softmax(&scores, D::Minus1)
.map_err(|e| crate::Error::Inference(format!("Softmax: {}", e)))?;
let context = attn
.matmul(&v)
.map_err(|e| crate::Error::Inference(format!("Attn*V: {}", e)))?;
let context = context
.transpose(0, 1)
.map_err(|e| crate::Error::Inference(format!("Context transpose: {}", e)))?
.reshape((seq_len, hidden_dim))
.map_err(|e| crate::Error::Inference(format!("Context reshape: {}", e)))?;
let out = self
.wo
.forward(&context)
.map_err(|e| crate::Error::Inference(format!("Wo: {}", e)))?;
let out = (x + &out).map_err(|e| crate::Error::Inference(format!("Residual: {}", e)))?;
let out = self
.ln
.forward(&out)
.map_err(|e| crate::Error::Inference(format!("LayerNorm: {}", e)))?;
out.mean(0)
.map_err(|e| crate::Error::Inference(format!("Mean pool: {}", e)))
}
}
#[cfg(feature = "candle")]
pub struct NeuralMergeScorer {
bilinear: Linear,
classifier: Linear,
device: Device,
}
#[cfg(feature = "candle")]
impl NeuralMergeScorer {
pub fn new(hidden_dim: usize) -> crate::Result<Self> {
let device = Device::cuda_if_available(0).unwrap_or(Device::Cpu);
let varmap = candle_nn::VarMap::new();
let vb = VarBuilder::from_varmap(&varmap, DType::F32, &device);
let bilinear = linear(hidden_dim * 2, hidden_dim, vb.pp("bilinear"))
.map_err(|e| crate::Error::Inference(format!("Bilinear: {}", e)))?;
let classifier = linear(hidden_dim, 1, vb.pp("classifier"))
.map_err(|e| crate::Error::Inference(format!("Classifier: {}", e)))?;
Ok(Self {
bilinear,
classifier,
device,
})
}
fn score_impl(&self, emb_a: &[f32], emb_b: &[f32]) -> crate::Result<f32> {
let concat: Vec<f32> = emb_a.iter().chain(emb_b.iter()).cloned().collect();
let input = Tensor::from_vec(concat, (1, emb_a.len() + emb_b.len()), &self.device)
.map_err(|e| crate::Error::Inference(format!("Input tensor: {}", e)))?;
let hidden = self
.bilinear
.forward(&input)
.map_err(|e| crate::Error::Inference(format!("Bilinear forward: {}", e)))?;
let hidden = hidden
.relu()
.map_err(|e| crate::Error::Inference(format!("ReLU: {}", e)))?;
let logit = self
.classifier
.forward(&hidden)
.map_err(|e| crate::Error::Inference(format!("Classifier forward: {}", e)))?;
let prob = candle_nn::ops::sigmoid(&logit)
.map_err(|e| crate::Error::Inference(format!("Sigmoid: {}", e)))?;
let score = prob
.to_vec2::<f32>()
.map_err(|e| crate::Error::Inference(format!("To vec: {}", e)))?[0][0];
Ok(score)
}
}
#[cfg(feature = "candle")]
impl MergeScorer for NeuralMergeScorer {
fn score(&self, embedding_a: &ClusterEmbedding, embedding_b: &ClusterEmbedding) -> f32 {
self.score_impl(&embedding_a.embedding, &embedding_b.embedding)
.unwrap_or(0.0)
}
}
pub struct CDCRAdapter;
impl CDCRAdapter {
pub fn documents_to_clusters(
docs: &[crate::eval::cdcr::Document],
) -> HashMap<usize, Vec<LocalCluster>> {
let mut all_clusters = HashMap::new();
for (doc_idx, doc) in docs.iter().enumerate() {
let mut clusters = Vec::new();
for (chain_idx, chain) in doc.coref_chains.iter().enumerate() {
let mut cluster = LocalCluster::new(chain_idx, doc_idx);
for mention in &chain.mentions {
cluster.add_mention(ClusterMention {
start: mention.start,
end: mention.end,
text: mention.text.clone(),
context_id: doc_idx,
});
}
cluster.compute_canonical();
clusters.push(cluster);
}
let chained_starts: std::collections::HashSet<usize> = doc
.coref_chains
.iter()
.flat_map(|c| c.mentions.iter().map(|m| m.start))
.collect();
for entity in &doc.entities {
if !chained_starts.contains(&entity.start()) {
let mut cluster = LocalCluster::new(clusters.len(), doc_idx);
cluster.add_mention(ClusterMention {
start: entity.start(),
end: entity.end(),
text: entity.text.clone(),
context_id: doc_idx,
});
cluster.compute_canonical();
clusters.push(cluster);
}
}
all_clusters.insert(doc_idx, clusters);
}
all_clusters
}
pub fn clusters_to_crossdoc(
merged: &[crate::eval::cluster_encoder::MergedCluster],
docs: &[crate::eval::cdcr::Document],
) -> Vec<crate::eval::cdcr::CrossDocCluster> {
merged
.iter()
.map(|m| {
let mut cluster = crate::eval::cdcr::CrossDocCluster::new(
m.id as u64,
m.canonical.as_deref().unwrap_or(""),
);
for mention in &m.mentions {
if let Some(doc) = docs.get(mention.context_id) {
let entity_idx = doc
.entities
.iter()
.position(|e| e.start() == mention.start && e.end() == mention.end)
.unwrap_or(0);
cluster.add_mention(&doc.id, entity_idx);
}
}
cluster
})
.collect()
}
}
pub struct UnifiedCrossContextResolver<E: ClusterEncoder, S: MergeScorer> {
encoder: E,
scorer: S,
config: CrossContextConfig,
}
#[derive(Debug, Clone)]
pub struct CrossContextConfig {
pub window_size: usize,
pub window_overlap: usize,
pub merge_threshold: f32,
}
impl Default for CrossContextConfig {
fn default() -> Self {
Self {
window_size: 4000,
window_overlap: 256,
merge_threshold: 0.5,
}
}
}
impl<E: ClusterEncoder, S: MergeScorer> UnifiedCrossContextResolver<E, S> {
pub fn new(encoder: E, scorer: S, config: CrossContextConfig) -> Self {
Self {
encoder,
scorer,
config,
}
}
pub fn resolve_documents(
&self,
docs: &[crate::eval::cdcr::Document],
) -> Vec<crate::eval::cdcr::CrossDocCluster> {
let local_clusters = CDCRAdapter::documents_to_clusters(docs);
let merged = self.merge_clusters(&local_clusters);
CDCRAdapter::clusters_to_crossdoc(&merged, docs)
}
fn merge_clusters(
&self,
local_clusters: &HashMap<usize, Vec<LocalCluster>>,
) -> Vec<crate::eval::cluster_encoder::MergedCluster> {
let mut embeddings: Vec<ClusterEmbedding> = Vec::new();
for clusters in local_clusters.values() {
for cluster in clusters {
let emb = self.encoder.encode_cluster(cluster, None);
embeddings.push(emb);
}
}
if embeddings.is_empty() {
return Vec::new();
}
let mut merge_pairs: Vec<(usize, usize, f32)> = Vec::new();
for (i, emb_a) in embeddings.iter().enumerate() {
for (j, emb_b) in embeddings.iter().enumerate().skip(i + 1) {
if emb_a.context_id == emb_b.context_id {
continue;
}
let score = self.scorer.score(emb_a, emb_b);
if score >= self.config.merge_threshold {
merge_pairs.push((i, j, score));
}
}
}
merge_pairs.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
let n = embeddings.len();
let mut parent: Vec<usize> = (0..n).collect();
let mut rank: Vec<usize> = vec![0; n];
fn find(parent: &mut [usize], i: usize) -> usize {
if parent[i] != i {
parent[i] = find(parent, parent[i]);
}
parent[i]
}
fn union(parent: &mut [usize], rank: &mut [usize], x: usize, y: usize) {
let px = find(parent, x);
let py = find(parent, y);
if px == py {
return;
}
match rank[px].cmp(&rank[py]) {
std::cmp::Ordering::Less => parent[px] = py,
std::cmp::Ordering::Greater => parent[py] = px,
std::cmp::Ordering::Equal => {
parent[py] = px;
rank[px] += 1;
}
}
}
for (i, j, _) in merge_pairs {
union(&mut parent, &mut rank, i, j);
}
let mut cluster_map: HashMap<usize, Vec<usize>> = HashMap::new();
for i in 0..n {
let root = find(&mut parent, i);
cluster_map.entry(root).or_default().push(i);
}
cluster_map
.into_iter()
.enumerate()
.map(|(merged_id, (_root, indices))| {
let mut merged = crate::eval::cluster_encoder::MergedCluster {
id: merged_id,
source_clusters: Vec::new(),
mentions: Vec::new(),
canonical: None,
};
for idx in indices {
let emb = &embeddings[idx];
merged
.source_clusters
.push((emb.context_id, emb.cluster_id));
if let Some(clusters) = local_clusters.get(&emb.context_id) {
if let Some(cluster) = clusters.iter().find(|c| c.id == emb.cluster_id) {
merged.mentions.extend(cluster.mentions.clone());
if merged.canonical.is_none() {
merged.canonical = cluster.canonical.clone();
}
}
}
}
merged
})
.collect()
}
}
pub struct IncrementalCorefAdapter;
impl IncrementalCorefAdapter {
pub fn windows_to_clusters(windows: &[WindowOutput]) -> HashMap<usize, Vec<LocalCluster>> {
let mut all_clusters = HashMap::new();
for (window_idx, output) in windows.iter().enumerate() {
let mut clusters = Vec::new();
for (chain_idx, chain) in output.chains.iter().enumerate() {
let mut cluster = LocalCluster::new(chain_idx, window_idx);
for mention in &chain.mentions {
cluster.add_mention(ClusterMention {
start: mention.start,
end: mention.end,
text: mention.text.clone(),
context_id: window_idx,
});
}
cluster.compute_canonical();
clusters.push(cluster);
}
all_clusters.insert(window_idx, clusters);
}
all_clusters
}
pub fn clusters_to_chains(
merged: &[crate::eval::cluster_encoder::MergedCluster],
) -> Vec<crate::eval::coref::CorefChain> {
use crate::eval::coref::{CorefChain, Mention, MentionType};
merged
.iter()
.map(|m| {
let mentions: Vec<Mention> = m
.mentions
.iter()
.map(|cm| Mention {
text: cm.text.clone(),
start: cm.start,
end: cm.end,
head_start: None,
head_end: None,
entity_type: None,
mention_type: Some(MentionType::Proper),
})
.collect();
CorefChain::new(mentions)
})
.collect()
}
}
#[derive(Debug, Clone)]
pub struct WindowOutput {
pub window_idx: usize,
pub start_offset: usize,
pub end_offset: usize,
pub chains: Vec<crate::eval::coref::CorefChain>,
}
impl WindowOutput {
pub fn new(
window_idx: usize,
start_offset: usize,
end_offset: usize,
chains: Vec<crate::eval::coref::CorefChain>,
) -> Self {
Self {
window_idx,
start_offset,
end_offset,
chains,
}
}
}
impl<E: ClusterEncoder, S: MergeScorer> UnifiedCrossContextResolver<E, S> {
pub fn resolve_long_document_windows(
&self,
windows: &[WindowOutput],
) -> Vec<crate::eval::coref::CorefChain> {
let local_clusters = IncrementalCorefAdapter::windows_to_clusters(windows);
let merged = self.merge_clusters(&local_clusters);
IncrementalCorefAdapter::clusters_to_chains(&merged)
}
pub fn config(&self) -> &CrossContextConfig {
&self.config
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::eval::cluster_encoder::{CosineMergeScorer, HeuristicClusterEncoder};
#[test]
fn test_cdcr_adapter_empty() {
let docs: Vec<crate::eval::cdcr::Document> = vec![];
let clusters = CDCRAdapter::documents_to_clusters(&docs);
assert!(clusters.is_empty());
}
#[test]
fn test_cdcr_adapter_single_doc() {
use crate::eval::cdcr::Document;
use anno::{Entity, EntityType};
let doc = Document::new("doc1", "Obama visited France.").with_entities(vec![Entity::new(
"Obama",
EntityType::Person,
0,
5,
0.9,
)]);
let clusters = CDCRAdapter::documents_to_clusters(&[doc]);
assert_eq!(clusters.len(), 1);
assert!(!clusters[&0].is_empty());
}
#[test]
fn test_unified_resolver() {
use crate::eval::cdcr::Document;
use anno::{Entity, EntityType};
let encoder = HeuristicClusterEncoder::new(64);
let scorer = CosineMergeScorer::new();
let config = CrossContextConfig::default();
let resolver = UnifiedCrossContextResolver::new(encoder, scorer, config);
let docs =
vec![
Document::new("doc1", "Barack Obama gave a speech.").with_entities(vec![
Entity::new("Barack Obama", EntityType::Person, 0, 12, 0.9),
]),
Document::new("doc2", "Obama met with leaders.").with_entities(vec![Entity::new(
"Obama",
EntityType::Person,
0,
5,
0.9,
)]),
];
let result = resolver.resolve_documents(&docs);
assert!(!result.is_empty());
}
#[test]
fn test_neural_config_default() {
let config = NeuralClusterConfig::default();
assert_eq!(config.hidden_dim, 768);
assert_eq!(config.num_heads, 12);
assert!(!config.use_cls_pooling);
}
#[test]
fn test_incremental_adapter_empty() {
let windows: Vec<WindowOutput> = vec![];
let clusters = IncrementalCorefAdapter::windows_to_clusters(&windows);
assert!(clusters.is_empty());
}
#[test]
fn test_incremental_adapter_single_window() {
use crate::eval::coref::{CorefChain, Mention};
use anno_core::MentionType;
fn new_mention(text: &str, start: usize, end: usize, mt: MentionType) -> Mention {
Mention {
text: text.to_string(),
start,
end,
head_start: None,
head_end: None,
entity_type: None,
mention_type: Some(mt),
}
}
let chain = CorefChain::new(vec![
new_mention("Obama", 0, 5, MentionType::Proper),
new_mention("he", 20, 22, MentionType::Pronominal),
]);
let window = WindowOutput::new(0, 0, 100, vec![chain]);
let clusters = IncrementalCorefAdapter::windows_to_clusters(&[window]);
assert_eq!(clusters.len(), 1);
assert_eq!(clusters[&0].len(), 1);
assert_eq!(clusters[&0][0].mentions.len(), 2);
}
#[test]
fn test_incremental_adapter_multi_window() {
use crate::eval::coref::{CorefChain, Mention};
use anno_core::MentionType;
fn new_mention(text: &str, start: usize, end: usize, mt: MentionType) -> Mention {
Mention {
text: text.to_string(),
start,
end,
head_start: None,
head_end: None,
entity_type: None,
mention_type: Some(mt),
}
}
let window1 = WindowOutput::new(
0,
0,
100,
vec![CorefChain::new(vec![new_mention(
"Obama",
0,
5,
MentionType::Proper,
)])],
);
let window2 = WindowOutput::new(
1,
80,
180,
vec![CorefChain::new(vec![new_mention(
"the President",
90,
103,
MentionType::Nominal,
)])],
);
let clusters = IncrementalCorefAdapter::windows_to_clusters(&[window1, window2]);
assert_eq!(clusters.len(), 2);
assert_eq!(clusters[&0].len(), 1);
assert_eq!(clusters[&1].len(), 1);
}
#[test]
fn test_long_document_resolution() {
use crate::eval::coref::{CorefChain, Mention};
use anno_core::MentionType;
fn new_mention(text: &str, start: usize, end: usize, mt: MentionType) -> Mention {
Mention {
text: text.to_string(),
start,
end,
head_start: None,
head_end: None,
entity_type: None,
mention_type: Some(mt),
}
}
let encoder = HeuristicClusterEncoder::new(64);
let scorer = CosineMergeScorer::new();
let config = CrossContextConfig {
merge_threshold: 0.3,
..Default::default()
};
let resolver = UnifiedCrossContextResolver::new(encoder, scorer, config);
let window1 = WindowOutput::new(
0,
0,
1000,
vec![CorefChain::new(vec![new_mention(
"Barack Obama",
0,
12,
MentionType::Proper,
)])],
);
let window2 = WindowOutput::new(
1,
800,
1800,
vec![CorefChain::new(vec![new_mention(
"Obama",
900,
905,
MentionType::Proper,
)])],
);
let chains = resolver.resolve_long_document_windows(&[window1, window2]);
assert!(!chains.is_empty());
}
}