use crate::error::{Result, StorageError};
use crate::types::RowId;
use crate::distance::DistanceMetric;
use dashmap::DashMap;
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use std::sync::Arc;
use std::collections::HashSet;
use std::time::{SystemTime, UNIX_EPOCH};
use super::Candidate;
#[derive(Debug, Clone)]
pub struct FreshGraphConfig {
pub max_nodes: usize,
pub max_degree: usize,
pub search_list_size: usize,
pub alpha: f32,
pub memory_threshold: usize,
}
impl Default for FreshGraphConfig {
fn default() -> Self {
Self {
max_nodes: 2000, max_degree: 64, search_list_size: 200, alpha: 1.2,
memory_threshold: 200 * 1024 * 1024,
}
}
}
#[derive(Clone)]
pub struct VectorNode {
pub vector: Vec<f32>,
pub neighbors: Vec<RowId>,
pub timestamp: u64,
pub deleted: bool, }
impl VectorNode {
pub fn new(vector: Vec<f32>) -> Self {
let timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
Self {
vector,
neighbors: Vec::new(),
timestamp,
deleted: false, }
}
pub fn memory_size(&self) -> usize {
self.vector.len() * 4 + self.neighbors.len() * 8 + 16 + 1 }
}
pub struct FreshVamanaGraph {
nodes: DashMap<RowId, VectorNode>,
medoid: AtomicU64,
config: FreshGraphConfig,
metric: Arc<dyn DistanceMetric>,
insert_count: AtomicUsize,
memory_usage: AtomicUsize,
}
impl FreshVamanaGraph {
pub fn new(config: FreshGraphConfig, metric: Arc<dyn DistanceMetric>) -> Self {
Self {
nodes: DashMap::new(),
medoid: AtomicU64::new(0),
config,
metric,
insert_count: AtomicUsize::new(0),
memory_usage: AtomicUsize::new(0),
}
}
pub fn insert(&self, id: RowId, vector: Vec<f32>) -> Result<()> {
if self.nodes.len() >= self.config.max_nodes {
return Err(StorageError::ResourceExhausted(
format!("Fresh graph is full ({})", self.config.max_nodes)
));
}
let node_count = self.nodes.len();
if node_count == 0 {
let node = VectorNode::new(vector);
self.nodes.insert(id, node);
self.medoid.store(id, Ordering::Release);
self.insert_count.fetch_add(1, Ordering::Relaxed);
return Ok(());
}
let neighbors = if node_count < 100 {
self.brute_force_knn(&vector, self.config.max_degree)
} else {
let medoid = self.medoid.load(Ordering::Acquire);
self.greedy_search_knn(&vector, medoid, self.config.max_degree)?
};
let mut node = VectorNode::new(vector.clone());
node.neighbors = neighbors.clone();
self.nodes.insert(id, node);
self.insert_count.fetch_add(1, Ordering::Relaxed);
for &neighbor_id in &neighbors {
if let Some(mut neighbor_node) = self.nodes.get_mut(&neighbor_id) {
if !neighbor_node.neighbors.contains(&id) && neighbor_node.neighbors.len() < self.config.max_degree {
neighbor_node.neighbors.push(id);
}
}
}
Ok(())
}
pub fn batch_insert(&self, vectors: &[(RowId, Vec<f32>)]) -> Result<()> {
if vectors.is_empty() {
return Ok(());
}
if self.nodes.len() + vectors.len() > self.config.max_nodes {
return Err(StorageError::ResourceExhausted(
format!("Batch insert would exceed max_nodes: {} + {} > {}",
self.nodes.len(), vectors.len(), self.config.max_nodes)
));
}
let start = std::time::Instant::now();
let batch_size = vectors.len();
for (id, vector) in vectors {
let node = VectorNode::new(vector.clone());
self.nodes.insert(*id, node);
}
let insert_time = start.elapsed();
let graph_start = std::time::Instant::now();
self.batch_build_graph()?;
let graph_time = graph_start.elapsed();
self.insert_count.fetch_add(batch_size, Ordering::Relaxed);
eprintln!("[FreshGraph] 批量插入 {} 个向量: 插入={:?}, 建图={:?}, 总计={:?}",
batch_size, insert_time, graph_time, start.elapsed());
Ok(())
}
fn batch_build_graph(&self) -> Result<()> {
let node_ids: Vec<_> = self.nodes.iter().map(|entry| *entry.key()).collect();
let node_count = node_ids.len();
if node_count == 0 {
return Ok(());
}
if node_count == 1 {
self.medoid.store(node_ids[0], Ordering::Release);
return Ok(());
}
let temp_medoid = node_ids[0];
self.medoid.store(temp_medoid, Ordering::Release);
let max_degree = self.config.max_degree;
let start = std::time::Instant::now();
if node_count < 1000 {
self.batch_build_graph_simple(&node_ids, max_degree)?;
} else {
self.batch_build_graph_parallel(&node_ids, max_degree)?;
}
eprintln!("[FreshGraph] 批量构建图完成:{} 个节点,耗时: {:?}",
node_count, start.elapsed());
Ok(())
}
fn batch_build_graph_simple(&self, node_ids: &[RowId], max_degree: usize) -> Result<()> {
for &node_id in node_ids {
if let Some(node_ref) = self.nodes.get(&node_id) {
let vector = &node_ref.vector;
let mut distances: Vec<_> = node_ids.iter()
.filter(|&&other_id| other_id != node_id)
.filter_map(|&other_id| {
self.nodes.get(&other_id).map(|other_node| {
let dist = self.metric.distance(vector, &other_node.vector);
(dist, other_id)
})
})
.collect();
distances.sort_unstable_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
let neighbors: Vec<_> = distances.iter()
.take(max_degree)
.map(|(_, id)| *id)
.collect();
drop(node_ref);
if let Some(mut node_mut) = self.nodes.get_mut(&node_id) {
node_mut.neighbors = neighbors;
}
}
}
Ok(())
}
fn batch_build_graph_parallel(&self, node_ids: &[RowId], max_degree: usize) -> Result<()> {
use rayon::prelude::*;
let vectors: Vec<_> = node_ids.par_iter()
.filter_map(|&id| {
self.nodes.get(&id).map(|node| (id, node.vector.clone()))
})
.collect();
eprintln!("[FreshGraph] 预加载 {} 个向量", vectors.len());
let neighbors_list: Vec<_> = vectors.par_iter()
.map(|(node_id, vector)| {
let mut distances: Vec<_> = vectors.iter()
.filter(|(other_id, _)| other_id != node_id)
.map(|(other_id, other_vec)| {
let dist = self.metric.distance(vector, other_vec);
(dist, *other_id)
})
.collect();
distances.sort_unstable_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
let neighbors: Vec<_> = distances.iter()
.take(max_degree)
.map(|(_, id)| *id)
.collect();
(*node_id, neighbors)
})
.collect();
eprintln!("[FreshGraph] 计算 {} 个节点的邻居(自动SIMD优化)", neighbors_list.len());
for (node_id, neighbors) in neighbors_list {
if let Some(mut node_mut) = self.nodes.get_mut(&node_id) {
node_mut.neighbors = neighbors;
}
}
Ok(())
}
fn brute_force_knn(&self, query: &[f32], k: usize) -> Vec<RowId> {
let mut candidates: Vec<(RowId, f32)> = self.nodes.iter()
.map(|entry| {
let dist = self.metric.distance(query, &entry.value().vector);
(*entry.key(), dist)
})
.collect();
candidates.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
candidates.truncate(k);
candidates.into_iter().map(|(id, _)| id).collect()
}
fn greedy_search_knn(&self, query: &[f32], start: RowId, k: usize) -> Result<Vec<RowId>> {
let mut visited = std::collections::HashSet::new();
let mut best_candidates = std::collections::BinaryHeap::new();
if let Some(start_node) = self.nodes.get(&start) {
let dist = self.metric.distance(query, &start_node.vector);
best_candidates.push(Candidate::new(start, dist));
visited.insert(start);
}
let mut iterations = 0;
let max_iter = 1000;
while iterations < max_iter && !best_candidates.is_empty() {
let current = best_candidates.pop().unwrap();
iterations += 1;
if let Some(node) = self.nodes.get(¤t.id) {
for &neighbor_id in &node.neighbors {
if visited.contains(&neighbor_id) {
continue;
}
visited.insert(neighbor_id);
if let Some(neighbor_node) = self.nodes.get(&neighbor_id) {
let dist = self.metric.distance(query, &neighbor_node.vector);
best_candidates.push(Candidate::new(neighbor_id, dist));
}
}
}
}
let mut results: Vec<_> = best_candidates.into_sorted_vec();
results.truncate(k);
Ok(results.into_iter().map(|c| c.id).collect())
}
pub fn search(&self, query: &[f32], k: usize, ef: usize) -> Result<Vec<Candidate>> {
if self.nodes.is_empty() {
return Ok(Vec::new());
}
if self.nodes.len() <= 50 {
self.linear_search(query, k)
} else {
self.graph_search(query, k, ef)
}
}
fn linear_search(&self, query: &[f32], k: usize) -> Result<Vec<Candidate>> {
let mut candidates: Vec<Candidate> = self.nodes.iter()
.filter(|entry| !entry.value().deleted) .map(|entry| {
let dist = self.metric.distance(query, &entry.value().vector);
Candidate::new(*entry.key(), dist)
})
.collect();
candidates.sort_by(|a, b| {
a.distance.partial_cmp(&b.distance).unwrap_or(std::cmp::Ordering::Equal)
});
candidates.truncate(k);
Ok(candidates)
}
fn graph_search(&self, query: &[f32], k: usize, ef: usize) -> Result<Vec<Candidate>> {
use std::collections::{BinaryHeap, HashSet};
let ef = ef.max(k * 3).max(50).min(self.nodes.len());
let start_ids = self.get_start_points();
let mut global_visited = HashSet::new(); let mut global_candidates = BinaryHeap::new();
let per_start_ef = ef;
for start_id in start_ids {
let local_results = self.graph_search_from_point(
query,
k,
per_start_ef,
start_id,
&mut global_visited, )?;
for candidate in local_results {
global_candidates.push(candidate);
}
}
let mut seen = HashSet::new();
let mut results: Vec<Candidate> = global_candidates.into_sorted_vec()
.into_iter()
.filter(|c| seen.insert(c.id))
.collect();
results.truncate(k);
Ok(results)
}
fn get_start_points(&self) -> Vec<RowId> {
let mut starts = Vec::new();
let ids: Vec<_> = self.nodes.iter().map(|e| *e.key()).collect();
if ids.is_empty() {
return starts;
}
let target_starts = 2.min(ids.len());
if ids.len() <= target_starts {
return ids; }
let step = ids.len() / target_starts;
for i in 0..target_starts {
starts.push(ids[i * step]);
}
starts
}
fn graph_search_from_point(
&self,
query: &[f32],
k: usize,
ef: usize,
start_id: RowId,
global_visited: &mut HashSet<RowId>, ) -> Result<Vec<Candidate>> {
use std::collections::BinaryHeap;
use std::cmp::Reverse;
let ef = ef.max(k * 2);
let start_node = match self.nodes.get(&start_id) {
Some(n) => n,
None => return Ok(Vec::new()),
};
let start_dist = self.metric.distance(query, &start_node.vector);
let mut candidates = BinaryHeap::new();
candidates.push(Reverse(Candidate::new(start_id, start_dist)));
let mut visited = BinaryHeap::new();
visited.push(Candidate::new(start_id, start_dist));
global_visited.insert(start_id);
while let Some(Reverse(current)) = candidates.pop() {
if visited.len() >= ef {
if let Some(furthest) = visited.peek() {
if current.distance > furthest.distance {
break;
}
}
}
if let Some(node) = self.nodes.get(¤t.id) {
for &neighbor_id in &node.neighbors {
if global_visited.contains(&neighbor_id) {
continue;
}
global_visited.insert(neighbor_id);
if let Some(neighbor_node) = self.nodes.get(&neighbor_id) {
let dist = self.metric.distance(query, &neighbor_node.vector);
if visited.len() < ef {
candidates.push(Reverse(Candidate::new(neighbor_id, dist)));
visited.push(Candidate::new(neighbor_id, dist));
} else if let Some(furthest) = visited.peek() {
if dist < furthest.distance {
candidates.push(Reverse(Candidate::new(neighbor_id, dist)));
visited.push(Candidate::new(neighbor_id, dist));
if visited.len() > ef {
visited.pop();
}
}
}
}
}
}
}
let results: Vec<Candidate> = visited.into_sorted_vec()
.into_iter()
.filter(|c| {
self.nodes.get(&c.id)
.map(|n| !n.deleted)
.unwrap_or(false)
})
.collect();
Ok(results)
}
pub fn should_flush(&self) -> bool {
self.nodes.len() >= self.config.max_nodes
}
pub fn node_count(&self) -> usize {
self.nodes.len()
}
pub fn memory_usage(&self) -> usize {
self.memory_usage.load(Ordering::Relaxed)
}
pub fn stats(&self) -> FreshGraphStats {
FreshGraphStats {
node_count: self.nodes.len(),
insert_count: self.insert_count.load(Ordering::Relaxed),
memory_usage: self.memory_usage.load(Ordering::Relaxed),
}
}
pub fn export_nodes(&self) -> Result<Vec<(RowId, VectorNode)>> {
let mut nodes: Vec<_> = self.nodes.iter()
.map(|e| (*e.key(), e.value().clone()))
.collect();
nodes.sort_by_key(|(id, _)| *id);
Ok(nodes)
}
pub fn medoid(&self) -> RowId {
self.medoid.load(Ordering::Acquire)
}
pub fn is_empty(&self) -> bool {
self.nodes.is_empty()
}
pub fn delete(&self, id: RowId) -> Result<()> {
if let Some(mut node) = self.nodes.get_mut(&id) {
node.deleted = true;
Ok(())
} else {
Err(StorageError::InvalidData(format!("Node {} not found", id)))
}
}
pub fn update(&self, id: RowId, vector: Vec<f32>) -> Result<()> {
self.delete(id)?;
self.insert(id, vector)?;
Ok(())
}
pub fn clear(&mut self) -> Result<()> {
self.nodes.clear();
self.medoid.store(0, Ordering::Release);
self.insert_count.store(0, Ordering::Relaxed);
self.memory_usage.store(0, Ordering::Relaxed);
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct FreshGraphStats {
pub node_count: usize,
pub insert_count: usize,
pub memory_usage: usize,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::distance::Euclidean;
#[test]
fn test_insert_and_search() {
let config = FreshGraphConfig::default();
let metric = Arc::new(Euclidean);
let graph = FreshVamanaGraph::new(config, metric);
for i in 0..50u64 {
let vector = vec![i as f32; 128];
graph.insert(i, vector).unwrap();
}
assert_eq!(graph.node_count(), 50);
let query = vec![25.0; 128];
let results = graph.search(&query, 5, 10).unwrap();
assert_eq!(results.len(), 5);
assert_eq!(results[0].id, 25);
}
}