#![allow(clippy::collapsible_if)]
use crate::AletheiaDB;
use crate::api::transaction::ReadOps;
use crate::core::error::{Error, Result};
use crate::core::id::NodeId;
use crate::core::vector::ops::cosine_similarity;
use std::collections::{HashSet, VecDeque};
#[derive(Debug, Clone, PartialEq)]
pub struct HorizonResult {
pub interior: HashSet<NodeId>,
pub horizon: HashSet<NodeId>,
}
pub struct HorizonEngine<'a> {
db: &'a AletheiaDB,
}
impl<'a> HorizonEngine<'a> {
pub fn new(db: &'a AletheiaDB) -> Self {
Self { db }
}
pub fn map_horizon(
&self,
seed: NodeId,
property_name: &str,
threshold: f32,
max_depth: usize,
) -> Result<HorizonResult> {
let mut interior = HashSet::new();
let mut horizon = HashSet::new();
let mut visited = HashSet::new();
let mut queue = VecDeque::new();
self.db.read(|tx| {
let seed_node = tx.get_node(seed)?;
let seed_vec = match seed_node
.get_property(property_name)
.and_then(|p| p.as_vector())
{
Some(v) => v,
None => {
return Err(Error::other(format!(
"Seed node {} does not have vector property '{}'",
seed, property_name
)));
}
};
queue.push_back((seed, 0));
visited.insert(seed);
interior.insert(seed);
while let Some((current_node_id, depth)) = queue.pop_front() {
if depth >= max_depth {
continue;
}
let outgoing = tx
.get_outgoing_edges(current_node_id)
.into_iter()
.filter_map(|e| tx.get_edge(e).ok().map(|edge| edge.target));
let incoming = tx
.get_incoming_edges(current_node_id)
.into_iter()
.filter_map(|e| tx.get_edge(e).ok().map(|edge| edge.source));
for neighbor_id in outgoing.chain(incoming) {
if visited.contains(&neighbor_id) {
continue;
}
visited.insert(neighbor_id);
if let Ok(neighbor_node) = tx.get_node(neighbor_id) {
if let Some(neighbor_vec) = neighbor_node
.get_property(property_name)
.and_then(|p| p.as_vector())
{
let sim = cosine_similarity(seed_vec, neighbor_vec)?;
if sim >= threshold {
interior.insert(neighbor_id);
queue.push_back((neighbor_id, depth + 1));
} else {
horizon.insert(neighbor_id);
}
}
}
}
}
Ok::<(), Error>(())
})?;
Ok(HorizonResult { interior, horizon })
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::api::transaction::WriteOps;
use crate::core::property::PropertyMapBuilder;
#[test]
fn test_horizon_mapping() {
let db = AletheiaDB::new().unwrap();
let mut a = NodeId::new(0).unwrap();
let mut b = NodeId::new(0).unwrap();
let mut c = NodeId::new(0).unwrap();
let mut d = NodeId::new(0).unwrap();
let mut e = NodeId::new(0).unwrap();
db.write(|tx| {
a = tx
.create_node(
"Concept",
PropertyMapBuilder::new()
.insert_vector("vec", &[1.0, 0.0, 0.0])
.build(),
)
.unwrap();
b = tx
.create_node(
"Concept",
PropertyMapBuilder::new()
.insert_vector("vec", &[0.9, 0.435889, 0.0])
.build(),
)
.unwrap();
c = tx
.create_node(
"Concept",
PropertyMapBuilder::new()
.insert_vector("vec", &[0.6, 0.8, 0.0])
.build(),
)
.unwrap();
d = tx
.create_node(
"Concept",
PropertyMapBuilder::new()
.insert_vector("vec", &[0.2, 0.979795, 0.0])
.build(),
)
.unwrap();
e = tx
.create_node(
"Concept",
PropertyMapBuilder::new()
.insert_vector("vec", &[0.0, 1.0, 0.0])
.build(),
)
.unwrap();
tx.create_edge(a, b, "LINK", Default::default()).unwrap();
tx.create_edge(b, c, "LINK", Default::default()).unwrap();
tx.create_edge(c, d, "LINK", Default::default()).unwrap();
tx.create_edge(d, e, "LINK", Default::default()).unwrap();
Ok::<(), Error>(())
})
.unwrap();
let engine = HorizonEngine::new(&db);
let result = engine.map_horizon(a, "vec", 0.5, 10).unwrap();
assert!(result.interior.contains(&a));
assert!(result.interior.contains(&b));
assert!(result.interior.contains(&c));
assert!(!result.interior.contains(&d));
assert!(!result.interior.contains(&e));
assert!(result.horizon.contains(&d));
assert!(!result.horizon.contains(&e));
let result2 = engine.map_horizon(a, "vec", 0.8, 10).unwrap();
assert!(result2.interior.contains(&a));
assert!(result2.interior.contains(&b));
assert!(result2.horizon.contains(&c));
assert!(!result2.horizon.contains(&d));
}
}