use std::collections::HashMap;
use candle_core::{Device, Tensor};
use candle_transformers::models::deepseek2::NonZeroOp;
use kodama::{Method, linkage};
use crate::error::ColbertError;
struct Dsu {
parent: HashMap<usize, usize>,
}
impl Dsu {
fn new() -> Self {
Self {
parent: HashMap::new(),
}
}
fn find(&mut self, i: usize) -> usize {
let parent_i = *self.parent.entry(i).or_insert(i);
if parent_i == i {
return i;
}
let root = self.find(parent_i);
self.parent.insert(i, root);
root
}
fn union(&mut self, i: usize, j: usize) {
let root_i = self.find(i);
let root_j = self.find(j);
if root_i != root_j {
self.parent.insert(root_i, root_j);
}
}
}
pub fn hierarchical_pooling(
documents_embeddings: &Tensor,
pool_factor: usize,
) -> Result<Tensor, ColbertError> {
if pool_factor <= 1 {
return Ok(documents_embeddings.clone());
}
if documents_embeddings.dims().len() != 3 {
return Err(ColbertError::Operation(format!(
"Input tensor must have 3 dimensions [batch_size, n_tokens, embedding_dim], but got {} dimensions.",
documents_embeddings.dims().len()
)));
}
let original_device = documents_embeddings.device().clone();
let documents_embeddings = if !original_device.is_cpu() {
documents_embeddings.to_device(&Device::Cpu)?
} else {
documents_embeddings.clone()
};
let device = documents_embeddings.device();
let mut all_pooled_embeddings: Vec<Tensor> = Vec::new();
let batch_size = documents_embeddings.dim(0)?;
for i in 0..batch_size {
let document_embeddings =
documents_embeddings.narrow(0, i, 1)?.squeeze(0)?;
let n_tokens = document_embeddings.dim(0)?;
if 1 >= n_tokens {
all_pooled_embeddings.push(document_embeddings.clone());
continue;
}
let protected_embeddings = document_embeddings.narrow(0, 0, 1)?;
let embeddings_to_pool =
document_embeddings.narrow(0, 1, n_tokens - 1)?;
let num_embeddings_to_pool = embeddings_to_pool.dim(0)?;
if num_embeddings_to_pool <= 1 {
let final_embeddings =
Tensor::cat(&[&protected_embeddings, &embeddings_to_pool], 0)?;
all_pooled_embeddings.push(final_embeddings);
continue;
}
let cosine_similarities =
embeddings_to_pool.matmul(&embeddings_to_pool.t()?)?;
let distance_matrix = (1.0 - cosine_similarities)?.to_vec2::<f32>()?;
let mut condensed_distances: Vec<f64> = Vec::with_capacity(
num_embeddings_to_pool * (num_embeddings_to_pool - 1) / 2,
);
for row in 0..num_embeddings_to_pool - 1 {
let row_slice = &distance_matrix[row];
for &dist in &row_slice[row + 1..num_embeddings_to_pool] {
condensed_distances.push(dist as f64);
}
}
let dend = linkage(
&mut condensed_distances,
num_embeddings_to_pool,
Method::Ward,
);
let num_clusters = (num_embeddings_to_pool / pool_factor).max(1);
if num_clusters >= num_embeddings_to_pool {
let final_embeddings =
Tensor::cat(&[&protected_embeddings, &embeddings_to_pool], 0)?;
all_pooled_embeddings.push(final_embeddings);
continue;
}
let mut dsu = Dsu::new();
let num_merges = num_embeddings_to_pool - num_clusters;
for step in dend.steps().iter().take(num_merges) {
dsu.union(step.cluster1, step.cluster2);
}
let mut root_to_label = HashMap::new();
let mut labels = Vec::with_capacity(num_embeddings_to_pool);
for i in 0..num_embeddings_to_pool {
let root = dsu.find(i);
let next_label = root_to_label.len();
let label = *root_to_label.entry(root).or_insert(next_label);
labels.push(label as u32);
}
let labels_tensor = Tensor::new(labels.as_slice(), device)?;
let mut pooled_document_embeddings: Vec<Tensor> =
Vec::with_capacity(num_clusters);
for cluster_id in 0..num_clusters {
let mask = labels_tensor.eq(cluster_id as u32)?;
let cluster_indices = mask.nonzero()?.squeeze(1)?;
if cluster_indices.dim(0)? > 0 {
let cluster_embeddings =
embeddings_to_pool.index_select(&cluster_indices, 0)?;
pooled_document_embeddings.push(cluster_embeddings.mean(0)?);
}
}
let mut final_embeddings_list = pooled_document_embeddings;
final_embeddings_list.push(protected_embeddings.get(0)?);
let final_doc_tensor = Tensor::stack(&final_embeddings_list, 0)?;
all_pooled_embeddings.push(final_doc_tensor);
}
let stacked = Tensor::stack(&all_pooled_embeddings, 0)?;
if original_device.is_cpu() {
Ok(stacked)
} else {
Ok(stacked.to_device(&original_device)?)
}
}