use std::collections::HashMap;
use std::sync::Arc;
use parking_lot::RwLock;
pub type DocId = String;
pub type InternalId = u64;
pub type ChunkIndex = u32;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum AggregationMethod {
#[default]
Max,
Mean,
First,
Last,
Sum,
}
impl AggregationMethod {
pub fn from_str(s: &str) -> Option<Self> {
match s.to_lowercase().as_str() {
"max" => Some(Self::Max),
"mean" | "avg" | "average" => Some(Self::Mean),
"first" => Some(Self::First),
"last" => Some(Self::Last),
"sum" => Some(Self::Sum),
_ => None,
}
}
}
#[derive(Debug, Clone)]
pub struct DocumentScore {
pub doc_id: DocId,
pub score: f32,
pub best_chunk: Option<ChunkIndex>,
pub matched_chunks: usize,
pub chunk_scores: Option<Vec<(ChunkIndex, f32)>>,
}
impl DocumentScore {
pub fn aggregate(
doc_id: DocId,
chunk_scores: Vec<(ChunkIndex, f32)>,
method: AggregationMethod,
keep_details: bool,
) -> Self {
if chunk_scores.is_empty() {
return Self {
doc_id,
score: 0.0,
best_chunk: None,
matched_chunks: 0,
chunk_scores: if keep_details { Some(Vec::new()) } else { None },
};
}
let matched_chunks = chunk_scores.len();
let (score, best_chunk) = match method {
AggregationMethod::Max => {
let (_idx, &(chunk, score)) = chunk_scores
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.1.partial_cmp(&b.1).unwrap())
.unwrap();
(score, Some(chunk))
}
AggregationMethod::Mean => {
let sum: f32 = chunk_scores.iter().map(|(_, s)| s).sum();
(sum / chunk_scores.len() as f32, None)
}
AggregationMethod::First => {
let (chunk, score) = chunk_scores
.iter()
.min_by_key(|(idx, _)| *idx)
.copied()
.unwrap();
(score, Some(chunk))
}
AggregationMethod::Last => {
let (chunk, score) = chunk_scores
.iter()
.max_by_key(|(idx, _)| *idx)
.copied()
.unwrap();
(score, Some(chunk))
}
AggregationMethod::Sum => {
let sum: f32 = chunk_scores.iter().map(|(_, s)| s).sum();
(sum, None)
}
};
Self {
doc_id,
score,
best_chunk,
matched_chunks,
chunk_scores: if keep_details {
Some(chunk_scores)
} else {
None
},
}
}
}
#[derive(Debug, Clone)]
pub struct MultiVectorMapping {
internal_to_doc: HashMap<InternalId, (DocId, ChunkIndex)>,
doc_to_internal: HashMap<DocId, Vec<InternalId>>,
next_internal_id: InternalId,
}
impl MultiVectorMapping {
pub fn new() -> Self {
Self {
internal_to_doc: HashMap::new(),
doc_to_internal: HashMap::new(),
next_internal_id: 0,
}
}
pub fn insert_document(&mut self, doc_id: DocId, num_chunks: usize) -> Vec<InternalId> {
self.remove_document(&doc_id);
let mut internal_ids = Vec::with_capacity(num_chunks);
for chunk_idx in 0..num_chunks {
let internal_id = self.next_internal_id;
self.next_internal_id += 1;
self.internal_to_doc
.insert(internal_id, (doc_id.clone(), chunk_idx as ChunkIndex));
internal_ids.push(internal_id);
}
self.doc_to_internal.insert(doc_id, internal_ids.clone());
internal_ids
}
pub fn remove_document(&mut self, doc_id: &str) -> Option<Vec<InternalId>> {
if let Some(internal_ids) = self.doc_to_internal.remove(doc_id) {
for id in &internal_ids {
self.internal_to_doc.remove(id);
}
Some(internal_ids)
} else {
None
}
}
#[inline]
pub fn get_doc(&self, internal_id: InternalId) -> Option<(&DocId, ChunkIndex)> {
self.internal_to_doc.get(&internal_id).map(|(d, c)| (d, *c))
}
pub fn get_internal_ids(&self, doc_id: &str) -> Option<&[InternalId]> {
self.doc_to_internal.get(doc_id).map(|v| v.as_slice())
}
pub fn has_document(&self, doc_id: &str) -> bool {
self.doc_to_internal.contains_key(doc_id)
}
pub fn num_documents(&self) -> usize {
self.doc_to_internal.len()
}
pub fn num_vectors(&self) -> usize {
self.internal_to_doc.len()
}
}
impl Default for MultiVectorMapping {
fn default() -> Self {
Self::new()
}
}
pub struct MultiVectorAggregator {
mapping: Arc<RwLock<MultiVectorMapping>>,
default_method: AggregationMethod,
}
impl MultiVectorAggregator {
pub fn new(mapping: Arc<RwLock<MultiVectorMapping>>) -> Self {
Self {
mapping,
default_method: AggregationMethod::Max,
}
}
pub fn with_default_method(mut self, method: AggregationMethod) -> Self {
self.default_method = method;
self
}
pub fn aggregate(
&self,
vector_results: &[(InternalId, f32)],
method: Option<AggregationMethod>,
limit: usize,
) -> Vec<DocumentScore> {
let method = method.unwrap_or(self.default_method);
let mapping = self.mapping.read();
let mut doc_chunks: HashMap<&DocId, Vec<(ChunkIndex, f32)>> = HashMap::new();
for &(internal_id, score) in vector_results {
if let Some((doc_id, chunk_idx)) = mapping.get_doc(internal_id) {
doc_chunks
.entry(doc_id)
.or_default()
.push((chunk_idx, score));
}
}
let mut results: Vec<DocumentScore> = doc_chunks
.into_iter()
.map(|(doc_id, chunks)| DocumentScore::aggregate(doc_id.clone(), chunks, method, false))
.collect();
results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
results.truncate(limit);
results
}
pub fn aggregate_detailed(
&self,
vector_results: &[(InternalId, f32)],
method: Option<AggregationMethod>,
limit: usize,
) -> Vec<DocumentScore> {
let method = method.unwrap_or(self.default_method);
let mapping = self.mapping.read();
let mut doc_chunks: HashMap<&DocId, Vec<(ChunkIndex, f32)>> = HashMap::new();
for &(internal_id, score) in vector_results {
if let Some((doc_id, chunk_idx)) = mapping.get_doc(internal_id) {
doc_chunks
.entry(doc_id)
.or_default()
.push((chunk_idx, score));
}
}
let mut results: Vec<DocumentScore> = doc_chunks
.into_iter()
.map(|(doc_id, chunks)| DocumentScore::aggregate(doc_id.clone(), chunks, method, true))
.collect();
results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
results.truncate(limit);
results
}
}
#[derive(Debug, Clone)]
pub struct MultiVectorConfig {
pub max_chunks_per_doc: usize,
pub default_aggregation: AggregationMethod,
pub overfetch_factor: f32,
}
impl Default for MultiVectorConfig {
fn default() -> Self {
Self {
max_chunks_per_doc: 1000,
default_aggregation: AggregationMethod::Max,
overfetch_factor: 2.0,
}
}
}
#[derive(Debug, Clone)]
pub struct MultiVectorDocument {
pub id: DocId,
pub vectors: Vec<Vec<f32>>,
pub chunks_text: Option<Vec<String>>,
pub metadata: HashMap<String, serde_json::Value>,
}
impl MultiVectorDocument {
pub fn new(id: impl Into<DocId>, vectors: Vec<Vec<f32>>) -> Self {
Self {
id: id.into(),
vectors,
chunks_text: None,
metadata: HashMap::new(),
}
}
pub fn with_text(mut self, chunks: Vec<String>) -> Self {
self.chunks_text = Some(chunks);
self
}
pub fn with_metadata(
mut self,
key: impl Into<String>,
value: impl Into<serde_json::Value>,
) -> Self {
self.metadata.insert(key.into(), value.into());
self
}
pub fn num_chunks(&self) -> usize {
self.vectors.len()
}
pub fn validate(&self, expected_dim: usize) -> Result<(), MultiVectorError> {
if self.vectors.is_empty() {
return Err(MultiVectorError::NoVectors);
}
for (i, v) in self.vectors.iter().enumerate() {
if v.len() != expected_dim {
return Err(MultiVectorError::DimensionMismatch {
chunk: i,
expected: expected_dim,
actual: v.len(),
});
}
}
if let Some(ref texts) = self.chunks_text {
if texts.len() != self.vectors.len() {
return Err(MultiVectorError::ChunkCountMismatch {
vectors: self.vectors.len(),
texts: texts.len(),
});
}
}
Ok(())
}
}
#[derive(Debug, thiserror::Error)]
pub enum MultiVectorError {
#[error("document must have at least one vector")]
NoVectors,
#[error("dimension mismatch in chunk {chunk}: expected {expected}, got {actual}")]
DimensionMismatch {
chunk: usize,
expected: usize,
actual: usize,
},
#[error("chunk count mismatch: {vectors} vectors but {texts} texts")]
ChunkCountMismatch { vectors: usize, texts: usize },
#[error("too many chunks: {count} exceeds limit of {limit}")]
TooManyChunks { count: usize, limit: usize },
#[error("document not found: {0}")]
NotFound(DocId),
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_aggregation_max() {
let chunks = vec![(0, 0.5), (1, 0.9), (2, 0.3)];
let result =
DocumentScore::aggregate("doc1".to_string(), chunks, AggregationMethod::Max, false);
assert_eq!(result.score, 0.9);
assert_eq!(result.best_chunk, Some(1));
assert_eq!(result.matched_chunks, 3);
}
#[test]
fn test_aggregation_mean() {
let chunks = vec![(0, 0.6), (1, 0.9), (2, 0.3)];
let result =
DocumentScore::aggregate("doc1".to_string(), chunks, AggregationMethod::Mean, false);
assert!((result.score - 0.6).abs() < 0.001); }
#[test]
fn test_aggregation_first() {
let chunks = vec![(2, 0.3), (0, 0.5), (1, 0.9)];
let result =
DocumentScore::aggregate("doc1".to_string(), chunks, AggregationMethod::First, false);
assert_eq!(result.score, 0.5); assert_eq!(result.best_chunk, Some(0));
}
#[test]
fn test_mapping_insert() {
let mut mapping = MultiVectorMapping::new();
let ids = mapping.insert_document("doc1".to_string(), 3);
assert_eq!(ids.len(), 3);
for (i, &id) in ids.iter().enumerate() {
let (doc_id, chunk) = mapping.get_doc(id).unwrap();
assert_eq!(doc_id, "doc1");
assert_eq!(chunk as usize, i);
}
}
#[test]
fn test_mapping_remove() {
let mut mapping = MultiVectorMapping::new();
let ids = mapping.insert_document("doc1".to_string(), 3);
let removed = mapping.remove_document("doc1").unwrap();
assert_eq!(removed, ids);
assert!(mapping.get_doc(ids[0]).is_none());
assert!(!mapping.has_document("doc1"));
}
#[test]
fn test_aggregator() {
let mapping = Arc::new(RwLock::new(MultiVectorMapping::new()));
{
let mut m = mapping.write();
m.insert_document("doc1".to_string(), 3); m.insert_document("doc2".to_string(), 2); }
let aggregator = MultiVectorAggregator::new(mapping);
let vector_results = vec![
(1, 0.95), (3, 0.90), (0, 0.85), (4, 0.80), ];
let doc_results = aggregator.aggregate(&vector_results, Some(AggregationMethod::Max), 10);
assert_eq!(doc_results.len(), 2);
assert_eq!(doc_results[0].doc_id, "doc1");
assert_eq!(doc_results[0].score, 0.95);
assert_eq!(doc_results[1].doc_id, "doc2");
assert_eq!(doc_results[1].score, 0.90);
}
#[test]
fn test_multi_vector_document() {
let doc = MultiVectorDocument::new("doc1", vec![vec![0.1, 0.2, 0.3], vec![0.4, 0.5, 0.6]])
.with_text(vec!["chunk 1".to_string(), "chunk 2".to_string()])
.with_metadata("author", serde_json::json!("Alice"));
assert_eq!(doc.num_chunks(), 2);
assert!(doc.validate(3).is_ok());
assert!(doc.validate(4).is_err()); }
}