use crate::AletheiaDB;
use crate::core::error::Result;
use crate::core::id::NodeId;
use std::cmp::Ordering;
use std::collections::{BinaryHeap, HashSet};
#[derive(Debug, Clone)]
pub struct ThreadStep {
pub node_id: NodeId,
pub timestamp: i64,
pub connection_type: ConnectionType,
}
#[derive(Debug, Clone, PartialEq)]
pub enum ConnectionType {
Start,
Edge {
label: String,
},
SemanticJump {
similarity: f32,
},
}
#[derive(Debug)]
struct SearchState {
node_id: NodeId,
timestamp: i64,
g_cost: f32, f_cost: f32, path: Vec<ThreadStep>,
}
impl PartialEq for SearchState {
fn eq(&self, other: &Self) -> bool {
self.f_cost == other.f_cost
}
}
impl Eq for SearchState {}
impl Ord for SearchState {
fn cmp(&self, other: &Self) -> Ordering {
other
.f_cost
.partial_cmp(&self.f_cost)
.unwrap_or(Ordering::Equal)
}
}
impl PartialOrd for SearchState {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
pub struct Ariadne<'a> {
db: &'a AletheiaDB,
}
impl<'a> Ariadne<'a> {
pub fn new(db: &'a AletheiaDB) -> Self {
Self { db }
}
pub fn weave(
&self,
start_node: NodeId,
goal_node: Option<NodeId>,
time_property: &str,
vector_property: &str,
max_steps: usize,
beam_width: usize,
) -> Result<Vec<ThreadStep>> {
let start_time = self.get_time(start_node, time_property)?;
let start_step = ThreadStep {
node_id: start_node,
timestamp: start_time,
connection_type: ConnectionType::Start,
};
let goal_vector = if let Some(goal) = goal_node {
if let Ok(node) = self.db.get_node(goal) {
node.properties
.get(vector_property)
.and_then(|v| v.as_vector())
.map(|v| v.to_vec())
} else {
None
}
} else {
None
};
let mut pq = BinaryHeap::new();
let heuristic = self.calculate_heuristic(start_node, vector_property, &goal_vector);
pq.push(SearchState {
node_id: start_node,
timestamp: start_time,
g_cost: 0.0,
f_cost: heuristic,
path: vec![start_step],
});
let mut visited = HashSet::new();
let mut best_path = Vec::new();
let mut max_path_len = 0;
while let Some(state) = pq.pop() {
if !visited.insert(state.node_id) {
continue;
}
if let Some(goal) = goal_node
&& state.node_id == goal
{
return Ok(state.path);
}
if state.path.len() > max_path_len {
max_path_len = state.path.len();
best_path = state.path.clone();
}
if state.path.len() >= max_steps {
continue;
}
let edges = self.db.current.get_outgoing_edges(state.node_id);
for edge_id in edges {
if let Ok(edge) = self.db.get_edge(edge_id) {
let target = edge.target;
if visited.contains(&target) {
continue;
}
if let Ok(target_time) = self.get_time(target, time_property)
&& target_time >= state.timestamp
{
let label = self.resolve_label(edge.label);
let mut new_path = state.path.clone();
new_path.push(ThreadStep {
node_id: target,
timestamp: target_time,
connection_type: ConnectionType::Edge { label },
});
let heuristic =
self.calculate_heuristic(target, vector_property, &goal_vector);
let new_g = state.g_cost + 1.0;
let f_cost = new_g + heuristic;
pq.push(SearchState {
node_id: target,
timestamp: target_time,
g_cost: new_g,
f_cost,
path: new_path,
});
}
}
}
if let Ok(current_node) = self.db.get_node(state.node_id)
&& let Some(current_vector) = current_node
.properties
.get(vector_property)
.and_then(|v| v.as_vector())
{
let min_time = state.timestamp;
let candidates = self.db.find_similar_with_predicate(
vector_property,
current_vector,
beam_width, |candidate_id| {
if *candidate_id == state.node_id {
return false;
}
if let Ok(node) = self.db.get_node(*candidate_id)
&& let Some(val) = node.properties.get(time_property)
{
let time = val.as_int().unwrap_or(0);
return time >= min_time;
}
false
},
);
if let Ok(results) = candidates {
for (target, score) in results {
if visited.contains(&target) {
continue;
}
if let Ok(target_time) = self.get_time(target, time_property) {
let mut new_path = state.path.clone();
new_path.push(ThreadStep {
node_id: target,
timestamp: target_time,
connection_type: ConnectionType::SemanticJump { similarity: score },
});
let heuristic =
self.calculate_heuristic(target, vector_property, &goal_vector);
let jump_cost = 1.5 + (1.0 - score) * 5.0;
let new_g = state.g_cost + jump_cost;
let f_cost = new_g + heuristic;
pq.push(SearchState {
node_id: target,
timestamp: target_time,
g_cost: new_g,
f_cost,
path: new_path,
});
}
}
}
}
}
Ok(best_path)
}
fn get_time(&self, node_id: NodeId, property: &str) -> Result<i64> {
let node = self.db.get_node(node_id)?;
let val = node.properties.get(property).ok_or_else(|| {
crate::core::error::Error::Storage(crate::core::error::StorageError::PropertyNotFound(
property.to_string(),
))
})?;
val.as_int().ok_or_else(|| {
crate::core::error::Error::Query(crate::core::error::QueryError::TypeMismatch {
expected: "Integer".to_string(),
actual: format!("{:?}", val),
})
})
}
fn resolve_label(&self, label_id: crate::core::interning::InternedString) -> String {
use crate::core::interning::GLOBAL_INTERNER;
GLOBAL_INTERNER
.resolve_with(label_id, |s| s.to_string())
.unwrap_or_else(|| "unknown".to_string())
}
fn calculate_heuristic(
&self,
node_id: NodeId,
vector_prop: &str,
goal_vector: &Option<Vec<f32>>,
) -> f32 {
if let Some(goal) = goal_vector
&& let Ok(node) = self.db.get_node(node_id)
&& let Some(vec) = node.properties.get(vector_prop).and_then(|v| v.as_vector())
{
let dot: f32 = vec.iter().zip(goal.iter()).map(|(a, b)| a * b).sum();
return (1.0 - dot).max(0.0);
}
0.0
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::property::PropertyMapBuilder;
use crate::index::vector::{DistanceMetric, HnswConfig};
#[test]
fn test_ariadne_weaving() {
let db = AletheiaDB::new().unwrap();
let config = HnswConfig::new(2, DistanceMetric::Cosine);
db.enable_vector_index("embedding", config).unwrap();
let props_a = PropertyMapBuilder::new()
.insert("time", 10i64)
.insert("name", "A")
.insert_vector("embedding", &[1.0, 0.0])
.build();
let a = db.create_node("Event", props_a).unwrap();
let props_b = PropertyMapBuilder::new()
.insert("time", 20i64)
.insert("name", "B")
.insert_vector("embedding", &[0.9, 0.4]) .build();
let b = db.create_node("Event", props_b).unwrap();
db.create_edge(a, b, "NEXT", PropertyMapBuilder::new().build())
.unwrap();
let props_c = PropertyMapBuilder::new()
.insert("time", 30i64)
.insert("name", "C")
.insert_vector("embedding", &[0.5, 0.8])
.build();
let c = db.create_node("Event", props_c).unwrap();
let props_d = PropertyMapBuilder::new()
.insert("time", 40i64)
.insert("name", "D")
.insert_vector("embedding", &[0.0, 1.0])
.build();
let d = db.create_node("Event", props_d).unwrap();
db.create_edge(c, d, "NEXT", PropertyMapBuilder::new().build())
.unwrap();
let props_e = PropertyMapBuilder::new()
.insert("time", 5i64)
.insert("name", "E")
.insert_vector("embedding", &[0.9, 0.4])
.build();
let _e = db.create_node("Event", props_e).unwrap();
let ariadne = Ariadne::new(&db);
let path = ariadne
.weave(
a,
Some(d),
"time",
"embedding",
10,
100, )
.unwrap();
if path.len() != 4 {
println!(
"Path found: {:?}",
path.iter().map(|s| s.node_id).collect::<Vec<_>>()
);
}
assert_eq!(path.len(), 4);
assert_eq!(path[0].node_id, a);
assert_eq!(path[1].node_id, b);
assert_eq!(path[2].node_id, c);
assert_eq!(path[3].node_id, d);
assert!(matches!(
path[1].connection_type,
ConnectionType::Edge { .. }
));
assert!(matches!(
path[2].connection_type,
ConnectionType::SemanticJump { .. }
));
assert!(matches!(
path[3].connection_type,
ConnectionType::Edge { .. }
));
}
}