use crate::Slab;
#[derive(Debug, Clone)]
pub struct LateChunkingPooler {
dim: usize,
}
impl LateChunkingPooler {
pub fn new(dim: usize) -> Self {
Self { dim }
}
pub fn pool(
&self,
token_embeddings: &[Vec<f32>],
chunks: &[Slab],
doc_len: usize,
) -> Vec<Vec<f32>> {
if token_embeddings.is_empty() || chunks.is_empty() || doc_len == 0 {
return vec![vec![0.0; self.dim]; chunks.len()];
}
let n_tokens = token_embeddings.len();
chunks
.iter()
.map(|chunk| {
let token_start = (chunk.start as f64 / doc_len as f64 * n_tokens as f64) as usize;
let token_end =
((chunk.end as f64 / doc_len as f64 * n_tokens as f64) as usize).min(n_tokens);
if token_end <= token_start {
return self.mean_pool(token_embeddings);
}
self.mean_pool(&token_embeddings[token_start..token_end])
})
.collect()
}
pub fn pool_with_offsets(
&self,
token_embeddings: &[Vec<f32>],
token_offsets: &[(usize, usize)],
chunks: &[Slab],
) -> Vec<Vec<f32>> {
if token_embeddings.is_empty() || chunks.is_empty() {
return vec![vec![0.0; self.dim]; chunks.len()];
}
chunks
.iter()
.map(|chunk| {
let token_indices: Vec<usize> = token_offsets
.iter()
.enumerate()
.filter(|(_, (start, end))| {
*start < chunk.end && *end > chunk.start
})
.map(|(i, _)| i)
.collect();
if token_indices.is_empty() {
return self.mean_pool(token_embeddings);
}
let selected: Vec<&[f32]> = token_indices
.iter()
.filter_map(|&i| token_embeddings.get(i).map(Vec::as_slice))
.collect();
self.mean_pool_refs(&selected)
})
.collect()
}
fn mean_pool(&self, embeddings: &[Vec<f32>]) -> Vec<f32> {
if embeddings.is_empty() {
return vec![0.0; self.dim];
}
let dim = embeddings[0].len();
let mut result = vec![0.0; dim];
let count = embeddings.len() as f32;
for emb in embeddings {
for (i, &v) in emb.iter().enumerate() {
result[i] += v;
}
}
for v in &mut result {
*v /= count;
}
let norm: f32 = result.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 1e-9 {
for v in &mut result {
*v /= norm;
}
}
result
}
fn mean_pool_refs(&self, embeddings: &[&[f32]]) -> Vec<f32> {
if embeddings.is_empty() {
return vec![0.0; self.dim];
}
let dim = embeddings[0].len();
let mut result = vec![0.0; dim];
let count = embeddings.len() as f32;
for emb in embeddings {
for (i, &v) in emb.iter().enumerate() {
result[i] += v;
}
}
for v in &mut result {
*v /= count;
}
let norm: f32 = result.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 1e-9 {
for v in &mut result {
*v /= norm;
}
}
result
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_late_chunking_pooler_basic() {
let pooler = LateChunkingPooler::new(4);
let token_embeddings = vec![
vec![1.0, 0.0, 0.0, 0.0],
vec![0.0, 1.0, 0.0, 0.0],
vec![0.0, 0.0, 1.0, 0.0],
vec![0.0, 0.0, 0.0, 1.0],
vec![1.0, 1.0, 0.0, 0.0],
vec![0.0, 0.0, 1.0, 1.0],
];
let chunks = vec![
Slab::new("first chunk", 0, 10, 0),
Slab::new("second chunk", 10, 20, 1),
];
let chunk_embeddings = pooler.pool(&token_embeddings, &chunks, 20);
assert_eq!(chunk_embeddings.len(), 2);
assert_eq!(chunk_embeddings[0].len(), 4);
assert_eq!(chunk_embeddings[1].len(), 4);
let norm0: f32 = chunk_embeddings[0]
.iter()
.map(|x| x * x)
.sum::<f32>()
.sqrt();
assert!((norm0 - 1.0).abs() < 0.01);
}
#[test]
fn test_pool_with_exact_offsets() {
let pooler = LateChunkingPooler::new(3);
let token_embeddings = vec![
vec![1.0, 0.0, 0.0], vec![0.0, 1.0, 0.0], vec![0.0, 0.0, 1.0], vec![1.0, 1.0, 0.0], vec![0.0, 1.0, 1.0], ];
let token_offsets = vec![
(0, 5), (5, 6), (6, 11), (11, 12), (12, 16), ];
let chunks = vec![
Slab::new("Hello world.", 0, 12, 0),
Slab::new(" Bye", 12, 16, 1),
];
let embeddings = pooler.pool_with_offsets(&token_embeddings, &token_offsets, &chunks);
assert_eq!(embeddings.len(), 2);
}
#[test]
fn test_empty_inputs() {
let pooler = LateChunkingPooler::new(4);
let result = pooler.pool(&[], &[], 0);
assert!(result.is_empty());
let chunks = vec![Slab::new("test", 0, 4, 0)];
let result = pooler.pool(&[], &chunks, 4);
assert_eq!(result.len(), 1);
assert_eq!(result[0].len(), 4);
}
}