use crate::AletheiaDB;
use crate::core::error::{Error, Result};
use crate::core::hasher::IdentityHasher;
use crate::core::id::{EdgeId, NodeId};
use crate::core::vector::cosine_similarity;
use dashmap::DashMap;
use std::cmp::Ordering;
use std::collections::{BinaryHeap, HashMap};
use std::hash::BuildHasherDefault;
use std::sync::atomic::{AtomicU64, Ordering as AtomicOrdering};
pub struct SynapseContext {
weights: DashMap<EdgeId, AtomicU64, BuildHasherDefault<IdentityHasher>>,
}
impl Default for SynapseContext {
fn default() -> Self {
Self::new()
}
}
impl SynapseContext {
pub fn new() -> Self {
Self {
weights: DashMap::with_hasher(BuildHasherDefault::default()),
}
}
pub fn get_usage(&self, edge_id: EdgeId) -> u64 {
match self.weights.get(&edge_id) {
Some(val) => val.load(AtomicOrdering::Relaxed),
None => 0,
}
}
pub fn observe(&self, edge_id: EdgeId) {
self.weights
.entry(edge_id)
.and_modify(|val| {
val.fetch_add(1, AtomicOrdering::Relaxed);
})
.or_insert(AtomicU64::new(1));
}
pub fn decay(&self, factor: f32) {
if factor >= 1.0 {
return;
}
for entry in self.weights.iter() {
let current = entry.value().load(AtomicOrdering::Relaxed);
let new_val = (current as f32 * factor) as u64;
entry.value().store(new_val, AtomicOrdering::Relaxed);
}
}
pub fn clear(&self) {
self.weights.clear();
}
}
pub struct Synapse<'a> {
db: &'a AletheiaDB,
context: &'a SynapseContext,
}
#[derive(Clone, Copy, PartialEq)]
struct State {
cost: f32,
node: NodeId,
}
impl Eq for State {}
impl Ord for State {
fn cmp(&self, other: &Self) -> Ordering {
other
.cost
.partial_cmp(&self.cost)
.unwrap_or(Ordering::Equal)
}
}
impl PartialOrd for State {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl<'a> Synapse<'a> {
pub fn new(db: &'a AletheiaDB, context: &'a SynapseContext) -> Self {
Self { db, context }
}
pub fn observe(&self, edge_id: EdgeId) {
self.context.observe(edge_id);
}
pub fn adaptive_semantic_path(
&self,
start: NodeId,
end: NodeId,
vector_prop: &str,
) -> Result<Vec<NodeId>> {
let start_node = self.db.get_node(start)?;
let end_node = self.db.get_node(end)?;
let _start_vec = start_node
.properties
.get(vector_prop)
.and_then(|v| v.as_arc_vector())
.ok_or_else(|| {
Error::other(format!(
"Start node {} missing vector property '{}'",
start, vector_prop
))
})?;
let _end_vec = end_node
.properties
.get(vector_prop)
.and_then(|v| v.as_arc_vector())
.ok_or_else(|| {
Error::other(format!(
"End node {} missing vector property '{}'",
end, vector_prop
))
})?;
let mut open_set = BinaryHeap::new();
open_set.push(State {
cost: 0.0,
node: start,
});
let mut came_from: HashMap<NodeId, NodeId, BuildHasherDefault<IdentityHasher>> =
HashMap::with_hasher(BuildHasherDefault::default());
let mut g_score: HashMap<NodeId, f32, BuildHasherDefault<IdentityHasher>> =
HashMap::with_hasher(BuildHasherDefault::default());
g_score.insert(start, 0.0);
while let Some(State {
cost: _current_f,
node: current,
}) = open_set.pop()
{
if current == end {
return Ok(self.reconstruct_path(came_from, current));
}
let current_node = if current == start {
start_node.clone()
} else {
self.db.get_node(current)?
};
let current_vec = current_node
.properties
.get(vector_prop)
.and_then(|v| v.as_arc_vector());
for edge_id in self.db.get_outgoing_edges(current) {
let neighbor = self.db.get_edge_target(edge_id)?;
let neighbor_node = self.db.get_node(neighbor)?;
let neighbor_vec = neighbor_node
.properties
.get(vector_prop)
.and_then(|v| v.as_arc_vector());
let semantic_cost = match (¤t_vec, &neighbor_vec) {
(Some(a), Some(b)) => (1.0 - cosine_similarity(a, b)?).max(0.001), _ => 1.0, };
let usage = self.context.get_usage(edge_id);
let weight_factor = 1.0 / (1.0 + ((1 + usage) as f32).log2());
let edge_cost = semantic_cost * weight_factor;
let tentative_g = g_score.get(¤t).unwrap_or(&f32::INFINITY) + edge_cost;
if tentative_g < *g_score.get(&neighbor).unwrap_or(&f32::INFINITY) {
came_from.insert(neighbor, current);
g_score.insert(neighbor, tentative_g);
let h_score = 0.0;
let f = tentative_g + h_score;
open_set.push(State {
cost: f,
node: neighbor,
});
}
}
}
Err(Error::other("No path found"))
}
fn reconstruct_path(
&self,
came_from: HashMap<NodeId, NodeId, BuildHasherDefault<IdentityHasher>>,
mut current: NodeId,
) -> Vec<NodeId> {
let mut total_path = vec![current];
while let Some(&prev) = came_from.get(¤t) {
current = prev;
total_path.push(current);
}
total_path.reverse();
total_path
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::property::PropertyMapBuilder;
#[test]
fn test_synapse_hebbian_learning() {
let db = AletheiaDB::new().unwrap();
let props = PropertyMapBuilder::new()
.insert_vector("vec", &[1.0, 0.0])
.build();
let a = db.create_node("Node", props.clone()).unwrap();
let b = db.create_node("Node", props.clone()).unwrap();
let c = db.create_node("Node", props.clone()).unwrap();
let d = db.create_node("Node", props.clone()).unwrap();
let e_ab = db
.create_edge(a, b, "NEXT", PropertyMapBuilder::new().build())
.unwrap();
let e_bd = db
.create_edge(b, d, "NEXT", PropertyMapBuilder::new().build())
.unwrap();
let _e_ac = db
.create_edge(a, c, "NEXT", PropertyMapBuilder::new().build())
.unwrap();
let _e_cd = db
.create_edge(c, d, "NEXT", PropertyMapBuilder::new().build())
.unwrap();
let context = SynapseContext::new();
let synapse = Synapse::new(&db, &context);
for _ in 0..10 {
synapse.observe(e_ab);
synapse.observe(e_bd);
}
let path = synapse.adaptive_semantic_path(a, d, "vec").unwrap();
assert_eq!(path, vec![a, b, d]);
}
#[test]
fn test_synapse_semantic_vs_popularity_tradeoff() {
let db = AletheiaDB::new().unwrap();
let props_a = PropertyMapBuilder::new()
.insert_vector("vec", &[1.0, 0.0])
.build();
let a = db.create_node("Node", props_a).unwrap();
let props_b = PropertyMapBuilder::new()
.insert_vector("vec", &[0.0, 1.0])
.build();
let b = db.create_node("Node", props_b).unwrap();
let props_c = PropertyMapBuilder::new()
.insert_vector("vec", &[0.8, 0.6]) .build();
let c = db.create_node("Node", props_c).unwrap();
let props_target = PropertyMapBuilder::new()
.insert_vector("vec", &[1.0, 0.0])
.build();
let target = db.create_node("Node", props_target).unwrap();
let e_ab = db
.create_edge(a, b, "NEXT", PropertyMapBuilder::new().build())
.unwrap();
let e_bt = db
.create_edge(b, target, "NEXT", PropertyMapBuilder::new().build())
.unwrap();
let _e_ac = db
.create_edge(a, c, "NEXT", PropertyMapBuilder::new().build())
.unwrap();
let _e_ct = db
.create_edge(c, target, "NEXT", PropertyMapBuilder::new().build())
.unwrap();
let context = SynapseContext::new();
let synapse = Synapse::new(&db, &context);
let initial_path = synapse.adaptive_semantic_path(a, target, "vec").unwrap();
assert_eq!(
initial_path,
vec![a, c, target],
"Should prefer semantic match initially"
);
for _ in 0..10_000 {
synapse.observe(e_ab);
synapse.observe(e_bt);
}
let adapted_path = synapse.adaptive_semantic_path(a, target, "vec").unwrap();
assert_eq!(
adapted_path,
vec![a, b, target],
"Should prefer popular path after training"
);
}
#[test]
fn test_synapse_decay() {
let context = SynapseContext::new();
let edge = EdgeId::new(1).unwrap();
context.observe(edge);
context.observe(edge);
assert_eq!(context.get_usage(edge), 2);
context.decay(0.5);
assert_eq!(context.get_usage(edge), 1);
context.decay(0.5);
assert_eq!(context.get_usage(edge), 0);
}
}