use crate::AletheiaDB;
use crate::core::error::Result;
use crate::core::id::{EdgeId, NodeId};
use std::collections::{HashMap, HashSet};
#[derive(Debug, Clone)]
pub struct PatternNode {
pub id: usize,
pub label: Option<String>,
pub vector_constraint: Option<VectorConstraint>,
}
#[derive(Debug, Clone)]
pub struct VectorConstraint {
pub property: String,
pub vector: Vec<f32>,
pub threshold: f32,
}
#[derive(Debug, Clone)]
pub struct PatternEdge {
pub source: usize,
pub target: usize,
pub label: Option<String>,
pub directed: bool,
}
#[derive(Debug, Clone, Default)]
pub struct Pattern {
pub nodes: Vec<PatternNode>,
pub edges: Vec<PatternEdge>,
}
impl Pattern {
pub fn new() -> Self {
Self::default()
}
pub fn add_node(&mut self, label: Option<String>) -> usize {
let id = self.nodes.len();
self.nodes.push(PatternNode {
id,
label,
vector_constraint: None,
});
id
}
pub fn add_semantic_node(
&mut self,
label: Option<String>,
property: String,
vector: Vec<f32>,
threshold: f32,
) -> usize {
let id = self.nodes.len();
self.nodes.push(PatternNode {
id,
label,
vector_constraint: Some(VectorConstraint {
property,
vector,
threshold,
}),
});
id
}
pub fn add_edge(&mut self, source: usize, target: usize, label: Option<String>) {
self.edges.push(PatternEdge {
source,
target,
label,
directed: true,
});
}
fn validate(&self) -> Result<()> {
let node_count = self.nodes.len();
for edge in &self.edges {
if edge.source >= node_count || edge.target >= node_count {
return Err(crate::core::error::Error::other(format!(
"Pattern edge references invalid node ID: {} -> {}",
edge.source, edge.target
)));
}
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct Match {
pub nodes: HashMap<usize, NodeId>,
pub edges: HashMap<usize, EdgeId>,
pub score: f32,
}
pub struct GestaltMatcher<'a> {
db: &'a AletheiaDB,
}
impl<'a> GestaltMatcher<'a> {
pub fn new(db: &'a AletheiaDB) -> Self {
Self { db }
}
pub fn find_matches(&self, pattern: &Pattern, limit: usize) -> Result<Vec<Match>> {
pattern.validate()?;
if pattern.nodes.is_empty() {
return Ok(Vec::new());
}
let anchor_idx = pattern
.nodes
.iter()
.position(|n| n.vector_constraint.is_some())
.ok_or_else(|| {
crate::core::error::Error::other(
"Gestalt requires at least one node with a vector constraint to serve as an anchor.",
)
})?;
let anchor_node = &pattern.nodes[anchor_idx];
let constraint = anchor_node.vector_constraint.as_ref().unwrap();
let search_k = limit * 10;
let candidates =
self.db
.search_vectors_in(&constraint.property, &constraint.vector, search_k)?;
let mut matches = Vec::new();
for (candidate_id, score) in candidates {
if score < constraint.threshold {
continue; }
if !self.check_node_label(candidate_id, &anchor_node.label)? {
continue;
}
let mut current_match = Match {
nodes: HashMap::new(),
edges: HashMap::new(),
score: 0.0, };
current_match.nodes.insert(anchor_idx, candidate_id);
self.backtrack(pattern, &mut current_match, &mut matches, limit)?;
if matches.len() >= limit {
break;
}
}
Ok(matches)
}
fn backtrack(
&self,
pattern: &Pattern,
current_match: &mut Match,
results: &mut Vec<Match>,
limit: usize,
) -> Result<()> {
if results.len() >= limit {
return Ok(());
}
if current_match.nodes.len() == pattern.nodes.len() {
if self.verify_edges(pattern, current_match)? {
let score = self.calculate_score(pattern, current_match)?;
let mut final_match = current_match.clone();
final_match.score = score;
results.push(final_match);
}
return Ok(());
}
let next_node_idx = self.pick_next_node(pattern, current_match);
if let Some(next_idx) = next_node_idx {
let candidates = self.find_candidates_for_node(pattern, next_idx, current_match)?;
for candidate_id in candidates {
if current_match.nodes.values().any(|&id| id == candidate_id) {
continue;
}
let p_node = &pattern.nodes[next_idx];
if !self.check_node_constraints(candidate_id, p_node)? {
continue;
}
current_match.nodes.insert(next_idx, candidate_id);
self.backtrack(pattern, current_match, results, limit)?;
current_match.nodes.remove(&next_idx);
if results.len() >= limit {
return Ok(());
}
}
} else {
}
Ok(())
}
fn pick_next_node(&self, pattern: &Pattern, current_match: &Match) -> Option<usize> {
for edge in &pattern.edges {
let s_mapped = current_match.nodes.contains_key(&edge.source);
let t_mapped = current_match.nodes.contains_key(&edge.target);
if s_mapped && !t_mapped {
return Some(edge.target);
}
if !s_mapped && t_mapped {
return Some(edge.source);
}
}
None
}
fn find_candidates_for_node(
&self,
pattern: &Pattern,
target_idx: usize,
current_match: &Match,
) -> Result<Vec<NodeId>> {
let mut candidates = HashSet::new();
let mut first = true;
for edge in &pattern.edges {
if edge.target == target_idx && current_match.nodes.contains_key(&edge.source) {
let source_db_id = current_match.nodes[&edge.source];
let neighbors = self.get_outgoing_neighbors(source_db_id, &edge.label)?;
if first {
candidates = neighbors;
first = false;
} else {
candidates.retain(|id| neighbors.contains(id));
}
}
if edge.source == target_idx && current_match.nodes.contains_key(&edge.target) {
let target_db_id = current_match.nodes[&edge.target];
let neighbors = self.get_incoming_neighbors(target_db_id, &edge.label)?;
if first {
candidates = neighbors;
first = false;
} else {
candidates.retain(|id| neighbors.contains(id));
}
}
}
if first {
return Ok(Vec::new());
}
Ok(candidates.into_iter().collect())
}
fn get_outgoing_neighbors(
&self,
source: NodeId,
label: &Option<String>,
) -> Result<HashSet<NodeId>> {
let mut neighbors = HashSet::new();
let edges = if let Some(l) = label {
self.db.get_outgoing_edges_with_label(source, l)
} else {
self.db.get_outgoing_edges(source)
};
for edge_id in edges {
let target = self.db.get_edge_target(edge_id)?;
neighbors.insert(target);
}
Ok(neighbors)
}
fn get_incoming_neighbors(
&self,
target: NodeId,
label: &Option<String>,
) -> Result<HashSet<NodeId>> {
let mut neighbors = HashSet::new();
let edges = if let Some(l) = label {
self.db.current.get_incoming_edges_with_label(target, l)
} else {
self.db.get_incoming_edges(target)
};
for edge_id in edges {
let source = self.db.get_edge_source(edge_id)?;
neighbors.insert(source);
}
Ok(neighbors)
}
fn check_node_label(&self, node_id: NodeId, label: &Option<String>) -> Result<bool> {
if let Some(l) = label {
let node = self.db.get_node(node_id)?;
Ok(node.has_label_str(l))
} else {
Ok(true)
}
}
fn check_node_constraints(&self, node_id: NodeId, pattern_node: &PatternNode) -> Result<bool> {
if !self.check_node_label(node_id, &pattern_node.label)? {
return Ok(false);
}
if let Some(vc) = &pattern_node.vector_constraint {
let node = self.db.get_node(node_id)?;
if let Some(val) = node.properties.get(&vc.property) {
if let Some(vec) = val.as_vector() {
let sim = crate::core::vector::cosine_similarity(vec, &vc.vector)?;
if sim < vc.threshold {
return Ok(false);
}
} else {
return Ok(false); }
} else {
return Ok(false); }
}
Ok(true)
}
fn verify_edges(&self, pattern: &Pattern, current_match: &mut Match) -> Result<bool> {
for (i, edge) in pattern.edges.iter().enumerate() {
let u = current_match.nodes[&edge.source];
let v = current_match.nodes[&edge.target];
let outgoing = if let Some(l) = &edge.label {
self.db.get_outgoing_edges_with_label(u, l)
} else {
self.db.get_outgoing_edges(u)
};
let mut found = false;
for edge_id in outgoing {
if self.db.get_edge_target(edge_id)? == v {
current_match.edges.insert(i, edge_id);
found = true;
break;
}
}
if !found {
return Ok(false);
}
}
Ok(true)
}
fn calculate_score(&self, pattern: &Pattern, current_match: &Match) -> Result<f32> {
let mut total_score = 0.0;
let mut count = 0;
for node_idx in 0..pattern.nodes.len() {
if let Some(vc) = &pattern.nodes[node_idx].vector_constraint {
let db_id = current_match.nodes[&node_idx];
let node = self.db.get_node(db_id)?;
if let Some(vec) = node
.properties
.get(&vc.property)
.and_then(|v| v.as_vector())
{
let sim = crate::core::vector::cosine_similarity(vec, &vc.vector)?;
total_score += sim;
count += 1;
}
}
}
if count > 0 {
Ok(total_score / count as f32)
} else {
Ok(1.0) }
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::property::PropertyMapBuilder;
use crate::index::vector::{DistanceMetric, HnswConfig};
#[test]
fn test_gestalt_simple_match() {
let db = AletheiaDB::new().unwrap();
db.enable_vector_index("vec", HnswConfig::new(2, DistanceMetric::Cosine))
.unwrap();
let props_a = PropertyMapBuilder::new()
.insert_vector("vec", &[1.0, 0.0])
.build();
let a = db.create_node("Person", props_a).unwrap();
let props_b = PropertyMapBuilder::new()
.insert_vector("vec", &[0.0, 1.0])
.build();
let b = db.create_node("Company", props_b).unwrap();
db.create_edge(a, b, "WORKS_FOR", Default::default())
.unwrap();
let props_c = PropertyMapBuilder::new()
.insert_vector("vec", &[1.0, 0.0])
.build();
let c = db.create_node("Person", props_c).unwrap();
let props_d = PropertyMapBuilder::new()
.insert_vector("vec", &[0.0, 1.0])
.build();
let d = db.create_node("Company", props_d).unwrap();
db.create_edge(c, d, "WORKS_FOR", Default::default())
.unwrap();
let props_e = PropertyMapBuilder::new()
.insert_vector("vec", &[0.0, 1.0]) .build();
let e = db.create_node("Person", props_e).unwrap();
let props_f = PropertyMapBuilder::new()
.insert_vector("vec", &[1.0, 0.0]) .build();
let f = db.create_node("Company", props_f).unwrap();
db.create_edge(e, f, "WORKS_FOR", Default::default())
.unwrap();
let mut pattern = Pattern::new();
let p0 = pattern.add_semantic_node(
Some("Person".to_string()),
"vec".to_string(),
vec![1.0, 0.0],
0.9,
);
let p1 = pattern.add_semantic_node(
Some("Company".to_string()),
"vec".to_string(),
vec![0.0, 1.0],
0.9,
);
pattern.add_edge(p0, p1, Some("WORKS_FOR".to_string()));
let matcher = GestaltMatcher::new(&db);
let matches = matcher.find_matches(&pattern, 10).unwrap();
assert_eq!(matches.len(), 2);
let found_pairs: Vec<(NodeId, NodeId)> =
matches.iter().map(|m| (m.nodes[&0], m.nodes[&1])).collect();
assert!(found_pairs.contains(&(a, b)));
assert!(found_pairs.contains(&(c, d)));
assert!(!found_pairs.contains(&(e, f)));
}
}