use crate::node::{NodeId, SearchHit};
use crate::storage::memtable::MemTable;
use std::collections::HashMap;
pub fn expand_graph<T: crate::VectorType>(
db: &MemTable<T>,
seeds: Vec<SearchHit>,
max_depth: usize,
teleport_alpha: f32, enable_inverse_inhibition: bool, lateral_inhibition_threshold: usize, enable_refractory_fatigue: bool, ) -> Vec<SearchHit> {
if max_depth == 0 {
return seeds;
}
let mut total_activation = HashMap::<NodeId, f32>::new();
let mut current_tier = HashMap::<NodeId, f32>::new();
let mut active_fatigue = Vec::new();
for seed in &seeds {
total_activation.insert(seed.id, seed.score);
current_tier.insert(seed.id, seed.score);
}
let propagation_threshold = 0.0;
for _ in 0..max_depth {
let mut next_tier = HashMap::<NodeId, f32>::new();
for (curr_id, curr_energy) in current_tier {
if let Some(edges) = db.get_edges(curr_id) {
let spread_energy = curr_energy * (1.0 - teleport_alpha).max(0.0);
if spread_energy <= propagation_threshold {
continue;
}
for edge in edges {
let inhibition_factor = if enable_inverse_inhibition {
let in_degree = db.get_in_degree(edge.target_id).max(1) as f32;
1.0 / in_degree.powf(0.55)
} else {
1.0
};
let fatigue_discount = if enable_refractory_fatigue {
let target_fatigue = db.get_fatigue(edge.target_id);
if target_fatigue > 0 {
active_fatigue.push(edge.target_id);
0.15 } else {
1.0
}
} else {
1.0
};
let transmitted = if edge.label == "inhibition" {
-(spread_energy * edge.weight * inhibition_factor * fatigue_discount)
} else {
spread_energy * edge.weight * inhibition_factor * fatigue_discount
};
*next_tier.entry(edge.target_id).or_insert(0.0) += transmitted;
*total_activation.entry(edge.target_id).or_insert(0.0) += transmitted;
}
}
}
next_tier.retain(|_, energy| *energy > propagation_threshold);
if lateral_inhibition_threshold > 0 && next_tier.len() > lateral_inhibition_threshold {
let mut sorted_tier: Vec<(NodeId, f32)> = next_tier.into_iter().collect();
sorted_tier.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
sorted_tier.truncate(lateral_inhibition_threshold);
next_tier = sorted_tier.into_iter().collect();
}
if next_tier.is_empty() {
break; }
current_tier = next_tier;
}
let mut expanded_results = Vec::new();
for (id, score) in total_activation {
if let Some(payload) = db.get_payload(id) {
expanded_results.push(SearchHit {
id,
score,
payload: payload.clone(),
});
}
}
expanded_results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
if enable_refractory_fatigue {
if !active_fatigue.is_empty() {
active_fatigue.sort_unstable();
active_fatigue.dedup();
db.consume_fatigue_batch(&active_fatigue);
}
if !expanded_results.is_empty() {
let top_ids: Vec<NodeId> = expanded_results.iter().take(15).map(|h| h.id).collect();
db.mark_fatigued(&top_ids);
}
}
expanded_results
}