use super::DistanceMetric;
#[derive(Debug, Clone)]
pub struct VectorZoneMap {
pub dimensions: usize,
pub count: usize,
pub min_magnitude: f32,
pub max_magnitude: f32,
pub centroid: Vec<f32>,
pub max_radius: f32,
pub dim_min: Vec<f32>,
pub dim_max: Vec<f32>,
}
impl VectorZoneMap {
#[must_use]
pub fn new(dimensions: usize) -> Self {
Self {
dimensions,
count: 0,
min_magnitude: f32::MAX,
max_magnitude: f32::MIN,
centroid: vec![0.0; dimensions],
max_radius: 0.0,
dim_min: vec![f32::MAX; dimensions],
dim_max: vec![f32::MIN; dimensions],
}
}
#[must_use]
pub fn build(vectors: &[&[f32]]) -> Self {
if vectors.is_empty() {
return Self::new(0);
}
let dimensions = vectors[0].len();
let count = vectors.len();
let mut min_magnitude = f32::MAX;
let mut max_magnitude = f32::MIN;
let mut centroid = vec![0.0; dimensions];
let mut dim_min = vec![f32::MAX; dimensions];
let mut dim_max = vec![f32::MIN; dimensions];
for vec in vectors {
let magnitude: f32 = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
min_magnitude = min_magnitude.min(magnitude);
max_magnitude = max_magnitude.max(magnitude);
for (i, &v) in vec.iter().enumerate() {
centroid[i] += v;
dim_min[i] = dim_min[i].min(v);
dim_max[i] = dim_max[i].max(v);
}
}
let count_f = count as f32;
for c in &mut centroid {
*c /= count_f;
}
let mut max_radius = 0.0f32;
for vec in vectors {
let dist_sq: f32 = vec
.iter()
.zip(¢roid)
.map(|(a, b)| (a - b) * (a - b))
.sum();
max_radius = max_radius.max(dist_sq.sqrt());
}
Self {
dimensions,
count,
min_magnitude,
max_magnitude,
centroid,
max_radius,
dim_min,
dim_max,
}
}
#[must_use]
pub fn might_contain_within_distance(
&self,
query: &[f32],
threshold: f32,
metric: DistanceMetric,
) -> bool {
if self.count == 0 {
return false;
}
match metric {
DistanceMetric::Euclidean => {
let centroid_dist = euclidean_distance(query, &self.centroid);
if centroid_dist - self.max_radius > threshold {
return false;
}
let box_dist = self.min_distance_to_box(query);
if box_dist > threshold {
return false;
}
true
}
DistanceMetric::Cosine => {
let centroid_dist = cosine_distance(query, &self.centroid);
centroid_dist - self.max_radius <= threshold
}
DistanceMetric::DotProduct | DistanceMetric::Manhattan => {
true
}
}
}
fn min_distance_to_box(&self, query: &[f32]) -> f32 {
let mut dist_sq = 0.0f32;
for (i, &q) in query.iter().enumerate() {
if i >= self.dimensions {
break;
}
let closest = if q < self.dim_min[i] {
self.dim_min[i]
} else if q > self.dim_max[i] {
self.dim_max[i]
} else {
q };
let diff = q - closest;
dist_sq += diff * diff;
}
dist_sq.sqrt()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.count == 0
}
#[must_use]
pub fn avg_magnitude(&self) -> f32 {
f32::midpoint(self.min_magnitude, self.max_magnitude)
}
#[must_use]
pub fn magnitude_range(&self) -> (f32, f32) {
(self.min_magnitude, self.max_magnitude)
}
#[must_use]
pub fn bounding_box(&self) -> (&[f32], &[f32]) {
(&self.dim_min, &self.dim_max)
}
pub fn merge(&mut self, other: &VectorZoneMap) {
if other.is_empty() {
return;
}
if self.is_empty() {
*self = other.clone();
return;
}
self.min_magnitude = self.min_magnitude.min(other.min_magnitude);
self.max_magnitude = self.max_magnitude.max(other.max_magnitude);
for i in 0..self.dimensions.min(other.dimensions) {
self.dim_min[i] = self.dim_min[i].min(other.dim_min[i]);
self.dim_max[i] = self.dim_max[i].max(other.dim_max[i]);
}
let total_count = self.count + other.count;
let self_weight = self.count as f32 / total_count as f32;
let other_weight = other.count as f32 / total_count as f32;
for i in 0..self.dimensions.min(other.dimensions) {
self.centroid[i] = self.centroid[i] * self_weight + other.centroid[i] * other_weight;
}
self.max_radius = f32::midpoint(self.max_radius, other.max_radius)
+ euclidean_distance(&self.centroid, &other.centroid);
self.count = total_count;
}
}
fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
a.iter()
.zip(b.iter())
.map(|(x, y)| (x - y) * (x - y))
.sum::<f32>()
.sqrt()
}
fn cosine_distance(a: &[f32], b: &[f32]) -> f32 {
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
return 1.0; }
1.0 - (dot / (norm_a * norm_b))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_vector_zone_map_build() {
let v1 = [1.0f32, 0.0, 0.0];
let v2 = [0.0f32, 1.0, 0.0];
let v3 = [0.0f32, 0.0, 1.0];
let vectors: Vec<&[f32]> = vec![&v1, &v2, &v3];
let zm = VectorZoneMap::build(&vectors);
assert_eq!(zm.count, 3);
assert_eq!(zm.dimensions, 3);
assert!((zm.min_magnitude - 1.0).abs() < 0.001);
assert!((zm.max_magnitude - 1.0).abs() < 0.001);
for c in &zm.centroid {
assert!((*c - 1.0 / 3.0).abs() < 0.001);
}
}
#[test]
fn test_vector_zone_map_pruning() {
let v1 = [5.0f32, 5.0, 5.0];
let v2 = [5.1f32, 4.9, 5.0];
let v3 = [4.9f32, 5.1, 5.0];
let vectors: Vec<&[f32]> = vec![&v1, &v2, &v3];
let zm = VectorZoneMap::build(&vectors);
let far_query = [0.0f32, 0.0, 0.0];
let far_dist = euclidean_distance(&far_query, &zm.centroid);
assert!(far_dist > 8.0);
assert!(!zm.might_contain_within_distance(&far_query, 1.0, DistanceMetric::Euclidean));
let close_query = [5.0f32, 5.0, 5.0];
assert!(zm.might_contain_within_distance(&close_query, 1.0, DistanceMetric::Euclidean));
}
#[test]
fn test_vector_zone_map_bounding_box() {
let v1 = [0.0f32, 0.0];
let v2 = [10.0f32, 10.0];
let vectors: Vec<&[f32]> = vec![&v1, &v2];
let zm = VectorZoneMap::build(&vectors);
let (min, max) = zm.bounding_box();
assert!((min[0] - 0.0).abs() < 0.001);
assert!((min[1] - 0.0).abs() < 0.001);
assert!((max[0] - 10.0).abs() < 0.001);
assert!((max[1] - 10.0).abs() < 0.001);
}
#[test]
fn test_vector_zone_map_merge() {
let v1 = [1.0f32, 0.0];
let v2 = [2.0f32, 0.0];
let zm1 = VectorZoneMap::build(&[&v1, &v2]);
let v3 = [10.0f32, 0.0];
let v4 = [11.0f32, 0.0];
let zm2 = VectorZoneMap::build(&[&v3, &v4]);
let mut merged = zm1.clone();
merged.merge(&zm2);
assert_eq!(merged.count, 4);
let (min, max) = merged.bounding_box();
assert!((min[0] - 1.0).abs() < 0.001);
assert!((max[0] - 11.0).abs() < 0.001);
}
#[test]
fn test_vector_zone_map_empty() {
let zm = VectorZoneMap::new(3);
assert!(zm.is_empty());
assert_eq!(zm.count, 0);
}
}