use crate::AletheiaDB;
use crate::core::error::Result;
use crate::core::id::NodeId;
use crate::core::vector::cosine_similarity;
use std::collections::HashSet;
pub struct Prophet<'a> {
db: &'a AletheiaDB,
property_name: Option<String>,
}
impl<'a> Prophet<'a> {
pub fn new(db: &'a AletheiaDB) -> Self {
Self {
db,
property_name: None,
}
}
pub fn with_property(mut self, name: impl Into<String>) -> Self {
self.property_name = Some(name.into());
self
}
fn get_property_name(&self) -> Result<String> {
if let Some(ref name) = self.property_name {
return Ok(name.clone());
}
let indexes = self.db.list_vector_indexes();
if let Some(idx_info) = indexes.first() {
Ok(idx_info.property_name.clone())
} else {
Ok("".to_string())
}
}
fn get_neighbors(&self, node_id: NodeId) -> Result<HashSet<NodeId>> {
let out_edges = self.db.get_outgoing_edges(node_id);
let in_edges = self.db.get_incoming_edges(node_id);
let mut neighbors = HashSet::with_capacity(out_edges.len() + in_edges.len());
for eid in out_edges {
let target = self.db.get_edge_target(eid)?;
neighbors.insert(target);
}
for eid in in_edges {
let source = self.db.get_edge_source(eid)?;
neighbors.insert(source);
}
Ok(neighbors)
}
fn adamic_adar(&self, neighbors_a: &HashSet<NodeId>, neighbors_b: &HashSet<NodeId>) -> f32 {
let mut score = 0.0;
for &neighbor in neighbors_a {
if neighbors_b.contains(&neighbor) {
let degree = self.db.out_degree(neighbor) + self.db.in_degree(neighbor);
if degree > 1 {
score += 1.0 / (degree as f32).ln();
}
}
}
score
}
fn vector_similarity(&self, a: NodeId, b: NodeId, property: &str) -> f32 {
if property.is_empty() {
return 0.0;
}
let get_vec = |id| -> Option<Vec<f32>> {
self.db
.get_node(id)
.ok()?
.properties
.get(property)?
.as_vector()
.map(|v| v.to_vec())
};
let vec_a = match get_vec(a) {
Some(v) => v,
None => return 0.0,
};
let vec_b = match get_vec(b) {
Some(v) => v,
None => return 0.0,
};
cosine_similarity(&vec_a, &vec_b).unwrap_or(0.0)
}
pub fn predict_links(&self, target: NodeId, k: usize) -> Result<Vec<(NodeId, f32)>> {
let neighbors = self.get_neighbors(target)?;
let mut candidates = HashSet::new();
for &neighbor in &neighbors {
if let Ok(neighbor_neighbors) = self.get_neighbors(neighbor) {
for &candidate in &neighbor_neighbors {
if candidate != target && !neighbors.contains(&candidate) {
candidates.insert(candidate);
}
}
}
}
let property = self.get_property_name()?;
let mut scored_candidates = Vec::with_capacity(candidates.len());
for candidate in candidates {
if let Ok(candidate_neighbors) = self.get_neighbors(candidate) {
let topo_score = self.adamic_adar(&neighbors, &candidate_neighbors);
let vec_score = self.vector_similarity(target, candidate, &property);
let final_score = topo_score * (1.0 + vec_score);
if final_score > 0.0 {
scored_candidates.push((candidate, final_score));
}
}
}
scored_candidates
.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
if scored_candidates.len() > k {
scored_candidates.truncate(k);
}
Ok(scored_candidates)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::property::PropertyMapBuilder;
use crate::index::vector::{DistanceMetric, HnswConfig};
#[test]
fn test_prophet_diamond_prediction() {
let db = AletheiaDB::new().unwrap();
let props = PropertyMapBuilder::new().build();
let a = db.create_node("Node", props.clone()).unwrap();
let b = db.create_node("Node", props.clone()).unwrap();
let c = db.create_node("Node", props.clone()).unwrap();
let d = db.create_node("Node", props.clone()).unwrap();
db.create_edge(a, b, "KNOWS", props.clone()).unwrap();
db.create_edge(a, c, "KNOWS", props.clone()).unwrap();
db.create_edge(b, d, "KNOWS", props.clone()).unwrap();
db.create_edge(c, d, "KNOWS", props.clone()).unwrap();
let prophet = Prophet::new(&db);
let predictions = prophet.predict_links(a, 5).unwrap();
assert!(!predictions.is_empty());
assert_eq!(predictions[0].0, d);
assert!(predictions[0].1 > 0.0);
}
#[test]
fn test_prophet_vector_boost() {
let db = AletheiaDB::new().unwrap();
let config = HnswConfig::new(2, DistanceMetric::Cosine);
db.enable_vector_index("embedding", config).unwrap();
let p_a = PropertyMapBuilder::new()
.insert_vector("embedding", &[1.0, 0.0])
.build();
let a = db.create_node("Node", p_a).unwrap();
let p_x = PropertyMapBuilder::new().build();
let b = db.create_node("Node", p_x.clone()).unwrap();
let c = db.create_node("Node", p_x.clone()).unwrap();
let p_d = PropertyMapBuilder::new()
.insert_vector("embedding", &[1.0, 0.0])
.build();
let d = db.create_node("Node", p_d).unwrap();
let p_e = PropertyMapBuilder::new()
.insert_vector("embedding", &[0.0, 1.0])
.build();
let e = db.create_node("Node", p_e).unwrap();
db.create_edge(a, b, "KNOWS", p_x.clone()).unwrap();
db.create_edge(a, c, "KNOWS", p_x.clone()).unwrap();
db.create_edge(b, d, "KNOWS", p_x.clone()).unwrap();
db.create_edge(c, d, "KNOWS", p_x.clone()).unwrap();
db.create_edge(b, e, "KNOWS", p_x.clone()).unwrap();
db.create_edge(c, e, "KNOWS", p_x.clone()).unwrap();
let prophet = Prophet::new(&db);
let predictions = prophet.predict_links(a, 5).unwrap();
assert!(predictions.len() >= 2);
let d_score = predictions
.iter()
.find(|(id, _)| *id == d)
.map(|(_, s)| *s)
.unwrap();
let e_score = predictions
.iter()
.find(|(id, _)| *id == e)
.map(|(_, s)| *s)
.unwrap();
assert!(
d_score > e_score,
"D should rank higher than E due to vector similarity (D: {}, E: {})",
d_score,
e_score
);
}
}