use std::collections::HashMap;
use crate::point::SearchResult;
use crate::velesql::SetOperator;
pub(crate) fn apply_set_operation(
left: Vec<SearchResult>,
right: Vec<SearchResult>,
operator: SetOperator,
limit: usize,
) -> Vec<SearchResult> {
let mut results = match operator {
SetOperator::Union => union_dedup(left, right),
SetOperator::UnionAll => union_all(left, right),
SetOperator::Intersect => intersect(left, &right),
SetOperator::Except => except(left, &right),
};
results.sort_unstable_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
results.truncate(limit);
results
}
fn union_dedup(left: Vec<SearchResult>, right: Vec<SearchResult>) -> Vec<SearchResult> {
let mut map: HashMap<u64, SearchResult> = HashMap::with_capacity(left.len() + right.len());
for result in left {
map.insert(result.point.id, result);
}
for result in right {
match map.entry(result.point.id) {
std::collections::hash_map::Entry::Occupied(mut existing) => {
if result.score > existing.get().score {
existing.insert(result);
}
}
std::collections::hash_map::Entry::Vacant(slot) => {
slot.insert(result);
}
}
}
map.into_values().collect()
}
fn union_all(mut left: Vec<SearchResult>, right: Vec<SearchResult>) -> Vec<SearchResult> {
left.extend(right);
left
}
fn intersect(left: Vec<SearchResult>, right: &[SearchResult]) -> Vec<SearchResult> {
let right_map: HashMap<u64, &SearchResult> = right.iter().map(|r| (r.point.id, r)).collect();
left.into_iter()
.filter_map(|l| {
right_map
.get(&l.point.id)
.map(|r| if r.score > l.score { (*r).clone() } else { l })
})
.collect()
}
fn except(left: Vec<SearchResult>, right: &[SearchResult]) -> Vec<SearchResult> {
let right_ids: std::collections::HashSet<u64> = right.iter().map(|r| r.point.id).collect();
left.into_iter()
.filter(|l| !right_ids.contains(&l.point.id))
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::point::{Point, SearchResult};
const TEST_LIMIT: usize = 100_000;
fn make_result(id: u64, score: f32) -> SearchResult {
SearchResult::new(Point::new(id, vec![0.0; 3], None), score)
}
#[test]
fn test_union_dedup_keeps_highest_score() {
let left = vec![make_result(1, 0.9), make_result(2, 0.5)];
let right = vec![make_result(2, 0.8), make_result(3, 0.7)];
let results = apply_set_operation(left, right, SetOperator::Union, TEST_LIMIT);
assert_eq!(results.len(), 3);
let ids: Vec<u64> = results.iter().map(|r| r.point.id).collect();
assert!(ids.contains(&1));
assert!(ids.contains(&2));
assert!(ids.contains(&3));
let point2 = results.iter().find(|r| r.point.id == 2).unwrap();
assert!((point2.score - 0.8).abs() < f32::EPSILON);
}
#[test]
fn test_union_all_keeps_duplicates() {
let left = vec![make_result(1, 0.9), make_result(2, 0.5)];
let right = vec![make_result(2, 0.8), make_result(3, 0.7)];
let results = apply_set_operation(left, right, SetOperator::UnionAll, TEST_LIMIT);
assert_eq!(results.len(), 4);
}
#[test]
fn test_intersect_keeps_common_ids() {
let left = vec![make_result(1, 0.9), make_result(2, 0.5)];
let right = vec![make_result(2, 0.8), make_result(3, 0.7)];
let results = apply_set_operation(left, right, SetOperator::Intersect, TEST_LIMIT);
assert_eq!(results.len(), 1);
assert_eq!(results[0].point.id, 2);
assert!((results[0].score - 0.8).abs() < f32::EPSILON);
}
#[test]
fn test_except_removes_right_ids() {
let left = vec![make_result(1, 0.9), make_result(2, 0.5)];
let right = vec![make_result(2, 0.8), make_result(3, 0.7)];
let results = apply_set_operation(left, right, SetOperator::Except, TEST_LIMIT);
assert_eq!(results.len(), 1);
assert_eq!(results[0].point.id, 1);
}
#[test]
fn test_results_sorted_by_score_desc() {
let left = vec![make_result(1, 0.3), make_result(2, 0.9)];
let right = vec![make_result(3, 0.6)];
let results = apply_set_operation(left, right, SetOperator::UnionAll, TEST_LIMIT);
let scores: Vec<f32> = results.iter().map(|r| r.score).collect();
for window in scores.windows(2) {
assert!(window[0] >= window[1], "Results not sorted descending");
}
}
#[test]
fn test_empty_operands() {
let empty: Vec<SearchResult> = Vec::new();
let non_empty = vec![make_result(1, 0.5)];
let r = apply_set_operation(
Vec::new(),
non_empty.clone(),
SetOperator::Union,
TEST_LIMIT,
);
assert_eq!(r.len(), 1);
let r = apply_set_operation(
Vec::new(),
non_empty.clone(),
SetOperator::Intersect,
TEST_LIMIT,
);
assert!(r.is_empty());
let r = apply_set_operation(non_empty, empty, SetOperator::Except, TEST_LIMIT);
assert_eq!(r.len(), 1);
}
fn scored(id: u16) -> SearchResult {
make_result(u64::from(id), f32::from(id))
}
#[test]
fn test_union_respects_limit() {
let left: Vec<SearchResult> = (1..=100).map(scored).collect();
let right: Vec<SearchResult> = (101..=200).map(scored).collect();
let results = apply_set_operation(left, right, SetOperator::Union, 5);
assert_eq!(results.len(), 5);
let ids: Vec<u64> = results.iter().map(|r| r.point.id).collect();
assert_eq!(ids, vec![200, 199, 198, 197, 196]);
}
#[test]
fn test_intersect_respects_limit() {
let left: Vec<SearchResult> = (0u16..100).map(scored).collect();
let right: Vec<SearchResult> = (0u16..100)
.map(|i| make_result(u64::from(i), 0.0))
.collect();
let results = apply_set_operation(left, right, SetOperator::Intersect, 3);
assert_eq!(results.len(), 3);
let ids: Vec<u64> = results.iter().map(|r| r.point.id).collect();
assert_eq!(ids, vec![99, 98, 97]);
}
#[test]
fn test_limit_larger_than_results_is_noop() {
let left = vec![make_result(1, 0.9), make_result(2, 0.5)];
let right = vec![make_result(3, 0.7)];
let results = apply_set_operation(left, right, SetOperator::Union, 100);
assert_eq!(results.len(), 3);
}
}