use crate::AletheiaDB;
use crate::core::error::{Error, Result};
use crate::core::id::NodeId;
use crate::core::vector::ops::normalize;
use std::cmp::Ordering;
use std::collections::{BinaryHeap, HashMap};
#[derive(Debug, Clone)]
pub struct Lens {
pub vector: Vec<f32>,
}
impl Lens {
pub fn new(vector: &[f32]) -> Self {
Self {
vector: normalize(vector),
}
}
}
pub struct Spectre<'a> {
db: &'a AletheiaDB,
}
#[derive(Clone, Copy, PartialEq)]
struct State {
cost: f32,
node: NodeId,
}
impl Eq for State {}
impl Ord for State {
fn cmp(&self, other: &Self) -> Ordering {
other
.cost
.partial_cmp(&self.cost)
.unwrap_or(Ordering::Equal)
}
}
impl PartialOrd for State {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl<'a> Spectre<'a> {
pub fn new(db: &'a AletheiaDB) -> Self {
Self { db }
}
pub fn focus(&self, node_id: NodeId, lens: &Lens, property: &str) -> Result<f32> {
let node = self.db.get_node(node_id)?;
let vec = node
.properties
.get(property)
.and_then(|v| v.as_vector())
.ok_or_else(|| {
Error::other(format!(
"Node {} missing vector property '{}'",
node_id, property
))
})?;
crate::core::vector::cosine_similarity(vec, &lens.vector)
}
pub fn traverse(
&self,
start: NodeId,
end: NodeId,
lens: &Lens,
vector_prop: &str,
) -> Result<Vec<NodeId>> {
if start == end {
return Ok(vec![start]);
}
let mut pq = BinaryHeap::new();
pq.push(State {
cost: 0.0,
node: start,
});
let mut dist = HashMap::new();
dist.insert(start, 0.0);
let mut came_from = HashMap::new();
while let Some(State { cost, node }) = pq.pop() {
if node == end {
return Ok(self.reconstruct_path(came_from, end));
}
if cost > *dist.get(&node).unwrap_or(&f32::INFINITY) {
continue;
}
for edge_id in self.db.get_outgoing_edges(node) {
if let Ok(neighbor) = self.db.get_edge_target(edge_id) {
let focus_score = self.focus(neighbor, lens, vector_prop).unwrap_or(-1.0);
let step_cost = 1.0 - focus_score + 0.001;
let new_cost = cost + step_cost;
if new_cost < *dist.get(&neighbor).unwrap_or(&f32::INFINITY) {
dist.insert(neighbor, new_cost);
came_from.insert(neighbor, node);
pq.push(State {
cost: new_cost,
node: neighbor,
});
}
}
}
}
Ok(Vec::new()) }
fn reconstruct_path(&self, came_from: HashMap<NodeId, NodeId>, current: NodeId) -> Vec<NodeId> {
let mut path = vec![current];
let mut curr = current;
while let Some(&prev) = came_from.get(&curr) {
path.push(prev);
curr = prev;
}
path.reverse();
path
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::property::PropertyMapBuilder;
use crate::index::vector::{DistanceMetric, HnswConfig};
fn create_db() -> AletheiaDB {
let db = AletheiaDB::new().unwrap();
db.enable_vector_index("vec", HnswConfig::new(2, DistanceMetric::Cosine))
.unwrap();
db
}
#[test]
fn test_spectre_focus() {
let db = create_db();
let props = PropertyMapBuilder::new()
.insert_vector("vec", &[1.0, 0.0])
.build();
let node = db.create_node("Node", props).unwrap();
let spectre = Spectre::new(&db);
let lens_x = Lens::new(&[1.0, 0.0]);
let focus_x = spectre.focus(node, &lens_x, "vec").unwrap();
assert!((focus_x - 1.0).abs() < 1e-5);
let lens_y = Lens::new(&[0.0, 1.0]);
let focus_y = spectre.focus(node, &lens_y, "vec").unwrap();
assert!((focus_y - 0.0).abs() < 1e-5);
}
#[test]
fn test_spectre_subjective_traversal() {
let db = create_db();
let p_start = PropertyMapBuilder::new()
.insert_vector("vec", &[0.0, 0.0])
.build();
let start = db.create_node("Start", p_start).unwrap();
let p_end = PropertyMapBuilder::new()
.insert_vector("vec", &[0.0, 0.0])
.build();
let end = db.create_node("End", p_end).unwrap();
let p_high = PropertyMapBuilder::new()
.insert_vector("vec", &[0.0, 1.0])
.build();
let high = db.create_node("HighRoad", p_high).unwrap();
let p_low = PropertyMapBuilder::new()
.insert_vector("vec", &[1.0, 0.0])
.build();
let low = db.create_node("LowRoad", p_low).unwrap();
db.create_edge(start, high, "path", Default::default())
.unwrap();
db.create_edge(high, end, "path", Default::default())
.unwrap();
db.create_edge(start, low, "path", Default::default())
.unwrap();
db.create_edge(low, end, "path", Default::default())
.unwrap();
let spectre = Spectre::new(&db);
let lens_y = Lens::new(&[0.0, 1.0]);
let path_y = spectre.traverse(start, end, &lens_y, "vec").unwrap();
assert_eq!(
path_y,
vec![start, high, end],
"Should take HighRoad with Y lens"
);
let lens_x = Lens::new(&[1.0, 0.0]);
let path_x = spectre.traverse(start, end, &lens_x, "vec").unwrap();
assert_eq!(
path_x,
vec![start, low, end],
"Should take LowRoad with X lens"
);
}
}