use crate::error::{GraphError, Result};
use crate::types::{EdgeId, NodeId, Properties, PropertyValue};
use dashmap::DashMap;
use parking_lot::RwLock;
use ruvector_core::index::flat::FlatIndex;
#[cfg(feature = "hnsw_rs")]
use ruvector_core::index::hnsw::HnswIndex;
use ruvector_core::index::VectorIndex;
#[cfg(feature = "hnsw_rs")]
use ruvector_core::types::HnswConfig;
use ruvector_core::types::{DistanceMetric, SearchResult};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum VectorIndexType {
Node,
Edge,
Hyperedge,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingConfig {
pub dimensions: usize,
pub metric: DistanceMetric,
#[cfg(feature = "hnsw_rs")]
pub hnsw_config: HnswConfig,
pub embedding_property: String,
}
impl Default for EmbeddingConfig {
fn default() -> Self {
Self {
dimensions: 384, metric: DistanceMetric::Cosine,
#[cfg(feature = "hnsw_rs")]
hnsw_config: HnswConfig::default(),
embedding_property: "embedding".to_string(),
}
}
}
#[cfg(feature = "hnsw_rs")]
type IndexImpl = HnswIndex;
#[cfg(not(feature = "hnsw_rs"))]
type IndexImpl = FlatIndex;
pub struct HybridIndex {
node_index: Arc<RwLock<Option<IndexImpl>>>,
edge_index: Arc<RwLock<Option<IndexImpl>>>,
hyperedge_index: Arc<RwLock<Option<IndexImpl>>>,
node_id_map: Arc<DashMap<NodeId, String>>,
edge_id_map: Arc<DashMap<EdgeId, String>>,
hyperedge_id_map: Arc<DashMap<String, String>>,
config: EmbeddingConfig,
}
impl HybridIndex {
pub fn new(config: EmbeddingConfig) -> Result<Self> {
Ok(Self {
node_index: Arc::new(RwLock::new(None)),
edge_index: Arc::new(RwLock::new(None)),
hyperedge_index: Arc::new(RwLock::new(None)),
node_id_map: Arc::new(DashMap::new()),
edge_id_map: Arc::new(DashMap::new()),
hyperedge_id_map: Arc::new(DashMap::new()),
config,
})
}
#[cfg(feature = "hnsw_rs")]
pub fn initialize_index(&self, index_type: VectorIndexType) -> Result<()> {
let index = HnswIndex::new(
self.config.dimensions,
self.config.metric,
self.config.hnsw_config.clone(),
)
.map_err(|e| GraphError::IndexError(format!("Failed to create HNSW index: {}", e)))?;
match index_type {
VectorIndexType::Node => {
*self.node_index.write() = Some(index);
}
VectorIndexType::Edge => {
*self.edge_index.write() = Some(index);
}
VectorIndexType::Hyperedge => {
*self.hyperedge_index.write() = Some(index);
}
}
Ok(())
}
#[cfg(not(feature = "hnsw_rs"))]
pub fn initialize_index(&self, index_type: VectorIndexType) -> Result<()> {
let index = FlatIndex::new(self.config.dimensions, self.config.metric);
match index_type {
VectorIndexType::Node => {
*self.node_index.write() = Some(index);
}
VectorIndexType::Edge => {
*self.edge_index.write() = Some(index);
}
VectorIndexType::Hyperedge => {
*self.hyperedge_index.write() = Some(index);
}
}
Ok(())
}
pub fn add_node_embedding(&self, node_id: NodeId, embedding: Vec<f32>) -> Result<()> {
if embedding.len() != self.config.dimensions {
return Err(GraphError::InvalidEmbedding(format!(
"Expected {} dimensions, got {}",
self.config.dimensions,
embedding.len()
)));
}
let mut index_guard = self.node_index.write();
let index = index_guard
.as_mut()
.ok_or_else(|| GraphError::IndexError("Node index not initialized".to_string()))?;
let vector_id = format!("node_{}", node_id);
index
.add(vector_id.clone(), embedding)
.map_err(|e| GraphError::IndexError(format!("Failed to add node embedding: {}", e)))?;
self.node_id_map.insert(node_id, vector_id);
Ok(())
}
pub fn add_edge_embedding(&self, edge_id: EdgeId, embedding: Vec<f32>) -> Result<()> {
if embedding.len() != self.config.dimensions {
return Err(GraphError::InvalidEmbedding(format!(
"Expected {} dimensions, got {}",
self.config.dimensions,
embedding.len()
)));
}
let mut index_guard = self.edge_index.write();
let index = index_guard
.as_mut()
.ok_or_else(|| GraphError::IndexError("Edge index not initialized".to_string()))?;
let vector_id = format!("edge_{}", edge_id);
index
.add(vector_id.clone(), embedding)
.map_err(|e| GraphError::IndexError(format!("Failed to add edge embedding: {}", e)))?;
self.edge_id_map.insert(edge_id, vector_id);
Ok(())
}
pub fn add_hyperedge_embedding(&self, hyperedge_id: String, embedding: Vec<f32>) -> Result<()> {
if embedding.len() != self.config.dimensions {
return Err(GraphError::InvalidEmbedding(format!(
"Expected {} dimensions, got {}",
self.config.dimensions,
embedding.len()
)));
}
let mut index_guard = self.hyperedge_index.write();
let index = index_guard
.as_mut()
.ok_or_else(|| GraphError::IndexError("Hyperedge index not initialized".to_string()))?;
let vector_id = format!("hyperedge_{}", hyperedge_id);
index.add(vector_id.clone(), embedding).map_err(|e| {
GraphError::IndexError(format!("Failed to add hyperedge embedding: {}", e))
})?;
self.hyperedge_id_map.insert(hyperedge_id, vector_id);
Ok(())
}
pub fn search_similar_nodes(&self, query: &[f32], k: usize) -> Result<Vec<(NodeId, f32)>> {
let index_guard = self.node_index.read();
let index = index_guard
.as_ref()
.ok_or_else(|| GraphError::IndexError("Node index not initialized".to_string()))?;
let results = index
.search(query, k)
.map_err(|e| GraphError::IndexError(format!("Search failed: {}", e)))?;
Ok(results
.into_iter()
.filter_map(|result| {
let node_id = result.id.strip_prefix("node_")?.to_string();
Some((node_id, result.score))
})
.collect())
}
pub fn search_similar_edges(&self, query: &[f32], k: usize) -> Result<Vec<(EdgeId, f32)>> {
let index_guard = self.edge_index.read();
let index = index_guard
.as_ref()
.ok_or_else(|| GraphError::IndexError("Edge index not initialized".to_string()))?;
let results = index
.search(query, k)
.map_err(|e| GraphError::IndexError(format!("Search failed: {}", e)))?;
Ok(results
.into_iter()
.filter_map(|result| {
let edge_id = result.id.strip_prefix("edge_")?.to_string();
Some((edge_id, result.score))
})
.collect())
}
pub fn search_similar_hyperedges(&self, query: &[f32], k: usize) -> Result<Vec<(String, f32)>> {
let index_guard = self.hyperedge_index.read();
let index = index_guard
.as_ref()
.ok_or_else(|| GraphError::IndexError("Hyperedge index not initialized".to_string()))?;
let results = index
.search(query, k)
.map_err(|e| GraphError::IndexError(format!("Search failed: {}", e)))?;
Ok(results
.into_iter()
.filter_map(|result| {
let hyperedge_id = result.id.strip_prefix("hyperedge_")?.to_string();
Some((hyperedge_id, result.score))
})
.collect())
}
pub fn extract_embedding(&self, properties: &Properties) -> Result<Option<Vec<f32>>> {
let prop_value = match properties.get(&self.config.embedding_property) {
Some(v) => v,
None => return Ok(None),
};
match prop_value {
PropertyValue::Array(arr) => {
let embedding: Result<Vec<f32>> = arr
.iter()
.map(|v| match v {
PropertyValue::Float(f) => Ok(*f as f32),
PropertyValue::Integer(i) => Ok(*i as f32),
_ => Err(GraphError::InvalidEmbedding(
"Embedding array must contain numbers".to_string(),
)),
})
.collect();
embedding.map(Some)
}
_ => Err(GraphError::InvalidEmbedding(
"Embedding property must be an array".to_string(),
)),
}
}
pub fn stats(&self) -> HybridIndexStats {
let node_count = self.node_id_map.len();
let edge_count = self.edge_id_map.len();
let hyperedge_count = self.hyperedge_id_map.len();
HybridIndexStats {
node_count,
edge_count,
hyperedge_count,
total_embeddings: node_count + edge_count + hyperedge_count,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HybridIndexStats {
pub node_count: usize,
pub edge_count: usize,
pub hyperedge_count: usize,
pub total_embeddings: usize,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_hybrid_index_creation() -> Result<()> {
let config = EmbeddingConfig::default();
let index = HybridIndex::new(config)?;
let stats = index.stats();
assert_eq!(stats.total_embeddings, 0);
Ok(())
}
#[test]
fn test_node_embedding_indexing() -> Result<()> {
let config = EmbeddingConfig {
dimensions: 4,
..Default::default()
};
let index = HybridIndex::new(config)?;
index.initialize_index(VectorIndexType::Node)?;
let embedding = vec![1.0, 2.0, 3.0, 4.0];
index.add_node_embedding("node1".to_string(), embedding)?;
let stats = index.stats();
assert_eq!(stats.node_count, 1);
Ok(())
}
#[test]
fn test_similarity_search() -> Result<()> {
let config = EmbeddingConfig {
dimensions: 4,
..Default::default()
};
let index = HybridIndex::new(config)?;
index.initialize_index(VectorIndexType::Node)?;
index.add_node_embedding("node1".to_string(), vec![1.0, 0.0, 0.0, 0.0])?;
index.add_node_embedding("node2".to_string(), vec![0.9, 0.1, 0.0, 0.0])?;
index.add_node_embedding("node3".to_string(), vec![0.0, 1.0, 0.0, 0.0])?;
let results = index.search_similar_nodes(&[1.0, 0.0, 0.0, 0.0], 2)?;
assert!(results.len() <= 2);
if !results.is_empty() {
assert_eq!(results[0].0, "node1");
}
Ok(())
}
}