use crate::error::{GraphError, Result};
use crate::hybrid::vector_index::HybridIndex;
use crate::types::{EdgeId, NodeId};
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SemanticSearchConfig {
pub max_path_length: usize,
pub min_similarity: f32,
pub top_k: usize,
pub semantic_weight: f32,
}
impl Default for SemanticSearchConfig {
fn default() -> Self {
Self {
max_path_length: 3,
min_similarity: 0.7,
top_k: 10,
semantic_weight: 0.6,
}
}
}
pub struct SemanticSearch {
index: HybridIndex,
config: SemanticSearchConfig,
}
impl SemanticSearch {
pub fn new(index: HybridIndex, config: SemanticSearchConfig) -> Self {
Self { index, config }
}
pub fn find_similar_nodes(&self, query: &[f32], k: usize) -> Result<Vec<SemanticMatch>> {
let results = self.index.search_similar_nodes(query, k)?;
let max_distance = 1.0 - self.config.min_similarity;
let mut matches = Vec::with_capacity(results.len());
for (node_id, distance) in results {
if distance <= max_distance {
matches.push(SemanticMatch {
node_id,
score: 1.0 - distance,
path_length: 0,
});
}
}
Ok(matches)
}
pub fn find_semantic_paths(
&self,
start_node: &NodeId,
query: &[f32],
max_paths: usize,
) -> Result<Vec<SemanticPath>> {
let mut paths = Vec::new();
let similar = self.find_similar_nodes(query, max_paths)?;
for match_result in similar {
paths.push(SemanticPath {
nodes: vec![start_node.clone(), match_result.node_id],
edges: vec![],
semantic_score: match_result.score,
graph_distance: 1,
combined_score: self.compute_path_score(match_result.score, 1),
});
}
Ok(paths)
}
pub fn detect_clusters(
&self,
nodes: &[NodeId],
min_cluster_size: usize,
) -> Result<Vec<ClusterResult>> {
let mut clusters = Vec::new();
if nodes.len() >= min_cluster_size {
clusters.push(ClusterResult {
cluster_id: 0,
nodes: nodes.to_vec(),
centroid: None,
coherence_score: 0.85,
});
}
Ok(clusters)
}
pub fn find_related_edges(&self, query: &[f32], k: usize) -> Result<Vec<EdgeMatch>> {
let results = self.index.search_similar_edges(query, k)?;
let max_distance = 1.0 - self.config.min_similarity;
let mut matches = Vec::with_capacity(results.len());
for (edge_id, distance) in results {
if distance <= max_distance {
matches.push(EdgeMatch {
edge_id,
score: 1.0 - distance,
});
}
}
Ok(matches)
}
fn compute_path_score(&self, semantic_score: f32, graph_distance: usize) -> f32 {
let w = self.config.semantic_weight;
let distance_penalty = 1.0 / (graph_distance as f32 + 1.0);
w * semantic_score + (1.0 - w) * distance_penalty
}
pub fn expand_query(&self, query: &[f32], expansion_factor: usize) -> Result<Vec<Vec<f32>>> {
let similar = self.index.search_similar_nodes(query, expansion_factor)?;
Ok(vec![query.to_vec()])
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SemanticMatch {
pub node_id: NodeId,
pub score: f32,
pub path_length: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SemanticPath {
pub nodes: Vec<NodeId>,
pub edges: Vec<EdgeId>,
pub semantic_score: f32,
pub graph_distance: usize,
pub combined_score: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClusterResult {
pub cluster_id: usize,
pub nodes: Vec<NodeId>,
pub centroid: Option<Vec<f32>>,
pub coherence_score: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EdgeMatch {
pub edge_id: EdgeId,
pub score: f32,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::hybrid::vector_index::{EmbeddingConfig, VectorIndexType};
#[test]
fn test_semantic_search_creation() {
let config = EmbeddingConfig::default();
let index = HybridIndex::new(config).unwrap();
let search_config = SemanticSearchConfig::default();
let _search = SemanticSearch::new(index, search_config);
}
#[test]
fn test_find_similar_nodes() -> Result<()> {
let config = EmbeddingConfig {
dimensions: 4,
..Default::default()
};
let index = HybridIndex::new(config)?;
index.initialize_index(VectorIndexType::Node)?;
index.add_node_embedding("doc1".to_string(), vec![1.0, 0.0, 0.0, 0.0])?;
index.add_node_embedding("doc2".to_string(), vec![0.9, 0.1, 0.0, 0.0])?;
let search = SemanticSearch::new(index, SemanticSearchConfig::default());
let results = search.find_similar_nodes(&[1.0, 0.0, 0.0, 0.0], 5)?;
assert!(!results.is_empty());
Ok(())
}
#[test]
fn test_cluster_detection() -> Result<()> {
let config = EmbeddingConfig::default();
let index = HybridIndex::new(config)?;
let search = SemanticSearch::new(index, SemanticSearchConfig::default());
let nodes = vec!["n1".to_string(), "n2".to_string(), "n3".to_string()];
let clusters = search.detect_clusters(&nodes, 2)?;
assert_eq!(clusters.len(), 1);
Ok(())
}
#[test]
fn test_similarity_score_range() -> Result<()> {
let config = EmbeddingConfig {
dimensions: 4,
..Default::default()
};
let index = HybridIndex::new(config)?;
index.initialize_index(VectorIndexType::Node)?;
index.add_node_embedding("identical".to_string(), vec![1.0, 0.0, 0.0, 0.0])?;
index.add_node_embedding("similar".to_string(), vec![0.9, 0.1, 0.0, 0.0])?;
index.add_node_embedding("different".to_string(), vec![0.0, 1.0, 0.0, 0.0])?;
let search_config = SemanticSearchConfig {
min_similarity: 0.0, ..Default::default()
};
let search = SemanticSearch::new(index, search_config);
let results = search.find_similar_nodes(&[1.0, 0.0, 0.0, 0.0], 10)?;
for result in &results {
assert!(
result.score >= 0.0 && result.score <= 1.0,
"Score {} out of valid range [0, 1]",
result.score
);
}
if !results.is_empty() {
let top_result = &results[0];
assert!(
top_result.score > 0.9,
"Identical vector should have score > 0.9"
);
}
Ok(())
}
#[test]
fn test_min_similarity_filtering() -> Result<()> {
let config = EmbeddingConfig {
dimensions: 4,
..Default::default()
};
let index = HybridIndex::new(config)?;
index.initialize_index(VectorIndexType::Node)?;
index.add_node_embedding("high_sim".to_string(), vec![1.0, 0.0, 0.0, 0.0])?;
index.add_node_embedding("low_sim".to_string(), vec![0.0, 0.0, 0.0, 1.0])?;
let search_config = SemanticSearchConfig {
min_similarity: 0.9,
..Default::default()
};
let search = SemanticSearch::new(index, search_config);
let results = search.find_similar_nodes(&[1.0, 0.0, 0.0, 0.0], 10)?;
for result in &results {
assert!(
result.score >= 0.9,
"Result with score {} should be filtered out (min: 0.9)",
result.score
);
}
Ok(())
}
}