use std::collections::HashMap;
use std::sync::RwLock;
pub type ShardId = u32;
pub type ClusterId = u32;
#[derive(Debug, Clone)]
pub struct Centroid {
pub id: ClusterId,
pub vector: Vec<f32>,
pub shards: Vec<ShardId>,
pub count: usize,
}
impl Centroid {
pub fn new(id: ClusterId, vector: Vec<f32>) -> Self {
Self {
id,
vector,
shards: Vec::new(),
count: 0,
}
}
#[inline]
pub fn distance_squared(&self, query: &[f32]) -> f32 {
self.vector
.iter()
.zip(query.iter())
.map(|(&a, &b)| {
let d = a - b;
d * d
})
.sum()
}
}
#[derive(Debug, Clone)]
pub struct RoutingDecision {
pub shards: Vec<ShardId>,
pub distances: Vec<f32>,
pub clusters_probed: usize,
}
impl RoutingDecision {
pub fn work_reduction(&self, total_shards: usize) -> f32 {
if self.shards.is_empty() {
return 1.0;
}
self.shards.len() as f32 / total_shards as f32
}
}
#[derive(Debug, Clone)]
pub struct TopologyConfig {
pub num_clusters: usize,
pub shards_per_cluster: usize,
pub probe_clusters: usize,
pub rebalance_threshold: f32,
}
impl Default for TopologyConfig {
fn default() -> Self {
Self {
num_clusters: 16,
shards_per_cluster: 16,
probe_clusters: 2,
rebalance_threshold: 2.0,
}
}
}
pub struct ShardTopology {
centroids: Vec<Centroid>,
shard_to_cluster: HashMap<ShardId, ClusterId>,
config: TopologyConfig,
total_shards: usize,
stats: RwLock<TopologyStats>,
}
#[derive(Debug, Clone, Default)]
pub struct TopologyStats {
pub queries_routed: u64,
pub shards_probed: u64,
pub avg_fanout: f32,
pub cluster_loads: Vec<u64>,
}
impl ShardTopology {
pub fn new(centroids: Vec<Centroid>, config: TopologyConfig) -> Self {
let total_shards = centroids.iter().map(|c| c.shards.len()).sum();
let mut shard_to_cluster = HashMap::new();
for centroid in ¢roids {
for &shard in ¢roid.shards {
shard_to_cluster.insert(shard, centroid.id);
}
}
let cluster_loads = vec![0; centroids.len()];
Self {
centroids,
shard_to_cluster,
config,
total_shards,
stats: RwLock::new(TopologyStats {
cluster_loads,
..Default::default()
}),
}
}
pub fn build_from_vectors(vectors: &[Vec<f32>], config: TopologyConfig) -> Self {
if vectors.is_empty() {
return Self::empty(config);
}
let dimension = vectors[0].len();
let num_clusters = config.num_clusters.min(vectors.len());
let mut centroids: Vec<Centroid> = (0..num_clusters)
.map(|i| {
let idx = (i * vectors.len()) / num_clusters;
Centroid::new(i as ClusterId, vectors[idx].clone())
})
.collect();
for _ in 0..10 {
let mut assignments: Vec<Vec<usize>> = vec![Vec::new(); num_clusters];
for (vec_idx, vector) in vectors.iter().enumerate() {
let nearest = Self::find_nearest_centroid(vector, ¢roids);
assignments[nearest].push(vec_idx);
}
for (cluster_idx, assigned) in assignments.iter().enumerate() {
if assigned.is_empty() {
continue;
}
let mut new_centroid = vec![0.0f32; dimension];
for &vec_idx in assigned {
for (i, &v) in vectors[vec_idx].iter().enumerate() {
new_centroid[i] += v;
}
}
let count = assigned.len() as f32;
for v in &mut new_centroid {
*v /= count;
}
centroids[cluster_idx].vector = new_centroid;
centroids[cluster_idx].count = assigned.len();
}
}
let _total_shards = config.num_clusters * config.shards_per_cluster;
for (i, centroid) in centroids.iter_mut().enumerate() {
let start_shard = i * config.shards_per_cluster;
let end_shard = start_shard + config.shards_per_cluster;
centroid.shards = (start_shard..end_shard).map(|s| s as ShardId).collect();
}
Self::new(centroids, config)
}
pub fn empty(config: TopologyConfig) -> Self {
Self {
centroids: Vec::new(),
shard_to_cluster: HashMap::new(),
config,
total_shards: 0,
stats: RwLock::new(TopologyStats::default()),
}
}
pub fn route(&self, query: &[f32]) -> RoutingDecision {
if self.centroids.is_empty() {
return RoutingDecision {
shards: Vec::new(),
distances: Vec::new(),
clusters_probed: 0,
};
}
let mut cluster_dists: Vec<(ClusterId, f32)> = self
.centroids
.iter()
.map(|c| (c.id, c.distance_squared(query)))
.collect();
cluster_dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
let probe_count = self.config.probe_clusters.min(cluster_dists.len());
let probed: Vec<_> = cluster_dists[..probe_count].to_vec();
let mut shards = Vec::new();
let mut distances = Vec::new();
for (cluster_id, dist) in &probed {
if let Some(centroid) = self.centroids.get(*cluster_id as usize) {
shards.extend_from_slice(¢roid.shards);
distances.push(*dist);
}
}
if let Ok(mut stats) = self.stats.write() {
stats.queries_routed += 1;
stats.shards_probed += shards.len() as u64;
stats.avg_fanout = stats.shards_probed as f32 / stats.queries_routed as f32;
for (cluster_id, _) in &probed {
if (*cluster_id as usize) < stats.cluster_loads.len() {
stats.cluster_loads[*cluster_id as usize] += 1;
}
}
}
RoutingDecision {
shards,
distances,
clusters_probed: probe_count,
}
}
pub fn shard_cluster(&self, shard: ShardId) -> Option<ClusterId> {
self.shard_to_cluster.get(&shard).copied()
}
pub fn all_shards(&self) -> Vec<ShardId> {
self.shard_to_cluster.keys().copied().collect()
}
pub fn cluster(&self, id: ClusterId) -> Option<&Centroid> {
self.centroids.get(id as usize)
}
pub fn num_clusters(&self) -> usize {
self.centroids.len()
}
pub fn num_shards(&self) -> usize {
self.total_shards
}
pub fn needs_rebalance(&self) -> bool {
if self.centroids.len() < 2 {
return false;
}
let counts: Vec<usize> = self.centroids.iter().map(|c| c.count).collect();
let max_count = *counts.iter().max().unwrap_or(&1) as f32;
let min_count = *counts.iter().min().unwrap_or(&1).max(&1) as f32;
max_count / min_count > self.config.rebalance_threshold
}
pub fn stats(&self) -> TopologyStats {
self.stats.read().unwrap().clone()
}
fn find_nearest_centroid(vector: &[f32], centroids: &[Centroid]) -> usize {
centroids
.iter()
.enumerate()
.min_by(|(_, a), (_, b)| {
a.distance_squared(vector)
.partial_cmp(&b.distance_squared(vector))
.unwrap()
})
.map(|(i, _)| i)
.unwrap_or(0)
}
}
pub struct ShardRouter {
topology: ShardTopology,
#[allow(dead_code)]
adaptive: bool,
}
impl ShardRouter {
pub fn new(topology: ShardTopology) -> Self {
Self {
topology,
adaptive: true,
}
}
pub fn route_adaptive(&self, query: &[f32], target_recall: f32) -> RoutingDecision {
let base_probe = self.topology.config.probe_clusters;
let _probe = if target_recall > 0.99 {
(base_probe * 2).min(self.topology.num_clusters())
} else if target_recall > 0.95 {
base_probe
} else {
(base_probe / 2).max(1)
};
let mut decision = self.topology.route(query);
if target_recall > 0.95 && decision.shards.len() < 4 {
decision.shards.extend(
self.topology
.all_shards()
.into_iter()
.take(4 - decision.shards.len()),
);
}
decision
}
pub fn estimated_recall(&self, decision: &RoutingDecision) -> f32 {
if self.topology.num_shards() == 0 {
return 0.0;
}
let coverage = decision.shards.len() as f32 / self.topology.num_shards() as f32;
coverage.sqrt().min(1.0)
}
pub fn topology(&self) -> &ShardTopology {
&self.topology
}
}
#[cfg(test)]
mod tests {
use super::*;
fn random_vectors(count: usize, dim: usize, seed: u64) -> Vec<Vec<f32>> {
(0..count)
.map(|i| {
(0..dim)
.map(|d| {
let x = ((i as u64 * 13 + d as u64 * 7 + seed) % 1000) as f32 / 1000.0;
x * 2.0 - 1.0
})
.collect()
})
.collect()
}
#[test]
fn test_centroid_distance() {
let centroid = Centroid::new(0, vec![1.0, 0.0, 0.0]);
let query = vec![0.0, 0.0, 0.0];
assert!((centroid.distance_squared(&query) - 1.0).abs() < 1e-6);
}
#[test]
fn test_topology_build() {
let vectors = random_vectors(1000, 128, 42);
let config = TopologyConfig {
num_clusters: 4,
shards_per_cluster: 4,
probe_clusters: 2,
..Default::default()
};
let topology = ShardTopology::build_from_vectors(&vectors, config);
assert_eq!(topology.num_clusters(), 4);
assert_eq!(topology.num_shards(), 16);
}
#[test]
fn test_query_routing() {
let vectors = random_vectors(1000, 128, 42);
let config = TopologyConfig {
num_clusters: 4,
shards_per_cluster: 4,
probe_clusters: 2,
..Default::default()
};
let topology = ShardTopology::build_from_vectors(&vectors, config);
let query = random_vectors(1, 128, 99)[0].clone();
let decision = topology.route(&query);
assert_eq!(decision.clusters_probed, 2);
assert_eq!(decision.shards.len(), 8);
assert!((decision.work_reduction(16) - 0.5).abs() < 1e-6);
}
#[test]
fn test_shard_cluster_mapping() {
let config = TopologyConfig {
num_clusters: 4,
shards_per_cluster: 4,
..Default::default()
};
let centroids: Vec<Centroid> = (0..4)
.map(|i| {
let mut c = Centroid::new(i, vec![i as f32; 128]);
c.shards = vec![i * 4, i * 4 + 1, i * 4 + 2, i * 4 + 3];
c
})
.collect();
let topology = ShardTopology::new(centroids, config);
assert_eq!(topology.shard_cluster(0), Some(0));
assert_eq!(topology.shard_cluster(5), Some(1));
assert_eq!(topology.shard_cluster(10), Some(2));
assert_eq!(topology.shard_cluster(15), Some(3));
}
#[test]
fn test_adaptive_routing() {
let vectors = random_vectors(1000, 128, 42);
let config = TopologyConfig {
num_clusters: 8,
shards_per_cluster: 4,
probe_clusters: 2,
..Default::default()
};
let topology = ShardTopology::build_from_vectors(&vectors, config);
let router = ShardRouter::new(topology);
let query = random_vectors(1, 128, 99)[0].clone();
let low_recall = router.route_adaptive(&query, 0.80);
let high_recall = router.route_adaptive(&query, 0.99);
assert!(high_recall.shards.len() >= low_recall.shards.len());
}
#[test]
fn test_empty_topology() {
let config = TopologyConfig::default();
let topology = ShardTopology::empty(config);
assert_eq!(topology.num_clusters(), 0);
assert_eq!(topology.num_shards(), 0);
let decision = topology.route(&[0.0, 0.0, 0.0]);
assert!(decision.shards.is_empty());
}
#[test]
fn test_stats_tracking() {
let vectors = random_vectors(1000, 128, 42);
let config = TopologyConfig {
num_clusters: 4,
shards_per_cluster: 4,
probe_clusters: 2,
..Default::default()
};
let topology = ShardTopology::build_from_vectors(&vectors, config);
for i in 0..10 {
let query = random_vectors(1, 128, i)[0].clone();
topology.route(&query);
}
let stats = topology.stats();
assert_eq!(stats.queries_routed, 10);
assert!(stats.avg_fanout > 0.0);
}
#[test]
fn test_rebalance_detection() {
let mut centroids: Vec<Centroid> = (0..4)
.map(|i| {
let mut c = Centroid::new(i, vec![i as f32; 128]);
c.shards = vec![i * 4];
c.count = if i == 0 { 1000 } else { 100 }; c
})
.collect();
let config = TopologyConfig {
rebalance_threshold: 2.0,
..Default::default()
};
let topology = ShardTopology::new(centroids, config);
assert!(topology.needs_rebalance());
}
#[test]
fn test_estimated_recall() {
let vectors = random_vectors(100, 128, 42);
let config = TopologyConfig {
num_clusters: 4,
shards_per_cluster: 4,
probe_clusters: 2,
..Default::default()
};
let topology = ShardTopology::build_from_vectors(&vectors, config);
let router = ShardRouter::new(topology);
let query = random_vectors(1, 128, 99)[0].clone();
let decision = router.topology().route(&query);
let recall = router.estimated_recall(&decision);
assert!(recall > 0.5 && recall < 1.0);
}
}