use crate::AletheiaDB;
use crate::core::error::{Error, Result};
use crate::core::id::NodeId;
use crate::core::vector::cosine_similarity;
use std::cmp::Ordering;
use std::collections::{BinaryHeap, HashMap};
pub struct SemanticNavigator<'a> {
db: &'a AletheiaDB,
}
#[derive(Clone, Copy, PartialEq)]
struct State {
cost: f32,
node: NodeId,
}
impl Eq for State {}
impl Ord for State {
fn cmp(&self, other: &Self) -> Ordering {
other
.cost
.partial_cmp(&self.cost)
.unwrap_or(Ordering::Equal)
}
}
impl PartialOrd for State {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl<'a> SemanticNavigator<'a> {
pub fn new(db: &'a AletheiaDB) -> Self {
Self { db }
}
pub fn find_path(&self, start: NodeId, end: NodeId, vector_prop: &str) -> Result<Vec<NodeId>> {
let start_node = self.db.get_node(start)?;
let end_node = self.db.get_node(end)?;
let _start_vec = start_node
.properties
.get(vector_prop)
.and_then(|v| v.as_arc_vector())
.ok_or_else(|| {
Error::other(format!(
"Start node {} missing vector property '{}'",
start, vector_prop
))
})?;
let end_vec = end_node
.properties
.get(vector_prop)
.and_then(|v| v.as_arc_vector())
.ok_or_else(|| {
Error::other(format!(
"End node {} missing vector property '{}'",
end, vector_prop
))
})?;
let mut open_set = BinaryHeap::new();
open_set.push(State {
cost: 0.0,
node: start,
});
let mut came_from: HashMap<NodeId, NodeId> = HashMap::new();
let mut g_score: HashMap<NodeId, f32> = HashMap::new();
g_score.insert(start, 0.0);
let mut f_score: HashMap<NodeId, f32> = HashMap::new();
let h_start = 1.0 - cosine_similarity(&_start_vec, &end_vec)?;
f_score.insert(start, h_start);
while let Some(State {
cost: _current_f,
node: current,
}) = open_set.pop()
{
if current == end {
return Ok(self.reconstruct_path(came_from, current));
}
let current_node = if current == start {
start_node.clone() } else {
self.db.get_node(current)?
};
let current_vec = current_node
.properties
.get(vector_prop)
.and_then(|v| v.as_arc_vector());
for edge_id in self.db.get_outgoing_edges(current) {
let neighbor = self.db.get_edge_target(edge_id)?;
let neighbor_node = self.db.get_node(neighbor)?;
let neighbor_vec = neighbor_node
.properties
.get(vector_prop)
.and_then(|v| v.as_arc_vector());
let distance_cost = match (¤t_vec, &neighbor_vec) {
(Some(a), Some(b)) => 1.0 - cosine_similarity(a, b)?,
_ => 1.0, };
let tentative_g = g_score.get(¤t).unwrap_or(&f32::INFINITY) + distance_cost;
if tentative_g < *g_score.get(&neighbor).unwrap_or(&f32::INFINITY) {
came_from.insert(neighbor, current);
g_score.insert(neighbor, tentative_g);
let h_score = match &neighbor_vec {
Some(vec) => 1.0 - cosine_similarity(vec, &end_vec)?,
None => 1.0, };
let f = tentative_g + h_score;
f_score.insert(neighbor, f);
open_set.push(State {
cost: f,
node: neighbor,
});
}
}
}
Err(Error::other("No path found"))
}
fn reconstruct_path(
&self,
came_from: HashMap<NodeId, NodeId>,
mut current: NodeId,
) -> Vec<NodeId> {
let mut total_path = vec![current];
while let Some(&prev) = came_from.get(¤t) {
current = prev;
total_path.push(current);
}
total_path.reverse();
total_path
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::AletheiaDBConfig;
use crate::config::WalConfigBuilder;
use crate::core::property::PropertyMapBuilder;
use tempfile::tempdir;
fn create_test_db() -> (AletheiaDB, tempfile::TempDir) {
let dir = tempdir().unwrap();
let wal_path = dir.path().join("wal");
let data_path = dir.path().join("data");
std::fs::create_dir_all(&wal_path).unwrap();
std::fs::create_dir_all(&data_path).unwrap();
let persistence_config = crate::storage::index_persistence::PersistenceConfig {
data_dir: data_path,
enabled: false,
..Default::default()
};
let config = AletheiaDBConfig::builder()
.wal(WalConfigBuilder::new().wal_dir(wal_path).build())
.persistence(persistence_config)
.build();
let db = AletheiaDB::with_unified_config(config).unwrap();
(db, dir)
}
#[test]
fn test_semantic_path_linear() {
let (db, _dir) = create_test_db();
let props_a = PropertyMapBuilder::new()
.insert_vector("vec", &[1.0, 0.0])
.build();
let a = db.create_node("Node", props_a).unwrap();
let props_b = PropertyMapBuilder::new()
.insert_vector("vec", &[0.707, 0.707])
.build();
let b = db.create_node("Node", props_b).unwrap();
let props_c = PropertyMapBuilder::new()
.insert_vector("vec", &[0.0, 1.0])
.build();
let c = db.create_node("Node", props_c).unwrap();
let props_d = PropertyMapBuilder::new()
.insert_vector("vec", &[0.0, -1.0])
.build();
let d = db.create_node("Node", props_d).unwrap();
db.create_edge(a, b, "NEXT", PropertyMapBuilder::new().build())
.unwrap();
db.create_edge(b, c, "NEXT", PropertyMapBuilder::new().build())
.unwrap();
db.create_edge(a, d, "NEXT", PropertyMapBuilder::new().build())
.unwrap();
db.create_edge(d, c, "NEXT", PropertyMapBuilder::new().build())
.unwrap();
let nav = SemanticNavigator::new(&db);
let path = nav.find_path(a, c, "vec").unwrap();
assert_eq!(
path,
vec![a, b, c],
"Should choose the semantically closer path via B"
);
}
#[test]
fn test_missing_vector_fail() {
let (db, _dir) = create_test_db();
let a = db
.create_node("Node", PropertyMapBuilder::new().build())
.unwrap();
let b = db
.create_node("Node", PropertyMapBuilder::new().build())
.unwrap();
let nav = SemanticNavigator::new(&db);
let result = nav.find_path(a, b, "vec");
assert!(
result.is_err(),
"Should fail if start/end nodes lack vectors"
);
}
}