vectus 0.1.37

A vector database implemented in Rust for learning purposes.
Documentation
pub mod metric;
mod node;

use metric::Metric;
use ndarray::Array1;
use node::HNSWNode;
use rand::Rng;
use std::cmp::Ordering;

pub struct HNSW {
    nodes: Vec<HNSWNode>,
    max_level: usize,
    ef_construction: usize,
    m: usize,
    m_max: usize,
    norm: f32,
    entry: Option<usize>,
    level: usize,
    metric: Metric,
}

#[derive(Default)]
pub struct HNSWInitializer {
    pub max_level: usize,
    pub ef_construction: usize,
    pub m: usize,
    pub m_max: usize,
    pub norm: f32,
    pub entry: Option<usize>,
    pub metric: Metric,
}

impl HNSW {
    pub fn len(&self) -> usize {
        self.nodes.len()
    }

    pub(crate) fn new(hnsw_init: HNSWInitializer) -> HNSW {
        HNSW {
            nodes: Vec::new(),
            max_level: hnsw_init.max_level,
            ef_construction: hnsw_init.ef_construction,
            m: hnsw_init.m,
            m_max: hnsw_init.m_max,
            norm: hnsw_init.norm,
            entry: hnsw_init.entry,
            level: 0,
            metric: hnsw_init.metric,
        }
    }

    fn random_level(&self) -> usize {
        let mut rng = rand::thread_rng();
        let random: f32 = rng.gen_range(0.0..1.0);
        (-random.ln() * self.norm).floor() as usize
    }

    fn set_entry(&mut self, entry: usize) {
        self.entry = Some(entry);
    }

    fn get_entry(&self) -> usize {
        self.entry.unwrap()
    }

    pub async fn insert(&mut self, embedding: &Array1<f32>, id: usize) {
        let mut node = HNSWNode::new(embedding.clone(), id);
        let node_level = self.random_level();
        node.level = node_level;

        if self.entry.is_none() {
            self.set_entry(node.id);
            self.nodes.push(node);
            return;
        }

        let mut closest_node = self.get_entry();
        for level in (node_level..=self.level).rev() {
            // If greedy_search involves any async operations, make it async too
            closest_node = self.greedy_search(node.clone(), closest_node, level).await;
        }

        let node_id = self.nodes.len();
        self.nodes.push(node);

        // Connect the new node with neighbors at each level up to its own level
        for level in 0..=node_level {
            // If get_neighbors involves any async operations, make it async too
            let neighbors = self
                .get_neighbors(self.nodes[node_id].clone(), closest_node, level)
                .await;
            // If connect_neighbors involves any async operations, make it async too
            self.connect_neighbors(node_id, neighbors.clone(), level)
                .await;
        }

        if node_level > self.level {
            self.level = node_level;
            self.set_entry(node_id);
        }
    }

    async fn greedy_search(&self, target: HNSWNode, entry_point: usize, level: usize) -> usize {
        let mut closest_node = entry_point;
        let mut closest_dist = self.distance(&self.nodes[entry_point].embedding, &target.embedding);

        loop {
            let mut improved = false;
            for &neighbor in &self.nodes[closest_node].neighbors {
                let dist = self.distance(&self.nodes[neighbor].embedding, &target.embedding);
                if dist < closest_dist {
                    closest_node = neighbor;
                    closest_dist = dist;
                    improved = true;
                }
            }
            if !improved {
                break;
            }
        }

        closest_node
    }

    async fn get_neighbors(
        &self,
        target: HNSWNode,
        entry_point: usize,
        level: usize,
    ) -> Vec<usize> {
        let mut candidates = vec![entry_point];
        let mut neighbors = Vec::new();

        while candidates.len() > 0 && neighbors.len() < self.m {
            let candidate = candidates.pop().unwrap();
            let dist = self.distance(&self.nodes[candidate].embedding, &target.embedding);
            neighbors.push(candidate);

            for &neighbor in &self.nodes[candidate].neighbors {
                candidates.push(neighbor);
            }
        }

        neighbors
    }

    async fn connect_neighbors(&mut self, node_id: usize, neighbors: Vec<usize>, level: usize) {
        if node_id >= self.nodes.len() {
            panic!("Node ID is out of bounds");
        }

        for &neighbor_id in &neighbors {
            if neighbor_id >= self.nodes.len() {
                panic!("Neighbor ID is out of bounds");
            }

            self.nodes[neighbor_id].neighbors.push(node_id);
            self.nodes[node_id].neighbors.push(neighbor_id);
        }
    }

    pub async fn search(&self, query_emb: Array1<f32>, id: usize, k: usize) -> Vec<usize> {
        let node = HNSWNode::new(query_emb.clone(), id);

        if self.entry.is_none() {
            return Vec::new();
        }

        let mut closest_node = self.get_entry();
        for level in (0..=self.level).rev() {
            closest_node = self.greedy_search(node.clone(), closest_node, level).await;
        }

        let neighbors = self.get_neighbors(node, closest_node, 0).await;
        neighbors
    }

    fn distance(&self, emb1: &Array1<f32>, emb2: &Array1<f32>) -> f32 {
        self.metric.distance(emb1, emb2)
    }
}