use std::collections::{HashMap, HashSet, VecDeque};
#[derive(Debug, Clone)]
pub struct IncrementalConfig {
pub track_edge_usage: bool,
pub stats_window: usize,
pub min_traversals: usize,
pub edge_add_threshold: f32,
pub edge_remove_threshold: f32,
pub max_edges_per_pass: usize,
pub async_refinement: bool,
}
impl Default for IncrementalConfig {
fn default() -> Self {
Self {
track_edge_usage: true,
stats_window: 10_000,
min_traversals: 100,
edge_add_threshold: 0.3,
edge_remove_threshold: 0.01,
max_edges_per_pass: 1000,
async_refinement: true,
}
}
}
#[derive(Debug, Default)]
pub struct EdgeStats {
edge_traversals: HashMap<(usize, usize), u64>,
edge_improvements: HashMap<(usize, usize), u64>,
cooccurrence: HashMap<(usize, usize), u64>,
total_searches: u64,
recent_entry_points: VecDeque<usize>,
}
impl EdgeStats {
pub fn new() -> Self {
Self::default()
}
pub fn record_traversal(&mut self, from: usize, to: usize, improved: bool) {
let key = (from.min(to), from.max(to));
*self.edge_traversals.entry(key).or_default() += 1;
if improved {
*self.edge_improvements.entry(key).or_default() += 1;
}
}
pub fn record_covisit(&mut self, nodes: &[usize]) {
for i in 0..nodes.len() {
for j in (i + 1)..nodes.len() {
let key = (nodes[i].min(nodes[j]), nodes[i].max(nodes[j]));
*self.cooccurrence.entry(key).or_default() += 1;
}
}
}
pub fn record_search(&mut self, entry_point: usize) {
self.total_searches += 1;
self.recent_entry_points.push_back(entry_point);
if self.recent_entry_points.len() > 1000 {
self.recent_entry_points.pop_front();
}
}
pub fn improvement_rate(&self, from: usize, to: usize) -> f32 {
let key = (from.min(to), from.max(to));
let traversals = self.edge_traversals.get(&key).copied().unwrap_or(0);
let improvements = self.edge_improvements.get(&key).copied().unwrap_or(0);
if traversals == 0 {
return 0.0;
}
improvements as f32 / traversals as f32
}
pub fn cooccurrence_rate(&self, a: usize, b: usize) -> f32 {
if self.total_searches == 0 {
return 0.0;
}
let key = (a.min(b), a.max(b));
let count = self.cooccurrence.get(&key).copied().unwrap_or(0);
count as f32 / self.total_searches as f32
}
pub fn top_cooccurrences(&self, n: usize) -> Vec<((usize, usize), f32)> {
if self.total_searches == 0 {
return Vec::new();
}
let mut pairs: Vec<_> = self
.cooccurrence
.iter()
.map(|(&k, &v)| (k, v as f32 / self.total_searches as f32))
.collect();
pairs.sort_unstable_by(|a, b| b.1.total_cmp(&a.1));
pairs.truncate(n);
pairs
}
pub fn underutilized_edges(&self, threshold: f32) -> Vec<(usize, usize)> {
self.edge_traversals
.iter()
.filter(|&(&_edge, &count)| {
let rate = count as f32 / self.total_searches.max(1) as f32;
rate < threshold
})
.map(|(&edge, _)| edge)
.collect()
}
pub fn clear(&mut self) {
self.edge_traversals.clear();
self.edge_improvements.clear();
self.cooccurrence.clear();
self.total_searches = 0;
}
pub fn total_searches(&self) -> u64 {
self.total_searches
}
}
#[derive(Debug, Clone)]
pub struct RefinementSuggestions {
pub edges_to_add: Vec<(usize, usize, f32)>,
pub edges_to_remove: Vec<(usize, usize)>,
pub hot_entry_points: Vec<usize>,
}
pub struct RefinementAnalyzer {
config: IncrementalConfig,
}
impl RefinementAnalyzer {
pub fn new(config: IncrementalConfig) -> Self {
Self { config }
}
pub fn analyze(
&self,
stats: &EdgeStats,
existing_edges: &HashSet<(usize, usize)>,
) -> RefinementSuggestions {
let mut edges_to_add = Vec::new();
let mut edges_to_remove = Vec::new();
let top_cooc = stats.top_cooccurrences(self.config.max_edges_per_pass * 2);
for ((a, b), rate) in top_cooc {
if rate >= self.config.edge_add_threshold && !existing_edges.contains(&(a, b)) {
edges_to_add.push((a, b, rate));
if edges_to_add.len() >= self.config.max_edges_per_pass {
break;
}
}
}
if stats.total_searches() >= self.config.min_traversals as u64 {
edges_to_remove = stats.underutilized_edges(self.config.edge_remove_threshold);
}
let mut entry_counts: HashMap<usize, usize> = HashMap::new();
for &ep in &stats.recent_entry_points {
*entry_counts.entry(ep).or_default() += 1;
}
let mut hot_entry_points: Vec<_> = entry_counts.into_iter().collect();
hot_entry_points.sort_unstable_by(|a, b| b.1.cmp(&a.1));
let hot_entry_points: Vec<_> = hot_entry_points
.into_iter()
.take(10)
.map(|(node, _)| node)
.collect();
RefinementSuggestions {
edges_to_add,
edges_to_remove,
hot_entry_points,
}
}
}
#[derive(Debug)]
pub struct RecencyWeighting {
insertion_times: Vec<u64>,
current_time: u64,
decay: f32,
max_bonus: f32,
}
impl RecencyWeighting {
pub fn new(initial_capacity: usize, decay: f32, max_bonus: f32) -> Self {
Self {
insertion_times: vec![0; initial_capacity],
current_time: 0,
decay,
max_bonus,
}
}
pub fn record_insertion(&mut self, node: usize) {
let time = self.current_time;
self.current_time += 1;
if node >= self.insertion_times.len() {
self.insertion_times.resize(node + 1, 0);
}
self.insertion_times[node] = time;
}
pub fn recency_bonus(&self, node: usize) -> f32 {
let current = self.current_time;
let inserted = self.insertion_times.get(node).copied().unwrap_or(0);
if current <= inserted {
return self.max_bonus;
}
let age = (current - inserted) as f32;
self.max_bonus * (-self.decay * age).exp()
}
pub fn adjust_distance(&self, distance: f32, node: usize) -> f32 {
let bonus = self.recency_bonus(node);
(distance * (1.0 - bonus)).max(0.0)
}
}
#[derive(Debug)]
pub struct DriftTracker {
query_centroid: Vec<f32>,
historical_centroids: VecDeque<Vec<f32>>,
window_count: usize,
window_size: usize,
dimension: usize,
}
impl DriftTracker {
pub fn new(dimension: usize, window_size: usize) -> Self {
Self {
query_centroid: vec![0.0; dimension],
historical_centroids: VecDeque::new(),
window_count: 0,
window_size,
dimension,
}
}
pub fn record_query(&mut self, query: &[f32]) {
if query.len() != self.dimension {
return;
}
self.window_count += 1;
let alpha = 1.0 / self.window_count as f32;
for (c, &q) in self.query_centroid.iter_mut().zip(query.iter()) {
*c = (1.0 - alpha) * *c + alpha * q;
}
if self.window_count >= self.window_size {
self.historical_centroids
.push_back(self.query_centroid.clone());
if self.historical_centroids.len() > 10 {
self.historical_centroids.pop_front();
}
self.query_centroid = vec![0.0; self.dimension];
self.window_count = 0;
}
}
pub fn drift_magnitude(&self) -> f32 {
if self.historical_centroids.len() < 2 {
return 0.0;
}
let historical_avg: Vec<f32> = (0..self.dimension)
.map(|d| {
self.historical_centroids.iter().map(|c| c[d]).sum::<f32>()
/ self.historical_centroids.len() as f32
})
.collect();
self.query_centroid
.iter()
.zip(historical_avg.iter())
.map(|(a, b)| (a - b).powi(2))
.sum::<f32>()
.sqrt()
}
pub fn has_drift(&self, threshold: f32) -> bool {
self.drift_magnitude() > threshold
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
#[test]
fn test_edge_stats_basic() {
let mut stats = EdgeStats::new();
stats.record_traversal(0, 1, true);
stats.record_traversal(0, 1, false);
stats.record_traversal(0, 1, true);
assert!((stats.improvement_rate(0, 1) - 0.666).abs() < 0.01);
}
#[test]
fn test_cooccurrence() {
let mut stats = EdgeStats::new();
stats.record_covisit(&[0, 1, 2]);
stats.record_search(0);
stats.record_covisit(&[0, 1, 3]);
stats.record_search(0);
assert_eq!(stats.cooccurrence_rate(0, 1), 1.0);
assert_eq!(stats.cooccurrence_rate(0, 2), 0.5);
}
#[test]
fn test_recency_weighting() {
let mut rw = RecencyWeighting::new(10, 0.1, 0.2);
rw.record_insertion(0);
rw.record_insertion(1);
rw.record_insertion(2);
let bonus_0 = rw.recency_bonus(0);
let bonus_2 = rw.recency_bonus(2);
assert!(bonus_2 > bonus_0);
let dist = 1.0;
let adj_0 = rw.adjust_distance(dist, 0);
let adj_2 = rw.adjust_distance(dist, 2);
assert!(adj_2 < adj_0);
}
#[test]
fn test_refinement_analyzer() {
let config = IncrementalConfig::default();
let analyzer = RefinementAnalyzer::new(config);
let mut stats = EdgeStats::new();
for _ in 0..100 {
stats.record_covisit(&[0, 5]);
stats.record_search(0);
}
let existing: HashSet<_> = [(0, 1), (1, 2)].into_iter().collect();
let suggestions = analyzer.analyze(&stats, &existing);
assert!(suggestions
.edges_to_add
.iter()
.any(|&(a, b, _)| { (a == 0 && b == 5) || (a == 5 && b == 0) }));
}
#[test]
fn test_drift_tracker() {
let mut tracker = DriftTracker::new(3, 10);
for _ in 0..10 {
tracker.record_query(&[0.0, 0.0, 0.0]);
}
for _ in 0..10 {
tracker.record_query(&[1.0, 1.0, 1.0]);
}
assert!(tracker.drift_magnitude() > 0.0);
}
}