use std::cmp::Ordering;
use std::collections::{BinaryHeap, HashMap, HashSet};
use std::fs::{File, OpenOptions};
use std::io::{BufReader, BufWriter, Read, Write};
use std::path::{Path, PathBuf};
use crate::distance::DistanceMetric;
use crate::error::Result;
use super::mmap_store::MmapVectorStorage;
#[derive(Clone, Debug)]
pub struct AppendGraphConfig {
pub m: usize,
pub m_max: usize,
pub metric: DistanceMetric,
}
impl Default for AppendGraphConfig {
fn default() -> Self {
Self {
m: 16,
m_max: 32,
metric: DistanceMetric::Cosine,
}
}
}
#[derive(Clone)]
struct Candidate {
id: u32,
distance: f32,
}
impl PartialEq for Candidate {
fn eq(&self, other: &Self) -> bool {
self.distance == other.distance
}
}
impl Eq for Candidate {}
impl PartialOrd for Candidate {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for Candidate {
fn cmp(&self, other: &Self) -> Ordering {
other
.distance
.partial_cmp(&self.distance)
.unwrap_or(Ordering::Equal)
}
}
pub struct AppendGraph {
path: PathBuf,
edge_file: Option<BufWriter<File>>,
neighbors: HashMap<u32, Vec<(u32, f32)>>,
config: AppendGraphConfig,
edge_count: usize,
}
impl AppendGraph {
pub fn new<P: AsRef<Path>>(path: P, config: AppendGraphConfig) -> Result<Self> {
let path = path.as_ref().to_path_buf();
let exists = path.exists();
let mut graph = Self {
path: path.clone(),
edge_file: None,
neighbors: HashMap::new(),
config,
edge_count: 0,
};
if exists {
graph.load_edges()?;
}
let file = OpenOptions::new().create(true).append(true).open(&path)?;
graph.edge_file = Some(BufWriter::new(file));
Ok(graph)
}
pub fn in_memory(config: AppendGraphConfig) -> Self {
Self {
path: PathBuf::new(),
edge_file: None,
neighbors: HashMap::new(),
config,
edge_count: 0,
}
}
fn load_edges(&mut self) -> Result<()> {
let file = File::open(&self.path)?;
let mut reader = BufReader::new(file);
let mut buf = [0u8; 12];
while reader.read_exact(&mut buf).is_ok() {
let from = u32::from_le_bytes([buf[0], buf[1], buf[2], buf[3]]);
let to = u32::from_le_bytes([buf[4], buf[5], buf[6], buf[7]]);
let distance = f32::from_le_bytes([buf[8], buf[9], buf[10], buf[11]]);
self.neighbors.entry(from).or_default().push((to, distance));
self.edge_count += 1;
}
for neighbors in self.neighbors.values_mut() {
neighbors.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
neighbors.truncate(self.config.m_max);
}
Ok(())
}
pub fn add_edge(&mut self, from: u32, to: u32, distance: f32) -> Result<()> {
if let Some(ref mut file) = self.edge_file {
file.write_all(&from.to_le_bytes())?;
file.write_all(&to.to_le_bytes())?;
file.write_all(&distance.to_le_bytes())?;
}
let neighbors = self.neighbors.entry(from).or_default();
if !neighbors.iter().any(|(n, _)| *n == to) {
neighbors.push((to, distance));
if neighbors.len() > self.config.m_max {
neighbors.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
neighbors.truncate(self.config.m_max);
}
}
self.edge_count += 1;
Ok(())
}
pub fn add_bidirectional_edge(&mut self, a: u32, b: u32, distance: f32) -> Result<()> {
self.add_edge(a, b, distance)?;
self.add_edge(b, a, distance)?;
Ok(())
}
#[inline]
pub fn get_neighbors(&self, node: u32) -> &[(u32, f32)] {
self.neighbors
.get(&node)
.map(|v| v.as_slice())
.unwrap_or(&[])
}
pub fn search(
&self,
storage: &MmapVectorStorage,
query: &[f32],
entry_points: &[u32],
k: usize,
ef: usize,
) -> Vec<(u32, f32)> {
if entry_points.is_empty() {
return Vec::new();
}
let mut visited = HashSet::with_capacity(ef * 2);
let mut candidates = BinaryHeap::new(); let mut results = BinaryHeap::new();
let max_initial = entry_points.len().min(ef);
for &ep in entry_points.iter().take(max_initial) {
if !visited.contains(&ep) {
visited.insert(ep);
if let Some(dist) = storage.distance_to_id(query, ep) {
candidates.push(Candidate {
id: ep,
distance: dist,
});
results.push(std::cmp::Reverse(Candidate {
id: ep,
distance: dist,
}));
}
}
}
let mut stale_count = 0;
let max_stale = 10;
while let Some(current) = candidates.pop() {
let worst_dist = if results.len() >= ef {
results.peek().map(|r| r.0.distance).unwrap_or(f32::MAX)
} else {
f32::MAX
};
if current.distance > worst_dist {
stale_count += 1;
if stale_count >= max_stale {
break;
}
continue;
}
stale_count = 0;
for &(neighbor_id, _) in self.get_neighbors(current.id) {
if !visited.contains(&neighbor_id) {
visited.insert(neighbor_id);
if let Some(dist) = storage.distance_to_id(query, neighbor_id) {
let should_add = results.len() < ef || dist < worst_dist;
if should_add {
candidates.push(Candidate {
id: neighbor_id,
distance: dist,
});
results.push(std::cmp::Reverse(Candidate {
id: neighbor_id,
distance: dist,
}));
while results.len() > ef {
results.pop();
}
}
}
}
}
}
let mut final_results: Vec<(u32, f32)> = results
.into_iter()
.map(|r| (r.0.id, r.0.distance))
.collect();
final_results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
final_results.truncate(k);
final_results
}
pub fn insert_with_candidates(
&mut self,
storage: &MmapVectorStorage,
node_id: u32,
candidates: &[u32],
) -> Result<()> {
let vector = match storage.get_vector_by_id(node_id) {
Some(v) => v,
None => return Ok(()),
};
let mut neighbor_candidates: Vec<(u32, f32)> = candidates
.iter()
.filter_map(|&cid| storage.distance_to_id(&vector, cid).map(|d| (cid, d)))
.collect();
if neighbor_candidates.len() < self.config.m && !self.neighbors.is_empty() {
let sample_size = (self.config.m * 4).min(storage.len());
let step = (storage.len() / sample_size).max(1);
for id in (0..storage.len() as u32).step_by(step) {
if id == node_id {
continue;
}
if neighbor_candidates.iter().any(|(c, _)| *c == id) {
continue;
}
if let Some(dist) = storage.distance_to_id(&vector, id) {
neighbor_candidates.push((id, dist));
}
if neighbor_candidates.len() >= self.config.m * 4 {
break;
}
}
}
neighbor_candidates.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
neighbor_candidates.truncate(self.config.m);
for (neighbor_id, dist) in neighbor_candidates {
self.add_bidirectional_edge(node_id, neighbor_id, dist)?;
}
Ok(())
}
pub fn node_count(&self) -> usize {
self.neighbors.len()
}
pub fn edge_count(&self) -> usize {
self.edge_count
}
pub fn flush(&mut self) -> Result<()> {
if let Some(ref mut file) = self.edge_file {
file.flush()?;
}
Ok(())
}
}
impl Drop for AppendGraph {
fn drop(&mut self) {
let _ = self.flush();
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::tempdir;
#[test]
fn test_append_graph_basic() {
let dir = tempdir().unwrap();
let path = dir.path().join("test.edges");
let config = AppendGraphConfig::default();
let mut graph = AppendGraph::new(&path, config).unwrap();
graph.add_bidirectional_edge(0, 1, 0.1).unwrap();
graph.add_bidirectional_edge(0, 2, 0.2).unwrap();
let neighbors = graph.get_neighbors(0);
assert_eq!(neighbors.len(), 2);
}
#[test]
fn test_append_graph_persistence() {
let dir = tempdir().unwrap();
let path = dir.path().join("test.edges");
{
let config = AppendGraphConfig::default();
let mut graph = AppendGraph::new(&path, config).unwrap();
for i in 0..10 {
graph
.add_bidirectional_edge(i, i + 1, i as f32 * 0.1)
.unwrap();
}
graph.flush().unwrap();
}
{
let config = AppendGraphConfig::default();
let graph = AppendGraph::new(&path, config).unwrap();
assert!(graph.node_count() > 0);
let neighbors = graph.get_neighbors(5);
assert!(!neighbors.is_empty());
}
}
}