use pgrx::prelude::*;
use std::cmp::Ordering;
use std::collections::BinaryHeap;
use std::sync::atomic::{AtomicU32, AtomicUsize, Ordering as AtomicOrdering};
use std::sync::Arc;
use parking_lot::RwLock;
use super::hnsw::{HnswIndex, NodeId};
use crate::distance::DistanceMetric;
#[repr(C)]
pub struct RuHnswSharedState {
pub num_workers: u32,
pub next_partition: AtomicU32,
pub total_partitions: u32,
pub dimensions: u32,
pub k: usize,
pub ef_search: usize,
pub metric: DistanceMetric,
pub completed_workers: AtomicU32,
pub total_results: AtomicUsize,
}
impl RuHnswSharedState {
pub fn new(
num_workers: u32,
total_partitions: u32,
dimensions: u32,
k: usize,
ef_search: usize,
metric: DistanceMetric,
) -> Self {
Self {
num_workers,
next_partition: AtomicU32::new(0),
total_partitions,
dimensions,
k,
ef_search,
metric,
completed_workers: AtomicU32::new(0),
total_results: AtomicUsize::new(0),
}
}
pub fn get_next_partition(&self) -> Option<u32> {
let partition = self.next_partition.fetch_add(1, AtomicOrdering::SeqCst);
if partition < self.total_partitions {
Some(partition)
} else {
None
}
}
pub fn mark_completed(&self) {
self.completed_workers.fetch_add(1, AtomicOrdering::SeqCst);
}
pub fn all_completed(&self) -> bool {
self.completed_workers.load(AtomicOrdering::SeqCst) >= self.num_workers
}
pub fn add_results(&self, count: usize) {
self.total_results.fetch_add(count, AtomicOrdering::SeqCst);
}
}
pub struct RuHnswParallelScanDesc {
pub shared: Arc<RwLock<RuHnswSharedState>>,
pub worker_id: u32,
pub local_results: Vec<(f32, ItemPointer)>,
pub query: Vec<f32>,
}
impl RuHnswParallelScanDesc {
pub fn new(
shared: Arc<RwLock<RuHnswSharedState>>,
worker_id: u32,
query: Vec<f32>,
) -> Self {
Self {
shared,
worker_id,
local_results: Vec::new(),
query,
}
}
pub fn execute_scan(&mut self, index: &HnswIndex) {
while let Some(partition_id) = {
let shared = self.shared.read();
shared.get_next_partition()
} {
let partition_results = self.scan_partition(index, partition_id);
self.local_results.extend(partition_results);
}
self.local_results.sort_by(|a, b| {
a.0.partial_cmp(&b.0).unwrap_or(Ordering::Equal)
});
let shared = self.shared.read();
let k = shared.k;
drop(shared);
if self.local_results.len() > k {
self.local_results.truncate(k);
}
let shared = self.shared.read();
shared.add_results(self.local_results.len());
shared.mark_completed();
}
fn scan_partition(
&self,
index: &HnswIndex,
partition_id: u32,
) -> Vec<(f32, ItemPointer)> {
let shared = self.shared.read();
let k = shared.k;
let ef_search = shared.ef_search;
drop(shared);
let total_nodes = index.len();
let shared = self.shared.read();
let partitions = shared.total_partitions as usize;
drop(shared);
let partition_size = (total_nodes + partitions - 1) / partitions;
let start_idx = partition_id as usize * partition_size;
let end_idx = ((partition_id as usize + 1) * partition_size).min(total_nodes);
if start_idx >= total_nodes {
return Vec::new();
}
let results = index.search(&self.query, k, Some(ef_search));
results
.into_iter()
.map(|(node_id, distance)| {
let item_pointer = create_item_pointer(node_id);
(distance, item_pointer)
})
.collect()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(C)]
pub struct ItemPointer {
pub block_number: u32,
pub offset_number: u16,
}
impl ItemPointer {
pub fn new(block_number: u32, offset_number: u16) -> Self {
Self {
block_number,
offset_number,
}
}
}
fn create_item_pointer(node_id: NodeId) -> ItemPointer {
let block = (node_id / 8191) as u32; let offset = (node_id % 8191) as u16 + 1;
ItemPointer::new(block, offset)
}
pub fn ruhnsw_estimate_parallel_workers(
index_pages: i32,
index_tuples: i64,
k: i32,
ef_search: i32,
) -> i32 {
if index_pages < 100 || index_tuples < 10000 {
return 0;
}
let max_workers = get_max_parallel_workers();
let workers_by_size = (index_pages / 1000).min(max_workers);
let complexity_factor = if ef_search > 100 || k > 100 {
2.0 } else if ef_search > 50 || k > 50 {
1.5
} else {
1.0
};
let recommended = ((workers_by_size as f32 * complexity_factor) as i32)
.min(max_workers)
.max(0);
recommended
}
fn get_max_parallel_workers() -> i32 {
4
}
pub fn estimate_partitions(num_workers: i32, total_tuples: i64) -> u32 {
let base_partitions = num_workers * 3;
let tuples_per_partition = 10000;
let partitions_by_size = (total_tuples / tuples_per_partition) as i32;
base_partitions.min(partitions_by_size).max(1) as u32
}
#[derive(Debug, Clone, Copy)]
pub struct KnnNeighbor {
pub distance: f32,
pub item_pointer: ItemPointer,
}
impl PartialEq for KnnNeighbor {
fn eq(&self, other: &Self) -> bool {
self.item_pointer == other.item_pointer
}
}
impl Eq for KnnNeighbor {}
impl PartialOrd for KnnNeighbor {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for KnnNeighbor {
fn cmp(&self, other: &Self) -> Ordering {
other.distance.partial_cmp(&self.distance)
.unwrap_or(Ordering::Equal)
}
}
pub fn merge_knn_results(
worker_results: &[Vec<(f32, ItemPointer)>],
k: usize,
) -> Vec<(f32, ItemPointer)> {
if worker_results.is_empty() {
return Vec::new();
}
let mut heap: BinaryHeap<KnnNeighbor> = BinaryHeap::new();
for results in worker_results {
for &(distance, item_pointer) in results {
let neighbor = KnnNeighbor {
distance,
item_pointer,
};
if heap.len() < k {
heap.push(neighbor);
} else if let Some(worst) = heap.peek() {
if neighbor.distance < worst.distance {
heap.pop();
heap.push(neighbor);
}
}
}
}
let mut results: Vec<(f32, ItemPointer)> = heap
.into_iter()
.map(|n| (n.distance, n.item_pointer))
.collect();
results.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(Ordering::Equal));
results
}
pub fn merge_knn_results_tournament(
worker_results: &[Vec<(f32, ItemPointer)>],
k: usize,
) -> Vec<(f32, ItemPointer)> {
if worker_results.is_empty() {
return Vec::new();
}
if worker_results.len() == 1 {
return worker_results[0].iter().take(k).copied().collect();
}
let mut cursors: Vec<usize> = vec![0; worker_results.len()];
let mut merged = Vec::with_capacity(k);
for _ in 0..k {
let mut best_worker = None;
let mut best_distance = f32::MAX;
for (worker_id, cursor) in cursors.iter_mut().enumerate() {
if *cursor < worker_results[worker_id].len() {
let (distance, _) = worker_results[worker_id][*cursor];
if distance < best_distance {
best_distance = distance;
best_worker = Some(worker_id);
}
}
}
if let Some(worker_id) = best_worker {
let cursor = &mut cursors[worker_id];
merged.push(worker_results[worker_id][*cursor]);
*cursor += 1;
} else {
break; }
}
merged
}
pub struct ParallelScanCoordinator {
pub shared_state: Arc<RwLock<RuHnswSharedState>>,
pub worker_results: Vec<Vec<(f32, ItemPointer)>>,
}
impl ParallelScanCoordinator {
pub fn new(
num_workers: u32,
total_partitions: u32,
dimensions: u32,
k: usize,
ef_search: usize,
metric: DistanceMetric,
) -> Self {
let shared_state = Arc::new(RwLock::new(RuHnswSharedState::new(
num_workers,
total_partitions,
dimensions,
k,
ef_search,
metric,
)));
Self {
shared_state,
worker_results: Vec::with_capacity(num_workers as usize),
}
}
pub fn execute_parallel_scan(
&mut self,
index: &HnswIndex,
query: Vec<f32>,
) -> Vec<(f32, ItemPointer)> {
let num_workers = {
let shared = self.shared_state.read();
shared.num_workers
};
use rayon::prelude::*;
let results: Vec<Vec<(f32, ItemPointer)>> = (0..num_workers)
.into_par_iter()
.map(|worker_id| {
let mut scan_desc = RuHnswParallelScanDesc::new(
Arc::clone(&self.shared_state),
worker_id,
query.clone(),
);
scan_desc.execute_scan(index);
scan_desc.local_results
})
.collect();
self.worker_results = results;
let k = {
let shared = self.shared_state.read();
shared.k
};
merge_knn_results_tournament(&self.worker_results, k)
}
pub fn get_stats(&self) -> ParallelScanStats {
let shared = self.shared_state.read();
ParallelScanStats {
num_workers: shared.num_workers,
total_partitions: shared.total_partitions,
completed_workers: shared.completed_workers.load(AtomicOrdering::SeqCst),
total_results: shared.total_results.load(AtomicOrdering::SeqCst),
}
}
}
#[derive(Debug, Clone)]
pub struct ParallelScanStats {
pub num_workers: u32,
pub total_partitions: u32,
pub completed_workers: u32,
pub total_results: usize,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_shared_state_partitioning() {
let state = RuHnswSharedState::new(
4, 16, 128, 10, 40, DistanceMetric::Euclidean,
);
assert_eq!(state.get_next_partition(), Some(0));
assert_eq!(state.get_next_partition(), Some(1));
assert_eq!(state.get_next_partition(), Some(2));
for _ in 3..16 {
state.get_next_partition();
}
assert_eq!(state.get_next_partition(), None);
}
#[test]
fn test_worker_estimation() {
assert_eq!(ruhnsw_estimate_parallel_workers(50, 5000, 10, 40), 0);
let workers = ruhnsw_estimate_parallel_workers(2000, 100000, 10, 40);
assert!(workers > 0 && workers <= 4);
let workers_complex = ruhnsw_estimate_parallel_workers(5000, 500000, 100, 200);
let workers_simple = ruhnsw_estimate_parallel_workers(5000, 500000, 10, 40);
assert!(workers_complex >= workers_simple);
}
#[test]
fn test_merge_knn_results() {
let worker1 = vec![
(0.1, ItemPointer::new(1, 1)),
(0.3, ItemPointer::new(1, 3)),
(0.5, ItemPointer::new(1, 5)),
];
let worker2 = vec![
(0.2, ItemPointer::new(2, 2)),
(0.4, ItemPointer::new(2, 4)),
(0.6, ItemPointer::new(2, 6)),
];
let worker3 = vec![
(0.15, ItemPointer::new(3, 1)),
(0.35, ItemPointer::new(3, 3)),
];
let results = merge_knn_results(&[worker1, worker2, worker3], 5);
assert_eq!(results.len(), 5);
assert_eq!(results[0].0, 0.1);
assert_eq!(results[1].0, 0.15);
assert_eq!(results[2].0, 0.2);
assert_eq!(results[3].0, 0.3);
assert_eq!(results[4].0, 0.35);
}
#[test]
fn test_merge_tournament() {
let worker1 = vec![
(0.1, ItemPointer::new(1, 1)),
(0.4, ItemPointer::new(1, 4)),
];
let worker2 = vec![
(0.2, ItemPointer::new(2, 2)),
(0.5, ItemPointer::new(2, 5)),
];
let worker3 = vec![
(0.3, ItemPointer::new(3, 3)),
(0.6, ItemPointer::new(3, 6)),
];
let results = merge_knn_results_tournament(&[worker1, worker2, worker3], 4);
assert_eq!(results.len(), 4);
assert_eq!(results[0].0, 0.1);
assert_eq!(results[1].0, 0.2);
assert_eq!(results[2].0, 0.3);
assert_eq!(results[3].0, 0.4);
}
#[test]
fn test_partition_estimation() {
let partitions = estimate_partitions(2, 15000);
assert!(partitions >= 2 && partitions <= 6);
let partitions_large = estimate_partitions(4, 500000);
assert!(partitions_large > partitions);
}
#[test]
fn test_item_pointer_creation() {
let ip1 = create_item_pointer(0);
assert_eq!(ip1.block_number, 0);
assert_eq!(ip1.offset_number, 1);
let ip2 = create_item_pointer(8191);
assert_eq!(ip2.block_number, 1);
assert_eq!(ip2.offset_number, 1);
let ip3 = create_item_pointer(100);
assert_eq!(ip3.block_number, 0);
assert_eq!(ip3.offset_number, 101);
}
}